I am trying to create a custom layer which computes W*x + b where W is a sparse tensor. It is important that I don’t ever form the dense version of W because it would be too large to store in memory. It is my understanding that the computation of W*x, using tf.sparse.sparse_dense_matmul(W, x),
does not have a supported gradient.
To make this work, I am trying to implement a custom gradient using the following code:
@tf.custom_gradient
def sparse_weight_multiply(self, w):
# compute the product sparse_W * inputs, where sparse_W is a sparse tensor formed from the entries in w
self.sparse_W = tf.sparse.SparseTensor(self.indices, w, self.shape)
w_inputs = tf.sparse.sparse_dense_matmul(self.sparse_W, self.inputs)
# define gradient for this function
def sparse_weight_grad(upstream_grad):
'''
upstream_grad is the gradient computed thus far in the computational graph.
The output of this function will be the gradient of the function sparse_weight_multiply
times upstream_grad, due to the product rule in differentiation.
'''
# check the shape of upstream_grad
print("Shape of upstream grad: {}".format(upstream_grad.shape))
print("Shape of inputs: {}".format(inputs.shape))
print("Shape of weight: {}".format(self.shape))
# map entries of input to corresponding locations in the gradient of weights*inputs
n_out = upstream_grad.shape[0]
num_RHS = upstream_grad.shape[1]
indices_i = range(0,self.num_connections)
indices_j = self.indices[:, 0]
indices_k = self.indices[:, 1]
J_indices = tf.cast(tf.transpose(tf.concat([[indices_i],[indices_j]], 0)), tf.int64)
grad_weights = []
for l in range(num_RHS):
input_permuted = np.array(self.inputs)[:,l][indices_k]
sparse_J = tf.sparse.SparseTensor(J_indices, input_permuted, (self.num_connections, n_out))
grad_weights.append(tf.sparse.sparse_dense_matmul(sparse_J, tf.reshape(upstream_grad[:,l], (n_out, 1)) ))
grad_weights = tf.transpose(tf.squeeze(tf.convert_to_tensor(grad_weights)))
return grad_weights
return w_inputs, sparse_weight_grad
However, I get the error:
tensorflow.python.framework.errors_impl.InvalidArgumentError: var and grad do not have the same shape[9632] [9632,638] [Op:ResourceApplyAdam]
I believe this is because my input, w
, is a tensor of shape [9632]. However, I want to compute the gradient of W*x
for each input x
to the layer, of which I have 638 in my training set. Thus, the gradient I return has shape [9632,638], corresponding to a gradient with shape [9632] for each input. This matches the shape of the upstream gradient I am given, which has shape (3836, 638). I definitely want to pass a gradient for each input, but how do I tell tensorflow that that is what I am doing?