Help with refactoring nested loops

Greetings!
I’ve written a class to compute a loss function in which I bring some vectors from the side and compute some distances. My first attempt was to compute 3 nested Python loops which is obviously painfully slow. Since I lack the knowledge/expertise to manipulate tensors my question is: Is there a form to refactor this function to pure tensorflow calls?
The code in question:

class AnchorLoss():
    def __init__(self) -> None:
#Code to load the vectors I want to use as centers and anchors
#O(n**3) not nice
    def inner_loop(self, embedding):
        aux = []
        for key, arrdict in self.mean_anchor_dict.items():
            center = tf.convert_to_tensor(arrdict['center'], dtype=tf.float32)
            anchor = tf.convert_to_tensor(arrdict['anchor'], dtype=tf.float32)

            d1 = tf.norm(center-embedding)
            d2 = tf.norm(center-anchor)
            aux.append(d1+d2)

        aux = tf.convert_to_tensor(aux)
        return tf.reduce_sum(aux)

    def mid_loop(self, vect):
        aux = []
        for embedding in tf.unstack(vect):
            aux.append(self.inner_loop(embedding))

        aux = tf.convert_to_tensor(aux)
        return tf.reduce_sum(aux)
        
    def loss(self, batch:tf.Tensor):
        batch_loss = []
        for vect in tf.unstack(batch):
            batch_loss.append(self.mid_loop(vect))

        batch_loss = tf.convert_to_tensor(batch_loss)
        return tf.reduce_mean(batch_loss)

(There might be an extra inner loop that doesn’t belong but the question remains the same)
Thanks!!

Sometimes It could be hard to think vectorized or to find a vectorized approach.
One easy solution Is to check what performance you have with:

You can also try to use jit_compile.

Check the thread at:

https://github.com/keras-team/keras-cv/pull/146

You could also try to check:

https://github.com/keras-team/keras-cv/pull/161

Thanks!! Will check those out