Hi all,
I have a trained LSTM model which is performing a binary classification of waveform data. The waveforms themselves are different lengths so are stored in ragged tensors.
I want to be able to visualise which elements of a given waveform the model thinks are most important. Ideally an end result being a plot of the waveform with each datapoint of the waveform a colour depending on its value of importance.
I came across a Saliency model which could output gradients but cant seem to be able to get it working with ragged tensors.
I was wondering if anyone had ideas on how to best achive this? I’ve pasted my current code below.
Thnaks in advance!
Define a Saliency Model
class SaliencyModel(tf.keras.Model):
def init(self, base_model, input_shape):
super(SaliencyModel, self).init()
self.base_model = base_model
self.trainable_input = self.add_weight(
name=‘trainable_input’,
shape=(None,) + tuple(input_shape[1:]), # Extracting shape from input_shape
initializer=‘zeros’,
trainable=True,
dtype=tf.float32
)
def call(self, inputs):
with tf.GradientTape() as tape:
output = self.base_model(inputs)
gradients = tape.gradient(output, self.trainable_input)
return gradients
saliency_model = SaliencyModel(base_model=model, input_shape=model.input)