Issues with Custom Loss and setting tf.config.run_functions_eagerly(True)

Experimental environment: window11 WSL2, tensorflow=2.18.0

class CIoULossMultiClass(tf.keras.losses.Loss):
def init(self, num_classes, reduction=None, name=‘ciou_loss_multiclass’):
“”"
num_classes: number of class
“”"
super(CIoULossMultiClass, self).init(reduction=reduction, name=name)
self.num_classes = num_classes

def call(self, y_true, y_pred):
    return self.ciou_loss_multiclass(y_true, y_pred)

def compute_ciou_loss(self, boxes_true, boxes_pred):

    true_x, true_y, true_w, true_h = tf.split(boxes_true, 4, axis=-1)
    pred_x, pred_y, pred_w, pred_h = tf.split(boxes_pred, 4, axis=-1)
    
    true_x = tf.expand_dims(true_x, axis=1)  # (N,1,1)
    true_y = tf.expand_dims(true_y, axis=1)
    true_w = tf.expand_dims(true_w, axis=1)
    true_h = tf.expand_dims(true_h, axis=1)
    
    pred_x = tf.expand_dims(pred_x, axis=0)    # (1,M,1)
    pred_y = tf.expand_dims(pred_y, axis=0)
    pred_w = tf.expand_dims(pred_w, axis=0)
    pred_h = tf.expand_dims(pred_h, axis=0)
    
    true_area = true_w * true_h
    pred_area = pred_w * pred_h
    
    true_x_min = true_x - true_w / 2.0
    true_y_min = true_y - true_h / 2.0
    true_x_max = true_x + true_w / 2.0
    true_y_max = true_y + true_h / 2.0
    
    pred_x_min = pred_x - pred_w / 2.0
    pred_y_min = pred_y - pred_h / 2.0
    pred_x_max = pred_x + pred_w / 2.0
    pred_y_max = pred_y + pred_h / 2.0
    
    intersect_x_min = tf.maximum(true_x_min, pred_x_min)
    intersect_y_min = tf.maximum(true_y_min, pred_y_min)
    intersect_x_max = tf.minimum(true_x_max, pred_x_max)
    intersect_y_max = tf.minimum(true_y_max, pred_y_max)
    
    intersect_w = tf.maximum(0.0, intersect_x_max - intersect_x_min)
    intersect_h = tf.maximum(0.0, intersect_y_max - intersect_y_min)
    intersect_area = intersect_w * intersect_h
    
    # IoU
    union_area = true_area + pred_area - intersect_area
    iou = intersect_area / (union_area + K.epsilon())
    
    # center distance
    center_distance = tf.square(true_x - pred_x) + tf.square(true_y - pred_y)
    
    enclose_x_min = tf.minimum(true_x_min, pred_x_min)
    enclose_y_min = tf.minimum(true_y_min, pred_y_min)
    enclose_x_max = tf.maximum(true_x_max, pred_x_max)
    enclose_y_max = tf.maximum(true_y_max, pred_y_max)
    
    enclose_w = enclose_x_max - enclose_x_min
    enclose_h = enclose_y_max - enclose_y_min
    enclose_diag = tf.square(enclose_w) + tf.square(enclose_h)
    
    atan_true = tf.atan(true_w / (true_h + K.epsilon()))
    atan_pred = tf.atan(pred_w / (pred_h + K.epsilon()))
    v = (4 / (3.14159265359 ** 2)) * tf.square(atan_true - atan_pred)
    alpha = v / (1 - iou + v + K.epsilon())
    
    # CIoU
    ciou = iou - (center_distance / (enclose_diag + K.epsilon())) - alpha * v
    
    # loss: 1 - ciou
    pairwise_loss = 1 - ciou
    pairwise_loss = tf.squeeze(pairwise_loss, axis=-1)  # shape: (N, M)
    return pairwise_loss

def ciou_loss_multiclass(self, y_true, y_pred):
    """
    y_true: Tensor of shape (N, 5)
    y_pred: Tensor of shape (M, 5)
    """
    losses = []
    for cls in range(self.num_classes):
        cls_true_mask = tf.equal(tf.cast(y_true[:, -1], tf.int32), cls)
        cls_pred_mask = tf.equal(tf.cast(y_pred[:, -1], tf.int32), cls)
        
        boxes_true_cls = tf.boolean_mask(y_true[:, :4], cls_true_mask)
        boxes_pred_cls = tf.boolean_mask(y_pred[:, :4], cls_pred_mask)
        
        def no_boxes():
            return 0.0
        
        def compute_loss():
            pairwise_loss = self.compute_ciou_loss(boxes_true_cls, boxes_pred_cls)
            loss_true = tf.reduce_min(pairwise_loss, axis=1)  # (N_cls,)
            loss_pred = tf.reduce_min(pairwise_loss, axis=0)    # (M_cls,)
            return (tf.reduce_mean(loss_true) + tf.reduce_mean(loss_pred)) / 2.0
        
        loss_cls = tf.cond(
            tf.logical_or(tf.equal(tf.shape(boxes_true_cls)[0], 0),
                          tf.equal(tf.shape(boxes_pred_cls)[0], 0)),
            no_boxes,
            compute_loss
        )
        losses.append(loss_cls)
    
    final_loss = tf.reduce_mean(tf.stack(losses))
    return final_loss

def decoder(classes_wh, cfg_voc):
“”"
Args:
classes_wh: encoded tensor (H, W, C)
Returns:
decoded_boxes: [x_center, y_center, width, height, label] shape tensor
“”"

selected_indices = tf.where(classes_wh[:, :, cfg_voc.n_class] != 0)
selected_classes = tf.gather_nd(classes_wh[..., :cfg_voc.n_class], selected_indices)
selected_wh = tf.gather_nd(classes_wh[..., cfg_voc.n_class:], selected_indices)

decoded_centroid_boxes = tf.concat([tf.cast(selected_indices, tf.float32), selected_wh], axis=1)
decoded_labels = tf.argmax(selected_classes, axis=1, output_type=tf.int32)

# [x_center, y_center, width, height, label]
decoded_labels = tf.cast(tf.expand_dims(decoded_labels, axis=1), tf.float32)
decoded_boxes = tf.concat([decoded_centroid_boxes, decoded_labels], axis=1)

return decoded_boxes  # shape: (num_objects, 5)

class Loss(tf.keras.losses.Loss):
def init(self, cfg, name=‘Loss’):
super().init(name=name)
“”"
class loss
tf.keras.losses.BinaryCrossentropy(): Object detection is multi-label classification.
tf.keras.losses.CategoricalCrossentropy(): Image classification is multi-class classification.
“”"
self.cfg = cfg
self.bce_loss = tf.keras.losses.BinaryCrossentropy()
self.ciou_loss = CIoULossMultiClass(reduction=“none”, num_classes=cfg.n_class)

    # loss log to tf.Variable
    self.loss_log = {
        "cls_loss": tf.Variable(0.0, trainable=False, dtype=tf.float32),
        "box_center_loss": tf.Variable(0.0, trainable=False, dtype=tf.float32), # smooth l1 loss
        "box_wh_loss": tf.Variable(0.0, trainable=False, dtype=tf.float32),     # ciou loss
        "regr_loss": tf.Variable(0.0, trainable=False, dtype=tf.float32),
        "total_loss": tf.Variable(0.0, trainable=False, dtype=tf.float32),
    }

def call(self, y_true, y_pred):

    classes,       wh = y_pred[..., :self.cfg.n_class], y_pred[..., self.cfg.n_class:]
    gt_classes, gt_wh = y_true[..., :self.cfg.n_class], y_true[..., self.cfg.n_class:]

    # cls loss
    cls_loss = self.cfg.alpha * self.bce_loss(gt_classes, classes)

    # ciou loss
    # Eager mode
    # ciou_loss_eager_value = tf.constant(0.0, dtype=tf.float32)
    # for i in range(self.cfg.batch_size):
    #     ciou_loss_eager_value += self.ciou_loss(decoder(y_true[i,...], self.cfg), decoder(y_pred[i,...], self.cfg))

    def compute_loss(i):
        decoded_y_true = decoder(y_true[i, ...], self.cfg)
        decoded_y_pred = decoder(y_pred[i, ...], self.cfg)
        return self.ciou_loss(decoded_y_true, decoded_y_pred)

    # not working
    # ciou_loss_graph_values = tf.map_fn(fn=lambda i: compute_loss(i), elems=tf.range(batch_size))
    # ciou_loss_graph_values = tf.map_fn(fn=lambda i: compute_loss(i), elems=tf.range(batch_size), fn_output_signature=tf.float32)
    # ciou_loss_graph_value =  tf.reduce_sum(ciou_loss_graph_values)

    ciou_loss_0, ciou_loss_1, ciou_loss_2, ciou_loss_3 = compute_loss(0), compute_loss(1), compute_loss(2), compute_loss(3)
    ciou_loss_4, ciou_loss_5, ciou_loss_6, ciou_loss_7 = compute_loss(4), compute_loss(5), compute_loss(6), compute_loss(7)
    ciou_loss_8, ciou_loss_9, ciou_loss_a, ciou_loss_b = compute_loss(8), compute_loss(9), compute_loss(10), compute_loss(11)
    ciou_loss_c, ciou_loss_d, ciou_loss_e, ciou_loss_f = compute_loss(12), compute_loss(13), compute_loss(14), compute_loss(15)

    ciou_sum = tf.stack([ciou_loss_0,ciou_loss_1,ciou_loss_2,ciou_loss_3,ciou_loss_4,ciou_loss_5,ciou_loss_6,ciou_loss_7,
               ciou_loss_8,ciou_loss_9,ciou_loss_a,ciou_loss_b,ciou_loss_c,ciou_loss_d,ciou_loss_e,ciou_loss_f])
    ciou_loss_graph_value = tf.math.reduce_sum(ciou_sum)

    # tf.print((ciou_loss_eager_value - ciou_loss_graph_value).numpy())
    # breakpoint()

    ciou_loss = ciou_loss_graph_value

    # box center loss
    box_center_loss = self.cfg.beta * ciou_loss
  
    # box wh loss
    box_wh_loss = self.cfg.gamma * ciou_loss

    # regression loss
    regr_loss = (box_center_loss + box_wh_loss) / (tf.cast(2, dtype=tf.float32) + 1e-7)

    # total loss
    total_loss = cls_loss + regr_loss

    # assign values to tf.Variable
    self.loss_log["cls_loss"].assign(cls_loss)
    self.loss_log["box_center_loss"].assign(box_center_loss)
    self.loss_log["box_wh_loss"].assign(box_wh_loss)
    self.loss_log["regr_loss"].assign(regr_loss)
    self.loss_log["total_loss"].assign(total_loss)

    return total_loss # single scalar value

I designed a custom loss like the one above and I want to use it. If I use tf.config.run_functions_eagerly(True), it trains very slowly. So if I disable this setting, it will wait for a very long time and then automatically shut down with a message like this

Epoch 1/256
Epoch 1: 0%| | 0/557 [00:00<?, ?batch/s]WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1738828043.106763 4934 service.cc:148] XLA service 0x6adb4ba0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1738828043.106821 4934 service.cc:156] StreamExecutor device (0): NVIDIA GeForce RTX 3090, Compute Capability 8.6
2025-02-06 16:47:24.471641: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var MLIR_CRASH_REPRODUCER_DIRECTORY to enable.
I0000 00:00:1738828059.794380 4934 cuda_dnn.cc:529] Loaded cuDNN version 90300
2025-02-06 16:56:39.396321: E external/local_xla/xla/service/slow_operation_alarm.cc:65]


[Compiling module a_inference_train_step_145069__XlaMustCompile_true_config_proto_2201667018877855759_executor_type_11160318154034397263_.252032] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=–xla_dump_to=/tmp/foo and attach the results.


2025-02-06 16:57:52.164506: E external/local_xla/xla/service/slow_operation_alarm.cc:133] The operation took 3m12.768764716s


[Compiling module a_inference_train_step_145069__XlaMustCompile_true_config_proto_2201667018877855759_executor_type_11160318154034397263_.252032] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=–xla_dump_to=/tmp/foo and attach the results.


I0000 00:00:1738828672.173032 4934 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
Epoch 1: 0%| | 0/557 [11:17<?, ?batch/s]
Traceback (most recent call last):
File “/home/sphinx/Ws/my_AFOD/my_OD1/trainval_voc.py”, line 343, in
train()
File “/home/sphinx/Ws/my_AFOD/my_OD1/trainval_voc.py”, line 107, in train
batch_loss_log = train_step(images, targets)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/home/sphinx/tf-gpu-env/lib/python3.12/site-packages/tensorflow/python/util/traceback_utils.py”, line 153, in error_handler
raise e.with_traceback(filtered_tb) from None
File “/home/sphinx/tf-gpu-env/lib/python3.12/site-packages/tensorflow/python/eager/execute.py”, line 53, in quick_execute
tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
tensorflow.python.framework.errors_impl.InternalError: CustomCall failed: Buffers have different size at runtime [Op:__inference_train_step_145069]

I don’t know what error this is and I don’t know how to fix it, the important thing is that I need a fast training.

Currently on ubuntu22.04 on VMware, tensorflow works fine in CPU mode. I have also minimized batch_size to 1. After I get back to the office, I will check if this is a WSL2 issue, a batch_size issue, or a GPU issue and share the information.

I found the problem: the decorator for the train_step function works when @tf.function is set, and it breaks when @tf.function(jit_compile=True) is set. I’m not sure why, but it seems to have some internal dynamics. tf.map_fn works fine now, too. I guess I need to be careful about using jit_compile.