Hi

We are using `tf.lookup.experimental.DenseHashTable`

withing a `tf.map_fn`

to perform specialised sorting of items within rows of a `RaggedTensor`

. When loading the `SavedModel`

with Tensorflow Serving, we get the following MLIR failure message:

```
error: 'tfg.While' op body function argument #6 type 'tensor<!tf_type.resource<tensor<!tf_type.string>>>' is not compatible with corresponding operand type: 'tensor<!tf_type.resource<tensor<!tf_type.string>, tensor<i32>>>'
2024-05-15 06:42:15.209491: E external/org_tensorflow/tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] tfg_optimizer{tfg-consolidate-attrs,tfg-toposort,tfg-shape-inference{graph-version=0},tfg-prepare-attrs-export} failed: INVALID_ARGUMENT: MLIR Graph Optimizer failed:
```

The `tfg.While`

is generated from the `map_fn`

. Argument #6 is the `DenseHashTable`

, which has `tf.string`

keys and `tf.int32`

values.

The code:

```
@tf.function
def sort_rail_horizontally_hashmap_map_fn(sorted_list: tf.Tensor, rails_to_sort: tf.RaggedTensor):
@tf.function
def reorder_rail(rail):
lookup_indexes = lookup_table.lookup(rail)
match_mask = ~tf.equal(lookup_indexes, -1)
match_indexes = tf.where(condition=match_mask)
extracted_match_indexes = tf.gather(
lookup_indexes, match_indexes[:, 0], name="gather_extracted_match_indexes"
)
extracted_sorted_list = tf.gather(
sorted_list, extracted_match_indexes, name="gather_extracted_sorted_list"
)
sorted_match_indexes = tf.argsort(extracted_match_indexes)
reordered_extracted_sorted_list = tf.gather(
extracted_sorted_list,
sorted_match_indexes,
name="gather_reordered_extracted_sorted_list",
)
composite_rail = tf.tensor_scatter_nd_update(
tensor=rail,
indices=match_indexes,
updates=reordered_extracted_sorted_list,
name="composite_rail_scatter",
)
return composite_rail
lookup_table = tf.lookup.experimental.DenseHashTable(
key_dtype=tf.string,
value_dtype=tf.int32,
default_value=-1,
empty_key="$",
deleted_key="ÂŁ",
)
lookup_table.insert(
sorted_list,
tf.range(0, tf.size(sorted_list), dtype=tf.int32),
name="lookup_table_insert"
)
ragged_rails = tf.map_fn(
reorder_rail,
rails_to_sort,
parallel_iterations=50,
fn_output_signature=tf.RaggedTensorSpec(shape=[None], dtype=tf.string),
name="rails_to_sort_map_fn",
)
return ragged_rails
```

The `While`

node looks like this:

```
node_def {
name: "rails_to_sort_map_fn/while"
op: "While"
input: "rails_to_sort_map_fn/while/loop_counter:output:0"
input: "rails_to_sort_map_fn/strided_slice:output:0"
input: "rails_to_sort_map_fn/Const:output:0"
input: "rails_to_sort_map_fn/TensorArrayV2_1:handle:0"
input: "rails_to_sort_map_fn/strided_slice:output:0"
input: "rails_to_sort_map_fn/TensorArrayUnstack/TensorListFromTensor:output_handle:0"
input: "MutableDenseHashTable:table_handle:0"
input: "default_value:output:0"
input: "sorted_list"
input: "^lookup_table_insert/LookupTableInsertV2"
attr {
key: "T"
value {
list {
type: DT_INT32
type: DT_INT32
type: DT_INT32
type: DT_VARIANT
type: DT_INT32
type: DT_VARIANT
type: DT_RESOURCE
type: DT_INT32
type: DT_STRING
}
}
}
attr {
key: "_lower_using_switch_merge"
value {
b: true
}
}
attr {
key: "_num_original_outputs"
value {
i: 9
}
}
attr {
key: "_read_only_resource_inputs"
value {
list {
}
}
}
attr {
key: "body"
value {
func {
name: "rails_to_sort_map_fn_while_body_18098"
}
}
}
attr {
key: "cond"
value {
func {
name: "rails_to_sort_map_fn_while_cond_18097"
}
}
}
attr {
key: "output_shapes"
value {
list {
shape {
}
shape {
}
shape {
}
shape {
}
shape {
}
shape {
}
shape {
}
shape {
}
shape {
dim {
size: 828
}
}
}
}
}
attr {
key: "parallel_iterations"
value {
i: 50
}
}
}
```

The graph executes as expected, and we guess that maybe MLIR does not yet support references to `DenseHashTable`

or itâ€™s an MLIR bug?

Our primary concern is what is the effect of the failure. Does it stop all graph optimization on the target server, or only that node of the graph?

Thanks

Adrian