8000 test: consistent naming in Cleaner, Maze, Game2048, and Minesweeper by dluo96 · Pull Request #112 · instadeepai/jumanji · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

test: consistent naming in Cleaner, Maze, Game2048, and Minesweeper #112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions jumanji/environments/logic/game_2048/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def board() -> Board:
return board


def test_env_reset_jit(game_2048: Game2048) -> None:
def test_game_2048__reset_jit(game_2048: Game2048) -> None:
"""Confirm that the reset method is only compiled once when jitted."""
chex.clear_trace_counter()
reset_fn = jax.jit(chex.assert_max_traces(game_2048.reset, n=1))
Expand All @@ -57,7 +57,7 @@ def test_env_reset_jit(game_2048: Game2048) -> None:
assert isinstance(state, State)


def test_env_step_jit(game_2048: Game2048) -> None:
def test_game_2048__step_jit(game_2048: Game2048) -> None:
"""Confirm that the step is only compiled once when jitted."""
key = jax.random.PRNGKey(0)
state, timestep = game_2048.reset(key)
Expand All @@ -81,7 +81,7 @@ def test_env_step_jit(game_2048: Game2048) -> None:
assert not jnp.array_equal(new_state.board, state.board)


def test_generate_board(game_2048: Game2048) -> None:
def test_game_2048__generate_board(game_2048: Game2048) -> None:
"""Confirm that `generate_board` method creates an initial board that
follows the rules of the 2048 game."""
key = jax.random.PRNGKey(0)
Expand All @@ -94,7 +94,7 @@ def test_generate_board(game_2048: Game2048) -> None:
assert new_value_is_one ^ new_value_is_two


def test_add_random_cell(game_2048: Game2048, board: Board) -> None:
def test_game_2048__add_random_cell(game_2048: Game2048, board: Board) -> None:
"""Validate that add_random_cell places a 1 or 2 in an empty spot on the board."""
key = jax.random.PRNGKey(0)
add_random_cell = jax.jit(game_2048._add_random_cell)
Expand All @@ -108,14 +108,14 @@ def test_add_random_cell(game_2048: Game2048, board: Board) -> None:
assert new_value_is_one ^ new_value_is_two


def test_get_action_mask(game_2048: Game2048, board: Board) -> None:
def test_game_2048__get_action_mask(game_2048: Game2048, board: Board) -> None:
"""Verify that the action mask generated by `get_action_mask` is correct."""
action_mask_fn = jax.jit(game_2048._get_action_mask)
action_mask = action_mask_fn(board)
expected_action_mask = jnp.array([False, True, True, True])
assert jnp.array_equal(action_mask, expected_action_mask)


def test_2048__does_not_smoke(game_2048: Game2048) -> None:
def test_game_2048__does_not_smoke(game_2048: Game2048) -> None:
"""Test that we can run an episode without any errors."""
check_env_does_not_smoke(game_2048)
12 changes: 6 additions & 6 deletions jumanji/environments/logic/minesweeper/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_default_reward_and_done_signals(
assert episode_length == len(actions)


def test_minesweeper_env_reset(minesweeper_env: Minesweeper) -> None:
def test_minesweeper__reset(minesweeper_env: Minesweeper) -> None:
"""Validates the jitted reset of the environment."""
reset_fn = jit(minesweeper_env.reset)
key = random.PRNGKey(0)
Expand All @@ -120,7 +120,7 @@ def test_minesweeper_env_reset(minesweeper_env: Minesweeper) -> None:
assert_is_jax_array_tree(state)


def test_minesweeper_env_step(minesweeper_env: Minesweeper) -> None:
def test_minesweeper__step(minesweeper_env: Minesweeper) -> None:
"""Validates the jitted step of the environment."""
chex.clear_trace_counter()
step_fn = chex.assert_max_traces(minesweeper_env.step, n=2)
Expand Down Expand Up @@ -154,12 +154,12 @@ def test_minesweeper_env_step(minesweeper_env: Minesweeper) -> None:
assert jnp.array_equal(next_next_state.board, next_next_timestep.observation.board)


def test_minesweeper_env_does_not_smoke(minesweeper_env: Minesweeper) -> None:
def test_minesweeper__does_not_smoke(minesweeper_env: Minesweeper) -> None:
"""Test that we can run an episode without any errors."""
check_env_does_not_smoke(env=minesweeper_env)


def test_minesweeper_env_render(
def test_minesweeper__render(
monkeypatch: pytest.MonkeyPatch, minesweeper_env: Minesweeper
) -> None:
"""Check that the render method builds the figure but does not display it."""
Expand All @@ -173,7 +173,7 @@ def test_minesweeper_env_render(
minesweeper_env.close()


def test_minesweeper_env_done_invalid_action(minesweeper_env: Minesweeper) -> None:
def test_minesweeper__done_invalid_action(minesweeper_env: Minesweeper) -> None:
"""Test that the strict done signal is sent correctly"""
# Note that this action corresponds to not stepping on a mine
action = minesweeper_env.action_spec().generate_value()
Expand All @@ -183,7 +183,7 @@ def test_minesweeper_env_done_invalid_action(minesweeper_env: Minesweeper) -> No
assert episode_length == 2


def test_minesweeper_env_solved(minesweeper_env: Minesweeper) -> None:
def test_minesweeper__solved(minesweeper_env: Minesweeper) -> None:
"""Solve the game and verify that things are as expected"""
state, timestep = jit(minesweeper_env.reset)(random.PRNGKey(0))
step_fn = jit(minesweeper_env.step)
Expand Down
60 changes: 30 additions & 30 deletions jumanji/environments/routing/cleaner/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,18 @@ def __call__(self, key: chex.PRNGKey) -> Maze:

class TestCleaner:
@pytest.fixture
def env(self) -> Cleaner:
def cleaner(self) -> Cleaner:
generator = DummyGenerator()
return Cleaner(num_agents=N_AGENT, generator=generator)

@pytest.fixture
def key(self) -> chex.PRNGKey:
return jax.random.PRNGKey(0)

def test_env__reset_jit(self, env: Cleaner) -> None:
def test_cleaner__reset_jit(self, cleaner: Cleaner) -> None:
"""Confirm that the reset is only compiled once when jitted."""
chex.clear_trace_counter()
reset_fn = jax.jit(chex.assert_max_traces(env.reset, n=1))
reset_fn = jax.jit(chex.assert_max_traces(cleaner.reset, n=1))
key = jax.random.PRNGKey(0)
state, timestep = reset_fn(key)

Expand All @@ -67,8 +67,8 @@ def test_env__reset_jit(self, env: Cleaner) -> None:
assert isinstance(timestep, TimeStep)
assert isinstance(state, State)

def test_env_cleaner__reset(self, env: Cleaner, key: chex.PRNGKey) -> None:
reset_fn = jax.jit(env.reset)
def test_cleaner__reset(self, cleaner: Cleaner, key: chex.PRNGKey) -> None:
reset_fn = jax.jit(cleaner.reset)
state, timestep = reset_fn(key)

assert isinstance(timestep, TimeStep)
Expand All @@ -80,33 +80,33 @@ def test_env_cleaner__reset(self, env: Cleaner, key: chex.PRNGKey) -> None:

assert_is_jax_array_tree(state)

def test_env__step_jit(self, env: Cleaner) -> None:
def test_cleaner__step_jit(self, cleaner: Cleaner) -> None:
"""Confirm that the step is only compiled once when jitted."""
key = jax.random.PRNGKey(0)
state, timestep = env.reset(key)
state, timestep = cleaner.reset(key)
action = jnp.array([1, 2, 3], jnp.int32)

chex.clear_trace_counter()
step_fn = jax.jit(chex.assert_max_traces(env.step, n=1))
step_fn = jax.jit(chex.assert_max_traces(cleaner.step, n=1))
next_state, next_timestep = step_fn(state, action)

# Call again to check it does not compile twice
next_state, next_timestep = step_fn(state, action)
assert isinstance(next_timestep, TimeStep)
assert isinstance(next_state, State)

def test_env_cleaner__step(self, env: Cleaner, key: chex.PRNGKey) -> None:
initial_state, timestep = env.reset(key)
def test_cleaner__step(self, cleaner: Cleaner, key: chex.PRNGKey) -> None:
initial_state, timestep = cleaner.reset(key)

step_fn = jax.jit(env.step)
step_fn = jax.jit(cleaner.step)

# First action: all agents move right
actions = jnp.array([1] * N_AGENT)
state, timestep = step_fn(initial_state, actions)
# Assert only one tile changed, on the right of the initial pos
assert jnp.sum(state.grid != initial_state.grid) == 1
assert state.grid[0, 1] == CLEAN
assert timestep.reward == 1 - env.penalty_per_timestep
assert timestep.reward == 1 - cleaner.penalty_per_timestep
assert jnp.all(state.agents_locations == jnp.array([0, 1]))

# Second action: agent 0 and 2 move down, agent 1 moves left
Expand All @@ -116,19 +116,19 @@ def test_env_cleaner__step(self, env: Cleaner, key: chex.PRNGKey) -> None:
assert jnp.sum(state.grid != initial_state.grid) == 2
assert state.grid[0, 1] == CLEAN
assert state.grid[1, 1] == CLEAN
assert timestep.reward == 1 - env.penalty_per_timestep
assert timestep.reward == 1 - cleaner.penalty_per_timestep
assert timestep.step_type == StepType.MID

assert jnp.all(state.agents_locations[0] == jnp.array([1, 1]))
assert jnp.all(state.agents_locations[1] == jnp.array([0, 0]))
assert jnp.all(state.agents_locations[2] == jnp.array([1, 1]))

def test_env_cleaner__step_invalid_action(
self, env: Cleaner, key: chex.PRNGKey
def test_cleaner__step_invalid_action(
self, cleaner: Cleaner, key: chex.PRNGKey
) -> None:
state, _ = env.reset(key)
state, _ = cleaner.reset(key)

step_fn = jax.jit(env.step)
step_fn = jax.jit(cleaner.step)
# Invalid action for agent 0, valid for 1 and 2
actions = jnp.array([0, 1, 1])
state, timestep = step_fn(state, actions)
Expand All @@ -139,12 +139,12 @@ def test_env_cleaner__step_invalid_action(
assert jnp.all(state.agents_locations[1] == jnp.array([0, 1]))
assert jnp.all(state.agents_locations[2] == jnp.array([0, 1]))

assert timestep.reward == 1 - env.penalty_per_timestep
assert timestep.reward == 1 - cleaner.penalty_per_timestep

def test_env_cleaner__initial_action_mask(
self, env: Cleaner, key: chex.PRNGKey
def test_cleaner__initial_action_mask(
self, cleaner: Cleaner, key: chex.PRNGKey
) -> None:
state, _ = env.reset(key)
state, _ = cleaner.reset(key)

# All agents can only move right in the initial state
expected_action_mask = jnp.array(
Expand All @@ -153,21 +153,21 @@ def test_env_cleaner__initial_action_mask(

assert jnp.all(state.action_mask == expected_action_mask)

action_mask = env._compute_action_mask(state.grid, state.agents_locations)
action_mask = cleaner._compute_action_mask(state.grid, state.agents_locations)
assert jnp.all(action_mask == expected_action_mask)

def test_env_cleaner__action_mask(self, env: Cleaner, key: chex.PRNGKey) -> None:
state, _ = env.reset(key)
def test_cleaner__action_mask(self, cleaner: Cleaner, key: chex.PRNGKey) -> None:
state, _ = cleaner.reset(key)

# Test action mask for different agent locations
agents_locations = jnp.array([[1, 1], [2, 2], [4, 4]])
action_mask = env._compute_action_mask(state.grid, agents_locations)
action_mask = cleaner._compute_action_mask(state.grid, agents_locations)

assert jnp.all(action_mask[0] == jnp.array([True, False, True, False]))
assert jnp.all(action_mask[1] == jnp.array([False, True, False, True]))
assert jnp.all(action_mask[2] == jnp.array([False, False, False, True]))

def test_env_cleaner__does_not_smoke(self, env: Cleaner) -> None:
def test_cleaner__does_not_smoke(self, cleaner: Cleaner) -> None:
def select_actions(key: chex.PRNGKey, observation: Observation) -> chex.Array:
@jax.vmap # map over the keys and agents
def select_action(
Expand All @@ -180,12 +180,12 @@ def select_action(
subkeys = jax.random.split(key, N_AGENT)
return select_action(subkeys, observation.action_mask)

check_env_does_not_smoke(env, select_actions)
check_env_does_not_smoke(cleaner, select_actions)

def test_env_cleaner__compute_extras(self, env: Cleaner, key: chex.PRNGKey) -> None:
state, _ = env.reset(key)
def test_cleaner__compute_extras(self, cleaner: Cleaner, key: chex.PRNGKey) -> None:
state, _ = cleaner.reset(key)

extras = env._compute_extras(state)
extras = cleaner._compute_extras(state)
assert list(extras.keys()) == ["ratio_dirty_tiles", "num_dirty_tiles"]
assert 0 <= extras["ratio_dirty_tiles"] <= 1
grid = state.grid
Expand Down
Loading
0