I’m upgrading my training code from 2.8 to 2.12. Fitting a simple multi-label classifier on 2.12.0 raises an exception only when I use the following metric:
tf.keras.metrics.AUC(curve="ROC", multi_label=True)
No exception is raised when I don’t use this metric. It’s raised whether or not I use a validation dataset, so this is happening when computing metrics on the training data. The error is:
File ".tox/train/lib/python3.10/site-packages/keras/engine/training.py", line 1284, in train_function *
return step_function(self, iterator)
File ".tox/train/lib/python3.10/site-packages/keras/engine/training.py", line 1268, in step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
File ".tox/train/lib/python3.10/site-packages/keras/engine/training.py", line 1249, in run_step **
outputs = model.train_step(data)
File ".tox/train/lib/python3.10/site-packages/keras/engine/training.py", line 1055, in train_step
return self.compute_metrics(x, y, y_pred, sample_weight)
File ".tox/train/lib/python3.10/site-packages/keras/engine/training.py", line 1149, in compute_metrics
self.compiled_metrics.update_state(y, y_pred, sample_weight)
File ".tox/train/lib/python3.10/site-packages/keras/engine/compile_utils.py", line 605, in update_state
metric_obj.update_state(y_t, y_p, sample_weight=mask)
File ".tox/train/lib/python3.10/site-packages/keras/utils/metrics_utils.py", line 77, in decorated
update_op = update_state_fn(*args, **kwargs)
File ".tox/train/lib/python3.10/site-packages/keras/metrics/base_metric.py", line 140, in update_state_fn
return ag_update_state(*args, **kwargs)
File ".tox/train/lib/python3.10/site-packages/keras/metrics/confusion_metrics.py", line 1453, in update_state **
self._build(tf.TensorShape(y_pred.shape))
File ".tox/train/lib/python3.10/site-packages/keras/metrics/confusion_metrics.py", line 1402, in _build
raise ValueError(
ValueError: `y_true` must have rank 2 when `multi_label=True`. Found rank None. Full shape received for `y_true`: <unknown>
One thing I noticed from the stacktrace is that the ValueError message says it’s about y_true
, but in fact one level up we see it’s using y_pred
. Maybe I’m being pedantic to notice the mismatch, but as an experiment I went ahead and tried directly editing confusion_metrics.py
on line 1453 to call self._build
on y_true
as _build
says it expects. With that hack, my model training finishes!
So now I have two questions:
- Is it a bug in keras to call
self._build(tf.TensorShape(y_pred.shape))
whenAUC._build
’s error message indicate it expects to receivey_true.shape
? I’m shocked that y_true and y_pred would ever have different shapes, but clearly they sometimes do. - What API change would cause this error to be start appearing only when I upgrade Tensorflow? I browsed the changelogs and didn’t see anything that felt relevant.
I could probably produce a minimal example if it’s helpful.