When I use tf.signal.irfft
in TensorFlow 2.9, I experience a continuous increase in GPU memory usage, and the GPU I’m using is an NVIDIA GeForce 4090. However, the same code doesn’t exhibit the mentioned issue when running on TensorFlow 2.4 with a GPU 2080.
Hi @lucky_hu
Welcome to the TensorFlow Forum!
Could you please retry with the latest TensorFlow version 2.13 or 2.14 and let us know if the issue still persists. Thank you.
I cannot use TensorFlow version 2.14. When I use tf.keras.layers.LayerNormalization(axis=[-1, -2])
, the program shows an error: ‘cuDNN launch failure: input shape ([1, 1992, 257, 1]).’ The original input shape is [8, 249, 257, 1].
I tried 2.14 and still couldn’t solve my problem
Could you please share the reproducible code to replicate the error and understand the issue? Thank you.
class stft(keras.layers.Layer):
def __init__(self, block_len=512,mode= "mag_pha"):
super().__init__()
self.block_len = block_len
self.win = tf.signal.hann_window(block_len)
self.mode = mode
def call(self, x):
frames = tf.signal.frame( x, self.block_len, self.block_len//2)
frames = self.win * frames
stft_dat = tf.signal.rfft(frames)
if self.mode == "mag_pha":
mag = tf.math.abs(stft_dat)
phase = tf.math.angle(stft_dat)
output_list = [mag, phase]
else:
real = tf.math.real(stft_dat)
imag = tf.math.imag(stft_dat)
output_list = [real, imag]
return output_list
class ifftLayer(keras.layers.Layer):
def __init__(self, block_len=512, mode="mag_pha"):
super().__init__()
self.block_len = block_len
self.win = tf.signal.hann_window(block_len)
self.mode = mode
def call(self, x):
if self.mode == "mag_pha":
# calculating the complex representation
s1_stft = (tf.cast(x[0], tf.complex64)) * tf.exp((1j * tf.cast(x[1], tf.complex64)))
else:
s1_stft = tf.cast(x[0], tf.complex64) + 1j * tf.cast(x[1], tf.complex64)
irfft = tf.signal.irfft(s1_stft) * self.win
return irfft
def Spon_DRC_NET_Block_Concat_T(batch_size=256, block_len=512, gain=None, num=16, epoch=2, input_norm=‘iLN’):
# input layer for time signal
time_data = tf.keras.layers.Input(batch_shape=(batch_size, None))
mag ,_ = stft()(time_data)
# mag, _ = tf.keras.layers.Lambda(stft, arguments={‘mode’: ‘mag_pha’})(time_data)
mag = tf.keras.layers.Lambda(reshape, arguments={‘axis’: [batch_size, -1, block_len // 2 + 1, 1]})(mag)
‘’‘encoder’‘’
if input_norm == “iLN”:
input_comples_spec = tf.keras.layers.LayerNormalization(axis=[-1, -2])(mag)
elif input_norm == “BN”:
input_comples_spec = tf.keras.layers.BatchNormalization()(mag)
# causl padding [1,0] [0,2]
'''encode'''
conv_1 = tf.keras.layers.Conv2D(num, (2, 5), (1, 2), padding="valid", name="CONV")(input_comples_spec)
bn_1 = tf.keras.layers.LayerNormalization(axis=[-1, -2])(conv_1)
out_1 = PReLU(shared_axes=[1, 2])(bn_1)
out_1 = DRC_Block_T(numUnits=num, batch_size=batch_size, L=-1, width=127, channel=num, epoch=epoch, causal=True)(
out_1)
c_out_1 = tf.keras.layers.Conv2D(num, (1, 1), (1, 1), padding='same')(out_1)
skipcon_5 = tf.keras.layers.concatenate([c_out_1, out_1])
deconv_5 = tf.keras.layers.Conv2DTranspose(1, (2, 5), (1, 2), padding="valid", use_bias=False, name="DCONV4")(
skipcon_5)
output_mask = tf.keras.activations.sigmoid(deconv_5)
enh_spec = tf.keras.layers.Lambda(mk_mag_mask, arguments={'gain': gain})([time_data, output_mask])
enh_frame = ifftLayer()(enh_spec)
enh_time = tf.keras.layers.Lambda(overlapAddLayer, name='enhanced_time')(enh_frame)
model = tf.keras.models.Model(time_data, enh_time)
model.summary()
return model