8000 15 remapping of one input variable to multiple new ones by sahahner · Pull Request #21 · ecmwf/anemoi-models · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
This repository was archived by the owner on Dec 20, 2024. It is now read-only.

15 remapping of one input variable to multiple new ones #21

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ Keep it human-readable, your future self will thank you!

### Added

- Remapper: Preprocessor for remapping one variable to multiple ones. Includes changes to the data indices since the remapper changes the number of variables.

### Changed

- Update CI to inherit from common infrastructue reusable workflows
Expand Down
25 changes: 23 additions & 2 deletions docs/modules/data_indices.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,33 @@ config entry:
:alt: Schematic of IndexCollection with Data Indexing on Data and Model levels.
:align: center

The are two Index-levels:
Additionally, prognostic and forcing variables can be remapped and
converted to multiple variables. The conversion is then done by the
remapper-preprocessor.

.. code:: yaml

data:
remapped:
- d:
- "d_1"
- "d_2"

There are two main Index-levels:

- Data: The data at "Zarr"-level provided by Anemoi-Datasets
- Model: The "squeezed" tensors with irrelevant parts missing.

These are both split into two versions:
Additionally, there are two internal model levels (After preprocessor
and before postprocessor) that are necessary because of the possiblity
to remap variables to multiple variables.

- Internal Data: Variables from Data-level that are used internally in
the model, but not exposed to the user.
- Internal Model: Variables from Model-level that are used internally
in the model, but not exposed to the user.

All indices at the different levels are split into two versions:

- Input: The data going into training / model
- Output: The data produced by training / model
Expand Down
13 changes: 13 additions & 0 deletions docs/modules/preprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,16 @@ following classes:
:members:
:no-undoc-members:
:show-inheritance:

**********
Remapper
**********

The remapper module is used to remap one variable to multiple other
variables that have been listed in data.remapped:. The module contains
the following classes:

.. automodule:: anemoi.models.preprocessing.remapper
:members:
:no-undoc-members:
:show-inheritance:
61 changes: 58 additions & 3 deletions src/anemoi/models/data_indices/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,74 @@ class IndexCollection:

def __init__(self, config, name_to_index) -> None:
self.config = OmegaConf.to_container(config, resolve=True)

self.name_to_index = dict(sorted(name_to_index.items(), key=operator.itemgetter(1)))
self.forcing = [] if config.data.forcing is None else OmegaConf.to_container(config.data.forcing, resolve=True)
self.diagnostic = (
[] if config.data.diagnostic is None else OmegaConf.to_container(config.data.diagnostic, resolve=True)
)
# config.data.remapped is a list of diccionaries: every remapper is one entry of the list
self.remapped = (
dict() if config.data.remapped is None else OmegaConf.to_container(config.data.remapped, resolve=True)
)
self.forcing_remapped = self.forcing.copy()

assert set(self.diagnostic).isdisjoint(self.forcing), (
f"Diagnostic and forcing variables overlap: {set(self.diagnostic).intersection(self.forcing)}. ",
"Please drop them at a dataset-level to exclude them from the training data.",
)
self.name_to_index = dict(sorted(name_to_index.items(), key=operator.itemgetter(1)))
assert set(self.remapped).isdisjoint(self.diagnostic), (
"Remapped variable overlap with diagnostic variables. Not implemented.",
)
assert set(self.remapped).issubset(self.name_to_index), (
"Remapping a variable that does not exist in the dataset. Check for typos: ",
f"{set(self.remapped).difference(self.name_to_index)}",
)
name_to_index_model_input = {
name: i for i, name in enumerate(key for key in self.name_to_index if key not in self.diagnostic)
}
name_to_index_model_output = {
name: i for i, name in enumerate(key for key in self.name_to_index if key not in self.forcing)
}
# remove remapped variables from internal data and model indices
name_to_index_internal_data_input = {
name: i for i, name in enumerate(key for key in self.name_to_index if key not in self.remapped)
}
name_to_index_internal_model_input = {
name: i for i, name in enumerate(key for key in name_to_index_model_input if key not in self.remapped)
}
name_to_index_internal_model_output = {
name: i for i, name in enumerate(key for key in name_to_index_model_output if key not in self.remapped)
}
# for all variables to be remapped we add the resulting remapped variables to the end of the tensors
# keep track of that in the index collections
for key in self.remapped:
for mapped in self.remapped[key]:
# add index of remapped variables to dictionary
name_to_index_internal_model_input[mapped] = len(name_to_index_internal_model_input)
name_to_index_internal_data_input[mapped] = len(name_to_index_internal_data_input)
if key not in self.forcing:
# do not include forcing variables in the remapped model output
name_to_index_internal_model_output[mapped] = len(name_to_index_internal_model_output)
else:
# add remapped forcing variables to forcing_remapped
self.forcing_remapped += [mapped]
if key in self.forcing:
# if key is in forcing we need to remove it from forcing_remapped after remapped variables have been added
self.forcing_remapped.remove(key)

self.data = DataIndex(self.diagnostic, self.forcing, self.name_to_index)
self.internal_data = DataIndex(
self.diagnostic,
self.forcing_remapped,
name_to_index_internal_data_input,
) # internal after the remapping applied to data (training)
self.model = ModelIndex(self.diagnostic, self.forcing, name_to_index_model_input, name_to_index_model_output)
self.internal_model = ModelIndex(
self.diagnostic,
self.forcing_remapped,
name_to_index_internal_model_input,
name_to_index_internal_model_output,
) # internal after the remapping applied to model (inference)

def __repr__(self) -> str:
return f"IndexCollection(config={self.config}, name_to_index={self.name_to_index})"
Expand All @@ -54,7 +102,12 @@ def __eq__(self, other):
# don't attempt to compare against unrelated types
return NotImplemented

return self.model == other.model and self.data == other.data
return (
self.model == other.model
and self.data == other.data
and self.internal_model == other.internal_model
and self.internal_data == other.internal_data
)

def __getitem__(self, key):
return getattr(self, key)
Expand All @@ -63,6 +116,8 @@ def todict(self):
return {
"data": self.data.todict(),
"model": self.model.todict(),
"internal_model": self.internal_model.todict(),
"internal_data": self.internal_data.todict(),
}

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/models/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _build_model(self) -> None:
"""Builds the model and pre- and post-processors."""
# Instantiate processors
processors = [
[name, instantiate(processor, statistics=self.statistics, data_indices=self.data_indices)]
[name, instantiate(processor, data_indices=self.data_indices, statistics=self.statistics)]
for name, processor in self.config.data.processors.items()
]

Expand Down
19 changes: 10 additions & 9 deletions src/anemoi/models/models/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,22 +104,23 @@ def __init__(
)

def _calculate_shapes_and_indices(self, data_indices: dict) -> None:
self.num_input_channels = len(data_indices.model.input)
self.num_output_channels = len(data_indices.model.output)
self._internal_input_idx = data_indices.model.input.prognostic
self._internal_output_idx = data_indices.model.output.prognostic
self.num_input_channels = len(data_indices.internal_model.input)
self.num_output_channels = len(data_indices.internal_model.output)
self._internal_input_idx = data_indices.internal_model.input.prognostic
self._internal_output_idx = data_indices.internal_model.output.prognostic

def _assert_matching_indices(self, data_indices: dict) -> None:

assert len(self._internal_output_idx) == len(data_indices.model.output.full) - len(
data_indices.model.output.diagnostic
assert len(self._internal_output_idx) == len(data_indices.internal_model.output.full) - len(
data_indices.internal_model.output.diagnostic
), (
f"Mismatch between the internal data indices ({len(self._internal_output_idx)}) and the output indices excluding "
f"diagnostic variables ({len(data_indices.model.output.full) - len(data_indices.model.output.diagnostic)})",
f"Mismatch between the internal data indices ({len(self._internal_output_idx)}) and "
f"the internal output indices excluding diagnostic variables "
f"({len(data_indices.internal_model.output.full) - len(data_indices.internal_model.output.diagnostic)})",
)
assert len(self._internal_input_idx) == len(
self._internal_output_idx,
), f"Model indices must match {self._internal_input_idx} != {self._internal_output_idx}"
), f"Internal model indices must match {self._internal_input_idx} != {self._internal_output_idx}"

def _define_tensor_sizes(self, config: DotDict) -> None:
self._data_grid_size = self._graph_data[self._graph_name_data].num_nodes
Expand Down
12 changes: 8 additions & 4 deletions src/anemoi/models/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@
#

import logging
from typing import TYPE_CHECKING
from typing import Optional

import torch
from torch import Tensor
from torch import nn

if TYPE_CHECKING:
from anemoi.models.data_indices.collection import IndexCollection

LOGGER = logging.getLogger(__name__)


Expand All @@ -23,19 +27,19 @@ class BasePreprocessor(nn.Module):
def __init__(
self,
config=None,
data_indices: Optional[IndexCollection] = None,
statistics: Optional[dict] = None,
data_indices: Optional[dict] = None,
) -> None:
"""Initialize the preprocessor.

Parameters
----------
config : DotDict
configuration object
configuration object of the processor
data_indices : IndexCollection
Data indices for input and output variables
statistics : dict
Data statistics dictionary
data_indices : dict
Data indices for input and output variables
"""
super().__init__()

Expand Down
16 changes: 9 additions & 7 deletions src/anemoi/models/preprocessing/imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,15 @@ def __init__(
Parameters
----------
config : DotDict
configuration object
configuration object of the processor
data_indices : IndexCollection
Data indices for input and output variables
statistics : dict
Data statistics dictionary
data_indices : dict
Data indices for input and output variables
"""
super().__init__(config, statistics, data_indices)
super().__init__(config, data_indices, statistics)

self.nan_locations = None
self.data_indices = data_indices

def _validate_indices(self):
assert len(self.index_training_input) == len(self.index_inference_input) <= len(self.replacement), (
Expand Down Expand Up @@ -174,8 +173,8 @@ class InputImputer(BaseImputer):
def __init__(
self,
config=None,
data_indices: Optional[IndexCollection] = None,
statistics: Optional[dict] = None,
data_indices: Optional[dict] = None,
) -> None:
super().__init__(config, data_indices, statistics)

Expand All @@ -201,7 +200,10 @@ class ConstantImputer(BaseImputer):
"""

def __init__(
self, config=None, statistics: Optional[dict] = None, data_indices: Optional[IndexCollection] = None
self,
config=None,
data_indices: Optional[IndexCollection] = None,
statistics: Optional[dict] = None,
) -> None:
super().__init__(config, data_indices, statistics)

Expand Down
8 changes: 4 additions & 4 deletions src/anemoi/models/preprocessing/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ def __init__(
Parameters
----------
config : DotDict
configuration object
configuration object of the processor
data_indices : IndexCollection
Data indices for input and output variables
statistics : dict
Data statistics dictionary
data_indices : dict
Data indices for input and output variables
"""
super().__init__(config, statistics, data_indices)
super().__init__(config, data_indices, statistics)

name_to_index_training_input = self.data_indices.data.input.name_to_index

Expand Down
Loading
Loading
0