Skip to content

Changes for training#708

Open
azazeal04 wants to merge 19 commits intoace-step:mainfrom
azazeal04:main
Open

Changes for training#708
azazeal04 wants to merge 19 commits intoace-step:mainfrom
azazeal04:main

Conversation

@azazeal04
Copy link

@azazeal04 azazeal04 commented Feb 26, 2026

proposing improvements for the lora training and full model finetune option
Implemented a concrete dataset QA phase with a new validate-dataset subcommand in CLI parsing and dispatch, plus a validator that checks required tensor keys, flags non-finite values, and reports latent-length stats before training.

Implemented data pipeline throughput improvements in the runtime loader: latent-length bucketing (BucketedBatchSampler) and optional RAM LRU caching in PreprocessedTensorDataset, and wired them through PreprocessedDataModule and trainer config flow.

Implemented full fine-tuning MVP (decoder-only) with training_mode=full: explicit freeze/unfreeze behavior, param-family LR group construction, and optimizer group usage in both Fabric and non-Fabric loops.

Added full-mode checkpoint/resume support (full_decoder_state.pt) and compatibility checks in save/verify/resume helpers so full FT can be resumed like adapter training state flows.

Added focused unit tests for new behavior:

dataset validation success+edge path,

full-mode decoder-only trainability and param-group coverage.

Summary by CodeRabbit

  • New Features

    • Dataset validation command to scan preprocessed training data and report stats/errors
    • Length-based bucketing for training batches
    • RAM LRU caching option for dataset loading with configurable size
    • New "full" decoder-only fine-tuning mode with per-component LR multipliers
    • New collate helper for preprocessed tensors and added CLI/config flags for bucketing, caching, and full-mode
  • Tests

    • Unit tests for dataset validation and full-mode training behavior
  • Documentation

    • Proposal doc outlining full fine-tune and data throughput plans

…ora-training-pipeline

Add full fine-tune mode, dataset validation, length bucketing and dataset cache
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 26, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds length-based bucketed batching and an in-memory LRU cache for preprocessed datasets, a dataset validation CLI, and a "full" decoder-only fine-tuning mode with per-family LR multipliers; wires new flags through CLI, TrainingConfigV2, data modules, trainer, and checkpoint/resume helpers.

Changes

Cohort / File(s) Summary
Data module & legacy API
acestep/training/data_module.py
Added BucketedBatchSampler; extended PreprocessedTensorDataset with latent_lengths, RAM LRU cache (_cache) and cache_policy/cache_max_items; PreprocessedDataModule and AceStepDataModule accept and propagate length_bucket, cache_policy, cache_max_items; train_dataloader switches to bucketed sampler when enabled.
Preprocessed dataset & sampler
acestep/training/preprocessed_dataset.py, acestep/training/preprocessed_sampler.py, acestep/training/preprocessed_collate.py
New PreprocessedTensorDataset that resolves manifests and optionally RAM-LRU caches samples; new BucketedBatchSampler grouping indices by coarse latent-length buckets; new collate_preprocessed_batch pads and stacks batched tensors.
CLI, config wiring & docs
acestep/training_v2/cli/args.py, acestep/training_v2/cli/config_builder.py, acestep/training_v2/configs.py, docs/lora_full_finetune_pipeline_proposal.md, train.py
Added validate-dataset subcommand; new CLI flags --length-bucket, --cache-policy, --cache-max-items, plus training_mode / full-train options and LR multipliers; TrainingConfigV2 extended with corresponding fields and serialization; train.py handles dataset validation reporting; design doc added.
Full fine-tune support (FixedLoRA)
acestep/training_v2/fixed_lora_module.py, acestep/training_v2/fixed_lora_module_full_mode_test.py
Added training_mode="full" support: helpers to classify attention/FFN params, _configure_full_finetune() to freeze non-decoder params, and build_full_mode_param_groups() creating LR-multiplier groups; tests added for decoder-only training and param-group coverage.
Trainer integration & helpers
acestep/training_v2/trainer_basic_loop.py, acestep/training_v2/trainer_fixed.py, acestep/training_v2/trainer_helpers.py, acestep/training_v2/trainer_helpers_test.py
Trainer now uses full-mode param groups when training_mode=="full"; data module constructed with new caching/bucketing flags; added _save_full_decoder_state()/load plus unified resume logic to support full-mode checkpointing and adapter/LoRA/LoKR resume paths; tests for resume warnings.
Dataset validation & tests
acestep/training_v2/dataset_validation.py, acestep/training_v2/dataset_validation_test.py
New validate_preprocessed_dataset() that scans .pt files for required keys, non-finite checks, and latent-length stats; unit tests added verifying valid/invalid/NaN reporting.
Tests & small adjustments
acestep/training/data_module_test.py
Added/updated tests to cover AceStepDataModule initialization and exported BucketedBatchSampler reference for compatibility testing.

Sequence Diagram

sequenceDiagram
    participant User as CLI/User
    participant Config as TrainingConfigV2
    participant Trainer as Trainer
    participant Module as FixedLoRAModule
    participant Optimizer as Optimizer
    participant DataModule as DataModule
    participant Dataset as PreprocessedTensorDataset

    rect rgba(100,150,200,0.5)
    Note over User,Module: Initialization with full-mode and data options
    User->>Trainer: start (training_mode="full", length_bucket=...)
    Trainer->>Config: read training_mode, lr multipliers, caching flags
    Config-->>Trainer: config
    Trainer->>Module: init with training_config
    Module->>Module: _configure_full_finetune() (freeze except decoder)
    Module-->>Trainer: module ready
    end

    rect rgba(150,100,200,0.5)
    Note over Trainer,Optimizer: Build optimizer params for full-mode
    Trainer->>Module: build_full_mode_param_groups()
    Module-->>Trainer: param groups (attn/ffn/other)
    Trainer->>Optimizer: init with grouped params
    Optimizer-->>Trainer: optimizer ready
    end

    rect rgba(100,200,150,0.5)
    Note over Trainer,Dataset: Data loading with bucketing & caching
    Trainer->>DataModule: create with length_bucket, cache_policy
    DataModule->>Dataset: init (compute latent_lengths, init LRU cache)
    DataModule->>DataModule: use BucketedBatchSampler if enabled
    DataModule-->>Trainer: dataloader ready
    end

    rect rgba(200,150,100,0.5)
    Note over Trainer,Dataset: Training loop sample flow
    Trainer->>Dataset: __getitem__(idx)
    Dataset->>Dataset: check _cache
    alt cache hit
        Dataset-->>Trainer: return cached tensors
    else cache miss
        Dataset->>Dataset: load from disk, store in LRU cache
        Dataset-->>Trainer: return tensors
    end
    Trainer->>Optimizer: backward & step
    Optimizer-->>Trainer: updated
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • ChuxiJ

Poem

🐰 Buckets and caches, a nimble little hop,
Decoder stays cozy while gradients flip-flop,
Validators sniff latents, tidy and bright,
LRU keeps memory snug through day and night,
Hooray — training hops forward with a twitch and a pop!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Title check ⚠️ Warning The title "Changes for training" is too vague and generic; it uses non-descriptive language that does not convey the specific nature or scope of the substantial changes in this PR. Revise to a more specific title that highlights the main feature(s), such as "Add full fine-tuning MVP and dataset QA for training" or "Implement full-mode training with bucketing and validation".
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Docstring Coverage ✅ Passed Docstring coverage is 86.05% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
acestep/training/data_module.py (1)

534-557: ⚠️ Potential issue | 🔴 Critical

Missing parameters in __init__ signature cause NameError at runtime.

The parameters length_bucket, cache_policy, and cache_max_items are assigned to self at lines 554-556, but they are not declared in the __init__ signature. This will raise a NameError when instantiating AceStepDataModule.

🐛 Proposed fix to add missing parameters
     def __init__(
         self,
         samples: List[Dict[str, Any]],
         dit_handler,
         batch_size: int = 1,
         num_workers: int = 4,
         pin_memory: bool = True,
         max_duration: float = 240.0,
         val_split: float = 0.0,
+        length_bucket: bool = False,
+        cache_policy: str = "none",
+        cache_max_items: int = 0,
     ):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training/data_module.py` around lines 534 - 557, The __init__ of
AceStepDataModule assigns self.length_bucket, self.cache_policy, and
self.cache_max_items but those parameters are missing from the signature; add
them to the __init__ signature (e.g. length_bucket: Optional[Sequence[int]] =
None, cache_policy: str = "none", cache_max_items: int = 0) and ensure the
assignments remain (self.length_bucket = length_bucket, self.cache_policy =
cache_policy, self.cache_max_items = cache_max_items) so instantiating
AceStepDataModule no longer raises NameError.
🧹 Nitpick comments (5)
acestep/training/data_module.py (3)

179-188: Cache is per-worker and won't be shared across DataLoader workers.

With num_workers > 0, PyTorch spawns separate processes, each with its own Dataset copy. The LRU cache won't be shared, reducing its effectiveness and increasing total memory usage proportionally to worker count.

Consider documenting this behavior or recommending num_workers=0 when using RAM caching. Alternatively, cache_policy="none" with persistent_workers=True may provide better performance for disk-based loading.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training/data_module.py` around lines 179 - 188, The LRU RAM cache
implemented by dataset attributes cache_policy, _cache and cache_max_items (in
the data loading block that reads valid_paths and uses torch.load) is
per-process and will be duplicated for each DataLoader worker when
num_workers>0; update documentation and/or dataset behavior: add a note in the
class docstring or surrounding comments explaining that RAM caching is
process-local and recommend using num_workers=0 for ram_lru, or suggest using
cache_policy="none" together with persistent_workers=True for disk-based
loading, and (optionally) add a runtime warning inside the Dataset __init__ or
before loading (checking torch.utils.data.get_worker_info or os.getpid
differences) when cache_policy=="ram_lru" and num_workers>0 to inform users of
the per-worker duplication.

59-61: __len__ may return inaccurate batch count.

The current implementation divides total samples by batch size, but actual batch count depends on bucket distribution. Each bucket's last batch may be smaller, leading to more batches than this estimate. This can cause misleading progress bars or epoch-length mismatches.

Consider counting actual batches:

♻️ Proposed fix for accurate batch count
     def __len__(self) -> int:
-        total = len(self.lengths)
-        return (total + self.batch_size - 1) // self.batch_size
+        buckets: Dict[int, int] = {}
+        for length in self.lengths:
+            bucket = int(length // 64)
+            buckets[bucket] = buckets.get(bucket, 0) + 1
+        return sum((count + self.batch_size - 1) // self.batch_size for count in buckets.values())
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training/data_module.py` around lines 59 - 61, The __len__ method
currently computes batches using total = len(self.lengths) which is wrong
because batches depend on each bucket's sample count; change __len__ (in the
class defining __len__) to sum the per-bucket batch counts by iterating over
self.lengths and for each length l adding (l + self.batch_size - 1) //
self.batch_size (or math.ceil(l / self.batch_size)) so the returned value equals
the actual number of batches across all buckets.

554-556: Attributes stored but never used in this deprecated class.

Even after fixing the signature, length_bucket, cache_policy, and cache_max_items are never referenced in setup() or train_dataloader(). Consider either implementing the functionality for consistency with PreprocessedDataModule, or removing these unused attributes since the class is deprecated.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training/data_module.py` around lines 554 - 556, The deprecated data
module stores unused attributes length_bucket, cache_policy, and
cache_max_items; either remove these attributes and any constructor parameters
that set them, or wire them up so the class behaves like PreprocessedDataModule
by using length_bucket in setup() to group examples into buckets, and applying
cache_policy/cache_max_items in train_dataloader() when building the in-memory
or disk cache for preprocessed samples; update the constructor, setup(), and
train_dataloader() implementations (referencing the same attribute names and
method names) to ensure the attributes are actually consumed or are removed to
avoid dead state.
acestep/training_v2/dataset_validation_test.py (1)

45-49: Add assertions for latent-length statistics to cover new behavior.

Current assertions validate counts, but not min_latent_length, max_latent_length, or avg_latent_length, which are part of the new validator output.

Proposed test additions
         self.assertEqual(report["total_samples"], 3)
         self.assertEqual(report["valid_samples"], 2)
         self.assertEqual(report["invalid_samples"], 1)
         self.assertEqual(report["nan_or_inf_samples"], 1)
+        self.assertEqual(report["min_latent_length"], 1)
+        self.assertEqual(report["max_latent_length"], 8)
+        self.assertAlmostEqual(report["avg_latent_length"], 4.5)
         self.assertGreaterEqual(len(report["errors"]), 1)

As per coding guidelines **/*_test.py: "Add or update tests for every behavior change and bug fix".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training_v2/dataset_validation_test.py` around lines 45 - 49, The
test currently asserts counts in the `report` but misses the new latent-length
stats; update the test that inspects the `report` (the assertions block where
`self.assertEqual(report["total_samples"], 3)` etc.) to also assert
`report["min_latent_length"]`, `report["max_latent_length"]`, and
`report["avg_latent_length"]` with the expected numeric values for this fixture
(compute expected min/max/avg from the three sample latents used in the test) so
the new validator output is covered.
acestep/training_v2/fixed_lora_module_full_mode_test.py (1)

71-74: set comparison does not verify “exactly once” membership.

Using sets here can hide duplicate parameters across groups. Add an explicit duplicate check before the set-equality assertion.

Suggested test hardening
         groups = module.build_full_mode_param_groups()
-        grouped = {id(p) for group in groups for p in group["params"]}
+        grouped_ids = [id(p) for group in groups for p in group["params"]]
+        self.assertEqual(
+            len(grouped_ids),
+            len(set(grouped_ids)),
+            "A trainable parameter appears in more than one optimizer group",
+        )
+        grouped = set(grouped_ids)
         trainable = {id(p) for p in module.model.parameters() if p.requires_grad}
         self.assertSetEqual(grouped, trainable)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training_v2/fixed_lora_module_full_mode_test.py` around lines 71 -
74, The test uses sets (grouped/trainable) which hide duplicate parameter
entries across groups; modify the test around build_full_mode_param_groups() to
first flatten the group["params"] into a list (e.g., groups ->
list_of_params_ids) and assert there are no duplicates by checking
len(list_of_params_ids) == len(set(list_of_params_ids)) (or by asserting
Counter(list_of_params_ids).most_common(1)[0][1] == 1), then proceed to compare
the sets (grouped vs trainable) as before; reference
build_full_mode_param_groups, groups, group["params"], grouped, and trainable
when making the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@acestep/training_v2/dataset_validation.py`:
- Around line 25-29: validate_preprocessed_dataset currently treats a
non-existent or non-directory path as valid by returning zero samples; after
computing root = Path(dataset_dir) add an explicit existence and directory check
(e.g., if not root.exists() or not root.is_dir()) and raise a clear exception
(FileNotFoundError or ValueError) so callers get an immediate, descriptive
failure instead of a false "valid" report; reference
validate_preprocessed_dataset, root, and files when making the change.

In `@acestep/training_v2/fixed_lora_module_full_mode_test.py`:
- Around line 17-23: Add a concise docstring inside _DummyModel.__init__
describing the purpose of the initializer and the attributes it sets (e.g.,
creates decoder: nn.Sequential of two Linear(4,4), encoder: nn.Linear(4,4),
null_condition_emb: nn.Parameter of zeros with shape (1,1,4), and a simple
config object), and note there are no parameters to __init__; place the
docstring as the first statement in the __init__ method so it documents the
constructor and its initialized members.

In `@acestep/training_v2/trainer_helpers.py`:
- Around line 367-369: The stray block in resume_checkpoint incorrectly
references output_dir and calls _save_full_decoder_state during resume, causing
a runtime crash; remove this entire conditional block (the if checking
getattr(trainer.training_config, "training_mode", "adapter") == "full" that
calls _save_full_decoder_state and returns) from resume_checkpoint so resume
logic does not perform saves or reference undefined output_dir, or if a mode
check is needed keep it readonly (no save/return) and use existing save paths
elsewhere.

In `@acestep/training/data_module.py`:
- Around line 355-362: The bucketing code fails silently when self.train_dataset
is a torch.utils.data.Subset because hasattr(self.train_dataset,
"latent_lengths") is false for the Subset wrapper; update the condition in the
block that sets BucketedBatchSampler (symbols: length_bucket,
self.train_dataset, latent_lengths, BucketedBatchSampler) to look through the
Subset wrapper and obtain the underlying dataset's latent_lengths (e.g., check
getattr(self.train_dataset, "dataset", self.train_dataset) for latent_lengths)
before deciding to pop batch_size/shuffle and assign batch_sampler, so the
BucketedBatchSampler still receives the correct lengths when val_split > 0.
- Around line 122-128: The loop that builds self.latent_lengths in DataModule
(using self.valid_paths) currently uses a bare except which masks errors; change
it to catch specific exceptions when loading torch files (e.g.,
FileNotFoundError, PermissionError, EOFError, RuntimeError) and log the failure
including the path and exception (use the module/class logger or python
logging), append 0 only for those expected/recoverable errors, and let other
unexpected exceptions propagate (or re-raise) so real issues aren’t silently
ignored; update the try/except around torch.load and the
sample["target_latents"] access accordingly and include the path (vp) and
exception text in the log message.

In `@train.py`:
- Around line 213-230: The current exit logic uses only
report["invalid_samples"] to decide success, allowing empty datasets to return
success; update the final return condition to require report["total_samples"] >
0 in addition to report["invalid_samples"] == 0 (i.e., return success only when
there is at least one sample and no invalid samples). Locate the block that
calls validate_preprocessed_dataset and references report (the print/report
summary block) and change the return expression to check both
report["total_samples"] > 0 and report["invalid_samples"] == 0; optionally add a
short warning/log when total_samples == 0 before returning failure.

---

Outside diff comments:
In `@acestep/training/data_module.py`:
- Around line 534-557: The __init__ of AceStepDataModule assigns
self.length_bucket, self.cache_policy, and self.cache_max_items but those
parameters are missing from the signature; add them to the __init__ signature
(e.g. length_bucket: Optional[Sequence[int]] = None, cache_policy: str = "none",
cache_max_items: int = 0) and ensure the assignments remain (self.length_bucket
= length_bucket, self.cache_policy = cache_policy, self.cache_max_items =
cache_max_items) so instantiating AceStepDataModule no longer raises NameError.

---

Nitpick comments:
In `@acestep/training_v2/dataset_validation_test.py`:
- Around line 45-49: The test currently asserts counts in the `report` but
misses the new latent-length stats; update the test that inspects the `report`
(the assertions block where `self.assertEqual(report["total_samples"], 3)` etc.)
to also assert `report["min_latent_length"]`, `report["max_latent_length"]`, and
`report["avg_latent_length"]` with the expected numeric values for this fixture
(compute expected min/max/avg from the three sample latents used in the test) so
the new validator output is covered.

In `@acestep/training_v2/fixed_lora_module_full_mode_test.py`:
- Around line 71-74: The test uses sets (grouped/trainable) which hide duplicate
parameter entries across groups; modify the test around
build_full_mode_param_groups() to first flatten the group["params"] into a list
(e.g., groups -> list_of_params_ids) and assert there are no duplicates by
checking len(list_of_params_ids) == len(set(list_of_params_ids)) (or by
asserting Counter(list_of_params_ids).most_common(1)[0][1] == 1), then proceed
to compare the sets (grouped vs trainable) as before; reference
build_full_mode_param_groups, groups, group["params"], grouped, and trainable
when making the change.

In `@acestep/training/data_module.py`:
- Around line 179-188: The LRU RAM cache implemented by dataset attributes
cache_policy, _cache and cache_max_items (in the data loading block that reads
valid_paths and uses torch.load) is per-process and will be duplicated for each
DataLoader worker when num_workers>0; update documentation and/or dataset
behavior: add a note in the class docstring or surrounding comments explaining
that RAM caching is process-local and recommend using num_workers=0 for ram_lru,
or suggest using cache_policy="none" together with persistent_workers=True for
disk-based loading, and (optionally) add a runtime warning inside the Dataset
__init__ or before loading (checking torch.utils.data.get_worker_info or
os.getpid differences) when cache_policy=="ram_lru" and num_workers>0 to inform
users of the per-worker duplication.
- Around line 59-61: The __len__ method currently computes batches using total =
len(self.lengths) which is wrong because batches depend on each bucket's sample
count; change __len__ (in the class defining __len__) to sum the per-bucket
batch counts by iterating over self.lengths and for each length l adding (l +
self.batch_size - 1) // self.batch_size (or math.ceil(l / self.batch_size)) so
the returned value equals the actual number of batches across all buckets.
- Around line 554-556: The deprecated data module stores unused attributes
length_bucket, cache_policy, and cache_max_items; either remove these attributes
and any constructor parameters that set them, or wire them up so the class
behaves like PreprocessedDataModule by using length_bucket in setup() to group
examples into buckets, and applying cache_policy/cache_max_items in
train_dataloader() when building the in-memory or disk cache for preprocessed
samples; update the constructor, setup(), and train_dataloader() implementations
(referencing the same attribute names and method names) to ensure the attributes
are actually consumed or are removed to avoid dead state.

ℹ️ Review info

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b4a045b and a3ff595.

📒 Files selected for processing (13)
  • acestep/training/data_module.py
  • acestep/training_v2/cli/args.py
  • acestep/training_v2/cli/config_builder.py
  • acestep/training_v2/configs.py
  • acestep/training_v2/dataset_validation.py
  • acestep/training_v2/dataset_validation_test.py
  • acestep/training_v2/fixed_lora_module.py
  • acestep/training_v2/fixed_lora_module_full_mode_test.py
  • acestep/training_v2/trainer_basic_loop.py
  • acestep/training_v2/trainer_fixed.py
  • acestep/training_v2/trainer_helpers.py
  • docs/lora_full_finetune_pipeline_proposal.md
  • train.py

Comment on lines +25 to +29
def validate_preprocessed_dataset(dataset_dir: str) -> Dict[str, Any]:
"""Validate a preprocessed dataset directory and return summary stats."""
root = Path(dataset_dir)
files = sorted(p for p in root.glob("*.pt") if p.name != "manifest.json")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Missing dataset directory validation can produce false “valid” reports.

At Line 27-29, a non-existent path silently returns zero samples instead of failing, which can mask configuration mistakes.

Proposed fix
 def validate_preprocessed_dataset(dataset_dir: str) -> Dict[str, Any]:
     """Validate a preprocessed dataset directory and return summary stats."""
     root = Path(dataset_dir)
+    if not root.is_dir():
+        raise FileNotFoundError(f"Dataset directory not found: {root}")
     files = sorted(p for p in root.glob("*.pt") if p.name != "manifest.json")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def validate_preprocessed_dataset(dataset_dir: str) -> Dict[str, Any]:
"""Validate a preprocessed dataset directory and return summary stats."""
root = Path(dataset_dir)
files = sorted(p for p in root.glob("*.pt") if p.name != "manifest.json")
def validate_preprocessed_dataset(dataset_dir: str) -> Dict[str, Any]:
"""Validate a preprocessed dataset directory and return summary stats."""
root = Path(dataset_dir)
if not root.is_dir():
raise FileNotFoundError(f"Dataset directory not found: {root}")
files = sorted(p for p in root.glob("*.pt") if p.name != "manifest.json")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training_v2/dataset_validation.py` around lines 25 - 29,
validate_preprocessed_dataset currently treats a non-existent or non-directory
path as valid by returning zero samples; after computing root =
Path(dataset_dir) add an explicit existence and directory check (e.g., if not
root.exists() or not root.is_dir()) and raise a clear exception
(FileNotFoundError or ValueError) so callers get an immediate, descriptive
failure instead of a false "valid" report; reference
validate_preprocessed_dataset, root, and files when making the change.

Comment on lines +17 to +23
def __init__(self) -> None:
super().__init__()
self.decoder = nn.Sequential(nn.Linear(4, 4), nn.Linear(4, 4))
self.encoder = nn.Linear(4, 4)
self.null_condition_emb = nn.Parameter(torch.zeros(1, 1, 4))
self.config = type("Cfg", (), {})()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add a docstring to _DummyModel.__init__.

The initializer is new but undocumented; please add a concise purpose/inputs docstring to keep module docs consistent.

As per coding guidelines: “Docstrings are mandatory for all new or modified Python modules, classes, and functions.”

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training_v2/fixed_lora_module_full_mode_test.py` around lines 17 -
23, Add a concise docstring inside _DummyModel.__init__ describing the purpose
of the initializer and the attributes it sets (e.g., creates decoder:
nn.Sequential of two Linear(4,4), encoder: nn.Linear(4,4), null_condition_emb:
nn.Parameter of zeros with shape (1,1,4), and a simple config object), and note
there are no parameters to __init__; place the docstring as the first statement
in the __init__ method so it documents the constructor and its initialized
members.

Comment on lines +122 to +128
self.latent_lengths: List[int] = []
for vp in self.valid_paths:
try:
sample = torch.load(vp, map_location="cpu", weights_only=True)
self.latent_lengths.append(int(sample["target_latents"].shape[0]))
except Exception:
self.latent_lengths.append(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Avoid bare except Exception; handle specific errors and log failures.

The current code silently catches all exceptions and defaults to length 0, which could mask real issues (corrupt files, permission errors) and cause unexpected bucketing behavior for failed samples.

🛡️ Proposed fix with specific exception handling
         self.latent_lengths: List[int] = []
         for vp in self.valid_paths:
             try:
                 sample = torch.load(vp, map_location="cpu", weights_only=True)
                 self.latent_lengths.append(int(sample["target_latents"].shape[0]))
-            except Exception:
+            except (OSError, KeyError, RuntimeError) as e:
+                logger.warning(f"Failed to read latent length from {vp}: {e}")
                 self.latent_lengths.append(0)
🧰 Tools
🪛 Ruff (0.15.2)

[warning] 127-127: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training/data_module.py` around lines 122 - 128, The loop that builds
self.latent_lengths in DataModule (using self.valid_paths) currently uses a bare
except which masks errors; change it to catch specific exceptions when loading
torch files (e.g., FileNotFoundError, PermissionError, EOFError, RuntimeError)
and log the failure including the path and exception (use the module/class
logger or python logging), append 0 only for those expected/recoverable errors,
and let other unexpected exceptions propagate (or re-raise) so real issues
aren’t silently ignored; update the try/except around torch.load and the
sample["target_latents"] access accordingly and include the path (vp) and
exception text in the log message.

Comment on lines +355 to +362
if self.length_bucket and hasattr(self.train_dataset, "latent_lengths"):
kwargs.pop("batch_size", None)
kwargs.pop("shuffle", None)
kwargs["batch_sampler"] = BucketedBatchSampler(
lengths=list(getattr(self.train_dataset, "latent_lengths", [])),
batch_size=self.batch_size,
shuffle=True,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Bucketing silently fails when val_split > 0.

When validation split is enabled, self.train_dataset becomes a torch.utils.data.Subset, which lacks the latent_lengths attribute. The hasattr check passes silently, falling back to regular batching without any warning.

🐛 Proposed fix to access underlying dataset's latent_lengths
-        if self.length_bucket and hasattr(self.train_dataset, "latent_lengths"):
+        latent_lengths = None
+        if self.length_bucket:
+            ds = self.train_dataset
+            # Handle Subset from random_split
+            if hasattr(ds, "dataset") and hasattr(ds, "indices"):
+                underlying = ds.dataset
+                if hasattr(underlying, "latent_lengths"):
+                    latent_lengths = [underlying.latent_lengths[i] for i in ds.indices]
+            elif hasattr(ds, "latent_lengths"):
+                latent_lengths = ds.latent_lengths
+        if latent_lengths is not None:
             kwargs.pop("batch_size", None)
             kwargs.pop("shuffle", None)
             kwargs["batch_sampler"] = BucketedBatchSampler(
-                lengths=list(getattr(self.train_dataset, "latent_lengths", [])),
+                lengths=latent_lengths,
                 batch_size=self.batch_size,
                 shuffle=True,
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training/data_module.py` around lines 355 - 362, The bucketing code
fails silently when self.train_dataset is a torch.utils.data.Subset because
hasattr(self.train_dataset, "latent_lengths") is false for the Subset wrapper;
update the condition in the block that sets BucketedBatchSampler (symbols:
length_bucket, self.train_dataset, latent_lengths, BucketedBatchSampler) to look
through the Subset wrapper and obtain the underlying dataset's latent_lengths
(e.g., check getattr(self.train_dataset, "dataset", self.train_dataset) for
latent_lengths) before deciding to pop batch_size/shuffle and assign
batch_sampler, so the BucketedBatchSampler still receives the correct lengths
when val_split > 0.

Comment on lines +213 to +230
report = validate_preprocessed_dataset(args.dataset_dir)
print("\n" + "=" * 60)
print(" Dataset Validation")
print("=" * 60)
print(f" Dataset: {args.dataset_dir}")
print(f" Samples: {report['total_samples']}")
print(f" Valid: {report['valid_samples']}")
print(f" Invalid: {report['invalid_samples']}")
print(f" NaN/Inf: {report['nan_or_inf_samples']}")
print(f" Min latent T: {report['min_latent_length']}")
print(f" Max latent T: {report['max_latent_length']}")
print(f" Avg latent T: {report['avg_latent_length']:.2f}")
if report["errors"]:
print("\n[WARN] Sample errors:")
for err in report["errors"][:10]:
print(f" - {err}")
print("=" * 60)
return 0 if report["invalid_samples"] == 0 else 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

validate-dataset can return success for empty datasets.

At Line 230, success is based only on invalid_samples == 0; an empty dataset (total_samples == 0) currently exits 0.

Proposed fix
-    return 0 if report["invalid_samples"] == 0 else 1
+    has_samples = report["total_samples"] > 0
+    is_clean = report["invalid_samples"] == 0
+    return 0 if (has_samples and is_clean) else 1
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@train.py` around lines 213 - 230, The current exit logic uses only
report["invalid_samples"] to decide success, allowing empty datasets to return
success; update the final return condition to require report["total_samples"] >
0 in addition to report["invalid_samples"] == 0 (i.e., return success only when
there is at least one sample and no invalid samples). Locate the block that
calls validate_preprocessed_dataset and references report (the print/report
summary block) and change the return expression to check both
report["total_samples"] > 0 and report["invalid_samples"] == 0; optionally add a
short warning/log when total_samples == 0 before returning failure.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (2)
acestep/training/data_module_test.py (1)

196-207: Test is valid but coverage gap exists for new features.

The regression test for legacy AceStepDataModule initialization is appropriate. However, the PR introduces BucketedBatchSampler and RAM LRU caching in PreprocessedTensorDataset without corresponding unit tests.

Consider adding tests for:

  • BucketedBatchSampler: bucket grouping, shuffle behavior, batch size enforcement
  • LRU caching: cache hits/misses, eviction when exceeding cache_max_items
  • latent_lengths computation with valid/invalid tensor files

Do you want me to generate unit tests for BucketedBatchSampler and the caching behavior, or open an issue to track this task?

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training/data_module_test.py` around lines 196 - 207, Add unit tests
covering the new BucketedBatchSampler and PreprocessedTensorDataset LRU caching
and latent_lengths behavior: write tests that instantiate BucketedBatchSampler
with a simple length list to verify bucket grouping semantics, that
shuffle=True/False affects order but not batch sizes, and that batches respect
the specified batch_size; add tests for PreprocessedTensorDataset that simulate
loading small tensors to assert cache hit/miss counts and that items are evicted
when exceeding cache_max_items (reference
PreprocessedTensorDataset.cache_max_items and its internal LRU behavior), and
add tests for latent_lengths computation that validate correct lengths for valid
tensor files and that invalid/missing tensor files raise or are handled as
expected (reference latent_lengths). Ensure tests import BucketedBatchSampler
and PreprocessedTensorDataset from the training module and use deterministic
seeds/mocks for file I/O to keep them deterministic.
acestep/training_v2/trainer_helpers.py (1)

307-425: Split resume_checkpoint into mode-specific helpers.

resume_checkpoint now bundles path normalization, mode dispatch, state restoration, and user messaging for full/LoKR/LoRA flows. Please extract per-mode resume helpers plus one shared training_state.pt loader to reduce branching drift and simplify maintenance.

As per coding guidelines "Function decomposition: do one thing at a time; if a function description naturally contains 'and', split it" and "Split functions by responsibility, not by convenience".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training_v2/trainer_helpers.py` around lines 307 - 425, Split
resume_checkpoint into focused helpers: create a shared
_load_training_state(ckpt_dir, module, optimizer, scheduler) that loads
training_state.pt and returns (epoch, global_step, state) or None, then
implement _resume_full_decoder(trainer, ckpt_dir, module, state_loader) to
encapsulate the full-mode logic including calling _load_full_decoder_state and
using _load_training_state, _resume_lokr(trainer, ckpt_dir, module,
state_loader) to handle lokr_weights.safetensors loading and state restore (use
load_lokr_weights and module.lycoris_net), and _resume_lora(trainer, ckpt_dir,
module, optimizer, scheduler) to call load_training_checkpoint and load
adapter_model.safetensors/.bin into module.model.decoder; finally simplify
resume_checkpoint to do path normalization, detect mode and dispatch to these
helpers and yield their TrainingUpdate/return values. Ensure helper names
(_load_training_state, _resume_full_decoder, _resume_lokr, _resume_lora) are
used so callers in resume_checkpoint are easy to find.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@acestep/training_v2/trainer_helpers_test.py`:
- Around line 17-45: Add a new test in trainer_helpers_test.py that exercises
the success path of trainer_helpers.resume_checkpoint: create a
TemporaryDirectory and write a valid full_decoder_state.pt (and optionally
training_state.pt) file(s) with a saved dict containing epoch and global_step,
call resume_checkpoint(trainer, tmpdir, optimizer, scheduler) using the existing
SimpleNamespace trainer/optimizer/scheduler, iterate the returned generator to
completion and assert the final result is (epoch, step), assert
optimizer.load_state_dict and scheduler.load_state_dict were invoked with the
saved state dicts, and assert a non-warning update was emitted; reference
resume_checkpoint, full_decoder_state.pt, training_state.pt,
optimizer.load_state_dict, and scheduler.load_state_dict to locate the code to
test.
- Around line 24-25: The test uses SimpleNamespace with lambdas that accept an
unused parameter (optimizer = SimpleNamespace(load_state_dict=lambda state:
None) and scheduler = SimpleNamespace(load_state_dict=lambda state: None)),
which triggers Ruff ARG005; change the lambda parameter name to be
underscore-prefixed (e.g., lambda _state: None) or use a splat (lambda *_: None)
so the unused argument is explicitly acknowledged and ARG005 is silenced for the
test doubles in trainer_helpers_test.py.

In `@acestep/training/data_module.py`:
- Around line 59-61: The __len__ method currently returns ceil(total_samples /
batch_size) which underestimates when samples are grouped into buckets in
__iter__; change __len__ in data_module.py to sum per-bucket batch counts
instead of using total alone — compute for each bucket: (len(bucket) +
self.batch_size - 1) // self.batch_size and return the sum. If buckets are
stored as self.buckets (lists of indices) iterate over that; if you only have a
per-sample bucket id array (e.g. self.bucket_ids) derive counts per bucket first
and then sum their ceilings; keep the method name __len__ and reference
self.batch_size exactly as in the diff.

---

Nitpick comments:
In `@acestep/training_v2/trainer_helpers.py`:
- Around line 307-425: Split resume_checkpoint into focused helpers: create a
shared _load_training_state(ckpt_dir, module, optimizer, scheduler) that loads
training_state.pt and returns (epoch, global_step, state) or None, then
implement _resume_full_decoder(trainer, ckpt_dir, module, state_loader) to
encapsulate the full-mode logic including calling _load_full_decoder_state and
using _load_training_state, _resume_lokr(trainer, ckpt_dir, module,
state_loader) to handle lokr_weights.safetensors loading and state restore (use
load_lokr_weights and module.lycoris_net), and _resume_lora(trainer, ckpt_dir,
module, optimizer, scheduler) to call load_training_checkpoint and load
adapter_model.safetensors/.bin into module.model.decoder; finally simplify
resume_checkpoint to do path normalization, detect mode and dispatch to these
helpers and yield their TrainingUpdate/return values. Ensure helper names
(_load_training_state, _resume_full_decoder, _resume_lokr, _resume_lora) are
used so callers in resume_checkpoint are easy to find.

In `@acestep/training/data_module_test.py`:
- Around line 196-207: Add unit tests covering the new BucketedBatchSampler and
PreprocessedTensorDataset LRU caching and latent_lengths behavior: write tests
that instantiate BucketedBatchSampler with a simple length list to verify bucket
grouping semantics, that shuffle=True/False affects order but not batch sizes,
and that batches respect the specified batch_size; add tests for
PreprocessedTensorDataset that simulate loading small tensors to assert cache
hit/miss counts and that items are evicted when exceeding cache_max_items
(reference PreprocessedTensorDataset.cache_max_items and its internal LRU
behavior), and add tests for latent_lengths computation that validate correct
lengths for valid tensor files and that invalid/missing tensor files raise or
are handled as expected (reference latent_lengths). Ensure tests import
BucketedBatchSampler and PreprocessedTensorDataset from the training module and
use deterministic seeds/mocks for file I/O to keep them deterministic.

ℹ️ Review info

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a3ff595 and 2b9fd32.

📒 Files selected for processing (4)
  • acestep/training/data_module.py
  • acestep/training/data_module_test.py
  • acestep/training_v2/trainer_helpers.py
  • acestep/training_v2/trainer_helpers_test.py

Comment on lines +17 to +45
def test_full_mode_missing_decoder_state_yields_warning(self) -> None:
"""Missing ``full_decoder_state.pt`` should warn and return ``None``."""
trainer = SimpleNamespace(
module=SimpleNamespace(device=torch.device("cpu"), lycoris_net=None),
training_config=SimpleNamespace(training_mode="full"),
adapter_type="lora",
)
optimizer = SimpleNamespace(load_state_dict=lambda state: None)
scheduler = SimpleNamespace(load_state_dict=lambda state: None)

with tempfile.TemporaryDirectory() as tmpdir:
generator = trainer_helpers.resume_checkpoint(
trainer,
tmpdir,
optimizer,
scheduler,
)
updates = []
try:
while True:
updates.append(next(generator))
except StopIteration as stop:
result = stop.value

self.assertIsNone(result)
self.assertEqual(len(updates), 1)
self.assertEqual(updates[0].kind, "warn")
self.assertIn("full_decoder_state.pt not found", updates[0].msg)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add a full-mode success-path resume test.

This file currently tests only the missing-file warning path. Please add a success-path case that creates full_decoder_state.pt (and optionally training_state.pt) and verifies returned (epoch, step) plus optimizer/scheduler restore calls.

As per coding guidelines "Include at least: one success-path test, one regression/edge-case test for the bug being fixed, and one non-target behavior check when relevant".

🧰 Tools
🪛 Ruff (0.15.2)

[warning] 24-24: Unused lambda argument: state

(ARG005)


[warning] 25-25: Unused lambda argument: state

(ARG005)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training_v2/trainer_helpers_test.py` around lines 17 - 45, Add a new
test in trainer_helpers_test.py that exercises the success path of
trainer_helpers.resume_checkpoint: create a TemporaryDirectory and write a valid
full_decoder_state.pt (and optionally training_state.pt) file(s) with a saved
dict containing epoch and global_step, call resume_checkpoint(trainer, tmpdir,
optimizer, scheduler) using the existing SimpleNamespace
trainer/optimizer/scheduler, iterate the returned generator to completion and
assert the final result is (epoch, step), assert optimizer.load_state_dict and
scheduler.load_state_dict were invoked with the saved state dicts, and assert a
non-warning update was emitted; reference resume_checkpoint,
full_decoder_state.pt, training_state.pt, optimizer.load_state_dict, and
scheduler.load_state_dict to locate the code to test.

Comment on lines +24 to +25
optimizer = SimpleNamespace(load_state_dict=lambda state: None)
scheduler = SimpleNamespace(load_state_dict=lambda state: None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Silence Ruff ARG005 for test doubles.

Use underscore-prefixed lambda arguments for intentionally unused parameters.

Proposed lint-only fix
-        optimizer = SimpleNamespace(load_state_dict=lambda state: None)
-        scheduler = SimpleNamespace(load_state_dict=lambda state: None)
+        optimizer = SimpleNamespace(load_state_dict=lambda _state: None)
+        scheduler = SimpleNamespace(load_state_dict=lambda _state: None)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
optimizer = SimpleNamespace(load_state_dict=lambda state: None)
scheduler = SimpleNamespace(load_state_dict=lambda state: None)
optimizer = SimpleNamespace(load_state_dict=lambda _state: None)
scheduler = SimpleNamespace(load_state_dict=lambda _state: None)
🧰 Tools
🪛 Ruff (0.15.2)

[warning] 24-24: Unused lambda argument: state

(ARG005)


[warning] 25-25: Unused lambda argument: state

(ARG005)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training_v2/trainer_helpers_test.py` around lines 24 - 25, The test
uses SimpleNamespace with lambdas that accept an unused parameter (optimizer =
SimpleNamespace(load_state_dict=lambda state: None) and scheduler =
SimpleNamespace(load_state_dict=lambda state: None)), which triggers Ruff
ARG005; change the lambda parameter name to be underscore-prefixed (e.g., lambda
_state: None) or use a splat (lambda *_: None) so the unused argument is
explicitly acknowledged and ARG005 is silenced for the test doubles in
trainer_helpers_test.py.

…into codex/fix-high-priority-bugs-in-pull-request-#1-fvat6p
…pull-request-#1-fvat6p

Codex/fix high priority bugs in pull request #1 fvat6p
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
acestep/training/data_module.py (1)

84-90: ⚠️ Potential issue | 🟡 Minor

Update __init__ docstrings for newly added parameters.

Both modified constructors added cache/bucketing args, but their docstrings still document only the previous parameter set.

📝 Proposed docstring updates
     def __init__(self, tensor_dir: str, cache_policy: str = "none", cache_max_items: int = 0):
         """Initialize from a directory of preprocessed .pt files.
         
         Args:
             tensor_dir: Directory containing preprocessed .pt files and manifest.json
+            cache_policy: Cache mode ("none" or "ram_lru")
+            cache_max_items: Maximum number of cached items when using RAM LRU
     def __init__(
         self,
         tensor_dir: str,
@@
     ):
         """Initialize the data module.
@@
         Args:
             tensor_dir: Directory containing preprocessed .pt files
             batch_size: Training batch size
             num_workers: Number of data loading workers
             pin_memory: Whether to pin memory for faster GPU transfer
             val_split: Fraction of data for validation (0 = no validation)
+            length_bucket: Whether to bucket training samples by latent length
+            cache_policy: Dataset cache mode ("none" or "ram_lru")
+            cache_max_items: Maximum number of cached items when using RAM LRU
         """

As per coding guidelines: "Docstrings are mandatory for all new or modified Python modules, classes, and functions" and "Docstrings must be concise and include purpose plus key inputs/outputs (and raised exceptions when relevant)".

Also applies to: 296-304

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training/data_module.py` around lines 84 - 90, The __init__ docstring
for the constructor (method __init__ in data_module.py) was not updated to
document the newly added parameters; update the docstring for __init__ to
briefly state the purpose of the class, list and describe all parameters
including the new cache_policy (str, allowed values like "none"/"lru"/etc. and
default "none") and cache_max_items (int, default 0 and effect on caching), and
mention any raised exceptions; make the same concise updates to the other
modified constructor docstring referenced around lines 296-304 (identify the
other __init__ there) so both constructors include purpose plus key
inputs/outputs and raised exceptions per project docstring guidelines.
♻️ Duplicate comments (3)
acestep/training_v2/trainer_helpers.py (1)

425-438: ⚠️ Potential issue | 🔴 Critical

Restore LoKR resume success path and remove unreachable full-mode branch.

After Line 425, a successful training_state load has no return path, so _resume_lokr implicitly returns None and _resume_lokr_or_lora falls through to LoRA resume. Also, the full-mode check at Line 431 is unreachable because resume_checkpoint already routes full mode to _resume_full_decoder.

💡 Proposed fix
     training_state = state_loader(ckpt_dir=ckpt_dir)
     if training_state is None:
         yield TrainingUpdate(0, 0.0, "[OK] LoKR weights loaded (no training state)", kind="info")
         return None

-    # Full mode resume: no full decoder checkpoint found.
-    if getattr(trainer.training_config, "training_mode", "adapter") == "full":
-        yield TrainingUpdate(
-            0,
-            0.0,
-            f"[WARN] full_decoder_state.pt not found in {ckpt_dir}",
-            kind="warn",
-        )
-        return
+    epoch, step, _ = training_state
+    yield TrainingUpdate(
+        0,
+        0.0,
+        f"[OK] Resumed LoKR from epoch {epoch}, step {step}",
+        kind="info",
+    )
+    return (epoch, step)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training_v2/trainer_helpers.py` around lines 425 - 438, The success
path after calling state_loader (training_state) currently yields a
TrainingUpdate but then falls through and returns None implicitly, causing
_resume_lokr to not signal a completed LoKR resume and letting
_resume_lokr_or_lora continue to LoRA resume; also the subsequent full-mode
check (getattr(trainer.training_config, "training_mode", "adapter") == "full")
is unreachable because resume_checkpoint routes full mode to
_resume_full_decoder. Fix by making the training_state success path explicitly
return (e.g., return a truthy result or stop further processing) so _resume_lokr
exits after yielding the "[OK] LoKR weights loaded" update, and remove the
unreachable full-mode branch block (or move full-mode handling into the correct
resume_checkpoint flow); refer to state_loader, training_state, _resume_lokr,
_resume_lokr_or_lora, resume_checkpoint, _resume_full_decoder, and
trainer.training_config to locate and update the logic.
acestep/training/data_module.py (2)

131-135: ⚠️ Potential issue | 🟡 Minor

Narrow exception handling when deriving latent_lengths.

except Exception hides corrupt sample/schema issues and silently forces length 0, which can skew bucketing. Catch expected load/key errors, log path + error, and let unexpected exceptions surface.

🔧 Proposed fix
         for vp in self.valid_paths:
             try:
                 sample = torch.load(vp, map_location="cpu", weights_only=True)
                 self.latent_lengths.append(int(sample["target_latents"].shape[0]))
-            except Exception:
+            except (
+                FileNotFoundError,
+                PermissionError,
+                EOFError,
+                OSError,
+                KeyError,
+                RuntimeError,
+            ) as exc:
+                logger.warning(f"Failed to read latent length from {vp}: {exc}")
                 self.latent_lengths.append(0)

Based on learnings: Applies to **/*.py : Handle errors explicitly in Python; avoid bare except.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training/data_module.py` around lines 131 - 135, The current broad
except around torch.load in the latent length derivation swallows all errors and
appends 0; change it to catch only expected errors (e.g., RuntimeError,
EOFError, KeyError) when loading/parsing the sample returned by torch.load and
on those cases append 0 but also log the file path (vp) and the exception via
the module or class logger (e.g., self.logger or logging.getLogger) so the
failure is visible; for any other unexpected exceptions re-raise them so they
surface instead of being hidden. Ensure this change is applied to the torch.load
+ self.latent_lengths.append(...) block so only known load/key issues are
handled gracefully while other errors propagate.

362-369: ⚠️ Potential issue | 🟠 Major

Bucketing is skipped when train_dataset is a Subset (val_split > 0).

After random_split, self.train_dataset is a torch.utils.data.Subset, so the current hasattr(..., "latent_lengths") check fails and bucketing is not used (with shuffle=False in this mode). Resolve latent lengths from the wrapped dataset + indices.

🐛 Proposed fix
-        if self.length_bucket and hasattr(self.train_dataset, "latent_lengths"):
+        latent_lengths: Optional[List[int]] = None
+        if self.length_bucket:
+            ds = self.train_dataset
+            if (
+                isinstance(ds, torch.utils.data.Subset)
+                and hasattr(ds.dataset, "latent_lengths")
+            ):
+                latent_lengths = [ds.dataset.latent_lengths[i] for i in ds.indices]
+            elif hasattr(ds, "latent_lengths"):
+                latent_lengths = list(getattr(ds, "latent_lengths", []))
+
+        if latent_lengths is not None:
             kwargs.pop("batch_size", None)
             kwargs.pop("shuffle", None)
             kwargs["batch_sampler"] = BucketedBatchSampler(
-                lengths=list(getattr(self.train_dataset, "latent_lengths", [])),
+                lengths=latent_lengths,
                 batch_size=self.batch_size,
                 shuffle=True,
             )
+        elif self.length_bucket:
+            kwargs["shuffle"] = True
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training/data_module.py` around lines 362 - 369, The bucketing check
skips when train_dataset is a torch.utils.data.Subset because
hasattr(self.train_dataset, "latent_lengths") is false; update the logic in the
DataModule where BucketedBatchSampler is created to detect a Subset
(torch.utils.data.Subset) and extract latent_lengths from the wrapped dataset
using the Subset.indices (e.g., resolve base = self.train_dataset.dataset and
indices = self.train_dataset.indices), build lengths = [getattr(base,
"latent_lengths")[i] for i in indices] (falling back to getattr(..., []) if
missing), then pass that lengths list into BucketedBatchSampler; keep the
existing batch_size/shuffle pop behavior and the BucketedBatchSampler
instantiation name to locate the change.
🧹 Nitpick comments (2)
acestep/training_v2/trainer_helpers.py (1)

137-507: Split checkpoint save/load/resume paths into smaller modules.

This file now combines many responsibilities (save, verify, load, resume orchestration) and exceeds the module hard cap, which makes future changes riskier.

Based on learnings: raise module-size concerns only when a Python file exceeds 200 LOC; this file is above that threshold. As per coding guidelines: **/*.py: Target module size: optimal <= 150 LOC, hard cap 200 LOC.

acestep/training/data_module.py (1)

36-391: This module exceeds the Python hard cap and now mixes too many responsibilities.

data_module.py is well beyond 200 LOC and now contains sampler logic, multiple datasets, and multiple DataModules in one place. Please split this into focused modules to reduce maintenance risk.

As per coding guidelines: **/*.py: "Target module size: optimal <= 150 LOC, hard cap 200 LOC".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training/data_module.py` around lines 36 - 391, This file is too
large and mixes responsibilities; split it into focused modules: move
BucketedBatchSampler into a new sampler module (e.g., sampler.py) and export the
class there; move PreprocessedTensorDataset and its helper
_resolve_manifest_path into a dataset module (e.g., preprocessed_dataset.py);
move collate_preprocessed_batch into a collate module (e.g., collate.py); keep
PreprocessedDataModule in its own data_module.py that imports
BucketedBatchSampler, PreprocessedTensorDataset, and collate_preprocessed_batch;
update imports in PreprocessedDataModule to reference the new modules and ensure
any shared utilities (safe_path, logger) are imported from their original
locations. Ensure behavior and public APIs (class/function names and signatures:
BucketedBatchSampler, PreprocessedTensorDataset, collate_preprocessed_batch,
PreprocessedDataModule) remain unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@acestep/training_v2/trainer_helpers.py`:
- Around line 472-475: The current code unwraps the decoder by checking only one
_forward_module (decoder = module.model.decoder; if hasattr(decoder,
"_forward_module"): decoder = decoder._forward_module) which is brittle; replace
that logic by calling the existing helper _unwrap_decoder(module.model) to
obtain the fully unwrapped decoder before calling
decoder.load_state_dict(state_dict, strict=False), i.e., locate the block around
decoder.load_state_dict and use _unwrap_decoder(module.model) to get the decoder
so LoRA keys under nested wrappers are correctly loaded.
- Line 352: The type annotation on the function returning Optional[Tuple[int,
int, Dict[str, Any]]] uses typing.Dict which isn't imported; change the
annotation to use the built-in generic syntax dict[str, Any] (i.e.,
Optional[Tuple[int, int, dict[str, Any]]]) in the function signature where that
return type appears (locate the signature around the function named in
trainer_helpers.py that currently returns Optional[Tuple[int, int, Dict[str,
Any]]]); also remove any unused import of Dict from typing if present to avoid
F401.

In `@acestep/training/data_module_test.py`:
- Around line 201-212: Add deterministic tests in
acestep/training/data_module_test.py covering the new bucketing/cache behavior
for AceStepDataModule: create tests that instantiate AceStepDataModule with
length_bucket enabled to verify samples are assigned to expected buckets
(success path), a regression/edge-case that uses ram_lru caching with a tiny
capacity to confirm eviction/split-dataset behavior occurs deterministically,
and a non-target behavior check ensuring legacy initialization (samples=[],
dit_handler=...) still works when these options are present; reference the
AceStepDataModule constructor and its length_bucket and ram_lru configuration
flags/symbols to locate where to set parameters and assert bucket assignments,
eviction counts, and that dit_handler remains set.

---

Outside diff comments:
In `@acestep/training/data_module.py`:
- Around line 84-90: The __init__ docstring for the constructor (method __init__
in data_module.py) was not updated to document the newly added parameters;
update the docstring for __init__ to briefly state the purpose of the class,
list and describe all parameters including the new cache_policy (str, allowed
values like "none"/"lru"/etc. and default "none") and cache_max_items (int,
default 0 and effect on caching), and mention any raised exceptions; make the
same concise updates to the other modified constructor docstring referenced
around lines 296-304 (identify the other __init__ there) so both constructors
include purpose plus key inputs/outputs and raised exceptions per project
docstring guidelines.

---

Duplicate comments:
In `@acestep/training_v2/trainer_helpers.py`:
- Around line 425-438: The success path after calling state_loader
(training_state) currently yields a TrainingUpdate but then falls through and
returns None implicitly, causing _resume_lokr to not signal a completed LoKR
resume and letting _resume_lokr_or_lora continue to LoRA resume; also the
subsequent full-mode check (getattr(trainer.training_config, "training_mode",
"adapter") == "full") is unreachable because resume_checkpoint routes full mode
to _resume_full_decoder. Fix by making the training_state success path
explicitly return (e.g., return a truthy result or stop further processing) so
_resume_lokr exits after yielding the "[OK] LoKR weights loaded" update, and
remove the unreachable full-mode branch block (or move full-mode handling into
the correct resume_checkpoint flow); refer to state_loader, training_state,
_resume_lokr, _resume_lokr_or_lora, resume_checkpoint, _resume_full_decoder, and
trainer.training_config to locate and update the logic.

In `@acestep/training/data_module.py`:
- Around line 131-135: The current broad except around torch.load in the latent
length derivation swallows all errors and appends 0; change it to catch only
expected errors (e.g., RuntimeError, EOFError, KeyError) when loading/parsing
the sample returned by torch.load and on those cases append 0 but also log the
file path (vp) and the exception via the module or class logger (e.g.,
self.logger or logging.getLogger) so the failure is visible; for any other
unexpected exceptions re-raise them so they surface instead of being hidden.
Ensure this change is applied to the torch.load +
self.latent_lengths.append(...) block so only known load/key issues are handled
gracefully while other errors propagate.
- Around line 362-369: The bucketing check skips when train_dataset is a
torch.utils.data.Subset because hasattr(self.train_dataset, "latent_lengths") is
false; update the logic in the DataModule where BucketedBatchSampler is created
to detect a Subset (torch.utils.data.Subset) and extract latent_lengths from the
wrapped dataset using the Subset.indices (e.g., resolve base =
self.train_dataset.dataset and indices = self.train_dataset.indices), build
lengths = [getattr(base, "latent_lengths")[i] for i in indices] (falling back to
getattr(..., []) if missing), then pass that lengths list into
BucketedBatchSampler; keep the existing batch_size/shuffle pop behavior and the
BucketedBatchSampler instantiation name to locate the change.

---

Nitpick comments:
In `@acestep/training/data_module.py`:
- Around line 36-391: This file is too large and mixes responsibilities; split
it into focused modules: move BucketedBatchSampler into a new sampler module
(e.g., sampler.py) and export the class there; move PreprocessedTensorDataset
and its helper _resolve_manifest_path into a dataset module (e.g.,
preprocessed_dataset.py); move collate_preprocessed_batch into a collate module
(e.g., collate.py); keep PreprocessedDataModule in its own data_module.py that
imports BucketedBatchSampler, PreprocessedTensorDataset, and
collate_preprocessed_batch; update imports in PreprocessedDataModule to
reference the new modules and ensure any shared utilities (safe_path, logger)
are imported from their original locations. Ensure behavior and public APIs
(class/function names and signatures: BucketedBatchSampler,
PreprocessedTensorDataset, collate_preprocessed_batch, PreprocessedDataModule)
remain unchanged.

ℹ️ Review info

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2b9fd32 and 2f870f9.

📒 Files selected for processing (3)
  • acestep/training/data_module.py
  • acestep/training/data_module_test.py
  • acestep/training_v2/trainer_helpers.py

module: Any,
optimizer: Any,
scheduler: Any,
) -> Optional[Tuple[int, int, Dict[str, Any]]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Find the file
file_path=$(find . -name "trainer_helpers.py" -path "*/training_v2/*" 2>/dev/null | head -1)

if [ -z "$file_path" ]; then
    echo "File not found"
    exit 1
fi

echo "=== File found at: $file_path ==="
echo ""

# Show line 352 and surrounding context
echo "=== Context around line 352 ==="
sed -n '340,360p' "$file_path" | cat -n

echo ""
echo "=== Checking imports at top of file ==="
head -40 "$file_path" | grep -E "^import|^from" || echo "No imports found in first 40 lines"

echo ""
echo "=== Running verification script ==="
python - <<'PY'
import ast
from pathlib import Path

path = next(Path(".").rglob("trainer_helpers.py"))
tree = ast.parse(path.read_text(encoding="utf-8"))

typing_imports = set()
for node in tree.body:
    if isinstance(node, ast.ImportFrom) and node.module == "typing":
        typing_imports.update(alias.name for alias in node.names)

dict_refs = []
for node in ast.walk(tree):
    if isinstance(node, ast.Name) and node.id == "Dict":
        dict_refs.append((node.lineno, 'Dict' in typing_imports))

print(f"Found {len(dict_refs)} reference(s) to 'Dict'")
print(f"Typing imports: {typing_imports}")
for line_no, is_imported in dict_refs:
    print(f"  Line {line_no}: imported_from_typing={is_imported}")
PY

Repository: ace-step/ACE-Step-1.5

Length of output: 1470


Use built-in dict generic for type annotation.

Line 352 references Dict which is not imported from typing, triggering F821. Since the project targets Python 3.11-3.12, use the built-in dict[str, Any] syntax instead:

Proposed fix
-) -> Optional[Tuple[int, int, Dict[str, Any]]]:
+) -> Optional[Tuple[int, int, dict[str, Any]]]:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
) -> Optional[Tuple[int, int, Dict[str, Any]]]:
) -> Optional[Tuple[int, int, dict[str, Any]]]:
🧰 Tools
🪛 Ruff (0.15.2)

[error] 352-352: Undefined name Dict

(F821)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training_v2/trainer_helpers.py` at line 352, The type annotation on
the function returning Optional[Tuple[int, int, Dict[str, Any]]] uses
typing.Dict which isn't imported; change the annotation to use the built-in
generic syntax dict[str, Any] (i.e., Optional[Tuple[int, int, dict[str, Any]]])
in the function signature where that return type appears (locate the signature
around the function named in trainer_helpers.py that currently returns
Optional[Tuple[int, int, Dict[str, Any]]]); also remove any unused import of
Dict from typing if present to avoid F401.

Comment on lines +472 to +475
decoder = module.model.decoder
if hasattr(decoder, "_forward_module"):
decoder = decoder._forward_module
decoder.load_state_dict(state_dict, strict=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Load LoRA weights into the fully unwrapped decoder.

At Line 472, targeting module.model.decoder with only one _forward_module unwrap is brittle and can silently skip keys under other wrappers. Use _unwrap_decoder(module.model) for consistency with save/load helpers.

💡 Proposed fix
-    decoder = module.model.decoder
-    if hasattr(decoder, "_forward_module"):
-        decoder = decoder._forward_module
+    decoder = _unwrap_decoder(module.model)
     decoder.load_state_dict(state_dict, strict=False)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training_v2/trainer_helpers.py` around lines 472 - 475, The current
code unwraps the decoder by checking only one _forward_module (decoder =
module.model.decoder; if hasattr(decoder, "_forward_module"): decoder =
decoder._forward_module) which is brittle; replace that logic by calling the
existing helper _unwrap_decoder(module.model) to obtain the fully unwrapped
decoder before calling decoder.load_state_dict(state_dict, strict=False), i.e.,
locate the block around decoder.load_state_dict and use
_unwrap_decoder(module.model) to get the decoder so LoRA keys under nested
wrappers are correctly loaded.

Comment on lines +201 to +212
class AceStepDataModuleInitTests(unittest.TestCase):
"""Regression tests for legacy ``AceStepDataModule`` initialization."""

def test_init_does_not_require_preprocessed_only_cache_args(self):
"""Legacy raw-audio datamodule should initialize without NameError."""
from acestep.training.data_module import AceStepDataModule

module = AceStepDataModule(samples=[], dit_handler=object())

self.assertEqual(module.samples, [])
self.assertIsNotNone(module.dit_handler)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Coverage is missing for the new bucketing/cache behaviors introduced by this PR.

This addition only validates legacy initialization; it does not test the new length_bucket and ram_lru paths (including split-dataset behavior). Please add deterministic tests for those new code paths.

As per coding guidelines: **/*_test.py: "Add or update tests for every behavior change and bug fix" and "Include at least: one success-path test, one regression/edge-case test for the bug being fixed, and one non-target behavior check when relevant".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training/data_module_test.py` around lines 201 - 212, Add
deterministic tests in acestep/training/data_module_test.py covering the new
bucketing/cache behavior for AceStepDataModule: create tests that instantiate
AceStepDataModule with length_bucket enabled to verify samples are assigned to
expected buckets (success path), a regression/edge-case that uses ram_lru
caching with a tiny capacity to confirm eviction/split-dataset behavior occurs
deterministically, and a non-target behavior check ensuring legacy
initialization (samples=[], dit_handler=...) still works when these options are
present; reference the AceStepDataModule constructor and its length_bucket and
ram_lru configuration flags/symbols to locate where to set parameters and assert
bucket assignments, eviction counts, and that dit_handler remains set.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (2)
acestep/training/data_module.py (2)

129-135: ⚠️ Potential issue | 🟠 Major

Avoid bare except Exception in latent-length loading.

This swallows unexpected failures and silently maps them to length 0, which can mask real data corruption or schema issues.

🔧 Proposed fix
         self.latent_lengths: List[int] = []
         for vp in self.valid_paths:
             try:
                 sample = torch.load(vp, map_location="cpu", weights_only=True)
                 self.latent_lengths.append(int(sample["target_latents"].shape[0]))
-            except Exception:
+            except (FileNotFoundError, PermissionError, EOFError, OSError, KeyError, RuntimeError) as exc:
+                logger.warning(f"Failed to read latent length from {vp}: {exc}")
                 self.latent_lengths.append(0)
#!/bin/bash
rg -nP '^\s*except\s+Exception\s*:' --type=py

Based on learnings: Applies to **/*.py : Handle errors explicitly in Python; avoid bare except.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training/data_module.py` around lines 129 - 135, The bare "except
Exception" in the loop that loads samples for self.valid_paths (where you call
torch.load and access sample["target_latents"] to populate self.latent_lengths)
should be replaced with explicit error handling: catch only the expected
exceptions (e.g., FileNotFoundError, RuntimeError from torch.load, KeyError when
"target_latents" is missing) and handle each case (log the error with context
including the path and exception, append 0 only for recoverable/known issues),
while letting unexpected exceptions propagate (or re-raise) so real bugs or
schema changes are not silently swallowed.

375-392: ⚠️ Potential issue | 🟠 Major

Bucketing still bypasses Subset and the new helper is unused.

_resolve_train_latent_lengths() was added, but train_dataloader() still checks hasattr(self.train_dataset, "latent_lengths"). With random_split, this skips bucketing and may also leave training unshuffled when length_bucket=True.

🔧 Proposed fix
         kwargs = dict(
             dataset=self.train_dataset,
             batch_size=self.batch_size,
-            shuffle=not self.length_bucket,
+            shuffle=True,
             num_workers=self.num_workers,
             pin_memory=self.pin_memory,
             collate_fn=collate_preprocessed_batch,
             drop_last=False,
             prefetch_factor=prefetch_factor,
             persistent_workers=persistent_workers,
         )
         if self.pin_memory_device:
             kwargs["pin_memory_device"] = self.pin_memory_device
-        if self.length_bucket and hasattr(self.train_dataset, "latent_lengths"):
+        latent_lengths = self._resolve_train_latent_lengths()
+        if latent_lengths is not None:
             kwargs.pop("batch_size", None)
             kwargs.pop("shuffle", None)
             kwargs["batch_sampler"] = BucketedBatchSampler(
-                lengths=list(getattr(self.train_dataset, "latent_lengths", [])),
+                lengths=latent_lengths,
                 batch_size=self.batch_size,
                 shuffle=True,
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training/data_module.py` around lines 375 - 392, train_dataloader
currently checks hasattr(self.train_dataset, "latent_lengths") and thus skips
bucketing when train_dataset is a Subset (e.g., from random_split) and the new
helper _resolve_train_latent_lengths() is unused; update train_dataloader to
call _resolve_train_latent_lengths() (or otherwise obtain latent lengths via
that helper) and use its result to decide whether to set batch_sampler to
BucketedBatchSampler and remove batch_size/shuffle, ensuring Subset-wrapped
datasets are handled and shuffling is preserved when length_bucket=True; refer
to symbols train_dataloader, _resolve_train_latent_lengths, self.train_dataset,
length_bucket, BucketedBatchSampler, and Subset when making the change.
🧹 Nitpick comments (9)
docs/lora_full_finetune_pipeline_proposal.md (7)

34-40: Document bucketing's impact on training dynamics.

While latent-length bucketing improves throughput by reducing padding waste, it can affect training dynamics if samples are not properly shuffled within and across buckets. The proposal mentions "deterministic shuffling by bucket + epoch seed," which is good.

Consider adding a note about:

  • Validating that bucketing doesn't introduce unintended correlations in training order
  • Including bucket-aware shuffle validation in the QA command (A6)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/lora_full_finetune_pipeline_proposal.md` around lines 34 - 40, Add a
short note under A2 (Add duration/latent-length bucketing in dataloader) that
describes potential training-dynamics risks from bucketing (e.g., unintended
correlations from grouping by length) and prescribe validation steps: (1) run
epoch-level bucket-aware shuffle checks to ensure samples are randomly
interleaved across epochs and within buckets, (2) compare training metrics
(loss/accuracy/validation curves) with and without bucketing to detect shifts,
and (3) include a bullet in QA command A6 to execute these bucket-aware shuffle
validations and metric comparisons as part of the QA checklist.

133-144: Consider clarifying measurement methodology for acceptance criteria.

The acceptance criteria are well-defined with specific targets, but some measurement methodologies could be clarified for reproducibility.

Consider specifying:

  • How "data-loader stall time" will be measured (profiling tool, specific metrics)
  • What constitutes "supported GPU setup" for full FT testing (minimum VRAM, specific models)
  • How to validate safety checks (e.g., test cases with known OOM-prone configs)

This would make the criteria more actionable during validation.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/lora_full_finetune_pipeline_proposal.md` around lines 133 - 144, Update
the Acceptance criteria paragraph to specify exact measurement methods: define
"data-loader stall time" measurement using a profiler (e.g., PyTorch profiler or
perf metrics) and the exact metric (percent time waiting vs total batch time)
and sampling window; define "step time" measurement method (median of N runs
with fixed seed and dataset shard, include warmup steps and measurement steps);
enumerate what "supported GPU setup" means (minimum VRAM, example device models
such as NVIDIA A100/RTX3090, CUDA/cuDNN versions) and the test protocol for
"Successfully trains decoder-only full FT" (one epoch, fixed batch size, seed,
and dataset subset); and add explicit validation for "CLI safety checks" by
listing specific OOM-prone config test cases and expected CLI responses
(warnings/failures) so acceptance tests are reproducible.

20-32: Consider adding migration tooling to the implementation plan.

While backward compatibility and auto-detection are mentioned, the proposal doesn't explicitly address tooling to migrate existing datasets from the per-sample .pt format to the new v3 sharded format. Users with large existing datasets will need a conversion utility.

Consider adding to the implementation plan:

  • A dataset migration command (e.g., convert-dataset --input-format pt --output-format v3_sharded)
  • Documentation on when migration is worthwhile vs. keeping existing format
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/lora_full_finetune_pipeline_proposal.md` around lines 20 - 32, Add
migration tooling and docs for converting per-sample `.pt` datasets to the new
v3 sharded format: implement a CLI command (e.g., convert-dataset or
convert_dataset) that accepts flags matching the loader options
(--dataset-format {pt,v3_sharded}, input/output paths) to produce
`shard-xxxxx.safetensors` and the corresponding `index.jsonl`, include
progress/error reporting and verification (checksums/sample counts), and update
the proposal and docs with guidance on when to migrate vs keep the `.pt` format
and examples of using the migration command.

99-101: Specify VRAM estimation methodology for accuracy.

The preflight VRAM estimator is a critical safety feature, but the proposal doesn't specify the estimation methodology. Inaccurate estimates could lead to user distrust or unexpected OOM failures.

Consider documenting:

  • Whether estimation will be empirical (formula-based) or profile-based (dry-run with dummy tensors)
  • How to account for framework overhead, optimizer states, and gradient accumulation
  • Acceptable margin of error and how to communicate uncertainty to users
  • Fallback behavior when estimation confidence is low
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/lora_full_finetune_pipeline_proposal.md` around lines 99 - 101, Update
the proposal to specify the VRAM estimation methodology for the "preflight VRAM
estimator" referenced alongside the opt-in flags (`--training-mode full
--i-understand-vram-risk`) and the "Auto-suggest fallback to LoRA" behavior:
state whether the estimator is formula-based (analytical) or profile-based
(dry-run with dummy tensors), describe how you will account for
framework/runtime overhead, optimizer/AMP/optimizer state sizes and gradient
accumulation, define an acceptable margin-of-error/confidence interval and how
that uncertainty is surfaced to users, and specify explicit fallback behavior
and messaging when confidence is low (e.g., suggest LoRA, require extra
confirmation, or perform an on-demand dry-run).

147-151: Minor style consideration: repetitive sentence structure.

The static analysis tool noted that lines 148-151 all begin with "Add," which creates repetitive sentence structure. While this is clear and easy to scan, you could optionally vary the structure for better readability.

Example variations:

  • "Implement DatasetBackend abstraction..."
  • "Introduce BucketedBatchSampler..."
  • "Extend train.py with validate-dataset command..."

However, the current structure is perfectly functional for a technical proposal, so this is purely a stylistic consideration.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/lora_full_finetune_pipeline_proposal.md` around lines 147 - 151, The
four bullet items in section "Immediate next patch candidates" use the same
leading verb "Add", creating repetitive sentence structure; update each line to
vary the verb while preserving meaning and the referenced symbols (e.g., change
"Add `DatasetBackend` abstraction and a sharded backend implementation." to
"Implement `DatasetBackend` abstraction with a sharded backend implementation.",
change "Add `BucketedBatchSampler` keyed by latent length." to "Introduce
`BucketedBatchSampler` keyed by latent length.", change "Add `train.py
validate-dataset --dataset-dir ...` command." to "Extend `train.py` with a
`validate-dataset --dataset-dir ...` command.", and change "Add `training_mode`
enum + `full` branch in config/trainer with decoder-only unfreeze." to "Add a
`training_mode` enum and a `full` branch in config/trainer that enables
decoder-only unfreeze."); keep the same symbols (`DatasetBackend`,
`BucketedBatchSampler`, `train.py validate-dataset`, `training_mode` and `full`)
so reviewers can locate the changes.

109-130: Consider adding testing strategy to implementation phases.

The phased implementation plan is well-structured, but doesn't explicitly address testing strategy for each phase. Given the complexity of the changes, especially around full fine-tuning, a testing plan would strengthen the proposal.

Consider adding for each phase:

  • Unit test coverage expectations (e.g., new samplers, cache policies, parameter grouping)
  • Integration test scenarios (e.g., end-to-end training with new features)
  • Backward compatibility validation (especially for Phase 1-2 data format changes)
  • Performance regression tests (to validate throughput improvements)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/lora_full_finetune_pipeline_proposal.md` around lines 109 - 130, Add a
testing strategy tied to each phase: for Phase 1, add unit tests for the packed
dataset writer/loader, bucketed sampler, and preprocessing QA validator (e.g.,
PackedDatasetWriterTests, BucketedSamplerTests, QAValidatorTests) plus
backward-compatibility tests for old vs. new data formats; for Phase 2, add
unit/integration tests for mmap/LRU cache policy and optional device prefetch
(CachePolicyTests, DevicePrefetchIntegration) and performance regression
benchmarks that validate throughput improvements and record baseline metrics by
GPU class; for Phase 3, add unit tests for training_mode=full behavior,
parameter group LR multipliers, and full-state checkpoint/resume
(TrainingModeFullTests, ParamGroupTests, CheckpointResumeTests) as well as
end-to-end training integration tests; for Phase 4, add tests for
distributed/sharded optimizer options, staged unfreeze profiles, and
eval/early-stop controls (DistributedOptimizerTests, StagedUnfreezeTests,
EarlyStopIntegration) and ensure all phases are wired into CI with clear
pass/fail criteria, recorded performance thresholds, and reproducible test data
fixtures.

93-96: Consider checkpoint format versioning.

As the full fine-tuning feature evolves across phases, checkpoint format may change. Consider adding explicit version metadata to checkpoints to ensure graceful handling of incompatible formats during resume.

This could prevent cryptic errors when users try to resume from checkpoints created with older code versions.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@docs/lora_full_finetune_pipeline_proposal.md` around lines 93 - 96, Update
the checkpointing metadata to include an explicit checkpoint format version
field (e.g., "checkpoint_version" using semantic versioning) whenever saving the
full model state, sharded states, optimizer, scheduler, scaler, or EMA
checkpoints; on resume (load) validate this metadata in the resume logic and
emit a clear, actionable error if the version is incompatible (or route to
compatibility/upgrade handlers), and ensure any code paths that change the
on-disk layout bump the checkpoint_version and add a loader/upgrade path for
older versions to avoid cryptic failures.
acestep/training/preprocessed_collate.py (1)

10-61: Deduplicate collate logic to prevent drift.

This implementation is effectively duplicated in acestep/training/data_module.py (Lines 206–271 from provided context). Keep a single canonical function and import it where needed to avoid divergence during future edits.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training/preprocessed_collate.py` around lines 10 - 61, The collate
logic is duplicated; keep collate_preprocessed_batch as the single canonical
function and remove the copy in data_module.py: delete the duplicate function
body there (the one that pads target_latents, attention_mask, context_latents,
encoder_hidden_states, and encoder_attention_mask) and replace it with an import
of collate_preprocessed_batch from this module, updating any reference in
DataModule (or functions that previously called the duplicate) to call
collate_preprocessed_batch; run a quick search for other copies and update
imports/usages so all code uses the single collate_preprocessed_batch
implementation to avoid drift.
acestep/training/data_module.py (1)

36-68: De-duplicate sampler/dataset definitions to a single source of truth.

BucketedBatchSampler and PreprocessedTensorDataset now exist both here and in acestep/training/preprocessed_sampler.py + acestep/training/preprocessed_dataset.py, and behavior has already drifted. Import shared implementations instead of maintaining parallel copies.

Based on learnings: only raise module-size concerns when a Python file exceeds 200 LOC; this file exceeds that threshold. As per coding guidelines: **/*.py hard cap is 200 LOC and functions/modules should be split by responsibility.

Also applies to: 71-205

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training/data_module.py` around lines 36 - 68, This file duplicates
BucketedBatchSampler and PreprocessedTensorDataset which have drifted from the
canonical implementations; remove the local classes and import
BucketedBatchSampler from the shared preprocessed_sampler module and
PreprocessedTensorDataset from the shared preprocessed_dataset module (update
all usages to the imported symbols), then split any remaining functionality in
this file so the module stays under the 200‑line guideline by moving unrelated
helpers/classes into separate modules; ensure tests/imports still reference the
imported names and run linters to catch any unresolved references.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@acestep/training/preprocessed_collate.py`:
- Around line 5-10: The return annotation Dict[str, torch.Tensor] is incorrect
because collate_preprocessed_batch returns a mix of tensors and a Python list
for "metadata"; define a TypedDict (e.g., PreprocessedBatch: TypedDict) that
lists each returned key and its precise type (for example tensor fields as
torch.Tensor and "metadata" as List[Dict[str, Any]]), import TypedDict and Any
from typing, and update the return type of collate_preprocessed_batch (in this
module) and the matching function in data_module.py to -> PreprocessedBatch so
static typing matches the actual payload.

In `@acestep/training/preprocessed_dataset.py`:
- Around line 96-115: The return type of __getitem__ is too narrow: it returns
tensors for most keys but "metadata" is a dict (or could be other non-tensor
values), so update the type hint for __getitem__ from Dict[str, torch.Tensor] to
a broader type such as Dict[str, Any] (and add Any to the typing imports) so the
signature reflects that values can be non-tensor; ensure the change is applied
to the __getitem__ definition and any related annotations in the same module
that reference this method.
- Around line 57-71: The logger calls in PreprocessedTensorDataset use
percent-style formatting which Loguru ignores; update all logger.warning and
logger.info usages in this region (the calls inside the constructor/initializer
that reference logger.warning("Some tensor files not found: %d missing", ...),
logger.warning("Failed to read latent length from %s: %s", vp, exc), and
logger.info("PreprocessedTensorDataset: %d samples from %s", ...)) to use
brace-style placeholders (e.g., "Some tensor files not found: {} missing") and
pass the values as positional args so Loguru will interpolate them; keep the
exception handling (in the try/except around torch.load and appending to
self.latent_lengths) unchanged.

In `@acestep/training/preprocessed_sampler.py`:
- Around line 42-45: The current __len__ returns ceil(total_items / batch_size)
which undercounts when __iter__ yields batches per bucket; update __len__ to sum
over buckets instead: for each bucket in self.lengths compute (len(bucket) +
self.batch_size - 1) // self.batch_size and return the total, ensuring you use
the same bucket structure that __iter__ uses (reference __len__, __iter__,
self.lengths, self.batch_size).

---

Duplicate comments:
In `@acestep/training/data_module.py`:
- Around line 129-135: The bare "except Exception" in the loop that loads
samples for self.valid_paths (where you call torch.load and access
sample["target_latents"] to populate self.latent_lengths) should be replaced
with explicit error handling: catch only the expected exceptions (e.g.,
FileNotFoundError, RuntimeError from torch.load, KeyError when "target_latents"
is missing) and handle each case (log the error with context including the path
and exception, append 0 only for recoverable/known issues), while letting
unexpected exceptions propagate (or re-raise) so real bugs or schema changes are
not silently swallowed.
- Around line 375-392: train_dataloader currently checks
hasattr(self.train_dataset, "latent_lengths") and thus skips bucketing when
train_dataset is a Subset (e.g., from random_split) and the new helper
_resolve_train_latent_lengths() is unused; update train_dataloader to call
_resolve_train_latent_lengths() (or otherwise obtain latent lengths via that
helper) and use its result to decide whether to set batch_sampler to
BucketedBatchSampler and remove batch_size/shuffle, ensuring Subset-wrapped
datasets are handled and shuffling is preserved when length_bucket=True; refer
to symbols train_dataloader, _resolve_train_latent_lengths, self.train_dataset,
length_bucket, BucketedBatchSampler, and Subset when making the change.

---

Nitpick comments:
In `@acestep/training/data_module.py`:
- Around line 36-68: This file duplicates BucketedBatchSampler and
PreprocessedTensorDataset which have drifted from the canonical implementations;
remove the local classes and import BucketedBatchSampler from the shared
preprocessed_sampler module and PreprocessedTensorDataset from the shared
preprocessed_dataset module (update all usages to the imported symbols), then
split any remaining functionality in this file so the module stays under the
200‑line guideline by moving unrelated helpers/classes into separate modules;
ensure tests/imports still reference the imported names and run linters to catch
any unresolved references.

In `@acestep/training/preprocessed_collate.py`:
- Around line 10-61: The collate logic is duplicated; keep
collate_preprocessed_batch as the single canonical function and remove the copy
in data_module.py: delete the duplicate function body there (the one that pads
target_latents, attention_mask, context_latents, encoder_hidden_states, and
encoder_attention_mask) and replace it with an import of
collate_preprocessed_batch from this module, updating any reference in
DataModule (or functions that previously called the duplicate) to call
collate_preprocessed_batch; run a quick search for other copies and update
imports/usages so all code uses the single collate_preprocessed_batch
implementation to avoid drift.

In `@docs/lora_full_finetune_pipeline_proposal.md`:
- Around line 34-40: Add a short note under A2 (Add duration/latent-length
bucketing in dataloader) that describes potential training-dynamics risks from
bucketing (e.g., unintended correlations from grouping by length) and prescribe
validation steps: (1) run epoch-level bucket-aware shuffle checks to ensure
samples are randomly interleaved across epochs and within buckets, (2) compare
training metrics (loss/accuracy/validation curves) with and without bucketing to
detect shifts, and (3) include a bullet in QA command A6 to execute these
bucket-aware shuffle validations and metric comparisons as part of the QA
checklist.
- Around line 133-144: Update the Acceptance criteria paragraph to specify exact
measurement methods: define "data-loader stall time" measurement using a
profiler (e.g., PyTorch profiler or perf metrics) and the exact metric (percent
time waiting vs total batch time) and sampling window; define "step time"
measurement method (median of N runs with fixed seed and dataset shard, include
warmup steps and measurement steps); enumerate what "supported GPU setup" means
(minimum VRAM, example device models such as NVIDIA A100/RTX3090, CUDA/cuDNN
versions) and the test protocol for "Successfully trains decoder-only full FT"
(one epoch, fixed batch size, seed, and dataset subset); and add explicit
validation for "CLI safety checks" by listing specific OOM-prone config test
cases and expected CLI responses (warnings/failures) so acceptance tests are
reproducible.
- Around line 20-32: Add migration tooling and docs for converting per-sample
`.pt` datasets to the new v3 sharded format: implement a CLI command (e.g.,
convert-dataset or convert_dataset) that accepts flags matching the loader
options (--dataset-format {pt,v3_sharded}, input/output paths) to produce
`shard-xxxxx.safetensors` and the corresponding `index.jsonl`, include
progress/error reporting and verification (checksums/sample counts), and update
the proposal and docs with guidance on when to migrate vs keep the `.pt` format
and examples of using the migration command.
- Around line 99-101: Update the proposal to specify the VRAM estimation
methodology for the "preflight VRAM estimator" referenced alongside the opt-in
flags (`--training-mode full --i-understand-vram-risk`) and the "Auto-suggest
fallback to LoRA" behavior: state whether the estimator is formula-based
(analytical) or profile-based (dry-run with dummy tensors), describe how you
will account for framework/runtime overhead, optimizer/AMP/optimizer state sizes
and gradient accumulation, define an acceptable margin-of-error/confidence
interval and how that uncertainty is surfaced to users, and specify explicit
fallback behavior and messaging when confidence is low (e.g., suggest LoRA,
require extra confirmation, or perform an on-demand dry-run).
- Around line 147-151: The four bullet items in section "Immediate next patch
candidates" use the same leading verb "Add", creating repetitive sentence
structure; update each line to vary the verb while preserving meaning and the
referenced symbols (e.g., change "Add `DatasetBackend` abstraction and a sharded
backend implementation." to "Implement `DatasetBackend` abstraction with a
sharded backend implementation.", change "Add `BucketedBatchSampler` keyed by
latent length." to "Introduce `BucketedBatchSampler` keyed by latent length.",
change "Add `train.py validate-dataset --dataset-dir ...` command." to "Extend
`train.py` with a `validate-dataset --dataset-dir ...` command.", and change
"Add `training_mode` enum + `full` branch in config/trainer with decoder-only
unfreeze." to "Add a `training_mode` enum and a `full` branch in config/trainer
that enables decoder-only unfreeze."); keep the same symbols (`DatasetBackend`,
`BucketedBatchSampler`, `train.py validate-dataset`, `training_mode` and `full`)
so reviewers can locate the changes.
- Around line 109-130: Add a testing strategy tied to each phase: for Phase 1,
add unit tests for the packed dataset writer/loader, bucketed sampler, and
preprocessing QA validator (e.g., PackedDatasetWriterTests,
BucketedSamplerTests, QAValidatorTests) plus backward-compatibility tests for
old vs. new data formats; for Phase 2, add unit/integration tests for mmap/LRU
cache policy and optional device prefetch (CachePolicyTests,
DevicePrefetchIntegration) and performance regression benchmarks that validate
throughput improvements and record baseline metrics by GPU class; for Phase 3,
add unit tests for training_mode=full behavior, parameter group LR multipliers,
and full-state checkpoint/resume (TrainingModeFullTests, ParamGroupTests,
CheckpointResumeTests) as well as end-to-end training integration tests; for
Phase 4, add tests for distributed/sharded optimizer options, staged unfreeze
profiles, and eval/early-stop controls (DistributedOptimizerTests,
StagedUnfreezeTests, EarlyStopIntegration) and ensure all phases are wired into
CI with clear pass/fail criteria, recorded performance thresholds, and
reproducible test data fixtures.
- Around line 93-96: Update the checkpointing metadata to include an explicit
checkpoint format version field (e.g., "checkpoint_version" using semantic
versioning) whenever saving the full model state, sharded states, optimizer,
scheduler, scaler, or EMA checkpoints; on resume (load) validate this metadata
in the resume logic and emit a clear, actionable error if the version is
incompatible (or route to compatibility/upgrade handlers), and ensure any code
paths that change the on-disk layout bump the checkpoint_version and add a
loader/upgrade path for older versions to avoid cryptic failures.

ℹ️ Review info

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2f870f9 and a8146d4.

📒 Files selected for processing (5)
  • acestep/training/data_module.py
  • acestep/training/preprocessed_collate.py
  • acestep/training/preprocessed_dataset.py
  • acestep/training/preprocessed_sampler.py
  • docs/lora_full_finetune_pipeline_proposal.md

Comment on lines +5 to +10
from typing import Dict, List

import torch


def collate_preprocessed_batch(batch: List[Dict]) -> Dict[str, torch.Tensor]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First, locate the file and read it
fd -t f "preprocessed_collate.py" | head -5

Repository: ace-step/ACE-Step-1.5

Length of output: 105


🏁 Script executed:

#!/bin/bash
# Read the file content, focusing on the function signature and return statement
if [ -f "acestep/training/preprocessed_collate.py" ]; then
  cat -n "acestep/training/preprocessed_collate.py" | head -70
else
  find . -name "preprocessed_collate.py" -exec cat -n {} \; | head -70
fi

Repository: ace-step/ACE-Step-1.5

Length of output: 2713


🏁 Script executed:

#!/bin/bash
# Search for usages of collate_preprocessed_batch function
rg -n "collate_preprocessed_batch" --type=py

Repository: ace-step/ACE-Step-1.5

Length of output: 604


🏁 Script executed:

#!/bin/bash
# Check if metadata field is accessed anywhere to understand usage patterns
rg -n "\[\"metadata\"\]|\['metadata'\]" --type=py -B2 -A2

Repository: ace-step/ACE-Step-1.5

Length of output: 3934


🏁 Script executed:

#!/bin/bash
# Read the function in data_module.py to see if it has the same issue
sed -n '207,275p' acestep/training/data_module.py

Repository: ace-step/ACE-Step-1.5

Length of output: 2832


Fix the return type annotation to match the actual payload in both collate_preprocessed_batch functions.

Both acestep/training/preprocessed_collate.py (line 10) and acestep/training/data_module.py (line 207) declare -> Dict[str, torch.Tensor], but both functions return "metadata" as a Python list (lines 60 and 271 respectively). This type mismatch breaks static type checking and misleads callers about the actual structure.

Replace the imprecise Dict[str, torch.Tensor] with a TypedDict that accurately describes the returned structure:

Proposed typing fix
-from typing import Dict, List
+from typing import Any, Dict, List, TypedDict
+
+
+class CollatedPreprocessedBatch(TypedDict):
+    target_latents: torch.Tensor
+    attention_mask: torch.Tensor
+    encoder_hidden_states: torch.Tensor
+    encoder_attention_mask: torch.Tensor
+    context_latents: torch.Tensor
+    metadata: List[Dict[str, Any]]
 
 
-def collate_preprocessed_batch(batch: List[Dict]) -> Dict[str, torch.Tensor]:
+def collate_preprocessed_batch(batch: List[Dict[str, Any]]) -> CollatedPreprocessedBatch:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training/preprocessed_collate.py` around lines 5 - 10, The return
annotation Dict[str, torch.Tensor] is incorrect because
collate_preprocessed_batch returns a mix of tensors and a Python list for
"metadata"; define a TypedDict (e.g., PreprocessedBatch: TypedDict) that lists
each returned key and its precise type (for example tensor fields as
torch.Tensor and "metadata" as List[Dict[str, Any]]), import TypedDict and Any
from typing, and update the return type of collate_preprocessed_batch (in this
module) and the matching function in data_module.py to -> PreprocessedBatch so
static typing matches the actual payload.

Comment on lines +57 to +71
logger.warning(
"Some tensor files not found: %d missing",
len(self.sample_paths) - len(self.valid_paths),
)

self.latent_lengths: List[int] = []
for vp in self.valid_paths:
try:
sample = torch.load(vp, map_location="cpu", weights_only=True)
self.latent_lengths.append(int(sample["target_latents"].shape[0]))
except (FileNotFoundError, PermissionError, EOFError, OSError, KeyError, RuntimeError) as exc:
logger.warning("Failed to read latent length from %s: %s", vp, exc)
self.latent_lengths.append(0)

logger.info("PreprocessedTensorDataset: %d samples from %s", len(self.valid_paths), self.tensor_dir)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

loguru Python logger formatting API {} vs % placeholders documentation

💡 Result:

Loguru’s logging methods (logger.debug/info/...) do not support %-style placeholders (%s, %(name)s, etc.). They support {}-style formatting only, using Python’s str.format() under the hood:

  • Replace stdlib-style: logger.debug("Some variable: %s", var)
  • With Loguru-style: logger.debug("Some variable: {}", var) (loguru.readthedocs.io)

How it works:

  • Any *args / **kwargs you pass are applied as message.format(*args, **kwargs) (and may be ignored if not referenced by a {} field). (loguru.readthedocs.io)
  • Example with named fields: logger.info("My name is {name}", name="John"). (loguru.readthedocs.io)

Sink/handler format= strings (for logger.add(..., format="...")) also use brace fields like {time}, {level}, {message}. (loguru.readthedocs.io)

Citations:


🏁 Script executed:

rg "logger\.(debug|info|warning|error)" acestep/training/preprocessed_dataset.py -A 2

Repository: ace-step/ACE-Step-1.5

Length of output: 767


Loguru uses {} placeholders, not % style—log arguments are currently being silently dropped.

These logger calls use %-style formatting which loguru does not support. The arguments are passed to .format(*args, **kwargs) but ignored because no {} fields exist in the message strings. Replace with {} formatting across all five instances (lines 57, 65, 69, 85, and 90).

🔧 Proposed fix
             logger.warning(
-                "Some tensor files not found: %d missing",
+                "Some tensor files not found: {} missing",
                 len(self.sample_paths) - len(self.valid_paths),
             )
@@
-                logger.warning("Failed to read latent length from %s: %s", vp, exc)
+                logger.warning("Failed to read latent length from {}: {}", vp, exc)
                 self.latent_lengths.append(0)
@@
-        logger.info("PreprocessedTensorDataset: %d samples from %s", len(self.valid_paths), self.tensor_dir)
+        logger.info("PreprocessedTensorDataset: {} samples from {}", len(self.valid_paths), self.tensor_dir)
@@
-                logger.debug("Resolved legacy manifest path via safe root: %s", raw)
+                logger.debug("Resolved legacy manifest path via safe root: {}", raw)
@@
-        logger.warning("Skipping unresolvable manifest path: %s", raw)
+        logger.warning("Skipping unresolvable manifest path: {}", raw)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training/preprocessed_dataset.py` around lines 57 - 71, The logger
calls in PreprocessedTensorDataset use percent-style formatting which Loguru
ignores; update all logger.warning and logger.info usages in this region (the
calls inside the constructor/initializer that reference logger.warning("Some
tensor files not found: %d missing", ...), logger.warning("Failed to read latent
length from %s: %s", vp, exc), and logger.info("PreprocessedTensorDataset: %d
samples from %s", ...)) to use brace-style placeholders (e.g., "Some tensor
files not found: {} missing") and pass the values as positional args so Loguru
will interpolate them; keep the exception handling (in the try/except around
torch.load and appending to self.latent_lengths) unchanged.

Comment on lines +96 to +115
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""Load one preprocessed sample."""
if self.cache_policy == "ram_lru" and idx in self._cache:
data = self._cache.pop(idx)
self._cache[idx] = data
else:
data = torch.load(self.valid_paths[idx], map_location="cpu", weights_only=True)
if self.cache_policy == "ram_lru" and self.cache_max_items > 0:
self._cache[idx] = data
while len(self._cache) > self.cache_max_items:
self._cache.popitem(last=False)

return {
"target_latents": data["target_latents"],
"attention_mask": data["attention_mask"],
"encoder_hidden_states": data["encoder_hidden_states"],
"encoder_attention_mask": data["encoder_attention_mask"],
"context_latents": data["context_latents"],
"metadata": data.get("metadata", {}),
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

__getitem__ return annotation is too narrow for the returned payload.

The method returns "metadata" as a dict, so Dict[str, torch.Tensor] is incorrect.

🔧 Proposed fix
-    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
+    def __getitem__(self, idx: int) -> Dict[str, Any]:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training/preprocessed_dataset.py` around lines 96 - 115, The return
type of __getitem__ is too narrow: it returns tensors for most keys but
"metadata" is a dict (or could be other non-tensor values), so update the type
hint for __getitem__ from Dict[str, torch.Tensor] to a broader type such as
Dict[str, Any] (and add Any to the typing imports) so the signature reflects
that values can be non-tensor; ensure the change is applied to the __getitem__
definition and any related annotations in the same module that reference this
method.

Comment on lines +42 to +45
def __len__(self) -> int:
"""Return estimated number of batches."""
total = len(self.lengths)
return (total + self.batch_size - 1) // self.batch_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

__len__ can underreport actual batch count for bucketed iteration.

__iter__ batches per bucket, but ceil(total / batch_size) ignores per-bucket remainders and can be smaller than the real yielded batch count.

🔧 Proposed fix
     def __len__(self) -> int:
-        """Return estimated number of batches."""
-        total = len(self.lengths)
-        return (total + self.batch_size - 1) // self.batch_size
+        """Return number of batches produced by bucketed iteration."""
+        bucket_counts: Dict[int, int] = {}
+        for length in self.lengths:
+            bucket = int(length // 64)
+            bucket_counts[bucket] = bucket_counts.get(bucket, 0) + 1
+        return sum(
+            (count + self.batch_size - 1) // self.batch_size
+            for count in bucket_counts.values()
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training/preprocessed_sampler.py` around lines 42 - 45, The current
__len__ returns ceil(total_items / batch_size) which undercounts when __iter__
yields batches per bucket; update __len__ to sum over buckets instead: for each
bucket in self.lengths compute (len(bucket) + self.batch_size - 1) //
self.batch_size and return the total, ensuring you use the same bucket structure
that __iter__ uses (reference __len__, __iter__, self.lengths, self.batch_size).

…ora-training-pipeline

Preprocessed dataset improvements: bucketing, RAM-LRU cache, validator, and full fine-tune mode
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
acestep/training_v2/trainer_helpers.py (2)

374-374: ⚠️ Potential issue | 🟡 Minor

Fix unresolved Dict annotation (F821) in the return type.

Dict is referenced but not imported. Prefer the built-in generic for the Python 3.11+ target.

Proposed fix
-) -> Optional[Tuple[int, int, Dict[str, Any]]]:
+) -> Optional[Tuple[int, int, dict[str, Any]]]:
#!/bin/bash
set -euo pipefail

# Verify there are no unresolved typing.Dict annotations left in this module.
fd -p 'trainer_helpers.py$' --exec rg -nP '\bDict\s*\[' {}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training_v2/trainer_helpers.py` at line 374, The return type
currently uses typing.Dict in the signature ")-> Optional[Tuple[int, int,
Dict[str, Any]]]:", which is unresolved; replace typing.Dict with the built-in
generic by changing the annotation to "Optional[Tuple[int, int, dict[str,
Any]]]" and ensure Any remains imported (or replace Any with an appropriate
concrete type); also remove any unused "Dict" import from typing and scan the
module for other occurrences of "Dict[" and update them to "dict[" to avoid F821
errors.

494-497: ⚠️ Potential issue | 🟠 Major

Use the shared unwrapping helper before load_state_dict().

This still unwraps only one _forward_module level and can miss nested wrappers. Use _unwrap_decoder(module.model) here for parity with the full-mode helpers.

Proposed fix
-    decoder = module.model.decoder
-    if hasattr(decoder, "_forward_module"):
-        decoder = decoder._forward_module
+    decoder = _unwrap_decoder(module.model)
     decoder.load_state_dict(state_dict, strict=False)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training_v2/trainer_helpers.py` around lines 494 - 497, Replace the
manual one-level unwrap of module.model.decoder with the shared helper: call
_unwrap_decoder(module.model) to obtain the real decoder instance, then call
decoder.load_state_dict(state_dict, strict=False); ensure you still pass
strict=False and remove the hasattr(... "_forward_module") branch so nested
wrappers are correctly unwrapped before load_state_dict.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@acestep/training_v2/trainer_helpers.py`:
- Around line 447-460: _resume_lokr can return None on handled paths which
_resume_lokr_or_lora treats as "not handled" and falls back to LoRA; change
_resume_lokr (the function that calls state_loader and yields
TrainingUpdate/TrainingUpdate warnings) to return an explicit handled boolean
(e.g., True when LoKR was loaded or when no training_state/handled warning path
occurred, False only when truly unhandled), update callers in
_resume_lokr_or_lora to check that boolean instead of truthiness/None so LoRA
loading is skipped when LoKR already handled the resume; apply the same explicit
handled-return fix to the other mirrored block (the second occurrence referenced
around lines 520-529).

---

Duplicate comments:
In `@acestep/training_v2/trainer_helpers.py`:
- Line 374: The return type currently uses typing.Dict in the signature ")->
Optional[Tuple[int, int, Dict[str, Any]]]:", which is unresolved; replace
typing.Dict with the built-in generic by changing the annotation to
"Optional[Tuple[int, int, dict[str, Any]]]" and ensure Any remains imported (or
replace Any with an appropriate concrete type); also remove any unused "Dict"
import from typing and scan the module for other occurrences of "Dict[" and
update them to "dict[" to avoid F821 errors.
- Around line 494-497: Replace the manual one-level unwrap of
module.model.decoder with the shared helper: call _unwrap_decoder(module.model)
to obtain the real decoder instance, then call
decoder.load_state_dict(state_dict, strict=False); ensure you still pass
strict=False and remove the hasattr(... "_forward_module") branch so nested
wrappers are correctly unwrapped before load_state_dict.

ℹ️ Review info

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a8146d4 and 461c79f.

📒 Files selected for processing (2)
  • acestep/training_v2/fixed_lora_module.py
  • acestep/training_v2/trainer_helpers.py

Comment on lines +447 to +460
training_state = state_loader(ckpt_dir=ckpt_dir)
if training_state is None:
yield TrainingUpdate(0, 0.0, "[OK] LoKR weights loaded (no training state)", kind="info")
return None

# Full mode resume: no full decoder checkpoint found.
if getattr(trainer.training_config, "training_mode", "adapter") == "full":
yield TrainingUpdate(
0,
0.0,
f"[WARN] full_decoder_state.pt not found in {ckpt_dir}",
kind="warn",
)
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

LoKR resume path can incorrectly fall through into LoRA resume.

After LoKR weights are loaded, _resume_lokr() still returns None on handled paths, and _resume_lokr_or_lora() treats that as “not handled” and attempts LoRA loading. This can mis-resume or override the intended LoKR state.

Proposed fix
 def _resume_lokr(
@@
     load_lokr_weights(module.lycoris_net, str(lokr_weights_path))
     training_state = state_loader(ckpt_dir=ckpt_dir)
     if training_state is None:
         yield TrainingUpdate(0, 0.0, "[OK] LoKR weights loaded (no training state)", kind="info")
         return None
-
-    # Full mode resume: no full decoder checkpoint found.
-    if getattr(trainer.training_config, "training_mode", "adapter") == "full":
-        yield TrainingUpdate(
-            0,
-            0.0,
-            f"[WARN] full_decoder_state.pt not found in {ckpt_dir}",
-            kind="warn",
-        )
-        return
+    epoch, step, _ = training_state
+    yield TrainingUpdate(
+        0,
+        0.0,
+        f"[OK] Resumed LoKR from epoch {epoch}, step {step}",
+        kind="info",
+    )
+    return (epoch, step)
@@
 def _resume_lokr_or_lora(
@@
-    lokr_result = yield from _resume_lokr(
+    has_lokr_checkpoint = (ckpt_dir / "lokr_weights.safetensors").exists()
+    lokr_result = yield from _resume_lokr(
         trainer,
         resume_path,
         ckpt_dir,
         module,
         state_loader,
     )
+    if has_lokr_checkpoint and module.lycoris_net is not None:
+        return lokr_result
     if lokr_result is not None:
         return lokr_result
     return (yield from _resume_lora(ckpt_dir, module, optimizer, scheduler))

Also applies to: 520-529

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training_v2/trainer_helpers.py` around lines 447 - 460, _resume_lokr
can return None on handled paths which _resume_lokr_or_lora treats as "not
handled" and falls back to LoRA; change _resume_lokr (the function that calls
state_loader and yields TrainingUpdate/TrainingUpdate warnings) to return an
explicit handled boolean (e.g., True when LoKR was loaded or when no
training_state/handled warning path occurred, False only when truly unhandled),
update callers in _resume_lokr_or_lora to check that boolean instead of
truthiness/None so LoRA loading is skipped when LoKR already handled the resume;
apply the same explicit handled-return fix to the other mirrored block (the
second occurrence referenced around lines 520-529).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants