8000 feat: Update sft config to use single GPU by ashors1 · Pull Request #90 · NVIDIA/NeMo-RL · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

feat: Update sft config to use single GPU #90

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,22 @@ We provide a sample SFT experiment that uses the [SQuAD dataset](https://rajpurk

#### Single Node

The experiment is set up to run on 8 GPUs. If using a machine that has access to 8 GPUs, you can launch the experiment as follows:
The default SFT experiment is configured to run on a single GPU. To launch the experiment,

```sh
uv run python examples/run_sft.py
```

This trains `Llama3.1-8B` on 8 GPUs. To run on a single GPU, we'll have to override a few of the experiment settings. We replace the 8B model with a smaller 1B model, decrease the batch size, and update the cluster configuration to use a single gpu:
This trains `Llama3.2-1B` on one GPU using SQUAD dataset.

If you have access to more GPUs, you can update the experiment accordingly. To run on 8 GPUs, we update the cluster configuration. We also switch to an 8B Llama base model and increase the batch size:

```sh
uv run python examples/run_sft.py \
policy.model_name="meta-llama/Llama-3.2-1B" \
policy.train_global_batch_size=16 \
sft.val_global_batch_size=16 \
cluster.gpus_per_node=1
policy.model_name="meta-llama/Meta-Llama-3-8B" \
policy.train_global_batch_size=128 \
sft.val_global_batch_size=128 \
cluster.gpus_per_node=8
```

Refer to [sft.yaml](examples/configs/sft.yaml) for a full list of parameters that can be overridden.
Expand Down
27 changes: 10 additions & 17 deletions examples/configs/sft.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# SFT Algorithm Configuration
sft:
max_num_steps: 1000
max_num_steps: 60
val_period: 10
val_batches: 8
val_global_batch_size: 128
val_global_batch_size: 32
val_micro_batch_size: 1
val_at_start: true
seed: 42
Expand All @@ -17,10 +17,10 @@ checkpointing:
save_period: 10

policy:
model_name: "meta-llama/Meta-Llama-3-8B"
train_global_batch_size: 128
model_name: "meta-llama/Llama-3.2-1B"
train_global_batch_size: 32
train_micro_batch_size: 1
max_total_sequence_length: 2048
max_total_sequence_length: 1024
precision: "float32"

optimizer:
Expand All @@ -30,32 +30,25 @@ policy:
weight_decay: 0.1
betas: [0.9, 0.98]
eps: 1e-5

scheduler:
name: "torch.optim.lr_scheduler.LinearLR"
kwargs:
start_factor: 0.0196078
end_factor: 1.0
total_iters: 50

data:
max_input_seq_length: ${policy.max_total_sequence_length}
dataset_name: "squad"

logger:
log_dir: "logs" # Base directory for all logs
wandb_enabled: false
tensorboard_enabled: false
wandb_enabled: true # Make sure you do ``wandb login [Your API key]'' before run
tensorboard_enabled: true
monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard
wandb:
project: "sft-dev"
name: "sft-dev-logger"
name: "sft-dev-${data.dataset_name}"
tensorboard:
log_dir: "tb_logs"
log_dir: "tb_logs-sft-dev-${data.dataset_name}"
gpu_monitoring:
collection_interval: 10 # How often to collect GPU usage metrics (in seconds)
flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds)

cluster:
gpus_per_node: 8
gpus_per_node: 1
num_nodes: 1
0