Hello, I am quite confused about the parameter new_axis_mask of tf.strided_slice() function. The official documentation says:
If the ith bit of
new_axis_mask
is set, thenbegin
,end
, andstride
are ignored and a new length 1 dimension is added at this point in the output tensor.
For example,foo[:4, tf.newaxis, :2]
would produce a shape(4, 1, 2)
tensor.
I did some tests…
import tensorflow as tf
import numpy as np
arr = np.random.rand(4,6,7)
out1 = tf.strided_slice(arr, begin=[1,1,1], end=[4,5,6], strides=[1,1,1])
print(out1.shape) // returns TensorShape([3, 4, 5])
So according to my understanding the result of(when using new_axis_mask)
out2 = tf.strided_slice(arr, begin=[1,1,1], end=[4,5,6], strides=[1,1,1], new_axis_mask=1)
print(out2.shape)
should be TensorShape([1, 3, 4, 5])
cause we add an extra dimension to the output at the first axis. However the result is TensorShape([1, 3, 5, 7])
.
Can anyone explain me how the parameter new_axis_mask of tf.strided_slice() function works?
Code is here: Google Colab