8000 GitHub - emb-ai/octo-pytorch
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

emb-ai/octo-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Octo-PyTorch

Open In Colab

This repo contains code for training and finetuning Octo generalist robotic policies (GRPs) using PyTorch framework (a.k.a. OctoPt). Base architecture and fine-tuning procedures were reimplemented from the original repo, which can be useful for further research related to robotic manipulation. We tried to keep the code as close as possible to the original JAX implementation, but some parts were reimplemented from scratch.

Get Started

Follow the installation instructions, then load a pretrained Octo model! See the examples for guides to zero-shot evaluation and finetuning examples.

You can load original JAX weights directly from HuggingFace or any provided local checkpoint and use it for PyTorch module initialization.

from octo.model.octo_model_pt import OctoModelPt
model = OctoModelPt.load_pretrained_from_jax("hf://rail-berkeley/octo-small-1.5")['octo_model']

Or from PyTorch checkpoint:

from octo.model.octo_model_pt import OctoModelPt
model = OctoModelPt.load_pretrained(checkpoint_path)['octo_model']

Octo model

Installation

conda create -n octo_pt python=3.10
conda activate octo_pt
pip install -e .
pip install -r requirements.txt
pip install torch torchvision
pip install accelerate

Note: Since we use JAX only for loading weights, we do not need a JAX version with CUDA support.

See the Jax Github page for more details on installing Jax.

Test the installation by finetuning on the debug dataset:

torchrun --nproc_per_node 1 scripts/finetune_pt.py --config.pretrained_path=hf://rail-berkeley/octo-small-1.5 --debug

Checkpoints

You can find the original pretrained Octo checkpoints for JAX here. At the moment, the authors provide the following model versions:

Model Size
Octo-Base 93M Params
Octo-Small 27M Params

Examples

We provide simple example scripts that demonstrate how to use and finetune Octo models, as well as how to use our data loader independently. We provide the following examples:

OctoPt Inference Minimal example for loading weights from JAX model checkpoint and running an OctoPt model
OctoPt Finetuning Minimal example for finetuning a pretrained OctoPt model on a small dataset with a new observation and action space
OctoPt Rollout Run a rollout of a pretrained OctoPt policy in a Gym environment

|

OctoPt Finetuning

We provide a minimal example for finetuning with a new observation and action space.

We also provide a more advanced finetuning script that allows you to change hyperparameters via a config file and logs finetuning metrics. To run advanced finetuning, use:

torchrun --nproc_per_node 8 scripts/finetune_pt.py --config.pretrained_path=hf://rail-berkeley/octo-small-1.5

Finetuning running arguments are kept the same as in the original JAX implementation.

Initializing OctoPt Model with Weights from JAX Checkpoint

Instantiation from JAX Checkpoint

The simplest way to instantiate the new OctoModelPt from JAX checkpoint:

from octo.model.octo_model_pt import OctoModelPt
loaded_dict = OctoModelPt.load_pretrained_from_jax("hf://rail-berkeley/octo-small-1.5")
model = loaded_dict['octo_model']

The loaded_dict contains a ready-to-use torch model in loaded_dict['octo_model'], raw JAX weights in loaded_dict['jax_params'] and lists of missing and skipped keys in loaded_dict['missing_keys'] and loaded_dict['skipped_keys'] respectively.

Since the text tokenizer is frozen during training, we initialize it directly from HuggingFace. Therefore, text tokenizer weights are always missing during loading. To suppress warning messages about missing parameters, you can explicitly specify the keys to be skipped using a regular expression:

loaded_dict = OctoModelPt.load_pretrained_from_jax("hf://rail-berkeley/octo-small-1.5", skip_keys_regex='.*hf_model')
model = loaded_dict['octo_model']

Instantiation from the Configuration

More often, you might want to make modifications to the architecture (add extra heads or tokenizers, or change the default action head). In such cases, it is better to initialize the model from the configuration and load weights manually. Here is how you can do it:

meta = OctoModelPt.load_config_and_meta_from_jax("hf://rail-berkeley/octo-small-1.5")

The meta contains everything you need for instantiation: configuration, example batch, dataset statistics, and text processor. You can modify the configuration as you wish (all information about model architecture is stored in meta['config']['model']). For example, you may want to use an L1 Action Head instead of a Diffusion Head and use an additional observation tokenizer for the secondary image.

First, change the action head:

meta['config']['model']['heads'] = {'action': {
    'args': [],
    'kwargs': {
        'input_dim': 384,
        'action_horizon': 4,
        'dropout_rate': 0.0,
        'action_dim': 7,
        'readout_key': 'readout_action',
    },
    'module': 'octo.model.components.action_heads_pt',
    'name': 'L1ActionHeadPt'
    }
}

Or more elegant variant using ModuleSpec.create:

from octo.utils.spec import ModuleSpec
from octo.model.components.action_heads_pt import L1ActionHeadPt

meta['config']['model']['heads'] = ModuleSpec.create(
        L1ActionHeadPt,
        input_dim=384,
        action_horizon=4,
        action_dim=7,
        readout_key="readout_action",
    )

Then add observation tokenizer for secondary image:

from octo.model.components.vit_encoders_pt import SmallStem16Pt
from octo.model.components.tokenizers_pt import ImageTokenizerPt

meta['config']['model']['observation_tokenizers']['secondary'] = ModuleSpec.create(
    ImageTokenizerPt,
    encoder = ModuleSpec.create(
        SmallStem16Pt
    ),
    obs_stack_keys = ["image_secondary"],
    task_stack_keys = ["image_secondary"]
)

We also have to provide the number of input tokens for all observations and tasks. Since we added new input (secondary image), we need to update number of tokens in the configuration:

meta['config']['model']['num_tokens_dict'] = {
    'primary': 256,
    'secondary': 256, # add this item
    'wrist': 64,
    'language': 16,
    'action': 1
}

Lastly, we have to update meta['example_batch'] and meta['dataset_statistics'] according to the new data. Then we call constructor:

model = OctoModelPt.from_config(
        **meta
    )

Then we can initialize weights of the new model from the JAX checkpoint

missing_keys, skipped_keys = model.load_weights_from_jax("hf://rail-berkeley/octo-small-1.5", skip_keys_regex='.*hf_model')

Note: In this case, there will be some missing parameters corresponding to the new modules and skipped parameters corresponding to the old Diffusion Head.

You can refer to OctoPt finetuning example for more information.

Non-Strict Initialization

Sometimes you may not want to initialize some parameters from the JAX checkpoint. For example, if you change the action horizon from 4 to 8 in the default pretrained model:

meta = OctoModelPt.load_config_and_meta_from_jax("hf://rail-berkeley/octo-small-1.5")

# change horizon
meta['config']['model']['heads']['action']['kwargs']['action_horizon'] = 8 

model = OctoModelPt.from_config(
        **meta
    )
model.load_weights_from_jax("hf://rail-berkeley/octo-small-1.5")

The code snippet above gives an assertion error, because of parameter mismatch in the action head:

AssertionError: New value of 'heads.action.diffusion_model.reverse_network.linear1.weight' has shape torch.Size([256, 444]), but torch.Size([256, 472]) expected

More specifically, you have a shape mismatch in weights and biases in the diffusion_model.reverse_network modules. For finetuning, this is not critical, and we can skip them:

model.load_weights_from_jax("hf://rail-berkeley/octo-small-1.5", 
    skip_keys_regex=".*diffusion_model.reverse_network.[linear1|linear2]"
)

However, we can use slighly better initialization than random for these modules and use non-strict weights initialization:

model.load_weights_from_jax("hf://rail-berkeley/octo-small-1.5", 
    non_strict_keys_regex=".*diffusion_model.reverse_network.[linear1|linear2]"
)

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 15

Languages

0