On device training with TFL

I have been looking at On-Device Training with TensorFlow Lite

It gives an example of a very simple model.

class Model(tf.Module):

  def __init__(self):
    self.model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(IMG_SIZE, IMG_SIZE), name='flatten'),
        tf.keras.layers.Dense(128, activation='relu', name='dense_1'),
        tf.keras.layers.Dense(10, name='dense_2')
    ])

    self.model.compile(
        optimizer='sgd',
        loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True))

  # The `train` function takes a batch of input images and labels.
  @tf.function(input_signature=[
      tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),
      tf.TensorSpec([None, 10], tf.float32),
  ])

I have been thinking I could try and implement it with the excellent google-research/kws_streaming at master · google-research/google-research · GitHub as its such a good framework for KWS.
What I want to ask is the signature is it just basically input and output layers is the above purely because the above is very simple?

Also whilst here anyone else using the kws_streaming framework or On device please say high or link a Github.

But yeah if anyone has any info about the signature requirements with more complex models or it is just in/out or if softmax if in the model should be the output?

Hi @rolyan_trauts,

In general signatures in TFlite defines the input and output data type and shape to ensure compatibility and integrity with the unseen input data while inferencing on edge devices. Here is the reference documentation for signatures in TFLite.

In the kws_streaming framework, function signatures are essential for defining the expected inputs and outputs of your keyword spotting (KWS) model. Here is the example Signature definition for simple KWS model.

# this is sample signature for complex KWS model  with multiple input
@tf.function(input_signature=[
    tf.TensorSpec([None, dim1, dim2], tf.float32),
    tf.TensorSpec([None, comp1], tf.float32)

    tf.TensoeSpec([None, label_count], tf.float32)
])

If your model uses softmax for classification probabilities, the output in the signature should represent those probabilities.

Thank You