jax-js
    Preparing search index...

    Function dotProductAttention

    • Scaled dot product attention (SDPA).

      Computes softmax((Q @ K^T) / sqrt(d) + bias) @ V, where Q is the query, K is the key, V is the value, and d is the dimensionality of each key and query vector.

      Multi-query attention is applied when input key and value tensors have fewer heads than query.

      We use the following uppercase letters to denote array shapes:

      • B = batch size
      • S = length of key/value sequences (source)
      • L = length of query sequences
      • N = number of attention heads
      • H = dimensionality of each attention head
      • K = number of key/value heads (for grouped-query attention)

      The batch size B may be omitted, which is equivalent to B = 1. In this case it must be omitted from all inputs.

      Parameters

      • query: ArrayLike

        Query array; shape [B, L, N, H]

      • key: ArrayLike

        Key array; shape [B, S, K, H]

      • value: ArrayLike

        Value array; same shape as key

      • opts: {
            bias?: ArrayLike;
            isCausal?: boolean;
            keyValueSeqLengths?: ArrayLike;
            localWindowSize?: number | [number, number];
            mask?: ArrayLike;
            querySeqLengths?: ArrayLike;
            scale?: number;
        } = {}
        • Optionalbias?: ArrayLike

          Optional bias to add to the attention logits; shape [B, N, L, S] or broadcastable to it.

        • OptionalisCausal?: boolean

          If true, applies a casual mask.

        • OptionalkeyValueSeqLengths?: ArrayLike

          Optional sequence lengths for the keys and values; shape (B,). Taken from the beginning of the tensor.

        • OptionallocalWindowSize?: number | [number, number]

          If specified, applies a local attention window of the given size. Can be a single number or a tuple [left, right].

        • Optionalmask?: ArrayLike

          Optional mask to apply to the attention logits; should be a boolean array broadcastable to [B, N, L, S], where true indicates the element should take part in attention.

        • OptionalquerySeqLengths?: ArrayLike

          Optional sequence lengths for the queries; shape (B,). Taken from the beginning of the tensor.

        • Optionalscale?: number

          Scaling factor override, default is 1 / sqrt(H).

      Returns Array

      The result of the attention operation; shape is the same as query [B, L, N, H], or [L, N, H] if B is omitted.