8000 feat: Add deepscaler dataset by abukharin-nv · Pull Request #335 · NVIDIA-NeMo/RL · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
< 8000 div hidden="hidden" data-view-component="true" class="js-stale-session-flash stale-session-flash flash flash-warn flash-full"> Dismiss alert

feat: Add deepscaler dataset #335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions examples/run_grpo_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nemo_rl.algorithms.utils import get_tokenizer
from nemo_rl.data import DataConfig
from nemo_rl.data.datasets import AllTaskProcessedDataset
from nemo_rl.data.hf_datasets.deepscaler import DeepScalerDataset
from nemo_rl.data.hf_datasets.openmathinstruct2 import OpenMathInstruct2Dataset
from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType, TaskDataSpec
from nemo_rl.distributed.virtual_cluster import init_ray
Expand Down Expand Up @@ -52,7 +53,7 @@ def parse_args():
# ===============================================================================


def openinstructmath2_data_processor(
def hf_data_processor(
datum_dict: Dict[str, Any],
task_data_spec: TaskDataSpec,
tokenizer,
Expand Down Expand Up @@ -179,13 +180,16 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig, env_configs):
if data_config["dataset_name"] == "OpenMathInstruct-2":
print("Loading nvidia/OpenMathInstruct2Dataset for training and validation")
data = OpenMathInstruct2Dataset()
elif data_config["dataset_name"] == "DeepScaler":
print(
"Loading agentica-org/DeepScaleR-Preview-Dataset for training and validation"< 10000 /span>
)
data = DeepScalerDataset()
else:
raise ValueError(f"No processor for dataset {data_config['dataset_name']}.")

task_data_processors = defaultdict(
lambda: (math_task_spec, openinstructmath2_data_processor)
)
task_data_processors["math"] = (math_task_spec, openinstructmath2_data_processor)
task_data_processors = defaultdict(lambda: (math_task_spec, hf_data_processor))
task_data_processors["math"] = (math_task_spec, hf_data_processor)

math_env = MathEnvironment.options(
runtime_env={
Expand Down
73 changes: 73 additions & 0 deletions nemo_rl/data/hf_datasets/deepscaler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from datasets import load_dataset

from nemo_rl.data.interfaces import TaskDataSpec


def format_math(data):
return {
"messages": [
{
"role": "user",
"content": data["problem"],
},
{
"role": "assistant",
"content": data["answer"],
},
],
# For v0.1 release, nemo rl datasets require a task_name key such that user can map a task processor per unique task.
"task_name": "math",
}


def prepare_deepscaler_dataset(seed=42):
"""Load and split the DeepScaler dataset into train and test sets."""
# Load the original dataset
original_ds = load_dataset("agentica-org/DeepScaleR-Preview-Dataset", split="train")

# Shuffle the dataset with the specified seed
shuffled_ds = original_ds.shuffle(seed=seed)

# Take 128 samples for test set
test_ds = shuffled_ds.select(range(128))

# Use the rest for training
train_ds = shuffled_ds.select(range(128, len(shuffled_ds)))

# Format the examples, removing original columns
train_formatted = train_ds.map(format_math, remove_columns=train_ds.column_names)
test_formatted = test_ds.map(format_math, remove_columns=test_ds.column_names)

return {
"train": train_formatted,
"validation": test_formatted,
}


class DeepScalerDataset:
def __init__(self, seed: int = 42):
"""Initialize the DeepScaler dataset with train/test split.

Args:
seed: Random seed for reproducible splitting
"""
self.formatted_ds = prepare_deepscaler_dataset(seed=seed)

self.task_spec = TaskDataSpec(
task_name="DeepScaler",
)
Loading
0