8000 [Rllib] new API stack crashes when using Repeated space · Issue #52093 · ray-project/ray · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
[Rllib] new API stack crashes when using Repeated space #52093
Open
@Phirefly9

Description

@Phirefly9

What happened + What you expected to happen

I'm attempting to upgrade our code framework that uses RLLIB and we currently depend on the repeated space object your team has implemented, attempting to use them on 2.44.1 results in a crash

As a Side item is there any documentation support in rllib to support custom gymnasium spaces or even just the entire "core gymnasium" spaces? I don't think there is any way for me to work around this error currently. I would love the capability to implement custom spaces

error message:

(SingleAgentEnvRunner pid=452302)     return func(self, *args, **kwargs)
(SingleAgentEnvRunner pid=452302)   File "/home/clong/.cache/pypoetry/virtualenvs/corl-QTKLdzdH-py3.10/lib/python3.10/site-packages/ray/rllib/execution/rollout_ops.py", line 110, in <lambda>
(SingleAgentEnvRunner pid=452302)     else (lambda w: (w.sample(**random_action_kwargs), w.get_metrics()))
(SingleAgentEnvRunner pid=452302)   File "/home/clong/.cache/pypoetry/virtualenvs/corl-QTKLdzdH-py3.10/lib/python3.10/site-packages/ray/util/tracing/tracing_helper.py", line 463, in _resume_span
(SingleAgentEnvRunner pid=452302)     return method(self, *_args, **_kwargs)
(SingleAgentEnvRunner pid=452302)   File "/home/clong/.cache/pypoetry/virtualenvs/corl-QTKLdzdH-py3.10/lib/python3.10/site-packages/ray/rllib/env/single_agent_env_runner.py", line 206, in sample
(SingleAgentEnvRunner pid=452302)     samples = self._sample(
(SingleAgentEnvRunner pid=452302)   File "/home/clong/.cache/pypoetry/virtualenvs/corl-QTKLdzdH-py3.10/lib/python3.10/site-packages/ray/util/tracing/tracing_helper.py", line 463, in _resume_span
(SingleAgentEnvRunner pid=452302)     return method(self, *_args, **_kwargs)
(SingleAgentEnvRunner pid=452302)   File "/home/clong/.cache/pypoetry/virtualenvs/corl-QTKLdzdH-py3.10/lib/python3.10/site-packages/ray/rllib/env/single_agent_env_runner.py", line 261, in _sample
(SingleAgentEnvRunner pid=452302)     self._reset_envs(episodes, shared_data, explore)
(SingleAgentEnvRunner pid=452302)   File "/home/clong/.cache/pypoetry/virtualenvs/corl-QTKLdzdH-py3.10/lib/python3.10/site-packages/ray/util/tracing/tracing_helper.py", line 463, in _resume_span
(SingleAgentEnvRunner pid=452302)     return method(self, *_args, **_kwargs)
(SingleAgentEnvRunner pid=452302)   File "/home/clong/.cache/pypoetry/virtualenvs/corl-QTKLdzdH-py3.10/lib/python3.10/site-packages/ray/rllib/env/single_agent_env_runner.py", line 729, in _reset_envs
(SingleAgentEnvRunner pid=452302)     episodes[env_index].add_env_reset(
(SingleAgentEnvRunner pid=452302)   File "/home/clong/.cache/pypoetry/virtualenvs/corl-QTKLdzdH-py3.10/lib/python3.10/site-packages/ray/rllib/env/single_agent_episode.py", line 376, in add_env_reset
(SingleAgentEnvRunner pid=452302)     assert self.observation_space.contains(observation), (
(SingleAgentEnvRunner pid=452302) AssertionError: `observation` {'angular-pos': {'some_random_stuff': 1, 'value': 0.037983317}, 'test': ([-0.41928768, 1.0369213],), 'velocs': (array([0.04648657], dtype=float32), 0.0046693585), 'x-pos': array([-0.01215742], dtype=float32)} does NOT fit SingleAgentEpisode's observation_space: Dict('angular-pos': Dict('some_random_stuff': Discrete(3), 'value': Box(-0.41887903, 0.41887903, (), float32)), 'test': Repeated(Box(-inf, inf, (1,), float32), 2), 'velocs': Tuple(Box(-inf, inf, (1,), float32), Box(-inf, inf, (), float32)), 'x-pos': Box(-4.8, 4.8, (1,), float32))!

Versions / Dependencies

ray 2.44.1
torch 2.6
numpy 1.26.1

Reproduction script

(grabed from your unit tests, use --enable-new-api-stack and you will get the fail, the old stack works fine

from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole
from ray.rllib.utils.test_utils import (
    add_rllib_example_script_args,
    run_rllib_example_script_experiment,
)
from ray.tune.registry import get_trainable_cls, register_env

parser = add_rllib_example_script_args(
    default_iters=200,
    default_timesteps=100000,
    default_reward=600.0,
)
# TODO (sven): This arg is currently ignored (hard-set to 2).
# parser.add_argument("--num-policies", type=int, default=2)

import gymnasium as gym
from gymnasium.envs.classic_control.cartpole import CartPoleEnv
import numpy as np
from gymnasium.spaces.sequence import Sequence
from ray.rllib.utils.spaces.repeated import Repeated


class CartPoleWithDictObservationSpace(CartPoleEnv):
    """CartPole gym environment that has a dict observation space.

    However, otherwise, the information content in each observation remains the same.

    https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/envs/classic_control/cartpole.py  # noqa

    The new observation space looks as follows (a little quirky, but this is
    for testing purposes only):

    gym.spaces.Dict({
        "x-pos": [x-pos],
        "angular-pos": gym.spaces.Dict({"test": [angular-pos]}),
        "velocs": gym.spaces.Tuple([x-veloc, angular-veloc]),
    })
    """

    def __init__(self, config=None):
        super().__init__()

        # Fix our observation-space as described above.
        low = self.observation_space.low
        high = self.observation_space.high

        # Test as many quirks and oddities as possible: Dict, Dict inside a Dict,
        # Tuple inside a Dict, and both (1,)-shapes as well as ()-shapes for Boxes.
        # Also add a random discrete variable here.
        
        self.observation_space = gym.spaces.Dict(
            {
                "x-pos": gym.spaces.Box(low[0], high[0], (1,), dtype=np.float32),
                "angular-pos": gym.spaces.Dict(
                    {
                        "value": gym.spaces.Box(low[2], high[2], (), dtype=np.float32),
                        # Add some random non-essential information.
                        "some_random_stuff": gym.spaces.Discrete(3),
                    }
                ),
                "velocs": gym.spaces.Tuple(
                    [
                        # x-veloc
                        gym.spaces.Box(low[1], high[1], (1,), dtype=np.float32),
                        # angular-veloc
                        gym.spaces.Box(low[3], high[3], (), dtype=np.float32),
                    ]
                ),
                "test": Repeated(gym.spaces.Box(low[1], high[1], (1,), dtype=np.float32), max_len=2),
                # "test2": Sequence(gym.spaces.Discrete(5))
            }
        )

    def step(self, action):
        next_obs, reward, done, truncated, info = super().step(action)
        return self._compile_current_obs(next_obs), reward, done, truncated, info

    def reset(self, *, seed=None, options=None):
        init_obs, init_info = super().reset(seed=seed, options=options)
        return self._compile_current_obs(init_obs), init_info

    def _compile_current_obs(self, original_cartpole_obs):
        # original_cartpole_obs is [x-pos, x-veloc, angle, angle-veloc]
        return {
            "x-pos": np.array([original_cartpole_obs[0]], np.float32),
            "angular-pos": {
                "value": np.array(original_cartpole_obs[2]),
                "some_random_stuff": np.random.randint(3),
            },
            "velocs": (
                np.array([original_cartpole_obs[1]], np.float32),
                np.array(original_cartpole_obs[3], np.float32),
            ),
            "test": self.observation_space["test"].sample(),
            # "test2": self.observation_space["test2"].sample()
        }


register_env(
            "env",
            lambda _: CartPoleWithDictObservationSpace(config=None),
        )
from ray.rllib.connectors.env_to_module import FlattenObservations
def _env_to_module_pipeline(env):
    return FlattenObservations(multi_agent=False)

if __name__ == "__main__":
    args = parser.parse_args()
        

    base_config = (
        get_trainable_cls(args.algo)
        .get_default_config()
        .environment("env")
        .env_runners(
            num_env_runners=10,
            env_to_module_connector=_env_to_module_pipeline,
            num_envs_per_env_runner=1
        )
        # .training(
        #     train_batch_size_per_learner=40000,
        #     minibatch_size=4000,
        # )
    )

    run_rllib_example_script_experiment(base_config, args)

Issue Severity

High: It blocks me from completing my task.

Metadata

Metadata

Assignees

No one assigned

    Labels

    P1Issue that should be fixed within a few weeksbugSomething that is supposed to be working; but isn'tcommunity-backlogrllibRLlib related issues

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0