From e7d333c8ec155a26a30d178b325229f7953c4f3c Mon Sep 17 00:00:00 2001 From: Robert Ioffe Date: Wed, 4 Mar 2026 10:52:41 -0800 Subject: [PATCH] Add --precision flag for reduced-precision inference Selectively casts heavy encoder/backbone modules (monodepth_model, feature_model) to bfloat16 or float16 while keeping lightweight heads in float32 for numerical stability. Achieves ~2x inference speedup on MPS with bfloat16. Co-Authored-By: Claude Opus 4.6 --- README.md | 10 ++++++ src/sharp/cli/predict.py | 75 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 83 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 59557530..4bb754f2 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,16 @@ To use a manually downloaded checkpoint, specify it with the `-c` flag: sharp predict -i /path/to/input/images -o /path/to/output/gaussians -c sharp_2572gikvuh.pt ``` +### Reduced-precision inference + +Use the `--precision` flag to run the heavy encoder modules in bfloat16 or float16. This can significantly reduce inference time (up to ~2x on supported hardware) with minimal impact on output quality: + +``` +sharp predict -i /path/to/input/images -o /path/to/output/gaussians --precision bfloat16 +``` + +Accepted values are `float32` (default), `bfloat16`, and `float16`. Only the large encoder/backbone modules are cast to reduced precision; lightweight heads remain in float32 for numerical stability. + The results will be 3D gaussian splats (3DGS) in the output folder. The 3DGS `.ply` files are compatible to various public 3DGS renderers. We follow the OpenCV coordinate convention (x right, y down, z forward). The 3DGS scene center is roughly at (0, 0, +z). When dealing with 3rdparty renderers, please scale and rotate to re-center the scene accordingly. ### Rendering trajectories (CUDA GPU only) diff --git a/src/sharp/cli/predict.py b/src/sharp/cli/predict.py index 8914bb56..a00f310d 100644 --- a/src/sharp/cli/predict.py +++ b/src/sharp/cli/predict.py @@ -72,6 +72,12 @@ default="default", help="Device to run on. ['cpu', 'mps', 'cuda']", ) +@click.option( + "--precision", + type=click.Choice(["float32", "bfloat16", "float16"]), + default="float32", + help="Inference precision. bfloat16/float16 reduce memory and may be faster.", +) @click.option("-v", "--verbose", is_flag=True, help="Activate debug logs.") def predict_cli( input_path: Path, @@ -79,6 +85,7 @@ def predict_cli( checkpoint_path: Path, with_rendering: bool, device: str, + precision: str, verbose: bool, ): """Predict Gaussians from input images.""" @@ -142,7 +149,9 @@ def predict_cli( device=device, dtype=torch.float32, ) - gaussians = predict_image(gaussian_predictor, image, f_px, torch.device(device)) + gaussians = predict_image( + gaussian_predictor, image, f_px, torch.device(device), precision + ) LOGGER.info("Saving 3DGS to %s", output_path) save_ply(gaussians, f_px, (height, width), output_path / f"{image_path.stem}.ply") @@ -161,6 +170,7 @@ def predict_image( image: np.ndarray, f_px: float, device: torch.device, + precision: str = "float32", ) -> Gaussians3D: """Predict Gaussians from an image.""" internal_shape = (1536, 1536) @@ -179,7 +189,68 @@ def predict_image( # Predict Gaussians in the NDC space. LOGGER.info("Running inference.") - gaussians_ndc = predictor(image_resized_pt, disparity_factor) + + # Selective precision casting: only cast heavy encoder/backbone modules to + # the target dtype. Lightweight modules (init_model, prediction_head, + # gaussian_composer) stay in float32 for numerical stability. Forward hooks + # cast inputs on entry and outputs back to float32 on exit. + use_autocast = precision != "float32" + dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16} + autocast_dtype = dtype_map.get(precision, torch.float32) + + cast_modules: list[torch.nn.Module] = [] + hooks: list[torch.utils.hooks.RemovableHandle] = [] + if use_autocast: + + def _cast_input_tensor(obj): + if isinstance(obj, torch.Tensor) and obj.is_floating_point(): + return obj.to(autocast_dtype) + if isinstance(obj, list): + return [_cast_input_tensor(x) for x in obj] + if isinstance(obj, tuple) and hasattr(obj, "_asdict"): + fields = {k: _cast_input_tensor(v) for k, v in obj._asdict().items()} + return type(obj)(**fields) + if isinstance(obj, tuple): + return tuple(_cast_input_tensor(x) for x in obj) + return obj + + def _cast_inputs(_mod, args, kwargs): + new_args = tuple(_cast_input_tensor(a) for a in args) + new_kwargs = {k: _cast_input_tensor(v) for k, v in kwargs.items()} + return new_args, new_kwargs + + def _cast_output_to_float(obj): + if isinstance(obj, torch.Tensor) and obj.is_floating_point(): + return obj.float() + if isinstance(obj, list): + return [_cast_output_to_float(x) for x in obj] + if isinstance(obj, dict): + return {k: _cast_output_to_float(v) for k, v in obj.items()} + if isinstance(obj, tuple) and hasattr(obj, "_asdict"): + fields = {k: _cast_output_to_float(v) for k, v in obj._asdict().items()} + return type(obj)(**fields) + if isinstance(obj, tuple): + return tuple(_cast_output_to_float(x) for x in obj) + return obj + + def _cast_outputs_hook(_mod, _input, output): + return _cast_output_to_float(output) + + for name in ("monodepth_model", "feature_model"): + mod = getattr(predictor, name, None) + if mod is not None: + mod.to(autocast_dtype) + cast_modules.append(mod) + hooks.append(mod.register_forward_pre_hook(_cast_inputs, with_kwargs=True)) + hooks.append(mod.register_forward_hook(_cast_outputs_hook)) + + try: + gaussians_ndc = predictor(image_resized_pt, disparity_factor) + finally: + for hook in hooks: + hook.remove() + for mod in cast_modules: + mod.float() LOGGER.info("Running postprocessing.") intrinsics = (