I did the implementation as follow:
import tensorflow as tf
def bMatrix(m, lmbd = 1600, d = 2, dtype=tf.float32):
E = tf.linalg.diag(tf.repeat(tf.constant(1), m))
D = tf.experimental.numpy.diff(E,n=d,axis=0)
B = E + (lmbd * (tf.linalg.matmul(a=D,b=D,transpose_a=True)))
B = tf.linalg.inv(tf.cast(B,dtype=dtype))
return B
def whittaker1D(m, lmbd = 5, d = 2, dtype=tf.float32):
B = bMatrix(m, lmbd, d, dtype)
def W(y):
return tf.linalg.matvec(B,y)
return tf.function(W,input_signature=(tf.TensorSpec(shape=(m), dtype=dtype),))
def whittaker2D(m, lmbd = 5, d = 2, transpose=False, dtype=tf.float32):
B = bMatrix(m, lmbd, d, dtype)
if (not transpose):
def W(y):
return tf.linalg.matvec(B,y)
return tf.function(W,input_signature=(tf.TensorSpec(shape=(None,m), dtype=dtype),))
else:
def W(y):
return tf.linalg.matvec(B,tf.transpose(y))
return tf.function(W,input_signature=(tf.TensorSpec(shape=(m,None), dtype=dtype),))
def whittakerImage(shape, lmbd = 5, d = 2, dtype=tf.float32):
assert len(shape) == 4
fn = whittaker2D(shape[1], lmbd = lmbd, d = d, dtype=dtype)
fn2 = whittaker2D(shape[2],lmbd = lmbd, d = d, transpose=True, dtype=dtype)
def W(z):
z = tf.vectorized_map(lambda z : tf.vectorized_map(fn2, tf.vectorized_map(fn, tf.reshape(z,(z.shape[2],z.shape[0],z.shape[1])))),z)
z = tf.transpose(z, perm=[0,3,2,1])
return z
return tf.function(W,input_signature=(tf.TensorSpec(shape=shape, dtype=dtype),))
smoother = whittakerImage((None,28,28,1))
layer = tf.keras.layers.Lambda(function=smoother)(layer)