Can anyone help with the tensorflow implementation of the VICReg loss terms. Thank you in advance.
I have been able to implement them. For anyone interested, here it is;
def off_diagonal(x):
n, m = x.shape[0], x.shape[1]
assert (n == m), f"Not a square tensor, dimensions found: {n} and {m}"
flattened_tensor = tf.reshape(x, [-1])[:-1]
elements = tf.reshape(flattened_tensor, [n - 1, n + 1])[:, 1:]
return tf.reshape(elements, [-1])
def invariance_loss(z_a,z_b):
'''invariance loss'''
mse_loss = tf.keras.metrics.mean_squared_error(z_a,z_b)
return mse_loss
def variance_loss(z_a,z_b):
'''variance preservation term to maintain the standard deviation
of each embedding over a batch applied separately to the two branches
'''
std_z_a = tf.math.sqrt(tf.math.reduce_variance(z_a,axis=0) + 1e-4)
std_z_b = tf.math.sqrt(tf.math.reduce_variance(z_b,axis=0) +1e-4)
std_loss = (tf.math.reduce_mean(tf.nn.relu(1 - std_z_a)) +
tf.math.reduce_mean(tf.nn.relu(1 - std_z_b))) * 0.5
return std_loss
def covariance_loss(z_a,z_b):
'''covariance between pairs of embedding over a batch applied
separately to the two branches
'''
z_a = z_a - tf.math.reduce_mean(z_a,axis=0)
z_b = z_b - tf.math.reduce_mean(z_b,axis=0)
cov_z_a = tf.linalg.matmul(z_a,z_a,transpose_a=True) / (N-1)
cov_z_b = tf.linalg.matmul(z_b,z_b,transpose_a=True) / (N-1)
cov_loss_z_a = tf.math.divide(tf.math.reduce_sum(tf.math.pow(off_diagonal(cov_z_a), 2)), D)
cov_loss_z_b = tf.math.divide(
tf.math.reduce_sum(tf.math.pow(off_diagonal(cov_z_b), 2)), D
)
return cov_loss_z_a + cov_loss_z_b
where N and D are the batch size and dimension size respectively
1 Like