Hello,
I have found the following code as alternative to the torch.select_index function, but the loop makes the process very slow. Do you have any suggestion ?
def tf_index_select(self, input_, dim, indices):
"""
input_(tensor): input tensor
dim(int): dimension
indices(list): selected indices list
"""
start = time.time()
shape = input_.get_shape().as_list()
if dim == -1:
dim = len(shape)-1
shape[dim] = 1
tmp = []
for idx in indices:
begin = [0]*len(shape)
begin[dim] = tf.dtypes.cast(idx, tf.int64)
tmp.append(tf.slice(input_, begin, shape))
res = tf.concat(tmp, axis=dim)
return res
Thank you