Write checkpoint:
class Mymodule (tf.Module):
def __init__ (self):
super().__init__()
self.m = tf.Variable (5.0)
self.b = tf.Variable (5.0)
def __call__ (self, x):
return self.m * x + self.b
m = Mymodule()
cp = tf.train.Checkpoint (model=m)
print ("cp:", cp)
cp.write ("mycheckpoint")
print ("list_variables:", tf.train.list_variables ("mycheckpoint"))
Output:
cp: <tensorflow.python.training.tracking.util.Checkpoint object at 0x7f95fc917940>
list_variables: [('_CHECKPOINTABLE_OBJECT_GRAPH', []), ('model/b/.ATTRIBUTES/VARIABLE_VALUE', []), ('model/m/.ATTRIBUTES/VARIABLE_VALUE', [])]