8000 Update Checkpointing to support Adapter Weights by kartikayk · Pull Request #494 · pytorch/torchtune · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Update Checkpointing to support Adapter Weights #494

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 14, 2024
Merged

Conversation

kartikayk
Copy link
Contributor

Context

#442 introduced the Checkpointer component and added this for full-finetuning. In this PR, I extend support for LoRA single device recipe and update the checkpointers to be able to deal with adapter weights. There are a few changes in how we approach checkpointing for LoRA:

  1. We always output the merged checkpoint file, instead of guarding this by a flag. This is because users can't really do anything with the adapter weights other than to use it for resuming a previously failed TorchTune run. For running evaluation or inference on intermediate checkpoints, users will need the merged checkpoint file.

  2. During training, we also output the adapter weights and the recipe state. Both of these are needed to correctly resume a previously failed training run. Note that to resume the run, the user should provide the original LLM weights with the adapter and recipe checkpoint. This can potentially be confusing and needs to be clarified through documentation.

Instead of creating a new class of checkpointers, I added support for adapters to the current set. The API is quite clean.

Changelog

  • Update the lora_finetune_single_device recipe and config. Currently not updating the distributed version since this has a bug which @ebsmothers is working on fixing.
  • Extending support for handling adapter weights to the current Checkpointers. This likely means that I should update the names of these checkpointers, but I'll do this as a follow up.
  • Removed the hard coded strings and replaced these with constants
  • Marked TestLoRAFinalCheckpoints within test_lora_finetuning to be skipped since this test has multiple issues. It's currently failing when we run pytest tests and is really confusing to understand and extend. This needs to be replaced with a better and cleaner test.

Test plan

Run the complete test suite:

pytest tests
image

 

Full Finetune with Meta format Checkpoints on single device

tune --nnodes 1 --nproc_per_node 1 recipes/full_finetune.py \
--config recipes/configs/alpaca_llama2_full_finetune.yaml \
--override seed=30 epochs=2 batch_size=2 \
enable_activation_checkpointing=False \
enable_fsdp=False \
max_steps_per_epoch=4

# Checkpointer config:
checkpointer:
  _component_: torchtune.utils.FullModelMetaCheckpointer
  checkpoint_dir: /tmp/llama2
  checkpoint_files: [consolidated.00.pth]
  output_dir: /tmp/llama2
  model_type: LLAMA2
resume_from_checkpoint: False
image

 

Full Finetune with Meta format Checkpoints on single device - resume training

tune --nnodes 1 --nproc_per_node 1 recipes/full_finetune.py \
--config recipes/configs/alpaca_llama2_full_finetune.yaml \
--override seed=30 epochs=2 batch_size=2 \
enable_activation_checkpointing=False \
enable_fsdp=False \
max_steps_per_epoch=4

# Checkpointer config:
checkpointer:
  _component_: torchtune.utils.FullModelMetaCheckpointer
  checkpoint_dir: /tmp/llama2
  checkpoint_files: [meta_model_0.pt]
  recipe_checkpoint: recipe_state.pt
  output_dir: /tmp/llama2
  model_type: LLAMA2
resume_from_checkpoint: True
image

 

Full Finetune with HF format Checkpoints on single device

tune --nnodes 1 --nproc_per_node 1 recipes/full_finetune.py \
--config recipes/configs/alpaca_llama2_full_finetune.yaml \
--override seed=30 epochs=2 batch_size=2 \
enable_activation_checkpointing=False \
enable_fsdp=False \
max_steps_per_epoch=4

checkpointer:
  _component_: torchtune.utils.FullModelHFCheckpointer
  checkpoint_dir: /tmp/llama2-hf
  checkpoint_files: [pytorch_model-00001-of-00002.bin, pytorch_model-00002-of-00002.bin]
  adapter_checkpoint: null
  recipe_checkpoint: null
  output_dir: /tmp/llama2-hf
  model_type: LLAMA2
resume_from_checkpoint: False
image

 

Full Finetune with HF format Checkpoints on single device - resume training

tune --nnodes 1 --nproc_per_node 1 recipes/full_finetune.py \
--config recipes/configs/alpaca_llama2_full_finetune.yaml \
--override seed=30 epochs=2 batch_size=2 \
enable_activation_checkpointing=False \
enable_fsdp=False \
max_steps_per_epoch=4

checkpointer:
  _component_: torchtune.utils.FullModelHFCheckpointer
  checkpoint_dir: /tmp/llama2-hf
  checkpoint_files: [hf_model_0001_0.pt, hf_model_0002_0.pt]
  adapter_checkpoint: null
  recipe_checkpoint: recipe_state.pt
  output_dir: /tmp/llama2-hf
  model_type: LLAMA2
resume_from_checkpoint: True
image

 

LoRA Finetune with HF format Checkpoints on single device

tune --nnodes 1 --nproc_per_node 1 recipes/lora_finetune_single_device.py \
--config recipes/configs/alpaca_llama2_lora_finetune_single_device.yaml \
--override seed=30 epochs=2 batch_size=2 \
enable_activation_checkpointing=False \
enable_fsdp=False max_steps_per_epoch=4

checkpointer:
  _component_: torchtune.utils.FullModelHFCheckpointer
  checkpoint_dir: /tmp/llama2-hf
  checkpoint_files: [pytorch_model-00001-of-00002.bin, pytorch_model-00002-of-00002.bin]
  adapter_checkpoint: null
  recipe_checkpoint: null
  output_dir: /tmp/llama2-hf
  model_type: LLAMA2
resume_from_checkpoint: False
image

 

LoRA Finetune with HF format Checkpoints on single device - resume from training

tune --nnodes 1 --nproc_per_node 1 recipes/lora_finetune_single_device.py \
--config recipes/configs/alpaca_llama2_lora_finetune_single_device.yaml \
--override seed=30 epochs=2 batch_size=2 \
enable_activation_checkpointing=False \
enable_fsdp=False max_steps_per_epoch=4

checkpointer:
  _component_: torchtune.utils.FullModelHFCheckpointer
  checkpoint_dir: /tmp/llama2-hf
  checkpoint_files: [hf_model_0001_0.pt, hf_model_0002_0.pt]
  adapter_checkpoint: adapter_0.pt
  recipe_checkpoint: recipe_state.pt
  output_dir: /tmp/llama2-hf
  model_type: LLAMA2

resume_from_checkpoint: True
image

 

LoRA Finetune with Meta format Checkpoints on single device

tune --nnodes 1 --nproc_per_node 1 recipes/lora_finetune_single_device.py \
--config recipes/configs/alpaca_llama2_lora_finetune_single_device.yaml \
--override seed=30 epochs=2 batch_size=2 \
enable_activation_checkpointing=False \
enable_fsdp=False max_steps_per_epoch=4

checkpointer:
  _component_: torchtune.utils.FullModelMetaCheckpointer
  checkpoint_dir: /tmp/llama2
  checkpoint_files: [consolidated.00.pth]
  adapter_checkpoint: null
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: LLAMA2
resume_from_checkpoint: False
image

 

LoRA Finetune with Meta format Checkpoints on single device - resume from training

tune --nnodes 1 --nproc_per_node 1 recipes/lora_finetune_single_device.py \
--config recipes/configs/alpaca_llama2_lora_finetune_single_device.yaml \
--override seed=30 epochs=2 batch_size=2 \
enable_activation_checkpointing=False \
enable_fsdp=False max_steps_per_epoch=4

checkpointer:
  _component_: torchtune.utils.FullModelMetaCheckpointer
  checkpoint_dir: /tmp/llama2
  checkpoint_files: [meta_model_0.pt]
  adapter_checkpoint: adapter_0.pt
  recipe_checkpoint: recipe_state.pt
  output_dir: ${output_dir}
  model_type: LLAMA2
resume_from_checkpoint: True
image

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 13, 2024
@kartikayk kartikayk requested a review from ebsmothers March 13, 2024 20:59
Copy link
netlify bot commented Mar 13, 2024

Deploy Preview for torchtune-preview ready!

Name Link
🔨 Latest commit cc9435a
🔍 Latest deploy log https://app.netlify.com/sites/torchtune-preview/deploys/65f3128c12c0a60008489247
😎 Deploy Preview https://deploy-preview-494--torchtune-preview.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

Copy link
Contributor
@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this, no major concerns from my side. Happy to have this land and then work on getting the LoRA weight merge + FSDP story sorted out as a follow-up (especially because I think it's easier to do the necessary refactor with this checkpointer integrated into the recipe). Lmk if that makes sense to you

except KeyError as e:
raise KeyError from e(
"Checkpoint does not contain the required keys needed for updating recipe state."
"Are you suare you passed in the right recipe checkpoint?"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
"Are you suare you passed in the right recipe checkpoint?"
"Are you sure you passed in the right recipe checkpoint?"

# If seed, total_epoch or max_steps_per_epoch don't match,
# warn the user and overwrite
if (
self.seed != ckpt_dict[SEED_KEY]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: any reason not to use utils.SEED_KEY, etc here, as you did in the full finetune recipe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good catch, I should update this

checkpoint_files: [consolidated.00.pth]
adapter_checkpoint: null
recipe_checkpoint: null
output_dir: /tmp/llama2/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I feel there's some potential for confusion across checkpointer.output_dir and the logger's output_dir configs (I had to look at it a couple times myself)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally used the same, but I feel like users might want to output the checkpoints to the checkpoint dir since thats where we have all of the required json files. Any specific aspect which is confusing? Should I change the name?



@pytest.mark.skip(reason="This test is broken in many ways and needs to be refactored")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 115 to 116
# if ckpt == "lora_small_test_ckpt":
# return "/tmp/test-artifacts/small-ckpt-01242024"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove?

raise ValueError(
f"Checkpoint file {self._checkpoint_path} is not a valid checkpoint file. "
"Checkpointer expects a valid .pt file."
)

self._adapter_checkpoint = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to consider changing the name FullModelTorchTuneCheckpointer given these changes to support adapters. Though tbh I don't have a better suggestion at the moment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I plan to do as a follow up

merged_state_dict,
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
)

if self._adapter_checkpoint:
adapter_state_dict = safe_torch_load(self._adapter_checkpoint)
converted_state_dict[utils.ADAPTER_KEY] = adapter_state_dict
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wanna make sure I understand this: is the assumption here that we only load adapter checkpoints in TorchTune format (since we are not converting them)? If so, does that mean we only load them on resume?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes thats right, we only load them on resume. We don't do any conversions since we can't really do anything with them outside of TorchTune

Args:
input_dir (Path): Directory containing the file
filename (str): Name of the file
missing_ok (bool): Whether to raise an error if the file is missing.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: technically your docstring description is more for missing_not_ok

Copy link
pytorch-bot bot commented Mar 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/494

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit cc9435a with merge base c60d10a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@SLR722
Copy link
Contributor
SLR722 commented Mar 16, 2024

Thanks for implementing this! I think we are not longer need this step https://github.com/pytorch/torchtune?tab=readme-ov-file#converting-the-checkpoint-into-pytorch-native-for-lora after this change?

@kartikayk
Copy link
Contributor Author

@SLR722 we did need this after this PR until #506 landed. Now that this is landed, we don't need it anymore. I'll update the README shortly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0