8000 fix: chat template improvements by ashors1 · Pull Request #148 · NVIDIA-NeMo/RL · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

fix: chat template improvements #148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
77b03d7
make chat template configurable from config, save chat template as at…
ashors1 Apr 5, 2025
6116f77
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/chat-…
ashors1 Apr 7, 2025
4f17127
save hf tokenizer
ashors1 Apr 7, 2025
525296a
add sft example with json data
ashors1 Apr 4, 2025
ce04b9b
improved configurability
ashors1 Apr 7, 2025
1597d0f
fixes
ashors1 Apr 8, 2025
70f11a3
update grpo and clean up
ashors1 Apr 8, 2025
83552a6
fix unit tests
ashors1 Apr 8, 2025
23f06d2
address comments
ashors1 Apr 9, 2025
1816456
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/chat-…
ashors1 Apr 10, 2025
58fe349
add unit tests and documentation
ashors1 Apr 11, 2025
956c7f0
copyright header
ashors1 Apr 11, 2025
4913f25
address comments
ashors1 Apr 11, 2025
c016902
small fixes
ashors1 Apr 11, 2025
326b151
fix typo
ashors1 Apr 11, 2025
5068487
fix tests
ashors1 Apr 11, 2025
8e835ba
update chat template documentation
ashors1 Apr 15, 2025
0eeaa1e
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/chat-…
ashors1 Apr 15, 2025
d95d88b
fix unit tests
ashors1 Apr 15, 2025
1958bb9
fix doctest
ashors1 Apr 15, 2025
c536b55
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/chat-…
ashors1 Apr 16, 2025
2c7c5c5
fix checkpoint save when tokenizer not provided
ashors1 Apr 16, 2025
1df9a9f
feat: add a unique seed for each vllm llm engine (#171)
parthchadha Apr 15, 2025
f35ad95
fix: unit test script halts on first failure (#189)
terrykong Apr 15, 2025
844e470
fix new vllm test and doctest
ashors1 Apr 16, 2025
7e50e8e
Merge branch 'main' of github.com:NVIDIA/reinforcer into ashors/chat-…
ashors1 Apr 16, 2025
ac4b6ea
remove old comment
ashors1 Apr 16, 2025
c5328f0
fix doctest
ashors1 Apr 16, 2025
5c4f849
Merge branch 'main' into ashors/chat-template-improvements
SahilJain314 Apr 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions docs/guides/sft.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ uv run examples/run_sft.py \

SFT datasets in Reinforcer are encapsulated using classes. Each SFT data class is expected to have the following attributes:
1. `formatted_ds`: The dictionary of formatted datasets. This dictionary should contain `train` and `validation` splits, and each split should conform to the format described below.
2. `task_spec`: The `TaskDataSpec` for this dataset. This should specify the name you choose for this dataset as well as the `custom_template` for this dataset. More on custom templates below.
2. `task_spec`: The `TaskDataSpec` for this dataset. This should specify the name you choose for this dataset.

SFT datasets are expected to follow the HuggingFace chat format. Refer to the [chat dataset document](../design-docs/chat-datasets.md) for details. If your data is not in the correct format, simply write a preprocessing script to convert the data into this format. [data/hf_datasets/squad.py](../../nemo_reinforcer/data/hf_datasets/squad.py) has an example:

Expand All @@ -51,17 +51,21 @@ def format_squad(data):
}
```

Reinforcer SFT uses HuggingFace chat templates to format the individual examples. If you would like to use a custom template, create a string template in [jinja format](https://huggingface.co/docs/transformers/v4.34.0/en/chat_templating#how-do-i-create-a-chat-template) and pass it to the dataset's `TaskDataSpec`. For example,
Reinforcer SFT uses HuggingFace chat templates to format the individual examples. Three types of chat templates are supported, which can be configured via `tokenizer.chat_template` in your yaml config (see [sft.yaml](../../examples/configs/sft.yaml) for an example):

1. Apply the tokenizer's default chat template. To use the tokenizer's default, either omit `tokenizer.chat_template` from the config altogether, or set `tokenizer.chat_template="default"`.
2. Use a "passthrough" template which simply concatenates all messages. This is desirable if the chat template has been applied to your dataset as an offline preprocessing step. In this case, you should set `tokenizer.chat_template` to None as follows:
```yaml
tokenizer:
chat_template: NULL
```
3. Use a custom template: If you would like to use a custom template, create a string template in [jinja format](https://huggingface.co/docs/transformers/v4.34.0/en/chat_templating#how-do-i-create-a-chat-template), and add that string to the config. For example,

```yaml
tokenizer:
custom_template: "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer: '}}{%- elif message['role'] == 'assistant' %}{{message['content'].strip()}}{%- endif %}{% endfor %}"
```

```python
custom_template = (
"{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer: '}}{%- elif message['role'] == 'assistant' %}{{message['content'].strip()}}{%- endif %}{% endfor %}"
)
task_spec = TaskDataSpec(
task_name="squad",
custom_template=custom_template,
)
```

By default, NeMo-Reinforcer has support for `Squad` and `OpenAssistant` datasets. Both of these datasets are downloaded from HuggingFace and preprocessed on-the-fly, so there's no need to provide a path to any datasets on disk.

Expand Down
3 changes: 2 additions & 1 deletion examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ checkpointing:

policy:
model_name: "meta-llama/Llama-3.2-1B-Instruct"
tokenizer_name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
tokenizer:
name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
train_global_batch_size: 512
train_micro_batch_size: 4
generation_batch_size: 32 # Only used when generating using HF backend
Expand Down
3 changes: 2 additions & 1 deletion examples/configs/grpo_math_8B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ grpo:

policy:
model_name: "meta-llama/Llama-3.1-8B-Instruct"
tokenizer_name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
tokenizer:
name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
train_global_batch_size: 512
train_micro_batch_size: 1
generation_batch_size: 32 # Only used when generating using HF backend
Expand Down
6 changes: 5 additions & 1 deletion examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ checkpointing:

policy:
model_name: "meta-llama/Llama-3.2-1B"
tokenizer_name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
tokenizer:
name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
chat_template: "{% for message in messages %}{%- if message['role'] == 'system' %}{{'Context: ' + message['content'].strip()}}{%- elif message['role'] == 'user' %}{{' Question: ' + message['content'].strip() + ' Answer:'}}{%- elif message['role'] == 'assistant' %}{{' ' + message['content'].strip()}}{%- endif %}{% endfor %}"
train_global_batch_size: 32
train_micro_batch_size: 1
max_total_sequence_length: 1024
Expand All @@ -35,6 +37,8 @@ policy:
data:
max_input_seq_length: ${policy.max_total_sequence_length}
dataset_name: "squad"
add_bos: true
add_eos: true

logger:
log_dir: "logs" # Base directory for all logs
Expand Down
4 changes: 1 addition & 3 deletions examples/run_grpo_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,13 @@ def openinstructmath2_data_processor(
problem = user_message[0]["content"]
extra_env_info = {"ground_truth": user_message[1]["content"]}

template = task_data_spec.custom_template
message_log: LLMMessageLogType = []
user_message = {
"role": "user",
"content": task_data_spec.prompt.format(problem),
}
message = tokenizer.apply_chat_template(
[user_message],
chat_template=template,
tokenize=False,
add_generation_prompt=True,
add_special_tokens=False,
Expand Down Expand Up @@ -254,7 +252,7 @@ def main():
init_ray()

# setup tokenizer
tokenizer = get_tokenizer(config["policy"]["model_name"])
tokenizer = get_tokenizer(config["policy"]["tokenizer"])
config["policy"]["generation"] = configure_generation_config(
config["policy"]["generation"], tokenizer
)
Expand Down
33 changes: 27 additions & 6 deletions examples/run_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import argparse
import os
import pprint
from functools import partial
from typing import Dict, Any

from omegaconf import OmegaConf
Expand Down Expand Up @@ -56,10 +57,16 @@ def sft_preprocessor(
tokenizer,
max_seq_length: int,
idx: int,
add_bos: bool = True,
add_eos: bool = True,
) -> DatumSpec:
"""Process a datum dictionary for SFT training."""
message_log = get_formatted_message_log(
datum_dict["messages"], tokenizer, task_data_spec
datum_dict["messages"],
tokenizer,
task_data_spec,
add_bos_token=add_bos,
add_eos_token=add_eos,
)

length = sum(len(m["token_ids"]) for m in message_log)
Expand Down Expand Up @@ -90,6 +97,13 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):
data = hf_datasets.OasstDataset(output_dir="/tmp/open_assistant")
elif data_cls == "squad":
data = hf_datasets.SquadDataset()
elif data_cls == "prompt_response_dataset":
data = hf_datasets.PromptResponseDataset(
data_config["train_data_path"],
data_config["val_data_path"],
data_config["input_key"],
data_config["output_key"],
)
else:
raise ValueError(f"Unknown dataset class: {data_cls}")
print(
Expand All @@ -104,15 +118,23 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):
train_dataset,
tokenizer,
sft_task_spec,
sft_preprocessor,
partial(
sft_preprocessor,
add_bos=data_config["add_bos"],
add_eos=data_config["add_eos"],
),
max_seq_length=data_config["max_input_seq_length"],
)

val_dataset = AllTaskProcessedDataset(
val_dataset,
tokenizer,
sft_task_spec,
sft_preprocessor,
partial(
sft_preprocessor,
add_bos=data_config.get("add_bos", True),
add_eos=data_config.get("add_eos", True),
),
max_seq_length=data_config["max_input_seq_length"],
)

Expand Down Expand Up @@ -151,7 +173,7 @@ def main():
init_ray()

# setup tokenizer
tokenizer = get_tokenizer(config["policy"]["model_name"])
tokenizer = get_tokenizer(config["policy"]["tokenizer"])

# setup data
(
Expand All @@ -170,8 +192,7 @@ def main():
checkpointer,
sft_save_state,
master_config,
) = setup(config, dataset, val_dataset)

) = setup(config, tokenizer, dataset, val_dataset)
sft_train(
policy,
train_dataloader,
Expand Down
4 changes: 4 additions & 0 deletions nemo_reinforcer/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def setup(
policy = HfPolicy(
cluster=cluster,
config=policy_config,
tokenizer=tokenizer,
weights_path=Path(last_checkpoint_path) / "policy" / "weights"
if last_checkpoint_path
else None,
Expand Down Expand Up @@ -628,6 +629,9 @@ def grpo_train(
optimizer_path=os.path.join(
checkpoint_path, "policy", "optimizer"
),
tokenizer_path=os.path.join(
checkpoint_path, "policy", "tokenizer"
),
save_hf=is_last_checkpoint,
)
torch.save(
Expand Down
6 changes: 6 additions & 0 deletions nemo_reinforcer/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from transformers import AutoTokenizer
from pathlib import Path
from typing import Optional, Tuple, TypedDict

Expand Down Expand Up @@ -76,6 +77,7 @@ class MasterConfig(TypedDict):
# =======================================================
def setup(
master_config: MasterConfig,
tokenizer: AutoTokenizer,
train_dataset: AllTaskProcessedDataset,
val_dataset: AllTaskProcessedDataset,
) -> Tuple[
Expand Down Expand Up @@ -175,6 +177,7 @@ def setup(
policy = HfPolicy(
cluster=cluster,
config=policy_config,
tokenizer=tokenizer,
weights_path=Path(last_checkpoint_path) / "policy" / "weights"
if last_checkpoint_path
else None,
Expand Down Expand Up @@ -416,6 +419,9 @@ def sft_train(
optimizer_path=os.path.join(
checkpoint_path, "policy", "optimizer"
),
tokenizer_path=os.path.join(
checkpoint_path, "policy", "tokenizer"
),
save_hf=is_last_checkpoint,
)
torch.save(
Expand Down
77 changes: 74 additions & 3 deletions nemo_reinforcer/algorithms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from torch.masked import as_masked_tensor
from transformers import AutoTokenizer

from nemo_reinforcer.data import hf_datasets
from nemo_reinforcer.models.policy import TokenizerConfig


def calculate_kl_penalty_joschu2020(
logprobs_policy: torch.Tensor, logprobs_reference: torch.Tensor
Expand Down Expand Up @@ -133,9 +136,77 @@ def set_seed(seed: int):
torch.cuda.manual_seed_all(seed)


def get_tokenizer(model_name: str) -> AutoTokenizer:
"""Get the tokenizer and set pad token to eos token if it is not already set."""
tokenizer = AutoTokenizer.from_pretrained(model_name)
def get_tokenizer(tokenizer_config: TokenizerConfig) -> AutoTokenizer:
"""Get the tokenizer and set pad token to eos token if it is not already set.

This function initializes a tokenizer from the Hugging Face transformers library
and configures it with appropriate chat templates and padding tokens.

Args:
tokenizer_config: A dictionary containing tokenizer configuration.
Required keys:
- name: The name or path of the pretrained tokenizer
Optional keys:
- chat_template: The chat template to use. Can be:
- None: Uses a passthrough template that just returns message content
- "default": Uses the tokenizer's default template
- A custom jinja2 template string
If not specified, the tokenizer's default template will be used.

Returns:
AutoTokenizer: The configured tokenizer instance

Examples:
```{doctest}
>>> from transformers import AutoTokenizer
>>> from nemo_reinforcer.algorithms.utils import get_tokenizer
>>> # not specifying a chat template uses the tokenizer's default
>>> config = {"name": "meta-llama/Llama-3.2-1B-Instruct"}
>>> tokenizer = get_tokenizer(config)
No chat template provided, using tokenizer's default
>>> messages = [
... {"role": "system", "content": "You are a helpful AI assistant."},
... {"role": "user", "content": "Hello!"}
... ]
>>> formatted = tokenizer.apply_chat_template(messages, tokenize=False)
>>> assert formatted == AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct").apply_chat_template(messages, tokenize=False)

>>> # Using a passthrough template
>>> config = {
... "name": "meta-llama/Llama-3.2-1B-Instruct",
... "chat_template": None
... }
>>> tokenizer = get_tokenizer(config)
Using passthrough chat template
>>> formatted = tokenizer.apply_chat_template(messages, tokenize=False)
>>> assert formatted == "".join(msg["content"] for msg in messages)

>>> # Using a custom template
>>> config = {
... "name": "meta-llama/Llama-3.2-1B-Instruct",
... "chat_template": "{% for message in messages %}{{ ' START: ' + message['content'] + ' END.' }}{% endfor %}"
... }
>>> tokenizer = get_tokenizer(config)
Using custom chat template
>>> formatted = tokenizer.apply_chat_template(messages, tokenize=False)
>>> assert formatted == " START: You are a helpful AI assistant. END. START: Hello! END."
```
"""
tokenizer = AutoTokenizer.from_pretrained(tokenizer_config["name"])
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if "chat_template" in tokenizer_config:
if tokenizer_config["chat_template"] is None:
print("Using passthrough chat template")
tokenizer.chat_template = (
hf_datasets.COMMON_CHAT_TEMPLATES.passthrough_prompt_response
)
elif tokenizer_config["chat_template"].lower() == "default":
print("Using tokenizer's default chat template")
else:
print("Using custom chat template")
tokenizer.chat_template = tokenizer_config["chat_template"]
else:
print("No chat template provided, using tokenizer's default")

return tokenizer
2 changes: 2 additions & 0 deletions nemo_reinforcer/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class DataConfig(TypedDict):
system_prompt_file: Optional[str]
dataset_name: str
val_dataset_name: Optional[str]
add_bos: Optional[bool]
add_eos: Optional[bool]


class MathDataConfig(DataConfig):
Expand Down
11 changes: 10 additions & 1 deletion nemo_reinforcer/data/hf_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo_reinforcer.data.hf_datasets.prompt_response_dataset import (
PromptResponseDataset,
)
from nemo_reinforcer.data.hf_datasets.oasst import OasstDataset
from nemo_reinforcer.data.hf_datasets.squad import SquadDataset
from nemo_reinforcer.data.hf_datasets.chat_templates import COMMON_CHAT_TEMPLATES

__all__ = ["OasstDataset", "SquadDataset"]
__all__ = [
"OasstDataset",
"PromptResponseDataset",
"SquadDataset",
"COMMON_CHAT_TEMPLATES",
]
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Any, Optional
from nemo_reinforcer.data.interfaces import TaskDataSpec


## a reference to frequently used chat templates for convenience
class COMMON_CHAT_TEMPLATES:
### simple template which prepends a role header to the content
simple_role_header = "{% for message in messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"


class HfDataset:
"""Interface for HuggingFace datasets."""

formatted_ds: Dict[str, Any]

def __init__(
self,
dataset_name: str,
custom_template: Optional[
str
] = None, ## "None" means use HuggingFace's tokenizer's template
):
self.task_spec = TaskDataSpec(
task_name=dataset_name,
custom_template=custom_template,
)
### passthrough template which just concatenates the content of the messages with no special tokens
passthrough_prompt_response = (
"{% for message in messages %}{{ message['content'] }}{% endfor %}"
)
Loading
0