Hi, I am trying to measure the runtime of a jit compiled model inference using time.time(). Here is a example snippet with a random model:
model = tf.keras.Sequential()
layers = 15
for _ in range(layers):
model.add(tf.keras.layers.Conv2D(3, 3, data_format="channels_first"))
model_input = tf.random.uniform(shape=(12, 3, 256, 256))
xla_fn = tf.function(model, jit_compile=True)
iterations = 50
for i in range(iterations):
start_time = time.time()
xla_fn(model_input)
end_time = time.time()
print(f"Iteration {i} time: {1000 * (end_time - start_time)}")
The resulting output is:
Iteration 0 time: 6884.105682373047
Iteration 1 time: 1.5869140625
Iteration 2 time: 1.2054443359375
Iteration 3 time: 1.1410713195800781
Iteration 4 time: 1.1148452758789062
Iteration 5 time: 1.1096000671386719
Iteration 6 time: 1.1081695556640625
Iteration 7 time: 1.1057853698730469
Iteration 8 time: 1.1026859283447266
Iteration 9 time: 1.0957717895507812
Iteration 10 time: 1.115560531616211
Iteration 11 time: 1.100301742553711
Iteration 12 time: 1.1072158813476562
Iteration 13 time: 1.0993480682373047
Iteration 14 time: 1.102447509765625
Iteration 15 time: 1.100778579711914
Iteration 16 time: 1.0983943939208984
Iteration 17 time: 1.111745834350586
Iteration 18 time: 1.0929107666015625
Iteration 19 time: 1.0929107666015625
Iteration 20 time: 1.10626220703125
Iteration 21 time: 1.1260509490966797
Iteration 22 time: 1.1012554168701172
Iteration 23 time: 1.0957717895507812
Iteration 24 time: 1.1103153228759766
Iteration 25 time: 1.0924339294433594
Iteration 26 time: 1.0981559753417969
Iteration 27 time: 1.0950565338134766
Iteration 28 time: 1.1153221130371094
Iteration 29 time: 1.0952949523925781
Iteration 30 time: 1.100301742553711
Iteration 31 time: 1.1081695556640625
Iteration 32 time: 1.0983943939208984
Iteration 33 time: 1.0955333709716797
Iteration 34 time: 1.0974407196044922
Iteration 35 time: 1.1048316955566406
Iteration 36 time: 13.811588287353516
Iteration 37 time: 18.66292953491211
Iteration 38 time: 18.66936683654785
Iteration 39 time: 18.68462562561035
Iteration 40 time: 18.68152618408203
Iteration 41 time: 19.069910049438477
Iteration 42 time: 17.874717712402344
Iteration 43 time: 18.11075210571289
Iteration 44 time: 18.109560012817383
Iteration 45 time: 18.10741424560547
Iteration 46 time: 18.124818801879883
Iteration 47 time: 18.130064010620117
Iteration 48 time: 17.90904998779297
Iteration 49 time: 17.884492874145508
From my understanding, I would expect the first iteration to take longer since the function needs to be traced and compiled. However, I’m confused why the runtime from iteration 36 and onwards increases significantly. Is there an issue with the way I am measuring the runtime? If not, what would explain the increase in runtime?
I am running this on a P40 GPU by the way. Thanks in advance.