Intermittent shape error in custom layer

I have design custom layer that is presenting a shape mismatch eventually. It run for many batches but then a mismatch shape error. It looks like a dynamic shape error, but I’m failing in fix it.

Here is my custom layer for which the input shape is (299,):

import tensorflow as tf

class LinePlotMatrixLayer(tf.keras.layers.Layer):
    def __init__(self, num_lines=299, **kwargs):
        super(LinePlotMatrixLayer, self).__init__(**kwargs)
        self.num_lines = num_lines

    def build(self, input_shape):
        super(LinePlotMatrixLayer, self).build(input_shape)

    def call(self, inputs):
        inputs = tf.squeeze(inputs, axis=-1)  # Shape [batch_size, seq_length]

        # Normalize inputs
        min_val = tf.reduce_min(inputs, axis=1, keepdims=True)
        max_val = tf.reduce_max(inputs, axis=1, keepdims=True)
        inputs_normalized = (inputs - min_val) / (max_val - min_val + 1e-5) * (self.num_lines - 1)

        # Round and convert to integer indices
        inputs_rounded = tf.cast(tf.round(inputs_normalized), tf.int32)
        inputs_rounded = tf.clip_by_value(inputs_rounded, 0, self.num_lines - 1)

        # Dynamic batch size and time steps
        batch_size = tf.shape(inputs)[0]
        time_steps = tf.shape(inputs)[1]

        # Create row_indices
        row_indices = tf.range(self.num_lines, dtype=tf.int32)
        row_indices = tf.reshape(row_indices, [1, self.num_lines, 1])
        row_indices = tf.tile(row_indices, [batch_size, 1, time_steps])

        inputs_expanded = tf.expand_dims(inputs_rounded, axis=1)
        inputs_expanded = tf.tile(inputs_expanded, [1, self.num_lines, 1])

        tf.print("Shape of row_indices:", tf.shape(row_indices))
        tf.print("Shape of inputs_expanded:", tf.shape(inputs_expanded))

        # Perform comparison and create mask
        mask = tf.cast(tf.less_equal(row_indices, inputs_expanded), tf.float32)
        matrix_rgb = tf.stack([mask] * 3, axis=-1)

        return matrix_rgb

    def compute_output_shape(self, input_shape):
        return tf.TensorShape([input_shape[0], self.num_lines, input_shape[1], 3])

    def get_config(self):
        config = super(LinePlotMatrixLayer, self).get_config()
        config.update({"num_lines": self.num_lines})
        return config

The error:

get_v input shape: (16, 299, 1)
get_v input shape: (16, 299, 1)
Shape of row_indices: [1]
Traceback (most recent call last):

  File ~\anaconda3\envs\tf_rl_env\lib\site-packages\spyder_kernels\py3compat.py:356 in compat_exec
    exec(code, globals, locals)

  File c:\users\administrador\desktop\we-bronx\reinforcement learning\ppo_script\main.py:118
    ppo.train(states.numpy(), actions.numpy(), discounted_r.numpy(), current_entropy_multiplier, global_step, writer)

  File ~\Desktop\we-bronx\Reinforcement Learning\PPO_script\ppo.py:52 in train
    adv = self.advantage_function(states, rewards)

  File ~\Desktop\we-bronx\Reinforcement Learning\PPO_script\ppo.py:47 in advantage_function
    v = self.get_v(s)

  File ~\Desktop\we-bronx\Reinforcement Learning\PPO_script\ppo.py:42 in get_v
    value = self.critic_model(s)

  File ~\anaconda3\envs\tf_rl_env\lib\site-packages\keras\utils\traceback_utils.py:70 in error_handler
    raise e.with_traceback(filtered_tb) from None

  File ~\Desktop\we-bronx\Reinforcement Learning\PPO_script\custom_layer.py:36 in call
    tf.print("Shape of inputs_expanded:", tf.shape(inputs_expanded))

InvalidArgumentError: Exception encountered when calling layer "line_plot_matrix_layer_5" "                 f"(type LinePlotMatrixLayer).

{{function_node __wrapped__Shape_device_/job:localhost/replica:0/task:0/device:GPU:0}} Expected begin, end, and strides to be 1D equal size tensors, but got shapes [16,299,299], [1], and [1] instead. [Op:Shape]

Call arguments received by layer "line_plot_matrix_layer_5" "                 f"(type LinePlotMatrixLayer):
  • inputs=tf.Tensor(shape=(16, 299, 1), dtype=float32)

Notice that suddenly the Shape of row_indices gets [1] when it is Shape of row_indices: [16 299 299].

Apparently, passing the shape dimentions as arguments and using the @tf.function as a decorator in the call method has solved the bug. I’m not sure why, but I would guess that in some of those methods used in the call there are some pythonic functions.

If someone can tell me something more concrete I’d love to hear it.