I’m trying to build a two input model with keras, with each input being a string.
Here’s the code for the model
vectorize_layer1 = TextVectorization(split="character", output_sequence_length=512,
max_tokens=MAX_STRING_SIZE)
vectorize_layer1.adapt(list(vocab))
# define two sets of inputs
inputA = Input(shape=(1,), dtype=tf.string)
inputB = Input(shape=(1,), dtype=tf.string)
# the first branch operates on the first input
x = vectorize_layer1(inputA)
x = Embedding(len(vectorize_layer1.get_vocabulary()), MAX_STRING_SIZE)(x)
x = Bidirectional(LSTM(MAX_STRING_SIZE, return_sequences=True, dropout=.2))(x)
x = LSTM(MAX_STRING_SIZE, activation="tanh", return_sequences=False, dropout=.2)(x)
x = Model(inputs=inputA, outputs=x)
# the second branch opreates on the second input
y = vectorize_layer1(inputB)
y = Embedding(len(vectorize_layer1.get_vocabulary()), MAX_STRING_SIZE)(y)
y = Bidirectional(LSTM(MAX_STRING_SIZE, return_sequences=True, dropout=.2))(y)
y = LSTM(MAX_STRING_SIZE, activation="tanh", return_sequences=False, dropout=.2)(y)
y = Model(inputs=inputB, outputs=y)
# combine the output of the two branches
combined = concatenate([x.output, y.output])
# apply a FC layer and then a prediction on the categories
z = Dense(2, activation="relu")(combined)
z = Dense(len(LABELS), activation="softmax")(z)
# our model will accept the inputs of the two branches an
model = Model(inputs=[x.input, y.input], outputs=z)
print(model.predict((np.array(["i love python"]), np.array(["test"]))))
#That prediction works fine!
model.summary()
plot_model(model, to_file="model.png", show_shapes=True, show_layer_names=True)
model.compile(optimizer='Adam',
loss=CategoricalCrossentropy(from_logits=False),
metrics=["categorical_accuracy"])
I’m getting the Shapes are incompatible error though:
line 5119, in categorical_crossentropy
target.shape.assert_is_compatible_with(output.shape)
ValueError: Shapes (None, 1) and (None, 20) are incompatible
Here is an example of the training/validation data:
((array(['foo'], dtype='<U6'), array(['bar'], dtype='<U26')), array([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]))
Any ideas what is going wrong? The data looks ok to me, the categories are 1-hot encoded.