To reproduce:
!pip install tensorflow-text==2.7.0
import tensorflow_text as text
import tensorflow_hub as hub
# ... other tf imports....
strategy = tf.distribute.MirroredStrategy()
print('Number of GPU: ' + str(strategy.num_replicas_in_sync)) # 1 or 2, shouldn't matter
NUM_CLASS=2
with strategy.scope():
bert_preprocess = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3")
bert_encoder = hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4")
def get_model():
text_input = Input(shape=(), dtype=tf.string, name='text')
preprocessed_text = bert_preprocess(text_input)
outputs = bert_encoder(preprocessed_text)
output_sequence = outputs['sequence_output']
x = Dense(NUM_CLASS, activation='sigmoid')(output_sequence)
model = Model(inputs=[text_input], outputs = [x])
return model
optimizer = Adam()
model = get_model()
model.compile(loss=CategoricalCrossentropy(from_logits=True),optimizer=optimizer,metrics=[Accuracy(), ],)
model.summary() # <- look at the output 1
tf.keras.utils.plot_model(model, show_shapes=True, to_file='model.png') # <- look at the figure 1
with strategy.scope():
optimizer = Adam()
model = get_model()
model.compile(loss=CategoricalCrossentropy(from_logits=True),optimizer=optimizer,metrics=[Accuracy(), ],)
model.summary() # <- compare with output 1, it has already lost it's shape
tf.keras.utils.plot_model(model, show_shapes=True, to_file='model_scoped.png') # <- compare this figure too, for ease
With scope
, BERT loses seq_length
, and it becomes None
.
Model summary withOUT scope: (See there is 128 at the very last layer, which is seq_length)
Model: "model_6"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
text (InputLayer) [(None,)] 0 []
keras_layer_2 (KerasLayer) {'input_mask': (Non 0 ['text[0][0]']
e, 128),
'input_word_ids':
(None, 128),
'input_type_ids':
(None, 128)}
keras_layer_3 (KerasLayer) multiple 109482241 ['keras_layer_2[6][0]',
'keras_layer_2[6][1]',
'keras_layer_2[6][2]']
dense_6 (Dense) (None, 128, 2) 1538 ['keras_layer_3[6][14]']
==================================================================================================
Total params: 109,483,779
Trainable params: 1,538
Non-trainable params: 109,482,241
__________________________________________________________________________________________________
Model with scope:
Model: "model_7"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
text (InputLayer) [(None,)] 0 []
keras_layer_2 (KerasLayer) {'input_mask': (Non 0 ['text[0][0]']
e, 128),
'input_word_ids':
(None, 128),
'input_type_ids':
(None, 128)}
keras_layer_3 (KerasLayer) multiple 109482241 ['keras_layer_2[7][0]',
'keras_layer_2[7][1]',
'keras_layer_2[7][2]']
dense_7 (Dense) (None, None, 2) 1538 ['keras_layer_3[7][14]']
==================================================================================================
Total params: 109,483,779
Trainable params: 1,538
Non-trainable params: 109,482,241
__________________________________________________________________________________________________
If these image helps:
Another notable thing encoder_outputs
is also missing if you take a look at the 2nd keras layer or 3rd layer of both model.