8000 fix saving json error for aime by AndreasXie · Pull Request #101 · LLM360/Reasoning360 · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

fix saving json error for aime #101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 24, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 47 additions & 48 deletions verl/trainer/main_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,33 +65,34 @@ def extract_content(p):
def merge_aime_responses(dataset, output_lst, prompt_key="prompt", response_key="responses"):
"""Merge responses for AIME dataset based on prompt content"""
# Convert to pandas DataFrame if it's not already
if hasattr(dataset, 'to_pandas'): # polars DataFrame
if hasattr(dataset, "to_pandas"): # polars DataFrame
df = dataset.to_pandas()
is_polars_df = True
else:
df = dataset.copy()
is_polars_df = False

# Add responses to dataframe
df[response_key] = output_lst

# Extract prompt content
df["prompt_content"] = df[prompt_key].apply(extract_content)

# Merge responses by prompt content
group_keys = ["prompt_content"]
agg_dict = {response_key: merge_responses}

# Keep first value for other columns
for col in df.columns:
if col not in group_keys + [response_key]:
agg_dict[col] = "first"

df_merged = df.groupby(group_keys, as_index=False).agg(agg_dict)

# Convert back to original format if needed
if is_polars_df:
import polars as pl

return pl.DataFrame(df_merged)
else:
return df_merged
Expand Down Expand Up @@ -123,17 +124,36 @@ def main_task(config):
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)

# NOTE: added by Reasoning360
if 'olmoe' in local_path.lower() and 'instruct' not in local_path.lower():
tokenizer.chat_template = "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'system' %}{{ '<|system|>\n' + message['content'] + '\n' }}{% elif message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '\n' }}{% elif message['role'] == 'assistant' %}{% if not loop.last %}{{ '<|assistant|>\n' + message['content'] + eos_token + '\n' }}{% else %}{{ '<|assistant|>\n' + message['content'] + eos_token }}{% endif %}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<|assistant|>\n' }}{% endif %}{% endfor %}"
if "olmoe" in local_path.lower() and "instruct" not in local_path.lower():
tokenizer.chat_template = (
"{{ bos_token }}"
"{% for message in messages %}"
"{% if message['role'] == 'system' %}"
"{{ '<|system|>\\n' + message['content'] + '\\n' }}"
"{% elif message['role'] == 'user' %}"
"{{ '<|user|>\\n' + message['content'] + '\\n' }}"
"{% elif message['role'] == 'assistant' %}"
"{% if not loop.last %}"
"{{ '<|assistant|>\\n' + message['content'] + eos_token + '\\n' }}"
"{% else %}"
"{{ '<|assistant|>\\n' + message['content'] + eos_token }}"
"{% endif %}"
"{% endif %}"
"{% if loop.last and add_generation_prompt %}"
"{{ '<|assistant|>\\n' }}"
"{% endif %}"
"{% endfor %}"
)

if config.rollout.temperature == 0.0:
assert config.data.n_samples == 1, "When temperature=0, n_samples must be 1."
assert config.data.n_samples >= 1, "n_samples should always >= 1"

# read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary)
is_polars_df = False
if 'livecodebench' in config.data.path:
if "livecodebench" in config.data.path:
import polars as pl

dataset = pl.read_parquet(config.data.path)
chat_lst = list(dataset[config.data.prompt_key])
chat_lst = [list(chat) for chat in chat_lst]
Expand All @@ -144,7 +164,7 @@ def main_task(config):
chat_lst = dataset[config.data.prompt_key].tolist()
chat_lst = [chat.tolist() for chat in chat_lst]
ground_truth_lst = dataset["reward_model"].tolist()

# NOTE: added by Reasoning360. handle n_samples
if config.data.n_samples > 1:
chat_lst = chat_lst * config.data.n_samples
Expand All @@ -160,10 +180,8 @@ def main_task(config):
wg.init_model()

# NOTE: updated by Reasoning360. Sample n times together
total_samples = len(chat_lst) # chat_lst is repeated
# real_batch_size = data.batch['input_ids'].shape[0]
total_samples = len(chat_lst) # chat_lst is repeated
config_batch_size = config.data.batch_size
dispatch_dp_size = wg.world_size
num_batch = -(-total_samples // config_batch_size)

output_lst = []
Expand Down Expand Up @@ -195,9 +213,7 @@ def main_task(config):
# NOTE: modified by Reasoning360. Sample n times altogether.
data_padded, pad_size = pad_dataproto_to_divisor(data, wg.world_size)

batch_size = data_padded.batch['input_ids'].shape[0]

print(f'[{batch_idx+1}/{num_batch}] Start to generate.')
print(f"[{batch_idx + 1}/{num_batch}] Start to generate.")
# START TO GENERATE FOR 1 TIME SINCE WE'VE ALREADY HANDLED n_samples beforehand
output_padded = wg.generate_sequences(data_padded)
# remove dummy data
Expand All @@ -216,83 +232,66 @@ def main_task(config):
pad_token = tokenizer.pad_token
output_text_unpad = []
for text in output_texts:
output_text_unpad.append(text.replace(pad_token, ''))
output_text_unpad.append(text.replace(pad_token, ""))

output_lst.extend(output_text_unpad)

# convert output_lst from (n_samples * n_data ,) to (n_data, n_sampels)
original_data_size = len(dataset)
output_lst = np.array(output_lst).reshape(config.data.n_samples, original_data_size)
output_lst = output_lst.T.tolist()

original_chat_lst = chat_lst[:original_data_size]
original_ground_truth_lst = ground_truth_lst[:original_data_size]

# Check if 'aime' is in the output path to determine if we should merge responses
should_merge_aime = 'aime' in config.data.output_path.lower()
should_merge_aime = "aime" in config.data.output_path.lower()

if should_merge_aime:
print("Detected 'aime' in output path, merging responses by prompt content...")
# Use merge logic for AIME dataset
merged_dataset = merge_aime_responses(dataset, output_lst, config.data.prompt_key, "responses")

# Save merged dataset
output_dir = os.path.dirname(config.data.output_path)
makedirs(output_dir, exist_ok=True)
if hasattr(merged_dataset, 'write_parquet'): # polars DataFrame

if hasattr(merged_dataset, "write_parquet"): # polars DataFrame
merged_dataset.write_parquet(config.data.output_path)
else: # pandas DataFrame
merged_dataset.to_parquet(config.data.output_path)

print(f"Saved merged AIME responses to {config.data.output_path}")

# Also save the merged results as JSON
merged_results = []
df_to_iterate = merged_dataset if hasattr(merged_dataset, 'iterrows') else merged_dataset.to_pandas()
for _, row in df_to_iterate.iterrows():
merged_results.append({
"prompt": row[config.data.prompt_key],
"prompt_content": row["prompt_content"],
"responses": row["responses"],
"ground_truth": str(row.get("reward_model", "")),
})

model_name = config.model.path.split('/')[-1]
json_output_path = config.data.output_path.replace('.parquet', f'_merged_{model_name}.json')
with open(json_output_path, 'w', encoding='utf-8') as f:
json.dump(merged_results, f, indent=2, ensure_ascii=False)
print(f"Saved merged AIME results as JSON to {json_output_path}")

else:
# Original logic for non-AIME datasets
# add to the data frame
if is_polars_df:
import polars as pl

dataset = dataset.with_columns(pl.Series("responses", output_lst))
# write to a new parquet
output_dir = os.path.dirname(config.data.output_path)
makedirs(output_dir, exist_ok=True)
dataset.write_parquet(config.data.output_path)
else:
# For pandas, use standard bracket assignment
dataset['responses'] = output_lst
dataset["responses"] = output_lst
# write to a new parquet
output_dir = os.path.dirname(config.data.output_path)
makedirs(output_dir, exist_ok=True)
dataset.to_parquet(config.data.output_path)

# NOTE: added by Reasoning360. dump results
result_list = [
{
"prompt": chat,
"response": output,
"ground_truth": str(ground_truth),
}
}
for chat, output, ground_truth in zip(original_chat_lst, output_lst, original_ground_truth_lst)
]
model_name = config.model.path.split('/')[-1]
with open(config.data.output_path.replace('.parquet', f'_{model_name}.json'), 'w', encoding='utf-8') as f:
model_name = config.model.path.split("/")[-1]
with open(config.data.output_path.replace(".parquet", f"_{model_name}.json"), "w", encoding="utf-8") as f:
json.dump(result_list, f, indent=2, ensure_ascii=False)


Expand Down
0