8000 Integrate SwanLab for offline/online experiment tracking for Accelerate by ShaohonChen · Pull Request #3605 · huggingface/accelerate · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Integrate SwanLab for offline/online experiment tracking for Accelerate #3605

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/package_reference/tracking.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,8 @@ rendered properly in your Markdown viewer.

[[autodoc]] tracking.ClearMLTracker
- __init__

## SwanLabTracker

[[autodoc]] tracking.SwanLabTracker
- __init__
2 changes: 1 addition & 1 deletion examples/by_feature/deepspeed_with_config_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def parse_args():
default="all",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"`, `"comet_ml"`, and `"dvclive"`. Use `"all"` (default) to report to all integrations.'
' `"wandb"`, `"comet_ml"`, `"dvclive"`, and `"swanlab"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
Expand Down
2 changes: 1 addition & 1 deletion examples/by_feature/megatron_lm_gpt_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def parse_args():
default="all",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"`, `"comet_ml"`, and `"dvclive"`. Use `"all"` (default) to report to all integrations.'
' `"wandb"`, `"comet_ml"`, and `"dvclive"`, and `"swanlab"`. Use `"all"` (default) to report 10000 to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
Expand Down
10 changes: 9 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,15 @@
extras["rich"] = ["rich"]

extras["test_fp8"] = ["torchao"] # note: TE for now needs to be done via pulling down the docker image directly
extras["test_trackers"] = ["wandb", "comet-ml", "tensorboard", "dvclive", "mlflow", "matplotlib"]
extras["test_trackers"] = [
"wandb",
"comet-ml",
"tensorboard",
"dvclive",
"mlflow",
"matplotlib",
"swanlab",
]
extras["dev"] = extras["quality"] + extras["testing"] + extras["rich"]

extras["sagemaker"] = [
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ class Accelerator:
- `"tensorboard"`
- `"wandb"`
- `"comet_ml"`
- `"swanlab"`
If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can
also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`.
project_config ([`~utils.ProjectConfiguration`], *optional*):
Expand Down
10 changes: 9 additions & 1 deletion src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
is_pytest_available,
is_schedulefree_available,
is_sdaa_available,
is_swanlab_available,
is_tensorboard_available,
is_timm_available,
is_torch_version,
Expand Down Expand Up @@ -482,6 +483,13 @@ def require_dvclive(test_case):
return unittest.skipUnless(is_dvclive_available(), "test requires dvclive")(test_case)


def require_swanlab(test_case):
"""
Decorator marking a test that requires swanlab installed. These tests are skipped when swanlab isn't installed
"""
return unittest.skipUnless(is_swanlab_available(), "test requires swanlab")(test_case)


def require_pandas(test_case):
"""
Decorator marking a test that requires pandas installed. These tests are skipped when pandas isn't installed
Expand Down Expand Up @@ -536,7 +544,7 @@ def require_matplotlib(test_case):


_atleast_one_tracker_available = (
any([is_wandb_available(), is_tensorboard_available()]) and not is_comet_ml_available()
any([is_wandb_available(), is_tensorboard_available(), is_swanlab_available()]) and not is_comet_ml_available()
)


Expand Down
106 changes: 106 additions & 0 deletions src/accelerate/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
is_comet_ml_available,
is_dvclive_available,
is_mlflow_available,
is_swanlab_available,
is_tensorboard_available,
is_wandb_available,
listify,
Expand Down Expand Up @@ -63,6 +64,9 @@
if is_dvclive_available():
_available_trackers.append(LoggerType.DVCLIVE)

if is_swanlab_available():
_available_trackers.append(LoggerType.SWANLAB)

logger = get_logger(__name__)


Expand Down Expand Up @@ -1061,6 +1065,106 @@ def finish(self):
self.live.end()


class SwanLabTracker(GeneralTracker):
"""
A `Tracker` class that supports `swanlab`. Should be initialized at the start of your script.

Args:
run_name (`str`):
The name of the experiment run.
**kwargs (additional keyword arguments, *optional*):
Additional key word arguments passed along to the `swanlab.init` method.
"""

name = "swanlab"
requires_logging_directory = False
main_process_only = False

def __init__(self, run_name: str, **kwargs):
super().__init__()
self.run_name = run_name
self.init_kwargs = kwargs

@on_main_process
def start(self):
import swanlab

self.run = swanlab.init(project=self.run_name, **self.init_kwargs)
swanlab.config["FRAMEWORK"] = "🤗Accelerate" # add accelerate logo in config
logger.debug(f"Initialized SwanLab project {self.run_name}")
logger.debug(
"Make sure to log any initial configurations with `self.store_init_configuration` before training!"
)

@property
def tracker(self):
return self.run

@on_main_process
def store_init_configuration(self, values: dict):
"""
Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.

Args:
values (Dictionary `str` to `bool`, `str`, `float` or `int`):
Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
`str`, `float`, `int`, or `None`.
"""
import swanlab

swanlab.config.update(values, allow_val_change=True)
logger.debug("Stored initial configuration hyperparameters to SwanLab")

@on_main_process
def log(self, values: dict, step: Optional[int] = None, **kwargs):
"""
Logs `values` to the current run.

Args:
data : Dict[str, DataType]
Data must be a dict. The key must be a string with 0-9, a-z, A-Z, " ", "_", "-", "/". The value must be a
`float`, `float convertible object`, `int` or `swanlab.data.BaseType`.
step : int, optional
The step number of the current data, if not provided, it will be automatically incremented.
If step is duplicated, the data will be ignored.
kwargs:
Additional key word arguments passed along to the `swanlab.log` method. Likes:
print_to_console : bool, optional
Whether to print the data to the console, the default is False.
"""
self.run.log(values, step=step, **kwargs)
logger.debug("Successfully logged to SwanLab")

@on_main_process
def log_images(self, values: dict, step: Optional[int] = None, **kwargs):
"""
Logs `images` to the current run.

Args:
values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`):
Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
step (`int`, *optional*):
The run step. If included, the log will be affiliated with this step.
kwargs:
Additional key word arguments passed along to the `swanlab.log` method. Likes:
print_to_console : bool, optional
Whether to print the data to the console, the default is False.
"""
import swanlab

for k, v in values.items():
self.log({k: [swanlab.Image(image) for image in v]}, step=step, **kwargs)
logger.debug("Successfully logged images to SwanLab")

@on_main_process
def finish(self):
"""
Closes `swanlab` writer
"""
self.run.finish()
logger.debug("SwanLab run closed")


LOGGER_TYPE_TO_CLASS = {
"aim": AimTracker,
"comet_ml": CometMLTracker,
Expand All @@ -1069,6 +1173,7 @@ def finish(self):
"wandb": WandBTracker,
"clearml": ClearMLTracker,
"dvclive": DVCLiveTracker,
"swanlab": SwanLabTracker,
}


Expand All @@ -1093,6 +1198,7 @@ def filter_trackers(
- `"comet_ml"`
- `"mlflow"`
- `"dvclive"`
- `"swanlab"`
If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can
also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`.
logging_dir (`str`, `os.PathLike`, *optional*):
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
is_sagemaker_available,
is_schedulefree_available,
is_sdaa_available,
is_swanlab_available,
is_tensorboard_available,
is_timm_available,
is_torch_xla_available,
Expand Down
2 changes: 2 additions & 0 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,7 @@ class LoggerType(BaseEnum):
- **WANDB** -- wandb as an experiment tracker
- **COMETML** -- comet_ml as an experiment tracker
- **DVCLIVE** -- dvclive as an experiment tracker
- **SWANLAB** -- swanlab as an experiment tracker
"""

ALL = "all"
Expand All @@ -711,6 +712,7 @@ class LoggerType(BaseEnum):
MLFLOW = "mlflow"
CLEARML = "clearml"
DVCLIVE = "dvclive"
SWANLAB = "swanlab"


class PrecisionType(str, BaseEnum):
Expand Down
4 changes: 4 additions & 0 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,10 @@ def is_comet_ml_available():
return _is_package_available("comet_ml")


def is_swanlab_available():
return _is_package_available("swanlab")


def is_boto3_available():
return _is_package_available("boto3")

Expand Down
5 changes: 4 additions & 1 deletion tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,10 @@ def test_schedulefree(self):
run_command(self.launch_args + testargs)

@require_trackers
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
@mock.patch.dict(
os.environ,
{"WANDB_MODE": "offline", "DVCLIVE_TEST": "true", "SWANLAB_MODE": "offline"},
)
def test_tracking(self):
with tempfile.TemporaryDirectory() as tmpdir:
testargs = f"""
Expand Down
Loading
Loading
0