I’ve a model that takes two inputs and outputs:
import tensorflow as tf
from tensorflow.keras import Input, Model
input1 = tf.keras.layers.Input(shape=(None, ))
input2 = tf.keras.layers.Input(shape=(None,))
inputs= tf.keras.layers.Concatenate(axis=1)([input1, input2])
inputs = tf.keras.layers.Reshape((-1,2))(inputs)
x = tf.keras.layers.Conv1D(filters=16, kernel_size=3, strides=1, padding="causal", activation="relu",input_shape=[None,2,1])(inputs)
x = tf.keras.layers.Bidirectional(tf.keras.layers.GRU(128, activation="tanh", return_sequences=True))(x)
x = tf.keras.layers.Bidirectional(tf.keras.layers.GRU(256, activation="tanh", return_sequences=True))(x)
x = tf.keras.layers.Dense(128, activation="tanh")(x)
o1 = tf.keras.layers.Dense(1, activation="linear",name="ed")(x)
o2 = tf.keras.layers.Dense(1, activation="sigmoid",name="sd")(x)
model = Model(inputs=[input1,input2], outputs=[o1, o2])
model.compile(loss={'ed': 'mean_squared_error',
'sd': 'binary_crossentropy'},
optimizer='adam',
metrics={'ed': tf.keras.metrics.MeanSquaredError(name="mean_squared_error", dtype=None),
'sd': tf.keras.metrics.BinaryCrossentropy(name="binary_crossentropy", dtype=None, from_logits=False, label_smoothing=0)})
Before feeding in the inputs , I’m windowing and shuffling them:
def windowed_dataset(series, window_size):
series = tf.expand_dims(series, axis=-1)
ds = tf.data.Dataset.from_tensor_slices(series)
ds = ds.window(window_size + 1, shift=1, drop_remainder=True)
ds = ds.flat_map(lambda w: w.batch(window_size + 1))
return ds
I’m individually passing each variable through the windowing function :
w1 = 512 #window size
at = windowed_dataset(at, w1) #input
timt = windowed_dataset(timt, w1) #input
wt = windowed_dataset(wt, w1) #output
wbt = windowed_dataset(wbt, w1) #output
Finally, I’m calling model.fit:
lr_schedule = tf.keras.callbacks.LearningRateScheduler(
lambda epoch: 1e-8 * 10**(epoch / 20))
history = model.fit([timt,at],[wt,wbt],epochs=100,callbacks=[lr_schedule])
The problem is it can’t find an adapter for FlatMapDataset:
ValueError: Failed to find data adapter that can handle input: (<class 'list'> containing values of types {"<class 'tensorflow.python.data.ops.dataset_ops.FlatMapDataset'>"}), (<class 'list'> containing values of types {"<class 'tensorflow.python.data.ops.dataset_ops.FlatMapDataset'>"})
Pls guide me on how to solve this.
Thanks