Error when using a custom metric in keras 3.0

I am working on an binary application problem where the required metric is partial AUC with true positive rate of 80% threshold. I have writen a custom loss in keras 3.0 for this but I run into an error during training,
Here is my code

class pAUC(keras.metrics.Metric):
    def __init__(self, name="pauc", **kwargs):
        super().__init__(name=name, **kwargs)
        self.y_true = self.add_weight(name='y_true', shape=(0,), initializer='zeros', dtype="float32")
        self.y_pred = self.add_weight(name='y_pred', shape=(0,), initializer='zeros', dtype="float32")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = ops.reshape(y_true, [-1])
        y_pred = ops.reshape(y_pred, [-1])

        self.y_true.assign(ops.concatenate([self.y_true, y_true], axis=0))
        self.y_pred.assign(ops.concatenate([self.y_pred, y_pred], axis=0))

    def result(self):
        y_true = self.y_true.numpy()
        y_pred = self.y_pred.numpy()

        fpr, tpr, _ = roc_curve(y_true, y_pred)
        mask = tpr >= 0.8
        filtered_fpr = fpr[mask]
        filtered_tpr = tpr[mask]

        if len(filtered_fpr) > 1:
            return auc(filtered_fpr, filtered_tpr)
        else:
            return 0.0

    def reset_states(self):
        self.y_true.assign(ops.zeros((0,)))
        self.y_pred.assign(ops.zeros((0,)))

And this is the error I am getting

ValueError: The shape of the target variable and the shape of the target value in `variable.assign(value)` must match. variable.shape=(32,), Received: value.shape=(64,). Target variable: <KerasVariable shape=(32,), dtype=float32, path=pauc/y_true>

I understand the error is because I am not able allowed to use the assign method on arrays with different sizes. However I am unable to figure out how to solve it. Any help would be appreciated.

I am training on a minibatch of 32 in case that helps.

Never mind. I rewrote the metrics in pure python and wrapped it in a tf.py_function and it works. For anyone who might want to do same here is the code.

def partial_auc_metric(y_true, y_pred, min_tpr=0.8):
    y_true = ops.reshape(y_true, [-1])
    y_pred = ops.reshape(y_pred, [-1])
    
    v_gt = ops.abs(y_true - 1)
    v_pred = 1 - y_pred
    max_fpr = ops.abs(1 - min_tpr)
    partial_auc_scaled = roc_auc_score(v_gt.numpy(), v_pred.numpy(), max_fpr=max_fpr.numpy())
    partial_auc = 0.5 * max_fpr**2 + (max_fpr - 0.5 * max_fpr**2) / (1.0 - 0.5) * (partial_auc_scaled - 0.5)
    return partial_auc

@keras.saving.register_keras_serializable(package="custom_auc", name="partial_auc")
def partial_auc(y_true, y_pred):
    return tf.py_function(func=partial_auc_metric, inp=[y_true, y_pred], Tout=tf.float32)

The decorator is necessary if you will end up saving the entrie model architecture. This prevents any errors that may arise when loading the model.