vjp#
Signature#
nabla.vjp(func: collections.abc.Callable[..., typing.Any], *primals, has_aux: bool = False) -> tuple[typing.Any, collections.abc.Callable] | tuple[typing.Any, collections.abc.Callable, typing.Any]
Description#
Compute vector-Jacobian product (reverse-mode autodiff).
Parameters#
func: Function to differentiate (should take positional arguments)
*primals: Positional arguments to the function (can be arbitrary pytrees)
has_aux: Optional, bool. Indicates whether func
returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
Returns#
If has_aux is False: Tuple of (outputs, vjp_function) where vjp_function computes gradients. If has_aux is True: Tuple of (outputs, vjp_function, aux) where aux is the auxiliary data.
The vjp_function always returns gradients as a tuple (matching JAX behavior):
Single argument: vjp_fn(cotangent) -> (gradient,)
Multiple arguments: vjp_fn(cotangent) -> (grad1, grad2, …)
Examples#
import nabla as nb
# Vector-Jacobian product for reverse-mode AD
def f(x):
return nb.sum(x ** 2)
x = nb.array([1.0, 2.0, 3.0])
output, vjp_fn = nb.vjp(f, x)
gradients = vjp_fn(nb.ones_like(output))
print(gradients) # [2.0, 4.0, 6.0]
Notes#
This follows JAX’s vjp API exactly:
Only accepts positional arguments
Always returns gradients as tuple
For functions requiring keyword arguments, use functools.partial or lambda
See Also#
jvp - Jacobian-vector product
grad - Automatic differentiation