8000 Add support for logging to W&B by LukasSchaefer · Pull Request #71 · uoe-agents/epymarl · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add support for logging to W&B #71

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

EPyMARL is an extension of [PyMARL](https://github.com/oxwhirl/pymarl), and includes
- **New!** Support for training in environments with individual rewards for all agents (for all algorithms that support such settings)
- **New!** Support for logging to [weights and biases (W&B)](https://wandb.ai/)
- Additional algorithms (IA2C, IPPO, MADDPG, MAA2C and MAPPO)
- Support for [Gym](https://github.com/openai/gym) environments (on top of the existing SMAC support)
- Option for no-parameter sharing between agents (original PyMARL only allowed for parameter sharing)
Expand All @@ -27,6 +28,25 @@ When using the `common_reward=True` setup in environments which naturally provid
### Plotting script
We have added a simple plotting script under `plot_results.py` to load data from sacred logs and visualise them for executed experiments. The script supports plotting of any logged metric, can apply simple window-smoothing, aggregates results across multiple runs of the same algorithm, and can filter which results to plot based on algorithm and environment names.

If multiple configs of the same algorithm exist within the loaded data and you only want to plot the best config per algorithm, then add the `--best_per_alg` argument! If this argument is not set, the script will visualise all configs of each (filtered) algorithm and show the values of the hyperparameter config that differ across all present configs in the legend.

### Weights and Biases (W&B) Logging
We now support logging to W&B! To log data to W&B, you need to install the library with `pip install wandb` and setup W&B (see their [documentation](https://docs.wandb.ai/quickstart)). To tell EPyMARL to log data to W&B, you then need to specify the following config parameters:
```yaml
use_wandb: True # Log results to W&B
wandb_team: null # W&B team name
wandb_project: null # W&B project name
```
to specify the team and project you wish to log to within your account, and set `use_wandb=True`. By default, we log all W&B runs in "offline" mode, i.e. the data will only be stored locally and can be uploaded to your W&B account via `wandb sync ...`. To directly log runs online, please specify `wandb_mode="online"` within the config.

We also support logging all stored models directly to W&B so you can download and inspect these from the W&B online dashboard. To do so, use the following config parameters:
```yaml
wandb_save_model: True # Save models to W&B (only done if use_wandb is True and save_model is True)
save_model: True # Save the models to disk
save_model_interval: 50000
```
Note that models are only saved in general if `save_model=True` and to further log them to W&B you need to specify `use_wandb`, `wandb_team`, `wandb_project`, and `wandb_save_model=True`.

## Update as of *15th July 2023*!
We have released our _Pareto Actor-Critic_ algorithm, accepted in TMLR, as part of the E-PyMARL source code.

Expand Down
5 changes: 5 additions & 0 deletions src/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ buffer_cpu_only: True # If true we won't keep all of the replay buffer in vram

# --- Logging options ---
use_tensorboard: False # Log results to tensorboard
use_wandb: False # Log results to W&B
wandb_team: null # W&B team name
wandb_project: null # W&B project name
wandb_mode: "offline" # W&B mode (online/offline)
wandb_save_model: False # Save models to W&B (only done if use_wandb is True and save_model is True)
save_model: False # Save the models to disk
save_model_interval: 50000 # Save models after this many timesteps
checkpoint_path: "" # Load a checkpoint from this path
Expand Down
25 changes: 21 additions & 4 deletions src/run.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import datetime
import os
from os.path import dirname, abspath
import pprint
import shutil
import time
import threading
import torch as th
from types import SimpleNamespace as SN
from os.path import dirname, abspath

from learners import REGISTRY as le_REGISTRY
from runners import REGISTRY as r_REGISTRY
import torch as th

from controllers import REGISTRY as mac_REGISTRY
from components.episode_buffer import ReplayBuffer
from components.transforms import OneHot
from learners import REGISTRY as le_REGISTRY
from runners import REGISTRY as r_REGISTRY
from utils.general_reward_support import test_alg_config_supports_reward
from utils.logging import Logger
from utils.timehelper import time_left, time_str
Expand Down Expand Up @@ -53,6 +55,11 @@ def run(_run, _config, _log):
tb_exp_direc = os.path.join(tb_logs_direc, "{}").format(unique_token)
logger.setup_tb(tb_exp_direc)

if args.use_wandb:
logger.setup_wandb(
_config, args.wandb_team, args.wandb_project, args.wandb_mode
)

# sacred is on by default
logger.setup_sacred(_run)

Expand Down Expand Up @@ -236,6 +243,16 @@ def run_sequential(args, logger):
# use appropriate filenames to do critics, optimizer states
learner.save_models(save_path)

if args.use_wandb and args.wandb_save_model:
wandb_save_dir = os.path.join(
logger.wandb.dir, "models", args.unique_token, str(runner.t_env)
)
os.makedirs(wandb_save_dir, exist_ok=True)
for f in os.listdir(save_path):
shutil.copyfile(
os.path.join(save_path, f), os.path.join(wandb_save_dir, f)
)

episode += args.batch_size_run

if (runner.t_env - last_log_T) >= args.log_interval:
Expand Down
6 changes: 4 additions & 2 deletions src/runners/episode_runner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from envs import REGISTRY as env_REGISTRY
from functools import partial
from components.episode_buffer import EpisodeBatch

import numpy as np

from components.episode_buffer import EpisodeBatch
from envs import REGISTRY as env_REGISTRY


class EpisodeRunner:
def __init__(self, args, logger):
Expand Down
6 changes: 4 additions & 2 deletions src/runners/parallel_runner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from envs import REGISTRY as env_REGISTRY
from functools import partial
from components.episode_buffer import EpisodeBatch
from multiprocessing import Pipe, Process

import numpy as np

from components.episode_buffer import EpisodeBatch
from envs import REGISTRY as env_REGISTRY


# Based (very) heavily on SubprocVecEnv from OpenAI Baselines
# https://github.com/openai/baselines/blob/master/baselines/common/vec_env/subproc_vec_env.py
Expand Down
85 changes: 79 additions & 6 deletions src/utils/logging.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from collections import defaultdict
from hashlib import sha256
import json
import logging

import numpy as np


class Logger:
def __init__(self, console_logger):
self.console_logger = console_logger

self.use_tb = False
self.use_wandb = False
self.use_sacred = False
self.use_hdf = False

Expand All @@ -15,10 +20,63 @@ def __init__(self, console_logger):
def setup_tb(self, directory_name):
# Import here so it doesn't have to be installed if you don't use it
from tensorboard_logger import configure, log_value

configure(directory_name)
self.tb_logger = log_value
self.use_tb = True

self.console_logger.info("*******************")
self.console_logger.info("Tensorboard logging dir:")
self.console_logger.info(f"{directory_name}")
self.console_logger.info("*******************")

def setup_wandb(self, config, team_name, project_name, mode):
import wandb

assert (
team_name is not None and project_name is not None
), "W&B logging requires specification of both `wandb_team` and `wandb_project`."
assert (
mode in ["offline", "online"]
), f"Invalid value for `wandb_mode`. Received {mode} but only 'online' and 'offline' are supported."

self.use_wandb = True

alg_name = config["name"]
env_name = config["env"]
if "map_name" in config["env_args"]:
env_name += "_" + config["env_args"]["map_name"]
elif "key" in config["env_args"]:
env_name += "_" + config["env_args"]["key"]

non_hash_keys = ["seed"]
self.config_hash = sha256(
json.dumps(
{k: v for k, v in config.items() if k not in non_hash_keys},
sort_keys=True,
).encode("utf8")
).hexdigest()[-10:]

group_name = "_".join([alg_name, env_name, self.config_hash])

self.wandb = wandb.init(
entity=team_name,
project=project_name,
config=config,
group=group_name,
mode=mode,
)

self.console_logger.info("*******************")
self.console_logger.info("WANDB RUN ID:")
self.console_logger.info(f"{self.wandb.id}")
self.console_logger.info("*******************")

# accumulate data at same timestep and only log in one batch once
# all data has been gathered
self.wandb_current_t = -1
self.wandb_current_data = {}

def setup_sacred(self, sacred_run_dict):
self._run_obj = sacred_run_dict
self.sacred_info = sacred_run_dict.info
Expand All @@ -30,6 +88,16 @@ def log_stat(self, key, value, t, to_sacred=True):
if self.use_tb:
self.tb_logger(key, value, t)

if self.use_wandb:
if self.wandb_current_t != t and self.wandb_current_data:
# self.console_logger.info(
# f"Logging to WANDB: {self.wandb_current_data} at t={self.wandb_current_t}"
# )
self.wandb.log(self.wandb_current_data, step=self.wandb_current_t)
self.wandb_current_data = {}
self.wandb_current_t = t
self.wandb_current_data[key] = value

if self.use_sacred and to_sacred:
if key in self.sacred_info:
self.sacred_info["{}_T".format(key)].append(t)
Expand All @@ -41,17 +109,21 @@ def log_stat(self, key, value, t, to_sacred=True):
self._run_obj.log_scalar(key, value, t)

def print_recent_stats(self):
log_str = "Recent Stats | t_env: {:>10} | Episode: {:>8}\n".format(*self.stats["episode"][-1])
log_str = "Recent Stats | t_env: {:>10} | Episode: {:>8}\n".format(
*self.stats["episode"][-1]
)
i = 0
for (k, v) in sorted(self.stats.items()):
for k, v in sorted(self.stats.items()):
if k == "episode":
continue
i += 1
window = 5 if k != "epsilon" else 1
try:
item = "{:.4f}".format(np.mean([x[1] for x in self.stats[k][-window:]]))
except:
item = "{:.4f}".format(np.mean([x[1].item() for x in self.stats[k][-window:]]))
item = "{:.4f}".format(
np.mean([x[1].item() for x in self.stats[k][-window:]])
)
log_str += "{:<25}{:>8}".format(k + ":", item)
log_str += "\n" if i % 4 == 0 else "\t"
self.console_logger.info(log_str)
Expand All @@ -62,10 +134,11 @@ def get_logger():
logger = logging.getLogger()
logger.handlers = []
ch = logging.StreamHandler()
formatter = logging.Formatter('[%(levelname)s %(asctime)s] %(name)s %(message)s', '%H:%M:%S')
formatter = logging.Formatter(
"[%(levelname)s %(asctime)s] %(name)s %(message)s", "%H:%M:%S"
)
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.setLevel('DEBUG')
logger.setLevel("DEBUG")

return logger

0