Torch.select_index() in Tensorflow?

Hello,

I have found the following code as alternative to the torch.select_index function, but the loop makes the process very slow. Do you have any suggestion ?

    def tf_index_select(self, input_, dim, indices):
    """
    input_(tensor): input tensor
    dim(int): dimension
    indices(list): selected indices list
    """
    start = time.time()
    shape = input_.get_shape().as_list()
    if dim == -1:
        dim = len(shape)-1
    shape[dim] = 1

    tmp = []
    for idx in indices:
        begin = [0]*len(shape)
        begin[dim] = tf.dtypes.cast(idx, tf.int64)
        tmp.append(tf.slice(input_, begin, shape))
    res = tf.concat(tmp, axis=dim)
    return res

Thank you

1 Like

Is tf.gather enought for your use case?

Then we have also tf.gather_nd:

2 Likes

Perfect ! thank you very much

2 Likes