Input tensor; shape [N, C_in, ...xs]
Convolution kernel; shape [C_out, C_in, ...ks]
Sequence of n integers, sets fractional stride
Apply padding of dilation * (kernel_size - 1) - padding to
each side of the input, so it acts like gradient of conv()
Atrous dilation for the kernel
Convenience wrapper for calculating the N-d convolution "transpose".
This function directly calculates a fractionally strided conv rather than indirectly calculating the gradient (transpose) of a forward convolution. It is equivalent to the JAX version, except:
use_consistent_paddingoption is not available. We only have the consistent padding case (JAX version >0.8.4).lax.conv_general_dilated.Unlike PyTorch/TensorFlow, by default we don't reverse the kernel's spatial dimensions or the
(C_out, C_in)axis order. To get this behavior, settransposeKernelto true.