Hi, this is a similar question to: python - How to create a keras layer with a custom gradient in TF2.0? - Stack Overflow
Only, I would like to introduce a learnable parameter into the custom layer that I am training.
Here’s a toy example of my current approach here:
# Method for calculation custom gradient
@tf.custom_gradient
def scaler(x, s):
def grad(upstream):
dy_dx = s
dy_ds = x
return dy_dx, dy_ds
return x * s, grad
# Keras Layer with trainable parameter
class TestLayer(tf.keras.layers.Layer):
def build(self, input_shape):
self.scale = self.add_weight("scale",
shape=[1,],
initializer=tf.keras.initializers.Constant(value=2.0),
trainable=True)
def call(self, inputs):
return scaler(inputs, self.scale)
# Creates Keras Model that uses the layer
def Model():
x_in = tf.keras.layers.Input(shape=(1,))
x_out = TestLayer()(x_in)
return tf.keras.Model(inputs=x_in, outputs=x_out, name="fp8_test")
# Create toy dataset, want to learn `scale` such to satisfy 5 = 2 * scale (i.e, `scale` should learn ~2.5)
def Dataset():
inps = tf.ones(shape=(10**5,)) * 2 # inputs
expected = tf.ones(shape=(10**5,)) * 5 # targets
data_in = tf.data.Dataset.from_tensors(inps)
data_exp = tf.data.Dataset.from_tensors(expected)
dataset = tf.data.Dataset.zip((data_in, data_exp))
return dataset
model = Model()
model.summary()
dataset = Dataset()
# Use `MSE` loss and `SGD` optimizer
model.compile(
loss=tf.keras.losses.MSE,
optimizer=tf.keras.optimizers.SGD(),
)
model.fit(dataset, epochs=100)
This is failing with the following shape related error in the optimizer:
ValueError: Shapes must be equal rank, but are 1 and 2 for '{{node SGD/SGD/update/ResourceApplyGradientDescent}} = ResourceApplyGradientDescent[T=DT_FLOAT, use_locking=true](fp8_test/test_layer_1/ReadVariableOp/resource, SGD/Identity, SGD/IdentityN)' with input shapes: [], [], [100000,1].
I’ve been staring at the docs for a while, I’m a bit stumped as to why this isn’t working, I would really appreciate any input on how to fix this toy example.
Thanks in advance.
I attempted the above toy example.