Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
74b53e1
feat(training_v2): implement dataset validation, bucketing cache, and…
azazeal04 Feb 26, 2026
a3ff595
Merge pull request #1 from azazeal04/codex/propose-improvements-for-l…
azazeal04 Feb 26, 2026
b8d9def
Fix legacy datamodule init and full-mode resume fallback
azazeal04 Feb 26, 2026
23769e3
Merge pull request #2 from azazeal04/codex/fix-high-priority-bugs-in-…
azazeal04 Feb 26, 2026
bddafc0
Merge pull request #3 from azazeal04/codex/propose-improvements-for-l…
azazeal04 Feb 26, 2026
2b9fd32
Delete docs/lora_full_finetune_pipeline_proposal.md
azazeal04 Feb 26, 2026
add3f66
Merge branch 'ace-step:main' into main
azazeal04 Feb 26, 2026
634ee5f
Address review follow-ups for resume helpers and data tests
azazeal04 Feb 26, 2026
bd21f9a
Merge branch 'codex/propose-improvements-for-lora-training-pipeline' …
azazeal04 Feb 26, 2026
2f870f9
Merge pull request #5 from azazeal04/codex/fix-high-priority-bugs-in-…
azazeal04 Feb 26, 2026
6d10eab
fix(training): address review comments for bucketing, cache docs, and…
azazeal04 Feb 26, 2026
12ea99d
Merge branch 'main' into codex/propose-improvements-for-lora-training…
azazeal04 Feb 26, 2026
a8146d4
Merge pull request #6 from azazeal04/codex/propose-improvements-for-l…
azazeal04 Feb 26, 2026
a418ec6
fix(training): address review feedback on typing, logging, and sample…
azazeal04 Feb 26, 2026
dcba69e
Merge branch 'main' into codex/propose-improvements-for-lora-training…
azazeal04 Feb 26, 2026
14a997a
Merge pull request #7 from azazeal04/codex/propose-improvements-for-l…
azazeal04 Feb 26, 2026
d0ce851
Merge branch 'ace-step:main' into main
azazeal04 Feb 27, 2026
5af7826
Merge branch 'main' into main
azazeal04 Feb 28, 2026
461c79f
Merge branch 'main' into main
azazeal04 Feb 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 144 additions & 28 deletions acestep/training/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import json
import random
from collections import OrderedDict
from typing import Optional, List, Dict, Any, Tuple
from loguru import logger

Expand All @@ -32,6 +33,41 @@ class LightningDataModule:
# Preprocessed Tensor Dataset (Recommended for Training)
# ============================================================================

class BucketedBatchSampler:
"""Batch sampler that groups indices by latent length buckets."""

def __init__(self, lengths: List[int], batch_size: int, shuffle: bool = True) -> None:
self.lengths = lengths
self.batch_size = max(1, int(batch_size))
self.shuffle = shuffle

def __iter__(self):
buckets: Dict[int, List[int]] = {}
for idx, length in enumerate(self.lengths):
bucket = int(length // 64)
buckets.setdefault(bucket, []).append(idx)
bucket_keys = list(buckets.keys())
if self.shuffle:
random.shuffle(bucket_keys)
for key in bucket_keys:
group = buckets[key]
if self.shuffle:
random.shuffle(group)
for start in range(0, len(group), self.batch_size):
yield group[start:start + self.batch_size]

def __len__(self) -> int:
bucket_counts: Dict[int, int] = {}
for length in self.lengths:
bucket = int(length // 64)
bucket_counts[bucket] = bucket_counts.get(bucket, 0) + 1

return sum(
(count + self.batch_size - 1) // self.batch_size
for count in bucket_counts.values()
)


class PreprocessedTensorDataset(Dataset):
"""Dataset that loads preprocessed tensor files.

Expand All @@ -45,7 +81,7 @@ class PreprocessedTensorDataset(Dataset):
No VAE/text encoder needed during training - just load tensors directly!
"""

def __init__(self, tensor_dir: str):
def __init__(self, tensor_dir: str, cache_policy: str = "none", cache_max_items: int = 0):
"""Initialize from a directory of preprocessed .pt files.

Args:
Expand All @@ -59,6 +95,9 @@ def __init__(self, tensor_dir: str):
raise ValueError(f"Not an existing directory: {tensor_dir}")
self.tensor_dir = validated_dir
self.sample_paths: List[str] = []
self.cache_policy = cache_policy
self.cache_max_items = max(0, int(cache_max_items))
self._cache: "OrderedDict[int, Dict[str, Any]]" = OrderedDict()

# Load manifest if exists
manifest_path = safe_path("manifest.json", base=self.tensor_dir)
Expand Down Expand Up @@ -87,6 +126,14 @@ def __init__(self, tensor_dir: str):
f"{len(self.sample_paths) - len(self.valid_paths)} missing"
)

self.latent_lengths: List[int] = []
for vp in self.valid_paths:
try:
sample = torch.load(vp, map_location="cpu", weights_only=True)
self.latent_lengths.append(int(sample["target_latents"].shape[0]))
except Exception:
self.latent_lengths.append(0)
Comment on lines +129 to +135
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Avoid bare except Exception; handle specific errors and log failures.

The current code silently catches all exceptions and defaults to length 0, which could mask real issues (corrupt files, permission errors) and cause unexpected bucketing behavior for failed samples.

🛡️ Proposed fix with specific exception handling
         self.latent_lengths: List[int] = []
         for vp in self.valid_paths:
             try:
                 sample = torch.load(vp, map_location="cpu", weights_only=True)
                 self.latent_lengths.append(int(sample["target_latents"].shape[0]))
-            except Exception:
+            except (OSError, KeyError, RuntimeError) as e:
+                logger.warning(f"Failed to read latent length from {vp}: {e}")
                 self.latent_lengths.append(0)
🧰 Tools
🪛 Ruff (0.15.2)

[warning] 127-127: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training/data_module.py` around lines 122 - 128, The loop that builds
self.latent_lengths in DataModule (using self.valid_paths) currently uses a bare
except which masks errors; change it to catch specific exceptions when loading
torch files (e.g., FileNotFoundError, PermissionError, EOFError, RuntimeError)
and log the failure including the path and exception (use the module/class
logger or python logging), append 0 only for those expected/recoverable errors,
and let other unexpected exceptions propagate (or re-raise) so real issues
aren’t silently ignored; update the try/except around torch.load and the
sample["target_latents"] access accordingly and include the path (vp) and
exception text in the log message.


logger.info(
f"PreprocessedTensorDataset: {len(self.valid_paths)} samples "
f"from {self.tensor_dir}"
Expand Down Expand Up @@ -136,15 +183,23 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
Returns:
Dictionary containing all pre-computed tensors for training
"""
tensor_path = self.valid_paths[idx]
data = torch.load(tensor_path, map_location='cpu', weights_only=True)

if self.cache_policy == "ram_lru" and idx in self._cache:
data = self._cache.pop(idx)
self._cache[idx] = data
else:
tensor_path = self.valid_paths[idx]
data = torch.load(tensor_path, map_location='cpu', weights_only=True)
if self.cache_policy == "ram_lru" and self.cache_max_items > 0:
self._cache[idx] = data
while len(self._cache) > self.cache_max_items:
self._cache.popitem(last=False)

return {
"target_latents": data["target_latents"], # [T, 64]
"attention_mask": data["attention_mask"], # [T]
"encoder_hidden_states": data["encoder_hidden_states"], # [L, D]
"encoder_attention_mask": data["encoder_attention_mask"], # [L]
"context_latents": data["context_latents"], # [T, 65]
"target_latents": data["target_latents"],
"attention_mask": data["attention_mask"],
"encoder_hidden_states": data["encoder_hidden_states"],
"encoder_attention_mask": data["encoder_attention_mask"],
"context_latents": data["context_latents"],
"metadata": data.get("metadata", {}),
}

Expand Down Expand Up @@ -219,11 +274,10 @@ def collate_preprocessed_batch(batch: List[Dict]) -> Dict[str, torch.Tensor]:

class PreprocessedDataModule(LightningDataModule if LIGHTNING_AVAILABLE else object):
"""DataModule for preprocessed tensor files.

This is the recommended DataModule for training. It loads pre-computed tensors
directly without needing VAE, text encoder, or condition encoder at training time.

Loads precomputed tensors directly, avoiding VAE/text encoding at train time.
"""

def __init__(
self,
tensor_dir: str,
Expand All @@ -234,19 +288,28 @@ def __init__(
persistent_workers: bool = True,
pin_memory_device: str = "",
val_split: float = 0.0,
length_bucket: bool = False,
cache_policy: str = "none",
cache_max_items: int = 0,
):
"""Initialize the data module.
"""Initialize the preprocessed data module.

Args:
tensor_dir: Directory containing preprocessed .pt files
batch_size: Training batch size
num_workers: Number of data loading workers
pin_memory: Whether to pin memory for faster GPU transfer
val_split: Fraction of data for validation (0 = no validation)
tensor_dir: Directory containing preprocessed ``.pt`` files.
batch_size: Training batch size.
num_workers: Number of DataLoader worker processes.
pin_memory: Pin host memory for faster GPU transfer.
prefetch_factor: Number of prefetched batches per worker.
persistent_workers: Keep worker processes alive between epochs.
pin_memory_device: Device string used by pinned memory allocator.
val_split: Fraction of data reserved for validation.
length_bucket: Whether to bucket training samples by latent length.
cache_policy: Dataset cache mode ("none" or "ram_lru").
cache_max_items: Maximum cached entries when RAM LRU is enabled.
"""
if LIGHTNING_AVAILABLE:
super().__init__()

self.tensor_dir = tensor_dir
self.batch_size = batch_size
self.num_workers = num_workers
Expand All @@ -255,36 +318,61 @@ def __init__(
self.persistent_workers = persistent_workers
self.pin_memory_device = pin_memory_device
self.val_split = val_split
self.length_bucket = length_bucket
self.cache_policy = cache_policy
self.cache_max_items = cache_max_items

self.train_dataset = None
self.val_dataset = None

def setup(self, stage: Optional[str] = None):
"""Setup datasets."""
if stage == 'fit' or stage is None:
# Create full dataset
full_dataset = PreprocessedTensorDataset(self.tensor_dir)
full_dataset = PreprocessedTensorDataset(
self.tensor_dir,
cache_policy=self.cache_policy,
cache_max_items=self.cache_max_items,
)

# Split if validation requested
if self.val_split > 0 and len(full_dataset) > 1:
n_val = max(1, int(len(full_dataset) * self.val_split))
n_train = len(full_dataset) - n_val

self.train_dataset, self.val_dataset = torch.utils.data.random_split(
full_dataset, [n_train, n_val]
)
else:
self.train_dataset = full_dataset
self.val_dataset = None


def _resolve_train_latent_lengths(self) -> Optional[List[int]]:
"""Resolve latent lengths for bucketed sampling, including Subset splits."""
if not self.length_bucket or self.train_dataset is None:
return None

ds = self.train_dataset
if isinstance(ds, torch.utils.data.Subset):
base = ds.dataset
indices = list(ds.indices)
base_lengths = getattr(base, "latent_lengths", None)
if base_lengths is None:
return None
return [base_lengths[i] for i in indices]

base_lengths = getattr(ds, "latent_lengths", None)
if base_lengths is None:
return None
return list(base_lengths)

def train_dataloader(self) -> DataLoader:
"""Create training dataloader."""
prefetch_factor = None if self.num_workers == 0 else self.prefetch_factor
persistent_workers = False if self.num_workers == 0 else self.persistent_workers
kwargs = dict(
dataset=self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
shuffle=not self.length_bucket,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=collate_preprocessed_batch,
Expand All @@ -294,8 +382,16 @@ def train_dataloader(self) -> DataLoader:
)
if self.pin_memory_device:
kwargs["pin_memory_device"] = self.pin_memory_device
if self.length_bucket and hasattr(self.train_dataset, "latent_lengths"):
kwargs.pop("batch_size", None)
kwargs.pop("shuffle", None)
kwargs["batch_sampler"] = BucketedBatchSampler(
lengths=list(getattr(self.train_dataset, "latent_lengths", [])),
batch_size=self.batch_size,
shuffle=True,
)
Comment on lines +385 to +392
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Bucketing silently fails when val_split > 0.

When validation split is enabled, self.train_dataset becomes a torch.utils.data.Subset, which lacks the latent_lengths attribute. The hasattr check passes silently, falling back to regular batching without any warning.

🐛 Proposed fix to access underlying dataset's latent_lengths
-        if self.length_bucket and hasattr(self.train_dataset, "latent_lengths"):
+        latent_lengths = None
+        if self.length_bucket:
+            ds = self.train_dataset
+            # Handle Subset from random_split
+            if hasattr(ds, "dataset") and hasattr(ds, "indices"):
+                underlying = ds.dataset
+                if hasattr(underlying, "latent_lengths"):
+                    latent_lengths = [underlying.latent_lengths[i] for i in ds.indices]
+            elif hasattr(ds, "latent_lengths"):
+                latent_lengths = ds.latent_lengths
+        if latent_lengths is not None:
             kwargs.pop("batch_size", None)
             kwargs.pop("shuffle", None)
             kwargs["batch_sampler"] = BucketedBatchSampler(
-                lengths=list(getattr(self.train_dataset, "latent_lengths", [])),
+                lengths=latent_lengths,
                 batch_size=self.batch_size,
                 shuffle=True,
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training/data_module.py` around lines 355 - 362, The bucketing code
fails silently when self.train_dataset is a torch.utils.data.Subset because
hasattr(self.train_dataset, "latent_lengths") is false for the Subset wrapper;
update the condition in the block that sets BucketedBatchSampler (symbols:
length_bucket, self.train_dataset, latent_lengths, BucketedBatchSampler) to look
through the Subset wrapper and obtain the underlying dataset's latent_lengths
(e.g., check getattr(self.train_dataset, "dataset", self.train_dataset) for
latent_lengths) before deciding to pop batch_size/shuffle and assign
batch_sampler, so the BucketedBatchSampler still receives the correct lengths
when val_split > 0.

return DataLoader(**kwargs)

def val_dataloader(self) -> Optional[DataLoader]:
"""Create validation dataloader."""
if self.val_dataset is None:
Expand Down Expand Up @@ -474,17 +570,37 @@ def __init__(
pin_memory: bool = True,
max_duration: float = 240.0,
val_split: float = 0.0,
length_bucket: bool = False,
cache_policy: str = "none",
cache_max_items: int = 0,
):
"""Initialize legacy raw-audio datamodule.

Args:
samples: Raw training sample metadata entries.
dit_handler: Model handler used by legacy training flows.
batch_size: Number of samples per batch.
num_workers: Number of dataloader workers.
pin_memory: Whether to enable pinned memory in dataloaders.
max_duration: Max audio duration (seconds) for clipping.
val_split: Validation split fraction.
length_bucket: Accepted for compatibility; unused for raw audio mode.
cache_policy: Accepted for compatibility; unused for raw audio mode.
cache_max_items: Accepted for compatibility; unused for raw audio mode.
"""
if LIGHTNING_AVAILABLE:
super().__init__()

self.samples = samples
self.dit_handler = dit_handler
self.batch_size = batch_size
self.num_workers = num_workers
self.pin_memory = pin_memory
self.max_duration = max_duration
self.val_split = val_split
self.length_bucket = length_bucket
self.cache_policy = cache_policy
self.cache_max_items = cache_max_items

self.train_dataset = None
self.val_dataset = None
Expand Down
19 changes: 19 additions & 0 deletions acestep/training/data_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@

import os
import json
import random
import tempfile
import unittest
from unittest import mock

import torch

from acestep.training.path_safety import safe_path, set_safe_root
from acestep.training.data_module import (
BucketedBatchSampler,
PreprocessedTensorDataset,
load_dataset_from_json,
)
Expand Down Expand Up @@ -193,5 +198,19 @@ def test_valid_json(self):
os.unlink(path)


class AceStepDataModuleInitTests(unittest.TestCase):
"""Regression tests for legacy ``AceStepDataModule`` initialization."""

def test_init_does_not_require_preprocessed_only_cache_args(self):
"""Legacy raw-audio datamodule should initialize without NameError."""
from acestep.training.data_module import AceStepDataModule

module = AceStepDataModule(samples=[], dit_handler=object())

self.assertEqual(module.samples, [])
self.assertIsNotNone(module.dit_handler)

Comment on lines +201 to +212
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Coverage is missing for the new bucketing/cache behaviors introduced by this PR.

This addition only validates legacy initialization; it does not test the new length_bucket and ram_lru paths (including split-dataset behavior). Please add deterministic tests for those new code paths.

As per coding guidelines: **/*_test.py: "Add or update tests for every behavior change and bug fix" and "Include at least: one success-path test, one regression/edge-case test for the bug being fixed, and one non-target behavior check when relevant".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training/data_module_test.py` around lines 201 - 212, Add
deterministic tests in acestep/training/data_module_test.py covering the new
bucketing/cache behavior for AceStepDataModule: create tests that instantiate
AceStepDataModule with length_bucket enabled to verify samples are assigned to
expected buckets (success path), a regression/edge-case that uses ram_lru
caching with a tiny capacity to confirm eviction/split-dataset behavior occurs
deterministically, and a non-target behavior check ensuring legacy
initialization (samples=[], dit_handler=...) still works when these options are
present; reference the AceStepDataModule constructor and its length_bucket and
ram_lru configuration flags/symbols to locate where to set parameters and assert
bucket assignments, eviction counts, and that dit_handler remains set.



if __name__ == "__main__":
unittest.main()
61 changes: 61 additions & 0 deletions acestep/training/preprocessed_collate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Collate helpers for preprocessed ACE-Step training tensors."""

from __future__ import annotations

from typing import Dict, List

import torch


def collate_preprocessed_batch(batch: List[Dict]) -> Dict[str, torch.Tensor]:
Comment on lines +5 to +10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First, locate the file and read it
fd -t f "preprocessed_collate.py" | head -5

Repository: ace-step/ACE-Step-1.5

Length of output: 105


🏁 Script executed:

#!/bin/bash
# Read the file content, focusing on the function signature and return statement
if [ -f "acestep/training/preprocessed_collate.py" ]; then
  cat -n "acestep/training/preprocessed_collate.py" | head -70
else
  find . -name "preprocessed_collate.py" -exec cat -n {} \; | head -70
fi

Repository: ace-step/ACE-Step-1.5

Length of output: 2713


🏁 Script executed:

#!/bin/bash
# Search for usages of collate_preprocessed_batch function
rg -n "collate_preprocessed_batch" --type=py

Repository: ace-step/ACE-Step-1.5

Length of output: 604


🏁 Script executed:

#!/bin/bash
# Check if metadata field is accessed anywhere to understand usage patterns
rg -n "\[\"metadata\"\]|\['metadata'\]" --type=py -B2 -A2

Repository: ace-step/ACE-Step-1.5

Length of output: 3934


🏁 Script executed:

#!/bin/bash
# Read the function in data_module.py to see if it has the same issue
sed -n '207,275p' acestep/training/data_module.py

Repository: ace-step/ACE-Step-1.5

Length of output: 2832


Fix the return type annotation to match the actual payload in both collate_preprocessed_batch functions.

Both acestep/training/preprocessed_collate.py (line 10) and acestep/training/data_module.py (line 207) declare -> Dict[str, torch.Tensor], but both functions return "metadata" as a Python list (lines 60 and 271 respectively). This type mismatch breaks static type checking and misleads callers about the actual structure.

Replace the imprecise Dict[str, torch.Tensor] with a TypedDict that accurately describes the returned structure:

Proposed typing fix
-from typing import Dict, List
+from typing import Any, Dict, List, TypedDict
+
+
+class CollatedPreprocessedBatch(TypedDict):
+    target_latents: torch.Tensor
+    attention_mask: torch.Tensor
+    encoder_hidden_states: torch.Tensor
+    encoder_attention_mask: torch.Tensor
+    context_latents: torch.Tensor
+    metadata: List[Dict[str, Any]]
 
 
-def collate_preprocessed_batch(batch: List[Dict]) -> Dict[str, torch.Tensor]:
+def collate_preprocessed_batch(batch: List[Dict[str, Any]]) -> CollatedPreprocessedBatch:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@acestep/training/preprocessed_collate.py` around lines 5 - 10, The return
annotation Dict[str, torch.Tensor] is incorrect because
collate_preprocessed_batch returns a mix of tensors and a Python list for
"metadata"; define a TypedDict (e.g., PreprocessedBatch: TypedDict) that lists
each returned key and its precise type (for example tensor fields as
torch.Tensor and "metadata" as List[Dict[str, Any]]), import TypedDict and Any
from typing, and update the return type of collate_preprocessed_batch (in this
module) and the matching function in data_module.py to -> PreprocessedBatch so
static typing matches the actual payload.

"""Pad and stack variable-length preprocessed training tensors.

Args:
batch: Per-sample tensor dictionaries.

Returns:
Batched tensor dictionary with padding to max lengths in batch.
"""
max_latent_len = max(s["target_latents"].shape[0] for s in batch)
max_encoder_len = max(s["encoder_hidden_states"].shape[0] for s in batch)

target_latents = []
attention_masks = []
encoder_hidden_states = []
encoder_attention_masks = []
context_latents = []

for sample in batch:
tl = sample["target_latents"]
if tl.shape[0] < max_latent_len:
tl = torch.cat([tl, tl.new_zeros(max_latent_len - tl.shape[0], tl.shape[1])], dim=0)
target_latents.append(tl)

am = sample["attention_mask"]
if am.shape[0] < max_latent_len:
am = torch.cat([am, am.new_zeros(max_latent_len - am.shape[0])], dim=0)
attention_masks.append(am)

cl = sample["context_latents"]
if cl.shape[0] < max_latent_len:
cl = torch.cat([cl, cl.new_zeros(max_latent_len - cl.shape[0], cl.shape[1])], dim=0)
context_latents.append(cl)

ehs = sample["encoder_hidden_states"]
if ehs.shape[0] < max_encoder_len:
ehs = torch.cat([ehs, ehs.new_zeros(max_encoder_len - ehs.shape[0], ehs.shape[1])], dim=0)
encoder_hidden_states.append(ehs)

eam = sample["encoder_attention_mask"]
if eam.shape[0] < max_encoder_len:
eam = torch.cat([eam, eam.new_zeros(max_encoder_len - eam.shape[0])], dim=0)
encoder_attention_masks.append(eam)

return {
"target_latents": torch.stack(target_latents),
"attention_mask": torch.stack(attention_masks),
"encoder_hidden_states": torch.stack(encoder_hidden_states),
"encoder_attention_mask": torch.stack(encoder_attention_masks),
"context_latents": torch.stack(context_latents),
"metadata": [s["metadata"] for s in batch],
}
Loading