Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/agentscope_frozenlake/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(self, model: OpenAIChatModel, max_steps: int = 20):
formatter=OpenAIChatFormatter(),
max_iters=2,
)
self.agent.set_console_output_enabled(False)
self.response_structure = FrozenLakeAction
self.current_step = 0
self.last_action = None
Expand Down
92 changes: 92 additions & 0 deletions examples/agentscope_frozenlake/multi_step_padding.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
project: "FrozenLake"
name: "Qwen25-3B-padding"
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
algorithm:
algorithm_type: multi_step_grpo
repeat_times: 8
kl_loss_fn: "low_var_kl"
kl_loss_fn_args:
kl_coef: 0
advantage_fn_args:
epsilon: 1e-6
optimizer:
lr: 1e-6
policy_loss_fn_args:
clip_range_low: 0.2
clip_range_high: 0.28
data_processor:
experience_pipeline:
operators: # NOTE
- name: multi_step_padding
args:
max_steps: 10
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-3B-Instruct}
max_response_tokens: 2048
min_response_tokens: 0
max_model_len: 25600
temperature: 0.7
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_epochs: 5
batch_size: 64
train_batch_size: 5120 # NOTE: 64 * 8 * 10 = batch_size * repeat_times * max_steps
explorer_input:
taskset:
name: frozenlake
storage_type: file
path: ${oc.env:TRINITY_TASKSET_PATH}
split: train
workflow_args:
env_max_steps: 8
agent_max_steps: 10
is_slippery: false
eval_tasksets:
- name: frozenlake
storage_type: file
path: ${oc.env:TRINITY_TASKSET_PATH}
split: test
workflow_args:
env_max_steps: 8
agent_max_steps: 10
is_slippery: false
repeat_times: 4
rollout_args:
top_p: 0.8
top_k: 20
default_workflow_type: 'examples.agentscope_frozenlake.workflow.FrozenLakeWorkflow'
trainer_input:
experience_buffer:
name: frozenlake_experience_buffer
storage_type: queue
max_read_timeout: 7200
explorer:
eval_on_startup: true
eval_interval: 20
runner_per_model: 6
max_repeat_times_per_runner: 4
rollout_model:
engine_num: 4
tensor_parallel_size: 1
enable_chunked_prefill: true
enforce_eager: false
enable_openai_api: true
enable_log_requests: false
enable_history: true
enable_auto_tool_choice: true
tool_call_parser: hermes
enable_thinking: true
dtype: bfloat16
seed: 42
trainer:
save_interval: 50
use_dynamic_bsz: true
grad_clip: 1.0
ulysses_sequence_parallel_size: 2
synchronizer:
sync_method: nccl
sync_style: fixed
sync_interval: 1
sync_timeout: 1200
1 change: 1 addition & 0 deletions trinity/buffer/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"pass_rate_calculator": "trinity.buffer.operators.mappers.pass_rate_calculator.PassRateCalculator",
"data_juicer": "trinity.buffer.operators.data_juicer_operator.DataJuicerOperator",
"invalid_reward_filter": "trinity.buffer.operators.filters.reward_filter.InvalidRewardFilter",
"multi_step_padding": "trinity.buffer.operators.multi_step_operator.MultiStepPadding",
},
)

Expand Down
103 changes: 103 additions & 0 deletions trinity/buffer/operators/multi_step_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing import List, Tuple

import torch

from trinity.buffer.operators import ExperienceOperator
from trinity.common.experience import EID, Experience, group_by
from trinity.utils.log import get_logger

logger = get_logger(__name__)


class MultiStepPadding(ExperienceOperator):
"""
Padding experiences of one run to the max step.

Note: This operator assumes that the reward is already calculated and stored in the Experience object.
"""

def __init__(self, max_steps: int = 0):
self.max_steps = max_steps

def process(self, exps: List[Experience]) -> Tuple[List[Experience], dict]:
"""Padding each rollout to the max step."""
logger.debug(f"Processing {len(exps)} experiences")
total_num_placeholder_exps = 0
all_exps = []

task_exps = group_by(exps, "task")
for _, task_exp in task_exps.items():
run_exps = group_by(task_exp, "run")
for _, exps_same_run in run_exps.items():
if len(exps_same_run) == 0:
continue
num_placeholder_exps = 0
if len(exps_same_run) < self.max_steps:
num_placeholder_exps = self.max_steps - len(exps_same_run)
# Calculate average response length to keep metrics unchanged
assert all(
exp.tokens is not None for exp in exps_same_run
), "Tokens are not provided"
response_lengths = [
len(exp.tokens) - exp.prompt_length for exp in exps_same_run # type: ignore
]
avg_response_length = int(sum(response_lengths) / len(response_lengths))
# Ensure at least 1 to avoid zero-length response
avg_response_length = max(avg_response_length, 1)

# Use the first experience as a template
template_exp = exps_same_run[0]
prompt_length = template_exp.prompt_length

# Create tokens with average response length
# Keep the prompt part, pad the response part to average length
prompt_tokens = template_exp.tokens[:prompt_length] # type: ignore
# Use the last token of prompt as padding token for response part
pad_token = prompt_tokens[-1] if len(prompt_tokens) > 0 else 0
response_tokens = torch.full(
(avg_response_length,),
pad_token,
dtype=template_exp.tokens.dtype, # type: ignore
)
avg_tokens = torch.cat([prompt_tokens, response_tokens])
avg_logprobs = (
torch.zeros(avg_response_length, dtype=torch.float32)
if template_exp.logprobs is not None
else None
)
assert all(
exp.reward is not None for exp in exps_same_run
), "Rewards are not provided"
rewards = [exp.reward for exp in exps_same_run if exp.reward is not None]
avg_reward = sum(rewards) / len(rewards)

template_eid = template_exp.eid

empty_experiences = [
Experience(
eid=EID(
batch=template_eid.batch,
task=template_eid.task,
run=template_eid.run,
step=-1,
), # -1 means placeholder
tokens=avg_tokens,
logprobs=avg_logprobs,
prompt_length=prompt_length,
action_mask=torch.zeros(avg_response_length, dtype=torch.bool),
truncate_status="prompt_truncated", # TODO: merge with the following
info={"status": "placeholder"}, # TODO: use another field
reward=avg_reward,
)
for _ in range(num_placeholder_exps)
]
logger.debug(f"Adding {num_placeholder_exps} placeholder experiences")
# Put empty at the beginning, as the adv is computed using the last exp
exps_same_run = empty_experiences + exps_same_run
all_exps.extend(exps_same_run)
else:
all_exps.extend(exps_same_run)
total_num_placeholder_exps += num_placeholder_exps
metrics = {"total_num_placeholder_exps": total_num_placeholder_exps}
logger.debug(f"After padding: {len(all_exps)}")
return all_exps, metrics
15 changes: 15 additions & 0 deletions trinity/common/experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,9 @@ def gather(
else:
teacher_logprobs = None

# gather statuses
statuses = gather_statuses(experiences)

exps = Experiences(
eids=eids,
tokens=tokens,
Expand All @@ -379,6 +382,7 @@ def gather(
logprobs=logprobs,
multi_modal_inputs=multi_modal_inputs,
teacher_logprobs=teacher_logprobs,
statuses=statuses,
)
if custom_fields is not None:
for custom_field in custom_fields:
Expand Down Expand Up @@ -465,6 +469,7 @@ class Experiences:
prompt_length: int
logprobs: Optional[Tensor] # [batch_size, response_length]
multi_modal_inputs: Optional[Any]
statuses: Optional[Tensor] = None # [batch_size] # 1 for effective, 0 for placeholder
custom_fields: List[str] = field(
default_factory=list
) # Custom fields to include in the gathered experiences
Expand Down Expand Up @@ -605,6 +610,16 @@ def gather_multi_modal_inputs(experiences) -> Dict[str, Tensor]:
return {key: [exp.multi_modal_inputs[key] for exp in experiences] for key in keys}


def gather_statuses(experiences) -> Tensor:
statuses = []
for exp in experiences:
if exp.info.get("status", None) == "placeholder":
statuses.append(0)
else:
statuses.append(1)
return torch.tensor(statuses, dtype=torch.bool)


def group_by(
experiences: List[Experience], id_type: Literal["task", "run", "step"]
) -> Dict[str, List[Experience]]:
Expand Down
21 changes: 20 additions & 1 deletion trinity/trainer/verl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,35 @@
gather_action_masks,
gather_attention_masks,
gather_response_attrs,
gather_statuses,
gather_token_ids,
split_dpo_experience_to_single_turn,
)


def print_effective_experience_stats(experiences: List[Experience], logger: Logger) -> None:
"""Gather effective experience count and the corresponding reweight factor."""
statuses = gather_statuses(experiences)
effective_experiences = torch.sum(statuses).item()
batch_size = len(experiences)
if effective_experiences == 0:
effective_weight = 1.0
logger.info("No effective experiences found, using default weight 1.0")
else:
effective_weight = float(batch_size / effective_experiences)
logger.info(
f"Effective experiences: {effective_experiences}, batch size: {batch_size}, effective_weight: {effective_weight}"
)
return None


def to_data_proto(
experiences: List[Experience], pad_token_id: int, processor: ProcessorMixin, logger: Logger
) -> DataProto: # noqa: C901
"""Convert List[Experience] to verl DataProto."""

print_effective_experience_stats(experiences, logger)

assert len(experiences) > 0, "No experiences provided."
if experiences[0].experience_type == "dpo":
experiences = split_dpo_experience_to_single_turn(experiences)
Expand All @@ -49,7 +69,6 @@ def to_data_proto(
"attention_mask": attention_mask,
"response_mask": gather_action_masks(experiences, max_response_length),
}

have_reward = all(exp.reward is not None for exp in experiences)
have_token_level_reward = all(exp.token_level_reward is not None for exp in experiences)
if have_reward or have_token_level_reward:
Expand Down