8000 [fix] Fix math reward hanging [WIP] by BlankCheng · Pull Request #109 · LLM360/Reasoning360 · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[fix] Fix math reward hanging [WIP] #109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 40 additions & 64 deletions verl/utils/reward_score/naive_dapo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,35 +13,17 @@
# limitations under the License.
# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py

import math
import os
import re
import signal
from typing import Optional

import sympy
from pylatexenc import latex2text
from sympy.parsing import sympy_parser
import os

from .prime_math import math_normalize
from .prime_math.grader import math_equal


class timeout:

def __init__(self, seconds=1, error_message="Timeout"):
self.seconds = seconds
self.error_message = error_message

def handle_timeout(self, signum, frame):
raise TimeoutError(self.error_message)

def __enter__(self):
signal.signal(signal.SIGALRM, self.handle_timeout)
signal.alarm(self.seconds)

def __exit__(self, type, value, traceback):
signal.alarm(0)


# Constants for normalization
SUBSTITUTIONS = [
("an ", ""),
Expand Down Expand Up @@ -103,10 +85,10 @@ def __exit__(self, type, value, traceback):

def normalize_final_answer(final_answer: str) -> str:
"""Normalize a final answer to a quantitative reasoning question.

Args:
final_answer: The answer string to normalize

Returns:
Normalized answer string
"""
Expand Down Expand Up @@ -153,7 +135,6 @@ def timeout(timeout_seconds: int = 8):
import signal

def decorator(func):

def handler(signum, frame):
raise TimeoutError("Operation timed out!")

Expand Down Expand Up @@ -213,7 +194,7 @@ def _is_float(num: str) -> bool:
def _is_int(x: float) -> bool:
try:
return abs(x - int(round(x))) <= 1e-7
except:
except Exception:
return False


Expand All @@ -226,7 +207,7 @@ def _str_is_int(x: str) -> bool:
x = _strip_properly_formatted_commas(x)
x = float(x)
return abs(x - int(round(x))) <= 1e-7
except:
except Exception:
return False


Expand Down Expand Up @@ -279,26 +260,26 @@ def _normalize(expr: str) -> str:
expr = expr.replace("trillion", "*10^12")

for unit in [
"degree",
"cm",
"centimeter",
"meter",
"mile",
"second",
"minute",
"hour",
"day",
"week",
"month",
"year",
"foot",
"feet",
"inch",
"yard",
"liter",
"degree",
"cm",
"centimeter",
"meter",
"mile",
"second",
"minute",
"hour",
"day",
"week",
"month",
"year",
"foot",
"feet",
"inch",
"yard",
"liter",
]:
expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr)
expr = re.sub(f"\^ *\\\\circ", "", expr)
expr = re.sub("\^ *\\\\circ", "", expr)

if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}":
expr = expr[1:-1]
Expand All @@ -309,7 +290,7 @@ def _normalize(expr: str) -> str:
if "\\" in expr:
try:
expr = _parse_latex(expr)
except:
except Exception:
pass

# edge case with mixed numbers and negative signs
Expand Down Expand Up @@ -359,7 +340,7 @@ def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str):
simplified = sympy.simplify(sympy_diff)
if simplified == 0:
are_equal = True
except:
except Exception:
pass
return are_equal

Expand All @@ -371,8 +352,7 @@ def split_tuple(expr: str):
expr = _strip_properly_formatted_commas(expr)
if len(expr) == 0:
return []
if (len(expr) > 2 and expr[0] in TUPLE_CHARS and expr[-1] in TUPLE_CHARS and
all([ch not in expr[1:-1] for ch in TUPLE_CHARS])):
if len(expr) > 2 and expr[0] in TUPLE_CHARS and expr[-1] in TUPLE_CHARS and all([ch not in expr[1:-1] for ch in TUPLE_CHARS]):
elems = [elem.strip() for elem in expr[1:-1].split(",")]
else:
elems = [expr]
Expand Down Expand Up @@ -411,8 +391,7 @@ def grade_answer(given_answer: str, ground_truth: str) -> tuple[bool, str]:
ground_truth_elems = split_tuple(ground_truth_normalized)
given_elems = split_tuple(given_normalized)

if len(ground_truth_elems) > 1 and (ground_truth_normalized[0] != given_normalized[0] or
ground_truth_normalized[-1] != given_normalized[-1]):
if len(ground_truth_elems) > 1 and (ground_truth_normalized[0] != given_normalized[0] or ground_truth_normalized[-1] != given_normalized[-1]):
is_correct = False
elif len(ground_truth_elems) != len(given_elems):
is_correct = False
Expand All @@ -432,6 +411,7 @@ def grade_answer(given_answer: str, ground_truth: str) -> tuple[bool, str]:

return is_correct, given_normalized


def _last_boxed_only_string(string):
idx = string.rfind("\\boxed")
if idx < 0:
Expand Down Expand Up @@ -459,7 +439,7 @@ def _last_boxed_only_string(string):
if left_brace_idx is None or right_brace_idx is None:
return None

return string[left_brace_idx + 1:right_brace_idx].strip()
return string[left_brace_idx + 1 : right_brace_idx].strip()


def match_answer(response):
Expand All @@ -471,21 +451,18 @@ def match_answer(response):
if ans_boxed:
is_matched = True
response = ans_boxed

return is_matched, response

import math

def compute_score(solution_str: str,
ground_truth: str,
extra_info: dict) -> float:
def compute_score(solution_str: str, ground_truth: str, extra_info: dict) -> float:
"""Compute the reward score for a solution. This draws heavily from the LLM-as-judge and PRIME reward functions

Args:
solution_str: The solution string
ground_truth: The ground truth answer
extra_info: dict with additional info for the score computation

Returns:
Reward score (1.0 for correct, -1.0 for incorrect)
"""
Expand All @@ -495,13 +472,13 @@ def compute_score(solution_str: str,

# Extract answer from generated output
is_matched, extracted_model_output = match_answer(model_output)

# TWK NOTE: WE REMOVED THE RESPONSE TRUNCATION FROM math_dapo.compute_score

# Verify the solution, first check simple comparisons.
correct, pred = grade_answer(extracted_model_output, ground_truth)

if not correct:
if not correct:
try:
if "\\pi" in extracted_model_output or "\\pi" in ground_truth:
equivs = []
Expand All @@ -510,15 +487,14 @@ def compute_score(solution_str: str,
correct = any(equivs)
else:
correct = math_equal(extracted_model_output, ground_truth, timeout=True)
except:
except Exception:
correct = False


# reward = 1.0 if correct else -1.0
reward = 1.0 if correct else 0.
reward = 1.0 if correct else 0.0
acc = correct

return {
"score": reward,
"acc": acc,
}
}
Loading
0