-
Notifications
You must be signed in to change notification settings - Fork 626
Update Checkpointing to support Adapter Weights #494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
✅ Deploy Preview for torchtune-preview ready!
To edit notification comments on pull requests, go to your Netlify site configuration. |
f820846
to
09dd731
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this, no major concerns from my side. Happy to have this land and then work on getting the LoRA weight merge + FSDP story sorted out as a follow-up (especially because I think it's easier to do the necessary refactor with this checkpointer integrated into the recipe). Lmk if that makes sense to you
recipes/full_finetune.py
Outdated
except KeyError as e: | ||
raise KeyError from e( | ||
"Checkpoint does not contain the required keys needed for updating recipe state." | ||
"Are you suare you passed in the right recipe checkpoint?" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
"Are you suare you passed in the right recipe checkpoint?" | |
"Are you sure you passed in the right recipe checkpoint?" |
# If seed, total_epoch or max_steps_per_epoch don't match, | ||
# warn the user and overwrite | ||
if ( | ||
self.seed != ckpt_dict[SEED_KEY] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: any reason not to use utils.SEED_KEY
, etc here, as you did in the full finetune recipe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah good catch, I should update this
checkpoint_files: [consolidated.00.pth] | ||
adapter_checkpoint: null | ||
recipe_checkpoint: null | ||
output_dir: /tmp/llama2/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I feel there's some potential for confusion across checkpointer.output_dir and the logger's output_dir configs (I had to look at it a couple times myself)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I originally used the same, but I feel like users might want to output the checkpoints to the checkpoint dir since thats where we have all of the required json files. Any specific aspect which is confusing? Should I change the name?
|
||
|
||
@pytest.mark.skip(reason="This test is broken in many ways and needs to be refactored") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tests/recipes/utils.py
Outdated
# if ckpt == "lora_small_test_ckpt": | ||
# return "/tmp/test-artifacts/small-ckpt-01242024" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove?
raise ValueError( | ||
f"Checkpoint file {self._checkpoint_path} is not a valid checkpoint file. " | ||
"Checkpointer expects a valid .pt file." | ||
) | ||
|
||
self._adapter_checkpoint = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may want to consider changing the name FullModelTorchTuneCheckpointer
given these changes to support adapters. Though tbh I don't have a better suggestion at the moment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, I plan to do as a follow up
merged_state_dict, | ||
num_heads=self._config["num_attention_heads"], | ||
num_kv_heads=self._config["num_key_value_heads"], | ||
dim=self._config["hidden_size"], | ||
) | ||
|
||
if self._adapter_checkpoint: | ||
adapter_state_dict = safe_torch_load(self._adapter_checkpoint) | ||
converted_state_dict[utils.ADAPTER_KEY] = adapter_state_dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wanna make sure I understand this: is the assumption here that we only load adapter checkpoints in TorchTune format (since we are not converting them)? If so, does that mean we only load them on resume?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes thats right, we only load them on resume. We don't do any conversions since we can't really do anything with them outside of TorchTune
Args: | ||
input_dir (Path): Directory containing the file | ||
filename (str): Name of the file | ||
missing_ok (bool): Whether to raise an error if the file is missing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: technically your docstring description is more for missing_not_ok
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/494
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit cc9435a with merge base c60d10a ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
1908bb3
to
cc9435a
Compare
Thanks for implementing this! I think we are not longer need this step https://github.com/pytorch/torchtune?tab=readme-ov-file#converting-the-checkpoint-into-pytorch-native-for-lora after this change? |
Context
#442 introduced the Checkpointer component and added this for full-finetuning. In this PR, I extend support for LoRA single device recipe and update the checkpointers to be able to deal with adapter weights. There are a few changes in how we approach checkpointing for LoRA:
We always output the merged checkpoint file, instead of guarding this by a flag. This is because users can't really do anything with the adapter weights other than to use it for resuming a previously failed TorchTune run. For running evaluation or inference on intermediate checkpoints, users will need the merged checkpoint file.
During training, we also output the adapter weights and the recipe state. Both of these are needed to correctly resume a previously failed training run. Note that to resume the run, the user should provide the original LLM weights with the adapter and recipe checkpoint. This can potentially be confusing and needs to be clarified through documentation.
Instead of creating a new class of checkpointers, I added support for adapters to the current set. The API is quite clean.
Changelog
lora_finetune_single_device
recipe and config. Currently not updating the distributed version since this has a bug which @ebsmothers is working on fixing.TestLoRAFinalCheckpoints
withintest_lora_finetuning
to be skipped since this test has multiple issues. It's currently failing when we runpytest tests
and is really confusing to understand and extend. This needs to be replaced with a better and cleaner test.Test plan
Run the complete test suite:
Full Finetune with Meta format Checkpoints on single device
Full Finetune with Meta format Checkpoints on single device - resume training
Full Finetune with HF format Checkpoints on single device
Full Finetune with HF format Checkpoints on single device - resume training
LoRA Finetune with HF format Checkpoints on single device
LoRA Finetune with HF format Checkpoints on single device - resume from training
LoRA Finetune with Meta format Checkpoints on single device
LoRA Finetune with Meta format Checkpoints on single device - resume from training