8000 feat: optimize get logprobs when cp enabled. by joyang-nv · Pull Request #528 · NVIDIA-NeMo/RL · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

feat: optimize get logprobs when cp enabled. #528

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft

Conversation

joyang-nv
Copy link
Contributor
@joyang-nv joyang-nv commented Jun 18, 2025

What does this PR do ?

This PR optimize get logprobs when CP is enabled for FSDP2. Issue #549

Issues

In previous PR, the logits were retrieved from sharded one (local tensor shape [b, s / cp_size, v / tp_size]) into full tensor with shape [b, s, v] and passed to loss function.
The key reason was we had to ensure sequence order was correct when get log probs.
This PR allows permuted sequenced to pass to loss function with additional full tensor seq_index which indicates the order of the permuted sequence and allow parallel logprobs computation even mixed with tp enabled.

Test Result

cp8

convergence time cost
image image

tp4cp2

convergence time cost
image image
TP4CP2-0624 - TIMING/TRAIN/POLICY_TRAINING CP8-0624 - TIMING/TRAIN/POLICY_TRAINING LLAMA-3.1-8B-INSTRUCT-CP8-0610 - TIMING/TRAIN/POLICY_TRAINING LLAMA-3.1-8B-INSTRUCT-TP4CP2-0609 - TIMING/TRAIN/POLICY_TRAINING
62.53703141 52.19302438 55.94293963 66.29522389

Average step has saved 3.+ seconds.

@github-actions github-actions bot added the documentation Improvements or additions to documentation label Jun 18, 2025
@github-actions github-actions bot removed the documentation Improvements or additions to documentation label Jun 18, 2025
@joyang-nv joyang-nv force-pushed the joyang/cp_opt branch 3 times, most recently from 34f829d to a068b07 Compare June 25, 2025 07:14
@joyang-nv joyang-nv changed the title Optimize get logprobs when cp enabled. feat: optimize get logprobs when cp enabled. Jun 25, 2025
@joyang-nv joyang-nv added the CI:L1 Run doctests, unit tests, and functional tests label Jun 25, 2025
@joyang-nv joyang-nv requested review from gshennvm, SahilJain314 and abukharin-nv and removed request for SahilJain314 June 25, 2025 08:23
Signed-off-by: Jonas yang <joyang@nvidia.com>
Signed-off-by: Jonas yang <joyang@nvidia.com>
Signed-off-by: Jonas yang <joyang@nvidia.com>
@joyang-nv joyang-nv added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Jun 25, 2025
@joyang-nv joyang-nv requested a review from terrykong June 25, 2025 16:06
logits = DTensor.from_local(
local_logits,
device_mesh=self.device_mesh["cp", "tp"],
placements=[Shard(sequence_dim), Shard(-1)],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any benefit to doing the redistribute like this vs. logits.redistribute(device_mesh=..., placements=....)?

Also, can this be set to async_op=True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accepted. :) Just want to unify full tensor/dtensor format.

assert isinstance(target, DTensor), (
"target must be a DTensor if seq_index is provided"
)
cp_mesh = target.device_mesh
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SahilJain314 to comment on CP making its appearance in the model agnostic utilities

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CI:L1 Run doctests, unit tests, and functional tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
0