I am trying to implement Unet with TensorFlow subclassing API and something does not seem to work properly, and I get the following error:
OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.
Furthermore, I am uncertain if I have correctly implemented the logic inside the call()
function. Any help to correct my mistakes would be much appreciated.
Here I am attaching the full copy of the implementation and the error tracks:
Code Implementation:
from functools import partial
keras.backend.clear_session()
tf.random.set_seed(42)
np.random.seed(42)
conv2d = partial(keras.layers.Conv2D, kernel_size = 3,
padding = 'SAME',
kernel_initializer = 'he_normal',
use_bias = False)
conv2dtranspose = partial(keras.layers.Conv2DTranspose,
kernel_size = 2, strides = 2,
padding = 'SAME')
class encoder(keras.layers.Layer):
def __init__(self, filters, **kwargs):
super(encoder, self).__init__(**kwargs)
self.convs = [
conv2d(filters),
keras.layers.BatchNormalization(),
keras.layers.Activation('relu'),
conv2d(filters),
keras.layers.BatchNormalization(),
keras.layers.Activation('relu')
]
def call(self, inputs):
Z = inputs
for layer in self.convs:
Z = layer(Z)
return Z
class UNet(keras.models.Model):
def __init__(self, filters, inputs_shape = [128, 128, 1], **kwargs):
super(UNet, self).__init__(**kwargs)
self.filters = filters
self.inputs = keras.layers.Input(shape = inputs_shape)
self.maxpool2d = keras.layers.MaxPool2D(pool_size = (2, 2), strides = 2)
self.conv2dtranspose = conv2dtranspose
self.concat = keras.layers.Concatenate()
def call(self, inputs):
skips = {}
Z, inpt = inputs
#implementing encoder path
for fId in range(len(self.filters)):
Z = encoder(filters = self.filters[fId])(Z)
if fId < len(self.filters) - 1:
skips[fId] = Z
Z = self.maxpool2d(Z)
#implementing decoder path
for fId in reversed(range(len(self.filters) - 1)):
Z = self.conv2dtranspose(self.filters[fId])(Z)
Z = self.concat([Z, skips[fId]])
Z = encoder(self.filters[::-1][fId])(Z)
output = keras.layers.Conv2D(1, kernel_size = 1, activation = 'sigmoid')(Z)
return keras.Model(inputs = [inpt], outputs = [output])
filters = [64, 128, 256, 512]
inpt = keras.layers.Input(shape = [128, 128, 1])
model = UNet(filters = filters)(inpt)
#Generating some test data
x = tf.random.normal(shape = (10, 128, 128, 1))
y = tf.random.normal(shape = (10, 128, 128, 1))
model.compile(loss = 'binary_crossentropy', optimizer = keras.optimizers.SGD(), metrics = ['accuracy'])
model.fit(x, y, epochs = 3)
Error Tracks:
WARNING:tensorflow:AutoGraph could not transform <bound method UNet.call of <__main__.UNet object at 0x2930b3d30>> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <bound method UNet.call of <__main__.UNet object at 0x2930b3d30>> and will run it as-is.
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in converted_call(f, args, kwargs, caller_fn_scope, options)
446 program_ctx = converter.ProgramContext(options=options)
--> 447 converted_f = _convert_actual(target_entity, program_ctx)
448 if logging.has_verbosity(2):
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in _convert_actual(entity, program_ctx)
283
--> 284 transformed, module, source_map = _TRANSPILER.transform(entity, program_ctx)
285
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/transpiler.py in transform(self, obj, user_context)
285 if inspect.isfunction(obj) or inspect.ismethod(obj):
--> 286 return self.transform_function(obj, user_context)
287
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/transpiler.py in transform_function(self, fn, user_context)
469 # TODO(mdan): Confusing overloading pattern. Fix.
--> 470 nodes, ctx = super(PyToPy, self).transform_function(fn, user_context)
471
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/transpiler.py in transform_function(self, fn, user_context)
362 node = self._erase_arg_defaults(node)
--> 363 result = self.transform_ast(node, context)
364
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in transform_ast(self, node, ctx)
251 unsupported_features_checker.verify(node)
--> 252 node = self.initial_analysis(node, ctx)
253
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in initial_analysis(self, node, ctx)
238 graphs = cfg.build(node)
--> 239 node = qual_names.resolve(node)
240 node = activity.resolve(node, ctx, None)
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/qual_names.py in resolve(node)
251 def resolve(node):
--> 252 return QnResolver().visit(node)
253
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
370 visitor = getattr(self, method, self.generic_visit)
--> 371 return visitor(node)
372
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in generic_visit(self, node)
446 if isinstance(value, AST):
--> 447 value = self.visit(value)
448 if value is None:
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
370 visitor = getattr(self, method, self.generic_visit)
--> 371 return visitor(node)
372
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in generic_visit(self, node)
446 if isinstance(value, AST):
--> 447 value = self.visit(value)
448 if value is None:
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
370 visitor = getattr(self, method, self.generic_visit)
--> 371 return visitor(node)
372
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in generic_visit(self, node)
455 elif isinstance(old_value, AST):
--> 456 new_node = self.visit(old_value)
457 if new_node is None:
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
370 visitor = getattr(self, method, self.generic_visit)
--> 371 return visitor(node)
372
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in generic_visit(self, node)
455 elif isinstance(old_value, AST):
--> 456 new_node = self.visit(old_value)
457 if new_node is None:
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
370 visitor = getattr(self, method, self.generic_visit)
--> 371 return visitor(node)
372
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in generic_visit(self, node)
446 if isinstance(value, AST):
--> 447 value = self.visit(value)
448 if value is None:
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
370 visitor = getattr(self, method, self.generic_visit)
--> 371 return visitor(node)
372
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in generic_visit(self, node)
455 elif isinstance(old_value, AST):
--> 456 new_node = self.visit(old_value)
457 if new_node is None:
~/miniforge3/envs/mlm1-engine/lib/python3.8/ast.py in visit(self, node)
370 visitor = getattr(self, method, self.generic_visit)
--> 371 return visitor(node)
372
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/pyct/qual_names.py in visit_Subscript(self, node)
231 s = node.slice
--> 232 if not isinstance(s, gast.Index):
233 # TODO(mdan): Support range and multi-dimensional indices.
AttributeError: module 'gast' has no attribute 'Index'
During handling of the above exception, another exception occurred:
OperatorNotAllowedInGraphError Traceback (most recent call last)
<ipython-input-449-e6f92329b0db> in <module>
2
3 inpt = keras.layers.Input(shape = [128, 128, 1])
----> 4 model = UNet(filters = filters)(inpt)
5
6 #Generating some test data
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
944 # >> model = tf.keras.Model(inputs, outputs)
945 if _in_functional_construction_mode(self, inputs, args, kwargs, input_list):
--> 946 return self._functional_construction_call(inputs, args, kwargs,
947 input_list)
948
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _functional_construction_call(self, inputs, args, kwargs, input_list)
1083 layer=self, inputs=inputs, build_graph=True, training=training_value):
1084 # Check input assumptions set after layer building, e.g. input shape.
-> 1085 outputs = self._keras_tensor_symbolic_call(
1086 inputs, input_masks, args, kwargs)
1087
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _keras_tensor_symbolic_call(self, inputs, input_masks, args, kwargs)
815 return nest.map_structure(keras_tensor.KerasTensor, output_signature)
816 else:
--> 817 return self._infer_output_signature(inputs, args, kwargs, input_masks)
818
819 def _infer_output_signature(self, inputs, args, kwargs, input_masks):
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _infer_output_signature(self, inputs, args, kwargs, input_masks)
856 # TODO(kaftan): do we maybe_build here, or have we already done it?
857 self._maybe_build(inputs)
--> 858 outputs = call_fn(inputs, *args, **kwargs)
859
860 self._handle_activity_regularization(inputs, outputs)
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in wrapper(*args, **kwargs)
665 try:
666 with conversion_ctx:
--> 667 return converted_call(f, args, kwargs, options=options)
668 except Exception as e: # pylint:disable=broad-except
669 if hasattr(e, 'ag_error_metadata'):
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in converted_call(f, args, kwargs, caller_fn_scope, options)
452 if is_autograph_strict_conversion_mode():
453 raise
--> 454 return _fall_back_unconverted(f, args, kwargs, options, e)
455
456 with StackTraceMapper(converted_f), tf_stack.CurrentModuleFilter():
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in _fall_back_unconverted(f, args, kwargs, options, exc)
499 logging.warn(warning_template, f, file_bug_message, exc)
500
--> 501 return _call_unconverted(f, args, kwargs, options)
502
503
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py in _call_unconverted(f, args, kwargs, options, update_cache)
476
477 if kwargs is not None:
--> 478 return f(*args, **kwargs)
479 return f(*args)
480
<ipython-input-448-ce9f55fd84b1> in call(self, inputs)
49 skips = {}
50
---> 51 Z, inpt = inputs
52
53 #implementing encoder path
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in __iter__(self)
503 def __iter__(self):
504 if not context.executing_eagerly():
--> 505 self._disallow_iteration()
506
507 shape = self._shape_tuple()
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in _disallow_iteration(self)
499 else:
500 # Default: V1-style Graph execution.
--> 501 self._disallow_in_graph_mode("iterating over `tf.Tensor`")
502
503 def __iter__(self):
~/miniforge3/envs/mlm1-engine/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in _disallow_in_graph_mode(self, task)
477
478 def _disallow_in_graph_mode(self, task):
--> 479 raise errors.OperatorNotAllowedInGraphError(
480 "{} is not allowed in Graph execution. Use Eager execution or decorate"
481 " this function with @tf.function.".format(task))
OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.