Skip to content

feat(archon): implement LoRA infrastructure with FSDP2/DTensor compatibility and PEFT checkpointing#1015

Open
MikaStars39 wants to merge 8 commits intoinclusionAI:mainfrom
MikaStars39:lora
Open

feat(archon): implement LoRA infrastructure with FSDP2/DTensor compatibility and PEFT checkpointing#1015
MikaStars39 wants to merge 8 commits intoinclusionAI:mainfrom
MikaStars39:lora

Conversation

@MikaStars39
Copy link

Summary
This PR introduces Phase 1 & 2 of the LoRA (Low-Rank Adaptation) infrastructure for the Archon engine. It provides a robust, parallel-aware implementation of LoRA that seamlessly integrates with Tensor Parallelism (TP) and FSDP2. Crucially, it resolves a known deadlock issue between FSDP2 Data Parallel (DP) reduce-scatter and DTensor TP operations during the backward pass. It also introduces HuggingFace PEFT-compatible checkpointing for adapter weights.

Key Features & Architectural Changes

  • FSDP2-Safe LoRA Implementation (lora_linear.py): * Implemented a custom LoRALinear module.

  • Deadlock Fix: LoRA weights ($A$ and $B$ matrices) are stored as plain tensors (via object.__setattr__) rather than nn.Parameter. This intentional design bypasses FSDP2's post_accumulate_grad_hook, preventing the DP reduce-scatter operations from interleaving with DTensor TP operations, which previously caused diamond deadlocks.

  • Added sync_lora_grads to manually all-reduce LoRA weight gradients across both TP and DP groups before the optimizer step.

  • Archon Engine Integration (archon_engine.py):

  • Added dynamic LoRA application (_apply_lora) to target modules based on LoRAConfig.

  • Implemented _freeze_non_lora_params to lock base model weights while keeping adapter parameters trainable.

  • Integrated LoRA initialization with the existing parallelization pipeline, ensuring LoRA is injected after TP/CP so that tensor-parallel planning operates correctly on nn.Linear.

  • PEFT-Compatible Checkpointing (archon_lora_checkpoint.py & base.py):

  • LoRA adapters are saved and loaded in HuggingFace's PEFT format (adapter_model.safetensors and adapter_config.json).

  • Introduced module name mapping in Qwen2StateDictAdapter to automatically translate Archon-specific FQN paths (e.g., layers.0.attention.wq.lora_a) to HF PEFT paths (e.g., self_attn.q_proj.lora_A).

  • Added stripping of adapter parameters from base HuggingFace checkpoints during the initial load to prevent missing key errors.

  • Weight Sync & Reliability (archon_weight_sync.py & remote_inf_engine.py):

  • Improved the reliability of cross-node weight synchronization by implementing an atomic swap (.tmp to final) for the .areal_weight_update_ready signal file.

  • Updated the remote inference engine to prioritize checking the disk-based ready file over the legacy name-resolve key to prevent timeouts.

  • Bug Fixes:

  • Gradient Norm (grad.py): Fixed a hanging issue during get_grad_norm_fp32 when grads_for_norm is empty on certain ranks by ensuring they still participate in the all_reduce with a zero contribution.

  • Removed stale training debug info for cleaner logging.

Testing

  • Added comprehensive unit tests for LoRALinear forward/backward passes, dropout behavior, and PEFT mathematical equivalence.
  • Added tests for HuggingFace/Archon state dict key translation and round-trip conversion.
  • Verified TP/CP ordering with mocked parallelize functions.

NJX-njx and others added 5 commits March 6, 2026 21:18
Port LoRA core infrastructure and checkpointing from fw/archon-lora branch
to main, implementing Phase 1 (core LoRA modules) and Phase 2 (PEFT-
compatible checkpoint I/O) from the lora-global-plan.

Phase 1 - Core LoRA Infrastructure:
- LoRALinear module following torchtune patterns (FSDP2-compatible)
- AdapterModule protocol for parameter extraction
- Utilities: get_adapter_params, set_trainable_params, disable/enable_adapter
- PEFT-compatible naming (lowercase lora_a/lora_b internally)
- Zero-initialization of lora_b ensures initial output matches base model

Phase 2 - Checkpointing & PEFT Conversion:
- save_lora_adapter: Save adapter in PEFT format (safetensors + config)
- load_lora_adapter: Load adapter from PEFT format with key validation
- is_lora_adapter_checkpoint: Detect PEFT adapter checkpoints
- Qwen2StateDictAdapter: 16 LoRA key mappings + to_peft_module_map
- BaseStateDictAdapter: create_peft_adapter_config method
- ArchonEngine: lora_config attribute, LoRA-aware save/load

Tests:
- 18+ unit tests for LoRALinear (forward, gradient, from_linear, dropout)
- PEFT compatibility tests (skipped if PEFT not installed)
- State dict key conversion tests (all 16 mappings)
- Checkpoint detection tests
- Round-trip conversion tests

Ref inclusionAI#945
## Summary
This commit fixes critical distributed training deadlocks when using LoRA with FSDP2 and Tensor Parallelism, and improves weight synchronization reliability across training and rollout nodes.

## Key Changes

### 1. LoRA Architecture Refactor (Core Fix)
**Problem**: LoRA weights as `nn.Parameter` were managed by FSDP2, causing DP reduce-scatter to conflict with DTensor TP all-reduce during backward, leading to NCCL deadlocks.

**Solution**: Store LoRA weights as plain tensors via `object.__setattr__`, making them invisible to FSDP2.

**Files**:
- `areal/experimental/models/archon/lora/lora_linear.py`
  - Rewrote `LoRALinear` to use `object.__setattr__` for `_lora_a_weight`, `_lora_b_weight`
  - Added `lora_parameters()`, `materialize_lora()` for explicit weight management
  - Custom `_save_to_state_dict`/`_load_from_state_dict` with meta-device check
  - `sync_lora_grads` handles both TP and DP all-reduces

- `areal/experimental/models/archon/lora/adapter.py`
  - `get_adapter_params` now uses `getattr` to find plain tensor attributes

- `areal/experimental/engine/archon_engine.py`
  - `_get_all_parameters`: Collects both nn.Parameters and plain LoRA tensors
  - `_freeze_non_lora_params`: Calls `materialize_lora()` to move weights from meta device
  - `train_batch`: Updated `sync_lora_grads` call with explicit TP/DP groups

- `areal/experimental/models/archon/qwen2/model/state_dict_adapter.py`
  - Updated key mapping from `lora_a.weight` to `_lora_a_weight`

- `areal/experimental/engine/archon_lora_checkpoint.py`
  - Recognizes new `_lora_a`/`_lora_b` key prefixes
  - Uses `engine.cpu_group` (gloo) for barrier instead of NCCL

### 2. Gradient Norm Fix
**Problem**: `get_grad_norm_fp32` returned early when `grads_for_norm` was empty (on TP rank > 0), causing all-reduce deadlock.

**Solution**: Ensure all ranks participate in all-reduce, contributing 0.0 if no local gradients.

**File**: `areal/engine/fsdp_utils/grad.py`

### 3. Meta Tensor Fix
**Problem**: `Cannot copy out of meta tensor` error during model initialization when state dict was accessed before `materialize_lora()` was called.

**Solution**: Skip adding LoRA weights to state dict if still on meta device.

**File**: `areal/experimental/models/archon/lora/lora_linear.py`

### 4. Weight Sync Reliability
**Problem**: NFS-based `name_resolve` signals had cross-node visibility issues causing 600s timeouts.

**Solution**: Dual-signal mechanism using checkpoint directory ready file + legacy name_resolve.

**Files**:
- `areal/experimental/engine/archon_weight_sync.py`
  - Writes `.areal_weight_update_ready` file in checkpoint directory after save
  - Timeout increased from 120s to 600s

- `areal/infra/remote_inf_engine.py`
  - `_wait_for_disk_weight_update_ready`: Prefer ready file in checkpoint dir, fallback to name_resolve
  - More informative timeout error messages

### 5. Configuration
**File**: `recipe/areal/lora_qwen.yaml`
- `cluster.name_resolve.nfs_record_root`: Changed from `/tmp/areal/name_resolve` to `${cluster.fileroot}/name_resolve` for shared storage

## Testing
- LoRA PPO training with 4D parallelism [pp=1, dp=4, cp=1, tp=2] on DeepSeek-R1-Distill-Qwen-7B
- Verified 32 PPO minibatches complete successfully
- Gradient norm calculation works correctly across all ranks
- Weight synchronization stable across multiple training steps

## Related Issues
- Error 15-19: Sequential fixes for training hangs, autograd chain breaks, NCCL deadlocks, meta tensor errors
- Weight sync timeout: Rolled out dual-signal approach for cross-node reliability
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Archon engine by introducing a robust LoRA (Low-Rank Adaptation) infrastructure. It provides a parallel-aware implementation that integrates seamlessly with existing Tensor Parallelism and FSDP2 setups, crucially addressing a known deadlock. The changes also enable PEFT-compatible checkpointing for LoRA adapter weights, allowing for interoperability with HuggingFace ecosystems. Additionally, it refines weight synchronization mechanisms and includes a critical bug fix related to gradient norm calculation in distributed environments.

Highlights

  • LoRA Infrastructure: Implemented a custom LoRALinear module and integrated LoRA application into the Archon engine, ensuring compatibility with Tensor Parallelism (TP) and FSDP2. LoRA is injected after TP/CP to maintain correct tensor-parallel planning.
  • Deadlock Fix: Resolved a deadlock issue between FSDP2 Data Parallel (DP) reduce-scatter and DTensor TP operations by storing LoRA weights as plain tensors (not nn.Parameter), bypassing FSDP2's post_accumulate_grad_hook. Manual gradient all-reduction for LoRA weights is performed via sync_lora_grads.
  • PEFT-Compatible Checkpointing: Introduced functionality to save and load LoRA adapters in HuggingFace's PEFT format (adapter_model.safetensors and adapter_config.json). This includes module name mapping for Archon-specific FQN paths to HF PEFT paths and stripping adapter parameters from base HF checkpoints during initial load.
  • Weight Synchronization Reliability: Improved cross-node weight synchronization by implementing an atomic swap for the .areal_weight_update_ready signal file and updated the remote inference engine to prioritize this disk-based ready file.
  • Gradient Norm Bug Fix: Fixed a hanging issue in get_grad_norm_fp32 by ensuring all ranks, even those without gradients, participate in the all_reduce operation with a zero contribution.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • areal/engine/fsdp_utils/grad.py
    • Modified get_grad_norm_fp32 to prevent hangs by ensuring all ranks participate in all_reduce even if they have no gradients.
  • areal/experimental/engine/archon_checkpoint.py
    • Added logic to strip LoRA adapter parameters from HuggingFace state dicts before loading base models.
    • Modified missing keys filtering to exclude LoRA adapter keys during model loading.
  • areal/experimental/engine/archon_engine.py
    • Added lora_config attribute to ArchonEngine and initialized it based on TrainEngineConfig.
    • Integrated _freeze_non_lora_params call into the initialize method when LoRA is enabled.
    • Added sync_lora_grads call after forward_backward_batch to manually all-reduce LoRA gradients.
    • Modified save method to conditionally use save_lora_adapter for LoRA-enabled models.
    • Modified load method to conditionally use load_lora_adapter for PEFT adapter checkpoints.
    • Passed _apply_lora function to parallelization methods to inject LoRA after TP/CP.
    • Implemented _apply_lora to dynamically replace nn.Linear modules with LoRALinear.
    • Implemented _freeze_non_lora_params to manage LoRA parameter trainability and initialization.
    • Updated _get_all_parameters to include LoRA parameters.
  • areal/experimental/engine/archon_lora_checkpoint.py
    • Added save_lora_adapter function to save LoRA adapters in PEFT format.
    • Added load_lora_adapter function to load LoRA adapters from PEFT format checkpoints.
    • Added is_lora_adapter_checkpoint function to detect PEFT LoRA adapter checkpoints.
  • areal/experimental/engine/archon_runner.py
    • Added total_mbs and mb_idx to the minibatch processing loop.
    • Removed a redundant comment regarding result types.
  • areal/experimental/engine/archon_weight_sync.py
    • Defined WEIGHT_UPDATE_READY_FILE constant for atomic weight update signaling.
    • Modified update_weights_from_disk to use save_lora_adapter for LoRA models and implemented an atomic file-based ready signal.
  • areal/experimental/models/archon/base.py
    • Added to_peft_module_map attribute to BaseStateDictAdapter for LoRA module name mapping.
    • Added create_peft_adapter_config method to generate PEFT-compatible adapter_config.json.
  • areal/experimental/models/archon/lora/init.py
    • Created the lora package and exposed its public API for LoRA modules and utilities.
  • areal/experimental/models/archon/lora/adapter.py
    • Defined AdapterModule protocol for modules containing adapter parameters.
    • Added get_adapter_params to extract adapter parameters from a model.
    • Added set_trainable_params to freeze/unfreeze model parameters.
    • Added get_adapter_state_dict to filter state dictionaries for adapter parameters.
    • Added disable_adapter and enable_adapter functions to control LoRA adapter activation.
  • areal/experimental/models/archon/lora/lora_linear.py
    • Implemented LoRALinear module, a custom linear layer for LoRA.
    • Designed LoRALinear to store LoRA weights as plain tensors to prevent FSDP2 deadlocks.
    • Included sync_lora_grads function for manual all-reduction of LoRA gradients across TP and DP groups.
  • areal/experimental/models/archon/qwen2/infra/parallelize.py
    • Imported Callable type for function annotations.
    • Added apply_lora_fn parameter to parallelize_qwen2.
    • Called apply_lora_fn after TP/CP to ensure correct LoRA injection order.
  • areal/experimental/models/archon/qwen3/infra/parallelize.py
    • Imported Callable type for function annotations.
    • Added apply_lora_fn parameter to parallelize_qwen3.
    • Called apply_lora_fn after TP/EP/CP for correct LoRA injection order.
  • areal/infra/remote_inf_engine.py
    • Added _wait_for_disk_weight_update_ready function to prioritize file-based weight update signals.
    • Updated _wait_for_disk_weight_update_ready call and increased its timeout.
  • areal/trainer/ppo/actor.py
    • Added _n_mbs and _mb_idx variables to the minibatch loop for internal tracking.
  • areal/trainer/rl_trainer.py
    • Removed a comment related to waiting for async checkpoint staging.
  • areal/utils/logging.py
    • Added 'LoRACheckpoint' to the LOG_COLORS dictionary for colored logging output.
  • tests/experimental/archon/test_archon_engine_lora.py
    • Added unit tests for ArchonEngine's LoRA integration, covering _apply_lora and _freeze_non_lora_params.
    • Included tests for update_weights_from_disk behavior with LoRA.
    • Verified the correct application order of LoRA within parallelization strategies.
  • tests/experimental/archon/test_archon_lora_checkpoint.py
    • Added unit tests for Qwen2StateDictAdapter's LoRA key conversions.
    • Included tests for PEFT adapter configuration generation.
    • Verified the functionality of is_lora_adapter_checkpoint for detecting PEFT checkpoints.
    • Tested state dict round-trip conversion with LoRA keys.
  • tests/experimental/archon/test_lora_linear.py
    • Added comprehensive unit tests for the LoRALinear module, including initialization, forward/backward passes, and dropout.
    • Tested from_linear conversion and AdapterModule protocol implementation.
    • Included tests for adapter utility functions like get_adapter_params and set_trainable_params.
    • Provided compatibility tests against HuggingFace PEFT's LoRA Linear module for forward pass, gradient flow, and scaling factor.
Activity
  • No human activity has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive LoRA (Low-Rank Adaptation) infrastructure compatible with FSDP2 and DTensor, which is a significant feature. The implementation correctly handles the complexities of distributed training, including a fix for a potential deadlock. It also adds PEFT-compatible checkpointing, which is great for interoperability. The code is generally well-structured, but there are a few areas for improvement regarding code style, clarity, and correctness in tests. Most notably, several new test files appear to be written for a different, outdated implementation of LoRALinear and will not work with the submitted code. This is a critical issue that needs to be addressed to ensure the new functionality is properly tested.

Comment on lines +69 to +80
def test_freeze_non_lora_params_keeps_only_adapter_trainable():
model = _ToyBlock()
engine = _make_engine(model, ["wq"])

engine._apply_lora()
engine._freeze_non_lora_params()

assert model.wq.weight.requires_grad is False
assert model.wq.lora_a.weight.requires_grad is True
assert model.wq.lora_b.weight.requires_grad is True
assert model.other.weight.requires_grad is False
assert model.inner.wv.weight.requires_grad is False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This test appears to be written for a different implementation of LoRALinear. It accesses model.wq.lora_a.weight, but the LoRALinear implementation in this PR does not have a lora_a submodule; it uses a plain tensor attribute _lora_a_weight. This test will fail and does not correctly validate the freezing logic for the submitted code. The test needs to be updated to access the LoRA weights correctly (e.g., model.wq._lora_a_weight).

@@ -0,0 +1,528 @@
"""Unit tests for LoRALinear module and adapter utilities."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The tests in this file appear to be written for a different implementation of LoRALinear. The current implementation in areal/experimental/models/archon/lora/lora_linear.py stores LoRA weights as plain tensor attributes (e.g., _lora_a_weight) to avoid FSDP hooks. However, these tests attempt to access them as if they were nn.Linear submodules (e.g., lora_linear.lora_a.weight).

This mismatch means the tests will fail and are not validating the submitted code. The tests need to be updated to reflect the actual LoRALinear implementation. For example, lora_linear.lora_b.weight should be lora_linear._lora_b_weight.

Comment on lines +198 to +206
@dataclass
class LoRAConfig:
enabled: bool
rank: int
alpha: float
target_modules: list[str]

self.lora_config = LoRAConfig(
enabled=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Defining the LoRAConfig dataclass inside the __init__ method causes it to be redefined on every instantiation of ArchonEngine. It would be better for clarity, performance, and potential reuse to define it at the module level.

Comment on lines +513 to +515
)
sync_lora_grads(
self.model,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Local imports can obscure dependencies and are best avoided unless there's a specific reason like preventing circular imports. Consider moving this import to the top of the file for better code clarity and consistency.

Comment on lines +1085 to +1086
if self.lora_config is not None:
from areal.experimental.models.archon.lora import LoRALinear
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The method _get_all_parameters is type-hinted to return list[nn.Parameter], but it is being extended with torch.Tensor objects from module.lora_parameters(). This creates a type inconsistency. Please update the type hint to reflect the actual return type, for example list[torch.Tensor] or typing.Union[nn.Parameter, torch.Tensor].

Suggested change
if self.lora_config is not None:
from areal.experimental.models.archon.lora import LoRALinear
def _get_all_parameters(self) -> list[torch.Tensor]:
params: list[torch.Tensor] = [p for m in self.model_parts for p in m.parameters()]

if self._tp_enabled:
result = self._tp_lora_forward(x, base_out)
if result.requires_grad and hasattr(self, "_debug_name"):
_name = self._debug_name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable _name is assigned but never used. This appears to be dead code and should be removed.

forward_only: bool,
) -> list[torch.Tensor | dict[int, torch.Tensor]]:
results: list[torch.Tensor | dict[int, torch.Tensor]] = []
total_mbs = len(mb_list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable total_mbs is defined but never used. This is dead code and should be removed to improve clarity.

total_mbs = len(mb_list)

for mb_item in mb_list:
for mb_idx, mb_item in enumerate(mb_list):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable mb_idx is defined but never used. Consider using for mb_item in mb_list: instead.

Suggested change
for mb_idx, mb_item in enumerate(mb_list):
for mb_item in mb_list:

with stats_tracker.scope("update"):
# Get current version for proximal approximation metrics
current_version = self.engine.get_version()
_n_mbs = len(mb_inputs.mbs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable _n_mbs is defined but never used. This is dead code and should be removed to improve clarity.

_n_mbs = len(mb_inputs.mbs)

for mb in mb_inputs.mbs:
for _mb_idx, mb in enumerate(mb_inputs.mbs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable _mb_idx is defined but never used. Consider using for mb in mb_inputs.mbs: instead.

Suggested change
for _mb_idx, mb in enumerate(mb_inputs.mbs):
for mb in mb_inputs.mbs:

@MikaStars39
Copy link
Author

Testing result on DAPO-math-17k and Deepseek-distill-qwen-1.5B:
image

settings:

experiment_name: gsm8k-grpo
trial_name: trial0

seed: 1
enable_offload: false
total_train_epochs: 10
tokenizer_path: ${actor.path}

cluster:
  n_nodes: 1
  n_gpus_per_node: 8
  fileroot: /your_path/qingyu/PeRL/outputs
  name_resolve:
    type: nfs
    nfs_record_root: ${cluster.fileroot}/name_resolve

allocation_mode: sglang:d4+archon:d2t2

scheduler:
  type: null

rollout:
  experiment_name: ${experiment_name}
  trial_name: ${trial_name}
  max_concurrent_rollouts: 128
  queue_size: null
  consumer_batch_size: ${train_dataset.batch_size}
  max_head_offpolicyness: 2
  enable_rollout_tracing: true
  scheduling_spec: ${actor.scheduling_spec}
  use_lora: true
  fileroot: ${cluster.fileroot}
  tokenizer_path: ${tokenizer_path}
  dump_to_file: true

gconfig:
  n_samples: 8
  min_new_tokens: 0
  max_new_tokens: 16384
  greedy: false
  temperature: 1.0
  lora_name: "lora"

actor:
  experiment_name: ${experiment_name}
  trial_name: ${trial_name}
  path: /your_path/qingyu/.cache/DeepSeek-R1-Distill-Qwen-1.5B
  init_from_scratch: false
  disable_dropout: true
  gradient_checkpointing: true
  dtype: bfloat16
  mb_spec:
    max_tokens_per_mb: 32768
  optimizer:
    type: adam
    lr: 2e-5
    weight_decay: 0.01
    beta1: 0.9
    beta2: 0.98
    eps: 1e-8
    lr_scheduler_type: constant
    gradient_clipping: 1.0
    warmup_steps_proportion: 0.001
  eps_clip: 0.2
  eps_clip_higher: 0.28
  temperature: ${gconfig.temperature}
  reward_scaling: 10.0
  reward_bias: -0.5
  kl_ctl: 0.0
  ppo_n_minibatches: 4
  recompute_logprob: true
  use_decoupled_loss: true
  behave_imp_weight_cap: 5.0
  reward_norm:
    mean_level: group
    std_level: group
    group_size: ${gconfig.n_samples}
  adv_norm:
    mean_level: batch
    std_level: batch
  max_new_tokens: ${gconfig.max_new_tokens}
  weight_update_mode: disk  # must be disk

  # lora
  use_lora: ${rollout.use_lora}
  peft_type: lora
  lora_rank: 32
  lora_alpha: 32
  target_modules: [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj]
  scheduling_spec:
    - task_type: worker
      port_count: 2
      gpu: 1
      mem: 32
      cmd: python3 -m areal.infra.rpc.rpc_server
      env_vars: {}

ref:
  experiment_name: ${experiment_name}
  trial_name: ${trial_name}
  path: ${actor.path}
  init_from_scratch: false
  disable_dropout: true
  dtype: ${actor.dtype}
  mb_spec:
    max_tokens_per_mb: 32768
  optimizer: null
  scheduling_strategy:
    type: colocation
    target: actor
  scheduling_spec: ${actor.scheduling_spec}

# SGLang
sglang:
  model_path: ${actor.path}
  random_seed: ${seed}
  skip_tokenizer_init: true
  dtype: ${actor.dtype}
  max_running_requests: null
  context_length: 32768
  mem_fraction_static: 0.8
  # lora
  enable_lora: ${actor.use_lora}
  max_lora_rank: ${actor.lora_rank}

# datasets
train_dataset:
  batch_size: 64
  shuffle: true
  pin_memory: true
  num_workers: 4
  path: /your_path/qingyu/.cache/dapo-math-17k
  type: rl

valid_dataset:
  batch_size: 64
  pin_memory: true
  num_workers: 4
  path: /your_path/qingyu/.cache/aime-2024
  type: rl

# Utilities
saver:
  experiment_name: ${experiment_name}
  trial_name: ${trial_name}
  fileroot: ${cluster.fileroot}
  freq_epochs: null
  freq_steps: 32
  freq_secs: null

recover:
  mode: disabled
  experiment_name: ${experiment_name}
  trial_name: ${trial_name}
  fileroot: ${cluster.fileroot}
  freq_epochs: null
  freq_steps: 32
  freq_secs: null

evaluator:
  experiment_name: ${experiment_name}
  trial_name: ${trial_name}
  fileroot: ${cluster.fileroot}
  freq_epochs: null
  freq_steps: 32
  freq_secs: null

stats_logger:
  experiment_name: ${experiment_name}
  trial_name: ${trial_name}
  fileroot: ${cluster.fileroot}

perf_tracer:
  experiment_name: ${experiment_name}
  trial_name: ${trial_name}
  fileroot: ${cluster.fileroot}
  enabled: false
  session_tracer:
    enabled: false

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants