8000 feat: Remove 'last 100' hack for math verifier by SahilJain314 · Pull Request #287 · NVIDIA-NeMo/RL · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

feat: Remove 'last 100' hack for math verifier #287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions nemo_rl/environments/math_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

import ray
import torch
from math_verify import parse, verify
from math_verify.metric import math_metric
from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig

from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.distributed.virtual_cluster import PY_EXECUTABLES
Expand Down Expand Up @@ -53,9 +54,23 @@ def verify(
results = []
for response, ground_truth in zip(pred_responses, ground_truths):
try:
gold = parse(ground_truth)
pred = parse(response[-100:]) # avoid looking at the whole string
results.append(float(verify(gold, pred)))
# Use Latex and plain math extraction from predictions
# https://github.com/huggingface/Math-Verify?tab=readme-ov-file#extraction-targets
verify_func = math_metric(
gold_extraction_target=(LatexExtractionConfig(),),
pred_extraction_target=(
ExprExtractionConfig(),
LatexExtractionConfig(),
),
)

ground_truth_parsable = "\\boxed{" + ground_truth + "}"
try:
ret_score, _ = verify_func([ground_truth_parsable], [response])
except Exception:
ret_score = 0.0

results.append(float(ret_score))
except Exception:
results.append(0)
return results
Expand Down
10 changes: 5 additions & 5 deletions tests/test_suites/nightly.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
########

# Short 1N/1B runs (go past 200 steps - usually divergence happens by now) -- going to 4 nodes doesn't help that much
tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.sh
tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.sh
tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v2.sh
tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v2.sh

# FSDP1 vs Dtensor (Qwen/Qwen2.5-7B-Instruct)
tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.sh
tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.sh
tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.v2.sh
tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v2.sh

# Functional 32b run
tests/test_suites/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.sh
tests/test_suites/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.v2.sh

#######
# SFT #
Expand Down
4 changes: 2 additions & 2 deletions tests/test_suites/release.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
########

# Long 8b run
tests/test_suites/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.sh
tests/test_suites/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v2.sh

# Long 32b run
tests/test_suites/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.sh
tests/test_suites/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.v2.sh

#######
# SFT #
Expand Down
0