1.fbank特征
import torch.nn as nn
import torchaudio
class ExtractAudioFeature(nn.Module):
def __init__(self, feat_type="fbank", feat_dim=40):
super(ExtractAudioFeature, self).__init__()
self.feat_type = feat_type
self.extract_fn = torchaudio.compliance.kaldi.fbank if feat_type == "fbank" else torchaudio.compliance.kaldi.mfcc
self.num_mel_bins = feat_dim
def forward(self, filepath):
waveform, sample_rate = torchaudio.load(filepath)
y = self.extract_fn(waveform,
num_mel_bins=self.num_mel_bins,
channel=-1,
sample_frequency=sample_rate,
frame_length=25, #每帧的时长
frame_shift=10,
dither=0)
return y.transpose(0, 1).unsqueeze(0).detach()
extracter = ExtractAudioFeature("fbank",feat_dim=40)
wav = "./data/wav/day0914_990.wav"
wav_feature = extracter(wav)
print(wav_feature.shape)
torch.Size([1, 40, 489])
# 40:特征维度
# 489:音频帧数=音频时长/25ms
import matplotlib.pyplot as plt
plt.figure(dpi=200)
plt.xticks([])
plt.yticks([])
plt.imshow(wav_feature[0])
plt.show()
2.mfcc特征
import torch.nn as nn
import torchaudio
class ExtractAudioFeature(nn.Module):
def __init__(self, feat_type="mfcc", feat_dim=13):
super(ExtractAudioFeature, self).__init__()
self.feat_type = feat_type
self.extract_fn = torchaudio.compliance.kaldi.fbank if feat_type == "fbank" else torchaudio.compliance.kaldi.mfcc
self.num_mel_bins = feat_dim
def forward(self, filepath):
waveform, sample_rate = torchaudio.load(filepath)
y = self.extract_fn(waveform,
num_mel_bins=self.num_mel_bins,
channel=-1,
sample_frequency=sample_rate,
frame_length=25, #每帧的时长
frame_shift=10,
dither=0)
return y.transpose(0, 1).unsqueeze(0).detach()
extracter = ExtractAudioFeature("mfcc",feat_dim=13)
wav = "./data/wav/day0914_990.wav"
wav_feature = extracter(wav)
print(wav_feature.shape)
torch.Size([1, 13, 489])
# 13:特征维度
# 489:音频帧数=音频时长/25ms
import matplotlib.pyplot as plt
plt.figure(dpi=200)
plt.xticks([])
plt.yticks([])
plt.imshow(wav_feature[0])
plt.show()
参考资料
- https://github.com/neil-zeng/asr
评论 (0)