jax-js
    Preparing search index...

    Function grad

    • Compute the gradient of a scalar-valued function f with respect to its first argument.

      Pass in different argnums to differentiate with respect to other arguments. If a tuple is provided, the return value will be a tuple of gradients corresponding to each argument index.

      When { hasAux: true } is passed, the function f is expected to return a [out, aux] tuple, and the return value will be [gradient, aux].

      Type Parameters

      • F extends (...args: any[]) => JsTree<Array>
      • const I extends number | number[] | undefined = undefined
      • const HA extends boolean = false

      Parameters

      • f: F
      • Optionalopts: Omit<{ argnums?: number | number[]; hasAux?: boolean }, "hasAux" | "argnums"> & {
            argnums?: I;
            hasAux?: HA;
        }

      Returns (
          ...primals: MappedJsTree<Parameters<F>, Array, ArrayLike>,
      ) => HA extends true
          ? ReturnType<F> extends [any, Aux]
              ? [
                  MappedJsTree<
                      I extends undefined
                          ? Parameters<F>[0]
                          : I extends number
                              ? Parameters<F>[I<I>]
                              : I extends number[]
                                  ? {
                                      [K in string
                                      | number
                                      | symbol]: (...)[(...)] extends number ? (...)[(...)] : never
                                  }
                                  : never,
                      ArrayLike,
                      Array,
                  >,
                  Aux,
              ]
              : never
          : MappedJsTree<
              I extends undefined
                  ? Parameters<F>[0]
                  : I extends number
                      ? Parameters<F>[I<I>]
                      : I extends number[]
                          ? {
                              [K in string
                              | number
                              | symbol]: I<I>[K] extends number ? Parameters<F>[any[any]] : never
                          }
                          : never,
              ArrayLike,
              Array,
          >

      const gradient = grad(f)(x);

      // With `argnums`
      const [gradientX, gradientZ] = grad(f, { argnums: [0, 2] })(x, y, z);

      // With `hasAux`
      const [gradient, aux] = grad(f, { hasAux: true })(x);