Files

5503 lines
170 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
<!doctype html>
<html lang="zh" class="no-js">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width,initial-scale=1">
<meta name="description" content="一本开源的直觉优先教科书,从零开始覆盖数学、计算机科学和人工智能(中文翻译版)。">
<meta name="author" content="Henry Ndubuaku (flykhan 译)">
<link rel="canonical" href="https://flykhan.github.io/maths-cs-ai-compendium-zh/chapter%2008%3A%20computer%20vision/04.%20vision%20transformers%20and%20generation/">
<link rel="prev" href="../03.%20object%20detection%20and%20segmentation/">
<link rel="next" href="../05.%20video%20and%203D%20vision/">
<link rel="icon" href="../../assets/images/favicon.png">
<meta name="generator" content="mkdocs-1.6.1, mkdocs-material-9.7.6">
<title>ViT 与生成模型 - 数学、计算机科学与 AI 百科全书</title>
<link rel="stylesheet" href="../../assets/stylesheets/main.484c7ddc.min.css">
<link rel="stylesheet" href="../../assets/stylesheets/palette.ab4e12ef.min.css">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300i,400,400i,700,700i%7CRoboto+Mono:400,400i,700,700i&display=fallback">
<style>:root{--md-text-font:"Roboto";--md-code-font:"Roboto Mono"}</style>
<script>__md_scope=new URL("../..",location),__md_hash=e=>[...e].reduce(((e,_)=>(e<<5)-e+_.charCodeAt(0)),0),__md_get=(e,_=localStorage,t=__md_scope)=>JSON.parse(_.getItem(t.pathname+"."+e)),__md_set=(e,_,t=localStorage,a=__md_scope)=>{try{t.setItem(a.pathname+"."+e,JSON.stringify(_))}catch(e){}}</script>
</head>
<body dir="ltr" data-md-color-scheme="default" data-md-color-primary="slate" data-md-color-accent="indigo">
<input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer" autocomplete="off">
<input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search" autocomplete="off">
<label class="md-overlay" for="__drawer"></label>
<div data-md-component="skip">
<a href="#transformer" class="md-skip">
跳转至
</a>
</div>
<div data-md-component="announce">
</div>
<header class="md-header" data-md-component="header">
<nav class="md-header__inner md-grid" aria-label="页眉">
<a href="../.." title="数学、计算机科学与 AI 百科全书" class="md-header__button md-logo" aria-label="数学、计算机科学与 AI 百科全书" data-md-component="logo">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 8a3 3 0 0 0 3-3 3 3 0 0 0-3-3 3 3 0 0 0-3 3 3 3 0 0 0 3 3m0 3.54C9.64 9.35 6.5 8 3 8v11c3.5 0 6.64 1.35 9 3.54 2.36-2.19 5.5-3.54 9-3.54V8c-3.5 0-6.64 1.35-9 3.54"/></svg>
</a>
<label class="md-header__button md-icon" for="__drawer">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M3 6h18v2H3zm0 5h18v2H3zm0 5h18v2H3z"/></svg>
</label>
<div class="md-header__title" data-md-component="header-title">
<div class="md-header__ellipsis">
<div class="md-header__topic">
<span class="md-ellipsis">
数学、计算机科学与 AI 百科全书
</span>
</div>
<div class="md-header__topic" data-md-component="header-topic">
<span class="md-ellipsis">
ViT 与生成模型
</span>
</div>
</div>
</div>
<form class="md-header__option" data-md-component="palette">
<input class="md-option" data-md-color-media="" data-md-color-scheme="default" data-md-color-primary="slate" data-md-color-accent="indigo" aria-label="切换到深色模式" type="radio" name="__palette" id="__palette_0">
<label class="md-header__button md-icon" title="切换到深色模式" for="__palette_1" hidden>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 8a4 4 0 0 0-4 4 4 4 0 0 0 4 4 4 4 0 0 0 4-4 4 4 0 0 0-4-4m0 10a6 6 0 0 1-6-6 6 6 0 0 1 6-6 6 6 0 0 1 6 6 6 6 0 0 1-6 6m8-9.31V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12z"/></svg>
</label>
<input class="md-option" data-md-color-media="" data-md-color-scheme="slate" data-md-color-primary="slate" data-md-color-accent="indigo" aria-label="切换到浅色模式" type="radio" name="__palette" id="__palette_1">
<label class="md-header__button md-icon" title="切换到浅色模式" for="__palette_0" hidden>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 18c-.89 0-1.74-.2-2.5-.55C11.56 16.5 13 14.42 13 12s-1.44-4.5-3.5-5.45C10.26 6.2 11.11 6 12 6a6 6 0 0 1 6 6 6 6 0 0 1-6 6m8-9.31V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12z"/></svg>
</label>
</form>
<script>var palette=__md_get("__palette");if(palette&&palette.color){if("(prefers-color-scheme)"===palette.color.media){var media=matchMedia("(prefers-color-scheme: light)"),input=document.querySelector(media.matches?"[data-md-color-media='(prefers-color-scheme: light)']":"[data-md-color-media='(prefers-color-scheme: dark)']");palette.color.media=input.getAttribute("data-md-color-media"),palette.color.scheme=input.getAttribute("data-md-color-scheme"),palette.color.primary=input.getAttribute("data-md-color-primary"),palette.color.accent=input.getAttribute("data-md-color-accent")}for(var[key,value]of Object.entries(palette.color))document.body.setAttribute("data-md-color-"+key,value)}</script>
<label class="md-header__button md-icon" for="__search">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.52 6.52 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5"/></svg>
</label>
<div class="md-search" data-md-component="search" role="dialog">
<label class="md-search__overlay" for="__search"></label>
<div class="md-search__inner" role="search">
<form class="md-search__form" name="search">
<input type="text" class="md-search__input" name="query" aria-label="搜索" placeholder="搜索" autocapitalize="off" autocorrect="off" autocomplete="off" spellcheck="false" data-md-component="search-query" required>
<label class="md-search__icon md-icon" for="__search">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.52 6.52 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5"/></svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11z"/></svg>
</label>
<nav class="md-search__options" aria-label="查找">
<button type="reset" class="md-search__icon md-icon" title="清空当前内容" aria-label="清空当前内容" tabindex="-1">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M19 6.41 17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12z"/></svg>
</button>
</nav>
<div class="md-search__suggest" data-md-component="search-suggest"></div>
</form>
<div class="md-search__output">
<div class="md-search__scrollwrap" tabindex="0" data-md-scrollfix>
<div class="md-search-result" data-md-component="search-result">
<div class="md-search-result__meta">
正在初始化搜索引擎
</div>
<ol class="md-search-result__list" role="presentation"></ol>
</div>
</div>
</div>
</div>
</div>
<div class="md-header__source">
<a href="https://github.com/flykhan/maths-cs-ai-compendium-zh" title="前往仓库" class="md-source" data-md-component="source">
<div class="md-source__icon md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Free 7.1.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2025 Fonticons, Inc.--><path d="M439.6 236.1 244 40.5c-5.4-5.5-12.8-8.5-20.4-8.5s-15 3-20.4 8.4L162.5 81l51.5 51.5c27.1-9.1 52.7 16.8 43.4 43.7l49.7 49.7c34.2-11.8 61.2 31 35.5 56.7-26.5 26.5-70.2-2.9-56-37.3L240.3 199v121.9c25.3 12.5 22.3 41.8 9.1 55-6.4 6.4-15.2 10.1-24.3 10.1s-17.8-3.6-24.3-10.1c-17.6-17.6-11.1-46.9 11.2-56v-123c-20.8-8.5-24.6-30.7-18.6-45L142.6 101 8.5 235.1C3 240.6 0 247.9 0 255.5s3 15 8.5 20.4l195.6 195.7c5.4 5.4 12.7 8.4 20.4 8.4s15-3 20.4-8.4l194.7-194.7c5.4-5.4 8.4-12.8 8.4-20.4s-3-15-8.4-20.4"/></svg>
</div>
<div class="md-source__repository">
flykhan/maths-cs-ai-compendium-zh
</div>
</a>
</div>
</nav>
</header>
<div class="md-container" data-md-component="container">
<nav class="md-tabs" aria-label="标签" data-md-component="tabs">
<div class="md-grid">
<ul class="md-tabs__list">
<li class="md-tabs__item">
<a href="../.." class="md-tabs__link">
首页
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2001%3A%20vectors/01.%20vector%20spaces/" class="md-tabs__link">
向量
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2002%3A%20matrices/01.%20matrix%20properties/" class="md-tabs__link">
矩阵
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2003%3A%20calculus/01.%20differential%20calculus/" class="md-tabs__link">
微积分
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2004%3A%20statistics/01.%20fundamentals/" class="md-tabs__link">
统计学
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2005%3A%20probability/01.%20counting/" class="md-tabs__link">
概率论
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2006%3A%20machine%20learning/01.%20classical%20machine%20learning/" class="md-tabs__link">
机器学习
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2007%3A%20computational%20linguistics/01.%20linguistic%20foundations/" class="md-tabs__link">
计算语言学
</a>
</li>
<li class="md-tabs__item md-tabs__item--active">
<a href="../01.%20image%20fundamentals/" class="md-tabs__link">
计算机视觉
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2009%3A%20audio%20and%20speech/01.%20digital%20signal%20processing/" class="md-tabs__link">
音频与语音
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2010%3A%20multimodal%20learning/01.%20multimodal%20representations/" class="md-tabs__link">
多模态学习
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2011%3A%20autonomous%20systems/01.%20perception/" class="md-tabs__link">
自主系统
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2012%3A%20graph%20neural%20networks/01.%20geometric%20deep%20learning/" class="md-tabs__link">
图神经网络
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2013%3A%20computing%20and%20OS/01.%20discrete%20maths/" class="md-tabs__link">
计算机与操作系统
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/" class="md-tabs__link">
数据结构与算法
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2015%3A%20production%20software%20engineering/01.%20linux%20and%20CMD/" class="md-tabs__link">
生产级软件工程
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2016%3A%20SIMD%20and%20GPU%20programming/00.%20why%20C%2B%2B%20and%20how%20ML%20frameworks%20work/" class="md-tabs__link">
SIMD 与 GPU 编程
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2017%3A%20AI%20inference/01.%20quantisation/" class="md-tabs__link">
AI 推理
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2018%3A%20ML%20systems%20design/01.%20systems%20design%20fundamentals/" class="md-tabs__link">
ML 系统设计
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2019%3A%20applied%20AI/01.%20AI%20for%20finance/" class="md-tabs__link">
应用 AI
</a>
</li>
<li class="md-tabs__item">
<a href="../../chapter%2020%3A%20bleeding%20edge%20AI/01.%20quantum%20machine%20learning/" class="md-tabs__link">
前沿 AI
</a>
</li>
</ul>
</div>
</nav>
<main class="md-main" data-md-component="main">
<div class="md-main__inner md-grid">
<div class="md-sidebar md-sidebar--primary" data-md-component="sidebar" data-md-type="navigation" >
<div class="md-sidebar__scrollwrap">
<div class="md-sidebar__inner">
<nav class="md-nav md-nav--primary md-nav--lifted" aria-label="导航栏" data-md-level="0">
<label class="md-nav__title" for="__drawer">
<a href="../.." title="数学、计算机科学与 AI 百科全书" class="md-nav__button md-logo" aria-label="数学、计算机科学与 AI 百科全书" data-md-component="logo">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 8a3 3 0 0 0 3-3 3 3 0 0 0-3-3 3 3 0 0 0-3 3 3 3 0 0 0 3 3m0 3.54C9.64 9.35 6.5 8 3 8v11c3.5 0 6.64 1.35 9 3.54 2.36-2.19 5.5-3.54 9-3.54V8c-3.5 0-6.64 1.35-9 3.54"/></svg>
</a>
数学、计算机科学与 AI 百科全书
</label>
<div class="md-nav__source">
<a href="https://github.com/flykhan/maths-cs-ai-compendium-zh" title="前往仓库" class="md-source" data-md-component="source">
<div class="md-source__icon md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Free 7.1.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2025 Fonticons, Inc.--><path d="M439.6 236.1 244 40.5c-5.4-5.5-12.8-8.5-20.4-8.5s-15 3-20.4 8.4L162.5 81l51.5 51.5c27.1-9.1 52.7 16.8 43.4 43.7l49.7 49.7c34.2-11.8 61.2 31 35.5 56.7-26.5 26.5-70.2-2.9-56-37.3L240.3 199v121.9c25.3 12.5 22.3 41.8 9.1 55-6.4 6.4-15.2 10.1-24.3 10.1s-17.8-3.6-24.3-10.1c-17.6-17.6-11.1-46.9 11.2-56v-123c-20.8-8.5-24.6-30.7-18.6-45L142.6 101 8.5 235.1C3 240.6 0 247.9 0 255.5s3 15 8.5 20.4l195.6 195.7c5.4 5.4 12.7 8.4 20.4 8.4s15-3 20.4-8.4l194.7-194.7c5.4-5.4 8.4-12.8 8.4-20.4s-3-15-8.4-20.4"/></svg>
</div>
<div class="md-source__repository">
flykhan/maths-cs-ai-compendium-zh
</div>
</a>
</div>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../.." class="md-nav__link">
<span class="md-ellipsis">
首页
</span>
</a>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_2" >
<label class="md-nav__link" for="__nav_2" id="__nav_2_label" tabindex="0">
<span class="md-ellipsis">
向量
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_2_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_2">
<span class="md-nav__icon md-icon"></span>
向量
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2001%3A%20vectors/01.%20vector%20spaces/" class="md-nav__link">
<span class="md-ellipsis">
向量空间
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2001%3A%20vectors/02.%20vector%20properties/" class="md-nav__link">
<span class="md-ellipsis">
向量性质
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2001%3A%20vectors/03.%20norms%20and%20metrics/" class="md-nav__link">
<span class="md-ellipsis">
范数与度量
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2001%3A%20vectors/04.%20products/" class="md-nav__link">
<span class="md-ellipsis">
向量积
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2001%3A%20vectors/05.%20basis%20and%20duality/" class="md-nav__link">
<span class="md-ellipsis">
基与对偶性
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_3" >
<label class="md-nav__link" for="__nav_3" id="__nav_3_label" tabindex="0">
<span class="md-ellipsis">
矩阵
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_3_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_3">
<span class="md-nav__icon md-icon"></span>
矩阵
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2002%3A%20matrices/01.%20matrix%20properties/" class="md-nav__link">
<span class="md-ellipsis">
矩阵性质
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2002%3A%20matrices/02.%20matrix%20types/" class="md-nav__link">
<span class="md-ellipsis">
矩阵类型
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2002%3A%20matrices/03.%20operations/" class="md-nav__link">
<span class="md-ellipsis">
矩阵运算
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2002%3A%20matrices/04.%20linear%20transformations/" class="md-nav__link">
<span class="md-ellipsis">
线性变换
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2002%3A%20matrices/05.%20decompositions/" class="md-nav__link">
<span class="md-ellipsis">
矩阵分解
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_4" >
<label class="md-nav__link" for="__nav_4" id="__nav_4_label" tabindex="0">
<span class="md-ellipsis">
微积分
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_4_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_4">
<span class="md-nav__icon md-icon"></span>
微积分
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2003%3A%20calculus/01.%20differential%20calculus/" class="md-nav__link">
<span class="md-ellipsis">
微分
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2003%3A%20calculus/02.%20integral%20calculus/" class="md-nav__link">
<span class="md-ellipsis">
积分
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2003%3A%20calculus/03.%20multivariate%20calculus/" class="md-nav__link">
<span class="md-ellipsis">
多元微积分
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2003%3A%20calculus/04.%20function%20approximation/" class="md-nav__link">
<span class="md-ellipsis">
函数逼近
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2003%3A%20calculus/05.%20optimisation/" class="md-nav__link">
<span class="md-ellipsis">
优化
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_5" >
<label class="md-nav__link" for="__nav_5" id="__nav_5_label" tabindex="0">
<span class="md-ellipsis">
统计学
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_5_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_5">
<span class="md-nav__icon md-icon"></span>
统计学
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2004%3A%20statistics/01.%20fundamentals/" class="md-nav__link">
<span class="md-ellipsis">
基础
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2004%3A%20statistics/02.%20measures/" class="md-nav__link">
<span class="md-ellipsis">
统计量
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2004%3A%20statistics/03.%20sampling/" class="md-nav__link">
<span class="md-ellipsis">
抽样
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2004%3A%20statistics/04.%20hypothesis%20testing/" class="md-nav__link">
<span class="md-ellipsis">
假设检验
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2004%3A%20statistics/05.%20inference/" class="md-nav__link">
<span class="md-ellipsis">
推断
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_6" >
<label class="md-nav__link" for="__nav_6" id="__nav_6_label" tabindex="0">
<span class="md-ellipsis">
概率论
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_6_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_6">
<span class="md-nav__icon md-icon"></span>
概率论
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2005%3A%20probability/01.%20counting/" class="md-nav__link">
<span class="md-ellipsis">
计数
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2005%3A%20probability/02.%20probability%20concepts/" class="md-nav__link">
<span class="md-ellipsis">
概率概念
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2005%3A%20probability/03.%20distributions/" class="md-nav__link">
<span class="md-ellipsis">
分布
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2005%3A%20probability/04.%20bayesian/" class="md-nav__link">
<span class="md-ellipsis">
贝叶斯
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2005%3A%20probability/05.%20information%20theory/" class="md-nav__link">
<span class="md-ellipsis">
信息论
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_7" >
<label class="md-nav__link" for="__nav_7" id="__nav_7_label" tabindex="0">
<span class="md-ellipsis">
机器学习
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_7_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_7">
<span class="md-nav__icon md-icon"></span>
机器学习
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2006%3A%20machine%20learning/01.%20classical%20machine%20learning/" class="md-nav__link">
<span class="md-ellipsis">
经典机器学习
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2006%3A%20machine%20learning/02.%20gradient%20machine%20learning/" class="md-nav__link">
<span class="md-ellipsis">
梯度机器学习
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2006%3A%20machine%20learning/03.%20deep%20learning/" class="md-nav__link">
<span class="md-ellipsis">
深度学习
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2006%3A%20machine%20learning/04.%20reinforcement%20learning/" class="md-nav__link">
<span class="md-ellipsis">
强化学习
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2006%3A%20machine%20learning/05.%20distributed%20deep%20learning/" class="md-nav__link">
<span class="md-ellipsis">
分布式深度学习
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_8" >
<label class="md-nav__link" for="__nav_8" id="__nav_8_label" tabindex="0">
<span class="md-ellipsis">
计算语言学
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_8_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_8">
<span class="md-nav__icon md-icon"></span>
计算语言学
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2007%3A%20computational%20linguistics/01.%20linguistic%20foundations/" class="md-nav__link">
<span class="md-ellipsis">
语言学基础
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2007%3A%20computational%20linguistics/02.%20text%20processing%20and%20classic%20NLP/" class="md-nav__link">
<span class="md-ellipsis">
文本处理与经典 NLP
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2007%3A%20computational%20linguistics/03.%20embeddings%20and%20sequence%20models/" class="md-nav__link">
<span class="md-ellipsis">
嵌入与序列模型
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2007%3A%20computational%20linguistics/04.%20transformers%20and%20language%20models/" class="md-nav__link">
<span class="md-ellipsis">
Transformer 与语言模型
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2007%3A%20computational%20linguistics/05.%20advanced%20text%20generation/" class="md-nav__link">
<span class="md-ellipsis">
高级文本生成
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--active md-nav__item--section md-nav__item--nested">
<input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_9" checked>
<label class="md-nav__link" for="__nav_9" id="__nav_9_label" tabindex="">
<span class="md-ellipsis">
计算机视觉
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_9_label" aria-expanded="true">
<label class="md-nav__title" for="__nav_9">
<span class="md-nav__icon md-icon"></span>
计算机视觉
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../01.%20image%20fundamentals/" class="md-nav__link">
<span class="md-ellipsis">
图像基础
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../02.%20convolutional%20networks/" class="md-nav__link">
<span class="md-ellipsis">
卷积网络
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../03.%20object%20detection%20and%20segmentation/" class="md-nav__link">
<span class="md-ellipsis">
目标检测与分割
</span>
</a>
</li>
<li class="md-nav__item md-nav__item--active">
<input class="md-nav__toggle md-toggle" type="checkbox" id="__toc">
<label class="md-nav__link md-nav__link--active" for="__toc">
<span class="md-ellipsis">
ViT 与生成模型
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<a href="./" class="md-nav__link md-nav__link--active">
<span class="md-ellipsis">
ViT 与生成模型
</span>
</a>
<nav class="md-nav md-nav--secondary" aria-label="目录">
<label class="md-nav__title" for="__toc">
<span class="md-nav__icon md-icon"></span>
目录
</label>
<ul class="md-nav__list" data-md-component="toc" data-md-scrollfix>
<li class="md-nav__item">
<a href="#colabnotebook" class="md-nav__link">
<span class="md-ellipsis">
编程练习(使用CoLab或notebook
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item">
<a href="../05.%20video%20and%203D%20vision/" class="md-nav__link">
<span class="md-ellipsis">
视频与 3D 视觉
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_10" >
<label class="md-nav__link" for="__nav_10" id="__nav_10_label" tabindex="0">
<span class="md-ellipsis">
音频与语音
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_10_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_10">
<span class="md-nav__icon md-icon"></span>
音频与语音
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2009%3A%20audio%20and%20speech/01.%20digital%20signal%20processing/" class="md-nav__link">
<span class="md-ellipsis">
数字信号处理
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2009%3A%20audio%20and%20speech/02.%20automatic%20speech%20recognition/" class="md-nav__link">
<span class="md-ellipsis">
自动语音识别
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2009%3A%20audio%20and%20speech/03.%20text%20to%20speech%20and%20voice/" class="md-nav__link">
<span class="md-ellipsis">
语音合成
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2009%3A%20audio%20and%20speech/04.%20speaker%20and%20audio%20analysis/" class="md-nav__link">
<span class="md-ellipsis">
说话人与音频分析
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2009%3A%20audio%20and%20speech/05.%20source%20separation%20and%20noise/" class="md-nav__link">
<span class="md-ellipsis">
源分离与降噪
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_11" >
<label class="md-nav__link" for="__nav_11" id="__nav_11_label" tabindex="0">
<span class="md-ellipsis">
多模态学习
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_11_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_11">
<span class="md-nav__icon md-icon"></span>
多模态学习
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2010%3A%20multimodal%20learning/01.%20multimodal%20representations/" class="md-nav__link">
<span class="md-ellipsis">
多模态表征
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2010%3A%20multimodal%20learning/02.%20vision%20language%20models/" class="md-nav__link">
<span class="md-ellipsis">
视觉语言模型
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2010%3A%20multimodal%20learning/03.%20image%20and%20video%20tokenisation/" class="md-nav__link">
<span class="md-ellipsis">
图像与视频 Token 化
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/" class="md-nav__link">
<span class="md-ellipsis">
跨模态生成
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2010%3A%20multimodal%20learning/05.%20unified%20multimodal%20architectures/" class="md-nav__link">
<span class="md-ellipsis">
统一多模态架构
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_12" >
<label class="md-nav__link" for="__nav_12" id="__nav_12_label" tabindex="0">
<span class="md-ellipsis">
自主系统
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_12_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_12">
<span class="md-nav__icon md-icon"></span>
自主系统
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2011%3A%20autonomous%20systems/01.%20perception/" class="md-nav__link">
<span class="md-ellipsis">
感知
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2011%3A%20autonomous%20systems/02.%20robot%20learning/" class="md-nav__link">
<span class="md-ellipsis">
机器人学习
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2011%3A%20autonomous%20systems/03.%20vision-language-action%20models/" class="md-nav__link">
<span class="md-ellipsis">
视觉-语言-动作模型
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2011%3A%20autonomous%20systems/04.%20self-driving/" class="md-nav__link">
<span class="md-ellipsis">
自动驾驶
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2011%3A%20autonomous%20systems/05.%20space%20and%20extreme%20robotics/" class="md-nav__link">
<span class="md-ellipsis">
太空与极端机器人
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_13" >
<label class="md-nav__link" for="__nav_13" id="__nav_13_label" tabindex="0">
<span class="md-ellipsis">
图神经网络
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_13_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_13">
<span class="md-nav__icon md-icon"></span>
图神经网络
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2012%3A%20graph%20neural%20networks/01.%20geometric%20deep%20learning/" class="md-nav__link">
<span class="md-ellipsis">
几何深度学习
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2012%3A%20graph%20neural%20networks/02.%20graph%20theory/" class="md-nav__link">
<span class="md-ellipsis">
图论
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2012%3A%20graph%20neural%20networks/03.%20graph%20neural%20networks/" class="md-nav__link">
<span class="md-ellipsis">
图神经网络
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2012%3A%20graph%20neural%20networks/04.%20graph%20attention%20networks/" class="md-nav__link">
<span class="md-ellipsis">
图注意力网络
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2012%3A%20graph%20neural%20networks/05.%203d%20graph%20networks/" class="md-nav__link">
<span class="md-ellipsis">
3D 图网络
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_14" >
<label class="md-nav__link" for="__nav_14" id="__nav_14_label" tabindex="0">
<span class="md-ellipsis">
计算机与操作系统
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_14_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_14">
<span class="md-nav__icon md-icon"></span>
计算机与操作系统
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2013%3A%20computing%20and%20OS/01.%20discrete%20maths/" class="md-nav__link">
<span class="md-ellipsis">
离散数学
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2013%3A%20computing%20and%20OS/02.%20computer%20architecture/" class="md-nav__link">
<span class="md-ellipsis">
计算机体系结构
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2013%3A%20computing%20and%20OS/03.%20operating%20systems/" class="md-nav__link">
<span class="md-ellipsis">
操作系统
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2013%3A%20computing%20and%20OS/04.%20concurrency%20and%20parallelism/" class="md-nav__link">
<span class="md-ellipsis">
并发与并行
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2013%3A%20computing%20and%20OS/05.%20programming%20languages/" class="md-nav__link">
<span class="md-ellipsis">
编程语言
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_15" >
<label class="md-nav__link" for="__nav_15" id="__nav_15_label" tabindex="0">
<span class="md-ellipsis">
数据结构与算法
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_15_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_15">
<span class="md-nav__icon md-icon"></span>
数据结构与算法
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/" class="md-nav__link">
<span class="md-ellipsis">
基础
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/" class="md-nav__link">
<span class="md-ellipsis">
数组与哈希
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2014%3A%20data%20structures%20and%20algorithms/02.%20linked%20lists%2C%20stacks%2C%20and%20queues/" class="md-nav__link">
<span class="md-ellipsis">
链表、栈与队列
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2014%3A%20data%20structures%20and%20algorithms/03.%20trees/" class="md-nav__link">
<span class="md-ellipsis">
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2014%3A%20data%20structures%20and%20algorithms/04.%20graphs/" class="md-nav__link">
<span class="md-ellipsis">
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/" class="md-nav__link">
<span class="md-ellipsis">
排序与搜索
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_16" >
<label class="md-nav__link" for="__nav_16" id="__nav_16_label" tabindex="0">
<span class="md-ellipsis">
生产级软件工程
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_16_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_16">
<span class="md-nav__icon md-icon"></span>
生产级软件工程
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2015%3A%20production%20software%20engineering/01.%20linux%20and%20CMD/" class="md-nav__link">
<span class="md-ellipsis">
Linux 与命令行
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2015%3A%20production%20software%20engineering/02.%20git%20and%20repository%20management/" class="md-nav__link">
<span class="md-ellipsis">
Git 与仓库管理
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2015%3A%20production%20software%20engineering/03.%20codebase%20design/" class="md-nav__link">
<span class="md-ellipsis">
代码设计
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2015%3A%20production%20software%20engineering/04.%20testing%20and%20quality%20assurance/" class="md-nav__link">
<span class="md-ellipsis">
测试与质量保障
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2015%3A%20production%20software%20engineering/05.%20deployment%20and%20devops/" class="md-nav__link">
<span class="md-ellipsis">
部署与 DevOps
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_17" >
<label class="md-nav__link" for="__nav_17" id="__nav_17_label" tabindex="0">
<span class="md-ellipsis">
SIMD 与 GPU 编程
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_17_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_17">
<span class="md-nav__icon md-icon"></span>
SIMD 与 GPU 编程
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2016%3A%20SIMD%20and%20GPU%20programming/00.%20why%20C%2B%2B%20and%20how%20ML%20frameworks%20work/" class="md-nav__link">
<span class="md-ellipsis">
为什么是 C++ 及 ML 框架原理
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2016%3A%20SIMD%20and%20GPU%20programming/01.%20hardware%20fundamentals/" class="md-nav__link">
<span class="md-ellipsis">
硬件基础
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2016%3A%20SIMD%20and%20GPU%20programming/02.%20ARM%20and%20NEON/" class="md-nav__link">
<span class="md-ellipsis">
ARM 与 NEON
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2016%3A%20SIMD%20and%20GPU%20programming/03.%20x86%20and%20AVX/" class="md-nav__link">
<span class="md-ellipsis">
x86 与 AVX
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2016%3A%20SIMD%20and%20GPU%20programming/04.%20GPU%20architecture%20and%20CUDA/" class="md-nav__link">
<span class="md-ellipsis">
GPU 架构与 CUDA
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2016%3A%20SIMD%20and%20GPU%20programming/05.%20triton%2C%20TPUs%20and%20pallax/" class="md-nav__link">
<span class="md-ellipsis">
Triton、TPU 与 Pallas
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2016%3A%20SIMD%20and%20GPU%20programming/06.%20RISC-V%20and%20embedded%20systems/" class="md-nav__link">
<span class="md-ellipsis">
RISC-V 与嵌入式系统
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2016%3A%20SIMD%20and%20GPU%20programming/07.%20vulkan%20compute%20and%20cross-platform%20GPU/" class="md-nav__link">
<span class="md-ellipsis">
Vulkan Compute 与跨平台 GPU
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_18" >
<label class="md-nav__link" for="__nav_18" id="__nav_18_label" tabindex="0">
<span class="md-ellipsis">
AI 推理
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_18_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_18">
<span class="md-nav__icon md-icon"></span>
AI 推理
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2017%3A%20AI%20inference/01.%20quantisation/" class="md-nav__link">
<span class="md-ellipsis">
量化
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2017%3A%20AI%20inference/02.%20efficient%20architectures/" class="md-nav__link">
<span class="md-ellipsis">
高效架构
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2017%3A%20AI%20inference/03.%20serving%20and%20batching/" class="md-nav__link">
<span class="md-ellipsis">
服务与批处理
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2017%3A%20AI%20inference/04.%20edge%20inference/" class="md-nav__link">
<span class="md-ellipsis">
边缘推理
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2017%3A%20AI%20inference/05.%20scaling%20and%20deployment/" class="md-nav__link">
<span class="md-ellipsis">
扩缩与部署
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_19" >
<label class="md-nav__link" for="__nav_19" id="__nav_19_label" tabindex="0">
<span class="md-ellipsis">
ML 系统设计
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_19_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_19">
<span class="md-nav__icon md-icon"></span>
ML 系统设计
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2018%3A%20ML%20systems%20design/01.%20systems%20design%20fundamentals/" class="md-nav__link">
<span class="md-ellipsis">
系统设计基础
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2018%3A%20ML%20systems%20design/02.%20cloud%20computing/" class="md-nav__link">
<span class="md-ellipsis">
云计算
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/" class="md-nav__link">
<span class="md-ellipsis">
大规模基础设施
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/" class="md-nav__link">
<span class="md-ellipsis">
ML 系统设计
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/" class="md-nav__link">
<span class="md-ellipsis">
ML 设计案例
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_20" >
<label class="md-nav__link" for="__nav_20" id="__nav_20_label" tabindex="0">
<span class="md-ellipsis">
应用 AI
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_20_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_20">
<span class="md-nav__icon md-icon"></span>
应用 AI
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2019%3A%20applied%20AI/01.%20AI%20for%20finance/" class="md-nav__link">
<span class="md-ellipsis">
AI 金融
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2019%3A%20applied%20AI/02.%20protein%20design/" class="md-nav__link">
<span class="md-ellipsis">
蛋白质设计
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2019%3A%20applied%20AI/03.%20drug%20discovery/" class="md-nav__link">
<span class="md-ellipsis">
药物发现
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2019%3A%20applied%20AI/04.%20agentic%20systems/" class="md-nav__link">
<span class="md-ellipsis">
智能体系统
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2019%3A%20applied%20AI/05.%20healthcare/" class="md-nav__link">
<span class="md-ellipsis">
医疗健康
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_21" >
<label class="md-nav__link" for="__nav_21" id="__nav_21_label" tabindex="0">
<span class="md-ellipsis">
前沿 AI
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_21_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_21">
<span class="md-nav__icon md-icon"></span>
前沿 AI
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../chapter%2020%3A%20bleeding%20edge%20AI/01.%20quantum%20machine%20learning/" class="md-nav__link">
<span class="md-ellipsis">
量子机器学习
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2020%3A%20bleeding%20edge%20AI/02.%20neuromorphic%20computing/" class="md-nav__link">
<span class="md-ellipsis">
神经形态计算
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2020%3A%20bleeding%20edge%20AI/03.%20datacentres%20in%20space/" class="md-nav__link">
<span class="md-ellipsis">
太空数据中心
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2020%3A%20bleeding%20edge%20AI/04.%20decentralised%20AI/" class="md-nav__link">
<span class="md-ellipsis">
去中心化 AI
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../chapter%2020%3A%20bleeding%20edge%20AI/05.%20brain%20machine%20interfaces/" class="md-nav__link">
<span class="md-ellipsis">
脑机接口
</span>
</a>
</li>
</ul>
</nav>
</li>
</ul>
</nav>
</div>
</div>
</div>
<div class="md-sidebar md-sidebar--secondary" data-md-component="sidebar" data-md-type="toc" >
<div class="md-sidebar__scrollwrap">
<div class="md-sidebar__inner">
<nav class="md-nav md-nav--secondary" aria-label="目录">
<label class="md-nav__title" for="__toc">
<span class="md-nav__icon md-icon"></span>
目录
</label>
<ul class="md-nav__list" data-md-component="toc" data-md-scrollfix>
<li class="md-nav__item">
<a href="#colabnotebook" class="md-nav__link">
<span class="md-ellipsis">
编程练习(使用CoLab或notebook
</span>
</a>
</li>
</ul>
</nav>
</div>
</div>
</div>
<div class="md-content" data-md-component="content">
<article class="md-content__inner md-typeset">
<h1 id="transformer">视觉Transformer与生成模型<a class="headerlink" href="#transformer" title="Permanent link">&para;</a></h1>
<p><em>视觉Transformer将自注意力应用于图像块,通过数据驱动的空间学习挑战了CNN的主导地位。本文涵盖ViT、DeiT、Swin Transformer、基于GAN的图像生成(StyleGAN)、VAE和扩散模型(DDPM、Stable Diffusion),以及超分辨率和神经风格迁移。</em></p>
<ul>
<li>
<p>CNN(文件02)内置了很强的空间归纳偏置:局部连接、权重共享和平移等变性。视觉Transformer(ViT)提出了一个启发性的问题:如果我们完全抛弃这些偏置,仅使用第06章中的注意力机制,让模型从数据中学习空间结构,结果会怎样?</p>
</li>
<li>
<p><strong>ViT</strong>Vision TransformerDosovitskiy等人,2021)将标准的Transformer编码器直接应用于图像。其核心思想是将图像视为一个图像块序列,就像NLP将文本视为一个词元序列一样。</p>
</li>
<li>
<p>其处理流程如下:</p>
</li>
<li>将图像(高度<span class="arithmatex">\(H\)</span>,宽度<span class="arithmatex">\(W\)</span>,通道数<span class="arithmatex">\(C\)</span>)分割成<span class="arithmatex">\(P \times P\)</span>大小的不重叠图像块网格。得到<span class="arithmatex">\(N = HW / P^2\)</span>个图像块。</li>
<li>将每个图像块展平成长度为<span class="arithmatex">\(P^2 \cdot C\)</span>的向量,并通过一个可学习的线性嵌入(单个矩阵乘法,第02章)将其投影到模型维度<span class="arithmatex">\(D\)</span></li>
<li>在前面添加一个可学习的<strong>[CLS]标记</strong>嵌入(类似于BERT的[CLS],第07章)。该标记会关注所有图像块,其最终表示用于分类。</li>
<li>添加<strong>位置嵌入</strong>(每个位置一个可学习向量)以提供空间信息,因为注意力是置换等变的。</li>
<li><span class="arithmatex">\((N + 1)\)</span>个标记嵌入序列通过标准的Transformer编码器(多头自注意力 + FFN,第06章)。</li>
<li>[CLS]标记的最终表示通过一个分类头(小型MLP)进行分类。</li>
</ul>
<p><img alt="ViT流程:将图像分割为16x16图像块,每个块展平并线性投影,添加[CLS]标记,加上位置嵌入,然后由Transformer编码器块处理" src="../../images/vit_pipeline.svg" /></p>
<ul>
<li>
<p><strong>图像块嵌入</strong>等价于一个卷积核大小为<span class="arithmatex">\(P\)</span>、步长为<span class="arithmatex">\(P\)</span>(不重叠)的卷积操作。ViT将2D图像字面地转换为1D序列,然后用与处理语言相同的架构来处理它。</p>
</li>
<li>
<p>ViT的归纳偏置比CNN少:它不强制局部连接或平移等变性。这意味着它需要更多的训练数据才能从头学习空间结构。在小型数据集上,CNN优于ViT。但在非常大的数据集(JFT-300M,3亿张图像)上训练时,ViT达到或超过了最佳CNN的性能,这表明CNN的归纳偏置有助于数据效率,但对于最终性能并非必需。</p>
</li>
<li>
<p>ViT自注意力的复杂度为<span class="arithmatex">\(O(N^2)\)</span>,其中N是图像块数量。对于224x224的图像和16x16的图像块,<span class="arithmatex">\(N = 196\)</span>,这在可控范围内。但对于更高分辨率的图像或更小的图像块,二次成本变得难以承受。</p>
</li>
<li>
<p><strong>DeiT</strong>(数据高效的图像TransformerTouvron等人,2021)表明,仅使用ImageNet(无需庞大的JFT数据集)并借助强数据增强、正则化(随机深度、标签平滑、dropout)和<strong>知识蒸馏</strong>,就可以有效训练ViT:一个预训练的CNN教师提供软标签,ViT学生学习匹配这些标签。DeiT在[CLS]标记旁边添加了一个<strong>蒸馏标记</strong>,训练用于预测教师的输出。</p>
</li>
<li>
<p><strong>Swin Transformer</strong>(Liu等人,2021)解决了ViT的两个主要局限:随图像大小呈二次增长的计算成本,以及缺少层次化特征图(检测和分割需要层次化特征)。</p>
</li>
<li>
<p>Swin引入了<strong>移动窗口</strong>:不再对所有图像块进行全局自注意力,而是在局部窗口内(例如7x7个图像块)计算注意力。这使得计算成本与图像大小呈线性关系:<span class="arithmatex">\(O(N)\)</span>而非<span class="arithmatex">\(O(N^2)\)</span>。但仅靠局部窗口会阻止区域之间的信息流动。</p>
</li>
<li>
<p><strong>窗口移动</strong>解决了这个问题:在交替层中,窗口划分会偏移半个窗口大小。这创建了跨窗口连接,使得信息可以在所有图像部分之间流动,而无需全局注意力的成本。</p>
</li>
</ul>
<p><img alt="Swin Transformer:第l层在常规窗口内计算注意力,第l+1层将窗口划分偏移一半,创建跨窗口连接" src="../../images/swin_shifted_windows.svg" /></p>
<ul>
<li>
<p>Swin还通过跨阶段合并图像块来构建<strong>层次化表示</strong>。每个阶段之后,相邻的2x2图像块被拼接并投影,使通道维度加倍、空间分辨率减半。这产生了多尺度特征图,类似于CNN和FPN(文件03)中的特征图,使得Swin可以直接兼容Faster R-CNN等检测头和U-Net等分割头。</p>
</li>
<li>
<p><strong>PVT</strong>(金字塔视觉Transformer)采用了类似的层次化方法,具有空间缩减注意力:在每个阶段,键和值在计算注意力之前先进行空间下采样,从而在保持全局感受野的同时降低二次成本。</p>
</li>
<li>
<p><strong>自监督视觉学习</strong>从未标注的图像中训练表示。标注成本高,但图像资源丰富。目标是在没有任何人工标注的情况下,学习能很好地迁移到下游任务的特征。</p>
</li>
<li>
<p><strong>对比学习</strong>训练模型识别:同一张图像的两个增广视图("正样本对")应具有相似的表示,而不同图像的视图("负样本对")应具有不相似的表示。</p>
</li>
<li>
<p><strong>SimCLR</strong>(Chen等人,2020)对一个批次中的每张图像创建两个增广视图,用共享主干网络+投影头对两者进行编码,并应用<strong>NT-Xent损失</strong>(归一化温度标度交叉熵):</p>
</li>
</ul>
<div class="arithmatex">\[\ell_{i,j} = -\log \frac{\exp(\text{sim}(z_i, z_j) / \tau)}{\sum_{k \neq i} \exp(\text{sim}(z_i, z_k) / \tau)}\]</div>
<ul>
<li>
<p>其中<span class="arithmatex">\(\text{sim}\)</span>是余弦相似度(第01章),<span class="arithmatex">\(\tau\)</span>是温度参数。分子将正样本对拉近;分母将负样本对推远。SimCLR需要大批量大小(4,096+)来提供足够的负样本。</p>
</li>
<li>
<p><strong>MoCo</strong>(动量对比,He等人,2020)通过维护一个<strong>动量更新的负嵌入队列</strong>来解决大批量需求。查询编码器通过梯度下降更新;键编码器作为查询编码器的指数移动平均(EMA,第04章)进行更新:<span class="arithmatex">\(\theta_k \leftarrow m \theta_k + (1 - m) \theta_q\)</span>,其中<span class="arithmatex">\(m = 0.999\)</span>。队列存储最近的键嵌入,提供了大量且一致的负样本集,无需巨大的批次。</p>
</li>
<li>
<p><strong>BYOL</strong>(自举你自己的隐空间,Grill等人,2020)完全消除了负样本对。它使用两个网络:"在线"网络和"目标"网络(在线的EMA)。在线网络预测目标网络对另一增广视图的表示。无需负样本,BYOL通过预测头的不对称性和EMA目标避免了坍塌问题(模型对所有输入输出相同向量)。</p>
</li>
<li>
<p><strong>DINO</strong>(无标签自蒸馏,Caron等人,2021)将自蒸馏应用于ViT。学生网络预测教师网络(学生的EMA)在不同增广视图下的输出。教师使用更大的裁剪区域;学生使用更小的裁剪区域。DINO产生的特征包含关于场景布局的显式信息:DINO训练的ViT的自注意力图自然地对物体进行分割,无需任何分割监督。</p>
</li>
<li>
<p><strong>掩码图像建模</strong>是BERT掩码语言建模(第07章)在视觉领域的类比。输入图像块的一大部分被掩码,模型学习重建它们。</p>
</li>
<li>
<p><strong>MAE</strong>(掩码自编码器,He等人,2022)掩码了75%的图像块,并训练一个ViT编码器-解码器来重建缺失的像素值。只有未掩码的图像块由编码器处理(在预训练期间节省4倍计算量),轻量级解码器从编码后的可见图像块加上可学习的掩码标记重建完整图像。</p>
</li>
<li>
<p><strong>BEiT</strong>(图像Transformer的BERT预训练,Bao等人,2022)掩码图像块并预测离散的视觉标记(从预训练的dVAE分词器获得),而不是原始像素。这类似于BERT预测离散词标记,避免了像素重建的低层细节。</p>
</li>
<li>
<p><strong>图像生成</strong>旨在生成训练集中不存在的新颖、逼真的图像。核心挑战是对自然图像的高维概率分布进行建模。</p>
</li>
<li>
<p><strong>生成对抗网络(GAN</strong>Goodfellow等人,2014)使用两个相互竞争的网络:一个<strong>生成器</strong><span class="arithmatex">\(G\)</span>从随机噪声中创建假图像,和一个<strong>判别器</strong><span class="arithmatex">\(D\)</span>试图区分真实图像和假图像。它们通过对抗性训练:<span class="arithmatex">\(G\)</span>试图欺骗<span class="arithmatex">\(D\)</span>,而<span class="arithmatex">\(D\)</span>试图抓住<span class="arithmatex">\(G\)</span></p>
</li>
</ul>
<div class="arithmatex">\[\min_G \max_D \; \mathbb{E}_{x \sim p_{\text{data}}}[\log D(x)] + \mathbb{E}_{z \sim p(z)}[\log(1 - D(G(z)))]\]</div>
<ul>
<li>
<p>生成器接收随机隐向量<span class="arithmatex">\(z\)</span>(从高斯分布等简单分布中采样),通过一系列转置卷积将其映射生成图像。判别器是一个标准的CNN分类器。在均衡状态下,<span class="arithmatex">\(G\)</span>生成的图像与真实数据无法区分,<span class="arithmatex">\(D\)</span>对所有输入输出0.5。</p>
</li>
<li>
<p><strong>模式坍塌</strong>是GAN的主要失败模式:生成器学会只生成少数几种能欺骗判别器的图像,忽略了训练数据的多样性。生成器找到一小部分"安全"输出,而不是覆盖完整的数据分布。</p>
</li>
<li>
<p>稳定GAN的训练技巧包括:谱归一化(约束判别器的Lipschitz常数)、渐进式增长(先在低分辨率训练,然后逐步提高)、特征匹配(匹配中间判别器特征的统计量而非最终输出),以及使用Wasserstein距离替代原始的JS散度目标。</p>
</li>
<li>
<p><strong>StyleGAN</strong>(Karras等人,2019)是最具影响力的高质量图像合成GAN架构。其关键创新是<strong>基于风格的生成器</strong>:不是将隐向量<span class="arithmatex">\(z\)</span>直接输入生成器,而是先通过一个<strong>映射网络</strong>8层MLP)生成风格向量<span class="arithmatex">\(w\)</span>。该风格向量通过<strong>自适应实例归一化(AdaIN</strong>注入到生成器的每一层,调节特征图的统计量:</p>
</li>
</ul>
<div class="arithmatex">\[\text{AdaIN}(x, y) = y_{s} \cdot \frac{x - \mu(x)}{\sigma(x)} + y_{b}\]</div>
<ul>
<li>
<p>其中<span class="arithmatex">\(y_s\)</span><span class="arithmatex">\(y_b\)</span>是从<span class="arithmatex">\(w\)</span>推导出的缩放和偏置。不同层控制不同方面:早期层控制粗粒度特征(姿态、脸型),中间层控制中粒度特征(发型、眼睛),后期层控制细粒度细节(雀斑、发质纹理)。StyleGAN能以1024x1024分辨率生成照片级逼真的人脸。</p>
</li>
<li>
<p><strong>变分自编码器(VAE</strong>(第06章)提供了另一种生成方法。与GAN不同,VAE有一个原则性的概率框架,具有清晰的训练目标(ELBO)。它们生成的图像通常比GAN模糊,但提供了更平滑、更结构化的隐空间。VAE是隐扩散模型中用于将图像压缩到隐空间和从隐空间重建的编码器-解码器对。</p>
</li>
<li>
<p><strong>扩散模型</strong>已成为图像生成的主导范式,在质量和多样性上都超越了GAN。其思想概念上很简单:逐步向数据添加噪声直到变成纯高斯噪声(<strong>前向过程</strong>),然后学习逐步逆转这一过程(<strong>反向过程</strong>)。</p>
</li>
<li>
<p><strong>前向过程</strong><span class="arithmatex">\(T\)</span>个时间步中添加高斯噪声:</p>
</li>
</ul>
<div class="arithmatex">\[q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1 - \beta_t} \, x_{t-1}, \beta_t I)\]</div>
<ul>
<li>其中<span class="arithmatex">\(\beta_t\)</span>是一个随时间递增的噪声调度。经过足够多的步骤后,无论原始图像<span class="arithmatex">\(x_0\)</span>如何,<span class="arithmatex">\(x_T\)</span>都近似于纯高斯噪声。利用重参数化技巧(第06章),设<span class="arithmatex">\(\alpha_t = 1 - \beta_t\)</span><span class="arithmatex">\(\bar{\alpha}_t = \prod_{s=1}^{t} \alpha_s\)</span>,我们可以直接从<span class="arithmatex">\(x_0\)</span>采样<span class="arithmatex">\(x_t\)</span></li>
</ul>
<div class="arithmatex">\[x_t = \sqrt{\bar{\alpha}_t} \, x_0 + \sqrt{1 - \bar{\alpha}_t} \, \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)\]</div>
<ul>
<li><strong>反向过程</strong>学习去噪:从纯噪声<span class="arithmatex">\(x_T\)</span>开始,模型预测每一步添加的噪声<span class="arithmatex">\(\epsilon\)</span>并将其减去以恢复<span class="arithmatex">\(x_{t-1}\)</span>。这由一个神经网络<span class="arithmatex">\(\epsilon_\theta\)</span>(通常是U-Net,来自文件03)参数化,使用简单的MSE损失训练:</li>
</ul>
<div class="arithmatex">\[\mathcal{L} = \mathbb{E}_{t, x_0, \epsilon}\left[\|\epsilon - \epsilon_\theta(x_t, t)\|^2\right]\]</div>
<p><img alt="扩散前向和反向过程:干净图像在T步中逐渐被噪声破坏(前向),神经网络学习逆转每一步(反向),从纯噪声开始生成干净图像" src="../../images/diffusion_process.svg" /></p>
<ul>
<li>
<p><strong>DDPM</strong>(去噪扩散概率模型,Ho等人,2020)建立了这个框架。采样需要迭代所有<span class="arithmatex">\(T\)</span>步(通常为1,000步),这很慢。<strong>DDIM</strong>(去噪扩散隐式模型,Song等人,2021)将采样过程重新表述为确定性映射,允许大跨度跳过(例如50步代替1,000步)且质量损失极小。</p>
</li>
<li>
<p><strong>基于分数的模型</strong>Song和Ermon,2019)提供了另一种视角。该模型不是预测噪声<span class="arithmatex">\(\epsilon\)</span>,而是估计<strong>分数函数</strong><span class="arithmatex">\(\nabla_{x_t} \log p(x_t)\)</span>,即对数概率相对于含噪图像的梯度。该梯度指向数据分布中更高概率(更干净)的区域。采样使用Langevin动力学沿着该梯度进行。基于分数的模型和DDPM在<strong>随机微分方程(SDE</strong>的框架下被统一:前向过程是添加噪声的SDE,反向过程是时间反转的SDE。</p>
</li>
<li>
<p><strong>无分类器引导</strong>Ho和Salimans,2022)控制样本质量和多样性之间的权衡。模型同时进行条件训练(使用文本提示或类别标签)和无条件训练(条件随机丢弃)。在采样时,预测是加权组合:</p>
</li>
</ul>
<div class="arithmatex">\[\hat{\epsilon} = \epsilon_\theta(x_t, \varnothing) + s \cdot (\epsilon_\theta(x_t, c) - \epsilon_\theta(x_t, \varnothing))\]</div>
<ul>
<li>
<p>其中<span class="arithmatex">\(c\)</span>是条件,<span class="arithmatex">\(\varnothing\)</span>是空条件,<span class="arithmatex">\(s &gt; 1\)</span>是引导尺度。<span class="arithmatex">\(s\)</span>越高,生成的图像越符合条件,但多样性越低。<span class="arithmatex">\(s = 1\)</span>是无引导模型;<span class="arithmatex">\(s = 7.5\)</span>是常见的默认值。</p>
</li>
<li>
<p><strong>隐扩散</strong>Rombach等人,2022Stable Diffusion)将扩散过程从像素空间转移到学习的隐空间中。一个预训练的VAE编码器将图像压缩为较低维度的隐空间表示(通常空间下采样4倍或8倍),扩散在这个压缩空间中进行,VAE解码器从去噪后的隐变量重建像素。这大大提高了效率:在像素空间扩散512x512图像需要处理<span class="arithmatex">\(512 \times 512 \times 3\)</span>的张量,但在隐空间中仅需处理<span class="arithmatex">\(64 \times 64 \times 4\)</span>的张量。</p>
</li>
<li>
<p>隐扩散中的去噪U-Net接收含噪隐变量、时间步(编码为正弦嵌入,类似于Transformer中的位置编码)和条件信号(来自冻结的CLIP或T5文本编码器的文本嵌入)。文本条件通过U-Net内的交叉注意力层进入:文本嵌入作为键和值,图像特征作为查询。这使得模型在每个空间位置都能关注文本提示的相关部分。</p>
</li>
<li>
<p><strong>流匹配</strong>是扩散模型的一个新兴替代方案,它学习噪声和数据之间的直接传输路径,而不是DDPM的迭代去噪。</p>
</li>
<li>
<p><strong>连续归一化流(CNF</strong>定义了一个时间相关的速度场<span class="arithmatex">\(v_\theta(x, t)\)</span>,沿着平滑轨迹将样本从简单分布<span class="arithmatex">\(p_0\)</span>(噪声)推送到数据分布<span class="arithmatex">\(p_1\)</span>。该变换遵循一个常微分方程(ODE):</p>
</li>
</ul>
<div class="arithmatex">\[\frac{dx}{dt} = v_\theta(x, t), \quad t \in [0, 1]\]</div>
<ul>
<li>
<p><span class="arithmatex">\(x_0 \sim \mathcal{N}(0, I)\)</span>开始,将ODE向前积分到<span class="arithmatex">\(t = 1\)</span>即可得到数据分布中的样本。速度场由神经网络参数化,训练目标是匹配目标条件流。</p>
</li>
<li>
<p><strong>最优传输(OT)流匹配</strong>(Lipman等人,2023)使用噪声和数据之间的直线路径作为目标流:从噪声样本<span class="arithmatex">\(x_0\)</span>到数据样本<span class="arithmatex">\(x_1\)</span>的条件路径简单地是<span class="arithmatex">\(x_t = (1 - t) x_0 + t x_1\)</span>,目标速度为<span class="arithmatex">\(v = x_1 - x_0\)</span>。训练损失变为:</p>
</li>
</ul>
<div class="arithmatex">\[\mathcal{L} = \mathbb{E}_{t, x_0, x_1} \left[\|v_\theta(x_t, t) - (x_1 - x_0)\|^2\right]\]</div>
<ul>
<li>
<p><strong>整流流</strong>(Liu等人,2022)通过迭代方式拉直学习到的流路径。在初始训练后,模型通过模拟ODE生成(噪声,数据)对。这些比随机配对更紧密对齐的对用于重新训练模型。重复此过程会产生越来越直的路径,可以通过更少的ODE步骤(甚至单步)来遍历,从而实现极快速的生成。</p>
</li>
<li>
<p>流匹配相比扩散有几个优势:训练目标更简单(直接的速度回归,无需噪声调度),采样ODE更平滑(需要的积分步骤更少),与最优传输的联系提供了理论依据。Stable Diffusion 3和Flux使用流匹配替代了传统的DDPM。</p>
</li>
</ul>
<h2 id="colabnotebook">编程练习(使用CoLab或notebook<a class="headerlink" href="#colabnotebook" title="Permanent link">&para;</a></h2>
<ol>
<li>
<p>从头实现ViT图像块嵌入。将图像分割成图像块,展平,投影到模型维度,添加位置嵌入,并前置[CLS]标记。
<div class="highlight"><pre><span></span><code><a id="__codelineno-0-1" name="__codelineno-0-1" href="#__codelineno-0-1"></a><span class="kn">import</span><span class="w"> </span><span class="nn">jax</span>
<a id="__codelineno-0-2" name="__codelineno-0-2" href="#__codelineno-0-2"></a><span class="kn">import</span><span class="w"> </span><span class="nn">jax.numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">jnp</span>
<a id="__codelineno-0-3" name="__codelineno-0-3" href="#__codelineno-0-3"></a><span class="kn">import</span><span class="w"> </span><span class="nn">matplotlib.pyplot</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">plt</span>
<a id="__codelineno-0-4" name="__codelineno-0-4" href="#__codelineno-0-4"></a>
<a id="__codelineno-0-5" name="__codelineno-0-5" href="#__codelineno-0-5"></a><span class="k">def</span><span class="w"> </span><span class="nf">create_patch_embedding</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">patch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">params</span><span class="p">):</span>
<a id="__codelineno-0-6" name="__codelineno-0-6" href="#__codelineno-0-6"></a><span class="w"> </span><span class="sd">&quot;&quot;&quot;将图像转换为图像块嵌入序列。&quot;&quot;&quot;</span>
<a id="__codelineno-0-7" name="__codelineno-0-7" href="#__codelineno-0-7"></a> <span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">shape</span>
<a id="__codelineno-0-8" name="__codelineno-0-8" href="#__codelineno-0-8"></a> <span class="n">n_patches_h</span> <span class="o">=</span> <span class="n">H</span> <span class="o">//</span> <span class="n">patch_size</span>
<a id="__codelineno-0-9" name="__codelineno-0-9" href="#__codelineno-0-9"></a> <span class="n">n_patches_w</span> <span class="o">=</span> <span class="n">W</span> <span class="o">//</span> <span class="n">patch_size</span>
<a id="__codelineno-0-10" name="__codelineno-0-10" href="#__codelineno-0-10"></a> <span class="n">n_patches</span> <span class="o">=</span> <span class="n">n_patches_h</span> <span class="o">*</span> <span class="n">n_patches_w</span>
<a id="__codelineno-0-11" name="__codelineno-0-11" href="#__codelineno-0-11"></a>
<a id="__codelineno-0-12" name="__codelineno-0-12" href="#__codelineno-0-12"></a> <span class="c1"># 提取图像块</span>
<a id="__codelineno-0-13" name="__codelineno-0-13" href="#__codelineno-0-13"></a> <span class="n">patches</span> <span class="o">=</span> <span class="p">[]</span>
<a id="__codelineno-0-14" name="__codelineno-0-14" href="#__codelineno-0-14"></a> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_patches_h</span><span class="p">):</span>
<a id="__codelineno-0-15" name="__codelineno-0-15" href="#__codelineno-0-15"></a> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_patches_w</span><span class="p">):</span>
<a id="__codelineno-0-16" name="__codelineno-0-16" href="#__codelineno-0-16"></a> <span class="n">patch</span> <span class="o">=</span> <span class="n">image</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="n">patch_size</span><span class="p">:(</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="n">patch_size</span><span class="p">,</span>
<a id="__codelineno-0-17" name="__codelineno-0-17" href="#__codelineno-0-17"></a> <span class="n">j</span><span class="o">*</span><span class="n">patch_size</span><span class="p">:(</span><span class="n">j</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="n">patch_size</span><span class="p">,</span> <span class="p">:]</span>
<a id="__codelineno-0-18" name="__codelineno-0-18" href="#__codelineno-0-18"></a> <span class="n">patches</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">patch</span><span class="o">.</span><span class="n">ravel</span><span class="p">())</span>
<a id="__codelineno-0-19" name="__codelineno-0-19" href="#__codelineno-0-19"></a> <span class="n">patches</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">patches</span><span class="p">)</span> <span class="c1"># (N, P*P*C)</span>
<a id="__codelineno-0-20" name="__codelineno-0-20" href="#__codelineno-0-20"></a>
<a id="__codelineno-0-21" name="__codelineno-0-21" href="#__codelineno-0-21"></a> <span class="c1"># 线性投影到d_model</span>
<a id="__codelineno-0-22" name="__codelineno-0-22" href="#__codelineno-0-22"></a> <span class="n">embeddings</span> <span class="o">=</span> <span class="n">patches</span> <span class="o">@</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;proj_w&#39;</span><span class="p">]</span> <span class="o">+</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;proj_b&#39;</span><span class="p">]</span> <span class="c1"># (N, d_model)</span>
<a id="__codelineno-0-23" name="__codelineno-0-23" href="#__codelineno-0-23"></a>
<a id="__codelineno-0-24" name="__codelineno-0-24" href="#__codelineno-0-24"></a> <span class="c1"># 前置CLS标记</span>
<a id="__codelineno-0-25" name="__codelineno-0-25" href="#__codelineno-0-25"></a> <span class="n">cls_token</span> <span class="o">=</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;cls_token&#39;</span><span class="p">]</span> <span class="c1"># (1, d_model)</span>
<a id="__codelineno-0-26" name="__codelineno-0-26" href="#__codelineno-0-26"></a> <span class="n">embeddings</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">cls_token</span><span class="p">,</span> <span class="n">embeddings</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="c1"># (N+1, d_model)</span>
<a id="__codelineno-0-27" name="__codelineno-0-27" href="#__codelineno-0-27"></a>
<a id="__codelineno-0-28" name="__codelineno-0-28" href="#__codelineno-0-28"></a> <span class="c1"># 添加位置嵌入</span>
<a id="__codelineno-0-29" name="__codelineno-0-29" href="#__codelineno-0-29"></a> <span class="n">embeddings</span> <span class="o">=</span> <span class="n">embeddings</span> <span class="o">+</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;pos_embed&#39;</span><span class="p">]</span> <span class="c1"># (N+1, d_model)</span>
<a id="__codelineno-0-30" name="__codelineno-0-30" href="#__codelineno-0-30"></a>
<a id="__codelineno-0-31" name="__codelineno-0-31" href="#__codelineno-0-31"></a> <span class="k">return</span> <span class="n">embeddings</span><span class="p">,</span> <span class="n">patches</span>
<a id="__codelineno-0-32" name="__codelineno-0-32" href="#__codelineno-0-32"></a>
<a id="__codelineno-0-33" name="__codelineno-0-33" href="#__codelineno-0-33"></a><span class="c1"># 设置</span>
<a id="__codelineno-0-34" name="__codelineno-0-34" href="#__codelineno-0-34"></a><span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">3</span>
<a id="__codelineno-0-35" name="__codelineno-0-35" href="#__codelineno-0-35"></a><span class="n">patch_size</span> <span class="o">=</span> <span class="mi">8</span>
<a id="__codelineno-0-36" name="__codelineno-0-36" href="#__codelineno-0-36"></a><span class="n">d_model</span> <span class="o">=</span> <span class="mi">64</span>
<a id="__codelineno-0-37" name="__codelineno-0-37" href="#__codelineno-0-37"></a><span class="n">n_patches</span> <span class="o">=</span> <span class="p">(</span><span class="n">H</span> <span class="o">//</span> <span class="n">patch_size</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">W</span> <span class="o">//</span> <span class="n">patch_size</span><span class="p">)</span> <span class="c1"># 16</span>
<a id="__codelineno-0-38" name="__codelineno-0-38" href="#__codelineno-0-38"></a>
<a id="__codelineno-0-39" name="__codelineno-0-39" href="#__codelineno-0-39"></a><span class="n">key</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
<a id="__codelineno-0-40" name="__codelineno-0-40" href="#__codelineno-0-40"></a><span class="n">keys</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>
<a id="__codelineno-0-41" name="__codelineno-0-41" href="#__codelineno-0-41"></a>
<a id="__codelineno-0-42" name="__codelineno-0-42" href="#__codelineno-0-42"></a><span class="c1"># 创建具有不同象限的合成图像</span>
<a id="__codelineno-0-43" name="__codelineno-0-43" href="#__codelineno-0-43"></a><span class="n">image</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">C</span><span class="p">))</span>
<a id="__codelineno-0-44" name="__codelineno-0-44" href="#__codelineno-0-44"></a><span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">at</span><span class="p">[:</span><span class="mi">16</span><span class="p">,</span> <span class="p">:</span><span class="mi">16</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)</span> <span class="c1"># 红色 左上</span>
<a id="__codelineno-0-45" name="__codelineno-0-45" href="#__codelineno-0-45"></a><span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">at</span><span class="p">[:</span><span class="mi">16</span><span class="p">,</span> <span class="mi">16</span><span class="p">:,</span> <span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)</span> <span class="c1"># 绿色 右上</span>
<a id="__codelineno-0-46" name="__codelineno-0-46" href="#__codelineno-0-46"></a><span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">at</span><span class="p">[</span><span class="mi">16</span><span class="p">:,</span> <span class="p">:</span><span class="mi">16</span><span class="p">,</span> <span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)</span> <span class="c1"># 蓝色 左下</span>
<a id="__codelineno-0-47" name="__codelineno-0-47" href="#__codelineno-0-47"></a><span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">at</span><span class="p">[</span><span class="mi">16</span><span class="p">:,</span> <span class="mi">16</span><span class="p">:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)</span> <span class="c1"># 黄色 右下</span>
<a id="__codelineno-0-48" name="__codelineno-0-48" href="#__codelineno-0-48"></a>
<a id="__codelineno-0-49" name="__codelineno-0-49" href="#__codelineno-0-49"></a><span class="n">params</span> <span class="o">=</span> <span class="p">{</span>
<a id="__codelineno-0-50" name="__codelineno-0-50" href="#__codelineno-0-50"></a> <span class="s1">&#39;proj_w&#39;</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">keys</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="p">(</span><span class="n">patch_size</span><span class="o">**</span><span class="mi">2</span> <span class="o">*</span> <span class="n">C</span><span class="p">,</span> <span class="n">d_model</span><span class="p">))</span> <span class="o">*</span> <span class="mf">0.02</span><span class="p">,</span>
<a id="__codelineno-0-51" name="__codelineno-0-51" href="#__codelineno-0-51"></a> <span class="s1">&#39;proj_b&#39;</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">d_model</span><span class="p">),</span>
<a id="__codelineno-0-52" name="__codelineno-0-52" href="#__codelineno-0-52"></a> <span class="s1">&#39;cls_token&#39;</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">keys</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">d_model</span><span class="p">))</span> <span class="o">*</span> <span class="mf">0.02</span><span class="p">,</span>
<a id="__codelineno-0-53" name="__codelineno-0-53" href="#__codelineno-0-53"></a> <span class="s1">&#39;pos_embed&#39;</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">keys</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="p">(</span><span class="n">n_patches</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">d_model</span><span class="p">))</span> <span class="o">*</span> <span class="mf">0.02</span><span class="p">,</span>
<a id="__codelineno-0-54" name="__codelineno-0-54" href="#__codelineno-0-54"></a><span class="p">}</span>
<a id="__codelineno-0-55" name="__codelineno-0-55" href="#__codelineno-0-55"></a>
<a id="__codelineno-0-56" name="__codelineno-0-56" href="#__codelineno-0-56"></a><span class="n">embeddings</span><span class="p">,</span> <span class="n">patches</span> <span class="o">=</span> <span class="n">create_patch_embedding</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">patch_size</span><span class="p">,</span> <span class="n">d_model</span><span class="p">,</span> <span class="n">params</span><span class="p">)</span>
<a id="__codelineno-0-57" name="__codelineno-0-57" href="#__codelineno-0-57"></a>
<a id="__codelineno-0-58" name="__codelineno-0-58" href="#__codelineno-0-58"></a><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;图像形状: </span><span class="si">{</span><span class="n">image</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<a id="__codelineno-0-59" name="__codelineno-0-59" href="#__codelineno-0-59"></a><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;图像块大小: </span><span class="si">{</span><span class="n">patch_size</span><span class="si">}</span><span class="s2">x</span><span class="si">{</span><span class="n">patch_size</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<a id="__codelineno-0-60" name="__codelineno-0-60" href="#__codelineno-0-60"></a><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;图像块数量: </span><span class="si">{</span><span class="n">n_patches</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<a id="__codelineno-0-61" name="__codelineno-0-61" href="#__codelineno-0-61"></a><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;图像块向量长度: </span><span class="si">{</span><span class="n">patch_size</span><span class="o">**</span><span class="mi">2</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">C</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<a id="__codelineno-0-62" name="__codelineno-0-62" href="#__codelineno-0-62"></a><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;嵌入形状: </span><span class="si">{</span><span class="n">embeddings</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2"> (CLS + </span><span class="si">{</span><span class="n">n_patches</span><span class="si">}</span><span class="s2"> 个图像块)&quot;</span><span class="p">)</span>
<a id="__codelineno-0-63" name="__codelineno-0-63" href="#__codelineno-0-63"></a>
<a id="__codelineno-0-64" name="__codelineno-0-64" href="#__codelineno-0-64"></a><span class="c1"># 可视化图像块</span>
<a id="__codelineno-0-65" name="__codelineno-0-65" href="#__codelineno-0-65"></a><span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">14</span><span class="p">,</span> <span class="mi">6</span><span class="p">))</span>
<a id="__codelineno-0-66" name="__codelineno-0-66" href="#__codelineno-0-66"></a><span class="n">axes</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image</span><span class="p">);</span> <span class="n">axes</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s1">&#39;完整图像&#39;</span><span class="p">);</span> <span class="n">axes</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s1">&#39;off&#39;</span><span class="p">)</span>
<a id="__codelineno-0-67" name="__codelineno-0-67" href="#__codelineno-0-67"></a><span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">min</span><span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="n">n_patches</span><span class="p">)):</span>
<a id="__codelineno-0-68" name="__codelineno-0-68" href="#__codelineno-0-68"></a> <span class="n">ax</span> <span class="o">=</span> <span class="n">axes</span><span class="p">[(</span><span class="n">idx</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span> <span class="o">//</span> <span class="mi">5</span><span class="p">,</span> <span class="p">(</span><span class="n">idx</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span> <span class="o">%</span> <span class="mi">5</span><span class="p">]</span>
<a id="__codelineno-0-69" name="__codelineno-0-69" href="#__codelineno-0-69"></a> <span class="n">patch_img</span> <span class="o">=</span> <span class="n">patches</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">patch_size</span><span class="p">,</span> <span class="n">patch_size</span><span class="p">,</span> <span class="n">C</span><span class="p">)</span>
<a id="__codelineno-0-70" name="__codelineno-0-70" href="#__codelineno-0-70"></a> <span class="n">ax</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">patch_img</span><span class="p">);</span> <span class="n">ax</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;图像块 </span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">);</span> <span class="n">ax</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s1">&#39;off&#39;</span><span class="p">)</span>
<a id="__codelineno-0-71" name="__codelineno-0-71" href="#__codelineno-0-71"></a><span class="n">plt</span><span class="o">.</span><span class="n">suptitle</span><span class="p">(</span><span class="s1">&#39;ViT 图像块分解&#39;</span><span class="p">)</span>
<a id="__codelineno-0-72" name="__codelineno-0-72" href="#__codelineno-0-72"></a><span class="n">plt</span><span class="o">.</span><span class="n">tight_layout</span><span class="p">();</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></p>
</li>
<li>
<p>实现一个简单的GAN训练循环。在二维数据上训练生成器和判别器,并可视化生成分布逐渐收敛到真实分布。
<div class="highlight"><pre><span></span><code><a id="__codelineno-1-1" name="__codelineno-1-1" href="#__codelineno-1-1"></a><span class="kn">import</span><span class="w"> </span><span class="nn">jax</span>
<a id="__codelineno-1-2" name="__codelineno-1-2" href="#__codelineno-1-2"></a><span class="kn">import</span><span class="w"> </span><span class="nn">jax.numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">jnp</span>
<a id="__codelineno-1-3" name="__codelineno-1-3" href="#__codelineno-1-3"></a><span class="kn">import</span><span class="w"> </span><span class="nn">matplotlib.pyplot</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">plt</span>
<a id="__codelineno-1-4" name="__codelineno-1-4" href="#__codelineno-1-4"></a>
<a id="__codelineno-1-5" name="__codelineno-1-5" href="#__codelineno-1-5"></a><span class="k">def</span><span class="w"> </span><span class="nf">generator</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">params</span><span class="p">):</span>
<a id="__codelineno-1-6" name="__codelineno-1-6" href="#__codelineno-1-6"></a> <span class="n">h</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">z</span> <span class="o">@</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;g_w1&#39;</span><span class="p">]</span> <span class="o">+</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;g_b1&#39;</span><span class="p">])</span>
<a id="__codelineno-1-7" name="__codelineno-1-7" href="#__codelineno-1-7"></a> <span class="n">h</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">h</span> <span class="o">@</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;g_w2&#39;</span><span class="p">]</span> <span class="o">+</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;g_b2&#39;</span><span class="p">])</span>
<a id="__codelineno-1-8" name="__codelineno-1-8" href="#__codelineno-1-8"></a> <span class="k">return</span> <span class="n">h</span> <span class="o">@</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;g_w3&#39;</span><span class="p">]</span> <span class="o">+</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;g_b3&#39;</span><span class="p">]</span>
<a id="__codelineno-1-9" name="__codelineno-1-9" href="#__codelineno-1-9"></a>
<a id="__codelineno-1-10" name="__codelineno-1-10" href="#__codelineno-1-10"></a><span class="k">def</span><span class="w"> </span><span class="nf">discriminator</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">params</span><span class="p">):</span>
<a id="__codelineno-1-11" name="__codelineno-1-11" href="#__codelineno-1-11"></a> <span class="n">h</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">leaky_relu</span><span class="p">(</span><span class="n">x</span> <span class="o">@</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;d_w1&#39;</span><span class="p">]</span> <span class="o">+</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;d_b1&#39;</span><span class="p">],</span> <span class="mf">0.2</span><span class="p">)</span>
<a id="__codelineno-1-12" name="__codelineno-1-12" href="#__codelineno-1-12"></a> <span class="n">h</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">leaky_relu</span><span class="p">(</span><span class="n">h</span> <span class="o">@</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;d_w2&#39;</span><span class="p">]</span> <span class="o">+</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;d_b2&#39;</span><span class="p">],</span> <span class="mf">0.2</span><span class="p">)</span>
<a id="__codelineno-1-13" name="__codelineno-1-13" href="#__codelineno-1-13"></a> <span class="k">return</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">h</span> <span class="o">@</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;d_w3&#39;</span><span class="p">]</span> <span class="o">+</span> <span class="n">params</span><span class="p">[</span><span class="s1">&#39;d_b3&#39;</span><span class="p">])</span>
<a id="__codelineno-1-14" name="__codelineno-1-14" href="#__codelineno-1-14"></a>
<a id="__codelineno-1-15" name="__codelineno-1-15" href="#__codelineno-1-15"></a><span class="k">def</span><span class="w"> </span><span class="nf">init_params</span><span class="p">(</span><span class="n">key</span><span class="p">):</span>
<a id="__codelineno-1-16" name="__codelineno-1-16" href="#__codelineno-1-16"></a> <span class="n">keys</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="mi">6</span><span class="p">)</span>
<a id="__codelineno-1-17" name="__codelineno-1-17" href="#__codelineno-1-17"></a> <span class="n">z_dim</span><span class="p">,</span> <span class="n">h_dim</span><span class="p">,</span> <span class="n">data_dim</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">2</span>
<a id="__codelineno-1-18" name="__codelineno-1-18" href="#__codelineno-1-18"></a> <span class="n">scale</span> <span class="o">=</span> <span class="mf">0.1</span>
<a id="__codelineno-1-19" name="__codelineno-1-19" href="#__codelineno-1-19"></a> <span class="k">return</span> <span class="p">{</span>
<a id="__codelineno-1-20" name="__codelineno-1-20" href="#__codelineno-1-20"></a> <span class="s1">&#39;g_w1&#39;</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">keys</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="p">(</span><span class="n">z_dim</span><span class="p">,</span> <span class="n">h_dim</span><span class="p">))</span> <span class="o">*</span> <span class="n">scale</span><span class="p">,</span>
<a id="__codelineno-1-21" name="__codelineno-1-21" href="#__codelineno-1-21"></a> <span class="s1">&#39;g_b1&#39;</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">h_dim</span><span class="p">),</span>
<a id="__codelineno-1-22" name="__codelineno-1-22" href="#__codelineno-1-22"></a> <span class="s1">&#39;g_w2&#39;</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">keys</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="p">(</span><span class="n">h_dim</span><span class="p">,</span> <span class="n">h_dim</span><span class="p">))</span> <span class="o">*</span> <span class="n">scale</span><span class="p">,</span>
<a id="__codelineno-1-23" name="__codelineno-1-23" href="#__codelineno-1-23"></a> <span class="s1">&#39;g_b2&#39;</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">h_dim</span><span class="p">),</span>
<a id="__codelineno-1-24" name="__codelineno-1-24" href="#__codelineno-1-24"></a> <span class="s1">&#39;g_w3&#39;</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">keys</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="p">(</span><span class="n">h_dim</span><span class="p">,</span> <span class="n">data_dim</span><span class="p">))</span> <span class="o">*</span> <span class="n">scale</span><span class="p">,</span>
<a id="__codelineno-1-25" name="__codelineno-1-25" href="#__codelineno-1-25"></a> <span class="s1">&#39;g_b3&#39;</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">data_dim</span><span class="p">),</span>
<a id="__codelineno-1-26" name="__codelineno-1-26" href="#__codelineno-1-26"></a> <span class="s1">&#39;d_w1&#39;</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">keys</span><span class="p">[</span><span class="mi">3</span><span class="p">],</span> <span class="p">(</span><span class="n">data_dim</span><span class="p">,</span> <span class="n">h_dim</span><span class="p">))</span> <span class="o">*</span> <span class="n">scale</span><span class="p">,</span>
<a id="__codelineno-1-27" name="__codelineno-1-27" href="#__codelineno-1-27"></a> <span class="s1">&#39;d_b1&#39;</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">h_dim</span><span class="p">),</span>
<a id="__codelineno-1-28" name="__codelineno-1-28" href="#__codelineno-1-28"></a> <span class="s1">&#39;d_w2&#39;</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">keys</span><span class="p">[</span><span class="mi">4</span><span class="p">],</span> <span class="p">(</span><span class="n">h_dim</span><span class="p">,</span> <span class="n">h_dim</span><span class="p">))</span> <span class="o">*</span> <span class="n">scale</span><span class="p">,</span>
<a id="__codelineno-1-29" name="__codelineno-1-29" href="#__codelineno-1-29"></a> <span class="s1">&#39;d_b2&#39;</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">h_dim</span><span class="p">),</span>
<a id="__codelineno-1-30" name="__codelineno-1-30" href="#__codelineno-1-30"></a> <span class="s1">&#39;d_w3&#39;</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">keys</span><span class="p">[</span><span class="mi">5</span><span class="p">],</span> <span class="p">(</span><span class="n">h_dim</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="o">*</span> <span class="n">scale</span><span class="p">,</span>
<a id="__codelineno-1-31" name="__codelineno-1-31" href="#__codelineno-1-31"></a> <span class="s1">&#39;d_b3&#39;</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span>
<a id="__codelineno-1-32" name="__codelineno-1-32" href="#__codelineno-1-32"></a> <span class="p">}</span>
<a id="__codelineno-1-33" name="__codelineno-1-33" href="#__codelineno-1-33"></a>
<a id="__codelineno-1-34" name="__codelineno-1-34" href="#__codelineno-1-34"></a><span class="k">def</span><span class="w"> </span><span class="nf">d_loss</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">real_data</span><span class="p">,</span> <span class="n">fake_data</span><span class="p">):</span>
<a id="__codelineno-1-35" name="__codelineno-1-35" href="#__codelineno-1-35"></a> <span class="n">real_score</span> <span class="o">=</span> <span class="n">discriminator</span><span class="p">(</span><span class="n">real_data</span><span class="p">,</span> <span class="n">params</span><span class="p">)</span>
<a id="__codelineno-1-36" name="__codelineno-1-36" href="#__codelineno-1-36"></a> <span class="n">fake_score</span> <span class="o">=</span> <span class="n">discriminator</span><span class="p">(</span><span class="n">fake_data</span><span class="p">,</span> <span class="n">params</span><span class="p">)</span>
<a id="__codelineno-1-37" name="__codelineno-1-37" href="#__codelineno-1-37"></a> <span class="k">return</span> <span class="o">-</span><span class="n">jnp</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">real_score</span> <span class="o">+</span> <span class="mf">1e-7</span><span class="p">)</span> <span class="o">+</span> <span class="n">jnp</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">fake_score</span> <span class="o">+</span> <span class="mf">1e-7</span><span class="p">))</span>
<a id="__codelineno-1-38" name="__codelineno-1-38" href="#__codelineno-1-38"></a>
<a id="__codelineno-1-39" name="__codelineno-1-39" href="#__codelineno-1-39"></a><span class="k">def</span><span class="w"> </span><span class="nf">g_loss</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">fake_data</span><span class="p">):</span>
<a id="__codelineno-1-40" name="__codelineno-1-40" href="#__codelineno-1-40"></a> <span class="n">fake_score</span> <span class="o">=</span> <span class="n">discriminator</span><span class="p">(</span><span class="n">fake_data</span><span class="p">,</span> <span class="n">params</span><span class="p">)</span>
<a id="__codelineno-1-41" name="__codelineno-1-41" href="#__codelineno-1-41"></a> <span class="k">return</span> <span class="o">-</span><span class="n">jnp</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">log</span><span class="p">(</span><span class="n">fake_score</span> <span class="o">+</span> <span class="mf">1e-7</span><span class="p">))</span>
<a id="__codelineno-1-42" name="__codelineno-1-42" href="#__codelineno-1-42"></a>
<a id="__codelineno-1-43" name="__codelineno-1-43" href="#__codelineno-1-43"></a><span class="c1"># 真实数据:环形分布</span>
<a id="__codelineno-1-44" name="__codelineno-1-44" href="#__codelineno-1-44"></a><span class="n">key</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
<a id="__codelineno-1-45" name="__codelineno-1-45" href="#__codelineno-1-45"></a><span class="n">theta</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="p">(</span><span class="mi">512</span><span class="p">,))</span> <span class="o">*</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">jnp</span><span class="o">.</span><span class="n">pi</span>
<a id="__codelineno-1-46" name="__codelineno-1-46" href="#__codelineno-1-46"></a><span class="n">real_data</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="n">jnp</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">theta</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">theta</span><span class="p">)],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<a id="__codelineno-1-47" name="__codelineno-1-47" href="#__codelineno-1-47"></a><span class="n">real_data</span> <span class="o">=</span> <span class="n">real_data</span> <span class="o">+</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">real_data</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.05</span>
<a id="__codelineno-1-48" name="__codelineno-1-48" href="#__codelineno-1-48"></a>
<a id="__codelineno-1-49" name="__codelineno-1-49" href="#__codelineno-1-49"></a><span class="n">params</span> <span class="o">=</span> <span class="n">init_params</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
<a id="__codelineno-1-50" name="__codelineno-1-50" href="#__codelineno-1-50"></a><span class="n">d_grad</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">grad</span><span class="p">(</span><span class="n">d_loss</span><span class="p">)</span>
<a id="__codelineno-1-51" name="__codelineno-1-51" href="#__codelineno-1-51"></a><span class="n">g_grad</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">grad</span><span class="p">(</span><span class="n">g_loss</span><span class="p">)</span>
<a id="__codelineno-1-52" name="__codelineno-1-52" href="#__codelineno-1-52"></a><span class="n">lr</span> <span class="o">=</span> <span class="mf">0.001</span>
<a id="__codelineno-1-53" name="__codelineno-1-53" href="#__codelineno-1-53"></a>
<a id="__codelineno-1-54" name="__codelineno-1-54" href="#__codelineno-1-54"></a><span class="n">snapshots</span> <span class="o">=</span> <span class="p">[]</span>
<a id="__codelineno-1-55" name="__codelineno-1-55" href="#__codelineno-1-55"></a><span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">3000</span><span class="p">):</span>
<a id="__codelineno-1-56" name="__codelineno-1-56" href="#__codelineno-1-56"></a> <span class="n">key</span><span class="p">,</span> <span class="n">k1</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>
<a id="__codelineno-1-57" name="__codelineno-1-57" href="#__codelineno-1-57"></a> <span class="n">z</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">k1</span><span class="p">,</span> <span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span>
<a id="__codelineno-1-58" name="__codelineno-1-58" href="#__codelineno-1-58"></a> <span class="n">fake_data</span> <span class="o">=</span> <span class="n">generator</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">params</span><span class="p">)</span>
<a id="__codelineno-1-59" name="__codelineno-1-59" href="#__codelineno-1-59"></a>
<a id="__codelineno-1-60" name="__codelineno-1-60" href="#__codelineno-1-60"></a> <span class="c1"># 更新判别器</span>
<a id="__codelineno-1-61" name="__codelineno-1-61" href="#__codelineno-1-61"></a> <span class="n">grads</span> <span class="o">=</span> <span class="n">d_grad</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">real_data</span><span class="p">,</span> <span class="n">fake_data</span><span class="p">)</span>
<a id="__codelineno-1-62" name="__codelineno-1-62" href="#__codelineno-1-62"></a> <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">&#39;d_w1&#39;</span><span class="p">,</span> <span class="s1">&#39;d_b1&#39;</span><span class="p">,</span> <span class="s1">&#39;d_w2&#39;</span><span class="p">,</span> <span class="s1">&#39;d_b2&#39;</span><span class="p">,</span> <span class="s1">&#39;d_w3&#39;</span><span class="p">,</span> <span class="s1">&#39;d_b3&#39;</span><span class="p">]:</span>
<a id="__codelineno-1-63" name="__codelineno-1-63" href="#__codelineno-1-63"></a> <span class="n">params</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">params</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">-</span> <span class="n">lr</span> <span class="o">*</span> <span class="n">grads</span><span class="p">[</span><span class="n">k</span><span class="p">]</span>
<a id="__codelineno-1-64" name="__codelineno-1-64" href="#__codelineno-1-64"></a>
<a id="__codelineno-1-65" name="__codelineno-1-65" href="#__codelineno-1-65"></a> <span class="c1"># 更新生成器</span>
<a id="__codelineno-1-66" name="__codelineno-1-66" href="#__codelineno-1-66"></a> <span class="n">fake_data</span> <span class="o">=</span> <span class="n">generator</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">params</span><span class="p">)</span>
<a id="__codelineno-1-67" name="__codelineno-1-67" href="#__codelineno-1-67"></a> <span class="n">grads</span> <span class="o">=</span> <span class="n">g_grad</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">fake_data</span><span class="p">)</span>
<a id="__codelineno-1-68" name="__codelineno-1-68" href="#__codelineno-1-68"></a> <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">&#39;g_w1&#39;</span><span class="p">,</span> <span class="s1">&#39;g_b1&#39;</span><span class="p">,</span> <span class="s1">&#39;g_w2&#39;</span><span class="p">,</span> <span class="s1">&#39;g_b2&#39;</span><span class="p">,</span> <span class="s1">&#39;g_w3&#39;</span><span class="p">,</span> <span class="s1">&#39;g_b3&#39;</span><span class="p">]:</span>
<a id="__codelineno-1-69" name="__codelineno-1-69" href="#__codelineno-1-69"></a> <span class="n">params</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="n">params</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">-</span> <span class="n">lr</span> <span class="o">*</span> <span class="n">grads</span><span class="p">[</span><span class="n">k</span><span class="p">]</span>
<a id="__codelineno-1-70" name="__codelineno-1-70" href="#__codelineno-1-70"></a>
<a id="__codelineno-1-71" name="__codelineno-1-71" href="#__codelineno-1-71"></a> <span class="k">if</span> <span class="n">step</span> <span class="ow">in</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">500</span><span class="p">,</span> <span class="mi">1500</span><span class="p">,</span> <span class="mi">2999</span><span class="p">]:</span>
<a id="__codelineno-1-72" name="__codelineno-1-72" href="#__codelineno-1-72"></a> <span class="n">snapshots</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">step</span><span class="p">,</span> <span class="n">fake_data</span><span class="o">.</span><span class="n">copy</span><span class="p">()))</span>
<a id="__codelineno-1-73" name="__codelineno-1-73" href="#__codelineno-1-73"></a>
<a id="__codelineno-1-74" name="__codelineno-1-74" href="#__codelineno-1-74"></a><span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span>
<a id="__codelineno-1-75" name="__codelineno-1-75" href="#__codelineno-1-75"></a><span class="k">for</span> <span class="n">ax</span><span class="p">,</span> <span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="n">fake</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">axes</span><span class="p">,</span> <span class="n">snapshots</span><span class="p">):</span>
<a id="__codelineno-1-76" name="__codelineno-1-76" href="#__codelineno-1-76"></a> <span class="n">ax</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">real_data</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">real_data</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">s</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.3</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s1">&#39;#3498db&#39;</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s1">&#39;真实&#39;</span><span class="p">)</span>
<a id="__codelineno-1-77" name="__codelineno-1-77" href="#__codelineno-1-77"></a> <span class="n">ax</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">fake</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">fake</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">s</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.3</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s1">&#39;#e74c3c&#39;</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s1">&#39;生成&#39;</span><span class="p">)</span>
<a id="__codelineno-1-78" name="__codelineno-1-78" href="#__codelineno-1-78"></a> <span class="n">ax</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;步骤 </span><span class="si">{</span><span class="n">step</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">);</span> <span class="n">ax</span><span class="o">.</span><span class="n">set_xlim</span><span class="p">(</span><span class="o">-</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">);</span> <span class="n">ax</span><span class="o">.</span><span class="n">set_ylim</span><span class="p">(</span><span class="o">-</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<a id="__codelineno-1-79" name="__codelineno-1-79" href="#__codelineno-1-79"></a> <span class="n">ax</span><span class="o">.</span><span class="n">set_aspect</span><span class="p">(</span><span class="s1">&#39;equal&#39;</span><span class="p">);</span> <span class="n">ax</span><span class="o">.</span><span class="n">legend</span><span class="p">(</span><span class="n">markerscale</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span>
<a id="__codelineno-1-80" name="__codelineno-1-80" href="#__codelineno-1-80"></a><span class="n">plt</span><span class="o">.</span><span class="n">suptitle</span><span class="p">(</span><span class="s1">&#39;GAN训练:生成器学习环形分布&#39;</span><span class="p">)</span>
<a id="__codelineno-1-81" name="__codelineno-1-81" href="#__codelineno-1-81"></a><span class="n">plt</span><span class="o">.</span><span class="n">tight_layout</span><span class="p">();</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></p>
</li>
<li>
<p>实现扩散前向过程:在不同时间步向图像添加噪声,并可视化逐步破坏过程。然后实现单步去噪。
<div class="highlight"><pre><span></span><code><a id="__codelineno-2-1" name="__codelineno-2-1" href="#__codelineno-2-1"></a><span class="kn">import</span><span class="w"> </span><span class="nn">jax</span>
<a id="__codelineno-2-2" name="__codelineno-2-2" href="#__codelineno-2-2"></a><span class="kn">import</span><span class="w"> </span><span class="nn">jax.numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">jnp</span>
<a id="__codelineno-2-3" name="__codelineno-2-3" href="#__codelineno-2-3"></a><span class="kn">import</span><span class="w"> </span><span class="nn">matplotlib.pyplot</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">plt</span>
<a id="__codelineno-2-4" name="__codelineno-2-4" href="#__codelineno-2-4"></a>
<a id="__codelineno-2-5" name="__codelineno-2-5" href="#__codelineno-2-5"></a><span class="k">def</span><span class="w"> </span><span class="nf">noise_schedule</span><span class="p">(</span><span class="n">T</span><span class="p">,</span> <span class="n">beta_start</span><span class="o">=</span><span class="mf">0.0001</span><span class="p">,</span> <span class="n">beta_end</span><span class="o">=</span><span class="mf">0.02</span><span class="p">):</span>
<a id="__codelineno-2-6" name="__codelineno-2-6" href="#__codelineno-2-6"></a><span class="w"> </span><span class="sd">&quot;&quot;&quot;线性噪声调度。&quot;&quot;&quot;</span>
<a id="__codelineno-2-7" name="__codelineno-2-7" href="#__codelineno-2-7"></a> <span class="n">betas</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="n">beta_start</span><span class="p">,</span> <span class="n">beta_end</span><span class="p">,</span> <span class="n">T</span><span class="p">)</span>
<a id="__codelineno-2-8" name="__codelineno-2-8" href="#__codelineno-2-8"></a> <span class="n">alphas</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="n">betas</span>
<a id="__codelineno-2-9" name="__codelineno-2-9" href="#__codelineno-2-9"></a> <span class="n">alpha_bars</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">cumprod</span><span class="p">(</span><span class="n">alphas</span><span class="p">)</span>
<a id="__codelineno-2-10" name="__codelineno-2-10" href="#__codelineno-2-10"></a> <span class="k">return</span> <span class="n">betas</span><span class="p">,</span> <span class="n">alphas</span><span class="p">,</span> <span class="n">alpha_bars</span>
<a id="__codelineno-2-11" name="__codelineno-2-11" href="#__codelineno-2-11"></a>
<a id="__codelineno-2-12" name="__codelineno-2-12" href="#__codelineno-2-12"></a><span class="k">def</span><span class="w"> </span><span class="nf">forward_diffusion</span><span class="p">(</span><span class="n">x0</span><span class="p">,</span> <span class="n">t</span><span class="p">,</span> <span class="n">alpha_bars</span><span class="p">,</span> <span class="n">key</span><span class="p">):</span>
<a id="__codelineno-2-13" name="__codelineno-2-13" href="#__codelineno-2-13"></a><span class="w"> </span><span class="sd">&quot;&quot;&quot;在时间步t向x0添加噪声。&quot;&quot;&quot;</span>
<a id="__codelineno-2-14" name="__codelineno-2-14" href="#__codelineno-2-14"></a> <span class="n">alpha_bar_t</span> <span class="o">=</span> <span class="n">alpha_bars</span><span class="p">[</span><span class="n">t</span><span class="p">]</span>
<a id="__codelineno-2-15" name="__codelineno-2-15" href="#__codelineno-2-15"></a> <span class="n">noise</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">x0</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<a id="__codelineno-2-16" name="__codelineno-2-16" href="#__codelineno-2-16"></a> <span class="n">xt</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">alpha_bar_t</span><span class="p">)</span> <span class="o">*</span> <span class="n">x0</span> <span class="o">+</span> <span class="n">jnp</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">alpha_bar_t</span><span class="p">)</span> <span class="o">*</span> <span class="n">noise</span>
<a id="__codelineno-2-17" name="__codelineno-2-17" href="#__codelineno-2-17"></a> <span class="k">return</span> <span class="n">xt</span><span class="p">,</span> <span class="n">noise</span>
<a id="__codelineno-2-18" name="__codelineno-2-18" href="#__codelineno-2-18"></a>
<a id="__codelineno-2-19" name="__codelineno-2-19" href="#__codelineno-2-19"></a><span class="c1"># 创建简单的2D&quot;图像&quot;(棋盘格)</span>
<a id="__codelineno-2-20" name="__codelineno-2-20" href="#__codelineno-2-20"></a><span class="n">img</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">))</span>
<a id="__codelineno-2-21" name="__codelineno-2-21" href="#__codelineno-2-21"></a><span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">4</span><span class="p">):</span>
<a id="__codelineno-2-22" name="__codelineno-2-22" href="#__codelineno-2-22"></a> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">4</span><span class="p">):</span>
<a id="__codelineno-2-23" name="__codelineno-2-23" href="#__codelineno-2-23"></a> <span class="k">if</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="n">j</span><span class="p">)</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<a id="__codelineno-2-24" name="__codelineno-2-24" href="#__codelineno-2-24"></a> <span class="n">img</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">at</span><span class="p">[</span><span class="n">i</span><span class="o">*</span><span class="mi">8</span><span class="p">:(</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="mi">8</span><span class="p">,</span> <span class="n">j</span><span class="o">*</span><span class="mi">8</span><span class="p">:(</span><span class="n">j</span><span class="o">+</span><span class="mi">1</span><span class="p">)</span><span class="o">*</span><span class="mi">8</span><span class="p">]</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)</span>
<a id="__codelineno-2-25" name="__codelineno-2-25" href="#__codelineno-2-25"></a>
<a id="__codelineno-2-26" name="__codelineno-2-26" href="#__codelineno-2-26"></a><span class="n">T</span> <span class="o">=</span> <span class="mi">1000</span>
<a id="__codelineno-2-27" name="__codelineno-2-27" href="#__codelineno-2-27"></a><span class="n">betas</span><span class="p">,</span> <span class="n">alphas</span><span class="p">,</span> <span class="n">alpha_bars</span> <span class="o">=</span> <span class="n">noise_schedule</span><span class="p">(</span><span class="n">T</span><span class="p">)</span>
<a id="__codelineno-2-28" name="__codelineno-2-28" href="#__codelineno-2-28"></a>
<a id="__codelineno-2-29" name="__codelineno-2-29" href="#__codelineno-2-29"></a><span class="c1"># 可视化前向过程</span>
<a id="__codelineno-2-30" name="__codelineno-2-30" href="#__codelineno-2-30"></a><span class="n">timesteps</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">200</span><span class="p">,</span> <span class="mi">500</span><span class="p">,</span> <span class="mi">999</span><span class="p">]</span>
<a id="__codelineno-2-31" name="__codelineno-2-31" href="#__codelineno-2-31"></a><span class="n">key</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span>
<a id="__codelineno-2-32" name="__codelineno-2-32" href="#__codelineno-2-32"></a>
<a id="__codelineno-2-33" name="__codelineno-2-33" href="#__codelineno-2-33"></a><span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">timesteps</span><span class="p">),</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mf">3.5</span><span class="p">))</span>
<a id="__codelineno-2-34" name="__codelineno-2-34" href="#__codelineno-2-34"></a><span class="k">for</span> <span class="n">ax</span><span class="p">,</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">axes</span><span class="p">,</span> <span class="n">timesteps</span><span class="p">):</span>
<a id="__codelineno-2-35" name="__codelineno-2-35" href="#__codelineno-2-35"></a> <span class="n">key</span><span class="p">,</span> <span class="n">subkey</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>
<a id="__codelineno-2-36" name="__codelineno-2-36" href="#__codelineno-2-36"></a> <span class="n">xt</span><span class="p">,</span> <span class="n">noise</span> <span class="o">=</span> <span class="n">forward_diffusion</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">t</span><span class="p">,</span> <span class="n">alpha_bars</span><span class="p">,</span> <span class="n">subkey</span><span class="p">)</span>
<a id="__codelineno-2-37" name="__codelineno-2-37" href="#__codelineno-2-37"></a> <span class="n">ax</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">xt</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="s1">&#39;gray&#39;</span><span class="p">,</span> <span class="n">vmin</span><span class="o">=-</span><span class="mi">2</span><span class="p">,</span> <span class="n">vmax</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<a id="__codelineno-2-38" name="__codelineno-2-38" href="#__codelineno-2-38"></a> <span class="n">ax</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;t=</span><span class="si">{</span><span class="n">t</span><span class="si">}</span><span class="se">\n</span><span class="s1">$</span><span class="se">\\</span><span class="s1">bar</span><span class="se">{{\\</span><span class="s1">alpha</span><span class="se">}}</span><span class="s1">$=</span><span class="si">{</span><span class="n">alpha_bars</span><span class="p">[</span><span class="n">t</span><span class="p">]</span><span class="si">:</span><span class="s1">.3f</span><span class="si">}</span><span class="s1">&#39;</span><span class="p">)</span>
<a id="__codelineno-2-39" name="__codelineno-2-39" href="#__codelineno-2-39"></a> <span class="n">ax</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s1">&#39;off&#39;</span><span class="p">)</span>
<a id="__codelineno-2-40" name="__codelineno-2-40" href="#__codelineno-2-40"></a><span class="n">plt</span><span class="o">.</span><span class="n">suptitle</span><span class="p">(</span><span class="s1">&#39;扩散前向过程:逐步添加噪声&#39;</span><span class="p">)</span>
<a id="__codelineno-2-41" name="__codelineno-2-41" href="#__codelineno-2-41"></a><span class="n">plt</span><span class="o">.</span><span class="n">tight_layout</span><span class="p">();</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
<a id="__codelineno-2-42" name="__codelineno-2-42" href="#__codelineno-2-42"></a>
<a id="__codelineno-2-43" name="__codelineno-2-43" href="#__codelineno-2-43"></a><span class="c1"># 简单去噪:训练小型网络在t=200时预测噪声</span>
<a id="__codelineno-2-44" name="__codelineno-2-44" href="#__codelineno-2-44"></a><span class="n">t_denoise</span> <span class="o">=</span> <span class="mi">200</span>
<a id="__codelineno-2-45" name="__codelineno-2-45" href="#__codelineno-2-45"></a><span class="n">key</span><span class="p">,</span> <span class="n">k1</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">key</span><span class="p">)</span>
<a id="__codelineno-2-46" name="__codelineno-2-46" href="#__codelineno-2-46"></a><span class="n">xt</span><span class="p">,</span> <span class="n">true_noise</span> <span class="o">=</span> <span class="n">forward_diffusion</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">t_denoise</span><span class="p">,</span> <span class="n">alpha_bars</span><span class="p">,</span> <span class="n">k1</span><span class="p">)</span>
<a id="__codelineno-2-47" name="__codelineno-2-47" href="#__codelineno-2-47"></a>
<a id="__codelineno-2-48" name="__codelineno-2-48" href="#__codelineno-2-48"></a><span class="c1"># 小型&quot;去噪器&quot;:仅学习恒定的噪声估计(用于演示)</span>
<a id="__codelineno-2-49" name="__codelineno-2-49" href="#__codelineno-2-49"></a><span class="n">noise_estimate</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">img</span><span class="p">)</span>
<a id="__codelineno-2-50" name="__codelineno-2-50" href="#__codelineno-2-50"></a><span class="n">lr</span> <span class="o">=</span> <span class="mf">0.01</span>
<a id="__codelineno-2-51" name="__codelineno-2-51" href="#__codelineno-2-51"></a><span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">100</span><span class="p">):</span>
<a id="__codelineno-2-52" name="__codelineno-2-52" href="#__codelineno-2-52"></a> <span class="n">residual</span> <span class="o">=</span> <span class="n">noise_estimate</span> <span class="o">-</span> <span class="n">true_noise</span>
<a id="__codelineno-2-53" name="__codelineno-2-53" href="#__codelineno-2-53"></a> <span class="n">noise_estimate</span> <span class="o">=</span> <span class="n">noise_estimate</span> <span class="o">-</span> <span class="n">lr</span> <span class="o">*</span> <span class="n">residual</span>
<a id="__codelineno-2-54" name="__codelineno-2-54" href="#__codelineno-2-54"></a>
<a id="__codelineno-2-55" name="__codelineno-2-55" href="#__codelineno-2-55"></a><span class="c1"># 反向一步</span>
<a id="__codelineno-2-56" name="__codelineno-2-56" href="#__codelineno-2-56"></a><span class="n">alpha_bar_t</span> <span class="o">=</span> <span class="n">alpha_bars</span><span class="p">[</span><span class="n">t_denoise</span><span class="p">]</span>
<a id="__codelineno-2-57" name="__codelineno-2-57" href="#__codelineno-2-57"></a><span class="n">x_denoised</span> <span class="o">=</span> <span class="p">(</span><span class="n">xt</span> <span class="o">-</span> <span class="n">jnp</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">alpha_bar_t</span><span class="p">)</span> <span class="o">*</span> <span class="n">noise_estimate</span><span class="p">)</span> <span class="o">/</span> <span class="n">jnp</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">alpha_bar_t</span><span class="p">)</span>
<a id="__codelineno-2-58" name="__codelineno-2-58" href="#__codelineno-2-58"></a>
<a id="__codelineno-2-59" name="__codelineno-2-59" href="#__codelineno-2-59"></a><span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span>
<a id="__codelineno-2-60" name="__codelineno-2-60" href="#__codelineno-2-60"></a><span class="n">axes</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="s1">&#39;gray&#39;</span><span class="p">);</span> <span class="n">axes</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s1">&#39;原始 $x_0$&#39;</span><span class="p">);</span> <span class="n">axes</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s1">&#39;off&#39;</span><span class="p">)</span>
<a id="__codelineno-2-61" name="__codelineno-2-61" href="#__codelineno-2-61"></a><span class="n">axes</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">xt</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="s1">&#39;gray&#39;</span><span class="p">,</span> <span class="n">vmin</span><span class="o">=-</span><span class="mi">2</span><span class="p">,</span> <span class="n">vmax</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<a id="__codelineno-2-62" name="__codelineno-2-62" href="#__codelineno-2-62"></a><span class="n">axes</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s1">&#39;含噪 $x_</span><span class="se">{{</span><span class="s1">200</span><span class="se">}}</span><span class="s1">$&#39;</span><span class="p">);</span> <span class="n">axes</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s1">&#39;off&#39;</span><span class="p">)</span>
<a id="__codelineno-2-63" name="__codelineno-2-63" href="#__codelineno-2-63"></a><span class="n">axes</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">x_denoised</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="s1">&#39;gray&#39;</span><span class="p">)</span>
<a id="__codelineno-2-64" name="__codelineno-2-64" href="#__codelineno-2-64"></a><span class="n">axes</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s1">&#39;去噪后(单步)&#39;</span><span class="p">);</span> <span class="n">axes</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s1">&#39;off&#39;</span><span class="p">)</span>
<a id="__codelineno-2-65" name="__codelineno-2-65" href="#__codelineno-2-65"></a><span class="n">plt</span><span class="o">.</span><span class="n">tight_layout</span><span class="p">();</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span>
<a id="__codelineno-2-66" name="__codelineno-2-66" href="#__codelineno-2-66"></a>
<a id="__codelineno-2-67" name="__codelineno-2-67" href="#__codelineno-2-67"></a><span class="n">mse</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">mean</span><span class="p">((</span><span class="n">x_denoised</span> <span class="o">-</span> <span class="n">img</span><span class="p">)</span><span class="o">**</span><span class="mi">2</span><span class="p">)</span>
<a id="__codelineno-2-68" name="__codelineno-2-68" href="#__codelineno-2-68"></a><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;去噪MSE: </span><span class="si">{</span><span class="n">mse</span><span class="si">:</span><span class="s2">.4f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
</code></pre></div></p>
</li>
</ol>
</article>
</div>
<script>var target=document.getElementById(location.hash.slice(1));target&&target.name&&(target.checked=target.name.startsWith("__tabbed_"))</script>
</div>
<button type="button" class="md-top md-icon" data-md-component="top" hidden>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8z"/></svg>
回到页面顶部
</button>
</main>
<footer class="md-footer">
<nav class="md-footer__inner md-grid" aria-label="页脚" >
<a href="../03.%20object%20detection%20and%20segmentation/" class="md-footer__link md-footer__link--prev" aria-label="上一页: 目标检测与分割">
<div class="md-footer__button md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11z"/></svg>
</div>
<div class="md-footer__title">
<span class="md-footer__direction">
上一页
</span>
<div class="md-ellipsis">
目标检测与分割
</div>
</div>
</a>
<a href="../05.%20video%20and%203D%20vision/" class="md-footer__link md-footer__link--next" aria-label="下一页: 视频与 3D 视觉">
<div class="md-footer__title">
<span class="md-footer__direction">
下一页
</span>
<div class="md-ellipsis">
视频与 3D 视觉
</div>
</div>
<div class="md-footer__button md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M4 11v2h12l-5.5 5.5 1.42 1.42L19.84 12l-7.92-7.92L10.5 5.5 16 11z"/></svg>
</div>
</a>
</nav>
<div class="md-footer-meta md-typeset">
<div class="md-footer-meta__inner md-grid">
<div class="md-copyright">
Made with
<a href="https://squidfunk.github.io/mkdocs-material/" target="_blank" rel="noopener">
Material for MkDocs
</a>
</div>
<div class="md-social">
<a href="https://github.com/flykhan/maths-cs-ai-compendium-zh" target="_blank" rel="noopener" title="github.com" class="md-social__link">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 512 512"><!--! Font Awesome Free 7.1.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2025 Fonticons, Inc.--><path d="M173.9 397.4c0 2-2.3 3.6-5.2 3.6-3.3.3-5.6-1.3-5.6-3.6 0-2 2.3-3.6 5.2-3.6 3-.3 5.6 1.3 5.6 3.6m-31.1-4.5c-.7 2 1.3 4.3 4.3 4.9 2.6 1 5.6 0 6.2-2s-1.3-4.3-4.3-5.2c-2.6-.7-5.5.3-6.2 2.3m44.2-1.7c-2.9.7-4.9 2.6-4.6 4.9.3 2 2.9 3.3 5.9 2.6 2.9-.7 4.9-2.6 4.6-4.6-.3-1.9-3-3.2-5.9-2.9M252.8 8C114.1 8 8 113.3 8 252c0 110.9 69.8 205.8 169.5 239.2 12.8 2.3 17.3-5.6 17.3-12.1 0-6.2-.3-40.4-.3-61.4 0 0-70 15-84.7-29.8 0 0-11.4-29.1-27.8-36.6 0 0-22.9-15.7 1.6-15.4 0 0 24.9 2 38.6 25.8 21.9 38.6 58.6 27.5 72.9 20.9 2.3-16 8.8-27.1 16-33.7-55.9-6.2-112.3-14.3-112.3-110.5 0-27.5 7.6-41.3 23.6-58.9-2.6-6.5-11.1-33.3 2.6-67.9 20.9-6.5 69 27 69 27 20-5.6 41.5-8.5 62.8-8.5s42.8 2.9 62.8 8.5c0 0 48.1-33.6 69-27 13.7 34.7 5.2 61.4 2.6 67.9 16 17.7 25.8 31.5 25.8 58.9 0 96.5-58.9 104.2-114.8 110.5 9.2 7.9 17 22.9 17 46.4 0 33.7-.3 75.4-.3 83.6 0 6.5 4.6 14.4 17.3 12.1C436.2 457.8 504 362.9 504 252 504 113.3 391.5 8 252.8 8M105.2 352.9c-1.3 1-1 3.3.7 5.2 1.6 1.6 3.9 2.3 5.2 1 1.3-1 1-3.3-.7-5.2-1.6-1.6-3.9-2.3-5.2-1m-10.8-8.1c-.7 1.3.3 2.9 2.3 3.9 1.6 1 3.6.7 4.3-.7.7-1.3-.3-2.9-2.3-3.9-2-.6-3.6-.3-4.3.7m32.4 35.6c-1.6 1.3-1 4.3 1.3 6.2 2.3 2.3 5.2 2.6 6.5 1 1.3-1.3.7-4.3-1.3-6.2-2.2-2.3-5.2-2.6-6.5-1m-11.4-14.7c-1.6 1-1.6 3.6 0 5.9s4.3 3.3 5.6 2.3c1.6-1.3 1.6-3.9 0-6.2-1.4-2.3-4-3.3-5.6-2"/></svg>
</a>
</div>
</div>
</div>
</footer>
</div>
<div class="md-dialog" data-md-component="dialog">
<div class="md-dialog__inner md-typeset"></div>
</div>
<script id="__config" type="application/json">{"annotate": null, "base": "../..", "features": ["navigation.tabs", "navigation.sections", "navigation.expand", "navigation.top", "navigation.footer", "search.suggest", "search.highlight", "content.code.copy", "toc.follow"], "search": "../../assets/javascripts/workers/search.2c215733.min.js", "tags": null, "translations": {"clipboard.copied": "\u5df2\u590d\u5236", "clipboard.copy": "\u590d\u5236", "search.result.more.one": "\u5728\u8be5\u9875\u4e0a\u8fd8\u6709 1 \u4e2a\u7b26\u5408\u6761\u4ef6\u7684\u7ed3\u679c", "search.result.more.other": "\u5728\u8be5\u9875\u4e0a\u8fd8\u6709 # \u4e2a\u7b26\u5408\u6761\u4ef6\u7684\u7ed3\u679c", "search.result.none": "\u6ca1\u6709\u627e\u5230\u7b26\u5408\u6761\u4ef6\u7684\u7ed3\u679c", "search.result.one": "\u627e\u5230 1 \u4e2a\u7b26\u5408\u6761\u4ef6\u7684\u7ed3\u679c", "search.result.other": "# \u4e2a\u7b26\u5408\u6761\u4ef6\u7684\u7ed3\u679c", "search.result.placeholder": "\u952e\u5165\u4ee5\u5f00\u59cb\u641c\u7d22", "search.result.term.missing": "\u7f3a\u5c11", "select.version": "\u9009\u62e9\u5f53\u524d\u7248\u672c"}, "version": null}</script>
<script src="../../assets/javascripts/bundle.79ae519e.min.js"></script>
<script src="../../javascripts/mathjax.js"></script>
<script src="https://unpkg.com/mathjax@3/es5/tex-mml-chtml.js"></script>
</body>
</html>