I am trying to do Strassen’s matrix multiplication on TensorFlow. I am trying to perform a graph execution using the following code
import tensorflow as tf
@tf.function # The decorator converts `split_tf` into a `Function`.
def split_tf(matrix):
shape = tf.shape(matrix)
shape2 = tf.cast(tf.divide(shape, 2), dtype=tf.int32)
return tf.slice(A, [0, 0], shape2), tf.slice(A, [0, shape2[1]], shape2),
tf.slice(A, [shape2[0], 0], shape2),
tf.slice(A, [shape2[0], shape2[1]], shape2)
@tf.function # The decorator converts `strassen_tf` into a `Function`.
def strassen_tf(x, y):
# Base case when size of matrices is 1x1
if tf.rank(x) == 1:
return tf.math.multiply(x, y)
# Splitting the matrices into quadrants. This will be done recursively
# until the base case is reached.
a, b, c, d = split_tf(x)
e, f, g, h = split_tf(y)
# Computing the 7 products, recursively (p1, p2...p7)
p1 = strassen_tf(a, tf.math.subtract(f, h))
p2 = strassen_tf(tf.math.add(a, b), h)
p3 = strassen_tf(tf.math.add(c, d), e)
p4 = strassen_tf(d, tf.math.subtract(g, e))
p5 = strassen_tf(tf.math.add(a, d), tf.math.add(e, h))
p6 = strassen_tf(tf.math.subtract(b, d), tf.math.add(g, h))
p7 = strassen_tf(tf.math.subtract(a, c), tf.math.add(e, f))
# Computing the values of the 4 quadrants of the final matrix c
c11 = tf.math.add(tf.math.subtract(tf.math.add(p5, p4), p2), p6)
c12 = tf.math.add(p1, p2)
c21 = tf.math.add(p3, p4)
c22 = tf.math.subtract(tf.math.subtract(tf.math.add(p1, p5), p3), p7)
# Combining the 4 quadrants into a single matrix by stacking horizontally and vertically.
c = tf.concat([tf.concat([C11, C12], axis=1), tf.concat([C21, C22], axis=1)], axis=0)
return c
Now when I use the above code to multiply two tensors as below:
n = 2;
A = tf.random.normal([n, n], 0, 1, dtype=tf.float64)
B = tf.random.normal([n, n], 0, 1, dtype=tf.float64)
D = strassen_tf(A, B)
the code stuck at the last statement.
Thanks in advance!