8000 feat(MMST): multi minimum spanning tree environment by ulricharmel · Pull Request #135 · instadeepai/jumanji · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
< 8000 div class="clearfix mt-4 px-3 px-md-4 px-lg-5">

feat(MMST): multi minimum spanning tree environment #135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
efd8cd2
added folder with code for cooperative minimum spanning tree
ulricharmel May 7, 2023
1c88cb7
fix typo
ulricharmel May 7, 2023
e46df34
add networkx to requirements
ulricharmel May 7, 2023
8d327ac
modified the observation shapes to be fully single agent
ulricharmel May 12, 2023
9eb31bf
entire environment with networks and docs
ulricharmel May 14, 2023
5beae7d
Merge branch 'main' into cmst-trim
ulricharmel May 14, 2023
ebf4be4
fixed linters
ulricharmel May 14, 2023
f5b3c2a
removed certain attributes from viewer
ulricharmel May 14, 2023
c390372
fixed comment in viewer test
ulricharmel May 14, 2023
2e5ed25
renamed to cmst and address initial PR comments
ulricharmel May 15, 2023
9fd9d8c
added cmst config
ulricharmel May 15, 2023
883c33d
changed the environment name from cmst to mmst
ulricharmel May 18, 2023
e12401d
updated readme files
ulricharmel May 18, 2023
1f46623
Merge branch 'main' into feat-cmst
ulricharmel May 23, 2023
9c5431e
modified reward values and settings for default environment
ulricharmel May 24, 2023
edf2325
Merge branch 'main' into feat-cmst
ulricharmel May 24, 2023
d169b05
address PR comments and breakup some functions
ulricharmel May 25, 2023
0c4fe18
Merge branch 'main' into feat-cmst
ulricharmel May 29, 2023
874112a
add extras to the environment
ulricharmel May 29, 2023
e4c5615
defualt configuration for paper
ulricharmel May 30, 2023
51e136f
minor refactoring from PR review
ulricharmel May 31, 2023
6c81948
fixed comments
ulricharmel May 31, 2023
de1055e
fixed more comments
ulricharmel May 31, 2023
1b18bb2
sum rewards in reward function
ulricharmel May 31, 2023
ee8b52c
fixed issue with reward summing after move
ulricharmel May 31, 2023
4872349
update after computing rewards
ulricharmel May 31, 2023
5597b52
Merge branch 'main' into feat-cmst
ulricharmel Jun 1, 2023
dd22f7c
fixed merge conflicts
ulricharmel Jun 1, 2023
9b4abcf
Merge branch 'main' into feat-cmst
clement-bonnet Jun 1, 2023
ca82077
Merge branch 'main' into feat-cmst
ulricharmel Jun 1, 2023
3217cbf
modified generator
ulricharmel Jun 1, 2023
c5787fa
removed truncation
ulricharmel Jun 1, 2023
f4a5438
fixed few comment typos
ulricharmel Jun 1, 2023
b974b7d
fixed linting
ulricharmel Jun 1, 2023
ba7d6cc
Merge branch 'main' into feat-cmst
clement-bonnet Jun 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
<img src="docs/env_anim/rubiks_cube.gif" alt="RubiksCube" width="30%" />
<img src="docs/env_anim/graph_coloring.gif" alt="GraphColoring" width="30%" />
<img src="docs/env_anim/game_2048.gif" alt="Game2048" width="30%" />
<img src="docs/env_anim/minesweeper.gif" alt="Minesweeper" width="30%" />
<img src="docs/env_anim/mmst.gif" alt="MMST" width="30%" />
<img src="docs/env_anim/sudoku.gif" alt="Sudoku" width="30%" />
</p>

Expand Down Expand Up @@ -90,6 +92,7 @@ problems.
| :robot: RobotWarehouse | Routing | `RobotWarehouse-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/robot_warehouse/) | [doc](https://instadeepai.github.io/jumanji/environments/robot_warehouse/) |
| 🐍 Snake | Routing | `Snake-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/snake/) | [doc](https://instadeepai.github.io/jumanji/environments/snake/) |
| 📬 TSP (Travelling Salesman Problem) | Routing | `TSP-v1` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/tsp/) | [doc](https://instadeepai.github.io/jumanji/environments/tsp/) |
| Multi Minimum Spanning Tree Problem | Routing | `MMST-v0` | [code](https://github.com/instadeepai/jumanji/tree/main/jumanji/environments/routing/mmst) | [doc](https://instadeepai.github.io/jumanji/environments/mmst/) |

## Installation 🎬

Expand Down
9 changes: 9 additions & 0 deletions docs/api/environments/mmst.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
::: jumanji.environments.routing.mmst.env.MMST
selection:
members:
- __init__
- observation_spec
- action_spec
- reset
- step
- render
Binary file added docs/env_anim/mmst.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/env_img/mmst.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
76 changes: 76 additions & 0 deletions docs/environments/mmst.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# MMST Environment

<p align="center">
<img src="../env_anim/mmst.gif" width="600"/>
</p>

The multi minimum spanning tree (mmst) environment consists of a random connected graph
with groups of nodes (same node types) that needs to be connected.
The goal of the environment is to connect all nodes of the same type together
without using the same utility nodes (nodes that do not belong to any group of nodes) in the shortest time possible.

An episode ends when all group of nodes are connected or the maximum number of steps is reached.

> Note:
>
> This environment can be treated as a multi agent problem with each agent atempting to connect
> one group of node. In this implementation, we treat the problem as single agent that outputs
> multiple actions per nodes.


## Observation
At each step observation contains 4 items: a node_types, an adjacency matrix for the graph,
an action mask for each group of nodes (agent) and current node positon of each agent.

- `node_types`: Array representing the types of nodes in the problem.
For example, if we have 12 nodes, their indices are 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11.
Let's consider we have 2 agents. Agent 0 wants to connect nodes (0, 1, 9),
and agent 1 wants to connect nodes (3, 5, 8).
The remaining nodes are considered utility nodes.
Therefore, in the state view, the node_types are
represented as [0, 0, -1, 1, -1, 1, 0, -1, 1, 0, -1, -1].
When generating the problem, each agent starts from one of its nodes.
So, if agent 0 starts on node 1 and agent 1 on node 3,
the connected_nodes array will have values [1, -1, ...] and [3, -1, ...] respectively.
The agent's observation is represented using the following rules:
- Each agent should see its connected nodes on the path as 0.
- Nodes that the agent still needs to connect are represented as 1.
- The next agent's nodes are represented by 2 and 3, the next by 4 and 5, and so on.
- Utility unconnected nodes are represented by -1.
For the 12 node example mentioned above,
the expected observation view node_types will have the following values:
node_types = jnp.array(
[
[1, 0, -1, 2, -1, 3, 1, -1, 3, 1, -1, -1],
[3, 2, -1, 0, -1, 1, 3, -1, 1, 3, -1, -1],
],
dtype=jnp.int32,
)
Note: to make the environment single agent, we use the first agent's observation.

- `adj_matrix`: Adjacency matrix representing the connections between nodes.

- `positions`: Current node positions of the agents.
In our current problem, this will be represented as jnp.array([1, 3]).

- `action_mask`: Binary mask indicating the validity of each action.
Given the current node on which the agent is located,
this mask determines if there is a valid edge to every other node.


## Action
The action space is a `MultiDiscreteArray` of shape `(num_agents,)` of integer values in the range
of `[0, num_nodes-1]`. During every step, an agent picks the next node it wants to move to.
An action is invalid if the agent picks a node it has no edge to or the node is a utility node already
been used by another agent.


## Reward
An agent recieves a reward of 10.0 every step it gets a valid connection, a reward of -1.0 if it does not
connect and an extra penalty of -1.0 if choses an invalid action.

The total step reward is the sum of rewards per agent.


## Registered Versions 📖
- `MMST-v0`, 3 agents, 36 nodes, 72 edges, 4 nodes to connect per agent and step limit of 70.
3 changes: 3 additions & 0 deletions jumanji/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@
# Connector with grid size of 10 and 5 agents.
register(id="Connector-v1", entry_point="jumanji.environments:Connector")

# MMST with 3 agents, 36 nodes, 72 edges, 4 nodes to connect per agent, and a time limit of 70.
register(id="MMST-v0", entry_point="jumanji.environments:MMST")

# CVRP with 20 randomly generated nodes, a maximum capacity of 30,
# a maximum demand for each node of 10, and a dense reward function.
register(id="CVRP-v1", entry_point="jumanji.environments:CVRP")
Expand Down
2 changes: 2 additions & 0 deletions jumanji/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
connector,
cvrp,
maze,
mmst,
robot_warehouse,
snake,
tsp,
Expand All @@ -37,6 +38,7 @@
from jumanji.environments.routing.connector.env import Connector
from jumanji.environments.routing.cvrp.env import CVRP
from jumanji.environments.routing.maze.env import Maze
from jumanji.environments.routing.mmst.env import MMST
from jumanji.environments.routing.robot_warehouse.env import RobotWarehouse
from jumanji.environments.routing.snake.env import Snake
from jumanji.environments.routing.tsp.env import TSP
Expand Down
16 changes: 16 additions & 0 deletions jumanji/environments/routing/mmst/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from jumanji.environments.routing.mmst.env import MMST
from jumanji.environments.routing.mmst.types import Observation, State
147 changes: 147 additions & 0 deletions jumanji/environments/routing/mmst/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

import jax
import jax.numpy as jnp
import pytest

from jumanji.environments.routing.mmst.env import MMST
from jumanji.environments.routing.mmst.generator import SplitRandomGenerator
from jumanji.environments.routing.mmst.types import State
from jumanji.environments.routing.mmst.utils import (
build_adjecency_matrix,
make_action_mask,
update_active_edges,
)
from jumanji.types import TimeStep, restart


@pytest.fixture(scope="module")
def mmst_split_gn_env() -> MMST:
"""Instantiates a default `MMST` environment."""

return MMST(
generator=None,
reward_fn=None,
)


@pytest.fixture(scope="module")
def deterministic_mmst_env() -> Tuple[MMST, State, TimeStep]:
"""Instantiates a `MMST` environment."""

num_nodes_per_agent = 3

env = MMST(
generator=SplitRandomGenerator(
num_nodes=12,
num_edges=18,
max_degree=5,
num_agents=2,
num_nodes_per_agent=num_nodes_per_agent,
max_step=12,
),
reward_fn=None,
time_limit=12,
)

state, timestep = env.reset(jax.random.PRNGKey(10))

key = jax.random.PRNGKey(0)

num_agents = 2
num_nodes = 12
nodes_to_connect = jnp.array([[0, 1, 6], [3, 5, 8]], dtype=jnp.int32)

edges = jnp.array(
[
[0, 1],
[0, 3],
[1, 2],
[1, 4],
[2, 4],
[2, 5],
[3, 4],
[3, 7],
[4, 5],
[4, 8],
[6, 7],
[6, 10],
[7, 8],
[8, 9],
[8, 10],
[9, 10],
[9, 11],
],
dtype=jnp.int32,
)

adj_matrix = build_adjecency_matrix(12, edges)

node_edges = jnp.ones((12, 12)) * -1
node_edges = node_edges.at[0, [1, 3]].set(jnp.array([1, 3]))
node_edges = node_edges.at[1, [0, 2, 4]].set(jnp.array([0, 2, 4]))
node_edges = node_edges.at[2, [1, 4, 5]].set(jnp.array([1, 4, 5]))
node_edges = node_edges.at[3, [0, 4, 7]].set(jnp.array([0, 4, 7]))
node_edges = node_edges.at[4, [1, 2, 3, 5, 8]].set(jnp.array([1, 2, 3, 5, 8]))
node_edges = node_edges.at[5, [2, 4]].set(jnp.array([2, 4]))
node_edges = node_edges.at[6, [7, 10]].set(jnp.array([7, 10]))
node_edges = node_edges.at[7, [3, 6, 8]].set(jnp.array([3, 6, 8]))
node_edges = node_edges.at[8, [4, 7, 9, 10]].set(jnp.array([4, 7, 9, 10]))
node_edges = node_edges.at[9, [8, 10, 11]].set(jnp.array([8, 10, 11]))
node_edges = node_edges.at[10, [6, 8, 9]].set(jnp.array([6, 8, 9]))
node_edges = node_edges.at[11, [9]].set(jnp.array([9]))

node_edges = jnp.array(node_edges, dtype=jnp.int32)

node_types = jnp.array([0, 0, -1, 1, -1, 1, 0, -1, 1, 0, -1, -1], dtype=jnp.int32)

conn_nodes = -1 * jnp.ones((2, 12), dtype=jnp.int32)
conn_nodes = conn_nodes.at[0, 0].set(1)
conn_nodes = conn_nodes.at[1, 0].set(3)

conn_nodes_index = -1 * jnp.ones((2, 12), dtype=jnp.int32)
conn_nodes_index = conn_nodes_index.at[0, 1].set(1)
conn_nodes_index = conn_nodes_index.at[1, 3].set(3)

positions = jnp.array([1, 3], dtype=jnp.int32)

active_node_edges = jnp.repeat(node_edges[None, ...], num_agents, axis=0)
active_node_edges = update_active_edges(
num_agents, active_node_edges, positions, node_types
)
finished_agents = jnp.zeros((num_agents), dtype=bool)

state = State(
node_types=node_types,
adj_matrix=adj_matrix,
connected_nodes=conn_nodes,
connected_nodes_index=conn_nodes_index,
nodes_to_connect=nodes_to_connect,
position_index=jnp.zeros((num_agents), dtype=jnp.int32),
positions=positions,
node_edges=active_node_edges,
action_mask=make_action_mask(
num_agents, num_nodes, active_node_edges, positions, finished_agents
),
finished_agents=finished_agents,
step_count=jnp.array(0, int),
key=key,
)

timestep = restart(observation=env._state_to_observation(state), shape=num_agents)

return env, state, timestep
27 changes: 27 additions & 0 deletions jumanji/environments/routing/mmst/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Node
INVALID_NODE = -1
UTILITY_NODE = -1
EMPTY_NODE = -1
DUMMY_NODE = -10

# Edges
EMPTY_EDGE = -1

# Actions
INVALID_CHOICE = -1
INVALID_TIE_BREAK = -2
INVALID_ALREADY_TRAVERSED = -3
Loading
0