Closed
Description
For example, the following trains without raising an error:
uv run examples/run_grpo_math.py --config=examples/configs/grpo_math_1B.yaml grpo.val_at_start=False checkpointing.enabled=False logger.wandb_enabled=False cluster.gpus_per_node=2 policy.model_name=Qwen/Qwen2.5-1.5B policy.dtensor_cfg.enabled=False
Changing the gpus_per_node to 1 raises the error correctly:
uv run examples/run_grpo_math.py --config=examples/configs/grpo_math_1B.yaml grpo.val_at_start=False checkpointing.enabled=False logger.wandb_enabled=False cluster.gpus_per_node=1 policy.model_name=Qwen/Qwen2.5-1.5B policy.dtensor_cfg.enabled=False
This seems to be because find_tied_parameters
doesn't work correctly with FSDP models. From my testing, transformers.modeling_utils._get_tied_weight_keys
seems to work correctly.