8000 CUDA Memory leak w/ torch.compile in both stable and trunk · Issue #119607 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

CUDA Memory leak w/ torch.compile in both stable and trunk #119607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
xmfan opened this issue Feb 9, 2024 · 27 comments
Closed

CUDA Memory leak w/ torch.compile in both stable and trunk #119607

xmfan opened this issue Feb 9, 2024 · 27 comments
Assignees
Labels
high priority module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@xmfan
Copy link
Member
xmfan commented Feb 9, 2024

🐛 Describe the bug

models traced with torch.compile don't seem to be freeing CUDA memory

import torch
import gc

def main():
    x = torch.randn(1000, 3000, device="cuda", requires_grad=True)
    model = torch.nn.Sequential(
        torch.nn.Linear(3000, 10000),
        torch.nn.ReLU(),
        torch.nn.Linear(10000, 50000),
        torch.nn.ReLU(),
        torch.nn.Linear(50000, 20000),
        torch.nn.ReLU(),
        torch.nn.Linear(20000, 1234),
    ).to("cuda")
    model = torch.compile(model, backend="eager")
    model(x)

if __name__ == "__main__":
    main()

    # tried clearing with a few ways
    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    torch._C._cuda_clearCublasWorkspaces()
    gc.collect()

    print(f"{torch.cuda.memory_allocated()/1e9} GB!!")  # 6.219729408 GB!!

one high priority use case to fix this is for compiled autograd, which calls torch.compile for compiled fw and once for compiled bw, leading to 2x memory use

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @bdhirsh @anijain2305 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @Chillee

Versions

2.2.0
trunk

@xmfan xmfan changed the title Memory leak in nightly CUDA Memory leak w/ torch.compile in nightly Feb 9, 2024
@xmfan xmfan changed the title CUDA Memory leak w/ torch.compile in nightly CUDA Memory leak w/ torch.compile in both stable and nightly Feb 10, 2024
@xmfan
Copy link
Member Author
xmfan commented Feb 10, 2024

marking dynamo since it happens with backend="eager"

@xmfan xmfan changed the title CUDA Memory leak w/ torch.compile in both stable and nightly CUDA Memory leak w/ torch.compile in both stable and trunk Feb 10, 2024
@anijain2305
Copy link
Contributor

Cc @williamwen42

@malfet
Copy link
Contributor
malfet commented Feb 12, 2024

dynamo has its own mechanism for cleaning compiled artifact, wouldn't that be sufficient? And perhaps something like that on the Triton side as well

@williamwen42
Copy link
Member
williamwen42 commented Feb 12, 2024

This is not intended behavior, but I find that if I wrap the torch.nn.Sequential inside a custom nn.Module, then the memory gets freed:

import gc
import torch

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # self.fc1 = torch.nn.Linear(3000, 50000)
        self.fc1 = torch.nn.Sequential(
            torch.nn.Linear(3000, 10000),
            torch.nn.ReLU(),
            torch.nn.Linear(10000, 50000),
            torch.nn.ReLU(),
            torch.nn.Linear(50000, 20000),
            torch.nn.ReLU(),
            torch.nn.Linear(20000, 1234),
        )

    def forward(self, out):
        out = self.fc1(out)
        return out

def run(compile):
    mod = MyModel().cuda()
    if compile:
        mod = torch.compile(mod, backend="eager")
    inp = torch.rand(10000, 3000).cuda()
    mod(inp)

def clean_and_report_memory():
    gc.collect()
    print(f"max memory: {torch.cuda.max_memory_allocated()}, curr memory: {torch.cuda.memory_allocated()}")

run(False)
clean_and_report_memory()

run(True)
clean_and_report_memory()

torch._dynamo.reset()
clean_and_report_memory()

Output:

max memory: 2730451456, curr memory: 8519680
max memory: 2730451456, curr memory: 8519680
max memory: 2730451456, curr memory: 8519680

I will continue to investigate why memory is not being freed in the original code snippet.

@gchanan
Copy link
Contributor
gchanan commented Feb 12, 2024

is it a regression?

@soulitzer soulitzer added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Feb 13, 2024
@atalman
Copy link
Contributor
atalman commented Feb 14, 2024

Moving to release 2.2.2. Since fix is not out yet

@atalman atalman modified the milestones: 2.2.1, 2.2.2 Feb 14, 2024
@williamwen42
Copy link
Member
williamwen42 commented Feb 15, 2024

I've simplified the repro:

import gc
import weakref
import torch

mod = torch.nn.Linear(3000, 50000).cuda()
def fn(x):
    return mod(x)

ref = weakref.ref(mod, lambda _: print("mod deleted"))
weakref.finalize(fn, lambda: print("fn deleted"))

inp = torch.rand(10000, 3000).cuda()

torch.compile(backend="eager")(fn)(inp)

del mod
del fn

gc.collect()

# expect finalizers to run before this point
breakpoint()

It seems that dynamo holds on to a reference to mod somewhere.

@lezcano
Copy link
Collaborator
lezcano commented Feb 22, 2024

a shot in the dark, but may this be related to the memory leak that you were once hunting @Fidget-Spinner?

@Fidget-Spinner
Copy link
Collaborator

The memory leak I was once hunting concerned a circular reference between the compiled code cache and the code object itself IIRC. If anyone is aware of a reference from the compiled artefact object to the symbolic evaluator, that might be a source of leaks, because the symbolic stuff definitely holds a reference onto mod in this case. Otherwise, it wouldn't be that.

@lezcano
Copy link
Collaborator
lezcano commented Feb 22, 2024

alas @IvanYashchuk patched in #109422 and didn't seem to help with this one, so I guess there's something else going on here.

@ezyang
Copy link
Contributor
ezyang commented Feb 26, 2024

@williamwen42 any updates on this?

@williamwen42
Copy link
Member
williamwen42 commented Feb 26, 2024

Still working on this - got blocked recently because #112090 was happening again, but it has been resolved.

@williamwen42
Copy link
Member

The repro no longer leaks due to #120578, but I still observe a leak if the model is a builtin module (e.g. torch.nn.Linear), or if we are on 3.11+.

williamwen42 added a commit that referenced this issue Apr 17, 2024
Fixes #119607 for 3.11+.

In 3.11+, `_PyFrame_FastToLocalsWithError` could implicity run `COPY_FREE_VARS` on the original frame, leading to double incref's since the dynamo shadow frame can rerun `COPY_FREE_VARS`. So the solution is to skip the first `COPY_FREE_VARS` instruction in the shadow frame if it was already executed in the original frame.

Also move the location for clearing the original frame in 3.12 to handle error cases more thoroughly.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
williamwen42 added a commit that referenced this issue Apr 17, 2024
Fixes #119607 for 3.11+.

In 3.11+, `_PyFrame_FastToLocalsWithError` could implicity run `COPY_FREE_VARS` on the original frame, leading to double incref's since the dynamo shadow frame can rerun `COPY_FREE_VARS`. So the solution is to skip the first `COPY_FREE_VARS` instruction in the shadow frame if it was already executed in the original frame.

Also move the location for clearing the original frame in 3.12 to handle error cases more thoroughly.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
williamwen42 added a commit that referenced this issue Apr 17, 2024
Fixes #119607 for 3.11+.

In 3.11+, `_PyFrame_FastToLocalsWithError` could implicity run `COPY_FREE_VARS` on the original frame, leading to double incref's since the dynamo shadow frame can rerun `COPY_FREE_VARS`. So the solution is to skip the first `COPY_FREE_VARS` instruction in the shadow frame if it was already executed in the original frame.

Also move the location for clearing the original frame in 3.12 to handle error cases more thoroughly.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
williamwen42 added a commit that referenced this issue Apr 17, 2024
Fixes #119607 for 3.11+.

In 3.11+, `_PyFrame_FastToLocalsWithError` could implicity run `COPY_FREE_VARS` on the original frame, leading to double incref's since the dynamo shadow frame can rerun `COPY_FREE_VARS`. So the solution is to skip the first `COPY_FREE_VARS` instruction in the shadow frame if it was already executed in the original frame.

Also move the location for clearing the original frame in 3.12 to handle error cases more thoroughly.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
williamwen42 added a commit that referenced this issue Apr 17, 2024
Fixes #119607 for 3.11+.

In 3.11+, `_PyFrame_FastToLocalsWithError` could implicity run `COPY_FREE_VARS` on the original frame, leading to double incref's since the dynamo shadow frame can rerun `COPY_FREE_VARS`. So the solution is to skip the first `COPY_FREE_VARS` instruction in the shadow frame if it was already executed in the original frame.

Also move the location for clearing the original frame in 3.12 to handle error cases more thoroughly.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
williamwen42 added a commit that referenced this issue Apr 17, 2024
Fixes #119607 for 3.11+.

In 3.11+, `_PyFrame_FastToLocalsWithError` could implicity run `COPY_FREE_VARS` on the original frame, leading to double incref's since the dynamo shadow frame can rerun `COPY_FREE_VARS`. So the solution is to skip the first `COPY_FREE_VARS` instruction in the shadow frame if it was already executed in the original frame.

Also move the location for clearing the original frame in 3.12 to handle error cases more thoroughly.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
facebook-github-bot pushed a commit to pytorch/benchmark that referenced this issue Apr 18, 2024
Summary:
Fixes pytorch/pytorch#119607 for 3.11+.

In 3.11+, `_PyFrame_FastToLocalsWithError` could implicity run `COPY_FREE_VARS` on the original frame, leading to double incref's since the dynamo shadow frame can rerun `COPY_FREE_VARS`. So the solution is to skip the first `COPY_FREE_VARS` instruction in the shadow frame if it was already executed in the original frame.

Also move the location for clearing the original frame in 3.12 to handle error cases more thoroughly.

X-link: pytorch/pytorch#124238
Approved by: https://github.com/jansel

Reviewed By: PaliC

Differential Revision: D56289286

Pulled By: williamwen42

fbshipit-source-id: 121abe4d8165d3bb4a2145841a8909bbd23a98dc
@JerrickLiu
Copy link

@williamwen42 is there a fix for this? I am also experiencing a memroy leak with torch compile and python 3.11

@williamwen42
Copy link
Member

Do you have a repro? The fix only recently went in so you should try the nightly binaries.

@JerrickLiu
Copy link

Not a local repro unforunately. I can try the nightly binary. How long does it take to make it into the default stable installation?

Where can I find the nightly binaries that would have your fix?

@JerrickLiu
Copy link

I'm also hitting this leak with backend=cudagraphs, if your fix accounts for that

@JerrickLiu
Copy link
JerrickLiu commented Apr 19, 2024

@williamwen42 bump on the nightly binary. Can I just use the one found here: https://pytorch.org/get-started/locally/

and selecting nightly? I have a nightly build, but with the above repro I still see the mem leak, most likely because I don't have your changes. Is there a way to verify I have your change?

@williamwen42
Copy link
Member

Yeah that's the right link. Give it another day and it should be in the nightlies. The fix will not be in the stable binaries until next release (2.4). The leak occurs at a very high level in the PT2 stack (dynamo) - it should occur even on the eager backend. Can you confirm this?

pytorch-bot bot pushed a commit that referenced this issue Apr 22, 2024
…120756)

Fixes remaining refleaks found when debugging #119607, tests added in #120657.

Also fixes some tests that xfail: #120631 (not entirely sure why), but introduced tests now fail.

Pull Request resolved: #120756
Approved by: https://github.com/jansel
pytorch-bot bot pushed a commit that referenced this issue Apr 22, 2024
Fixes #119607 for 3.11+.

In 3.11+, `_PyFrame_FastToLocalsWithError` could implicity run `COPY_FREE_VARS` on the original frame, leading to double incref's since the dynamo shadow frame can rerun `COPY_FREE_VARS`. So the solution is to skip the first `COPY_FREE_VARS` instruction in the shadow frame if it was already executed in the original frame.

Also move the location for clearing the original frame in 3.12 to handle error cases more thoroughly.

Pull Request resolved: #124238
Approved by: https://github.com/jansel
@JerrickLiu
Copy link

@williamwen42 I confirmed with a nightly build that the memory leak is fixed

petrex pushed a commit to petrex/pytorch that referenced this issue May 3, 2024
Fixes pytorch#119607 for 3.11+.

In 3.11+, `_PyFrame_FastToLocalsWithError` could implicity run `COPY_FREE_VARS` on the original frame, leading to double incref's since the dynamo shadow frame can rerun `COPY_FREE_VARS`. So the solution is to skip the first `COPY_FREE_VARS` instruction in the shadow frame if it was already executed in the original frame.

Also move the location for clearing the original frame in 3.12 to handle error cases more thoroughly.

Pull Request resolved: pytorch#124238
Approved by: https://github.com/jansel
pytorchbot pushed a commit that referenced this issue May 13, 2024
Fixes #119607 for 3.11+.

In 3.11+, `_PyFrame_FastToLocalsWithError` could implicity run `COPY_FREE_VARS` on the original frame, leading to double incref's since the dynamo shadow frame can rerun `COPY_FREE_VARS`. So the solution is to skip the first `COPY_FREE_VARS` instruction in the shadow frame if it was already executed in the original frame.

Also move the location for clearing the original frame in 3.12 to handle error cases more thoroughly.

Pull Request resolved: #124238
Approved by: https://github.com/jansel

(cherry picked from commit 812bae0)
williamwen42 added a commit that referenced this issue May 15, 2024
…120756)

Fixes remaining refleaks found when debugging #119607, tests added in #120657.

Also fixes some tests that xfail: #120631 (not entirely sure why), but introduced tests now fail.

Pull Request resolved: #120756
Approved by: https://github.com/jansel
williamwen42 added a commit that referenced this issue May 15, 2024
…120756)

Fixes remaining refleaks found when debugging #119607, tests added in #120657.

Also fixes some tests that xfail: #120631 (not entirely sure why), but introduced tests now fail.

Pull Request resolved: #120756
Approved by: https://github.com/jansel
atalman pushed a commit that referenced this issue May 22, 2024
…126332)

* [dynamo] use proxies to nn.Module in dynamo generated GraphModules (#120756)

Fixes remaining refleaks found when debugging #119607, tests added in #120657.

Also fixes some tests that xfail: #120631 (not entirely sure why), but introduced tests now fail.

Pull Request resolved: #120756
Approved by: https://github.com/jansel

* [dynamo] use proxies to nn.Module in dynamo generated GraphModules (#120756)

Fixes remaining refleaks found when debugging #119607, tests added in #120657.

Also fixes some tests that xfail: #120631 (not entirely sure why), but introduced tests now fail.

Pull Request resolved: #120756
Approved by: https://github.com/jansel
@huydhn
Copy link
Contributor
huydhn commented May 30, 2024

This issue has been fixed in the upcoming 2.3.1 release https://dev-discuss.pytorch.org/t/pytorch-release-2-3-1-final-rc-is-available/2126 on python 3.11

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

0