8000 enh: Device and dtype management by stephane-rivaud · Pull Request #93 · growingnet/gromo · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

enh: Device and dtype management #93

New issue 8000

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Develop branch
Enhancements
~~~~~~~~~~~~

- Add a `to` method to the `GrowingContainer`, `GrowingModule` and `TensorStatistic` classes to move the module to a different device and/or dtype. Propagate device management to all growing modules and containers. (:gh:`93` by `Stéphane Rivaud`_)
- Split Conv2dGrowingModule into two subclass `FullConv2dGrowingModule`(that does the same as the previous class) and `RestrictedConv2dGrowingModule` (that compute only the best 1x1 convolution as the second layer at growth time) (:gh:`92` by `Théo Rudkiewicz`_).
- Code factorization of methods `compute_optimal_added_parameters` and `compute_optimal_delta` that are now abstracted in the `GrowingModule` class. (:gh:`87` by `Théo Rudkiewicz`_).
- Stops automatically computing parameter update in `Conv2dGrowingModule.compute_optimal_added_parameters`to be consistent with `LinearGrowingModule.compute_optimal_added_parameters` (:gh:`87` by `Théo Rudkiewicz`_) .
Expand Down
19 changes: 18 additions & 1 deletion src/gromo/containers/growing_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from gromo.config.loader import load_config
from gromo.modules.growing_module import GrowingModule, MergeGrowingModule
from gromo.utils.tensor_statistic import TensorStatistic
from gromo.utils.utils import get_correct_device, global_device


Expand Down Expand Up @@ -43,7 +44,7 @@ def __init__(
self.in_features = in_features
self.out_features = out_features

self._growing_layers = list()
self._growing_layers = torch.nn.ModuleList()
self.currently_updated_layer_index = None

def set_growing_layers(self):
Expand Down Expand Up @@ -136,3 +137,19 @@ def number_of_parameters(self) -> int:
Number of parameters.
"""
return sum(p.numel() for p in self.parameters())

def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
):
"""Move the module to a new device and/or dtype.

Child classes should implement this to handle their specific attributes.

Parameters
----------
device: torch.device | str | None
Target device
dtype: torch.dtype | None
Target dtype
"""
raise NotImplementedError
23 changes: 22 additions & 1 deletion src/gromo/containers/growing_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
self.set_growing_layers()

def set_growing_layers(self) -> None:
self._growing_layers = list(self.layers[1:])
self._growing_layers = nn.ModuleList(self.layers[1:])

def forward(self, x: Tensor) -> Tensor:
"""
Expand Down Expand Up @@ -205,6 +205,27 @@
), f"{item=} should be in [0, {len(self.layers)})"
return self.layers[item]

def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
):
"""
Move the module to a new device and/or dtype.

Parameters
----------
device: torch.device | str | None
device to move the module to
dtype: torch.dtype | None
dtype to move the module to
"""
if device is not None:
self.device = device

Check warning on line 222 in src/gromo/containers/growing_mlp.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/containers/growing_mlp.py#L222

Added line #L222 was not covered by tests

for layer in self.layers:
layer.to(device=device, dtype=dtype)

Check warning on line 225 in src/gromo/containers/growing_mlp.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/containers/growing_mlp.py#L225

Added line #L225 was not covered by tests

return self

Check warning on line 227 in src/gromo/containers/growing_mlp.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/containers/growing_mlp.py#L227

Added line #L227 was not covered by tests


class Perceptron(GrowingMLP):
def __init__(
Expand Down
70 changes: 66 additions & 4 deletions src/gromo/containers/growing_mlp_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
self.set_growing_layers()

def set_growing_layers(self) -> None:
self._growing_layers = [self.second_layer]
self._growing_layers = torch.nn.ModuleList([self.second_layer])

def extended_forward(self, x: Tensor) -> Tensor:
"""
Expand Down Expand Up @@ -118,6 +118,18 @@
y = self.dropout(y)
return y

def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
):
"""
Move the module to a new device and/or dtype.
"""
if device is not None:
self.device = device
self.first_layer.to(device=device, dtype=dtype)
self.second_layer.to(device=device, dtype=dtype)
return self

Check warning on line 131 in src/gromo/containers/growing_mlp_mixer.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/containers/growing_mlp_mixer.py#L128-L131

Added lines #L128 - L131 were not covered by tests

@staticmethod
def tensor_statistics(tensor: Tensor) -> Dict[str, float]:
min_value = tensor.min().item()
Expand Down Expand Up @@ -241,6 +253,18 @@
out = y + residual
return out

def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
):
"""
Move the module to a new device and/or dtype.
"""
if device is not None:
self.device = device
self.norm.to(device=device, dtype=dtype)
self.mlp.to(device=device, dtype=dtype)
return self

Check warning on line 266 in src/gromo/containers/growing_mlp_mixer.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/containers/growing_mlp_mixer.py#L263-L266

Added lines #L263 - L266 were not covered by tests

def weights_statistics(self) -> Dict[int, Dict[str, Any]]:
return self.mlp.weights_statistics()

Expand Down Expand Up @@ -322,6 +346,18 @@
out = x + residual
return out

def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
):
"""
Move the module to a new device and/or dtype.
"""
if device is not None:
self.device = device
self.norm.to(device=device, dtype=dtype)
self.mlp.to(device=device, dtype=dtype)
return self

Check warning on line 359 in src/gromo/containers/growing_mlp_mixer.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/containers/growing_mlp_mixer.py#L356-L359

Added lines #L356 - L359 were not covered by tests

def weights_statistics(self) -> Dict[int, Dict[str, Any]]:
return self.mlp.weights_statistics()

Expand Down Expand Up @@ -371,7 +407,7 @@
self.set_growing_layers()

def set_growing_layers(self) -> None:
self._growing_layers = list()
self._growing_layers = torch.nn.ModuleList()
self._growing_layers.extend(self.token_mixer._growing_layers)
self._growing_layers.extend(self.channel_mixer._growing_layers)

Expand Down Expand Up @@ -411,6 +447,18 @@
x = self.channel_mixer.extended_forward(x)
return x

def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
):
"""
Move the module to a new device and/or dtype.
"""
if device is not None:
self.device = device
self.token_mixer.to(device=device, dtype=dtype)
self.channel_mixer.to(device=device, dtype=dtype)
return self

Check warning on line 460 in src/gromo/containers/growing_mlp_mixer.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/containers/growing_mlp_mixer.py#L457-L460

Added lines #L457 - L460 were not covered by tests

def weights_statistics(self) -> Dict[int, Dict[str, Any]]:
statistics = {}
statistics[0] = self.token_mixer.weights_statistics()
Expand Down Expand Up @@ -478,8 +526,8 @@
super().__init__(
in_features=torch.tensor(in_features).prod().int().item(),
out_features=out_features,
device=device,
)
self.device = device
self.patcher = nn.Conv2d(
in_channels,
num_features,
Expand All @@ -503,7 +551,7 @@
self.set_growing_layers()

def set_growing_layers(self) -> None:
self._growing_layers = list()
self._growing_layers = torch.nn.ModuleList()
for mixer in self.mixers:
self._growing_layers.append(mixer.token_mixer.mlp.second_layer)
self._growing_layers.append(mixer.channel_mixer.mlp.second_layer)
Expand Down Expand Up @@ -556,6 +604,20 @@
logits = self.classifier(embedding)
return logits

def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
):
"""
Move the module to a new device and/or dtype.
"""
if device is not None:
self.device = device
self.patcher.to(device=device, dtype=dtype)
self.classifier.to(device=device, dtype=dtype)

Check warning on line 616 in src/gromo/containers/growing_mlp_mixer.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/containers/growing_mlp_mixer.py#L614-L616

Added lines #L614 - L616 were not covered by tests
for mixer in self.mixers:
mixer.to(device=device, dtype=dtype)
return self

Check warning on line 619 in src/gromo/containers/growing_mlp_mixer.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/containers/growing_mlp_mixer.py#L618-L619

Added lines #L618 - L619 were not covered by tests


def weights_statistics(self) -> Dict[int, Dict[str, Any]]:
statistics = {}
for i, mixer in enumerate(self.mixers):
Expand Down
33 changes: 31 additions & 2 deletions src/gromo/containers/growing_residual_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
self.set_growing_layers()

def set_growing_layers(self) -> None:
self._growing_layers = [self.second_layer]
self._growing_layers = torch.nn.ModuleList([self.second_layer])

def extended_forward(self, x: Tensor) -> Tensor:
"""
Expand Down Expand Up @@ -129,6 +129,19 @@
x = y + x
return x

def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
):
"""
Move the module to a new device and/or dtype.
"""
if device is not None:
self.device = device
self.norm.to(device=device, dtype=dtype)
self.first_layer.to(device=device, dtype=dtype)
self.second_layer.to(device=device, dtype=dtype)
return self

Check warning on line 143 in src/gromo/containers/growing_residual_mlp.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/containers/growing_residual_mlp.py#L139-L143

Added lines #L139 - L143 were not covered by tests

@staticmethod
def tensor_statistics(tensor: Tensor) -> Dict[str, float]:
min_value = tensor.min().item()
Expand Down Expand Up @@ -212,7 +225,9 @@
self.set_growing_layers()

def set_growing_layers(self):
self._growing_layers = list(block.second_layer for block in self.blocks)
self._growing_layers = torch.nn.ModuleList(
block.second_layer for block in self.blocks
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.embedding(x)
Expand Down Expand Up @@ -244,6 +259,20 @@
self.currently_updated_layer_index = i
return self.currently_updated_layer_index

def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
):
"""
Move the module to a new device and/or dtype.
"""
if device is not None:
self.device = device
self.embedding.to(device=device, dtype=dtype)
self.projection.to(device=device, dtype=dtype)

Check warning on line 271 in src/gromo/containers/growing_residual_mlp.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/containers/growing_residual_mlp.py#L269-L271

Added lines #L269 - L271 were not covered by tests
for block in self.blocks:
block.to(device=device, dtype=dtype)
return self

Check warning on line 274 in src/gromo/containers/growing_residual_mlp.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/containers/growing_residual_mlp.py#L273-L274

Added lines #L273 - L274 were not covered by tests

@staticmethod
def tensor_statistics(tensor) -> dict[str, float]:
min_value = tensor.min().item()
Expand Down
46 changes: 46 additions & 0 deletions src/gromo/modules/growing_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,52 @@
"""
return sum(p.numel() for p in self.parameters())

def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
):
"""
Move the module to a new device and/or dtype.

Parameters
----------
device: torch.device | str | None
device to move the module to
dtype: torch.dtype | None
dtype to move the module to
"""
if device is not None:
self.device = device

Check warning on line 595 in src/gromo/modules/growing_module.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/modules/growing_module.py#L595

Added line #L595 was not covered by tests

# Move the pytorch modules
self.layer.to(device=device, dtype=dtype)
self.post_layer_function.to(device=device, dtype=dtype)

Check warning on line 599 in src/gromo/modules/growing_module.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/modules/growing_module.py#L598-L599

Added lines #L598 - L599 were not covered by tests
if self.optimal_delta_layer is not None:
self.optimal_delta_layer.to(device=device, dtype=dtype)

Check warning on line 601 in src/gromo/modules/growing_module.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/modules/growing_module.py#L601

Added line #L601 was not covered by tests
if self.extended_input_layer is not None:
self.extended_input_layer.to(device=device, dtype=dtype)

Check warning on line 603 in src/gromo/modules/growing_module.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/modules/growing_module.py#L603

Added line #L603 was not covered by tests
if self.extended_output_layer is not None:
self.extended_output_layer.to(device=device, dtype=dtype)

Check warning on line 605 in src/gromo/modules/growing_module.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/modules/growing_module.py#L605

Added line #L605 was not covered by tests

# Move the tensor statistics
self.tensor_s.to(device=device, dtype=dtype)
self.tensor_m.to(device=device, dtype=dtype)
self.tensor_m_prev.to(device=device, dtype=dtype)
self.cross_covariance.to(device=device, dtype=dtype)

Check warning on line 611 in src/gromo/modules/growing_module.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/modules/growing_module.py#L608-L611

Added lines #L608 - L611 were not covered by tests
if self.s_growth_is_needed:
self.tensor_s_growth.to(device=device, dtype=dtype)

Check warning on line 613 in src/gromo/modules/growing_module.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/modules/growing_module.py#L613

Added line #L613 was not covered by tests

# Move the other attributes
if self.delta_raw is not None:
self.delta_raw.to(device=device, dtype=dtype)

Check warning on line 617 in src/gromo/modules/growing_module.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/modules/growing_module.py#L617

Added line #L617 was not covered by tests
if self.parameter_update_decrease is not None:
self.parameter_update_decrease.to(device=device, dtype=dtype)

Check warning on line 619 in src/gromo/modules/growing_module.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/modules/growing_module.py#L619

Added line #L619 was not covered by tests
if self.eigenvalues_extension is not None:
self.eigenvalues_extension.to(device=device, dtype=dtype)

Check warning on line 621 in src/gromo/modules/growing_module.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/modules/growing_module.py#L621

Added line #L621 was not covered by tests
if self.scaling_factor is not None:
self.scaling_factor.to(device=device, dtype=dtype)

Check warning on line 623 in src/gromo/modules/growing_module.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/modules/growing_module.py#L623

Added line #L623 was not covered by tests
< B41A /code>
if self._scaling_factor_next_module is not None:
self._scaling_factor_next_module.to(device=device, dtype=dtype)

Check warning on line 625 in src/gromo/modules/growing_module.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/modules/growing_module.py#L625

Added line #L625 was not covered by tests

def __str__(self, verbose=0):
if verbose == 0:
return f"{self.name} module with {self.number_of_parameters()} parameters."
Expand Down
5 changes: 4 additions & 1 deletion src/gromo/modules/linear_growing_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,10 @@ def input_extended(self) -> torch.Tensor:
if self.use_bias:
# TODO (optimize this): we could directly store the extended input
return torch.cat(
(self.input, torch.ones(*self.input.shape[:-1], 1, device=self.device)),
(
self.input,
torch.ones(*self.input.shape[:-1], 1, device=self.device),
),
dim=-1,
)
else:
Expand Down
9 changes: 9 additions & 0 deletions src/gromo/utils/tensor_statistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@
def __str__(self):
return f"{self.name} tensor of shape {self._shape} with {self.samples} samples"

def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
):
if device is not None:
self.device = device

Check warning on line 79 in src/gromo/utils/tensor_statistic.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/utils/tensor_statistic.py#L79

Added line #L79 was not covered by tests
if self._tensor is not None:
self._tensor = self._tensor.to(device=device, dtype=dtype)
return self

Check warning on line 82 in src/gromo/utils/tensor_statistic.py

View check run for this annotation

Codecov / codecov/patch

src/gromo/utils/tensor_statistic.py#L81-L82

Added lines #L81 - L82 were not covered by tests

def update(self, **kwargs):
assert (
not self._shape or self._tensor is not None
Expand Down
Loading
0