jax-js
    Preparing search index...

    Function convTranspose

    • Convenience wrapper for calculating the N-d convolution "transpose".

      This function directly calculates a fractionally strided conv rather than indirectly calculating the gradient (transpose) of a forward convolution. It is equivalent to the JAX version, except:

      • The use_consistent_padding option is not available. We only have the consistent padding case (JAX version >0.8.4).
      • The order of dimensions matches lax.conv_general_dilated.

      Unlike PyTorch/TensorFlow, by default we don't reverse the kernel's spatial dimensions or the (C_out, C_in) axis order. To get this behavior, set transposeKernel to true.

      Parameters

      • lhs: Array

        Input tensor; shape [N, C_in, ...xs]

      • rhs: Array

        Convolution kernel; shape [C_out, C_in, ...ks]

      • strides: number[]

        Sequence of n integers, sets fractional stride

      • padding: PaddingType

        Apply padding of dilation * (kernel_size - 1) - padding to each side of the input, so it acts like gradient of conv()

      • rhsDilation: { rhsDilation?: number[]; transposeKernel?: boolean } = {}

        Atrous dilation for the kernel

      Returns Array