Nabla: High-Performance Distributed ML#
Nabla is a JAX-inspired autodiff library with factor-based SPMD sharding, built on Mojo & MAX.
Installation#
Nabla requires Modular nightly.
python -m venv .venv
source .venv/bin/activate
pip install --pre --extra-index-url https://whl.modular.com/nightly/simple/ modular nabla-ml
GPU Support:
Linux (AMD/NVIDIA): Supported natively via Modular MAX.
macOS (Apple Silicon): Requires Xcode Metal toolchain (
xcode-select --install).
Development Setup#
Installation of all dependencies (torch/jax for testing, mypy/black for linting, etc.)
git clone https://github.com/nabla-ml/nabla.git
cd nabla
python -m venv venv
source venv/bin/activate
pip install -r requirements-dev.txt
pip install -e ".[dev]"
Feature Showcase#
1. Tensors & Autodiff#
Define Python functions and compute gradients using trace-based automatic differentiation. Explore the code.
import nabla
# Use Accelerator (GPU) or CPU for execution
with nabla.default_device(nabla.Accelerator()):
x = nabla.uniform((4, 8))
w = nabla.uniform((8, 16))
# Define loss function
def compute_loss(x, w):
return nabla.mean(nabla.relu(x @ w))
# Compute loss (implicit .realize() on print)
loss = compute_loss(x, w)
print("Loss:", loss)
# Compute gradients via backward replay
grad_x, grad_w = nabla.grad(compute_loss, argnums=(0, 1))(x, w)
print("Gradients:", grad_x.shape, grad_w.shape)
2. SPMD Sharding#
Shard tensors on a logical mesh; operations automatically propagate sharding constraints. Explore the code.
# Define 2×4 device mesh (Logical DP × TP)
mesh = nabla.DeviceMesh("my_mini_pod", (2, 4), ("dp", "tp"))
# Shard x on 'dp' (rows), w on 'tp' (columns)
x = nabla.shard(nabla.uniform((32, 128)), mesh, nabla.P("dp", None))
w = nabla.shard(nabla.uniform((128, 256)), mesh, nabla.P(None, "tp"))
def compute_loss(x, w):
return nabla.mean(nabla.relu(x @ w))
# Automatic AllReduce is inserted for 'tp' sum
loss = compute_loss(x, w)
print("Loss (Sharded):", loss)
3. Mojo Integration#
Nabla’s core strength is its ability to drop down to Mojo for high-performance custom kernels, bridging the gap between high-level Python and bare-metal execution. Explore the code.
Mojo Kernel (kernels/custom_kernel.mojo)
@compiler.register("my_kernel")
struct MyKernel:
@staticmethod
def execute[target: StaticString](
output: OutputTensor,
x: InputTensor[dtype = output.dtype, rank = output.rank],
ctx: DeviceContextPtr,
):
@parameter
fn add_one[W: Int](idx: IndexList[x.rank]) -> SIMD[x.dtype, W]:
return x.load[W](idx) + 1
foreach[add_one, target=target](output, ctx)
Python Usage
class AddOneOp(nabla.UnaryOperation):
name = "my_kernel"
def kernel(self, x, **kwargs):
# Concise invocation: (func_name, path, inputs, out_types)
return nabla.call_custom_kernel("my_kernel", "./kernels", x, x.type)
x = nabla.Tensor.constant([1., 2., 3.])
y = AddOneOp()(x)
4. Distributed Pipeline Parallelism (GPipe)#
Define complex distributed schedules like GPipe using vmap for parallel execution and ppermute for explicit data movement. Explore the code.
# Parallel execution across 'num_stages'
@nabla.vmap(in_axes=(0, 0), spmd_axis_name="stage")
def stage_compute(x, w):
return nabla.relu(x @ w)
def pipeline_step(current_state, fresh_input, weights, mask_0):
# 1. Compute: Run all stages in parallel
computed = stage_compute(current_state, weights)
# 2. Communicate: Shift activations to the next stage (i -> i+1)
shifted = nabla.ppermute(computed, perm=[(i, (i + 1) % stages) for i in range(stages)])
# 3. Control: Stage 0 takes fresh input; others take shifted data
return nabla.where(mask_0, fresh_input, shifted)
5. Dynamic Shape Compilation#
Compile functions once with symbolic dimensions to handle varying input sizes without recompilation.
# Compile once for ANY batch size (dim 0)
@nabla.compile(dynamic_dims={0: {0: "batch"}})
def square(x):
return x * x
x_small = nabla.uniform((2, 10))
x_large = nabla.uniform((128, 10))
res1 = square(x_small) # Triggers compilation
res2 = square(x_large) # Reuses compiled graph!
Architecture Overview#
Nabla relies on three core principles:
Lazy Execution: Shapes are computed eagerly, but the computation graph is built and compiled only when
.realize()is called.Trace-Based Autodiff: Gradients are computed by tracing the forward pass and replaying operations in reverse.
Factor-Based SPMD: Sharding is propagated using “semantic factors” (e.g., batch, heads) rather than physical mesh axes.
Contributing#
Bugs/Docs: Submit PR directly.
Features: Open an Issue first.
New Ops: See nabla/ops/README.md.
License#
Apache-2.0 — see LICENSE