In this code, I try to optimize the phase of a hologram. I used Gradient Descent Algorithm through Tensorflow. I used Adam optimizer. the problem I am facing is the following:
In the Gradient descent algorithm, we must compute the partial derivatives with respect to a certain variables. In here, the variable is a tensor (tf.Variable) that have the same shape of the input image (i.e. 256x256). I analytically computed the partial derivative of the loss function. it is given in the code. I also tried to implement an autodifferentiation method base on taensorflow GradientTap. the problem is that it doesn’t converge as the manually computed one.
The loss function I used is :
MSE (I_predicted - I_target).
I_predicted = Abs( FFT( np.exp(1j*phase) ),
the autodifferentiation should handle the derivation with the presence of the FFT(Complex function). I have seen some people succeded carrying the autodifferentiation. I cannot achieve that. Please Help. These people used Pytorch instead of Tensorflow. Also, they didn’t exlicitly give a relation between the I_predicted and the variable (i.e. phase).
Thank you for your help.
import cv2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from scipy.fftpack import fft2, fftshift, ifft2
imgBGR = cv2.imread('CirCTria.png')
imgGray = cv2.cvtColor(imgBGR, cv2.COLOR_BGR2GRAY)
def FFT(inputArray):
shift = tf.cast( tf.signal.ifftshift( inputArray ), tf.complex64 )
FourierTransform = tf.signal.fft2d( shift )
unshift = tf.signal.fftshift( FourierTransform )
return unshift
def iFFT(inputArray):
shift = tf.cast( tf.signal.ifftshift( inputArray ), tf.complex64 )
FourierTransform = tf.signal.ifft2d( shift )
unshift = tf.signal.fftshift( FourierTransform )
return unshift
def SGDHoloCompute(inputImg, nEpochs, initialPhase='FFT', lr=0.5):
optimizer = tf.keras.optimizers.Adam(lr) ### Optimizer based on ADAM
train_loss_results = []
train_psnr_results = []
h,w = inputImg.shape
inputImg = np.fliplr(inputImg)/np.amax(inputImg) ### Normalize image Array
phase = tf.Variable( np.random.rand(h,w) *2*np.pi , name = 'Phase', dtype = tf.float32)
for epoch in range(nEpochs):
epoch_loss_avg = tf.keras.metrics.Mean()
epoch_psnr_avg = tf.keras.metrics.Mean()
holo = tf.math.exp(1.j * tf.dtypes.cast(phase, tf.complex64))
reconstruction = iFFT(holo)
phaseImagePlane = tf.math.angle(reconstruction) # Phase image plane
reconIntensity = tf.math.abs(reconstruction) **2
I_t = inputImg
I_r = reconIntensity / tf.math.reduce_mean(reconIntensity, axis = [0, 1], keepdims = True) * np.mean(I_t, axis = (0, 1))
I_r = tf.clip_by_value(I_r, 0.0, 1.0)
I_rPSNR = I_r[:,:, np.newaxis]
I_tPSNR = I_t[:,:, np.newaxis]
phase_mod = tf.experimental.numpy.mod( phase , 2*np.pi)
## Define Loss
mse = tf.keras.losses.MeanSquaredError()
loss_mse = mse(I_t, I_r)
loss_psnr = tf.reduce_mean(tf.image.psnr(I_tPSNR, I_rPSNR, max_val = 1))
################ Please uncomment this pat for MANUAL DIFFERENTIATION GRADIENT DESCENT
# delta_f = 2 * tf.dtypes.cast((I_r - I_t), dtype = tf.complex64) * 2 *reconstruction/ h / w
# delta = FFT(delta_f)
# recon_grad = - 1.j * tf.math.exp(-1.j * tf.dtypes.cast(phase, dtype = tf.complex64)) * delta
# recon_const = tf.math.real(recon_grad)
# total_grad = [recon_const]
############### Please Un comment this part for AUTOMATIC DIFFERENTIATION FROM TENSORFLOW
with tf.GradientTape() as tape:
tape.watch(phase)
Lossfunction = tf.math.l2_normalize( tf.abs(iFFT( tf.exp(tf.convert_to_tensor(1j, dtype=tf.complex64)* tf.cast(phase, dtype=tf.complex64 ) )))**2 - I_t )
grads = tape.gradient(Lossfunction, [phase])
total_grad = grads
# Update Gradient Descent
optimizer.apply_gradients(zip(total_grad, [phase]))
# Track the loss and PSNR
epoch_loss_avg(loss_mse)
epoch_psnr_avg(loss_psnr)
train_loss_results.append(epoch_loss_avg.result())
train_psnr_results.append(epoch_psnr_avg.result())
print("Epoch: {:03d} MSE Loss: {:.6f} PSNR: {:.2f}".format(epoch, epoch_loss_avg.result(), epoch_psnr_avg.result()))
plt.imshow(np.flipud(np.abs(fftshift( fft2( fftshift( 1*np.exp(1j*phase.numpy() )) ) )))**2)
plt.pause(0.05)
return phase_mod
if __name__ == "__main__":
phase = SGDHoloCompute(imgGray, 100, lr=0.4)