xpr#
Signature#
nabla.xpr(fn: 'Callable[..., Any]', *primals) -> 'str'
Description#
Get a JAX-like string representation of the function’s computation graph.
Parameters#
fn: Function to trace (should take positional arguments) *primals: Positional arguments to the function (can be arbitrary pytrees)
Returns#
JAX-like string representation of the computation graph
Notes#
This follows the same flexible API as vjp, jvp, and vmap:
Accepts functions with any number of positional arguments
For functions requiring keyword arguments, use functools.partial or lambda