Sample random values from categorical distributions.
Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k
trick for sampling without replacement.
Note: Sampling without replacement currently uses argsort and slices the last
k elements. This should be replaced with a more efficient topK implementation.
key - PRNG key
logits - Unnormalized log probabilities of the categorical distribution(s).
softmax(logits, axis) gives the corresponding probabilities.
axis - Axis along which logits belong to the same categorical distribution.
shape - Result batch shape. Must be broadcast-compatible with
logits.shape with axis removed. Default is logits.shape with axis removed.
replace - If true (default), sample with replacement. If false, sample
without replacement (each category can only be selected once per batch).
Sample random values from categorical distributions.
Uses the Gumbel max trick for sampling with replacement, or the Gumbel top-k trick for sampling without replacement.
Note: Sampling without replacement currently uses argsort and slices the last k elements. This should be replaced with a more efficient topK implementation.
key- PRNG keylogits- Unnormalized log probabilities of the categorical distribution(s).softmax(logits, axis)gives the corresponding probabilities.axis- Axis along which logits belong to the same categorical distribution.shape- Result batch shape. Must be broadcast-compatible withlogits.shapewithaxisremoved. Default islogits.shapewithaxisremoved.replace- If true (default), sample with replacement. If false, sample without replacement (each category can only be selected once per batch).