Your understanding is correct. If you want to contract away all non-batch dimensions, you can do something like this with the custom layer API:
class ExtendedDense(tf.keras.layers.Layer):
def __init__(self, units):
super(ExtendedDense, self).__init__()
self.units = units
def build(self, input_shape):
rank = len(input_shape)
self.axes = [list(range(1, rank)), list(range(0, rank - 1))]
self.kernel = self.add_weight("kernel", shape=input_shape[1::] + (self.units,))
self.bias = self.add_weight("bias", shape=(self.units,))
def call(self, x):
return tf.tensordot(x, self.kernel, self.axes) + self.bias
inp = tf.keras.layers.Input(shape=(64, 128,))
out = ExtendedDense(256)(inp)
model = tf.keras.models.Model(inp, out)
And usage:
x = tf.random.normal((32, 64, 128))
model(x)
Output:
<tf.Tensor: shape=(32, 256), dtype=float32, numpy=
array([[ 5.87857895e+01, 7.20561981e+01, 4.76380234e+01, ...,
1.35572968e+02, -7.85310516e+01, 8.38421555e+01],
[-7.71801834e+01, 1.77247452e+02, -1.48629440e+02, ...,
3.40917473e+01, 1.80250931e+00, -8.22327881e+01],
[ 2.06383377e+02, 5.67772293e+01, -1.80862427e-01, ...,
6.32748642e+01, 3.95997772e+01, -4.52198296e+01],
...,
[ 5.45089569e+01, 1.02864510e+02, 1.28402420e+02, ...,
-1.03204765e+02, -1.85087776e+01, 3.91985130e+01],
[-2.94154816e+01, -8.31197052e+01, -9.01190338e+01, ...,
1.73478546e+02, -1.15511642e+02, -8.12764282e+01],
[-1.32574982e+02, 5.21116562e+01, 1.01657272e+02, ...,
-2.43008301e+02, 5.90627060e+01, 1.45800140e+02]], dtype=float32)>
<tf.Tensor: shape=(32, 256), dtype=float32, numpy=
array([[-2.30272830e-01, -1.33582354e+00, -2.89380431e-01, ...,
1.00469515e-01, -1.17660590e-01, 2.11346745e-01],
[-5.74482441e-01, -1.97411704e+00, -1.08441144e-01, ...,
-1.35504484e+00, 3.19112301e-01, 1.21400483e-01],
[ 2.24482596e-01, -3.09928954e-01, 1.68636054e-01, ...,
-1.44872054e-01, -8.07262778e-01, 2.01987267e-01],
...,
[ 1.22694087e+00, -1.46654761e+00, -5.89887261e-01, ...,
-1.97368228e+00, -1.25852346e+00, 1.33283436e-04],
[-1.33344865e+00, -5.02050817e-01, 8.51780176e-01, ...,
-7.39977777e-01, -3.36842000e-01, -2.30333179e-01],
[ 2.70515770e-01, -9.70930576e-01, 1.15700647e-01, ...,
2.89812565e-01, -1.02232158e+00, -6.77653134e-01]], dtype=float32)>