Getting nan loss while training ArcFace model

Dear Team,

I am getting issue of loss nan while training arcface with pre-trained facenet model and also with resnet50 model. please suggest what changes can resolve the nan issue.
Code Below:


import os
import tensorflow as tf
from tensorflow.keras import layers, Model, optimizers
from tensorflow.keras.applications import ResNet50
from tqdm import tqdm
from tensorflow.keras.models import load_model

# Custom ArcFace Loss (unchanged)
class ArcFaceLoss(tf.keras.losses.Loss):
    def __init__(self, num_classes, s=16.0, m=0.50):
        super(ArcFaceLoss, self).__init__()
        self.num_classes = num_classes
        self.s = s  # Scaling factor
        self.m = m  # Margin
        self.class_weights = None

    def call(self, y_true, y_pred):
        if self.class_weights is None:
            raise ValueError("Class weights not set.")
        
        y_pred = tf.nn.l2_normalize(y_pred, axis=1)
        class_weights = tf.nn.l2_normalize(self.class_weights, axis=0)

            
        cosine_sim = tf.matmul(y_pred, class_weights, transpose_b=True)
        cosine_sim = tf.clip_by_value(cosine_sim, -0.999, 0.999)
        theta = tf.acos(cosine_sim)
        
        #theta = tf.acos(tf.clip_by_value(cosine_sim, -1.0 + 1e-7, 1.0 - 1e-7))
        target_logits = tf.cos(theta + self.m)
        
        print("target_logits inside calling: ",target_logits)

        output = tf.zeros_like(cosine_sim)
        output = tf.tensor_scatter_nd_update(output, tf.expand_dims(y_true, axis=1), target_logits)
        output *= self.s
        
        loss_value = tf.keras.losses.sparse_categorical_crossentropy(y_true, output, from_logits=True)
        return loss_value

    def set_class_weights(self, class_weights):
        self.class_weights = class_weights


pretrained_model = './FaceNetModel/facenet_keras.h5'

def load_facenet_base(input_shape):
    """Load the FaceNet model without the top classification layer."""
    base_model = load_model(pretrained_model)
    
    # Get the output of the second-to-last layer
    facenet_output = base_model.layers[-5].output
    print("facenet_output : ",facenet_output)
    print("facenet_output shape : ",facenet_output.shape)
    
    # Create a new model from the base model, but with the top layer removed
    facenet_model = tf.keras.Model(inputs=base_model.input, outputs=facenet_output)
    print("facenet_model : ",facenet_model)
    return facenet_model
        
# ArcFace Model Definition with Batch Normalization
class ArcFaceModel(Model):
    def __init__(self, num_classes, embedding_size):
        super(ArcFaceModel, self).__init__()
        # Load ResNet50 backbone and add Batch Normalization after each convolutional block
        #self.backbone = ResNet50(include_top=False, input_shape=(112, 112, 3), pooling=None, weights=None)
        self.backbone = load_facenet_base(input_shape=(160, 160, 3))
        self.pooling = layers.GlobalAveragePooling2D()
        self.batch_norm = layers.BatchNormalization()  # Batch Normalization after Global Average Pooling
        self.dense = layers.Dense(embedding_size)
        self.classifier = layers.Dense(num_classes)

    def call(self, inputs):
        x = self.backbone(inputs)
        x = self.pooling(x)
        x = self.batch_norm(x)  # Apply batch normalization before dense layers
        embeddings = self.dense(x)
        logits = self.classifier(embeddings)
        return embeddings, logits

# Load images from directory with error handling
def load_dataset(data_dir, batch_size, img_size=(160, 160)):
    image_paths = []
    labels = []
    
    # Iterate through each uid folder
    for uid in os.listdir(data_dir):
        uid_path = os.path.join(data_dir, uid)
        if os.path.isdir(uid_path):
            for img_file in os.listdir(uid_path):
                if img_file.endswith(('.jpg', '.jpeg', '.png')):  # Add more formats if needed
                    image_paths.append(os.path.join(uid_path, img_file))
                    labels.append(int(uid))  # Assuming uid is the label, adjust as needed

    dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels))

    def _preprocess_image(image_path, label):
        try:
            image = tf.io.read_file(image_path)
            image = tf.image.decode_jpeg(image, channels=3)
            image = tf.image.resize(image, img_size)
            image = image / 255.0
            return image, label
        except Exception as e:
            print(f"Error processing image: {image_path}, Error: {e}")
            return None, None  # Return None for problematic images

    dataset = dataset.map(_preprocess_image)
    
    # Filter out None values (problematic images)
    dataset = dataset.filter(lambda img, lbl: img is not None and lbl is not None)
    
    # Batch and shuffle the dataset
    dataset = dataset.shuffle(len(image_paths)).batch(batch_size).prefetch(tf.data.AUTOTUNE)

    return dataset

# Hyperparameters (unchanged)
# Hyperparameters (unchanged)
num_classes = 214
embedding_size = 512
batch_size = 16
img_size = (160, 160)

# Model, optimizer, and loss function (unchanged)
model = ArcFaceModel(num_classes=num_classes, embedding_size=embedding_size)
# optimizer = optimizers.Adam(learning_rate=0.01)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.01, clipnorm=1.0)
loss_fn = ArcFaceLoss(num_classes=num_classes)

# Load the dataset
data_dir = "../data/test_data/Part_5"  # Path to your data directory
dataset = load_dataset(data_dir, batch_size, img_size)

# Train the model (unchanged)
num_batches = tf.data.experimental.cardinality(dataset).numpy()
model.compile(optimizer=optimizer, loss=loss_fn, metrics=['accuracy'])  

for epoch in range(10):
    with tqdm(total=num_batches, desc=f'Epoch [{epoch + 1}/10]', unit='batch') as pbar:
        for images, labels in dataset:
            with tf.GradientTape() as tape:
                embeddings, logits = model(images)  # Get embeddings and logits
                class_weights = model.classifier.get_weights()[0]  # Get classifier weights
                if len(class_weights) == 0:
                    raise ValueError("Classifier weights are not initialized.")
                print("class_weights : ",class_weights)
                print("class_weights shape : ",class_weights.shape)
                print("labels : ",labels)
                print("labels shape : ",labels.shape)
                print("logits : ",logits)
                print("logits shape : ",logits.shape)
                loss_fn.set_class_weights(tf.convert_to_tensor(class_weights))  # Set class weights in loss function
                loss = loss_fn(labels, logits)  # Use logits for loss computation

            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            # Update progress bar
            pbar.set_postfix(loss=loss.numpy())
            pbar.update(1)

# Get embeddings (unchanged)
embeddings = []
for images, labels in dataset:
    features, _ = model(images)  # Get embeddings
    normalized_features = tf.nn.l2_normalize(features, axis=1)
    embeddings.append(normalized_features)
embeddings = tf.concat(embeddings, axis=0)
print("Embeddings shape:", embeddings.shape)

From above code i am getting below output.

After some iteration weights converted to nan.

WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.
facenet_output :  KerasTensor(type_spec=TensorSpec(shape=(None, 3, 3, 1792), dtype=tf.float32, name=None), name='Block8_6_ScaleSum/add:0', description="created by layer 'Block8_6_ScaleSum'")
facenet_output shape :  (None, 3, 3, 1792)
facenet_model :  <keras.engine.functional.Functional object at 0x7f2c20096cc0>
Epoch [1/10]: 0batch [00:00, ?batch/s]
class_weights :  [[-0.07154729 -0.02306947  0.02076409 ...  0.08830355  0.03795646
  -0.05537736]
 [ 0.02608624 -0.01183958  0.08689994 ...  0.02180763  0.06158453
  -0.02694845]
 [ 0.02542777 -0.00198308  0.06604646 ... -0.08475913  0.00327843
  -0.01243555]
 ...
 [ 0.02010283  0.03593907  0.02092864 ... -0.03379323 -0.03277213
  -0.02229045]
 [-0.08940084 -0.07136245  0.01797614 ...  0.00967152 -0.06468429
  -0.07133079]
 [-0.07091641  0.06512755  0.03746107 ...  0.02718134 -0.04930652
   0.05826384]]
class_weights shape :  (512, 214)
labels :  tf.Tensor(
[ 6805  2451  1351 13363   891  5954   908  9006  6484  5922   477 16645
  2069 13559  2969 11079], shape=(16,), dtype=int32)
labels shape :  (16,)
logits :  tf.Tensor(
[[ 0.26744387  0.5673069  -0.65750664 ...  0.93259597 -0.31617057
  -0.9069006 ]
 [ 0.30575937  0.6854721  -0.47334898 ... -0.8903005   0.34003362
  -0.51950604]
 [-0.2756386   0.57466805  0.60747224 ... -0.38713667  0.01741117
  -1.0266993 ]
 ...
 [-1.5937876  -0.24903166 -0.21675788 ...  0.7609053   0.6301979
  -0.7865187 ]
 [-0.37443778 -0.50634176 -0.05435409 ... -0.17619671  0.28485733
   0.4528828 ]
 [-0.12459118  0.06653331 -0.8503268  ...  0.5178139  -0.01143806
  -0.63348514]], shape=(16, 214), dtype=float32)
logits shape :  (16, 214)
target_logits inside calling:  tf.Tensor(
[[-0.40937608 -0.50916    -0.47429553 ... -0.4966299  -0.41787565
  -0.51206803]
 [-0.5221163  -0.49376774 -0.41838503 ... -0.45142698 -0.44719943
  -0.4908942 ]
 [-0.42787015 -0.45540735 -0.45555636 ... -0.4498175  -0.4395322
  -0.40937698]
 ...
 [-0.34729308 -0.5006037  -0.57512224 ... -0.5177299  -0.49817082
  -0.4743371 ]
 [-0.517252   -0.45060605 -0.4886686  ... -0.58086795 -0.4891091
  -0.39746413]
 [-0.42886612 -0.50017554 -0.467979   ... -0.45910588 -0.5730436
  -0.54143083]], shape=(16, 512), dtype=float32)
Epoch [1/10]: 2batch [00:00,  2.54batch/s, loss=nan]
class_weights :  [[-0.07154729 -0.02306947  0.02076409 ...  0.08830355  0.03795646
  -0.05537736]
 [ 0.02608624 -0.01183958  0.08689994 ...  0.02180763  0.06158453
  -0.02694845]
 [ 0.02542777 -0.00198308  0.06604646 ... -0.08475913  0.00327843
  -0.01243555]
 ...
 [ 0.02010283  0.03593907  0.02092864 ... -0.03379323 -0.03277213
  -0.02229045]
 [-0.08940084 -0.07136245  0.01797614 ...  0.00967152 -0.06468429
  -0.07133079]
 [-0.07091641  0.06512755  0.03746107 ...  0.02718134 -0.04930652
   0.05826384]]
class_weights shape :  (512, 214)
labels :  tf.Tensor(
[11442   314  9505 19431 14649 16141  3666  7328  1959  1184  8669 13249
 13248  5693  3341 14175], shape=(16,), dtype=int32)
labels shape :  (16,)
logits :  tf.Tensor(
[[ 0.09020002  0.1057491  -0.06038473 ...  0.2686601  -0.12298391
  -0.49338812]
 [-0.16962089  1.1862966   0.08852205 ...  0.4877321   0.59952277
  -0.27804232]
 [ 0.48871276  0.47887018 -0.31307012 ... -0.4734539  -0.06406529
  -0.16018376]
 ...
 [-0.26977187 -0.11485683 -0.16163483 ...  0.3033935   0.18887891
  -0.09917035]
 [ 0.28281295  0.3348917   0.12493412 ... -0.20933345 -0.3844156
   0.10291277]
 [-0.4714488   0.7371135  -0.29467827 ... -0.19916914 -0.08629423
  -0.3001826 ]], shape=(16, 214), dtype=float32)
logits shape :  (16, 214)
target_logits inside calling:  tf.Tensor(
[[-0.3956176  -0.4904168  -0.5275642  ... -0.4945123  -0.46351594
  -0.4495377 ]
 [-0.4822794  -0.407266   -0.40311182 ... -0.50610894 -0.519847
  -0.44590583]
 [-0.5073171  -0.47912392 -0.4754053  ... -0.434844   -0.445486
  -0.41657516]
 ...
 [-0.4377993  -0.5066198  -0.53239334 ... -0.54246074 -0.45009303
  -0.4315502 ]
 [-0.4953049  -0.51256824 -0.45097828 ... -0.474633   -0.48660812
  -0.41378236]
 [-0.47897533 -0.46150374 -0.50203556 ... -0.4518656  -0.5071928
  -0.42874914]], shape=(16, 512), dtype=float32)
Epoch [1/10]: 3batch [00:01,  2.93batch/s, loss=nan]
class_weights :  [[-0.07154729 -0.02306947  0.02076409 ...  0.08830355  0.03795646
  -0.05537736]
 [ 0.02608624 -0.01183958  0.08689994 ...  0.02180763  0.06158453
  -0.02694845]
 [ 0.02542777 -0.00198308  0.06604646 ... -0.08475913  0.00327843
  -0.01243555]
 ...
 [ 0.02010283  0.03593907  0.02092864 ... -0.03379323 -0.03277213
  -0.02229045]
 [-0.08940084 -0.07136245  0.01797614 ...  0.00967152 -0.06468429
  -0.07133079]
 [-0.07091641  0.06512755  0.03746107 ...  0.02718134 -0.04930652
   0.05826384]]
class_weights shape :  (512, 214)
labels :  tf.Tensor(
[ 8166 11109  2830  3006  5202  3041  5928 14173 10219  1597  8481 12291
 15977   314  9426   743], shape=(16,), dtype=int32)
labels shape :  (16,)
logits :  tf.Tensor(
[[-0.19675736 -0.21331258  0.16208161 ...  0.02279774 -0.19987005
   0.19420093]
 [-0.07037039  0.59028083  0.02098583 ... -0.01945191 -0.06997629
  -0.7754571 ]
 [ 0.12654018  0.4661865  -0.45868406 ...  0.01586402 -0.16408072
  -0.1076064 ]
 ...
 [-0.16547754  1.1767013  -0.03659337 ...  0.45615825  0.55877113
  -0.48068044]
 [ 0.4288724   0.209414   -0.2569908  ... -0.34497777 -0.39943108
  -0.52091426]
 [-0.08063635  0.34660146  0.6675661  ... -0.2966741  -0.28360367
  -0.6161124 ]], shape=(16, 214), dtype=float32)
logits shape :  (16, 214)
target_logits inside calling:  tf.Tensor(
[[-0.42727485 -0.3947605  -0.48398003 ... -0.49644738 -0.48725283
  -0.44645205]
 [-0.5239179  -0.48973742 -0.4337545  ... -0.5109691  -0.5288607
  -0.48230633]
 [-0.49096835 -0.5117907  -0.39603633 ... -0.47919467 -0.46923852
  -0.47364157]
 ...
 [-0.4652151  -0.42833942 -0.4172487  ... -0.47953323 -0.5277615
  -0.4556489 ]
 [-0.38039064 -0.44641748 -0.4634931  ... -0.49803168 -0.46754384
  -0.49363837]
 [-0.51332855 -0.42935324 -0.43799263 ... -0.50245386 -0.4839796
  -0.42089927]], shape=(16, 512), dtype=float32)
class_weights :  [[-0.07154729 -0.02306947  0.02076409 ...  0.08830355  0.03795646
  -0.05537736]
 [ 0.02608624 -0.01183958  0.08689994 ...  0.02180763  0.06158453
  -0.02694845]
 [ 0.02542777 -0.00198308  0.06604646 ... -0.08475913  0.00327843
  -0.01243555]
 ...
 [ 0.02010283  0.03593907  0.02092864 ... -0.03379323 -0.03277213
  -0.02229045]
 [-0.08940084 -0.07136245  0.01797614 ...  0.00967152 -0.06468429
  -0.07133079]
 [-0.07091641  0.06512755  0.03746107 ...  0.02718134 -0.04930652
   0.05826384]]
class_weights shape :  (512, 214)
labels :  tf.Tensor(
[11154  5409 11169  8270  4493   569  2890 12072  2457   455 13205  4500
  3241 10947  7114  3320], shape=(16,), dtype=int32)
labels shape :  (16,)
logits :  tf.Tensor(
[[ 0.17383376  0.35689956  0.555365   ...  0.49418834 -0.07051814
   0.4163072 ]
 [ 0.0274086  -0.09327105 -0.26898354 ... -0.12536123 -0.06546447
  -0.17825045]
 [ 0.2917659   0.34864125  0.08608174 ... -0.34252894 -0.26545414
  -0.04261594]
 ...
 [ 0.18320861  0.4792547   0.11823319 ... -0.04992133  0.06286443
   0.13358451]
 [ 0.03251099  0.50109315  0.05128436 ...  0.55811834  0.40319735
  -1.0428449 ]
 [ 0.5744581   0.7111312  -0.34238854 ...  0.21014996  0.1358957
   0.18308519]], shape=(16, 214), dtype=float32)
logits shape :  (16, 214)
target_logits inside calling:  tf.Tensor(
[[-0.47088498 -0.37479493 -0.48961684 ... -0.53336936 -0.53514457
  -0.3968008 ]
 [-0.48518842 -0.45421192 -0.50211084 ... -0.52084565 -0.49652934
  -0.4129826 ]
 [-0.5339255  -0.51311576 -0.39257815 ... -0.47371632 -0.54076433
  -0.43326184]
 ...
 [-0.53899187 -0.47603276 -0.46572384 ... -0.44957644 -0.48546147
  -0.42481828]
 [-0.39202803 -0.3986407  -0.47883406 ... -0.5288463  -0.54959095
  -0.52567625]
 [-0.49601558 -0.46867839 -0.40670738 ... -0.45878708 -0.542849
  -0.47709417]], shape=(16, 512), dtype=float32)
Epoch [1/10]: 4batch [00:01,  3.09batch/s, loss=nan]
class_weights :  [[-0.07154729 -0.02306947  0.02076409 ...  0.08830355  0.03795646
  -0.05537736]
 [ 0.02608624 -0.01183958  0.08689994 ...  0.02180763  0.06158453
  -0.02694845]
 [ 0.02542777 -0.00198308  0.06604646 ... -0.08475913  0.00327843
  -0.01243555]
 ...
 [ 0.02010283  0.03593907  0.02092864 ... -0.03379323 -0.03277213
  -0.02229045]
 [-0.08940084 -0.07136245  0.01797614 ...  0.00967152 -0.06468429
  -0.07133079]
 [-0.07091641  0.06512755  0.03746107 ...  0.02718134 -0.04930652
   0.05826384]]
class_weights shape :  (512, 214)
labels :  tf.Tensor(
[ 3868 19169  1493 16145 12473  3869 12782  5913  4790  7617  5904 11342
  5674 19292  3499  3409], shape=(16,), dtype=int32)
labels shape :  (16,)
logits :  tf.Tensor(
[[ 0.57800317  0.1613152  -0.07243444 ... -0.6577404  -0.16849959
  -0.42421365]
 [-0.01219992  0.97159374 -0.5871273  ...  0.5042745  -0.12050397
  -1.3935063 ]
 [ 0.2730779   0.66284895 -0.9893655  ... -0.1894682   0.4959141
  -0.22442843]
 ...
 [-0.6545307   0.4511203  -0.15208882 ... -0.4262018  -0.0690819
  -0.5040807 ]
 [ 0.0931059   0.36629394 -0.9170659  ... -0.39066035 -0.20088665
  -0.8689369 ]
 [ 0.31923646  0.34596732 -0.23662555 ... -0.38889137 -0.2460812
  -0.04167897]], shape=(16, 214), dtype=float32)
logits shape :  (16, 214)
target_logits inside calling:  tf.Tensor(
[[-0.49253947 -0.43537292 -0.3809213  ... -0.4932991  -0.44544566
  -0.5068683 ]
 [-0.4155472  -0.44325832 -0.48157015 ... -0.50359666 -0.427507
  -0.48878798]
 [-0.46081388 -0.46436414 -0.46414536 ... -0.51911986 -0.48245418
  -0.51590776]
 ...
 [-0.47680017 -0.4833611  -0.44492304 ... -0.5321027  -0.4253278
  -0.4180446 ]
 [-0.38592088 -0.46478513 -0.4231593  ... -0.5112595  -0.5244249
  -0.4525988 ]
 [-0.42263517 -0.45989022 -0.46747324 ... -0.4677919  -0.52482826
  -0.4653073 ]], shape=(16, 512), dtype=float32)
Epoch [1/10]: 6batch [00:01,  3.29batch/s, loss=nan]
class_weights :  [[-0.07154729 -0.02306947  0.02076409 ...  0.08830355  0.03795646
  -0.05537736]
 [ 0.02608624 -0.01183958  0.08689994 ...  0.02180763  0.06158453
  -0.02694845]
 [ 0.02542777 -0.00198308  0.06604646 ... -0.08475913  0.00327843
  -0.01243555]
 ...
 [ 0.02010283  0.03593907  0.02092864 ... -0.03379323 -0.03277213
  -0.02229045]
 [-0.08940084 -0.07136245  0.01797614 ...  0.00967152 -0.06468429
  -0.07133079]
 [-0.07091641  0.06512755  0.03746107 ...  0.02718134 -0.04930652
   0.05826384]]
class_weights shape :  (512, 214)
labels :  tf.Tensor(
[14343  1434   870 11088   212  2769 20174 14226  3736 16001  3788  8296
  2272  7533 13250  6786], shape=(16,), dtype=int32)
labels shape :  (16,)
logits :  tf.Tensor(
[[ 0.19983542 -0.18925124 -0.35843274 ... -0.22158289 -0.19040169
   0.01305819]
 [-0.6920266  -0.19054618  0.01261879 ... -0.863125    0.2301442
   0.35670522]
 [-0.54018325  0.5191148   0.12106402 ...  0.5660032  -0.41441488
   0.12721318]
 ...
 [-0.2173877   0.8746421  -0.73231035 ... -0.15774941  0.02040236
  -0.388246  ]
 [ 0.00583356  0.21866369 -0.35246223 ...  0.30253044  0.04051851
  -0.21645775]
 [ 0.38573867  0.6111902  -0.66730255 ... -0.08702751 -0.18722671
  -0.43641695]], shape=(16, 214), dtype=float32)
logits shape :  (16, 214)
target_logits inside calling:  tf.Tensor(
[[-0.46032333 -0.45692754 -0.48155114 ... -0.5220475  -0.55633557
  -0.4543623 ]
 [-0.41653657 -0.49110666 -0.5253643  ... -0.47046193 -0.5037827
  -0.44169816]
 [-0.42483857 -0.4863621  -0.5233862  ... -0.5072314  -0.5566115
  -0.36512554]
 ...
 [-0.42671046 -0.45978054 -0.5164236  ... -0.4811409  -0.49915367
  -0.44097054]
 [-0.486555   -0.50692403 -0.423274   ... -0.5193035  -0.55187196
  -0.49745122]
 [-0.5100029  -0.45428032 -0.4199505  ... -0.4990041  -0.5938168
  -0.50580066]], shape=(16, 512), dtype=float32)
class_weights :  [[-0.07154729 -0.02306947  0.02076409 ...  0.08830355  0.03795646
  -0.05537736]
 [ 0.02608624 -0.01183958  0.08689994 ...  0.02180763  0.06158453
  -0.02694845]
 [ 0.02542777 -0.00198308  0.06604646 ... -0.08475913  0.00327843
  -0.01243555]
 ...
 [ 0.02010283  0.03593907  0.02092864 ... -0.03379323 -0.03277213
  -0.02229045]
 [-0.08940084 -0.07136245  0.01797614 ...  0.00967152 -0.06468429
  -0.07133079]
 [-0.07091641  0.06512755  0.03746107 ...  0.02718134 -0.04930652
   0.05826384]]
class_weights shape :  (512, 214)
labels :  tf.Tensor(
[13616  1449  5932  2977 14012 13012 16148 16574   455 14644  2587  6766
 18602 17569  9325  3736], shape=(16,), dtype=int32)
labels shape :  (16,)
logits :  tf.Tensor(
[[-0.13739923  0.42064998  0.00721037 ...  0.35631368 -0.16070034
  -0.6218335 ]
 [ 0.29295176  0.13757068 -0.04192078 ...  0.4133935  -0.01754976
  -0.08973861]
 [-0.4909365  -0.3258107   0.1204057  ... -0.24937841 -0.07292646
  -0.28140342]
 ...
 [ 0.0793658   1.156419    0.03510895 ...  0.08095186  0.3646526
  -0.351102  ]
 [-0.35604402  0.3830788  -0.09276076 ... -0.0208631   0.09093831
  -0.37243736]
 [ 0.11898601  0.34577584  0.13236059 ... -0.31919676 -0.3411179
   0.02237839]], shape=(16, 214), dtype=float32)
logits shape :  (16, 214)
target_logits inside calling:  tf.Tensor(
[[-0.46862972 -0.5053872  -0.4087431  ... -0.4708995  -0.41095516
  -0.43321416]
 [-0.48458058 -0.47134006 -0.47283274 ... -0.5240599  -0.5742386
  -0.44783652]
 [-0.4497647  -0.4383861  -0.48589748 ... -0.5516401  -0.5692487
  -0.4635873 ]
 ...
 [-0.5488741  -0.48695654 -0.40012178 ... -0.5737456  -0.5103246
  -0.4802511 ]
 [-0.491943   -0.44454038 -0.4249562  ... -0.4730008  -0.5851825
  -0.49085328]
 [-0.48542315 -0.51825905 -0.45620295 ... -0.4629727  -0.5127115
  -0.41201305]], shape=(16, 512), dtype=float32)
Epoch [1/10]: 8batch [00:02,  3.40batch/s, loss=nan]
class_weights :  [[-0.07154729 -0.02306947  0.02076409 ...  0.08830355  0.03795646
  -0.05537736]
 [ 0.02608624 -0.01183958  0.08689994 ...  0.02180763  0.06158453
  -0.02694845]
 [ 0.02542777 -0.00198308  0.06604646 ... -0.08475913  0.00327843
  -0.01243555]
 ...
 [ 0.02010283  0.03593907  0.02092864 ... -0.03379323 -0.03277213
  -0.02229045]
 [-0.08940084 -0.07136245  0.01797614 ...  0.00967152 -0.06468429
  -0.07133079]
 [-0.07091641  0.06512755  0.03746107 ...  0.02718134 -0.04930652
   0.05826384]]
class_weights shape :  (512, 214)
labels :  tf.Tensor(
[ 3341 14364  2192 17606   743  2977   283 17414  2192  1827     4 18226
 14053  2824 11474 16516], shape=(16,), dtype=int32)
labels shape :  (16,)
logits :  tf.Tensor(
[[ 0.48532888 -0.03771351 -0.21448003 ...  0.97941756  0.5101669
  -0.6936798 ]
 [-0.03332526  0.60884523  0.34555247 ...  0.17341597 -0.70665264
  -0.25494382]
 [-0.2862319   0.88382936  0.46006796 ...  0.973786    0.10388324
  -0.5888302 ]
 ...
 [ 0.2091163  -0.02802525  0.1419187  ... -0.20346242 -0.01128776
  -0.0675283 ]
 [ 0.03348187  0.23810017 -0.28930712 ...  0.41748637 -0.09295375
  -0.32645518]
 [ 0.06341116  0.47675174 -0.68711644 ...  0.06052744  0.34673354
  -0.15073118]], shape=(16, 214), dtype=float32)
logits shape :  (16, 214)
target_logits inside calling:  tf.Tensor(
[[-0.44945803 -0.4488667  -0.49205592 ... -0.52650595 -0.48646772
  -0.5146937 ]
 [-0.517638   -0.5168296  -0.45049834 ... -0.48184887 -0.5136248
  -0.42165264]
 [-0.4231148  -0.49899915 -0.53883344 ... -0.51535726 -0.5015623
  -0.4252881 ]
 ...
 [-0.49695227 -0.44860357 -0.44301444 ... -0.4934584  -0.49325225
  -0.42463243]
 [-0.44819695 -0.45737794 -0.48683867 ... -0.4787514  -0.5754766
  -0.44328845]
 [-0.4024756  -0.4680557  -0.51806355 ... -0.52414614 -0.5134484
  -0.46852335]], shape=(16, 512), dtype=float32)
class_weights :  [[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 ...
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]]
class_weights shape :  (512, 214)
labels :  tf.Tensor(
[17379  2825 15697  3588 14825 14825  7814 13478 19322 14031    12  2851
  3259  3791 17036 11676], shape=(16,), dtype=int32)
labels shape :  (16,)
logits :  tf.Tensor(
[[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 ...
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]], shape=(16, 214), dtype=float32)
logits shape :  (16, 214)

Please help… Where I am making mistake?

Thanks & Regards.
Pritam