Compute the inner product of two arrays.
Unlike jax.numpy.matmul() or jax.numpy.dot(), this always performs a contraction on the last axis.
jax.numpy.matmul()
jax.numpy.dot()
Returned array has shape [...x.shape[:-1], ...y.shape[:-1]].
[...x.shape[:-1], ...y.shape[:-1]]
Compute the inner product of two arrays.
Unlike
jax.numpy.matmul()orjax.numpy.dot(), this always performs a contraction on the last axis.Returned array has shape
[...x.shape[:-1], ...y.shape[:-1]].