-
Notifications
You must be signed in to change notification settings - Fork 56
feat: optimize get logprobs when cp enabled. #528
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
34f829d
to
a068b07
Compare
Signed-off-by: Jonas yang <joyang@nvidia.com>
Signed-off-by: Jonas yang <joyang@nvidia.com>
logits = DTensor.from_local( | ||
local_logits, | ||
device_mesh=self.device_mesh["cp", "tp"], | ||
placements=[Shard(sequence_dim), Shard(-1)], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there any benefit to doing the redistribute like this vs. logits.redistribute(device_mesh=..., placements=....)
?
Also, can this be set to async_op=True?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accepted. :) Just want to unify full tensor/dtensor format.
assert isinstance(target, DTensor), ( | ||
"target must be a DTensor if seq_index is provided" | ||
) | ||
cp_mesh = target.device_mesh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SahilJain314 to comment on CP making its appearance in the model agnostic utilities
What does this PR do ?
This PR optimize get logprobs when CP is enabled for FSDP2. Issue #549
Issues
In previous PR, the logits were retrieved from sharded one (local tensor shape [b, s / cp_size, v / tp_size]) into full tensor with shape [b, s, v] and passed to loss function.
The key reason was we had to ensure sequence order was correct when get log probs.
This PR allows permuted sequenced to pass to loss function with additional full tensor
seq_index
which indicates the order of the permuted sequence and allow parallel logprobs computation even mixed with tp enabled.Test Result
cp8
tp4cp2
Average step has saved 3.+ seconds.