From 47b365c5816c7a4349519362cb3bfd018689e066 Mon Sep 17 00:00:00 2001 From: Mark Neumann Date: Tue, 17 Jul 2018 13:13:36 -0700 Subject: [PATCH 1/3] save the best validation metrics for all metrics --- allennlp/tests/training/trainer_test.py | 9 +++++++++ allennlp/training/trainer.py | 7 +++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/allennlp/tests/training/trainer_test.py b/allennlp/tests/training/trainer_test.py index 19dcb3a4e53..a448ee83827 100644 --- a/allennlp/tests/training/trainer_test.py +++ b/allennlp/tests/training/trainer_test.py @@ -51,8 +51,13 @@ def test_trainer_can_run(self): validation_dataset=self.instances, num_epochs=2) metrics = trainer.train() + print("metrics dict: ", metrics) assert 'best_validation_loss' in metrics assert isinstance(metrics['best_validation_loss'], float) + assert 'best_validation_accuracy' in metrics + assert isinstance(metrics['best_validation_accuracy'], float) + assert 'best_validation_accuracy3' in metrics + assert isinstance(metrics['best_validation_accuracy3'], float) assert 'best_epoch' in metrics assert isinstance(metrics['best_epoch'], int) @@ -67,6 +72,10 @@ def test_trainer_can_run(self): metrics = trainer.train() assert 'best_validation_loss' in metrics assert isinstance(metrics['best_validation_loss'], float) + assert 'best_validation_accuracy' in metrics + assert isinstance(metrics['best_validation_accuracy'], float) + assert 'best_validation_accuracy3' in metrics + assert isinstance(metrics['best_validation_accuracy3'], float) assert 'best_epoch' in metrics assert isinstance(metrics['best_epoch'], int) diff --git a/allennlp/training/trainer.py b/allennlp/training/trainer.py index 6699c27479e..8e38066d393 100644 --- a/allennlp/training/trainer.py +++ b/allennlp/training/trainer.py @@ -706,6 +706,7 @@ def train(self) -> Dict[str, Any]: train_metrics: Dict[str, float] = {} val_metrics: Dict[str, float] = {} + best_epoch_val_metrics: Dict[str, float] = {} epochs_trained = 0 training_start_time = time.time() for epoch in range(epoch_counter, self._num_epochs): @@ -723,7 +724,8 @@ def train(self) -> Dict[str, Any]: # Check validation metric to see if it's the best so far is_best_so_far = self._is_best_so_far(this_epoch_val_metric, validation_metric_per_epoch) - + if is_best_so_far: + best_epoch_val_metrics = val_metrics validation_metric_per_epoch.append(this_epoch_val_metric) if self._should_stop_early(validation_metric_per_epoch): logger.info("Ran out of patience. Stopping training.") @@ -733,6 +735,7 @@ def train(self) -> Dict[str, Any]: # No validation set, so just assume it's the best so far. is_best_so_far = True val_metrics = {} + best_epoch_val_metrics = {} this_epoch_val_metric = None self._save_checkpoint(epoch, validation_metric_per_epoch, is_best=is_best_so_far) @@ -773,7 +776,7 @@ def train(self) -> Dict[str, Any]: best_validation_metric = min(validation_metric_per_epoch) else: best_validation_metric = max(validation_metric_per_epoch) - metrics[f"best_validation_{self._validation_metric}"] = best_validation_metric + metrics.update({f"best_validation_{k}": v for k,v in best_epoch_val_metrics.items()}) metrics['best_epoch'] = [i for i, value in enumerate(validation_metric_per_epoch) if value == best_validation_metric][-1] return metrics From 815ce8862cf43d2bc756a3872f75696ba5e3318e Mon Sep 17 00:00:00 2001 From: Mark Neumann Date: Tue, 17 Jul 2018 13:17:10 -0700 Subject: [PATCH 2/3] lint --- allennlp/training/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/allennlp/training/trainer.py b/allennlp/training/trainer.py index 8e38066d393..5f91dd65e9d 100644 --- a/allennlp/training/trainer.py +++ b/allennlp/training/trainer.py @@ -725,7 +725,7 @@ def train(self) -> Dict[str, Any]: # Check validation metric to see if it's the best so far is_best_so_far = self._is_best_so_far(this_epoch_val_metric, validation_metric_per_epoch) if is_best_so_far: - best_epoch_val_metrics = val_metrics + best_epoch_val_metrics = val_metrics.copy() validation_metric_per_epoch.append(this_epoch_val_metric) if self._should_stop_early(validation_metric_per_epoch): logger.info("Ran out of patience. Stopping training.") @@ -776,7 +776,7 @@ def train(self) -> Dict[str, Any]: best_validation_metric = min(validation_metric_per_epoch) else: best_validation_metric = max(validation_metric_per_epoch) - metrics.update({f"best_validation_{k}": v for k,v in best_epoch_val_metrics.items()}) + metrics.update({f"best_validation_{k}": v for k, v in best_epoch_val_metrics.items()}) metrics['best_epoch'] = [i for i, value in enumerate(validation_metric_per_epoch) if value == best_validation_metric][-1] return metrics From bc7aa458c728ef806273f72921eeabcbbab2d890 Mon Sep 17 00:00:00 2001 From: Mark Neumann Date: Tue, 17 Jul 2018 13:17:56 -0700 Subject: [PATCH 3/3] remove print --- allennlp/tests/training/trainer_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/allennlp/tests/training/trainer_test.py b/allennlp/tests/training/trainer_test.py index a448ee83827..8c146dff588 100644 --- a/allennlp/tests/training/trainer_test.py +++ b/allennlp/tests/training/trainer_test.py @@ -51,7 +51,6 @@ def test_trainer_can_run(self): validation_dataset=self.instances, num_epochs=2) metrics = trainer.train() - print("metrics dict: ", metrics) assert 'best_validation_loss' in metrics assert isinstance(metrics['best_validation_loss'], float) assert 'best_validation_accuracy' in metrics