2536c937e3
翻译自英文原版 maths-cs-ai-compendium,共 20 章全部完成。 第01章 向量 | 第02章 矩阵 | 第03章 微积分 第04章 统计学 | 第05章 概率论 | 第06章 机器学习 第07章 计算语言学 | 第08章 计算机视觉 | 第09章 音频与语音 第10章 多模态学习 | 第11章 自主系统 | 第12章 图神经网络 第13章 计算与操作系统 | 第14章 数据结构与算法 第15章 生产级软件工程 | 第16章 SIMD与GPU编程 第17章 AI推理 | 第18章 ML系统设计 第19章 应用人工智能 | 第20章 前沿人工智能 翻译说明: - 所有数学公式 $...$ / $$...$$、代码块、图片引用完整保留 - mkdocs.yml 配置中文导航 + language: zh - README.md 已翻译为中文(兼 docs/index.md) - docs/ 目录包含指向各章文件的 symlink - 约 29,000 行中文内容,排除 .cache/ 构建缓存
644 lines
34 KiB
Markdown
644 lines
34 KiB
Markdown
# 说话人与音频分析
|
||
|
||
*说话人与音频分析识别谁在说话、何时说话以及存在哪些非语言声音。本文涵盖说话人确认与识别、i向量、d向量、x向量、说话人日志、音频事件分类、音乐信息检索以及语音情感识别。*
|
||
|
||
- 在文件 01 中,我们构建了信号处理基础:语谱图、MFCC 和梅尔滤波器组。在文件 02 中,我们识别了所说的内容。现在我们要问:是谁说的、何时说的、以及音频中还在发生什么。说话人识别、说话人日志、音频分类和音乐分析都共享一条主线:学习能够为当前任务捕捉正确不变性的紧凑嵌入,这与第 06 章中的嵌入思想一脉相承。
|
||
|
||
- 可以把说话人识别想象成在电话中辨认朋友的声音。你不需要理解词汇;某种关于音色、语速和嗓音特质的东西对这个人来说是独一无二的。说话人识别系统学会从原始音频中提取这种"声纹",忽略说的是什么,专注于怎么说的。
|
||
|
||
- **说话人识别**是两类相关任务的总称:
|
||
- **说话人确认**(SV):给定一个声明的身份和一段音频片段,判断说话人是否与其声称的身份一致。这是一个二元决策(接受或拒绝),是基于语音的身份验证技术("嘿 Siri,这是我的声音吗?")背后的核心原理。
|
||
- **说话人识别**(SI):给定一段音频片段和一个已知说话人库,判断该片段由哪个说话人产生。这是一个多分类问题。
|
||
|
||

|
||
|
||
- 两种任务共享相同的底层表示:一个固定维度的**说话人嵌入**,它捕捉说话人的身份特征而与所说内容无关。区别仅在于决策阶段:确认比较两个嵌入,识别则在候选嵌入中找到最近邻。
|
||
|
||
- **余弦相似度**是比较说话人嵌入的标准度量。给定注册嵌入 $e$ 和测试嵌入 $t$:
|
||
|
||
$$s = \frac{e \cdot t}{\|e\| \, \|t\|}$$
|
||
|
||
- 阈值 $\theta$ 决定接受/拒绝决策:若 $s > \theta$,则接受。阈值在**错误接受率(FAR)**和**错误拒绝率(FRR)**之间权衡。**等错误率(EER)**,即 FAR = FRR 时的值,是标准评估指标。EER 越低表示性能越好。最先进的系统在标准基准(VoxCeleb)上可实现低于 1% 的 EER。
|
||
|
||
- **i向量**(Dehak 等人,2010)是深度学习之前主导性的说话人嵌入方法。其思想源于因子分析(第 02 章的矩阵分解和第 04 章的降维)。一个**通用背景模型(UBM)**——基于多样本说话人训练的大型 GMM——定义了一个超向量空间。每条语音的 GMM 超向量被投影到低维的**全可变性空间**:
|
||
|
||
$$M = m + Tw$$
|
||
|
||
- 其中 $M$ 是该语音的 GMM 超向量,$m$ 是 UBM 均值超向量,$T$ 是全可变性矩阵(从数据中学习得到),$w$ 是 i 向量,一个低维(通常为 400-600 维)表示,同时捕捉说话人变异和信道变异。
|
||
|
||
- 为了从 i 向量中去除信道变异,**概率线性判别分析(PLDA)**将 i 向量建模为说话人特定潜变量和信道特定潜变量之和。PLDA 为确认任务提供了一个有原则的对数似然比分数:
|
||
|
||
$$\text{score}(w_1, w_2) = \log \frac{P(w_1, w_2 \mid \text{同一说话人})}{P(w_1 \mid \text{说话人}_1) \, P(w_2 \mid \text{说话人}_2)}$$
|
||
|
||
- **d向量**(Variani 等人,2014)是第一个神经说话人嵌入。一个为说话人分类训练的 DNN 处理帧级特征,通过对整条语音中最后一层隐藏层激活值求平均,提取出固定维度的表示。虽然简单但有效,d向量证明了神经网络可以在没有 i 向量复杂统计机制的情况下学习到说话人判别性特征。
|
||
|
||
- **x向量**(Snyder 等人,2018)使用**时延神经网络(TDNN)**架构显著推进了神经说话人嵌入。TDNN 是具有特定上下文窗口的 1D 卷积,与文件 03 中 WaveNet 的扩张卷积有关,但应用于帧级特征而非原始波形样本。
|
||
|
||

|
||
|
||
- x向量架构包含三个阶段:
|
||
- **帧级层**:一组 TDNN 层处理 MFCC(来自文件 01),时间上下文逐步扩大。每一层都有一个固定的上下文窗口(例如第一层为 $\{t-2, t-1, t, t+1, t+2\}$,后续层窗口更宽)。
|
||
- **统计池化**:在帧级层之后,计算帧级输出在整个语音上的均值和标准差,产生一个与语音时长无关的固定维度向量:
|
||
|
||
```math
|
||
\begin{aligned}
|
||
\mu &= \frac{1}{T} \sum_{t=1}^{T} h_t \\
|
||
\sigma &= \sqrt{\frac{1}{T} \sum_{t=1}^{T} (h_t - \mu)^2}
|
||
\end{aligned}
|
||
```
|
||
|
||
- 其中 $h_t$ 是时间 $t$ 的帧级输出。拼接 $[\mu; \sigma]$ 即为池化后的表示。
|
||
- **段级层**:全连接层处理池化后的表示。第一个段级层的输出(softmax 之前)即为 x 向量嵌入。
|
||
|
||
- x向量使用说话人身份上的标准交叉熵损失进行训练。尽管是为分类任务训练的,但学习到的中间表示(x向量)能很好地泛化到未见过的说话人,因为网络学习的是提取说话人判别性特征,而非记忆特定说话人。
|
||
|
||
- **ECAPA-TDNN**(Desplanques 等人,2020)是目前最先进的基于 TDNN 的说话人识别架构。它在 x 向量基础上引入了三项改进:
|
||
- **压缩激励(SE)模块**:通道注意力(来自第 08 章的 SENet),根据全局上下文重新加权特征通道,使模型能够强调与说话人相关的通道。
|
||
- **Res2Net 风格的多尺度特征**:在每个 TDNN 模块内,通道被分成若干组,以层级方式处理,在多个时间分辨率上创建特征(类似于第 08 章的多尺度特征提取)。
|
||
- **注意力统计池化**:不再使用等权平均,而是通过注意力机制为每一帧对池化统计量的贡献分配权重。包含更多说话人判别性内容的帧(如元音,承载更多说话人信息)获得更高的注意力权重:
|
||
|
||
$$\alpha_t = \frac{\exp(v^T f(h_t))}{\sum_{\tau} \exp(v^T f(h_\tau))}$$
|
||
|
||
- 其中 $f$ 是一个小型神经网络,$v$ 是一个学习到的注意力向量。注意力加权的均值和标准差变为 $\tilde{\mu} = \sum_t \alpha_t h_t$ 和 $\tilde{\sigma} = \sqrt{\sum_t \alpha_t (h_t - \tilde{\mu})^2}$。
|
||
|
||
- ECAPA-TDNN 通常使用 **AAM-Softmax**(附加角度间隔 Softmax)进行训练,它在分类损失中添加了角度间隔惩罚,将同一说话人的嵌入推得更近,不同说话人的嵌入在超球面上推得更远:
|
||
|
||
$$L = -\log \frac{e^{s \cos(\theta_{y_i} + m)}}{e^{s \cos(\theta_{y_i} + m)} + \sum_{j \neq y_i} e^{s \cos \theta_j}}$$
|
||
|
||
- 其中 $\theta_{y_i}$ 是嵌入与真实类别权重向量之间的夹角,$m$ 是间隔(通常为 0.2),$s$ 是缩放因子(通常为 30)。该损失函数来自人脸识别(第 08 章的 ArcFace),在说话人确认中非常有效。
|
||
|
||
- **说话人日志**回答了多方录音中"谁在什么时候说话"的问题。可以把这想象成给时间线上色:每种颜色代表一个不同的说话人,系统必须确定每个说话人何时活跃,包括重叠语音的情况。
|
||
|
||

|
||
|
||
- **基于聚类的说话人日志**是传统的流水线方法:
|
||
- **分割**:将音频划分为短段(通常为 1-2 秒),使用滑动窗口或说话人变化检测。
|
||
- **嵌入提取**:为每个片段提取说话人嵌入(x向量、ECAPA-TDNN)。
|
||
- **聚类**:按说话人对片段进行分组。**凝聚层次聚类(AHC)**是标准方法:开始时每个片段自成一类,然后迭代合并两个最相似的类,直到满足停止条件(基于距离阈值或目标说话人数)。
|
||
- **重分割**:使用基于维特比算法的重对齐来优化边界。
|
||
|
||
- 说话人数量通常事先未知,这使得该问题比标准聚类更困难。使用基于特征值阈值确定 $k$ 的谱聚类是另一种常见方法。
|
||
|
||
- **端到端神经说话人日志(EEND)**(Fujita 等人,2019)将说话人日志框架化为一个多标签分类问题。一个神经网络(通常是基于自注意力的模型,第 07 章的 transformer)将整段录音作为输入,为每一帧输出每个说话人的二元活动标签。这直接处理了重叠语音,而这是基于聚类方法的主要弱点。
|
||
|
||
- EEND 对 $S$ 个说话人在帧 $t$ 的输出为:
|
||
|
||
$$\hat{y}_{t,s} = \sigma(f_s(h_t))$$
|
||
|
||
- 其中 $h_t$ 是帧 $t$ 处的 transformer 输出,$f_s$ 是说话人 $s$ 的线性投影。训练损失是在说话人和帧上求和得到的二元交叉熵。一个关键挑战是说话人数量必须固定,或者使用可变输出架构(EEND-EDA 使用带吸引子的编码器-解码器)来处理。
|
||
|
||
- **置换不变训练(PIT)**用于处理说话人日志中的标签歧义问题:由于说话人没有固有顺序,需要对所有可能的说话人到输出分配计算损失,并取最小值(这与文件 05 中源分离使用的 PIT 相同)。
|
||
|
||
- **音频分类**为整段音频片段分配一个标签。与转录语音的 ASR(文件 02)不同,音频分类涵盖更广的范围:环境声音(警笛、雨声、狗吠)、音乐流派(摇滚、爵士、古典)以及一般音频事件。
|
||
|
||
- 标准方法遵循第 08 章的图像分类范式:将音频表示为语谱图(一个二维时间-频率图像),然后应用 CNN 或 transformer 分类器。这种谱图-图像方法利用了计算机视觉几十年来的进展。
|
||
|
||
- **环境声音分类(ESC)**使用 ESC-50(50 类,2000 个片段)和 UrbanSound8K 等数据集。典型架构是应用于对数梅尔语谱图的 CNN(第 06 章)。数据增强至关重要:时间拉伸、音高偏移、添加背景噪声以及 **SpecAugment**(文件 02 的掩码方法应用于语谱图)都能提升泛化能力。
|
||
|
||
- **音频事件检测**(声音事件检测,SED)是分类的时间维度对应任务:不仅仅要知道存在哪些事件,还要知道它们何时开始和结束。**AudioSet**(Gemmeke 等人,2017)是大规模基准,包含 527 个事件类别和超过 200 万个来自 YouTube 的 10 秒片段,每个片段都有弱标注(片段级标签,而非帧级)。
|
||
|
||
- **弱监督 SED** 必须从片段级标签学习帧级预测。标准方法使用 CNN 产生帧级类别概率,然后通过注意力池化聚合成片段级预测:
|
||
|
||
$$\hat{Y}_c = \sigma\left(\sum_t \alpha_{t,c} \cdot f_{t,c}\right)$$
|
||
|
||
- 其中 $f_{t,c}$ 是类别 $c$ 在时间 $t$ 的帧级 logit,$\alpha_{t,c}$ 是注意力权重。片段级预测 $\hat{Y}_c$ 根据片段级标签进行训练。
|
||
|
||
- **声学场景分类(ASC)**对整体环境进行分类:"机场"、"公园"、"地铁站"、"办公室"。这是一个整体性任务:模型必须捕捉一般的声学纹理而非特定事件。DCASE 挑战系列每年对 ASC 进行基准测试,获奖系统通常使用多分辨率语谱图上的 CNN 集成。
|
||
|
||
- **音频嵌入**是从大规模音频数据中学习到的通用表示,类似于可迁移到下游任务的词嵌入(第 07 章)或图像特征(第 08 章)。
|
||
|
||
- **VGGish**(Hershey 等人,2017)将 VGG 图像分类网络(第 08 章)适配到音频领域。它通过一个在 AudioSet 上预训练的类 VGG CNN 处理 0.96 秒的对数梅尔语谱图块,每块产生一个 128 维嵌入。VGGish 嵌入可作为下游任务的通用音频特征,类似于 ImageNet 预训练 CNN 提供视觉特征的方式。
|
||
|
||
- **PANNs**(预训练音频神经网络,Kong 等人,2020)是一系列 CNN 架构(CNN6、CNN10、CNN14),在完整的 AudioSet 上为音频标记任务训练。CNN14 使用最广泛,是一个 14 层 CNN,将对数梅尔语谱图作为输入,使用 $3 \times 3$ 卷积。PANNs 产生 2048 维嵌入,在多种音频任务上实现了最先进的迁移学习性能。
|
||
|
||
- **音频语谱图 Transformer(AST)**(Gong 等人,2021)将视觉 Transformer(ViT,第 08 章)架构直接应用于音频语谱图。语谱图被分割成 $16 \times 16$ 的块(就像 ViT 分割图像一样),每个块被线性投影为令牌嵌入,添加位置嵌入,然后由标准 Transformer 编码器(第 07 章)处理序列。[CLS] 令牌的输出用于分类。
|
||
|
||

|
||
|
||
- AST 受益于 **ImageNet 预训练**:由于语谱图是 2D 图像,AST 从 ImageNet 图像上预训练的 ViT 初始化,然后在音频上微调。这种跨模态迁移出奇地有效,因为两个域共享低级特征(边缘、纹理),并且位置嵌入可以插值以处理不同大小的语谱图。
|
||
|
||
- **HTS-AT**(Chen 等人,2022)使用分层 Swin Transformer 架构(第 08 章的移位窗口注意力)改进了 AST,在降低计算成本的同时通过多尺度特征提取提升了性能。
|
||
|
||
- **BEATs**(Chen 等人,2023)使用了一种音频特定的预训练策略:使用离散标记器进行迭代掩码预测(类似于文件 02 中 wav2vec 2.0 的方法,但应用于通用音频)。标记器逐步细化,创建越来越具有语义意义的离散音频令牌。
|
||
|
||
- **基于嵌入的说话人日志**结合了说话人嵌入与时序建模。像 Pyannote.audio 这样的现代系统使用三阶段流水线:(1) 检测说话人切换和重叠语音的神经分割模型,(2) 应用于每个检测到的片段的嵌入提取阶段(ECAPA-TDNN),以及 (3) 聚类以在整个录音中分配说话人身份。
|
||
|
||
- **音乐信息检索(MIR)**将音频分析应用于音乐。文件 01 中的谱图表示在这里尤其有用,因为音乐具有丰富的和声结构。
|
||
|
||
- **节拍跟踪**检测音乐的节奏脉冲。标准方法从语谱图计算**起始强度包络**(检测表示音符起始的能量增加),然后使用自相关或节拍图谱找到节奏,最后使用动态规划跟踪单个节拍位置,找到最能匹配起始包络同时保持稳定节奏的节拍时间序列。
|
||
|
||
- **和弦识别**识别随时间变化的和声内容。输入通常是**色度图**(也称为音高类别分布图):一个 12 维表示,将所有八度折叠在一起,显示 12 个音高类别(C、C#、D、…、B)中每个类别的能量。CNN 或 RNN(第 06 章)将每个时间帧分类到标准和弦标签之一(C 大调、A 小调、G7 等)。
|
||
|
||
- 色度图通过将每个频率区间映射到其音高类别,从 STFT(文件 01)计算得到:
|
||
|
||
$$\text{chroma}(p) = \sum_{k : \text{pitch}(k) \bmod 12 = p} |X(k)|^2$$
|
||
|
||
- 其中 $p \in \{0, 1, \ldots, 11\}$ 是音高类别,$\text{pitch}(k)$ 将频率区间 $k$ 映射到其 MIDI 音符编号。
|
||
|
||
- **源分离基础**(详见文件 05)将音乐录音分离为单独的乐器(人声、鼓、贝斯、其他)。这是混音、卡拉 OK 和音乐转录等 MIR 应用的核心。像 Demucs(文件 05)这样的模型在标准 MUSDB18 基准上达到了非常好的分离质量。
|
||
|
||
- **音乐标记**为歌曲分配标签(流派、情感、乐器、时代)。它本质上是应用于音乐的音频分类,使用相同的 CNN-语谱图方法。Million Song Dataset 和 MagnaTagATune 是标准基准。
|
||
|
||
- **音频指纹**从短片段中识别特定录音,即使存在噪声、混响或压缩伪影。经典系统是 Shazam,它对星座图(语谱图中的显著峰值)进行哈希处理。神经方法学习对声学退化具有不变性、同时对不同录音保持判别性的鲁棒嵌入,这与第 06 章和第 08 章中的不变特征学习一脉相承。
|
||
|
||
## 编程任务(使用 Colab 或笔记本)
|
||
|
||
- **任务 1:带统计池化的说话人嵌入提取。** 构建一个简单的 x向量风格模型,通过 TDNN 层和统计池化处理帧级特征以产生说话人嵌入。
|
||
|
||
```python
|
||
import jax
|
||
import jax.numpy as jnp
|
||
import jax.random as jr
|
||
import matplotlib.pyplot as plt
|
||
|
||
# Simulate frame-level MFCC features for multiple speakers
|
||
def generate_speaker_data(key, n_speakers=5, utterances_per_speaker=20,
|
||
n_frames=100, n_features=40):
|
||
"""Generate synthetic speaker data with speaker-dependent patterns."""
|
||
keys = jr.split(key, 3)
|
||
all_features = []
|
||
all_labels = []
|
||
|
||
# Each speaker has a characteristic spectral pattern
|
||
speaker_patterns = jr.normal(keys[0], (n_speakers, n_features)) * 0.5
|
||
|
||
for spk in range(n_speakers):
|
||
for utt in range(utterances_per_speaker):
|
||
k = jr.fold_in(keys[1], spk * utterances_per_speaker + utt)
|
||
noise = jr.normal(k, (n_frames, n_features)) * 0.3
|
||
features = speaker_patterns[spk][None, :] + noise
|
||
all_features.append(features)
|
||
all_labels.append(spk)
|
||
|
||
perm = jr.permutation(keys[2], len(all_features))
|
||
features = jnp.stack(all_features)[perm]
|
||
labels = jnp.array(all_labels)[perm]
|
||
return features, labels
|
||
|
||
key = jr.PRNGKey(42)
|
||
features, labels = generate_speaker_data(key)
|
||
n_speakers = 5
|
||
n_features = 40
|
||
|
||
# x-vector-style model
|
||
def init_xvector(key, n_features=40, hidden=128, embed_dim=64, n_speakers=5):
|
||
keys = jr.split(key, 8)
|
||
params = {
|
||
# TDNN layer 1: context [-2, 2]
|
||
'tdnn1_w': jr.normal(keys[0], (5, n_features, hidden)) * jnp.sqrt(2.0 / (5 * n_features)),
|
||
'tdnn1_b': jnp.zeros(hidden),
|
||
# TDNN layer 2: context [-2, 2]
|
||
'tdnn2_w': jr.normal(keys[1], (5, hidden, hidden)) * jnp.sqrt(2.0 / (5 * hidden)),
|
||
'tdnn2_b': jnp.zeros(hidden),
|
||
# TDNN layer 3: context [-3, 3]
|
||
'tdnn3_w': jr.normal(keys[2], (7, hidden, hidden)) * jnp.sqrt(2.0 / (7 * hidden)),
|
||
'tdnn3_b': jnp.zeros(hidden),
|
||
# Segment-level layers (after pooling: 2*hidden -> embed_dim)
|
||
'seg1_w': jr.normal(keys[3], (2 * hidden, embed_dim)) * jnp.sqrt(2.0 / (2 * hidden)),
|
||
'seg1_b': jnp.zeros(embed_dim),
|
||
# Classification head
|
||
'cls_w': jr.normal(keys[4], (embed_dim, n_speakers)) * jnp.sqrt(2.0 / embed_dim),
|
||
'cls_b': jnp.zeros(n_speakers),
|
||
}
|
||
return params
|
||
|
||
def xvector_forward(params, x, return_embedding=False):
|
||
"""x: (batch, frames, features) -> logits or embeddings."""
|
||
# TDNN layers (1D convolutions)
|
||
h = jax.lax.conv_general_dilated(
|
||
x.transpose(0, 2, 1), params['tdnn1_w'].transpose(2, 1, 0),
|
||
window_strides=(1,), padding='SAME'
|
||
).transpose(0, 2, 1) + params['tdnn1_b']
|
||
h = jax.nn.relu(h)
|
||
|
||
h = jax.lax.conv_general_dilated(
|
||
h.transpose(0, 2, 1), params['tdnn2_w'].transpose(2, 1, 0),
|
||
window_strides=(1,), padding='SAME'
|
||
).transpose(0, 2, 1) + params['tdnn2_b']
|
||
h = jax.nn.relu(h)
|
||
|
||
h = jax.lax.conv_general_dilated(
|
||
h.transpose(0, 2, 1), params['tdnn3_w'].transpose(2, 1, 0),
|
||
window_strides=(1,), padding='SAME'
|
||
).transpose(0, 2, 1) + params['tdnn3_b']
|
||
h = jax.nn.relu(h)
|
||
|
||
# Statistics pooling: mean and std over time
|
||
mu = jnp.mean(h, axis=1)
|
||
sigma = jnp.std(h, axis=1)
|
||
pooled = jnp.concatenate([mu, sigma], axis=-1)
|
||
|
||
# Segment-level layer -> embedding
|
||
embedding = jax.nn.relu(pooled @ params['seg1_w'] + params['seg1_b'])
|
||
|
||
if return_embedding:
|
||
return embedding
|
||
|
||
# Classification
|
||
logits = embedding @ params['cls_w'] + params['cls_b']
|
||
return logits
|
||
|
||
def cross_entropy_loss(params, features, labels):
|
||
logits = xvector_forward(params, features)
|
||
one_hot = jax.nn.one_hot(labels, n_speakers)
|
||
log_probs = jax.nn.log_softmax(logits)
|
||
return -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1))
|
||
|
||
grad_fn = jax.jit(jax.value_and_grad(cross_entropy_loss))
|
||
|
||
# Train
|
||
params = init_xvector(jr.PRNGKey(0))
|
||
lr = 1e-3
|
||
losses = []
|
||
|
||
for epoch in range(300):
|
||
loss_val, grads = grad_fn(params, features, labels)
|
||
params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
|
||
losses.append(float(loss_val))
|
||
|
||
# Extract embeddings and visualise with t-SNE-style 2D projection (using PCA)
|
||
embeddings = xvector_forward(params, features, return_embedding=True)
|
||
|
||
# Simple PCA to 2D
|
||
emb_centered = embeddings - jnp.mean(embeddings, axis=0)
|
||
_, _, Vt = jnp.linalg.svd(emb_centered, full_matrices=False)
|
||
proj_2d = emb_centered @ Vt[:2].T
|
||
|
||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||
|
||
axes[0].plot(losses, color='#3498db', linewidth=1.5)
|
||
axes[0].set_xlabel('Epoch')
|
||
axes[0].set_ylabel('Cross-Entropy Loss')
|
||
axes[0].set_title('Speaker Classification Training')
|
||
axes[0].set_yscale('log')
|
||
|
||
colors = ['#3498db', '#e74c3c', '#27ae60', '#f39c12', '#9b59b6']
|
||
for spk in range(n_speakers):
|
||
mask = labels == spk
|
||
axes[1].scatter(proj_2d[mask, 0], proj_2d[mask, 1], c=colors[spk],
|
||
label=f'Speaker {spk}', alpha=0.7, s=30)
|
||
axes[1].set_xlabel('PC 1')
|
||
axes[1].set_ylabel('PC 2')
|
||
axes[1].set_title('Speaker Embeddings (PCA projection)')
|
||
axes[1].legend()
|
||
|
||
plt.tight_layout()
|
||
plt.show()
|
||
|
||
# Verification demo: cosine similarity
|
||
emb_norm = embeddings / jnp.linalg.norm(embeddings, axis=-1, keepdims=True)
|
||
sim_matrix = emb_norm @ emb_norm.T
|
||
print(f"Embedding shape: {embeddings.shape}")
|
||
print(f"Avg same-speaker similarity: {jnp.mean(sim_matrix[labels[:, None] == labels[None, :]]):.4f}")
|
||
print(f"Avg diff-speaker similarity: {jnp.mean(sim_matrix[labels[:, None] != labels[None, :]]):.4f}")
|
||
```
|
||
|
||
- **任务 2:基于余弦相似度评分的说话人确认。** 给定预计算的说话人嵌入,实现一个计算 EER(等错误率)并绘制 DET 曲线的确认系统。
|
||
|
||
```python
|
||
import jax
|
||
import jax.numpy as jnp
|
||
import jax.random as jr
|
||
import matplotlib.pyplot as plt
|
||
|
||
def generate_verification_pairs(key, n_speakers=20, dim=64, n_pairs=2000):
|
||
"""Generate speaker embeddings and verification trial pairs."""
|
||
keys = jr.split(key, 5)
|
||
|
||
# Speaker centroids with some variance
|
||
centroids = jr.normal(keys[0], (n_speakers, dim))
|
||
centroids = centroids / jnp.linalg.norm(centroids, axis=-1, keepdims=True)
|
||
|
||
# Generate enrollment and test embeddings with intra-speaker variance
|
||
enroll_embs = []
|
||
test_embs = []
|
||
trial_labels = [] # 1 = same speaker (target), 0 = different (impostor)
|
||
|
||
for i in range(n_pairs):
|
||
k1, k2, k3 = jr.split(jr.fold_in(keys[1], i), 3)
|
||
is_target = jr.bernoulli(k1).astype(int)
|
||
|
||
spk1 = jr.randint(k2, (), 0, n_speakers)
|
||
emb1 = centroids[spk1] + jr.normal(jr.fold_in(k3, 0), (dim,)) * 0.15
|
||
|
||
if is_target:
|
||
spk2 = spk1
|
||
else:
|
||
spk2 = (spk1 + jr.randint(jr.fold_in(k3, 1), (), 1, n_speakers)) % n_speakers
|
||
|
||
emb2 = centroids[spk2] + jr.normal(jr.fold_in(k3, 2), (dim,)) * 0.15
|
||
|
||
enroll_embs.append(emb1)
|
||
test_embs.append(emb2)
|
||
trial_labels.append(int(is_target))
|
||
|
||
return (jnp.stack(enroll_embs), jnp.stack(test_embs),
|
||
jnp.array(trial_labels))
|
||
|
||
key = jr.PRNGKey(42)
|
||
enroll, test, labels = generate_verification_pairs(key)
|
||
|
||
# Compute cosine similarity scores
|
||
enroll_norm = enroll / jnp.linalg.norm(enroll, axis=-1, keepdims=True)
|
||
test_norm = test / jnp.linalg.norm(test, axis=-1, keepdims=True)
|
||
scores = jnp.sum(enroll_norm * test_norm, axis=-1)
|
||
|
||
# Compute FAR and FRR at various thresholds
|
||
thresholds = jnp.linspace(-1.0, 1.0, 500)
|
||
|
||
target_scores = scores[labels == 1]
|
||
impostor_scores = scores[labels == 0]
|
||
|
||
fars = []
|
||
frrs = []
|
||
for thresh in thresholds:
|
||
far = jnp.mean(impostor_scores >= thresh) # false accepts
|
||
frr = jnp.mean(target_scores < thresh) # false rejects
|
||
fars.append(float(far))
|
||
frrs.append(float(frr))
|
||
|
||
fars = jnp.array(fars)
|
||
frrs = jnp.array(frrs)
|
||
|
||
# Find EER: where FAR ≈ FRR
|
||
eer_idx = jnp.argmin(jnp.abs(fars - frrs))
|
||
eer = float((fars[eer_idx] + frrs[eer_idx]) / 2)
|
||
eer_threshold = float(thresholds[eer_idx])
|
||
|
||
print(f"Equal Error Rate (EER): {eer:.4f} ({eer*100:.2f}%)")
|
||
print(f"EER threshold: {eer_threshold:.4f}")
|
||
|
||
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
|
||
|
||
# Score distributions
|
||
bins = jnp.linspace(-0.5, 1.0, 60)
|
||
axes[0].hist(target_scores, bins=bins, alpha=0.6, color='#27ae60',
|
||
label='Target (same speaker)', density=True)
|
||
axes[0].hist(impostor_scores, bins=bins, alpha=0.6, color='#e74c3c',
|
||
label='Impostor (different speaker)', density=True)
|
||
axes[0].axvline(eer_threshold, color='#f39c12', linestyle='--', linewidth=2,
|
||
label=f'EER threshold = {eer_threshold:.3f}')
|
||
axes[0].set_xlabel('Cosine Similarity Score')
|
||
axes[0].set_ylabel('Density')
|
||
axes[0].set_title('Score Distributions')
|
||
axes[0].legend()
|
||
|
||
# FAR vs FRR
|
||
axes[1].plot(thresholds, fars, color='#e74c3c', linewidth=2, label='FAR')
|
||
axes[1].plot(thresholds, frrs, color='#3498db', linewidth=2, label='FRR')
|
||
axes[1].axvline(eer_threshold, color='#f39c12', linestyle='--', linewidth=1.5)
|
||
axes[1].scatter([eer_threshold], [eer], color='#f39c12', s=100, zorder=5,
|
||
label=f'EER = {eer:.4f}')
|
||
axes[1].set_xlabel('Threshold')
|
||
axes[1].set_ylabel('Error Rate')
|
||
axes[1].set_title('FAR and FRR vs Threshold')
|
||
axes[1].legend()
|
||
|
||
# DET curve (FAR vs FRR)
|
||
axes[2].plot(fars, frrs, color='#9b59b6', linewidth=2)
|
||
axes[2].plot([0, 1], [0, 1], 'k--', alpha=0.3)
|
||
axes[2].scatter([eer], [eer], color='#f39c12', s=100, zorder=5,
|
||
label=f'EER = {eer:.4f}')
|
||
axes[2].set_xlabel('False Acceptance Rate')
|
||
axes[2].set_ylabel('False Rejection Rate')
|
||
axes[2].set_title('DET Curve')
|
||
axes[2].set_xlim([0, 0.5])
|
||
axes[2].set_ylim([0, 0.5])
|
||
axes[2].legend()
|
||
axes[2].set_aspect('equal')
|
||
|
||
plt.tight_layout()
|
||
plt.show()
|
||
```
|
||
|
||
- **任务 3:音频语谱图块嵌入(AST 风格)。** 实现音频语谱图 Transformer 的块提取和嵌入层,可视化语谱图如何被令牌化。
|
||
|
||
```python
|
||
import jax
|
||
import jax.numpy as jnp
|
||
import jax.random as jr
|
||
import matplotlib.pyplot as plt
|
||
|
||
# Generate a synthetic spectrogram (harmonic structure + noise)
|
||
def generate_spectrogram(key, n_time=128, n_freq=128):
|
||
"""Create a synthetic spectrogram with harmonic patterns."""
|
||
k1, k2 = jr.split(key)
|
||
spec = jr.normal(k1, (n_time, n_freq)) * 0.1
|
||
|
||
# Add harmonic bands (simulating speech formants)
|
||
for f0 in [15, 30, 45, 70]:
|
||
width = 3
|
||
envelope = jnp.exp(-0.5 * ((jnp.arange(n_freq) - f0) / width) ** 2)
|
||
time_mod = 0.5 + 0.5 * jnp.sin(2 * jnp.pi * jnp.arange(n_time) / 40)
|
||
spec += jnp.outer(time_mod, envelope)
|
||
|
||
return jnp.clip(spec, 0, None)
|
||
|
||
key = jr.PRNGKey(42)
|
||
spectrogram = generate_spectrogram(key)
|
||
n_time, n_freq = spectrogram.shape
|
||
|
||
# Patch extraction parameters
|
||
patch_h = 16 # time
|
||
patch_w = 16 # frequency
|
||
stride_h = 16
|
||
stride_w = 16
|
||
embed_dim = 192 # ViT-Small dimension
|
||
|
||
n_patches_h = n_time // stride_h
|
||
n_patches_w = n_freq // stride_w
|
||
n_patches = n_patches_h * n_patches_w
|
||
|
||
print(f"Spectrogram: {n_time} x {n_freq}")
|
||
print(f"Patch size: {patch_h} x {patch_w}")
|
||
print(f"Number of patches: {n_patches_h} x {n_patches_w} = {n_patches}")
|
||
|
||
# Extract patches
|
||
def extract_patches(spec, patch_h, patch_w, stride_h, stride_w):
|
||
"""Extract non-overlapping patches from spectrogram."""
|
||
patches = []
|
||
positions = []
|
||
for i in range(0, spec.shape[0] - patch_h + 1, stride_h):
|
||
for j in range(0, spec.shape[1] - patch_w + 1, stride_w):
|
||
patch = spec[i:i+patch_h, j:j+patch_w]
|
||
patches.append(patch.flatten())
|
||
positions.append((i, j))
|
||
return jnp.stack(patches), positions
|
||
|
||
patches, positions = extract_patches(spectrogram, patch_h, patch_w, stride_h, stride_w)
|
||
print(f"Patches shape: {patches.shape}") # (n_patches, patch_h * patch_w)
|
||
|
||
# Linear projection (patch embedding)
|
||
patch_dim = patch_h * patch_w
|
||
k1, k2 = jr.split(jr.PRNGKey(0))
|
||
W_embed = jr.normal(k1, (patch_dim, embed_dim)) * jnp.sqrt(2.0 / patch_dim)
|
||
b_embed = jnp.zeros(embed_dim)
|
||
|
||
# Learnable positional embeddings
|
||
pos_embed = jr.normal(k2, (n_patches + 1, embed_dim)) * 0.02 # +1 for CLS
|
||
|
||
# CLS token
|
||
cls_token = jnp.zeros((1, embed_dim))
|
||
|
||
# Forward pass
|
||
patch_tokens = patches @ W_embed + b_embed # (n_patches, embed_dim)
|
||
tokens = jnp.concatenate([cls_token, patch_tokens], axis=0) # (n_patches+1, embed_dim)
|
||
tokens = tokens + pos_embed # Add positional embeddings
|
||
|
||
print(f"Token sequence shape: {tokens.shape}")
|
||
print(f"Each token has dimension: {embed_dim}")
|
||
|
||
# Visualisation
|
||
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
||
|
||
# Original spectrogram with patch grid
|
||
axes[0, 0].imshow(spectrogram.T, aspect='auto', origin='lower', cmap='magma')
|
||
for i in range(0, n_time + 1, stride_h):
|
||
axes[0, 0].axvline(i - 0.5, color='white', linewidth=0.5, alpha=0.5)
|
||
for j in range(0, n_freq + 1, stride_w):
|
||
axes[0, 0].axhline(j - 0.5, color='white', linewidth=0.5, alpha=0.5)
|
||
axes[0, 0].set_title(f'Spectrogram with {patch_h}x{patch_w} Patch Grid')
|
||
axes[0, 0].set_xlabel('Time frame')
|
||
axes[0, 0].set_ylabel('Frequency bin')
|
||
|
||
# Individual patches visualised
|
||
n_show = min(16, n_patches)
|
||
patch_grid = patches[:n_show].reshape(n_show, patch_h, patch_w)
|
||
combined = jnp.concatenate([patch_grid[i] for i in range(min(8, n_show))], axis=1)
|
||
axes[0, 1].imshow(combined.T, aspect='auto', origin='lower', cmap='magma')
|
||
axes[0, 1].set_title(f'First {min(8, n_show)} Patches (concatenated)')
|
||
axes[0, 1].set_xlabel('Patch index (horizontal)')
|
||
axes[0, 1].set_ylabel('Frequency within patch')
|
||
|
||
# Token embeddings similarity matrix
|
||
token_norms = tokens / jnp.linalg.norm(tokens, axis=-1, keepdims=True)
|
||
sim = token_norms @ token_norms.T
|
||
im = axes[1, 0].imshow(sim, cmap='RdBu_r', vmin=-1, vmax=1)
|
||
axes[1, 0].set_title('Token Similarity Matrix (cosine)')
|
||
axes[1, 0].set_xlabel('Token index')
|
||
axes[1, 0].set_ylabel('Token index')
|
||
plt.colorbar(im, ax=axes[1, 0], fraction=0.046)
|
||
|
||
# Positional embedding similarity
|
||
pos_norms = pos_embed / jnp.linalg.norm(pos_embed, axis=-1, keepdims=True)
|
||
pos_sim = pos_norms @ pos_norms.T
|
||
im2 = axes[1, 1].imshow(pos_sim, cmap='RdBu_r', vmin=-1, vmax=1)
|
||
axes[1, 1].set_title('Positional Embedding Similarity')
|
||
axes[1, 1].set_xlabel('Position index')
|
||
axes[1, 1].set_ylabel('Position index')
|
||
plt.colorbar(im2, ax=axes[1, 1], fraction=0.046)
|
||
|
||
plt.tight_layout()
|
||
plt.show()
|
||
```
|
||
|
||
- **任务 4:用于和弦分析的简单色度图计算。** 从合成和声信号计算并可视化色度图,展示音乐信息检索中使用的音高类别折叠方法。
|
||
|
||
```python
|
||
import jax
|
||
import jax.numpy as jnp
|
||
import matplotlib.pyplot as plt
|
||
|
||
# Generate a synthetic musical signal: C major chord -> G major chord
|
||
sr = 16000
|
||
duration = 2.0
|
||
t = jnp.linspace(0, duration, int(sr * duration))
|
||
|
||
# C major (C4=261.6, E4=329.6, G4=392.0) for first half
|
||
# G major (G3=196.0, B3=246.9, D4=293.7) for second half
|
||
half = len(t) // 2
|
||
|
||
c_major = (0.5 * jnp.sin(2 * jnp.pi * 261.63 * t[:half]) +
|
||
0.4 * jnp.sin(2 * jnp.pi * 329.63 * t[:half]) +
|
||
0.3 * jnp.sin(2 * jnp.pi * 392.00 * t[:half]))
|
||
|
||
g_major = (0.5 * jnp.sin(2 * jnp.pi * 196.00 * t[:half]) +
|
||
0.4 * jnp.sin(2 * jnp.pi * 246.94 * t[:half]) +
|
||
0.3 * jnp.sin(2 * jnp.pi * 293.66 * t[:half]))
|
||
|
||
signal = jnp.concatenate([c_major, g_major])
|
||
|
||
# Compute STFT
|
||
n_fft = 4096 # high resolution for pitch accuracy
|
||
hop_length = 512
|
||
window = jnp.hanning(n_fft)
|
||
|
||
def stft(signal, n_fft, hop_length, window):
|
||
n_frames = 1 + (len(signal) - n_fft) // hop_length
|
||
frames = jnp.stack([
|
||
signal[i * hop_length : i * hop_length + n_fft] * window
|
||
for i in range(n_frames)
|
||
])
|
||
return jnp.fft.rfft(frames, n=n_fft)
|
||
|
||
S = stft(signal, n_fft, hop_length, window)
|
||
power_spec = jnp.abs(S) ** 2
|
||
freqs = jnp.fft.rfftfreq(n_fft, 1.0 / sr)
|
||
|
||
# Compute chromagram by mapping frequency bins to pitch classes
|
||
# MIDI note number from frequency: 69 + 12 * log2(f / 440)
|
||
note_names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
|
||
|
||
def freq_to_chroma(freq):
|
||
"""Map frequency to pitch class (0-11). Returns -1 for freq <= 0."""
|
||
midi = 69 + 12 * jnp.log2(jnp.clip(freq, 1e-10, None) / 440.0)
|
||
return jnp.round(midi).astype(int) % 12
|
||
|
||
# Build chromagram: sum power spectrum energy for each pitch class
|
||
chromagram = jnp.zeros((power_spec.shape[0], 12))
|
||
valid_freqs = freqs[1:] # skip DC
|
||
valid_power = power_spec[:, 1:]
|
||
|
||
for p in range(12):
|
||
# Find frequency bins belonging to this pitch class
|
||
chroma_bins = freq_to_chroma(valid_freqs)
|
||
mask = (chroma_bins == p).astype(jnp.float32)
|
||
chromagram = chromagram.at[:, p].set(
|
||
jnp.sum(valid_power * mask[None, :], axis=1)
|
||
)
|
||
|
||
# Normalise each frame
|
||
chromagram = chromagram / (jnp.max(chromagram, axis=1, keepdims=True) + 1e-8)
|
||
|
||
# Visualisation
|
||
fig, axes = plt.subplots(3, 1, figsize=(14, 10))
|
||
|
||
# Waveform
|
||
axes[0].plot(t[:3000], signal[:3000], color='#3498db', linewidth=0.5,
|
||
label='C major')
|
||
axes[0].plot(t[half:half+3000], signal[half:half+3000], color='#e74c3c',
|
||
linewidth=0.5, label='G major')
|
||
axes[0].set_title('Waveform: C major → G major')
|
||
axes[0].set_ylabel('Amplitude')
|
||
axes[0].set_xlabel('Time (s)')
|
||
axes[0].legend()
|
||
|
||
# Spectrogram (log scale)
|
||
time_axis = jnp.arange(power_spec.shape[0]) * hop_length / sr
|
||
axes[1].imshow(jnp.log1p(power_spec[:, :500].T), aspect='auto', origin='lower',
|
||
cmap='magma', extent=[0, time_axis[-1], 0, freqs[500]])
|
||
axes[1].set_title('Power Spectrogram')
|
||
axes[1].set_ylabel('Frequency (Hz)')
|
||
axes[1].set_xlabel('Time (s)')
|
||
|
||
# Chromagram
|
||
im = axes[2].imshow(chromagram.T, aspect='auto', origin='lower', cmap='YlOrRd',
|
||
extent=[0, time_axis[-1], -0.5, 11.5])
|
||
axes[2].set_yticks(range(12))
|
||
axes[2].set_yticklabels(note_names)
|
||
axes[2].set_title('Chromagram (pitch class energy over time)')
|
||
axes[2].set_ylabel('Pitch class')
|
||
axes[2].set_xlabel('Time (s)')
|
||
plt.colorbar(im, ax=axes[2], fraction=0.046, label='Normalised energy')
|
||
|
||
# Mark expected active pitch classes
|
||
mid_frame = chromagram.shape[0] // 2
|
||
print(f"C major region - expected: C, E, G")
|
||
print(f" Chroma values: {dict(zip(note_names, [f'{v:.2f}' for v in chromagram[mid_frame//2]]))}")
|
||
print(f"G major region - expected: G, B, D")
|
||
print(f" Chroma values: {dict(zip(note_names, [f'{v:.2f}' for v in chromagram[mid_frame + mid_frame//2]]))}")
|
||
|
||
plt.tight_layout()
|
||
plt.show()
|
||
```
|