General n-dimensional convolution operator, with optional dilation.
The semantics of this operation mimic the jax.lax.conv_general_dilated function in JAX, which wraps XLA's general convolution operator.
jax.lax.conv_general_dilated
Input tensor; shape [N, C_in, ...xs]
[N, C_in, ...xs]
Convolution kernel; shape [C_out, C_in / G, ...ks]
[C_out, C_in / G, ...ks]
Strides for each spatial dimension
Padding for each spatial dimension, or a string ("VALID", "SAME", or "SAME_LOWER")
"VALID"
"SAME"
"SAME_LOWER"
General n-dimensional convolution operator, with optional dilation.
The semantics of this operation mimic the
jax.lax.conv_general_dilatedfunction in JAX, which wraps XLA's general convolution operator.