AIFS
AIFS((model): AnemoiModelInterface((pre_processors): Processors [forward](ModuleDict((normalizer): InputNormalizer()))(post_processors): Processors [inverse](ModuleDict((normalizer): InputNormalizer()))(model): AnemoiModelEncProcDec((node_attributes): NamedNodesAttributes((trainable_tensors): ModuleDict((data): TrainableTensor()(hidden): TrainableTensor()))(encoder): GraphTransformerForwardMapper((trainable): TrainableTensor()(proc): GraphTransformerMapperBlock((lin_key): Linear(in_features=1024, out_features=1024, bias=True)(lin_query): Linear(in_features=1024, out_features=1024, bias=True)(lin_value): Linear(in_features=1024, out_features=1024, bias=True)(lin_self): Linear(in_features=1024, out_features=1024, bias=True)(lin_edge): Linear(in_features=11, out_features=1024, bias=True)(conv): GraphTransformerConv()(projection): Linear(in_features=1024, out_features=1024, bias=True)(layer_norm_attention): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(layer_norm_mlp_dst): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(node_dst_mlp): Sequential((0): Linear(in_features=1024, out_features=4096, bias=True)(1): GELU(approximate='none')(2): Linear(in_features=4096, out_features=1024, bias=True))(layer_norm_attention_src): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(layer_norm_attention_dest): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(layer_norm_mlp_src): Identity()(node_src_mlp): Identity())(emb_nodes_dst): Linear(in_features=12, out_features=1024, bias=True)(emb_nodes_src): Linear(in_features=218, out_features=1024, bias=True))(processor): TransformerProcessor((proc): ModuleList((0-1): 2 x TransformerProcessorChunk((blocks): ModuleList((0-7): 8 x TransformerProcessorBlock((layer_norm_attention): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(layer_norm_mlp): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(attention): MultiHeadSelfAttention((attention): FlashAttentionWrapper()(lin_qkv): Linear(in_features=1024, out_features=3072, bias=False)(projection): Linear(in_features=1024, out_features=1024, bias=True))(mlp): Sequential((0): Linear(in_features=1024, out_features=4096, bias=True)(1): GELU(approximate='none')(2): Linear(in_features=4096, out_features=1024, bias=True)))))))(decoder): GraphTransformerBackwardMapper((trainable): TrainableTensor()(proc): GraphTransformerMapperBlock((lin_key): Linear(in_features=1024, out_features=1024, bias=True)(lin_query): Linear(in_features=1024, out_features=1024, bias=True)(lin_value): Linear(in_features=1024, out_features=1024, bias=True)(lin_self): Linear(in_features=1024, out_features=1024, bias=True)(lin_edge): Linear(in_features=11, out_features=1024, bias=True)(conv): GraphTransformerConv()(projection): Linear(in_features=1024, out_features=1024, bias=True)(layer_norm_attention): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(layer_norm_mlp_dst): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(node_dst_mlp): Sequential((0): Linear(in_features=1024, out_features=4096, bias=True)(1): GELU(approximate='none')(2): Linear(in_features=4096, out_features=1024, bias=True))(layer_norm_attention_src): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(layer_norm_attention_dest): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(layer_norm_mlp_src): Identity()(node_src_mlp): Identity())(emb_nodes_dst): Linear(in_features=218, out_features=1024, bias=True)(node_data_extractor): Sequential((0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)(1): Linear(in_features=1024, out_features=102, bias=True)))(boundings): ModuleList((0): ReluBounding()(1): HardtanhBounding()(2-3): 2 x FractionBounding()))) )
model.VARIABLES = ['u10m', 'v10m', 'd2m', 't2m', 'cos_julian_day', 'cos_latitude', 'cos_local_time', 'cos_longitude', 'cp06', 'insolation',
'lsm', 'msl', 'q100', 'q1000', 'q150', 'q200', 'q250', 'q300', 'q400', 'q50',
'q500', 'q600', 'q700', 'q850', 'q925', 'sdor', 'sin_julian_day', 'sin_latitude', 'sin_local_time', 'sin_longitude',
'skt', 'slor', 'sp', 't100', 't1000', 't150', 't200', 't250', 't300', 't400',
't50', 't500', 't600', 't700', 't850', 't925', 'tcw', 'tp06', 'u100', 'u1000',
'u150', 'u200', 'u250', 'u300', 'u400', 'u50', 'u500', 'u600', 'u700', 'u850',
'u925', 'v100', 'v1000', 'v150', 'v200', 'v250', 'v300', 'v400', 'v50', 'v500',
'v600', 'v700', 'v850', 'v925', 'w100', 'w1000', 'w150', 'w200', 'w250', 'w300',
'w400', 'w50', 'w500', 'w600', 'w700', 'w850', 'w925', 'z', 'z100', 'z1000',
'z150', 'z200', 'z250', 'z300', 'z400', 'z50', 'z500', 'z600', 'z700', 'z850',
'z925', 'swvl1', 'swvl2', 'stl1', 'stl2', 'ssrd06', 'strd06', 'sf', 'tcc', 'mcc',
'hcc', 'lcc', 'u100m', 'v100m', 'ro'] len = 115
input_full_ids: tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104], dtype=torch.int32) len = 103
input_ids: tensor([ 0, 1, 2, 3, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 30, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104], dtype=torch.int32) len = 90
output_ids: tensor([ 0, 1, 2, 3, 8, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 30, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114], dtype=torch.int32) len = 102
model.VARIABLE_FORCINGS = ['cos_latitude', 'cos_longitude', 'sin_latitude', 'sin_longitude', 'cos_julian_day', 'cos_local_time', 'sin_julian_day', 'sin_local_time', 'insolation']
forcing_ids: tensor([ 5, 7, 27, 29, 4, 6, 26, 28, 9], dtype=torch.int32)
model.VARIABLE_INVARIANTS= ['lsm', 'sdor', 'slor', 'z']
aifs.py
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. # SPDX-FileCopyrightText: All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import zipfile from collections import OrderedDict from collections.abc import Generator, Iteratorimport numpy as np import torch# from earth2studio.data import IFS # from earth2studio.data.utils import fetch_data from earth2studio.models.auto import AutoModelMixin from earth2studio.models.batch import batch_coords, batch_func # from earth2studio.models.px.base import PrognosticModel from earth2studio.models.px.utils import PrognosticMixin from earth2studio.utils import handshake_coords, handshake_dim from earth2studio.utils.imports import (OptionalDependencyFailure,check_optional_dependencies, ) from earth2studio.utils.type import CoordSystemtry:import anemoi.models # noqa: F401import earthkit.regrid # noqa: F401import ecmwf.opendata # noqa: F401import flash_attn # noqa: F401 except ImportError:OptionalDependencyFailure("aifs")@check_optional_dependencies() class AIFS(torch.nn.Module, AutoModelMixin, PrognosticMixin):"""Artificial Intelligence Forecasting System (AIFS), a data driven forecast modeldeveloped by the European Centre for Medium-Range Weather Forecasts (ECMWF). AIFS isbased on a graph neural network (GNN) encoder and decoder, and a sliding windowtransformer processor, and is trained on ECMWF's ERA5 re-analysis and ECMWF'soperational numerical weather prediction (NWP) analyses.Consists of a single model with a time-step size of 6 hours.Note----This model uses the checkpoints provided by ECMWF.Multiple checkpoint versions are supported. Use:- `AIFS.load_default_package()` for the default (AIFS-Single v1.0)- `AIFS.load_default_package(version="1.1")` for AIFS-Single v1.1The checkpoint metadata (`ai-models.json`) is used to derive the correct variableordering and indices for each checkpoint version.For additional information see the following resources:- https://arxiv.org/abs/2406.01465- https://huggingface.co/ecmwf/aifs-single-1.1- https://github.com/ecmwf/anemoi-coreParameters----------model : torch.nn.ModuleCore PyTorch module with the pretrained AIFS weights loaded.latitudes : torch.TensorLatitude values for the native octahedral grid, registered as a buffer forinterpolation.longitudes : torch.TensorLongitude values for the native octahedral grid, registered as a buffer forinterpolation.interpolation_matrix : torch.TensorCSR sparse matrix mapping ERA5 lat/lon inputs onto the octahedral grid.inverse_interpolation_matrix : torch.TensorCSR sparse matrix mapping outputs from the octahedral grid back to ERA5lat/lon.invariants : torch.TensorTensor of shape [4, 721, 1440] containing the invariant fields "lsm", "sdor","slor" and "z"Warning-------We encourage users to familiarize themselves with the license restrictions of thismodel's checkpoints.Badges------region:global class:mrf product:wind product:precip product:temp product:atmos product:land product:solar year:2025 gpu:40gb"""VARIABLES = ["u10m","v10m","d2m","t2m","cos_julian_day","cos_latitude","cos_local_time","cos_longitude","cp06","insolation","lsm","msl","q100","q1000","q150","q200","q250","q300","q400","q50","q500","q600","q700","q850","q925","sdor","sin_julian_day","sin_latitude","sin_local_time","sin_longitude","skt","slor","sp","t100","t1000","t150","t200","t250","t300","t400","t50","t500","t600","t700","t850","t925","tcw","tp06","u100","u1000","u150","u200","u250","u300","u400","u50","u500","u600","u700","u850","u925","v100","v1000","v150","v200","v250","v300","v400","v50","v500","v600","v700","v850","v925","w100","w1000","w150","w200","w250","w300","w400","w50","w500","w600","w700","w850","w925","z","z100","z1000","z150","z200","z250","z300","z400","z50","z500","z600","z700","z850","z925","swvl1","swvl2","stl1","stl2","ssrd06","strd06","sf","tcc","mcc","hcc","lcc","u100m","v100m","ro",]VARIABLE_INVARIANTS = ["lsm", "sdor", "slor", "z"]VARIABLE_FORCINGS = ["cos_latitude","cos_longitude","sin_latitude","sin_longitude","cos_julian_day","cos_local_time","sin_julian_day","sin_local_time","insolation",]def __init__(self,model: torch.nn.Module,latitudes: torch.Tensor,longitudes: torch.Tensor,interpolation_matrix: torch.Tensor,inverse_interpolation_matrix: torch.Tensor,invariants: torch.Tensor,) -> None:super().__init__()self.model = modelself.register_buffer("invariants", invariants)self.register_buffer("latitudes", latitudes)self.register_buffer("longitudes", longitudes)self.register_buffer("interpolation_matrix", interpolation_matrix)self.register_buffer("inverse_interpolation_matrix", inverse_interpolation_matrix)# Check to make sure that the models total variable input variables are# consistent with the wrappers# Useful: https://github.com/ecmwf/anemoi-core/blob/main/models/src/anemoi/models/data_indices/tensor.py# https://anemoi.readthedocs.io/projects/models/en/latest/modules/data_indices.html#usage-informationname_to_index = self.model.data_indices.data._name_to_indexvariables = [name for name, idx in sorted(name_to_index.items(), key=lambda x: x[1])]if self._ckpt_var_to_e2s(variables) != self.VARIABLES:raise ValueError("Model variables are not the same as wrapper VARIABLES, your checkpoint is not expected...")self.register_buffer("input_full_ids", self.model.data_indices.data.input.full)self.register_buffer("invariant_ids",torch.IntTensor([name_to_index[v] for v in self.VARIABLE_INVARIANTS]),)self.register_buffer("forcing_ids",torch.IntTensor([name_to_index[v] for v in self.VARIABLE_FORCINGS]),)self.register_buffer("input_ids", self.model.data_indices.data.input.prognostic)self.register_buffer("output_ids",torch.sort(torch.cat([self.model.data_indices.data.output.prognostic,self.model.data_indices.data.output.diagnostic,]))[0],)def input_coords(self) -> CoordSystem:"""Input coordinate system of the prognostic modelReturns-------CoordSystemCoordinate system dictionary"""return OrderedDict({"batch": np.empty(0),"time": np.empty(0),"lead_time": np.array([np.timedelta64(-6, "h"), np.timedelta64(0, "h")]),"variable": np.array([self.VARIABLES[i] for i in self.input_ids]),"lat": np.linspace(90.0, -90.0, 721),"lon": np.linspace(0, 360, 1440, endpoint=False),})@batch_coords()def output_coords(self, input_coords: CoordSystem) -> CoordSystem:"""Output coordinate system of the prognostic modelParameters----------input_coords : CoordSystemInput coordinate system to transform into output_coordsby default None, will use self.input_coords.Returns-------CoordSystemCoordinate system dictionary"""output_coords = OrderedDict({"batch": np.empty(0),"time": np.empty(0),"lead_time": np.array([np.timedelta64(6, "h")]),"variable": np.array([self.VARIABLES[i] for i in self.output_ids]),"lat": np.linspace(90.0, -90.0, 721),"lon": np.linspace(0, 360, 1440, endpoint=False),})if input_coords is None:return output_coordstest_coords = input_coords.copy()test_coords["lead_time"] = (test_coords["lead_time"] - input_coords["lead_time"][-1])target_input_coords = self.input_coords()for i, key in enumerate(target_input_coords):if key not in ["batch", "time"]:handshake_dim(test_coords, key, i)handshake_coords(test_coords, target_input_coords, key)output_coords["batch"] = input_coords["batch"]output_coords["time"] = input_coords["time"]output_coords["lead_time"] = (input_coords["lead_time"][-1] + output_coords["lead_time"])return output_coords# @classmethod# def load_default_package(cls) -> Package:# """Load prognostic package"""# package = Package(# "hf://ecmwf/aifs-single-1.1",# cache_options={# "cache_storage": Package.default_cache("aifs-single-1.1"),# "same_names": True,# },# )# return package @classmethod@check_optional_dependencies()def load_model(cls, package = None): # -> PrognosticModel:"""Load prognostic from package"""# Load model# model_path = package.resolve("aifs-single-mse-1.1.ckpt")model_path = "/root/.cache/earth2studio/aifs-single-1.1/aifs-single-mse-1.1.ckpt"interpolation_matrix_path = "/root/.cache/earth2studio/aifs-single-1.1_interpolation_matrix/9533e90f8433424400ab53c7fafc87ba1a04453093311c0b5bd0b35fedc1fb83.npz"inverse_interpolation_matrix_path = "/root/.cache/earth2studio/aifs-single-1.1_inverse_interpolation_matrix/7f0be51c7c1f522592c7639e0d3f95bcbff8a044292aa281c1e73b842736d9bf.npz"#<------ worm#worm ----->model = torch.load(model_path, weights_only=False, map_location=torch.ones(1).device)model.eval()# Define the path to the metadata filemetadata_path = "inference-last/anemoi-metadata/ai-models.json"# Extract metadata and supporting arrays from the zip filewith zipfile.ZipFile(model_path, "r") as zipf: # NOTE: this is totally baffling# Load metadatametadata = json.load(zipf.open(metadata_path))# Load supporting arrayssupporting_arrays = {}for key, entry in metadata.get("supporting_arrays_paths", {}).items():supporting_arrays[key] = np.frombuffer(zipf.read(entry["path"]),dtype=entry["dtype"],).reshape(entry["shape"])# Load interpolation matrix# TODO: Maybe change this to allow for multiple packages?# interpolation_package = Package(# "https://get.ecmwf.int/repository/earthkit/regrid/db/1/mir_16_linear",# cache_options={# "cache_storage": Package.default_cache(# "aifs-single-1.1_interpolation_matrix"# ),# "same_names": True,# },# )# interpolation_matrix_path = interpolation_package.resolve(# "9533e90f8433424400ab53c7fafc87ba1a04453093311c0b5bd0b35fedc1fb83.npz"# )interpolation_matrix = np.load(interpolation_matrix_path)torch_interpolation_matrix = torch.sparse_csr_tensor(crow_indices=torch.from_numpy(interpolation_matrix["indptr"]),col_indices=torch.from_numpy(interpolation_matrix["indices"]),values=torch.from_numpy(interpolation_matrix["data"]),size=(interpolation_matrix["shape"][0], interpolation_matrix["shape"][1]),dtype=torch.float64,)# inverse_interpolation_package = Package(# "https://get.ecmwf.int/repository/earthkit/regrid/db/1/mir_16_linear/",# cache_options={# "cache_storage": Package.default_cache(# "aifs-single-1.1_inverse_interpolation_matrix"# ),# "same_names": True,# },# )# inverse_interpolation_matrix_path = inverse_interpolation_package.resolve(# "7f0be51c7c1f522592c7639e0d3f95bcbff8a044292aa281c1e73b842736d9bf.npz"# )inverse_interpolation_matrix = np.load(inverse_interpolation_matrix_path)torch_inverse_interpolation_matrix = torch.sparse_csr_tensor(crow_indices=torch.from_numpy(inverse_interpolation_matrix["indptr"]),col_indices=torch.from_numpy(inverse_interpolation_matrix["indices"]),values=torch.from_numpy(inverse_interpolation_matrix["data"]),size=(inverse_interpolation_matrix["shape"][0],inverse_interpolation_matrix["shape"][1],),dtype=torch.float64,)# Fetch invariants from IFS, note that there are deviations between these# invariant fields depending on where and what time the data is fetched.# For this model, we will use ECMWF's own invarints in the IFS data store.# ifs = IFS(cache=True, verbose=False)# invariants, _ = fetch_data(# source=ifs,# time=np.array([np.datetime64("2026-01-01T00:00:00")]),# variable=["lsm", "sdor", "slor", "z"],# )# invariants = invariants.squeeze()# torch.save(invariants, "invariants.pt")invariants = torch.load("invariants.pt")# Can also fetch from NCAR ERA5 backup but these have some differences# invariant_package = Package(# "https://nsf-ncar-era5.s3.amazonaws.com/e5.oper.invariant/197901/",# cache_options={# "cache_storage": Package.default_cache(# "aifs-single-1.0"# ),# "same_names": True,# },# )# invariant_arrays = []# for key, value in {"lsm": 172, "sdor": 160, "slor": 163, "z": 129}.items():# ds = xr.load_dataset(invariant_package.resolve(f"e5.oper.invariant.128_{value:03d}_{key}.ll025sc.1979010100_1979010100.nc"))# invariant_arrays.append(ds[key.upper()].values)# invariants = torch.Tensor(invariant_arrays).squeeze()return cls(model,latitudes=torch.Tensor(supporting_arrays["latitudes"]).reshape(1, 1, -1, 1),longitudes=torch.Tensor(supporting_arrays["longitudes"]).reshape(1, 1, -1, 1),interpolation_matrix=torch_interpolation_matrix,inverse_interpolation_matrix=torch_inverse_interpolation_matrix,invariants=invariants,)# def load_model(cls, package: Package) -> PrognosticModel:# """Load prognostic from package"""# # Load model# model_path = package.resolve("aifs-single-mse-1.1.ckpt")# model = torch.load(# model_path, weights_only=False, map_location=torch.ones(1).device# )# model.eval()# # Define the path to the metadata file# metadata_path = "inference-last/anemoi-metadata/ai-models.json"# # Extract metadata and supporting arrays from the zip file# with zipfile.ZipFile(model_path, "r") as zipf: # NOTE: this is totally baffling# # Load metadata# metadata = json.load(zipf.open(metadata_path))# # Load supporting arrays# supporting_arrays = {}# for key, entry in metadata.get("supporting_arrays_paths", {}).items():# supporting_arrays[key] = np.frombuffer(# zipf.read(entry["path"]),# dtype=entry["dtype"],# ).reshape(entry["shape"])# # Load interpolation matrix# # TODO: Maybe change this to allow for multiple packages?# interpolation_package = Package(# "https://get.ecmwf.int/repository/earthkit/regrid/db/1/mir_16_linear",# cache_options={# "cache_storage": Package.default_cache(# "aifs-single-1.1_interpolation_matrix"# ),# "same_names": True,# },# )# interpolation_matrix_path = interpolation_package.resolve(# "9533e90f8433424400ab53c7fafc87ba1a04453093311c0b5bd0b35fedc1fb83.npz"# )# interpolation_matrix = np.load(interpolation_matrix_path)# torch_interpolation_matrix = torch.sparse_csr_tensor(# crow_indices=torch.from_numpy(interpolation_matrix["indptr"]),# col_indices=torch.from_numpy(interpolation_matrix["indices"]),# values=torch.from_numpy(interpolation_matrix["data"]),# size=(interpolation_matrix["shape"][0], interpolation_matrix["shape"][1]),# dtype=torch.float64,# )# inverse_interpolation_package = Package(# "https://get.ecmwf.int/repository/earthkit/regrid/db/1/mir_16_linear/",# cache_options={# "cache_storage": Package.default_cache(# "aifs-single-1.1_inverse_interpolation_matrix"# ),# "same_names": True,# },# )# inverse_interpolation_matrix_path = inverse_interpolation_package.resolve(# "7f0be51c7c1f522592c7639e0d3f95bcbff8a044292aa281c1e73b842736d9bf.npz"# )# inverse_interpolation_matrix = np.load(inverse_interpolation_matrix_path)# torch_inverse_interpolation_matrix = torch.sparse_csr_tensor(# crow_indices=torch.from_numpy(inverse_interpolation_matrix["indptr"]),# col_indices=torch.from_numpy(inverse_interpolation_matrix["indices"]),# values=torch.from_numpy(inverse_interpolation_matrix["data"]),# size=(# inverse_interpolation_matrix["shape"][0],# inverse_interpolation_matrix["shape"][1],# ),# dtype=torch.float64,# )# # Fetch invariants from IFS, note that there are deviations between these# # invariant fields depending on where and what time the data is fetched.# # For this model, we will use ECMWF's own invarints in the IFS data store.# ifs = IFS(cache=True, verbose=False)# invariants, _ = fetch_data(# source=ifs,# time=np.array([np.datetime64("2026-01-01T00:00:00")]),# variable=["lsm", "sdor", "slor", "z"],# )# invariants = invariants.squeeze()# # Can also fetch from NCAR ERA5 backup but these have some differences# # invariant_package = Package(# # "https://nsf-ncar-era5.s3.amazonaws.com/e5.oper.invariant/197901/",# # cache_options={# # "cache_storage": Package.default_cache(# # "aifs-single-1.0"# # ),# # "same_names": True,# # },# # )# # invariant_arrays = []# # for key, value in {"lsm": 172, "sdor": 160, "slor": 163, "z": 129}.items():# # ds = xr.load_dataset(invariant_package.resolve(f"e5.oper.invariant.128_{value:03d}_{key}.ll025sc.1979010100_1979010100.nc"))# # invariant_arrays.append(ds[key.upper()].values)# # invariants = torch.Tensor(invariant_arrays).squeeze()# return cls(# model,# latitudes=torch.Tensor(supporting_arrays["latitudes"]).reshape(1, 1, -1, 1),# longitudes=torch.Tensor(supporting_arrays["longitudes"]).reshape(# 1, 1, -1, 1# ),# interpolation_matrix=torch_interpolation_matrix,# inverse_interpolation_matrix=torch_inverse_interpolation_matrix,# invariants=invariants,# ) @staticmethoddef _ckpt_var_to_e2s(names: list[str]) -> list[str]:"""Translate checkpoint variable names into Earth2Studio variable IDs."""surface_map = {"10u": "u10m","10v": "v10m","2d": "d2m","2t": "t2m","100u": "u100m","100v": "v100m",}accum_6h_map = {"cp": "cp06","tp": "tp06","ssrd": "ssrd06","strd": "strd06",}def _map_one(name: str) -> str:# Surface shorthandif name in surface_map:return surface_map[name]# 6-hour accumulationsif name in accum_6h_map:return accum_6h_map[name]# Pressure level e.g. q_50 -> q50if "_" in name:parts = name.split("_", 1)if len(parts) == 2 and parts[1].isdigit():return f"{parts[0]}{parts[1]}"return namereturn [_map_one(n) for n in names]def get_cos_sin_julian_day(self,time_array: np.datetime64,longitudes: torch.Tensor,) -> tuple[torch.Tensor, torch.Tensor]:"""Get cosine and sine of Julian day"""days = (time_array.astype("datetime64[D]") - time_array.astype("datetime64[Y]")).astype(np.float32)hours = (time_array.astype("datetime64[h]") - time_array.astype("datetime64[D]")).astype(np.float32)julian_days = days + (hours / 24.0)normalized = 2 * np.pi * (julian_days / 365.25)cos_julian_day = torch.full_like(longitudes, np.cos(normalized), dtype=torch.float32)sin_julian_day = torch.full_like(longitudes, np.sin(normalized), dtype=torch.float32)return cos_julian_day, sin_julian_daydef get_cos_sin_local_time(self,time_array: np.datetime64,longitudes: torch.Tensor,) -> tuple[torch.Tensor, torch.Tensor]:"""Get cosine and sine of local time"""hours = (time_array.astype("datetime64[h]") - time_array.astype("datetime64[D]")).astype(np.float32)normalized_time = 2 * np.pi * (hours / 24.0)normalized_longitudes = 2 * np.pi * (longitudes / 360.0)tau = normalized_time + normalized_longitudescos_local_time = torch.cos(tau)sin_local_time = torch.sin(tau)return cos_local_time, sin_local_timedef get_cosine_zenith_fields(self,date: np.datetime64,latitudes: torch.Tensor,longitudes: torch.Tensor,) -> torch.Tensor:"""Get cosine zenith fields for input time array"""# Get Julian daydays = (date.astype("datetime64[D]") - date.astype("datetime64[Y]")).astype(np.float32)hours = (date.astype("datetime64[h]") - date.astype("datetime64[D]")).astype(np.float32)seconds = (date.astype("datetime64[s]") - date.astype("datetime64[h]")).astype(np.float32)julian_day = days + seconds / 86400.0# Convert angle to tensorangle = torch.tensor(julian_day / 365.25 * torch.pi * 2, device=latitudes.device)# declination in [degrees]declination = (0.396372- 22.91327 * torch.cos(angle)+ 4.025430 * torch.sin(angle)- 0.387205 * torch.cos(2 * angle)+ 0.051967 * torch.sin(2 * angle)- 0.154527 * torch.cos(3 * angle)+ 0.084798 * torch.sin(3 * angle))# time correction in [h.degrees]time_correction = (0.004297+ 0.107029 * torch.cos(angle)- 1.837877 * torch.sin(angle)- 0.837378 * torch.cos(2 * angle)- 2.340475 * torch.sin(2 * angle))# Convert to radiansdeclination = torch.deg2rad(declination)latitudes = torch.deg2rad(latitudes)# Calculate sine and cosine of declination and latitudesindec_sinlat = torch.sin(declination) * torch.sin(latitudes)cosdec_coslat = torch.cos(declination) * torch.cos(latitudes)# Solar hour anglesolar_angle = torch.deg2rad((hours - 12) * 15 + longitudes + time_correction)zenith_angle = sindec_sinlat + cosdec_coslat * torch.cos(solar_angle)# Clip negative valuesreturn torch.clamp(zenith_angle, min=0.0)def _prepare_input(self,x: torch.Tensor,coords: CoordSystem,) -> torch.Tensor:"""Prepare input tensor and coordinates for the AIFS model."""# Interpolate the input tensor to the model gridshape = x.shapex = x.flatten(start_dim=4)x = x.flatten(end_dim=3)x = torch.swapaxes(x, 0, -1)x = x.to(dtype=torch.float64)x = self.interpolation_matrix @ xx = x.to(dtype=torch.float32)x = torch.swapaxes(x, 0, -1)x = x.reshape([shape[0] * shape[1], shape[2], shape[3], -1])x = torch.swapaxes(x, 2, 3)n_bt = shape[0] * shape[1]n_lead = shape[2]n_nodes = x.shape[2]# Interpolate invariantsi = self.invariants.flatten(start_dim=1)i = torch.swapaxes(i, 0, -1)i = i.to(dtype=torch.float64)i = self.interpolation_matrix @ ii = i.to(dtype=torch.float32)# Reconstruct full feature tensor in checkpoint variable space (ordering and# indices are checkpoint dependent).x_full = torch.zeros((n_bt, n_lead, n_nodes, len(self.VARIABLES)),device=x.device,dtype=torch.float32,)x_full[..., self.input_ids] = xx_full[..., self.invariant_ids] = i# Compute generated forcingscos_latitude = torch.cos(torch.deg2rad(self.latitudes)).to(dtype=torch.float32)sin_latitude = torch.sin(torch.deg2rad(self.latitudes)).to(dtype=torch.float32)cos_longitude = torch.cos(torch.deg2rad(self.longitudes)).to(dtype=torch.float32)sin_longitude = torch.sin(torch.deg2rad(self.longitudes)).to(dtype=torch.float32)cos_latitude = cos_latitude.repeat(n_bt, n_lead, 1, 1)sin_latitude = sin_latitude.repeat(n_bt, n_lead, 1, 1)cos_longitude = cos_longitude.repeat(n_bt, n_lead, 1, 1)sin_longitude = sin_longitude.repeat(n_bt, n_lead, 1, 1)cos_julian_day_0, sin_julian_day_0 = self.get_cos_sin_julian_day(coords["time"][0] - np.timedelta64(6, "h"), self.longitudes)cos_julian_day_1, sin_julian_day_1 = self.get_cos_sin_julian_day(coords["time"][0], self.longitudes)cos_julian_day = torch.cat([cos_julian_day_0, cos_julian_day_1], dim=1).repeat(n_bt, 1, 1, 1)sin_julian_day = torch.cat([sin_julian_day_0, sin_julian_day_1], dim=1).repeat(n_bt, 1, 1, 1)cos_local_time_0, sin_local_time_0 = self.get_cos_sin_local_time(coords["time"][0] - np.timedelta64(6, "h"), self.longitudes)cos_local_time_1, sin_local_time_1 = self.get_cos_sin_local_time(coords["time"][0], self.longitudes)cos_local_time = torch.cat([cos_local_time_0, cos_local_time_1], dim=1).repeat(n_bt, 1, 1, 1)sin_local_time = torch.cat([sin_local_time_0, sin_local_time_1], dim=1).repeat(n_bt, 1, 1, 1)cos_zenith_angle_0 = self.get_cosine_zenith_fields(coords["time"][0] - np.timedelta64(6, "h"), self.latitudes, self.longitudes)cos_zenith_angle_1 = self.get_cosine_zenith_fields(coords["time"][0], self.latitudes, self.longitudes)cos_zenith_angle = torch.cat([cos_zenith_angle_0, cos_zenith_angle_1], dim=1).repeat(n_bt, 1, 1, 1)x_full[..., self.forcing_ids[0]] = cos_latitude[..., 0]x_full[..., self.forcing_ids[1]] = cos_longitude[..., 0]x_full[..., self.forcing_ids[2]] = sin_latitude[..., 0]x_full[..., self.forcing_ids[3]] = sin_longitude[..., 0]x_full[..., self.forcing_ids[4]] = cos_julian_day[..., 0]x_full[..., self.forcing_ids[5]] = cos_local_time[..., 0]x_full[..., self.forcing_ids[6]] = sin_julian_day[..., 0]x_full[..., self.forcing_ids[7]] = sin_local_time[..., 0]x_full[..., self.forcing_ids[8]] = cos_zenith_angle[..., 0]x_full = x_full[..., self.input_full_ids] # Select input (prognostic + forcing + invar)return x_fulldef _update_input(self,x: torch.Tensor,coords: CoordSystem,) -> torch.Tensor:"""Update time based inputs."""time0 = coords["time"][0] + coords["lead_time"][0]time1 = coords["time"][0] + coords["lead_time"][1]# Get cos, sin of Julian daycos_julian_day_0, sin_julian_day_0 = self.get_cos_sin_julian_day(time0, self.longitudes)cos_julian_day_1, sin_julian_day_1 = self.get_cos_sin_julian_day(time1, self.longitudes)cos_julian_day = torch.cat([cos_julian_day_0, cos_julian_day_1], dim=1)sin_julian_day = torch.cat([sin_julian_day_0, sin_julian_day_1], dim=1)# Get cos, sin local timecos_local_time_0, sin_local_time_0 = self.get_cos_sin_local_time(time0, self.longitudes)cos_local_time_1, sin_local_time_1 = self.get_cos_sin_local_time(time1, self.longitudes)cos_local_time = torch.cat([cos_local_time_0, cos_local_time_1], dim=1)sin_local_time = torch.cat([sin_local_time_0, sin_local_time_1], dim=1)# Get cosine zenith angle# Add insolation / cosine zenith anglecos_zenith_angle_0 = self.get_cosine_zenith_fields(time0, self.latitudes, self.longitudes)cos_zenith_angle_1 = self.get_cosine_zenith_fields(time1, self.latitudes, self.longitudes)cos_zenith_angle = torch.cat([cos_zenith_angle_0, cos_zenith_angle_1], dim=1)x[..., self.forcing_ids[4]] = cos_julian_day[..., 0]x[..., self.forcing_ids[5]] = cos_local_time[..., 0]x[..., self.forcing_ids[6]] = sin_julian_day[..., 0]x[..., self.forcing_ids[7]] = sin_local_time[..., 0]x[..., self.forcing_ids[8]] = cos_zenith_angle[..., 0]# Select out actual input variables from the full fields setx = x[..., self.input_full_ids]return xdef _prepare_output(self,x: torch.Tensor,coords: CoordSystem,) -> tuple[torch.Tensor, CoordSystem]:"""Prepare input tensor and coordinates for the AIFS model."""# Remove generated forcingsx = x[..., self.output_ids]shape = x.shape# Interpolate the model grid to the lat lon gridx = x[:, 1:2]x = x.flatten(end_dim=1)x = torch.swapaxes(x, 0, 1)x = x.flatten(start_dim=1)x = x.to(dtype=torch.float64)x = self.inverse_interpolation_matrix @ xx = x.to(dtype=torch.float32)x = torch.reshape(x, [x.shape[0], shape[0], shape[-1]])x = torch.swapaxes(x, 0, 1)x = torch.swapaxes(x, 1, 2)x = torch.reshape(x,[coords["batch"].shape[0],coords["time"].shape[0],coords["lead_time"].shape[0],coords["variable"].shape[0],coords["lat"].shape[0],coords["lon"].shape[0],],)return xdef _forward(self,x: torch.Tensor,coords: CoordSystem,step: int = 1,) -> tuple[torch.Tensor, CoordSystem]:output_coords = self.output_coords(coords)with torch.autocast(device_type=str(x.device), dtype=torch.float16):y = self.model.predict_step(x, fcstep=step)out = torch.zeros((x.shape[0], x.shape[1], x.shape[2], len(self.VARIABLES)),device=x.device,)out[:, 0, :, self.input_full_ids] = x[:, 1]out[:, 1, :, self.model.data_indices.data.output.full] = y[:, 0]forcing_full = torch.sort(torch.cat([self.forcing_ids, self.invariant_ids], dim=0))[0]mask = torch.isin(self.input_full_ids, forcing_full)out[:, 1, :, forcing_full] = x[:, 1, :, torch.where(mask)[0]]return out, output_coords@batch_func()def __call__(self,x: torch.Tensor,coords: CoordSystem,) -> tuple[torch.Tensor, CoordSystem]:"""Runs prognostic model 1 step.Parameters----------x : torch.TensorInput tensorcoords : CoordSystemInput coordinate systemReturns-------tuple[torch.Tensor, CoordSystem]Output tensor and coordinate system 6 hours in the future"""_ = self.output_coords(coords) # NOTE: Quick fix for exception handlingx = self._prepare_input(x, coords)x, coords = self._forward(x, coords)x = self._prepare_output(x, coords)return x, coordsdef _fill_input(self, x: torch.Tensor, coords: CoordSystem) -> torch.Tensor:"""Helper function of create a lat/lon tensor with the input prognostic and zerofilled diagnostic variables."""# add invariants to prognosticsbatch, time, lead, _, height, width = x.shape# Prepare empty output tensor with VARIABLE dimensionout = torch.zeros((batch, time, lead, len(self.VARIABLES), height, width),device=x.device,)# Fill tensor: copy input slices into selected variable slotsout[:, :, 0, self.input_ids] = x[0, 0, 0, ...]out[:, :, 0, self.invariant_ids] = self.invariantsout[:, :, 1, self.input_ids] = x[0, 0, 1, ...]out[:, :, 1, self.invariant_ids] = self.invariantsout = out[:, :, :, self.output_ids, ...]out_coords = coords.copy()out_coords["variable"] = np.array([self.VARIABLES[i] for i in self.output_ids])return out, out_coords@batch_func()def _default_generator(self, x: torch.Tensor, coords: CoordSystem) -> Generator[tuple[torch.Tensor, CoordSystem], None, None]:coords = coords.copy()self.output_coords(coords)output_tensor, coords_out = self._fill_input(x, coords)coords_out["lead_time"] = coords["lead_time"][1:]yield output_tensor[:, :, 1:], coords_out# Prepare input tensorx = self._prepare_input(x, coords)step = 1while True:# Front hookx, coords = self.front_hook(x, coords)# Forward is identity operatory, coords_out = self._forward(x, coords, step=step)# Prepare output tensoroutput_tensor = self._prepare_output(y, coords_out)# Rear hookoutput_tensor, coords_out = self.rear_hook(output_tensor, coords_out)# Yield output tensoryield output_tensor, coords_out.copy()# Update coordinatescoords["lead_time"] = (coords["lead_time"]+ self.output_coords(self.input_coords())["lead_time"])# Prepare input tensorx = self._update_input(y, coords)step += 1def create_iterator(self, x: torch.Tensor, coords: CoordSystem) -> Iterator[tuple[torch.Tensor, CoordSystem]]:"""Creates a iterator which can be used to perform time-integration of theprognostic model. Will return the initial condition first (0th step).Parameters----------x : torch.TensorInput tensorcoords : CoordSystemInput coordinate systemYields------Iterator[tuple[torch.Tensor, CoordSystem]]Iterator that generates time-steps of the prognostic model container theoutput data tensor and coordinate system dictionary."""yield from self._default_generator(x, coords)
