jax-js
    Preparing search index...

    Function vjp

    • Calculate the reverse-mode vector-Jacobian product for a function.

      The return value is a tuple of [out, vjpFn], where out is the output of f(primals), and vjpFn is a function that takes in cotangents for each output and returns the cotangents for each input.

      When { hasAux: true } is passed, the function f is expected to return an [out, aux] tuple, and vjp returns [out, vjpFn, aux].

      Type Parameters

      • F extends (...args: any[]) => JsTree<Array>
      • const HA extends boolean = false

      Parameters

      Returns HA extends true
          ? ReturnType<F> extends [Out, Aux]
              ? [
                  Out,
                  OwnedFunction<
                      (
                          cotangents: MappedJsTree<Out, Array, ArrayLike>,
                      ) => MappedJsTree<Parameters<F>, ArrayLike, Array>,
                  >,
                  Aux,
              ]
              : never
          : [
              ReturnType<F>,
              OwnedFunction<
                  (
                      cotangents: MappedJsTree<ReturnType<F>, Array, ArrayLike>,
                  ) => MappedJsTree<Parameters<F>, ArrayLike, Array>,
              >,
          ]

      const [y, vjpFn] = vjp(f, [x]);

      // With hasAux
      const [y, vjpFn, aux] = vjp(f, [x], { hasAux: true });