8000 Added sequence iterator by luisenp · Pull Request #91 · facebookresearch/mbrl-lib · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
This repository was archived by the owner on Sep 1, 2024. It is now read-only.

Added sequence iterator #91

Merged
merged 21 commits into from
Jul 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
c8c02f8
[refactor] Deprecated method ReplayBuffer.get_iterators() and moved i…
luisenp Jun 10, 2021
d2a020d
Added an iterator for batches of sequences of fixed length.
luisenp Jun 11, 2021
513d727
Added option to loop a limited number of batches in the sequence iter…
luisenp Jun 11, 2021
3c220e5
Added option to return training and validation splits in get_sequence…
luisenp Jun 11, 2021
7030219
Merge branch 'master' into sequence_iterator
luisenp Jun 13, 2021
7fec7de
Removed unnecessary train_ensemble= arg from get_basic_buffer_iterators.
luisenp Jun 13, 2021
d3875bb
[bug-fix] SequenceIterator was off in the number of valid start state…
luisenp Jun 13, 2021
05be22d
Merge branch 'master' into sequence_iterator
luisenp Jun 14, 2021
029e0ab
Added option shuffle_each_epoch to the sequence iterator and to get_s…
luisenp Jul 5, 2021
c02b702
[bug-fix] Utility get_sequence_iterators() was not creating random tr…
luisenp Jul 5, 2021
8925341
Added separate options to get_sequence_iterators() for the max_batche…
luisenp Jul 5, 2021
eb22e0e
Merge branch 'master' into sequence_iterator
luisenp Jul 5, 2021
0d36e93
Merge branch 'master' into sequence_iterator
luisenp Jul 7, 2021
3a6147e
Merge branch 'master' into sequence_iterator
luisenp Jul 8, 2021
6223b51
Fixed a bug where the length attribute ignores max_batches_per_loop
jan1854 Jul 9, 2021
f7b0ec4
Merge pull request #105 from JanS97/sequence_iterator
luisenp Jul 9, 2021
c4e5d3a
Fixed a bug where the default argument for ensemble_size caused a Typ…
jan1854 Jul 12, 2021
e9ff5c2
Added an assertion that the replay buffer passed to get_sequence_buff…
jan1854 Jul 12, 2021
75779fe
Merge pull request #106 from JanS97/sequence_iterator
luisenp Jul 19, 2021
82005a1
Updated changelog and version number to v0.1.2
luisenp Jul 19, 2021
2e39169
Updated PETS notebook example to us 10000 e util.common.get_basic_buffer_ite…
luisenp Jul 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
# Changelog

## v0.1.2
- Multiple bug fixes
- Added a training browser to compare results of multiple runs
- Deprecated `ReplayBuffer.get_iterators()` and replaced with `mbrl.util.common.get_basic_iterators()`
- Added an iterator that returns batches of sequences of transitions of a given length

## v0.1.1
- Multiple bug fixes
- Added `third_party` folder for `pytorch_sac` and `dmc2gym`
- Library now available in `pypi`
- Moved example configurations to package mbrl.examples, which can now be
run as `python -m mbrl.examples.main`, after `pip` installation.
- Moved example configurations to package `mbrl.examples`, which can now be
run as `python -m mbrl.examples.main`, after `pip` installation

## v0.1.0

Expand Down
2 changes: 1 addition & 1 deletion mbrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
__version__ = "0.1.1"
__version__ = "0.1.2"
4 changes: 2 additions & 2 deletions mbrl/diagnostics/eval_model_on_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def run(self):
# Some models (e.g., GaussianMLP) require the batch size to be
# a multiple of number of models
batch_size = len(self.dynamics_model) * 8
dataset, _ = self.replay_buffer.get_iterators(
batch_size=batch_size, val_ratio=0
dataset, _ = mbrl.util.common.get_basic_buffer_iterators(
self.replay_buffer, batch_size=batch_size, val_ratio=0
)

self.plot_dataset_results(dataset)
Expand Down
4 changes: 2 additions & 2 deletions mbrl/diagnostics/finetune_model_with_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ def run(
logger=logger,
)

dataset_train, dataset_val = self.replay_buffer.get_iterators(
dataset_train, dataset_val = mbrl.util.common.get_basic_buffer_iterators(
self.replay_buffer,
batch_size,
val_ratio,
train_ensemble=len(self.dynamics_model.model) > 1,
ensemble_size=len(self.dynamics_model.model),
shuffle_each_epoch=True,
bootstrap_permutes=False,
Expand Down
21 changes: 21 additions & 0 deletions mbrl/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,26 @@ def __getitem__(self, item):
self.dones[item],
)

@staticmethod
def _get_new_shape(old_shape: Tuple[int, ...], batch_size: int):
new_shape = list((1,) + old_shape)
new_shape[0] = batch_size
new_shape[1] = old_shape[0] // batch_size
return tuple(new_shape)

def add_new_batch_dim(self, batch_size: int):
if not len(self) % batch_size == 0:
raise ValueError(
"Current batch of transitions size is not a "
"multiple of the new batch size. "
)
return TransitionBatch(
self.obs.reshape(self._get_new_shape(self.obs.shape, batch_size)),
self.act.reshape(self._get_new_shape(self.act.shape, batch_size)),
self.next_obs.reshape(self._get_new_shape(self.obs.shape, batch_size)),
self.rewards.reshape(self._get_new_shape(self.rewards.shape, batch_size)),
self.dones.reshape(self._get_new_shape(self.dones.shape, batch_size)),
)


ModelInput = Union[torch.Tensor, TransitionBatch]
8 changes: 2 additions & 6 deletions mbrl/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .logger import Logger
from .replay_buffer import ReplayBuffer, TransitionIterator
from .replay_buffer import ReplayBuffer, SequenceTransitionIterator, TransitionIterator

__all__ = [
"Logger",
"ReplayBuffer",
"TransitionIterator",
]
__all__ = ["Logger", "ReplayBuffer", "TransitionIterator", "SequenceTransitionIterator"]
141 changes: 136 additions & 5 deletions mbrl/util/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
import mbrl.planning
import mbrl.types

from .replay_buffer import ReplayBuffer
from .replay_buffer import (
BootstrapIterator,
ReplayBuffer,
SequenceTransitionIterator,
TransitionIterator,
)


# TODO read model from hydra
Expand Down Expand Up @@ -158,8 +163,7 @@ def create_replay_buffer(
batches. If None (default value), a new default generator will be used.

Returns:
(tuple of :class:`mbrl.replay_buffer.IterableReplayBuffer`): the training and validation
buffers, respectively.
(:class:`mbrl.replay_buffer.ReplayBuffer`): the replay buffer.
"""
dataset_size = (
cfg.algorithm.get("dataset_size", None) if "algorithm" in cfg else None
Expand Down Expand Up @@ -190,6 +194,132 @@ def create_replay_buffer(
return replay_buffer


def get_basic_buffer_iterators(
replay_buffer: ReplayBuffer,
batch_size: int,
val_ratio: float,
ensemble_size: int = 1,
shuffle_each_epoch: bool = True,
bootstrap_permutes: bool = False,
) -> Tuple[TransitionIterator, Optional[TransitionIterator]]:
"""Returns training/validation iterators for the data in the replay buffer.

Args:
replay_buffer (:class:`mbrl.util.ReplayBuffer`): the replay buffer from which
data will be sampled.
batch_size (int): the batch size for the iterators.
val_ratio (float): the proportion of data to use for validation. If 0., the
validation buffer will be set to ``None``.
ensemble_size (int): the size of the ensemble being trained.
shuffle_each_epoch (bool): if ``True``, the iterator will shuffle the
order each time a loop starts. Otherwise the iteration order will
be the same. Defaults to ``True``.
bootstrap_permutes (bool): if ``True``, the bootstrap iterator will create
the bootstrap data using permutations of the original data. Otherwise
it will use sampling with replacement. Defaults to ``False``.

Returns:
(tuple of :class:`mbrl.replay_buffer.TransitionIterator`): the training
and validation iterators, respectively.
"""
data = replay_buffer.get_all(shuffle=True)
val_size = int(replay_buffer.num_stored * val_ratio)
train_size = replay_buffer.num_stored - val_size
train_data = data[:train_size]
train_iter = BootstrapIterator(
train_data,
batch_size,
ensemble_size,
shuffle_each_epoch=shuffle_each_epoch,
permute_indices=bootstrap_permutes,
rng=replay_buffer.rng,
)

val_iter = None
if val_size > 0:
val_data = data[train_size:]
val_iter = TransitionIterator(
val_data, batch_size, shuffle_each_epoch=False, rng=replay_buffer.rng
)

return train_iter, val_iter


def get_sequence_buffer_iterator(
replay_buffer: ReplayBuffer,
batch_size: int,
val_ratio: float,
sequence_length: int,
ensemble_size: int = 1,
shuffle_each_epoch: bool = True,
max_batches_per_loop_train: Optional[int] = None,
max_batches_per_loop_val: Optional[int] = None,
) -> Tuple[SequenceTransitionIterator, Optional[SequenceTransitionIterator]]:
"""Returns training/validation iterators for the data in the replay buffer.

Args:
replay_buffer (:class:`mbrl.util.ReplayBuffer`): the replay buffer from which
data will be sampled.
batch_size (int): the batch size for the iterators.
val_ratio (float): the proportion of data to use for validation. If 0., the
validation buffer will be set to ``None``.
sequence_length (int): the length of the sequences returned by the iterators.
ensemble_size (int): the number of models in the ensemble.
shuffle_each_epoch (bool): if ``True``, the iterator will shuffle the
order each time a loop starts. Otherwise the iteration order will
be the same. Defaults to ``True``.
max_batches_per_loop_train (int, optional): if given, specifies how many batches
to return (at most) over a full loop of the training iterator.
max_batches_per_loop_val (int, optional): if given, specifies how many batches
to return (at most) over a full loop of the validation iterator.

Returns:
(tuple of :class:`mbrl.replay_buffer.SequenceTransitionIterator`): the training
and validation iterators, respectively.
"""

assert replay_buffer.stores_trajectories, (
"The passed replay buffer does not store trajectory information. "
"Make sure that the replay buffer is created with the max_trajectory_length "
"parameter set."
)

transitions = replay_buffer.get_all()
num_trajectories = len(replay_buffer.trajectory_indices)
val_size = int(num_trajectories * val_ratio)
train_size = num_trajectories - val_size
all_trajectories = replay_buffer.rng.permutation(replay_buffer.trajectory_indices)
train_trajectories = all_trajectories[:train_size]

train_iterator = SequenceTransitionIterator(
transitions,
train_trajectories,
batch_size,
sequence_length,
ensemble_size,
shuffle_each_epoch=shuffle_each_epoch,
rng=replay_buffer.rng,
max_batches_per_loop=max_batches_per_loop_train,
)

val_iterator = None
if val_size > 0:
val_trajectories = all_trajectories[train_size:]
val_iterator = SequenceTransitionIterator(
transitions,
val_trajectories,
batch_size,
sequence_length,
1,
shuffle_each_epoch=shuffle_each_epoch,
rng=replay_buffer.rng,
max_batches_per_loop=max_batches_per_loop_val,
)
val_iterator.toggle_bootstrap()

return train_iterator, val_iterator


def train_model_and_save_model_and_data(
model: mbrl.models.Model,
model_trainer: mbrl.models.ModelTrainer,
Expand All @@ -209,6 +339,7 @@ def train_model_and_save_model_and_data(
model_trainer (:class:`mbrl.models.ModelTrainer`): the model trainer.
cfg (:class:`omegaconf.DictConfig`): configuration to use for training. It
must contain the following fields::

-model_batch_size (int)
-validation_ratio (float)
-num_epochs_train_model (int, optional)
Expand All @@ -220,10 +351,10 @@ def train_model_and_save_model_and_data(
callback (callable, optional): if provided, this function will be called after
every training epoch. See :class:`mbrl.models.ModelTrainer` for signature.
"""
dataset_train, dataset_val = replay_buffer.get_iterators(
dataset_train, dataset_val = mbrl.util.common.get_basic_buffer_iterators(
replay_buffer,
cfg.model_batch_size,
cfg.validation_ratio,
train_ensemble=len(model) > 1,
ensemble_size=len(model),
shuffle_each_epoch=True,
bootstrap_permutes=cfg.get("bootstrap_permutes", False),
Expand Down
Loading
0