From ed13d82e2a1c4f734fabe41b173411467e737606 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Bonnet?= <56230714+clement-bonnet@users.noreply.github.com> Date: Thu, 15 Jun 2023 14:33:38 +0200 Subject: [PATCH 1/6] fix(examples): port notebook to colab (#169) --- examples/load_checkpoints.ipynb | 173 +++++--- examples/training.ipynb | 387 ++++++++++++++++-- jumanji/training/configs/env/cleaner.yaml | 14 +- jumanji/training/configs/env/game_2048.yaml | 4 +- jumanji/training/configs/env/job_shop.yaml | 4 +- jumanji/training/configs/env/maze.yaml | 4 +- jumanji/training/configs/env/rubiks_cube.yaml | 4 +- jumanji/training/configs/env/snake.yaml | 4 +- jumanji/training/configs/env/sudoku.yaml | 4 +- jumanji/training/configs/env/tetris.yaml | 4 +- 10 files changed, 501 insertions(+), 101 deletions(-) diff --git a/examples/load_checkpoints.ipynb b/examples/load_checkpoints.ipynb index 410883e5c..1f3c82226 100644 --- a/examples/load_checkpoints.ipynb +++ b/examples/load_checkpoints.ipynb @@ -1,5 +1,16 @@ { "cells": [ + { + "cell_type": "markdown", + "source": [ + "\n", + " \"Open\n", + "" + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "code", "execution_count": 1, @@ -12,26 +23,59 @@ }, "scrolled": true, "ExecuteTime": { - "end_time": "2023-06-09T12:54:46.203755781Z", - "start_time": "2023-06-09T12:54:34.946097225Z" + "end_time": "2023-06-14T10:11:06.832854981Z", + "start_time": "2023-06-14T10:10:51.403505913Z" } }, + "outputs": [], + "source": [ + "%%capture\n", + "!pip install --quiet -U \"jumanji[train] @ git+https://github.com/instadeepai/jumanji.git@main\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Note: you may need to restart the kernel to use updated packages.\n" + "Only CPU accelerator is connected.\n" ] } ], "source": [ - "%pip install --quiet -U pip ../.[train]" - ] + "# @title Set up JAX for available hardware (run me) { display-mode: \"form\" }\n", + "\n", + "import subprocess\n", + "import os\n", + "\n", + "# Based on https://stackoverflow.com/questions/67504079/how-to-check-if-an-nvidia-gpu-is-available-on-my-system\n", + "try:\n", + " subprocess.check_output('nvidia-smi')\n", + " print(\"a GPU is connected.\")\n", + "except Exception:\n", + " # TPU or CPU\n", + " if \"COLAB_TPU_ADDR\" in os.environ and os.environ[\"COLAB_TPU_ADDR\"]:\n", + " import jax.tools.colab_tpu\n", + "\n", + " jax.tools.colab_tpu.setup_tpu()\n", + " print(\"A TPU is connected.\")\n", + " else:\n", + " print(\"Only CPU accelerator is connected.\")\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-06-14T10:11:06.844131189Z", + "start_time": "2023-06-14T10:11:06.837796509Z" + } + } }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": { "jupyter": { "outputs_hidden": false @@ -40,8 +84,8 @@ "is_executing": true }, "ExecuteTime": { - "end_time": "2023-06-09T12:54:47.789167404Z", - "start_time": "2023-06-09T12:54:46.209657499Z" + "end_time": "2023-06-14T10:11:08.370733527Z", + "start_time": "2023-06-14T10:11:06.842722444Z" } }, "outputs": [ @@ -54,7 +98,6 @@ } ], "source": [ - "import os\n", "import pickle\n", "\n", "import jax\n", @@ -76,26 +119,62 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 4, + "outputs": [], + "source": [ + "env = \"bin_pack\" # @param ['bin_pack', 'cleaner', 'connector', 'cvrp', 'game_2048', 'graph_coloring', 'job_shop', 'knapsack', 'maze', 'minesweeper', 'mmst', 'multi_cvrp', 'robot_warehouse', 'rubiks_cube', 'snake', 'sudoku', 'tetris', 'tsp']\n", + "agent = \"a2c\" # @param ['random', 'a2c']" + ], "metadata": { + "collapsed": false, "ExecuteTime": { - "end_time": "2023-06-09T14:15:36.323183536Z", - "start_time": "2023-06-09T14:15:35.716424914Z" + "end_time": "2023-06-14T10:11:08.373857448Z", + "start_time": "2023-06-14T10:11:08.371354210Z" } - }, + } + }, + { + "cell_type": "code", + "execution_count": 5, "outputs": [], "source": [ - "env = \"bin_pack\" # @param ['bin_pack', 'cleaner', 'connector', 'cvrp', 'game_2048', 'graph_coloring', 'job_shop', 'knapsack', 'maze', 'minesweeper', 'mmst', 'multi_cvrp', 'robot_warehouse', 'rubiks_cube', 'snake', 'sudoku', 'tetris', 'tsp']\n", - "agent = \"a2c\" # @param ['random', 'a2c']" - ] + "#@title Download Jumanji Configs (run me) { display-mode: \"form\" }\n", + "\n", + "import os\n", + "import requests\n", + "\n", + "def download_file(url: str, file_path: str) -> None:\n", + " # Send an HTTP GET request to the URL\n", + " response = requests.get(url)\n", + " # Check if the request was successful (status code 200)\n", + " if response.status_code == 200:\n", + " with open(file_path, \"wb\") as f:\n", + " f.write(response.content)\n", + " else:\n", + " print(\"Failed to download the file.\")\n", + "\n", + "os.makedirs(\"configs\", exist_ok=True)\n", + "config_url = \"https://raw.githubusercontent.com/instadeepai/jumanji/main/jumanji/training/configs/config.yaml\"\n", + "download_file(config_url, \"configs/config.yaml\")\n", + "env_url = f\"https://raw.githubusercontent.com/instadeepai/jumanji/main/jumanji/training/configs/env/{env}.yaml\"\n", + "os.makedirs(\"configs/env\", exist_ok=True)\n", + "download_file(env_url, f\"configs/env/{env}.yaml\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-06-14T10:11:08.479313689Z", + "start_time": "2023-06-14T10:11:08.376210715Z" + } + } }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 6, "metadata": { "ExecuteTime": { - "end_time": "2023-06-09T14:15:36.416627047Z", - "start_time": "2023-06-09T14:15:36.312210869Z" + "end_time": "2023-06-14T10:11:08.701848858Z", + "start_time": "2023-06-14T10:11:08.480541997Z" } }, "outputs": [ @@ -103,13 +182,13 @@ "data": { "text/plain": "{'agent': 'a2c', 'seed': 0, 'logger': {'type': 'terminal', 'save_checkpoint': False, 'name': '${agent}_${env.name}'}, 'env': {'name': 'bin_pack', 'registered_version': 'BinPack-v2', 'network': {'num_transformer_layers': 2, 'transformer_num_heads': 8, 'transformer_key_size': 16, 'transformer_mlp_units': [512]}, 'training': {'num_epochs': 550, 'num_learner_steps_per_epoch': 100, 'n_steps': 30, 'total_batch_size': 64}, 'evaluation': {'eval_total_batch_size': 10000, 'greedy_eval_total_batch_size': 10000}, 'a2c': {'normalize_advantage': False, 'discount_factor': 1.0, 'bootstrapping_factor': 0.95, 'l_pg': 1.0, 'l_td': 1.0, 'l_en': 0.005, 'learning_rate': 0.0001}}}" }, - "execution_count": 25, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "with initialize(version_base=None, config_path=\"../jumanji/training/configs\"):\n", + "with initialize(version_base=None, config_path=\"configs\"):\n", " cfg = compose(config_name=\"config.yaml\", overrides=[f\"env={env}\", f\"agent={agent}\"])\n", "cfg" ] @@ -123,7 +202,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 7, "outputs": [], "source": [ "# Chose the corresponding checkpoint from the InstaDeep Model Hub\n", @@ -139,18 +218,18 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-06-09T14:15:40.190993641Z", - "start_time": "2023-06-09T14:15:39.866341022Z" + "end_time": "2023-06-14T10:11:10.226606119Z", + "start_time": "2023-06-14T10:11:08.702541986Z" } } }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 8, "metadata": { "ExecuteTime": { - "end_time": "2023-06-09T14:15:44.450615139Z", - "start_time": "2023-06-09T14:15:43.939253885Z" + "end_time": "2023-06-14T10:11:10.646574619Z", + "start_time": "2023-06-14T10:11:10.232194710Z" } }, "outputs": [], @@ -172,37 +251,27 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": null, "metadata": { - "tags": [], - "ExecuteTime": { - "end_time": "2023-06-09T14:50:32.328783179Z", - "start_time": "2023-06-09T14:47:35.773302568Z" - } + "tags": [] }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/clement/jumanji/venv/lib/python3.8/site-packages/jax/_src/ops/scatter.py:89: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=float32 to dtype=int32. In future JAX releases this will result in an error.\n", - " warnings.warn(\"scatter inputs have incompatible types: cannot safely cast \"\n" - ] - } - ], + "outputs": [], "source": [ "NUM_EPISODES = 2\n", "\n", + "reset_fn = jax.jit(env.reset)\n", + "step_fn = jax.jit(env.step)\n", "states = []\n", "key = jax.random.PRNGKey(cfg.seed)\n", "for episode in range(NUM_EPISODES):\n", " key, reset_key = jax.random.split(key) \n", - " state, timestep = jax.jit(env.reset)(reset_key)\n", + " state, timestep = reset_fn(reset_key)\n", + " states.append(state)\n", " while not timestep.last():\n", " key, action_key = jax.random.split(key)\n", " observation = jax.tree_util.tree_map(lambda x: x[None], timestep.observation)\n", " action, _ = policy(observation, action_key)\n", - " state, timestep = jax.jit(env.step)(state, action.squeeze(axis=0))\n", + " state, timestep = step_fn(state, action.squeeze(axis=0))\n", " states.append(state)\n", " # Freeze the terminal frame to pause the GIF.\n", " for _ in range(3):\n", @@ -218,14 +287,14 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 10, "metadata": { "pycharm": { "is_executing": true }, "ExecuteTime": { - "end_time": "2023-06-09T14:50:36.340139662Z", - "start_time": "2023-06-09T14:50:32.329168660Z" + "end_time": "2023-06-14T10:11:23.572860540Z", + "start_time": "2023-06-14T10:11:19.277668279Z" } }, "outputs": [ @@ -240,17 +309,17 @@ { "data": { "text/plain": "", - "text/html": "
" + "text/html": "
" }, "metadata": {}, "output_type": "display_data" }, { "data": { - "text/plain": "", - "text/html": "\n\n\n\n\n\n
\n \n
\n \n
\n \n \n \n \n \n \n \n \n \n
\n
\n \n \n \n \n \n \n
\n
\n
\n\n\n\n" + "text/plain": "", + "text/html": "\n\n\n\n\n\n
\n \n
\n \n
\n \n \n \n \n \n \n \n \n \n
\n
\n \n \n \n \n \n \n
\n
\n
\n\n\n\n" }, - "execution_count": 32, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } diff --git a/examples/training.ipynb b/examples/training.ipynb index 0d67cd438..b338dc2ab 100644 --- a/examples/training.ipynb +++ b/examples/training.ipynb @@ -1,45 +1,89 @@ { "cells": [ + { + "cell_type": "markdown", + "source": [ + "\n", + " \"Open\n", + "" + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 8, "metadata": { "collapsed": true, "jupyter": { "outputs_hidden": true + }, + "ExecuteTime": { + "end_time": "2023-06-14T10:11:33.230999708Z", + "start_time": "2023-06-14T10:11:13.526881698Z" } }, + "outputs": [], + "source": [ + "%%capture\n", + "!pip install --quiet -U \"jumanji[train] @ git+https://github.com/instadeepai/jumanji.git@main\"" + ] + }, + { + "cell_type": "code", + "execution_count": 9, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Note: you may need to restart the kernel to use updated packages.\n" + "Only CPU accelerator is connected.\n" ] } ], "source": [ - "%pip install --quiet -U pip -r ../requirements/requirements-training.txt ../." - ] + "# @title Set up JAX for available hardware (run me) { display-mode: \"form\" }\n", + "\n", + "import subprocess\n", + "import os\n", + "\n", + "# Based on https://stackoverflow.com/questions/67504079/how-to-check-if-an-nvidia-gpu-is-available-on-my-system\n", + "try:\n", + " subprocess.check_output('nvidia-smi')\n", + " print(\"a GPU is connected.\")\n", + "except Exception:\n", + " # TPU or CPU\n", + " if \"COLAB_TPU_ADDR\" in os.environ and os.environ[\"COLAB_TPU_ADDR\"]:\n", + " import jax.tools.colab_tpu\n", + "\n", + " jax.tools.colab_tpu.setup_tpu()\n", + " print(\"A TPU is connected.\")\n", + " else:\n", + " print(\"Only CPU accelerator is connected.\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-06-14T10:11:33.245117659Z", + "start_time": "2023-06-14T10:11:33.237735383Z" + } + } }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 10, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false + }, + "ExecuteTime": { + "end_time": "2023-06-14T10:11:33.268137075Z", + "start_time": "2023-06-14T10:11:33.246267189Z" } }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" - ] - } - ], + "outputs": [], "source": [ "import warnings\n", "warnings.filterwarnings(\"ignore\")\n", @@ -50,41 +94,328 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 11, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false + }, + "ExecuteTime": { + "end_time": "2023-06-14T10:11:33.279561988Z", + "start_time": "2023-06-14T10:11:33.268947238Z" } }, "outputs": [], "source": [ - "env = \"maze\" # @param ['bin_pack', 'cleaner', 'connector', 'cvrp', 'game_2048', 'job_shop', 'knapsack', 'maze', 'minesweeper', 'rubiks_cube', 'snake', \"sudoku\", 'tsp']\n", + "env = \"maze\" # @param ['bin_pack', 'cleaner', 'connector', 'cvrp', 'game_2048', 'graph_coloring', 'job_shop', 'knapsack', 'maze', 'minesweeper', 'mmst', 'multi_cvrp', 'robot_warehouse', 'rubiks_cube', 'snake', 'sudoku', 'tetris', 'tsp']\n", "agent = \"random\" # @param ['random', 'a2c']" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 12, + "outputs": [], + "source": [ + "#@title Download Jumanji Configs (run me) { display-mode: \"form\" }\n", + "\n", + "import os\n", + "import requests\n", + "\n", + "\n", + "def download_file(url: str, file_path: str) -> None:\n", + " # Send an HTTP GET request to the URL\n", + " response = requests.get(url)\n", + " # Check if the request was successful (status code 200)\n", + " if response.status_code == 200:\n", + " with open(file_path, \"wb\") as f:\n", + " f.write(response.content)\n", + " else:\n", + " print(\"Failed to download the file.\")\n", + "\n", + "\n", + "os.makedirs(\"configs\", exist_ok=True)\n", + "config_url = \"https://raw.githubusercontent.com/instadeepai/jumanji/main/jumanji/training/configs/config.yaml\"\n", + "download_file(config_url, \"configs/config.yaml\")\n", + "env_url = f\"https://raw.githubusercontent.com/instadeepai/jumanji/main/jumanji/training/configs/env/{env}.yaml\"\n", + "os.makedirs(\"configs/env\", exist_ok=True)\n", + "download_file(env_url, f\"configs/env/{env}.yaml\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-06-14T10:11:33.662474073Z", + "start_time": "2023-06-14T10:11:33.281569701Z" + } + } + }, + { + "cell_type": "code", + "execution_count": 13, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false + }, + "ExecuteTime": { + "end_time": "2023-06-14T10:12:46.061682766Z", + "start_time": "2023-06-14T10:11:33.664132133Z" } }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:agent: random\n", + "seed: 0\n", + "logger:\n", + " type: terminal\n", + " save_checkpoint: true\n", + " name: ${agent}_${env.name}\n", + "env:\n", + " name: maze\n", + " registered_version: Maze-v0\n", + " network:\n", + " num_channels:\n", + " - 32\n", + " - 32\n", + " - 8\n", + " policy_layers:\n", + " - 64\n", + " - 64\n", + " value_layers:\n", + " - 128\n", + " - 128\n", + " training:\n", + " num_epochs: 100\n", + " num_learner_steps_per_epoch: 500\n", + " n_steps: 10\n", + " total_batch_size: 128\n", + " evaluation:\n", + " eval_total_batch_size: 500\n", + " greedy_eval_total_batch_size: 500\n", + " a2c:\n", + " normalize_advantage: false\n", + " discount_factor: 0.99\n", + " bootstrapping_factor: 0.95\n", + " l_pg: 1.0\n", + " l_td: 1.0\n", + " l_en: 0.01\n", + " learning_rate: 0.0003\n", + "\n", + "INFO:root:{'devices': [CpuDevice(id=0)]}\n", + "INFO:root:Experiment: random_maze.\n", + "INFO:root:Starting logger.\n", + "INFO:root:Eval Stochastic >> Env Steps: 0.00e+00 | Episode Length: 80.542 | Episode Return: 0.294 | Time: 2.283\n", + "INFO:root:Train >> Env Steps: 0.00e+00 | Steps Per Second: 311,313 | Time: 2.056\n", + "INFO:root:Eval Stochastic >> Env Steps: 6.40e+05 | Episode Length: 81.860 | Episode Return: 0.294 | Time: 0.039\n", + "INFO:root:Train >> Env Steps: 6.40e+05 | Steps Per Second: 2,491,524 | Time: 0.257\n", + "INFO:root:Eval Stochastic >> Env Steps: 1.28e+06 | Episode Length: 83.506 | Episode Return: 0.268 | Time: 0.043\n", + "INFO:root:Train >> Env Steps: 1.28e+06 | Steps Per Second: 2,262,205 | Time: 0.283\n", + "INFO:root:Eval Stochastic >> Env Steps: 1.92e+06 | Episode Length: 80.668 | Episode Return: 0.298 | Time: 0.040\n", + "INFO:root:Train >> Env Steps: 1.92e+06 | Steps Per Second: 2,267,689 | Time: 0.282\n", + "INFO:root:Eval Stochastic >> Env Steps: 2.56e+06 | Episode Length: 81.108 | Episode Return: 0.304 | Time: 0.042\n", + "INFO:root:Train >> Env Steps: 2.56e+06 | Steps Per Second: 2,072,973 | Time: 0.309\n", + "INFO:root:Eval Stochastic >> Env Steps: 3.20e+06 | Episode Length: 78.792 | Episode Return: 0.314 | Time: 0.050\n", + "INFO:root:Train >> Env Steps: 3.20e+06 | Steps Per Second: 2,117,558 | Time: 0.302\n", + "INFO:root:Eval Stochastic >> Env Steps: 3.84e+06 | Episode Length: 82.330 | Episode Return: 0.280 | Time: 0.041\n", + "INFO:root:Train >> Env Steps: 3.84e+06 | Steps Per Second: 2,249,738 | Time: 0.284\n", + "INFO:root:Eval Stochastic >> Env Steps: 4.48e+06 | Episode Length: 80.168 | Episode Return: 0.296 | Time: 0.050\n", + "INFO:root:Train >> Env Steps: 4.48e+06 | Steps Per Second: 2,226,935 | Time: 0.287\n", + "INFO:root:Eval Stochastic >> Env Steps: 5.12e+06 | Episode Length: 79.178 | Episode Return: 0.314 | Time: 0.040\n", + "INFO:root:Train >> Env Steps: 5.12e+06 | Steps Per Second: 2,167,084 | Time: 0.295\n", + "INFO:root:Eval Stochastic >> Env Steps: 5.76e+06 | Episode Length: 82.756 | Episode Return: 0.296 | Time: 0.037\n", + "INFO:root:Train >> Env Steps: 5.76e+06 | Steps Per Second: 2,550,027 | Time: 0.251\n", + "INFO:root:Eval Stochastic >> Env Steps: 6.40e+06 | Episode Length: 80.560 | Episode Return: 0.304 | Time: 0.041\n", + "INFO:root:Train >> Env Steps: 6.40e+06 | Steps Per Second: 2,279,612 | Time: 0.281\n", + "INFO:root:Eval Stochastic >> Env Steps: 7.04e+06 | Episode Length: 80.342 | Episode Return: 0.310 | Time: 0.051\n", + "INFO:root:Train >> Env Steps: 7.04e+06 | Steps Per Second: 2,461,171 | Time: 0.260\n", + "INFO:root:Eval Stochastic >> Env Steps: 7.68e+06 | Episode Length: 82.324 | Episode Return: 0.276 | Time: 0.037\n", + "INFO:root:Train >> Env Steps: 7.68e+06 | Steps Per Second: 2,453,811 | Time: 0.261\n", + "INFO:root:Eval Stochastic >> Env Steps: 8.32e+06 | Episode Length: 81.474 | Episode Return: 0.278 | Time: 0.038\n", + "INFO:root:Train >> Env Steps: 8.32e+06 | Steps Per Second: 2,455,146 | Time: 0.261\n", + "INFO:root:Eval Stochastic >> Env Steps: 8.96e+06 | Episode Length: 81.528 | Episode Return: 0.290 | Time: 0.037\n", + "INFO:root:Train >> Env Steps: 8.96e+06 | Steps Per Second: 2,481,307 | Time: 0.258\n", + "INFO:root:Eval Stochastic >> Env Steps: 9.60e+06 | Episode Length: 81.986 | Episode Return: 0.276 | Time: 0.038\n", + "INFO:root:Train >> Env Steps: 9.60e+06 | Steps Per Second: 2,356,032 | Time: 0.272\n", + "INFO:root:Eval Stochastic >> Env Steps: 1.02e+07 | Episode Length: 82.106 | Episode Return: 0.290 | Time: 0.040\n", + "INFO:root:Train >> Env Steps: 1.02e+07 | Steps Per Second: 2,326,346 | Time: 0.275\n", + "INFO:root:Eval Stochastic >> Env Steps: 1.09e+07 | Episode Length: 82.614 | Episode Return: 0.292 | Time: 0.036\n", + "INFO:root:Train >> Env Steps: 1.09e+07 | Steps Per Second: 2,436,756 | Time: 0.263\n", + "INFO:root:Eval Stochastic >> Env Steps: 1.15e+07 | Episode Length: 80.540 | Episode Return: 0.298 | Time: 0.039\n", + "INFO:root:Train >> Env Steps: 1.15e+07 | Steps Per Second: 2,323,687 | Time: 0.275\n", + "INFO:root:Eval Stochastic >> Env Steps: 1.22e+07 | Episode Length: 81.710 | Episode Return: 0.270 | Time: 0.039\n", + "INFO:root:Train >> Env Steps: 1.22e+07 | Steps Per Second: 2,309,338 | Time: 0.277\n", + "INFO:root:Eval Stochastic >> Env Steps: 1.28e+07 | Episode Length: 79.808 | Episode Return: 0.320 | Time: 0.038\n", + "INFO:root:Train >> Env Steps: 1.28e+07 | Steps Per Second: 2,353,846 | Time: 0.272\n", + "INFO:root:Eval Stochastic >> Env Steps: 1.34e+07 | Episode Length: 81.976 | Episode Return: 0.286 | Time: 0.039\n", + "INFO:root:Train >> Env Steps: 1.34e+07 | Steps Per Second: 2,341,133 | Time: 0.273\n", + "INFO:root:Eval Stochastic >> Env Steps: 1.41e+07 | Episode Length: 82.834 | Episode Return: 0.278 | Time: 0.039\n", + "INFO:root:Train >> Env Steps: 1.41e+07 | Steps Per Second: 2,367,474 | Time: 0.270\n", + "INFO:root:Eval Stochastic >> Env Steps: 1.47e+07 | Episode Length: 81.438 | Episode Return: 0.292 | Time: 0.038\n", + "INFO:root:Train >> Env Steps: 1.47e+07 | Steps Per Second: 2,347,204 | Time: 0.273\n", + "INFO:root:Eval Stochastic >> Env Steps: 1.54e+07 | Episode Length: 79.628 | Episode Return: 0.304 | Time: 0.041\n", + "INFO:root:Train >> Env Steps: 1.54e+07 | Steps Per Second: 2,335,717 | Time: 0.274\n", + "INFO:root:Eval Stochastic >> Env Steps: 1.60e+07 | Episode Length: 82.752 | Episode Return: 0.270 | Time: 0.038\n", + "INFO:root:Train >> Env Steps: 1.60e+07 | Steps Per Second: 2,550,539 | Time: 0.251\n", + "INFO:root:Eval Stochastic >> Env Steps: 1.66e+07 | Episode Length: 83.584 | Episode Return: 0.258 | Time: 0.041\n", + "INFO:root:Train >> Env Steps: 1.66e+07 | Steps Per Second: 2,411,240 | Time: 0.265\n", + "INFO:root:Eval Stochastic >> Env Steps: 1.73e+07 | Episode Length: 81.678 | Episode Return: 0.288 | Time: 0.037\n", + "INFO:root:Train >> Env Steps: 1.73e+07 | Steps Per Second: 2,354,060 | Time: 0.272\n", + "INFO:root:Eval Stochastic >> Env Steps: 1.79e+07 | Episode Length: 80.766 | Episode Return: 0.302 | Time: 0.037\n", + "INFO:root:Train >> Env Steps: 1.79e+07 | Steps Per Second: 2,420,398 | Time: 0.264\n", + "INFO:root:Eval Stochastic >> Env Steps: 1.86e+07 | Episode Length: 82.046 | Episode Return: 0.306 | Time: 0.039\n", + "INFO:root:Train >> Env Steps: 1.86e+07 | Steps Per Second: 2,306,851 | Time: 0.277\n", + "INFO:root:Eval Stochastic >> Env Steps: 1.92e+07 | Episode Length: 80.440 | Episode Return: 0.290 | Time: 0.040\n", + "INFO:root:Train >> Env Steps: 1.92e+07 | Steps Per Second: 2,302,742 | Time: 0.278\n", + "INFO:root:Eval Stochastic >> Env Steps: 1.98e+07 | Episode Length: 85.246 | Episode Return: 0.246 | Time: 0.043\n", + "INFO:root:Train >> Env Steps: 1.98e+07 | Steps Per Second: 2,310,177 | Time: 0.277\n", + "INFO:root:Eval Stochastic >> Env Steps: 2.05e+07 | Episode Length: 82.386 | Episode Return: 0.288 | Time: 0.040\n", + "INFO:root:Train >> Env Steps: 2.05e+07 | Steps Per Second: 2,502,355 | Time: 0.256\n", + "INFO:root:Eval Stochastic >> Env Steps: 2.11e+07 | Episode Length: 83.376 | Episode Return: 0.256 | Time: 0.038\n", + "INFO:root:Train >> Env Steps: 2.11e+07 | Steps Per Second: 2,496,007 | Time: 0.256\n", + "INFO:root:Eval Stochastic >> Env Steps: 2.18e+07 | Episode Length: 82.706 | Episode Return: 0.280 | Time: 0.038\n", + "INFO:root:Train >> Env Steps: 2.18e+07 | Steps Per Second: 2,227,243 | Time: 0.287\n", + "INFO:root:Eval Stochastic >> Env Steps: 2.24e+07 | Episode Length: 81.022 | Episode Return: 0.294 | Time: 0.042\n", + "INFO:root:Train >> Env Steps: 2.24e+07 | Steps Per Second: 2,234,127 | Time: 0.286\n", + "INFO:root:Eval Stochastic >> Env Steps: 2.30e+07 | Episode Length: 81.408 | Episode Return: 0.282 | Time: 0.043\n", + "INFO:root:Train >> Env Steps: 2.30e+07 | Steps Per Second: 2,291,360 | Time: 0.279\n", + "INFO:root:Eval Stochastic >> Env Steps: 2.37e+07 | Episode Length: 82.474 | Episode Return: 0.272 | Time: 0.038\n", + "INFO:root:Train >> Env Steps: 2.37e+07 | Steps Per Second: 2,581,666 | Time: 0.248\n", + "INFO:root:Eval Stochastic >> Env Steps: 2.43e+07 | Episode Length: 81.540 | Episode Return: 0.288 | Time: 0.040\n", + "INFO:root:Train >> Env Steps: 2.43e+07 | Steps Per Second: 2,432,661 | Time: 0.263\n", + "INFO:root:Eval Stochastic >> Env Steps: 2.50e+07 | Episode Length: 81.974 | Episode Return: 0.272 | Time: 0.037\n", + "INFO:root:Train >> Env Steps: 2.50e+07 | Steps Per Second: 2,411,056 | Time: 0.265\n", + "INFO:root:Eval Stochastic >> Env Steps: 2.56e+07 | Episode Length: 82.248 | Episode Return: 0.278 | Time: 0.037\n", + "INFO:root:Train >> Env Steps: 2.56e+07 | Steps Per Second: 2,496,048 | Time: 0.256\n", + "INFO:root:Eval Stochastic >> Env Steps: 2.62e+07 | Episode Length: 83.648 | Episode Return: 0.268 | Time: 0.049\n", + "INFO:root:Train >> Env Steps: 2.62e+07 | Steps Per Second: 2,352,356 | Time: 0.272\n", + "INFO:root:Eval Stochastic >> Env Steps: 2.69e+07 | Episode Length: 83.536 | Episode Return: 0.268 | Time: 0.045\n", + "INFO:root:Train >> Env Steps: 2.69e+07 | Steps Per Second: 2,106,113 | Time: 0.304\n", + "INFO:root:Eval Stochastic >> Env Steps: 2.75e+07 | Episode Length: 81.730 | Episode Return: 0.292 | Time: 0.040\n", + "INFO:root:Train >> Env Steps: 2.75e+07 | Steps Per Second: 2,300,960 | Time: 0.278\n", + "INFO:root:Eval Stochastic >> Env Steps: 2.82e+07 | Episode Length: 83.238 | Episode Return: 0.258 | Time: 0.040\n", + "INFO:root:Train >> Env Steps: 2.82e+07 | Steps Per Second: 2,326,204 | Time: 0.275\n", + "INFO:root:Eval Stochastic >> Env Steps: 2.88e+07 | Episode Length: 82.600 | Episode Return: 0.258 | Time: 0.039\n", + "INFO:root:Train >> Env Steps: 2.88e+07 | Steps Per Second: 2,293,995 | Time: 0.279\n", + "INFO:root:Eval Stochastic >> Env Steps: 2.94e+07 | Episode Length: 80.240 | Episode Return: 0.298 | Time: 0.040\n", + "INFO:root:Train >> Env Steps: 2.94e+07 | Steps Per Second: 2,333,271 | Time: 0.274\n", + "INFO:root:Eval Stochastic >> Env Steps: 3.01e+07 | Episode Length: 81.502 | Episode Return: 0.290 | Time: 0.041\n", + "INFO:root:Train >> Env Steps: 3.01e+07 | Steps Per Second: 2,171,676 | Time: 0.295\n", + "INFO:root:Eval Stochastic >> Env Steps: 3.07e+07 | Episode Length: 82.788 | Episode Return: 0.278 | Time: 0.037\n", + "INFO:root:Train >> Env Steps: 3.07e+07 | Steps Per Second: 2,283,905 | Time: 0.280\n", + "INFO:root:Eval Stochastic >> Env Steps: 3.14e+07 | Episode Length: 82.726 | Episode Return: 0.274 | Time: 0.039\n", + "INFO:root:Train >> Env Steps: 3.14e+07 | Steps Per Second: 2,098,806 | Time: 0.305\n", + "INFO:root:Eval Stochastic >> Env Steps: 3.20e+07 | Episode Length: 83.738 | Episode Return: 0.274 | Time: 0.051\n", + "INFO:root:Train >> Env Steps: 3.20e+07 | Steps Per Second: 2,210,576 | Time: 0.290\n", + "INFO:root:Eval Stochastic >> Env Steps: 3.26e+07 | Episode Length: 80.448 | Episode Return: 0.298 | Time: 0.043\n", + "INFO:root:Train >> Env Steps: 3.26e+07 | Steps Per Second: 2,418,761 | Time: 0.265\n", + "INFO:root:Eval Stochastic >> Env Steps: 3.33e+07 | Episode Length: 81.492 | Episode Return: 0.278 | Time: 0.038\n", + "INFO:root:Train >> Env Steps: 3.33e+07 | Steps Per Second: 2,249,525 | Time: 0.285\n", + "INFO:root:Eval Stochastic >> Env Steps: 3.39e+07 | Episode Length: 82.242 | Episode Return: 0.280 | Time: 0.044\n", + "INFO:root:Train >> Env Steps: 3.39e+07 | Steps Per Second: 2,014,416 | Time: 0.318\n", + "INFO:root:Eval Stochastic >> Env Steps: 3.46e+07 | Episode Length: 82.854 | Episode Return: 0.274 | Time: 0.043\n", + "INFO:root:Train >> Env Steps: 3.46e+07 | Steps Per Second: 2,098,892 | Time: 0.305\n", + "INFO:root:Eval Stochastic >> Env Steps: 3.52e+07 | Episode Length: 84.762 | Episode Return: 0.250 | Time: 0.046\n", + "INFO:root:Train >> Env Steps: 3.52e+07 | Steps Per Second: 2,090,125 | Time: 0.306\n", + "INFO:root:Eval Stochastic >> Env Steps: 3.58e+07 | Episode Length: 83.244 | Episode Return: 0.260 | Time: 0.046\n", + "INFO:root:Train >> Env Steps: 3.58e+07 | Steps Per Second: 1,879,605 | Time: 0.340\n", + "INFO:root:Eval Stochastic >> Env Steps: 3.65e+07 | Episode Length: 83.160 | Episode Return: 0.276 | Time: 0.038\n", + "INFO:root:Train >> Env Steps: 3.65e+07 | Steps Per Second: 2,352,184 | Time: 0.272\n", + "INFO:root:Eval Stochastic >> Env Steps: 3.71e+07 | Episode Length: 79.962 | Episode Return: 0.304 | Time: 0.047\n", + "INFO:root:Train >> Env Steps: 3.71e+07 | Steps Per Second: 2,448,137 | Time: 0.261\n", + "INFO:root:Eval Stochastic >> Env Steps: 3.78e+07 | Episode Length: 86.560 | Episode Return: 0.234 | Time: 0.038\n", + "INFO:root:Train >> Env Steps: 3.78e+07 | Steps Per Second: 2,333,372 | Time: 0.274\n", + "INFO:root:Eval Stochastic >> Env Steps: 3.84e+07 | Episode Length: 83.018 | Episode Return: 0.270 | Time: 0.041\n", + "INFO:root:Train >> Env Steps: 3.84e+07 | Steps Per Second: 2,164,049 | Time: 0.296\n", + "INFO:root:Eval Stochastic >> Env Steps: 3.90e+07 | Episode Length: 80.088 | Episode Return: 0.322 | Time: 0.049\n", + "INFO:root:Train >> Env Steps: 3.90e+07 | Steps Per Second: 2,011,495 | Time: 0.318\n", + "INFO:root:Eval Stochastic >> Env Steps: 3.97e+07 | Episode Length: 82.300 | Episode Return: 0.278 | Time: 0.044\n", + "INFO:root:Train >> Env Steps: 3.97e+07 | Steps Per Second: 2,043,342 | Time: 0.313\n", + "INFO:root:Eval Stochastic >> Env Steps: 4.03e+07 | Episode Length: 84.022 | Episode Return: 0.256 | Time: 0.045\n", + "INFO:root:Train >> Env Steps: 4.03e+07 | Steps Per Second: 2,142,862 | Time: 0.299\n", + "INFO:root:Eval Stochastic >> Env Steps: 4.10e+07 | Episode Length: 82.592 | Episode Return: 0.260 | Time: 0.045\n", + "INFO:root:Train >> Env Steps: 4.10e+07 | Steps Per Second: 1,806,697 | Time: 0.354\n", + "INFO:root:Eval Stochastic >> Env Steps: 4.16e+07 | Episode Length: 84.444 | Episode Return: 0.260 | Time: 0.040\n", + "INFO:root:Train >> Env Steps: 4.16e+07 | Steps Per Second: 2,497,808 | Time: 0.256\n", + "INFO:root:Eval Stochastic >> Env Steps: 4.22e+07 | Episode Length: 82.130 | Episode Return: 0.272 | Time: 0.038\n", + "INFO:root:Train >> Env Steps: 4.22e+07 | Steps Per Second: 2,456,166 | Time: 0.261\n", + "INFO:root:Eval Stochastic >> Env Steps: 4.29e+07 | Episode Length: 81.448 | Episode Return: 0.300 | Time: 0.045\n", + "INFO:root:Train >> Env Steps: 4.29e+07 | Steps Per Second: 2,247,462 | Time: 0.285\n", + "INFO:root:Eval Stochastic >> Env Steps: 4.35e+07 | Episode Length: 81.246 | Episode Return: 0.284 | Time: 0.046\n", + "INFO:root:Train >> Env Steps: 4.35e+07 | Steps Per Second: 2,096,755 | Time: 0.305\n", + "INFO:root:Eval Stochastic >> Env Steps: 4.42e+07 | Episode Length: 79.680 | Episode Return: 0.318 | Time: 0.045\n", + "INFO:root:Train >> Env Steps: 4.42e+07 | Steps Per Second: 2,104,420 | Time: 0.304\n", + "INFO:root:Eval Stochastic >> Env Steps: 4.48e+07 | Episode Length: 79.308 | Episode Return: 0.330 | Time: 0.043\n", + "INFO:root:Train >> Env Steps: 4.48e+07 | Steps Per Second: 1,884,177 | Time: 0.340\n", + "INFO:root:Eval Stochastic >> Env Steps: 4.54e+07 | Episode Length: 81.016 | Episode Return: 0.308 | Time: 0.049\n", + "INFO:root:Train >> Env Steps: 4.54e+07 | Steps Per Second: 1,882,314 | Time: 0.340\n", + "INFO:root:Eval Stochastic >> Env Steps: 4.61e+07 | Episode Length: 82.860 | Episode Return: 0.262 | Time: 0.051\n", + "INFO:root:Train >> Env Steps: 4.61e+07 | Steps Per Second: 2,209,167 | Time: 0.290\n", + "INFO:root:Eval Stochastic >> Env Steps: 4.67e+07 | Episode Length: 82.430 | Episode Return: 0.288 | Time: 0.044\n", + "INFO:root:Train >> Env Steps: 4.67e+07 | Steps Per Second: 2,420,382 | Time: 0.264\n", + "INFO:root:Eval Stochastic >> Env Steps: 4.74e+07 | Episode Length: 80.886 | Episode Return: 0.294 | Time: 0.037\n", + "INFO:root:Train >> Env Steps: 4.74e+07 | Steps Per Second: 2,254,018 | Time: 0.284\n", + "INFO:root:Eval Stochastic >> Env Steps: 4.80e+07 | Episode Length: 82.700 | Episode Return: 0.278 | Time: 0.046\n", + "INFO:root:Train >> Env Steps: 4.80e+07 | Steps Per Second: 2,204,178 | Time: 0.290\n", + "INFO:root:Eval Stochastic >> Env Steps: 4.86e+07 | Episode Length: 81.658 | Episode Return: 0.282 | Time: 0.041\n", + "INFO:root:Train >> Env Steps: 4.86e+07 | Steps Per Second: 2,216,029 | Time: 0.289\n", + "INFO:root:Eval Stochastic >> Env Steps: 4.93e+07 | Episode Length: 80.420 | Episode Return: 0.302 | Time: 0.053\n", + "INFO:root:Train >> Env Steps: 4.93e+07 | Steps Per Second: 2,041,316 | Time: 0.314\n", + "INFO:root:Eval Stochastic >> Env Steps: 4.99e+07 | Episode Length: 80.008 | Episode Return: 0.304 | Time: 0.046\n", + "INFO:root:Train >> Env Steps: 4.99e+07 | Steps Per Second: 2,415,941 | Time: 0.265\n", + "INFO:root:Eval Stochastic >> Env Steps: 5.06e+07 | Episode Length: 83.252 | Episode Return: 0.268 | Time: 0.042\n", + "INFO:root:Train >> Env Steps: 5.06e+07 | Steps Per Second: 2,308,797 | Time: 0.277\n", + "INFO:root:Eval Stochastic >> Env Steps: 5.12e+07 | Episode Length: 82.196 | Episode Return: 0.284 | Time: 0.038\n", + "INFO:root:Train >> Env Steps: 5.12e+07 | Steps Per Second: 2,298,530 | Time: 0.278\n", + "INFO:root:Eval Stochastic >> Env Steps: 5.18e+07 | Episode Length: 81.020 | Episode Return: 0.302 | Time: 0.042\n", + "INFO:root:Train >> Env Steps: 5.18e+07 | Steps Per Second: 2,249,649 | Time: 0.284\n", + "INFO:root:Eval Stochastic >> Env Steps: 5.25e+07 | Episode Length: 80.792 | Episode Return: 0.294 | Time: 0.046\n", + "INFO:root:Train >> Env Steps: 5.25e+07 | Steps Per Second: 2,238,783 | Time: 0.286\n", + "INFO:root:Eval Stochastic >> Env Steps: 5.31e+07 | Episode Length: 83.096 | Episode Return: 0.272 | Time: 0.044\n", + "INFO:root:Train >> Env Steps: 5.31e+07 | Steps Per Second: 2,115,799 | Time: 0.302\n", + "INFO:root:Eval Stochastic >> Env Steps: 5.38e+07 | Episode Length: 81.346 | Episode Return: 0.286 | Time: 0.044\n", + "INFO:root:Train >> Env Steps: 5.38e+07 | Steps Per Second: 2,210,072 | Time: 0.290\n", + "INFO:root:Eval Stochastic >> Env Steps: 5.44e+07 | Episode Length: 81.972 | Episode Return: 0.284 | Time: 0.050\n", + "INFO:root:Train >> Env Steps: 5.44e+07 | Steps Per Second: 2,327,461 | Time: 0.275\n", + "INFO:root:Eval Stochastic >> Env Steps: 5.50e+07 | Episode Length: 83.524 | Episode Return: 0.274 | Time: 0.037\n", + "INFO:root:Train >> Env Steps: 5.50e+07 | Steps Per Second: 2,500,349 | Time: 0.256\n", + "INFO:root:Eval Stochastic >> Env Steps: 5.57e+07 | Episode Length: 82.860 | Episode Return: 0.282 | Time: 0.037\n", + "INFO:root:Train >> Env Steps: 5.57e+07 | Steps Per Second: 2,319,887 | Time: 0.276\n", + "INFO:root:Eval Stochastic >> Env Steps: 5.63e+07 | Episode Length: 81.936 | Episode Return: 0.288 | Time: 0.048\n", + "INFO:root:Train >> Env Steps: 5.63e+07 | Steps Per Second: 1,927,306 | Time: 0.332\n", + "INFO:root:Eval Stochastic >> Env Steps: 5.70e+07 | Episode Length: 81.932 | Episode Return: 0.288 | Time: 0.051\n", + "INFO:root:Train >> Env Steps: 5.70e+07 | Steps Per Second: 1,521,043 | Time: 0.421\n", + "INFO:root:Eval Stochastic >> Env Steps: 5.76e+07 | Episode Length: 84.466 | Episode Return: 0.260 | Time: 0.066\n", + "INFO:root:Train >> Env Steps: 5.76e+07 | Steps Per Second: 1,035,243 | Time: 0.618\n", + "INFO:root:Eval Stochastic >> Env Steps: 5.82e+07 | Episode Length: 81.380 | Episode Return: 0.298 | Time: 0.097\n", + "INFO:root:Train >> Env Steps: 5.82e+07 | Steps Per Second: 467,005 | Time: 1.370\n", + "INFO:root:Eval Stochastic >> Env Steps: 5.89e+07 | Episode Length: 81.800 | Episode Return: 0.302 | Time: 0.483\n", + "INFO:root:Train >> Env Steps: 5.89e+07 | Steps Per Second: 187,811 | Time: 3.408\n", + "INFO:root:Eval Stochastic >> Env Steps: 5.95e+07 | Episode Length: 82.188 | Episode Return: 0.276 | Time: 0.577\n", + "INFO:root:Train >> Env Steps: 5.95e+07 | Steps Per Second: 191,302 | Time: 3.345\n", + "INFO:root:Eval Stochastic >> Env Steps: 6.02e+07 | Episode Length: 81.502 | Episode Return: 0.276 | Time: 0.554\n", + "INFO:root:Train >> Env Steps: 6.02e+07 | Steps Per Second: 191,432 | Time: 3.343\n", + "INFO:root:Eval Stochastic >> Env Steps: 6.08e+07 | Episode Length: 81.666 | Episode Return: 0.288 | Time: 0.541\n", + "INFO:root:Train >> Env Steps: 6.08e+07 | Steps Per Second: 182,772 | Time: 3.502\n", + "INFO:root:Eval Stochastic >> Env Steps: 6.14e+07 | Episode Length: 82.468 | Episode Return: 0.268 | Time: 0.636\n", + "INFO:root:Train >> Env Steps: 6.14e+07 | Steps Per Second: 192,009 | Time: 3.333\n", + "INFO:root:Eval Stochastic >> Env Steps: 6.21e+07 | Episode Length: 80.256 | Episode Return: 0.296 | Time: 0.599\n", + "INFO:root:Train >> Env Steps: 6.21e+07 | Steps Per Second: 192,225 | Time: 3.329\n", + "INFO:root:Eval Stochastic >> Env Steps: 6.27e+07 | Episode Length: 83.312 | Episode Return: 0.258 | Time: 0.479\n", + "INFO:root:Train >> Env Steps: 6.27e+07 | Steps Per Second: 210,990 | Time: 3.033\n", + "INFO:root:Eval Stochastic >> Env Steps: 6.34e+07 | Episode Length: 81.808 | Episode Return: 0.280 | Time: 0.252\n", + "INFO:root:Train >> Env Steps: 6.34e+07 | Steps Per Second: 533,761 | Time: 1.199\n", + "INFO:root:Saving checkpoint...\n", + "INFO:root:Checkpoint saved at 'training_state'.\n", + "INFO:root:Closing logger...\n" + ] + } + ], "source": [ - "with initialize(version_base=None, config_path=\"../jumanji/training/configs\"):\n", + "with initialize(version_base=None, config_path=\"configs\"):\n", " cfg = compose(config_name=\"config.yaml\", overrides=[f\"env={env}\", f\"agent={agent}\", \"logger.type=terminal\", \"logger.save_checkpoint=true\"])\n", - "cfg" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + "\n", "train(cfg)" ] } diff --git a/jumanji/training/configs/env/cleaner.yaml b/jumanji/training/configs/env/cleaner.yaml index 1064ec8f6..11f199b3a 100644 --- a/jumanji/training/configs/env/cleaner.yaml +++ b/jumanji/training/configs/env/cleaner.yaml @@ -1,20 +1,20 @@ name: cleaner registered_version: Cleaner-v0 +network: + num_conv_channels: [4, 4, 1] + policy_layers: [64] + value_layers: [128] + training: num_epochs: 300 num_learner_steps_per_epoch: 500 n_steps: 10 total_batch_size: 128 -network: - num_conv_channels: [4, 4, 1] - policy_layers: [64] - value_layers: [128] - evaluation: - eval_total_batch_size: 500 - greedy_eval_total_batch_size: 500 + eval_total_batch_size: 512 + greedy_eval_total_batch_size: 512 a2c: normalize_advantage: False diff --git a/jumanji/training/configs/env/game_2048.yaml b/jumanji/training/configs/env/game_2048.yaml index a51b68e83..7b7c0e368 100644 --- a/jumanji/training/configs/env/game_2048.yaml +++ b/jumanji/training/configs/env/game_2048.yaml @@ -13,8 +13,8 @@ training: total_batch_size: 32 evaluation: - eval_total_batch_size: 500 - greedy_eval_total_batch_size: 500 + eval_total_batch_size: 512 + greedy_eval_total_batch_size: 512 a2c: normalize_advantage: False diff --git a/jumanji/training/configs/env/job_shop.yaml b/jumanji/training/configs/env/job_shop.yaml index 2449c475d..a3bdfac0f 100644 --- a/jumanji/training/configs/env/job_shop.yaml +++ b/jumanji/training/configs/env/job_shop.yaml @@ -16,8 +16,8 @@ training: total_batch_size: 128 evaluation: - eval_total_batch_size: 500 - greedy_eval_total_batch_size: 500 + eval_total_batch_size: 512 + greedy_eval_total_batch_size: 512 a2c: normalize_advantage: False diff --git a/jumanji/training/configs/env/maze.yaml b/jumanji/training/configs/env/maze.yaml index 54a02efa1..8486f8570 100644 --- a/jumanji/training/configs/env/maze.yaml +++ b/jumanji/training/configs/env/maze.yaml @@ -13,8 +13,8 @@ training: total_batch_size: 128 evaluation: - eval_total_batch_size: 500 - greedy_eval_total_batch_size: 500 + eval_total_batch_size: 512 + greedy_eval_total_batch_size: 512 a2c: normalize_advantage: False diff --git a/jumanji/training/configs/env/rubiks_cube.yaml b/jumanji/training/configs/env/rubiks_cube.yaml index 03d6d9033..e0b452bca 100644 --- a/jumanji/training/configs/env/rubiks_cube.yaml +++ b/jumanji/training/configs/env/rubiks_cube.yaml @@ -13,8 +13,8 @@ training: total_batch_size: 256 evaluation: - eval_total_batch_size: 1000 - greedy_eval_total_batch_size: 1000 + eval_total_batch_size: 1024 + greedy_eval_total_batch_size: 1024 a2c: normalize_advantage: False diff --git a/jumanji/training/configs/env/snake.yaml b/jumanji/training/configs/env/snake.yaml index d4669feaf..2310c3f21 100644 --- a/jumanji/training/configs/env/snake.yaml +++ b/jumanji/training/configs/env/snake.yaml @@ -13,8 +13,8 @@ training: total_batch_size: 128 evaluation: - eval_total_batch_size: 1000 - greedy_eval_total_batch_size: 1000 + eval_total_batch_size: 1024 + greedy_eval_total_batch_size: 1024 a2c: normalize_advantage: False diff --git a/jumanji/training/configs/env/sudoku.yaml b/jumanji/training/configs/env/sudoku.yaml index 8b0ce9735..53c0c9eac 100644 --- a/jumanji/training/configs/env/sudoku.yaml +++ b/jumanji/training/configs/env/sudoku.yaml @@ -14,8 +14,8 @@ training: total_batch_size: 128 evaluation: - eval_total_batch_size: 1000 - greedy_eval_total_batch_size: 1000 + eval_total_batch_size: 1024 + greedy_eval_total_batch_size: 1024 a2c: normalize_advantage: False diff --git a/jumanji/training/configs/env/tetris.yaml b/jumanji/training/configs/env/tetris.yaml index 05b3a49b5..f6ba7d534 100644 --- a/jumanji/training/configs/env/tetris.yaml +++ b/jumanji/training/configs/env/tetris.yaml @@ -13,8 +13,8 @@ training: total_batch_size: 128 evaluation: - eval_total_batch_size: 1000 - greedy_eval_total_batch_size: 1000 + eval_total_batch_size: 1024 + greedy_eval_total_batch_size: 1024 a2c: normalize_advantage: False From d8ecfae58c70375546627599afc70669c4628bcb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Bonnet?= <56230714+clement-bonnet@users.noreply.github.com> Date: Sun, 18 Jun 2023 19:52:07 +0200 Subject: [PATCH 2/6] ci: update to latest jax and chex (#174) --- jumanji/testing/pytrees.py | 5 +++-- jumanji/testing/pytrees_test.py | 25 ++++++++++++------------- jumanji/wrappers.py | 2 +- requirements/requirements.txt | 6 +++--- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/jumanji/testing/pytrees.py b/jumanji/testing/pytrees.py index be9f88b10..a60dee52e 100644 --- a/jumanji/testing/pytrees.py +++ b/jumanji/testing/pytrees.py @@ -16,6 +16,7 @@ import chex import jax +import jax.numpy as jnp import jax.tree_util import numpy as np import tree as tree_lib @@ -89,8 +90,8 @@ def assert_tree_with_leaves_of_type(input_tree: Any, *leaf_type: Type) -> None: def assert_is_jax_array_tree(input_tree: chex.ArrayTree) -> None: - """Asserts that the `input_tree` has leaves that are exclusively of type `chex.Array`.""" - assert_tree_with_leaves_of_type(input_tree, chex.Array, type(None)) + """Asserts that the `input_tree` has leaves that are exclusively of type `jnp.ndarray`.""" + assert_tree_with_leaves_of_type(input_tree, jnp.ndarray, type(None)) def has_at_least_rank(input_tree: chex.ArrayTree, rank: int) -> bool: diff --git a/jumanji/testing/pytrees_test.py b/jumanji/testing/pytrees_test.py index 8ce3ca5fe..cef9e0d02 100644 --- a/jumanji/testing/pytrees_test.py +++ b/jumanji/testing/pytrees_test.py @@ -15,7 +15,6 @@ import re from typing import Dict -import chex import jax.numpy as jnp import numpy as np import pytest @@ -108,13 +107,13 @@ def test_is_tree_with_leaves_of_type( is composed exclusively of leaves of the specified type, and `False` if there is at least one leaf that is not of the specified type. """ - assert pytree_test_utils.is_tree_with_leaves_of_type(jax_tree, chex.Array) - assert pytree_test_utils.is_tree_with_leaves_of_type(np_tree, chex.ArrayNumpy) + assert pytree_test_utils.is_tree_with_leaves_of_type(jax_tree, jnp.ndarray) + assert pytree_test_utils.is_tree_with_leaves_of_type(np_tree, np.ndarray) assert not pytree_test_utils.is_tree_with_leaves_of_type( - jax_and_numpy_tree, chex.Array + jax_and_numpy_tree, jnp.ndarray ) assert not pytree_test_utils.is_tree_with_leaves_of_type( - jax_and_numpy_tree, chex.ArrayNumpy + jax_and_numpy_tree, np.ndarray ) @@ -128,21 +127,21 @@ def test_assert_tree_with_leaves_of_type( is composed exclusively of leaves of a specified type, and raises an AssertionError if there is at least one leaf of a different type. """ - pytree_test_utils.assert_tree_with_leaves_of_type(jax_tree, chex.Array) - pytree_test_utils.assert_tree_with_leaves_of_type(np_tree, chex.ArrayNumpy) + pytree_test_utils.assert_tree_with_leaves_of_type(jax_tree, jnp.ndarray) + pytree_test_utils.assert_tree_with_leaves_of_type(np_tree, np.ndarray) with pytest.raises( AssertionError, - match=f"The tree has at least one leaf that is not of type {chex.Array}.", + match=f"The tree has at least one leaf that is not of type {jnp.ndarray}.", ): pytree_test_utils.assert_tree_with_leaves_of_type( - jax_and_numpy_tree, chex.Array + jax_and_numpy_tree, jnp.ndarray ) with pytest.raises( AssertionError, - match=f"The tree has at least one leaf that is not of type {chex.ArrayNumpy}.", + match=f"The tree has at least one leaf that is not of type {np.ndarray}.", ): pytree_test_utils.assert_tree_with_leaves_of_type( - jax_and_numpy_tree, chex.ArrayNumpy + jax_and_numpy_tree, np.ndarray ) @@ -159,12 +158,12 @@ def test_assert_is_jax_array_tree( pytree_test_utils.assert_is_jax_array_tree(jax_tree) with pytest.raises( AssertionError, - match=f"The tree has at least one " f"leaf that is not of type {chex.Array}.", + match=f"The tree has at least one " f"leaf that is not of type {jnp.ndarray}.", ): pytree_test_utils.assert_is_jax_array_tree(np_tree) with pytest.raises( AssertionError, - match=f"The tree has at least one " f"leaf that is not of type {chex.Array}.", + match=f"The tree has at least one " f"leaf that is not of type {jnp.ndarray}.", ): pytree_test_utils.assert_is_jax_array_tree(jax_and_numpy_tree) diff --git a/jumanji/wrappers.py b/jumanji/wrappers.py index 72f38be0b..3dbb9d338 100644 --- a/jumanji/wrappers.py +++ b/jumanji/wrappers.py @@ -642,7 +642,7 @@ def jumanji_to_gym_obs(observation: Observation) -> GymObservation: Returns: Numpy array or nested dictionary of numpy arrays. """ - if isinstance(observation, chex.Array): + if isinstance(observation, jnp.ndarray): return np.asarray(observation) elif hasattr(observation, "__dict__"): # Applies to various containers including `chex.dataclass` diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 3d1d3e301..3aca4d41f 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,8 +1,8 @@ -chex>=0.1.3,<0.1.6 +chex>=0.1.3 dm-env>=1.5 gym>=0.22.0 -jax>=0.2.26,<=0.4.10 -jaxlib>=0.1.74,<=0.4.10 +jax>=0.2.26 +jaxlib>=0.1.74 matplotlib>=3.3.4 numpy>=1.19.5 Pillow>=9.0.0 From 77878f5d1f62e232c6cd296ee7d1fe1fe4aa1b20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Bonnet?= <56230714+clement-bonnet@users.noreply.github.com> Date: Mon, 19 Jun 2023 10:40:42 +0200 Subject: [PATCH 3/6] build: remove jaxlib (#175) --- requirements/requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 3aca4d41f..23c26ee59 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -2,7 +2,6 @@ chex>=0.1.3 dm-env>=1.5 gym>=0.22.0 jax>=0.2.26 -jaxlib>=0.1.74 matplotlib>=3.3.4 numpy>=1.19.5 Pillow>=9.0.0 From 96e8e52464a7ee5e9884839d10ee43527c5ec727 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Bonnet?= <56230714+clement-bonnet@users.noreply.github.com> Date: Mon, 19 Jun 2023 11:35:46 +0200 Subject: [PATCH 4/6] docs: update readme citation (#176) --- README.md | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 09c6ad6bc..2db8ae9d5 100644 --- a/README.md +++ b/README.md @@ -222,15 +222,21 @@ details on how to submit pull requests, our Contributor License Agreement, and c If you use Jumanji in your work, please cite the library using: ``` -@software{jumanji2023github, - author = {Clément Bonnet and Daniel Luo and Donal Byrne and Sasha Abramowitz - and Vincent Coyette and Paul Duckworth and Daniel Furelos-Blanco and - Nathan Grinsztajn and Tristan Kalloniatis and Victor Le and Omayma Mahjoub - and Laurence Midgley and Shikha Surana and Cemlyn Waters and Alexandre Laterre}, - title = {Jumanji: a Suite of Diverse and Challenging Reinforcement Learning Environments in JAX}, - url = {https://github.com/instadeepai/jumanji}, - version = {0.2.2}, - year = {2023}, +@misc{bonnet2023jumanji, + title={Jumanji: a Diverse Suite of Scalable Reinforcement Learning Environments in JAX}, + author={ + Clément Bonnet and Daniel Luo and Donal Byrne and Shikha Surana and Vincent Coyette and + Paul Duckworth and Laurence I. Midgley and Tristan Kalloniatis and Sasha Abramowitz and + Cemlyn N. Waters and Andries P. Smit and Nathan Grinsztajn and Ulrich A. Mbou Sob and + Omayma Mahjoub and Elshadai Tegegn and Mohamed A. Mimouni and Raphael Boige and + Ruan de Kock and Daniel Furelos-Blanco and Victor Le and Arnu Pretorius and + Alexandre Laterre + }, + year={2023}, + eprint={2306.09884}, + url={https://arxiv.org/abs/2306.09884}, + archivePrefix={arXiv}, + primaryClass={cs.LG} } ``` From 32685cb9b21b35b6bff01e56186933f9803a7dbe Mon Sep 17 00:00:00 2001 From: aar65537 <115365716+aar65537@users.noreply.github.com> Date: Tue, 20 Jun 2023 10:01:43 -0500 Subject: [PATCH 5/6] feat(2048): environment performance improvements (#172) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Clément Bonnet <56230714+clement-bonnet@users.noreply.github.com> --- jumanji/environments/logic/game_2048/env.py | 29 +- jumanji/environments/logic/game_2048/utils.py | 417 +++++++++--------- .../logic/game_2048/utils_test.py | 36 ++ 3 files changed, 258 insertions(+), 224 deletions(-) diff --git a/jumanji/environments/logic/game_2048/env.py b/jumanji/environments/logic/game_2048/env.py index 323995bdd..ba8115c16 100644 --- a/jumanji/environments/logic/game_2048/env.py +++ b/jumanji/environments/logic/game_2048/env.py @@ -23,12 +23,7 @@ from jumanji import specs from jumanji.env import Environment from jumanji.environments.logic.game_2048.types import Board, Observation, State -from jumanji.environments.logic.game_2048.utils import ( - move_down, - move_left, - move_right, - move_up, -) +from jumanji.environments.logic.game_2048.utils import can_move, move from jumanji.environments.logic.game_2048.viewer import Game2048Viewer from jumanji.types import TimeStep, restart, termination, transition from jumanji.viewer import Viewer @@ -181,11 +176,7 @@ def step( timestep: the next timestep. """ # Take the action in the environment: Up, Right, Down, Left. - updated_board, additional_reward = jax.lax.switch( - action, - [move_up, move_right, move_down, move_left], - state.board, - ) + updated_board, reward = move(state.board, action) # Generate new key. random_cell_key, new_state_key = jax.random.split(state.key) @@ -209,7 +200,7 @@ def step( action_mask=action_mask, step_count=state.step_count + 1, key=new_state_key, - score=state.score + additional_reward.astype(float), + score=state.score + reward, ) # Generate the observation from the environment state. @@ -227,12 +218,12 @@ def step( timestep = jax.lax.cond( done, lambda: termination( - reward=additional_reward, + reward=reward, observation=observation, extras=extras, ), lambda: transition( - reward=additional_reward, + reward=reward, observation=observation, extras=extras, ), @@ -303,15 +294,7 @@ def _get_action_mask(self, board: Board) -> chex.Array: Returns: action_mask: action mask for the current state of the environment. """ - action_mask = jnp.array( - [ - jnp.any(move_up(board, final_shift=False)[0] != board), - jnp.any(move_right(board, final_shift=False)[0] != board), - jnp.any(move_down(board, final_shift=False)[0] != board), - jnp.any(move_left(board, final_shift=False)[0] != board), - ], - ) - return action_mask + return jax.vmap(can_move, (None, 0))(board, jnp.arange(4)) def render(self, state: State) -> Optional[NDArray]: """Renders the current state of the game board. diff --git a/jumanji/environments/logic/game_2048/utils.py b/jumanji/environments/logic/game_2048/utils.py index ae17c3e66..0cea3389a 100644 --- a/jumanji/environments/logic/game_2048/utils.py +++ b/jumanji/environments/logic/game_2048/utils.py @@ -12,229 +12,244 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools -from typing import Tuple +from typing import NamedTuple, Tuple +import chex import jax import jax.numpy as jnp -from jax.numpy import DeviceArray from jumanji.environments.logic.game_2048.types import Board -def shift_nonzero_element(carry: Tuple) -> Tuple[DeviceArray, int]: - """Shift nonzero element from index i to index j and increment j. - For example, in the case of this column [2, 0, 2, 0], this method will be invoked - when `i` equals 0 and 2, and it will return successively ([2, 0, 2, 0], `j` = 1) - and ([2, 2, 2, 0], `j` = 2). - - Args: - carry: - col: a column of the board. - i: the current index. - j: the index of the nonzero element. It also represents the number of nonzero - elements that have been shifted so far. - - Returns: - A tuple containing the updated array (col) and the incremented target index (j). - """ - col, j, i = carry - col = col.at[j].set(col[i]) - j += 1 - return col, j - - -def shift_column_elements_up(carry: Tuple, i: int) -> Tuple[DeviceArray, None]: - """This method calls `shift_nonzero_element` to shift non-zero elements in the column, - and conducts the identity operation if the element is zero. - - Agrs: - carry: - col: a one-dimensional array representing a column of the board. - j: the index of the non zero element. It also represents the number of non-zero - elements that have been shifted so far. - i: the current index. - - Returns: - A tuple containing the updated column and None. - """ - col, j = carry - col, j = jax.lax.cond( - col[i] != 0, - shift_nonzero_element, - lambda col_j_i: col_j_i[:2], - (col, j, i), +def transform_board(board: Board, action: int) -> Board: + """Transform board so that move_left is analagous to move_action. Also, transform back.""" + return jax.lax.switch( + action, + [ + lambda: jnp.transpose(board), + lambda: jnp.flip(board, 1), + lambda: jnp.flip(jnp.transpose(board)), + lambda: board, + ], ) - return (col, j), None - - -def fill_with_zero(carry: Tuple[DeviceArray, int]) -> Tuple[DeviceArray, int]: - """Fill the remaining elements of the column with zeros after shifting non-zero elements to the up. - For example: if the initial column is [2, 0, 2, 0] then this method will be invoked when `j` - equals to 2 and 3. - - Args: - carry: - col: a column of the board. - j: the index of the nonzero element. It also represents the number of nonzero - elements that have been shifted so far. - - Returns: - A tuple containing the updated column and incremented index. - """ - col, j = carry - col = col.at[j].set(0) - j += 1 - return col, j - - -def shift_up(col: DeviceArray) -> DeviceArray: - """Shift all the elements in a column up. - For example: [2, 0, 2, 0] -> [2, 2, 0, 0] - - Args: - col: a column of the board. - - Returns: - The modified column with all the elements shifted up. - """ - j = 0 - (col, j), _ = jax.lax.scan( # In example: [2, 0, 2, 0] -> [2, 2, 2, 0] - f=shift_column_elements_up, init=(col, j), xs=jnp.arange(len(col)) + + +class CanMoveCarry(NamedTuple): + """Carry value for while loop in can_move_left_row.""" + + can_move: bool + row: chex.Array + target_idx: int + origin_idx: int + + @property + def target(self) -> chex.Numeric: + """Tile at target index of row.""" + return self.row[self.target_idx] + + @property + def origin(self) -> chex.Numeric: + """Tile at origin index of row.""" + return self.row[self.origin_idx] + + +def can_move_left_row_cond(carry: CanMoveCarry) -> chex.Numeric: + """Terminate loop when valid move is found or origin reaches end of row.""" + return ~carry.can_move & (carry.origin_idx < carry.row.shape[0]) + + +def can_move_left_row_body(carry: CanMoveCarry) -> CanMoveCarry: + """Check if the current tiles can move and increment the indices.""" + # Check if tiles can move + can_move = (carry.origin != 0) & ( + (carry.target == 0) | (carry.target == carry.origin) ) - col, j = jax.lax.while_loop( # In example: [2, 2, 2, 0] -> [2, 2, 0, 0] - lambda col_j: col_j[1] < len(col_j[0]), - fill_with_zero, - (col, j), + + # Increment indices as if performed a no op + # If not performing no op, loop will be terminated anyways + target_idx = carry.target_idx + (carry.origin != 0) + origin_idx = jax.lax.select( + (carry.origin == 0) | (target_idx == carry.origin_idx), + carry.origin_idx + 1, + carry.origin_idx, ) - return col - - -def merge_elements(carry: Tuple) -> Tuple[DeviceArray, float]: - """Merge two adjacent elements in a column. - For example: col = [1, 1, 2, 2] and i = 2 -> [1, 1, 3, 0], with a reward equal to 2³. - - Args: - carry: a tuple containing the current state of the column, the current index, - and the current reward. - - Returns: - A tuple containing the modified column, and the updated reward. - """ - col, reward, i = carry - new_col_i = col[i] + 1 - col = col.at[i].set(new_col_i) - col = col.at[i + 1].set(0) - reward += 2**new_col_i - return col, reward - - -def merge_equal_elements( - carry: Tuple[DeviceArray, float], i: int -) -> Tuple[Tuple[DeviceArray, float], None]: - """This function merges adjacent non-zero elements in the column of the board, if the - two adjacent elements are equal. - This function will examine each element individually to locate two adjacent equal elements. - For example in the case of [1, 1, 2, 2], this method will call `merge_elements` for `i` equals - to 0 and 2. - - Args: - carry: a tuple containing the current state of the column, and the current reward. - i: the current index. - - Returns: - Tuple containing the updated column and the reward. - """ - col, reward = carry - col, reward = jax.lax.cond( - ((col[i] != 0) & (col[i] == col[i + 1])), - merge_elements, - lambda col_reward_i: col_reward_i[:2], - (col, reward, i), + + # Return updated carry + return carry._replace( + can_move=can_move, target_idx=target_idx, origin_idx=origin_idx ) - return (col, reward), None -def merge_col(col: DeviceArray) -> Tuple[DeviceArray, float]: - """Merge the elements of a column according to the rules of the 2048 game. - For example: [0, 0, 2, 2] -> [0, 0, 3, 0] with a reward equal to 2³. +def can_move_left_row(row: chex.Array) -> bool: + """Check if row can move left.""" + carry = CanMoveCarry(can_move=False, row=row, target_idx=0, origin_idx=1) + can_move: bool = jax.lax.while_loop( + can_move_left_row_cond, can_move_left_row_body, carry + )[0] + return can_move + + +def can_move_left(board: Board) -> bool: + """Check if board can move left.""" + can_move: bool = jax.vmap(can_move_left_row)(board).any() + return can_move + + +def can_move(board: Board, action: int) -> bool: + """Check if board can move with action.""" + return can_move_left(transform_board(board, action)) + + +def can_move_up(board: Board) -> bool: + """Check if board can move up.""" + return can_move(board, 0) + + +def can_move_right(board: Board) -> bool: + """Check if board can move right.""" + return can_move(board, 1) + + +def can_move_down(board: Board) -> bool: + """Check if board can move down.""" + return can_move(board, 2) + + +class MoveUpdate(NamedTuple): + """Update to move carry.""" + + target: chex.Numeric + origin: chex.Numeric + additional_reward: float + target_idx: int + origin_idx: int - Args: - col: a column of the board. - Returns: - A tuple containing the modified column and the total reward obtained by - merging the elements. - """ - reward = 0.0 - elements_indices = jnp.arange(len(col) - 1) - (col, reward), _ = jax.lax.scan( - f=merge_equal_elements, init=(col, reward), xs=elements_indices +class MoveCarry(NamedTuple): + """Carry value for while loop in move_left_row.""" + + row: chex.Array + reward: float + target_idx: int + origin_idx: int + + @property + def target(self) -> chex.Numeric: + """Tile at target index of row.""" + return self.row[self.target_idx] + + @property + def origin(self) -> chex.Numeric: + """Tile at origin index of row.""" + return self.row[self.origin_idx] + + def update(self, update: MoveUpdate) -> "MoveCarry": + """Return new updated carry. This method will cause row to be copied when called within a + jax conditional primative such as `jax.lax.cond` or `jax.lax.switch`. + """ + # Update row + row = self.row + row = row.at[self.target_idx].set(update.target) + row = row.at[self.origin_idx].set(update.origin) + + # Return updated carry + return self._replace( + row=row, + reward=self.reward + update.additional_reward, + target_idx=update.target_idx, + origin_idx=update.origin_idx, + ) + + +def no_op(carry: MoveCarry) -> MoveUpdate: + """Return a move update equivalent to performing a no op.""" + target_idx = carry.target_idx + (carry.origin != 0) + origin_idx = jax.lax.select( + (carry.origin == 0) | (target_idx == carry.origin_idx), + carry.origin_idx + 1, + carry.origin_idx, ) - return col, reward - - -def move_up_col( - carry: Tuple[Board, float], c: int, final_shift: bool = True -) -> Tuple[Tuple[Board, float], None]: - """Move the elements in the specified column up and merge those that are equal in - a single pass. `final_shift` is not needed when computing the action mask - this is - because creating the action mask only requires knowledge of whether the board will - have changed as a result of the action. - - For example: [2, 2, 1, 1] -> [3, 2, 0, 0]. - - Args: - carry: tuple containing the board and the additional reward. - c: column index to perform the move and merge on. - final_shift: is a flag to determine if the column should be shifted up once or - twice. In the "get_action_mask" method, it is set to False, as the purpose is - to check if the action is allowed and one shift is enough for this determination. - - Returns: - Tuple containing the updated board and the additional reward. - """ - board, additional_reward = carry - col = board[:, c] - col = shift_up(col) # In example: [2, 2, 1, 1] -> [2, 2, 1, 1] - col, reward = merge_col(col) # In example: [2, 2, 1, 1] -> [3, 0, 2, 0] - if final_shift: - col = shift_up(col) # In example: [3, 0, 2, 0] -> [3, 2, 0, 0] - additional_reward += reward - return (board.at[:, c].set(col), additional_reward), None - - -def move_up(board: Board, final_shift: bool = True) -> Tuple[Board, float]: - """Move up.""" - additional_reward = 0.0 - col_indices = jnp.arange(board.shape[0]) # Board of size 4 -> [0, 1, 2, 3] - (board, additional_reward), _ = jax.lax.scan( - f=functools.partial(move_up_col, final_shift=final_shift), - init=(board, additional_reward), - xs=col_indices, + return MoveUpdate( + target=carry.target, + origin=carry.origin, + additional_reward=0.0, + target_idx=target_idx, + origin_idx=origin_idx, ) - return board, additional_reward -def move_down(board: Board, final_shift: bool = True) -> Tuple[Board, float]: - """Move down.""" - board, additional_reward = move_up( - board=jnp.flip(board, 0), final_shift=final_shift +def shift(carry: MoveCarry) -> MoveUpdate: + """Return a move update equivalent to shifting origin to target.""" + return MoveUpdate( + target=carry.origin, + origin=0, + additional_reward=0.0, + target_idx=carry.target_idx, + origin_idx=carry.origin_idx + 1, ) - return jnp.flip(board, 0), additional_reward -def move_left(board: Board, final_shift: bool = True) -> Tuple[Board, float]: - """Move left.""" - board, additional_reward = move_up( - board=jnp.rot90(board, k=-1), final_shift=final_shift +def merge(carry: MoveCarry) -> MoveUpdate: + """Return a move update equivalent to merging origin with target.""" + return MoveUpdate( + target=carry.target + 1, + origin=0, + additional_reward=2.0 ** (carry.target + 1), + target_idx=carry.target_idx + 1, + origin_idx=carry.origin_idx + 1, ) - return jnp.rot90(board, k=1), additional_reward -def move_right(board: Board, final_shift: bool = True) -> Tuple[Board, float]: - """Move right.""" - board, additional_reward = move_up( - board=jnp.rot90(board, k=1), final_shift=final_shift - ) - return jnp.rot90(board, k=-1), additional_reward +def move_left_row_cond(carry: MoveCarry) -> chex.Numeric: + """Terminate loop when origin reaches end of row.""" + return carry.origin_idx < carry.row.shape[0] + + +def move_left_row_body(carry: MoveCarry) -> MoveCarry: + """Move the current tiles and increment the indices.""" + # Determine move type + can_shift = (carry.origin != 0) & (carry.target == 0) + can_merge = (carry.origin != 0) & (carry.target == carry.origin) + move_type = can_shift.astype(int) + 2 * can_merge.astype(int) + + # Get update based on move type + update = jax.lax.switch(move_type, [no_op, shift, merge], carry) + + # Return updated carry + return carry.update(update) + + +def move_left_row(row: chex.Array) -> Tuple[chex.Array, float]: + """Move the row left.""" + carry = MoveCarry(row=row, reward=0.0, target_idx=0, origin_idx=1) + row, reward, *_ = jax.lax.while_loop(move_left_row_cond, move_left_row_body, carry) + return row, reward + + +def move_left(board: Board) -> Tuple[Board, float]: + """Move the board left.""" + board, reward = jax.vmap(move_left_row)(board) + return board, reward.sum() + + +def move(board: Board, action: int) -> Tuple[Board, float]: + """Move the board with action.""" + board = transform_board(board, action) + board, reward = move_left(board) + board = transform_board(board, action) + return board, reward + + +def move_up(board: Board) -> Tuple[Board, float]: + """Move the board up.""" + return move(board, 0) + + +def move_right(board: Board) -> Tuple[Board, float]: + """Move the board right.""" + return move(board, 1) + + +def move_down(board: Board) -> Tuple[Board, float]: + """Move the board down.""" + return move(board, 2) diff --git a/jumanji/environments/logic/game_2048/utils_test.py b/jumanji/environments/logic/game_2048/utils_test.py index 0e27c8dcc..5d38983b9 100644 --- a/jumanji/environments/logic/game_2048/utils_test.py +++ b/jumanji/environments/logic/game_2048/utils_test.py @@ -17,6 +17,10 @@ from jumanji.environments.logic.game_2048.types import Board from jumanji.environments.logic.game_2048.utils import ( + can_move_down, + can_move_left, + can_move_right, + can_move_up, move_down, move_left, move_right, @@ -72,6 +76,38 @@ def board8x8() -> Board: return board +def test_can_move_down(board: Board, another_board: Board) -> None: + """Test checking if the board can move down.""" + assert can_move_down(board) + assert can_move_down(another_board) + board = jnp.array([[0, 0, 0, 0], [1, 0, 0, 0], [2, 1, 0, 0], [3, 2, 1, 0]]) + assert ~can_move_down(board) + + +def test_can_move_up(board: Board, another_board: Board) -> None: + """Test checking if the board can move up.""" + assert can_move_up(board) + assert can_move_up(another_board) + board = jnp.array([[4, 2, 1, 0], [3, 1, 0, 0], [2, 0, 0, 0], [1, 0, 0, 0]]) + assert ~can_move_up(board) + + +def test_can_move_right(board: Board, another_board: Board) -> None: + """Test checking if the board can move right.""" + assert can_move_right(board) + assert can_move_right(another_board) + board = jnp.array([[0, 0, 0, 0], [0, 0, 0, 1], [0, 0, 1, 2], [0, 1, 2, 3]]) + assert ~can_move_right(board) + + +def test_can_move_left(board: Board, another_board: Board) -> None: + """Test checking if the board can move left.""" + assert can_move_left(board) + assert can_move_left(another_board) + board = jnp.array([[1, 2, 3, 4], [1, 2, 0, 0], [1, 0, 0, 0], [0, 0, 0, 0]]) + assert ~can_move_left(board) + + def test_move_down(board: Board, another_board: Board) -> None: """Test shifting the board cells down.""" # First example. From 666d70954f1d99236f8cee5bd23560841f897f79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Bonnet?= <56230714+clement-bonnet@users.noreply.github.com> Date: Tue, 20 Jun 2023 17:39:41 +0200 Subject: [PATCH 6/6] build: bump version to 0.3.1 (#177) --- jumanji/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jumanji/version.py b/jumanji/version.py index b38e3a9e5..7d26f5fd9 100644 --- a/jumanji/version.py +++ b/jumanji/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.3.0" +__version__ = "0.3.1"