8000 pr #1629 breaks autograd when entropy_coeff != 0 · Issue #1970 · volcengine/verl · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
pr #1629 breaks autograd when entropy_coeff != 0 #1970
Open
@linxxx3

Description

@linxxx3

#1629 intruduced a "logits_processor" which potentially breaks autograd.

refer to following codes in verl/workers/actor/megatron_actor.py L409:

            def logits_processor(logits, label, label_mask):
                assert logits.shape[:2] == label.shape[:2]
                assert label.shape == label_mask.shape

                ret = {}

                if calculate_entropy:
                    entropy = vocab_parallel_entropy(logits)
                    ret["entropy"] = entropy

                log_probs = vocab_parallel_log_probs_from_logits(logits, label)
                log_probs = log_probs.masked_fill(~label_mask, 0.0)
                ret["log_probs"] = log_probs
                return ret

if entropy_coeff != 0, in logits_processor, vocab_parallel_entropy runs forward and saves logits for backward. and then vocab_parallel_log_probs_from_logits modifies logits in-place, like #1672 said. this breaks autograd in vocab_parallel_entropy's backward process.

error logs like:

  File ".../verl/utils/megatron/tensor_parallel.py", line 134, in backward
    vocab_parallel_logits, softmax_logits, sum_softmax_times_logits = ctx.saved_tensors
                                                                      ^^^^^^^^^^^^^^^^^
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 16228, 75968]], which is output 0 of ToCopyBackward0, is at version 5; expected version 0 instead.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0