Hello
I’m trying to train a model that separates 2 speakers from a noisy, reverberated mixed audio file. I first try to train the model with only 1 mixed file and test the model with the same file to see if the model executes the expected task. This however is not the case and the generated masks return pratically the same output where the only difference is the perceived loudness ( the 2 speakers are still both hearable) Since I train and test on the same file I should expect a perfect separation though. I’ve provided the code for the architecture, loss function and training pipeline below. One of the things I tried (without success) is adding the phase information as input to the model.
Thanks in advance for the help
MODEL ARCHITECTURE
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
import math
class Cruse(nn.Module):
def __init__(self):
super().__init__()
kernel_size = [2,3]
stride = [1,2]
self.conv1 = nn.Conv2d(1, 16, kernel_size=kernel_size, stride=stride)
self.conv2 = nn.Conv2d(16, 32, kernel_size=kernel_size, stride=stride)
self.conv3 = nn.Conv2d(32, 64, kernel_size=kernel_size, stride=stride)
self.conv4 = nn.Conv2d(64, 128, kernel_size=kernel_size, stride=stride)
feat_size = 1920
self.gru1 = nn.GRU(input_size=feat_size//4, hidden_size=feat_size//4, num_layers=1, batch_first=True)
self.gru2 = nn.GRU(input_size=feat_size//4, hidden_size=feat_size//4, num_layers=1, batch_first=True)
self.gru3 = nn.GRU(input_size=feat_size//4, hidden_size=feat_size//4, num_layers=1, batch_first=True)
self.gru4 = nn.GRU(input_size=feat_size//4, hidden_size=feat_size//4, num_layers=1, batch_first=True)
self.feat_size = feat_size
self.deconv1 = nn.ConvTranspose2d(128, 64, kernel_size=kernel_size, stride=stride)
self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=kernel_size, stride=stride)
self.deconv3 = nn.ConvTranspose2d(32, 16, kernel_size=kernel_size, stride=stride, output_padding=(0, 1))
self.deconv4 = nn.ConvTranspose2d(16, 2, kernel_size=kernel_size, stride=stride)
self.skip1 = nn.Conv2d(128, 128, kernel_size=1, groups=128)
self.skip2 = nn.Conv2d(64, 64, kernel_size=1, groups=64)
self.skip3 = nn.Conv2d(32, 32, kernel_size=1, groups=32)
self.skip4 = nn.Conv2d(16, 16, kernel_size=1, groups=16)
self.activation = nn.LeakyReLU(negative_slope=0.03, inplace=True)
self.activation_outlayer = nn.Sigmoid()
def forward(self, x):
encoded = []
x = self.activation(self.conv1(x))
encoded.append(x)
x = self.activation(self.conv2(x))
encoded.append(x)
x = self.activation(self.conv3(x))
encoded.append(x)
x = self.activation(self.conv4(x))
encoded.append(x)
gru_in = x.permute(0,2,1,3)
input_size = gru_in.shape
gru_in = gru_in.reshape(gru_in.shape[0], gru_in.shape[1], -1)
feat_size = self.feat_size
gru_out1,_ = self.gru1(gru_in[:,:,:feat_size//4])
gru_out2,_ = self.gru2(gru_in[:,:,feat_size//4:feat_size//4*2])
gru_out3,_ = self.gru3(gru_in[:,:,feat_size//4*2:feat_size//4*3])
gru_out4,_ = self.gru4(gru_in[:,:,feat_size//4*3:feat_size//4*4])
gru_out = torch.concat((gru_out1, gru_out2, gru_out3, gru_out4), 2)
gru_out = gru_out.reshape(input_size)
x = gru_out.permute(0,2,1,3)
temp = self.skip1(encoded[-1])
x = x + temp
x = self.activation(self.deconv1(x))
temp = self.skip2(encoded[-2])
x = x + temp
x = self.activation(self.deconv2(x))
temp = self.skip3(encoded[-3])
x = x + temp
x = self.activation(self.deconv3(x))
temp = self.skip4(encoded[-4])
x = x + temp
mask = self.activation_outlayer(self.deconv4(x))
mask1=mask[:,[0],:,:]
mask2=mask[:,[1],:,:]
return mask1, mask2 # Return two separate masks for the two speakers
LOSS FUNCTION
import torch
import torch.nn as nn
class Signal_Approximation_Loss(nn.Module):
def init(self, power=0.3, normalization=False):
super(Signal_Approximation_Loss, self).init()
self.power = power
self.normalization=normalization
def power_compression(self, x):
# mag = torch.clamp(torch.abs(x), min=1e-10, max=1e3)
mag = torch.abs(x)
mag_pc = torch.pow(mag, self.power)
angle = torch.angle(x)
phase = torch.exp( 1j * angle )
x = mag_pc * phase
return x
def complex_mse_loss(self, pred, target):
mse_nominator = torch.linalg.norm(target-pred)**2
mse_denominator = torch.linalg.norm(target)**2 if self.normalization else torch.numel(target)
# mse_denominator = torch.clamp(torch.linalg.norm(target)**2, min=1e-8) if self.normalization else torch.numel(target) ### avoid 0 in denominator
return mse_nominator / mse_denominator
@staticmethod
def replace_denormals(x: torch.tensor, threshold=1e-10):
### References:
### https://discuss.pytorch.org/t/anglebackward-returned-nan-values/122336/2
### https://discuss.pytorch.org/t/strange-behavior-of-torch-angle-s-anglebackward/133621
y = x.clone()
y[(x < threshold) & (x > -1.0 * threshold)] = threshold
return y
def forward(self, pred, target):
pred = self.replace_denormals(pred.real) + 1j * self.replace_denormals(pred.imag)
target = self.replace_denormals(target.real) + 1j * self.replace_denormals(target.imag)
pred_pc = self.power_compression(pred)
target_pc = self.power_compression(target)
b = 0.5
loss = (1-b) * self.complex_mse_loss(torch.abs(pred_pc), torch.abs(target_pc)) + b * self.complex_mse_loss(pred_pc, target_pc)
return loss
import itertools
class Pit_Cruse(nn.Module):
def init(self, speakerwise=False, power=0.3, normalization=False):
super(Pit_Cruse, self).init()
self.speakerwise = speakerwise
self.mse_loss = Signal_Approximation_Loss(power=power, normalization=normalization)
def forward(self, preds, targets):
batch, sp, _, _ = targets.shape ## New shape 20240310
perms_list = list((itertools.permutations(range(sp))))
# print("perms_list:", perms_list)
all_losses = torch.ones((batch, len(perms_list)), device=preds.device)
# print("all_losses.shape:", all_losses.shape)
for i_batch in range(batch):
# print("i_batch:", i_batch)
for i_perm, perm in enumerate(itertools.permutations(range(sp))):
# print("i_perm:", i_perm)
# print("perm:", perm)
if self.speakerwise: ## MSE for each speaker
losses_perm_spk_list = [self.mse_loss(preds[i_batch,perm[i_spk],:,:], targets[i_batch,i_spk,:,:]) for i_spk in range(sp)]
losses_perm = sum(losses_perm_spk_list) / sp
else: ## MSE for each permutation
losses_perm = self.mse_loss(preds[i_batch,perm,:,:], targets[i_batch,:,:,:])
# print("losses_perm:", losses_perm)
all_losses[i_batch,i_perm] = losses_perm
losses, _ = torch.min(all_losses, dim=1) ## min of all Permutations
# print("losses:", losses)
loss = torch.mean(losses)
return loss
###training Pipeline
import argparse
import logging
import os
import random
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from pathlib import Path
from torch import optim
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
from os import listdir
import matplotlib.pyplot as plt
from scipy. io import wavfile
from torchaudio.transforms import Spectrogram, InverseSpectrogram
from os.path import isfile, join
from cru_loss import Pit_Cruse
from cruseModel import Cruse
import warnings
warnings.filterwarnings(“ignore”)
dir_input = ‘/project_ghent/mixed_datasets2/wav16k/tr/mix_both/’ # Training set directory
dir_s1=‘/project_ghent/mixed_datasets2/wav16k/tr/s1/’ #directory speaker 1
dir_s2=‘/project_ghent/mixed_datasets2/wav16k/tr/s2/’ #directory speaker 2
dir_input_val = ‘/project_ghent/mixed_datasets2/wav16k/tr/mix_both/’ # Validation dataset directory
dir_checkpoint = ‘/project_ghent/Model2/chekpoints/’ # Directory to save model checkpoints
dir_lossplots = ‘/project_ghent/Model2/loss_plots/’ # Directory to save loss plots
todb = lambda x: torch.log10(torch.abs(x) + 1e-10)
def collate_fn(batch):
return {
‘input’: torch.tensor([x[‘input’] for x in batch], dtype=torch.cfloat), # Noisy mixture
‘gt’: torch.tensor([x[‘gt’] for x in batch], dtype=torch.cfloat), # Ground truth for speaker 1
‘gt_2’: torch.tensor([x[‘gt_2’] for x in batch], dtype=torch.cfloat) # Ground truth for speaker 2
}
class MyDataset(Dataset):
def init(self, dir_input, Train=True, data_num=9000):
file_list = [file for file in listdir(dir_input) if file.endswith(“.wav”) and isfile(join(dir_input, file))]
if Train:
sub_list = random.sample(file_list, data_num)
else:
sub_list = random.sample(file_list,data_num)
noisy_files =
gt_files =
gt_files_2 =
for file in sub_list:
if file.endswith(“.wav”):
name = file
noisy_files.append(str(dir_input + name))
gt_files.append(str(dir_s1 + ‘s1_’+name))
gt_files_2.append(str(dir_s2 +‘s2_’+name))
self.noisy_files = noisy_files
self.gt_files = gt_files
self.gt_files_2 = gt_files_2
def __getitem__(self, idx):
noisy_path = self.noisy_files[idx]
gt_path = self.gt_files[idx]
gt_2_path = self.gt_files_2[idx]
win = lambda x: torch.sqrt(torch.hann_window(x))
sr, audio_gt = wavfile.read(gt_path)
audio_gt = torch.tensor(audio_gt.astype(np.float32))
gt = Spectrogram(n_fft=512, hop_length=256, power=None, window_fn=win)(audio_gt).cfloat()
gt = gt.permute(1, 0).unsqueeze(0)
sr, audio_gt_2 = wavfile.read(gt_2_path)
audio_gt_2 = torch.tensor(audio_gt_2.astype(np.float32))
gt_2 = Spectrogram(n_fft=512, hop_length=256, power=None, window_fn=win)(audio_gt_2).cfloat()
gt_2 = gt_2.permute(1, 0).unsqueeze(0)
sr, audio_temp = wavfile.read(noisy_path)
audio_temp = torch.tensor(audio_temp.astype(np.float32))
audio_temp = Spectrogram(n_fft=512, hop_length=256, power=None, window_fn=win)(audio_temp).cfloat()
audio_temp = audio_temp.permute(1, 0).unsqueeze(0)
return {
'input': audio_temp,
'gt': gt,
'gt_2': gt_2
}
def __len__(self):
return len(self.gt_files)
def data_generator(dir_input, Train=True, data_num=9000):
dataset = MyDataset(dir_input, Train, data_num)
datasize = len(dataset)
return dataset, datasize
def train_model(
model,
device,
epochs: int = 5,
batch_size: int = 1,
learning_rate: float = 8e-5,
save_checkpoint: bool = True,
amp: bool = False,
weight_decay: float = 1e-2,
momentum: float = 0.999,
gradient_clipping: float = 1.0,
):
loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
dataset_val, datasize_val = data_generator(dir_input_val, Train=False,data_num=1)#@@@@@@@@@@@@@@@@@@@@@@@@@
n_val = datasize_val
logging.info(f'Validation size: {n_val}')
val_loader = DataLoader(dataset_val, shuffle=True, drop_last=True, **loader_args)
logging.info(f'''Starting training:
Epochs: {epochs}
Batch size: {batch_size}
Learning rate: {learning_rate}
Validation size: {n_val}
Checkpoints: {save_checkpoint}
Device: {device.type}
Mixed Precision: {amp}
''')
optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.99), eps=1e-08, weight_decay=weight_decay)
global_step = 0
train_loss_list = []
val_loss_list = []
x = []
for epoch in range(1, epochs + 1): #epochs: 100
model.train()
dataset_train, datasize = data_generator(dir_input, Train=True, data_num=1) # Train dataset@@@@@@@@@@@@@@@@@@@@@@@@@@@@
n_train = datasize
logging.info(f'Training size: {n_train}')
train_loader = DataLoader(dataset_train, shuffle=True, drop_last=True, **loader_args)
epoch_loss = 0
with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='aud') as pbar:
for batch in train_loader:
optimizer.zero_grad()
inputs, gts, gts_2 = batch['input'], batch['gt'], batch['gt_2']
inputs = inputs.to(device=device, dtype=torch.cfloat)
gts = gts.to(device=device, dtype=torch.cfloat)
gts_2 = gts_2.to(device=device, dtype=torch.cfloat)
inputs_mag = torch.pow(torch.abs(inputs), 0.3)
mask_preds_1,mask_preds_2= model(inputs_mag)
outputs_1 = inputs_mag * mask_preds_1
outputs_1 = torch.pow(outputs_1, (10 / 3))
outputs_1 = outputs_1 * torch.exp(1j * torch.angle(inputs)) # Keep phase intact
outputs_2 = inputs_mag * mask_preds_2
outputs_2 = torch.pow(outputs_2, (10 / 3))
outputs_2 = outputs_2 * torch.exp(1j * torch.angle(inputs)) # Keep phase intact
gts_concat = torch.cat([gts, gts_2], dim=1)
outputs_concat = torch.cat([outputs_1, outputs_2], dim=1)
train_loss=Pit_Cruse(speakerwise=False, power=0.3, normalization=True)(outputs_concat,gts_concat)
train_loss.backward()
optimizer.step()
pbar.update(inputs.shape[0])
global_step += 1
epoch_loss += train_loss.item()
val_loss = 0
for batch in val_loader:
inputs, gts, gts_2 = batch['input'], batch['gt'], batch['gt_2']
inputs = inputs.to(device=device, dtype=torch.cfloat)
gts = gts.to(device=device, dtype=torch.cfloat)
gts_2 = gts_2.to(device=device, dtype=torch.cfloat)
inputs_mag = torch.pow(torch.abs(inputs), 0.3)
mask_preds_1, mask_preds_2 = model(inputs_mag)
outputs_1 = inputs_mag * mask_preds_1
outputs_1 = torch.pow(outputs_1, (10 / 3))
outputs_1 = outputs_1 * torch.exp(1j * torch.angle(inputs)) # Keep phase intact
outputs_2 = inputs_mag * mask_preds_2
outputs_2 = torch.pow(outputs_2, (10 / 3))
outputs_2 = outputs_2 * torch.exp(1j * torch.angle(inputs)) # Keep phase intact
gts_concat = torch.cat([gts, gts_2], dim=1)
outputs_concat = torch.cat([outputs_1, outputs_2], dim=1)
val_loss_bat=Pit_Cruse(speakerwise=False, power=0.3, normalization=True)(outputs_concat,gts_concat)
val_loss += val_loss_bat.item()
logging.info(f'epoch {epoch}, train_loss: {epoch_loss/n_train}, val_loss: {val_loss/n_val}')
train_loss_list.append(epoch_loss / n_train)
val_loss_list.append(val_loss / n_val)
x.append(epoch)
Path(dir_lossplots).mkdir(parents=True, exist_ok=True)
if epoch % 10 == 0:
plt.figure(figsize=(6, 6), dpi=100)
plt.plot(x, train_loss_list, 'r', lw=1)
plt.plot(x, val_loss_list, 'b', lw=1)
plt.title("Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(["Train Loss", "Validation Loss"])
plt.savefig(f'{dir_lossplots}/loss_epoch{epoch}.png')
if save_checkpoint:
Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
if epoch % 1 == 0:
state_dict = model.state_dict()
torch.save(state_dict, f'{dir_checkpoint}/checkpoint_epoch{epoch}.pth')
logging.info(f'Checkpoint {epoch} saved!')
def get_args():
parser = argparse.ArgumentParser(description=‘Train a model for two-speaker speech separation’)
parser.add_argument(‘–epochs’, ‘-e’, metavar=‘E’, type=int, default=500, help=‘Number of epochs’)
parser.add_argument(‘–batch-size’, ‘-b’, dest=‘batch_size’, metavar=‘B’, type=int, default=1, help=‘Batch size’)
parser.add_argument(‘–learning-rate’, ‘-l’, metavar=‘LR’, type=float, default=4e-5, help=‘Learning rate’, dest=‘lr’)
parser.add_argument(‘–load’, ‘-f’, type=str, default=False, help=‘Path to load model from .pth file’)
parser.add_argument(‘–amp’, action=‘store_true’, default=False, help=‘Use mixed precision’)
parser.add_argument(‘–classes’, ‘-c’, type=int, default=2, help=‘Number of speakers/classes’)
return parser.parse_args()
if name == ‘main’:
args = get_args()
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
model = Cruse()
logging.info(f'Network: {model}')
if args.load:
state_dict = torch.load(args.load, map_location=device)
model.load_state_dict(state_dict)
logging.info(f'Model loaded from {args.load}')
model.to(device=device)
try:
train_model(
model=model,
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.lr,
device=device,
amp=args.amp
)
except torch.cuda.OutOfMemoryError:
logging.error('Out of memory error detected! Reducing batch size or using AMP may help.')
torch.cuda.empty_cache()
train_model(
model=model,
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.lr,
device=device,
amp=args.amp
)