jaxsplat

jaxsplat is a port of 3D Gaussian Splatting to JAX. Fully differentiable, CUDA accelerated.

Installation

Requires a working CUDA toolchain to install. Simply pip installing directly from source should build and install jaxsplat:

$ python -m venv venv && . venv/bin/activate
$ pip install git+https://github.com/yklcs/jaxsplat

Usage

The primary function of jaxsplat is jaxsplat.render, which renders 3D Gaussians to a 2D image differentiably. View the rendering API docs for more complete docs.

img = jaxsplat.render(
    means3d,
    scales,
    quats,
    colors,
    opacities,
    viewmat=viewmat,
    background=background,
    img_shape=img_shape,
    f=f,
    c=c,
    glob_scale=glob_scale,
    clip_thresh=clip_thresh,
    block_size=block_size,
)

Bibliography