8000 Added score_sign to add_early_stopping_by_val_score and gen_save_best_models_by_val_score by Hummer12007 · Pull Request #2929 · pytorch/ignite · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Added score_sign to add_early_stopping_by_val_score and gen_save_best_models_by_val_score #2929

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension .py  (2)

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions ignite/contrib/engines/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ def gen_save_best_models_by_val_score(
n_saved: int = 3,
trainer: Optional[Engine] = None,
tag: str = "val",
score_sign: float = 1.0,
**kwargs: Any,
) -> Checkpoint:
"""Method adds a handler to ``evaluator`` to save ``n_saved`` of best models based on the metric
Expand All @@ -602,6 +603,8 @@ def gen_save_best_models_by_val_score(
n_saved: number of best models to store
trainer: trainer engine to fetch the epoch when saving the best model.
tag: score name prefix: `{tag}_{metric_name}`. By default, tag is "val".
score_sign: sign of the score: 1.0 or -1.0. For error-like metrics, e.g. smaller is better,
a negative score sign should be used (objects with larger score are retained). Default, 1.0.
kwargs: optional keyword args to be passed to construct :class:`~ignite.handlers.checkpoint.Checkpoint`.

Returns:
Expand All @@ -623,7 +626,7 @@ def gen_save_best_models_by_val_score(
n_saved=n_saved,
global_step_transform=global_step_transform,
score_name=f"{tag}_{metric_name.lower()}",
score_function=Checkpoint.get_default_score_fn(metric_name),
score_function=get_default_score_fn(metric_name, score_sign=score_sign),
**kwargs,
)
evaluator.add_event_handler(Events.COMPLETED, best_model_handler)
Expand All @@ -639,6 +642,7 @@ def save_best_model_by_val_score(
n_saved: int = 3,
trainer: Optional[Engine] = None,
tag: str = "val",
score_sign: float = 1.0,
**kwargs: Any,
) -> Checkpoint:
"""Method adds a handler to ``evaluator`` to save on a disk ``n_saved`` of best models based on the metric
Expand All @@ -654,6 +658,9 @@ def save_best_model_by_val_score(
n_saved: number of best models to store
trainer: trainer engine to fetch the epoch when saving the best model.
tag: score name prefix: `{tag}_{metric_name}`. By default, tag is "val".
score_sign: sign of the score: 1.0 or -1.0. For error-like metrics, e.g. smaller is better,
a negative score sign should be used (objects with larger score are retained). Default, 1.0.

kwargs: optional keyword args to be passed to construct :class:`~ignite.handlers.checkpoint.Checkpoint`.

Returns:
Expand All @@ -667,12 +674,17 @@ def save_best_model_by_val_score(
n_saved=n_saved,
trainer=trainer,
tag=tag,
score_sign=score_sign,
**kwargs,
)


def add_early_stopping_by_val_score(
patience: int, evaluator: Engine, trainer: Engine, metric_name: str
patience: int,
evaluator: Engine,
trainer: Engine,
metric_name: str,
score_sign: float = 1.0,
) -> EarlyStopping:
"""Method setups early stopping handler based on the score (named by `metric_name`) provided by `evaluator`.
Metric value should increase in order to keep training and not early stop.
Expand All @@ -683,11 +695,15 @@ def add_early_stopping_by_val_score(
trainer: trainer engine to stop the run if no improvement.
metric_name: metric name to use for score evaluation. This metric should be present in
`evaluator.state.metrics`.
score_sign: sign of the score: 1.0 or -1.0. For error-like metrics, e.g. smaller is better,
a negative score sign should be used (objects with larger score are retained). Default, 1.0.

Returns:
A :class:`~ignite.handlers.early_stopping.EarlyStopping` handler.
"""
es_handler = EarlyStopping(patience=patience, score_function=get_default_score_fn(metric_name), trainer=trainer)
es_handler = EarlyStopping(
patience=patience, score_function=get_default_score_fn(metric_name, score_sign=score_sign), trainer=trainer
)
evaluator.add_event_handler(Events.COMPLETED, es_handler)

return es_handler
141 changes: 100 additions & 41 deletions tests/ignite/contrib/engines/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def _test_setup_common_training_handlers(
save_handler=None,
output_transform=lambda loss: loss,
):

lr = 0.01
step_size = 100
gamma = 0.5
Expand Down Expand Up @@ -218,7 +217,6 @@ def test_setup_common_training_handlers(dirname, capsys):


def test_setup_common_training_handlers_using_save_handler(dirname, capsys):

save_handler = DiskSaver(dirname=dirname, require_empty=False)
_test_setup_common_training_handlers(dirname=None, device="cpu", save_handler=save_handler)

Expand All @@ -231,43 +229,68 @@ def test_setup_common_training_handlers_using_save_handler(dirname, capsys):


def test_save_best_model_by_val_score(dirname):
acc_scores = [0.1, 0.2, 0.3, 0.4, 0.3, 0.5, 0.6, 0.61, 0.7, 0.5]

trainer = Engine(lambda e, b: None)
evaluator = Engine(lambda e, b: None)
model = DummyModel()
def setup_trainer():
trainer = Engine(lambda e, b: None)
evaluator = Engine(lambda e, b: None)
model = DummyModel()

acc_scores = [0.1, 0.2, 0.3, 0.4, 0.3, 0.5, 0.6, 0.61, 0.7, 0.5]
@trainer.on(Events.EPOCH_COMPLETED)
def validate(engine):
evaluator.run([0, 1])

@trainer.on(Events.EPOCH_COMPLETED)
def validate(engine):
evaluator.run([0, 1])
@evaluator.on(Events.EPOCH_COMPLETED)
def set_eval_metric(engine):
acc = acc_scores[trainer.state.epoch - 1]
engine.state.metrics = {"acc": acc, "loss": 1 - acc}

return trainer, evaluator, model

@evaluator.on(Events.EPOCH_COMPLETED)
def set_eval_metric(engine):
engine.state.metrics = {"acc": acc_scores[trainer.state.epoch - 1]}
trainer, evaluator, model = setup_trainer()

save_best_model_by_val_score(dirname, evaluator, model, metric_name="acc", n_saved=2, trainer=trainer)

trainer.run([0, 1], max_epochs=len(acc_scores))

assert set(os.listdir(dirname)) == {"best_model_8_val_acc=0.6100.pt", "best_model_9_val_acc=0.7000.pt"}

for fname in os.listdir(dirname):
os.unlink(f"{dirname}/{fname}")

def test_gen_save_best_models_by_val_score():
trainer, evaluator, model = setup_trainer()

save_best_model_by_val_score(
dirname, evaluator, model, metric_name="loss", n_saved=2, trainer=trainer, score_sign=-1.0
)

trainer.run([0, 1], max_epochs=len(acc_scores))

assert set(os.listdir(dirname)) == {"best_model_8_val_loss=-0.3900.pt", "best_model_9_val_loss=-0.3000.pt"}

trainer = Engine(lambda e, b: None)
evaluator = Engine(lambda e, b: None)
model = DummyModel()

def test_gen_save_best_models_by_val_score():
acc_scores = [0.1, 0.2, 0.3, 0.4, 0.3, 0.5, 0.6, 0.61, 0.7, 0.5]
loss_scores = [0.9, 0.8, 0.7, 0.6, 0.7, 0.5, 0.4, 0.39, 0.3, 0.5]

def setup_trainer():
trainer = Engine(lambda e, b: None)
evaluator = Engine(lambda e, b: None)
model = DummyModel()

@trainer.on(Events.EPOCH_COMPLETED)
def validate(engine):
evaluator.run([0, 1])

@trainer.on(Events.EPOCH_COMPLETED)
def validate(engine):
evaluator.run([0, 1])
@evaluator.on(Events.EPOCH_COMPLETED)
def set_eval_metric(engine):
acc = acc_scores[trainer.state.epoch - 1]
loss = loss_scores[trainer.state.epoch - 1]
engine.state.metrics = {"acc": acc, "loss": loss}

@evaluator.on(Events.EPOCH_COMPLETED)
def set_eval_metric(engine):
engine.state.metrics = {"acc": acc_scores[trainer.state.epoch - 1]}
return trainer, evaluator, model

trainer, evaluator, model = setup_trainer()

save_handler = MagicMock()

Expand All @@ -291,36 +314,80 @@ def set_eval_metric(engine):
any_order=True,
)

trainer, evaluator, model = setup_trainer()

def test_add_early_stopping_by_val_score():
trainer = Engine(lambda e, b: None)
evaluator = Engine(lambda e, b: None)
save_handler = MagicMock()

gen_save_best_models_by_val_score(
save_handler,
evaluator,
{"a": model, "b": model},
metric_name="loss",
n_saved=2,
trainer=trainer,
score_sign=-1.0,
)

trainer.run([0, 1], max_epochs=len(acc_scores))

assert save_handler.call_count == len(acc_scores) - 2 # 2 score values (-0.7 and -0.5) are not the best
obj_to_save = {"a": model.state_dict(), "b": model.state_dict()}
save_handler.assert_has_calls(
[
call(
obj_to_save,
f"best_checkpoint_{e}_val_loss={p:.4f}.pt",
dict([("basename", "best_checkpoint"), ("score_name", "val_loss"), ("priority", p)]),
)
for e, p in zip([1, 2, 3, 4, 6, 7, 8, 9], [-0.9, -0.8, -0.7, -0.6, -0.5, -0.4, -0.39, -0.3])
],
any_order=True,
)


def test_add_early_stopping_by_val_score():
acc_scores = [0.1, 0.2, 0.3, 0.4, 0.3, 0.3, 0.2, 0.1, 0.1, 0.0]

@trainer.on(Events.EPOCH_COMPLETED)
def validate(engine):
evaluator.run([0, 1])
def setup_trainer():
trainer = Engine(lambda e, b: None)
evaluator = Engine(lambda e, b: None)

@trainer.on(Events.EPOCH_COMPLETED)
def validate(engine):
evaluator.run([0, 1])

@evaluator.on(Events.EPOCH_COMPLETED)
def set_eval_metric(engine):
engine.state.metrics = {"acc": acc_scores[trainer.state.epoch - 1]}
@evaluator.on(Events.EPOCH_COMPLETED)
def set_eval_metric(engine):
acc = acc_scores[trainer.state.epoch - 1]
engine.state.metrics = {"acc": acc, "loss": 1 - acc}

return trainer, evaluator

trainer, evaluator = setup_trainer()

add_early_stopping_by_val_score(patience=3, evaluator=evaluator, trainer=trainer, metric_name="acc")

state = trainer.run([0, 1], max_epochs=len(acc_scores))

assert state.epoch == 7

trainer, evaluator = setup_trainer()

def test_deprecated_setup_any_logging():
add_early_stopping_by_val_score(
patience=3, evaluator=evaluator, trainer=trainer, metric_name="loss", score_sign=-1.0
)

state = trainer.run([0, 1], max_epochs=len(acc_scores))

assert state.epoch == 7


def test_deprecated_setup_any_logging():
with pytest.raises(DeprecationWarning, match=r"deprecated since version 0.4.0"):
setup_any_logging(None, None, None, None, None, None)


def test__setup_logging_wrong_args():

with pytest.raises(TypeError, match=r"Argument optimizers should be either a single optimizer or"):
_setup_logging(MagicMock(), MagicMock(), "abc", MagicMock(), 1)

Expand Down Expand Up @@ -406,7 +473,6 @@ def set_eval_metric(engine):


def test_setup_tb_logging(dirname):

tb_logger = _test_setup_logging(
setup_logging_fn=setup_tb_logging,
kwargs_dict={"output_path": dirname / "t1"},
Expand Down Expand Up @@ -462,7 +528,6 @@ def test_setup_visdom_logging(visdom_offline_logfile):


def test_setup_plx_logging():

os.environ["POLYAXON_NO_OP"] = "1"

_test_setup_logging(
Expand Down Expand Up @@ -506,15 +571,13 @@ def test_setup_mlflow_logging(dirname):


def test_setup_wandb_logging(dirname):

from unittest.mock import patch

with patch("ignite.contrib.engines.common.WandBLogger") as _:
setup_wandb_logging(MagicMock())


def test_setup_clearml_logging():

handlers.clearml_logger.ClearMLLogger.set_bypass_mode(True)

with pytest.warns(UserWarning, match=r"running in bypass mode"):
Expand Down Expand Up @@ -583,7 +646,6 @@ def test_setup_neptune_logging(dirname):
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_distrib_nccl_gpu(dirname, distributed_context_single_node_nccl):

local_rank = distributed_context_single_node_nccl["local_rank"]
device = idist.device()
_test_setup_common_training_handlers(dirname, device, rank=local_rank, local_rank=local_rank, distributed=True)
Expand All @@ -593,7 +655,6 @@ def test_distrib_nccl_gpu(dirname, distributed_context_single_node_nccl):
@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
def test_distrib_gloo_cpu_or_gpu(dirname, distributed_context_single_node_gloo):

device = idist.device()
local_rank = distributed_context_single_node_gloo["local_rank"]
_test_setup_common_training_handlers(dirname, device, rank=local_rank, local_rank=local_rank, distributed=True)
Expand All @@ -610,7 +671,6 @@ def test_distrib_gloo_cpu_or_gpu(dirname, distributed_context_single_node_gloo):
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_gloo_cpu_or_gpu(dirname, distributed_context_multi_node_gloo):

device = idist.device()
rank = distributed_context_multi_node_gloo["rank"]
_test_setup_common_training_handlers(dirname, device, rank=rank)
Expand All @@ -621,7 +681,6 @@ def test_multinode_distrib_gloo_cpu_or_gpu(dirname, distributed_context_multi_no
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_nccl_gpu(dirname, distributed_context_multi_node_nccl):

local_rank = distributed_context_multi_node_nccl["local_rank"]
rank = distributed_context_multi_node_nccl["rank"]
device = idist.device()
Expand Down
0