Sample random values according to p(x) = 1/sqrt(2pi) * exp(-x^2/2).
Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
directly inverts the CDF, but we don't have support for that yet. Outputs will not be
bitwise identical to JAX.
Sample random values according to
p(x) = 1/sqrt(2pi) * exp(-x^2/2).Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and directly inverts the CDF, but we don't have support for that yet. Outputs will not be bitwise identical to JAX.