8000 feat(tsp): generator by surana01 · Pull Request #137 · instadeepai/jumanji · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

feat(tsp): generator #137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 23, 2023
57 changes: 55 additions & 2 deletions jumanji/environments/routing/tsp/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import chex
import jax
import jax.numpy as jnp
import pytest

from jumanji.environments.routing.tsp.env import TSP
from jumanji.environments.routing.tsp.generator import Generator, UniformGenerator
from jumanji.environments.routing.tsp.reward import DenseReward, SparseReward
from jumanji.environments.routing.tsp.types import State


@pytest.fixture
Expand All @@ -31,10 +36,58 @@ def sparse_reward() -> SparseReward:
@pytest.fixture
def tsp_dense_reward(dense_reward: DenseReward) -> TSP:
"""Instantiates a TSP environment with dense rewards and 5 cities."""
return TSP(num_cities=5, reward_fn=dense_reward)
return TSP(generator=UniformGenerator(num_cities=5), reward_fn=dense_reward)


@pytest.fixture
def tsp_sparse_reward(sparse_reward: SparseReward) -> TSP:
"""Instantiates a TSP environment with sparse rewards and 5 cities."""
return TSP(num_cities=5, reward_fn=sparse_reward)
return TSP(generator=UniformGenerator(num_cities=5), reward_fn=sparse_reward)


class DummyGenerator(Generator):
"""Hardcoded `Generator` mainly used for testing and debugging. It deterministically
outputs a hardcoded instance with 5 cities.
"""

def __init__(self) -> None:
super().__init__(num_cities=5)

def __call__(self, key: chex.PRNGKey) -> State:
"""Call method responsible for generating a new state. It returns a travelling salesman
problem instance without any visited cities.

Args:
key: jax random key for any stochasticity used in the generation process. Not used
in this instance generator.

Returns:
A TSP State.
"""
del key

coordinates = jnp.array(
[[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0], [0.5, 0.5]],
float,
)

# Initially, the position is set to -1, which means that the agent is not in any city.
position = jnp.array(-1, jnp.int32)

# Initially, the agent has not visited any city.
visited_mask = jnp.array([False, False, False, False, False])
trajectory = jnp.array([-1, -1, -1, -1, -1], jnp.int32)

# The number of visited cities is set to 0.
num_visited = jnp.array(0, jnp.int32)

state = State(
coordinates=coordinates,
position=position,
visited_mask=visited_mask,
trajectory=trajectory,
num_visited=num_visited,
key=jax.random.PRNGKey(0),
)

return state
25 changes: 10 additions & 15 deletions jumanji/environments/routing/tsp/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from jumanji import specs
from jumanji.env import Environment
from jumanji.environments.routing.tsp.generator import Generator, UniformGenerator
from jumanji.environments.routing.tsp.reward import DenseReward, RewardFn
from jumanji.environments.routing.tsp.types import Observation, State
from jumanji.environments.routing.tsp.viewer import TSPViewer
Expand Down Expand Up @@ -90,22 +91,27 @@ class TSP(Environment[State]):

def __init__(
self,
num_cities: int = 20,
generator: Optional[Generator] = None,
reward_fn: Optional[RewardFn] = None,
viewer: Optional[Viewer[State]] = None,
):
"""Instantiates a `TSP` environment.

Args:
num_cities: number of cities to visit. Defaults to 20.
generator: `Generator` whose `__call__` instantiates an environment instance.
The default option is 'UniformGenerator' which randomly generates
TSP instances with 20 cities sampled from a uniform distribution.
reward_fn: RewardFn whose `__call__` method computes the reward of an environment
transition. The function must compute the reward based on the current state,
the chosen action and the next state.
Implemented options are [`DenseReward`, `SparseReward`]. Defaults to `DenseReward`.
viewer: `Viewer` used for rendering. Defaults to `TSPViewer` with "human" render mode.
"""

self.num_cities = num_cities
self.generator = generator or UniformGenerator(
num_cities=20,
)
self.num_cities = self.generator.num_cities
self.reward_fn = reward_fn or DenseReward()
self._viewer = viewer or TSPViewer(name="TSP", render_mode="human")

Expand All @@ -123,18 +129,7 @@ def reset(self, key: PRNGKey) -> Tuple[State, TimeStep[Observation]]:
timestep: TimeStep object corresponding to the first timestep returned
by the environment.
"""
key, sample_key = jax.random.split(key)
coordinates = jax.random.uniform(
sample_key, (self.num_cities, 2), minval=0, maxval=1
)
state = State(
coordinates=coordinates,
position=jnp.array(-1, jnp.int32),
visited_mask=jnp.zeros(self.num_cities, dtype=bool),
trajectory=jnp.full(self.num_cities, -1, jnp.int32),
num_visited=jnp.array(0, jnp.int32),
key=key,
)
state = self.generator(key)
timestep = restart(observation=self._state_to_observation(state))
return state, timestep

Expand Down
85 changes: 85 additions & 0 deletions jumanji/environments/routing/tsp/generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import abc

import chex
import jax
import jax.numpy as jnp

from jumanji.environments.routing.tsp.types import State


class Generator(abc.ABC):
"""Defines the abstract `Generator` base class. A `Generator` is responsible
for generating a problem instance when the environment is reset.
"""

def __init__(self, num_cities: int):
"""Abstract class implementing the attribute `num_cities`.

Args:
num_cities (int): the number of cities in the problem instance.
"""
self.num_cities = num_cities

@abc.abstractmethod
def __call__(self, key: chex.PRNGKey) -> State:
"""Call method responsible for generating a new state.

Args:
key: jax random key in case stochasticity is used in the instance generation process.

Returns:
A `TSP` environment state.
"""


class UniformGenerator(Generator):
"""Instance generator that generates a random uniform instance of the traveling salesman
problem. Given the number of cities, the coordinates of the cities are randomly sampled from a
uniform distribution on the unit square.
"""

def __init__(self, num_cities: int):
super().__init__(num_cities)

def __call__(self, key: chex.PRNGKey) -> State:
key, sample_key = jax.random.split(key)

# Randomly sample the coordinates of the cities.
coordinates = jax.random.uniform(
sample_key, (self.num_cities, 2), minval=0, maxval=1
)

# Initially, the position is set to -1, which means that the agent is not in any city.
position = jnp.array(-1, jnp.int32)

# Initially, the agent has not visited any city.
visited_mask = jnp.zeros(self.num_cities, dtype=bool)
trajectory = jnp.full(self.num_cities, -1, jnp.int32)

# The number of visited cities is set to 0.
num_visited = jnp.array(0, jnp.int32)

state = State(
coordinates=coordinates,
position=position,
visited_mask=visited_mask,
trajectory=trajectory,
num_visited=num_visited,
key=key,
)

return state
69 changes: 69 additions & 0 deletions jumanji/environments/routing/tsp/generator_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import chex
import jax
import pytest

from jumanji.environments.routing.tsp.conftest import DummyGenerator
from jumanji.environments.routing.tsp.generator import UniformGenerator
from jumanji.environments.routing.tsp.types import State
from jumanji.testing.pytrees import assert_trees_are_different, assert_trees_are_equal


class TestDummyGenerator:
@pytest.fixture
def dummy_generator(self) -> DummyGenerator:
return DummyGenerator()

def test_dummy_generator__properties(self, dummy_generator: DummyGenerator) -> None:
"""Validate that the dummy instance generator has the correct properties."""
assert dummy_generator.num_cities == 5

def test_dummy_generator__call(self, dummy_generator: DummyGenerator) -> None:
"""Validate that the dummy instance generator's call function behaves correctly,
that it is jit-table and compiles only once, and that it returns the same state
for different keys.
"""
chex.clear_trace_counter()
call_fn = jax.jit(chex.assert_max_traces(dummy_generator.__call__, n=1))
state1 = call_fn(jax.random.PRNGKey(1))
state2 = call_fn(jax.random.PRNGKey(2))
assert_trees_are_equal(state1, state2)


class TestUniformGenerator:
@pytest.fixture
def uniform_generator(self) -> UniformGenerator:
return UniformGenerator(
num_cities=50,
)

def test_uniform_generator__properties(
self, uniform_generator: UniformGenerator
) -> None:
"""Validate that the random instance generator has the correct properties."""
assert uniform_generator.num_cities == 50

def test_uniform_generator__call(self, uniform_generator: UniformGenerator) -> None:
"""Validate that the random instance generator's call function is jit-able and compiles
only once. Also check that giving two different keys results in two different instances.
"""
chex.clear_trace_counter()
call_fn = jax.jit(chex.assert_max_traces(uniform_generator.__call__, n=1))
state1 = call_fn(key=jax.random.PRNGKey(1))
assert isinstance(state1, State)

state2 = call_fn(key=jax.random.PRNGKey(2))
assert_trees_are_different(state1, state2)
0