Understanding Keras Dense Layer for rank >2

Hey all,

the official API doc states on the page regarding tf.keras.layers.Dense that

Note: If the input to the layer has a rank greater than 2, then Dense computes the dot product between the inputs and the kernel along the last axis of the inputs and axis 0 of the kernel (using tf.tensordot). For example, if input has dimensions (batch_size, d0, d1), then we create a kernel with shape (d1, units), and the kernel operates along axis 2 of the input, on every sub-tensor of shape (1, 1, d1) (there are batch_size * d0 such sub-tensors). The output in this case will have shape (batch_size, d0, units).

My understanding is that only one kernel is created for ranks larger than 2. That would mean that (in case of rank 3) the “middle” dimensions are acted upon by the same kernel and thus the outputs for different indeces of the “middle” dimension are not independent.

Is that understanding correct? And if it is, is there a simple way to get the Dense layer to use a stack of kernels? (The network currently has the tensor multiplication implemented manually.)

Thank you all in advance.
Fabian

Your understanding is correct. If you want to contract away all non-batch dimensions, you can do something like this with the custom layer API:

class ExtendedDense(tf.keras.layers.Layer):
  def __init__(self, units):
    super(ExtendedDense, self).__init__()
    self.units = units

  def build(self, input_shape):
    rank = len(input_shape)
    self.axes = [list(range(1, rank)), list(range(0, rank - 1))]
    self.kernel = self.add_weight("kernel", shape=input_shape[1::] + (self.units,))
    self.bias = self.add_weight("bias", shape=(self.units,))
  
  def call(self, x):
    return tf.tensordot(x, self.kernel, self.axes) + self.bias

inp = tf.keras.layers.Input(shape=(64, 128,))
out = ExtendedDense(256)(inp)

model = tf.keras.models.Model(inp, out)

And usage:

x = tf.random.normal((32, 64, 128))
model(x)

Output:



<tf.Tensor: shape=(32, 256), dtype=float32, numpy=
array([[ 5.87857895e+01,  7.20561981e+01,  4.76380234e+01, ...,
         1.35572968e+02, -7.85310516e+01,  8.38421555e+01],
       [-7.71801834e+01,  1.77247452e+02, -1.48629440e+02, ...,
         3.40917473e+01,  1.80250931e+00, -8.22327881e+01],
       [ 2.06383377e+02,  5.67772293e+01, -1.80862427e-01, ...,
         6.32748642e+01,  3.95997772e+01, -4.52198296e+01],
       ...,
       [ 5.45089569e+01,  1.02864510e+02,  1.28402420e+02, ...,
        -1.03204765e+02, -1.85087776e+01,  3.91985130e+01],
       [-2.94154816e+01, -8.31197052e+01, -9.01190338e+01, ...,
         1.73478546e+02, -1.15511642e+02, -8.12764282e+01],
       [-1.32574982e+02,  5.21116562e+01,  1.01657272e+02, ...,
        -2.43008301e+02,  5.90627060e+01,  1.45800140e+02]], dtype=float32)>

<tf.Tensor: shape=(32, 256), dtype=float32, numpy=
array([[-2.30272830e-01, -1.33582354e+00, -2.89380431e-01, ...,
         1.00469515e-01, -1.17660590e-01,  2.11346745e-01],
       [-5.74482441e-01, -1.97411704e+00, -1.08441144e-01, ...,
        -1.35504484e+00,  3.19112301e-01,  1.21400483e-01],
       [ 2.24482596e-01, -3.09928954e-01,  1.68636054e-01, ...,
        -1.44872054e-01, -8.07262778e-01,  2.01987267e-01],
       ...,
       [ 1.22694087e+00, -1.46654761e+00, -5.89887261e-01, ...,
        -1.97368228e+00, -1.25852346e+00,  1.33283436e-04],
       [-1.33344865e+00, -5.02050817e-01,  8.51780176e-01, ...,
        -7.39977777e-01, -3.36842000e-01, -2.30333179e-01],
       [ 2.70515770e-01, -9.70930576e-01,  1.15700647e-01, ...,
         2.89812565e-01, -1.02232158e+00, -6.77653134e-01]], dtype=float32)>
1 Like

Hey Sean,

Thank you for the quick reply and the example.

Fabian