| | import json |
| | import random |
| | import torch |
| | import torchaudio |
| | from torch.utils.data import Dataset |
| |
|
| |
|
| | class AudioTextDataset(Dataset): |
| | """Can sample data from audio-text databases |
| | Params: |
| | sampling_rate: audio sampling rate |
| | max_clip_len: max length (seconds) of audio clip to be sampled |
| | """ |
| | def __init__( |
| | self, |
| | datafiles=[''], |
| | sampling_rate=32000, |
| | max_clip_len=5, |
| | ): |
| | all_data_json = [] |
| | for datafile in datafiles: |
| | with open(datafile, 'r') as fp: |
| | data_json = json.load(fp)['data'] |
| | all_data_json.extend(data_json) |
| | self.all_data_json = all_data_json |
| |
|
| | self.sampling_rate = sampling_rate |
| | self.max_length = max_clip_len * sampling_rate |
| |
|
| | def __len__(self): |
| | return len(self.all_data_json) |
| |
|
| | def _cut_or_randomcrop(self, waveform): |
| | |
| | |
| | if waveform.size(1) > self.max_length: |
| | random_idx = random.randint(0, waveform.size(1)-self.max_length) |
| | waveform = waveform[:, random_idx:random_idx+self.max_length] |
| | else: |
| | temp_wav = torch.zeros(1, self.max_length) |
| | temp_wav[:, 0:waveform.size(1)] = waveform |
| | waveform = temp_wav |
| |
|
| | assert waveform.size(1) == self.max_length, \ |
| | f"number of audio samples is {waveform.size(1)}" |
| |
|
| | return waveform |
| |
|
| | def _read_audio(self, index): |
| | try: |
| | audio_path = self.all_data_json[index]['wav'] |
| | audio_data, audio_rate = torchaudio.load(audio_path, channels_first=True) |
| | text = self.all_data_json[index]['caption'] |
| |
|
| | |
| | if audio_data.size(1) < self.sampling_rate * 1: |
| | raise Exception(f'{audio_path} is too short, drop it ...') |
| | |
| | return text, audio_data, audio_rate |
| | |
| | except Exception as e: |
| | print(f'error: {e} occurs, when loading {audio_path}') |
| | random_index = random.randint(0, len(self.all_data_json)-1) |
| | return self._read_audio(index=random_index) |
| |
|
| | def __getitem__(self, index): |
| | |
| | text, audio_data, audio_rate = self._read_audio(index) |
| | audio_len = audio_data.shape[1] / audio_rate |
| | |
| | if audio_data.shape[0] > 1: |
| | |
| | audio_data = (audio_data[0] + audio_data[1]) / 2 |
| | else: |
| | audio_data = audio_data.squeeze(0) |
| | |
| | |
| | if audio_rate != self.sampling_rate: |
| | audio_data = torchaudio.functional.resample(audio_data, orig_freq=audio_rate, new_freq=self.sampling_rate) |
| | |
| | audio_data = audio_data.unsqueeze(0) |
| | |
| | audio_data = self._cut_or_randomcrop(audio_data) |
| |
|
| | data_dict = { |
| | 'text': text, |
| | 'waveform': audio_data, |
| | 'modality': 'audio_text' |
| | } |
| |
|
| | return data_dict |
| |
|