From 88c25e119ea826f669deba9091502513036d3f68 Mon Sep 17 00:00:00 2001 From: ShaohonChen Date: Tue, 3 Jun 2025 19:19:02 +0800 Subject: [PATCH 1/5] add support for SwanLabTracker and update related documentation --- docs/source/package_reference/tracking.md | 5 + .../deepspeed_with_config_support.py | 2 +- .../by_feature/megatron_lm_gpt_pretraining.py | 2 +- setup.py | 10 +- src/accelerate/accelerator.py | 1 + src/accelerate/test_utils/testing.py | 13 ++- src/accelerate/tracking.py | 107 ++++++++++++++++++ src/accelerate/utils/__init__.py | 1 + src/accelerate/utils/dataclasses.py | 2 + src/accelerate/utils/imports.py | 4 + 10 files changed, 143 insertions(+), 4 deletions(-) diff --git a/docs/source/package_reference/tracking.md b/docs/source/package_reference/tracking.md index 4f69e027b11..9e0e8cba0f6 100644 --- a/docs/source/package_reference/tracking.md +++ b/docs/source/package_reference/tracking.md @@ -48,3 +48,8 @@ rendered properly in your Markdown viewer. [[autodoc]] tracking.ClearMLTracker - __init__ + +## SwanLabTracker + +[[autodoc]] tracking.SwanLabTracker + - __init__ diff --git a/examples/by_feature/deepspeed_with_config_support.py b/examples/by_feature/deepspeed_with_config_support.py index ff7535761af..c7dcbf1ae6a 100755 --- a/examples/by_feature/deepspeed_with_config_support.py +++ b/examples/by_feature/deepspeed_with_config_support.py @@ -218,7 +218,7 @@ def parse_args(): default="all", help=( 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' - ' `"wandb"`, `"comet_ml"`, and `"dvclive"`. Use `"all"` (default) to report to all integrations.' + ' `"wandb"`, `"comet_ml"`, `"dvclive"`, and `"swanlab"`. Use `"all"` (default) to report to all integrations.' "Only applicable when `--with_tracking` is passed." ), ) diff --git a/examples/by_feature/megatron_lm_gpt_pretraining.py b/examples/by_feature/megatron_lm_gpt_pretraining.py index b357106e6d5..7c1153bc49d 100644 --- a/examples/by_feature/megatron_lm_gpt_pretraining.py +++ b/examples/by_feature/megatron_lm_gpt_pretraining.py @@ -215,7 +215,7 @@ def parse_args(): default="all", help=( 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' - ' `"wandb"`, `"comet_ml"`, and `"dvclive"`. Use `"all"` (default) to report to all integrations.' + ' `"wandb"`, `"comet_ml"`, and `"dvclive"`, and `"swanlab"`. Use `"all"` (default) to report to all integrations.' "Only applicable when `--with_tracking` is passed." ), ) diff --git a/setup.py b/setup.py index 9607ccb6a13..16656189a2a 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,15 @@ extras["rich"] = ["rich"] extras["test_fp8"] = ["torchao"] # note: TE for now needs to be done via pulling down the docker image directly -extras["test_trackers"] = ["wandb", "comet-ml", "tensorboard", "dvclive", "mlflow", "matplotlib"] +extras["test_trackers"] = [ + "wandb", + "comet-ml", + "tensorboard", + "dvclive", + "mlflow", + "matplotlib", + "swanlab", +] extras["dev"] = extras["quality"] + extras["testing"] + extras["rich"] extras["sagemaker"] = [ diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 368b041134d..feb5d2dbcf2 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -229,6 +229,7 @@ class Accelerator: - `"tensorboard"` - `"wandb"` - `"comet_ml"` + - `"swanlab"` If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`. project_config ([`~utils.ProjectConfiguration`], *optional*): diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py index a833bd51de1..1d71319608f 100644 --- a/src/accelerate/test_utils/testing.py +++ b/src/accelerate/test_utils/testing.py @@ -61,6 +61,7 @@ is_pytest_available, is_schedulefree_available, is_sdaa_available, + is_swanlab_available, is_tensorboard_available, is_timm_available, is_torch_version, @@ -482,6 +483,15 @@ def require_dvclive(test_case): return unittest.skipUnless(is_dvclive_available(), "test requires dvclive")(test_case) +def require_swanlab(test_case): + """ + Decorator marking a test that requires swanlab installed. These tests are skipped when swanlab isn't installed + """ + return unittest.skipUnless(is_swanlab_available(), "test requires swanlab")( + test_case + ) + + def require_pandas(test_case): """ Decorator marking a test that requires pandas installed. These tests are skipped when pandas isn't installed @@ -536,7 +546,8 @@ def require_matplotlib(test_case): _atleast_one_tracker_available = ( - any([is_wandb_available(), is_tensorboard_available()]) and not is_comet_ml_available() + any([is_wandb_available(), is_tensorboard_available(), is_swanlab_available()]) + and not is_comet_ml_available() ) diff --git a/src/accelerate/tracking.py b/src/accelerate/tracking.py index d4478296815..addd66acbae 100644 --- a/src/accelerate/tracking.py +++ b/src/accelerate/tracking.py @@ -34,6 +34,7 @@ is_comet_ml_available, is_dvclive_available, is_mlflow_available, + is_swanlab_available, is_tensorboard_available, is_wandb_available, listify, @@ -63,6 +64,9 @@ if is_dvclive_available(): _available_trackers.append(LoggerType.DVCLIVE) +if is_swanlab_available(): + _available_trackers.append(LoggerType.SWANLAB) + logger = get_logger(__name__) @@ -1061,6 +1065,107 @@ def finish(self): self.live.end() +class SwanLabTracker(GeneralTracker): + """ + A `Tracker` class that supports `swanlab`. Should be initialized at the start of your script. + + Args: + run_name (`str`): + The name of the experiment run. + **kwargs (additional keyword arguments, *optional*): + Additional key word arguments passed along to the `swanlab.init` method. + """ + + name = "swanlab" + requires_logging_directory = False + main_process_only = False + + def __init__(self, run_name: str, **kwargs): + super().__init__() + self.run_name = run_name + self.init_kwargs = kwargs + + @on_main_process + def start(self): + import swanlab + + self.run = swanlab.init(project=self.run_name, **self.init_kwargs) + swanlab.config["FRAMEWORK"] = "accelerate" # add accelerate logo in config + logger.debug(f"Initialized SwanLab project {self.run_name}") + logger.debug( + "Make sure to log any initial configurations with `self.store_init_configuration` before training!" + ) + + @property + def tracker(self): + return self.run + + @on_main_process + def store_init_configuration(self, values: dict): + """ + Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. + + Args: + values (Dictionary `str` to `bool`, `str`, `float` or `int`): + Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`, + `str`, `float`, `int`, or `None`. + """ + import swanlab + + swanlab.config.update(values, allow_val_change=True) + logger.debug("Stored initial configuration hyperparameters to SwanLab") + + @on_main_process + def log(self, values: dict, step: Optional[int] = None, **kwargs): + """ + Logs `values` to the current run. + + Args: + data : Dict[str, DataType] + Data must be a dict. + The key must be a string with 0-9, a-z, A-Z, " ", "_", "-", "/". + The value must be a `float`, `float convertible object`, `int` or `swanlab.data.BaseType`. + step : int, optional + The step number of the current data, if not provided, it will be automatically incremented. + If step is duplicated, the data will be ignored. + kwargs: + Additional key word arguments passed along to the `swanlab.log` method. Likes: + print_to_console : bool, optional + Whether to print the data to the console, the default is False. + """ + self.run.log(values, step=step, **kwargs) + logger.debug("Successfully logged to SwanLab") + + @on_main_process + def log_images(self, values: dict, step: Optional[int] = None, **kwargs): + """ + Logs `images` to the current run. + + Args: + values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`): + Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or + step (`int`, *optional*): + The run step. If included, the log will be affiliated with this step. + kwargs: + Additional key word arguments passed along to the `swanlab.log` method. Likes: + print_to_console : bool, optional + Whether to print the data to the console, the default is False. + """ + import swanlab + + for k, v in values.items(): + self.log({k: [swanlab.Image(image) for image in v]}, step=step, **kwargs) + logger.debug("Successfully logged images to SwanLab") + + @on_main_process + def finish(self): + """ + Closes `swanlab` writer + """ + self.run.finish() + logger.debug("SwanLab run closed") + + LOGGER_TYPE_TO_CLASS = { "aim": AimTracker, "comet_ml": CometMLTracker, @@ -1069,6 +1174,7 @@ def finish(self): "wandb": WandBTracker, "clearml": ClearMLTracker, "dvclive": DVCLiveTracker, + "swanlab": SwanLabTracker, } @@ -1093,6 +1199,7 @@ def filter_trackers( - `"comet_ml"` - `"mlflow"` - `"dvclive"` + - `"swanlab"` If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`. logging_dir (`str`, `os.PathLike`, *optional*): diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index 2fe8c19dd7b..84f4f742856 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -121,6 +121,7 @@ is_sagemaker_available, is_schedulefree_available, is_sdaa_available, + is_swanlab_available, is_tensorboard_available, is_timm_available, is_torch_xla_available, diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 5d414cf3222..71be9198346 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -701,6 +701,7 @@ class LoggerType(BaseEnum): - **WANDB** -- wandb as an experiment tracker - **COMETML** -- comet_ml as an experiment tracker - **DVCLIVE** -- dvclive as an experiment tracker + - **SWANLAB** -- swanlab as an experiment tracker """ ALL = "all" @@ -711,6 +712,7 @@ class LoggerType(BaseEnum): MLFLOW = "mlflow" CLEARML = "clearml" DVCLIVE = "dvclive" + SWANLAB = "swanlab" class PrecisionType(str, BaseEnum): diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index 3a7a5c6f6fe..ad7df7ec6e9 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -281,6 +281,10 @@ def is_comet_ml_available(): return _is_package_available("comet_ml") +def is_swanlab_available(): + return _is_package_available("swanlab") + + def is_boto3_available(): return _is_package_available("boto3") From a4f3102bba59c171ad642fa42514f6aefae7b8a3 Mon Sep 17 00:00:00 2001 From: ShaohonChen Date: Tue, 3 Jun 2025 20:15:35 +0800 Subject: [PATCH 2/5] add emoji in FRAMWORK --- src/accelerate/tracking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/tracking.py b/src/accelerate/tracking.py index addd66acbae..0b85cf4a228 100644 --- a/src/accelerate/tracking.py +++ b/src/accelerate/tracking.py @@ -1090,7 +1090,7 @@ def start(self): import swanlab self.run = swanlab.init(project=self.run_name, **self.init_kwargs) - swanlab.config["FRAMEWORK"] = "accelerate" # add accelerate logo in config + swanlab.config["FRAMEWORK"] = "🤗Accelerate" # add accelerate logo in config logger.debug(f"Initialized SwanLab project {self.run_name}") logger.debug( "Make sure to log any initial configurations with `self.store_init_configuration` before training!" From 63afb5ce0b0d0a08e3c2ec9c551df7c01dee0755 Mon Sep 17 00:00:00 2001 From: ShaohonChen Date: Tue, 3 Jun 2025 20:24:59 +0800 Subject: [PATCH 3/5] apply the style corrections and quality control --- src/accelerate/test_utils/testing.py | 7 ++----- src/accelerate/tracking.py | 5 ++--- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py index 1d71319608f..4e96b245198 100644 --- a/src/accelerate/test_utils/testing.py +++ b/src/accelerate/test_utils/testing.py @@ -487,9 +487,7 @@ def require_swanlab(test_case): """ Decorator marking a test that requires swanlab installed. These tests are skipped when swanlab isn't installed """ - return unittest.skipUnless(is_swanlab_available(), "test requires swanlab")( - test_case - ) + return unittest.skipUnless(is_swanlab_available(), "test requires swanlab")(test_case) def require_pandas(test_case): @@ -546,8 +544,7 @@ def require_matplotlib(test_case): _atleast_one_tracker_available = ( - any([is_wandb_available(), is_tensorboard_available(), is_swanlab_available()]) - and not is_comet_ml_available() + any([is_wandb_available(), is_tensorboard_available(), is_swanlab_available()]) and not is_comet_ml_available() ) diff --git a/src/accelerate/tracking.py b/src/accelerate/tracking.py index 0b85cf4a228..ec7e62cb435 100644 --- a/src/accelerate/tracking.py +++ b/src/accelerate/tracking.py @@ -1122,9 +1122,8 @@ def log(self, values: dict, step: Optional[int] = None, **kwargs): Args: data : Dict[str, DataType] - Data must be a dict. - The key must be a string with 0-9, a-z, A-Z, " ", "_", "-", "/". - The value must be a `float`, `float convertible object`, `int` or `swanlab.data.BaseType`. + Data must be a dict. The key must be a string with 0-9, a-z, A-Z, " ", "_", "-", "/". The value must be a + `float`, `float convertible object`, `int` or `swanlab.data.BaseType`. step : int, optional The step number of the current data, if not provided, it will be automatically incremented. If step is duplicated, the data will be ignored. From 55e7f60e5487b9fc675482d574065069ca070309 Mon Sep 17 00:00:00 2001 From: ShaohonChen Date: Sun, 8 Jun 2025 18:25:57 +0800 Subject: [PATCH 4/5] add support for SwanLabTracker in tests --- tests/test_examples.py | 5 +- tests/test_tracking.py | 129 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 133 insertions(+), 1 deletion(-) diff --git a/tests/test_examples.py b/tests/test_examples.py index e1b7d0a53d3..9e67794aa4e 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -239,7 +239,10 @@ def test_schedulefree(self): run_command(self.launch_args + testargs) @require_trackers - @mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"}) + @mock.patch.dict( + os.environ, + {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true", "SWANLAB_MODE": "offline"}, + ) def test_tracking(self): with tempfile.TemporaryDirectory() as tmpdir: testargs = f""" diff --git a/tests/test_tracking.py b/tests/test_tracking.py index e5fe2ba3ce1..3201c775734 100644 --- a/tests/test_tracking.py +++ b/tests/test_tracking.py @@ -16,6 +16,7 @@ import json import logging import os +import random import re import subprocess import tempfile @@ -42,6 +43,7 @@ require_matplotlib, require_mlflow, require_pandas, + require_swanlab, require_tensorboard, require_wandb, skip, @@ -53,6 +55,7 @@ DVCLiveTracker, GeneralTracker, MLflowTracker, + SwanLabTracker, TensorBoardTracker, WandBTracker, ) @@ -520,6 +523,123 @@ def test_log_table_pandas(self): self.assertCountEqual(plot["data"][0]["cells"]["values"], [[1, 2], [3, 4], [5, 6]]) +@require_swanlab +@mock.patch.dict(os.environ, {"SWANLAB_MODE": "offline"}) +class SwanLabTrackingTest(TempDirTestCase, MockingTestCase): + def setUp(self): + super().setUp() + # Setting Path where SwanLab parsed log files are saved via the SWANLAB_LOG_DIR env var + self.add_mocks(mock.patch.dict(os.environ, {"SWANLAB_LOG_DIR": self.tmpdir})) + + @skip + def test_swanlab(self): + # Disable hardware monitoring to prevent errors in test mode. + import swanlab + from swanlab.log.backup import BackupHandler + from swanlab.log.backup.datastore import DataStore + from swanlab.log.backup.models import ModelsParser + + swanlab.merge_settings(swanlab.Settings(hardware_monitor=False)) + # Start a fake training session. + accelerator = Accelerator(log_with="swanlab") + project_name = "test_project_with_config" + experiment_name = "test" + description = "test project for swanlab" + tags = ["my_tag"] + config = { + "epochs": 10, + "learning_rate": 0.01, + "offset": 0.1, + } + kwargs = { + "swanlab": { + "experiment_name": experiment_name, + "description": description, + "tags": tags, + } + } + accelerator.init_trackers(project_name, config, kwargs) + record_metrics = [] + record_scalars = [] + record_images_count = 0 + record_logs = [] + for epoch in range(1, swanlab.config.epochs): + acc = 1 - 2**-epoch - random.random() / epoch - 0.1 + loss = 2**-epoch + random.random() / epoch + 0.1 + ll = swanlab.log( + { + "accuracy": acc, + "loss": loss, + "image": swanlab.Image(np.random.random((3, 3, 3))), + }, + step=epoch, + ) + log = f"epoch={epoch}, accuracy={acc}, loss={loss}" + print(log) + record_scalars.extend([acc, loss]) + record_images_count += 1 + record_logs.append(log) + record_metrics.extend([x for _, x in ll.items()]) + accelerator.end_training() + + # Load latest offline log + run_dir = os.path.join("./swanlog", swanlab.get_run().public.run_dir) + assert os.path.exists(run_dir) is True + ds = DataStore() + ds.open_for_scan(os.path.join(run_dir.__str__(), BackupHandler.BACKUP_FILE).__str__()) + with ModelsParser() as models_parser: + for record in ds: + if record is None: + continue + models_parser.parse_record(record) + header, project, experiment, logs, runtime, columns, scalars, medias, footer = models_parser.get_parsed() + + # test file header + assert header.backup_type == "DEFAULT" + + # test project info + assert project.name == project_name + assert project.workspace is None + assert project.public is None + + # test experiment info + assert experiment.name is not None + assert experiment.description == description + assert experiment.tags == tags + + # test log record + backup_logs = [log.message for log in logs] + for record_log in record_logs: + assert record_log in backup_logs, "Log not found in backup logs: " + record_log + + # test runtime info + runtime_info = runtime.to_file_model(os.path.join(run_dir.__str__(), "files")) + assert runtime_info.conda is None, "Not using conda, should be None" + assert isinstance(runtime_info.requirements, str), "Requirements should be a string" + assert isinstance(runtime_info.metadata, dict), "Metadata should be a dictionary" + assert isinstance(runtime_info.config, dict), "Config should be a dictionary" + for key in runtime_info.config: + assert key in config, f"Config key {key} not found in original config" + assert runtime_info.config[key]["value"] == config[key], ( + f"Config value for {key} does not match original value" + ) + + # test scalar + assert len(scalars) + len(medias) == len(record_metrics), "Total metrics count does not match" + backup_scalars = [ + metric.metric["data"] + for metric in record_metrics + if metric.column_info.chart_type.value.column_type == "FLOAT" + ] + assert len(backup_scalars) == len(scalars), "Total scalars count does not match" + for scalar in backup_scalars: + assert scalar in record_scalars, f"Scalar {scalar} not found in original scalars" + backup_images = [ + metric for metric in record_metrics if metric.column_info.chart_type.value.column_type == "IMAGE" + ] + assert len(backup_images) == record_images_count, "Total images count does not match" + + class MyCustomTracker(GeneralTracker): "Basic tracker that writes to a csv for testing" @@ -728,3 +848,12 @@ def test_dvclive_deferred_init(self): self.assertEqual(PartialState._shared_state, {}) _ = Accelerator(log_with=tracker) self.assertNotEqual(PartialState._shared_state, {}) + + @require_swanlab + def test_swanlab_deferred_init(self): + """Test that SwanLab tracker initialization doesn't initialize distributed""" + PartialState._reset_state() + tracker = SwanLabTracker(run_name="test_swanlab") + self.assertEqual(PartialState._shared_state, {}) + _ = Accelerator(log_with=tracker) + self.assertNotEqual(PartialState._shared_state, {}) From a404b2904223367481d77f08892d993f7551567c Mon Sep 17 00:00:00 2001 From: ShaohonChen Date: Tue, 10 Jun 2025 19:52:43 +0800 Subject: [PATCH 5/5] fix bug in test_tracking --- tests/test_tracking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_tracking.py b/tests/test_tracking.py index 3201c775734..a055a53de1f 100644 --- a/tests/test_tracking.py +++ b/tests/test_tracking.py @@ -583,7 +583,7 @@ def test_swanlab(self): accelerator.end_training() # Load latest offline log - run_dir = os.path.join("./swanlog", swanlab.get_run().public.run_dir) + run_dir = swanlab.get_run().public.run_dir assert os.path.exists(run_dir) is True ds = DataStore() ds.open_for_scan(os.path.join(run_dir.__str__(), BackupHandler.BACKUP_FILE).__str__())