From 395fdecd115bce622715f9cb88033da85f66185d Mon Sep 17 00:00:00 2001 From: Daniel Luo Date: Thu, 30 Mar 2023 10:44:54 +0100 Subject: [PATCH 1/4] test: fix naming in cleaner tests --- .../environments/routing/cleaner/env_test.py | 60 +++++++++---------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/jumanji/environments/routing/cleaner/env_test.py b/jumanji/environments/routing/cleaner/env_test.py index 1be7176a8..f324f126f 100644 --- a/jumanji/environments/routing/cleaner/env_test.py +++ b/jumanji/environments/routing/cleaner/env_test.py @@ -47,7 +47,7 @@ def __call__(self, key: chex.PRNGKey) -> Maze: class TestCleaner: @pytest.fixture - def env(self) -> Cleaner: + def cleaner(self) -> Cleaner: generator = DummyGenerator() return Cleaner(num_agents=N_AGENT, generator=generator) @@ -55,10 +55,10 @@ def env(self) -> Cleaner: def key(self) -> chex.PRNGKey: return jax.random.PRNGKey(0) - def test_env__reset_jit(self, env: Cleaner) -> None: + def test_cleaner__reset_jit(self, cleaner: Cleaner) -> None: """Confirm that the reset is only compiled once when jitted.""" chex.clear_trace_counter() - reset_fn = jax.jit(chex.assert_max_traces(env.reset, n=1)) + reset_fn = jax.jit(chex.assert_max_traces(cleaner.reset, n=1)) key = jax.random.PRNGKey(0) state, timestep = reset_fn(key) @@ -67,8 +67,8 @@ def test_env__reset_jit(self, env: Cleaner) -> None: assert isinstance(timestep, TimeStep) assert isinstance(state, State) - def test_env_cleaner__reset(self, env: Cleaner, key: chex.PRNGKey) -> None: - reset_fn = jax.jit(env.reset) + def test_cleaner__reset(self, cleaner: Cleaner, key: chex.PRNGKey) -> None: + reset_fn = jax.jit(cleaner.reset) state, timestep = reset_fn(key) assert isinstance(timestep, TimeStep) @@ -80,14 +80,14 @@ def test_env_cleaner__reset(self, env: Cleaner, key: chex.PRNGKey) -> None: assert_is_jax_array_tree(state) - def test_env__step_jit(self, env: Cleaner) -> None: + def test_cleaner__step_jit(self, cleaner: Cleaner) -> None: """Confirm that the step is only compiled once when jitted.""" key = jax.random.PRNGKey(0) - state, timestep = env.reset(key) + state, timestep = cleaner.reset(key) action = jnp.array([1, 2, 3], jnp.int32) chex.clear_trace_counter() - step_fn = jax.jit(chex.assert_max_traces(env.step, n=1)) + step_fn = jax.jit(chex.assert_max_traces(cleaner.step, n=1)) next_state, next_timestep = step_fn(state, action) # Call again to check it does not compile twice @@ -95,10 +95,10 @@ def test_env__step_jit(self, env: Cleaner) -> None: assert isinstance(next_timestep, TimeStep) assert isinstance(next_state, State) - def test_env_cleaner__step(self, env: Cleaner, key: chex.PRNGKey) -> None: - initial_state, timestep = env.reset(key) + def test_cleaner__step(self, cleaner: Cleaner, key: chex.PRNGKey) -> None: + initial_state, timestep = cleaner.reset(key) - step_fn = jax.jit(env.step) + step_fn = jax.jit(cleaner.step) # First action: all agents move right actions = jnp.array([1] * N_AGENT) @@ -106,7 +106,7 @@ def test_env_cleaner__step(self, env: Cleaner, key: chex.PRNGKey) -> None: # Assert only one tile changed, on the right of the initial pos assert jnp.sum(state.grid != initial_state.grid) == 1 assert state.grid[0, 1] == CLEAN - assert timestep.reward == 1 - env.penalty_per_timestep + assert timestep.reward == 1 - cleaner.penalty_per_timestep assert jnp.all(state.agents_locations == jnp.array([0, 1])) # Second action: agent 0 and 2 move down, agent 1 moves left @@ -116,19 +116,19 @@ def test_env_cleaner__step(self, env: Cleaner, key: chex.PRNGKey) -> None: assert jnp.sum(state.grid != initial_state.grid) == 2 assert state.grid[0, 1] == CLEAN assert state.grid[1, 1] == CLEAN - assert timestep.reward == 1 - env.penalty_per_timestep + assert timestep.reward == 1 - cleaner.penalty_per_timestep assert timestep.step_type == StepType.MID assert jnp.all(state.agents_locations[0] == jnp.array([1, 1])) assert jnp.all(state.agents_locations[1] == jnp.array([0, 0])) assert jnp.all(state.agents_locations[2] == jnp.array([1, 1])) - def test_env_cleaner__step_invalid_action( - self, env: Cleaner, key: chex.PRNGKey + def test_cleaner__step_invalid_action( + self, cleaner: Cleaner, key: chex.PRNGKey ) -> None: - state, _ = env.reset(key) + state, _ = cleaner.reset(key) - step_fn = jax.jit(env.step) + step_fn = jax.jit(cleaner.step) # Invalid action for agent 0, valid for 1 and 2 actions = jnp.array([0, 1, 1]) state, timestep = step_fn(state, actions) @@ -139,12 +139,12 @@ def test_env_cleaner__step_invalid_action( assert jnp.all(state.agents_locations[1] == jnp.array([0, 1])) assert jnp.all(state.agents_locations[2] == jnp.array([0, 1])) - assert timestep.reward == 1 - env.penalty_per_timestep + assert timestep.reward == 1 - cleaner.penalty_per_timestep - def test_env_cleaner__initial_action_mask( - self, env: Cleaner, key: chex.PRNGKey + def test_cleaner__initial_action_mask( + self, cleaner: Cleaner, key: chex.PRNGKey ) -> None: - state, _ = env.reset(key) + state, _ = cleaner.reset(key) # All agents can only move right in the initial state expected_action_mask = jnp.array( @@ -153,21 +153,21 @@ def test_env_cleaner__initial_action_mask( assert jnp.all(state.action_mask == expected_action_mask) - action_mask = env._compute_action_mask(state.grid, state.agents_locations) + action_mask = cleaner._compute_action_mask(state.grid, state.agents_locations) assert jnp.all(action_mask == expected_action_mask) - def test_env_cleaner__action_mask(self, env: Cleaner, key: chex.PRNGKey) -> None: - state, _ = env.reset(key) + def test_cleaner__action_mask(self, cleaner: Cleaner, key: chex.PRNGKey) -> None: + state, _ = cleaner.reset(key) # Test action mask for different agent locations agents_locations = jnp.array([[1, 1], [2, 2], [4, 4]]) - action_mask = env._compute_action_mask(state.grid, agents_locations) + action_mask = cleaner._compute_action_mask(state.grid, agents_locations) assert jnp.all(action_mask[0] == jnp.array([True, False, True, False])) assert jnp.all(action_mask[1] == jnp.array([False, True, False, True])) assert jnp.all(action_mask[2] == jnp.array([False, False, False, True])) - def test_env_cleaner__does_not_smoke(self, env: Cleaner) -> None: + def test_cleaner__does_not_smoke(self, cleaner: Cleaner) -> None: def select_actions(key: chex.PRNGKey, observation: Observation) -> chex.Array: @jax.vmap # map over the keys and agents def select_action( @@ -180,12 +180,12 @@ def select_action( subkeys = jax.random.split(key, N_AGENT) return select_action(subkeys, observation.action_mask) - check_env_does_not_smoke(env, select_actions) + check_env_does_not_smoke(cleaner, select_actions) - def test_env_cleaner__compute_extras(self, env: Cleaner, key: chex.PRNGKey) -> None: - state, _ = env.reset(key) + def test_cleaner__compute_extras(self, cleaner: Cleaner, key: chex.PRNGKey) -> None: + state, _ = cleaner.reset(key) - extras = env._compute_extras(state) + extras = cleaner._compute_extras(state) assert list(extras.keys()) == ["ratio_dirty_tiles", "num_dirty_tiles"] assert 0 <= extras["ratio_dirty_tiles"] <= 1 grid = state.grid From 4360f3254f772190926711589d06368a19c9fe87 Mon Sep 17 00:00:00 2001 From: Daniel Luo Date: Thu, 30 Mar 2023 10:46:57 +0100 Subject: [PATCH 2/4] test: fix naming in maze tests --- jumanji/environments/routing/maze/env_test.py | 60 +++++++++---------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/jumanji/environments/routing/maze/env_test.py b/jumanji/environments/routing/maze/env_test.py index dae73f55a..e84c9b942 100644 --- a/jumanji/environments/routing/maze/env_test.py +++ b/jumanji/environments/routing/maze/env_test.py @@ -27,13 +27,13 @@ class TestMazeEnvironment: @pytest.fixture(scope="module") - def maze_env(self) -> Maze: + def maze(self) -> Maze: """Instantiates a default Maze environment.""" generator = RandomGenerator(num_rows=5, num_cols=5) return Maze(generator=generator, time_limit=15) - def test_env_maze__reset(self, maze_env: Maze) -> None: - reset_fn = jax.jit(maze_env.reset) + def test_maze__reset(self, maze: Maze) -> None: + reset_fn = jax.jit(maze.reset) key = jax.random.PRNGKey(0) state, timestep = reset_fn(key) @@ -49,10 +49,10 @@ def test_env_maze__reset(self, maze_env: Maze) -> None: assert not state.walls[tuple(state.agent_position)] assert not state.walls[tuple(state.target_position)] - def test_env__reset_jit(self, maze_env: Maze) -> None: + def test_env__reset_jit(self, maze: Maze) -> None: """Confirm that the reset is only compiled once when jitted.""" chex.clear_trace_counter() - reset_fn = jax.jit(chex.assert_max_traces(maze_env.reset, n=1)) + reset_fn = jax.jit(chex.assert_max_traces(maze.reset, n=1)) key = jax.random.PRNGKey(0) state, timestep = reset_fn(key) @@ -61,16 +61,16 @@ def test_env__reset_jit(self, maze_env: Maze) -> None: assert isinstance(timestep, TimeStep) assert isinstance(state, State) - def test_env__step_jit(self, maze_env: Maze) -> None: + def test_env__step_jit(self, maze: Maze) -> None: """Confirm that the step is only compiled once when jitted.""" key = jax.random.PRNGKey(0) - state, timestep = maze_env.reset(key) + state, timestep = maze.reset(key) assert isinstance(timestep, TimeStep) assert isinstance(state, State) action = jnp.array(2, jnp.int32) chex.clear_trace_counter() - step_fn = jax.jit(chex.assert_max_traces(maze_env.step, n=1)) + step_fn = jax.jit(chex.assert_max_traces(maze.step, n=1)) next_state, next_timestep = step_fn(state, action) # Call again to check it does not compile twice @@ -78,18 +78,18 @@ def test_env__step_jit(self, maze_env: Maze) -> None: assert isinstance(next_timestep, TimeStep) assert isinstance(next_state, State) - def test_env_maze__random_agent_start(self, maze_env: Maze) -> None: + def test_maze__random_agent_start(self, maze: Maze) -> None: key1, key2 = jax.random.PRNGKey(0), jax.random.PRNGKey(1) - state1, _ = maze_env.reset(key1) - state2, _ = maze_env.reset(key2) + state1, _ = maze.reset(key1) + state2, _ = maze.reset(key2) # Check random positions are different assert state1.agent_position != state2.agent_position assert state1.target_position != state2.target_position - def test_env_maze__step(self, maze_env: Maze) -> None: + def test_maze__step(self, maze: Maze) -> None: key = jax.random.PRNGKey(0) - state, _ = maze_env.reset(key) + state, _ = maze.reset(key) # Fixed agent start state agent_position = Position(row=4, col=0) @@ -97,12 +97,12 @@ def test_env_maze__step(self, maze_env: Maze) -> None: agent_position=agent_position, target_position=state.target_position, walls=state.walls, - action_mask=maze_env._compute_action_mask(state.walls, agent_position), + action_mask=maze._compute_action_mask(state.walls, agent_position), key=state.key, step_count=jnp.array(0, jnp.int32), ) - step_fn = jax.jit(maze_env.step) + step_fn = jax.jit(maze.step) # Agent takes a step right action = jnp.array(1, jnp.int32) @@ -144,16 +144,16 @@ def test_env_maze__step(self, maze_env: Maze) -> None: assert timestep.step_type == StepType.MID assert state.agent_position == Position(row=2, col=2) - def test_env_maze__action_mask(self, maze_env: Maze) -> None: + def test_maze__action_mask(self, maze: Maze) -> None: key = jax.random.PRNGKey(0) - state, _ = maze_env.reset(key) + state, _ = maze.reset(key) # Fixed agent start state agent_position = Position(row=4, col=0) # The agent can only move up or right in the initial state expected_action_mask = jnp.array([True, True, False, False]) - action_mask = maze_env._compute_action_mask(state.walls, agent_position) + action_mask = maze._compute_action_mask(state.walls, agent_position) assert jnp.all(action_mask == expected_action_mask) # Check another position @@ -161,12 +161,12 @@ def test_env_maze__action_mask(self, maze_env: Maze) -> None: # The agent can move up, right or down expected_action_mask = jnp.array([True, True, True, False]) - action_mask = maze_env._compute_action_mask(state.walls, another_position) + action_mask = maze._compute_action_mask(state.walls, another_position) assert jnp.all(action_mask == expected_action_mask) - def test_env_maze__reward(self, maze_env: Maze) -> None: + def test_maze__reward(self, maze: Maze) -> None: key = jax.random.PRNGKey(0) - state, timestep = maze_env.reset(key) + state, timestep = maze.reset(key) # Fixed agent and target positions agent_position = Position(row=4, col=0) @@ -176,7 +176,7 @@ def test_env_maze__reward(self, maze_env: Maze) -> None: agent_position=agent_position, target_position=target_position, walls=state.walls, - action_mask=maze_env._compute_action_mask(state.walls, agent_position), + action_mask=maze._compute_action_mask(state.walls, agent_position), key=state.key, step_count=jnp.array(0, jnp.int32), ) @@ -186,19 +186,19 @@ def test_env_maze__reward(self, maze_env: Maze) -> None: for a in actions: assert timestep.reward == 0 assert timestep.step_type < StepType.LAST - state, timestep = maze_env.step(state, a) + state, timestep = maze.step(state, a) # Final step into the target assert timestep.reward == 1 assert timestep.last() assert state.agent_position == state.target_position - def test_env_maze__toy_generator(self) -> None: + def test_maze__toy_generator(self) -> None: key = jax.random.PRNGKey(0) toy_generator = ToyGenerator() - maze_env = Maze(generator=toy_generator, time_limit=25) - state, timestep = maze_env.reset(key) + maze = Maze(generator=toy_generator, time_limit=25) + state, timestep = maze.reset(key) # Fixed agent and target positions agent_position = Position(row=4, col=0) @@ -208,7 +208,7 @@ def test_env_maze__toy_generator(self) -> None: agent_position=agent_position, target_position=target_position, walls=state.walls, - action_mask=maze_env._compute_action_mask(state.walls, agent_position), + action_mask=maze._compute_action_mask(state.walls, agent_position), key=state.key, step_count=jnp.array(0, jnp.int32), ) @@ -218,12 +218,12 @@ def test_env_maze__toy_generator(self) -> None: for a in actions: assert timestep.reward == 0 assert timestep.step_type < StepType.LAST - state, timestep = maze_env.step(state, a) + state, timestep = maze.step(state, a) # Final step into the target assert timestep.reward == 1 assert timestep.last() assert state.agent_position == state.target_position - def test_env_maze__does_not_smoke(self, maze_env: Maze) -> None: - check_env_does_not_smoke(maze_env) + def test_maze__does_not_smoke(self, maze: Maze) -> None: + check_env_does_not_smoke(maze) From 8b364364ffabb8b98a4d27c172bd7a988e8c0fc5 Mon Sep 17 00:00:00 2001 From: Daniel Luo Date: Thu, 30 Mar 2023 10:50:11 +0100 Subject: [PATCH 3/4] test: fix naming in game 2048 tests --- jumanji/environments/logic/game_2048/env_test.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/jumanji/environments/logic/game_2048/env_test.py b/jumanji/environments/logic/game_2048/env_test.py index 559766491..59d1a1eab 100644 --- a/jumanji/environments/logic/game_2048/env_test.py +++ b/jumanji/environments/logic/game_2048/env_test.py @@ -37,7 +37,7 @@ def board() -> Board: return board -def test_env_reset_jit(game_2048: Game2048) -> None: +def test_game_2048__reset_jit(game_2048: Game2048) -> None: """Confirm that the reset method is only compiled once when jitted.""" chex.clear_trace_counter() reset_fn = jax.jit(chex.assert_max_traces(game_2048.reset, n=1)) @@ -57,7 +57,7 @@ def test_env_reset_jit(game_2048: Game2048) -> None: assert isinstance(state, State) -def test_env_step_jit(game_2048: Game2048) -> None: +def test_game_2048__step_jit(game_2048: Game2048) -> None: """Confirm that the step is only compiled once when jitted.""" key = jax.random.PRNGKey(0) state, timestep = game_2048.reset(key) @@ -81,7 +81,7 @@ def test_env_step_jit(game_2048: Game2048) -> None: assert not jnp.array_equal(new_state.board, state.board) -def test_generate_board(game_2048: Game2048) -> None: +def test_game_2048__generate_board(game_2048: Game2048) -> None: """Confirm that `generate_board` method creates an initial board that follows the rules of the 2048 game.""" key = jax.random.PRNGKey(0) @@ -94,7 +94,7 @@ def test_generate_board(game_2048: Game2048) -> None: assert new_value_is_one ^ new_value_is_two -def test_add_random_cell(game_2048: Game2048, board: Board) -> None: +def test_game_2048__add_random_cell(game_2048: Game2048, board: Board) -> None: """Validate that add_random_cell places a 1 or 2 in an empty spot on the board.""" key = jax.random.PRNGKey(0) add_random_cell = jax.jit(game_2048._add_random_cell) @@ -108,7 +108,7 @@ def test_add_random_cell(game_2048: Game2048, board: Board) -> None: assert new_value_is_one ^ new_value_is_two -def test_get_action_mask(game_2048: Game2048, board: Board) -> None: +def test_game_2048__get_action_mask(game_2048: Game2048, board: Board) -> None: """Verify that the action mask generated by `get_action_mask` is correct.""" action_mask_fn = jax.jit(game_2048._get_action_mask) action_mask = action_mask_fn(board) @@ -116,6 +116,6 @@ def test_get_action_mask(game_2048: Game2048, board: Board) -> None: assert jnp.array_equal(action_mask, expected_action_mask) -def test_2048__does_not_smoke(game_2048: Game2048) -> None: +def test_game_2048__does_not_smoke(game_2048: Game2048) -> None: """Test that we can run an episode without any errors.""" check_env_does_not_smoke(game_2048) From e375c9399f2f866c3113e64b7e51024c3caf069b Mon Sep 17 00:00:00 2001 From: Daniel Luo Date: Thu, 30 Mar 2023 11:05:22 +0100 Subject: [PATCH 4/4] test: consistent naming for minesweeper --- jumanji/environments/logic/minesweeper/env_test.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/jumanji/environments/logic/minesweeper/env_test.py b/jumanji/environments/logic/minesweeper/env_test.py index 129d81583..cbe8220ee 100644 --- a/jumanji/environments/logic/minesweeper/env_test.py +++ b/jumanji/environments/logic/minesweeper/env_test.py @@ -99,7 +99,7 @@ def test_default_reward_and_done_signals( assert episode_length == len(actions) -def test_minesweeper_env_reset(minesweeper_env: Minesweeper) -> None: +def test_minesweeper__reset(minesweeper_env: Minesweeper) -> None: """Validates the jitted reset of the environment.""" reset_fn = jit(minesweeper_env.reset) key = random.PRNGKey(0) @@ -120,7 +120,7 @@ def test_minesweeper_env_reset(minesweeper_env: Minesweeper) -> None: assert_is_jax_array_tree(state) -def test_minesweeper_env_step(minesweeper_env: Minesweeper) -> None: +def test_minesweeper__step(minesweeper_env: Minesweeper) -> None: """Validates the jitted step of the environment.""" chex.clear_trace_counter() step_fn = chex.assert_max_traces(minesweeper_env.step, n=2) @@ -154,12 +154,12 @@ def test_minesweeper_env_step(minesweeper_env: Minesweeper) -> None: assert jnp.array_equal(next_next_state.board, next_next_timestep.observation.board) -def test_minesweeper_env_does_not_smoke(minesweeper_env: Minesweeper) -> None: +def test_minesweeper__does_not_smoke(minesweeper_env: Minesweeper) -> None: """Test that we can run an episode without any errors.""" check_env_does_not_smoke(env=minesweeper_env) -def test_minesweeper_env_render( +def test_minesweeper__render( monkeypatch: pytest.MonkeyPatch, minesweeper_env: Minesweeper ) -> None: """Check that the render method builds the figure but does not display it.""" @@ -173,7 +173,7 @@ def test_minesweeper_env_render( minesweeper_env.close() -def test_minesweeper_env_done_invalid_action(minesweeper_env: Minesweeper) -> None: +def test_minesweeper__done_invalid_action(minesweeper_env: Minesweeper) -> None: """Test that the strict done signal is sent correctly""" # Note that this action corresponds to not stepping on a mine action = minesweeper_env.action_spec().generate_value() @@ -183,7 +183,7 @@ def test_minesweeper_env_done_invalid_action(minesweeper_env: Minesweeper) -> No assert episode_length == 2 -def test_minesweeper_env_solved(minesweeper_env: Minesweeper) -> None: +def test_minesweeper__solved(minesweeper_env: Minesweeper) -> None: """Solve the game and verify that things are as expected""" state, timestep = jit(minesweeper_env.reset)(random.PRNGKey(0)) step_fn = jit(minesweeper_env.step)