jax-js
    Preparing search index...

    Module @jax-js/optax

    @jax-js/optax

    This is a port of Optax to jax-js, for gradient processing and optimization. It includes implementations of common algorithms like Adam.

    import { adam } from "@jax-js/optax";

    let params = np.array([1.0, 2.0, 3.0]);

    const solver = adam(1e-3);
    let optState = solver.init(params.ref);
    let updates: np.Array;

    const f = (x: np.Array) => squaredError(x, np.ones([3])).sum();

    for (let i = 0; i < 100; i++) {
    const paramsGrad = grad(f)(params.ref);
    [updates, optState] = solver.update(paramsGrad, optState);
    params = applyUpdates(params, updates);
    }

    Interfaces

    GradientTransformation

    Type Aliases

    AdamWOptions
    AddDecayedWeightsOptions
    NormOrd
    OptState
    ScaleByAdamOptions
    SgdOptions
    TraceOptions

    Functions

    adam
    adamw
    addDecayedWeights
    applyUpdates
    chain
    clipByGlobalNorm
    identity
    l2Loss
    scale
    scaleByAdam
    scaleByLearningRate
    scaleBySchedule
    setToZero
    sgd
    squaredError
    trace
    treeBiasCorrection
    treeMax
    treeNorm
    treeOnesLike
    treeSum
    treeUpdateMoment
    treeZerosLike