8000 [`memory leak`] Replace GradientState -> DataLoader reference with weakrefs by tomaarsen · Pull Request #3391 · huggingface/accelerate · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[memory leak] Replace GradientState -> DataLoader reference with weakrefs #3391

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

tomaarsen
Copy link
Member

Pull Request overview

  • Replace GradientState -> DataLoader reference with weakref.ref(...) to break the undetectable cycle

Details

Bug Report

After preparing a dataloader and starting to use it, the DataLoaderShard and underlying GradientState cannot be cleaned up by the garbage collection, even if some class instance (e.g. a Trainer) which called self.dataloader = self.accelerator.prepare(dataloader) is deleted.

import gc
from torch.utils.data import DataLoader
import accelerate

class Foo:
    def __init__(self):
        dataloader = DataLoader([1, 2, 3, 4, 5], batch_size=2, shuffle=True)
        self.accelerator = accelerate.Accelerator()
        self.dataloader = self.accelerator.prepare(dataloader)
        self.iter = iter(self.dataloader)
        print(next(self.iter))

def get_dls_instances():
    # Utility function to find all DataLoaderShard instances in all memory, even if they don't have a global reference
    instances = []
    for obj in gc.get_objects():
        try:
            if isinstance(obj, accelerate.data_loader.DataLoaderShard):
                instances.append(obj)
        except ReferenceError:
            pass
    return instances

instance = Foo()

print(get_dls_instances())
# [<accelerate.data_loader.DataLoaderShard object at 0x00000260D85F3AD0>]

# Attempt to remove the DataLoaderShard instance
instance.accelerator.free_memory()
# instance.accelerator.gradient_state.active_dataloader = None
# instance.accelerator.gradient_state.dataloader_references = []
del instance

instances = get_dls_instances()
print(instances)
# [<accelerate.data_loader.DataLoaderShard object at 0x00000260D85F3AD0>]
# Uh-oh! It still exists.

As you can see here, if we delete the Foo instance, the DataLoaderShard still exists!
If we uncomment the lines regarding active_dataloader and dataloader_references (we need to uncomment both!), then the DataLoaderShard does get deleted.

Here's 2 bigger scripts that show this in action in transformers and sentence-transformers. It is a big problem in sentence-transformers because DataLoaderShard being kept in memory also keeps the data collator in memory, and the sentence-transformers data collator uses a tokenization function that is bound to the model. In short: this prevents the model from being cleaned up!!!

Memory leak in Transformers
import gc
import accelerate
from datasets import load_dataset
import torch
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
import evaluate
import numpy as np
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer


def get_dls_instances():
    instances = []
    for obj in gc.get_objects():
        try:
            if isinstance(obj, accelerate.data_loader.DataLoaderShard):
                instances.append(obj)
        except ReferenceError:
            pass
    return instances

for it in range(5):

    imdb = load_dataset("imdb")
    tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

    def preprocess_function(examples):
        return tokenizer(examples["text"], truncation=True)

    tokenized_imdb = imdb.map(preprocess_function, batched=True)

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    accuracy = evaluate.load("accuracy")

    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        predictions = np.argmax(predictions, axis=1)
        return accuracy.compute(predictions=predictions, references=labels)

    id2label = {0: "NEGATIVE", 1: "POSITIVE"}
    label2id = {"NEGATIVE": 0, "POSITIVE": 1}

    model = AutoModelForSequenceClassification.from_pretrained(
        "distilbert-base-uncased", num_labels=2, id2label=id2label, label2id=label2id
    )

    training_args = TrainingArguments(
        output_dir="my_awesome_model",
        learning_rate=2e-5,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=2,
        weight_decay=0.01,
        evaluation_strategy="no",
        save_strategy="no",
        load_best_model_at_end=True,
        push_to_hub=True,
        max_steps=6,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_imdb["train"],
        eval_dataset=tokenized_imdb["test"],
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )

    trainer.train()
    trainer.accelerator.free_memory()

    del imdb, tokenized_imdb, data_collator, accuracy, model, training_args, trainer

    gc.collect()
    torch.cuda.empty_cache()

    print(get_dls_instances())
Memory leak in Sentence Transformers
import gc
import logging
from datetime import datetime

import accelerate
from datasets import load_dataset
import torch

from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import TripletLoss
from sentence_transformers.trainer import SentenceTransformerTrainer
from sentence_transformers.training_args import SentenceTransformerTrainingArguments

# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)

# You can specify any huggingface/transformers pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base
model_name = "distilbert-base-uncased"
batch_size = 16
num_train_epochs = 1

for it in range(5):
    output_dir = "output/training-wikipedia-sections-" + model_name + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    # 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically
    # create one with "mean" pooling.
    model = SentenceTransformer(model_name)
    # If we want, we can limit the maximum sequence length for the model
    # model.max_seq_length = 75
    # logging.info(model)

    # 2. Load the Wikipedia-Sections dataset: https://huggingface.co/datasets/sentence-transformers/wikipedia-sections
    train_dataset = load_dataset("sentence-transformers/wikipedia-sections", "triplet", split="train").select(
        range(10_000)
    )
    logging.info(train_dataset)

    # 3. Define our training loss
    # TripletLoss (https://sbert.net/docs/package_reference/sentence_transformer/losses.html#tripletloss) needs three text columns
    train_loss = TripletLoss(model)


    # 5. Define the training arguments
    args = SentenceTransformerTrainingArguments(
        # Required parameter:
        output_dir=output_dir,
        # Optional training parameters:
        num_train_epochs=num_train_epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        warmup_ratio=0.1,
        max_steps=6,
        fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
        bf16=False,  # Set to True if you have a GPU that supports BF16
        # Optional tracking/debugging parameters:
        eval_strategy="no",
        save_strategy="no",
        logging_steps=100,
        run_name="wikipedia-sections-triplet",  # Will be used in W&B if `wandb` is installed
    )

    # 6. Create the trainer & start training
    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        loss=train_loss,
    )
    trainer.train()

    
    del train_dataset, train_loss, args, trainer, model

    gc.collect()
    torch.cuda.empty_cache()

    print(f'iter: {it} memory_allocated: {torch.cuda.memory_allocated() / 1024**3}')
    print(f'iter: {it} memory_reserved:  {torch.cuda.memory_reserved() / 1024**3}')

    dl_shards = []
    for obj in gc.get_objects():
        try:
            if isinstance(obj, accelerate.data_loader.DataLoaderShard):
                dl_shards.append(obj)
        except ReferenceError:
            pass
    print(dl_shards)

"""
iter: 0 memory_allocated: 0.2649421691894531
iter: 0 memory_reserved:  0.3046875

iter: 1 memory_allocated: 0.5138320922851562
iter: 1 memory_reserved:  0.548828125

iter: 2 memory_allocated: 0.7621116638183594
iter: 2 memory_reserved:  0.8125

iter: 3 memory_allocated: 1.0103912353515625
iter: 3 memory_reserved:  1.076171875

iter: 4 memory_allocated: 1.2586708068847656
iter: 4 memory_reserved:  1.33984375
"""

Why?

I believe this issue is caused by 1) a cycle between 2 classes: DataLoaderShard and GradientState and 2) GradientState uses some __dict__ hacking:

_shared_state = SharedDict()
def __init__(self, gradient_accumulation_plugin: Optional[GradientAccumulationPlugin] = None):
self.__dict__ = self._shared_state

I can reproduce this with much smaller classes here:

import gc

class InnerClass:
    _state = dict()

    def __init__(self):
        self.__dict__ = self._state

    def register(self, outer):
        self.outer = outer

class OuterClass:
    def __init__(self):
        self.info = "I am the outer class"
        self.inner = InnerClass()
        self.inner.register(self)

def get_instances():
    instances = []
    for obj in gc.get_objects():
        try:
            if isinstance(obj, (InnerClass, OuterClass)):
                instances.append(obj)
        except ReferenceError:
            pass
    return instances


outer = OuterClass()

print(get_instances())
# [<__main__.OuterClass object at 0x0000019762014750>, <__main__.InnerClass object at 0x0000019762014050>]

del outer
print(get_instances())
# [<__main__.OuterClass object at 0x0000019762014750>, <__main__.InnerClass object at 0x0000019762014050>]

Here, the inner class (i.e. like GradientState) maps the __dict__ to some class attribute. Then, an outer class (i.e. like DataLoaderShard) instance can call a method to add itself to the __dict__ of the inner class instance. This creates a cycle, which is very common in Python classes. Normally, (i.e. without the __dict__-hacking), the Garbage Collector can recognize that deleting the outer class instance means that the inner class instance can also go, but it seems messing with __dict__ prevents that.

In short: the instances never get removed, even when we delete the outer class instance.

The Fix

To prevent the Garbage Collector from getting messed up by our DataLoaderShard <-> GradientState cycle, we can update one of the directions of the referencing with a "weak reference". If only weak references to a class exists, the Python GC will eat it. In short: there's no "cycle".

Weak references are very simple (note: they're standard library, this doesn't add dependencies):

  1. Create them with variable_ref = weakref.ref(variable).
  2. Get the variable again by calling the ref: variable_ref(). If the variable was killed, we get None by calling variable_ref().

We can do this very neatly & automatically with a property:

import weakref

class Inner
8000
Class:
    _state = dict()

    def __init__(self):
        self.__dict__ = self._state

    def register(self, outer):
        self.outer = outer

    @property
    def outer(self):
        return self._outer()

    @outer.setter
    def outer(self, value):
        self._outer = weakref.ref(value)

Here, setting outer triggers the setter, and we store the weak reference. Getting the outer calls the getter, and we grab the actual variable by calling the weak reference that we stored.

If we call the same script that was broken before, but with these 2 property methods, then we get:

...

outer = OuterClass()

print(get_instances())
# [<__main__.OuterClass object at 0x000002BEDF988C50>, <__main__.InnerClass object at 0x000002BEDF988550>]

del outer
print(get_instances())
# []

Perfect, they get cleaned up!

The PR:

For this PR, I applied the fix above.
Additionally, I replace active_dataloader with a property that simply gets the last element in dataloader_references, which is what the previous behaviour already was.
Also note that you cannot weakref.ref(None), so if we're adding None, then we just add it outright. And we can't use .append(...) or .remove(...) as we have to trigger the setter.

I've added a simple test case that also uses weak references to the objects that should be destroyed. If they indeed get destroyed, then calling the weak reference returns None. Before this PR, they would simply return the class instances - because they didn't get destroyed.

I've verified that the fix solves the transformers and sentence-transformers memory leaks! Both scripts run without keeping all DataLoaderShard (and dependents like data collator and/or model) in memory.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@muellerzr @BenjaminBossan @SunMarc

So they can be cleaned up. Otherwise, they will always stay in memory, leading to notable memory leaks. Note: even accelerator.free_memory() did not work!
@tomaarsen tomaarsen requested a review from muellerzr February 10, 2025 14:56
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member
@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First of all, I have to say this is an outstanding PR description. Thanks a lot for putting in so much effort.

Generally, I agree with the approach taken here (but I'm not an expert on Python gc). I stumbled a bit trying to understand the implementation, I think a few clarifying comments could help to ensure that readers will understand why it needs to be implemented as is. And I wonder if the test could be made a bit more realistic.

Copy link
Member
@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the adjustments. LGTM.

Copy link
Contributor
@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this! I'll include it as part of the release this week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0