From 45957688278716361ae2813dab42bd70b494bb70 Mon Sep 17 00:00:00 2001 From: AndreasXie <1206470742@qq.com> Date: Mon, 23 Jun 2025 23:59:39 +0000 Subject: [PATCH] fix saving json error for aime --- verl/trainer/main_generation.py | 95 ++++++++++++++++----------------- 1 file changed, 47 insertions(+), 48 deletions(-) diff --git a/verl/trainer/main_generation.py b/verl/trainer/main_generation.py index daa8c691..ee901162 100644 --- a/verl/trainer/main_generation.py +++ b/verl/trainer/main_generation.py @@ -65,33 +65,34 @@ def extract_content(p): def merge_aime_responses(dataset, output_lst, prompt_key="prompt", response_key="responses"): """Merge responses for AIME dataset based on prompt content""" # Convert to pandas DataFrame if it's not already - if hasattr(dataset, 'to_pandas'): # polars DataFrame + if hasattr(dataset, "to_pandas"): # polars DataFrame df = dataset.to_pandas() is_polars_df = True else: df = dataset.copy() is_polars_df = False - + # Add responses to dataframe df[response_key] = output_lst - + # Extract prompt content df["prompt_content"] = df[prompt_key].apply(extract_content) - + # Merge responses by prompt content group_keys = ["prompt_content"] agg_dict = {response_key: merge_responses} - + # Keep first value for other columns for col in df.columns: if col not in group_keys + [response_key]: agg_dict[col] = "first" - + df_merged = df.groupby(group_keys, as_index=False).agg(agg_dict) - + # Convert back to original format if needed if is_polars_df: import polars as pl + return pl.DataFrame(df_merged) else: return df_merged @@ -123,8 +124,26 @@ def main_task(config): tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) # NOTE: added by Reasoning360 - if 'olmoe' in local_path.lower() and 'instruct' not in local_path.lower(): - tokenizer.chat_template = "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'system' %}{{ '<|system|>\n' + message['content'] + '\n' }}{% elif message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '\n' }}{% elif message['role'] == 'assistant' %}{% if not loop.last %}{{ '<|assistant|>\n' + message['content'] + eos_token + '\n' }}{% else %}{{ '<|assistant|>\n' + message['content'] + eos_token }}{% endif %}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<|assistant|>\n' }}{% endif %}{% endfor %}" + if "olmoe" in local_path.lower() and "instruct" not in local_path.lower(): + tokenizer.chat_template = ( + "{{ bos_token }}" + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "{{ '<|system|>\\n' + message['content'] + '\\n' }}" + "{% elif message['role'] == 'user' %}" + "{{ '<|user|>\\n' + message['content'] + '\\n' }}" + "{% elif message['role'] == 'assistant' %}" + "{% if not loop.last %}" + "{{ '<|assistant|>\\n' + message['content'] + eos_token + '\\n' }}" + "{% else %}" + "{{ '<|assistant|>\\n' + message['content'] + eos_token }}" + "{% endif %}" + "{% endif %}" + "{% if loop.last and add_generation_prompt %}" + "{{ '<|assistant|>\\n' }}" + "{% endif %}" + "{% endfor %}" + ) if config.rollout.temperature == 0.0: assert config.data.n_samples == 1, "When temperature=0, n_samples must be 1." @@ -132,8 +151,9 @@ def main_task(config): # read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary) is_polars_df = False - if 'livecodebench' in config.data.path: + if "livecodebench" in config.data.path: import polars as pl + dataset = pl.read_parquet(config.data.path) chat_lst = list(dataset[config.data.prompt_key]) chat_lst = [list(chat) for chat in chat_lst] @@ -144,7 +164,7 @@ def main_task(config): chat_lst = dataset[config.data.prompt_key].tolist() chat_lst = [chat.tolist() for chat in chat_lst] ground_truth_lst = dataset["reward_model"].tolist() - + # NOTE: added by Reasoning360. handle n_samples if config.data.n_samples > 1: chat_lst = chat_lst * config.data.n_samples @@ -160,10 +180,8 @@ def main_task(config): wg.init_model() # NOTE: updated by Reasoning360. Sample n times together - total_samples = len(chat_lst) # chat_lst is repeated - # real_batch_size = data.batch['input_ids'].shape[0] + total_samples = len(chat_lst) # chat_lst is repeated config_batch_size = config.data.batch_size - dispatch_dp_size = wg.world_size num_batch = -(-total_samples // config_batch_size) output_lst = [] @@ -195,9 +213,7 @@ def main_task(config): # NOTE: modified by Reasoning360. Sample n times altogether. data_padded, pad_size = pad_dataproto_to_divisor(data, wg.world_size) - batch_size = data_padded.batch['input_ids'].shape[0] - - print(f'[{batch_idx+1}/{num_batch}] Start to generate.') + print(f"[{batch_idx + 1}/{num_batch}] Start to generate.") # START TO GENERATE FOR 1 TIME SINCE WE'VE ALREADY HANDLED n_samples beforehand output_padded = wg.generate_sequences(data_padded) # remove dummy data @@ -216,7 +232,7 @@ def main_task(config): pad_token = tokenizer.pad_token output_text_unpad = [] for text in output_texts: - output_text_unpad.append(text.replace(pad_token, '')) + output_text_unpad.append(text.replace(pad_token, "")) output_lst.extend(output_text_unpad) @@ -224,51 +240,34 @@ def main_task(config): original_data_size = len(dataset) output_lst = np.array(output_lst).reshape(config.data.n_samples, original_data_size) output_lst = output_lst.T.tolist() - + original_chat_lst = chat_lst[:original_data_size] original_ground_truth_lst = ground_truth_lst[:original_data_size] # Check if 'aime' is in the output path to determine if we should merge responses - should_merge_aime = 'aime' in config.data.output_path.lower() - + should_merge_aime = "aime" in config.data.output_path.lower() + if should_merge_aime: print("Detected 'aime' in output path, merging responses by prompt content...") # Use merge logic for AIME dataset merged_dataset = merge_aime_responses(dataset, output_lst, config.data.prompt_key, "responses") - + # Save merged dataset output_dir = os.path.dirname(config.data.output_path) makedirs(output_dir, exist_ok=True) - - if hasattr(merged_dataset, 'write_parquet'): # polars DataFrame + + if hasattr(merged_dataset, "write_parquet"): # polars DataFrame merged_dataset.write_parquet(config.data.output_path) else: # pandas DataFrame merged_dataset.to_parquet(config.data.output_path) - + print(f"Saved merged AIME responses to {config.data.output_path}") - - # Also save the merged results as JSON - merged_results = [] - df_to_iterate = merged_dataset if hasattr(merged_dataset, 'iterrows') else merged_dataset.to_pandas() - for _, row in df_to_iterate.iterrows(): - merged_results.append({ - "prompt": row[config.data.prompt_key], - "prompt_content": row["prompt_content"], - "responses": row["responses"], - "ground_truth": str(row.get("reward_model", "")), - }) - - model_name = config.model.path.split('/')[-1] - json_output_path = config.data.output_path.replace('.parquet', f'_merged_{model_name}.json') - with open(json_output_path, 'w', encoding='utf-8') as f: - json.dump(merged_results, f, indent=2, ensure_ascii=False) - print(f"Saved merged AIME results as JSON to {json_output_path}") - else: # Original logic for non-AIME datasets # add to the data frame if is_polars_df: import polars as pl + dataset = dataset.with_columns(pl.Series("responses", output_lst)) # write to a new parquet output_dir = os.path.dirname(config.data.output_path) @@ -276,23 +275,23 @@ def main_task(config): dataset.write_parquet(config.data.output_path) else: # For pandas, use standard bracket assignment - dataset['responses'] = output_lst + dataset["responses"] = output_lst # write to a new parquet output_dir = os.path.dirname(config.data.output_path) makedirs(output_dir, exist_ok=True) dataset.to_parquet(config.data.output_path) - + # NOTE: added by Reasoning360. dump results result_list = [ { "prompt": chat, "response": output, "ground_truth": str(ground_truth), - } + } for chat, output, ground_truth in zip(original_chat_lst, output_lst, original_ground_truth_lst) ] - model_name = config.model.path.split('/')[-1] - with open(config.data.output_path.replace('.parquet', f'_{model_name}.json'), 'w', encoding='utf-8') as f: + model_name = config.model.path.split("/")[-1] + with open(config.data.output_path.replace(".parquet", f"_{model_name}.json"), "w", encoding="utf-8") as f: json.dump(result_list, f, indent=2, ensure_ascii=False)