10000 [Rllib] Bug in TorchMultiDistribution logp prevents policy mapping from being used · Issue #53994 · ray-project/ray · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
[Rllib] Bug in TorchMultiDistribution logp prevents policy mapping from being used #53994
Open
@Phirefly9

Description

@Phirefly9

What happened + What you expected to happen

Good Afternoon,

I tried to train a competitive environment where both agents were policy mapped to the same rl_module using PPO and ran into an issue where the TorchMultiDistribution cannot deal with calculating the
batch[Columns.ACTION_LOGP] = action_dist.logp(actions) call in the get_actions connector when the batch size is more than 1

I've provided a fixed implementation below (I trained PPO for several hours using it with no issues)

Versions / Dependencies

ray 2.47.1

Reproduction script

from ray.rllib.models.torch.torch_distributions import TorchMultiDistribution, TorchCategorical, TorchMultiCategorical
from ray.rllib.utils.annotations import override
import torch
import tree

class FixedTorchMultiDistribution(TorchMultiDistribution):
    # copy paste of TorchMultiDistribution but the map_ function is changed
    # return is also changed

    @override(TorchMultiDistribution)
    def logp(self, value):
        # Different places in RLlib use this method with different inputs.
        # We therefore need to handle a flattened and concatenated input, as well as
        # a nested one.
        # TODO(Artur): Deprecate tensor inputs, only allow nested structures.
        if isinstance(value, torch.Tensor):
            split_indices = []
            for dist in self._flat_child_distributions:
                if isinstance(dist, TorchCategorical):
                    split_indices.append(1)
                elif isinstance(dist, TorchMultiCategorical):
                    split_indices.append(len(dist._cats))
                else:
                    sample = dist.sample()
                    # Cover Box(shape=()) case.
                    if len(sample.shape) == 1:
                        split_indices.append(1)
                    else:
                        split_indices.append(sample.size()[1])
            split_value = list(torch.split(value, split_indices, dim=1))
        else:
            split_value = tree.flatten(value)

        def map_(val, dist):
            return dist.logp(val)

        flat_logps = tree.map_structure(
            map_, split_value, self._flat_child_distributions
        )

        return torch.sum(torch.stack(flat_logps, dim=1), dim=1)

Issue Severity

Low: It annoys or frustrates me.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething that is supposed to be working; but isn'trllibRLlib related issuesstabilitytriageNeeds triage (eg: priority, bug/not-bug, and owning component)

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0