# 说话人与音频分析 *说话人与音频分析识别谁在说话、何时说话以及存在哪些非语言声音。本文涵盖说话人确认与识别、i向量、d向量、x向量、说话人日志、音频事件分类、音乐信息检索以及语音情感识别。* - 在文件 01 中,我们构建了信号处理基础:语谱图、MFCC 和梅尔滤波器组。在文件 02 中,我们识别了所说的内容。现在我们要问:是谁说的、何时说的、以及音频中还在发生什么。说话人识别、说话人日志、音频分类和音乐分析都共享一条主线:学习能够为当前任务捕捉正确不变性的紧凑嵌入,这与第 06 章中的嵌入思想一脉相承。 - 可以把说话人识别想象成在电话中辨认朋友的声音。你不需要理解词汇;某种关于音色、语速和嗓音特质的东西对这个人来说是独一无二的。说话人识别系统学会从原始音频中提取这种"声纹",忽略说的是什么,专注于怎么说的。 - **说话人识别**是两类相关任务的总称: - **说话人确认**(SV):给定一个声明的身份和一段音频片段,判断说话人是否与其声称的身份一致。这是一个二元决策(接受或拒绝),是基于语音的身份验证技术("嘿 Siri,这是我的声音吗?")背后的核心原理。 - **说话人识别**(SI):给定一段音频片段和一个已知说话人库,判断该片段由哪个说话人产生。这是一个多分类问题。 ![说话人确认:注册音频被嵌入,测试音频被嵌入,计算嵌入之间的余弦相似度,通过阈值决定接受或拒绝](../images/speaker_verification.svg) - 两种任务共享相同的底层表示:一个固定维度的**说话人嵌入**,它捕捉说话人的身份特征而与所说内容无关。区别仅在于决策阶段:确认比较两个嵌入,识别则在候选嵌入中找到最近邻。 - **余弦相似度**是比较说话人嵌入的标准度量。给定注册嵌入 $e$ 和测试嵌入 $t$: $$s = \frac{e \cdot t}{\|e\| \, \|t\|}$$ - 阈值 $\theta$ 决定接受/拒绝决策:若 $s > \theta$,则接受。阈值在**错误接受率(FAR)**和**错误拒绝率(FRR)**之间权衡。**等错误率(EER)**,即 FAR = FRR 时的值,是标准评估指标。EER 越低表示性能越好。最先进的系统在标准基准(VoxCeleb)上可实现低于 1% 的 EER。 - **i向量**(Dehak 等人,2010)是深度学习之前主导性的说话人嵌入方法。其思想源于因子分析(第 02 章的矩阵分解和第 04 章的降维)。一个**通用背景模型(UBM)**——基于多样本说话人训练的大型 GMM——定义了一个超向量空间。每条语音的 GMM 超向量被投影到低维的**全可变性空间**: $$M = m + Tw$$ - 其中 $M$ 是该语音的 GMM 超向量,$m$ 是 UBM 均值超向量,$T$ 是全可变性矩阵(从数据中学习得到),$w$ 是 i 向量,一个低维(通常为 400-600 维)表示,同时捕捉说话人变异和信道变异。 - 为了从 i 向量中去除信道变异,**概率线性判别分析(PLDA)**将 i 向量建模为说话人特定潜变量和信道特定潜变量之和。PLDA 为确认任务提供了一个有原则的对数似然比分数: $$\text{score}(w_1, w_2) = \log \frac{P(w_1, w_2 \mid \text{同一说话人})}{P(w_1 \mid \text{说话人}_1) \, P(w_2 \mid \text{说话人}_2)}$$ - **d向量**(Variani 等人,2014)是第一个神经说话人嵌入。一个为说话人分类训练的 DNN 处理帧级特征,通过对整条语音中最后一层隐藏层激活值求平均,提取出固定维度的表示。虽然简单但有效,d向量证明了神经网络可以在没有 i 向量复杂统计机制的情况下学习到说话人判别性特征。 - **x向量**(Snyder 等人,2018)使用**时延神经网络(TDNN)**架构显著推进了神经说话人嵌入。TDNN 是具有特定上下文窗口的 1D 卷积,与文件 03 中 WaveNet 的扩张卷积有关,但应用于帧级特征而非原始波形样本。 ![x向量架构:TDNN 层以递增的上下文处理帧级特征,统计池化在时间维度上聚合,全连接层产生说话人嵌入](../images/xvector_architecture.svg) - x向量架构包含三个阶段: - **帧级层**:一组 TDNN 层处理 MFCC(来自文件 01),时间上下文逐步扩大。每一层都有一个固定的上下文窗口(例如第一层为 $\{t-2, t-1, t, t+1, t+2\}$,后续层窗口更宽)。 - **统计池化**:在帧级层之后,计算帧级输出在整个语音上的均值和标准差,产生一个与语音时长无关的固定维度向量: ```math \begin{aligned} \mu &= \frac{1}{T} \sum_{t=1}^{T} h_t \\ \sigma &= \sqrt{\frac{1}{T} \sum_{t=1}^{T} (h_t - \mu)^2} \end{aligned} ``` - 其中 $h_t$ 是时间 $t$ 的帧级输出。拼接 $[\mu; \sigma]$ 即为池化后的表示。 - **段级层**:全连接层处理池化后的表示。第一个段级层的输出(softmax 之前)即为 x 向量嵌入。 - x向量使用说话人身份上的标准交叉熵损失进行训练。尽管是为分类任务训练的,但学习到的中间表示(x向量)能很好地泛化到未见过的说话人,因为网络学习的是提取说话人判别性特征,而非记忆特定说话人。 - **ECAPA-TDNN**(Desplanques 等人,2020)是目前最先进的基于 TDNN 的说话人识别架构。它在 x 向量基础上引入了三项改进: - **压缩激励(SE)模块**:通道注意力(来自第 08 章的 SENet),根据全局上下文重新加权特征通道,使模型能够强调与说话人相关的通道。 - **Res2Net 风格的多尺度特征**:在每个 TDNN 模块内,通道被分成若干组,以层级方式处理,在多个时间分辨率上创建特征(类似于第 08 章的多尺度特征提取)。 - **注意力统计池化**:不再使用等权平均,而是通过注意力机制为每一帧对池化统计量的贡献分配权重。包含更多说话人判别性内容的帧(如元音,承载更多说话人信息)获得更高的注意力权重: $$\alpha_t = \frac{\exp(v^T f(h_t))}{\sum_{\tau} \exp(v^T f(h_\tau))}$$ - 其中 $f$ 是一个小型神经网络,$v$ 是一个学习到的注意力向量。注意力加权的均值和标准差变为 $\tilde{\mu} = \sum_t \alpha_t h_t$ 和 $\tilde{\sigma} = \sqrt{\sum_t \alpha_t (h_t - \tilde{\mu})^2}$。 - ECAPA-TDNN 通常使用 **AAM-Softmax**(附加角度间隔 Softmax)进行训练,它在分类损失中添加了角度间隔惩罚,将同一说话人的嵌入推得更近,不同说话人的嵌入在超球面上推得更远: $$L = -\log \frac{e^{s \cos(\theta_{y_i} + m)}}{e^{s \cos(\theta_{y_i} + m)} + \sum_{j \neq y_i} e^{s \cos \theta_j}}$$ - 其中 $\theta_{y_i}$ 是嵌入与真实类别权重向量之间的夹角,$m$ 是间隔(通常为 0.2),$s$ 是缩放因子(通常为 30)。该损失函数来自人脸识别(第 08 章的 ArcFace),在说话人确认中非常有效。 - **说话人日志**回答了多方录音中"谁在什么时候说话"的问题。可以把这想象成给时间线上色:每种颜色代表一个不同的说话人,系统必须确定每个说话人何时活跃,包括重叠语音的情况。 ![说话人日志:音频时间线被分割并用说话人身份标注,展示交替说话和重叠区域](../images/speaker_diarisation.svg) - **基于聚类的说话人日志**是传统的流水线方法: - **分割**:将音频划分为短段(通常为 1-2 秒),使用滑动窗口或说话人变化检测。 - **嵌入提取**:为每个片段提取说话人嵌入(x向量、ECAPA-TDNN)。 - **聚类**:按说话人对片段进行分组。**凝聚层次聚类(AHC)**是标准方法:开始时每个片段自成一类,然后迭代合并两个最相似的类,直到满足停止条件(基于距离阈值或目标说话人数)。 - **重分割**:使用基于维特比算法的重对齐来优化边界。 - 说话人数量通常事先未知,这使得该问题比标准聚类更困难。使用基于特征值阈值确定 $k$ 的谱聚类是另一种常见方法。 - **端到端神经说话人日志(EEND)**(Fujita 等人,2019)将说话人日志框架化为一个多标签分类问题。一个神经网络(通常是基于自注意力的模型,第 07 章的 transformer)将整段录音作为输入,为每一帧输出每个说话人的二元活动标签。这直接处理了重叠语音,而这是基于聚类方法的主要弱点。 - EEND 对 $S$ 个说话人在帧 $t$ 的输出为: $$\hat{y}_{t,s} = \sigma(f_s(h_t))$$ - 其中 $h_t$ 是帧 $t$ 处的 transformer 输出,$f_s$ 是说话人 $s$ 的线性投影。训练损失是在说话人和帧上求和得到的二元交叉熵。一个关键挑战是说话人数量必须固定,或者使用可变输出架构(EEND-EDA 使用带吸引子的编码器-解码器)来处理。 - **置换不变训练(PIT)**用于处理说话人日志中的标签歧义问题:由于说话人没有固有顺序,需要对所有可能的说话人到输出分配计算损失,并取最小值(这与文件 05 中源分离使用的 PIT 相同)。 - **音频分类**为整段音频片段分配一个标签。与转录语音的 ASR(文件 02)不同,音频分类涵盖更广的范围:环境声音(警笛、雨声、狗吠)、音乐流派(摇滚、爵士、古典)以及一般音频事件。 - 标准方法遵循第 08 章的图像分类范式:将音频表示为语谱图(一个二维时间-频率图像),然后应用 CNN 或 transformer 分类器。这种谱图-图像方法利用了计算机视觉几十年来的进展。 - **环境声音分类(ESC)**使用 ESC-50(50 类,2000 个片段)和 UrbanSound8K 等数据集。典型架构是应用于对数梅尔语谱图的 CNN(第 06 章)。数据增强至关重要:时间拉伸、音高偏移、添加背景噪声以及 **SpecAugment**(文件 02 的掩码方法应用于语谱图)都能提升泛化能力。 - **音频事件检测**(声音事件检测,SED)是分类的时间维度对应任务:不仅仅要知道存在哪些事件,还要知道它们何时开始和结束。**AudioSet**(Gemmeke 等人,2017)是大规模基准,包含 527 个事件类别和超过 200 万个来自 YouTube 的 10 秒片段,每个片段都有弱标注(片段级标签,而非帧级)。 - **弱监督 SED** 必须从片段级标签学习帧级预测。标准方法使用 CNN 产生帧级类别概率,然后通过注意力池化聚合成片段级预测: $$\hat{Y}_c = \sigma\left(\sum_t \alpha_{t,c} \cdot f_{t,c}\right)$$ - 其中 $f_{t,c}$ 是类别 $c$ 在时间 $t$ 的帧级 logit,$\alpha_{t,c}$ 是注意力权重。片段级预测 $\hat{Y}_c$ 根据片段级标签进行训练。 - **声学场景分类(ASC)**对整体环境进行分类:"机场"、"公园"、"地铁站"、"办公室"。这是一个整体性任务:模型必须捕捉一般的声学纹理而非特定事件。DCASE 挑战系列每年对 ASC 进行基准测试,获奖系统通常使用多分辨率语谱图上的 CNN 集成。 - **音频嵌入**是从大规模音频数据中学习到的通用表示,类似于可迁移到下游任务的词嵌入(第 07 章)或图像特征(第 08 章)。 - **VGGish**(Hershey 等人,2017)将 VGG 图像分类网络(第 08 章)适配到音频领域。它通过一个在 AudioSet 上预训练的类 VGG CNN 处理 0.96 秒的对数梅尔语谱图块,每块产生一个 128 维嵌入。VGGish 嵌入可作为下游任务的通用音频特征,类似于 ImageNet 预训练 CNN 提供视觉特征的方式。 - **PANNs**(预训练音频神经网络,Kong 等人,2020)是一系列 CNN 架构(CNN6、CNN10、CNN14),在完整的 AudioSet 上为音频标记任务训练。CNN14 使用最广泛,是一个 14 层 CNN,将对数梅尔语谱图作为输入,使用 $3 \times 3$ 卷积。PANNs 产生 2048 维嵌入,在多种音频任务上实现了最先进的迁移学习性能。 - **音频语谱图 Transformer(AST)**(Gong 等人,2021)将视觉 Transformer(ViT,第 08 章)架构直接应用于音频语谱图。语谱图被分割成 $16 \times 16$ 的块(就像 ViT 分割图像一样),每个块被线性投影为令牌嵌入,添加位置嵌入,然后由标准 Transformer 编码器(第 07 章)处理序列。[CLS] 令牌的输出用于分类。 ![音频语谱图 Transformer:梅尔语谱图被分割成块,每个块展平并线性投影为令牌,添加位置嵌入,Transformer 编码器通过 CLS 令牌产生分类输出](../images/audio_spectrogram_transformer.svg) - AST 受益于 **ImageNet 预训练**:由于语谱图是 2D 图像,AST 从 ImageNet 图像上预训练的 ViT 初始化,然后在音频上微调。这种跨模态迁移出奇地有效,因为两个域共享低级特征(边缘、纹理),并且位置嵌入可以插值以处理不同大小的语谱图。 - **HTS-AT**(Chen 等人,2022)使用分层 Swin Transformer 架构(第 08 章的移位窗口注意力)改进了 AST,在降低计算成本的同时通过多尺度特征提取提升了性能。 - **BEATs**(Chen 等人,2023)使用了一种音频特定的预训练策略:使用离散标记器进行迭代掩码预测(类似于文件 02 中 wav2vec 2.0 的方法,但应用于通用音频)。标记器逐步细化,创建越来越具有语义意义的离散音频令牌。 - **基于嵌入的说话人日志**结合了说话人嵌入与时序建模。像 Pyannote.audio 这样的现代系统使用三阶段流水线:(1) 检测说话人切换和重叠语音的神经分割模型,(2) 应用于每个检测到的片段的嵌入提取阶段(ECAPA-TDNN),以及 (3) 聚类以在整个录音中分配说话人身份。 - **音乐信息检索(MIR)**将音频分析应用于音乐。文件 01 中的谱图表示在这里尤其有用,因为音乐具有丰富的和声结构。 - **节拍跟踪**检测音乐的节奏脉冲。标准方法从语谱图计算**起始强度包络**(检测表示音符起始的能量增加),然后使用自相关或节拍图谱找到节奏,最后使用动态规划跟踪单个节拍位置,找到最能匹配起始包络同时保持稳定节奏的节拍时间序列。 - **和弦识别**识别随时间变化的和声内容。输入通常是**色度图**(也称为音高类别分布图):一个 12 维表示,将所有八度折叠在一起,显示 12 个音高类别(C、C#、D、…、B)中每个类别的能量。CNN 或 RNN(第 06 章)将每个时间帧分类到标准和弦标签之一(C 大调、A 小调、G7 等)。 - 色度图通过将每个频率区间映射到其音高类别,从 STFT(文件 01)计算得到: $$\text{chroma}(p) = \sum_{k : \text{pitch}(k) \bmod 12 = p} |X(k)|^2$$ - 其中 $p \in \{0, 1, \ldots, 11\}$ 是音高类别,$\text{pitch}(k)$ 将频率区间 $k$ 映射到其 MIDI 音符编号。 - **源分离基础**(详见文件 05)将音乐录音分离为单独的乐器(人声、鼓、贝斯、其他)。这是混音、卡拉 OK 和音乐转录等 MIR 应用的核心。像 Demucs(文件 05)这样的模型在标准 MUSDB18 基准上达到了非常好的分离质量。 - **音乐标记**为歌曲分配标签(流派、情感、乐器、时代)。它本质上是应用于音乐的音频分类,使用相同的 CNN-语谱图方法。Million Song Dataset 和 MagnaTagATune 是标准基准。 - **音频指纹**从短片段中识别特定录音,即使存在噪声、混响或压缩伪影。经典系统是 Shazam,它对星座图(语谱图中的显著峰值)进行哈希处理。神经方法学习对声学退化具有不变性、同时对不同录音保持判别性的鲁棒嵌入,这与第 06 章和第 08 章中的不变特征学习一脉相承。 ## 编程任务(使用 Colab 或笔记本) - **任务 1:带统计池化的说话人嵌入提取。** 构建一个简单的 x向量风格模型,通过 TDNN 层和统计池化处理帧级特征以产生说话人嵌入。 ```python import jax import jax.numpy as jnp import jax.random as jr import matplotlib.pyplot as plt # Simulate frame-level MFCC features for multiple speakers def generate_speaker_data(key, n_speakers=5, utterances_per_speaker=20, n_frames=100, n_features=40): """Generate synthetic speaker data with speaker-dependent patterns.""" keys = jr.split(key, 3) all_features = [] all_labels = [] # Each speaker has a characteristic spectral pattern speaker_patterns = jr.normal(keys[0], (n_speakers, n_features)) * 0.5 for spk in range(n_speakers): for utt in range(utterances_per_speaker): k = jr.fold_in(keys[1], spk * utterances_per_speaker + utt) noise = jr.normal(k, (n_frames, n_features)) * 0.3 features = speaker_patterns[spk][None, :] + noise all_features.append(features) all_labels.append(spk) perm = jr.permutation(keys[2], len(all_features)) features = jnp.stack(all_features)[perm] labels = jnp.array(all_labels)[perm] return features, labels key = jr.PRNGKey(42) features, labels = generate_speaker_data(key) n_speakers = 5 n_features = 40 # x-vector-style model def init_xvector(key, n_features=40, hidden=128, embed_dim=64, n_speakers=5): keys = jr.split(key, 8) params = { # TDNN layer 1: context [-2, 2] 'tdnn1_w': jr.normal(keys[0], (5, n_features, hidden)) * jnp.sqrt(2.0 / (5 * n_features)), 'tdnn1_b': jnp.zeros(hidden), # TDNN layer 2: context [-2, 2] 'tdnn2_w': jr.normal(keys[1], (5, hidden, hidden)) * jnp.sqrt(2.0 / (5 * hidden)), 'tdnn2_b': jnp.zeros(hidden), # TDNN layer 3: context [-3, 3] 'tdnn3_w': jr.normal(keys[2], (7, hidden, hidden)) * jnp.sqrt(2.0 / (7 * hidden)), 'tdnn3_b': jnp.zeros(hidden), # Segment-level layers (after pooling: 2*hidden -> embed_dim) 'seg1_w': jr.normal(keys[3], (2 * hidden, embed_dim)) * jnp.sqrt(2.0 / (2 * hidden)), 'seg1_b': jnp.zeros(embed_dim), # Classification head 'cls_w': jr.normal(keys[4], (embed_dim, n_speakers)) * jnp.sqrt(2.0 / embed_dim), 'cls_b': jnp.zeros(n_speakers), } return params def xvector_forward(params, x, return_embedding=False): """x: (batch, frames, features) -> logits or embeddings.""" # TDNN layers (1D convolutions) h = jax.lax.conv_general_dilated( x.transpose(0, 2, 1), params['tdnn1_w'].transpose(2, 1, 0), window_strides=(1,), padding='SAME' ).transpose(0, 2, 1) + params['tdnn1_b'] h = jax.nn.relu(h) h = jax.lax.conv_general_dilated( h.transpose(0, 2, 1), params['tdnn2_w'].transpose(2, 1, 0), window_strides=(1,), padding='SAME' ).transpose(0, 2, 1) + params['tdnn2_b'] h = jax.nn.relu(h) h = jax.lax.conv_general_dilated( h.transpose(0, 2, 1), params['tdnn3_w'].transpose(2, 1, 0), window_strides=(1,), padding='SAME' ).transpose(0, 2, 1) + params['tdnn3_b'] h = jax.nn.relu(h) # Statistics pooling: mean and std over time mu = jnp.mean(h, axis=1) sigma = jnp.std(h, axis=1) pooled = jnp.concatenate([mu, sigma], axis=-1) # Segment-level layer -> embedding embedding = jax.nn.relu(pooled @ params['seg1_w'] + params['seg1_b']) if return_embedding: return embedding # Classification logits = embedding @ params['cls_w'] + params['cls_b'] return logits def cross_entropy_loss(params, features, labels): logits = xvector_forward(params, features) one_hot = jax.nn.one_hot(labels, n_speakers) log_probs = jax.nn.log_softmax(logits) return -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1)) grad_fn = jax.jit(jax.value_and_grad(cross_entropy_loss)) # Train params = init_xvector(jr.PRNGKey(0)) lr = 1e-3 losses = [] for epoch in range(300): loss_val, grads = grad_fn(params, features, labels) params = jax.tree.map(lambda p, g: p - lr * g, params, grads) losses.append(float(loss_val)) # Extract embeddings and visualise with t-SNE-style 2D projection (using PCA) embeddings = xvector_forward(params, features, return_embedding=True) # Simple PCA to 2D emb_centered = embeddings - jnp.mean(embeddings, axis=0) _, _, Vt = jnp.linalg.svd(emb_centered, full_matrices=False) proj_2d = emb_centered @ Vt[:2].T fig, axes = plt.subplots(1, 2, figsize=(14, 5)) axes[0].plot(losses, color='#3498db', linewidth=1.5) axes[0].set_xlabel('Epoch') axes[0].set_ylabel('Cross-Entropy Loss') axes[0].set_title('Speaker Classification Training') axes[0].set_yscale('log') colors = ['#3498db', '#e74c3c', '#27ae60', '#f39c12', '#9b59b6'] for spk in range(n_speakers): mask = labels == spk axes[1].scatter(proj_2d[mask, 0], proj_2d[mask, 1], c=colors[spk], label=f'Speaker {spk}', alpha=0.7, s=30) axes[1].set_xlabel('PC 1') axes[1].set_ylabel('PC 2') axes[1].set_title('Speaker Embeddings (PCA projection)') axes[1].legend() plt.tight_layout() plt.show() # Verification demo: cosine similarity emb_norm = embeddings / jnp.linalg.norm(embeddings, axis=-1, keepdims=True) sim_matrix = emb_norm @ emb_norm.T print(f"Embedding shape: {embeddings.shape}") print(f"Avg same-speaker similarity: {jnp.mean(sim_matrix[labels[:, None] == labels[None, :]]):.4f}") print(f"Avg diff-speaker similarity: {jnp.mean(sim_matrix[labels[:, None] != labels[None, :]]):.4f}") ``` - **任务 2:基于余弦相似度评分的说话人确认。** 给定预计算的说话人嵌入,实现一个计算 EER(等错误率)并绘制 DET 曲线的确认系统。 ```python import jax import jax.numpy as jnp import jax.random as jr import matplotlib.pyplot as plt def generate_verification_pairs(key, n_speakers=20, dim=64, n_pairs=2000): """Generate speaker embeddings and verification trial pairs.""" keys = jr.split(key, 5) # Speaker centroids with some variance centroids = jr.normal(keys[0], (n_speakers, dim)) centroids = centroids / jnp.linalg.norm(centroids, axis=-1, keepdims=True) # Generate enrollment and test embeddings with intra-speaker variance enroll_embs = [] test_embs = [] trial_labels = [] # 1 = same speaker (target), 0 = different (impostor) for i in range(n_pairs): k1, k2, k3 = jr.split(jr.fold_in(keys[1], i), 3) is_target = jr.bernoulli(k1).astype(int) spk1 = jr.randint(k2, (), 0, n_speakers) emb1 = centroids[spk1] + jr.normal(jr.fold_in(k3, 0), (dim,)) * 0.15 if is_target: spk2 = spk1 else: spk2 = (spk1 + jr.randint(jr.fold_in(k3, 1), (), 1, n_speakers)) % n_speakers emb2 = centroids[spk2] + jr.normal(jr.fold_in(k3, 2), (dim,)) * 0.15 enroll_embs.append(emb1) test_embs.append(emb2) trial_labels.append(int(is_target)) return (jnp.stack(enroll_embs), jnp.stack(test_embs), jnp.array(trial_labels)) key = jr.PRNGKey(42) enroll, test, labels = generate_verification_pairs(key) # Compute cosine similarity scores enroll_norm = enroll / jnp.linalg.norm(enroll, axis=-1, keepdims=True) test_norm = test / jnp.linalg.norm(test, axis=-1, keepdims=True) scores = jnp.sum(enroll_norm * test_norm, axis=-1) # Compute FAR and FRR at various thresholds thresholds = jnp.linspace(-1.0, 1.0, 500) target_scores = scores[labels == 1] impostor_scores = scores[labels == 0] fars = [] frrs = [] for thresh in thresholds: far = jnp.mean(impostor_scores >= thresh) # false accepts frr = jnp.mean(target_scores < thresh) # false rejects fars.append(float(far)) frrs.append(float(frr)) fars = jnp.array(fars) frrs = jnp.array(frrs) # Find EER: where FAR ≈ FRR eer_idx = jnp.argmin(jnp.abs(fars - frrs)) eer = float((fars[eer_idx] + frrs[eer_idx]) / 2) eer_threshold = float(thresholds[eer_idx]) print(f"Equal Error Rate (EER): {eer:.4f} ({eer*100:.2f}%)") print(f"EER threshold: {eer_threshold:.4f}") fig, axes = plt.subplots(1, 3, figsize=(18, 5)) # Score distributions bins = jnp.linspace(-0.5, 1.0, 60) axes[0].hist(target_scores, bins=bins, alpha=0.6, color='#27ae60', label='Target (same speaker)', density=True) axes[0].hist(impostor_scores, bins=bins, alpha=0.6, color='#e74c3c', label='Impostor (different speaker)', density=True) axes[0].axvline(eer_threshold, color='#f39c12', linestyle='--', linewidth=2, label=f'EER threshold = {eer_threshold:.3f}') axes[0].set_xlabel('Cosine Similarity Score') axes[0].set_ylabel('Density') axes[0].set_title('Score Distributions') axes[0].legend() # FAR vs FRR axes[1].plot(thresholds, fars, color='#e74c3c', linewidth=2, label='FAR') axes[1].plot(thresholds, frrs, color='#3498db', linewidth=2, label='FRR') axes[1].axvline(eer_threshold, color='#f39c12', linestyle='--', linewidth=1.5) axes[1].scatter([eer_threshold], [eer], color='#f39c12', s=100, zorder=5, label=f'EER = {eer:.4f}') axes[1].set_xlabel('Threshold') axes[1].set_ylabel('Error Rate') axes[1].set_title('FAR and FRR vs Threshold') axes[1].legend() # DET curve (FAR vs FRR) axes[2].plot(fars, frrs, color='#9b59b6', linewidth=2) axes[2].plot([0, 1], [0, 1], 'k--', alpha=0.3) axes[2].scatter([eer], [eer], color='#f39c12', s=100, zorder=5, label=f'EER = {eer:.4f}') axes[2].set_xlabel('False Acceptance Rate') axes[2].set_ylabel('False Rejection Rate') axes[2].set_title('DET Curve') axes[2].set_xlim([0, 0.5]) axes[2].set_ylim([0, 0.5]) axes[2].legend() axes[2].set_aspect('equal') plt.tight_layout() plt.show() ``` - **任务 3:音频语谱图块嵌入(AST 风格)。** 实现音频语谱图 Transformer 的块提取和嵌入层,可视化语谱图如何被令牌化。 ```python import jax import jax.numpy as jnp import jax.random as jr import matplotlib.pyplot as plt # Generate a synthetic spectrogram (harmonic structure + noise) def generate_spectrogram(key, n_time=128, n_freq=128): """Create a synthetic spectrogram with harmonic patterns.""" k1, k2 = jr.split(key) spec = jr.normal(k1, (n_time, n_freq)) * 0.1 # Add harmonic bands (simulating speech formants) for f0 in [15, 30, 45, 70]: width = 3 envelope = jnp.exp(-0.5 * ((jnp.arange(n_freq) - f0) / width) ** 2) time_mod = 0.5 + 0.5 * jnp.sin(2 * jnp.pi * jnp.arange(n_time) / 40) spec += jnp.outer(time_mod, envelope) return jnp.clip(spec, 0, None) key = jr.PRNGKey(42) spectrogram = generate_spectrogram(key) n_time, n_freq = spectrogram.shape # Patch extraction parameters patch_h = 16 # time patch_w = 16 # frequency stride_h = 16 stride_w = 16 embed_dim = 192 # ViT-Small dimension n_patches_h = n_time // stride_h n_patches_w = n_freq // stride_w n_patches = n_patches_h * n_patches_w print(f"Spectrogram: {n_time} x {n_freq}") print(f"Patch size: {patch_h} x {patch_w}") print(f"Number of patches: {n_patches_h} x {n_patches_w} = {n_patches}") # Extract patches def extract_patches(spec, patch_h, patch_w, stride_h, stride_w): """Extract non-overlapping patches from spectrogram.""" patches = [] positions = [] for i in range(0, spec.shape[0] - patch_h + 1, stride_h): for j in range(0, spec.shape[1] - patch_w + 1, stride_w): patch = spec[i:i+patch_h, j:j+patch_w] patches.append(patch.flatten()) positions.append((i, j)) return jnp.stack(patches), positions patches, positions = extract_patches(spectrogram, patch_h, patch_w, stride_h, stride_w) print(f"Patches shape: {patches.shape}") # (n_patches, patch_h * patch_w) # Linear projection (patch embedding) patch_dim = patch_h * patch_w k1, k2 = jr.split(jr.PRNGKey(0)) W_embed = jr.normal(k1, (patch_dim, embed_dim)) * jnp.sqrt(2.0 / patch_dim) b_embed = jnp.zeros(embed_dim) # Learnable positional embeddings pos_embed = jr.normal(k2, (n_patches + 1, embed_dim)) * 0.02 # +1 for CLS # CLS token cls_token = jnp.zeros((1, embed_dim)) # Forward pass patch_tokens = patches @ W_embed + b_embed # (n_patches, embed_dim) tokens = jnp.concatenate([cls_token, patch_tokens], axis=0) # (n_patches+1, embed_dim) tokens = tokens + pos_embed # Add positional embeddings print(f"Token sequence shape: {tokens.shape}") print(f"Each token has dimension: {embed_dim}") # Visualisation fig, axes = plt.subplots(2, 2, figsize=(14, 10)) # Original spectrogram with patch grid axes[0, 0].imshow(spectrogram.T, aspect='auto', origin='lower', cmap='magma') for i in range(0, n_time + 1, stride_h): axes[0, 0].axvline(i - 0.5, color='white', linewidth=0.5, alpha=0.5) for j in range(0, n_freq + 1, stride_w): axes[0, 0].axhline(j - 0.5, color='white', linewidth=0.5, alpha=0.5) axes[0, 0].set_title(f'Spectrogram with {patch_h}x{patch_w} Patch Grid') axes[0, 0].set_xlabel('Time frame') axes[0, 0].set_ylabel('Frequency bin') # Individual patches visualised n_show = min(16, n_patches) patch_grid = patches[:n_show].reshape(n_show, patch_h, patch_w) combined = jnp.concatenate([patch_grid[i] for i in range(min(8, n_show))], axis=1) axes[0, 1].imshow(combined.T, aspect='auto', origin='lower', cmap='magma') axes[0, 1].set_title(f'First {min(8, n_show)} Patches (concatenated)') axes[0, 1].set_xlabel('Patch index (horizontal)') axes[0, 1].set_ylabel('Frequency within patch') # Token embeddings similarity matrix token_norms = tokens / jnp.linalg.norm(tokens, axis=-1, keepdims=True) sim = token_norms @ token_norms.T im = axes[1, 0].imshow(sim, cmap='RdBu_r', vmin=-1, vmax=1) axes[1, 0].set_title('Token Similarity Matrix (cosine)') axes[1, 0].set_xlabel('Token index') axes[1, 0].set_ylabel('Token index') plt.colorbar(im, ax=axes[1, 0], fraction=0.046) # Positional embedding similarity pos_norms = pos_embed / jnp.linalg.norm(pos_embed, axis=-1, keepdims=True) pos_sim = pos_norms @ pos_norms.T im2 = axes[1, 1].imshow(pos_sim, cmap='RdBu_r', vmin=-1, vmax=1) axes[1, 1].set_title('Positional Embedding Similarity') axes[1, 1].set_xlabel('Position index') axes[1, 1].set_ylabel('Position index') plt.colorbar(im2, ax=axes[1, 1], fraction=0.046) plt.tight_layout() plt.show() ``` - **任务 4:用于和弦分析的简单色度图计算。** 从合成和声信号计算并可视化色度图,展示音乐信息检索中使用的音高类别折叠方法。 ```python import jax import jax.numpy as jnp import matplotlib.pyplot as plt # Generate a synthetic musical signal: C major chord -> G major chord sr = 16000 duration = 2.0 t = jnp.linspace(0, duration, int(sr * duration)) # C major (C4=261.6, E4=329.6, G4=392.0) for first half # G major (G3=196.0, B3=246.9, D4=293.7) for second half half = len(t) // 2 c_major = (0.5 * jnp.sin(2 * jnp.pi * 261.63 * t[:half]) + 0.4 * jnp.sin(2 * jnp.pi * 329.63 * t[:half]) + 0.3 * jnp.sin(2 * jnp.pi * 392.00 * t[:half])) g_major = (0.5 * jnp.sin(2 * jnp.pi * 196.00 * t[:half]) + 0.4 * jnp.sin(2 * jnp.pi * 246.94 * t[:half]) + 0.3 * jnp.sin(2 * jnp.pi * 293.66 * t[:half])) signal = jnp.concatenate([c_major, g_major]) # Compute STFT n_fft = 4096 # high resolution for pitch accuracy hop_length = 512 window = jnp.hanning(n_fft) def stft(signal, n_fft, hop_length, window): n_frames = 1 + (len(signal) - n_fft) // hop_length frames = jnp.stack([ signal[i * hop_length : i * hop_length + n_fft] * window for i in range(n_frames) ]) return jnp.fft.rfft(frames, n=n_fft) S = stft(signal, n_fft, hop_length, window) power_spec = jnp.abs(S) ** 2 freqs = jnp.fft.rfftfreq(n_fft, 1.0 / sr) # Compute chromagram by mapping frequency bins to pitch classes # MIDI note number from frequency: 69 + 12 * log2(f / 440) note_names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'] def freq_to_chroma(freq): """Map frequency to pitch class (0-11). Returns -1 for freq <= 0.""" midi = 69 + 12 * jnp.log2(jnp.clip(freq, 1e-10, None) / 440.0) return jnp.round(midi).astype(int) % 12 # Build chromagram: sum power spectrum energy for each pitch class chromagram = jnp.zeros((power_spec.shape[0], 12)) valid_freqs = freqs[1:] # skip DC valid_power = power_spec[:, 1:] for p in range(12): # Find frequency bins belonging to this pitch class chroma_bins = freq_to_chroma(valid_freqs) mask = (chroma_bins == p).astype(jnp.float32) chromagram = chromagram.at[:, p].set( jnp.sum(valid_power * mask[None, :], axis=1) ) # Normalise each frame chromagram = chromagram / (jnp.max(chromagram, axis=1, keepdims=True) + 1e-8) # Visualisation fig, axes = plt.subplots(3, 1, figsize=(14, 10)) # Waveform axes[0].plot(t[:3000], signal[:3000], color='#3498db', linewidth=0.5, label='C major') axes[0].plot(t[half:half+3000], signal[half:half+3000], color='#e74c3c', linewidth=0.5, label='G major') axes[0].set_title('Waveform: C major → G major') axes[0].set_ylabel('Amplitude') axes[0].set_xlabel('Time (s)') axes[0].legend() # Spectrogram (log scale) time_axis = jnp.arange(power_spec.shape[0]) * hop_length / sr axes[1].imshow(jnp.log1p(power_spec[:, :500].T), aspect='auto', origin='lower', cmap='magma', extent=[0, time_axis[-1], 0, freqs[500]]) axes[1].set_title('Power Spectrogram') axes[1].set_ylabel('Frequency (Hz)') axes[1].set_xlabel('Time (s)') # Chromagram im = axes[2].imshow(chromagram.T, aspect='auto', origin='lower', cmap='YlOrRd', extent=[0, time_axis[-1], -0.5, 11.5]) axes[2].set_yticks(range(12)) axes[2].set_yticklabels(note_names) axes[2].set_title('Chromagram (pitch class energy over time)') axes[2].set_ylabel('Pitch class') axes[2].set_xlabel('Time (s)') plt.colorbar(im, ax=axes[2], fraction=0.046, label='Normalised energy') # Mark expected active pitch classes mid_frame = chromagram.shape[0] // 2 print(f"C major region - expected: C, E, G") print(f" Chroma values: {dict(zip(note_names, [f'{v:.2f}' for v in chromagram[mid_frame//2]]))}") print(f"G major region - expected: G, B, D") print(f" Chroma values: {dict(zip(note_names, [f'{v:.2f}' for v in chromagram[mid_frame + mid_frame//2]]))}") plt.tight_layout() plt.show() ```