The scan
method is neat. It allows you to iterate over the elements of a dataset and at the same time carry over a state.
Just think about a simple dataset that produces the values 1, 2, 3, 4, 5
.
With the map function, you can loop over every single element of this dataset and apply a transformation.
dataset.map(lambda x: x*2) # 2, 4, 6, 8, 10
With scan
instead, you can carry over some information from the previous iterations.
So, for example, if you want to sum up the previous element, you can use the scan
method.
initial_state = tf.constant(0)
def scan_fun(old_state, input_element):
new_state = input_element
output_element = input_element + old_state
return new_state, output_element
dataset.scan(initial_state, scan_func) # 1 +0 , 2 + 1, 3 + 2, ...
So you carry over the next iteration, the old_state
, and every time you iterate over a new input_element
you can generate a new new_state
(that becomes the old_state
input for the next iteration), and produce (as output) the output_element
.
I used scan
in several solutions (which is super helpful), I suggest you read all the articles and search inside the code how I used it. I hope it’s helpful for you to understand a little bit more about how to use this great feature.