Hi,
If the question seems to dumb, it is because I am new to TensorFlow.
I was implementing a toy endocer-decoder problem using TensorFlow 2’s TFA seq2seq implementation.
The API was clearly understandable until I wanted to change my BasicDecoder with BeamSearchDecoder.
My question is regarding start_tokens and end_token arguments’ initialization of BeamSearchDecoder.
Here is a copy of the implementation, any help is appreciated.
tf.keras.backend.clear_session()
tf.random.set_seed(42)
enc_vocab_size = len(train_vocab) + 1
dec_vocab_size = len(target_vocab) + 1
embed_size = 10
import tensorflow_addons as tfa
encoder_inputs = keras.layers.Input(shape=[None], dtype=np.int32)
decoder_inputs = keras.layers.Input(shape=[None], dtype=np.int32)
sequence_lengths = keras.layers.Input(shape=[], dtype=np.int32)
encoder_embeddings = keras.layers.Embedding(enc_vocab_size, embed_size)(encoder_inputs)
encoder = keras.layers.LSTM(512, return_state = True)
encoder_outputs, state_h, state_c = encoder(encoder_embeddings)
encoder_state = [state_h, state_c]
sampler = tfa.seq2seq.sampler.TrainingSampler()
decoder_embeddings = keras.layers.Embedding(dec_vocab_size, embed_size)(decoder_inputs)
decoder_cell = keras.layers.LSTMCell(512)
output_layer = keras.layers.Dense(dec_vocab_size)
beam_width = 10
start_tokens = tf.zeros([32], tf.dtypes.int32)
end_tokens = tf.constant([1], tf.dtypes.int32)
decoder = tfa.seq2seq.beam_search_decoder.BeamSearchDecoder(cell = decoder_cell, beam_width = beam_width, output_layer = output_layer)
decoder_initial_state = tfa.seq2seq.beam_search_decoder.tile_batch(encoder_state, multiplier = beam_width)
outputs, _, _ = decoder(decoder_embeddings, start_tokens = start_tokens, end_token = 0, initial_state = decoder_initial_state)
Y_proba = tf.nn.softmax(outputs.rnn_output)
model = keras.models.Model(inputs = [encoder_inputs, decoder_inputs], outputs = [Y_proba])
model.compile(loss="sparse_categorical_crossentropy", optimizer = 'adam', metrics = ['accuracy'])
Error trace:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-99-8287ffcfd4fa> in <module>()
34 decoder = tfa.seq2seq.beam_search_decoder.BeamSearchDecoder(cell = decoder_cell, beam_width = beam_width, output_layer = output_layer)
35 decoder_initial_state = tfa.seq2seq.beam_search_decoder.tile_batch(encoder_state, multiplier = beam_width)
---> 36 outputs, _, _ = decoder(decoder_embeddings, start_tokens = start_tokens, end_token = 0, initial_state = decoder_initial_state)
37 Y_proba = tf.nn.softmax(outputs.rnn_output)
38
1 frames
/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
65 except Exception as e: # pylint: disable=broad-except
66 filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67 raise e.with_traceback(filtered_tb) from None
68 finally:
69 del filtered_tb
/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
690 except Exception as e: # pylint:disable=broad-except
691 if hasattr(e, 'ag_error_metadata'):
--> 692 raise e.ag_error_metadata.to_exception(e)
693 else:
694 raise
ValueError: Exception encountered when calling layer "beam_search_decoder" (type BeamSearchDecoder).
in user code:
File "/usr/local/lib/python3.7/dist-packages/tensorflow_addons/seq2seq/beam_search_decoder.py", line 941, in call *
self,
File "/usr/local/lib/python3.7/dist-packages/typeguard/__init__.py", line 262, in wrapper *
retval = func(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/tensorflow_addons/seq2seq/decoder.py", line 430, in body *
(next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(
File "/usr/local/lib/python3.7/dist-packages/tensorflow_addons/seq2seq/beam_search_decoder.py", line 705, in step *
cell_outputs, next_cell_state = self._cell(
File "/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py", line 67, in error_handler **
raise e.with_traceback(filtered_tb) from None
ValueError: Exception encountered when calling layer "lstm_cell_1" (type LSTMCell).
Dimensions must be equal, but are 80 and 320 for '{{node beam_search_decoder/decoder/while/BeamSearchDecoderStep/lstm_cell_1/mul}} = Mul[T=DT_FLOAT](beam_search_decoder/decoder/while/BeamSearchDecoderStep/lstm_cell_1/Sigmoid_1, beam_search_decoder/decoder/while/BeamSearchDecoderStep/Reshape_2)' with input shapes: [320,80,2048], [320,512].
Call arguments received:
• inputs=tf.Tensor(shape=(320, None, 10), dtype=float32)
• states=ListWrapper(['tf.Tensor(shape=(320, 512), dtype=float32)', 'tf.Tensor(shape=(320, 512), dtype=float32)'])
• training=None
Call arguments received:
• embedding=tf.Tensor(shape=(None, None, 10), dtype=float32)
• start_tokens=tf.Tensor(shape=(32,), dtype=int32)
• end_token=0
• initial_state=['tf.Tensor(shape=(None, 512), dtype=float32)', 'tf.Tensor(shape=(None, 512), dtype=float32)']
• training=None
• kwargs=<class 'inspect._empty'>