diff --git a/docs/source/whats_new.rst b/docs/source/whats_new.rst index 3fd8956..29f2ce7 100644 --- a/docs/source/whats_new.rst +++ b/docs/source/whats_new.rst @@ -20,6 +20,7 @@ Develop branch Enhancements ~~~~~~~~~~~~ +- Add a `to` method to the `GrowingContainer`, `GrowingModule` and `TensorStatistic` classes to move the module to a different device and/or dtype. Propagate device management to all growing modules and containers. (:gh:`93` by `Stéphane Rivaud`_) - Split Conv2dGrowingModule into two subclass `FullConv2dGrowingModule`(that does the same as the previous class) and `RestrictedConv2dGrowingModule` (that compute only the best 1x1 convolution as the second layer at growth time) (:gh:`92` by `Théo Rudkiewicz`_). - Code factorization of methods `compute_optimal_added_parameters` and `compute_optimal_delta` that are now abstracted in the `GrowingModule` class. (:gh:`87` by `Théo Rudkiewicz`_). - Stops automatically computing parameter update in `Conv2dGrowingModule.compute_optimal_added_parameters`to be consistent with `LinearGrowingModule.compute_optimal_added_parameters` (:gh:`87` by `Théo Rudkiewicz`_) . diff --git a/src/gromo/containers/growing_container.py b/src/gromo/containers/growing_container.py index e16561a..400435d 100644 --- a/src/gromo/containers/growing_container.py +++ b/src/gromo/containers/growing_container.py @@ -2,6 +2,7 @@ from gromo.config.loader import load_config from gromo.modules.growing_module import GrowingModule, MergeGrowingModule +from gromo.utils.tensor_statistic import TensorStatistic from gromo.utils.utils import get_correct_device, global_device @@ -43,7 +44,7 @@ def __init__( self.in_features = in_features self.out_features = out_features - self._growing_layers = list() + self._growing_layers = torch.nn.ModuleList() self.currently_updated_layer_index = None def set_growing_layers(self): @@ -136,3 +137,19 @@ def number_of_parameters(self) -> int: Number of parameters. """ return sum(p.numel() for p in self.parameters()) + + def to( + self, device: torch.device | str | None = None, dtype: torch.dtype | None = None + ): + """Move the module to a new device and/or dtype. + + Child classes should implement this to handle their specific attributes. + + Parameters + ---------- + device: torch.device | str | None + Target device + dtype: torch.dtype | None + Target dtype + """ + raise NotImplementedError diff --git a/src/gromo/containers/growing_mlp.py b/src/gromo/containers/growing_mlp.py index d360caa..161194e 100644 --- a/src/gromo/containers/growing_mlp.py +++ b/src/gromo/containers/growing_mlp.py @@ -84,7 +84,7 @@ def __init__( self.set_growing_layers() def set_growing_layers(self) -> None: - self._growing_layers = list(self.layers[1:]) + self._growing_layers = nn.ModuleList(self.layers[1:]) def forward(self, x: Tensor) -> Tensor: """ @@ -205,6 +205,27 @@ def __getitem__(self, item: int) -> LinearGrowingModule: ), f"{item=} should be in [0, {len(self.layers)})" return self.layers[item] + def to( + self, device: torch.device | str | None = None, dtype: torch.dtype | None = None + ): + """ + Move the module to a new device and/or dtype. + + Parameters + ---------- + device: torch.device | str | None + device to move the module to + dtype: torch.dtype | None + dtype to move the module to + """ + if device is not None: + self.device = device + + for layer in self.layers: + layer.to(device=device, dtype=dtype) + + return self + class Perceptron(GrowingMLP): def __init__( diff --git a/src/gromo/containers/growing_mlp_mixer.py b/src/gromo/containers/growing_mlp_mixer.py index 5011ac8..2af767f 100644 --- a/src/gromo/containers/growing_mlp_mixer.py +++ b/src/gromo/containers/growing_mlp_mixer.py @@ -68,7 +68,7 @@ def __init__( self.set_growing_layers() def set_growing_layers(self) -> None: - self._growing_layers = [self.second_layer] + self._growing_layers = torch.nn.ModuleList([self.second_layer]) def extended_forward(self, x: Tensor) -> Tensor: """ @@ -118,6 +118,18 @@ def forward(self, x: Tensor) -> Tensor: y = self.dropout(y) return y + def to( + self, device: torch.device | str | None = None, dtype: torch.dtype | None = None + ): + """ + Move the module to a new device and/or dtype. + """ + if device is not None: + self.device = device + self.first_layer.to(device=device, dtype=dtype) + self.second_layer.to(device=device, dtype=dtype) + return self + @staticmethod def tensor_statistics(tensor: Tensor) -> Dict[str, float]: min_value = tensor.min().item() @@ -241,6 +253,18 @@ def extended_forward(self, x: Tensor) -> Tensor: out = y + residual return out + def to( + self, device: torch.device | str | None = None, dtype: torch.dtype | None = None + ): + """ + Move the module to a new device and/or dtype. + """ + if device is not None: + self.device = device + self.norm.to(device=device, dtype=dtype) + self.mlp.to(device=device, dtype=dtype) + return self + def weights_statistics(self) -> Dict[int, Dict[str, Any]]: return self.mlp.weights_statistics() @@ -322,6 +346,18 @@ def extended_forward(self, x: Tensor) -> Tensor: out = x + residual return out + def to( + self, device: torch.device | str | None = None, dtype: torch.dtype | None = None + ): + """ + Move the module to a new device and/or dtype. + """ + if device is not None: + self.device = device + self.norm.to(device=device, dtype=dtype) + self.mlp.to(device=device, dtype=dtype) + return self + def weights_statistics(self) -> Dict[int, Dict[str, Any]]: return self.mlp.weights_statistics() @@ -371,7 +407,7 @@ def __init__( self.set_growing_layers() def set_growing_layers(self) -> None: - self._growing_layers = list() + self._growing_layers = torch.nn.ModuleList() self._growing_layers.extend(self.token_mixer._growing_layers) self._growing_layers.extend(self.channel_mixer._growing_layers) @@ -411,6 +447,18 @@ def extended_forward(self, x: Tensor) -> Tensor: x = self.channel_mixer.extended_forward(x) return x + def to( + self, device: torch.device | str | None = None, dtype: torch.dtype | None = None + ): + """ + Move the module to a new device and/or dtype. + """ + if device is not None: + self.device = device + self.token_mixer.to(device=device, dtype=dtype) + self.channel_mixer.to(device=device, dtype=dtype) + return self + def weights_statistics(self) -> Dict[int, Dict[str, Any]]: statistics = {} statistics[0] = self.token_mixer.weights_statistics() @@ -478,8 +526,8 @@ def __init__( super().__init__( in_features=torch.tensor(in_features).prod().int().item(), out_features=out_features, + device=device, ) - self.device = device self.patcher = nn.Conv2d( in_channels, num_features, @@ -503,7 +551,7 @@ def __init__( self.set_growing_layers() def set_growing_layers(self) -> None: - self._growing_layers = list() + self._growing_layers = torch.nn.ModuleList() for mixer in self.mixers: self._growing_layers.append(mixer.token_mixer.mlp.second_layer) self._growing_layers.append(mixer.channel_mixer.mlp.second_layer) @@ -556,6 +604,20 @@ def extended_forward(self, x: Tensor) -> Tensor: logits = self.classifier(embedding) return logits + def to( + self, device: torch.device | str | None = None, dtype: torch.dtype | None = None + ): + """ + Move the module to a new device and/or dtype. + """ + if device is not None: + self.device = device + self.patcher.to(device=device, dtype=dtype) + self.classifier.to(device=device, dtype=dtype) + for mixer in self.mixers: + mixer.to(device=device, dtype=dtype) + return self + def weights_statistics(self) -> Dict[int, Dict[str, Any]]: statistics = {} for i, mixer in enumerate(self.mixers): diff --git a/src/gromo/containers/growing_residual_mlp.py b/src/gromo/containers/growing_residual_mlp.py index bf90b43..3b85368 100644 --- a/src/gromo/containers/growing_residual_mlp.py +++ b/src/gromo/containers/growing_residual_mlp.py @@ -79,7 +79,7 @@ def __init__( self.set_growing_layers() def set_growing_layers(self) -> None: - self._growing_layers = [self.second_layer] + self._growing_layers = torch.nn.ModuleList([self.second_layer]) def extended_forward(self, x: Tensor) -> Tensor: """ @@ -129,6 +129,19 @@ def forward(self, x: Tensor) -> Tensor: x = y + x return x + def to( + self, device: torch.device | str | None = None, dtype: torch.dtype | None = None + ): + """ + Move the module to a new device and/or dtype. + """ + if device is not None: + self.device = device + self.norm.to(device=device, dtype=dtype) + self.first_layer.to(device=device, dtype=dtype) + self.second_layer.to(device=device, dtype=dtype) + return self + @staticmethod def tensor_statistics(tensor: Tensor) -> Dict[str, float]: min_value = tensor.min().item() @@ -212,7 +225,9 @@ def __init__( self.set_growing_layers() def set_growing_layers(self): - self._growing_layers = list(block.second_layer for block in self.blocks) + self._growing_layers = torch.nn.ModuleList( + block.second_layer for block in self.blocks + ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.embedding(x) @@ -244,6 +259,20 @@ def select_update(self, layer_index: int, verbose: bool = False) -> int: self.currently_updated_layer_index = i return self.currently_updated_layer_index + def to( + self, device: torch.device | str | None = None, dtype: torch.dtype | None = None + ): + """ + Move the module to a new device and/or dtype. + """ + if device is not None: + self.device = device + self.embedding.to(device=device, dtype=dtype) + self.projection.to(device=device, dtype=dtype) + for block in self.blocks: + block.to(device=device, dtype=dtype) + return self + @staticmethod def tensor_statistics(tensor) -> dict[str, float]: min_value = tensor.min().item() diff --git a/src/gromo/modules/growing_module.py b/src/gromo/modules/growing_module.py index 1ccb850..5c630b7 100644 --- a/src/gromo/modules/growing_module.py +++ b/src/gromo/modules/growing_module.py @@ -578,6 +578,52 @@ def number_of_parameters(self) -> int: """ return sum(p.numel() for p in self.parameters()) + def to( + self, device: torch.device | str | None = None, dtype: torch.dtype | None = None + ): + """ + Move the module to a new device and/or dtype. + + Parameters + ---------- + device: torch.device | str | None + device to move the module to + dtype: torch.dtype | None + dtype to move the module to + """ + if device is not None: + self.device = device + + # Move the pytorch modules + self.layer.to(device=device, dtype=dtype) + self.post_layer_function.to(device=device, dtype=dtype) + if self.optimal_delta_layer is not None: + self.optimal_delta_layer.to(device=device, dtype=dtype) + if self.extended_input_layer is not None: + self.extended_input_layer.to(device=device, dtype=dtype) + if self.extended_output_layer is not None: + self.extended_output_layer.to(device=device, dtype=dtype) + + # Move the tensor statistics + self.tensor_s.to(device=device, dtype=dtype) + self.tensor_m.to(device=device, dtype=dtype) + self.tensor_m_prev.to(device=device, dtype=dtype) + self.cross_covariance.to(device=device, dtype=dtype) + if self.s_growth_is_needed: + self.tensor_s_growth.to(device=device, dtype=dtype) + + # Move the other attributes + if self.delta_raw is not None: + self.delta_raw.to(device=device, dtype=dtype) + if self.parameter_update_decrease is not None: + self.parameter_update_decrease.to(device=device, dtype=dtype) + if self.eigenvalues_extension is not None: + self.eigenvalues_extension.to(device=device, dtype=dtype) + if self.scaling_factor is not None: + self.scaling_factor.to(device=device, dtype=dtype) + if self._scaling_factor_next_module is not None: + self._scaling_factor_next_module.to(device=device, dtype=dtype) + def __str__(self, verbose=0): if verbose == 0: return f"{self.name} module with {self.number_of_parameters()} parameters." diff --git a/src/gromo/modules/linear_growing_module.py b/src/gromo/modules/linear_growing_module.py index 0feda86..c05fe2b 100644 --- a/src/gromo/modules/linear_growing_module.py +++ b/src/gromo/modules/linear_growing_module.py @@ -383,7 +383,10 @@ def input_extended(self) -> torch.Tensor: if self.use_bias: # TODO (optimize this): we could directly store the extended input return torch.cat( - (self.input, torch.ones(*self.input.shape[:-1], 1, device=self.device)), + ( + self.input, + torch.ones(*self.input.shape[:-1], 1, device=self.device), + ), dim=-1, ) else: diff --git a/src/gromo/utils/tensor_statistic.py b/src/gromo/utils/tensor_statistic.py index 72fed6b..7da775a 100644 --- a/src/gromo/utils/tensor_statistic.py +++ b/src/gromo/utils/tensor_statistic.py @@ -72,6 +72,15 @@ def __init__( def __str__(self): return f"{self.name} tensor of shape {self._shape} with {self.samples} samples" + def to( + self, device: torch.device | str | None = None, dtype: torch.dtype | None = None + ): + if device is not None: + self.device = device + if self._tensor is not None: + self._tensor = self._tensor.to(device=device, dtype=dtype) + return self + def update(self, **kwargs): assert ( not self._shape or self._tensor is not None