Strassen's Matrix Multiplication on TensorFlow

I am trying to do Strassen’s matrix multiplication on TensorFlow. I am trying to perform a graph execution using the following code

import tensorflow as tf

@tf.function  # The decorator converts `split_tf` into a `Function`.
def split_tf(matrix):
  shape = tf.shape(matrix)
  shape2 = tf.cast(tf.divide(shape, 2), dtype=tf.int32)
  return tf.slice(A, [0, 0], shape2), tf.slice(A, [0, shape2[1]], shape2), 
         tf.slice(A, [shape2[0], 0], shape2), 
         tf.slice(A, [shape2[0], shape2[1]], shape2)

@tf.function  # The decorator converts `strassen_tf` into a `Function`.
def strassen_tf(x, y):
  # Base case when size of matrices is 1x1
  if tf.rank(x) == 1:
      return tf.math.multiply(x, y)

  # Splitting the matrices into quadrants. This will be done recursively
  # until the base case is reached.

  a, b, c, d = split_tf(x)
  e, f, g, h = split_tf(y)

  # Computing the 7 products, recursively (p1, p2...p7)
  p1 = strassen_tf(a, tf.math.subtract(f, h)) 
  p2 = strassen_tf(tf.math.add(a, b), h)
  p3 = strassen_tf(tf.math.add(c, d), e)
  p4 = strassen_tf(d, tf.math.subtract(g, e))
  p5 = strassen_tf(tf.math.add(a, d), tf.math.add(e, h))
  p6 = strassen_tf(tf.math.subtract(b, d), tf.math.add(g, h))
  p7 = strassen_tf(tf.math.subtract(a, c), tf.math.add(e, f))

  # Computing the values of the 4 quadrants of the final matrix c
  c11 = tf.math.add(tf.math.subtract(tf.math.add(p5, p4), p2), p6)
  c12 = tf.math.add(p1, p2)          
  c21 = tf.math.add(p3, p4)
  c22 = tf.math.subtract(tf.math.subtract(tf.math.add(p1, p5), p3), p7)

  # Combining the 4 quadrants into a single matrix by stacking horizontally and vertically.
  c = tf.concat([tf.concat([C11, C12], axis=1), tf.concat([C21, C22], axis=1)], axis=0)

  return c

Now when I use the above code to multiply two tensors as below:

n = 2;
A = tf.random.normal([n, n], 0, 1, dtype=tf.float64)
B = tf.random.normal([n, n], 0, 1, dtype=tf.float64)
D = strassen_tf(A, B)

the code stuck at the last statement.

Thanks in advance!

First, I would 100% advise against using recursion in a tf.function. I don’t think tf.function will play well with recursion. tf.function needs to know the entire graph at compile-time and so it will (I think) trace through the recursion until it has a complete graph. With large matrices this will result in REALLY big graphs that are difficult to compile. You should opt for tf control-flow instead e.g. tf.scan, tf.while_loop, etc.

Second, if you add some debug prints and inspect the rank, you’ll see the rank always remains 2. Notice your comment for your base case says the base case is a 1 x 1 matrix, which tf.rank will always evaluate to 2. I think instead what you’re looking for is tf.size(x) == 1 OR tf.linalg.matrix_rank(x) == 1. tf.rank returns the number of dimensions in the Tensor, whereas tf.linalg.matrix_rank returns the rank in the sense you are trying to use it. Changing that and commenting out the tf.function decorators evaluates the function correctly. If you want to use tf.function, try to rework the recursion to use native TensorFlow control-flow.

1 Like

Thanks a lot, @Sean_Moriarity for the detailed reply. My mistake about the rank. Although I have two questions.

  1. As I understand tf.function uses Graphs (operations+tensors) for function’s computations. Now If I don’t use tf.function, how is the graph generated. And if it is not generated, how does the code run faster because internally graphs optimize by pipelining parallel operations together. (This is my understanding. I may be wrong). Do tf.scan and tf.while_loop also create Graphs.

  2. The code is running fine by removing decorators i.e. eagerly. How can I run the code non-eagerly or using graphs?

Yes recursion is not supported. We had already a ticket on this specific aspect of your issue:

1 Like

You are correct tf.function builds a graph, tf.scan and tf.while_loop are native TensorFlow operations that will be included in the graph so you can use them within tf.function. Here’s an internal discussion on TF control flow: https://www.youtube.com/watch?v=IzKXEbpT9Lg

1 Like

Okay. Got it. That means the implementations of a recursive linear algorithm like Strassen’s and other matrix inversion technique have to wait until its release.

Thanks @Bhack for the pointer.