I’m trying to setup diarization system with speaker embedder for further clustering.
I’ve decided to use GE2E-XS loss, extension of GE2E loss, used in turn-to-diarize, which is mostly my inspiration. For this loss calculation I’d need batches of data in shape [n speakers x m utterances]. I think such random choice for batching is simple task when we operate on list of the files and I’ve found PyTorch implementation of such dataset and dataloader.
My problem is with the usage of tf.data.Dataset
, which I build from TFRecords, because otherwise my training is input-bounded by network HDD drive with indexing issues (slow on many small files). I’d like to get different pairs with different utterances in each epoch.
Each of my tf.train.Example
has such structure: {spectrogram, speaker_id, misc…}. Notice that each speaker has different amount of samples.
My idea was to:
- Load dataset from TFRecords and parse it;
- Shuffle it each epoch (to get different order of utterances every epoch);
- Filter the dataset in such way that produces dataset for each of the speakers Dataset = {dataset_speaker_0, dataset_speaker_1, …, dataset_speaker_k}. List of speakers is know and is loaded from CSV with sources that I use to build TFRecords;
- Shuffle the list of these speaker datasets (so I get different order and combinations of speakers in each epoch);
- Batch by somehow selected n speakers and iter through their dataset with m utterances as long as there is enough data left (Any unused data would come from speakers with less than m utterances left or the case where there is no n speakers left to select).
I’m not the expert in tf.data
datasets by any means and previously I worked mostly with cases where each sample was independent. I’m a bit worried about complexity of such operations (how fast filtering the dataset into individual datsets for each speaker can go? Maybe I should directly prepare tfrecords for each speaker?). I have also no idea how to solve the step #5. I’ve seen choose_from_datasets
method, but I’m still not sure how to use it in this scenario.
I’d appreciate your feedback on those two topics (memory+time complexity of filtering one big dataset + batching).