-
Notifications
You must be signed in to change notification settings - Fork 6
enh: Device and dtype management #93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
enh: Device and dtype management #93
Conversation
…Module, and TensorStatistic classes.
…, GrowingModule, and TensorStatistic classes
…ner, GrowingModule, and TensorStatistic classes; refactor growing layers initialization to use ModuleList in GrowingMLPBlock, GrowingMixerLayer, and GrowingResidualMLP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors device and dtype management by introducing consistent “to” methods and updating growing layer collections in several modules. Key changes include adding a “to” method to TensorStatistic and GrowingModule, converting plain Python lists to torch.nn.ModuleList in container classes, and adjusting the device selection when creating additional tensors.
Reviewed Changes
Copilot reviewed 7 out of 8 changed files in this pull request and generated 2 comments.
File | Description |
---|---|
src/gromo/utils/tensor_statistic.py | Added a new “to” method to update tensor and device attributes. |
src/gromo/modules/linear_growing_module.py | Updated inner tensor creation to use the device of the input tensor. |
src/gromo/modules/growing_module.py | Added a “to” method to propagate device/dtype updates across module attributes. |
src/gromo/containers/growing_residual_mlp.py, growing_mlp_mixer.py, growing_mlp.py, growing_container.py | Replaced plain list constructs with torch.nn.ModuleList to improve parameter registration and maintain consistency across growing layers. |
Files not reviewed (1)
- docs/source/whats_new.rst: Language not supported
Comments suppressed due to low confidence (1)
src/gromo/modules/linear_growing_module.py:388
- [nitpick] Ensure that switching from using self.device to self.input.device for creating the ones-tensor is intentional and consistent with the overall device management strategy.
torch.ones(*self.input.shape[:-1], 1, device=self.input.device)
src/gromo/containers/growing_mlp.py
Outdated
@@ -84,7 +84,7 @@ def __init__( | |||
self.set_growing_layers() | |||
|
|||
def set_growing_layers(self) -> None: | |||
self._growing_layers = list(self.layers[1:]) | |||
self._growing_layers = self.layers[1:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider wrapping the sliced growing layers in torch.nn.ModuleList rather than using a plain list to ensure proper parameter registration.
self._growing_layers = self.layers[1:] | |
self._growing_layers = nn.ModuleList(self.layers[1:]) |
Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements improved device and dtype management across GrowingContainer, GrowingModule, and tensor statistics classes. Key changes include adding to() methods for moving modules and tensors to a specified device/dtype, and converting growing layers to torch.nn.ModuleList in several containers for consistent behavior.
Reviewed Changes
Copilot reviewed 7 out of 8 changed files in this pull request and generated 1 comment.
Show a summary per file
File | Description |
---|---|
src/gromo/utils/tensor_statistic.py | Added a to() method to update device and dtype for tensors |
src/gromo/modules/linear_growing_module.py | Updated bias creation to use self.input.device instead of self.device |
src/gromo/modules/growing_module.py | Added a to() method for the module and its attributes |
src/gromo/containers/growing_residual_mlp.py | Converted _growing_layers to torch.nn.ModuleList for consistency |
src/gromo/containers/growing_mlp_mixer.py | Converted _growing_layers and aggregated layers to torch.nn.ModuleList |
src/gromo/containers/growing_mlp.py | Removed list() wrapping for _growing_layers |
src/gromo/containers/growing_container.py | Added a to() method that iterates through layers to set device/dtype |
Files not reviewed (1)
- docs/source/whats_new.rst: Language not supported
Comments suppressed due to low confidence (1)
src/gromo/containers/growing_mlp.py:87
- [nitpick] For consistency with other growing containers that use torch.nn.ModuleList for _growing_layers, consider converting self.layers[1:] to a ModuleList, e.g. torch.nn.ModuleList(self.layers[1:]).
self._growing_layers = self.layers[1:]
(self.input, torch.ones(*self.input.shape[:-1], 1, device=self.device)), | ||
( | ||
self.input, | ||
torch.ones(*self.input.shape[:-1], 1, device=self.input.device), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change from using self.device to self.input.device could lead to inconsistent device handling if self.input.device is not kept in sync with the module's device. Consider using self.device for bias creation to maintain a unified device management strategy.
torch.ones(*self.input.shape[:-1], 1, device=self.input.device), | |
torch.ones(*self.input.shape[:-1], 1, device=self.device), |
Copilot uses AI. Check for mistakes.
…uleList for improved layer management
…MLP, GrowingResidualMLP, and various GrowingContainer subclasses; remove device handling from GrowingContainer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This pull request improves device and dtype management by implementing new to() methods across various modules and containers. Key changes include updating tensor statistic handling, converting growing layers to torch.nn.ModuleList, and adding device/dtype transfer methods.
Reviewed Changes
Copilot reviewed 7 out of 8 changed files in this pull request and generated 2 comments.
Show a summary per file
File | Description |
---|---|
src/gromo/utils/tensor_statistic.py | Added a to() method for transferring tensor to a new device/dtype |
src/gromo/modules/linear_growing_module.py | Reformatted torch.cat call with multi-line tuple syntax |
src/gromo/modules/growing_module.py | Implemented a to() method to move module sub-components to a new device/dtype |
src/gromo/containers/growing_residual_mlp.py | Updated growing layers to a ModuleList and added to() methods for device management |
src/gromo/containers/growing_mlp_mixer.py | Updated growing layers to a ModuleList and added to() methods for device management |
src/gromo/containers/growing_mlp.py | Converted growing layers to nn.ModuleList and added to() method for device/dtype handling |
src/gromo/containers/growing_container.py | Converted growing layers to ModuleList and introduced a placeholder to() method |
Files not reviewed (1)
- docs/source/whats_new.rst: Language not supported
…arameters and implementation requirements for child classes.
… for consistent behavior
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR improves device and dtype management by replacing native PyTorch calls with a unified .to() method across various modules and containers within the gromo library. The changes update the GrowingContainer, GrowingModule, and TensorStatistic classes, and also refactor layer storage to use torch.nn.ModuleList.
Reviewed Changes
Copilot reviewed 7 out of 8 changed files in this pull request and generated no comments.
Show a summary per file
File | Description |
---|---|
src/gromo/utils/tensor_statistic.py | Added a .to() method to update device and dtype of tensor statistics. |
src/gromo/modules/linear_growing_module.py | Reformatted tensor concatenation in input_extended for clarity. |
src/gromo/modules/growing_module.py | Introduced a comprehensive .to() method to move modules and attributes. |
src/gromo/containers/growing_residual_mlp.py | Replaced list with ModuleList and added .to() methods for layers. |
src/gromo/containers/growing_mlp_mixer.py | Replaced list with ModuleList and introduced multiple .to() methods. |
src/gromo/containers/growing_mlp.py | Updated layer storage using nn.ModuleList and added a .to() method. |
src/gromo/containers/growing_container.py | Added an abstract .to() method raising NotImplementedError. |
Files not reviewed (1)
- docs/source/whats_new.rst: Language not supported
Comments suppressed due to low confidence (1)
src/gromo/containers/growing_mlp_mixer.py:121
- Multiple definitions of the 'to' method are present in this file, which may lead to unexpected behavior. Consider consolidating these implementations into a single method to ensure consistent device and dtype management.
def to(
…okenMixer, GrowingChannelMixer, and GrowingMixerLayer for improved device management
…GrowingResidualMLP for improved device management
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This pull request enhances device and dtype handling by introducing custom .to methods across several module, container, and utility classes while also standardizing how growing layers are managed. Key changes include:
- Adding new .to methods to GrowingModule, GrowingMLP, GrowingMLPMixer, and TensorStatistic classes for consistent device/dtype conversion.
- Replacing plain lists with torch.nn.ModuleList for managing growing layers.
- Updating function signatures and docstrings to reflect the new device and dtype handling.
Reviewed Changes
Copilot reviewed 7 out of 8 changed files in this pull request and generated 1 comment.
Show a summary per file
File | Description |
---|---|
src/gromo/utils/tensor_statistic.py | Added a .to method that updates the tensor and device attributes. |
src/gromo/modules/linear_growing_module.py | Reformatted torch.cat arguments for clarity. |
src/gromo/modules/growing_module.py | Introduced a comprehensive .to method that moves module components and attributes. |
src/gromo/containers/growing_residual_mlp.py | Replaced list with torch.nn.ModuleList and added relevant .to functionality. |
src/gromo/containers/growing_mlp_mixer.py | Added several .to methods; some formatting changes to ensure device/dtype propagation. |
src/gromo/containers/growing_mlp.py | Updated growing layer initialization and added a .to method iterating over layers. |
src/gromo/containers/growing_container.py | Changed growing layer initialization and introduced an abstract .to method. |
Files not reviewed (1)
- docs/source/whats_new.rst: Language not supported
Comments suppressed due to low confidence (2)
src/gromo/modules/growing_module.py:581
- Ensure that the new device and dtype handling in the .to method is covered by tests, particularly for the movement of non-standard attributes like optimal_delta_layer and tensor statistics.
def to(
src/gromo/containers/growing_mlp_mixer.py:349
- Multiple definitions of the 'to' method are present in this file; please verify that each one belongs to a distinct class to prevent unintended overrides.
def to(
Using keyword arguments (i.e. device=device, dtype=dtype) when calling .to on PyTorch modules for consistency. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR improves device and dtype management across the growing network components by adding uniform “to” methods and updating layer collections. Key changes include:
- Adding a “to” method in multiple classes (e.g. GrowingModule, GrowingContainer, TensorStatistic) to support moving modules and tensors to a specified device/dtype.
- Replacing plain Python lists with torch.nn.ModuleList for growing layers to ensure proper module registration.
- Updating device handling in various container and module classes for a cleaner and more scalable implementation.
Reviewed Changes
Copilot reviewed 7 out of 8 changed files in this pull request and generated no comments.
Show a summary per file
File | Description |
---|---|
src/gromo/utils/tensor_statistic.py | Added a “to” method to update tensor device/dtype. |
src/gromo/modules/linear_growing_module.py | Reformatted the input extension and maintained device handling in the “to” method. |
src/gromo/modules/growing_module.py | Introduced a “to” method to propagate device/dtype changes to submodules. |
src/gromo/containers/growing_residual_mlp.py | Updated growing layers to torch.nn.ModuleList and added a “to” method. |
src/gromo/containers/growing_mlp_mixer.py | Added several “to” methods across components for device and dtype propagation. |
src/gromo/containers/growing_mlp.py | Converted layer collections to nn.ModuleList and added a “to” method. |
src/gromo/containers/growing_container.py | Updated growing layers and defined an abstract “to” method for child classes. |
Files not reviewed (1)
- docs/source/whats_new.rst: Language not supported
The Code Coverage is too low for merging but I do not know if we can actually implement testing functions for device management. Does Github have access to GPU or MPS (macos) device so I can implement proper testing ? |
Implement device and dtype handling in 8000 GrowingContainer, GrowingModule, and TensorStatistic classes. Before these modifications, device and dtype was handled with native pytorch function. The following pull request give a better implementation.