-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[memory leak
] Replace GradientState -> DataLoader reference with weakrefs
#3391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[memory leak
] Replace GradientState -> DataLoader reference with weakrefs
#3391
Conversation
So they can be cleaned up. Otherwise, they will always stay in memory, leading to notable memory leaks. Note: even accelerator.free_memory() did not work!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First of all, I have to say this is an outstanding PR description. Thanks a lot for putting in so much effort.
Generally, I agree with the approach taken here (but I'm not an expert on Python gc). I stumbled a bit trying to understand the implementation, I think a few clarifying comments could help to ensure that readers will understand why it needs to be implemented as is. And I wonder if the test could be made a bit more realistic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the adjustments. LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this! I'll include it as part of the release this week.
Pull Request overview
weakref.ref(...)
to break the undetectable cycleDetails
Bug Report
After preparing a dataloader and starting to use it, the
DataLoaderShard
and underlyingGradientState
cannot be cleaned up by the garbage collection, even if some class instance (e.g. aTrainer
) which calledself.dataloader = self.accelerator.prepare(dataloader)
is deleted.As you can see here, if we delete the
Foo
instance, theDataLoaderShard
still exists!If we uncomment the lines regarding
active_dataloader
anddataloader_references
(we need to uncomment both!), then theDataLoaderShard
does get deleted.Here's 2 bigger scripts that show this in action in
transformers
andsentence-transformers
. It is a big problem insentence-transformers
becauseDataLoaderShard
being kept in memory also keeps the data collator in memory, and thesentence-transformers
data collator uses a tokenization function that is bound to themodel
. In short: this prevents themodel
from being cleaned up!!!Memory leak in Transformers
Memory leak in Sentence Transformers
Why?
I believe this issue is caused by 1) a cycle between 2 classes:
DataLoaderShard
andGradientState
and 2)GradientState
uses some__dict__
hacking:accelerate/src/accelerate/state.py
Lines 1161 to 1164 in f19b957
I can reproduce this with much smaller classes here:
Here, the inner class (i.e. like
GradientState
) maps the__dict__
to some class attribute. Then, an outer class (i.e. likeDataLoaderShard
) instance can call a method to add itself to the__dict__
of the inner class instance. This creates a cycle, which is very common in Python classes. Normally, (i.e. without the__dict__
-hacking), the Garbage Collector can recognize that deleting the outer class instance means that the inner class instance can also go, but it seems messing with__dict__
prevents that.In short: the instances never get removed, even when we delete the outer class instance.
The Fix
To prevent the Garbage Collector from getting messed up by our
DataLoaderShard
<->GradientState
cycle, we can update one of the directions of the referencing with a "weak reference". If only weak references to a class exists, the Python GC will eat it. In short: there's no "cycle".Weak references are very simple (note: they're standard library, this doesn't add dependencies):
variable_ref = weakref.ref(variable)
.variable
again by calling the ref:variable_ref()
. If thevariable
was killed, we getNone
by callingvariable_ref()
.We can do this very neatly & automatically with a property:
Here, setting
outer
triggers the setter, and we store the weak reference. Getting theouter
calls the getter, and we grab the actual variable by calling the weak reference that we stored.If we call the same script that was broken before, but with these 2 property methods, then we get:
Perfect, they get cleaned up!
The PR:
For this PR, I applied the fix above.
Additionally, I replace
active_dataloader
with a property that simply gets the last element indataloader_references
, which is what the previous behaviour already was.Also note that you cannot
weakref.ref(None)
, so if we're addingNone
, then we just add it outright. And we can't use.append(...)
or.remove(...)
as we have to trigger the setter.I've added a simple test case that also uses weak references to the objects that should be destroyed. If they indeed get destroyed, then calling the weak reference returns
None
. Before this PR, they would simply return the class instances - because they didn't get destroyed.I've verified that the fix solves the
transformers
andsentence-transformers
memory leaks! Both scripts run without keeping allDataLoaderShard
(and dependents like data collator and/or model) in memory.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@muellerzr @BenjaminBossan @SunMarc