Query array; shape [B, L, N, H]
Key array; shape [B, S, K, H]
Value array; same shape as key
Optionalbias?: ArrayLikeOptional bias to add to the attention logits; shape
[B, N, L, S] or broadcastable to it.
OptionalisCausal?: booleanIf true, applies a casual mask.
OptionalkeyValueSeqLengths?: ArrayLikeOptional sequence lengths for the keys and
values; shape (B,). Taken from the beginning of the tensor.
OptionallocalWindowSize?: number | [number, number]If specified, applies a local attention window
of the given size. Can be a single number or a tuple [left, right].
Optionalmask?: ArrayLikeOptional mask to apply to the attention logits; should be
a boolean array broadcastable to [B, N, L, S], where true indicates
the element should take part in attention.
OptionalquerySeqLengths?: ArrayLikeOptional sequence lengths for the queries;
shape (B,). Taken from the beginning of the tensor.
Optionalscale?: numberScaling factor override, default is 1 / sqrt(H).
The result of the attention operation; shape is the same as query
[B, L, N, H], or [L, N, H] if B is omitted.
Scaled dot product attention (SDPA).
Computes
softmax((Q @ K^T) / sqrt(d) + bias) @ V, whereQis the query,Kis the key,Vis the value, anddis the dimensionality of each key and query vector.Multi-query attention is applied when input
keyandvaluetensors have fewer heads thanquery.We use the following uppercase letters to denote array shapes:
B= batch sizeS= length of key/value sequences (source)L= length of query sequencesN= number of attention headsH= dimensionality of each attention headK= number of key/value heads (for grouped-query attention)The batch size
Bmay be omitted, which is equivalent toB = 1. In this case it must be omitted from all inputs.