8000 GitHub - RobertTLange/gymnax at v0.0.1
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

RobertTLange/gymnax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Gymnax - Classic Gym Environments in JAX

PyversionsPyPI versionColab

Are you fed up with slow CPU-based RL environment processes? Do you want to leverage massive vectorization for high-throughput RL experiments? gymnax brings the power of jit and vmap to classic OpenAI gym environments.

Basic gymnax API Usage 🍲

  • Classic Open AI gym wrapper including gymnax.make, env.reset, env.step:
import jax
import gymnax

rng = jax.random.PRNGKey(0)
rng, key_reset, key_policy, key_step = jax.random.split(rng, 4)

env, env_params = gymnax.make("Pendulum-v1")

obs, state = env.reset(key_reset, env_params)
action = env.action_space(env_params).sample(key_policy)
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)

Episode Rollouts, Vectorization & Acceleration

  • Easy composition of JAX primitives (e.g. jit, vmap, pmap):
def rollout(rng_input, policy_params, env_params, num_env_steps):
      """Rollout a jitted gymnax episode with lax.scan."""
      # Reset the environment
      rng_reset, rng_episode = jax.random.split(rng_input)
      obs, state = env.reset(rng_reset, env_params)

      def policy_step(state_input, tmp):
          """lax.scan compatible step transition in jax env."""
          obs, state, policy_params, rng = state_input
          rng, rng_step, rng_net = jax.random.split(rng, 3)
          action = network.apply({"params": policy_params}, obs, rng=rng_net)
          next_o, next_s, reward, done, _ = env.step(
              rng_step, state, action, env_params
          )
          carry = [next_o.squeeze(), next_s, policy_params, rng]
          return carry, [reward, done]

      # Scan over episode step loop
      _, scan_out = jax.lax.scan(
          policy_step,
          [obs, state, policy_params, rng_episode],
          [jnp.zeros((num_env_steps, 2))],
      )
      # Return masked sum of rewards accumulated by agent in episode
      rewards, dones = scan_out[0], scan_out[1]
      rewards = rewards.reshape(num_env_steps, 1)
      ep_mask = (jnp.cumsum(dones) < 1).reshape(num_env_steps, 1)
      return jnp.sum(rewards * ep_mask)
# Jit-Compiled Episode Rollout
jit_rollout = jax.jit(rollout, static_argnums=3)

# Vmap across random keys for Batch Rollout
batch_rollout = jax.vmap(jit_rollout, in_axes=(0, None, None, None))
  • Vectorization over different environment parametrizations:
env.step(key_step, state, action, env_params)

Implemented Accelerated Environments 🌍

Classic Control OpenAI gym environments.
Environment Name Implemented Tested Single Step Speed Gain (JAX vs. NumPy)
Pendulum-v0 βœ”οΈ βœ”οΈ
CartPole-v0 βœ”οΈ βœ”οΈ
MountainCar-v0 βœ”οΈ βœ”οΈ
MountainCarContinuous-v0 βœ”οΈ βœ”οΈ
Acrobot-v1 βœ”οΈ βœ”οΈ
DeepMind's BSuite environments.
Environment Name Implemented Tested Single Step Speed Gain (JAX vs. NumPy)
Catch-bsuite βœ”οΈ βœ”οΈ
DeepSea-bsuite βœ”οΈ βœ”οΈ
MemoryChain-bsuite βœ”οΈ βœ”οΈ
UmbrellaChain-bsuite βœ”οΈ βœ”οΈ
DiscountingChain-bsuite βœ”οΈ βœ”οΈ
MNISTBandit-bsuite βœ”οΈ βœ”οΈ
SimpleBandit-bsuite βœ”οΈ βœ”οΈ
K. Young's and T. Tian's MinAtar environments.
Environment Name Implemented Tested Single Step Speed Gain (JAX vs. NumPy)
Asterix-MinAtar βœ”οΈ βœ”οΈ
Breakout-MinAtar βœ”οΈ βœ”οΈ
Freeway-MinAtar βœ”οΈ βœ”οΈ
Seaquest-MinAtar ❌ ❌
SpaceInvaders-MinAtar βœ”οΈ βœ”οΈ
Miscellaneous Environments.
Environment Name Implemented Tested Single Step Speed Gain (JAX vs. NumPy)
BernoulliBandit-misc βœ”οΈ βœ”οΈ
GaussianBandit-misc βœ”οΈ βœ”οΈ
FourRooms-misc βœ”οΈ βœ”οΈ

Installation πŸ“

gymnax can be directly installed from PyPi.

pip install gymnax

Alternatively, you can clone this repository and 'manually' install the gymnax:

git clone https://github.com/RobertTLange/gymnax.git
cd gymnax
pip install -e .

Benchmarking Details πŸš‹

Examples πŸŽ’

  • πŸ““ Environment API - Check out the API and accelerated control environments.
  • πŸ““ Anakin Agent - Check out the DeepMind's Anakin agent with gymnax's Catch-bsuite environment.
  • πŸ““ CMA-ES - CMA-ES in JAX with vectorized population evaluation.

Acknowledgements & Citing gymnax ✏️

To cite this repository:

@software{gymnax2021github,
  author = {Robert Tjarko Lange},
  title = {{gymnax}: A {JAX}-based Reinforcement Learning Environment Library},
  url = {http://github.com/RobertTLange/gymnax},
  version = {0.0.1},
  year = {2021},
}

Much of the design of gymnax has been inspired by the classic OpenAI gym RL environment API and DeepMind's JAX eco-system. I am grateful to the JAX team and Matteo Hessel for their support and motivating words. Finally, a big thank you goes out to the TRC team at Google for granting me TPU quota for benchmarking gymnax.

Notes, Development & Questions ❓

  • If you find a bug or want a new feature, feel free to contact me @RobertTLange or create an issue πŸ€—
  • You can check out the history of release modifications in CHANGELOG.md (added, changed, fixed).
  • You can find a set of open milestones in CONTRIBUTING.md.
Design Notes (control flow, random numbers, episode termination).
  1. Each step transition requires you to pass a set of environment parameters env.step(rng, state, action, env_params), which specify the 'hyperparameters' of the environment. You can
  2. gymnax automatically resets an episode after termination. This way we can ensure that trajectory rollouts with fixed amounts of steps continue rolling out transitions.
  3. If you want calculate evaluation returns simply mask the sum using the binary discount vector.

About

RL Environments in JAX 🌍

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

0