Files
maths-cs-ai-compendium-zh/chapter 07: computational linguistics/03. embeddings and sequence models.md
T
flykhan 2536c937e3 feat: 完整中文翻译 maths-cs-ai-compendium(数学·计算机科学·AI 知识大全)
翻译自英文原版 maths-cs-ai-compendium,共 20 章全部完成。

第01章 向量 | 第02章 矩阵 | 第03章 微积分
第04章 统计学 | 第05章 概率论 | 第06章 机器学习
第07章 计算语言学 | 第08章 计算机视觉 | 第09章 音频与语音
第10章 多模态学习 | 第11章 自主系统 | 第12章 图神经网络
第13章 计算与操作系统 | 第14章 数据结构与算法
第15章 生产级软件工程 | 第16章 SIMD与GPU编程
第17章 AI推理 | 第18章 ML系统设计
第19章 应用人工智能 | 第20章 前沿人工智能

翻译说明:
- 所有数学公式 $...$ / $$...$$、代码块、图片引用完整保留
- mkdocs.yml 配置中文导航 + language: zh
- README.md 已翻译为中文(兼 docs/index.md)
- docs/ 目录包含指向各章文件的 symlink
- 约 29,000 行中文内容,排除 .cache/ 构建缓存
2026-05-03 10:23:20 +08:00

24 KiB
Raw Blame History

嵌入与序列模型

词嵌入将稀疏的符号化文本压缩到稠密向量空间中,使得语义相似性转化为几何邻近性。本文涵盖 Word2VecCBOW、Skip-gram)、GloVe、FastText、RNN、LSTM、GRU、带注意力机制的 seq2seq、编码器-解码器范式,以及从词袋模型到上下文表示的发展历程。

  • 在文件 01 中,我们介绍了分布假设:出现在相似语境中的词往往具有相似的含义。在文件 02 中,我们使用稀疏的、手工设计的特征(如 TF-IDF 向量)来表示文本。这些向量位于极高维空间中(每个词汇表词占一维),且大部分为零。词嵌入将这些信息压缩到稠密的低维向量中,捕捉语义关系,并且直接从数据中学习。

  • Word2VecMikolov et al., 2013)通过在简单的预测任务上训练一个浅层神经网络来学习词嵌入。共有两种架构。

  • **连续词袋模型(CBOW)**根据目标词周围的上下文词来预测该词。给定一个窗口大小的上下文词(例如,"the cat ___ on the"),模型求它们的嵌入向量的平均值,并将结果通过一个线性层来预测缺失的词("sat")。训练目标最大化:

P(w_t \mid w_{t-k}, \ldots, w_{t-1}, w_{t+1}, \ldots, w_{t+k})
  • Skip-gram 模型则反过来:给定一个目标词,预测其周围的上下文词。对于目标词 "sat",模型分别尝试预测 "the"、"cat"、"on"、"the"。目标最大化:
P(w_{t+j} \mid w_t) \quad \text{对于每个 } j \in [-k, k], \; j \neq 0

Skip-gram 与 CBOW 架构对比:CBOW 对上下文嵌入求平均来预测中心词,skip-gram 使用中心词嵌入来预测每个上下文词

  • Skip-gram 通常对罕见词效果更好,因为每个词会产生多个训练样本(每个上下文位置一个)。CBOW 速度更快,对频繁词略优,因为它对多个上下文信号取平均。

  • 在整个词汇表上训练代价很高,因为 softmax 分母需要对所有 V 个词求和。负采样通过将问题转化为二分类来近似这一过程:区分真实的上下文词(正样本)与随机采样的噪声词(负样本)。模型无需计算完整的 softmax,只需更新目标词、真实上下文词以及少数负样本的嵌入:

\mathcal{L} = \log \sigma(v_{w_O}^T v_{w_I}) + \sum_{i=1}^{k} \mathbb{E}_{w_i \sim P_n} [\log \sigma(-v_{w_i}^T v_{w_I})]
  • 这里 v_{w_I} 是输入词嵌入,v_{w_O} 是输出(上下文)词嵌入,P_n 是噪声分布,通常采用词频的 3/4 次方(这会降低"the"这类高频词的权重)。

  • 为什么这个简单的目标函数能产生有意义的嵌入?Levy 和 Goldberg2014)证明,带负采样的 skip-gram 实际上是在分解一个**移位点互信息(PMI)**矩阵。在收敛时,两个词向量的点积近似于:

v_w^T v_c \approx \text{PMI}(w, c) - \log k
  • 其中 \text{PMI}(w, c) = \log \frac{P(w, c)}{P(w) P(c)} 衡量词 wc 共现的频率比随机期望高出多少(见第 05 章信息论),k 是负样本数量。共现远高于随机期望的词具有高 PMI,从而具有高点积(相似的嵌入)。共现低于预期的词具有负 PMI 和不相似的嵌入。这表明 Word2Vec 实际上与经典的分布语义学方法(如潜在语义分析,即对共现矩阵做 SVD)在做同样的事情,只是采用了更具扩展性的在线方式。

  • Word2Vec 嵌入最令人惊讶的特性是它们能通过向量算术捕捉类比关系。向量 v_{\text{king}} - v_{\text{man}} + v_{\text{woman}} 最接近 $v_{\text{queen}}$。这是因为嵌入空间将语义关系编码为近似线性方向:"王室"方向大致为 $v_{\text{king}} - v_{\text{man}}$,将其加到 v_{\text{woman}} 上就会落在 v_{\text{queen}} 附近。这与第 01 章的线性代数相关联:语义关系就是向量平移。

  • GloVeGlobal Vectors for Word RepresentationPennington et al., 2014)采用不同的方法。它不是一次一个地从局部上下文窗口学习,而是构建一个全局的词共现矩阵 $X$,其中 X_{ij} 统计在整个语料库中词 j 出现在词 i 上下文中的次数。然后模型学习嵌入,使其点积近似于对数共现次数:

w_i^T \tilde{w}_j + b_i + \tilde{b}_j = \log X_{ij}
  • 损失函数通过一个截断函数 f(X_{ij}) 对每一对加权,防止非常频繁的共现主导训练:
\mathcal{L} = \sum_{i,j=1}^{V} f(X_{ij}) \left(w_i^T \tilde{w}_j + b_i + \tilde{b}_j - \log X_{ij}\right)^2
  • GloVe 结合了全局矩阵分解(如潜在语义分析)和 Word2Vec 的局部上下文学习的优点。在实践中,GloVe 和 Word2Vec 生成的嵌入质量相近。

  • FastTextBojanowski et al., 2017)扩展了 skip-gram,将每个词表示为一组字符 n-gram 的集合。对于 $n = 3$,词 "where" 变成:"<wh"、"whe"、"her"、"ere"、"re>",加上完整词标记 "<where>"。该词的嵌入是其所有 n-gram 嵌入之和。

  • 这有一个关键优势:FastText 能够为训练中从未见过的词生成嵌入。词 "whereabouts" 与 "where" 共享 n-gram,因此即使 "whereabouts" 从未出现在训练数据中,其嵌入也是合理的。这对于形态丰富的语言(文件 01)尤为有用,因为这些语言中的词有许多屈折形式。

  • 嵌入评估通常使用两类基准测试。类比任务测试 v_a - v_b + v_c \approx v_d 是否成立(例如,"Paris" - "France" + "Italy" \approx "Rome")。相似性基准将词对之间的余弦相似度(第 01 章)与人工判断进行比较。常见的数据集包括 WordSim-353、SimLex-999 和 Google 类比测试集。一个实用注意事项:在类比任务上表现出色的嵌入不一定最适合下游任务,如情感分类。最好的评估往往是任务本身。

  • 在第 06 章中,我们介绍了 RNN、LSTM 和 GRU 作为处理序列数据的架构。这里我们重点讨论它们如何具体应用于语言任务。

  • 语言模型 RNN 每次读取一个词元,并在每一步预测下一个词元。隐藏状态 h_t 将整个历史序列 w_1, \ldots, w_t 压缩为一个固定大小的向量,线性层加 softmax 将 h_t 映射到词汇表上的分布。训练使用与真实下一词元的交叉熵损失,这等价于最小化困惑度(文件 02)。关键局限在于:固定大小的隐藏状态必须编码关于历史的所有信息,早期词元的信息会逐渐被覆盖。

  • 双向 RNN 从两个方向处理序列:一个 RNN 从左到右读取,另一个从右到左读取。在每个位置 $t$,前向隐藏状态 \overrightarrow{h}_t 和后向隐藏状态 \overleftarrow{h}_t 被拼接起来,形成上下文感知的表示 $h_t = [\overrightarrow{h}_t ; \overleftarrow{h}_t]$。这使模型能够同时访问过去和未来的上下文,对于词性标注和命名实体识别(文件 02)等任务非常有效,因为这些任务中一个词的标签依赖于其前后的词。双向 RNN 不能用于语言建模,因为在预测未来词元时不能窥视它们。

双向 RNN:前向 RNN 从左到右读取产生隐藏状态,后向 RNN 从右到左读取,每个位置的输出拼接在一起

  • 深层堆叠 RNN 将多个 RNN 层叠放在一起。第 l 层所有时间步的隐藏状态成为第 l+1 层的输入序列。堆叠 2-4 层通常能通过构建层次化表示来提升性能,类似于深层 CNN 构建特征层次结构(第 06 章)。超过 4 层时,梯度消失和过拟合会成为问题,除非在层之间添加残差连接。

  • 序列到序列(seq2seq架构(Sutskever et al., 2014)将可变长度的输入序列映射到可变长度的输出序列。它由一个编码器 RNN(读取输入并将其压缩为上下文向量,即最终的隐藏状态)和一个解码器 RNN(基于该上下文向量逐步生成输出)组成。

Seq2seq 编码器-解码器:编码器 RNN 从左到右读取输入词元,最终隐藏状态作为解码器 RNN 的初始状态,解码器自回归地生成输出词元

  • Seq2seq 是机器翻译的突破性架构。编码器读取法语句子,解码器生成英文翻译。解码器从一个特殊的序列起始词元开始,自回归地生成词元,直到产生序列结束词元。一个实用的技巧:反转输入序列(输入 "chat le" 而不是 "le chat")可以改善结果,因为这使得第一个输入词在计算图中更靠近第一个输出词,缩短了梯度路径。

  • 瓶颈问题:整个输入必须被压缩到一个固定大小的向量中。对于长句子,这个向量无法捕捉所有信息,性能会下降。这推动了注意力机制的发展。

  • 第 06 章介绍了现代的点积注意力 Q、K、V 形式。NLP 中最早的注意力机制以不同的方式提出,作为编码器和解码器状态之间的对齐模型。

  • Bahdanau 注意力(加性注意力,Bahdanau et al., 2015)使用一个可学习的前馈网络计算解码器隐藏状态 s_t 与每个编码器隐藏状态 h_i 之间的对齐分数:

e_{ti} = v^T \tanh(W_s s_{t-1} + W_h h_i)
  • 分数通过 softmax 归一化为注意力权重,上下文向量是编码器状态的加权和:
\alpha_{ti} = \frac{\exp(e_{ti})}{\sum_j \exp(e_{tj})}, \quad c_t = \sum_i \alpha_{ti} h_i
  • 然后解码器同时使用 s_{t-1}c_t 来生成下一个输出。关键洞察:不是为整个句子使用一个固定的上下文向量,每个解码步骤获得编码器状态的不同加权组合,使模型能够"回顾"输入的相关部分。

  • Luong 注意力(乘性注意力,Luong et al., 2015)简化了分数计算。点积变体使用 $e_{ti} = s_t^T h_i$。通用变体使用 $e_{ti} = s_t^T W h_i$。这些比 Bahdanau 的加性分数更快,因为它们使用矩阵乘法而非前馈网络。Luong 注意力还从当前解码器状态 $s_t$(而非 $s_{t-1}$)计算上下文向量,这使得它能获取更多信息,但计算方式略有不同。

源句子与其翻译之间的注意力对齐热力图,显示每个目标词关注哪些源词,较亮的单元格表示更高的注意力权重

  • 注意力权重通常可视化为热力图,显示解码器在生成每个输出词元时关注哪些输入词元。在翻译中,这些热力图大致勾勒出源语言和目标语言之间的词对齐关系,对角模式会被重排序打破(例如,形容词-名词顺序在法语和英语中有所不同)。

  • 推理时,解码器每一步必须选择一个词元。贪心解码在每个位置选择概率最高的词元,但这可能导致次优序列:一个局部好的选择可能迫使模型进入全局不佳的句子。束搜索在每一步维护分数最高的 k 个(束宽)部分序列,对每个序列扩展所有可能的下一词元,并保留总体最好的 k 个。

  • 当束宽 k = 1 时,束搜索退化为贪心解码。典型值为 k = 4 到 $k = 10$。更大的束能找到更好的序列,但速度会成比例降低。束搜索还需要长度归一化,以避免偏向较短的序列(因为较短的序列乘法项更少,自然具有更高的总概率)。归一化后的分数为:

\text{score}(y) = \frac{1}{|y|^\alpha} \sum_{t=1}^{|y|} \log P(y_t \mid y_{<t})
  • 其中 |y| 是序列长度,$\alpha$(通常为 0.6-0.7)控制长度惩罚的强度。当 \alpha = 0 时,没有长度归一化。当 \alpha = 1 时,分数是每个词元的对数概率(几何平均)。中间值在倾向于简洁输出和不过早截断之间取得平衡。

  • 虽然 RNN 顺序处理文本,但 1D CNN 通过在词元序列上滑动滤波器来并行处理文本。每个滤波器检测一个局部模式(n-gram 特征)。

  • TextCNN(Kim, 2014)对输入的嵌入矩阵应用多个不同宽度(例如 3、4、5 个词元)的一维卷积滤波器。每个滤波器生成一个特征图,时序最大池化从每个特征图中取单一最大值,捕获该模式是否在文本中的任何位置被检测到,而不考虑位置。所有滤波器的池化特征被拼接后传递给分类器。

TextCNN 架构:输入嵌入通过宽度为 3、4、5 的并行卷积滤波器,每个滤波器后接时序最大池化,然后拼接并馈送到全连接分类器

  • TextCNN 速度快,对于情感分析等文本分类任务效果出奇地好。它能捕获局部 n-gram 模式,但无法建模长距离依赖:宽度为 5 的滤波器只能看到 5 个连续的词元。膨胀因果卷积通过在滤波器元素之间插入间隙(膨胀)来解决这个问题。堆叠膨胀率呈指数增长(1、2、4、8、...)的层,可以在不增加参数的情况下指数级地扩大感受野,使模型能够捕获跨越数百个词元的依赖关系。

  • 到目前为止讨论的所有嵌入(Word2Vec、GloVe、FastText)针对每个词类型生成单一向量,与上下文无关。"Bank"无论是指金融机构还是河岸,都得到相同的嵌入。这是一个根本性的局限,而上下文嵌入解决了这一问题。

  • ELMoEmbeddings from Language ModelsPeters et al., 2018)通过在输入文本上运行一个深层双向 LSTM 语言模型来生成上下文相关的词表示。前向 LSTM 在每个位置预测下一个词;一个独立的后向 LSTM 预测前一个词。两者都在大规模语料库上作为语言模型进行训练。

  • 在每个位置 $k$,ELMo 使用任务特定的学习权重组合所有 L 层的隐藏状态:

\text{ELMo}_k = \gamma \sum_{j=0}^{L} s_j \, h_{k,j}
  • 这里 h_{k,j} 是位置 kj 的隐藏状态(层 0 是原始词嵌入),s_j 是 softmax 归一化的标量权重,\gamma 是任务特定的缩放因子。不同层捕获不同信息:较低层捕获句法(词性标注、词形态),较高层捕获语义(词义、语义角色)。通过使用学习到的权重混合所有层,ELMo 嵌入能够适应多样化的下游任务。

  • ELMo 标志着预训练然后微调范式的开始:在海量无标注文本上训练大型语言模型,然后将其表示用于下游任务。ELMo 具体使用预训练的表示作为固定的或轻度微调的特征,与任务特定的输入拼接在一起。BERT 和 GPT(文件 04)通过端到端地微调整个模型进一步推进了这一范式,事实证明这要有效得多。

  • 从 Word2Vec 到 ELMo 的发展过程展示了 NLP 中一个反复出现的主题:从静态表示到动态表示,从局部上下文到全局上下文,从浅层模型到深层模型。每一步都以计算成本换取更丰富的表示。Transformer(文件 04)通过用注意力完全取代循环,实现了深层上下文化和并行计算,完成了这一演进。

编程任务(使用 CoLab 或 notebook

  1. 从头实现带负采样的 Word2Vec skip-gram。在小型语料库上训练,并使用 PCA 可视化学习到的嵌入。
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

# 小型语料库
corpus = """the king ruled the kingdom . the queen ruled the kingdom .
the prince is the son of the king . the princess is the daughter of the queen .
a man worked in the castle . a woman worked in the castle .
the king and queen lived in the castle . the prince and princess played outside .""".lower().split()

vocab = sorted(set(corpus))
word2idx = {w: i for i, w in enumerate(vocab)}
idx2word = {i: w for w, i in word2idx.items()}
V = len(vocab)

# 生成 skip-gram 对,窗口大小为 2
window = 2
pairs = []
for i, word in enumerate(corpus):
    for j in range(max(0, i - window), min(len(corpus), i + window + 1)):
        if i != j:
            pairs.append((word2idx[word], word2idx[corpus[j]]))

pairs = jnp.array(pairs)
print(f"词汇表大小: {V} 个词, 训练样本数: {len(pairs)}")

# 模型参数
embed_dim = 16
key = jax.random.PRNGKey(42)
k1, k2 = jax.random.split(key)
W_in = jax.random.normal(k1, (V, embed_dim)) * 0.1    # 输入嵌入
W_out = jax.random.normal(k2, (V, embed_dim)) * 0.1   # 输出嵌入

# 单个样本对的负采样损失
def neg_sampling_loss(W_in, W_out, target, context, neg_ids):
    v_in = W_in[target]      # (embed_dim,)
    v_out = W_out[context]   # (embed_dim,)
    v_neg = W_out[neg_ids]   # (k, embed_dim)

    pos_loss = -jax.nn.log_sigmoid(jnp.dot(v_in, v_out))
    neg_loss = -jnp.sum(jax.nn.log_sigmoid(-v_neg @ v_in))
    return pos_loss + neg_loss

# 训练循环
num_neg = 5
lr = 0.05

@jax.jit
def train_step(W_in, W_out, target, context, neg_ids):
    loss, (g_in, g_out) = jax.value_and_grad(neg_sampling_loss, argnums=(0, 1))(
        W_in, W_out, target, context, neg_ids)
    return loss, W_in - lr * g_in, W_out - lr * g_out

key = jax.random.PRNGKey(0)
for epoch in range(50):
    total_loss = 0.0
    for i in range(len(pairs)):
        key, subkey = jax.random.split(key)
        neg_ids = jax.random.randint(subkey, (num_neg,), 0, V)
        loss, W_in, W_out = train_step(W_in, W_out, pairs[i, 0], pairs[i, 1], neg_ids)
        total_loss += loss
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}: avg loss = {total_loss / len(pairs):.4f}")

# 使用 PCA 可视化(第 01 章)
embeddings = W_in
mean = embeddings.mean(axis=0)
centered = embeddings - mean
U, S, Vt = jnp.linalg.svd(centered, full_matrices=False)
coords = centered @ Vt[:2].T  # 投影到前两个主成分

plt.figure(figsize=(10, 8))
for i, word in idx2word.items():
    plt.scatter(coords[i, 0], coords[i, 1], c='#3498db', s=40)
    plt.annotate(word, (coords[i, 0] + 0.02, coords[i, 1] + 0.02), fontsize=9)
plt.title("Word2Vec Skip-gram 嵌入(PCA 投影)")
plt.grid(alpha=0.3); plt.show()
  1. 构建一个字符级 RNN 语言模型,从一小段训练文本中学习生成文本。
import jax
import jax.numpy as jnp

# 小型训练文本
text = "to be or not to be that is the question "
chars = sorted(set(text))
char2idx = {c: i for i, c in enumerate(chars)}
idx2char = {i: c for c, i in char2idx.items()}
V = len(chars)
data = jnp.array([char2idx[c] for c in text])

# RNN 参数
hidden_dim = 64
key = jax.random.PRNGKey(0)
k1, k2, k3, k4, k5 = jax.random.split(key, 5)

params = {
    'Wx': jax.random.normal(k1, (V, hidden_dim)) * 0.1,
    'Wh': jax.random.normal(k2, (hidden_dim, hidden_dim)) * 0.05,
    'bh': jnp.zeros(hidden_dim),
    'Wy': jax.random.normal(k3, (hidden_dim, V)) * 0.1,
    'by': jnp.zeros(V),
}

def rnn_step(params, h, x_idx):
    x = jnp.eye(V)[x_idx]  # one-hot 编码
    h = jnp.tanh(x @ params['Wx'] + h @ params['Wh'] + params['bh'])
    logits = h @ params['Wy'] + params['by']
    return h, logits

def loss_fn(params, inputs, targets):
    h = jnp.zeros(hidden_dim)
    total_loss = 0.0
    for t in range(len(inputs)):
        h, logits = rnn_step(params, h, inputs[t])
        log_probs = jax.nn.log_softmax(logits)
        total_loss -= log_probs[targets[t]]
    return total_loss / len(inputs)

grad_fn = jax.jit(jax.grad(loss_fn))

# 训练
inputs = data[:-1]
targets = data[1:]
lr = 0.01

for step in range(500):
    grads = grad_fn(params, inputs, targets)
    params = {k: params[k] - lr * grads[k] for k in params}
    if (step + 1) % 100 == 0:
        l = loss_fn(params, inputs, targets)
        print(f"Step {step+1}: loss = {l:.4f}")

# 生成文本
def generate(params, seed_char, length=60):
    h = jnp.zeros(hidden_dim)
    idx = char2idx[seed_char]
    result = [seed_char]
    key = jax.random.PRNGKey(42)
    for _ in range(length):
        h, logits = rnn_step(params, h, idx)
        key, subkey = jax.random.split(key)
        idx = jax.random.categorical(subkey, logits)
        result.append(idx2char[int(idx)])
    return ''.join(result)

print(f"\n生成文本: {generate(params, 't')}")
  1. 实现一个带 Bahdanau 注意力的简易 seq2seq 模型,用于序列反转。可视化注意力对齐矩阵。
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

# 任务:反转数字序列(例如,[3, 1, 4] -> [4, 1, 3]
vocab_size = 10  # 数字 0-9
SOS, EOS = 10, 11  # 特殊词元
total_vocab = 12
embed_dim, hidden_dim = 16, 32
max_len = 5

key = jax.random.PRNGKey(42)
keys = jax.random.split(key, 8)

params = {
    'embed': jax.random.normal(keys[0], (total_vocab, embed_dim)) * 0.1,
    'enc_Wx': jax.random.normal(keys[1], (embed_dim, hidden_dim)) * 0.1,
    'enc_Wh': jax.random.normal(keys[2], (hidden_dim, hidden_dim)) * 0.05,
    'dec_Wx': jax.random.normal(keys[3], (embed_dim, hidden_dim)) * 0.1,
    'dec_Wh': jax.random.normal(keys[4], (hidden_dim, hidden_dim)) * 0.05,
    # Bahdanau 注意力
    'Ws': jax.random.normal(keys[5], (hidden_dim, hidden_dim)) * 0.1,
    'Wh_att': jax.random.normal(keys[6], (hidden_dim, hidden_dim)) * 0.1,
    'v_att': jax.random.normal(keys[7], (hidden_dim,)) * 0.1,
    # 输出投影(从隐藏状态+上下文到词汇表)
    'Wo': jax.random.normal(keys[0], (hidden_dim * 2, total_vocab)) * 0.1,
}

def encode(params, seq):
    """编码输入序列,返回所有隐藏状态。"""
    h = jnp.zeros(hidden_dim)
    states = []
    for t in range(len(seq)):
        x = params['embed'][seq[t]]
        h = jnp.tanh(x @ params['enc_Wx'] + h @ params['enc_Wh'])
        states.append(h)
    return jnp.stack(states), h

def bahdanau_attention(params, dec_state, enc_states):
    """计算 Bahdanau 注意力权重和上下文向量。"""
    scores = jnp.tanh(enc_states @ params['Wh_att'] + dec_state @ params['Ws'])
    e = scores @ params['v_att']  # (src_len,)
    alpha = jax.nn.softmax(e)
    context = alpha @ enc_states
    return context, alpha

def decode_step(params, dec_h, prev_token, enc_states):
    x = params['embed'][prev_token]
    dec_h = jnp.tanh(x @ params['dec_Wx'] + dec_h @ params['dec_Wh'])
    context, alpha = bahdanau_attention(params, dec_h, enc_states)
    combined = jnp.concatenate([dec_h, context])
    logits = combined @ params['Wo']
    return dec_h, logits, alpha

def seq2seq_loss(params, src, tgt):
    enc_states, enc_final = encode(params, src)
    dec_h = enc_final
    loss = 0.0
    prev_token = SOS
    for t in range(len(tgt)):
        dec_h, logits, _ = decode_step(params, dec_h, prev_token, enc_states)
        log_probs = jax.nn.log_softmax(logits)
        loss -= log_probs[tgt[t]]
        prev_token = tgt[t]
    return loss / len(tgt)

# 生成训练数据:反转序列
key = jax.random.PRNGKey(0)
train_srcs, train_tgts = [], []
for _ in range(200):
    key, subkey = jax.random.split(key)
    length = jax.random.randint(subkey, (), 3, max_len + 1)
    key, subkey = jax.random.split(key)
    seq = jax.random.randint(subkey, (int(length),), 0, vocab_size)
    train_srcs.append(seq)
    train_tgts.append(seq[::-1])  # 反转

# 训练
grad_fn = jax.grad(seq2seq_loss)
lr = 0.01

for epoch in range(100):
    total_loss = 0.0
    for src, tgt in zip(train_srcs, train_tgts):
        grads = grad_fn(params, src, tgt)
        params = {k: params[k] - lr * grads[k] for k in params}
        total_loss += seq2seq_loss(params, src, tgt)
    if (epoch + 1) % 20 == 0:
        print(f"Epoch {epoch+1}: avg loss = {total_loss / len(train_srcs):.4f}")

# 可视化一个示例的注意力
test_src = jnp.array([3, 1, 4, 1, 5])
test_tgt = test_src[::-1]

enc_states, enc_final = encode(params, test_src)
dec_h = enc_final
attentions = []
prev_token = SOS
for t in range(len(test_tgt)):
    dec_h, logits, alpha = decode_step(params, dec_h, prev_token, enc_states)
    attentions.append(alpha)
    prev_token = test_tgt[t]

att_matrix = jnp.stack(attentions)
fig, ax = plt.subplots(figsize=(6, 5))
im = ax.imshow(att_matrix, cmap='Blues')
ax.set_xlabel("源位置"); ax.set_ylabel("目标位置")
src_labels = [str(int(x)) for x in test_src]
tgt_labels = [str(int(x)) for x in test_tgt]
ax.set_xticks(range(len(src_labels))); ax.set_xticklabels(src_labels)
ax.set_yticks(range(len(tgt_labels))); ax.set_yticklabels(tgt_labels)
for i in range(len(tgt_labels)):
    for j in range(len(src_labels)):
        ax.text(j, i, f"{att_matrix[i,j]:.2f}", ha='center', va='center', fontsize=9)
ax.set_title("Bahdanau 注意力对齐(序列反转)")
plt.colorbar(im); plt.tight_layout(); plt.show()