-
Notifications
You must be signed in to change notification settings - Fork 55
[Experimental] Add a MultiStepPadding Operator #511
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
hiyuchang
wants to merge
2
commits into
agentscope-ai:main
Choose a base branch
from
hiyuchang:feat/regularize_multi_turn
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,92 @@ | ||
| project: "FrozenLake" | ||
| name: "Qwen25-3B-padding" | ||
| checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} | ||
| algorithm: | ||
| algorithm_type: multi_step_grpo | ||
| repeat_times: 8 | ||
| kl_loss_fn: "low_var_kl" | ||
| kl_loss_fn_args: | ||
| kl_coef: 0 | ||
| advantage_fn_args: | ||
| epsilon: 1e-6 | ||
| optimizer: | ||
| lr: 1e-6 | ||
| policy_loss_fn_args: | ||
| clip_range_low: 0.2 | ||
| clip_range_high: 0.28 | ||
| data_processor: | ||
| experience_pipeline: | ||
| operators: # NOTE | ||
| - name: multi_step_padding | ||
| args: | ||
| max_steps: 10 | ||
| model: | ||
| model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-3B-Instruct} | ||
| max_response_tokens: 2048 | ||
| min_response_tokens: 0 | ||
| max_model_len: 25600 | ||
| temperature: 0.7 | ||
| cluster: | ||
| node_num: 1 | ||
| gpu_per_node: 8 | ||
| buffer: | ||
| total_epochs: 5 | ||
| batch_size: 64 | ||
| train_batch_size: 5120 # NOTE: 64 * 8 * 10 = batch_size * repeat_times * max_steps | ||
| explorer_input: | ||
| taskset: | ||
| name: frozenlake | ||
| storage_type: file | ||
| path: ${oc.env:TRINITY_TASKSET_PATH} | ||
| split: train | ||
| workflow_args: | ||
| env_max_steps: 8 | ||
| agent_max_steps: 10 | ||
| is_slippery: false | ||
| eval_tasksets: | ||
| - name: frozenlake | ||
| storage_type: file | ||
| path: ${oc.env:TRINITY_TASKSET_PATH} | ||
| split: test | ||
| workflow_args: | ||
| env_max_steps: 8 | ||
| agent_max_steps: 10 | ||
| is_slippery: false | ||
| repeat_times: 4 | ||
| rollout_args: | ||
| top_p: 0.8 | ||
| top_k: 20 | ||
| default_workflow_type: 'examples.agentscope_frozenlake.workflow.FrozenLakeWorkflow' | ||
| trainer_input: | ||
| experience_buffer: | ||
| name: frozenlake_experience_buffer | ||
| storage_type: queue | ||
| max_read_timeout: 7200 | ||
| explorer: | ||
| eval_on_startup: true | ||
| eval_interval: 20 | ||
| runner_per_model: 6 | ||
| max_repeat_times_per_runner: 4 | ||
| rollout_model: | ||
| engine_num: 4 | ||
| tensor_parallel_size: 1 | ||
| enable_chunked_prefill: true | ||
| enforce_eager: false | ||
| enable_openai_api: true | ||
| enable_log_requests: false | ||
| enable_history: true | ||
| enable_auto_tool_choice: true | ||
| tool_call_parser: hermes | ||
| enable_thinking: true | ||
| dtype: bfloat16 | ||
| seed: 42 | ||
| trainer: | ||
| save_interval: 50 | ||
| use_dynamic_bsz: true | ||
| grad_clip: 1.0 | ||
| ulysses_sequence_parallel_size: 2 | ||
| synchronizer: | ||
| sync_method: nccl | ||
| sync_style: fixed | ||
| sync_interval: 1 | ||
| sync_timeout: 1200 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| from typing import List, Tuple | ||
|
|
||
| import torch | ||
|
|
||
| from trinity.buffer.operators import ExperienceOperator | ||
| from trinity.common.experience import EID, Experience, group_by | ||
| from trinity.utils.log import get_logger | ||
|
|
||
| logger = get_logger(__name__) | ||
|
|
||
|
|
||
| class MultiStepPadding(ExperienceOperator): | ||
| """ | ||
| Padding experiences of one run to the max step. | ||
|
|
||
| Note: This operator assumes that the reward is already calculated and stored in the Experience object. | ||
| """ | ||
|
|
||
| def __init__(self, max_steps: int = 0): | ||
| self.max_steps = max_steps | ||
|
|
||
| def process(self, exps: List[Experience]) -> Tuple[List[Experience], dict]: | ||
| """Padding each rollout to the max step.""" | ||
| logger.debug(f"Processing {len(exps)} experiences") | ||
| total_num_placeholder_exps = 0 | ||
| all_exps = [] | ||
|
|
||
| task_exps = group_by(exps, "task") | ||
| for _, task_exp in task_exps.items(): | ||
| run_exps = group_by(task_exp, "run") | ||
| for _, exps_same_run in run_exps.items(): | ||
| if len(exps_same_run) == 0: | ||
| continue | ||
| num_placeholder_exps = 0 | ||
| if len(exps_same_run) < self.max_steps: | ||
| num_placeholder_exps = self.max_steps - len(exps_same_run) | ||
| # Calculate average response length to keep metrics unchanged | ||
| assert all( | ||
| exp.tokens is not None for exp in exps_same_run | ||
| ), "Tokens are not provided" | ||
| response_lengths = [ | ||
| len(exp.tokens) - exp.prompt_length for exp in exps_same_run # type: ignore | ||
| ] | ||
| avg_response_length = int(sum(response_lengths) / len(response_lengths)) | ||
| # Ensure at least 1 to avoid zero-length response | ||
| avg_response_length = max(avg_response_length, 1) | ||
|
|
||
| # Use the first experience as a template | ||
| template_exp = exps_same_run[0] | ||
| prompt_length = template_exp.prompt_length | ||
|
|
||
| # Create tokens with average response length | ||
| # Keep the prompt part, pad the response part to average length | ||
| prompt_tokens = template_exp.tokens[:prompt_length] # type: ignore | ||
| # Use the last token of prompt as padding token for response part | ||
| pad_token = prompt_tokens[-1] if len(prompt_tokens) > 0 else 0 | ||
| response_tokens = torch.full( | ||
| (avg_response_length,), | ||
| pad_token, | ||
| dtype=template_exp.tokens.dtype, # type: ignore | ||
| ) | ||
| avg_tokens = torch.cat([prompt_tokens, response_tokens]) | ||
| avg_logprobs = ( | ||
| torch.zeros(avg_response_length, dtype=torch.float32) | ||
| if template_exp.logprobs is not None | ||
| else None | ||
| ) | ||
| assert all( | ||
| exp.reward is not None for exp in exps_same_run | ||
| ), "Rewards are not provided" | ||
| rewards = [exp.reward for exp in exps_same_run if exp.reward is not None] | ||
| avg_reward = sum(rewards) / len(rewards) | ||
|
|
||
| template_eid = template_exp.eid | ||
|
|
||
| empty_experiences = [ | ||
| Experience( | ||
| eid=EID( | ||
| batch=template_eid.batch, | ||
| task=template_eid.task, | ||
| run=template_eid.run, | ||
| step=-1, | ||
| ), # -1 means placeholder | ||
| tokens=avg_tokens, | ||
| logprobs=avg_logprobs, | ||
| prompt_length=prompt_length, | ||
| action_mask=torch.zeros(avg_response_length, dtype=torch.bool), | ||
| truncate_status="prompt_truncated", # TODO: merge with the following | ||
| info={"status": "placeholder"}, # TODO: use another field | ||
| reward=avg_reward, | ||
| ) | ||
| for _ in range(num_placeholder_exps) | ||
| ] | ||
| logger.debug(f"Adding {num_placeholder_exps} placeholder experiences") | ||
| # Put empty at the beginning, as the adv is computed using the last exp | ||
| exps_same_run = empty_experiences + exps_same_run | ||
| all_exps.extend(exps_same_run) | ||
| else: | ||
| all_exps.extend(exps_same_run) | ||
| total_num_placeholder_exps += num_placeholder_exps | ||
| metrics = {"total_num_placeholder_exps": total_num_placeholder_exps} | ||
| logger.debug(f"After padding: {len(all_exps)}") | ||
| return all_exps, metrics | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.