说话人与音频分析¶
说话人与音频分析识别谁在说话、何时说话以及存在哪些非语言声音。本文涵盖说话人确认与识别、i向量、d向量、x向量、说话人日志、音频事件分类、音乐信息检索以及语音情感识别。
-
在文件 01 中,我们构建了信号处理基础:语谱图、MFCC 和梅尔滤波器组。在文件 02 中,我们识别了所说的内容。现在我们要问:是谁说的、何时说的、以及音频中还在发生什么。说话人识别、说话人日志、音频分类和音乐分析都共享一条主线:学习能够为当前任务捕捉正确不变性的紧凑嵌入,这与第 06 章中的嵌入思想一脉相承。
-
可以把说话人识别想象成在电话中辨认朋友的声音。你不需要理解词汇;某种关于音色、语速和嗓音特质的东西对这个人来说是独一无二的。说话人识别系统学会从原始音频中提取这种"声纹",忽略说的是什么,专注于怎么说的。
-
说话人识别是两类相关任务的总称:
- 说话人确认(SV):给定一个声明的身份和一段音频片段,判断说话人是否与其声称的身份一致。这是一个二元决策(接受或拒绝),是基于语音的身份验证技术("嘿 Siri,这是我的声音吗?")背后的核心原理。
- 说话人识别(SI):给定一段音频片段和一个已知说话人库,判断该片段由哪个说话人产生。这是一个多分类问题。
-
两种任务共享相同的底层表示:一个固定维度的说话人嵌入,它捕捉说话人的身份特征而与所说内容无关。区别仅在于决策阶段:确认比较两个嵌入,识别则在候选嵌入中找到最近邻。
-
余弦相似度是比较说话人嵌入的标准度量。给定注册嵌入 \(e\) 和测试嵌入 \(t\):
-
阈值 \(\theta\) 决定接受/拒绝决策:若 \(s > \theta\),则接受。阈值在错误接受率(FAR)和错误拒绝率(FRR)之间权衡。等错误率(EER),即 FAR = FRR 时的值,是标准评估指标。EER 越低表示性能越好。最先进的系统在标准基准(VoxCeleb)上可实现低于 1% 的 EER。
-
i向量(Dehak 等人,2010)是深度学习之前主导性的说话人嵌入方法。其思想源于因子分析(第 02 章的矩阵分解和第 04 章的降维)。一个通用背景模型(UBM)——基于多样本说话人训练的大型 GMM——定义了一个超向量空间。每条语音的 GMM 超向量被投影到低维的全可变性空间:
-
其中 \(M\) 是该语音的 GMM 超向量,\(m\) 是 UBM 均值超向量,\(T\) 是全可变性矩阵(从数据中学习得到),\(w\) 是 i 向量,一个低维(通常为 400-600 维)表示,同时捕捉说话人变异和信道变异。
-
为了从 i 向量中去除信道变异,概率线性判别分析(PLDA)将 i 向量建模为说话人特定潜变量和信道特定潜变量之和。PLDA 为确认任务提供了一个有原则的对数似然比分数:
-
d向量(Variani 等人,2014)是第一个神经说话人嵌入。一个为说话人分类训练的 DNN 处理帧级特征,通过对整条语音中最后一层隐藏层激活值求平均,提取出固定维度的表示。虽然简单但有效,d向量证明了神经网络可以在没有 i 向量复杂统计机制的情况下学习到说话人判别性特征。
-
x向量(Snyder 等人,2018)使用时延神经网络(TDNN)架构显著推进了神经说话人嵌入。TDNN 是具有特定上下文窗口的 1D 卷积,与文件 03 中 WaveNet 的扩张卷积有关,但应用于帧级特征而非原始波形样本。
- x向量架构包含三个阶段:
- 帧级层:一组 TDNN 层处理 MFCC(来自文件 01),时间上下文逐步扩大。每一层都有一个固定的上下文窗口(例如第一层为 \(\{t-2, t-1, t, t+1, t+2\}\),后续层窗口更宽)。
- 统计池化:在帧级层之后,计算帧级输出在整个语音上的均值和标准差,产生一个与语音时长无关的固定维度向量:
-
其中 \(h_t\) 是时间 \(t\) 的帧级输出。拼接 \([\mu; \sigma]\) 即为池化后的表示。
- 段级层:全连接层处理池化后的表示。第一个段级层的输出(softmax 之前)即为 x 向量嵌入。
-
x向量使用说话人身份上的标准交叉熵损失进行训练。尽管是为分类任务训练的,但学习到的中间表示(x向量)能很好地泛化到未见过的说话人,因为网络学习的是提取说话人判别性特征,而非记忆特定说话人。
-
ECAPA-TDNN(Desplanques 等人,2020)是目前最先进的基于 TDNN 的说话人识别架构。它在 x 向量基础上引入了三项改进:
- 压缩激励(SE)模块:通道注意力(来自第 08 章的 SENet),根据全局上下文重新加权特征通道,使模型能够强调与说话人相关的通道。
- Res2Net 风格的多尺度特征:在每个 TDNN 模块内,通道被分成若干组,以层级方式处理,在多个时间分辨率上创建特征(类似于第 08 章的多尺度特征提取)。
- 注意力统计池化:不再使用等权平均,而是通过注意力机制为每一帧对池化统计量的贡献分配权重。包含更多说话人判别性内容的帧(如元音,承载更多说话人信息)获得更高的注意力权重:
-
其中 \(f\) 是一个小型神经网络,\(v\) 是一个学习到的注意力向量。注意力加权的均值和标准差变为 \(\tilde{\mu} = \sum_t \alpha_t h_t\) 和 \(\tilde{\sigma} = \sqrt{\sum_t \alpha_t (h_t - \tilde{\mu})^2}\)。
-
ECAPA-TDNN 通常使用 AAM-Softmax(附加角度间隔 Softmax)进行训练,它在分类损失中添加了角度间隔惩罚,将同一说话人的嵌入推得更近,不同说话人的嵌入在超球面上推得更远:
-
其中 \(\theta_{y_i}\) 是嵌入与真实类别权重向量之间的夹角,\(m\) 是间隔(通常为 0.2),\(s\) 是缩放因子(通常为 30)。该损失函数来自人脸识别(第 08 章的 ArcFace),在说话人确认中非常有效。
-
说话人日志回答了多方录音中"谁在什么时候说话"的问题。可以把这想象成给时间线上色:每种颜色代表一个不同的说话人,系统必须确定每个说话人何时活跃,包括重叠语音的情况。
-
基于聚类的说话人日志是传统的流水线方法:
- 分割:将音频划分为短段(通常为 1-2 秒),使用滑动窗口或说话人变化检测。
- 嵌入提取:为每个片段提取说话人嵌入(x向量、ECAPA-TDNN)。
- 聚类:按说话人对片段进行分组。凝聚层次聚类(AHC)是标准方法:开始时每个片段自成一类,然后迭代合并两个最相似的类,直到满足停止条件(基于距离阈值或目标说话人数)。
- 重分割:使用基于维特比算法的重对齐来优化边界。
-
说话人数量通常事先未知,这使得该问题比标准聚类更困难。使用基于特征值阈值确定 \(k\) 的谱聚类是另一种常见方法。
-
端到端神经说话人日志(EEND)(Fujita 等人,2019)将说话人日志框架化为一个多标签分类问题。一个神经网络(通常是基于自注意力的模型,第 07 章的 transformer)将整段录音作为输入,为每一帧输出每个说话人的二元活动标签。这直接处理了重叠语音,而这是基于聚类方法的主要弱点。
-
EEND 对 \(S\) 个说话人在帧 \(t\) 的输出为:
-
其中 \(h_t\) 是帧 \(t\) 处的 transformer 输出,\(f_s\) 是说话人 \(s\) 的线性投影。训练损失是在说话人和帧上求和得到的二元交叉熵。一个关键挑战是说话人数量必须固定,或者使用可变输出架构(EEND-EDA 使用带吸引子的编码器-解码器)来处理。
-
置换不变训练(PIT)用于处理说话人日志中的标签歧义问题:由于说话人没有固有顺序,需要对所有可能的说话人到输出分配计算损失,并取最小值(这与文件 05 中源分离使用的 PIT 相同)。
-
音频分类为整段音频片段分配一个标签。与转录语音的 ASR(文件 02)不同,音频分类涵盖更广的范围:环境声音(警笛、雨声、狗吠)、音乐流派(摇滚、爵士、古典)以及一般音频事件。
-
标准方法遵循第 08 章的图像分类范式:将音频表示为语谱图(一个二维时间-频率图像),然后应用 CNN 或 transformer 分类器。这种谱图-图像方法利用了计算机视觉几十年来的进展。
-
环境声音分类(ESC)使用 ESC-50(50 类,2000 个片段)和 UrbanSound8K 等数据集。典型架构是应用于对数梅尔语谱图的 CNN(第 06 章)。数据增强至关重要:时间拉伸、音高偏移、添加背景噪声以及 SpecAugment(文件 02 的掩码方法应用于语谱图)都能提升泛化能力。
-
音频事件检测(声音事件检测,SED)是分类的时间维度对应任务:不仅仅要知道存在哪些事件,还要知道它们何时开始和结束。AudioSet(Gemmeke 等人,2017)是大规模基准,包含 527 个事件类别和超过 200 万个来自 YouTube 的 10 秒片段,每个片段都有弱标注(片段级标签,而非帧级)。
-
弱监督 SED 必须从片段级标签学习帧级预测。标准方法使用 CNN 产生帧级类别概率,然后通过注意力池化聚合成片段级预测:
-
其中 \(f_{t,c}\) 是类别 \(c\) 在时间 \(t\) 的帧级 logit,\(\alpha_{t,c}\) 是注意力权重。片段级预测 \(\hat{Y}_c\) 根据片段级标签进行训练。
-
声学场景分类(ASC)对整体环境进行分类:"机场"、"公园"、"地铁站"、"办公室"。这是一个整体性任务:模型必须捕捉一般的声学纹理而非特定事件。DCASE 挑战系列每年对 ASC 进行基准测试,获奖系统通常使用多分辨率语谱图上的 CNN 集成。
-
音频嵌入是从大规模音频数据中学习到的通用表示,类似于可迁移到下游任务的词嵌入(第 07 章)或图像特征(第 08 章)。
-
VGGish(Hershey 等人,2017)将 VGG 图像分类网络(第 08 章)适配到音频领域。它通过一个在 AudioSet 上预训练的类 VGG CNN 处理 0.96 秒的对数梅尔语谱图块,每块产生一个 128 维嵌入。VGGish 嵌入可作为下游任务的通用音频特征,类似于 ImageNet 预训练 CNN 提供视觉特征的方式。
-
PANNs(预训练音频神经网络,Kong 等人,2020)是一系列 CNN 架构(CNN6、CNN10、CNN14),在完整的 AudioSet 上为音频标记任务训练。CNN14 使用最广泛,是一个 14 层 CNN,将对数梅尔语谱图作为输入,使用 \(3 \times 3\) 卷积。PANNs 产生 2048 维嵌入,在多种音频任务上实现了最先进的迁移学习性能。
-
音频语谱图 Transformer(AST)(Gong 等人,2021)将视觉 Transformer(ViT,第 08 章)架构直接应用于音频语谱图。语谱图被分割成 \(16 \times 16\) 的块(就像 ViT 分割图像一样),每个块被线性投影为令牌嵌入,添加位置嵌入,然后由标准 Transformer 编码器(第 07 章)处理序列。[CLS] 令牌的输出用于分类。
-
AST 受益于 ImageNet 预训练:由于语谱图是 2D 图像,AST 从 ImageNet 图像上预训练的 ViT 初始化,然后在音频上微调。这种跨模态迁移出奇地有效,因为两个域共享低级特征(边缘、纹理),并且位置嵌入可以插值以处理不同大小的语谱图。
-
HTS-AT(Chen 等人,2022)使用分层 Swin Transformer 架构(第 08 章的移位窗口注意力)改进了 AST,在降低计算成本的同时通过多尺度特征提取提升了性能。
-
BEATs(Chen 等人,2023)使用了一种音频特定的预训练策略:使用离散标记器进行迭代掩码预测(类似于文件 02 中 wav2vec 2.0 的方法,但应用于通用音频)。标记器逐步细化,创建越来越具有语义意义的离散音频令牌。
-
基于嵌入的说话人日志结合了说话人嵌入与时序建模。像 Pyannote.audio 这样的现代系统使用三阶段流水线:(1) 检测说话人切换和重叠语音的神经分割模型,(2) 应用于每个检测到的片段的嵌入提取阶段(ECAPA-TDNN),以及 (3) 聚类以在整个录音中分配说话人身份。
-
音乐信息检索(MIR)将音频分析应用于音乐。文件 01 中的谱图表示在这里尤其有用,因为音乐具有丰富的和声结构。
-
节拍跟踪检测音乐的节奏脉冲。标准方法从语谱图计算起始强度包络(检测表示音符起始的能量增加),然后使用自相关或节拍图谱找到节奏,最后使用动态规划跟踪单个节拍位置,找到最能匹配起始包络同时保持稳定节奏的节拍时间序列。
-
和弦识别识别随时间变化的和声内容。输入通常是色度图(也称为音高类别分布图):一个 12 维表示,将所有八度折叠在一起,显示 12 个音高类别(C、C#、D、…、B)中每个类别的能量。CNN 或 RNN(第 06 章)将每个时间帧分类到标准和弦标签之一(C 大调、A 小调、G7 等)。
-
色度图通过将每个频率区间映射到其音高类别,从 STFT(文件 01)计算得到:
-
其中 \(p \in \{0, 1, \ldots, 11\}\) 是音高类别,\(\text{pitch}(k)\) 将频率区间 \(k\) 映射到其 MIDI 音符编号。
-
源分离基础(详见文件 05)将音乐录音分离为单独的乐器(人声、鼓、贝斯、其他)。这是混音、卡拉 OK 和音乐转录等 MIR 应用的核心。像 Demucs(文件 05)这样的模型在标准 MUSDB18 基准上达到了非常好的分离质量。
-
音乐标记为歌曲分配标签(流派、情感、乐器、时代)。它本质上是应用于音乐的音频分类,使用相同的 CNN-语谱图方法。Million Song Dataset 和 MagnaTagATune 是标准基准。
-
音频指纹从短片段中识别特定录音,即使存在噪声、混响或压缩伪影。经典系统是 Shazam,它对星座图(语谱图中的显著峰值)进行哈希处理。神经方法学习对声学退化具有不变性、同时对不同录音保持判别性的鲁棒嵌入,这与第 06 章和第 08 章中的不变特征学习一脉相承。
编程任务(使用 Colab 或笔记本)¶
- 任务 1:带统计池化的说话人嵌入提取。 构建一个简单的 x向量风格模型,通过 TDNN 层和统计池化处理帧级特征以产生说话人嵌入。
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
# Simulate frame-level MFCC features for multiple speakers
def generate_speaker_data(key, n_speakers=5, utterances_per_speaker=20,
n_frames=100, n_features=40):
"""Generate synthetic speaker data with speaker-dependent patterns."""
keys = jr.split(key, 3)
all_features = []
all_labels = []
# Each speaker has a characteristic spectral pattern
speaker_patterns = jr.normal(keys[0], (n_speakers, n_features)) * 0.5
for spk in range(n_speakers):
for utt in range(utterances_per_speaker):
k = jr.fold_in(keys[1], spk * utterances_per_speaker + utt)
noise = jr.normal(k, (n_frames, n_features)) * 0.3
features = speaker_patterns[spk][None, :] + noise
all_features.append(features)
all_labels.append(spk)
perm = jr.permutation(keys[2], len(all_features))
features = jnp.stack(all_features)[perm]
labels = jnp.array(all_labels)[perm]
return features, labels
key = jr.PRNGKey(42)
features, labels = generate_speaker_data(key)
n_speakers = 5
n_features = 40
# x-vector-style model
def init_xvector(key, n_features=40, hidden=128, embed_dim=64, n_speakers=5):
keys = jr.split(key, 8)
params = {
# TDNN layer 1: context [-2, 2]
'tdnn1_w': jr.normal(keys[0], (5, n_features, hidden)) * jnp.sqrt(2.0 / (5 * n_features)),
'tdnn1_b': jnp.zeros(hidden),
# TDNN layer 2: context [-2, 2]
'tdnn2_w': jr.normal(keys[1], (5, hidden, hidden)) * jnp.sqrt(2.0 / (5 * hidden)),
'tdnn2_b': jnp.zeros(hidden),
# TDNN layer 3: context [-3, 3]
'tdnn3_w': jr.normal(keys[2], (7, hidden, hidden)) * jnp.sqrt(2.0 / (7 * hidden)),
'tdnn3_b': jnp.zeros(hidden),
# Segment-level layers (after pooling: 2*hidden -> embed_dim)
'seg1_w': jr.normal(keys[3], (2 * hidden, embed_dim)) * jnp.sqrt(2.0 / (2 * hidden)),
'seg1_b': jnp.zeros(embed_dim),
# Classification head
'cls_w': jr.normal(keys[4], (embed_dim, n_speakers)) * jnp.sqrt(2.0 / embed_dim),
'cls_b': jnp.zeros(n_speakers),
}
return params
def xvector_forward(params, x, return_embedding=False):
"""x: (batch, frames, features) -> logits or embeddings."""
# TDNN layers (1D convolutions)
h = jax.lax.conv_general_dilated(
x.transpose(0, 2, 1), params['tdnn1_w'].transpose(2, 1, 0),
window_strides=(1,), padding='SAME'
).transpose(0, 2, 1) + params['tdnn1_b']
h = jax.nn.relu(h)
h = jax.lax.conv_general_dilated(
h.transpose(0, 2, 1), params['tdnn2_w'].transpose(2, 1, 0),
window_strides=(1,), padding='SAME'
).transpose(0, 2, 1) + params['tdnn2_b']
h = jax.nn.relu(h)
h = jax.lax.conv_general_dilated(
h.transpose(0, 2, 1), params['tdnn3_w'].transpose(2, 1, 0),
window_strides=(1,), padding='SAME'
).transpose(0, 2, 1) + params['tdnn3_b']
h = jax.nn.relu(h)
# Statistics pooling: mean and std over time
mu = jnp.mean(h, axis=1)
sigma = jnp.std(h, axis=1)
pooled = jnp.concatenate([mu, sigma], axis=-1)
# Segment-level layer -> embedding
embedding = jax.nn.relu(pooled @ params['seg1_w'] + params['seg1_b'])
if return_embedding:
return embedding
# Classification
logits = embedding @ params['cls_w'] + params['cls_b']
return logits
def cross_entropy_loss(params, features, labels):
logits = xvector_forward(params, features)
one_hot = jax.nn.one_hot(labels, n_speakers)
log_probs = jax.nn.log_softmax(logits)
return -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1))
grad_fn = jax.jit(jax.value_and_grad(cross_entropy_loss))
# Train
params = init_xvector(jr.PRNGKey(0))
lr = 1e-3
losses = []
for epoch in range(300):
loss_val, grads = grad_fn(params, features, labels)
params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
losses.append(float(loss_val))
# Extract embeddings and visualise with t-SNE-style 2D projection (using PCA)
embeddings = xvector_forward(params, features, return_embedding=True)
# Simple PCA to 2D
emb_centered = embeddings - jnp.mean(embeddings, axis=0)
_, _, Vt = jnp.linalg.svd(emb_centered, full_matrices=False)
proj_2d = emb_centered @ Vt[:2].T
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
axes[0].plot(losses, color='#3498db', linewidth=1.5)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Cross-Entropy Loss')
axes[0].set_title('Speaker Classification Training')
axes[0].set_yscale('log')
colors = ['#3498db', '#e74c3c', '#27ae60', '#f39c12', '#9b59b6']
for spk in range(n_speakers):
mask = labels == spk
axes[1].scatter(proj_2d[mask, 0], proj_2d[mask, 1], c=colors[spk],
label=f'Speaker {spk}', alpha=0.7, s=30)
axes[1].set_xlabel('PC 1')
axes[1].set_ylabel('PC 2')
axes[1].set_title('Speaker Embeddings (PCA projection)')
axes[1].legend()
plt.tight_layout()
plt.show()
# Verification demo: cosine similarity
emb_norm = embeddings / jnp.linalg.norm(embeddings, axis=-1, keepdims=True)
sim_matrix = emb_norm @ emb_norm.T
print(f"Embedding shape: {embeddings.shape}")
print(f"Avg same-speaker similarity: {jnp.mean(sim_matrix[labels[:, None] == labels[None, :]]):.4f}")
print(f"Avg diff-speaker similarity: {jnp.mean(sim_matrix[labels[:, None] != labels[None, :]]):.4f}")
- 任务 2:基于余弦相似度评分的说话人确认。 给定预计算的说话人嵌入,实现一个计算 EER(等错误率)并绘制 DET 曲线的确认系统。
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
def generate_verification_pairs(key, n_speakers=20, dim=64, n_pairs=2000):
"""Generate speaker embeddings and verification trial pairs."""
keys = jr.split(key, 5)
# Speaker centroids with some variance
centroids = jr.normal(keys[0], (n_speakers, dim))
centroids = centroids / jnp.linalg.norm(centroids, axis=-1, keepdims=True)
# Generate enrollment and test embeddings with intra-speaker variance
enroll_embs = []
test_embs = []
trial_labels = [] # 1 = same speaker (target), 0 = different (impostor)
for i in range(n_pairs):
k1, k2, k3 = jr.split(jr.fold_in(keys[1], i), 3)
is_target = jr.bernoulli(k1).astype(int)
spk1 = jr.randint(k2, (), 0, n_speakers)
emb1 = centroids[spk1] + jr.normal(jr.fold_in(k3, 0), (dim,)) * 0.15
if is_target:
spk2 = spk1
else:
spk2 = (spk1 + jr.randint(jr.fold_in(k3, 1), (), 1, n_speakers)) % n_speakers
emb2 = centroids[spk2] + jr.normal(jr.fold_in(k3, 2), (dim,)) * 0.15
enroll_embs.append(emb1)
test_embs.append(emb2)
trial_labels.append(int(is_target))
return (jnp.stack(enroll_embs), jnp.stack(test_embs),
jnp.array(trial_labels))
key = jr.PRNGKey(42)
enroll, test, labels = generate_verification_pairs(key)
# Compute cosine similarity scores
enroll_norm = enroll / jnp.linalg.norm(enroll, axis=-1, keepdims=True)
test_norm = test / jnp.linalg.norm(test, axis=-1, keepdims=True)
scores = jnp.sum(enroll_norm * test_norm, axis=-1)
# Compute FAR and FRR at various thresholds
thresholds = jnp.linspace(-1.0, 1.0, 500)
target_scores = scores[labels == 1]
impostor_scores = scores[labels == 0]
fars = []
frrs = []
for thresh in thresholds:
far = jnp.mean(impostor_scores >= thresh) # false accepts
frr = jnp.mean(target_scores < thresh) # false rejects
fars.append(float(far))
frrs.append(float(frr))
fars = jnp.array(fars)
frrs = jnp.array(frrs)
# Find EER: where FAR ≈ FRR
eer_idx = jnp.argmin(jnp.abs(fars - frrs))
eer = float((fars[eer_idx] + frrs[eer_idx]) / 2)
eer_threshold = float(thresholds[eer_idx])
print(f"Equal Error Rate (EER): {eer:.4f} ({eer*100:.2f}%)")
print(f"EER threshold: {eer_threshold:.4f}")
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
# Score distributions
bins = jnp.linspace(-0.5, 1.0, 60)
axes[0].hist(target_scores, bins=bins, alpha=0.6, color='#27ae60',
label='Target (same speaker)', density=True)
axes[0].hist(impostor_scores, bins=bins, alpha=0.6, color='#e74c3c',
label='Impostor (different speaker)', density=True)
axes[0].axvline(eer_threshold, color='#f39c12', linestyle='--', linewidth=2,
label=f'EER threshold = {eer_threshold:.3f}')
axes[0].set_xlabel('Cosine Similarity Score')
axes[0].set_ylabel('Density')
axes[0].set_title('Score Distributions')
axes[0].legend()
# FAR vs FRR
axes[1].plot(thresholds, fars, color='#e74c3c', linewidth=2, label='FAR')
axes[1].plot(thresholds, frrs, color='#3498db', linewidth=2, label='FRR')
axes[1].axvline(eer_threshold, color='#f39c12', linestyle='--', linewidth=1.5)
axes[1].scatter([eer_threshold], [eer], color='#f39c12', s=100, zorder=5,
label=f'EER = {eer:.4f}')
axes[1].set_xlabel('Threshold')
axes[1].set_ylabel('Error Rate')
axes[1].set_title('FAR and FRR vs Threshold')
axes[1].legend()
# DET curve (FAR vs FRR)
axes[2].plot(fars, frrs, color='#9b59b6', linewidth=2)
axes[2].plot([0, 1], [0, 1], 'k--', alpha=0.3)
axes[2].scatter([eer], [eer], color='#f39c12', s=100, zorder=5,
label=f'EER = {eer:.4f}')
axes[2].set_xlabel('False Acceptance Rate')
axes[2].set_ylabel('False Rejection Rate')
axes[2].set_title('DET Curve')
axes[2].set_xlim([0, 0.5])
axes[2].set_ylim([0, 0.5])
axes[2].legend()
axes[2].set_aspect('equal')
plt.tight_layout()
plt.show()
- 任务 3:音频语谱图块嵌入(AST 风格)。 实现音频语谱图 Transformer 的块提取和嵌入层,可视化语谱图如何被令牌化。
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
# Generate a synthetic spectrogram (harmonic structure + noise)
def generate_spectrogram(key, n_time=128, n_freq=128):
"""Create a synthetic spectrogram with harmonic patterns."""
k1, k2 = jr.split(key)
spec = jr.normal(k1, (n_time, n_freq)) * 0.1
# Add harmonic bands (simulating speech formants)
for f0 in [15, 30, 45, 70]:
width = 3
envelope = jnp.exp(-0.5 * ((jnp.arange(n_freq) - f0) / width) ** 2)
time_mod = 0.5 + 0.5 * jnp.sin(2 * jnp.pi * jnp.arange(n_time) / 40)
spec += jnp.outer(time_mod, envelope)
return jnp.clip(spec, 0, None)
key = jr.PRNGKey(42)
spectrogram = generate_spectrogram(key)
n_time, n_freq = spectrogram.shape
# Patch extraction parameters
patch_h = 16 # time
patch_w = 16 # frequency
stride_h = 16
stride_w = 16
embed_dim = 192 # ViT-Small dimension
n_patches_h = n_time // stride_h
n_patches_w = n_freq // stride_w
n_patches = n_patches_h * n_patches_w
print(f"Spectrogram: {n_time} x {n_freq}")
print(f"Patch size: {patch_h} x {patch_w}")
print(f"Number of patches: {n_patches_h} x {n_patches_w} = {n_patches}")
# Extract patches
def extract_patches(spec, patch_h, patch_w, stride_h, stride_w):
"""Extract non-overlapping patches from spectrogram."""
patches = []
positions = []
for i in range(0, spec.shape[0] - patch_h + 1, stride_h):
for j in range(0, spec.shape[1] - patch_w + 1, stride_w):
patch = spec[i:i+patch_h, j:j+patch_w]
patches.append(patch.flatten())
positions.append((i, j))
return jnp.stack(patches), positions
patches, positions = extract_patches(spectrogram, patch_h, patch_w, stride_h, stride_w)
print(f"Patches shape: {patches.shape}") # (n_patches, patch_h * patch_w)
# Linear projection (patch embedding)
patch_dim = patch_h * patch_w
k1, k2 = jr.split(jr.PRNGKey(0))
W_embed = jr.normal(k1, (patch_dim, embed_dim)) * jnp.sqrt(2.0 / patch_dim)
b_embed = jnp.zeros(embed_dim)
# Learnable positional embeddings
pos_embed = jr.normal(k2, (n_patches + 1, embed_dim)) * 0.02 # +1 for CLS
# CLS token
cls_token = jnp.zeros((1, embed_dim))
# Forward pass
patch_tokens = patches @ W_embed + b_embed # (n_patches, embed_dim)
tokens = jnp.concatenate([cls_token, patch_tokens], axis=0) # (n_patches+1, embed_dim)
tokens = tokens + pos_embed # Add positional embeddings
print(f"Token sequence shape: {tokens.shape}")
print(f"Each token has dimension: {embed_dim}")
# Visualisation
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# Original spectrogram with patch grid
axes[0, 0].imshow(spectrogram.T, aspect='auto', origin='lower', cmap='magma')
for i in range(0, n_time + 1, stride_h):
axes[0, 0].axvline(i - 0.5, color='white', linewidth=0.5, alpha=0.5)
for j in range(0, n_freq + 1, stride_w):
axes[0, 0].axhline(j - 0.5, color='white', linewidth=0.5, alpha=0.5)
axes[0, 0].set_title(f'Spectrogram with {patch_h}x{patch_w} Patch Grid')
axes[0, 0].set_xlabel('Time frame')
axes[0, 0].set_ylabel('Frequency bin')
# Individual patches visualised
n_show = min(16, n_patches)
patch_grid = patches[:n_show].reshape(n_show, patch_h, patch_w)
combined = jnp.concatenate([patch_grid[i] for i in range(min(8, n_show))], axis=1)
axes[0, 1].imshow(combined.T, aspect='auto', origin='lower', cmap='magma')
axes[0, 1].set_title(f'First {min(8, n_show)} Patches (concatenated)')
axes[0, 1].set_xlabel('Patch index (horizontal)')
axes[0, 1].set_ylabel('Frequency within patch')
# Token embeddings similarity matrix
token_norms = tokens / jnp.linalg.norm(tokens, axis=-1, keepdims=True)
sim = token_norms @ token_norms.T
im = axes[1, 0].imshow(sim, cmap='RdBu_r', vmin=-1, vmax=1)
axes[1, 0].set_title('Token Similarity Matrix (cosine)')
axes[1, 0].set_xlabel('Token index')
axes[1, 0].set_ylabel('Token index')
plt.colorbar(im, ax=axes[1, 0], fraction=0.046)
# Positional embedding similarity
pos_norms = pos_embed / jnp.linalg.norm(pos_embed, axis=-1, keepdims=True)
pos_sim = pos_norms @ pos_norms.T
im2 = axes[1, 1].imshow(pos_sim, cmap='RdBu_r', vmin=-1, vmax=1)
axes[1, 1].set_title('Positional Embedding Similarity')
axes[1, 1].set_xlabel('Position index')
axes[1, 1].set_ylabel('Position index')
plt.colorbar(im2, ax=axes[1, 1], fraction=0.046)
plt.tight_layout()
plt.show()
- 任务 4:用于和弦分析的简单色度图计算。 从合成和声信号计算并可视化色度图,展示音乐信息检索中使用的音高类别折叠方法。
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
# Generate a synthetic musical signal: C major chord -> G major chord
sr = 16000
duration = 2.0
t = jnp.linspace(0, duration, int(sr * duration))
# C major (C4=261.6, E4=329.6, G4=392.0) for first half
# G major (G3=196.0, B3=246.9, D4=293.7) for second half
half = len(t) // 2
c_major = (0.5 * jnp.sin(2 * jnp.pi * 261.63 * t[:half]) +
0.4 * jnp.sin(2 * jnp.pi * 329.63 * t[:half]) +
0.3 * jnp.sin(2 * jnp.pi * 392.00 * t[:half]))
g_major = (0.5 * jnp.sin(2 * jnp.pi * 196.00 * t[:half]) +
0.4 * jnp.sin(2 * jnp.pi * 246.94 * t[:half]) +
0.3 * jnp.sin(2 * jnp.pi * 293.66 * t[:half]))
signal = jnp.concatenate([c_major, g_major])
# Compute STFT
n_fft = 4096 # high resolution for pitch accuracy
hop_length = 512
window = jnp.hanning(n_fft)
def stft(signal, n_fft, hop_length, window):
n_frames = 1 + (len(signal) - n_fft) // hop_length
frames = jnp.stack([
signal[i * hop_length : i * hop_length + n_fft] * window
for i in range(n_frames)
])
return jnp.fft.rfft(frames, n=n_fft)
S = stft(signal, n_fft, hop_length, window)
power_spec = jnp.abs(S) ** 2
freqs = jnp.fft.rfftfreq(n_fft, 1.0 / sr)
# Compute chromagram by mapping frequency bins to pitch classes
# MIDI note number from frequency: 69 + 12 * log2(f / 440)
note_names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
def freq_to_chroma(freq):
"""Map frequency to pitch class (0-11). Returns -1 for freq <= 0."""
midi = 69 + 12 * jnp.log2(jnp.clip(freq, 1e-10, None) / 440.0)
return jnp.round(midi).astype(int) % 12
# Build chromagram: sum power spectrum energy for each pitch class
chromagram = jnp.zeros((power_spec.shape[0], 12))
valid_freqs = freqs[1:] # skip DC
valid_power = power_spec[:, 1:]
for p in range(12):
# Find frequency bins belonging to this pitch class
chroma_bins = freq_to_chroma(valid_freqs)
mask = (chroma_bins == p).astype(jnp.float32)
chromagram = chromagram.at[:, p].set(
jnp.sum(valid_power * mask[None, :], axis=1)
)
# Normalise each frame
chromagram = chromagram / (jnp.max(chromagram, axis=1, keepdims=True) + 1e-8)
# Visualisation
fig, axes = plt.subplots(3, 1, figsize=(14, 10))
# Waveform
axes[0].plot(t[:3000], signal[:3000], color='#3498db', linewidth=0.5,
label='C major')
axes[0].plot(t[half:half+3000], signal[half:half+3000], color='#e74c3c',
linewidth=0.5, label='G major')
axes[0].set_title('Waveform: C major → G major')
axes[0].set_ylabel('Amplitude')
axes[0].set_xlabel('Time (s)')
axes[0].legend()
# Spectrogram (log scale)
time_axis = jnp.arange(power_spec.shape[0]) * hop_length / sr
axes[1].imshow(jnp.log1p(power_spec[:, :500].T), aspect='auto', origin='lower',
cmap='magma', extent=[0, time_axis[-1], 0, freqs[500]])
axes[1].set_title('Power Spectrogram')
axes[1].set_ylabel('Frequency (Hz)')
axes[1].set_xlabel('Time (s)')
# Chromagram
im = axes[2].imshow(chromagram.T, aspect='auto', origin='lower', cmap='YlOrRd',
extent=[0, time_axis[-1], -0.5, 11.5])
axes[2].set_yticks(range(12))
axes[2].set_yticklabels(note_names)
axes[2].set_title('Chromagram (pitch class energy over time)')
axes[2].set_ylabel('Pitch class')
axes[2].set_xlabel('Time (s)')
plt.colorbar(im, ax=axes[2], fraction=0.046, label='Normalised energy')
# Mark expected active pitch classes
mid_frame = chromagram.shape[0] // 2
print(f"C major region - expected: C, E, G")
print(f" Chroma values: {dict(zip(note_names, [f'{v:.2f}' for v in chromagram[mid_frame//2]]))}")
print(f"G major region - expected: G, B, D")
print(f" Chroma values: {dict(zip(note_names, [f'{v:.2f}' for v in chromagram[mid_frame + mid_frame//2]]))}")
plt.tight_layout()
plt.show()