Closed
Description
What happened + What you expected to happen
I opened an issue with PettingZoo which I believe is the main cause of the problem, but it also may be related to RLlib, posting this here in case someone who's worked on the ParallelPettingZooEnv class could help diagnose the problem: Farama-Foundation/PettingZoo#889
Versions / Dependencies
gym==0.23.1
Gymnasium==0.26.3
numpy==1.23.5
PettingZoo==1.22.3
Pillow==9.4.0
pygame==2.1.2
ray==2.3.0
SuperSuit==3.7.1
tianshou==0.4.11
torch==1.13.1
Reproduction script
Basic working example not using ray (to show that the env works on its own, the pre-processing steps with ss aren't the problem afaik):
import supersuit as ss
from pettingzoo.butterfly import pistonball_v6
def env_creator(args):
env = pistonball_v6.parallel_env(
n_pistons=20,
time_penalty=-0.1,
continuous=True,
random_drop=True,
random_rotate=True,
ball_mass=0.75,
ball_friction=0.3,
ball_elasticity=1.5,
max_cycles=125,
render_mode="human"
)
env = ss.color_reduction_v0(env, mode="B")
env = ss.dtype_v0(env, "float32")
env = ss.resize_v1(env, x_size=84, y_size=84)
env = ss.normalize_obs_v0(env, env_min=0, env_max=1)
env = ss.frame_stack_v1(env, 3)
return env
if __name__ == "__main__":
env = env_creator({})
env.reset()
while env.agents:
actions = {agent: env.action_space(agent).sample() for agent in env.agents}
observations, rewards, terminations, truncations, infos = env.step(actions)
rllib_pistonball.py:
"""Uses Ray's RLLib to train agents to play Pistonball.
Author: Rohan (https://github.com/Rohan138)
"""
import os
import ray
import supersuit as ss
from ray import tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.tune.registry import register_env
from torch import nn
from pettingzoo.butterfly import pistonball_v6
# raise NotImplementedError(
# "There are currently bugs in this tutorial, we will fix them soon."
# )
class CNNModelV2(TorchModelV2, nn.Module):
def __init__(self, obs_space, act_space, num_outputs, *args, **kwargs):
TorchModelV2.__init__(self, obs_space, act_space, num_outputs, *args, **kwargs)
nn.Module.__init__(self)
self.model = nn.Sequential(
nn.Conv2d(3, 32, [8, 8], stride=(4, 4)),
nn.ReLU(),
nn.Conv2d(32, 64, [4, 4], stride=(2, 2)),
nn.ReLU(),
nn.Conv2d(64, 64, [3, 3], stride=(1, 1)),
nn.ReLU(),
nn.Flatten(),
(nn.Linear(3136, 512)),
nn.ReLU(),
)
self.policy_fn = nn.Linear(512, num_outputs)
self.value_fn = nn.Linear(512, 1)
def forward(self, input_dict, state, seq_lens):
model_out = self.model(input_dict["obs"].permute(0, 3, 1, 2))
self._value_out = self.value_fn(model_out)
return self.policy_fn(model_out), state
def value_function(self):
return self._value_out.flatten()
def env_creator(args):
env = pistonball_v6.parallel_env(
n_pistons=20,
time_penalty=-0.1,
continuous=True,
random_drop=True,
random_rotate=True,
ball_mass=0.75,
ball_friction=0.3,
ball_elasticity=1.5,
max_cycles=125,
)
env = ss.color_reduction_v0(env, mode="B")
env = ss.dtype_v0(env, "float32")
env = ss.resize_v1(env, x_size=84, y_size=84)
env = ss.normalize_obs_v0(env, env_min=0, env_max=1)
env = ss.frame_stack_v1(env, 3)
return env
if __name__ == "__main__":
ray.init(local_mode=True)
env_name = "pistonball_v6"
register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator(config)))
ModelCatalog.register_custom_model("CNNModelV2", CNNModelV2)
config = (
PPOConfig()
.rollouts(num_rollout_workers=4, rollout_fragment_length=128)
.training(
train_batch_size=512,
lr=2e-5,
gamma=0.99,
lambda_=0.9,
use_gae=True,
clip_param=0.4,
grad_clip=None,
entropy_coeff=0.1,
vf_loss_coeff=0.25,
sgd_minibatch_size=64,
num_sgd_iter=10,
)
.environment(env=env_name, clip_actions=True)
.debugging(log_level="ERROR")
.framework(framework="torch")
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
)
tune.run(
"PPO",
name="PPO",
stop={"timesteps_total": 5000000},
checkpoint_freq=10,
local_dir="~/ray_results/" + env_name,
config=config.to_dict(),
)
Issue Severity
High: It blocks me from completing my task.