Training a model using GradientTape() API

Hi, I am trying to use GradientTape() TensorFlow API, I need some help to modify my training script for this.
Some more context: I am using an AWS GPU instance, ml.p3.2xlarge, and its size is 16 GB. The training data size in memory (notebook instance) is 17.8 GB.
I earlier used from_tensor_slices APIs, to load batches in the GPU memory and then train it. Here is the script for that,

    # load data
    X_train = np.load(os.path.join(training_dir, 'train.npz'))['X']
    y_train = np.load(os.path.join(training_dir, 'train.npz'))['y']
    X_val  = np.load(os.path.join(validation_dir, 'val.npz'))['X']
    y_val  = np.load(os.path.join(validation_dir, 'val.npz'))['y']
        
    #create model
    model = Sequential()
    model.add(LSTM(32, input_shape=(X_train.shape[1:]), return_sequences=True))
    model.add(Dropout(0.2))
    model.add(BatchNormalization()) 

    model.add(LSTM(32))
    model.add(Dropout(0.2))
    model.add(BatchNormalization())

    model.add(Dense(32, activation='relu'))
    model.add(Dropout(0.2))

    model.add(Dense(1, activation='sigmoid'))
  
    # compile model
    model.compile(loss=tf.keras.losses.binary_crossentropy,
                  optimizer=Adam(lr=lr, decay=1e-6),
                  metrics=METRICS)
    
    # Slicing using tensorflow apis
    tf_train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
    train_dataset = tf_train_dataset.batch(batch_size)
    
    tf_validation_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val))
    validation_dataset = tf_validation_dataset.batch(batch_size)
    
    # Train model
    model.fit(train_dataset,
              epochs=epochs,
              class_weight=class_weight,
              validation_data=validation_dataset, 
              verbose=2)
    
    # save model
    model.save(os.path.join(model_dir, "000000001"))

Instead of from_tensorflow_slices, I want to use GradientTape(), I was referring to this article: https://saturncloud.io/blog/how-to-accumulate-gradients-in-tensorflow/

I would highly appreciate some help, to make me understand how can I use GradientTape() in my case.

Thanks and regards,
Priyanshi

HI @Priyanshi_Jajoo.
Can I ask what you have tried out so far and where you got stuck?
This keras,io page explains nicely why and how to use tf.GradientTape(). Maybe you can have a look at it.