8000 Add filename components in Checkpoint by sdesrozis · Pull Request #2498 · pytorch/ignite · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add filename components in Checkpoint #2498

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 119 additions & 29 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,18 @@ def __init__(
self.include_self = include_self
self.greater_or_equal = greater_or_equal

def _get_filename_pattern(self, global_step: Optional[int]) -> str:
if self.filename_pattern is None:
filename_pattern = self.setup_filename_pattern(
with_prefix=len(self.filename_prefix) > 0,
with_score=self.score_function is not None,
with_score_name=self.score_name is not None,
with_global_step=global_step is not None,
)
else:
filename_pattern = self.filename_pattern
return filename_pattern

def reset(self) -> None:
"""Method to reset saved checkpoint names.

Expand Down Expand Up @@ -402,15 +414,7 @@ def __call__(self, engine: Engine) -> None:
name = k
checkpoint = checkpoint[name]

if self.filename_pattern is None:
filename_pattern = self.setup_filename_pattern(
with_prefix=len(self.filename_prefix) > 0,
with_score=self.score_function is not None,
with_score_name=self.score_name is not None,
with_global_step=global_step is not None,
)
else:
filename_pattern = self.filename_pattern
filename_pattern = self._get_filename_pattern(global_step)

filename_dict = {
"filename_prefix": self.filename_prefix,
Expand Down Expand Up @@ -519,41 +523,51 @@ def _check_objects(objs: Mapping, attr: str) -> None:
raise TypeError(f"Object {type(obj)} should have `{attr}` method")

@staticmethod
def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping], **kwargs: Any) -> None:
def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping, Path], **kwargs: Any) -> None:
"""Helper method to apply ``load_state_dict`` on the objects from ``to_load`` using states from ``checkpoint``.

Args:
to_load: a dictionary with objects, e.g. `{"model": model, "optimizer": optimizer, ...}`
checkpoint: a string filepath or a dictionary with state_dicts to load, e.g. `{"model": model_state_dict,
"optimizer": opt_state_dict}`. If `to_load` contains a single key, then checkpoint can contain
directly corresponding state_dict.
checkpoint: a path, a string filepath or a dictionary with state_dicts to load, e.g.
`{"model": model_state_dict, "optimizer": opt_state_dict}`. If `to_load` contains a single key,
then checkpoint can contain directly corresponding state_dict.
kwargs: Keyword arguments accepted for `nn.Module.load_state_dict()`. Passing `strict=False` enables
the user to load part of the pretrained model (useful for example, in Transfer Learning)

Examples:
.. code-block:: python

import tempfile
from pathlib import Path

import torch

from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint, Checkpoint
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also remove unused Checkpoint


trainer = Engine(lambda engine, batch: None)
handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=None, create_dir=True)
model = torch.nn.Linear(3, 3)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
to_save = {"weights": model, "optimizer": optimizer}
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save)
trainer.run(torch.randn(10, 1), 5)

to_load = to_save
checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth"
checkpoint = torch.load(checkpoint_fp)
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
with tempfile.TemporaryDirectory() as tmpdirname:
handler = ModelCheckpoint(tmpdirname, 'myprefix', n_saved=None, create_dir=True)

model = torch.nn.Linear(3, 3)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

to_save = {"weights": model, "optimizer": optimizer}

trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save)
trainer.run(torch.randn(10, 1), 5)

to_load = to_save
checkpoint_fp = Path(tmpdirname) / 'myprefix_checkpoint_40.pt'
checkpoint = torch.load(checkpoint_fp)
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
Comment on lines +550 to +564
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sdesrozis this example is not working:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [2], in <cell line: 10>()
     21 to_load = to_save
     22 # load checkpoint myprefix_checkpoint_40.pt
---> 23 checkpoint.load_objects(to_load=to_load, global_step=40)

TypeError: load_objects() missing 1 required positional argument: 'checkpoint'


# or using a string for checkpoint filepath
# or using a string for checkpoint filepath

to_load = to_save
checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth"
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_fp)
to_load = to_save
checkpoint_fp = Path(tmpdirname) / 'myprefix_checkpoint_40.pt'
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_fp)

Note:
If ``to_load`` contains objects of type torch `DistributedDataParallel`_ or
Expand All @@ -564,13 +578,13 @@ def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping], **kwargs: An
.. _DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html
"""

if isinstance(checkpoint, str):
if isinstance(checkpoint, (str, Path)):
checkpoint_obj = torch.load(checkpoint)
else:
checkpoint_obj = checkpoint

Checkpoint._check_objects(to_load, "load_state_dict")
if not isinstance(checkpoint, (collections.Mapping, str)):
if not isinstance(checkpoint, (collections.Mapping, str, Path)):
raise TypeError(f"Argument checkpoint should be a string or a dictionary, but given {type(checkpoint)}")

if len(kwargs) > 1 or any(k for k in kwargs if k not in ["strict"]):
Expand Down Expand Up @@ -599,6 +613,82 @@ def _load_object(obj: Any, chkpt_obj: Any) -> None:
raise ValueError(f"Object labeled by '{k}' from `to_load` is not found in the checkpoint")
_load_object(obj, checkpoint_obj[k])

def reload_objects(self, to_load: Mapping, load_kwargs: Optional[Dict] = None, **filename_components: Any) -> None:
"""Helper method to apply ``load_state_dict`` on the objects from ``to_load``. Filename components such as
name, score and global state can be configured.

Args:
to_load: a dictionary with objects, e.g. `{"model": model, "optimizer": optimizer, ...}`
load_kwargs: Keyword arguments accepted for `nn.Module.load_state_dict()`. Passing `strict=False` enables
the user to load part of the pretrained model (useful for example, in Transfer Learning)
filename_components: Filename components used to define the checkpoint file path.
Keyword arguments accepted are `name`, `score` and `global_state`.

Examples:
.. code-block:: python

import tempfile

import torch

from ignite.engine import Engine, Events
from ignite.handlers import ModelCheckpoint, Checkpoint

trainer = Engine(lambda engine, batch: None)

with tempfile.TemporaryDirectory() as tmpdirname:
checkpoint = ModelCheckpoint(tmpdirname, 'myprefix', n_saved=None, create_dir=True)

model = torch.nn.Linear(3, 3)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

to_save = {"weights": model, "optimizer": optimizer}

trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), checkpoint, to_save)
trainer.run(torch.randn(10, 1), 5)

to_load = to_save
# load checkpoint myprefix_checkpoint_40.pt
checkpoint.load_objects(to_load=to_load, global_step=40)

Note:
If ``to_load`` contains objects of type torch `DistributedDataParallel`_ or
`DataParallel`_, method ``load_state_dict`` will applied to their internal wrapped model (``obj.module``).

.. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/
torch.nn.parallel.DistributedDataParallel.html
.. _DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html
"""

global_step = filename_components.get("global_step", None)

filename_pattern = self._get_filename_pattern(global_step)

checkpoint = self._setup_checkpoint()
name = "checkpoint"
if len(checkpoint) == 1:
for k in checkpoint:
name = k
name = filename_components.get("name", name)
score = filename_components.get("score", None)

filename_dict = {
"filename_prefix": self.filename_prefix,
"ext": self.ext,
"name": name,
"score_name": self.score_name,
"score": score,
"global_step": global_step,
}

checkpoint_fp = filename_pattern.format(**filename_dict)

path = self.save_handler.dirname / checkpoint_fp

load_kwargs = {} if load_kwargs is None else load_kwargs

Checkpoint.load_objects(to_load=to_load, checkpoint=path, **load_kwargs)

def state_dict(self) -> "OrderedDict[str, List[Tuple[int, str]]]":
"""Method returns state dict with saved items: list of ``(priority, filename)`` pairs.
Can be used to save internal state of the class.
Expand Down
7 changes: 7 additions & 0 deletions tests/ignite/handlers/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,9 @@ def test_model_checkpoint_simple_recovery(dirname):
assert fname.exists()
loaded_objects = torch.load(fname)
assert loaded_objects == model.state_dict()
to_load = {"model": DummyModel()}
h.reload_objects(to_load=to_load, global_step=1)
assert to_load["model"].state_dict() == model.state_dict()


def test_model_checkpoint_simple_recovery_from_existing_non_empty(dirname):
Expand All @@ -600,6 +603,9 @@ def _test(ext, require_empty):
assert previous_fname.exists()
loaded_objects = torch.load(fname)
assert loaded_objects == model.state_dict()
to_load = {"model": DummyModel()}
h.reload_objects(to_load=to_load, global_step=1)
assert to_load["model"].state_dict() == model.state_dict()
fname.unlink()

_test(".txt", require_empty=True)
Expand Down Expand Up @@ -1118,6 +1124,7 @@ def _get_multiple_objs_to_save():
assert str(dirname / _PREFIX) in str(fname)
assert fname.exists()
Checkpoint.load_objects(to_save, str(fname))
Checkpoint.load_objects(to_save, fname)
fname.unlink()

# case: multiple objects
Expand Down
0