Hi
I have a pruned TF model, which I need to retrain with the streaming data. ). I want to retrain only the non-zero weights of the 80% pruned model. I’d like to avoid creating a mask and performing additional calculations as I want to minimize retraining time. Here’s the training function that I’m currently using.
@tf.function
def train_step(inputs, targets):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
loss_value = loss_fn(targets, predictions)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
return loss_value
Is there any way we can calculate the grads only of the non-zero weights or use only the non-zero weights in model.trainable_weights
? If not, is there any way to use tf.IndexedSlices to update non-zero weights efficiently?
Highly appreciate any support on this