PyTorch Ease. JAX Scale. Mojo Speed.

Nabla is a modular scientific computing and machine learning framework combining imperative and functional APIs. Seamlessly drop custom Mojo kernels into the autodiff engine and automatically shard distributed workloads.

pip install --pre --extra-index-url https://whl.modular.com/nightly/simple/ modular nabla-ml
Get started in Nabla
Getting Started
Setup, tensors, autodiff, and your first training loop.
Getting Started
Nabla API reference
API Reference
Core tensors, transforms, ops, and neural-network primitives.
API Reference
Nabla examples and tutorials
Examples
Guided notebooks across MLPs, transformers, and fine-tuning.
Examples

Why Nabla#

Eager Metadata, Compiled Speed

Get the debuggability of PyTorch with the performance of compiled graphs. Nabla computes shapes eagerly but defers graph building, using zero-overhead caching to skip recompilation on hot paths.

Automatic Sharding (SPMD)

Write single-device code and let Nabla handle the rest. Nabla automatically propagates sharding specs and handles communication for distributed training and inference workloads.

Custom Mojo Kernels in Autodiff

Need more performance? Drop down to Mojo for custom kernels and seamlessly integrate them into Nabla’s autodiff engine via nabla.call_custom_kernel(...).

Modular by design

Use nb.nn.Module and nb.nn.functional side-by-side. Nabla supports imperative and functional workflows in one framework, so you can use the style that fits your workflow.

Try It Quickly#

import nabla as nb

model = nb.nn.Sequential(
    nb.nn.Linear(128, 256),
    nb.nn.ReLU(),
    nb.nn.Linear(256, 10),
)
import nabla as nb


def loss_fn(x, w):
    return nb.mean(nb.relu(x @ w))


grad_w = nb.grad(loss_fn, argnums=1)(x, w)
import nabla as nb

# Quantize frozen weights to NF4
qweight = nb.nn.finetune.quantize_nf4(frozen_weight, block_size=64)

# Initialize LoRA adapter
lora_params = nb.nn.finetune.init_lora_adapter(frozen_weight, rank=8)

# Forward pass with QLoRA
def loss_fn(adapter, batch_x, batch_y):
    pred = nb.nn.finetune.qlora_linear(
        batch_x, qweight, adapter, alpha=16.0
    )
    return nb.mean((pred - batch_y) ** 2)

# Compute gradients
loss, grads = nb.value_and_grad(loss_fn)(lora_params, x, y)

From the team

Blog and release notes section coming soon.

Project Status#

Nabla is currently in Alpha. It is an experimental framework designed to explore new ideas in ML infrastructure on top of Modular MAX. APIs are subject to change, and we welcome early adopters to join us in building the next generation of ML tools.