I’m trying to essentially create a 3-D tensor from the indexed rows of a 2-D tensor. For example, assuming I have:
A = tensor(shape=[200, 256]) # 2-D Tensor.
Aidx = tensor(shape=[1000, 10]) # 2-D Tensor holding row indices of A for each of 1000 batches.
I wish to create:
B = tensor(shape=[1000, 10, 256]) # 3-D Tensor with each batch being of dims (10, 256) selected from A.
Right now, I’m doing this in a memory inefficient manner by doing a tf.broadcast()
and then using a tf.gather()
. This is very fast, but also takes up a lot of RAM:
A = tf.broadcast_to(A, [1000, A.shape[0], A.shape[1]])
A = tf.gather(A, Aidx, axis=1, batch_dims=1)
Is there a more memory efficient way of doing the above operation? Naively, one can make use of a for loop, but that is very compute inefficient for my use case. Thanks in advance!