From 79fba098a7bf42782cc9b0e6dcc6c859ce2d1c59 Mon Sep 17 00:00:00 2001 From: Hummer12007 Date: Fri, 21 Apr 2023 11:57:00 +0300 Subject: [PATCH 1/7] Added score_sign to add_early_stopping_by_val_score --- ignite/contrib/engines/common.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ignite/contrib/engines/common.py b/ignite/contrib/engines/common.py index b22f52c18760..a08265c017c6 100644 --- a/ignite/contrib/engines/common.py +++ b/ignite/contrib/engines/common.py @@ -672,7 +672,7 @@ def save_best_model_by_val_score( def add_early_stopping_by_val_score( - patience: int, evaluator: Engine, trainer: Engine, metric_name: str + patience: int, evaluator: Engine, trainer: Engine, metric_name: str, score_sign: float = 1.0, ) -> EarlyStopping: """Method setups early stopping handler based on the score (named by `metric_name`) provided by `evaluator`. Metric value should increase in order to keep training and not early stop. @@ -687,7 +687,11 @@ def add_early_stopping_by_val_score( Returns: A :class:`~ignite.handlers.early_stopping.EarlyStopping` handler. """ - es_handler = EarlyStopping(patience=patience, score_function=get_default_score_fn(metric_name), trainer=trainer) + es_handler = EarlyStopping( + patience=patience, + score_function=get_default_score_fn(metric_name, score_sign=score_sign), + trainer=trainer + ) evaluator.add_event_handler(Events.COMPLETED, es_handler) return es_handler From 40acd15849bd6adb3fa2529f08a3018b942042d1 Mon Sep 17 00:00:00 2001 From: Hummer12007 Date: Fri, 21 Apr 2023 12:05:32 +0300 Subject: [PATCH 2/7] Updated docstring --- ignite/contrib/engines/common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ignite/contrib/engines/common.py b/ignite/contrib/engines/common.py index a08265c017c6..63b8b1e896a9 100644 --- a/ignite/contrib/engines/common.py +++ b/ignite/contrib/engines/common.py @@ -683,6 +683,8 @@ def add_early_stopping_by_val_score( trainer: trainer engine to stop the run if no improvement. metric_name: metric name to use for score evaluation. This metric should be present in `evaluator.state.metrics`. + score_sign: sign of the score: 1.0 or -1.0. For error-like metrics, e.g. smaller is better, + a negative score sign should be used (objects with larger score are retained). Default, 1.0. Returns: A :class:`~ignite.handlers.early_stopping.EarlyStopping` handler. From c21610818a735de79c4913d3cd3c436516514b7c Mon Sep 17 00:00:00 2001 From: Hummer12007 Date: Fri, 21 Apr 2023 12:13:40 +0300 Subject: [PATCH 3/7] Updated tests for add_early_stopping_by_val_score --- tests/ignite/contrib/engines/test_common.py | 33 +++++++++++++++------ 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/tests/ignite/contrib/engines/test_common.py b/tests/ignite/contrib/engines/test_common.py index 824ec6b15699..a4775d921a24 100644 --- a/tests/ignite/contrib/engines/test_common.py +++ b/tests/ignite/contrib/engines/test_common.py @@ -293,18 +293,25 @@ def set_eval_metric(engine): def test_add_early_stopping_by_val_score(): - trainer = Engine(lambda e, b: None) - evaluator = Engine(lambda e, b: None) - acc_scores = [0.1, 0.2, 0.3, 0.4, 0.3, 0.3, 0.2, 0.1, 0.1, 0.0] - @trainer.on(Events.EPOCH_COMPLETED) - def validate(engine): - evaluator.run([0, 1]) + def setup_trainer(): + trainer = Engine(lambda e, b: None) + evaluator = Engine(lambda e, b: None) - @evaluator.on(Events.EPOCH_COMPLETED) - def set_eval_metric(engine): - engine.state.metrics = {"acc": acc_scores[trainer.state.epoch - 1]} + + @trainer.on(Events.EPOCH_COMPLETED) + def validate(engine): + evaluator.run([0, 1]) + + @evaluator.on(Events.EPOCH_COMPLETED) + def set_eval_metric(engine): + acc = acc_scores[trainer.state.epoch - 1] + engine.state.metrics = {"acc": acc, "loss": 1 - acc} + + return trainer, evaluator + + trainer, evaluator = setup_trainer() add_early_stopping_by_val_score(patience=3, evaluator=evaluator, trainer=trainer, metric_name="acc") @@ -312,6 +319,14 @@ def set_eval_metric(engine): assert state.epoch == 7 + trainer, evaluator = setup_trainer() + + add_early_stopping_by_val_score(patience=3, evaluator=evaluator, trainer=trainer, metric_name="loss", score_sign=-1.0) + + state = trainer.run([0, 1], max_epochs=len(acc_scores)) + + assert state.epoch == 7 + def test_deprecated_setup_any_logging(): From ab7f2bc45a5130577da0d2de509cd3c28be8eceb Mon Sep 17 00:00:00 2001 From: Hummer12007 Date: Fri, 21 Apr 2023 20:04:33 +0300 Subject: [PATCH 4/7] Added score_sign to save_best_model_by_val_score --- ignite/contrib/engines/common.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/ignite/contrib/engines/common.py b/ignite/contrib/engines/common.py index 63b8b1e896a9..b6306003d9b7 100644 --- a/ignite/contrib/engines/common.py +++ b/ignite/contrib/engines/common.py @@ -580,6 +580,7 @@ def gen_save_best_models_by_val_score( n_saved: int = 3, trainer: Optional[Engine] = None, tag: str = "val", + score_sign: float = 1.0, **kwargs: Any, ) -> Checkpoint: """Method adds a handler to ``evaluator`` to save ``n_saved`` of best models based on the metric @@ -602,6 +603,8 @@ def gen_save_best_models_by_val_score( n_saved: number of best models to store trainer: trainer engine to fetch the epoch when saving the best model. tag: score name prefix: `{tag}_{metric_name}`. By default, tag is "val". + score_sign: sign of the score: 1.0 or -1.0. For error-like metrics, e.g. smaller is better, + a negative score sign should be used (objects with larger score are retained). Default, 1.0. kwargs: optional keyword args to be passed to construct :class:`~ignite.handlers.checkpoint.Checkpoint`. Returns: @@ -623,7 +626,7 @@ def gen_save_best_models_by_val_score( n_saved=n_saved, global_step_transform=global_step_transform, score_name=f"{tag}_{metric_name.lower()}", - score_function=Checkpoint.get_default_score_fn(metric_name), + score_function=Checkpoint.get_default_score_fn(metric_name, score_sign=score_sign), **kwargs, ) evaluator.add_event_handler(Events.COMPLETED, best_model_handler) @@ -639,6 +642,7 @@ def save_best_model_by_val_score( n_saved: int = 3, trainer: Optional[Engine] = None, tag: str = "val", + score_sign: float = 1.0, **kwargs: Any, ) -> Checkpoint: """Method adds a handler to ``evaluator`` to save on a disk ``n_saved`` of best models based on the metric @@ -654,6 +658,9 @@ def save_best_model_by_val_score( n_saved: number of best models to store trainer: trainer engine to fetch the epoch when saving the best model. tag: score name prefix: `{tag}_{metric_name}`. By default, tag is "val". + score_sign: sign of the score: 1.0 or -1.0. For error-like metrics, e.g. smaller is better, + a negative score sign should be used (objects with larger score are retained). Default, 1.0. + kwargs: optional keyword args to be passed to construct :class:`~ignite.handlers.checkpoint.Checkpoint`. Returns: @@ -667,6 +674,7 @@ def save_best_model_by_val_score( n_saved=n_saved, trainer=trainer, tag=tag, + score_sign=score_sign, **kwargs, ) From 05f722ca0b60c32ede389e269728debbac8d2b77 Mon Sep 17 00:00:00 2001 From: Hummer12007 Date: Fri, 21 Apr 2023 20:21:14 +0300 Subject: [PATCH 5/7] Add tests for score_sign with save_best_model_by_val_score --- tests/ignite/contrib/engines/test_common.py | 116 ++++++++++++++------ 1 file changed, 80 insertions(+), 36 deletions(-) diff --git a/tests/ignite/contrib/engines/test_common.py b/tests/ignite/contrib/engines/test_common.py index a4775d921a24..4749d5db1086 100644 --- a/tests/ignite/contrib/engines/test_common.py +++ b/tests/ignite/contrib/engines/test_common.py @@ -48,7 +48,6 @@ def _test_setup_common_training_handlers( save_handler=None, output_transform=lambda loss: loss, ): - lr = 0.01 step_size = 100 gamma = 0.5 @@ -218,7 +217,6 @@ def test_setup_common_training_handlers(dirname, capsys): def test_setup_common_training_handlers_using_save_handler(dirname, capsys): - save_handler = DiskSaver(dirname=dirname, require_empty=False) _test_setup_common_training_handlers(dirname=None, device="cpu", save_handler=save_handler) @@ -231,20 +229,25 @@ def test_setup_common_training_handlers_using_save_handler(dirname, capsys): def test_save_best_model_by_val_score(dirname): + acc_scores = [0.1, 0.2, 0.3, 0.4, 0.3, 0.5, 0.6, 0.61, 0.7, 0.5] - trainer = Engine(lambda e, b: None) - evaluator = Engine(lambda e, b: None) - model = DummyModel() + def setup_trainer(): + trainer = Engine(lambda e, b: None) + evaluator = Engine(lambda e, b: None) + model = DummyModel() - acc_scores = [0.1, 0.2, 0.3, 0.4, 0.3, 0.5, 0.6, 0.61, 0.7, 0.5] + @trainer.on(Events.EPOCH_COMPLETED) + def validate(engine): + evaluator.run([0, 1]) + + @evaluator.on(Events.EPOCH_COMPLETED) + def set_eval_metric(engine): + acc = acc_scores[trainer.state.epoch - 1] + engine.state.metrics = {"acc": acc, "loss": 1 - acc} - @trainer.on(Events.EPOCH_COMPLETED) - def validate(engine): - evaluator.run([0, 1]) + return trainer, evaluator, model - @evaluator.on(Events.EPOCH_COMPLETED) - def set_eval_metric(engine): - engine.state.metrics = {"acc": acc_scores[trainer.state.epoch - 1]} + trainer, evaluator, model = setup_trainer() save_best_model_by_val_score(dirname, evaluator, model, metric_name="acc", n_saved=2, trainer=trainer) @@ -252,22 +255,42 @@ def set_eval_metric(engine): assert set(os.listdir(dirname)) == {"best_model_8_val_acc=0.6100.pt", "best_model_9_val_acc=0.7000.pt"} + for fname in os.listdir(dirname): + os.unlink(f"{dirname}/{fname}") -def test_gen_save_best_models_by_val_score(): + trainer, evaluator, model = setup_trainer() - trainer = Engine(lambda e, b: None) - evaluator = Engine(lambda e, b: None) - model = DummyModel() + save_best_model_by_val_score( + dirname, evaluator, model, metric_name="loss", n_saved=2, trainer=trainer, score_sign=-1.0 + ) + + trainer.run([0, 1], max_epochs=len(acc_scores)) + + assert set(os.listdir(dirname)) == {"best_model_8_val_loss=-0.3900.pt", "best_model_9_val_loss=-0.3000.pt"} + +def test_gen_save_best_models_by_val_score(): acc_scores = [0.1, 0.2, 0.3, 0.4, 0.3, 0.5, 0.6, 0.61, 0.7, 0.5] + loss_scores = [0.9, 0.8, 0.7, 0.6, 0.7, 0.5, 0.4, 0.39, 0.3, 0.5] + + def setup_trainer(): + trainer = Engine(lambda e, b: None) + evaluator = Engine(lambda e, b: None) + model = DummyModel() + + @trainer.on(Events.EPOCH_COMPLETED) + def validate(engine): + evaluator.run([0, 1]) + + @evaluator.on(Events.EPOCH_COMPLETED) + def set_eval_metric(engine): + acc = acc_scores[trainer.state.epoch - 1] + loss = loss_scores[trainer.state.epoch - 1] + engine.state.metrics = {"acc": acc, "loss": loss} - @trainer.on(Events.EPOCH_COMPLETED) - def validate(engine): - evaluator.run([0, 1]) + return trainer, evaluator, model - @evaluator.on(Events.EPOCH_COMPLETED) - def set_eval_metric(engine): - engine.state.metrics = {"acc": acc_scores[trainer.state.epoch - 1]} + trainer, evaluator, model = setup_trainer() save_handler = MagicMock() @@ -291,6 +314,36 @@ def set_eval_metric(engine): any_order=True, ) + trainer, evaluator, model = setup_trainer() + + save_handler = MagicMock() + + gen_save_best_models_by_val_score( + save_handler, + evaluator, + {"a": model, "b": model}, + metric_name="loss", + n_saved=2, + trainer=trainer, + score_sign=-1.0, + ) + + trainer.run([0, 1], max_epochs=len(acc_scores)) + + assert save_handler.call_count == len(acc_scores) - 2 # 2 score values (-0.7 and -0.5) are not the best + obj_to_save = {"a": model.state_dict(), "b": model.state_dict()} + save_handler.assert_has_calls( + [ + call( + obj_to_save, + f"best_checkpoint_{e}_val_loss={p:.4f}.pt", + dict([("basename", "best_checkpoint"), ("score_name", "val_loss"), ("priority", p)]), + ) + for e, p in zip([1, 2, 3, 4, 6, 7, 8, 9], [-0.9, -0.8, -0.7, -0.6, -0.5, -0.4, -0.39, -0.3]) + ], + any_order=True, + ) + def test_add_early_stopping_by_val_score(): acc_scores = [0.1, 0.2, 0.3, 0.4, 0.3, 0.3, 0.2, 0.1, 0.1, 0.0] @@ -299,7 +352,6 @@ def setup_trainer(): trainer = Engine(lambda e, b: None) evaluator = Engine(lambda e, b: None) - @trainer.on(Events.EPOCH_COMPLETED) def validate(engine): evaluator.run([0, 1]) @@ -307,8 +359,8 @@ def validate(engine): @evaluator.on(Events.EPOCH_COMPLETED) def set_eval_metric(engine): acc = acc_scores[trainer.state.epoch - 1] - engine.state.metrics = {"acc": acc, "loss": 1 - acc} - + engine.state.metrics = {"acc": acc, "loss": 1 - acc} + return trainer, evaluator trainer, evaluator = setup_trainer() @@ -321,7 +373,9 @@ def set_eval_metric(engine): trainer, evaluator = setup_trainer() - add_early_stopping_by_val_score(patience=3, evaluator=evaluator, trainer=trainer, metric_name="loss", score_sign=-1.0) + add_early_stopping_by_val_score( + patience=3, evaluator=evaluator, trainer=trainer, metric_name="loss", score_sign=-1.0 + ) state = trainer.run([0, 1], max_epochs=len(acc_scores)) @@ -329,13 +383,11 @@ def set_eval_metric(engine): def test_deprecated_setup_any_logging(): - with pytest.raises(DeprecationWarning, match=r"deprecated since version 0.4.0"): setup_any_logging(None, None, None, None, None, None) def test__setup_logging_wrong_args(): - with pytest.raises(TypeError, match=r"Argument optimizers should be either a single optimizer or"): _setup_logging(MagicMock(), MagicMock(), "abc", MagicMock(), 1) @@ -421,7 +473,6 @@ def set_eval_metric(engine): def test_setup_tb_logging(dirname): - tb_logger = _test_setup_logging( setup_logging_fn=setup_tb_logging, kwargs_dict={"output_path": dirname / "t1"}, @@ -477,7 +528,6 @@ def test_setup_visdom_logging(visdom_offline_logfile): def test_setup_plx_logging(): - os.environ["POLYAXON_NO_OP"] = "1" _test_setup_logging( @@ -521,7 +571,6 @@ def test_setup_mlflow_logging(dirname): def test_setup_wandb_logging(dirname): - from unittest.mock import patch with patch("ignite.contrib.engines.common.WandBLogger") as _: @@ -529,7 +578,6 @@ def test_setup_wandb_logging(dirname): def test_setup_clearml_logging(): - handlers.clearml_logger.ClearMLLogger.set_bypass_mode(True) with pytest.warns(UserWarning, match=r"running in bypass mode"): @@ -598,7 +646,6 @@ def test_setup_neptune_logging(dirname): @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") def test_distrib_nccl_gpu(dirname, distributed_context_single_node_nccl): - local_rank = distributed_context_single_node_nccl["local_rank"] device = idist.device() _test_setup_common_training_handlers(dirname, device, rank=local_rank, local_rank=local_rank, distributed=True) @@ -608,7 +655,6 @@ def test_distrib_nccl_gpu(dirname, distributed_context_single_node_nccl): @pytest.mark.distributed @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") def test_distrib_gloo_cpu_or_gpu(dirname, distributed_context_single_node_gloo): - device = idist.device() local_rank = distributed_context_single_node_gloo["local_rank"] _test_setup_common_training_handlers(dirname, device, rank=local_rank, local_rank=local_rank, distributed=True) @@ -625,7 +671,6 @@ def test_distrib_gloo_cpu_or_gpu(dirname, distributed_context_single_node_gloo): @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_gloo_cpu_or_gpu(dirname, distributed_context_multi_node_gloo): - device = idist.device() rank = distributed_context_multi_node_gloo["rank"] _test_setup_common_training_handlers(dirname, device, rank=rank) @@ -636,7 +681,6 @@ def test_multinode_distrib_gloo_cpu_or_gpu(dirname, distributed_context_multi_no @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") def test_multinode_distrib_nccl_gpu(dirname, distributed_context_multi_node_nccl): - local_rank = distributed_context_multi_node_nccl["local_rank"] rank = distributed_context_multi_node_nccl["rank"] device = idist.device() From 0ba32fe797fbd7ea43f74dba09a52aaed4e1974a Mon Sep 17 00:00:00 2001 From: Hummer12007 Date: Wed, 26 Apr 2023 01:48:31 +0000 Subject: [PATCH 6/7] Fmt fix --- ignite/contrib/engines/common.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/ignite/contrib/engines/common.py b/ignite/contrib/engines/common.py index b6306003d9b7..3896f5a28433 100644 --- a/ignite/contrib/engines/common.py +++ b/ignite/contrib/engines/common.py @@ -680,7 +680,11 @@ def save_best_model_by_val_score( def add_early_stopping_by_val_score( - patience: int, evaluator: Engine, trainer: Engine, metric_name: str, score_sign: float = 1.0, + patience: int, + evaluator: Engine, + trainer: Engine, + metric_name: str, + score_sign: float = 1.0, ) -> EarlyStopping: """Method setups early stopping handler based on the score (named by `metric_name`) provided by `evaluator`. Metric value should increase in order to keep training and not early stop. @@ -698,9 +702,7 @@ def add_early_stopping_by_val_score( A :class:`~ignite.handlers.early_stopping.EarlyStopping` handler. """ es_handler = EarlyStopping( - patience=patience, - score_function=get_default_score_fn(metric_name, score_sign=score_sign), - trainer=trainer + patience=patience, score_function=get_default_score_fn(metric_name, score_sign=score_sign), trainer=trainer ) evaluator.add_event_handler(Events.COMPLETED, es_handler) From 6d8d0987d6b47d477fbe2f06225f0bfeee36a21e Mon Sep 17 00:00:00 2001 From: Hummer12007 Date: Wed, 26 Apr 2023 04:49:19 +0300 Subject: [PATCH 7/7] Update ignite/contrib/engines/common.py Co-authored-by: vfdev --- ignite/contrib/engines/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/contrib/engines/common.py b/ignite/contrib/engines/common.py index 3896f5a28433..95e4e09cb3b1 100644 --- a/ignite/contrib/engines/common.py +++ b/ignite/contrib/engines/common.py @@ -626,7 +626,7 @@ def gen_save_best_models_by_val_score( n_saved=n_saved, global_step_transform=global_step_transform, score_name=f"{tag}_{metric_name.lower()}", - score_function=Checkpoint.get_default_score_fn(metric_name, score_sign=score_sign), + score_function=get_default_score_fn(metric_name, score_sign=score_sign), **kwargs, ) evaluator.add_event_handler(Events.COMPLETED, best_model_handler)