Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ To use a manually downloaded checkpoint, specify it with the `-c` flag:
sharp predict -i /path/to/input/images -o /path/to/output/gaussians -c sharp_2572gikvuh.pt
```

### Reduced-precision inference

Use the `--precision` flag to run the heavy encoder modules in bfloat16 or float16. This can significantly reduce inference time (up to ~2x on supported hardware) with minimal impact on output quality:

```
sharp predict -i /path/to/input/images -o /path/to/output/gaussians --precision bfloat16
```

Accepted values are `float32` (default), `bfloat16`, and `float16`. Only the large encoder/backbone modules are cast to reduced precision; lightweight heads remain in float32 for numerical stability.

The results will be 3D gaussian splats (3DGS) in the output folder. The 3DGS `.ply` files are compatible to various public 3DGS renderers. We follow the OpenCV coordinate convention (x right, y down, z forward). The 3DGS scene center is roughly at (0, 0, +z). When dealing with 3rdparty renderers, please scale and rotate to re-center the scene accordingly.

### Rendering trajectories (CUDA GPU only)
Expand Down
75 changes: 73 additions & 2 deletions src/sharp/cli/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,20 @@
default="default",
help="Device to run on. ['cpu', 'mps', 'cuda']",
)
@click.option(
"--precision",
type=click.Choice(["float32", "bfloat16", "float16"]),
default="float32",
help="Inference precision. bfloat16/float16 reduce memory and may be faster.",
)
@click.option("-v", "--verbose", is_flag=True, help="Activate debug logs.")
def predict_cli(
input_path: Path,
output_path: Path,
checkpoint_path: Path,
with_rendering: bool,
device: str,
precision: str,
verbose: bool,
):
"""Predict Gaussians from input images."""
Expand Down Expand Up @@ -142,7 +149,9 @@ def predict_cli(
device=device,
dtype=torch.float32,
)
gaussians = predict_image(gaussian_predictor, image, f_px, torch.device(device))
gaussians = predict_image(
gaussian_predictor, image, f_px, torch.device(device), precision
)

LOGGER.info("Saving 3DGS to %s", output_path)
save_ply(gaussians, f_px, (height, width), output_path / f"{image_path.stem}.ply")
Expand All @@ -161,6 +170,7 @@ def predict_image(
image: np.ndarray,
f_px: float,
device: torch.device,
precision: str = "float32",
) -> Gaussians3D:
"""Predict Gaussians from an image."""
internal_shape = (1536, 1536)
Expand All @@ -179,7 +189,68 @@ def predict_image(

# Predict Gaussians in the NDC space.
LOGGER.info("Running inference.")
gaussians_ndc = predictor(image_resized_pt, disparity_factor)

# Selective precision casting: only cast heavy encoder/backbone modules to
# the target dtype. Lightweight modules (init_model, prediction_head,
# gaussian_composer) stay in float32 for numerical stability. Forward hooks
# cast inputs on entry and outputs back to float32 on exit.
use_autocast = precision != "float32"
dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16}
autocast_dtype = dtype_map.get(precision, torch.float32)

cast_modules: list[torch.nn.Module] = []
hooks: list[torch.utils.hooks.RemovableHandle] = []
if use_autocast:

def _cast_input_tensor(obj):
if isinstance(obj, torch.Tensor) and obj.is_floating_point():
return obj.to(autocast_dtype)
if isinstance(obj, list):
return [_cast_input_tensor(x) for x in obj]
if isinstance(obj, tuple) and hasattr(obj, "_asdict"):
fields = {k: _cast_input_tensor(v) for k, v in obj._asdict().items()}
return type(obj)(**fields)
if isinstance(obj, tuple):
return tuple(_cast_input_tensor(x) for x in obj)
return obj

def _cast_inputs(_mod, args, kwargs):
new_args = tuple(_cast_input_tensor(a) for a in args)
new_kwargs = {k: _cast_input_tensor(v) for k, v in kwargs.items()}
return new_args, new_kwargs

def _cast_output_to_float(obj):
if isinstance(obj, torch.Tensor) and obj.is_floating_point():
return obj.float()
if isinstance(obj, list):
return [_cast_output_to_float(x) for x in obj]
if isinstance(obj, dict):
return {k: _cast_output_to_float(v) for k, v in obj.items()}
if isinstance(obj, tuple) and hasattr(obj, "_asdict"):
fields = {k: _cast_output_to_float(v) for k, v in obj._asdict().items()}
return type(obj)(**fields)
if isinstance(obj, tuple):
return tuple(_cast_output_to_float(x) for x in obj)
return obj

def _cast_outputs_hook(_mod, _input, output):
return _cast_output_to_float(output)

for name in ("monodepth_model", "feature_model"):
mod = getattr(predictor, name, None)
if mod is not None:
mod.to(autocast_dtype)
cast_modules.append(mod)
hooks.append(mod.register_forward_pre_hook(_cast_inputs, with_kwargs=True))
hooks.append(mod.register_forward_hook(_cast_outputs_hook))

try:
gaussians_ndc = predictor(image_resized_pt, disparity_factor)
finally:
for hook in hooks:
hook.remove()
for mod in cast_modules:
mod.float()

LOGGER.info("Running postprocessing.")
intrinsics = (
Expand Down