Single channel speech separation using Neural Networks

Hello
I’m trying to train a model that separates 2 speakers from a noisy, reverberated mixed audio file. I first try to train the model with only 1 mixed file and test the model with the same file to see if the model executes the expected task. This however is not the case and the generated masks return pratically the same output where the only difference is the perceived loudness ( the 2 speakers are still both hearable) Since I train and test on the same file I should expect a perfect separation though. I’ve provided the code for the architecture, loss function and training pipeline below. One of the things I tried (without success) is adding the phase information as input to the model.
Thanks in advance for the help

MODEL ARCHITECTURE

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
import math

class Cruse(nn.Module):

def __init__(self):
    super().__init__()
    kernel_size = [2,3]
    stride = [1,2]
    self.conv1 = nn.Conv2d(1, 16, kernel_size=kernel_size, stride=stride)
    self.conv2 = nn.Conv2d(16, 32, kernel_size=kernel_size, stride=stride)
    self.conv3 = nn.Conv2d(32, 64, kernel_size=kernel_size, stride=stride)
    self.conv4 = nn.Conv2d(64, 128, kernel_size=kernel_size, stride=stride)
    feat_size = 1920
    self.gru1 = nn.GRU(input_size=feat_size//4, hidden_size=feat_size//4, num_layers=1, batch_first=True) 
    self.gru2 = nn.GRU(input_size=feat_size//4, hidden_size=feat_size//4, num_layers=1, batch_first=True) 
    self.gru3 = nn.GRU(input_size=feat_size//4, hidden_size=feat_size//4, num_layers=1, batch_first=True) 
    self.gru4 = nn.GRU(input_size=feat_size//4, hidden_size=feat_size//4, num_layers=1, batch_first=True) 
    self.feat_size = feat_size
    self.deconv1 = nn.ConvTranspose2d(128, 64, kernel_size=kernel_size, stride=stride)
    self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=kernel_size, stride=stride)
    self.deconv3 = nn.ConvTranspose2d(32, 16, kernel_size=kernel_size, stride=stride, output_padding=(0, 1))
    self.deconv4 = nn.ConvTranspose2d(16, 2, kernel_size=kernel_size, stride=stride)

    self.skip1 = nn.Conv2d(128, 128, kernel_size=1, groups=128)
    self.skip2 = nn.Conv2d(64, 64, kernel_size=1, groups=64)
    self.skip3 = nn.Conv2d(32, 32, kernel_size=1, groups=32)
    self.skip4 = nn.Conv2d(16, 16, kernel_size=1, groups=16)

    self.activation = nn.LeakyReLU(negative_slope=0.03, inplace=True)
    self.activation_outlayer = nn.Sigmoid()

def forward(self, x):
    encoded = []
    x = self.activation(self.conv1(x))
    encoded.append(x)
    x = self.activation(self.conv2(x))
    encoded.append(x)
    x = self.activation(self.conv3(x))
    encoded.append(x)
    x = self.activation(self.conv4(x))
    encoded.append(x)    

    gru_in = x.permute(0,2,1,3) 
    input_size = gru_in.shape
    gru_in = gru_in.reshape(gru_in.shape[0], gru_in.shape[1], -1)
    feat_size = self.feat_size

    gru_out1,_ = self.gru1(gru_in[:,:,:feat_size//4])
    gru_out2,_ = self.gru2(gru_in[:,:,feat_size//4:feat_size//4*2])
    gru_out3,_ = self.gru3(gru_in[:,:,feat_size//4*2:feat_size//4*3])
    gru_out4,_ = self.gru4(gru_in[:,:,feat_size//4*3:feat_size//4*4])

    gru_out = torch.concat((gru_out1, gru_out2, gru_out3, gru_out4), 2)
    gru_out = gru_out.reshape(input_size)
    x = gru_out.permute(0,2,1,3)

    temp = self.skip1(encoded[-1])
    x = x + temp
    x = self.activation(self.deconv1(x))
    
    temp = self.skip2(encoded[-2])
    x = x + temp
    x = self.activation(self.deconv2(x))

    temp = self.skip3(encoded[-3])
    x = x + temp
    x = self.activation(self.deconv3(x))

    temp = self.skip4(encoded[-4])
    x = x + temp

    mask = self.activation_outlayer(self.deconv4(x))
    mask1=mask[:,[0],:,:]
    mask2=mask[:,[1],:,:]
    
    return mask1, mask2  # Return two separate masks for the two speakers

LOSS FUNCTION

import torch
import torch.nn as nn
class Signal_Approximation_Loss(nn.Module):
def init(self, power=0.3, normalization=False):
super(Signal_Approximation_Loss, self).init()
self.power = power
self.normalization=normalization

def power_compression(self, x):
    # mag = torch.clamp(torch.abs(x), min=1e-10, max=1e3)
    mag = torch.abs(x)
    mag_pc = torch.pow(mag, self.power)
    angle = torch.angle(x)
    phase = torch.exp( 1j * angle )
    x = mag_pc * phase
    return x

def complex_mse_loss(self, pred, target):
    mse_nominator = torch.linalg.norm(target-pred)**2
    mse_denominator = torch.linalg.norm(target)**2 if self.normalization else torch.numel(target)
    # mse_denominator = torch.clamp(torch.linalg.norm(target)**2, min=1e-8) if self.normalization else torch.numel(target)  ### avoid 0 in denominator
    return mse_nominator / mse_denominator

@staticmethod
def replace_denormals(x: torch.tensor, threshold=1e-10):
    ### References:
    ### https://discuss.pytorch.org/t/anglebackward-returned-nan-values/122336/2
    ### https://discuss.pytorch.org/t/strange-behavior-of-torch-angle-s-anglebackward/133621
    y = x.clone()
    y[(x < threshold) & (x > -1.0 * threshold)] = threshold
    return y

def forward(self, pred, target):
    pred = self.replace_denormals(pred.real) + 1j * self.replace_denormals(pred.imag)
    target = self.replace_denormals(target.real) + 1j * self.replace_denormals(target.imag)

    pred_pc = self.power_compression(pred)
    target_pc = self.power_compression(target)
    b = 0.5
    loss = (1-b) * self.complex_mse_loss(torch.abs(pred_pc), torch.abs(target_pc)) + b * self.complex_mse_loss(pred_pc, target_pc)
    return loss

import itertools
class Pit_Cruse(nn.Module):
def init(self, speakerwise=False, power=0.3, normalization=False):
super(Pit_Cruse, self).init()
self.speakerwise = speakerwise
self.mse_loss = Signal_Approximation_Loss(power=power, normalization=normalization)

def forward(self, preds, targets):
    batch, sp, _, _  = targets.shape  ## New shape 20240310
    perms_list = list((itertools.permutations(range(sp))))
    # print("perms_list:", perms_list)
    all_losses = torch.ones((batch, len(perms_list)), device=preds.device)
    # print("all_losses.shape:", all_losses.shape)
    for i_batch in range(batch):
        # print("i_batch:", i_batch)
        for i_perm, perm in enumerate(itertools.permutations(range(sp))):
            # print("i_perm:", i_perm)
            # print("perm:", perm)
            if self.speakerwise:  ## MSE for each speaker
                losses_perm_spk_list = [self.mse_loss(preds[i_batch,perm[i_spk],:,:], targets[i_batch,i_spk,:,:]) for i_spk in range(sp)]
                losses_perm = sum(losses_perm_spk_list) / sp
            else:  ## MSE for each permutation
                losses_perm = self.mse_loss(preds[i_batch,perm,:,:], targets[i_batch,:,:,:])
            # print("losses_perm:", losses_perm)
            all_losses[i_batch,i_perm] = losses_perm

    losses, _ = torch.min(all_losses, dim=1)  ## min of all Permutations
    # print("losses:", losses)
    loss = torch.mean(losses)
    return loss

###training Pipeline

import argparse
import logging
import os
import random
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from pathlib import Path
from torch import optim
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
from os import listdir
import matplotlib.pyplot as plt
from scipy. io import wavfile
from torchaudio.transforms import Spectrogram, InverseSpectrogram
from os.path import isfile, join
from cru_loss import Pit_Cruse
from cruseModel import Cruse
import warnings
warnings.filterwarnings(“ignore”)

dir_input = ‘/project_ghent/mixed_datasets2/wav16k/tr/mix_both/’ # Training set directory
dir_s1=‘/project_ghent/mixed_datasets2/wav16k/tr/s1/’ #directory speaker 1
dir_s2=‘/project_ghent/mixed_datasets2/wav16k/tr/s2/’ #directory speaker 2
dir_input_val = ‘/project_ghent/mixed_datasets2/wav16k/tr/mix_both/’ # Validation dataset directory
dir_checkpoint = ‘/project_ghent/Model2/chekpoints/’ # Directory to save model checkpoints
dir_lossplots = ‘/project_ghent/Model2/loss_plots/’ # Directory to save loss plots

todb = lambda x: torch.log10(torch.abs(x) + 1e-10)

def collate_fn(batch):
return {
‘input’: torch.tensor([x[‘input’] for x in batch], dtype=torch.cfloat), # Noisy mixture
‘gt’: torch.tensor([x[‘gt’] for x in batch], dtype=torch.cfloat), # Ground truth for speaker 1
‘gt_2’: torch.tensor([x[‘gt_2’] for x in batch], dtype=torch.cfloat) # Ground truth for speaker 2
}

class MyDataset(Dataset):
def init(self, dir_input, Train=True, data_num=9000):
file_list = [file for file in listdir(dir_input) if file.endswith(“.wav”) and isfile(join(dir_input, file))]
if Train:
sub_list = random.sample(file_list, data_num)
else:
sub_list = random.sample(file_list,data_num)
noisy_files =
gt_files =
gt_files_2 =
for file in sub_list:
if file.endswith(“.wav”):
name = file
noisy_files.append(str(dir_input + name))
gt_files.append(str(dir_s1 + ‘s1_’+name))
gt_files_2.append(str(dir_s2 +‘s2_’+name))

    self.noisy_files = noisy_files
    self.gt_files = gt_files
    self.gt_files_2 = gt_files_2



def __getitem__(self, idx):
    noisy_path = self.noisy_files[idx]
    gt_path = self.gt_files[idx]
    gt_2_path = self.gt_files_2[idx]
    
    win = lambda x: torch.sqrt(torch.hann_window(x))
    
    sr, audio_gt = wavfile.read(gt_path)
    audio_gt = torch.tensor(audio_gt.astype(np.float32))
    gt = Spectrogram(n_fft=512, hop_length=256, power=None, window_fn=win)(audio_gt).cfloat()
    gt = gt.permute(1, 0).unsqueeze(0)
    
    sr, audio_gt_2 = wavfile.read(gt_2_path)
    audio_gt_2 = torch.tensor(audio_gt_2.astype(np.float32))
    gt_2 = Spectrogram(n_fft=512, hop_length=256, power=None, window_fn=win)(audio_gt_2).cfloat()
    gt_2 = gt_2.permute(1, 0).unsqueeze(0)

    sr, audio_temp = wavfile.read(noisy_path)
    audio_temp = torch.tensor(audio_temp.astype(np.float32))
    audio_temp = Spectrogram(n_fft=512, hop_length=256, power=None, window_fn=win)(audio_temp).cfloat()
    audio_temp = audio_temp.permute(1, 0).unsqueeze(0)
    
    return {
        'input': audio_temp,
        'gt': gt,
        'gt_2': gt_2
    }



def __len__(self):
    return len(self.gt_files)

def data_generator(dir_input, Train=True, data_num=9000):
dataset = MyDataset(dir_input, Train, data_num)
datasize = len(dataset)
return dataset, datasize

def train_model(
model,
device,
epochs: int = 5,
batch_size: int = 1,
learning_rate: float = 8e-5,
save_checkpoint: bool = True,
amp: bool = False,
weight_decay: float = 1e-2,
momentum: float = 0.999,
gradient_clipping: float = 1.0,
):

loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)

dataset_val, datasize_val = data_generator(dir_input_val, Train=False,data_num=1)#@@@@@@@@@@@@@@@@@@@@@@@@@
n_val = datasize_val
logging.info(f'Validation size: {n_val}')
val_loader = DataLoader(dataset_val, shuffle=True, drop_last=True, **loader_args)

logging.info(f'''Starting training:
    Epochs:          {epochs}
    Batch size:      {batch_size}  
    Learning rate:   {learning_rate}
    Validation size: {n_val}
    Checkpoints:     {save_checkpoint}
    Device:          {device.type}
    Mixed Precision: {amp}
''')


optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.99), eps=1e-08, weight_decay=weight_decay)
global_step = 0
train_loss_list = []
val_loss_list = []
x = []

for epoch in range(1, epochs + 1):   #epochs: 100
    model.train()

    dataset_train, datasize = data_generator(dir_input, Train=True, data_num=1)  # Train dataset@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    n_train = datasize
    logging.info(f'Training size: {n_train}')
    train_loader = DataLoader(dataset_train, shuffle=True, drop_last=True, **loader_args)
    
    epoch_loss = 0
    with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='aud') as pbar:
        for batch in train_loader:
            optimizer.zero_grad()
            inputs, gts, gts_2 = batch['input'], batch['gt'], batch['gt_2']
            inputs = inputs.to(device=device, dtype=torch.cfloat)
            gts = gts.to(device=device, dtype=torch.cfloat)
            gts_2 = gts_2.to(device=device, dtype=torch.cfloat)

            inputs_mag = torch.pow(torch.abs(inputs), 0.3)
            mask_preds_1,mask_preds_2= model(inputs_mag)

            outputs_1 = inputs_mag * mask_preds_1
            outputs_1 = torch.pow(outputs_1, (10 / 3))
            outputs_1 = outputs_1 * torch.exp(1j * torch.angle(inputs))  # Keep phase intact
    
            outputs_2 = inputs_mag * mask_preds_2
            outputs_2 = torch.pow(outputs_2, (10 / 3))
            outputs_2 = outputs_2 * torch.exp(1j * torch.angle(inputs))  # Keep phase intact

            gts_concat = torch.cat([gts, gts_2], dim=1)

            outputs_concat = torch.cat([outputs_1, outputs_2], dim=1)
            train_loss=Pit_Cruse(speakerwise=False, power=0.3, normalization=True)(outputs_concat,gts_concat)

            
            train_loss.backward()
            optimizer.step()
            pbar.update(inputs.shape[0])
            global_step += 1
            epoch_loss += train_loss.item()

        val_loss = 0
        for batch in val_loader:
            inputs, gts, gts_2 = batch['input'], batch['gt'], batch['gt_2']
            inputs = inputs.to(device=device, dtype=torch.cfloat)
            gts = gts.to(device=device, dtype=torch.cfloat)
            gts_2 = gts_2.to(device=device, dtype=torch.cfloat)
            inputs_mag = torch.pow(torch.abs(inputs), 0.3)

            mask_preds_1, mask_preds_2 = model(inputs_mag)
            outputs_1 = inputs_mag * mask_preds_1
            outputs_1 = torch.pow(outputs_1, (10 / 3))
            outputs_1 = outputs_1 * torch.exp(1j * torch.angle(inputs))  # Keep phase intact

            outputs_2 = inputs_mag * mask_preds_2
            outputs_2 = torch.pow(outputs_2, (10 / 3))
            outputs_2 = outputs_2 * torch.exp(1j * torch.angle(inputs))  # Keep phase intact
            gts_concat = torch.cat([gts, gts_2], dim=1)
            outputs_concat = torch.cat([outputs_1, outputs_2], dim=1)
            val_loss_bat=Pit_Cruse(speakerwise=False, power=0.3, normalization=True)(outputs_concat,gts_concat)
            val_loss += val_loss_bat.item()
    logging.info(f'epoch {epoch}, train_loss: {epoch_loss/n_train}, val_loss: {val_loss/n_val}')
    train_loss_list.append(epoch_loss / n_train)
    val_loss_list.append(val_loss / n_val)
    x.append(epoch)

    Path(dir_lossplots).mkdir(parents=True, exist_ok=True)
    if epoch % 10 == 0:
        plt.figure(figsize=(6, 6), dpi=100)
        plt.plot(x, train_loss_list, 'r', lw=1)
        plt.plot(x, val_loss_list, 'b', lw=1)
        plt.title("Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.legend(["Train Loss", "Validation Loss"])
        plt.savefig(f'{dir_lossplots}/loss_epoch{epoch}.png')

    if save_checkpoint:
        Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
        if epoch % 1 == 0:
            state_dict = model.state_dict()
            torch.save(state_dict, f'{dir_checkpoint}/checkpoint_epoch{epoch}.pth')
            logging.info(f'Checkpoint {epoch} saved!')

def get_args():
parser = argparse.ArgumentParser(description=‘Train a model for two-speaker speech separation’)
parser.add_argument(‘–epochs’, ‘-e’, metavar=‘E’, type=int, default=500, help=‘Number of epochs’)
parser.add_argument(‘–batch-size’, ‘-b’, dest=‘batch_size’, metavar=‘B’, type=int, default=1, help=‘Batch size’)
parser.add_argument(‘–learning-rate’, ‘-l’, metavar=‘LR’, type=float, default=4e-5, help=‘Learning rate’, dest=‘lr’)
parser.add_argument(‘–load’, ‘-f’, type=str, default=False, help=‘Path to load model from .pth file’)
parser.add_argument(‘–amp’, action=‘store_true’, default=False, help=‘Use mixed precision’)
parser.add_argument(‘–classes’, ‘-c’, type=int, default=2, help=‘Number of speakers/classes’)

return parser.parse_args()

if name == ‘main’:
args = get_args()

logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')

model = Cruse() 

logging.info(f'Network: {model}')

if args.load:
    state_dict = torch.load(args.load, map_location=device)
    model.load_state_dict(state_dict)
    logging.info(f'Model loaded from {args.load}')

model.to(device=device)

try:
    train_model(
        model=model,
        epochs=args.epochs,
        batch_size=args.batch_size,
        learning_rate=args.lr,
        device=device,
        amp=args.amp
    )
except torch.cuda.OutOfMemoryError:
    logging.error('Out of memory error detected! Reducing batch size or using AMP may help.')
    torch.cuda.empty_cache()
    train_model(
        model=model,
        epochs=args.epochs,
        batch_size=args.batch_size,
        learning_rate=args.lr,
        device=device,
        amp=args.amp
    )