Hi
We are using tf.lookup.experimental.DenseHashTable
withing a tf.map_fn
to perform specialised sorting of items within rows of a RaggedTensor
. When loading the SavedModel
with Tensorflow Serving, we get the following MLIR failure message:
error: 'tfg.While' op body function argument #6 type 'tensor<!tf_type.resource<tensor<!tf_type.string>>>' is not compatible with corresponding operand type: 'tensor<!tf_type.resource<tensor<!tf_type.string>, tensor<i32>>>'
2024-05-15 06:42:15.209491: E external/org_tensorflow/tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] tfg_optimizer{tfg-consolidate-attrs,tfg-toposort,tfg-shape-inference{graph-version=0},tfg-prepare-attrs-export} failed: INVALID_ARGUMENT: MLIR Graph Optimizer failed:
The tfg.While
is generated from the map_fn
. Argument #6 is the DenseHashTable
, which has tf.string
keys and tf.int32
values.
The code:
@tf.function
def sort_rail_horizontally_hashmap_map_fn(sorted_list: tf.Tensor, rails_to_sort: tf.RaggedTensor):
@tf.function
def reorder_rail(rail):
lookup_indexes = lookup_table.lookup(rail)
match_mask = ~tf.equal(lookup_indexes, -1)
match_indexes = tf.where(condition=match_mask)
extracted_match_indexes = tf.gather(
lookup_indexes, match_indexes[:, 0], name="gather_extracted_match_indexes"
)
extracted_sorted_list = tf.gather(
sorted_list, extracted_match_indexes, name="gather_extracted_sorted_list"
)
sorted_match_indexes = tf.argsort(extracted_match_indexes)
reordered_extracted_sorted_list = tf.gather(
extracted_sorted_list,
sorted_match_indexes,
name="gather_reordered_extracted_sorted_list",
)
composite_rail = tf.tensor_scatter_nd_update(
tensor=rail,
indices=match_indexes,
updates=reordered_extracted_sorted_list,
name="composite_rail_scatter",
)
return composite_rail
lookup_table = tf.lookup.experimental.DenseHashTable(
key_dtype=tf.string,
value_dtype=tf.int32,
default_value=-1,
empty_key="$",
deleted_key="ÂŁ",
)
lookup_table.insert(
sorted_list,
tf.range(0, tf.size(sorted_list), dtype=tf.int32),
name="lookup_table_insert"
)
ragged_rails = tf.map_fn(
reorder_rail,
rails_to_sort,
parallel_iterations=50,
fn_output_signature=tf.RaggedTensorSpec(shape=[None], dtype=tf.string),
name="rails_to_sort_map_fn",
)
return ragged_rails
The While
node looks like this:
node_def {
name: "rails_to_sort_map_fn/while"
op: "While"
input: "rails_to_sort_map_fn/while/loop_counter:output:0"
input: "rails_to_sort_map_fn/strided_slice:output:0"
input: "rails_to_sort_map_fn/Const:output:0"
input: "rails_to_sort_map_fn/TensorArrayV2_1:handle:0"
input: "rails_to_sort_map_fn/strided_slice:output:0"
input: "rails_to_sort_map_fn/TensorArrayUnstack/TensorListFromTensor:output_handle:0"
input: "MutableDenseHashTable:table_handle:0"
input: "default_value:output:0"
input: "sorted_list"
input: "^lookup_table_insert/LookupTableInsertV2"
attr {
key: "T"
value {
list {
type: DT_INT32
type: DT_INT32
type: DT_INT32
type: DT_VARIANT
type: DT_INT32
type: DT_VARIANT
type: DT_RESOURCE
type: DT_INT32
type: DT_STRING
}
}
}
attr {
key: "_lower_using_switch_merge"
value {
b: true
}
}
attr {
key: "_num_original_outputs"
value {
i: 9
}
}
attr {
key: "_read_only_resource_inputs"
value {
list {
}
}
}
attr {
key: "body"
value {
func {
name: "rails_to_sort_map_fn_while_body_18098"
}
}
}
attr {
key: "cond"
value {
func {
name: "rails_to_sort_map_fn_while_cond_18097"
}
}
}
attr {
key: "output_shapes"
value {
list {
shape {
}
shape {
}
shape {
}
shape {
}
shape {
}
shape {
}
shape {
}
shape {
}
shape {
dim {
size: 828
}
}
}
}
}
attr {
key: "parallel_iterations"
value {
i: 50
}
}
}
The graph executes as expected, and we guess that maybe MLIR does not yet support references to DenseHashTable
or it’s an MLIR bug?
Our primary concern is what is the effect of the failure. Does it stop all graph optimization on the target server, or only that node of the graph?
Thanks
Adrian