import os
import datetime
import tensorflow as tf
print(tf.__version__)
class ExampleModel(tf.Module):
def __init__(self):
stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = 'logs/test11/%s' % stamp
self.summary_writer = tf.summary.create_file_writer(logdir)
@tf.function(input_signature=[tf.TensorSpec(shape=(), dtype=tf.float32)])
def capture_fn(self, x):
if not hasattr(self, 'weight'):
self.weight = tf.Variable(5.0, name='weight')
self.weight.assign_add(x * self.weight)
# no error if these two lines are commented
with self.summary_writer.as_default():
tf.summary.scalar('loss', tf.constant(0, dtype=tf.int64), step=tf.constant(0, dtype=tf.int64))
return self.weight
@tf.function
def polymorphic_fn(self, x):
return tf.constant(3.0) * x
model = ExampleModel()
model.polymorphic_fn(tf.constant([1.0, 2.0, 3.0]))
tf.saved_model.save(
model, "/tmp/example-model", signatures={'capture_fn': model.capture_fn})
error message
AssertionError Traceback (most recent call last)
in ()
31 model.polymorphic_fn(tf.constant([1.0, 2.0, 3.0]))
32 tf.saved_model.save(
—> 33 model, “/tmp/example-model”, signatures={‘capture_fn’: model.capture_fn})
7 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/saved_model/save.py in _map_captures_to_created_tensors(original_captures, resource_map)
474 “(from gc.get_referrers, limited to two hops):\n{}”
475 ).format(interior,
→ 476 “\n”.join([repr(obj) for obj in trackable_referrers])))
477 export_captures.append(mapped_resource)
478 return export_captures
AssertionError: Tried to export a function which references untracked resource Tensor(“185:0”, shape=(), dtype=resource). TensorFlow objects (e.g. tf.Variable) captured by functions must be tracked by assigning them to an attribute of a tracked object or assigned to an attribute of the main object directly.
Trackable Python objects referring to this tensor (from gc.get_referrers, limited to two hops):