8000 Add x_orig param in SequentialEx to allow split models by Patataman · Pull Request #4042 · fastai/fastai · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add x_orig param in SequentialEx to allow split models #4042

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

Patataman
Copy link
@Patataman Patataman commented Aug 20, 2024

I was trying to split a model created with unet_learner and I found that no matter how you split the model, if there is a ResizeToOrig layer on it, you couldn't split it because it uses the original input as reference. Until now.

I have this example code that I would like to add to the tests, but I do not see any reference about how to include it

from fastai.vision.all import *
from fastai.vision.gan import *
import torch

# Used to split the model created with unet_learner
class SimplifiedModel(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        layers = [m for m in model.layers]
        m_len = len(model.layers)
        self.layer1 = SequentialEx(*layers[:m_len//3])
        self.layer2 = SequentialEx(*layers[m_len//3:m_len//3*2])
        self.layer3 = SequentialEx(*layers[m_len//3*2:])

    def forward(self, x):
        _x = self.layer1(x)
        # This would have failed before this PR because it takes the original value as the input,
        # which is not the real original input of the net
        _x = self.layer2(_x, x_orig=x)
        return self.layer3(_x, x_orig=x)


n_samples = 100
n_channels = 3 
image_size = 128
n_classes = 2

# Random image tensors and labels
X = torch.randn(n_samples, n_channels, image_size, image_size)
y = torch.randint(0, n_classes, (n_samples, image_size, image_size))

train_dl = DataLoader(list(zip(X,y)), batch_size=32, shuffle=True, device="cuda:0")
dls = DataLoaders(train_dl, train_dl)

model = resnet34

learn = unet_learner(
    dls, model, loss_func=nn.CrossEntropyLoss(),
    normalize=False, n_out=n_classes, n_in=n_channels
)

n_epochs = 5
learn.fit_one_cycle(n_epochs)

# Alternative version, because "reasons"
learn2 = unet_learner(
    dls, model, loss_func=nn.CrossEntropyLoss(),
    normalize=False, n_out=n_classes, n_in=n_channels
)
newmodel = SimplifiedModel(learn.model)
newmodel.to("cuda:0")

loss = learn2.loss_func
opt = learn2.opt

for n in range(n_epochs):
    print("Epoch", n)
    for inp, target in dls[0]:
        inp = inp.to("cuda")
        target = target.to("cuda")
        pred = newmodel(torch.Tensor(inp))
        error = loss(pred, target)
    print(error)  # Error is slightly different because logic is not the same as FastAI, but it works

Edit: I submitted it before finish writing

Afaik, there is no easy way to do Model Parallelism in FastAI when using multiple nodes (no multiple GPUs in 1 node, but multiples nodes with 1 GPU). With this PR, it would be possible to use PyTorch RPC module to split the model using SequentialEx

@Patataman Patataman requested a review from jph00 as a code owner August 20, 2024 11:06
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@Patataman Patataman closed this Aug 20, 2024
@Patataman
Copy link
Author

I need to fix some things that I didn't think about before

@Patataman
Copy link
Author
Patataman commented Aug 22, 2024

Ok, I found that (at least for this example) when splitting the first 8 layers between different machines it does not work because of the hooks triggered on the first 4 layers, which store values later used in 4 UnetBlocks.

However, putting all those layers in the same node """"fix"""" this problem. It is not a real solution to the problem with the hooks, but at least it works, and you can still split the model among different nodes

I'm trying to find out how to address this situation with the hooks, but for now, this change in SequentialEx would still be necessary

@Patataman Patataman reopened this Aug 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant
0