How to stack tensors in Java

I am trying to stack tensors (all of the same shape) into one tensor. In python I believe you can simply do tf.stack([my, tensors])

I am struggling to stack x tensors of shape [1, 480, 640, 3] into one tensor of shape [x, 480, 640, 3].

I have atttempted using the Stack and TensorListStack classes:

try (EagerSession eagerSession = EagerSession.create())
{
Scope scope = new Scope(eagerSession);
Tensor<?> combinedTensor = Stack.create(scope, tensorsList);
}

Where tensorsList is an ArrayList<Tensors<?>>
create() requires Iterable<Operand> instead of ArrayList but I am unsure what this means.

Clearly I’m missing something but can’t figure it out…

Edit:
The tensors were originally ByteBuffers in BGR24 format

You can’t operate on tensors themselves. Tensors are mostly used for feeding a graph, reading its output or looking at the result of an eager operation.

What you need is constants. You can skip completely the creation of the tensors by creating constants directly instead, hypothetically something like this:

List<ByteBuffer> byteBuffers = ....;

try (EagerSession eagerSession  = EagerSession.create()) {
    Ops tf = Ops.create(eagerSession);

    var images = byteBuffers.stream().map(b -> tf.constant(Shape.of(1, 480, 640, 3), DataBuffers.of(b))).collect(Collectors.toList());
    var stack = tf.stack(images);
}
2 Likes

Thanks for the clarification on using tensors Karl!

The above code was pretty much the solution just need to cast the constants as Operand<TUint8> and change Shape.of().

    try(EagerSession eagerSession = EagerSession.create();){
        Ops tf = Ops.create(eagerSession);

        var images = byteBuffers.stream().map(b -> (Operand<TUint8>) tf.constant(Shape.of(480, 640, 3), DataBuffers.of(b))).collect(Collectors.toList());
        var stack = tf.stack(images);
    }

For anyone that may not understand why:
tf.stack() requires Iterable Operand<T>
Constant implements the Operand interface so just need to cast so that images can be passed to stack()

1 Like

Great! But I’m not sure you should cast a tensor like that. I understands it works but it might results in errors in other cases. You can either create TUint8 tensors upfront and convert them to constant using tf.constant(tensor), or you can use tf.dtypes.cast(tensor, TUint8.class) to let TensorFlow deal with the cast and, potentially, the conversion.

1 Like

I’m not too sure I understand.

The tf.constant(Shape shape, ByteDataBuffer byteDataBuffer) returns a Constant<TUint8> so I’m not trying to cast to TUint8

Is it that it’s bad to cast Constant to Operand?

Thanks again!

Oh sorry, you’re right, I see now that your cast is not to force the correct datatype but to make Stack happy, which is expecting a list of Operand<?> and not a list of ? extends Operand<?> as I was originally thinking… Can’t tell if this is the desired behaviour but your code is correct, sorry for the confusion!

1 Like