The speed of training is reduced using a custom method in tensorflow.keras.layers

I’m using and custom layers to solve the bottleneck of data augmentation, but I found that using alone is faster than mixing, I don’t know what’s going on in the custom layers, can someone please tell me?

Thanks in advance!

This is my data augmentation code, mainly to do standardization and resize.

def random_normalization(data, mean, std):
    mean = tf.multiply(mean, tf.random.uniform(shape=(), minval=0.5,maxval=0.9, dtype=tf.float64))
    std = tf.multiply(std, tf.random.uniform(shape=(), minval=0.5,maxval=0.9, dtype=tf.float64))
    return tf.divide((tf.subtract(data, mean)), std)

def random_resize(data):
    def resizing(index, data, choice, enable, new_data, number, overlap):        
        FrontEnd = tf.cond(tf.math.greater_equal(tf.subtract(index, overlap), tf.constant(0)),
                           lambda: tf.subtract(index, overlap),
                           lambda: index)
        BackEnd = tf.cond(tf.math.less(tf.add(tf.add(index, 10),overlap),tf.constant(2000)),
                          lambda: tf.add(tf.add(index, 10),overlap),
                          lambda: index)
        z1 = tf.gather(data, indices=[0], axis=1)
        z1 = tf.gather(z1, indices=tf.range(FrontEnd, BackEnd), axis=0)
        z2 = tf.gather(data, indices=[1], axis=1)
        z2 = tf.gather(z2, indices=tf.range(FrontEnd, BackEnd), axis=0)
        z3 = tf.gather(data, indices=[2], axis=1)
        z3 = tf.gather(z3, indices=tf.range(FrontEnd, BackEnd), axis=0)
        z4 = tf.gather(data, indices=[3], axis=1)
        z4 = tf.gather(z4, indices=tf.range(FrontEnd, BackEnd), axis=0)
        z5 = tf.gather(data, indices=[4], axis=1)
        z5 = tf.gather(z5, indices=tf.range(FrontEnd, BackEnd), axis=0)
        z6 = tf.gather(data, indices=[5], axis=1)
        z6 = tf.gather(z6, indices=tf.range(FrontEnd, BackEnd), axis=0)
        new_data = tf.tensor_scatter_nd_update(new_data, [[number, 0], [number, 1], [number, 2],
                                                          [number, 3], [number, 4], [number, 5]], 
                                               [tf.math.reduce_mean(z1), tf.math.reduce_mean(z2),
                                                tf.math.reduce_mean(z3), tf.math.reduce_mean(z4),
                                                tf.math.reduce_mean(z5), tf.math.reduce_mean(z6)])
        return tf.add(index, 10), data, choice, enable, new_data, tf.add(number, 1), overlap
    choice = tf.random.uniform(shape=(), minval=0,maxval=4,dtype=tf.int32)
    enable = tf.random.uniform(shape=(), minval=0,maxval=1,dtype=tf.float64)
    overlap = tf.random.uniform(shape=(), minval=5,maxval=21,dtype=tf.int32)
    new_data = tf.zeros((200,6), dtype=tf.float64)
    index = tf.constant(0)
    number = tf.constant(0)
    condition = lambda index, data, choice, enable, new_data, number, overlap: tf.less(index, 2000)
    r = tf.while_loop(condition, resizing, loop_vars=(index, data, choice, enable, new_data, number, overlap))
    return r[4]

def normal_resize(data):
    data = tf.reshape(data, (2000,6,1))
    data = tf.image.resize(data, size=[200,6])
    return tf.cast(tf.reshape(data, (200,6)),dtype=tf.float64)

def augmentation(data, labels):
    mean = tf.math.reduce_mean(data,axis=0)
    std = tf.math.reduce_std(data,axis=0)
    data = tf.cond(tf.random.uniform(shape=(), minval=0, maxval=1,dtype=tf.float64) < tf.constant(0.8,dtype=tf.float64), 
                   lambda: random_normalization(data, mean, std), 
                   lambda: tf.divide((tf.subtract(data, mean)), std))
    # 2000 resize to 200
    data = tf.cond(tf.random.uniform(shape=(), minval=0, maxval=1,dtype=tf.float64) < tf.constant(0.8,dtype=tf.float64), 
                   lambda: random_resize(data), 
                   lambda: normal_resize(data))

    return data, labels

Main code, including and model

if __name__ == '__main__':
    trainDS =,2000,6),
    trainDS = (
        .shuffle(1000, reshuffle_each_iteration=False)
        .batch(128, drop_remainder=True)
    input = Input((200,6))
    x = LSTM(64, return_sequences=True)(input)
    output = Dense(1,activation='sigmoid')(x)
    model = Model(input, output)
    model.compile(optimizer='adam', loss='BinaryCrossentropy'), epochs=3)

Then this is the code of my custom layer, although it is a bit cumbersome, it still achieves the result I want.

import tensorflow as tf
from tensorflow.keras.layers import LSTM, Dense, Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Layer
import numpy as np

class CustomLayer(Layer):
    def __init__(self, **kwargs):
    def execute(self, data, batch_size, new_data, _type):
        def _fun(index, data, _type, new_data):
            resized = tf.cond(_type,
                              lambda:augmentation(tf.reshape(tf.gather(data,[index]), (2000,6))),
                              lambda:normal_resize(tf.reshape(tf.gather(data,[index]), (2000,6))))
            values = tf.reshape(resized, (1,-1))[0]
            _Indices = self.createIndices(index)
            new_data = tf.tensor_scatter_nd_update(new_data, _Indices, values)
            return tf.add(index,1), data, _type, new_data
        index = tf.constant(0)
        condition = lambda index, data, _type, new_data: tf.less(index, batch_size)
        r = tf.while_loop(condition, _fun, loop_vars=(index, data, _type, new_data))
        return r[-1]
    def createIndices(self, BatchSizeIndex):
        def loop1(_i, BatchSizeIndex, col_num, _Indices):
            def loop2(_i, _j, BatchSizeIndex, col_num, _Indices):
                _Indices = tf.tensor_scatter_nd_update(_Indices, [[col_num, 0], [col_num, 1], [col_num, 2]], 
                                                        [BatchSizeIndex, _i, _j])
                return _i, tf.add(_j,1), BatchSizeIndex, tf.add(col_num,1), _Indices
            _j = tf.constant(0)
            condition_loop2 = lambda _i, _j, BatchSizeIndex, col_num, _Indices: tf.less(_j, 6)
            r_loop2 = tf.while_loop(condition_loop2, loop2, loop_vars=(_i, _j, BatchSizeIndex, col_num, _Indices))  
            return tf.add(_i,1), BatchSizeIndex, r_loop2[3], r_loop2[4]

        _Indices = tf.zeros((1200,3), dtype=tf.int32)
        col_num = tf.constant(0)
        _i = tf.constant(0)
        condition_loop1 = lambda _i, BatchSizeIndex, col_num, _Indices: tf.less(_i, 200)
        r_loop1 = tf.while_loop(condition_loop1, loop1, loop_vars=(_i, BatchSizeIndex, col_num, _Indices))
        return r_loop1[-1]
    def call(self, images, training):
        batch_size = tf.shape(images)[0]
        new_data = tf.zeros((batch_size, 200, 6), dtype=tf.float64)
        images = tf.cast(images, dtype=tf.float64)
        if training:
            data = self.execute(images, batch_size, new_data, tf.constant(True))
            data = self.execute(images, batch_size, new_data, tf.constant(False))
        return data

The final code can be modified to execute like this.

def augmentation(data):
    return data

if __name__ == '__main__':
    trainDS =,2000,6),
    trainDS = (
        .shuffle(1000, reshuffle_each_iteration=False)
        .batch(128, drop_remainder=True)
    input = Input((2000,6))
    x = CustomLayer()(input)
    x = LSTM(64, return_sequences=True)(x)
    output = Dense(1,activation='sigmoid')(x)
    model = Model(input, output)
    model.compile(optimizer='adam', loss='BinaryCrossentropy'), epochs=3)

Results: Alone spend about 18s, spend about 38s.

The thing I want to clarify is that the use of map in to run augmentation is on the CPU, but if I write augmentation in the Layer, it should theoretically run on the GPU. Why is there such a big gap between the two?

Environment: python3.6, tensorflow2.4.0

I suggest to profile the two pipelines performances with:

I’ve not checked your specific model but generally if your preprocessing pipeline is faster then the inference+backward steps of your model with enough margin you could prepare your data on CPU in parallel when the GPU is “computing” your model.

So, in terms of performance, is using the augmentation pipeline faster than building on the model layer?

I don’t think that we can do this claim in general as the two profiling methods that I’ve linked are always the profiling GT for your specific model and augmentation needs (as It could depend on model size, input size, augmentation complexity etc…).

You can check how was organized some recent STOA augmentations in:

Check also my comment and the related reply at:

/cc @Luke_Wood

Maybe I use too many tf.gather function in my augmentation function, causing the whole performance to be seriously reduced. The loop reduction in the random_resizez method tf.gather helps to speed up the training, but this is only valid for, and is constructed in the layer’s way too slow.
Are there any faster alternatives to tf.gather?

I suggest to profile your code with then mentioned resources so you can investigate the bottlenecks instead of guessing.

Profile help:

import tensorflow as tf
from tensorflow._api.v2.compat.v1 import ConfigProto
from tensorflow._api.v2.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
from tensorflow.keras.layers import LSTM, Dense, Input, Layer
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import TensorBoard
import numpy as np
from datetime import datetime
import os
os.system('rm -r logs/*')

class CustomLayer(Layer):
    def __init__(self, **kwargs):
    def call(self, images, training):
        cutmix_augment = self._custom(images)
        return cutmix_augment
    def _custom(self, images):
        images = tf.map_fn(
            lambda x: _resize(*x),
        return images

def _resize(data):
    def resizing(index, number, data, new_data):
        select_range = tf.gather(data, indices=tf.range(index, tf.add(index, 10)), axis=0)
        new_data = tf.tensor_scatter_nd_update(new_data, [[number, 0], [number, 1], [number, 2],
                                                          [number, 3], [number, 4], [number, 5]], 
                                                          tf.cast(tf.reduce_mean(select_range, axis=0),dtype=tf.float64))
        return tf.add(index,10), tf.add(number,1), data, new_data
    new_data = tf.zeros((200,6), dtype=tf.float64)
    index = tf.constant(0)
    number = tf.constant(0)
    condition = lambda index, number, data, new_data: tf.less(index, 2000)
    results = tf.while_loop(condition, resizing, loop_vars=(index, number, data, new_data))
    return results[-1]

if __name__ == '__main__':

    logs = "logs/" +"%Y%m%d-%H%M%S")
    tboard_callback = TensorBoard(log_dir = logs,
                                  histogram_freq = 1,
                                  profile_batch = (1,2))

    trainDS =,2000,6),
    trainDS = (
        .shuffle(1000, reshuffle_each_iteration=False)
        .batch(128, drop_remainder=True)
    input = Input((2000,6))
    x = CustomLayer()(input)
    x = LSTM(32, return_sequences=True)(x)
    output = Dense(1,activation='sigmoid')(x)
    model = Model(input, output)
    model.compile(optimizer='adam', loss='BinaryCrossentropy'), epochs=2, callbacks=[tboard_callback])

I simplified the code, only kept the part that I think is the problem, and ran Profile, but I can’t understand the content of Profile clearly, the download link file and screenshot of logs.

I ran the annotation tf.tensor_scatter_nd_update separately and found that it affects my execution speed.

You have many ops occupancy views as you can see in:

What do you want to do in resize?

In the data of 2000 points, the average is taken every 10 points, so I first establish tf.zeros((200,6)), and then update the content in while_loop.

def _resize(data):   
    new_data = tf.TensorArray(dtype=tf.float64, size=200, dynamic_size=False)
    number = tf.constant(0)    
    for index in tf.range(0,10,2000):
        select_range = tf.gather(data, indices=tf.range(index, tf.add(index, 10)), axis=0)
        new_data = new_data.write(number, tf.cast(tf.reduce_mean(select_range, axis=0),dtype=tf.float64))
        number = tf.add(number,1)
    return new_data.stack()

I changed tf.while_loop to for loop, and the training time changed from 2min to 2s.