Open
Description
What happened + What you expected to happen
Good Afternoon,
I tried to train a competitive environment where both agents were policy mapped to the same rl_module using PPO and ran into an issue where the TorchMultiDistribution cannot deal with calculating the
batch[Columns.ACTION_LOGP] = action_dist.logp(actions)
call in the get_actions
connector when the batch size is more than 1
I've provided a fixed implementation below (I trained PPO for several hours using it with no issues)
Versions / Dependencies
ray 2.47.1
Reproduction script
from ray.rllib.models.torch.torch_distributions import TorchMultiDistribution, TorchCategorical, TorchMultiCategorical
from ray.rllib.utils.annotations import override
import torch
import tree
class FixedTorchMultiDistribution(TorchMultiDistribution):
# copy paste of TorchMultiDistribution but the map_ function is changed
# return is also changed
@override(TorchMultiDistribution)
def logp(self, value):
# Different places in RLlib use this method with different inputs.
# We therefore need to handle a flattened and concatenated input, as well as
# a nested one.
# TODO(Artur): Deprecate tensor inputs, only allow nested structures.
if isinstance(value, torch.Tensor):
split_indices = []
for dist in self._flat_child_distributions:
if isinstance(dist, TorchCategorical):
split_indices.append(1)
elif isinstance(dist, TorchMultiCategorical):
split_indices.append(len(dist._cats))
else:
sample = dist.sample()
# Cover Box(shape=()) case.
if len(sample.shape) == 1:
split_indices.append(1)
else:
split_indices.append(sample.size()[1])
split_value = list(torch.split(value, split_indices, dim=1))
else:
split_value = tree.flatten(value)
def map_(val, dist):
return dist.logp(val)
flat_logps = tree.map_structure(
map_, split_value, self._flat_child_distributions
)
return torch.sum(torch.stack(flat_logps, dim=1), dim=1)
Issue Severity
Low: It annoys or frustrates me.