8000 Gymnasium migration by Rohan138 · Pull Request #177 · facebookresearch/mbrl-lib · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
This repository was archived by the owner on Sep 1, 2024. It is now read-only.

Gymnasium migration #177

Merged
merged 40 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
281c00a
wip
Rohan138 Jan 29, 2023
0077789
wip
Rohan138 Jan 30, 2023
00eb286
wip
Rohan138 Jan 30, 2023
ecd1597
wip
Rohan138 Jan 30, 2023
2734ef6
wip
Rohan138 Jan 30, 2023
3a0811e
wip
Rohan138 Jan 30, 2023
ec638f5
drop python 3.11 for now
Rohan138 Feb 1, 2023
8665ded
wip
Rohan138 Feb 1, 2023
61e17f1
wip
Rohan138 Feb 1, 2023
1a491b2
wip
Rohan138 Feb 1, 2023
9c68439
Merge branch 'main' into gymnasium
Rohan138 Feb 9, 2023
206dfd7
wip
Rohan138 Feb 9, 2023
46fa18a
wip
Rohan138 Feb 10, 2023
f1c8244
wip
Rohan138 Feb 13, 2023
a2a9d1d
wip
Rohan138 Feb 13, 2023
61d6c60
wip
Rohan138 Feb 13, 2023
09caf18
wip
Rohan138 Feb 13, 2023
669af65
wip - passes tests/algorithms
raghavauppuluri13 Feb 14, 2023
dca254f
wip
Rohan138 Feb 17, 2023
0d0d203
wip
Rohan138 Feb 17, 2023
88a8040
wip
Rohan138 Feb 17, 2023
3bfa309
wip
Rohan138 Feb 17, 2023
dec83fa
wip
Rohan138 Feb 19, 2023
edaa68f
Merge branch 'facebookresearch:main' into gymnasium
Rohan138 Feb 19, 2023
83a817e
wip
Rohan138 Feb 19, 2023
611e383
Merge branch 'gymnasium' of https://github.com/Rohan138/mbrl-lib into…
Rohan138 Feb 19, 2023
d777e5a
wip
Rohan138 Feb 19, 2023
6f01689
wip
Rohan138 Feb 19, 2023
4082783
wip
Rohan138 Feb 19, 2023
35f85f8
Fixed errors in notebooks after Gymnasium migration.
luisenp Feb 20, 2023
9e182d4
wip
Rohan138 Feb 20, 2023
e28cafc
wip
Rohan138 Feb 20, 2023
9afe539
wip
Rohan138 Feb 20, 2023
b63c1ed
Merge pull request #2 from facebookresearch/lep.fix_pets_notebook
Rohan138 Feb 20, 2023
7859750
wip
Rohan138 Feb 20, 2023
d18dde3
wip
Rohan138 Feb 21, 2023
ac9abf1
updated cfgs to support new wrappers
raghavauppuluri13 Feb 21, 2023
845ba05
wip
Rohan138 Mar 3, 2023
ac58d46
wip
Rohan138 Mar 7, 2023
aa76f04
Update changelog
Rohan138 Mar 27, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
flake8 mbrl --ignore=E203,W503 --per-file-ignores='mbrl/env/mujoco_envs.py:F401 */__init__.py:F401 tests/*:F401' --max-line-length=100
- name: Lint with mypy
run: |
mypy mbrl --no-strict-optional --ignore-missing-imports
mypy mbrl --no-strict-optional --ignore-missing-imports --follow-imports=skip
- name: Check format with black
run: |
black --check mbrl
Expand Down
5 changes: 2 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/psf/black
rev: 22.6.0
rev: 23.1.0
hooks:
- id: black
files: 'mbrl'
Expand All @@ -17,8 +17,7 @@ repos:
- id: mypy
files: 'mbrl'
additional_dependencies: [numpy, torch, tokenize-rt==3.2.0, types-PyYAML, types-termcolor]
args: [--no-strict-optional, --ignore-missing-imports]
exclude: setup.py
args: [--no-strict-optional, --ignore-missing-imports, --follow-imports=skip]

- repo: https://github.com/pycqa/isort
rev: 5.12.0
Expand Down
19 changes: 18 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
# Changelog

## main (v0.2.0.dev4)
## main (v0.2.0)
### Breaking changes
- Migrated from [gym](https://github.com/openai/gym) to [Gymnasium](https://github.com/Farama-Foundation/Gymnasium/)
- `gym==0.26.3` is still required for the dm_control and pybullet-gym environments
- `Transition` and `TranistionBatch` now support the `terminated` and `truncated` booleans
instead of the single `done` boolean previously used by gym
- Migrated calls to `env.reset()` which now returns a tuple of `obs, info` instead of just `obs`
- Migrated calls to `env.step()` which now returns a `observation, reward, terminated, truncated, info`
- Migrated to Gymnasium render API, environments are instantiated with `render_mode=None` by default
- DMC and PyBullet envs use the original gym wrappers to turn them into gym environments, then are wrapper by gymnasium.envs.GymV20Environment
- All Mujoco envs use the DeepMind Mujoco [bindings](https://github.com/deepmind/mujoco), [mujoco-py](https://github.com/openai/mujoco-py) is deprecated as a dependency
- Custom Mujoco envs e.g. `AntTruncatedObsEnv` inherit from gymnasium.envs.mujoco_env.MujocoEnv, and access data through `self.data` instead of `self.sim.data`
- Mujoco environment versions have been updated to `v4` from`v2` e.g. `Hopper-v4`
- [Fixed](https://github.com/facebookresearch/mbrl-lib/blob/ac58d46f585cc90c064b8c989e7ddf64f9e330ce/mbrl/algorithms/planet.py#L147) PlaNet to save model to a directory instead of a file name
- Added `follow-imports=skip` to `mypy` CI test to allow for gymnasium/gym wrapper compatibility
- Bumped `black` to version `0.23.1` in CI

## v0.2.0.dev4
### Main new features
- Added [PlaNet](http://proceedings.mlr.press/v97/hafner19a/hafner19a.pdf) implementation.
- Added support for [PyBullet](https://pybullet.org/wordpress/) environments.
Expand Down
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ environment specific configurations for each environment, overriding the
default configurations with the best hyperparameter values we have found so far
for each combination of algorithm and environment. You can run training
by passing the desired override option via command line.
For example, to run MBPO on the gym version of HalfCheetah, you should call
For example, to run MBPO on the [Gymnasium](https://github.com/Farama-Foundation/Gymnasium/) version of HalfCheetah, you should call
```python
python -m mbrl.examples.main algorithm=mbpo overrides=mbpo_halfcheetah
```
Expand All @@ -86,16 +86,17 @@ all the available options, take a look at the provided
## Supported environments
Our example configurations are largely based on [Mujoco](https://mujoco.org/), but
our library components (and algorithms) are compatible with any environment that follows
the standard gym syntax. You can try our utilities in other environments
the standard [Gymnasium](https://github.com/Farama-Foundation/Gymnasium/) syntax. You can try our utilities in other environments
by creating your own entry script and Hydra configuration, using our default entry
[`main.py`](https://github.com/facebookresearch/mbrl-lib/blob/main/mbrl/examples/main.py) as guiding template.
See also the example [override](https://github.com/facebookresearch/mbrl-lib/tree/main/mbrl/examples/conf/overrides)
configurations.

Without any modifications, our provided `main.py` can be used to launch experiments with the following environments:
* [`mujoco-py`](https://github.com/openai/mujoco-py) (up to version 2.0)
* [`mujoco`](https://github.com/deepmind/mujoco)
* [`dm_control`](https://github.com/deepmind/dm_control)
* [`pybullet-gym`](https://github.com/benelot/pybullet-gym) (thanks to [dtch1997](https://github.com/dtch1997)) for the contribution!
Note: You must run `pip install gym==0.26.3` to use the dm_control and pybulletgym environments.

You can test your Mujoco and PyBullet installations by running

Expand All @@ -106,7 +107,7 @@ To specify the environment to use for `main.py`, there are two possibilities:

* **Preferred way**: Use a Hydra dictionary to specify arguments for your env constructor. See [example](https://github.com/facebookresearch/mbrl-lib/blob/main/mbrl/examples/conf/overrides/planet_cartpole_balance.yaml#L4).
* Less flexible alternative: A single string with the following syntax:
- `mujoco-gym`: `"gym___<env-name>"`, where `env-name` is the name of the environment in gym (e.g., "HalfCheetah-v2").
- `mujoco-gym`: `"gym___<env-name>"`, where `env-name` is the name of the environment in Gymnasium (e.g., "HalfCheetah-v2").
- `dm_control`: `"dmcontrol___<domain>--<task>`, where domain/task are defined as in DMControl (e.g., "cheetah--run").
- `pybullet-gym`: `"pybulletgym___<env-name>"`, where `env-name` is the name of the environment in pybullet gym (e.g., "HopperPyBulletEnv-v0")

Expand Down
2 changes: 1 addition & 1 deletion mbrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
__version__ = "0.2.0.dev4"
__version__ = "0.2.0"
42 changes: 31 additions & 11 deletions mbrl/algorithms/mbpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
from typing import Optional, Sequence, cast

import gym
import gymnasium as gym
import hydra.utils
import numpy as np
import omegaconf
Expand Down Expand Up @@ -50,12 +50,14 @@ def rollout_model_and_populate_sac_buffer(
pred_next_obs, pred_rewards, pred_dones, model_state = model_env.step(
action, model_state, sample=True
)
truncateds = np.zeros_like(pred_dones, dtype=bool)
sac_buffer.add_batch(
obs[~accum_dones],
action[~accum_dones],
pred_next_obs[~accum_dones],
pred_rewards[~accum_dones, 0],
pred_dones[~accum_dones, 0],
truncateds[~accum_dones, 0],
)
obs = pred_next_obs
accum_dones |= pred_dones.squeeze()
Expand All @@ -69,13 +71,14 @@ def evaluate(
) -> float:
avg_episode_reward = 0.0
for episode in range(num_episodes):
obs = env.reset()
obs, _ = env.reset()
video_recorder.init(enabled=(episode == 0))
done = False
terminated = False
truncated = False
episode_reward = 0.0
while not done:
while not terminated and not truncated:
action = agent.act(obs)
obs, reward, done, _ = env.step(action)
obs, reward, terminated, truncated, _ = env.step(action)
video_recorder.record(env)
episode_reward += reward
avg_episode_reward += episode_reward
Expand All @@ -97,8 +100,15 @@ def maybe_replace_sac_buffer(
new_buffer = mbrl.util.ReplayBuffer(new_capacity, obs_shape, act_shape, rng=rng)
if sac_buffer is None:
return new_buffer
obs, action, next_obs, reward, done = sac_buffer.get_all().astuple()
new_buffer.add_batch(obs, action, next_obs, reward, done)
(
obs,
action,
next_obs,
reward,
terminated,
truncated,
) = sac_buffer.get_all().astuple()
new_buffer.add_batch(obs, action, next_obs, reward, terminated, truncated)
return new_buffer
return sac_buffer

Expand Down Expand Up @@ -194,13 +204,23 @@ def train(
sac_buffer = maybe_replace_sac_buffer(
sac_buffer, obs_shape, act_shape, sac_buffer_capacity, cfg.seed
)
obs, done = None, False
obs = None
terminated = False
truncated = False
for steps_epoch in range(cfg.overrides.epoch_length):
if steps_epoch == 0 or done:
if steps_epoch == 0 or terminated or truncated:
steps_epoch = 0
obs, done = env.reset(), False
obs, _ = env.reset()
terminated = False
truncated = False
# --- Doing env step and adding to model dataset ---
next_obs, reward, done, _ = mbrl.util.common.step_env_and_add_to_buffer(
(
next_obs,
reward,
terminated,
truncated,
_,
) = mbrl.util.common.step_env_and_add_to_buffer(
env, obs, agent, {}, replay_buffer
)

Expand Down
17 changes: 12 additions & 5 deletions mbrl/algorithms/pets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
from typing import Optional

import gym
import gymnasium as gym
import numpy as np
import omegaconf
import torch
Expand Down Expand Up @@ -95,12 +95,13 @@ def train(
current_trial = 0
max_total_reward = -np.inf
while env_steps < cfg.overrides.num_steps:
obs = env.reset()
obs, _ = env.reset()
agent.reset()
done = False
terminated = False
truncated = False
total_reward = 0.0
steps_trial = 0
while not done:
while not terminated and not truncated:
# --------------- Model Training -----------------
if env_steps % cfg.algorithm.freq_train_model == 0:
mbrl.util.common.train_model_and_save_model_and_data(
Expand All @@ -112,7 +113,13 @@ def train(
)

# --- Doing env step using the agent and adding to model dataset ---
next_obs, reward, done, _ = mbrl.util.common.step_env_and_add_to_buffer(
(
next_obs,
reward,
terminated,
truncated,
_,
) = mbrl.util.common.step_env_and_add_to_buffer(
env, obs, agent, {}, replay_buffer
)

Expand Down
28 changes: 18 additions & 10 deletions mbrl/algorithms/planet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
import pathlib
from typing import List, Optional, Union

import gym
import gymnasium as gym
import hydra
import numpy as np
import omegaconf
import torch
from tqdm import tqdm

import mbrl.constants
from mbrl.env.termination_fns import no_termination
Expand Down Expand Up @@ -130,7 +131,7 @@ def is_test_episode(episode_):
# PlaNet loop
step = replay_buffer.num_stored
total_rewards = 0.0
for episode in range(cfg.algorithm.num_episodes):
for episode in tqdm(range(cfg.algorithm.num_episodes)):
# Train the model for one epoch of `num_grad_updates`
dataset, _ = get_sequence_buffer_iterator(
replay_buffer,
Expand All @@ -143,19 +144,22 @@ def is_test_episode(episode_):
trainer.train(
dataset, num_epochs=1, batch_callback=batch_callback, evaluate=False
)
planet.save(work_dir / "planet.pth")
replay_bu 10000 ffer.save(work_dir)
planet.save(work_dir)
if cfg.overrides.get("save_replay_buffer", False):
replay_buffer.save(work_dir)
metrics = get_metrics_and_clear_metric_containers()
logger.log_data("metrics", metrics)

# Collect one episode of data
episode_reward = 0.0
obs = env.reset()
obs, _ = env.reset()
agent.reset()
planet.reset_posterior()
action = None
done = False
while not done:
terminated = False
truncated = False
pbar = tqdm(total=1000)
while not terminated and not truncated:
planet.update_posterior(obs, action=action, rng=rng)
action_noise = (
0
Expand All @@ -164,14 +168,18 @@ def is_test_episode(episode_):
* np_rng.standard_normal(env.action_space.shape[0])
)
action = agent.act(obs) + action_noise
action = np.clip(action, -1.0, 1.0) # to account for the noise
next_obs, reward, done, info = env.step(action)
replay_buffer.add(obs, action, next_obs, reward, done)
action = np.clip(
action, -1.0, 1.0, dtype=env.action_space.dtype
) # to account for the noise and fix dtype
next_obs, reward, terminated, truncated, _ = env.step(action)
replay_buffer.add(obs, action, next_obs, reward, terminated, truncated)
episode_reward += reward
obs = next_obs
if debug_mode:
print(f"step: {step}, reward: {reward}.")
step += 1
pbar.update(1)
pbar.close()
total_rewards += episode_reward
logger.log_data(
mbrl.constants.RESULTS_LOG_NAME,
Expand Down
12 changes: 5 additions & 7 deletions mbrl/diagnostics/control_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import time
from typing import Sequence, Tuple, cast

import gym.wrappers
import gymnasium as gym
import numpy as np
import omegaconf
import skvideo.io
Expand All @@ -27,8 +27,7 @@ def init(env_name: str, seed: int):
global handler__
handler__ = mbrl.util.create_handler_from_str(env_name)
env__ = handler__.make_env_from_str(env_name)
env__.seed(seed)
env__.reset()
env__.reset(seed=seed)


def step_env(action: np.ndarray):
Expand Down Expand Up @@ -85,10 +84,9 @@ def get_random_trajectory(horizon):
mp.set_start_method("spawn")
handler = mbrl.util.create_handler_from_str(args.env)
eval_env = handler.make_env_from_str(args.env)
eval_env.seed(args.seed)
torch.random.manual_seed(args.seed)
np.random.seed(args.seed)
current_obs = eval_env.reset()
current_obs, _ = eval_env.reset(seed=args.seed)

if args.optimizer_type == "cem":
optimizer_cfg = omegaconf.OmegaConf.create(
Expand Down Expand Up @@ -158,7 +156,7 @@ def get_random_trajectory(horizon):
values_sizes = [] # for icem
for t in range(args.num_steps):
if args.render:
frames.append(eval_env.render(mode="rgb_array"))
frames.append(eval_env.render())
start = time.time()

current_state__ = handler.get_current_state(
Expand All @@ -183,7 +181,7 @@ def compute_population_stats(_population, values, opt_step):
trajectory_eval_fn, callback=compute_population_stats
)
action__ = plan[0]
next_obs__, reward__, done__, _ = eval_env.step(action__)
next_obs__, reward__, terminated__, _, _ = eval_env.step(action__)

total_reward__ += reward__

Expand Down
Loading
0