Files
flykhan 2536c937e3 feat: 完整中文翻译 maths-cs-ai-compendium(数学·计算机科学·AI 知识大全)
翻译自英文原版 maths-cs-ai-compendium,共 20 章全部完成。

第01章 向量 | 第02章 矩阵 | 第03章 微积分
第04章 统计学 | 第05章 概率论 | 第06章 机器学习
第07章 计算语言学 | 第08章 计算机视觉 | 第09章 音频与语音
第10章 多模态学习 | 第11章 自主系统 | 第12章 图神经网络
第13章 计算与操作系统 | 第14章 数据结构与算法
第15章 生产级软件工程 | 第16章 SIMD与GPU编程
第17章 AI推理 | 第18章 ML系统设计
第19章 应用人工智能 | 第20章 前沿人工智能

翻译说明:
- 所有数学公式 $...$ / $$...$$、代码块、图片引用完整保留
- mkdocs.yml 配置中文导航 + language: zh
- README.md 已翻译为中文(兼 docs/index.md)
- docs/ 目录包含指向各章文件的 symlink
- 约 29,000 行中文内容,排除 .cache/ 构建缓存
2026-05-03 10:23:20 +08:00

644 lines
34 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# 说话人与音频分析
*说话人与音频分析识别谁在说话、何时说话以及存在哪些非语言声音。本文涵盖说话人确认与识别、i向量、d向量、x向量、说话人日志、音频事件分类、音乐信息检索以及语音情感识别。*
- 在文件 01 中,我们构建了信号处理基础:语谱图、MFCC 和梅尔滤波器组。在文件 02 中,我们识别了所说的内容。现在我们要问:是谁说的、何时说的、以及音频中还在发生什么。说话人识别、说话人日志、音频分类和音乐分析都共享一条主线:学习能够为当前任务捕捉正确不变性的紧凑嵌入,这与第 06 章中的嵌入思想一脉相承。
- 可以把说话人识别想象成在电话中辨认朋友的声音。你不需要理解词汇;某种关于音色、语速和嗓音特质的东西对这个人来说是独一无二的。说话人识别系统学会从原始音频中提取这种"声纹",忽略说的是什么,专注于怎么说的。
- **说话人识别**是两类相关任务的总称:
- **说话人确认**(SV):给定一个声明的身份和一段音频片段,判断说话人是否与其声称的身份一致。这是一个二元决策(接受或拒绝),是基于语音的身份验证技术("嘿 Siri,这是我的声音吗?")背后的核心原理。
- **说话人识别**(SI):给定一段音频片段和一个已知说话人库,判断该片段由哪个说话人产生。这是一个多分类问题。
![说话人确认:注册音频被嵌入,测试音频被嵌入,计算嵌入之间的余弦相似度,通过阈值决定接受或拒绝](../images/speaker_verification.svg)
- 两种任务共享相同的底层表示:一个固定维度的**说话人嵌入**,它捕捉说话人的身份特征而与所说内容无关。区别仅在于决策阶段:确认比较两个嵌入,识别则在候选嵌入中找到最近邻。
- **余弦相似度**是比较说话人嵌入的标准度量。给定注册嵌入 $e$ 和测试嵌入 $t$:
$$s = \frac{e \cdot t}{\|e\| \, \|t\|}$$
- 阈值 $\theta$ 决定接受/拒绝决策:若 $s > \theta$,则接受。阈值在**错误接受率(FAR)**和**错误拒绝率(FRR)**之间权衡。**等错误率(EER)**,即 FAR = FRR 时的值,是标准评估指标。EER 越低表示性能越好。最先进的系统在标准基准(VoxCeleb)上可实现低于 1% 的 EER。
- **i向量**Dehak 等人,2010)是深度学习之前主导性的说话人嵌入方法。其思想源于因子分析(第 02 章的矩阵分解和第 04 章的降维)。一个**通用背景模型(UBM)**——基于多样本说话人训练的大型 GMM——定义了一个超向量空间。每条语音的 GMM 超向量被投影到低维的**全可变性空间**:
$$M = m + Tw$$
- 其中 $M$ 是该语音的 GMM 超向量,$m$ 是 UBM 均值超向量,$T$ 是全可变性矩阵(从数据中学习得到),$w$ 是 i 向量,一个低维(通常为 400-600 维)表示,同时捕捉说话人变异和信道变异。
- 为了从 i 向量中去除信道变异,**概率线性判别分析(PLDA)**将 i 向量建模为说话人特定潜变量和信道特定潜变量之和。PLDA 为确认任务提供了一个有原则的对数似然比分数:
$$\text{score}(w_1, w_2) = \log \frac{P(w_1, w_2 \mid \text{同一说话人})}{P(w_1 \mid \text{说话人}_1) \, P(w_2 \mid \text{说话人}_2)}$$
- **d向量**Variani 等人,2014)是第一个神经说话人嵌入。一个为说话人分类训练的 DNN 处理帧级特征,通过对整条语音中最后一层隐藏层激活值求平均,提取出固定维度的表示。虽然简单但有效,d向量证明了神经网络可以在没有 i 向量复杂统计机制的情况下学习到说话人判别性特征。
- **x向量**Snyder 等人,2018)使用**时延神经网络(TDNN)**架构显著推进了神经说话人嵌入。TDNN 是具有特定上下文窗口的 1D 卷积,与文件 03 中 WaveNet 的扩张卷积有关,但应用于帧级特征而非原始波形样本。
![x向量架构:TDNN 层以递增的上下文处理帧级特征,统计池化在时间维度上聚合,全连接层产生说话人嵌入](../images/xvector_architecture.svg)
- x向量架构包含三个阶段:
- **帧级层**:一组 TDNN 层处理 MFCC(来自文件 01),时间上下文逐步扩大。每一层都有一个固定的上下文窗口(例如第一层为 $\{t-2, t-1, t, t+1, t+2\}$,后续层窗口更宽)。
- **统计池化**:在帧级层之后,计算帧级输出在整个语音上的均值和标准差,产生一个与语音时长无关的固定维度向量:
```math
\begin{aligned}
\mu &= \frac{1}{T} \sum_{t=1}^{T} h_t \\
\sigma &= \sqrt{\frac{1}{T} \sum_{t=1}^{T} (h_t - \mu)^2}
\end{aligned}
```
- 其中 $h_t$ 是时间 $t$ 的帧级输出。拼接 $[\mu; \sigma]$ 即为池化后的表示。
- **段级层**:全连接层处理池化后的表示。第一个段级层的输出(softmax 之前)即为 x 向量嵌入。
- x向量使用说话人身份上的标准交叉熵损失进行训练。尽管是为分类任务训练的,但学习到的中间表示(x向量)能很好地泛化到未见过的说话人,因为网络学习的是提取说话人判别性特征,而非记忆特定说话人。
- **ECAPA-TDNN**Desplanques 等人,2020)是目前最先进的基于 TDNN 的说话人识别架构。它在 x 向量基础上引入了三项改进:
- **压缩激励(SE)模块**:通道注意力(来自第 08 章的 SENet),根据全局上下文重新加权特征通道,使模型能够强调与说话人相关的通道。
- **Res2Net 风格的多尺度特征**:在每个 TDNN 模块内,通道被分成若干组,以层级方式处理,在多个时间分辨率上创建特征(类似于第 08 章的多尺度特征提取)。
- **注意力统计池化**:不再使用等权平均,而是通过注意力机制为每一帧对池化统计量的贡献分配权重。包含更多说话人判别性内容的帧(如元音,承载更多说话人信息)获得更高的注意力权重:
$$\alpha_t = \frac{\exp(v^T f(h_t))}{\sum_{\tau} \exp(v^T f(h_\tau))}$$
- 其中 $f$ 是一个小型神经网络,$v$ 是一个学习到的注意力向量。注意力加权的均值和标准差变为 $\tilde{\mu} = \sum_t \alpha_t h_t$ 和 $\tilde{\sigma} = \sqrt{\sum_t \alpha_t (h_t - \tilde{\mu})^2}$。
- ECAPA-TDNN 通常使用 **AAM-Softmax**(附加角度间隔 Softmax)进行训练,它在分类损失中添加了角度间隔惩罚,将同一说话人的嵌入推得更近,不同说话人的嵌入在超球面上推得更远:
$$L = -\log \frac{e^{s \cos(\theta_{y_i} + m)}}{e^{s \cos(\theta_{y_i} + m)} + \sum_{j \neq y_i} e^{s \cos \theta_j}}$$
- 其中 $\theta_{y_i}$ 是嵌入与真实类别权重向量之间的夹角,$m$ 是间隔(通常为 0.2),$s$ 是缩放因子(通常为 30)。该损失函数来自人脸识别(第 08 章的 ArcFace),在说话人确认中非常有效。
- **说话人日志**回答了多方录音中"谁在什么时候说话"的问题。可以把这想象成给时间线上色:每种颜色代表一个不同的说话人,系统必须确定每个说话人何时活跃,包括重叠语音的情况。
![说话人日志:音频时间线被分割并用说话人身份标注,展示交替说话和重叠区域](../images/speaker_diarisation.svg)
- **基于聚类的说话人日志**是传统的流水线方法:
- **分割**:将音频划分为短段(通常为 1-2 秒),使用滑动窗口或说话人变化检测。
- **嵌入提取**:为每个片段提取说话人嵌入(x向量、ECAPA-TDNN)。
- **聚类**:按说话人对片段进行分组。**凝聚层次聚类(AHC)**是标准方法:开始时每个片段自成一类,然后迭代合并两个最相似的类,直到满足停止条件(基于距离阈值或目标说话人数)。
- **重分割**:使用基于维特比算法的重对齐来优化边界。
- 说话人数量通常事先未知,这使得该问题比标准聚类更困难。使用基于特征值阈值确定 $k$ 的谱聚类是另一种常见方法。
- **端到端神经说话人日志(EEND)**(Fujita 等人,2019)将说话人日志框架化为一个多标签分类问题。一个神经网络(通常是基于自注意力的模型,第 07 章的 transformer)将整段录音作为输入,为每一帧输出每个说话人的二元活动标签。这直接处理了重叠语音,而这是基于聚类方法的主要弱点。
- EEND 对 $S$ 个说话人在帧 $t$ 的输出为:
$$\hat{y}_{t,s} = \sigma(f_s(h_t))$$
- 其中 $h_t$ 是帧 $t$ 处的 transformer 输出,$f_s$ 是说话人 $s$ 的线性投影。训练损失是在说话人和帧上求和得到的二元交叉熵。一个关键挑战是说话人数量必须固定,或者使用可变输出架构(EEND-EDA 使用带吸引子的编码器-解码器)来处理。
- **置换不变训练(PIT)**用于处理说话人日志中的标签歧义问题:由于说话人没有固有顺序,需要对所有可能的说话人到输出分配计算损失,并取最小值(这与文件 05 中源分离使用的 PIT 相同)。
- **音频分类**为整段音频片段分配一个标签。与转录语音的 ASR(文件 02)不同,音频分类涵盖更广的范围:环境声音(警笛、雨声、狗吠)、音乐流派(摇滚、爵士、古典)以及一般音频事件。
- 标准方法遵循第 08 章的图像分类范式:将音频表示为语谱图(一个二维时间-频率图像),然后应用 CNN 或 transformer 分类器。这种谱图-图像方法利用了计算机视觉几十年来的进展。
- **环境声音分类(ESC**使用 ESC-5050 类,2000 个片段)和 UrbanSound8K 等数据集。典型架构是应用于对数梅尔语谱图的 CNN(第 06 章)。数据增强至关重要:时间拉伸、音高偏移、添加背景噪声以及 **SpecAugment**(文件 02 的掩码方法应用于语谱图)都能提升泛化能力。
- **音频事件检测**(声音事件检测,SED)是分类的时间维度对应任务:不仅仅要知道存在哪些事件,还要知道它们何时开始和结束。**AudioSet**Gemmeke 等人,2017)是大规模基准,包含 527 个事件类别和超过 200 万个来自 YouTube 的 10 秒片段,每个片段都有弱标注(片段级标签,而非帧级)。
- **弱监督 SED** 必须从片段级标签学习帧级预测。标准方法使用 CNN 产生帧级类别概率,然后通过注意力池化聚合成片段级预测:
$$\hat{Y}_c = \sigma\left(\sum_t \alpha_{t,c} \cdot f_{t,c}\right)$$
- 其中 $f_{t,c}$ 是类别 $c$ 在时间 $t$ 的帧级 logit$\alpha_{t,c}$ 是注意力权重。片段级预测 $\hat{Y}_c$ 根据片段级标签进行训练。
- **声学场景分类(ASC)**对整体环境进行分类:"机场"、"公园"、"地铁站"、"办公室"。这是一个整体性任务:模型必须捕捉一般的声学纹理而非特定事件。DCASE 挑战系列每年对 ASC 进行基准测试,获奖系统通常使用多分辨率语谱图上的 CNN 集成。
- **音频嵌入**是从大规模音频数据中学习到的通用表示,类似于可迁移到下游任务的词嵌入(第 07 章)或图像特征(第 08 章)。
- **VGGish**Hershey 等人,2017)将 VGG 图像分类网络(第 08 章)适配到音频领域。它通过一个在 AudioSet 上预训练的类 VGG CNN 处理 0.96 秒的对数梅尔语谱图块,每块产生一个 128 维嵌入。VGGish 嵌入可作为下游任务的通用音频特征,类似于 ImageNet 预训练 CNN 提供视觉特征的方式。
- **PANNs**(预训练音频神经网络,Kong 等人,2020)是一系列 CNN 架构(CNN6、CNN10、CNN14),在完整的 AudioSet 上为音频标记任务训练。CNN14 使用最广泛,是一个 14 层 CNN,将对数梅尔语谱图作为输入,使用 $3 \times 3$ 卷积。PANNs 产生 2048 维嵌入,在多种音频任务上实现了最先进的迁移学习性能。
- **音频语谱图 TransformerAST**Gong 等人,2021)将视觉 TransformerViT,第 08 章)架构直接应用于音频语谱图。语谱图被分割成 $16 \times 16$ 的块(就像 ViT 分割图像一样),每个块被线性投影为令牌嵌入,添加位置嵌入,然后由标准 Transformer 编码器(第 07 章)处理序列。[CLS] 令牌的输出用于分类。
![音频语谱图 Transformer:梅尔语谱图被分割成块,每个块展平并线性投影为令牌,添加位置嵌入,Transformer 编码器通过 CLS 令牌产生分类输出](../images/audio_spectrogram_transformer.svg)
- AST 受益于 **ImageNet 预训练**:由于语谱图是 2D 图像,AST 从 ImageNet 图像上预训练的 ViT 初始化,然后在音频上微调。这种跨模态迁移出奇地有效,因为两个域共享低级特征(边缘、纹理),并且位置嵌入可以插值以处理不同大小的语谱图。
- **HTS-AT**Chen 等人,2022)使用分层 Swin Transformer 架构(第 08 章的移位窗口注意力)改进了 AST,在降低计算成本的同时通过多尺度特征提取提升了性能。
- **BEATs**Chen 等人,2023)使用了一种音频特定的预训练策略:使用离散标记器进行迭代掩码预测(类似于文件 02 中 wav2vec 2.0 的方法,但应用于通用音频)。标记器逐步细化,创建越来越具有语义意义的离散音频令牌。
- **基于嵌入的说话人日志**结合了说话人嵌入与时序建模。像 Pyannote.audio 这样的现代系统使用三阶段流水线:(1) 检测说话人切换和重叠语音的神经分割模型,(2) 应用于每个检测到的片段的嵌入提取阶段(ECAPA-TDNN),以及 (3) 聚类以在整个录音中分配说话人身份。
- **音乐信息检索(MIR)**将音频分析应用于音乐。文件 01 中的谱图表示在这里尤其有用,因为音乐具有丰富的和声结构。
- **节拍跟踪**检测音乐的节奏脉冲。标准方法从语谱图计算**起始强度包络**(检测表示音符起始的能量增加),然后使用自相关或节拍图谱找到节奏,最后使用动态规划跟踪单个节拍位置,找到最能匹配起始包络同时保持稳定节奏的节拍时间序列。
- **和弦识别**识别随时间变化的和声内容。输入通常是**色度图**(也称为音高类别分布图):一个 12 维表示,将所有八度折叠在一起,显示 12 个音高类别(C、C#、D、…、B)中每个类别的能量。CNN 或 RNN(第 06 章)将每个时间帧分类到标准和弦标签之一(C 大调、A 小调、G7 等)。
- 色度图通过将每个频率区间映射到其音高类别,从 STFT(文件 01)计算得到:
$$\text{chroma}(p) = \sum_{k : \text{pitch}(k) \bmod 12 = p} |X(k)|^2$$
- 其中 $p \in \{0, 1, \ldots, 11\}$ 是音高类别,$\text{pitch}(k)$ 将频率区间 $k$ 映射到其 MIDI 音符编号。
- **源分离基础**(详见文件 05)将音乐录音分离为单独的乐器(人声、鼓、贝斯、其他)。这是混音、卡拉 OK 和音乐转录等 MIR 应用的核心。像 Demucs(文件 05)这样的模型在标准 MUSDB18 基准上达到了非常好的分离质量。
- **音乐标记**为歌曲分配标签(流派、情感、乐器、时代)。它本质上是应用于音乐的音频分类,使用相同的 CNN-语谱图方法。Million Song Dataset 和 MagnaTagATune 是标准基准。
- **音频指纹**从短片段中识别特定录音,即使存在噪声、混响或压缩伪影。经典系统是 Shazam,它对星座图(语谱图中的显著峰值)进行哈希处理。神经方法学习对声学退化具有不变性、同时对不同录音保持判别性的鲁棒嵌入,这与第 06 章和第 08 章中的不变特征学习一脉相承。
## 编程任务(使用 Colab 或笔记本)
- **任务 1:带统计池化的说话人嵌入提取。** 构建一个简单的 x向量风格模型,通过 TDNN 层和统计池化处理帧级特征以产生说话人嵌入。
```python
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
# Simulate frame-level MFCC features for multiple speakers
def generate_speaker_data(key, n_speakers=5, utterances_per_speaker=20,
n_frames=100, n_features=40):
"""Generate synthetic speaker data with speaker-dependent patterns."""
keys = jr.split(key, 3)
all_features = []
all_labels = []
# Each speaker has a characteristic spectral pattern
speaker_patterns = jr.normal(keys[0], (n_speakers, n_features)) * 0.5
for spk in range(n_speakers):
for utt in range(utterances_per_speaker):
k = jr.fold_in(keys[1], spk * utterances_per_speaker + utt)
noise = jr.normal(k, (n_frames, n_features)) * 0.3
features = speaker_patterns[spk][None, :] + noise
all_features.append(features)
all_labels.append(spk)
perm = jr.permutation(keys[2], len(all_features))
features = jnp.stack(all_features)[perm]
labels = jnp.array(all_labels)[perm]
return features, labels
key = jr.PRNGKey(42)
features, labels = generate_speaker_data(key)
n_speakers = 5
n_features = 40
# x-vector-style model
def init_xvector(key, n_features=40, hidden=128, embed_dim=64, n_speakers=5):
keys = jr.split(key, 8)
params = {
# TDNN layer 1: context [-2, 2]
'tdnn1_w': jr.normal(keys[0], (5, n_features, hidden)) * jnp.sqrt(2.0 / (5 * n_features)),
'tdnn1_b': jnp.zeros(hidden),
# TDNN layer 2: context [-2, 2]
'tdnn2_w': jr.normal(keys[1], (5, hidden, hidden)) * jnp.sqrt(2.0 / (5 * hidden)),
'tdnn2_b': jnp.zeros(hidden),
# TDNN layer 3: context [-3, 3]
'tdnn3_w': jr.normal(keys[2], (7, hidden, hidden)) * jnp.sqrt(2.0 / (7 * hidden)),
'tdnn3_b': jnp.zeros(hidden),
# Segment-level layers (after pooling: 2*hidden -> embed_dim)
'seg1_w': jr.normal(keys[3], (2 * hidden, embed_dim)) * jnp.sqrt(2.0 / (2 * hidden)),
'seg1_b': jnp.zeros(embed_dim),
# Classification head
'cls_w': jr.normal(keys[4], (embed_dim, n_speakers)) * jnp.sqrt(2.0 / embed_dim),
'cls_b': jnp.zeros(n_speakers),
}
return params
def xvector_forward(params, x, return_embedding=False):
"""x: (batch, frames, features) -> logits or embeddings."""
# TDNN layers (1D convolutions)
h = jax.lax.conv_general_dilated(
x.transpose(0, 2, 1), params['tdnn1_w'].transpose(2, 1, 0),
window_strides=(1,), padding='SAME'
).transpose(0, 2, 1) + params['tdnn1_b']
h = jax.nn.relu(h)
h = jax.lax.conv_general_dilated(
h.transpose(0, 2, 1), params['tdnn2_w'].transpose(2, 1, 0),
window_strides=(1,), padding='SAME'
).transpose(0, 2, 1) + params['tdnn2_b']
h = jax.nn.relu(h)
h = jax.lax.conv_general_dilated(
h.transpose(0, 2, 1), params['tdnn3_w'].transpose(2, 1, 0),
window_strides=(1,), padding='SAME'
).transpose(0, 2, 1) + params['tdnn3_b']
h = jax.nn.relu(h)
# Statistics pooling: mean and std over time
mu = jnp.mean(h, axis=1)
sigma = jnp.std(h, axis=1)
pooled = jnp.concatenate([mu, sigma], axis=-1)
# Segment-level layer -> embedding
embedding = jax.nn.relu(pooled @ params['seg1_w'] + params['seg1_b'])
if return_embedding:
return embedding
# Classification
logits = embedding @ params['cls_w'] + params['cls_b']
return logits
def cross_entropy_loss(params, features, labels):
logits = xvector_forward(params, features)
one_hot = jax.nn.one_hot(labels, n_speakers)
log_probs = jax.nn.log_softmax(logits)
return -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1))
grad_fn = jax.jit(jax.value_and_grad(cross_entropy_loss))
# Train
params = init_xvector(jr.PRNGKey(0))
lr = 1e-3
losses = []
for epoch in range(300):
loss_val, grads = grad_fn(params, features, labels)
params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
losses.append(float(loss_val))
# Extract embeddings and visualise with t-SNE-style 2D projection (using PCA)
embeddings = xvector_forward(params, features, return_embedding=True)
# Simple PCA to 2D
emb_centered = embeddings - jnp.mean(embeddings, axis=0)
_, _, Vt = jnp.linalg.svd(emb_centered, full_matrices=False)
proj_2d = emb_centered @ Vt[:2].T
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
axes[0].plot(losses, color='#3498db', linewidth=1.5)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Cross-Entropy Loss')
axes[0].set_title('Speaker Classification Training')
axes[0].set_yscale('log')
colors = ['#3498db', '#e74c3c', '#27ae60', '#f39c12', '#9b59b6']
for spk in range(n_speakers):
mask = labels == spk
axes[1].scatter(proj_2d[mask, 0], proj_2d[mask, 1], c=colors[spk],
label=f'Speaker {spk}', alpha=0.7, s=30)
axes[1].set_xlabel('PC 1')
axes[1].set_ylabel('PC 2')
axes[1].set_title('Speaker Embeddings (PCA projection)')
axes[1].legend()
plt.tight_layout()
plt.show()
# Verification demo: cosine similarity
emb_norm = embeddings / jnp.linalg.norm(embeddings, axis=-1, keepdims=True)
sim_matrix = emb_norm @ emb_norm.T
print(f"Embedding shape: {embeddings.shape}")
print(f"Avg same-speaker similarity: {jnp.mean(sim_matrix[labels[:, None] == labels[None, :]]):.4f}")
print(f"Avg diff-speaker similarity: {jnp.mean(sim_matrix[labels[:, None] != labels[None, :]]):.4f}")
```
- **任务 2:基于余弦相似度评分的说话人确认。** 给定预计算的说话人嵌入,实现一个计算 EER(等错误率)并绘制 DET 曲线的确认系统。
```python
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
def generate_verification_pairs(key, n_speakers=20, dim=64, n_pairs=2000):
"""Generate speaker embeddings and verification trial pairs."""
keys = jr.split(key, 5)
# Speaker centroids with some variance
centroids = jr.normal(keys[0], (n_speakers, dim))
centroids = centroids / jnp.linalg.norm(centroids, axis=-1, keepdims=True)
# Generate enrollment and test embeddings with intra-speaker variance
enroll_embs = []
test_embs = []
trial_labels = [] # 1 = same speaker (target), 0 = different (impostor)
for i in range(n_pairs):
k1, k2, k3 = jr.split(jr.fold_in(keys[1], i), 3)
is_target = jr.bernoulli(k1).astype(int)
spk1 = jr.randint(k2, (), 0, n_speakers)
emb1 = centroids[spk1] + jr.normal(jr.fold_in(k3, 0), (dim,)) * 0.15
if is_target:
spk2 = spk1
else:
spk2 = (spk1 + jr.randint(jr.fold_in(k3, 1), (), 1, n_speakers)) % n_speakers
emb2 = centroids[spk2] + jr.normal(jr.fold_in(k3, 2), (dim,)) * 0.15
enroll_embs.append(emb1)
test_embs.append(emb2)
trial_labels.append(int(is_target))
return (jnp.stack(enroll_embs), jnp.stack(test_embs),
jnp.array(trial_labels))
key = jr.PRNGKey(42)
enroll, test, labels = generate_verification_pairs(key)
# Compute cosine similarity scores
enroll_norm = enroll / jnp.linalg.norm(enroll, axis=-1, keepdims=True)
test_norm = test / jnp.linalg.norm(test, axis=-1, keepdims=True)
scores = jnp.sum(enroll_norm * test_norm, axis=-1)
# Compute FAR and FRR at various thresholds
thresholds = jnp.linspace(-1.0, 1.0, 500)
target_scores = scores[labels == 1]
impostor_scores = scores[labels == 0]
fars = []
frrs = []
for thresh in thresholds:
far = jnp.mean(impostor_scores >= thresh) # false accepts
frr = jnp.mean(target_scores < thresh) # false rejects
fars.append(float(far))
frrs.append(float(frr))
fars = jnp.array(fars)
frrs = jnp.array(frrs)
# Find EER: where FAR ≈ FRR
eer_idx = jnp.argmin(jnp.abs(fars - frrs))
eer = float((fars[eer_idx] + frrs[eer_idx]) / 2)
eer_threshold = float(thresholds[eer_idx])
print(f"Equal Error Rate (EER): {eer:.4f} ({eer*100:.2f}%)")
print(f"EER threshold: {eer_threshold:.4f}")
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
# Score distributions
bins = jnp.linspace(-0.5, 1.0, 60)
axes[0].hist(target_scores, bins=bins, alpha=0.6, color='#27ae60',
label='Target (same speaker)', density=True)
axes[0].hist(impostor_scores, bins=bins, alpha=0.6, color='#e74c3c',
label='Impostor (different speaker)', density=True)
axes[0].axvline(eer_threshold, color='#f39c12', linestyle='--', linewidth=2,
label=f'EER threshold = {eer_threshold:.3f}')
axes[0].set_xlabel('Cosine Similarity Score')
axes[0].set_ylabel('Density')
axes[0].set_title('Score Distributions')
axes[0].legend()
# FAR vs FRR
axes[1].plot(thresholds, fars, color='#e74c3c', linewidth=2, label='FAR')
axes[1].plot(thresholds, frrs, color='#3498db', linewidth=2, label='FRR')
axes[1].axvline(eer_threshold, color='#f39c12', linestyle='--', linewidth=1.5)
axes[1].scatter([eer_threshold], [eer], color='#f39c12', s=100, zorder=5,
label=f'EER = {eer:.4f}')
axes[1].set_xlabel('Threshold')
axes[1].set_ylabel('Error Rate')
axes[1].set_title('FAR and FRR vs Threshold')
axes[1].legend()
# DET curve (FAR vs FRR)
axes[2].plot(fars, frrs, color='#9b59b6', linewidth=2)
axes[2].plot([0, 1], [0, 1], 'k--', alpha=0.3)
axes[2].scatter([eer], [eer], color='#f39c12', s=100, zorder=5,
label=f'EER = {eer:.4f}')
axes[2].set_xlabel('False Acceptance Rate')
axes[2].set_ylabel('False Rejection Rate')
axes[2].set_title('DET Curve')
axes[2].set_xlim([0, 0.5])
axes[2].set_ylim([0, 0.5])
axes[2].legend()
axes[2].set_aspect('equal')
plt.tight_layout()
plt.show()
```
- **任务 3:音频语谱图块嵌入(AST 风格)。** 实现音频语谱图 Transformer 的块提取和嵌入层,可视化语谱图如何被令牌化。
```python
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
# Generate a synthetic spectrogram (harmonic structure + noise)
def generate_spectrogram(key, n_time=128, n_freq=128):
"""Create a synthetic spectrogram with harmonic patterns."""
k1, k2 = jr.split(key)
spec = jr.normal(k1, (n_time, n_freq)) * 0.1
# Add harmonic bands (simulating speech formants)
for f0 in [15, 30, 45, 70]:
width = 3
envelope = jnp.exp(-0.5 * ((jnp.arange(n_freq) - f0) / width) ** 2)
time_mod = 0.5 + 0.5 * jnp.sin(2 * jnp.pi * jnp.arange(n_time) / 40)
spec += jnp.outer(time_mod, envelope)
return jnp.clip(spec, 0, None)
key = jr.PRNGKey(42)
spectrogram = generate_spectrogram(key)
n_time, n_freq = spectrogram.shape
# Patch extraction parameters
patch_h = 16 # time
patch_w = 16 # frequency
stride_h = 16
stride_w = 16
embed_dim = 192 # ViT-Small dimension
n_patches_h = n_time // stride_h
n_patches_w = n_freq // stride_w
n_patches = n_patches_h * n_patches_w
print(f"Spectrogram: {n_time} x {n_freq}")
print(f"Patch size: {patch_h} x {patch_w}")
print(f"Number of patches: {n_patches_h} x {n_patches_w} = {n_patches}")
# Extract patches
def extract_patches(spec, patch_h, patch_w, stride_h, stride_w):
"""Extract non-overlapping patches from spectrogram."""
patches = []
positions = []
for i in range(0, spec.shape[0] - patch_h + 1, stride_h):
for j in range(0, spec.shape[1] - patch_w + 1, stride_w):
patch = spec[i:i+patch_h, j:j+patch_w]
patches.append(patch.flatten())
positions.append((i, j))
return jnp.stack(patches), positions
patches, positions = extract_patches(spectrogram, patch_h, patch_w, stride_h, stride_w)
print(f"Patches shape: {patches.shape}") # (n_patches, patch_h * patch_w)
# Linear projection (patch embedding)
patch_dim = patch_h * patch_w
k1, k2 = jr.split(jr.PRNGKey(0))
W_embed = jr.normal(k1, (patch_dim, embed_dim)) * jnp.sqrt(2.0 / patch_dim)
b_embed = jnp.zeros(embed_dim)
# Learnable positional embeddings
pos_embed = jr.normal(k2, (n_patches + 1, embed_dim)) * 0.02 # +1 for CLS
# CLS token
cls_token = jnp.zeros((1, embed_dim))
# Forward pass
patch_tokens = patches @ W_embed + b_embed # (n_patches, embed_dim)
tokens = jnp.concatenate([cls_token, patch_tokens], axis=0) # (n_patches+1, embed_dim)
tokens = tokens + pos_embed # Add positional embeddings
print(f"Token sequence shape: {tokens.shape}")
print(f"Each token has dimension: {embed_dim}")
# Visualisation
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
# Original spectrogram with patch grid
axes[0, 0].imshow(spectrogram.T, aspect='auto', origin='lower', cmap='magma')
for i in range(0, n_time + 1, stride_h):
axes[0, 0].axvline(i - 0.5, color='white', linewidth=0.5, alpha=0.5)
for j in range(0, n_freq + 1, stride_w):
axes[0, 0].axhline(j - 0.5, color='white', linewidth=0.5, alpha=0.5)
axes[0, 0].set_title(f'Spectrogram with {patch_h}x{patch_w} Patch Grid')
axes[0, 0].set_xlabel('Time frame')
axes[0, 0].set_ylabel('Frequency bin')
# Individual patches visualised
n_show = min(16, n_patches)
patch_grid = patches[:n_show].reshape(n_show, patch_h, patch_w)
combined = jnp.concatenate([patch_grid[i] for i in range(min(8, n_show))], axis=1)
axes[0, 1].imshow(combined.T, aspect='auto', origin='lower', cmap='magma')
axes[0, 1].set_title(f'First {min(8, n_show)} Patches (concatenated)')
axes[0, 1].set_xlabel('Patch index (horizontal)')
axes[0, 1].set_ylabel('Frequency within patch')
# Token embeddings similarity matrix
token_norms = tokens / jnp.linalg.norm(tokens, axis=-1, keepdims=True)
sim = token_norms @ token_norms.T
im = axes[1, 0].imshow(sim, cmap='RdBu_r', vmin=-1, vmax=1)
axes[1, 0].set_title('Token Similarity Matrix (cosine)')
axes[1, 0].set_xlabel('Token index')
axes[1, 0].set_ylabel('Token index')
plt.colorbar(im, ax=axes[1, 0], fraction=0.046)
# Positional embedding similarity
pos_norms = pos_embed / jnp.linalg.norm(pos_embed, axis=-1, keepdims=True)
pos_sim = pos_norms @ pos_norms.T
im2 = axes[1, 1].imshow(pos_sim, cmap='RdBu_r', vmin=-1, vmax=1)
axes[1, 1].set_title('Positional Embedding Similarity')
axes[1, 1].set_xlabel('Position index')
axes[1, 1].set_ylabel('Position index')
plt.colorbar(im2, ax=axes[1, 1], fraction=0.046)
plt.tight_layout()
plt.show()
```
- **任务 4:用于和弦分析的简单色度图计算。** 从合成和声信号计算并可视化色度图,展示音乐信息检索中使用的音高类别折叠方法。
```python
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
# Generate a synthetic musical signal: C major chord -> G major chord
sr = 16000
duration = 2.0
t = jnp.linspace(0, duration, int(sr * duration))
# C major (C4=261.6, E4=329.6, G4=392.0) for first half
# G major (G3=196.0, B3=246.9, D4=293.7) for second half
half = len(t) // 2
c_major = (0.5 * jnp.sin(2 * jnp.pi * 261.63 * t[:half]) +
0.4 * jnp.sin(2 * jnp.pi * 329.63 * t[:half]) +
0.3 * jnp.sin(2 * jnp.pi * 392.00 * t[:half]))
g_major = (0.5 * jnp.sin(2 * jnp.pi * 196.00 * t[:half]) +
0.4 * jnp.sin(2 * jnp.pi * 246.94 * t[:half]) +
0.3 * jnp.sin(2 * jnp.pi * 293.66 * t[:half]))
signal = jnp.concatenate([c_major, g_major])
# Compute STFT
n_fft = 4096 # high resolution for pitch accuracy
hop_length = 512
window = jnp.hanning(n_fft)
def stft(signal, n_fft, hop_length, window):
n_frames = 1 + (len(signal) - n_fft) // hop_length
frames = jnp.stack([
signal[i * hop_length : i * hop_length + n_fft] * window
for i in range(n_frames)
])
return jnp.fft.rfft(frames, n=n_fft)
S = stft(signal, n_fft, hop_length, window)
power_spec = jnp.abs(S) ** 2
freqs = jnp.fft.rfftfreq(n_fft, 1.0 / sr)
# Compute chromagram by mapping frequency bins to pitch classes
# MIDI note number from frequency: 69 + 12 * log2(f / 440)
note_names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
def freq_to_chroma(freq):
"""Map frequency to pitch class (0-11). Returns -1 for freq <= 0."""
midi = 69 + 12 * jnp.log2(jnp.clip(freq, 1e-10, None) / 440.0)
return jnp.round(midi).astype(int) % 12
# Build chromagram: sum power spectrum energy for each pitch class
chromagram = jnp.zeros((power_spec.shape[0], 12))
valid_freqs = freqs[1:] # skip DC
valid_power = power_spec[:, 1:]
for p in range(12):
# Find frequency bins belonging to this pitch class
chroma_bins = freq_to_chroma(valid_freqs)
mask = (chroma_bins == p).astype(jnp.float32)
chromagram = chromagram.at[:, p].set(
jnp.sum(valid_power * mask[None, :], axis=1)
)
# Normalise each frame
chromagram = chromagram / (jnp.max(chromagram, axis=1, keepdims=True) + 1e-8)
# Visualisation
fig, axes = plt.subplots(3, 1, figsize=(14, 10))
# Waveform
axes[0].plot(t[:3000], signal[:3000], color='#3498db', linewidth=0.5,
label='C major')
axes[0].plot(t[half:half+3000], signal[half:half+3000], color='#e74c3c',
linewidth=0.5, label='G major')
axes[0].set_title('Waveform: C major → G major')
axes[0].set_ylabel('Amplitude')
axes[0].set_xlabel('Time (s)')
axes[0].legend()
# Spectrogram (log scale)
time_axis = jnp.arange(power_spec.shape[0]) * hop_length / sr
axes[1].imshow(jnp.log1p(power_spec[:, :500].T), aspect='auto', origin='lower',
cmap='magma', extent=[0, time_axis[-1], 0, freqs[500]])
axes[1].set_title('Power Spectrogram')
axes[1].set_ylabel('Frequency (Hz)')
axes[1].set_xlabel('Time (s)')
# Chromagram
im = axes[2].imshow(chromagram.T, aspect='auto', origin='lower', cmap='YlOrRd',
extent=[0, time_axis[-1], -0.5, 11.5])
axes[2].set_yticks(range(12))
axes[2].set_yticklabels(note_names)
axes[2].set_title('Chromagram (pitch class energy over time)')
axes[2].set_ylabel('Pitch class')
axes[2].set_xlabel('Time (s)')
plt.colorbar(im, ax=axes[2], fraction=0.046, label='Normalised energy')
# Mark expected active pitch classes
mid_frame = chromagram.shape[0] // 2
print(f"C major region - expected: C, E, G")
print(f" Chroma values: {dict(zip(note_names, [f'{v:.2f}' for v in chromagram[mid_frame//2]]))}")
print(f"G major region - expected: G, B, D")
print(f" Chroma values: {dict(zip(note_names, [f'{v:.2f}' for v in chromagram[mid_frame + mid_frame//2]]))}")
plt.tight_layout()
plt.show()
```