8000 Make memory logging great again by RdoubleA · Pull Request #817 · pytorch/torchtune · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Make memory logging great again #817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,8 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerDecoderLayer}
)
if self._is_rank_zero:
memory_stats = utils.memory_stats_log(device=self._device)
log.info(f"Memory Stats after model init:\n{memory_stats}")
memory_stats = utils.get_memory_stats(device=self._device)
utils.log_memory_stats(memory_stats)

# synchronize before training begins
torch.distributed.barrier()
Expand Down Expand Up @@ -477,7 +477,7 @@ def train(self) -> None:
and self._is_rank_zero
):
# Log peak memory for iteration
memory_stats = utils.memory_stats_log(device=self._device)
memory_stats = utils.get_memory_stats(device=self._device)
self._metric_logger.log_dict(
memory_stats, step=self.total_training_steps
)
Expand Down
7 changes: 4 additions & 3 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,9 @@ def _setup_model(
log.info("Compiling model with torch.compile...")
model = utils.wrap_compile(model)
if self._device.type == "cuda":
memory_stats = utils.memory_stats_log(device=self._device)
log.info(f"Memory Stats after model init:\n{memory_stats}")
memory_stats = utils.get_memory_stats(device=self._device)
utils.log_memory_stats(memory_stats)

return model

def _setup_optimizer(
Expand Down Expand Up @@ -444,7 +445,7 @@ def train(self) -> None:
self.total_training_steps % self._log_peak_memory_every_n_steps == 0
and self._device.type == "cuda"
):
memory_stats = utils.memory_stats_log(device=self._device)
memory_stats = utils.get_memory_stats(device=self._device)
self._metric_logger.log_dict(
memory_stats, step=self.total_training_steps
)
Expand Down
7 changes: 3 additions & 4 deletions recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ def setup(self, cfg: DictConfig) -> None:
self._steps_per_epoch = (
len(self._dataloader) // self._gradient_accumulation_steps
)
steps_per_epoch = len(self._dataloader)
if (
self.max_steps_per_epoch is not None
and self.max_steps_per_epoch < self._steps_per_epoch
Expand Down Expand Up @@ -257,8 +256,8 @@ def _setup_model(

log.info(f"Model is initialized with precision {self._dtype}.")
if self._device == torch.device("cuda"):
memory_stats = utils.memory_stats_log(device=self._device)
log.info(f"Memory Stats after model init:\n{memory_stats}")
memory_stats = utils.get_memory_stats(device=self._device)
utils.log_memory_stats(memory_stats)
return model

def _setup_optimizer(
Expand Down Expand Up @@ -523,7 +522,7 @@ def train(self) -> None:
and self._device == torch.device("cuda")
):
# Log peak memory for iteration
memory_stats = utils.memory_stats_log(device=self._device)
memory_stats = utils.get_memory_stats(device=self._device)
self._metric_logger.log_dict(
memory_stats, step=self.total_training_steps
)
Expand Down
6 changes: 3 additions & 3 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,8 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerDecoderLayer}
)
if self._is_rank_zero:
memory_stats = utils.memory_stats_log(device=self._device)
log.info(f"Memory Stats after model init:\n{memory_stats}")
memory_stats = utils.get_memory_stats(device=self._device)
utils.log_memory_stats(memory_stats)

# synchronize before training begins
torch.distributed.barrier()
Expand Down Expand Up @@ -572,7 +572,7 @@ def train(self) -> None:
and self._is_rank_zero
):
# Log peak memory for iteration
memory_stats = utils.memory_stats_log(device=self._device)
memory_stats = utils.get_memory_stats(device=self._device)
self._metric_logger.log_dict(
memory_stats, step=self.total_training_steps
)
Expand Down
7 changes: 3 additions & 4 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ def setup(self, cfg: DictConfig) -> None:
self._steps_per_epoch = (
len(self._dataloader) // self._gradient_accumulation_steps
)
steps_per_epoch = len(self._dataloader)
if (
self.max_steps_per_epoch is not None
and self.max_steps_per_epoch < self._steps_per_epoch
Expand Down Expand Up @@ -296,8 +295,8 @@ def _setup_model(
log.info("Compiling model with torch.compile...")
model = utils.wrap_compile(model)
if self._device.type == "cuda":
memory_stats = utils.memory_stats_log(device=self._device)
log.info(f"Memory Stats after model init:\n{memory_stats}")
memory_stats = utils.get_memory_stats(device=self._device)
utils.log_memory_stats(memory_stats)
return model

def _setup_optimizer(
Expand Down Expand Up @@ -479,7 +478,7 @@ def train(self) -> None:
and self._device.type == "cuda"
):
# Log peak memory for iteration
memory_stats = utils.memory_stats_log(device=self._device)
memory_stats = utils.get_memory_stats(device=self._device)
self._metric_logger.log_dict(
memory_stats, step=self.total_training_steps
)
Expand Down
6 changes: 4 additions & 2 deletions torchtune/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
from .memory import ( # noqa
cleanup_before_training,
create_optim_in_bwd_wrapper,
memory_stats_log,
get_memory_stats,
log_memory_stats,
OptimizerInBackwardWrapper,
register_optim_in_bwd_hooks,
set_activation_checkpointing,
Expand All @@ -62,7 +63,8 @@
"transform_opt_state_dict",
"validate_checkpoint",
"get_autocast",
"memory_stats_log",
"get_memory_stats",
"log_memory_stats",
"get_device",
"get_dtype",
"wrap_fsdp",
Expand Down
24 changes: 23 additions & 1 deletion torchtune/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import gc
import logging

from typing import Any, Dict, Optional, Set

Expand All @@ -15,6 +16,9 @@
apply_activation_checkpointing,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torchtune.utils.logging import get_logger

_log: logging.Logger = get_logger()


def set_activation_checkpointing(
Expand Down Expand Up @@ -160,7 +164,7 @@ def optim_step(param) -> None:
p.register_post_accumulate_grad_hook(optim_step)


def memory_stats_log(device: torch.device, reset_stats: bool = True) -> dict:
def get_memory_stats(device: torch.device, reset_stats: bool = True) -> dict:
"""
Computes a memory summary for the passed in device. If ``reset_stats`` is ``True``, this will
also reset CUDA's peak memory tracking. This is useful to get data around relative use of peak
Expand Down Expand Up @@ -196,3 +200,21 @@ def memory_stats_log(device: torch.device, reset_stats: bool = True) -> dict:
"peak_memory_reserved": peak_mem_reserved,
}
return memory_stats


def log_memory_stats(stats: Dict[str, float]) -> None:
"""
Logs a dict containing memory stats to the logger. This expects the fields
`peak_memory_active`, `peak_memory_alloc`, and `peak_memory_reserved` as
returned by `get_memory_stats`.

Args:
stats (Dict[str, float]): A dictionary containing the peak memory active, peak memory
allocated, and peak memory reserved stats.
"""
_log.info(
"Memory stats after model init:"
f"\n\tGPU peak memory allocation: {stats['peak_memory_alloc']:.2f} GB"
f"\n\tGPU peak memory reserved: {stats['peak_memory_reserved']:.2f} GB"
f"\n\tGPU peak memory active: {stats['peak_memory_active']:.2f} GB"
)
0