-
-
Notifications
You must be signed in to change notification settings - Fork 649
Add filename components in Checkpoint #2498
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
vfdev-5
merged 10 commits into
pytorch:master
from
sdesrozis:checkpoint_load_from_filename
Mar 9, 2022
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
a9b282b
add filename components in Checkpoint
d96fc1b
remove debug print
9d5304f
add doc
71d59b9
add Path to valid type
e53e93b
handle Path as a valid type
f8f3568
fix mypy
8cd20a5
remove mypy (error with win32 only)
fb79a42
add mandatory empty lines dor doctest
e83533e
rename private function
2278d9f
Merge branch 'master' into checkpoint_load_from_filename
vfdev-5 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -327,6 +327,18 @@ def __init__( | |
self.include_self = include_self | ||
self.greater_or_equal = greater_or_equal | ||
|
||
def _get_filename_pattern(self, global_step: Optional[int]) -> str: | ||
if self.filename_pattern is None: | ||
filename_pattern = self.setup_filename_pattern( | ||
with_prefix=len(self.filename_prefix) > 0, | ||
with_score=self.score_function is not None, | ||
with_score_name=self.score_name is not None, | ||
with_global_step=global_step is not None, | ||
) | ||
else: | ||
filename_pattern = self.filename_pattern | ||
return filename_pattern | ||
|
||
def reset(self) -> None: | ||
"""Method to reset saved checkpoint names. | ||
|
||
|
@@ -402,15 +414,7 @@ def __call__(self, engine: Engine) -> None: | |
name = k | ||
checkpoint = checkpoint[name] | ||
|
||
if self.filename_pattern is None: | ||
filename_pattern = self.setup_filename_pattern( | ||
with_prefix=len(self.filename_prefix) > 0, | ||
with_score=self.score_function is not None, | ||
with_score_name=self.score_name is not None, | ||
with_global_step=global_step is not None, | ||
) | ||
else: | ||
filename_pattern = self.filename_pattern | ||
filename_pattern = self._get_filename_pattern(global_step) | ||
|
||
filename_dict = { | ||
"filename_prefix": self.filename_prefix, | ||
|
@@ -519,41 +523,51 @@ def _check_objects(objs: Mapping, attr: str) -> None: | |
raise TypeError(f"Object {type(obj)} should have `{attr}` method") | ||
|
||
@staticmethod | ||
def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping], **kwargs: Any) -> None: | ||
def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping, Path], **kwargs: Any) -> None: | ||
"""Helper method to apply ``load_state_dict`` on the objects from ``to_load`` using states from ``checkpoint``. | ||
|
||
Args: | ||
to_load: a dictionary with objects, e.g. `{"model": model, "optimizer": optimizer, ...}` | ||
checkpoint: a string filepath or a dictionary with state_dicts to load, e.g. `{"model": model_state_dict, | ||
"optimizer": opt_state_dict}`. If `to_load` contains a single key, then checkpoint can contain | ||
directly corresponding state_dict. | ||
checkpoint: a path, a string filepath or a dictionary with state_dicts to load, e.g. | ||
`{"model": model_state_dict, "optimizer": opt_state_dict}`. If `to_load` contains a single key, | ||
then checkpoint can contain directly corresponding state_dict. | ||
kwargs: Keyword arguments accepted for `nn.Module.load_state_dict()`. Passing `strict=False` enables | ||
the user to load part of the pretrained model (useful for example, in Transfer Learning) | ||
|
||
Examples: | ||
.. code-block:: python | ||
|
||
import tempfile | ||
from pathlib import Path | ||
|
||
import torch | ||
|
||
from ignite.engine import Engine, Events | ||
from ignite.handlers import ModelCheckpoint, Checkpoint | ||
|
||
trainer = Engine(lambda engine, batch: None) | ||
handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=None, create_dir=True) | ||
model = torch.nn.Linear(3, 3) | ||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) | ||
to_save = {"weights": model, "optimizer": optimizer} | ||
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save) | ||
trainer.run(torch.randn(10, 1), 5) | ||
|
||
to_load = to_save | ||
checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth" | ||
checkpoint = torch.load(checkpoint_fp) | ||
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint) | ||
with tempfile.TemporaryDirectory() as tmpdirname: | ||
handler = ModelCheckpoint(tmpdirname, 'myprefix', n_saved=None, create_dir=True) | ||
|
||
model = torch.nn.Linear(3, 3) | ||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) | ||
|
||
to_save = {"weights": model, "optimizer": optimizer} | ||
|
||
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save) | ||
trainer.run(torch.randn(10, 1), 5) | ||
|
||
to_load = to_save | ||
checkpoint_fp = Path(tmpdirname) / 'myprefix_checkpoint_40.pt' | ||
checkpoint = torch.load(checkpoint_fp) | ||
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint) | ||
Comment on lines
+550
to
+564
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sdesrozis this example is not working: ---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [2], in <cell line: 10>()
21 to_load = to_save
22 # load checkpoint myprefix_checkpoint_40.pt
---> 23 checkpoint.load_objects(to_load=to_load, global_step=40)
TypeError: load_objects() missing 1 required positional argument: 'checkpoint' |
||
|
||
# or using a string for checkpoint filepath | ||
# or using a string for checkpoint filepath | ||
|
||
to_load = to_save | ||
checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth" | ||
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_fp) | ||
to_load = to_save | ||
checkpoint_fp = Path(tmpdirname) / 'myprefix_checkpoint_40.pt' | ||
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_fp) | ||
|
||
Note: | ||
If ``to_load`` contains objects of type torch `DistributedDataParallel`_ or | ||
|
@@ -564,13 +578,13 @@ def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping], **kwargs: An | |
.. _DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html | ||
""" | ||
|
||
if isinstance(checkpoint, str): | ||
if isinstance(checkpoint, (str, Path)): | ||
checkpoint_obj = torch.load(checkpoint) | ||
else: | ||
checkpoint_obj = checkpoint | ||
|
||
Checkpoint._check_objects(to_load, "load_state_dict") | ||
if not isinstance(checkpoint, (collections.Mapping, str)): | ||
if not isinstance(checkpoint, (collections.Mapping, str, Path)): | ||
raise TypeError(f"Argument checkpoint should be a string or a dictionary, but given {type(checkpoint)}") | ||
|
||
if len(kwargs) > 1 or any(k for k in kwargs if k not in ["strict"]): | ||
|
@@ -599,6 +613,82 @@ def _load_object(obj: Any, chkpt_obj: Any) -> None: | |
raise ValueError(f"Object labeled by '{k}' from `to_load` is not found in the checkpoint") | ||
_load_object(obj, checkpoint_obj[k]) | ||
|
||
def reload_objects(self, to_load: Mapping, load_kwargs: Optional[Dict] = None, **filename_components: Any) -> None: | ||
"""Helper method to apply ``load_state_dict`` on the objects from ``to_load``. Filename components such as | ||
name, score and global state can be configured. | ||
|
||
Args: | ||
to_load: a dictionary with objects, e.g. `{"model": model, "optimizer": optimizer, ...}` | ||
load_kwargs: Keyword arguments accepted for `nn.Module.load_state_dict()`. Passing `strict=False` enables | ||
the user to load part of the pretrained model (useful for example, in Transfer Learning) | ||
filename_components: Filename components used to define the checkpoint file path. | ||
Keyword arguments accepted are `name`, `score` and `global_state`. | ||
|
||
Examples: | ||
.. code-block:: python | ||
|
||
import tempfile | ||
|
||
import torch | ||
|
||
from ignite.engine import Engine, Events | ||
from ignite.handlers import ModelCheckpoint, Checkpoint | ||
|
||
trainer = Engine(lambda engine, batch: None) | ||
|
||
with tempfile.TemporaryDirectory() as tmpdirname: | ||
checkpoint = ModelCheckpoint(tmpdirname, 'myprefix', n_saved=None, create_dir=True) | ||
|
||
model = torch.nn.Linear(3, 3) | ||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) | ||
|
||
to_save = {"weights": model, "optimizer": optimizer} | ||
|
||
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), checkpoint, to_save) | ||
trainer.run(torch.randn(10, 1), 5) | ||
|
||
to_load = to_save | ||
# load checkpoint myprefix_checkpoint_40.pt | ||
checkpoint.load_objects(to_load=to_load, global_step=40) | ||
|
||
Note: | ||
If ``to_load`` contains objects of type torch `DistributedDataParallel`_ or | ||
`DataParallel`_, method ``load_state_dict`` will applied to their internal wrapped model (``obj.module``). | ||
|
||
.. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/ | ||
torch.nn.parallel.DistributedDataParallel.html | ||
.. _DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html | ||
""" | ||
|
||
global_step = filename_components.get("global_step", None) | ||
|
||
filename_pattern = self._get_filename_pattern(global_step) | ||
|
||
checkpoint = self._setup_checkpoint() | ||
name = "checkpoint" | ||
if len(checkpoint) == 1: | ||
for k in checkpoint: | ||
name = k | ||
name = filename_components.get("name", name) | ||
score = filename_components.get("score", None) | ||
|
||
filename_dict = { | ||
"filename_prefix": self.filename_prefix, | ||
"ext": self.ext, | ||
"name": name, | ||
"score_name": self.score_name, | ||
"score": score, | ||
"global_step": global_step, | ||
} | ||
|
||
checkpoint_fp = filename_pattern.format(**filename_dict) | ||
|
||
path = self.save_handler.dirname / checkpoint_fp | ||
|
||
load_kwargs = {} if load_kwargs is None else load_kwargs | ||
|
||
Checkpoint.load_objects(to_load=to_load, checkpoint=path, **load_kwargs) | ||
|
||
def state_dict(self) -> "OrderedDict[str, List[Tuple[int, str]]]": | ||
"""Method returns state dict with saved items: list of ``(priority, filename)`` pairs. | ||
Can be used to save internal state of the class. | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can also remove unused Checkpoint