Hi
We are using tf.lookup.experimental.DenseHashTable withing a tf.map_fn to perform specialised sorting of items within rows of a RaggedTensor. When loading the SavedModel with Tensorflow Serving, we get the following MLIR failure message:
error: 'tfg.While' op body function argument #6 type 'tensor<!tf_type.resource<tensor<!tf_type.string>>>' is not compatible with corresponding operand type: 'tensor<!tf_type.resource<tensor<!tf_type.string>, tensor<i32>>>'
2024-05-15 06:42:15.209491: E external/org_tensorflow/tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] tfg_optimizer{tfg-consolidate-attrs,tfg-toposort,tfg-shape-inference{graph-version=0},tfg-prepare-attrs-export} failed: INVALID_ARGUMENT: MLIR Graph Optimizer failed:
The tfg.While is generated from the map_fn. Argument #6 is the DenseHashTable, which has tf.string keys and tf.int32 values.
The code:
@tf.function
def sort_rail_horizontally_hashmap_map_fn(sorted_list: tf.Tensor, rails_to_sort: tf.RaggedTensor):
@tf.function
def reorder_rail(rail):
lookup_indexes = lookup_table.lookup(rail)
match_mask = ~tf.equal(lookup_indexes, -1)
match_indexes = tf.where(condition=match_mask)
extracted_match_indexes = tf.gather(
lookup_indexes, match_indexes[:, 0], name="gather_extracted_match_indexes"
)
extracted_sorted_list = tf.gather(
sorted_list, extracted_match_indexes, name="gather_extracted_sorted_list"
)
sorted_match_indexes = tf.argsort(extracted_match_indexes)
reordered_extracted_sorted_list = tf.gather(
extracted_sorted_list,
sorted_match_indexes,
name="gather_reordered_extracted_sorted_list",
)
composite_rail = tf.tensor_scatter_nd_update(
tensor=rail,
indices=match_indexes,
updates=reordered_extracted_sorted_list,
name="composite_rail_scatter",
)
return composite_rail
lookup_table = tf.lookup.experimental.DenseHashTable(
key_dtype=tf.string,
value_dtype=tf.int32,
default_value=-1,
empty_key="$",
deleted_key="ÂŁ",
)
lookup_table.insert(
sorted_list,
tf.range(0, tf.size(sorted_list), dtype=tf.int32),
name="lookup_table_insert"
)
ragged_rails = tf.map_fn(
reorder_rail,
rails_to_sort,
parallel_iterations=50,
fn_output_signature=tf.RaggedTensorSpec(shape=[None], dtype=tf.string),
name="rails_to_sort_map_fn",
)
return ragged_rails
The While node looks like this:
node_def {
name: "rails_to_sort_map_fn/while"
op: "While"
input: "rails_to_sort_map_fn/while/loop_counter:output:0"
input: "rails_to_sort_map_fn/strided_slice:output:0"
input: "rails_to_sort_map_fn/Const:output:0"
input: "rails_to_sort_map_fn/TensorArrayV2_1:handle:0"
input: "rails_to_sort_map_fn/strided_slice:output:0"
input: "rails_to_sort_map_fn/TensorArrayUnstack/TensorListFromTensor:output_handle:0"
input: "MutableDenseHashTable:table_handle:0"
input: "default_value:output:0"
input: "sorted_list"
input: "^lookup_table_insert/LookupTableInsertV2"
attr {
key: "T"
value {
list {
type: DT_INT32
type: DT_INT32
type: DT_INT32
type: DT_VARIANT
type: DT_INT32
type: DT_VARIANT
type: DT_RESOURCE
type: DT_INT32
type: DT_STRING
}
}
}
attr {
key: "_lower_using_switch_merge"
value {
b: true
}
}
attr {
key: "_num_original_outputs"
value {
i: 9
}
}
attr {
key: "_read_only_resource_inputs"
value {
list {
}
}
}
attr {
key: "body"
value {
func {
name: "rails_to_sort_map_fn_while_body_18098"
}
}
}
attr {
key: "cond"
value {
func {
name: "rails_to_sort_map_fn_while_cond_18097"
}
}
}
attr {
key: "output_shapes"
value {
list {
shape {
}
shape {
}
shape {
}
shape {
}
shape {
}
shape {
}
shape {
}
shape {
}
shape {
dim {
size: 828
}
}
}
}
}
attr {
key: "parallel_iterations"
value {
i: 50
}
}
}
The graph executes as expected, and we guess that maybe MLIR does not yet support references to DenseHashTable or it’s an MLIR bug?
Our primary concern is what is the effect of the failure. Does it stop all graph optimization on the target server, or only that node of the graph?
Thanks
Adrian