Hi everyone,
I am using the CRFModelWrapper method following the tutorial as addons/layers_crf.ipynb at add_crf_tutorial · howl-anderson/addons · GitHub to implement a Bi-LSTM -CRF neural-network for a multi-classes NER problem. The model I built (codes are shown below) can be trained with multiple GPUs and it can be saved and load with the tf.keras.model functions. However, when I I saved the model, a warning shows below which I am not sure if it really matters. After that, I loaded the trained model, it shows many warnings related inconsistent output shape? These warnings are posted below. want to train it again using the fit function, it shows no
Saving warning:
WARNING:absl:Found untraced functions such as embedding_layer_call_and_return_conditional_losses, embedding_layer_call_fn, embedding_1_layer_call_and_return_conditional_losses, embedding_1_layer_call_fn, multi_head_attention_layer_call_and_return_conditional_losses while saving (showing 5 of 65). These functions will not be directly callable after loading.
Loading warnings:
2021-11-21 00:28:34.222763: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:34.374067: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:35.553642: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:37.286608: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:37.556591: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:37.645775: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:37.742369: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:38.758195: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:38.892746: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:38.905252: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:39.369396: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:39.426702: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:39.439491: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:39.562408: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:39.574280: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:40.102410: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:40.363686: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:40.375926: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:40.706863: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:40.765348: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:41.099845: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:41.111754: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:41.127138: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:41.139674: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:41.959556: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:41.971252: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:42.012887: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:42.025740: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:42.284787: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:42.403613: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:42.551636: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:42.773302: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:42.786014: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:42.993502: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:43.005655: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:43.019730: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:43.031896: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:43.045523: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:43.581028: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:43.593377: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:44.286465: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:44.298637: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:44.319105: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:44.331097: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:44.446932: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:44.864126: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:44.995524: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:45.084922: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:45.097600: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:45.134094: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:45.243815: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:45.264186: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:45.599941: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:45.621283: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:45.633535: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:46.106656: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond/while' has 14 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:46.119119: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:46.141522: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:46.165137: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
2021-11-21 00:28:46.266813: W tensorflow/core/common_runtime/graph_constructor.cc:803] Node 'cond' has 5 outputs but the _output_shapes attribute specifies shapes for 48 outputs. Output shapes may be inaccurate.
With these warnings, the model can still be saved, loaded, and used for the prediction, the loaded model cannot be re-trained again. The error message is posted below and it seems that the loss function inside the CRFModelWrapper cannot be called again so that the gradient calculation cannot be done. So, I am wondering if the CRFModelWrapper doesn’t support the serialization (save->load->training) or it’s because of some mistakes I have made. If so, is there any way that I can workaround to retrain the model?
Thank you very much.
Error message when re-training the trained model:
Traceback (most recent call last):
File "<input>", line 1, in <module>
File "C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\keras\engine\training.py", line 1184, in fit
tmp_logs = self.train_function(iterator)
File "C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\tensorflow\python\eager\def_function.py", line 885, in __call__
result = self._call(*args, **kwds)
File "C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\tensorflow\python\eager\def_function.py", line 933, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\tensorflow\python\eager\def_function.py", line 759, in _initialize
self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
File "C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\tensorflow\python\eager\function.py", line 3066, in _get_concrete_function_internal_garbage_collected
graph_function, _ = self._maybe_define_function(args, kwargs)
File "C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\tensorflow\python\eager\function.py", line 3463, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\tensorflow\python\eager\function.py", line 3298, in _create_graph_function
func_graph_module.func_graph_from_py_func(
File "C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\tensorflow\python\framework\func_graph.py", line 1007, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\tensorflow\python\eager\def_function.py", line 668, in wrapped_fn
out = weak_wrapped_fn().__wrapped__(*args, **kwds)
File "C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\tensorflow\python\framework\func_graph.py", line 994, in wrapper
raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:
C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\keras\engine\training.py:853 train_function *
return step_function(self, iterator)
C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\keras\engine\training.py:842 step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:1286 run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:2849 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:3632 _call_for_each_replica
return fn(*args, **kwargs)
C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\keras\engine\training.py:835 run_step **
outputs = model.train_step(data)
C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\keras\engine\training.py:791 train_step
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\keras\optimizer_v2\optimizer_v2.py:522 minimize
return self.apply_gradients(grads_and_vars, name=name)
C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\keras\optimizer_v2\optimizer_v2.py:622 apply_gradients
grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars)
C:\Users\YenPangLai\anaconda3\envs\tf2p6\lib\site-packages\keras\optimizer_v2\utils.py:72 filter_empty_gradients
raise ValueError("No gradients provided for any variable: %s." %
ValueError: No gradients provided for any variable: ['chain_kernel:0', 'left_boundary:0', 'right_boundary:0', 'crf_model_wrapper/crf/dense/kernel:0', 'crf_model_wrapper/crf/dense/bias:0', 'embedding/embeddings:0', 'bidirectional/forward_bilstm/lstm_cell_1/kernel:0', 'bidirectional/forward_bilstm/lstm_cell_1/recurrent_kernel:0', 'bidirectional/forward_bilstm/lstm_cell_1/bias:0', 'bidirectional/backward_bilstm/lstm_cell_2/kernel:0', 'bidirectional/backward_bilstm/lstm_cell_2/recurrent_kernel:0', 'bidirectional/backward_bilstm/lstm_cell_2/bias:0', 'time_distributed/kernel:0', 'time_distributed/bias:0'].
Below shows my codes except for the data-preprocessing.
#%% Build the base model_1
def build_bilstm_crf_model(
lstm_unit,
fc_unit
) -> tf.keras.Model:
x = tf.keras.layers.Input(shape=(None,), dtype=tf.float32, name="inn")
y = tf.keras.layers.Embedding(1, 1, mask_zero=True)(x)
y = tf.keras.layers.Bidirectional(
tf.keras.layers.LSTM(lstm_unit, return_sequences=True,name="bilstm")
)(y)
y = tf.keras.layers.TimeDistributed(
tf.keras.layers.Dense(fc_unit,name="fc")
)(y)
return tf.keras.Model(
inputs=x, outputs=y
)
# CFR Wrapper Model
class CRFModelWrapper(tf.keras.Model):
def __init__(
self,
model: tf.keras.Model,
units: int,
chain_initializer="orthogonal",
use_boundary: bool = True,
boundary_initializer="zeros",
use_kernel: bool = True,
**kwargs
):
super().__init__()
self.crf_layer = tfa.layers.CRF(
units=units,
chain_initializer=chain_initializer,
use_boundary=use_boundary,
boundary_initializer=boundary_initializer,
use_kernel=use_kernel,
**kwargs
)
self.base_model = model
def unpack_training_data(self, data):
# override me, if this is not suit for your task
if len(data) == 3:
x, y, sample_weight = data
else:
x, y = data
sample_weight = None
return x, y, sample_weight
def call(self, inputs, training=None, mask=None, return_crf_internal=False):
base_model_outputs = self.base_model(inputs, training, mask)
# change next line, if your model has more outputs
crf_input = base_model_outputs
decode_sequence, potentials, sequence_length, kernel = self.crf_layer(crf_input)
### potentials =predicted y during training
# change next line, if your base model has more outputs
# Always keep `(potentials, sequence_length, kernel), decode_sequence, `
# as first two outputs of model.
# current `self.train_step()` expected such settings
outputs = (potentials, sequence_length, kernel), decode_sequence
if return_crf_internal:
return outputs
else:
# outputs[0] is the crf internal, skip it
output_without_crf_internal = outputs[1:]
# it is nicer to return a tensor instead of an one tensor list
if len(output_without_crf_internal) == 1:
return output_without_crf_internal[0]
else:
return output_without_crf_internal
def compute_crf_loss(self, potentials, sequence_length, kernel, y, sample_weight=None):
### Added to reshape labels(y)
shape = y.shape
if len(shape) > 2:
y_1 = tf.argmax(y, -1, output_type=tf.int32)
################################################
crf_likelihood, _ = tfa.text.crf_log_likelihood(
potentials, y_1, sequence_length, kernel
)
# convert likelihood to loss
flat_crf_loss = -1 * crf_likelihood
if sample_weight is not None:
flat_crf_loss = flat_crf_loss * sample_weight
crf_loss = tf.reduce_mean(flat_crf_loss)
return crf_loss
def train_step(self, data):
x, y, sample_weight = self.unpack_training_data(data)
with tf.GradientTape() as tape:
(potentials, sequence_length, kernel), decoded_sequence, *_ = self(
x, training=True, return_crf_internal=True
)
crf_loss = self.compute_crf_loss(
potentials, sequence_length, kernel, y, sample_weight
)
loss = crf_loss + tf.reduce_sum(self.losses)
gradients = tape.gradient(loss, self.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
# Update metrics (includes the metric that tracks the loss)
self.compiled_metrics.update_state(y, decoded_sequence)
# Return a dict mapping metric names to current value
results = {m.name: m.result() for m in self.metrics}
results.update({"loss": loss, "crf_loss": crf_loss}) # append loss
return results
def test_step(self, data):
x, y, sample_weight = self.unpack_training_data(data)
(potentials, sequence_length, kernel), decode_sequence, *_ = self(
x, training=False, return_crf_internal=True
)
crf_loss = self.compute_crf_loss(
potentials, sequence_length, kernel, y, sample_weight
)
loss = crf_loss + tf.reduce_sum(self.losses)
# Update metrics (includes the metric that tracks the loss)
self.compiled_metrics.update_state(y, decode_sequence)
# Return a dict mapping metric names to current value
results = {m.name: m.result() for m in self.metrics}
results.update({"loss": loss, "crf_loss": crf_loss}) # append loss
return results
When training the model:
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
with strategy.scope():
base_model = build_bilstm_crf_model(units_lstm,TAG_SIZE)
model = CRFModelWrapper(base_model, TAG_SIZE)
model.compile(optimizer=tf.keras.optimizers.Adam(lr))
num_epochs = 10
name = 'BLD_CRF_Lut{}_lr{}_{}'.format(units_lstm,lrr,int(time.time()))
name_csv = 'BLD_CRF_Lut{}_lr{}'.format(units_lstm,lrr)
mc = tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(name,'{epoch:02d}'), verbose=1,
save_best_only=False,save_weights_only=False)
csv_log = tf.keras.callbacks.CSVLogger('{}.csv'.format(name_csv),append=True)
model.fit(tr_gen, epochs=num_epochs, validation_data = val_gen,verbose=2,batch_size = bt_sz,callbacks=[mc,csv_log])
When re-training the model:
#%%
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
with strategy.scope():
# base_model = build_bilstm_crf_model(units_lstm,TAG_SIZE)
# model = CRFModelWrapper(base_model, TAG_SIZE)
# model.compile(optimizer=tf.keras.optimizers.Adam(lr))
model = load_model(FILE_PATH4)
num_epochs = 10
name = 'BLD_CRF_Lut{}_lr{}_{}'.format(units_lstm,lrr,int(time.time()))
name_csv = 'BLD_CRF_Lut{}_lr{}'.format(units_lstm,lrr)
mc = tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(name,'{epoch:02d}'),verbose=1,save_best_only=False,save_weights_only=False)
csv_log = tf.keras.callbacks.CSVLogger('{}.csv'.format(name_csv),append=True)
model.fit(tr_gen, epochs=num_epochs, validation_data = val_gen,verbose=2,batch_size = bt_sz,callbacks=[mc,csv_log])