Log-mean-exp reduction, like jax.nn.logsumexp() but subtracts log(n).
jax.nn.logsumexp()
log(n)
Optional
Log-mean-exp reduction, like
jax.nn.logsumexp()but subtractslog(n).