Large memory usage while validating model with workaround

Hi,

I wanted to calculate the maximum absolute error and the percentile error after training a model using the following code:

def compute_percentiles(errors, percentiles=[90, 95, 99]):
    return np.percentile(errors, percentiles)

all_errors = []

BATCH_SIZE = 65536

for batch_features, batch_labels in val_dataset:
    predictions = gwd_model(batch_features, training=False)

    batch_labels = tf.cast(batch_labels, tf.float32)

    #This is the workaround
    batch_labels = tf.reshape(batch_labels, predictions.shape)

    print("Shape of batch_labels:", batch_labels.shape)
    print("Shape of predictions:", predictions.shape)

    errors = tf.abs(tf.subtract(batch_labels, predictions))

    max_error = tf.reduce_max(errors)

    print("Max absolute error for this batch:", max_error.numpy())

    all_errors.extend(tf.reshape(errors, [-1]).numpy())

percentile_values = compute_percentiles(all_errors)

print("90th percentile error:", percentile_values[0])
print("95th percentile error:", percentile_values[1])
print("99th percentile error:", percentile_values[2])

The shape of batch_labels is (65536, 1), the shape of predictions is (65536,) if you comment out the reshape. Without the reshape tf.substract wants to allocate 16GB of memory causing it to run out of GPU memory. If you then bracket the code with a

with tf.device("CPU:0")

the code quickly exhausts my 128GB of RAM. If you reduce the batch size to 1024 the code allocates about 64GB of RAM but it runs and finishes although it takes about a minute for only 200K features. Why does the code without the reshape use so much memory? With the reshape the code no longer runs out of memory and finishes in about 0.1 second.

Regards,
GW

Hi @gwiesenekker, while doing the arithmetic operations if the array are of different shapes broadcasting takes place. This broadcasting will treat the arrays with different shapes in a way that the smaller array is “broadcast” across the larger array so that they have compatible shapes. This broadcasting leads to inefficient use of memory. so code without reshape uses more memory. Please refer to this document to know more about broadcasting. Thank You.