-
Notifications
You must be signed in to change notification settings - Fork 889
Changes for training #708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes for training #708
Changes from all commits
74b53e1
a3ff595
b8d9def
23769e3
bddafc0
2b9fd32
add3f66
634ee5f
bd21f9a
2f870f9
6d10eab
12ea99d
a8146d4
a418ec6
dcba69e
14a997a
d0ce851
5af7826
461c79f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,6 +8,7 @@ | |
| import os | ||
| import json | ||
| import random | ||
| from collections import OrderedDict | ||
| from typing import Optional, List, Dict, Any, Tuple | ||
| from loguru import logger | ||
|
|
||
|
|
@@ -32,6 +33,41 @@ class LightningDataModule: | |
| # Preprocessed Tensor Dataset (Recommended for Training) | ||
| # ============================================================================ | ||
|
|
||
| class BucketedBatchSampler: | ||
| """Batch sampler that groups indices by latent length buckets.""" | ||
|
|
||
| def __init__(self, lengths: List[int], batch_size: int, shuffle: bool = True) -> None: | ||
| self.lengths = lengths | ||
| self.batch_size = max(1, int(batch_size)) | ||
| self.shuffle = shuffle | ||
|
|
||
| def __iter__(self): | ||
| buckets: Dict[int, List[int]] = {} | ||
| for idx, length in enumerate(self.lengths): | ||
| bucket = int(length // 64) | ||
| buckets.setdefault(bucket, []).append(idx) | ||
| bucket_keys = list(buckets.keys()) | ||
| if self.shuffle: | ||
| random.shuffle(bucket_keys) | ||
| for key in bucket_keys: | ||
| group = buckets[key] | ||
| if self.shuffle: | ||
| random.shuffle(group) | ||
| for start in range(0, len(group), self.batch_size): | ||
| yield group[start:start + self.batch_size] | ||
|
|
||
| def __len__(self) -> int: | ||
| bucket_counts: Dict[int, int] = {} | ||
| for length in self.lengths: | ||
| bucket = int(length // 64) | ||
| bucket_counts[bucket] = bucket_counts.get(bucket, 0) + 1 | ||
|
|
||
| return sum( | ||
| (count + self.batch_size - 1) // self.batch_size | ||
| for count in bucket_counts.values() | ||
| ) | ||
|
|
||
|
|
||
| class PreprocessedTensorDataset(Dataset): | ||
| """Dataset that loads preprocessed tensor files. | ||
|
|
||
|
|
@@ -45,7 +81,7 @@ class PreprocessedTensorDataset(Dataset): | |
| No VAE/text encoder needed during training - just load tensors directly! | ||
| """ | ||
|
|
||
| def __init__(self, tensor_dir: str): | ||
| def __init__(self, tensor_dir: str, cache_policy: str = "none", cache_max_items: int = 0): | ||
| """Initialize from a directory of preprocessed .pt files. | ||
|
|
||
| Args: | ||
|
|
@@ -59,6 +95,9 @@ def __init__(self, tensor_dir: str): | |
| raise ValueError(f"Not an existing directory: {tensor_dir}") | ||
| self.tensor_dir = validated_dir | ||
| self.sample_paths: List[str] = [] | ||
| self.cache_policy = cache_policy | ||
| self.cache_max_items = max(0, int(cache_max_items)) | ||
| self._cache: "OrderedDict[int, Dict[str, Any]]" = OrderedDict() | ||
|
|
||
| # Load manifest if exists | ||
| manifest_path = safe_path("manifest.json", base=self.tensor_dir) | ||
|
|
@@ -87,6 +126,14 @@ def __init__(self, tensor_dir: str): | |
| f"{len(self.sample_paths) - len(self.valid_paths)} missing" | ||
| ) | ||
|
|
||
| self.latent_lengths: List[int] = [] | ||
| for vp in self.valid_paths: | ||
| try: | ||
| sample = torch.load(vp, map_location="cpu", weights_only=True) | ||
| self.latent_lengths.append(int(sample["target_latents"].shape[0])) | ||
| except Exception: | ||
| self.latent_lengths.append(0) | ||
|
|
||
| logger.info( | ||
| f"PreprocessedTensorDataset: {len(self.valid_paths)} samples " | ||
| f"from {self.tensor_dir}" | ||
|
|
@@ -136,15 +183,23 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: | |
| Returns: | ||
| Dictionary containing all pre-computed tensors for training | ||
| """ | ||
| tensor_path = self.valid_paths[idx] | ||
| data = torch.load(tensor_path, map_location='cpu', weights_only=True) | ||
|
|
||
| if self.cache_policy == "ram_lru" and idx in self._cache: | ||
| data = self._cache.pop(idx) | ||
| self._cache[idx] = data | ||
| else: | ||
| tensor_path = self.valid_paths[idx] | ||
| data = torch.load(tensor_path, map_location='cpu', weights_only=True) | ||
| if self.cache_policy == "ram_lru" and self.cache_max_items > 0: | ||
| self._cache[idx] = data | ||
| while len(self._cache) > self.cache_max_items: | ||
| self._cache.popitem(last=False) | ||
|
|
||
| return { | ||
| "target_latents": data["target_latents"], # [T, 64] | ||
| "attention_mask": data["attention_mask"], # [T] | ||
| "encoder_hidden_states": data["encoder_hidden_states"], # [L, D] | ||
| "encoder_attention_mask": data["encoder_attention_mask"], # [L] | ||
| "context_latents": data["context_latents"], # [T, 65] | ||
| "target_latents": data["target_latents"], | ||
| "attention_mask": data["attention_mask"], | ||
| "encoder_hidden_states": data["encoder_hidden_states"], | ||
| "encoder_attention_mask": data["encoder_attention_mask"], | ||
| "context_latents": data["context_latents"], | ||
| "metadata": data.get("metadata", {}), | ||
| } | ||
|
|
||
|
|
@@ -219,11 +274,10 @@ def collate_preprocessed_batch(batch: List[Dict]) -> Dict[str, torch.Tensor]: | |
|
|
||
| class PreprocessedDataModule(LightningDataModule if LIGHTNING_AVAILABLE else object): | ||
| """DataModule for preprocessed tensor files. | ||
|
|
||
| This is the recommended DataModule for training. It loads pre-computed tensors | ||
| directly without needing VAE, text encoder, or condition encoder at training time. | ||
|
|
||
| Loads precomputed tensors directly, avoiding VAE/text encoding at train time. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| tensor_dir: str, | ||
|
|
@@ -234,19 +288,28 @@ def __init__( | |
| persistent_workers: bool = True, | ||
| pin_memory_device: str = "", | ||
| val_split: float = 0.0, | ||
| length_bucket: bool = False, | ||
| cache_policy: str = "none", | ||
| cache_max_items: int = 0, | ||
| ): | ||
| """Initialize the data module. | ||
| """Initialize the preprocessed data module. | ||
|
|
||
| Args: | ||
| tensor_dir: Directory containing preprocessed .pt files | ||
| batch_size: Training batch size | ||
| num_workers: Number of data loading workers | ||
| pin_memory: Whether to pin memory for faster GPU transfer | ||
| val_split: Fraction of data for validation (0 = no validation) | ||
| tensor_dir: Directory containing preprocessed ``.pt`` files. | ||
| batch_size: Training batch size. | ||
| num_workers: Number of DataLoader worker processes. | ||
| pin_memory: Pin host memory for faster GPU transfer. | ||
| prefetch_factor: Number of prefetched batches per worker. | ||
| persistent_workers: Keep worker processes alive between epochs. | ||
| pin_memory_device: Device string used by pinned memory allocator. | ||
| val_split: Fraction of data reserved for validation. | ||
| length_bucket: Whether to bucket training samples by latent length. | ||
| cache_policy: Dataset cache mode ("none" or "ram_lru"). | ||
| cache_max_items: Maximum cached entries when RAM LRU is enabled. | ||
| """ | ||
| if LIGHTNING_AVAILABLE: | ||
| super().__init__() | ||
|
|
||
| self.tensor_dir = tensor_dir | ||
| self.batch_size = batch_size | ||
| self.num_workers = num_workers | ||
|
|
@@ -255,36 +318,61 @@ def __init__( | |
| self.persistent_workers = persistent_workers | ||
| self.pin_memory_device = pin_memory_device | ||
| self.val_split = val_split | ||
| self.length_bucket = length_bucket | ||
| self.cache_policy = cache_policy | ||
| self.cache_max_items = cache_max_items | ||
|
|
||
| self.train_dataset = None | ||
| self.val_dataset = None | ||
|
|
||
| def setup(self, stage: Optional[str] = None): | ||
| """Setup datasets.""" | ||
| if stage == 'fit' or stage is None: | ||
| # Create full dataset | ||
| full_dataset = PreprocessedTensorDataset(self.tensor_dir) | ||
| full_dataset = PreprocessedTensorDataset( | ||
| self.tensor_dir, | ||
| cache_policy=self.cache_policy, | ||
| cache_max_items=self.cache_max_items, | ||
| ) | ||
|
|
||
| # Split if validation requested | ||
| if self.val_split > 0 and len(full_dataset) > 1: | ||
| n_val = max(1, int(len(full_dataset) * self.val_split)) | ||
| n_train = len(full_dataset) - n_val | ||
|
|
||
| self.train_dataset, self.val_dataset = torch.utils.data.random_split( | ||
| full_dataset, [n_train, n_val] | ||
| ) | ||
| else: | ||
| self.train_dataset = full_dataset | ||
| self.val_dataset = None | ||
|
|
||
|
|
||
| def _resolve_train_latent_lengths(self) -> Optional[List[int]]: | ||
| """Resolve latent lengths for bucketed sampling, including Subset splits.""" | ||
| if not self.length_bucket or self.train_dataset is None: | ||
| return None | ||
|
|
||
| ds = self.train_dataset | ||
| if isinstance(ds, torch.utils.data.Subset): | ||
| base = ds.dataset | ||
| indices = list(ds.indices) | ||
| base_lengths = getattr(base, "latent_lengths", None) | ||
| if base_lengths is None: | ||
| return None | ||
| return [base_lengths[i] for i in indices] | ||
|
|
||
| base_lengths = getattr(ds, "latent_lengths", None) | ||
| if base_lengths is None: | ||
| return None | ||
| return list(base_lengths) | ||
|
|
||
| def train_dataloader(self) -> DataLoader: | ||
| """Create training dataloader.""" | ||
| prefetch_factor = None if self.num_workers == 0 else self.prefetch_factor | ||
| persistent_workers = False if self.num_workers == 0 else self.persistent_workers | ||
| kwargs = dict( | ||
| dataset=self.train_dataset, | ||
| batch_size=self.batch_size, | ||
| shuffle=True, | ||
| shuffle=not self.length_bucket, | ||
| num_workers=self.num_workers, | ||
| pin_memory=self.pin_memory, | ||
| collate_fn=collate_preprocessed_batch, | ||
|
|
@@ -294,8 +382,16 @@ def train_dataloader(self) -> DataLoader: | |
| ) | ||
| if self.pin_memory_device: | ||
| kwargs["pin_memory_device"] = self.pin_memory_device | ||
| if self.length_bucket and hasattr(self.train_dataset, "latent_lengths"): | ||
| kwargs.pop("batch_size", None) | ||
| kwargs.pop("shuffle", None) | ||
| kwargs["batch_sampler"] = BucketedBatchSampler( | ||
| lengths=list(getattr(self.train_dataset, "latent_lengths", [])), | ||
| batch_size=self.batch_size, | ||
| shuffle=True, | ||
| ) | ||
|
Comment on lines
+385
to
+392
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bucketing silently fails when When validation split is enabled, 🐛 Proposed fix to access underlying dataset's latent_lengths- if self.length_bucket and hasattr(self.train_dataset, "latent_lengths"):
+ latent_lengths = None
+ if self.length_bucket:
+ ds = self.train_dataset
+ # Handle Subset from random_split
+ if hasattr(ds, "dataset") and hasattr(ds, "indices"):
+ underlying = ds.dataset
+ if hasattr(underlying, "latent_lengths"):
+ latent_lengths = [underlying.latent_lengths[i] for i in ds.indices]
+ elif hasattr(ds, "latent_lengths"):
+ latent_lengths = ds.latent_lengths
+ if latent_lengths is not None:
kwargs.pop("batch_size", None)
kwargs.pop("shuffle", None)
kwargs["batch_sampler"] = BucketedBatchSampler(
- lengths=list(getattr(self.train_dataset, "latent_lengths", [])),
+ lengths=latent_lengths,
batch_size=self.batch_size,
shuffle=True,
)🤖 Prompt for AI Agents |
||
| return DataLoader(**kwargs) | ||
|
|
||
| def val_dataloader(self) -> Optional[DataLoader]: | ||
| """Create validation dataloader.""" | ||
| if self.val_dataset is None: | ||
|
|
@@ -474,17 +570,37 @@ def __init__( | |
| pin_memory: bool = True, | ||
| max_duration: float = 240.0, | ||
| val_split: float = 0.0, | ||
| length_bucket: bool = False, | ||
| cache_policy: str = "none", | ||
| cache_max_items: int = 0, | ||
| ): | ||
| """Initialize legacy raw-audio datamodule. | ||
|
|
||
| Args: | ||
| samples: Raw training sample metadata entries. | ||
| dit_handler: Model handler used by legacy training flows. | ||
| batch_size: Number of samples per batch. | ||
| num_workers: Number of dataloader workers. | ||
| pin_memory: Whether to enable pinned memory in dataloaders. | ||
| max_duration: Max audio duration (seconds) for clipping. | ||
| val_split: Validation split fraction. | ||
| length_bucket: Accepted for compatibility; unused for raw audio mode. | ||
| cache_policy: Accepted for compatibility; unused for raw audio mode. | ||
| cache_max_items: Accepted for compatibility; unused for raw audio mode. | ||
| """ | ||
| if LIGHTNING_AVAILABLE: | ||
| super().__init__() | ||
|
|
||
| self.samples = samples | ||
| self.dit_handler = dit_handler | ||
| self.batch_size = batch_size | ||
| self.num_workers = num_workers | ||
| self.pin_memory = pin_memory | ||
| self.max_duration = max_duration | ||
| self.val_split = val_split | ||
| self.length_bucket = length_bucket | ||
| self.cache_policy = cache_policy | ||
| self.cache_max_items = cache_max_items | ||
|
|
||
| self.train_dataset = None | ||
| self.val_dataset = None | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,11 +7,16 @@ | |
|
|
||
| import os | ||
| import json | ||
| import random | ||
| import tempfile | ||
| import unittest | ||
| from unittest import mock | ||
|
|
||
| import torch | ||
|
|
||
| from acestep.training.path_safety import safe_path, set_safe_root | ||
| from acestep.training.data_module import ( | ||
| BucketedBatchSampler, | ||
| PreprocessedTensorDataset, | ||
| load_dataset_from_json, | ||
| ) | ||
|
|
@@ -193,5 +198,19 @@ def test_valid_json(self): | |
| os.unlink(path) | ||
|
|
||
|
|
||
| class AceStepDataModuleInitTests(unittest.TestCase): | ||
| """Regression tests for legacy ``AceStepDataModule`` initialization.""" | ||
|
|
||
| def test_init_does_not_require_preprocessed_only_cache_args(self): | ||
| """Legacy raw-audio datamodule should initialize without NameError.""" | ||
| from acestep.training.data_module import AceStepDataModule | ||
|
|
||
| module = AceStepDataModule(samples=[], dit_handler=object()) | ||
|
|
||
| self.assertEqual(module.samples, []) | ||
| self.assertIsNotNone(module.dit_handler) | ||
|
|
||
|
Comment on lines
+201
to
+212
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Coverage is missing for the new bucketing/cache behaviors introduced by this PR. This addition only validates legacy initialization; it does not test the new As per coding guidelines: 🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| """Collate helpers for preprocessed ACE-Step training tensors.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import Dict, List | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| def collate_preprocessed_batch(batch: List[Dict]) -> Dict[str, torch.Tensor]: | ||
|
Comment on lines
+5
to
+10
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# First, locate the file and read it
fd -t f "preprocessed_collate.py" | head -5Repository: ace-step/ACE-Step-1.5 Length of output: 105 🏁 Script executed: #!/bin/bash
# Read the file content, focusing on the function signature and return statement
if [ -f "acestep/training/preprocessed_collate.py" ]; then
cat -n "acestep/training/preprocessed_collate.py" | head -70
else
find . -name "preprocessed_collate.py" -exec cat -n {} \; | head -70
fiRepository: ace-step/ACE-Step-1.5 Length of output: 2713 🏁 Script executed: #!/bin/bash
# Search for usages of collate_preprocessed_batch function
rg -n "collate_preprocessed_batch" --type=pyRepository: ace-step/ACE-Step-1.5 Length of output: 604 🏁 Script executed: #!/bin/bash
# Check if metadata field is accessed anywhere to understand usage patterns
rg -n "\[\"metadata\"\]|\['metadata'\]" --type=py -B2 -A2Repository: ace-step/ACE-Step-1.5 Length of output: 3934 🏁 Script executed: #!/bin/bash
# Read the function in data_module.py to see if it has the same issue
sed -n '207,275p' acestep/training/data_module.pyRepository: ace-step/ACE-Step-1.5 Length of output: 2832 Fix the return type annotation to match the actual payload in both Both Replace the imprecise Proposed typing fix-from typing import Dict, List
+from typing import Any, Dict, List, TypedDict
+
+
+class CollatedPreprocessedBatch(TypedDict):
+ target_latents: torch.Tensor
+ attention_mask: torch.Tensor
+ encoder_hidden_states: torch.Tensor
+ encoder_attention_mask: torch.Tensor
+ context_latents: torch.Tensor
+ metadata: List[Dict[str, Any]]
-def collate_preprocessed_batch(batch: List[Dict]) -> Dict[str, torch.Tensor]:
+def collate_preprocessed_batch(batch: List[Dict[str, Any]]) -> CollatedPreprocessedBatch:🤖 Prompt for AI Agents |
||
| """Pad and stack variable-length preprocessed training tensors. | ||
|
|
||
| Args: | ||
| batch: Per-sample tensor dictionaries. | ||
|
|
||
| Returns: | ||
| Batched tensor dictionary with padding to max lengths in batch. | ||
| """ | ||
| max_latent_len = max(s["target_latents"].shape[0] for s in batch) | ||
| max_encoder_len = max(s["encoder_hidden_states"].shape[0] for s in batch) | ||
|
|
||
| target_latents = [] | ||
| attention_masks = [] | ||
| encoder_hidden_states = [] | ||
| encoder_attention_masks = [] | ||
| context_latents = [] | ||
|
|
||
| for sample in batch: | ||
| tl = sample["target_latents"] | ||
| if tl.shape[0] < max_latent_len: | ||
| tl = torch.cat([tl, tl.new_zeros(max_latent_len - tl.shape[0], tl.shape[1])], dim=0) | ||
| target_latents.append(tl) | ||
|
|
||
| am = sample["attention_mask"] | ||
| if am.shape[0] < max_latent_len: | ||
| am = torch.cat([am, am.new_zeros(max_latent_len - am.shape[0])], dim=0) | ||
| attention_masks.append(am) | ||
|
|
||
| cl = sample["context_latents"] | ||
| if cl.shape[0] < max_latent_len: | ||
| cl = torch.cat([cl, cl.new_zeros(max_latent_len - cl.shape[0], cl.shape[1])], dim=0) | ||
| context_latents.append(cl) | ||
|
|
||
| ehs = sample["encoder_hidden_states"] | ||
| if ehs.shape[0] < max_encoder_len: | ||
| ehs = torch.cat([ehs, ehs.new_zeros(max_encoder_len - ehs.shape[0], ehs.shape[1])], dim=0) | ||
| encoder_hidden_states.append(ehs) | ||
|
|
||
| eam = sample["encoder_attention_mask"] | ||
| if eam.shape[0] < max_encoder_len: | ||
| eam = torch.cat([eam, eam.new_zeros(max_encoder_len - eam.shape[0])], dim=0) | ||
| encoder_attention_masks.append(eam) | ||
|
|
||
| return { | ||
| "target_latents": torch.stack(target_latents), | ||
| "attention_mask": torch.stack(attention_masks), | ||
| "encoder_hidden_states": torch.stack(encoder_hidden_states), | ||
| "encoder_attention_mask": torch.stack(encoder_attention_masks), | ||
| "context_latents": torch.stack(context_latents), | ||
| "metadata": [s["metadata"] for s in batch], | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid bare
except Exception; handle specific errors and log failures.The current code silently catches all exceptions and defaults to length 0, which could mask real issues (corrupt files, permission errors) and cause unexpected bucketing behavior for failed samples.
🛡️ Proposed fix with specific exception handling
self.latent_lengths: List[int] = [] for vp in self.valid_paths: try: sample = torch.load(vp, map_location="cpu", weights_only=True) self.latent_lengths.append(int(sample["target_latents"].shape[0])) - except Exception: + except (OSError, KeyError, RuntimeError) as e: + logger.warning(f"Failed to read latent length from {vp}: {e}") self.latent_lengths.append(0)🧰 Tools
🪛 Ruff (0.15.2)
[warning] 127-127: Do not catch blind exception:
Exception(BLE001)
🤖 Prompt for AI Agents