Conversation
…ora-training-pipeline Add full fine-tune mode, dataset validation, length bucketing and dataset cache
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds length-based bucketed batching and an in-memory LRU cache for preprocessed datasets, a dataset validation CLI, and a "full" decoder-only fine-tuning mode with per-family LR multipliers; wires new flags through CLI, TrainingConfigV2, data modules, trainer, and checkpoint/resume helpers. Changes
Sequence DiagramsequenceDiagram
participant User as CLI/User
participant Config as TrainingConfigV2
participant Trainer as Trainer
participant Module as FixedLoRAModule
participant Optimizer as Optimizer
participant DataModule as DataModule
participant Dataset as PreprocessedTensorDataset
rect rgba(100,150,200,0.5)
Note over User,Module: Initialization with full-mode and data options
User->>Trainer: start (training_mode="full", length_bucket=...)
Trainer->>Config: read training_mode, lr multipliers, caching flags
Config-->>Trainer: config
Trainer->>Module: init with training_config
Module->>Module: _configure_full_finetune() (freeze except decoder)
Module-->>Trainer: module ready
end
rect rgba(150,100,200,0.5)
Note over Trainer,Optimizer: Build optimizer params for full-mode
Trainer->>Module: build_full_mode_param_groups()
Module-->>Trainer: param groups (attn/ffn/other)
Trainer->>Optimizer: init with grouped params
Optimizer-->>Trainer: optimizer ready
end
rect rgba(100,200,150,0.5)
Note over Trainer,Dataset: Data loading with bucketing & caching
Trainer->>DataModule: create with length_bucket, cache_policy
DataModule->>Dataset: init (compute latent_lengths, init LRU cache)
DataModule->>DataModule: use BucketedBatchSampler if enabled
DataModule-->>Trainer: dataloader ready
end
rect rgba(200,150,100,0.5)
Note over Trainer,Dataset: Training loop sample flow
Trainer->>Dataset: __getitem__(idx)
Dataset->>Dataset: check _cache
alt cache hit
Dataset-->>Trainer: return cached tensors
else cache miss
Dataset->>Dataset: load from disk, store in LRU cache
Dataset-->>Trainer: return tensors
end
Trainer->>Optimizer: backward & step
Optimizer-->>Trainer: updated
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
acestep/training/data_module.py (1)
534-557:⚠️ Potential issue | 🔴 CriticalMissing parameters in
__init__signature causeNameErrorat runtime.The parameters
length_bucket,cache_policy, andcache_max_itemsare assigned toselfat lines 554-556, but they are not declared in the__init__signature. This will raise aNameErrorwhen instantiatingAceStepDataModule.🐛 Proposed fix to add missing parameters
def __init__( self, samples: List[Dict[str, Any]], dit_handler, batch_size: int = 1, num_workers: int = 4, pin_memory: bool = True, max_duration: float = 240.0, val_split: float = 0.0, + length_bucket: bool = False, + cache_policy: str = "none", + cache_max_items: int = 0, ):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/training/data_module.py` around lines 534 - 557, The __init__ of AceStepDataModule assigns self.length_bucket, self.cache_policy, and self.cache_max_items but those parameters are missing from the signature; add them to the __init__ signature (e.g. length_bucket: Optional[Sequence[int]] = None, cache_policy: str = "none", cache_max_items: int = 0) and ensure the assignments remain (self.length_bucket = length_bucket, self.cache_policy = cache_policy, self.cache_max_items = cache_max_items) so instantiating AceStepDataModule no longer raises NameError.
🧹 Nitpick comments (5)
acestep/training/data_module.py (3)
179-188: Cache is per-worker and won't be shared across DataLoader workers.With
num_workers > 0, PyTorch spawns separate processes, each with its ownDatasetcopy. The LRU cache won't be shared, reducing its effectiveness and increasing total memory usage proportionally to worker count.Consider documenting this behavior or recommending
num_workers=0when using RAM caching. Alternatively,cache_policy="none"withpersistent_workers=Truemay provide better performance for disk-based loading.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/training/data_module.py` around lines 179 - 188, The LRU RAM cache implemented by dataset attributes cache_policy, _cache and cache_max_items (in the data loading block that reads valid_paths and uses torch.load) is per-process and will be duplicated for each DataLoader worker when num_workers>0; update documentation and/or dataset behavior: add a note in the class docstring or surrounding comments explaining that RAM caching is process-local and recommend using num_workers=0 for ram_lru, or suggest using cache_policy="none" together with persistent_workers=True for disk-based loading, and (optionally) add a runtime warning inside the Dataset __init__ or before loading (checking torch.utils.data.get_worker_info or os.getpid differences) when cache_policy=="ram_lru" and num_workers>0 to inform users of the per-worker duplication.
59-61:__len__may return inaccurate batch count.The current implementation divides total samples by batch size, but actual batch count depends on bucket distribution. Each bucket's last batch may be smaller, leading to more batches than this estimate. This can cause misleading progress bars or epoch-length mismatches.
Consider counting actual batches:
♻️ Proposed fix for accurate batch count
def __len__(self) -> int: - total = len(self.lengths) - return (total + self.batch_size - 1) // self.batch_size + buckets: Dict[int, int] = {} + for length in self.lengths: + bucket = int(length // 64) + buckets[bucket] = buckets.get(bucket, 0) + 1 + return sum((count + self.batch_size - 1) // self.batch_size for count in buckets.values())🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/training/data_module.py` around lines 59 - 61, The __len__ method currently computes batches using total = len(self.lengths) which is wrong because batches depend on each bucket's sample count; change __len__ (in the class defining __len__) to sum the per-bucket batch counts by iterating over self.lengths and for each length l adding (l + self.batch_size - 1) // self.batch_size (or math.ceil(l / self.batch_size)) so the returned value equals the actual number of batches across all buckets.
554-556: Attributes stored but never used in this deprecated class.Even after fixing the signature,
length_bucket,cache_policy, andcache_max_itemsare never referenced insetup()ortrain_dataloader(). Consider either implementing the functionality for consistency withPreprocessedDataModule, or removing these unused attributes since the class is deprecated.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/training/data_module.py` around lines 554 - 556, The deprecated data module stores unused attributes length_bucket, cache_policy, and cache_max_items; either remove these attributes and any constructor parameters that set them, or wire them up so the class behaves like PreprocessedDataModule by using length_bucket in setup() to group examples into buckets, and applying cache_policy/cache_max_items in train_dataloader() when building the in-memory or disk cache for preprocessed samples; update the constructor, setup(), and train_dataloader() implementations (referencing the same attribute names and method names) to ensure the attributes are actually consumed or are removed to avoid dead state.acestep/training_v2/dataset_validation_test.py (1)
45-49: Add assertions for latent-length statistics to cover new behavior.Current assertions validate counts, but not
min_latent_length,max_latent_length, oravg_latent_length, which are part of the new validator output.Proposed test additions
self.assertEqual(report["total_samples"], 3) self.assertEqual(report["valid_samples"], 2) self.assertEqual(report["invalid_samples"], 1) self.assertEqual(report["nan_or_inf_samples"], 1) + self.assertEqual(report["min_latent_length"], 1) + self.assertEqual(report["max_latent_length"], 8) + self.assertAlmostEqual(report["avg_latent_length"], 4.5) self.assertGreaterEqual(len(report["errors"]), 1)As per coding guidelines
**/*_test.py: "Add or update tests for every behavior change and bug fix".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/training_v2/dataset_validation_test.py` around lines 45 - 49, The test currently asserts counts in the `report` but misses the new latent-length stats; update the test that inspects the `report` (the assertions block where `self.assertEqual(report["total_samples"], 3)` etc.) to also assert `report["min_latent_length"]`, `report["max_latent_length"]`, and `report["avg_latent_length"]` with the expected numeric values for this fixture (compute expected min/max/avg from the three sample latents used in the test) so the new validator output is covered.acestep/training_v2/fixed_lora_module_full_mode_test.py (1)
71-74:setcomparison does not verify “exactly once” membership.Using sets here can hide duplicate parameters across groups. Add an explicit duplicate check before the set-equality assertion.
Suggested test hardening
groups = module.build_full_mode_param_groups() - grouped = {id(p) for group in groups for p in group["params"]} + grouped_ids = [id(p) for group in groups for p in group["params"]] + self.assertEqual( + len(grouped_ids), + len(set(grouped_ids)), + "A trainable parameter appears in more than one optimizer group", + ) + grouped = set(grouped_ids) trainable = {id(p) for p in module.model.parameters() if p.requires_grad} self.assertSetEqual(grouped, trainable)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/training_v2/fixed_lora_module_full_mode_test.py` around lines 71 - 74, The test uses sets (grouped/trainable) which hide duplicate parameter entries across groups; modify the test around build_full_mode_param_groups() to first flatten the group["params"] into a list (e.g., groups -> list_of_params_ids) and assert there are no duplicates by checking len(list_of_params_ids) == len(set(list_of_params_ids)) (or by asserting Counter(list_of_params_ids).most_common(1)[0][1] == 1), then proceed to compare the sets (grouped vs trainable) as before; reference build_full_mode_param_groups, groups, group["params"], grouped, and trainable when making the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@acestep/training_v2/dataset_validation.py`:
- Around line 25-29: validate_preprocessed_dataset currently treats a
non-existent or non-directory path as valid by returning zero samples; after
computing root = Path(dataset_dir) add an explicit existence and directory check
(e.g., if not root.exists() or not root.is_dir()) and raise a clear exception
(FileNotFoundError or ValueError) so callers get an immediate, descriptive
failure instead of a false "valid" report; reference
validate_preprocessed_dataset, root, and files when making the change.
In `@acestep/training_v2/fixed_lora_module_full_mode_test.py`:
- Around line 17-23: Add a concise docstring inside _DummyModel.__init__
describing the purpose of the initializer and the attributes it sets (e.g.,
creates decoder: nn.Sequential of two Linear(4,4), encoder: nn.Linear(4,4),
null_condition_emb: nn.Parameter of zeros with shape (1,1,4), and a simple
config object), and note there are no parameters to __init__; place the
docstring as the first statement in the __init__ method so it documents the
constructor and its initialized members.
In `@acestep/training_v2/trainer_helpers.py`:
- Around line 367-369: The stray block in resume_checkpoint incorrectly
references output_dir and calls _save_full_decoder_state during resume, causing
a runtime crash; remove this entire conditional block (the if checking
getattr(trainer.training_config, "training_mode", "adapter") == "full" that
calls _save_full_decoder_state and returns) from resume_checkpoint so resume
logic does not perform saves or reference undefined output_dir, or if a mode
check is needed keep it readonly (no save/return) and use existing save paths
elsewhere.
In `@acestep/training/data_module.py`:
- Around line 355-362: The bucketing code fails silently when self.train_dataset
is a torch.utils.data.Subset because hasattr(self.train_dataset,
"latent_lengths") is false for the Subset wrapper; update the condition in the
block that sets BucketedBatchSampler (symbols: length_bucket,
self.train_dataset, latent_lengths, BucketedBatchSampler) to look through the
Subset wrapper and obtain the underlying dataset's latent_lengths (e.g., check
getattr(self.train_dataset, "dataset", self.train_dataset) for latent_lengths)
before deciding to pop batch_size/shuffle and assign batch_sampler, so the
BucketedBatchSampler still receives the correct lengths when val_split > 0.
- Around line 122-128: The loop that builds self.latent_lengths in DataModule
(using self.valid_paths) currently uses a bare except which masks errors; change
it to catch specific exceptions when loading torch files (e.g.,
FileNotFoundError, PermissionError, EOFError, RuntimeError) and log the failure
including the path and exception (use the module/class logger or python
logging), append 0 only for those expected/recoverable errors, and let other
unexpected exceptions propagate (or re-raise) so real issues aren’t silently
ignored; update the try/except around torch.load and the
sample["target_latents"] access accordingly and include the path (vp) and
exception text in the log message.
In `@train.py`:
- Around line 213-230: The current exit logic uses only
report["invalid_samples"] to decide success, allowing empty datasets to return
success; update the final return condition to require report["total_samples"] >
0 in addition to report["invalid_samples"] == 0 (i.e., return success only when
there is at least one sample and no invalid samples). Locate the block that
calls validate_preprocessed_dataset and references report (the print/report
summary block) and change the return expression to check both
report["total_samples"] > 0 and report["invalid_samples"] == 0; optionally add a
short warning/log when total_samples == 0 before returning failure.
---
Outside diff comments:
In `@acestep/training/data_module.py`:
- Around line 534-557: The __init__ of AceStepDataModule assigns
self.length_bucket, self.cache_policy, and self.cache_max_items but those
parameters are missing from the signature; add them to the __init__ signature
(e.g. length_bucket: Optional[Sequence[int]] = None, cache_policy: str = "none",
cache_max_items: int = 0) and ensure the assignments remain (self.length_bucket
= length_bucket, self.cache_policy = cache_policy, self.cache_max_items =
cache_max_items) so instantiating AceStepDataModule no longer raises NameError.
---
Nitpick comments:
In `@acestep/training_v2/dataset_validation_test.py`:
- Around line 45-49: The test currently asserts counts in the `report` but
misses the new latent-length stats; update the test that inspects the `report`
(the assertions block where `self.assertEqual(report["total_samples"], 3)` etc.)
to also assert `report["min_latent_length"]`, `report["max_latent_length"]`, and
`report["avg_latent_length"]` with the expected numeric values for this fixture
(compute expected min/max/avg from the three sample latents used in the test) so
the new validator output is covered.
In `@acestep/training_v2/fixed_lora_module_full_mode_test.py`:
- Around line 71-74: The test uses sets (grouped/trainable) which hide duplicate
parameter entries across groups; modify the test around
build_full_mode_param_groups() to first flatten the group["params"] into a list
(e.g., groups -> list_of_params_ids) and assert there are no duplicates by
checking len(list_of_params_ids) == len(set(list_of_params_ids)) (or by
asserting Counter(list_of_params_ids).most_common(1)[0][1] == 1), then proceed
to compare the sets (grouped vs trainable) as before; reference
build_full_mode_param_groups, groups, group["params"], grouped, and trainable
when making the change.
In `@acestep/training/data_module.py`:
- Around line 179-188: The LRU RAM cache implemented by dataset attributes
cache_policy, _cache and cache_max_items (in the data loading block that reads
valid_paths and uses torch.load) is per-process and will be duplicated for each
DataLoader worker when num_workers>0; update documentation and/or dataset
behavior: add a note in the class docstring or surrounding comments explaining
that RAM caching is process-local and recommend using num_workers=0 for ram_lru,
or suggest using cache_policy="none" together with persistent_workers=True for
disk-based loading, and (optionally) add a runtime warning inside the Dataset
__init__ or before loading (checking torch.utils.data.get_worker_info or
os.getpid differences) when cache_policy=="ram_lru" and num_workers>0 to inform
users of the per-worker duplication.
- Around line 59-61: The __len__ method currently computes batches using total =
len(self.lengths) which is wrong because batches depend on each bucket's sample
count; change __len__ (in the class defining __len__) to sum the per-bucket
batch counts by iterating over self.lengths and for each length l adding (l +
self.batch_size - 1) // self.batch_size (or math.ceil(l / self.batch_size)) so
the returned value equals the actual number of batches across all buckets.
- Around line 554-556: The deprecated data module stores unused attributes
length_bucket, cache_policy, and cache_max_items; either remove these attributes
and any constructor parameters that set them, or wire them up so the class
behaves like PreprocessedDataModule by using length_bucket in setup() to group
examples into buckets, and applying cache_policy/cache_max_items in
train_dataloader() when building the in-memory or disk cache for preprocessed
samples; update the constructor, setup(), and train_dataloader() implementations
(referencing the same attribute names and method names) to ensure the attributes
are actually consumed or are removed to avoid dead state.
ℹ️ Review info
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
acestep/training/data_module.pyacestep/training_v2/cli/args.pyacestep/training_v2/cli/config_builder.pyacestep/training_v2/configs.pyacestep/training_v2/dataset_validation.pyacestep/training_v2/dataset_validation_test.pyacestep/training_v2/fixed_lora_module.pyacestep/training_v2/fixed_lora_module_full_mode_test.pyacestep/training_v2/trainer_basic_loop.pyacestep/training_v2/trainer_fixed.pyacestep/training_v2/trainer_helpers.pydocs/lora_full_finetune_pipeline_proposal.mdtrain.py
| def validate_preprocessed_dataset(dataset_dir: str) -> Dict[str, Any]: | ||
| """Validate a preprocessed dataset directory and return summary stats.""" | ||
| root = Path(dataset_dir) | ||
| files = sorted(p for p in root.glob("*.pt") if p.name != "manifest.json") | ||
|
|
There was a problem hiding this comment.
Missing dataset directory validation can produce false “valid” reports.
At Line 27-29, a non-existent path silently returns zero samples instead of failing, which can mask configuration mistakes.
Proposed fix
def validate_preprocessed_dataset(dataset_dir: str) -> Dict[str, Any]:
"""Validate a preprocessed dataset directory and return summary stats."""
root = Path(dataset_dir)
+ if not root.is_dir():
+ raise FileNotFoundError(f"Dataset directory not found: {root}")
files = sorted(p for p in root.glob("*.pt") if p.name != "manifest.json")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def validate_preprocessed_dataset(dataset_dir: str) -> Dict[str, Any]: | |
| """Validate a preprocessed dataset directory and return summary stats.""" | |
| root = Path(dataset_dir) | |
| files = sorted(p for p in root.glob("*.pt") if p.name != "manifest.json") | |
| def validate_preprocessed_dataset(dataset_dir: str) -> Dict[str, Any]: | |
| """Validate a preprocessed dataset directory and return summary stats.""" | |
| root = Path(dataset_dir) | |
| if not root.is_dir(): | |
| raise FileNotFoundError(f"Dataset directory not found: {root}") | |
| files = sorted(p for p in root.glob("*.pt") if p.name != "manifest.json") |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@acestep/training_v2/dataset_validation.py` around lines 25 - 29,
validate_preprocessed_dataset currently treats a non-existent or non-directory
path as valid by returning zero samples; after computing root =
Path(dataset_dir) add an explicit existence and directory check (e.g., if not
root.exists() or not root.is_dir()) and raise a clear exception
(FileNotFoundError or ValueError) so callers get an immediate, descriptive
failure instead of a false "valid" report; reference
validate_preprocessed_dataset, root, and files when making the change.
| def __init__(self) -> None: | ||
| super().__init__() | ||
| self.decoder = nn.Sequential(nn.Linear(4, 4), nn.Linear(4, 4)) | ||
| self.encoder = nn.Linear(4, 4) | ||
| self.null_condition_emb = nn.Parameter(torch.zeros(1, 1, 4)) | ||
| self.config = type("Cfg", (), {})() | ||
|
|
There was a problem hiding this comment.
Add a docstring to _DummyModel.__init__.
The initializer is new but undocumented; please add a concise purpose/inputs docstring to keep module docs consistent.
As per coding guidelines: “Docstrings are mandatory for all new or modified Python modules, classes, and functions.”
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@acestep/training_v2/fixed_lora_module_full_mode_test.py` around lines 17 -
23, Add a concise docstring inside _DummyModel.__init__ describing the purpose
of the initializer and the attributes it sets (e.g., creates decoder:
nn.Sequential of two Linear(4,4), encoder: nn.Linear(4,4), null_condition_emb:
nn.Parameter of zeros with shape (1,1,4), and a simple config object), and note
there are no parameters to __init__; place the docstring as the first statement
in the __init__ method so it documents the constructor and its initialized
members.
| self.latent_lengths: List[int] = [] | ||
| for vp in self.valid_paths: | ||
| try: | ||
| sample = torch.load(vp, map_location="cpu", weights_only=True) | ||
| self.latent_lengths.append(int(sample["target_latents"].shape[0])) | ||
| except Exception: | ||
| self.latent_lengths.append(0) |
There was a problem hiding this comment.
Avoid bare except Exception; handle specific errors and log failures.
The current code silently catches all exceptions and defaults to length 0, which could mask real issues (corrupt files, permission errors) and cause unexpected bucketing behavior for failed samples.
🛡️ Proposed fix with specific exception handling
self.latent_lengths: List[int] = []
for vp in self.valid_paths:
try:
sample = torch.load(vp, map_location="cpu", weights_only=True)
self.latent_lengths.append(int(sample["target_latents"].shape[0]))
- except Exception:
+ except (OSError, KeyError, RuntimeError) as e:
+ logger.warning(f"Failed to read latent length from {vp}: {e}")
self.latent_lengths.append(0)🧰 Tools
🪛 Ruff (0.15.2)
[warning] 127-127: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@acestep/training/data_module.py` around lines 122 - 128, The loop that builds
self.latent_lengths in DataModule (using self.valid_paths) currently uses a bare
except which masks errors; change it to catch specific exceptions when loading
torch files (e.g., FileNotFoundError, PermissionError, EOFError, RuntimeError)
and log the failure including the path and exception (use the module/class
logger or python logging), append 0 only for those expected/recoverable errors,
and let other unexpected exceptions propagate (or re-raise) so real issues
aren’t silently ignored; update the try/except around torch.load and the
sample["target_latents"] access accordingly and include the path (vp) and
exception text in the log message.
| if self.length_bucket and hasattr(self.train_dataset, "latent_lengths"): | ||
| kwargs.pop("batch_size", None) | ||
| kwargs.pop("shuffle", None) | ||
| kwargs["batch_sampler"] = BucketedBatchSampler( | ||
| lengths=list(getattr(self.train_dataset, "latent_lengths", [])), | ||
| batch_size=self.batch_size, | ||
| shuffle=True, | ||
| ) |
There was a problem hiding this comment.
Bucketing silently fails when val_split > 0.
When validation split is enabled, self.train_dataset becomes a torch.utils.data.Subset, which lacks the latent_lengths attribute. The hasattr check passes silently, falling back to regular batching without any warning.
🐛 Proposed fix to access underlying dataset's latent_lengths
- if self.length_bucket and hasattr(self.train_dataset, "latent_lengths"):
+ latent_lengths = None
+ if self.length_bucket:
+ ds = self.train_dataset
+ # Handle Subset from random_split
+ if hasattr(ds, "dataset") and hasattr(ds, "indices"):
+ underlying = ds.dataset
+ if hasattr(underlying, "latent_lengths"):
+ latent_lengths = [underlying.latent_lengths[i] for i in ds.indices]
+ elif hasattr(ds, "latent_lengths"):
+ latent_lengths = ds.latent_lengths
+ if latent_lengths is not None:
kwargs.pop("batch_size", None)
kwargs.pop("shuffle", None)
kwargs["batch_sampler"] = BucketedBatchSampler(
- lengths=list(getattr(self.train_dataset, "latent_lengths", [])),
+ lengths=latent_lengths,
batch_size=self.batch_size,
shuffle=True,
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@acestep/training/data_module.py` around lines 355 - 362, The bucketing code
fails silently when self.train_dataset is a torch.utils.data.Subset because
hasattr(self.train_dataset, "latent_lengths") is false for the Subset wrapper;
update the condition in the block that sets BucketedBatchSampler (symbols:
length_bucket, self.train_dataset, latent_lengths, BucketedBatchSampler) to look
through the Subset wrapper and obtain the underlying dataset's latent_lengths
(e.g., check getattr(self.train_dataset, "dataset", self.train_dataset) for
latent_lengths) before deciding to pop batch_size/shuffle and assign
batch_sampler, so the BucketedBatchSampler still receives the correct lengths
when val_split > 0.
| report = validate_preprocessed_dataset(args.dataset_dir) | ||
| print("\n" + "=" * 60) | ||
| print(" Dataset Validation") | ||
| print("=" * 60) | ||
| print(f" Dataset: {args.dataset_dir}") | ||
| print(f" Samples: {report['total_samples']}") | ||
| print(f" Valid: {report['valid_samples']}") | ||
| print(f" Invalid: {report['invalid_samples']}") | ||
| print(f" NaN/Inf: {report['nan_or_inf_samples']}") | ||
| print(f" Min latent T: {report['min_latent_length']}") | ||
| print(f" Max latent T: {report['max_latent_length']}") | ||
| print(f" Avg latent T: {report['avg_latent_length']:.2f}") | ||
| if report["errors"]: | ||
| print("\n[WARN] Sample errors:") | ||
| for err in report["errors"][:10]: | ||
| print(f" - {err}") | ||
| print("=" * 60) | ||
| return 0 if report["invalid_samples"] == 0 else 1 |
There was a problem hiding this comment.
validate-dataset can return success for empty datasets.
At Line 230, success is based only on invalid_samples == 0; an empty dataset (total_samples == 0) currently exits 0.
Proposed fix
- return 0 if report["invalid_samples"] == 0 else 1
+ has_samples = report["total_samples"] > 0
+ is_clean = report["invalid_samples"] == 0
+ return 0 if (has_samples and is_clean) else 1🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@train.py` around lines 213 - 230, The current exit logic uses only
report["invalid_samples"] to decide success, allowing empty datasets to return
success; update the final return condition to require report["total_samples"] >
0 in addition to report["invalid_samples"] == 0 (i.e., return success only when
there is at least one sample and no invalid samples). Locate the block that
calls validate_preprocessed_dataset and references report (the print/report
summary block) and change the return expression to check both
report["total_samples"] > 0 and report["invalid_samples"] == 0; optionally add a
short warning/log when total_samples == 0 before returning failure.
…pull-request-#1 Fix legacy AceStepDataModule init and full-mode resume fallback
…ora-training-pipeline Codex/propose improvements for lora training pipeline
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (2)
acestep/training/data_module_test.py (1)
196-207: Test is valid but coverage gap exists for new features.The regression test for legacy
AceStepDataModuleinitialization is appropriate. However, the PR introducesBucketedBatchSamplerand RAM LRU caching inPreprocessedTensorDatasetwithout corresponding unit tests.Consider adding tests for:
BucketedBatchSampler: bucket grouping, shuffle behavior, batch size enforcement- LRU caching: cache hits/misses, eviction when exceeding
cache_max_itemslatent_lengthscomputation with valid/invalid tensor filesDo you want me to generate unit tests for
BucketedBatchSamplerand the caching behavior, or open an issue to track this task?🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/training/data_module_test.py` around lines 196 - 207, Add unit tests covering the new BucketedBatchSampler and PreprocessedTensorDataset LRU caching and latent_lengths behavior: write tests that instantiate BucketedBatchSampler with a simple length list to verify bucket grouping semantics, that shuffle=True/False affects order but not batch sizes, and that batches respect the specified batch_size; add tests for PreprocessedTensorDataset that simulate loading small tensors to assert cache hit/miss counts and that items are evicted when exceeding cache_max_items (reference PreprocessedTensorDataset.cache_max_items and its internal LRU behavior), and add tests for latent_lengths computation that validate correct lengths for valid tensor files and that invalid/missing tensor files raise or are handled as expected (reference latent_lengths). Ensure tests import BucketedBatchSampler and PreprocessedTensorDataset from the training module and use deterministic seeds/mocks for file I/O to keep them deterministic.acestep/training_v2/trainer_helpers.py (1)
307-425: Splitresume_checkpointinto mode-specific helpers.
resume_checkpointnow bundles path normalization, mode dispatch, state restoration, and user messaging for full/LoKR/LoRA flows. Please extract per-mode resume helpers plus one sharedtraining_state.ptloader to reduce branching drift and simplify maintenance.As per coding guidelines "Function decomposition: do one thing at a time; if a function description naturally contains 'and', split it" and "Split functions by responsibility, not by convenience".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/training_v2/trainer_helpers.py` around lines 307 - 425, Split resume_checkpoint into focused helpers: create a shared _load_training_state(ckpt_dir, module, optimizer, scheduler) that loads training_state.pt and returns (epoch, global_step, state) or None, then implement _resume_full_decoder(trainer, ckpt_dir, module, state_loader) to encapsulate the full-mode logic including calling _load_full_decoder_state and using _load_training_state, _resume_lokr(trainer, ckpt_dir, module, state_loader) to handle lokr_weights.safetensors loading and state restore (use load_lokr_weights and module.lycoris_net), and _resume_lora(trainer, ckpt_dir, module, optimizer, scheduler) to call load_training_checkpoint and load adapter_model.safetensors/.bin into module.model.decoder; finally simplify resume_checkpoint to do path normalization, detect mode and dispatch to these helpers and yield their TrainingUpdate/return values. Ensure helper names (_load_training_state, _resume_full_decoder, _resume_lokr, _resume_lora) are used so callers in resume_checkpoint are easy to find.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@acestep/training_v2/trainer_helpers_test.py`:
- Around line 17-45: Add a new test in trainer_helpers_test.py that exercises
the success path of trainer_helpers.resume_checkpoint: create a
TemporaryDirectory and write a valid full_decoder_state.pt (and optionally
training_state.pt) file(s) with a saved dict containing epoch and global_step,
call resume_checkpoint(trainer, tmpdir, optimizer, scheduler) using the existing
SimpleNamespace trainer/optimizer/scheduler, iterate the returned generator to
completion and assert the final result is (epoch, step), assert
optimizer.load_state_dict and scheduler.load_state_dict were invoked with the
saved state dicts, and assert a non-warning update was emitted; reference
resume_checkpoint, full_decoder_state.pt, training_state.pt,
optimizer.load_state_dict, and scheduler.load_state_dict to locate the code to
test.
- Around line 24-25: The test uses SimpleNamespace with lambdas that accept an
unused parameter (optimizer = SimpleNamespace(load_state_dict=lambda state:
None) and scheduler = SimpleNamespace(load_state_dict=lambda state: None)),
which triggers Ruff ARG005; change the lambda parameter name to be
underscore-prefixed (e.g., lambda _state: None) or use a splat (lambda *_: None)
so the unused argument is explicitly acknowledged and ARG005 is silenced for the
test doubles in trainer_helpers_test.py.
In `@acestep/training/data_module.py`:
- Around line 59-61: The __len__ method currently returns ceil(total_samples /
batch_size) which underestimates when samples are grouped into buckets in
__iter__; change __len__ in data_module.py to sum per-bucket batch counts
instead of using total alone — compute for each bucket: (len(bucket) +
self.batch_size - 1) // self.batch_size and return the sum. If buckets are
stored as self.buckets (lists of indices) iterate over that; if you only have a
per-sample bucket id array (e.g. self.bucket_ids) derive counts per bucket first
and then sum their ceilings; keep the method name __len__ and reference
self.batch_size exactly as in the diff.
---
Nitpick comments:
In `@acestep/training_v2/trainer_helpers.py`:
- Around line 307-425: Split resume_checkpoint into focused helpers: create a
shared _load_training_state(ckpt_dir, module, optimizer, scheduler) that loads
training_state.pt and returns (epoch, global_step, state) or None, then
implement _resume_full_decoder(trainer, ckpt_dir, module, state_loader) to
encapsulate the full-mode logic including calling _load_full_decoder_state and
using _load_training_state, _resume_lokr(trainer, ckpt_dir, module,
state_loader) to handle lokr_weights.safetensors loading and state restore (use
load_lokr_weights and module.lycoris_net), and _resume_lora(trainer, ckpt_dir,
module, optimizer, scheduler) to call load_training_checkpoint and load
adapter_model.safetensors/.bin into module.model.decoder; finally simplify
resume_checkpoint to do path normalization, detect mode and dispatch to these
helpers and yield their TrainingUpdate/return values. Ensure helper names
(_load_training_state, _resume_full_decoder, _resume_lokr, _resume_lora) are
used so callers in resume_checkpoint are easy to find.
In `@acestep/training/data_module_test.py`:
- Around line 196-207: Add unit tests covering the new BucketedBatchSampler and
PreprocessedTensorDataset LRU caching and latent_lengths behavior: write tests
that instantiate BucketedBatchSampler with a simple length list to verify bucket
grouping semantics, that shuffle=True/False affects order but not batch sizes,
and that batches respect the specified batch_size; add tests for
PreprocessedTensorDataset that simulate loading small tensors to assert cache
hit/miss counts and that items are evicted when exceeding cache_max_items
(reference PreprocessedTensorDataset.cache_max_items and its internal LRU
behavior), and add tests for latent_lengths computation that validate correct
lengths for valid tensor files and that invalid/missing tensor files raise or
are handled as expected (reference latent_lengths). Ensure tests import
BucketedBatchSampler and PreprocessedTensorDataset from the training module and
use deterministic seeds/mocks for file I/O to keep them deterministic.
ℹ️ Review info
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
acestep/training/data_module.pyacestep/training/data_module_test.pyacestep/training_v2/trainer_helpers.pyacestep/training_v2/trainer_helpers_test.py
| def test_full_mode_missing_decoder_state_yields_warning(self) -> None: | ||
| """Missing ``full_decoder_state.pt`` should warn and return ``None``.""" | ||
| trainer = SimpleNamespace( | ||
| module=SimpleNamespace(device=torch.device("cpu"), lycoris_net=None), | ||
| training_config=SimpleNamespace(training_mode="full"), | ||
| adapter_type="lora", | ||
| ) | ||
| optimizer = SimpleNamespace(load_state_dict=lambda state: None) | ||
| scheduler = SimpleNamespace(load_state_dict=lambda state: None) | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| generator = trainer_helpers.resume_checkpoint( | ||
| trainer, | ||
| tmpdir, | ||
| optimizer, | ||
| scheduler, | ||
| ) | ||
| updates = [] | ||
| try: | ||
| while True: | ||
| updates.append(next(generator)) | ||
| except StopIteration as stop: | ||
| result = stop.value | ||
|
|
||
| self.assertIsNone(result) | ||
| self.assertEqual(len(updates), 1) | ||
| self.assertEqual(updates[0].kind, "warn") | ||
| self.assertIn("full_decoder_state.pt not found", updates[0].msg) | ||
|
|
There was a problem hiding this comment.
Add a full-mode success-path resume test.
This file currently tests only the missing-file warning path. Please add a success-path case that creates full_decoder_state.pt (and optionally training_state.pt) and verifies returned (epoch, step) plus optimizer/scheduler restore calls.
As per coding guidelines "Include at least: one success-path test, one regression/edge-case test for the bug being fixed, and one non-target behavior check when relevant".
🧰 Tools
🪛 Ruff (0.15.2)
[warning] 24-24: Unused lambda argument: state
(ARG005)
[warning] 25-25: Unused lambda argument: state
(ARG005)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@acestep/training_v2/trainer_helpers_test.py` around lines 17 - 45, Add a new
test in trainer_helpers_test.py that exercises the success path of
trainer_helpers.resume_checkpoint: create a TemporaryDirectory and write a valid
full_decoder_state.pt (and optionally training_state.pt) file(s) with a saved
dict containing epoch and global_step, call resume_checkpoint(trainer, tmpdir,
optimizer, scheduler) using the existing SimpleNamespace
trainer/optimizer/scheduler, iterate the returned generator to completion and
assert the final result is (epoch, step), assert optimizer.load_state_dict and
scheduler.load_state_dict were invoked with the saved state dicts, and assert a
non-warning update was emitted; reference resume_checkpoint,
full_decoder_state.pt, training_state.pt, optimizer.load_state_dict, and
scheduler.load_state_dict to locate the code to test.
| optimizer = SimpleNamespace(load_state_dict=lambda state: None) | ||
| scheduler = SimpleNamespace(load_state_dict=lambda state: None) |
There was a problem hiding this comment.
Silence Ruff ARG005 for test doubles.
Use underscore-prefixed lambda arguments for intentionally unused parameters.
Proposed lint-only fix
- optimizer = SimpleNamespace(load_state_dict=lambda state: None)
- scheduler = SimpleNamespace(load_state_dict=lambda state: None)
+ optimizer = SimpleNamespace(load_state_dict=lambda _state: None)
+ scheduler = SimpleNamespace(load_state_dict=lambda _state: None)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| optimizer = SimpleNamespace(load_state_dict=lambda state: None) | |
| scheduler = SimpleNamespace(load_state_dict=lambda state: None) | |
| optimizer = SimpleNamespace(load_state_dict=lambda _state: None) | |
| scheduler = SimpleNamespace(load_state_dict=lambda _state: None) |
🧰 Tools
🪛 Ruff (0.15.2)
[warning] 24-24: Unused lambda argument: state
(ARG005)
[warning] 25-25: Unused lambda argument: state
(ARG005)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@acestep/training_v2/trainer_helpers_test.py` around lines 24 - 25, The test
uses SimpleNamespace with lambdas that accept an unused parameter (optimizer =
SimpleNamespace(load_state_dict=lambda state: None) and scheduler =
SimpleNamespace(load_state_dict=lambda state: None)), which triggers Ruff
ARG005; change the lambda parameter name to be underscore-prefixed (e.g., lambda
_state: None) or use a splat (lambda *_: None) so the unused argument is
explicitly acknowledged and ARG005 is silenced for the test doubles in
trainer_helpers_test.py.
…into codex/fix-high-priority-bugs-in-pull-request-#1-fvat6p
…pull-request-#1-fvat6p Codex/fix high priority bugs in pull request #1 fvat6p
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
acestep/training/data_module.py (1)
84-90:⚠️ Potential issue | 🟡 MinorUpdate
__init__docstrings for newly added parameters.Both modified constructors added cache/bucketing args, but their docstrings still document only the previous parameter set.
📝 Proposed docstring updates
def __init__(self, tensor_dir: str, cache_policy: str = "none", cache_max_items: int = 0): """Initialize from a directory of preprocessed .pt files. Args: tensor_dir: Directory containing preprocessed .pt files and manifest.json + cache_policy: Cache mode ("none" or "ram_lru") + cache_max_items: Maximum number of cached items when using RAM LRUdef __init__( self, tensor_dir: str, @@ ): """Initialize the data module. @@ Args: tensor_dir: Directory containing preprocessed .pt files batch_size: Training batch size num_workers: Number of data loading workers pin_memory: Whether to pin memory for faster GPU transfer val_split: Fraction of data for validation (0 = no validation) + length_bucket: Whether to bucket training samples by latent length + cache_policy: Dataset cache mode ("none" or "ram_lru") + cache_max_items: Maximum number of cached items when using RAM LRU """As per coding guidelines: "Docstrings are mandatory for all new or modified Python modules, classes, and functions" and "Docstrings must be concise and include purpose plus key inputs/outputs (and raised exceptions when relevant)".
Also applies to: 296-304
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/training/data_module.py` around lines 84 - 90, The __init__ docstring for the constructor (method __init__ in data_module.py) was not updated to document the newly added parameters; update the docstring for __init__ to briefly state the purpose of the class, list and describe all parameters including the new cache_policy (str, allowed values like "none"/"lru"/etc. and default "none") and cache_max_items (int, default 0 and effect on caching), and mention any raised exceptions; make the same concise updates to the other modified constructor docstring referenced around lines 296-304 (identify the other __init__ there) so both constructors include purpose plus key inputs/outputs and raised exceptions per project docstring guidelines.
♻️ Duplicate comments (3)
acestep/training_v2/trainer_helpers.py (1)
425-438:⚠️ Potential issue | 🔴 CriticalRestore LoKR resume success path and remove unreachable full-mode branch.
After Line 425, a successful
training_stateload has no return path, so_resume_lokrimplicitly returnsNoneand_resume_lokr_or_lorafalls through to LoRA resume. Also, the full-mode check at Line 431 is unreachable becauseresume_checkpointalready routes full mode to_resume_full_decoder.💡 Proposed fix
training_state = state_loader(ckpt_dir=ckpt_dir) if training_state is None: yield TrainingUpdate(0, 0.0, "[OK] LoKR weights loaded (no training state)", kind="info") return None - # Full mode resume: no full decoder checkpoint found. - if getattr(trainer.training_config, "training_mode", "adapter") == "full": - yield TrainingUpdate( - 0, - 0.0, - f"[WARN] full_decoder_state.pt not found in {ckpt_dir}", - kind="warn", - ) - return + epoch, step, _ = training_state + yield TrainingUpdate( + 0, + 0.0, + f"[OK] Resumed LoKR from epoch {epoch}, step {step}", + kind="info", + ) + return (epoch, step)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/training_v2/trainer_helpers.py` around lines 425 - 438, The success path after calling state_loader (training_state) currently yields a TrainingUpdate but then falls through and returns None implicitly, causing _resume_lokr to not signal a completed LoKR resume and letting _resume_lokr_or_lora continue to LoRA resume; also the subsequent full-mode check (getattr(trainer.training_config, "training_mode", "adapter") == "full") is unreachable because resume_checkpoint routes full mode to _resume_full_decoder. Fix by making the training_state success path explicitly return (e.g., return a truthy result or stop further processing) so _resume_lokr exits after yielding the "[OK] LoKR weights loaded" update, and remove the unreachable full-mode branch block (or move full-mode handling into the correct resume_checkpoint flow); refer to state_loader, training_state, _resume_lokr, _resume_lokr_or_lora, resume_checkpoint, _resume_full_decoder, and trainer.training_config to locate and update the logic.acestep/training/data_module.py (2)
131-135:⚠️ Potential issue | 🟡 MinorNarrow exception handling when deriving
latent_lengths.
except Exceptionhides corrupt sample/schema issues and silently forces length0, which can skew bucketing. Catch expected load/key errors, log path + error, and let unexpected exceptions surface.🔧 Proposed fix
for vp in self.valid_paths: try: sample = torch.load(vp, map_location="cpu", weights_only=True) self.latent_lengths.append(int(sample["target_latents"].shape[0])) - except Exception: + except ( + FileNotFoundError, + PermissionError, + EOFError, + OSError, + KeyError, + RuntimeError, + ) as exc: + logger.warning(f"Failed to read latent length from {vp}: {exc}") self.latent_lengths.append(0)Based on learnings: Applies to **/*.py : Handle errors explicitly in Python; avoid bare
except.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/training/data_module.py` around lines 131 - 135, The current broad except around torch.load in the latent length derivation swallows all errors and appends 0; change it to catch only expected errors (e.g., RuntimeError, EOFError, KeyError) when loading/parsing the sample returned by torch.load and on those cases append 0 but also log the file path (vp) and the exception via the module or class logger (e.g., self.logger or logging.getLogger) so the failure is visible; for any other unexpected exceptions re-raise them so they surface instead of being hidden. Ensure this change is applied to the torch.load + self.latent_lengths.append(...) block so only known load/key issues are handled gracefully while other errors propagate.
362-369:⚠️ Potential issue | 🟠 MajorBucketing is skipped when
train_datasetis aSubset(val_split > 0).After
random_split,self.train_datasetis atorch.utils.data.Subset, so the currenthasattr(..., "latent_lengths")check fails and bucketing is not used (withshuffle=Falsein this mode). Resolve latent lengths from the wrapped dataset + indices.🐛 Proposed fix
- if self.length_bucket and hasattr(self.train_dataset, "latent_lengths"): + latent_lengths: Optional[List[int]] = None + if self.length_bucket: + ds = self.train_dataset + if ( + isinstance(ds, torch.utils.data.Subset) + and hasattr(ds.dataset, "latent_lengths") + ): + latent_lengths = [ds.dataset.latent_lengths[i] for i in ds.indices] + elif hasattr(ds, "latent_lengths"): + latent_lengths = list(getattr(ds, "latent_lengths", [])) + + if latent_lengths is not None: kwargs.pop("batch_size", None) kwargs.pop("shuffle", None) kwargs["batch_sampler"] = BucketedBatchSampler( - lengths=list(getattr(self.train_dataset, "latent_lengths", [])), + lengths=latent_lengths, batch_size=self.batch_size, shuffle=True, ) + elif self.length_bucket: + kwargs["shuffle"] = True🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/training/data_module.py` around lines 362 - 369, The bucketing check skips when train_dataset is a torch.utils.data.Subset because hasattr(self.train_dataset, "latent_lengths") is false; update the logic in the DataModule where BucketedBatchSampler is created to detect a Subset (torch.utils.data.Subset) and extract latent_lengths from the wrapped dataset using the Subset.indices (e.g., resolve base = self.train_dataset.dataset and indices = self.train_dataset.indices), build lengths = [getattr(base, "latent_lengths")[i] for i in indices] (falling back to getattr(..., []) if missing), then pass that lengths list into BucketedBatchSampler; keep the existing batch_size/shuffle pop behavior and the BucketedBatchSampler instantiation name to locate the change.
🧹 Nitpick comments (2)
acestep/training_v2/trainer_helpers.py (1)
137-507: Split checkpoint save/load/resume paths into smaller modules.This file now combines many responsibilities (save, verify, load, resume orchestration) and exceeds the module hard cap, which makes future changes riskier.
Based on learnings: raise module-size concerns only when a Python file exceeds 200 LOC; this file is above that threshold. As per coding guidelines:
**/*.py: Target module size: optimal <= 150 LOC, hard cap 200 LOC.acestep/training/data_module.py (1)
36-391: This module exceeds the Python hard cap and now mixes too many responsibilities.
data_module.pyis well beyond 200 LOC and now contains sampler logic, multiple datasets, and multiple DataModules in one place. Please split this into focused modules to reduce maintenance risk.As per coding guidelines:
**/*.py: "Target module size: optimal <= 150 LOC, hard cap 200 LOC".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/training/data_module.py` around lines 36 - 391, This file is too large and mixes responsibilities; split it into focused modules: move BucketedBatchSampler into a new sampler module (e.g., sampler.py) and export the class there; move PreprocessedTensorDataset and its helper _resolve_manifest_path into a dataset module (e.g., preprocessed_dataset.py); move collate_preprocessed_batch into a collate module (e.g., collate.py); keep PreprocessedDataModule in its own data_module.py that imports BucketedBatchSampler, PreprocessedTensorDataset, and collate_preprocessed_batch; update imports in PreprocessedDataModule to reference the new modules and ensure any shared utilities (safe_path, logger) are imported from their original locations. Ensure behavior and public APIs (class/function names and signatures: BucketedBatchSampler, PreprocessedTensorDataset, collate_preprocessed_batch, PreprocessedDataModule) remain unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@acestep/training_v2/trainer_helpers.py`:
- Around line 472-475: The current code unwraps the decoder by checking only one
_forward_module (decoder = module.model.decoder; if hasattr(decoder,
"_forward_module"): decoder = decoder._forward_module) which is brittle; replace
that logic by calling the existing helper _unwrap_decoder(module.model) to
obtain the fully unwrapped decoder before calling
decoder.load_state_dict(state_dict, strict=False), i.e., locate the block around
decoder.load_state_dict and use _unwrap_decoder(module.model) to get the decoder
so LoRA keys under nested wrappers are correctly loaded.
- Line 352: The type annotation on the function returning Optional[Tuple[int,
int, Dict[str, Any]]] uses typing.Dict which isn't imported; change the
annotation to use the built-in generic syntax dict[str, Any] (i.e.,
Optional[Tuple[int, int, dict[str, Any]]]) in the function signature where that
return type appears (locate the signature around the function named in
trainer_helpers.py that currently returns Optional[Tuple[int, int, Dict[str,
Any]]]); also remove any unused import of Dict from typing if present to avoid
F401.
In `@acestep/training/data_module_test.py`:
- Around line 201-212: Add deterministic tests in
acestep/training/data_module_test.py covering the new bucketing/cache behavior
for AceStepDataModule: create tests that instantiate AceStepDataModule with
length_bucket enabled to verify samples are assigned to expected buckets
(success path), a regression/edge-case that uses ram_lru caching with a tiny
capacity to confirm eviction/split-dataset behavior occurs deterministically,
and a non-target behavior check ensuring legacy initialization (samples=[],
dit_handler=...) still works when these options are present; reference the
AceStepDataModule constructor and its length_bucket and ram_lru configuration
flags/symbols to locate where to set parameters and assert bucket assignments,
eviction counts, and that dit_handler remains set.
---
Outside diff comments:
In `@acestep/training/data_module.py`:
- Around line 84-90: The __init__ docstring for the constructor (method __init__
in data_module.py) was not updated to document the newly added parameters;
update the docstring for __init__ to briefly state the purpose of the class,
list and describe all parameters including the new cache_policy (str, allowed
values like "none"/"lru"/etc. and default "none") and cache_max_items (int,
default 0 and effect on caching), and mention any raised exceptions; make the
same concise updates to the other modified constructor docstring referenced
around lines 296-304 (identify the other __init__ there) so both constructors
include purpose plus key inputs/outputs and raised exceptions per project
docstring guidelines.
---
Duplicate comments:
In `@acestep/training_v2/trainer_helpers.py`:
- Around line 425-438: The success path after calling state_loader
(training_state) currently yields a TrainingUpdate but then falls through and
returns None implicitly, causing _resume_lokr to not signal a completed LoKR
resume and letting _resume_lokr_or_lora continue to LoRA resume; also the
subsequent full-mode check (getattr(trainer.training_config, "training_mode",
"adapter") == "full") is unreachable because resume_checkpoint routes full mode
to _resume_full_decoder. Fix by making the training_state success path
explicitly return (e.g., return a truthy result or stop further processing) so
_resume_lokr exits after yielding the "[OK] LoKR weights loaded" update, and
remove the unreachable full-mode branch block (or move full-mode handling into
the correct resume_checkpoint flow); refer to state_loader, training_state,
_resume_lokr, _resume_lokr_or_lora, resume_checkpoint, _resume_full_decoder, and
trainer.training_config to locate and update the logic.
In `@acestep/training/data_module.py`:
- Around line 131-135: The current broad except around torch.load in the latent
length derivation swallows all errors and appends 0; change it to catch only
expected errors (e.g., RuntimeError, EOFError, KeyError) when loading/parsing
the sample returned by torch.load and on those cases append 0 but also log the
file path (vp) and the exception via the module or class logger (e.g.,
self.logger or logging.getLogger) so the failure is visible; for any other
unexpected exceptions re-raise them so they surface instead of being hidden.
Ensure this change is applied to the torch.load +
self.latent_lengths.append(...) block so only known load/key issues are handled
gracefully while other errors propagate.
- Around line 362-369: The bucketing check skips when train_dataset is a
torch.utils.data.Subset because hasattr(self.train_dataset, "latent_lengths") is
false; update the logic in the DataModule where BucketedBatchSampler is created
to detect a Subset (torch.utils.data.Subset) and extract latent_lengths from the
wrapped dataset using the Subset.indices (e.g., resolve base =
self.train_dataset.dataset and indices = self.train_dataset.indices), build
lengths = [getattr(base, "latent_lengths")[i] for i in indices] (falling back to
getattr(..., []) if missing), then pass that lengths list into
BucketedBatchSampler; keep the existing batch_size/shuffle pop behavior and the
BucketedBatchSampler instantiation name to locate the change.
---
Nitpick comments:
In `@acestep/training/data_module.py`:
- Around line 36-391: This file is too large and mixes responsibilities; split
it into focused modules: move BucketedBatchSampler into a new sampler module
(e.g., sampler.py) and export the class there; move PreprocessedTensorDataset
and its helper _resolve_manifest_path into a dataset module (e.g.,
preprocessed_dataset.py); move collate_preprocessed_batch into a collate module
(e.g., collate.py); keep PreprocessedDataModule in its own data_module.py that
imports BucketedBatchSampler, PreprocessedTensorDataset, and
collate_preprocessed_batch; update imports in PreprocessedDataModule to
reference the new modules and ensure any shared utilities (safe_path, logger)
are imported from their original locations. Ensure behavior and public APIs
(class/function names and signatures: BucketedBatchSampler,
PreprocessedTensorDataset, collate_preprocessed_batch, PreprocessedDataModule)
remain unchanged.
ℹ️ Review info
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
acestep/training/data_module.pyacestep/training/data_module_test.pyacestep/training_v2/trainer_helpers.py
| module: Any, | ||
| optimizer: Any, | ||
| scheduler: Any, | ||
| ) -> Optional[Tuple[int, int, Dict[str, Any]]]: |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Find the file
file_path=$(find . -name "trainer_helpers.py" -path "*/training_v2/*" 2>/dev/null | head -1)
if [ -z "$file_path" ]; then
echo "File not found"
exit 1
fi
echo "=== File found at: $file_path ==="
echo ""
# Show line 352 and surrounding context
echo "=== Context around line 352 ==="
sed -n '340,360p' "$file_path" | cat -n
echo ""
echo "=== Checking imports at top of file ==="
head -40 "$file_path" | grep -E "^import|^from" || echo "No imports found in first 40 lines"
echo ""
echo "=== Running verification script ==="
python - <<'PY'
import ast
from pathlib import Path
path = next(Path(".").rglob("trainer_helpers.py"))
tree = ast.parse(path.read_text(encoding="utf-8"))
typing_imports = set()
for node in tree.body:
if isinstance(node, ast.ImportFrom) and node.module == "typing":
typing_imports.update(alias.name for alias in node.names)
dict_refs = []
for node in ast.walk(tree):
if isinstance(node, ast.Name) and node.id == "Dict":
dict_refs.append((node.lineno, 'Dict' in typing_imports))
print(f"Found {len(dict_refs)} reference(s) to 'Dict'")
print(f"Typing imports: {typing_imports}")
for line_no, is_imported in dict_refs:
print(f" Line {line_no}: imported_from_typing={is_imported}")
PYRepository: ace-step/ACE-Step-1.5
Length of output: 1470
Use built-in dict generic for type annotation.
Line 352 references Dict which is not imported from typing, triggering F821. Since the project targets Python 3.11-3.12, use the built-in dict[str, Any] syntax instead:
Proposed fix
-) -> Optional[Tuple[int, int, Dict[str, Any]]]:
+) -> Optional[Tuple[int, int, dict[str, Any]]]:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| ) -> Optional[Tuple[int, int, Dict[str, Any]]]: | |
| ) -> Optional[Tuple[int, int, dict[str, Any]]]: |
🧰 Tools
🪛 Ruff (0.15.2)
[error] 352-352: Undefined name Dict
(F821)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@acestep/training_v2/trainer_helpers.py` at line 352, The type annotation on
the function returning Optional[Tuple[int, int, Dict[str, Any]]] uses
typing.Dict which isn't imported; change the annotation to use the built-in
generic syntax dict[str, Any] (i.e., Optional[Tuple[int, int, dict[str, Any]]])
in the function signature where that return type appears (locate the signature
around the function named in trainer_helpers.py that currently returns
Optional[Tuple[int, int, Dict[str, Any]]]); also remove any unused import of
Dict from typing if present to avoid F401.
| decoder = module.model.decoder | ||
| if hasattr(decoder, "_forward_module"): | ||
| decoder = decoder._forward_module | ||
| decoder.load_state_dict(state_dict, strict=False) |
There was a problem hiding this comment.
Load LoRA weights into the fully unwrapped decoder.
At Line 472, targeting module.model.decoder with only one _forward_module unwrap is brittle and can silently skip keys under other wrappers. Use _unwrap_decoder(module.model) for consistency with save/load helpers.
💡 Proposed fix
- decoder = module.model.decoder
- if hasattr(decoder, "_forward_module"):
- decoder = decoder._forward_module
+ decoder = _unwrap_decoder(module.model)
decoder.load_state_dict(state_dict, strict=False)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@acestep/training_v2/trainer_helpers.py` around lines 472 - 475, The current
code unwraps the decoder by checking only one _forward_module (decoder =
module.model.decoder; if hasattr(decoder, "_forward_module"): decoder =
decoder._forward_module) which is brittle; replace that logic by calling the
existing helper _unwrap_decoder(module.model) to obtain the fully unwrapped
decoder before calling decoder.load_state_dict(state_dict, strict=False), i.e.,
locate the block around decoder.load_state_dict and use
_unwrap_decoder(module.model) to get the decoder so LoRA keys under nested
wrappers are correctly loaded.
| class AceStepDataModuleInitTests(unittest.TestCase): | ||
| """Regression tests for legacy ``AceStepDataModule`` initialization.""" | ||
|
|
||
| def test_init_does_not_require_preprocessed_only_cache_args(self): | ||
| """Legacy raw-audio datamodule should initialize without NameError.""" | ||
| from acestep.training.data_module import AceStepDataModule | ||
|
|
||
| module = AceStepDataModule(samples=[], dit_handler=object()) | ||
|
|
||
| self.assertEqual(module.samples, []) | ||
| self.assertIsNotNone(module.dit_handler) | ||
|
|
There was a problem hiding this comment.
Coverage is missing for the new bucketing/cache behaviors introduced by this PR.
This addition only validates legacy initialization; it does not test the new length_bucket and ram_lru paths (including split-dataset behavior). Please add deterministic tests for those new code paths.
As per coding guidelines: **/*_test.py: "Add or update tests for every behavior change and bug fix" and "Include at least: one success-path test, one regression/edge-case test for the bug being fixed, and one non-target behavior check when relevant".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@acestep/training/data_module_test.py` around lines 201 - 212, Add
deterministic tests in acestep/training/data_module_test.py covering the new
bucketing/cache behavior for AceStepDataModule: create tests that instantiate
AceStepDataModule with length_bucket enabled to verify samples are assigned to
expected buckets (success path), a regression/edge-case that uses ram_lru
caching with a tiny capacity to confirm eviction/split-dataset behavior occurs
deterministically, and a non-target behavior check ensuring legacy
initialization (samples=[], dit_handler=...) still works when these options are
present; reference the AceStepDataModule constructor and its length_bucket and
ram_lru configuration flags/symbols to locate where to set parameters and assert
bucket assignments, eviction counts, and that dit_handler remains set.
…ora-training-pipeline Add bucketed/cached preprocessed dataset, dataset validator, and full fine-tune mode
There was a problem hiding this comment.
Actionable comments posted: 4
♻️ Duplicate comments (2)
acestep/training/data_module.py (2)
129-135:⚠️ Potential issue | 🟠 MajorAvoid bare
except Exceptionin latent-length loading.This swallows unexpected failures and silently maps them to length
0, which can mask real data corruption or schema issues.🔧 Proposed fix
self.latent_lengths: List[int] = [] for vp in self.valid_paths: try: sample = torch.load(vp, map_location="cpu", weights_only=True) self.latent_lengths.append(int(sample["target_latents"].shape[0])) - except Exception: + except (FileNotFoundError, PermissionError, EOFError, OSError, KeyError, RuntimeError) as exc: + logger.warning(f"Failed to read latent length from {vp}: {exc}") self.latent_lengths.append(0)#!/bin/bash rg -nP '^\s*except\s+Exception\s*:' --type=pyBased on learnings: Applies to **/*.py : Handle errors explicitly in Python; avoid bare
except.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/training/data_module.py` around lines 129 - 135, The bare "except Exception" in the loop that loads samples for self.valid_paths (where you call torch.load and access sample["target_latents"] to populate self.latent_lengths) should be replaced with explicit error handling: catch only the expected exceptions (e.g., FileNotFoundError, RuntimeError from torch.load, KeyError when "target_latents" is missing) and handle each case (log the error with context including the path and exception, append 0 only for recoverable/known issues), while letting unexpected exceptions propagate (or re-raise) so real bugs or schema changes are not silently swallowed.
375-392:⚠️ Potential issue | 🟠 MajorBucketing still bypasses
Subsetand the new helper is unused.
_resolve_train_latent_lengths()was added, buttrain_dataloader()still checkshasattr(self.train_dataset, "latent_lengths"). Withrandom_split, this skips bucketing and may also leave training unshuffled whenlength_bucket=True.🔧 Proposed fix
kwargs = dict( dataset=self.train_dataset, batch_size=self.batch_size, - shuffle=not self.length_bucket, + shuffle=True, num_workers=self.num_workers, pin_memory=self.pin_memory, collate_fn=collate_preprocessed_batch, drop_last=False, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, ) if self.pin_memory_device: kwargs["pin_memory_device"] = self.pin_memory_device - if self.length_bucket and hasattr(self.train_dataset, "latent_lengths"): + latent_lengths = self._resolve_train_latent_lengths() + if latent_lengths is not None: kwargs.pop("batch_size", None) kwargs.pop("shuffle", None) kwargs["batch_sampler"] = BucketedBatchSampler( - lengths=list(getattr(self.train_dataset, "latent_lengths", [])), + lengths=latent_lengths, batch_size=self.batch_size, shuffle=True, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/training/data_module.py` around lines 375 - 392, train_dataloader currently checks hasattr(self.train_dataset, "latent_lengths") and thus skips bucketing when train_dataset is a Subset (e.g., from random_split) and the new helper _resolve_train_latent_lengths() is unused; update train_dataloader to call _resolve_train_latent_lengths() (or otherwise obtain latent lengths via that helper) and use its result to decide whether to set batch_sampler to BucketedBatchSampler and remove batch_size/shuffle, ensuring Subset-wrapped datasets are handled and shuffling is preserved when length_bucket=True; refer to symbols train_dataloader, _resolve_train_latent_lengths, self.train_dataset, length_bucket, BucketedBatchSampler, and Subset when making the change.
🧹 Nitpick comments (9)
docs/lora_full_finetune_pipeline_proposal.md (7)
34-40: Document bucketing's impact on training dynamics.While latent-length bucketing improves throughput by reducing padding waste, it can affect training dynamics if samples are not properly shuffled within and across buckets. The proposal mentions "deterministic shuffling by bucket + epoch seed," which is good.
Consider adding a note about:
- Validating that bucketing doesn't introduce unintended correlations in training order
- Including bucket-aware shuffle validation in the QA command (A6)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@docs/lora_full_finetune_pipeline_proposal.md` around lines 34 - 40, Add a short note under A2 (Add duration/latent-length bucketing in dataloader) that describes potential training-dynamics risks from bucketing (e.g., unintended correlations from grouping by length) and prescribe validation steps: (1) run epoch-level bucket-aware shuffle checks to ensure samples are randomly interleaved across epochs and within buckets, (2) compare training metrics (loss/accuracy/validation curves) with and without bucketing to detect shifts, and (3) include a bullet in QA command A6 to execute these bucket-aware shuffle validations and metric comparisons as part of the QA checklist.
133-144: Consider clarifying measurement methodology for acceptance criteria.The acceptance criteria are well-defined with specific targets, but some measurement methodologies could be clarified for reproducibility.
Consider specifying:
- How "data-loader stall time" will be measured (profiling tool, specific metrics)
- What constitutes "supported GPU setup" for full FT testing (minimum VRAM, specific models)
- How to validate safety checks (e.g., test cases with known OOM-prone configs)
This would make the criteria more actionable during validation.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@docs/lora_full_finetune_pipeline_proposal.md` around lines 133 - 144, Update the Acceptance criteria paragraph to specify exact measurement methods: define "data-loader stall time" measurement using a profiler (e.g., PyTorch profiler or perf metrics) and the exact metric (percent time waiting vs total batch time) and sampling window; define "step time" measurement method (median of N runs with fixed seed and dataset shard, include warmup steps and measurement steps); enumerate what "supported GPU setup" means (minimum VRAM, example device models such as NVIDIA A100/RTX3090, CUDA/cuDNN versions) and the test protocol for "Successfully trains decoder-only full FT" (one epoch, fixed batch size, seed, and dataset subset); and add explicit validation for "CLI safety checks" by listing specific OOM-prone config test cases and expected CLI responses (warnings/failures) so acceptance tests are reproducible.
20-32: Consider adding migration tooling to the implementation plan.While backward compatibility and auto-detection are mentioned, the proposal doesn't explicitly address tooling to migrate existing datasets from the per-sample
.ptformat to the new v3 sharded format. Users with large existing datasets will need a conversion utility.Consider adding to the implementation plan:
- A dataset migration command (e.g.,
convert-dataset --input-format pt --output-format v3_sharded)- Documentation on when migration is worthwhile vs. keeping existing format
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@docs/lora_full_finetune_pipeline_proposal.md` around lines 20 - 32, Add migration tooling and docs for converting per-sample `.pt` datasets to the new v3 sharded format: implement a CLI command (e.g., convert-dataset or convert_dataset) that accepts flags matching the loader options (--dataset-format {pt,v3_sharded}, input/output paths) to produce `shard-xxxxx.safetensors` and the corresponding `index.jsonl`, include progress/error reporting and verification (checksums/sample counts), and update the proposal and docs with guidance on when to migrate vs keep the `.pt` format and examples of using the migration command.
99-101: Specify VRAM estimation methodology for accuracy.The preflight VRAM estimator is a critical safety feature, but the proposal doesn't specify the estimation methodology. Inaccurate estimates could lead to user distrust or unexpected OOM failures.
Consider documenting:
- Whether estimation will be empirical (formula-based) or profile-based (dry-run with dummy tensors)
- How to account for framework overhead, optimizer states, and gradient accumulation
- Acceptable margin of error and how to communicate uncertainty to users
- Fallback behavior when estimation confidence is low
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@docs/lora_full_finetune_pipeline_proposal.md` around lines 99 - 101, Update the proposal to specify the VRAM estimation methodology for the "preflight VRAM estimator" referenced alongside the opt-in flags (`--training-mode full --i-understand-vram-risk`) and the "Auto-suggest fallback to LoRA" behavior: state whether the estimator is formula-based (analytical) or profile-based (dry-run with dummy tensors), describe how you will account for framework/runtime overhead, optimizer/AMP/optimizer state sizes and gradient accumulation, define an acceptable margin-of-error/confidence interval and how that uncertainty is surfaced to users, and specify explicit fallback behavior and messaging when confidence is low (e.g., suggest LoRA, require extra confirmation, or perform an on-demand dry-run).
147-151: Minor style consideration: repetitive sentence structure.The static analysis tool noted that lines 148-151 all begin with "Add," which creates repetitive sentence structure. While this is clear and easy to scan, you could optionally vary the structure for better readability.
Example variations:
- "Implement
DatasetBackendabstraction..."- "Introduce
BucketedBatchSampler..."- "Extend
train.pywithvalidate-datasetcommand..."However, the current structure is perfectly functional for a technical proposal, so this is purely a stylistic consideration.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@docs/lora_full_finetune_pipeline_proposal.md` around lines 147 - 151, The four bullet items in section "Immediate next patch candidates" use the same leading verb "Add", creating repetitive sentence structure; update each line to vary the verb while preserving meaning and the referenced symbols (e.g., change "Add `DatasetBackend` abstraction and a sharded backend implementation." to "Implement `DatasetBackend` abstraction with a sharded backend implementation.", change "Add `BucketedBatchSampler` keyed by latent length." to "Introduce `BucketedBatchSampler` keyed by latent length.", change "Add `train.py validate-dataset --dataset-dir ...` command." to "Extend `train.py` with a `validate-dataset --dataset-dir ...` command.", and change "Add `training_mode` enum + `full` branch in config/trainer with decoder-only unfreeze." to "Add a `training_mode` enum and a `full` branch in config/trainer that enables decoder-only unfreeze."); keep the same symbols (`DatasetBackend`, `BucketedBatchSampler`, `train.py validate-dataset`, `training_mode` and `full`) so reviewers can locate the changes.
109-130: Consider adding testing strategy to implementation phases.The phased implementation plan is well-structured, but doesn't explicitly address testing strategy for each phase. Given the complexity of the changes, especially around full fine-tuning, a testing plan would strengthen the proposal.
Consider adding for each phase:
- Unit test coverage expectations (e.g., new samplers, cache policies, parameter grouping)
- Integration test scenarios (e.g., end-to-end training with new features)
- Backward compatibility validation (especially for Phase 1-2 data format changes)
- Performance regression tests (to validate throughput improvements)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@docs/lora_full_finetune_pipeline_proposal.md` around lines 109 - 130, Add a testing strategy tied to each phase: for Phase 1, add unit tests for the packed dataset writer/loader, bucketed sampler, and preprocessing QA validator (e.g., PackedDatasetWriterTests, BucketedSamplerTests, QAValidatorTests) plus backward-compatibility tests for old vs. new data formats; for Phase 2, add unit/integration tests for mmap/LRU cache policy and optional device prefetch (CachePolicyTests, DevicePrefetchIntegration) and performance regression benchmarks that validate throughput improvements and record baseline metrics by GPU class; for Phase 3, add unit tests for training_mode=full behavior, parameter group LR multipliers, and full-state checkpoint/resume (TrainingModeFullTests, ParamGroupTests, CheckpointResumeTests) as well as end-to-end training integration tests; for Phase 4, add tests for distributed/sharded optimizer options, staged unfreeze profiles, and eval/early-stop controls (DistributedOptimizerTests, StagedUnfreezeTests, EarlyStopIntegration) and ensure all phases are wired into CI with clear pass/fail criteria, recorded performance thresholds, and reproducible test data fixtures.
93-96: Consider checkpoint format versioning.As the full fine-tuning feature evolves across phases, checkpoint format may change. Consider adding explicit version metadata to checkpoints to ensure graceful handling of incompatible formats during resume.
This could prevent cryptic errors when users try to resume from checkpoints created with older code versions.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@docs/lora_full_finetune_pipeline_proposal.md` around lines 93 - 96, Update the checkpointing metadata to include an explicit checkpoint format version field (e.g., "checkpoint_version" using semantic versioning) whenever saving the full model state, sharded states, optimizer, scheduler, scaler, or EMA checkpoints; on resume (load) validate this metadata in the resume logic and emit a clear, actionable error if the version is incompatible (or route to compatibility/upgrade handlers), and ensure any code paths that change the on-disk layout bump the checkpoint_version and add a loader/upgrade path for older versions to avoid cryptic failures.acestep/training/preprocessed_collate.py (1)
10-61: Deduplicate collate logic to prevent drift.This implementation is effectively duplicated in
acestep/training/data_module.py(Lines 206–271 from provided context). Keep a single canonical function and import it where needed to avoid divergence during future edits.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/training/preprocessed_collate.py` around lines 10 - 61, The collate logic is duplicated; keep collate_preprocessed_batch as the single canonical function and remove the copy in data_module.py: delete the duplicate function body there (the one that pads target_latents, attention_mask, context_latents, encoder_hidden_states, and encoder_attention_mask) and replace it with an import of collate_preprocessed_batch from this module, updating any reference in DataModule (or functions that previously called the duplicate) to call collate_preprocessed_batch; run a quick search for other copies and update imports/usages so all code uses the single collate_preprocessed_batch implementation to avoid drift.acestep/training/data_module.py (1)
36-68: De-duplicate sampler/dataset definitions to a single source of truth.
BucketedBatchSamplerandPreprocessedTensorDatasetnow exist both here and inacestep/training/preprocessed_sampler.py+acestep/training/preprocessed_dataset.py, and behavior has already drifted. Import shared implementations instead of maintaining parallel copies.Based on learnings: only raise module-size concerns when a Python file exceeds 200 LOC; this file exceeds that threshold. As per coding guidelines:
**/*.pyhard cap is 200 LOC and functions/modules should be split by responsibility.Also applies to: 71-205
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/training/data_module.py` around lines 36 - 68, This file duplicates BucketedBatchSampler and PreprocessedTensorDataset which have drifted from the canonical implementations; remove the local classes and import BucketedBatchSampler from the shared preprocessed_sampler module and PreprocessedTensorDataset from the shared preprocessed_dataset module (update all usages to the imported symbols), then split any remaining functionality in this file so the module stays under the 200‑line guideline by moving unrelated helpers/classes into separate modules; ensure tests/imports still reference the imported names and run linters to catch any unresolved references.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@acestep/training/preprocessed_collate.py`:
- Around line 5-10: The return annotation Dict[str, torch.Tensor] is incorrect
because collate_preprocessed_batch returns a mix of tensors and a Python list
for "metadata"; define a TypedDict (e.g., PreprocessedBatch: TypedDict) that
lists each returned key and its precise type (for example tensor fields as
torch.Tensor and "metadata" as List[Dict[str, Any]]), import TypedDict and Any
from typing, and update the return type of collate_preprocessed_batch (in this
module) and the matching function in data_module.py to -> PreprocessedBatch so
static typing matches the actual payload.
In `@acestep/training/preprocessed_dataset.py`:
- Around line 96-115: The return type of __getitem__ is too narrow: it returns
tensors for most keys but "metadata" is a dict (or could be other non-tensor
values), so update the type hint for __getitem__ from Dict[str, torch.Tensor] to
a broader type such as Dict[str, Any] (and add Any to the typing imports) so the
signature reflects that values can be non-tensor; ensure the change is applied
to the __getitem__ definition and any related annotations in the same module
that reference this method.
- Around line 57-71: The logger calls in PreprocessedTensorDataset use
percent-style formatting which Loguru ignores; update all logger.warning and
logger.info usages in this region (the calls inside the constructor/initializer
that reference logger.warning("Some tensor files not found: %d missing", ...),
logger.warning("Failed to read latent length from %s: %s", vp, exc), and
logger.info("PreprocessedTensorDataset: %d samples from %s", ...)) to use
brace-style placeholders (e.g., "Some tensor files not found: {} missing") and
pass the values as positional args so Loguru will interpolate them; keep the
exception handling (in the try/except around torch.load and appending to
self.latent_lengths) unchanged.
In `@acestep/training/preprocessed_sampler.py`:
- Around line 42-45: The current __len__ returns ceil(total_items / batch_size)
which undercounts when __iter__ yields batches per bucket; update __len__ to sum
over buckets instead: for each bucket in self.lengths compute (len(bucket) +
self.batch_size - 1) // self.batch_size and return the total, ensuring you use
the same bucket structure that __iter__ uses (reference __len__, __iter__,
self.lengths, self.batch_size).
---
Duplicate comments:
In `@acestep/training/data_module.py`:
- Around line 129-135: The bare "except Exception" in the loop that loads
samples for self.valid_paths (where you call torch.load and access
sample["target_latents"] to populate self.latent_lengths) should be replaced
with explicit error handling: catch only the expected exceptions (e.g.,
FileNotFoundError, RuntimeError from torch.load, KeyError when "target_latents"
is missing) and handle each case (log the error with context including the path
and exception, append 0 only for recoverable/known issues), while letting
unexpected exceptions propagate (or re-raise) so real bugs or schema changes are
not silently swallowed.
- Around line 375-392: train_dataloader currently checks
hasattr(self.train_dataset, "latent_lengths") and thus skips bucketing when
train_dataset is a Subset (e.g., from random_split) and the new helper
_resolve_train_latent_lengths() is unused; update train_dataloader to call
_resolve_train_latent_lengths() (or otherwise obtain latent lengths via that
helper) and use its result to decide whether to set batch_sampler to
BucketedBatchSampler and remove batch_size/shuffle, ensuring Subset-wrapped
datasets are handled and shuffling is preserved when length_bucket=True; refer
to symbols train_dataloader, _resolve_train_latent_lengths, self.train_dataset,
length_bucket, BucketedBatchSampler, and Subset when making the change.
---
Nitpick comments:
In `@acestep/training/data_module.py`:
- Around line 36-68: This file duplicates BucketedBatchSampler and
PreprocessedTensorDataset which have drifted from the canonical implementations;
remove the local classes and import BucketedBatchSampler from the shared
preprocessed_sampler module and PreprocessedTensorDataset from the shared
preprocessed_dataset module (update all usages to the imported symbols), then
split any remaining functionality in this file so the module stays under the
200‑line guideline by moving unrelated helpers/classes into separate modules;
ensure tests/imports still reference the imported names and run linters to catch
any unresolved references.
In `@acestep/training/preprocessed_collate.py`:
- Around line 10-61: The collate logic is duplicated; keep
collate_preprocessed_batch as the single canonical function and remove the copy
in data_module.py: delete the duplicate function body there (the one that pads
target_latents, attention_mask, context_latents, encoder_hidden_states, and
encoder_attention_mask) and replace it with an import of
collate_preprocessed_batch from this module, updating any reference in
DataModule (or functions that previously called the duplicate) to call
collate_preprocessed_batch; run a quick search for other copies and update
imports/usages so all code uses the single collate_preprocessed_batch
implementation to avoid drift.
In `@docs/lora_full_finetune_pipeline_proposal.md`:
- Around line 34-40: Add a short note under A2 (Add duration/latent-length
bucketing in dataloader) that describes potential training-dynamics risks from
bucketing (e.g., unintended correlations from grouping by length) and prescribe
validation steps: (1) run epoch-level bucket-aware shuffle checks to ensure
samples are randomly interleaved across epochs and within buckets, (2) compare
training metrics (loss/accuracy/validation curves) with and without bucketing to
detect shifts, and (3) include a bullet in QA command A6 to execute these
bucket-aware shuffle validations and metric comparisons as part of the QA
checklist.
- Around line 133-144: Update the Acceptance criteria paragraph to specify exact
measurement methods: define "data-loader stall time" measurement using a
profiler (e.g., PyTorch profiler or perf metrics) and the exact metric (percent
time waiting vs total batch time) and sampling window; define "step time"
measurement method (median of N runs with fixed seed and dataset shard, include
warmup steps and measurement steps); enumerate what "supported GPU setup" means
(minimum VRAM, example device models such as NVIDIA A100/RTX3090, CUDA/cuDNN
versions) and the test protocol for "Successfully trains decoder-only full FT"
(one epoch, fixed batch size, seed, and dataset subset); and add explicit
validation for "CLI safety checks" by listing specific OOM-prone config test
cases and expected CLI responses (warnings/failures) so acceptance tests are
reproducible.
- Around line 20-32: Add migration tooling and docs for converting per-sample
`.pt` datasets to the new v3 sharded format: implement a CLI command (e.g.,
convert-dataset or convert_dataset) that accepts flags matching the loader
options (--dataset-format {pt,v3_sharded}, input/output paths) to produce
`shard-xxxxx.safetensors` and the corresponding `index.jsonl`, include
progress/error reporting and verification (checksums/sample counts), and update
the proposal and docs with guidance on when to migrate vs keep the `.pt` format
and examples of using the migration command.
- Around line 99-101: Update the proposal to specify the VRAM estimation
methodology for the "preflight VRAM estimator" referenced alongside the opt-in
flags (`--training-mode full --i-understand-vram-risk`) and the "Auto-suggest
fallback to LoRA" behavior: state whether the estimator is formula-based
(analytical) or profile-based (dry-run with dummy tensors), describe how you
will account for framework/runtime overhead, optimizer/AMP/optimizer state sizes
and gradient accumulation, define an acceptable margin-of-error/confidence
interval and how that uncertainty is surfaced to users, and specify explicit
fallback behavior and messaging when confidence is low (e.g., suggest LoRA,
require extra confirmation, or perform an on-demand dry-run).
- Around line 147-151: The four bullet items in section "Immediate next patch
candidates" use the same leading verb "Add", creating repetitive sentence
structure; update each line to vary the verb while preserving meaning and the
referenced symbols (e.g., change "Add `DatasetBackend` abstraction and a sharded
backend implementation." to "Implement `DatasetBackend` abstraction with a
sharded backend implementation.", change "Add `BucketedBatchSampler` keyed by
latent length." to "Introduce `BucketedBatchSampler` keyed by latent length.",
change "Add `train.py validate-dataset --dataset-dir ...` command." to "Extend
`train.py` with a `validate-dataset --dataset-dir ...` command.", and change
"Add `training_mode` enum + `full` branch in config/trainer with decoder-only
unfreeze." to "Add a `training_mode` enum and a `full` branch in config/trainer
that enables decoder-only unfreeze."); keep the same symbols (`DatasetBackend`,
`BucketedBatchSampler`, `train.py validate-dataset`, `training_mode` and `full`)
so reviewers can locate the changes.
- Around line 109-130: Add a testing strategy tied to each phase: for Phase 1,
add unit tests for the packed dataset writer/loader, bucketed sampler, and
preprocessing QA validator (e.g., PackedDatasetWriterTests,
BucketedSamplerTests, QAValidatorTests) plus backward-compatibility tests for
old vs. new data formats; for Phase 2, add unit/integration tests for mmap/LRU
cache policy and optional device prefetch (CachePolicyTests,
DevicePrefetchIntegration) and performance regression benchmarks that validate
throughput improvements and record baseline metrics by GPU class; for Phase 3,
add unit tests for training_mode=full behavior, parameter group LR multipliers,
and full-state checkpoint/resume (TrainingModeFullTests, ParamGroupTests,
CheckpointResumeTests) as well as end-to-end training integration tests; for
Phase 4, add tests for distributed/sharded optimizer options, staged unfreeze
profiles, and eval/early-stop controls (DistributedOptimizerTests,
StagedUnfreezeTests, EarlyStopIntegration) and ensure all phases are wired into
CI with clear pass/fail criteria, recorded performance thresholds, and
reproducible test data fixtures.
- Around line 93-96: Update the checkpointing metadata to include an explicit
checkpoint format version field (e.g., "checkpoint_version" using semantic
versioning) whenever saving the full model state, sharded states, optimizer,
scheduler, scaler, or EMA checkpoints; on resume (load) validate this metadata
in the resume logic and emit a clear, actionable error if the version is
incompatible (or route to compatibility/upgrade handlers), and ensure any code
paths that change the on-disk layout bump the checkpoint_version and add a
loader/upgrade path for older versions to avoid cryptic failures.
ℹ️ Review info
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
acestep/training/data_module.pyacestep/training/preprocessed_collate.pyacestep/training/preprocessed_dataset.pyacestep/training/preprocessed_sampler.pydocs/lora_full_finetune_pipeline_proposal.md
| from typing import Dict, List | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| def collate_preprocessed_batch(batch: List[Dict]) -> Dict[str, torch.Tensor]: |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# First, locate the file and read it
fd -t f "preprocessed_collate.py" | head -5Repository: ace-step/ACE-Step-1.5
Length of output: 105
🏁 Script executed:
#!/bin/bash
# Read the file content, focusing on the function signature and return statement
if [ -f "acestep/training/preprocessed_collate.py" ]; then
cat -n "acestep/training/preprocessed_collate.py" | head -70
else
find . -name "preprocessed_collate.py" -exec cat -n {} \; | head -70
fiRepository: ace-step/ACE-Step-1.5
Length of output: 2713
🏁 Script executed:
#!/bin/bash
# Search for usages of collate_preprocessed_batch function
rg -n "collate_preprocessed_batch" --type=pyRepository: ace-step/ACE-Step-1.5
Length of output: 604
🏁 Script executed:
#!/bin/bash
# Check if metadata field is accessed anywhere to understand usage patterns
rg -n "\[\"metadata\"\]|\['metadata'\]" --type=py -B2 -A2Repository: ace-step/ACE-Step-1.5
Length of output: 3934
🏁 Script executed:
#!/bin/bash
# Read the function in data_module.py to see if it has the same issue
sed -n '207,275p' acestep/training/data_module.pyRepository: ace-step/ACE-Step-1.5
Length of output: 2832
Fix the return type annotation to match the actual payload in both collate_preprocessed_batch functions.
Both acestep/training/preprocessed_collate.py (line 10) and acestep/training/data_module.py (line 207) declare -> Dict[str, torch.Tensor], but both functions return "metadata" as a Python list (lines 60 and 271 respectively). This type mismatch breaks static type checking and misleads callers about the actual structure.
Replace the imprecise Dict[str, torch.Tensor] with a TypedDict that accurately describes the returned structure:
Proposed typing fix
-from typing import Dict, List
+from typing import Any, Dict, List, TypedDict
+
+
+class CollatedPreprocessedBatch(TypedDict):
+ target_latents: torch.Tensor
+ attention_mask: torch.Tensor
+ encoder_hidden_states: torch.Tensor
+ encoder_attention_mask: torch.Tensor
+ context_latents: torch.Tensor
+ metadata: List[Dict[str, Any]]
-def collate_preprocessed_batch(batch: List[Dict]) -> Dict[str, torch.Tensor]:
+def collate_preprocessed_batch(batch: List[Dict[str, Any]]) -> CollatedPreprocessedBatch:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@acestep/training/preprocessed_collate.py` around lines 5 - 10, The return
annotation Dict[str, torch.Tensor] is incorrect because
collate_preprocessed_batch returns a mix of tensors and a Python list for
"metadata"; define a TypedDict (e.g., PreprocessedBatch: TypedDict) that lists
each returned key and its precise type (for example tensor fields as
torch.Tensor and "metadata" as List[Dict[str, Any]]), import TypedDict and Any
from typing, and update the return type of collate_preprocessed_batch (in this
module) and the matching function in data_module.py to -> PreprocessedBatch so
static typing matches the actual payload.
| logger.warning( | ||
| "Some tensor files not found: %d missing", | ||
| len(self.sample_paths) - len(self.valid_paths), | ||
| ) | ||
|
|
||
| self.latent_lengths: List[int] = [] | ||
| for vp in self.valid_paths: | ||
| try: | ||
| sample = torch.load(vp, map_location="cpu", weights_only=True) | ||
| self.latent_lengths.append(int(sample["target_latents"].shape[0])) | ||
| except (FileNotFoundError, PermissionError, EOFError, OSError, KeyError, RuntimeError) as exc: | ||
| logger.warning("Failed to read latent length from %s: %s", vp, exc) | ||
| self.latent_lengths.append(0) | ||
|
|
||
| logger.info("PreprocessedTensorDataset: %d samples from %s", len(self.valid_paths), self.tensor_dir) |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
loguru Python logger formatting API {} vs % placeholders documentation
💡 Result:
Loguru’s logging methods (logger.debug/info/...) do not support %-style placeholders (%s, %(name)s, etc.). They support {}-style formatting only, using Python’s str.format() under the hood:
- Replace stdlib-style:
logger.debug("Some variable: %s", var) - With Loguru-style:
logger.debug("Some variable: {}", var)(loguru.readthedocs.io)
How it works:
- Any
*args/**kwargsyou pass are applied asmessage.format(*args, **kwargs)(and may be ignored if not referenced by a{}field). (loguru.readthedocs.io) - Example with named fields:
logger.info("My name is {name}", name="John"). (loguru.readthedocs.io)
Sink/handler format= strings (for logger.add(..., format="...")) also use brace fields like {time}, {level}, {message}. (loguru.readthedocs.io)
Citations:
- 1: https://loguru.readthedocs.io/en/stable/resources/migration.html?utm_source=openai
- 2: https://loguru.readthedocs.io/en/stable/resources/migration.html?utm_source=openai
- 3: https://loguru.readthedocs.io/en/latest/resources/troubleshooting.html?utm_source=openai
- 4: https://loguru.readthedocs.io/en/stable/api/logger.html?utm_source=openai
🏁 Script executed:
rg "logger\.(debug|info|warning|error)" acestep/training/preprocessed_dataset.py -A 2Repository: ace-step/ACE-Step-1.5
Length of output: 767
Loguru uses {} placeholders, not % style—log arguments are currently being silently dropped.
These logger calls use %-style formatting which loguru does not support. The arguments are passed to .format(*args, **kwargs) but ignored because no {} fields exist in the message strings. Replace with {} formatting across all five instances (lines 57, 65, 69, 85, and 90).
🔧 Proposed fix
logger.warning(
- "Some tensor files not found: %d missing",
+ "Some tensor files not found: {} missing",
len(self.sample_paths) - len(self.valid_paths),
)
@@
- logger.warning("Failed to read latent length from %s: %s", vp, exc)
+ logger.warning("Failed to read latent length from {}: {}", vp, exc)
self.latent_lengths.append(0)
@@
- logger.info("PreprocessedTensorDataset: %d samples from %s", len(self.valid_paths), self.tensor_dir)
+ logger.info("PreprocessedTensorDataset: {} samples from {}", len(self.valid_paths), self.tensor_dir)
@@
- logger.debug("Resolved legacy manifest path via safe root: %s", raw)
+ logger.debug("Resolved legacy manifest path via safe root: {}", raw)
@@
- logger.warning("Skipping unresolvable manifest path: %s", raw)
+ logger.warning("Skipping unresolvable manifest path: {}", raw)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@acestep/training/preprocessed_dataset.py` around lines 57 - 71, The logger
calls in PreprocessedTensorDataset use percent-style formatting which Loguru
ignores; update all logger.warning and logger.info usages in this region (the
calls inside the constructor/initializer that reference logger.warning("Some
tensor files not found: %d missing", ...), logger.warning("Failed to read latent
length from %s: %s", vp, exc), and logger.info("PreprocessedTensorDataset: %d
samples from %s", ...)) to use brace-style placeholders (e.g., "Some tensor
files not found: {} missing") and pass the values as positional args so Loguru
will interpolate them; keep the exception handling (in the try/except around
torch.load and appending to self.latent_lengths) unchanged.
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: | ||
| """Load one preprocessed sample.""" | ||
| if self.cache_policy == "ram_lru" and idx in self._cache: | ||
| data = self._cache.pop(idx) | ||
| self._cache[idx] = data | ||
| else: | ||
| data = torch.load(self.valid_paths[idx], map_location="cpu", weights_only=True) | ||
| if self.cache_policy == "ram_lru" and self.cache_max_items > 0: | ||
| self._cache[idx] = data | ||
| while len(self._cache) > self.cache_max_items: | ||
| self._cache.popitem(last=False) | ||
|
|
||
| return { | ||
| "target_latents": data["target_latents"], | ||
| "attention_mask": data["attention_mask"], | ||
| "encoder_hidden_states": data["encoder_hidden_states"], | ||
| "encoder_attention_mask": data["encoder_attention_mask"], | ||
| "context_latents": data["context_latents"], | ||
| "metadata": data.get("metadata", {}), | ||
| } |
There was a problem hiding this comment.
__getitem__ return annotation is too narrow for the returned payload.
The method returns "metadata" as a dict, so Dict[str, torch.Tensor] is incorrect.
🔧 Proposed fix
- def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
+ def __getitem__(self, idx: int) -> Dict[str, Any]:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@acestep/training/preprocessed_dataset.py` around lines 96 - 115, The return
type of __getitem__ is too narrow: it returns tensors for most keys but
"metadata" is a dict (or could be other non-tensor values), so update the type
hint for __getitem__ from Dict[str, torch.Tensor] to a broader type such as
Dict[str, Any] (and add Any to the typing imports) so the signature reflects
that values can be non-tensor; ensure the change is applied to the __getitem__
definition and any related annotations in the same module that reference this
method.
| def __len__(self) -> int: | ||
| """Return estimated number of batches.""" | ||
| total = len(self.lengths) | ||
| return (total + self.batch_size - 1) // self.batch_size |
There was a problem hiding this comment.
__len__ can underreport actual batch count for bucketed iteration.
__iter__ batches per bucket, but ceil(total / batch_size) ignores per-bucket remainders and can be smaller than the real yielded batch count.
🔧 Proposed fix
def __len__(self) -> int:
- """Return estimated number of batches."""
- total = len(self.lengths)
- return (total + self.batch_size - 1) // self.batch_size
+ """Return number of batches produced by bucketed iteration."""
+ bucket_counts: Dict[int, int] = {}
+ for length in self.lengths:
+ bucket = int(length // 64)
+ bucket_counts[bucket] = bucket_counts.get(bucket, 0) + 1
+ return sum(
+ (count + self.batch_size - 1) // self.batch_size
+ for count in bucket_counts.values()
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@acestep/training/preprocessed_sampler.py` around lines 42 - 45, The current
__len__ returns ceil(total_items / batch_size) which undercounts when __iter__
yields batches per bucket; update __len__ to sum over buckets instead: for each
bucket in self.lengths compute (len(bucket) + self.batch_size - 1) //
self.batch_size and return the total, ensuring you use the same bucket structure
that __iter__ uses (reference __len__, __iter__, self.lengths, self.batch_size).
…ora-training-pipeline Preprocessed dataset improvements: bucketing, RAM-LRU cache, validator, and full fine-tune mode
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
acestep/training_v2/trainer_helpers.py (2)
374-374:⚠️ Potential issue | 🟡 MinorFix unresolved
Dictannotation (F821) in the return type.
Dictis referenced but not imported. Prefer the built-in generic for the Python 3.11+ target.Proposed fix
-) -> Optional[Tuple[int, int, Dict[str, Any]]]: +) -> Optional[Tuple[int, int, dict[str, Any]]]:#!/bin/bash set -euo pipefail # Verify there are no unresolved typing.Dict annotations left in this module. fd -p 'trainer_helpers.py$' --exec rg -nP '\bDict\s*\[' {}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/training_v2/trainer_helpers.py` at line 374, The return type currently uses typing.Dict in the signature ")-> Optional[Tuple[int, int, Dict[str, Any]]]:", which is unresolved; replace typing.Dict with the built-in generic by changing the annotation to "Optional[Tuple[int, int, dict[str, Any]]]" and ensure Any remains imported (or replace Any with an appropriate concrete type); also remove any unused "Dict" import from typing and scan the module for other occurrences of "Dict[" and update them to "dict[" to avoid F821 errors.
494-497:⚠️ Potential issue | 🟠 MajorUse the shared unwrapping helper before
load_state_dict().This still unwraps only one
_forward_modulelevel and can miss nested wrappers. Use_unwrap_decoder(module.model)here for parity with the full-mode helpers.Proposed fix
- decoder = module.model.decoder - if hasattr(decoder, "_forward_module"): - decoder = decoder._forward_module + decoder = _unwrap_decoder(module.model) decoder.load_state_dict(state_dict, strict=False)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@acestep/training_v2/trainer_helpers.py` around lines 494 - 497, Replace the manual one-level unwrap of module.model.decoder with the shared helper: call _unwrap_decoder(module.model) to obtain the real decoder instance, then call decoder.load_state_dict(state_dict, strict=False); ensure you still pass strict=False and remove the hasattr(... "_forward_module") branch so nested wrappers are correctly unwrapped before load_state_dict.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@acestep/training_v2/trainer_helpers.py`:
- Around line 447-460: _resume_lokr can return None on handled paths which
_resume_lokr_or_lora treats as "not handled" and falls back to LoRA; change
_resume_lokr (the function that calls state_loader and yields
TrainingUpdate/TrainingUpdate warnings) to return an explicit handled boolean
(e.g., True when LoKR was loaded or when no training_state/handled warning path
occurred, False only when truly unhandled), update callers in
_resume_lokr_or_lora to check that boolean instead of truthiness/None so LoRA
loading is skipped when LoKR already handled the resume; apply the same explicit
handled-return fix to the other mirrored block (the second occurrence referenced
around lines 520-529).
---
Duplicate comments:
In `@acestep/training_v2/trainer_helpers.py`:
- Line 374: The return type currently uses typing.Dict in the signature ")->
Optional[Tuple[int, int, Dict[str, Any]]]:", which is unresolved; replace
typing.Dict with the built-in generic by changing the annotation to
"Optional[Tuple[int, int, dict[str, Any]]]" and ensure Any remains imported (or
replace Any with an appropriate concrete type); also remove any unused "Dict"
import from typing and scan the module for other occurrences of "Dict[" and
update them to "dict[" to avoid F821 errors.
- Around line 494-497: Replace the manual one-level unwrap of
module.model.decoder with the shared helper: call _unwrap_decoder(module.model)
to obtain the real decoder instance, then call
decoder.load_state_dict(state_dict, strict=False); ensure you still pass
strict=False and remove the hasattr(... "_forward_module") branch so nested
wrappers are correctly unwrapped before load_state_dict.
| training_state = state_loader(ckpt_dir=ckpt_dir) | ||
| if training_state is None: | ||
| yield TrainingUpdate(0, 0.0, "[OK] LoKR weights loaded (no training state)", kind="info") | ||
| return None | ||
|
|
||
| # Full mode resume: no full decoder checkpoint found. | ||
| if getattr(trainer.training_config, "training_mode", "adapter") == "full": | ||
| yield TrainingUpdate( | ||
| 0, | ||
| 0.0, | ||
| f"[WARN] full_decoder_state.pt not found in {ckpt_dir}", | ||
| kind="warn", | ||
| ) | ||
| return |
There was a problem hiding this comment.
LoKR resume path can incorrectly fall through into LoRA resume.
After LoKR weights are loaded, _resume_lokr() still returns None on handled paths, and _resume_lokr_or_lora() treats that as “not handled” and attempts LoRA loading. This can mis-resume or override the intended LoKR state.
Proposed fix
def _resume_lokr(
@@
load_lokr_weights(module.lycoris_net, str(lokr_weights_path))
training_state = state_loader(ckpt_dir=ckpt_dir)
if training_state is None:
yield TrainingUpdate(0, 0.0, "[OK] LoKR weights loaded (no training state)", kind="info")
return None
-
- # Full mode resume: no full decoder checkpoint found.
- if getattr(trainer.training_config, "training_mode", "adapter") == "full":
- yield TrainingUpdate(
- 0,
- 0.0,
- f"[WARN] full_decoder_state.pt not found in {ckpt_dir}",
- kind="warn",
- )
- return
+ epoch, step, _ = training_state
+ yield TrainingUpdate(
+ 0,
+ 0.0,
+ f"[OK] Resumed LoKR from epoch {epoch}, step {step}",
+ kind="info",
+ )
+ return (epoch, step)
@@
def _resume_lokr_or_lora(
@@
- lokr_result = yield from _resume_lokr(
+ has_lokr_checkpoint = (ckpt_dir / "lokr_weights.safetensors").exists()
+ lokr_result = yield from _resume_lokr(
trainer,
resume_path,
ckpt_dir,
module,
state_loader,
)
+ if has_lokr_checkpoint and module.lycoris_net is not None:
+ return lokr_result
if lokr_result is not None:
return lokr_result
return (yield from _resume_lora(ckpt_dir, module, optimizer, scheduler))Also applies to: 520-529
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@acestep/training_v2/trainer_helpers.py` around lines 447 - 460, _resume_lokr
can return None on handled paths which _resume_lokr_or_lora treats as "not
handled" and falls back to LoRA; change _resume_lokr (the function that calls
state_loader and yields TrainingUpdate/TrainingUpdate warnings) to return an
explicit handled boolean (e.g., True when LoKR was loaded or when no
training_state/handled warning path occurred, False only when truly unhandled),
update callers in _resume_lokr_or_lora to check that boolean instead of
truthiness/None so LoRA loading is skipped when LoKR already handled the resume;
apply the same explicit handled-return fix to the other mirrored block (the
second occurrence referenced around lines 520-529).
proposing improvements for the lora training and full model finetune option
Implemented a concrete dataset QA phase with a new validate-dataset subcommand in CLI parsing and dispatch, plus a validator that checks required tensor keys, flags non-finite values, and reports latent-length stats before training.
Implemented data pipeline throughput improvements in the runtime loader: latent-length bucketing (BucketedBatchSampler) and optional RAM LRU caching in PreprocessedTensorDataset, and wired them through PreprocessedDataModule and trainer config flow.
Implemented full fine-tuning MVP (decoder-only) with training_mode=full: explicit freeze/unfreeze behavior, param-family LR group construction, and optimizer group usage in both Fabric and non-Fabric loops.
Added full-mode checkpoint/resume support (full_decoder_state.pt) and compatibility checks in save/verify/resume helpers so full FT can be resumed like adapter training state flows.
Added focused unit tests for new behavior:
dataset validation success+edge path,
full-mode decoder-only trainability and param-group coverage.
Summary by CodeRabbit
New Features
Tests
Documentation