diff --git a/acestep/training/data_module.py b/acestep/training/data_module.py index 68e0b23d..fa5096af 100644 --- a/acestep/training/data_module.py +++ b/acestep/training/data_module.py @@ -8,6 +8,7 @@ import os import json import random +from collections import OrderedDict from typing import Optional, List, Dict, Any, Tuple from loguru import logger @@ -32,6 +33,41 @@ class LightningDataModule: # Preprocessed Tensor Dataset (Recommended for Training) # ============================================================================ +class BucketedBatchSampler: + """Batch sampler that groups indices by latent length buckets.""" + + def __init__(self, lengths: List[int], batch_size: int, shuffle: bool = True) -> None: + self.lengths = lengths + self.batch_size = max(1, int(batch_size)) + self.shuffle = shuffle + + def __iter__(self): + buckets: Dict[int, List[int]] = {} + for idx, length in enumerate(self.lengths): + bucket = int(length // 64) + buckets.setdefault(bucket, []).append(idx) + bucket_keys = list(buckets.keys()) + if self.shuffle: + random.shuffle(bucket_keys) + for key in bucket_keys: + group = buckets[key] + if self.shuffle: + random.shuffle(group) + for start in range(0, len(group), self.batch_size): + yield group[start:start + self.batch_size] + + def __len__(self) -> int: + bucket_counts: Dict[int, int] = {} + for length in self.lengths: + bucket = int(length // 64) + bucket_counts[bucket] = bucket_counts.get(bucket, 0) + 1 + + return sum( + (count + self.batch_size - 1) // self.batch_size + for count in bucket_counts.values() + ) + + class PreprocessedTensorDataset(Dataset): """Dataset that loads preprocessed tensor files. @@ -45,7 +81,7 @@ class PreprocessedTensorDataset(Dataset): No VAE/text encoder needed during training - just load tensors directly! """ - def __init__(self, tensor_dir: str): + def __init__(self, tensor_dir: str, cache_policy: str = "none", cache_max_items: int = 0): """Initialize from a directory of preprocessed .pt files. Args: @@ -59,6 +95,9 @@ def __init__(self, tensor_dir: str): raise ValueError(f"Not an existing directory: {tensor_dir}") self.tensor_dir = validated_dir self.sample_paths: List[str] = [] + self.cache_policy = cache_policy + self.cache_max_items = max(0, int(cache_max_items)) + self._cache: "OrderedDict[int, Dict[str, Any]]" = OrderedDict() # Load manifest if exists manifest_path = safe_path("manifest.json", base=self.tensor_dir) @@ -87,6 +126,14 @@ def __init__(self, tensor_dir: str): f"{len(self.sample_paths) - len(self.valid_paths)} missing" ) + self.latent_lengths: List[int] = [] + for vp in self.valid_paths: + try: + sample = torch.load(vp, map_location="cpu", weights_only=True) + self.latent_lengths.append(int(sample["target_latents"].shape[0])) + except Exception: + self.latent_lengths.append(0) + logger.info( f"PreprocessedTensorDataset: {len(self.valid_paths)} samples " f"from {self.tensor_dir}" @@ -136,15 +183,23 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: Returns: Dictionary containing all pre-computed tensors for training """ - tensor_path = self.valid_paths[idx] - data = torch.load(tensor_path, map_location='cpu', weights_only=True) - + if self.cache_policy == "ram_lru" and idx in self._cache: + data = self._cache.pop(idx) + self._cache[idx] = data + else: + tensor_path = self.valid_paths[idx] + data = torch.load(tensor_path, map_location='cpu', weights_only=True) + if self.cache_policy == "ram_lru" and self.cache_max_items > 0: + self._cache[idx] = data + while len(self._cache) > self.cache_max_items: + self._cache.popitem(last=False) + return { - "target_latents": data["target_latents"], # [T, 64] - "attention_mask": data["attention_mask"], # [T] - "encoder_hidden_states": data["encoder_hidden_states"], # [L, D] - "encoder_attention_mask": data["encoder_attention_mask"], # [L] - "context_latents": data["context_latents"], # [T, 65] + "target_latents": data["target_latents"], + "attention_mask": data["attention_mask"], + "encoder_hidden_states": data["encoder_hidden_states"], + "encoder_attention_mask": data["encoder_attention_mask"], + "context_latents": data["context_latents"], "metadata": data.get("metadata", {}), } @@ -219,11 +274,10 @@ def collate_preprocessed_batch(batch: List[Dict]) -> Dict[str, torch.Tensor]: class PreprocessedDataModule(LightningDataModule if LIGHTNING_AVAILABLE else object): """DataModule for preprocessed tensor files. - - This is the recommended DataModule for training. It loads pre-computed tensors - directly without needing VAE, text encoder, or condition encoder at training time. + + Loads precomputed tensors directly, avoiding VAE/text encoding at train time. """ - + def __init__( self, tensor_dir: str, @@ -234,19 +288,28 @@ def __init__( persistent_workers: bool = True, pin_memory_device: str = "", val_split: float = 0.0, + length_bucket: bool = False, + cache_policy: str = "none", + cache_max_items: int = 0, ): - """Initialize the data module. - + """Initialize the preprocessed data module. + Args: - tensor_dir: Directory containing preprocessed .pt files - batch_size: Training batch size - num_workers: Number of data loading workers - pin_memory: Whether to pin memory for faster GPU transfer - val_split: Fraction of data for validation (0 = no validation) + tensor_dir: Directory containing preprocessed ``.pt`` files. + batch_size: Training batch size. + num_workers: Number of DataLoader worker processes. + pin_memory: Pin host memory for faster GPU transfer. + prefetch_factor: Number of prefetched batches per worker. + persistent_workers: Keep worker processes alive between epochs. + pin_memory_device: Device string used by pinned memory allocator. + val_split: Fraction of data reserved for validation. + length_bucket: Whether to bucket training samples by latent length. + cache_policy: Dataset cache mode ("none" or "ram_lru"). + cache_max_items: Maximum cached entries when RAM LRU is enabled. """ if LIGHTNING_AVAILABLE: super().__init__() - + self.tensor_dir = tensor_dir self.batch_size = batch_size self.num_workers = num_workers @@ -255,28 +318,53 @@ def __init__( self.persistent_workers = persistent_workers self.pin_memory_device = pin_memory_device self.val_split = val_split + self.length_bucket = length_bucket + self.cache_policy = cache_policy + self.cache_max_items = cache_max_items self.train_dataset = None self.val_dataset = None - + def setup(self, stage: Optional[str] = None): """Setup datasets.""" if stage == 'fit' or stage is None: # Create full dataset - full_dataset = PreprocessedTensorDataset(self.tensor_dir) + full_dataset = PreprocessedTensorDataset( + self.tensor_dir, + cache_policy=self.cache_policy, + cache_max_items=self.cache_max_items, + ) # Split if validation requested if self.val_split > 0 and len(full_dataset) > 1: n_val = max(1, int(len(full_dataset) * self.val_split)) n_train = len(full_dataset) - n_val - self.train_dataset, self.val_dataset = torch.utils.data.random_split( full_dataset, [n_train, n_val] ) else: self.train_dataset = full_dataset self.val_dataset = None - + + def _resolve_train_latent_lengths(self) -> Optional[List[int]]: + """Resolve latent lengths for bucketed sampling, including Subset splits.""" + if not self.length_bucket or self.train_dataset is None: + return None + + ds = self.train_dataset + if isinstance(ds, torch.utils.data.Subset): + base = ds.dataset + indices = list(ds.indices) + base_lengths = getattr(base, "latent_lengths", None) + if base_lengths is None: + return None + return [base_lengths[i] for i in indices] + + base_lengths = getattr(ds, "latent_lengths", None) + if base_lengths is None: + return None + return list(base_lengths) + def train_dataloader(self) -> DataLoader: """Create training dataloader.""" prefetch_factor = None if self.num_workers == 0 else self.prefetch_factor @@ -284,7 +372,7 @@ def train_dataloader(self) -> DataLoader: kwargs = dict( dataset=self.train_dataset, batch_size=self.batch_size, - shuffle=True, + shuffle=not self.length_bucket, num_workers=self.num_workers, pin_memory=self.pin_memory, collate_fn=collate_preprocessed_batch, @@ -294,8 +382,16 @@ def train_dataloader(self) -> DataLoader: ) if self.pin_memory_device: kwargs["pin_memory_device"] = self.pin_memory_device + if self.length_bucket and hasattr(self.train_dataset, "latent_lengths"): + kwargs.pop("batch_size", None) + kwargs.pop("shuffle", None) + kwargs["batch_sampler"] = BucketedBatchSampler( + lengths=list(getattr(self.train_dataset, "latent_lengths", [])), + batch_size=self.batch_size, + shuffle=True, + ) return DataLoader(**kwargs) - + def val_dataloader(self) -> Optional[DataLoader]: """Create validation dataloader.""" if self.val_dataset is None: @@ -474,10 +570,27 @@ def __init__( pin_memory: bool = True, max_duration: float = 240.0, val_split: float = 0.0, + length_bucket: bool = False, + cache_policy: str = "none", + cache_max_items: int = 0, ): + """Initialize legacy raw-audio datamodule. + + Args: + samples: Raw training sample metadata entries. + dit_handler: Model handler used by legacy training flows. + batch_size: Number of samples per batch. + num_workers: Number of dataloader workers. + pin_memory: Whether to enable pinned memory in dataloaders. + max_duration: Max audio duration (seconds) for clipping. + val_split: Validation split fraction. + length_bucket: Accepted for compatibility; unused for raw audio mode. + cache_policy: Accepted for compatibility; unused for raw audio mode. + cache_max_items: Accepted for compatibility; unused for raw audio mode. + """ if LIGHTNING_AVAILABLE: super().__init__() - + self.samples = samples self.dit_handler = dit_handler self.batch_size = batch_size @@ -485,6 +598,9 @@ def __init__( self.pin_memory = pin_memory self.max_duration = max_duration self.val_split = val_split + self.length_bucket = length_bucket + self.cache_policy = cache_policy + self.cache_max_items = cache_max_items self.train_dataset = None self.val_dataset = None diff --git a/acestep/training/data_module_test.py b/acestep/training/data_module_test.py index a6e43769..97bea97c 100644 --- a/acestep/training/data_module_test.py +++ b/acestep/training/data_module_test.py @@ -7,11 +7,16 @@ import os import json +import random import tempfile import unittest +from unittest import mock + +import torch from acestep.training.path_safety import safe_path, set_safe_root from acestep.training.data_module import ( + BucketedBatchSampler, PreprocessedTensorDataset, load_dataset_from_json, ) @@ -193,5 +198,19 @@ def test_valid_json(self): os.unlink(path) +class AceStepDataModuleInitTests(unittest.TestCase): + """Regression tests for legacy ``AceStepDataModule`` initialization.""" + + def test_init_does_not_require_preprocessed_only_cache_args(self): + """Legacy raw-audio datamodule should initialize without NameError.""" + from acestep.training.data_module import AceStepDataModule + + module = AceStepDataModule(samples=[], dit_handler=object()) + + self.assertEqual(module.samples, []) + self.assertIsNotNone(module.dit_handler) + + + if __name__ == "__main__": unittest.main() diff --git a/acestep/training/preprocessed_collate.py b/acestep/training/preprocessed_collate.py new file mode 100644 index 00000000..533e9929 --- /dev/null +++ b/acestep/training/preprocessed_collate.py @@ -0,0 +1,61 @@ +"""Collate helpers for preprocessed ACE-Step training tensors.""" + +from __future__ import annotations + +from typing import Dict, List + +import torch + + +def collate_preprocessed_batch(batch: List[Dict]) -> Dict[str, torch.Tensor]: + """Pad and stack variable-length preprocessed training tensors. + + Args: + batch: Per-sample tensor dictionaries. + + Returns: + Batched tensor dictionary with padding to max lengths in batch. + """ + max_latent_len = max(s["target_latents"].shape[0] for s in batch) + max_encoder_len = max(s["encoder_hidden_states"].shape[0] for s in batch) + + target_latents = [] + attention_masks = [] + encoder_hidden_states = [] + encoder_attention_masks = [] + context_latents = [] + + for sample in batch: + tl = sample["target_latents"] + if tl.shape[0] < max_latent_len: + tl = torch.cat([tl, tl.new_zeros(max_latent_len - tl.shape[0], tl.shape[1])], dim=0) + target_latents.append(tl) + + am = sample["attention_mask"] + if am.shape[0] < max_latent_len: + am = torch.cat([am, am.new_zeros(max_latent_len - am.shape[0])], dim=0) + attention_masks.append(am) + + cl = sample["context_latents"] + if cl.shape[0] < max_latent_len: + cl = torch.cat([cl, cl.new_zeros(max_latent_len - cl.shape[0], cl.shape[1])], dim=0) + context_latents.append(cl) + + ehs = sample["encoder_hidden_states"] + if ehs.shape[0] < max_encoder_len: + ehs = torch.cat([ehs, ehs.new_zeros(max_encoder_len - ehs.shape[0], ehs.shape[1])], dim=0) + encoder_hidden_states.append(ehs) + + eam = sample["encoder_attention_mask"] + if eam.shape[0] < max_encoder_len: + eam = torch.cat([eam, eam.new_zeros(max_encoder_len - eam.shape[0])], dim=0) + encoder_attention_masks.append(eam) + + return { + "target_latents": torch.stack(target_latents), + "attention_mask": torch.stack(attention_masks), + "encoder_hidden_states": torch.stack(encoder_hidden_states), + "encoder_attention_mask": torch.stack(encoder_attention_masks), + "context_latents": torch.stack(context_latents), + "metadata": [s["metadata"] for s in batch], + } diff --git a/acestep/training/preprocessed_dataset.py b/acestep/training/preprocessed_dataset.py new file mode 100644 index 00000000..7ad9381f --- /dev/null +++ b/acestep/training/preprocessed_dataset.py @@ -0,0 +1,115 @@ +"""Preprocessed tensor dataset for ACE-Step LoRA/full fine-tuning.""" + +from __future__ import annotations + +import json +import os +from collections import OrderedDict +from typing import Any, Dict, List, Optional + +import torch +from loguru import logger +from torch.utils.data import Dataset + +from acestep.training.path_safety import safe_path + + +class PreprocessedTensorDataset(Dataset): + """Dataset that loads preprocessed tensor files for training.""" + + def __init__(self, tensor_dir: str, cache_policy: str = "none", cache_max_items: int = 0): + """Initialize dataset from preprocessed tensor directory. + + Args: + tensor_dir: Directory containing preprocessed ``.pt`` samples. + cache_policy: Cache mode ("none" or "ram_lru"). Defaults to "none". + cache_max_items: Max samples kept in RAM when ``cache_policy='ram_lru'``. + ``0`` disables RAM caching. + + Raises: + ValueError: If *tensor_dir* is invalid or escapes safe root. + """ + validated_dir = safe_path(tensor_dir) + if not os.path.isdir(validated_dir): + raise ValueError(f"Not an existing directory: {tensor_dir}") + + self.tensor_dir = validated_dir + self.sample_paths: List[str] = [] + self.cache_policy = cache_policy + self.cache_max_items = max(0, int(cache_max_items)) + self._cache: "OrderedDict[int, Dict[str, Any]]" = OrderedDict() + + manifest_path = safe_path("manifest.json", base=self.tensor_dir) + if os.path.exists(manifest_path): + with open(manifest_path, "r", encoding="utf-8") as handle: + manifest = json.load(handle) + for raw in manifest.get("samples", []): + resolved = self._resolve_manifest_path(raw) + if resolved is not None: + self.sample_paths.append(resolved) + else: + for filename in os.listdir(self.tensor_dir): + if filename.endswith(".pt") and filename != "manifest.json": + self.sample_paths.append(safe_path(filename, base=self.tensor_dir)) + + self.valid_paths = [p for p in self.sample_paths if os.path.exists(p)] + if len(self.valid_paths) != len(self.sample_paths): + logger.warning( + "Some tensor files not found: %d missing", + len(self.sample_paths) - len(self.valid_paths), + ) + + self.latent_lengths: List[int] = [] + for vp in self.valid_paths: + try: + sample = torch.load(vp, map_location="cpu", weights_only=True) + self.latent_lengths.append(int(sample["target_latents"].shape[0])) + except (FileNotFoundError, PermissionError, EOFError, OSError, KeyError, RuntimeError) as exc: + logger.warning("Failed to read latent length from %s: %s", vp, exc) + self.latent_lengths.append(0) + + logger.info("PreprocessedTensorDataset: %d samples from %s", len(self.valid_paths), self.tensor_dir) + + def _resolve_manifest_path(self, raw: str) -> Optional[str]: + """Resolve and validate manifest sample path.""" + try: + candidate = safe_path(raw, base=self.tensor_dir) + if os.path.exists(candidate): + return candidate + except ValueError: + pass + + try: + candidate = safe_path(raw) + if os.path.exists(candidate): + logger.debug("Resolved legacy manifest path via safe root: %s", raw) + return candidate + except ValueError: + pass + + logger.warning("Skipping unresolvable manifest path: %s", raw) + return None + + def __len__(self) -> int: + return len(self.valid_paths) + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + """Load one preprocessed sample.""" + if self.cache_policy == "ram_lru" and idx in self._cache: + data = self._cache.pop(idx) + self._cache[idx] = data + else: + data = torch.load(self.valid_paths[idx], map_location="cpu", weights_only=True) + if self.cache_policy == "ram_lru" and self.cache_max_items > 0: + self._cache[idx] = data + while len(self._cache) > self.cache_max_items: + self._cache.popitem(last=False) + + return { + "target_latents": data["target_latents"], + "attention_mask": data["attention_mask"], + "encoder_hidden_states": data["encoder_hidden_states"], + "encoder_attention_mask": data["encoder_attention_mask"], + "context_latents": data["context_latents"], + "metadata": data.get("metadata", {}), + } diff --git a/acestep/training/preprocessed_sampler.py b/acestep/training/preprocessed_sampler.py new file mode 100644 index 00000000..a99ef9e2 --- /dev/null +++ b/acestep/training/preprocessed_sampler.py @@ -0,0 +1,45 @@ +"""Sampling utilities for preprocessed tensor training datasets.""" + +from __future__ import annotations + +import random +from typing import Dict, List + + +class BucketedBatchSampler: + """Batch sampler that groups indices by latent-length buckets.""" + + def __init__(self, lengths: List[int], batch_size: int, shuffle: bool = True) -> None: + """Initialize bucket sampler. + + Args: + lengths: Per-sample latent lengths. + batch_size: Number of items per yielded batch. + shuffle: Whether to shuffle buckets and samples each epoch. + """ + self.lengths = lengths + self.batch_size = max(1, int(batch_size)) + self.shuffle = shuffle + + def __iter__(self): + """Yield batches of indices grouped by coarse latent-length buckets.""" + buckets: Dict[int, List[int]] = {} + for idx, length in enumerate(self.lengths): + bucket = int(length // 64) + buckets.setdefault(bucket, []).append(idx) + + bucket_keys = list(buckets.keys()) + if self.shuffle: + random.shuffle(bucket_keys) + + for key in bucket_keys: + group = buckets[key] + if self.shuffle: + random.shuffle(group) + for start in range(0, len(group), self.batch_size): + yield group[start:start + self.batch_size] + + def __len__(self) -> int: + """Return estimated number of batches.""" + total = len(self.lengths) + return (total + self.batch_size - 1) // self.batch_size diff --git a/acestep/training_v2/cli/args.py b/acestep/training_v2/cli/args.py index 49130d15..408dd817 100644 --- a/acestep/training_v2/cli/args.py +++ b/acestep/training_v2/cli/args.py @@ -108,6 +108,13 @@ def build_root_parser() -> argparse.ArgumentParser: help="Random seed (default: 42)", ) + p_validate = subparsers.add_parser( + "validate-dataset", + help="Validate preprocessed tensors and report dataset stats", + formatter_class=formatter_class, + ) + p_validate.add_argument("--dataset-dir", type=str, required=True, help="Directory containing preprocessed .pt files") + return root @@ -201,6 +208,9 @@ def _add_common_training_args(parser: argparse.ArgumentParser) -> None: default=_DEFAULT_NUM_WORKERS > 0, help="Keep workers alive between epochs (default: True; False on Windows)", ) + g_data.add_argument("--length-bucket", action=argparse.BooleanOptionalAction, default=False, help="Bucket samples by latent length") + g_data.add_argument("--cache-policy", type=str, default="none", choices=["none", "ram_lru"], help="Dataset cache policy") + g_data.add_argument("--cache-max-items", type=int, default=0, help="Max cached samples for ram_lru cache") # -- Training hyperparams ------------------------------------------------ g_train = parser.add_argument_group("Training") @@ -221,7 +231,12 @@ def _add_common_training_args(parser: argparse.ArgumentParser) -> None: # -- Adapter selection --------------------------------------------------- g_adapter = parser.add_argument_group("Adapter") - g_adapter.add_argument("--adapter-type", type=str, default="lora", choices=["lora", "lokr"], help="Adapter type: lora (PEFT) or lokr (LyCORIS) (default: lora)") + g_adapter.add_argument("--training-mode", type=str, default="adapter", choices=["adapter", "full"], help="Training mode: adapter (LoRA/LoKR) or full decoder fine-tune") + g_adapter.add_argument("--adapter-type", type=str, default="lora", choices=["lora", "lokr"], help="Adapter type (used in adapter mode)") + g_adapter.add_argument("--full-train-include", type=str, default="decoder", choices=["decoder"], help="Scope for full fine-tuning") + g_adapter.add_argument("--full-lr-mult-attn", type=float, default=1.0, help="LR multiplier for attention params in full mode") + g_adapter.add_argument("--full-lr-mult-ffn", type=float, default=1.0, help="LR multiplier for FFN params in full mode") + g_adapter.add_argument("--full-lr-mult-other", type=float, default=1.0, help="LR multiplier for remaining decoder params in full mode") # -- LoRA hyperparams --------------------------------------------------- g_lora = parser.add_argument_group("LoRA (used when --adapter-type=lora)") diff --git a/acestep/training_v2/cli/config_builder.py b/acestep/training_v2/cli/config_builder.py index 2fbc7a78..c90423d9 100644 --- a/acestep/training_v2/cli/config_builder.py +++ b/acestep/training_v2/cli/config_builder.py @@ -153,6 +153,14 @@ def build_configs(args: argparse.Namespace) -> Tuple[AdapterConfig, TrainingConf persistent_workers=persistent_workers, # V2 extensions adapter_type=adapter_type, + training_mode=getattr(args, "training_mode", "adapter"), + full_train_include=getattr(args, "full_train_include", "decoder"), + full_lr_mult_attn=getattr(args, "full_lr_mult_attn", 1.0), + full_lr_mult_ffn=getattr(args, "full_lr_mult_ffn", 1.0), + full_lr_mult_other=getattr(args, "full_lr_mult_other", 1.0), + length_bucket=getattr(args, "length_bucket", False), + cache_policy=getattr(args, "cache_policy", "none"), + cache_max_items=getattr(args, "cache_max_items", 0), optimizer_type=getattr(args, "optimizer_type", "adamw"), scheduler_type=getattr(args, "scheduler_type", "cosine"), gradient_checkpointing=getattr(args, "gradient_checkpointing", True), diff --git a/acestep/training_v2/configs.py b/acestep/training_v2/configs.py index eae85de9..422ef666 100644 --- a/acestep/training_v2/configs.py +++ b/acestep/training_v2/configs.py @@ -155,6 +155,30 @@ class TrainingConfigV2(TrainingConfig): adapter_type: str = "lora" """Adapter type: 'lora' (PEFT) or 'lokr' (LyCORIS).""" + training_mode: str = "adapter" + """Training mode: 'adapter' (LoRA/LoKR) or 'full' (decoder full fine-tune).""" + + full_train_include: str = "decoder" + """Scope for full fine-tuning. Currently supports 'decoder'.""" + + full_lr_mult_attn: float = 1.0 + """LR multiplier for attention parameters in full fine-tuning.""" + + full_lr_mult_ffn: float = 1.0 + """LR multiplier for FFN/MLP parameters in full fine-tuning.""" + + full_lr_mult_other: float = 1.0 + """LR multiplier for non-attention, non-FFN decoder parameters in full fine-tuning.""" + + length_bucket: bool = False + """Enable latent-length bucketing to reduce padding waste.""" + + cache_policy: str = "none" + """Dataset cache policy: 'none' or 'ram_lru'.""" + + cache_max_items: int = 0 + """Max RAM cache entries when cache_policy='ram_lru' (0 disables cache).""" + # --- Model / paths ------------------------------------------------------ model_variant: str = "turbo" """Model variant: 'turbo', 'base', or 'sft'.""" @@ -252,6 +276,14 @@ def to_dict(self) -> dict: "offload_encoder": self.offload_encoder, "vram_profile": self.vram_profile, "adapter_type": self.adapter_type, + "training_mode": self.training_mode, + "full_train_include": self.full_train_include, + "full_lr_mult_attn": self.full_lr_mult_attn, + "full_lr_mult_ffn": self.full_lr_mult_ffn, + "full_lr_mult_other": self.full_lr_mult_other, + "length_bucket": self.length_bucket, + "cache_policy": self.cache_policy, + "cache_max_items": self.cache_max_items, "cfg_ratio": self.cfg_ratio, "timestep_mu": self.timestep_mu, "timestep_sigma": self.timestep_sigma, diff --git a/acestep/training_v2/dataset_validation.py b/acestep/training_v2/dataset_validation.py new file mode 100644 index 00000000..7ad1c748 --- /dev/null +++ b/acestep/training_v2/dataset_validation.py @@ -0,0 +1,66 @@ +"""Dataset validation utilities for preprocessed training tensors.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Dict, List + +import torch + + +REQUIRED_KEYS = ( + "target_latents", + "attention_mask", + "encoder_hidden_states", + "encoder_attention_mask", + "context_latents", +) + + +def _has_nonfinite(tensors: List[torch.Tensor]) -> bool: + """Return True if any tensor contains NaN or Inf values.""" + return any(not torch.isfinite(t).all() for t in tensors if isinstance(t, torch.Tensor)) + + +def validate_preprocessed_dataset(dataset_dir: str) -> Dict[str, Any]: + """Validate a preprocessed dataset directory and return summary stats.""" + root = Path(dataset_dir) + files = sorted(p for p in root.glob("*.pt") if p.name != "manifest.json") + + valid = 0 + invalid = 0 + nonfinite = 0 + lengths: List[int] = [] + errors: List[str] = [] + + for file_path in files: + try: + data = torch.load(str(file_path), map_location="cpu", weights_only=True) + missing = [k for k in REQUIRED_KEYS if k not in data] + if missing: + invalid += 1 + errors.append(f"{file_path.name}: missing keys {missing}") + continue + + tensors = [data[k] for k in REQUIRED_KEYS] + if _has_nonfinite(tensors): + nonfinite += 1 + + latent_len = int(data["target_latents"].shape[0]) + lengths.append(latent_len) + valid += 1 + except Exception as exc: # explicit diagnostic handling + invalid += 1 + errors.append(f"{file_path.name}: {exc}") + + avg_len = float(sum(lengths) / max(len(lengths), 1)) + return { + "total_samples": len(files), + "valid_samples": valid, + "invalid_samples": invalid, + "nan_or_inf_samples": nonfinite, + "min_latent_length": min(lengths) if lengths else 0, + "max_latent_length": max(lengths) if lengths else 0, + "avg_latent_length": avg_len, + "errors": errors, + } diff --git a/acestep/training_v2/dataset_validation_test.py b/acestep/training_v2/dataset_validation_test.py new file mode 100644 index 00000000..5b28af54 --- /dev/null +++ b/acestep/training_v2/dataset_validation_test.py @@ -0,0 +1,53 @@ +"""Unit tests for training_v2 dataset validation helpers.""" + +from __future__ import annotations + +import tempfile +import unittest +from pathlib import Path + +import torch + +from acestep.training_v2.dataset_validation import validate_preprocessed_dataset + + +class DatasetValidationTest(unittest.TestCase): + """Covers success and regression paths for dataset validation.""" + + def test_validate_dataset_reports_valid_and_invalid_samples(self) -> None: + """Validator should count valid, invalid, and non-finite samples correctly.""" + with tempfile.TemporaryDirectory() as td: + root = Path(td) + torch.save( + { + "target_latents": torch.zeros(8, 64), + "attention_mask": torch.ones(8), + "encoder_hidden_states": torch.zeros(6, 16), + "encoder_attention_mask": torch.ones(6), + "context_latents": torch.zeros(8, 65), + }, + root / "ok.pt", + ) + torch.save( + { + "target_latents": torch.tensor([[float("nan")]]), + "attention_mask": torch.ones(1), + "encoder_hidden_states": torch.zeros(1, 1), + "encoder_attention_mask": torch.ones(1), + "context_latents": torch.zeros(1, 1), + }, + root / "nan.pt", + ) + torch.save({"target_latents": torch.zeros(4, 64)}, root / "bad.pt") + + report = validate_preprocessed_dataset(str(root)) + + self.assertEqual(report["total_samples"], 3) + self.assertEqual(report["valid_samples"], 2) + self.assertEqual(report["invalid_samples"], 1) + self.assertEqual(report["nan_or_inf_samples"], 1) + self.assertGreaterEqual(len(report["errors"]), 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/acestep/training_v2/fixed_lora_module.py b/acestep/training_v2/fixed_lora_module.py index baa5028e..b79c50b2 100644 --- a/acestep/training_v2/fixed_lora_module.py +++ b/acestep/training_v2/fixed_lora_module.py @@ -36,6 +36,18 @@ AdapterConfig = Union[LoRAConfigV2, LoKRConfigV2] +def _is_attention_name(name: str) -> bool: + """Return True if *name* looks like an attention parameter path.""" + key = name.lower() + return "attn" in key or any(part in key for part in ("q_proj", "k_proj", "v_proj", "o_proj")) + + +def _is_ffn_name(name: str) -> bool: + """Return True if *name* looks like an FFN/MLP parameter path.""" + key = name.lower() + return any(part in key for part in ("mlp", "ffn", "feed_forward", "fc", "up_proj", "down_proj", "gate_proj")) + + class _LastLossAccessor: """Lightweight wrapper that provides ``[-1]`` and bool access. @@ -138,8 +150,16 @@ def __init__( self.lycoris_net: Any = None self.adapter_info: Dict[str, Any] = {} - # -- Adapter injection ----------------------------------------------- - if self.adapter_type == "lokr": + self.training_mode = getattr(training_config, "training_mode", "adapter") + + # -- Adapter or full fine-tune setup --------------------------------- + if self.training_mode == "full": + self.model = model + self._configure_full_finetune() + trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) + self.adapter_info = {"trainable_params": trainable_params} + logger.info("[OK] Full fine-tune enabled: %s trainable params", f"{trainable_params:,}") + elif self.adapter_type == "lokr": self._inject_lokr(model, adapter_config) # type: ignore[arg-type] else: self._inject_lora(model, adapter_config) # type: ignore[arg-type] @@ -230,6 +250,48 @@ def _inject_lokr(self, model: nn.Module, cfg: LoKRConfigV2) -> None: self.device, ) + def _configure_full_finetune(self) -> None: + """Enable decoder-only full fine-tuning with explicit freezing.""" + for p in self.model.parameters(): + p.requires_grad = False + + include = getattr(self.training_config, "full_train_include", "decoder") + if include != "decoder" or not hasattr(self.model, "decoder"): + raise ValueError("Only decoder full fine-tuning is supported currently") + + for p in self.model.decoder.parameters(): + p.requires_grad = True + + def build_full_mode_param_groups(self) -> list[dict]: + """Build optimizer param groups with per-family LR multipliers.""" + base_lr = float(self.training_config.learning_rate) + attn_mult = float(getattr(self.training_config, "full_lr_mult_attn", 1.0)) + ffn_mult = float(getattr(self.training_config, "full_lr_mult_ffn", 1.0)) + other_mult = float(getattr(self.training_config, "full_lr_mult_other", 1.0)) + + attn_params = [] + ffn_params = [] + other_params = [] + + for name, param in self.model.decoder.named_parameters(): + if not param.requires_grad: + continue + if _is_attention_name(name): + attn_params.append(param) + elif _is_ffn_name(name): + ffn_params.append(param) + else: + other_params.append(param) + + groups = [] + if attn_params: + groups.append({"params": attn_params, "lr": base_lr * attn_mult}) + if ffn_params: + groups.append({"params": ffn_params, "lr": base_lr * ffn_mult}) + if other_params: + groups.append({"params": other_params, "lr": base_lr * other_mult}) + return groups + # ----------------------------------------------------------------------- # Training step # ----------------------------------------------------------------------- diff --git a/acestep/training_v2/fixed_lora_module_full_mode_test.py b/acestep/training_v2/fixed_lora_module_full_mode_test.py new file mode 100644 index 00000000..994503c7 --- /dev/null +++ b/acestep/training_v2/fixed_lora_module_full_mode_test.py @@ -0,0 +1,78 @@ +"""Tests for full fine-tuning mode in FixedLoRAModule.""" + +from __future__ import annotations + +import unittest + +import torch +import torch.nn as nn + +from acestep.training_v2.configs import LoRAConfigV2, TrainingConfigV2 +from acestep.training_v2.fixed_lora_module import FixedLoRAModule + + +class _DummyModel(nn.Module): + """Minimal ACE-Step-like model with decoder and null condition embedding.""" + + def __init__(self) -> None: + super().__init__() + self.decoder = nn.Sequential(nn.Linear(4, 4), nn.Linear(4, 4)) + self.encoder = nn.Linear(4, 4) + self.null_condition_emb = nn.Parameter(torch.zeros(1, 1, 4)) + self.config = type("Cfg", (), {})() + + +class FixedLoRAModuleFullModeTest(unittest.TestCase): + """Verifies full-mode freezing and optimizer grouping behavior.""" + + def test_full_mode_trains_decoder_only(self) -> None: + """Only decoder parameters should remain trainable in full mode.""" + model = _DummyModel() + cfg = TrainingConfigV2( + dataset_dir=".", + checkpoint_dir=".", + output_dir=".", + training_mode="full", + ) + module = FixedLoRAModule( + model=model, + adapter_config=LoRAConfigV2(), + training_config=cfg, + device=torch.device("cpu"), + dtype=torch.float32, + ) + + decoder_trainable = all(p.requires_grad for p in module.model.decoder.parameters()) + encoder_trainable = any(p.requires_grad for p in module.model.encoder.parameters()) + + self.assertTrue(decoder_trainable) + self.assertFalse(encoder_trainable) + + def test_full_mode_param_groups_include_all_trainable_params(self) -> None: + """Param groups should cover every trainable decoder parameter once.""" + model = _DummyModel() + cfg = TrainingConfigV2( + dataset_dir=".", + checkpoint_dir=".", + output_dir=".", + training_mode="full", + full_lr_mult_attn=1.1, + full_lr_mult_ffn=0.9, + full_lr_mult_other=1.0, + ) + module = FixedLoRAModule( + model=model, + adapter_config=LoRAConfigV2(), + training_config=cfg, + device=torch.device("cpu"), + dtype=torch.float32, + ) + + groups = module.build_full_mode_param_groups() + grouped = {id(p) for group in groups for p in group["params"]} + trainable = {id(p) for p in module.model.parameters() if p.requires_grad} + self.assertSetEqual(grouped, trainable) + + +if __name__ == "__main__": + unittest.main() diff --git a/acestep/training_v2/trainer_basic_loop.py b/acestep/training_v2/trainer_basic_loop.py index b8c2de9e..97907b30 100644 --- a/acestep/training_v2/trainer_basic_loop.py +++ b/acestep/training_v2/trainer_basic_loop.py @@ -110,6 +110,9 @@ def run_basic_training_loop( train_loader = data_module.train_dataloader() trainable_params = [p for p in module.model.parameters() if p.requires_grad] + optimizer_params = trainable_params + if getattr(cfg, "training_mode", "adapter") == "full": + optimizer_params = module.build_full_mode_param_groups() if not trainable_params: yield TrainingUpdate(0, 0.0, "[FAIL] No trainable parameters found", kind="fail") tb.close() @@ -118,7 +121,7 @@ def run_basic_training_loop( device_type = module.device_type if hasattr(module, "device_type") else str(module.device).split(":")[0] optimizer_type = getattr(cfg, "optimizer_type", "adamw") optimizer = build_optimizer( - trainable_params, + optimizer_params, optimizer_type=optimizer_type, lr=cfg.learning_rate, weight_decay=cfg.weight_decay, diff --git a/acestep/training_v2/trainer_fixed.py b/acestep/training_v2/trainer_fixed.py index 1854257c..b070a229 100644 --- a/acestep/training_v2/trainer_fixed.py +++ b/acestep/training_v2/trainer_fixed.py @@ -149,6 +149,9 @@ def train( prefetch_factor=cfg.prefetch_factor if num_workers > 0 else None, persistent_workers=cfg.persistent_workers if num_workers > 0 else False, pin_memory_device=cfg.pin_memory_device, + length_bucket=getattr(cfg, "length_bucket", False), + cache_policy=getattr(cfg, "cache_policy", "none"), + cache_max_items=getattr(cfg, "cache_max_items", 0), ) data_module.setup("fit") @@ -256,6 +259,9 @@ def _train_fabric( # -- Trainable params / optimizer ----------------------------------- trainable_params = [p for p in self.module.model.parameters() if p.requires_grad] + optimizer_params = trainable_params + if getattr(cfg, "training_mode", "adapter") == "full": + optimizer_params = self.module.build_full_mode_param_groups() if not trainable_params: yield TrainingUpdate(0, 0.0, "[FAIL] No trainable parameters found", kind="fail") tb.close() @@ -265,7 +271,7 @@ def _train_fabric( optimizer_type = getattr(cfg, "optimizer_type", "adamw") optimizer = build_optimizer( - trainable_params, + optimizer_params, optimizer_type=optimizer_type, lr=cfg.learning_rate, weight_decay=cfg.weight_decay, diff --git a/acestep/training_v2/trainer_helpers.py b/acestep/training_v2/trainer_helpers.py index b5f44320..456b0d92 100644 --- a/acestep/training_v2/trainer_helpers.py +++ b/acestep/training_v2/trainer_helpers.py @@ -12,6 +12,7 @@ import os from pathlib import Path from typing import Any, Generator, Optional, Tuple +from functools import partial import torch import torch.nn as nn @@ -143,6 +144,25 @@ def offload_non_decoder(model: nn.Module) -> int: # --------------------------------------------------------------------------- +def _save_full_decoder_state(module: Any, output_dir: str) -> None: + """Save decoder state for full fine-tuning mode.""" + decoder = _unwrap_decoder(module.model) + path = os.path.join(output_dir, "full_decoder_state.pt") + torch.save(decoder.state_dict(), path) + logger.info("[OK] Full decoder weights saved to %s", path) + + +def _load_full_decoder_state(module: Any, ckpt_dir: Path) -> bool: + """Load full decoder checkpoint when present.""" + path = ckpt_dir / "full_decoder_state.pt" + if not path.exists(): + return False + state = torch.load(str(path), map_location=module.device, weights_only=True) + decoder = _unwrap_decoder(module.model) + decoder.load_state_dict(state, strict=False) + return True + + def save_adapter_flat(trainer: Any, output_dir: str) -> None: """Save adapter weights directly into *output_dir* (no nesting). @@ -154,6 +174,10 @@ def save_adapter_flat(trainer: Any, output_dir: str) -> None: assert module is not None os.makedirs(output_dir, exist_ok=True) + if getattr(trainer.training_config, "training_mode", "adapter") == "full": + _save_full_decoder_state(module, output_dir) + return + if trainer.adapter_type == "lokr": if module.lycoris_net is None: logger.error( @@ -240,6 +264,11 @@ def verify_saved_adapter(output_dir: str) -> None: a warning if the weights appear to be all zeros (which would mean the LoRA has no effect during inference). """ + full_path = os.path.join(output_dir, "full_decoder_state.pt") + if os.path.exists(full_path): + logger.info("[OK] Full fine-tune weights saved: %s", full_path) + return + safetensors_path = os.path.join(output_dir, "adapter_model.safetensors") config_path = os.path.join(output_dir, "adapter_config.json") @@ -315,95 +344,186 @@ def resume_checkpoint( ) ckpt_dir = ckpt_dir.parent - # -- Detect format: LoKR uses lokr_weights.safetensors --------------- - lokr_weights_path = ckpt_dir / "lokr_weights.safetensors" + state_loader = partial( + _load_training_state, + module=module, + optimizer=optimizer, + scheduler=scheduler, + ) + + training_mode = getattr(trainer.training_config, "training_mode", "adapter") + if training_mode == "full": + return (yield from _resume_full_decoder(trainer, ckpt_dir, module, state_loader)) + + return (yield from _resume_lokr_or_lora( + trainer, + resume_path, + ckpt_dir, + module, + optimizer, + scheduler, + state_loader, + )) + + +def _load_training_state( + ckpt_dir: Path, + module: Any, + optimizer: Any, + scheduler: Any, +) -> Optional[Tuple[int, int, Dict[str, Any]]]: + """Load training progress and optimizer/scheduler state from ``training_state.pt``.""" state_path = ckpt_dir / "training_state.pt" + if not state_path.exists(): + return None - if lokr_weights_path.exists() and module.lycoris_net is not None: - # LoKR resume - if trainer.adapter_type != "lokr": - logger.warning( - "[WARN] Found lokr_weights.safetensors but adapter_type is '%s' " - "-- loading as LoKR anyway", - trainer.adapter_type, - ) - load_lokr_weights(module.lycoris_net, str(lokr_weights_path)) - if state_path.exists(): - state = torch.load( - str(state_path), map_location=module.device, weights_only=False - ) - epoch = state.get("epoch", 0) - step = state.get("global_step", 0) - if "optimizer_state_dict" in state: - optimizer.load_state_dict(state["optimizer_state_dict"]) - if "scheduler_state_dict" in state: - scheduler.load_state_dict(state["scheduler_state_dict"]) - yield TrainingUpdate( - 0, - 0.0, - f"[OK] Resumed LoKR from epoch {epoch}, step {step}", - kind="info", - ) - return (epoch, step) + state = torch.load(str(state_path), map_location=module.device, weights_only=False) + if "optimizer_state_dict" in state: + optimizer.load_state_dict(state["optimizer_state_dict"]) + if "scheduler_state_dict" in state: + scheduler.load_state_dict(state["scheduler_state_dict"]) + + return state.get("epoch", 0), state.get("global_step", 0), state + + +def _resume_full_decoder( + trainer: Any, + ckpt_dir: Path, + module: Any, + state_loader: Any, +) -> Generator[TrainingUpdate, None, Optional[Tuple[int, int]]]: + """Handle full fine-tune checkpoint resume path.""" + if not _load_full_decoder_state(module, ckpt_dir): yield TrainingUpdate( - 0, 0.0, "[OK] LoKR weights loaded (no training state)", kind="info" + 0, + 0.0, + f"[WARN] full_decoder_state.pt not found in {ckpt_dir}", + kind="warn", ) return None - # Warn if LoKR was expected but checkpoint is LoRA-format - if trainer.adapter_type == "lokr": - if not lokr_weights_path.exists(): - logger.warning( - "[WARN] adapter_type is 'lokr' but no lokr_weights.safetensors " - "found in %s -- falling back to LoRA resume format", - resume_path, - ) - elif module.lycoris_net is None: - logger.warning( - "[WARN] adapter_type is 'lokr' and lokr_weights.safetensors exists " - "but lycoris_net is None -- cannot load LoKR checkpoint", - ) + training_state = state_loader(ckpt_dir=ckpt_dir) + if training_state is None: + yield TrainingUpdate(0, 0.0, "[OK] Loaded full decoder weights (no training state)", kind="info") + return None - # LoRA resume (original logic) + epoch, step, _ = training_state + yield TrainingUpdate(0, 0.0, f"[OK] Resumed full fine-tune from epoch {epoch}, step {step}", kind="info") + return (epoch, step) + + +def _resume_lokr( + trainer: Any, + resume_path: str, + ckpt_dir: Path, + module: Any, + state_loader: Any, +) -> Generator[TrainingUpdate, None, Optional[Tuple[int, int]]]: + """Handle LoKR checkpoint resume path when LoKR weights are present.""" + lokr_weights_path = ckpt_dir / "lokr_weights.safetensors" + if not (lokr_weights_path.exists() and module.lycoris_net is not None): + if trainer.adapter_type == "lokr": + if not lokr_weights_path.exists(): + logger.warning( + "[WARN] adapter_type is 'lokr' but no lokr_weights.safetensors " + "found in %s -- falling back to LoRA resume format", + resume_path, + ) + else: + logger.warning( + "[WARN] adapter_type is 'lokr' and lokr_weights.safetensors exists " + "but lycoris_net is None -- cannot load LoKR checkpoint", + ) + return None + + if trainer.adapter_type != "lokr": + logger.warning( + "[WARN] Found lokr_weights.safetensors but adapter_type is '%s' " + "-- loading as LoKR anyway", + trainer.adapter_type, + ) + + load_lokr_weights(module.lycoris_net, str(lokr_weights_path)) + training_state = state_loader(ckpt_dir=ckpt_dir) + if training_state is None: + yield TrainingUpdate(0, 0.0, "[OK] LoKR weights loaded (no training state)", kind="info") + return None + + # Full mode resume: no full decoder checkpoint found. + if getattr(trainer.training_config, "training_mode", "adapter") == "full": + yield TrainingUpdate( + 0, + 0.0, + f"[WARN] full_decoder_state.pt not found in {ckpt_dir}", + kind="warn", + ) + return + + +def _resume_lora( + ckpt_dir: Path, + module: Any, + optimizer: Any, + scheduler: Any, +) -> Generator[TrainingUpdate, None, Optional[Tuple[int, int]]]: + """Handle LoRA checkpoint resume path.""" ckpt_info = load_training_checkpoint( str(ckpt_dir), optimizer=optimizer, scheduler=scheduler, device=module.device, ) - if ckpt_info["adapter_path"]: - adapter_path = ckpt_info["adapter_path"] - aw_path = os.path.join(adapter_path, "adapter_model.safetensors") - if not os.path.exists(aw_path): - aw_path = os.path.join(adapter_path, "adapter_model.bin") - - if os.path.exists(aw_path): - from safetensors.torch import load_file - - state_dict = ( - load_file(aw_path) - if aw_path.endswith(".safetensors") - else torch.load(aw_path, map_location=module.device, weights_only=True) - ) - decoder = module.model.decoder - if hasattr(decoder, "_forward_module"): - decoder = decoder._forward_module - decoder.load_state_dict(state_dict, strict=False) - - start_epoch = ckpt_info["epoch"] - g_step = ckpt_info["global_step"] - parts = [f"[OK] Resumed from epoch {start_epoch}, step {g_step}"] - if ckpt_info["loaded_optimizer"]: - parts.append("optimizer OK") - if ckpt_info["loaded_scheduler"]: - parts.append("scheduler OK") - yield TrainingUpdate(0, 0.0, ", ".join(parts), kind="info") - return (start_epoch, g_step) - yield TrainingUpdate( - 0, 0.0, f"[WARN] Adapter weights not found in {adapter_path}", kind="warn" - ) + if not ckpt_info["adapter_path"]: + yield TrainingUpdate(0, 0.0, f"[WARN] No valid checkpoint in {ckpt_dir}", kind="warn") + return None + + adapter_path = ckpt_info["adapter_path"] + aw_path = os.path.join(adapter_path, "adapter_model.safetensors") + if not os.path.exists(aw_path): + aw_path = os.path.join(adapter_path, "adapter_model.bin") + if not os.path.exists(aw_path): + yield TrainingUpdate(0, 0.0, f"[WARN] Adapter weights not found in {adapter_path}", kind="warn") return None - yield TrainingUpdate( - 0, 0.0, f"[WARN] No valid checkpoint in {ckpt_dir}", kind="warn" + + from safetensors.torch import load_file + + state_dict = ( + load_file(aw_path) if aw_path.endswith(".safetensors") + else torch.load(aw_path, map_location=module.device, weights_only=True) + ) + decoder = module.model.decoder + if hasattr(decoder, "_forward_module"): + decoder = decoder._forward_module + decoder.load_state_dict(state_dict, strict=False) + + start_epoch = ckpt_info["epoch"] + g_step = ckpt_info["global_step"] + parts = [f"[OK] Resumed from epoch {start_epoch}, step {g_step}"] + if ckpt_info["loaded_optimizer"]: + parts.append("optimizer OK") + if ckpt_info["loaded_scheduler"]: + parts.append("scheduler OK") + yield TrainingUpdate(0, 0.0, ", ".join(parts), kind="info") + return (start_epoch, g_step) + + +def _resume_lokr_or_lora( + trainer: Any, + resume_path: str, + ckpt_dir: Path, + module: Any, + optimizer: Any, + scheduler: Any, + state_loader: Any, +) -> Generator[TrainingUpdate, None, Optional[Tuple[int, int]]]: + """Resume adapter-mode training by trying LoKR then LoRA format.""" + lokr_result = yield from _resume_lokr( + trainer, + resume_path, + ckpt_dir, + module, + state_loader, ) - return None + if lokr_result is not None: + return lokr_result + return (yield from _resume_lora(ckpt_dir, module, optimizer, scheduler)) diff --git a/acestep/training_v2/trainer_helpers_test.py b/acestep/training_v2/trainer_helpers_test.py new file mode 100644 index 00000000..bbf746a9 --- /dev/null +++ b/acestep/training_v2/trainer_helpers_test.py @@ -0,0 +1,48 @@ +"""Unit tests for ``trainer_helpers.resume_checkpoint`` full-mode fallback.""" + +from __future__ import annotations + +import tempfile +import unittest +from types import SimpleNamespace + +import torch + +from acestep.training_v2 import trainer_helpers + + +class ResumeCheckpointFullModeTests(unittest.TestCase): + """Validate full-mode resume behavior when full checkpoint file is absent.""" + + def test_full_mode_missing_decoder_state_yields_warning(self) -> None: + """Missing ``full_decoder_state.pt`` should warn and return ``None``.""" + trainer = SimpleNamespace( + module=SimpleNamespace(device=torch.device("cpu"), lycoris_net=None), + training_config=SimpleNamespace(training_mode="full"), + adapter_type="lora", + ) + optimizer = SimpleNamespace(load_state_dict=lambda state: None) + scheduler = SimpleNamespace(load_state_dict=lambda state: None) + + with tempfile.TemporaryDirectory() as tmpdir: + generator = trainer_helpers.resume_checkpoint( + trainer, + tmpdir, + optimizer, + scheduler, + ) + updates = [] + try: + while True: + updates.append(next(generator)) + except StopIteration as stop: + result = stop.value + + self.assertIsNone(result) + self.assertEqual(len(updates), 1) + self.assertEqual(updates[0].kind, "warn") + self.assertIn("full_decoder_state.pt not found", updates[0].msg) + + +if __name__ == "__main__": + unittest.main() diff --git a/docs/lora_full_finetune_pipeline_proposal.md b/docs/lora_full_finetune_pipeline_proposal.md new file mode 100644 index 00000000..7acc9345 --- /dev/null +++ b/docs/lora_full_finetune_pipeline_proposal.md @@ -0,0 +1,151 @@ +# Proposal: Improve ACE-Step LoRA Training Execution Pipeline and Add Full Fine-Tuning Capability + +## 1) Current-state findings (from code walk) + +### What is already good +- The training loops (`vanilla` and `fixed`) consume **preprocessed `.pt` tensors** via `PreprocessedDataModule`, not raw audio. +- The v2 preprocessing path already runs a **two-pass offline pipeline**: + - Pass 1: audio decode + VAE encode + text/lyrics encode into temporary tensors. + - Pass 2: DIT encoder + context latent construction into final training tensors. + +### Where inefficiency/risk still exists +- Each sample is stored as a separate `.pt` file; the dataset loader does `torch.load` per sample at runtime. This avoids real-time VAE decode but can still bottleneck on filesystem seeks + Python deserialization for large datasets. +- Pass 2 currently re-loads each temporary file independently and repeatedly transfers multiple tensors to GPU; this is functional but not optimized for throughput. +- The training stack is adapter-centric (LoRA/LoKR) and does not expose a first-class, safe “full fine-tune” mode with optimizer/parameter-group controls, freezing policies, and checkpoint strategy. + +--- + +## 2) Proposal A — Better LoRA execution pipeline (precompute-first, throughput-oriented) + +## A1. Introduce a packed latent dataset format (v3) +Create a new optional preprocessing output format: +- `shard-xxxxx.safetensors` (or `.pt` if needed) with contiguous arrays. +- `index.jsonl` containing sample metadata + byte/range mapping. + +Benefits: +- Fewer file opens/seeks vs one-file-per-sample. +- Faster startup and epoch iteration. +- Easier weighted sampling and bucketing using index metadata only. + +Backward compatibility: +- Keep existing `.pt`-per-sample format as default for compatibility. +- Add `--dataset-format {pt,v3_sharded}` and auto-detect in loader. + +## A2. Add duration/latent-length bucketing in dataloader +- Precompute latent length in manifest/index. +- Build bucketed batches (near-equal lengths) to reduce padding waste. +- Keep deterministic shuffling by bucket + epoch seed. + +Expected gain: +- Better GPU utilization and lower wasted FLOPs for variable-length audio. + +## A3. Add optional in-memory or mmap cache layer +- New runtime flag: `--cache-policy {none,mmap,ram_lru}`. +- `mmap`: avoid repeated deserialization cost. +- `ram_lru`: cache hottest tensors with memory cap (`--cache-max-gb`). + +Expected gain: +- Less CPU overhead and reduced storage I/O jitter. + +## A4. Precompute and persist all conditioning tensors once +The current two-pass preprocessing already computes major tensors. Extend this by: +- Persisting any variant-dependent conditioning expansions that remain deterministic. +- Recording preprocessing fingerprint (`model variant`, `checkpoint hash`, `tokenizer hash`) to enforce dataset/model compatibility before training starts. + +## A5. Add asynchronous prefetch + pinned staging buffer +- A small CUDA prefetcher can move next batch to device while current batch trains. +- Expose `--device-prefetch` with safe fallback. + +## A6. Add preprocessing QA command +New CLI check command before training: +- Verifies tensor shapes/dtypes/masks. +- Checks for NaNs/infs. +- Reports padding efficiency estimate and bucket distribution. + +--- + +## 3) Proposal B — Full model fine-tuning mode (beyond LoRA) + +## B1. Add `adapter_type=full` (or `training_mode=full`) +- Reuse existing trainer shell and UI flow. +- Skip adapter injection and mark selected modules trainable. + +Initial scope (safe): +- Fine-tune **decoder-only** first. +- Keep encoder/VAE frozen by default. + +Advanced scope: +- Optional staged unfreezing (decoder → encoder). +- Optional text encoder unfreezing for domain transfer. + +## B2. Parameter-group and LR policy +Provide explicit parameter groups: +- `decoder.attn`, `decoder.ffn`, `norm`, `embeddings` (as applicable). +- Distinct LR multipliers (`--lr-mult-attn`, etc.). +- Weight decay exclusions for norms/biases. + +## B3. Memory-safe full FT path +- Mandatory gradient checkpointing path validation. +- bf16/fp16 mixed precision with fp32 master weights where needed. +- Optional 8-bit optimizer support where stable. +- Optional FSDP/ZeRO integration behind explicit flag. + +## B4. Checkpointing and resume semantics for full FT +- Save full model state (or sharded states) + optimizer + scheduler + scaler. +- Add periodic EMA checkpoint option for stability. +- Add strict compatibility validation on resume. + +## B5. Safety controls +- Require explicit opt-in (`--training-mode full --i-understand-vram-risk`). +- Preflight VRAM estimator for model variant + seq length + batch size. +- Auto-suggest fallback to LoRA when estimate exceeds threshold. + +## B6. Evaluation and regression hooks +- Add minimal validation hooks (loss-only + optional sample generation every N epochs). +- Track train/val divergence and expose early-stop patience. + +--- + +## 4) Suggested implementation plan (low-risk phases) + +### Phase 1: Data throughput +1. Add packed dataset writer + loader (behind flags). +2. Add bucketed sampler. +3. Add preprocessing QA validator. + +### Phase 2: Runtime pipeline polish +1. Add mmap/LRU cache policy. +2. Add optional device prefetch. +3. Benchmark and publish defaults by GPU class. + +### Phase 3: Full fine-tuning MVP +1. Add `training_mode=full` with decoder-only unfreeze. +2. Add parameter groups + LR multipliers. +3. Add full-state checkpoint/resume. + +### Phase 4: Advanced scaling +1. Add distributed/sharded optimizer options. +2. Add staged unfreeze profiles. +3. Add stronger eval/early-stop controls. + +--- + +## 5) Acceptance criteria (measurable) + +For LoRA pipeline improvements: +- >=20% faster step time on medium dataset vs current `.pt` baseline. +- >=15% reduction in data-loader stall time. +- No regression in final training loss curve over fixed seed smoke run. + +For full fine-tuning MVP: +- Successfully trains decoder-only full FT for at least one epoch on supported GPU setup. +- Resume from checkpoint reproduces optimizer/scheduler state correctly. +- CLI safety checks prevent accidental OOM-prone config starts. + +--- + +## 6) Immediate next patch candidates +1. Add `DatasetBackend` abstraction and a sharded backend implementation. +2. Add `BucketedBatchSampler` keyed by latent length. +3. Add `train.py validate-dataset --dataset-dir ...` command. +4. Add `training_mode` enum + `full` branch in config/trainer with decoder-only unfreeze. diff --git a/train.py b/train.py index f28b54e6..192450f6 100644 --- a/train.py +++ b/train.py @@ -61,7 +61,7 @@ def _has_subcommand() -> bool: args = sys.argv[1:] if "--help" in args or "-h" in args: return True # let argparse handle help - known = {"vanilla", "fixed", "estimate"} + known = {"vanilla", "fixed", "estimate", "validate-dataset"} return bool(known & set(args)) @@ -89,8 +89,8 @@ def _dispatch(args) -> int: sub = args.subcommand - # All subcommands need path validation - if not validate_paths(args): + # Most subcommands need path validation + if sub != "validate-dataset" and not validate_paths(args): return 1 if sub == "vanilla": @@ -104,6 +104,9 @@ def _dispatch(args) -> int: elif sub == "estimate": return _run_estimate(args) + elif sub == "validate-dataset": + return _run_validate_dataset(args) + else: print(f"[FAIL] Unknown subcommand: {sub}", file=sys.stderr) return 1 @@ -202,6 +205,30 @@ def _run_preprocess(args) -> int: return 0 + +def _run_validate_dataset(args) -> int: + """Validate preprocessed tensors and print quality stats.""" + from acestep.training_v2.dataset_validation import validate_preprocessed_dataset + + report = validate_preprocessed_dataset(args.dataset_dir) + print("\n" + "=" * 60) + print(" Dataset Validation") + print("=" * 60) + print(f" Dataset: {args.dataset_dir}") + print(f" Samples: {report['total_samples']}") + print(f" Valid: {report['valid_samples']}") + print(f" Invalid: {report['invalid_samples']}") + print(f" NaN/Inf: {report['nan_or_inf_samples']}") + print(f" Min latent T: {report['min_latent_length']}") + print(f" Max latent T: {report['max_latent_length']}") + print(f" Avg latent T: {report['avg_latent_length']:.2f}") + if report["errors"]: + print("\n[WARN] Sample errors:") + for err in report["errors"][:10]: + print(f" - {err}") + print("=" * 60) + return 0 if report["invalid_samples"] == 0 else 1 + def _run_estimate(args) -> int: """Run gradient sensitivity estimation.""" import json as _json