jvp#

Signature#

nabla.jvp(func: collections.abc.Callable[..., typing.Any], primals, tangents, has_aux: bool = False) -> tuple[typing.Any, typing.Any] | tuple[typing.Any, typing.Any, typing.Any]

Description#

Compute Jacobian-vector product (forward-mode autodiff).

Parameters#

func: Function to differentiate (should take positional arguments) primals: Positional arguments to the function (can be arbitrary pytrees) tangents: Tangent vectors for directional derivatives (matching structure of primals) has_aux: Optional, bool. Indicates whether func returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

Returns#

If has_aux is False, returns a (outputs, output_tangents) pair. If has_aux is True, returns a (outputs, output_tangents, aux) tuple where aux is the auxiliary data returned by func.

Notes#

This follows JAX’s jvp API:

  • Only accepts positional arguments

  • For functions requiring keyword arguments, use functools.partial or lambda

See Also#

  • vjp - Vector-Jacobian product

  • grad - Automatic differentiation