Files

5746 lines
178 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
<!doctype html>
<html lang="zh" class="no-js">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width,initial-scale=1">
<meta name="description" content="一本开源的直觉优先教科书,从零开始覆盖数学、计算机科学和人工智能(中文翻译版)。">
<meta name="author" content="Henry Ndubuaku (flykhan 译)">
<link rel="canonical" href="https://flykhan.github.io/maths-cs-ai-compendium-zh/chapter%2007%3A%20computational%20linguistics/04.%20transformers%20and%20language%20models/">
<link rel="prev" href="../03.%20embeddings%20and%20sequence%20models/">
<link rel="next" href="../05.%20advanced%20text%20generation/">
<link rel="icon" href="../../assets/images/favicon.png">
<meta name="generator" content="mkdocs-1.6.1, mkdocs-material-9.7.6">
<title>Transformer 与语言模型 - 数学、计算机科学与 AI 百科全书</title>
<link rel="stylesheet" href="../../assets/stylesheets/main.484c7ddc.min.css">
<link rel="stylesheet" href="../../assets/stylesheets/palette.ab4e12ef.min.css">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300i,400,400i,700,700i%7CRoboto+Mono:400,400i,700,700i&display=fallback">
<style>:root{--md-text-font:"Roboto";--md-code-font:"Roboto Mono"}</style>
<script>__md_scope=new URL("../..",location),__md_hash=e=>[...e].reduce(((e,_)=>(e<<5)-e+_.charCodeAt(0)),0),__md_get=(e,_=localStorage,t=__md_scope)=>JSON.parse(_.getItem(t.pathname+"."+e)),__md_set=(e,_,t=localStorage,a=__md_scope)=>{try{t.setItem(a.pathname+"."+e,JSON.stringify(_))}catch(e){}}</script>
</head>
<body dir="ltr" data-md-color-scheme="default" data-md-color-primary="slate" data-md-color-accent="indigo">
<input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer" autocomplete="off">
<input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search" autocomplete="off">
<label class="md-overlay" for="__drawer"></label>
<div data-md-component="skip">
<a href="#transformer" class="md-skip">
跳转至
</a>
</div>
<div data-md-component="announce">
</div>
<header class="md-header" data-md-component="header">
<nav class="md-header__inner md-grid" aria-label="页眉">
<a href="../.." title="数学、计算机科学与 AI 百科全书" class="md-header__button md-logo" aria-label="数学、计算机科学与 AI 百科全书" data-md-component="logo">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 8a3 3 0 0 0 3-3 3 3 0 0 0-3-3 3 3 0 0 0-3 3 3 3 0 0 0 3 3m0 3.54C9.64 9.35 6.5 8 3 8v11c3.5 0 6.64 1.35 9 3.54 2.36-2.19 5.5-3.54 9-3.54V8c-3.5 0-6.64 1.35-9 3.54"/></svg>
</a>
<label class="md-header__button md-icon" for="__drawer">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M3 6h18v2H3zm0 5h18v2H3zm0 5h18v2H3z"/></svg>
</label>
<div class="md-header__title" data-md-component="header-title">
<div class="md-header__ellipsis">
<div class="md-header__topic">
<span class="md-ellipsis">
数学、计算机科学与 AI 百科全书
</span>
</div>
<div class="md-header__topic" data-md-component="header-topic">
<span class="md-ellipsis">
Transformer 与语言模型
</span>
</div>
</div>
</div>
<form class="md-header__option" data-md-component="palette">
<input class="md-option" data-md-color-media="" data-md-color-scheme="default" data-md-color-primary="slate" data-md-color-accent="indigo" aria-label="切换到深色模式" type="radio" name="__palette" id="__palette_0">
<label class="md-header__button md-icon" title="切换到深色模式" for="__palette_1" hidden>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 8a4 4 0 0 0-4 4 4 4 0 0 0 4 4 4 4 0 0 0 4-4 4 4 0 0 0-4-4m0 10a6 6 0 0 1-6-6 6 6 0 0 1 6-6 6 6 0 0 1 6 6 6 6 0 0 1-6 6m8-9.31V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12z"/></svg>
</label>
<input class="md-option" data-md-color-media="" data-md-color-scheme="slate" data-md-color-primary="slate" data-md-color-accent="indigo" aria-label="切换到浅色模式" type="radio" name="__palette" id="__palette_1">
<label class="md-header__button md-icon" title="切换到浅色模式" for="__palette_0" hidden>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 18c-.89 0-1.74-.2-2.5-.55C11.56 16.5 13 14.42 13 12s-1.44-4.5-3.5-5.45C10.26 6.2 11.11 6 12 6a6 6 0 0 1 6 6 6 6 0 0 1-6 6m8-9.31V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12z"/></svg>
</label>
</form>
<script>var palette=__md_get("__palette");if(palette&&palette.color){if("(prefers-color-scheme)"===palette.color.media){var media=matchMedia("(prefers-color-scheme: light)"),input=document.querySelector(media.matches?"[data-md-color-media='(prefers-color-scheme: light)']":"[data-md-color-media='(prefers-color-scheme: dark)']");palette.color.media=input.getAttribute("data-md-color-media"),palette.color.scheme=input.getAttribute("data-md-color-scheme"),palette.color.primary=input.getAttribute("data-md-color-primary"),palette.color.accent=input.getAttribute("data-md-color-accent")}for(var[key,value]of Object.entries(palette.color))document.body.setAttribute("data-md-color-"+key,value)}</script>
<label class="md-header__button md-icon" for="__search">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.52 6.52 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5"/></svg>
</label>
<div class="md-search" data-md-component="search" role="dialog">
<label class="md-search__overlay" for="__search"></label>
<div class="md-search__inner" role="search">
<form class="md-search__form" name="search">
<input type="text" class="md-search__input" name="query" aria-label="搜索" placeholder="搜索" autocapitalize="off" autocorrect="off" autocomplete="off" spellcheck="false" data-md-component="search-query" required>
<label class="md-search__icon md-icon" for="__search">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.52 6.52 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5"/></svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11z"/></svg>
</label>
<nav class="md-search__options" aria-label="查找">
<button type="reset" class="md-search__icon md-icon" title="清空当前内容" aria-label="清空当前内容" tabindex="-1">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M19 6.41 17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12z"/></svg>
</button>
</nav>
<div class="md-search__suggest" data-md-component="search-suggest"></div>
</form>
<div class="md-search__output">
<div class="md-search__scrollwrap" tabindex="0" data-md-scrollfix>
<div class="md-search-result" data-md-component="search-result">
<div class="md-search-result__meta">
正在初始化搜索引擎
</div>
<ol class="md-search-result__list" role="presentation"></ol>
</div>
</div>
</div>
</div>
</div>
<div class="md-header__source">
<a href="https://github.com/flykhan/maths-cs-ai-compendium-zh" title="前往仓库" class="md-source" data-md-component="source">
<div class="md-source__icon md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Free 7.1.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2025 Fonticons, Inc.--><path d="M439.6 236.1 244 40.5c-5.4-5.5-12.8-8.5-20.4-8.5s-15 3-20.4 8.4L162.5 81l51.5 51.5c27.1-9.1 52.7 16.8 43.4 43.7l49.7 49.7c34.2-11.8 61.2 31 35.5 56.7-26.5 26.5-70.2-2.9-56-37.3L240.3 199v121.9c25.3 12.5 22.3 41.8 9.1 55-6.4 6.4-15.2 10.1-24.3 10.1s-17.8-3.6-24.3-10.1c-17.6-17.6-11.1-46.9 11.2-56v-123c-20.8-8.5-24.6-30.7-18.6-45L142.6 101 8.5 235.1C3 240.6 0 247.9 0 255.5s3 15 8.5 20.4l195.6 195.7c5.4 5.4 12.7 8.4 20.4 8.4s15-3 20.4-8.4l194.7-194.7c5.4-5.4 8.4-12.8 8.4-20.4s-3-15-8.4-20.4"/></svg>
</div>
<div class="md-source__repository">
flykhan/maths-cs-ai-compendium-zh
</div>
</a>
</div>
</nav>
</header>
<div class="md-container" data-md-component="container">
<nav class="md-tabs" aria-label="标签" data-md-component="tabs">
<div class="md-grid">
<ul class="md-tabs__list">
<li class="md-tabs__item">
<a href="../.." class="md-tabs__link">
首页
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2001%3A%20vectors/01.%20vector%20spaces/" class="md-tabs__link">
向量
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2002%3A%20matrices/01.%20matrix%20properties/" class="md-tabs__link">
矩阵
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2003%3A%20calculus/01.%20differential%20calculus/" class="md-tabs__link">
微积分
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2004%3A%20statistics/01.%20fundamentals/" class="md-tabs__link">
统计学
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2005%3A%20probability/01.%20counting/" class="md-tabs__link">
概率论
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2006%3A%20machine%20learning/01.%20classical%20machine%20learning/" class="md-tabs__link">
机器学习
</a>
</li>
<li class="md-tabs__item md-tabs__item--active">
<a href="../01.%20linguistic%20foundations/" class="md-tabs__link">
计算语言学
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2008%3A%20computer%20vision/01.%20image%20fundamentals/" class="md-tabs__link">
计算机视觉
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2009%3A%20audio%20and%20speech/01.%20digital%20signal%20processing/" class="md-tabs__link">
音频与语音
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2010%3A%20multimodal%20learning/01.%20multimodal%20representations/" class="md-tabs__link">
多模态学习
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2011%3A%20autonomous%20systems/01.%20perception/" class="md-tabs__link">
自主系统
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2012%3A%20graph%20neural%20networks/01.%20geometric%20deep%20learning/" class="md-tabs__link">
图神经网络
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2013%3A%20computing%20and%20OS/01.%20discrete%20maths/" class="md-tabs__link">
计算机与操作系统
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/" class="md-tabs__link">
数据结构与算法
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2015%3A%20production%20software%20engineering/01.%20linux%20and%20CMD/" class="md-tabs__link">
生产级软件工程
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2016%3A%20SIMD%20and%20GPU%20programming/00.%20why%20C%2B%2B%20and%20how%20ML%20frameworks%20work/" class="md-tabs__link">
SIMD 与 GPU 编程
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2017%3A%20AI%20inference/01.%20quantisation/" class="md-tabs__link">
AI 推理
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2018%3A%20ML%20systems%20design/01.%20systems%20design%20fundamentals/" class="md-tabs__link">
ML 系统设计
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2019%3A%20applied%20AI/01.%20AI%20for%20finance/" class="md-tabs__link">
应用 AI
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2020%3A%20bleeding%20edge%20AI/01.%20quantum%20machine%20learning/" class="md-tabs__link">
前沿 AI
</a>
</li>
</ul>
</div>
</nav>
<main class="md-main" data-md-component="main">
<div class="md-main__inner md-grid">
<div class="md-sidebar md-sidebar--primary" data-md-component="sidebar" data-md-type="navigation" >
<div class="md-sidebar__scrollwrap">
<div class="md-sidebar__inner">
<nav class="md-nav md-nav--primary md-nav--lifted" aria-label="导航栏" data-md-level="0">
<label class="md-nav__title" for="__drawer">
<a href="../.." title="数学、计算机科学与 AI 百科全书" class="md-nav__button md-logo" aria-label="数学、计算机科学与 AI 百科全书" data-md-component="logo">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 8a3 3 0 0 0 3-3 3 3 0 0 0-3-3 3 3 0 0 0-3 3 3 3 0 0 0 3 3m0 3.54C9.64 9.35 6.5 8 3 8v11c3.5 0 6.64 1.35 9 3.54 2.36-2.19 5.5-3.54 9-3.54V8c-3.5 0-6.64 1.35-9 3.54"/></svg>
</a>
数学、计算机科学与 AI 百科全书
</label>
<div class="md-nav__source">
<a href="https://github.com/flykhan/maths-cs-ai-compendium-zh" title="前往仓库" class="md-source" data-md-component="source">
<div class="md-source__icon md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Free 7.1.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2025 Fonticons, Inc.--><path d="M439.6 236.1 244 40.5c-5.4-5.5-12.8-8.5-20.4-8.5s-15 3-20.4 8.4L162.5 81l51.5 51.5c27.1-9.1 52.7 16.8 43.4 43.7l49.7 49.7c34.2-11.8 61.2 31 35.5 56.7-26.5 26.5-70.2-2.9-56-37.3L240.3 199v121.9c25.3 12.5 22.3 41.8 9.1 55-6.4 6.4-15.2 10.1-24.3 10.1s-17.8-3.6-24.3-10.1c-17.6-17.6-11.1-46.9 11.2-56v-123c-20.8-8.5-24.6-30.7-18.6-45L142.6 101 8.5 235.1C3 240.6 0 247.9 0 255.5s3 15 8.5 20.4l195.6 195.7c5.4 5.4 12.7 8.4 20.4 8.4s15-3 20.4-8.4l194.7-194.7c5.4-5.4 8.4-12.8 8.4-20.4s-3-15-8.4-20.4"/></svg>
</div>
<div class="md-source__repository">
flykhan/maths-cs-ai-compendium-zh
</div>
</a>
</div>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../.." class="md-nav__link">
<span class="md-ellipsis">
首页
</span>
</a>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_2" >
<label class="md-nav__link" for="__nav_2" id="__nav_2_label" tabindex="0">
<span class="md-ellipsis">
向量
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_2_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_2">
<span class="md-nav__icon md-icon"></span>
向量
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2001%3A%20vectors/01.%20vector%20spaces/" class="md-nav__link">
<span class="md-ellipsis">
向量空间
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2001%3A%20vectors/02.%20vector%20properties/" class="md-nav__link">
<span class="md-ellipsis">
向量性质
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2001%3A%20vectors/03.%20norms%20and%20metrics/" class="md-nav__link">
<span class="md-ellipsis">
范数与度量
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2001%3A%20vectors/04.%20products/" class="md-nav__link">
<span class="md-ellipsis">
向量积
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2001%3A%20vectors/05.%20basis%20and%20duality/" class="md-nav__link">
<span class="md-ellipsis">
基与对偶性
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_3" >
<label class="md-nav__link" for="__nav_3" id="__nav_3_label" tabindex="0">
<span class="md-ellipsis">
矩阵
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_3_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_3">
<span class="md-nav__icon md-icon"></span>
矩阵
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2002%3A%20matrices/01.%20matrix%20properties/" class="md-nav__link">
<span class="md-ellipsis">
矩阵性质
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2002%3A%20matrices/02.%20matrix%20types/" class="md-nav__link">
<span class="md-ellipsis">
矩阵类型
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2002%3A%20matrices/03.%20operations/" class="md-nav__link">
<span class="md-ellipsis">
矩阵运算
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2002%3A%20matrices/04.%20linear%20transformations/" class="md-nav__link">
<span class="md-ellipsis">
线性变换
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2002%3A%20matrices/05.%20decompositions/" class="md-nav__link">
<span class="md-ellipsis">
矩阵分解
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_4" >
<label class="md-nav__link" for="__nav_4" id="__nav_4_label" tabindex="0">
<span class="md-ellipsis">
微积分
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_4_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_4">
<span class="md-nav__icon md-icon"></span>
微积分
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2003%3A%20calculus/01.%20differential%20calculus/" class="md-nav__link">
<span class="md-ellipsis">
微分
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2003%3A%20calculus/02.%20integral%20calculus/" class="md-nav__link">
<span class="md-ellipsis">
积分
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2003%3A%20calculus/03.%20multivariate%20calculus/" class="md-nav__link">
<span class="md-ellipsis">
多元微积分
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2003%3A%20calculus/04.%20function%20approximation/" class="md-nav__link">
<span class="md-ellipsis">
函数逼近
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2003%3A%20calculus/05.%20optimisation/" class="md-nav__link">
<span class="md-ellipsis">
优化
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_5" >
<label class="md-nav__link" for="__nav_5" id="__nav_5_label" tabindex="0">
<span class="md-ellipsis">
统计学
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_5_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_5">
<span class="md-nav__icon md-icon"></span>
统计学
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2004%3A%20statistics/01.%20fundamentals/" class="md-nav__link">
<span class="md-ellipsis">
基础
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2004%3A%20statistics/02.%20measures/" class="md-nav__link">
<span class="md-ellipsis">
统计量
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2004%3A%20statistics/03.%20sampling/" class="md-nav__link">
<span class="md-ellipsis">
抽样
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2004%3A%20statistics/04.%20hypothesis%20testing/" class="md-nav__link">
<span class="md-ellipsis">
假设检验
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2004%3A%20statistics/05.%20inference/" class="md-nav__link">
<span class="md-ellipsis">
推断
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_6" >
<label class="md-nav__link" for="__nav_6" id="__nav_6_label" tabindex="0">
<span class="md-ellipsis">
概率论
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_6_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_6">
<span class="md-nav__icon md-icon"></span>
概率论
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2005%3A%20probability/01.%20counting/" class="md-nav__link">
<span class="md-ellipsis">
计数
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2005%3A%20probability/02.%20probability%20concepts/" class="md-nav__link">
<span class="md-ellipsis">
概率概念
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2005%3A%20probability/03.%20distributions/" class="md-nav__link">
<span class="md-ellipsis">
分布
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2005%3A%20probability/04.%20bayesian/" class="md-nav__link">
<span class="md-ellipsis">
贝叶斯
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2005%3A%20probability/05.%20information%20theory/" class="md-nav__link">
<span class="md-ellipsis">
信息论
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_7" >
<label class="md-nav__link" for="__nav_7" id="__nav_7_label" tabindex="0">
<span class="md-ellipsis">
机器学习
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_7_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_7">
<span class="md-nav__icon md-icon"></span>
机器学习
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2006%3A%20machine%20learning/01.%20classical%20machine%20learning/" class="md-nav__link">
<span class="md-ellipsis">
经典机器学习
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2006%3A%20machine%20learning/02.%20gradient%20machine%20learning/" class="md-nav__link">
<span class="md-ellipsis">
梯度机器学习
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2006%3A%20machine%20learning/03.%20deep%20learning/" class="md-nav__link">
<span class="md-ellipsis">
深度学习
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2006%3A%20machine%20learning/04.%20reinforcement%20learning/" class="md-nav__link">
<span class="md-ellipsis">
强化学习
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2006%3A%20machine%20learning/05.%20distributed%20deep%20learning/" class="md-nav__link">
<span class="md-ellipsis">
分布式深度学习
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--active md-nav__item--section md-nav__item--nested">
<input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_8" checked>
<label class="md-nav__link" for="__nav_8" id="__nav_8_label" tabindex="">
<span class="md-ellipsis">
计算语言学
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_8_label" aria-expanded="true">
<label class="md-nav__title" for="__nav_8">
<span class="md-nav__icon md-icon"></span>
计算语言学
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../01.%20linguistic%20foundations/" class="md-nav__link">
<span class="md-ellipsis">
语言学基础
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../02.%20text%20processing%20and%20classic%20NLP/" class="md-nav__link">
<span class="md-ellipsis">
文本处理与经典 NLP
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../03.%20embeddings%20and%20sequence%20models/" class="md-nav__link">
<span class="md-ellipsis">
嵌入与序列模型
</span>
</a>
</li>
<li class="md-nav__item md-nav__item--active">
<input class="md-nav__toggle md-toggle" type="checkbox" id="__toc">
<label class="md-nav__link md-nav__link--active" for="__toc">
<span class="md-ellipsis">
Transformer 与语言模型
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<a href="./" class="md-nav__link md-nav__link--active">
<span class="md-ellipsis">
Transformer 与语言模型
</span>
</a>
<nav class="md-nav md-nav--secondary" aria-label="目录">
<label class="md-nav__title" for="__toc">
<span class="md-nav__icon md-icon"></span>
目录
</label>
<ul class="md-nav__list" data-md-component="toc" data-md-scrollfix>
<li class="md-nav__item">
<a href="#colab" class="md-nav__link">
<span class="md-ellipsis">
编程任务(使用CoLab或笔记本)
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item">
<a href="../05.%20advanced%20text%20generation/" class="md-nav__link">
<span class="md-ellipsis">
高级文本生成
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_9" >
<label class="md-nav__link" for="__nav_9" id="__nav_9_label" tabindex="0">
<span class="md-ellipsis">
计算机视觉
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_9_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_9">
<span class="md-nav__icon md-icon"></span>
计算机视觉
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2008%3A%20computer%20vision/01.%20image%20fundamentals/" class="md-nav__link">
<span class="md-ellipsis">
图像基础
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2008%3A%20computer%20vision/02.%20convolutional%20networks/" class="md-nav__link">
<span class="md-ellipsis">
卷积网络
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2008%3A%20computer%20vision/03.%20object%20detection%20and%20segmentation/" class="md-nav__link">
<span class="md-ellipsis">
目标检测与分割
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2008%3A%20computer%20vision/04.%20vision%20transformers%20and%20generation/" class="md-nav__link">
<span class="md-ellipsis">
ViT 与生成模型
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2008%3A%20computer%20vision/05.%20video%20and%203D%20vision/" class="md-nav__link">
<span class="md-ellipsis">
视频与 3D 视觉
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_10" >
<label class="md-nav__link" for="__nav_10" id="__nav_10_label" tabindex="0">
<span class="md-ellipsis">
音频与语音
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_10_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_10">
<span class="md-nav__icon md-icon"></span>
音频与语音
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2009%3A%20audio%20and%20speech/01.%20digital%20signal%20processing/" class="md-nav__link">
<span class="md-ellipsis">
数字信号处理
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2009%3A%20audio%20and%20speech/02.%20automatic%20speech%20recognition/" class="md-nav__link">
<span class="md-ellipsis">
自动语音识别
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2009%3A%20audio%20and%20speech/03.%20text%20to%20speech%20and%20voice/" class="md-nav__link">
<span class="md-ellipsis">
语音合成
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2009%3A%20audio%20and%20speech/04.%20speaker%20and%20audio%20analysis/" class="md-nav__link">
<span class="md-ellipsis">
说话人与音频分析
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2009%3A%20audio%20and%20speech/05.%20source%20separation%20and%20noise/" class="md-nav__link">
<span class="md-ellipsis">
源分离与降噪
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_11" >
<label class="md-nav__link" for="__nav_11" id="__nav_11_label" tabindex="0">
<span class="md-ellipsis">
多模态学习
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_11_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_11">
<span class="md-nav__icon md-icon"></span>
多模态学习
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2010%3A%20multimodal%20learning/01.%20multimodal%20representations/" class="md-nav__link">
<span class="md-ellipsis">
多模态表征
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2010%3A%20multimodal%20learning/02.%20vision%20language%20models/" class="md-nav__link">
<span class="md-ellipsis">
视觉语言模型
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2010%3A%20multimodal%20learning/03.%20image%20and%20video%20tokenisation/" class="md-nav__link">
<span class="md-ellipsis">
图像与视频 Token 化
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/" class="md-nav__link">
<span class="md-ellipsis">
跨模态生成
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2010%3A%20multimodal%20learning/05.%20unified%20multimodal%20architectures/" class="md-nav__link">
<span class="md-ellipsis">
统一多模态架构
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_12" >
<label class="md-nav__link" for="__nav_12" id="__nav_12_label" tabindex="0">
<span class="md-ellipsis">
自主系统
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_12_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_12">
<span class="md-nav__icon md-icon"></span>
自主系统
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2011%3A%20autonomous%20systems/01.%20perception/" class="md-nav__link">
<span class="md-ellipsis">
感知
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2011%3A%20autonomous%20systems/02.%20robot%20learning/" class="md-nav__link">
<span class="md-ellipsis">
机器人学习
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2011%3A%20autonomous%20systems/03.%20vision-language-action%20models/" class="md-nav__link">
<span class="md-ellipsis">
视觉-语言-动作模型
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2011%3A%20autonomous%20systems/04.%20self-driving/" class="md-nav__link">
<span class="md-ellipsis">
自动驾驶
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2011%3A%20autonomous%20systems/05.%20space%20and%20extreme%20robotics/" class="md-nav__link">
<span class="md-ellipsis">
太空与极端机器人
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_13" >
<label class="md-nav__link" for="__nav_13" id="__nav_13_label" tabindex="0">
<span class="md-ellipsis">
图神经网络
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_13_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_13">
<span class="md-nav__icon md-icon"></span>
图神经网络
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2012%3A%20graph%20neural%20networks/01.%20geometric%20deep%20learning/" class="md-nav__link">
<span class="md-ellipsis">
几何深度学习
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2012%3A%20graph%20neural%20networks/02.%20graph%20theory/" class="md-nav__link">
<span class="md-ellipsis">
图论
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2012%3A%20graph%20neural%20networks/03.%20graph%20neural%20networks/" class="md-nav__link">
<span class="md-ellipsis">
图神经网络
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2012%3A%20graph%20neural%20networks/04.%20graph%20attention%20networks/" class="md-nav__link">
<span class="md-ellipsis">
图注意力网络
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2012%3A%20graph%20neural%20networks/05.%203d%20graph%20networks/" class="md-nav__link">
<span class="md-ellipsis">
3D 图网络
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_14" >
<label class="md-nav__link" for="__nav_14" id="__nav_14_label" tabindex="0">
<span class="md-ellipsis">
计算机与操作系统
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_14_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_14">
<span class="md-nav__icon md-icon"></span>
计算机与操作系统
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2013%3A%20computing%20and%20OS/01.%20discrete%20maths/" class="md-nav__link">
<span class="md-ellipsis">
离散数学
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2013%3A%20computing%20and%20OS/02.%20computer%20architecture/" class="md-nav__link">
<span class="md-ellipsis">
计算机体系结构
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2013%3A%20computing%20and%20OS/03.%20operating%20systems/" class="md-nav__link">
<span class="md-ellipsis">
操作系统
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2013%3A%20computing%20and%20OS/04.%20concurrency%20and%20parallelism/" class="md-nav__link">
<span class="md-ellipsis">
并发与并行
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2013%3A%20computing%20and%20OS/05.%20programming%20languages/" class="md-nav__link">
<span class="md-ellipsis">
编程语言
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_15" >
<label class="md-nav__link" for="__nav_15" id="__nav_15_label" tabindex="0">
<span class="md-ellipsis">
数据结构与算法
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_15_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_15">
<span class="md-nav__icon md-icon"></span>
数据结构与算法
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/" class="md-nav__link">
<span class="md-ellipsis">
基础
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/" class="md-nav__link">
<span class="md-ellipsis">
数组与哈希
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2014%3A%20data%20structures%20and%20algorithms/02.%20linked%20lists%2C%20stacks%2C%20and%20queues/" class="md-nav__link">
<span class="md-ellipsis">
链表、栈与队列
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2014%3A%20data%20structures%20and%20algorithms/03.%20trees/" class="md-nav__link">
<span class="md-ellipsis">
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2014%3A%20data%20structures%20and%20algorithms/04.%20graphs/" class="md-nav__link">
<span class="md-ellipsis">
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/" class="md-nav__link">
<span class="md-ellipsis">
排序与搜索
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_16" >
<label class="md-nav__link" for="__nav_16" id="__nav_16_label" tabindex="0">
<span class="md-ellipsis">
生产级软件工程
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_16_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_16">
<span class="md-nav__icon md-icon"></span>
生产级软件工程
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2015%3A%20production%20software%20engineering/01.%20linux%20and%20CMD/" class="md-nav__link">
<span class="md-ellipsis">
Linux 与命令行
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2015%3A%20production%20software%20engineering/02.%20git%20and%20repository%20management/" class="md-nav__link">
<span class="md-ellipsis">
Git 与仓库管理
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2015%3A%20production%20software%20engineering/03.%20codebase%20design/" class="md-nav__link">
<span class="md-ellipsis">
代码设计
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2015%3A%20production%20software%20engineering/04.%20testing%20and%20quality%20assurance/" class="md-nav__link">
<span class="md-ellipsis">
测试与质量保障
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2015%3A%20production%20software%20engineering/05.%20deployment%20and%20devops/" class="md-nav__link">
<span class="md-ellipsis">
部署与 DevOps
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_17" >
<label class="md-nav__link" for="__nav_17" id="__nav_17_label" tabindex="0">
<span class="md-ellipsis">
SIMD 与 GPU 编程
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_17_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_17">
<span class="md-nav__icon md-icon"></span>
SIMD 与 GPU 编程
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2016%3A%20SIMD%20and%20GPU%20programming/00.%20why%20C%2B%2B%20and%20how%20ML%20frameworks%20work/" class="md-nav__link">
<span class="md-ellipsis">
为什么是 C++ 及 ML 框架原理
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2016%3A%20SIMD%20and%20GPU%20programming/01.%20hardware%20fundamentals/" class="md-nav__link">
<span class="md-ellipsis">
硬件基础
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2016%3A%20SIMD%20and%20GPU%20programming/02.%20ARM%20and%20NEON/" class="md-nav__link">
<span class="md-ellipsis">
ARM 与 NEON
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2016%3A%20SIMD%20and%20GPU%20programming/03.%20x86%20and%20AVX/" class="md-nav__link">
<span class="md-ellipsis">
x86 与 AVX
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2016%3A%20SIMD%20and%20GPU%20programming/04.%20GPU%20architecture%20and%20CUDA/" class="md-nav__link">
<span class="md-ellipsis">
GPU 架构与 CUDA
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2016%3A%20SIMD%20and%20GPU%20programming/05.%20triton%2C%20TPUs%20and%20pallax/" class="md-nav__link">
<span class="md-ellipsis">
Triton、TPU 与 Pallas
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2016%3A%20SIMD%20and%20GPU%20programming/06.%20RISC-V%20and%20embedded%20systems/" class="md-nav__link">
<span class="md-ellipsis">
RISC-V 与嵌入式系统
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2016%3A%20SIMD%20and%20GPU%20programming/07.%20vulkan%20compute%20and%20cross-platform%20GPU/" class="md-nav__link">
<span class="md-ellipsis">
Vulkan Compute 与跨平台 GPU
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_18" >
<label class="md-nav__link" for="__nav_18" id="__nav_18_label" tabindex="0">
<span class="md-ellipsis">
AI 推理
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_18_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_18">
<span class="md-nav__icon md-icon"></span>
AI 推理
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2017%3A%20AI%20inference/01.%20quantisation/" class="md-nav__link">
<span class="md-ellipsis">
量化
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2017%3A%20AI%20inference/02.%20efficient%20architectures/" class="md-nav__link">
<span class="md-ellipsis">
高效架构
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2017%3A%20AI%20inference/03.%20serving%20and%20batching/" class="md-nav__link">
<span class="md-ellipsis">
服务与批处理
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2017%3A%20AI%20inference/04.%20edge%20inference/" class="md-nav__link">
<span class="md-ellipsis">
边缘推理
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2017%3A%20AI%20inference/05.%20scaling%20and%20deployment/" class="md-nav__link">
<span class="md-ellipsis">
扩缩与部署
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_19" >
<label class="md-nav__link" for="__nav_19" id="__nav_19_label" tabindex="0">
<span class="md-ellipsis">
ML 系统设计
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_19_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_19">
<span class="md-nav__icon md-icon"></span>
ML 系统设计
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2018%3A%20ML%20systems%20design/01.%20systems%20design%20fundamentals/" class="md-nav__link">
<span class="md-ellipsis">
系统设计基础
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2018%3A%20ML%20systems%20design/02.%20cloud%20computing/" class="md-nav__link">
<span class="md-ellipsis">
云计算
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/" class="md-nav__link">
<span class="md-ellipsis">
大规模基础设施
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/" class="md-nav__link">
<span class="md-ellipsis">
ML 系统设计
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/" class="md-nav__link">
<span class="md-ellipsis">
ML 设计案例
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_20" >
<label class="md-nav__link" for="__nav_20" id="__nav_20_label" tabindex="0">
<span class="md-ellipsis">
应用 AI
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_20_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_20">
<span class="md-nav__icon md-icon"></span>
应用 AI
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2019%3A%20applied%20AI/01.%20AI%20for%20finance/" class="md-nav__link">
<span class="md-ellipsis">
AI 金融
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2019%3A%20applied%20AI/02.%20protein%20design/" class="md-nav__link">
<span class="md-ellipsis">
蛋白质设计
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2019%3A%20applied%20AI/03.%20drug%20discovery/" class="md-nav__link">
<span class="md-ellipsis">
药物发现
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2019%3A%20applied%20AI/04.%20agentic%20systems/" class="md-nav__link">
<span class="md-ellipsis">
智能体系统
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2019%3A%20applied%20AI/05.%20healthcare/" class="md-nav__link">
<span class="md-ellipsis">
医疗健康
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_21" >
<label class="md-nav__link" for="__nav_21" id="__nav_21_label" tabindex="0">
<span class="md-ellipsis">
前沿 AI
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_21_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_21">
<span class="md-nav__icon md-icon"></span>
前沿 AI
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2020%3A%20bleeding%20edge%20AI/01.%20quantum%20machine%20learning/" class="md-nav__link">
<span class="md-ellipsis">
量子机器学习
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2020%3A%20bleeding%20edge%20AI/02.%20neuromorphic%20computing/" class="md-nav__link">
<span class="md-ellipsis">
神经形态计算
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2020%3A%20bleeding%20edge%20AI/03.%20datacentres%20in%20space/" class="md-nav__link">
<span class="md-ellipsis">
太空数据中心
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2020%3A%20bleeding%20edge%20AI/04.%20decentralised%20AI/" class="md-nav__link">
<span class="md-ellipsis">
去中心化 AI
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2020%3A%20bleeding%20edge%20AI/05.%20brain%20machine%20interfaces/" class="md-nav__link">
<span class="md-ellipsis">
脑机接口
</span>
</a>
</li>
</ul>
</nav>
</li>
</ul>
</nav>
</div>
</div>
</div>
<div class="md-sidebar md-sidebar--secondary" data-md-component="sidebar" data-md-type="toc" >
<div class="md-sidebar__scrollwrap">
<div class="md-sidebar__inner">
<nav class="md-nav md-nav--secondary" aria-label="目录">
<label class="md-nav__title" for="__toc">
<span class="md-nav__icon md-icon"></span>
目录
</label>
<ul class="md-nav__list" data-md-component="toc" data-md-scrollfix>
<li class="md-nav__item">
<a href="#colab" class="md-nav__link">
<span class="md-ellipsis">
编程任务(使用CoLab或笔记本)
</span>
</a>
</li>
</ul>
</nav>
</div>
</div>
</div>
<div class="md-content" data-md-component="content">
<article class="md-content__inner md-typeset">
<h1 id="transformer">Transformer与语言模型<a class="headerlink" href="#transformer" title="Permanent link">&para;</a></h1>
<p><em>Transformer用自注意力取代了循环结构,成为语言理解和生成的主导架构。本文件涵盖BERT、GPT、T5、位置编码(正弦编码、RoPE)、预训练目标(MLM、CLM)、微调、提示工程和缩放定律——这些是现代大语言模型背后的蓝图。</em></p>
<ul>
<li>
<p>在第06章中,我们介绍了Transformer架构:自注意力、多头注意力、位置编码以及编码器-解码器结构。这里我们聚焦于Transformer如何适配特定的NLP范式、定义现代NLP的模型(BERT、GPT、T5),以及让它们在大规模下切实可行的技术。</p>
</li>
<li>
<p>回顾核心操作:<strong>缩放点积注意力</strong>计算 <span class="arithmatex">\(\text{softmax}(QK^T / \sqrt{d_k}) V\)</span>,其中查询、键和值都是输入的线性投影。<strong>多头注意力</strong>并行运行 <span class="arithmatex">\(h\)</span> 个注意力头,每个头使用不同的学习投影,然后将结果拼接起来。Transformer块通过残差连接、层归一化和逐位置前馈网络(第06章)将这一切包裹起来。</p>
</li>
<li>
<p>一个微妙但重要的架构选择是<strong>层归一化</strong>的放置位置。原始Transformer使用<strong>后归一化</strong>:残差和归一化在子层之后执行,即 <span class="arithmatex">\(\text{LayerNorm}(x + \text{Sublayer}(x))\)</span></p>
</li>
<li>
<p>大多数现代模型使用<strong>前归一化</strong>:在子层之前进行归一化,即 <span class="arithmatex">\(x + \text{Sublayer}(\text{LayerNorm}(x))\)</span>。前归一化在训练过程中更加稳定,因为残差连接直接将梯度通过恒等路径传递,不受归一化的影响。这使得训练非常深的模型变得更容易,无需仔细的学习率预热。</p>
</li>
<li>
<p>每个Transformer块中的<strong>前馈子层</strong>是一个两层MLP,独立应用于每个标记位置:</p>
</li>
</ul>
<div class="arithmatex">\[\text{FFN}(x) = W_2 \cdot \text{GELU}(W_1 x + b_1) + b_2\]</div>
<ul>
<li>
<p>内部维度通常是模型维度的4倍(例如,<span class="arithmatex">\(d_{\text{model}} = 768\)</span><span class="arithmatex">\(d_{\text{ff}} = 3072\)</span>)。这个FFN约占每个块中参数的三分之二,被认为起到键-值记忆的作用,存储训练过程中学到的事实知识。</p>
</li>
<li>
<p><strong>位置编码</strong>为模型提供标记顺序的信息,因为注意力本身是置换等变的。原始的<strong>正弦编码</strong>(第06章)使用不同频率的固定正弦和余弦函数。<strong>可学习位置嵌入</strong>则简单地为每个位置添加一个可训练向量(用于BERT和GPT-2)。两者都是绝对编码:无论上下文如何,位置5总是得到相同的向量。</p>
</li>
<li>
<p><strong>旋转位置编码(RoPE</strong>通过在二维子空间中旋转查询和键向量来编码位置。对于一对维度 <span class="arithmatex">\((q_{2i}, q_{2i+1})\)</span>,按角度 <span class="arithmatex">\(m\theta_i\)</span> 的旋转(其中 <span class="arithmatex">\(m\)</span> 是位置,<span class="arithmatex">\(\theta_i = 10000^{-2i/d}\)</span>)应用如下:</p>
</li>
</ul>
<div class="arithmatex">\[
\begin{bmatrix} q'_{2i} \\ q'_{2i+1} \end{bmatrix} = \begin{bmatrix} \cos m\theta_i & -\sin m\theta_i \\ \sin m\theta_i & \cos m\theta_i \end{bmatrix} \begin{bmatrix} q_{2i} \\ q_{2i+1} \end{bmatrix}
\]</div>
<p><img alt="RoPE:每个位置在二维子空间中以不同角度旋转查询和键向量,使注意力分数仅依赖于相对位置" src="../../images/rope_rotation.svg" /></p>
<ul>
<li>
<p>RoPE的精妙之处在于,旋转后的查询和键之间的点积 <span class="arithmatex">\(q'^T k'\)</span> 仅依赖于相对位置 <span class="arithmatex">\(m - n\)</span>,而非绝对位置。</p>
</li>
<li>
<p>为了理解原因,将旋转写为 <span class="arithmatex">\(q' = R_m q\)</span><span class="arithmatex">\(k' = R_n k\)</span>,其中 <span class="arithmatex">\(R_m\)</span> 是一个块对角旋转矩阵。注意力分数变为:</p>
</li>
</ul>
<div class="arithmatex">\[q'^T k' = (R_m q)^T (R_n k) = q^T R_m^T R_n \, k = q^T R_{n-m} \, k\]</div>
<ul>
<li>
<p>最后一步利用了旋转群性质:<span class="arithmatex">\(R_m^T R_n = R_{n-m}\)</span>(先向后旋转 <span class="arithmatex">\(m\)</span> 再向前旋转 <span class="arithmatex">\(n\)</span>,等价于旋转 <span class="arithmatex">\(n - m\)</span>)。</p>
</li>
<li>
<p>这意味着注意力分数仅依赖于相对距离 <span class="arithmatex">\(n - m\)</span>,而非绝对位置 <span class="arithmatex">\(m\)</span><span class="arithmatex">\(n\)</span> 本身。</p>
</li>
<li>
<p>模型无需任何学习的位置参数就能获得自然的距离概念,并且可以泛化到训练时未见过的序列长度。</p>
</li>
<li>
<p><strong>ALiBi</strong>(带线性偏置的注意力)采用了一种更简单的方法:它根据距离向注意力分数添加一个固定的线性惩罚,即 <span class="arithmatex">\(\text{score}_{ij} = q_i^T k_j - m \cdot |i - j|\)</span>,其中 <span class="arithmatex">\(m\)</span> 是每个头特定的斜率。不同的头使用不同的斜率,使一些头可以关注局部信息,另一些头关注全局信息。ALiBi不需要任何可学习的位置参数,并且能够很好地泛化到比训练时更长的序列。</p>
</li>
<li>
<p>基于Transformer的语言模型的三种主导范式是<strong>仅编码器</strong><strong>仅解码器</strong><strong>编码器-解码器</strong>。它们在模型能看到的范围(注意力掩码)以及训练方式上有所不同。</p>
</li>
</ul>
<p><img alt="三种Transformer范式:仅编码器(BERT)使用双向注意力进行分类,仅解码器(GPT)使用因果注意力进行生成,编码器-解码器(T5)结合两者用于序列到序列任务" src="../../images/transformer_paradigms.svg" /></p>
<ul>
<li>
<p><strong>BERT</strong>(来自Transformer的双向编码器表示,Devlin等人,2019)是典型的仅编码器模型。它使用完全的双向注意力处理文本:每个标记可以关注所有其他标记,包括左右两侧。这赋予了BERT丰富的上下文表示,但意味着它不能自回归地生成文本。</p>
</li>
<li>
<p>BERT通过两个目标进行预训练。<strong>掩码语言建模(MLM</strong>随机遮蔽15%的输入标记,并训练模型去预测它们。在被选中的标记中,80%被替换为[MASK]标记,10%被替换为随机词,10%保持不变(以防止模型只学会在看到[MASK]时才进行预测)。训练目标如下:</p>
</li>
</ul>
<div class="arithmatex">\[\mathcal{L}_{\text{MLM}} = -\sum_{i \in \mathcal{M}} \log P(w_i \mid w_{\backslash \mathcal{M}})\]</div>
<ul>
<li>其中 <span class="arithmatex">\(\mathcal{M}\)</span> 是被遮蔽的位置集合,<span class="arithmatex">\(w_{\backslash \mathcal{M}}\)</span> 是这些位置被遮蔽后的句子。这是一个<strong>去噪</strong>目标:模型学习重建被破坏的输入。</li>
</ul>
<p><img alt="BERT掩码语言建模:15%的输入标记被遮蔽,双向Transformer在遮蔽位置预测原始标记" src="../../images/bert_mlm.svg" /></p>
<ul>
<li>
<p><strong>下一句预测(NSP</strong>训练BERT预测两个句子在原始文本中是否连续。输入开头的特殊[CLS]标记用于此二分类。NSP的加入是为了帮助理解句子关系的任务(如问答),不过后来的工作(RoBERTa)表明其贡献很小,可以去掉。</p>
</li>
<li>
<p>BERT的预训练表示通过在其顶部添加特定任务的头部(一个简单的线性层)并微调整个模型来适应下游任务。对于分类任务,使用[CLS]标记的表示。对于标记级任务(命名实体识别、词性标注),使用每个标记的表示。这种<strong>微调</strong>方法将预训练期间学到的语言知识迁移到新任务上,只需相对较少的标注数据。</p>
</li>
<li>
<p><strong>GPT</strong>(生成式预训练TransformerRadford等人,2018)是典型的仅解码器模型。它使用<strong>因果(自回归)注意力</strong>:每个标记只能关注更早位置的标记(以及自身)。这是通过在注意力矩阵中遮蔽未来位置(将其分数设置为 <span class="arithmatex">\(-\infty\)</span>,然后再进行softmax)来实现的。训练目标是简单的<strong>因果语言建模</strong>:根据所有之前的标记预测下一个标记。</p>
</li>
</ul>
<div class="arithmatex">\[\mathcal{L}_{\text{CLM}} = -\sum_{i=1}^{n} \log P(w_i \mid w_1, \ldots, w_{i-1})\]</div>
<ul>
<li>
<p>这与文件02中的n-gram语言模型目标相同,但采用了Transformer参数化方式,可以基于整个前文进行条件建模,而不仅仅是最后 <span class="arithmatex">\(k-1\)</span> 个标记。</p>
</li>
<li>
<p><strong>GPT-2</strong>将其规模扩大到15亿参数,并展现了强大的零样本能力:无需任何微调,它就能通过自然语言提示("将英语翻译成法语:……")来执行任务。</p>
</li>
<li>
<p><strong>GPT-3</strong>(1750亿参数)表明,仅凭规模就能实现<strong>上下文学习</strong>:通过在提示中提供几个输入-输出示例,模型无需任何梯度更新就能执行新任务。</p>
</li>
<li>
<p><strong>编码器-解码器模型</strong><strong>T5</strong>(文本到文本迁移TransformerRaffel等人,2020)将每个NLP任务都视为文本到文本:输入是一个文本字符串(可能带有任务前缀,如"将英语翻译成德语:"),输出也是一个文本字符串。编码器使用双向注意力处理输入,解码器则通过交叉注意力自回归地生成输出。</p>
</li>
<li>
<p>T5通过<strong>跨度破坏</strong>进行预训练:随机连续标记跨度被替换为哨兵标记,模型需要生成原始标记。例如,"The cat sat on the mat"可能变成输入"The [X] on [Y]",目标输出是"[X] cat sat [Y] the mat"。这是BERT的MLM从单个标记向跨度的泛化。</p>
</li>
<li>
<p><strong>BART</strong>(Lewis等人,2020)是另一种编码器-解码器模型,通过去噪目标进行预训练,但它应用了更广泛的破坏策略:标记遮蔽、标记删除、跨度遮蔽、句子置换和文档旋转。多样化的破坏方式迫使模型学习更鲁棒的表示。</p>
</li>
<li>
<p>随着语言模型变得越来越大,<strong>全量微调</strong>(更新所有参数)变得不切实际:一个175B参数的模型仅存储优化器状态就需要数百GB。<strong>参数高效微调(PEFT</strong>方法只调整一小部分参数。</p>
</li>
<li>
<p><strong>适配器</strong>在现有Transformer层之间插入小型瓶颈层(通常是两个线性层加一个非线性激活:下投影到小维度,再上投影回来)。只有适配器的权重被训练;原始模型权重被冻结。这增加了不到5%的新参数,同时在大多数任务上匹配全量微调的性能。</p>
</li>
<li>
<p><strong>LoRA</strong>(低秩适配)直接修改权重矩阵,而不添加新层。LoRA不更新完整的权重矩阵 <span class="arithmatex">\(W\)</span>,而是学习一个低秩分解的更新:<span class="arithmatex">\(W' = W + BA\)</span>,其中 <span class="arithmatex">\(B\)</span><span class="arithmatex">\(d \times r\)</span> 矩阵,<span class="arithmatex">\(A\)</span><span class="arithmatex">\(r \times d\)</span> 矩阵,且 <span class="arithmatex">\(r \ll d\)</span>(通常 <span class="arithmatex">\(r = 4\)</span><span class="arithmatex">\(r = 64\)</span>)。原始 <span class="arithmatex">\(W\)</span> 被冻结;只训练 <span class="arithmatex">\(A\)</span><span class="arithmatex">\(B\)</span>。在推理时,更新可以合并到原始权重中,不会增加额外延迟:</p>
</li>
</ul>
<div class="arithmatex">\[W' = W + BA\]</div>
<p><img alt="LoRA:冻结的权重矩阵W被一个通过小矩阵A和B的低秩路径旁路,可训练参数减少32倍,同时匹配全量微调的性能" src="../../images/lora_decomposition.svg" /></p>
<ul>
<li>
<p><strong>前缀微调</strong>在每个注意力层的键和值矩阵前添加一串可学习的"虚拟标记"。模型像对待真实标记一样关注这些前缀向量,并且只训练前缀参数。这与提示微调类似,但在激活空间而非嵌入空间中操作。</p>
</li>
<li>
<p><strong>提示工程</strong>是设计输入文本的艺术,旨在从预训练模型中引出所需行为,而无需任何参数更新。</p>
<ul>
<li>
<p><strong>零样本提示</strong>用自然语言描述任务("对以下评论的情感进行分类:")。</p>
</li>
<li>
<p><strong>少样本提示</strong>在实际查询之前提供输入-输出示例。</p>
</li>
<li>
<p><strong>链式思维(CoT)提示</strong>添加"让我们一步一步地思考"或在示例中包含推理过程,这通过引导模型分解问题,显著提高了算术和逻辑推理任务的性能。</p>
</li>
</ul>
</li>
<li>
<p><strong>上下文学习(ICL</strong>是大语言模型能够从提示中提供的示例学习执行任务的现象,而无需任何梯度更新。模型的权重没有改变;它将示例作为一种隐式规范来使用。</p>
</li>
<li>
<p>ICL在机制上是如何工作的仍然是一个活跃的研究问题;一种假说是注意力层在前向传播中实现了一种梯度下降形式,实际上是在上下文示例上进行"训练"。</p>
</li>
<li>
<p><strong>缩放定律</strong>描述了模型大小、数据大小、计算预算与性能(以损失衡量)之间的可预测关系。Kaplan等人(2020)发现损失在每个变量上都遵循幂律:</p>
</li>
</ul>
<div class="arithmatex">\[L(N) \propto N^{-\alpha_N}, \quad L(D) \propto D^{-\alpha_D}, \quad L(C) \propto C^{-\alpha_C}\]</div>
<ul>
<li>其中 <span class="arithmatex">\(N\)</span> 是参数量,<span class="arithmatex">\(D\)</span> 是数据集大小,<span class="arithmatex">\(C\)</span> 是计算预算。这些幂律在多个数量级上成立,表明单纯地扩大规模就能带来可预测的改进。</li>
</ul>
<p><img alt="缩放定律:损失在对数-对数坐标轴上以幂律递减,Kaplan和Chinchilla的研究结果表明随规模扩大有可预期的改进" src="../../images/scaling_laws.svg" /></p>
<ul>
<li><strong>Chinchilla缩放定律</strong>Hoffmann等人,2022)修正了这一点,指出大多数大型模型都训练不足。对于固定的计算预算 <span class="arithmatex">\(C\)</span>,最优分配是同等规模地扩大模型大小和训练数据:</li>
</ul>
<div class="arithmatex">\[N_{\text{opt}} \propto C^{0.5}, \quad D_{\text{opt}} \propto C^{0.5}\]</div>
<ul>
<li>
<p>这意味着如果计算预算翻倍,应该同时将模型大小和数据集大小增加 <span class="arithmatex">\(\sqrt{2}\)</span> 倍,而不仅仅是让模型变得更大。</p>
</li>
<li>
<p>Kaplan等人曾建议 <span class="arithmatex">\(N\)</span> 的缩放速度应快于 <span class="arithmatex">\(D\)</span>,这导致了非常大但训练不足的模型。Chinchilla(70B参数,1.4T标记)在相同的计算预算下匹配了Gopher(280B参数,300B标记)的性能,表明早期模型严重缺乏数据。</p>
</li>
<li>
<p>实用的经验法则:大约每个参数训练20个标记。</p>
</li>
<li>
<p><strong>混合专家(MoE</strong>是一种在不成比例增加计算量的情况下扩大模型容量的架构。MoE不采用单一的大型前馈层,而是使用多个<strong>专家</strong>FFN层和一个<strong>门控网络</strong>(路由网络)来选择每个标记应该激活哪些专家。</p>
</li>
<li>
<p>门控函数计算每个专家的路由分数,并选择前 <span class="arithmatex">\(k\)</span> 个(通常 <span class="arithmatex">\(k = 1\)</span><span class="arithmatex">\(k = 2\)</span>):</p>
</li>
</ul>
<div class="arithmatex">\[G(x) = \text{TopK}(\text{softmax}(W_g x))\]</div>
<ul>
<li>只有被选中的专家处理该标记,因此计算成本随 <span class="arithmatex">\(k\)</span>(活跃专家数)而非总专家数 <span class="arithmatex">\(E\)</span> 增长。一个有8个专家且采用top-2路由的模型,参数量是稠密模型的4倍,但计算量仅为2倍。</li>
</ul>
<p><img alt="MoE层:输入标记经过路由网络计算每个专家的分数,选择top-2专家,它们的输出按门控分数加权后求和" src="../../images/moe_layer.svg" /></p>
<ul>
<li>MoE中一个关键的挑战是<strong>负载均衡</strong>:如果路由网络将大多数标记发送给少数热门专家,其他专家就被浪费了。训练时会添加一个辅助的<strong>负载均衡损失</strong>,鼓励均匀的专家利用率:</li>
</ul>
<div class="arithmatex">\[\mathcal{L}_{\text{balance}} = E \cdot \sum_{i=1}^{E} f_i \cdot p_i\]</div>
<ul>
<li>
<p>其中 <span class="arithmatex">\(f_i\)</span> 是分配给专家 <span class="arithmatex">\(i\)</span> 的标记比例,<span class="arithmatex">\(p_i\)</span> 是专家 <span class="arithmatex">\(i\)</span> 的平均路由概率。当标记比例和概率都均匀(各等于 <span class="arithmatex">\(1/E\)</span>)时,该乘积最小。</p>
</li>
<li>
<p><strong>专家并行</strong>将不同的专家分布到不同的加速器上。在前向传播过程中,通过一个全到全的通信步骤将标记路由到其指定专家所在的设备,然后将结果路由回来。这种通信成本是MoE在大规模部署中的主要工程挑战。Switch Transformer、Mixtral和GShard等模型使用MoE来获得强大的性能,同时保持合理的推理成本。</p>
</li>
<li>
<p>构建模型只是工作的一半;衡量它们是否有效是另一半。NLP评估特别困难,因为语言是模糊的、主观的和开放式的。</p>
</li>
<li>
<p>一个翻译可以有多种正确的表达方式。一个摘要即使与参考摘要没有任何完全相同的词汇,也可能是好的。</p>
</li>
<li>
<p>一个聊天机器人的回复可能既有用、又无害、又诚实,但理性的人仍会对此有不同看法。</p>
</li>
<li>
<p><strong>精确匹配(EM</strong>是最简单的指标:模型的输出是否与标准答案完全一致?它用于答案简短且无歧义的任务,如抽取式问答(SQuAD)或封闭式数学问题。</p>
</li>
<li>
<p>EM是严苛的;"New York City"和"new york city"在不做归一化的情况下无法匹配——但它的简单性使其没有歧义。</p>
</li>
<li>
<p><strong>标记级指标</strong>将NLP视为标记级别的分类问题,使用第06章中的精确率、召回率和F1值。</p>
</li>
<li>
<p><strong>精确率(Precision</strong>衡量模型预测的标记中正确部分的比例:<span class="arithmatex">\(P = \text{TP} / (\text{TP} + \text{FP})\)</span>。一个预测很少但全部正确的模型具有高精确率。</p>
</li>
<li>
<p><strong>召回率(Recall</strong>衡量模型找到了多少标准标记:<span class="arithmatex">\(R = \text{TP} / (\text{TP} + \text{FN})\)</span>。一个将所有标记都预测为实体的模型具有完美的召回率但精确率极低。</p>
</li>
<li>
<p><strong>F1</strong>是精确率和召回率的调和平均值:</p>
</li>
</ul>
<div class="arithmatex">\[F_1 = \frac{2PR}{P + R}\]</div>
<ul>
<li>
<p>调和平均值(而非算术平均值)惩罚不均衡:如果 <span class="arithmatex">\(P\)</span><span class="arithmatex">\(R\)</span> 中任何一个较低,F1就会很低。对于命名实体识别(文件02),F1按每个实体类型分别计算,然后跨类型取宏平均。对于词性标注,标记级准确率更常见,因为每个标记都有一个标签。</p>
</li>
<li>
<p><strong>跨度级F1</strong>(用于SQuAD)比较预测跨度中的标记集与标准跨度中的标记集。这比精确匹配更宽容:如果标准答案是"the Eiffel Tower"而模型预测的是"Eiffel Tower",跨度F1很高(5个重叠标记中的4个),即使EM为零。</p>
</li>
<li>
<p><strong>BLEU</strong>(双语评估替补,Papineni等人,2002)是机器翻译的经典指标。它衡量候选翻译与一个或多个参考翻译之间的n-gram重叠。该评分结合了多个n-gram级别(unigram到4-gram)的精确率和一个简短惩罚:</p>
</li>
</ul>
<div class="arithmatex">\[\text{BLEU} = \text{BP} \cdot \exp\!\left(\sum_{n=1}^{N} w_n \log p_n\right)\]</div>
<ul>
<li>
<p>其中 <span class="arithmatex">\(p_n\)</span><strong>修正的n-gram精确率</strong>:候选翻译中每个n-gram的计数被裁剪为其在任何参考翻译中的最大计数,防止像"the the the the"这样的退化候选获得高分。权重 <span class="arithmatex">\(w_n\)</span> 通常是均匀的(<span class="arithmatex">\(w_n = 1/N\)</span>,其中 <span class="arithmatex">\(N = 4\)</span>)。</p>
</li>
<li>
<p><strong>简短惩罚</strong> <span class="arithmatex">\(\text{BP} = \min(1, \exp(1 - r/c))\)</span> 惩罚比参考翻译短的候选(<span class="arithmatex">\(c\)</span> 是候选长度,<span class="arithmatex">\(r\)</span> 是参考长度)。没有这个惩罚,模型可以通过输出很少但非常安全的词来获得高精确率。</p>
</li>
<li>
<p>BLEU在语料级别(对多个句子取平均)与人类判断有合理的相关性,但在句子级别相关性较差。</p>
</li>
<li>
<p>它奖励精确的n-gram匹配,但会遗漏有效的释义:"the cat is on the mat"和"a feline sits atop the rug"尽管意思相同,但二元组重叠为零。</p>
</li>
<li>
<p>BLEU也完全忽略了召回率——只输出最常见词汇的候选在精确率上得分很高。</p>
</li>
<li>
<p><strong>ROUGE</strong>(面向召回率的摘要评估替补,Lin,2004)是摘要的标准指标。与强调精确率的BLEU不同,ROUGE强调召回率:参考n-gram中有多少比例出现在候选摘要中?</p>
</li>
<li>
<p><strong>ROUGE-N</strong>计算n-gram的召回率:<span class="arithmatex">\(\text{ROUGE-N} = \frac{|\text{n-grams}_{\text{ref}} \cap \text{n-grams}_{\text{cand}}|}{|\text{n-grams}_{\text{ref}}|}\)</span>。ROUGE-1unigram)和ROUGE-2bigram)最为常用。</p>
</li>
<li>
<p>ROUGE-L使用候选和参考之间的<strong>最长公共子序列(LCS</strong>,这可以捕捉句子级别的词序信息,而不要求连续匹配。</p>
</li>
<li>
<p>LCS长度除以参考长度得到召回率,除以候选长度得到精确率,F度量则组合两者。</p>
</li>
<li>
<p>LCS通过动态规划在 <span class="arithmatex">\(O(mn)\)</span> 时间内计算(类似于文件02中的编辑距离):</p>
</li>
</ul>
<div class="arithmatex">\[R_{\text{LCS}} = \frac{\text{LCS}(X, Y)}{m}, \quad P_{\text{LCS}} = \frac{\text{LCS}(X, Y)}{n}, \quad F_{\text{LCS}} = \frac{(1 + \beta^2) R_{\text{LCS}} P_{\text{LCS}}}{R_{\text{LCS}} + \beta^2 P_{\text{LCS}}}\]</div>
<ul>
<li>
<p>其中 <span class="arithmatex">\(m\)</span><span class="arithmatex">\(n\)</span> 分别是参考和候选的长度,<span class="arithmatex">\(\beta\)</span> 通常设置为偏向召回率(<span class="arithmatex">\(\beta \to \infty\)</span> 给出纯召回率)。</p>
</li>
<li>
<p><strong>METEOR</strong>(带显式排序的翻译评估度量,Banerjee和Lavie,2005)通过引入同义词、词干提取和词序来解决BLEU的弱点。</p>
</li>
<li>
<p>它首先使用精确匹配、词干匹配(通过文件02中的Porter词干提取算法)和同义词匹配(通过文件01中的WordNet)在候选和参考之间对齐词汇。</p>
</li>
<li>
<p>然后计算unigram精确率和召回率的调和平均值(偏向召回率),并应用一个碎片化惩罚,惩罚那些匹配词顺序与参考不同的候选。</p>
</li>
<li>
<p><strong>ChrF</strong>(字符n-gram F值)计算字符n-gram而非词汇n-gram的F值。这使其对形态变化具有鲁棒性(对文件01中的黏着语至关重要),并部分处理了分词差异。ChrF++在字符n-gram的基础上增加了词汇二元组。</p>
</li>
<li>
<p>它已成为机器翻译中与BLEU一起推荐的度量标准,特别是对于形态丰富的语言。</p>
</li>
<li>
<p><strong>困惑度</strong>(文件02)衡量语言模型在保留测试集上的预测效果。这是语言模型的标准内在指标:<span class="arithmatex">\(\text{PPL} = \exp(-\frac{1}{N} \sum_{i} \log P(w_i \mid w_{&lt;i}))\)</span>。越低越好。</p>
</li>
<li>
<p>困惑度只能在使用了相同分词方法的模型之间进行比较,因为不同的分词器对同一文本会产生不同的序列长度 <span class="arithmatex">\(N\)</span></p>
</li>
<li>
<p>词汇量更大的模型每个标记的困惑度往往更低,但每个句子处理的标记数也更少。</p>
</li>
<li>
<p><strong>每字节比特数(BPB</strong>按照文本中UTF-8字节数而非标记数进行归一化,使其与分词方式无关:</p>
</li>
</ul>
<div class="arithmatex">\[
\text{BPB} = \frac{-\sum_{i} \log_2 P(w_i \mid w_{<i})}{\text{UTF-8字节数}}
\]</div>
<ul>
<li><strong>BERTScore</strong>(Zhang等人,2020)超越了表面的n-gram匹配,在嵌入空间中计算相似度。候选中的每个标记与其在参考中最相似的标记进行匹配,使用上下文嵌入(通常来自预训练的BERT模型)的余弦相似度。分数汇总为精确率、召回率和F1:</li>
</ul>
<div class="arithmatex">\[R_{\text{BERT}} = \frac{1}{|r|} \sum_{r_i \in r} \max_{c_j \in c} \cos(r_i, c_j), \quad P_{\text{BERT}} = \frac{1}{|c|} \sum_{c_j \in c} \max_{r_i \in r} \cos(c_j, r_i)\]</div>
<ul>
<li>
<p>其中 <span class="arithmatex">\(r_i\)</span><span class="arithmatex">\(c_j\)</span> 是参考和候选标记的上下文嵌入。这捕捉了n-gram指标无法捕捉的语义相似性:"automobile"和"car"得分很高,因为它们的BERT嵌入相似,尽管它们没有共享任何字符。</p>
</li>
<li>
<p><strong>BLEURT</strong>(Sellam等人,2020)在此基础上更进一步,直接在人工质量判断上微调BERT模型。给定一个参考和候选对,它输出一个标量质量分数。BLEURT在合成数据(由BLEU和METEOR等指标评分的参考翻译的随机扰动)上训练,然后在人工评分上微调。它与人类判断的相关性优于任何表面级指标。</p>
</li>
<li>
<p><strong>COMET</strong>(翻译评估跨语言优化指标,Rei等人,2020)是一个用于机器翻译的学习度量,它同时以源句、参考和候选作为条件——而不仅仅是参考和候选。它使用多语言编码器(XLM-R)嵌入三者,并预测质量分数。通过看到源句,COMET可以检测仅基于参考的指标所遗漏的意义错误(例如,流畅但事实错误的翻译)。</p>
</li>
<li>
<p><strong>大语言模型作为裁判(LLM-as-judge</strong>是大规模评估的现代方法。不再计算与参考的指标,而是让一个强大的语言模型(GPT-4、Claude)被提示评估模型输出的质量。裁判接收输入、模型的回复以及可选的参考答案,并给出评分(例如1-5分)或成对偏好(回复A优于回复B)。</p>
</li>
<li>
<p><strong>成对比较</strong>(用于Chatbot Arena)是最可靠的LLM-as-judge格式。裁判看到两个回复并选择更好的那个,而不是给出绝对分数。这避免了校准问题(不同的裁判可能对"3/5"有不同的基准)。结果汇总为<strong>Elo评分</strong>(源自国际象棋),每个模型从一个基准评分开始,根据与其他模型的对战胜负增减分数。模型 <span class="arithmatex">\(A\)</span> 对模型 <span class="arithmatex">\(B\)</span> 的预期获胜概率为:</p>
</li>
</ul>
<div class="arithmatex">\[P(A \succ B) = \frac{1}{1 + 10^{(R_B - R_A) / 400}}\]</div>
<ul>
<li>
<p>其中 <span class="arithmatex">\(R_A, R_B\)</span> 是Elo评分。每次比较后,评分更新:<span class="arithmatex">\(R_A' = R_A + K(S - P(A \succ B))\)</span>,其中 <span class="arithmatex">\(S \in \{0, 1\}\)</span> 是实际结果,<span class="arithmatex">\(K\)</span> 控制更新幅度。持续击败强对手的模型快速上升;输给弱对手的模型下降。</p>
</li>
<li>
<p><strong>位置偏置</strong>是LLM裁判的一个已知问题:它们倾向于偏好先展示的回复(或者在某些模型中,后展示的回复)。<strong>交换</strong>(以两种顺序对每对进行评估)并平均结果可以缓解这一问题。</p>
</li>
<li>
<p><strong>冗长偏置</strong>是另一个问题:裁判倾向于偏好更长、更详细的回复,即使简洁的回答更好。</p>
</li>
<li>
<p><strong>自一致性</strong>检查裁判在多次评估同一输入时是否给出相同的评分。高方差表明评估信号存在噪音。</p>
</li>
<li>
<p><strong>标注者间一致性</strong>Cohen's kappa或Krippendorff's alpha)衡量多个裁判是否一致,为评估可靠性提供了一个上限。</p>
</li>
<li>
<p><strong>数据污染</strong>是一个关键问题:如果评估数据出现在模型的训练集中,基准分数就会被夸大且毫无意义。</p>
</li>
<li>
<p>这对于在网页抓取数据上训练的大语言模型尤其有问题,因为流行的基准很可能出现在其中。缓解策略包括:使用未公开发布的保留测试集、创建定期重新生成问题的动态基准、<strong>金丝雀字符串</strong>(嵌入在基准数据中用于检测泄露的唯一标识符),以及比较在污染与清洁子集上的性能。</p>
</li>
<li>
<p><strong>标准NLU基准</strong>评估跨多种任务的语言理解能力。</p>
</li>
<li>
<p><strong>GLUE</strong>(通用语言理解评估)和<strong>SuperGLUE</strong>是多任务基准,涵盖情感分析(SST-2)、文本相似度(STS-B)、自然语言推理(MNLI、RTE)、共指消解(WSC)和问答(BoolQ)。</p>
</li>
<li>
<p>模型在每个任务上分别评估,并按聚合指标打分。GLUE现在被认为已经饱和(模型在大多数任务上已超过人类表现);SuperGLUE仍然更具挑战性。</p>
</li>
<li>
<p><strong>MMLU</strong>(大规模多任务语言理解)通过多项选择题评估57个学术科目(数学、历史、法律、医学、计算机科学等)中的知识和推理能力。</p>
</li>
<li>
<p>它测试模型在预训练期间是否吸收了广泛的知识。分数按科目报告并作为宏平均给出。</p>
</li>
<li>
<p><strong>MMLU-Pro</strong>增加了更困难的多步推理问题,有10个选项而非4个。</p>
</li>
<li>
<p><strong>HellaSwag</strong>通过要求模型选择一个场景最合理的续写来测试常识推理。错误的答案是通过模型对抗性生成的,表面看似合理但语义错误。</p>
</li>
<li>
<p><strong>WinoGrande</strong>通过仅一词之差的极小对测试常识共指消解。</p>
</li>
<li>
<p><strong>ARC</strong>(AI2推理挑战)使用小学科学问题,分为简单和挑战集,测试事实和推理能力。</p>
</li>
<li>
<p><strong>推理和数学基准</strong>评估区分强大LLM与弱小LLM的问题解决能力。</p>
</li>
<li>
<p><strong>GSM8K</strong>(小学数学8K)包含8,500道小学算术应用题,需要多步算术推理。它是基础数学推理和评估链式思维提示(文件04)的标准基准。</p>
</li>
<li>
<p><strong>MATH</strong>是一个更难的数据集,包含代数、数论、几何、计数和概率方面的竞赛级数学问题。问题需要多步符号推理,MATH-500是常用的500题子集。</p>
</li>
<li>
<p><strong>AIME</strong>(美国数学邀请赛)问题是竞赛级的:正确解答需要跨越多个步骤的深度数学推理。DeepSeek-R1在AIME 2024上得分为79.8%,展示了经过RL训练的推理模型(文件05)可以接近人类高手。</p>
</li>
<li>
<p><strong>HumanEval</strong><strong>MBPP</strong>(基础编程问题)通过检查模型生成的代码是否通过单元测试来评估代码生成能力。HumanEval包含164个Python问题,包括函数签名和文档字符串;模型需要生成函数体。</p>
</li>
<li>
<p>指标是<strong>pass@k</strong>:在 <span class="arithmatex">\(k\)</span> 个生成的解决方案中至少有一个通过所有测试的概率。对于单个样本:</p>
</li>
</ul>
<div class="arithmatex">\[\text{pass@}k = 1 - \frac{\binom{n-c}{k}}{\binom{n}{k}}\]</div>
<ul>
<li>
<p>其中 <span class="arithmatex">\(n\)</span> 是生成的样本总数,<span class="arithmatex">\(c\)</span> 是通过的数量。这个公式修正了简单取 <span class="arithmatex">\(k\)</span> 个样本中最好结果的偏差。</p>
</li>
<li>
<p><strong>SWE-bench</strong>更进一步,评估模型能否通过修改现有代码库来解决真实的GitHub问题——这是对实际软件工程能力的更困难测试。</p>
</li>
<li>
<p><strong>GPQA</strong>(研究生级Google-proof问答)包含生物学、物理学和化学领域的专家级问题,即使是领域专家也很难解答。它测试模型是否具有真正的理解能力而非模式匹配。"Diamond"子集是最难的部分。</p>
</li>
<li>
<p><strong>安全和对齐基准</strong>评估模型是否有用、无害和诚实。</p>
</li>
<li>
<p><strong>TruthfulQA</strong>测试模型是否复现了常见的误解。问题设计为最常见的互联网答案是错误的(例如,"如果吞下口香糖会怎样?",常见的谣言是它会在胃里停留7年,但事实是它会正常通过)。那些记忆了流行但不正确说法的模型得分很低。</p>
</li>
<li>
<p><strong>BBQ</strong>(问答偏置基准)测试在年龄、性别、种族和宗教等类别上的社会偏置。问题的结构使得有偏置的模型会系统地选择刻板印象的答案。<strong>Toxigen</strong>评估模型针对特定人口群体生成有害内容的倾向。</p>
</li>
<li>
<p><strong>MT-Bench</strong>使用80个精心设计的问题评估多轮对话能力,涵盖写作、角色扮演、推理、数学、编程、信息抽取、STEM和人文学科。LLM裁判(GPT-4)按1-10分对回复评分。多轮格式测试模型是否能进行后续提问、保持上下文和处理澄清请求。</p>
</li>
<li>
<p><strong>Chatbot Arena</strong>(LMSYS)使用真实用户对匿名模型进行盲法成对比较。用户提交提示并对更好的回复投票,而不知道是哪个模型生成的。由此产生的Elo排行榜被认为是对通用LLM质量最生态有效的评估,因为它反映了真实用户在多样化、未经策划的提示上的偏好。</p>
</li>
<li>
<p><strong>AlpacaEval</strong>通过在一组固定的指令上将模型输出与参考模型(GPT-4)进行比较来自动化成对评估。由裁判模型决定胜率。</p>
</li>
<li>
<p><strong>AlpacaEval 2.0</strong>使用长度控制的胜率来纠正冗长偏置。</p>
</li>
<li>
<p><strong>任务特定评估</strong>需要针对专业领域量身定制的指标。</p>
</li>
<li>
<p><strong>词错误率(WER</strong>用于语音识别:<span class="arithmatex">\(\text{WER} = (S + D + I) / N\)</span>,其中 <span class="arithmatex">\(S\)</span><span class="arithmatex">\(D\)</span><span class="arithmatex">\(I\)</span> 分别是替换、删除和插入错误,<span class="arithmatex">\(N\)</span> 是参考词的数量。这是按参考长度归一化的编辑距离(文件02),应用于词汇级别。</p>
</li>
<li>
<p><strong>槽位F1</strong>用于任务导向的对话系统,衡量模型是否正确地从用户话语中提取结构化信息(例如,从"帮我订一张明天去巴黎的机票"中提取"目的地:巴黎"和"日期:明天")。</p>
</li>
<li>
<p><strong>引用准确率</strong>用于RAG系统(文件05),检查模型生成的引用是否确实支持所提出的主张。将主张与检索到的段落进行验证,指标统计完全支持、部分支持和不支持的主张比例。</p>
</li>
<li>
<p><strong>评估陷阱</strong>很常见,可能使整个基准比较无效。</p>
</li>
<li>
<p><strong>对测试投其所好</strong>:优化基准性能而非真正能力。在MMLU风格的多项选择上微调的模型在MMLU上得分很高,但在以开放式形式提出的相同问题上可能失败。</p>
</li>
<li>
<p><strong>指标游戏化</strong>:模型可以被优化以产生在自动指标上得分很高的输出(高BLEU、低困惑度),但并非真正优秀。BLEU最优的翻译往往是安全、通用的释义,而非自然流畅的翻译。</p>
</li>
<li>
<p><strong>基准饱和</strong>:当模型在基准上接近或超过人类表现时,该基准就不再提供信息。GLUE、SQuAD 1.1和其他几个基准现在已经饱和。</p>
</li>
<li>
<p>该领域不断创建更难的新基准,但这种创建、饱和和替换的循环使得纵向比较变得困难。</p>
</li>
<li>
<p><strong>人工评估</strong>仍然是黄金标准,但成本高、速度慢且难以复现。不同的标注者群体(众包工作者与领域专家、不同文化、不同语言)会产生不同的判断。报告标注者间一致性和标注者人口统计信息对可复现性至关重要。</p>
</li>
</ul>
<h2 id="colab">编程任务(使用CoLab或笔记本)<a class="headerlink" href="#colab" title="Permanent link">&para;</a></h2>
<ol>
<li>
<p>从头实现一个完整的Transformer编码器块(多头注意力、前馈网络、残差连接、层归一化)。将其应用于一个简单的序列分类任务。
<div class="highlight"><pre><span></span><code><a id="__codelineno-0-1" name="__codelineno-0-1" href="#__codelineno-0-1"></a><span class="kn">import</span><span class="w"> </span><span class="nn">jax</span>
<a id="__codelineno-0-2" name="__codelineno-0-2" href="#__codelineno-0-2"></a><span class="kn">import</span><span class="w"> </span><span class="nn">jax.numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">jnp</span>
<a id="__codelineno-0-3" name="__codelineno-0-3" href="#__codelineno-0-3"></a><span class="kn">import</span><span class="w"> </span><span class="nn">matplotlib.pyplot</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">plt</span>
<a id="__codelineno-0-4" name="__codelineno-0-4" href="#__codelineno-0-4"></a>
<a id="__codelineno-0-5" name="__codelineno-0-5" href="#__codelineno-0-5"></a><span class="k">def</span><span class="w"> </span><span class="nf">layer_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">gamma</span><span class="p">,</span> <span class="n">beta</span><span class="p">,</span> <span class="n">eps</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">):</span>
<a id="__codelineno-0-6" name="__codelineno-0-6" href="#__codelineno-0-6"></a> <span class="n">mean</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<a id="__codelineno-0-7" name="__codelineno-0-7" href="#__codelineno-0-7"></a> <span class="n">var</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">var</span><span class="p">(</span><span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<a id="__codelineno-0-8" name="__codelineno-0-8" href="#__codelineno-0-8"></a> <span class="k">return</span> <span class="n">gamma</span> <span class="o">*</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="n">jnp</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="n">eps</span><span class="p">)</span> <span class="o">+</span> <span class="n">beta</span>
<a id="__codelineno-0-9" name="__codelineno-0-9" href="#__codelineno-0-9"></a>
<a id="__codelineno-0-10" name="__codelineno-0-10" href="#__codelineno-0-10"></a><span class="k">def</span><span class="w"> </span><span class="nf">multi_head_attention</span><span class="p">(</span><span class="n">Q</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">V</span><span class="p">,</span> <span class="n">W_q</span><span class="p">,</span> <span class="n">W_k</span><span class="p">,</span> <span class="n">W_v</span><span class="p">,</span> <span class="n">W_o</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">):</span>
<a id="__codelineno-0-11" name="__codelineno-0-11" href="#__codelineno-0-11"></a> <span class="n">B</span><span class="p">,</span> <span class="n">T</span><span class="p">,</span> <span class="n">D</span> <span class="o">=</span> <span class="n">Q</span><span class="o">.</span><span class="n">shape</span>
<a id="__codelineno-0-12" name="__codelineno-0-12" href="#__codelineno-0-12"></a> <span class="n">head_dim</span> <span class="o">=</span> <span class="n">D</span> <span class="o">//</span> <span class="n">n_heads</span>
<a id="__codelineno-0-13" name="__codelineno-0-13" href="#__codelineno-0-13"></a>
<a id="__codelineno-0-14" name="__codelineno-0-14" href="#__codelineno-0-14"></a> <span class="n">q</span> <span class="o">=</span> <span class="n">Q</span> <span class="o">@</span> <span class="n">W_q</span> <span class="c1"># (B, T, D)</span>
<a id="__codelineno-0-15" name="__codelineno-0-15" href="#__codelineno-0-15"></a> <span class="n">k</span> <span class="o">=</span> <span class="n">K</span> <span class="o">@</span> <span class="n">W_k</span>
<a id="__codelineno-0-16" name="__codelineno-0-16" href="#__codelineno-0-16"></a> <span class="n">v</span> <span class="o">=</span> <span class="n">V</span> <span class="o">@</span> <span class="n">W_v</span>
<a id="__codelineno-0-17" name="__codelineno-0-17" href="#__codelineno-0-17"></a>
<a id="__codelineno-0-18" name="__codelineno-0-18" href="#__codelineno-0-18"></a> <span class="c1"># Reshape to (B, n_heads, T, head_dim)</span>
<a id="__codelineno-0-19" name="__codelineno-0-19" href="#__codelineno-0-19"></a> <span class="n">q</span> <span class="o">=</span> <span class="n">q</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">T</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">head_dim</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<a id="__codelineno-0-20" name="__codelineno-0-20" href="#__codelineno-0-20"></a> <span class="n">k</span> <span class="o">=</span> <span class="n">k</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">T</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">head_dim</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<a id="__codelineno-0-21" name="__codelineno-0-21" href="#__codelineno-0-21"></a> <span class="n">v</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">T</span><span class="p">,</span> <span class="n">n_heads</span><span class="p">,</span> <span class="n">head_dim</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<a id="__codelineno-0-22" name="__codelineno-0-22" href="#__codelineno-0-22"></a>
<a id="__codelineno-0-23" name="__codelineno-0-23" href="#__codelineno-0-23"></a> <span class="n">scores</span> <span class="o">=</span> <span class="n">q</span> <span class="o">@</span> <span class="n">k</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="n">jnp</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">head_dim</span><span class="p">)</span>
<a id="__codelineno-0-24" name="__codelineno-0-24" href="#__codelineno-0-24"></a> <span class="n">weights</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<a id="__codelineno-0-25" name="__codelineno-0-25" href="#__codelineno-0-25"></a> <span class="n">out</span> <span class="o">=</span> <span class="p">(</span><span class="n">weights</span> <span class="o">@</span> <span class="n">v</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">T</span><span class="p">,</span> <span class="n">D</span><span class="p">)</span>
<a id="__codelineno-0-26" name="__codelineno-0-26" href="#__codelineno-0-26"></a> <span class="k">return</span> <span class="n">out</span> <span class="o">@</span> <span class="n">W_o</span><span class="p">,</span> <span class="n">weights</span>
<a id="__codelineno-0-27" name="__codelineno-0-27" href="#__codelineno-0-27"></a>
<a id="__codelineno-0-28" name="__codelineno-0-28" href="#__codelineno-0-28"></a><span class="k">def</span><span class="w"> </span><span class="nf">transformer_block</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">params</span><span class="p">):</span>
<a id="__codelineno-0-29" name="__codelineno-0-29" href="#__codelineno-0-29"></a> <span class="c1"># Pre-norm multi-head self-attention</span>
<a id="__codelineno-0-30" name="__codelineno-0-30" href="#__codelineno-0-30"></a> <span class="n">normed</span> <span class="o">=</span> <span class="n">layer_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;ln1_g&#39;</span><span class="p">],</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;ln1_b&#39;</span><span class="p">])</span>
<a id="__codelineno-0-31" name="__codelineno-0-31" href="#__codelineno-0-31"></a> <span class="n">attn_out</span><span class="p">,</span> <span class="n">weights</span> <span class="o">=</span> <span class="n">multi_head_attention</span><span class="p">(</span>
<a id="__codelineno-0-32" name="__codelineno-0-32" href="#__codelineno-0-32"></a> <span class="n">normed</span><span class="p">,</span> <span class="n">normed</span><span class="p">,</span> <span class="n">normed</span><span class="p">,</span>
<a id="__codelineno-0-33" name="__codelineno-0-33" href="#__codelineno-0-33"></a> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;W_q&#39;</span><span class="p">],</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;W_k&#39;</span><span class="p">],</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;W_v&#39;</span><span class="p">],</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;W_o&#39;</span><span class="p">],</span>
<a id="__codelineno-0-34" name="__codelineno-0-34" href="#__codelineno-0-34"></a> <span class="n">n_heads</span><span class="o">=</span><span class="mi">4</span>
<a id="__codelineno-0-35" name="__codelineno-0-35" href="#__codelineno-0-35"></a> <span class="p">)</span>
<a id="__codelineno-0-36" name="__codelineno-0-36" href="#__codelineno-0-36"></a> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">attn_out</span>
<a id="__codelineno-0-37" name="__codelineno-0-37" href="#__codelineno-0-37"></a>
<a id="__codelineno-0-38" name="__codelineno-0-38" href="#__codelineno-0-38"></a> <span class="c1"># Pre-norm feed-forward</span>
<a id="__codelineno-0-39" name="__codelineno-0-39" href="#__codelineno-0-39"></a> <span class="n">normed</span> <span class="o">=</span> <span class="n">layer_norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;ln2_g&#39;</span><span class="p">],</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;ln2_b&#39;</span><span class="p">])</span>
<a id="__codelineno-0-40" name="__codelineno-0-40" href="#__codelineno-0-40"></a> <span class="n">ff</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">gelu</span><span class="p">(</span><span class="n">normed</span> <span class="o">@</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;W1&#39;</span><span class="p">]</span> <span class="o">+</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;b1&#39;</span><span class="p">])</span>
<a id="__codelineno-0-41" name="__codelineno-0-41" href="#__codelineno-0-41"></a> <span class="n">ff</span> <span class="o">=</span> <span class="n">ff</span> <span class="o">@</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;W2&#39;</span><span class="p">]</span> <span class="o">+</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;b2&#39;</span><span class="p">]</span>
<a id="__codelineno-0-42" name="__codelineno-0-42" href="#__codelineno-0-42"></a> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">ff</span>
<a id="__codelineno-0-43" name="__codelineno-0-43" href="#__codelineno-0-43"></a> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">weights</span>
<a id="__codelineno-0-44" name="__codelineno-0-44" href="#__codelineno-0-44"></a>
<a id="__codelineno-0-45" name="__codelineno-0-45" href="#__codelineno-0-45"></a><span class="c1"># Initialise parameters</span>
<a id="__codelineno-0-46" name="__codelineno-0-46" href="#__codelineno-0-46"></a><span class="n">d_model</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">,</span> <span class="n">n_heads</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">4</span>
<a id="__codelineno-0-47" name="__codelineno-0-47" href="#__codelineno-0-47"></a><span class="n">key</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
<a id="__codelineno-0-48" name="__codelineno-0-48" href="#__codelineno-0-48"></a><span class="n">keys</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span>
<a id="__codelineno-0-49" name="__codelineno-0-49" href="#__codelineno-0-49"></a>
<a id="__codelineno-0-50" name="__codelineno-0-50" href="#__codelineno-0-50"></a><span class="n">params</span> <span class="o">=</span> <span class="p">{</span>
<a id="__codelineno-0-51" name="__codelineno-0-51" href="#__codelineno-0-51"></a> <span class="s1">&#39;W_q&#39;</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">keys</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">))</span> <span class="o">*</span> <span class="mf">0.05</span><span class="p">,</span>
<a id="__codelineno-0-52" name="__codelineno-0-52" href="#__codelineno-0-52"></a> <span class="s1">&#39;W_k&#39;</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">keys</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">))</span> <span class="o">*</span> <span class="mf">0.05</span><span class="p">,</span>
<a id="__codelineno-0-53" name="__codelineno-0-53" href="#__codelineno-0-53"></a> <span class="s1">&#39;W_v&#39;</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">keys</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">))</span> <span class="o">*</span> <span class="mf">0.05</span><span class="p">,</span>
<a id="__codelineno-0-54" name="__codelineno-0-54" href="#__codelineno-0-54"></a> <span class="s1">&#39;W_o&#39;</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">keys</span><span class="p">[</span><span class="mi">3</span><span class="p">],</span> <span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">))</span> <span class="o">*</span> <span class="mf">0.05</span><span class="p">,</span>
<a id="__codelineno-0-55" name="__codelineno-0-55" href="#__codelineno-0-55"></a> <span class="s1">&#39;ln1_g&#39;</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">d_model</span><span class="p">),</span> <span class="s1">&#39;ln1_b&#39;</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">d_model</span><span class="p">),</span>
<a id="__codelineno-0-56" name="__codelineno-0-56" href="#__codelineno-0-56"></a> <span class="s1">&#39;ln2_g&#39;</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">d_model</span><span class="p">),</span> <span class="s1">&#39;ln2_b&#39;</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">d_model</span><span class="p">),</span>
<a id="__codelineno-0-57" name="__codelineno-0-57" href="#__codelineno-0-57"></a> <span class="s1">&#39;W1&#39;</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">keys</span><span class="p">[</span><span class="mi">4</span><span class="p">],</span> <span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_ff</span><span class="p">))</span> <span class="o">*</span> <span class="mf">0.05</span><span class="p">,</span>
<a id="__codelineno-0-58" name="__codelineno-0-58" href="#__codelineno-0-58"></a> <span class="s1">&#39;b1&#39;</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">d_ff</span><span class="p">),</span>
<a id="__codelineno-0-59" name="__codelineno-0-59" href="#__codelineno-0-59"></a> <span class="s1">&#39;W2&#39;</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">keys</span><span class="p">[</span><span class="mi">5</span><span class="p">],</span> <span class="p">(</span><span class="n">d_ff</span><span class="p">,</span> <span class="n">d_model</span><span class="p">))</span> <span class="o">*</span> <span class="mf">0.05</span><span class="p">,</span>
<a id="__codelineno-0-60" name="__codelineno-0-60" href="#__codelineno-0-60"></a> <span class="s1">&#39;b2&#39;</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">d_model</span><span class="p">),</span>
<a id="__codelineno-0-61" name="__codelineno-0-61" href="#__codelineno-0-61"></a><span class="p">}</span>
<a id="__codelineno-0-62" name="__codelineno-0-62" href="#__codelineno-0-62"></a>
<a id="__codelineno-0-63" name="__codelineno-0-63" href="#__codelineno-0-63"></a><span class="c1"># Test with random input</span>
<a id="__codelineno-0-64" name="__codelineno-0-64" href="#__codelineno-0-64"></a><span class="n">x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">keys</span><span class="p">[</span><span class="mi">6</span><span class="p">],</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="n">d_model</span><span class="p">))</span> <span class="c1"># batch=2, seq_len=8</span>
<a id="__codelineno-0-65" name="__codelineno-0-65" href="#__codelineno-0-65"></a><span class="n">out</span><span class="p">,</span> <span class="n">attn_weights</span> <span class="o">=</span> <span class="n">transformer_block</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">params</span><span class="p">)</span>
<a id="__codelineno-0-66" name="__codelineno-0-66" href="#__codelineno-0-66"></a><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Input shape: </span><span class="si">{</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<a id="__codelineno-0-67" name="__codelineno-0-67" href="#__codelineno-0-67"></a><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Output shape: </span><span class="si">{</span><span class="n">out</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<a id="__codelineno-0-68" name="__codelineno-0-68" href="#__codelineno-0-68"></a><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Attention weights shape: </span><span class="si">{</span><span class="n">attn_weights</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="c1"># (B, n_heads, T, T)</span>
<a id="__codelineno-0-69" name="__codelineno-0-69" href="#__codelineno-0-69"></a>
<a id="__codelineno-0-70" name="__codelineno-0-70" href="#__codelineno-0-70"></a><span class="c1"># Visualise attention patterns for each head</span>
<a id="__codelineno-0-71" name="__codelineno-0-71" href="#__codelineno-0-71"></a><span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mf">3.5</span><span class="p">))</span>
<a id="__codelineno-0-72" name="__codelineno-0-72" href="#__codelineno-0-72"></a><span class="k">for</span> <span class="n">h</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">4</span><span class="p">):</span>
<a id="__codelineno-0-73" name="__codelineno-0-73" href="#__codelineno-0-73"></a> <span class="n">im</span> <span class="o">=</span> <span class="n">axes</span><span class="p">[</span><span class="n">h</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">attn_weights</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">h</span><span class="p">],</span> <span class="n">cmap</span><span class="o">=</span><span class="s1">&#39;Blues&#39;</span><span class="p">,</span> <span class="n">vmin</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<a id="__codelineno-0-74" name="__codelineno-0-74" href="#__codelineno-0-74"></a> <span class="n">axes</span><span class="p">[</span><span class="n">h</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Head </span><span class="si">{</span><span class="n">h</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<a id="__codelineno-0-75" name="__codelineno-0-75" href="#__codelineno-0-75"></a> <span class="n">axes</span><span class="p">[</span><span class="n">h</span><span class="p">]</span><span class="o">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s2">&quot;Key pos&quot;</span><span class="p">);</span> <span class="n">axes</span><span class="p">[</span><span class="n">h</span><span class="p">]</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s2">&quot;Query pos&quot;</span><span class="p">)</span>
<a id="__codelineno-0-76" name="__codelineno-0-76" href="#__codelineno-0-76"></a><span class="n">plt</span><span class="o">.</span><span class="n">suptitle</span><span class="p">(</span><span class="s2">&quot;Multi-Head Attention Patterns&quot;</span><span class="p">)</span>
<a id="__codelineno-0-77" name="__codelineno-0-77" href="#__codelineno-0-77"></a><span class="n">plt</span><span class="o">.</span><span class="n">tight_layout</span><span class="p">();</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></p>
</li>
<li>
<p>实现因果(自回归)注意力掩码,并与双向注意力进行比较。展示掩码如何防止信息从未来流向过去的标记。
<div class="highlight"><pre><span></span><code><a id="__codelineno-1-1" name="__codelineno-1-1" href="#__codelineno-1-1"></a><span class="kn">import</span><span class="w"> </span><span class="nn">jax</span>
<a id="__codelineno-1-2" name="__codelineno-1-2" href="#__codelineno-1-2"></a><span class="kn">import</span><span class="w"> </span><span class="nn">jax.numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">jnp</span>
<a id="__codelineno-1-3" name="__codelineno-1-3" href="#__codelineno-1-3"></a><span class="kn">import</span><span class="w"> </span><span class="nn">matplotlib.pyplot</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">plt</span>
<a id="__codelineno-1-4" name="__codelineno-1-4" href="#__codelineno-1-4"></a>
<a id="__codelineno-1-5" name="__codelineno-1-5" href="#__codelineno-1-5"></a><span class="k">def</span><span class="w"> </span><span class="nf">attention</span><span class="p">(</span><span class="n">Q</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">V</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
<a id="__codelineno-1-6" name="__codelineno-1-6" href="#__codelineno-1-6"></a> <span class="n">d_k</span> <span class="o">=</span> <span class="n">Q</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
<a id="__codelineno-1-7" name="__codelineno-1-7" href="#__codelineno-1-7"></a> <span class="n">scores</span> <span class="o">=</span> <span class="n">Q</span> <span class="o">@</span> <span class="n">K</span><span class="o">.</span><span class="n">T</span> <span class="o">/</span> <span class="n">jnp</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">d_k</span><span class="p">)</span>
<a id="__codelineno-1-8" name="__codelineno-1-8" href="#__codelineno-1-8"></a> <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<a id="__codelineno-1-9" name="__codelineno-1-9" href="#__codelineno-1-9"></a> <span class="n">scores</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">scores</span><span class="p">,</span> <span class="o">-</span><span class="mf">1e9</span><span class="p">)</span>
<a id="__codelineno-1-10" name="__codelineno-1-10" href="#__codelineno-1-10"></a> <span class="n">weights</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
<a id="__codelineno-1-11" name="__codelineno-1-11" href="#__codelineno-1-11"></a> <span class="k">return</span> <span class="n">weights</span> <span class="o">@</span> <span class="n">V</span><span class="p">,</span> <span class="n">weights</span>
<a id="__codelineno-1-12" name="__codelineno-1-12" href="#__codelineno-1-12"></a>
<a id="__codelineno-1-13" name="__codelineno-1-13" href="#__codelineno-1-13"></a><span class="n">seq_len</span><span class="p">,</span> <span class="n">d_model</span> <span class="o">=</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">8</span>
<a id="__codelineno-1-14" name="__codelineno-1-14" href="#__codelineno-1-14"></a><span class="n">key</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span>
<a id="__codelineno-1-15" name="__codelineno-1-15" href="#__codelineno-1-15"></a><span class="n">k1</span><span class="p">,</span> <span class="n">k2</span><span class="p">,</span> <span class="n">k3</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<a id="__codelineno-1-16" name="__codelineno-1-16" href="#__codelineno-1-16"></a><span class="n">Q</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">k1</span><span class="p">,</span> <span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">))</span>
<a id="__codelineno-1-17" name="__codelineno-1-17" href="#__codelineno-1-17"></a><span class="n">K</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">k2</span><span class="p">,</span> <span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">))</span>
<a id="__codelineno-1-18" name="__codelineno-1-18" href="#__codelineno-1-18"></a><span class="n">V</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">k3</span><span class="p">,</span> <span class="p">(</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">d_model</span><span class="p">))</span>
<a id="__codelineno-1-19" name="__codelineno-1-19" href="#__codelineno-1-19"></a>
<a id="__codelineno-1-20" name="__codelineno-1-20" href="#__codelineno-1-20"></a><span class="c1"># Bidirectional (encoder-style): all positions visible</span>
<a id="__codelineno-1-21" name="__codelineno-1-21" href="#__codelineno-1-21"></a><span class="n">bidir_mask</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>
<a id="__codelineno-1-22" name="__codelineno-1-22" href="#__codelineno-1-22"></a><span class="n">bidir_out</span><span class="p">,</span> <span class="n">bidir_weights</span> <span class="o">=</span> <span class="n">attention</span><span class="p">(</span><span class="n">Q</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">V</span><span class="p">,</span> <span class="n">bidir_mask</span><span class="p">)</span>
<a id="__codelineno-1-23" name="__codelineno-1-23" href="#__codelineno-1-23"></a>
<a id="__codelineno-1-24" name="__codelineno-1-24" href="#__codelineno-1-24"></a><span class="c1"># Causal (decoder-style): only past and current positions visible</span>
<a id="__codelineno-1-25" name="__codelineno-1-25" href="#__codelineno-1-25"></a><span class="n">causal_mask</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">tril</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">seq_len</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">))</span>
<a id="__codelineno-1-26" name="__codelineno-1-26" href="#__codelineno-1-26"></a><span class="n">causal_out</span><span class="p">,</span> <span class="n">causal_weights</span> <span class="o">=</span> <span class="n">attention</span><span class="p">(</span><span class="n">Q</span><span class="p">,</span> <span class="n">K</span><span class="p">,</span> <span class="n">V</span><span class="p">,</span> <span class="n">causal_mask</span><span class="p">)</span>
<a id="__codelineno-1-27" name="__codelineno-1-27" href="#__codelineno-1-27"></a>
<a id="__codelineno-1-28" name="__codelineno-1-28" href="#__codelineno-1-28"></a><span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">14</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span>
<a id="__codelineno-1-29" name="__codelineno-1-29" href="#__codelineno-1-29"></a><span class="n">tokens</span> <span class="o">=</span> <span class="p">[</span><span class="sa">f</span><span class="s2">&quot;t</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">&quot;</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">seq_len</span><span class="p">)]</span>
<a id="__codelineno-1-30" name="__codelineno-1-30" href="#__codelineno-1-30"></a>
<a id="__codelineno-1-31" name="__codelineno-1-31" href="#__codelineno-1-31"></a><span class="n">axes</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">bidir_weights</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="s1">&#39;Blues&#39;</span><span class="p">,</span> <span class="n">vmin</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">vmax</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span>
<a id="__codelineno-1-32" name="__codelineno-1-32" href="#__codelineno-1-32"></a><span class="n">axes</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Bidirectional Attention</span><span class="se">\n</span><span class="s2">(BERT-style)&quot;</span><span class="p">)</span>
<a id="__codelineno-1-33" name="__codelineno-1-33" href="#__codelineno-1-33"></a><span class="n">axes</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">set_xticks</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">seq_len</span><span class="p">));</span> <span class="n">axes</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">set_xticklabels</span><span class="p">(</span><span class="n">tokens</span><span class="p">)</span>
<a id="__codelineno-1-34" name="__codelineno-1-34" href="#__codelineno-1-34"></a><span class="n">axes</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">set_yticks</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">seq_len</span><span class="p">));</span> <span class="n">axes</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">set_yticklabels</span><span class="p">(</span><span class="n">tokens</span><span class="p">)</span>
<a id="__codelineno-1-35" name="__codelineno-1-35" href="#__codelineno-1-35"></a>
<a id="__codelineno-1-36" name="__codelineno-1-36" href="#__codelineno-1-36"></a><span class="n">axes</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">causal_mask</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">float</span><span class="p">),</span> <span class="n">cmap</span><span class="o">=</span><span class="s1">&#39;Greys&#39;</span><span class="p">,</span> <span class="n">vmin</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">vmax</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<a id="__codelineno-1-37" name="__codelineno-1-37" href="#__codelineno-1-37"></a><span class="n">axes</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Causal Mask</span><span class="se">\n</span><span class="s2">(1 = allowed, 0 = blocked)&quot;</span><span class="p">)</span>
<a id="__codelineno-1-38" name="__codelineno-1-38" href="#__codelineno-1-38"></a><span class="n">axes</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">set_xticks</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">seq_len</span><span class="p">));</span> <span class="n">axes</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">set_xticklabels</span><span class="p">(</span><span class="n">tokens</span><span class="p">)</span>
<a id="__codelineno-1-39" name="__codelineno-1-39" href="#__codelineno-1-39"></a><span class="n">axes</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">set_yticks</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">seq_len</span><span class="p">));</span> <span class="n">axes</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">set_yticklabels</span><span class="p">(</span><span class="n">tokens</span><span class="p">)</span>
<a id="__codelineno-1-40" name="__codelineno-1-40" href="#__codelineno-1-40"></a>
<a id="__codelineno-1-41" name="__codelineno-1-41" href="#__codelineno-1-41"></a><span class="n">axes</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">causal_weights</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="s1">&#39;Blues&#39;</span><span class="p">,</span> <span class="n">vmin</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">vmax</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span>
<a id="__codelineno-1-42" name="__codelineno-1-42" href="#__codelineno-1-42"></a><span class="n">axes</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Causal Attention</span><span class="se">\n</span><span class="s2">(GPT-style)&quot;</span><span class="p">)</span>
<a id="__codelineno-1-43" name="__codelineno-1-43" href="#__codelineno-1-43"></a><span class="n">axes</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">set_xticks</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">seq_len</span><span class="p">));</span> <span class="n">axes</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">set_xticklabels</span><span class="p">(</span><span class="n">tokens</span><span class="p">)</span>
<a id="__codelineno-1-44" name="__codelineno-1-44" href="#__codelineno-1-44"></a><span class="n">axes</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">set_yticks</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">seq_len</span><span class="p">));</span> <span class="n">axes</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">set_yticklabels</span><span class="p">(</span><span class="n">tokens</span><span class="p">)</span>
<a id="__codelineno-1-45" name="__codelineno-1-45" href="#__codelineno-1-45"></a>
<a id="__codelineno-1-46" name="__codelineno-1-46" href="#__codelineno-1-46"></a><span class="k">for</span> <span class="n">ax</span> <span class="ow">in</span> <span class="n">axes</span><span class="p">:</span>
<a id="__codelineno-1-47" name="__codelineno-1-47" href="#__codelineno-1-47"></a> <span class="n">ax</span><span class="o">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s2">&quot;Key&quot;</span><span class="p">);</span> <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s2">&quot;Query&quot;</span><span class="p">)</span>
<a id="__codelineno-1-48" name="__codelineno-1-48" href="#__codelineno-1-48"></a><span class="n">plt</span><span class="o">.</span><span class="n">tight_layout</span><span class="p">();</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
<a id="__codelineno-1-49" name="__codelineno-1-49" href="#__codelineno-1-49"></a>
<a id="__codelineno-1-50" name="__codelineno-1-50" href="#__codelineno-1-50"></a><span class="c1"># Verify: in causal attention, output at position i depends only on positions &lt;= i</span>
<a id="__codelineno-1-51" name="__codelineno-1-51" href="#__codelineno-1-51"></a><span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Causal attention weight at position 2 (should only attend to 0, 1, 2):&quot;</span><span class="p">)</span>
<a id="__codelineno-1-52" name="__codelineno-1-52" href="#__codelineno-1-52"></a><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot; Weights: </span><span class="si">{</span><span class="n">causal_weights</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<a id="__codelineno-1-53" name="__codelineno-1-53" href="#__codelineno-1-53"></a><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot; Sum of future weights (should be ~0): </span><span class="si">{</span><span class="n">causal_weights</span><span class="p">[</span><span class="mi">2</span><span class="p">,</span><span class="w"> </span><span class="mi">3</span><span class="p">:]</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
</code></pre></div></p>
</li>
<li>
<p>实现LoRA(低秩适配),并展示它如何以远少于全量微调的可训练参数来修改权重矩阵。
<div class="highlight"><pre><span></span><code><a id="__codelineno-2-1" name="__codelineno-2-1" href="#__codelineno-2-1"></a><span class="kn">import</span><span class="w"> </span><span class="nn">jax</span>
<a id="__codelineno-2-2" name="__codelineno-2-2" href="#__codelineno-2-2"></a><span class="kn">import</span><span class="w"> </span><span class="nn">jax.numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">jnp</span>
<a id="__codelineno-2-3" name="__codelineno-2-3" href="#__codelineno-2-3"></a>
<a id="__codelineno-2-4" name="__codelineno-2-4" href="#__codelineno-2-4"></a><span class="n">d_model</span> <span class="o">=</span> <span class="mi">256</span>
<a id="__codelineno-2-5" name="__codelineno-2-5" href="#__codelineno-2-5"></a><span class="n">rank</span> <span class="o">=</span> <span class="mi">4</span> <span class="c1"># LoRA rank (much smaller than d_model)</span>
<a id="__codelineno-2-6" name="__codelineno-2-6" href="#__codelineno-2-6"></a>
<a id="__codelineno-2-7" name="__codelineno-2-7" href="#__codelineno-2-7"></a><span class="n">key</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
<a id="__codelineno-2-8" name="__codelineno-2-8" href="#__codelineno-2-8"></a><span class="n">k1</span><span class="p">,</span> <span class="n">k2</span><span class="p">,</span> <span class="n">k3</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<a id="__codelineno-2-9" name="__codelineno-2-9" href="#__codelineno-2-9"></a>
<a id="__codelineno-2-10" name="__codelineno-2-10" href="#__codelineno-2-10"></a><span class="c1"># Original frozen weight matrix</span>
<a id="__codelineno-2-11" name="__codelineno-2-11" href="#__codelineno-2-11"></a><span class="n">W_frozen</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">k1</span><span class="p">,</span> <span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">))</span> <span class="o">*</span> <span class="mf">0.02</span>
<a id="__codelineno-2-12" name="__codelineno-2-12" href="#__codelineno-2-12"></a>
<a id="__codelineno-2-13" name="__codelineno-2-13" href="#__codelineno-2-13"></a><span class="c1"># LoRA matrices (only these are trainable)</span>
<a id="__codelineno-2-14" name="__codelineno-2-14" href="#__codelineno-2-14"></a><span class="n">B</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">d_model</span><span class="p">,</span> <span class="n">rank</span><span class="p">))</span> <span class="c1"># initialised to zero</span>
<a id="__codelineno-2-15" name="__codelineno-2-15" href="#__codelineno-2-15"></a><span class="n">A</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">k2</span><span class="p">,</span> <span class="p">(</span><span class="n">rank</span><span class="p">,</span> <span class="n">d_model</span><span class="p">))</span> <span class="o">*</span> <span class="mf">0.01</span> <span class="c1"># random init</span>
<a id="__codelineno-2-16" name="__codelineno-2-16" href="#__codelineno-2-16"></a>
<a id="__codelineno-2-17" name="__codelineno-2-17" href="#__codelineno-2-17"></a><span class="c1"># Forward pass: W_effective = W_frozen + B @ A</span>
<a id="__codelineno-2-18" name="__codelineno-2-18" href="#__codelineno-2-18"></a><span class="n">x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">k3</span><span class="p">,</span> <span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="n">d_model</span><span class="p">))</span>
<a id="__codelineno-2-19" name="__codelineno-2-19" href="#__codelineno-2-19"></a>
<a id="__codelineno-2-20" name="__codelineno-2-20" href="#__codelineno-2-20"></a><span class="c1"># Without LoRA</span>
<a id="__codelineno-2-21" name="__codelineno-2-21" href="#__codelineno-2-21"></a><span class="n">y_original</span> <span class="o">=</span> <span class="n">x</span> <span class="o">@</span> <span class="n">W_frozen</span><span class="o">.</span><span class="n">T</span>
<a id="__codelineno-2-22" name="__codelineno-2-22" href="#__codelineno-2-22"></a>
<a id="__codelineno-2-23" name="__codelineno-2-23" href="#__codelineno-2-23"></a><span class="c1"># With LoRA</span>
<a id="__codelineno-2-24" name="__codelineno-2-24" href="#__codelineno-2-24"></a><span class="n">W_effective</span> <span class="o">=</span> <span class="n">W_frozen</span> <span class="o">+</span> <span class="n">B</span> <span class="o">@</span> <span class="n">A</span>
<a id="__codelineno-2-25" name="__codelineno-2-25" href="#__codelineno-2-25"></a><span class="n">y_lora</span> <span class="o">=</span> <span class="n">x</span> <span class="o">@</span> <span class="n">W_effective</span><span class="o">.</span><span class="n">T</span>
<a id="__codelineno-2-26" name="__codelineno-2-26" href="#__codelineno-2-26"></a>
<a id="__codelineno-2-27" name="__codelineno-2-27" href="#__codelineno-2-27"></a><span class="c1"># Parameter counts</span>
<a id="__codelineno-2-28" name="__codelineno-2-28" href="#__codelineno-2-28"></a><span class="n">full_params</span> <span class="o">=</span> <span class="n">d_model</span> <span class="o">*</span> <span class="n">d_model</span>
<a id="__codelineno-2-29" name="__codelineno-2-29" href="#__codelineno-2-29"></a><span class="n">lora_params</span> <span class="o">=</span> <span class="n">d_model</span> <span class="o">*</span> <span class="n">rank</span> <span class="o">+</span> <span class="n">rank</span> <span class="o">*</span> <span class="n">d_model</span> <span class="c1"># B + A</span>
<a id="__codelineno-2-30" name="__codelineno-2-30" href="#__codelineno-2-30"></a>
<a id="__codelineno-2-31" name="__codelineno-2-31" href="#__codelineno-2-31"></a><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Model dimension: </span><span class="si">{</span><span class="n">d_model</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<a id="__codelineno-2-32" name="__codelineno-2-32" href="#__codelineno-2-32"></a><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;LoRA rank: </span><span class="si">{</span><span class="n">rank</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<a id="__codelineno-2-33" name="__codelineno-2-33" href="#__codelineno-2-33"></a><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Full fine-tuning parameters: </span><span class="si">{</span><span class="n">full_params</span><span class="si">:</span><span class="s2">,</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<a id="__codelineno-2-34" name="__codelineno-2-34" href="#__codelineno-2-34"></a><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;LoRA parameters: </span><span class="si">{</span><span class="n">lora_params</span><span class="si">:</span><span class="s2">,</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<a id="__codelineno-2-35" name="__codelineno-2-35" href="#__codelineno-2-35"></a><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Parameter reduction: </span><span class="si">{</span><span class="n">full_params</span><span class="w"> </span><span class="o">/</span><span class="w"> </span><span class="n">lora_params</span><span class="si">:</span><span class="s2">.1f</span><span class="si">}</span><span class="s2">x&quot;</span><span class="p">)</span>
<a id="__codelineno-2-36" name="__codelineno-2-36" href="#__codelineno-2-36"></a><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Since B is initialised to zeros, initial LoRA output matches original:&quot;</span><span class="p">)</span>
<a id="__codelineno-2-37" name="__codelineno-2-37" href="#__codelineno-2-37"></a><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot; Max difference: </span><span class="si">{</span><span class="n">jnp</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">y_original</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="n">y_lora</span><span class="p">)</span><span class="o">.</span><span class="n">max</span><span class="p">()</span><span class="si">:</span><span class="s2">.2e</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<a id="__codelineno-2-38" name="__codelineno-2-38" href="#__codelineno-2-38"></a>
<a id="__codelineno-2-39" name="__codelineno-2-39" href="#__codelineno-2-39"></a><span class="c1"># Simulate training: only update A and B</span>
<a id="__codelineno-2-40" name="__codelineno-2-40" href="#__codelineno-2-40"></a><span class="k">def</span><span class="w"> </span><span class="nf">lora_forward</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">W_frozen</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
<a id="__codelineno-2-41" name="__codelineno-2-41" href="#__codelineno-2-41"></a> <span class="k">return</span> <span class="n">x</span> <span class="o">@</span> <span class="p">(</span><span class="n">W_frozen</span> <span class="o">+</span> <span class="n">B</span> <span class="o">@</span> <span class="n">A</span><span class="p">)</span><span class="o">.</span><span class="n">T</span>
<a id="__codelineno-2-42" name="__codelineno-2-42" href="#__codelineno-2-42"></a>
<a id="__codelineno-2-43" name="__codelineno-2-43" href="#__codelineno-2-43"></a><span class="k">def</span><span class="w"> </span><span class="nf">dummy_loss</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">W_frozen</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">target</span><span class="p">):</span>
<a id="__codelineno-2-44" name="__codelineno-2-44" href="#__codelineno-2-44"></a> <span class="n">pred</span> <span class="o">=</span> <span class="n">lora_forward</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">W_frozen</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span>
<a id="__codelineno-2-45" name="__codelineno-2-45" href="#__codelineno-2-45"></a> <span class="k">return</span> <span class="n">jnp</span><span class="o">.</span><span class="n">mean</span><span class="p">((</span><span class="n">pred</span> <span class="o">-</span> <span class="n">target</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span>
<a id="__codelineno-2-46" name="__codelineno-2-46" href="#__codelineno-2-46"></a>
<a id="__codelineno-2-47" name="__codelineno-2-47" href="#__codelineno-2-47"></a><span class="c1"># Target: some transformation of x</span>
<a id="__codelineno-2-48" name="__codelineno-2-48" href="#__codelineno-2-48"></a><span class="n">target</span> <span class="o">=</span> <span class="n">x</span> <span class="o">@</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="mi">99</span><span class="p">),</span> <span class="p">(</span><span class="n">d_model</span><span class="p">,</span> <span class="n">d_model</span><span class="p">))</span><span class="o">.</span><span class="n">T</span> <span class="o">*</span> <span class="mf">0.02</span>
<a id="__codelineno-2-49" name="__codelineno-2-49" href="#__codelineno-2-49"></a>
<a id="__codelineno-2-50" name="__codelineno-2-50" href="#__codelineno-2-50"></a><span class="n">grad_fn</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">jit</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">grad</span><span class="p">(</span><span class="n">dummy_loss</span><span class="p">,</span> <span class="n">argnums</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)))</span>
<a id="__codelineno-2-51" name="__codelineno-2-51" href="#__codelineno-2-51"></a><span class="n">lr</span> <span class="o">=</span> <span class="mf">0.01</span>
<a id="__codelineno-2-52" name="__codelineno-2-52" href="#__codelineno-2-52"></a>
<a id="__codelineno-2-53" name="__codelineno-2-53" href="#__codelineno-2-53"></a><span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">200</span><span class="p">):</span>
<a id="__codelineno-2-54" name="__codelineno-2-54" href="#__codelineno-2-54"></a> <span class="n">gA</span><span class="p">,</span> <span class="n">gB</span> <span class="o">=</span> <span class="n">grad_fn</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">W_frozen</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
<a id="__codelineno-2-55" name="__codelineno-2-55" href="#__codelineno-2-55"></a> <span class="n">A</span> <span class="o">=</span> <span class="n">A</span> <span class="o">-</span> <span class="n">lr</span> <span class="o">*</span> <span class="n">gA</span>
<a id="__codelineno-2-56" name="__codelineno-2-56" href="#__codelineno-2-56"></a> <span class="n">B</span> <span class="o">=</span> <span class="n">B</span> <span class="o">-</span> <span class="n">lr</span> <span class="o">*</span> <span class="n">gB</span>
<a id="__codelineno-2-57" name="__codelineno-2-57" href="#__codelineno-2-57"></a>
<a id="__codelineno-2-58" name="__codelineno-2-58" href="#__codelineno-2-58"></a><span class="n">loss_before</span> <span class="o">=</span> <span class="n">dummy_loss</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">A</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">B</span><span class="p">),</span> <span class="n">W_frozen</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
<a id="__codelineno-2-59" name="__codelineno-2-59" href="#__codelineno-2-59"></a><span class="n">loss_after</span> <span class="o">=</span> <span class="n">dummy_loss</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">W_frozen</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span>
<a id="__codelineno-2-60" name="__codelineno-2-60" href="#__codelineno-2-60"></a><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Loss before LoRA: </span><span class="si">{</span><span class="n">loss_before</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<a id="__codelineno-2-61" name="__codelineno-2-61" href="#__codelineno-2-61"></a><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Loss after LoRA: </span><span class="si">{</span><span class="n">loss_after</span><span class="si">:</span><span class="s2">.6f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<a id="__codelineno-2-62" name="__codelineno-2-62" href="#__codelineno-2-62"></a><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Effective weight change rank: </span><span class="si">{</span><span class="n">jnp</span><span class="o">.</span><span class="n">linalg</span><span class="o">.</span><span class="n">matrix_rank</span><span class="p">(</span><span class="n">B</span><span class="w"> </span><span class="o">@</span><span class="w"> </span><span class="n">A</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
</code></pre></div></p>
</li>
</ol>
</article>
</div>
<script>var target=document.getElementById(location.hash.slice(1));target&&target.name&&(target.checked=target.name.startsWith("__tabbed_"))</script>
</div>
<button type="button" class="md-top md-icon" data-md-component="top" hidden>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8z"/></svg>
回到页面顶部
</button>
</main>
<footer class="md-footer">
<nav class="md-footer__inner md-grid" aria-label="页脚" >
<a href="../03.%20embeddings%20and%20sequence%20models/" class="md-footer__link md-footer__link--prev" aria-label="上一页: 嵌入与序列模型">
<div class="md-footer__button md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11z"/></svg>
</div>
<div class="md-footer__title">
<span class="md-footer__direction">
上一页
</span>
<div class="md-ellipsis">
嵌入与序列模型
</div>
</div>
</a>
<a href="../05.%20advanced%20text%20generation/" class="md-footer__link md-footer__link--next" aria-label="下一页: 高级文本生成">
<div class="md-footer__title">
<span class="md-footer__direction">
下一页
</span>
<div class="md-ellipsis">
高级文本生成
</div>
</div>
<div class="md-footer__button md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M4 11v2h12l-5.5 5.5 1.42 1.42L19.84 12l-7.92-7.92L10.5 5.5 16 11z"/></svg>
</div>
</a>
</nav>
<div class="md-footer-meta md-typeset">
<div class="md-footer-meta__inner md-grid">
<div class="md-copyright">
Made with
<a href="https://squidfunk.github.io/mkdocs-material/" target="_blank" rel="noopener">
Material for MkDocs
</a>
</div>
<div class="md-social">
<a href="https://github.com/flykhan/maths-cs-ai-compendium-zh" target="_blank" rel="noopener" title="github.com" class="md-social__link">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 512 512"><!--! Font Awesome Free 7.1.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2025 Fonticons, Inc.--><path d="M173.9 397.4c0 2-2.3 3.6-5.2 3.6-3.3.3-5.6-1.3-5.6-3.6 0-2 2.3-3.6 5.2-3.6 3-.3 5.6 1.3 5.6 3.6m-31.1-4.5c-.7 2 1.3 4.3 4.3 4.9 2.6 1 5.6 0 6.2-2s-1.3-4.3-4.3-5.2c-2.6-.7-5.5.3-6.2 2.3m44.2-1.7c-2.9.7-4.9 2.6-4.6 4.9.3 2 2.9 3.3 5.9 2.6 2.9-.7 4.9-2.6 4.6-4.6-.3-1.9-3-3.2-5.9-2.9M252.8 8C114.1 8 8 113.3 8 252c0 110.9 69.8 205.8 169.5 239.2 12.8 2.3 17.3-5.6 17.3-12.1 0-6.2-.3-40.4-.3-61.4 0 0-70 15-84.7-29.8 0 0-11.4-29.1-27.8-36.6 0 0-22.9-15.7 1.6-15.4 0 0 24.9 2 38.6 25.8 21.9 38.6 58.6 27.5 72.9 20.9 2.3-16 8.8-27.1 16-33.7-55.9-6.2-112.3-14.3-112.3-110.5 0-27.5 7.6-41.3 23.6-58.9-2.6-6.5-11.1-33.3 2.6-67.9 20.9-6.5 69 27 69 27 20-5.6 41.5-8.5 62.8-8.5s42.8 2.9 62.8 8.5c0 0 48.1-33.6 69-27 13.7 34.7 5.2 61.4 2.6 67.9 16 17.7 25.8 31.5 25.8 58.9 0 96.5-58.9 104.2-114.8 110.5 9.2 7.9 17 22.9 17 46.4 0 33.7-.3 75.4-.3 83.6 0 6.5 4.6 14.4 17.3 12.1C436.2 457.8 504 362.9 504 252 504 113.3 391.5 8 252.8 8M105.2 352.9c-1.3 1-1 3.3.7 5.2 1.6 1.6 3.9 2.3 5.2 1 1.3-1 1-3.3-.7-5.2-1.6-1.6-3.9-2.3-5.2-1m-10.8-8.1c-.7 1.3.3 2.9 2.3 3.9 1.6 1 3.6.7 4.3-.7.7-1.3-.3-2.9-2.3-3.9-2-.6-3.6-.3-4.3.7m32.4 35.6c-1.6 1.3-1 4.3 1.3 6.2 2.3 2.3 5.2 2.6 6.5 1 1.3-1.3.7-4.3-1.3-6.2-2.2-2.3-5.2-2.6-6.5-1m-11.4-14.7c-1.6 1-1.6 3.6 0 5.9s4.3 3.3 5.6 2.3c1.6-1.3 1.6-3.9 0-6.2-1.4-2.3-4-3.3-5.6-2"/></svg>
</a>
</div>
</div>
</div>
</footer>
</div>
<div class="md-dialog" data-md-component="dialog">
<div class="md-dialog__inner md-typeset"></div>
</div>
<script id="__config" type="application/json">{"annotate": null, "base": "../..", "features": ["navigation.tabs", "navigation.sections", "navigation.expand", "navigation.top", "navigation.footer", "search.suggest", "search.highlight", "content.code.copy", "toc.follow"], "search": "../../assets/javascripts/workers/search.2c215733.min.js", "tags": null, "translations": {"clipboard.copied": "\u5df2\u590d\u5236", "clipboard.copy": "\u590d\u5236", "search.result.more.one": "\u5728\u8be5\u9875\u4e0a\u8fd8\u6709 1 \u4e2a\u7b26\u5408\u6761\u4ef6\u7684\u7ed3\u679c", "search.result.more.other": "\u5728\u8be5\u9875\u4e0a\u8fd8\u6709 # \u4e2a\u7b26\u5408\u6761\u4ef6\u7684\u7ed3\u679c", "search.result.none": "\u6ca1\u6709\u627e\u5230\u7b26\u5408\u6761\u4ef6\u7684\u7ed3\u679c", "search.result.one": "\u627e\u5230 1 \u4e2a\u7b26\u5408\u6761\u4ef6\u7684\u7ed3\u679c", "search.result.other": "# \u4e2a\u7b26\u5408\u6761\u4ef6\u7684\u7ed3\u679c", "search.result.placeholder": "\u952e\u5165\u4ee5\u5f00\u59cb\u641c\u7d22", "search.result.term.missing": "\u7f3a\u5c11", "select.version": "\u9009\u62e9\u5f53\u524d\u7248\u672c"}, "version": null}</script>
<script src="../../assets/javascripts/bundle.79ae519e.min.js"></script>
<script src="../../javascripts/mathjax.js"></script>
<script src="https://unpkg.com/mathjax@3/es5/tex-mml-chtml.js"></script>
</body>
</html>