From 2b932b84a48a77706bb6207150401bbb6d4fac72 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Thu, 26 Feb 2026 18:13:59 +0800 Subject: [PATCH 1/2] Add MultiStepPadding Operator --- examples/agentscope_frozenlake/agent.py | 1 + .../multi_step_padding.yaml | 92 +++++++++++++++ trinity/buffer/operators/__init__.py | 1 + .../buffer/operators/multi_step_operator.py | 105 ++++++++++++++++++ trinity/common/experience.py | 15 +++ trinity/trainer/verl/utils.py | 21 +++- 6 files changed, 234 insertions(+), 1 deletion(-) create mode 100644 examples/agentscope_frozenlake/multi_step_padding.yaml create mode 100644 trinity/buffer/operators/multi_step_operator.py diff --git a/examples/agentscope_frozenlake/agent.py b/examples/agentscope_frozenlake/agent.py index f05689d6b2..b96be61d2c 100644 --- a/examples/agentscope_frozenlake/agent.py +++ b/examples/agentscope_frozenlake/agent.py @@ -26,6 +26,7 @@ def __init__(self, model: OpenAIChatModel, max_steps: int = 20): formatter=OpenAIChatFormatter(), max_iters=2, ) + self.agent.set_console_output_enabled(False) self.response_structure = FrozenLakeAction self.current_step = 0 self.last_action = None diff --git a/examples/agentscope_frozenlake/multi_step_padding.yaml b/examples/agentscope_frozenlake/multi_step_padding.yaml new file mode 100644 index 0000000000..734bd6fcc3 --- /dev/null +++ b/examples/agentscope_frozenlake/multi_step_padding.yaml @@ -0,0 +1,92 @@ +project: "FrozenLake" +name: "Qwen25-3B-padding" +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} +algorithm: + algorithm_type: multi_step_grpo + repeat_times: 8 + kl_loss_fn: "low_var_kl" + kl_loss_fn_args: + kl_coef: 0 + advantage_fn_args: + epsilon: 1e-6 + optimizer: + lr: 1e-6 + policy_loss_fn_args: + clip_range_low: 0.2 + clip_range_high: 0.28 +data_processor: + experience_pipeline: + operators: # NOTE + - name: multi_step_padding + args: + max_steps: 10 +model: + model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-3B-Instruct} + max_response_tokens: 2048 + min_response_tokens: 0 + max_model_len: 25600 + temperature: 0.7 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 5 + batch_size: 64 + train_batch_size: 5120 # NOTE: 64 * 8 * 10 = batch_size * repeat_times * max_steps + explorer_input: + taskset: + name: frozenlake + storage_type: file + path: ${oc.env:TRINITY_TASKSET_PATH} + split: train + workflow_args: + env_max_steps: 8 + agent_max_steps: 10 + is_slippery: false + eval_tasksets: + - name: frozenlake + storage_type: file + path: ${oc.env:TRINITY_TASKSET_PATH} + split: test + workflow_args: + env_max_steps: 8 + agent_max_steps: 10 + is_slippery: false + repeat_times: 4 + rollout_args: + top_p: 0.8 + top_k: 20 + default_workflow_type: 'examples.agentscope_frozenlake.workflow.FrozenLakeWorkflow' + trainer_input: + experience_buffer: + name: frozenlake_experience_buffer + storage_type: queue + max_read_timeout: 7200 +explorer: + eval_on_startup: true + eval_interval: 20 + runner_per_model: 6 + max_repeat_times_per_runner: 4 + rollout_model: + engine_num: 4 + tensor_parallel_size: 1 + enable_chunked_prefill: true + enforce_eager: false + enable_openai_api: true + enable_log_requests: false + enable_history: true + enable_auto_tool_choice: true + tool_call_parser: hermes + enable_thinking: true + dtype: bfloat16 + seed: 42 +trainer: + save_interval: 50 + use_dynamic_bsz: true + grad_clip: 1.0 + ulysses_sequence_parallel_size: 2 +synchronizer: + sync_method: nccl + sync_style: fixed + sync_interval: 1 + sync_timeout: 1200 diff --git a/trinity/buffer/operators/__init__.py b/trinity/buffer/operators/__init__.py index e83b7d05ee..88840cc366 100644 --- a/trinity/buffer/operators/__init__.py +++ b/trinity/buffer/operators/__init__.py @@ -10,6 +10,7 @@ "pass_rate_calculator": "trinity.buffer.operators.mappers.pass_rate_calculator.PassRateCalculator", "data_juicer": "trinity.buffer.operators.data_juicer_operator.DataJuicerOperator", "invalid_reward_filter": "trinity.buffer.operators.filters.reward_filter.InvalidRewardFilter", + "multi_step_padding": "trinity.buffer.operators.multi_step_operator.MultiStepPadding", }, ) diff --git a/trinity/buffer/operators/multi_step_operator.py b/trinity/buffer/operators/multi_step_operator.py new file mode 100644 index 0000000000..35c2eda4a2 --- /dev/null +++ b/trinity/buffer/operators/multi_step_operator.py @@ -0,0 +1,105 @@ +from typing import List, Tuple + +import torch + +from trinity.buffer.operators import ExperienceOperator +from trinity.common.experience import EID, Experience, group_by +from trinity.utils.log import get_logger + +logger = get_logger(__name__) + + +class MultiStepPadding(ExperienceOperator): + """ + Padding experiences of one run to the max step. + + Note: This operator assumes that the reward is already calculated and stored in the Experience object. + """ + + def __init__(self, max_steps: int = 0): + self.max_steps = max_steps + + def process(self, exps: List[Experience]) -> Tuple[List[Experience], dict]: + """Padding each rollout to the max step.""" + logger.debug(f"Processing {len(exps)} experiences") + total_num_placeholder_exps = 0 + all_exps = [] + + task_exps = group_by(exps, "task") + for _, task_exp in task_exps.items(): + run_exps = group_by(task_exp, "run") + for _, exps_same_run in run_exps.items(): + if len(exps_same_run) == 0: + continue + num_placeholder_exps = 0 + if len(exps_same_run) < self.max_steps: + num_placeholder_exps = self.max_steps - len(exps_same_run) + # Calculate average response length to keep metrics unchanged + assert all( + exp.tokens is not None for exp in exps_same_run + ), "Tokens are not provided" + response_lengths = [ + len(exp.tokens) - exp.prompt_length for exp in exps_same_run # type: ignore + ] + avg_response_length = int(sum(response_lengths) / len(response_lengths)) + # Ensure at least 1 to avoid zero-length response + avg_response_length = max(avg_response_length, 1) + + # Use the first experience as a template + template_exp = exps_same_run[0] + prompt_length = template_exp.prompt_length + + # Create tokens with average response length + # Keep the prompt part, pad the response part to average length + prompt_tokens = template_exp.tokens[:prompt_length] # type: ignore + # Use the last token of prompt as padding token for response part + pad_token = prompt_tokens[-1] if len(prompt_tokens) > 0 else 0 + response_tokens = torch.full( + (avg_response_length,), + pad_token, + dtype=template_exp.tokens.dtype, # type: ignore + ) + avg_tokens = torch.cat([prompt_tokens, response_tokens]) + avg_logprobs = ( + torch.zeros(avg_response_length, dtype=torch.float32) + if template_exp.logprobs is not None + else None + ) + assert all( + exp.reward is not None for exp in exps_same_run + ), "Rewards are not provided" + rewards = [exp.reward for exp in exps_same_run if exp.reward is not None] + avg_reward = sum(rewards) / len(rewards) + + template_eid = template_exp.eid + + empty_experiences = [ + Experience( + eid=EID( + batch=template_eid.batch, + task=template_eid.task, + run=template_eid.run, + step=-1, + ), # -1 means placeholder + tokens=avg_tokens, + logprobs=avg_logprobs, + prompt_length=prompt_length, + action_mask=torch.zeros(avg_response_length, dtype=torch.bool), + truncate_status="prompt_truncated", # TODO: merge with the following + info={"status": "placeholder"}, # TODO: use another field + reward=avg_reward, + ) + for _ in range(num_placeholder_exps) + ] + logger.debug( + f"Adding {num_placeholder_exps} placeholder experiences" + ) + # Put empty at the beginning, as the adv is computed using the last exp + exps_same_run = empty_experiences + exps_same_run + all_exps.extend(exps_same_run) + else: + all_exps.extend(exps_same_run) + total_num_placeholder_exps += num_placeholder_exps + metrics = {"total_num_placeholder_exps": total_num_placeholder_exps} + logger.debug(f"After padding: {len(all_exps)}") + return all_exps, metrics diff --git a/trinity/common/experience.py b/trinity/common/experience.py index d9d519252e..7f92562706 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -366,6 +366,9 @@ def gather( else: teacher_logprobs = None + # gather statuses + statuses = gather_statuses(experiences) + exps = Experiences( eids=eids, tokens=tokens, @@ -379,6 +382,7 @@ def gather( logprobs=logprobs, multi_modal_inputs=multi_modal_inputs, teacher_logprobs=teacher_logprobs, + statuses=statuses, ) if custom_fields is not None: for custom_field in custom_fields: @@ -465,6 +469,7 @@ class Experiences: prompt_length: int logprobs: Optional[Tensor] # [batch_size, response_length] multi_modal_inputs: Optional[Any] + statuses: Optional[Tensor] = None # [batch_size] # 1 for effective, 0 for placeholder custom_fields: List[str] = field( default_factory=list ) # Custom fields to include in the gathered experiences @@ -605,6 +610,16 @@ def gather_multi_modal_inputs(experiences) -> Dict[str, Tensor]: return {key: [exp.multi_modal_inputs[key] for exp in experiences] for key in keys} +def gather_statuses(experiences) -> Tensor: + statuses = [] + for exp in experiences: + if exp.info.get("status", None) == "placeholder": + statuses.append(0) + else: + statuses.append(1) + return torch.tensor(statuses, dtype=torch.bool) + + def group_by( experiences: List[Experience], id_type: Literal["task", "run", "step"] ) -> Dict[str, List[Experience]]: diff --git a/trinity/trainer/verl/utils.py b/trinity/trainer/verl/utils.py index 16fd02d894..6e3f21e442 100644 --- a/trinity/trainer/verl/utils.py +++ b/trinity/trainer/verl/utils.py @@ -17,15 +17,35 @@ gather_action_masks, gather_attention_masks, gather_response_attrs, + gather_statuses, gather_token_ids, split_dpo_experience_to_single_turn, ) +def print_effective_experience_stats(experiences: List[Experience], logger: Logger) -> None: + """Gather effective experience count and the corresponding reweight factor.""" + statuses = gather_statuses(experiences) + effective_experiences = torch.sum(statuses).item() + batch_size = len(experiences) + if effective_experiences == 0: + effective_weight = 1.0 + logger.info("No effective experiences found, using default weight 1.0") + else: + effective_weight = float(batch_size / effective_experiences) + logger.info( + f"Effective experiences: {effective_experiences}, batch size: {batch_size}, effective_weight: {effective_weight}" + ) + return None + + def to_data_proto( experiences: List[Experience], pad_token_id: int, processor: ProcessorMixin, logger: Logger ) -> DataProto: # noqa: C901 """Convert List[Experience] to verl DataProto.""" + + print_effective_experience_stats(experiences, logger) + assert len(experiences) > 0, "No experiences provided." if experiences[0].experience_type == "dpo": experiences = split_dpo_experience_to_single_turn(experiences) @@ -49,7 +69,6 @@ def to_data_proto( "attention_mask": attention_mask, "response_mask": gather_action_masks(experiences, max_response_length), } - have_reward = all(exp.reward is not None for exp in experiences) have_token_level_reward = all(exp.token_level_reward is not None for exp in experiences) if have_reward or have_token_level_reward: From 94b9c513fa9533d953683190a62b1be21eee341a Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Sat, 28 Feb 2026 14:26:40 +0800 Subject: [PATCH 2/2] fix pre-commit --- trinity/buffer/operators/multi_step_operator.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/trinity/buffer/operators/multi_step_operator.py b/trinity/buffer/operators/multi_step_operator.py index 35c2eda4a2..bb08a43a7f 100644 --- a/trinity/buffer/operators/multi_step_operator.py +++ b/trinity/buffer/operators/multi_step_operator.py @@ -91,9 +91,7 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], dict]: ) for _ in range(num_placeholder_exps) ] - logger.debug( - f"Adding {num_placeholder_exps} placeholder experiences" - ) + logger.debug(f"Adding {num_placeholder_exps} placeholder experiences") # Put empty at the beginning, as the adv is computed using the last exp exps_same_run = empty_experiences + exps_same_run all_exps.extend(exps_same_run)