Files
maths-cs-ai-compendium-zh/chapter 07: computational linguistics/03. embeddings and sequence models.md
T
flykhan 2536c937e3 feat: 完整中文翻译 maths-cs-ai-compendium(数学·计算机科学·AI 知识大全)
翻译自英文原版 maths-cs-ai-compendium,共 20 章全部完成。

第01章 向量 | 第02章 矩阵 | 第03章 微积分
第04章 统计学 | 第05章 概率论 | 第06章 机器学习
第07章 计算语言学 | 第08章 计算机视觉 | 第09章 音频与语音
第10章 多模态学习 | 第11章 自主系统 | 第12章 图神经网络
第13章 计算与操作系统 | 第14章 数据结构与算法
第15章 生产级软件工程 | 第16章 SIMD与GPU编程
第17章 AI推理 | 第18章 ML系统设计
第19章 应用人工智能 | 第20章 前沿人工智能

翻译说明:
- 所有数学公式 $...$ / $$...$$、代码块、图片引用完整保留
- mkdocs.yml 配置中文导航 + language: zh
- README.md 已翻译为中文(兼 docs/index.md)
- docs/ 目录包含指向各章文件的 symlink
- 约 29,000 行中文内容,排除 .cache/ 构建缓存
2026-05-03 10:23:20 +08:00

390 lines
24 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# 嵌入与序列模型
*词嵌入将稀疏的符号化文本压缩到稠密向量空间中,使得语义相似性转化为几何邻近性。本文涵盖 Word2VecCBOW、Skip-gram)、GloVe、FastText、RNN、LSTM、GRU、带注意力机制的 seq2seq、编码器-解码器范式,以及从词袋模型到上下文表示的发展历程。*
- 在文件 01 中,我们介绍了分布假设:出现在相似语境中的词往往具有相似的含义。在文件 02 中,我们使用稀疏的、手工设计的特征(如 TF-IDF 向量)来表示文本。这些向量位于极高维空间中(每个词汇表词占一维),且大部分为零。**词嵌入**将这些信息压缩到稠密的低维向量中,捕捉语义关系,并且直接从数据中学习。
- **Word2Vec**Mikolov et al., 2013)通过在简单的预测任务上训练一个浅层神经网络来学习词嵌入。共有两种架构。
- **连续词袋模型(CBOW)**根据目标词周围的上下文词来预测该词。给定一个窗口大小的上下文词(例如,"the cat ___ on the"),模型求它们的嵌入向量的平均值,并将结果通过一个线性层来预测缺失的词("sat")。训练目标最大化:
$$P(w_t \mid w_{t-k}, \ldots, w_{t-1}, w_{t+1}, \ldots, w_{t+k})$$
- **Skip-gram 模型**则反过来:给定一个目标词,预测其周围的上下文词。对于目标词 "sat",模型分别尝试预测 "the"、"cat"、"on"、"the"。目标最大化:
$$P(w_{t+j} \mid w_t) \quad \text{对于每个 } j \in [-k, k], \; j \neq 0$$
![Skip-gram 与 CBOW 架构对比:CBOW 对上下文嵌入求平均来预测中心词,skip-gram 使用中心词嵌入来预测每个上下文词](../images/word2vec_architectures.svg)
- Skip-gram 通常对罕见词效果更好,因为每个词会产生多个训练样本(每个上下文位置一个)。CBOW 速度更快,对频繁词略优,因为它对多个上下文信号取平均。
- 在整个词汇表上训练代价很高,因为 softmax 分母需要对所有 $V$ 个词求和。**负采样**通过将问题转化为二分类来近似这一过程:区分真实的上下文词(正样本)与随机采样的噪声词(负样本)。模型无需计算完整的 softmax,只需更新目标词、真实上下文词以及少数负样本的嵌入:
$$\mathcal{L} = \log \sigma(v_{w_O}^T v_{w_I}) + \sum_{i=1}^{k} \mathbb{E}_{w_i \sim P_n} [\log \sigma(-v_{w_i}^T v_{w_I})]$$
- 这里 $v_{w_I}$ 是输入词嵌入,$v_{w_O}$ 是输出(上下文)词嵌入,$P_n$ 是噪声分布,通常采用词频的 3/4 次方(这会降低"the"这类高频词的权重)。
- 为什么这个简单的目标函数能产生有意义的嵌入?Levy 和 Goldberg2014)证明,带负采样的 skip-gram 实际上是在分解一个**移位点互信息(PMI)**矩阵。在收敛时,两个词向量的点积近似于:
$$v_w^T v_c \approx \text{PMI}(w, c) - \log k$$
- 其中 $\text{PMI}(w, c) = \log \frac{P(w, c)}{P(w) P(c)}$ 衡量词 $w$ 和 $c$ 共现的频率比随机期望高出多少(见第 05 章信息论),$k$ 是负样本数量。共现远高于随机期望的词具有高 PMI,从而具有高点积(相似的嵌入)。共现低于预期的词具有负 PMI 和不相似的嵌入。这表明 Word2Vec 实际上与经典的分布语义学方法(如潜在语义分析,即对共现矩阵做 SVD)在做同样的事情,只是采用了更具扩展性的在线方式。
- Word2Vec 嵌入最令人惊讶的特性是它们能通过**向量算术**捕捉**类比关系**。向量 $v_{\text{king}} - v_{\text{man}} + v_{\text{woman}}$ 最接近 $v_{\text{queen}}$。这是因为嵌入空间将语义关系编码为近似线性方向:"王室"方向大致为 $v_{\text{king}} - v_{\text{man}}$,将其加到 $v_{\text{woman}}$ 上就会落在 $v_{\text{queen}}$ 附近。这与第 01 章的线性代数相关联:语义关系就是向量平移。
- **GloVe**Global Vectors for Word RepresentationPennington et al., 2014)采用不同的方法。它不是一次一个地从局部上下文窗口学习,而是构建一个全局的词共现矩阵 $X$,其中 $X_{ij}$ 统计在整个语料库中词 $j$ 出现在词 $i$ 上下文中的次数。然后模型学习嵌入,使其点积近似于对数共现次数:
$$w_i^T \tilde{w}_j + b_i + \tilde{b}_j = \log X_{ij}$$
- 损失函数通过一个截断函数 $f(X_{ij})$ 对每一对加权,防止非常频繁的共现主导训练:
$$\mathcal{L} = \sum_{i,j=1}^{V} f(X_{ij}) \left(w_i^T \tilde{w}_j + b_i + \tilde{b}_j - \log X_{ij}\right)^2$$
- GloVe 结合了全局矩阵分解(如潜在语义分析)和 Word2Vec 的局部上下文学习的优点。在实践中,GloVe 和 Word2Vec 生成的嵌入质量相近。
- **FastText**Bojanowski et al., 2017)扩展了 skip-gram,将每个词表示为一组字符 n-gram 的集合。对于 $n = 3$,词 "where" 变成:"\<wh"、"whe"、"her"、"ere"、"re\>",加上完整词标记 "\<where\>"。该词的嵌入是其所有 n-gram 嵌入之和。
- 这有一个关键优势:FastText 能够为训练中从未见过的词生成嵌入。词 "whereabouts" 与 "where" 共享 n-gram,因此即使 "whereabouts" 从未出现在训练数据中,其嵌入也是合理的。这对于形态丰富的语言(文件 01)尤为有用,因为这些语言中的词有许多屈折形式。
- **嵌入评估**通常使用两类基准测试。**类比任务**测试 $v_a - v_b + v_c \approx v_d$ 是否成立(例如,"Paris" $-$ "France" $+$ "Italy" $\approx$ "Rome")。**相似性基准**将词对之间的余弦相似度(第 01 章)与人工判断进行比较。常见的数据集包括 WordSim-353、SimLex-999 和 Google 类比测试集。一个实用注意事项:在类比任务上表现出色的嵌入不一定最适合下游任务,如情感分类。最好的评估往往是任务本身。
- 在第 06 章中,我们介绍了 RNN、LSTM 和 GRU 作为处理序列数据的架构。这里我们重点讨论它们如何具体应用于语言任务。
- **语言模型 RNN** 每次读取一个词元,并在每一步预测下一个词元。隐藏状态 $h_t$ 将整个历史序列 $w_1, \ldots, w_t$ 压缩为一个固定大小的向量,线性层加 softmax 将 $h_t$ 映射到词汇表上的分布。训练使用与真实下一词元的交叉熵损失,这等价于最小化困惑度(文件 02)。关键局限在于:固定大小的隐藏状态必须编码关于历史的所有信息,早期词元的信息会逐渐被覆盖。
- **双向 RNN** 从两个方向处理序列:一个 RNN 从左到右读取,另一个从右到左读取。在每个位置 $t$,前向隐藏状态 $\overrightarrow{h}_t$ 和后向隐藏状态 $\overleftarrow{h}_t$ 被拼接起来,形成上下文感知的表示 $h_t = [\overrightarrow{h}_t ; \overleftarrow{h}_t]$。这使模型能够同时访问过去和未来的上下文,对于词性标注和命名实体识别(文件 02)等任务非常有效,因为这些任务中一个词的标签依赖于其前后的词。双向 RNN 不能用于语言建模,因为在预测未来词元时不能窥视它们。
![双向 RNN:前向 RNN 从左到右读取产生隐藏状态,后向 RNN 从右到左读取,每个位置的输出拼接在一起](../images/bidirectional_rnn.svg)
- **深层堆叠 RNN** 将多个 RNN 层叠放在一起。第 $l$ 层所有时间步的隐藏状态成为第 $l+1$ 层的输入序列。堆叠 2-4 层通常能通过构建层次化表示来提升性能,类似于深层 CNN 构建特征层次结构(第 06 章)。超过 4 层时,梯度消失和过拟合会成为问题,除非在层之间添加残差连接。
- **序列到序列(seq2seq**架构(Sutskever et al., 2014)将可变长度的输入序列映射到可变长度的输出序列。它由一个**编码器** RNN(读取输入并将其压缩为上下文向量,即最终的隐藏状态)和一个**解码器** RNN(基于该上下文向量逐步生成输出)组成。
![Seq2seq 编码器-解码器:编码器 RNN 从左到右读取输入词元,最终隐藏状态作为解码器 RNN 的初始状态,解码器自回归地生成输出词元](../images/seq2seq_architecture.svg)
- Seq2seq 是机器翻译的突破性架构。编码器读取法语句子,解码器生成英文翻译。解码器从一个特殊的序列起始词元开始,自回归地生成词元,直到产生序列结束词元。一个实用的技巧:反转输入序列(输入 "chat le" 而不是 "le chat")可以改善结果,因为这使得第一个输入词在计算图中更靠近第一个输出词,缩短了梯度路径。
- 瓶颈问题:整个输入必须被压缩到一个固定大小的向量中。对于长句子,这个向量无法捕捉所有信息,性能会下降。这推动了**注意力机制**的发展。
- 第 06 章介绍了现代的点积注意力 Q、K、V 形式。NLP 中最早的注意力机制以不同的方式提出,作为编码器和解码器状态之间的对齐模型。
- **Bahdanau 注意力**(加性注意力,Bahdanau et al., 2015)使用一个可学习的前馈网络计算解码器隐藏状态 $s_t$ 与每个编码器隐藏状态 $h_i$ 之间的对齐分数:
$$e_{ti} = v^T \tanh(W_s s_{t-1} + W_h h_i)$$
- 分数通过 softmax 归一化为注意力权重,上下文向量是编码器状态的加权和:
$$\alpha_{ti} = \frac{\exp(e_{ti})}{\sum_j \exp(e_{tj})}, \quad c_t = \sum_i \alpha_{ti} h_i$$
- 然后解码器同时使用 $s_{t-1}$ 和 $c_t$ 来生成下一个输出。关键洞察:不是为整个句子使用一个固定的上下文向量,每个解码步骤获得编码器状态的不同加权组合,使模型能够"回顾"输入的相关部分。
- **Luong 注意力**(乘性注意力,Luong et al., 2015)简化了分数计算。**点积**变体使用 $e_{ti} = s_t^T h_i$。**通用**变体使用 $e_{ti} = s_t^T W h_i$。这些比 Bahdanau 的加性分数更快,因为它们使用矩阵乘法而非前馈网络。Luong 注意力还从当前解码器状态 $s_t$(而非 $s_{t-1}$)计算上下文向量,这使得它能获取更多信息,但计算方式略有不同。
![源句子与其翻译之间的注意力对齐热力图,显示每个目标词关注哪些源词,较亮的单元格表示更高的注意力权重](../images/attention_alignment.svg)
- 注意力权重通常可视化为热力图,显示解码器在生成每个输出词元时关注哪些输入词元。在翻译中,这些热力图大致勾勒出源语言和目标语言之间的词对齐关系,对角模式会被重排序打破(例如,形容词-名词顺序在法语和英语中有所不同)。
- 推理时,解码器每一步必须选择一个词元。**贪心解码**在每个位置选择概率最高的词元,但这可能导致次优序列:一个局部好的选择可能迫使模型进入全局不佳的句子。**束搜索**在每一步维护分数最高的 $k$ 个(束宽)部分序列,对每个序列扩展所有可能的下一词元,并保留总体最好的 $k$ 个。
- 当束宽 $k = 1$ 时,束搜索退化为贪心解码。典型值为 $k = 4$ 到 $k = 10$。更大的束能找到更好的序列,但速度会成比例降低。束搜索还需要**长度归一化**,以避免偏向较短的序列(因为较短的序列乘法项更少,自然具有更高的总概率)。归一化后的分数为:
$$\text{score}(y) = \frac{1}{|y|^\alpha} \sum_{t=1}^{|y|} \log P(y_t \mid y_{<t})$$
- 其中 $|y|$ 是序列长度,$\alpha$(通常为 0.6-0.7)控制长度惩罚的强度。当 $\alpha = 0$ 时,没有长度归一化。当 $\alpha = 1$ 时,分数是每个词元的对数概率(几何平均)。中间值在倾向于简洁输出和不过早截断之间取得平衡。
- 虽然 RNN 顺序处理文本,但 **1D CNN** 通过在词元序列上滑动滤波器来并行处理文本。每个滤波器检测一个局部模式(n-gram 特征)。
- **TextCNN**Kim, 2014)对输入的嵌入矩阵应用多个不同宽度(例如 3、4、5 个词元)的一维卷积滤波器。每个滤波器生成一个特征图,**时序最大池化**从每个特征图中取单一最大值,捕获该模式是否在文本中的任何位置被检测到,而不考虑位置。所有滤波器的池化特征被拼接后传递给分类器。
![TextCNN 架构:输入嵌入通过宽度为 3、4、5 的并行卷积滤波器,每个滤波器后接时序最大池化,然后拼接并馈送到全连接分类器](../images/textcnn_architecture.svg)
- TextCNN 速度快,对于情感分析等文本分类任务效果出奇地好。它能捕获局部 n-gram 模式,但无法建模长距离依赖:宽度为 5 的滤波器只能看到 5 个连续的词元。**膨胀因果卷积**通过在滤波器元素之间插入间隙(膨胀)来解决这个问题。堆叠膨胀率呈指数增长(1、2、4、8、...)的层,可以在不增加参数的情况下指数级地扩大感受野,使模型能够捕获跨越数百个词元的依赖关系。
- 到目前为止讨论的所有嵌入(Word2Vec、GloVe、FastText)针对每个词类型生成单一向量,与上下文无关。"Bank"无论是指金融机构还是河岸,都得到相同的嵌入。这是一个根本性的局限,而**上下文嵌入**解决了这一问题。
- **ELMo**Embeddings from Language ModelsPeters et al., 2018)通过在输入文本上运行一个深层双向 LSTM 语言模型来生成上下文相关的词表示。前向 LSTM 在每个位置预测下一个词;一个独立的后向 LSTM 预测前一个词。两者都在大规模语料库上作为语言模型进行训练。
- 在每个位置 $k$,ELMo 使用任务特定的学习权重组合所有 $L$ 层的隐藏状态:
$$\text{ELMo}_k = \gamma \sum_{j=0}^{L} s_j \, h_{k,j}$$
- 这里 $h_{k,j}$ 是位置 $k$ 层 $j$ 的隐藏状态(层 0 是原始词嵌入),$s_j$ 是 softmax 归一化的标量权重,$\gamma$ 是任务特定的缩放因子。不同层捕获不同信息:较低层捕获句法(词性标注、词形态),较高层捕获语义(词义、语义角色)。通过使用学习到的权重混合所有层,ELMo 嵌入能够适应多样化的下游任务。
- ELMo 标志着**预训练然后微调**范式的开始:在海量无标注文本上训练大型语言模型,然后将其表示用于下游任务。ELMo 具体使用预训练的表示作为固定的或轻度微调的特征,与任务特定的输入拼接在一起。BERT 和 GPT(文件 04)通过端到端地微调整个模型进一步推进了这一范式,事实证明这要有效得多。
- 从 Word2Vec 到 ELMo 的发展过程展示了 NLP 中一个反复出现的主题:从静态表示到动态表示,从局部上下文到全局上下文,从浅层模型到深层模型。每一步都以计算成本换取更丰富的表示。Transformer(文件 04)通过用注意力完全取代循环,实现了深层上下文化和并行计算,完成了这一演进。
## 编程任务(使用 CoLab 或 notebook
1. 从头实现带负采样的 Word2Vec skip-gram。在小型语料库上训练,并使用 PCA 可视化学习到的嵌入。
```python
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
# 小型语料库
corpus = """the king ruled the kingdom . the queen ruled the kingdom .
the prince is the son of the king . the princess is the daughter of the queen .
a man worked in the castle . a woman worked in the castle .
the king and queen lived in the castle . the prince and princess played outside .""".lower().split()
vocab = sorted(set(corpus))
word2idx = {w: i for i, w in enumerate(vocab)}
idx2word = {i: w for w, i in word2idx.items()}
V = len(vocab)
# 生成 skip-gram 对,窗口大小为 2
window = 2
pairs = []
for i, word in enumerate(corpus):
for j in range(max(0, i - window), min(len(corpus), i + window + 1)):
if i != j:
pairs.append((word2idx[word], word2idx[corpus[j]]))
pairs = jnp.array(pairs)
print(f"词汇表大小: {V} 个词, 训练样本数: {len(pairs)}")
# 模型参数
embed_dim = 16
key = jax.random.PRNGKey(42)
k1, k2 = jax.random.split(key)
W_in = jax.random.normal(k1, (V, embed_dim)) * 0.1 # 输入嵌入
W_out = jax.random.normal(k2, (V, embed_dim)) * 0.1 # 输出嵌入
# 单个样本对的负采样损失
def neg_sampling_loss(W_in, W_out, target, context, neg_ids):
v_in = W_in[target] # (embed_dim,)
v_out = W_out[context] # (embed_dim,)
v_neg = W_out[neg_ids] # (k, embed_dim)
pos_loss = -jax.nn.log_sigmoid(jnp.dot(v_in, v_out))
neg_loss = -jnp.sum(jax.nn.log_sigmoid(-v_neg @ v_in))
return pos_loss + neg_loss
# 训练循环
num_neg = 5
lr = 0.05
@jax.jit
def train_step(W_in, W_out, target, context, neg_ids):
loss, (g_in, g_out) = jax.value_and_grad(neg_sampling_loss, argnums=(0, 1))(
W_in, W_out, target, context, neg_ids)
return loss, W_in - lr * g_in, W_out - lr * g_out
key = jax.random.PRNGKey(0)
for epoch in range(50):
total_loss = 0.0
for i in range(len(pairs)):
key, subkey = jax.random.split(key)
neg_ids = jax.random.randint(subkey, (num_neg,), 0, V)
loss, W_in, W_out = train_step(W_in, W_out, pairs[i, 0], pairs[i, 1], neg_ids)
total_loss += loss
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}: avg loss = {total_loss / len(pairs):.4f}")
# 使用 PCA 可视化(第 01 章)
embeddings = W_in
mean = embeddings.mean(axis=0)
centered = embeddings - mean
U, S, Vt = jnp.linalg.svd(centered, full_matrices=False)
coords = centered @ Vt[:2].T # 投影到前两个主成分
plt.figure(figsize=(10, 8))
for i, word in idx2word.items():
plt.scatter(coords[i, 0], coords[i, 1], c='#3498db', s=40)
plt.annotate(word, (coords[i, 0] + 0.02, coords[i, 1] + 0.02), fontsize=9)
plt.title("Word2Vec Skip-gram 嵌入(PCA 投影)")
plt.grid(alpha=0.3); plt.show()
```
2. 构建一个字符级 RNN 语言模型,从一小段训练文本中学习生成文本。
```python
import jax
import jax.numpy as jnp
# 小型训练文本
text = "to be or not to be that is the question "
chars = sorted(set(text))
char2idx = {c: i for i, c in enumerate(chars)}
idx2char = {i: c for c, i in char2idx.items()}
V = len(chars)
data = jnp.array([char2idx[c] for c in text])
# RNN 参数
hidden_dim = 64
key = jax.random.PRNGKey(0)
k1, k2, k3, k4, k5 = jax.random.split(key, 5)
params = {
'Wx': jax.random.normal(k1, (V, hidden_dim)) * 0.1,
'Wh': jax.random.normal(k2, (hidden_dim, hidden_dim)) * 0.05,
'bh': jnp.zeros(hidden_dim),
'Wy': jax.random.normal(k3, (hidden_dim, V)) * 0.1,
'by': jnp.zeros(V),
}
def rnn_step(params, h, x_idx):
x = jnp.eye(V)[x_idx] # one-hot 编码
h = jnp.tanh(x @ params['Wx'] + h @ params['Wh'] + params['bh'])
logits = h @ params['Wy'] + params['by']
return h, logits
def loss_fn(params, inputs, targets):
h = jnp.zeros(hidden_dim)
total_loss = 0.0
for t in range(len(inputs)):
h, logits = rnn_step(params, h, inputs[t])
log_probs = jax.nn.log_softmax(logits)
total_loss -= log_probs[targets[t]]
return total_loss / len(inputs)
grad_fn = jax.jit(jax.grad(loss_fn))
# 训练
inputs = data[:-1]
targets = data[1:]
lr = 0.01
for step in range(500):
grads = grad_fn(params, inputs, targets)
params = {k: params[k] - lr * grads[k] for k in params}
if (step + 1) % 100 == 0:
l = loss_fn(params, inputs, targets)
print(f"Step {step+1}: loss = {l:.4f}")
# 生成文本
def generate(params, seed_char, length=60):
h = jnp.zeros(hidden_dim)
idx = char2idx[seed_char]
result = [seed_char]
key = jax.random.PRNGKey(42)
for _ in range(length):
h, logits = rnn_step(params, h, idx)
key, subkey = jax.random.split(key)
idx = jax.random.categorical(subkey, logits)
result.append(idx2char[int(idx)])
return ''.join(result)
print(f"\n生成文本: {generate(params, 't')}")
```
3. 实现一个带 Bahdanau 注意力的简易 seq2seq 模型,用于序列反转。可视化注意力对齐矩阵。
```python
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
# 任务:反转数字序列(例如,[3, 1, 4] -> [4, 1, 3]
vocab_size = 10 # 数字 0-9
SOS, EOS = 10, 11 # 特殊词元
total_vocab = 12
embed_dim, hidden_dim = 16, 32
max_len = 5
key = jax.random.PRNGKey(42)
keys = jax.random.split(key, 8)
params = {
'embed': jax.random.normal(keys[0], (total_vocab, embed_dim)) * 0.1,
'enc_Wx': jax.random.normal(keys[1], (embed_dim, hidden_dim)) * 0.1,
'enc_Wh': jax.random.normal(keys[2], (hidden_dim, hidden_dim)) * 0.05,
'dec_Wx': jax.random.normal(keys[3], (embed_dim, hidden_dim)) * 0.1,
'dec_Wh': jax.random.normal(keys[4], (hidden_dim, hidden_dim)) * 0.05,
# Bahdanau 注意力
'Ws': jax.random.normal(keys[5], (hidden_dim, hidden_dim)) * 0.1,
'Wh_att': jax.random.normal(keys[6], (hidden_dim, hidden_dim)) * 0.1,
'v_att': jax.random.normal(keys[7], (hidden_dim,)) * 0.1,
# 输出投影(从隐藏状态+上下文到词汇表)
'Wo': jax.random.normal(keys[0], (hidden_dim * 2, total_vocab)) * 0.1,
}
def encode(params, seq):
"""编码输入序列,返回所有隐藏状态。"""
h = jnp.zeros(hidden_dim)
states = []
for t in range(len(seq)):
x = params['embed'][seq[t]]
h = jnp.tanh(x @ params['enc_Wx'] + h @ params['enc_Wh'])
states.append(h)
return jnp.stack(states), h
def bahdanau_attention(params, dec_state, enc_states):
"""计算 Bahdanau 注意力权重和上下文向量。"""
scores = jnp.tanh(enc_states @ params['Wh_att'] + dec_state @ params['Ws'])
e = scores @ params['v_att'] # (src_len,)
alpha = jax.nn.softmax(e)
context = alpha @ enc_states
return context, alpha
def decode_step(params, dec_h, prev_token, enc_states):
x = params['embed'][prev_token]
dec_h = jnp.tanh(x @ params['dec_Wx'] + dec_h @ params['dec_Wh'])
context, alpha = bahdanau_attention(params, dec_h, enc_states)
combined = jnp.concatenate([dec_h, context])
logits = combined @ params['Wo']
return dec_h, logits, alpha
def seq2seq_loss(params, src, tgt):
enc_states, enc_final = encode(params, src)
dec_h = enc_final
loss = 0.0
prev_token = SOS
for t in range(len(tgt)):
dec_h, logits, _ = decode_step(params, dec_h, prev_token, enc_states)
log_probs = jax.nn.log_softmax(logits)
loss -= log_probs[tgt[t]]
prev_token = tgt[t]
return loss / len(tgt)
# 生成训练数据:反转序列
key = jax.random.PRNGKey(0)
train_srcs, train_tgts = [], []
for _ in range(200):
key, subkey = jax.random.split(key)
length = jax.random.randint(subkey, (), 3, max_len + 1)
key, subkey = jax.random.split(key)
seq = jax.random.randint(subkey, (int(length),), 0, vocab_size)
train_srcs.append(seq)
train_tgts.append(seq[::-1]) # 反转
# 训练
grad_fn = jax.grad(seq2seq_loss)
lr = 0.01
for epoch in range(100):
total_loss = 0.0
for src, tgt in zip(train_srcs, train_tgts):
grads = grad_fn(params, src, tgt)
params = {k: params[k] - lr * grads[k] for k in params}
total_loss += seq2seq_loss(params, src, tgt)
if (epoch + 1) % 20 == 0:
print(f"Epoch {epoch+1}: avg loss = {total_loss / len(train_srcs):.4f}")
# 可视化一个示例的注意力
test_src = jnp.array([3, 1, 4, 1, 5])
test_tgt = test_src[::-1]
enc_states, enc_final = encode(params, test_src)
dec_h = enc_final
attentions = []
prev_token = SOS
for t in range(len(test_tgt)):
dec_h, logits, alpha = decode_step(params, dec_h, prev_token, enc_states)
attentions.append(alpha)
prev_token = test_tgt[t]
att_matrix = jnp.stack(attentions)
fig, ax = plt.subplots(figsize=(6, 5))
im = ax.imshow(att_matrix, cmap='Blues')
ax.set_xlabel("源位置"); ax.set_ylabel("目标位置")
src_labels = [str(int(x)) for x in test_src]
tgt_labels = [str(int(x)) for x in test_tgt]
ax.set_xticks(range(len(src_labels))); ax.set_xticklabels(src_labels)
ax.set_yticks(range(len(tgt_labels))); ax.set_yticklabels(tgt_labels)
for i in range(len(tgt_labels)):
for j in range(len(src_labels)):
ax.text(j, i, f"{att_matrix[i,j]:.2f}", ha='center', va='center', fontsize=9)
ax.set_title("Bahdanau 注意力对齐(序列反转)")
plt.colorbar(im); plt.tight_layout(); plt.show()
```