Paper | Project Page | Notebook Demo | Models
This repo contains the official PyTorch implementation of Navigation World Models- the Conditional Diffusion Transformer (CDiT) model training code. See the project page for additional results.
Navigation World Models
Amir Bar, Gaoyue "Kathy" Zhou, Danny Tran, Trevor Darrell, Yann LeCun
AI at Meta, UC Berkeley, New York University
First, download and set up the repo:
git clone https://github.com/facebookresearch/nwm
cd nwm
To download and preprocess data, please follow the steps from NoMaD, specifically:
- Download the datasets
- Change the preprocessing resolution from (160, 120) to (320, 240) for higher resolution
- run
process_bags.py
andprocess_recon.py
to save each processed dataset topath/to/nwm_repo/data/<dataset_name>
.
For SACSon/HuRoN, we used a private version which contains higher resolution images. Please contact the authors for access.
Finally, you should have the following structure:
nwm/data
├── <dataset_name>
│ ├── <name_of_traj1>
│ │ ├── 0.jpg
│ │ ├── 1.jpg
│ │ ├── ...
│ │ ├── T_1.jpg
│ │ └── traj_data.pkl
│ ├── <name_of_traj2>
│ │ ├── 0.jpg
│ │ ├── 1.jpg
│ │ ├── ...
│ │ ├── T_2.jpg
│ │ └── traj_data.pkl
│ ...
└── └── <name_of_trajN>
├── 0.jpg
├── 1.jpg
├── ...
├── T_N.jpg
└── traj_data.pkl
mamba create -n nwm python=3.10
mamba activate nwm
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126
mamba install ffmpeg
pip3 install decord einops evo transformers diffusers tqdm timm notebook dreamsim torcheval lpips ipywidgets
Using torchrun:
export NUM_NODES=8
export HOST_NODE_ADDR=<HOST_ADDR>
export CURR_NODE_RANK=<NODE_RANK>
torchrun \
--nnodes=${NUM_NODES} \
--nproc-per-node=8 \
--node-rank=${CURR_NODE_RANK} \
--rdzv-backend=c10d \
--rdzv-endpoint=${HOST_NODE_ADDR}:29500 \
train.py --config config/nwm_cdit_xl.yaml --ckpt-every 2000 --eval-every 10000 --bfloat16 1 --epochs 300 --torch-compile 0
Or using submitit and slurm (8 machines of 8 gpus):
python submitit_train_cw.py --nodes 8 --partition <partition_name> --qos <qos> --config config/nwm_cdit_xl.yaml --ckpt-every 2000 --eval-every 10000 --bfloat16 1 --epochs 300 --torch-compile 0
Or locally on one GPU for debug:
python train.py --config config/nwm_cdit_xl.yaml --ckpt-every 2000 --eval-every 10000 --bfloat16 1 --epochs 300 --torch-compile 0
Note: torch compile can lead to ~40% faster training speed. However, it might lead to instabilities and inconsistent behvaior across different pytorch versions. Use carefuly.
To use a pretrained CDiT/XL model:
- Download a pretrained model from Hugging Face
- Place the checkpoint in ./logs/nwm_cdit_xl/checkpoints
directory to save evaluation results:
export RESULTS_FOLDER=/path/to/res_folder/
python isolated_nwm_infer.py \
--exp config/nwm_cdit_xl.yaml \
--datasets recon,scand,sacson,tartan_drive \
--batch_size 96 \
--num_workers 12 \
--eval_type time \
--output_dir ${RESULTS_FOLDER} \
--gt 1
python isolated_nwm_infer.py \
--exp config/nwm_cdit_xl.yaml \
--ckp 0100000 \
--datasets <dataset_name> \
--batch_size 64 \
--num_workers 12 \
--eval_type time \
--output_dir ${RESULTS_FOLDER}
python isolated_nwm_eval.py \
--datasets <dataset_name> \
--gt_dir ${RESULTS_FOLDER}/gt \
--exp_dir ${RESULTS_FOLDER}/nwm_cdit_xl \
--eval_types time
Results are saved in ${RESULTS_FOLDER}/nwm_cdit_xl/<dataset_name>
python isolated_nwm_infer.py \
--exp config/nwm_cdit_xl.yaml \
--datasets recon,scand,sacson,tartan_drive \
--batch_size 96 \
--num_workers 12 \
--eval_type rollout \
--output_dir ${RESULTS_FOLDER} \
--gt 1 \
--rollout_fps_values 1,4
python isolated_nwm_infer.py \
--exp config/nwm_cdit_xl.yaml \
--ckp 0100000 \
--datasets <dataset_name> \
--batch_size 64 \
--num_workers 12 \
--eval_type rollout \
--output_dir ${RESULTS_FOLDER} \
--rollout_fps_values 1,4
python isolated_nwm_eval.py \
--datasets recon \
--gt_dir ${RESULTS_FOLDER}/gt \
--exp_dir ${RESULTS_FOLDER}/nwm_cdit_xl \
--eval_types rollout
Results are saved in ${RESULTS_FOLDER}/nwm_cdit_xl/<dataset_name>
Using 1-step Cross Entropy Method planning on 8 gpus (sampling 120 trajectories):
torchrun --nproc-per-node=8 planning_eval.py \
--exp config/nwm_cdit_xl.yaml \
--datasets recon \
--rollout_stride 1 \
--batch_size 1 \
--num_samples 120 \
--topk 5 \
--num_workers 12 \
--output_dir ${RESULTS_FOLDER} \
--save_preds \
--ckp 0100000 \
--opt_steps 1 \
--num_repeat_eval 3
Results are saved in ${RESULTS_FOLDER}/nwm_cdit_xl/<dataset_name>
@article{bar2024navigation,
title={Navigation world models},
author={Bar, Amir and Zhou, Gaoyue and Tran, Danny and Darrell, Trevor and LeCun, Yann},
journal={arXiv preprint arXiv:2412.03572},
year={2024}
}
We thank Noriaki Hirose for his help with the HuRoN dataset and for sharing his insights, and to Manan Tomar, David Fan, Sonia Joseph, Angjoo Kanazawa, Ethan Weber, Nicolas Ballas, and the anonymous reviewers for their helpful discussions and feedback.
The code and model weights are licensed under Creative Commons Attribution-NonCommercial 4.0 International. See LICENSE.txt
for details.