jax-js
    Preparing search index...

    Function categorical

    • Sample random values from categorical distributions.

      Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k trick for sampling without replacement.

      Note: Sampling without replacement currently uses argsort and slices the last k elements. This should be replaced with a more efficient topK implementation.

      • key - PRNG key
      • logits - Unnormalized log probabilities of the categorical distribution(s). softmax(logits, axis) gives the corresponding probabilities.
      • axis - Axis along which logits belong to the same categorical distribution.
      • shape - Result batch shape. Must be broadcast-compatible with logits.shape with axis removed. Default is logits.shape with axis removed.
      • replace - If true (default), sample with replacement. If false, sample without replacement (each category can only be selected once per batch).

      Parameters

      • ...args: [
            key: ArrayLike,
            logits: ArrayLike,
            { axis?: number; replace?: boolean; shape?: number[] }?,
        ]

      Returns Array

      A random array with int dtype and shape given by shape if provided, otherwise logits.shape with axis removed.

    Properties

    dispose: () => void