What’s the equivalent of PyTorch InstanceNorm2d(channels)? I use BatchNormalization(axis=1) and using it with input of shape (n_test, image_channels, image_size, image_size) doesn’t seem to transform anything at all.
Hi @khteh, BatchNormalization normalizes the mean and variance across the entire batch. Since InstanceNormalization computes the mean and variance over the spatial dimensions for each channel independently. Both works differently. To get the equivalent of PyTorch InstanceNorm2d(channels), Consider GroupNormalization. GroupNormalization divides the channels into groups and computes within each group the mean and variance for normalization.
1 Like