From 2536c937e3b9fa19545dfac5350b88f15a2c2620 Mon Sep 17 00:00:00 2001 From: flykhan Date: Sun, 3 May 2026 10:23:20 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=8C=E6=95=B4=E4=B8=AD=E6=96=87?= =?UTF-8?q?=E7=BF=BB=E8=AF=91=20maths-cs-ai-compendium=EF=BC=88=E6=95=B0?= =?UTF-8?q?=E5=AD=A6=C2=B7=E8=AE=A1=E7=AE=97=E6=9C=BA=E7=A7=91=E5=AD=A6?= =?UTF-8?q?=C2=B7AI=20=E7=9F=A5=E8=AF=86=E5=A4=A7=E5=85=A8=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 翻译自英文原版 maths-cs-ai-compendium,共 20 章全部完成。 第01章 向量 | 第02章 矩阵 | 第03章 微积分 第04章 统计学 | 第05章 概率论 | 第06章 机器学习 第07章 计算语言学 | 第08章 计算机视觉 | 第09章 音频与语音 第10章 多模态学习 | 第11章 自主系统 | 第12章 图神经网络 第13章 计算与操作系统 | 第14章 数据结构与算法 第15章 生产级软件工程 | 第16章 SIMD与GPU编程 第17章 AI推理 | 第18章 ML系统设计 第19章 应用人工智能 | 第20章 前沿人工智能 翻译说明: - 所有数学公式 $...$ / $$...$$、代码块、图片引用完整保留 - mkdocs.yml 配置中文导航 + language: zh - README.md 已翻译为中文(兼 docs/index.md) - docs/ 目录包含指向各章文件的 symlink - 约 29,000 行中文内容,排除 .cache/ 构建缓存 --- .gitignore | 1 + LICENSE | 201 +++++ README.md | 79 ++ chapter 01: vectors/01. vector spaces.md | 147 ++++ chapter 01: vectors/02. vector properties.md | 97 +++ chapter 01: vectors/03. norms and metrics.md | 90 ++ chapter 01: vectors/04. products.md | 113 +++ chapter 01: vectors/05. basis and duality.md | 89 ++ chapter 02: matrices/01. matrix properties.md | 166 ++++ chapter 02: matrices/02. matrix types.md | 139 +++ chapter 02: matrices/03. operations.md | 146 ++++ .../04. linear transformations.md | 163 ++++ chapter 02: matrices/05. decompositions.md | 211 +++++ .../01. differential calculus.md | 205 +++++ chapter 03: calculus/02. integral calculus.md | 110 +++ .../03. multivariate calculus.md | 219 +++++ .../04. function approximation.md | 143 ++++ chapter 03: calculus/05. optimisation.md | 145 ++++ chapter 04: statistics/01. fundamentals.md | 186 ++++ chapter 04: statistics/02. measures.md | 178 ++++ chapter 04: statistics/03. sampling.md | 164 ++++ .../04. hypothesis testing.md | 177 ++++ chapter 04: statistics/05. inference.md | 210 +++++ chapter 05: probability/01. counting.md | 167 ++++ .../02. probability concepts.md | 243 ++++++ chapter 05: probability/03. distributions.md | 238 ++++++ chapter 05: probability/04. bayesian.md | 292 +++++++ .../05. information theory.md | 191 +++++ .../01. classical machine learning.md | 381 +++++++++ .../02. gradient machine learning.md | 408 +++++++++ .../03. deep learning.md | 354 ++++++++ .../04. reinforcement learning.md | 353 ++++++++ .../05. distributed deep learning.md | 264 ++++++ .../01. linguistic foundations.md | 303 +++++++ .../02. text processing and classic NLP.md | 340 ++++++++ .../03. embeddings and sequence models.md | 389 +++++++++ .../04. transformers and language models.md | 517 +++++++++++ .../05. advanced text generation.md | 567 ++++++++++++ .../01. image fundamentals.md | 363 ++++++++ .../02. convolutional networks.md | 382 +++++++++ .../03. object detection and segmentation.md | 376 ++++++++ .../04. vision transformers and generation.md | 360 ++++++++ .../05. video and 3D vision.md | 347 ++++++++ .../01. digital signal processing.md | 498 +++++++++++ .../02. automatic speech recognition.md | 592 +++++++++++++ .../03. text to speech and voice.md | 711 ++++++++++++++++ .../04. speaker and audio analysis.md | 643 ++++++++++++++ .../05. source separation and noise.md | 804 ++++++++++++++++++ .../01. multimodal representations.md | 366 ++++++++ .../02. vision language models.md | 388 +++++++++ .../03. image and video tokenisation.md | 419 +++++++++ .../04. cross-modal generation.md | 405 +++++++++ .../05. unified multimodal architectures.md | 322 +++++++ .../01. perception.md | 287 +++++++ .../02. robot learning.md | 298 +++++++ .../03. vision-language-action models.md | 217 +++++ .../04. self-driving.md | 312 +++++++ .../05. space and extreme robotics.md | 311 +++++++ .../01. geometric deep learning.md | 170 ++++ .../02. graph theory.md | 236 +++++ .../03. graph neural networks.md | 271 ++++++ .../04. graph attention networks.md | 258 ++++++ .../05. 3d graph networks.md | 274 ++++++ .../01. discrete maths.md | 253 ++++++ .../02. computer architecture.md | 233 +++++ .../03. operating systems.md | 258 ++++++ .../04. concurrency and parallelism.md | 226 +++++ .../05. programming languages.md | 273 ++++++ .../00. foundations.md | 500 +++++++++++ .../01. arrays and hashing.md | 526 ++++++++++++ .../02. linked lists, stacks, and queues.md | 410 +++++++++ .../03. trees.md | 381 +++++++++ .../04. graphs.md | 323 +++++++ .../05. sorting and search.md | 553 ++++++++++++ .../01. linux and CMD.md | 318 +++++++ .../02. git and repository management.md | 218 +++++ .../03. codebase design.md | 383 +++++++++ .../04. testing and quality assurance.md | 322 +++++++ .../05. deployment and devops.md | 233 +++++ .../00. why C++ and how ML frameworks work.md | 419 +++++++++ .../01. hardware fundamentals.md | 282 ++++++ .../02. ARM and NEON.md | 484 +++++++++++ .../03. x86 and AVX.md | 450 ++++++++++ .../04. GPU architecture and CUDA.md | 598 +++++++++++++ .../05. triton, TPUs and pallax.md | 393 +++++++++ .../06. RISC-V and embedded systems.md | 428 ++++++++++ .... vulkan compute and cross-platform GPU.md | 668 +++++++++++++++ chapter 17: AI inference/01. quantisation.md | 339 ++++++++ .../02. efficient architectures.md | 238 ++++++ .../03. serving and batching.md | 236 +++++ .../04. edge inference.md | 212 +++++ .../05. scaling and deployment.md | 251 ++++++ .../01. systems design fundamentals.md | 176 ++++ .../02. cloud computing.md | 165 ++++ .../03. large scale infrastructure.md | 200 +++++ .../04. ML systems design.md | 183 ++++ .../05. ML design examples.md | 337 ++++++++ chapter 19: applied AI/01. AI for finance.md | 10 + chapter 19: applied AI/02. protein design.md | 10 + chapter 19: applied AI/03. drug discovery.md | 0 chapter 19: applied AI/04. agentic systems.md | 0 chapter 19: applied AI/05. healthcare.md | 0 .../01. quantum machine learning.md | 11 + .../02. neuromorphic computing.md | 10 + .../03. datacentres in space.md | 0 .../04. decentralised AI.md | 0 .../05. brain machine interfaces.md | 0 docs/chapter 01: vectors | 1 + docs/chapter 02: matrices | 1 + docs/chapter 03: calculus | 1 + docs/chapter 04: statistics | 1 + docs/chapter 05: probability | 1 + docs/chapter 06: machine learning | 1 + docs/chapter 07: computational linguistics | 1 + docs/chapter 08: computer vision | 1 + docs/chapter 09: audio and speech | 1 + docs/chapter 10: multimodal learning | 1 + docs/chapter 11: autonomous systems | 1 + docs/chapter 12: graph neural networks | 1 + docs/chapter 13: computing and OS | 1 + ...chapter 14: data structures and algorithms | 1 + ...hapter 15: production software engineering | 1 + docs/chapter 16: SIMD and GPU programming | 1 + docs/chapter 17: AI inference | 1 + docs/chapter 18: ML systems design | 1 + docs/chapter 19: applied AI | 1 + docs/chapter 20: bleeding edge AI | 1 + docs/images | 1 + docs/index.md | 1 + docs/javascripts | 1 + images/ab_testing.svg | 46 + images/action_tokenisation.svg | 65 ++ images/activation_functions.svg | 51 ++ images/actor_critic.svg | 45 + images/additive_inverse.svg | 21 + images/amdahl_serial_bottleneck.svg | 33 + images/any_to_any_architectures.svg | 187 ++++ images/area_under_curve.svg | 40 + images/asr_pipeline.svg | 65 ++ images/attention_alignment.svg | 79 ++ images/attention_captioning.svg | 116 +++ images/attention_sparsity_patterns.svg | 48 ++ images/audio_spectrogram_transformer.svg | 96 +++ images/audio_visual_correspondence.svg | 95 +++ images/audio_waveform.svg | 68 ++ images/autonomous_driving_stack.svg | 54 ++ images/bag_of_words.svg | 107 +++ images/basis_transform.svg | 44 + images/bayes_components.svg | 41 + images/beamforming.svg | 106 +++ images/bernoulli_binomial.svg | 75 ++ images/bert_mlm.svg | 81 ++ images/bev_fusion_pipeline.svg | 62 ++ images/bidirectional_rnn.svg | 91 ++ images/bio_tagging.svg | 43 + images/cache_aside_pattern.svg | 41 + images/cap_theorem.svg | 33 + images/central_limit_theorem.svg | 72 ++ images/chain_rule.svg | 32 + images/clip_contrastive_matrix.svg | 92 ++ images/cloud_service_layers.svg | 60 ++ images/cnn_convolution.svg | 123 +++ images/cocktail_party.svg | 86 ++ images/codebook_collapse.svg | 111 +++ images/cofactor.svg | 39 + images/column_space.svg | 46 + images/common_distributions.svg | 66 ++ images/commutativity.svg | 29 + images/compilation_pipeline.svg | 66 ++ images/concurrency_vs_parallelism.svg | 62 ++ images/conditional_probability.svg | 47 + images/confidence_interval.svg | 37 + images/conformer_block.svg | 88 ++ images/constituency_tree.svg | 73 ++ images/container_vs_vm.svg | 75 ++ images/continuous_vs_discrete_tokens.svg | 103 +++ images/contrastive_temperature.svg | 87 ++ images/conv_tasnet.svg | 109 +++ images/convex_nonconvex.svg | 76 ++ images/correlation_scatter.svg | 63 ++ images/counting_outfits.svg | 71 ++ images/cpu_pipeline.svg | 62 ++ images/critical_points.svg | 51 ++ images/crnn_ocr_pipeline.svg | 68 ++ images/cross_modal_generation_overview.svg | 92 ++ images/ctc_alignment.svg | 121 +++ images/dalle_autoregressive_pipeline.svg | 98 +++ images/data_model_parallelism.svg | 88 ++ images/deadlock_cycle.svg | 51 ++ images/decision_tree_split.svg | 61 ++ images/deeplab_aspp.svg | 105 +++ images/densenet_block.svg | 65 ++ images/dependency_tree.svg | 43 + images/depthwise_separable_conv.svg | 84 ++ images/detection_boxes.svg | 54 ++ images/determinant.svg | 29 + images/difference_quotient.svg | 43 + images/diffusion_process.svg | 109 +++ images/distribution_shift_bc.svg | 32 + images/distribution_types.svg | 37 + images/distributivity.svg | 38 + images/dit_architecture.svg | 119 +++ images/dot_product.svg | 36 + images/dual_vs_fusion_encoder.svg | 113 +++ images/earth_mars_delay.svg | 34 + images/efficientnet_scaling.svg | 71 ++ images/eigenvector.svg | 47 + images/ensemble_methods.svg | 93 ++ images/faster_rcnn.svg | 81 ++ images/fcos_detection.svg | 78 ++ images/feature_store.svg | 53 ++ images/five_geometric_domains.svg | 83 ++ images/flamingo_architecture.svg | 126 +++ images/focal_loss.svg | 69 ++ images/fpn_pyramid.svg | 96 +++ images/fraud_detection_pipeline.svg | 64 ++ images/fusion_strategies.svg | 124 +++ images/gat_attention_weights.svg | 52 ++ images/gaussian_elimination.svg | 43 + images/generation_evaluation_metrics.svg | 135 +++ images/gpu_cluster_topology.svg | 109 +++ images/grad_cam.svg | 79 ++ images/gradient_contour.svg | 46 + images/gradient_descent_landscape.svg | 50 ++ images/graph_adjacency_matrix.svg | 57 ++ images/graph_laplacian_smoothness.svg | 49 ++ images/grounding_coordinate_tokens.svg | 109 +++ images/hifi_gan_generator.svg | 127 +++ images/hmm_structure.svg | 68 ++ images/human_vectors.svg | 45 + images/hypothesis_test.svg | 48 ++ images/ieee754_float.svg | 35 + images/image_histogram.svg | 68 ++ images/image_pyramid.svg | 65 ++ images/image_tokenisation_overview.svg | 100 +++ images/image_tokeniser_comparison.svg | 115 +++ images/inception_module.svg | 67 ++ images/instructpix2pix_pipeline.svg | 94 ++ images/invariance_vs_equivariance.svg | 85 ++ images/joint_embedding_space.svg | 81 ++ images/kl_divergence.svg | 35 + images/kmeans_clustering.svg | 67 ++ images/lagrange_multiplier.svg | 47 + images/lidar_time_of_flight.svg | 39 + images/line_equation.svg | 44 + images/linear_regression_fit.svg | 52 ++ images/llava_architecture.svg | 96 +++ images/lms_adaptive_filter.svg | 84 ++ images/load_balancer.svg | 49 ++ images/logo.png | Bin 0 -> 802543 bytes images/lora_decomposition.svg | 73 ++ images/lu_decomposition.svg | 37 + images/magnitude_direction.svg | 43 + images/markov_chain.svg | 68 ++ images/mask_rcnn.svg | 106 +++ images/matrix_rank.svg | 45 + images/matrix_trace.svg | 20 + images/mdp_agent_loop.svg | 41 + images/mel_filterbank.svg | 105 +++ images/memory_hierarchy.svg | 37 + images/message_passing_gnn.svg | 63 ++ images/mfcc_pipeline.svg | 95 +++ images/mha_gqa_mqa.svg | 113 +++ images/microservices_architecture.svg | 71 ++ images/ml_lifecycle.svg | 49 ++ images/mla_vs_gqa.svg | 97 +++ images/mle_vs_map.svg | 55 ++ images/moe_layer.svg | 72 ++ images/moe_routing.svg | 72 ++ images/moments_shape.svg | 49 ++ images/monte_carlo_pi.svg | 56 ++ images/morpheme_tree.svg | 33 + images/multimodal_action_distribution.svg | 45 + images/multimodal_agent_loop.svg | 103 +++ images/multimodal_overview.svg | 91 ++ images/multimodal_tokenisation_sequence.svg | 98 +++ images/multimodal_world_model.svg | 114 +++ images/naive_bayes_classify.svg | 33 + images/newtons_method.svg | 53 ++ images/normal_empirical.svg | 48 ++ images/normalization_types.svg | 103 +++ images/occupancy_vs_bbox.svg | 63 ++ images/ocr_free_document_understanding.svg | 90 ++ images/open_vs_closed_loop.svg | 65 ++ images/optical_flow.svg | 93 ++ images/optimizer_memory.svg | 63 ++ images/optimizer_muon.svg | 60 ++ images/optimizer_trajectories.svg | 45 + images/orthogonal_vectors.svg | 21 + images/over_smoothing_gnn.svg | 47 + images/p_np_complexity.svg | 34 + images/paged_attention.svg | 92 ++ images/parallel_vectors.svg | 28 + images/partial_derivative.svg | 40 + images/pca.svg | 48 ++ images/permutation_vs_combination.svg | 61 ++ images/pid_response.svg | 28 + images/pinhole_camera.svg | 63 ++ images/polynomial_buildup.svg | 48 ++ images/precision_formats_memory.svg | 110 +++ images/process_states.svg | 53 ++ images/quantisation_granularity.svg | 68 ++ images/quartiles_boxplot.svg | 43 + images/rag_architecture.svg | 62 ++ images/random_variable.svg | 69 ++ images/receptive_field.svg | 127 +++ images/recommendation_pipeline.svg | 54 ++ images/reflection.svg | 27 + images/residual_quantisation.svg | 107 +++ images/resnet_block.svg | 60 ++ images/retrieval_recall_at_k.svg | 110 +++ images/reynolds_flocking.svg | 71 ++ images/rgb_channels.svg | 76 ++ images/ring_allreduce.svg | 50 ++ images/rlhf_pipeline.svg | 66 ++ images/rnn_lstm_cell.svg | 79 ++ images/robot_arm_fk.svg | 40 + images/rope_rotation.svg | 63 ++ images/rotation.svg | 27 + images/sae_autonomy_levels.svg | 59 ++ images/sampling_aliasing.svg | 113 +++ images/sampling_methods.svg | 102 +++ images/scalar_multiplication.svg | 29 + images/scaling.svg | 19 + images/scaling_laws.svg | 80 ++ images/scaling_vlms_comparison.svg | 96 +++ images/se3_equivariance.svg | 64 ++ images/semantic_segmentation.svg | 60 ++ images/sensor_comparison.svg | 56 ++ images/seq2seq_architecture.svg | 84 ++ images/shared_autonomy_spectrum.svg | 33 + images/shared_backbone_multimodal.svg | 146 ++++ images/shearing.svg | 23 + images/sigmoid_logistic.svg | 40 + images/sim_to_real.svg | 54 ++ images/sobel_edges.svg | 52 ++ images/social_network_matrix.svg | 77 ++ images/sparse_attention_patterns.svg | 216 +++++ images/sparse_dense.svg | 47 + images/speaker_diarisation.svg | 73 ++ images/speaker_verification.svg | 92 ++ images/spectrogram_stft.svg | 155 ++++ images/speculative_decoding.svg | 78 ++ images/ssm_dual_view.svg | 92 ++ images/stable_diffusion_architecture.svg | 106 +++ images/stack_vs_heap.svg | 43 + images/staged_multimodal_training.svg | 141 +++ images/static_vs_continuous_batching.svg | 80 ++ images/subspaces.svg | 33 + images/surprisal_entropy.svg | 56 ++ images/svd.svg | 45 + images/svm_margin.svg | 72 ++ images/swarm_consensus.svg | 50 ++ images/swin_shifted_windows.svg | 87 ++ images/tacotron2_architecture.svg | 115 +++ images/tangent_line.svg | 34 + images/taylor_approximation.svg | 46 + images/tcp_ip_layers.svg | 29 + images/td_update.svg | 40 + images/temporal_compression_strategies.svg | 110 +++ images/text_diffusion.svg | 106 +++ images/text_to_audio_pipeline.svg | 157 ++++ images/text_to_video_pipeline.svg | 100 +++ images/textcnn_architecture.svg | 116 +++ images/tokenisation_comparison.svg | 50 ++ images/training_memory_breakdown.svg | 47 + images/transformer_block.svg | 68 ++ images/transformer_paradigms.svg | 117 +++ images/tts_pipeline.svg | 102 +++ images/type_errors.svg | 43 + images/unet_architecture.svg | 89 ++ images/unified_multimodal_overview.svg | 125 +++ images/unified_vision_language_tokens.svg | 100 +++ images/variance_spread.svg | 19 + images/vector_3d.svg | 34 + images/vector_addition.svg | 25 + images/vector_equality.svg | 28 + images/venn_diagram.svg | 35 + images/vgg_architecture.svg | 89 ++ images/video_3d_vqvae.svg | 139 +++ images/visual_token_pipeline.svg | 143 ++++ images/vit_pipeline.svg | 83 ++ images/vla_architecture.svg | 63 ++ images/vlm_taxonomy.svg | 101 +++ images/voice_conversion_pipeline.svg | 114 +++ images/vqa_pipeline.svg | 82 ++ images/vqgan_training.svg | 108 +++ images/vqvae_architecture.svg | 94 ++ images/wav2vec2_pretraining.svg | 157 ++++ images/word2vec_architectures.svg | 87 ++ images/xvector_architecture.svg | 95 +++ images/yolo_grid.svg | 57 ++ images/zero_shot_classification.svg | 97 +++ images/zero_vector.svg | 17 + javascripts/mathjax.js | 19 + llms.txt | 156 ++++ mcp/package.json | 19 + mcp/src/index.ts | 334 ++++++++ mcp/tsconfig.json | 12 + mkdocs.yml | 214 +++++ 400 files changed, 49040 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 README.md create mode 100644 chapter 01: vectors/01. vector spaces.md create mode 100644 chapter 01: vectors/02. vector properties.md create mode 100644 chapter 01: vectors/03. norms and metrics.md create mode 100644 chapter 01: vectors/04. products.md create mode 100644 chapter 01: vectors/05. basis and duality.md create mode 100644 chapter 02: matrices/01. matrix properties.md create mode 100644 chapter 02: matrices/02. matrix types.md create mode 100644 chapter 02: matrices/03. operations.md create mode 100644 chapter 02: matrices/04. linear transformations.md create mode 100644 chapter 02: matrices/05. decompositions.md create mode 100644 chapter 03: calculus/01. differential calculus.md create mode 100644 chapter 03: calculus/02. integral calculus.md create mode 100644 chapter 03: calculus/03. multivariate calculus.md create mode 100644 chapter 03: calculus/04. function approximation.md create mode 100644 chapter 03: calculus/05. optimisation.md create mode 100644 chapter 04: statistics/01. fundamentals.md create mode 100644 chapter 04: statistics/02. measures.md create mode 100644 chapter 04: statistics/03. sampling.md create mode 100644 chapter 04: statistics/04. hypothesis testing.md create mode 100644 chapter 04: statistics/05. inference.md create mode 100644 chapter 05: probability/01. counting.md create mode 100644 chapter 05: probability/02. probability concepts.md create mode 100644 chapter 05: probability/03. distributions.md create mode 100644 chapter 05: probability/04. bayesian.md create mode 100644 chapter 05: probability/05. information theory.md create mode 100644 chapter 06: machine learning/01. classical machine learning.md create mode 100644 chapter 06: machine learning/02. gradient machine learning.md create mode 100644 chapter 06: machine learning/03. deep learning.md create mode 100644 chapter 06: machine learning/04. reinforcement learning.md create mode 100644 chapter 06: machine learning/05. distributed deep learning.md create mode 100644 chapter 07: computational linguistics/01. linguistic foundations.md create mode 100644 chapter 07: computational linguistics/02. text processing and classic NLP.md create mode 100644 chapter 07: computational linguistics/03. embeddings and sequence models.md create mode 100644 chapter 07: computational linguistics/04. transformers and language models.md create mode 100644 chapter 07: computational linguistics/05. advanced text generation.md create mode 100644 chapter 08: computer vision/01. image fundamentals.md create mode 100644 chapter 08: computer vision/02. convolutional networks.md create mode 100644 chapter 08: computer vision/03. object detection and segmentation.md create mode 100644 chapter 08: computer vision/04. vision transformers and generation.md create mode 100644 chapter 08: computer vision/05. video and 3D vision.md create mode 100644 chapter 09: audio and speech/01. digital signal processing.md create mode 100644 chapter 09: audio and speech/02. automatic speech recognition.md create mode 100644 chapter 09: audio and speech/03. text to speech and voice.md create mode 100644 chapter 09: audio and speech/04. speaker and audio analysis.md create mode 100644 chapter 09: audio and speech/05. source separation and noise.md create mode 100644 chapter 10: multimodal learning/01. multimodal representations.md create mode 100644 chapter 10: multimodal learning/02. vision language models.md create mode 100644 chapter 10: multimodal learning/03. image and video tokenisation.md create mode 100644 chapter 10: multimodal learning/04. cross-modal generation.md create mode 100644 chapter 10: multimodal learning/05. unified multimodal architectures.md create mode 100644 chapter 11: autonomous systems/01. perception.md create mode 100644 chapter 11: autonomous systems/02. robot learning.md create mode 100644 chapter 11: autonomous systems/03. vision-language-action models.md create mode 100644 chapter 11: autonomous systems/04. self-driving.md create mode 100644 chapter 11: autonomous systems/05. space and extreme robotics.md create mode 100644 chapter 12: graph neural networks/01. geometric deep learning.md create mode 100644 chapter 12: graph neural networks/02. graph theory.md create mode 100644 chapter 12: graph neural networks/03. graph neural networks.md create mode 100644 chapter 12: graph neural networks/04. graph attention networks.md create mode 100644 chapter 12: graph neural networks/05. 3d graph networks.md create mode 100644 chapter 13: computing and OS/01. discrete maths.md create mode 100644 chapter 13: computing and OS/02. computer architecture.md create mode 100644 chapter 13: computing and OS/03. operating systems.md create mode 100644 chapter 13: computing and OS/04. concurrency and parallelism.md create mode 100644 chapter 13: computing and OS/05. programming languages.md create mode 100644 chapter 14: data structures and algorithms/00. foundations.md create mode 100644 chapter 14: data structures and algorithms/01. arrays and hashing.md create mode 100644 chapter 14: data structures and algorithms/02. linked lists, stacks, and queues.md create mode 100644 chapter 14: data structures and algorithms/03. trees.md create mode 100644 chapter 14: data structures and algorithms/04. graphs.md create mode 100644 chapter 14: data structures and algorithms/05. sorting and search.md create mode 100644 chapter 15: production software engineering/01. linux and CMD.md create mode 100644 chapter 15: production software engineering/02. git and repository management.md create mode 100644 chapter 15: production software engineering/03. codebase design.md create mode 100644 chapter 15: production software engineering/04. testing and quality assurance.md create mode 100644 chapter 15: production software engineering/05. deployment and devops.md create mode 100644 chapter 16: SIMD and GPU programming/00. why C++ and how ML frameworks work.md create mode 100644 chapter 16: SIMD and GPU programming/01. hardware fundamentals.md create mode 100644 chapter 16: SIMD and GPU programming/02. ARM and NEON.md create mode 100644 chapter 16: SIMD and GPU programming/03. x86 and AVX.md create mode 100644 chapter 16: SIMD and GPU programming/04. GPU architecture and CUDA.md create mode 100644 chapter 16: SIMD and GPU programming/05. triton, TPUs and pallax.md create mode 100644 chapter 16: SIMD and GPU programming/06. RISC-V and embedded systems.md create mode 100644 chapter 16: SIMD and GPU programming/07. vulkan compute and cross-platform GPU.md create mode 100644 chapter 17: AI inference/01. quantisation.md create mode 100644 chapter 17: AI inference/02. efficient architectures.md create mode 100644 chapter 17: AI inference/03. serving and batching.md create mode 100644 chapter 17: AI inference/04. edge inference.md create mode 100644 chapter 17: AI inference/05. scaling and deployment.md create mode 100644 chapter 18: ML systems design/01. systems design fundamentals.md create mode 100644 chapter 18: ML systems design/02. cloud computing.md create mode 100644 chapter 18: ML systems design/03. large scale infrastructure.md create mode 100644 chapter 18: ML systems design/04. ML systems design.md create mode 100644 chapter 18: ML systems design/05. ML design examples.md create mode 100644 chapter 19: applied AI/01. AI for finance.md create mode 100644 chapter 19: applied AI/02. protein design.md create mode 100644 chapter 19: applied AI/03. drug discovery.md create mode 100644 chapter 19: applied AI/04. agentic systems.md create mode 100644 chapter 19: applied AI/05. healthcare.md create mode 100644 chapter 20: bleeding edge AI/01. quantum machine learning.md create mode 100644 chapter 20: bleeding edge AI/02. neuromorphic computing.md create mode 100644 chapter 20: bleeding edge AI/03. datacentres in space.md create mode 100644 chapter 20: bleeding edge AI/04. decentralised AI.md create mode 100644 chapter 20: bleeding edge AI/05. brain machine interfaces.md create mode 120000 docs/chapter 01: vectors create mode 120000 docs/chapter 02: matrices create mode 120000 docs/chapter 03: calculus create mode 120000 docs/chapter 04: statistics create mode 120000 docs/chapter 05: probability create mode 120000 docs/chapter 06: machine learning create mode 120000 docs/chapter 07: computational linguistics create mode 120000 docs/chapter 08: computer vision create mode 120000 docs/chapter 09: audio and speech create mode 120000 docs/chapter 10: multimodal learning create mode 120000 docs/chapter 11: autonomous systems create mode 120000 docs/chapter 12: graph neural networks create mode 120000 docs/chapter 13: computing and OS create mode 120000 docs/chapter 14: data structures and algorithms create mode 120000 docs/chapter 15: production software engineering create mode 120000 docs/chapter 16: SIMD and GPU programming create mode 120000 docs/chapter 17: AI inference create mode 120000 docs/chapter 18: ML systems design create mode 120000 docs/chapter 19: applied AI create mode 120000 docs/chapter 20: bleeding edge AI create mode 120000 docs/images create mode 120000 docs/index.md create mode 120000 docs/javascripts create mode 100644 images/ab_testing.svg create mode 100644 images/action_tokenisation.svg create mode 100644 images/activation_functions.svg create mode 100644 images/actor_critic.svg create mode 100644 images/additive_inverse.svg create mode 100644 images/amdahl_serial_bottleneck.svg create mode 100644 images/any_to_any_architectures.svg create mode 100644 images/area_under_curve.svg create mode 100644 images/asr_pipeline.svg create mode 100644 images/attention_alignment.svg create mode 100644 images/attention_captioning.svg create mode 100644 images/attention_sparsity_patterns.svg create mode 100644 images/audio_spectrogram_transformer.svg create mode 100644 images/audio_visual_correspondence.svg create mode 100644 images/audio_waveform.svg create mode 100644 images/autonomous_driving_stack.svg create mode 100644 images/bag_of_words.svg create mode 100644 images/basis_transform.svg create mode 100644 images/bayes_components.svg create mode 100644 images/beamforming.svg create mode 100644 images/bernoulli_binomial.svg create mode 100644 images/bert_mlm.svg create mode 100644 images/bev_fusion_pipeline.svg create mode 100644 images/bidirectional_rnn.svg create mode 100644 images/bio_tagging.svg create mode 100644 images/cache_aside_pattern.svg create mode 100644 images/cap_theorem.svg create mode 100644 images/central_limit_theorem.svg create mode 100644 images/chain_rule.svg create mode 100644 images/clip_contrastive_matrix.svg create mode 100644 images/cloud_service_layers.svg create mode 100644 images/cnn_convolution.svg create mode 100644 images/cocktail_party.svg create mode 100644 images/codebook_collapse.svg create mode 100644 images/cofactor.svg create mode 100644 images/column_space.svg create mode 100644 images/common_distributions.svg create mode 100644 images/commutativity.svg create mode 100644 images/compilation_pipeline.svg create mode 100644 images/concurrency_vs_parallelism.svg create mode 100644 images/conditional_probability.svg create mode 100644 images/confidence_interval.svg create mode 100644 images/conformer_block.svg create mode 100644 images/constituency_tree.svg create mode 100644 images/container_vs_vm.svg create mode 100644 images/continuous_vs_discrete_tokens.svg create mode 100644 images/contrastive_temperature.svg create mode 100644 images/conv_tasnet.svg create mode 100644 images/convex_nonconvex.svg create mode 100644 images/correlation_scatter.svg create mode 100644 images/counting_outfits.svg create mode 100644 images/cpu_pipeline.svg create mode 100644 images/critical_points.svg create mode 100644 images/crnn_ocr_pipeline.svg create mode 100644 images/cross_modal_generation_overview.svg create mode 100644 images/ctc_alignment.svg create mode 100644 images/dalle_autoregressive_pipeline.svg create mode 100644 images/data_model_parallelism.svg create mode 100644 images/deadlock_cycle.svg create mode 100644 images/decision_tree_split.svg create mode 100644 images/deeplab_aspp.svg create mode 100644 images/densenet_block.svg create mode 100644 images/dependency_tree.svg create mode 100644 images/depthwise_separable_conv.svg create mode 100644 images/detection_boxes.svg create mode 100644 images/determinant.svg create mode 100644 images/difference_quotient.svg create mode 100644 images/diffusion_process.svg create mode 100644 images/distribution_shift_bc.svg create mode 100644 images/distribution_types.svg create mode 100644 images/distributivity.svg create mode 100644 images/dit_architecture.svg create mode 100644 images/dot_product.svg create mode 100644 images/dual_vs_fusion_encoder.svg create mode 100644 images/earth_mars_delay.svg create mode 100644 images/efficientnet_scaling.svg create mode 100644 images/eigenvector.svg create mode 100644 images/ensemble_methods.svg create mode 100644 images/faster_rcnn.svg create mode 100644 images/fcos_detection.svg create mode 100644 images/feature_store.svg create mode 100644 images/five_geometric_domains.svg create mode 100644 images/flamingo_architecture.svg create mode 100644 images/focal_loss.svg create mode 100644 images/fpn_pyramid.svg create mode 100644 images/fraud_detection_pipeline.svg create mode 100644 images/fusion_strategies.svg create mode 100644 images/gat_attention_weights.svg create mode 100644 images/gaussian_elimination.svg create mode 100644 images/generation_evaluation_metrics.svg create mode 100644 images/gpu_cluster_topology.svg create mode 100644 images/grad_cam.svg create mode 100644 images/gradient_contour.svg create mode 100644 images/gradient_descent_landscape.svg create mode 100644 images/graph_adjacency_matrix.svg create mode 100644 images/graph_laplacian_smoothness.svg create mode 100644 images/grounding_coordinate_tokens.svg create mode 100644 images/hifi_gan_generator.svg create mode 100644 images/hmm_structure.svg create mode 100644 images/human_vectors.svg create mode 100644 images/hypothesis_test.svg create mode 100644 images/ieee754_float.svg create mode 100644 images/image_histogram.svg create mode 100644 images/image_pyramid.svg create mode 100644 images/image_tokenisation_overview.svg create mode 100644 images/image_tokeniser_comparison.svg create mode 100644 images/inception_module.svg create mode 100644 images/instructpix2pix_pipeline.svg create mode 100644 images/invariance_vs_equivariance.svg create mode 100644 images/joint_embedding_space.svg create mode 100644 images/kl_divergence.svg create mode 100644 images/kmeans_clustering.svg create mode 100644 images/lagrange_multiplier.svg create mode 100644 images/lidar_time_of_flight.svg create mode 100644 images/line_equation.svg create mode 100644 images/linear_regression_fit.svg create mode 100644 images/llava_architecture.svg create mode 100644 images/lms_adaptive_filter.svg create mode 100644 images/load_balancer.svg create mode 100644 images/logo.png create mode 100644 images/lora_decomposition.svg create mode 100644 images/lu_decomposition.svg create mode 100644 images/magnitude_direction.svg create mode 100644 images/markov_chain.svg create mode 100644 images/mask_rcnn.svg create mode 100644 images/matrix_rank.svg create mode 100644 images/matrix_trace.svg create mode 100644 images/mdp_agent_loop.svg create mode 100644 images/mel_filterbank.svg create mode 100644 images/memory_hierarchy.svg create mode 100644 images/message_passing_gnn.svg create mode 100644 images/mfcc_pipeline.svg create mode 100644 images/mha_gqa_mqa.svg create mode 100644 images/microservices_architecture.svg create mode 100644 images/ml_lifecycle.svg create mode 100644 images/mla_vs_gqa.svg create mode 100644 images/mle_vs_map.svg create mode 100644 images/moe_layer.svg create mode 100644 images/moe_routing.svg create mode 100644 images/moments_shape.svg create mode 100644 images/monte_carlo_pi.svg create mode 100644 images/morpheme_tree.svg create mode 100644 images/multimodal_action_distribution.svg create mode 100644 images/multimodal_agent_loop.svg create mode 100644 images/multimodal_overview.svg create mode 100644 images/multimodal_tokenisation_sequence.svg create mode 100644 images/multimodal_world_model.svg create mode 100644 images/naive_bayes_classify.svg create mode 100644 images/newtons_method.svg create mode 100644 images/normal_empirical.svg create mode 100644 images/normalization_types.svg create mode 100644 images/occupancy_vs_bbox.svg create mode 100644 images/ocr_free_document_understanding.svg create mode 100644 images/open_vs_closed_loop.svg create mode 100644 images/optical_flow.svg create mode 100644 images/optimizer_memory.svg create mode 100644 images/optimizer_muon.svg create mode 100644 images/optimizer_trajectories.svg create mode 100644 images/orthogonal_vectors.svg create mode 100644 images/over_smoothing_gnn.svg create mode 100644 images/p_np_complexity.svg create mode 100644 images/paged_attention.svg create mode 100644 images/parallel_vectors.svg create mode 100644 images/partial_derivative.svg create mode 100644 images/pca.svg create mode 100644 images/permutation_vs_combination.svg create mode 100644 images/pid_response.svg create mode 100644 images/pinhole_camera.svg create mode 100644 images/polynomial_buildup.svg create mode 100644 images/precision_formats_memory.svg create mode 100644 images/process_states.svg create mode 100644 images/quantisation_granularity.svg create mode 100644 images/quartiles_boxplot.svg create mode 100644 images/rag_architecture.svg create mode 100644 images/random_variable.svg create mode 100644 images/receptive_field.svg create mode 100644 images/recommendation_pipeline.svg create mode 100644 images/reflection.svg create mode 100644 images/residual_quantisation.svg create mode 100644 images/resnet_block.svg create mode 100644 images/retrieval_recall_at_k.svg create mode 100644 images/reynolds_flocking.svg create mode 100644 images/rgb_channels.svg create mode 100644 images/ring_allreduce.svg create mode 100644 images/rlhf_pipeline.svg create mode 100644 images/rnn_lstm_cell.svg create mode 100644 images/robot_arm_fk.svg create mode 100644 images/rope_rotation.svg create mode 100644 images/rotation.svg create mode 100644 images/sae_autonomy_levels.svg create mode 100644 images/sampling_aliasing.svg create mode 100644 images/sampling_methods.svg create mode 100644 images/scalar_multiplication.svg create mode 100644 images/scaling.svg create mode 100644 images/scaling_laws.svg create mode 100644 images/scaling_vlms_comparison.svg create mode 100644 images/se3_equivariance.svg create mode 100644 images/semantic_segmentation.svg create mode 100644 images/sensor_comparison.svg create mode 100644 images/seq2seq_architecture.svg create mode 100644 images/shared_autonomy_spectrum.svg create mode 100644 images/shared_backbone_multimodal.svg create mode 100644 images/shearing.svg create mode 100644 images/sigmoid_logistic.svg create mode 100644 images/sim_to_real.svg create mode 100644 images/sobel_edges.svg create mode 100644 images/social_network_matrix.svg create mode 100644 images/sparse_attention_patterns.svg create mode 100644 images/sparse_dense.svg create mode 100644 images/speaker_diarisation.svg create mode 100644 images/speaker_verification.svg create mode 100644 images/spectrogram_stft.svg create mode 100644 images/speculative_decoding.svg create mode 100644 images/ssm_dual_view.svg create mode 100644 images/stable_diffusion_architecture.svg create mode 100644 images/stack_vs_heap.svg create mode 100644 images/staged_multimodal_training.svg create mode 100644 images/static_vs_continuous_batching.svg create mode 100644 images/subspaces.svg create mode 100644 images/surprisal_entropy.svg create mode 100644 images/svd.svg create mode 100644 images/svm_margin.svg create mode 100644 images/swarm_consensus.svg create mode 100644 images/swin_shifted_windows.svg create mode 100644 images/tacotron2_architecture.svg create mode 100644 images/tangent_line.svg create mode 100644 images/taylor_approximation.svg create mode 100644 images/tcp_ip_layers.svg create mode 100644 images/td_update.svg create mode 100644 images/temporal_compression_strategies.svg create mode 100644 images/text_diffusion.svg create mode 100644 images/text_to_audio_pipeline.svg create mode 100644 images/text_to_video_pipeline.svg create mode 100644 images/textcnn_architecture.svg create mode 100644 images/tokenisation_comparison.svg create mode 100644 images/training_memory_breakdown.svg create mode 100644 images/transformer_block.svg create mode 100644 images/transformer_paradigms.svg create mode 100644 images/tts_pipeline.svg create mode 100644 images/type_errors.svg create mode 100644 images/unet_architecture.svg create mode 100644 images/unified_multimodal_overview.svg create mode 100644 images/unified_vision_language_tokens.svg create mode 100644 images/variance_spread.svg create mode 100644 images/vector_3d.svg create mode 100644 images/vector_addition.svg create mode 100644 images/vector_equality.svg create mode 100644 images/venn_diagram.svg create mode 100644 images/vgg_architecture.svg create mode 100644 images/video_3d_vqvae.svg create mode 100644 images/visual_token_pipeline.svg create mode 100644 images/vit_pipeline.svg create mode 100644 images/vla_architecture.svg create mode 100644 images/vlm_taxonomy.svg create mode 100644 images/voice_conversion_pipeline.svg create mode 100644 images/vqa_pipeline.svg create mode 100644 images/vqgan_training.svg create mode 100644 images/vqvae_architecture.svg create mode 100644 images/wav2vec2_pretraining.svg create mode 100644 images/word2vec_architectures.svg create mode 100644 images/xvector_architecture.svg create mode 100644 images/yolo_grid.svg create mode 100644 images/zero_shot_classification.svg create mode 100644 images/zero_vector.svg create mode 100644 javascripts/mathjax.js create mode 100644 llms.txt create mode 100644 mcp/package.json create mode 100644 mcp/src/index.ts create mode 100644 mcp/tsconfig.json create mode 100644 mkdocs.yml diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ceddaa3 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.cache/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..c0fa0eb --- /dev/null +++ b/README.md @@ -0,0 +1,79 @@ +# 数学、计算机科学与人工智能纲要 + +Logo + +**在线阅读**: [henryndubuaku.github.io/maths-cs-ai-compendium](https://henryndubuaku.github.io/maths-cs-ai-compendium/) + +## 概述 +大多数教科书将好的思想埋没在密集的符号之下,跳过直觉,假设你已经掌握了一半的内容,并且在人工智能等快速发展的领域很快过时。这是一本开放、非传统的教科书,从零开始涵盖数学、计算机科学和人工智能。为那些希望深入理解知识、而不仅仅是为了通过考试或面试的好奇实践者而编写。 + +## 背景 +在过去几年从事AI/ML工作的过程中,我用笔记本记录了数学、计算机科学和人工智能概念的直觉优先、结合实际、不打马虎眼的解释。2025年,几位朋友用这些笔记准备DeepMind、OpenAI、Nvidia等公司的面试。他们全部被录用,目前在工作中表现出色。而我去年也进入了Y Combinator。所以现在我把这些分享给所有人。 + +## MCP 服务器 +本仓库包含一个MCP服务器,允许任何AI助手(Claude Code、Cursor、VS Code等)将这本纲要作为知识库使用。它需要本地克隆该仓库。内置教育用途的工具和示例实现。 + +## 内容大纲 + +| # | 章节 | 简介 | 状态 | +|---|------|------|------| +| 01 | [向量](chapter%2001%3A%20vectors/01.%20vector%20spaces.md) | 空间、模长、方向、范数、度量、点积/叉积/外积、基、对偶性 | 已完成 | +| 02 | [矩阵](chapter%2002%3A%20matrices/01.%20matrix%20properties.md) | 性质、特殊类型、运算、线性变换、分解(LU、QR、SVD) | 已完成 | +| 03 | [微积分](chapter%2003%3A%20calculus/01.%20differential%20calculus.md) | 导数、积分、多元微积分、泰勒近似、优化与梯度下降 | 已完成 | +| 04 | [统计学](chapter%2004%3A%20statistics/01.%20fundamentals.md) | 描述性度量、抽样、中心极限定理、假设检验、置信区间 | 已完成 | +| 05 | [概率论](chapter%2005%3A%20probability/01.%20counting.md) | 计数、条件概率、分布、贝叶斯方法、信息论 | 已完成 | +| 06 | [机器学习](chapter%2006%3A%20machine%20learning/01.%20classical%20machine%20learning.md) | 经典机器学习、梯度方法、深度学习、强化学习、分布式训练 | 已完成 | +| 07 | [计算语言学](chapter%2007%3A%20computational%20linguistics/01.%20linguistic%20foundations.md) | 句法学、语义学、语用学、自然语言处理、语言模型、RNN、CNN、注意力机制、Transformer、文本扩散、文本OCR、MoE、SSM、现代LLM架构、自然语言处理评估 | 已完成 | +| 08 | [计算机视觉](chapter%2008%3A%20computer%20vision/01.%20image%20fundamentals.md) | 图像处理、目标检测、分割、视频处理、SLAM、CNN、视觉Transformer、扩散模型、流匹配、VR/AR | 已完成 | +| 09 | [音频与语音](chapter%2009%3A%20audio%20and%20speech/01.%20digital%20signal%20processing.md) | 数字信号处理、自动语音识别、文本转语音、语音与声学活动检测、说话人分离、源分离、主动降噪、WaveNet、Conformer | 已完成 | +| 10 | [多模态学习](chapter%2010%3A%20multimodal%20learning/01.%20multimodal%20representations.md) | 融合策略、对比学习、CLIP、视觉语言模型、图像/视频分词、跨模态生成、统一架构、世界模型 | 已完成 | +| 11 | [自主系统](chapter%2011%3A%20autonomous%20systems/01.%20perception.md) | 感知、机器人学习、视觉-语言-动作模型、自动驾驶、太空机器人 | 已完成 | +| 12 | [图神经网络](chapter%2012%3A%20graph%20neural%20networks/01.%20geometric%20deep%20learning.md) | 几何深度学习、图论、GNN、图注意力机制、图Transformer、三维等变网络 | 已完成 | +| 13 | [计算与操作系统](chapter%2013%3A%20computing%20and%20OS/01.%20discrete%20maths.md) | 离散数学、计算机体系结构、操作系统、并发、并行、编程语言 | 已完成 | +| 14 | [数据结构与算法](chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations.md) | 大O表示法、递归、回溯、动态规划、数组、哈希、链表、栈、树、图、排序、二分查找 | 已完成 | +| 15 | [生产级软件工程](chapter%2015%3A%20production%20software%20engineering/01.%20linux%20and%20CMD.md) | Linux、Git、代码库设计、测试、CI/CD、Docker、模型服务、MLOps、监控、使用编码代理的最佳实践 | 已完成 | +| 16 | [SIMD与GPU编程](chapter%2016%3A%20SIMD%20and%20GPU%20programming/00.%20why%20C%2B%2B%20and%20how%20ML%20frameworks%20work.md) | 面向机器学习的C++、框架工作原理、硬件基础、ARM NEON/I8MM/SME2、x86 AVX、GPU/CUDA、Triton、TPU、RISC-V、Vulkan、WebGPU | 已完成 | +| 17 | [AI推理](chapter%2017%3A%20AI%20inference/01.%20quantisation.md) | 量化、高效架构、服务与批处理、边缘推理、推测解码、成本优化 | 已完成 | +| 18 | [ML系统设计](chapter%2018%3A%20ML%20systems%20design/01.%20systems%20design%20fundamentals.md) | 系统基础、云计算、分布式系统、ML生命周期、特征存储、A/B测试、推荐/搜索/广告/欺诈设计实例 | 已完成 | +| 19 | 应用人工智能 | 金融、医疗健康、蛋白质、药物发现中的人工智能 | 待完成 | +| 20 | 前沿人工智能 | 量子机器学习、神经形态机器学习、去中心化人工智能、太空数据中心、脑机接口 | 待完成 | + +## 前言 + +新生婴儿的大脑是一个新初始化的神经网络,通过现实世界的数据和经验训练直至成年……直至永远。能够用法语流利交流并拥有完美口音,意味着接触到了优秀的法语和完美口音。同样,优秀的人工智能研究员和工程师具备出色的问题解决能力,意味着他们吸收了高质量的知识并拥有丰富的经验。 + +科瓦舍夫实验是一项长期的塞尔维亚研究,表明为期三年的高强度创造性问题解决训练可以显著提高智力,尤其是流体智力,提升10-15个IQ点。当然,天生高IQ是真实存在的,就像优质的权重初始化能带来更好的训练效果一样——先天与后天之争的实验结果也证明了这一点。 + +然而,高IQ个体的真正优势仅在于能更快地学习和识别模式。但重复使用一种模式可以使任何概念都变得绝对可学。查尔斯·达尔文被他的老师和父亲认为是一个非常普通、甚至低于平均水平的学生。他自称并不机智,感觉自己像一个"慢处理器",需要时间来吸收数据。 + +在3到10岁之间,我的学习成绩很好,自然而然地理解概念,从不做笔记或复习。11到13岁之间我有点自大,用这种方式在一个80人的班级中跌到了下半部分。14到15岁之间,我开始像普通学生一样读书,在中学最后一个学期取得了第一名。早期学校课程与自然IQ配合得很好,但现实世界的才华源于高质量的知识摄入和执行力度。 + +事实上,大多数学习成绩好的学生只是更勤奋,但学术系统是为快速学习者设计的。这本纲要提供了一个全面且相互关联的知识流,以帮助世界上那些"达尔文们"更好地学习。你只需要初等数学基础和基本的Python编程知识,其他一切都会逐步掌握——只需阅读并相信这个过程! + +## 如何更好地学习 + +大学第一学期,我同时选了17门课,成绩并不理想,于是我采用了一个技巧: + +**第一阶段:课后累积阅读** +只阅读每张幻灯片/笔记的标题/大标题,合上书,然后在脑海中可视化并写出对该概念的解释。只重读你遗漏的部分,类似于机器学习中的掩码语言建模。重读之后,最终将概念用代码实现。这样你就能对每个概念形成肌肉记忆。 + +**第二阶段:考前影子阅读** +阅读每张幻灯片/笔记的副标题,合上书,然后在脑海中可视化并写出对该概念的解释。只重读你遗漏的部分,类似于机器学习中的掩码语言建模。重读之后,最终将概念用代码实现。这样你就能对每个概念形成肌肉记忆。 + +这个方法对我不太自信的朋友们非常有效。事实上,其中一位朋友在高等工程数学(涵盖海森矩阵和优化)这门课上超过了我。她现在在一家大型石油天然气公司工作。灵魂的意愿比我们与之工作的身体更重要(罗森塔尔实验)。 + +## 关于作者 + +查看GitHub个人资料! + +## 引用 + +```bibtex +@book{ndubuaku2025compendium, + title = {Maths, CS & AI Compendium}, + author = {Henry Ndubuaku}, + year = {2026}, + publisher = {GitHub}, + url = {https://github.com/HenryNdubuaku/maths-cs-ai-compendium} +} +``` diff --git a/chapter 01: vectors/01. vector spaces.md b/chapter 01: vectors/01. vector spaces.md new file mode 100644 index 0000000..9323e0f --- /dev/null +++ b/chapter 01: vectors/01. vector spaces.md @@ -0,0 +1,147 @@ +# 向量空间 + +*向量空间构成了机器学习的数学舞台。本文涵盖向量加法、标量乘法、封闭性公理、子空间,以及为什么AI中几乎所有东西都表示为向量。* + +- 将向量空间想象成一种特定类型的舞台,数学对象生活在其中,每个对象被称为一个**向量**。 + +- 为了机器学习(ML)中的几何直觉,我们始终将向量视为欧几里得空间中的一个点,由其坐标表示。 + +- 向量 $\mathbf{a}$(数学上用粗体小写字母表示)有 $n$ 个坐标,每个坐标代表沿一个轴的位置。 + +$$\mathbf{a} = [a_1, a_2, a_3]$$ + +![向量 a = (3, 2, 4) 在三维空间中沿 x、y、z 轴绘制](../images/vector_3d.svg) + +- 向量空间中的向量遵循一套非常具体、不可打破的规则: + + - **向量加法(组合)**: + 你可以取任意两个向量并将它们组合起来创建新向量。 + 把向量想象成移动的指令。 + 如果向量 A 表示"向前走 3 步",向量 B 表示"向右走 2 步", + 将它们相加(A + B)就创建了一条新的单一指令:"向前走 3 步并向右走 2 步。" + + - **标量乘法(缩放)**: + 你可以使用一个普通数字("标量")来缩放任意向量。 + 你可以拉伸它、缩小它或反转它。 + 如果向量 A 是"向前走 3 步",将其乘以 2 就变成"向前走 6 步。" + 将其乘以 -1 则完全翻转成"向后走 3 步。" + +- 向量空间的**维度**是其包含的独立方向的数量。$\mathbb{R}^2$ 是二维的(需要 2 个坐标),而上面的 $\mathbf{a}$ 存在于 $\mathbb{R}^3$ 中。 + +- 例如,我们可以将任何对象(比如一个人)表示为一个向量,其中 $h_1$ = 身高(厘米),$h_2$ = 体重(公斤),$h_3$ = 年龄。 + +$$\mathbf{h} = [185, 75, 30]$$ + +- 我们现在已经创建了一个包含表示人的向量的向量空间。 + +- 我们可以表示多个人,并观察他们之间的远近! + +![将三个人表示为向量:Alice 和 Carol 很近,Alice 和 Bob 很远](../images/human_vectors.svg) + +- 我们可以添加更多特征,创建丰富的人体表示,在 ML 中通常称为特征向量。 + +- 你拥有的独特且有意义的特征越多,特征向量的描述性就越强,这是需要记住的一个重要因素。 + +- 超过 3 维后,向量变得非常难以直观检查,这催生了一个名为**线性代数**的数学领域。 + +- 现在,**线性代数**是研究向量、向量空间以及向量之间映射关系的学科。 + +- 我们在 AI/ML 中将几乎所有东西都表示为向量,这使得线性代数成为该领域的基石。 + +- 向量加法可以通过将一个向量放在另一个向量的尾部,然后从原点画到终点的可视化方式执行。 + +![向量加法:a(红色)加 b(蓝色)得到结果 a + b(绿色虚线)](../images/vector_addition.svg) + +- 对于两个向量 $\mathbf{a} = (a_1, a_2)$ 和 $\mathbf{b} = (b_1, b_2)$:$\mathbf{a} + \mathbf{b} = (a_1 + b_1, a_2 + b_2)$ + +- 向量也可以相减,所有加法规则同样适用。 + +- 将向量乘以标量会在相同方向上按该因子缩放向量。 + +![标量乘法:v(红色)、2v(蓝色,加倍)、-v(紫色,反向)](../images/scalar_multiplication.svg) + +- 对于标量 $c$ 和向量 $\mathbf{v} = (v_1, v_2)$:$c\mathbf{v} = (cv_1, cv_2)$ + +- **加法封闭性**:如果将向量空间中的任意两个向量相加,结果也属于同一空间:如果 $\mathbf{u} \in V$ 且 $\mathbf{v} \in V$,则 $\mathbf{u} + \mathbf{v} \in V$ + +- **标量乘法封闭性**:如果将向量空间中的任意向量乘以标量,结果也属于同一空间:如果 $\mathbf{v} \in V$ 且 $c \in F$,则 $c\mathbf{v} \in V$ + +- **加法结合律**:对于任意三个向量 $\mathbf{u}$、$\mathbf{v}$ 和 $\mathbf{w}$:$(\mathbf{u} + \mathbf{v}) + \mathbf{w} = \mathbf{u} + (\mathbf{v} + \mathbf{w})$ + +- **加法交换律**:对于任意两个向量 $\mathbf{u}$ 和 $\mathbf{v}$:$\mathbf{u} + \mathbf{v} = \mathbf{v} + \mathbf{u}$ + +![平行四边形法则:两条路径(先 u 后 v,或先 v 后 u)到达同一点](../images/commutativity.svg) + +- 通过平行四边形的两条路径都到达同一点。 + +- **(零向量)**:存在一个向量 $\mathbf{0}$,使得对于任何向量 $\mathbf{v}$:$\mathbf{v} + \mathbf{0} = \mathbf{v}$ + +![零向量:v + 0 = v](../images/zero_vector.svg) + +- **加法逆元**:对于每个向量 $\mathbf{v}$,存在一个向量 $-\mathbf{v}$,使得:$\mathbf{v} + (-\mathbf{v}) = \mathbf{0}$ + +![加法逆元:v(红色)和 -v(蓝色)抵消为零](../images/additive_inverse.svg) + +- **分配律 1**:对于任意标量 $c$ 和向量 $\mathbf{u}$、$\mathbf{v}$:$c(\mathbf{u} + \mathbf{v}) = c\mathbf{u} + c\mathbf{v}$ + +![分配律:缩放和(金色)等于缩放后的向量之和](../images/distributivity.svg) + +- 缩放和(金色)与分别缩放向量再求和的结果相同。 + +- **分配律 2**:对于任意标量 $c$、$d$ 和向量 $\mathbf{v}$:$(c + d)\mathbf{v} = c\mathbf{v} + d\mathbf{v}$ + +- **结合律**:对于任意标量 $c$、$d$ 和向量 $\mathbf{v}$:$(cd)\mathbf{v} = c(d\mathbf{v})$ + +- **单位元**:对于任何向量 $\mathbf{v}$:$1\mathbf{v} = \mathbf{v}$,其中 $1$ 是标量域中的乘法单位元。 + +- **子空间**就是大空间内部的一个较小舞台。把三维空间想象成一个房间。一张穿过房间中心的平坦纸片就是一个子空间,穿过中心的一根直导线也是子空间。 + +- 关键要求是子空间必须经过原点。如果你把那片纸移开中心,它就不再是子空间了,因为零向量不再位于其上。 + +![子空间:经过原点的直线和平面在三维空间内部](../images/subspaces.svg) + +- 向量空间的所有规则(加法、缩放、封闭性)在子空间内部仍然有效。你可以在子空间内添加或缩放向量,永远不会"掉出"到更大的空间。 + +- 经过原点的直线是一维子空间,经过原点的平面是二维子空间,而整个空间是自身的子空间。 + +- 在 ML 中,子空间自然出现。高维数据通常具有存在于低维子空间上的结构。PCA 等技术找到那个子空间,这样我们可以更高效地处理数据。 + +## 编程练习(使用 CoLab 或 notebook) + +1. 运行代码验证分配律性质,然后修改并尝试测试其他规则! +```python +import jax.numpy as jnp + +u = jnp.array([1, 2]) +v = jnp.array([3, 0]) +c = 2 + +lhs = c * (u + v) +rhs = c*u + c*v + +print(f"LHS: {lhs}") +print(f"RHS: {rhs}") +``` + +2. 运行代码可视化不同的向量,然后修改不同坐标的值以理解每个轴如何影响位置。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# 尝试修改这些向量! +a = jnp.array([3, 2, 4]) +b = jnp.array([1, 4, 2]) +c = jnp.array([4, 1, 3]) + +fig = plt.figure() +ax = fig.add_subplot(111, projection="3d") + +for vec, name, color in [(a, "a", "red"), (b, "b", "blue"), (c, "c", "green")]: + ax.quiver(0, 0, 0, *vec, color=color, arrow_length_ratio=0.1, linewidth=2, label=name) + +lim = int(jnp.abs(jnp.stack([a, b, c])).max()) + 1 +ax.set_xlim([0, lim]); ax.set_ylim([0, lim]); ax.set_zlim([0, lim]) +ax.set_xlabel("X"); ax.set_ylabel("Y"); ax.set_zlabel("Z") +ax.legend() +plt.show() +``` diff --git a/chapter 01: vectors/02. vector properties.md b/chapter 01: vectors/02. vector properties.md new file mode 100644 index 0000000..e74939f --- /dev/null +++ b/chapter 01: vectors/02. vector properties.md @@ -0,0 +1,97 @@ +# 向量性质 + +*向量性质描述了定义向量行为的几何和代数特征。本文涵盖模长、方向、单位向量、相等性、平行性、正交性和线性无关性,它们是每个 ML 特征空间的基石。* + +- 向量的**模长**(或长度)告诉你它延伸了*多远*。把它想象成箭头的长度。对于向量 $\mathbf{a} = (a_1, a_2, a_3)$,其模长为: + +$$\|\mathbf{a}\| = \sqrt{a_1^2 + a_2^2 + a_3^2}$$ + +- 这只是勾股定理推广到更高维度,测量从原点到该点的直线距离。 + +- 向量的**方向**告诉你它指向*哪里*;只需想象从原点到坐标点的一条直线即可。 + +- 当没有明确指定原点时,我们通常隐含地使用 $(0,0,\ldots,0)$ 即中心点,至少为了可视化目的如此。 + +- 位置并不重要,它总是关于位移:从原点画出的向量 $(3, 2)$ 和从另一个点画出的同样的 $(3, 2)$ 仍然是相等的。 + +![向量相等:从两个不同起点画出的相同 (3,2) 向量](../images/vector_equality.svg) + +- 两个向量可以有相同的长度但指向完全不同的方向,或者指向相同方向但长度不同。 + +![相同方向、不同模长(v 和 2v)对比相同模长、不同方向](../images/magnitude_direction.svg) + +- 两个向量**相等**当且仅当它们所有对应的分量都匹配;相同的长度,相同的方向,完全相同的箭头。 + +$$\mathbf{a} = \mathbf{b} \iff a_i = b_i \text{ 对所有 } i$$ + +- 两个向量**平行**如果一个是另一个的标量倍数。它们沿着同一条直线,要么同向,要么完全反向。 + +$$\mathbf{a} \parallel \mathbf{b} \iff \mathbf{a} = k\mathbf{b} \text{ 对于某个标量 } k \neq 0$$ + +![平行向量:a 和 b 指向相同方向,a 和 -b 指向相反方向](../images/parallel_vectors.svg) + +- 如果 $k > 0$,它们指向相同方向。如果 $k < 0$,它们指向相反方向。无论哪种情况,它们都位于经过原点的同一条直线上。 + +- 直观地说,平行向量不携带任何"新的"方向信息。一个只是另一个的拉伸或翻转版本。 + +- 两个向量**正交**(垂直)如果它们指向完全独立的方向。沿一个方向移动不会让你在另一个方向上有任何进展。 + +![正交向量:u 和 v 成直角相交](../images/orthogonal_vectors.svg) + +- 想象向北走然后向东走,这些是正交方向,无论向北走多远都不会使你向东移动。我们经常会遇到正交性。 + +- 正交性对 ML 至关重要:正交的特征携带完全独立的信息,这对于表示是最理想的。 + +- 更一般地,任意两个向量之间都有一个**夹角** $\theta$,范围从 $0°$ 到 $180°$。 + +- 这个角度捕捉了两个方向之间的全部关系:$0°$ 表示平行(相同方向),$180°$ 表示平行(相反方向),$90°$ 表示正交。介于之间的都是混合情况。 + +- ML 中的大多数向量关系都处在这个范围的某处。稍后,我们将看到精确的工具(点积、余弦相似度)来计算这个角度。 + +- 一组向量是**线性相关**的,如果其中至少一个可以通过缩放和相加从其他向量构造出来。它没有为该集合带来新的信息。 + +- 例如,如果 $\mathbf{c} = 2\mathbf{a} + 3\mathbf{b}$,那么 $\mathbf{c}$ 是冗余的,你已经通过 $\mathbf{a}$ 和 $\mathbf{b}$ 拥有了 $\mathbf{c}$ 所提供的全部信息。 + +- 平行向量总是线性相关的,因为一个只是另一个的缩放副本。任何包含零向量的集合也是线性相关的。 + +- 向量是**线性无关**的,如果其中没有一个能从其他向量构造出来。每个向量都贡献了一个真正的新方向。正交向量总是线性无关的。 + +- 在二维中,两个线性无关的向量可以到达平面上的任何点。在三维中,你需要三个。"需要多少个独立的向量"这个想法直接与维度相关。 + +- 当向量的大多数分量为零时,该向量是**稀疏**的。相反,大多数分量非零称为**稠密**。 + +$$\mathbf{s} = [0, 0, 3, 0, 0, 0, 1, 0, 0, 0]$$ + +- 稀疏性很重要,因为它影响存储和计算。稀疏向量可以通过只跟踪非零条目来更高效地存储和处理。 + +- **单位向量**是模长正好为 1 的向量。它纯粹表示方向,不包含长度信息。你可以通过除以模长将任何向量变成单位向量: + +$$\hat{\mathbf{a}} = \frac{\mathbf{a}}{\|\mathbf{a}\|}$$ + +- 这个过程称为**归一化**。它剥离了"多远",只保留"往哪走。" + +- 标准单位向量指向每个轴:$\hat{\mathbf{i}} = (1, 0, 0)$,$\hat{\mathbf{j}} = (0, 1, 0)$,$\hat{\mathbf{k}} = (0, 0, 1)$。任何向量都可以写成这些向量的组合,例如 $(3, 2, 4) = 3\hat{\mathbf{i}} + 2\hat{\mathbf{j}} + 4\hat{\mathbf{k}}$。 + +## 编程练习(使用 CoLab 或 notebook) + +1. 计算向量的模长并验证它符合勾股定理,然后修改代码计算单位向量。 +```python +import jax.numpy as jnp + +a = jnp.array([3.0, 4.0]) + +magnitude = jnp.sqrt(jnp.sum(a ** 2)) +print(f"Magnitude of a: {magnitude}") +``` + +2. 通过测试一个向量是否是另一个的标量倍数来检查两个向量是否平行。 +```python +import jax.numpy as jnp + +a = jnp.array([2, 4, 6]) +b = jnp.array([1, 2, 3]) + +ratios = a / b +print(f"Ratios: {ratios}") +print(f"Parallel: {jnp.allclose(ratios, ratios[0])}") +``` diff --git a/chapter 01: vectors/03. norms and metrics.md b/chapter 01: vectors/03. norms and metrics.md new file mode 100644 index 0000000..4f4abde --- /dev/null +++ b/chapter 01: vectors/03. norms and metrics.md @@ -0,0 +1,90 @@ +# 度量与范数 + +*范数衡量单个向量的大小;度量衡量两个向量之间的距离。本文涵盖 L1、L2 和 L-无穷范数、欧几里得距离和余弦距离,以及为什么为 kNN、聚类和 ML 中的检索选择合适的距离函数至关重要。* + +- 我们知道向量有模长和方向。但我们如何实际衡量单个向量"有多大",或者两个向量"有多远"?这就是**范数**和**度量**发挥作用的地方。 + +- 对标量而言,我们知道 10 > 5,因为它们的值对它们进行了量化,但是我们如何量化一个向量?它的**范数**衡量单个向量的大小。 + +- 最熟悉的范数是**欧几里得范数**(L2),它就是我们已知的模长公式: + +$$\|\mathbf{v}\|_2 = \sqrt{v_1^2 + v_2^2 + \cdots + v_n^2}$$ + +- 但还有其他衡量大小的方法。想象你在一个街道呈网格状的城市中。你不能斜穿建筑物,所以你旅程的"长度"是沿着每条街道行走的总街区数。这就是**曼哈顿范数**(L1): + +$$\|\mathbf{v}\|_1 = |v_1| + |v_2| + \cdots + |v_n|$$ + +- 或者你可能只关心单个最大的分量,忽略其余部分。这就是**最大范数**(L-无穷): + +$$\|\mathbf{v}\|_\infty = \max(|v_1|, |v_2|, \ldots, |v_n|)$$ + +- 这三个都是**一般 Lp 范数**的特例: + +$$\|\mathbf{v}\|_p = (|v_1|^p + |v_2|^p + \cdots + |v_n|^p)^{1/p}$$ + +- 设置 $p = 2$ 得到欧几里得,$p = 1$ 得到曼哈顿,而当 $p \to \infty$ 时得到最大范数。随着 $p$ 增大,最大分量贡献越来越大,直到最终只有它重要。 + +- 每个范数必须遵守三条规则: + + - **非负性**:$\|\mathbf{v}\| \geq 0$,且 $\|\mathbf{v}\| = 0$ 仅当 $\mathbf{v} = \mathbf{0}$。大小从不为负,只有零向量的大小为零。 + + - **缩放性**:$\|c\mathbf{v}\| = |c| \cdot \|\mathbf{v}\|$。将向量加倍,其大小也加倍。 + + - **三角不等式**:$\|\mathbf{u} + \mathbf{v}\| \leq \|\mathbf{u}\| + \|\mathbf{v}\|$。捷径永远不会比绕远路更长。 + +- 现在,**度量**衡量*两个*向量之间的距离。把它想象成问:"这两个点相距多远?" + +- 获得度量的最简单方法是使用差值的范数:$d(\mathbf{u}, \mathbf{v}) = \|\mathbf{u} - \mathbf{v}\|$。减去两个向量,然后测量剩余部分的大小。 + +- 使用欧几里得范数,我们得到熟悉的**欧几里得距离**: + +$$d(\mathbf{u}, \mathbf{v}) = \sqrt{(u_1 - v_1)^2 + (u_2 - v_2)^2 + \cdots + (u_n - v_n)^2}$$ + +- 使用曼哈顿范数得到**曼哈顿距离**,沿着每个轴的总差异,就像计算两个位置之间的城市街区数。 + +- 每个度量必须遵守四条规则: + + - **非负性**:$d(\mathbf{u}, \mathbf{v}) \geq 0$。距离从不为负。 + + - **同一性**:$d(\mathbf{u}, \mathbf{v}) = 0$ 当且仅当 $\mathbf{u} = \mathbf{v}$。零距离意味着同一点。 + + - **对称性**:$d(\mathbf{u}, \mathbf{v}) = d(\mathbf{v}, \mathbf{u})$。从 A 到 B 的距离与从 B 到 A 的距离相同。 + + - **三角不等式**:$d(\mathbf{u}, \mathbf{w}) \leq d(\mathbf{u}, \mathbf{v}) + d(\mathbf{v}, \mathbf{w})$。直接走永远不会比绕路更长。 + +- 那么两者之间的关系是什么?范数衡量一个向量,度量衡量两个向量之间的差距。每个范数自然地创建一个度量(通过测量差值),但并非每个度量都来自范数。 + +- 例如,**汉明距离**计算两个向量不同的位置数量。它是一个有效的度量,但并非来自任何范数。 + +- 在 ML 中,选择合适的范数或度量很重要。 + +- L2 距离在求和前对每个差值平方,因此单个大的差值会主导结果。 + +- L1 距离对绝对差值求和,平等对待每个差值。与 L2 相比,单个大的差值影响较小。 + +## 编程练习(使用 CoLab 或 notebook) + +1. 计算同一向量的 L1 和 L2 范数。尝试更改值,注意哪个范数对大的分量最敏感,哪个对许多小分量最敏感。然后尝试计算 p 值递增(例如 1、2、5、10、50、100)时的 Lp 范数,观察它如何收敛到 L-无穷值。 +```python +import jax.numpy as jnp + +v = jnp.array([3.0, -4.0, 1.0]) + +l1 = jnp.sum(jnp.abs(v)) +l2 = jnp.sqrt(jnp.sum(v ** 2)) + +print(f"L1: {l1}, L2: {l2:.2f}") +``` + +2. 计算两个向量之间的欧几里得距离和曼哈顿距离。尝试让向量彼此靠近或远离,观察每种距离如何不同地响应。 +```python +import jax.numpy as jnp + +u = jnp.array([1.0, 2.0, 3.0]) +v = jnp.array([4.0, 0.0, 1.0]) + +euclidean = jnp.sqrt(jnp.sum((u - v) ** 2)) +manhattan = jnp.sum(jnp.abs(u - v)) + +print(f"Euclidean: {euclidean:.2f}, Manhattan: {manhattan}") +``` diff --git a/chapter 01: vectors/04. products.md b/chapter 01: vectors/04. products.md new file mode 100644 index 0000000..85c8136 --- /dev/null +++ b/chapter 01: vectors/04. products.md @@ -0,0 +1,113 @@ +# 向量积 + +*向量积是衡量相似性和计算投影的基本运算。本文涵盖内积、点积、余弦相似度、叉积和外积,这些运算支撑了 AI 中的注意力机制、嵌入和几何推理。* + +- 我们已经看到如何相加和缩放向量。但是我们可以*相乘*两个向量吗?事实证明不止一种方法,每种方法回答不同的问题。 + +- **内积**是一个广义概念:一个接受两个向量并产生一个标量的函数。它是"相乘"向量的抽象蓝图。 + +- 任何内积必须满足三条规则: + + - **正定性**:$\langle \mathbf{v}, \mathbf{v} \rangle \geq 0$,且仅对零向量等于零。向量与自身相乘总是给出非负结果。 + + - **对称性**:$\langle \mathbf{u}, \mathbf{v} \rangle = \langle \mathbf{v}, \mathbf{u} \rangle$。顺序无关紧要。 + + - **线性性**:$\langle a\mathbf{u} + b\mathbf{v}, \mathbf{w} \rangle = a\langle \mathbf{u}, \mathbf{w} \rangle + b\langle \mathbf{v}, \mathbf{w} \rangle$。它对加法和缩放具有分配性。 + +- **点积**是最常见的内积。它是你几乎到处都会用到的具体版本。对于两个向量 $\mathbf{a} = (a_1, a_2, \ldots, a_n)$ 和 $\mathbf{b} = (b_1, b_2, \ldots, b_n)$: + +$$\mathbf{a} \cdot \mathbf{b} = a_1 b_1 + a_2 b_2 + \cdots + a_n b_n$$ + +- 将匹配的分量相乘,然后全部加起来。这就是全部。 + +- 但这个数字*意味着*什么?点积有一个优美的几何解释: + +$$\mathbf{a} \cdot \mathbf{b} = \|\mathbf{a}\| \, \|\mathbf{b}\| \cos(\theta)$$ + +![点积:向量 a 投影到 b 上,显示角度 θ 和投影](../images/dot_product.svg) + +- 这将点积直接与两个向量之间的角度 $\theta$ 联系起来。结果告诉你两个向量在方向上"一致"的程度。 + +- 如果它们指向相同方向($\theta = 0°$),$\cos(\theta) = 1$ 且点积最大。 + +- 如果它们正交($\theta = 90°$),$\cos(\theta) = 0$ 且点积恰好为零。这给出了正交性的精确检验。 + +- 如果它们指向相反方向($\theta = 180°$),$\cos(\theta) = -1$ 且点积为负。 + +- 向量与自身的点积给出其模长的平方:$\mathbf{a} \cdot \mathbf{a} = \|\mathbf{a}\|^2$。 + +- 点积还给出了**投影**,即一个向量在另一个向量上投下的影子。$\mathbf{a}$ 在 $\mathbf{b}$ 上的投影为: + +$$\text{proj}_{\mathbf{b}}(\mathbf{a}) = \frac{\mathbf{a} \cdot \mathbf{b}}{\|\mathbf{b}\|^2} \, \mathbf{b}$$ + +- 想象一束光线直射到 $\mathbf{b}$ 上。$\mathbf{a}$ 在那条线上的影子就是投影。它告诉你 $\mathbf{a}$ 有多少位于 $\mathbf{b}$ 的方向上。 + +- **余弦相似度**通过除以两个模长来归一化点积: + +$$\cos(\theta) = \frac{\mathbf{a} \cdot \mathbf{b}}{\|\mathbf{a}\| \, \|\mathbf{b}\|}$$ + +- 这会给出一个介于 $-1$ 和 $1$ 之间的值,衡量方向对齐程度,忽略向量的长度。它广泛应用于 ML 中来比较文档、嵌入和用户偏好等事物。 + +- 现在,点积接受两个向量并返回标量。**叉积**则相反,它接受两个向量并返回一个*新向量*。 + +- 叉积 $\mathbf{a} \times \mathbf{b}$ 产生一个同时垂直于 $\mathbf{a}$ 和 $\mathbf{b}$ 的向量: + +$$\mathbf{a} \times \mathbf{b} = (a_2 b_3 - a_3 b_2, \; a_3 b_1 - a_1 b_3, \; a_1 b_2 - a_2 b_1)$$ + +- 叉积只适用于三维。点积适用于任意维度,而叉积是三维空间特有的。 + +- 其模长等于由这两个向量形成的平行四边形的面积: + +$$\|\mathbf{a} \times \mathbf{b}\| = \|\mathbf{a}\| \, \|\mathbf{b}\| \sin(\theta)$$ + +- 注意模式:点积使用 $\cos(\theta)$,叉积使用 $\sin(\theta)$。点积衡量两个向量对齐的程度,叉积衡量它们在方向上*差异*的程度。 + +- 结果的方向遵循**右手定则**:将右手的手指从 $\mathbf{a}$ 弯向 $\mathbf{b}$,拇指指向 $\mathbf{a} \times \mathbf{b}$ 的方向。 + +- 与点积不同,叉积**不可交换**:$\mathbf{a} \times \mathbf{b} = -(\mathbf{b} \times \mathbf{a})$。交换顺序会翻转方向。 + +- 如果两个向量平行,它们的叉积是零向量(因为 $\sin(0°) = 0$)。没有面积,没有垂直方向。 + +- 当你使用两个乘积结合三个向量会发生什么?这就得到了**三重积**。 + +- xxxxxxxxxx9 1import jax.numpy as jnp2​3u = jnp.array([1.0, 2.0, 3.0])4v = jnp.array([4.0, 0.0, 1.0])5​6euclidean = jnp.sqrt(jnp.sum((u - v) ** 2))7manhattan = jnp.sum(jnp.abs(u - v))8​9print(f"Euclidean: {euclidean:.2f}, Manhattan: {manhattan}")python + +- 如果标量三重积为零,则这三个向量**共面**,它们都位于同一个平坦平面上,不形成体积。 + +- 顺序可以循环而不改变结果:$\mathbf{a} \cdot (\mathbf{b} \times \mathbf{c}) = \mathbf{b} \cdot (\mathbf{c} \times \mathbf{a}) = \mathbf{c} \cdot (\mathbf{a} \times \mathbf{b})$。 + +- **向量三重积** $\mathbf{a} \times (\mathbf{b} \times \mathbf{c})$ 应用两次叉积并返回一个向量。它可以使用恒等式简洁展开: + +$$\mathbf{a} \times (\mathbf{b} \times \mathbf{c}) = (\mathbf{a} \cdot \mathbf{c})\mathbf{b} - (\mathbf{a} \cdot \mathbf{b})\mathbf{c}$$ + +- 结果总是位于由 $\mathbf{b}$ 和 $\mathbf{c}$ 张成的平面内。注意叉积**不满足结合律**:$\mathbf{a} \times (\mathbf{b} \times \mathbf{c}) \neq (\mathbf{a} \times \mathbf{b}) \times \mathbf{c}$。 + +## 编程练习(使用 CoLab 或 notebook) + +1. 计算两个向量的点积并用它求出它们之间的角度。尝试让它们正交、平行或反向,观察角度如何变化。 +```python +import jax.numpy as jnp + +a = jnp.array([1.0, 2.0, 3.0]) +b = jnp.array([4.0, -1.0, 2.0]) + +dot = jnp.dot(a, b) +angle = jnp.arccos(dot / (jnp.linalg.norm(a) * jnp.linalg.norm(b))) + +print(f"Dot product: {dot}") +print(f"Angle: {jnp.degrees(angle):.1f}°") +``` + +2. 计算两个三维向量的叉积,并通过检查结果与每个原始向量的点积为零来验证结果垂直于两者。 +```python +import jax.numpy as jnp + +a = jnp.array([1.0, 0.0, 0.0]) +b = jnp.array([0.0, 1.0, 0.0]) + +cross = jnp.cross(a, b) + +print(f"a x b = {cross}") +print(f"Perpendicular to a: {jnp.dot(cross, a) == 0}") +print(f"Perpendicular to b: {jnp.dot(cross, b) == 0}") +``` diff --git a/chapter 01: vectors/05. basis and duality.md b/chapter 01: vectors/05. basis and duality.md new file mode 100644 index 0000000..108ad65 --- /dev/null +++ b/chapter 01: vectors/05. basis and duality.md @@ -0,0 +1,89 @@ +# 基与对偶性 + +*基定义了向量空间的坐标系,而对偶性揭示了线性函数如何作用于向量。本文涵盖线性无关性、生成集、基变换、对偶空间和余向量,这些概念支撑了 ML 中的 PCA、特征变换和注意力查询。* + +- 我们已经看到向量存在于具有一定维度数的空间中。但什么定义了这些维度?这就是**基向量**发挥作用的地方。 + +- **基**是一组向量,可以通过缩放和相加(线性组合)构建空间中的每个其他向量,且没有冗余。它们是空间的构建块。 + +- 基必须满足两个条件: + + - **线性无关**:没有基向量能从其他基向量构造出来。每个都贡献了一个真正的新方向。 + + - **生成性**:空间中的每个向量都可以表示为基向量的组合。没有任何遗漏。 + +- 基中的向量数量等于空间的**维度**。在 $\mathbb{R}^2$ 中你需要 2 个,在 $\mathbb{R}^3$ 中你需要 3 个,依此类推。 + +- 最自然的基是**标准基**,即沿每个轴的单位向量: + + - 在 $\mathbb{R}^2$ 中:$\hat{\mathbf{i}} = (1, 0)$ 和 $\hat{\mathbf{j}} = (0, 1)$ + - 在 $\mathbb{R}^3$ 中:$\hat{\mathbf{i}} = (1, 0, 0)$,$\hat{\mathbf{j}} = (0, 1, 0)$,$\hat{\mathbf{k}} = (0, 0, 1)$ + +- 任何向量都是这些基向量的加权和。向量 $(3, 2)$ 实际上是 $3\hat{\mathbf{i}} + 2\hat{\mathbf{j}}$。权重(3 和 2)是该基下向量的**坐标**。 + +- 但标准基并不是唯一有效的基。在 $\mathbb{R}^2$ 中,向量 $(1, 1)$ 和 $(-1, 1)$ 也构成基。它们线性无关,并且可以到达平面上的任何点。同一个向量在这个新基下只是有不同的坐标。 + +- **基变换**使用不同的构建块重新表达同一个向量。向量没有移动,我们只是从不同的角度描述它。 + +- 这是通过乘以一个**基变换矩阵** $P$ 来完成的,其列是用旧坐标表示的新基向量。要变回去,乘以 $P^{-1}$。 + +- 在 ML 中,基变换经常出现。例如,PCA 找到一个新基(主成分),在该基下数据更容易理解,坐标轴与最大变化方向对齐。 + +- 现在,这里隐藏着一个更深层的想法。当我们写 $\mathbf{v} = (3, 2)$ 时,坐标 3 和 2 实际上是沿着每个基方向"测量" $\mathbf{v}$ 的结果。第一个坐标问"$\hat{\mathbf{i}}$ 在 $\mathbf{v}$ 中有多少?",第二个问"$\hat{\mathbf{j}}$ 呢?" + +- 这些测量中的每一个都是一个**线性泛函**,一个接受向量并返回单个标量的函数。所有这样的线性泛函的集合构成了**对偶空间** $V^\ast$。 + +- 这样想:向量是被测对象,线性泛函是测量它们的*标尺*。对偶空间是所有可能的标尺的集合。 + +- 对于每个基 $\{\mathbf{e}_1, \mathbf{e}_2, \ldots, \mathbf{e}_n\}$,存在一个对应的**对偶基** $\{\mathbf{e}_1^\ast, \mathbf{e}_2^\ast, \ldots, \mathbf{e}_n^\ast\}$。每个对偶基向量恰好提取一个坐标: + +```math +\mathbf{e}_i^\ast(\mathbf{e}_j) = \delta_{ij} = \begin{cases} 1 & \text{if } i = j \\ 0 & \text{if } i \neq j \end{cases} +``` + +- $\mathbf{e}_1^\ast$ 作用于 $\mathbf{e}_1$ 时返回 1,对其它所有向量返回 0。它完美地隔离了第一个坐标。 + +- 点积连接了这两个世界。当你计算 $\mathbf{u} \cdot \mathbf{v}$ 时,你可以把其中一个向量看作"标尺"在测量另一个向量。点积 $\mathbf{u} \cdot \mathbf{v}$ 等同于将由 $\mathbf{u}$ 定义的线性泛函应用于向量 $\mathbf{v}$。 + +- 这意味着每个向量都隐含地定义了一个线性泛函,并且每个线性泛函都可以用一个向量表示。在有限维空间中,对偶空间本质上是原始空间的镜像。 + +- 对偶性现在可能看起来很抽象,但它支撑着许多实际的概念:坐标是对偶基的评估,点积是对偶配对,而神经网络中的注意力等变换通过让一组向量"查询"另一组向量来运作,这正是对偶性在起作用。 + +## 编程练习(使用 CoLab 或 notebook) + +1. 在两个不同的基中表达一个向量,并验证它们代表同一个点。尝试创建你自己的基,观察向量得到什么坐标。 +```python +import jax.numpy as jnp + +v = jnp.array([3.0, 2.0]) + +# 标准基:坐标就是分量本身 +print(f"Standard basis coords: {v}") + +# 新基:(1,1) 和 (-1,1) +P = jnp.array([[1.0, -1.0], + [1.0, 1.0]]) +new_coords = jnp.linalg.solve(P, v) +print(f"New basis coords: {new_coords}") + +# 验证:从新坐标重建 +reconstructed = new_coords[0] * P[:, 0] + new_coords[1] * P[:, 1] +print(f"Reconstructed: {reconstructed}") +``` + +2. 验证对偶基性质:每个对偶基向量恰好提取一个坐标,对其他向量返回零。 +```python +import jax.numpy as jnp + +# R3 中的标准基 +e1 = jnp.array([1.0, 0.0, 0.0]) +e2 = jnp.array([0.0, 1.0, 0.0]) +e3 = jnp.array([0.0, 0.0, 1.0]) + +v = jnp.array([5.0, 3.0, 7.0]) + +# 每个点积提取一个坐标 +print(f"e1 · v = {jnp.dot(e1, v)}") +print(f"e2 · v = {jnp.dot(e2, v)}") +print(f"e3 · v = {jnp.dot(e3, v)}") +``` diff --git a/chapter 02: matrices/01. matrix properties.md b/chapter 02: matrices/01. matrix properties.md new file mode 100644 index 0000000..dc986eb --- /dev/null +++ b/chapter 02: matrices/01. matrix properties.md @@ -0,0 +1,166 @@ +# 矩阵性质 + +*矩阵是存储数据集、编码变换和定义每个神经网络层的数据结构。本文涵盖矩阵维度、元素、转置、迹、行列式、逆、秩和零空间,这些是贯穿线性代数和 ML 的基础性质。* + +- 核心而言,**矩阵**是按行列排列的数字矩形网格。如果向量是数字的单个列表,那么矩阵就是数字的一张表格。 + +```math +A = \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} +``` + +- 你也可以将矩阵视为向量的堆叠。 + +- 如果一个人由向量 $[\text{age}, \text{height}, \text{weight}]$ 描述,那么三个人就形成一个矩阵,其中每行是一个人: + +```math +\begin{bmatrix} 25 & 170 & 65 \\ 30 & 180 & 80 \\ 22 & 160 & 55 \end{bmatrix} +``` + +- 这个矩阵有 3 行和 3 列,所以我们称它为 $3 \times 3$ 矩阵。 + +- 网格中的每个数字称为一个**元素**或**条目**,由其行列标识:$A_{ij}$ 是第 $i$ 行第 $j$ 列的元素。 + +- 矩阵的**转置**沿其对角线翻转,将行变为列,列变为行。如果 $A$ 是 $m \times n$,那么 $A^T$ 是 $n \times m$。 + +```math +A = \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} \quad \Rightarrow \quad A^T = \begin{bmatrix} 1 & 4 \\ 2 & 5 \\ 3 & 6 \end{bmatrix} +``` + +- 矩阵乘以其转置总是得到一个方阵:$AA^T$ 是 $m \times m$,$A^TA$ 是 $n \times n$。 + +- 方阵的**迹**是其对角线元素之和:$\text{tr}(A) = A_{11} + A_{22} + \cdots + A_{nn}$。迹等于特征值之和(我们稍后会看到)。 + +![迹:对角线元素之和](../images/matrix_trace.svg) + +- 对于上面的矩阵,$\text{tr}(A) = 1 + 4 + 9 = 14$。只有高亮的对角线部分重要。 + +- 如果两个矩阵在不同基下表示相同的线性变换,它们的迹相同。迹是"与基无关的。" + +- 矩阵的**秩**是线性无关的行(或等价地,列)的数量。它告诉你矩阵携带了多少"有用信息。" + +- 例如,以下矩阵的秩为 2,因为两行之间互不为倍数: + +```math +\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} +``` + +但以下矩阵的秩为 1,因为第二行只是第一行的两倍,所以它没有增加新信息: + +```math +\begin{bmatrix} 1 & 2 \\ 2 & 4 \end{bmatrix} +``` + +- 一个 $5 \times 3$ 矩阵的秩最多为 3。如果某些行只是其他行的缩放或组合版本,秩就会下降。具有最大可能秩的矩阵称为**满秩**。 + +![秩:独立行张成整个空间 vs. 相关行只张成一个子空间](../images/matrix_rank.svg) + +- 方阵可逆(有逆矩阵)当且仅当它是满秩的。 + +- 秩通过**秩-零化度定理**与**零空间**(矩阵映射到零的向量的集合)相连:$\text{rank}(A) + \text{nullity}(A) = \text{列数 of } A$。矩阵保留的(秩)加上它破坏的(零化度)等于总维度。 + +- 矩阵的**列空间**是当你将矩阵乘以任意向量时所有可能输出的集合。它由矩阵的列张成。如果矩阵有 3 列但只有 2 列独立,列空间是一个二维平面,而不是整个三维空间。 + +![列空间:独立列张成一个平面,相关列只张成一条线](../images/column_space.svg) + +- **行空间**是同样的概念,但从行的角度来看。秩等于列空间和行空间的维度,所以它们总是一致的。 + +- 一起来看,列空间告诉你"这个矩阵能产生什么输出?"零空间告诉你"什么输入被映射到零?"这两个空间完整描述了矩阵的功能。 + +- 方阵的**行列式**是一个标量,捕捉矩阵如何缩放空间。想象一个 $2 \times 2$ 矩阵将一个单位正方形变换成一个平行四边形。行列式就是那个平行四边形的面积(带有符号)。 + +```math +\det\begin{bmatrix} a & b \\ c & d \end{bmatrix} = ad - bc +``` + +![行列式:线性变换的面积缩放因子](../images/determinant.svg) + +- 例如: + +```math +\det\begin{bmatrix} 2 & 1 \\ 0 & 3 \end{bmatrix} = 2 \cdot 3 - 1 \cdot 0 = 6 +``` + +这个变换将单位正方形拉伸成一个面积为 6 的平行四边形。 + +- 如果行列式为正,变换保持定向(事物不会被"翻转")。如果为负,它翻转定向(像镜面反射)。如果为零,矩阵将空间压缩到更低维度,将平行四边形坍缩成一条线或一个点。 + +- 行列式为零的矩阵称为**奇异矩阵**。它没有逆矩阵且已永久丢失信息。 + +- 对于大于 $2 \times 2$ 的矩阵,行列式使用**余子式**和**代数余子式**计算。**余子式** $M_{ij}$ 是通过删除第 $i$ 行和第 $j$ 列得到的较小矩阵的行列式。 + +![余子式:删除一行一列得到更小的矩阵](../images/cofactor.svg) + +- **代数余子式** $C_{ij} = (-1)^{i+j} M_{ij}$ 为每个余子式附加一个符号(像棋盘一样交替:$+, -, +, \ldots$)。整个矩阵的行列式然后沿着任意行或列求和:$\det(A) = \sum_j A_{1j} \cdot C_{1j}$。这称为**代数余子式展开**。 + +- 方阵 $A$ 的**逆**,记作 $A^{-1}$,是撤销 $A$ 的矩阵:$AA^{-1} = A^{-1}A = I$(单位矩阵)。只有非奇异矩阵才有逆。 + +- 对于 $2 \times 2$ 矩阵,逆有一个直接公式: + +```math +\begin{bmatrix} a & b \\ c & d \end{bmatrix}^{-1} = \frac{1}{ad - bc}\begin{bmatrix} d & -b \\ -c & a \end{bmatrix} +``` + +注意分母中的行列式,这就是为什么奇异矩阵(行列式为零)没有逆。 + +- **条件数**衡量矩阵对其输入微小变化的敏感程度。它定义为 $\kappa(A) = \|A\| \cdot \|A^{-1}\|$。 + +- 接近 1 的条件数意味着矩阵是**良态的**:微小的输入变化产生微小的输出变化。大的条件数意味着它是**病态的**:微小的误差被极大放大。正交矩阵和单位矩阵的条件数为 1,而奇异矩阵的条件数为无穷大。 + +- 例如,以下矩阵的条件数为 $10^8$。一个方向被正常缩放,而另一个几乎被压缩为零,所以沿该方向的小扰动会被严重扭曲: + +```math +\begin{bmatrix} 1 & 0 \\ 0 & 10^{-8} \end{bmatrix} +``` + +- 就像向量有范数(长度)一样,矩阵也有衡量其"大小"的**范数**。最常见的是**弗罗贝尼乌斯范数**,它将矩阵视为一个长向量并计算其长度: + +```math +\|A\|_F = \sqrt{\sum_{i}\sum_{j} A_{ij}^2} +``` + +- 例如: + +```math +\left\|\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}\right\|_F = \sqrt{1 + 4 + 9 + 16} = \sqrt{30} \approx 5.48 +``` + +- **谱范数** $\|A\|_2$ 是 $A$ 的最大奇异值。它衡量矩阵可以拉伸任何单位向量的最大程度。在 ML 中,矩阵范数用于权重正则化(惩罚大权重)和监控训练稳定性。 + +- 对称矩阵 $A$ 是**正定的**,如果对每个非零向量 $\mathbf{x}$:$\mathbf{x}^T A \mathbf{x} > 0$。这个二次型总是产生正数。 + +- 例如,以下矩阵是正定的: + +```math +A = \begin{bmatrix} 2 & 1 \\ 1 & 3 \end{bmatrix} +``` + +取任意向量,比如 $\mathbf{x} = [1, -1]^T$:$\mathbf{x}^T A \mathbf{x} = 2 - 1 - 1 + 3 = 3 > 0$。无论你尝试哪个非零 $\mathbf{x}$,你总是得到正的结果。 + +- 正定矩阵很重要,因为它们保证优化问题有唯一的最小值。 + +- 如果条件放宽到 $\mathbf{x}^T A \mathbf{x} \geq 0$(允许为零),矩阵是**半正定**(PSD)。PSD 矩阵经常出现:协方差矩阵、SVM 中的核矩阵以及局部最小值处的 Hessian 矩阵都是 PSD。区别在于 PSD 允许某些方向是"平坦的"(零曲率),而不是严格向上弯曲。 + +## 编程练习(使用 CoLab 或 notebook) + +1. 计算矩阵的迹、秩和行列式。尝试使一行成为另一行的倍数,观察秩和行列式如何变化。 +```python +import jax.numpy as jnp + +A = jnp.array([[1.0, 2.0], + [3.0, 4.0]]) + +print(f"Trace: {jnp.trace(A)}") +print(f"Rank: {jnp.linalg.matrix_rank(A)}") +print(f"Determinant: {jnp.linalg.det(A):.2f}") +``` + +2. 计算矩阵的逆,将其乘以原矩阵,验证得到单位矩阵。然后尝试奇异矩阵并观察会发生什么。 +```python +import jax.numpy as jnp + +A = jnp.array([[1.0, 2.0], + [3.0, 4.0]]) + +A_inv = jnp.linalg.inv(A) +print(f"A * A_inv:\n{A @ A_inv}") +``` diff --git a/chapter 02: matrices/02. matrix types.md b/chapter 02: matrices/02. matrix types.md new file mode 100644 index 0000000..65efea9 --- /dev/null +++ b/chapter 02: matrices/02. matrix types.md @@ -0,0 +1,139 @@ +# 矩阵类型 + +*特殊的矩阵结构能够解锁计算捷径和数学保证。本文涵盖单位矩阵、对角矩阵、对称矩阵、三角矩阵、正交矩阵、正定矩阵、稀疏矩阵和随机矩阵,这些类型出现在协方差估计、图算法、正则化和马尔可夫链中。* + +- 并非所有矩阵都一样。不同的结构赋予矩阵特殊的性质,使它们计算更快、更易于推理,或两者兼得。以下是你最常遇到的类型。 + +- **方阵**的行数和列数相同($n \times n$)。大多数有趣的性质(行列式、特征值、逆)只适用于方阵。 + +- **单位矩阵** $I$ 是一个对角线为 1、其余为 0 的方阵。它是"什么都不做"的变换:$AI = IA = A$ 对任何兼容的矩阵 $A$。 + +```math +I = \begin{bmatrix} 1 & 0 & 0 \\ 0 & 1 & 0 \\ 0 & 0 & 1 \end{bmatrix} +``` + +- **零矩阵** $O$ 的所有元素都为零。它将每个向量映射到零向量,破坏所有信息。 + +- **对角矩阵**除主对角线外全为零。将向量乘以对角矩阵只是独立地缩放每个分量,非常高效。 + +```math +D = \begin{bmatrix} 3 & 0 \\ 0 & 7 \end{bmatrix} +``` + +- **对称矩阵**等于其转置:$A = A^T$,意味着 $A_{ij} = A_{ji}$。对称矩阵有一个特殊性质:它们的特征向量总是相互垂直。协方差矩阵总是对称的。 + +```math +S = \begin{bmatrix} 3 & -1 \\ -1 & 6 \end{bmatrix} +``` + +- **三角矩阵**在对角线的一侧全为零。**下三角**在上方全为零,**上三角**在下方全为零。它们对于通过前向或回代高效求解方程组至关重要。 + +```math +L = \begin{bmatrix} 2 & 0 & 0 \\ 1 & 3 & 0 \\ -1 & 2 & 4 \end{bmatrix} \qquad U = \begin{bmatrix} 5 & -1 & 2 \\ 0 & 1 & 3 \\ 0 & 0 & -2 \end{bmatrix} +``` + +- 三角矩阵的行列式就是其对角线元素的乘积。 + +- **正交矩阵**具有转置等于逆的性质:$Q^TQ = QQ^T = I$。 + +- 这意味着你只需转置就能"撤销"变换,计算成本很低。其列是标准正交的(单位长度且相互垂直)。 + +- **稀疏矩阵**的大多数元素为零,而**稠密矩阵**的大多数元素非零。 + +![稀疏 vs 稠密:点表示非零元素](../images/sparse_dense.svg) + +- 在实践中,许多现实世界的矩阵是极其稀疏的。 + +- 一个拥有百万用户的社交网络可以表示为一个 $10^6 \times 10^6$ 的矩阵,但每个人只连接到少数其他人,所以几乎所有元素都是零。 + +![一个小型社交网络及其邻接矩阵:大多数元素为零](../images/social_network_matrix.svg) + +- **置换矩阵**是通过重排单位矩阵的行得到的。乘以它会打乱向量的元素。每行每列恰好有一个 1,其余为 0。 + +- 例如,下面的矩阵将元素 3 移到位置 1,元素 1 移到位置 2,元素 2 移到位置 3: + +```math +P = \begin{bmatrix} 0 & 0 & 1 \\ 1 & 0 & 0 \\ 0 & 1 & 0 \end{bmatrix} +``` + +- **托普利茨矩阵**沿每条对角线(左上到右下)具有相同的值。注意每条对角线是如何恒定的: + +```math +T = \begin{bmatrix} a & b & c \\ d & a & b \\ e & d & a \end{bmatrix} +``` + +- 这种结构出现在信号处理和卷积中,因为将固定滤波器滑过信号等价于乘以托普利茨矩阵。 + +- **循环矩阵**是一种特殊的托普利茨矩阵,其中每一行是上一行的循环移位。当一行到达末尾时,它会绕回: + +```math +C = \begin{bmatrix} 1 & 3 & 2 \\ 2 & 1 & 3 \\ 3 & 2 & 1 \end{bmatrix} +``` + +- 循环矩阵与离散傅里叶变换(DFT)密切相关,并且是循环卷积如何工作的核心。 + +- **埃尔米特矩阵**是对称矩阵在复数域中的等价形式:$A = A^\ast$(其中 $A^\ast$ 是共轭转置)。 + +- 对于实值矩阵,埃尔米特矩阵和对称矩阵是一回事。你会在量子计算和信号处理中遇到它们。 + +- **酉矩阵**是正交矩阵在复数域中的等价形式:$U^\ast U = UU^\ast = I$。正如正交矩阵在实空间中保持长度,酉矩阵在复空间中保持长度。 + +- **幂等矩阵**满足 $A^2 = A$。应用变换两次等同于应用一次,这使得它成为一个**投影**。一旦你投影了,再次投影不会改变任何东西。 + +- **幂零矩阵**满足对某个幂次 $k$ 有 $A^k = O$(零矩阵)。应用变换足够多次后,所有东西都坍缩为零。例如: + +```math +\begin{bmatrix} 0 & 1 \\ 0 & 0 \end{bmatrix}^2 = \begin{bmatrix} 0 & 0 \\ 0 & 0 \end{bmatrix} +``` + +- **布尔矩阵**(或二元矩阵)只包含 0 和 1。它表示是/否关系。例如,在一个有 3 个节点的图中,**邻接矩阵**记录哪些节点相连: + +```math +B = \begin{bmatrix} 0 & 1 & 1 \\ 1 & 0 & 0 \\ 1 & 0 & 0 \end{bmatrix} +``` + +- 这里,节点 1 连接到节点 2 和 3,但节点 2 和 3 之间没有连接。 + +- **范德蒙矩阵**由一组值的连续幂次构成。给定值 $x_1, x_2, x_3$: + +```math +V = \begin{bmatrix} 1 & x_1 & x_1^2 \\ 1 & x_2 & x_2^2 \\ 1 & x_3 & x_3^2 \end{bmatrix} +``` + +- 这种结构出现在多项式插值中:找到通过给定点集的唯一多项式。 + +- **海森堡矩阵**是"几乎"三角的,在第一次次对角线以下全为零: + +```math +H = \begin{bmatrix} 4 & 2 & 1 \\ 3 & 5 & -1 \\ 0 & 1 & 6 \end{bmatrix} +``` + +- 它是有效计算特征值的有用中间形式。先将矩阵化为海森堡形式可以使迭代算法收敛更快。 + +## 编程练习(使用 CoLab 或 notebook) + +1. 创建一个正交矩阵(旋转矩阵),乘以其转置,验证得到单位矩阵。尝试不同的角度。 +```python +import jax.numpy as jnp + +theta = jnp.pi / 4 +Q = jnp.array([[jnp.cos(theta), -jnp.sin(theta)], + [jnp.sin(theta), jnp.cos(theta)]]) + +print(f"Q @ Q.T:\n{Q @ Q.T}") +print(f"Determinant: {jnp.linalg.det(Q):.2f}") +``` + +2. 创建一个对称矩阵并验证它等于其转置。然后计算其特征值并检查特征向量是否垂直。 +```python +import jax.numpy as jnp + +S = jnp.array([[4.0, 2.0], + [2.0, 3.0]]) + +print(f"Symmetric: {jnp.allclose(S, S.T)}") + +eigenvalues, eigenvectors = jnp.linalg.eigh(S) +print(f"Eigenvalues: {eigenvalues}") +print(f"Dot product of eigenvectors: {jnp.dot(eigenvectors[:, 0], eigenvectors[:, 1]):.6f}") +``` diff --git a/chapter 02: matrices/03. operations.md b/chapter 02: matrices/03. operations.md new file mode 100644 index 0000000..5a9f3cb --- /dev/null +++ b/chapter 02: matrices/03. operations.md @@ -0,0 +1,146 @@ +# 矩阵运算 + +*矩阵运算是深度学习的计算引擎。本文涵盖矩阵加法、标量乘法、矩阵-向量积、矩阵乘法、逐元素运算、Kronecker积和广播——支撑每一次前向传播和梯度更新的运算。* + +- 矩阵可以像向量一样进行加法和缩放。 + +- 加法要求两个矩阵维度相同,然后逐元素相加: + +```math +\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} + \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix} = \begin{bmatrix} 6 & 8 \\ 10 & 12 \end{bmatrix} +``` + +- 标量乘法将每个元素乘以标量: + +```math +3 \times \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} = \begin{bmatrix} 3 & 6 \\ 9 & 12 \end{bmatrix} +``` + +- 矩阵能做的最简单的事情是乘以一个向量。**矩阵-向量乘法** $A\mathbf{x}$ 使用 $\mathbf{x}$ 的分量作为权重来组合 $A$ 的列: + +```math +\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} \begin{bmatrix} 5 \\ 6 \end{bmatrix} = 5 \begin{bmatrix} 1 \\ 3 \end{bmatrix} + 6 \begin{bmatrix} 2 \\ 4 \end{bmatrix} = \begin{bmatrix} 17 \\ 39 \end{bmatrix} +``` + +- 这是机器学习中的核心运算。每个神经网络层都计算 $A\mathbf{x} + \mathbf{b}$:矩阵乘以输入向量,再加上偏置。 + +- 一般情况是**矩阵乘法**。给定 $A$($m \times n$)和 $B$($n \times p$),乘积 $C = AB$ 是一个 $m \times p$ 矩阵,每个元素都是一个点积: + +$$C_{ij} = \sum_{k=1}^{n} A_{ik} B_{kj}$$ + +- 结果中的每个条目都是 $A$ 的一行与 $B$ 的一列的点积。内部维度必须匹配($n$),结果取外部维度($m \times p$)。 + +- 另一种理解方式:结果的每一列都是 $A$ 的列的**加权和**,其中权重来自 $B$ 的对应列。 + +- 如果 $B$ 的某一列为 $[2, 3]^T$,则结果列就是 $2 \times (\text{A的第1列}) + 3 \times (\text{A的第2列})$。 + +- 一个有用的特例:矩阵与其转置相乘总是得到一个方阵。$AA^T$ 是 $m \times m$,$A^TA$ 是 $n \times n$: + +```math +\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} \begin{bmatrix} 1 & 4 \\ 2 & 5 \\ 3 & 6 \end{bmatrix} = \begin{bmatrix} 14 & 32 \\ 32 & 77 \end{bmatrix} +``` + +- 矩阵乘法有重要的运算规则: + + - **不满足交换律**:通常 $AB \neq BA$。顺序很重要。 + + - **满足结合律**:$(AB)C = A(BC)$。你可以任意分组乘法。 + + - **满足分配律**:$A(B + C) = AB + AC$。 + + - **单位矩阵**:$AI = IA = A$。 + +- **Hadamard积**(逐元素乘积)将两个相同大小的矩阵逐项相乘,记作 $A \odot B$: + +```math +\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix} \odot \begin{bmatrix} 5 & 6 \\ 7 & 8 \end{bmatrix} = \begin{bmatrix} 5 & 12 \\ 21 & 32 \end{bmatrix} +``` + +- 与标准矩阵乘法不同,Hadamard积满足交换律($A \odot B = B \odot A$),且要求两个矩阵维度相同。它在机器学习中广泛用于门控机制:通过与一个取值在0到1之间的掩码逐元素相乘,控制每个条目"通过"多少。 + +- 两个向量 $\mathbf{u}$ 和 $\mathbf{v}$ 的**外积**产生一个矩阵:$\mathbf{u}\mathbf{v}^T$。每个条目是 $\mathbf{u}$ 的一个元素与 $\mathbf{v}$ 的一个元素的乘积: + +```math +\begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix} \begin{bmatrix} 4 & 5 \end{bmatrix} = \begin{bmatrix} 4 & 5 \\ 8 & 10 \\ 12 & 15 \end{bmatrix} +``` + +- 结果总是秩为1,因为每一行都是 $\mathbf{v}^T$ 的缩放版本。任何矩阵都可以写成秩-1外积之和,这正是SVD所做的事情(见分解章节)。 + +- 矩阵乘法的计算开销很大。两个 $n \times n$ 矩阵相乘需要 $O(n^3)$ 次运算。对于一个 $1000 \times 1000$ 的矩阵,那就是十亿次乘法。 + +- 当矩阵是**稀疏的**(大部分为零)时,朴素的乘法会浪费时间乘以零。**压缩稀疏行(CSR)**格式只存储非零元素及其位置: + + - **值**:按行顺序排列的非零条目 + - **列索引**:每个值属于哪一列 + - **行偏移**:每一行在值列表中的起始位置 + +- 例如,矩阵: + +```math +A = \begin{bmatrix} 5 & 0 & 0 & 2 \\ 0 & 0 & 3 & 0 \\ 0 & 0 & 0 & -1 \end{bmatrix} +``` + +- 存储为:values = [5, 2, 3, -1], columns = [0, 3, 2, 3], row offsets = [0, 2, 3, 4]。这跳过了所有零,使稀疏运算快得多。 + +- 矩阵的一个核心用途是求解**线性方程组**。方程组 $A\mathbf{x} = \mathbf{b}$ 问的是:"什么向量 $\mathbf{x}$ 被 $A$ 变换后,会得到 $\mathbf{b}$?" + +- 例如,假设你在买水果。苹果每个 $x_1$ 元,香蕉每个 $x_2$ 元。已知2个苹果和1个香蕉共5元,1个苹果和3个香蕉共10元。用矩阵形式表示: + +```math +\begin{bmatrix} 2 & 1 \\ 1 & 3 \end{bmatrix} \begin{bmatrix} x_1 \\ x_2 \end{bmatrix} = \begin{bmatrix} 5 \\ 10 \end{bmatrix} +``` + +- 矩阵逐行乘以向量(每一行与 $[x_1, x_2]^T$ 点积)得到两个方程: + +$$2x_1 + 1x_2 = 5 \qquad \text{(第1行)} \qquad \qquad x_1 + 3x_2 = 10 \qquad \text{(第2行)}$$ + +- 从第1行得 $x_2 = 5 - 2x_1$。代入第2行:$x_1 + 3(5 - 2x_1) = 10$,解得 $x_1 = 1$,则 $x_2 = 3$。苹果每个1元,香蕉每个3元。 + +- 验证——结果正确: + +```math +\begin{bmatrix} 2 & 1 \\ 1 & 3 \end{bmatrix} \begin{bmatrix} 1 \\ 3 \end{bmatrix} = \begin{bmatrix} 2 + 3 \\ 1 + 9 \end{bmatrix} = \begin{bmatrix} 5 \\ 10 \end{bmatrix} +``` + +- 如果 $A$ 有逆矩阵,解就是简单的 $\mathbf{x} = A^{-1}\mathbf{b}$。但直接计算逆矩阵代价高昂且数值不稳定。实践中我们使用分解方法。 + +- 并非所有矩阵都是方阵,也不是所有方阵都可逆。**伪逆** $A^+$ 将逆推广到任意矩阵。它总是存在,并提供"尽可能好的"逆: + +$$A^+ = (A^TA)^{-1}A^T$$ + +- 当 $A$ 是下三角矩阵时,通过**前向代入**求解 $L\mathbf{x} = \mathbf{b}$ 很容易:先解出 $x_1$,然后用它求出 $x_2$,依此类推。 + +- 当 $A$ 是上三角矩阵时,通过**回代**求解 $U\mathbf{x} = \mathbf{b}$:先解出最后一个变量,然后向上求解。 + +- 这就是为什么将矩阵分解为三角因子(如分解章节所述)如此有用——它将一个难题转化为两个简单问题。 + +## 编程练习(使用CoLab或Jupyter Notebook) + +1. 将两个矩阵相乘并验证维度。然后交换顺序,观察结果如何变化(或者,如果维度不匹配,运算失败)。 + +```python +import jax.numpy as jnp + +A = jnp.array([[1.0, 2.0], + [3.0, 4.0]]) +B = jnp.array([[5.0, 6.0], + [7.0, 8.0]]) + +print(f"A @ B:\n{A @ B}") +print(f"B @ A:\n{B @ A}") +print(f"Equal: {jnp.allclose(A @ B, B @ A)}") +``` + +2. 求解线性方程组 $A\mathbf{x} = \mathbf{b}$,并通过回代乘法验证解。尝试改变 $\mathbf{b}$,观察解如何变化。 + +```python +import jax.numpy as jnp + +A = jnp.array([[2.0, 1.0], + [5.0, 3.0]]) +b = jnp.array([4.0, 7.0]) + +x = jnp.linalg.solve(A, b) +print(f"Solution x: {x}") +print(f"A @ x: {A @ x}") +``` diff --git a/chapter 02: matrices/04. linear transformations.md b/chapter 02: matrices/04. linear transformations.md new file mode 100644 index 0000000..ab895d0 --- /dev/null +++ b/chapter 02: matrices/04. linear transformations.md @@ -0,0 +1,163 @@ +# 线性变换 + +*每个矩阵乘法都是一个线性变换——一个在保持线性性质的同时重塑、旋转或投影向量的函数。本文涵盖旋转、反射、缩放、剪切、投影、映射的核与像,以及神经网络层如何串联这些变换。* + +- **线性变换**(或线性映射)是一个接收向量并产生另一个向量的函数,同时保持加法和缩放性质。如果 $T$ 是线性的,则: + + - $T(\mathbf{u} + \mathbf{v}) = T(\mathbf{u}) + T(\mathbf{v})$ + - $T(c\mathbf{u}) = cT(\mathbf{u})$ + +- 每个线性变换都可以表示为矩阵乘法。矩阵*就是*变换本身。当你用一个矩阵乘以一个向量时,就是在对它施加一个线性变换。 + +- 可以把一个 $2 \times 2$ 矩阵想象成一个机器:它接收二维向量,输出新的二维向量。矩阵的列告诉你标准基向量 $\hat{\mathbf{i}}$ 和 $\hat{\mathbf{j}}$ 经过变换后到了哪里。其余一切都由线性性质导出。 + +![矩阵的列显示了基向量落在何处](../images/basis_transform.svg) + +- 例如,如果 + +```math +A = \begin{bmatrix} 2 & 1 \\ 1 & 2 \end{bmatrix} +``` + + 那么 $\hat{\mathbf{i}} = [1, 0]^T$ 落在 $[2, 1]^T$(第1列),$\hat{\mathbf{j}} = [0, 1]^T$ 落在 $[1, 2]^T$(第2列)。其他所有向量都是这两个向量的组合,因此其输出自动遵循。 + +- 将两个矩阵相乘可以理解为依次施加两个变换。如果 $B$ 将向量从一个空间变换,然后 $A$ 变换结果,那么 $AB$ 按顺序完成这两个操作。在游戏引擎中,先旋转角色再向前移动,与先移动再旋转,结果完全不同——这就是矩阵乘法不满足交换律的原因。 + +- **旋转**将向量绕一定角度 $\theta$ 转动而不改变其长度。向量大小不变,只是指向新的方向。 + +![旋转保持长度不变但改变方向](../images/rotation.svg) + +- 二维中的旋转矩阵为: + +```math +R(\theta) = \begin{bmatrix} \cos\theta & -\sin\theta \\ \sin\theta & \cos\theta \end{bmatrix} +``` + +- 当 $\theta = 90°$ 时: + +```math +R = \begin{bmatrix} 0 & -1 \\ 1 & 0 \end{bmatrix} +``` + + 因此 $[1, 0]^T$ 变成 $[0, 1]^T$。原来指向右侧的向量现在指向上方。旋转矩阵是正交的,且行列式始终为1。当你在手机上旋转照片时,就是对每个像素坐标应用这个矩阵。 + +- 在三维中,每个坐标轴都有对应的旋转矩阵。机械臂的每个关节绕特定轴旋转,每个关节就是一个旋转矩阵。绕z轴旋转看起来像是嵌入三维的二维情况: + +```math +R_z(\theta) = \begin{bmatrix} \cos\theta & -\sin\theta & 0 \\ \sin\theta & \cos\theta & 0 \\ 0 & 0 & 1 \end{bmatrix} +``` + +- **缩放**沿每个坐标轴独立地拉伸或压缩向量: + +```math +S(s_x, s_y) = \begin{bmatrix} s_x & 0 \\ 0 & s_y \end{bmatrix} +``` + +![缩放沿每个轴以不同因子拉伸](../images/scaling.svg) + +- $S(2, 1.5)$ 将x分量加倍,y分量乘以1.5。沿某轴缩放 $-1$ 会翻转该分量。对角矩阵总是缩放变换。当你将图片缩小到50%时,就是对每个像素坐标应用 $S(0.5, 0.5)$。 + +- **反射**像镜子一样将向量翻转到某个轴或直线的另一侧。沿x轴的反射保持x分量不变,取反y分量: + +```math +\text{Ref}_x = \begin{bmatrix} 1 & 0 \\ 0 & -1 \end{bmatrix} +``` + +![沿x轴反射翻转y分量](../images/reflection.svg) + +- 例如,$[3, 2]^T$ 变成 $[3, -2]^T$。当你的手机水平翻转自拍照使文字正确显示时,就是在应用反射矩阵。沿直线 $y = x$ 的反射交换两个分量: + +```math +\text{Ref}_{y=x} = \begin{bmatrix} 0 & 1 \\ 1 & 0 \end{bmatrix} +``` + +- 反射矩阵的行列式为 $-1$,表明它们翻转了方向。 + +- 旋转和反射都是**刚性变换**:它们保持距离和角度不变。表示这些变换的矩阵是正交矩阵,这就是为什么正交矩阵的行列式总是 $+1$(旋转)或 $-1$(反射)。 + +- **剪切**沿一个坐标轴按另一坐标轴的比例倾斜向量。水平剪切因子 $k$: + +```math +\text{Sh}_x(k) = \begin{bmatrix} 1 & k \\ 0 & 1 \end{bmatrix} +``` + +![剪切使顶部侧向滑动而底部保持不动](../images/shearing.svg) + +- 每个点水平滑动 $k$ 倍于其高度的距离。当 $k = 0.5$ 时,高度为2的点向右移动1。最下面一行保持不动,最上面一行滑动最多。这就是斜体文字的工作原理:正立的字母被剪切,从而向右倾斜。 + +- 以上所有变换(旋转、缩放、反射、剪切)都是**线性**变换。它们保持原点固定,并保持直线为直线。但**平移**(将所有点按固定量移动)呢? + +- 平移*不是*线性变换,因为它移动了原点。如果将每个点向右移动3,零向量会移动到 $[3, 0]^T$,从而破坏了线性性质。为了处理平移,我们使用**仿射变换**,它将线性变换与平移结合起来: + +$$\mathbf{y} = A\mathbf{x} + \mathbf{t}$$ + +- 为了将其表示为单个矩阵乘法,我们使用**齐次坐标**:为每个向量添加一个额外的1,并使用一个 $(n+1) \times (n+1)$ 的矩阵: + +```math +\begin{bmatrix} A & \mathbf{t} \\ \mathbf{0}^T & 1 \end{bmatrix} \begin{bmatrix} \mathbf{x} \\ 1 \end{bmatrix} = \begin{bmatrix} A\mathbf{x} + \mathbf{t} \\ 1 \end{bmatrix} +``` + +- 仿射变换保持直线和平行性,但不一定保持角度或长度。电子游戏中的每个物体都使用仿射变换来定位:旋转它、缩放它,然后放置到正确的位置——所有这些都编码在一个矩阵中。 + +- **退化变换**(奇异矩阵)将空间坍缩到更低维度。 + +- 例如,矩阵 + +```math +\begin{bmatrix} 1 & 2 \\ 2 & 4 \end{bmatrix} +``` + + 将每个二维向量映射到一条直线上,因为两列指向同一方向。行列式为零,信息丢失,且该变换不可逆。 + +- 将彩色图像(每个像素有3个值:红、绿、蓝)转换为灰度图(每个像素1个值)就是退化变换:颜色信息永久丢失。 + +- 在机器学习中,线性变换是神经网络的核心。数据被表示为矩阵(向量的堆叠,这些向量代表对象的特征——人、飞机、文本、图像……任何东西!) + +- 每一层应用一个矩阵乘法(线性变换),详细内容在其他章节中提供,我们需要解释如何组织这些数据并恰当地引出神经网络。 + +- 然而,当今最常用的技术几乎完全是将数据通过一系列线性变换传递,我们称之为**Transformer**。 + +- Gemini、ChatGPT、Claude、Qwen、DeepSeek以及当今世界上性能最好的AI,都是Transformer! + +## 编程练习(使用CoLab或Jupyter Notebook) + +1. 对向量应用旋转矩阵,并绘制原始向量和旋转后的向量。尝试不同的角度。 + +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +theta = jnp.pi / 3 +R = jnp.array([[jnp.cos(theta), -jnp.sin(theta)], + [jnp.sin(theta), jnp.cos(theta)]]) + +v = jnp.array([1.0, 0.0]) +v_rot = R @ v + +plt.figure(figsize=(5, 5)) +plt.quiver(0, 0, v[0], v[1], angles='xy', scale_units='xy', scale=1, color='red', label='original') +plt.quiver(0, 0, v_rot[0], v_rot[1], angles='xy', scale_units='xy', scale=1, color='blue', label='rotated') +plt.xlim(-1.5, 1.5); plt.ylim(-1.5, 1.5) +plt.grid(True); plt.legend(); plt.gca().set_aspect('equal') +plt.show() +``` + +2. 对构成正方形的一组点应用剪切变换,并可视化变形后的形状。 + +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +square = jnp.array([[0,0],[1,0],[1,1],[0,1],[0,0]]).T + +k = 0.5 +shear = jnp.array([[1, k], + [0, 1]]) +sheared = shear @ square + +plt.figure(figsize=(6, 4)) +plt.plot(square[0], square[1], 'r-o', label='original') +plt.plot(sheared[0], sheared[1], 'b-o', label='sheared') +plt.grid(True); plt.legend(); plt.gca().set_aspect('equal') +plt.show() +``` diff --git a/chapter 02: matrices/05. decompositions.md b/chapter 02: matrices/05. decompositions.md new file mode 100644 index 0000000..8152d83 --- /dev/null +++ b/chapter 02: matrices/05. decompositions.md @@ -0,0 +1,211 @@ +# 矩阵分解 + +*矩阵分解将复杂矩阵拆分为更简单的因子,用于求解方程组、计算逆矩阵和数据压缩。本文涵盖高斯消元、LU、QR、Cholesky、特征分解和SVD——这些算法是PCA、推荐系统和机器学习数值稳定性的基石。* + +- 矩阵分解(或因子分解)将一个矩阵拆分成更容易处理的更简单的部分。可以把它类比为因数分解:$12 = 3 \times 4$ 比单独的12更容易理解。 + +- 我们分解矩阵是为了更快地求解方程组、稳定地计算逆矩阵、寻找特征值、压缩数据以及理解变换的几何结构。 + +- 最基本的技术是**高斯消元**(行化简)。思路很简单:给定方程组 $A\mathbf{x} = \mathbf{b}$,使用三种允许的操作简化 $A$,直到答案显而易见。 + +- 这些操作是:交换两行、将一行乘以非零标量、或将一行的倍数加到另一行上。 + +- 例如,要消除主元下方的第一列,从下面的行中减去第1行的倍数: + +```math +\begin{bmatrix} 2 & 1 & 5 \\ 4 & 3 & 7 \\ 6 & 5 & 9 \end{bmatrix} \xrightarrow{R_2 - 2R_1} \begin{bmatrix} 2 & 1 & 5 \\ 0 & 1 & -3 \\ 6 & 5 & 9 \end{bmatrix} \xrightarrow{R_3 - 3R_1} \begin{bmatrix} 2 & 1 & 5 \\ 0 & 1 & -3 \\ 0 & 2 & -6 \end{bmatrix} +``` + +- 目标是**行阶梯形(REF)**:每个主元(每行第一个非零条目)下方全为零,且每个主元在其上方主元的右侧。矩阵呈现阶梯形状。 + +![高斯消元:行操作产生三角形形式,然后从下往上求解](../images/gaussian_elimination.svg) + +- 进一步得到**简化行阶梯形(RREF)**,使每个主元为1且是该列中唯一的非零条目。每个矩阵有唯一的RREF。 + +- 一旦转换为三角形形式,我们通过**回代**求解:最下面一行直接给出最后一个变量,然后向上求解。 + +- 这是所有其他分解方法所建立的基础,分解的目标就是将矩阵简化为三角形形式,从而可以通过回代求解变量。 + +- **LU分解**将高斯消元形式化,将方阵分解为 $A = LU$(或通过行交换得到 $A = PLU$),其中 $L$ 是下三角矩阵,$U$ 是上三角矩阵。 + +![LU分解:将一个困难的矩阵拆分为两个简单的三角矩阵](../images/lu_decomposition.svg) + +- 求解 $A\mathbf{x} = \mathbf{b}$:先通过前向代入(从上到下)求解 $L\mathbf{y} = \mathbf{b}$,然后通过回代(从下到上)求解 $U\mathbf{x} = \mathbf{y}$。两次简单的三角求解代替了一次困难的一般求解。 + +- 相比原始高斯消元的优势在于可复用。一旦得到 $L$ 和 $U$,就可以对许多不同的 $\mathbf{b}$ 向量求解,而无需重新进行分解。 + +- 如果你需要用1000个不同的右端项求解同一个方程组(这在模拟中很常见),只需分解一次然后重复使用。 + +- 当矩阵是对称正定矩阵时(如协方差矩阵),我们可以做得更好。 + +- **Cholesky分解**将其分解为 $A = LL^T$,其中 $L$ 是下三角矩阵。例如: + +```math +\begin{bmatrix} 4 & 2 \\ 2 & 5 \end{bmatrix} = \begin{bmatrix} 2 & 0 \\ 1 & 2 \end{bmatrix} \begin{bmatrix} 2 & 1 \\ 0 & 2 \end{bmatrix} +``` + +- 这大约比LU快两倍,并且保证数值稳定。可以将其视为矩阵的"平方根"。 + +- 如果分解失败(平方根下出现负值),则该矩阵不是正定的。因此Cholesky分解也可以作为正定性的检验方法。 + +- 方阵 $A$ 的**特征向量**是特殊方向,该变换在这些方向上只进行拉伸或压缩,而不旋转。**特征值**是缩放因子: + +$$A\mathbf{x} = \lambda\mathbf{x}$$ + +![特征向量保持在同一直线上(仅被缩放),普通向量被旋转](../images/eigenvector.svg) + +- 大多数向量在乘以矩阵时方向会改变。但特征向量是特殊的:输出方向与输入方向相同,仅被 $\lambda$ 缩放。如果 $\lambda = 2$,特征向量长度加倍。如果 $\lambda = -1$,它翻转方向。如果 $\lambda = 0$,它被压缩为零。 + +- 例如,对于: + +```math +A = \begin{bmatrix} 3 & 1 \\ 0 & 2 \end{bmatrix} +``` + + 向量 $[1, 0]^T$ 是特征向量,$\lambda = 3$,因为 $A[1, 0]^T = [3, 0]^T = 3[1, 0]^T$。 + +- 求特征值需要解**特征多项式** $\det(A - \lambda I) = 0$。根即为特征值。然后将每个 $\lambda$ 代回 $(A - \lambda I)\mathbf{x} = \mathbf{0}$ 中,求出对应的特征向量。 + +- 关键性质: + + - $A$ 的迹等于其特征值之和。 + - $A$ 的行列式等于其特征值之积。 + - 对称矩阵的特征向量互相垂直,特征值为实数。 + - 正定矩阵的所有特征值为正。 + - 协方差矩阵(我们将在统计学中遇到)总是半正定的。 + +- 通过特征多项式计算特征值对于大型矩阵来说是不切实际的。相反,使用迭代方法: + + - **幂迭代**:反复乘以 $A$ 并归一化。收敛到主特征向量(最大特征值)。简单但只能找到一个特征对。 + + - **QR算法**:最常用的方法。使用QR分解反复分解和重组矩阵,直到矩阵收敛到三角形形式,对角线上的元素即为所有特征值。 + + - **反迭代**:寻找最接近给定目标值的特征向量。当你大致知道想要哪个特征值时很有用。 + + - 对于大型稀疏矩阵,**Arnoldi**和**Lanczos**迭代利用稀疏性提高效率。 + +- 如果方阵有一组完整的线性无关的特征向量,它可以被**对角化**:$A = PDP^{-1}$,其中 $D$ 是以特征值为对角元的对角矩阵,$P$ 的列是特征向量。 + +- 这有什么用?对角矩阵非常容易处理。需要计算 $A^{100}$?不用将 $A$ 自乘100次,计算 $PD^{100}P^{-1}$ 即可——而对角矩阵的幂只需独立地对每个对角元求幂。这将一个昂贵的运算变成了廉价运算。 + +- **特征基**是完全由特征向量构成的基。在这个基下,矩阵变成对角矩阵,变换仅仅是沿每个特征向量方向的独立缩放。这就像是找到了变换的自然坐标系。 + +- **QR分解**将任意矩阵 $A$ 分解为 $A = QR$,其中 $Q$ 是正交矩阵(其列是标准正交的),$R$ 是上三角矩阵。可以理解为将"方向"信息($Q$)与"缩放和混合"信息($R$)分开。 + +- **Gram-Schmidt过程**逐列构建 $Q$。取 $A$ 的第一列并归一化。取第二列,减去其在第一列上的投影(使其垂直),再归一化。对每一列重复此过程。结果是一组标准正交向量。 + +- QR分解是QR算法求特征值背后的引擎。它也直接用于求解最小二乘问题:当 $A\mathbf{x} = \mathbf{b}$ 没有精确解(方程多于未知数)时,QR找到最佳近似解。 + +- **SVD**(奇异值分解)是最通用、也可以说是最重要的分解。每个矩阵(任意形状、任意秩)都有SVD:$A = U\Sigma V^T$ + + - $V^T$($n \times n$,正交):旋转输入 + - $\Sigma$($m \times n$,对角):沿正交坐标轴缩放(奇异值,非负,递减排列) + - $U$($m \times m$,正交):旋转输出 + +![SVD:任何变换 = 旋转,然后缩放,再旋转](../images/svd.svg) + +- 几何上,SVD表明每个线性变换,无论多么复杂,都只是一个旋转、一个沿坐标轴的拉伸、再一个旋转的组合。一个圆变成了一个椭圆。 + +- 奇异值($\sigma_1 \geq \sigma_2 \geq \ldots$)揭示了每个方向的"重要性"。大的奇异值对应最重要的方向。$A$ 的秩等于非零奇异值的个数。 + +- **低秩近似**:只保留最大的 $k$ 个奇异值,将其他置零,就得到了 $A$ 的最佳秩-$k$ 近似。这就是图像压缩的原理:一张 $1000 \times 1000$ 的图像可能只需要 $k = 50$ 个奇异值就能看起来几乎一模一样,压缩了20倍。 + +- SVD也提供了伪逆:$A^+ = V\Sigma^+U^T$,其中 $\Sigma^+$ 是对非零奇异值取倒数。 + +- 特征分解只对方阵有效,而SVD对任意矩阵都有效。这是它的关键优势。 + +- **PCA**(主成分分析)使用特征分解(或SVD)进行降维。 + +- 想象一个数据集,每个样本有100个特征(堆叠成矩阵的100维向量)。其中许多特征是相关的、冗余的。 + +- PCA找到数据实际变化的那些方向,让你只保留重要的部分。 + +![PCA找到数据中方差最大的方向](../images/pca.svg) + +- 第一主成分(PC1)是方差最大的方向。 + +- 第二主成分(PC2)捕获剩余部分的最大方差,且与第一主成分垂直。 + +- 如果大部分方差只集中在少数几个方向上,你可以将数据投影到这些维度上,丢弃其余部分,损失极小。 + +- 步骤: + + - 标准化数据(减去均值,除以标准差),使所有特征贡献平等 + - 计算协方差矩阵 + - 求其特征值和特征向量 + - 选择 $k$ 个最大特征值对应的特征向量(即主成分) + - 将数据投影到这些主成分上 + +- 标准化至关重要:如果不做标准化,用公里测量的特征会主导用厘米测量的特征,而不论其实际重要性如何。 + +- 在实践中,PCA用于可视化(将高维数据投影到2D或3D)、降噪(丢弃主要是噪声的低方差方向),以及通过减少输入特征数量来加速机器学习模型。 + +- **核PCA**将PCA扩展到非线性关系。它通过核函数将数据映射到更高维空间,在那里结构变得线性,然后应用标准PCA并投影回来。 + +- **Schur分解**将方阵分解为 $A = QTQ^\ast$,其中 $Q$ 是酉矩阵,$T$ 是上三角矩阵。每个方阵都有Schur分解,即使它不能被对角化。 + +- **非负矩阵分解(NMF)** 将一个矩阵分解为两个非负矩阵:$A \approx WH$,其中 $W$ 和 $H$ 的所有条目都 $\geq 0$。与可能产生负条目的SVD不同,NMF只做加法,从不做减法。这使得各部分可解释:在主题建模中,$W$ 给出每个文档的主题权重,$H$ 给出每个主题的词权重,全部非负,这与我们对文档"包含多少某个主题"的思考方式相符。 + +- **谱定理**指出,对称(或Hermitian)矩阵总可以用正交(或酉)矩阵对角化。它们的特征值总是实数,特征向量总是正交的。这是PCA的理论基础。 + +## 编程练习(使用CoLab或Jupyter Notebook) + +1. 计算对称矩阵的特征值和特征向量。验证特征向量互相垂直,并从特征分解重建矩阵。 + +```python +import jax.numpy as jnp + +A = jnp.array([[4.0, 2.0], + [2.0, 3.0]]) + +eigenvalues, eigenvectors = jnp.linalg.eigh(A) +print(f"Eigenvalues: {eigenvalues}") +print(f"Eigenvectors orthogonal: {jnp.dot(eigenvectors[:,0], eigenvectors[:,1]):.6f}") + +# Reconstruct: A = P D P^T +D = jnp.diag(eigenvalues) +A_reconstructed = eigenvectors @ D @ eigenvectors.T +print(f"Reconstruction matches: {jnp.allclose(A, A_reconstructed)}") +``` + +2. 实现幂迭代求最大特征值,以及反迭代求最小特征值。与 `jnp.linalg.eigh` 比较。然后尝试自己实现QR算法。 + +```python +import jax.numpy as jnp + +A = jnp.array([[4.0, 2.0], + [2.0, 3.0]]) + +# Power iteration: finds the LARGEST eigenvalue +v = jnp.array([1.0, 0.0]) +for _ in range(20): + v = A @ v + v = v / jnp.linalg.norm(v) +print(f"Largest eigenvalue: {v @ A @ v:.4f}") + +# Inverse iteration: multiply by A^{-1} instead of A, finds the SMALLEST eigenvalue +v = jnp.array([1.0, 0.0]) +for _ in range(20): + v = jnp.linalg.solve(A, v) + v = v / jnp.linalg.norm(v) +print(f"Smallest eigenvalue: {1.0 / (v @ jnp.linalg.solve(A, v)):.4f}") + +print(f"jnp.linalg.eigh: {jnp.linalg.eigh(A)[0]}") +``` + +3. 计算矩阵的SVD,然后仅使用前k个奇异值重建矩阵,观察近似质量随k的变化。 + +```python +import jax.numpy as jnp + +A = jnp.array([[1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0]]) + +U, S, Vt = jnp.linalg.svd(A) + +for k in [1, 2, 3]: + approx = U[:, :k] @ jnp.diag(S[:k]) @ Vt[:k, :] + error = jnp.linalg.norm(A - approx) + print(f"k={k}, reconstruction error: {error:.4f}") +``` diff --git a/chapter 03: calculus/01. differential calculus.md b/chapter 03: calculus/01. differential calculus.md new file mode 100644 index 0000000..e0fe6b2 --- /dev/null +++ b/chapter 03: calculus/01. differential calculus.md @@ -0,0 +1,205 @@ +# 微分 + +*微分学研究瞬时变化率。本节涵盖极限、导数、微分法则、链式法则(反向传播的基础),以及机器学习中常用的导数。* + +- 在前面的章节中,我们学会了如何将数据表示为向量,并用矩阵对其进行变换。但现实世界中的许多现象并非静止不变的。汽车在加速,股价在波动,神经网络的损失随着权重的更新而变化。**微积分**是研究变化的数学。 + +- 微积分回答两个问题:某个量在当前时刻变化得有多快?(微分学)以及它在一段时间内累积了多少?(积分学)。本节讨论的是"多快"的问题。 + +- 想象一下你正在开车,瞥了一眼车速表。上面显示 60 km/h。这个数字不是你整个行程的平均速度,而是你在这一瞬间的瞬时速度。微分学为我们提供了计算这种瞬时变化率的工具。 + +- 但首先,让我们回顾一下直线方程:$y = mx + b$。 + +- 这是两个量之间最简单的关系。 + + - $b$ 是 **y 截距**,即直线与 y 轴的交点(当 $x = 0$ 时的起始值)。 + - $m$ 是**斜率**,即变化率:$x$ 每增加 1 个单位,$y$ 就变化 $m$ 个单位。 +- 如果 $m = 3$,直线陡峭上升;如果 $m = 0$,直线水平;如果 $m = -2$,直线下降。 + +- 斜率计算公式为 $m = \frac{\Delta y}{\Delta x} = \frac{y_2 - y_1}{x_2 - x_1}$,即"$y$ 变化了多少"与"$x$ 变化了多少"的比值。 + +![直线方程:b 是 y 截距,m 是斜率(纵坐标变化除以横坐标变化)](../images/line_equation.svg) + +- 一旦知道了 $m$ 和 $b$,就可以计算任意 $x$ 对应的 $y$ 值。 + +- 例如,若 $m = 2$ 且 $b = 3$,则在 $x = 5$ 处:$y = 2(5) + 3 = 13$。 + +- 这两个参数完全决定了这条直线,预测任何输出只需代入公式即可。 + +- 对于直线,斜率处处相同。 + +- 这一思想可以推广到直线之外。任何函数都是一个将输入映射到输出的规则,一旦知道了它的公式(参数和形状),就可以计算任意输入对应的输出,并将结果绘制成图。 + +- $y = x^2$ 给出抛物线,$y = \sin(x)$ 给出波形,$y = e^x$ 给出指数增长。每个公式都定义了一条特定的曲线,能够熟练地将函数理解为一种形状,对于后续内容至关重要。 + +- 对于直线,斜率处处相同。但大多数有趣的函数都是弯曲的,因此斜率在不同点处各不相同。微积分给了我们一种方法来求曲线上任意一点的斜率。 + +- 我们还需要**极限**的概念。极限描述的是当输入越来越接近某个目标值时,函数趋近于什么值,而不一定非要达到该值。 + +$$\lim_{x \to a} f(x) = L$$ + +- 这读作:"当 $x$ 趋近于 $a$ 时,$f(x)$ 趋近于 $L$。"函数在 $x = a$ 处不一定等于 $L$,只需无限接近即可。 + +- 例如,考虑 $f(x) = \frac{x^2 - 1}{x - 1}$。如果直接代入 $x = 1$,会得到 $\frac{0}{0}$,这是未定义的。 + +- 但尝试接近 1 的值:$f(0.9) = 1.9$,$f(0.99) = 1.99$,$f(1.01) = 2.01$。输出显然朝着 2 靠近。 + +- 从代数角度看,我们可以理解原因:将分子因式分解为 $(x-1)(x+1)$,约去 $(x-1)$ 项,对于所有 $x \neq 1$ 得到 $f(x) = x + 1$。因此当 $x \to 1$ 时,$f(x) \to 2$。 + +- 该函数在 $x = 1$ 处有一个空洞,但极限仍然存在。 + +- 极限是微积分中其他一切内容的基础。 + +- 函数 $f(x)$ 在点 $x = a$ 处的**导数**衡量的是瞬时变化率。从几何角度看,它是该点处曲线切线的斜率。 + +![导数就是曲线上某点处切线的斜率](../images/tangent_line.svg) + +- 要计算这个斜率,我们首先在曲线上取两个点,计算通过这两个点的直线(**割线**)的斜率。然后让第二个点逐渐靠近第一个点,观察割线的斜率趋近于什么值。这就是**差商**: + +$$f'(a) = \lim_{h \to 0} \frac{f(a + h) - f(a)}{h}$$ + +![随着 h 趋近于 0,割线趋近于切线](../images/difference_quotient.svg) + +- 分子 $f(a+h) - f(a)$ 是输出的变化量。分母 $h$ 是输入的变化量。它们的比值是在一个极小区间上的平均变化率。当 $h \to 0$ 时,这个平均值就变成了瞬时变化率。 + +- 例如,设 $f(x) = x^2$。在 $x = 3$ 处: + +$$f'(3) = \lim_{h \to 0} \frac{(3+h)^2 - 9}{h} = \lim_{h \to 0} \frac{9 + 6h + h^2 - 9}{h} = \lim_{h \to 0} (6 + h) = 6$$ + +- 因此在 $x = 3$ 处,函数 $x^2$ 以每单位输入变化 6 单位输出的速率增加。 + +- 如果这个极限存在,则称函数在该点是**可微**的。要做到这一点,函数必须连续(没有跳跃)、光滑(没有尖角),并且在点附近有定义。 + +- 如果你能笔不离纸地画出曲线,且没有任何折点,那么它在该点很可能是可微的。 + +- 每次都从极限定义出发计算导数会很繁琐。幸运的是,少数几条法则就能让我们快速微分几乎任何函数。 + +- **常数法则**:常数的导数为零。若 $f(x) = 5$,则 $f'(x) = 0$。水平线的斜率为零。 + +- **幂法则**:微分的主力法则。将指数提到前面,然后将指数减一: + +$$\frac{d}{dx} x^n = n x^{n-1}$$ + +- 例如:$\frac{d}{dx} x^3 = 3x^2$。三次函数变成了二次函数。该法则适用于任何实数指数,包括负数和分数:$\frac{d}{dx} x^{-1} = -x^{-2}$ 以及 $\frac{d}{dx} \sqrt{x} = \frac{d}{dx} x^{1/2} = \frac{1}{2}x^{-1/2}$。 + +- **和/差法则**:逐项求导。 + +$$\frac{d}{dx}[f(x) \pm g(x)] = f'(x) \pm g'(x)$$ + +- **乘积法则**:当两个函数相乘时,导数并非简单地将各自的导数相乘。而是: + +$$\frac{d}{dx}[f(x) \cdot g(x)] = f'(x)g(x) + f(x)g'(x)$$ + +- 可以这样理解:"第一个的变化率乘以第二个,加上第一个乘以第二个的变化率。"例如,$\frac{d}{dx}[x^2 \sin x] = 2x \sin x + x^2 \cos x$。 + +- **商法则**:对于函数的比值: + +$$\frac{d}{dx}\left[\frac{f(x)}{g(x)}\right] = \frac{f'(x)g(x) - f(x)g'(x)}{[g(x)]^2}$$ + +- 一个有用的记忆口诀:"上导下不导减去上不导下导,除以分母的平方。" + +- **链式法则**:对机器学习最重要的法则。当函数复合(一个函数嵌套在另一个函数内部)时,导数等于沿链各导数的乘积: + +$$\frac{d}{dx} f(g(x)) = f'(g(x)) \cdot g'(x)$$ + +- 可以把它想象成剥洋葱。先对外层函数求导(内层函数保持不变),然后乘以内层函数的导数。 + +![链式法则:对外层求导,乘以内层的导数](../images/chain_rule.svg) + +- 例如,$\frac{d}{dx} (3x + 1)^5 = 5(3x+1)^4 \cdot 3 = 15(3x+1)^4$。外层函数是 $(\cdot)^5$,内层是 $3x+1$。 + +- 链式法则是神经网络中**反向传播**的数学基础。一个深层网络就是一个由多个复合函数组成的长链。要计算损失相对于每个权重的变化率,我们从输出层开始逐层向输入层反复应用链式法则,将每一步的局部导数相乘。 + +- 以下是你会遇到的最常见导数。每一个都可以从极限定义推导出来,但熟记它们可以节省时间: + +| 函数 | 导数 | 备注 | +|---|---|---| +| $e^x$ | $e^x$ | 唯一一个导数等于自身的函数 | +| $a^x$ | $a^x \ln a$ | 指数函数的一般形式 | +| $\ln x$ | $\frac{1}{x}$ | 自然对数 | +| $\log_a x$ | $\frac{1}{x \ln a}$ | 一般对数 | +| $\sin x$ | $\cos x$ | | +| $\cos x$ | $-\sin x$ | 注意负号 | +| $\tan x$ | $\sec^2 x$ | | + +- 指数函数 $e^x$ 非常特别:它是唯一一个导数等于自身的函数。这就是为什么 $e$ 在机器学习中无处不在,从 softmax 激活函数到概率分布都能见到它的身影。 + +- **洛必达法则**用于处理形如 $\frac{0}{0}$ 或 $\frac{\infty}{\infty}$ 的不定式极限。当直接代入得到这类形式时,可以分别对分子和分母求导,然后再次尝试求极限: + +$$\lim_{x \to a} \frac{f(x)}{g(x)} = \lim_{x \to a} \frac{f'(x)}{g'(x)}$$ + +- 条件:$f$ 和 $g$ 都必须在 $a$ 附近可微,并且 $g'(x)$ 在 $a$ 附近(可能除去 $a$ 本身)不为零。原极限必须是不定式。 + +- 例如:$\lim_{x \to 0} \frac{\sin x}{x}$。直接代入得到 $\frac{0}{0}$。应用洛必达法则:$\lim_{x \to 0} \frac{\cos x}{1} = 1$。这个极限是基础的,在信号处理和傅里叶分析中都会出现。 + +- 如果结果仍然是不定式,可以反复应用该法则。例如,$\lim_{x \to 0} \frac{1 - \cos x}{x^2}$ 得到 $\frac{0}{0}$。第一次应用:$\lim_{x \to 0} \frac{\sin x}{2x}$,仍然是 $\frac{0}{0}$。第二次应用:$\lim_{x \to 0} \frac{\cos x}{2} = \frac{1}{2}$。 + +- 如果两个函数是可微的,那么它们的和、差、积、复合以及商(分母不为零时)也是可微的。这就是为什么我们可以自信地对由简单部分组成的复杂表达式进行微分。 + +## 编程练习(使用 CoLab 或 notebook) + +1. 可视化常见函数。在同一张图中绘制 $x^2$、$\sin(x)$ 和 $e^x$,建立对不同公式产生不同形状的直观感受。尝试修改参数(例如 $2x^2$、$\sin(2x)$),观察曲线如何变化。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +x = jnp.linspace(-3, 3, 300) + +fig, axes = plt.subplots(1, 3, figsize=(12, 3)) +axes[0].plot(x, x**2, color="#e74c3c") +axes[0].set_title("x² (抛物线)") +axes[1].plot(x, jnp.sin(x), color="#3498db") +axes[1].set_title("sin(x) (波形)") +axes[2].plot(x, jnp.exp(x), color="#27ae60") +axes[2].set_title("eˣ (指数函数)") +for ax in axes: + ax.axhline(0, color="gray", linewidth=0.5) + ax.axvline(0, color="gray", linewidth=0.5) +plt.tight_layout() +plt.show() +``` + +2. 使用 JAX 的自动微分计算 $f(x) = x^3 - 2x + 1$ 在若干点处的导数,并与解析导数 $f'(x) = 3x^2 - 2$ 进行比较。 +```python +import jax +import jax.numpy as jnp + +f = lambda x: x**3 - 2*x + 1 +df = jax.grad(f) + +for x in [0.0, 1.0, 2.0, -1.0]: + print(f"x={x:5.1f} 自动微分: {df(x):.4f} 解析解: {3*x**2 - 2:.4f}") +``` + +2. 数值验证链式法则。定义 $f(x) = \sin(x^2)$,通过 `jax.grad` 计算其导数,并与解析结果 $2x\cos(x^2)$ 进行比较。 +```python +import jax +import jax.numpy as jnp + +f = lambda x: jnp.sin(x**2) +df = jax.grad(f) + +for x in [0.5, 1.0, 2.0]: + auto = df(x) + analytical = 2*x * jnp.cos(x**2) + print(f"x={x:.1f} 自动微分: {auto:.6f} 解析解: {analytical:.6f}") +``` + +3. 可视化导数。将 $f(x) = x^3 - 3x$ 与其导数 $f'(x) = 3x^2 - 3$ 绘制在同一张图上。观察 $f'(x) = 0$ 的位置与 $f$ 的峰谷之间的对应关系。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +f = lambda x: x**3 - 3*x +# jax.grad 用于标量;jax.vmap 将其向量化,可同时处理一组输入 +df = jax.vmap(jax.grad(f)) + +x = jnp.linspace(-2.5, 2.5, 200) +plt.plot(x, jax.vmap(f)(x), label="f(x)") +plt.plot(x, df(x), label="f'(x)", linestyle="--") +plt.axhline(0, color="gray", linewidth=0.5) +plt.legend() +plt.title("函数及其导数") +plt.show() +``` diff --git a/chapter 03: calculus/02. integral calculus.md b/chapter 03: calculus/02. integral calculus.md new file mode 100644 index 0000000..18a428b --- /dev/null +++ b/chapter 03: calculus/02. integral calculus.md @@ -0,0 +1,110 @@ +# 积分学 + +*积分学在区间上累积量,将局部变化率还原为总量。本文涵盖定积分与不定积分、微积分基本定理、积分技巧,以及在机器学习中与概率密度和期望值的应用。* + +- 微分告诉我们单个点的变化率。积分则相反:它将许多微小片段累积起来,计算出一个总量。 + +- 如果导数回答的是"多快?",那么积分回答的是"多少?" + +- 理解积分最简单的方式是将其视为**曲线下的面积**。如果绘制出函数 $f(x)$ 的图像,并将从 $x = a$ 到 $x = b$ 之间曲线与 x 轴之间的区域涂上阴影,积分给出的就是该区域的有符号面积。 + +![积分通过求和薄矩形来计算曲线下的面积](../images/area_under_curve.svg) + +- 为什么是"有符号"的?在 x 轴上方的区域贡献正面积,在下方的区域贡献负面积。这在物理上是有意义的:如果 $f(x)$ 代表速度,积分给出的是净位移(正向减去反向),而不是总路程。 + +- 为了计算这个面积,想象将区域切成 $n$ 个细长的竖直矩形,每个矩形的宽度为 $\Delta x$。每个矩形的高度是该切片内某一点的函数值。将它们求和: + +$$\text{面积} \approx \sum_{i=1}^{n} f(x_i^\ast) \, \Delta x$$ + +- 当我们让矩形越来越薄时($n \to \infty$,$\Delta x \to 0$),这个和就变得精确。这个极限过程定义了**定积分**: + +$$\int_a^b f(x)\, dx = \lim_{n \to \infty} \sum_{i=1}^{n} f(x_i^\ast) \, \Delta x$$ + +- $\int$ 符号是拉长的"S",代表"求和"(Sum)。$dx$ 提醒我们,我们是在沿 x 轴方向对无穷薄的切片求和。 + +- **不定积分**(或**原函数**)是一个函数 $F(x)$,其导数为 $f(x)$。我们写作: + +$$\int f(x)\, dx = F(x) + C$$ + +- $+ C$ 是**积分常数**。因为任何常数的导数都是零,所以存在无穷多个仅相差一个常数的原函数。例如,$\int 2x\, dx = x^2 + C$,因为 $x^2 + 7$ 或 $x^2 - 3$ 的导数仍然是 $2x$。 + +- **微积分基本定理**是连接微分与积分的桥梁。它包含两部分: + +- **第一部分**:如果 $F(x)$ 是 $f(x)$ 的一个原函数,那么定积分等于 $F$ 在端点处的值之差: + +$$\int_a^b f(x)\, dx = F(b) - F(a)$$ + +- 这非常实用。我们不再需要计算一个和的极限(这很困难),而是找到一个原函数并在两点处求值(这通常很简单)。 + +- **第二部分**:如果我们定义 $F(x) = \int_a^x f(t)\, dt$,那么 $F'(x) = f(x)$。微分与积分是互逆运算,它们相互抵消。 + +- 例如,计算 $\int_1^3 x^2\, dx$:$x^2$ 的原函数是 $\frac{x^3}{3}$。所以 $\int_1^3 x^2\, dx = \frac{27}{3} - \frac{1}{3} = \frac{26}{3} \approx 8.67$。 + +- 正如微分有运算法则一样,积分也有相应的逆向运算法则: + +| 函数 | 积分 | 条件 | +|---|---|---| +| $x^n$ | $\frac{x^{n+1}}{n+1} + C$ | $n \neq -1$ | +| $\frac{1}{x}$ | $\ln\|x\| + C$ | | +| $e^x$ | $e^x + C$ | | +| $a^x$ | $\frac{a^x}{\ln a} + C$ | | +| $\sin x$ | $-\cos x + C$ | | +| $\cos x$ | $\sin x + C$ | | +| $k$(常数) | $kx + C$ | | + +- **和/差法则**同样适用:$\int [f(x) \pm g(x)]\, dx = \int f(x)\, dx \pm \int g(x)\, dx$。常数可以提出来:$\int k\, f(x)\, dx = k \int f(x)\, dx$。 + +- 当一个函数太复杂而无法直接积分时,我们有简化它的技巧。 + +- **换元积分法(u 代换)**是链式法则的逆过程。如果发现一个复合函数 $f(g(x))$ 乘以 $g'(x)$,则令 $u = g(x)$,于是 $du = g'(x)\, dx$,积分得以简化。 + +- 例如:$\int 2x \cos(x^2)\, dx$。令 $u = x^2$,则 $du = 2x\, dx$。积分变为 $\int \cos(u)\, du = \sin(u) + C = \sin(x^2) + C$。 + +- **分部积分法**是乘积法则的逆过程。如果被积函数是两个函数的乘积: + +$$\int u\, dv = uv - \int v\, du$$ + +- 策略性地选择 $u$ 和 $dv$,使得剩下的积分 $\int v\, du$ 比原来的更简单。选择 $u$ 的常用助记法是 **LIATE**:对数函数(Logarithmic)、反三角函数(Inverse trig)、代数函数(Algebraic)、三角函数(Trigonometric)、指数函数(Exponential)(优先从靠前的类别中选择 $u$)。 + +- 例如:$\int x\, e^x\, dx$。令 $u = x$(代数函数)和 $dv = e^x\, dx$。则 $du = dx$,$v = e^x$。因此:$\int x\, e^x\, dx = x\, e^x - \int e^x\, dx = x\, e^x - e^x + C = e^x(x - 1) + C$。 + +- 在机器学习中,积分出现在概率论中(通过对密度函数积分来计算概率)、期望值中(连续分布上的加权平均),以及计算 ROC 曲线下的面积。虽然在实际中我们很少手动积分,但理解积分的含义有助于解释这些量。 + +## 编程练习(使用 CoLab 或 notebook) + +1. 使用黎曼和,用不断增加数量的矩形来数值逼近 $\int_0^1 x^2\, dx$。与精确答案 $\frac{1}{3}$ 进行比较。 +```python +import jax.numpy as jnp + +for n in [10, 100, 1000, 10000]: + x = jnp.linspace(0, 1, n, endpoint=False) + dx = 1.0 / n + area = jnp.sum(x**2 * dx) + print(f"n={n:5d} approx: {area:.6f} exact: {1/3:.6f}") +``` + +2. 数值验证微积分基本定理。定义 $F(x) = \int_0^x t^2\, dt = \frac{x^3}{3}$,并验证其导数(通过 `jax.grad` 计算)等于 $x^2$。 +```python +import jax +import jax.numpy as jnp + +F = lambda x: x**3 / 3 +dF = jax.grad(F) + +for x in [0.5, 1.0, 2.0, 3.0]: + print(f"x={x:.1f} F'(x)={dF(x):.4f} x^2={x**2:.4f}") +``` + +3. 可视化 $f(x) = \sin(x)$ 从 $0$ 到 $\pi$ 的曲线下面积。使用 `plt.fill_between` 填充该区域,并用黎曼和数值计算面积。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +x = jnp.linspace(0, jnp.pi, 500) +y = jnp.sin(x) + +plt.plot(x, y, color="purple", linewidth=2) +plt.fill_between(x, y, alpha=0.2, color="purple") +plt.title(f"Area = {jnp.sum(jnp.sin(x) * (jnp.pi / 500)):.4f} (exact: 2.0)") +plt.show() +``` diff --git a/chapter 03: calculus/03. multivariate calculus.md b/chapter 03: calculus/03. multivariate calculus.md new file mode 100644 index 0000000..b82c56f --- /dev/null +++ b/chapter 03: calculus/03. multivariate calculus.md @@ -0,0 +1,219 @@ +# 多元微积分 + +*多元微积分将导数和积分扩展到多变量函数,这对于机器学习模型拥有数百万参数的情形至关重要。本章涵盖偏导数、梯度、雅可比矩阵、海森矩阵以及使反向传播成为可能的多变量链式法则。* + +- 到目前为止,我们的函数都只接受单个输入 $x$ 并产生单个输出 $f(x)$。但在机器学习中,我们几乎从不只处理一个变量。 + +- 考虑一个双变量函数,例如 $f(x, y) = x^2 + y^2$。它在三维空间中定义了一个曲面,形状像一个碗。我们想知道:如果我们在保持 $y$ 固定的同时稍微调整 $x$,$f$ 会如何变化?这就是**偏导数**。 + +- $f$ 对 $x$ 的**偏导数**,记作 $\frac{\partial f}{\partial x}$,将其他所有变量视为常数,然后对 $x$ 正常求导。 + +- 对于 $f(x, y) = x^2y + 3x - 2y$: + +$$\frac{\partial f}{\partial x} = 2xy + 3 \qquad \frac{\partial f}{\partial y} = x^2 - 2$$ + +- 计算 $\frac{\partial f}{\partial x}$ 时,我们将 $y$ 视为常数,因此 $x^2y$ 求导得 $2xy$,$3x$ 求导得 $3$,$-2y$ 求导得 $0$。 + +- 计算 $\frac{\partial f}{\partial y}$ 时,我们将 $x$ 视为常数,因此 $x^2y$ 求导得 $x^2$,$3x$ 求导得 $0$,$-2y$ 求导得 $-2$。 + +- 从几何上看,对 $x$ 求偏导数就像用一个平行于 $xz$ 平面的平面(在固定的 $y$ 值处)切割三维曲面,然后求所得曲线的斜率。 + +![偏导数:固定一个变量来切割曲面](../images/partial_derivative.svg) + +- **梯度**将所有偏导数收集到一个向量中: + +$$\nabla f = \left(\frac{\partial f}{\partial x_1}, \frac{\partial f}{\partial x_2}, \ldots, \frac{\partial f}{\partial x_n}\right)$$ + +- 对于 $f(x, y) = x^2 + y^2$:$\nabla f(x, y) = (2x, 2y)$。在点 $(1, 2)$ 处:$\nabla f(1, 2) = (2, 4)$。 + +- 梯度有两个关键性质: + + - **方向**:它指向上升最陡的方向。想象一位登山者在山上。他们所在位置的梯度指向正上方,沿着最陡的路径。 + + - **大小**:$\|\nabla f\|$ 给出了最陡方向上的变化率。梯度大意味着地形陡峭;梯度小意味着地形近乎平坦。 + +![梯度向量指向山坡上方,垂直于等高线](../images/gradient_contour.svg) + +- 由于梯度指向上坡,沿相反方向($-\nabla f$)移动就是下坡,走向更低的值。这个简单的想法是**梯度下降**的基础,我们将在后续章节中详细探讨这种优化技术。现在,关键要点是:梯度告诉你哪个方向是"上坡",以及攀登的陡峭程度。 + +- **方向导数**推广了偏导数。它不问"$f$ 沿 $x$ 轴如何变化?",而是问"$f$ 沿任意方向 $\mathbf{u}$ 如何变化?"它通过梯度与单位向量的点积来计算: + +$$D_{\mathbf{u}} f = \nabla f \cdot \mathbf{u}$$ + +- 对于 $f(x, y) = x^2 + y^2$ 在点 $(1, 2)$ 处,沿方向 $\mathbf{v} = (3, 4)$:首先归一化得到 $\mathbf{u} = (3/5, 4/5)$,然后 $D_{\mathbf{u}} f = (2, 4) \cdot (3/5, 4/5) = 6/5 + 16/5 = 22/5$。 + +- 偏导数是方向导数的特例,其中方向沿着坐标轴。如果方向导数在某个方向上为零,则函数在该点沿该方向是平坦的。 + +- **等高线**(或水平曲线)连接函数值相等的点。对于 $f(x, y) = x^2 + y^2$,等高线是以原点为中心的圆:对应不同 $c$ 值的 $x^2 + y^2 = c$。 + +- 等高线永不相交(一个点不可能有两个不同的函数值)。 + +- 梯度始终垂直于等高线,从低值指向高值。 + +- 等高线密集表示地形陡峭;等高线稀疏表示坡度平缓。 + +- 到目前为止,我们的函数都只产生单个输出。但许多函数会产生多个输出。函数 $\mathbf{F}: \mathbb{R}^n \to \mathbb{R}^m$ 接收 $n$ 个输入并产生 $m$ 个输出。**雅可比矩阵**组织了这样一个向量值函数的所有偏导数: + +```math +J = \begin{bmatrix} \frac{\partial f_1}{\partial x_1} & \cdots & \frac{\partial f_1}{\partial x_n} \\ \vdots & \ddots & \vdots \\ \frac{\partial f_m}{\partial x_1} & \cdots & \frac{\partial f_m}{\partial x_n} \end{bmatrix} +``` + +- 雅可比矩阵的每一行是一个输出分量的梯度。对于一个有 3 个输入和 2 个输出的函数,雅可比矩阵是一个 $2 \times 3$ 矩阵。 + +- 雅可比矩阵将导数推广到向量值函数。 + +- 就像标量函数的导数告诉你每单位输入变化对应的输出变化量一样,雅可比矩阵告诉你每个输出相对于每个输入的变化情况。 + +- **雅可比行列式**衡量一个变换局部拉伸或压缩空间的程度。 + +- 如果行列式为 2,小区域的面积加倍。如果行列式为 0,该变换将空间压缩到更低维度(回想我们在矩阵章节中学到的:行列式为零意味着奇异变换,不可逆)。 + +- 当多个变换组合在一起(一个变换的输出作为下一个变换的输入)时,整体映射的雅可比矩阵是各个雅可比矩阵的乘积。我们将在后续章节中看到这个思想变得至关重要。 + +- 梯度捕获一阶信息(斜率),而**海森矩阵**捕获二阶信息(曲率)。 + +- 对于标量函数 $f(x_1, \ldots, x_n)$,海森矩阵是所有二阶偏导数的 $n \times n$ 矩阵: + +```math +H = \begin{bmatrix} \frac{\partial^2 f}{\partial x_1^2} & \frac{\partial^2 f}{\partial x_1 \partial x_2} & \cdots \\ \frac{\partial^2 f}{\partial x_2 \partial x_1} & \frac{\partial^2 f}{\partial x_2^2} & \cdots \\ \vdots & \vdots & \ddots \end{bmatrix} +``` + +- 对于 $f(x, y) = x^3 + 2xy^2 - y^3$,梯度为 $(3x^2 + 2y^2,\; 4xy - 3y^2)$,海森矩阵为: + +```math +H = \begin{bmatrix} 6x & 4y \\ 4y & 4x - 6y \end{bmatrix} +``` + +- 对角线元素($6x$ 和 $4x - 6y$)告诉你 $x$ 方向的斜率随 $x$ 移动如何变化,$y$ 方向同理。 + +- 非对角线元素($4y$)告诉你一个方向的斜率随另一个方向的移动如何变化。 + +- **克莱罗定理**保证:对于具有连续二阶导数的函数,混合偏导数相等:$\frac{\partial^2 f}{\partial x \partial y} = \frac{\partial^2 f}{\partial y \partial x}$。 + +- 这意味着海森矩阵是对称的,这(正如我们在矩阵章节中看到的)保证了实特征值和正交特征向量。 + +- 海森矩阵告诉我们临界点(梯度为零的点)附近函数的形状: + + - 如果 $H$ 是正定的(所有特征值为正),则该点是**局部极小值点**,曲面像碗一样向各个方向向上弯曲。 + - 如果 $H$ 是负定的(所有特征值为负),则该点是**局部极大值点**,曲面像倒扣的碗一样向各个方向向下弯曲。 + - 如果 $H$ 同时具有正负特征值,则该点是**鞍点**,曲面在某些方向上向上弯曲,在另一些方向上向下弯曲,就像山坳一样。 + +- **多变量链式法则**将链式法则扩展到多变量函数。如果 $z = f(x, y)$,其中 $x = g(t)$ 且 $y = h(t)$,则: + +$$\frac{dz}{dt} = \frac{\partial f}{\partial x}\frac{dx}{dt} + \frac{\partial f}{\partial y}\frac{dy}{dt}$$ + +- 从 $t$ 到 $z$ 的每条路径都贡献一项:沿该路径的偏导数乘以中间变量对 $t$ 的导数。 + +- 例如,如果 $z = x^2 y + 3x - y^2$,$x = \cos(t)$,$y = \sin(t)$: + +$$\frac{dz}{dt} = (2xy + 3)(-\sin t) + (x^2 - 2y)(\cos t)$$ + +- 除了手动计算导数,还有三种方法: + + - **数值微分**:用 $f'(x) \approx \frac{f(x+h) - f(x-h)}{2h}$(取很小的 $h$)来近似。简单但有噪声且不精确。 + - **符号微分**:通过代数地应用求导法则产生精确表达式。可能导致表达式呈指数级膨胀。 + - **自动微分(autodiff)**:跟踪运算链并高效地计算精确导数。JAX、PyTorch 和 TensorFlow 都使用这种方法。它能给出精确的数值(而非近似值),且不会产生臃肿的符号表达式。 + +## 编程练习(使用 CoLab 或 notebook) + +1. 使用 `jax.grad` 计算函数 $f(x, y) = x^2 y + 3x - 2y$ 在点 $(1, 2)$ 处的梯度。由于 $f$ 接收向量输入,请使用带 `argnums` 参数的 `jax.grad`。 +```python +import jax +import jax.numpy as jnp + +def f(x, y): + return x**2 * y + 3*x - 2*y + +df_dx = jax.grad(f, argnums=0) +df_dy = jax.grad(f, argnums=1) + +x, y = 1.0, 2.0 +print(f"∂f/∂x = {df_dx(x, y):.4f} (期望: {2*x*y + 3:.4f})") +print(f"∂f/∂y = {df_dy(x, y):.4f} (期望: {x**2 - 2:.4f})") +``` + +2. 使用 `jax.jacobian` 计算向量值函数的雅可比矩阵,并与手动计算结果进行比较。 +```python +import jax +import jax.numpy as jnp + +def F(x): + return jnp.array([x[0]**2 + x[1], x[0] * x[1]**2]) + +J = jax.jacobian(F) +x = jnp.array([1.0, 2.0]) +print(f"在 (1,2) 处的雅可比矩阵:\n{J(x)}") +# 期望: [[2*x[0], 1], [x[1]**2, 2*x[0]*x[1]]] = [[2, 1], [4, 4]] +``` + +3. 使用 `jax.hessian` 计算 $f(x, y) = x^3 + 2xy^2 - y^3$ 的海森矩阵,并验证其对称性。 +```python +import jax +import jax.numpy as jnp + +def f(xy): + x, y = xy[0], xy[1] + return x**3 + 2*x*y**2 - y**3 + +H = jax.hessian(f) +point = jnp.array([1.0, 2.0]) +hess = H(point) +print(f"海森矩阵:\n{hess}") +print(f"是否对称: {jnp.allclose(hess, hess.T)}") +# 期望: [[6x, 4y], [4y, 4x-6y]] = [[6, 8], [8, -8]] +``` + +4. 从头构建一个极简的自动微分引擎。 + - 每个 `Var` 追踪其值以及如何通过链式法则反向传播梯度。 + - 尝试扩展更多运算(除法、幂运算等)。 + - 这是 JAX、PyTorch 和 Numpy 的设计基础。 +```python +class Var: + def __init__(self, val, children=(), backward_fn=None): + self.val = val + self.grad = 0.0 + self.children = children + self.backward_fn = backward_fn + + def __add__(self, other): + out = Var(self.val + other.val, children=(self, other)) + def _backward(): + self.grad += out.grad # d(a+b)/da = 1 + other.grad += out.grad # d(a+b)/db = 1 + out.backward_fn = _backward + return out + + def __mul__(self, other): + out = Var(self.val * other.val, children=(self, other)) + def _backward(): + self.grad += other.val * out.grad # d(a*b)/da = b + other.grad += self.val * out.grad # d(a*b)/db = a + out.backward_fn = _backward + return out + + def backward(self): + # 拓扑排序,然后传播梯度 + # 我们将在数据结构与算法章节中详细介绍 + order, visited = [], set() + def topo(v): + if v not in visited: + visited.add(v) + for c in v.children: + topo(c) + order.append(v) + topo(self) + self.grad = 1.0 + for v in reversed(order): + if v.backward_fn: + v.backward_fn() + +# f(x, y) = x*x*y + x 在 (3, 2) 处 +x = Var(3.0) +y = Var(2.0) +f = x * x * y + x # = 3*3*2 + 3 = 21 + +f.backward() +print(f"f = {f.val}") # 21.0 +print(f"df/dx = {x.grad}") # 2*x*y + 1 = 13.0 +print(f"df/dy = {y.grad}") # x*x = 9.0 +``` diff --git a/chapter 03: calculus/04. function approximation.md b/chapter 03: calculus/04. function approximation.md new file mode 100644 index 0000000..b046214 --- /dev/null +++ b/chapter 03: calculus/04. function approximation.md @@ -0,0 +1,143 @@ +# 函数逼近 + +*函数逼近用足够接近原函数的简单函数来替代复杂函数。本文涵盖线性化、泰勒级数、多项式逼近、傅里叶级数以及通用逼近定理——这些是神经网络能够学习任意映射的理论基础。* + +- 我们遇到的许多函数都过于复杂,无法直接处理。例如,在纸上计算 $e^{0.1}$、预测卫星轨迹等,都涉及没有简单封闭形式答案的函数。 + +- **函数逼近**用更简单的函数来替代复杂函数,使其在关心区域内"足够接近"原函数。 + +- 最自然的逼近是多项式。多项式只是 $x$ 的幂次与系数的和,易于求值、微分和积分。 + +- 但为什么多项式作为逼近器如此有效?看看 $x$ 的每个幂次贡献了什么。 + + - 常数项 $a_0$ 设定基准值。 + - $a_1 x$ 项增加斜率。 + - $a_2 x^2$ 项增加曲率。 + - 更高的幂次则捕捉函数形状的更多细节。 + +![每个多项式项为逼近增加一层细节](../images/polynomial_buildup.svg) + +- 通过选择合适的系数,我们可以逐次匹配函数在某一点的值、斜率、曲率以及高阶行为。 + +- 当项数足够时,多项式几乎可以模仿任何光滑函数。 + +- 问题在于:如何找到正确的系数? + +- **线性化**是最简单的逼近。在点 $x = a$ 附近,我们用函数的切线来代替它: + +$$L(x) = f(a) + f'(a)(x - a)$$ + +- 这是一阶**泰勒逼近**。它的思路是:从已知值 $f(a)$ 出发,然后加上斜率乘以距离 $a$ 的偏移量。 + +- 例如,在 $x = 0$ 处对 $\sin(x)$ 线性化:$f(0) = 0$,$f'(0) = \cos(0) = 1$,所以 $L(x) = x$。在零附近,$\sin(x) \approx x$。试试看:$\sin(0.1) = 0.0998\ldots \approx 0.1$。 + +- 但线性化仅在非常接近 $a$ 的地方有效。离得稍远,逼近就失效了。为了做得更好,我们需要引入高阶项。 + +- **泰勒级数**将函数表示为无穷多个多项式项的和,每一项都捕捉到函数在点 $a$ 附近行为的更精细细节: + +$$f(x) = \sum_{n=0}^{\infty} \frac{f^{(n)}(a)}{n!}(x - a)^n = f(a) + f'(a)(x-a) + \frac{f''(a)}{2!}(x-a)^2 + \frac{f'''(a)}{3!}(x-a)^3 + \cdots$$ + +![泰勒级数:项数越多,逼近越精确](../images/taylor_approximation.svg) + +- 每一项依次增加一个修正项。第一项匹配函数值,第二项匹配斜率,第三项匹配曲率,依此类推。包含的项越多,逼近精确的区域就越大。 + +- 分母中的 $n!$ 并非随意选择。当你对 $(x - a)^n$ 恰好微分 $n$ 次时,会得到 $n!$。阶乘抵消了这个结果,从而确保泰勒多项式的 $n$ 阶导数在 $x = a$ 处与原函数的 $n$ 阶导数相等。 + +- **麦克劳林级数**就是中心在 $a = 0$ 的泰勒级数: + +$$f(x) = \sum_{n=0}^{\infty} \frac{f^{(n)}(0)}{n!} x^n$$ + +- 一些著名的麦克劳林级数: + +$$e^x = 1 + x + \frac{x^2}{2!} + \frac{x^3}{3!} + \cdots$$ + +$$\sin x = x - \frac{x^3}{3!} + \frac{x^5}{5!} - \frac{x^7}{7!} + \cdots$$ + +$$\cos x = 1 - \frac{x^2}{2!} + \frac{x^4}{4!} - \frac{x^6}{6!} + \cdots$$ + +- 注意 $\sin x$ 只有奇次幂(它是奇函数),而 $\cos x$ 只有偶次幂(它是偶函数)。交替的符号使得逼近在真实值周围振荡,从两侧同时收敛。 + +- 让我们用四项来逼近 $e^{0.5}$:$1 + 0.5 + \frac{0.25}{2} + \frac{0.125}{6} = 1 + 0.5 + 0.125 + 0.02083 \approx 1.6458$。真实值为 $1.6487\ldots$,因此四项已经给出了三个正确的小数位。 + +- 并非所有泰勒级数都处处收敛。**收敛半径**告诉我们,在距离中心 $a$ 多远的范围内,级数给出有效的结果。在此半径内,通过增加项数,多项式逼近可以达到任意所需的精度。超出此半径,级数发散。 + +- **幂级数**的一般形式是:$\sum_{n=0}^{\infty} a_n (x - c)^n$。泰勒级数是系数由导数确定的幂级数。其他幂级数可能由其他规则定义。**比值判别法**用于判定收敛性:计算 $\lim_{n \to \infty} \left|\frac{a_{n+1}}{a_n}\right|$。如果该极限为 $L$,则收敛半径为 $R = 1/L$。 + +- 将泰勒级数截断到 $n$ 项时,会产生误差。**拉格朗日余项**给出了这个误差的界限: + +$$R_n(x) = \frac{f^{(n+1)}(c)}{(n+1)!}(x-a)^{n+1}$$ + +- 这里 $c$ 是 $a$ 和 $x$ 之间的某个未知点。我们无法确切知道 $c$,但通常可以限定 $|f^{(n+1)}(c)|$ 来得到最坏情况下的误差估计。分母中的 $(n+1)!$ 增长极快,因此随着项数增加,误差迅速减小(对于收敛半径内的函数而言)。 + +- 对于多变量函数,泰勒展开包含混合偏导数。$f(\mathbf{x})$ 在点 $\mathbf{a}$ 附近的二阶逼近为: + +$$f(\mathbf{x}) \approx f(\mathbf{a}) + \nabla f(\mathbf{a})^T (\mathbf{x} - \mathbf{a}) + \frac{1}{2} (\mathbf{x} - \mathbf{a})^T H(\mathbf{a}) (\mathbf{x} - \mathbf{a})$$ + +- 第一项是函数值,第二项使用梯度(向量,如我们在多元微积分中看到的),第三项使用海森矩阵(捕捉曲率)。这直接将我们的矩阵章节与微积分联系起来:海森矩阵是一个由二阶导数组成的矩阵,描述了函数表面的形状。 + +- 这种多变量二阶逼近是牛顿法和其他二阶优化技术的基础,我们将在下一个文件中看到。 + +- 除了多项式,还有其他值得了解的逼近方法: + + - **样条插值**:不用单个高次多项式,而是将多个低次多项式光滑拼接在一起。这避免了高次多项式可能产生的剧烈振荡。 + - **傅里叶级数**:将周期函数逼近为正弦和余弦的和。在信号处理和音频中至关重要。 + - **神经网络**:通用函数逼近器。只要有足够的神经元,它们可以任意精度逼近任何连续函数。这就是深度学习的理论基础。 + +- 如果一个函数具有使逼近可靠的性质——连续性(无跳跃)、可微性(无尖角)、光滑性(所有阶导数都存在)和有界性(输出保持有限),我们就称其为"行为良好"的函数。 + +- 多项式、指数函数和三角函数都属于行为良好的函数。函数行为越好,获得良好逼近所需的泰勒项数就越少。 + +## 编程练习(使用 CoLab 或 Jupyter Notebook) + +1. 用递增数量的泰勒项逼近 $e^x$,并可视化逼近效果如何改善。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +x = jnp.linspace(-2, 3, 300) +plt.plot(x, jnp.exp(x), "k-", linewidth=2, label="eˣ (精确值)") + +colors = ["#e74c3c", "#3498db", "#27ae60", "#9b59b6"] +for n, color in zip([1, 2, 4, 8], colors): + approx = sum(x**k / jnp.array(float(jnp.prod(jnp.arange(1, k+1)) if k > 0 else 1)) + for k in range(n+1)) + plt.plot(x, approx, color=color, linestyle="--", label=f"{n} 项") + +plt.ylim(-2, 15) +plt.legend() +plt.title("eˣ 的泰勒逼近") +plt.show() +``` + +2. 计算拉格朗日余项,以限定用不同数量的泰勒项逼近 $\sin(1)$ 时的误差。 +```python +import jax.numpy as jnp + +x = 1.0 +exact = jnp.sin(x) + +taylor = 0.0 +for n in range(8): + sign = (-1)**n + factorial = float(jnp.prod(jnp.arange(1, 2*n+2))) + taylor += sign * x**(2*n+1) / factorial + error = abs(exact - taylor) + bound = x**(2*n+3) / float(jnp.prod(jnp.arange(1, 2*n+4))) + print(f"项数={n+1} 近似值={taylor:.10f} 误差={error:.2e} 界限={bound:.2e}") +``` + +3. 比较在 $x=0$ 附近 $\cos(x)$ 的线性化逼近与二次泰勒逼近。在同一张图上绘制两个逼近和真实函数,观察各自精确的范围。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +x = jnp.linspace(-3, 3, 300) +plt.plot(x, jnp.cos(x), "k-", linewidth=2, label="cos(x)") +plt.plot(x, jnp.ones_like(x), "--", color="#e74c3c", label="线性: 1") +plt.plot(x, 1 - x**2/2, "--", color="#3498db", label="二次: 1 - x²/2") +plt.plot(x, 1 - x**2/2 + x**4/24, "--", color="#27ae60", label="四阶") +plt.ylim(-2, 2) +plt.legend() +plt.title("cos(x) 的泰勒逼近") +plt.show() +``` diff --git a/chapter 03: calculus/05. optimisation.md b/chapter 03: calculus/05. optimisation.md new file mode 100644 index 0000000..6732dd7 --- /dev/null +++ b/chapter 03: calculus/05. optimisation.md @@ -0,0 +1,145 @@ +# 优化 + +*优化是模型训练的数学核心——寻找使损失函数最小的参数。本文涵盖驻点、凸性、梯度下降、牛顿法、带拉格朗日乘数的约束优化,以及驱动现代深度学习的主流优化器(SGD、Adam)。* + +- 训练神经网络、拟合回归线、调优超参数:几乎所有机器学习算法的核心都是一个**优化**问题。 + +- 我们有一个函数(损失函数、代价函数、目标函数),希望找到使其尽可能小(或大)的输入。 + +- 在优化之前,我们需要理解函数的**零点**(或根)。$f(x)$ 的零点是指满足 $f(x) = 0$ 的 $x$ 值。从图形上看,这些点就是与 x 轴的交点。 + +- 例如,$f(x) = x^2 - 3x + 2 = (x-1)(x-2)$ 的零点在 $x = 1$ 和 $x = 2$ 处。在两个零点之间,函数为负($f(1.5) = -0.25$);在零点之外,函数为正。零点将数轴分割成若干个区域,在每个区域中函数保持相同符号。 + +- 零点的**重数**是指对应因式出现的次数。 + +- 在单零点(重数为 1)处,图像穿过 x 轴。在二重零点(重数为 2)处,图像接触 x 轴但反弹回去而不穿过,在该点处看起来是"平坦"的。 + +- 寻找零点之所以重要,是因为导数 $f'(x)$ 的零点正是 $f(x)$ 的**驻点**——即极大值或极小值的候选点。 + +- 在极大值或极小值处,切线是水平的(斜率为 0),因此 $f'(x) = 0$。 + +![驻点:当导数为零时,函数出现波峰、波谷或鞍点](../images/critical_points.svg) + +- 但并非每个驻点都是极大值或极小值。$f'(x) = 0$ 的点也可能是**拐点**(如 $f(x) = x^3$ 在 $x = 0$ 处),函数在该点暂时变平但并未改变方向。 + +- **二阶导数检验**可以解决这个问题。在驻点 $x = c$(即 $f'(c) = 0$)处: + + - 若 $f''(c) > 0$:曲线向下凸(碗状),因此 $c$ 是**局部极小值**。 + - 若 $f''(c) < 0$:曲线向上凸(山丘状),因此 $c$ 是**局部极大值**。 + - 若 $f''(c) = 0$:检验无效,需要使用更高阶导数或其他方法。 + +- 例如,$f(x) = x^3 - 3x$。导数为 $f'(x) = 3x^2 - 3 = 3(x-1)(x+1)$,因此驻点在 $x = -1$ 和 $x = 1$ 处。二阶导数为 $f''(x) = 6x$。在 $x = -1$ 处:$f''(-1) = -6 < 0$(局部极大值)。在 $x = 1$ 处:$f''(1) = 6 > 0$(局部极小值)。 + +- 如果连接函数图像上任意两点的线段位于图像之上(或与之重合),则该函数是**凸的**。可以想象成一个碗形,处处向上弯曲。数学上,若对所有 $x$ 有 $f''(x) \geq 0$,则 $f$ 是凸函数。 + +![凸函数具有唯一的全局最小值;非凸函数可能有多个局部最小值](../images/convex_nonconvex.svg) + +- 凸性的强大之处在于凸函数有一个卓越的性质:每个局部极小值同时也是**全局最小值**。不存在会让人陷入的欺骗性局部低谷。如果你把一个球滚入凸碗中,它总是会到达底部。 + +- 若 $-f$ 是凸的,则函数是**凹的**(向下弯曲)。函数从凹性过渡到凸性的点称为**拐点**,出现在 $f''(x) = 0$ 处。 + +- **牛顿法**利用切线寻找函数的零点(进而也可用于寻找其导数的驻点)。从初始猜测 $x_0$ 出发,迭代更新: + +$$x_{n+1} = x_n - \frac{f(x_n)}{f'(x_n)}$$ + +![牛顿法:沿切线方向逼近根的更好近似值](../images/newtons_method.svg) + +- 其思想是:在 $x_n$ 处画出切线,找到它与 x 轴的交点,该交点即为 $x_{n+1}$。对于性质良好且初始点选取恰当的函数,牛顿法收敛非常快(二次收敛,即每步正确位数大致翻倍)。 + +- 例如,求 $\sqrt{5}$(即 $f(x) = x^2 - 5$ 的零点):$f'(x) = 2x$,因此 $x_{n+1} = x_n - \frac{x_n^2 - 5}{2x_n}$。从 $x_0 = 2$ 开始:$x_1 = 2.25$,$x_2 = 2.2361\ldots$,已精确到小数点后四位。 + +- 如果初始猜测离根太远、根附近 $f'(x) = 0$,或函数在附近有拐点,牛顿法可能会失败。此外,它还需要计算导数,这可能代价高昂。 + +- 对于优化(寻找极小值而非零点),我们将牛顿法应用于 $f'(x) = 0$,得到更新公式: + +$$x_{n+1} = x_n - \frac{f'(x_n)}{f''(x_n)}$$ + +- 在多维情形下,这变为 $\mathbf{x}_{n+1} = \mathbf{x}_n - H^{-1} \nabla f(\mathbf{x}_n)$,其中 $H$ 是 Hessian 矩阵。这正是上一节中二阶泰勒近似的实际应用:将函数近似为二次型,跳到该二次型的极小值点,然后重复。 + +- **拉格朗日乘数**用于求解**约束优化**:在约束条件 $g(x, y) = c$ 下求 $f(x, y)$ 的最优值。我们不是在 $\mathbb{R}^n$ 中全域搜索,而是限制在约束条件成立的集合(一条曲线或曲面)上。 + +- 关键见解是几何层面的:在约束最优解处,$f$ 的梯度必须与 $g$ 的梯度平行。如果它们不平行,我们可以沿着约束条件朝某个方向移动,从而继续改进 $f$ 的值,这意味着还没有达到最优。 + +- 我们引入一个新变量 $\lambda$(拉格朗日乘数),定义**拉格朗日函数**: + +$$\mathcal{L}(x, y, \lambda) = f(x, y) - \lambda(g(x, y) - c)$$ + +- 令所有偏导数为零,得到一个方程组,其解即为约束最优解: + +$$\frac{\partial \mathcal{L}}{\partial x} = 0, \quad \frac{\partial \mathcal{L}}{\partial y} = 0, \quad \frac{\partial \mathcal{L}}{\partial \lambda} = 0$$ + +![拉格朗日乘数:在最优解处,f 和 g 的梯度平行](../images/lagrange_multiplier.svg) + +- 例如,在 $x^2 + y^2 = 1$ 的约束下最大化 $f(x,y) = x^2 y$。拉格朗日函数为 $\mathcal{L} = x^2 y - \lambda(x^2 + y^2 - 1)$。求偏导: + +$$2xy - 2\lambda x = 0, \quad x^2 - 2\lambda y = 0, \quad x^2 + y^2 = 1$$ + +- 由第一个方程(假设 $x \neq 0$):$\lambda = y$。代入第二个方程:$x^2 = 2y^2$。结合约束条件:$2y^2 + y^2 = 1$,得 $y = \frac{1}{\sqrt{3}}$。最大值为 $f = \frac{2}{3\sqrt{3}}$。 + +- 对于不等式约束($g(x,y) \leq c$ 而非 $= c$),**Karush-Kuhn-Tucker(KKT)条件**推广了拉格朗日乘数法。约束要么是激活的(有效约束,按等式处理),要么是非激活的(解在内部,约束无关紧要)。 + +- 在实践中,我们很少手工进行优化。以下是主要的算法家族: + + - **一阶方法**(仅使用梯度):梯度下降、随机梯度下降(SGD)、Adam。这些方法每步计算成本低,但收敛可能较慢,尤其是在病态问题上。 + + - **二阶方法**(使用梯度和 Hessian 矩阵):牛顿法收敛快,但计算和求逆 Hessian 矩阵代价高昂(对于 $n$ 个参数为 $O(n^3)$)。**拟牛顿法**(如 BFGS 和 L-BFGS)仅利用梯度信息近似 Hessian 矩阵,比一阶方法收敛更快,又无需承担完全的二阶方法计算成本。 + + - **共轭梯度法**:适用于大型稀疏系统,仅需矩阵-向量乘积,无需存储完整的 Hessian 矩阵。 + + - **高斯-牛顿法**和**莱文贝格-马夸尔特法**:专门用于最小二乘问题(在回归中常见),通过 Jacobian 矩阵近似 Hessian 矩阵。 + + - **自然梯度下降**:利用 Fisher 信息矩阵考虑参数空间的几何结构,对概率模型可能更有效。 + +- 优化器的选择取决于具体问题。对于深度学习,一阶方法(尤其是 Adam)占主导地位,因为参数量巨大(数百万到数十亿),计算 Hessian 矩阵不切实际。对于目标函数光滑的小规模问题,二阶方法可能快得多。 + +## 编程练习(在 CoLab 或 notebook 中完成) + +1. 实现牛顿法求 $\sqrt{7}$(即 $f(x) = x^2 - 7$ 的零点)。观察其快速收敛。 +```python +import jax.numpy as jnp + +f = lambda x: x**2 - 7 +df = lambda x: 2*x + +x = 3.0 # 初始猜测 +for i in range(6): + x = x - f(x) / df(x) + print(f"step {i+1}: x = {x:.10f} (error: {abs(x - jnp.sqrt(7.0)):.2e})") +``` + +2. 使用梯度下降最小化 $f(x, y) = (x - 3)^2 + (y + 1)^2$。最小值在 $(3, -1)$ 处。尝试不同的学习率。 +```python +import jax +import jax.numpy as jnp + +def f(params): + x, y = params + return (x - 3)**2 + (y + 1)**2 + +grad_f = jax.grad(f) +params = jnp.array([0.0, 0.0]) +lr = 0.1 + +for i in range(20): + g = grad_f(params) + params = params - lr * g + if i % 5 == 0 or i == 19: + print(f"step {i:2d}: ({params[0]:.4f}, {params[1]:.4f}) loss={f(params):.6f}") +``` + +3. 数值求解约束优化问题。在 $x + y = 10$ 的约束下最大化 $f(x,y) = xy$,通过参数化 $y = 10 - x$ 并求单变量函数的最优值。 +```python +import jax +import jax.numpy as jnp + +# 代入约束条件:y = 10 - x,所以 f = x(10 - x) = 10x - x² +f = lambda x: x * (10 - x) +df = jax.grad(f) + +# 梯度上升(我们要求最大值,所以加上梯度) +x = 1.0 +lr = 0.1 +for i in range(20): + x = x + lr * df(x) +print(f"x={x:.4f}, y={10-x:.4f}, f={f(x):.4f}") # 应为 x=5, y=5, f=25 +``` diff --git a/chapter 04: statistics/01. fundamentals.md b/chapter 04: statistics/01. fundamentals.md new file mode 100644 index 0000000..f1ed0af --- /dev/null +++ b/chapter 04: statistics/01. fundamentals.md @@ -0,0 +1,186 @@ +# 统计学基础 + +*统计学提供了描述数据和量化不确定性的语言。本节涵盖分布、随机变量、PMF、PDF、CDF、期望、方差、矩以及中心极限定理——这些概念支撑着每一个机器学习评估指标和损失函数。* + +- 统计学是从数据中学习的科学。你收集观测值,对其进行汇总,并得出结论——通常针对那些无法直接测量的事物。 + +- 假设你想知道某个国家所有成年人的平均身高。你不可能测量每一个人,因此你测量一个**样本**,并利用统计学对整个**总体**做出有根据的推测。 + +- 统计学有两个主要分支: + - **描述性统计**:对已有数据进行汇总(平均值、图表、表格) + - **推断性统计**:利用样本对更大群体做出推断 + +- 统计学的基本构件是**分布**——一种描述数值如何分布的方式。其他一切——平均值、检验、预测——都源于对分布的理解。 + +- **频率分布**统计数据中每个值(或值区间)出现的次数。想象一下把考试成绩分到不同的区间,然后统计每个区间中有多少学生。结果就是直方图。 + +- **概率分布**用概率代替原始计数。它不说"12 名学生的分数在 70 到 80 之间",而是说"分数在 70 到 80 之间的概率为 0.24"。当数据连续时,直方图的柱状会变成一条平滑曲线。 + +![频率分布直方图与概率分布平滑曲线对比](../images/distribution_types.svg) + +- 左侧的直方图基于你实际收集的数据构建。右侧的平滑曲线是一个数学模型,描述了数据背后的模式。一个是经验性的,另一个是理论性的。 + +- 为了从数学上处理分布,我们需要一种将结果赋予数值的方法。这正是**随机变量**所做的。 + +- 随机变量是一个将每次试验的结果映射到实数的函数。抛一枚硬币:结果是"正面"或"反面",但随机变量 $X$ 将其转换为 $X(正面) = 1$ 和 $X(反面) = 0$。现在我们就可以进行算术运算了。 + +![随机变量将结果(硬币、骰子)映射到数轴](../images/random_variable.svg) + +- **离散**随机变量取值为可数集:10 次抛掷中的正面次数、骰子的点数、一小时内收到的电子邮件数量。 + +- **连续**随机变量可以在一个区间内取任意值:你的精确身高、下一班公交车到达的时间、中午的温度。 + +- 这种区别很重要,因为它改变了我们计算概率的方式。对于离散变量,我们求和。对于连续变量,我们积分(回顾第 3 章的积分内容)。 + +- 对于离散随机变量,**概率质量函数(PMF)**给出每个具体值的概率: + +$$P(X = x) = p(x), \quad \text{其中 } \sum_{x} p(x) = 1$$ + +- 对于连续随机变量,**概率密度函数(PDF)**给出落在某个区间内的概率。任何单个精确值的概率为零;只有区间才具有正概率: + +$$P(a \le X \le b) = \int_a^b f(x)\, dx, \quad \text{其中 } \int_{-\infty}^{\infty} f(x)\, dx = 1$$ + +- 既然我们可以将结果赋予数值,最自然的问题就是:平均而言我们期望得到什么值? + +- **期望**(或期望值)是所有可能值的加权平均值,权重即为概率。可以将其视为分布的"重心"。 + +- 如果你多次掷一个公平的骰子,你的平均掷点数会收敛到 3.5。这就是期望值,尽管你实际上永远掷不出 3.5。 + +- 对于离散随机变量: + +$$E[X] = \sum_{x} x \cdot p(x)$$ + +- 对于连续随机变量(使用第 3 章的积分): + +$$E[X] = \int_{-\infty}^{\infty} x \cdot f(x)\, dx$$ + +- 示例:一个公平的六面骰子,对于 $x = 1, 2, 3, 4, 5, 6$,有 $p(x) = 1/6$。 + +$$E[X] = 1 \cdot \tfrac{1}{6} + 2 \cdot \tfrac{1}{6} + 3 \cdot \tfrac{1}{6} + 4 \cdot \tfrac{1}{6} + 5 \cdot \tfrac{1}{6} + 6 \cdot \tfrac{1}{6} = \frac{21}{6} = 3.5$$ + +- 期望具有线性性质,即 $E[aX + b] = aE[X] + b$。这一性质极其有用,在机器学习损失函数中频繁出现。 + +- 期望告诉我们中心位置,但完全没有说明数值的分散程度。为了描述分布的完整形状,我们需要**矩**。 + +- 矩是 $X$ 的某次幂的期望。第 $k$ 阶**原点矩**为: + +$$\mu_k' = E[X^k]$$ + +- 一阶原点矩($k = 1$)就是均值:$\mu_1' = E[X] = \mu$。 + +- 原点矩是从零点开始度量的。通常我们更关心相对于均值的偏差。第 $k$ 阶**中心矩**将测量中心化: + +$$\mu_k = E[(X - \mu)^k]$$ + +- 一阶中心矩始终为零(均值上下方的偏差相互抵消)。二阶中心矩就是**方差**。 + +- 为了比较不同尺度上的分布,我们通过除以标准差 $\sigma$ 的适当幂次来进行**标准化**: + +$$\tilde{\mu}_k = \frac{\mu_k}{\sigma^k}$$ + +- 每个矩捕捉分布形状的不同方面: + +![钟形曲线标注了每个矩所捕捉的特征:均值(中心)、方差(分散程度)、偏度(不对称性)、峰度(尾部重量)](../images/moments_shape.svg) + +- **1 阶矩(均值)**:分布的中心位置。平衡点。 +- **2 阶矩(方差)**:数值围绕均值的分散程度。方差越大,分布越宽。 +- **3 阶矩(偏度)**:分布向左还是向右倾斜。偏度为零表示对称。 +- **4 阶矩(峰度)**:尾部的重量。峰度越高,极端异常值越多。 + +- 让我们对具体数据集 $X = \{2, 4, 4, 4, 5, 5, 7, 9\}$ 计算全部四个矩。 + +- **步骤 1:均值**(一阶原点矩) + +$$\mu = \frac{2 + 4 + 4 + 4 + 5 + 5 + 7 + 9}{8} = \frac{40}{8} = 5$$ + +- **步骤 2:方差**(二阶中心矩)。从每个值中减去均值,平方,然后取平均: + +$$\sigma^2 = \frac{(2{-}5)^2 + (4{-}5)^2 + (4{-}5)^2 + (4{-}5)^2 + (5{-}5)^2 + (5{-}5)^2 + (7{-}5)^2 + (9{-}5)^2}{8}$$ + +$$= \frac{9 + 1 + 1 + 1 + 0 + 0 + 4 + 16}{8} = \frac{32}{8} = 4$$ + +- **标准差**为 $\sigma = \sqrt{4} = 2$。 + +- **步骤 3:偏度**(标准化三阶中心矩)。偏差取三次方,求平均,再除以 $\sigma^3$: + +$$\tilde{\mu}_3 = \frac{1}{8} \cdot \frac{(-3)^3 + (-1)^3 + (-1)^3 + (-1)^3 + 0^3 + 0^3 + 2^3 + 4^3}{2^3}$$ + +$$= \frac{1}{8} \cdot \frac{-27 -1 -1 -1 + 0 + 0 + 8 + 64}{8} = \frac{42}{64} = 0.656$$ + +- 正偏度表示右尾更长,这很合理,因为 9 远高于均值。 + +- **步骤 4:峰度**(标准化四阶中心矩)。偏差取四次方: + +$$\tilde{\mu}_4 = \frac{1}{8} \cdot \frac{(-3)^4 + (-1)^4 + (-1)^4 + (-1)^4 + 0^4 + 0^4 + 2^4 + 4^4}{2^4}$$ + +$$= \frac{1}{8} \cdot \frac{81 + 1 + 1 + 1 + 0 + 0 + 16 + 256}{16} = \frac{356}{128} = 2.781$$ + +- 正态分布的峰度为 3(称为"常峰态")。我们的 2.781 很接近,表明尾部大致呈正态。大于 3 的值("尖峰态")表示尾部更重;小于 3("低峰态")表示尾部更轻。某些公式会报告**超值峰度**(减去 3),因此我们的超值峰度为 $-0.219$。 + +## 编程练习(使用 CoLab 或 notebook) + +1. 计算一个加载骰子的期望值,其中面 6 的概率为 0.3,其余面均分剩余概率。通过模拟 100,000 次投掷进行验证。 +```python +import jax +import jax.numpy as jnp + +# 加载骰子:面 6 的 p=0.3,其余面均分 0.7 +probs = jnp.array([0.14, 0.14, 0.14, 0.14, 0.14, 0.30]) +faces = jnp.array([1, 2, 3, 4, 5, 6]) + +# 解析法计算期望值 +ev = jnp.sum(faces * probs) +print(f"期望值(公式法): {ev:.4f}") + +# 模拟 +key = jax.random.PRNGKey(42) +rolls = jax.random.choice(key, faces, shape=(100_000,), p=probs) +print(f"期望值(模拟法): {rolls.mean():.4f}") +``` + +2. 计算示例数据集的所有四个矩(均值、方差、偏度、峰度),然后修改数据并观察每个矩如何变化。 +```python +import jax.numpy as jnp + +x = jnp.array([2, 4, 4, 4, 5, 5, 7, 9], dtype=jnp.float32) + +mean = jnp.mean(x) +variance = jnp.mean((x - mean) ** 2) +std = jnp.sqrt(variance) +skewness = jnp.mean(((x - mean) / std) ** 3) +kurtosis = jnp.mean(((x - mean) / std) ** 4) + +print(f"均值: {mean:.3f}") +print(f"方差: {variance:.3f}") +print(f"标准差: {std:.3f}") +print(f"偏度: {skewness:.3f}") +print(f"峰度: {kurtosis:.3f}") +print(f"超值峰度: {kurtosis - 3:.3f}") +``` + +3. 并排可视化公平骰子的 PMF 和 CDF。尝试修改概率以观察形状如何变化。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +faces = jnp.array([1, 2, 3, 4, 5, 6]) +pmf = jnp.ones(6) / 6 # 公平骰子;试试修改这些值! +cdf = jnp.cumsum(pmf) + +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4)) + +ax1.bar(faces, pmf, color="#3498db", alpha=0.8) +ax1.set_title("PMF") +ax1.set_xlabel("面值") +ax1.set_ylabel("P(X = x)") +ax1.set_ylim(0, 0.5) + +ax2.step(faces, cdf, where="mid", color="#e74c3c", linewidth=2) +ax2.set_title("CDF") +ax2.set_xlabel("面值") +ax2.set_ylabel("P(X ≤ x)") +ax2.set_ylim(0, 1.1) + +plt.tight_layout() +plt.show() +``` diff --git a/chapter 04: statistics/02. measures.md b/chapter 04: statistics/02. measures.md new file mode 100644 index 0000000..7f65a56 --- /dev/null +++ b/chapter 04: statistics/02. measures.md @@ -0,0 +1,178 @@ +# 统计量 + +*统计量用单个数值概括数据,捕捉其离散程度、位置、形状和关联。本节涵盖方差、标准差、四分位数、偏度、峰度、协方差、相关和 z 分数——这是探索性数据分析和机器学习特征工程的基础工具集。* + +- 在上一节中,我们介绍了矩作为一组概括性统计量家族。在此,我们展开讨论从矩中衍生出的实用工具:度量离散程度、位置、形状和关联的统计量。 + +- **离散程度**回答了这样一个问题:数据的分布有多分散?两个班级的平均考试成绩可能相同,但其分散程度却可能大相径庭。 + +![均值相同但离散程度不同的两个分布](../images/variance_spread.svg) + +- 窄(蓝色)分布的方差较小:大部分数值紧密聚集在均值周围。宽(红色)分布的方差较大:数值散布得更远。 + +- **方差**是距均值距离的平方的平均值。取平方是为了避免正负偏差相互抵消。 + +$$\sigma^2 = \frac{1}{N} \sum_{i=1}^{N} (x_i - \mu)^2$$ + +- 当处理样本(而非整个总体)时,我们用 $N - 1$ 而不是 $N$ 来除。这种修正(称为**贝塞尔校正**)是因为样本往往会低估真实的变异性: + +$$s^2 = \frac{1}{N-1} \sum_{i=1}^{N} (x_i - \bar{x})^2$$ + +- **标准差**是方差的平方根:$\sigma = \sqrt{\sigma^2}$。它将度量单位恢复为原始单位。如果数据的单位是厘米,方差的单位是 cm$^2$,而标准差的单位又回到了 cm。 + +- **平均绝对偏差(MAD)**是一个更简单的替代方案。它不取平方,而是取每个偏差的绝对值: + +$$\text{MAD} = \frac{1}{N} \sum_{i=1}^{N} |x_i - \mu|$$ + +- MAD 对方差而言对异常值更稳健,因为它不会通过平方来放大大的偏差。然而,方差在数学上更便利(在证明和机器学习优化中更容易分解)。 + +- **位置**回答了一个不同的问题:特定数值相对于其余数据的位置在哪里? + +- **四分位数**将排序后的数据分成四个相等的部分。Q1(第 25 百分位数)是低于该值的数据占 25% 的值。Q2 是中位数(第 50 百分位数)。Q3 是第 75 百分位数。 + +- **四分位距(IQR)**是 $Q3 - Q1$。它捕捉了中间 50% 数据的离散程度,排除了极端值。 + +![显示 Q1、中位数、Q3、IQR、须线和异常值的箱线图](../images/quartiles_boxplot.svg) + +- **箱线图**是统计学中最有用的可视化工具之一。箱体从 Q1 延伸到 Q3,中间的线为中位数,须线延伸到最远的非异常值,而须线之外的点则为异常值。 + +- **百分位数**是四分位数的推广。第 $p$ 百分位数是低于该值的观测值占 $p\%$ 的值。Q1 是第 25 百分位数,中位数是第 50 百分位数,Q3 是第 75 百分位数。 + +- **z 分数**告诉你一个值距均值有多少个标准差: + +$$z = \frac{x - \mu}{\sigma}$$ + +- z 分数为 2 表示该值高于均值 2 个标准差。z 分数为 $-1.5$ 表示低于均值 1.5 个标准差。这也称为**标准化**,在机器学习中广泛用于特征缩放,因为它将任何分布变换为均值为 0、标准差为 1。 + +- **形状**描述了分布超出其中心和离散程度之外的几何特征。 + +- **偏度**(上一节中的标准化三阶矩)衡量不对称性。像正态曲线这样完全对称的分布,其偏度为零。正偏度表示右尾较长(如收入分布)。负偏度表示左尾较长(如退休年龄分布)。 + +$$\text{偏度} = \frac{1}{N} \sum_{i=1}^{N} \left(\frac{x_i - \mu}{\sigma}\right)^3$$ + +- **峰度**(标准化四阶矩)衡量尾部厚度。正态分布的峰度为 3。尾部更厚(更容易出现异常值)的分布的峰度大于 3。 + +$$\text{峰度} = \frac{1}{N} \sum_{i=1}^{N} \left(\frac{x_i - \mu}{\sigma}\right)^4$$ + +- **相关**衡量两个变量之间关系的强度和方向。它回答了:当一个变量上升时,另一个变量倾向于上升、下降,还是基本不变? + +![三个散点图,分别显示正相关、无相关和负相关](../images/correlation_scatter.svg) + +- **皮尔森相关**($r$)衡量*线性*关联。其取值范围从 $-1$(完全负相关)到 $0$(无相关)再到 $+1$(完全正相关)。 + +$$r = \frac{\sum_{i=1}^{N} (x_i - \bar{x})(y_i - \bar{y})}{\sqrt{\sum (x_i - \bar{x})^2} \cdot \sqrt{\sum (y_i - \bar{y})^2}}$$ + +- 如果你还记得第 1 章中的点积,皮尔森相关本质上就是 $\mathbf{x}$ 和 $\mathbf{y}$ 均值中心化之后的余弦相似度。 + +- **斯皮尔曼相关**($\rho$)衡量*单调*关联。它不使用原始值,而是先对它们进行排序,然后在排序上计算皮尔森相关。这使得它对异常值稳健,并且即使关系是非线性的,只要是一致递增或递减的,也能正常工作。 + +- **几何平均数**是当数值相互乘除时(如增长率)合适的平均值。如果你的投资分别增长了 10%、20% 和 30%,那么平均增长因子并不是这些增长率的算术平均数。而是: + +$$\bar{x}_{\text{geo}} = \left(\prod_{i=1}^{N} x_i\right)^{1/N}$$ + +- 具体到增长率,先将百分比转换为因子(1.10、1.20、1.30),计算几何平均数,再减去 1。 + +- **指数移动平均(EMA)**赋予最近观测值更高的权重。与简单移动平均中窗口内所有点权重相等不同,EMA 呈指数衰减: + +$$\text{EMA}_t = \alpha \cdot x_t + (1 - \alpha) \cdot \text{EMA}_{t-1}$$ + +- 平滑因子 $\alpha$(介于 0 和 1 之间)控制旧观测值失去影响的速度。$\alpha$ 越大,对近期变化的响应越灵敏;$\alpha$ 越小,曲线越平滑。在机器学习中,EMA 被用于 Adam 等优化器以及批归一化的运行统计中。 + +- **异常值检测**识别出与其余数据异常遥远的数点。两种常用方法: + - **IQR 法**:如果一个点低于 $Q1 - 1.5 \times \text{IQR}$ 或高于 $Q3 + 1.5 \times \text{IQR}$,则为异常值 + - **Z 分数法**:如果 $|z| > 3$(距均值超过 3 个标准差),则为异常值 + +- IQR 法更稳健,因为它不假设正态分布。Z 分数法在数据近似正态时效果良好,但当分布高度偏斜时可能失效。 + +## 编程练习(使用 CoLab 或 notebook) + +1. 计算数据集的方差、标准差和 MAD,并进行比较。观察添加极端异常值时发生的变化。 +```python +import jax.numpy as jnp + +data = jnp.array([4, 8, 6, 5, 3, 7, 9, 5, 6, 7], dtype=jnp.float32) + +mean = jnp.mean(data) +variance = jnp.var(data) +std = jnp.std(data) +mad = jnp.mean(jnp.abs(data - mean)) + +print("原始数据:") +print(f" 方差:{variance:.3f},标准差:{std:.3f},MAD:{mad:.3f}") + +# 添加一个异常值并重新计算 +data_outlier = jnp.append(data, 100.0) +mean2 = jnp.mean(data_outlier) +print(f"\n添加异常值(100)后:") +print(f" 方差:{jnp.var(data_outlier):.3f},标准差:{jnp.std(data_outlier):.3f},MAD:{jnp.mean(jnp.abs(data_outlier - mean2)):.3f}") +``` + +2. 计算两个变量之间的皮尔森相关和斯皮尔曼相关。尝试不同的关系。 +```python +import jax +import jax.numpy as jnp + +# 完全线性关系 +x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=jnp.float32) +y = 2 * x + 1 # 试试修改这个! + +def pearson(a, b): + a_c = a - jnp.mean(a) + b_c = b - jnp.mean(b) + return jnp.sum(a_c * b_c) / (jnp.sqrt(jnp.sum(a_c**2)) * jnp.sqrt(jnp.sum(b_c**2))) + +def spearman(a, b): + rank_a = jnp.argsort(jnp.argsort(a)).astype(jnp.float32) + rank_b = jnp.argsort(jnp.argsort(b)).astype(jnp.float32) + return pearson(rank_a, rank_b) + +print(f"皮尔森 r: {pearson(x, y):.4f}") +print(f"斯皮尔曼 ρ:{spearman(x, y):.4f}") +``` + +3. 分别使用 IQR 和 Z 分数方法实现异常值检测,然后比较它们在偏斜数据上的结果。 +```python +import jax.numpy as jnp + +data = jnp.array([2, 3, 3, 4, 5, 5, 5, 6, 6, 7, 50], dtype=jnp.float32) + +# IQR 方法 +q1, q3 = jnp.percentile(data, 25), jnp.percentile(data, 75) +iqr = q3 - q1 +lower, upper = q1 - 1.5 * iqr, q3 + 1.5 * iqr +iqr_outliers = data[(data < lower) | (data > upper)] +print(f"IQR 边界:[{lower:.1f}, {upper:.1f}]") +print(f"IQR 异常值:{iqr_outliers}") + +# Z 分数方法 +z_scores = (data - jnp.mean(data)) / jnp.std(data) +z_outliers = data[jnp.abs(z_scores) > 3] +print(f"\nZ 分数:{z_scores}") +print(f"Z 分数异常值(|z| > 3):{z_outliers}") +``` + +4. 在不同平滑因子下计算并绘制带噪声数据的指数移动平均。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# 生成带噪声的数据 +key = __import__("jax").random.PRNGKey(0) +noise = __import__("jax").random.normal(key, shape=(50,)) +signal = jnp.linspace(0, 5, 50) + noise + +def ema(data, alpha): + result = jnp.zeros_like(data) + result = result.at[0].set(data[0]) + for t in range(1, len(data)): + result = result.at[t].set(alpha * data[t] + (1 - alpha) * result[t - 1]) + return result + +plt.figure(figsize=(10, 4)) +plt.plot(signal, "o", alpha=0.3, label="原始数据", color="#999") +for alpha, color in [(0.1, "#e74c3c"), (0.3, "#3498db"), (0.7, "#27ae60")]: + plt.plot(ema(signal, alpha), label=f"α={alpha}", color=color, linewidth=2) +plt.legend() +plt.title("不同平滑因子下的 EMA") +plt.show() +``` diff --git a/chapter 04: statistics/03. sampling.md b/chapter 04: statistics/03. sampling.md new file mode 100644 index 0000000..350387d --- /dev/null +++ b/chapter 04: statistics/03. sampling.md @@ -0,0 +1,164 @@ +# 抽样 + +*抽样决定了我们如何收集数据,并直接控制着我们所做每项结论的质量。本文涵盖随机抽样、分层抽样、整群抽样与系统抽样、抽样分布、大数定律以及自助法——这些方法对于机器学习中的训练/测试划分和数据集整理至关重要。* + +- 在理想世界中,你会测量所关心群体中的每一个成员。但在实践中,这几乎永远不可能做到。你无法调查每一位选民,无法测试每一只灯泡,也无法扫描每一位患者。所以你只能抽取一个**样本**,并用它来了解整体。 + +- **总体**是你想研究的个体或项目的完整集合。**样本**是你实际观测到的子集。 + +- **参数**是描述总体的数值(例如,某个国家所有成年人的真实平均身高)。 + +- **统计量**是从样本中计算出的数值(例如,你测量的 500 人的平均身高)。统计量用于估计参数。 + +- 结论的质量完全取决于你如何选择样本。一个有偏的样本会导致有偏的结论,无论你的分析多么复杂。 + +- **抽样框**是你实际从中抽取样本的所有个体的列表。理想情况下,抽样框与总体完全吻合,但在实践中总会存在差距。 + +- 例如,如果你通过电话调查人群,就会漏掉所有没有电话的人。抽样框与总体之间的差异称为**覆盖误差**。 + +- **抽样误差**是样本统计量与总体参数之间的自然差异。 + +- 即使是完全随机的样本也不会与总体完全一致。更大的样本可以减少抽样误差。 + +- 抽样有两大类:概率抽样和非概率抽样。 + +- **概率抽样**意味着总体中的每一个成员都有已知的、非零的概率被选中。这让你能够量化不确定性并推广结果。 + +- **简单随机抽样**:每个个体被选中的概率相等,且每个大小为 $n$ 的可能样本出现的概率相同。就像把每个名字放进一顶帽子里,然后蒙眼抽取。 + +- **分层抽样**:根据某个共同特征(如年龄组、地区)将总体划分为互不重叠的组(层),然后从每一层中随机抽样。这保证了每个群体的代表性,并且当层与层之间存在差异时,可以降低方差。 + +- **整群抽样**:将总体划分为若干组(群),随机选择一些群,然后将所选群中的全部个体都纳入样本。当总体在地理上分散时这种方法很实用,比如在整个学区中抽取整所学校而非单个学生。 + +- **系统抽样**:随机选择一个起点,然后从列表中每隔 $k$ 个个体选取一个。例如,从第 7 个人开始,然后每隔 10 个人取一个(7, 17, 27, ...)。这种方法易于实施,但如果列表中存在隐藏模式,则可能引入偏差。 + +![三种概率抽样方法对比:简单随机、分层和整群](../images/sampling_methods.svg) + +- **非概率抽样**并不给每个成员已知的入选机会。其结果无法被严格推广,但这些方法通常更快、更便宜。 + +- **便利抽样**:选择最容易接触到的人。在购物中心调查人群很方便,但会遗漏那些不去购物中心的人。 + +- **配额抽样**:与分层抽样类似,但没有随机性。研究者通过从每个群体中选取方便接触的个体来填补配额(例如 50 名男性和 50 名女性)。 + +- **雪球抽样**:从少数参与者开始,然后请他们招募其他人。适用于难以接触到的人群(例如研究罕见疾病),但会严重偏向于有社交联系的个体。 + +- 一旦你有了抽样方法,一个自然的问题就出现了:如果抽取一个不同的样本,会得到不同的统计量吗?几乎肯定会。**抽样分布**是一个统计量(如样本均值)在所有相同大小的可能样本上的分布。 + +- 想象一下抽取 1000 个不同的 30 人样本,并计算每个样本的平均身高。这 1000 个均值形成了一个分布。有些会略高于真实的总体均值,有些会略低于,而大多数会聚集在真实值周围。 + +- 这个抽样分布的标准差称为**标准误**: + +$$SE = \frac{\sigma}{\sqrt{n}}$$ + +- 注意标准误随着 $n$ 的增大而缩小。更大的样本能给出更精确的估计。样本量扩大到四倍,标准误减半。 + +- 统计学中最重要的结果是**中心极限定理(CLT)**。它指出:无论原始总体的分布形态如何,随着样本量的增大,样本均值的分布都趋近于正态分布。 + +![CLT:偏态总体产生正态分布的样本均值](../images/central_limit_theorem.svg) + +- 更精确地说,如果 $X_1, X_2, \ldots, X_n$ 是来自任意分布的独立观测值,该分布具有均值 $\mu$ 和有限方差 $\sigma^2$,那么随着 $n$ 增大: + +$$\bar{X} \approx \text{Normal}\!\left(\mu, \frac{\sigma^2}{n}\right)$$ + +- CLT 是大部分推断统计得以进行的基础。它让我们能够使用正态分布作为近似,即使底层数据不是正态分布,只要样本量足够大即可。 + +- "足够大"是多大?一个常见的经验法则是 $n \ge 30$,但这取决于总体的非正态程度。对于高度偏态的分布,你可能需要更大的样本量。对于大致对称的总体,即使 $n = 10$ 也可能足够了。 + +- CLT 有三个关键条件: + - **独立性**:每个观测值不能影响其他观测值 + - **有限方差**:总体方差必须存在(排除了某些特殊分布) + - **同分布**:所有观测值来自同一分布 + +## 编程任务(使用 CoLab 或 notebook) + +1. 可视化演示 CLT:从高度偏态的分布中抽取样本,计算样本均值,观察均值直方图如何变成钟形。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +key = jax.random.PRNGKey(0) + +# 指数分布(高度偏态) +population = jax.random.exponential(key, shape=(100_000,)) + +fig, axes = plt.subplots(1, 4, figsize=(14, 3)) +sample_sizes = [1, 5, 30, 100] + +for ax, n in zip(axes, sample_sizes): + keys = jax.random.split(key, 2000) + means = jnp.array([jax.random.choice(k, population, shape=(n,)).mean() for k in keys]) + ax.hist(means, bins=40, color="#3498db", alpha=0.7, density=True) + ax.set_title(f"n = {n}") + ax.set_xlim(0, 4) + +fig.suptitle("CLT:随着 n 增大,样本均值趋近正态分布", fontsize=13) +plt.tight_layout() +plt.show() +``` + +2. 比较简单随机抽样与分层抽样。创建一个具有不同分组的总体,展示分层抽样能给出更低的估计方差。 +```python +import jax +import jax.numpy as jnp + +key = jax.random.PRNGKey(42) + +# 总体:两个不同的组 +group_a = jax.random.normal(key, shape=(500,)) + 10 # 均值 ~10 +key, subkey = jax.random.split(key) +group_b = jax.random.normal(subkey, shape=(500,)) + 20 # 均值 ~20 +population = jnp.concatenate([group_a, group_b]) + +# 简单随机抽样:1000 次试验,样本量 20 +srs_means = [] +for i in range(1000): + key, subkey = jax.random.split(key) + sample = jax.random.choice(subkey, population, shape=(20,), replace=False) + srs_means.append(sample.mean()) +srs_means = jnp.array(srs_means) + +# 分层抽样:每组各取 10 个 +strat_means = [] +for i in range(1000): + key, k1, k2 = jax.random.split(key, 3) + s_a = jax.random.choice(k1, group_a, shape=(10,), replace=False) + s_b = jax.random.choice(k2, group_b, shape=(10,), replace=False) + strat_means.append(jnp.concatenate([s_a, s_b]).mean()) +strat_means = jnp.array(strat_means) + +print(f"简单随机 - 均值: {srs_means.mean():.3f}, 标准差: {srs_means.std():.3f}") +print(f"分层抽样 - 均值: {strat_means.mean():.3f}, 标准差: {strat_means.std():.3f}") +print(f"分层抽样降低了方差 {(1 - strat_means.var()/srs_means.var())*100:.1f}%") +``` + +3. 探索样本量如何影响标准误。绘制标准误随样本量变化的曲线,验证 $1/\sqrt{n}$ 的关系。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +key = jax.random.PRNGKey(7) +population = jax.random.normal(key, shape=(50_000,)) * 10 + 50 + +sample_sizes = [5, 10, 20, 50, 100, 200, 500, 1000] +std_errors = [] + +for n in sample_sizes: + means = [] + for _ in range(500): + key, subkey = jax.random.split(key) + sample = jax.random.choice(subkey, population, shape=(n,)) + means.append(sample.mean()) + std_errors.append(jnp.array(means).std()) + +plt.figure(figsize=(8, 4)) +plt.plot(sample_sizes, std_errors, "o-", color="#e74c3c", label="观测到的 SE") +theoretical = population.std() / jnp.sqrt(jnp.array(sample_sizes, dtype=jnp.float32)) +plt.plot(sample_sizes, theoretical, "--", color="#3498db", label="σ/√n(理论值)") +plt.xlabel("样本量 (n)") +plt.ylabel("标准误") +plt.legend() +plt.title("标准误随样本量增大而缩小") +plt.show() +``` diff --git a/chapter 04: statistics/04. hypothesis testing.md b/chapter 04: statistics/04. hypothesis testing.md new file mode 100644 index 0000000..821cb5b --- /dev/null +++ b/chapter 04: statistics/04. hypothesis testing.md @@ -0,0 +1,177 @@ +# 假设检验 + +*假设检验提供了一个严谨的框架,用于判断观测到的效应是真实存在的还是由随机因素造成的。本文涵盖原假设与备择假设、p值、显著性水平、t检验、卡方检验、方差分析以及第一类/第二类错误——这些逻辑同样应用于A/B测试、模型比较和研究中。* + +- 统计学不仅仅是对数据进行描述。你经常需要做出决策:新药是否有效?某个算法是否比另一个更快?平均值是否发生了变化?假设检验为你提供了一个基于数据回答这些问题的结构化框架。 + +- 其思路很简单:假设没有任何变化("原假设"),然后检验数据是否极端到让这个假设难以令人置信。 + +- **原假设**($H_0$)是默认的主张,通常表述为"无效应"或"无差异"。例如:"平均配送时间仍为30分钟"或"新模型并不比旧模型更好"。 + +- **备择假设**($H_1$ 或 $H_a$)是你认为可能成立的替代情况:"平均配送时间发生了变化"或"新模型更好"。 + +- 你永远无法直接证明 $H_1$。相反,你提出这样一个问题:如果 $H_0$ 成立,观察到如此极端的数据的可能性有多大?如果这种可能性非常小,你就拒绝 $H_0$,转而接受 $H_1$。 + +- **检验统计量**是一个单一数值,它概括了你的样本结果与 $H_0$ 预测值之间的偏差程度。不同的检验使用不同的公式,但逻辑始终一致:度量观测值与期望值之间的距离。 + +- **p值**是在假设 $H_0$ 成立的前提下,观察到至少与当前检验统计量一样极端的结果的概率。p值越小,意味着在 $H_0$ 下数据越令人意外。 + +- **显著性水平**($\alpha$)是你在看到数据之前设定的阈值。如果 $p \le \alpha$,则拒绝 $H_0$。常用的选择有 $\alpha = 0.05$(5%)和 $\alpha = 0.01$(1%)。 + +![正态曲线,阴影区域为拒绝域,标出了检验统计量,并高亮了p值区域](../images/hypothesis_test.svg) + +- 阴影尾部即拒绝域。如果你的检验统计量落在此区域,说明在 $H_0$ 下数据足够极端,你拒绝 $H_0$。绿色区域显示了某个特定检验统计量对应的p值。 + +- 以下是逐步流程: + - **第1步**:提出 $H_0$ 和 $H_1$ + - **第2步**:选择显著性水平 $\alpha$ + - **第3步**:收集数据并计算检验统计量 + - **第4步**:计算p值(或将检验统计量与临界值比较) + - **第5步**:如果 $p \le \alpha$,拒绝 $H_0$;否则,无法拒绝 $H_0$ + +- **实例演练**:某工厂声称其螺栓的平均长度为10 cm。你测量了36个螺栓,得到样本均值为10.3 cm。已知总体标准差为0.9 cm。是否有证据表明均值发生了变化? + +- $H_0$:$\mu = 10$,$H_1$:$\mu \neq 10$,$\alpha = 0.05$ + +- 检验统计量(z检验,因为 $\sigma$ 已知且 $n$ 较大): + +$$z = \frac{\bar{x} - \mu_0}{\sigma / \sqrt{n}} = \frac{10.3 - 10}{0.9 / \sqrt{36}} = \frac{0.3}{0.15} = 2.0$$ + +- 对于 $\alpha = 0.05$ 的双侧检验,临界值为 $\pm 1.96$。我们的 $z = 2.0 > 1.96$,因此拒绝 $H_0$。p值约为0.046,小于0.05。 + +- 结论:有统计学上的显著证据表明,螺栓的平均长度与10 cm不同。 + +- **单侧检验**检查效应是否朝着某个特定方向发生($H_1$:$\mu > 10$ 或 $\mu < 10$)。整个 $\alpha$ 集中于一个尾部,使得在该方向上更容易拒绝 $H_0$,但无法检测到相反方向的效应。 + +- **双侧检验**检查是否存在任何差异($H_1$:$\mu \neq 10$)。$\alpha$ 被分配到两个尾部(各 $\alpha/2$)。这种方法更保守,但能捕捉到两个方向上的效应。 + +- 即使有了良好的流程,错误仍然可能发生。共有两种类型的错误: + +![2×2矩阵图,展示第一类错误和第二类错误:实际情况与决策结果](../images/type_errors.svg) + +- **第一类错误**(假阳性):当 $H_0$ 实际为真时,你错误地拒绝了它。其概率为 $\alpha$,你可以通过选择显著性水平来控制。就像没有火灾时火灾报警器却响了。 + +- **第二类错误**(假阴性):当 $H_0$ 实际为假时,你未能拒绝它。其概率为 $\beta$。就像发生真实火灾时火灾报警器保持沉默。 + +- **检验功效**为 $1 - \beta$,即正确拒绝错误 $H_0$ 的概率。功效越高,意味着你检测真实效应的能力越强。功效随以下因素增加: + - 真实效应量更大(差异越大越容易检测) + - 样本量更大(更多数据 = 更高精度) + - 显著性水平 $\alpha$ 更大(但这会增加第一类错误的风险) + - 变异性更低(噪声更小) + +- 第一类错误与第二类错误之间存在权衡关系。降低 $\alpha$(对假阳性更加谨慎)会增加 $\beta$(更多假阴性)。在固定样本量下,你无法同时最小化这两类错误。 + +- **参数检验**假设数据服从特定的分布(通常是正态分布)。当假设条件成立时,参数检验的功效更高。 + +- **Z检验**:在 $\sigma$ 已知且 $n$ 较大($n \ge 30$)时,将样本均值与已知值进行比较。检验统计量: + +$$z = \frac{\bar{x} - \mu_0}{\sigma / \sqrt{n}}$$ + +- **T检验**:类似于z检验,但适用于 $\sigma$ 未知(由样本估计)或 $n$ 较小的情况。使用t分布,其尾部比正态分布更厚。更厚的尾部反映了估计 $\sigma$ 所引入的额外不确定性。 + +$$t = \frac{\bar{x} - \mu_0}{s / \sqrt{n}}$$ + +- t分布有一个称为**自由度**($df = n - 1$)的参数。随着 $df$ 增大,t分布趋近于正态分布。 + +- t检验有几种变体: + - **单样本t检验**:样本均值是否与某个特定值不同? + - **独立双样本t检验**:两个独立组的均值是否不同? + - **配对t检验**:两个相关测量值的均值是否不同(例如同一批受试者治疗前后的测量值)? + +- **方差分析**:检验三个或更多组的均值是否相等。与运行多次t检验(这会膨胀第一类错误率)不同,方差分析通过比较组间方差与组内方差进行一次统一检验。 + +$$F = \frac{\text{组间方差}}{\text{组内方差}}$$ + +- 较大的 $F$ 比值意味着各组之间的差异超出了随机变异所能解释的范围。 + +- **非参数检验**对数据分布的假设较少。它们基于秩次而非原始值进行运算,因此对异常值和非正态性具有稳健性。 + +- **卡方检验**($\chi^2$):检验观测频数与期望频数是否一致。用于分类数据。例如:红、蓝、绿三种颜色汽车的比例是否与制造商声称的比例一致? + +$$\chi^2 = \sum \frac{(O_i - E_i)^2}{E_i}$$ + +- **Mann-Whitney U检验**:独立双样本t检验的非参数替代方法。通过比较秩次来检验一组是否倾向于比另一组有更大的值。 + +- **Wilcoxon符号秩检验**:配对t检验的非参数替代方法。通过考察差异的大小和方向来比较配对观测值。 + +- **Kruskal-Wallis检验**:单因素方差分析的非参数替代方法。通过比较所有组的秩次来检验多个组是否来自同一分布。 + +- **拟合优度检验**检查数据是否服从某个特定的理论分布。卡方拟合优度检验将观测到的区间计数与假设分布下的期望计数进行比较。 + +- **正态性检验**专门检验数据是否服从正态分布。常用的检验包括Shapiro-Wilk检验(对小样本检验力强)和Kolmogorov-Smirnov检验(将样本经验分布函数与理论分布函数进行比较)。 + +- 在机器学习中,假设检验出现在比较模型性能时。如果模型A达到92%的准确率,模型B达到91%的准确率,这种差异是真实的还是仅仅是噪声?对交叉验证得分进行配对t检验可以回答这个问题。 + +## 编程练习(使用CoLab或notebook) + +1. 对文中的螺栓工厂示例执行z检验。计算检验统计量、p值并做出决策。 +```python +import jax.numpy as jnp + +x_bar = 10.3 # 样本均值 +mu_0 = 10.0 # 原假设值 +sigma = 0.9 # 已知总体标准差 +n = 36 # 样本量 +alpha = 0.05 + +# 检验统计量 +z = (x_bar - mu_0) / (sigma / jnp.sqrt(n)) +print(f"z = {z:.4f}") + +# p值(双侧检验)使用正态CDF近似 +# 对于 |z| = 2.0,p ≈ 0.0456 +from jax.scipy.stats import norm +p_value = 2 * (1 - norm.cdf(jnp.abs(z))) +print(f"p值 = {p_value:.4f}") +print(f"拒绝H₀?{p_value <= alpha}") +``` + +2. 模拟第一类错误:当 $H_0$ 为真时,我们犯错误的频率有多高?运行10,000次实验,检验拒绝率是否与 $\alpha$ 一致。 +```python +import jax +import jax.numpy as jnp + +key = jax.random.PRNGKey(0) +mu_0 = 50.0 +sigma = 10.0 +n = 30 +alpha = 0.05 +n_experiments = 10_000 + +rejections = 0 +for i in range(n_experiments): + key, subkey = jax.random.split(key) + sample = mu_0 + sigma * jax.random.normal(subkey, shape=(n,)) + z = (sample.mean() - mu_0) / (sigma / jnp.sqrt(n)) + p_value = 2 * (1 - __import__("jax").scipy.stats.norm.cdf(jnp.abs(z))) + if p_value <= alpha: + rejections += 1 + +print(f"拒绝率:{rejections/n_experiments:.4f}") +print(f"期望值(α): {alpha}") +``` + +3. 对两组数据分别运行t检验和Mann-Whitney U检验。生成一组均值略高于另一组的数据,观察哪种检验能检测出差异。 +```python +import jax +import jax.numpy as jnp + +key = jax.random.PRNGKey(99) +k1, k2 = jax.random.split(key) + +group_a = jax.random.normal(k1, shape=(25,)) * 5 + 100 +group_b = jax.random.normal(k2, shape=(25,)) * 5 + 103 # 均值略高 + +# 双样本t检验(假设方差相等) +n_a, n_b = len(group_a), len(group_b) +mean_a, mean_b = group_a.mean(), group_b.mean() +pooled_var = ((n_a - 1) * group_a.var() + (n_b - 1) * group_b.var()) / (n_a + n_b - 2) +se = jnp.sqrt(pooled_var * (1/n_a + 1/n_b)) +t_stat = (mean_a - mean_b) / se +print(f"t检验统计量:{t_stat:.4f}") + +# Mann-Whitney:统计group_a的值小于group_b值的次数 +u_stat = jnp.sum(group_a[:, None] < group_b[None, :]) +print(f"Mann-Whitney U: {u_stat}") +print(f"\nA组均值:{mean_a:.2f},B组均值:{mean_b:.2f}") +``` diff --git a/chapter 04: statistics/05. inference.md b/chapter 04: statistics/05. inference.md new file mode 100644 index 0000000..3cd67e0 --- /dev/null +++ b/chapter 04: statistics/05. inference.md @@ -0,0 +1,210 @@ +# 统计推断 + +*统计推断超越了简单的"是/否"决策,以量化的不确定性来估计总体参数。本节涵盖置信区间、点估计与区间估计、极大似然估计、矩法以及回归分析——这是连接原始数据与机器学习预测模型的桥梁。* + +- 假设检验给出一个"是/否"的结论:拒绝或不拒绝原假设。但通常你希望得到更有信息量的结果——你正在估计的参数的一个合理取值区间。这正是**置信区间**所提供的。 + +- **点估计**是从样本中计算出的单一数值,比如样本均值 $\bar{x}$。它是你对总体参数的最佳猜测,但仅凭它本身无法反映估计的精确程度。 + +- **置信区间**在点估计周围包裹一个反映不确定性的范围。其形式为: + +$$\text{CI} = \bar{x} \pm \text{ME}$$ + +- **误差范围**取决于三个因素:你希望多高的置信度、数据的变异程度有多大、以及样本量有多大: + +$$\text{ME} = z^\ast \cdot \frac{\sigma}{\sqrt{n}}$$ + +- 其中 $z^\ast$ 是从正态分布中查得的临界值,与你期望的置信水平对应。对于 95% 置信度,$z^\ast = 1.96$;对于 99% 置信度,$z^\ast = 2.576$。 + +![置信区间:点估计及其两侧的误差范围](../images/confidence_interval.svg) + +- **95% 置信区间**的含义是:如果你重复进行多次实验,每次构建一个区间,那么大约 95% 的区间会包含真实的总体参数。这并不意味着该参数有 95% 的概率落在这个特定的区间内。参数是一个固定值;变化的是区间本身。 + +- **示例**:你测量了 50 人的身高,得到 $\bar{x} = 170$ cm,$\sigma = 8$ cm。构建一个 95% 置信区间。 + +$$\text{ME} = 1.96 \cdot \frac{8}{\sqrt{50}} = 1.96 \cdot 1.131 = 2.22 \text{ cm}$$ + +$$\text{CI} = [170 - 2.22, \; 170 + 2.22] = [167.78, \; 172.22]$$ + +- 你可以说,有 95% 的把握认为真正的平均身高介于 167.78 cm 和 172.22 cm 之间。 + +- 当 $\sigma$ 未知时(这是常见情况),改用样本标准差 $s$ 和 t 分布: + +$$\text{CI} = \bar{x} \pm t^\ast_{n-1} \cdot \frac{s}{\sqrt{n}}$$ + +- 越宽的区间置信度越高,但精度越低;越窄的区间精度越高,但置信度越低。在不降低置信度的前提下,增大样本量可以缩小区间。 + +- **功效分析**帮助你在实验开始前进行规划。要回答的问题是:为了检测到某个给定大小的效应并达到指定的检验功效,我需要多大的样本量? + +- 回顾上一节的内容,功效 = $1 - \beta$,即正确拒绝错误原假设 $H_0$ 的概率。常见的功效目标是 80%。 + +- 对于 z 检验,检测差异 $\delta$ 所需样本量(给定显著性水平 $\alpha$ 和功效 $1-\beta$)为: + +$$n = \left(\frac{(z_{\alpha/2} + z_{\beta}) \cdot \sigma}{\delta}\right)^2$$ + +- 例如,要检测平均身高 2 cm 的差异($\sigma = 8$),取 $\alpha = 0.05$、功效 80%($z_{0.025} = 1.96$,$z_{0.20} = 0.84$): + +$$n = \left(\frac{(1.96 + 0.84) \cdot 8}{2}\right)^2 = \left(\frac{22.4}{2}\right)^2 = 11.2^2 \approx 126$$ + +- 你大约需要每组 126 人。 + +- 功效分析可以防止两种常见错误:实验规模太小,无法检测到真实的效应(功效不足);或者浪费资源做远超必要规模的实验(功效过剩)。 + +- **蒙特卡洛方法**利用随机抽样来求解难以或无法解析求解的问题。其核心思想是:如果你无法精确计算某个量,那就多次模拟并用结果作为近似值。 + +- 名称来源于蒙特卡洛赌场,寓意随机性的角色。这些方法是机器学习中的重要工具,用于估计积分、评估模型不确定性以及近似复杂分布等任务。 + +- 蒙特卡洛的一般步骤: + - 定义可能输入的取值范围 + - 从该范围中随机生成输入 + - 对每个输入评估某个函数 + - 汇总结果(平均值、计数等) + +- 一个经典例子是估算 $\pi$。想象一个边长为 2 的正方形,中心在原点,内切一个半径为 1 的圆。正方形的面积为 4,圆的面积为 $\pi$。 + +![正方形及其内切圆,随机点按圆内/圆外着色](../images/monte_carlo_pi.svg) + +- 在正方形内均匀地随机投点。落在圆内的点的比例近似 $\pi/4$: + +$$\pi \approx 4 \times \frac{\text{圆内点数}}{\text{总点数}}$$ + +- 点 $(x, y)$ 在圆内的条件是 $x^2 + y^2 \le 1$。投的点越多,估算值就越接近 $\pi$ 的真实值。 + +- 在机器学习中,蒙特卡洛方法出现在: + - **蒙特卡洛 Dropout**:多次执行推理(启用 dropout)来估计预测不确定性 + - **MCMC(马尔可夫链蒙特卡洛)**:在贝叶斯模型中从复杂的后验分布中抽样 + - **策略梯度方法**:通过采样轨迹来估计强化学习中的梯度 + +- **因子分析**是一种发现隐藏(潜在)变量的技术,这些变量解释了观测变量之间的相关性。如果 10 个个性调查问题可以由 3 个潜在特质(外向性、宜人性、责任心)解释,因子分析就能找出这些特质。 + +- 该模型假设每个观测变量 $x_i$ 是少数潜在因子 $f_j$ 的线性组合加上噪声: + +$$x_i = \lambda_{i1} f_1 + \lambda_{i2} f_2 + \ldots + \lambda_{ik} f_k + \epsilon_i$$ + +- $\lambda$ 值称为**因子载荷**,表示每个观测变量与各因子的关联强度。这与第 2 章的矩阵分解直接相关;因子分析与特征值分解和 SVD 密切相关。 + +- **实验设计**是安排实验结构的艺术,使你能够得出有效的结论。糟糕的设计甚至会使大量数据变得毫无价值。 + +- 良好实验设计的关键要素: + - **自变量**:你操控的变量(例如药物剂量、模型架构) + - **因变量**:你测量的变量(例如恢复时间、准确率) + - **对照组**:不接受处理(或接受安慰剂),提供比较的基线 + - **随机分配**:参与者被随机分配到各组,从而平衡掉未测量的混杂变量 + +- **常见的实验设计**: + - **完全随机设计**:受试者被随机分配到处理组。在各组可比的情况下,简单有效。 + - **随机区组设计**:受试者先按区组分组(例如按年龄),然后在每个区组内随机分配到处理组。这降低了区组因素带来的变异,类似于分层抽样的思路。 + - **析因设计**:同时测试多个自变量。一个 $2 \times 3$ 的析因设计包含一个变量的 2 个水平和另一个变量的 3 个水平,共 6 种处理组合。这使你能够检测到**交互作用**——即一个变量的效应取决于另一个变量的水平。 + - **交叉设计**:每个受试者按顺序接受所有处理(其间有洗脱期)。每个受试者作为自身的对照,减少了个体差异的影响。 + +- 在机器学习实验中,这些原则至关重要。比较模型时,应控制随机种子、数据集划分和硬件环境。交叉验证是一种交叉设计形式。逐次移除一个组件的消融研究则遵循析因设计的逻辑。 + +## 编程练习(在 CoLab 或 notebook 中完成) + +1. 为身高示例构建一个 95% 置信区间,然后尝试不同的置信水平和样本量。 +```python +import jax.numpy as jnp + +x_bar = 170.0 # 样本均值 +sigma = 8.0 # 总体标准差(已知) +n = 50 # 样本量 + +# 常用置信水平的临界值 +z_stars = {0.90: 1.645, 0.95: 1.960, 0.99: 2.576} + +for conf, z_star in z_stars.items(): + me = z_star * (sigma / jnp.sqrt(n)) + lower, upper = x_bar - me, x_bar + me + print(f"{conf*100:.0f}% CI: [{lower:.2f}, {upper:.2f}] (ME = {me:.2f})") +``` + +2. 使用蒙特卡洛模拟估算 $\pi$。绘制随着点数增加估算值收敛的曲线。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +key = jax.random.PRNGKey(42) + +# 在 [-1, 1] x [-1, 1] 内生成随机点 +n_points = 100_000 +k1, k2 = jax.random.split(key) +x = jax.random.uniform(k1, shape=(n_points,), minval=-1, maxval=1) +y = jax.random.uniform(k2, shape=(n_points,), minval=-1, maxval=1) + +# 检查哪些点在单位圆内 +inside = (x**2 + y**2) <= 1.0 +cumulative_inside = jnp.cumsum(inside) +counts = jnp.arange(1, n_points + 1) +pi_estimates = 4.0 * cumulative_inside / counts + +plt.figure(figsize=(10, 4)) +plt.plot(pi_estimates, color="#3498db", alpha=0.7, linewidth=0.5) +plt.axhline(y=jnp.pi, color="#e74c3c", linestyle="--", label=f"π = {jnp.pi:.6f}") +plt.xlabel("点数") +plt.ylabel("π 的估算值") +plt.title("蒙特卡洛估算 π") +plt.legend() +plt.ylim(2.8, 3.5) +plt.show() + +print(f"最终估算值: {pi_estimates[-1]:.6f}") +print(f"真实值: {jnp.pi:.6f}") +print(f"误差: {abs(pi_estimates[-1] - jnp.pi):.6f}") +``` + +3. 执行一个简单的功效分析:给定效应大小和标准差,计算所需样本量并通过模拟验证。 +```python +import jax +import jax.numpy as jnp + +# 参数 +delta = 2.0 # 效应大小(均值差) +sigma = 8.0 # 总体标准差 +alpha = 0.05 +power_target = 0.80 + +# 解析计算的样本量 +z_alpha = 1.96 # 双尾,alpha=0.05 +z_beta = 0.84 # power=0.80 +n_required = ((z_alpha + z_beta) * sigma / delta) ** 2 +print(f"每组所需样本量: {n_required:.0f}") + +# 通过模拟验证 +key = jax.random.PRNGKey(7) +n = int(jnp.ceil(n_required)) +n_sims = 5000 +rejections = 0 + +for _ in range(n_sims): + key, k1, k2 = jax.random.split(key, 3) + group_a = jax.random.normal(k1, shape=(n,)) * sigma + 50 + group_b = jax.random.normal(k2, shape=(n,)) * sigma + 50 + delta + pooled_se = jnp.sqrt(2 * sigma**2 / n) + z = (group_b.mean() - group_a.mean()) / pooled_se + p = 2 * (1 - __import__("jax").scipy.stats.norm.cdf(jnp.abs(z))) + if p <= alpha: + rejections += 1 + +print(f"模拟功效: {rejections/n_sims:.3f}") +print(f"目标功效: {power_target:.3f}") +``` + +4. 可视化置信区间宽度随样本量的变化。这展示了为什么收集更多数据可以得到更精确的估计。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +sigma = 8.0 +z_star = 1.96 # 95% 置信度 + +sample_sizes = jnp.array([10, 20, 30, 50, 100, 200, 500, 1000], dtype=jnp.float32) +margins = z_star * sigma / jnp.sqrt(sample_sizes) + +plt.figure(figsize=(8, 4)) +plt.bar([str(int(n)) for n in sample_sizes], margins, color="#3498db", alpha=0.7) +plt.xlabel("样本量") +plt.ylabel("误差范围 (cm)") +plt.title("95% CI 误差范围随样本量增大而缩小") +plt.show() +``` diff --git a/chapter 05: probability/01. counting.md b/chapter 05: probability/01. counting.md new file mode 100644 index 0000000..139f2e3 --- /dev/null +++ b/chapter 05: probability/01. counting.md @@ -0,0 +1,167 @@ +# 计数 + +*计数是计算概率的前提——在分配可能性之前,你必须先知道有多少种结果。本文涵盖乘法与加法规则、阶乘、排列、组合、二项式系数,以及支撑机器学习中采样、哈希和概率分析的基本组合工具。* + +- 在计算概率之前,我们需要先数清结果的数量。如果你想知道在扑克中拿到一手赢牌的概率,你必须先知道一共有多少种可能的牌型,以及其中有多少种是赢牌。计数正是让概率精确化的基础工具。 + +- 最简单的计数原则是**乘法规则**。如果一个选择有 $m$ 种选项,另一个独立的选择有 $n$ 种选项,那么组合起来的总结果数为 $m \times n$。 + +- 想象早上穿衣服的场景。你有 3 件衬衫和 4 条裤子。每件衬衫都能与每条裤子搭配,共有 $3 \times 4 = 12$ 种穿搭。 + +![树状图:3 件衬衫 × 4 条裤子 = 12 种穿搭](../images/counting_outfits.svg) + +- 乘法规则可以推广到任意数量的选择。如果你还有 2 双鞋,那么总穿搭数就变成 $3 \times 4 \times 2 = 24$。每个新的独立选择都会乘到总计数中。 + +- **加法规则**处理"或"的场景。如果事件 A 有 $m$ 种发生方式,事件 B 有 $n$ 种发生方式,且它们不能同时发生(互斥),那么总的方式数为 $m + n$。 + +- 假设你要从城市 X 前往城市 Y:开车有 3 条路线,坐火车有 2 条路线。你无法同时选择两者,因此总选项数为 $3 + 2 = 5$。 + +- 当事件有重叠时,需要减去被重复计数的结果。如果 $A$ 和 $B$ 可以同时发生,计数为 $|A \cup B| = |A| + |B| - |A \cap B|$。这就是容斥原理,它将在我们讨论概率加法规则时再次出现。 + +- 非负整数 $n$ 的**阶乘**是从 1 到 $n$ 的所有正整数的乘积: + +$$n! = n \times (n-1) \times (n-2) \times \cdots \times 2 \times 1$$ + +- 可以将阶乘理解为:将 $n$ 个不同的物体排成一列有多少种方式?三本书在书架上有 $3! = 3 \times 2 \times 1 = 6$ 种排列方式。按约定,$0! = 1$。 + +- 阶乘的增长速度极快。$10! = 3{,}628{,}800$,而 $20!$ 已经超过 $2.4 \times 10^{18}$。这种爆炸式增长正是暴力搜索在组合问题中变得不切实际的原因。 + +- **排列**是对物体的有序安排。当你从 $n$ 个不同的物体中选取 $r$ 个且顺序重要时,排列数为: + +$$P(n, r) = \frac{n!}{(n - r)!}$$ + +- 想象从一个 10 人的俱乐部中选出会长、副会长和财务主管。第一个职位有 10 个候选人,第二个有 9 个,第三个有 8 个。因此 $P(10, 3) = 10 \times 9 \times 8 = 720$。公式也印证了这一点:$\frac{10!}{7!} = 720$。 + +- **组合**是无序的选择。当你从 $n$ 个中选取 $r$ 个且顺序无关紧要时,需要除去重复的排列顺序: + +$$C(n, r) = \binom{n}{r} = \frac{n!}{r!(n - r)!}$$ + +- 符号 $\binom{n}{r}$ 读作"n 选 r"。核心思想是:每个组合对应 $r!$ 种排列(选出的 $r$ 个物品可以有 $r!$ 种重新排列的方式),因此我们将排列数除以 $r!$。 + +![并列对比:排列计数所有顺序,组合将相同集合合并](../images/permutation_vs_combination.svg) + +- 示例:从 10 人中组成一个 3 人委员会有多少种方式?顺序无关紧要(没有会长或副会长之分,只有成员),因此我们使用组合: + +$$\binom{10}{3} = \frac{10!}{3! \cdot 7!} = \frac{10 \times 9 \times 8}{3 \times 2 \times 1} = 120$$ + +- 同样的 10 个人产生 720 种排列,但只有 120 种组合,因为每个 3 人组内部有 $3! = 6$ 种排序方式。 + +- 组合在概率中至关重要。二项式系数 $\binom{n}{r}$ 统计了在 $n$ 次试验中恰好获得 $r$ 次成功的方式数,这正是二项分布(见文件 03)的核心。 + +- 让我们通过一个经典的委员会问题来综合运用多种计数思路。 + +- **问题**:一个俱乐部有 8 名男性和 6 名女性。要组成一个 5 人委员会,其中恰好包含 3 名男性和 2 名女性,有多少种方式? + +- **第 1 步**:从 8 人中选 3 名男性。 + +$$\binom{8}{3} = \frac{8!}{3! \cdot 5!} = \frac{8 \times 7 \times 6}{3 \times 2 \times 1} = 56$$ + +- **第 2 步**:从 6 人中选 2 名女性。 + +$$\binom{6}{2} = \frac{6!}{2! \cdot 4!} = \frac{6 \times 5}{2 \times 1} = 15$$ + +- **第 3 步**:应用乘法规则。每种男性选择可以与每种女性选择配对: + +$$56 \times 15 = 840 \text{ 个委员会}$$ + +- 这种将复杂计数问题分解为独立子选择再相乘的模式,是组合数学中的标准方法。 + +- 还有**可重复的排列**。当物品可以重复时,从 $n$ 种类型中选 $r$ 个会产生 $n^r$ 种结果。一个使用数字 0-9 的 4 位 PIN 码有 $10^4 = 10{,}000$ 种可能性。每一位都有 10 种选择,乘法规则即可解决。 + +- **可重复的组合**(也称"星条法")统计从 $n$ 种类型中选 $r$ 个、允许重复且顺序无关的方式数: + +$$\binom{n + r - 1}{r} = \frac{(n + r - 1)!}{r!(n - 1)!}$$ + +- 示例:从 4 种冰淇淋口味中选择 3 勺(允许重复)有 $\binom{4 + 3 - 1}{3} = \binom{6}{3} = 20$ 种选项。 + +- 总结计数工具箱: + +| 场景 | 公式 | +|---|---| +| 有序,无重复(排列) | $P(n,r) = \frac{n!}{(n-r)!}$ | +| 无序,无重复(组合) | $\binom{n}{r} = \frac{n!}{r!(n-r)!}$ | +| 有序,可重复 | $n^r$ | +| 无序,可重复 | $\binom{n+r-1}{r}$ | + +- 每个涉及等可能结果(等概率结果)的概率计算都使用公式 $P(\text{事件}) = \frac{\text{有利结果数}}{\text{总结果数}}$。计数为我们提供了这两个数字。有了这个基础,我们将在下一个文件中正式定义概率本身。 + +## 编程练习(在 CoLab 或 notebook 中完成) + +1. 使用阶乘公式和直接计算两种方式计算 $P(10, 3)$ 和 $\binom{10}{3}$。验证排列数总是组合数的 $r!$ 倍。 +```python +import jax.numpy as jnp +from math import factorial + +n, r = 10, 3 + +perm = factorial(n) // factorial(n - r) +comb = factorial(n) // (factorial(r) * factorial(n - r)) + +print(f"P({n},{r}) = {perm}") +print(f"C({n},{r}) = {comb}") +print(f"P / C = {perm // comb} (应等于 {r}! = {factorial(r)})") +``` + +2. 通过程序解决委员会问题(8 人中选 3 名男性,6 人中选 2 名女性),并通过枚举所有有效委员会来验证。 +```python +from itertools import combinations +from math import factorial + +def comb_count(n, r): + return factorial(n) // (factorial(r) * factorial(n - r)) + +# 公式法 +men_ways = comb_count(8, 3) +women_ways = comb_count(6, 2) +print(f"公式法: {men_ways} × {women_ways} = {men_ways * women_ways}") + +# 枚举法 +men = [f"M{i}" for i in range(1, 9)] +women = [f"W{i}" for i in range(1, 7)] +count = sum(1 for _ in combinations(men, 3) for _ in combinations(women, 2)) +print(f"枚举法: {count}") +``` + +3. 统计由 26 个小写字母组成的 4 位密码有多少种(允许重复)。然后统计没有重复字母的密码有多少种。 +```python +from math import factorial + +n = 26 +r = 4 + +with_rep = n ** r +without_rep = factorial(n) // factorial(n - r) + +print(f"允许重复: {with_rep:>10,}") +print(f"不允许重复: {without_rep:>10,}") +print(f"含重复的比例: {1 - without_rep/with_rep:.2%}") +``` + +4. 模拟生日问题:在 $k$ 人的群体中,至少两人共享生日的概率是多少?绘制 $k = 1$ 到 $60$ 的概率曲线,并找出概率超过 50% 的位置。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +def birthday_prob_exact(k): + \"\"\"k 人群体中至少有一对共享生日的概率。\"\"\" + p_no_match = 1.0 + for i in range(k): + p_no_match *= (365 - i) / 365 + return 1 - p_no_match + +ks = list(range(1, 61)) +probs = [birthday_prob_exact(k) for k in ks] + +plt.figure(figsize=(8, 4)) +plt.plot(ks, probs, color="#3498db", linewidth=2) +plt.axhline(y=0.5, color="#e74c3c", linestyle="--", alpha=0.7, label="50%") +cross = next(k for k, p in zip(ks, probs) if p >= 0.5) +plt.axvline(x=cross, color="#e74c3c", linestyle="--", alpha=0.7) +plt.xlabel("群体大小 (k)") +plt.ylabel("P(至少两人共享生日)") +plt.title(f"生日问题(在 k={cross} 时超过 50%)") +plt.legend() +plt.grid(alpha=0.3) +plt.show() +``` diff --git a/chapter 05: probability/02. probability concepts.md b/chapter 05: probability/02. probability concepts.md new file mode 100644 index 0000000..a416039 --- /dev/null +++ b/chapter 05: probability/02. probability concepts.md @@ -0,0 +1,243 @@ +# 概率概念 + +*概率论形式化了不确定性,并提供了在此框架下进行推理的规则。本文涵盖样本空间、事件、概率公理、条件概率、独立性、贝叶斯定理、频率派与贝叶斯派解释,这是机器学习中每个生成模型和判别模型背后的数学框架。* + +- 概率为一个事件赋予一个介于 0 和 1 之间的数字,衡量该事件发生的可能性。 + +- 概率为 0 表示不可能,1 表示必然,0.5 则像抛硬币一样。 + +- 有两种主要解释。**频率派**观点认为概率是长期相对频率:抛一枚均匀硬币 10,000 次,正面大约会出现 50% 的次数。 + +- **贝叶斯派**观点认为概率是信念程度:你可能会说明天降雨的概率是 70%,尽管明天只会发生一次。 + +- 两种解释使用相同的数学规则。区别在于哲学层面,但在机器学习中这很重要。频率派方法给出点估计。贝叶斯派方法给出参数的完整分布。 + +- **样本空间** $S$ 是实验所有可能结果的集合。抛一枚硬币:$S = \{H, T\}$。掷一个骰子:$S = \{1, 2, 3, 4, 5, 6\}$。 + +- **事件**是样本空间的任意子集。"掷出偶数"是事件 $A = \{2, 4, 6\}$,它是 $S$ 的一个子集。 + +- 当所有结果等可能时,事件的概率就是简单的计数(来自文件 01): + +$$P(A) = \frac{|A|}{|S|} = \frac{\text{有利结果}}{\text{总结果}}$$ + +- 对于偶数例子:$P(\text{偶数}) = \frac{3}{6} = 0.5$。 + +![样本空间 S 中事件 A 和 B 的维恩图,显示交集与补集](../images/venn_diagram.svg) + +- 事件 $A$ 的**补集**,记作 $A'$ 或 $A^c$,是 $S$ 中所有不在 $A$ 中的元素。由于每个结果要么在 $A$ 中,要么不在: + +$$P(A') = 1 - P(A)$$ + +- 补集通常是更简便的途径。与其计算 5 次抛硬币中至少得到一个正面的所有方式,不如计算得到全反面的一种方式然后相减:$P(\text{至少一个正面}) = 1 - P(\text{全反面}) = 1 - (0.5)^5 = 0.969$。 + +- 如果两个事件不能同时发生,则它们是**互斥**(不相交)的:$A \cap B = \emptyset$。一次掷骰子中掷出 2 和掷出 5 是互斥事件。 + +- **互斥事件的加法法则**很直接: + +$$P(A \cup B) = P(A) + P(B) \quad \text{(若 } A \cap B = \emptyset\text{)}$$ + +- 当事件可能有重叠时,需要使用**一般加法法则**来避免重复计算交集: + +$$P(A \cup B) = P(A) + P(B) - P(A \cap B)$$ + +- 这与计数中的容斥原理相对应。上方的维恩图说明了原因:紫色区域(交集)在 $P(A)$ 中被计算一次,在 $P(B)$ 中又被计算一次,因此我们减去一次。 + +- **联合概率** $P(A \cap B)$ 是 $A$ 和 $B$ 同时发生的概率。在一副扑克牌中,$P(\text{红色} \cap \text{国王}) = \frac{2}{52}$,因为有 2 张红色国王。 + +- **边际概率**是单个事件不考虑其他事件时的概率。$P(\text{红色}) = \frac{26}{52} = 0.5$ 是一个边际概率。如果你有关于两个变量的联合分布,通过对另一个变量求和(或积分)即可得到边际概率。 + +- **条件概率**回答的是:已知 $B$ 已经发生,$A$ 的概率是多少?我们将样本空间从 $S$ 缩小到 $B$,并问 $B$ 中同时属于 $A$ 的比例是多少: + +$$P(A | B) = \frac{P(A \cap B)}{P(B)}, \quad P(B) > 0$$ + +![条件概率:将样本空间从 S 缩小到 B](../images/conditional_probability.svg) + +- 示例:你抽一张牌,有人告诉你它是红色。它是国王的概率是多少?有 26 张红色牌,其中 2 张是国王,所以 $P(\text{国王} | \text{红色}) = \frac{2}{26} = \frac{1}{13}$。使用公式:$P(\text{国王} \cap \text{红色}) / P(\text{红色}) = \frac{2/52}{26/52} = \frac{1}{13}$。 + +- 如果知道一个事件发生了不会告诉你关于另一个事件的任何信息,则这两个事件是**独立**的。形式化定义: + +$$P(A \cap B) = P(A) \cdot P(B)$$ + +- 等价地,$P(A | B) = P(A)$。分别抛两枚硬币是独立事件。无放回地抽两张牌不是独立的(第一次抽取会改变剩余牌的数量)。 + +- 独立性是一个巨大的简化工具。对于独立事件,联合概率分解为乘积形式,这使得计算可处理。许多机器学习模型假设特征之间独立(例如朴素贝叶斯),正是基于这种简化。 + +- 任意两个事件的**乘法法则**由条件概率公式重新排列得到: + +$$P(A \cap B) = P(A | B) \cdot P(B) = P(B | A) \cdot P(A)$$ + +- 对于独立事件,由于条件概率等于边际概率,上式简化为 $P(A \cap B) = P(A) \cdot P(B)$。 + +- **贝叶斯定理**是概率论中最重要的结论之一,也是贝叶斯机器学习的基础。它让我们可以反转条件概率的方向: + +$$P(A | B) = \frac{P(B | A) \cdot P(A)}{P(B)}$$ + +- 该定理直接源于将 $P(A \cap B)$ 写成两种形式:$P(B|A) \cdot P(A) = P(A|B) \cdot P(B)$,然后解出 $P(A|B)$。 + +![贝叶斯定理的组成部分:后验、似然、先验和证据](../images/bayes_components.svg) + +- 每个部分都有名称: + - **先验** $P(A)$:看到证据之前的初始信念 + - **似然** $P(B|A)$:假设 $A$ 为真的前提下,证据出现的概率 + - **证据** $P(B)$:看到证据的总概率,起归一化作用 + - **后验** $P(A|B)$:看到证据之后更新后的信念 + +- 让我们通过经典的医学诊断例子来理解。假设某种疾病影响 1% 的人口。针对该疾病的检测准确率为 95%:它能正确识别 95% 的患病者(灵敏度),并能正确识别 90% 的健康人(特异度)。 + +- 你的检测结果为阳性。你实际患病的概率是多少? + +- 设 $D$ = 患病,$+$ = 检测阳性。 + - 先验:$P(D) = 0.01$ + - 似然:$P(+ | D) = 0.95$ + - 假阳性率:$P(+ | D') = 0.10$ + +- 我们需要 $P(+)$。根据全概率公式: + +$$P(+) = P(+ | D) \cdot P(D) + P(+ | D') \cdot P(D')$$ +$$= 0.95 \times 0.01 + 0.10 \times 0.99 = 0.0095 + 0.099 = 0.1085$$ + +- 现在应用贝叶斯定理: + +$$P(D | +) = \frac{P(+ | D) \cdot P(D)}{P(+)} = \frac{0.95 \times 0.01}{0.1085} \approx 0.088$$ + +- 尽管检测"准确率高达 95%",但阳性结果只能给你约 8.8% 的患病概率。先验至关重要。由于该疾病罕见,大多数阳性结果都是假阳性。这对机器学习中的任何分类问题都是一个关键见解:当类别不平衡时,仅靠准确率是具有误导性的。 + +- **全概率公式**将样本空间划分为互斥且完备的事件 $B_1, B_2, \ldots, B_n$,并将任意事件 $A$ 表示为: + +$$P(A) = \sum_{i=1}^{n} P(A | B_i) \cdot P(B_i)$$ + +- 这正是我们在医学例子中计算 $P(+)$ 所用的方法:我们将人群分为"患病"和"未患病"两类。 + +- **概率的链式法则**将乘法法则推广到任意数量的事件: + +$$P(A_1 \cap A_2 \cap \cdots \cap A_n) = P(A_1) \cdot P(A_2 | A_1) \cdot P(A_3 | A_1 \cap A_2) \cdots P(A_n | A_1 \cap \cdots \cap A_{n-1})$$ + +- 每个因子都以前面所有事件为条件。这是自回归语言模型的基石:一个句子的概率等于每个单词在给定前面所有单词条件下的概率的乘积。 + +- **条件独立**意味着两个事件在给定第三个事件的条件下是独立的。如果满足下式,则 $A$ 和 $B$ 在给定 $C$ 的条件下条件独立: + +$$P(A \cap B | C) = P(A | C) \cdot P(B | C)$$ + +- 事件可以边际上相关但条件独立,反之亦然。例如,两名学生的考试成绩可能相关(都依赖于考试的难度),但给定考试难度后,他们的成绩是独立的。 + +- 条件独立是贝叶斯网络等图模型背后的关键假设。它允许将复杂的联合分布分解为可管理的部分,使推断在计算上可行。 + +## 编程练习(使用 CoLab 或 notebook) + +1. 模拟医学诊断问题。生成 100,000 人的总体,应用疾病患病率和检测准确率,验证贝叶斯定理给出正确的后验概率。 +```python +import jax +import jax.numpy as jnp + +key = jax.random.PRNGKey(42) +n = 100_000 + +# 生成总体 +k1, k2 = jax.random.split(key) +has_disease = jax.random.bernoulli(k1, p=0.01, shape=(n,)) + +# 生成检测结果 +k3, k4 = jax.random.split(k2) +# 灵敏度:P(+|D) = 0.95,特异度:P(-|D') = 0.90 +test_positive = jnp.where( + has_disease, + jax.random.bernoulli(k3, p=0.95, shape=(n,)), + jax.random.bernoulli(k4, p=0.10, shape=(n,)) +) + +# 在检测阳性的人群中,实际患病的比例是多少? +positives = test_positive.astype(bool) +true_positives = (has_disease & positives).sum() +total_positives = positives.sum() + +print(f"检测阳性总人数: {total_positives}") +print(f"真阳性人数: {true_positives}") +print(f"P(患病 | 阳性) = {true_positives / total_positives:.4f}") +print(f"贝叶斯公式: {0.95 * 0.01 / 0.1085:.4f}") +``` + +2. 通过模拟验证加法法则。生成具有已知概率和重叠程度的随机事件 A 和 B,然后验证 $P(A \cup B) = P(A) + P(B) - P(A \cap B)$。 +```python +import jax +import jax.numpy as jnp + +key = jax.random.PRNGKey(0) +n = 200_000 +k1, k2 = jax.random.split(key) + +# 事件:A = 值 < 0.4,B = 值 < 0.6(在 < 0.4 处重叠) +vals_a = jax.random.uniform(k1, shape=(n,)) +vals_b = jax.random.uniform(k2, shape=(n,)) + +A = vals_a < 0.4 +B = vals_b < 0.6 + +p_a = A.mean() +p_b = B.mean() +p_a_and_b = (A & B).mean() +p_a_or_b = (A | B).mean() + +print(f"P(A) = {p_a:.4f}") +print(f"P(B) = {p_b:.4f}") +print(f"P(A ∩ B) = {p_a_and_b:.4f}") +print(f"P(A ∪ B) 模拟值 = {p_a_or_b:.4f}") +print(f"P(A) + P(B) - P(A∩B) = {p_a + p_b - p_a_and_b:.4f}") +``` + +3. 演示条件概率随证据变化。模拟掷两个骰子,计算 $P(\text{和} = 7)$,然后计算 $P(\text{和} = 7 | \text{第一个骰子} = 3)$。 +```python +import jax +import jax.numpy as jnp + +key = jax.random.PRNGKey(1) +n = 500_000 +k1, k2 = jax.random.split(key) + +d1 = jax.random.randint(k1, shape=(n,), minval=1, maxval=7) +d2 = jax.random.randint(k2, shape=(n,), minval=1, maxval=7) +total = d1 + d2 + +# 无条件概率 +p_sum7 = (total == 7).mean() +print(f"P(和=7) = {p_sum7:.4f} (精确值: {6/36:.4f})") + +# 条件于第一个骰子 = 3 +mask = d1 == 3 +p_sum7_given_d1_3 = (total[mask] == 7).mean() +print(f"P(和=7 | d1=3) = {p_sum7_given_d1_3:.4f} (精确值: {1/6:.4f})") +``` + +4. 将贝叶斯定理实现为一个函数,并用它迭代更新信念。从硬币偏向的均匀先验开始,在观察到每次抛掷后更新。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +def bayes_update(prior, likelihood): + """将先验乘以似然并归一化。""" + posterior = prior * likelihood + return posterior / posterior.sum() + +# 离散化可能的偏向值 +theta = jnp.linspace(0, 1, 200) +prior = jnp.ones_like(theta) # 均匀先验 +prior = prior / prior.sum() + +# 观测到的抛掷结果:1=正面,0=反面 +flips = [1, 1, 0, 1, 1, 1, 0, 1, 0, 1] + +plt.figure(figsize=(10, 5)) +plt.plot(theta, prior, "--", color="#999", label="先验") + +for i, flip in enumerate(flips): + likelihood = theta if flip == 1 else (1 - theta) + prior = bayes_update(prior, likelihood) + if i in [0, 2, 4, 9]: + plt.plot(theta, prior, label=f"经过 {i+1} 次抛掷后", linewidth=2) + +plt.xlabel("硬币偏向 θ") +plt.ylabel("信念(归一化)") +plt.title("贝叶斯更新:关于硬币偏向的信念") +plt.legend() +plt.grid(alpha=0.3) +plt.show() +``` diff --git a/chapter 05: probability/03. distributions.md b/chapter 05: probability/03. distributions.md new file mode 100644 index 0000000..707dda2 --- /dev/null +++ b/chapter 05: probability/03. distributions.md @@ -0,0 +1,238 @@ +# 概率分布 + +*概率分布描述了随机结果如何在可能取值上分布。本文档整理了关键的离散和连续分布:伯努利分布、二项分布、泊松分布、高斯分布、指数分布、贝塔分布等,给出了各自的公式、直观理解及其在机器学习中的应用(损失函数、先验、噪声模型)。* + +- 在第4章中,我们介绍了随机变量、PMF、PDF和CDF。本章列出你在机器学习和统计学中最常遇到的重要概率分布,给出每个分布的直观理解、公式、均值和方差。 + +- 三种核心函数的快速回顾(完整定义见第4章): + - **PMF** $P(X = x)$:给出每个离散结果的概率。即条形图中每个条形的高度。 + - **PDF** $f(x)$:给出连续变量在每个点上的密度。两点之间曲线下的面积即为概率。 + - **CDF** $F(x) = P(X \le x)$:累积到 $x$ 为止的概率。取值范围始终从0到1且单调不减。 + +- 分布的**支撑集**是指PMF或PDF取正值的集合。对掷骰子而言,支撑集为 $\{1,2,3,4,5,6\}$。对正态分布而言,支撑集为全体实数 $(-\infty, \infty)$。 + +- 分布清晰地分为两个家族:离散分布(结果可数,使用PMF)和连续分布(结果不可数,使用PDF)。 + +- **伯努利分布**:最简单的分布。单次试验有两种结果:成功(1)的概率为 $p$,失败(0)的概率为 $1-p$。 + +$$P(X = x) = p^x (1 - p)^{1-x}, \quad x \in \{0, 1\}$$ + +- 均值:$E[X] = p$。方差:$\text{Var}(X) = p(1-p)$。 + +- 每一次抛硬币、每一个是/否分类、每一个二元结果都是伯努利试验。在机器学习中,sigmoid函数的输出正是伯努利分布的参数 $p$。 + +- **二项分布**:计算 $n$ 次独立伯努利试验中成功的次数,每次试验的成功概率 $p$ 相同。 + +$$P(X = k) = \binom{n}{k} p^k (1-p)^{n-k}, \quad k = 0, 1, \ldots, n$$ + +- 二项式系数 $\binom{n}{k}$(见文件01)计算了 $k$ 次成功在 $n$ 次试验中的排列方式数量。 + +- 均值:$E[X] = np$。方差:$\text{Var}(X) = np(1-p)$。 + +![伯努利分布作为单一条形图与二项分布作为计数上的分布对比](../images/bernoulli_binomial.svg) + +- 示例:抛一枚有偏硬币($p = 0.7$)八次。恰好得到6次正面的概率为 $\binom{8}{6}(0.7)^6(0.3)^2 = 28 \times 0.1176 \times 0.09 \approx 0.296$。 + +- **泊松分布**:在固定的时间或空间区间内,以已知的平均速率 $\lambda$ 计算事件发生的次数。适用于事件稀少且相互独立的情形。 + +$$P(X = k) = \frac{\lambda^k e^{-\lambda}}{k!}, \quad k = 0, 1, 2, \ldots$$ + +- 均值:$E[X] = \lambda$。方差:$\text{Var}(X) = \lambda$。均值等于方差是其标志性特征。 + +- 示例:每小时收到的邮件数($\lambda = 5$)、每页的错别字数、每秒的服务器请求数。在机器学习中,泊松回归用于建模计数数据,而线性模型可能会预测出负的计数值。 + +- 当 $n \to \infty$ 且 $p \to 0$,且 $np = \lambda$ 保持不变时,二项分布 Binomial$(n,p)$ 收敛于泊松分布 Poisson$(\lambda)$。这就是泊松分布适用于大总体中稀有事件的原因。 + +- **几何分布**:计算直到首次成功所需的试验次数。"我要抛多少次硬币才能第一次得到正面?" + +$$P(X = k) = (1-p)^{k-1} p, \quad k = 1, 2, 3, \ldots$$ + +- 均值:$E[X] = 1/p$。方差:$\text{Var}(X) = (1-p)/p^2$。 + +- 几何分布具有**无记忆性**:再等待 $k$ 次试验才成功的概率与你已经等待了多少次试验无关。这使得它在离散分布中非常特殊。 + +- **负二项分布**:推广了几何分布,计算直到第 $r$ 次成功所需的试验次数(几何分布是 $r=1$ 的特殊情形)。 + +$$P(X = k) = \binom{k-1}{r-1} p^r (1-p)^{k-r}, \quad k = r, r+1, r+2, \ldots$$ + +- 均值:$E[X] = r/p$。方差:$\text{Var}(X) = r(1-p)/p^2$。 + +- 负二项分布在实践中也用于建模过度离散的计数数据(方差超过均值的情形),这是泊松分布无法处理的。 + +- 接下来我们进入连续分布。 + +- **均匀分布**:区间 $[a, b]$ 内的所有值等可能。其PDF是一个平坦的矩形。 + +$$f(x) = \frac{1}{b - a}, \quad a \le x \le b$$ + +- 均值:$E[X] = \frac{a+b}{2}$。方差:$\text{Var}(X) = \frac{(b-a)^2}{12}$。 + +- 随机数生成器以生成均匀分布 Uniform(0,1) 样本为起点。其他分布通过对这些均匀样本进行变换得到。 + +- **正态(高斯)分布**:统计学中最重要的分布。它由中心极限定理(见第4章)自然导出:大量独立随机变量的平均值趋于正态分布,无论原始分布是什么。 + +$$f(x) = \frac{1}{\sigma\sqrt{2\pi}} \exp\!\left(-\frac{(x - \mu)^2}{2\sigma^2}\right)$$ + +- 均值:$E[X] = \mu$。方差:$\text{Var}(X) = \sigma^2$。 + +- **标准正态分布**的 $\mu = 0$ 且 $\sigma = 1$。任意正态变量 $X$ 可通过 $Z = (X - \mu)/\sigma$ 标准化为标准正态变量 $Z$。 + +![带有68-95-99.7经验法则区域的钟形曲线](../images/normal_empirical.svg) + +- **经验法则**(68-95-99.7法则)指出: + - 约68%的数据落在均值 $\pm 1\sigma$ 范围内 + - 约95%的数据落在 $\pm 2\sigma$ 范围内 + - 约99.7%的数据落在 $\pm 3\sigma$ 范围内 + +- 在机器学习中,正态分布无处不在:权重初始化、数据增强中的噪声、MSE损失背后的假设(其隐含假设高斯误差)、以及变分自编码器中的重参数化技巧。 + +- **指数分布**:模拟泊松过程中事件之间的时间间隔。如果事件以速率 $\lambda$ 到达,则它们之间的等待时间服从指数分布 Exponential$(\lambda)$。 + +$$f(x) = \lambda e^{-\lambda x}, \quad x \ge 0$$ + +- 均值:$E[X] = 1/\lambda$。方差:$\text{Var}(X) = 1/\lambda^2$。 + +- 与离散变量中的几何分布类似,指数分布也具有**无记忆性**:$P(X > s + t | X > s) = P(X > t)$。再等待 $t$ 个时间单位的概率与你已经等待了多长时间无关。 + +- **伽马分布**:推广了指数分布。它模拟泊松过程中第 $\alpha$ 个事件发生的时间(指数分布是 $\alpha = 1$ 的特殊情形)。 + +$$f(x) = \frac{\beta^\alpha}{\Gamma(\alpha)} x^{\alpha - 1} e^{-\beta x}, \quad x > 0$$ + +- 这里 $\alpha$(形状参数)控制形状,$\beta$(速率参数)控制尺度。$\Gamma(\alpha)$ 是伽马函数,它将阶乘推广到实数:对正整数 $n$ 有 $\Gamma(n) = (n-1)!$。 + +- 均值:$E[X] = \alpha/\beta$。方差:$\text{Var}(X) = \alpha/\beta^2$。 + +- **贝塔分布**:定义在区间 $[0, 1]$ 上,非常适合对概率、比例和比率进行建模。 + +$$f(x) = \frac{x^{\alpha - 1}(1 - x)^{\beta - 1}}{B(\alpha, \beta)}, \quad 0 \le x \le 1$$ + +- 分母 $B(\alpha, \beta) = \frac{\Gamma(\alpha)\Gamma(\beta)}{\Gamma(\alpha + \beta)}$ 是贝塔函数,起到归一化常数的作用。 + +- 均值:$E[X] = \frac{\alpha}{\alpha + \beta}$。方差:$\text{Var}(X) = \frac{\alpha\beta}{(\alpha+\beta)^2(\alpha+\beta+1)}$。 + +- 贝塔分布是伯努利和二项似然函数的共轭先验。这意味着如果先验是贝塔分布且数据服从伯努利分布,则后验也是贝塔分布,这使得贝叶斯更新在解析上易于处理。我们将在文件04中使用这一性质。 + +![四种常见的分布形状:均匀分布、指数分布、贝塔分布、泊松分布](../images/common_distributions.svg) + +- **卡方分布**($\chi^2$):如果你取 $k$ 个独立的标准正态随机变量并求其平方和,结果服从自由度为 $k$ 的 $\chi^2$ 分布。 + +$$f(x) = \frac{1}{2^{k/2}\Gamma(k/2)} x^{k/2 - 1} e^{-x/2}, \quad x > 0$$ + +- 均值:$E[X] = k$。方差:$\text{Var}(X) = 2k$。 + +- $\chi^2$ 分布实际上是伽马分布的特殊情形,其中 $\alpha = k/2$ 且 $\beta = 1/2$。它出现在假设检验(第4章中的卡方检验)、拟合优度检验以及方差置信区间的计算中。 + +- **学生t分布**:形状类似于正态分布但尾部更重。当你使用小样本且总体方差未知时,对正态分布总体的均值进行估计时就会出现t分布。 + +$$f(x) = \frac{\Gamma\!\left(\frac{\nu+1}{2}\right)}{\sqrt{\nu\pi}\,\Gamma\!\left(\frac{\nu}{2}\right)} \left(1 + \frac{x^2}{\nu}\right)^{-(\nu+1)/2}$$ + +- 参数 $\nu$(自由度)。当 $\nu \to \infty$ 时,t分布收敛于标准正态分布。当 $\nu$ 较小时,更重的尾部赋予极端值更高的概率,反映了小样本带来的额外不确定性。 + +- 均值:$E[X] = 0$(当 $\nu > 1$ 时)。方差:$\text{Var}(X) = \frac{\nu}{\nu - 2}$(当 $\nu > 2$ 时)。 + +- t分布用于t检验(第4章),并出现在贝叶斯推断中,作为在积分消去未知方差时的边缘分布。 + +- 关键分布总结: + +| 分布 | 类型 | 支撑集 | 均值 | 方差 | +|---|---|---|---|---| +| Bernoulli$(p)$ | 离散 | $\{0,1\}$ | $p$ | $p(1-p)$ | +| Binomial$(n,p)$ | 离散 | $\{0,\ldots,n\}$ | $np$ | $np(1-p)$ | +| Poisson$(\lambda)$ | 离散 | $\{0,1,2,\ldots\}$ | $\lambda$ | $\lambda$ | +| Geometric$(p)$ | 离散 | $\{1,2,3,\ldots\}$ | $1/p$ | $(1-p)/p^2$ | +| Uniform$(a,b)$ | 连续 | $[a,b]$ | $(a+b)/2$ | $(b-a)^2/12$ | +| Normal$(\mu,\sigma^2)$ | 连续 | $(-\infty,\infty)$ | $\mu$ | $\sigma^2$ | +| Exponential$(\lambda)$ | 连续 | $[0,\infty)$ | $1/\lambda$ | $1/\lambda^2$ | +| Gamma$(\alpha,\beta)$ | 连续 | $(0,\infty)$ | $\alpha/\beta$ | $\alpha/\beta^2$ | +| Beta$(\alpha,\beta)$ | 连续 | $[0,1]$ | $\alpha/(\alpha+\beta)$ | 见上文 | +| $\chi^2(k)$ | 连续 | $(0,\infty)$ | $k$ | $2k$ | +| Student's $t(\nu)$ | 连续 | $(-\infty,\infty)$ | $0$ | $\nu/(\nu-2)$ | + +## 编程练习(使用CoLab或笔记本) + +1. 绘制 $n=20$ 时二项分布PMF在不同 $p$ 取值下的图像。观察形状如何从左偏变为对称再变为右偏。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt +from math import comb + +n = 20 +ks = jnp.arange(0, n + 1) + +fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True) +for ax, p, color in zip(axes, [0.2, 0.5, 0.8], ["#e74c3c", "#3498db", "#27ae60"]): + pmf = jnp.array([comb(n, int(k)) * p**k * (1-p)**(n-k) for k in ks]) + ax.bar(ks, pmf, color=color, alpha=0.7) + ax.set_title(f"Binomial(n={n}, p={p})") + ax.set_xlabel("k") +axes[0].set_ylabel("P(X = k)") +plt.tight_layout() +plt.show() +``` + +2. 验证泊松分布对二项分布的近似。设 $n = 1000$,$p = 0.003$,比较二项分布 Binomial$(n, p)$ 和泊松分布 Poisson$(\lambda = np)$。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt +from math import comb, factorial, exp + +n, p = 1000, 0.003 +lam = n * p +ks = jnp.arange(0, 15) + +binom_pmf = jnp.array([comb(n, int(k)) * p**k * (1-p)**(n-k) for k in ks]) +poisson_pmf = jnp.array([lam**k * exp(-lam) / factorial(int(k)) for k in ks]) + +plt.figure(figsize=(8, 4)) +plt.bar(ks - 0.15, binom_pmf, width=0.3, color="#3498db", alpha=0.7, label=f"Binomial({n},{p})") +plt.bar(ks + 0.15, poisson_pmf, width=0.3, color="#e74c3c", alpha=0.7, label=f"Poisson({lam})") +plt.xlabel("k") +plt.ylabel("P(X = k)") +plt.title("泊松分布对二项分布的近似") +plt.legend() +plt.show() +``` + +3. 从正态分布中采样并验证经验法则。计算落在1、2和3个标准差内的样本比例。 +```python +import jax +import jax.numpy as jnp + +key = jax.random.PRNGKey(42) +mu, sigma = 5.0, 2.0 +samples = mu + sigma * jax.random.normal(key, shape=(100_000,)) + +for k in [1, 2, 3]: + within = jnp.abs(samples - mu) <= k * sigma + print(f"Within {k}σ: {within.mean():.4f} (expected: {[0.6827, 0.9545, 0.9973][k-1]:.4f})") +``` + +4. 通过改变 $\alpha$ 和 $\beta$ 探索贝塔分布。绘制几种形状,观察分布如何从均匀变为偏斜再变为集中。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +x = jnp.linspace(0.01, 0.99, 200) + +def beta_pdf(x, a, b): + # 未归一化,用于形状比较 + return x**(a-1) * (1-x)**(b-1) + +plt.figure(figsize=(10, 5)) +params = [(1,1,"均匀"), (2,5,"左偏"), (5,2,"右偏"), + (5,5,"对称"), (0.5,0.5,"U形")] +colors = ["#999", "#e74c3c", "#3498db", "#27ae60", "#9b59b6"] + +for (a, b, label), color in zip(params, colors): + y = beta_pdf(x, a, b) + y = y / jnp.trapezoid(y, x) # 归一化 + plt.plot(x, y, label=f"α={a}, β={b} ({label})", color=color, linewidth=2) + +plt.xlabel("x") +plt.ylabel("密度") +plt.title("贝塔分布形状") +plt.legend() +plt.grid(alpha=0.3) +plt.show() +``` diff --git a/chapter 05: probability/04. bayesian.md b/chapter 05: probability/04. bayesian.md new file mode 100644 index 0000000..34be1ed --- /dev/null +++ b/chapter 05: probability/04. bayesian.md @@ -0,0 +1,292 @@ +# 贝叶斯方法与序列模型 + +*贝叶斯方法将先验信念与观测数据相结合,生成模型参数的后验分布。本文涵盖最大似然估计、最大后验估计、共轭先验、贝叶斯推断、隐马尔可夫模型和EM算法——这些技术是垃圾邮件过滤器、语言模型和不确定性感知机器学习的基础。* + +- 到目前为止,我们介绍了各种分布以及如何计算概率。现在我们来处理机器学习的核心问题:给定观测数据,如何找到模型的最佳参数? + +- **最大似然估计 (MLE)** 直接回答了这个问题:选择使观测数据概率最大的参数值。 + +- 形式上,给定数据 $D = \{x_1, x_2, \ldots, x_n\}$ 和带有参数 $\theta$ 的模型,**似然函数**为: + +$$L(\theta | D) = P(D | \theta) = \prod_{i=1}^{n} P(x_i | \theta)$$ + +- 乘积假设数据点独立同分布(i.i.d.)。MLE估计量为: + +$$\hat{\theta}_{\text{MLE}} = \arg\max_\theta L(\theta | D)$$ + +- 实践中我们最大化**对数似然**,因为对数将乘积转化为求和,并防止数值下溢: + +$$\ell(\theta) = \log L(\theta | D) = \sum_{i=1}^{n} \log P(x_i | \theta)$$ + +- 由于 $\log$ 是单调递增函数,使得 $\ell(\theta)$ 最大的 $\theta$ 也同样使得 $L(\theta)$ 最大。 + +- **抛硬币示例**:你抛一枚硬币10次,得到7次正面。硬币偏置 $p$(正面概率)的MLE估计是多少? + +- 每次抛掷服从 Bernoulli($p$),因此10次抛掷中出现7次正面的似然为: + +$$L(p) = \binom{10}{7} p^7 (1-p)^3$$ + +- 取对数并求导:$\frac{d\ell}{dp} = \frac{7}{p} - \frac{3}{1-p} = 0$,解得 $\hat{p}_{\text{MLE}} = 7/10 = 0.7$。 + +- MLE直观且简单。如果10次抛掷中得到7次正面,最可能的偏置是0.7。但注意一个问题:如果10次抛掷中得到10次正面,MLE会得出 $\hat{p} = 1$,意味着硬币将永远正面朝上。仅凭10次观测就得出这样的结论似乎过于自信。 + +- **最大后验估计 (MAP)** 通过加入先验信念来修复这个问题。MAP不是仅最大化似然,而是最大化后验: + +$$\hat{\theta}_{\text{MAP}} = \arg\max_\theta P(\theta | D) = \arg\max_\theta P(D | \theta) \cdot P(\theta)$$ + +- 我们省略了分母 $P(D)$,因为它不依赖于 $\theta$,不影响argmax的结果。 + +- 先验 $P(\theta)$ 编码了我们在看到数据之前对 $\theta$ 的信念。如果我们使用 Beta(2, 2) 先验来表示硬币偏置(表达"硬币大致是公平的"这一温和信念),MAP估计就不再仅仅是正面的比例,而是被拉向0.5。 + +![MLE 找到似然的峰值;MAP 找到似然乘以先验的峰值](../images/mle_vs_map.svg) + +- 使用 Beta($\alpha$, $\beta$) 先验,观测到 $h$ 次正面和 $t$ 次反面后,后验为 Beta($\alpha + h$, $\beta + t$),MAP估计为: + +$$\hat{p}_{\text{MAP}} = \frac{\alpha + h - 1}{\alpha + \beta + h + t - 2}$$ + +- 对于我们的示例,Beta(2,2)先验,7次正面,3次反面:$\hat{p}_{\text{MAP}} = \frac{2 + 7 - 1}{2 + 2 + 10 - 2} = \frac{8}{12} = 0.667$。 + +- 注意MAP估计(0.667)相比MLE(0.7)如何被拉向0.5。先验起到了正则化的作用。在机器学习中,L2正则化(权重衰减)完全等价于在权重上使用高斯先验的MAP估计。 + +- **完整的贝叶斯推断**比MAP更进一步。它不是寻找单一的最佳 $\theta$,而是维护整个后验分布 $P(\theta | D)$。这不仅给你一个点估计,还给出了不确定性的度量。 + +- 对于具有Beta(2,2)先验和7次正面、3次反面的偏置硬币,完整的后验是 Beta(9, 5)。该分布的均值为 $9/14 \approx 0.643$,其弥散程度告诉我们置信度的高低。数据越多,后验越窄。 + +- 三种方法形成了一个谱系: + - **MLE**:无先验,仅依赖数据。速度快,但数据少时可能过拟合。 + - **MAP**:带先验正则化的点估计。增加鲁棒性。 + - **完整贝叶斯**:完整的后验分布。信息量最大,但通常计算成本高。 + +- **马尔可夫链**对序列进行建模,其中下一状态仅依赖于当前状态,而不依赖于历史。这种"无记忆性"称为**马尔可夫性**: + +$$P(X_{t+1} | X_t, X_{t-1}, \ldots, X_1) = P(X_{t+1} | X_t)$$ + +- 以天气为例。明天的天气取决于今天的天气,但不取决于上周的天气(这是一个简化,但出奇地有用)。 + +- 马尔可夫链具有有限个**状态**和一个**转移矩阵** $T$,其中元素 $T_{ij}$ 表示从状态 $i$ 转移到状态 $j$ 的概率。每一行之和为1。 + +![天气马尔可夫链,状态有雨天、晴天、多云,以及转移概率](../images/markov_chain.svg) + +- 对于上图的天气示例,转移矩阵为: + +```math +T = \begin{pmatrix} 0.3 & 0.4 & 0.3 \\ 0.2 & 0.5 & 0.3 \\ 0.4 & 0.3 & 0.3 \end{pmatrix} +``` + +- 如果今天是雨天(状态向量 $\mathbf{s}_0 = [1, 0, 0]$),明天天气的概率分布为 $\mathbf{s}_1 = \mathbf{s}_0 T = [0.3, 0.4, 0.3]$。两天后:$\mathbf{s}_2 = \mathbf{s}_0 T^2$。这使用了第一章中的矩阵乘法。 + +- 许多马尔可夫链会收敛到一个**平稳分布** $\pi$,满足 $\pi T = \pi$。无论从哪里出发,经过足够多的步数后,链会收敛到 $\pi$。这一性质是MCMC(马尔可夫链蒙特卡罗)的基础,MCMC是贝叶斯机器学习中广泛使用的采样技术。 + +- **隐马尔可夫模型 (HMM)** 通过增加一层间接性来扩展马尔可夫链。真实状态是隐藏的(不可观测的),每个时间步隐藏状态会发出一个可观测的信号。 + +![HMM 结构:上方隐藏状态由转移连接,下方观测由发射连接](../images/hmm_structure.svg) + +- HMM 有三个组成部分: + - **转移概率** $P(z_t | z_{t-1})$:隐藏状态如何演化(马尔可夫链) + - **发射概率** $P(x_t | z_t)$:每个隐藏状态产生什么可观测输出 + - **初始分布** $P(z_1)$:起始隐藏状态的概率 + +- **雨伞示例**:假设你不能直接看到天气,但可以观察到你的朋友是否带伞。隐藏状态为 {雨天, 晴天},观测为 {带伞, 不带伞}。 + +- 转移概率:$P(\text{雨天}|\text{雨天}) = 0.7$,$P(\text{晴天}|\text{雨天}) = 0.3$,$P(\text{雨天}|\text{晴天}) = 0.4$,$P(\text{晴天}|\text{晴天}) = 0.6$。 + +- 发射概率:$P(\text{带伞}|\text{雨天}) = 0.9$,$P(\text{不带伞}|\text{雨天}) = 0.1$,$P(\text{带伞}|\text{晴天}) = 0.2$,$P(\text{不带伞}|\text{晴天}) = 0.8$。 + +- HMM 的关键问题有: + - **解码**:给定观测,最可能的隐藏状态序列是什么?由**维特比算法**求解。 + - **评估**:观测序列的概率是多少?由**前向算法**求解。 + - **学习**:给定观测,最佳模型参数是什么?由**Baum-Welch算法**求解(期望最大化算法的一个实例)。 + +- **维特比演算**:假设你观测到 [带伞, 带伞, 不带伞],想找到最可能的天气序列。 + +- 从初始概率开始。假设 $P(R) = 0.5$,$P(S) = 0.5$。 + +- **第1天**(观测到带伞): + - $V_1(R) = P(R) \cdot P(U|R) = 0.5 \times 0.9 = 0.45$ + - $V_1(S) = P(S) \cdot P(U|S) = 0.5 \times 0.2 = 0.10$ + +- **第2天**(观测到带伞): + - $V_2(R) = \max(V_1(R) \cdot P(R|R), V_1(S) \cdot P(R|S)) \cdot P(U|R)$ + - $= \max(0.45 \times 0.7, 0.10 \times 0.4) \times 0.9 = \max(0.315, 0.04) \times 0.9 = 0.2835$ + - $V_2(S) = \max(V_1(R) \cdot P(S|R), V_1(S) \cdot P(S|S)) \cdot P(U|S)$ + - $= \max(0.45 \times 0.3, 0.10 \times 0.6) \times 0.2 = \max(0.135, 0.06) \times 0.2 = 0.027$ + +- **第3天**(观测到不带伞): + - $V_3(R) = \max(0.2835 \times 0.7, 0.027 \times 0.4) \times 0.1 = 0.1985 \times 0.1 = 0.01985$ + - $V_3(S) = \max(0.2835 \times 0.3, 0.027 \times 0.6) \times 0.8 = 0.08505 \times 0.8 = 0.06804$ + +- 第3天的最大值在晴天。回溯:第3天 = 晴天(来自R),第2天 = 雨天(来自R),第1天 = 雨天。最可能的序列:**雨天, 雨天, 晴天**。 + +- **前向-后向算法**计算在给定整个观测序列条件下,每个时间步处于每个隐藏状态的概率。前向过程计算 $P(z_t, x_{1:t})$,后向过程计算 $P(x_{t+1:T} | z_t)$。两者相乘得到平滑后的状态概率。 + +- **Baum-Welch算法**在隐藏状态不可观测时从数据中学习HMM参数。它是一种期望最大化(EM)算法:E步使用前向-后向算法估计哪些隐藏状态生成了观测,M步更新转移概率和发射概率。 + +- HMM在历史上主导了语音识别(隐藏的音素状态发出声学信号)和生物信息学(隐藏的基因状态发出DNA碱基对)。虽然深度学习在很大程度上已取代了这些领域中的HMM,但隐藏状态、发射和序列推断的思想仍然是序列模型的核心。 + +- **条件随机场 (CRF)** 通过去除发射独立假设来改进HMM。在HMM中,时间 $t$ 的观测仅依赖于时间 $t$ 的隐藏状态。CRF允许位置 $t$ 的标签依赖于整个输入序列。 + +- 线性链CRF对给定输入序列 $\mathbf{x}$ 条件下标签序列 $\mathbf{y}$ 的条件概率建模: + +$$P(\mathbf{y} | \mathbf{x}) = \frac{1}{Z(\mathbf{x})} \exp\!\left(\sum_t \left[\sum_k \lambda_k f_k(y_t, y_{t-1}, \mathbf{x}, t)\right]\right)$$ + +- 其中 $f_k$ 是特征函数(可以查看输入的任意部分),$\lambda_k$ 是学习到的权重,$Z(\mathbf{x})$ 是归一化常数。 + +- CRF是判别式模型(直接建模 $P(\mathbf{y}|\mathbf{x})$),而HMM是生成式模型(建模 $P(\mathbf{x}, \mathbf{y})$)。这一区别与逻辑回归(判别式)和朴素贝叶斯(生成式)之间的区别相同。 + +- 在现代NLP中,CRF层通常被加在神经网络之上(BiLSTM-CRF、BERT-CRF),用于命名实体识别和词性标注等需要捕捉标签依赖关系的任务。 + +## 编程练习(使用 CoLab 或 notebook) + +1. 实现抛硬币实验的MLE和MAP。观察MAP估计如何随不同的先验和不同的数据量而变化。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# 数据:观测到的硬币抛掷结果 +heads, tails = 7, 3 + +# MLE +p_mle = heads / (heads + tails) +print(f"MLE: {p_mle:.4f}") + +# 使用 Beta 先验的 MAP +for alpha, beta in [(1,1), (2,2), (5,5), (10,10)]: + p_map = (alpha + heads - 1) / (alpha + beta + heads + tails - 2) + print(f"MAP (Beta({alpha},{beta})): {p_map:.4f}") + +# 可视化 Beta(2,2) 先验下的后验 +theta = jnp.linspace(0.01, 0.99, 200) +# 后验为 Beta(alpha+heads, beta+tails) +a_post, b_post = 2 + heads, 2 + tails +posterior = theta**(a_post-1) * (1-theta)**(b_post-1) +posterior = posterior / jnp.trapezoid(posterior, theta) + +plt.figure(figsize=(8, 4)) +plt.plot(theta, posterior, color="#e74c3c", linewidth=2, label=f"后验 Beta({a_post},{b_post})") +plt.axvline(p_mle, color="#3498db", linestyle="--", label=f"MLE = {p_mle:.2f}") +plt.axvline((a_post-1)/(a_post+b_post-2), color="#e74c3c", linestyle="--", label=f"MAP = {(a_post-1)/(a_post+b_post-2):.3f}") +plt.xlabel("θ (硬币偏置)") +plt.ylabel("密度") +plt.title("7次正面、3次反面后 Beta(2,2) 先验下的后验分布") +plt.legend() +plt.grid(alpha=0.3) +plt.show() +``` + +2. 为天气模型构建一个马尔可夫链并进行模拟。分别通过模拟和求解 $\pi T = \pi$ 计算平稳分布。 +```python +import jax +import jax.numpy as jnp + +# 转移矩阵:R(雨天), S(晴天), C(多云) +T = jnp.array([ + [0.3, 0.4, 0.3], + [0.2, 0.5, 0.3], + [0.4, 0.3, 0.3] +]) +states = ["雨天", "晴天", "多云"] + +# 模拟 100,000 步 +key = jax.random.PRNGKey(42) +n_steps = 100_000 +state = 0 # 从雨天开始 +counts = jnp.zeros(3) + +for i in range(n_steps): + key, subkey = jax.random.split(key) + state = jax.random.choice(subkey, 3, p=T[state]) + counts = counts.at[state].add(1) + +sim_stationary = counts / n_steps +print("模拟得到的平稳分布:") +for s, p in zip(states, sim_stationary): + print(f" {s}: {p:.4f}") + +# 解析法:找到特征值为1的左特征向量 +eigenvalues, eigenvectors = jnp.linalg.eig(T.T) +idx = jnp.argmin(jnp.abs(eigenvalues - 1.0)) +pi = jnp.real(eigenvectors[:, idx]) +pi = pi / pi.sum() +print("\n解析得到的平稳分布:") +for s, p in zip(states, pi): + print(f" {s}: {p:.4f}") +``` + +3. 为雨伞HMM实现维特比算法,并解码一个观测序列。 +```python +import jax.numpy as jnp + +# HMM 参数 +states = ["雨天", "晴天"] +obs_names = ["带伞", "不带伞"] + +trans = jnp.array([[0.7, 0.3], # R->R, R->S + [0.4, 0.6]]) # S->R, S->S + +emit = jnp.array([[0.9, 0.1], # R->带伞, R->不带伞 + [0.2, 0.8]]) # S->带伞, S->不带伞 + +init = jnp.array([0.5, 0.5]) + +# 观测:带伞=0,不带伞=1 +observations = [0, 0, 1] # 带伞, 带伞, 不带伞 + +def viterbi(obs, init, trans, emit): + n_states = len(init) + T = len(obs) + V = jnp.zeros((T, n_states)) + path = jnp.zeros((T, n_states), dtype=int) + + # 初始化 + V = V.at[0].set(init * emit[:, obs[0]]) + + # 递推 + for t in range(1, T): + for j in range(n_states): + probs = V[t-1] * trans[:, j] + V = V.at[t, j].set(jnp.max(probs) * emit[j, obs[t]]) + path = path.at[t, j].set(jnp.argmax(probs)) + + # 回溯 + best = [int(jnp.argmax(V[-1]))] + for t in range(T-1, 0, -1): + best.insert(0, int(path[t, best[0]])) + return best, V + +decoded, scores = viterbi(observations, init, trans, emit) +print("观测序列:", [obs_names[o] for o in observations]) +print("解码结果:", [states[s] for s in decoded]) +``` + +4. 可视化随着观测更多抛硬币结果,后验如何演化。从 Beta(1,1) 先验(均匀分布)开始,每次抛掷后更新后验。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +theta = jnp.linspace(0.01, 0.99, 300) +key = jax.random.PRNGKey(7) + +# 真实偏置 = 0.65 +flips = jax.random.bernoulli(key, p=0.65, shape=(50,)) + +plt.figure(figsize=(10, 5)) +a, b = 1, 1 # Beta(1,1) = 均匀分布 + +for n_obs in [0, 1, 5, 10, 25, 50]: + h = int(flips[:n_obs].sum()) + t = n_obs - h + a_post = a + h + b_post = b + t + y = theta**(a_post-1) * (1-theta)**(b_post-1) + y = y / jnp.trapezoid(y, theta) + plt.plot(theta, y, linewidth=2, label=f"n={n_obs} (h={h})") + +plt.axvline(0.65, color="black", linestyle=":", alpha=0.5, label="真实 p=0.65") +plt.xlabel("θ") +plt.ylabel("密度") +plt.title("贝叶斯更新:数据越多后验越窄") +plt.legend() +plt.grid(alpha=0.3) +plt.show() +``` diff --git a/chapter 05: probability/05. information theory.md b/chapter 05: probability/05. information theory.md new file mode 100644 index 0000000..d241270 --- /dev/null +++ b/chapter 05: probability/05. information theory.md @@ -0,0 +1,191 @@ +# 信息论 + +*信息论量化了信息、惊奇度以及概率分布之间的差异。本文涵盖熵、交叉熵、KL散度、互信息和自信息——这些概念是机器学习中每一个分类损失函数、VAE目标和数据压缩方案背后的理论基础。* + +- 信息论由克劳德·香农于1948年创立,为我们提供了量化信息的数学框架。它回答了诸如此类的问题:一个事件应当让你感到多惊讶?一条消息携带了多少信息?两个概率分布之间有多大的差异? + +- 这些问题看似抽象,但它们是机器学习损失函数、数据压缩和通信系统的基础。交叉熵损失——分类中最常见的损失函数——直接源于信息论。 + +- 从最简单的问题开始:单个事件携带了多少信息? + +- **自信息**(surprisal,也称 self-information)衡量一个事件的惊奇程度。如果某件极有可能发生的事情真的发生了,你几乎学不到任何新信息。如果某件罕见的事情发生了,你则会获得大量信息。 + +- 如果你住在沙漠里,有人告诉你今天是大晴天,这并没有什么信息量。但如果他们告诉你正在下雪,那信息量就极其丰富。自信息将这种直觉形式化: + +$$I(x) = \log_2 \frac{1}{p(x)} = -\log_2 p(x)$$ + +- 使用 $\log_2$ 时,单位是**比特**。一枚公平的硬币抛掷的自信息为 $-\log_2(0.5) = 1$ 比特。一个概率为 $1/8$ 的事件具有 $ \log_2(8) = 3$ 比特的自信息。 + +- 为什么用对数而不是简单的 $1/p$?三个原因: + - 必然事件($p = 1$)应给出零信息:$\log(1) = 0$ 但 $1/1 = 1$。 + - 独立事件的信息应该是可加的:$\log(1/p_1 p_2) = \log(1/p_1) + \log(1/p_2)$。 + - 我们需要一个平滑、性质良好的函数。$1/p$ 会爆炸;$\log(1/p)$ 则平缓增长。 + +- **熵**是自信息的期望值,即从一个分布中每次采样获得的平均信息量。它衡量该分布的不确定性或"不可预测性": + +$$H(X) = E[I(X)] = -\sum_{x} p(x) \log_2 p(x)$$ + +![柱状图展示高概率事件具有低自信息,反之亦然;熵是加权平均值](../images/surprisal_entropy.svg) + +- 一枚公平硬币的熵为 $H = -0.5\log_2(0.5) - 0.5\log_2(0.5) = 1$ 比特。不确定性最大。 + +- 一枚偏倚硬币,$p = 0.9$,其熵为 $H = -0.9\log_2(0.9) - 0.1\log_2(0.1) \approx 0.469$ 比特。不太确定,因此熵更小。 + +- 一个确定性事件($p = 1$)的熵为 $H = 0$。完全没有不确定性。 + +- 当所有结果等可能时,熵达到最大。对于 $n$ 个等可能结果,$H = \log_2 n$。一颗公平骰子的熵为 $\log_2 6 \approx 2.585$ 比特。 + +- 熵的实际意义在于**压缩**。香农的源编码定理指出,如果不丢失信息,你无法将数据压缩到低于其熵率。一幅每个像素都等可能的图像(最大熵)无法压缩。一幅几乎全是白色的图像(低熵)则可以很好地压缩。 + +- 快速感受一下数量级:一个灰度像素(256 个值)的最大熵为 8 比特。一张 1080p 的灰度图像最多有 $1920 \times 1080 \times 8 \approx 1660$ 万比特。真实图像的熵要低得多,因为相邻像素是相关的——这正是 JPEG 压缩能够工作的原因。 + +- 对于连续随机变量,离散求和变为积分。**微分熵**定义为: + +$$h(X) = -\int_{-\infty}^{\infty} f(x) \log f(x)\, dx$$ + +- 方差为 $\sigma^2$ 的高斯分布的微分熵为 $h = \frac{1}{2}\log_2(2\pi e \sigma^2)$。在所有具有相同方差的分布中,高斯分布具有最大熵。这也是高斯分布在建模中如此常见的原因之一:它在指定均值和方差之外做出了最少的假设。 + +- **互信息**衡量知道一个变量能告诉你关于另一个变量的多少信息。它是观察到 $Y$ 后 $X$ 不确定性的减少量: + +$$I(X; Y) = H(X) - H(X|Y) = H(Y) - H(Y|X)$$ + +- 等价形式: + +$$I(X; Y) = \sum_{x,y} p(x,y) \log_2 \frac{p(x,y)}{p(x) p(y)}$$ + +- 如果 $X$ 和 $Y$ 独立,则 $p(x,y) = p(x)p(y)$,互信息为零。它们依赖程度越高,互信息就越大。 + +- 在机器学习中,互信息用于特征选择(挑选与目标具有高 MI 的特征)、信息瓶颈方法以及聚类质量评估。 + +- **交叉熵**衡量使用针对分布 $q$ 优化的编码方案来编码来自分布 $p$ 的事件所需的平均比特数: + +$$H(p, q) = -\sum_{x} p(x) \log_2 q(x)$$ + +- 如果 $q$ 与 $p$ 完全匹配,则交叉熵等于熵:$H(p, p) = H(p)$。如果 $q$ 是一个糟糕的近似,交叉熵就会更高。"额外"的比特来自这种不匹配。 + +- 这正是交叉熵成为机器学习中分类标准损失函数的原因。真实标签定义了 $p$(一个 one-hot 分布),模型的预测概率定义了 $q$。最小化交叉熵推动 $q$ 趋近于 $p$: + +$$\mathcal{L} = -\sum_{c} y_c \log \hat{y}_c$$ + +- 对于单个样本,若真实类别为 $c$,上式简化为 $\mathcal{L} = -\log \hat{y}_c$。该损失就是模型预测下真实类别的自信息。如果模型对正确类别赋予高概率,则损失较低。 + +- **KL 散度**(Kullback-Leibler 散度,也称相对熵)衡量一个分布与另一个分布的差异程度: + +$$D_{\text{KL}}(p \| q) = \sum_{x} p(x) \log \frac{p(x)}{q(x)} = H(p, q) - H(p)$$ + +- KL 散度是"使用分布 $q$ 而非真实分布 $p$ 的额外代价"。它总是非负的($D_{\text{KL}} \ge 0$),且仅在 $p = q$ 时为零。 + +![两个分布 p 和 q,它们之间的间隙表示 KL 散度](../images/kl_divergence.svg) + +- KL 散度不是对称的:$D_{\text{KL}}(p \| q) \ne D_{\text{KL}}(q \| p)$。这种不对称性很重要。$D_{\text{KL}}(p \| q)$ 惩罚 $q$ 在 $p$ 具有高概率处放置低概率(因为 $\log(p/q)$ 会趋于无穷大)。$D_{\text{KL}}(q \| p)$ 则惩罚相反的情况。 + +- 这种不对称性导致了两种近似风格: + - 最小化 $D_{\text{KL}}(p \| q)$ 产生**矩匹配**行为:$q$ 覆盖 $p$ 的所有模态,但可能过于分散。 + - 最小化 $D_{\text{KL}}(q \| p)$ 产生**模式寻找**行为:$q$ 集中于 $p$ 的某一个模态,但可能错过其他模态。变分推断使用的正是这一种。 + +- 由于 $H(p)$ 相对于模型是常数,最小化交叉熵 $H(p, q)$ 等价于最小化 $D_{\text{KL}}(p \| q)$。这就是为什么我们可以使用交叉熵损失,同时知道我们也在最小化真实分布与预测分布之间的 KL 散度。 + +- KL 散度在**贝叶斯更新**中扮演着核心角色。后验 $P(\theta | D)$ 是在 KL 散度意义上与先验 $P(\theta)$ 最接近且与观测数据一致的分布。每一次新的观测都会更新后验,减少关于 $\theta$ 的不确定性。 + +- 在变分自编码器(VAE)中,损失函数包含两项:重构损失(交叉熵)和一个 KL 散度项,后者对潜在空间进行正则化,使其保持接近标准正态分布。 + +- 将所有概念联系起来:熵告诉你一个分布内在的不确定性,交叉熵告诉你的模型对现实的近似程度,而 KL 散度则告诉你两者之间的差距。这三个量构成了现代机器学习优化的基石。 + +## 编程练习(使用 CoLab 或 notebook) + +1. 计算各种分布的熵,并验证在给定结果数量下,均匀分布的熵最大。 +```python +import jax.numpy as jnp + +def entropy(p): + """以比特为单位计算熵。过滤掉概率为零的事件。""" + p = p[p > 0] + return -jnp.sum(p * jnp.log2(p)) + +# 公平骰子 +fair = jnp.ones(6) / 6 +print(f"公平骰子熵: {entropy(fair):.4f} 比特 (最大 = log2(6) = {jnp.log2(6.):.4f})") + +# 灌铅骰子 +loaded = jnp.array([0.1, 0.1, 0.1, 0.1, 0.1, 0.5]) +print(f"灌铅骰子熵: {entropy(loaded):.4f} 比特") + +# 确定性 +det = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 1.0]) +print(f"确定性: {entropy(det):.4f} 比特") + +# 公平硬币 +coin = jnp.array([0.5, 0.5]) +print(f"公平硬币熵: {entropy(coin):.4f} 比特") +``` + +2. 计算真实分布与多个近似分布之间的交叉熵和 KL 散度。验证 $D_{\text{KL}}(p \| q) = H(p, q) - H(p)$。 +```python +import jax.numpy as jnp + +def cross_entropy(p, q): + return -jnp.sum(p * jnp.log2(jnp.clip(q, 1e-10, 1.0))) + +def kl_divergence(p, q): + mask = p > 0 + return jnp.sum(jnp.where(mask, p * jnp.log2(p / jnp.clip(q, 1e-10, 1.0)), 0.0)) + +def entropy(p): + p = p[p > 0] + return -jnp.sum(p * jnp.log2(p)) + +p = jnp.array([0.4, 0.3, 0.2, 0.1]) # 真实分布 + +for name, q in [("完全匹配", p), + ("轻微偏差", jnp.array([0.35, 0.30, 0.25, 0.10])), + ("严重偏差", jnp.array([0.1, 0.1, 0.1, 0.7]))]: + h_p = entropy(p) + h_pq = cross_entropy(p, q) + kl = kl_divergence(p, q) + print(f"{name:20s}: H(p)={h_p:.4f}, H(p,q)={h_pq:.4f}, " + f"KL={kl:.4f}, H(p,q)-H(p)={h_pq-h_p:.4f}") +``` + +3. 通过计算两个不同分布之间的 $D_{\text{KL}}(p \| q)$ 和 $D_{\text{KL}}(q \| p)$,证明 KL 散度不是对称的。 +```python +import jax.numpy as jnp + +def kl_div(p, q): + mask = p > 0 + return float(jnp.sum(jnp.where(mask, p * jnp.log2(p / jnp.clip(q, 1e-10, 1.0)), 0.0))) + +p = jnp.array([0.9, 0.1]) +q = jnp.array([0.5, 0.5]) + +print(f"D_KL(p || q) = {kl_div(p, q):.4f}") +print(f"D_KL(q || p) = {kl_div(q, p):.4f}") +print("不相同!KL 散度是不对称的。") +``` + +4. 模拟训练过程中交叉熵损失的变化。创建一个"真实"的 one-hot 标签,展示随着模型预测概率的改善,损失如何下降。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# 真实标签:4 个类别中的第 2 类 +true_label = jnp.array([0, 0, 1, 0]) + +# 模拟预测逐步改善 +steps = [] +losses = [] +for confidence in jnp.linspace(0.25, 0.99, 50): + # 模型对类别 2 的置信度逐渐提高 + remaining = (1 - confidence) / 3 + pred = jnp.array([remaining, remaining, confidence, remaining]) + loss = -jnp.sum(true_label * jnp.log(jnp.clip(pred, 1e-10, 1.0))) + steps.append(float(confidence)) + losses.append(float(loss)) + +plt.figure(figsize=(8, 4)) +plt.plot(steps, losses, color="#e74c3c", linewidth=2) +plt.xlabel("模型对真实类别的置信度") +plt.ylabel("交叉熵损失") +plt.title("交叉熵损失随预测改善而下降") +plt.grid(alpha=0.3) +plt.show() +``` diff --git a/chapter 06: machine learning/01. classical machine learning.md b/chapter 06: machine learning/01. classical machine learning.md new file mode 100644 index 0000000..4a70b63 --- /dev/null +++ b/chapter 06: machine learning/01. classical machine learning.md @@ -0,0 +1,381 @@ +# 经典机器学习 + +*经典机器学习算法通过数据学习模式而无需显式编程,使用闭式解或启发式搜索而非梯度下降。本文涵盖朴素贝叶斯、k-NN、决策树、随机森林、支持向量机、k-means聚类和主成分分析* + +- 机器学习是研究算法通过从数据中学习来提升其在某项任务上表现的学科,而非通过显式规则编程。与其编写"如果收入 > 50k 且年龄 < 30 则批准贷款",不如将数千条历史贷款决策交给算法,让它自行找出模式。 + +- 存在三大范式。**监督学习**使用带标签数据,即每个输入都有已知的正确输出。算法学习从输入到输出的映射。**无监督学习**处理未标签数据,试图发现隐藏结构,如聚类或压缩表示。**强化学习**通过试错学习,根据在环境中采取的动作接收奖励或惩罚(在第04篇中介绍)。 + +- 在监督学习中,**分类**预测离散类别(垃圾邮件或非垃圾邮件,猫或狗),而**回归**预测连续值(房价、明天温度)。边界并不总是清晰:逻辑回归虽然名为"回归",但实际上执行分类任务。 + +- 概率模型中的一个关键区分是**生成式 vs 判别式**。生成模型学习联合分布 $P(x, y)$,这意味着它理解数据本身的生成方式。它能产生新样本。判别模型直接学习 $P(y \mid x)$,仅关注类别之间的边界。朴素贝叶斯是生成式的;逻辑回归(第02篇)是判别式的。生成模型更灵活但更难训练好;判别模型在数据充足时通常给出更好的分类准确率。 + +- **朴素贝叶斯**是最简单且最有效的分类器之一。它直接应用贝叶斯定理(来自第05章): + +$$P(C_k \mid x) = \frac{P(x \mid C_k) \, P(C_k)}{P(x)}$$ + +- "朴素"之处在于一个强烈的独立性假设:它假设给定类别后每个特征相互独立。如果你正在将电子邮件分类为垃圾邮件,朴素贝叶斯假设一旦你知道邮件是垃圾邮件,单词"免费"的出现告诉你关于单词"赢家"是否出现的信息为零。这在现实中几乎从不成立,但分类器仍然出奇地好用。 + +- 由于 $P(x)$ 对所有类别都一样,分类简化为选择最大化分子的类别: + +$$\hat{y} = \arg\max_{k} \; P(C_k) \prod_{i=1}^{n} P(x_i \mid C_k)$$ + +- 先验 $P(C_k)$ 就是每个类别中训练样本的比例。似然 $P(x_i \mid C_k)$ 取决于特征的类型,从而产生三种常见变体。 + +- **多项式朴素贝叶斯**专为计数数据设计,如文档中的词频。每个特征 $x_i$ 表示单词 $i$ 出现的次数,似然遵循多项分布。这是文本分类、情感分析和垃圾邮件过滤的标准选择。 + +- **高斯朴素贝叶斯**假设每个特征在每个类别内服从正态分布。你从训练数据中估计特征 $i$ 对类别 $k$ 的均值 $\mu_{ik}$ 和方差 $\sigma_{ik}^2$,然后计算: + +$$P(x_i \mid C_k) = \frac{1}{\sqrt{2\pi\sigma_{ik}^2}} \exp\!\left(-\frac{(x_i - \mu_{ik})^2}{2\sigma_{ik}^2}\right)$$ + +- 当特征为连续测量值时,如身高、体重或传感器读数,这是自然的选择。 + +![两个重叠的高斯类条件分布,后验概率交叉处的决策边界](../images/naive_bayes_classify.svg) + +- **伯努利朴素贝叶斯**对二元特征建模:每个特征要么存在(1)要么不存在(0)。你不再统计单词出现的次数,而是只跟踪它是否出现。这适用于短文本或二元特征向量。 + +- 一个实际问题是,当某个特征值在训练数据中从未与某个类别一起出现时,似然变为零,由于所有概率相乘,整个后验概率也归零。**拉普拉斯平滑**通过为每个特征-类别组合添加一个小计数(通常为1)来解决这个问题: + +$$P(x_i \mid C_k) = \frac{\text{count}(x_i, C_k) + \alpha}{\text{count}(C_k) + \alpha \cdot V}$$ + +- 这里 $\alpha$ 是平滑参数(通常为1),$V$ 是该特征的可能取值数量。这确保了任何概率永远不会精确为零。 + +- **决策树**采用了一种完全不同的方法。它不是计算概率,而是通过一系列的"是/否"问题来划分特征空间。想象"二十问"游戏:每一步,你问一个最能缩小可能性范围的问题。 + +- 树从根节点开始,包含所有训练样本。在每个内部节点,它选择一个特征和一个阈值进行分裂(例如,"年龄 < 30?")。样本根据答案向左或向右流动。这一过程递归进行直到叶节点,叶节点中存放预测结果:分类任务中的多数类别,或回归任务中的均值。 + +![深度为2的决策树,特征分裂、是/否分支和彩色叶节点显示类别预测](../images/decision_tree_split.svg) + +- 关键问题是:应该选择哪个特征进行分裂?你希望分裂产生最"纯"的子节点,即大多数样本属于同一类别。衡量不纯度的两种常用指标是**基尼不纯度**和**熵**。 + +- **基尼不纯度**衡量的是如果按照该节点中的分布标记,随机选择的样本被错误分类的概率: + +$$\text{Gini}(S) = 1 - \sum_{k=1}^{K} p_k^2$$ + +- 如果节点完全纯(全部属于一个类别),基尼值为0。如果类别完全平衡(比如两类各占50%),基尼值达到最大值0.5。 + +- **熵**(来自第05章的信息论部分)衡量平均惊讶程度: + +$$H(S) = -\sum_{k=1}^{K} p_k \log_2 p_k$$ + +- 纯节点的熵为0。完全平衡的二元节点的熵为1比特。实际上,基尼和熵产生的树非常相似;基尼计算稍快,因为它避免了对数运算。 + +- **信息增益**是由一次分裂带来的不纯度降低。对于将集合 $S$ 划分为子集 $S_L$ 和 $S_R$ 的分裂: + +$$\text{IG}(S, \text{split}) = H(S) - \frac{|S_L|}{|S|} H(S_L) - \frac{|S_R|}{|S|} H(S_R)$$ + +- 算法在每一节点贪心地选择信息增益最高的分裂。这是一种局部最优策略,而非全局最优,但在实践中效果很好。 + +- **回归树**工作原理相同,但叶子预测连续值(到达该叶子的样本的均值),分裂准则使用方差减少而非基尼或熵。 + +- 如果不加约束,决策树会一直分裂直到每个叶子都纯,本质上是在记忆训练数据。这是严重的过拟合。**剪枝**用于应对这一问题。预剪枝在树生长之前设置限制:最大深度、每个叶子的最少样本数、或进行分裂的最小信息增益。后剪枝先生长完整树,然后移除那些不能提升验证集性能的分支。 + +- 单个决策树易于解释,但往往不稳定:数据的微小变化可能导致完全不同的树。**集成方法**组合多个模型,以获得比任何单个模型更好的预测结果。 + +- 核心思想是"群众智慧"。如果你问100个平庸的分类器然后进行多数投票,只要各个分类器做出一定程度上独立的错误,集成结果可以非常出色。 + +- **Bagging**(自助汇聚法)在数据的不同随机子集上训练多个模型,采用有放回抽样(bootstrap样本)。每个模型大约看到原始数据的63%。在预测时,你对输出取平均(回归)或进行多数投票(分类)。由于每个模型看到不同的数据,它们犯不同的错误,平均操作抵消了大部分方差。 + +- **随机森林**是将bagging应用于决策树并增加一个额外技巧:在每个分裂处,树只考虑一个随机的特征子集(通常是从 $d$ 个总特征中选 $\sqrt{d}$ 个)。这进一步去除了树之间的相关性,使集成更强大。随机森林是整个机器学习中最可靠的现成分类器之一。 + +![并排对比:bagging并行训练模型并取平均,boosting顺序训练模型并纠正之前的错误](../images/ensemble_methods.svg) + +- **Boosting**采取了相反的哲学。它不是独立地训练模型,而是顺序地训练,每个新模型专注于之前模型分类错误的样本。 + +- **AdaBoost**(自适应提升)为每个训练样本维护一个权重。最初所有权重相等。训练一个弱学习器(通常是深度很浅的决策树,称为"桩")后,被错误分类的样本获得更高的权重,因此下一个学习器更加关注它们。最终预测是所有学习器的加权投票,表现更好的学习器拥有更大的发言权: + +$$H(x) = \text{sign}\!\left(\sum_{t=1}^{T} \alpha_t \, h_t(x)\right)$$ + +- 学习器 $t$ 的权重 $\alpha_t$ 取决于其错误率 $\epsilon_t$: + +$$\alpha_t = \frac{1}{2} \ln\!\left(\frac{1 - \epsilon_t}{\epsilon_t}\right)$$ + +- 错误率低的学习器获得大的正权重;表现与随机水平持平($\epsilon = 0.5$)的学习器获得零权重。 + +- **梯度提升**推广了这一思想。不同于重新加权样本,每个新模型被训练来预测当前集成整体的残差误差(损失函数的负梯度)。对于平方误差损失,残差就是预测值与目标值之间的差值。基于决策树的梯度提升(GBDT)是结构化数据竞赛中许多获胜方案背后的方法(XGBoost、LightGBM、CatBoost是流行的实现)。 + +- 关键对比:bagging降低**方差**(通过平均消除噪声),而boosting降低**偏差**(纠正系统性错误)。Bagging在个别模型过拟合时效果最好;boosting在模型欠拟合时效果最好。 + +- 转向无监督学习,**K-Means聚类**是最简单且使用最广泛的聚类算法。给定 $n$ 个数据点和目标聚类数 $K$,它通过最小化每个点到其聚类中心的距离总和,将每个点分配给 $K$ 个组之一。 + +- 算法交替进行两个步骤。首先,将每个点**分配**到最近的中心点。其次,将每个中心点**更新**为分配给它的所有点的均值。重复直到分配不再变化。这保证收敛,因为每一步总簇内距离都会减小(或保持不变)。 + +![具有三个彩色点簇、中心点标记和虚线簇边界的2D散点图](../images/kmeans_clustering.svg) + +- 形式上,K-Means最小化簇内平方和,称为**惯性**: + +$$J = \sum_{k=1}^{K} \sum_{x \in C_k} \|x - \mu_k\|^2$$ + +- 其中 $\mu_k$ 是簇 $C_k$ 的中心点。 + +- K-Means对初始化敏感。糟糕的起始中心点可能导致较差的局部最小值。**K-Means++** 初始化策略首先随机选择一个中心点,然后每个后续中心点的选择概率与其距离最近现有中心点的平方距离成正比。这分散了初始中心点,几乎总是能给出更好的结果。 + +- 如何选择 $K$?两种常用工具。**肘部法**绘制惯性随 $K$ 变化的曲线,寻找"肘部"——增加更多簇不再显著帮助的点。**轮廓系数**衡量一个点与其自身簇的相似度相对于最近其他簇的相似度,范围从-1(错误簇)到+1(良好聚类)。所有点的平均轮廓系数给出了聚类质量的整体衡量。 + +- K-Means有局限性:它假设大致相等大小的球形簇,并且它做出"硬"分配(每个点恰好属于一个簇)。**高斯混合模型(GMM)** 放松了这两个限制。 + +- GMM将数据建模为 $K$ 个高斯分布的混合,每个分布有自己的均值 $\mu_k$、协方差 $\Sigma_k$ 和混合权重 $\pi_k$(所有权重之和为1): + +$$P(x) = \sum_{k=1}^{K} \pi_k \, \mathcal{N}(x \mid \mu_k, \Sigma_k)$$ + +- 不同于硬分配,每个点得到一个**软分配**:它属于每个簇的概率(称为"责任")。位于两个高斯边界附近的点可能是60%属于簇A,40%属于簇B。 + +- GMM使用**期望-最大化(EM)算法**进行拟合,该算法交替两个步骤,与K-Means非常类似。**E步**计算责任:对于每个点,它来自每个高斯的概率是多少?**M步**更新参数:给定责任,最佳的均值、协方差和混合权重是什么?EM保证每次迭代增加数据似然,并收敛到局部最大值。 + +- K-Means实际上是GMM的EM算法的一个特例:它对应于具有相等协方差的球形高斯和硬(0/1)责任分配。 + +- **支持向量机(SVM)** 从几何视角处理分类问题。给定两个线性可分的类别,存在无限多个超平面可以将它们分开。SVM找到**最大间隔**的那个——超平面与每个类别最近数据点之间的最大可能间隙。 + +- 最近的点,即恰好位于间隔边缘的点,称为**支持向量**。它们是定义决策边界唯一重要的点;你可以移除所有其他训练点,仍然得到相同的超平面。 + +![两个类别被最大间隔超平面分开,带有间隔带和圈出的支持向量](../images/svm_margin.svg) + +- 对于线性分类器 $f(x) = w \cdot x + b$,找到最大间隔等价于求解: + +$$\min_{w, b} \; \frac{1}{2}\|w\|^2 \quad \text{subject to} \quad y_i(w \cdot x_i + b) \geq 1 \; \text{for all } i$$ + +- 这是一个凸二次规划问题,因此有唯一的全局解(无需担心局部最小值)。 + +- 真实数据很少完美可分。**软间隔SVM** 通过引入松弛变量 $\xi_i \geq 0$ 允许一些点违反间隔: + +$$\min_{w, b, \xi} \; \frac{1}{2}\|w\|^2 + C \sum_{i=1}^{n} \xi_i \quad \text{subject to} \quad y_i(w \cdot x_i + b) \geq 1 - \xi_i$$ + +- 超参数 $C$ 控制权衡:大的 $C$ 对错误分类施加高惩罚(更紧的拟合,有过拟合风险),小的 $C$ 允许更多违规(更宽的间隔,更强的正则化)。 + +- SVM最强大的特性是**核技巧**。许多在原始特征空间中不是线性可分的数据集,在映射到高维空间后变得可分。核技巧让你能够在那个高维空间中计算点积,而无需显式计算变换。 + +- 核函数 $K(x_i, x_j) = \phi(x_i) \cdot \phi(x_j)$ 替换SVM优化中的每个点积。最流行的核是**径向基函数(RBF)核**: + +$$K(x_i, x_j) = \exp\!\left(-\gamma \|x_i - x_j\|^2\right)$$ + +- RBF核隐式地将数据映射到无限维空间。参数 $\gamma$ 控制单个训练点的影响范围:大的 $\gamma$ 意味着每个点只影响其紧邻区域(过拟合风险),小的 $\gamma$ 给出更平滑的边界。 + +- 其他常见核包括多项式核 $K(x_i, x_j) = (x_i \cdot x_j + c)^d$ 和线性核 $K(x_i, x_j) = x_i \cdot x_j$(即没有任何变换的标准SVM)。 + +- 实际上,带RBF核的SVM在深度学习出现之前是主导分类器。它们在中小规模数据集上仍然表现良好,特别是当特征数量相对于样本数量较大时。 + +- SVM与第02章(矩阵)的联系很深。优化通常以其对偶形式求解,其中解仅依赖于训练样本之间的点积——这正是使核技巧成为可能的原因。整个算法以内积和线性代数的语言运作。 + +- 汇总经典ML工具箱: + +| 算法 | 类型 | 关键优势 | 关键劣势 | +|---|---|---|---| +| 朴素贝叶斯 | 监督(生成式) | 快速,少量数据即可工作 | 独立性假设 | +| 决策树 | 监督 | 可解释 | 容易过拟合 | +| 随机森林 | 监督(集成) | 稳健,超参数少 | 可解释性较差 | +| 梯度提升 | 监督(集成) | 表格数据上的最优水平 | 较慢,调参更多 | +| K-Means | 无监督(聚类) | 简单,可扩展 | 假设球形簇 | +| GMM | 无监督(聚类) | 软分配,形状灵活 | 对初始化敏感 | +| SVM | 监督 | 高维有效 | 大数据集上慢 | + +## 编程任务(在CoLab或笔记本中完成) + +1. 从头实现高斯朴素贝叶斯。在合成二维数据(两个类别)上训练并可视化决策边界。与scikit-learn的实现进行比较。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt +from sklearn.datasets import make_classification + +# 生成合成数据 +X, y = make_classification(n_samples=300, n_features=2, n_redundant=0, + n_informative=2, n_clusters_per_class=1, random_state=42) +X, y = jnp.array(X), jnp.array(y) + +# 从头拟合高斯朴素贝叶斯 +classes = jnp.unique(y) +params = {} +for c in classes: + c = int(c) + mask = y == c + X_c = X[mask] + params[c] = { + 'mean': jnp.mean(X_c, axis=0), + 'var': jnp.var(X_c, axis=0), + 'prior': jnp.sum(mask) / len(y) + } + +def gaussian_log_likelihood(x, mean, var): + return -0.5 * jnp.sum(jnp.log(2 * jnp.pi * var) + (x - mean)**2 / var) + +def predict(X): + preds = [] + for x in X: + log_posts = [] + for c in [0, 1]: + log_post = jnp.log(params[c]['prior']) + gaussian_log_likelihood( + x, params[c]['mean'], params[c]['var']) + log_posts.append(log_post) + preds.append(jnp.argmax(jnp.array(log_posts))) + return jnp.array(preds) + +# 决策边界可视化 +xx, yy = jnp.meshgrid(jnp.linspace(X[:,0].min()-1, X[:,0].max()+1, 200), + jnp.linspace(X[:,1].min()-1, X[:,1].max()+1, 200)) +grid = jnp.column_stack([xx.ravel(), yy.ravel()]) +zz = predict(grid).reshape(xx.shape) + +plt.figure(figsize=(8, 6)) +plt.contourf(xx, yy, zz, alpha=0.3, cmap='coolwarm') +plt.scatter(X[y==0, 0], X[y==0, 1], c='#3498db', label='Class 0', edgecolors='k', s=20) +plt.scatter(X[y==1, 0], X[y==1, 1], c='#e74c3c', label='Class 1', edgecolors='k', s=20) +plt.title("Gaussian Naive Bayes Decision Boundary") +plt.legend() +plt.grid(alpha=0.3) +plt.show() + +accuracy = jnp.mean(predict(X) == y) +print(f"Training accuracy: {accuracy:.2%}") +``` + +2. 构建一个使用基尼不纯度进行分裂的决策树。实现单个节点的分裂逻辑,并展示信息增益如何选择最佳特征和阈值。 +```python +import jax.numpy as jnp + +def gini_impurity(y): + """计算标签数组的基尼不纯度。""" + classes, counts = jnp.unique(y, return_counts=True) + probs = counts / len(y) + return 1.0 - jnp.sum(probs ** 2) + +def information_gain(y, left_mask): + """通过布尔掩码将y分裂为左/右后的信息增益。""" + parent_gini = gini_impurity(y) + left_y, right_y = y[left_mask], y[~left_mask] + n = len(y) + if len(left_y) == 0 or len(right_y) == 0: + return 0.0 + child_gini = (len(left_y)/n) * gini_impurity(left_y) + \ + (len(right_y)/n) * gini_impurity(right_y) + return float(parent_gini - child_gini) + +def best_split(X, y): + """找到最大化信息增益的特征和阈值。""" + best_ig, best_feat, best_thresh = -1, None, None + for feat in range(X.shape[1]): + thresholds = jnp.unique(X[:, feat]) + for thresh in thresholds: + mask = X[:, feat] <= float(thresh) + ig = information_gain(y, mask) + if ig > best_ig: + best_ig, best_feat, best_thresh = ig, feat, float(thresh) + return best_feat, best_thresh, best_ig + +# 示例:合成数据 +from sklearn.datasets import make_classification +X, y = make_classification(n_samples=100, n_features=4, n_redundant=0, random_state=0) +X, y = jnp.array(X), jnp.array(y) + +feat, thresh, ig = best_split(X, y) +print(f"Best split: feature {feat}, threshold {thresh:.3f}, info gain {ig:.4f}") +print(f"Parent Gini: {gini_impurity(y):.4f}") +mask = X[:, feat] <= thresh +print(f"Left Gini: {gini_impurity(y[mask]):.4f} ({int(jnp.sum(mask))} samples)") +print(f"Right Gini: {gini_impurity(y[~mask]):.4f} ({int(jnp.sum(~mask))} samples)") +``` + +3. 从头实现带K-Means++初始化的K-Means。对合成数据集进行聚类并可视化每次迭代的簇。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +from sklearn.datasets import make_blobs + +# 生成合成簇 +X, y_true = make_blobs(n_samples=300, centers=4, cluster_std=0.8, random_state=42) +X = jnp.array(X) + +def kmeans_plus_plus_init(X, K, key): + """K-Means++初始化。""" + n = X.shape[0] + idx = jax.random.randint(key, (), 0, n) + centroids = [X[idx]] + for _ in range(1, K): + dists = jnp.min(jnp.stack([jnp.sum((X - c)**2, axis=1) for c in centroids]), axis=0) + probs = dists / jnp.sum(dists) + key, subkey = jax.random.split(key) + idx = jax.random.choice(subkey, n, p=probs) + centroids.append(X[idx]) + return jnp.stack(centroids) + +def kmeans(X, K, max_iters=20, key=jax.random.PRNGKey(0)): + centroids = kmeans_plus_plus_init(X, K, key) + history = [centroids] + for _ in range(max_iters): + # 分配步骤 + dists = jnp.stack([jnp.sum((X - c)**2, axis=1) for c in centroids]) + labels = jnp.argmin(dists, axis=0) + # 更新步骤 + new_centroids = jnp.stack([ + jnp.mean(X[labels == k], axis=0) for k in range(K) + ]) + history.append(new_centroids) + if jnp.allclose(centroids, new_centroids): + break + centroids = new_centroids + return labels, centroids, history + +K = 4 +labels, centroids, history = kmeans(X, K) + +# 绘制最终结果 +colors = ['#3498db', '#e74c3c', '#27ae60', '#9b59b6'] +plt.figure(figsize=(8, 6)) +for k in range(K): + mask = labels == k + plt.scatter(X[mask, 0], X[mask, 1], c=colors[k], s=20, alpha=0.6) + plt.scatter(centroids[k, 0], centroids[k, 1], c=colors[k], marker='X', + s=200, edgecolors='k', linewidths=1.5) +plt.title(f"K-Means Clustering (K={K}, {len(history)-1} iterations)") +plt.grid(alpha=0.3) +plt.show() + +# 计算惯性 +inertia = sum(jnp.sum((X[labels == k] - centroids[k])**2) for k in range(K)) +print(f"Final inertia: {inertia:.2f}") +``` + +4. 演示核技巧。通过比较核矩阵与多项式核的显式特征映射,展示RBF核如何在高维空间中计算点积。 +```python +import jax.numpy as jnp + +# 简单2D数据 +X = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + +# 多项式核:K(x,y) = (x·y + 1)^2 +def poly_kernel(X, degree=2, c=1.0): + return (X @ X.T + c) ** degree + +# 2D的显式二次特征映射:(1, sqrt(2)*x1, sqrt(2)*x2, x1^2, x2^2, sqrt(2)*x1*x2) +def poly_features(X): + x1, x2 = X[:, 0], X[:, 1] + return jnp.column_stack([ + jnp.ones(len(X)), + jnp.sqrt(2) * x1, + jnp.sqrt(2) * x2, + x1 ** 2, + x2 ** 2, + jnp.sqrt(2) * x1 * x2 + ]) + +K_trick = poly_kernel(X) +phi = poly_features(X) +K_explicit = phi @ phi.T + +print("Kernel trick (polynomial degree 2):") +print(K_trick) +print("\nExplicit feature map dot products:") +print(K_explicit) +print(f"\nMatrices match: {jnp.allclose(K_trick, K_explicit)}") + +# RBF核:不存在有限的显式映射 +def rbf_kernel(X, gamma=0.5): + sq_dists = jnp.sum(X**2, axis=1, keepdims=True) + \ + jnp.sum(X**2, axis=1) - 2 * X @ X.T + return jnp.exp(-gamma * sq_dists) + +K_rbf = rbf_kernel(X) +print("\nRBF kernel matrix:") +print(K_rbf) +print("Diagonal is always 1 (a point is identical to itself)") +print("Off-diagonal entries decay with distance") +``` diff --git a/chapter 06: machine learning/02. gradient machine learning.md b/chapter 06: machine learning/02. gradient machine learning.md new file mode 100644 index 0000000..1e4c5e4 --- /dev/null +++ b/chapter 06: machine learning/02. gradient machine learning.md @@ -0,0 +1,408 @@ +# 梯度机器学习 + +*基于梯度的学习通过沿着损失曲面的斜率迭代优化模型参数。本文涵盖线性回归、逻辑回归、softmax分类、梯度下降变体、正则化(L1/L2)和偏差-方差权衡* + +- 第01篇中的经典方法使用巧妙的启发式或闭式解。本文涵盖通过沿着梯度学习、在损失曲面上小步下坡直到找到好参数的算法。基于梯度的学习是从线性回归到最大神经网络的一切背后的引擎。 + +- **线性回归**是最简单的基于梯度的模型,它也有闭式解,这使其成为完美的起点。模型是一条直线(或更高维的超平面): + +$$\hat{y} = w \cdot x + b = \sum_{i=1}^{d} w_i x_i + b$$ + +- 用矩阵符号(来自第02章),如果我们将所有训练输入堆叠为矩阵 $X$ 的行,并通过追加一列1将偏置吸收到 $w$ 中,这就变成了 $\hat{y} = Xw$。 + +- 目标是最小化**均方误差(MSE)**,即预测值与实际值之间平均平方差: + +$$\mathcal{L}(w) = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{y}_i)^2 = \frac{1}{n} \|y - Xw\|^2$$ + +- 为什么采用平方误差?它有概率论上的依据:如果你假设目标值由 $y = Xw + \epsilon$ 生成,其中 $\epsilon \sim \mathcal{N}(0, \sigma^2)$,那么最大化数据的高斯似然(第05章)等价于最小化MSE。平方误差还比小错误更严厉地惩罚大错误,这通常是可取的。 + +![具有最佳拟合线和显示误差的虚线垂直残差线的数据点散点图](../images/linear_regression_fit.svg) + +- 由于MSE是 $w$ 的二次函数,它具有唯一的全局最小值,我们可以通过解析方法找到。求导、设为零并求解,得到**正规方程**: + +$$w^{*} = (X^T X)^{-1} X^T y$$ + +- 这直接使用了第02章的矩阵逆运算。表达式 $X^T X$ 是一个 $d \times d$ 矩阵(其中 $d$ 是特征数量),$X^T y$ 是一个 $d$ 维向量。正规方程一次性给出精确的最优权重。 + +- 正规方程何时失效?当 $X^T X$ 奇异(不可逆)时,这发生在特征线性相关或特征数量多于样本数量($d > n$)的情况下。在这些情况下,你需要正则化(后续介绍)或梯度下降。 + +- **逻辑回归**将线性模型适用于二元分类。我们不预测连续值,而是想要一个介于0和1之间的概率。**Sigmoid函数**将所有实数压缩到这个范围内: + +$$\sigma(z) = \frac{1}{1 + e^{-z}}$$ + +- 模型计算 $z = w \cdot x + b$(线性得分,与线性回归相同),然后将其通过sigmoid:$\hat{y} = \sigma(w \cdot x + b)$。输出 $\hat{y}$ 被解释为 $P(y = 1 \mid x)$。 + +![带有0.5阈值标记的Sigmoid曲线,显示预测0和预测1的分类区域](../images/sigmoid_logistic.svg) + +- Sigmoid具有良好的性质:$\sigma(0) = 0.5$,$\sigma(z) \to 1$ 当 $z \to \infty$,$\sigma(z) \to 0$ 当 $z \to -\infty$,且其导数具有优雅的形式 $\sigma'(z) = \sigma(z)(1 - \sigma(z))$。 + +- 逻辑回归的损失函数是**二元交叉熵(BCE)**,直接来自于伯努利似然(第05章): + +$$\mathcal{L} = -\frac{1}{n} \sum_{i=1}^{n} \left[ y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i) \right]$$ + +- 当真实标签为1时,只有第一项起作用,它惩罚过低的预测。当真实标签为0时,只有第二项起作用,它惩罚过高的预测。对数使得对于自信的错误预测,惩罚极其陡峭:当真实标签为1时预测0.01,代价远高于预测0.4。 + +- 与线性回归的MSE不同,BCE最小化权重没有闭式解。我们需要一种迭代方法:**梯度下降**。 + +- 梯度下降的直觉很简单:想象你身处大雾中的丘陵地带(损失曲面)。你看不到全局最小值,但可以感受到脚下的坡度。你向下坡走一步,再次感受坡度,然后重复。最终你到达一个山谷。 + +$$w \leftarrow w - \eta \frac{\partial \mathcal{L}}{\partial w}$$ + +- 学习率 $\eta$ 控制你的步长。太大则越过山谷,来回弹跳而不收敛。太小则缓慢前行,可能陷入局部最小值。 + +![一维损失曲线,三个球:大学习率越过,好学习率收敛,小学习率卡住](../images/gradient_descent_landscape.svg) + +- 梯度 $\frac{\partial \mathcal{L}}{\partial w}$ 是一个指向最陡上升方向的向量。我们减去它是因为想向下坡走。这是第03章中的链式法则应用于损失函数。 + +- **批量梯度下降**每一步使用整个训练集计算梯度。这给出精确梯度,但当 $n$ 很大时计算代价高昂。 + +- **随机梯度下降(SGD)** 每一步使用单个随机样本。梯度带有噪声(它从一个样本估计真实梯度),但每一步非常快。噪声实际上可以帮助逃离浅的局部极小值。 + +- **小批量梯度下降**折中:每一步使用 $B$ 个样本的批次(通常为32、64或256)。这平衡了计算效率(对批次的向量化操作)与梯度质量。几乎所有深度学习都使用小批量SGD。 + +- **反向传播**是我们实际计算具有许多参数的模型(如神经网络)中梯度的方法。它是第03章的链式法则通过计算图系统化地应用。 + +- 任何模型都可以表示为操作的有向无环图:输入流入,乘以权重,加在一起,通过非线性函数传递,最终产生损失值。**前向传播**通过让数据从输入到输出流经此图来计算输出(和损失)。 + +- **反向传播**反向流动梯度。从损失开始,你使用每个节点的链式法则计算损失相对于每个中间值的变化。如果 $L$ 依赖于 $z$,而 $z$ 依赖于 $w$,则: + +$$\frac{\partial L}{\partial w} = \frac{\partial L}{\partial z} \cdot \frac{\partial z}{\partial w}$$ + +- 每个节点只需要知道自己的局部导数和从上方流入的梯度。这使得反向传播模块化且高效:代价大约是前向传播的两倍(一次前向,一次反向)。 + +- 原始SGD有一个问题:它在陡峭曲率方向上振荡,而在平坦方向上进展缓慢。**优化器**通过根据梯度历史调整步长来改进这一点。 + +- **带动量的SGD**维护过去梯度的运行平均值(指数移动平均,来自第04章)。这平滑了振荡并加速了沿一致方向的进展: + +$$v_t = \beta v_{t-1} + (1 - \beta) \nabla \mathcal{L}$$ +$$w \leftarrow w - \eta \, v_t$$ + +- 想象一个滚下山的球:动量让它沿一致方向积累速度并抑制侧向抖动。典型值为 $\beta = 0.9$。 + +- **内斯特罗夫加速梯度(NAG)** 是一个小巧但巧妙的调整:不在当前位置计算梯度,而是在"前瞻"位置 $w - \eta \beta v_{t-1}$ 计算梯度。这一修正步骤减少了过冲: + +$$v_t = \beta \, v_{t-1} + \nabla \mathcal{L}(w - \eta \beta \, v_{t-1})$$ +$$w \leftarrow w - \eta \, v_t$$ + +- **Adagrad** 为每个参数调整学习率。接收大梯度的参数获得较小的学习率,反之亦然。它累积平方梯度: + +$$G_t = G_{t-1} + g_t^2, \quad w \leftarrow w - \frac{\eta}{\sqrt{G_t + \epsilon}} g_t$$ + +- 问题在于:$G_t$ 只增不减,因此有效学习率单调递减,最终变得太小而无法学习任何东西。 + +- **RMSprop** 通过使用平方梯度的指数移动平均而非求和来修复此问题,使得近期梯度比早期梯度更重要: + +$$s_t = \beta \, s_{t-1} + (1 - \beta) g_t^2, \quad w \leftarrow w - \frac{\eta}{\sqrt{s_t + \epsilon}} g_t$$ + +- **Adam**(自适应矩估计)结合了动量和RMSprop。它同时维护一阶矩估计(梯度的均值,像动量)和二阶矩估计(平方梯度的均值,像RMSprop): + +$$m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t$$ +$$v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2$$ + +- 由于 $m_t$ 和 $v_t$ 初始化为零,它们在早期步骤中有偏近于零。偏差修正解决了这个问题: + +$$\hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t}$$ +$$w \leftarrow w - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t$$ + +![二维等高线图显示SGD呈锯齿形、动量沿更平滑路径、Adam走最直接路线到达最小值](../images/optimizer_trajectories.svg) + +- 默认超参数($\beta_1 = 0.9$, $\beta_2 = 0.999$, $\epsilon = 10^{-8}$)在广泛的问题上表现良好,这就是为什么Adam是大多数深度学习工作中的默认优化器。 + +- **AdamW** 将权重衰减与梯度更新解耦。标准L2正则化和权重衰减对于SGD是等价的,但对于Adam则不然。AdamW直接将权重衰减应用于参数,而不是将 $\lambda w$ 加到梯度上。这带来了更好的泛化性能,现在是Transformer训练的标准: + +$$w \leftarrow w - \eta \left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \, w \right)$$ + +- **LION**(演化符号动量)是通过程序搜索发现的新优化器。它只使用动量更新的符号(而不是幅度),使得每次更新的尺度均匀。LION比Adam使用更少的内存(没有二阶矩缓冲区),并且在许多任务上可以匹配或超越Adam: + +$$w \leftarrow w - \eta \cdot \text{sign}(\beta_1 \, m_{t-1} + (1 - \beta_1) \, g_t)$$ +$$m_t = \beta_2 \, m_{t-1} + (1 - \beta_2) \, g_t$$ + +- **Muon**(动量 + 正交化)应用内斯特罗夫动量,然后使用Newton-Schulz迭代对更新矩阵进行正交化,该迭代近似极分解。得到的更新方向位于Stiefel流形上,每次更新在所有奇异方向上具有大致相等的幅度,防止任何单一方向主导。这消除了对自适应二阶矩估计(如Adam的 $v_t$ 缓冲区)的需求,减少了内存使用。Muon在Transformer训练中表现出色,通常以更快的收敛速度达到与AdamW相当的质量,尤其适用于注意力矩阵和MLP权重矩阵。嵌入层和输出层通常仍由AdamW处理。 + +$$G_t = \text{NesterovMomentum}(\nabla \mathcal{L})$$ +$$U_t = \text{NewtonSchulz}(G_t) \approx G_t (G_t^T G_t)^{-1/2}$$ +$$W \leftarrow W - \eta \, U_t$$ + +- Newton-Schulz迭代通过重复 $X_{k+1} = \frac{1}{2} X_k (3I - X_k^T X_k)$ 几个步骤(通常5-10步)来计算正交因子。这避免了完整SVD的计算代价,同时提供了良好的近似。 + +![Muon正交化:动量更新具有偏斜的奇异值,Newton-Schulz迭代使它们均衡,所有方向均匀更新](../images/optimizer_muon.svg) + +![优化器内存比较:每个优化器在每个参数上存储的内容](../images/optimizer_memory.svg) + +- 除了MSE和BCE之外,还有几种常用的**损失函数**。 + +- **平均绝对误差(MAE)**,或L1损失,取绝对差的平均值:$\frac{1}{n}\sum|y_i - \hat{y}_i|$。它对异常值比MSE更鲁棒,因为它不对大误差进行平方。 + +- **Huber损失**结合了两者的优点:对于小误差表现像MSE(平滑,易于优化),对于大误差表现像MAE(对异常值鲁棒)。它有一个控制过渡的阈值 $\delta$。 + +- **分类交叉熵(CCE)** 将BCE推广到多个类别。如果 $\hat{y}_k$ 是类别 $k$ 的预测概率,真实类别为 $c$: + +$$\mathcal{L} = -\log(\hat{y}_c)$$ + +- 这只是正确类别的负对数概率。最小化交叉熵等价于最大化似然,这联系到第05章的信息论:交叉熵衡量当你使用预测分布代替真实分布时需要多少额外比特。 + +- **Hinge损失** 被SVM使用:$\mathcal{L} = \max(0, 1 - y \cdot f(x))$。它只惩罚在间隔错误一侧或间隔内的预测。一旦一个点被足够置信地正确分类,损失为零。 + +- **正则化**通过添加对复杂模型的惩罚来防止过拟合。正则化后的损失为: + +$$\mathcal{L}_{\text{reg}} = \mathcal{L}_{\text{data}} + \lambda \, R(w)$$ + +- **L2正则化**(Ridge,权重衰减)惩罚平方权重之和:$R(w) = \|w\|^2 = \sum w_i^2$。它阻止任何单个权重变得过大,有效地将所有权重向零收缩,但很少使它们精确为零。 + +- **L1正则化**(Lasso)惩罚绝对权重之和:$R(w) = \|w\|_1 = \sum |w_i|$。它鼓励稀疏性,将许多权重驱动到精确为零,实现自动特征选择。 + +- **弹性网络** 结合了两者:$R(w) = \alpha \|w\|_1 + (1 - \alpha) \|w\|^2$,融合了稀疏性和收缩。 + +- 有一个优美的贝叶斯解释(来自第05章)。L2正则化等价于在权重上放置高斯先验并寻找MAP估计。L1正则化对应于拉普拉斯先验。正则化强度 $\lambda$ 控制你相对于数据信任先验的程度。 + +- **评估指标**告诉你模型是否真正有效。对于回归,MSE和MAE是标准指标。对于分类,情况更为微妙。 + +- **混淆矩阵**是一个二元分类的四格表: + - 真正例(TP):预测为正,实际为正 + - 假正例(FP):预测为正,实际为负 + - 真负例(TN):预测为负,实际为负 + - 假负例(FN):预测为负,实际为正 + +- **准确率** = $\frac{TP + TN}{TP + TN + FP + FN}$ 在类别不平衡时可能具有误导性。如果99%的电子邮件不是垃圾邮件,一个总是预测"非垃圾邮件"的模型有99%的准确率,但没有用处。 + +- **精确率** = $\frac{TP}{TP + FP}$ 回答:在所有预测为正的样本中,有多少实际为正?高精确率意味着误报少。 + +- **召回率**(敏感度)= $\frac{TP}{TP + FN}$ 回答:在所有实际为正的样本中,你捕获了多少?高召回率意味着漏检少。 + +- **F1分数** = $\frac{2 \cdot \text{precision} \cdot \text{recall}}{\text{precision} + \text{recall}}$ 是精确率和召回率的调和平均数,平衡了两者。 + +- **ROC曲线**绘制了真正率(召回率)对假正率($\frac{FP}{FP + TN}$)随分类阈值从0到1变化的曲线。完美分类器紧贴左上角。**AUC**(ROC曲线下面积)用一个数字概括性能:1.0为完美,0.5为随机猜测。 + +- **交叉验证**提供了更可靠的泛化性能估计。在 $k$ 折交叉验证中,你将数据分成 $k$ 份,在 $k-1$ 份上训练,在剩余一份上测试,然后轮换。所有 $k$ 折的平均测试性能就是你的估计。这使用了所有数据进行训练和测试(只是不在同一时间),在数据稀缺时尤为宝贵。 + +- **偏差-方差权衡**(来自第04章)是ML中的基本张力。模型期望误差分解为: + +$$\text{Error} = \text{Bias}^2 + \text{Variance} + \text{Irreducible Noise}$$ + +- **偏差**是错误假设带来的系统性误差(例如,用直线拟合曲线数据)。**方差**是对训练数据波动的敏感度(例如,20次多项式拟合噪声)。简单模型具有高偏差和低方差;复杂模型具有低偏差和高方差。最优在两者之间。 + +- **学习率调度**在训练期间调整 $\eta$。常见策略: + - 步长衰减:每 $N$ 个epoch将 $\eta$ 乘以一个因子(如0.1) + - 余弦退火:按照余弦曲线从初始值平滑降低 $\eta$ 到接近零 + - 预热:从一个非常小的 $\eta$ 开始,在前几千步线性增加,然后衰减。这防止了大的初始梯度破坏训练稳定性 + - 1cycle:一个先升后降的余弦周期,可以带来更快的收敛 + +- **超参数调优**是找到学习率、批量大小、正则化强度和其他不由梯度下降学习的设置的良好值的过程。常用方法: + - 网格搜索:在预定义的网格上尝试每一种组合(穷举但代价高) + - 随机搜索:随机采样组合,通常更高效,因为并非所有超参数同等重要 + - 贝叶斯优化:构建目标函数的模型并智能选择下一个要尝试的超参数 + - **ASHA**(异步连续减半算法):使用小预算并行运行许多试验,然后将最有希望的提升到更大预算,同时及早终止其余试验。它结合了早停的高效性和大规模并行性——不是运行100次完整的训练,而是廉价地启动所有100次,在每级保留前四分之一,只有少数运行到完成。这是现代大规模调优框架(如Ray Tune)的骨干。 + +- **无调度学习**完全消除了对学习率调度的需求。它不是在固定曲线上衰减 $\eta$,而是维护两个序列:一个缓慢移动的迭代平均值 $z_t$(收敛到最优值)和一个快速探索的迭代 $y_t$(在其上评估梯度)。最终输出是平均序列,被证明在事后能匹配最佳调度的收敛速度。这完全消除了调度作为一个超参数——你只需设置基础学习率,优化器处理其余部分。SGD和Adam的无调度变体已被证明能达到或超越其经过调度的对应版本。 + +## 编程任务(在CoLab或笔记本中完成) + +1. 使用正规方程和梯度下降两种方法实现线性回归。比较求解结果,并绘制GD损失随迭代的收敛曲线。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# 生成合成数据:y = 3x + 2 + noise +key = jax.random.PRNGKey(42) +n = 100 +X = jax.random.uniform(key, (n, 1), minval=0, maxval=10) +y = 3 * X[:, 0] + 2 + jax.random.normal(key, (n,)) * 1.5 + +# 添加偏置列 +X_b = jnp.column_stack([X, jnp.ones(n)]) + +# 正规方程 +w_exact = jnp.linalg.solve(X_b.T @ X_b, X_b.T @ y) +print(f"Normal equation: w={w_exact[0]:.4f}, b={w_exact[1]:.4f}") + +# 梯度下降 +w_gd = jnp.zeros(2) +lr = 0.005 +losses = [] +for step in range(500): + pred = X_b @ w_gd + error = pred - y + loss = jnp.mean(error ** 2) + losses.append(float(loss)) + grad = (2 / n) * X_b.T @ error + w_gd = w_gd - lr * grad + +print(f"Gradient descent: w={w_gd[0]:.4f}, b={w_gd[1]:.4f}") + +fig, axes = plt.subplots(1, 2, figsize=(12, 4)) +axes[0].scatter(X[:, 0], y, s=15, alpha=0.5, color='#3498db') +axes[0].plot([0, 10], [w_exact[1], w_exact[0]*10 + w_exact[1]], color='#e74c3c', linewidth=2) +axes[0].set_title("Linear Regression Fit") +axes[0].set_xlabel("x"); axes[0].set_ylabel("y") + +axes[1].plot(losses, color='#27ae60', linewidth=1.5) +axes[1].set_title("GD Loss Convergence") +axes[1].set_xlabel("Step"); axes[1].set_ylabel("MSE") +axes[1].set_yscale('log') +plt.tight_layout() +plt.show() +``` + +2. 从头实现带梯度下降的逻辑回归。在二维数据集上训练并可视化学习到的决策边界。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +from sklearn.datasets import make_moons + +# 生成数据 +X, y = make_moons(n_samples=300, noise=0.2, random_state=42) +X, y = jnp.array(X), jnp.array(y, dtype=jnp.float32) + +def sigmoid(z): + return 1 / (1 + jnp.exp(-z)) + +# 添加偏置列 +X_b = jnp.column_stack([X, jnp.ones(len(X))]) +w = jnp.zeros(3) +lr = 0.5 +losses = [] + +for step in range(2000): + z = X_b @ w + pred = sigmoid(z) + # BCE损失 + loss = -jnp.mean(y * jnp.log(pred + 1e-8) + (1 - y) * jnp.log(1 - pred + 1e-8)) + losses.append(float(loss)) + # 梯度 + grad = X_b.T @ (pred - y) / len(y) + w = w - lr * grad + +# 决策边界 +xx, yy = jnp.meshgrid(jnp.linspace(-2, 3, 200), jnp.linspace(-1.5, 2, 200)) +grid = jnp.column_stack([xx.ravel(), yy.ravel(), jnp.ones(xx.size)]) +zz = sigmoid(grid @ w).reshape(xx.shape) + +plt.figure(figsize=(8, 6)) +plt.contourf(xx, yy, zz, levels=[0, 0.5, 1], alpha=0.3, colors=['#e74c3c', '#3498db']) +plt.contour(xx, yy, zz, levels=[0.5], colors='#9b59b6', linewidths=2) +plt.scatter(X[y==0, 0], X[y==0, 1], c='#e74c3c', s=15, label='Class 0') +plt.scatter(X[y==1, 0], X[y==1, 1], c='#3498db', s=15, label='Class 1') +plt.title("Logistic Regression Decision Boundary") +plt.legend() +plt.grid(alpha=0.3) +plt.show() +``` + +3. 在二维二次曲面上比较优化器的轨迹。从相同的起点运行SGD、SGD+Momentum和Adam,绘制它们的路径。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# 拉长的二次曲面:L(w1, w2) = 0.5*w1^2 + 10*w2^2 +def loss_fn(w): + return 0.5 * w[0]**2 + 10 * w[1]**2 + +grad_fn = jax.grad(loss_fn) + +def run_sgd(w0, lr=0.05, steps=80): + w = w0.copy() + path = [w.copy()] + for _ in range(steps): + g = grad_fn(w) + w = w - lr * g + path.append(w.copy()) + return jnp.stack(path) + +def run_momentum(w0, lr=0.05, beta=0.9, steps=80): + w, v = w0.copy(), jnp.zeros(2) + path = [w.copy()] + for _ in range(steps): + g = grad_fn(w) + v = beta * v + (1 - beta) * g + w = w - lr * v + path.append(w.copy()) + return jnp.stack(path) + +def run_adam(w0, lr=0.05, b1=0.9, b2=0.999, eps=1e-8, steps=80): + w, m, v = w0.copy(), jnp.zeros(2), jnp.zeros(2) + path = [w.copy()] + for t in range(1, steps + 1): + g = grad_fn(w) + m = b1 * m + (1 - b1) * g + v = b2 * v + (1 - b2) * g**2 + m_hat = m / (1 - b1**t) + v_hat = v / (1 - b2**t) + w = w - lr * m_hat / (jnp.sqrt(v_hat) + eps) + path.append(w.copy()) + return jnp.stack(path) + +w0 = jnp.array([8.0, 3.0]) +sgd_path = run_sgd(w0) +mom_path = run_momentum(w0) +adam_path = run_adam(w0) + +# 绘图 +fig, ax = plt.subplots(figsize=(8, 6)) +w1 = jnp.linspace(-10, 10, 100) +w2 = jnp.linspace(-4, 4, 100) +W1, W2 = jnp.meshgrid(w1, w2) +L = 0.5 * W1**2 + 10 * W2**2 +ax.contour(W1, W2, L, levels=20, cmap='Greys', alpha=0.4) +ax.plot(sgd_path[:,0], sgd_path[:,1], 'o-', color='#3498db', markersize=2, linewidth=1, label='SGD') +ax.plot(mom_path[:,0], mom_path[:,1], 'o-', color='#27ae60', markersize=2, linewidth=1, label='Momentum') +ax.plot(adam_path[:,0], adam_path[:,1], 'o-', color='#e74c3c', markersize=2, linewidth=1, label='Adam') +ax.plot(0, 0, 'k*', markersize=15, label='Minimum') +ax.set_xlabel('w₁'); ax.set_ylabel('w₂') +ax.set_title("Optimizer Trajectories on Elongated Quadratic") +ax.legend() +plt.grid(alpha=0.3) +plt.show() +``` + +4. 展示L1与L2正则化对权重稀疏性的影响。使用两种惩罚训练线性回归,并比较得到的权重向量。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# 合成数据:20个特征中只有前3个是相关的 +key = jax.random.PRNGKey(0) +n, d = 200, 20 +w_true = jnp.zeros(d).at[:3].set(jnp.array([3.0, -2.0, 1.5])) +X = jax.random.normal(key, (n, d)) +y = X @ w_true + 0.5 * jax.random.normal(key, (n,)) + +def train_ridge(X, y, lam=1.0, lr=0.01, steps=2000): + """通过GD进行L2正则化线性回归。""" + w = jnp.zeros(X.shape[1]) + for _ in range(steps): + pred = X @ w + grad = (2/len(y)) * X.T @ (pred - y) + 2 * lam * w + w = w - lr * grad + return w + +def train_lasso(X, y, lam=1.0, lr=0.01, steps=2000): + """通过近端GD进行L1正则化线性回归。""" + w = jnp.zeros(X.shape[1]) + for _ in range(steps): + pred = X @ w + grad = (2/len(y)) * X.T @ (pred - y) + w = w - lr * grad + # 软阈值(L1的近端算子) + w = jnp.sign(w) * jnp.maximum(jnp.abs(w) - lr * lam, 0) + return w + +w_l2 = train_ridge(X, y, lam=0.1) +w_l1 = train_lasso(X, y, lam=0.1) + +fig, axes = plt.subplots(1, 3, figsize=(14, 4)) +axes[0].bar(range(d), w_true, color='#333', alpha=0.7) +axes[0].set_title("True Weights"); axes[0].set_xlabel("Feature") +axes[1].bar(range(d), w_l2, color='#3498db', alpha=0.7) +axes[1].set_title("L2 (Ridge): shrinks all"); axes[1].set_xlabel("Feature") +axes[2].bar(range(d), w_l1, color='#e74c3c', alpha=0.7) +axes[2].set_title("L1 (Lasso): zeros out irrelevant"); axes[2].set_xlabel("Feature") +plt.tight_layout() +plt.show() + +print(f"L2 non-zero weights: {int(jnp.sum(jnp.abs(w_l2) > 0.01))}/{d}") +print(f"L1 non-zero weights: {int(jnp.sum(jnp.abs(w_l1) > 0.01))}/{d}") +``` diff --git a/chapter 06: machine learning/03. deep learning.md b/chapter 06: machine learning/03. deep learning.md new file mode 100644 index 0000000..bd33e0a --- /dev/null +++ b/chapter 06: machine learning/03. deep learning.md @@ -0,0 +1,354 @@ +# 深度学习 + +*深度学习堆叠非线性层来构建层次化表示,自动将原始输入转换为有用的特征。本文涵盖MLP、激活函数、反向传播、CNN、RNN、LSTM、注意力机制、Transformer、GAN、VAE、扩散模型和归一化技术* + +- 什么使网络"深"?浅网络只有一个隐藏层;深网络有许多层。深度让网络构建层次化表示,早期层学习简单特征(边缘、音调),后期层将它们组合成复杂概念(人脸、句子)。这种组合性正是深度学习力量的来源。 + +- 最简单的深度网络是**多层感知器(MLP)**,也称为全连接或密集网络。每层计算: + +$$h = \sigma(Wx + b)$$ + +- 这里 $W$ 是权重矩阵(第02章),$b$ 是偏置向量,$\sigma$ 是非线性激活函数。一层的输出成为下一层的输入。没有非线性,堆叠层将毫无意义:$W_2(W_1 x) = (W_2 W_1)x$,这只是另一个线性变换。这正是第02章中的矩阵乘法塌缩。 + +- **激活函数**引入使深度有意义的非线性。 + +- **ReLU**(修正线性单元):$\text{ReLU}(x) = \max(0, x)$。它是使用最广泛的激活函数。计算速度快,正输入不饱和,并产生稀疏激活(许多神经元输出精确为零)。缺点:负输入的神经元总是输出零,如果它们永久卡在那里,就会"死亡"并停止学习。 + +- **Sigmoid**:$\sigma(x) = \frac{1}{1+e^{-x}}$,将输入压缩到 $(0, 1)$。适用于二元分类的输出层,但在隐藏层中有问题,因为当输入远离零时梯度消失(曲线几乎平坦)。 + +- **Tanh**:$\tanh(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}$,压缩到 $(-1, 1)$。零中心(不同于sigmoid),有助于梯度流动,但在极端值处仍存在梯度消失问题。 + +- **GELU**(高斯误差线性单元):$\text{GELU}(x) = x \cdot \Phi(x)$,其中 $\Phi$ 是标准正态CDF。它是ReLU的平滑近似,允许微小的负值通过。GELU是GPT和BERT中的默认选择。 + +- **Swish**:$\text{Swish}(x) = x \cdot \sigma(x)$,另一种平滑门控。实际使用中与GELU类似。 + +![ReLU、Sigmoid、Tanh和GELU及其关键属性的并列图](../images/activation_functions.svg) + +- 一个具有 $d_{\text{in}}$ 个输入和 $d_{\text{out}}$ 个输出的密集层有 $d_{\text{in}} \times d_{\text{out}} + d_{\text{out}}$ 个参数(权重加偏置)。矩阵乘法 $Wx$ 就是第02章中的矩阵-向量乘法。在批处理设置中,输入是形状为 $(B, d_{\text{in}})$ 的矩阵 $X$,输出是形状为 $(B, d_{\text{out}})$ 的 $XW^T + b$。 + +- **万能近似定理**指出,一个具有足够神经元的隐藏层可以在紧致域上以任意精度逼近任何连续函数。这听起来似乎深度无关紧要,但关键在于"足够的神经元"。实际上,深层网络可以用指数级少于浅层网络的参数来表示相同的函数。深度带来的是效率,而不仅仅是表达能力。 + +- 随着网络变深,出现两种梯度病理。**梯度消失**:当梯度通过许多层时(通过链式法则,第03章),它们被乘以许多因子。如果这些因子都小于1(如sigmoid和tanh饱和时发生的情况),梯度呈指数级缩小趋近于零。早期层几乎无法学习。**梯度爆炸**:如果因子都大于1,梯度呈指数级增长,导致数值溢出和训练不稳定。 + +- 梯度消失/爆炸的解决方案: + - 使用ReLU或GELU激活函数(正输入时梯度为1,无饱和) + - 仔细的权重初始化 + - 归一化层 + - 残差连接(跳跃连接) + - 梯度裁剪(针对梯度爆炸):将梯度范数限制在最大值 + +- **权重初始化**很重要,因为它决定了训练开始时激活值和梯度的尺度。如果权重太大,激活值爆炸;太小,它们消失。 + +- **Xavier (Glorot) 初始化**从方差为 $\frac{2}{d_{\text{in}} + d_{\text{out}}}$ 的分布中设置权重。这假设使用线性或tanh激活函数时,能使激活值的方差在各层大致保持恒定。 + +- **He (Kaiming) 初始化**使用方差 $\frac{2}{d_{\text{in}}}$,针对ReLU激活函数校准(由于ReLU将半数激活值置零,需要双倍方差来补偿)。 + +- **归一化层**通过确保每层的输入具有一致的统计特性(大致零均值、单位方差)来稳定训练。 + +- **批归一化(BatchNorm)** 在批次维度上进行归一化:对于每个通道/特征,计算小批次中所有样本的均值和方差,然后归一化。它添加了可学习的尺度($\gamma$)和偏移($\beta$)参数,以便网络在需要时撤销归一化: + +$$\hat{x} = \frac{x - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}, \quad y = \gamma \hat{x} + \beta$$ + +- BatchNorm有一个问题:它依赖于批量大小。当批次非常小时,统计数据有噪声。在推理时,使用运行平均值而非批次统计,这造成了训练/测试不一致。 + +- **层归一化(LayerNorm)** 对每个单独样本在特征维度上进行归一化。它不依赖于批次中的其他样本,使其成为Transformer和循环网络的标准选择。 + +- **实例归一化** 对每个样本和每个通道独立地在空间维度上进行归一化。在风格迁移中很流行。 + +- **组归一化** 将通道分成组并在每个组内进行归一化。它是LayerNorm和InstanceNorm之间的折中。 + +![3D张量,彩色切片显示BatchNorm、LayerNorm和InstanceNorm在哪些维度上进行归一化](../images/normalization_types.svg) + +- **Dropout** 是一种正则化技术,在训练期间随机将一部分 $p$ 的神经元置零。这迫使网络不依赖任何单个神经元,鼓励冗余表示。测试时,所有神经元都被激活。**逆置Dropout** 在训练期间将激活值缩放 $\frac{1}{1-p}$,以便测试时无需缩放。这是标准实现。 + +- **卷积神经网络(CNN)** 利用了空间结构。卷积层不是将每个输入连接到每个输出(如密集层),而是在输入上滑动一个小滤波器(核),在每个位置计算点积。相同的滤波器权重在所有位置共享,这大大减少了参数并内建了平移不变性。 + +- 二维输入与大小为 $k \times k$ 的滤波器 $K$ 的**卷积操作**: + +$$(\text{input} * K)[i,j] = \sum_{m=0}^{k-1} \sum_{n=0}^{k-1} \text{input}[i+m, j+n] \cdot K[m, n]$$ + +![输入网格上滑动3x3滤波器,在每个位置进行逐元素乘加产生输出特征图](../images/cnn_convolution.svg) + +- 输出大小取决于三个超参数。**步幅**控制滤波器在位置之间移动多少像素(步幅2使空间维度减半)。**填充**在输入边界周围添加零("same"填充保持空间大小,"valid"填充不填充)。输出大小公式:$\text{out} = \lfloor (\text{in} - k + 2p) / s \rfloor + 1$。 + +- **池化**层对特征图进行下采样。最大池化取每个窗口中的最大值;平均池化取均值。池化在保留最重要信息的同时减少空间维度。 + +- **扩张卷积** 在滤波器元素之间插入间隙,增加感受野而不增加参数。扩张率为2意味着3x3滤波器覆盖5x5区域。 + +- **1x1卷积** 是使用1x1滤波器的卷积。它们不查看空间邻居;而是跨通道混合信息。可以将其视为在每个空间位置应用密集层。用于廉价地改变通道数。 + +- **跳跃连接**(残差连接)让输入绕过一层或多层:$\text{output} = F(x) + x$。该层只需学习残差 $F(x) = \text{output} - x$,当最优变换接近恒等映射时这更容易。ResNet(残差网络)使用这一技巧堆叠超过100层,解决了更深的网络表现比浅层网络更差的退化问题。 + +- CNN构建了一个**特征层次结构**。早期层检测边缘和纹理。中间层将这些组合成部件(眼睛、轮子)。后期层识别整个物体。每层的感受野(它"看到"的输入区域)随深度增加。 + +- **嵌入**将离散的标记(单词、字符、物品ID)映射到密集向量。嵌入层只是一个查找表:一个形状为(词汇表大小,嵌入维度)的矩阵 $E$。查找标记 $i$ 意味着选择 $E$ 的第 $i$ 行。这等价于乘以one-hot向量,这只是矩阵-向量乘法的一个特例(第02章)。嵌入在训练期间学习,因此相似的标记最终具有相似的向量。 + +- **分词**是将原始文本转换为标记序列的过程。词级分词按空格分割,但无法处理未见过的词。**子词分词**(BPE、WordPiece、SentencePiece)将文本分解为频繁的子词单元,平衡词汇表大小和覆盖率。单词"unhappiness"可能变成["un", "happiness"]或["un", "happ", "iness"]。 + +- **循环神经网络(RNN)** 一次处理一个序列元素,维护一个向前传递信息的隐藏状态: + +$$h_t = \tanh(W_h h_{t-1} + W_x x_t + b)$$ + +- 隐藏状态 $h_t$ 是网络到时间 $t$ 为止所看到内容的压缩摘要。相同的权重 $W_h$ 和 $W_x$ 在所有时间步共享(权重共享,如同CNN共享空间权重)。 + +- 原始RNN在长序列上存在梯度消失问题:从步骤 $t$ 到步骤 $t - k$ 的梯度信号经过 $k$ 次与 $W_h$ 的乘法,呈指数级缩小(或爆炸)。 + +- **LSTM**(长短时记忆网络)通过引入一个独立的细胞状态 $c_t$ 来解决这一问题,该状态以最小干扰流过时间。三个门控制哪些信息进入、离开和持续存在: + +- **遗忘门**决定从细胞状态中擦除什么:$f_t = \sigma(W_f [h_{t-1}, x_t] + b_f)$ +- **输入门**决定写入什么新信息:$i_t = \sigma(W_i [h_{t-1}, x_t] + b_i)$,候选值 $\tilde{c}_t = \tanh(W_c [h_{t-1}, x_t] + b_c)$ +- 细胞状态更新:$c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t$ +- **输出门**决定暴露什么:$o_t = \sigma(W_o [h_{t-1}, x_t] + b_o)$,$h_t = o_t \odot \tanh(c_t)$ + +![LSTM单元显示遗忘门、输入门、输出门、细胞状态高速公路和数据流连接](../images/rnn_lstm_cell.svg) + +- 细胞状态像传送带一样工作:信息可以不变地流过许多时间步(遗忘门保持接近1),这解决了长距离依赖的梯度消失问题。 + +- **GRU**(门控循环单元)通过将细胞状态和隐藏状态合并为一个,并使用两个门(更新门和重置门)代替三个门来简化LSTM。GRU参数更少,通常表现与LSTM相当。 + +- RNN(包括LSTM)的根本限制是顺序处理:必须按顺序处理标记1、标记2、标记3。这阻止了并行化并造成信息瓶颈,因为所有上下文必须通过固定大小的隐藏状态。 + +- **注意力机制**解决了这两个问题。注意力机制不是将整个输入压缩为固定向量,而是让模型回顾所有输入位置并决定哪些位置与当前输出相关。 + +- 现代公式使用**查询、键和值(Q, K, V)**。将其想象为图书馆搜索:你有一个查询(你在找什么)、键(每本书的标签)和值(实际书籍内容)。你将查询与所有键比较,以确定检索哪些值。 + +- **缩放点积注意力**: + +$$\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$ + +- $QK^T$ 计算每个查询和每个键之间的相似度。这是矩阵乘法(第02章),其中的条目是点积,衡量余弦相似度(第01章)。除以 $\sqrt{d_k}$ 防止点积变得太大(这会使softmax饱和并产生接近one-hot分布,导致梯度消失)。Softmax将相似度转换为概率分布。乘以 $V$ 产生值的加权组合。 + +- **多头注意力**运行 $h$ 个并行的注意力操作,每个使用不同的Q、K、V学习投影。这让模型同时从不同的表示子空间关注信息。一个头可能关注句法关系,而另一个关注语义关系。输出被拼接并投影: + +$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O$$ + +- **Transformer**架构(Vaswani等人,2017)完全由注意力和前馈层构建,没有循环。编码器块重复:多头自注意力、加法和层归一化、前馈网络、加法和层归一化。解码器块添加了掩码自注意力(防止模型看到未来的标记)和关注编码器输出的交叉注意力层。 + +![Transformer编码器块:多头注意力、加法和层归一化、前馈网络、加法和层归一化,带有残差连接](../images/transformer_block.svg) + +- **位置编码**是必需的,因为注意力是排列等变的,意味着它将输入视为集合而非序列。没有位置信息,"猫坐在垫子上"和"垫子坐在猫上"将是相同的。原始Transformer使用正弦位置编码: + +$$PE_{(pos, 2i)} = \sin\!\left(\frac{pos}{10000^{2i/d}}\right), \quad PE_{(pos, 2i+1)} = \cos\!\left(\frac{pos}{10000^{2i/d}}\right)$$ + +- 每个位置获得一个唯一的向量,模型可以用来区分位置。现代模型通常使用学习的位置嵌入或相对位置编码(RoPE、ALiBi)代替。 + +- Transformer并行处理所有标记(自注意力矩阵 $QK^T$ 在一次矩阵乘法中计算),这使得它们在现代硬件上比RNN训练更快。权衡是自注意力在序列长度上是 $O(n^2)$(每个标记关注每个其他标记),而RNN是 $O(n)$。这就是为什么长上下文模型需要特殊的注意力变体(稀疏注意力、线性注意力、Flash Attention)。 + +- **视觉Transformer(ViT)** 通过将图像分割为固定大小的块(如16x16),将每个块展平为向量,并将这些块视为标记序列,将Transformer应用于图像。一个可学习的[CLS]标记被前置,其最终表示用于分类。尽管没有卷积的归纳偏置,ViT在足够数据上训练时可以匹配或超越CNN。 + +- **MLP-Mixer** 是一种更简单的架构,用MLP替代了注意力和卷积。它在"标记混合"MLP(跨空间位置应用)和"通道混合"MLP(跨特征应用)之间交替。它的表现具有竞争力,表明现代架构的关键洞察不是注意力本身,而是跨标记和特征的高效信息混合。 + +- **自编码器**通过训练网络重构自身输入来学习压缩表示。编码器将输入映射到低维瓶颈(潜码),解码器将其映射回来: + +$$z = f_{\text{enc}}(x), \quad \hat{x} = f_{\text{dec}}(z), \quad \mathcal{L} = \|x - \hat{x}\|^2$$ + +- 瓶颈迫使网络学习最重要的特征。自编码器用于降维、去噪(在噪声输入上训练,重构干净输出)和异常检测(高重构误差表明输入异常)。 + +- **变分自编码器(VAE)** 增加了概率的变体。编码器不是编码到单个点 $z$,而是输出分布的参数(高斯的均值 $\mu$ 和方差 $\sigma^2$)。潜码从此分布中采样:$z = \mu + \sigma \odot \epsilon$,其中 $\epsilon \sim \mathcal{N}(0, I)$。这个**重参数化技巧**使采样可微,梯度可以流过。 + +- VAE损失有两个项: + +$$\mathcal{L} = \underbrace{\|x - \hat{x}\|^2}_{\text{reconstruction}} + \underbrace{D_{\text{KL}}(q(z|x) \| p(z))}_{\text{regularisation}}$$ + +- KL散度项(来自第05章)将学习到的后验 $q(z|x)$ 推向先验 $p(z) = \mathcal{N}(0, I)$,确保潜空间平滑且结构良好。然后你可以从先验中采样并解码以生成新数据。这就是使VAE成为生成模型的原因。 + +## 编程任务(在CoLab或笔记本中完成) + +1. 在JAX中从头构建一个简单的MLP。在二维分类问题(如同心圆)上训练并可视化决策边界。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +from sklearn.datasets import make_circles + +# 数据 +X, y = make_circles(n_samples=500, noise=0.1, factor=0.5, random_state=42) +X, y = jnp.array(X), jnp.array(y, dtype=jnp.float32) + +# 初始化一个2层MLP:2 -> 16 -> 16 -> 1 +def init_params(key): + k1, k2, k3 = jax.random.split(key, 3) + return { + 'W1': jax.random.normal(k1, (2, 16)) * 0.5, + 'b1': jnp.zeros(16), + 'W2': jax.random.normal(k2, (16, 16)) * 0.5, + 'b2': jnp.zeros(16), + 'W3': jax.random.normal(k3, (16, 1)) * 0.5, + 'b3': jnp.zeros(1), + } + +def forward(params, x): + h = jnp.maximum(0, x @ params['W1'] + params['b1']) # ReLU + h = jnp.maximum(0, h @ params['W2'] + params['b2']) # ReLU + logit = (h @ params['W3'] + params['b3']).squeeze() + return jax.nn.sigmoid(logit) + +def loss_fn(params, X, y): + pred = forward(params, X) + return -jnp.mean(y * jnp.log(pred + 1e-7) + (1 - y) * jnp.log(1 - pred + 1e-7)) + +grad_fn = jax.jit(jax.grad(loss_fn)) +params = init_params(jax.random.PRNGKey(0)) +lr = 0.1 + +for step in range(2000): + grads = grad_fn(params, X, y) + params = {k: params[k] - lr * grads[k] for k in params} + +# 绘制决策边界 +xx, yy = jnp.meshgrid(jnp.linspace(-2, 2, 200), jnp.linspace(-2, 2, 200)) +grid = jnp.column_stack([xx.ravel(), yy.ravel()]) +zz = forward(params, grid).reshape(xx.shape) + +plt.figure(figsize=(7, 6)) +plt.contourf(xx, yy, zz, levels=[0, 0.5, 1], alpha=0.3, colors=['#e74c3c', '#3498db']) +plt.scatter(X[y==0,0], X[y==0,1], c='#e74c3c', s=10, label='Class 0') +plt.scatter(X[y==1,0], X[y==1,1], c='#3498db', s=10, label='Class 1') +plt.title("MLP Decision Boundary on Concentric Circles") +plt.legend(); plt.grid(alpha=0.3); plt.show() + +acc = jnp.mean((forward(params, X) > 0.5) == y) +print(f"Accuracy: {acc:.2%}") +``` + +2. 从头实现一维卷积。将简单的边缘检测滤波器应用于信号,并与内置的 `jnp.convolve` 进行比较。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +def conv1d(signal, kernel): + """从头实现一维卷积(valid模式)。""" + n, k = len(signal), len(kernel) + output = jnp.zeros(n - k + 1) + for i in range(n - k + 1): + output = output.at[i].set(jnp.sum(signal[i:i+k] * kernel)) + return output + +# 创建一个带有阶跃函数的信号 +t = jnp.linspace(0, 4, 200) +signal = jnp.where(t < 1, 0.0, jnp.where(t < 2, 1.0, jnp.where(t < 3, 0.5, 1.5))) + +# 边缘检测核 +edge_kernel = jnp.array([-1.0, 0.0, 1.0]) + +# 我们的实现 vs 内置函数 +our_output = conv1d(signal, edge_kernel) +jnp_output = jnp.convolve(signal, edge_kernel, mode='valid') + +fig, axes = plt.subplots(3, 1, figsize=(10, 6), sharex=True) +axes[0].plot(t, signal, color='#3498db', linewidth=1.5) +axes[0].set_title("Original Signal"); axes[0].set_ylabel("Value") + +axes[1].plot(t[:len(our_output)], our_output, color='#e74c3c', linewidth=1.5) +axes[1].set_title("After Edge Detection (our conv1d)"); axes[1].set_ylabel("Value") + +axes[2].plot(t[:len(jnp_output)], jnp_output, color='#27ae60', linewidth=1.5, linestyle='--') +axes[2].set_title("After Edge Detection (jnp.convolve)"); axes[2].set_ylabel("Value") +axes[2].set_xlabel("t") + +plt.tight_layout(); plt.show() +print(f"Outputs match: {jnp.allclose(our_output, jnp_output)}") +``` + +3. 从头实现缩放点积注意力。为一个小例子计算注意力权重,并将注意力矩阵可视化为热力图。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +def scaled_dot_product_attention(Q, K, V): + """缩放点积注意力。""" + d_k = Q.shape[-1] + scores = Q @ K.T / jnp.sqrt(d_k) + weights = jax.nn.softmax(scores, axis=-1) + output = weights @ V + return output, weights + +# 示例:4个标记,嵌入维度8 +key = jax.random.PRNGKey(42) +k1, k2, k3 = jax.random.split(key, 3) +seq_len, d_model = 4, 8 + +Q = jax.random.normal(k1, (seq_len, d_model)) +K = jax.random.normal(k2, (seq_len, d_model)) +V = jax.random.normal(k3, (seq_len, d_model)) + +output, weights = scaled_dot_product_attention(Q, K, V) + +print(f"Q shape: {Q.shape}") +print(f"Attention weights shape: {weights.shape}") +print(f"Output shape: {output.shape}") +print(f"\nAttention weights (rows sum to 1):") +print(weights) +print(f"Row sums: {weights.sum(axis=-1)}") + +# 可视化注意力 +fig, ax = plt.subplots(figsize=(5, 4)) +im = ax.imshow(weights, cmap='Blues', vmin=0, vmax=1) +ax.set_xlabel("Key position"); ax.set_ylabel("Query position") +ax.set_title("Attention Weights") +tokens = ['tok 0', 'tok 1', 'tok 2', 'tok 3'] +ax.set_xticks(range(4)); ax.set_xticklabels(tokens) +ax.set_yticks(range(4)); ax.set_yticklabels(tokens) +for i in range(4): + for j in range(4): + ax.text(j, i, f"{weights[i,j]:.2f}", ha='center', va='center', fontsize=10) +plt.colorbar(im); plt.tight_layout(); plt.show() +``` + +4. 构建一个简单的自编码器,通过一维瓶颈压缩二维数据并重建。可视化潜空间和重建结果。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +from sklearn.datasets import make_moons + +# 数据 +X, _ = make_moons(n_samples=500, noise=0.05, random_state=42) +X = jnp.array(X) + +# 自编码器:2 -> 8 -> 1 -> 8 -> 2 +def init_ae(key): + k1, k2, k3, k4 = jax.random.split(key, 4) + return { + 'enc_W1': jax.random.normal(k1, (2, 8)) * 0.5, 'enc_b1': jnp.zeros(8), + 'enc_W2': jax.random.normal(k2, (8, 1)) * 0.5, 'enc_b2': jnp.zeros(1), + 'dec_W1': jax.random.normal(k3, (1, 8)) * 0.5, 'dec_b1': jnp.zeros(8), + 'dec_W2': jax.random.normal(k4, (8, 2)) * 0.5, 'dec_b2': jnp.zeros(2), + } + +def encode(p, x): + h = jnp.tanh(x @ p['enc_W1'] + p['enc_b1']) + return h @ p['enc_W2'] + p['enc_b2'] + +def decode(p, z): + h = jnp.tanh(z @ p['dec_W1'] + p['dec_b1']) + return h @ p['dec_W2'] + p['dec_b2'] + +def ae_loss(p, X): + z = encode(p, X) + X_hat = decode(p, z) + return jnp.mean((X - X_hat) ** 2) + +grad_fn = jax.jit(jax.grad(ae_loss)) +params = init_ae(jax.random.PRNGKey(0)) +lr = 0.01 + +for step in range(3000): + grads = grad_fn(params, X) + params = {k: params[k] - lr * grads[k] for k in params} + +z = encode(params, X) +X_hat = decode(params, z) + +fig, axes = plt.subplots(1, 2, figsize=(12, 5)) +axes[0].scatter(X[:,0], X[:,1], c=z.squeeze(), cmap='viridis', s=10) +axes[0].set_title("Original Data (coloured by latent code)") +axes[1].scatter(X_hat[:,0], X_hat[:,1], c=z.squeeze(), cmap='viridis', s=10) +axes[1].set_title("Reconstruction from 1D bottleneck") +for ax in axes: + ax.set_aspect('equal'); ax.grid(alpha=0.3) +plt.tight_layout(); plt.show() + +print(f"Reconstruction MSE: {ae_loss(params, X):.4f}") +``` diff --git a/chapter 06: machine learning/04. reinforcement learning.md b/chapter 06: machine learning/04. reinforcement learning.md new file mode 100644 index 0000000..4197848 --- /dev/null +++ b/chapter 06: machine learning/04. reinforcement learning.md @@ -0,0 +1,353 @@ +# 强化学习 + +*强化学习通过试错法最大化累积奖励来训练智能体做出序列决策。本文件涵盖MDP、价值函数、贝尔曼方程、Q学习、策略梯度、演员-评论家方法、PPO和RLHF——这些是游戏智能体和语言模型对齐背后的框架。* + +- 监督学习需要标注数据。无监督学习在无标注数据中发现模式。**强化学习(RL)** 与两者都不同:智能体通过与环境的交互、采取行动和接收奖励来学习。没有正确的标签;智能体必须通过试错来发现好的行为。 + +- 想象教狗一个新把戏。你不会给它展示一个正确行为的数据集。相反,它尝试各种动作,你对好的行为给予奖励,随着时间的推移它明白了你想要什么。RL将这个形式化。 + +- RL设置包含五个核心组件。**智能体(agent)** 是学习者和决策者。**环境(environment)** 是智能体之外与之交互的一切。在每个时间步,智能体观察一个**状态(state)** $s_t$,选择一个**动作(action)** $a_t$,接收一个**奖励(reward)** $r_t$,并转移到新状态 $s_{t+1}$。智能体的目标是最大化其随时间收集的总奖励。 + +![智能体-环境循环:智能体观察状态,采取动作,接收奖励,环境转移到新状态](../images/mdp_agent_loop.svg) + +- **策略(policy)** $\pi$ 是智能体的策略:从状态到动作的映射。确定性策略对每个状态给出一个动作:$a = \pi(s)$。随机策略给出动作上的概率分布:$\pi(a \mid s)$。RL的目标是找到最优策略,即最大化期望累积奖励的策略。 + +- RL的数学框架是**马尔可夫决策过程(MDP)**,由元组 $(S, A, P, R, \gamma)$ 定义:一组状态 $S$,一组动作 $A$,转移概率 $P(s' \mid s, a)$,奖励函数 $R(s, a)$,以及折扣因子 $\gamma$。 + +- **马尔可夫性质**(来自第05章)指出未来仅取决于当前状态,而不是如何到达那里的历史:$P(s_{t+1} \mid s_t, a_t, s_{t-1}, \ldots) = P(s_{t+1} \mid s_t, a_t)$。这意味着状态包含了做出决策所需的全部信息。 + +- **折扣因子** $\gamma \in [0, 1)$ 决定了智能体对未来奖励相对于即时奖励的重视程度。从时间 $t$ 开始的折扣回报为: + +$$G_t = r_t + \gamma r_{t+1} + \gamma^2 r_{t+2} + \cdots = \sum_{k=0}^{\infty} \gamma^k r_{t+k}$$ + +- 当 $\gamma = 0$ 时,智能体完全短视,只关心下一个奖励。当 $\gamma$ 接近1时,智能体具有长远眼光。折扣因子还确保了求和收敛(如果奖励有界),这对数学上的良定义性很重要。 + +- **价值函数**估计处于某个状态(或在某个状态下采取某个动作)有多好。**状态价值函数** $V^\pi(s)$ 是从状态 $s$ 开始并按照策略 $\pi$ 行动所获得的期望回报: + +$$V^\pi(s) = \mathbb{E}_\pi \left[ G_t \mid s_t = s \right]$$ + +- **动作价值函数** $Q^\pi(s, a)$ 是从状态 $s$ 开始,采取动作 $a$,然后按照 $\pi$ 行动所获得的期望回报: + +$$Q^\pi(s, a) = \mathbb{E}_\pi \left[ G_t \mid s_t = s, a_t = a \right]$$ + +- 两者关系:$V^\pi(s) = \sum_a \pi(a \mid s) \, Q^\pi(s, a)$。状态价值是动作价值按策略加权的平均值。 + +- **贝尔曼方程**表达了递归关系:一个状态的价值等于即时奖励加上下一个状态的折扣价值。对于状态价值函数: + +$$V^\pi(s) = \sum_a \pi(a \mid s) \sum_{s'} P(s' \mid s, a) \left[ R(s, a) + \gamma \, V^\pi(s') \right]$$ + +- 对于最优价值函数 $V^{*}(s)$,智能体总是选择最佳动作: + +$$V^{*}(s) = \max_a \sum_{s'} P(s' \mid s, a) \left[ R(s, a) + \gamma \, V^{*}(s') \right]$$ + +- 类似地,$Q^{*}$ 的**贝尔曼最优方程**为: + +$$Q^{*}(s, a) = \sum_{s'} P(s' \mid s, a) \left[ R(s, a) + \gamma \max_{a'} Q^{*}(s', a') \right]$$ + +- 一旦你有了 $Q^{*}$,最优策略就很简单了:总是选择Q值最高的动作:$\pi^{*}(s) = \arg\max_a Q^{*}(s, a)$。 + +- **动态规划**方法在已知转移概率和奖励(完整模型)时求解MDP。**策略评估**通过迭代应用贝尔曼方程直到收敛来计算给定策略的 $V^\pi$。**策略改进**利用价值函数并通过对最优动作贪心来构建更好的策略:$\pi'(s) = \arg\max_a \sum_{s'} P(s' \mid s, a)[R(s,a) + \gamma V^\pi(s')]$。 + +- **策略迭代**在评估和改进之间交替,直到策略停止变化。它保证收敛到最优策略。 + +- **价值迭代**将两个步骤合并为一个:重复应用贝尔曼最优方程直到 $V^{*}$ 收敛,然后提取策略。 + +$$V(s) \leftarrow \max_a \sum_{s'} P(s' \mid s, a) \left[ R(s, a) + \gamma \, V(s') \right]$$ + +- 动态规划需要知道 $P(s' \mid s, a)$,这通常不可行。在大多数真实问题中,智能体不知道环境的动态;它只能与环境交互。这就是**无模型**方法发挥作用的地方。 + +- **时序差分(TD)学习**在不了解模型的情况下从经验中学习。关键思想是**引导(bootstrapping)**:不等情节结束才计算实际回报 $G_t$,而是使用当前的价值函数对其进行估计: + +$$V(s_t) \leftarrow V(s_t) + \alpha \left[ r_t + \gamma \, V(s_{t+1}) - V(s_t) \right]$$ + +- 括号中的项是**TD误差**:**TD目标**($r_t + \gamma V(s_{t+1})$)与当前估计 $V(s_t)$ 之间的差异。如果TD误差为正,说明该状态比预期好,我们增加其价值。如果为负,则减少其价值。 + +![状态转移展示TD目标:当前价值、奖励以及引导的下一状态价值,附更新公式](../images/td_update.svg) + +- TD学习在每一步之后(而不是完成整个情节后)进行更新,这使其比蒙特卡洛方法高效得多。它也适用于持续(非情节式)环境。 + +- **SARSA**(状态-动作-奖励-状态-动作)是将TD学习应用于Q值。智能体在状态 $s$ 下采取动作 $a$,观察奖励 $r$ 和下一状态 $s'$,然后根据其策略选择下一个动作 $a'$: + +$$Q(s, a) \leftarrow Q(s, a) + \alpha \left[ r + \gamma \, Q(s', a') - Q(s, a) \right]$$ + +- SARSA是**在策略(on-policy)**:它使用智能体实际采取的动作进行更新,这包括了探索。这使得SARSA更为保守;它学习一个考虑自身探索噪声的策略。 + +- **Q学习**是最著名的RL算法。它类似于SARSA,但不同的是它使用最佳可能动作而非智能体实际采取的动作: + +$$Q(s, a) \leftarrow Q(s, a) + \alpha \left[ r + \gamma \max_{a'} Q(s', a') - Q(s, a) \right]$$ + +- Q学习是**离策略(off-policy)**:它学习最优Q值,与正在执行的策略无关。智能体可以随机探索,同时仍然学习最优动作价值。这使得Q学习更具攻击性,通常收敛更快,但可能高估值。 + +- **探索 vs 利用**是基本困境:智能体应该利用已知信息(选择估计价值最高的动作)还是探索未知动作(可能发现更好的)? + +- 最简单的策略是**ε-贪心**:以概率 $\epsilon$ 采取随机动作(探索);以概率 $1 - \epsilon$ 采取贪心动作(利用)。一种常见的时间表是从高 $\epsilon$(大量探索)开始,随时间衰减。 + +- 表格方法(在表中存储每个状态-动作对的价值)适用于小的离散状态空间。对于大或连续的状态空间,需要函数近似。**深度Q网络(DQN)** 使用神经网络来近似 $Q(s, a; \theta)$,其中 $\theta$ 是网络权重。 + +- DQN引入了两个关键的稳定技术。**经验回放**:不是从连续的转移中学习(高度相关),而是将转移存储在回放缓冲区中,并采样随机小批次进行训练。这打破了相关性并高效地重用数据。 + +- **目标网络**:使用一个单独的、缓慢更新的网络副本来计算TD目标。没有这个,每次更新网络时目标都会移动,造成"追自己尾巴"的不稳定性。目标网络定期更新(每 $N$ 步硬更新)或连续更新(软更新:$\theta^{-} \leftarrow \tau\theta + (1-\tau)\theta^{-}$)。 + +- DQN损失只是预测Q值与TD目标之间的均方误差: + +$$\mathcal{L}(\theta) = \mathbb{E} \left[ \left( r + \gamma \max_{a'} Q(s', a'; \theta^{-}) - Q(s, a; \theta) \right)^2 \right]$$ + +- 到目前为止的所有方法都学习价值函数并从中推导策略。**策略梯度**方法采用不同方法:它们直接参数化策略 $\pi(a \mid s; \theta)$ 并通过梯度上升优化期望回报。 + +- **策略梯度定理**给出了期望回报相对于策略参数的梯度: + +$$\nabla_\theta J(\theta) = \mathbb{E}_\pi \left[ \nabla_\theta \log \pi(a \mid s; \theta) \cdot G_t \right]$$ + +- 这说明:增加导致高回报的动作的概率,减少导致低回报的动作的概率。对数概率梯度给出了改变策略的方向,$G_t$ 则缩放改变的程度。 + +- **REINFORCE**是最简单的策略梯度算法。运行一个情节,为每一步计算回报 $G_t$,然后更新: + +$$\theta \leftarrow \theta + \alpha \, \nabla_\theta \log \pi(a_t \mid s_t; \theta) \cdot G_t$$ + +- REINFORCE方差很高,因为 $G_t$ 是期望回报的噪声单样本估计。一个常见修复是减去一个**基线(baseline)**(通常是平均回报或学习到的价值函数)来降低方差而不引入偏差: + +$$\theta \leftarrow \theta + \alpha \, \nabla_\theta \log \pi(a_t \mid s_t; \theta) \cdot (G_t - b)$$ + +- **演员-评论家(Actor-Critic)** 方法使用两个网络。**演员(actor)** 是策略 $\pi(a \mid s; \theta)$。**评论家(critic)** 是价值函数 $V(s; \phi)$,作为基线。优势 $A_t = r_t + \gamma V(s_{t+1}) - V(s_t)$ 替代了 $G_t - b$: + +$$\theta \leftarrow \theta + \alpha \, \nabla_\theta \log \pi(a_t \mid s_t; \theta) \cdot A_t$$ + +- 评论家通过最小化TD误差来更新,与基于价值的方法相同。演员使用策略梯度更新,评论家的优势估计降低了方差。这是两全其美。 + +![双头架构:演员输出动作概率,评论家输出价值估计,优势信号指导演员更新](../images/actor_critic.svg) + +- **PPO**(近端策略优化)是实践中使用最广泛的策略梯度算法。它解决了一个关键问题:如果策略更新过大,性能可能灾难性地崩溃。 + +- PPO使用一个**裁剪的替代目标**。令 $r_t(\theta) = \frac{\pi(a_t | s_t; \theta)}{\pi(a_t | s_t; \theta_{\text{old}})}$ 为新旧策略之间的概率比。损失为: + +$$\mathcal{L}^{\text{CLIP}}(\theta) = \mathbb{E} \left[ \min\!\left( r_t(\theta) A_t, \; \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right]$$ + +- 裁剪(通常 $\epsilon = 0.2$)防止比率远离1,使更新保持小而稳定。如果优势为正(动作好),比率上限为 $1 + \epsilon$。如果为负(动作差),比率下限为 $1 - \epsilon$。这比早期的信任区域方法(TRPO)更简单、更稳定。 + +- PPO被用于通过**RLHF**(基于人类反馈的强化学习)训练ChatGPT风格的模型。在RLHF中,一个奖励模型在人类偏好数据(人类更喜欢两个输出中的哪一个?)上训练,然后PPO优化语言模型策略以最大化这个学习到的奖励。 + +- **DPO**(直接偏好优化)通过完全消除奖励模型来简化RLHF。DPO不训练奖励模型然后运行RL,而是推导出一个闭式损失,直接从偏好数据优化策略: + +$$\mathcal{L}_{\text{DPO}}(\theta) = -\mathbb{E} \left[ \log \sigma\!\left( \beta \log \frac{\pi_\theta(y_w \mid x)}{\pi_{\text{ref}}(y_w \mid x)} - \beta \log \frac{\pi_\theta(y_l \mid x)}{\pi_{\text{ref}}(y_l \mid x)} \right) \right]$$ + +- 这里 $y_w$ 是偏好的(胜出)回答,$y_l$ 是不被偏好的(失败)回答。DPO增加偏好输出的相对概率,并且比基于PPO的RLHF实现起来简单得多。 + +- RL算法中有两个重要区分。**在策略 vs 离策略**:在策略方法(SARSA, PPO)从当前策略生成的数据中学习;离策略方法(Q学习, DQN)可以从任何策略生成的数据中学习。离策略方法样本效率更高(它们重用旧数据),但可能不那么稳定。 + +- **基于模型 vs 无模型**:无模型方法(到目前为止讨论的所有方法)直接从经验中学习价值或策略。基于模型的方法学习环境的模型($P(s' \mid s, a)$ 和 $R(s, a)$)并用其进行规划(想象未来的轨迹而不实际采取动作)。基于模型的方法样本效率更高,但增加了学习精确模型的复杂性。 + +- 总结RL领域: + +| 方法 | 类型 | 核心思想 | 优势 | +|---|---|---|---| +| 价值迭代 | DP, 基于模型 | 贝尔曼最优性 | 精确解(小MDP) | +| SARSA | TD, 在策略 | 在策略学习Q | 保守、安全 | +| Q学习 | TD, 离策略 | 学习Q*, 贪心目标 | 简单、有效 | +| DQN | 深度, 离策略 | 神经Q + 回放 + 目标网络 | 扩展到高维状态 | +| REINFORCE | 策略梯度 | log-概率 * 回报的梯度 | 简单的策略优化 | +| 演员-评论家 | PG + 价值 | 演员 + 评论家降低方差 | 实用且灵活 | +| PPO | PG, 裁剪 | 信任区域般的稳定性 | 行业标准 | +| DPO | 直接偏好 | 跳过奖励模型 | 更简单的RLHF | + +## 编程任务(使用CoLab或笔记本) + +1. 为简单的网格世界实现价值迭代。计算最优价值函数并提取最优策略。将两者可视化为热力图和箭头图。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# 4x4网格世界:目标在(3,3),每步奖励-1,目标处为0 +grid_size = 4 +gamma = 0.99 +goal = (3, 3) + +# 动作:上、下、左、右 +actions = [(-1, 0), (1, 0), (0, -1), (0, 1)] +action_names = ['up', 'down', 'left', 'right'] +action_arrows = ['\u2191', '\u2193', '\u2190', '\u2192'] + +def step(s, a): + """确定性转移。""" + ns = (max(0, min(grid_size-1, s[0]+a[0])), + max(0, min(grid_size-1, s[1]+a[1]))) + return ns + +# 价值迭代 +V = jnp.zeros((grid_size, grid_size)) +for iteration in range(100): + V_new = jnp.array(V) + for i in range(grid_size): + for j in range(grid_size): + if (i, j) == goal: + continue + values = [] + for a in actions: + ns = step((i, j), a) + values.append(-1 + gamma * float(V[ns[0], ns[1]])) + V_new = V_new.at[i, j].set(max(values)) + if jnp.max(jnp.abs(V_new - V)) < 1e-6: + print(f"在{iteration+1}次迭代后收敛") + break + V = V_new + +# 提取策略 +policy = [['' for _ in range(grid_size)] for _ in range(grid_size)] +for i in range(grid_size): + for j in range(grid_size): + if (i, j) == goal: + policy[i][j] = 'G' + continue + best_a = max(range(4), key=lambda a: -1 + gamma * float(V[step((i,j), actions[a])[0], step((i,j), actions[a])[1]])) + policy[i][j] = action_arrows[best_a] + +fig, axes = plt.subplots(1, 2, figsize=(10, 4)) +im = axes[0].imshow(V, cmap='YlOrRd_r') +axes[0].set_title("最优价值函数") +for i in range(grid_size): + for j in range(grid_size): + axes[0].text(j, i, f"{V[i,j]:.1f}", ha='center', va='center', fontsize=10) +plt.colorbar(im, ax=axes[0]) + +axes[1].imshow(jnp.ones((grid_size, grid_size)), cmap='Greys', vmin=0, vmax=2) +axes[1].set_title("最优策略") +for i in range(grid_size): + for j in range(grid_size): + axes[1].text(j, i, policy[i][j], ha='center', va='center', fontsize=18) +plt.tight_layout(); plt.show() +``` + +2. 在简单的网格世界上实现表格Q学习。训练智能体,绘制学习曲线,显示学习到的Q值。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +grid_size = 5 +goal = (4, 4) +actions = [(-1,0), (1,0), (0,-1), (0,1)] + +# Q表 +Q = {} +for i in range(grid_size): + for j in range(grid_size): + Q[(i,j)] = [0.0] * 4 + +alpha = 0.1 +gamma = 0.95 +epsilon = 1.0 +epsilon_decay = 0.995 +min_epsilon = 0.01 + +def step(s, a_idx): + a = actions[a_idx] + ns = (max(0, min(grid_size-1, s[0]+a[0])), + max(0, min(grid_size-1, s[1]+a[1]))) + r = 0.0 if ns == goal else -1.0 + done = ns == goal + return ns, r, done + +key = jax.random.PRNGKey(42) +rewards_per_episode = [] + +for ep in range(500): + s = (0, 0) + total_reward = 0 + for _ in range(100): + key, subkey = jax.random.split(key) + if float(jax.random.uniform(subkey)) < epsilon: + key, subkey = jax.random.split(key) + a = int(jax.random.randint(subkey, (), 0, 4)) + else: + a = max(range(4), key=lambda i: Q[s][i]) + + ns, r, done = step(s, a) + total_reward += r + # Q学习更新 + Q[s][a] += alpha * (r + gamma * max(Q[ns]) - Q[s][a]) + s = ns + if done: + break + rewards_per_episode.append(total_reward) + epsilon = max(min_epsilon, epsilon * epsilon_decay) + +plt.figure(figsize=(8, 4)) +# 平滑曲线 +window = 20 +smoothed = [sum(rewards_per_episode[max(0,i-window):i+1])/min(i+1, window) + for i in range(len(rewards_per_episode))] +plt.plot(smoothed, color='#3498db', linewidth=1.5) +plt.xlabel("Episode"); plt.ylabel("Total Reward (smoothed)") +plt.title("Q-Learning on Gridworld") +plt.grid(alpha=0.3); plt.show() + +# 显示学到的策略 +arrow = ['\u2191', '\u2193', '\u2190', '\u2192'] +print("学到的策略:") +for i in range(grid_size): + row = "" + for j in range(grid_size): + if (i,j) == goal: + row += " G " + else: + row += f" {arrow[max(range(4), key=lambda a: Q[(i,j)][a])]} " + print(row) +``` + +3. 在多臂老虎机问题上实现REINFORCE。展示策略如何随训练演变以偏向最佳臂。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# 5臂老虎机,不同期望奖励 +true_rewards = jnp.array([0.2, 0.5, 0.8, 0.3, 0.1]) +n_arms = len(true_rewards) + +# 策略:在logits上的softmax +logits = jnp.zeros(n_arms) +lr = 0.1 +key = jax.random.PRNGKey(42) + +policy_history = [] +reward_history = [] + +for step in range(2000): + probs = jax.nn.softmax(logits) + policy_history.append(probs) + + # 采样动作 + key, subkey = jax.random.split(key) + action = jax.random.choice(subkey, n_arms, p=probs) + + # 获取奖励(伯努利分布) + key, subkey = jax.random.split(key) + reward = float(jax.random.uniform(subkey) < true_rewards[action]) + reward_history.append(reward) + + # REINFORCE更新 + # grad log pi(a) = e_a - probs(对于softmax参数化) + grad_log_pi = -probs.at[action].add(1.0) # one-hot(a) - probs + logits = logits + lr * reward * grad_log_pi + +policy_history = jnp.stack(policy_history) + +fig, axes = plt.subplots(1, 2, figsize=(12, 4)) +colors = ['#3498db', '#e74c3c', '#27ae60', '#9b59b6', '#f39c12'] +for i in range(n_arms): + axes[0].plot(policy_history[:, i], color=colors[i], + label=f'臂{i} (真实={true_rewards[i]:.1f})', linewidth=1.5) +axes[0].set_xlabel("步骤"); axes[0].set_ylabel("P(臂)") +axes[0].set_title("策略演变 (REINFORCE)") +axes[0].legend(fontsize=8); axes[0].grid(alpha=0.3) + +# 平滑奖励 +window = 50 +smoothed = [sum(reward_history[max(0,i-window):i+1])/min(i+1,window) + for i in range(len(reward_history))] +axes[1].plot(smoothed, color='#27ae60', linewidth=1.5) +axes[1].axhline(y=0.8, color='#e74c3c', linestyle='--', alpha=0.5, label='最佳臂') +axes[1].set_xlabel("步骤"); axes[1].set_ylabel("平均奖励") +axes[1].set_title("奖励随时间变化"); axes[1].legend() +axes[1].grid(alpha=0.3) +plt.tight_layout(); plt.show() +``` diff --git a/chapter 06: machine learning/05. distributed deep learning.md b/chapter 06: machine learning/05. distributed deep learning.md new file mode 100644 index 0000000..f39a409 --- /dev/null +++ b/chapter 06: machine learning/05. distributed deep learning.md @@ -0,0 +1,264 @@ +# 分布式深度学习 + +*分布式训练将计算分散到多个GPU和机器上,以训练单个设备无法容纳或训练太慢的模型。本文件涵盖混合精度、数据并行、模型并行、流水线并行、ZeRO、FSDP、张量并行以及全规约等通信原语——这些对于大规模训练LLM至关重要。* + +- 在单个GPU上训练大型神经网络最终会遇到瓶颈。模型可能无法放入内存,或者训练可能需要数月。分布式训练将工作分散到多个设备(GPU、TPU或整台机器)上,以更快地训练和训练更大的模型。本文件涵盖了实现这一目标的技术。 + +- 要理解为何分布式重要,从训练的**计算成本**开始。在一个包含 $d_{\text{in}}$ 个输入和 $d_{\text{out}}$ 个输出的密集层上,对一批 $B$ 个样本进行一次前向传播需要大约 $2 \cdot B \cdot d_{\text{in}} \cdot d_{\text{out}}$ 次FLOP(浮点运算):对输出矩阵的每个元素进行一次乘法和一次加法。反向传播的成本大约是前向传播的两倍(计算相对于输入和权重的梯度),因此一个密集层的一个训练步骤约为 $6 \cdot B \cdot d_{\text{in}} \cdot d_{\text{out}}$ 次FLOP。 + +- 对于隐藏维度为 $d$ 的Transformer层,自注意力块涉及四个投影(Q、K、V和输出),每个的成本为 $O(B \cdot n \cdot d^2)$ 次FLOP(其中 $n$ 是序列长度),加上注意力矩阵计算 $O(B \cdot n^2 \cdot d)$。前馈块有两个密集层,通常扩展到 $4d$ 再回来:$O(B \cdot n \cdot 8d^2)$。每层总计:大约 $O(B \cdot n \cdot 12d^2 + B \cdot n^2 \cdot d)$。乘以层数,你就会明白为什么训练GPT规模的模型需要数千个GPU小时。 + +- **内存墙**通常是更严格的约束。在训练期间,GPU内存必须同时容纳四样东西: + +![堆叠柱状图展示训练内存分解:参数、梯度、优化器状态、激活值](../images/training_memory_breakdown.svg) + +- **参数**:模型权重。一个70亿参数的模型在FP32中(每个参数4字节)仅权重就需要28 GB。 +- **梯度**:与参数大小相同。又是28 GB。 +- **优化器状态**:Adam维护两个额外的缓冲区(一阶和二阶矩估计),每个与参数大小相同。即使模型使用较低精度,这些也以FP32格式保存以确保数值稳定性。对于我们的7B模型,那就是 $2 \times 28 = 56$ GB。 +- **激活值**:在前向传播过程中保存下来供反向传播使用的中间值。大小取决于批量大小、序列长度和模型宽度。这通常是最主要的组成部分,并随批量大小线性增长。 + +- 对于使用FP32 Adam的7B模型:28(参数)+ 28(梯度)+ 56(优化器)= 112 GB,这还没算激活值。单个80 GB的A100 GPU无法容纳。这就是分布式策略至关重要的原因。 + +- **混合精度训练**是第一道防线。不是将所有内容存储在FP32(32位浮点)中,而是使用FP16或BF16(16位)进行前向和反向传播,同时将权重的FP32主副本保留给优化器更新。 + +- **FP16**具有高精度(10位尾数),但范围有限,可能导致上溢/下溢。损失缩放(在反向传播前将损失乘以一个大因子,然后将梯度除以相同因子)缓解了这个问题。 + +- **BF16**(脑浮点)具有与FP32相同的指数范围(8位指数),但精度较低(7位尾数)。它几乎从不溢出,很少需要损失缩放,因此使用更简单。BF16是现代Transformer训练的默认选择。 + +- 混合精度大致将激活值和梯度的内存减半(前向/反向传播期间的主要成本),同时将优化器状态保留在FP32中以确保数值稳定性。 + +- **数据并行**是最简单的分布式策略。你在 $N$ 个GPU上复制整个模型,将每个小批量分成 $N$ 个相等的块,并将一个块发送到每个GPU。每个GPU在其块上独立运行前向和反向传播。然后梯度在所有GPU上平均(使用全规约操作),每个GPU更新其本地模型副本。 + +- 从模型的角度来看,这相当于使用大了 $N$ 倍的小批量进行训练。如果每个GPU处理一个大小为 $B$ 的批次,则有效批量大小为 $N \cdot B$。 + +![并排比较:数据并行复制模型并分割数据,模型并行分割模型并分享数据](../images/data_model_parallelism.svg) + +- 梯度平均可以同步或异步进行。**同步SGD**等待所有GPU完成后再进行平均,确保与使用更大批量的单GPU训练数学上等价。缺点是,最慢的GPU("掉队者")会拖慢所有人。 + +- **异步SGD**让每个GPU独立地更新一个共享的参数服务器,无需等待。这消除了掉队者问题,但引入了"陈旧梯度":一个GPU可能基于略微过时的参数计算梯度。陈旧梯度增加了噪声,可能减缓收敛。在实践中,带高效通信的同步SGD更受青睐。 + +- **梯度累积**是一种软件技巧,用于在有限硬件上模拟更大的批量大小。不必每个小批量做一次更新,而是运行多次前向/反向传播并累积梯度,然后做一次更新。这与更大批量得到相同的结果,而无需更多GPU内存用于激活值(一次只有一个小批量的激活值在内存中)。 + +- 当模型本身太大无法放入单个GPU时,需要**模型并行**。有两种主要形式。 + +- **张量并行**将单个层分割到多个GPU上。一个大的矩阵乘法 $Y = XW$ 可以按列分割:将 $W$ 分区为 $[W_1, W_2]$ 分布在两个GPU上,并行计算 $Y_1 = XW_1$ 和 $Y_2 = XW_2$,然后拼接。这适用于注意力投影和前馈层。它需要GPU之间快速通信(通常是节点内的NVLink),因为每层都必须组合部分结果。 + +- **流水线并行**将不同的层分配到不同的GPU上。GPU 0运行第1-4层,GPU 1运行第5-8层,依此类推。数据像流水线一样流经整个管道。朴素的方法有一个"流水线气泡":当GPU 0处理微批次1的前向传播时,GPU 1-3处于空闲状态。**微批处理**通过将小批量分割成更小的微批次来缓解这个问题,这些微批次按顺序流经流水线,使所有GPU大部分时间保持忙碌。 + +- **混合并行**结合了数据并行、张量并行和流水线并行。一个典型的大模型设置可能使用节点内的张量并行(8个GPU通过快速NVLink连接)、跨节点的流水线并行以及跨节点组的数据并行。这就是GPT-4和Llama等模型的训练方式。 + +- 分布式训练的效率在很大程度上取决于**通信**。关键操作是**全规约(all-reduce)**:给定 $N$ 个GPU上各有一个值,计算总和(或平均值)并将结果分发给所有GPU。 + +- 朴素的全规约将所有数据发送到一个GPU,求和,然后广播回来。通信量为 $O(N)$,并在根节点造成瓶颈。 + +- **环全规约(Ring all-reduce)** 要高效得多。将 $N$ 个GPU排列成一个环。每个GPU将其数据分割成 $N$ 块。在 $N - 1$ 步中,每个GPU向邻居发送一块,并从另一个邻居接收一块,累加部分和。再经过 $N - 1$ 步后,完整的总和传播到所有GPU。每个GPU的总数据传输量:数据大小的 $2(N-1)/N$ 倍,随着 $N$ 的增长趋近于 $2\times$。关键在于,这不随 $N$ 增加,使其带宽最优。 + +![四个GPU排列成环,每个将梯度块传递给邻居,直到所有GPU都得到完整总和](../images/ring_allreduce.svg) + +- **参数服务器**是一种替代架构,其中专用服务器节点保存模型参数。工作节点计算梯度并将其发送到服务器,服务器更新参数并将其发送回来。这更简单,但可能在服务器处造成通信瓶颈。 + +- **NCCL**(NVIDIA集合通信库)是GPU间通信的标准库。它提供了全规约、全收集、广播和其他集合操作的高效实现,自动为网络拓扑选择最佳算法。 + +- **缩放定律**描述了模型性能如何随计算量、数据量和模型大小而提升。原始的Kaplan等人(2020)缩放定律发现,损失随每个因素以幂律方式下降: + +$$L(N) \propto N^{-\alpha_N}, \quad L(D) \propto D^{-\alpha_D}, \quad L(C) \propto C^{-\alpha_C}$$ + +- 其中 $N$ 是参数数量,$D$ 是数据集大小,$C$ 是计算预算。 + +- **Chinchilla缩放定律**(Hoffmann等人,2022)表明大多数模型训练不足:对于给定的计算预算,应该训练一个更小的模型,使用比以前认为的更多的数据。最优比例大约是每参数20个token。一个7B模型应该看到大约140B个token,而不是Llama 1在65B模型上使用的300B个token。这一发现将领域转向了"计算最优"训练。 + +- **混合专家(MoE)** 是一种在不按比例增加计算量的情况下扩展模型容量的架构。每个Transformer层不是使用一个前馈网络,而是有 $N$ 个"专家"网络(每个都是一个标准FFN)。一个**门控网络**(路由器)检查每个token并将其发送到top-$K$个专家(通常 $K = 1$ 或 $K = 2$)。 + +![token通过门控网络路由到选定的专家,使用top-K稀疏路由和加权输出组合](../images/moe_routing.svg) + +- 总参数量要大得多(因为有 $N$ 个专家),但每个token的FLOPs大致保持不变(因为每个token只有 $K$ 个专家激活)。例如,Mixtral 8x7B共有47B个参数,但每次前向传播只用大约13B,以较小模型的代价获得更大模型的性能。 + +- MoE带来了挑战。**负载均衡**:如果路由器将大多数token发送到同一个专家,其他专家就被浪费了。辅助损失鼓励均匀路由。**通信**:不同的专家可能位于不同的GPU上,因此路由token需要全对全通信,这很昂贵。 + +- **容错**在训练运行持续数周或数月、涉及数千个GPU时至关重要。如果单个GPU失效,你不想丢失所有进度。**检查点**定期将模型权重、优化器状态和训练状态(学习率、步数、数据位置)保存到磁盘。如果发生故障,你可以从最近的检查点重新开始。 + +- **梯度检查点**(也称为激活重计算)是一种内存优化,而非容错机制。在前向传播过程中,不是保存所有激活值供反向传播使用,而是只在某些检查点保存激活值。在反向传播过程中,从检查点重新计算缺失的激活值。这以计算换取内存:它使前向传播成本增加约33%,但可以将激活内存减少 $\sqrt{L}$ 倍(其中 $L$ 是层数)。 + +- 综合起来,训练前沿模型结合了所有这些技术:BF16混合精度、使用环全规约在数千个GPU上进行数据并行、节点内的张量并行、跨节点的流水线并行、减少内存的梯度检查点、提高参数效率的MoE,以及用于容错的定期检查点。系统工程与算法设计一样具有挑战性。 + +- 总结分布式训练工具包: + +| 技术 | 作用 | 权衡 | +|---|---|---| +| 混合精度 (BF16) | 将激活值/梯度的内存减半 | 轻微数值差异 | +| 数据并行 | 在GPU间扩展批量大小 | 梯度同步的通信开销 | +| 张量并行 | 在GPU间分割层 | 需要快速互联 | +| 流水线并行 | 在GPU间分割模型阶段 | 流水线气泡(计算浪费) | +| 梯度累积 | 模拟大批量 | 更慢(多次前向/反向传播) | +| 梯度检查点 | 减少激活内存 | 约多33%计算 | +| 环全规约 | 高效的梯度平均 | 大模型受限于带宽 | +| MoE | 更多容量,相同FLOPs | 负载均衡、路由复杂性 | +| 缩放定律 | 指导计算分配 | 经验公式,未必在所有规模都成立 | + +## 编程任务(使用CoLab或笔记本) + +1. 计算Transformer层的FLOPs和内存需求。给定隐藏维度 $d$、序列长度 $n$、批量大小 $B$ 和层数,估计总训练成本。 +```python +import jax.numpy as jnp + +def transformer_layer_flops(d, n, B): + """一个Transformer层前向传播的近似FLOPs。""" + # QKV投影:3 * (B * n * d * d) * 2(乘法-加法) + qkv_flops = 3 * 2 * B * n * d * d + # 注意力:(B * n * n * d) * 2 用于QK^T,(B * n * n * d) * 2 用于attn*V + attn_flops = 2 * 2 * B * n * n * d + # 输出投影:(B * n * d * d) * 2 + out_flops = 2 * B * n * d * d + # FFN:两层,d->4d 和 4d->d:2 * (B * n * d * 4d) * 2 + ffn_flops = 2 * 2 * B * n * d * 4 * d + return qkv_flops + attn_flops + out_flops + ffn_flops + +def transformer_layer_memory(d, n, B, dtype_bytes=2): + """一个层的近似激活内存(字节)。""" + # QKV:3 * B * n * d + qkv_mem = 3 * B * n * d * dtype_bytes + # 注意力权重:B * heads * n * n(近似 B * n * n * sizeof) + attn_mem = B * n * n * dtype_bytes + # FFN中间值:B * n * 4d + ffn_mem = B * n * 4 * d * dtype_bytes + return qkv_mem + attn_mem + ffn_mem + +# 示例:GPT-2规模 +d, n, B, L = 1024, 1024, 8, 24 +fwd_flops = transformer_layer_flops(d, n, B) +total_flops = 3 * L * fwd_flops # 前向+反向的3倍 +act_mem = L * transformer_layer_memory(d, n, B) +param_count = L * (12 * d * d + 13 * d) # 近似 + +print(f"模型:d={d}, n={n}, B={B}, L={L}") +print(f"参数:{param_count / 1e6:.0f}M") +print(f"每步FLOPs:{total_flops / 1e12:.2f} TFLOPs") +print(f"激活内存:{act_mem / 1e9:.2f} GB (BF16)") +print(f"参数内存 (FP32):{param_count * 4 / 1e9:.2f} GB") +print(f"Adam优化器内存:{param_count * 8 / 1e9:.2f} GB") +print(f"总训练内存:{(param_count * 16 + act_mem) / 1e9:.2f} GB") +``` + +2. 模拟数据并行训练。将数据集分割到多个"虚拟GPU"上,独立计算梯度,平均它们,并验证结果与单GPU训练匹配。 +```python +import jax +import jax.numpy as jnp + +# 简单线性模型:y = wx + b +key = jax.random.PRNGKey(0) +X = jax.random.normal(key, (64, 4)) +w_true = jnp.array([1.0, -2.0, 3.0, 0.5]) +y = X @ w_true + 0.1 * jax.random.normal(key, (64,)) + +def loss_fn(w, X, y): + return jnp.mean((X @ w - y) ** 2) + +grad_fn = jax.grad(loss_fn) + +# 单GPU:全批量梯度 +w = jnp.zeros(4) +grad_single = grad_fn(w, X, y) + +# 数据并行:分割到4个"GPU"上 +n_gpus = 4 +chunk_size = len(X) // n_gpus +grads = [] +for i in range(n_gpus): + X_chunk = X[i*chunk_size:(i+1)*chunk_size] + y_chunk = y[i*chunk_size:(i+1)*chunk_size] + grads.append(grad_fn(w, X_chunk, y_chunk)) + +# 全规约:平均梯度 +grad_parallel = jnp.mean(jnp.stack(grads), axis=0) + +print("单GPU梯度:", grad_single) +print("数据并行梯度(平均):", grad_parallel) +print(f"匹配:{jnp.allclose(grad_single, grad_parallel, atol=1e-5)}") + +# 训练两者并比较 +w_single, w_parallel = jnp.zeros(4), jnp.zeros(4) +lr = 0.1 +for step in range(100): + w_single = w_single - lr * grad_fn(w_single, X, y) + + grads = [grad_fn(w_parallel, X[i*chunk_size:(i+1)*chunk_size], + y[i*chunk_size:(i+1)*chunk_size]) for i in range(n_gpus)] + avg_grad = jnp.mean(jnp.stack(grads), axis=0) + w_parallel = w_parallel - lr * avg_grad + +print(f"\n100步之后:") +print(f"单GPU权重:{w_single}") +print(f"数据并行权重:{w_parallel}") +print(f"最大差异:{jnp.max(jnp.abs(w_single - w_parallel)):.2e}") +``` + +3. 实现一个简单的混合专家层。创建一个门控网络,将token路由到top-K个专家并组合它们的输出。 +```python +import jax +import jax.numpy as jnp + +def expert_fn(x, W1, b1, W2, b2): + """简单的2层FFN专家。""" + h = jnp.maximum(0, x @ W1 + b1) # ReLU + return h @ W2 + b2 + +def moe_layer(x, gate_W, experts_params, top_k=2): + """ + MoE前向传播。 + x: (batch, d_model) + gate_W: (d_model, n_experts) + experts_params: 每个专家的 (W1, b1, W2, b2) 列表 + """ + n_experts = len(experts_params) + + # 门控:计算路由分数 + gate_logits = x @ gate_W # (batch, n_experts) + gate_probs = jax.nn.softmax(gate_logits, axis=-1) + + # Top-K选择 + top_k_indices = jnp.argsort(-gate_probs, axis=-1)[:, :top_k] + top_k_probs = jnp.take_along_axis(gate_probs, top_k_indices, axis=-1) + # 重新归一化 + top_k_probs = top_k_probs / jnp.sum(top_k_probs, axis=-1, keepdims=True) + + # 计算专家输出(简化:运行所有专家,稍后掩码) + expert_outputs = jnp.stack([ + expert_fn(x, *experts_params[i]) for i in range(n_experts) + ], axis=1) # (batch, n_experts, d_model) + + # 收集top-K专家输出并加权 + batch_idx = jnp.arange(x.shape[0])[:, None] + selected_outputs = expert_outputs[batch_idx, top_k_indices] # (batch, top_k, d_model) + output = jnp.sum(selected_outputs * top_k_probs[:, :, None], axis=1) + + return output, gate_probs + +# 设置 +key = jax.random.PRNGKey(42) +batch, d_model, d_ff, n_experts = 8, 16, 32, 4 + +# 初始化专家 +experts_params = [] +for i in range(n_experts): + k1, k2, key = jax.random.split(key, 3)[0], jax.random.split(key, 3)[1], jax.random.split(key, 3)[2] + experts_params.append(( + jax.random.normal(k1, (d_model, d_ff)) * 0.1, + jnp.zeros(d_ff), + jax.random.normal(k2, (d_ff, d_model)) * 0.1, + jnp.zeros(d_model), + )) + +key, subkey = jax.random.split(key) +gate_W = jax.random.normal(subkey, (d_model, n_experts)) * 0.1 +x = jax.random.normal(key, (batch, d_model)) + +output, gate_probs = moe_layer(x, gate_W, experts_params, top_k=2) + +print(f"输入形状:{x.shape}") +print(f"输出形状:{output.shape}") +print(f"门控概率(第一个样本):{gate_probs[0]}") +print(f"专家使用率(批量平均):") +for i in range(n_experts): + usage = jnp.mean(gate_probs[:, i]) + print(f" 专家 {i}: {usage:.3f}") +``` diff --git a/chapter 07: computational linguistics/01. linguistic foundations.md b/chapter 07: computational linguistics/01. linguistic foundations.md new file mode 100644 index 0000000..79d8a06 --- /dev/null +++ b/chapter 07: computational linguistics/01. linguistic foundations.md @@ -0,0 +1,303 @@ +# 语言学基础 + +*语言学为NLP系统提供了它们隐式学习并利用的结构化词汇。本文涵盖形态学、句法学、语义学、语用学、音系学、成分句法和依存句法分析,以及分布假设——这些人类语言科学构成了AI中词元化、语法和意义的基础。* + +- 在构建能够理解或生成语言的系统之前,我们需要理解语言本身是如何运作的。 + +- 语言学是对语言的科学研究,它为NLP提供了不断借用的概念性词汇。 + +- 即使是现代神经模型——它们从原始数据中学习语言——也会隐式地重新发现语言学家们几十年来已经编目的许多结构。 + +- 语言在每一层都具有结构:组成单词的声音、组成单词的部件、将单词组合成句子的规则、这些句子所承载的意义,以及语境如何塑造解读。我们将自下而上地逐层探索。 + +- **形态学**是对单词内部结构的研究。单词并非不可分割的原子;它们由更小的有意义的单元构建而成,这些单元称为**语素**。 + +- 单词"unhappiness"包含三个语素:"un-"(前缀,意为"不")、"happy"(词根)和"-ness"(后缀,将形容词转化为名词)。每个语素都对意义有所贡献。 + +- **词根**(或称词干)是承载主要意义的核心语素。"Happy"、"run"、"compute"都是词根。 + +- **词缀**是附加到词根上以修饰其意义或语法功能的语素。 + +- 英语中有**前缀**(位于词根之前:un-、re-、pre-)和**后缀**(位于词根之后:-ing、-ed、-tion)。一些语言还包含中缀(插入词根内部)和环缀(包裹在词根周围)。 + +![语素树:"unhappiness"分解为前缀"un"、词根"happy"、后缀"ness"](../images/morpheme_tree.svg) + +- 形态学对NLP很重要,因为它影响词元化。一个基于词级的词元化器会将"run"、"runs"、"running"和"ran"视为四个互不相关的符号。 + +- 一个具有形态学意识的系统会识别出它们共享同一个词根。子词词元化(BPE、WordPiece)——我们将在文件02中讨论——是形态学分析的统计近似方法。 + +- **句法学**研究单词如何组合成短语和句子。每种语言都有控制词序和结构的规则;违反这些规则会产生无意义的输出。 + +- "The cat sat on the mat"是合乎语法的英语;"Mat the on sat cat the"则不是。 + +- 描述句法结构主要有两种框架。 + +- **短语结构语法**(也称为成分语法)认为句子是通过将一个短语嵌套在另一个短语内部构建而成的。一个句子(S)由一个名词短语(NP)和一个动词短语(VP)组成。 + +- 一个名词短语可能由一个限定词(Det)后跟一个名词(N)组成。一个动词短语可能由一个动词(V)后跟一个名词短语组成。这些规则构建出一棵树: + +![成分树:"the cat sat on the mat":S分支为NP和VP,NP分支为Det"the"和N"cat",VP分支为V"sat"和PP,PP分支为P"on"和NP](../images/constituency_tree.svg) + +- 这棵树称为**成分树**(或分析树)。每个内部节点是一个短语类型,每个叶子节点是一个单词。这棵树捕捉了层次化分组:"on the mat"是一个单元(介词短语),"sat on the mat"是一个单元(动词短语),而整个结构是一个句子。 + +- **上下文无关文法(CFG)**将这些规则形式化。它由一组产生式规则组成,每条规则的形式为 $A \to \alpha$,其中 $A$ 是一个非终结符(如NP或VP这样的短语类型),$\alpha$ 是一个由终结符(单词)和非终结符组成的序列。例如: + +``` +S → NP VP +NP → Det N +NP → Det N PP +VP → V NP +VP → V PP +PP → P NP +Det → "the" | "a" +N → "cat" | "mat" | "dog" +V → "sat" | "chased" +P → "on" | "under" +``` + +- 从S开始,反复应用规则,你可以生成该文法允许的所有句子。分析则是相反的过程:给定一个句子,找出产生它的树(或所有可能的树)。一个有多个有效分析树的句子称为**句法歧义**。"I saw the man with the telescope"有两种分析:我使用望远镜看到了那个男人,或者我看到了一个拿着望远镜的男人。 + +- **依存语法**采取了一种不同的视角。它不依赖短语嵌套,而是描述单词之间的直接关系。句子中的每个单词都恰好依赖于另一个单词(它的**核心词**),除了句子的根节点。结果是一个**依存树**,其中边标有语法关系标签(主语、宾语、修饰语等)。 + +![依存树:"the cat sat on the mat":从"sat"到"cat"的箭头(nsubj)和到"on"的箭头(prep),从"on"到"mat"的箭头(pobj),从"cat"到"the"的箭头(det),从"mat"到"the"的箭头(det)](../images/dependency_tree.svg) + +- 在依存视角下,"sat"是根节点。"Cat"作为主语(nsubj)依赖于"sat"。"On"作为介词修饰语依赖于"sat"。"Mat"作为介词宾语依赖于"on"。每个单词都挂在恰好一个核心词上,形成一棵树。 + +- 依存语法已成为现代NLP中的主导框架,因为依存树更容易用统计分析器生成,而且这些关系更直接地映射到语义角色(谁对谁做了什么)。 + +- **配价**描述一个动词需要多少个论元。"Sleep"是**不及物动词**(一个论元:睡觉者)。"Eat"是**及物动词**(两个:吃者和被吃之物)。"Give"是**双及物动词**(三个:给予者、给予之物和接受者)。了解动词的配价可以约束哪些分析树是有效的。 + +- **语义学**是对意义的研究。句法学告诉你句子是如何结构的;语义学告诉你句子意味着什么。 + +- **词汇语义学**关注单个单词的意义。单词之间以系统性的方式相互关联: + + - **同义关系**:具有(几乎)相同意义的单词。"Big"和"large"是同义词。真正完美的同义词是罕见的;几乎总是存在含义或用法上的细微差别。 + - **反义关系**:具有相反意义的单词。"Hot"和"cold","buy"和"sell"。 + - **上位关系/下位关系**:"是一种"关系。"Dog"是"animal"的下位词(狗是一种动物)。"Animal"是"dog"的上位词。这些关系形成分类层次结构。 + - **部分整体关系**:"组成部分"关系。"Wheel"是"car"的部分词。 + - **多义关系**:一个单词具有多个相关意义。"Bank"可以指金融机构或河岸。语境可以消除歧义。 + +- **词义消歧(WSD)**是根据上下文确定多义词的哪个义项被使用的任务。在"I deposited money at the bank"中,金融义项是正确的。在"We sat by the river bank"中,地理义项是正确的。WSD是早期NLP中的一个核心问题;现代的上下文嵌入(ELMo、BERT)通过为同一个单词的不同用法生成不同的向量表示,在很大程度上解决了这个问题。 + +- **组合语义学**研究单个单词的意义如何组合以形成短语或句子的意义。**组合性原则**(归功于弗雷格)指出,一个复杂表达式的意义由其组成部分的意义以及组合这些部分的规则共同决定。"The cat chased the dog"与"the dog chased the cat"意义不同,因为句法结构(谁是主语、谁是宾语)与单词意义相互作用。 + +- 并非所有意义都是组合性的。**习语**如"kick the bucket"(意为"去世")具有无法从其组成部分推导出的意义。这对任何组合性方法都是一个挑战。 + +- **分布语义学**是支撑现代NLP的计算性意义研究方法。**分布假设**(Firth, 1957)指出:"观其伴,知其意。"(You shall know a word by the company it keeps.)出现在相似语境中的单词往往具有相似的意义。这是词嵌入(Word2Vec、GloVe)的理论基础,我们将在文件03中深入探讨。 + +- **语用学**研究语境如何影响意义。同一个句子根据说话者、时间、地点和原因的不同,可能意味着不同的事情。 + +- "Can you pass the salt?"在句法上是一个关于能力的疑问句。在语用上,它是一个请求。你不会回答"是的,我能"然后坐着不动。理解这一点需要超越字面意义的知识,具体来说,是关于**言语行为**的惯例知识。 + +- **言语行为理论**(Austin, Searle)区分了: + - **言内行为**:字面内容("Can you pass the salt?") + - **言外行为**:意图实现的功能(一个请求) + - **言后行为**:对听者产生的效果(他们递过盐) + +- **隐涵**(Grice)是指被暗示但未明确陈述的意义。如果有人问"Is John a good cook?"而你回答"He's British",你并没有从字面上回答问题,但听者可以推断(通过文化刻板印象,无论公平与否)你的意思是"不好"。Grice的**合作原则**指出,说话者通常会努力做到信息充分、真实、相关和清晰,而听者假定这些准则成立来进行解读。 + +- **共指**是一种语用现象,其中不同的表达指向同一个实体。在"Alice went to the store. She bought milk"中,"she"指代Alice。解决共指问题对于理解多句文本至关重要,是NLP中的一个关键任务。 + +- **篇章结构**描述句子如何连接以形成连贯的文本。叙事有开头、中间和结尾。论证有主张和证据。**修辞结构理论(RST)**将文本分析为篇章关系(阐述、对比、因果等)的树状结构。 + +- 语用学是NLP中最困难的领域。现代语言模型通过训练数据隐式地处理了大部分句法和语义,但语用推理——理解讽刺、隐涵和依赖语境的意义——仍然是一个前沿挑战。 + +- **音系学**研究语言的声音系统。虽然本章主要关注文本,但简要概述可以衔接音频和语音章节(第09章)。 + +- **音位**是区分意义的最小声音单位。英语约有44个音位。单词"bat"和"pat"相差一个音位(/b/ 与 /p/),而意义的改变是完全性的。这被称为**最小对立体**。 + +- **音位变体**是同一个音位的不同物理实现,不改变意义。"pin"中的"p"(送气音,带一股气流)和"spin"中的"p"(不送气音)在英语中是音位/p/的音位变体;母语者将它们视为同一个声音。 + +- **国际音标(IPA)**为所有语言的音位提供了标准化的记法。单词"cat"转录为/kæt/。IPA是书面文本和语音系统之间的桥梁。 + +- **韵律**涵盖语音的节奏、重音和语调。"I didn't say he stole the money"根据重音落在哪个单词上,可以有七种不同的含义。韵律携带了纯文本所丢失的信息,这就是为什么文本转语音系统必须仔细建模韵律的原因。 + +- 在NLP中,音系学知识出现在文本转语音(字形到音位的转换)、语音识别(将声学信号映射到音位),甚至拼写纠正和音译中。 + +## 编程练习(使用CoLab或notebook) + +1. 构建一个简单的形态分析器,使用常见前缀和后缀列表将英语单词分解为可能的语素。 +```python +prefixes = ['un', 're', 'pre', 'dis', 'mis', 'over', 'under', 'out', 'non'] +suffixes = ['ing', 'ed', 'ly', 'ness', 'ment', 'tion', 'able', 'ible', 'er', 'est', 'ful', 'less', 'ous'] + +def analyse_morphemes(word): + """使用已知词缀进行简单的语素分析。""" + parts = [] + remaining = word.lower() + + # 检查前缀 + for p in sorted(prefixes, key=len, reverse=True): + if remaining.startswith(p) and len(remaining) > len(p) + 2: + parts.append(f"[prefix: {p}]") + remaining = remaining[len(p):] + break + + # 检查后缀 + for s in sorted(suffixes, key=len, reverse=True): + if remaining.endswith(s) and len(remaining) > len(s) + 2: + root = remaining[:-len(s)] + parts.append(f"[root: {root}]") + parts.append(f"[suffix: {s}]") + remaining = None + break + + if remaining is not None: + parts.append(f"[root: {remaining}]") + + return parts + +for word in ['unhappiness', 'reusable', 'disconnected', 'overreacting', 'kindness']: + print(f"{word:20s} → {' + '.join(analyse_morphemes(word))}") +``` + +2. 实现一个使用递归下降法的简单上下文无关文法分析器。定义一个小型文法,并将句子分析为成分树。 +```python +class CFGParser: + """用于小型英语文法的递归下降分析器。""" + def __init__(self, tokens): + self.tokens = tokens + self.pos = 0 + + def peek(self): + return self.tokens[self.pos] if self.pos < len(self.tokens) else None + + def consume(self, expected=None): + tok = self.peek() + if expected and tok != expected: + return None + self.pos += 1 + return tok + + def parse_det(self): + if self.peek() in ('the', 'a'): + return ('Det', self.consume()) + return None + + def parse_noun(self): + if self.peek() in ('cat', 'dog', 'mat', 'man'): + return ('N', self.consume()) + return None + + def parse_verb(self): + if self.peek() in ('sat', 'chased', 'saw'): + return ('V', self.consume()) + return None + + def parse_prep(self): + if self.peek() in ('on', 'under', 'with'): + return ('P', self.consume()) + return None + + def parse_np(self): + save = self.pos + det = self.parse_det() + noun = self.parse_noun() + if det and noun: + # 检查可选的PP + pp = self.parse_pp() + if pp: + return ('NP', det, noun, pp) + return ('NP', det, noun) + self.pos = save + return None + + def parse_pp(self): + save = self.pos + prep = self.parse_prep() + np = self.parse_np() + if prep and np: + return ('PP', prep, np) + self.pos = save + return None + + def parse_vp(self): + save = self.pos + verb = self.parse_verb() + if verb: + np = self.parse_np() + if np: + return ('VP', verb, np) + pp = self.parse_pp() + if pp: + return ('VP', verb, pp) + self.pos = save + return None + + def parse_sentence(self): + np = self.parse_np() + vp = self.parse_vp() + if np and vp and self.pos == len(self.tokens): + return ('S', np, vp) + return None + +def print_tree(tree, indent=0): + if isinstance(tree, str): + print(' ' * indent + tree) + elif isinstance(tree, tuple): + print(' ' * indent + tree[0]) + for child in tree[1:]: + print_tree(child, indent + 2) + +sentences = [ + "the cat sat on the mat", + "a dog chased the cat", +] + +for sent in sentences: + tokens = sent.split() + parser = CFGParser(tokens) + tree = parser.parse_sentence() + print(f"\n'{sent}':") + if tree: + print_tree(tree) + else: + print(" (no parse found)") +``` + +3. 通过构建一个简单的词图来探索词汇关系。给定一个包含同义、反义和上位关系的小型词汇表,查找单词之间的路径。 +```python +relations = { + ('big', 'large'): 'synonym', + ('big', 'small'): 'antonym', + ('small', 'tiny'): 'synonym', + ('dog', 'animal'): 'hypernym', + ('cat', 'animal'): 'hypernym', + ('puppy', 'dog'): 'hypernym', + ('happy', 'glad'): 'synonym', + ('happy', 'sad'): 'antonym', + ('hot', 'cold'): 'antonym', + ('hot', 'warm'): 'synonym', +} + +# 构建邻接列表 +from collections import defaultdict, deque + +graph = defaultdict(list) +for (w1, w2), rel in relations.items(): + graph[w1].append((w2, rel)) + graph[w2].append((w1, rel)) + +def find_path(start, end): + """使用BFS在关系图中查找两个单词之间的路径。""" + queue = deque([(start, [(start, None)])]) + visited = {start} + while queue: + node, path = queue.popleft() + if node == end: + return path + for neighbor, rel in graph[node]: + if neighbor not in visited: + visited.add(neighbor) + queue.append((neighbor, path + [(neighbor, rel)])) + return None + +pairs = [('big', 'tiny'), ('puppy', 'cat'), ('happy', 'sad')] +for w1, w2 in pairs: + path = find_path(w1, w2) + if path: + steps = " → ".join(f"{w}({r})" if r else w for w, r in path) + print(f"{w1} → {w2}: {steps}") + else: + print(f"{w1} → {w2}: no path found") +``` diff --git a/chapter 07: computational linguistics/02. text processing and classic NLP.md b/chapter 07: computational linguistics/02. text processing and classic NLP.md new file mode 100644 index 0000000..28b2834 --- /dev/null +++ b/chapter 07: computational linguistics/02. text processing and classic NLP.md @@ -0,0 +1,340 @@ +# 文本处理与经典NLP + +*文本处理将原始字符转换为模型可消费的结构化表示。本文涵盖分词(词级、子词、BPE、WordPiece)、文本规范化、编辑距离、TF-IDF、n元组语言模型、词性标注、命名实体识别和情感分析——这些经典NLP流水线至今仍是现代系统的基础。* + +- 原始文本是混乱的。在任何NLP模型处理语言之前,文本必须经过清洗、规范化并转换为结构化表示。本文涵盖了从原始字符到模型可消费特征的完整流水线,以及深度学习兴起之前主导领域的经典NLP算法。 + +- **文本规范化**将原始文本转换为规范形式。其目标是减少不相关的变异,使"Hello"、"hello"、"HELLO"和"héllo"得到恰当的处理。 + +- **大小写折叠**将文本转换为小写。这将"The"和"the"合并为一个词元。这对大多数任务有帮助,但在某些情况下会破坏有用信息:"US"(国家)vs "us"(代词),或"Apple"(公司)vs "apple"(水果)。 + +- **Unicode规范化**处理同一字符有多种编码方式的问题。字符"é"可以是单个码点(U+00E9),也可以是基础"e"加上组合变音符号(U+0065 + U+0301)。NFC规范化将它们组合成一个码点;NFD则进行分解。如果没有规范化,两个看起来相同的字符串可能无法匹配。 + +- **编辑距离**衡量两个字符串之间的差异程度。**莱文斯坦距离**计算将一个字符串转换为另一个所需的最少单字符插入、删除和替换次数。"kitten" → "sitting"的编辑距离为3(k→s,e→i,插入g)。 + +- 编辑距离使用动态规划计算(我们在算法章节中回顾)。定义 $D[i][j]$ 为字符串 $s$ 的前 $i$ 个字符与字符串 $t$ 的前 $j$ 个字符之间的距离: + +```math +D[i][j] = \begin{cases} j & \text{if } i = 0 \\ i & \text{if } j = 0 \\ D[i{-}1][j{-}1] & \text{if } s[i] = t[j] \\ 1 + \min(D[i{-}1][j], \; D[i][j{-}1], \; D[i{-}1][j{-}1]) & \text{otherwise} \end{cases} +``` + +- 编辑距离支撑着拼写纠正、模糊匹配和DNA序列比对。在NLP中,它用于处理拼写错误和查找相似单词。 + +- **分词**将文本分割成模型可以处理的离散单元(词元)。这是第一个也是最重要的预处理步骤。分词策略的选择深刻影响着模型行为。 + +- **空白分词**以空格分割。简单但幼稚:"New York"变成两个词元,"don't"是一个词元(或根据分割器不同,拆分为"don"和"'t"),而中文和日文等语言在词之间根本没有空格。 + +- **基于规则的分词**使用手工设计的模式(正则表达式)来处理缩写、标点符号和特殊情况。"I'm" → "I" + "'m","U.S.A."保持为一个词元。每种语言都需要自己的规则,这非常耗费人力。 + +- **子词分词**是现代解决方案。它不是在词边界处分割,而是从数据中学习一个高频子词单元的词汇表。这优雅地处理了未知词:如果"unhappiness"不在词汇表中,它可能被拆分为"un" + "happi" + "ness",保留了形态结构。 + +!["unhappiness"和"transformers"的词级、字符级和子词分词对比](../images/tokenisation_comparison.svg) + +- **字节对编码(BPE)**从单个字符作为词汇表开始。它反复查找最频繁的相邻对并将其合并为一个新词元。经过足够次数的合并后,常见词成为单个词元,罕见词则被拆分为高频子词片段。 + +- BPE算法: + 1. 用训练语料中的所有单个字符初始化词汇表 + 2. 统计每个相邻词元对的频率 + 3. 将最频繁的对合并为一个新词元 + 4. 重复步骤2-3,直到达到所需的合并次数(词汇表大小) + +- 例如,从"l o w"(5次)、"l o w e r"(2次)、"n e w e s t"(6次)开始:最频繁的对可能是"e s" → 合并为"es"。然后"es t" → "est"。然后"n e w" → "new"。最终的词汇表同时包含完整单词和子词片段。 + +- **WordPiece**(BERT使用)与BPE类似,但基于似然而非频率来选择合并。它合并能使训练数据的语言模型似然最大化的对。非词首的子词词元以"##"作为前缀(例如,"playing" → "play" + "##ing")。 + +- **Unigram**(SentencePiece使用)采用相反的方法:从一个大型词汇表开始,迭代地移除那些移除后对训练数据似然损失最小的词元。最终的词汇表是最能解释语料库的子词单元集合。 + +- **SentencePiece**是一个语言无关的分词库,它将输入视为原始字节流(不在空格上进行预分词)。这使得它适用于任何语言,包括没有空格的语言。它同时实现了BPE和Unigram算法。 + +- 词汇表大小是一个关键超参数。典型的选择范围从30,000到100,000个词元。更大的词汇表意味着每个序列的词元更少(更高效),但需要更大的嵌入表。更小的词汇表意味着更多的子词分割和更长的序列。 + +- 两种技术都将词汇简化为基本形式,但方法不同。 + +- **词干提取**使用粗略规则切除后缀。波特词干提取器将"running"简化为"run","happiness"简化为"happi","studies"简化为"studi"。它速度快但不精确:"university"和"universe"都被词干化为"univers",尽管它们毫不相关。 + +- **词形还原**使用词汇表和形态学分析来找到真正的词典形式(词元)。"Running" → "run","better" → "good","mice" → "mouse"。它需要知道词性:"saw"作为动词时词形还原为"see",但作为名词时保持为"saw"。 + +- 现代子词分词在很大程度上已取代了神经NLP中的词干提取和词形还原,但它们在信息检索以及处理较小模型或有限数据时仍然有用。 + +- **词性标注**为每个词分配一个语法类别:名词、动词、形容词、限定词等。这是最古老的NLP任务之一,也是句法分析的基础。 + +- 宾州树库标签集是英语中最常用的,包含36个标签(NN表示单数名词,NNS表示复数名词,VB表示动词原形,VBD表示过去式,JJ表示形容词等)。 + +- 词性标注很棘手,因为许多词是有歧义的。"Book"可以是名词("the book")或动词("book a flight")。"Run"在不同词性下有数十种含义。上下文至关重要。 + +- 早期的标注器使用第05章中的**隐马尔可夫模型(HMM)**。隐藏状态是词性标签,观测值是单词。转移概率捕捉标签序列(限定词后面很可能跟名词或形容词),发射概率捕捉哪些词与哪些标签一起出现。维特比算法找出最可能的标签序列。 + +- 用于词性标注的HMM模型: + +$$\\hat{t}_{1:n} = \\arg\\max_{t_{1:n}} \\prod_{i=1}^{n} P(w_i \\mid t_i) \\cdot P(t_i \\mid t_{i-1})$$ + +- 现代词性标注器使用神经网络(双向LSTM或Transformer),在英语上达到超过97%的准确率,接近人类水平。 + +- **命名实体识别(NER)**识别并分类文本中的专有名词和其他特定实体:人物、组织、地点、日期、货币金额等。 + +- 在"Apple CEO Tim Cook announced the event in Cupertino on Monday"中,NER系统应识别出:Apple(ORG组织)、Tim Cook(PER人物)、Cupertino(LOC地点)、Monday(DATE日期)。 + +- NER通常被框架化为**序列标注**,使用**BIO标注**(也称为IOB标注)。每个词元获得一个标签: + - **B-TYPE**:TYPE类型实体的开始 + - **I-TYPE**:TYPE类型实体的内部(延续) + - **O**:实体外部 + +- "Tim Cook visited New York"变为:Tim/B-PER Cook/I-PER visited/O New/B-LOC York/I-LOC。B标签标记新实体的起始位置,这对于两个同类型实体相邻的情况很重要。 + +![带有BIO标签颜色编码的句子:B-PER(红色)、I-PER(红色)、O(灰色)、B-LOC(蓝色)、I-LOC(蓝色)](../images/bio_tagging.svg) + +- 经典NER使用第05章中的**条件随机场(CRF)**,它对给定输入下整个标签序列的条件概率建模。与生成式模型($P(x, y)$)的HMM不同,CRF是判别式模型,直接建模 $P(y \\mid x)$。线性链CRF定义为: + +$$P(y_{1:n} \\mid x_{1:n}) = \\frac{1}{Z(x)} \\exp\\!\\left(\\sum_{i=1}^{n} \\left[\\sum_k \\lambda_k f_k(y_i, x, i) + \\sum_j \\mu_j g_j(y_i, y_{i-1}, x, i)\\right]\\right)$$ + +- 这里 $f_k$ 是**发射特征**(给定位置 $i$ 的输入,标签 $y_i$ 的可能性),$g_j$ 是**转移特征**(给定前一个标签 $y_{i-1}$,当前标签 $y_i$ 的可能性)。 + +- 配分函数 $Z(x) = \\sum_{y'} \\exp(\\ldots)$ 对所有可能的标签序列求和,以归一化分布。训练最大化条件对数似然,这需要使用前向算法(第05章)高效计算 $Z(x)$。 + +- 与独立分类每个词元相比的关键优势:CRF的转移特征强制结构约束(例如,I-PER应该只跟在B-PER或I-PER之后,绝不应出现在O之后)。 + +- 现代NER将CRF堆叠在神经编码器之上(BiLSTM-CRF或BERT-CRF),其中神经网络产生发射分数,CRF层学习转移结构。 + +- **句法分析**将句子转换为其句法结构,可以是成分树或依存树(两者均见文件01)。 + +- **CYK算法**(Cocke-Younger-Kasami)使用动态规划结合上下文无关文法解析句子。 + +- 它要求文法为**乔姆斯基范式**(每条规则的右侧要么有两个非终结符,要么有一个终结符)。它自底向上填充一个三角表格:单元格表示句子的跨度,每个单元格存储可以生成该跨度的非终结符。 + +- CYK的时间复杂度为 $O(n^3 \\cdot |G|)$,其中 $n$ 是句子长度,$|G|$ 是文法规模。这是精确算法,但对于大型文法来说速度较慢。 + +- **移进-归约解析**从左到右处理句子,维护一个栈。在每一步,它要么**移进**(将下一个词压入栈),要么**归约**(从栈中弹出元素并用短语替换)。一个训练好的分类器在每一步决定操作。时间复杂度为 $O(n)$,比CYK快得多。 + +- **依存解析**在实践中比成分解析更为常见。基于转换的依存解析器(如移进-归约)和基于图的解析器(对所有可能的边评分并找到最大生成树)是两种主要方法。使用BiLSTM或Transformer的神经依存解析器取得了最先进的成果。 + +- 在嵌入出现之前,NLP使用简单的计数方法将文档表示为向量。 + +- **词袋模型(BoW)**将文档表示为词频向量,完全忽略词序。如果词汇表有 $V$ 个词,每个文档就是 $\\mathbb{R}^V$ 空间中的一个向量(与第01章的向量空间相联系)。词 $w$ 对应的条目是 $w$ 在文档中出现的次数。 + +![词袋模型:文档转换为词频表,再转换为R^V空间中的稀疏向量,词汇表中每个词对应一个条目](../images/bag_of_words.svg) + +- BoW简单但出奇有效,适用于文档分类和垃圾邮件过滤等任务。其主要缺点是每个词都被同等对待:"the"和"revolutionary"获得相同的权重。 + +- **TF-IDF**(词频-逆文档频率)通过根据词的信息量大小来加权,解决了这个问题。在单个文档中频繁出现但在整个语料库中罕见的词,很可能对该文档很重要。 + +$$\\text{TF-IDF}(t, d) = \\text{TF}(t, d) \\times \\text{IDF}(t)$$ + +- **词频** $\\text{TF}(t, d)$ 通常是词 $t$ 在文档 $d$ 中的原始计数(或其对数形式:$1 + \\log(\\text{count})$)。 + +- **逆文档频率** $\\text{IDF}(t) = \\log\\frac{N}{|\\{d : t \\in d\\}|}$,其中 $N$ 是文档总数。出现在每个文档中的词(如"the")的IDF接近0。罕见词获得高IDF。 + +- TF-IDF向量可以使用余弦相似度(来自第01章)进行比较,以衡量文档相似性。这是经典信息检索和搜索引擎的基础。 + +- **语言模型**为词序列分配概率。它回答的是:这个句子的可能性有多大?语言模型是机器翻译、语音识别、拼写纠正和文本生成的核心。 + +- 句子 $w_1, w_2, \\ldots, w_n$ 的概率,根据概率的链式法则(第05章)为: + +$$P(w_1, w_2, \\ldots, w_n) = \\prod_{i=1}^{n} P(w_i \\mid w_1, \\ldots, w_{i-1})$$ + +- 这是精确的但不实用:你需要为每个可能的历史存储概率。**马尔可夫假设**(第05章)将历史截断到最近 $k-1$ 个词,得到 **n元语法模型**(其中 $n = k$)。 + +- **二元模型**($n = 2$)仅依赖前一个词: + +$$P(w_i \\mid w_1, \\ldots, w_{i-1}) \\approx P(w_i \\mid w_{i-1})$$ + +- **三元模型**($n = 3$)依赖前两个词。n元语法概率通过在语料库中计数来估计: + +$$P(w_i \\mid w_{i-1}) = \\frac{\\text{count}(w_{i-1}, w_i)}{\\text{count}(w_{i-1})}$$ + +- **困惑度**衡量语言模型对测试集的预测能力。它是测试集概率的倒数,按词数归一化: + +$$\\text{PPL} = P(w_1, \\ldots, w_N)^{-1/N} = \\exp\\!\\left(-\\frac{1}{N} \\sum_{i=1}^{N} \\log P(w_i \\mid w_{ 0\\}|}{|\\{(w', w'') : \\text{count}(w', w'') > 0\\}|}$$ + +- 分子统计在语料库中出现在 $w_i$ 之前的不同词的数量。像"Francisco"这样的词出现在很少的上下文中(几乎总是在"San"之后),所以即使"San Francisco"非常频繁,"Francisco"的延续概率也很低,不会在其他上下文中被错误预测。 + +- 相反,像"the"这样的常见词出现在许多不同词之后,获得高延续概率。这体现了这样一种直觉:对于回退估计而言,词的多功能性比其原始频率更重要。 + +- n元语法模型几十年来一直是主流技术。它们速度快、可解释性强,且无需训练(只需计数)。但它们难以处理长距离依赖("The keys that I left on the table **are** missing"需要知道主语"keys"是复数,而它与动词相距甚远)。神经语言模型——从RNN开始到Transformer达到顶峰——解决了这一局限性。 + +## 编程练习(使用CoLab或notebook) + +1. 使用动态规划实现莱文斯坦编辑距离。在词对上测试,并用于简单的拼写纠正。 +```python +import jax.numpy as jnp + +def edit_distance(s, t): + """Compute Levenshtein edit distance using DP.""" + m, n = len(s), len(t) + D = [[0] * (n + 1) for _ in range(m + 1)] + + for i in range(m + 1): + D[i][0] = i + for j in range(n + 1): + D[0][j] = j + + for i in range(1, m + 1): + for j in range(1, n + 1): + if s[i-1] == t[j-1]: + D[i][j] = D[i-1][j-1] + else: + D[i][j] = 1 + min(D[i-1][j], D[i][j-1], D[i-1][j-1]) + + return D[m][n] + +# Test +pairs = [("kitten", "sitting"), ("sunday", "saturday"), ("hello", "hallo")] +for s, t in pairs: + print(f"d('{s}', '{t}') = {edit_distance(s, t)}") + +# Simple spelling correction +dictionary = ["the", "their", "there", "then", "than", "this", "that", "these", "those"] +misspelled = "thier" +corrections = sorted(dictionary, key=lambda w: edit_distance(misspelled, w)) +print(f"\nClosest to '{misspelled}': {corrections[:3]}") +``` + +2. 从头实现BPE分词。从字符级词元开始,迭代地合并最频繁的对。 +```python +from collections import Counter + +def get_pairs(corpus): + """Count adjacent token pairs across all words.""" + pairs = Counter() + for word, freq in corpus.items(): + symbols = word.split() + for i in range(len(symbols) - 1): + pairs[(symbols[i], symbols[i+1])] += freq + return pairs + +def merge_pair(pair, corpus): + """Merge all occurrences of a pair in the corpus.""" + new_corpus = {} + bigram = ' '.join(pair) + replacement = ''.join(pair) + for word, freq in corpus.items(): + new_word = word.replace(bigram, replacement) + new_corpus[new_word] = freq + return new_corpus + +# Training corpus with word frequencies +text = "low low low low low lower lower newest newest newest newest newest newest" +word_freqs = Counter(text.split()) +# Initialise: split each word into characters with end-of-word marker +corpus = {' '.join(word) + ' _': freq for word, freq in word_freqs.items()} + +print("Initial corpus:") +for word, freq in corpus.items(): + print(f" {word}: {freq}") + +# Run BPE for 10 merges +for i in range(10): + pairs = get_pairs(corpus) + if not pairs: + break + best_pair = max(pairs, key=pairs.get) + corpus = merge_pair(best_pair, corpus) + print(f"\nMerge {i+1}: {best_pair} (freq={pairs[best_pair]})") + for word, freq in corpus.items(): + print(f" {word}: {freq}") +``` + +3. 构建一个二元语言模型,并计算测试句子的困惑度。尝试拉普拉斯平滑。 +```python +from collections import Counter, defaultdict +import math + +# Training corpus +train = """the cat sat on the mat . the dog chased the cat . +the cat ran from the dog . a dog sat on a mat .""".split() + +# Count bigrams and unigrams +bigrams = Counter(zip(train[:-1], train[1:])) +unigrams = Counter(train) +vocab_size = len(set(train)) + +def bigram_prob(w2, w1, alpha=0): + """P(w2 | w1) with optional Laplace smoothing.""" + return (bigrams[(w1, w2)] + alpha) / (unigrams[w1] + alpha * vocab_size) + +# Compute perplexity +test = "the cat sat on a mat .".split() + +for alpha in [0, 1, 0.1]: + log_prob = 0 + for w1, w2 in zip(test[:-1], test[1:]): + p = bigram_prob(w2, w1, alpha=alpha) + if p > 0: + log_prob += math.log(p) + else: + log_prob += float('-inf') + + ppl = math.exp(-log_prob / (len(test) - 1)) if log_prob > float('-inf') else float('inf') + print(f"Smoothing α={alpha}: perplexity = {ppl:.2f}") +``` + +4. 从头实现TF-IDF,并使用余弦相似度找到与查询最相似的文档。 +```python +import jax.numpy as jnp +import math +from collections import Counter + +documents = [ + "the cat sat on the mat", + "the dog chased the cat around the park", + "a mat was placed on the floor by the door", + "the quick brown fox jumped over the lazy dog", +] + +# Build vocabulary +vocab = sorted(set(word for doc in documents for word in doc.split())) +word_to_idx = {w: i for i, w in enumerate(vocab)} +V = len(vocab) +N = len(documents) + +# Compute TF-IDF matrix +doc_freq = Counter() +for doc in documents: + for word in set(doc.split()): + doc_freq[word] += 1 + +tfidf_matrix = jnp.zeros((N, V)) +for i, doc in enumerate(documents): + word_counts = Counter(doc.split()) + for word, count in word_counts.items(): + tf = 1 + math.log(count) + idf = math.log(N / doc_freq[word]) + j = word_to_idx[word] + tfidf_matrix = tfidf_matrix.at[i, j].set(tf * idf) + +# Query +query = "cat on the mat" +query_vec = jnp.zeros(V) +query_counts = Counter(query.split()) +for word, count in query_counts.items(): + if word in word_to_idx: + tf = 1 + math.log(count) + idf = math.log(N / doc_freq.get(word, 1)) + query_vec = query_vec.at[word_to_idx[word]].set(tf * idf) + +# Cosine similarity (from chapter 01) +def cosine_sim(a, b): + return jnp.dot(a, b) / (jnp.linalg.norm(a) * jnp.linalg.norm(b) + 1e-8) + +print(f"Query: '{query}'\n") +for i, doc in enumerate(documents): + sim = cosine_sim(query_vec, tfidf_matrix[i]) + print(f" Doc {i} (sim={sim:.3f}): '{doc}'") +``` diff --git a/chapter 07: computational linguistics/03. embeddings and sequence models.md b/chapter 07: computational linguistics/03. embeddings and sequence models.md new file mode 100644 index 0000000..7495fd8 --- /dev/null +++ b/chapter 07: computational linguistics/03. embeddings and sequence models.md @@ -0,0 +1,389 @@ +# 嵌入与序列模型 + +*词嵌入将稀疏的符号化文本压缩到稠密向量空间中,使得语义相似性转化为几何邻近性。本文涵盖 Word2Vec(CBOW、Skip-gram)、GloVe、FastText、RNN、LSTM、GRU、带注意力机制的 seq2seq、编码器-解码器范式,以及从词袋模型到上下文表示的发展历程。* + +- 在文件 01 中,我们介绍了分布假设:出现在相似语境中的词往往具有相似的含义。在文件 02 中,我们使用稀疏的、手工设计的特征(如 TF-IDF 向量)来表示文本。这些向量位于极高维空间中(每个词汇表词占一维),且大部分为零。**词嵌入**将这些信息压缩到稠密的低维向量中,捕捉语义关系,并且直接从数据中学习。 + +- **Word2Vec**(Mikolov et al., 2013)通过在简单的预测任务上训练一个浅层神经网络来学习词嵌入。共有两种架构。 + +- **连续词袋模型(CBOW)**根据目标词周围的上下文词来预测该词。给定一个窗口大小的上下文词(例如,"the cat ___ on the"),模型求它们的嵌入向量的平均值,并将结果通过一个线性层来预测缺失的词("sat")。训练目标最大化: + +$$P(w_t \mid w_{t-k}, \ldots, w_{t-1}, w_{t+1}, \ldots, w_{t+k})$$ + +- **Skip-gram 模型**则反过来:给定一个目标词,预测其周围的上下文词。对于目标词 "sat",模型分别尝试预测 "the"、"cat"、"on"、"the"。目标最大化: + +$$P(w_{t+j} \mid w_t) \quad \text{对于每个 } j \in [-k, k], \; j \neq 0$$ + +![Skip-gram 与 CBOW 架构对比:CBOW 对上下文嵌入求平均来预测中心词,skip-gram 使用中心词嵌入来预测每个上下文词](../images/word2vec_architectures.svg) + +- Skip-gram 通常对罕见词效果更好,因为每个词会产生多个训练样本(每个上下文位置一个)。CBOW 速度更快,对频繁词略优,因为它对多个上下文信号取平均。 + +- 在整个词汇表上训练代价很高,因为 softmax 分母需要对所有 $V$ 个词求和。**负采样**通过将问题转化为二分类来近似这一过程:区分真实的上下文词(正样本)与随机采样的噪声词(负样本)。模型无需计算完整的 softmax,只需更新目标词、真实上下文词以及少数负样本的嵌入: + +$$\mathcal{L} = \log \sigma(v_{w_O}^T v_{w_I}) + \sum_{i=1}^{k} \mathbb{E}_{w_i \sim P_n} [\log \sigma(-v_{w_i}^T v_{w_I})]$$ + +- 这里 $v_{w_I}$ 是输入词嵌入,$v_{w_O}$ 是输出(上下文)词嵌入,$P_n$ 是噪声分布,通常采用词频的 3/4 次方(这会降低"the"这类高频词的权重)。 + +- 为什么这个简单的目标函数能产生有意义的嵌入?Levy 和 Goldberg(2014)证明,带负采样的 skip-gram 实际上是在分解一个**移位点互信息(PMI)**矩阵。在收敛时,两个词向量的点积近似于: + +$$v_w^T v_c \approx \text{PMI}(w, c) - \log k$$ + +- 其中 $\text{PMI}(w, c) = \log \frac{P(w, c)}{P(w) P(c)}$ 衡量词 $w$ 和 $c$ 共现的频率比随机期望高出多少(见第 05 章信息论),$k$ 是负样本数量。共现远高于随机期望的词具有高 PMI,从而具有高点积(相似的嵌入)。共现低于预期的词具有负 PMI 和不相似的嵌入。这表明 Word2Vec 实际上与经典的分布语义学方法(如潜在语义分析,即对共现矩阵做 SVD)在做同样的事情,只是采用了更具扩展性的在线方式。 + +- Word2Vec 嵌入最令人惊讶的特性是它们能通过**向量算术**捕捉**类比关系**。向量 $v_{\text{king}} - v_{\text{man}} + v_{\text{woman}}$ 最接近 $v_{\text{queen}}$。这是因为嵌入空间将语义关系编码为近似线性方向:"王室"方向大致为 $v_{\text{king}} - v_{\text{man}}$,将其加到 $v_{\text{woman}}$ 上就会落在 $v_{\text{queen}}$ 附近。这与第 01 章的线性代数相关联:语义关系就是向量平移。 + +- **GloVe**(Global Vectors for Word Representation,Pennington et al., 2014)采用不同的方法。它不是一次一个地从局部上下文窗口学习,而是构建一个全局的词共现矩阵 $X$,其中 $X_{ij}$ 统计在整个语料库中词 $j$ 出现在词 $i$ 上下文中的次数。然后模型学习嵌入,使其点积近似于对数共现次数: + +$$w_i^T \tilde{w}_j + b_i + \tilde{b}_j = \log X_{ij}$$ + +- 损失函数通过一个截断函数 $f(X_{ij})$ 对每一对加权,防止非常频繁的共现主导训练: + +$$\mathcal{L} = \sum_{i,j=1}^{V} f(X_{ij}) \left(w_i^T \tilde{w}_j + b_i + \tilde{b}_j - \log X_{ij}\right)^2$$ + +- GloVe 结合了全局矩阵分解(如潜在语义分析)和 Word2Vec 的局部上下文学习的优点。在实践中,GloVe 和 Word2Vec 生成的嵌入质量相近。 + +- **FastText**(Bojanowski et al., 2017)扩展了 skip-gram,将每个词表示为一组字符 n-gram 的集合。对于 $n = 3$,词 "where" 变成:"\",加上完整词标记 "\"。该词的嵌入是其所有 n-gram 嵌入之和。 + +- 这有一个关键优势:FastText 能够为训练中从未见过的词生成嵌入。词 "whereabouts" 与 "where" 共享 n-gram,因此即使 "whereabouts" 从未出现在训练数据中,其嵌入也是合理的。这对于形态丰富的语言(文件 01)尤为有用,因为这些语言中的词有许多屈折形式。 + +- **嵌入评估**通常使用两类基准测试。**类比任务**测试 $v_a - v_b + v_c \approx v_d$ 是否成立(例如,"Paris" $-$ "France" $+$ "Italy" $\approx$ "Rome")。**相似性基准**将词对之间的余弦相似度(第 01 章)与人工判断进行比较。常见的数据集包括 WordSim-353、SimLex-999 和 Google 类比测试集。一个实用注意事项:在类比任务上表现出色的嵌入不一定最适合下游任务,如情感分类。最好的评估往往是任务本身。 + +- 在第 06 章中,我们介绍了 RNN、LSTM 和 GRU 作为处理序列数据的架构。这里我们重点讨论它们如何具体应用于语言任务。 + +- **语言模型 RNN** 每次读取一个词元,并在每一步预测下一个词元。隐藏状态 $h_t$ 将整个历史序列 $w_1, \ldots, w_t$ 压缩为一个固定大小的向量,线性层加 softmax 将 $h_t$ 映射到词汇表上的分布。训练使用与真实下一词元的交叉熵损失,这等价于最小化困惑度(文件 02)。关键局限在于:固定大小的隐藏状态必须编码关于历史的所有信息,早期词元的信息会逐渐被覆盖。 + +- **双向 RNN** 从两个方向处理序列:一个 RNN 从左到右读取,另一个从右到左读取。在每个位置 $t$,前向隐藏状态 $\overrightarrow{h}_t$ 和后向隐藏状态 $\overleftarrow{h}_t$ 被拼接起来,形成上下文感知的表示 $h_t = [\overrightarrow{h}_t ; \overleftarrow{h}_t]$。这使模型能够同时访问过去和未来的上下文,对于词性标注和命名实体识别(文件 02)等任务非常有效,因为这些任务中一个词的标签依赖于其前后的词。双向 RNN 不能用于语言建模,因为在预测未来词元时不能窥视它们。 + +![双向 RNN:前向 RNN 从左到右读取产生隐藏状态,后向 RNN 从右到左读取,每个位置的输出拼接在一起](../images/bidirectional_rnn.svg) + +- **深层堆叠 RNN** 将多个 RNN 层叠放在一起。第 $l$ 层所有时间步的隐藏状态成为第 $l+1$ 层的输入序列。堆叠 2-4 层通常能通过构建层次化表示来提升性能,类似于深层 CNN 构建特征层次结构(第 06 章)。超过 4 层时,梯度消失和过拟合会成为问题,除非在层之间添加残差连接。 + +- **序列到序列(seq2seq)**架构(Sutskever et al., 2014)将可变长度的输入序列映射到可变长度的输出序列。它由一个**编码器** RNN(读取输入并将其压缩为上下文向量,即最终的隐藏状态)和一个**解码器** RNN(基于该上下文向量逐步生成输出)组成。 + +![Seq2seq 编码器-解码器:编码器 RNN 从左到右读取输入词元,最终隐藏状态作为解码器 RNN 的初始状态,解码器自回归地生成输出词元](../images/seq2seq_architecture.svg) + +- Seq2seq 是机器翻译的突破性架构。编码器读取法语句子,解码器生成英文翻译。解码器从一个特殊的序列起始词元开始,自回归地生成词元,直到产生序列结束词元。一个实用的技巧:反转输入序列(输入 "chat le" 而不是 "le chat")可以改善结果,因为这使得第一个输入词在计算图中更靠近第一个输出词,缩短了梯度路径。 + +- 瓶颈问题:整个输入必须被压缩到一个固定大小的向量中。对于长句子,这个向量无法捕捉所有信息,性能会下降。这推动了**注意力机制**的发展。 + +- 第 06 章介绍了现代的点积注意力 Q、K、V 形式。NLP 中最早的注意力机制以不同的方式提出,作为编码器和解码器状态之间的对齐模型。 + +- **Bahdanau 注意力**(加性注意力,Bahdanau et al., 2015)使用一个可学习的前馈网络计算解码器隐藏状态 $s_t$ 与每个编码器隐藏状态 $h_i$ 之间的对齐分数: + +$$e_{ti} = v^T \tanh(W_s s_{t-1} + W_h h_i)$$ + +- 分数通过 softmax 归一化为注意力权重,上下文向量是编码器状态的加权和: + +$$\alpha_{ti} = \frac{\exp(e_{ti})}{\sum_j \exp(e_{tj})}, \quad c_t = \sum_i \alpha_{ti} h_i$$ + +- 然后解码器同时使用 $s_{t-1}$ 和 $c_t$ 来生成下一个输出。关键洞察:不是为整个句子使用一个固定的上下文向量,每个解码步骤获得编码器状态的不同加权组合,使模型能够"回顾"输入的相关部分。 + +- **Luong 注意力**(乘性注意力,Luong et al., 2015)简化了分数计算。**点积**变体使用 $e_{ti} = s_t^T h_i$。**通用**变体使用 $e_{ti} = s_t^T W h_i$。这些比 Bahdanau 的加性分数更快,因为它们使用矩阵乘法而非前馈网络。Luong 注意力还从当前解码器状态 $s_t$(而非 $s_{t-1}$)计算上下文向量,这使得它能获取更多信息,但计算方式略有不同。 + +![源句子与其翻译之间的注意力对齐热力图,显示每个目标词关注哪些源词,较亮的单元格表示更高的注意力权重](../images/attention_alignment.svg) + +- 注意力权重通常可视化为热力图,显示解码器在生成每个输出词元时关注哪些输入词元。在翻译中,这些热力图大致勾勒出源语言和目标语言之间的词对齐关系,对角模式会被重排序打破(例如,形容词-名词顺序在法语和英语中有所不同)。 + +- 推理时,解码器每一步必须选择一个词元。**贪心解码**在每个位置选择概率最高的词元,但这可能导致次优序列:一个局部好的选择可能迫使模型进入全局不佳的句子。**束搜索**在每一步维护分数最高的 $k$ 个(束宽)部分序列,对每个序列扩展所有可能的下一词元,并保留总体最好的 $k$ 个。 + +- 当束宽 $k = 1$ 时,束搜索退化为贪心解码。典型值为 $k = 4$ 到 $k = 10$。更大的束能找到更好的序列,但速度会成比例降低。束搜索还需要**长度归一化**,以避免偏向较短的序列(因为较短的序列乘法项更少,自然具有更高的总概率)。归一化后的分数为: + +$$\text{score}(y) = \frac{1}{|y|^\alpha} \sum_{t=1}^{|y|} \log P(y_t \mid y_{ [4, 1, 3]) +vocab_size = 10 # 数字 0-9 +SOS, EOS = 10, 11 # 特殊词元 +total_vocab = 12 +embed_dim, hidden_dim = 16, 32 +max_len = 5 + +key = jax.random.PRNGKey(42) +keys = jax.random.split(key, 8) + +params = { + 'embed': jax.random.normal(keys[0], (total_vocab, embed_dim)) * 0.1, + 'enc_Wx': jax.random.normal(keys[1], (embed_dim, hidden_dim)) * 0.1, + 'enc_Wh': jax.random.normal(keys[2], (hidden_dim, hidden_dim)) * 0.05, + 'dec_Wx': jax.random.normal(keys[3], (embed_dim, hidden_dim)) * 0.1, + 'dec_Wh': jax.random.normal(keys[4], (hidden_dim, hidden_dim)) * 0.05, + # Bahdanau 注意力 + 'Ws': jax.random.normal(keys[5], (hidden_dim, hidden_dim)) * 0.1, + 'Wh_att': jax.random.normal(keys[6], (hidden_dim, hidden_dim)) * 0.1, + 'v_att': jax.random.normal(keys[7], (hidden_dim,)) * 0.1, + # 输出投影(从隐藏状态+上下文到词汇表) + 'Wo': jax.random.normal(keys[0], (hidden_dim * 2, total_vocab)) * 0.1, +} + +def encode(params, seq): + """编码输入序列,返回所有隐藏状态。""" + h = jnp.zeros(hidden_dim) + states = [] + for t in range(len(seq)): + x = params['embed'][seq[t]] + h = jnp.tanh(x @ params['enc_Wx'] + h @ params['enc_Wh']) + states.append(h) + return jnp.stack(states), h + +def bahdanau_attention(params, dec_state, enc_states): + """计算 Bahdanau 注意力权重和上下文向量。""" + scores = jnp.tanh(enc_states @ params['Wh_att'] + dec_state @ params['Ws']) + e = scores @ params['v_att'] # (src_len,) + alpha = jax.nn.softmax(e) + context = alpha @ enc_states + return context, alpha + +def decode_step(params, dec_h, prev_token, enc_states): + x = params['embed'][prev_token] + dec_h = jnp.tanh(x @ params['dec_Wx'] + dec_h @ params['dec_Wh']) + context, alpha = bahdanau_attention(params, dec_h, enc_states) + combined = jnp.concatenate([dec_h, context]) + logits = combined @ params['Wo'] + return dec_h, logits, alpha + +def seq2seq_loss(params, src, tgt): + enc_states, enc_final = encode(params, src) + dec_h = enc_final + loss = 0.0 + prev_token = SOS + for t in range(len(tgt)): + dec_h, logits, _ = decode_step(params, dec_h, prev_token, enc_states) + log_probs = jax.nn.log_softmax(logits) + loss -= log_probs[tgt[t]] + prev_token = tgt[t] + return loss / len(tgt) + +# 生成训练数据:反转序列 +key = jax.random.PRNGKey(0) +train_srcs, train_tgts = [], [] +for _ in range(200): + key, subkey = jax.random.split(key) + length = jax.random.randint(subkey, (), 3, max_len + 1) + key, subkey = jax.random.split(key) + seq = jax.random.randint(subkey, (int(length),), 0, vocab_size) + train_srcs.append(seq) + train_tgts.append(seq[::-1]) # 反转 + +# 训练 +grad_fn = jax.grad(seq2seq_loss) +lr = 0.01 + +for epoch in range(100): + total_loss = 0.0 + for src, tgt in zip(train_srcs, train_tgts): + grads = grad_fn(params, src, tgt) + params = {k: params[k] - lr * grads[k] for k in params} + total_loss += seq2seq_loss(params, src, tgt) + if (epoch + 1) % 20 == 0: + print(f"Epoch {epoch+1}: avg loss = {total_loss / len(train_srcs):.4f}") + +# 可视化一个示例的注意力 +test_src = jnp.array([3, 1, 4, 1, 5]) +test_tgt = test_src[::-1] + +enc_states, enc_final = encode(params, test_src) +dec_h = enc_final +attentions = [] +prev_token = SOS +for t in range(len(test_tgt)): + dec_h, logits, alpha = decode_step(params, dec_h, prev_token, enc_states) + attentions.append(alpha) + prev_token = test_tgt[t] + +att_matrix = jnp.stack(attentions) +fig, ax = plt.subplots(figsize=(6, 5)) +im = ax.imshow(att_matrix, cmap='Blues') +ax.set_xlabel("源位置"); ax.set_ylabel("目标位置") +src_labels = [str(int(x)) for x in test_src] +tgt_labels = [str(int(x)) for x in test_tgt] +ax.set_xticks(range(len(src_labels))); ax.set_xticklabels(src_labels) +ax.set_yticks(range(len(tgt_labels))); ax.set_yticklabels(tgt_labels) +for i in range(len(tgt_labels)): + for j in range(len(src_labels)): + ax.text(j, i, f"{att_matrix[i,j]:.2f}", ha='center', va='center', fontsize=9) +ax.set_title("Bahdanau 注意力对齐(序列反转)") +plt.colorbar(im); plt.tight_layout(); plt.show() +``` diff --git a/chapter 07: computational linguistics/04. transformers and language models.md b/chapter 07: computational linguistics/04. transformers and language models.md new file mode 100644 index 0000000..5fc899f --- /dev/null +++ b/chapter 07: computational linguistics/04. transformers and language models.md @@ -0,0 +1,517 @@ +# Transformer与语言模型 + +*Transformer用自注意力取代了循环结构,成为语言理解和生成的主导架构。本文件涵盖BERT、GPT、T5、位置编码(正弦编码、RoPE)、预训练目标(MLM、CLM)、微调、提示工程和缩放定律——这些是现代大语言模型背后的蓝图。* + +- 在第06章中,我们介绍了Transformer架构:自注意力、多头注意力、位置编码以及编码器-解码器结构。这里我们聚焦于Transformer如何适配特定的NLP范式、定义现代NLP的模型(BERT、GPT、T5),以及让它们在大规模下切实可行的技术。 + +- 回顾核心操作:**缩放点积注意力**计算 $\text{softmax}(QK^T / \sqrt{d_k}) V$,其中查询、键和值都是输入的线性投影。**多头注意力**并行运行 $h$ 个注意力头,每个头使用不同的学习投影,然后将结果拼接起来。Transformer块通过残差连接、层归一化和逐位置前馈网络(第06章)将这一切包裹起来。 + +- 一个微妙但重要的架构选择是**层归一化**的放置位置。原始Transformer使用**后归一化**:残差和归一化在子层之后执行,即 $\text{LayerNorm}(x + \text{Sublayer}(x))$。 + +- 大多数现代模型使用**前归一化**:在子层之前进行归一化,即 $x + \text{Sublayer}(\text{LayerNorm}(x))$。前归一化在训练过程中更加稳定,因为残差连接直接将梯度通过恒等路径传递,不受归一化的影响。这使得训练非常深的模型变得更容易,无需仔细的学习率预热。 + +- 每个Transformer块中的**前馈子层**是一个两层MLP,独立应用于每个标记位置: + +$$\text{FFN}(x) = W_2 \cdot \text{GELU}(W_1 x + b_1) + b_2$$ + +- 内部维度通常是模型维度的4倍(例如,$d_{\text{model}} = 768$,$d_{\text{ff}} = 3072$)。这个FFN约占每个块中参数的三分之二,被认为起到键-值记忆的作用,存储训练过程中学到的事实知识。 + +- **位置编码**为模型提供标记顺序的信息,因为注意力本身是置换等变的。原始的**正弦编码**(第06章)使用不同频率的固定正弦和余弦函数。**可学习位置嵌入**则简单地为每个位置添加一个可训练向量(用于BERT和GPT-2)。两者都是绝对编码:无论上下文如何,位置5总是得到相同的向量。 + +- **旋转位置编码(RoPE)**通过在二维子空间中旋转查询和键向量来编码位置。对于一对维度 $(q_{2i}, q_{2i+1})$,按角度 $m\theta_i$ 的旋转(其中 $m$ 是位置,$\theta_i = 10000^{-2i/d}$)应用如下: + +```math +\begin{bmatrix} q'_{2i} \\ q'_{2i+1} \end{bmatrix} = \begin{bmatrix} \cos m\theta_i & -\sin m\theta_i \\ \sin m\theta_i & \cos m\theta_i \end{bmatrix} \begin{bmatrix} q_{2i} \\ q_{2i+1} \end{bmatrix} +``` + +![RoPE:每个位置在二维子空间中以不同角度旋转查询和键向量,使注意力分数仅依赖于相对位置](../images/rope_rotation.svg) + +- RoPE的精妙之处在于,旋转后的查询和键之间的点积 $q'^T k'$ 仅依赖于相对位置 $m - n$,而非绝对位置。 + +- 为了理解原因,将旋转写为 $q' = R_m q$ 和 $k' = R_n k$,其中 $R_m$ 是一个块对角旋转矩阵。注意力分数变为: + +$$q'^T k' = (R_m q)^T (R_n k) = q^T R_m^T R_n \, k = q^T R_{n-m} \, k$$ + +- 最后一步利用了旋转群性质:$R_m^T R_n = R_{n-m}$(先向后旋转 $m$ 再向前旋转 $n$,等价于旋转 $n - m$)。 + +- 这意味着注意力分数仅依赖于相对距离 $n - m$,而非绝对位置 $m$ 和 $n$ 本身。 + +- 模型无需任何学习的位置参数就能获得自然的距离概念,并且可以泛化到训练时未见过的序列长度。 + +- **ALiBi**(带线性偏置的注意力)采用了一种更简单的方法:它根据距离向注意力分数添加一个固定的线性惩罚,即 $\text{score}_{ij} = q_i^T k_j - m \cdot |i - j|$,其中 $m$ 是每个头特定的斜率。不同的头使用不同的斜率,使一些头可以关注局部信息,另一些头关注全局信息。ALiBi不需要任何可学习的位置参数,并且能够很好地泛化到比训练时更长的序列。 + +- 基于Transformer的语言模型的三种主导范式是**仅编码器**、**仅解码器**和**编码器-解码器**。它们在模型能看到的范围(注意力掩码)以及训练方式上有所不同。 + +![三种Transformer范式:仅编码器(BERT)使用双向注意力进行分类,仅解码器(GPT)使用因果注意力进行生成,编码器-解码器(T5)结合两者用于序列到序列任务](../images/transformer_paradigms.svg) + +- **BERT**(来自Transformer的双向编码器表示,Devlin等人,2019)是典型的仅编码器模型。它使用完全的双向注意力处理文本:每个标记可以关注所有其他标记,包括左右两侧。这赋予了BERT丰富的上下文表示,但意味着它不能自回归地生成文本。 + +- BERT通过两个目标进行预训练。**掩码语言建模(MLM)**随机遮蔽15%的输入标记,并训练模型去预测它们。在被选中的标记中,80%被替换为[MASK]标记,10%被替换为随机词,10%保持不变(以防止模型只学会在看到[MASK]时才进行预测)。训练目标如下: + +$$\mathcal{L}_{\text{MLM}} = -\sum_{i \in \mathcal{M}} \log P(w_i \mid w_{\backslash \mathcal{M}})$$ + +- 其中 $\mathcal{M}$ 是被遮蔽的位置集合,$w_{\backslash \mathcal{M}}$ 是这些位置被遮蔽后的句子。这是一个**去噪**目标:模型学习重建被破坏的输入。 + +![BERT掩码语言建模:15%的输入标记被遮蔽,双向Transformer在遮蔽位置预测原始标记](../images/bert_mlm.svg) + +- **下一句预测(NSP)**训练BERT预测两个句子在原始文本中是否连续。输入开头的特殊[CLS]标记用于此二分类。NSP的加入是为了帮助理解句子关系的任务(如问答),不过后来的工作(RoBERTa)表明其贡献很小,可以去掉。 + +- BERT的预训练表示通过在其顶部添加特定任务的头部(一个简单的线性层)并微调整个模型来适应下游任务。对于分类任务,使用[CLS]标记的表示。对于标记级任务(命名实体识别、词性标注),使用每个标记的表示。这种**微调**方法将预训练期间学到的语言知识迁移到新任务上,只需相对较少的标注数据。 + +- **GPT**(生成式预训练Transformer,Radford等人,2018)是典型的仅解码器模型。它使用**因果(自回归)注意力**:每个标记只能关注更早位置的标记(以及自身)。这是通过在注意力矩阵中遮蔽未来位置(将其分数设置为 $-\infty$,然后再进行softmax)来实现的。训练目标是简单的**因果语言建模**:根据所有之前的标记预测下一个标记。 + +$$\mathcal{L}_{\text{CLM}} = -\sum_{i=1}^{n} \log P(w_i \mid w_1, \ldots, w_{i-1})$$ + +- 这与文件02中的n-gram语言模型目标相同,但采用了Transformer参数化方式,可以基于整个前文进行条件建模,而不仅仅是最后 $k-1$ 个标记。 + +- **GPT-2**将其规模扩大到15亿参数,并展现了强大的零样本能力:无需任何微调,它就能通过自然语言提示("将英语翻译成法语:……")来执行任务。 + +- **GPT-3**(1750亿参数)表明,仅凭规模就能实现**上下文学习**:通过在提示中提供几个输入-输出示例,模型无需任何梯度更新就能执行新任务。 + +- **编码器-解码器模型**如**T5**(文本到文本迁移Transformer,Raffel等人,2020)将每个NLP任务都视为文本到文本:输入是一个文本字符串(可能带有任务前缀,如"将英语翻译成德语:"),输出也是一个文本字符串。编码器使用双向注意力处理输入,解码器则通过交叉注意力自回归地生成输出。 + +- T5通过**跨度破坏**进行预训练:随机连续标记跨度被替换为哨兵标记,模型需要生成原始标记。例如,"The cat sat on the mat"可能变成输入"The [X] on [Y]",目标输出是"[X] cat sat [Y] the mat"。这是BERT的MLM从单个标记向跨度的泛化。 + +- **BART**(Lewis等人,2020)是另一种编码器-解码器模型,通过去噪目标进行预训练,但它应用了更广泛的破坏策略:标记遮蔽、标记删除、跨度遮蔽、句子置换和文档旋转。多样化的破坏方式迫使模型学习更鲁棒的表示。 + +- 随着语言模型变得越来越大,**全量微调**(更新所有参数)变得不切实际:一个175B参数的模型仅存储优化器状态就需要数百GB。**参数高效微调(PEFT)**方法只调整一小部分参数。 + +- **适配器**在现有Transformer层之间插入小型瓶颈层(通常是两个线性层加一个非线性激活:下投影到小维度,再上投影回来)。只有适配器的权重被训练;原始模型权重被冻结。这增加了不到5%的新参数,同时在大多数任务上匹配全量微调的性能。 + +- **LoRA**(低秩适配)直接修改权重矩阵,而不添加新层。LoRA不更新完整的权重矩阵 $W$,而是学习一个低秩分解的更新:$W' = W + BA$,其中 $B$ 是 $d \times r$ 矩阵,$A$ 是 $r \times d$ 矩阵,且 $r \ll d$(通常 $r = 4$ 到 $r = 64$)。原始 $W$ 被冻结;只训练 $A$ 和 $B$。在推理时,更新可以合并到原始权重中,不会增加额外延迟: + +$$W' = W + BA$$ + +![LoRA:冻结的权重矩阵W被一个通过小矩阵A和B的低秩路径旁路,可训练参数减少32倍,同时匹配全量微调的性能](../images/lora_decomposition.svg) + +- **前缀微调**在每个注意力层的键和值矩阵前添加一串可学习的"虚拟标记"。模型像对待真实标记一样关注这些前缀向量,并且只训练前缀参数。这与提示微调类似,但在激活空间而非嵌入空间中操作。 + +- **提示工程**是设计输入文本的艺术,旨在从预训练模型中引出所需行为,而无需任何参数更新。 + + - **零样本提示**用自然语言描述任务("对以下评论的情感进行分类:")。 + + - **少样本提示**在实际查询之前提供输入-输出示例。 + + - **链式思维(CoT)提示**添加"让我们一步一步地思考"或在示例中包含推理过程,这通过引导模型分解问题,显著提高了算术和逻辑推理任务的性能。 + +- **上下文学习(ICL)**是大语言模型能够从提示中提供的示例学习执行任务的现象,而无需任何梯度更新。模型的权重没有改变;它将示例作为一种隐式规范来使用。 + +- ICL在机制上是如何工作的仍然是一个活跃的研究问题;一种假说是注意力层在前向传播中实现了一种梯度下降形式,实际上是在上下文示例上进行"训练"。 + +- **缩放定律**描述了模型大小、数据大小、计算预算与性能(以损失衡量)之间的可预测关系。Kaplan等人(2020)发现损失在每个变量上都遵循幂律: + +$$L(N) \propto N^{-\alpha_N}, \quad L(D) \propto D^{-\alpha_D}, \quad L(C) \propto C^{-\alpha_C}$$ + +- 其中 $N$ 是参数量,$D$ 是数据集大小,$C$ 是计算预算。这些幂律在多个数量级上成立,表明单纯地扩大规模就能带来可预测的改进。 + +![缩放定律:损失在对数-对数坐标轴上以幂律递减,Kaplan和Chinchilla的研究结果表明随规模扩大有可预期的改进](../images/scaling_laws.svg) + +- **Chinchilla缩放定律**(Hoffmann等人,2022)修正了这一点,指出大多数大型模型都训练不足。对于固定的计算预算 $C$,最优分配是同等规模地扩大模型大小和训练数据: + +$$N_{\text{opt}} \propto C^{0.5}, \quad D_{\text{opt}} \propto C^{0.5}$$ + +- 这意味着如果计算预算翻倍,应该同时将模型大小和数据集大小增加 $\sqrt{2}$ 倍,而不仅仅是让模型变得更大。 + +- Kaplan等人曾建议 $N$ 的缩放速度应快于 $D$,这导致了非常大但训练不足的模型。Chinchilla(70B参数,1.4T标记)在相同的计算预算下匹配了Gopher(280B参数,300B标记)的性能,表明早期模型严重缺乏数据。 + +- 实用的经验法则:大约每个参数训练20个标记。 + +- **混合专家(MoE)**是一种在不成比例增加计算量的情况下扩大模型容量的架构。MoE不采用单一的大型前馈层,而是使用多个**专家**FFN层和一个**门控网络**(路由网络)来选择每个标记应该激活哪些专家。 + +- 门控函数计算每个专家的路由分数,并选择前 $k$ 个(通常 $k = 1$ 或 $k = 2$): + +$$G(x) = \text{TopK}(\text{softmax}(W_g x))$$ + +- 只有被选中的专家处理该标记,因此计算成本随 $k$(活跃专家数)而非总专家数 $E$ 增长。一个有8个专家且采用top-2路由的模型,参数量是稠密模型的4倍,但计算量仅为2倍。 + +![MoE层:输入标记经过路由网络计算每个专家的分数,选择top-2专家,它们的输出按门控分数加权后求和](../images/moe_layer.svg) + +- MoE中一个关键的挑战是**负载均衡**:如果路由网络将大多数标记发送给少数热门专家,其他专家就被浪费了。训练时会添加一个辅助的**负载均衡损失**,鼓励均匀的专家利用率: + +$$\mathcal{L}_{\text{balance}} = E \cdot \sum_{i=1}^{E} f_i \cdot p_i$$ + +- 其中 $f_i$ 是分配给专家 $i$ 的标记比例,$p_i$ 是专家 $i$ 的平均路由概率。当标记比例和概率都均匀(各等于 $1/E$)时,该乘积最小。 + +- **专家并行**将不同的专家分布到不同的加速器上。在前向传播过程中,通过一个全到全的通信步骤将标记路由到其指定专家所在的设备,然后将结果路由回来。这种通信成本是MoE在大规模部署中的主要工程挑战。Switch Transformer、Mixtral和GShard等模型使用MoE来获得强大的性能,同时保持合理的推理成本。 + +- 构建模型只是工作的一半;衡量它们是否有效是另一半。NLP评估特别困难,因为语言是模糊的、主观的和开放式的。 + +- 一个翻译可以有多种正确的表达方式。一个摘要即使与参考摘要没有任何完全相同的词汇,也可能是好的。 + +- 一个聊天机器人的回复可能既有用、又无害、又诚实,但理性的人仍会对此有不同看法。 + +- **精确匹配(EM)**是最简单的指标:模型的输出是否与标准答案完全一致?它用于答案简短且无歧义的任务,如抽取式问答(SQuAD)或封闭式数学问题。 + +- EM是严苛的;"New York City"和"new york city"在不做归一化的情况下无法匹配——但它的简单性使其没有歧义。 + +- **标记级指标**将NLP视为标记级别的分类问题,使用第06章中的精确率、召回率和F1值。 + +- **精确率(Precision)**衡量模型预测的标记中正确部分的比例:$P = \text{TP} / (\text{TP} + \text{FP})$。一个预测很少但全部正确的模型具有高精确率。 + +- **召回率(Recall)**衡量模型找到了多少标准标记:$R = \text{TP} / (\text{TP} + \text{FN})$。一个将所有标记都预测为实体的模型具有完美的召回率但精确率极低。 + +- **F1**是精确率和召回率的调和平均值: + +$$F_1 = \frac{2PR}{P + R}$$ + +- 调和平均值(而非算术平均值)惩罚不均衡:如果 $P$ 或 $R$ 中任何一个较低,F1就会很低。对于命名实体识别(文件02),F1按每个实体类型分别计算,然后跨类型取宏平均。对于词性标注,标记级准确率更常见,因为每个标记都有一个标签。 + +- **跨度级F1**(用于SQuAD)比较预测跨度中的标记集与标准跨度中的标记集。这比精确匹配更宽容:如果标准答案是"the Eiffel Tower"而模型预测的是"Eiffel Tower",跨度F1很高(5个重叠标记中的4个),即使EM为零。 + +- **BLEU**(双语评估替补,Papineni等人,2002)是机器翻译的经典指标。它衡量候选翻译与一个或多个参考翻译之间的n-gram重叠。该评分结合了多个n-gram级别(unigram到4-gram)的精确率和一个简短惩罚: + +$$\text{BLEU} = \text{BP} \cdot \exp\!\left(\sum_{n=1}^{N} w_n \log p_n\right)$$ + +- 其中 $p_n$ 是**修正的n-gram精确率**:候选翻译中每个n-gram的计数被裁剪为其在任何参考翻译中的最大计数,防止像"the the the the"这样的退化候选获得高分。权重 $w_n$ 通常是均匀的($w_n = 1/N$,其中 $N = 4$)。 + +- **简短惩罚** $\text{BP} = \min(1, \exp(1 - r/c))$ 惩罚比参考翻译短的候选($c$ 是候选长度,$r$ 是参考长度)。没有这个惩罚,模型可以通过输出很少但非常安全的词来获得高精确率。 + +- BLEU在语料级别(对多个句子取平均)与人类判断有合理的相关性,但在句子级别相关性较差。 + +- 它奖励精确的n-gram匹配,但会遗漏有效的释义:"the cat is on the mat"和"a feline sits atop the rug"尽管意思相同,但二元组重叠为零。 + +- BLEU也完全忽略了召回率——只输出最常见词汇的候选在精确率上得分很高。 + +- **ROUGE**(面向召回率的摘要评估替补,Lin,2004)是摘要的标准指标。与强调精确率的BLEU不同,ROUGE强调召回率:参考n-gram中有多少比例出现在候选摘要中? + +- **ROUGE-N**计算n-gram的召回率:$\text{ROUGE-N} = \frac{|\text{n-grams}_{\text{ref}} \cap \text{n-grams}_{\text{cand}}|}{|\text{n-grams}_{\text{ref}}|}$。ROUGE-1(unigram)和ROUGE-2(bigram)最为常用。 + +- ROUGE-L使用候选和参考之间的**最长公共子序列(LCS)**,这可以捕捉句子级别的词序信息,而不要求连续匹配。 + +- LCS长度除以参考长度得到召回率,除以候选长度得到精确率,F度量则组合两者。 + +- LCS通过动态规划在 $O(mn)$ 时间内计算(类似于文件02中的编辑距离): + +$$R_{\text{LCS}} = \frac{\text{LCS}(X, Y)}{m}, \quad P_{\text{LCS}} = \frac{\text{LCS}(X, Y)}{n}, \quad F_{\text{LCS}} = \frac{(1 + \beta^2) R_{\text{LCS}} P_{\text{LCS}}}{R_{\text{LCS}} + \beta^2 P_{\text{LCS}}}$$ + +- 其中 $m$ 和 $n$ 分别是参考和候选的长度,$\beta$ 通常设置为偏向召回率($\beta \to \infty$ 给出纯召回率)。 + +- **METEOR**(带显式排序的翻译评估度量,Banerjee和Lavie,2005)通过引入同义词、词干提取和词序来解决BLEU的弱点。 + +- 它首先使用精确匹配、词干匹配(通过文件02中的Porter词干提取算法)和同义词匹配(通过文件01中的WordNet)在候选和参考之间对齐词汇。 + +- 然后计算unigram精确率和召回率的调和平均值(偏向召回率),并应用一个碎片化惩罚,惩罚那些匹配词顺序与参考不同的候选。 + +- **ChrF**(字符n-gram F值)计算字符n-gram而非词汇n-gram的F值。这使其对形态变化具有鲁棒性(对文件01中的黏着语至关重要),并部分处理了分词差异。ChrF++在字符n-gram的基础上增加了词汇二元组。 + +- 它已成为机器翻译中与BLEU一起推荐的度量标准,特别是对于形态丰富的语言。 + +- **困惑度**(文件02)衡量语言模型在保留测试集上的预测效果。这是语言模型的标准内在指标:$\text{PPL} = \exp(-\frac{1}{N} \sum_{i} \log P(w_i \mid w_{ 0$放大了条件的影响。越大的$w$使输出更强烈地遵循提示词,但降低了多样性。 + +- **RLHF**(基于人类反馈的强化学习,Ouyang等人,2022)是对齐语言模型与人类偏好的主流方法。该过程分为三个阶段: + +- 首先,**监督微调(SFT)**:在高质量人工编写的提示-回复数据集上对基础语言模型进行微调。 + +- 其次,**奖励模型训练**:收集人类比较数据(给定提示$x$和两个回复$y_1, y_2$,哪个更好?)并训练一个奖励模型$r_\phi(x, y)$来预测人类偏好。奖励模型使用成对排序损失进行训练: + +$$\mathcal{L}_{\text{RM}} = -\log \sigma(r_\phi(x, y_w) - r_\phi(x, y_l))$$ + +- 其中$y_w$是更受偏好的回复,$y_l$是不受偏好的回复。 + +- 第三,**RL微调**:优化语言模型以最大化奖励,同时保持接近SFT模型(以防止模式崩塌)。这使用带有KL惩罚的PPO(近端策略优化,来自第06章): + +$$\mathcal{L}_{\text{RL}} = -\mathbb{E}\left[r_\phi(x, y) - \beta \, D_{\text{KL}}(\pi_\theta \| \pi_{\text{SFT}})\right]$$ + +- KL项防止模型偏离基础模型太远,并防止模型利用奖励模型的缺陷(\"奖励破解\")。 + +![RLHF 管线](../images/rlhf_pipeline.svg) + +- **DPO**(直接偏好优化,Rafailov等人,2023)通过完全消除奖励模型来简化RLHF。关键的数学洞见是,上述KL约束的RL目标有一个闭式最优策略: + +$$\pi^\ast(y \mid x) = \frac{1}{Z(x)} \pi_{\text{ref}}(y \mid x) \exp\!\left(\frac{r(x, y)}{\beta}\right)$$ + +- 其中$Z(x)$是一个归一化配分函数。整理上式求解奖励得$r(x, y) = \beta \log \frac{\pi^\ast(y \mid x)}{\pi_{\text{ref}}(y \mid x)} + \beta \log Z(x)$。将这个隐式奖励代入Bradley-Terry偏好模型$P(y_w \succ y_l) = \sigma(r(x, y_w) - r(x, y_l))$会导致难以处理的$Z(x)$项相互抵消,直接得到DPO损失: + +$$\mathcal{L}_{\text{DPO}} = -\log \sigma\!\left(\beta \log \frac{\pi_\theta(y_w \mid x)}{\pi_{\text{ref}}(y_w \mid x)} - \beta \log \frac{\pi_\theta(y_l \mid x)}{\pi_{\text{ref}}(y_l \mid x)}\right)$$ + +- 这在数学上等价于RLHF,但将奖励模型和RL训练合并为一个单一的监督步骤。 + +- sigmoid内部的表达式可以理解为:"增加偏好回复的相对概率,降低不偏好回复的相对概率,这是相对于参考模型而言的。" + +- 参数$\beta$控制策略可以偏离参考模型的程度。在实践中,DPO实现更简单(只需计算当前模型和参考模型对两个完成序列的对数概率),并且避免了PPO训练的不稳定性。 + +- **Constitutional AI**(Bai等人,2022)自动化了对齐过程的某些部分。它不再收集人类比较数据,而是让语言模型本身根据一组原则("宪法")来批评和修订自己的输出,例如"选择危害较小的回复"。然后,AI生成的比较数据被用于偏好训练(RLAIF:基于AI反馈的强化学习)。 + +- **长上下文方法**解决了标准自注意力的$O(n^2)$内存和计算成本问题,这限制了序列长度。当$n$增长到数万或数十万个token时,标准注意力变得不可行。 + +- **稀疏注意力**将稠密的$n \times n$注意力矩阵替换为一种稀疏模式,其中每个token只关注其他token的一个子集。常见的模式包括**局部注意力**(每个token关注一个固定大小的相邻窗口)、**步长注意力**(关注每隔$k$个token)和**随机注意力**(关注一个随机子集)。这些模式的组合(用于BigBird、Longformer)实现了$O(n)$或$O(n \sqrt{n})$的复杂度,同时保持了捕获局部和全局依赖关系的能力。 + +![稀疏注意力模式](../images/sparse_attention_patterns.svg) + +- **滑动窗口注意力**将每个token限制为只关注其之前的$w$个token(其局部窗口)。这是$O(nw)$而不是$O(n^2)$,但长距离信息必须通过跨层的重叠窗口传播。对于$L$层和窗口大小$w$,有效感受野为$L \times w$个token。 + +- **环形注意力**通过将设备排列成环形拓扑结构,将长序列分布到多个设备上。每个设备持有序列的一个块,并为其块计算注意力,同时将键值块发送给环中的下一个设备。这种方式将计算与通信重叠,允许任意长度的序列,仅受所有设备总内存的限制,而不受任何单个设备内存的限制。 + +- **记忆增强模型**通过为Transformer配备一个外部记忆库来扩展上下文。在每个层中,模型可以使用注意力从这个记忆库中读取和写入。Memorizing Transformers缓存来自先前块的键值对,并在后续块中关注它们,从而有效地将上下文扩展到训练窗口之外。检索是近似的(使用缓存键的$k$近邻搜索)以保持高效。 + +- 上述方法是处理长上下文的**架构**解决方案。同样重要的是模型如何被**训练**以有效使用长上下文。 + +- **渐进式上下文扩展**是标准方法。从一开始就在非常长的序列上训练代价高昂($O(n^2)$的注意力成本),因此模型在较短的上下文长度上预训练(通常为4K-8K token),然后通过**继续预训练**分阶段扩展到目标长度。 + +- Llama 3.1从8K扩展到128K,使用了800B token,并逐步增加序列长度。DeepSeek-V3在4K处训练,然后扩展到32K,再到128K。 + +- 每个阶段使用适中的token数量(相对于完整的预训练预算),因为模型只需要学习如何使用更长的位置,而不是重新学习语言本身。 + +- 在扩展过程中,位置编码必须进行调整。**RoPE插值**缩小位置索引,使得模型看到与训练时相同的旋转角度,只是分布在更长的序列上。如果模型在长度$L$上训练,你想要扩展到$L' = 4L$,你可以将所有位置索引除以4。 + +- 这意味着模型永远不会遇到未见过的旋转角度,但相邻位置之间的有效分辨率会下降。 + +- **RoPE外推**保持原始位置索引不变,直接将RoPE应用于超出$L$的位置,依赖模型对未见角度的泛化能力。 + +- 插值要稳定得多;在不调整基频(ABF)的情况下,外推会迅速退化。 + +- **YaRN**(Yet another RoPE extensioN,又一种RoPE扩展)改进了朴素插值,因为它认识到并非所有RoPE维度都应被同等对待。 + +- 高频维度(在$\theta_i = \theta_{\text{base}}^{-2i/d}$中较小的$i$)在训练长度内旋转多次,可以很好地外推。 + +- 低频维度(较大的$i$)旋转缓慢,对长度扩展更敏感。 + +- YaRN只插值低频维度,外推高频维度,并对注意力logits应用温度缩放$t$以补偿分布偏移: + +$$\text{score}'_{ij} = \frac{q_i^T k_j}{t \sqrt{d_k}}$$ + +- 其中$t > 1$展平了注意力分布,防止模型在位置信号被压缩时过于尖锐地关注附近的token。 + +- **长上下文数据策展**是一个关键且常被低估的挑战。大多数预训练语料库由短文档组成(新闻文章、网页、社交媒体帖子)。 + +- 长上下文训练需要实际利用完整上下文窗口的数据组合:书籍、代码仓库、长篇科学文章、多轮对话日志,以及主题相关的拼接文档。 + +- 如果模型仅在填充或打包以填满上下文窗口的短文档上训练,它会学会忽略远处的token,因为它们从来都不相关。 + +- **序列打包**是一种训练效率技术:多个文档拼接成一个训练序列以避免填充浪费,使用注意力掩码防止跨文档的注意力。 + +- 对于长上下文训练,打包策略很重要:打包许多不相关的短文档会教模型将远处的token视为噪声,而打包更少的、真正长的文档则教它使用完整的上下文。 + +- 一个已知的失败模式是**"中间迷失"**现象(Liu等人,2023):语言模型能够有效利用上下文窗口开头和结尾的信息,但在处理位于中间的信息时表现困难。 + +- 这类似于人类记忆中的序列位置效应(首因效应和近因效应)。 + +- 它部分源于训练数据的分布(重要信息通常在文档的开头或结尾),部分源于注意力模式集中于邻近token和初始token。 + +- 通过在不同位置放置关键信息进行长上下文训练可以缓解但无法完全解决这个问题。 + +- **大海捞针**评估测试模型是否能够从长长的干扰上下文("大海")中检索出位于不同位置的特定事实("针")。 + +- 具有真正长上下文能力的模型应该无论针放在哪里都能实现近乎完美的检索。 + +- 这个测试清晰地揭示了"中间迷失"效应,并被用作上下文扩展方法的基准。 + +- **预训练后的长上下文微调**使用有针对性的SFT数据:长多轮对话、证据分散在数千个token中的文档问答、长篇摘要,以及仓库级别的代码理解。 + +- Qwen3在此阶段使用**双块注意力(DCA)**,它将长序列作为成对的块进行处理,其中块内注意力是完整的,块间注意力是高效的,在微调期间实现了4倍的有效序列容量。 + +- **状态空间模型(SSM)**提供了一种根本不同的长序列建模方法。它们不是修改注意力,而是用受连续时间控制理论启发的线性动力系统完全取代注意力。 + +- 一个SSM将输入序列$u(t)$通过一个潜在状态$x(t) \in \mathbb{R}^N$映射到输出$y(t)$,其控制方程为: + +$$x'(t) = Ax(t) + Bu(t), \quad y(t) = Cx(t) + Du(t)$$ + +- 其中$A \in \mathbb{R}^{N \times N}$是状态转移矩阵,$B \in \mathbb{R}^{N \times 1}$是输入投影,$C \in \mathbb{R}^{1 \times N}$是输出投影,$D$是一个跳跃连接。 + +- 为了将其应用于离散序列(token),使用步长$\Delta$对连续系统进行**离散化**。零阶保持离散化给出: + +$$\bar{A} = \exp(\Delta A), \quad \bar{B} = (\Delta A)^{-1}(\exp(\Delta A) - I) \cdot \Delta B$$ + +- 离散递归变为$x_k = \bar{A} x_{k-1} + \bar{B} u_k$,$y_k = C x_k + D u_k$,这看起来像一个RNN:每次用一个隐藏状态处理一个token。 + +- 与RNN不同,这个递归也可以展开为一个**全局卷积**:因为系统是线性的,输出为$y = \bar{K} \ast u$,其中核$\bar{K} = (C\bar{B}, \, C\bar{A}\bar{B}, \, C\bar{A}^2\bar{B}, \ldots)$仅取决于固定参数。 + +- 这种**双重视角**——用于高效自回归推理的递归(每步$O(1)$)和用于高效并行训练的卷积(通过FFT实现$O(n \log n)$)——是SSM的核心洞见。 + +![SSM双重视角:推理时的递归、训练时的卷积,以及Mamba的选择性扩展](../images/ssm_dual_view.svg) + +- **S4**(序列建模的结构化状态空间,Gu等人,2022)通过解决关键的数值挑战使SSM变得实用:状态矩阵$A$必须捕获长距离依赖关系,但朴素地参数化会导致梯度消失或爆炸(与普通RNN相同的问题)。 + +- S4使用**HiPPO**(高阶多项式投影算子)矩阵初始化$A$,该矩阵来源于连续信号最优多项式逼近的理论。HiPPO矩阵具有特定的结构,被证明能使状态以优雅衰减的方式维持整个输入历史的压缩表示: + +```math +A_{nk} = -\begin{cases} (2n+1)^{1/2}(2k+1)^{1/2} & \text{if } n > k \\ n+1 & \text{if } n = k \\ 0 & \text{if } n < k \end{cases} +``` + +- 这种下三角结构确保状态使用勒让德多项式作为信号的在线逼近器。计算长核的$\bar{A}^k$代价高昂,因此S4利用HiPPO矩阵可以分解为低秩项和对角项之和的事实,实现了$O(n \log n)$的核计算。 + +- **Mamba**(Gu和Dao,2023)引入了**选择性状态空间**这一关键创新:使SSM参数依赖于输入。在S4中,矩阵$A$、$B$、$C$和步长$\Delta$是固定的——无论内容如何,相同的动力学应用于每个token。Mamba使$B$、$C$和$\Delta$成为输入的函数: + +$$B_k = \text{Linear}(u_k), \quad C_k = \text{Linear}(u_k), \quad \Delta_k = \text{softplus}(\text{Linear}(u_k))$$ + +- 这种选择性允许模型在每个位置决定哪些信息存入状态、哪些信息忽略——类似于注意力如何选择相关token,但没有二次成本。步长$\Delta_k$控制着"门":大的$\Delta$导致状态强烈地整合当前输入(连续动力学前进一大步,有效重置状态),而小的$\Delta$则保留现有状态并忽略当前输入。 + +- 权衡之处在于,依赖于输入的参数打破了卷积视角(核不再固定),因此Mamba无法使用基于FFT的训练。相反,它使用一种**硬件感知的并行扫描**算法,利用递归的结合律:状态更新$(x_k, u_k) \mapsto x_{k+1}$可以表示为一串结合性操作,并使用前缀和(扫描)进行并行化,类似于硬件设计中的并行前缀加法。这在GPU上以$O(n)$时间和$O(\log n)$深度运行,几乎与卷积的效率相当。 + +- Mamba实现了真正每token $O(1)$的推理(只需更新固定大小的状态,没有随上下文增长的KV缓存),使其在长序列长度上从根本上比Transformer更节省内存。状态大小$N$(通常为16)远小于Transformer的KV缓存(存储$O(n \cdot d)$个值)。在实践中,在相同的参数量下,Mamba在语言建模基准上的质量达到或超过Transformer,并且在长序列上推理速度显著更快。 + +- **混合架构**将SSM层与注意力层相结合,使用SSM处理大部分层(高效的长距离传播),并穿插少量注意力层(精确的基于内容的检索)。像Jamba和Zamba这样的模型交错了Mamba和Transformer块,在保持大部分推理效率优势的同时,实现了比纯SSM更好的质量。这表明注意力和SSM捕获了互补的能力:SSM擅长平滑的长距离状态传播,而注意力擅长精确的、依赖于内容的查找。 + +- **检索增强生成(RAG)**通过在推理时让语言模型访问外部知识库,来解决语言模型的知识局限性。RAG不是仅依赖于训练期间编码在模型参数中的知识,而是检索相关文档并基于它们进行条件生成。 + +- 经典的**检索器-阅读器架构**有两个组件。**检索器**接收查询并从语料库中获取最相关的top-$k$个段落。**阅读器**(一个语言模型)基于查询和检索到的段落生成答案。检索器可以使用稀疏方法(BM25,它扩展了文件02中的TF-IDF)或稠密方法。 + +- **稠密段落检索(DPR)**使用双编码器架构:一个编码器将问题映射为向量,另一个将段落映射为向量。两者通常都是基于BERT的。在索引时,所有段落被编码并存储。在查询时,问题被编码,使用近似最近邻搜索(如FAISS)找到最近的段落。相似度度量是问题向量和段落向量之间的点积。 + +- **分块策略**显著影响检索质量。文档必须被分割成足够小以使检索器能够处理的段落,但又要足够大以包含完整的思想。固定大小的分块(例如,256个token,50个token重叠)很简单,但可能笨拙地分割句子。语义分块在段落或章节边界处分割。层次化分块在不同粒度上创建一个摘要树。 + +![RAG 架构](../images/rag_architecture.svg) + +- RAG有几个优势:知识库可以更新而无需重新训练模型,模型可以引用来源,并且因为模型可以基于检索到的文本进行回答,幻觉减少了。主要挑战是检索质量(如果检索到错误的段落,模型可能会自信地给出错误答案)和延迟(检索为推理增加了一个步骤)。 + +- **推测性解码**通过使用一个小的、快速的**草稿模型**并行提出多个token,然后由大的**目标模型**在单个前向传播中进行验证,从而加速自回归生成。 + +- 该算法的工作方式如下:草稿模型自回归地生成$k$个候选token(因为草稿模型很小,所以这很快)。 + +- 然后,目标模型在单个前向传播中同时对全部$k$个token进行评分(因为工作被批处理,所以这很高效)。 + +- 对于从草稿分布$p_d(t)$中采样的每个候选token $t$,它以概率$\min(1, \, p_{\text{target}}(t) / p_d(t))$被接受。如果被拒绝,则从**调整后分布**$p_{\text{adj}}(t) = \max(0, \, p_{\text{target}}(t) - p_d(t))$(经归一化)中重新采样一个修正后的token。 + +- 这种接受-拒绝方案保证了输出分布与单独使用目标模型完全相同。 + +- 为了理解原因,考虑生成token $t$的有效概率。它可以直接被接受(概率$p_d(t) \cdot \min(1, p_{\text{target}}(t)/p_d(t))$),或者通过重新采样产生。 + +- 对于$p_{\text{target}}(t) \leq p_d(t)$的token,直接接受贡献$p_{\text{target}}(t)$。对于$p_{\text{target}}(t) > p_d(t)$的token,直接接受贡献$p_d(t)$,重新采样贡献剩余部分$p_{\text{target}}(t) - p_d(t)$(在考虑拒绝概率之后)。 + +- 在这两种情况下,生成$t$的总概率等于$p_{\text{target}}(t)$。草稿模型只影响速度,不影响质量。 + +![推测性解码](../images/speculative_decoding.svg) + +- 加速取决于接受率:如果草稿模型与目标模型对齐良好,大多数token被接受,墙上时钟时间大致等于草稿模型的时间。典型加速为2-3倍,且质量无下降。 + +- **Medusa**(Cai等人,2024)采用不同的方法:不是使用单独的草稿模型,而是在目标模型本身中添加多个轻量级的预测头。每个头同时预测不同的未来token位置(提前$k = 1, 2, 3, \ldots$步)。在每一步,Medusa使用树状结构提出若干候选延续,通过目标模型注意力层的单个前向传播验证哪些候选是一致的。这完全避免了对单独草稿模型的需求。 + +- **并行生成**方法更广泛地旨在打破自回归解码的串行瓶颈。雅可比解码使用猜测初始化所有位置,并并行地迭代精炼直到收敛,将生成视为一个不动点迭代。非自回归模型(NAT)在单个前向传播中同时生成所有token,但通常遭受质量下降的问题,需要像迭代精炼、CTC损失或来自自回归教师的知识蒸馏这样的技术来缩小差距。 + +- 上述技术——对齐、长上下文、检索、高效解码、状态空间模型——在现代生产级LLM中结合在一起。 + +- 本文的其余部分审视了前沿模型的架构创新,展示了文件01-04中的理论思想以及上述方法是如何在实践中结合起来的。 + +- **分组查询注意力(GQA)** 是采用最广泛的注意力效率技术。标准多头注意力(MHA)为每个头维护独立的键和值投影,每个token需要缓存$n_{\text{heads}} \times d_{\text{head}}$个值。GQA将多个查询头分组以共享一个键-值头。 + +- 使用64个查询头和8个KV头(Llama 3、Qwen、Gemma中的常见配置),每个KV头被8个查询头共享,与MHA相比KV缓存减少了8倍。 + +- 输出质量几乎与MHA相同,因为查询仍然可以关注不同的模式,它们只是共享相同的键-值子空间。多查询注意力(MQA)是所有查询使用单个KV头的极端情况,但GQA提供了更好的质量-效率权衡。 + +- **多头潜在注意力(MLA)**,由DeepSeek-V2引入,实现了更激进的KV缓存压缩。MLA不是缓存完整的键-值投影(即使使用GQA),而是将隐藏状态下投影为一个低秩的**潜在向量**$c_t \in \mathbb{R}^{d_c}$,其中$d_c \ll n_{\text{heads}} \times d_{\text{head}}$: + +$$c_t = W_{\text{down}} \, h_t$$ + +- 仅缓存这个压缩向量。在注意力计算时,通过上投影重建完整的键和值表示:$k_t = W_{\text{up}}^K c_t$,$v_t = W_{\text{up}}^V c_t$。在DeepSeek-V3中(671B总参数,37B激活参数),压缩维度为$d_c = 512$,而完整MHA需要$128 \times 128 = 16{,}384$,KV缓存减少了93%。 + +- 一个微妙的点:标准RoPE依赖于位置,与共享压缩不兼容,因此MLA使用**解耦的RoPE**:查询和键的一个小的独立流(每头64维)通过RoPE携带位置信息,而表示的主要部分通过压缩的潜在路径流动。 + +![注意力KV缓存策略:MHA、GQA和MLA比较](../images/mla_vs_gqa.svg) + +- **大规模位置编码**已经从原始的正弦方案显著分化。所有前沿模型都使用**RoPE**(文件04),但针对长上下文有关键修改。原始RoPE公式$\theta_i = \theta_{\text{base}}^{-2i/d}$中的基频$\theta_{\text{base}}$通常为10,000,这限制了超出训练长度的外推能力。 + +- **调整基频(ABF)**只是将$\theta_{\text{base}}$增加到500,000(Llama 3)或1,000,000(Qwen3、Gemma 3),拉伸旋转周期,使得模型在训练期间遇到更少的完整旋转,从而能够外推得更远。 + +- **YaRN**(Yet another RoPE extensioN,又一种RoPE扩展)应用依赖于频率的插值:低频维度被插值(缩小比例),高频维度被外推,同时温度因子调整注意力分布。DeepSeek-V3、Qwen和Kimi K2都使用基于YaRN的扩展,从预训练时的4K-8K上下文达到128K上下文。 + +- **iRoPE**(交错RoPE),由Llama 4引入,采取了更激进的方法:每4个注意力层中有一个**完全不使用位置编码**(NoPE),而其他层使用标准RoPE配合分块注意力。 + +- NoPE层可以在没有任何位置偏差的情况下关注所有位置,而RoPE层提供局部排序。结合推理时的温度缩放,这使得Llama 4 Scout的1000万token上下文窗口成为可能——比任何纯RoPE方法都高出几个数量级。 + +- **大规模混合专家**已成为前沿模型的主导架构(文件04介绍了MoE基础)。关键的设计选择是专家数量、路由稀疏性和负载均衡。 + +- **路由稀疏性**差异显著:DeepSeek-V3使用256个专家,top-8路由(32倍稀疏);Qwen3使用128个专家,top-8路由(16倍稀疏);Mixtral使用8个专家,top-2路由(4倍稀疏);Llama 4 Maverick使用128个专家,top-1加一个共享专家(128倍稀疏)。 + +- 更高的稀疏性意味着在相同激活计算量下拥有更多总参数,但需要更仔细的负载均衡和通信基础设施。 + +- **无辅助损失的负载均衡**(DeepSeek-V3)取代了传统的负载均衡损失(文件04),后者被发现会降低模型质量。每个专家维护一个动态偏置项,在每个训练步骤进行调整:过载的专家其偏置降低(接收更少的token),欠载的专家其偏置增加。这实现了均衡的路由,没有任何辅助损失污染主要训练信号。 + +- **共享专家**出现在大多数MoE设计中:一个或多个专家FFN处理每个token,无论路由结果如何。这些处理所有token都需要的常见模式(基本语法、功能词),使得路由专家可以专注于 specialization。Llama 4使用1个共享专家加每个token 1个路由专家(非常稀疏);DeepSeek-V3使用1个共享加8个路由。 + +- **交替稠密层和MoE层**提供了另一个设计维度。Gemma 2和3交替使用局部/全局注意力层(Gemma 3中比例为5:1,其中局部层使用1024 token的滑动窗口,只有全局层缓存完整的128K上下文)。 + +- Llama 4 Maverick交错使用稠密FFN层和MoE层。Kimi K2使用混合稀疏层(一个稠密层穿插在专家层之间)。这种异构设计允许不同层服务于不同的功能。 + +- **多token预测(MTP)**,用于DeepSeek-V3,训练模型不仅预测下一个token,还预测后面的token。在每个位置,一个次级预测模块(共享主模型的嵌入)预测一个额外的未来token。MTP损失的权重是主下一个token损失的0.1-0.3倍。除了在训练期间改善表示质量外,MTP头还可以在推理时作为推测性解码的草稿头,提供免费的加速。 + +- **知识蒸馏**是一种训练策略,其中大型"教师"模型的输出指导较小"学生"模型的训练。Gemma 2和3广泛使用蒸馏:较小的模型(2B、4B)在计算最优数据量的50倍上训练,使用教师的概率分布作为软目标。这就是为什么Gemma 3-4B在质量上匹配Gemma 2-27B。 + +- 蒸馏损失替代或补充了标准交叉熵:学生最小化其输出分布与教师分布之间的KL散度: + +$$\mathcal{L}_{\text{distill}} = D_{\text{KL}}(p_{\text{teacher}}(\cdot \mid x) \| p_{\text{student}}(\cdot \mid x))$$ + +- DeepSeek-R1将其671B推理模型蒸馏到小至1.5B的稠密模型中,使用了80万条精选的思维链样本,产生了推理能力异常强的小模型。 + +- **基于强化学习的推理**代表了LLM能力中最显著的最新进展。DeepSeek-R1证明,在基础模型上进行纯强化学习(无需监督微调)可以引出思维链推理、自我验证和纠错行为——当模型因给出正确的最终答案而获得奖励时,这些行为会自发涌现。 + +- DeepSeek-R1使用**GRPO**(组相对策略优化),它消除了PPO所需的价值网络。对于每个提示,GRPO采样一组$G$个输出,计算它们的奖励,并在组内归一化优势值: + +$$A_i = \frac{r_i - \text{mean}(r_1, \ldots, r_G)}{\text{std}(r_1, \ldots, r_G)}$$ + +- 然后策略梯度使用这些组相对优势值,配合一个裁剪目标(类似于PPO的裁剪)。 + +- 消除评论家网络将RL训练的内存和计算需求减半,使得在671B参数模型上进行RL训练变得可行。 + +- 一个关键的设计选择:DeepSeek-R1使用**基于规则的奖励**(对照标准答案检查数学答案、运行代码测试用例)而不是神经奖励模型,因为神经奖励模型在此规模下被发现容易受到奖励破解的影响。 + +- **Qwen3的混合思考模式**将推理(使用``标签进行逐步思维链)和快速直接回复整合到一个模型中,允许用户控制一个"思考预算",在延迟和推理深度之间进行权衡。 + +- 这是通过在思考和思考数据上训练实现的,而不是通过单独的模型检查点。 + +- **大规模训练稳定化**需要超越标准实践的新技术。**Logits软裁剪**(Gemma 2)将注意力分数通过$s \cdot \tanh(\text{logits} / s)$处理,软裁剪值$s$(通常为30-50),以防止无界增长。 + +- **QK归一化**(Qwen3)在计算注意力分数之前对查询和键向量应用RMSNorm,取代了对QKV偏置的需求。**QK裁剪**(Kimi K2的MuonClip优化器)在训练期间监控最大注意力logits,当查询-键权重矩阵超过阈值时对其进行重新缩放,使得1T参数模型的预训练能够稳定进行,且没有不稳定事件。 + +- **FP8混合精度训练**(DeepSeek-V3)在前向和反向传播中使用8位浮点数进行计算密集的矩阵乘法,同时将主权重保持在更高精度。 + +- 与BF16/FP16训练相比,这大致将吞吐量提升了一倍,且质量损失可忽略不计。DeepSeek-V3使用仅280万H800 GPU小时训练了其671B参数模型——只是同类模型的一小部分——这主要归功于这一优化和其他工程优化。 + +- **FP8混合精度训练**(DeepSeek-V3)在前向和反向传播中使用8位浮点数进行计算密集的矩阵乘法,同时将主权重保持在更高精度。 + +- 与BF16/FP16训练相比,这大致将吞吐量提升了一倍,且质量损失可忽略不计。DeepSeek-V3使用仅280万H800 GPU小时训练了其671B参数模型——只是同类模型的一小部分——这主要归功于这一优化和其他工程优化。 + +## 编程练习(使用 CoLab 或 notebook) + +1. 从头实现一个简单的检索增强生成管线。使用TF-IDF(文件02)索引一组文档,为查询检索最相关的段落,并将其前置到提示中。 +```python +import jax.numpy as jnp +import math +from collections import Counter + +# 知识库:一组简短段落 +knowledge_base = [ + "The Eiffel Tower is a wrought-iron lattice tower in Paris, France. It was constructed from 1887 to 1889 as the centerpiece of the 1889 World's Fair.", + "The Great Wall of China is a series of fortifications built along the northern borders of China. Construction began in the 7th century BC.", + "Photosynthesis is the process by which plants convert sunlight, water, and carbon dioxide into glucose and oxygen using chlorophyll.", + "The theory of general relativity, published by Albert Einstein in 1915, describes gravity as the curvature of spacetime caused by mass and energy.", + "Python is a high-level programming language known for its simple syntax and readability. It was created by Guido van Rossum and released in 1991.", + "The mitochondria are organelles found in eukaryotic cells. They generate most of the cell's supply of ATP, used as a source of chemical energy.", +] + +# 构建 TF-IDF 索引(重用了文件02中的概念) +def tokenise(text): + return text.lower().split() + +vocab = sorted(set(w for doc in knowledge_base for w in tokenise(doc))) +word2idx = {w: i for i, w in enumerate(vocab)} +V = len(vocab) +N = len(knowledge_base) + +# 文档频率 +doc_freq = Counter() +for doc in knowledge_base: + for w in set(tokenise(doc)): + doc_freq[w] += 1 + +def tfidf_vector(text): + words = tokenise(text) + counts = Counter(words) + vec = jnp.zeros(V) + for w, c in counts.items(): + if w in word2idx: + tf = 1 + math.log(c) + idf = math.log(N / (doc_freq.get(w, 0) + 1)) + vec = vec.at[word2idx[w]].set(tf * idf) + return vec + +# 索引所有文档 +doc_vectors = jnp.stack([tfidf_vector(doc) for doc in knowledge_base]) + +def cosine_sim(a, b): + return jnp.dot(a, b) / (jnp.linalg.norm(a) * jnp.linalg.norm(b) + 1e-8) + +def retrieve(query, top_k=2): + """为查询检索top-k个最相关的段落。""" + q_vec = tfidf_vector(query) + sims = jnp.array([cosine_sim(q_vec, doc_vectors[i]) for i in range(N)]) + top_indices = jnp.argsort(-sims)[:top_k] + return [(int(i), float(sims[i]), knowledge_base[int(i)]) for i in top_indices] + +# 测试检索 +queries = [ + "Who built the Eiffel Tower?", + "How do plants make food?", + "What did Einstein discover?", +] + +for query in queries: + results = retrieve(query, top_k=1) + print(f"\nQuery: '{query}'") + for idx, sim, passage in results: + print(f" Retrieved (sim={sim:.3f}): '{passage[:80]}...'") + + # RAG风格的提示构建 + context = results[0][2] + rag_prompt = f"Context: {context}\n\nQuestion: {query}\nAnswer:" + print(f" RAG prompt:\n {rag_prompt[:120]}...") +``` + +2. 使用玩具草稿模型和目标模型实现推测性解码。展示接受的输出与目标模型的分布一致。 +```python +import jax +import jax.numpy as jnp + +# 模拟草稿模型(快速,不太准确)和目标模型(慢速,准确) +vocab_size = 8 +seq_len = 5 + +key = jax.random.PRNGKey(42) + +# 目标模型:给定序列返回logits +def target_model(seq, key): + """模拟的目标模型:产生token logits(昂贵的)。""" + # 实践中这将是一个大型Transformer前向传播 + k1, k2 = jax.random.split(key) + logits = jax.random.normal(k1, (len(seq), vocab_size)) * 2 + # 使其有些可预测性:偏向于 token (seq[-1] + 1) % vocab_size + for i in range(len(seq)): + logits = logits.at[i, (seq[i] + 1) % vocab_size].add(3.0) + return logits + +def draft_model(seq, key): + """模拟的草稿模型:类似但噪声更大(便宜的)。""" + k1, k2 = jax.random.split(key) + logits = jax.random.normal(k1, (len(seq), vocab_size)) + for i in range(len(seq)): + logits = logits.at[i, (seq[i] + 1) % vocab_size].add(2.0) + return logits + +def sample_token(logits, key): + return jax.random.categorical(key, logits) + +def speculative_decode(prefix, draft_steps=3, key=jax.random.PRNGKey(0)): + """推测性解码:草稿提出,目标验证。""" + seq = list(prefix) + total_accepted = 0 + total_proposed = 0 + + for _ in range(4): # 生成4轮 + key, *subkeys = jax.random.split(key, draft_steps + 3) + + # 草稿模型提出draft_steps个token + draft_tokens = [] + draft_probs = [] + draft_seq = list(seq) + for i in range(draft_steps): + d_logits = draft_model(jnp.array(draft_seq), subkeys[i]) + d_probs = jax.nn.softmax(d_logits[-1]) + tok = sample_token(d_logits[-1], subkeys[i]) + draft_tokens.append(int(tok)) + draft_probs.append(d_probs) + draft_seq.append(int(tok)) + + # 目标模型在一次前向中评估所有草稿token + target_logits = target_model(jnp.array(draft_seq), subkeys[draft_steps]) + target_start = len(seq) - 1 # 最后一个前缀token的位置 + + # 接受/拒绝每个草稿token + accepted = 0 + for i in range(draft_steps): + t_probs = jax.nn.softmax(target_logits[target_start + i]) + d_prob = draft_probs[i][draft_tokens[i]] + t_prob = t_probs[draft_tokens[i]] + + # 以概率 min(1, target_prob / draft_prob) 接受 + accept_prob = jnp.minimum(1.0, t_prob / (d_prob + 1e-10)) + key, accept_key = jax.random.split(key) + if jax.random.uniform(accept_key) < accept_prob: + seq.append(draft_tokens[i]) + accepted += 1 + else: + # 拒绝:从调整后的分布中采样 + key, resample_key = jax.random.split(key) + adjusted = jnp.maximum(0, t_probs - draft_probs[i]) + adjusted = adjusted / (adjusted.sum() + 1e-10) + new_tok = jax.random.categorical(resample_key, jnp.log(adjusted + 1e-10)) + seq.append(int(new_tok)) + break + + total_accepted += accepted + total_proposed += draft_steps + + return seq, total_accepted, total_proposed + +# 运行推测性解码 +prefix = [0, 1] +result_seq, accepted, proposed = speculative_decode(prefix) +acceptance_rate = accepted / proposed if proposed > 0 else 0 + +print(f"Prefix: {prefix}") +print(f"Generated sequence: {result_seq}") +print(f"Draft proposals: {proposed}") +print(f"Accepted: {accepted}") +print(f"Acceptance rate: {acceptance_rate:.1%}") +print(f"Speedup potential: {(accepted + proposed) / proposed:.2f}x") +``` + +3. 构建一个简单的DPO训练循环。给定偏好和不偏好的完成序列对,使用DPO损失更新一个小模型。 +```python +import jax +import jax.numpy as jnp + +# 微型语言模型:从one-hot到logits的线性投影 +vocab_size = 10 +seq_len = 4 + +key = jax.random.PRNGKey(42) +k1, k2 = jax.random.split(key) + +# 当前策略参数(可训练的) +theta = jax.random.normal(k1, (vocab_size, vocab_size)) * 0.1 +# 参考策略参数(theta的冻结副本) +theta_ref = theta.copy() + +def log_prob_sequence(params, sequence): + """计算简单自回归模型下的 log P(sequence)。""" + total = 0.0 + for t in range(1, len(sequence)): + # 简单:位置t处的logits取决于位置t-1处的token + logits = params[sequence[t-1]] + log_probs = jax.nn.log_softmax(logits) + total += log_probs[sequence[t]] + return total + +def dpo_loss(theta, theta_ref, preferred, dispreferred, beta=0.1): + """一对数据的直接偏好优化损失。""" + log_pi_w = log_prob_sequence(theta, preferred) + log_pi_l = log_prob_sequence(theta, dispreferred) + log_ref_w = log_prob_sequence(theta_ref, preferred) + log_ref_l = log_prob_sequence(theta_ref, dispreferred) + + # DPO目标 + return -jax.nn.log_sigmoid( + beta * ((log_pi_w - log_ref_w) - (log_pi_l - log_ref_l)) + ) + +# 偏好数据集:(提示前缀, 偏好完成序列, 不偏好完成序列) +preferences = [ + (jnp.array([1, 3, 5, 7]), jnp.array([1, 3, 5, 2])), # 结尾偏好7而不是2 + (jnp.array([0, 2, 4, 6]), jnp.array([0, 2, 4, 9])), # 偏好6而不是9 + (jnp.array([3, 3, 3, 3]), jnp.array([3, 3, 3, 0])), # 偏好重复而不是0 + (jnp.array([5, 6, 7, 8]), jnp.array([5, 6, 7, 1])), # 偏好8而不是1 +] + +grad_fn = jax.jit(jax.grad(dpo_loss)) +lr = 0.05 + +print("训练 DPO...") +for epoch in range(100): + total_loss = 0.0 + for preferred, dispreferred in preferences: + loss = dpo_loss(theta, theta_ref, preferred, dispreferred) + grads = grad_fn(theta, theta_ref, preferred, dispreferred) + theta = theta - lr * grads + total_loss += loss + if (epoch + 1) % 20 == 0: + avg_loss = total_loss / len(preferences) + print(f" Epoch {epoch+1}: avg DPO loss = {avg_loss:.4f}") + +# 检查:模型现在应该偏好偏好的完成序列 +print("\nDPO训练后的偏好检查:") +for preferred, dispreferred in preferences: + lp_w = log_prob_sequence(theta, preferred) + lp_l = log_prob_sequence(theta, dispreferred) + print(f" Preferred {list(preferred.astype(int))}: logP={lp_w:.3f} " + f"Dispreferred {list(dispreferred.astype(int))}: logP={lp_l:.3f} " + f"{'correct' if lp_w > lp_l else 'WRONG'}") +``` diff --git a/chapter 08: computer vision/01. image fundamentals.md b/chapter 08: computer vision/01. image fundamentals.md new file mode 100644 index 0000000..1911af1 --- /dev/null +++ b/chapter 08: computer vision/01. image fundamentals.md @@ -0,0 +1,363 @@ +# 图像基础 + +*图像基础解释数字图像在被任何模型处理之前如何表示、形成和预处理。本文涵盖像素、色彩空间(RGB、HSV、YCbCr、LAB)、针孔相机模型、卷积、边缘检测(Sobel、Canny)、直方图以及特征描述子(SIFT、ORB),是底层视觉的工具包。* + +- **数字图像**是一个二维数字网格。网格中的每个单元格是一个**像素**(图像元素),其值表示强度或颜色。灰度图像是一个单一的二维矩阵,其中每个像素包含一个亮度值,对于 8 位图像,通常范围从 0(黑色)到 255(白色)。 + +- 彩色图像将此扩展到三个通道。在 **RGB** 色彩空间中,每个像素存储三个值:红色、绿色和蓝色的强度。 + +- 彩色图像是一个形状为 (高度, 宽度, 3) 的三维张量(矩阵)。以不同强度混合这三个通道可以产生完整的可见光谱。 + +![彩色图像分解为红、绿、蓝三个通道,每个通道显示为灰度强度图](../images/rgb_channels.svg) + +- **位深度**决定每个通道可以表示的离散强度级别数量。 + +- 8 位图像每个通道有 $2^8 = 256$ 个级别,总共 $256^3 \approx 1670$ 万种可能的颜色。16 位图像每个通道有 65,536 个级别,用于医学成像和高动态范围摄影等对精细强度差异敏感的场景。 + +- RGB 便于显示,但其他色彩空间更适合不同的任务。 + +- **HSV**(色调、饱和度、明度)将颜色信息与亮度分离。色调是纯色(在色环上 0-360 度),饱和度是颜色的鲜艳程度(0 = 灰色,1 = 纯色),明度是亮度。HSV 适合基于颜色的分割,因为你可以仅根据色调设定阈值,而无需考虑光照条件。在 HSV 中检测"红色物体"比在 RGB 中容易得多。 + +- **YCbCr** 将亮度(Y,感知亮度)与色度(Cb、Cr,颜色差异信号)分离。这是 JPEG 压缩和视频编解码器中使用的色彩空间。人眼对亮度比对颜色更敏感,因此色度可以以较低分辨率存储(色度子采样)而几乎不产生感知损失。 + +- **LAB**(CIELAB)的设计目标是使两种颜色之间的数值距离对应于感知差异。在 LAB 空间中相等的步长对人眼观察者来说看起来也是相等的。L 通道是明度,A 从绿色到红色,B 从蓝色到黄色。当需要感知均匀的颜色比较时,使用 LAB。 + +- **图像形成**描述三维场景如何变成二维图像。最简单的模型是**针孔相机**:来自场景的光线通过一个小孔投射到其后的传感器平面上。世界坐标系中的点 $(X, Y, Z)$ 投影到像素坐标 $(u, v)$: + +```math +\begin{bmatrix} u \\ v \\ 1 \end{bmatrix} = \frac{1}{Z} \begin{bmatrix} f_x & 0 & c_x \\ 0 & f_y & c_y \\ 0 & 0 & 1 \end{bmatrix} \begin{bmatrix} X \\ Y \\ Z \end{bmatrix} +``` + +- 这个 3x3 矩阵是**内参矩阵** $K$。它编码了相机的内部属性:焦距 $f_x, f_y$(透镜会聚光线的强度)和主点 $(c_x, c_y)$(光轴与传感器的交点,通常靠近图像中心)。对于给定的相机和镜头组合,这些参数是固定的。 + +![针孔相机模型:三维点通过光学中心投影到图像平面上,标注了焦距和主点](../images/pinhole_camera.svg) + +- **外参**描述相机在世界中的位置:一个旋转矩阵 $R$(3x3,来自第 02 章)和一个平移向量 $t$(3x1)。它们共同将世界坐标转换为相机坐标。完整的投影是: + +$$\mathbf{p} = K [R \mid t] \mathbf{P}$$ + +- 其中 $\mathbf{P} = [X, Y, Z, 1]^T$ 是齐次坐标下的三维点,$\mathbf{p} = [u, v, 1]^T$ 是投影后的像素。$[R \mid t]$ 矩阵是 3x4,将旋转和平移并排放置。这全是第 02 章中的线性代数。 + +- 真实镜头会引入**畸变**。 + + - **径向畸变**使直线弯曲成曲线(桶形畸变使图像向外凸出;枕形畸变使其向内收缩)。 + **切向畸变**源于镜头未与传感器完全平行。 + +- 相机标定通过拍摄已知图案(如棋盘格)的图像来估计内参和畸变系数,然后校正(去畸变)图像。 + +- **空间滤波**是经典图像处理的基础。一个**滤波器**(或卷积核)是一个小矩阵(通常为 3x3 或 5x5),它在图像上滑动。在每个位置,滤波器的值与重叠的图像块逐元素相乘并求和,产生一个输出像素。这就是**二维卷积**,与驱动 CNN(文件 02)的运算相同,但这里的滤波器权重是手工设计而非学习得到的。 + +$$(\text{图像} * K)[i,j] = \sum_{m} \sum_{n} \text{图像}[i+m, j+n] \cdot K[m, n]$$ + +- 这是第 06 章中一维卷积的二维扩展。滤波器决定了该运算检测的内容:不同的滤波器检测不同的特征。 + +- **模糊**通过对相邻像素取平均来平滑图像。**盒式滤波器**对所有相邻像素赋予相同的权重。 + +- **高斯滤波器**通过二维高斯函数(第 05 章)对相邻像素加权,给相邻像素更大的权重,给远处的像素更小的权重。高斯模糊是最常见的平滑操作,由 $\sigma$ 参数化:$\sigma$ 越大,平滑程度越高。 + +- **中值滤波**用邻域的中值代替每个像素,而非加权平均。它在去除椒盐噪声(随机的黑白像素)方面特别有效,同时保留边缘,因为中值对异常值具有鲁棒性(如第 04 章所讨论的)。 + +- **边缘检测**识别像素强度急剧变化的边界。边缘承载了图像中的大部分结构信息;仅凭边缘就可以识别物体。 + +- **Sobel 算子**使用两个 3x3 滤波器来估计水平方向和垂直方向的梯度: + +```math +G_x = \begin{bmatrix} -1 & 0 & 1 \\ -2 & 0 & 2 \\ -1 & 0 & 1 \end{bmatrix}, \quad G_y = \begin{bmatrix} -1 & -2 & -1 \\ 0 & 0 & 0 \\ 1 & 2 & 1 \end{bmatrix} +``` + +- 将图像与 $G_x$ 卷积得到水平梯度(对垂直边缘响应强烈),与 $G_y$ 卷积得到垂直梯度(对水平边缘响应强烈)。 + +- 梯度幅值 $\sqrt{G_x^2 + G_y^2}$ 和方向 $\arctan(G_y / G_x)$ 共同描述每个像素处的边缘强度和方向。这是第 03 章中梯度在图像域的对应概念。 + +![原始图像、Sobel 水平梯度、Sobel 垂直梯度和组合边缘幅值](../images/sobel_edges.svg) + +- **Canny 边缘检测器**是边缘检测的黄金标准。它包含四个步骤: + 1. 使用高斯滤波器平滑图像以减少噪声 + 2. 计算梯度幅值和方向(使用 Sobel) + 3. **非极大值抑制**:仅保留沿梯度方向为局部最大值的像素,细化边缘 + 4. **滞后阈值处理**:使用两个阈值(高阈值和低阈值)。高于高阈值的像素是确定边缘。介于两个阈值之间的像素仅当连接到确定边缘时才被视为边缘。低于低阈值的像素被舍弃。 + +- Canny 中的双阈值使其比单阈值更鲁棒:强边缘始终被保留,弱边缘仅当属于连续边缘结构时才被保留。 + +- **频域**分析揭示了在空间域难以看到的模式。**二维傅里叶变换**(扩展自第 03 章的一维版本)将图像分解为不同频率和方向的正弦模式之和: + +$$F(u, v) = \sum_{x=0}^{M-1} \sum_{y=0}^{N-1} f(x, y) \cdot e^{-j2\pi(ux/M + vy/N)}$$ + +- 低频对应平滑、缓慢变化的区域(天空、墙壁)。高频对应锐利变化(边缘、纹理、噪声)。**幅度谱**显示每个频率上存在多少能量,**相位谱**编码了空间排列信息。 + +- **低通滤波**去除高频,从而平滑图像(相当于空间域的高斯模糊)。**高通滤波**去除低频,从而强调边缘和细节。**带通滤波**只保留一定范围的频率,用于纹理分析。 + +- 在实践中,对于大尺寸滤波器,频域滤波可能比空间卷积更快,因为空间域中的卷积等价于频域中的逐元素乘法(**卷积定理**)。这直接联系到第 03 章中的傅里叶变换性质。 + +- **直方图**总结像素强度的分布。直方图统计每个强度值有多少像素(对于 8 位图像为 0-255)。这是第 04 章中的频率分布应用于像素值。 + +![图像及其强度直方图:暗图像的直方图偏左,亮图像的直方图偏右](../images/image_histogram.svg) + +- 暗图像的直方图集中在左侧(低值)。亮图像的直方图集中在右侧。低对比度图像的直方图狭窄。高对比度图像的直方图宽而分散。 + +- **直方图均衡化**将直方图拉伸以覆盖整个强度范围,从而改善对比度。其思路是找到一个映射,使像素强度的累积分布函数(CDF)近似为线性。这是第 04 章中 CDF 概念的直接应用。 + +- **Otsu 方法**自动找到将图像分割为前景和背景的最佳阈值。它尝试每个可能的阈值,并选择使类内方差最小(或等价地,使类间方差最大)的阈值。这是第 04 章中方差概念应用于像素强度群体的体现。 + +- **特征提取**识别图像中可用于匹配、识别和三维重建的独特点或区域。好的特征应具有可重复性(在不同视角下能被再次找到)、独特性(可与其他特征区分)和计算高效性。 + +- **角点检测**寻找图像强度在多个方向上显著变化的点。平滑区域在任何方向上的变化都很小。边缘在一个方向上有变化。角点在至少两个方向上都有变化,使其在局部是唯一的,因此是可靠的标志点。 + +- **Harris 角点检测器**分析每个像素处的**结构张量**(也称为二阶矩矩阵): + +```math +M = \sum_{(x,y) \in W} w(x,y) \begin{bmatrix} I_x^2 & I_x I_y \\ I_x I_y & I_y^2 \end{bmatrix} +``` + +- 其中 $I_x$ 和 $I_y$ 是图像梯度(使用 Sobel 计算),$W$ 是局部窗口,$w$ 是高斯加权函数。$M$ 的特征值(来自第 02 章)告诉你特征的类型: + - 两个特征值都很小:平坦区域(无特征) + - 一个很大,一个很小:边缘 + - 两个都很大:角点 + +- Harris 不显式计算特征值,而是使用角点响应函数:$R = \det(M) - k \cdot (\text{tr}(M))^2$,其中 $\det(M) = \lambda_1 \lambda_2$ 且 $\text{tr}(M) = \lambda_1 + \lambda_2$(均来自第 02 章)。$R$ 为正且较大时表示角点。常数 $k$ 通常为 0.04-0.06。 + +- **Shi-Tomasi** 检测器将其简化为 $R = \min(\lambda_1, \lambda_2)$,直接检查较小的特征值是否足够大。这在实际中稍微更稳定。 + +- **斑点检测**寻找与周围环境不同的区域。与角点(属于点特征)不同,斑点具有特征尺寸。 + +- **SIFT**(尺度不变特征变换,Lowe,2004)在多个尺度上检测斑点,并构建对旋转、尺度具有不变性,对光照变化具有部分不变性的描述子。它的工作原理是: + 1. 使用逐渐增大 $\sigma$ 的高斯模糊构建**尺度空间**(见下文) + 2. 在尺度间的 Gaussian 差分(DoG)中寻找极值点 + 3. 精炼关键点位置,去除低对比度点和边缘响应 + 4. 基于局部梯度方向分配主方向 + 5. 从关键点周围 16x16 块中的梯度直方图构建 128 维描述子 + +- **SURF**(加速稳健特征)使用盒式滤波器和积分图像近似 SIFT 以实现更快的计算。**ORB**(定向 FAST 和旋转 BRIEF)是一个快速、开源的替代方案,它将 FAST 角点检测器与 BRIEF 二进制描述子结合,并增加了旋转不变性。 + +- **HOG**(方向梯度直方图)描述子将图像划分为小单元格,计算每个单元格内梯度方向的直方图,并在单元格块间进行归一化。HOG 捕捉边缘方向的分布,这对物体形状具有高度信息量。在深度学习之前,HOG + SVM(第 06 章)是行人检测和物体识别的主流方法。 + +- **图像金字塔**以多种分辨率表示图像。 + - **高斯金字塔**通过重复模糊和下采样(分辨率减半)构建。每一层都是原始图像的粗略版本。 + - **拉普拉斯金字塔**存储连续高斯层之间的差异,捕捉每一步下采样丢失的细节。拉普拉斯金字塔是可逆的:你可以从中重建原始图像。 + +![高斯金字塔:原始图像为全分辨率,然后每层逐步缩小为一半分辨率](../images/image_pyramid.svg) + +- **尺度空间**形式化了物体存在于不同尺度这一概念。一棵树是一个大斑点;树上的一片叶子是一个小斑点。要同时检测两者,你需要跨尺度搜索。图像的尺度空间是通过将图像与逐渐增大 $\sigma$ 的高斯函数卷积得到的图像族: + +$$L(x, y, \sigma) = G(x, y, \sigma) * I(x, y)$$ + +- 其中 $G$ 是标准差为 $\sigma$ 的二维高斯函数。跨多个尺度持续存在的特征更有可能是有意义的结构而非噪声。尺度空间是 SIFT 的理论基础,也是贯穿现代计算机视觉的多尺度处理的基础,包括目标检测中的特征金字塔网络(文件 03)。 + +## 编码任务(使用 CoLab 或 notebook) + +1. 加载图像,将其转换为不同的色彩空间(RGB、HSV、LAB),并可视化各个通道。观察颜色信息在不同空间中的分布差异。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt +from PIL import Image +import numpy as np + +# Create a synthetic test image with distinct colours +H, W = 128, 256 +img = np.zeros((H, W, 3), dtype=np.uint8) +img[:, :64] = [255, 50, 50] # red +img[:, 64:128] = [50, 255, 50] # green +img[:, 128:192] = [50, 50, 255] # blue +img[:, 192:] = [255, 255, 50] # yellow + +# Add a brightness gradient +for y in range(H): + scale = 0.3 + 0.7 * y / H + img[y] = (img[y] * scale).astype(np.uint8) + +img_jnp = jnp.array(img, dtype=jnp.float32) / 255.0 + +# Manual RGB to HSV conversion +def rgb_to_hsv(rgb): + r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2] + maxc = jnp.max(rgb, axis=-1) + minc = jnp.min(rgb, axis=-1) + diff = maxc - minc + 1e-7 + + # Hue + h = jnp.where(maxc == minc, 0.0, + jnp.where(maxc == r, 60 * ((g - b) / diff % 6), + jnp.where(maxc == g, 60 * ((b - r) / diff + 2), + 60 * ((r - g) / diff + 4)))) + s = jnp.where(maxc < 1e-7, 0.0, diff / maxc) + v = maxc + return jnp.stack([h / 360, s, v], axis=-1) + +hsv = rgb_to_hsv(img_jnp) + +fig, axes = plt.subplots(2, 3, figsize=(14, 8)) +for i, (ch, name) in enumerate(zip([img_jnp[...,0], img_jnp[...,1], img_jnp[...,2]], + ['Red', 'Green', 'Blue'])): + axes[0, i].imshow(ch, cmap='gray', vmin=0, vmax=1) + axes[0, i].set_title(f'RGB: {name}'); axes[0, i].axis('off') + +for i, (ch, name) in enumerate(zip([hsv[...,0], hsv[...,1], hsv[...,2]], + ['Hue', 'Saturation', 'Value'])): + axes[1, i].imshow(ch, cmap='gray', vmin=0, vmax=1) + axes[1, i].set_title(f'HSV: {name}'); axes[1, i].axis('off') + +plt.suptitle('RGB vs HSV Channels') +plt.tight_layout(); plt.show() +``` + +2. 使用二维卷积从头实现 Sobel 边缘检测和高斯模糊。将其应用于图像并比较结果。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +def conv2d(image, kernel): + """2D convolution (valid mode) from scratch.""" + H, W = image.shape + kH, kW = kernel.shape + out_h, out_w = H - kH + 1, W - kW + 1 + output = jnp.zeros((out_h, out_w)) + for i in range(out_h): + for j in range(out_w): + patch = image[i:i+kH, j:j+kW] + output = output.at[i, j].set(jnp.sum(patch * kernel)) + return output + +# Create a test image: white rectangle on dark background +img = jnp.zeros((64, 64)) +img = img.at[15:50, 20:45].set(1.0) +# Add some noise +key = jax.random.PRNGKey(42) +img = img + jax.random.normal(key, img.shape) * 0.05 + +# Sobel filters +sobel_x = jnp.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=jnp.float32) +sobel_y = jnp.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=jnp.float32) + +# Gaussian blur kernel (5x5, sigma=1) +ax = jnp.arange(-2, 3, dtype=jnp.float32) +xx, yy = jnp.meshgrid(ax, ax) +gaussian = jnp.exp(-(xx**2 + yy**2) / (2 * 1.0**2)) +gaussian = gaussian / gaussian.sum() + +# Apply filters +gx = conv2d(img, sobel_x) +gy = conv2d(img, sobel_y) +edges = jnp.sqrt(gx**2 + gy**2) +blurred = conv2d(img, gaussian) + +fig, axes = plt.subplots(1, 4, figsize=(16, 4)) +for ax, data, title in zip(axes, + [img, edges, blurred, gx], + ['Original', 'Edge Magnitude', 'Gaussian Blur', 'Horizontal Gradient']): + ax.imshow(data, cmap='gray') + ax.set_title(title); ax.axis('off') +plt.tight_layout(); plt.show() +``` + +3. 从头实现直方图均衡化,并将其应用于低对比度灰度图像。比较均衡前后的直方图。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# Create a low-contrast image (values clustered in a narrow range) +key = __import__('jax').random.PRNGKey(42) +img = __import__('jax').random.uniform(key, (128, 128)) * 0.3 + 0.3 # values in [0.3, 0.6] + +def histogram_equalise(img, n_bins=256): + """Histogram equalisation for a grayscale image.""" + # Quantise to bins + bins = jnp.linspace(0, 1, n_bins + 1) + hist = jnp.histogram(img, bins=bins)[0] + + # Compute CDF + cdf = jnp.cumsum(hist) + cdf_normalised = (cdf - cdf.min()) / (cdf.max() - cdf.min()) + + # Map each pixel through the CDF + indices = jnp.clip((img * n_bins).astype(jnp.int32), 0, n_bins - 1) + equalised = cdf_normalised[indices] + return equalised + +eq_img = histogram_equalise(img) + +fig, axes = plt.subplots(2, 2, figsize=(12, 10)) +axes[0, 0].imshow(img, cmap='gray', vmin=0, vmax=1) +axes[0, 0].set_title('Original (Low Contrast)'); axes[0, 0].axis('off') +axes[0, 1].imshow(eq_img, cmap='gray', vmin=0, vmax=1) +axes[0, 1].set_title('After Histogram Equalisation'); axes[0, 1].axis('off') + +axes[1, 0].hist(img.ravel(), bins=64, color='#3498db', alpha=0.8) +axes[1, 0].set_title('Histogram Before'); axes[1, 0].set_xlim(0, 1) +axes[1, 1].hist(eq_img.ravel(), bins=64, color='#e74c3c', alpha=0.8) +axes[1, 1].set_title('Histogram After'); axes[1, 1].set_xlim(0, 1) + +plt.tight_layout(); plt.show() +``` + +4. 从头实现 Harris 角点检测器。在简单图像中检测角点并可视化。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +def harris_corners(img, k=0.05, threshold=0.01): + """Harris corner detection from scratch.""" + # Compute gradients with Sobel + sobel_x = jnp.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=jnp.float32) + sobel_y = jnp.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=jnp.float32) + + # Pad image for valid convolution to preserve size + img_pad = jnp.pad(img, 1, mode='edge') + H, W = img.shape + + Ix = jnp.zeros_like(img) + Iy = jnp.zeros_like(img) + for i in range(H): + for j in range(W): + patch = img_pad[i:i+3, j:j+3] + Ix = Ix.at[i, j].set(jnp.sum(patch * sobel_x)) + Iy = Iy.at[i, j].set(jnp.sum(patch * sobel_y)) + + # Structure tensor components + Ixx = Ix * Ix + Iyy = Iy * Iy + Ixy = Ix * Iy + + # Gaussian smoothing of structure tensor (approximate with window sum) + w = 3 # window half-size + R = jnp.zeros_like(img) + pad_xx = jnp.pad(Ixx, w, mode='constant') + pad_yy = jnp.pad(Iyy, w, mode='constant') + pad_xy = jnp.pad(Ixy, w, mode='constant') + + for i in range(H): + for j in range(W): + sxx = jnp.sum(pad_xx[i:i+2*w+1, j:j+2*w+1]) + syy = jnp.sum(pad_yy[i:i+2*w+1, j:j+2*w+1]) + sxy = jnp.sum(pad_xy[i:i+2*w+1, j:j+2*w+1]) + det = sxx * syy - sxy * sxy + trace = sxx + syy + R = R.at[i, j].set(det - k * trace * trace) + + # Threshold + corners = R > threshold * R.max() + return R, corners + +# Test image: checkerboard pattern (lots of corners) +block = 16 +n = 4 +checker = jnp.zeros((block * n, block * n)) +for i in range(n): + for j in range(n): + if (i + j) % 2 == 0: + checker = checker.at[i*block:(i+1)*block, j*block:(j+1)*block].set(1.0) + +R, corners = harris_corners(checker) +cy, cx = jnp.where(corners) + +fig, axes = plt.subplots(1, 3, figsize=(14, 4)) +axes[0].imshow(checker, cmap='gray') +axes[0].set_title('Checkerboard'); axes[0].axis('off') +axes[1].imshow(R, cmap='hot') +axes[1].set_title('Harris Response'); axes[1].axis('off') +axes[2].imshow(checker, cmap='gray') +axes[2].scatter(cx, cy, c='#e74c3c', s=15, marker='x') +axes[2].set_title(f'Detected Corners ({len(cx)})'); axes[2].axis('off') +plt.tight_layout(); plt.show() +``` diff --git a/chapter 08: computer vision/02. convolutional networks.md b/chapter 08: computer vision/02. convolutional networks.md new file mode 100644 index 0000000..a533bf5 --- /dev/null +++ b/chapter 08: computer vision/02. convolutional networks.md @@ -0,0 +1,382 @@ +# 卷积网络 + +*卷积神经网络直接从像素数据中学习空间特征层级,用梯度优化的滤波器取代人工设计的滤波器。本文涵盖卷积机制、池化、步长、空洞卷积、感受野,以及定义了图像分类的标志性架构(LeNet、AlexNet、VGG、ResNet、Inception、EfficientNet)。* + +- 在文件 01 中,我们手工设计了用于边缘检测、模糊和角点检测的滤波器。一个自然而然的问题是:我们能否从数据中学习最优的滤波器?这正是卷积神经网络(CNN)所做的。 + +- CNN 不是手动选择滤波器权重,而是通过梯度下降(第 06 章)学习它们,发现对当前任务直接有用的特征。 + +- 在第 06 章中,我们介绍了卷积操作、CNN 基础以及滤波器学习的思想。在这里,我们深入探讨使 CNN 在十多年来成为计算机视觉主导范式的架构创新。 + +- 回顾核心的**卷积操作**:一个大小为 $k \times k$ 的滤波器 $K$ 在输入特征图上滑动,在每个位置计算点积(第 06 章)。输出大小由三个超参数控制: + + - **步长**:滤波器在位置之间移动的像素数。步长 1 意味着滤波器每次移动一个像素。步长 2 意味着每次移动两个像素,空间维度减半。步长卷积是下采样时池化的一种替代方案。 + - **填充**:在输入边界周围添加零。"Same"填充($p = \lfloor k/2 \rfloor$)保持空间维度不变。"Valid"填充($p = 0$)会减小空间维度。 + - **空洞卷积**:在滤波器元素之间插入间隙。一个 3x3 的滤波器以空洞率 2 工作,仅用 9 个参数就覆盖了 5x5 的感受野。空洞卷积扩大了感受野而不增加计算量。 + +- 卷积后的输出空间大小: + +$$\text{out} = \left\lfloor \frac{\text{in} - k + 2p}{s} \right\rfloor + 1$$ + +- 其中 $\text{in}$ 是输入大小,$k$ 是卷积核大小,$p$ 是填充,$s$ 是步长。该公式独立地适用于高度和宽度。 + +- **感受野**是指能够影响某个神经元值的原始输入区域。 + - 早期层的感受野较小(它们看到的是边缘等局部模式)。 + - 更深层的感受野较大(它们看到的是物体部件等更大的结构)。 + +- 感受野随着每一层增长:大致每层卷积增加 $k - 1$ 个像素(加入步长或空洞卷积时增长更多)。 + +![感受野逐层增长:第 1 层神经元看到 3x3 的补丁,第 2 层神经元看到 5x5 的补丁,第 3 层神经元看到原始输入中 7x7 的补丁](../images/receptive_field.svg) + +- **池化**层在保留最重要信息的同时降低空间维度。 + - **最大池化**取每个窗口中的最大值,保留最强的激活(最突出的特征)。 + - **平均池化**取均值,平滑特征图。一个 2x2 的池化窗口配合步长 2 会使两个空间维度都减半。 + +- **全局平均池化(GAP)** 将每个通道的整个空间范围平均为单个数值,生成一个长度等于通道数的向量。GAP 取代了许多现代架构末尾的全连接层,大幅减少了参数量,并起到了结构正则化的作用。 + +- **批归一化(BatchNorm)** 将每个小批量内的激活值归一化为零均值和单位方差,然后应用可学习的缩放和平移(第 06 章)。在 CNN 中,批归一化按通道应用:统计量在跨批次和空间维度上为每个通道独立计算。它稳定了训练,允许使用更高的学习率,并起到轻度正则化的作用。 + +- **丢弃法**(第 06 章)在训练期间随机将神经元置零。 + +- 在 CNN 中,**空间丢弃法(Dropout2D)** 丢弃整个特征图通道而非单个像素,这更为有效,因为特征图中相邻像素高度相关。 + +- **数据增广**通过在训练期间对每张图像应用随机变换来人为地扩展训练集:水平翻转、随机裁剪、旋转、颜色抖动(调整亮度、对比度、饱和度、色调)以及 cutout(遮挡随机矩形区域)。网络以多种不同形式看到每张图像,迫使其学习变换不变的特征,而非记忆特定的像素模式。 + +- 高级增广策略包括 **Mixup**(混合两张图像及其标签:$\tilde{x} = \lambda x_i + (1-\lambda) x_j$,$\tilde{y} = \lambda y_i + (1-\lambda) y_j$)、**CutMix**(将一张图像的矩形区域粘贴到另一张图像上,并按面积比例混合标签)以及 **RandAugment**(从一个固定集合中随机采样一系列增广操作,使用单一的强度参数)。 + +- CNN 架构的历史是一个逐步走向更深、更高效设计的故事,每一步都解决了限制前代架构的问题。 + +- **LeNet-5**(LeCun 等人,1998 年)是最早的 CNN,专为手写数字识别设计。两个卷积层后接三个全连接层,使用平均池化和 tanh 激活函数。它证明了学习到的滤波器优于手工设计的特征,但按现代标准来看很小(6 万个参数)。 + +- **AlexNet**(Krizhevsky 等人,2012 年)以巨大优势赢得了 ImageNet 竞赛,引发了深度学习革命。关键创新:ReLU 激活函数(取代了存在梯度消失问题的 tanh)、用于正则化的丢弃法、数据增广以及在 GPU 上训练。五个卷积层,三个全连接层,6000 万个参数。 + +- **VGG**(Simonyan 和 Zisserman,2014 年)证明,仅使用 3x3 滤波器并深层堆叠效果优于更大的滤波器。两个堆叠的 3x3 滤波器具有与一个 5x5 滤波器相同的感受野,但参数更少($2 \times 3^2 = 18$ 对比 $5^2 = 25$)且多了一个非线性层。VGG-16(16 层)和 VGG-19(19 层)至今仍被广泛用作特征提取器。架构非常简单:卷积块通道数递增(64、128、256、512),每个块后接最大池化。 + +![VGG 架构:堆叠的 3x3 卷积块,通道深度递增(64→128→256→512),块之间是最大池化,末端是全连接层](../images/vgg_architecture.svg) + +- **GoogLeNet/Inception**(Szegedy 等人,2014 年)引入了 **Inception 模块**:不是选择单一的滤波器大小,而是并行使用 1x1、3x3 和 5x5 卷积,将它们的输出拼接起来,让网络决定哪个尺度最有用。1x1 卷积在较大滤波器之前用作瓶颈以减少计算量。GoogLeNet 以比 VGG 少 12 倍的参数(680 万对比 1.38 亿)实现了更高的准确率。 + +![Inception 模块:四个并行分支(1×1、3×3、5×5 和池化),带 1×1 瓶颈,沿通道维度拼接](../images/inception_module.svg) + +- Inception 模块同时捕获多个尺度的特征。1x1 滤波器捕获逐点模式,3x3 捕获局部纹理,5x5 捕获更大的结构。拼接将所有视角组合成丰富的表示。 + +- **ResNet**(He 等人,2016 年)解决了**退化问题**:更深的网络表现反而不如较浅的网络,这不是因为过拟合,而是因为更深的网络更难优化。解决方案是**跳跃连接**(残差连接): + +$$\text{output} = F(x) + x$$ + +- 该层学习残差 $F(x) = \text{output} - x$。如果最优变换接近恒等映射(这在深层网络中很常见),学习一个接近零的残差比学习完整的映射要容易得多。跳跃连接还提供了直接的梯度通道,减少了梯度消失问题。ResNet 训练了 152 层的网络,远超此前任何架构。 + +![ResNet 块:输入 x 经过两个卷积层得到 F(x),然后跳跃连接将 x 加回,得到输出 F(x) + x](../images/resnet_block.svg) + +- 当输入和输出维度不同时(由于步长或通道数变化),**投影捷径**会应用一个 1x1 卷积来匹配 $x$ 的维度:$\text{output} = F(x) + W_s x$。 + +- **瓶颈块**(用于 ResNet-50 及更深版本)使用三个卷积:1x1 降通道,3x3 进行空间处理,1x1 再将通道数恢复。这比两个 3x3 卷积计算量更小,允许构建更深的网络。 + +- **DenseNet**(Huang 等人,2017 年)将跳跃连接的思想进一步推进:在一个密集块内,每一层都与所有后续层相连。第 $l$ 层接收前面所有层的特征图作为输入:$x_l = H_l([x_0, x_1, \ldots, x_{l-1}])$,其中 $[\cdot]$ 表示沿通道维度的拼接。这促进了特征复用,增强了梯度流动,并减少了总参数量。 + +![DenseNet 密集块:每一层通过拼接接收前面所有层的特征图,形成密集连接以实现最大程度的特征复用](../images/densenet_block.svg) + +- **高效架构**面向移动设备和边缘硬件上的部署,这些场景下计算、内存和能耗都受到限制。 + +- **MobileNet**(Howard 等人,2017 年)用**深度可分离卷积**取代了标准卷积,将操作分解为两个步骤: + 1. **深度卷积**:每个输入通道应用一个独立的 $k \times k$ 滤波器(不跨通道交互) + 2. **逐点卷积**:应用 1x1 卷积来组合跨通道的信息 + +- 一个标准 $k \times k$ 卷积,输入通道数为 $C_{\text{in}}$,输出通道数为 $C_{\text{out}}$,每个空间位置需要 $k^2 \cdot C_{\text{in}} \cdot C_{\text{out}}$ 次乘法。深度可分离卷积需要 $k^2 \cdot C_{\text{in}} + C_{\text{in}} \cdot C_{\text{out}}$ 次,减少了大约 $k^2$ 倍。对于 3x3 滤波器,这大约便宜 9 倍。 + +![深度可分离卷积:深度步骤对每个通道应用一个 k×k 滤波器,然后逐点 1×1 卷积混合通道——输出形状相同,操作量约减少 9×](../images/depthwise_separable_conv.svg) + +- **MobileNet-V2** 引入了**逆残差块**:先用 1x1 卷积扩展通道,在扩展空间中应用深度卷积,再用 1x1 卷积投影回低维。跳跃连接放置在窄(瓶颈)层上,与 ResNet 的模式相反。扩展率通常为 6。 + +- **EfficientNet**(Tan 和 Le,2019 年)引入了**复合缩放**:不是独立地仅缩放深度、或仅缩放宽度、或仅缩放分辨率,而是使用固定比例同时缩放所有三个维度。给定缩放系数 $\phi$: + +$$\text{depth}: d = \alpha^\phi, \quad \text{width}: w = \beta^\phi, \quad \text{resolution}: r = \gamma^\phi$$ + +- 约束条件为 $\alpha \cdot \beta^2 \cdot \gamma^2 \approx 2$(这样 $\phi$ 每增加一个单位,总计算量大约翻倍)。通过网格搜索得到基线比例 $\alpha = 1.2$,$\beta = 1.1$,$\gamma = 1.15$。EfficientNet-B0 到 B7 逐步放大,以远少于之前模型的参数和 FLOPs 达到了最先进的准确率。 + +![EfficientNet 复合缩放:单独缩放宽度、深度或分辨率,与使用单一系数 φ 同时缩放三者](../images/efficientnet_scaling.svg) + +- **ShuffleNet** 通过使用**分组卷积**后接**通道混洗**来降低 1x1 卷积(在 MobileNet 风格的架构中占主导)的成本。分组卷积将通道分成多个组,在每个组内独立进行卷积,但这阻止了跨组的信息流动。混洗操作在组之间重新排列通道,以可忽略不计的成本恢复了信息混合。 + +- **迁移学习**是将在一个任务上训练好的模型适配到不同任务的实践。在计算机视觉中,这几乎总是意味着从一个在 ImageNet(140 万张图像,1000 个类别)上预训练的模型开始,适配到特定领域的数据集(医学图像、卫星图像、制造缺陷检测)。 + +- **特征提取**:冻结所有卷积层,移除最终的分类头,仅在上面训练一个新的分类头。冻结的层充当通用特征提取器。当目标域与 ImageNet 相似且目标数据集较小时,这种方法效果很好。 + +- **微调**:解冻部分或全部卷积层,以较小的学习率进行训练。预训练的权重作为起点而非固定特征。微调通常先解冻后面的层(这些层捕获高级的、任务特定的特征),再根据需要解冻更早的层。 + +- 迁移学习之所以有效,是因为 CNN 的早期层学习通用特征(边缘、纹理、颜色),这些特征对各种任务都有用,而后面层学习任务特定的特征。一个用于分类动物的网络,其边缘检测器对分类建筑物仍然有用。 + +- **可视化 CNN** 可以揭示网络学到了什么,并帮助调试意外行为。 + +- **激活图**(特征图)展示了给定输入图像下每个滤波器的输出。早期层的激活图看起来像边缘图;更深层的激活图则越来越抽象,空间上越来越粗糙。 + +- **Grad-CAM**(梯度加权类别激活映射,Selvaraju 等人,2017 年)高亮了输入图像中对模型预测最重要的区域。其工作原理是: + 1. 计算目标类别分数相对于最后一个卷积层特征图的梯度(使用第 03 章的链式法则) + 2. 对这些梯度进行全局平均池化,得到每个通道的重要性权重 + 3. 计算特征图的加权组合并应用 ReLU + +$$L_{\text{Grad-CAM}} = \text{ReLU}\!\left(\sum_k \alpha_k A^k\right), \quad \alpha_k = \frac{1}{Z} \sum_i \sum_j \frac{\partial y^c}{\partial A^k_{ij}}$$ + +- 其中 $A^k$ 是第 $k$ 个特征图,$\alpha_k$ 是通道 $k$ 的重要性权重,$y^c$ 是类别 $c$ 的分数。结果是一个粗糙的热力图,显示哪些区域驱动了分类。应用 ReLU 是因为我们只对具有正影响分类的特征感兴趣。 + +![Grad-CAM:一张狗的输入图像,最后一个卷积层的特征图,梯度加权组合,以及叠加在原始图像上的热力图,高亮了狗的脸部](../images/grad_cam.svg) + +- **特征反演**通过优化一张随机图像使其匹配目标特征(对像素值进行梯度下降),从特征表示中重建输入图像。这揭示了网络在各层保留了哪些信息。浅层几乎能完美重建图像;深层产生的图像可识别但有所扭曲,这表明精细的空间细节丢失了,而语义内容得以保留。 + +- **Deep Dream** 和**神经风格迁移**是特征可视化的创意应用。Deep Dream 最大化选定层中神经元的激活,产生超现实的、放大模式的图像。神经风格迁移优化目标图像,使其同时匹配一张图像的内容特征(来自深层)和另一张图像的风格特征(滤波器激活的 Gram 矩阵,捕获纹理统计信息)。 + +## 编程任务(使用 CoLab 或 notebook) + +1. 用 JAX 从头实现一个简单的 CNN,包含两个卷积层、最大池化和一个分类头。在一个合成的二维模式分类任务上训练它。 +```python +import jax +import jax.numpy as jnp +import jax.lax as lax +import matplotlib.pyplot as plt + +def conv2d(x, kernel, stride=1): + """简单 2D 卷积,单输入,单滤波器。""" + return lax.conv(x[None, None], kernel[None, None], (stride, stride), 'SAME')[0, 0] + +def max_pool(x, size=2): + """2x2 最大池化。""" + H, W = x.shape + x = x[:H//size*size, :W//size*size] + return x.reshape(H//size, size, W//size, size).max(axis=(1, 3)) + +def init_cnn(key): + k1, k2, k3 = jax.random.split(key, 3) + return { + 'conv1': jax.random.normal(k1, (5, 5)) * 0.3, + 'conv2': jax.random.normal(k2, (3, 3)) * 0.3, + 'fc_w': jax.random.normal(k3, (64, 1)) * 0.1, + 'fc_b': jnp.zeros(1), + } + +def forward_cnn(params, img): + # Conv1 -> ReLU -> Pool + h = jnp.maximum(0, conv2d(img, params['conv1'])) + h = max_pool(h) + # Conv2 -> ReLU -> Pool + h = jnp.maximum(0, conv2d(h, params['conv2'])) + h = max_pool(h) + # Flatten and classify + flat = h.ravel() + # Pad or truncate to fixed size + flat = jnp.pad(flat, (0, max(0, 64 - len(flat))))[:64] + logit = (flat @ params['fc_w'] + params['fc_b']).squeeze() + return jax.nn.sigmoid(logit) + +# Generate synthetic data: class 0 = low-freq pattern, class 1 = high-freq +def make_data(key, n=200): + images, labels = [], [] + for i in range(n): + k1, key = jax.random.split(key) + x, y = jnp.meshgrid(jnp.linspace(0, 4*jnp.pi, 32), jnp.linspace(0, 4*jnp.pi, 32)) + if i < n // 2: + img = jnp.sin(x) + jax.random.normal(k1, (32, 32)) * 0.1 + labels.append(0) + else: + img = jnp.sin(4 * x) * jnp.sin(4 * y) + jax.random.normal(k1, (32, 32)) * 0.1 + labels.append(1) + images.append(img) + return images, jnp.array(labels, dtype=jnp.float32) + +key = jax.random.PRNGKey(42) +images, labels = make_data(key) +params = init_cnn(jax.random.PRNGKey(0)) + +def loss_fn(params, img, label): + pred = forward_cnn(params, img) + return -(label * jnp.log(pred + 1e-7) + (1 - label) * jnp.log(1 - pred + 1e-7)) + +grad_fn = jax.grad(loss_fn) +lr = 0.01 + +for epoch in range(5): + total_loss = 0.0 + for img, label in zip(images, labels): + grads = grad_fn(params, img, label) + params = {k: params[k] - lr * grads[k] for k in params} + total_loss += loss_fn(params, img, label) + print(f"Epoch {epoch}: loss = {total_loss / len(images):.4f}") + +# Test accuracy +preds = jnp.array([forward_cnn(params, img) > 0.5 for img in images]) +acc = jnp.mean(preds == labels) +print(f"Accuracy: {acc:.2%}") +``` + +2. 可视化不同滤波器大小如何影响感受野。展示两个堆叠的 3x3 滤波器与一个 5x5 滤波器覆盖相同的感受野,但参数更少。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +def compute_receptive_field(layers): + """从一组 (kernel_size, stride) 元组计算感受野大小。""" + rf = 1 # 从 1 个像素开始 + stride_product = 1 + for k, s in layers: + rf += (k - 1) * stride_product + stride_product *= s + return rf + +# Compare architectures +configs = { + 'Single 5x5': [(5, 1)], + 'Two 3x3': [(3, 1), (3, 1)], + 'Three 3x3': [(3, 1), (3, 1), (3, 1)], + 'Single 7x7': [(7, 1)], + '3x3 stride 2 + 3x3': [(3, 2), (3, 1)], +} + +print(f"{'Config':<25} {'RF':>4} {'Params (per channel)':>20}") +print('-' * 55) +for name, layers in configs.items(): + rf = compute_receptive_field(layers) + # Parameters: sum of k^2 for each layer (per input-output channel pair) + params = sum(k * k for k, s in layers) + print(f"{name:<25} {rf:>4} {params:>20}") + +# Visualise receptive fields +fig, axes = plt.subplots(1, 3, figsize=(14, 4)) +for ax, (name, rf_size) in zip(axes, [('5x5 filter', 5), ('Two 3x3 filters', 5), ('Three 3x3 filters', 7)]): + grid = jnp.zeros((9, 9)) + c = 4 # centre + half = rf_size // 2 + grid = grid.at[c-half:c+half+1, c-half:c+half+1].set(1.0) + ax.imshow(grid, cmap='Blues', vmin=0, vmax=1) + ax.set_title(f'{name}\nRF = {rf_size}x{rf_size}') + ax.set_xticks(range(9)); ax.set_yticks(range(9)) + ax.grid(True, alpha=0.3) +plt.suptitle('Receptive Field Comparison') +plt.tight_layout(); plt.show() +``` + +3. 从头实现 Grad-CAM。给定一个预构建的简单 CNN,计算针对特定类别的梯度加权激活图,并将其可视化为热力图。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +def simple_cnn(params, img): + """返回预测和最后一个卷积层激活的简单 CNN。""" + # Conv layer (our "last conv layer" for Grad-CAM) + H, W = img.shape + k = params['conv'].shape[0] + pad = k // 2 + img_pad = jnp.pad(img, pad, mode='edge') + activation_map = jnp.zeros((H, W)) + for i in range(H): + for j in range(W): + activation_map = activation_map.at[i, j].set( + jnp.sum(img_pad[i:i+k, j:j+k] * params['conv']) + ) + activation_map = jnp.maximum(0, activation_map) # ReLU + + # Global average pool -> dense -> output + pooled = activation_map.mean() + logit = pooled * params['w'] + params['b'] + return jax.nn.sigmoid(logit), activation_map + +# Create test image: bright region on the left (class indicator) +img = jnp.zeros((32, 32)) +img = img.at[8:24, 4:16].set(1.0) +img = img.at[5:10, 20:28].set(0.3) + +key = jax.random.PRNGKey(42) +params = { + 'conv': jax.random.normal(key, (5, 5)) * 0.3, + 'w': jnp.array(2.0), + 'b': jnp.array(-0.5), +} + +# Compute Grad-CAM +def class_score(params, img): + pred, _ = simple_cnn(params, img) + return pred + +# Get activation map and gradients +pred, act_map = simple_cnn(params, img) +grad_fn = jax.grad(lambda img: simple_cnn(params, img)[0]) +img_grad = grad_fn(img) + +# Weight = global average of gradients (simplified 1-channel Grad-CAM) +alpha = img_grad.mean() +grad_cam = jnp.maximum(0, alpha * act_map) # ReLU +grad_cam = (grad_cam - grad_cam.min()) / (grad_cam.max() - grad_cam.min() + 1e-8) + +fig, axes = plt.subplots(1, 3, figsize=(14, 4)) +axes[0].imshow(img, cmap='gray'); axes[0].set_title('Input Image'); axes[0].axis('off') +axes[1].imshow(act_map, cmap='viridis'); axes[1].set_title('Activation Map'); axes[1].axis('off') +axes[2].imshow(img, cmap='gray', alpha=0.6) +axes[2].imshow(grad_cam, cmap='jet', alpha=0.4) +axes[2].set_title(f'Grad-CAM (pred={pred:.2f})'); axes[2].axis('off') +plt.tight_layout(); plt.show() +``` + +4. 比较深度可分离卷积与标准卷积。统计两者的参数和 FLOPs,并展示它们在计算量少得多的情况下产生相似的输出。 +```python +import jax +import jax.numpy as jnp + +def standard_conv(x, kernel): + """标准卷积:(H, W, C_in) * (k, k, C_in, C_out) -> (H, W, C_out)。""" + H, W, C_in = x.shape + k, _, _, C_out = kernel.shape + pad = k // 2 + x_pad = jnp.pad(x, ((pad, pad), (pad, pad), (0, 0)), mode='constant') + out = jnp.zeros((H, W, C_out)) + for i in range(H): + for j in range(W): + patch = x_pad[i:i+k, j:j+k, :] # (k, k, C_in) + for c in range(C_out): + out = out.at[i, j, c].set(jnp.sum(patch * kernel[:, :, :, c])) + return out + +def depthwise_separable_conv(x, dw_kernel, pw_kernel): + """深度可分离:深度卷积 (k,k,C_in) 然后逐点卷积 (C_in, C_out)。""" + H, W, C_in = x.shape + k = dw_kernel.shape[0] + pad = k // 2 + x_pad = jnp.pad(x, ((pad, pad), (pad, pad), (0, 0)), mode='constant') + + # Depthwise: one filter per channel + dw_out = jnp.zeros((H, W, C_in)) + for i in range(H): + for j in range(W): + for c in range(C_in): + patch = x_pad[i:i+k, j:j+k, c] + dw_out = dw_out.at[i, j, c].set(jnp.sum(patch * dw_kernel[:, :, c])) + + # Pointwise: 1x1 conv across channels + out = dw_out @ pw_kernel + return out + +# Setup +H, W, C_in, C_out, k = 8, 8, 16, 32, 3 +key = jax.random.PRNGKey(42) +k1, k2, k3, k4 = jax.random.split(key, 4) + +x = jax.random.normal(k1, (H, W, C_in)) +std_kernel = jax.random.normal(k2, (k, k, C_in, C_out)) * 0.1 +dw_kernel = jax.random.normal(k3, (k, k, C_in)) * 0.1 +pw_kernel = jax.random.normal(k4, (C_in, C_out)) * 0.1 + +# Compare +std_params = k * k * C_in * C_out +dw_params = k * k * C_in + C_in * C_out + +std_flops = H * W * k * k * C_in * C_out +dw_flops = H * W * (k * k * C_in + C_in * C_out) + +print(f"Standard conv: {std_params:>8,} params, {std_flops:>10,} FLOPs") +print(f"Depthwise separable conv: {dw_params:>8,} params, {dw_flops:>10,} FLOPs") +print(f"Parameter reduction: {std_params / dw_params:.1f}x") +print(f"FLOP reduction: {std_flops / dw_flops:.1f}x") + +std_out = standard_conv(x, std_kernel) +ds_out = depthwise_separable_conv(x, dw_kernel, pw_kernel) +print(f"\nStandard output shape: {std_out.shape}") +print(f"Depthwise sep output shape: {ds_out.shape}") +``` diff --git a/chapter 08: computer vision/03. object detection and segmentation.md b/chapter 08: computer vision/03. object detection and segmentation.md new file mode 100644 index 0000000..2d02447 --- /dev/null +++ b/chapter 08: computer vision/03. object detection and segmentation.md @@ -0,0 +1,376 @@ +# 目标检测与分割 + +*目标检测定位并分类图像中的每个物体;分割为每个像素分配一个标签。本文件涵盖交并比(IoU)、平均精度均值(mAP)、锚框、R-CNN系列、YOLO、SSD、特征金字塔网络(FPN)、语义/实例/全景分割(U-Net、Mask R-CNN、SAM)以及用于基准测试的评估指标。* + +- 图像分类(文件02)回答了"这张图像里有什么?"目标检测提出了一个更难的问题:"这张图像里有哪些物体,它们在哪里?" + +- 分割则更进一步:"哪些像素属于哪个物体或类别?"这些任务形成了一个空间理解精度逐步提高的层次结构。 + +- **目标检测**模型输出一组**边界框**,每个边界框由四个坐标(左上角 $x, y$、宽度、高度)以及一个带有置信度分数的类别标签定义。一张图像可能包含零个、一个或数百个来自多个类别的物体。 + +![输入图像中包含多个物体,每个物体由一个彩色边界框和带有置信度分数的类别标签包围](../images/detection_boxes.svg) + +- **交并比(IoU)**衡量预测边界框与真实标注的匹配程度。它是重叠面积除以并集面积: + +$$\text{IoU} = \frac{\text{交集面积}}{\text{并集面积}}$$ + +- IoU为1表示完全重叠,IoU为0表示完全不重叠。"正确"检测的标准阈值为IoU $\geq 0.5$,但也使用更严格的阈值(0.75、0.9)。 + +- 如果预测框与真实框的IoU超过阈值且类别正确,则检测结果为**真正例(TP)**。 + +- **假正例(FP)**是未匹配到任何真实标注的预测框。 + +- **假负例(FN)**是没有任何预测框匹配到的真实物体。这些与第06章中的精确率和召回率概念相同。 + +- **平均精度(AP)**总结单个类别的检测质量。对于每个类别,按置信度分数对所有检测结果排序,计算每个排序位置的精确率和召回率,然后计算精确率-召回率曲线下的面积: + +$$\text{AP} = \int_0^1 p(r) \, dr$$ + +- 在实践中,曲线是插值处理的:在每个召回率水平上,精确率被设置为所有召回率 $\geq r$ 处的最大精确率。这使曲线平滑并使其单调递减。 + +- **平均精度均值(mAP)**对所有类别的AP进行平均。"mAP@0.5"使用IoU阈值0.5。"mAP@[.5:.95]"(COCO标准)在从0.5到0.95的十个IoU阈值上(步长0.05)对mAP进行平均,同时奖励检测能力和精确的定位能力。 + +- **非极大值抑制(NMS)**移除重复的检测结果。当模型为同一个物体预测出多个重叠的边界框时,NMS保留置信度最高的框,并移除所有与其重叠超过IoU阈值的其他框。这是在模型生成原始预测之后,按每个类别分别进行的。 + +- **两阶段检测器**首先提出候选区域,然后对每个提案进行分类和精细化调整。 + +- **R-CNN**(Girshick 等人,2014年)是第一个成功的深度学习检测器。它使用选择性搜索(一种经典算法)提出约2,000个候选区域,将每个区域变形为固定尺寸,独立通过CNN运行,并使用SVM(第06章)进行分类。R-CNN准确但极其缓慢:每张图像需要运行CNN 2,000次。 + +- **Fast R-CNN**(Girshick,2015年)解决了冗余问题:它在整张图像上运行一次CNN以生成共享特征图,然后使用**RoI池化**(感兴趣区域池化)从该共享特征图中为每个提案提取特征。 + +- RoI池化从特征图中取出一个可变大小的区域,通过将该区域划分为一个网格并在每个单元格内进行最大池化,生成固定大小的输出。这种方法快得多,因为昂贵的CNN计算只进行一次。 + +- **Faster R-CNN**(Ren 等人,2015年)引入了**区域提议网络(RPN)**,从而消除了外部区域提议算法。RPN是一个小型CNN,运行在共享特征图之上,直接预测提案。RPN在特征图上滑动一个小窗口,在每个位置上预测 $k$ 个提案(每个**锚框**对应一个提案)。 + +![Faster R-CNN流程:输入图像 → 骨干CNN → 共享特征图 → RPN生成提案 → RoI池化 → 分类和边界框回归头](../images/faster_rcnn.svg) + +- **锚框**是特征图上每个空间位置处预定义的边界框,覆盖不同的尺度和长宽比(例如,三个尺度 $\times$ 三个比例 = 每个位置9个锚框)。RPN为每个锚框预测两样东西:物体性分数(物体vs背景)以及用于将锚框精炼为更紧凑提案的坐标偏移量。这种参数化使回归问题更容易:网络不需要预测绝对坐标,只需预测对合理初始框的小幅调整。 + +- 锚框偏移量的参数化公式为: + +$$t_x = \frac{x - x_a}{w_a}, \quad t_y = \frac{y - y_a}{h_a}, \quad t_w = \log\frac{w}{w_a}, \quad t_h = \log\frac{h}{h_a}$$ + +- 其中 $(x, y, w, h)$ 是预测框的中心和尺寸,$(x_a, y_a, w_a, h_a)$ 是锚框。宽度和高度的对数变换确保预测框始终为正数,并使回归具有尺度不变性。 + +- Faster R-CNN使用多任务损失进行训练:类别标签的分类损失(第05章的交叉熵),以及用于边界框回归的**平滑L1损失**。平滑L1对异常值不如L2敏感: + +```math +\text{smooth}_{L1}(x) = \begin{cases} 0.5x^2 & \text{if } |x| < 1 \\ |x| - 0.5 & \text{otherwise} \end{cases} +``` + +- **特征金字塔网络(FPN)**(Lin 等人,2017年)通过构建一个带有侧边连接的自顶向下路径来解决多尺度问题,该路径将高层语义信息与低层空间细节融合。骨干网络生成多个尺度的特征图(每个池化层将分辨率减半)。FPN添加了一个自顶向下的路径,其中每个层级接收来自上一层级的上采样特征,并通过侧边1x1卷积与对应的自底向上层级合并。结果是一个特征图金字塔,每个层级的特征图既具有强语义信息又具有良好的空间分辨率。 + +- 小物体从金字塔的高分辨率层级检测;大物体从低分辨率层级检测。FPN现在已成为大多数现代检测架构的标准组件。 + +- **单阶段检测器**完全跳过了提案步骤,在一次前向传播中直接预测类别标签和边界框。这种方法更快,但在历史上准确率低于两阶段检测器,直到焦点损失(focal loss)缩小了这一差距。 + +- **YOLO**(You Only Look Once,Redmon 等人,2016年)将图像划分为一个 $S \times S$ 的网格。每个网格单元预测 $B$ 个边界框和 $C$ 个类别概率。如果一个物体的中心落在一个网格单元内,该单元负责检测该物体。YOLO极其快速,因为整个检测过程只有一次前向传播,没有提案阶段。 + +- **YOLOv2**添加了锚框、批归一化和多尺度训练。**YOLOv3**使用了特征金字塔网络并在三个尺度上进行预测。**YOLOv4-v8**继续改进,采用了更好的骨干网络、路径聚合网络和马赛克数据增强(在训练中将四张图像拼接在一起以增加上下文多样性)。 + +- **SSD**(Single Shot MultiBox Detector,Liu 等人,2016年)在骨干网络内的多个特征图尺度上进行预测,在每个尺度上使用锚框。早期(高分辨率)特征图检测小物体;后期(低分辨率)特征图检测大物体。SSD比Faster R-CNN更快,且具有竞争力的准确率。 + +- **RetinaNet**(Lin 等人,2017年)指出了单阶段检测器的核心问题:类别不平衡。绝大多数锚框对应的是背景,这产生了大量容易的负样本,它们主导了损失函数并压倒了来自稀有正样本的梯度。 + +- **焦点损失(Focal Loss)**通过降低容易样本的权重来解决这个问题: + +$$\text{FL}(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t)$$ + +- 其中 $p_t$ 是正确类别的预测概率。当模型自信且正确时($p_t$ 很高),$(1 - p_t)^\gamma$ 很小,从而减少了容易负样本对损失的贡献。超参数 $\gamma$ (通常为2)控制降权的强度。当 $\gamma = 0$ 时,焦点损失退化为标准交叉熵。凭借焦点损失,RetinaNet以单阶段的速度实现了与两阶段检测器相当的准确率。 + +- **无锚框检测**完全消除了锚框,减少了超参数调优并简化了流程。 + +- **FCOS**(全卷积单阶段检测器,Tian 等人,2019年)在特征图的每个空间位置预测从该位置到最近边界框四条边(左、上、右、下)的距离以及一个类别标签。**中心性(centerness)**分数降低了远离物体中心的预测的权重,从而提高了质量。FCOS使用FPN来处理多尺度问题。 + +- **CenterNet**(Zhou 等人,2019年)将物体检测为点:它预测一个热力图,其中的峰值对应物体中心,然后在每个峰值处回归宽度和高度。检测变成了关键点估计。这种方法优雅且无需锚框,但需要仔细的热力图后处理。 + +- **CornerNet**将物体检测为一对角点(左上角和右下角)。它预测两个热力图(每个角类型一个),并使用**关联嵌入(associative embedding)**将对应的角点匹配成边界框。这避免了对锚框的需求,并处理了任意形状的物体。 + +- **语义分割**为图像中的每个像素分配一个类别标签。与检测(输出边界框)不同,分割生成密集的像素级映射。一条街景可能会将每个像素标记为道路、人行道、汽车、行人、建筑、天空等。 + +![语义分割:输入街景及其像素级标签图,每种颜色代表一个类别](../images/semantic_segmentation.svg) + +- **全卷积网络(FCN)**(Long 等人,2015年)通过将全连接层替换为卷积层,使分类CNN适用于分割任务,从而使网络能够输出空间映射而非单个类别。上采样(通过转置卷积或双线性插值)将输出恢复到输入分辨率。来自早期层的跳跃连接添加了在下采样过程中丢失的空间细节。 + +- **转置卷积**(有时称为"反卷积")是卷积的上采样对应操作。步幅卷积减少空间维度,而转置卷积增加空间维度。它在输入元素之间插入零,然后应用标准卷积,从而有效地学习如何上采样。 + +- **U-Net**(Ronneberger 等人,2015年)引入了一种对称的编码器-解码器架构,在每一层都有跳跃连接。编码器(收缩路径)在增加通道数的同时降低空间分辨率,与分类CNN完全相同。解码器(扩展路径)将结果上采样回全分辨率。跳跃连接在每一层将编码器特征图与解码器特征图拼接起来,为解码器提供精细的空间细节。这种高层语义与低层细节的结合产生了清晰、准确的分割边界。 + +![U-Net架构:左侧为带下采样的编码器路径,右侧为带下采样的解码器路径,以及连接对应层级的跳跃连接](../images/unet_architecture.svg) + +- U-Net最初是为生物医学图像分割设计的(其中训练数据稀缺),其架构已成为许多后续模型的基础,包括潜在扩散模型中的U-Net(文件04)。 + +- **DeepLab**(Chen 等人,2014-2018年)为分割引入了两个关键创新: + + - **空洞(扩张)卷积**:在滤波器元素之间插入间隙的标准卷积,由扩张率 $r$ 控制。一个扩张率为 $r$ 的3x3滤波器的感受野为 $(2r + 1) \times (2r + 1)$,而仅使用9个参数。这在不进行下采样的情况下捕获多尺度上下文,同时保持空间分辨率。 + + - **空洞空间金字塔池化(ASPP)**:并行应用多个具有不同扩张率的空洞卷积(例如,扩张率1、6、12、18),拼接结果,并通过1x1卷积融合。ASPP同时捕获多个尺度的上下文,其精神类似于Inception模块(文件02),但使用扩张而非不同大小的卷积核。 + +- DeepLab还使用**条件随机场(CRF)**(第05章)作为后处理步骤,通过鼓励空间上相邻且颜色相似的像素共享相同的标签来优化分割边界。 + +- **实例分割**结合了检测和分割:它识别每个单独的物体实例,并为每个实例生成像素级掩码。场景中的两辆车会得到两个独立的掩码,而不仅仅是"车"。 + +- **Mask R-CNN**(He 等人,2017年)通过添加一个小型分割头来扩展Faster R-CNN,该分割头为每个检测到的物体预测一个二值掩码。其架构为Faster R-CNN加上一个掩码分支:掩码分支接收RoI池化后的特征,并为每个类别输出一个 $m \times m$ 的二值掩码。它使用**RoIAlign**代替RoI池化:在精确定位的采样点处进行双线性插值,而非在量化的网格单元格内进行,这避免了量化引起的空间错位。这一小改动显著提高了掩码质量。 + +- Mask R-CNN使用多任务损失进行训练:分类损失 + 边界框回归损失 + 掩码损失(逐像素二值交叉熵)。掩码分支独立地为每个类别预测一个掩码;仅使用与预测类别对应的掩码,这使掩码预测与分类解耦,并同时改进了两者。 + +- **全景分割**将语义分割和实例分割统一为单个任务。每个像素同时获得一个类别标签(语义)和一个实例ID(用于"物体"类别,如汽车和人)。"背景"类别(天空、道路、草地)只获得语义标签,因为它们是无形区域,没有可计数的实例。 + +- 全景质量(PQ)指标通过分解为分割质量(匹配片段的平均IoU)和识别质量(匹配片段的F1分数)来评估: + +$$\text{PQ} = \underbrace{\frac{\sum_{(p,g) \in \text{TP}} \text{IoU}(p,g)}{|\text{TP}|}}_{\text{SQ}} \times \underbrace{\frac{|\text{TP}|}{|\text{TP}| + \frac{1}{2}|\text{FP}| + \frac{1}{2}|\text{FN}|}}_{\text{RQ}}$$ + +- **实时分割**对于自动驾驶和增强现实等应用至关重要,这些应用对延迟预算要求严格(通常每帧不超过30毫秒)。 + +- **BiSeNet**(双边分割网络,Yu 等人,2018年)使用两条并行路径:一条**空间路径**,具有宽而浅的层以保留空间细节;一条**上下文路径**,具有深而窄的层以捕获语义信息。输出被融合,兼顾速度和准确率。 + +- **DDRNet**(深度双分辨率网络,Hong 等人,2021年)在整个网络中以不同分辨率维持两个分支,并在它们之间反复交换信息。高分辨率分支保留空间细节,而低分辨率分支捕获全局上下文。多个双边融合模块在两个方向上合并信息。 + +- 实时分割的总体趋势是避免沉重的编码器-解码器模式,而是通过网络全程维持足够的空间分辨率,以一定的准确率为代价换取显著更低的延迟。 + +## 编程练习(使用CoLab或notebook) + +1. 从头实现IoU计算和非极大值抑制。对一组重叠的边界框应用NMS并可视化结果。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt +import matplotlib.patches as patches + +def compute_iou(box1, box2): + """计算两个框[x1, y1, x2, y2]之间的IoU。""" + x1 = jnp.maximum(box1[0], box2[0]) + y1 = jnp.maximum(box1[1], box2[1]) + x2 = jnp.minimum(box1[2], box2[2]) + y2 = jnp.minimum(box1[3], box2[3]) + + intersection = jnp.maximum(0, x2 - x1) * jnp.maximum(0, y2 - y1) + area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) + area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) + union = area1 + area2 - intersection + + return intersection / (union + 1e-6) + +def nms(boxes, scores, iou_threshold=0.5): + """非极大值抑制。""" + order = jnp.argsort(-scores) # 按置信度降序排列 + keep = [] + + remaining = list(range(len(scores))) + order_list = order.tolist() + + while order_list: + idx = order_list[0] + keep.append(idx) + order_list = order_list[1:] + + new_order = [] + for j in order_list: + iou = compute_iou(boxes[idx], boxes[j]) + if iou < iou_threshold: + new_order.append(j) + order_list = new_order + + return keep + +# 示例:同一物体的重叠检测 +boxes = jnp.array([ + [50, 60, 150, 160], # 高置信度 + [55, 65, 155, 165], # 重叠的重复框 + [52, 58, 148, 158], # 重叠的重复框 + [200, 100, 300, 200], # 不同物体 + [205, 105, 305, 205], # 重叠的重复框 +]) +scores = jnp.array([0.95, 0.80, 0.70, 0.90, 0.60]) + +keep = nms(boxes, scores, iou_threshold=0.5) + +fig, axes = plt.subplots(1, 2, figsize=(14, 5)) +colors = ['#3498db', '#e74c3c', '#27ae60', '#9b59b6', '#f39c12'] + +for ax, title, indices in zip(axes, ['NMS之前', 'NMS之后'], + [range(len(boxes)), keep]): + ax.set_xlim(0, 400); ax.set_ylim(0, 300) + ax.set_aspect('equal'); ax.invert_yaxis() + ax.set_title(title) + for i in indices: + b = boxes[i] + rect = patches.Rectangle((b[0], b[1]), b[2]-b[0], b[3]-b[1], + linewidth=2, edgecolor=colors[i], + facecolor='none') + ax.add_patch(rect) + ax.text(b[0], b[1]-5, f'{scores[i]:.2f}', color=colors[i], fontsize=10) + +plt.tight_layout(); plt.show() +print(f"NMS后保留了{len(keep)}个框,共{len(boxes)}个") +``` + +2. 实现一个简化的区域提议网络(RPN)。给定一个特征图,生成具有多种尺度和长宽比的锚框,并预测物体性分数和边界框偏移量。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import matplotlib.patches as patches + +def generate_anchors(feature_h, feature_w, stride, scales, ratios): + """为特征图上的每个位置生成锚框。""" + anchors = [] + for y in range(feature_h): + for x in range(feature_w): + cx = (x + 0.5) * stride + cy = (y + 0.5) * stride + for s in scales: + for r in ratios: + w = s * jnp.sqrt(r) + h = s / jnp.sqrt(r) + anchors.append([cx - w/2, cy - h/2, cx + w/2, cy + h/2]) + return jnp.array(anchors) + +def rpn_forward(feature_map, params): + """简化版RPN:预测每个锚框的物体性和框偏移量。""" + H, W, C = feature_map.shape + n_anchors = params['cls_w'].shape[1] + + # 在特征图上滑动1x1卷积(简化版) + cls_scores = feature_map.reshape(-1, C) @ params['cls_w'] # (H*W, n_anchors) + box_offsets = feature_map.reshape(-1, C) @ params['reg_w'] # (H*W, n_anchors*4) + + cls_scores = jax.nn.sigmoid(cls_scores) + return cls_scores.ravel(), box_offsets.reshape(-1, 4) + +# 设置 +feature_h, feature_w, channels = 4, 4, 16 +stride = 16 # 每个特征图单元格覆盖16x16像素 +scales = [32, 64, 128] +ratios = [0.5, 1.0, 2.0] +n_anchors_per_pos = len(scales) * len(ratios) + +key = jax.random.PRNGKey(42) +k1, k2, k3 = jax.random.split(key, 3) + +feature_map = jax.random.normal(k1, (feature_h, feature_w, channels)) +params = { + 'cls_w': jax.random.normal(k2, (channels, n_anchors_per_pos)) * 0.01, + 'reg_w': jax.random.normal(k3, (channels, n_anchors_per_pos * 4)) * 0.01, +} + +anchors = generate_anchors(feature_h, feature_w, stride, scales, ratios) +scores, offsets = rpn_forward(feature_map, params) + +print(f"特征图:{feature_h}x{feature_w},步幅={stride}") +print(f"每个位置的锚框数:{n_anchors_per_pos}") +print(f"锚框总数:{len(anchors)}") +print(f"物体性分数形状:{scores.shape}") +print(f"边界框偏移量形状:{offsets.shape}") + +# 可视化一个位置的锚框 +fig, ax = plt.subplots(figsize=(6, 6)) +img_size = feature_h * stride +ax.set_xlim(0, img_size); ax.set_ylim(0, img_size) +ax.invert_yaxis(); ax.set_aspect('equal') + +pos_idx = feature_h // 2 * feature_w + feature_w // 2 # 中心位置 +colors = ['#3498db', '#e74c3c', '#27ae60'] +for i, s in enumerate(scales): + for j, r in enumerate(ratios): + idx = pos_idx * n_anchors_per_pos + i * len(ratios) + j + a = anchors[idx] + rect = patches.Rectangle((a[0], a[1]), a[2]-a[0], a[3]-a[1], + linewidth=1.5, edgecolor=colors[i], + facecolor='none', linestyle=['--', '-', ':'][j]) + ax.add_patch(rect) + +ax.scatter([img_size/2], [img_size/2], c='red', s=50, zorder=5) +ax.set_title(f'中心位置的锚框\n3个尺度 × 3个比例 = {n_anchors_per_pos}') +ax.grid(True, alpha=0.3) +plt.tight_layout(); plt.show() +``` + +3. 实现一个简化版的一维U-Net编码器-解码器,带有跳跃连接,用于一维分割(一维信号的二值标注)。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +def conv1d_same(x, kernel): + """具有相同填充的一维卷积。""" + k = len(kernel) + pad = k // 2 + x_pad = jnp.pad(x, pad, mode='edge') + n = len(x) + out = jnp.zeros(n) + for i in range(n): + out = out.at[i].set(jnp.sum(x_pad[i:i+k] * kernel)) + return out + +def downsample(x): + return x[::2] + +def upsample(x, target_len): + return jnp.interp(jnp.linspace(0, 1, target_len), jnp.linspace(0, 1, len(x)), x) + +def unet_1d(x, params): + """简化版一维U-Net,包含2个编码器/解码器层级。""" + # 编码器 + e1 = jnp.maximum(0, conv1d_same(x, params['enc1'])) + e1_down = downsample(e1) + + e2 = jnp.maximum(0, conv1d_same(e1_down, params['enc2'])) + e2_down = downsample(e2) + + # 瓶颈层 + bottleneck = jnp.maximum(0, conv1d_same(e2_down, params['bottleneck'])) + + # 带跳跃连接的解码器 + d2_up = upsample(bottleneck, len(e2)) + d2 = jnp.maximum(0, conv1d_same(d2_up + e2, params['dec2'])) # 跳跃连接 + + d1_up = upsample(d2, len(e1)) + d1 = conv1d_same(d1_up + e1, params['dec1']) # 跳跃连接 + + return jax.nn.sigmoid(d1) + +# 创建带有标注区域的信号 +n = 128 +t = jnp.linspace(0, 4 * jnp.pi, n) +signal = jnp.sin(t) + 0.5 * jnp.sin(3 * t) +labels = (signal > 0.5).astype(jnp.float32) # 二值分割目标 + +key = jax.random.PRNGKey(42) +keys = jax.random.split(key, 5) +params = { + 'enc1': jax.random.normal(keys[0], (5,)) * 0.3, + 'enc2': jax.random.normal(keys[1], (5,)) * 0.3, + 'bottleneck': jax.random.normal(keys[2], (3,)) * 0.3, + 'dec2': jax.random.normal(keys[3], (5,)) * 0.3, + 'dec1': jax.random.normal(keys[4], (5,)) * 0.3, +} + +def loss_fn(params, signal, labels): + pred = unet_1d(signal, params) + return -jnp.mean(labels * jnp.log(pred + 1e-7) + (1 - labels) * jnp.log(1 - pred + 1e-7)) + +grad_fn = jax.jit(jax.grad(loss_fn)) +lr = 0.05 + +for step in range(500): + grads = grad_fn(params, signal, labels) + params = {k: params[k] - lr * grads[k] for k in params} + +pred = unet_1d(signal, params) + +fig, axes = plt.subplots(3, 1, figsize=(12, 7), sharex=True) +axes[0].plot(t, signal, color='#3498db', linewidth=1.5) +axes[0].set_title('输入信号'); axes[0].set_ylabel('值') + +axes[1].fill_between(t, 0, labels, alpha=0.3, color='#27ae60') +axes[1].set_title('真实标注'); axes[1].set_ylabel('标签') + +axes[2].plot(t, pred, color='#e74c3c', linewidth=1.5) +axes[2].fill_between(t, 0, (pred > 0.5).astype(float), alpha=0.2, color='#e74c3c') +axes[2].set_title('U-Net预测'); axes[2].set_ylabel('概率') +axes[2].set_xlabel('t') + +plt.tight_layout(); plt.show() +print(f"最终损失:{loss_fn(params, signal, labels):.4f}") +print(f"像素准确率:{jnp.mean((pred > 0.5) == labels):.2%}") +``` diff --git a/chapter 08: computer vision/04. vision transformers and generation.md b/chapter 08: computer vision/04. vision transformers and generation.md new file mode 100644 index 0000000..bc210e4 --- /dev/null +++ b/chapter 08: computer vision/04. vision transformers and generation.md @@ -0,0 +1,360 @@ +# 视觉Transformer与生成模型 + +*视觉Transformer将自注意力应用于图像块,通过数据驱动的空间学习挑战了CNN的主导地位。本文涵盖ViT、DeiT、Swin Transformer、基于GAN的图像生成(StyleGAN)、VAE和扩散模型(DDPM、Stable Diffusion),以及超分辨率和神经风格迁移。* + +- CNN(文件02)内置了很强的空间归纳偏置:局部连接、权重共享和平移等变性。视觉Transformer(ViT)提出了一个启发性的问题:如果我们完全抛弃这些偏置,仅使用第06章中的注意力机制,让模型从数据中学习空间结构,结果会怎样? + +- **ViT**(Vision Transformer,Dosovitskiy等人,2021)将标准的Transformer编码器直接应用于图像。其核心思想是将图像视为一个图像块序列,就像NLP将文本视为一个词元序列一样。 + +- 其处理流程如下: + 1. 将图像(高度$H$,宽度$W$,通道数$C$)分割成$P \times P$大小的不重叠图像块网格。得到$N = HW / P^2$个图像块。 + 2. 将每个图像块展平成长度为$P^2 \cdot C$的向量,并通过一个可学习的线性嵌入(单个矩阵乘法,第02章)将其投影到模型维度$D$。 + 3. 在前面添加一个可学习的**[CLS]标记**嵌入(类似于BERT的[CLS],第07章)。该标记会关注所有图像块,其最终表示用于分类。 + 4. 添加**位置嵌入**(每个位置一个可学习向量)以提供空间信息,因为注意力是置换等变的。 + 5. 将$(N + 1)$个标记嵌入序列通过标准的Transformer编码器(多头自注意力 + FFN,第06章)。 + 6. [CLS]标记的最终表示通过一个分类头(小型MLP)进行分类。 + +![ViT流程:将图像分割为16x16图像块,每个块展平并线性投影,添加[CLS]标记,加上位置嵌入,然后由Transformer编码器块处理](../images/vit_pipeline.svg) + +- **图像块嵌入**等价于一个卷积核大小为$P$、步长为$P$(不重叠)的卷积操作。ViT将2D图像字面地转换为1D序列,然后用与处理语言相同的架构来处理它。 + +- ViT的归纳偏置比CNN少:它不强制局部连接或平移等变性。这意味着它需要更多的训练数据才能从头学习空间结构。在小型数据集上,CNN优于ViT。但在非常大的数据集(JFT-300M,3亿张图像)上训练时,ViT达到或超过了最佳CNN的性能,这表明CNN的归纳偏置有助于数据效率,但对于最终性能并非必需。 + +- ViT自注意力的复杂度为$O(N^2)$,其中N是图像块数量。对于224x224的图像和16x16的图像块,$N = 196$,这在可控范围内。但对于更高分辨率的图像或更小的图像块,二次成本变得难以承受。 + +- **DeiT**(数据高效的图像Transformer,Touvron等人,2021)表明,仅使用ImageNet(无需庞大的JFT数据集)并借助强数据增强、正则化(随机深度、标签平滑、dropout)和**知识蒸馏**,就可以有效训练ViT:一个预训练的CNN教师提供软标签,ViT学生学习匹配这些标签。DeiT在[CLS]标记旁边添加了一个**蒸馏标记**,训练用于预测教师的输出。 + +- **Swin Transformer**(Liu等人,2021)解决了ViT的两个主要局限:随图像大小呈二次增长的计算成本,以及缺少层次化特征图(检测和分割需要层次化特征)。 + +- Swin引入了**移动窗口**:不再对所有图像块进行全局自注意力,而是在局部窗口内(例如7x7个图像块)计算注意力。这使得计算成本与图像大小呈线性关系:$O(N)$而非$O(N^2)$。但仅靠局部窗口会阻止区域之间的信息流动。 + +- **窗口移动**解决了这个问题:在交替层中,窗口划分会偏移半个窗口大小。这创建了跨窗口连接,使得信息可以在所有图像部分之间流动,而无需全局注意力的成本。 + +![Swin Transformer:第l层在常规窗口内计算注意力,第l+1层将窗口划分偏移一半,创建跨窗口连接](../images/swin_shifted_windows.svg) + +- Swin还通过跨阶段合并图像块来构建**层次化表示**。每个阶段之后,相邻的2x2图像块被拼接并投影,使通道维度加倍、空间分辨率减半。这产生了多尺度特征图,类似于CNN和FPN(文件03)中的特征图,使得Swin可以直接兼容Faster R-CNN等检测头和U-Net等分割头。 + +- **PVT**(金字塔视觉Transformer)采用了类似的层次化方法,具有空间缩减注意力:在每个阶段,键和值在计算注意力之前先进行空间下采样,从而在保持全局感受野的同时降低二次成本。 + +- **自监督视觉学习**从未标注的图像中训练表示。标注成本高,但图像资源丰富。目标是在没有任何人工标注的情况下,学习能很好地迁移到下游任务的特征。 + +- **对比学习**训练模型识别:同一张图像的两个增广视图("正样本对")应具有相似的表示,而不同图像的视图("负样本对")应具有不相似的表示。 + +- **SimCLR**(Chen等人,2020)对一个批次中的每张图像创建两个增广视图,用共享主干网络+投影头对两者进行编码,并应用**NT-Xent损失**(归一化温度标度交叉熵): + +$$\ell_{i,j} = -\log \frac{\exp(\text{sim}(z_i, z_j) / \tau)}{\sum_{k \neq i} \exp(\text{sim}(z_i, z_k) / \tau)}$$ + +- 其中$\text{sim}$是余弦相似度(第01章),$\tau$是温度参数。分子将正样本对拉近;分母将负样本对推远。SimCLR需要大批量大小(4,096+)来提供足够的负样本。 + +- **MoCo**(动量对比,He等人,2020)通过维护一个**动量更新的负嵌入队列**来解决大批量需求。查询编码器通过梯度下降更新;键编码器作为查询编码器的指数移动平均(EMA,第04章)进行更新:$\theta_k \leftarrow m \theta_k + (1 - m) \theta_q$,其中$m = 0.999$。队列存储最近的键嵌入,提供了大量且一致的负样本集,无需巨大的批次。 + +- **BYOL**(自举你自己的隐空间,Grill等人,2020)完全消除了负样本对。它使用两个网络:"在线"网络和"目标"网络(在线的EMA)。在线网络预测目标网络对另一增广视图的表示。无需负样本,BYOL通过预测头的不对称性和EMA目标避免了坍塌问题(模型对所有输入输出相同向量)。 + +- **DINO**(无标签自蒸馏,Caron等人,2021)将自蒸馏应用于ViT。学生网络预测教师网络(学生的EMA)在不同增广视图下的输出。教师使用更大的裁剪区域;学生使用更小的裁剪区域。DINO产生的特征包含关于场景布局的显式信息:DINO训练的ViT的自注意力图自然地对物体进行分割,无需任何分割监督。 + +- **掩码图像建模**是BERT掩码语言建模(第07章)在视觉领域的类比。输入图像块的一大部分被掩码,模型学习重建它们。 + +- **MAE**(掩码自编码器,He等人,2022)掩码了75%的图像块,并训练一个ViT编码器-解码器来重建缺失的像素值。只有未掩码的图像块由编码器处理(在预训练期间节省4倍计算量),轻量级解码器从编码后的可见图像块加上可学习的掩码标记重建完整图像。 + +- **BEiT**(图像Transformer的BERT预训练,Bao等人,2022)掩码图像块并预测离散的视觉标记(从预训练的dVAE分词器获得),而不是原始像素。这类似于BERT预测离散词标记,避免了像素重建的低层细节。 + +- **图像生成**旨在生成训练集中不存在的新颖、逼真的图像。核心挑战是对自然图像的高维概率分布进行建模。 + +- **生成对抗网络(GAN)**(Goodfellow等人,2014)使用两个相互竞争的网络:一个**生成器**$G$从随机噪声中创建假图像,和一个**判别器**$D$试图区分真实图像和假图像。它们通过对抗性训练:$G$试图欺骗$D$,而$D$试图抓住$G$。 + +$$\min_G \max_D \; \mathbb{E}_{x \sim p_{\text{data}}}[\log D(x)] + \mathbb{E}_{z \sim p(z)}[\log(1 - D(G(z)))]$$ + +- 生成器接收随机隐向量$z$(从高斯分布等简单分布中采样),通过一系列转置卷积将其映射生成图像。判别器是一个标准的CNN分类器。在均衡状态下,$G$生成的图像与真实数据无法区分,$D$对所有输入输出0.5。 + +- **模式坍塌**是GAN的主要失败模式:生成器学会只生成少数几种能欺骗判别器的图像,忽略了训练数据的多样性。生成器找到一小部分"安全"输出,而不是覆盖完整的数据分布。 + +- 稳定GAN的训练技巧包括:谱归一化(约束判别器的Lipschitz常数)、渐进式增长(先在低分辨率训练,然后逐步提高)、特征匹配(匹配中间判别器特征的统计量而非最终输出),以及使用Wasserstein距离替代原始的JS散度目标。 + +- **StyleGAN**(Karras等人,2019)是最具影响力的高质量图像合成GAN架构。其关键创新是**基于风格的生成器**:不是将隐向量$z$直接输入生成器,而是先通过一个**映射网络**(8层MLP)生成风格向量$w$。该风格向量通过**自适应实例归一化(AdaIN)**注入到生成器的每一层,调节特征图的统计量: + +$$\text{AdaIN}(x, y) = y_{s} \cdot \frac{x - \mu(x)}{\sigma(x)} + y_{b}$$ + +- 其中$y_s$和$y_b$是从$w$推导出的缩放和偏置。不同层控制不同方面:早期层控制粗粒度特征(姿态、脸型),中间层控制中粒度特征(发型、眼睛),后期层控制细粒度细节(雀斑、发质纹理)。StyleGAN能以1024x1024分辨率生成照片级逼真的人脸。 + +- **变分自编码器(VAE)**(第06章)提供了另一种生成方法。与GAN不同,VAE有一个原则性的概率框架,具有清晰的训练目标(ELBO)。它们生成的图像通常比GAN模糊,但提供了更平滑、更结构化的隐空间。VAE是隐扩散模型中用于将图像压缩到隐空间和从隐空间重建的编码器-解码器对。 + +- **扩散模型**已成为图像生成的主导范式,在质量和多样性上都超越了GAN。其思想概念上很简单:逐步向数据添加噪声直到变成纯高斯噪声(**前向过程**),然后学习逐步逆转这一过程(**反向过程**)。 + +- **前向过程**在$T$个时间步中添加高斯噪声: + +$$q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1 - \beta_t} \, x_{t-1}, \beta_t I)$$ + +- 其中$\beta_t$是一个随时间递增的噪声调度。经过足够多的步骤后,无论原始图像$x_0$如何,$x_T$都近似于纯高斯噪声。利用重参数化技巧(第06章),设$\alpha_t = 1 - \beta_t$,$\bar{\alpha}_t = \prod_{s=1}^{t} \alpha_s$,我们可以直接从$x_0$采样$x_t$: + +$$x_t = \sqrt{\bar{\alpha}_t} \, x_0 + \sqrt{1 - \bar{\alpha}_t} \, \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)$$ + +- **反向过程**学习去噪:从纯噪声$x_T$开始,模型预测每一步添加的噪声$\epsilon$并将其减去以恢复$x_{t-1}$。这由一个神经网络$\epsilon_\theta$(通常是U-Net,来自文件03)参数化,使用简单的MSE损失训练: + +$$\mathcal{L} = \mathbb{E}_{t, x_0, \epsilon}\left[\|\epsilon - \epsilon_\theta(x_t, t)\|^2\right]$$ + +![扩散前向和反向过程:干净图像在T步中逐渐被噪声破坏(前向),神经网络学习逆转每一步(反向),从纯噪声开始生成干净图像](../images/diffusion_process.svg) + +- **DDPM**(去噪扩散概率模型,Ho等人,2020)建立了这个框架。采样需要迭代所有$T$步(通常为1,000步),这很慢。**DDIM**(去噪扩散隐式模型,Song等人,2021)将采样过程重新表述为确定性映射,允许大跨度跳过(例如50步代替1,000步)且质量损失极小。 + +- **基于分数的模型**(Song和Ermon,2019)提供了另一种视角。该模型不是预测噪声$\epsilon$,而是估计**分数函数**$\nabla_{x_t} \log p(x_t)$,即对数概率相对于含噪图像的梯度。该梯度指向数据分布中更高概率(更干净)的区域。采样使用Langevin动力学沿着该梯度进行。基于分数的模型和DDPM在**随机微分方程(SDE)**的框架下被统一:前向过程是添加噪声的SDE,反向过程是时间反转的SDE。 + +- **无分类器引导**(Ho和Salimans,2022)控制样本质量和多样性之间的权衡。模型同时进行条件训练(使用文本提示或类别标签)和无条件训练(条件随机丢弃)。在采样时,预测是加权组合: + +$$\hat{\epsilon} = \epsilon_\theta(x_t, \varnothing) + s \cdot (\epsilon_\theta(x_t, c) - \epsilon_\theta(x_t, \varnothing))$$ + +- 其中$c$是条件,$\varnothing$是空条件,$s > 1$是引导尺度。$s$越高,生成的图像越符合条件,但多样性越低。$s = 1$是无引导模型;$s = 7.5$是常见的默认值。 + +- **隐扩散**(Rombach等人,2022;Stable Diffusion)将扩散过程从像素空间转移到学习的隐空间中。一个预训练的VAE编码器将图像压缩为较低维度的隐空间表示(通常空间下采样4倍或8倍),扩散在这个压缩空间中进行,VAE解码器从去噪后的隐变量重建像素。这大大提高了效率:在像素空间扩散512x512图像需要处理$512 \times 512 \times 3$的张量,但在隐空间中仅需处理$64 \times 64 \times 4$的张量。 + +- 隐扩散中的去噪U-Net接收含噪隐变量、时间步(编码为正弦嵌入,类似于Transformer中的位置编码)和条件信号(来自冻结的CLIP或T5文本编码器的文本嵌入)。文本条件通过U-Net内的交叉注意力层进入:文本嵌入作为键和值,图像特征作为查询。这使得模型在每个空间位置都能关注文本提示的相关部分。 + +- **流匹配**是扩散模型的一个新兴替代方案,它学习噪声和数据之间的直接传输路径,而不是DDPM的迭代去噪。 + +- **连续归一化流(CNF)**定义了一个时间相关的速度场$v_\theta(x, t)$,沿着平滑轨迹将样本从简单分布$p_0$(噪声)推送到数据分布$p_1$。该变换遵循一个常微分方程(ODE): + +$$\frac{dx}{dt} = v_\theta(x, t), \quad t \in [0, 1]$$ + +- 从$x_0 \sim \mathcal{N}(0, I)$开始,将ODE向前积分到$t = 1$即可得到数据分布中的样本。速度场由神经网络参数化,训练目标是匹配目标条件流。 + +- **最优传输(OT)流匹配**(Lipman等人,2023)使用噪声和数据之间的直线路径作为目标流:从噪声样本$x_0$到数据样本$x_1$的条件路径简单地是$x_t = (1 - t) x_0 + t x_1$,目标速度为$v = x_1 - x_0$。训练损失变为: + +$$\mathcal{L} = \mathbb{E}_{t, x_0, x_1} \left[\|v_\theta(x_t, t) - (x_1 - x_0)\|^2\right]$$ + +- **整流流**(Liu等人,2022)通过迭代方式拉直学习到的流路径。在初始训练后,模型通过模拟ODE生成(噪声,数据)对。这些比随机配对更紧密对齐的对用于重新训练模型。重复此过程会产生越来越直的路径,可以通过更少的ODE步骤(甚至单步)来遍历,从而实现极快速的生成。 + +- 流匹配相比扩散有几个优势:训练目标更简单(直接的速度回归,无需噪声调度),采样ODE更平滑(需要的积分步骤更少),与最优传输的联系提供了理论依据。Stable Diffusion 3和Flux使用流匹配替代了传统的DDPM。 + +## 编程练习(使用CoLab或notebook) + +1. 从头实现ViT图像块嵌入。将图像分割成图像块,展平,投影到模型维度,添加位置嵌入,并前置[CLS]标记。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +def create_patch_embedding(image, patch_size, d_model, params): + """将图像转换为图像块嵌入序列。""" + H, W, C = image.shape + n_patches_h = H // patch_size + n_patches_w = W // patch_size + n_patches = n_patches_h * n_patches_w + + # 提取图像块 + patches = [] + for i in range(n_patches_h): + for j in range(n_patches_w): + patch = image[i*patch_size:(i+1)*patch_size, + j*patch_size:(j+1)*patch_size, :] + patches.append(patch.ravel()) + patches = jnp.stack(patches) # (N, P*P*C) + + # 线性投影到d_model + embeddings = patches @ params['proj_w'] + params['proj_b'] # (N, d_model) + + # 前置CLS标记 + cls_token = params['cls_token'] # (1, d_model) + embeddings = jnp.concatenate([cls_token, embeddings], axis=0) # (N+1, d_model) + + # 添加位置嵌入 + embeddings = embeddings + params['pos_embed'] # (N+1, d_model) + + return embeddings, patches + +# 设置 +H, W, C = 32, 32, 3 +patch_size = 8 +d_model = 64 +n_patches = (H // patch_size) * (W // patch_size) # 16 + +key = jax.random.PRNGKey(42) +keys = jax.random.split(key, 5) + +# 创建具有不同象限的合成图像 +image = jnp.zeros((H, W, C)) +image = image.at[:16, :16, 0].set(1.0) # 红色 左上 +image = image.at[:16, 16:, 1].set(1.0) # 绿色 右上 +image = image.at[16:, :16, 2].set(1.0) # 蓝色 左下 +image = image.at[16:, 16:, :2].set(1.0) # 黄色 右下 + +params = { + 'proj_w': jax.random.normal(keys[0], (patch_size**2 * C, d_model)) * 0.02, + 'proj_b': jnp.zeros(d_model), + 'cls_token': jax.random.normal(keys[1], (1, d_model)) * 0.02, + 'pos_embed': jax.random.normal(keys[2], (n_patches + 1, d_model)) * 0.02, +} + +embeddings, patches = create_patch_embedding(image, patch_size, d_model, params) + +print(f"图像形状: {image.shape}") +print(f"图像块大小: {patch_size}x{patch_size}") +print(f"图像块数量: {n_patches}") +print(f"图像块向量长度: {patch_size**2 * C}") +print(f"嵌入形状: {embeddings.shape} (CLS + {n_patches} 个图像块)") + +# 可视化图像块 +fig, axes = plt.subplots(2, 5, figsize=(14, 6)) +axes[0, 0].imshow(image); axes[0, 0].set_title('完整图像'); axes[0, 0].axis('off') +for idx in range(min(9, n_patches)): + ax = axes[(idx+1) // 5, (idx+1) % 5] + patch_img = patches[idx].reshape(patch_size, patch_size, C) + ax.imshow(patch_img); ax.set_title(f'图像块 {idx}'); ax.axis('off') +plt.suptitle('ViT 图像块分解') +plt.tight_layout(); plt.show() +``` + +2. 实现一个简单的GAN训练循环。在二维数据上训练生成器和判别器,并可视化生成分布逐渐收敛到真实分布。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +def generator(z, params): + h = jnp.tanh(z @ params['g_w1'] + params['g_b1']) + h = jnp.tanh(h @ params['g_w2'] + params['g_b2']) + return h @ params['g_w3'] + params['g_b3'] + +def discriminator(x, params): + h = jax.nn.leaky_relu(x @ params['d_w1'] + params['d_b1'], 0.2) + h = jax.nn.leaky_relu(h @ params['d_w2'] + params['d_b2'], 0.2) + return jax.nn.sigmoid(h @ params['d_w3'] + params['d_b3']) + +def init_params(key): + keys = jax.random.split(key, 6) + z_dim, h_dim, data_dim = 2, 32, 2 + scale = 0.1 + return { + 'g_w1': jax.random.normal(keys[0], (z_dim, h_dim)) * scale, + 'g_b1': jnp.zeros(h_dim), + 'g_w2': jax.random.normal(keys[1], (h_dim, h_dim)) * scale, + 'g_b2': jnp.zeros(h_dim), + 'g_w3': jax.random.normal(keys[2], (h_dim, data_dim)) * scale, + 'g_b3': jnp.zeros(data_dim), + 'd_w1': jax.random.normal(keys[3], (data_dim, h_dim)) * scale, + 'd_b1': jnp.zeros(h_dim), + 'd_w2': jax.random.normal(keys[4], (h_dim, h_dim)) * scale, + 'd_b2': jnp.zeros(h_dim), + 'd_w3': jax.random.normal(keys[5], (h_dim, 1)) * scale, + 'd_b3': jnp.zeros(1), + } + +def d_loss(params, real_data, fake_data): + real_score = discriminator(real_data, params) + fake_score = discriminator(fake_data, params) + return -jnp.mean(jnp.log(real_score + 1e-7) + jnp.log(1 - fake_score + 1e-7)) + +def g_loss(params, fake_data): + fake_score = discriminator(fake_data, params) + return -jnp.mean(jnp.log(fake_score + 1e-7)) + +# 真实数据:环形分布 +key = jax.random.PRNGKey(42) +theta = jax.random.uniform(key, (512,)) * 2 * jnp.pi +real_data = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=1) +real_data = real_data + jax.random.normal(key, real_data.shape) * 0.05 + +params = init_params(jax.random.PRNGKey(0)) +d_grad = jax.grad(d_loss) +g_grad = jax.grad(g_loss) +lr = 0.001 + +snapshots = [] +for step in range(3000): + key, k1 = jax.random.split(key) + z = jax.random.normal(k1, (512, 2)) + fake_data = generator(z, params) + + # 更新判别器 + grads = d_grad(params, real_data, fake_data) + for k in ['d_w1', 'd_b1', 'd_w2', 'd_b2', 'd_w3', 'd_b3']: + params[k] = params[k] - lr * grads[k] + + # 更新生成器 + fake_data = generator(z, params) + grads = g_grad(params, fake_data) + for k in ['g_w1', 'g_b1', 'g_w2', 'g_b2', 'g_w3', 'g_b3']: + params[k] = params[k] - lr * grads[k] + + if step in [0, 500, 1500, 2999]: + snapshots.append((step, fake_data.copy())) + +fig, axes = plt.subplots(1, 4, figsize=(16, 4)) +for ax, (step, fake) in zip(axes, snapshots): + ax.scatter(real_data[:, 0], real_data[:, 1], s=5, alpha=0.3, c='#3498db', label='真实') + ax.scatter(fake[:, 0], fake[:, 1], s=5, alpha=0.3, c='#e74c3c', label='生成') + ax.set_title(f'步骤 {step}'); ax.set_xlim(-2, 2); ax.set_ylim(-2, 2) + ax.set_aspect('equal'); ax.legend(markerscale=3) +plt.suptitle('GAN训练:生成器学习环形分布') +plt.tight_layout(); plt.show() +``` + +3. 实现扩散前向过程:在不同时间步向图像添加噪声,并可视化逐步破坏过程。然后实现单步去噪。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +def noise_schedule(T, beta_start=0.0001, beta_end=0.02): + """线性噪声调度。""" + betas = jnp.linspace(beta_start, beta_end, T) + alphas = 1.0 - betas + alpha_bars = jnp.cumprod(alphas) + return betas, alphas, alpha_bars + +def forward_diffusion(x0, t, alpha_bars, key): + """在时间步t向x0添加噪声。""" + alpha_bar_t = alpha_bars[t] + noise = jax.random.normal(key, x0.shape) + xt = jnp.sqrt(alpha_bar_t) * x0 + jnp.sqrt(1 - alpha_bar_t) * noise + return xt, noise + +# 创建简单的2D"图像"(棋盘格) +img = jnp.zeros((32, 32)) +for i in range(4): + for j in range(4): + if (i + j) % 2 == 0: + img = img.at[i*8:(i+1)*8, j*8:(j+1)*8].set(1.0) + +T = 1000 +betas, alphas, alpha_bars = noise_schedule(T) + +# 可视化前向过程 +timesteps = [0, 50, 200, 500, 999] +key = jax.random.PRNGKey(42) + +fig, axes = plt.subplots(1, len(timesteps), figsize=(16, 3.5)) +for ax, t in zip(axes, timesteps): + key, subkey = jax.random.split(key) + xt, noise = forward_diffusion(img, t, alpha_bars, subkey) + ax.imshow(xt, cmap='gray', vmin=-2, vmax=2) + ax.set_title(f't={t}\n$\\bar{{\\alpha}}$={alpha_bars[t]:.3f}') + ax.axis('off') +plt.suptitle('扩散前向过程:逐步添加噪声') +plt.tight_layout(); plt.show() + +# 简单去噪:训练小型网络在t=200时预测噪声 +t_denoise = 200 +key, k1 = jax.random.split(key) +xt, true_noise = forward_diffusion(img, t_denoise, alpha_bars, k1) + +# 小型"去噪器":仅学习恒定的噪声估计(用于演示) +noise_estimate = jnp.zeros_like(img) +lr = 0.01 +for step in range(100): + residual = noise_estimate - true_noise + noise_estimate = noise_estimate - lr * residual + +# 反向一步 +alpha_bar_t = alpha_bars[t_denoise] +x_denoised = (xt - jnp.sqrt(1 - alpha_bar_t) * noise_estimate) / jnp.sqrt(alpha_bar_t) + +fig, axes = plt.subplots(1, 3, figsize=(12, 4)) +axes[0].imshow(img, cmap='gray'); axes[0].set_title('原始 $x_0$'); axes[0].axis('off') +axes[1].imshow(xt, cmap='gray', vmin=-2, vmax=2) +axes[1].set_title(f'含噪 $x_{{200}}$'); axes[1].axis('off') +axes[2].imshow(x_denoised, cmap='gray') +axes[2].set_title('去噪后(单步)'); axes[2].axis('off') +plt.tight_layout(); plt.show() + +mse = jnp.mean((x_denoised - img)**2) +print(f"去噪MSE: {mse:.4f}") +``` diff --git a/chapter 08: computer vision/05. video and 3D vision.md b/chapter 08: computer vision/05. video and 3D vision.md new file mode 100644 index 0000000..fbed535 --- /dev/null +++ b/chapter 08: computer vision/05. video and 3D vision.md @@ -0,0 +1,347 @@ +# 视频与3D视觉 + +*视频与3D视觉将图像理解扩展到时间域和空间域。本文涵盖光流、视频分类(3D卷积网络、TimeSformer)、目标跟踪(SORT、DeepSORT)、动作识别、深度估计(单目与立体)、点云、神经辐射场(NeRF)和3D高斯泼溅。* + +- 文件01-04将图像视为孤立快照。但视觉世界是连续的:物体在运动,场景在变化,深度真实存在。本文将计算机视觉扩展到时间域(视频)和空间域(3D),涵盖模型如何理解运动、跟踪目标、估计深度和重建场景。 + +- **视频**是一系列随时间捕获的图像(帧)。以30帧/秒计算,一段10秒的片段包含300帧。关键挑战在于建模**时间维度**:物体如何运动,场景如何演变,以及如何跨帧关联信息。 + +- **光流**估计两帧连续图像之间像素的表观运动。对于帧$t$中的每个像素,光流产生一个二维位移向量$(u, v)$,指向该像素在帧$t+1$中的位置。结果是一个与图像大小相同的稠密运动场。 + +![两帧连续视频帧及其之间的光流场,以彩色箭头可视化显示像素运动方向和大小](../images/optical_flow.svg) + +- 光流在**亮度恒常性假设**下计算:像素的强度在其移动时不变。如果帧$t$中位置$(x, y)$处的像素强度为$I(x, y, t)$,并在小时间间隔$\delta t$内移动了$(u, v)$: + +$$I(x + u\delta t, \, y + v\delta t, \, t + \delta t) = I(x, y, t)$$ + +- 进行一阶泰勒展开(见第03章)并除以$\delta t$: + +$$I_x u + I_y v + I_t = 0$$ + +- 其中$I_x, I_y$是空间梯度(Sobel算子,见文件01),$I_t$是时间梯度(相邻帧的差值)。这就是**光流约束方程**。一个方程,两个未知数$(u, v)$:我们需要额外的约束条件。 + +- **Lucas-Kanade**假设光流在一个小窗口内(例如5x5像素)是恒定的。这给出了一个超定系统(25个方程,2个未知数),通过最小二乘法求解(第06章的正规方程): + +```math +\begin{bmatrix} u \\ v \end{bmatrix} = \begin{bmatrix} \sum I_x^2 & \sum I_x I_y \\ \sum I_x I_y & \sum I_y^2 \end{bmatrix}^{-1} \begin{bmatrix} -\sum I_x I_t \\ -\sum I_y I_t \end{bmatrix} +``` + +- 这个2x2矩阵就是文件01中的结构张量(与Harris角点检测中使用的矩阵相同)。Lucas-Kanade适用于小运动,但当物体在帧间移动超过几个像素时会失效。 + +- **Farneback方法**对每个像素邻域进行多项式展开,并估计最能解释帧间变化的位移场。它产生稠密光流(每个像素一个向量),能处理比Lucas-Kanade更大的运动。 + +- 现代**深度学习光流**方法(FlowNet、RAFT)学习从帧对端到端预测光流。**RAFT**(Recurrent All-Pairs Field Transforms,Teed和Deng,2020)计算两帧中所有像素对之间的4D相关体,并使用基于GRU的更新算子迭代优化光流估计。RAFT达到了最先进的精度,并已成为标准的光流骨干网络。 + +- **双流网络**(Simonyan和Zisserman,2014)是视频理解的早期方法。一个流处理单帧RGB图像(外观),另一个流处理光流帧的堆叠(运动)。两个流在末端融合(通过平均或拼接)。这种架构明确区分了"事物看起来像什么"与"它们如何运动"。 + +- **3D卷积网络**将2D卷积扩展到时间维度。3D卷积使用大小为$k \times k \times k_t$的滤波器,同时跨越空间和时间维度,直接学习时空特征。 + +- **C3D**(Tran等人,2015)堆叠了3x3x3滤波器的3D卷积,展示了时间卷积可以在没有显式光流的情况下学习运动特征。代价是高昂的:3D卷积的参数和计算量是其2D对应物的$k_t$倍。 + +- **I3D**(Inflated 3D,Carreira和Zisserman,2017)采用了一种更实用的方法:从预训练的2D CNN(如Inception或ResNet)开始,将所有2D滤波器沿时间维度"膨胀"为3D,重复权重并除以$k_t$。这将ImageNet预训练迁移到视频,同时增加了时间建模能力。一个2D的$k \times k$滤波器变为$k \times k \times k_t$的滤波器,初始化为$W_{\text{3D}}[:,:,j] = W_{\text{2D}} / k_t$,对所有时间位置$j$。 + +- **SlowFast网络**(Feichtenhofer等人,2019)使用两条并行的路径,以不同的时间分辨率运行: + - **Slow路径**以低帧率(例如每16帧)处理帧,具有高空间分辨率和更多通道,捕获精细的空间细节。 + - **Fast路径**以高帧率(每2帧)处理帧,空间分辨率降低且通道数较少(通常为Slow路径的$1/8$),捕获快速的时间变化。 + - 侧向连接通过步长卷积将信息从Fast融合到Slow。 + +- 其核心洞见是:空间和时间信息具有不同的带宽需求——物体外观变化缓慢,但运动可以很迅速。SlowFast通过设计匹配这种非对称性。 + +- **TimeSformer**(Bertasius等人,2021)将Vision Transformer应用于视频。它将完整的时空注意力(代价过高:$O((T \times N)^2)$,其中$T$为帧数,$N$为每帧的块数)分解为**分块注意力**:每个块在时间注意(每个块在相同空间位置跨时间进行注意力)和空间注意(每个块在同一帧内跨空间进行注意力)之间交替。这使代价从$O(T^2 N^2)$降低到$O(T^2 + N^2)$。 + +- **VideoMAE**(Tong等人,2022)将掩码自编码器思想(见文件04)扩展到视频。使用极高的掩码比例(90-95%),因为视频具有高度的时间冗余性:相邻帧看起来几乎相同,因此掩码大部分块后仍然留有足够的信息进行重建。VideoMAE在无标签视频上预训练ViT骨干网络,并迁移到下游任务。 + +- **动作识别**将视频片段分类为多种动作类别之一(例如"跑步"、"烹饪"、"弹吉他")。它是图像分类的视频对应任务。标准基准数据集包括Kinetics-400(400个动作类别,约30万个片段)、Something-Something(174个需要时间推理的细粒度动作)和ActivityNet(200个类别,包含长时未裁剪视频)。 + +- **时间动作检测**超越了分类:给定一个长段未裁剪的视频,找到每个动作的开始时间、结束时间和类别。这是目标检测的时间对应任务。ActionFormer等方法使用Transformer处理时间特征并预测动作边界。 + +- **视频目标跟踪**在第一帧识别出特定目标后,跨帧跟踪该目标。 + +- **SORT**(Simple Online and Realtime Tracking,Bewley等人,2016)将检测模型(独立检测每帧中的目标)与**卡尔曼滤波器**(用于运动预测)和**匈牙利算法**(用于分配)相结合。 + +- **卡尔曼滤波器**为每个跟踪的目标维护一个状态估计(位置、速度、大小),并使用线性运动模型预测它在下一帧中的位置。当新的检测结果到达时,卡尔曼滤波器通过结合预测值和观测值(按各自的不确定性加权)来更新其估计。这是贝叶斯更新(第05章)在跟踪中的应用。 + +- **匈牙利算法**解决双线性分配问题:给定$M$个已跟踪目标和$N$个新检测结果,找到使总代价最小化的最优一对一匹配(使用文件03中的IoU距离)。未匹配的检测结果开始新的轨迹;未匹配的轨迹在宽限期后被终止。 + +- **DeepSORT**通过添加**深度外观特征**扩展了SORT:每个检测到的目标经过一个小型CNN,产生一个外观嵌入(描述子向量)。匹配代价结合了IoU距离和嵌入空间中的余弦距离(第01章)。这处理了遮挡和重识别:即使一个目标在其他目标后消失数帧,其外观嵌入允许在重新出现时重新匹配。 + +- **ByteTrack**(Zhang等人,2022)通过使用所有检测结果(包括低置信度的)来改进跟踪。大多数跟踪器会丢弃低于置信度阈值的检测结果。ByteTrack首先将高置信度检测结果与现有轨迹匹配,然后将剩余的低置信度检测结果与未匹配的轨迹匹配。这恢复了暂时被遮挡或模糊(因此检测置信度低)的目标。 + +- **3D视觉**恢复在2D图像投影中丢失的第三个空间维度(见文件01)。 + +- **深度估计**预测从相机到场景中每个点的距离。 + +- **立体深度**使用两个相距基线距离$b$的相机。同一个点在左右图像中出现在不同的水平位置(这个偏移称为**视差**$d$)。深度与视差成反比: + +$$Z = \frac{f \cdot b}{d}$$ + +- 其中$f$是焦距,$b$是基线距离。计算视差需要找到两个图像之间的对应点(立体匹配),这是沿水平扫描线的一维搜索(因为相机水平对齐,3D中同一高度的点投影到两幅图像的同一行)。 + +- **单目深度估计**从单张图像预测深度,这本质上是病态问题(无限多个3D场景可以产生相同的2D图像)。然而,人类利用相对大小、纹理梯度、遮挡和大气雾霾等线索毫不费力地做到这一点。深度学习网络从训练数据中学习这些线索。 + +- **MiDaS**和**Depth Anything**等模型从单张图像预测相对深度图(排序哪些物体更近)。它们使用尺度不变损失在各种数据集上训练,尽管理论上存在歧义,但仍能产生非常准确的结果。 + +- **点云**是3D点$(x, y, z)$的集合,可选地带有颜色或其他属性,由LiDAR传感器或立体重建捕获。与图像不同,点云是无序且不规则间隔的。 + +- **PointNet**(Qi等人,2017)通过独立地对每个点应用共享MLP,然后使用最大池化聚合(这是置换不变的,解决了排序问题),直接处理点云。**PointNet++**增加了层次化分组,以捕获多尺度的局部结构。 + +- **神经辐射场(NeRF)**(Mildenhall等人,2020)将3D场景表示为一个连续函数,将3D位置$(x, y, z)$和视角方向$(\theta, \phi)$映射到颜色$(r, g, b)$和密度$\sigma$。该函数由一个MLP参数化: + +$$F_\theta: (x, y, z, \theta, \phi) \to (r, g, b, \sigma)$$ + +- 为了渲染一个像素,从相机穿过该像素向场景投射一条射线。沿射线采样点,MLP预测每个点的颜色和密度。像素颜色通过**体渲染**计算:沿射线按密度加权积分颜色: + +$$C(\mathbf{r}) = \int_{t_n}^{t_f} T(t) \cdot \sigma(\mathbf{r}(t)) \cdot \mathbf{c}(\mathbf{r}(t), \mathbf{d}) \, dt$$ + +- 其中$T(t) = \exp(-\int_{t_n}^{t} \sigma(\mathbf{r}(s)) \, ds)$是累积透射率(已吸收的光总量)。在实际中,该积分通过沿射线采样$N$个点并求和来近似: + +$$\hat{C} = \sum_{i=1}^{N} T_i \cdot (1 - \exp(-\sigma_i \delta_i)) \cdot c_i$$ + +- NeRF通过最小化渲染像素与一组带位姿照片的真实像素之间的MSE来训练。训练完成后,NeRF可以从任何相机位置渲染逼真的新视角。其局限性在于速度:渲染需要对MLP进行数百万次评估(每个像素每个采样点一次),这使得实时渲染变得困难。 + +- **3D高斯泼溅**(Kerbl等人,2023)通过将场景表示为3D高斯原语的集合(而非连续的体积函数)来解决NeRF的速度限制。每个高斯原语有一个3D位置(均值)、一个3D协方差矩阵(控制形状和朝向)、不透明度及颜色(表示为球谐函数以实现视角相关效果)。 + +- 渲染将每个3D高斯投影到图像平面(产生一个2D高斯"泼溅"),按深度排序,并使用alpha混合从前往后合成。这是一个在GPU上实时运行的栅格化过程(100+ FPS),比NeRF的射线步进快几个数量级。高斯泼溅达到或超过NeRF的质量,同时实现实时渲染。 + +- **SLAM**(同时定位与地图构建)是在未知环境中构建地图同时跟踪相机自身位置的问题。这是机器人、自动驾驶和AR的基础。 + +- **视觉里程计**通过跨图像跟踪特征来估计相机从一帧到另一帧的运动。特征点(SIFT、ORB,见文件01)在连续帧之间匹配,并利用这些匹配关系通过**本质矩阵**(编码两视图之间的几何关系,由文件01的内参和外参推导)估计相机的旋转和平移。 + +- **基于特征的SLAM**通过维护持久地图来扩展视觉里程计。**ORB-SLAM**(Mur-Artal等人,2015)是使用最广泛的基于特征的SLAM系统。它有三个并行线程: + 1. **跟踪**:将每帧中的ORB特征与地图匹配,使用PnP(Perspective-n-Point)和RANSAC估计相机位姿。 + 2. **局部建图**:从匹配的特征三角化新的地图点,通过光束法平差(最小化所有观察到每个点的视图的重投影误差)优化其位置。 + 3. **闭环检测**:检测相机何时重新访问先前建图的区域(使用视觉词袋),然后通过全局优化地图来校正累积漂移。 + +- **LiDAR SLAM**使用来自LiDAR传感器的3D点云替代(或补充)相机图像。LiDAR提供直接的深度测量,使几何估计更鲁棒,但硬件成本更高。LOAM(LiDAR Odometry and Mapping)等方法使用迭代最近点(ICP)配准来对齐连续扫描之间的点云。 + +- **视觉-惯性SLAM**融合相机数据与IMU(加速度计+陀螺仪)的测量结果。IMU提供高频的旋转和加速度估计,弥补相机帧之间的间隙,并处理快速运动或临时视觉特征丢失的情况。 + +- **VR/AR**应用是计算机视觉最苛刻的消费者之一。 + +- **姿态估计**从图像中确定人体(或面部、手部)的位置和朝向。**身体姿态**通常表示为一组2D或3D关键点位置(关节点:肩膀、肘部、手腕、髋部、膝盖、脚踝)。**OpenPose**和**MediaPipe**等模型使用热图回归预测这些关键点:对于每个关节点,模型输出一个热图,其中峰值指示关节点的位置。 + +- **自上而下**的方法首先使用边界框检测器(见文件03)检测人物,然后在每个框内估计姿态。**自下而上**的方法首先检测图像中的所有关键点,然后使用部位亲和场(编码连接关节点之间关联的向量场)将它们分组为个体。 + +- **场景重建**从传感器数据构建环境的3D模型。在AR中,这使得可以将虚拟物体放置在真实表面上、遮挡真实物体后面的虚拟物体以及投射虚拟阴影。实时场景重建方法(如ARKit和ARCore中基于深度传感器的系统)构建环境的稀疏网格,并随着用户移动而更新。 + +- **VR中的实时渲染**约束极为苛刻:双眼需要独立渲染90+ FPS(以避免晕动症),从头部位移到显示更新的延迟需低于20毫秒。**注视点渲染**(仅渲染用户注视位置的高分辨率,使用眼动追踪)和**重投影**(基于新头部位姿扭曲上一帧以填补下一帧渲染间隙)等技术对于满足这些约束至关重要。 + +- 实时神经渲染(3D高斯泼溅)、鲁棒跟踪(视觉-惯性SLAM)和高效姿态估计的融合,正使逼真的交互式AR/VR体验变得越来越可行。 + +## 编程任务(使用CoLab或Notebook) + +1. 从头实现Lucas-Kanade光流算法。计算一个方块向右移动的两帧合成图像之间的光流。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +def lucas_kanade(frame1, frame2, window_size=5): + """Lucas-Kanade光流。""" + # 计算梯度 + Ix = jnp.zeros_like(frame1) + Iy = jnp.zeros_like(frame1) + It = frame2 - frame1 + + # Sobel风格梯度 + Ix = Ix.at[1:-1, :].set((frame1[2:, :] - frame1[:-2, :]) / 2) + Iy = Iy.at[:, 1:-1].set((frame1[:, 2:] - frame1[:, :-2]) / 2) + + H, W = frame1.shape + half_w = window_size // 2 + u = jnp.zeros_like(frame1) + v = jnp.zeros_like(frame1) + + for i in range(half_w, H - half_w): + for j in range(half_w, W - half_w): + Ix_win = Ix[i-half_w:i+half_w+1, j-half_w:j+half_w+1].ravel() + Iy_win = Iy[i-half_w:i+half_w+1, j-half_w:j+half_w+1].ravel() + It_win = It[i-half_w:i+half_w+1, j-half_w:j+half_w+1].ravel() + + A = jnp.stack([Ix_win, Iy_win], axis=1) + ATA = A.T @ A + ATb = -A.T @ It_win + + # 检查系统是否良态 + det = ATA[0,0] * ATA[1,1] - ATA[0,1] * ATA[1,0] + if jnp.abs(det) > 1e-6: + flow = jnp.linalg.solve(ATA, ATb) + u = u.at[i, j].set(flow[0]) + v = v.at[i, j].set(flow[1]) + + return u, v + +# 创建两帧:一个向右移动的白色方块 +frame1 = jnp.zeros((64, 64)) +frame1 = frame1.at[20:40, 15:35].set(1.0) + +frame2 = jnp.zeros((64, 64)) +frame2 = frame2.at[20:40, 20:40].set(1.0) # 向右移动5个像素 + +u, v = lucas_kanade(frame1, frame2, window_size=7) + +# 可视化 +fig, axes = plt.subplots(1, 3, figsize=(14, 4)) +axes[0].imshow(frame1, cmap='gray'); axes[0].set_title('帧1'); axes[0].axis('off') +axes[1].imshow(frame2, cmap='gray'); axes[1].set_title('帧2'); axes[1].axis('off') + +# 光流的箭矢图(为清晰起见降采样) +step = 4 +Y, X = jnp.mgrid[0:64:step, 0:64:step] +axes[2].imshow(frame1, cmap='gray', alpha=0.5) +axes[2].quiver(X, Y, u[::step, ::step], v[::step, ::step], + color='#e74c3c', scale=50, width=0.005) +axes[2].set_title('光流'); axes[2].axis('off') + +plt.tight_layout(); plt.show() + +# 检查运动区域的平均光流 +region_u = u[20:40, 15:35] +print(f"物体区域的平均水平光流: {region_u[region_u != 0].mean():.2f} 像素") +``` + +2. 实现一个用于2D目标跟踪的简单卡尔曼滤波器。模拟一个带噪声的轨迹,并展示卡尔曼滤波器如何平滑估计。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +def kalman_predict(x, P, F, Q): + """卡尔曼滤波器预测步骤。""" + x_pred = F @ x + P_pred = F @ P @ F.T + Q + return x_pred, P_pred + +def kalman_update(x_pred, P_pred, z, H, R): + """卡尔曼滤波器更新步骤。""" + y = z - H @ x_pred # 创新 + S = H @ P_pred @ H.T + R # 创新协方差 + K = P_pred @ H.T @ jnp.linalg.inv(S) # 卡尔曼增益 + x_updated = x_pred + K @ y + P_updated = (jnp.eye(len(x_pred)) - K @ H) @ P_pred + return x_updated, P_updated + +# 状态: [x, y, vx, vy] +dt = 1.0 +F = jnp.array([[1, 0, dt, 0], # 状态转移 + [0, 1, 0, dt], + [0, 0, 1, 0], + [0, 0, 0, 1]]) +H = jnp.array([[1, 0, 0, 0], # 观测:测量 x, y + [0, 1, 0, 0]]) +Q = jnp.eye(4) * 0.01 # 过程噪声 +R = jnp.eye(2) * 4.0 # 测量噪声(有噪声的检测器) + +# 模拟真实轨迹:圆周运动 +n_steps = 50 +t = jnp.linspace(0, 2 * jnp.pi, n_steps) +true_x = 10 * jnp.cos(t) + 20 +true_y = 10 * jnp.sin(t) + 20 + +# 带噪声的观测 +key = jax.random.PRNGKey(42) +noise = jax.random.normal(key, (n_steps, 2)) * 2.0 +obs_x = true_x + noise[:, 0] +obs_y = true_y + noise[:, 1] + +# 运行卡尔曼滤波器 +x = jnp.array([obs_x[0], obs_y[0], 0.0, 0.0]) # 初始状态 +P = jnp.eye(4) * 10.0 # 初始不确定性 + +kalman_x, kalman_y = [], [] +for i in range(n_steps): + x, P = kalman_predict(x, P, F, Q) + z = jnp.array([obs_x[i], obs_y[i]]) + x, P = kalman_update(x, P, z, H, R) + kalman_x.append(x[0]) + kalman_y.append(x[1]) + +kalman_x = jnp.array(kalman_x) +kalman_y = jnp.array(kalman_y) + +# 可视化 +plt.figure(figsize=(8, 8)) +plt.plot(true_x, true_y, 'k-', linewidth=2, label='真实轨迹') +plt.scatter(obs_x, obs_y, c='#e74c3c', s=20, alpha=0.5, label='带噪声的观测') +plt.plot(kalman_x, kalman_y, '#3498db', linewidth=2, label='卡尔曼滤波') +plt.legend(); plt.grid(alpha=0.3) +plt.title('卡尔曼滤波跟踪') +plt.xlabel('x'); plt.ylabel('y') +plt.axis('equal'); plt.show() + +obs_error = jnp.mean(jnp.sqrt((obs_x - true_x)**2 + (obs_y - true_y)**2)) +kalman_error = jnp.mean(jnp.sqrt((kalman_x - true_x)**2 + (kalman_y - true_y)**2)) +print(f"观测RMSE: {obs_error:.2f}") +print(f"卡尔曼滤波RMSE: {kalman_error:.2f}") +print(f"误差降低: {(1 - kalman_error/obs_error) * 100:.1f}%") +``` + +3. 实现一个简化的NeRF风格体渲染管线。通过一个简单的3D场景(已知颜色和密度的球体)投射射线,并沿每条射线积分来渲染图像。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +def render_ray(origin, direction, spheres, n_samples=64, t_near=1.0, t_far=6.0): + """穿过球体场景对单条射线进行体渲染。""" + t_vals = jnp.linspace(t_near, t_far, n_samples) + deltas = jnp.concatenate([jnp.diff(t_vals), jnp.array([1e-3])]) + + colour = jnp.zeros(3) + transmittance = 1.0 + + for i in range(n_samples): + point = origin + t_vals[i] * direction + + # 计算该点的密度和颜色 + density = 0.0 + point_colour = jnp.zeros(3) + + for center, radius, col, sigma in spheres: + dist = jnp.linalg.norm(point - center) + # 软球体:密度随距表面的距离指数衰减 + d = jnp.exp(-jnp.maximum(0, dist - radius) * sigma) * sigma + density += d + point_colour += d * jnp.array(col) + + # 按总密度归一化颜色 + point_colour = jnp.where(density > 1e-6, point_colour / density, point_colour) + + # 体渲染方程 + alpha = 1.0 - jnp.exp(-density * deltas[i]) + colour += transmittance * alpha * point_colour + transmittance *= (1.0 - alpha) + + return colour + +# 场景:三个彩色球体 +spheres = [ + (jnp.array([0.0, 0.0, 4.0]), 0.8, [1.0, 0.2, 0.2], 5.0), # 红色 + (jnp.array([1.5, 0.5, 5.0]), 0.6, [0.2, 1.0, 0.2], 5.0), # 绿色 + (jnp.array([-1.0, -0.5, 3.5]), 0.5, [0.2, 0.2, 1.0], 5.0), # 蓝色 +] + +# 相机设置 +img_h, img_w = 64, 64 +focal = 60.0 +origin = jnp.array([0.0, 0.0, 0.0]) + +image = jnp.zeros((img_h, img_w, 3)) +for i in range(img_h): + for j in range(img_w): + # 计算射线方向 + px = (j - img_w / 2) / focal + py = -(i - img_h / 2) / focal + direction = jnp.array([px, py, 1.0]) + direction = direction / jnp.linalg.norm(direction) + + colour = render_ray(origin, direction, spheres) + image = image.at[i, j].set(jnp.clip(colour, 0, 1)) + +plt.figure(figsize=(6, 6)) +plt.imshow(image) +plt.title('NeRF风格体渲染\n(3个球体)') +plt.axis('off') +plt.tight_layout(); plt.show() +print(f"图像形状: {image.shape}") +print(f"渲染了 {img_h * img_w} 条射线,每条 {64} 个采样点") +``` diff --git a/chapter 09: audio and speech/01. digital signal processing.md b/chapter 09: audio and speech/01. digital signal processing.md new file mode 100644 index 0000000..4a5a4a1 --- /dev/null +++ b/chapter 09: audio and speech/01. digital signal processing.md @@ -0,0 +1,498 @@ +# 数字信号处理 + +*数字信号处理将原始音频波形转换为结构化表示,机器学习模型可以从中学习。本文涵盖声音物理学、采样与量化、傅里叶变换(DFT、FFT)、语谱图、梅尔滤波器组、MFCC 和加窗,以及所有语音和音频 AI 所需的特征提取流水线。* + +- **声音**是一种通过介质(空气、水、固体)传播的压力波。振动物体(声带、吉他弦、扬声器纸盆)推拉空气分子,产生交替的高压区域(压缩)和低压区域(稀疏)。 + +- 这些压力变化以大约 343 m/s 的速度在空气中向外传播,到达你的耳朵后,使耳膜振动并转换为神经信号。 + +- 可以把声音想象成向平静的水面投下一块石头:石头是振动源,涟漪是压力波,水面漂浮的软木塞就是麦克风或耳膜,它响应着波的到来。 + +- 软木塞上下浮动的幅度是**振幅**,每秒浮动的次数是**频率**,波到达时软木塞是处于浮动的最高点还是最低点则是**相位**。 + +- **波形**是压力(或电压,在麦克风将声音转换为电信号后)随时间变化的曲线图。最简单的波形是**纯音**,即单一正弦波: + +$$x(t) = A \sin(2\pi f t + \phi)$$ + +- 其中: + - $A$ 是振幅(偏离零点的最大偏差,决定响度), + - $f$ 是以 Hz 为单位的频率(每秒周期数,决定音高), + - $\phi$ 是以弧度为单位的相位(波的时间偏移)。 + +- **周期** $T = 1/f$,是一个完整周期持续的时长。 + +![标注了振幅、周期、频率和相位偏移的正弦波](../images/audio_waveform.svg) + +- **振幅**决定了感知到的响度。振幅加倍,功率变为四倍(因为功率与振幅的平方成正比)。 + +- 人耳的听觉范围覆盖极大的振幅跨度,因此我们使用对数刻度:**分贝**(dB)。声压级的计算方式为: + +$$L = 20 \log_{10}\left(\frac{A}{A_\text{ref}}\right) \text{ dB}$$ + +- 其中 $A_\text{ref}$ 是参考振幅(通常取听阈,$20 \mu\text{Pa}$)。耳语约为 30 dB,正常对话 60 dB,摇滚音乐会 110 dB。每增加 6 dB,振幅大约翻倍;每增加 10 dB,感知响度大约翻倍。此处的对数与第 03 章中的对数函数相同。 + +- **频率**决定音高。低频(20–250 Hz)听起来低沉;高频(2000–20000 Hz)听起来尖锐。人耳听觉范围大致为 20 Hz 到 20 kHz。音乐会标准音 A 为 440 Hz。频率加倍,音高升高一个**八度**。 + +- 大多数自然声音不是纯音,而是许多频率的复杂混合——这就是为什么钢琴和小提琴演奏同一个音符时听起来不同:它们共享相同的**基频**,但**谐波**(基频的整数倍)及其相对振幅(**音色**)不同。 + +- **相位**决定了波从其周期中的哪个起点开始。两个振幅和频率相同但相位不同的波可以发生相长干涉(相位对齐,振幅相加)或相消干涉(相位相反,振幅抵消)。 + +- 相位在立体声音频和波束成形中至关重要,但在许多语音处理流水线中基本上被丢弃,因为人类对音高和音色的感知大多与相位无关。 + +- 现实世界的音频信号是时间的**连续**函数,但计算机处理的是离散数值。**采样**通过以固定间隔测量信号值,将连续信号转换为离散序列。 + +- **采样率** $f_s$ 是每秒的测量次数。CD 音频使用 $f_s = 44{,}100$ Hz;电话通信使用 8000 Hz;现代语音模型通常使用 16000 Hz。 + +- **奈奎斯特-香农采样定理**指出:当且仅当采样率至少是信号中最高频率的两倍时,连续信号才能从其样本中完美重建: + +$$f_s \geq 2 f_\text{max}$$ + +- 频率 $f_s / 2$ 称为**奈奎斯特频率**。如果信号中包含高于奈奎斯特频率的频率成分,这些频率会折叠回有效范围内,表现为虚假的低频成分。这种现象称为**混叠**。混叠是不可逆的:一旦发生,就无法从样本中恢复原始信号。 + +- 混叠的日常类比是电影中的马车轮效应:车轮转速刚好高于帧率时,看起来像是在缓慢地倒转,因为摄像机对旋转的采样不足。在音频中,一个 15 kHz 的音调以 16 kHz 采样($f_\text{奈奎斯特} = 8$ kHz)时,会混叠为 $16 - 15 = 1$ kHz,一个完全不同的音高。 + +![信号的正确采样与因采样率过低导致的混叠采样对比](../images/sampling_aliasing.svg) + +- 为防止混叠,**抗混叠滤波器**(一个低通滤波器)在采样前滤除所有高于 $f_s/2$ 的频率。这一步由模数转换器(ADC)硬件在信号数字化之前完成。 + +- **量化**将每个连续取值的样本映射到有限电平集合中的最近值。一个 $n$ 位量化器有 $2^n$ 个电平。CD 音频使用 16 位量化($2^{16} = 65{,}536$ 个电平);电话通信通常使用 8 位配合 $\mu$ 律或 A 律**压扩**(一种非线性映射,为小振幅分配更多电平,以匹配人类感知)。量化会引入**量化噪声**,这是一种舍入误差,其方差为 $\Delta^2/12$,其中 $\Delta$ 是相邻电平之间的步长。 + +- **时域分析**直接从波形中提取特征,无需变换到其他域。这些特征简单、计算快速,能够捕捉信号的基本性质。 + +- **能量**衡量一帧(共 $N$ 个样本)的整体响度: + +$$E = \sum_{n=0}^{N-1} x[n]^2$$ + +- 语音段能量高;静音段能量低。能量是第 01 章中平方 $\ell_2$ 范数在信号向量上的应用。 + +- **过零率**(ZCR)统计一帧内信号改变符号的次数: + +$$\text{ZCR} = \frac{1}{2(N-1)} \sum_{n=1}^{N-1} |\text{sign}(x[n]) - \text{sign}(x[n-1])|$$ + +- 高 ZCR 表明高频成分或噪声;低 ZCR 表明低频成分或浊音(声带周期性振动时)。ZCR 是一种粗略的频率估计方法:一个 $f$ Hz 的纯音每秒过零 $2f$ 次。 + +- **自相关**衡量信号与其延迟副本之间的相似度: + +$$R[k] = \sum_{n=0}^{N-1-k} x[n] \cdot x[n+k]$$ + +- 在延迟 $k = 0$ 处,自相关等于能量。对于周期信号,自相关在等于周期及其整数倍的延迟处出现峰值。这是**基音检测**的标准技术:找出 $R[k]$ 在 $k=0$ 之后的第一个显著峰值,则基音频率为 $f_s / k_\text{峰值}$。自相关与第 01 章的点积相关:$R[k]$ 是信号与其 $k$ 位移版本的点积。 + +- **频域分析**揭示信号的频谱内容,这些信息在波形中不可见。核心工具是**离散傅里叶变换**(DFT),它将 $N$ 个样本的信号分解为 $N$ 个复数值的频率分量: + +$$X[k] = \sum_{n=0}^{N-1} x[n] \cdot e^{-j 2\pi k n / N}, \quad k = 0, 1, \ldots, N-1$$ + +- 每个 $X[k]$ 是一个复数,其幅度 $|X[k]|$ 给出频率 $f_k = k \cdot f_s / N$ Hz 处的振幅,相位 $\angle X[k]$ 给出相位偏移。DFT 是从时域基(单位脉冲)到频域基(复指数)的基变换,这是第 02 章基概念的直接应用。DFT 可以写为矩阵乘法 $\mathbf{X} = W \mathbf{x}$,其中 $W$ 是 $N \times N$ 的 DFT 矩阵,其元素为 $W_{kn} = e^{-j2\pi kn/N}$。 + +- **快速傅里叶变换**(FFT)是一种以 $O(N \log N)$ 次运算计算 DFT 的算法(而非朴素的 $O(N^2)$),其原理是将问题递归地拆分为偶数索引和奇数索引的子问题(库利-图基算法)。这种加速使得实时频谱分析成为可能。FFT 是整个计算领域最重要的算法之一。 + +- **功率谱** $|X[k]|^2$ 显示能量在各频率上的分布。**幅度谱** $|X[k]|$ 显示振幅。绘制这些谱图可以揭示哪些频率主导了信号:元音在基频的整数倍处有强谐波;擦音(如"s")在宽高频范围内有能量分布。 + +- **语谱图**是信号频率内容随时间变化的可视化表示。它是将信号切分成短的、重叠的帧,对每帧计算 FFT,然后将得到的幅度谱并排放置。横轴是时间,纵轴是频率,每个点的颜色(或亮度)代表幅度。语谱图是音频处理中最重要的单一可视化工具。 + +![语谱图,横轴为时间,纵轴为频率,颜色代表强度](../images/spectrogram_stft.svg) + +- **梅尔刻度**是一种感知频率刻度,反映人类对音高的感知方式。人类将频率的等比率感知为音高的等间隔(正如我们将强度的等比率感知为响度的等间隔)。在约 1000 Hz 以下,梅尔刻度近似线性;在 1000 Hz 以上,它变为近似对数: + +$$m = 2595 \log_{10}\left(1 + \frac{f}{700}\right)$$ + +- 其逆变换为 $f = 700(10^{m/2595} - 1)$。梅尔刻度解释了为什么音乐中的半音在对数频率轴上等间距排列:A4(440 Hz)到 A5(880 Hz)和 A5 到 A6(1760 Hz)听起来都是"向上一个八度",尽管以 Hz 为单位的间隔分别是 440 和 880。 + +- **梅尔滤波器组**是一组在梅尔刻度上均匀分布的三角形带通滤波器。每个滤波器覆盖一个频带,对该频带内的频谱能量进行求和,产生一个数值。典型的语音系统使用 40–80 个梅尔滤波器。低频滤波器窄(在人类感知敏感的频率分辨率高的区域),高频滤波器宽(在人类不敏感的低分辨率区域)。这模仿了人耳耳蜗的频率分辨率。 + +![在频率轴上叠加显示的三角形梅尔刻度滤波器组,低频处滤波器窄、高频处滤波器宽](../images/mel_filterbank.svg) + +- **梅尔频率倒谱系数**(MFCC)是语音和音频的经典特征表示。它们将梅尔谱压缩为少量去相关化的系数,捕捉谱包络的形状(编码声道配置,从而编码语音身份),同时丢弃精细的谱细节(编码音高和相位)。 + +- MFCC 流水线: + 1. **预加重**:应用一阶高通滤波器 $y[n] = x[n] - \alpha x[n-1]$(通常 $\alpha = 0.97$)以提升被声道衰减的高频成分。 + 2. **分帧**:将信号切分为重叠的帧(通常 25 ms 长,步进 10 ms)。 + 3. **加窗**:对每帧乘以窗口函数(汉明窗)以减少频谱泄漏(见下文)。 + 4. **FFT**:计算每帧加窗后的功率谱。 + 5. **梅尔滤波器组**:对功率谱应用三角形梅尔滤波器组,得到梅尔频带能量。 + 6. **对数**:对梅尔频带能量取对数。对数压缩动态范围,并将乘法(频谱分量之间)转换为加法,匹配人类响度感知。 + 7. **DCT**:对对数梅尔能量应用离散余弦变换。DCT 对梅尔频带进行去相关化(因为相邻频带高度相关)并将能量压缩到前几个系数中。保留前 13 个系数(MFCC-0 至 MFCC-12)。 + +![MFCC 流水线,从原始音频经加窗帧、FFT、梅尔滤波器组、对数压缩和 DCT 到最终 MFCC 特征](../images/mfcc_pipeline.svg) + +- 第 7 步中的 DCT 本质上是"频谱的傅里叶变换"(因此得名**倒谱** cepstrum = spectrum 的字母重排)。低阶倒谱系数捕捉宽泛的谱形状(声道谐振,称为**共振峰**),而高阶系数捕捉精细的谱细节(音高谐波)。通过只保留前 13 个系数,我们保留了共振峰信息并丢弃了音高细节。 + +- **Delta** 和 **delta-delta** MFCC(MFCC 的一阶和二阶时间导数,通过相邻帧之间的有限差分计算)捕捉谱形状的动态变化,增加时间上下文。完整的 MFCC 特征向量通常是 39 维的:13 个静态 + 13 个 delta + 13 个 delta-delta。 + +- 现代神经网络模型(第 06 章)已在很大程度上用学习到的特征取代了 MFCC:对数梅尔语谱图(第 6 步的输出,跳过 DCT)是深度学习 ASR 和音频分类的标准输入。模型学习自己的去相关化。尽管如此,MFCC 在低资源场景、经典 ML 流水线以及理解信号处理基础方面仍然很重要。 + +- **加窗**是在计算 FFT 之前对信号帧乘以平滑窗口函数的过程。不加窗时,FFT 假设帧无限重复;帧的突然开始和结束会创建人工的不连续性,使能量扩散到所有频率,这种伪影称为**频谱泄漏**。 + +- **矩形窗** $w[n] = 1$ 对所有 $n$:无渐减,泄漏最大,但主瓣最宽(在给定帧长下频率分辨率最佳)。实践中很少使用。 + +- **汉明窗**:$w[n] = 0.54 - 0.46 \cos(2\pi n / (N-1))$。在边缘处渐减到接近零,大大减少泄漏。是语音处理的标准选择。 + +- **汉宁窗**(也称为 Hanning 窗):$w[n] = 0.5 - 0.5 \cos(2\pi n / (N-1))$。在边缘处精确渐减到零。与汉明窗非常相似,但旁瓣抑制略好。 + +- **布莱克曼窗**:$w[n] = 0.42 - 0.5 \cos(2\pi n / (N-1)) + 0.08 \cos(4\pi n / (N-1))$。旁瓣抑制更好,但主瓣更宽(频率分辨率更差)。当旁瓣伪影特别严重时使用。 + +- 存在一个根本性的权衡:泄漏越少的窗口,主瓣越宽,意味着它们无法分辨两个间隔很近的频率。这就是**频谱分辨率与泄漏的权衡**,是第 03 章不确定原理的结果。 + +- **重叠相加**(OLA)是一种从加窗、处理后的帧重建信号的技术。帧之间有重叠(通常 50–75%),处理后将加窗后的输出相加。如果窗口和重叠选择得当(例如,汉宁窗配合 50% 重叠),重叠的窗口相加为常数,可实现完美重建。这对任何基于帧的音频修改(降噪、变调、变速)都至关重要。 + +- **短时傅里叶变换**(STFT)是语谱图背后的正式框架。它对信号的每个加窗帧应用 DFT: + +```math +\text{STFT}\{x[n]\}(m, k) = \sum_{n=0}^{N-1} x[n + mH] \cdot w[n] \cdot e^{-j 2\pi k n / N} +``` + +- 其中 $m$ 是帧索引,$H$ 是步进大小(连续帧之间的样本数),$w[n]$ 是窗口函数,$N$ 是 FFT 大小。输出是一个二维复数值矩阵:信号的**时频表示**。 + +- STFT 体现了根本的**时频权衡**: + - 长帧(大 $N$):频率分辨率高(能区分间隔很近的频率),但时间分辨率差(无法精确定位频率何时变化)。 + - 短帧(小 $N$):时间分辨率高,但频率分辨率差。 + - 时间分辨率和频率分辨率的乘积有下界:$\Delta t \cdot \Delta f \geq \frac{1}{4\pi}$。这是**加伯极限**,是物理中海森堡不确定原理在信号处理中的类比。 + +- 典型语音 STFT 参数:25 ms 帧长(在 16 kHz 下 $N = 400$),10 ms 步进($H = 160$),汉明窗,512 点 FFT(从 400 进行零填充以提高效率和频谱插值平滑度)。 + +- **滤波**通过放大某些频率和衰减其他频率来修改信号的频率内容。**滤波器**是一个接受输入信号并产生输出信号的系统。滤波器由其**频率响应** $H(f)$ 表征,它描述了每个频率上所施加的增益和相位偏移。 + +- **低通滤波器**:通过低于截止频率 $f_c$ 的频率,衰减高于 $f_c$ 的频率。用于去除高频噪声和细节。采样前的抗混叠滤波器就是低通滤波器。 + +- **高通滤波器**:通过高于 $f_c$ 的频率,衰减低于 $f_c$ 的频率。用于去除低频隆隆声和直流偏移。MFCC 提取中的预加重滤波器($y[n] = x[n] - 0.97 x[n-1]$)就是一个简单的高通滤波器。 + +- **带通滤波器**:通过范围 $[f_1, f_2]$ 内的频率,衰减范围外的频率。梅尔滤波器组中的每个三角形就是一个带通滤波器。 + +- **带阻(陷波)滤波器**:衰减特定的窄频范围。用于去除特定干扰(例如 50/60 Hz 的电源线嗡嗡声)。 + +- **有限冲激响应**(FIR)滤波器将每个输出样本计算为当前和过去输入样本的加权和: + +$$y[n] = \sum_{k=0}^{M} b_k \cdot x[n-k]$$ + +- 权重 $b_k$ 是**滤波器系数**(也称为**抽头**)。滤波器的阶数为 $M$。FIR 滤波器始终稳定(输出不会发散),并且可以设计为具有完美的线性相位(所有频率的延迟相同,从而保持波形形状)。其缺点是实现陡峭的截止需要大量抽头(高 $M$),增加了计算量。输出是输入与系数向量的卷积,正是第 06 章中的一维卷积运算。 + +- **无限冲激响应**(IIR)滤波器使用反馈:输出既依赖于过去的输入,也依赖于过去的输出: + +```math +y[n] = \sum_{k=0}^{M} b_k \cdot x[n-k] - \sum_{k=1}^{L} a_k \cdot y[n-k] +``` + +- 反馈项 $a_k$ 创建了一个递归结构,其冲激响应理论上持续无限长。IIR 滤波器可以用比 FIR 滤波器少得多的系数实现陡峭的截止,但可能不稳定(如果传递函数的极点位于单位圆之外,输出将无界增长——这是 $z$ 变换中的概念)。它们还具有非线性相位,可能使波形形状失真。经典滤波器设计(巴特沃斯、切比雪夫、椭圆滤波器)都是 IIR 的。 + +- **传递函数**通过 $z$ 变换获得: + +$$H(z) = \frac{\sum_{k=0}^{M} b_k z^{-k}}{1 + \sum_{k=1}^{L} a_k z^{-k}}$$ + +- 分子的根称为**零点**,分母的根称为**极点**。极零点图完全刻画了滤波器的行为。单位圆附近的极点放大附近的频率;单位圆附近的零点衰减它们。FIR 滤波器只有零点(分母为 1)。这与第 02 章和第 03 章中的特征值和求根概念相联系。 + +- **卷积定理**:时域中的卷积等于频域中的逐元素乘法。这意味着滤波既可以通过将信号与滤波器的冲激响应直接卷积来实现,也可以通过将它们的傅里叶变换相乘再逆变换来实现。对于长滤波器,频域方法(使用 FFT)更快:$O(N \log N)$ 对比 $O(NM)$。 + +- **逆 STFT**(iSTFT)从其 STFT 表示重建时域信号。这对于任何在频域中修改音频的系统(降噪、源分离、语音转换)都至关重要。重建使用重叠相加: + +```math +x[n] = \frac{\sum_{m} w[n - mH] \cdot \text{IDFT}\{X(m, k)\}[n - mH]}{\sum_{m} w[n - mH]^2} +``` + +- 分母对窗口重叠进行归一化,确保当合成窗口与分析窗口匹配且重叠足够时实现完美重建。 + +- **语音 DSP 流水线总结**:原始音频以 16 kHz 采样、预加重、切分为 25 ms 的汉明窗帧(步进 10 ms),每帧进行 FFT 变换,通过梅尔滤波器组,进行对数压缩,然后要么保留为对数梅尔特征(用于神经网络模型),要么进行 DCT 变换生成 MFCC(用于经典模型)。整个流水线将一维时域信号转换为适合下游机器学习的二维时频表示,这将是文件 02 的主题。 + +## 编程练习(在 CoLab 或 notebook 中完成) + +1. 生成一个正弦波,以不同采样率采样,演示混叠现象。绘制连续信号、正确采样版本和欠采样(混叠)版本的对比图。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# 参数 +f_signal = 5.0 # 5 Hz 信号 +duration = 1.0 # 1 秒 + +# "连续"信号(非常高的采样率) +t_cont = jnp.linspace(0, duration, 10000) +x_cont = jnp.sin(2 * jnp.pi * f_signal * t_cont) + +# 正确采样(fs = 50 Hz,远高于奈奎斯特频率 10 Hz) +fs_good = 50 +t_good = jnp.arange(0, duration, 1.0 / fs_good) +x_good = jnp.sin(2 * jnp.pi * f_signal * t_good) + +# 欠采样(fs = 7 Hz,低于奈奎斯特频率 10 Hz)-> 混叠 +fs_bad = 7 +t_bad = jnp.arange(0, duration, 1.0 / fs_bad) +x_bad = jnp.sin(2 * jnp.pi * f_signal * t_bad) + +# 混叠后的频率:|f_signal - fs_bad| = |5 - 7| = 2 Hz +f_alias = abs(f_signal - fs_bad) +x_alias_cont = jnp.sin(2 * jnp.pi * f_alias * t_cont) + +fig, axes = plt.subplots(3, 1, figsize=(12, 9)) + +# 图 1:原始信号 +axes[0].plot(t_cont, x_cont, color='#3498db', linewidth=1.5, label=f'原始 {f_signal} Hz 信号') +axes[0].set_title(f'原始 {f_signal} Hz 信号') +axes[0].set_xlabel('时间 (s)'); axes[0].set_ylabel('振幅') +axes[0].legend(); axes[0].grid(True, alpha=0.3) + +# 图 2:正确采样 +axes[1].plot(t_cont, x_cont, color='#3498db', linewidth=1, alpha=0.4, label='原始信号') +axes[1].stem(t_good, x_good, linefmt='#27ae60', markerfmt='o', basefmt='k-', + label=f'以 {fs_good} Hz 采样(高于奈奎斯特频率)') +axes[1].set_title(f'正确采样:fs = {fs_good} Hz > 2 x {f_signal} Hz') +axes[1].set_xlabel('时间 (s)'); axes[1].set_ylabel('振幅') +axes[1].legend(); axes[1].grid(True, alpha=0.3) + +# 图 3:混叠采样 +axes[2].plot(t_cont, x_cont, color='#3498db', linewidth=1, alpha=0.4, label='原始信号') +axes[2].stem(t_bad, x_bad, linefmt='#e74c3c', markerfmt='o', basefmt='k-', + label=f'以 {fs_bad} Hz 采样(低于奈奎斯特频率)') +axes[2].plot(t_cont, x_alias_cont, color='#f39c12', linewidth=1.5, linestyle='--', + label=f'混叠信号表现为 {f_alias} Hz') +axes[2].set_title(f'混叠采样:fs = {fs_bad} Hz < 2 x {f_signal} Hz') +axes[2].set_xlabel('时间 (s)'); axes[2].set_ylabel('振幅') +axes[2].legend(); axes[2].grid(True, alpha=0.3) + +plt.tight_layout(); plt.show() +``` + +2. 计算并可视化由多个正弦波组成的信号的 FFT。显示幅度谱并识别组成频率。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# 创建复合信号:220 Hz + 440 Hz + 880 Hz(A3 + A4 + A5) +fs = 8000 # 8 kHz 采样率 +duration = 0.1 # 100 ms +t = jnp.arange(0, duration, 1.0 / fs) +n_samples = len(t) + +# 三个频率分量,不同振幅 +x = 1.0 * jnp.sin(2 * jnp.pi * 220 * t) + \ + 0.6 * jnp.sin(2 * jnp.pi * 440 * t) + \ + 0.3 * jnp.sin(2 * jnp.pi * 880 * t) + +# 计算 FFT +X = jnp.fft.fft(x) +freqs = jnp.fft.fftfreq(n_samples, d=1.0 / fs) +magnitude = jnp.abs(X) / n_samples # 归一化 + +# 只绘制正频率部分 +pos_mask = freqs >= 0 +freqs_pos = freqs[pos_mask] +mag_pos = magnitude[pos_mask] * 2 # 翻倍以补偿负频率的能量 + +fig, axes = plt.subplots(2, 1, figsize=(12, 7)) + +# 时域 +axes[0].plot(t * 1000, x, color='#3498db', linewidth=1) +axes[0].set_title('复合信号:220 Hz + 440 Hz + 880 Hz') +axes[0].set_xlabel('时间 (ms)'); axes[0].set_ylabel('振幅') +axes[0].grid(True, alpha=0.3) + +# 频域 +axes[1].plot(freqs_pos, mag_pos, color='#e74c3c', linewidth=1.5) +axes[1].set_title('幅度谱(FFT)') +axes[1].set_xlabel('频率 (Hz)'); axes[1].set_ylabel('幅度') +axes[1].set_xlim(0, 1500) +# 标注峰值 +for f_peak, amp in [(220, 1.0), (440, 0.6), (880, 0.3)]: + axes[1].annotate(f'{f_peak} Hz', xy=(f_peak, amp), fontsize=10, + ha='center', va='bottom', color='#9b59b6', + arrowprops=dict(arrowstyle='->', color='#9b59b6')) +axes[1].grid(True, alpha=0.3) + +plt.tight_layout(); plt.show() +``` + +3. 在 JAX 中从头构建完整的 MFCC 流水线:预加重、分帧、加窗、FFT、梅尔滤波器组、对数、DCT。可视化梅尔滤波器组和生成的 MFCC 热力图。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# --- 生成一个合成类语音信号 --- +key = jax.random.PRNGKey(42) +fs = 16000 +duration = 1.0 +t = jnp.arange(0, duration, 1.0 / fs) + +# 模拟浊音语音:基频 + 谐波,振幅衰减 +f0 = 150.0 # 基频 +x = sum(jnp.sin(2 * jnp.pi * f0 * k * t) / k for k in range(1, 8)) +# 添加一些噪声 +x = x + 0.1 * jax.random.normal(key, t.shape) +x = x / jnp.max(jnp.abs(x)) # 归一化 + +# --- 第 1 步:预加重 --- +alpha = 0.97 +x_pre = jnp.concatenate([x[:1], x[1:] - alpha * x[:-1]]) + +# --- 第 2 步:分帧 --- +frame_len = int(0.025 * fs) # 25 ms = 400 个样本 +hop_len = int(0.010 * fs) # 10 ms = 160 个样本 +n_frames = (len(x_pre) - frame_len) // hop_len + 1 +frames = jnp.stack([x_pre[i * hop_len : i * hop_len + frame_len] + for i in range(n_frames)]) + +# --- 第 3 步:汉明窗 --- +hamming = 0.54 - 0.46 * jnp.cos(2 * jnp.pi * jnp.arange(frame_len) / (frame_len - 1)) +windowed = frames * hamming + +# --- 第 4 步:FFT --- +n_fft = 512 +spectra = jnp.fft.rfft(windowed, n=n_fft) +power_spectra = jnp.abs(spectra) ** 2 / n_fft + +# --- 第 5 步:梅尔滤波器组 --- +n_mels = 40 +f_min, f_max = 0.0, fs / 2.0 + +def hz_to_mel(f): + return 2595 * jnp.log10(1 + f / 700) + +def mel_to_hz(m): + return 700 * (10 ** (m / 2595) - 1) + +mel_min = hz_to_mel(f_min) +mel_max = hz_to_mel(f_max) +mel_points = jnp.linspace(mel_min, mel_max, n_mels + 2) +hz_points = mel_to_hz(mel_points) + +freq_bins = jnp.floor((n_fft + 1) * hz_points / fs).astype(jnp.int32) +n_freqs = n_fft // 2 + 1 +filterbank = jnp.zeros((n_mels, n_freqs)) + +for m in range(n_mels): + f_left = freq_bins[m] + f_center = freq_bins[m + 1] + f_right = freq_bins[m + 2] + # 上升沿 + for k in range(int(f_left), int(f_center)): + if f_center != f_left: + filterbank = filterbank.at[m, k].set((k - f_left) / (f_center - f_left)) + # 下降沿 + for k in range(int(f_center), int(f_right)): + if f_right != f_center: + filterbank = filterbank.at[m, k].set((f_right - k) / (f_right - f_center)) + +# 应用滤波器组 +mel_spectra = jnp.dot(power_spectra, filterbank.T) + +# --- 第 6 步:对数 --- +log_mel = jnp.log(mel_spectra + 1e-10) + +# --- 第 7 步:DCT(第二类) --- +n_mfcc = 13 +n_mel_channels = log_mel.shape[1] +dct_matrix = jnp.zeros((n_mfcc, n_mel_channels)) +for i in range(n_mfcc): + for j in range(n_mel_channels): + dct_matrix = dct_matrix.at[i, j].set( + jnp.cos(jnp.pi * i * (j + 0.5) / n_mel_channels) + ) +mfccs = jnp.dot(log_mel, dct_matrix.T) + +# --- 可视化 --- +fig, axes = plt.subplots(3, 1, figsize=(14, 11)) + +# 梅尔滤波器组 +freq_axis = jnp.linspace(0, fs / 2, n_freqs) +for m in range(n_mels): + color = '#3498db' if m % 2 == 0 else '#e74c3c' + axes[0].plot(freq_axis, filterbank[m], color=color, alpha=0.6, linewidth=0.8) +axes[0].set_title(f'梅尔滤波器组({n_mels} 个滤波器)') +axes[0].set_xlabel('频率 (Hz)'); axes[0].set_ylabel('权重') +axes[0].grid(True, alpha=0.3) + +# 对数梅尔语谱图 +im1 = axes[1].imshow(log_mel.T, aspect='auto', origin='lower', + extent=[0, duration, 0, n_mels], cmap='viridis') +axes[1].set_title('对数梅尔语谱图') +axes[1].set_xlabel('时间 (s)'); axes[1].set_ylabel('梅尔频带') +plt.colorbar(im1, ax=axes[1], label='对数能量') + +# MFCC +im2 = axes[2].imshow(mfccs.T, aspect='auto', origin='lower', + extent=[0, duration, 0, n_mfcc], cmap='coolwarm') +axes[2].set_title(f'MFCC(前 {n_mfcc} 个系数)') +axes[2].set_xlabel('时间 (s)'); axes[2].set_ylabel('MFCC 索引') +plt.colorbar(im2, ax=axes[2], label='系数值') + +plt.tight_layout(); plt.show() +``` + +4. 实现 FIR 低通和高通滤波器,并可视化它们对包含低频和高频分量信号的影响。同时显示时域和频域的视图。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# 创建包含低频(100 Hz)和高频(2000 Hz)分量的信号 +fs = 8000 +duration = 0.05 # 50 ms,便于清晰显示 +t = jnp.arange(0, duration, 1.0 / fs) + +x_low = jnp.sin(2 * jnp.pi * 100 * t) +x_high = 0.5 * jnp.sin(2 * jnp.pi * 2000 * t) +x = x_low + x_high + +# 使用窗函数法设计简单的 FIR 低通滤波器 +def fir_lowpass(cutoff_hz, fs, n_taps=51): + """使用窗函数法设计 FIR 低通滤波器。""" + fc = cutoff_hz / fs # 归一化截止频率 + n = jnp.arange(n_taps) + mid = (n_taps - 1) / 2.0 + # Sinc 函数(理想低通冲激响应) + h = jnp.where(n == mid, 2 * fc, + jnp.sin(2 * jnp.pi * fc * (n - mid)) / (jnp.pi * (n - mid))) + # 应用汉明窗 + window = 0.54 - 0.46 * jnp.cos(2 * jnp.pi * n / (n_taps - 1)) + h = h * window + h = h / jnp.sum(h) # 归一化到直流增益为 1 + return h + +def apply_filter(x, h): + """通过卷积应用 FIR 滤波器。""" + return jnp.convolve(x, h, mode='same') + +# 500 Hz 低通滤波器(通过 100 Hz,阻塞 2000 Hz) +h_lp = fir_lowpass(500, fs, n_taps=51) +x_lp = apply_filter(x, h_lp) + +# 高通 = 冲激 - 低通(频谱反转) +delta = jnp.zeros(51) +delta = delta.at[25].set(1.0) +h_hp = delta - h_lp +x_hp = apply_filter(x, h_hp) + +# 计算所有信号的频谱 +def compute_spectrum(signal, fs): + X = jnp.fft.rfft(signal) + freqs = jnp.fft.rfftfreq(len(signal), d=1.0 / fs) + mag = jnp.abs(X) / len(signal) * 2 + return freqs, mag + +fig, axes = plt.subplots(3, 2, figsize=(14, 10)) + +# 时域图 +for i, (sig, title, color) in enumerate([ + (x, '原始信号(100 Hz + 2000 Hz)', '#3498db'), + (x_lp, '低通滤波后(< 500 Hz)', '#27ae60'), + (x_hp, '高通滤波后(> 500 Hz)', '#e74c3c') +]): + axes[i, 0].plot(t * 1000, sig[:len(t)], color=color, linewidth=1) + axes[i, 0].set_title(f'时域:{title}') + axes[i, 0].set_xlabel('时间 (ms)'); axes[i, 0].set_ylabel('振幅') + axes[i, 0].grid(True, alpha=0.3) + +# 频域图 +for i, (sig, title, color) in enumerate([ + (x, '原始信号', '#3498db'), + (x_lp, '低通', '#27ae60'), + (x_hp, '高通', '#e74c3c') +]): + freqs, mag = compute_spectrum(sig, fs) + axes[i, 1].plot(freqs, mag, color=color, linewidth=1.5) + axes[i, 1].set_title(f'频谱:{title}') + axes[i, 1].set_xlabel('频率 (Hz)'); axes[i, 1].set_ylabel('幅度') + axes[i, 1].set_xlim(0, 3000) + axes[i, 1].axvline(x=500, color='#f39c12', linestyle='--', alpha=0.7, + label='截止频率(500 Hz)') + axes[i, 1].legend(); axes[i, 1].grid(True, alpha=0.3) + +plt.tight_layout(); plt.show() +``` diff --git a/chapter 09: audio and speech/02. automatic speech recognition.md b/chapter 09: audio and speech/02. automatic speech recognition.md new file mode 100644 index 0000000..4655afe --- /dev/null +++ b/chapter 09: audio and speech/02. automatic speech recognition.md @@ -0,0 +1,592 @@ +# 自动语音识别 + +*自动语音识别将口语音频转换为书面文本,弥合人类语音与机器可读语言之间的鸿沟。本文涵盖 GMM-HMM、CTC 损失、RNN-转导器、基于注意力的编码器-解码器模型(LAS)、Whisper 以及端到端 ASR,从经典流水线到现代神经架构。* + +- **自动语音识别**(ASR)是将口语音频转换为书面文本的任务。它是 AI 领域最古老的问题之一(20 世纪 50 年代的第一批系统就能识别单个数字),也是商业部署最广泛的任务之一(语音助手、转录服务、字幕生成)。 + +- 难点在于语音的巨大变异性:不同的说话人、口音、语速、背景噪声、麦克风特性,以及将连续声学信号映射到离散单词这一根本性歧义问题。 + +- 可以把 ASR 想象成法庭速记员。速记员听到连续的声音流,在心理上将其分割成单词,利用上下文解决歧义(如"they're" vs "their" vs "there"),然后打出结果。ASR 系统做同样的事情,但分阶段进行,每个阶段可以独立或联合优化。 + +- **经典 ASR 流水线**通过一系列不同阶段处理音频:原始音频被转换为特征(MFCC 或对数梅尔频谱图,见文件 01),**声学模型**评估每个特征帧与每个语音单元的匹配程度,**发音模型**(词典)将语音单元映射为单词,**语言模型**评估词序列的合理程度,**解码器**搜索使联合得分最大化的词序列。每个组件分别训练和调优。 + +![ASR 流水线:从原始音频经过特征提取、声学模型、解码器和语言模型到输出文本](../images/asr_pipeline.svg) + +- **音素**是语言中区分单词的最小声音单位。英语大约有 39-44 个音素(具体数量取决于方言和所用音素库)。例如,"bat"和"pat"相差一个音素(/b/ vs /p/)。大多数 ASR 系统建模的是**上下文相关音素**,称为**三音素**:由其左邻和右邻共同定义的音素(例如,"b_t"上下文中的"a"与"c_t"上下文中的"a"是不同的单元),因为音素的声学实现受其邻接音素的强烈影响(这称为**协同发音**)。 + +- 可能的三音素数量巨大(40 个音素的三次方 = 64,000),因此**决策树聚类**将声学上相似的三音素分组为**声学状态**(通常为 2000-10,000 个类别)。每个声学状态拥有自己的声学模型。这种聚类是第 06 章中决策树算法的一种应用形式。 + +- **GMM-HMM**(高斯混合模型-隐马尔可夫模型)是从 20 世纪 80 年代到 21 世纪初主导的声学建模方法。HMM(见第 05 章)对语音的时间结构进行建模:每个音素是一个从左到右的 HMM,有 3-5 个状态,每个状态代表一个子音素段(起始、中间、结束)。状态间的转移隐式地建模时长。 + +- 在每个 HMM 状态,发射概率(给定状态下特定特征向量的可能性)由**高斯混合模型**(GMM)建模:多元高斯分布的加权和(见第 05 章): + +```math +p(\mathbf{x} | s) = \sum_{m=1}^{M} w_m \cdot \mathcal{N}(\mathbf{x} ; \boldsymbol{\mu}_m, \boldsymbol{\Sigma}_m) +``` + +- 其中 $\mathbf{x}$ 是特征向量(例如 39 维 MFCC),$s$ 是 HMM 状态,$M$ 是混合分量数(通常为 8-64),$w_m$ 是混合权重,$\boldsymbol{\mu}_m$ 和 $\boldsymbol{\Sigma}_m$ 是每个高斯分量的均值和协方差。协方差矩阵通常使用对角形式以提高计算效率(假设特征维度独立,对于 MFCC 而言由于 DCT 去相关性,这一假设近似成立)。 + +- 训练使用 **Baum-Welch 算法**(EM 算法的特例,见第 05 章)从有标注的语音数据中迭代估计 GMM 参数和 HMM 转移概率。解码(寻找最可能的状态序列)使用 **Viterbi 算法**(动态规划,见第 05 章): + +```math +\delta_t(j) = \max_{i} \left[ \delta_{t-1}(i) \cdot a_{ij} \right] \cdot b_j(\mathbf{x}_t) +``` + +- 其中 $\delta_t(j)$ 是在时间 $t$ 以状态 $j$ 结束的最佳路径的概率,$a_{ij}$ 是从状态 $i$ 到状态 $j$ 的转移概率,$b_j(\mathbf{x}_t)$ 是在状态 $j$ 下特征 $\mathbf{x}_t$ 的发射概率。 + +- **DNN-HMM**(Hinton 等人,2012)用深度神经网络(DNN,见第 06 章)取代了 GMM 发射模型,从特征帧窗口中预测声学状态后验概率 $p(s | \mathbf{x})$。HMM 仍然处理时间结构和序列化,但神经网络提供了更具判别力的发射分数。这种混合方法相对于 GMM 将词错误率降低了 20-30%,并在 2012-2016 年间占据主导地位。 + +- **WFST 解码**(加权有限状态换导器)是传统 ASR 的标准解码框架。每个组件(HMM 拓扑 H、上下文依赖 C、词典 L、语法/语言模型 G)都表示为加权有限状态换导器,它们被组合成单个搜索图 $H \circ C \circ L \circ G$。然后 Viterbi 搜索在此组合图中寻找最低成本路径。WFST 允许知识源的模块化组合和高效的动态规划搜索。其数学框架来自有限自动机理论(与第 05 章中的状态机相关)。 + +- **端到端 ASR** 消除了独立的组件(发音模型、音素库、WFST 解码器),训练一个直接将音频特征映射到字符或子词的单一神经网络。关键挑战是**对齐问题**:输入(每秒数百个特征帧)和输出(每秒几个字符)的长度相差很大,且训练时它们之间的对齐关系是未知的。 + +- **连接主义时序分类**(CTC)(Graves 等人,2006)通过引入一个特殊的**空白**标记解决了对齐问题,允许网络输出任意长度的字符和空白序列,只要通过合并连续重复和移除空白后能得到正确的转录文本。例如,转录文本"cat"可以由输出序列"--cc-aa-t--"产生(其中"-"是空白)。 + +- 形式上,CTC 定义了一个多对一映射 $\mathcal{B}$,从所有长度为 $T$ 的输出序列(使用字母表加上空白)到标签序列。标签序列 $\mathbf{y}$ 的概率是所有能约简到它的对齐路径的概率之和: + +$$P(\mathbf{y} | \mathbf{x}) = \sum_{\boldsymbol{\pi} \in \mathcal{B}^{-1}(\mathbf{y})} \prod_{t=1}^{T} p(\pi_t | \mathbf{x})$$ + +![CTC 对齐示意图:多条路径通过空白和字符标记,最终约简为相同的输出文本](../images/ctc_alignment.svg) + +- 直接计算此和需要枚举指数数量的对齐路径,但 **CTC 前向-后向算法**使用动态规划在 $O(T \cdot |\mathbf{y}|)$ 时间内高效计算,类似于第 05 章中的 HMM 前向-后向算法。 + +- CTC 做了一个**条件独立性假设**:给定输入,每个时间步的输出独立于所有其他输出。这意味着 CTC 无法建模输出之间的依赖关系(例如,它无法学习到"q"几乎总是后跟"u")。必须使用外部语言模型来处理此类依赖关系。 + +- **CTC 解码**选项: + - **贪婪解码**:在每个时间步取最可能的标记,然后合并。速度快但效果次优。 + - **束搜索**:在每个步骤维护得分最高的 $k$ 个部分假设,合并能约简为相同前缀的假设。可以结合语言模型得分。 + - **前缀束搜索**:一种改进的束搜索,正确处理 CTC 空白合并,确保假设在合并后进行对比。 + +- **RNN-转导器**(RNN-T)(Graves,2012)通过添加一个显式的**预测网络**(类语言模型的 RNN)扩展了 CTC,使每个输出以之前的输出为条件,从而消除了条件独立性假设。RNN-T 有三个组件: + - **编码器**:处理音频特征,生成隐藏表示 $\mathbf{h}_t^\text{enc}$(通常是 LSTM 或 Conformer 层的堆叠)。 + - **预测网络**:自回归 RNN,根据之前发射的标签生成隐藏表示 $\mathbf{h}_u^\text{pred}$。 + - **联合网络**:在每个(时间,标签)位置组合编码器和预测网络的输出,产生下一个标记(包括空白)的分布: + +$$p(y | t, u) = \text{softmax}(W \cdot \text{tanh}(W_\text{enc} \mathbf{h}_t^\text{enc} + W_\text{pred} \mathbf{h}_u^\text{pred} + b))$$ + +- RNN-T 可以在每个时间步发射零个或多个标签(通过先发射非空白标记再前进到下一个时间步,或发射空白前进但不输出)。训练使用二维(时间,标签)网格上的前向-后向算法,复杂度为 $O(T \cdot U)$,其中 $U$ 是输出长度。RNN-T 是设备端流式 ASR 的主导架构(用于 Google Pixel 手机和类似产品),因为它天然支持流式处理:编码器从左到右处理音频,预测网络增量生成输出。 + +- **Listen, Attend and Spell**(LAS)(Chan 等人,2016)是一种基于注意力的编码器-解码器模型(序列到序列架构,见第 06 章)。它有三个组件: + - **Listener**(编码器):金字塔形双向 LSTM,处理完整的输入序列并下采样 8 倍(通过在每层拼接连续隐藏状态对),生成较短的编码器隐藏状态序列。 + - **Attention**(注意力):在每个解码步骤中,计算所有编码器状态上的注意力权重,形成上下文向量(与第 07 章中相同的注意力机制)。 + - **Speller**(解码器):自回归 LSTM,在上下文向量和之前生成的字符的条件下逐字符生成输出转录文本。 + +- LAS 取得了很强的结果,但需要完整的语音片段才能开始解码(因为注意力需要关注所有编码器状态),因此不适合流式应用。此外,它在处理超长语音片段时表现不佳,因为长序列上的注意力会变得弥散。 + +- **Conformer**(Gulati 等人,2020)将卷积的局部模式捕捉能力与自注意力的全局依赖建模能力相结合。每个 Conformer 块以三明治结构包含四个模块: + 1. **前馈模块**(半步):带残差连接的前馈网络,使用一半的残差权重。 + 2. **多头自注意力模块**:标准 Transformer 自注意力(来自第 07 章),使用相对位置编码。 + 3. **卷积模块**:逐点卷积、门控线性单元(GLU)、一维深度可分离卷积、批归一化、Swish 激活函数和另一个逐点卷积。深度可分离卷积捕捉局部上下文(类似于特征序列上的 n-gram)。 + 4. **前馈模块**(半步):与模块 1 相同。 + +- 输出为:$\mathbf{y} = \text{LayerNorm}(\mathbf{x} + \frac{1}{2}\text{FFN}_1 + \text{MHSA} + \text{Conv} + \frac{1}{2}\text{FFN}_2)$。实验证明这种马卡龙式结构(FFN-注意力-卷积-FFN)配合半步残差优于其他排序方式。Conformer 已成为 CTC 和 RNN-T 系统的默认编码器,性能优于纯 Transformer 和纯 LSTM 编码器。 + +![Conformer 块示意图:前馈、自注意力、卷积和前馈模块的三明治结构](../images/conformer_block.svg) + +- **Whisper**(Radford 等人,2023)是 OpenAI 的大规模基于注意力的 ASR 模型。它使用标准的编码器-解码器 Transformer 架构(来自第 07 章),在从互联网抓取的 68 万小时弱监督数据(音频与近似转录文本配对)上进行训练。关键设计选择: + - 输入:80 通道对数梅尔频谱图(来自文件 01),使用 25 ms 窗口和 10 ms 步长,归一化为零均值和单位方差。 + - 编码器:标准 Transformer 编码器,使用正弦位置嵌入和预激活层归一化。 + - 解码器:Transformer 解码器,使用字节级 BPE 分词器(来自第 07 章)自回归生成标记。 + - 多任务:单个模型处理转录、翻译、语言识别和时间戳预测,通过解码器提示中的特殊任务标记进行条件控制。 + - 训练数据的规模(而非架构创新)是 Whisper 在跨领域、跨口音和跨语言上强泛化能力的主要驱动力。 + +- **wav2vec 2.0**(Baevski 等人,2020)是一种用于语音表示的**自监督**预训练框架。核心思想是从大量未标注的音频中学习语音表示,然后用少量标注数据进行微调。这遵循了与 BERT(来自第 07 章)相同的自监督范式,但针对连续音频信号进行了适配。 + +- wav2vec 2.0 架构包含三个部分: + - **特征编码器**:多层一维 CNN,处理原始波形样本,以 20 ms 的帧率(在 16 kHz 下每 320 个样本一个向量)生成潜在表示 $\mathbf{z}_t$。 + - **量化模块**:使用**乘积量化**(将向量分成组,每组独立量化,从 $G$ 个码本中各选 $V$ 个条目)将潜在表示离散化为有限码本。这为对比学习目标产生目标 $\mathbf{q}_t$。 + - **上下文网络**:Transformer 编码器,接收(部分掩码的)潜在表示并生成上下文化的表示 $\mathbf{c}_t$。 + +![wav2vec 2.0 架构:CNN 特征编码器、掩码、Transformer 上下文网络和基于量化目标的对比学习](../images/wav2vec2_pretraining.svg) + +- 在预训练期间,随机跨度内的潜在表示被**掩码**(替换为可学习的掩码嵌入),模型必须从一组干扰项(从同一语音片段的其他位置采样的负样本)中识别出掩码位置的真实量化表示。对比损失为: + +$$\mathcal{L} = -\log \frac{\exp(\text{sim}(\mathbf{c}_t, \mathbf{q}_t) / \kappa)}{\sum_{\tilde{\mathbf{q}} \in Q_t} \exp(\text{sim}(\mathbf{c}_t, \tilde{\mathbf{q}}) / \kappa)}$$ + +- 其中 $\text{sim}$ 是余弦相似度,$\kappa$ 是温度参数,$Q_t$ 包括真实量化目标和干扰项。额外的**多样性损失**鼓励均衡使用所有码本条目。该损失本质上是 InfoNCE 对比损失,与视觉自监督学习中使用的对比目标函数属于同一族。 + +- 预训练后,在其上添加线性投影和 CTC 头部,然后在标注数据上进行微调。wav2vec 2.0 仅使用 10 分钟标注数据(使用 53,000 小时未标注音频进行预训练)即达到了接近最优的结果,展示了自监督学习在低资源语音识别中的强大能力。 + +- **HuBERT**(Hsu 等人,2021)是另一种自监督方法,用**掩码预测**目标(预测掩码帧的离散聚类分配)替代对比目标。目标由离线聚类步骤产生(第一次迭代使用 MFCC 的 k-means,后续迭代使用 HuBERT 特征的 k-means)。与 wav2vec 2.0 相比,HuBERT 简化了训练流程(无需量化模块或对比采样),且达到相当或更好的结果。 + +- **Fast Conformer**(Rekesh 等人,2023,NVIDIA NeMo)用**下采样注意力**机制替代标准 Conformer 中的二次自注意力:输入序列在计算注意力之前被压缩(通常通过步进卷积实现 8 倍压缩),然后再扩展回来。这将注意力成本从 $O(T^2)$ 降低到 $O(T^2/64)$,同时保留全局上下文,使训练超长语音片段(长达几分钟)不会出现内存问题。Fast Conformer 是 NVIDIA NeMo 工具包中的默认编码器,构成了其生产级模型的基础架构。 + +- **Parakeet**(NVIDIA,2024)是一系列基于 Fast Conformer 编码器的高精度英文 ASR 模型,配备 CTC 和 RNN-T 解码器,在 64,000 小时英语语音上训练。Parakeet 模型(0.6B 和 1.1B 参数)在发布时于标准基准上取得了最低的词错误率,在大多数英语测试集上超越了 Whisper large-v3。关键要素是高效的 Fast Conformer 架构、激进的数据增强(SpecAugment、速度扰动、噪声混合)和大规模监督训练数据——这表明对已知组件的精心工程化仍能推动技术前沿。 + +- **Canary**(NVIDIA,2024)将 NeMo 框架扩展到多语言和多任务 ASR。它使用 Fast Conformer 编码器配合基于注意力的解码器(而非 CTC 或 RNN-T),在单个模型中处理多种语言的转录和翻译(类似于 Whisper 的多任务设计,但使用更高效的 Fast Conformer 骨干网络)。Canary 模型支持英语、德语、西班牙语和法语,具有竞争性的准确率。 + +- **Moonshine**(Useful Sensors,2024)是一系列针对**设备端和边缘部署**专门优化的 ASR 模型。编码器使用混合架构,将初始的 Transformer/Conformer 层替换为小型 CNN 后接少量 Transformer 层,大幅缩小了模型体积(基础模型不到 3000 万参数)。Moonshine 面向 CPU 和低功耗设备上的实时流式处理,在这些场景下 Whisper 过大过慢,Moonshine 以少量精度换取 5-10 倍的更低延迟和内存占用。 + +- **Distil-Whisper**(Gandhi 等人,2023)应用**知识蒸馏**(第 06 章)将 Whisper 压缩为更小更快的模型。学生模型仅使用 2 个解码器层(相比之下 Whisper 有 32 层),同时保留完整的编码器,并训练以匹配 Whisper 的输出分布。Distil-Whisper 在 WER 上与教师模型差距在 1% 以内,同时速度快了 6 倍,使其在全尺寸 Whisper 模型过慢的实时应用中变得实用。 + +- **通用语音模型(USM)**(Zhang 等人,2023,Google)将自监督预训练扩展到 1200 万小时跨 300 多种语言的未标注音频,随后进行监督微调。USM 证明了 wav2vec 2.0 / 自监督范式可以扩展到真正大规模的数据范围,在标注数据非常有限的低资源语言上取得了强性能。 + +- **大规模多语言语音(MMS)**(Pratap 等人,2023,Meta)将 wav2vec 2.0 预训练扩展到超过 1,100 种语言,利用宗教录音和其他来源的多语言音频。MMS 覆盖的语言数量远超之前的任何 ASR 系统,首次为许多资源匮乏的语言提供了语音识别能力。 + +- 现代 ASR 的格局正趋于几个主导范式:(1)Conformer 族编码器配合 CTC 或 RNN-T 用于流式处理,(2)编码器-解码器 Transformer 用于离线/多任务,(3)自监督预训练用于低资源场景,(4)规模化——更多的数据和更大的模型持续提升准确率。这些选择取决于部署约束:延迟预算、可用算力、语言数量,以及应用是流式还是批处理。 + +- **语言模型集成**通过引入声学模型无法捕捉的语言知识来改进 ASR。基本思想是在解码时将声学模型得分 $p(\mathbf{x} | \mathbf{y})$(音频与转录文本的匹配程度)与语言模型得分 $p(\mathbf{y})$(转录文本作为句子的合理性)相结合。 + +- **浅融合**在束搜索时结合得分: + +$$\hat{\mathbf{y}} = \arg\max_\mathbf{y} \left[ \log p_\text{AM}(\mathbf{y} | \mathbf{x}) + \lambda \log p_\text{LM}(\mathbf{y}) \right]$$ + +- 其中 $\lambda$ 是可调权重,$p_\text{LM}$ 是外部语言模型(通常是 n-gram 或神经语言模型,来自第 07 章)。这种方法简单有效,但要求 LM 使用与 ASR 模型相同的标记词汇表。 + +- **深度融合**(Gulcehre 等人,2015)将语言模型集成到解码器网络内部:LM 隐藏状态与解码器隐藏状态拼接,通过门控机制后进入输出投影层。整个系统(包括预训练的 LM)被联合微调。这种方法集成更深入,但训练更复杂。 + +- **冷融合**(Sriram 等人,2018)与深度融合类似,但 ASR 解码器从头开始与集成语言模型一起训练,而非微调预训练的解码器。这迫使声学模型学习互补信息,而非重复 LM 已经知道的内容。 + +- **重打分**(N-best 重打分)是一种两遍方法:首先使用束搜索生成 $N$ 个候选转录文本,然后使用更强大的语言模型(例如,大型 Transformer LM)对它们重新排序。这种方法实现简单,且允许使用对第一遍解码来说太慢的非常大的 LM。 + +- **内部语言模型估计**(ILME)解决了一个微妙的问题:端到端模型从训练转录文本中隐式学习了一个内部 LM,这在浅融合时可能与外部 LM 冲突(本质上是对语言先验进行了双重计数)。ILME 估计内部 LM 并在融合时减去其得分: + +$$\hat{\mathbf{y}} = \arg\max_\mathbf{y} \left[ \log p_\text{E2E}(\mathbf{y} | \mathbf{x}) - \beta \log p_\text{ILM}(\mathbf{y}) + \lambda \log p_\text{LM}(\mathbf{y}) \right]$$ + +- **流式 vs. 离线 ASR** 是一个基本的架构选择。离线(或批处理)ASR 在处理完整个语音片段后才产生输出。流式 ASR 在音频到达时增量产生输出,具有有界延迟。 + +- 流式处理对实时应用至关重要:实时字幕、语音助手(用户在说完之前就期望得到响应)、电话通话转录。挑战在于某些未来上下文有助于识别(知道下一个词是"York"有助于消歧"New"),但流式系统不能无限等待未来的上下文。 + +- **单向编码器**(从左到右 LSTM、因果卷积、因果 Transformer)天然支持流式处理,因为每个输出仅依赖于过去和当前的输入。双向编码器(查看未来上下文)不能直接支持流式处理。 + +- **分块注意力**(也称为逐块或分段注意力)将输入划分为固定长度的块,仅在每个块内(以及可选的前面几个块)应用自注意力。这将延迟限制在块大小加上处理时间,同时在每个块内仍允许一定的局部双向上下文。其权衡是:块越小,准确率下降越多。 + +- **前瞻**允许流式编码器在当前帧产生输出之前,窥视少量的未来帧(例如 300-900 ms)。这是通过在单向计算中添加少量右上下文来实现的。前瞻窗口增加了延迟,但显著提升了准确率。 + +- **流式 ASR 中的延迟**包含几个组成部分: + - **算法延迟**:从音频到达到模型能够处理它的延迟(由块大小、前瞻和特征提取决定)。 + - **计算延迟**:运行模型前向传播所需的时间。 + - **端点检测延迟**:检测用户说话完毕的延迟。 + - **首词延迟**:第一个词出现的速度。**最终确认延迟**:最终输出被确认的速度(流式系统通常产生暂定输出,随着更多音频到达而被修正)。 + +- **ASR 的评估指标**: + +- **词错误率**(WER)是主要指标。通过将系统输出(假设)与参考文本(真实转录文本)进行对齐计算,使用编辑距离(将一个转换为另一个所需的最少替换、插入和删除次数),然后: + +$$\text{WER} = \frac{S + D + I}{N}$$ + +- 其中 $S$ 是替换数,$D$ 是删除数,$I$ 是插入数,$N$ 是参考文本中的总词数。如果插入过多,WER 可能超过 100%。5% 的 WER 被认为大致相当于人类在清晰朗读语音上的水平;对话或噪声环境下的语音则困难得多(10-20%+)。 + +- **字符错误率**(CER)是相同的公式应用于字符级别而非词级别。CER 对于没有明确词边界的语言(如中文、日语)以及评估近似正确情况的接近程度("cat" vs "bat" 是 100% WER 但 33% CER)更有参考价值。 + +- **词信息损失**(WIL)和**词信息保留**(WIP)是信息论替代指标,比 WER 更精确地考虑了参考文本与假设之间的相关性,但使用较少。 + +- **实时因子**(RTF)衡量计算效率:处理时间与音频时长的比值。RTF < 1 表示系统运行速度快于实时;RTF > 1 表示系统无法跟上实时音频。流式系统必须保持 RTF < 1。 + +- **数据增强**对鲁棒 ASR 至关重要。常见技术: + - **速度扰动**:以 0.9 倍和 1.1 倍速度对音频进行重采样(改变音高和时长)。 + - **SpecAugment**(Park 等人,2019):掩码频谱图中的随机频率带和时间步。这是音频领域的 dropout 类比,也是 ASR 中最有效的正则化技术之一。无需额外数据。 + - **噪声增强**:将干净语音与录制的噪声以各种信噪比混合。 + - **房间脉冲响应模拟**:将干净语音与模拟的房间声学进行卷积,以模拟混响环境。 + +- **ASR 的分词**决定了模型的输出词汇表。选项包括: + - **字符**:简单,词汇量小(英语约 30 个),但输出序列长且无隐式语言建模。 + - **子词 / BPE**(来自第 07 章):在词汇表大小和序列长度之间取得平衡的子词单元。现代系统的标准(Whisper 使用字节级 BPE,约 50,000 个标记)。 + - **词**:词汇量大(50,000+),输出序列短,但无法处理词表外的词。 + - **音素**:语言上合理,紧凑,但需要发音词典。 + +- ASR 的演进可以概括为:从高度工程化的模块化系统(GMM-HMM + WFST 解码,1990 年代-2010 年代)到混合系统(DNN-HMM,2012-2016),再到将流水线越来越多地吸收到单一神经网络中的端到端系统(CTC、RNN-T、LAS,2016-2020),最后到利用海量未标注或弱标注数据的大型预训练模型(wav2vec 2.0、Whisper,2020 至今)。每一次转变都在提升准确率的同时简化了工程复杂度,遵循了机器学习中从手工设计特征到从数据中学习表示的更广泛趋势(第 06 章中 CNN 替代图像特征、第 07 章中 Transformer 替代 NLP 特征也是如此)。 + +## 编程任务(使用 CoLab 或 notebook) + +1. 在 JAX 中从头实现 CTC 损失。创建一个包含短序列 logits 和目标标签的玩具示例,计算 CTC 前向算法得到总概率,并计算负对数似然损失。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +def ctc_forward(log_probs, targets): + """ + CTC 前向算法(对数域,数值稳定性)。 + log_probs: (T, V) 词汇表上的对数概率(索引 0 = 空白) + targets: (U,) 目标标签索引(不含空白) + 返回:目标序列在 CTC 下的对数概率。 + """ + T, V = log_probs.shape + U = len(targets) + + # 构建带有空白的扩展标签序列:[blank, y1, blank, y2, ..., yU, blank] + S = 2 * U + 1 + labels = jnp.zeros(S, dtype=jnp.int32) # 全部为空白 + for i in range(U): + labels = labels.at[2 * i + 1].set(targets[i]) + + # 初始化 alpha(对数域) + NEG_INF = -1e30 + alpha = jnp.full((T, S), NEG_INF) + alpha = alpha.at[0, 0].set(log_probs[0, labels[0]]) # 以空白开始 + alpha = alpha.at[0, 1].set(log_probs[0, labels[1]]) # 或第一个标签 + + # 前向填充 + for t in range(1, T): + for s in range(S): + # 同一状态 + a = alpha[t - 1, s] + # 从前一状态来 + if s > 0: + a = jnp.logaddexp(a, alpha[t - 1, s - 1]) + # 跳过空白(如果当前标签与两步前的标签不同) + if s > 1 and labels[s] != 0 and labels[s] != labels[s - 2]: + a = jnp.logaddexp(a, alpha[t - 1, s - 2]) + alpha = alpha.at[t, s].set(a + log_probs[t, labels[s]]) + + # 总对数概率:最后时间步的最后两个状态之和 + log_prob = jnp.logaddexp(alpha[T - 1, S - 1], alpha[T - 1, S - 2]) + return log_prob, alpha + +# --- 玩具示例 --- +T = 12 # 输入长度(时间步) +V = 5 # 词汇表大小(0=空白,1='c',2='a',3='t',4='x') +targets = jnp.array([1, 2, 3]) # "c", "a", "t" + +# 创建随机 logits 并转换为对数概率 +key = jax.random.PRNGKey(42) +logits = jax.random.normal(key, (T, V)) +log_probs = jax.nn.log_softmax(logits, axis=-1) + +log_prob, alpha = ctc_forward(log_probs, targets) +ctc_loss = -log_prob + +print(f"目标序列: {targets.tolist()} ('c', 'a', 't')") +print(f"输入长度 T={T}, 词汇表大小 V={V}") +print(f"CTC 对数概率: {log_prob:.4f}") +print(f"CTC 损失(负对数概率): {ctc_loss:.4f}") + +# 可视化前向变量(alpha)网格 +fig, ax = plt.subplots(figsize=(12, 5)) +# 将对数转换为线性以便可视化 +alpha_linear = jnp.exp(alpha - jnp.max(alpha)) # 归一化以便观察 +im = ax.imshow(alpha_linear.T, aspect='auto', origin='lower', cmap='viridis') +ax.set_xlabel('时间步 (t)') +ax.set_ylabel('扩展标签索引 (s)') + +label_names = ['_', 'c', '_', 'a', '_', 't', '_'] # _ = 空白 +ax.set_yticks(range(len(label_names))) +ax.set_yticklabels(label_names) +ax.set_title(f'CTC 前向变量(alpha 网格)| 损失 = {ctc_loss:.2f}') +plt.colorbar(im, ax=ax, label='归一化概率') +plt.tight_layout(); plt.show() +``` + +2. 在 JAX 中构建一个简单的编码器-解码器基于注意力的 ASR 模型(最小化的 LAS 类架构)。使用一维卷积编码器和带有点积注意力的单层解码器。在合成数据上运行并可视化注意力权重。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# --- 最小化的基于注意力的编码器-解码器 ASR 模型 --- + +def init_params(key, input_dim, hidden_dim, vocab_size): + """初始化小型 LAS 类模型的参数。""" + keys = jax.random.split(key, 8) + scale = 0.1 + params = { + # 编码器:简单的线性投影(模拟卷积输出) + 'enc_w': jax.random.normal(keys[0], (input_dim, hidden_dim)) * scale, + 'enc_b': jnp.zeros(hidden_dim), + # 注意力:查询、键、值投影 + 'attn_q': jax.random.normal(keys[1], (hidden_dim, hidden_dim)) * scale, + 'attn_k': jax.random.normal(keys[2], (hidden_dim, hidden_dim)) * scale, + 'attn_v': jax.random.normal(keys[3], (hidden_dim, hidden_dim)) * scale, + # 解码器 RNN(为演示使用简单 Elman RNN) + 'dec_wh': jax.random.normal(keys[4], (hidden_dim, hidden_dim)) * scale, + 'dec_wx': jax.random.normal(keys[5], (vocab_size, hidden_dim)) * scale, + 'dec_wc': jax.random.normal(keys[6], (hidden_dim, hidden_dim)) * scale, + 'dec_b': jnp.zeros(hidden_dim), + # 输出投影 + 'out_w': jax.random.normal(keys[7], (hidden_dim, vocab_size)) * scale, + 'out_b': jnp.zeros(vocab_size), + } + return params + +def encode(params, x): + """编码器:线性投影(占位符,代表卷积/LSTM 堆叠)。""" + return jnp.tanh(x @ params['enc_w'] + params['enc_b']) + +def attend(params, query, enc_out): + """在编码器输出上的点积注意力。""" + q = query @ params['attn_q'] # (hidden,) + k = enc_out @ params['attn_k'] # (T_enc, hidden) + v = enc_out @ params['attn_v'] # (T_enc, hidden) + d_k = q.shape[-1] + scores = (k @ q) / jnp.sqrt(d_k) # (T_enc,) + weights = jax.nn.softmax(scores) # (T_enc,) + context = weights @ v # (hidden,) + return context, weights + +def decode_step(params, h_prev, y_prev_onehot, enc_out): + """单步解码:RNN + 注意力。""" + # 嵌入前一个标记 + y_emb = y_prev_onehot @ params['dec_wx'] # (hidden,) + # 注意力到编码器 + context, attn_w = attend(params, h_prev, enc_out) + # RNN 更新 + h = jnp.tanh(h_prev @ params['dec_wh'] + y_emb + context @ params['dec_wc'] + + params['dec_b']) + # 输出 logits + logits = h @ params['out_w'] + params['out_b'] + return h, logits, attn_w + +# --- 设置 --- +key = jax.random.PRNGKey(0) +input_dim = 40 # 例如 40 个梅尔频带 +hidden_dim = 64 +vocab_size = 10 # 用于演示的小词汇表 +T_enc = 30 # 编码器时间步 +T_dec = 8 # 解码器步数 + +params = init_params(key, input_dim, hidden_dim, vocab_size) + +# 合成输入:随机梅尔类特征 +key, subkey = jax.random.split(key) +x = jax.random.normal(subkey, (T_enc, input_dim)) + +# 编码 +enc_out = encode(params, x) + +# 解码(使用随机目标的教师强制) +key, subkey = jax.random.split(key) +targets = jax.random.randint(subkey, (T_dec,), 0, vocab_size) + +h = jnp.zeros(hidden_dim) +all_logits = [] +all_attn = [] + +for t in range(T_dec): + y_prev = jax.nn.one_hot(targets[t] if t > 0 else 0, vocab_size) + h, logits, attn_w = decode_step(params, h, y_prev, enc_out) + all_logits.append(logits) + all_attn.append(attn_w) + +all_attn = jnp.stack(all_attn) # (T_dec, T_enc) +all_logits = jnp.stack(all_logits) # (T_dec, vocab_size) + +# --- 可视化注意力权重 --- +fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + +im = axes[0].imshow(all_attn, aspect='auto', cmap='Blues', origin='lower') +axes[0].set_xlabel('编码器时间步') +axes[0].set_ylabel('解码器步') +axes[0].set_title('注意力权重(解码器 -> 编码器)') +plt.colorbar(im, ax=axes[0]) + +# 显示每个解码步的预测标记分布 +im2 = axes[1].imshow(jax.nn.softmax(all_logits, axis=-1), aspect='auto', + cmap='Oranges', origin='lower') +axes[1].set_xlabel('词汇表索引') +axes[1].set_ylabel('解码器步') +axes[1].set_title('输出标记概率') +plt.colorbar(im2, ax=axes[1]) + +plt.suptitle('最小化的基于注意力的 ASR 模型(未训练)') +plt.tight_layout(); plt.show() +``` + +3. 使用动态规划(编辑距离)从头计算词错误率(WER),并针对一个参考文本评估多个假设。可视化编辑距离矩阵。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np + +def compute_wer(reference, hypothesis): + """ + 使用动态规划(词级别的 Levenshtein 距离)计算 WER。 + 返回 WER、替换数、删除数、插入数和 DP 矩阵。 + """ + ref_words = reference.split() + hyp_words = hypothesis.split() + N = len(ref_words) + M = len(hyp_words) + + # DP 矩阵:d[i][j] = ref[:i] 和 hyp[:j] 之间的编辑距离 + d = np.zeros((N + 1, M + 1), dtype=np.int32) + # 回溯矩阵用于统计 S, D, I + ops = np.zeros((N + 1, M + 1, 3), dtype=np.int32) # [sub, del, ins] + + for i in range(N + 1): + d[i][0] = i # 全部删除 + for j in range(M + 1): + d[0][j] = j # 全部插入 + + for i in range(1, N + 1): + for j in range(1, M + 1): + if ref_words[i - 1] == hyp_words[j - 1]: + sub_cost = d[i - 1][j - 1] # 匹配,无需编辑 + else: + sub_cost = d[i - 1][j - 1] + 1 # 替换 + del_cost = d[i - 1][j] + 1 # 删除 + ins_cost = d[i][j - 1] + 1 # 插入 + + d[i][j] = min(sub_cost, del_cost, ins_cost) + + # 回溯统计操作次数 + i, j = N, M + S, D, I = 0, 0, 0 + while i > 0 or j > 0: + if i > 0 and j > 0 and d[i][j] == d[i-1][j-1] and ref_words[i-1] == hyp_words[j-1]: + i -= 1; j -= 1 # 正确 + elif i > 0 and j > 0 and d[i][j] == d[i-1][j-1] + 1: + S += 1; i -= 1; j -= 1 # 替换 + elif i > 0 and d[i][j] == d[i-1][j] + 1: + D += 1; i -= 1 # 删除 + elif j > 0 and d[i][j] == d[i][j-1] + 1: + I += 1; j -= 1 # 插入 + else: + break + + wer = (S + D + I) / N if N > 0 else 0.0 + return wer, S, D, I, d + +# --- 测试用例 --- +reference = "the cat sat on the mat" +hypotheses = [ + "the cat sat on the mat", # 完美 + "the cat sit on the mat", # 1 次替换 + "the cat on the mat", # 1 次删除 + "the big cat sat on the mat", # 1 次插入 + "a dog sat in a rug", # 多处错误 +] + +print(f"参考文本: '{reference}'\n") +print(f"{'假设':<40s} {'WER':>6s} {'S':>3s} {'D':>3s} {'I':>3s}") +print("-" * 60) +results = [] +for hyp in hypotheses: + wer, S, D, I, dp = compute_wer(reference, hyp) + results.append((hyp, wer, S, D, I, dp)) + print(f"'{hyp}':<40s} {wer:>6.1%} {S:>3d} {D:>3d} {I:>3d}") + +# 可视化最差情况的 DP 矩阵 +worst = results[-1] +hyp_words = worst[0].split() +ref_words = reference.split() +dp_matrix = worst[5] + +fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + +# DP 矩阵 +im = axes[0].imshow(dp_matrix, cmap='YlOrRd', origin='upper') +axes[0].set_xticks(range(len(hyp_words) + 1)) +axes[0].set_xticklabels([''] + hyp_words, rotation=45, ha='right', fontsize=9) +axes[0].set_yticks(range(len(ref_words) + 1)) +axes[0].set_yticklabels([''] + ref_words, fontsize=9) +axes[0].set_xlabel('假设词') +axes[0].set_ylabel('参考词') +axes[0].set_title(f'编辑距离矩阵\nWER = {worst[1]:.1%}') +for i in range(dp_matrix.shape[0]): + for j in range(dp_matrix.shape[1]): + axes[0].text(j, i, str(dp_matrix[i, j]), ha='center', va='center', fontsize=8) +plt.colorbar(im, ax=axes[0]) + +# WER 比较柱状图 +names = [f'Hyp {i+1}' for i in range(len(results))] +wers = [r[1] * 100 for r in results] +colors = ['#27ae60' if w == 0 else '#f39c12' if w < 30 else '#e74c3c' for w in wers] +axes[1].barh(names, wers, color=colors) +axes[1].set_xlabel('WER (%)') +axes[1].set_title('词错误率比较') +for i, (w, r) in enumerate(zip(wers, results)): + axes[1].text(w + 1, i, f'{w:.0f}% (S={r[2]}, D={r[3]}, I={r[4]})', + va='center', fontsize=9) +axes[1].set_xlim(0, max(wers) * 1.4) + +plt.tight_layout(); plt.show() +``` + +4. 在对数梅尔频谱图上实现 SpecAugment(频率掩码和时间掩码),并可视化原始版本与增强版本。从合成信号生成频谱图。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# --- 生成合成对数梅尔频谱图 --- +key = jax.random.PRNGKey(42) +fs = 16000 +duration = 2.0 +t = jnp.arange(0, duration, 1.0 / fs) + +# 模拟语音:带谐波的啁啾信号 +f0 = 120.0 +x = sum(jnp.sin(2 * jnp.pi * f0 * k * t * (1 + 0.1 * t)) / k for k in range(1, 10)) +key, subkey = jax.random.split(key) +x = x + 0.05 * jax.random.normal(subkey, t.shape) + +# 计算对数梅尔频谱图(简化版) +frame_len = 400 # 25 ms +hop_len = 160 # 10 ms +n_fft = 512 +n_mels = 80 + +n_frames = (len(x) - frame_len) // hop_len + 1 +hamming = 0.54 - 0.46 * jnp.cos(2 * jnp.pi * jnp.arange(frame_len) / (frame_len - 1)) + +frames = jnp.stack([x[i * hop_len : i * hop_len + frame_len] for i in range(n_frames)]) +windowed = frames * hamming +spectra = jnp.abs(jnp.fft.rfft(windowed, n=n_fft)) ** 2 + +# 简单的梅尔滤波器组 +def hz_to_mel(f): return 2595 * jnp.log10(1 + f / 700) +def mel_to_hz(m): return 700 * (10 ** (m / 2595) - 1) + +mel_points = jnp.linspace(hz_to_mel(0), hz_to_mel(fs / 2), n_mels + 2) +hz_pts = mel_to_hz(mel_points) +bins = jnp.floor((n_fft + 1) * hz_pts / fs).astype(jnp.int32) + +n_freqs = n_fft // 2 + 1 +fb = jnp.zeros((n_mels, n_freqs)) +for m in range(n_mels): + lo, mid, hi = int(bins[m]), int(bins[m+1]), int(bins[m+2]) + for k in range(lo, mid): + if mid != lo: + fb = fb.at[m, k].set((k - lo) / (mid - lo)) + for k in range(mid, hi): + if hi != mid: + fb = fb.at[m, k].set((hi - k) / (hi - mid)) + +log_mel = jnp.log(spectra @ fb.T + 1e-10) + +# --- SpecAugment --- +def spec_augment(spec, key, n_freq_masks=2, freq_mask_width=15, + n_time_masks=2, time_mask_width=25): + """应用 SpecAugment:频率掩码和时间掩码。""" + augmented = spec.copy() + T, F = spec.shape + + # 频率掩码 + for _ in range(n_freq_masks): + key, k1, k2 = jax.random.split(key, 3) + f_width = jax.random.randint(k1, (), 1, freq_mask_width + 1) + f_start = jax.random.randint(k2, (), 0, max(1, F - freq_mask_width)) + mask = (jnp.arange(F) >= f_start) & (jnp.arange(F) < f_start + f_width) + augmented = jnp.where(mask[None, :], 0.0, augmented) + + # 时间掩码 + for _ in range(n_time_masks): + key, k1, k2 = jax.random.split(key, 3) + t_width = jax.random.randint(k1, (), 1, time_mask_width + 1) + t_start = jax.random.randint(k2, (), 0, max(1, T - time_mask_width)) + mask = (jnp.arange(T) >= t_start) & (jnp.arange(T) < t_start + t_width) + augmented = jnp.where(mask[:, None], 0.0, augmented) + + return augmented + +key, subkey = jax.random.split(key) +log_mel_aug = spec_augment(log_mel, subkey) + +# --- 可视化 --- +fig, axes = plt.subplots(2, 1, figsize=(14, 8)) + +im0 = axes[0].imshow(log_mel.T, aspect='auto', origin='lower', cmap='inferno', + extent=[0, duration, 0, n_mels]) +axes[0].set_title('原始对数梅尔频谱图') +axes[0].set_xlabel('时间 (s)'); axes[0].set_ylabel('梅尔频带') +plt.colorbar(im0, ax=axes[0], label='对数能量') + +im1 = axes[1].imshow(log_mel_aug.T, aspect='auto', origin='lower', cmap='inferno', + extent=[0, duration, 0, n_mels]) +axes[1].set_title('SpecAugment 后(频率 + 时间掩码)') +axes[1].set_xlabel('时间 (s)'); axes[1].set_ylabel('梅尔频带') +plt.colorbar(im1, ax=axes[1], label='对数能量') + +plt.tight_layout(); plt.show() +``` diff --git a/chapter 09: audio and speech/03. text to speech and voice.md b/chapter 09: audio and speech/03. text to speech and voice.md new file mode 100644 index 0000000..16b1a09 --- /dev/null +++ b/chapter 09: audio and speech/03. text to speech and voice.md @@ -0,0 +1,711 @@ +# 语音合成与声音 + +*语音合成(Text-to-Speech Synthesis)逆向执行 ASR 流水线,从书面文本生成自然听感的音频。本文涵盖 TTS 流水线(文本规范化、G2P、声学模型、声码器)、Tacotron、WaveNet、HiFi-GAN、声音克隆、声音转换以及语音活动检测(VAD)。* + +- 在文件 01 中,我们构建了信号处理工具包:波形、语谱图、梅尔滤波器组和 MFCC。在文件 02 中,我们将语音转换为文本。现在我们反方向操作:给定文本,合成自然听感的语音。这就是**语音合成(TTS)**,一个同样通向声音转换、声音克隆和语音活动检测的问题。 + +- 将 TTS 想象成一场舞台表演。剧本就是文本输入。导演(声学模型)决定每句台词应该如何发音——音高、时长、重音。管弦乐队(声码器)随后演奏乐谱,产生听众实际听到的声波。现代神经 TTS 用媲美人类说话者的演绎,取代了基于规则系统那种僵硬、机械的发音。 + +![TTS 流水线:文本被规范化、转换为音素、由声学模型处理生成梅尔语谱图,然后通过声码器生成最终波形](../images/tts_pipeline.svg) + +- **语音合成流水线** 标准 TTS 流水线包含四个阶段:(1) 文本规范化,(2) 音素转换,(3) 声学模型,(4) 声码器。一些现代系统将阶段 3 和 4 合并为一个端到端模型,但这种概念分解仍然有用。 + +- **文本规范化** 将原始文本转换为可发音的形式。缩写展开("Dr."变为"Doctor")、数字变为词语("1984"变为"nineteen eighty-four")、货币符号被口头发音("$5"变为"five dollars"),以及处理 URL 或特殊字符。这一阶段通常基于规则和语言特定文法,不过也存在神经规范化模型。此处的错误会传播到所有下游阶段:如果"St."被读作"saint"而不是"street",整个发音就错了。 + +- **字素到音素(G2P)转换** 将规范化文本映射为音素序列。英语尤其不规则("though"、"through"、"tough"中的"ough"发音各不相同),因此词典查找(CMU 发音词典)处理常见词语,而神经序列到序列模型(第 06 章的编码器-解码器或第 07 章的 Transformer)处理词汇表外的词语。浅层正字法语言(西班牙语、芬兰语)需要更简单的 G2P。输出通常是 IPA(国际音标)序列或等效的内部音素集合。 + +- **声学模型** 接收音素序列并产生中间声学表示,几乎总是**梅尔语谱图**(文件 01)。梅尔语谱图捕获每个时间帧的频谱包络,编码了声码器重构波形所需的感知相关信息。声学模型必须决定时长(每个音素持续多久)、音高(基频 $F_0$)和能量(响度)。 + +- **声码器** 接收梅尔语谱图并产生原始音频波形。这是一个不适定的反演问题:由于相位信息已被丢弃,许多波形可以产生相同的语谱图。经典声码器(Griffin-Lim、WORLD)使用迭代或信号模型方法,但神经声码器现在在质量上占主导地位。 + +- **声码器:WaveNet**(van den Oord 等人,2016)是第一个生成几乎与人类录音无法区分的语音的神经声码器。它自回归地对波形建模,预测每个样本 $x_t$ 的条件概率依赖于所有先前样本: + +$$P(x) = \prod_{t=1}^{T} P(x_t \mid x_1, \ldots, x_{t-1}, c)$$ + +- 其中 $c$ 是条件信号(梅尔语谱图)。每个样本是 16 位,因此对 65536 个值进行朴素 softmax 是不切实际的。WaveNet 使用 **μ-law 压扩** 减少到 256 个量化级别,或者后来的变体使用 logistics 混合分布。 + +- WaveNet 的核心构建模块是**扩张因果卷积**。因果意味着滤波器权重只看过去样本(无未来泄露)。扩张意味着滤波器以指数增长的间隔跳过样本:扩张因子 $1, 2, 4, 8, \ldots, 512$。这提供了指数级大的感受野,同时保持参数量线性增长。 + +- 每层的门控激活函数为: + +$$z = \tanh(W_{f} \ast x) \odot \sigma(W_{g} \ast x)$$ + +- 其中 $W_f$ 和 $W_g$ 是滤波器和门控卷积权重,$\ast$ 表示扩张因果卷积,$\odot$ 是逐元素乘法。这种门控机制(来自第 06 章的 LSTM)允许网络控制信息流。 + +- WaveNet 产生卓越的质量,但推理速度极慢:生成一秒 24 kHz 音频需要 24000 次顺序前向传播。这推动了所有后续声码器研究。 + +- **WaveRNN**(Kalchbrenner 等人,2018)用单层循环网络取代了 WaveNet 的深层卷积堆叠。它将每个 16 位样本拆分为粗(高 8 位)和细(低 8 位)分量,使用 GRU(第 06 章)预测每个分量。这种双 softmax 方法显著减少了计算量,同时保持了高质量。经过精心内核优化后,WaveRNN 在移动 CPU 上足以实现实时运行。 + +- **WaveGlow**(Prenger 等人,2019)是一种基于**流**的声码器,完全避免了自回归生成。它使用一系列可逆变换(仿射耦合层,第 06 章的正则化流)将简单高斯分布映射到波形分布。训练使用变量变换公式最大化精确对数似然: + +$$\log P(x) = \log P(z) + \sum_{i} \log \left| \det \frac{\partial f_i}{\partial f_{i-1}} \right|$$ + +- 其中 $z = f(x)$ 是通过将 $x$ 传递经流得到的潜在变量。推理时,抽取样本 $z \sim \mathcal{N}(0, I)$ 并通过逆流以单次并行前向传播推出。WaveGlow 用模型大小(耦合层的大网络)换取生成速度。 + +- **HiFi-GAN**(Kong 等人,2020)使用**生成对抗网络**从梅尔语谱图合成波形。生成器通过一系列转置卷积对梅尔语谱图进行上采样,每个卷积后跟一个**多感受野融合(MRF)**模块。MRF 模块并行应用多个具有不同核大小和扩张率的残差块,然后将它们的输出求和。这使得生成器能够同时捕获多个时间尺度的模式。 + +![HiFi-GAN 生成器架构:梅尔语谱图输入经过转置卷积上采样层,每层后跟多感受野融合块,这些融合块组合了具有不同扩张模式的并行残差堆叠](../images/hifi_gan_generator.svg) + +- HiFi-GAN 使用两种鉴别器类型。**多周期鉴别器(MPD)**通过以不同周期(2、3、5、7、11)折叠一维波形,将其重塑为二维,然后应用二维卷积。这捕获了不同基频下的周期结构。**多尺度鉴别器(MSD)**在原始波形、2 倍降采样和 4 倍降采样版本上操作,捕获不同时间分辨率下的模式。 + +- 训练目标结合了对抗损失、**梅尔语谱图重构损失**(合成音频与真实音频的梅尔语谱图之间的 L1 距离)和**特征匹配损失**(中间鉴别器特征之间的 L1 距离): + +$$\mathcal{L}_G = \mathcal{L}_{\text{adv}}(G) + \lambda_{\text{mel}} \mathcal{L}_{\text{mel}}(G) + \lambda_{\text{fm}} \mathcal{L}_{\text{fm}}(G)$$ + +- HiFi-GAN 实现了与 WaveNet 相当的合成质量,同时速度提升超过 1000 倍,可在单个 GPU 上实现实时生成。 + +- **神经源-滤波器(NSF)模型**将传统信号处理与神经网络相结合。在经典源-滤波器模型中,浊音由声源激励(基频 $F_0$ 处的周期脉冲序列)通过声道滤波器(频谱包络)产生。NSF 模型用神经网络替代手工设计的滤波器,同时保留显式源信号。输入的 $F_0$ 轮廓提供了纯数据驱动声码器有时难以处理的精细音高控制。 + +- **声学模型:Tacotron**(Wang 等人,2017)是第一个直接将字符序列转换为梅尔语谱图的端到端神经 TTS 系统。它使用带注意力机制的编码器-解码器架构(第 07 章)。编码器使用卷积库、高速网络和双向 GRU 处理字符/音素序列。解码器是一个自回归 GRU,逐个预测梅尔帧,使用前一帧和注意力上下文作为输入。 + +- **Tacotron 2**(Shen 等人,2018)显著改进了架构。编码器是一个 3 层一维卷积堆叠后跟双向 LSTM(第 06 章)。解码器是一个 2 层 LSTM,带**位置敏感注意力**,该注意力机制不仅基于编码器输出和解码器状态,还基于先前步骤累积的注意力权重来条件化。这防止了注意力跳过或重复词语的常见失败模式。 + +![Tacotron 2 架构:字符/音素编码器包含卷积层和 BiLSTM,位置敏感注意力对齐到梅尔语谱图帧,自回归解码器包含停止标记预测](../images/tacotron2_architecture.svg) + +- 解码器步骤 $i$ 下编码器位置 $j$ 的位置敏感注意力能量为: + +$$e_{i,j} = w^T \tanh(W_s s_{i-1} + W_h h_j + W_f f_{i,j} + b)$$ + +- 其中 $s_{i-1}$ 是前一个解码器状态,$h_j$ 是位置 $j$ 处的编码器输出,$f_{i,j}$ 是通过将累积注意力权重 $\sum_{k 通道 + params['input_w'] = jr.normal(keys[0], (7, n_mels, channels)) * 0.02 + params['input_b'] = jnp.zeros(channels) + + # 上采样块(转置卷积) + in_ch = channels + for i, rate in enumerate(upsample_rates): + k_size = rate * 2 + scale = jnp.sqrt(2.0 / (in_ch * k_size)) + out_ch = in_ch // 2 + params[f'up{i}_w'] = jr.normal(keys[i+1], (k_size, in_ch, out_ch)) * scale + params[f'up{i}_b'] = jnp.zeros(out_ch) + # 每个尺度下的残差块 + params[f'res{i}_0'] = init_residual_block(jr.fold_in(keys[i+4], 0), + out_ch, 3, 1) + params[f'res{i}_1'] = init_residual_block(jr.fold_in(keys[i+4], 1), + out_ch, 3, 3) + in_ch = out_ch + + # 输出投影到单声道波形 + params['output_w'] = jr.normal(keys[8], (7, in_ch, 1)) * 0.02 + params['output_b'] = jnp.zeros(1) + params['upsample_rates'] = upsample_rates + + return params + +def generator_forward(params, mel): + """mel: (batch, time, n_mels) -> waveform: (batch, time * prod(rates), 1)。""" + # 输入投影 + h = jax.lax.conv_general_dilated( + mel.transpose(0, 2, 1), + params['input_w'].transpose(2, 1, 0), + window_strides=(1,), padding='SAME' + ).transpose(0, 2, 1) + params['input_b'] + + for i, rate in enumerate(params['upsample_rates']): + h = jax.nn.leaky_relu(h, negative_slope=0.1) + # 通过转置卷积上采样 + k_size = rate * 2 + h = jax.lax.conv_transpose( + h.transpose(0, 2, 1), + params[f'up{i}_w'].transpose(2, 1, 0), + strides=(rate,), + padding='SAME' + ).transpose(0, 2, 1) + params[f'up{i}_b'] + # 残差块 + h = residual_block(params[f'res{i}_0'], h) + h = residual_block(params[f'res{i}_1'], h) + + h = jax.nn.leaky_relu(h, negative_slope=0.1) + out = jax.lax.conv_general_dilated( + h.transpose(0, 2, 1), + params['output_w'].transpose(2, 1, 0), + window_strides=(1,), padding='SAME' + ).transpose(0, 2, 1) + params['output_b'] + + return jnp.tanh(out) + +# 创建一个合成梅尔语谱图(模拟元音) +n_mels = 80 +n_frames = 50 +mel = jnp.zeros((1, n_frames, n_mels)) +# 在低频梅尔频带中添加能量(模拟共振峰) +mel = mel.at[:, :, 5:15].set(1.0) +mel = mel.at[:, :, 20:25].set(0.6) + +# 初始化并运行生成器 +key = jr.PRNGKey(42) +params = init_generator(key, n_mels=n_mels, upsample_rates=(8, 8, 4), + channels=128) +waveform = generator_forward(params, mel) + +print(f"输入梅尔形状:{mel.shape}") +print(f"输出波形形状:{waveform.shape}") +print(f"上采样因子:{8 * 8 * 4} = {8*8*4}x") + +fig, axes = plt.subplots(2, 1, figsize=(12, 6)) + +axes[0].imshow(mel[0].T, aspect='auto', origin='lower', cmap='magma') +axes[0].set_title('输入梅尔语谱图') +axes[0].set_ylabel('梅尔频带') +axes[0].set_xlabel('帧') + +waveform_np = waveform[0, :, 0] +axes[1].plot(waveform_np[:2000], color='#9b59b6', linewidth=0.5) +axes[1].set_title('生成器输出波形(未经训练 - 随机噪声)') +axes[1].set_ylabel('振幅') +axes[1].set_xlabel('样本') + +plt.tight_layout() +plt.show() +print("注意:输出是噪声,因为生成器未经训练。") +print("在实践中,对抗损失 + 梅尔损失训练会将其塑造成语音。") +``` + +- **任务 4:使用简单 RNN 的语音活动检测。** 在合成音频特征上训练一个基于小型 GRU 的 VAD 模型,对帧进行语音或静音分类。 + +```python +import jax +import jax.numpy as jnp +import jax.random as jr +import matplotlib.pyplot as plt + +# 生成合成对数梅尔能量特征及语音/静音标签 +def generate_vad_data(key, n_sequences=100, n_frames=200, n_features=40): + """模拟对数梅尔特征:语音区域能量更高且具有结构。""" + keys = jr.split(key, 5) + all_features = [] + all_labels = [] + + for i in range(n_sequences): + k = jr.fold_in(keys[0], i) + k1, k2, k3 = jr.split(k, 3) + + # 随机语音/静音模式 + label = jnp.zeros(n_frames) + n_segments = jr.randint(k1, (), 2, 6) + for seg in range(int(n_segments)): + start = jr.randint(jr.fold_in(k2, seg), (), 0, n_frames - 20) + length = jr.randint(jr.fold_in(k3, seg), (), 10, 50) + end = jnp.minimum(start + length, n_frames) + label = label.at[int(start):int(end)].set(1.0) + + # 特征:语音帧具有更高能量 + 频谱结构 + noise = jr.normal(jr.fold_in(keys[1], i), (n_frames, n_features)) * 0.3 + speech_pattern = jnp.outer(label, jnp.exp(-jnp.arange(n_features) / 15.0)) + features = speech_pattern * 2.0 + noise + 0.1 + + all_features.append(features) + all_labels.append(label) + + return jnp.stack(all_features), jnp.stack(all_labels) + +key = jr.PRNGKey(123) +features, labels = generate_vad_data(key) +train_features, train_labels = features[:80], labels[:80] +test_features, test_labels = features[80:], labels[80:] + +# 基于 GRU 的简单 VAD 模型 +def init_vad_model(key, input_dim=40, hidden_dim=64): + keys = jr.split(key, 6) + scale_ih = jnp.sqrt(2.0 / input_dim) + scale_hh = jnp.sqrt(2.0 / hidden_dim) + return { + 'W_z': jr.normal(keys[0], (input_dim, hidden_dim)) * scale_ih, + 'U_z': jr.normal(keys[1], (hidden_dim, hidden_dim)) * scale_hh, + 'b_z': jnp.zeros(hidden_dim), + 'W_r': jr.normal(keys[2], (input_dim, hidden_dim)) * scale_ih, + 'U_r': jr.normal(keys[3], (hidden_dim, hidden_dim)) * scale_hh, + 'b_r': jnp.zeros(hidden_dim), + 'W_h': jr.normal(keys[4], (input_dim, hidden_dim)) * scale_ih, + 'U_h': jr.normal(keys[5], (hidden_dim, hidden_dim)) * scale_hh, + 'b_h': jnp.zeros(hidden_dim), + 'W_out': jr.normal(jr.fold_in(keys[0], 99), (hidden_dim, 1)) * 0.1, + 'b_out': jnp.zeros(1), + } + +def gru_step(params, h, x): + """单步 GRU。""" + z = jax.nn.sigmoid(x @ params['W_z'] + h @ params['U_z'] + params['b_z']) + r = jax.nn.sigmoid(x @ params['W_r'] + h @ params['U_r'] + params['b_r']) + h_tilde = jnp.tanh(x @ params['W_h'] + (r * h) @ params['U_h'] + params['b_h']) + h_new = (1 - z) * h + z * h_tilde + return h_new + +def vad_forward(params, x): + """x: (batch, time, features) -> logits: (batch, time)。""" + batch_size, n_frames, _ = x.shape + hidden_dim = params['W_z'].shape[1] + h = jnp.zeros((batch_size, hidden_dim)) + + outputs = [] + for t in range(n_frames): + h = gru_step(params, h, x[:, t, :]) + logit = (h @ params['W_out'] + params['b_out']).squeeze(-1) + outputs.append(logit) + + return jnp.stack(outputs, axis=1) + +def bce_loss(params, features, labels): + """VAD 的二元交叉熵损失。""" + logits = vad_forward(params, features) + probs = jax.nn.sigmoid(logits) + probs = jnp.clip(probs, 1e-7, 1 - 1e-7) + loss = -(labels * jnp.log(probs) + (1 - labels) * jnp.log(1 - probs)) + return jnp.mean(loss) + +grad_fn = jax.jit(jax.value_and_grad(bce_loss)) + +# 训练 +params = init_vad_model(jr.PRNGKey(0)) +lr = 5e-3 +losses = [] + +for epoch in range(200): + loss_val, grads = grad_fn(params, train_features, train_labels) + params = jax.tree.map(lambda p, g: p - lr * g, params, grads) + losses.append(float(loss_val)) + if epoch % 50 == 0: + print(f"轮次 {epoch}:损失 = {loss_val:.4f}") + +# 在测试集上评估 +test_logits = vad_forward(params, test_features) +test_preds = (jax.nn.sigmoid(test_logits) > 0.5).astype(jnp.float32) +accuracy = jnp.mean(test_preds == test_labels) +print(f"\n测试准确率:{accuracy:.4f}") + +# 可视化一个测试示例 +idx = 0 +fig, axes = plt.subplots(3, 1, figsize=(14, 7)) + +axes[0].imshow(test_features[idx].T, aspect='auto', origin='lower', cmap='magma') +axes[0].set_title('对数梅尔能量特征') +axes[0].set_ylabel('梅尔频带') + +axes[1].fill_between(range(200), test_labels[idx], alpha=0.4, color='#27ae60', + label='真实值') +axes[1].plot(jax.nn.sigmoid(test_logits[idx]), color='#e74c3c', + linewidth=1.5, label='预测概率') +axes[1].axhline(0.5, color='gray', linestyle='--', linewidth=0.8) +axes[1].set_ylabel('语音概率') +axes[1].legend() +axes[1].set_title('VAD 预测') + +axes[2].fill_between(range(200), test_labels[idx], alpha=0.4, color='#27ae60', + label='真实值') +axes[2].fill_between(range(200), test_preds[idx], alpha=0.4, color='#f39c12', + label='预测(阈值=0.5)') +axes[2].set_ylabel('语音 / 静音') +axes[2].set_xlabel('帧') +axes[2].legend() +axes[2].set_title('VAD 二值决策') + +plt.tight_layout() +plt.show() +``` diff --git a/chapter 09: audio and speech/04. speaker and audio analysis.md b/chapter 09: audio and speech/04. speaker and audio analysis.md new file mode 100644 index 0000000..0fa15a2 --- /dev/null +++ b/chapter 09: audio and speech/04. speaker and audio analysis.md @@ -0,0 +1,643 @@ +# 说话人与音频分析 + +*说话人与音频分析识别谁在说话、何时说话以及存在哪些非语言声音。本文涵盖说话人确认与识别、i向量、d向量、x向量、说话人日志、音频事件分类、音乐信息检索以及语音情感识别。* + +- 在文件 01 中,我们构建了信号处理基础:语谱图、MFCC 和梅尔滤波器组。在文件 02 中,我们识别了所说的内容。现在我们要问:是谁说的、何时说的、以及音频中还在发生什么。说话人识别、说话人日志、音频分类和音乐分析都共享一条主线:学习能够为当前任务捕捉正确不变性的紧凑嵌入,这与第 06 章中的嵌入思想一脉相承。 + +- 可以把说话人识别想象成在电话中辨认朋友的声音。你不需要理解词汇;某种关于音色、语速和嗓音特质的东西对这个人来说是独一无二的。说话人识别系统学会从原始音频中提取这种"声纹",忽略说的是什么,专注于怎么说的。 + +- **说话人识别**是两类相关任务的总称: + - **说话人确认**(SV):给定一个声明的身份和一段音频片段,判断说话人是否与其声称的身份一致。这是一个二元决策(接受或拒绝),是基于语音的身份验证技术("嘿 Siri,这是我的声音吗?")背后的核心原理。 + - **说话人识别**(SI):给定一段音频片段和一个已知说话人库,判断该片段由哪个说话人产生。这是一个多分类问题。 + +![说话人确认:注册音频被嵌入,测试音频被嵌入,计算嵌入之间的余弦相似度,通过阈值决定接受或拒绝](../images/speaker_verification.svg) + +- 两种任务共享相同的底层表示:一个固定维度的**说话人嵌入**,它捕捉说话人的身份特征而与所说内容无关。区别仅在于决策阶段:确认比较两个嵌入,识别则在候选嵌入中找到最近邻。 + +- **余弦相似度**是比较说话人嵌入的标准度量。给定注册嵌入 $e$ 和测试嵌入 $t$: + +$$s = \frac{e \cdot t}{\|e\| \, \|t\|}$$ + +- 阈值 $\theta$ 决定接受/拒绝决策:若 $s > \theta$,则接受。阈值在**错误接受率(FAR)**和**错误拒绝率(FRR)**之间权衡。**等错误率(EER)**,即 FAR = FRR 时的值,是标准评估指标。EER 越低表示性能越好。最先进的系统在标准基准(VoxCeleb)上可实现低于 1% 的 EER。 + +- **i向量**(Dehak 等人,2010)是深度学习之前主导性的说话人嵌入方法。其思想源于因子分析(第 02 章的矩阵分解和第 04 章的降维)。一个**通用背景模型(UBM)**——基于多样本说话人训练的大型 GMM——定义了一个超向量空间。每条语音的 GMM 超向量被投影到低维的**全可变性空间**: + +$$M = m + Tw$$ + +- 其中 $M$ 是该语音的 GMM 超向量,$m$ 是 UBM 均值超向量,$T$ 是全可变性矩阵(从数据中学习得到),$w$ 是 i 向量,一个低维(通常为 400-600 维)表示,同时捕捉说话人变异和信道变异。 + +- 为了从 i 向量中去除信道变异,**概率线性判别分析(PLDA)**将 i 向量建模为说话人特定潜变量和信道特定潜变量之和。PLDA 为确认任务提供了一个有原则的对数似然比分数: + +$$\text{score}(w_1, w_2) = \log \frac{P(w_1, w_2 \mid \text{同一说话人})}{P(w_1 \mid \text{说话人}_1) \, P(w_2 \mid \text{说话人}_2)}$$ + +- **d向量**(Variani 等人,2014)是第一个神经说话人嵌入。一个为说话人分类训练的 DNN 处理帧级特征,通过对整条语音中最后一层隐藏层激活值求平均,提取出固定维度的表示。虽然简单但有效,d向量证明了神经网络可以在没有 i 向量复杂统计机制的情况下学习到说话人判别性特征。 + +- **x向量**(Snyder 等人,2018)使用**时延神经网络(TDNN)**架构显著推进了神经说话人嵌入。TDNN 是具有特定上下文窗口的 1D 卷积,与文件 03 中 WaveNet 的扩张卷积有关,但应用于帧级特征而非原始波形样本。 + +![x向量架构:TDNN 层以递增的上下文处理帧级特征,统计池化在时间维度上聚合,全连接层产生说话人嵌入](../images/xvector_architecture.svg) + +- x向量架构包含三个阶段: + - **帧级层**:一组 TDNN 层处理 MFCC(来自文件 01),时间上下文逐步扩大。每一层都有一个固定的上下文窗口(例如第一层为 $\{t-2, t-1, t, t+1, t+2\}$,后续层窗口更宽)。 + - **统计池化**:在帧级层之后,计算帧级输出在整个语音上的均值和标准差,产生一个与语音时长无关的固定维度向量: + +```math +\begin{aligned} +\mu &= \frac{1}{T} \sum_{t=1}^{T} h_t \\ +\sigma &= \sqrt{\frac{1}{T} \sum_{t=1}^{T} (h_t - \mu)^2} +\end{aligned} +``` + +- 其中 $h_t$ 是时间 $t$ 的帧级输出。拼接 $[\mu; \sigma]$ 即为池化后的表示。 + - **段级层**:全连接层处理池化后的表示。第一个段级层的输出(softmax 之前)即为 x 向量嵌入。 + +- x向量使用说话人身份上的标准交叉熵损失进行训练。尽管是为分类任务训练的,但学习到的中间表示(x向量)能很好地泛化到未见过的说话人,因为网络学习的是提取说话人判别性特征,而非记忆特定说话人。 + +- **ECAPA-TDNN**(Desplanques 等人,2020)是目前最先进的基于 TDNN 的说话人识别架构。它在 x 向量基础上引入了三项改进: + - **压缩激励(SE)模块**:通道注意力(来自第 08 章的 SENet),根据全局上下文重新加权特征通道,使模型能够强调与说话人相关的通道。 + - **Res2Net 风格的多尺度特征**:在每个 TDNN 模块内,通道被分成若干组,以层级方式处理,在多个时间分辨率上创建特征(类似于第 08 章的多尺度特征提取)。 + - **注意力统计池化**:不再使用等权平均,而是通过注意力机制为每一帧对池化统计量的贡献分配权重。包含更多说话人判别性内容的帧(如元音,承载更多说话人信息)获得更高的注意力权重: + +$$\alpha_t = \frac{\exp(v^T f(h_t))}{\sum_{\tau} \exp(v^T f(h_\tau))}$$ + +- 其中 $f$ 是一个小型神经网络,$v$ 是一个学习到的注意力向量。注意力加权的均值和标准差变为 $\tilde{\mu} = \sum_t \alpha_t h_t$ 和 $\tilde{\sigma} = \sqrt{\sum_t \alpha_t (h_t - \tilde{\mu})^2}$。 + +- ECAPA-TDNN 通常使用 **AAM-Softmax**(附加角度间隔 Softmax)进行训练,它在分类损失中添加了角度间隔惩罚,将同一说话人的嵌入推得更近,不同说话人的嵌入在超球面上推得更远: + +$$L = -\log \frac{e^{s \cos(\theta_{y_i} + m)}}{e^{s \cos(\theta_{y_i} + m)} + \sum_{j \neq y_i} e^{s \cos \theta_j}}$$ + +- 其中 $\theta_{y_i}$ 是嵌入与真实类别权重向量之间的夹角,$m$ 是间隔(通常为 0.2),$s$ 是缩放因子(通常为 30)。该损失函数来自人脸识别(第 08 章的 ArcFace),在说话人确认中非常有效。 + +- **说话人日志**回答了多方录音中"谁在什么时候说话"的问题。可以把这想象成给时间线上色:每种颜色代表一个不同的说话人,系统必须确定每个说话人何时活跃,包括重叠语音的情况。 + +![说话人日志:音频时间线被分割并用说话人身份标注,展示交替说话和重叠区域](../images/speaker_diarisation.svg) + +- **基于聚类的说话人日志**是传统的流水线方法: + - **分割**:将音频划分为短段(通常为 1-2 秒),使用滑动窗口或说话人变化检测。 + - **嵌入提取**:为每个片段提取说话人嵌入(x向量、ECAPA-TDNN)。 + - **聚类**:按说话人对片段进行分组。**凝聚层次聚类(AHC)**是标准方法:开始时每个片段自成一类,然后迭代合并两个最相似的类,直到满足停止条件(基于距离阈值或目标说话人数)。 + - **重分割**:使用基于维特比算法的重对齐来优化边界。 + +- 说话人数量通常事先未知,这使得该问题比标准聚类更困难。使用基于特征值阈值确定 $k$ 的谱聚类是另一种常见方法。 + +- **端到端神经说话人日志(EEND)**(Fujita 等人,2019)将说话人日志框架化为一个多标签分类问题。一个神经网络(通常是基于自注意力的模型,第 07 章的 transformer)将整段录音作为输入,为每一帧输出每个说话人的二元活动标签。这直接处理了重叠语音,而这是基于聚类方法的主要弱点。 + +- EEND 对 $S$ 个说话人在帧 $t$ 的输出为: + +$$\hat{y}_{t,s} = \sigma(f_s(h_t))$$ + +- 其中 $h_t$ 是帧 $t$ 处的 transformer 输出,$f_s$ 是说话人 $s$ 的线性投影。训练损失是在说话人和帧上求和得到的二元交叉熵。一个关键挑战是说话人数量必须固定,或者使用可变输出架构(EEND-EDA 使用带吸引子的编码器-解码器)来处理。 + +- **置换不变训练(PIT)**用于处理说话人日志中的标签歧义问题:由于说话人没有固有顺序,需要对所有可能的说话人到输出分配计算损失,并取最小值(这与文件 05 中源分离使用的 PIT 相同)。 + +- **音频分类**为整段音频片段分配一个标签。与转录语音的 ASR(文件 02)不同,音频分类涵盖更广的范围:环境声音(警笛、雨声、狗吠)、音乐流派(摇滚、爵士、古典)以及一般音频事件。 + +- 标准方法遵循第 08 章的图像分类范式:将音频表示为语谱图(一个二维时间-频率图像),然后应用 CNN 或 transformer 分类器。这种谱图-图像方法利用了计算机视觉几十年来的进展。 + +- **环境声音分类(ESC)**使用 ESC-50(50 类,2000 个片段)和 UrbanSound8K 等数据集。典型架构是应用于对数梅尔语谱图的 CNN(第 06 章)。数据增强至关重要:时间拉伸、音高偏移、添加背景噪声以及 **SpecAugment**(文件 02 的掩码方法应用于语谱图)都能提升泛化能力。 + +- **音频事件检测**(声音事件检测,SED)是分类的时间维度对应任务:不仅仅要知道存在哪些事件,还要知道它们何时开始和结束。**AudioSet**(Gemmeke 等人,2017)是大规模基准,包含 527 个事件类别和超过 200 万个来自 YouTube 的 10 秒片段,每个片段都有弱标注(片段级标签,而非帧级)。 + +- **弱监督 SED** 必须从片段级标签学习帧级预测。标准方法使用 CNN 产生帧级类别概率,然后通过注意力池化聚合成片段级预测: + +$$\hat{Y}_c = \sigma\left(\sum_t \alpha_{t,c} \cdot f_{t,c}\right)$$ + +- 其中 $f_{t,c}$ 是类别 $c$ 在时间 $t$ 的帧级 logit,$\alpha_{t,c}$ 是注意力权重。片段级预测 $\hat{Y}_c$ 根据片段级标签进行训练。 + +- **声学场景分类(ASC)**对整体环境进行分类:"机场"、"公园"、"地铁站"、"办公室"。这是一个整体性任务:模型必须捕捉一般的声学纹理而非特定事件。DCASE 挑战系列每年对 ASC 进行基准测试,获奖系统通常使用多分辨率语谱图上的 CNN 集成。 + +- **音频嵌入**是从大规模音频数据中学习到的通用表示,类似于可迁移到下游任务的词嵌入(第 07 章)或图像特征(第 08 章)。 + +- **VGGish**(Hershey 等人,2017)将 VGG 图像分类网络(第 08 章)适配到音频领域。它通过一个在 AudioSet 上预训练的类 VGG CNN 处理 0.96 秒的对数梅尔语谱图块,每块产生一个 128 维嵌入。VGGish 嵌入可作为下游任务的通用音频特征,类似于 ImageNet 预训练 CNN 提供视觉特征的方式。 + +- **PANNs**(预训练音频神经网络,Kong 等人,2020)是一系列 CNN 架构(CNN6、CNN10、CNN14),在完整的 AudioSet 上为音频标记任务训练。CNN14 使用最广泛,是一个 14 层 CNN,将对数梅尔语谱图作为输入,使用 $3 \times 3$ 卷积。PANNs 产生 2048 维嵌入,在多种音频任务上实现了最先进的迁移学习性能。 + +- **音频语谱图 Transformer(AST)**(Gong 等人,2021)将视觉 Transformer(ViT,第 08 章)架构直接应用于音频语谱图。语谱图被分割成 $16 \times 16$ 的块(就像 ViT 分割图像一样),每个块被线性投影为令牌嵌入,添加位置嵌入,然后由标准 Transformer 编码器(第 07 章)处理序列。[CLS] 令牌的输出用于分类。 + +![音频语谱图 Transformer:梅尔语谱图被分割成块,每个块展平并线性投影为令牌,添加位置嵌入,Transformer 编码器通过 CLS 令牌产生分类输出](../images/audio_spectrogram_transformer.svg) + +- AST 受益于 **ImageNet 预训练**:由于语谱图是 2D 图像,AST 从 ImageNet 图像上预训练的 ViT 初始化,然后在音频上微调。这种跨模态迁移出奇地有效,因为两个域共享低级特征(边缘、纹理),并且位置嵌入可以插值以处理不同大小的语谱图。 + +- **HTS-AT**(Chen 等人,2022)使用分层 Swin Transformer 架构(第 08 章的移位窗口注意力)改进了 AST,在降低计算成本的同时通过多尺度特征提取提升了性能。 + +- **BEATs**(Chen 等人,2023)使用了一种音频特定的预训练策略:使用离散标记器进行迭代掩码预测(类似于文件 02 中 wav2vec 2.0 的方法,但应用于通用音频)。标记器逐步细化,创建越来越具有语义意义的离散音频令牌。 + +- **基于嵌入的说话人日志**结合了说话人嵌入与时序建模。像 Pyannote.audio 这样的现代系统使用三阶段流水线:(1) 检测说话人切换和重叠语音的神经分割模型,(2) 应用于每个检测到的片段的嵌入提取阶段(ECAPA-TDNN),以及 (3) 聚类以在整个录音中分配说话人身份。 + +- **音乐信息检索(MIR)**将音频分析应用于音乐。文件 01 中的谱图表示在这里尤其有用,因为音乐具有丰富的和声结构。 + +- **节拍跟踪**检测音乐的节奏脉冲。标准方法从语谱图计算**起始强度包络**(检测表示音符起始的能量增加),然后使用自相关或节拍图谱找到节奏,最后使用动态规划跟踪单个节拍位置,找到最能匹配起始包络同时保持稳定节奏的节拍时间序列。 + +- **和弦识别**识别随时间变化的和声内容。输入通常是**色度图**(也称为音高类别分布图):一个 12 维表示,将所有八度折叠在一起,显示 12 个音高类别(C、C#、D、…、B)中每个类别的能量。CNN 或 RNN(第 06 章)将每个时间帧分类到标准和弦标签之一(C 大调、A 小调、G7 等)。 + +- 色度图通过将每个频率区间映射到其音高类别,从 STFT(文件 01)计算得到: + +$$\text{chroma}(p) = \sum_{k : \text{pitch}(k) \bmod 12 = p} |X(k)|^2$$ + +- 其中 $p \in \{0, 1, \ldots, 11\}$ 是音高类别,$\text{pitch}(k)$ 将频率区间 $k$ 映射到其 MIDI 音符编号。 + +- **源分离基础**(详见文件 05)将音乐录音分离为单独的乐器(人声、鼓、贝斯、其他)。这是混音、卡拉 OK 和音乐转录等 MIR 应用的核心。像 Demucs(文件 05)这样的模型在标准 MUSDB18 基准上达到了非常好的分离质量。 + +- **音乐标记**为歌曲分配标签(流派、情感、乐器、时代)。它本质上是应用于音乐的音频分类,使用相同的 CNN-语谱图方法。Million Song Dataset 和 MagnaTagATune 是标准基准。 + +- **音频指纹**从短片段中识别特定录音,即使存在噪声、混响或压缩伪影。经典系统是 Shazam,它对星座图(语谱图中的显著峰值)进行哈希处理。神经方法学习对声学退化具有不变性、同时对不同录音保持判别性的鲁棒嵌入,这与第 06 章和第 08 章中的不变特征学习一脉相承。 + +## 编程任务(使用 Colab 或笔记本) + +- **任务 1:带统计池化的说话人嵌入提取。** 构建一个简单的 x向量风格模型,通过 TDNN 层和统计池化处理帧级特征以产生说话人嵌入。 + +```python +import jax +import jax.numpy as jnp +import jax.random as jr +import matplotlib.pyplot as plt + +# Simulate frame-level MFCC features for multiple speakers +def generate_speaker_data(key, n_speakers=5, utterances_per_speaker=20, + n_frames=100, n_features=40): + """Generate synthetic speaker data with speaker-dependent patterns.""" + keys = jr.split(key, 3) + all_features = [] + all_labels = [] + + # Each speaker has a characteristic spectral pattern + speaker_patterns = jr.normal(keys[0], (n_speakers, n_features)) * 0.5 + + for spk in range(n_speakers): + for utt in range(utterances_per_speaker): + k = jr.fold_in(keys[1], spk * utterances_per_speaker + utt) + noise = jr.normal(k, (n_frames, n_features)) * 0.3 + features = speaker_patterns[spk][None, :] + noise + all_features.append(features) + all_labels.append(spk) + + perm = jr.permutation(keys[2], len(all_features)) + features = jnp.stack(all_features)[perm] + labels = jnp.array(all_labels)[perm] + return features, labels + +key = jr.PRNGKey(42) +features, labels = generate_speaker_data(key) +n_speakers = 5 +n_features = 40 + +# x-vector-style model +def init_xvector(key, n_features=40, hidden=128, embed_dim=64, n_speakers=5): + keys = jr.split(key, 8) + params = { + # TDNN layer 1: context [-2, 2] + 'tdnn1_w': jr.normal(keys[0], (5, n_features, hidden)) * jnp.sqrt(2.0 / (5 * n_features)), + 'tdnn1_b': jnp.zeros(hidden), + # TDNN layer 2: context [-2, 2] + 'tdnn2_w': jr.normal(keys[1], (5, hidden, hidden)) * jnp.sqrt(2.0 / (5 * hidden)), + 'tdnn2_b': jnp.zeros(hidden), + # TDNN layer 3: context [-3, 3] + 'tdnn3_w': jr.normal(keys[2], (7, hidden, hidden)) * jnp.sqrt(2.0 / (7 * hidden)), + 'tdnn3_b': jnp.zeros(hidden), + # Segment-level layers (after pooling: 2*hidden -> embed_dim) + 'seg1_w': jr.normal(keys[3], (2 * hidden, embed_dim)) * jnp.sqrt(2.0 / (2 * hidden)), + 'seg1_b': jnp.zeros(embed_dim), + # Classification head + 'cls_w': jr.normal(keys[4], (embed_dim, n_speakers)) * jnp.sqrt(2.0 / embed_dim), + 'cls_b': jnp.zeros(n_speakers), + } + return params + +def xvector_forward(params, x, return_embedding=False): + """x: (batch, frames, features) -> logits or embeddings.""" + # TDNN layers (1D convolutions) + h = jax.lax.conv_general_dilated( + x.transpose(0, 2, 1), params['tdnn1_w'].transpose(2, 1, 0), + window_strides=(1,), padding='SAME' + ).transpose(0, 2, 1) + params['tdnn1_b'] + h = jax.nn.relu(h) + + h = jax.lax.conv_general_dilated( + h.transpose(0, 2, 1), params['tdnn2_w'].transpose(2, 1, 0), + window_strides=(1,), padding='SAME' + ).transpose(0, 2, 1) + params['tdnn2_b'] + h = jax.nn.relu(h) + + h = jax.lax.conv_general_dilated( + h.transpose(0, 2, 1), params['tdnn3_w'].transpose(2, 1, 0), + window_strides=(1,), padding='SAME' + ).transpose(0, 2, 1) + params['tdnn3_b'] + h = jax.nn.relu(h) + + # Statistics pooling: mean and std over time + mu = jnp.mean(h, axis=1) + sigma = jnp.std(h, axis=1) + pooled = jnp.concatenate([mu, sigma], axis=-1) + + # Segment-level layer -> embedding + embedding = jax.nn.relu(pooled @ params['seg1_w'] + params['seg1_b']) + + if return_embedding: + return embedding + + # Classification + logits = embedding @ params['cls_w'] + params['cls_b'] + return logits + +def cross_entropy_loss(params, features, labels): + logits = xvector_forward(params, features) + one_hot = jax.nn.one_hot(labels, n_speakers) + log_probs = jax.nn.log_softmax(logits) + return -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1)) + +grad_fn = jax.jit(jax.value_and_grad(cross_entropy_loss)) + +# Train +params = init_xvector(jr.PRNGKey(0)) +lr = 1e-3 +losses = [] + +for epoch in range(300): + loss_val, grads = grad_fn(params, features, labels) + params = jax.tree.map(lambda p, g: p - lr * g, params, grads) + losses.append(float(loss_val)) + +# Extract embeddings and visualise with t-SNE-style 2D projection (using PCA) +embeddings = xvector_forward(params, features, return_embedding=True) + +# Simple PCA to 2D +emb_centered = embeddings - jnp.mean(embeddings, axis=0) +_, _, Vt = jnp.linalg.svd(emb_centered, full_matrices=False) +proj_2d = emb_centered @ Vt[:2].T + +fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + +axes[0].plot(losses, color='#3498db', linewidth=1.5) +axes[0].set_xlabel('Epoch') +axes[0].set_ylabel('Cross-Entropy Loss') +axes[0].set_title('Speaker Classification Training') +axes[0].set_yscale('log') + +colors = ['#3498db', '#e74c3c', '#27ae60', '#f39c12', '#9b59b6'] +for spk in range(n_speakers): + mask = labels == spk + axes[1].scatter(proj_2d[mask, 0], proj_2d[mask, 1], c=colors[spk], + label=f'Speaker {spk}', alpha=0.7, s=30) +axes[1].set_xlabel('PC 1') +axes[1].set_ylabel('PC 2') +axes[1].set_title('Speaker Embeddings (PCA projection)') +axes[1].legend() + +plt.tight_layout() +plt.show() + +# Verification demo: cosine similarity +emb_norm = embeddings / jnp.linalg.norm(embeddings, axis=-1, keepdims=True) +sim_matrix = emb_norm @ emb_norm.T +print(f"Embedding shape: {embeddings.shape}") +print(f"Avg same-speaker similarity: {jnp.mean(sim_matrix[labels[:, None] == labels[None, :]]):.4f}") +print(f"Avg diff-speaker similarity: {jnp.mean(sim_matrix[labels[:, None] != labels[None, :]]):.4f}") +``` + +- **任务 2:基于余弦相似度评分的说话人确认。** 给定预计算的说话人嵌入,实现一个计算 EER(等错误率)并绘制 DET 曲线的确认系统。 + +```python +import jax +import jax.numpy as jnp +import jax.random as jr +import matplotlib.pyplot as plt + +def generate_verification_pairs(key, n_speakers=20, dim=64, n_pairs=2000): + """Generate speaker embeddings and verification trial pairs.""" + keys = jr.split(key, 5) + + # Speaker centroids with some variance + centroids = jr.normal(keys[0], (n_speakers, dim)) + centroids = centroids / jnp.linalg.norm(centroids, axis=-1, keepdims=True) + + # Generate enrollment and test embeddings with intra-speaker variance + enroll_embs = [] + test_embs = [] + trial_labels = [] # 1 = same speaker (target), 0 = different (impostor) + + for i in range(n_pairs): + k1, k2, k3 = jr.split(jr.fold_in(keys[1], i), 3) + is_target = jr.bernoulli(k1).astype(int) + + spk1 = jr.randint(k2, (), 0, n_speakers) + emb1 = centroids[spk1] + jr.normal(jr.fold_in(k3, 0), (dim,)) * 0.15 + + if is_target: + spk2 = spk1 + else: + spk2 = (spk1 + jr.randint(jr.fold_in(k3, 1), (), 1, n_speakers)) % n_speakers + + emb2 = centroids[spk2] + jr.normal(jr.fold_in(k3, 2), (dim,)) * 0.15 + + enroll_embs.append(emb1) + test_embs.append(emb2) + trial_labels.append(int(is_target)) + + return (jnp.stack(enroll_embs), jnp.stack(test_embs), + jnp.array(trial_labels)) + +key = jr.PRNGKey(42) +enroll, test, labels = generate_verification_pairs(key) + +# Compute cosine similarity scores +enroll_norm = enroll / jnp.linalg.norm(enroll, axis=-1, keepdims=True) +test_norm = test / jnp.linalg.norm(test, axis=-1, keepdims=True) +scores = jnp.sum(enroll_norm * test_norm, axis=-1) + +# Compute FAR and FRR at various thresholds +thresholds = jnp.linspace(-1.0, 1.0, 500) + +target_scores = scores[labels == 1] +impostor_scores = scores[labels == 0] + +fars = [] +frrs = [] +for thresh in thresholds: + far = jnp.mean(impostor_scores >= thresh) # false accepts + frr = jnp.mean(target_scores < thresh) # false rejects + fars.append(float(far)) + frrs.append(float(frr)) + +fars = jnp.array(fars) +frrs = jnp.array(frrs) + +# Find EER: where FAR ≈ FRR +eer_idx = jnp.argmin(jnp.abs(fars - frrs)) +eer = float((fars[eer_idx] + frrs[eer_idx]) / 2) +eer_threshold = float(thresholds[eer_idx]) + +print(f"Equal Error Rate (EER): {eer:.4f} ({eer*100:.2f}%)") +print(f"EER threshold: {eer_threshold:.4f}") + +fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + +# Score distributions +bins = jnp.linspace(-0.5, 1.0, 60) +axes[0].hist(target_scores, bins=bins, alpha=0.6, color='#27ae60', + label='Target (same speaker)', density=True) +axes[0].hist(impostor_scores, bins=bins, alpha=0.6, color='#e74c3c', + label='Impostor (different speaker)', density=True) +axes[0].axvline(eer_threshold, color='#f39c12', linestyle='--', linewidth=2, + label=f'EER threshold = {eer_threshold:.3f}') +axes[0].set_xlabel('Cosine Similarity Score') +axes[0].set_ylabel('Density') +axes[0].set_title('Score Distributions') +axes[0].legend() + +# FAR vs FRR +axes[1].plot(thresholds, fars, color='#e74c3c', linewidth=2, label='FAR') +axes[1].plot(thresholds, frrs, color='#3498db', linewidth=2, label='FRR') +axes[1].axvline(eer_threshold, color='#f39c12', linestyle='--', linewidth=1.5) +axes[1].scatter([eer_threshold], [eer], color='#f39c12', s=100, zorder=5, + label=f'EER = {eer:.4f}') +axes[1].set_xlabel('Threshold') +axes[1].set_ylabel('Error Rate') +axes[1].set_title('FAR and FRR vs Threshold') +axes[1].legend() + +# DET curve (FAR vs FRR) +axes[2].plot(fars, frrs, color='#9b59b6', linewidth=2) +axes[2].plot([0, 1], [0, 1], 'k--', alpha=0.3) +axes[2].scatter([eer], [eer], color='#f39c12', s=100, zorder=5, + label=f'EER = {eer:.4f}') +axes[2].set_xlabel('False Acceptance Rate') +axes[2].set_ylabel('False Rejection Rate') +axes[2].set_title('DET Curve') +axes[2].set_xlim([0, 0.5]) +axes[2].set_ylim([0, 0.5]) +axes[2].legend() +axes[2].set_aspect('equal') + +plt.tight_layout() +plt.show() +``` + +- **任务 3:音频语谱图块嵌入(AST 风格)。** 实现音频语谱图 Transformer 的块提取和嵌入层,可视化语谱图如何被令牌化。 + +```python +import jax +import jax.numpy as jnp +import jax.random as jr +import matplotlib.pyplot as plt + +# Generate a synthetic spectrogram (harmonic structure + noise) +def generate_spectrogram(key, n_time=128, n_freq=128): + """Create a synthetic spectrogram with harmonic patterns.""" + k1, k2 = jr.split(key) + spec = jr.normal(k1, (n_time, n_freq)) * 0.1 + + # Add harmonic bands (simulating speech formants) + for f0 in [15, 30, 45, 70]: + width = 3 + envelope = jnp.exp(-0.5 * ((jnp.arange(n_freq) - f0) / width) ** 2) + time_mod = 0.5 + 0.5 * jnp.sin(2 * jnp.pi * jnp.arange(n_time) / 40) + spec += jnp.outer(time_mod, envelope) + + return jnp.clip(spec, 0, None) + +key = jr.PRNGKey(42) +spectrogram = generate_spectrogram(key) +n_time, n_freq = spectrogram.shape + +# Patch extraction parameters +patch_h = 16 # time +patch_w = 16 # frequency +stride_h = 16 +stride_w = 16 +embed_dim = 192 # ViT-Small dimension + +n_patches_h = n_time // stride_h +n_patches_w = n_freq // stride_w +n_patches = n_patches_h * n_patches_w + +print(f"Spectrogram: {n_time} x {n_freq}") +print(f"Patch size: {patch_h} x {patch_w}") +print(f"Number of patches: {n_patches_h} x {n_patches_w} = {n_patches}") + +# Extract patches +def extract_patches(spec, patch_h, patch_w, stride_h, stride_w): + """Extract non-overlapping patches from spectrogram.""" + patches = [] + positions = [] + for i in range(0, spec.shape[0] - patch_h + 1, stride_h): + for j in range(0, spec.shape[1] - patch_w + 1, stride_w): + patch = spec[i:i+patch_h, j:j+patch_w] + patches.append(patch.flatten()) + positions.append((i, j)) + return jnp.stack(patches), positions + +patches, positions = extract_patches(spectrogram, patch_h, patch_w, stride_h, stride_w) +print(f"Patches shape: {patches.shape}") # (n_patches, patch_h * patch_w) + +# Linear projection (patch embedding) +patch_dim = patch_h * patch_w +k1, k2 = jr.split(jr.PRNGKey(0)) +W_embed = jr.normal(k1, (patch_dim, embed_dim)) * jnp.sqrt(2.0 / patch_dim) +b_embed = jnp.zeros(embed_dim) + +# Learnable positional embeddings +pos_embed = jr.normal(k2, (n_patches + 1, embed_dim)) * 0.02 # +1 for CLS + +# CLS token +cls_token = jnp.zeros((1, embed_dim)) + +# Forward pass +patch_tokens = patches @ W_embed + b_embed # (n_patches, embed_dim) +tokens = jnp.concatenate([cls_token, patch_tokens], axis=0) # (n_patches+1, embed_dim) +tokens = tokens + pos_embed # Add positional embeddings + +print(f"Token sequence shape: {tokens.shape}") +print(f"Each token has dimension: {embed_dim}") + +# Visualisation +fig, axes = plt.subplots(2, 2, figsize=(14, 10)) + +# Original spectrogram with patch grid +axes[0, 0].imshow(spectrogram.T, aspect='auto', origin='lower', cmap='magma') +for i in range(0, n_time + 1, stride_h): + axes[0, 0].axvline(i - 0.5, color='white', linewidth=0.5, alpha=0.5) +for j in range(0, n_freq + 1, stride_w): + axes[0, 0].axhline(j - 0.5, color='white', linewidth=0.5, alpha=0.5) +axes[0, 0].set_title(f'Spectrogram with {patch_h}x{patch_w} Patch Grid') +axes[0, 0].set_xlabel('Time frame') +axes[0, 0].set_ylabel('Frequency bin') + +# Individual patches visualised +n_show = min(16, n_patches) +patch_grid = patches[:n_show].reshape(n_show, patch_h, patch_w) +combined = jnp.concatenate([patch_grid[i] for i in range(min(8, n_show))], axis=1) +axes[0, 1].imshow(combined.T, aspect='auto', origin='lower', cmap='magma') +axes[0, 1].set_title(f'First {min(8, n_show)} Patches (concatenated)') +axes[0, 1].set_xlabel('Patch index (horizontal)') +axes[0, 1].set_ylabel('Frequency within patch') + +# Token embeddings similarity matrix +token_norms = tokens / jnp.linalg.norm(tokens, axis=-1, keepdims=True) +sim = token_norms @ token_norms.T +im = axes[1, 0].imshow(sim, cmap='RdBu_r', vmin=-1, vmax=1) +axes[1, 0].set_title('Token Similarity Matrix (cosine)') +axes[1, 0].set_xlabel('Token index') +axes[1, 0].set_ylabel('Token index') +plt.colorbar(im, ax=axes[1, 0], fraction=0.046) + +# Positional embedding similarity +pos_norms = pos_embed / jnp.linalg.norm(pos_embed, axis=-1, keepdims=True) +pos_sim = pos_norms @ pos_norms.T +im2 = axes[1, 1].imshow(pos_sim, cmap='RdBu_r', vmin=-1, vmax=1) +axes[1, 1].set_title('Positional Embedding Similarity') +axes[1, 1].set_xlabel('Position index') +axes[1, 1].set_ylabel('Position index') +plt.colorbar(im2, ax=axes[1, 1], fraction=0.046) + +plt.tight_layout() +plt.show() +``` + +- **任务 4:用于和弦分析的简单色度图计算。** 从合成和声信号计算并可视化色度图,展示音乐信息检索中使用的音高类别折叠方法。 + +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# Generate a synthetic musical signal: C major chord -> G major chord +sr = 16000 +duration = 2.0 +t = jnp.linspace(0, duration, int(sr * duration)) + +# C major (C4=261.6, E4=329.6, G4=392.0) for first half +# G major (G3=196.0, B3=246.9, D4=293.7) for second half +half = len(t) // 2 + +c_major = (0.5 * jnp.sin(2 * jnp.pi * 261.63 * t[:half]) + + 0.4 * jnp.sin(2 * jnp.pi * 329.63 * t[:half]) + + 0.3 * jnp.sin(2 * jnp.pi * 392.00 * t[:half])) + +g_major = (0.5 * jnp.sin(2 * jnp.pi * 196.00 * t[:half]) + + 0.4 * jnp.sin(2 * jnp.pi * 246.94 * t[:half]) + + 0.3 * jnp.sin(2 * jnp.pi * 293.66 * t[:half])) + +signal = jnp.concatenate([c_major, g_major]) + +# Compute STFT +n_fft = 4096 # high resolution for pitch accuracy +hop_length = 512 +window = jnp.hanning(n_fft) + +def stft(signal, n_fft, hop_length, window): + n_frames = 1 + (len(signal) - n_fft) // hop_length + frames = jnp.stack([ + signal[i * hop_length : i * hop_length + n_fft] * window + for i in range(n_frames) + ]) + return jnp.fft.rfft(frames, n=n_fft) + +S = stft(signal, n_fft, hop_length, window) +power_spec = jnp.abs(S) ** 2 +freqs = jnp.fft.rfftfreq(n_fft, 1.0 / sr) + +# Compute chromagram by mapping frequency bins to pitch classes +# MIDI note number from frequency: 69 + 12 * log2(f / 440) +note_names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'] + +def freq_to_chroma(freq): + """Map frequency to pitch class (0-11). Returns -1 for freq <= 0.""" + midi = 69 + 12 * jnp.log2(jnp.clip(freq, 1e-10, None) / 440.0) + return jnp.round(midi).astype(int) % 12 + +# Build chromagram: sum power spectrum energy for each pitch class +chromagram = jnp.zeros((power_spec.shape[0], 12)) +valid_freqs = freqs[1:] # skip DC +valid_power = power_spec[:, 1:] + +for p in range(12): + # Find frequency bins belonging to this pitch class + chroma_bins = freq_to_chroma(valid_freqs) + mask = (chroma_bins == p).astype(jnp.float32) + chromagram = chromagram.at[:, p].set( + jnp.sum(valid_power * mask[None, :], axis=1) + ) + +# Normalise each frame +chromagram = chromagram / (jnp.max(chromagram, axis=1, keepdims=True) + 1e-8) + +# Visualisation +fig, axes = plt.subplots(3, 1, figsize=(14, 10)) + +# Waveform +axes[0].plot(t[:3000], signal[:3000], color='#3498db', linewidth=0.5, + label='C major') +axes[0].plot(t[half:half+3000], signal[half:half+3000], color='#e74c3c', + linewidth=0.5, label='G major') +axes[0].set_title('Waveform: C major → G major') +axes[0].set_ylabel('Amplitude') +axes[0].set_xlabel('Time (s)') +axes[0].legend() + +# Spectrogram (log scale) +time_axis = jnp.arange(power_spec.shape[0]) * hop_length / sr +axes[1].imshow(jnp.log1p(power_spec[:, :500].T), aspect='auto', origin='lower', + cmap='magma', extent=[0, time_axis[-1], 0, freqs[500]]) +axes[1].set_title('Power Spectrogram') +axes[1].set_ylabel('Frequency (Hz)') +axes[1].set_xlabel('Time (s)') + +# Chromagram +im = axes[2].imshow(chromagram.T, aspect='auto', origin='lower', cmap='YlOrRd', + extent=[0, time_axis[-1], -0.5, 11.5]) +axes[2].set_yticks(range(12)) +axes[2].set_yticklabels(note_names) +axes[2].set_title('Chromagram (pitch class energy over time)') +axes[2].set_ylabel('Pitch class') +axes[2].set_xlabel('Time (s)') +plt.colorbar(im, ax=axes[2], fraction=0.046, label='Normalised energy') + +# Mark expected active pitch classes +mid_frame = chromagram.shape[0] // 2 +print(f"C major region - expected: C, E, G") +print(f" Chroma values: {dict(zip(note_names, [f'{v:.2f}' for v in chromagram[mid_frame//2]]))}") +print(f"G major region - expected: G, B, D") +print(f" Chroma values: {dict(zip(note_names, [f'{v:.2f}' for v in chromagram[mid_frame + mid_frame//2]]))}") + +plt.tight_layout() +plt.show() +``` diff --git a/chapter 09: audio and speech/05. source separation and noise.md b/chapter 09: audio and speech/05. source separation and noise.md new file mode 100644 index 0000000..6aa7198 --- /dev/null +++ b/chapter 09: audio and speech/05. source separation and noise.md @@ -0,0 +1,804 @@ +# 源分离与降噪 + +*源分离与降噪从混合音频中恢复单个信号;即计算层面的"鸡尾酒会问题"。本文涵盖ICA、NMF、时频掩蔽、波束成形、深度学习分离网络(Conv-TasNet、SepFormer)、语音增强以及自适应降噪。* + +- 想象一下你站在一个拥挤的鸡尾酒会上。数十人同时在交谈,音乐在播放,酒杯在碰撞,但你却能专注于一段对话并清晰地跟上它。这种非凡的能力被称为**鸡尾酒会问题**(Cherry, 1953),人类听觉系统可以毫不费力地做到,但机器却觉得异常困难。本文涵盖了尝试解决这一问题的算法:分离混合音频源、消除不必要的噪声以及在不利条件下增强语音。 + +- 文件01中的信号处理基础(STFT、语谱图、滤波器组)支撑了这里的每一种方法。第02章中的矩阵分解技术(NMF、ICA、SVD)提供了经典工具集。第06章中的深度学习架构(CNN、RNN、注意力机制)以及第04/05章中的概率论则为现代方法提供了理论基础。 + +![鸡尾酒会问题:多个说话人和声源混合到麦克风阵列中,分离系统必须从混合物中恢复单个源信号](../images/cocktail_party.svg) + +- **问题形式化**:在一个或多个麦克风处观测到混合信号 $x(t)$。在最简单的情况下,混合信号是 $C$ 个源信号的和: + +$$x(t) = \sum_{c=1}^{C} s_c(t) + n(t)$$ + +- 其中 $s_c(t)$ 是第 $c$ 个源信号,$n(t)$ 是背景噪声。目标是从 $x(t)$ 中恢复出各个 $s_c(t)$。在单麦克风情况下,这是一个严重欠定的问题:一个方程,$C$ 个未知数。需要额外的假设(统计独立性、频谱结构、学习先验)才能使问题变得可解。 + +- 在频域中(通过文件01中的STFT),混合信号变为: + +$$X(t, f) = \sum_{c=1}^{C} S_c(t, f) + N(t, f)$$ + +- 许多分离方法在时频域中通过为每个源估计一个**掩蔽** $M_c(t, f) \in [0, 1]$ 来工作,然后通过 $\hat{S}_c(t, f) = M_c(t, f) \cdot X(t, f)$ 恢复源信号。**理想二值掩蔽(IBM)** 设置 $M_c(t, f) = 1$ 如果源 $c$ 在该时频单元中占主导,否则为0。**理想比率掩蔽(IRM)** 是其软版本: + +$$\text{IRM}_c(t, f) = \frac{|S_c(t, f)|^2}{\sum_{j=1}^{C} |S_j(t, f)|^2}$$ + +- **独立成分分析(ICA)** 是麦克风数量等于或超过源数量时的经典方法。ICA(第02章)寻找一个线性解混矩阵 $W$,使得 $\hat{s} = Wx$,其中恢复的源 $\hat{s}$ 在统计上最大限度地独立。关键假设是源信号是非高斯且独立的,这对于语音和音乐通常是成立的。 + +- 对于多麦克风瞬时混叠模型 $x = As$(其中 $A$ 是混叠矩阵),ICA 通过最大化输出的非高斯性(FastICA 使用负熵)或最小化互信息来恢复 $W \approx A^{-1}$。ICA 在受控环境中表现良好,但当混叠涉及卷积(房间混响)、源数量超过麦克风数量或独立性假设被违反时则会失败。 + +- **非负矩阵分解(NMF)** 将幅度语谱图 $V \in \mathbb{R}_+^{F \times T}$ 分解为两个非负矩阵的乘积(第02章): + +$$V \approx WH$$ + +- 其中 $W \in \mathbb{R}_+^{F \times K}$ 是包含 $K$ 个频谱基向量的字典,$H \in \mathbb{R}_+^{K \times T}$ 包含随时间变化的激活系数。非负约束具有物理动机:幅度是非负的,且声音是加性组合的。 + +- 对于源分离,NMF 为每个源学习独立的字典:$W_{\text{语音}}$ 捕捉语音的频谱模式(共振峰结构),而 $W_{\text{噪声}}$ 捕捉噪声模式。混合信号被分解为 $V \approx W_{\text{语音}} H_{\text{语音}} + W_{\text{噪声}} H_{\text{噪声}}$,每个源通过掩蔽来恢复。NMF 使用乘法更新规则进行最小化,代价函数可以是 Frobenius 范数或 KL 散度: + +```math +\begin{aligned} +\text{Frobenius:} \quad D_F(V \| WH) &= \|V - WH\|_F^2 \\ +\text{KL:} \quad D_{KL}(V \| WH) &= \sum_{f,t} \left[ V_{ft} \log \frac{V_{ft}}{(WH)_{ft}} - V_{ft} + (WH)_{ft} \right] +\end{aligned} +``` + +- **波束成形**利用麦克风阵列的空间信息。当一个源信号以不同的延迟到达不同的麦克风(由于空间排列)时,这些延迟可以用来增强来自某个方向的信号,同时抑制其他方向的信号。 + +![波束成形:麦克风阵列从不同方向接收具有不同时间延迟的信号,波束成形器将它们组合以增强目标方向并抑制其他方向](../images/beamforming.svg) + +- **延迟求和波束成形**是最简单的方法。如果目标源相对于阵列的角度为 $\theta$,则在麦克风 $m$ 处的时间延迟为 $\tau_m(\theta) = d_m \sin \theta / c$,其中 $d_m$ 是麦克风位置,$c$ 是声速。波束成形器输出将麦克风信号对齐并求和: + +$$y(t) = \frac{1}{M} \sum_{m=1}^{M} x_m(t - \tau_m(\theta))$$ + +- 来自目标方向的信号相干相加,而来自其他方向的信号非相干相加,从而实现空间滤波。阵列的几何形状决定了空间分辨率:更大的阵列产生更窄的波束。 + +- **最小方差无失真响应(MVDR)** 波束成形优化权重,以最小化总输出功率,同时保证目标方向无失真地通过: + +```math +\begin{aligned} +\min_{\mathbf{w}} \quad & \mathbf{w}^H \Phi_{nn} \mathbf{w} \\ +\text{subject to} \quad & \mathbf{w}^H \mathbf{d}(\theta) = 1 +\end{aligned} +``` + +- 其中 $\Phi_{nn}$ 是噪声空间协方差矩阵,$\mathbf{d}(\theta)$ 是方向 $\theta$ 的导向向量。闭式解为: + +$$\mathbf{w}_{\text{MVDR}} = \frac{\Phi_{nn}^{-1} \mathbf{d}(\theta)}{\mathbf{d}(\theta)^H \Phi_{nn}^{-1} \mathbf{d}(\theta)}$$ + +- MVDR 通过使用估计的噪声协方差自适应地适应噪声环境,比延迟求和提供更好的干扰抑制能力。它广泛用于助听器、智能音箱和远程会议系统。 + +- **深度学习用于源分离**显著提升了性能,特别是在经典方法难以处理的单麦克风情况下。一般范式是:编码混合信号,通过神经网络估计掩蔽或源表示,然后解码以恢复各个源。 + +- **深度聚类**(Hershey 等,2016)将每个时频单元嵌入到一个高维空间中,使得属于同一源的单元彼此靠近,而来自不同源的单元则远离。一个双向 LSTM(第06章)将每个时频单元 $(t, f)$ 映射为一个嵌入向量 $v_{t,f} \in \mathbb{R}^D$。训练目标为: + +$$\mathcal{L} = \|VV^T - YY^T\|_F^2$$ + +- 其中 $V$ 是嵌入矩阵,$Y$ 是源分配的单热矩阵。乘积 $VV^T$ 是一个亲和矩阵(两个单元的嵌入有多相似),而 $YY^T$ 是理想的亲和度(若属于同一源则为1,否则为0)。推理时,对嵌入进行 K-means 聚类产生二值掩蔽。 + +- **Conv-TasNet**(Luo 和 Mesgarani,2019)完全在时域中操作,绕过了 STFT。它包含三个组件: + +![Conv-TasNet 架构:编码器将混合波形转换为潜在表示,时域卷积网络分离器估计源掩蔽,解码器重建各个源波形](../images/conv_tasnet.svg) + +- **编码器**:一个一维卷积将混合波形的短片段映射为潜在表示。对于混合信号 $x \in \mathbb{R}^T$,编码器输出为 $w = \text{ReLU}(U \ast x) \in \mathbb{R}^{N \times L}$,其中 $U$ 是一个可学习的基(类似于 STFT 基但从数据中学习),$N$ 是基函数的数量,$L$ 是片段数。编码器核大小和步长(通常为2ms和1ms)决定了时间分辨率。 + +- **分离器**:一个**时域卷积网络(TCN)**处理编码后的混合信号并输出 $C$ 个掩蔽。TCN 堆叠了扩张一维深度可分离卷积(来自第08章的高效卷积),这些卷积以指数增长的扩张因子 $1, 2, 4, \ldots, 2^{B-1}$ 排列成块,重复 $R$ 次。这提供了非常大的感受野,同时保持计算高效。 + +- **解码器**:一个转置一维卷积(使用可学习基 $V$)将每个掩蔽后的表示转换回时域:$\hat{s}_c = V^T (M_c \odot w)$。 + +- Conv-TasNet 显著优于基于语谱图的方法,因为学习到的编码器-解码器基可以捕捉 STFT 幅度所丢弃的信息(特别是相位)。 + +- **双路径 RNN(DPRNN)**(Luo 等,2020)解决了分离中的长序列建模问题。DPRNN 不是用单个 RNN 或 TCN 处理整个编码序列,而是将序列分割成重叠的块,并沿着两条路径应用 RNN:**块内**路径(对每个块内的局部模式建模)和**块间**路径(对跨块的全局模式建模)。这使 RNN 序列长度从 $L$ 降低到每个维度上的 $\sqrt{L}$: + +```math +\begin{aligned} +\text{块内:} \quad & h_{k,n}^{\text{块内}} = \text{BiLSTM}_{\text{块内}}(z_{k,n}) \\ +\text{块间:} \quad & h_{k,n}^{\text{块间}} = \text{BiLSTM}_{\text{块间}}(h_{k,n}^{\text{块内}}) +\end{aligned} +``` + +- 其中 $k$ 索引块,$n$ 索引块内的位置。块内 LSTM 对固定 $k$ 的各 $n$ 处理;块间 LSTM 对固定 $n$ 的各 $k$ 处理。 + +- **SepFormer**(Subakan 等,2021)用 Transformer(第07章)替换了双路径框架中的 RNN。块内 Transformer 通过自注意力捕捉局部依赖关系,块间 Transformer 捕捉全局依赖关系。多头注意力能够建模长程依赖关系而不会出现梯度消失问题(第06章),这使得 SepFormer 对于长录音特别有效。SepFormer 在 WSJ0-2mix 基准上达到了最先进的结果。 + +- **置换不变训练(PIT)** 解决了监督式源分离中的一个基本问题:标签分配歧义。如果网络有两个输出(对应两个说话人),哪个输出应该对应哪个说话人?没有自然的排序。PIT 计算所有可能分配的损失并取最小值: + +$$\mathcal{L}_{\text{PIT}} = \min_{\pi \in \mathcal{P}} \sum_{c=1}^{C} \ell(\hat{s}_{\pi(c)}, s_c)$$ + +- 其中 $\mathcal{P}$ 是 $\{1, \ldots, C\}$ 的所有排列集合,$\ell$ 是每个源的损失(通常是尺度不变信号失真比 SI-SDR)。对于 $C = 2$ 个源只有2种排列;对于 $C = 3$ 有6种。对于更大的 $C$,可以使用匈牙利算法高效计算。 + +- **尺度不变信号失真比(SI-SDR)** 是源分离的标准评估指标: + +```math +\begin{aligned} +s_{\text{target}} &= \frac{\langle \hat{s}, s \rangle}{\|s\|^2} s \\ +e_{\text{noise}} &= \hat{s} - s_{\text{target}} \\ +\text{SI-SDR} &= 10 \log_{10} \frac{\|s_{\text{target}}\|^2}{\|e_{\text{noise}}\|^2} +\end{aligned} +``` + +- 其中 $\hat{s}$ 是估计的源,$s$ 是真实值。SI-SDR 对估计的总体尺度不变,这是期望的特性,因为绝对音量不如分离质量重要。较高的 SI-SDR(以 dB 为单位)更好。最先进的系统在 WSJ0-2mix 上实现了约 20-22 dB 的 SI-SDR 改进。 + +- **音乐源分离**将音乐录音分离成声部:人声、鼓、贝斯和其他乐器。这实现了卡拉OK(去除人声)、重新混音(调整乐器电平)和转录(一次分析一种乐器)等应用。 + +- **Open-Unmix**(Stoter 等,2019)是一个参考基线,使用三层双向 LSTM 在幅度 STFT 域中为每个源预测软掩蔽。它使用专用模型独立处理每个源。Open-Unmix 虽简单但有效,在 MUSDB18 上建立了可重复的基准。 + +- **Demucs**(Defossez 等,2019;2021年更新为 Hybrid Demucs)使用直接在波形上操作的 U-Net 架构(第08章)。编码器通过步长卷积压缩混合信号,解码器通过转置卷积和跳跃连接将其扩展回来,每个源有各自的解码器头。**Hybrid Demucs** 结合了时域和频域处理:编码器具有并行的时域和 STFT 分支,其特征在解码器之前融合。这同时捕捉了精细的时间细节和频谱结构。 + +- Demucs 在 MUSDB18 上达到了最先进的分离质量,特别是人声分离方面。其 U-Net 架构让人联想到第08章中的图像分割架构,将分离问题视为一种"音频分割"形式。 + +- **主动降噪(ANC)** 通过生成一个与噪声相消干涉的反噪声信号来减少不需要的声音。想象一下降噪耳机:麦克风拾取环境噪声,ANC 系统生成一个反相版本,混合信号(噪声 + 反噪声)理想情况下抵消为静音。 + +- 物理原理很简单:如果噪声是 $n(t)$,在空间同一点生成 $-n(t)$ 则产生静音:$n(t) + (-n(t)) = 0$。挑战在于反噪声必须在时间、幅度和相位上精确对齐。即使很小的误差也会产生残留噪声或伪影。 + +- **前馈式 ANC** 使用一个参考麦克风,在噪声到达听者之前拾取噪声。系统有时间处理噪声并生成反噪声。参考信号通过一个自适应滤波器,其输出在误差麦克风(靠近听者)处从噪声中减去。这适用于可预测的宽带噪声(引擎嗡嗡声、风扇噪声)。 + +- **反馈式 ANC** 仅使用听者耳边的误差麦克风。系统从残余信号(听者实际听到的)中估计噪声并调整反噪声。反馈式 ANC 更简单(不需要参考麦克风),但带宽有限且可能变得不稳定。 + +- **自适应滤波**是 ANC 背后的数学引擎。滤波器系数必须不断适应变化的噪声环境。最常用的算法是**最小均方(LMS)**滤波器。 + +![LMS 自适应滤波器:参考信号通过 FIR 滤波器,输出从期望信号中减去产生误差,误差反馈更新滤波器系数](../images/lms_adaptive_filter.svg) + +- **LMS 算法**:一个具有系数 $\mathbf{w} = [w_0, w_1, \ldots, w_{L-1}]^T$ 的 FIR 滤波器处理参考信号 $\mathbf{x}(n) = [x(n), x(n-1), \ldots, x(n-L+1)]^T$。输出为 $y(n) = \mathbf{w}^T \mathbf{x}(n)$,误差为 $e(n) = d(n) - y(n)$(其中 $d(n)$ 是期望/主信号),权重更新为: + +$$\mathbf{w}(n+1) = \mathbf{w}(n) + \mu \, e(n) \, \mathbf{x}(n)$$ + +- 其中 $\mu$ 是步长(学习率)。这是对均方误差 $E[e^2(n)]$ 的一个随机梯度下降步骤,使用瞬时梯度估计 $-2 e(n) \mathbf{x}(n)$ 代替真实梯度(第03章的梯度下降和第06章的 SGD)。 + +- 步长 $\mu$ 控制收敛速度与稳态误差之间的权衡。过大则滤波器振荡或发散;过小则自适应速度迟缓。稳定条件为 $0 < \mu < 2 / (\lambda_{\max})$,其中 $\lambda_{\max}$ 是输入自相关矩阵 $R = E[\mathbf{x}\mathbf{x}^T]$ 的最大特征值。 + +- **归一化 LMS(NLMS)** 通过输入功率对步长进行归一化,使收敛与信号电平无关: + +$$\mathbf{w}(n+1) = \mathbf{w}(n) + \frac{\mu}{\|\mathbf{x}(n)\|^2 + \epsilon} \, e(n) \, \mathbf{x}(n)$$ + +- 其中 $\epsilon$ 是一个小的正则化常数,以防止除零。NLMS 比 LMS 更可靠地收敛,因为有效步长自适应地适应输入功率。 + +- **递归最小二乘(RLS)** 是一种收敛更快的替代方法,它最小化加权最小二乘代价 $\sum_{k=1}^{n} \lambda^{n-k} e^2(k)$,其中 $\lambda \in (0, 1]$ 是遗忘因子。RLS 维护逆自相关矩阵的估计并递归更新,以每个样本 $O(L^2)$ 的计算成本(相对于 LMS 的 $O(L)$)实现最优收敛。 + +- **降噪与语音增强**旨在提高嘈杂录音中的语音质量和可懂度。与源分离(分离不同的源)不同,语音增强专门针对语音加噪声的情况,从带噪观测中恢复干净的语音。 + +- **谱减法**是最简单的方法。在纯噪声帧(由文件03中的 VAD 检测)期间,估计噪声频谱 $|\hat{N}(f)|^2$。然后将其从每个帧中减去: + +$$|\hat{S}(f)|^2 = \max(|X(f)|^2 - \alpha |\hat{N}(f)|^2, \beta |X(f)|^2)$$ + +- 其中 $\alpha$ 是过减因子(通常为1-4,激进的减法去除更多噪声但引入更多伪影),$\beta$ 是频谱地板,防止出现负值并减少"音乐噪声"伪影(听起来像随机音符的孤立音调残留)。 + +- **维纳滤波**提供了干净语音频谱的最小均方误差估计: + +$$\hat{S}(t, f) = \frac{|S(t,f)|^2}{|S(t,f)|^2 + |N(t,f)|^2} \cdot X(t, f) = G(t, f) \cdot X(t, f)$$ + +- 维纳增益 $G(t, f) = \text{SNR}(t, f) / (1 + \text{SNR}(t, f))$ 的范围从0(纯噪声)到1(纯语音),作为一个软掩蔽。挑战在于估计语音和噪声的功率谱。**先验 SNR** $\xi(t, f) = |S(t,f)|^2 / |N(t,f)|^2$ 使用"决策导向"方法估计:当前帧估计与前一帧维纳滤波输出的平滑组合。 + +- **神经语音增强**使用深度学习来估计掩蔽(如维纳增益)或直接估计干净语谱图。架构从简单的前馈网络到 U-Net(第08章)、CRN(卷积递归网络)和 Transformer。 + +- **DCCRN**(深度复数卷积递归网络)在复数 STFT(幅度和相位)上操作,使用自然处理实部和虚部的复数值卷积。这避免了仅幅度方法所困扰的相位估计问题。 + +- **FullSubNet** 使用双路径架构,包含一个全频带模型(捕捉全局频谱模式)和一个子频带模型(捕捉局部谐波细节)。全频带模型处理整个频谱,而子频带模型处理以每个频率单元为中心的窄频带。它们的输出被组合用于最终的掩蔽估计。 + +- **DNS(深度噪声抑制)挑战赛**由微软每年举办,对语音增强系统进行基准测试。获胜者通常使用大规模训练,包含多种噪声类型、数据增强(以各种 SNR 添加噪声、混响、编解码器伪影)以及支持实时处理的架构。 + +- **回声消除**在双向通信中去除声学回声。当你在电话通话中时,远端说话人的声音通过你的扬声器播放,在房间内反弹,并被你的麦克风拾取,产生远端说话人听到的回声。**声学回声消除(AEC)** 对从扬声器到麦克风的声学路径进行建模并减去预测的回声。 + +- 声学路径被建模为一个自适应 FIR 滤波器(使用 LMS 或 NLMS),以远端信号为输入。滤波器对房间脉冲响应进行建模,包括直达路径、早期反射和晚期混响。房间脉冲响应可能长达数百毫秒,需要数千个抽头的滤波器。 + +- **双讲检测**对 AEC 至关重要:当近端和远端说话人同时说话时,自适应滤波器必须冻结(停止更新),以防止其抵消近端说话人的声音。双讲检测器将误差信号的能量与远端信号能量进行比较;无法用远端信号解释的误差能量突然增加表明存在近端语音。 + +- 远端信号 $x(n)$ 与麦克风信号 $d(n)$ 之间的**归一化互相关**提供了一个双讲指示符: + +$$\xi(n) = \frac{|\sum_{k=0}^{L-1} x(n-k) d(n-k)|}{\sqrt{\sum_{k} x^2(n-k)} \sqrt{\sum_{k} d^2(n-k)}}$$ + +- 在单讲期间(仅远端),$\xi$ 较高,因为 $d$ 主要是 $x$ 的回声。在双讲期间,$\xi$ 下降,因为近端语音与 $x$ 不相关。 + +- 现代 AEC 系统将自适应滤波与神经网络相结合:自适应滤波器提供初始回声估计,神经网络(类似于上述语音增强模型)清理残余回声并处理线性滤波器无法捕捉的非线性(扬声器失真)。 + +- **分离与增强的评估指标**: + - **SI-SDR**(如上定义):源分离的标准指标。 + - **SDR**(信号失真比):来自 BSS Eval,衡量包括伪影和干扰在内的整体分离质量。 + - **PESQ**(语音质量感知评估):ITU 标准,预测主观质量分数。范围:-0.5 至 4.5。 + - **STOI**(短时客观可懂度):预测语音可懂度。范围:0 至 1。 + - **DNSMOS**:微软的深度噪声抑制 MOS 预测器,一个训练用于预测人类 MOS 分数的神经网络,无需干净的参考音频。 + +## 编程任务(使用 CoLab 或 notebook) + +- **任务 1:用于源分离的独立成分分析。** 实现 FastICA 来分离两个混合音频源,演示确定情况(源与麦克风数量相等)下的经典鸡尾酒会解决方案。 + +```python +import jax +import jax.numpy as jnp +import jax.random as jr +import matplotlib.pyplot as plt + +# 生成两个源信号 +sr = 8000 +duration = 1.0 +t = jnp.linspace(0, duration, int(sr * duration)) + +# 源 1:正弦波(类似音调) +s1 = jnp.sin(2 * jnp.pi * 440 * t) + 0.3 * jnp.sin(2 * jnp.pi * 880 * t) + +# 源 2:锯齿波(丰富的谐波) +s2 = 2 * (t * 200 % 1) - 1 # 200 Hz 锯齿波 + +# 归一化源信号 +s1 = s1 / jnp.max(jnp.abs(s1)) +s2 = s2 / jnp.max(jnp.abs(s2)) +sources = jnp.stack([s1, s2]) # (2, T) + +# 混叠矩阵(算法未知) +A = jnp.array([[0.8, 0.4], + [0.3, 0.9]]) +mixtures = A @ sources # (2, T) + +# FastICA 实现 +def whiten(X): + """数据中心化与白化。""" + X_centered = X - jnp.mean(X, axis=1, keepdims=True) + cov = (X_centered @ X_centered.T) / X_centered.shape[1] + eigvals, eigvecs = jnp.linalg.eigh(cov) + D_inv_sqrt = jnp.diag(1.0 / jnp.sqrt(eigvals + 1e-8)) + whitening = D_inv_sqrt @ eigvecs.T + return whitening @ X_centered, whitening + +def fastica(X, n_components=2, max_iter=200, tol=1e-6): + """使用 tanh 非线性的 FastICA(负熵近似)。""" + X_white, whitening = whiten(X) + n, T = X_white.shape + + key = jr.PRNGKey(42) + W = jr.normal(key, (n_components, n)) + # 正交化 W + U, _, Vt = jnp.linalg.svd(W, full_matrices=False) + W = U @ Vt + + for iteration in range(max_iter): + W_old = W.copy() + + # 对每个分量 + for i in range(n_components): + w = W[i] + # w^T X_white: (T,) + wx = w @ X_white # (T,) + + # g(u) = tanh(u), g'(u) = 1 - tanh^2(u) + g_wx = jnp.tanh(wx) + g_prime_wx = 1 - g_wx ** 2 + + # Newton 更新: w_new = E[X * g(w^T X)] - E[g'(w^T X)] * w + w_new = jnp.mean(X_white * g_wx[None, :], axis=1) - \ + jnp.mean(g_prime_wx) * w + + # 与之前的分量去相关(消去法) + for j in range(i): + w_new = w_new - jnp.dot(w_new, W[j]) * W[j] + + w_new = w_new / jnp.linalg.norm(w_new) + W = W.at[i].set(w_new) + + # 检查收敛 + convergence = jnp.min(jnp.abs(jnp.diag(W @ W_old.T))) + if convergence > 1 - tol: + print(f"FastICA 在 {iteration + 1} 次迭代后收敛") + break + + # 解混矩阵 + unmixing = W @ whitening + recovered = unmixing @ X + return recovered, unmixing + +recovered, W_unmix = fastica(mixtures) + +# 修复符号歧义(ICA 可能翻转符号) +for i in range(2): + if jnp.corrcoef(recovered[i], sources[i])[0, 1] < -0.5: + recovered = recovered.at[i].set(-recovered[i]) + +# 如果源被交换,修复排列 +corr_00 = jnp.abs(jnp.corrcoef(recovered[0], sources[0])[0, 1]) +corr_01 = jnp.abs(jnp.corrcoef(recovered[0], sources[1])[0, 1]) +if corr_01 > corr_00: + recovered = recovered[::-1] + +# 归一化以便显示 +recovered = recovered / jnp.max(jnp.abs(recovered), axis=1, keepdims=True) + +fig, axes = plt.subplots(3, 2, figsize=(14, 9)) + +axes[0, 0].plot(t[:1000], s1[:1000], color='#3498db', linewidth=0.8) +axes[0, 0].set_title('源信号 1(原始)') +axes[0, 0].set_ylabel('幅度') + +axes[0, 1].plot(t[:1000], s2[:1000], color='#e74c3c', linewidth=0.8) +axes[0, 1].set_title('源信号 2(原始)') + +axes[1, 0].plot(t[:1000], mixtures[0, :1000], color='#9b59b6', linewidth=0.8) +axes[1, 0].set_title('混合信号 1(麦克风 1)') +axes[1, 0].set_ylabel('幅度') + +axes[1, 1].plot(t[:1000], mixtures[1, :1000], color='#9b59b6', linewidth=0.8) +axes[1, 1].set_title('混合信号 2(麦克风 2)') + +axes[2, 0].plot(t[:1000], recovered[0, :1000], color='#27ae60', linewidth=0.8) +axes[2, 0].set_title('恢复的源信号 1(FastICA)') +axes[2, 0].set_ylabel('幅度') +axes[2, 0].set_xlabel('时间 (s)') + +axes[2, 1].plot(t[:1000], recovered[1, :1000], color='#f39c12', linewidth=0.8) +axes[2, 1].set_title('恢复的源信号 2(FastICA)') +axes[2, 1].set_xlabel('时间 (s)') + +plt.tight_layout() +plt.show() + +# 报告与原始信号的相关性 +for i in range(2): + corr = jnp.corrcoef(recovered[i], sources[i])[0, 1] + print(f"源 {i+1} 恢复相关性: {corr:.4f}") +``` + +- **任务 2:基于 NMF 的语谱图源分离。** 使用非负矩阵分解(第02章)将语谱图分离为两个分量,演示 NMF 如何为每个源学习频谱字典。 + +```python +import jax +import jax.numpy as jnp +import jax.random as jr +import matplotlib.pyplot as plt + +# 生成两个具有不同频谱特征的信号 +sr = 8000 +duration = 1.0 +t = jnp.linspace(0, duration, int(sr * duration)) + +# 源 1:低频谐波(模拟贝斯) +src1 = (jnp.sin(2 * jnp.pi * 100 * t) + + 0.5 * jnp.sin(2 * jnp.pi * 200 * t) + + 0.3 * jnp.sin(2 * jnp.pi * 300 * t)) + +# 源 2:高频谐波(模拟长笛) +src2 = (jnp.sin(2 * jnp.pi * 800 * t) + + 0.4 * jnp.sin(2 * jnp.pi * 1600 * t)) + +# 时变幅度(源在不同时间激活) +env1 = jnp.where(t < 0.5, 1.0, 0.3) +env2 = jnp.where(t > 0.3, 1.0, 0.2) +src1 = src1 * env1 +src2 = src2 * env2 + +mixture = src1 + src2 + +# 计算幅度语谱图(STFT) +n_fft = 512 +hop = 128 +window = jnp.hanning(n_fft) + +def compute_stft(signal, n_fft, hop, window): + n_frames = 1 + (len(signal) - n_fft) // hop + frames = jnp.stack([ + signal[i * hop : i * hop + n_fft] * window + for i in range(n_frames) + ]) + return jnp.fft.rfft(frames, n=n_fft) + +S_mix = compute_stft(mixture, n_fft, hop, window) +V = jnp.abs(S_mix).T # (F, T) - 频率 x 时间 +phase = jnp.angle(S_mix).T + +F, T = V.shape +print(f"语谱图形状: {F} 个频率 bin x {T} 个时间帧") + +# NMF: V ≈ WH 使用乘法更新规则 +def nmf(V, K, n_iter=200, key=jr.PRNGKey(0)): + """使用 Frobenius 范数的非负矩阵分解。""" + k1, k2 = jr.split(key) + W = jnp.abs(jr.normal(k1, (F, K))) * 0.1 + 0.01 # (F, K) + H = jnp.abs(jr.normal(k2, (K, T))) * 0.1 + 0.01 # (K, T) + + costs = [] + for i in range(n_iter): + # H 的乘法更新 + WtV = W.T @ V + WtWH = W.T @ W @ H + 1e-8 + H = H * (WtV / WtWH) + + # W 的乘法更新 + VHt = V @ H.T + WHHt = W @ H @ H.T + 1e-8 + W = W * (VHt / WHHt) + + cost = jnp.sum((V - W @ H) ** 2) + costs.append(float(cost)) + + return W, H, costs + +# 运行 K=2 个分量的 NMF +K = 2 +W, H, costs = nmf(V, K, n_iter=300) + +# 使用软掩蔽重建每个源 +V_hat = W @ H +mask1 = (W[:, 0:1] @ H[0:1, :]) / (V_hat + 1e-8) +mask2 = (W[:, 1:2] @ H[1:2, :]) / (V_hat + 1e-8) + +V_src1 = mask1 * V +V_src2 = mask2 * V + +# 可视化 +fig, axes = plt.subplots(3, 2, figsize=(14, 10)) + +# 混合信号语谱图 +axes[0, 0].imshow(jnp.log1p(V), aspect='auto', origin='lower', cmap='magma') +axes[0, 0].set_title('混合信号语谱图 |X|') +axes[0, 0].set_ylabel('频率 bin') + +# NMF 收敛 +axes[0, 1].plot(costs, color='#3498db', linewidth=1.5) +axes[0, 1].set_title('NMF 收敛曲线') +axes[0, 1].set_xlabel('迭代次数') +axes[0, 1].set_ylabel('Frobenius 代价') +axes[0, 1].set_yscale('log') + +# 频谱基向量 W +freq_hz = jnp.arange(F) * sr / n_fft +axes[1, 0].plot(freq_hz, W[:, 0], color='#27ae60', linewidth=1.5, + label='基 1(低频)') +axes[1, 0].plot(freq_hz, W[:, 1], color='#e74c3c', linewidth=1.5, + label='基 2(高频)') +axes[1, 0].set_title('学习到的频谱基 W') +axes[1, 0].set_xlabel('频率 (Hz)') +axes[1, 0].set_ylabel('幅度') +axes[1, 0].legend() + +# 时域激活 H +time_s = jnp.arange(T) * hop / sr +axes[1, 1].plot(time_s, H[0], color='#27ae60', linewidth=1.5, + label='激活 1') +axes[1, 1].plot(time_s, H[1], color='#e74c3c', linewidth=1.5, + label='激活 2') +axes[1, 1].set_title('时域激活 H') +axes[1, 1].set_xlabel('时间 (s)') +axes[1, 1].set_ylabel('激活值') +axes[1, 1].legend() + +# 分离后的语谱图 +axes[2, 0].imshow(jnp.log1p(V_src1), aspect='auto', origin='lower', cmap='magma') +axes[2, 0].set_title('分离后的源信号 1(低频)') +axes[2, 0].set_ylabel('频率 bin') +axes[2, 0].set_xlabel('时间帧') + +axes[2, 1].imshow(jnp.log1p(V_src2), aspect='auto', origin='lower', cmap='magma') +axes[2, 1].set_title('分离后的源信号 2(高频)') +axes[2, 1].set_xlabel('时间帧') + +plt.tight_layout() +plt.show() + +print(f"重建误差: {jnp.sum((V - W @ H)**2):.2f}") +print(f"NMF 学习到的频谱基能够捕捉每个源的频率特征。") +``` + +- **任务 3:用于降噪的 LMS 自适应滤波器。** 实现 LMS 和 NLMS 算法用于回声/降噪,展示收敛行为及步长的影响。 + +```python +import jax +import jax.numpy as jnp +import jax.random as jr +import matplotlib.pyplot as plt + +# 模拟回声消除场景 +# 远端信号 -> 房间脉冲响应 -> 麦克风处的回声 +# 近端语音是我们希望保留的目标信号 + +sr = 8000 +duration = 2.0 +n_samples = int(sr * duration) +key = jr.PRNGKey(42) +keys = jr.split(key, 5) + +# 远端信号(参考):随机的类语音信号 +far_end = jr.normal(keys[0], (n_samples,)) * 0.5 + +# 房间脉冲响应(算法未知) +rir_length = 64 +rir = jnp.zeros(rir_length) +rir = rir.at[0].set(0.8) # 直达路径 +rir = rir.at[5].set(0.3) # 早期反射 +rir = rir.at[12].set(-0.2) # 反射 +rir = rir.at[25].set(0.1) # 晚期反射 +rir = rir.at[40].set(-0.05) + +# 回声:远端信号与 RIR 的卷积 +echo = jnp.convolve(far_end, rir)[:n_samples] + +# 近端语音(在信号的一部分中活跃) +near_end = jnp.zeros(n_samples) +start, end = n_samples // 3, 2 * n_samples // 3 +near_speech = 0.3 * jnp.sin( + 2 * jnp.pi * 300 * jnp.linspace(0, (end - start) / sr, end - start) +) +near_end = near_end.at[start:end].set(near_speech) + +# 麦克风信号:回声 + 近端 + 噪声 +noise = jr.normal(keys[1], (n_samples,)) * 0.01 +mic_signal = echo + near_end + noise + +# LMS 自适应滤波器 +def lms_filter(reference, desired, filter_length, mu): + """标准 LMS 自适应滤波器。""" + n = len(reference) + w = jnp.zeros(filter_length) + output = jnp.zeros(n) + error = jnp.zeros(n) + w_history = [] + + for i in range(filter_length, n): + x = reference[max(0, i-filter_length+1):i+1][::-1] + + y = jnp.dot(w, x) + e = desired[i] - y + w = w + mu * e * x + + output = output.at[i].set(y) + error = error.at[i].set(e) + + if i % 500 == 0: + w_history.append(w.copy()) + + return output, error, w_history + +# NLMS 自适应滤波器 +def nlms_filter(reference, desired, filter_length, mu, eps=1e-6): + """归一化 LMS 自适应滤波器。""" + n = len(reference) + w = jnp.zeros(filter_length) + output = jnp.zeros(n) + error = jnp.zeros(n) + + for i in range(filter_length, n): + x = reference[max(0, i-filter_length+1):i+1][::-1] + + y = jnp.dot(w, x) + e = desired[i] - y + norm_factor = jnp.dot(x, x) + eps + w = w + (mu / norm_factor) * e * x + + output = output.at[i].set(y) + error = error.at[i].set(e) + + return output, error + +# 使用不同步长运行 LMS +filter_len = 64 +mu_values = [0.001, 0.01, 0.05] +colors_mu = ['#3498db', '#e74c3c', '#27ae60'] + +fig, axes = plt.subplots(2, 2, figsize=(14, 10)) + +# 原始信号 +t = jnp.arange(n_samples) / sr +axes[0, 0].plot(t, mic_signal, color='#9b59b6', linewidth=0.5, alpha=0.7, + label='麦克风(回声 + 近端)') +axes[0, 0].plot(t, echo, color='#e74c3c', linewidth=0.5, alpha=0.7, + label='回声(待消除)') +axes[0, 0].plot(t, near_end, color='#27ae60', linewidth=0.8, + label='近端语音(需保留)') +axes[0, 0].set_title('信号分量') +axes[0, 0].set_xlabel('时间 (s)') +axes[0, 0].set_ylabel('幅度') +axes[0, 0].legend(fontsize=8) + +# 不同步长下的 LMS 收敛 +for mu, color in zip(mu_values, colors_mu): + _, err, _ = lms_filter(far_end, mic_signal, filter_len, mu) + # 平滑后的平方误差 + sq_err = err ** 2 + window_size = 200 + smoothed = jnp.convolve(sq_err, jnp.ones(window_size)/window_size, + mode='valid') + axes[0, 1].plot(smoothed, color=color, linewidth=1.2, + label=f'mu={mu}') + +axes[0, 1].set_title('LMS 收敛曲线(平滑 MSE)') +axes[0, 1].set_xlabel('样本') +axes[0, 1].set_ylabel('平方误差') +axes[0, 1].set_yscale('log') +axes[0, 1].legend() + +# 最佳 LMS 结果 +_, err_lms, w_hist = lms_filter(far_end, mic_signal, filter_len, 0.01) +axes[1, 0].plot(t, mic_signal, color='#9b59b6', linewidth=0.5, alpha=0.4, + label='消除前') +axes[1, 0].plot(t, err_lms, color='#3498db', linewidth=0.5, alpha=0.8, + label='LMS 消除后') +axes[1, 0].plot(t, near_end, color='#27ae60', linewidth=0.8, alpha=0.5, + label='真实近端') +axes[1, 0].set_title('LMS 回声消除结果 (mu=0.01)') +axes[1, 0].set_xlabel('时间 (s)') +axes[1, 0].set_ylabel('幅度') +axes[1, 0].legend(fontsize=8) + +# NLMS 结果 +_, err_nlms = nlms_filter(far_end, mic_signal, filter_len, 0.5) +axes[1, 1].plot(t, mic_signal, color='#9b59b6', linewidth=0.5, alpha=0.4, + label='消除前') +axes[1, 1].plot(t, err_nlms, color='#f39c12', linewidth=0.5, alpha=0.8, + label='NLMS 消除后') +axes[1, 1].plot(t, near_end, color='#27ae60', linewidth=0.8, alpha=0.5, + label='真实近端') +axes[1, 1].set_title('NLMS 回声消除结果 (mu=0.5)') +axes[1, 1].set_xlabel('时间 (s)') +axes[1, 1].set_ylabel('幅度') +axes[1, 1].legend(fontsize=8) + +plt.tight_layout() +plt.show() + +# 测量回声衰减 +echo_power = jnp.mean(echo ** 2) +lms_residual = jnp.mean(err_lms[n_samples//2:] ** 2) # 收敛后 +nlms_residual = jnp.mean(err_nlms[n_samples//2:] ** 2) +print(f"回声功率: {10*jnp.log10(echo_power):.1f} dB") +print(f"LMS 残差: {10*jnp.log10(lms_residual):.1f} dB " + f"(ERLE: {10*jnp.log10(echo_power/lms_residual):.1f} dB)") +print(f"NLMS 残差: {10*jnp.log10(nlms_residual):.1f} dB " + f"(ERLE: {10*jnp.log10(echo_power/nlms_residual):.1f} dB)") +``` + +- **任务 4:用于语音增强的时频掩蔽。** 实现一个简单的频谱掩蔽方法(理想比率掩蔽),并将其与谱减法进行比较,在合成的带噪语音信号上可视化分离质量。 + +```python +import jax +import jax.numpy as jnp +import jax.random as jr +import matplotlib.pyplot as plt + +# 创建合成的"语音"和"噪声"信号 +sr = 8000 +duration = 2.0 +t = jnp.linspace(0, duration, int(sr * duration)) + +# 语音:具有时变幅度的谐波序列(模拟语音) +speech = jnp.zeros_like(t) +for f0 in [150, 300, 450, 600, 900]: + amp_env = 0.5 + 0.5 * jnp.sin(2 * jnp.pi * 2.0 * t) # 2 Hz 调制 + speech = speech + (0.5 / (f0/150)) * amp_env * jnp.sin(2 * jnp.pi * f0 * t) +speech = speech / jnp.max(jnp.abs(speech)) + +# 噪声:限带噪声 +key = jr.PRNGKey(42) +noise_raw = jr.normal(key, t.shape) * 0.4 + +# 在给定 SNR 下混合 +snr_db = 5.0 +speech_power = jnp.mean(speech ** 2) +noise_power = jnp.mean(noise_raw ** 2) +noise_scale = jnp.sqrt(speech_power / (noise_power * 10 ** (snr_db / 10))) +noise = noise_raw * noise_scale +mixture = speech + noise + +# STFT +n_fft = 512 +hop = 128 +window = jnp.hanning(n_fft) + +def stft(signal, n_fft, hop, window): + n_frames = 1 + (len(signal) - n_fft) // hop + frames = jnp.stack([ + signal[i * hop : i * hop + n_fft] * window + for i in range(n_frames) + ]) + return jnp.fft.rfft(frames, n=n_fft) + +def istft(S, hop, window, length): + n_fft = (S.shape[1] - 1) * 2 + n_frames = S.shape[0] + frames = jnp.fft.irfft(S, n=n_fft) * window[None, :] + output = jnp.zeros(length) + window_sum = jnp.zeros(length) + for i in range(n_frames): + start = i * hop + end = start + n_fft + if end <= length: + output = output.at[start:end].add(frames[i]) + window_sum = window_sum.at[start:end].add(window ** 2) + window_sum = jnp.maximum(window_sum, 1e-8) + return output / window_sum + +S_speech = stft(speech, n_fft, hop, window) +S_noise = stft(noise, n_fft, hop, window) +S_mix = stft(mixture, n_fft, hop, window) + +mag_speech = jnp.abs(S_speech) +mag_noise = jnp.abs(S_noise) +mag_mix = jnp.abs(S_mix) +phase_mix = jnp.angle(S_mix) + +# 方法 1:理想比率掩蔽(oracle - 理论上限) +irm = mag_speech ** 2 / (mag_speech ** 2 + mag_noise ** 2 + 1e-8) +S_irm = (irm * mag_mix) * jnp.exp(1j * phase_mix) +enhanced_irm = istft(S_irm, hop, window, len(mixture)) + +# 方法 2:谱减法 +# 从前 0.2s 估计噪声(假设为静音段) +noise_frames = int(0.2 * sr / hop) +noise_est = jnp.mean(mag_mix[:noise_frames] ** 2, axis=0, keepdims=True) +alpha = 2.0 # 过减因子 +beta = 0.02 # 频谱地板 +mag_sub = jnp.maximum(mag_mix ** 2 - alpha * noise_est, beta * mag_mix ** 2) +mag_sub = jnp.sqrt(mag_sub) +S_sub = mag_sub * jnp.exp(1j * phase_mix) +enhanced_sub = istft(S_sub, hop, window, len(mixture)) + +# 方法 3:维纳滤波器 +snr_est = mag_mix ** 2 / (noise_est + 1e-8) +wiener_gain = snr_est / (1 + snr_est) +S_wiener = (wiener_gain * mag_mix) * jnp.exp(1j * phase_mix) +enhanced_wiener = istft(S_wiener, hop, window, len(mixture)) + +# 计算每种方法的 SI-SDR +def si_sdr(estimate, reference): + """尺度不变信号失真比。""" + ref = reference[:len(estimate)] + est = estimate[:len(reference)] + s_target = (jnp.dot(est, ref) / (jnp.dot(ref, ref) + 1e-8)) * ref + e_noise = est - s_target + return 10 * jnp.log10(jnp.dot(s_target, s_target) / + (jnp.dot(e_noise, e_noise) + 1e-8)) + +si_sdr_mix = si_sdr(mixture, speech) +si_sdr_irm_val = si_sdr(enhanced_irm, speech) +si_sdr_sub_val = si_sdr(enhanced_sub, speech) +si_sdr_wiener_val = si_sdr(enhanced_wiener, speech) + +# 可视化 +fig, axes = plt.subplots(3, 2, figsize=(14, 12)) + +# 语谱图 +axes[0, 0].imshow(jnp.log1p(mag_speech.T), aspect='auto', origin='lower', + cmap='magma') +axes[0, 0].set_title('干净语音语谱图') +axes[0, 0].set_ylabel('频率 bin') + +axes[0, 1].imshow(jnp.log1p(mag_mix.T), aspect='auto', origin='lower', + cmap='magma') +axes[0, 1].set_title(f'带噪混合 ({snr_db:.0f} dB SNR)') + +# 掩蔽 +axes[1, 0].imshow(irm.T, aspect='auto', origin='lower', cmap='RdYlGn') +axes[1, 0].set_title('理想比率掩蔽(Oracle)') +axes[1, 0].set_ylabel('频率 bin') + +axes[1, 1].imshow(wiener_gain.T, aspect='auto', origin='lower', cmap='RdYlGn', + vmin=0, vmax=1) +axes[1, 1].set_title('估计的维纳增益') + +# 增强后的波形对比 +n_show = 3000 +axes[2, 0].plot(t[:n_show], speech[:n_show], color='#27ae60', linewidth=0.8, + alpha=0.5, label='干净') +axes[2, 0].plot(t[:n_show], mixture[:n_show], color='#e74c3c', linewidth=0.5, + alpha=0.4, label='带噪') +axes[2, 0].plot(t[:n_show], enhanced_irm[:n_show], color='#3498db', + linewidth=0.8, label='IRM 增强') +axes[2, 0].set_title('波形对比(IRM)') +axes[2, 0].set_xlabel('时间 (s)') +axes[2, 0].set_ylabel('幅度') +axes[2, 0].legend(fontsize=8) + +# SI-SDR 柱状图 +methods = ['混合信号', '谱减法', '维纳滤波器', '理想比率掩蔽'] +sdr_values = [float(si_sdr_mix), float(si_sdr_sub_val), + float(si_sdr_wiener_val), float(si_sdr_irm_val)] +bar_colors = ['#e74c3c', '#f39c12', '#9b59b6', '#27ae60'] +bars = axes[2, 1].bar(methods, sdr_values, color=bar_colors, alpha=0.8) +axes[2, 1].set_ylabel('SI-SDR (dB)') +axes[2, 1].set_title('增强质量对比') +for bar, val in zip(bars, sdr_values): + axes[2, 1].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.3, + f'{val:.1f}', ha='center', fontsize=10) +axes[2, 1].axhline(0, color='gray', linestyle='--', linewidth=0.8) + +plt.tight_layout() +plt.show() + +print(f"SI-SDR(带噪混合): {si_sdr_mix:.2f} dB") +print(f"SI-SDR(谱减法): {si_sdr_sub_val:.2f} dB") +print(f"SI-SDR(维纳滤波器): {si_sdr_wiener_val:.2f} dB") +print(f"SI-SDR(理想比率掩蔽): {si_sdr_irm_val:.2f} dB(oracle 理论上限)") +``` diff --git a/chapter 10: multimodal learning/01. multimodal representations.md b/chapter 10: multimodal learning/01. multimodal representations.md new file mode 100644 index 0000000..c49ae40 --- /dev/null +++ b/chapter 10: multimodal learning/01. multimodal representations.md @@ -0,0 +1,366 @@ +# 多模态表征 + +*多模态表征将视觉、语言和音频桥接到共享嵌入空间中。本文件涵盖融合策略、CLIP、ALIGN、SigLIP、对比损失函数(InfoNCE、NT-Xent)、零样本分类和检索评估。* + +- 想象你坐在一家咖啡馆里。你看到桌上冒热气的水杯,听到陶瓷的叮当声,闻到烘焙咖啡豆的香气,感受到从马克杯传来的暖意。没有哪一种感官能告诉你一切:你的大脑将这些信号融合成一个统一的感知——"热咖啡"。**多模态学习** 对机器做了同样的事:它结合来自多种模态(视觉、语言、音频等)的信息,构建出比任何单一模态单独提供的表征更丰富、更鲁棒的表征。 + +- **模态(modality)** 是一种独特的信息通道。在机器学习中,最常见的模态包括图像(像素网格)、文本(词元序列)、音频(波形或语谱图,如第9章所述)、视频(帧序列)和结构化数据(表格、图)。每种模态都有其自身的统计结构:图像具有空间连贯性,文本是序列化和离散的,音频是时间性的和连续的。多模态学习的挑战在于桥接这些根本不同的数据类型。 + +- 为什么要费心结合多种模态?因为它们提供互补的信息。一张狗的照片告诉你它的品种和颜色,但不会告诉你名字。像"我的金毛犬 Max"这样的描述告诉你名字和品种,但不会告诉你确切姿态。图像和文本结合起来,比任何单独一个给出的画面都更完整。这种互补性是其核心动机:多模态模型可以回答那些单模态模型无法回答的问题、生成内容并做出决策。 + +![多模态学习概览:独立的编码器处理图像、文本和音频输入,它们的表征在共享嵌入空间中汇合](../images/multimodal_overview.svg) + +## 融合策略 + +- 想象一个小组项目。你有两种组合想法的方式:每个人从一开始就在同一个房间里一起工作(共享原始笔记和草稿),或者每个人独立撰写自己的部分,最后合并最终文档。这分别对应于多模态学习中的**早期融合(early fusion)** 和**晚期融合(late fusion)**。 + +- **早期融合**(也称为特征级融合)在任何高级处理之前,对来自不同模态的原始或低级特征进行拼接或混合。例如,你可以将图像的像素特征与文本的词元嵌入拼接起来,将组合后的序列输入到一个单一的 Transformer 中。模型可以从一开始就学习细粒度的跨模态交互,但输入空间很大,且模型必须学会同时处理截然不同的数据类型。 + +- 形式化地,给定来自两种模态的特征向量 $x_{\\text{img}} \\in \\mathbb{R}^{d_1}$ 和 $x_{\\text{txt}} \\in \\mathbb{R}^{d_2}$,早期融合简单地拼接它们: + +$$x_{\\text{fused}} = [x_{\\text{img}}; x_{\\text{txt}}] \\in \\mathbb{R}^{d_1 + d_2}$$ + +- 这个拼接后的向量由共享网络处理。其优势在于模型可以在每一层发现跨模态相关性。缺点是计算成本高,且难以对齐非常不同的特征类型(密集的像素值与稀疏的词元索引)。 + +- **晚期融合**(也称为决策级融合)通过各自的编码器独立处理每种模态,为每种模态生成一个高层表征甚至最终的预测结果。这些输出随后被组合,通常通过平均分数、投票或一个可学习的组合层。晚期融合更简单,且允许你直接复用预训练的单模态模型,但它无法捕捉低层的跨模态交互,因为各模态从未"看到"彼此的原始特征。 + +- 给定模态特定的预测值 $\hat{y}_1$ 和 $\hat{y}_2$,一个简单的晚期融合规则是: + +$$\hat{y} = \\alpha \\hat{y}_1 + (1 - \\alpha) \\hat{y}_2$$ + +- 其中 $\\alpha \\in [0, 1]$ 是一个可学习或手动调节的混合权重。 + +- **中间融合(middle fusion)**(也称为中间融合 intermediate fusion)是大多数现代系统使用的实用折中方案。每种模态先由其自身的编码器处理(提取模态特定的特征),然后在网络中间部分通过跨注意力层等方式组合编码后的表征。这使得每个编码器可以专注于自身的模态,同时仍能实现丰富的跨模态交互。Flamingo、LLaVA 和大多数视觉-语言模型(文件 02)都使用中间融合。 + +![早期、中间和晚期融合策略:早期融合拼接原始输入,中间融合通过跨注意力合并中间表征,晚期融合组合最终预测](../images/fusion_strategies.svg) + +- 融合策略的选择取决于数据可用性、计算预算和任务。早期融合功能强大但数据需求高。晚期融合廉价但受限。带有跨注意力的中间融合已成为大规模多模态模型的主流做法,因为它在表达能力与模块化之间取得了平衡。 + +## 联合嵌入空间 + +- 想象一个通用翻译器,它可以将任何语言的任何句子映射到同一个共享"意义空间"中的同一点。用英语、法语或日语说的"a dog on a beach"都会落在同一个坐标上。**联合嵌入空间** 跨模态做了完全相同的事:一张沙滩上的狗的图像和文本"a dog on a beach"应该映射到同一向量空间中的邻近点。 + +- 形式化地,我们学习两个编码器函数:模态 1(如图像)的 $f_\\theta : \\mathcal{X}_1 \\to \\mathbb{R}^d$ 和模态 2(如文本)的 $g_\\phi : \\mathcal{X}_2 \\to \\mathbb{R}^d$。两者都将输入映射到相同的 $d$ 维空间。训练目标确保语义匹配的对 $(x_1, x_2)$ 的嵌入 $f_\\theta(x_1)$ 和 $g_\\phi(x_2)$ 彼此接近(高余弦相似度),而不匹配的对则相距很远。 + +- 这是第 7 章中词嵌入空间的直接推广。回忆一下,Word2Vec 和 GloVe 将语义相似的词放置在向量空间中彼此靠近。联合嵌入空间将这一思想扩展到跨模态:不是衡量词与词的相似性,而是衡量图像到文本的相似性、音频到文本的相似性,甚至图像到音频的相似性。 + +- 相似度度量几乎总是**余弦相似度**(第 1 章): + +$$\\text{sim}(u, v) = \\frac{u \\cdot v}{\\|u\\| \\|v\\|}$$ + +- 通过将所有嵌入 $L_2$ 归一化到单位超球面上,余弦相似度简化为简单的点积 $u \\cdot v$,计算效率极高,并且可以使用近似最近邻库进行加速。 + +![联合嵌入空间:图像编码器和文本编码器将各自的输入映射到共享向量空间中,匹配的对在该空间中聚集在一起](../images/joint_embedding_space.svg) + +- 联合嵌入空间的强大之处在于它实现了**零样本迁移**。一旦你对齐了图像和文本嵌入,你就可以将从未训练过的类别图像分类:只需将类别名称作为文本嵌入,然后找出与图像嵌入最接近的文本嵌入即可。无需特定任务的微调。这是 CLIP 及其后继模型背后的关键洞察。 + +## 用于多模态对齐的对比学习 + +- 想象一个课堂练习:学生们拿到打乱的照片和描述对,需要将每张照片与其正确的描述配对。要出色地完成这项任务,你需要同时理解视觉内容与语言,并知道它们如何关联。**对比学习** 正是以这种方式训练模型:给定一批 (图像, 文本) 对,模型必须找出哪张图像对应哪段文本。 + +- 正如我们在第 8 章(文件 04)中看到的,单模态环境下的对比学习(SimCLR、MoCo)将同一图像的不同增广视图拉近,将不同图像的视图推远。多模态对比学习将"增广视图"替换为"匹配的模态":图像及其描述构成正样本对;该图像与批次中任何其他描述的配对构成负样本对。 + +### CLIP + +- **CLIP**(Contrastive Language-Image Pre-training,对比语言-图像预训练,Radford 等,2021)是多模态对比学习的基础模型。它在从互联网上抓取的 4 亿个 (图像, 文本) 对上联合训练一个图像编码器(ViT 或 ResNet,第 8 章)和一个文本编码器(Transformer,第 7 章)。 + +- 给定一批 $N$ 个图像-文本对,CLIP 计算所有图像嵌入与所有文本嵌入之间的 $N \\times N$ 余弦相似度矩阵。对角线上的条目是匹配的对(正样本);所有非对角线条目是不匹配的(负样本)。训练损失促使对角线条目升高,非对角线条目降低。 + +- 该损失是对称交叉熵。对于图像 $i$ 与文本 $j = i$ 的配对,图像到文本的损失为: + +$$\\mathcal{L}_{i \\to t} = -\\frac{1}{N} \\sum_{i=1}^{N} \\log \\frac{\\exp(\\text{sim}(z_i^{\\text{img}}, z_i^{\\text{txt}}) / \\tau)}{\\sum_{k=1}^{N} \\exp(\\text{sim}(z_i^{\\text{img}}, z_k^{\\text{txt}}) / \\tau)}$$ + +- 文本到图像的损失与之相同,只是交换了角色: + +$$\\mathcal{L}_{t \\to i} = -\\frac{1}{N} \\sum_{i=1}^{N} \\log \\frac{\\exp(\\text{sim}(z_i^{\\text{txt}}, z_i^{\\text{img}}) / \\tau)}{\\sum_{k=1}^{N} \\exp(\\text{sim}(z_i^{\\text{txt}}, z_k^{\\text{img}}) / \\tau)}$$ + +- 总的 CLIP 损失是平均值: + +$$\\mathcal{L}_{\\text{CLIP}} = \\frac{1}{2}(\\mathcal{L}_{i \\to t} + \\mathcal{L}_{t \\to i})$$ + +- 这里 $\\tau$ 是一个可学习的**温度**参数(初始化为 $\\tau = 0.07$)。温度控制 softmax 分布的尖锐程度:较低的 $\\tau$ 使模型更专注于最接近的匹配,较高的 $\\tau$ 则更均匀地分布概率。CLIP 将 $\\tau$ 与模型权重一起联合学习,而不是将其视为固定的超参数。 + +![CLIP 训练:一批 N 个图像-文本对生成 NxN 相似度矩阵,训练最大化对角线条目并最小化非对角线条目](../images/clip_contrastive_matrix.svg) + +- CLIP 的图像编码器通常是 ViT-L/14(大型 Vision Transformer,14x14 块,第 8 章文件 04)。文本编码器是一个 12 层带有因果掩码的 Transformer(类似 GPT,第 7 章文件 04)。两个编码器都通过一个可学习的线性投影将其输出映射到共享的 512 或 768 维空间,随后进行 $L_2$ 归一化。 + +- CLIP 最引人注目的特性是**零样本图像分类**。要将图像分类到 $K$ 个类别之一,你创建 $K$ 个文本提示,如"a photo of a {class name}",用文本编码器嵌入每个提示,用图像编码器嵌入图像,然后选择文本嵌入与图像嵌入余弦相似度最高的类别。在 ImageNet 上,CLIP 在从未见过任何 ImageNet 训练样本的情况下取得了具有竞争力的准确率。 + +### ALIGN + +- **ALIGN**(Jia 等,2021)将 CLIP 的方法扩展到更大、更嘈杂的数据集:18 亿个图像-文本对,仅极少量过滤。CLIP 精心筛选其数据,而 ALIGN 表明规模可以弥补噪声。ALIGN 使用 EfficientNet 图像编码器和 BERT 文本编码器,并使用相同的对比损失进行训练。关键发现是,只要有足够的数据,就不需要昂贵的数据清洗:对比目标会自然地降低噪声对的权重,因为它们产生不一致的梯度。 + +### SigLIP + +- **SigLIP**(Sigmoid Loss for Language-Image Pre-training,Sigmoid 损失语言-图像预训练,Zhai 等,2023)用更简单的 sigmoid 损失取代了 CLIP 基于 softmax 的对比损失。SigLIP 不将 $N \\times N$ 相似度矩阵视为分类问题(每行是一个列上的 softmax),而是将每个条目独立视为二分类问题:这个 (图像, 文本) 对是否匹配? + +- 单个对 $(i, j)$ 的 SigLIP 损失是: + +$$\\mathcal{L}_{ij} = -y_{ij} \\log \\sigma(z_i^{\\text{img}} \\cdot z_j^{\\text{txt}} / \\tau) - (1 - y_{ij}) \\log(1 - \\sigma(z_i^{\\text{img}} \\cdot z_j^{\\text{txt}} / \\tau))$$ + +- 其中 $y_{ij} = 1$ 如果 $i = j$(匹配),否则 $y_{ij} = 0$,$\\sigma$ 是 sigmoid 函数。 + +- SigLIP 的关键优势在于它消除了跨整个批次进行全局 softmax 归一化的需要。在 CLIP 中,softmax 分母需要收集所有设备上的所有嵌入,这在分布式训练中是一个通信瓶颈。SigLIP 的逐对 sigmoid 损失可以在本地计算,从而能够更高效地扩展到非常大的批次。SigLIP 以更低的训练成本达到了与 CLIP 相当的质量。 + +## 对比损失函数详解 + +- 对比学习中使用的损失函数共享一个共同的结构:它们都试图使正样本对的相似度得分高于负样本对的相似度得分,同时通过某种"间隔"或"温度"控制模型施加的力度。让我们形式化关键变体。 + +### InfoNCE + +- **InfoNCE**(噪声对比估计,van den Oord 等,2018)是 CLIP 损失背后的理论基础。给定一个查询 $q$、一个正样本键 $k^+$ 和 $K$ 个负样本键 $\\{k_1^-, \\ldots, k_K^-\\}$,损失为: + +$$\\mathcal{L}_{\\text{InfoNCE}} = -\\log \\frac{\\exp(q \\cdot k^+ / \\tau)}{\\exp(q \\cdot k^+ / \\tau) + \\sum_{j=1}^{K} \\exp(q \\cdot k_j^- / \\tau)}$$ + +- 这是一个 $(K+1)$ 类分类问题:从 $K+1$ 个候选中识别出正样本。InfoNCE 是查询与正样本键之间互信息的下界,这就是为什么最大化它能够对齐语义匹配输入的表征。随着负样本数量 $K$ 的增加,下界更加紧致,这解释了为什么对比方法受益于大批量大小。 + +### NT-Xent + +- **NT-Xent**(归一化温度标度交叉熵,Chen 等,2020)是 SimCLR(第 8 章文件 04)中使用的损失,本质上是在批次内对称应用的 InfoNCE。对于一批 $N$ 个对,$2N$ 个增广视图为每个锚点产生 $2N - 2$ 个负样本(除自身及其正样本外的所有视图)。正样本对 $(i, j)$ 的损失为: + +$$\\ell_{i,j} = -\\log \\frac{\\exp(\\text{sim}(z_i, z_j) / \\tau)}{\\sum_{k=1}^{2N} \\mathbf{1}_{[k \\neq i]} \\exp(\\text{sim}(z_i, z_k) / \\tau)}$$ + +- NT-Xent 和 InfoNCE 是相同的数学公式;名称不同只是因为它们是在不同的上下文(自监督视觉 vs. 表征学习理论)中引入的。 + +### 温度的作用 + +- **温度** $\\tau$ 是对比学习中最重要的超参数之一。为了建立直觉,可以从物理意义上考虑温度:在高温下,分子随机运动(softmax 是平坦的,所有负样本看起来一样差);在低温下,分子沉降为刚性结构(softmax 是尖锐的,只有最难的负样本才重要)。 + +- 形式化地,当 $\\tau \\to 0$ 时,softmax 趋近于硬 argmax,只选择最单一的困难负样本。当 $\\tau \\to \\infty$ 时,所有负样本的贡献相等。在实践中,$\\tau \\in [0.01, 0.1]$ 对归一化嵌入效果良好。温度过低会导致训练不稳定(困难负样本的梯度变得非常大);温度过高会使损失对违反情况不敏感。 + +- CLIP 初始化 $\\tau = 0.07$ 并将其作为对数参数化的标量 $\\tau = \\exp(t)$ 学习,其中 $t$ 与模型权重一起通过梯度下降更新。这使得模型能够在训练过程中自动调整对比任务的难度。 + +![温度对对比 softmax 的影响:低温产生聚焦于困难负样本的尖锐分布,高温产生平坦分布](../images/contrastive_temperature.svg) + +### 三元组损失和基于间隔的替代方案 + +- 在 InfoNCE 主导之前,**三元组损失(triplet loss)** 是度量学习的标准。给定一个锚点 $a$、一个正样本 $p$ 和一个负样本 $n$: + +$$\\mathcal{L}_{\\text{triplet}} = \\max(0, \\|a - p\\|^2 - \\|a - n\\|^2 + m)$$ + +- 其中 $m$ 是一个间隔,确保正样本至少比负样本近 $m$。三元组损失操作在单个三元组上而非批次上,因此样本效率低于 InfoNCE。它还对挖掘策略敏感:随机负样本通常过于简单(损失为零),因此**困难负样本挖掘**(hard negative mining,选择最接近的不正确匹配)或**半困难挖掘**(semi-hard mining,选择间隔内的负样本)至关重要。 + +- InfoNCE 在整个批次中隐式地执行困难负样本挖掘,这是它在规模上优于三元组损失的原因之一。InfoNCE 中的 softmax 归一化自动提高困难负样本(与锚点相似度高的负样本)的权重,在无需显式挖掘的情况下提供了自然的课程学习。 + +## 图像-文本检索与零样本分类 + +- 一旦你有了训练好的联合嵌入空间,就可以执行**图像-文本检索**:给定一个图像查询,从数据库中找出最相关的文本(图像到文本检索),或者给定一个文本查询,找出最相关的图像(文本到图像检索)。这仅仅是共享嵌入空间中的最近邻搜索。 + +- 想象一个图书管理员,可以即时比较一百万条目录中的任何照片与任何描述。他们不需要事先理解每一个可能的类别;只需测量每张照片与每条描述有多"接近"。这就是 CLIP 风格的模型执行检索和零样本分类的方式。 + +- **零样本分类**是文本到图像检索的一个特例。给定 $K$ 个类别名称,你构建文本提示 $\\{t_1, \\ldots, t_K\\}$(例如,"a photo of a cat"、"a photo of a dog")并对其进行嵌入。对于一张新图像 $x$,预测的类别为: + +$$\\hat{y} = \\arg\\max_{k} \\; \\text{sim}(f_\\theta(x), g_\\phi(t_k))$$ + +- 关键洞察在于,文本编码器充当了一个灵活的分类器头。你不需要为每个下游任务训练新的线性层,只需用自然语言描述任务。这就是 CLIP 泛化能力如此之强的原因:文本编码器在预训练期间见过数百万种不同的描述。 + +- **提示工程(prompt engineering)** 很重要。CLIP 在 ImageNet 上的零样本准确率从 63.2% 提升到 68.4%,仅仅是将提示模板从 "{class name}" 改为 "a photo of a {class name}." 更好的是,**提示集成(prompt ensembling)** 通过平均多个模板的文本嵌入(例如,"a photo of a {class name}"、"a good photo of a {class name}"、"a drawing of a {class name}")来产生更鲁棒的文本表征。 + +![零样本分类:将每个类别的文本提示与图像一起嵌入,选择余弦相似度最高的类别](../images/zero_shot_classification.svg) + +## 音视频对应 + +- 闭上眼睛,听某人拍篮球。你能从节奏性的砰砰声中判断球何时落地。现在睁开眼睛:视觉上的弹跳与每次砰声完美对齐。这种音频与视觉事件之间的紧密对应关系是一种机器可以学习的免费监督信号。**音视频对应学习(audio-visual correspondence learning)** 训练模型将声音与其视觉来源关联起来,无需任何人工标注。 + +- 这个想法与 CLIP 惊人地相似,只是将文本替换为音频。给定配对的视频帧和音频片段,模型学习一个嵌入空间,其中时间上对齐的音视频对彼此接近,而错位的对则相距很远。 + +- **音视频嵌入(Audio-Visual Embedding, AVE)** 方法(Arandjelovic 和 Zisserman,2017)使用对比损失在视频数据上训练一个视觉编码器 $f$ 和一个音频编码器 $g$。正样本对是(视频帧,来自同一时刻的音频片段),负样本是来自不同视频或不同时刻的音频片段。模型学会狗叫声对应狗的图像,吉他声对应吉他的图像,所有这些都不需要标签。 + +- 音频编码器通常使用 CNN 或音频 Transformer 处理**对数梅尔语谱图(log-mel spectrograms)**(第 9 章文件 01),生成固定大小的嵌入。视觉编码器使用标准图像骨干网络(ResNet、ViT)处理视频帧。两者都投影到共享的 $d$ 维空间,训练使用与 CLIP 相同的 InfoNCE 损失: + +$$\\mathcal{L}_{\\text{AV}} = -\\log \\frac{\\exp(\\text{sim}(z^{\\text{vis}}, z^{\\text{aud}}) / \\tau)}{\\sum_{k=1}^{N} \\exp(\\text{sim}(z^{\\text{vis}}, z_k^{\\text{aud}}) / \\tau)}$$ + +![音视频对应:视觉编码器处理视频帧,音频编码器处理语谱图,对比学习对齐时间匹配的对](../images/audio_visual_correspondence.svg) + +- 音视频学习的**应用**包括:声源定位(图像中声音来自何处?)、音视频语音识别(结合嘴唇运动和音频,如第 9 章文件 02)、音视频源分离(通过看着对方的脸来隔离一个人的声音——第 9 章文件 05 中的"鸡尾酒会"问题),以及基于音频的视频生成。 + +- **ImageBind**(Girdhar 等,2023)将其扩展到六种模态:图像、文本、音频、深度、热成像和 IMU 数据。关键洞察在于,你不需要每个组合都有配对数据。通过将每种模态与图像对齐(文本通过图像-文本对,音频通过图像-音频对等),所有模态通过共享的图像嵌入空间隐式对齐。这种通过公共锚点模态的"绑定"产生了涌现式对齐:音频和文本变得相似,即使它们从未被直接一起训练过。 + +## 评估 + +- 评估多模态模型需要能够捕捉跨模态理解的度量指标。两种主流的评估范式是**零样本基准测试**和**检索度量**。 + +### 零样本基准测试 + +- 零样本评估衡量模型是否能够执行从未被明确训练过的任务。最常用的基准是**ImageNet 零样本准确率**:将所有 1,000 个 ImageNet 类别名称作为文本嵌入,嵌入每个测试图像,根据余弦相似度测量 top-1 和 top-5 分类准确率。CLIP ViT-L/14 在零样本下达到 75.5% 的 top-1 准确率,与在 ImageNet 上训练的监督式 ResNet-50 相当。 + +- 其他零样本基准包括:CIFAR-10/100、STL-10、Food-101、Oxford Pets 和 Flowers-102。在多个数据集上评估可以测试模型是否真正具有通用的视觉理解能力,还是仅仅是记住了预训练数据中的模式。 + +- **线性探测(linear probe)** 评估是一种互补的测试。你冻结预训练的图像编码器,为标注数据集提取特征,然后在其上训练一个简单的线性分类器。这独立于零样本检索机制来度量学习到的表征的质量。CLIP 的特征是极好的线性探测特征,通常达到或超过监督预训练。 + +### 检索度量 + +- 对于检索任务(图像到文本和文本到图像),标准度量是 **Recall@K**(R@K):正确匹配出现在前 $K$ 个检索结果中的查询比例。常用的取值为 R@1、R@5 和 R@10。 + +- 形式化地,对于一组 $Q$ 个查询: + +$$\\text{R@}K = \\frac{1}{Q} \\sum_{q=1}^{Q} \\mathbf{1}[\\text{rank}(q) \\leq K]$$ + +- 其中 $\\text{rank}(q)$ 是查询 $q$ 的排序检索列表中正确匹配的位置。 + +- 标准的检索基准包括 **Flickr30K**(31,000 张图像,每张 5 条描述)和 **MS-COCO**(123,000 张图像,每张 5 条描述)。在测试集上评估:给定一张图像,从全部测试集中检索正确的描述,反之亦然。 + +- **中位数排名(Median Rank, MedR)** 是一种补充度量:所有查询中正确匹配的中位数位置。完美模型的 MedR = 1。数值越小越好。 + +- 除了检索,多模态模型还在组合理解基准上进行评估,如 **Winoground**(测试模型能否区分"a mug in a dog"和"a dog in a mug")和 **ARO**(属性、关系、顺序),这些基准测试模型是否真正理解语言的结构,而不仅仅是匹配词袋。CLIP 风格的模型通常在这些任务上表现不佳,这揭示了一个基本的局限:对比预训练对齐了全局语义,但可能无法捕捉细粒度的组合结构。 + +![检索评估:给定一个查询图像,模型按相似度对所有文本候选进行排序,Recall@K 衡量正确描述是否出现在前 K 个结果中](../images/retrieval_recall_at_k.svg) + +## 总结 + +- 本文件涵盖的多模态表征构成了本章后续所有内容的基础。CLIP 及其后继模型训练的联合嵌入空间是连接视觉和语言的"胶水"。文件 02 在此基础之上,构建了超越检索、能够生成关于图像文本的视觉-语言模型。文件 03 探讨了如何在序列模型中对图像和视频进行分词。文件 04 涵盖跨模态生成(文本到图像、文本到视频)。文件 05 研究了在单一模型中处理多种模态的统一架构。 + +- 核心要点:在配对数据上进行对比学习产生了嵌入空间,使得不同模态之间可以互换。图像嵌入和文本嵌入变成了"同一种东西",从而实现零样本分类、检索以及无缝集成到更大的系统中。这个想法——将匹配的对拉近、不匹配的对推远——的简单性掩盖了其非凡的有效性。 + +## 编程任务(使用 CoLab 或 notebook) + +1. 从头实现 CLIP 对比损失。创建随机图像和文本嵌入,计算相似度矩阵,并计算对称交叉熵损失。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +def clip_loss(image_embeds, text_embeds, temperature=0.07): + """计算对称 CLIP 对比损失。""" + # L2 归一化嵌入 + image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=1, keepdims=True) + text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=1, keepdims=True) + + # 计算余弦相似度矩阵 (N x N) + logits = image_embeds @ text_embeds.T / temperature # (N, N) + + # 标签:对角线(第 i 张图像匹配第 i 段文本) + N = logits.shape[0] + labels = jnp.arange(N) + + # 对称交叉熵:图像到文本 + 文本到图像 + loss_i2t = -jnp.mean(jax.nn.log_softmax(logits, axis=1)[jnp.arange(N), labels]) + loss_t2i = -jnp.mean(jax.nn.log_softmax(logits, axis=0)[labels, jnp.arange(N)]) + return (loss_i2t + loss_t2i) / 2, logits * temperature + +# 模拟一批 8 个图像-文本对,64 维空间 +key = jax.random.PRNGKey(42) +k1, k2 = jax.random.split(key) +N, D = 8, 64 +image_embeds = jax.random.normal(k1, (N, D)) +text_embeds = jax.random.normal(k2, (N, D)) + +loss, sim_matrix = clip_loss(image_embeds, text_embeds) +print(f"CLIP loss (random embeddings): {loss:.4f}") + +# 可视化相似度矩阵 +fig, ax = plt.subplots(figsize=(6, 5)) +im = ax.imshow(sim_matrix, cmap='coolwarm', vmin=-1, vmax=1) +ax.set_xlabel("Text index"); ax.set_ylabel("Image index") +ax.set_title(f"Cosine Similarity Matrix (loss={loss:.3f})") +plt.colorbar(im); plt.tight_layout(); plt.show() +# 尝试改变温度 (0.01, 0.1, 1.0) 并观察损失如何变化 +# 尝试使匹配对相似:将 text_embeds 设置为 image_embeds + 小噪声 +``` + +2. 构建一个玩具联合嵌入模型,学习使用 InfoNCE 损失和梯度下降来对齐 2D"图像"(随机向量)与"描述"(不同的随机向量)。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +def info_nce_loss(img_enc, txt_enc, img_data, txt_data, tau=0.1): + """在一批配对的 (图像, 文本) 数据上计算 InfoNCE。""" + z_img = img_data @ img_enc # (N, D) + z_txt = txt_data @ txt_enc # (N, D) + # L2 归一化 + z_img = z_img / jnp.linalg.norm(z_img, axis=1, keepdims=True) + z_txt = z_txt / jnp.linalg.norm(z_txt, axis=1, keepdims=True) + logits = z_img @ z_txt.T / tau + labels = jnp.arange(logits.shape[0]) + return -jnp.mean(jax.nn.log_softmax(logits, axis=1)[jnp.arange(len(labels)), labels]) + +# 创建 32 个配对样本:图像在 R^8 中,文本在 R^6 中,嵌入到 R^4 +key = jax.random.PRNGKey(0) +k1, k2, k3, k4 = jax.random.split(key, 4) +N, d_img, d_txt, d_embed = 32, 8, 6, 4 + +img_data = jax.random.normal(k1, (N, d_img)) +txt_data = jax.random.normal(k2, (N, d_txt)) + +# 可学习的投影矩阵 +img_enc = jax.random.normal(k3, (d_img, d_embed)) * 0.1 +txt_enc = jax.random.normal(k4, (d_txt, d_embed)) * 0.1 + +grad_fn = jax.jit(jax.grad(info_nce_loss, argnums=(0, 1))) +lr = 0.05 +losses = [] + +for step in range(300): + loss = info_nce_loss(img_enc, txt_enc, img_data, txt_data) + losses.append(float(loss)) + g_img, g_txt = grad_fn(img_enc, txt_enc, img_data, txt_data) + img_enc = img_enc - lr * g_img + txt_enc = txt_enc - lr * g_txt + +print(f"Initial loss: {losses[0]:.3f}, Final loss: {losses[-1]:.3f}") +print(f"Random baseline (log N): {jnp.log(N):.3f}") + +plt.figure(figsize=(8, 4)) +plt.plot(losses, color='#2c3e50') +plt.axhline(y=0, color='green', linestyle='--', alpha=0.5, label='Perfect alignment') +plt.axhline(y=float(jnp.log(N)), color='red', linestyle='--', alpha=0.5, label='Random (log N)') +plt.xlabel("Step"); plt.ylabel("InfoNCE Loss") +plt.title("Learning a Joint Embedding Space") +plt.legend(); plt.grid(alpha=0.3); plt.tight_layout(); plt.show() +# 修改 d_embed(尝试 2, 4, 16)观察嵌入维度如何影响对齐 +``` + +3. 使用预计算的嵌入实现零样本分类。模拟类"原型"作为文本嵌入,通过最近邻查找对新图像进行分类。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# 模拟 5 个类,每个类有一个原型文本嵌入在 R^32 中 +key = jax.random.PRNGKey(42) +n_classes, d = 5, 32 +class_names = ["cat", "dog", "car", "plane", "ship"] + +# 类原型(想象这些来自文本编码器) +k1, k2 = jax.random.split(key) +class_prototypes = jax.random.normal(k1, (n_classes, d)) +class_prototypes = class_prototypes / jnp.linalg.norm(class_prototypes, axis=1, keepdims=True) + +# 生成 200 个测试"图像"(在其类原型附近加上噪声的嵌入) +n_per_class = 40 +true_labels = jnp.repeat(jnp.arange(n_classes), n_per_class) +keys = jax.random.split(k2, n_classes * n_per_class) + +image_embeds = [] +for i in range(n_classes): + noise = jax.random.normal(keys[i], (n_per_class, d)) * 0.5 + cluster = class_prototypes[i] + noise + image_embeds.append(cluster) +image_embeds = jnp.concatenate(image_embeds, axis=0) +image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=1, keepdims=True) + +# 零样本分类:与每个原型的余弦相似度 +similarities = image_embeds @ class_prototypes.T # (200, 5) +predicted_labels = jnp.argmax(similarities, axis=1) +accuracy = jnp.mean(predicted_labels == true_labels) +print(f"Zero-shot accuracy: {accuracy:.1%}") + +# 混淆矩阵 +conf = jnp.zeros((n_classes, n_classes), dtype=jnp.int32) +for true, pred in zip(true_labels, predicted_labels): + conf = conf.at[true, pred].add(1) + +fig, ax = plt.subplots(figsize=(6, 5)) +im = ax.imshow(conf, cmap='Blues') +ax.set_xticks(range(n_classes)); ax.set_xticklabels(class_names, rotation=45) +ax.set_yticks(range(n_classes)); ax.set_yticklabels(class_names) +ax.set_xlabel("Predicted"); ax.set_ylabel("True") +for i in range(n_classes): + for j in range(n_classes): + ax.text(j, i, int(conf[i, j]), ha='center', va='center', fontsize=11) +ax.set_title(f"Zero-Shot Confusion Matrix (acc={accuracy:.1%})") +plt.colorbar(im); plt.tight_layout(); plt.show() +# 尝试增加噪声(0.5 -> 1.0 -> 2.0)观察准确率下降 +# 尝试提示集成:平均每个原型的 3 个噪声副本 +``` diff --git a/chapter 10: multimodal learning/02. vision language models.md b/chapter 10: multimodal learning/02. vision language models.md new file mode 100644 index 0000000..db0ef11 --- /dev/null +++ b/chapter 10: multimodal learning/02. vision language models.md @@ -0,0 +1,388 @@ +# 视觉语言模型 + +*视觉语言模型共同理解图像和文本,实现视觉问答、图像描述和视觉推理。本文件涵盖 VQA、图像描述、视觉定位,以及 VisualBERT、BLIP、LLaVA、Flamingo、PaLI 和 Qwen-VL 等将视觉编码器与大型语言模型融合的架构。* + +- 想象一位博物馆导览员,他能看着一幅画并清晰描述画中的一切:有哪些物体、讲述了什么故事、传达了怎样的情感,还能回答参观者的任何问题。**视觉语言模型(VLM)** 就是计算领域的等价物——一个能同时理解图像和文本的系统,能够描述视觉场景、回答相关问题、执行视觉指令,甚至根据自然语言查询在图像中定位特定物体。 + +- VLM 位于你在第 8 章学到的视觉编码器和第 7 章的语言模型的交汇点。核心工程挑战在于桥接两个截然不同的表征世界:视觉骨干网络产生的空间化、连续的 feature map,与语言模型产生的序列化、离散的 token 嵌入。本文件中的每一种架构,本质上都是对同一个问题的不同回答:如何融合视觉和语言? + +![VLM 高层次分类:双编码器、融合编码器和编码器-解码器家族及其输入与输出](../images/vlm_taxonomy.svg) + +## 视觉问答 + +- 想象有人向你展示一张照片并问:"公园里有几只狗?"你毫不费力地解析图像、定位狗、数出数量并给出答案。**视觉问答(VQA)** 将这一过程形式化:给定一张图像 $I$ 和一个自然语言问题 $q$,预测答案 $a$。 + +- 该任务可以有多种定义方式。最常见的方式将 VQA 视为**开放式分类**:模型从最常见的答案构成的固定词汇表中选择(例如 VQA v2 中排名前 3,129 的答案)。另一种方式是**生成式回答**,模型生成自由形式的文本字符串——这是现代 VLM 采用的方法。 + +- 形式上,你需要学习一个最大化正确答案似然的函数 $f(I, q) \to a$。在分类设置中,这变为: + +$$p(a \mid I, q) = \text{softmax}(W \cdot g(v, h))$$ + +- 其中 $v$ 是视觉特征向量(来自 CNN 或 ViT),$h$ 是问题编码(来自 LSTM 或 Transformer),$g$ 是融合函数。$g$ 的设计正是真正的架构创造力所在。 + +- **VQA v1**(Antol 等人,2015)引入了该基准,包含来自 MS COCO 的 204,000 张图像上的 614,000 个问题。研究人员很快发现,模型可以通过利用**语言先验**达到惊人高的准确率——对"多少个"问题回答"2",对"有没有"问题回答"是",甚至不需要看图像。 + +- **VQA v2**(Goyal 等人,2017)通过为每个问题配对不同答案的两张相似图像来解决这个问题。这迫使模型真正将其推理建立在视觉内容之上。平衡配对设置使数据集规模大约翻倍,并使纯语言捷径的效果大打折扣。 + +- 其他重要的 VQA 数据集包括 **GQA**(Hudson & Manning,2019),包含需要多步推理的组合性问题;**OK-VQA**(Marino 等人,2019),需要超出图像范围的外部知识;以及 **TextVQA**(Singh 等人,2019),答案依赖于读取图像中的文字。 + +![VQA 流水线:图像经视觉编码器编码,问题经文本编码器编码,两者表征融合后,融合向量被分类为答案](../images/vqa_pipeline.svg) + +- 早期的 VQA 模型使用简单策略:从预训练 CNN 中提取图像特征(通常是第 8 章中 ResNet 或 VGGNet 的倒数第二层),用 LSTM(第 6 章)对问题进行编码,然后将它们组合。组合函数 $g$ 演变迅速:从简单的逐元素乘法,到双线性池化,再到多模态 Tucker 分解。**双线性注意力**计算 $v^T W h$,其中 $W$ 是可学习的交互矩阵,但完整的双线性形式有 $O(d_v \times d_h)$ 个参数,规模过大。**MLB**(多模态低秩双线性池化)将其分解为两个低秩投影,使其变得可行。 + +- VQA 的突破是注意力机制。**堆叠注意力网络**(Yang 等人,2016)使用问题编码在空间图像区域上施加注意力,迭代式地精炼需要关注的图像部分。这个思想——让问题"关注"相关图像区域——成为了标准做法。 + +## 图像描述 + +- 想象一位朋友看着你的度假照片并叙述他们所看到的:"一只金毛猎犬在阳光明媚的沙滩上接飞盘。"**图像描述**是生成图像的自然语言描述的任务。与 VQA 不同,这里没有提问——模型必须自行决定哪些内容值得描述。 + +- **Show and Tell**(Vinyals 等人,2015)建立了描述任务的标准编码器-解码器架构。CNN 编码器(如 Inception 或 ResNet)生成一个单一图像特征向量 $v$。该向量被用作 LSTM 解码器的初始隐藏状态,然后逐词自回归地生成描述: + +$$p(w_t \mid w_{1:t-1}, I) = \text{LSTM}(w_{t-1}, h_{t-1})$$ + +- 整个模型通过最大化真实描述的对数似然进行端到端训练。推理时使用束搜索(第 7 章)来找到高概率的描述。 + +- Show and Tell 的问题在于整张图像被压缩成一个单一向量。对于复杂场景,单一向量无法捕捉所有相关细节。你会丢失空间信息——模型在生成不同词语时无法"回看"图像的特定区域。 + +- **Show, Attend and Tell**(Xu 等人,2015)通过引入**图像区域上的注意力**解决了这个问题。模型不是将图像编码为一个向量,而是由 CNN 产生一个空间特征网格(例如来自 VGGNet 最后一个卷积层的 $14 \times 14 \times 512$)。在每个解码步骤,模型计算这些空间位置上的注意力权重,生成一个突出当前词语最相关区域的上下文向量。 + +- 回顾第 6 章的注意力机制:解码器隐藏状态充当查询,空间特征充当键和值,注意力权重告诉模型应该看哪里。作者提出了两种变体:**软注意力**(可微分,所有区域的加权平均)和**硬注意力**(对单个区域进行随机采样,使用 REINFORCE 训练)。 + +![基于注意力的描述:在每个解码步骤,模型关注图像的不同空间区域,例如在生成"狗"这个词时聚焦于狗所在区域](../images/attention_captioning.svg) + +- 这些模型产生的注意力图具有显著的可解释性:生成"狗"时,注意力集中在狗的区域;生成"海滩"时,注意力转移到沙子和水面。这是注意力机制提供内置可解释性的最早令人信服的演示之一。 + +- **CIDEr**(Vedantam 等人,2015)、**METEOR**、**BLEU** 和 **SPICE** 是标准描述评估指标。CIDEr 计算生成描述与参考描述之间的 TF-IDF 加权 n-gram 相似度,专门为描述评估设计。现代 VLM 通常在 MS COCO Captions 和 NoCaps 等描述基准上用 CIDEr 进行评估。 + +- 后来的描述模型引入了**自底向上注意力**(Anderson 等人,2018),其中目标检测器(Faster R-CNN,第 8 章)首先提出显著的图像区域,然后描述模型在这些区域特征而非均匀网格上施加注意力。在基于 ViT 的编码器接管之前,这是主导方法。 + +## 架构模式 + +- 每个 VLM 都必须回答一个基本设计问题:视觉和语言在哪个节点交互?答案决定了模型的架构家族。有三种主要模式,各自具有不同的权衡。 + +### 双编码器 + +- 想象两位独立工作的译者——一位读法语文件,另一位读英语文件——他们各自用一种共享的"通用语言"生成摘要。他们在翻译过程中从不交流,但他们的摘要可以直接比较。这就是**双编码器**模式。 + +- 视觉编码器 $f_v$ 和文本编码器 $f_t$ 独立地将各自的输入映射到一个维度为 $d$ 的共享嵌入空间。图像嵌入为 $v = f_v(I) \in \mathbb{R}^d$,文本嵌入为 $t = f_t(q) \in \mathbb{R}^d$。相似度通过点积或余弦相似度计算:$\text{sim}(I, q) = v^T t / (\|v\| \|t\|)$。 + +- **CLIP**(Radford 等人,2021),在前一篇关于多模态表示的文件中已介绍,是典型的双编码器。它在从互联网抓取的 4 亿图像-文本对上使用对比目标函数(InfoNCE)进行训练。由于编码器相互独立,你可以预计算并缓存所有图像嵌入,使检索极其高效——搜索时只需对查询文本进行编码。 + +- 双编码器的缺点在于视觉和语言从未在特征层面进行交互。模型无法进行细粒度的跨模态推理:例如,它无法确定描述中的特定词是否对应图像中的特定区域。这限制了它在 VQA 或 grounded 描述等任务中的实用性。 + +### 融合编码器 + +- 现在想象两位译者共处一室,积极讨论两篇文件。他们可以指向特定段落、互相提问,并建立共同的理解。这就是**融合编码器**模式。 + +- 两种模态都被编码,然后通过**交叉注意力层**进行融合,其中一种模态的 token 关注另一种模态的 token。图像首先由视觉编码器处理为一系列 patch 或区域 token $V = [v_1, \ldots, v_N]$。文本被分词化为 $T = [t_1, \ldots, t_M]$。在融合层中,文本 token 通过交叉注意力关注图像 token: + +$$\text{CrossAttn}(T, V) = \text{softmax}\!\left(\frac{(TW_Q)(VW_K)^T}{\sqrt{d_k}}\right)(VW_V)$$ + +- 这实现了细粒度的交互:每个文本 token 都可以关注其所需的特定图像区域。**VisualBERT**、**VilBERT** 和 **UNITER** 等模型使用这种模式。代价是你无法为检索预计算独立的嵌入——每个图像-文本对都需要通过融合层进行完整的前向传播。 + +![双编码器与融合编码器对比:双编码器计算独立嵌入和相似度得分,而融合编码器通过交叉注意力层合并模态](../images/dual_vs_fusion_encoder.svg) + +### 编码器-解码器 + +- **编码器-解码器**模式将视觉编码器与自回归生成输出 token 的文本解码器相结合,类似于第 7 章中的 seq2seq 模型。视觉编码器产生上下文图像表征,文本解码器在生成输出文本时对其执行交叉注意力。 + +- 这种模式天然支持生成式任务:图像描述、自由形式答案的 VQA 以及视觉对话。**GIT**(Generative Image-to-text Transformer,Wang 等人,2022)、**CoCa**(Contrastive Captioner,Yu 等人,2022)和 **PaLI** 使用这种架构。CoCa 巧妙地将双编码器和编码器-解码器模式结合起来:文本解码器的前半部分作为单模态文本编码器(用于对比学习),而后半部分对图像特征执行交叉注意力(用于生成式描述),兼得两者之优势。 + +- 这三种模式的选择取决于目标任务。双编码器最适合大规模检索。融合编码器最适合细粒度理解任务。编码器-解码器对于生成任务最为通用。现代最先进的 VLM 越来越多地采用编码器-解码器或仅解码器范式,将每项视觉语言任务都视为文本生成。 + +## Flamingo:少样本多模态学习 + +- 想象一位经验丰富的专家,经过多年对艺术和文学的研究,只需要看一两个例子就能优雅地描述一种全新的绘画风格。**Flamingo**(Alonso 等人,2022,DeepMind)基于相同原理构建:它利用强大的预训练语言模型和预训练视觉编码器,通过轻量级架构组件将其连接,实现多模态任务上的少样本学习。 + +- Flamingo 的设计理念保守而有效:保持预训练的视觉编码器(NFNet)和语言模型(Chinchilla)冻结,仅学习连接它们的"胶水"。这种胶水由两个组件组成:**Perceiver 重采样器**和**门控交叉注意力层**。 + +- **Perceiver 重采样器**将视觉编码器的变长输出(取决于图像分辨率)压缩为一组固定数量的 $N$ 个视觉 token(通常 $N = 64$)。它的工作原理是初始化一组 $N$ 个可学习的查询向量,并使用交叉注意力让这些查询关注完整的视觉编码器输出。这本质上是 Perceiver 架构(Jaegle 等人,2021)作为瓶颈的应用——无论输入图像大小如何,它都能生成紧凑的、固定大小的视觉表示。 + +$$z = \text{CrossAttn}(Q_{\text{learned}}, V_{\text{image}}) \in \mathbb{R}^{N \times d}$$ + +- **门控交叉注意力层**交错插入在冻结的语言模型层之间。在每个这样的层中,语言模型的文本 token 对 Perceiver 重采样器产生的视觉 token 执行交叉注意力。关键之处在于,每个门控交叉注意力层包含一个可学习的标量门控 $\alpha$,初始化为零,将交叉注意力输出乘以 $\alpha$ 后再加到残差流中: + +$$\hat{x} = x + \alpha \cdot \text{CrossAttn}(x, z)$$ + +- 初始化 $\alpha = 0$ 意味着训练开始时交叉注意力不贡献任何信息,模型行为与原始的冻结语言模型完全相同。门控在训练过程中逐渐打开,平滑地整合视觉信息,同时不破坏语言模型的预训练表示。 + +![Flamingo 架构:冻结的视觉编码器输入 Perceiver 重采样器,生成固定长度的视觉 token,通过交错在 LM 块之间的门控交叉注意力层注入冻结的 LM](../images/flamingo_architecture.svg) + +- Flamingo 原生支持**交错图像-文本序列**。你可以向它输入包含多张图像穿插文本的提示,例如:"[图像 1] 这是一只猫。[图像 2] 这是一只狗。[图像 3] 这是一个 ___。"模型将每张图像通过视觉编码器和 Perceiver 重采样器处理,得到的视觉 token 插入到文本序列中的对应位置。语言模型的因果注意力掩码确保每个文本 token 只能关注当前及之前图像的视觉 token。 + +- 这种交错机制实现了强大的**少样本多模态学习**。通过在上下文中提供少量图像-文本示例,Flamingo 可以在没有任何梯度更新的情况下执行新任务。在 VQAv2、OK-VQA 和描述等基准上,具有 800 亿参数的 Flamingo 实现了最先进的少样本性能,仅需 4 到 32 个示例即可匹配甚至超越经过微调的专家模型。 + +## LLaVA 与视觉指令微调 + +- 想象你有一位出色的语言专家(一个 LLM)和一位出色的艺术评论家(一个视觉编码器)。如果你能教会艺术评论家"说语言专家的语言",他们就可以无缝协作。**LLaVA**(Large Language and Vision Assistant,Liu 等人,2023)正是这样做的:它使用一个简单的线性层将视觉特征投影到 LLM 的 token 嵌入空间,然后在指令遵循数据上微调整个系统。 + +- LLaVA 的架构出奇地简单。图像由一个预训练的 CLIP ViT-L/14 视觉编码器编码为一个 patch 特征网格 $V \in \mathbb{R}^{N \times d_v}$,其中 $N = 256$ 个 patch(对于 336px 图像和 14px patch)。一个**投影层** $W$ 将这些视觉特征映射到 LLM 的嵌入维度: + +$$H_v = VW, \quad W \in \mathbb{R}^{d_v \times d_{\text{LLM}}}$$ + +- 投影后的视觉 token $H_v$ 直接与文本 token 嵌入拼接,作为一个单一序列输入到 LLM(Vicuna,一个微调后的 LLaMA)。LLM 使用其标准因果自注意力处理它们——没有特殊的交叉注意力层,没有 perceiver,只有拼接。视觉 token 被当作恰好编码了视觉信息的文本 token 来处理。 + +![LLaVA 架构:CLIP ViT 将图像编码为 patch 特征,线性投影将其映射到 LLM 嵌入空间,投影后的视觉 token 拼接在文本 token 之前并输入到 LLM](../images/llava_architecture.svg) + +- **视觉指令微调**是 LLaVA 的关键训练创新。作者使用 GPT-4 从 COCO 图像生成了 158,000 个多模态指令遵循示例。每个示例包含一张图像和一个对话式指令(例如"详细描述这张图像"、"这张图像有什么不寻常之处?"、"如果我是一名游客参观这个地方,我应该知道什么?")。模型接受训练,根据图像和指令生成 GPT-4 撰写的回答。 + +- 训练分为两个阶段。**阶段 1(预训练)**:仅训练投影层 $W$,使用图像-描述对(来自 CC3M 的 595K 数据),视觉编码器和 LLM 都保持冻结。这教会 $W$ 将视觉特征与 LLM 的嵌入空间对齐。**阶段 2(微调)**:投影层和 LLM 在指令遵循数据上联合微调,视觉编码器保持冻结。这教会模型遵循复杂的视觉指令。 + +- **LLaVA-1.5** 通过三项关键更改改进了原始版本:将单层线性投影替换为两层 MLP(更具表现力的映射),使用更高分辨率的图像(336px 而非 224px,产生更多 patch token),以及在训练混合数据中加入学术 VQA 数据集。这些看似细微的修改带来了基准性能的大幅提升。 + +- LLaVA 的方法证明,你不需要像 Flamingo 的 Perceiver 重采样器或门控交叉注意力那样复杂的架构创新。一个简单的线性投影,结合高质量的指令微调数据,就足以有效地将视觉编码器连接到 LLM。这种简洁性使得 LLaVA 极具影响力——后续大多数开源 VLM 都遵循类似的方案。 + +## 扩展视觉语言模型 + +- 该领域从概念验证型 VLM 迅速发展为在数十亿图像-文本对上训练的工业级系统。三个模型家族展示了不同的扩展方法。 + +### PaLI + +- **PaLI**(Pathways Language and Image model,Chen 等人,2022,Google)同时扩展视觉编码器和语言模型。PaLI 使用 ViT-e(40 亿参数)作为视觉编码器,mT5(130 亿参数)作为语言模型,总计 170 亿参数。图像被编码为一系列 patch token,拼接在文本 token 之前,输入到编码器-解码器架构的 mT5。 + +- PaLI 的关键洞见是**扩展视觉编码器与扩展语言模型同样重要**。先前的工作通常使用固定的、中等规模的视觉骨干网络(如 ViT-B 或 ViT-L),将参数预算全部投入 LLM。PaLI 表明,一个 40 亿参数的 ViT-e,在 JFT-4B(40 亿张标注图像)上预训练后,能够显著提升 OCR 和空间推理等细粒度视觉任务的性能。 + +- PaLI 在 WebLI(一个包含 109 种语言、100 亿图像-文本对的数据集)上训练,因此天然具备多语言能力。模型通过混合任务进行预训练:图像描述、VQA 和图像-文本匹配,全部作为文本到文本生成任务(遵循第 7 章的 T5 范式)。**PaLI-X**(550 亿参数)和 **PaLI-3**(50 亿,使用 SigLIP 作为视觉编码器)是后续迭代版本。 + +### Qwen-VL + +- **Qwen-VL**(Bai 等人,2023,阿里巴巴)在 Qwen LLM 基础上增加了一个 ViT 视觉编码器和一个单层交叉注意力模块(类似于 Flamingo 的 Perceiver 重采样器),将视觉编码器的输出压缩为一组固定的 256 个视觉 token。视觉 token 与文本 token 拼接后由 Qwen LLM 处理。 + +- Qwen-VL 的训练采用三阶段方案。阶段 1:在 14 亿个弱监督图像-文本对上预训练,仅解冻视觉编码器。阶段 2:在更高质量的数据上进行多任务预训练,包括 VQA、描述、定位和 OCR 数据集,整个模型解冻。阶段 3:在指令遵循和对话数据上进行监督微调。这种从噪声网络数据到精选指令数据的渐进式精炼,是大多数现代 VLM 共享的模式。 + +- **Qwen2-VL**(2024)引入了**动态分辨率**支持:模型不是将所有图像缩放到固定大小,而是通过动态调整视觉 token 数量以原始分辨率处理图像。更高分辨率的图像产生更多 token,更低分辨率的图像产生更少 token。这在不浪费低分辨率输入计算量的前提下,提升了文档理解和细粒度识别等对细节敏感的任务的性能。 + +### InternVL + +- **InternVL**(Chen 等人,2024,上海人工智能实验室)激进地扩展了视觉编码器,使用 InternViT-6B——一个 60 亿参数的视觉 Transformer——与语言模型配对。关键的架构贡献是**动态高分辨率处理**:图像被分割为 448x448 像素的图块,每个图块由视觉编码器独立处理,得到的图块特征与完整图像的缩略图特征拼接。这使得模型能够处理任意宽高比和分辨率的图像。 + +- InternVL-2 进一步引入了**渐进对齐训练**:首先用对比目标(如 CLIP)对齐视觉编码器,然后通过轻量级 MLP 连接器将其连接到 LLM,最后在指令数据上进行端到端微调。这种渐进策略防止了视觉编码器预训练表示的灾难性遗忘。 + +![扩展 VLM:PaLI、Qwen-VL 和 InternVL 的比较,展示了连接视觉编码器和语言模型的不同方法,包括其训练阶段](../images/scaling_vlms_comparison.svg) + +- 所有三个模型家族的一个共同主题是**训练数据精选**的重要性。从网络抓取的原始图像-文本对是噪声大且常常不对齐的。后续的训练阶段逐步过滤和精炼数据,从数十亿噪声对过渡到数百万高质量指令示例。最终微调数据的质量往往比模型的原始参数数量更为重要。 + +## 定位与指代 + +- 想象你在人群中指着一个人说"戴红帽子的女士"。你在用语言指代一个特定的空间区域。**视觉定位**是相反的过程:给定一张图像和一个自然语言表述,模型必须识别(定位)所指的对象。**指代表达理解**产生边界框;**指代表达分割**产生像素掩码。 + +- 形式上,给定一张图像 $I$ 和一个指代表达 $r$(例如"左边那只大型棕色狗"),模型预测一个边界框 $b = (x, y, w, h)$ 或一组定位所引用对象的坐标。数据集包括 **RefCOCO**、**RefCOCO+** 和 **RefCOCOg**,每个数据集包含具有多个对象的图像以及每个对象的明确指代表达。 + +- 早期的定位模型使用两阶段方法:首先生成区域提议(使用 Faster R-CNN 或类似方法),然后使用融合模型对每个提议与语言查询进行评分。评分最高的区域即为预测结果。这种方法计算代价高昂,且受限于提议的质量。 + +- 现代 VLM 将定位直接整合到生成式框架中。关键思想是将边界框坐标表示为**文本 token**。你将连续的坐标空间离散化为槽位(例如 $x, y, w, h$ 各 1000 个槽位),并向词汇表中添加特殊的位置 token,如 ``。然后模型通过输出一系列位置 token 来生成边界框: + +$$\text{输出: } \texttt{}$$ + +- 这种 token 化技巧使得任何自回归语言模型无需架构更改即可执行定位——它只需学会"说坐标"。**Pix2Seq**(Chen 等人,2022)率先将这种方法用于目标检测,而 Qwen-VL、Ferret 和 Kosmos-2 等模型将其扩展到指代表达理解和短语定位。 + +- **Kosmos-2**(Peng 等人,2023,Microsoft)通过将空间位置表示为嵌入在生成文本中的特殊 token,为多模态 LLM 增加了定位能力。例如,它可以生成:"一只 `` 金毛猎犬 `` `` `` `` `` `` `` 正在接飞盘。"这种文本和空间 token 的交错融合实现了同步描述和定位。 + +![通过坐标 token 化实现定位:模型生成文本 token 与离散化的边界框坐标 token 交错在一起,定位描述中提到的物体](../images/grounding_coordinate_tokens.svg) + +- **定点指向**将定位更进一步:模型不再输出边界框,而是预测一个单一的点(通常是指代物体的中心)。这对于交互式应用非常有用,例如用户问"最近的出口在哪里?",模型返回一个叠加在图像上的坐标。**Shikra** 和 **Ferret** 等模型支持基于点的指代以及基于框的定位。 + +## 免 OCR 文档理解 + +- 传统的文档理解流水线很复杂:首先运行 OCR 引擎提取文本和布局,然后将提取的文本输入语言模型。这种多阶段方法很脆弱——OCR 错误向下游传播,空间布局信息常常丢失或表征不良。如果模型能像人类一样直接从像素中读取信息呢? + +- **Donut**(Document Understanding Transformer,Kim 等人,2022)完全消除了 OCR。它使用 Swin Transformer(第 8 章)作为视觉编码器处理文档图像,并使用 BART 风格的 Transformer 解码器直接从视觉特征生成结构化文本输出。解码器可以根据任务生成 JSON、键值对或纯文本。 + +- Donut 的训练分为两个阶段。**预训练**:模型通过执行合成 OCR 来学习阅读——给定一张文档图像,生成完整的文本内容。这在从文本语料库渲染的数百万张合成文档图像上进行训练,教会视觉编码器识别字符、字体和布局。**微调**:模型通过训练生成特定于任务的结构化输出,适应特定的下游任务,如收据解析、表格理解或文档分类。 + +- Donut 解码器使用特殊的提示方案:任务由提示 token 指定(例如分类用 ``,收据解析用 ``),模型根据此提示生成输出。这种统一接口使得单个模型可以处理多种文档理解任务。 + +- **Pix2Struct**(Lee 等人,2023,Google)将免 OCR 思想应用于网页理解和图表/图形理解。关键的预训练目标是**截图解析**:给定一个网页的带掩码截图,模型生成产生可见区域的底层 HTML。这教会模型理解视觉呈现与结构化标记之间的关系。 + +- Pix2Struct 引入了**可变分辨率输入处理**:它并不是将所有图像缩放到固定大小(这会扭曲宽高比并破坏精细文字),而是在保持原始宽高比的同时将图像打包为固定数量的 patch。一个高而窄的文档产生一个高而窄的 patch 网格。这对于文档理解至关重要,因为宽高比携带着语义信息(收据窄而高;表格宽而短)。 + +![免 OCR 文档理解:Donut 和 Pix2Struct 通过视觉编码器直接处理文档图像,无需任何 OCR 预处理即可生成结构化文本输出](../images/ocr_free_document_understanding.svg) + +- **Nougat**(Blecher 等人,2023,Meta)将 Donut 架构专门应用于学术论文,直接从 PDF 页面图像生成完整的 LaTeX 标记。它可以处理复杂的数学方程、表格和图形——这些任务正是传统 OCR 流水线难以应付的。该模型在 PDF 页面图像及其对应的 LaTeX 源代码对上进行训练。 + +- 免 OCR 模型的成功展示了深度学习中的一个更广泛原则:直接从原始输入(像素)学习的端到端模型通常优于复杂的多阶段流水线,因为它们可以联合优化所有组件,并学习专门针对最终任务定制的表示。中间的 OCR 步骤是一个瓶颈,限制了模型能够学习的内容。 + +## 视觉 Token 流水线 + +- 无论架构家族如何,每个 VLM 都必须将图像转换为语言模型可以处理的一系列 token。理解这一流水线至关重要。不同模型的处理过程有所差异,但总体流程如下: + +- **第 1 步:Patch 提取。** 图像(高度 $H$,宽度 $W$)被划分为不重叠的、大小为 $P \times P$ 的 patch,产生 $N = HW / P^2$ 个 patch。对于 336x336 图像和 14x14 patch,$N = 576$。 + +- **第 2 步:视觉编码。** 每个 patch 经过线性投影并通过视觉编码器(通常是 ViT)。输出是一系列上下文 patch 嵌入 $V = [v_1, \ldots, v_N] \in \mathbb{R}^{N \times d_v}$。这些嵌入既携带局部外观信息,也携带全局上下文(来自自注意力)。 + +- **第 3 步:Token 压缩(可选)。** 一些模型将 $N$ 个视觉 token 压缩为更少的 $M \ll N$ 个 token,以减少语言模型的计算负担。Flamingo 使用 Perceiver 重采样器($M = 64$);Qwen-VL 使用交叉注意力($M = 256$);**Q-Former**(在 BLIP-2 中使用,Li 等人,2023)使用一组 $M = 32$ 个可学习查询 token,对视觉编码器的输出执行交叉注意力。 + +- **第 4 步:投影。** 视觉 token(全部或压缩后的集合)通过线性层或 MLP 投影到语言模型的嵌入空间。投影后,视觉 token 与文本 token 嵌入具有相同维度,可以与它们拼接。 + +- **第 5 步:注入 LLM。** 投影后的视觉 token 在特殊 `` 占位符 token 的位置插入到 token 序列中,组合后的序列由语言模型处理。LLM 的自注意力使文本 token 能够关注视觉 token,反之亦然。 + +![视觉 token 流水线:图像 patch 被提取,由 ViT 编码,可选地由 Perceiver 或 Q-Former 压缩,投影到 LLM 维度,并与文本 token 拼接](../images/visual_token_pipeline.svg) + +- 视觉 token 的数量直接影响计算成本。每个视觉 token 参与 LLM 的自注意力,其复杂度与序列长度的平方成正比。具有多个 patch 的高分辨率图像可能产生数百或数千个视觉 token,占据 LLM 上下文窗口的主导地位。这就是 token 压缩的重要性所在:将 576 个视觉 token 减少到 64 个,可将视觉部分在注意力中的贡献减少约 9 倍。 + +- **BLIP-2**(Li 等人,2023)以其高效的桥接策略而闻名。它引入了一个轻量级的 **Q-Former**(一个带有可学习查询的小型 Transformer),位于冻结的视觉编码器和冻结的 LLM 之间。Q-Former 是唯一可训练的组件——视觉编码器和 LLM 都保持冻结。它的预训练分为两个阶段:首先是图像-文本对比学习、匹配和描述目标(连接视觉编码器),然后是语言生成目标(连接 LLM)。这种模块化设计使得 BLIP-2 可以将任何视觉编码器插入到任何 LLM 中。 + +## 训练目标 + +- VLM 使用多种目标的组合进行训练,具体取决于架构模式: + +- **图像-文本对比损失(ITC):** 在共享嵌入空间中对齐图像和文本表示,如 CLIP 中所示。这是双编码器的主要目标,也常被用作融合模型的预训练目标。该损失就是上一篇文件中的 InfoNCE 损失。 + +- **图像-文本匹配(ITM):** 一个二分类目标——给定图像和文本,预测它们是否匹配。困难负样本(与不同图像配对的相似文本)使这项任务具有挑战性,迫使模型学习细粒度的对齐。 + +- **语言建模(LM):** 标准的自回归语言建模目标——给定之前的所有 token 预测下一个 token。对于 VLM,"之前的 token" 包括视觉 token,因此模型学习在视觉输入条件下生成文本。这是编码器-解码器和仅解码器 VLM 的主要目标。 + +$$\mathcal{L}_{\text{LM}} = -\sum_{t=1}^{T} \log p(w_t \mid w_{, 1="a", 2="red", 3="car", 4= +vocab_size, embed_dim, hidden_dim = 5, 16, 32 +W_embed = jax.random.normal(k2, (vocab_size, embed_dim)) * 0.1 +W_attn_q = jax.random.normal(k3, (hidden_dim, 32)) * 0.1 # 查询投影 + +def attend(h, img_feats, W_q): + """在给定解码器状态 h 的情况下计算图像特征上的软注意力。""" + query = h @ W_q # (32,) + scores = img_feats @ query # (16,) + weights = jax.nn.softmax(scores) # (16,) + context = weights @ img_feats # (32,) + return context, weights + +# 简单的 GRU 风格步骤(为说明目的,仅用线性 + tanh) +W_h = jax.random.normal(jax.random.PRNGKey(0), (embed_dim + 32, hidden_dim)) * 0.1 + +def decode_step(h, word_idx, img_feats): + context, attn_weights = attend(h, img_feats, W_attn_q) + word_emb = W_embed[word_idx] # (16,) + inp = jnp.concatenate([word_emb, context]) # (48,) + h_new = jnp.tanh(inp @ W_h) # (32,) + return h_new, attn_weights + +# 运行解码序列: -> "a" -> "red" -> "car" -> +target_seq = [0, 1, 2, 3, 4] +h = jnp.zeros(hidden_dim) +all_attn = [] +for word_idx in target_seq[:-1]: + h, attn_w = decode_step(h, word_idx, img_features) + all_attn.append(attn_w) + +# 可视化每一步的注意力图(重塑为 4x4 网格) +words = ["", "a", "red", "car"] +fig, axes = plt.subplots(1, 4, figsize=(14, 3)) +for i, (ax, w) in enumerate(zip(axes, words)): + ax.imshow(all_attn[i].reshape(4, 4), cmap='viridis') + ax.set_title(f'生成"{w}"后\n关注的区域') + ax.axis('off') +plt.suptitle('每个解码步骤的图像区域注意力') +plt.tight_layout(); plt.show() +# 尝试修改 img_features,观察注意力模式如何变化! +``` + +2. 模拟视觉 token 流水线:将图像划分为 patch,将 patch 投影到嵌入空间,与文本 token 嵌入拼接,并在组合序列上运行单层自注意力。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +key = jax.random.PRNGKey(7) + +# 创建一个合成的 8x8 "图像",3 个通道 +k1, k2, k3, k4 = jax.random.split(key, 4) +image = jax.random.uniform(k1, (8, 8, 3)) + +# 第 1 步:划分为 4x4 patch -> 4 个 patch +patch_size = 4 +patches = image.reshape(2, patch_size, 2, patch_size, 3) +patches = patches.transpose(0, 2, 1, 3, 4).reshape(4, patch_size * patch_size * 3) # (4, 48) +print(f"Patch 数量: {patches.shape[0]}, Patch 维度: {patches.shape[1]}") + +# 第 2 步:将 patch 投影到嵌入维度 (d=16) +d_model = 16 +W_patch = jax.random.normal(k2, (patches.shape[1], d_model)) * 0.1 +visual_tokens = patches @ W_patch # (4, 16) + +# 第 3 步:创建文本 token 嵌入(模拟 3 个文本 token) +text_tokens = jax.random.normal(k3, (3, d_model)) * 0.1 + +# 第 4 步:拼接视觉 + 文本 token +combined = jnp.concatenate([visual_tokens, text_tokens], axis=0) # (7, 16) +print(f"组合序列长度: {combined.shape[0]} (4 个视觉 + 3 个文本)") + +# 第 5 步:在组合序列上运行单头自注意力 +W_Q = jax.random.normal(k4, (d_model, d_model)) * 0.1 +k5, k6 = jax.random.split(k4) +W_K = jax.random.normal(k5, (d_model, d_model)) * 0.1 +W_V = jax.random.normal(k6, (d_model, d_model)) * 0.1 + +Q = combined @ W_Q +K = combined @ W_K +V = combined @ W_V +attn_scores = (Q @ K.T) / jnp.sqrt(d_model) +attn_weights = jax.nn.softmax(attn_scores, axis=-1) # (7, 7) + +output = attn_weights @ V # (7, 16) + +# 可视化跨模态注意力模式 +labels = ['V1', 'V2', 'V3', 'V4', 'T1', 'T2', 'T3'] +fig, ax = plt.subplots(figsize=(6, 5)) +im = ax.imshow(attn_weights, cmap='Blues') +ax.set_xticks(range(7)); ax.set_xticklabels(labels) +ax.set_yticks(range(7)); ax.set_yticklabels(labels) +ax.set_xlabel('键'); ax.set_ylabel('查询') +ax.set_title('自注意力:视觉(V)和文本(T)Token') +plt.colorbar(im, ax=ax); plt.tight_layout(); plt.show() +# 观察:文本 token 关注视觉 token(跨模态注意力)! +``` + +3. 实现用于视觉定位的坐标 token 化。给定一个边界框,将其转换为离散 token;给定离散 token,重构边界框。在不同槽位分辨率下可视化量化误差。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +def encode_bbox(bbox, num_bins=1000): + """将连续的边界框 (x, y, w, h)(在 [0,1] 范围内)转换为离散 token。""" + tokens = jnp.round(jnp.array(bbox) * (num_bins - 1)).astype(jnp.int32) + return tokens + +def decode_bbox(tokens, num_bins=1000): + """将离散 token 转换回连续的边界框。""" + return tokens.astype(jnp.float32) / (num_bins - 1) + +# 真实边界框(归一化到 [0, 1]) +gt_bbox = jnp.array([0.123, 0.456, 0.333, 0.222]) + +# 测试不同槽位分辨率下的量化 +bin_sizes = [10, 50, 100, 500, 1000] +errors = [] +for n_bins in bin_sizes: + tokens = encode_bbox(gt_bbox, n_bins) + reconstructed = decode_bbox(tokens, n_bins) + error = jnp.max(jnp.abs(gt_bbox - reconstructed)) + errors.append(float(error)) + print(f"槽位数={n_bins:>5d} | Token={tokens} | " + f"重构={reconstructed} | 最大误差={error:.6f}") + +fig, ax = plt.subplots(figsize=(8, 4)) +ax.plot(bin_sizes, errors, 'o-', color='#e74c3c', linewidth=2, markersize=8) +ax.set_xlabel('槽位数'); ax.set_ylabel('最大量化误差') +ax.set_title('边界框量化误差 vs 槽位分辨率') +ax.set_xscale('log'); ax.set_yscale('log') +ax.grid(True, alpha=0.3); plt.tight_layout(); plt.show() +# 尝试:槽位非常少时(如 5)会发生什么?误差在何时是可接受的? +``` diff --git a/chapter 10: multimodal learning/03. image and video tokenisation.md b/chapter 10: multimodal learning/03. image and video tokenisation.md new file mode 100644 index 0000000..e3b66e4 --- /dev/null +++ b/chapter 10: multimodal learning/03. image and video tokenisation.md @@ -0,0 +1,419 @@ +# 图像与视频词元化 + +*图像与视频词元化将连续的视觉数据转换为离散的词元序列,使 Transformer 能够像处理文本一样处理它们。本节涵盖 VQ-VAE、VQ-GAN、码本学习、DALL-E 的 dVAE、视频词元化以及免查询词元化。* + +## 为什么要对图像进行词元化 + +- 把语言想象成一个有限的字母表:英语大约有 26 个字母,现代语言模型将文本切分为 30,000 到 100,000 个子词词元。每个句子都变成一串离散符号,Transformer 可以逐个预测。而图像存在于连续的高维空间中:一张 256×256 的 RGB 图像就是 $\mathbb{R}^{256 \times 256 \times 3} \approx \mathbb{R}^{196{,}608}$ 中的一个点。如果你希望语言模型用与说英语同样的机制来"说"图像,就需要将这些连续的像素数组转换为一串可管理的离散词元,这些词元来自一个有限的词汇表。这种转换就是**图像词元化**。 + +- 想象你是一位马赛克艺术家。你没有无限多种瓷砖色调,只有一个固定的调色板,比如说 8192 种不同的瓷砖颜色。要再现一张照片作为马赛克,你必须 (1) 确定每个瓷砖代表照片的哪个区域,(2) 为每个区域选择最接近的瓷砖颜色,(3) 接受一些细节的丢失,但整体画面仍然可辨认。图像词元化做的正是这件事:编码器将空间块压缩为潜在向量,码本将每个向量映射到其最近的条目,结果是一个整数索引网格(每个块对应一个索引),离散模型可以处理它。 + +- 词元化的好处有三方面。首先,它大幅压缩了图像:一张 256×256 的图像可能变成一个 16×16 的词元网格,序列长度从 65,536 个像素减少到 256 个词元,这对于成本随序列长度呈二次方增长的注意力模型来说是可行的。其次,它统一了表示形式:文本词元和图像词元位于同一个离散词汇表中,使得单个自回归 Transformer 可以生成交织的文本和图像。第三,它施加了一个有用的瓶颈,迫使模型学习语义上有意义的编码,而不是记忆像素噪声。 + +![图像词元化流程概览:连续图像经过编码器,潜在向量通过码本进行量化,生成离散词元索引网格](../images/image_tokenisation_overview.svg) + +- 回顾第 8 章中卷积网络如何从图像中提取层次化特征图,以及第 7 章中文本词元化器如何将字符串转换为整数序列。图像词元化正处于两者的交汇点:它使用 CNN 或视觉 Transformer 编码器(第 8 章)产生空间特征,然后借用离散词汇表的思想(第 7 章)将这些特征转换为词元索引。 + +## VQ-VAE:向量量化 + +- 正如我们在第 6 章中看到的,标准**变分自编码器**(VAE)将输入编码为连续潜在分布,并从该分布中采样再解码为重建结果。潜在空间是连续的,这使得将其输入离散序列模型变得困难。**向量量化变分自编码器**(VQ-VAE),由 van den Oord 等人(2017)提出,通过引入一个可学习的嵌入向量码本,并将每个编码器输出映射到其最近的码本条目,用离散潜在表示取代了连续潜在表示。 + +- 想象一个藏书室,里面有恰好 $K$ 个贴有标签的书架。当一本新书(编码器输出)到达时,图书管理员将它放在与其现有书籍(码本向量)最相似的书架上,并记录下书架编号。之后,要取回这本书,你只需要书架编号:那个书架上的码本条目就是一个足够好的替代。这就是向量量化。 + +- 形式上,VQ-VAE 有三个组件: + + - **编码器** $E$,将输入图像 $\mathbf{x} \in \mathbb{R}^{H \times W \times 3}$ 映射到连续潜在向量的空间网格 $\mathbf{z}_e = E(\mathbf{x}) \in \mathbb{R}^{h \times w \times d}$,其中 $h \times w$ 是降采样后的空间分辨率,$d$ 是嵌入维度。 + + - **码本** $\mathcal{C} = \{\mathbf{e}_1, \mathbf{e}_2, \ldots, \mathbf{e}_K\} \subset \mathbb{R}^d$,包含 $K$ 个可学习的嵌入向量。典型码本大小范围为 512 到 16,384 个条目。 + + - **解码器** $D$,从量化后的潜在表示重建图像。 + +- **量化步骤**将每个编码器输出 $\mathbf{z}_e(\mathbf{x})$ 在空间位置 $(i, j)$ 处替换为最近的码本条目: + +$$\mathbf{z}_q(i,j) = \mathbf{e}_{k^\ast} \quad \text{其中} \quad k^\ast = \arg\min_k \|\mathbf{z}_e(i,j) - \mathbf{e}_k\|_2$$ + +- 这是在嵌入空间中的最近邻查找,与 k-means 分配(第 6 章)完全相同。索引 $k^\ast$ 是空间位置 $(i,j)$ 的离散词元,整张图像被表示为一个 $h \times w$ 的整数网格,取值范围为 $\{1, \ldots, K\}$。 + +![VQ-VAE 架构:编码器产生连续潜在向量,每个潜在向量匹配到最近的码本条目,解码器从量化后的编码重建图像](../images/vqvae_architecture.svg) + +- 挑战在于 $\arg\min$ 是不可微的:你无法通过离散选择进行反向传播。VQ-VAE 通过**直通估计器**解决了这个问题:在前向传播过程中,解码器接收 $\mathbf{z}_q$(量化后的向量);在反向传播过程中,重建损失相对于 $\mathbf{z}_q$ 的梯度被直接复制到 $\mathbf{z}_e$,就好像量化步骤是恒等函数一样。这可以简洁地写为: + +$$\mathbf{z}_q = \mathbf{z}_e + \text{sg}(\mathbf{z}_q - \mathbf{z}_e)$$ + +- 其中 $\text{sg}(\cdot)$ 是停止梯度算子。在前向传播中,计算结果为 $\mathbf{z}_q$;在反向传播中,梯度仅流经 $\mathbf{z}_e$ 项。 + +- 完整的 VQ-VAE 损失包含三项: + +$$\mathcal{L} = \underbrace{\|\mathbf{x} - D(\mathbf{z}_q)\|_2^2}_{\text{重建损失}} + \underbrace{\|\text{sg}(\mathbf{z}_e) - \mathbf{e}\|_2^2}_{\text{码本(VQ)损失}} + \underbrace{\beta \|\mathbf{z}_e - \text{sg}(\mathbf{e})\|_2^2}_{\text{承诺损失}}$$ + +- **重建损失**训练编码器和解码器忠实地再现输入。**码本损失**(也称为 VQ 损失)将码本向量拉向编码器输出;注意 $\text{sg}(\mathbf{z}_e)$ 意味着编码器不会从这一项接收梯度,因此它只更新码本。**承诺损失**则相反:它鼓励编码器输出保持接近码本向量,防止编码器"远离"码本。超参数 $\beta$(通常为 0.25)控制码本损失和承诺损失之间的平衡。 + +- 在实践中,码本通常使用**指数移动平均**(EMA)而不是梯度下降来更新,这样更稳定。令 $\mathbf{n}_k$ 为分配给码本条目 $k$ 的编码器输出计数,$\mathbf{s}_k$ 为它们的和。EMA 更新为: + +$$\mathbf{n}_k \leftarrow \gamma \mathbf{n}_k + (1 - \gamma) |\{(i,j) : k^\ast_{ij} = k\}|$$ + +$$\mathbf{s}_k \leftarrow \gamma \mathbf{s}_k + (1 - \gamma) \sum_{(i,j) : k^\ast_{ij} = k} \mathbf{z}_e(i,j)$$ + +$$\mathbf{e}_k \leftarrow \frac{\mathbf{s}_k}{\mathbf{n}_k}$$ + +- 其中 $\gamma$ 是衰减率(通常为 0.99)。这等价于对编码器输出运行在线 k-means 算法。 + +### 码本坍塌 + +- VQ-VAE 一个臭名昭著的失败模式是**码本坍塌**(也称为索引坍塌):模型只学会使用 $K$ 个码本条目中的一小部分,导致大多数条目"死亡"。想象一个图书馆,90% 的书架是空的,因为图书管理员总是把书送到同样的几个热门书架上。这浪费了表示能力。 + +- 码本坍塌的发生是因为编码器、码本和解码器在训练过程中共同适应。如果一个条目在几个批次中都没有被选中,它就会漂离编码器流形,使其更不可能被选中,从而形成正反馈循环。 + +- 缓解码本坍塌的几种技术: + - **码本重置**:定期通过随机采样编码器输出重新初始化死亡条目。这为死亡条目在潜在空间活跃区域附近提供了一个新的起点。 + - **带拉普拉斯平滑的 EMA 更新**:向 $\mathbf{n}_k$ 添加一个小常数,防止任何条目计数为零,确保所有条目都能接收到梯度信号。 + - **承诺损失调优**:增大 $\beta$ 迫使编码器输出更紧密地聚集在码本条目周围,使分配更均匀。 + - **分解编码**:将码本查找分解为多个较小查找的乘积(例如,两个大小各为 $\sqrt{K}$ 的码本),通过减少每次查找的有效码本大小来提高利用率。 + - **熵正则化**:添加一个惩罚项,鼓励码本使用上的均匀分布,最大化熵 $H = -\sum_k p_k \log p_k$,其中 $p_k$ 是经验分配概率。 + +![码本利用率:健康码本具有均匀分布的分配,而坍塌码本中大多数条目未被使用](../images/codebook_collapse.svg) + +## VQ-GAN:对抗训练实现更高保真度 + +- VQ-VAE 能产生不错的重建效果,但像素级的 $\ell_2$ 损失往往会产生模糊的输出,因为它对每个像素偏差都同等惩罚,在合理的细节上取平均而不是选择清晰的细节。想象一下,要求某人画一张脸,使得与所有可能的脸的平均差异最小——他们会画出一张模糊的平均脸,而不是一张清晰的特定人脸。 + +- **VQ-GAN**(Esser 等人,2021)通过将 VQ-VAE 框架与生成对抗网络(第 6 章)中的**判别器**相结合来解决这个问题。判别器是一个基于块的卷积网络,用于判断局部图像块是真(来自训练数据)还是假(来自解码器)。这种对抗损失鼓励解码器产生感知上清晰、逼真的纹理,而不是像素级的平均值。 + +- VQ-GAN 目标函数在 VQ-VAE 损失的基础上增加了两项: + +$$\mathcal{L}_\text{VQ-GAN} = \mathcal{L}_\text{VQ-VAE} + \lambda_\text{adv} \mathcal{L}_\text{adv} + \lambda_\text{perc} \mathcal{L}_\text{perc}$$ + +- **对抗损失** $\mathcal{L}_\text{adv}$ 是应用于解码器输出的标准 GAN 目标。判别器 $\mathcal{D}$ 试图区分真实块和解码块,而解码器(生成器)试图欺骗它。非饱和形式为: + +$$\mathcal{L}_\text{adv} = -\mathbb{E}[\log \mathcal{D}(D(\mathbf{z}_q))]$$ + +- **感知损失** $\mathcal{L}_\text{perc}$ 比较原始图像和重建图像在预训练网络(通常是 VGG 或 LPIPS)中的特征激活: + +$$\mathcal{L}_\text{perc} = \sum_l \|\phi_l(\mathbf{x}) - \phi_l(D(\mathbf{z}_q))\|_2^2$$ + +- 其中 $\phi_l$ 表示预训练网络在第 $l$ 层的特征图。这个损失捕捉的是高层结构相似性,而非像素级精度。 + +- 权重 $\lambda_\text{adv}$ 被自适应地设置,使得对抗梯度和重建梯度保持平衡,防止在训练早期重建效果还很差时对抗损失占主导。 + +![VQ-GAN 训练:编码器和解码器通过量化步骤连接,块判别器对解码输出提供对抗反馈](../images/vqgan_training.svg) + +- 结果是,在相同码本大小下,VQ-GAN 产生的词元化器重建效果远比 VQ-VAE 清晰。VQ-GAN 是许多主要图像生成系统(包括最初的 DALL-E、Parti 以及众多文生图模型)背后的骨干词元化器。它将 256×256 的图像转换为 16×16 或 32×32 的离散词元网格,来源于大小为 1024–16384 的码本,在每个空间维度上实现 16 倍到 64 倍的压缩比。 + +## 残差量化与多尺度码本 + +- 单个码本对重建质量施加了一个硬上限:每个空间位置恰好由一个码本向量表示,任何比码本所能表达的更精细的细节都会丢失。想象用固定调色板中的一个词来描述一种颜色:"青色"很接近但不精确。如果你能添加一个细化描述——"青色,但稍微偏蓝一点,亮一点"——你就能得到更接近的结果。 + +- **残差量化**(RQ)迭代地应用了这一思想。在第一次量化步骤产生 $\mathbf{z}_q^{(1)}$ 之后,计算残差 $\mathbf{r}^{(1)} = \mathbf{z}_e - \mathbf{z}_q^{(1)}$,然后对残差使用第二个码本进行量化得到 $\mathbf{z}_q^{(2)}$,以此类推,共 $T$ 个层级: + +$$\mathbf{r}^{(0)} = \mathbf{z}_e$$ + +$$\mathbf{z}_q^{(t)} = \text{Quantise}(\mathbf{r}^{(t-1)}, \mathcal{C}^{(t)})$$ + +$$\mathbf{r}^{(t)} = \mathbf{r}^{(t-1)} - \mathbf{z}_q^{(t)}$$ + +- 最终的量化表示为 $\hat{\mathbf{z}} = \sum_{t=1}^{T} \mathbf{z}_q^{(t)}$。使用 $T$ 个层级,每个层级码本大小为 $K$,有效词汇表大小为 $K^T$,但你只需要存储 $T \times K$ 个向量,而不是 $K^T$ 个。例如,8 个层级,$K = 1024$,有效条目数为 $1024^8 \approx 10^{24}$,而只存储了 8192 个向量。 + +- 每个后续层级捕捉更精细的细节:第一个码本捕捉粗略结构,第二个捕捉中频修正,依此类推。这类似于 JPEG 中的逐次逼近或网页图像中的渐进式渲染,先出现粗略版本,然后逐步填充细节。 + +![残差量化:原始向量在多个阶段中被逐步逼近,每个阶段量化前一阶段的残差](../images/residual_quantisation.svg) + +- **多尺度码本**通过在不同空间分辨率上操作来扩展这一思想。不是重复量化同一个空间网格,而是在多个尺度上进行量化:粗粒度网格捕捉全局结构,细粒度网格捕捉局部细节。这与第 8 章目标检测部分中的特征金字塔思想相关,其中不同尺度的特征捕捉不同层次的细节。 + +- **乘积量化**是一种相关技术,将 $d$ 维潜在向量拆分为 $M$ 个维度为 $d/M$ 的子向量,每个子向量使用自己的码本独立量化。这使得有效词汇表达到 $K^M$,同时只存储 $M \times K$ 个向量。乘积量化广泛应用于近似最近邻搜索(第 13 章),并已被适配用于图像词元化。 + +- **有限标量量化**(FSQ),由 Mentzer 等人(2023)提出,采取了一种完全不同的方法:不是学习一个码本,而是简单地将潜在向量的每个维度四舍五入到一组固定整数级别中的一个(例如 $\{-2, -1, 0, 1, 2\}$)。每维 $L$ 个级别,$d$ 个维度,隐含码本大小为 $L^d$。FSQ 完全避免了码本坍塌,因为没有可学习的码本向量,只有被确定性四舍五入的可学习编码器输出。直通估计器处理了四舍五入的不可微性。 + +## 实践中的图像词元化器 + +- 从 VQ-VAE 到 VQ-GAN 再到残差量化的演进,催生了一系列实际图像词元化器,用于最先进的生成模型。 + +### DALL-E 词元化器(dVAE) + +- 最初的 **DALL-E**(Ramesh 等人,2021)使用离散 VAE(dVAE)将 256×256 图像词元化为 32×32 的词元网格,码本大小为 8192。dVAE 将硬 $\arg\min$ 量化替换为 Gumbel-Softmax 松弛,使前向传播在训练过程中可微。在推理时,使用 $\arg\max$ 生成硬词元分配。dVAE 使用重建损失、与均匀先验的 KL 散度以及 Gumbel-Softmax 的学习温度调度组合进行训练。然后 DALL-E 训练了一个 120 亿参数的自回归 Transformer 来建模 256 个文本词元和 1024 个图像词元(32×32)的联合分布。 + +### LlamaGen + +- **LlamaGen**(Sun 等人,2024)表明,只要你有一个好的图像词元化器,就可以将标准的 Llama 风格语言模型架构(第 7 章)重新用于自回归图像生成。LlamaGen 使用改进的 VQ-GAN 词元化器,具有大型码本(16,384 个条目),并训练了一个普通的自回归 Transformer(除了词元化器外没有特殊的图像特定修改)以光栅扫描顺序从左到右预测图像词元。关键的见解是,一旦图像被词元化为离散序列,适用于语言的相同下一个词元预测范式也同样适用于图像,这验证了词元化确实弥合了模态鸿沟的观点。 + +### Cosmos 词元化器 + +- **Cosmos 词元化器**(NVIDIA,2024)设计用于在统一框架中处理图像和视频。它使用因果 3D 架构,将图像视为单帧视频,使得同一个词元化器可以处理两种模态。Cosmos 支持连续和离散两种词元化模式:连续模式输出实值潜在向量(用于扩散模型后端),而离散模式应用有限标量量化产生整数词元(用于自回归模型后端)。编码器使用因果 3D 卷积,使得每帧的词元仅依赖于当前帧和之前的帧,从而支持流式视频词元化。 + +![图像词元化器架构对比:带有 Gumbel-Softmax 的 dVAE、带有码本查找的 VQ-GAN、以及带有标量四舍五入的 FSQ](../images/image_tokeniser_comparison.svg) + +## 视频词元化 + +- 视频在图像的二维空间维度上增加了第三个轴——时间。视频是一系列帧,通常为每秒 24–30 帧,相邻帧之间高度冗余,因为在 33 毫秒内视觉世界不会发生剧烈变化。视频词元化利用这种时间冗余来实现比独立词元化每帧高得多的压缩率。 + +- 把视频压缩想象成一幅翻页书。如果每一页都从头画起,你需要数千张精细的绘图。但大多数页面与相邻页面几乎相同,所以你可以每 10 页画一个完整的"关键帧",只记录中间页面上的微小变化。视频词元化器自动学会了这个技巧。 + +### 3D VQ-VAE + +- 将 VQ-VAE 扩展到视频的最直接方式是 **3D VQ-VAE**,它将编码器和解码器中的 2D 卷积替换为同时在空间和时间维度上操作的 3D 卷积。如果编码器在空间上降采样 $f_s$ 倍,在时间上降采样 $f_t$ 倍,则 $T \times H \times W$ 的视频片段变为 $(T/f_t) \times (H/f_s) \times (W/f_s)$ 的词元网格。 + +- 例如,$f_s = 16$ 且 $f_t = 4$ 时,一个 16 帧的 256×256 视频片段变为 $4 \times 16 \times 16 = 1024$ 的词元序列。这对 Transformer 进行自回归建模来说已经足够紧凑,而原始像素数将是 $16 \times 256 \times 256 \times 3 \approx 310$ 万个数值。 + +- 3D 卷积联合学习空间和时间特征。早期层捕捉局部运动(帧间移动的边缘),而更深层捕捉高层动态(物体的出现、消失或形状变化)。这与第 8 章卷积网络中的层次化特征提取原理相同,只是沿时间轴进行了扩展。 + +![用于视频的 3D VQ-VAE:短视频片段通过 3D 卷积编码为潜在向量的时空网格,量化后解码回帧](../images/video_3d_vqvae.svg) + +### 因果视频词元化器 + +- 标准 3D 卷积会同时查看过去、当前和未来的帧,这意味着在词元化任何帧之前需要整个视频片段。**因果视频词元化器**约束时间卷积,使每个输出仅依赖于当前帧和之前的帧,从不依赖于未来的帧。这类似于自回归 Transformer(第 7 章)中的因果掩码:信息在时间上向前流动,但绝不向后。 + +- 因果词元化对于两种使用场景至关重要。首先,**流式处理**:你可以在帧到达时实时词元化视频,而无需缓冲未来的帧。其次,**自回归生成**:当 Transformer 逐帧生成视频时,第 $t$ 帧的词元必须在不知道第 $t+1$ 帧的情况下可计算,因为第 $t+1$ 帧尚未生成。 + +- 因果约束通过非对称填充时间卷积来实现:时间大小为 $k$ 的核在过去一侧填充 $k-1$ 个零,未来一侧填充零个零,确保时间 $t$ 的输出仅依赖于时间 $t-k+1, \ldots, t$ 的输入。 + +- 因果视频词元化器的一个优雅特性是它们可以词元化单张图像("视频"只有一帧)而无需特殊处理。第一帧没有历史上下文,因此其词元仅从该帧本身计算。这种**图像-视频统一**意味着单个词元化器可以服务于两种模态,简化了架构,并使模型能够使用同一个解码器生成图像和视频。 + +### 时间压缩策略 + +- 不同的应用需要不同的时间压缩比。对于动作识别(其中细微运动很重要),温和压缩($f_t = 2$)可以保留时间细节。对于长视频生成(存储数千帧是不可行的),需要激进压缩($f_t = 8$ 或更高)。 + +- 某些词元化器使用**分解压缩**:空间和时间压缩在不同的阶段进行。首先,2D 编码器独立压缩每帧,产生每帧的潜在网格。然后,1D 时间编码器跨时间维度进行压缩。这种分解在计算上比完整的 3D 卷积更便宜,并允许空间和时间采用不同的压缩比。其代价是它不能像联合 3D 编码那样高效地捕捉时空模式(如对角线运动的球)。 + +- **时间插值词元**是一项最近的创新,词元化器仅完整编码关键帧,并将中间帧表示为轻量级的插值编码,描述如何在关键帧之间变形。这类似于经典视频压缩(H.264/HEVC 中的 I 帧和 P 帧),但在学习到的潜在空间中进行。 + +![时间压缩策略:帧独立的空间编码后接时间编码,与联合时空 3D 编码的对比](../images/temporal_compression_strategies.svg) + +## 连续词元与离散词元 + +- 并非每个下游模型都需要离散词元。**扩散模型**(第 10 章,文件 04)原生使用连续值——它们迭代地去噪高斯样本,其损失函数(去噪得分匹配)定义在连续空间上。对于扩散后端,词元化器编码器产生连续潜在向量,从不进行量化。**潜在扩散模型**(Stable Diffusion、DALL-E 3、Flux)使用类似 VQ-GAN 的编码器-解码器,但完全跳过了码本,在连续潜在空间中操作。 + +- 而**自回归模型**(GPT 风格)则使用 $K$ 类上的 softmax 从有限词汇表中预测下一个词元。它们从根本上需要离散词元。每个使用自回归 Transformer 的图像生成系统(DALL-E、Parti、LlamaGen、Chameleon)都依赖离散词元化器。 + +- 因此,连续词元和离散词元之间的选择由生成后端决定: + +- 在以下情况下使用**离散词元**:模型是自回归的(使用交叉熵损失的下一个词元预测),你想与文本词元共享词汇表以实现统一的多模态模型,或者你需要精确的词元级控制(例如,通过词元替换进行检索或编辑)。 + +- 在以下情况下使用**连续词元**:模型是扩散模型或流匹配模型,任务需要非常高的保真度重建(连续潜在表示完全避免了量化误差),或者你想使用作用于实值向量的回归损失。 + +- 一些最近的架构支持两种模式。例如,Cosmos 词元化器可以从同一个编码器输出连续潜在表示(用于其扩散模式)或 FSQ 离散化词元(用于其自回归模式),只需一个可以打开或关闭的轻量级量化头。 + +- **软量化**是一个中间地带:不是硬 $\arg\min$ 分配,而是计算 top-$k$ 最近码本条目的加权平均,权重由负距离上的 softmax 给出。这比硬量化保留了更多信息,同时仍然近似离散。有些系统在训练时使用软量化,在推理时使用硬量化。 + +![根据下游生成模型选择连续词元化与离散词元化的决策树](../images/continuous_vs_discrete_tokens.svg) + +## 应用 + +### 自回归图像生成 + +- 一旦图像变成离散词元序列,你就可以训练标准的自回归 Transformer 来建模它们。图像词元被展平为一维序列(通常按光栅扫描顺序:从左到右、从上到下),Transformer 学习 $p(\text{词元}_i \mid \text{词元}_1, \ldots, \text{词元}_{i-1})$,使用标准交叉熵损失。在生成时,词元被逐个采样,完整的网格通过词元化器的解码器转换为像素。 + +- 文本条件化很简单:在图像词元序列前添加文本词元,使模型学习 $p(\text{图像词元} \mid \text{文本词元})$。这正是 DALL-E、Parti 和 LlamaGen 执行文生图的方式。文本词元和图像词元共享同一个 Transformer、同一个注意力机制,并且通常共享同一个嵌入表(文本词元和图像词元占据不同的索引范围)。 + +- 光栅扫描顺序引入了一种人为的非对称性:图像的左上角是在没有任何关于右下角上下文的情况下首先生成的。一些工作解决了这个问题。**掩码图像建模**(MaskGIT)训练了一个双向 Transformer,同时生成所有词元但置信度不同,迭代地解开最自信的词元。**多尺度生成**首先生成粗粒度词元(捕捉全局构图),然后用残差词元进行细化。这些方法用纯从左到右生成的简单性换取了更好的全局连贯性。 + +### 统一的视觉-语言词元 + +- 图像词元化最深刻的动机是**统一**:将视觉和语言置于相同的表示格式中,使得单个模型架构可以同时处理两者。正如我们在第 7 章中讨论的,语言模型是极其强大的序列到序列机器。通过将图像表示为词元序列,我们免费继承了语言建模的所有基础设施——预训练配方、缩放定律、RLHF、上下文长度扩展。 + +- **Chameleon**(Meta,2024)是一个突出的例子:它使用具有 8192 个码本条目的 VQ-GAN 词元化器将图像转换为词元,这些词元与文本词元交织在一个约 65,000 个条目(文本 + 图像)的单一词汇表中。标准的 Transformer 在混合文本-图像序列上进行训练,使其能够根据图像生成文本、根据文本生成图像或生成交织的文本和图像内容,全部使用同一次前向传播。 + +- **Gemini**(Google,2024)在大规模上采取了类似的方法,原生地在单个 Transformer 中理解并生成图像、音频和文本,由特定模态的词元化器馈送到共享序列中。 + +- 统一模型中的关键工程挑战是**词汇表平衡**:如果 65,000 个词汇表条目中有 8192 个是图像词元,模型可能会分配不足的能力给视觉。解决方案包括为每种模态使用独立的嵌入层(仅在注意力层面共享)、特定模态的损失加权,以及预训练期间仔细的数据混合比例。 + +![统一视觉-语言模型:来自不同词元化器的文本和图像词元交织成单个序列,由一个 Transformer 处理](../images/unified_vision_language_tokens.svg) + +## 编程练习(在 Colab 或笔记本中运行) + +1. 在 JAX 中实现一个最小 VQ 层:给定一批编码器输出向量,执行最近邻码本查找并计算 VQ-VAE 损失(重建 + 码本 + 承诺)。将码本利用率可视化为直方图。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# --- 最小 VQ 层 --- +key = jax.random.PRNGKey(42) +d = 8 # 嵌入维度 +K = 64 # 码本大小 +n_vectors = 256 # 一批编码器输出 + +# 随机编码器输出和码本 +k1, k2 = jax.random.split(key) +z_e = jax.random.normal(k1, (n_vectors, d)) # 编码器输出 +codebook = jax.random.normal(k2, (K, d)) * 0.1 # 码本(小初始化) + +# 最近邻查找:为每个 z_e 找到最近的码本条目 +# distances[i, k] = ||z_e[i] - codebook[k]||^2 +distances = ( + jnp.sum(z_e ** 2, axis=1, keepdims=True) + - 2 * z_e @ codebook.T + + jnp.sum(codebook ** 2, axis=1, keepdims=True).T +) +indices = jnp.argmin(distances, axis=1) # 词元索引 +z_q = codebook[indices] # 量化向量 + +# VQ-VAE 损失项 +beta = 0.25 +loss_codebook = jnp.mean((jax.lax.stop_gradient(z_e) - z_q) ** 2) +loss_commit = jnp.mean((z_e - jax.lax.stop_gradient(z_q)) ** 2) +loss_total = loss_codebook + beta * loss_commit +print(f"码本损失: {loss_codebook:.4f}, 承诺损失: {loss_commit:.4f}") + +# 码本利用率 +unique, counts = jnp.unique(indices, return_counts=True, size=K, fill_value=-1) +plt.figure(figsize=(10, 4)) +plt.bar(range(K), counts, color='#3498db', alpha=0.8) +plt.xlabel('码本索引'); plt.ylabel('分配计数') +plt.title(f'码本利用率(已使用 {jnp.sum(counts > 0)}/{K} 个条目)') +plt.grid(True, alpha=0.3); plt.tight_layout(); plt.show() +# 尝试:将 K 增加到 512 并观察坍塌。然后添加码本重置逻辑。 +``` + +2. 构建一个玩具 2D 向量量化器,学习对 2D 分布进行划分。生成随机 2D 点,通过 EMA 更新学习码本,并将 Voronoi 区域可视化。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# 从高斯混合生成 2D 数据 +key = jax.random.PRNGKey(0) +n_points = 2000 +K = 16 # 码本条目数 +gamma = 0.99 # EMA 衰减 + +# 四个簇 +keys = jax.random.split(key, 5) +centres = jnp.array([[2, 2], [-2, 2], [-2, -2], [2, -2]], dtype=jnp.float32) +data = jnp.concatenate([ + jax.random.normal(keys[i], (n_points // 4, 2)) * 0.5 + centres[i] + for i in range(4) +]) + +# 从随机数据点初始化码本 +idx = jax.random.choice(keys[4], n_points, (K,), replace=False) +codebook = data[idx] +ema_count = jnp.ones(K) +ema_sum = codebook.copy() + +# 运行多个 epoch 的基于 EMA 的码本学习 +for epoch in range(30): + # 将每个点分配给最近的码本条目 + dists = jnp.sum((data[:, None, :] - codebook[None, :, :]) ** 2, axis=2) + assignments = jnp.argmin(dists, axis=1) + # EMA 更新 + for k in range(K): + mask = (assignments == k) + count_k = jnp.sum(mask) + ema_count = ema_count.at[k].set(gamma * ema_count[k] + (1 - gamma) * count_k) + if count_k > 0: + sum_k = jnp.sum(data[mask], axis=0) + ema_sum = ema_sum.at[k].set(gamma * ema_sum[k] + (1 - gamma) * sum_k) + codebook = ema_sum / ema_count[:, None] + +# 可视化分配和码本 +fig, ax = plt.subplots(1, 1, figsize=(8, 8)) +colors = plt.cm.tab20(jnp.linspace(0, 1, K)) +for k in range(K): + mask = assignments == k + ax.scatter(data[mask, 0], data[mask, 1], c=[colors[k]], s=5, alpha=0.3) +ax.scatter(codebook[:, 0], codebook[:, 1], c='black', s=120, marker='X', + edgecolors='white', linewidths=1.5, zorder=10, label='码本') +ax.set_title(f'在 2D 数据上学得的 VQ 码本({K} 个条目)') +ax.legend(); ax.set_aspect('equal'); ax.grid(True, alpha=0.3) +plt.tight_layout(); plt.show() +# 尝试:将 K 增加到 64 并观察更精细的划分。减小 gamma 并观察不稳定性。 +``` + +3. 演示残差量化:用 $T$ 个连续的量化阶段对一批向量进行编码,并测量每个层级重建误差的下降。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +key = jax.random.PRNGKey(7) +d = 16 # 嵌入维度 +K = 32 # 每个层级的码本大小 +T = 8 # 残差层级数 +n_vectors = 512 + +# 待量化的随机数据 +k1, *cb_keys = jax.random.split(key, T + 1) +z = jax.random.normal(k1, (n_vectors, d)) + +# 每个层级的独立随机码本 +codebooks = [jax.random.normal(cb_keys[t], (K, d)) * (0.5 ** t) + for t in range(T)] + +# 残差量化循环 +residual = z.copy() +z_hat = jnp.zeros_like(z) +errors = [] + +for t in range(T): + cb = codebooks[t] + dists = (jnp.sum(residual ** 2, axis=1, keepdims=True) + - 2 * residual @ cb.T + + jnp.sum(cb ** 2, axis=1, keepdims=True).T) + indices = jnp.argmin(dists, axis=1) + z_q_t = cb[indices] + z_hat = z_hat + z_q_t + residual = residual - z_q_t + mse = jnp.mean(jnp.sum((z - z_hat) ** 2, axis=1)) + errors.append(float(mse)) + print(f"层级 {t+1}: MSE = {mse:.4f}") + +plt.figure(figsize=(8, 5)) +plt.plot(range(1, T + 1), errors, 'o-', color='#e74c3c', linewidth=2, markersize=8) +plt.xlabel('残差量化层级') +plt.ylabel('重建 MSE') +plt.title('残差量化的误差降低') +plt.xticks(range(1, T + 1)); plt.grid(True, alpha=0.3) +plt.tight_layout(); plt.show() +# 尝试:使用大小为 K*T 的单个码本并与 RQ 比较。哪个更好? +``` + +4. 模拟一个简单的 1D"视频词元化器":生成一系列 1D 信号(模拟视频帧),应用因果时间压缩,并与无因果压缩在重建质量方面进行比较。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +key = jax.random.PRNGKey(99) +n_frames = 16 +frame_len = 64 + +# 生成一个"视频":在帧间缓慢移动的高斯凸起 +x_axis = jnp.linspace(-3, 3, frame_len) +frames = jnp.stack([ + jnp.exp(-0.5 * (x_axis - (-2 + 4 * t / n_frames)) ** 2) + for t in range(n_frames) +]) # 形状: (n_frames, frame_len) + +# 因果时间压缩:每帧的编码仅依赖于过去的帧 +# 简单方法:使用过去帧的指数衰减对当前帧进行平均 +alpha_causal = 0.6 +causal_codes = jnp.zeros_like(frames) +causal_codes = causal_codes.at[0].set(frames[0]) +for t in range(1, n_frames): + causal_codes = causal_codes.at[t].set( + alpha_causal * frames[t] + (1 - alpha_causal) * causal_codes[t - 1] + ) + +# 无因果:同时平均过去和未来(双边平滑) +kernel = jnp.array([0.2, 0.6, 0.2]) # 过去, 当前, 未来 +padded = jnp.concatenate([frames[:1], frames, frames[-1:]], axis=0) +noncausal_codes = jnp.stack([ + kernel[0] * padded[t] + kernel[1] * padded[t+1] + kernel[2] * padded[t+2] + for t in range(n_frames) +]) + +# 重建误差 +mse_causal = jnp.mean((frames - causal_codes) ** 2) +mse_noncausal = jnp.mean((frames - noncausal_codes) ** 2) +print(f"因果 MSE: {mse_causal:.6f}, 无因果 MSE: {mse_noncausal:.6f}") + +fig, axes = plt.subplots(1, 3, figsize=(15, 5)) +for ax, data, title in zip(axes, + [frames, causal_codes, noncausal_codes], + ['原始帧', f'因果 (MSE={mse_causal:.5f})', + f'无因果 (MSE={mse_noncausal:.5f})']): + ax.imshow(data, aspect='auto', cmap='viridis', origin='lower') + ax.set_xlabel('空间位置'); ax.set_ylabel('帧索引') + ax.set_title(title) +plt.tight_layout(); plt.show() +# 尝试:改变 alpha_causal 和核权重。alpha=1.0 时会发生什么? +``` diff --git a/chapter 10: multimodal learning/04. cross-modal generation.md b/chapter 10: multimodal learning/04. cross-modal generation.md new file mode 100644 index 0000000..c798196 --- /dev/null +++ b/chapter 10: multimodal learning/04. cross-modal generation.md @@ -0,0 +1,405 @@ +# 跨模态生成 (Cross-Modal Generation) + +*跨模态生成(cross-modal generation)是指以某一模态的输入为条件,生成另一模态的输出——从文生图、图生文、文生音频,乃至更多。本章涵盖 DALL·E、Stable Diffusion、无分类器引导、ControlNet、图像描述、文生视频(Sora)以及文生音频生成。* + +- 在本章的文件 01-03 中,你已经学习了如何表示、对齐和分词不同模态。现在轮到创造性的环节了:从一个模态生成另一个模态。跨模态生成是文生图工具、视频合成系统、音乐创作模型和图像描述背后的引擎。可以将其理解为教会机器成为多媒体艺术家——你用文字描述你想要的内容,机器则负责绘画、动画或作曲。 + +- 核心思想是**条件生成(conditional generation)**:给定来自模态 $A$(例如文本)的输入,生成模态 $B$(例如图像)的输出。形式上,我们学习模型 $p_\theta(y \mid x)$,其中 $x$ 是条件信号,$y$ 是生成的输出。挑战在于这个条件分布极其复杂且维度极高——一张 512x512 的图像存在于 $\mathbb{R}^{786432}$ 中,而对于同一个文本提示,可能有无数张合理的图像。 + +![跨模态生成概览:文本、图像、音频和视频模态之间以方向箭头连接,展示文生图、图生文、文生音频和文生视频等生成路径](../images/cross_modal_generation_overview.svg) + +## 文生图生成 (Text-to-Image Generation) + +- 想象你向法庭素描师描述一个场景。素描师必须理解你的话,回忆物体长什么样,在空间上排布它们,最后画出最终的图画。文生图模型正是做这件事,但它们必须从数据中学习所有这些技能,而不是经过多年的艺术院校训练。 + +### DALL·E:自回归图像生成 + +- **DALL·E**(Ramesh 等人,2021)将图像生成视为一个序列预测问题——这正是语言模型所采用的范式(见第 07 章)。其关键洞察是:如果你能将图像表示为离散 token(回顾文件 03 中的 VQ-VAE),那么生成图像就只是逐个生成 token 序列的过程。 + +- 其流程分为两个阶段。首先,一个**离散 VAE(dVAE)**将 256x256 的图像压缩成 32x32 的离散 token 网格,码本大小为 8192,将图像简化为 1024 个 token 的序列。其次,一个**Transformer 解码器**被训练来建模 256 个文本 token(BPE 编码)与 1024 个图像 token 拼接后的联合分布,总计 1280 个 token: + +$$p(x_{\text{text}}, x_{\text{img}}) = \prod_{i=1}^{1280} p(x_i \mid x_1, \ldots, x_{i-1})$$ + +- 在生成时,输入文本 token,模型自回归地逐个采样图像 token。这种方法优雅之处在于它复用了语言建模的完整机制——注意力、因果掩码、top-k 采样——来完成图像合成。 + +- 缺点是自回归生成本质上是串行的:逐个生成 1024 个 token 速度很慢,而且序列早期的任何错误都会被放大。DALL·E 通过生成大量候选图像并用 CLIP(来自文件 01)进行重排序来缓解这一问题,以找到与文本提示最匹配的结果。 + +![DALL·E 流程:文本 token 和图像 token 拼接成单一序列,由 Transformer 解码器处理,该解码器以文本 token 为条件自回归地预测图像 token](../images/dalle_autoregressive_pipeline.svg) + +### Stable Diffusion:带文本条件的隐空间扩散 + +- **Stable Diffusion**(Rombach 等人,2022)采用了一种根本不同的方法。它不是逐个预测 token,而是从纯噪声开始,在文本提示的引导下逐步将噪声去噪成图像。回顾第 8 章中的扩散模型——Stable Diffusion 在压缩后的隐空间(latent space)而非像素空间中运行,因此效率大幅提升。 + +- 其架构由三个组件协同工作。**VAE 编码器**将图像从像素空间($512 \times 512 \times 3$)压缩为隐空间表示($64 \times 64 \times 4$),将维度降低了 48 倍。**文本编码器**(通常为 CLIP 或 OpenCLIP)将文本提示转换为嵌入向量序列。**U-Net 去噪器**接收含噪隐变量、时间步和文本嵌入,并预测每一步需要减去的噪声。文本条件通过**交叉注意力(cross-attention)**层进入 U-Net: + +$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V$$ + +- 其中 $Q$ 来自含噪图像特征,$K$ 和 $V$ 来自文本嵌入。这使得模型能够在每个空间位置上关注相关的词语——当去噪"红球"应该出现的区域时,模型会关注"红"和"球"这两个 token。 + +- 在推理时,你在隐空间中采样 $z_T \sim \mathcal{N}(0, I)$,利用 U-Net 迭代去噪 $T$ 步(通常使用 DDIM 调度为 20-50 步),然后用 VAE 解码器将干净的隐变量 $z_0$ 解码回像素空间。整个前向过程在消费级 GPU 上仅需数秒即可生成一张 512x512 的图像。 + +![Stable Diffusion 架构:文本提示由 CLIP 编码,隐空间中的随机噪声由带有交叉注意力机制的 U-Net 迭代去噪,文本嵌入作为条件,最后由 VAE 解码生成最终图像](../images/stable_diffusion_architecture.svg) + +### 无分类器引导的实践应用 + +- **无分类器引导(Classifier-Free Guidance,CFG)**是让文生图模型能够生成与提示真正匹配的图像的关键要素。回顾第 8 章,CFG 同时训练条件模型和无条件模型,然后在采样时放大条件信号: + +$$\hat{\epsilon} = \epsilon_\theta(x_t, \varnothing) + s \cdot (\epsilon_\theta(x_t, c) - \epsilon_\theta(x_t, \varnothing))$$ + +- 其中 $s$ 是引导尺度。可以将 $(\epsilon_\theta(x_t, c) - \epsilon_\theta(x_t, \varnothing))$ 理解为"朝向提示的方向"——它捕捉了有条件预测与无条件预测之间的差异。乘以 $s > 1$ 会放大这个方向,将图像推近文本描述,但代价是多样性降低。 + +- 在实践中,Stable Diffusion 的常用默认值为 $s = 7.5$。当 $s = 1.0$ 时得到模型的原始输出(多样但仅松散匹配提示)。当 $s \geq 20$ 时图像变得过饱和且重复,但与文本高度一致。最优 $s$ 值取决于应用场景:创意探索倾向于较低的引导值,而精确遵循提示则需要更高的引导值。 + +### Imagen:基于语言理解的级联扩散 + +- **Imagen**(Saharia 等人,2022)证明了强大的文本编码器比更大的图像模型更重要。Imagen 没有使用 CLIP,而是采用一个冻结的 **T5-XXL** 语言模型(来自第 07 章)作为文本编码器,该模型对语言语义、组合性和空间关系(如"红色球体上的蓝色方块")有着更丰富的理解。 + +- Imagen 使用了**级联扩散(cascaded diffusion)**方法:基础扩散模型生成 64x64 的图像,第一个超分辨率模型放大到 256x256,第二个超分辨率模型达到 1024x1024。每个阶段都是独立的扩散模型,以文本和(对于上采样器)低分辨率图像为条件。这种级联方式避免了在基础分辨率上建模精细细节,使基础模型能够专注于构图和语义,而上采样器则负责处理纹理和清晰度。 + +- Imagen 还引入了**动态阈值(dynamic thresholding)**:在每个去噪步骤中,预测的像素值被裁剪到基于百分位数的范围,而不是固定的 $[-1, 1]$ 范围。这可以防止在高引导尺度下出现饱和伪影,这是扩散模型中的常见问题。 + +### Parti:大规模自回归 + +- **Parti**(Pathways Autoregressive Text-to-Image,Yu 等人,2022)以超大尺度复兴了自回归方法。与 DALL·E 类似,它将图像转换为离散 token(使用 ViT-VQGAN),并用 Transformer 顺序生成。但 Parti 使用了 200 亿参数的编码器-解码器 Transformer(基于 Pathways 架构),并证明了自回归模型在充分扩展后可以达到扩散模型的质量。 + +- Parti 的编码器-解码器架构是与 DALL·E 纯解码器设计的关键区别。文本通过编码器处理;解码器在生成图像 token 时,通过交叉注意力关注编码后的文本。这类似于机器翻译(第 07 章)——你从"文本语言"翻译到"图像语言"。 + +### DiT 与基于流匹配的生成 + +- **扩散 Transformer(DiT)**(Peebles 和 Xie,2023)用纯 Transformer 替换了扩散模型中的 U-Net 主干网络。每个含噪隐空间块被当作一个 token(类似于第 8 章中的 ViT),Transformer 通过自注意力和对文本条件的交叉注意力来处理这些 token。DiT 表明,在扩散任务中,Transformer 的可扩展性比 U-Net 更具可预测性——计算量每翻一倍,FID 分数就会可靠地减半。 + +- **流匹配(flow matching)**(回顾第 8 章)已成为扩散噪声预测范式之外的一种替代方案。模型不再预测需要减去的噪声 $\epsilon$,而是预测一个速度场 $v_\theta(x_t, t)$,该速度场沿直线路径将样本从噪声传输到数据。**Stable Diffusion 3** 和 **Flux** 采用流匹配和**多模态 DiT(MM-DiT)**架构,其中文本和图像 token 由 Transformer 块通过双向注意力联合处理——两种模态互相关注,而不是文本仅通过交叉注意力作为图像特征的条件。 + +![DiT 架构:含噪隐空间块像 ViT 一样被 token 化,由带有自适应层归一化的 Transformer 块处理(用于时间步和类别条件),然后解码回空间隐空间](../images/dit_architecture.svg) + +## 文生视频生成 (Text-to-Video Generation) + +- 文生视频相当于文生图再加上一个严苛的额外约束:**时间连贯性(temporal coherence)**。每一帧必须在内部保持一致(是一张合理的图像),但连续帧之间也必须平滑连接——物体应该自然运动,光照应连续变化,"镜头"应遵循物理上合理的轨迹。可以想象一下绘制一幅风景画和导演一部电影之间的区别。 + +### 时间维度的挑战 + +- 视频引入了图像生成之外的三个挑战。**时间一致性(temporal consistency)**要求物体在各帧之间保持身份不变——第 1 帧中的狗在第 100 帧中应该还是同一条狗。**运动建模(motion modeling)**需要学习物理动态:物体如何运动、重力如何作用、流体如何流动。**计算成本**非常高昂:一段 24 fps、512x512 分辨率的 10 秒视频包含 $10 \times 24 \times 512 \times 512 \times 3 \approx 1.88$ 亿个值,大约是单张图像数据量的 240 倍。 + +### Make-A-Video 与延展至视频方法 + +- **Make-A-Video**(Singer 等人,2022)采用了一种务实的方法:从预训练的文生图模型开始,添加时间层。关键洞察是,你已经拥有了基于数十亿图文对训练的强大文生图模型,你只需要从(未标注的)视频数据中学习运动。 + +- Make-A-Video 在预训练的空间 U-Net 中插入了**时间注意力(temporal attention)**和**时间卷积(temporal convolution)**层。空间层(在图像上预训练)负责外观,而新的时间层(在视频上训练)负责运动。空间自注意力在每帧内部操作;时间注意力在每个空间位置上跨帧操作。这种分解是高效的,因为时间和空间模式在很大程度上是可分离的。 + +- 生成流程与 Imagen 的级联方式类似:基础模型生成 64x64 的 16 帧,然后空间和时间超分辨率模型将分辨率升级到最终大小和帧率。帧插值网络用于提高时间平滑性。 + +### VideoPoet 与基于 Token 的视频模型 + +- **VideoPoet**(Kondratyuk 等人,2024)将视频生成统一到语言建模范式之下。所有模态——文本、图像、视频、音频——都被 token 化为离散序列,一个单一的大语言模型(LLM)被训练来跨所有模态自回归地预测 token。这使得零样本能力成为可能:文生视频、图生视频、视频生音频、视频编辑和视频修补都可以从同一个模型中涌现。 + +- VideoPoet 使用 MAGVIT-v2 编码器(一个来自文件 03 的 3D VQ-VAE)对视频进行 token 化,该编码器联合压缩空间和时间维度。音频使用 SoundStream 进行 token 化。LLM 主干在文本上预训练,然后在多模态 token 序列上微调,学习跨模态的联合分布。 + +### Sora 风格的时间扩散 + +- **Sora**(OpenAI,2024)凭借其生成长时间、连贯、物理合理的视频的能力,将时间扩散带入了主流视野。虽然完整的架构细节尚未公开,但其关键思想是将 DiT 扩展到时空领域:视频帧被分解为**时空块(spacetime patches)**(跨越高度、宽度和时间的三维块),这些块被当作大型 Transformer 的 token 来处理。 + +- 时空块方法意味着模型将视频作为原生的 3D 信号来处理,而不是一系列 2D 帧。这使得模型能够捕获长程的时间依赖关系——模型可以"提前规划"整个视频时长,而不是逐帧生成。 + +- Sora 可以通过调整时空块的数量来处理可变的时长、分辨率和宽高比。以数据原生分辨率进行训练(而不是将所有图像裁剪为正方形)可以提高构图和取景质量。 + +### Wan:开源视频生成 + +- **Wan**(Wan 等人,2025)是一个开源视频生成模型系列(1.3B 和 14B 参数),基于 DiT 主干和 3D VAE 时间压缩。Wan 采用**流匹配**而不是传统的 DDPM 风格扩散,学习从噪声到视频隐空间的直线传输路径。3D VAE 在空间和时间上压缩视频(4 倍时间压缩),DiT 以全 3D 注意力处理生成的时空隐空间 token。 + +- Wan 支持文生视频、图生视频(将静态图像动画化)和视频编辑。14B 模型可以生成长达 5 秒、720p 分辨率的连贯视频,表明当架构和训练方案选择恰当时,开源模型可以接近专有系统的质量。 + +![文生视频流程:文本由语言模型编码,时空噪声由关注文本嵌入的时间扩散 Transformer 去噪,由 3D VAE 解码为视频帧](../images/text_to_video_pipeline.svg) + +## 文生音频生成 (Text-to-Audio Generation) + +- 想象一位电影配乐师阅读剧本并为电影配乐。文生音频模型做着类似的事情:给定一段文本描述("伴有大雨和远处雷声的雷暴"),它们生成相应的音频波形。挑战在于弥合文本的离散、符号化本质与声音的连续、时间性本质之间的差距。 + +### AudioLM:音频的语言建模 + +- **AudioLM**(Borsos 等人,2023)通过自回归预测离散音频 token 来生成音频,采用了与 DALL·E 为图像所用的相同语言建模范式。它使用分层 token 结构:**语义 token**(来自自监督模型如 w2v-BERT,回顾第 9 章)捕获高层次内容(说了什么或演奏了什么),而**声学 token**(来自 SoundStream,一种神经音频编解码器)捕获细粒度的声学细节(听起来如何——音色、录音质量)。 + +- 生成分两个阶段进行。首先,一个 Transformer 在给定可选音频提示的情况下预测语义 token,建立高层次的"内容规划"。其次,另一个 Transformer 以语义 token 为条件预测声学 token,填充声学细节。这种层次结构类似于文生语音流程(第 9 章)——语义 token 扮演音素的角色,声学 token 扮演梅尔频谱图帧的角色。 + +- AudioLM 可以生成语音接续(给定 3 秒语音,生成接下来的 10 秒)、音乐接续和音效,所有这些都来自一个仅在音频数据上训练的模型(预训练不需要文本标签)。 + +### MusicLM:文本条件音乐生成 + +- **MusicLM**(Agostinelli 等人,2023)将 AudioLM 扩展到文本条件下的音乐生成。它添加了一个文本-音频联合嵌入(来自 **MuLan**,一个在音乐-文本对上训练的类 CLIP 模型)来条件化生成。MuLan 嵌入捕获文本描述的语义含义("带有萨克斯独奏的欢快爵士乐")并指导分层 token 生成。 + +- MusicLM 以 24 kHz 的频率生成任意时长的音乐,在数分钟长的作品中保持旋律和节奏的连贯性。它还可以用哼唱的旋律(由音高追踪器提取的旋律 token)加上文本描述作为条件,生成完整的编曲,既遵循哼唱的曲调,又符合文本描述的风格。 + +### MusicGen:高效单阶段生成 + +- **MusicGen**(Copet 等人,2023)简化了多阶段方法。MusicGen 不使用独立的语义和声学模型,而是使用一个单一的自回归 Transformer,直接生成来自音频编解码器的多个码本层级。关键创新是**交织码本模式(interleaved codebook pattern)**:MusicGen 并非在进入下一个时间步之前生成该时间步的所有码本层级,而是以某种模式跨码本和时间步交织 token,从而允许对某些码本层级进行并行解码。 + +- 条件化直接明了:文本由 T5 编码器编码,文本嵌入被前置到音频 token 序列之前(像语言模型中的前缀提示)或通过交叉注意力注入。MusicGen 还支持旋律条件化:参考旋律的色谱图(chromagram,来自第 9 章中讨论的频谱图特征)被编码后与文本条件一起使用。 + +$$p(a_1, \ldots, a_T) = \prod_{t=1}^{T} \prod_{k=1}^{K} p(a_{t,k} \mid a_{ ... [/IMAGE] [AUDIO] ... [/AUDIO] +``` + +- Transformer 然后使用其标准因果(或双向)注意力机制处理整个混合序列。模态分隔 token 起到双重作用:它们向模型告知模态边界,并充当"汇聚点",其表示概括了每个模态段。 + +![交错 token 序列的示意图,显示文本 token、离散图像 token 和音频编解码器 token 流经一个带有模态边界标记的单一 Transformer](../images/multimodal_tokenisation_sequence.svg) + +- 一个关键的设计选择是**token 预算**。一张被分词为 256 个 token 的图像加上 50 个 token 的文本描述,意味着图像消耗的上下文窗口是文本的 5 倍。模型必须在分辨率(更多 token = 更多细节)和上下文长度(更多 token = 更高的内存和计算成本)之间取得平衡。**token 合并**(逐渐合并相似 token)和**自适应分词**(对简单区域使用较少的 token,对复杂区域使用更多 token)等技术有助于管理这种权衡。 + +## 训练配方:分阶段预训练与联合微调 + +- 你不会在教孩子算术之前就教他微积分。同样,你不能从随机初始化开始,在所有模态上同时训练一个统一多模态模型,并期望它能很好地收敛。主导方法是**分阶段训练**,其中模型在精心排序的阶段中逐步学习越来越复杂的跨模态能力。 + +- **阶段 1:单模态预训练。** 每个模态编码器在大型单模态数据集上独立训练。文本主干使用标准语言建模目标(下一 token 预测)在数万亿文本 token 上进行预训练,正如第 7 章一样。视觉编码器在图像分类或自监督目标(MAE、DINO)上预训练,如第 8 章所述。音频编码器在语音识别或音频分类数据上预训练,如第 9 章所述。这一阶段产生了强大的单模态特征提取器。 + +- **阶段 2:跨模态对齐。** 预训练的编码器连接到共享主干,模型在成对的多模态数据(图像-描述对、音频-文本对)上使用对比或生成目标进行训练。在此阶段,编码器权重可能被冻结(以保留单模态知识),仅更新投影层和主干。这是来自本章文件 01 的 CLIP 风格对齐被纳入统一模型的阶段。 + +- **阶段 3:联合多模态预训练。** 所有参数(或大部分)被解冻,模型在单模态和多模态数据的混合上训练,使用对所有模态 token 的单一下一 token 预测目标。损失函数为: + +$$\mathcal{L} = -\sum_{t=1}^{T} \log p_\theta(x_t \mid x_{ 0) & (p_img[:, 0] > 0) & (p_img[:, 0] < 640) & \ + (p_img[:, 1] > 0) & (p_img[:, 1] < 480) +depth = p_cam[mask, 2] + +plt.figure(figsize=(8, 5)) +plt.scatter(p_img[mask, 0], p_img[mask, 1], c=depth, cmap="viridis", s=5) +plt.colorbar(label="深度 (米)") +plt.xlim(0, 640); plt.ylim(480, 0) +plt.title("投影到相机图像上的LiDAR点") +plt.xlabel("u (像素)"); plt.ylabel("v (像素)") +plt.show() +``` + +2. 使用贝叶斯对数几率更新构建一个简单的2D占据网格。模拟一个距离传感器扫描环境,观察地图的生成过程。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# 网格设置:50x50个单元,每个0.2米 +grid_size = 50 +log_odds = jnp.zeros((grid_size, grid_size)) + +# 传感器模型:对数几率更新值 +l_occ = 0.85 # 命中意味着占据的置信度 +l_free = -0.4 # 穿过意味着空闲的置信度 + +# 模拟障碍物:从(5,20)到(5,30)的墙(网格坐标) +wall_y = jnp.arange(20, 30) + +# 机器人在(25, 25),向外扫描 +robot = jnp.array([25, 25]) + +for angle_deg in range(0, 360, 5): + angle = jnp.radians(angle_deg) + direction = jnp.array([jnp.cos(angle), jnp.sin(angle)]) + + for step in range(1, 25): + cell = (robot + direction * step).astype(int) + r, c = int(cell[0]), int(cell[1]) + if r < 0 or r >= grid_size or c < 0 or c >= grid_size: + break + + # 检查此单元是否为墙 + is_wall = (r == 5) and (c >= 20) and (c < 30) + if is_wall: + log_odds = log_odds.at[r, c].add(l_occ) + break + else: + log_odds = log_odds.at[r, c].add(l_free) + +# 将对数几率转换为概率 +prob = 1.0 / (1.0 + jnp.exp(-log_odds)) + +plt.figure(figsize=(6, 6)) +plt.imshow(prob.T, origin="lower", cmap="RdYlGn_r", vmin=0, vmax=1) +plt.colorbar(label="P(被占据)") +plt.plot(25, 25, "b*", markersize=10, label="机器人") +plt.legend() +plt.title("贝叶斯更新生成的2D占据网格") +plt.show() +``` + +3. 使用视差从立体图像对计算深度。模拟两个相机视角下的3D点,计算视差并恢复深度。 +```python +import jax +import jax.numpy as jnp + +# 相机参数 +f = 500.0 # 焦距(像素) +b = 0.12 # 基线(米,12厘米) + +# 已知深度的3D点 +depths_true = jnp.array([5.0, 10.0, 20.0, 50.0, 100.0]) + +# 视差 = f * b / Z +disparities = f * b / depths_true + +# 从视差恢复深度 +depths_recovered = f * b / disparities + +for z, d, z_r in zip(depths_true, disparities, depths_recovered): + print(f"真实深度: {z:6.1f}米 视差: {d:6.2f}像素 恢复值: {z_r:6.1f}米") + +# 注意:视差与深度成反比 +# 近处物体视差大,远处物体视差小 +# 这就是为什么立体视觉在近距离最准确 +``` diff --git a/chapter 11: autonomous systems/02. robot learning.md b/chapter 11: autonomous systems/02. robot learning.md new file mode 100644 index 0000000..9e52c4d --- /dev/null +++ b/chapter 11: autonomous systems/02. robot learning.md @@ -0,0 +1,298 @@ +# 机器人学习 + +*机器人学习弥合了算法与物理行动之间的鸿沟。本章涵盖运动学、动力学、经典控制、模仿学习、仿真到现实迁移、操作、移动和安全——这些技术赋予机器人在现实世界中移动、抓取、行走和交互的能力。* + +- 在前面的章节中,我们研究了如何感知世界(第8章,第11章文件1)以及如何从数据中学习(第6章)。但感知和学习还不够。机器人必须**行动**:移动手臂抓取杯子、在不平坦的地形上行走、或在仓库中导航。这就是机器人学习的用武之地。 + +- 核心挑战在于物理世界是连续的、高维的、接触丰富的且不宽容的。图像识别中的分类错误只是标签错误,而机器人学中的控制错误则意味着机器人损坏或物体掉落。两者的代价截然不同。 + +## 机器人运动学 + +- **运动学**描述运动的几何关系,不考虑力。机器人手臂是由关节连接的刚性连杆组成的链条。每个关节有一个自由度(DoF):要么旋转(旋转关节),要么滑动(棱柱关节)。 + +- 机器人的**构型**是所有关节角度(或位移)的集合 $\\mathbf{q} = [q_1, q_2, \\ldots, q_n]^T$。这个向量位于**关节空间**(或构型空间)中,这是一个$n$维空间,每个轴对应一个关节。一个6自由度机器人手臂有一个6维构型空间。 + +![2连杆机器人手臂:关节角度q1和q2通过正向运动学确定末端执行器位置](../images/robot_arm_fk.svg) + +- **正向运动学(FK)**根据给定的关节角度计算末端执行器("手")的位置和姿态。这是一个从关节空间映射到**任务空间**(末端执行器的3D位置和姿态,也称为笛卡尔空间)的函数 $\\mathbf{x} = f(\\mathbf{q})$。 + +- 每个关节由一个$4 \\times 4$齐次变换矩阵描述(回顾第2章的仿射变换)。**Denavit-Hartenberg(DH)约定**用四个参数参数化每个关节:连杆长度$a$、连杆扭转角$\\alpha$、连杆偏移$d$和关节角度$\\theta$。关节$i$的变换为: + +$$T_i = \\begin{bmatrix} \\cos\\theta_i & -\\sin\\theta_i \\cos\\alpha_i & \\sin\\theta_i \\sin\\alpha_i & a_i \\cos\\theta_i \\\\ \\sin\\theta_i & \\cos\\theta_i \\cos\\alpha_i & -\\cos\\theta_i \\sin\\alpha_i & a_i \\sin\\theta_i \\\\ 0 & \\sin\\alpha_i & \\cos\\alpha_i & d_i \\\\ 0 & 0 & 0 & 1 \\end{bmatrix}$$ + +- 完整的正向运动学是所有关节变换的乘积:$T_{0 \\to n} = T_1 T_2 \\cdots T_n$。这是矩阵乘法链式变换(第2章):每个关节的变换依次应用,将坐标系从基座旋转和平移到末端执行器。 + +- **逆向运动学(IK)**是反向问题:给定期望的末端执行器姿态$\\mathbf{x}^*$,求关节角度$\\mathbf{q}$使得$f(\\mathbf{q}) = \\mathbf{x}^*$。这要难得多,因为: + + - 映射是非线性的(涉及正弦和余弦)。 + - 可能有多个解(不同的手臂构型可以到达同一点)。 + - 可能没有解(目标超出可达范围)。 + +- 解析解只存在于特定的机器人几何构型中。对于通用机器人,IK使用**雅可比矩阵**迭代求解。雅可比矩阵$J(\\mathbf{q})$将关节角度的微小变化与末端执行器位置的微小变化联系起来(回顾第3章的雅可比矩阵): + +$$\\dot{\\mathbf{x}} = J(\\mathbf{q}) \\dot{\\mathbf{q}}$$ + +- 要将末端执行器移动一个小的量$\\Delta \\mathbf{x}$,我们需要$\\Delta \\mathbf{q} = J^{-1} \\Delta \\mathbf{x}$(当$J$不是方阵时使用伪逆$J^+ \\Delta \\mathbf{x}$)。这个过程迭代进行,直到末端执行器到达目标,本质上就是将牛顿法(第3章)应用于运动学方程。 + +- 在**奇异点**附近,雅可比矩阵的秩下降(某些列变得线性相关,如我们在第2章中研究的)。物理上这意味着机器人失去一个自由度:无论关节移动多快,末端执行器都无法在某些方向上移动。伪逆在奇异点附近会爆炸,因此使用阻尼最小二乘法(加入正则化项$\\lambda^2 I$): + +$$\\Delta \\mathbf{q} = J^T(JJ^T + \\lambda^2 I)^{-1} \\Delta \\mathbf{x}$$ + +## 动力学与控制 + +- **动力学**将力引入画面。机器人手臂的运动方程遵循**操作臂方程**: + +$$M(\\mathbf{q})\\ddot{\\mathbf{q}} + C(\\mathbf{q}, \\dot{\\mathbf{q}})\\dot{\\mathbf{q}} + \\mathbf{g}(\\mathbf{q}) = \\boldsymbol{\\tau}$$ + +- 其中$M(\\mathbf{q})$是质量(惯性)矩阵,$C(\\mathbf{q}, \\dot{\\mathbf{q}})$捕获科里奥利力和离心力效应,$\\mathbf{g}(\\mathbf{q})$是重力向量,$\\boldsymbol{\\tau}$是关节力矩向量(控制输入)。这是一个二阶微分方程组,每个关节一个方程。 + +- 质量矩阵$M$总是对称正定的(回顾第2章,正定矩阵保证唯一最小值,在这里它确保系统对施加的力矩有可预测的响应)。 + +- **PID控制**是机器人学中使用最广泛的控制器。对于每个关节,它根据误差$e(t) = q_{\\text{期望}}(t) - q_{\\text{实际}}(t)$计算力矩: + +$$\\tau(t) = K_p e(t) + K_i \\int_0^t e(s) \\, ds + K_d \\dot{e}(t)$$ + +- 三个项有直观的作用: + - **比例项**($K_p$):与当前误差成比例地校正。误差越大 → 校正越大。就像弹簧将关节拉向目标。 + - **积分项**($K_i$):累积过去的误差以消除稳态偏差。如果关节持续欠调,积分项会积累并提供额外的推力。 + - **微分项**($K_d$):对误差变化率作出反应,提供阻尼。它随着误差减小而减缓响应,防止过冲和震荡。 + +![PID控制器调参:高Kp震荡,高Kd迟钝,调谐好的PID快速到达目标](../images/pid_response.svg) + +- 调整$K_p, K_i, K_d$是一种平衡:$K_p$太大会引起震荡,$K_d$太大会使系统反应迟钝,$K_i$太大会导致积分饱和(在持续误差期间积分无限增长)。 + +- **模型预测控制(MPC)**具有前瞻性。在每个时间步,它求解一个优化问题:找到未来控制序列,在有限时域内最小化代价函数(例如,跟踪误差+控制能量),并满足动力学模型和约束条件。只应用第一个控制量,然后在下一个时间步重复该过程。 + +$$\\min_{\\mathbf{u}_{0:T}} \\sum_{t=0}^{T} \\left[ \\|\\mathbf{x}_t - \\mathbf{x}_t^*\\|_Q^2 + \\|\\mathbf{u}_t\\|_R^2 \\right] \\quad \\text{subject to} \\quad \\mathbf{x}_{t+1} = f(\\mathbf{x}_t, \\mathbf{u}_t)$$ + +- 这里$\\|\\mathbf{x}\\|_Q^2 = \\mathbf{x}^T Q \\mathbf{x}$是使用正定矩阵$Q$(第2章)的加权范数,允许对不同状态误差进行不同惩罚。MPC自然地处理约束(关节限位、力矩限位、避障),因为它们被显式地包含在优化中。 + +- **阻抗控制**调节力与运动之间的关系,而不是跟踪刚性轨迹。它不命令"到达位置$x$",而是命令"表现得像一个以$x$为中心的弹簧-阻尼系统": + +$$F = K_s(\\mathbf{x}^* - \\mathbf{x}) + D(\\dot{\\mathbf{x}}^* - \\dot{\\mathbf{x}})$$ + +- 其中$K_s$是刚度矩阵,$D$是阻尼矩阵。这使得机器人具有柔顺性:如果它接触到障碍物,它会退让而不是强行通过。阻抗控制对于接触密集型任务(如将销钉插入孔中或将物体递给人类)至关重要。 + +## 模仿学习 + +- 我们可以从示范中学习控制策略,而不是手工设计控制器。人类执行任务,机器人观察,学习算法提取策略。这就是**模仿学习**(或从示范中学习)。 + +- **行为克隆(BC)**是最简单的方法:将示范视为监督学习数据集。给定来自专家的观测-动作对$\\{(\\mathbf{o}_t, \\mathbf{a}_t)\\}$,训练策略$\\pi_\\theta(\\mathbf{a} \\mid \\mathbf{o})$从观测中预测专家的动作。这是标准的监督学习(第6章):最小化损失: + +$$\\mathcal{L}(\\theta) = \\mathbb{E}_{(\\mathbf{o}, \\mathbf{a}) \\sim \\mathcal{D}} \\left[ \\| \\pi_\\theta(\\mathbf{o}) - \\mathbf{a} \\|^2 \\right]$$ + +![行为克隆中的分布偏移:微小误差不断累积,导致学习到的策略远离专家轨迹](../images/distribution_shift_bc.svg) + +- 问题是**分布偏移**(也称为**复合误差问题**)。在训练期间,策略看到的是专家的状态。在部署期间,策略自身的小误差将其推入专家从未访问过的状态。这些不熟悉的状态导致更差的动作,进而导致更不熟悉的状态,误差迅速累积放大。 + +- 想象一下通过观看完美驾驶员来学习开车。你从未见过小幅偏移后会发生什么,因为专家从未偏移过。第一次你稍为偏离时,你完全不知道如何恢复。 + +- **DAgger**(数据集聚合)通过迭代解决这个问题: + 1. 在当前数据上训练策略。 + 2. 在环境中运行策略,收集新状态。 + 3. 请专家用正确的动作标注这些新状态。 + 4. 将新数据添加到数据集并重新训练。 + +- 经过多次迭代,数据集覆盖了学习策略实际访问的状态,而不仅仅是专家的轨迹。策略得到了改善,因为它已经看到并学会了从自己的错误中恢复。 + +- **使用Transformer的动作分块(ACT)**是一种现代方法,策略预测一系列未来动作(一个"块"),而不是一次预测一个动作。它使用带有transformer主干的 conditional VAE 实现。预测动作块更鲁棒,因为它捕获了时间相关性:伸手动作的平滑性编码在块中,而不是依赖于可能漂移的自回归单步预测。 + +- **扩散策略**将扩散模型(第8章)应用于动作生成。它不预测单个动作,而是建模以观测为条件的完整动作分布。从噪声开始,它迭代地去噪以生成动作序列。这自然地处理了**多模态性**:当有多个有效方式完成任务时(从左边或右边伸手),扩散模型可以表示两种模式,而回归策略会平均它们(到达中间某处,可能两种都无效)。 + +## 仿真到现实迁移 + +- 在现实世界中训练机器人是昂贵、缓慢且危险的。一个通过试错学习抓取的机器人可能需要数千次尝试,在这个过程中损坏物体和自身。**仿真**提供了无限、安全、快速的体验。但仿真器并非完美:物理近似、视觉合成、接触简化。 + +- **仿真到现实差距**是仿真性能与真实性能之间的差异。在仿真中完美运行的策略可能在真实机器人上完全失败,因为它过度拟合了仿真器的特定细节。 + +![通过域随机化实现仿真到现实迁移:在大量随机化仿真上训练,使现实世界只是另一种变体](../images/sim_to_real.svg) + +- **域随机化**通过在广泛的仿真器设置上进行训练来应对这一问题。不是使用一种仿真,而是使用数千种具有随机化参数的仿真: + - 物理:摩擦系数、质量、阻尼 + - 视觉:光照、纹理、颜色、相机位置 + - 动力学:电机延迟、噪声水平 + +- 其思想是,如果策略在所有这些变化下都能工作,那么现实世界只是分布中的"另一种变化"。策略学习对随机化属性不变的特征,这些不变特征能够迁移。 + +- **系统辨识**采取相反的方法:不是随机化所有内容,而是仔细测量真实系统的物理参数并将仿真器调谐到匹配。这提供了更精确的仿真,但也更脆弱(任何未建模的效应都会导致差距)。 + +- 在实践中,最好的结果是将两者结合:使用系统辨识使仿真器合理接近,然后使用域随机化覆盖剩余的不确定性。 + +- **通过微调的仿真到现实迁移**主要在仿真中训练,然后进行少量的真实世界微调。仿真提供了良好的初始化,真实世界数据纠正了仿真器特定的偏差。这需要的真实世界数据远少于从头训练。 + +## 机器人世界模型 + +- 上述所有强化学习和模仿学习方法都是**无模型**的:策略通过直接交互(或示范)学习行动,而不显式建模世界如何运作。另一种是**基于模型**的学习:首先学习环境动力学模型,然后使用该模型进行规划或生成合成经验。 + +- **世界模型**学习转移函数$p(s_{t+1} \\mid s_t, a_t)$:给定当前状态和动作,预测下一状态(如第10章所述)。在机器人学中,这意味着预测如果机器人采取特定动作会发生什么:"如果我向左推这个方块,它会滑动3厘米,它后面的杯子会倒下。" + +- 其吸引力在于**样本效率**。现实世界的机器人交互成本高昂。如果机器人能从适量的真实数据中学习一个世界模型,它就可以通过在大脑中滚动模型来"想象"数千条轨迹,在不动用物理世界的情况下规划和完善策略。这类似于棋手通过在脑海中模拟走棋来思考。 + +- **DreamerV3**是一个通用的基于模型的强化学习智能体。它联合学习三个组件: + - **表示模型**:将观测编码为紧凑的潜在状态。 + - **转移模型**(世界模型):根据当前状态和动作预测下一潜在状态。 + - **奖励模型**:从潜在状态预测奖励。 + +- 然后智能体通过在潜在空间中展开转移模型多步来进行"做梦",在这些想象的轨迹上训练策略,并将策略转移到真实环境。关键创新在于所有想象都在潜在空间(紧凑的学习表示)中进行,而不是在像素空间中,使其计算可行。 + +$$\\hat{s}_{t+1} = f_\\theta(s_t, a_t), \\quad \\hat{r}_t = g_\\theta(s_t)$$ + +- 转移模型$f_\\theta$和奖励模型$g_\\theta$在真实经验上训练,策略在想象的展开上训练。这将数据收集与策略优化解耦。 + +- 对于机器人操作,世界模型实现了**心理排练**。在尝试抓取之前,机器人可以在其学习模型上模拟多种方法,并选择最可能成功的一种。这对于接触密集型任务尤其有价值,因为在这些任务中现实世界的试错既慢又危险。 + +- 世界模型也自然地与**仿真到现实迁移**相关联:在真实数据上训练的世界模型实际上是一个自动捕获真实世界物理的学习型仿真器,完全绕过了仿真到现实差距。对于理解良好的场景,它可能不如手工构建的仿真器精确,但它捕获了手工仿真器常常出错的效果(摩擦、形变、接触动力学)。 + +- **JEPA**(联合嵌入预测架构,在第10章中介绍)提供了像素级预测的替代方案。JEPA不在像素空间预测精确的未来观测,而是在嵌入空间中预测:"下一状态的潜在表示将接近该向量。"这避免了预测像素级完美未来的困难(既无必要又计算浪费),并专注于预测对决策重要的未来方面。 + +- 世界模型的局限性在于**复合预测误差**。转移模型中的微小不准确性在长程展开中积累,导致想象的轨迹偏离现实。缓解措施包括:短想象时域、集成模型(使用不确定性检测预测何时变得不可靠)、以及定期用新的真实世界数据校准模型。 + +## 操作 + +- **操作**是使用机器人末端执行器与物体交互的艺术:抓取、放置、推、插入、组装。 + +- **抓取**是基础的操作技能。目标是找到一个稳定的抓取姿态:夹爪的位置和方向,能够牢固地抓住物体。 + +- **解析抓取规划**使用物理学。如果接触力能够抵抗外部扳手(力和力矩),则抓取是稳定的。对于平行夹爪,最简单的标准是**力闭合**条件:接触法线必须跨越所有力的方向,使抓取能够抵抗任何扰动。这涉及检查抓取扳手矩阵的秩,是第2章秩概念的直接应用。 + +- **数据驱动的抓取**学习从感官输入预测抓取成功。给定桌子上物体的深度图像,网络预测每个候选夹爪姿态的抓取质量分数。**GraspNet**和类似架构使用点云编码器(PointNet风格,第8章)来预测带有置信度分数的6自由度抓取姿态(位置+方向)。 + +- **灵巧操作**超越了简单的抓取和放置。多指手具有20+自由度,可以执行手中旋转(在手指间旋转笔)、工具使用和精细组装等任务。状态空间巨大且接触复杂,使其成为机器人学中最困难的问题之一。 + +- 学习灵巧操作通常使用带有大量域随机化的仿真中的强化学习(第6章)。OpenAI用Shadow手解决魔方的工作就是在仿真中使用随机化物理训练PPO策略,最终实现了向真实机器人手的迁移。 + +- **接触密集型任务**如销钉入孔或擦拭表面,要求机器人与环境保持受控接触。这些任务需要力传感和柔顺控制(阻抗控制),并且难以准确仿真,因为接触物理众所周知地难以建模。 + +## 移动 + +- 移动是让机器人的身体在世界中移动:行走、奔跑、攀爬、游泳。与操作的关键区别在于机器人必须在移动时保持平衡,并且与地面的接触点随时间变化。 + +- **腿式移动**具有挑战性,因为它本质上是不稳定的。单步站立的双足机器人(类人机器人)就像一个倒立摆。质心必须保持在支撑多边形(与地面接触的脚的凸包)上方,否则机器人会摔倒。 + +- **零力矩点(ZMP)**是地面上重力和惯性力产生的净力矩为零的点。如果ZMP保持在支撑多边形内,机器人就不会翻倒。传统的人形机器人控制器(如本田ASIMO)规划使ZMP保持在边界内的轨迹。 + +- **中央模式发生器(CPG)**是受生物学启发的基于振荡器的控制器。动物使用脊髓中的神经回路产生有节奏的移动模式(行走、小跑、奔跑),无需大脑持续参与。CPG模型使用耦合微分方程: + +$$\\dot{\\phi}_i = \\omega_i + \\sum_j w_{ij} \\sin(\\phi_j - \\phi_i - \\psi_{ij})$$ + +- 其中$\\phi_i$是振荡器$i$的相位,$\\omega_i$是自然频率,$w_{ij}$是耦合强度,$\\psi_{ij}$是期望的相位偏移。不同的相位关系产生不同的步态:所有腿同步(跳跃)、交替配对(小跑)、顺序(行走)。正弦耦合自然地同步振荡器,类似于傅里叶级数(第3章)如何将运动分解为频率分量。 + +- **用于移动的强化学习**已成为敏捷四足和类人机器人的主要方法。机器人在仿真中通过试错学习策略$\\pi(\\mathbf{a} \\mid \\mathbf{o})$(第6章),奖励包括前进速度、稳定性和能效,惩罚包括摔倒、关节限位违规和抖动运动。 + +- 近期工作(如Agility Robotics、Boston Dynamics和学术实验室)的关键洞见是,RL训练的移动策略远优于手工设计的控制器。它们自然学会从推动中恢复、适应地形变化,并处理没有工程师能预料到的情况。训练通常使用PPO(第6章)结合域随机化。 + +- **四足机器人**(如Boston Dynamics Spot或Unitree Go2)已成为腿式机器人的主力。四条腿提供固有稳定性(三条腿的三角支撑总能在一条腿移动时支撑身体)。四足机器人的RL策略实现了令人印象深刻的结果:以3+米/秒奔跑、爬楼梯、在岩石地形上导航以及从踢击中恢复。 + +- **类人机器人移动**更难,因为双足机器人有更小的支撑多边形和更高的质心。最近的进展(Tesla Optimus、Figure、Unitree H1)使用在仿真中训练的RL,配以精心的奖励塑造。类人机器人必须学会的不仅仅是行走,还要协调手臂摆动以保持平衡、在不平坦表面上导航以及从扰动中恢复。 + +## 机器人学习中的安全性 + +- 一个为了学习而随机探索的机器人(如在RL中)可能会损坏自身、环境或附近的人类。**安全的机器人学习**约束探索以避免灾难性后果。 + +- **约束RL**向MDP(第6章)添加安全约束。目标变为:在满足$J_c(\\pi) \\leq d$的条件下最大化奖励,其中$J_c$是期望的累积代价(如碰撞事件),$d$是最大允许代价。像约束策略优化(CPO)这样的算法扩展了PPO以处理这些约束。 + +- **安全包络**定义了机器人绝不能越过的硬边界,无论学习策略如何输出。一个安全控制器监控机器人状态,并在即将违反约束时覆盖学习策略(例如,接近关节限位、在人类附近移动过快、或超过力阈值)。这是一种分层架构:学习算法处理性能,安全层处理约束。 + +- **风险感知规划**显式地建模环境和机器人自身状态估计中的不确定性。它不是为最可能的结果进行规划,而是在置信区间内为最坏情况进行规划。这与条件数概念(第2章)相关:良态系统对扰动具有鲁棒性,风险感知规划寻求在扰动下仍保持安全的控制策略。 + +## 编程任务(使用CoLab或notebook) + +1. 实现一个简单2连杆平面机器人手臂的正向运动学。计算并可视化不同关节角度下的末端执行器位置。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +def forward_kinematics(q1, q2, l1=1.0, l2=0.8): + """计算2连杆手臂的关节和末端执行器位置。""" + x1 = l1 * jnp.cos(q1) + y1 = l1 * jnp.sin(q1) + x2 = x1 + l2 * jnp.cos(q1 + q2) + y2 = y1 + l2 * jnp.sin(q1 + q2) + return jnp.array([0, x1, x2]), jnp.array([0, y1, y2]) + +fig, ax = plt.subplots(figsize=(6, 6)) +configs = [(0.5, 0.3), (1.0, -0.5), (1.5, 1.0), (2.0, -1.5)] +colors = ["#e74c3c", "#3498db", "#27ae60", "#9b59b6"] + +for (q1, q2), c in zip(configs, colors): + xs, ys = forward_kinematics(q1, q2) + ax.plot(xs, ys, "o-", color=c, linewidth=2, markersize=6, + label=f"q=({q1:.1f}, {q2:.1f})") + +ax.set_xlim(-2, 2); ax.set_ylim(-2, 2) +ax.set_aspect("equal"); ax.grid(True); ax.legend() +ax.set_title("2连杆机器人手臂:正向运动学") +plt.show() +``` + +2. 使用雅可比伪逆实现逆向运动学。从随机构型开始,迭代地将末端执行器移动到目标。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +l1, l2 = 1.0, 0.8 + +def end_effector(q): + x = l1 * jnp.cos(q[0]) + l2 * jnp.cos(q[0] + q[1]) + y = l1 * jnp.sin(q[0]) + l2 * jnp.sin(q[0] + q[1]) + return jnp.array([x, y]) + +jacobian_fn = jax.jacobian(end_effector) + +target = jnp.array([0.5, 1.2]) +q = jnp.array([0.1, 0.1]) +trajectory = [end_effector(q)] + +for _ in range(50): + pos = end_effector(q) + error = target - pos + if jnp.linalg.norm(error) < 1e-4: + break + J = jacobian_fn(q) + # 阻尼伪逆处理接近奇异点的情况 + dq = J.T @ jnp.linalg.solve(J @ J.T + 0.01 * jnp.eye(2), error) + q = q + dq + trajectory.append(end_effector(q)) + +traj = jnp.stack(trajectory) +plt.plot(traj[:, 0], traj[:, 1], "b.-", label="末端执行器路径") +plt.plot(*target, "r*", markersize=15, label="目标点") +plt.gca().set_aspect("equal"); plt.grid(True); plt.legend() +plt.title(f"IK在{len(trajectory)-1}步内收敛") +plt.show() +``` + +3. 模拟一个简单的PID控制器跟踪期望的关节轨迹。观察调参对增益的影响。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# 期望轨迹:平滑正弦运动 +dt = 0.01 +t = jnp.arange(0, 5, dt) +q_desired = jnp.sin(2 * t) + +# 模拟二阶动力学:m * q_ddot + b * q_dot = tau +m, b_damp = 1.0, 0.5 + +for Kp, Kd, Ki, label in [(10, 5, 0, "仅PD"), (10, 5, 2, "PID"), (50, 10, 2, "激进PID")]: + q, q_dot, integral = 0.0, 0.0, 0.0 + qs = [] + for i in range(len(t)): + error = q_desired[i] - q + integral += error * dt + d_error = -q_dot # 误差导数(此处简化,已知期望速度) + tau = Kp * error + Kd * d_error + Ki * integral + q_ddot = (tau - b_damp * q_dot) / m + q_dot += q_ddot * dt + q += q_dot * dt + qs.append(float(q)) + + plt.plot(t, qs, label=label) + +plt.plot(t, q_desired, "k--", label="期望值", linewidth=2) +plt.xlabel("时间 (秒)"); plt.ylabel("关节角度") +plt.legend(); plt.title("PID控制器跟踪") +plt.show() +``` diff --git a/chapter 11: autonomous systems/03. vision-language-action models.md b/chapter 11: autonomous systems/03. vision-language-action models.md new file mode 100644 index 0000000..68c0d46 --- /dev/null +++ b/chapter 11: autonomous systems/03. vision-language-action models.md @@ -0,0 +1,217 @@ +# 视觉-语言-动作模型 + +*视觉-语言-动作模型(VLA)将视觉理解、语言理解和行动控制统一到单个神经网络中。本章涵盖VLA架构、动作标记化、RT-2、Octo、OpenVLA、预训练策略、泛化能力、与具体形态无关的模型以及基准测试。* + +- 在前面的文件中,我们涵盖了感知(感知世界)和机器人学习(控制身体)。传统上,这些是独立的流程:感知模块检测物体,语言模块解释指令,控制模块生成动作。每个模块独立设计、训练和调试。 + +- **视觉-语言-动作模型(VLA)**将这一流程压缩为单个神经网络。模型接收图像(视觉)和自然语言指令(语言),并输出电机命令(动作)。一个模型,端到端。 + +- 这沿袭了我们在第10章看到的统一趋势:正如多模态模型将视觉和语言理解合并到一个架构中一样,VLA将这一趋势扩展到物理行动。关键洞见在于,语言为指定任务提供了自然、灵活的接口("拿起红色杯子放到架子上"),而大规模预训练的视觉-语言模型已经理解图像和指令。 + +## 从视觉-语言到行动 + +- 回顾第10章,**视觉-语言模型(VLM)**如LLaVA和Flamingo接收图像和文本作为输入,并生成文本作为输出。它们理解场景、回答问题、遵循指令——全部通过语言完成。 + +- VLA提出的问题是:如果输出不是文本而是**机器人动作**呢?模型不再生成"红色杯子在桌子的左侧",而是生成一系列电机命令,驱动手臂去抓取那个杯子。 + +- 关键的架构洞见是,动作可以像单词一样表示为标记。如果VLM使用下一个标记预测逐个生成语言标记,那么VLA以同样的方式生成动作标记。Transformer从根本上并不关心输出标记表示"杯子"还是"将夹爪向前移动2厘米"。 + +- 这重新定义了机器人控制为序列建模问题,这正是transformer擅长的(第7章)。模型学习映射:(图像观测,语言指令)$\\to$(动作标记序列)。 + +## VLA架构 + +![VLA架构:相机图像和语言指令被编码为标记,由LLM主干网络处理,并解码为机器人动作](../images/vla_architecture.svg) + +- 典型的VLA有三个组件: + + - **视觉编码器**:将相机图像处理为视觉标记。通常是预训练的ViT(第8章)或SigLIP编码器(第10章)。图像被分割成块,每个块嵌入为一个标记,与标准视觉transformer完全一样。 + + - **语言模型主干网络**:一个预训练的LLM(例如LLaMA、PaLM),处理交错的视觉标记和语言标记序列。这就是推理发生的地方:模型通过同时关注指令和视觉特征来理解"拿起**红色**杯子"。 + + - **动作头**:将LLM的输出映射到机器人动作。可以是一个简单的MLP,将最后的隐藏状态映射到连续动作值,或者是一种将动作转换为离散标记的方案,由LLM的现有词汇表来预测。 + +- 架构看起来像: + +$$\\text{图像} \\xrightarrow{\\text{ViT}} \\text{视觉标记} \\quad + \\quad \\text{指令} \\xrightarrow{\\text{分词器}} \\text{语言标记} \\quad \\xrightarrow{\\text{LLM}} \\quad \\text{动作标记}$$ + +- 视觉标记和语言标记被拼接(或交错)并输入到transformer主干网络,后者自回归地生成动作标记。这与VLM(第10章)的架构相同,但输出模态是动作而非文本。 + +## 动作标记化 + +- 机器人动作是连续的:关节速度、末端执行器位置、夹爪宽度。这些必须转换为离散标记才能让LLM生成。 + +![动作标记化:连续动作值被分箱为离散索引,LLM将其作为标记生成](../images/action_tokenisation.svg) + +- 最简单的方法是**均匀离散化**。每个动作维度被划分为$N$个箱,覆盖有效值范围。例如,如果x方向速度范围从-0.1到0.1米/秒,使用256个箱,每个箱代表$\\frac{0.2}{256} \\approx 0.8$毫米/秒。动作值被映射到最近的箱索引,该索引成为一个标记。 + +- 对于7个动作维度(6自由度+夹爪)和每个维度256个箱,动作词汇表有$7 \\times 256 = 1792$个标记。这些被添加到LLM现有的文本词汇表中。模型每个维度生成一个动作标记,自回归地,就像生成单词一样。 + +- **动作分块**一次预测多个未来时间步,而不是单个动作。如果块大小为$H$,模型输出$H \\times d$个标记(其中$d$是动作维度)。这对于平滑、时间连贯的运动至关重要。一次预测一步会产生抖动行为,因为每次预测都是独立的。分块迫使模型规划一个短轨迹,捕获时间结构。 + +- 更复杂的方法使用**学习型标记化**,通过VQ-VAE(第10章)。VQ-VAE编码器将连续动作序列映射到离散码本索引序列,解码器从这些索引重建连续动作。LLM然后生成码本索引,而不是均匀分箱的值。这类似于图像分词器(第10章)如何将视觉信息压缩为紧凑的离散编码。 + +## 关键VLA模型 + +- **RT-2**(机器人Transformer 2,Google DeepMind)是第一个大规模VLA。它使用预训练的VLM(PaLM-E或PaLI-X,参数高达55B)并在机器人示范数据上微调。动作表示为文本字符串:标记序列"1 128 91 241 5 101 127"编码了一个7维动作(每个数字是箱索引)。 + +- RT-2展示了一个显著特性:来自VLM主干网络的**涌现能力**迁移到了机器人领域。模型可以遵循涉及从未在机器人数据中见过的概念的指令(例如,"将香蕉移动到以A开头的国家"需要视觉物体识别+世界知识+行动)。VLM的语言理解和视觉推理"免费"获得。 + +- RT-2的局限性在于它是在单个机器人形态(特定的手臂和特定的夹爪)的数据上训练的。它不能泛化到不同的机器人。 + +- **Octo**(加州大学伯克利分校)是一个开源的、**与具体形态无关**的VLA,设计用于跨不同机器人平台工作。关键创新包括: + + - **扩散动作头**,而不是自回归标记预测。动作头获取transformer的输出,并通过去噪扩散过程(第8章)生成动作。这自然地处理了多模态动作分布(见下图),即存在多个有效的任务完成方式。 + +![多模态动作分布:回归将两条有效路径平均为一条穿过障碍物的无效路径](../images/multimodal_action_distribution.svg) + + - **灵活的观测和动作空间**:Octo为不同的机器人配置使用特定于任务的标记化器。它在Open X-Embodiment数据集上预训练,该数据集包含来自22种不同机器人形态的示范。 + + - **高效微调**:Octo只需100个示范就可以微调到新机器人,使其适用于数据有限的实验室。 + +- **OpenVLA**(斯坦福大学、加州大学伯克利分校)采用微调现有开源VLM(基于Llama)用于机器人技术的方法。它使用7B参数主干网络、均匀动作标记化(每个维度256个箱),并在Open X-Embodiment数据上训练。其优势在于简单性:架构是标准的VLM,动作标记被附加到词汇表中,使其易于使用现有的LLM基础设施进行训练和部署。 + +- **$\\pi_0$**(Physical Intelligence)代表了当前最高水平。它使用预训练的VLM主干网络和**流匹配**动作头(第8章)。流匹配通过学习一个速度场将噪声传输到动作分布来生成动作,产生平滑、时间连贯的动作轨迹。$\\pi_0$展示了卓越的通用性,在多种机器人形态(包括双臂操作和灵巧手控制)上执行任务。 + +## 预训练配方 + +- VLA极大地受益于预训练的VLM主干网络,这些网络已经理解视觉场景和语言。训练流程通常分为几个阶段: + + 1. **VLM预训练**:在数十亿来自互联网的图像-文本对(CLIP、SigLIP、LLaVA风格的训练,如第10章所述)上训练(或使用现成的)视觉-语言模型。 + + 2. **机器人数据协同训练**:在互联网数据和机器人示范数据的混合上微调VLM。互联网数据防止视觉和语言理解的灾难性遗忘,而机器人数据教授动作生成。混合比例很重要:机器人数据过多会降低语言理解,过少则无法学习动作。 + + 3. **特定任务微调**:可选地在特定任务或机器人的示范上进行微调,通常使用LoRA(第10章)保持可训练参数数量较少。 + +- 机器人数据的数量比互联网数据少数个数量级。VLM可能在上数十亿张图像上预训练,但最大的机器人数据集(Open X-Embodiment)在所有形态上只有数百万帧。这种数据稀缺性正是从预训练VLM开始至关重要的原因:视觉和语言表示可以迁移,只有动作映射需要从有限的机器人数据中学习。 + +## 泛化能力 + +- VLA的承诺是**泛化**:执行训练中未见的任务,使用未见过的物体,在未见过的环境中,遵循未见过的指令。 + +- VLA沿多个轴进行泛化: + + - **新颖物体**:VLM主干网络从互联网预训练中识别物体。如果模型从网络图像中知道"螺丝刀"长什么样,即使没有机器人示范涉及螺丝刀,它也能操作螺丝刀。 + + - **新颖指令**:组合语言理解使模型能够遵循已知概念的新组合。"将蓝色方块堆叠在绿色方块上"即使训练只展示了堆叠红色方块也能工作,因为模型从语言预训练中理解了颜色形容词。 + + - **新颖环境**:在一定程度上,VLA跨视觉域(不同的桌子、光照、背景)迁移,因为视觉编码器在多样化的网络图像上预训练。但这有局限性:在实验室训练的机器人可能在杂乱厨房中遇到困难。 + + - **新颖形态**:这是最难的轴。不同机器人有不同的动作空间(关节角度 vs. 末端执行器速度)、不同的传感器(腕部相机 vs. 头顶相机)和不同的物理能力。与形态无关的模型如Octo和$\\pi_0$通过灵活的标记化器和跨多种机器人类型的预训练来解决这一问题。 + +- 泛化能力通过**保留任务**进行评估:机器人被要求执行从未训练过的任务。在新颖任务上50-80%的成功率被认为是强劲的结果,而在分布内任务上成功率通常>90%。随着模型规模扩大和机器人数据集增长,这一差距正在缩小。 + +## 与形态无关的模型 + +- 该领域正朝着"一个模型,多种机器人"的方向发展。不再为每个机器人训练单独的策略,而是单个VLA处理多种形态。 + +- 这需要解决**动作空间不匹配**问题。一个7自由度手臂带平行夹爪有7个动作维度。双臂设置是14个。四足机器人有12个。类人机器人有30个以上。动作标记化必须足够灵活以处理所有这些。 + +- 解决方案包括: + - **填充动作向量**:使用最大的动作空间,较小的用零填充。 + - **每种形态的动作头**:共享的transformer主干网络,每种机器人类型有单独的小型MLP。 + - **归一化动作表示**:在共同框架中表示所有动作(如世界坐标系中的末端执行器速度),使产生类似末端执行器运动的不同机器人共享相同的动作标记。 + +- 共享主干网络学习通用的视觉和语言理解,加上通用的操作策略(从上方接近、对齐物体、闭合夹爪)。特定于形态的组件只需要将这些高层策略转化为具体的电机命令。 + +## 基准测试与评估 + +- 评估VLA具有独特的挑战性,因为它需要物理机器人实验(或高保真仿真)。 + +- **SIMPLER**(机器人学习模拟操作策略评估)提供了标准化的仿真环境,无需物理硬件即可比较VLA性能。它与现实世界的成功率相关性良好,并实现了可复现的基准测试。 + +- **现实世界评估**仍然是金标准。典型协议: + 1. 定义一组具有明确成功标准的任务(物体到达目标位置、选择正确物体、在时限内完成任务)。 + 2. 每次任务运行$N$次试验(通常10-50次)。 + 3. 报告成功率及置信区间。 + 4. 包括保留任务(从未训练过的)以衡量泛化能力。 + +- **Open X-Embodiment**数据集和基准测试汇总了来自22个机构、跨越多个机器人平台的机器人数据。它提供了共享示范的标准格式和用于跨形态迁移的通用评估套件。 + +## 编程任务(使用CoLab或notebook) + +1. 实现动作标记化:将连续动作离散化为箱并重建。观察量化误差随箱数量的变化。 +```python +import jax.numpy as jnp + +# 连续动作:7个维度(6自由度+夹爪) +action_true = jnp.array([0.023, -0.051, 0.012, 0.1, -0.03, 0.005, 0.8]) +action_min = jnp.array([-0.1, -0.1, -0.1, -0.5, -0.5, -0.5, 0.0]) +action_max = jnp.array([ 0.1, 0.1, 0.1, 0.5, 0.5, 0.5, 1.0]) + +for n_bins in [16, 64, 256, 1024]: + # 标记化:将连续值映射为箱索引 + normalised = (action_true - action_min) / (action_max - action_min) + tokens = jnp.clip((normalised * n_bins).astype(int), 0, n_bins - 1) + + # 去标记化:将箱索引映射回连续值 + reconstructed = (tokens + 0.5) / n_bins * (action_max - action_min) + action_min + + error = jnp.linalg.norm(action_true - reconstructed) + print(f"箱数={n_bins:4d} 标记={tokens} 误差={error:.6f}") +``` + +2. 模拟动作分块与单步预测的比较。生成平滑轨迹,向单步预测添加噪声,并与分块预测比较。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# 真实平滑轨迹(例如,伸手动作) +t = jnp.linspace(0, 2 * jnp.pi, 100) +gt_x = jnp.sin(t) +gt_y = 1 - jnp.cos(t) + +# 单步:每次预测有独立噪声 +rng = jax.random.PRNGKey(42) +noise_ss = jax.random.normal(rng, (100, 2)) * 0.05 +single_step = jnp.stack([gt_x, gt_y], axis=1) + noise_ss +# 单步误差累积漂移 +single_step_cumulative = jnp.cumsum(noise_ss, axis=0) * 0.3 + jnp.stack([gt_x, gt_y], axis=1) + +# 分块(块大小=10):块内噪声关联,更平滑 +chunk_size = 10 +rng2 = jax.random.PRNGKey(7) +chunks = [] +for i in range(0, 100, chunk_size): + chunk_noise = jax.random.normal(jax.random.fold_in(rng2, i), (2,)) * 0.05 + chunk = jnp.stack([gt_x[i:i+chunk_size], gt_y[i:i+chunk_size]], axis=1) + chunks.append(chunk + chunk_noise) +chunked = jnp.concatenate(chunks, axis=0) + +plt.figure(figsize=(8, 4)) +plt.plot(gt_x, gt_y, "k-", linewidth=2, label="真实轨迹") +plt.plot(single_step_cumulative[:, 0], single_step_cumulative[:, 1], + "r-", alpha=0.7, label="单步(漂移)") +plt.plot(chunked[:, 0], chunked[:, 1], "b-", alpha=0.7, label="分块(稳定)") +plt.legend(); plt.axis("equal"); plt.grid(True) +plt.title("动作分块 vs 单步预测") +plt.show() +``` + +3. 可视化VLA动作分布如何是多模态的。使用简单的2D高斯混合来展示为什么扩散/流匹配动作头优于回归。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# 绕过障碍物的两种有效方式:左边或右边 +rng = jax.random.PRNGKey(0) +k1, k2 = jax.random.split(rng) + +mode1 = jax.random.normal(k1, (200, 2)) * 0.15 + jnp.array([-1.0, 0.5]) +mode2 = jax.random.normal(k2, (200, 2)) * 0.15 + jnp.array([ 1.0, 0.5]) +samples = jnp.concatenate([mode1, mode2]) + +# 回归预测均值 = 模态的均值(无效!) +mean_pred = samples.mean(axis=0) + +plt.figure(figsize=(6, 5)) +plt.scatter(samples[:, 0], samples[:, 1], s=5, alpha=0.5, label="真实动作分布") +plt.plot(*mean_pred, "rx", markersize=15, markeredgewidth=3, label="回归均值(无效!)") +plt.plot(-1, 0.5, "g^", markersize=12, label="模态1(向左)") +plt.plot(1, 0.5, "b^", markersize=12, label="模态2(向右)") +plt.legend(); plt.grid(True) +plt.title("多模态动作:为什么回归失败") +plt.xlabel("动作维度1"); plt.ylabel("动作维度2") +plt.show() +``` diff --git a/chapter 11: autonomous systems/04. self-driving.md b/chapter 11: autonomous systems/04. self-driving.md new file mode 100644 index 0000000..d1a6e17 --- /dev/null +++ b/chapter 11: autonomous systems/04. self-driving.md @@ -0,0 +1,312 @@ +# 自动驾驶汽车 + +*自动驾驶汽车是商业上最先进的自主系统,将感知、预测、规划和控制集成到单个车辆中。本章涵盖自动驾驶堆栈、高精地图、运动预测、规划、端到端驾驶、仿真、安全标准和自主等级。* + +- 自动驾驶汽车可以说是正在大规模尝试的最困难的机器人问题。与在受控环境中运行的工厂机器人不同,自动驾驶汽车必须处理一个开放世界:不可预测的人类驾驶员、乱穿马路的行人、一夜之间出现的施工区域以及每分钟都在变化的天气。 + +- 其风险也异常之高。自动驾驶汽车在高速公路上行驶,周围是脆弱的道路使用者。对于安全关键的故障,误差容限几乎为零。 + +## 自动驾驶堆栈 + +- 经典的自动驾驶架构是一个**模块化流水线**,包含四个阶段,每个阶段作为下一个阶段的输入: + +$$\\text{感知} \\to \\text{预测} \\to \\text{规划} \\to \\text{控制}$$ + +![自动驾驶堆栈:传感器流入感知,感知流入预测、规划,最后流入控制](../images/autonomous_driving_stack.svg) + +- **感知**(本章文件1中已介绍)将原始传感器数据处理为结构化的场景表示:带有3D位置、速度和类别标签的检测物体;车道标线;交通信号灯;可行驶表面边界。 + +- **预测**预测其他交通参与者(车辆、行人、骑行者)未来将如何移动。给定场景的当前状态,预测模块为每个交通参与者输出未来一段时间(通常3-8秒)的轨迹。 + +- **规划**决定主车应该做什么:走哪条路径、何时变道、何时让行、何时加速或刹车。它接收预测的场景,为主车生成一条安全、舒适且向目的地前进的轨迹。 + +- **控制**将规划的轨迹转化为执行器命令:转向角、油门和刹车。这是最底层,将抽象轨迹转化为物理运动。 + +- 模块化设计有明确的工程优势:每个模块可以独立开发、测试和改进。但它也有弱点:误差向下游传播(漏检对规划器是不可见的),并且信息在每个接口处丢失(规划器看到的是边界框,而不是产生它们的丰富传感器数据)。 + +## 高精地图 + +- **高精(HD)地图**是详细、厘米级精度的数字地图,编码道路结构:车道边界、车道连通性(哪个车道在交叉口连接到哪个)、交通标志位置、限速、人行横道位置和路面高程。 + +- 高精地图为驾驶任务提供了强有力的先验。感知模块不需要每帧从头发现车道边界;它只需要将车辆在地图中进行定位,并验证现实是否与存储的结构匹配。这极大地简化了规划。 + +- 构建高精地图需要配备高端LiDAR、相机和RTK-GPS的专业测量车辆。地图必须随着道路变化而维护和更新。这很昂贵,且不容易扩展到地球上的每条道路。 + +- **无图驾驶**(也称为"在线地图构建")旨在消除对预建高精地图的依赖。车辆从传感器实时构建局部地图。像**MapTR**和**MapTRv2**这样的模型使用transformer架构直接从相机图像预测矢量化地图元素(车道中心线、道路边界、人行横道),将多段线输出为有序点序列。 + +- 无图方法用地图精度换取可扩展性:任何汽车能行驶的道路,它都能建图。但它要求感知系统足够鲁棒,能够实时检测所有相关的道路结构,包括在复杂交叉口、高速公路匝道和施工区域中。 + +- 在实践中,许多系统采用混合方法:轻量级地图包含粗略的道路拓扑(来自现有地图提供商),并通过车辆的传感器实时丰富。 + +## 运动预测 + +- 预测其他道路使用者将去哪里是自动驾驶中最困难的子问题之一。人类不可预测,意图是隐藏的,未来可能性的空间迅速分叉。 + +- 预测模型的输入是**场景上下文**:所有检测到的参与者在近期过去(通常1-2秒的历史)的位置和速度,加上静态上下文(车道几何、交通信号、道路边界)。 + +- 输出是每个参与者的一组**预测轨迹**,通常覆盖未来3-8秒。由于未来是不确定的,好的预测模型输出多条可能的轨迹及其相关概率,而不是单一的点估计。 + +- **轨迹预测**作为一个回归问题:预测每个参与者在离散未来时间步的$(x, y)$坐标。损失通常是$K$条预测轨迹上的最小平均位移误差(minADE): + +$$\\text{minADE}_K = \\min_{k \\in \\{1, \\ldots, K\\}} \\frac{1}{T} \\sum_{t=1}^{T} \\| \\hat{\\mathbf{p}}_t^{(k)} - \\mathbf{p}_t \\|_2$$ + +- 这是一个"最佳$K$个"指标:如果模型的$K$个预测中有一个接近真实值,模型就得分。这鼓励多样化的多模态预测。 + +- **社会力模型**将行人行为建模为动力系统,其中每个人受到吸引力(朝向目标)和排斥力(远离其他行人和障碍物)。行人$i$的加速度为: + +$$\\mathbf{a}_i = \\frac{\\mathbf{v}_i^{\\text{期望}} - \\mathbf{v}_i}{\\tau} + \\sum_{j \\neq i} \\mathbf{f}_{ij}^{\\text{排斥}} + \\sum_{\\text{墙壁}} \\mathbf{f}_{\\text{墙壁}}$$ + +- 这是一个与本章文件2中的机器人动力学方程类似的微分方程组。该模型优雅但依赖于手工调谐的力参数,并且在复杂多智能体交互中表现不佳。 + +- **图神经网络(GNN)**用于预测时将场景建模为图:每个参与者是一个节点,边表示空间关系(邻近度、共享车道)。节点之间的消息传递捕获交互:"这辆车正在给那个行人让行"或"这两辆车正在汇入同一条车道。" + +- 现代预测架构(例如**MTR**、**QCNet**)使用基于transformer的模型,联合推理参与者历史、地图上下文和参与者之间的交互。参与者通过交叉注意力关注相关的地图特征(当前车道、即将到来的交叉口)和其他参与者(前车、人行横道上的行人)。输出是一组通过自回归生成或混合模型产生的轨迹假设。 + +- **目标条件预测**首先预测参与者可能去哪里(一组候选目标点,如车道端点或交叉口出口),然后预测到达每个目标的轨迹。这将问题分解为"去哪里"(离散的、可管理的)和"怎么去"(给定目标的连续路径),使多模态预测问题更加可解。 + +## 规划 + +- 给定预测的场景,规划器必须为主车生成一条轨迹。这是一个约束优化问题:找到一条安全、舒适、高效且合法的轨迹。 + +- **基于规则的规划器**将驾驶行为编码为一组if-then规则:"如果行人在人行横道上,让行"、"如果与前车距离小于2秒,不变道"、"如果接近红灯,减速停在停止线处。"这些规则是可解释和可审计的,但对于复杂场景(数千条规则、许多边缘情况、规则间的交互),它们变得难以管理。 + +- **基于优化的规划器**将驾驶形式化为轨迹优化。主车轨迹被参数化(例如,作为未来时间步的$(x, y, \\theta, v)$状态序列),并最小化一个目标函数: + +$$\\min_{\\boldsymbol{\\xi}} \\underbrace{w_1 \\cdot J_{\\text{进度}}(\\boldsymbol{\\xi})}_{\\text{到达目的地}} + \\underbrace{w_2 \\cdot J_{\\text{舒适}}(\\boldsymbol{\\xi})}_{\\text{平稳行驶}} + \\underbrace{w_3 \\cdot J_{\\text{安全}}(\\boldsymbol{\\xi})}_{\\text{避免碰撞}}$$ + +$$\\text{约束条件:运动学约束、限速、车道边界}$$ + +- 进度项惩罚偏离期望路线。舒适项惩罚高横向加速度、加加速度(加速度的导数)和突然转向,因为乘客能感受到这些。安全项惩罚与其他交通参与者的接近程度,使用预测轨迹评估碰撞风险。 + +- 这是约束优化(第3章):在不等式约束下最小化代价函数。权重$w_1, w_2, w_3$权衡竞争目标(激进驾驶更快但更不舒适且更不安全)。 + +- **基于学习的规划器**使用在人类驾驶数据上训练的神经网络生成轨迹。模型观察场景并直接输出规划的轨迹,从专家人类驾驶示例中隐式学习复杂的权衡。 + +- 优势在于人类驾驶行为被整体捕获,包括那些微妙且难以形式化的方面:何时激进地合流、何时在交叉口前微微前移、给骑行者留出多少空间。缺点是来自模仿学习(文件2)的相同分布偏移问题:模型在训练数据中未充分代表的情况下可能表现不可预测。 + +## 端到端驾驶 + +- **端到端驾驶**完全消除了模块边界。单个神经网络接收原始传感器输入(相机图像、LiDAR点云)并直接输出驾驶命令(转向、油门、刹车)或规划轨迹。没有独立的感知、预测或规划模块。 + +- 其吸引力在于整个系统针对最终任务(安全驾驶)进行联合优化,因此没有信息在模块边界丢失。感知模块学习精确提取规划器所需的特征,而不是通用的目标检测结果,后者可能不捕获任务相关的细节。 + +- **UniAD**(统一自动驾驶)是一个里程碑式的端到端架构。它通过BEV编码器处理多相机图像,然后应用一系列基于transformer的模块:跟踪、在线建图、运动预测、占据预测和规划。虽然它有内部模块,但它们都是可微的,并端到端联合训练,规划损失通过整个网络反向传播。 + +- UniAD中的规划模块通过关注预测的BEV特征、预测的参与者轨迹和预测的占据来生成未来主车路径点。这就是多元链式法则(第3章)的实际应用:梯度从规划损失一直流回图像编码器,告诉感知特征如何对规划更有用。 + +- 更近期端到端方法使用VLA风格的架构(本章文件3)。像**DriveVLM**这样的模型接收相机图像和导航指令(或路线),并使用VLM主干网络产生驾驶动作。这带来了大规模预训练(视觉理解、推理)的好处,直接融入驾驶堆栈。 + +- 端到端驾驶中的张力是**可解释性**。模块化系统可以报告"我检测到行人在(x,y)处,预测他们会横穿"——故障模式是可诊断的。端到端系统是一个产生转向角的黑盒。当它失败时,诊断原因很困难,这对安全认证是一个严重问题。 + +## 驾驶世界模型 + +- **世界模型**学习在给定当前状态和主车动作的情况下预测驾驶场景的未来状态:$p(s_{t+1} \\mid s_t, a_t)$(如第10章所述)。在驾驶中,这意味着生成逼真的未来帧或BEV布局:"如果我加速并左转,3秒后的场景会是这样。" + +- 世界模型为自动驾驶提供了两种强大能力: + + - **基于想象的规划**:规划器不是先执行一个动作再看结果,而是可以通过世界模型"想象"多条候选轨迹,评估每条的安全性和舒适性,然后选择最佳的一条。这是基于模型的RL(本章文件2中介绍)应用于驾驶。 + + - **学习型仿真**:在真实驾驶数据上训练的世界模型实际上是一个数据驱动的仿真器。它生成逼真的场景(包括罕见的边缘情况),无需手工构建仿真器的工作。关键是,它捕获了真实驾驶的统计模式:其他驾驶员实际如何表现、光照如何变化、雨水如何影响可见度。 + +- **GAIA-1**(Wayve)是一个用于驾驶的生成式世界模型。给定过去相机帧和主车动作的序列,它自回归地生成未来视频帧。它使用以动作为条件的视频扩散架构。模型学习生成合理的未来:遵守交通规则的车辆、在人行道上行走的行人以及正确变化的交通信号灯——都从训练数据中涌现,而非编程规则。 + +- **DriveDreamer**和**GenAD**采取类似方法,但在BEV空间而非像素空间中操作。预测未来BEV布局比生成完整视频帧更紧凑(类似于机器人学中的DreamerV3在潜在空间而非像素空间中进行预测,如文件2所述)。BEV世界模型预测所有参与者的位置、道路结构的样子以及自由空间的位置,规划器直接使用这些信息。 + +- **神经闭环仿真**使用世界模型替代手工构建的仿真器进行测试。给定真实驾驶日志作为起点,世界模型生成如果主车采取了不同动作会发生什么。这使得反事实评估成为可能:"如果我刹车晚了0.5秒会怎样?"而无需实际重现场景。 + +- 与**JEPA**框架(第10章)的联系在这里很自然。驾驶世界模型不需要预测像素级完美的未来(每个像素的精确RGB值)。它们需要预测对规划重要的方面:参与者在哪、移动速度多快、自由空间在哪。嵌入空间预测(JEPA风格)捕获这些语义上有意义的属性,而无需浪费容量在无关的视觉细节上,如确切的云纹理。 + +- 主要挑战是**长时程保真度**。世界模型随时间累积误差:第2帧的一个小错误会偏移所有后续帧。对于驾驶,3秒的预测时域对战术决策有用(我应该现在合流吗?),但30秒的时域(用于路线规划等战略决策所需)仍然不可靠。当前工作通过重新锚定(定期用真实观测重置模型)和不确定性估计(在预测变得不可靠时标记)来缓解这一问题。 + +## 仿真 + +- 通过在真实道路上驾驶来测试自动驾驶汽车是必要的,但还不够。危险场景(近碰撞、边缘情况)很少见,因此通过行驶里程来测试效率低下。一辆车需要行驶数亿英里才能以统计学方式证明安全性,这是不可行的。 + +- **仿真**提供了无限、可控且安全的测试。在现实世界中罕见的场景(一个孩子跑上马路、轮胎爆胎、突然的障碍物)可以在仿真中测试数百万次。 + +- **CARLA**是一个基于Unreal Engine构建的开源驾驶仿真器。它提供逼真的城市环境、动态天气、交通参与者以及传感器仿真(相机、LiDAR、雷达)。研究人员使用CARLA训练基于RL的驾驶智能体并评估感知算法。 + +- **nuPlan**(Motional)是一个闭环规划基准测试。与开环评估(重放记录数据,比较规划器的输出与人类驾驶员的实际轨迹)不同,闭环评估允许规划器的决策影响仿真:如果规划器决定变道,仿真会相应地演变。这测试了反应性行为,而不仅仅是轨迹相似性。 + +![开环重放日志,无交互;闭环让模型的动作改变仿真状态](../images/open_vs_closed_loop.svg) + +- **开环**和**闭环**评估之间的区别至关重要: + + - 开环:重放记录的场景,计算模型输出与人类驾驶员动作的相似度。这容易设置但具有误导性:一个总是预测"直行"的模型在高速公路上误差可能很低,但在第一个转弯处就会撞车。 + + - 闭环:模型的动作改变仿真状态,仿真相应地演变。这测试了模型从自身错误中恢复和响应动态情况的能力。它昂贵得多,但更有意义。 + +- **场景生成**创建对系统进行压力测试的测试用例。对抗性场景(车辆突然刹车、行人隐藏在停放的汽车后面)通过优化使自动驾驶系统表现最差的情况来生成。这与ML中的对抗训练(第6章)有关:寻找最大化损失的输入。 + +## 安全性 + +- 自动驾驶中的安全性由工程标准而非仅ML指标来管理。 + +- **ISO 26262**(功能安全)是安全关键电子系统的汽车标准。它根据潜在危害的严重性、暴露度和可控性定义了**汽车安全完整性等级(ASIL)**,从A(最低)到D(最高)。自动驾驶系统的感知和规划组件通常为ASIL-D,即最高等级,需要广泛的验证、冗余和故障安全设计。 + +- **SOTIF**(预期功能安全,ISO 21448)处理另一类危害:不是硬件故障(ISO 26262覆盖的),而是系统按设计工作但仍产生不安全结果的情况。一个将白色卡车误分类为天空的感知模型(真实事件)是SOTIF问题:硬件工作正常,但算法的局限性导致了危害。 + +- **运行设计域(ODD)**定义了自动驾驶系统设计用于运行的条件:特定的地理区域、道路类型(仅高速、城市道路、两者兼有)、天气条件(无大雪)、速度范围和时间段。不允许在ODD之外运行:如果系统不能处理雪,就不能在雪中驾驶。 + +- **故障安全** vs **故障可操作**设计: + - 故障安全:检测到故障时,系统过渡到安全状态(例如,靠边停车)。这是最低要求。 + - 故障可操作:系统在故障情况下仍能安全运行,使用冗余组件。具有冗余转向、制动和计算的自动驾驶汽车可以在单个组件故障后存活并仍然行驶到安全位置。 + +- **冗余**是基础。关键感知传感器被复制:多个相机覆盖重叠视场、LiDAR和雷达同时提供独立的深度测量、双计算平台运行相同的软件。如果任何单个组件发生故障,其他组件提供足够的信息来安全驾驶。 + +## 自动驾驶等级 + +![SAE自动驾驶等级从L0(无自动化)到L5(完全自动化),显示责任从人类转移到系统的过程](../images/sae_autonomy_levels.svg) + +- **SAE J3016**标准定义了六个驾驶自动化等级,从0(无自动化)到5(完全自动化): + + - **等级0(无自动化)**:人类做所有事情。系统可能提供警告(车道偏离警报)但不控制车辆。 + + - **等级1(驾驶辅助)**:系统控制转向或速度,但不能同时控制两者。自适应巡航控制(保持速度和跟车距离)或车道保持辅助(使车辆保持在车道中央)属于等级1。 + + - **等级2(部分自动化)**:系统同时控制转向和速度,但人类必须时刻监控并准备接管。特斯拉Autopilot、GM Super Cruise和大多数当前的"自动驾驶"功能属于等级2。人类仍然是负责的驾驶员。 + + - **等级3(条件自动化)**:系统驾驶并监控环境,但仅在特定条件下(ODD内)。人类可以脱离关注,但必须准备好在系统请求时接管(有时间缓冲,通常10秒以上)。Mercedes Drive Pilot(特定高速公路上,低于60公里/小时)是第一个经认证的等级3系统。 + + - **等级4(高度自动化)**:系统在ODD内驾驶并处理所有情况,无需人类干预。如果遇到ODD之外的情况,它可以安全地自己停车。Waymo的机器人出租车服务在特定地理区域内以等级4运行。 + + - **等级5(完全自动化)**:系统能在人类能去的一切地方、一切条件下驾驶。无需方向盘或踏板。这目前还不存在。 + +- 关键区别在于**谁对安全负责**。在等级0-2,人类负责。在等级3-5,系统负责(在其ODD内)。这具有深远的法律、保险和伦理影响。 + +- 当前行业状态是等级2(广泛部署)、等级3(开始部署)和等级4(有限地理部署)的混合。等级5仍然是一个长期研究目标。 + +## 编程任务(使用CoLab或notebook) + +1. 实现一个简单的轨迹优化规划器。给定起始位置、目标和障碍物,使用梯度下降找到最平滑的无碰撞路径。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# 轨迹:N个路径点,每个(x, y) +N = 20 +start = jnp.array([0.0, 0.0]) +goal = jnp.array([10.0, 0.0]) +obstacle = jnp.array([5.0, 0.0]) +obs_radius = 1.5 + +# 初始化:从起点到终点的直线 +waypoints_init = jnp.linspace(start, goal, N) + +def cost(waypoints): + wp = jnp.concatenate([start[None], waypoints, goal[None]], axis=0) + + # 平滑度:惩罚加速度(二阶差分) + accel = wp[2:] - 2 * wp[1:-1] + wp[:-2] + smooth_cost = jnp.sum(accel ** 2) + + # 避障:惩罚接近度 + dists = jnp.linalg.norm(wp - obstacle, axis=1) + collision_cost = jnp.sum(jnp.maximum(0, obs_radius + 0.5 - dists) ** 2) + + return 10 * smooth_cost + 100 * collision_cost + +grad_cost = jax.grad(cost) + +# 优化内部路径点 +waypoints = waypoints_init[1:-1] +lr = 0.01 +for _ in range(500): + g = grad_cost(waypoints) + waypoints = waypoints - lr * g + +# 绘图 +full_path = jnp.concatenate([start[None], waypoints, goal[None]], axis=0) +theta = jnp.linspace(0, 2 * jnp.pi, 100) + +plt.figure(figsize=(10, 4)) +plt.plot(full_path[:, 0], full_path[:, 1], "b.-", label="优化后路径") +plt.plot(waypoints_init[:, 0], waypoints_init[:, 1], "r--", alpha=0.5, label="初始(直线)") +plt.fill(obstacle[0] + obs_radius * jnp.cos(theta), + obstacle[1] + obs_radius * jnp.sin(theta), alpha=0.3, color="red", label="障碍物") +plt.plot(*start, "go", markersize=10); plt.plot(*goal, "g*", markersize=15) +plt.legend(); plt.axis("equal"); plt.grid(True) +plt.title("轨迹优化:平滑无碰撞路径") +plt.show() +``` + +2. 模拟一个匀速运动预测模型,并与转弯车辆的真实值比较。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# 真实值:车辆右转 +dt = 0.1 +T = 40 # 4秒 +v = 10.0 # 米/秒 +omega = 0.3 # 弧度/秒(转弯速率) + +# 真实轨迹(恒定转弯速率) +t = jnp.arange(T) * dt +theta = omega * t +gt_x = (v / omega) * jnp.sin(theta) +gt_y = (v / omega) * (1 - jnp.cos(theta)) + +# 从t=0开始的匀速预测 +# 假设车辆沿当前航向继续直行 +obs_steps = 10 # 观察前1秒 +vx0 = v * jnp.cos(theta[obs_steps - 1]) +vy0 = v * jnp.sin(theta[obs_steps - 1]) +pred_t = jnp.arange(T - obs_steps) * dt +pred_x = gt_x[obs_steps - 1] + vx0 * pred_t +pred_y = gt_y[obs_steps - 1] + vy0 * pred_t + +plt.figure(figsize=(8, 6)) +plt.plot(gt_x[:obs_steps], gt_y[:obs_steps], "ko-", label="已观测") +plt.plot(gt_x[obs_steps:], gt_y[obs_steps:], "g-", linewidth=2, label="真实未来") +plt.plot(pred_x, pred_y, "r--", linewidth=2, label="匀速预测") +plt.legend(); plt.axis("equal"); plt.grid(True) +plt.xlabel("x (米)"); plt.ylabel("y (米)") +plt.title("匀速预测 vs 转弯车辆") +plt.show() +``` + +3. 实现一个简单的基于规则的规划器,根据检测到的障碍物决定保持车道还是停车。 +```python +import jax.numpy as jnp + +def rule_based_planner(ego_speed, obstacles, speed_limit=13.9): + """ + 简单的基于规则的规划器。 + ego_speed: 当前速度(米/秒) + obstacles: 前方车辆的(距离,速度)元组列表 + speed_limit: 最高允许速度(米/秒),默认约50公里/小时 + + 返回:(目标速度,动作标签) + """ + min_following_distance = 2.0 * ego_speed # 2秒规则 + emergency_distance = 5.0 # 米 + + if not obstacles: + return speed_limit, "巡航" + + # 找到最近的前方障碍物 + closest_dist, closest_speed = min(obstacles, key=lambda o: o[0]) + + if closest_dist < emergency_distance: + return 0.0, "紧急停车" + elif closest_dist < min_following_distance: + # 匹配前车速度 + target = min(closest_speed, speed_limit) + return target, "跟随" + else: + return speed_limit, "巡航" + +# 测试场景 +scenarios = [ + (13.9, [], "空旷道路"), + (13.9, [(30.0, 10.0)], "前方有较慢车辆"), + (13.9, [(3.0, 0.0)], "前方有停靠车辆,距离极近"), + (13.9, [(50.0, 13.9)], "前方车辆同速行驶"), +] + +for speed, obs, desc in scenarios: + target, action = rule_based_planner(speed, obs) + print(f"{desc:30s} → {action:15s} 目标速度={target:.1f} 米/秒 ({target*3.6:.0f} 公里/小时)") +``` diff --git a/chapter 11: autonomous systems/05. space and extreme robotics.md b/chapter 11: autonomous systems/05. space and extreme robotics.md new file mode 100644 index 0000000..3087394 --- /dev/null +++ b/chapter 11: autonomous systems/05. space and extreme robotics.md @@ -0,0 +1,311 @@ +# 太空与极端环境机器人 + +*太空和极端环境机器人将自主性推向极限——通信延迟、辐射和非结构化地形要求机器人自己思考。本章涵盖行星漫游车、在轨服务、通信受限自主性、抗辐射计算、水下机器人、搜索救援、群体机器人和人机交互。* + +- 在本章中,我们研究了在相对温和环境中运行的自主系统:有车道标线的道路、有平坦地板的地板、有已知物体类别的厨房。但机器人技术的一些最具影响力的应用是在人类无法到达的环境,或者人类存在的成本极高的环境:火星表面、深海海底、核灾难现场和燃烧的建筑。 + +- 这些**极端环境**面临着共同的挑战:通信受限或有延迟、地形非结构化且不可预测、硬件必须在恶劣条件下生存、而且附近没有人能在出现问题时修理。机器人必须真正自主,而不仅仅是"有人在屏幕前监控的自主"。 + +## 太空机器人 + +- 太空是终极的极端环境。没有空气,温度在-170°C到+120°C之间摆动,辐射轰击电子设备,而援助在数百万公里之外。太空机器人必须异常可靠、节能且自主。 + +- **行星漫游车**是在其他世界表面探索的移动机器人。NASA的火星漫游车(勇气号、机遇号、好奇号、毅力号)是最著名的例子。每一代都比上一代更加自主。 + +![地球-火星通信延迟:单向4-24分钟,往返8-48分钟,使得实时控制不可能](../images/earth_mars_delay.svg) + +- 根本限制是**通信延迟**。火星距地球4-24分钟的无线电距离(取决于轨道位置),因此往返通信需要8-48分钟。漫游车不能实时操控。如果遇到岩石,它不能向地球求助并等待回应。它必须自己决定。 + +- 早期的漫游车(勇气号、机遇号)严重依赖地面参与的规划:人类研究图像、规划路径、上传命令,漫游车执行命令。一个驾驶周期需要一个完整的火星日。漫游车每天大约能行进50-100米。 + +- 好奇号和毅力号上的**AutoNav**(自主导航)极大地提高了自主性。漫游车使用立体相机构建局部3D地图(回顾第8章的立体深度),评估地形可通过性(坡度、粗糙度、岩石大小),并使用基于网格的规划器和可通过性代价图规划安全路径。漫游车在人类团队睡眠时自主行驶,将每日行进距离提高到100米以上。 + +- 火星漫游车上的感知流程受到抗辐射处理器的限制,这些处理器比消费级硬件慢几个数量级(下文讨论)。算法必须计算节俭:经典的立体匹配而非深度神经网络,简单的代价图规划器而非学习型策略。 + +- **在轨服务**涉及在轨道上检查、修理、加油或使卫星脱离轨道的机器人。随着太空变得更加拥挤,这是一个不断增长的领域。**OSAM-1**(NASA)和商业企业(Astroscale、Northrop Grumman MEV)等任务使用机械臂和对接机构来服务卫星。 + +- 挑战在于**近距离操作**:服务航天器必须接近目标卫星(可能正在翻滚、不合作且缺乏对接接口),并在微重力下执行精确操作。基于视觉的位姿估计(从相机图像确定目标的3D位置和方向)至关重要。这使用了第8章的技术:特征检测、PnP(透视n点)求解,以及最近基于深度学习的位姿估计器。 + +- **卫星检查**使用小型航天器目视检查其他卫星是否有损坏或异常。检查者必须自主绕目标导航、避免碰撞并从最佳视角捕获高分辨率图像。这是一个规划问题:找到覆盖所有检查点且满足燃料约束、光照条件和避碰要求的轨迹。 + +## 通信约束 + +- 在太空中,通信受到光速、可用带宽和轨道几何的限制(火星背面的漫游车在没有中继卫星的情况下根本无法与地球通信)。 + +- 这些限制从根本上改变了自主性架构。在地球上,机器人可以将高清视频流传输到云服务器,在GPU集群上运行推理,并在毫秒内接收指令。在太空中,机器人必须在飞行器上完成所有工作。 + +- **高延迟**意味着机器人必须在没有实时人类指导的情况下规划和行动。自主软件必须处理常规操作、检测异常并响应危险,而无需等待人类输入。这需要鲁棒的板载状态估计、故障检测和应急规划。 + +- **有限带宽**意味着机器人无法传输原始传感器数据。一张高分辨率图像可能有几兆字节,但火星到地球的数据速率通过直接对地链路只有每秒几千比特(通过轨道中继更高,但仍然有限)。机器人必须积极压缩数据、优先决定发送哪些数据,并在本地做出大部分决策。 + +- **通信窗口**是间歇性的。火星漫游车只能在特定轨道几何形状期间与地球通信,通常每个火星日通过中继卫星只有几小时。在这些窗口之外,漫游车完全靠自己。 + +- 对AI的影响是**板载自主性**必须非常可靠。系统需要检测是否出了问题(轮子卡住了、传感器故障了、前方地形无法通行),决定安全响应,并继续运行直到下一个通信窗口,届时它可以报告并接收更新指令。 + +## 抗辐射计算 + +- 太空中充满了电离辐射:宇宙射线、太阳粒子事件以及行星磁场中的捕获辐射。高能粒子可以翻转存储器中的比特(**单粒子翻转,SEU**),永久损坏晶体管(**总电离剂量,TID**),或在电路中引起破坏性闩锁。 + +- **抗辐射处理器**被设计为承受这种环境。它们使用更大的晶体管几何尺寸、冗余逻辑(三模冗余:每个电路有三个副本对输出进行投票)和专门的制造工艺。代价是性能:最先进的抗辐射处理器可能提供200 MIPS,而消费级GPU每秒可执行数十亿次操作。 + +- **RAD750**(BAE Systems)为好奇号和许多其他航天器提供动力。它以200 MHz运行,约400 MIPS的处理能力,相当于1990年代中期的台式电脑。毅力号使用类似等级的处理器。在现代神经网络上运行(数百万参数、数十亿次乘加运算)在这样的硬件上是不可行的。 + +- **模型压缩**变得至关重要。第6章的技术(量化、剪枝、知识蒸馏)用于缩小神经网络以适应极端的计算预算。在笔记本电脑GPU上毫秒级运行的模型可能在抗辐射处理器上需要数分钟,或者根本无法装入内存。 + +- 另一种方法使用**商用现货**处理器,配合软件中的辐射缓解措施:纠错码、看门狗定时器、定期内存清理和优雅降级策略。一些现代任务使用这种方法以获得更强大的计算能力,代价是增加了软件复杂性和风险。 + +- 未来的行星任务正在探索**FPGA**和专门的AI加速器,它们可以具有耐辐射性,同时提供比传统抗辐射CPU多得多的计算能力,可能首次实现板载深度学习。 + +## 非结构化地形中的自主导航 + +- 在地球上,道路平坦、标记清晰且有地图。在火星、月球或灾难现场,没有道路。地形是非结构化的:岩石、斜坡、沙地、裂缝和可能无法支撑机器人重量的表面。 + +- **地形分类**评估每块地面是否安全通行。特征包括坡度(来自3D重建)、粗糙度(表面法线的方差)、岩石密度和土壤类型。经典方法从立体深度图计算这些特征;现代方法在视觉和几何特征上使用学习型分类器。 + +- **视觉-惯性里程计(VIO)**通过跟踪跨相机帧的视觉特征并与IMU测量融合来估计机器人的运动。这是SLAM的核心组件(第8章),针对极端条件进行了调整。在火星上,VIO必须处理:无特征的沙地地形(几乎没有可跟踪的视觉特征)、强烈的光照(极端阴影)和有限的计算能力。 + +- 估计过程使用**扩展卡尔曼滤波(EKF)**或因子图优化融合视觉和惯性数据。状态向量包括位置、速度、方向和IMU偏差。预测步骤使用IMU积分: + +$$\\mathbf{x}_{t+1} = f(\\mathbf{x}_t, \\mathbf{u}_t)$$ + +- 其中$\\mathbf{u}_t$是IMU测量值(加速度和角速度)。更新步骤使用视觉特征观测来校正预测。这是贝叶斯估计(第5章):IMU提供先验,视觉观测更新信念。 + +- **危险规避**在行星着陆过程中至关重要。当航天器下降向表面时,它必须使用板载相机或LiDAR实时识别安全的着陆区。NASA毅力号上的**地形相对导航(TRN)**系统将板载相机图像与预加载的轨道地图进行比较,以确定下降过程中的位置,然后避开危险地形。这使得在Jezero陨石坑着陆成为可能——一个科学丰富但地形危险的站点,对于以前的 missions 来说风险太大。 + +## 水下机器人 + +- 深海与太空一样陌生:压碎性压力(全海深1000+个大气压)、接近零能见度、无GPS和有限的通信。水下机器人对海洋科学、近海基础设施检查、深海采矿和搜索操作至关重要。 + +- **AUV**(自主水下航行器)无缆运行,携带自己的电力和计算资源。它们遵循预设的测量模式或使用板载智能来适应发现。AUV用于海底测绘、管道检查和环境监测。 + +- **ROV**(遥控水下航行器)通过电缆连接到水面船只,提供电力和通信。它们用于需要实时人类控制的任务:深海操作、建造和修理。缆线消除了通信限制,但限制了范围并增加了操作复杂性。 + +- **声学通信**是主要的水下通信方法(无线电波在水中迅速衰减)。声学调制解调器在几公里范围内达到1-10 kbps的数据速率,而陆地上无线电可达吉比特每秒。这甚至比火星通信更加受限,迫使AUV高度自主。 + +- **水下SLAM**尤其具有挑战性。声纳提供距离测量,但角分辨率差且噪声大(来自海底和水面的多径反射)。相机只能在非常短的距离内工作(清澈水中几米,浑浊条件下更短)。基于特征的可视SLAM(第8章)必须针对水下场景的独特视觉特征进行调整:颜色衰减(红光首先被吸收)、反向散射以及产生亮斑和深影的人工照明。 + +- 无GPS导航使用**航位推算**(积分来自多普勒测速仪DVL的速度,该仪器利用声学多普勒频移测量相对于海底的速度),辅以偶尔浮出水面获取GPS定位或来自水面应答器的声学定位。这与仅IMU导航相同的漂移问题:小的速度误差在长任务中累积。 + +## 搜索救援机器人 + +- 在地震、建筑物倒塌或工业事故后,机器人可以进入对人类救援人员太危险的区域:结构不稳定的建筑物、有毒环境、火场或密闭空间。 + +- 需求是:快速部署(几分钟而非几小时)、在GPS受限环境中运行(建筑物内部、地下)、通过墙壁和瓦砾的鲁棒通信,以及导航高度杂乱、部分坍塌空间的能力,这些空间充满碎片、灰尘和不良照明。 + +- **多机器人协调**在搜索救援中很有价值,因为一支机器人团队可以比单个机器人更快地覆盖大面积。挑战在于协调:机器人必须划分搜索区域、避免重复工作并共享发现。 + +- **前沿探索**将机器人分配到已探索和未探索空间之间的边界("前沿")。每个机器人导航到最近的未探索前沿、绘制地图并继续前进。中央或分布式规划器将前沿分配给机器人以最小化总探索时间。这是一个覆盖优化问题。 + +- 通过瓦砾的通信不可靠。机器人可能失去与控制台和彼此的联系。系统必须对间歇通信具有鲁棒性:每个机器人应能独立运行,构建自己的局部地图并做出自己的决策,然后在通信恢复时合并信息。 + +## 群体机器人 + +- **群体机器人**使用大量简单、低成本的机器人,通过局部交互实现复杂的集体行为。没有单个机器人单独具备能力,但整个群体可以执行单个机器人无法完成的任务。 + +- 灵感来自生物群体:蚂蚁用身体搭桥、蜜蜂集体决定巢穴位置、鱼群通过协调运动躲避捕食者。在每种情况下,简单的局部规则(跟随邻居、避免碰撞、向食物移动)产生复杂的全局行为。 + +- **去中心化控制**意味着没有中央指挥官。每个机器人遵循相同的局部规则,仅对其邻居和即时环境作出反应。全局行为从这些局部交互中**涌现**。这使得群体具有固有的鲁棒性:如果一个机器人失效,群体继续运行。没有单点故障。 + +- **共识算法**使群体能够仅通过局部通信就集体决策达成一致(例如,向哪个方向移动、优先处理哪个任务)。一个简单的共识协议让每个机器人与其邻居平均其值: + +$$x_i(t+1) = \\frac{1}{|N_i| + 1} \\left( x_i(t) + \\sum_{j \\in N_i} x_j(t) \\right)$$ + +![群体共识:机器人从分散开始,迭代地与邻居平均,收敛到共享位置](../images/swarm_consensus.svg) + +- 其中$N_i$是机器人$i$的邻居集合。这一过程迭代直到所有机器人收敛到相同的值(全局平均值)。收敛速度取决于通信图的拓扑结构,特别是其代数连通性(图拉普拉斯矩阵的第二小特征值,与第2章的特征值相关)。 + +![Reynolds三个群集规则:分离避免碰撞,对齐匹配航向,内聚保持群体](../images/reynolds_flocking.svg) + +- **群集算法**(Reynolds规则)通过每个机器人的三个简单规则产生协调的群体运动: + - **分离**:远离太近的邻居(避免碰撞)。 + - **对齐**:朝向邻居的平均航向(朝相同方向移动)。 + - **内聚**:朝向邻居的平均位置(与群体在一起)。 + +- 每个规则是机器人速度的一个向量贡献。这些向量的加权和产生自然主义的群集行为。这是一个向量的线性组合(第1章),其中权重控制每个行为的相对重要性。 + +- 群体机器人的应用包括环境监测(在大范围内分布传感器)、精准农业(协调无人机进行作物喷洒)、建造(机器人集体组装结构)和搜索操作(高效覆盖大面积)。 + +## 人机交互 + +- 大多数真实的自主系统是与人类并肩运行,而非孤立运行。人与机器人之间的交互——他们如何沟通、共享控制和建立信任——与机器人的技术能力同样重要。 + +![共享自主光谱:从完全人类遥控操作(alpha=1)经混合控制到完全机器人自主(alpha=0)](../images/shared_autonomy_spectrum.svg) + +- **共享自主**混合了人和机器人的控制。不是完全遥控操作(人类控制一切)或完全自主(机器人控制一切),而是共享自主让人类提供高层意图,同时机器人处理底层执行。例如,人类可能指向一个物体说"捡起来",然后机器人自主规划抓取和手臂运动。 + +- 数学上,共享自主可以建模为人类输入$\\mathbf{u}_h$和机器人自主动作$\\mathbf{u}_r$的混合: + +$$\\mathbf{u} = \\alpha \\mathbf{u}_h + (1 - \\alpha) \\mathbf{u}_r$$ + +- 其中$\\alpha \\in [0, 1]$是混合参数。当$\\alpha = 1$时,人类完全控制(遥控操作)。当$\\alpha = 0$时,机器人完全自主。自适应共享自主根据情况调整$\\alpha$:机器人在自信时接管更多控制,在不确定或情况新颖时让出控制。 + +- **遥控操作**对于超出当前自主能力的任务仍然很重要。人类操作员通过机器人的相机远程查看场景并控制机器人。挑战是**延迟**:即使100毫秒的延迟也会使遥控操作变得困难,而太空中的多秒延迟使其对精细操作几乎不可能。预测显示(显示机器人预测的未来状态)和虚拟夹具(防止操作员命令危险运动的软件引导)有助于弥补。 + +- **信任校准**是确保人类对机器人有适当信任的问题:不要太多(过度信任导致自满,在需要时未能干预),也不要太少(信任不足导致不必要干预和利用不足)。信任应该校准到机器人的实际能力:在它处理得好的情况下信任它,在接近其能力边缘的情况下保持怀疑。 + +- 研究表明,信任受以下因素的影响:机器人的透明度(它是否解释其决策?)、可靠性(它是可预测地失败还是随机地失败?)以及沟通(它是否表达不确定性?)。一个说"我对此路径只有40%的置信度,是否继续?"的机器人比一个默默向前驾驶的机器人能做出更好的人类决策。 + +- 机器人运动中的**可读性**意味着机器人以传达其意图的方式运动给附近的人类。如果机器人伸手去拿一个物体,它的路径应该使其目标对象显而易见,即使它还未到达。这涉及规划最大化观察者早期推断目标的轨迹,可以形式化为给定观察到的部分轨迹时真实目标的后验概率最大化: + +$$\\pi^* = \\arg\\max_\\pi P(G \\mid \\xi_{0:t})$$ + +- 其中$G$是目标,$\\xi_{0:t}$是到目前为止观察到的轨迹。这使用了贝叶斯推理(第5章):观察者对可能的目标有先验,机器人的轨迹提供了更新此信念的证据。 + +## 编程任务(使用CoLab或notebook) + +1. 模拟机器人群体就目标位置达成一致的共识算法。从随机初始位置开始,观察收敛过程。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +n_robots = 10 +rng = jax.random.PRNGKey(0) +positions = jax.random.uniform(rng, (n_robots, 2), minval=-5, maxval=5) + +# 通信图:每个机器人与最近的3个邻居通信 +def get_neighbours(positions, k=3): + dists = jnp.linalg.norm(positions[:, None] - positions[None, :], axis=-1) + # 对每个机器人,找最近的k个(排除自身) + neighbours = jnp.argsort(dists, axis=1)[:, 1:k+1] + return neighbours + +history = [positions.copy()] + +for step in range(30): + neighbours = get_neighbours(positions) + new_positions = jnp.zeros_like(positions) + for i in range(n_robots): + nbr_pos = positions[neighbours[i]] + new_positions = new_positions.at[i].set( + (positions[i] + nbr_pos.sum(axis=0)) / (len(neighbours[i]) + 1) + ) + positions = new_positions + history.append(positions.copy()) + +# 绘制收敛过程 +fig, axes = plt.subplots(1, 3, figsize=(15, 4)) +for ax, step_idx, title in zip(axes, [0, 10, 29], ["初始", "第10步", "最终"]): + h = history[step_idx] + ax.scatter(h[:, 0], h[:, 1], s=50) + ax.set_xlim(-6, 6); ax.set_ylim(-6, 6) + ax.set_aspect("equal"); ax.grid(True); ax.set_title(title) +plt.suptitle("群体共识:机器人收敛到一致性") +plt.tight_layout() +plt.show() +``` + +2. 实现Reynolds群集规则(分离、对齐、内聚)并模拟一个群体一起移动。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +n = 30 +rng = jax.random.PRNGKey(1) +k1, k2 = jax.random.split(rng) +pos = jax.random.uniform(k1, (n, 2), minval=-5, maxval=5) +vel = jax.random.uniform(k2, (n, 2), minval=-0.5, maxval=0.5) + +dt = 0.1 +separation_radius = 1.0 +neighbour_radius = 3.0 + +trajectories = [pos.copy()] + +for _ in range(200): + new_vel = jnp.zeros_like(vel) + for i in range(n): + diffs = pos - pos[i] + dists = jnp.linalg.norm(diffs, axis=1) + + # 半径内的邻居(排除自身) + nbr_mask = (dists < neighbour_radius) & (dists > 0) + sep_mask = (dists < separation_radius) & (dists > 0) + + # 分离:远离非常近的邻居 + if sep_mask.any(): + sep = -diffs[sep_mask].sum(axis=0) + else: + sep = jnp.zeros(2) + + # 对齐:匹配邻居的平均速度 + if nbr_mask.any(): + align = vel[nbr_mask].mean(axis=0) - vel[i] + else: + align = jnp.zeros(2) + + # 内聚:朝向邻居的平均位置 + if nbr_mask.any(): + cohesion = pos[nbr_mask].mean(axis=0) - pos[i] + else: + cohesion = jnp.zeros(2) + + new_vel = new_vel.at[i].set(vel[i] + 1.5 * sep + 0.5 * align + 0.3 * cohesion) + + # 限制速度 + speeds = jnp.linalg.norm(new_vel, axis=1, keepdims=True) + vel = jnp.where(speeds > 2.0, new_vel / speeds * 2.0, new_vel) + pos = pos + vel * dt + trajectories.append(pos.copy()) + +# 绘制快照 +fig, axes = plt.subplots(1, 3, figsize=(15, 4)) +for ax, idx, title in zip(axes, [0, 50, 199], ["开始", "第50步", "第200步"]): + p = trajectories[idx] + v = vel if idx == 199 else jnp.zeros_like(vel) + ax.scatter(p[:, 0], p[:, 1], s=20, c="blue") + ax.set_aspect("equal"); ax.grid(True); ax.set_title(title) + lim = max(abs(p).max() + 1, 6) + ax.set_xlim(-lim, lim); ax.set_ylim(-lim, lim) +plt.suptitle("Reynolds群集:分离+对齐+内聚") +plt.tight_layout() +plt.show() +``` + +3. 模拟共享自主混合:人类提供带噪声的方向输入,机器人的自主系统提供到目标的平滑路径。用不同的alpha值进行混合。 +```python +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +goal = jnp.array([10.0, 5.0]) +pos = jnp.array([0.0, 0.0]) +dt = 0.1 + +rng = jax.random.PRNGKey(3) + +fig, axes = plt.subplots(1, 3, figsize=(15, 4)) +for ax, alpha in zip(axes, [1.0, 0.5, 0.0]): + pos = jnp.array([0.0, 0.0]) + path = [pos.copy()] + + for step in range(150): + # 机器人自主:到目标的平滑路径 + direction = goal - pos + u_robot = direction / (jnp.linalg.norm(direction) + 1e-6) * 1.0 + + # 人类输入:大致正确的方向但有噪声 + noise = jax.random.normal(jax.random.fold_in(rng, step), (2,)) * 0.5 + u_human = u_robot + noise + + # 混合 + u = alpha * u_human + (1 - alpha) * u_robot + pos = pos + u * dt + path.append(pos.copy()) + + if jnp.linalg.norm(pos - goal) < 0.3: + break + + path = jnp.stack(path) + ax.plot(path[:, 0], path[:, 1], "b-", alpha=0.7) + ax.plot(*goal, "r*", markersize=15) + ax.plot(0, 0, "go", markersize=10) + ax.set_title(f"α={alpha:.1f} ({'人类' if alpha==1 else '机器人' if alpha==0 else '共享'})") + ax.set_xlim(-1, 12); ax.set_ylim(-3, 8) + ax.set_aspect("equal"); ax.grid(True) + +plt.suptitle("共享自主:混合人类与机器人控制") +plt.tight_layout() +plt.show() +``` diff --git a/chapter 12: graph neural networks/01. geometric deep learning.md b/chapter 12: graph neural networks/01. geometric deep learning.md new file mode 100644 index 0000000..81f53e4 --- /dev/null +++ b/chapter 12: graph neural networks/01. geometric deep learning.md @@ -0,0 +1,170 @@ +# 几何深度学习 + +*几何深度学习是揭示CNN、Transformer和GNN皆遵循同一原理——利用对称性——的统一框架。本章涵盖对称群、群作用、不变性、等变性、五个几何域以及尺度分离* + +- 在本书中,我们已经学习了多种架构:图像的CNN(第8章)、语言的Transformer(第7章)以及序列决策的RL策略(第6章)。它们看上去像是为完全不同的问题设计的完全不同的模型。但背后存在一个更深层的模式。 + +- **几何深度学习**揭示出所有这些架构都是同一个思想的实例:构建尊重数据**对称性**的网络。CNN利用图像中的平移对称性。Transformer利用序列中的置换对称性(注意力不依赖于绝对位置)。GNN利用图中的置换对称性。一旦看清这一点,众多架构就变成了一个统一的连贯框架。 + +## 对称性与群 + +- 一个对象的**对称性**是使其保持不变的变换。正方形有8种对称性:4种旋转(0°、90°、180°、270°)和4种反射。圆有无限多种:任何绕其中心的旋转。关键洞察在于,对称性告诉你什么是不重要的,而知道什么不重要的对于学习来说极为强大。 + +- 用机器学习的术语来说:如果一个任务具有对称性,那么无论看到输入的哪种"版本",模型都应给出相同的答案。猫检测器无论猫在图像的左上角还是右下角都应能工作。这就是平移对称性。 + +- 对称性通过**群**来形式化。一个群 $G$ 是一个具有四个性质的变换集合: + + - **封闭性**:两个变换的组合产生集合中的另一个变换。先旋转90°再旋转90°得到180°,也属于该集合。 + - **结合律**:$(g_1 \circ g_2) \circ g_3 = g_1 \circ (g_2 \circ g_3)$。分组的顺序无关紧要(回顾第2章中矩阵乘法的结合律)。 + - **单位元**:存在一个"什么也不做"的变换 $e$,使得 $e \circ g = g \circ e = g$。 + - **逆元**:每个变换都有撤销操作:$g \circ g^{-1} = e$。 + +- 这些公理与向量空间(第1章)的公理相同,但应用于变换而非向量。其联系十分深刻:群作用于向量空间,而神经网络必须尊重这种作用。 + +- 深度学习中出现的关键群: + + - **平移群** $(\mathbb{R}^n, +)$:平移图像或信号。这是CNN利用的对称性。 + - **对称群** $S_n$:$n$ 个元素的所有置换。这是GNN和Transformer利用的对称性(重新排序节点或标记不应改变结果)。 + - **旋转群** $SO(n)$:$n$ 维空间中的所有旋转。$SO(2)$ 是平面旋转,$SO(3)$ 是三维旋转(对分子和3D视觉任务至关重要)。 + - **欧几里得群** $E(n)$:所有旋转、反射和平移。物理空间的对称性。 + - **特殊欧几里得群** $SE(n)$:旋转和平移(不含反射)。刚体运动的对称性。 + +- **群作用**描述了群如何变换数据。如果 $G$ 是一个群,$X$ 是数据空间,则作用 $\rho: G \times X \to X$ 将每个群元素 $g$ 和数据点 $x$ 映射到一个变换后的点 $\rho(g, x)$。对于图像,平移群通过平移像素坐标来作用。对于图,对称群通过重新标记节点来作用。 + +## 不变性与等变性 + +- 给定一个对称群,函数可以通过两种重要方式与之关联: + +- 函数 $f$ 对群 $G$ 是**不变**的,如果输入变换后输出不变: + +$$f(\rho(g, x)) = f(x) \quad \text{对于所有 } g \in G$$ + +- 示例:图像的总体亮度不因平移而改变。图像分类应是平移不变的:"猫"的类别无论猫在何处都是一样的。 + +- 函数 $f$ 对群 $G$ 是**等变**的,如果变换输入会对等地变换输出: + +$$f(\rho_{\text{in}}(g, x)) = \rho_{\text{out}}(g, f(x)) \quad \text{对于所有 } g \in G$$ + +- 示例:如果将图像向右平移5个像素,CNN中的特征图也会向右平移5个像素。卷积操作是平移等变的:它保留了空间关系。目标检测应该是等变的:如果猫移动了,边界框也应随之移动。 + +![不变性:输出不随变换而改变。等变性:输出随之对应变换](../images/invariance_vs_equivariance.svg) + +- 区分两者的重要性在于:**中间层**通常应是等变的(为下游层保留结构),而**最终输出**应是不变的(答案不应依赖于变换)。CNN通过堆叠等变卷积层,然后在末尾应用全局池化(它是不变的)来实现这一点。 + +- 将等变性构建到架构中比从数据中学习它要高效得多。一个具有权重共享的平移等变CNN所需的参数远少于一个必须独立学习"位置(10,10)处的猫"和"位置(200,150)处的猫"的全连接网络。对称性约束指数级地缩小了假设空间。 + +## 五个几何域 + +- 几何深度学习识别出数据的**五个基本域**,每个域都有其自己的对称群。每一个神经网络架构都可以被理解为利用其中某个域的对称性。 + +![五个几何域:网格、集合、序列、图和流形,各有其对称性和架构](../images/five_geometric_domains.svg) + +- **1. 网格(欧几里得数据)**:图像、音频频谱图、体数据。底层结构是具有平移对称性的规则网格。群是平移群(可能再加上旋转和反射)。利用这种对称性的架构是**CNN**:卷积正是平移等变的操作。空间位置上的权重共享就是平移等变性的具体实现。 + +- **2. 集合(无序集合)**:点云、粒子系统。对称性是置换不变性:元素的顺序无关紧要。架构是**DeepSets**(以及第8章的PointNet):对每个元素应用共享函数,然后用置换不变操作(求和、均值或取最大值)进行聚合。形式上,$f(\{x_1, \ldots, x_n\}) = \phi\left(\sum_i \psi(x_i)\right)$。 + +- **3. 序列(有序数据)**:文本、时间序列。序列是一维网格,但有一个微妙之处:对称性更加细致。绝对位置可能重要也可能不重要。RNN以自回归方式处理序列。带位置编码的Transformer可以关注任何位置,其自注意力在加入位置编码之前是置换等变的。这就是Transformer泛化能力如此之强的原因:它们从置换等变开始,然后仅添加必要的位置结构。 + +- **4. 图(关系数据)**:社交网络、分子、知识图谱。对称性是节点的置换:重新标记节点不应改变图的性质。架构是**GNN**:连接节点之间传递消息,使用不依赖于节点顺序的共享函数。这是本章剩余部分的重点。 + +- **5. 流形和网格**:曲面、3D形状。对称性包括微分同胚(光滑变形)。架构使用内在算子(例如拉普拉斯-贝尔特拉米算子),这些算子由曲面几何本身定义,与曲面在空间中的嵌入方式无关。这联系到微分几何,并适用于形状分析、球面上的气候建模和蛋白质表面分析。 + +- 这个框架的强大之处在于其统一性。CNN是网格图上的GNN。Transformer是完全连接图上的GNN。DeepSets是没有边的GNN。将这些视为同一原理的实例,指导着新架构的设计:识别数据的对称性,然后构建一个尊重它的网络。 + +## 尺度分离与粗化 + +- 真实世界的数据具有多尺度结构。一幅图像有细粒度纹理(像素级)、局部模式(边缘、角点)、物体部件(车轮、窗户)和全局结构(整个场景)。一个分子有原子级特征、官能团和整体分子形状。 + +- **尺度分离**是这样一个原理:这些细节层次可以分层处理——先捕获局部结构,然后逐步聚合成更粗粒度的表示。这就是**粗化**或**池化**。 + +- 在CNN中,池化层(最大池化、平均池化)对空间分辨率进行下采样,迫使高层捕获更大尺度的模式。在感受野视角(第8章)中,更深层能"看到"更多的图像。这就是尺度分离的实际应用。 + +- 在图(graph)中,粗化意味着将节点群聚为"超节点",生成一个保留基本结构的更小图。这就是图池化,我们将在文件3中详细讨论。它与图像池化直接类似:降低分辨率的同时保留重要特征。 + +- 在序列中,分层处理(例如句子→段落→文档)在不同时间或语义尺度捕获结构。Swin Transformer(第8章)通过其移位窗口层次结构将这一思想应用于图像。 + +- 数学上,粗化定义了一个**逐渐抽象的表示层次**: + +$$x \xrightarrow{\text{局部特征}} h^{(1)} \xrightarrow{\text{粗化}} h^{(2)} \xrightarrow{\text{粗化}} \cdots \xrightarrow{\text{全局}} y$$ + +- 在每个层次,表示相对于该层次的对称群是等变的。最后的全局表示是不变的,捕获了输入的本质而不受无关变换的影响。 + +- 这就是为什么对于结构化数据,深层网络比浅层网络效果更好:每一层增加一个抽象层次,多个等变层的组合从简单的局部特征构建出复杂的不变特征。 + +## 编程任务(使用CoLab或notebook) + +1. 验证卷积的平移等变性。对图像应用卷积,然后平移图像再次卷积。检查输出是否互为平移版本。 +```python +import jax +import jax.numpy as jnp + +# 一维信号和一个简单滤波器 +signal = jnp.array([0, 0, 0, 1, 2, 3, 2, 1, 0, 0, 0], dtype=float) +kernel = jnp.array([1, 0, -1], dtype=float) + +# 先卷积再平移 +conv_result = jnp.convolve(signal, kernel, mode="same") +shifted_signal = jnp.roll(signal, 3) +conv_shifted = jnp.convolve(shifted_signal, kernel, mode="same") +shifted_conv = jnp.roll(conv_result, 3) + +print(f"先卷积再平移: {shifted_conv}") +print(f"先平移再卷积: {conv_shifted}") +print(f"等变性: {jnp.allclose(shifted_conv, conv_shifted, atol=1e-5)}") +``` + +2. 验证DeepSets风格聚合的置换不变性。对集合中的每个元素应用共享函数,求和结果,并检查输出是否不依赖于元素顺序。 +```python +import jax +import jax.numpy as jnp + +# 4个向量的"集合"(顺序应无关紧要) +x = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]) + +# 简单的共享函数:逐元素平方 +psi = lambda v: v ** 2 + +# 通过求和聚合 +def deepsets(points): + return jnp.sum(jax.vmap(psi)(points), axis=0) + +# 原始顺序 +result1 = deepsets(x) + +# 置换后的顺序 +perm = jnp.array([2, 0, 3, 1]) +result2 = deepsets(x[perm]) + +print(f"原始顺序: {result1}") +print(f"置换顺序: {result2}") +print(f"不变性: {jnp.allclose(result1, result2)}") +``` + +3. 探索群结构。通过检查封闭性、结合律、单位元和逆元,验证二维旋转矩阵构成群。 +```python +import jax.numpy as jnp + +def rot2d(theta): + return jnp.array([[jnp.cos(theta), -jnp.sin(theta)], + [jnp.sin(theta), jnp.cos(theta)]]) + +R1 = rot2d(jnp.pi / 6) +R2 = rot2d(jnp.pi / 4) +R3 = rot2d(jnp.pi / 3) + +# 封闭性:两个旋转的乘积还是一个旋转 +R12 = R1 @ R2 +print(f"封闭性 (行列式=1, 正交): det={jnp.linalg.det(R12):.4f}, " + f"R^T R = I: {jnp.allclose(R12.T @ R12, jnp.eye(2), atol=1e-5)}") + +# 结合律 +print(f"结合律: {jnp.allclose((R1 @ R2) @ R3, R1 @ (R2 @ R3), atol=1e-5)}") + +# 单位元 +I = rot2d(0.0) +print(f"单位元: {jnp.allclose(R1 @ I, R1, atol=1e-5)}") + +# 逆元 +R1_inv = rot2d(-jnp.pi / 6) +print(f"逆元: {jnp.allclose(R1 @ R1_inv, jnp.eye(2), atol=1e-5)}") +``` diff --git a/chapter 12: graph neural networks/02. graph theory.md b/chapter 12: graph neural networks/02. graph theory.md new file mode 100644 index 0000000..69128c8 --- /dev/null +++ b/chapter 12: graph neural networks/02. graph theory.md @@ -0,0 +1,236 @@ +# 图论 + +*图论为描述实体间关系提供了数学语言。本章涵盖节点、边、邻接矩阵、图类型、度和连通性、图拉普拉斯算子、谱图理论以及现实世界的图应用。我们将在纯计算机科学章节中更深入地讨论图* + +- 到目前为止,本书中的数据都存在于规则结构上:$\mathbb{R}^n$ 中的向量(第1章)、数字网格形式的矩阵(第2章)、像素网格形式的图像(第8章)、有序列表形式的序列(第7章)。但许多现实世界的系统是**不规则**的:社交网络没有网格结构,分子没有从左到右的顺序,道路网络也不能整齐地平铺成行和列。 + +- **图(Graph)** 是表示这些不规则关系结构的数学工具。图捕获了**实体**(节点)及它们之间的**关系**(边)。一旦数据被表示为图,我们就可以应用文件1中的几何深度学习原理来从中学习。 + +## 节点、边和邻接 + +- 一个**图** $G = (V, E)$ 由一组**节点**(或顶点)$V = \{v_1, v_2, \ldots, v_n\}$ 和一组连接节点对的**边** $E \subseteq V \times V$ 组成。 + +- 节点代表实体:人、原子、城市、网页、神经元。边代表关系:友谊、化学键、道路、超链接、突触。 + +- **邻接矩阵** $A$ 是图的矩阵表示。对于一个有 $n$ 个节点的图,$A$ 是一个 $n \times n$ 矩阵,其中如果存在从节点 $i$ 到节点 $j$ 的边,则 $A_{ij} = 1$,否则 $A_{ij} = 0$。 + +- 例如,一个三角形图(3个节点,全部相连): + +```math +A = \begin{bmatrix} 0 & 1 & 1 \\ 1 & 0 & 1 \\ 1 & 1 & 0 \end{bmatrix} +``` + +![一个三角形图及其邻接矩阵:边存在处为1,否则为0](../images/graph_adjacency_matrix.svg) + +- 对角线为零,因为节点默认不与自身相连(无自环)。邻接矩阵是我们在第2章中研究的布尔矩阵的直接应用:每个条目都是一个二元关系。 + +- 邻接矩阵完整地编码了图的结构。对 $A$ 的矩阵运算揭示了图的性质:$A^2_{ij}$ 计算节点 $i$ 和 $j$ 之间长度为2的路径数量(回顾第2章中的矩阵乘法:每个条目是经过中间节点的乘积之和)。更一般地,$A^k_{ij}$ 计算长度为 $k$ 的路径数量。 + +- 每个节点可以携带一个**特征向量** $\mathbf{x}_i \in \mathbb{R}^d$。对于社交网络,这可能是用户的个人信息。对于分子,它编码原子类型、电荷和其他属性。全部节点特征的集合是一个矩阵 $X \in \mathbb{R}^{n \times d}$,其中每一行是一个节点的特征。 + +- 边也可以携带特征:分子中的键类型、空间图中的距离、知识图谱中的关系类型。边 $(i, j)$ 的**边特征**是一个向量 $\mathbf{e}_{ij} \in \mathbb{R}^{d_e}$。 + +## 图类型 + +- **无向图**具有对称的边:如果 $i$ 连接到 $j$,则 $j$ 也连接到 $i$。邻接矩阵是对称的:$A = A^T$(一个对称矩阵,见第2章)。友谊和化学键是无向的。 + +- **有向图**(digraph)具有带方向的边:从 $i$ 到 $j$ 的边不意味着从 $j$ 到 $i$ 的边。邻接矩阵是非对称的。Twitter关注、网页超链接和引文网络是有向的。 + +- **加权图**为每条边分配一个数值权重。邻接矩阵具有实数值而非二进制值:$A_{ij} = w_{ij}$。道路网络中的距离、大脑连通性中的相关强度以及社交网络中的交互频率是加权的。 + +- **二分图**具有两个不相交的节点集合,边只存在于集合之间(集合内部没有边)。用户和产品构成一个二分图:用户评价产品,但用户之间不相互评价。二分图的邻接矩阵具有块结构: + +```math +A = \begin{bmatrix} 0 & B \\ B^T & 0 \end{bmatrix} +``` + +- 其中 $B$ 是两个节点集之间的二分邻接矩阵。 + +- **多重图**允许同一对节点之间存在多条边和/或自环。知识图谱通常是多重图:两个实体之间可以有多种关系(例如"出生于"、"居住于"、"工作于")。 + +- **超图**将边推广为一次连接两个以上节点。一条**超边**连接一组节点,表示高阶关系。一篇由五人合著的研究论文是一条连接五个作者节点的超边。 + +- **完全图** $K_n$ 在每一对节点之间都有边。这是全连接层的图类比,也是Transformer操作的结构(每个标记关注每个其他标记)。 + +## 度、路径和连通性 + +- 一个**节点**的**度**是与它相连的边的数量。在无向图中,节点 $i$ 的度为 $d_i = \sum_j A_{ij}$。高度节点是拥有大量连接的"枢纽"。 + +- **度矩阵** $D$ 是一个对角线元素为度的对角矩阵:$D_{ii} = d_i$。这个矩阵出现在整个图论和GNN公式中。 + +- 两个节点之间的**路径**是连接它们的边序列。$i$ 和 $j$ 之间的**最短路径**(或测地线)是边数最少(或在加权图中总权重最小)的路径。**迪杰斯特拉算法**(Dijkstra's algorithm)在 $O((|V| + |E|) \log |V|)$ 时间内找到最短路径。 + +- 如果每对节点之间都存在路径,则图是**连通的**。否则,图有多个**连通分量**:相互之间没有边的孤立子图。 + +- 图的**直径**是任意一对节点之间最长最短路径的长度。它衡量图"分散"的程度。社交网络以直径小而闻名("六度分隔")。 + +- **环**是起点和终点在同一节点的路径。没有环的图是**树**。树是最简单的连通图:$n$ 个节点和恰好 $n-1$ 条边。 + +- **中心性**衡量节点的重要性。**度中心性**就是度数。**介数中心性**计算通过一个节点的最短路径数量。**特征向量中心性**根据节点邻居的重要性分配重要性,得到特征向量方程 $A\mathbf{x} = \lambda \mathbf{x}$(第2章)。谷歌的PageRank是特征向量中心性在有向图上的变体。 + +## 图拉普拉斯算子 + +- **图拉普拉斯算子**也许是图论中最重要的矩阵。定义如下: + +$$L = D - A$$ + +- 其中 $D$ 是度矩阵,$A$ 是邻接矩阵。对于我们的三角形示例: + +```math +L = \begin{bmatrix} 2 & 0 & 0 \\ 0 & 2 & 0 \\ 0 & 0 & 2 \end{bmatrix} - \begin{bmatrix} 0 & 1 & 1 \\ 1 & 0 & 1 \\ 1 & 1 & 0 \end{bmatrix} = \begin{bmatrix} 2 & -1 & -1 \\ -1 & 2 & -1 \\ -1 & -1 & 2 \end{bmatrix} +``` + +- 拉普拉斯算子具有显著的性质: + + - 它始终是**对称的**且**半正定的**(回顾第2章:所有特征值 $\geq 0$)。对于任意向量 $\mathbf{x}$: + +$$\mathbf{x}^T L \mathbf{x} = \sum_{(i,j) \in E} (x_i - x_j)^2$$ + +![图拉普拉斯算子度量信号平滑度:平滑信号在连接节点上具有相似值,非平滑信号变化剧烈](../images/graph_laplacian_smoothness.svg) + + - 这个二次形式度量图上的信号 $\mathbf{x}$ 在边上的变化程度。如果相邻节点值相近,则 $\mathbf{x}^T L \mathbf{x}$ 较小。如果它们差异很大,则较大。拉普拉斯算子度量图上信号的**平滑度**。 + + - 最小特征值始终为0,特征向量为 $\mathbf{1} = [1, 1, \ldots, 1]^T$(常数信号没有变化)。零特征值的数量等于连通分量的数量。 + + - 第二小特征值 $\lambda_2$ 是**代数连通度**(Fiedler值)。它衡量图的连通程度:$\lambda_2 = 0$ 表示图不连通,大的 $\lambda_2$ 表示图紧密连通。 + +- **归一化拉普拉斯算子**通过度进行缩放: + +$$\hat{L} = D^{-1/2} L D^{-1/2} = I - D^{-1/2} A D^{-1/2}$$ + +- 这种归一化确保拉普拉斯算子的性质不依赖于节点度的绝对尺度。项 $D^{-1/2} A D^{-1/2}$ 是**对称归一化邻接矩阵**,它直接出现在GCN公式中(文件3)。 + +## 谱图理论 + +- 图拉普拉斯算子的特征值和特征向量定义了图的**谱**,它们充当图上的傅里叶变换的类似物。 + +- 在经典信号处理中,傅里叶变换将信号分解为频率分量(正弦和余弦)。在图上,拉普拉斯算子的特征向量扮演这些频率基的角色。小特征值的特征向量在图上变化缓慢(低频、平滑),而大特征值的特征向量变化迅速(高频、振荡)。 + +- 信号 $\mathbf{x}$ 在图上的**图傅里叶变换(GFT)** 为: + +$$\hat{\mathbf{x}} = U^T \mathbf{x}$$ + +- 其中 $U$ 是拉普拉斯算子特征向量的矩阵(回顾第2章中的特征分解:$L = U \Lambda U^T$)。逆变换为 $\mathbf{x} = U \hat{\mathbf{x}}$。 + +- 谱域中的**图卷积**是频域中的逐点乘法,正如空间域中的卷积对应于傅里叶域中的乘法(卷积定理,见第8章): + +$$g_\theta \star \mathbf{x} = U \left( (U^T g_\theta) \odot (U^T \mathbf{x}) \right) = U \, \text{diag}(\hat{g}_\theta) \, U^T \mathbf{x}$$ + +- 滤波器 $\hat{g}_\theta$ 是特征值的可学习函数。这是谱域GNN的基础,我们将在文件3中将其简化为实用的GCN。 + +- 计算瓶颈是对 $L$ 进行特征分解,对于有 $n$ 个节点的图需要 $O(n^3)$ 时间。这对于大型图(数百万节点)是不可行的。多项式近似(切比雪夫多项式)完全避免了特征分解,而这种近似直接导致了GCN。 + +## 社区检测 + +- 许多现实世界的图具有**社区结构**:紧密连接的节点簇,簇之间连接稀疏。社交网络有好友群组,生物网络有功能模块,引文网络有研究领域。 + +- **谱聚类**使用拉普拉斯算子特征向量来寻找社区。思路:使用 $L$ 的 $k$ 个最小的非平凡特征向量对每个节点进行嵌入,然后在这个嵌入空间中应用k-means(第6章)。同一社区中的节点在谱嵌入中最终彼此靠近。 + +- 这是可行的,因为Fiedler向量($\lambda_2$ 的特征向量)自然地将图分成两组:正值的节点和负值的节点,沿着最稀疏的连接切开。更高的特征向量进一步细分为更多组。 + +- **模块度** $Q$ 衡量社区划分的质量。它将社区内边的数量与随机图中的期望数量进行比较: + +$$Q = \frac{1}{2|E|} \sum_{ij} \left( A_{ij} - \frac{d_i d_j}{2|E|} \right) \delta(c_i, c_j)$$ + +- 其中 $c_i$ 是节点 $i$ 的社区分配,如果节点在同一个社区则 $\delta$ 为1。$Q$ 的范围从 $-0.5$ 到 $1$,值越高表示社区结构越强。 + +## 现实世界中的图 + +- **社交网络**:节点是人,边是友谊或互动。Facebook有数十亿节点和数千亿条边。这些图通常是稀疏的(每个人有几百个朋友,而不是几十亿),具有小世界性质(短的平均路径长度),以及重尾度分布(少数拥有数百万连接的枢纽节点)。 + +- **分子图**:节点是原子,边是化学键。每个原子有特征(元素类型、电荷、杂化方式),每条键有特征(单键、双键、三键、芳香键)。分子图很小(数十到数百个节点)但高度结构化。从图结构预测分子性质是GNN的一个重要应用。 + +- **知识图谱**:节点是实体(人、地点、概念),边是类型化的关系("出生于"、"首都是"、"是……的实例")。知识图谱为搜索引擎、推荐系统和问答系统提供支持。它们通常是具有数百万实体和数十亿关系的有多重图。 + +- **引文网络**:节点是论文,边是引用(有向的)。聚类揭示研究社区。节点特征包括标题、摘要和出版年份。 + +- **蛋白质相互作用网络**:节点是蛋白质,边表示物理相互作用或功能关联。理解这些图有助于识别药物靶点和疾病机制。 + +- **道路网络与交通**:节点是交叉路口,边是具有距离/时间权重的道路段。这些图上的最短路径算法为导航系统提供动力。自动驾驶运动预测(第11章)将智能体交互表示为图。 + +## 编程任务(使用CoLab或notebook) + +1. 构建一个小型图的邻接矩阵,计算基本性质:每个节点的度、长度为2的路径数量以及图是否连通。 +```python +import jax.numpy as jnp + +# 一个简单图:5个节点 +# 0-1, 0-2, 1-2, 2-3, 3-4 +A = jnp.array([[0, 1, 1, 0, 0], + [1, 0, 1, 0, 0], + [1, 1, 0, 1, 0], + [0, 0, 1, 0, 1], + [0, 0, 0, 1, 0]], dtype=float) + +# 度 +degrees = A.sum(axis=1) +print(f"度数: {degrees}") + +# 长度为2的路径 +A2 = A @ A +print(f"长度为2的路径(节点0到3): {int(A2[0, 3])}") + +# 是否连通?检查 A^(n-1) 是否所有条目非零 +An = jnp.linalg.matrix_power(A + jnp.eye(5), 4) # (A+I)^4 用于可达性 +connected = jnp.all(An > 0) +print(f"连通: {connected}") +``` + +2. 计算图拉普拉斯算子及其特征值。验证最小特征值为0且对应的特征向量为常数。 +```python +import jax.numpy as jnp + +A = jnp.array([[0, 1, 1, 0, 0], + [1, 0, 1, 0, 0], + [1, 1, 0, 1, 0], + [0, 0, 1, 0, 1], + [0, 0, 0, 1, 0]], dtype=float) + +D = jnp.diag(A.sum(axis=1)) +L = D - A + +eigenvalues, eigenvectors = jnp.linalg.eigh(L) +print(f"特征值: {eigenvalues}") +print(f"最小特征向量: {eigenvectors[:, 0]}") +print(f"Fiedler值(代数连通度): {eigenvalues[1]:.4f}") + +# 验证: x^T L x 度量平滑度 +x = jnp.array([1.0, 1.0, 1.0, -1.0, -1.0]) # 两个组 +smoothness = x @ L @ x +print(f"两组信号的平滑度: {smoothness:.2f}") +``` + +3. 对具有两个社区的图执行谱聚类。使用Fiedler向量嵌入节点,并按符号分离。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# 两个社区,各5个节点,弱连接 +A = jnp.zeros((10, 10)) +# 社区1:节点0-4(密集) +for i in range(5): + for j in range(i+1, 5): + A = A.at[i, j].set(1).at[j, i].set(1) +# 社区2:节点5-9(密集) +for i in range(5, 10): + for j in range(i+1, 10): + A = A.at[i, j].set(1).at[j, i].set(1) +# 一条桥接边 +A = A.at[2, 7].set(1).at[7, 2].set(1) + +D = jnp.diag(A.sum(axis=1)) +L = D - A +eigenvalues, eigenvectors = jnp.linalg.eigh(L) + +# Fiedler向量(第二小特征值) +fiedler = eigenvectors[:, 1] +communities = (fiedler > 0).astype(int) + +print(f"Fiedler向量: {fiedler}") +print(f"聚类: {communities}") + +plt.bar(range(10), fiedler, color=["#3498db" if c == 0 else "#e74c3c" for c in communities]) +plt.xlabel("节点"); plt.ylabel("Fiedler向量值") +plt.title("通过Fiedler向量进行谱聚类") +plt.show() +``` diff --git a/chapter 12: graph neural networks/03. graph neural networks.md b/chapter 12: graph neural networks/03. graph neural networks.md new file mode 100644 index 0000000..694fd37 --- /dev/null +++ b/chapter 12: graph neural networks/03. graph neural networks.md @@ -0,0 +1,271 @@ +# 图神经网络 + +*图神经网络通过在连接节点之间传递消息来学习图结构数据。本章涵盖消息传递框架、GCN、GraphSAGE、GIN、过平滑、图池化以及节点/边/图级别的任务;支撑分子性质预测、社交网络分析和推荐系统的核心架构。* + +- 在前面的文件中,我们建立了数学基础:几何深度学习(文件1)告诉我们利用对称性,图论(文件2)提供了节点、边和邻接的语言。现在我们构建直接在图(graph)上操作的神经网络。 + +- 核心挑战:图数据是**不规则**的。与图像(固定网格)或序列(固定顺序)不同,图具有可变数量的节点、可变的连通性,并且没有规范的节点顺序。用于图的神经网络必须处理所有这些情况,同时保持置换等变性(重新标记节点不应改变输出)。 + +## 消息传递框架 + +- 几乎所有的GNN都遵循同样的模式,称为**消息传递**(也称为邻域聚合)。这个想法简单而优雅:每个节点通过从邻居收集信息来更新其表示。 + +- 在每个层 $l$,每个节点 $i$ 做三件事: + + 1. **消息**:节点 $i$ 的每个邻居 $j$ 基于其当前特征计算一条消息 $\mathbf{m}_{j \to i}$。 + 2. **聚合**:节点 $i$ 收集所有传入消息,并使用置换不变函数(求和、均值或取最大值)将它们组合。 + 3. **更新**:节点 $i$ 将聚合的消息与其自身特征结合,产生一个新的表示。 + +- 形式上: + +$$\mathbf{m}_i^{(l)} = \bigoplus_{j \in \mathcal{N}(i)} \phi^{(l)}\left(\mathbf{h}_i^{(l)}, \mathbf{h}_j^{(l)}, \mathbf{e}_{ij}\right)$$ + +$$\mathbf{h}_i^{(l+1)} = \psi^{(l)}\left(\mathbf{h}_i^{(l)}, \mathbf{m}_i^{(l)}\right)$$ + +- 其中 $\mathcal{N}(i)$ 是节点 $i$ 的邻居集合,$\bigoplus$ 是一个置换不变的聚合操作(求和、均值、取最大值),$\phi$ 是消息函数,$\psi$ 是更新函数,$\mathbf{e}_{ij}$ 是可选的边特征。 + +![消息传递:邻居发送消息,置换不变函数聚合它们,然后节点更新其特征](../images/message_passing_gnn.svg) + +- 聚合操作 $\bigoplus$ 必须是置换不变的(邻居处理的顺序无关紧要),以确保整个函数是置换等变的。这直接实现了文件1中的对称性原理。 + +- 经过 $k$ 层消息传递后,每个节点的表示编码了其 **$k$ 跳邻域**的信息:所有在 $k$ 条边内可达的节点。第1层看到直接邻居,第2层看到邻居的邻居,依此类推。这就是局部信息传播以建立全局理解的方式。 + +- GNN的感受野随深度增长,就像CNN的感受野随层数增长一样(第8章)。但与规则网格上的CNN不同,感受野的形状根据图拓扑结构在每个节点上有所不同。 + +## 图卷积网络(GCN) + +- **GCN**(Kipf & Welling,2017)是基础性的GNN架构。它将谱域图卷积(来自文件2)简化为一个优雅、高效的公式。 + +- 从谱域卷积 $g_\theta \star \mathbf{x} = U \, \text{diag}(\hat{g}_\theta) \, U^T \mathbf{x}$ 出发,Kipf和Welling用一阶切比雪夫多项式近似谱域滤波器,这完全避免了计算特征分解。简化后,逐层更新变为: + +$$H^{(l+1)} = \sigma\left(\hat{A} H^{(l)} W^{(l)}\right)$$ + +- 其中: + - $H^{(l)} \in \mathbb{R}^{n \times d}$ 是第 $l$ 层的节点特征矩阵 + - $W^{(l)} \in \mathbb{R}^{d \times d'}$ 是可学习的权重矩阵 + - $\hat{A} = \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2}$ 是带自环的对称归一化邻接矩阵 + - $\tilde{A} = A + I$ 添加了自环(因此每个节点也接收自己的消息) + - $\tilde{D}$ 是 $\tilde{A}$ 的度矩阵 + - $\sigma$ 是一个非线性激活函数(ReLU,如第6章所述) + +- 矩阵乘法 $\hat{A} H^{(l)}$ 是聚合步骤:对于每个节点,它计算其邻居特征(加上自身特征,通过自环)的加权平均。权重矩阵 $W^{(l)}$ 是可学习的变换,在所有节点间共享。激活函数增加了非线性。 + +- 这非常简单:它只是矩阵乘法后接一个学习到的线性映射和激活函数。整个GCN层可以用一行代码实现。通过 $\tilde{D}^{-1/2}$ 的归一化防止具有许多邻居的节点占主导地位:高度节点的消息被按比例缩小。 + +- 在消息传递框架中,GCN使用: + - 消息:$\phi(\mathbf{h}_j) = \mathbf{h}_j$(只发送你的特征) + - 聚合:归一化和(按度加权) + - 更新:线性变换 + 激活函数 + +## GraphSAGE + +- GCN是**直推式**的:它在训练时需要完整的图,无法处理新出现的未知节点。如果新用户加入社交网络,GCN必须对整个图重新训练。**GraphSAGE**(Hamilton等,2017)通过**归纳式**方法解决了这个问题。 + +- 关键思想是**邻域采样**:不是使用所有邻居,而是采样一个固定大小的子集。这使得计算独立于完整的图结构,并允许推广到未见过的节点和图。 + +- 节点 $i$ 的GraphSAGE更新: + +$$\mathbf{h}_i^{(l+1)} = \sigma\left(W^{(l)} \cdot \text{CONCAT}\left(\mathbf{h}_i^{(l)}, \text{AGG}\left(\{\mathbf{h}_j^{(l)} : j \in \mathcal{S}(i)\}\right)\right)\right)$$ + +- 其中 $\mathcal{S}(i)$ 是一个**采样**的邻居子集(例如,从500个邻居中随机采样10个)。CONCAT操作显式地将节点自身的特征与聚合后的邻居特征分开,让网络学习"自身"和"邻域"的不同变换。 + +- GraphSAGE支持多种聚合函数: + - **均值(Mean)**:$\text{AGG} = \frac{1}{|\mathcal{S}|} \sum_{j \in \mathcal{S}} \mathbf{h}_j$(简单,有效) + - **LSTM**:将采样的邻居通过LSTM(但这引入了顺序依赖,一定程度上违反了置换不变性) + - **池化(Pool)**:$\text{AGG} = \max(\{\sigma(W_{\text{pool}} \mathbf{h}_j + \mathbf{b})\})$(非线性变换后取最大值) + +- 采样策略使GraphSAGE可扩展到非常大的图。训练使用节点的小批量:对于每个目标节点,在第1层采样 $k_1$ 个邻居,然后对于其中每个邻居在第2层采样 $k_2$ 个邻居。使用 $k_1 = k_2 = 10$ 和2层,每个节点的计算树最多有 $10 \times 10 = 100$ 个节点,与图的大小无关。 + +## 图同构网络(GIN) + +- 不同的GNN架构具有不同的**表达能力**:它们区分结构不同之图的能力。GCN和GraphSAGE虽然在实践中有效,但理论上在能区分哪些图结构方面是受限的。 + +- 衡量GNN表达能力的理论工具是**Weisfeiler-Lehman(WL)测试**,这是一个用于测试图同构(两个图是否结构相同)的经典算法。WL测试通过将每个节点的标签与其邻居标签的多重集一起哈希,迭代地精炼节点标签。 + +- **GIN**(Xu等,2019)被设计为具有与WL测试同等的表达能力,使其成为最强大的消息传递GNN(在消息传递的理论限制内)。关键洞察:聚合函数必须在多重集上是**单射**的(不同的邻居特征多重集必须产生不同的聚合值)。 + +- 求和聚合在多重集上是单射的(求和 $\{1, 1, 2\}$ 得到4,而 $\{1, 3\}$ 也得到4,但在具有足够维度的特征向量上,不同多重集的和一般而言是不同的)。均值和取最大值不是单射的:均值无法区分 $\{1, 1\}$ 和 $\{2, 2\}$,取最大值无法区分 $\{1, 2, 3\}$ 和 $\{1, 1, 3\}$。 + +- GIN更新: + +$$\mathbf{h}_i^{(l+1)} = \text{MLP}^{(l)}\left((1 + \epsilon^{(l)}) \cdot \mathbf{h}_i^{(l)} + \sum_{j \in \mathcal{N}(i)} \mathbf{h}_j^{(l)}\right)$$ + +- 其中 $\epsilon$ 是一个可学习的标量(或固定为0),MLP提供非线性、单射的映射。求和聚合保留了多重集结构,MLP可以学会区分任意两个不同的聚合值。 + +## 过平滑 + +- GNN的一个主要挑战是**过平滑**:随着层数增加,所有节点表示收敛到相同的值,失去区分不同节点的能力。 + +![过平滑:在第1层各不相同的节点特征在更深层逐渐融合为统一特征](../images/over_smoothing_gnn.svg) + +- 其机制是直观的。每个消息传递层将节点的特征与其邻居的特征进行平均。经过多轮平均后,每个节点已经"看到"(并混合了)其连通分量中的每个其他节点。这些特征变成了统一的平均值,相当于将图像模糊太多次直到变成纯色的图类比。 + +- 形式上,重复应用归一化邻接矩阵 $\hat{A}$ 收敛到一个秩为1的矩阵(每一行都变得与图上随机游走的平稳分布成正比)。这与幂迭代收敛到主特征向量的过程相同(第2章)。 + +- 过平滑将GNN限制在很浅的深度(通常2-4层),而CNN和Transformer可以从几十或数百层中受益。这意味着每个节点只能看到有限的邻域,这对于需要长距离信息的任务来说是有问题的。 + +- 缓解方法包括: + - **残差连接**(来自ResNet,第8章):$\mathbf{h}_i^{(l+1)} = \mathbf{h}_i^{(l+1)} + \mathbf{h}_i^{(l)}$,保留来自较早层的信息。 + - **跳跃知识(Jumping Knowledge)**:拼接或注意力池化来自所有层的表示,而不仅仅是最后一层。 + - **DropEdge**:训练期间随机移除边,减缓信息传播。 + - **图Transformer(Graph Transformer)**(文件4):用全局注意力绕过局部消息传递的瓶颈。 + +## 图池化 + +- 对于**图级别任务**(预测整个图的属性,如分子的毒性),我们需要将所有节点表示折叠成一个单一的图级别向量。这就是**图池化**,是CNN中全局平均池化的图类比(第8章)。 + +- 最简单的方法是**读出(readout)**:对所有节点特征应用一个置换不变函数: + +$$\mathbf{h}_G = \text{READOUT}(\{\mathbf{h}_i^{(L)} : i \in V\}) = \sum_i \mathbf{h}_i^{(L)} \quad \text{或} \quad \frac{1}{|V|} \sum_i \mathbf{h}_i^{(L)} \quad \text{或} \quad \max_i \mathbf{h}_i^{(L)}$$ + +- 这就是文件1中的DeepSets聚合,应用于最终的GNN层之后。求和保留了大小信息(一个有100个节点的图会比只有10个节点的图具有更大的和),而均值对大小进行了归一化。 + +- **分层池化**逐步粗化图,模仿CNN逐步下采样图像的方式。在每个层级,节点组被合并为"超节点": + +- **DiffPool**(可微分池化)学习一个软分配矩阵 $S^{(l)} \in \mathbb{R}^{n_l \times n_{l+1}}$,将每个节点分配到一个簇: + +$$X^{(l+1)} = S^{(l)T} H^{(l)}, \quad A^{(l+1)} = S^{(l)T} A^{(l)} S^{(l)}$$ + +- 分配矩阵由一个单独的GNN预测,使聚类变得端到端可微分。这创建了一个层次结构:原始图 → 具有较少节点的粗化图 → 更粗的图 → 单个节点(图表示)。 + +- **TopKPool**采用更简单的方法:为每个节点学习一个标量分数,保留得分最高的 top-$k$ 个节点,丢弃其余节点。这是一种硬选择(而非软分配),计算上比DiffPool更廉价。 + +## 异构图 + +- 截至目前的所有GNN都假设一个**同构图**:一种节点类型,一种边类型。但大多数现实世界的图是**异构**的:多种节点类型和多种边类型。知识图谱有人物节点、组织节点和位置节点,由"工作于"、"出生于"和"位于"边连接。推荐系统有用户节点和物品节点,由"已购买"、"已浏览"和"已评价"边连接。 + +- 异构图有一个**模式**(也称为元图),定义了允许的节点类型和边类型。每个边类型连接特定的源类型到特定的目标类型。例如,"工作于"连接 Person → Organisation。 + +- **关系GCN(R-GCN)**(Schlichtkrull等,2018)通过为每种边类型使用单独的权重矩阵来处理异构边: + +$$\mathbf{h}_i^{(l+1)} = \sigma\left(\sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)} \frac{1}{|\mathcal{N}_r(i)|} W_r^{(l)} \mathbf{h}_j^{(l)} + W_0^{(l)} \mathbf{h}_i^{(l)}\right)$$ + +- 其中 $\mathcal{R}$ 是边类型的集合,$\mathcal{N}_r(i)$ 是通过关系 $r$ 连接到节点 $i$ 的邻居集合,$W_r$ 是关系 $r$ 特有的权重矩阵。自连接 $W_0$ 单独处理节点自身的特征。 + +- 问题:当关系类型很多时,参数数量爆炸(每种关系一个 $d \times d$ 矩阵)。R-GCN通过**基分解**缓解这一问题:$W_r = \sum_{b=1}^{B} a_{rb} V_b$,其中 $V_b$ 是共享的基矩阵,$a_{rb}$ 是每个关系的标量系数。这类似于低秩分解(第2章):关系特定的矩阵生活在一个低维子空间中。 + +- **异构图表Transformer(HGT)**(Hu等,2020)将注意力机制应用于异构图。关键洞察:注意力应同时依赖于节点类型和连接它们的边类型。HGT为查询、键和值使用类型特定的投影矩阵: + +$$\text{Attention}(i, j) = \left(W_{\tau(i)}^Q \mathbf{h}_i\right)^T \cdot \frac{W_{\phi(i,j)}^{\text{ATT}}}{\sqrt{d}} \cdot \left(W_{\tau(j)}^K \mathbf{h}_j\right)$$ + +- 其中 $\tau(i)$ 是节点 $i$ 的类型,$\phi(i,j)$ 是它们之间的边类型。这确保了模型对不同的关系类型使用不同的注意力权重:一篇论文关注其作者时,应使用与关注其参考文献时不同的注意力权重。 + +- **基于元路径的方法**定义通过模式的含义路径(例如,作者 → 论文 → 作者表示合著关系),并沿着这些路径聚合信息。**HAN**(异构图注意力网络)在两个层次应用注意力:在每个元路径内(沿此路径哪些邻居重要?)和跨元路径(哪些关系模式重要?)。 + +## 链接预测与知识图谱补全 + +- **链接预测**提出的问题是:给定现有边,哪些缺失的边可能存在?这是知识图谱补全(预测缺失的事实)、推荐(预测用户会喜欢哪些物品)和社交网络分析(预测未来的友谊)的核心任务。 + +- **基于嵌入的方法**为每个实体学习一个向量,为每个关系学习一个变换,然后通过实体和关系的匹配程度对潜在边进行评分: + +- **TransE**将关系建模为嵌入空间中的平移:如果 $(h, r, t)$ 是一个有效的三元组(头实体,关系,尾实体),那么 $\mathbf{h} + \mathbf{r} \approx \mathbf{t}$。评分函数为 $f(h, r, t) = -\|\mathbf{h} + \mathbf{r} - \mathbf{t}\|$。直观地说,关系向量在嵌入空间中将头实体"移动"到尾实体。 + +- **RotatE**将关系建模为复空间中的旋转:$\mathbf{t} = \mathbf{h} \circ \mathbf{r}$,其中 $\circ$ 是逐元素复数乘法,$|\mathbf{r}_i| = 1$(单位复数就是旋转)。这可以建模TransE无法处理的对称性、反对称性、反转和复合模式。 + +- **ComplEx**使用复数值嵌入和埃尔米特点积,使其能够建模非对称关系(如果A是B的老板,B不是A的老板)。 + +- 基于GNN的链接预测通过消息传递计算节点嵌入,然后使用端点嵌入对边进行评分。这结合了GNN的结构推理能力和嵌入方法的关系建模能力。GNN编码器捕获了单嵌入方法所遗漏的多跳邻域结构。 + +## 任务类型 + +- GNN解决三类任务: + +- **节点级别任务**:为每个节点预测一个属性。示例:对社交网络中的用户进行分类(机器人还是人类),预测相互作用网络中每个蛋白质的功能,半监督节点分类(标记少数节点,预测其余节点)。输出是节点嵌入 $\mathbf{h}_i^{(L)}$ 经过一个分类器。 + +- **边级别任务**:为每条边预测一个属性或预测边是否存在。示例:链接预测(这两个用户会成为朋友吗?),知识图谱补全(这个关系在这些实体间成立吗?),药物-药物相互作用预测。输出通常使用两个端点节点的嵌入:$\hat{y}_{ij} = f(\mathbf{h}_i, \mathbf{h}_j)$,其中 $f$ 是点积、拼接+MLP或其他组合。 + +- **图级别任务**:为整个图预测一个属性。示例:分子性质预测(这个分子有毒吗?),图分类(这个社交网络是机器人网络吗?),图生成(设计一个具有期望性质的分子)。输出使用图池化产生 $\mathbf{h}_G$,然后进行分类或回归。 + +## 编程任务(使用CoLab或notebook) + +1. 使用归一化邻接矩阵从头实现一个单层GCN。应用于一个小型图,观察节点特征如何被平滑。 +```python +import jax +import jax.numpy as jnp + +# 图:5个节点,简单链带分支 +A = jnp.array([[0, 1, 0, 0, 0], + [1, 0, 1, 0, 0], + [0, 1, 0, 1, 1], + [0, 0, 1, 0, 0], + [0, 0, 1, 0, 0]], dtype=float) + +# 添加自环 +A_hat = A + jnp.eye(5) +D_hat = jnp.diag(A_hat.sum(axis=1)) +D_inv_sqrt = jnp.diag(1.0 / jnp.sqrt(A_hat.sum(axis=1))) +A_norm = D_inv_sqrt @ A_hat @ D_inv_sqrt + +# 节点特征:one-hot 单位阵 +H = jnp.eye(5) + +# 权重矩阵(随机初始化) +rng = jax.random.PRNGKey(0) +W = jax.random.normal(rng, (5, 3)) * 0.5 + +# GCN层:H' = ReLU(A_norm @ H @ W) +H_new = jax.nn.relu(A_norm @ H @ W) + +print("原始特征(one-hot):") +print(H) +print("\n经过GCN层后:") +print(jnp.round(H_new, 3)) +print("\n注意:连接的节点现在具有相似的表示") +``` + +2. 实现具有求和聚合(GIN风格)和均值聚合(GCN风格)的消息传递。展示求和能区分均值无法区分的多重集。 +```python +import jax.numpy as jnp + +# 两个具有相同均值的不同邻居多重集 +# 节点A:邻居特征为 [1, 1, 1, 1] (四个邻居,都是1) +# 节点B:邻居特征为 [2, 2] (两个邻居,都是2) + +neighbours_A = jnp.array([[1.0], [1.0], [1.0], [1.0]]) +neighbours_B = jnp.array([[2.0], [2.0]]) + +# 均值聚合 +mean_A = neighbours_A.mean(axis=0) +mean_B = neighbours_B.mean(axis=0) +print(f"均值 A: {mean_A}, 均值 B: {mean_B}, 相同: {jnp.allclose(mean_A, mean_B)}") + +# 求和聚合 +sum_A = neighbours_A.sum(axis=0) +sum_B = neighbours_B.sum(axis=0) +print(f"求和 A: {sum_A}, 求和 B: {sum_B}, 相同: {jnp.allclose(sum_A, sum_B)}") +print("\n求和能区分这些多重集;均值不能!") +``` + +3. 演示过平滑。重复应用归一化邻接矩阵,观察节点特征收敛。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# 随机图 +A = jnp.array([[0,1,1,0,0,0], + [1,0,1,0,0,0], + [1,1,0,1,0,0], + [0,0,1,0,1,1], + [0,0,0,1,0,1], + [0,0,0,1,1,0]], dtype=float) + +A_hat = A + jnp.eye(6) +D_inv_sqrt = jnp.diag(1.0 / jnp.sqrt(A_hat.sum(axis=1))) +A_norm = D_inv_sqrt @ A_hat @ D_inv_sqrt + +# 初始特征:每个节点各不相同 +H = jnp.array([[1,0], [0,1], [1,1], [-1,0], [0,-1], [-1,-1]], dtype=float) + +distances = [] +for k in range(20): + H = A_norm @ H + # 衡量特征的区别程度(节点间的标准差) + spread = jnp.std(H, axis=0).mean() + distances.append(float(spread)) + +plt.plot(distances, "o-") +plt.xlabel("消息传递轮数") +plt.ylabel("特征分散度(节点间标准差)") +plt.title("过平滑:特征随深度增加而收敛") +plt.show() +``` diff --git a/chapter 12: graph neural networks/04. graph attention networks.md b/chapter 12: graph neural networks/04. graph attention networks.md new file mode 100644 index 0000000..709d3ac --- /dev/null +++ b/chapter 12: graph neural networks/04. graph attention networks.md @@ -0,0 +1,258 @@ +# 图注意力网络 + +*图注意力网络将均匀的邻居聚合替换为学习到的、依赖数据的加权。本章涵盖GAT、多头图注意力、GATv2、图Transformer、位置和结构编码以及可扩展性* + +- 在GCN(文件3)中,每个节点使用由图结构确定的固定权重(归一化邻接矩阵)聚合其邻居特征。一个有三个邻居的节点会给每个邻居大致相等的权重($\approx 1/3$)。但并非所有邻居都同等重要:来自密切合作者的消息应比来自远方熟人的消息更重要。 + +- **图注意力网络**通过使用与Transformer(第7章)相同的注意力机制来学习**关注哪些邻居**,从而解决了这一问题。与固定的、基于结构的权重不同,每个节点在其邻居上计算动态的、基于内容的注意力分数。 + +## GAT:图注意力网络 + +- **GAT**(Veličković等,2018)计算每个节点与其邻居之间的注意力系数。对于节点 $i$ 和邻居 $j$: + +$$e_{ij} = \text{LeakyReLU}\left(\mathbf{a}^T \left[W\mathbf{h}_i \| W\mathbf{h}_j\right]\right)$$ + +- 其中 $W \in \mathbb{R}^{d' \times d}$ 是共享的线性变换,$\|$ 表示拼接,$\mathbf{a} \in \mathbb{R}^{2d'}$ 是可学习的注意力向量。分数 $e_{ij}$ 衡量节点 $j$ 的特征对节点 $i$ 的重要程度。 + +- 原始分数使用softmax在所有邻居之间进行归一化: + +$$\alpha_{ij} = \text{softmax}_j(e_{ij}) = \frac{\exp(e_{ij})}{\sum_{k \in \mathcal{N}(i)} \exp(e_{ik})}$$ + +- 这确保了每个节点邻域上的注意力权重之和为1,就像Transformer注意力一样(第7章)。节点更新后的特征为: + +$$\mathbf{h}_i' = \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij} W\mathbf{h}_j\right)$$ + +![GCN为所有邻居分配固定的等权重;GAT学习依赖数据的注意力权重](../images/gat_attention_weights.svg) + +- 与GCN的关键区别:权重 $\alpha_{ij}$ 是**从数据中学习**的,而非由图结构固定。节点可以学会关注信息量最大的邻居,同时忽略噪声或无关的邻居。 + +- 注意,注意力仅在边上计算(节点 $i$ 只关注其邻居 $\mathcal{N}(i)$),而不是在所有节点对之间。这使得计算量与边的数量成正比,而不是节点数的平方。 + +## 多头图注意力 + +- 正如在Transformer中(第7章),**多头注意力**并行运行 $K$ 个独立的注意力机制,每个都有自己的参数 $W^k$ 和 $\mathbf{a}^k$。结果在中间层进行拼接,在最终层取平均: + +$$\mathbf{h}_i' = \Big\|_{k=1}^{K} \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij}^k W^k \mathbf{h}_j\right)$$ + +- 每个头可以关注邻域的不同方面:一个头可能关注结构特征,另一个关注语义相似性。这与Transformer中多头注意力的动机相同:不同的头捕获不同类型的关系。 + +- 使用 $K$ 个头和每个头输出维度 $d'$,拼接后的输出维度为 $K \times d'$。最后一层通常使用平均而不是拼接来产生固定大小的输出。 + +## GATv2:修复静态注意力 + +- 原始GAT有一个微妙的限制:其注意力函数是**静态的**(也称为基于排序的)。注意力分数取决于拼接 $[W\mathbf{h}_i \| W\mathbf{h}_j]$,但由于注意力向量 $\mathbf{a}$ 在拼接之后应用,它可以分解为两个独立的分量:$\mathbf{a}^T [W\mathbf{h}_i \| W\mathbf{h}_j] = \mathbf{a}_1^T W\mathbf{h}_i + \mathbf{a}_2^T W\mathbf{h}_j$。 + +- 这意味着对于给定节点 $i$,邻居的排序完全由邻居的特征 $\mathbf{h}_j$ 决定(项 $\mathbf{a}_1^T W\mathbf{h}_i$ 在 $i$ 的所有邻居中是常数)。注意力排名并不真正依赖于查询节点的特征。节点 $i$ 和节点 $k$ 将以完全相同的方式对同一组邻居进行排序,这限制了表达能力。 + +- **GATv2**(Brody等,2022)通过在注意力向量之前应用非线性函数来修复这个问题: + +$$e_{ij} = \mathbf{a}^T \text{LeakyReLU}\left(W \left[\mathbf{h}_i \| \mathbf{h}_j\right]\right)$$ + +- 将LeakyReLU移到计算内部意味着注意力分数是联合特征的非线性函数,不能分解为独立项。这使得注意力变为**动态**:邻居的排序现在依赖于特定的查询节点。GATv2严格比GAT更具表达能力,且没有额外的计算成本。 + +## 图Transformer + +- 标准消息传递GNN受到图拓扑的限制:一个节点只能关注其直接邻居。经过 $k$ 层后,来自 $k$ 跳邻居的信息已通过多个聚合步骤混合,失去了保真度。这种局部瓶颈(再加上文件3中的过平滑)限制了捕获长距离依赖关系的能力。 + +- **图Transformer**通过将**全局自注意力**应用于所有节点对(无论它们之间是否有边)来突破这个瓶颈。每个节点可以在单层中关注每个其他节点,就像标准Transformer一样(第7章)。 + +- 基本思想:将所有节点视为标记(token),应用Transformer自注意力: + +$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$ + +- 其中 $Q = XW_Q$,$K = XW_K$,$V = XW_V$ 是节点特征 $X$ 的查询、键和值投影(与第7章完全相同)。这是完全连接图(完全图 $K_n$,文件2)上的GNN。 + +- 问题:完全连接图忽略了实际的图结构。边信息(谁实际连接到谁)丢失了。两种方法恢复了这一点: + +- **Graphormer**(Ying等,2021)通过注意力分数中的**偏置项**将图结构注入Transformer: + +$$A_{ij} = \frac{(\mathbf{h}_i W_Q)(W_K^T \mathbf{h}_j^T)}{\sqrt{d_k}} + b_{\text{spatial}}(i, j) + b_{\text{edge}}(i, j)$$ + +- 空间偏置 $b_{\text{spatial}}$ 编码节点 $i$ 和 $j$ 之间的最短路径距离。边偏置 $b_{\text{edge}}$ 编码沿最短路径的边特征。此外,Graphormer使用**中心性编码**,将节点的度数添加到其输入嵌入中,为模型提供关于每个节点结构角色的信息。 + +- **GPS**(通用、强大、可扩展的图Transformer,Rampášek等,2022)在每一层中结合了局部消息传递和全局注意力: + +$$\mathbf{h}_i' = \text{MLP}\left(\mathbf{h}_i^{\text{MPNN}} + \mathbf{h}_i^{\text{Attention}}\right)$$ + +- 每一层同时应用标准GNN(用于局部结构)和Transformer(用于全局上下文),然后组合结果。这获得了两个世界的优点:来自消息传递的局部结构和来自注意力的长距离依赖关系。 + +## 位置编码与结构编码 + +- 序列上的Transformer使用位置编码(第7章)来注入顺序信息。图没有规范的顺序,因此需要特定于图的编码。 + +- **拉普拉斯特征向量编码**使用图拉普拉斯算子(文件2)的特征向量作为位置特征。$k$ 个最小的非平凡特征向量提供了图的谱嵌入:在图中"附近"的节点具有相似的特征向量值。这些被拼接到节点特征中。 + +- 一个微妙之处:拉普拉斯特征向量有符号模糊性(如果 $\mathbf{u}$ 是特征向量,$-\mathbf{u}$ 也是)。模型必须对这些符号翻转保持不变。解决方案包括在训练期间使用随机符号翻转作为数据增强,或学习符号不变的变换。 + +- **随机游走编码**计算从节点 $i$ 开始的随机游走经过 $k$ 步后返回节点 $i$ 的概率,对于 $k = 1, 2, \ldots, K$。这些概率编码了局部结构信息:密集簇中的节点具有高的返回概率,而稀疏区域中的节点返回概率低。着陆概率 $p_{ii}^{(k)} = (A_{\text{rw}}^k)_{ii}$,其中 $A_{\text{rw}} = D^{-1}A$ 是随机游走转移矩阵。 + +- **度数编码**简单地将节点度数作为一个特征添加。这出奇地有效,因为度数是一个强大的结构信号:叶节点(度数为1)、桥接节点和枢纽节点的行为不同。 + +- 这些编码提供了普通Transformer所缺乏的结构信息,使图Transformer在需要长距离推理的任务上能够超越标准消息传递GNN。 + +## 可扩展性 + +- GNN的基本可扩展性挑战在于图可能拥有数百万个节点和数十亿条边。在完整图上训练GNN需要将所有节点特征和整个邻接矩阵存储在内存中,这通常是不可行的。 + +- GNN的**小批量训练**比图像或序列更复杂,因为节点之间是相互连接的。朴素地采样一批节点需要它们的邻居(第1层)、邻居的邻居(第2层),依此类推。这种**邻域爆炸**意味着一个包含1000个目标节点的小批量可能需要计算图中数百万个节点。 + +- **邻域采样**(GraphSAGE风格,文件3)通过每层每个节点采样固定数量的邻居来限制爆炸。使用2层和每层15个样本,每个目标节点的子图最多有 $15^2 = 225$ 个节点,与完整图的大小无关。 + +- **Cluster-GCN**(Chiang等,2019)使用图聚类算法(例如METIS)将图划分为簇,然后一次在一个簇上训练。簇内边是密集的(大多数邻居在同一个簇内),因此子图捕获了相关结构。跨簇边通过偶尔包含簇之间的边来处理。 + +- **图Transformer的可扩展性**更困难,因为全局注意力是 $O(n^2)$ 的。对于具有数百万个节点的图,完整的注意力是不可行的。解决方案包括: + - 稀疏注意力模式(只关注图中距离最近的 $k$ 个节点) + - 线性注意力近似 + - 将局部消息传递(廉价,$O(|E|)$)与粗化图上的全局注意力(更少的节点)相结合 + +## 时序图与动态图 + +- 我们迄今为止研究的图是**静态**的:节点、边和特征都是固定的。但许多现实世界的图会**随时间演化**:新用户加入社交网络、金融交易创建边、交通模式全天变化、分子相互作用发生波动。 + +- **时序图**为每条边增加一个时间戳:$(i, j, t)$ 表示节点 $i$ 在时间 $t$ 与节点 $j$ 发生了交互。挑战在于学习同时捕获图结构和时序动态的表示。 + +- 存在两种范式: + +- **离散时间动态图(DTDG)**:图被表示为一系列快照 $G_1, G_2, \ldots, G_T$,每个时间步一个。GNN处理每个快照,RNN或时序注意力机制捕获快照间的演化。这很简单,但丢失了精细的时间信息(快照之间的事件丢失了),并且需要选择快照频率。 + +- **连续时间动态图(CTDG)**:事件被建模为带时间戳的交互流。每个事件 $(i, j, t)$ 在其发生的准确时间更新节点 $i$ 和 $j$ 的表示。这保留了所有时序信息。 + +- **时序图网络(TGN)**(Rossi等,2020)是领先的CTDG架构。每个节点维护一个**记忆状态** $\mathbf{s}_i(t)$,每当节点参与交互时更新: + +$$\mathbf{s}_i(t^+) = \text{GRU}\left(\mathbf{s}_i(t^-), \; \mathbf{m}_i(t)\right)$$ + +- 其中 $\mathbf{m}_i(t)$ 是从交互中计算出的消息(结合了两个节点的特征、边特征和时间编码)。GRU(第6章)选择性地保留和遗忘过去的信息,使记忆能够捕获长期模式,同时适应近期事件。 + +- **时间编码**表示自上次交互以来经过的时间,类似于Transformer中的位置编码(第7章)。常用方法使用可学习的傅里叶特征: + +$$\Phi(t) = \left[\cos(\omega_1 t), \sin(\omega_1 t), \ldots, \cos(\omega_d t), \sin(\omega_d t)\right]$$ + +- 这为模型提供了时间间隔的丰富表示:"该用户上次活跃是5分钟前"与"3个月前"以不同的方式嵌入。 + +- **时序图注意力(TGAT)**在节点的时间邻域上应用自注意力:一组最近的交互,每个交互同时按特征相关性(如GAT)和时间近度加权。来自遥远过去的交互自然地被降低权重。 + +- 应用包括欺诈检测(金融图中的异常交易模式)、交通预测(从历史流量模式预测拥堵)、社交网络动态(预测病毒内容传播)以及随时间推移的药物相互作用预测。 + +## 编程任务(使用CoLab或notebook) + +1. 从头实现一个单头GAT注意力。计算节点与其邻居之间的注意力权重,并验证权重之和为1。 +```python +import jax +import jax.numpy as jnp + +rng = jax.random.PRNGKey(0) +k1, k2, k3 = jax.random.split(rng, 3) + +n_nodes, d_in, d_out = 5, 4, 3 + +# 随机节点特征 +H = jax.random.normal(k1, (n_nodes, d_in)) + +# 可学习参数 +W = jax.random.normal(k2, (d_in, d_out)) * 0.5 +a = jax.random.normal(k3, (2 * d_out,)) * 0.5 + +# 邻接(节点0连接到1, 2, 3) +neighbours_of_0 = [1, 2, 3] + +# 变换特征 +Wh = H @ W # (n_nodes, d_out) + +# 计算节点0的注意力分数 +h_i = Wh[0] +scores = [] +for j in neighbours_of_0: + h_j = Wh[j] + e_ij = jnp.dot(a, jnp.concatenate([h_i, h_j])) + e_ij = jax.nn.leaky_relu(e_ij, negative_slope=0.2) + scores.append(float(e_ij)) + +scores = jnp.array(scores) +alpha = jax.nn.softmax(scores) + +print(f"原始分数: {scores}") +print(f"注意力权重: {alpha}") +print(f"权重之和: {alpha.sum():.4f}") + +# 加权聚合 +h_new = sum(alpha[k] * Wh[neighbours_of_0[k]] for k in range(len(neighbours_of_0))) +print(f"更新后的节点0特征: {h_new}") +``` + +2. 比较GCN(固定权重)和GAT(学习权重)的聚合。展示GAT可以为邻居分配不同的权重,而GCN统一对待它们。 +```python +import jax +import jax.numpy as jnp + +# 4个节点:节点0连接到1, 2, 3 +A = jnp.array([[0,1,1,1], + [1,0,0,0], + [1,0,0,0], + [1,0,0,0]], dtype=float) + +# 特征:节点1非常相关,节点2是噪声,节点3中等 +H = jnp.array([[0.0, 0.0], # 节点0 + [1.0, 0.0], # 节点1(信号) + [0.0, 0.0], # 节点2(噪声) + [0.5, 0.0]]) # 节点3(中等) + +# GCN:归一化邻接权重 +A_hat = A + jnp.eye(4) +D_inv = jnp.diag(1.0 / A_hat.sum(axis=1)) +gcn_weights = (D_inv @ A_hat)[0] # 节点0的权重 +print(f"GCN中节点0的权重: {gcn_weights}") +print(" → 所有邻居获得大致相等的权重") + +# GAT:学习到的注意力(模拟) +# 假设注意力机制学会关注节点1 +gat_weights = jnp.array([0.1, 0.7, 0.05, 0.15]) # 学习到的 +print(f"\nGAT中节点0的权重: {gat_weights}") +print(" → 最具信息量的节点1获得最多关注") + +gcn_output = gcn_weights @ H +gat_output = gat_weights @ H +print(f"\nGCN输出: {gcn_output} (被噪声稀释)") +print(f"GAT输出: {gat_output} (聚焦于信号)") +``` + +3. 演示位置编码的益处。计算图的拉普拉斯特征向量编码,展示结构相似的节点获得相似的编码。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# 杠铃图:两个团由一条桥连接 +n = 10 +A = jnp.zeros((n, n)) +# 团1:节点0-4 +for i in range(5): + for j in range(i+1, 5): + A = A.at[i,j].set(1).at[j,i].set(1) +# 团2:节点5-9 +for i in range(5, 10): + for j in range(i+1, 10): + A = A.at[i,j].set(1).at[j,i].set(1) +# 桥 +A = A.at[4,5].set(1).at[5,4].set(1) + +D = jnp.diag(A.sum(axis=1)) +L = D - A +eigenvalues, eigenvectors = jnp.linalg.eigh(L) + +# 使用前3个非平凡特征向量作为位置编码 +pe = eigenvectors[:, 1:4] + +print("拉普拉斯位置编码:") +for i in range(n): + group = "团1" if i < 5 else "团2" + bridge = " (桥)" if i in [4, 5] else "" + print(f" 节点 {i} ({group}{bridge}): {pe[i]}") + +plt.scatter(pe[:5, 0], pe[:5, 1], c="#3498db", s=80, label="团1") +plt.scatter(pe[5:, 0], pe[5:, 1], c="#e74c3c", s=80, label="团2") +plt.scatter(pe[[4,5], 0], pe[[4,5], 1], c="black", s=120, marker="*", + label="桥节点", zorder=5) +plt.legend(); plt.grid(True) +plt.title("拉普拉斯特征向量位置编码") +plt.xlabel("特征向量 1"); plt.ylabel("特征向量 2") +plt.show() +``` diff --git a/chapter 12: graph neural networks/05. 3d graph networks.md b/chapter 12: graph neural networks/05. 3d graph networks.md new file mode 100644 index 0000000..557f3a6 --- /dev/null +++ b/chapter 12: graph neural networks/05. 3d graph networks.md @@ -0,0 +1,274 @@ +# 3D图网络 + +*3D图网络将GNN扩展到具有空间几何的数据,其中必须正确处理旋转和平移。本章涵盖几何图、SE(3)/E(n)等变性、SchNet、DimeNet、EGNN、张量场网络以及分子性质预测、蛋白质结构、材料科学和药物发现中的应用——从3D物理世界中学习的架构。* + +- 文件3和4中的GNN操作于抽象图:节点有特征,边编码连接性,但没有3D空间的概念。社交网络图没有几何结构。但许多最具影响力的GNN应用涉及存在于**物理3D空间**中的数据:分子、蛋白质、晶体、点云。对于这些数据,节点的空间位置携带了抽象GNN所忽略的关键信息。 + +- 挑战在于3D数据具有**几何对称性**(文件1):旋转分子不会改变其性质,平移也是如此。3D GNN必须尊重这些对称性。一个会在旋转分子时改变的能量预测在物理上是错误的。 + +## 几何图 + +- **几何图**是嵌入在3D空间中的图。每个节点 $i$ 除了其特征向量 $\mathbf{h}_i$ 之外,还有一个位置 $\mathbf{r}_i \in \mathbb{R}^3$。边可以基于空间邻近性(连接距离在 $r_{\text{cut}}$ 内的节点)而不是基于显式的化学键来定义。 + +- 对于分子,几何图以原子为节点(特征包括:元素类型、电荷等),化学键为边。3D位置 $\mathbf{r}_i$ 是原子坐标,由量子力学或实验测量(X射线晶体学、冷冻电镜)确定。 + +- 对于点云(来自LiDAR或3D扫描仪,第8章和第11章),每个点是一个节点,具有位置和可选特征(颜色、强度)。边连接附近的点,形成**k最近邻(kNN)图**或半径图。 + +- 用于消息传递的关键几何量: + + - **原子间距离**:$d_{ij} = \|\mathbf{r}_i - \mathbf{r}_j\|$。距离对旋转和平移保持不变。具有相同原子间距离的两个分子具有相同的形状,无论朝向如何。 + + - **键角**:节点 $i$ 处向量 $\mathbf{r}_j - \mathbf{r}_i$ 和 $\mathbf{r}_k - \mathbf{r}_i$ 之间的角度 $\theta_{ijk}$。角度捕获了超越成对距离的局部几何结构。 + + - **二面角(扭转角)**:由 $(i, j, k)$ 和 $(j, k, l)$ 定义的平面之间的角度 $\phi_{ijkl}$。二面角捕获结构在3D中的扭转方式,对蛋白质主链几何结构至关重要。 + + - **相对位置向量**:$\mathbf{r}_{ij} = \mathbf{r}_j - \mathbf{r}_i$。这些是平移不变的,但不是旋转不变的。使用它们需要等变(而不仅仅是不变)的架构。 + +## SE(3) 和 E(n) 等变性 + +- 3D物理数据的对称群是**欧几里得群** $E(3)$,由所有旋转、反射和平移组成。子群 **$SE(3)$**(特殊欧几里得群)包括旋转和平移,但不包括反射。 + +- 3D GNN应该是: + - 对标量输出(能量、结合亲和力)**平移不变**:将所有原子平移相同向量不应改变预测。 + - 对标量输出**旋转不变**:旋转分子不应改变其能量。 + - 对向量/张量输出(力、偶极矩)**旋转等变**:旋转分子应使预测的力向量按相同旋转旋转。 + +![SE(3)等变性:旋转分子使标量预测(能量)保持不变,但使向量预测(力)相应旋转](../images/se3_equivariance.svg) + +- 形式上,对标量预测 $f$ 和旋转 $R \in SO(3)$: + +$$f(R\mathbf{r}_1, R\mathbf{r}_2, \ldots) = f(\mathbf{r}_1, \mathbf{r}_2, \ldots) \quad \text{(不变性)}$$ + +- 对向量预测 $\mathbf{F}$: + +$$\mathbf{F}(R\mathbf{r}_1, R\mathbf{r}_2, \ldots) = R \cdot \mathbf{F}(\mathbf{r}_1, \mathbf{r}_2, \ldots) \quad \text{(等变性)}$$ + +- 这些约束直接反映了文件1中的不变性/等变性框架,现在专门应用于3D旋转和平移群。 + +- 存在两种设计方法: + 1. **不变架构**:只使用不变几何特征(距离、角度)作为消息传递的输入。内部表示是标量(不变的)。简单高效,但不能在不破坏对称性的情况下产生向量输出。 + 2. **等变架构**:在整个网络中维护向量(以及更高阶张量)表示,确保每一层是等变的。表达能力更强,可以自然地预测向量和张量,但更加复杂。 + +## SchNet:基于距离的消息传递 + +- **SchNet**(Schütt等,2017)是基础性的不变3D GNN。其关键创新是**连续滤波器卷积**:不是使用固定的边类型集合(如分子GNN中的键类型),SchNet直接从原子间距离生成消息滤波器。 + +- 距离 $d_{ij}$ 首先使用**径向基函数(RBF)**扩展为特征向量: + +$$\text{RBF}(d_{ij}) = \left[\exp\left(-\gamma_1 (d_{ij} - \mu_1)^2\right), \ldots, \exp\left(-\gamma_K (d_{ij} - \mu_K)^2\right)\right]$$ + +- 每个基函数是一个以 $\mu_k$ 为中心、宽度为 $\gamma_k$ 的高斯函数。这类似于距离的可学习位置编码:连续距离被映射到一个高维特征空间,网络可以在其中学习距离相关的交互。中心 $\mu_k$ 通常从0到截止半径均匀分布。 + +- SchNet从节点 $j$ 到节点 $i$ 的消息为: + +$$\mathbf{m}_{j \to i} = \mathbf{h}_j \odot W_{\text{filter}}(\text{RBF}(d_{ij}))$$ + +- 其中 $W_{\text{filter}}$ 是一个将RBF扩展映射到滤波器向量的MLP,$\odot$ 是逐元素乘法(Hadamard乘积,第2章)。滤波器依赖于距离,因此附近的原子与远处的原子产生不同的交互。逐元素乘法类似于门控机制(第6章):依赖于距离的滤波器控制每个特征维度有多少通过。 + +- 由于SchNet只使用距离(不变的),整个模型自动对旋转和平移保持不变。除了这个设计选择之外,不需要对对称性进行特殊处理。 + +## DimeNet和SphereNet:角度和二面角 + +- 仅凭距离不能完全确定3D结构。两个不同的分子构象可以具有相同的成对距离但不同的键角(这就是"距离几何歧义"问题)。**DimeNet**(Gasteiger等,2020)将**键角**纳入消息传递。 + +- DimeNet使用**定向消息传递**:消息沿有向边流动,边 $(j \to i)$ 上的消息受边 $(k \to j)$ 和 $(j \to i)$ 之间的角度影响: + +$$\mathbf{m}_{kj \to ji} = f\left(\mathbf{m}_{kj}, d_{ji}, \theta_{kji}\right)$$ + +- 角度 $\theta_{kji}$ 使用球贝塞尔函数和球谐函数(球面上角度信息的自然基,类似于距离的RBF)进行扩展。这使模型在保持不变性的同时能够访问方向信息。 + +- **SphereNet**(Liu等,2022)更进一步,包含**二面角** $\phi_{lkji}$,捕获完整的3D扭转结构。层次结构为: + - 距离 → 捕获成对邻近性 + - 角度 → 捕获局部几何结构(弯曲 vs. 线性) + - 二面角 → 捕获3D扭转(对蛋白质主链、药物结合至关重要) + +- 每个层次增加了几何分辨率,但计算复杂度也随之增加(距离为 $O(|E|)$,角度为 $O(|E| \cdot k)$,二面角为 $O(|E| \cdot k^2)$,其中 $k$ 是平均度数)。 + +## E(n)等变GNN(EGNN) + +- **EGNN**(Satorras等,2021)采用等变方法:它不只使用不变特征,而是在每一层同时更新节点特征**和**节点位置,在整个过程中保持等变性。 + +- 节点 $i$ 的EGNN更新: + +$$\mathbf{m}_{ij} = \phi_e\left(\mathbf{h}_i, \mathbf{h}_j, d_{ij}^2, a_{ij}\right)$$ + +$$\mathbf{r}_i' = \mathbf{r}_i + C \sum_{j \neq i} (\mathbf{r}_i - \mathbf{r}_j) \cdot \phi_r(\mathbf{m}_{ij})$$ + +$$\mathbf{h}_i' = \phi_h\left(\mathbf{h}_i, \sum_j \mathbf{m}_{ij}\right)$$ + +- 关键在于位置更新:节点位置通过相对位置向量 $(\mathbf{r}_i - \mathbf{r}_j)$ 的加权和进行调整。权重来自消息函数 $\phi_r$,该函数仅依赖于不变的量(特征和距离)。这种构造是**可证明等变的**:如果所有输入位置被旋转 $R$,则所有输出位置被相同的 $R$ 旋转。 + +- EGNN的优雅之处在于它不显式使用球谐函数或不可约表示就实现了等变性。相对位置向量携带方向信息,不变的消息函数控制如何使用该方向信息。 + +- 这种简洁性是有代价的:EGNN只使用向量表示(1阶)。它无法在未经扩展的情况下表示更高阶的张量,如四极矩或应力张量。 + +## 张量场网络与高阶表示 + +- **张量场网络**(Thomas等,2018)及其后继者(**SE(3)-Transformers**、**MACE**、**Equiformer**)使用旋转群的**不可约表示**的完整机制来构建等变层。 + +- 在表示论中(联系到第2章的线性代数),3D中的旋转可以分解为以整数阶 $\ell$ 为特征的不可约分量: + - $\ell = 0$:标量(1个分量,不变)。能量、电荷。 + - $\ell = 1$:向量(3个分量,像位置向量一样旋转)。力、偶极矩。 + - $\ell = 2$:秩2对称无迹张量(5个分量)。四极矩、应力张量。 + - 更高的 $\ell$:捕获越来越复杂的角结构。 + +- 这些被称为**球面张量**,它们通过**Wigner-D矩阵** $D^\ell(R)$ 在旋转 $R$ 下变换:标量不变,向量由 $R$ 旋转,秩2张量由更复杂的矩阵旋转。 + +- 使用球面张量的**等变消息传递**使用**Clebsch-Gordan张量积**来组合不同阶的特征: + +$$(\mathbf{f}^{\ell_1} \otimes \mathbf{f}^{\ell_2})^{\ell_{\text{out}}} = \sum_{m_1, m_2} C^{\ell_{\text{out}}, m_{\text{out}}}_{\ell_1, m_1, \ell_2, m_2} \cdot f^{\ell_1}_{m_1} \cdot f^{\ell_2}_{m_2}$$ + +- Clebsch-Gordan系数 $C$ 是固定的数学常数,确保张量积是等变的。这是SO(3)等变版本的矩阵乘法。 + +- **MACE**(Batatia等,2022)使用高阶消息(多个邻居特征的乘积)以更少的消息传递层达到高精度。通过构建体序相互作用(距离的2体、角度的3体、张量积的多体),MACE高效地捕获了复杂的原子间相互作用。 + +- **Equiformer**(Liao & Smidt,2023)将等变球面张量特征与Transformer注意力机制(文件4)相结合,创建了SE(3)等变的图Transformer。注意力分数从不变量特征计算,而值聚合在等变张量特征上进行。 + +## 应用 + +- **分子性质预测**:给定分子的3D结构,预测性质如能量、力、偶极矩、HOMO-LUMO能隙、毒性、溶解度。这是3D GNN最成熟的应用。在量子化学数据集(QM9、OC20)上训练的模型在许多性质上达到了化学精度,实现了对数百万候选分子的虚拟筛选。 + +- **分子动力学加速**:使用量子力学(密度泛函理论,DFT)计算原子间的力极其昂贵(对 $n$ 个电子为 $O(n^3)$)。训练用于预测力的3D GNN可以在分子动力学模拟期间替代DFT,实现 $10^3$–$10^6$ 的加速,同时保持接近DFT的精度。这使得能够模拟更大的系统和更长的时间尺度,揭示传统方法无法观测的现象。 + +- **蛋白质结构**:蛋白质是折叠成复杂3D结构的氨基酸链。蛋白质主链是一个几何图,其中节点是残基,边连接空间上邻近的残基。3D GNN用于蛋白质功能预测、结合位点识别和蛋白质设计(逆折叠:给定期望结构,预测氨基酸序列)。**AlphaFold**使用几何和基于图的推理从序列预测蛋白质结构。 + +- **材料科学与催化**:晶体材料具有周期性的3D结构。GNN对重复晶胞进行建模并预测材料性质:带隙、形成能、机械强度。开放催化剂项目(OC20/OC22)对GNN进行基准测试,预测催化表面上的吸附能,加速寻找用于可再生能源的新型催化剂。 + +- **药物发现**:3D GNN预测药物分子如何与靶蛋白结合。结合亲和力取决于药物与蛋白质结合口袋之间的3D形状互补性和化学相互作用。**DiffDock**等模型使用等变GNN与扩散模型(第8章)来预测结合姿态(药物在蛋白质口袋中的3D朝向)。 + +## 图生成 + +- 上述所有架构**分析**现有图。**图生成**创建新的图:设计具有期望性质的分子、生成用于测试的合成社交网络或提出新的蛋白质结构。这是图级别预测的生成对应任务。 + +- 挑战在于图是离散的、大小可变且组合的。生成图意味着决定要创建多少个节点、它们具有什么特征以及哪些对要连接。可能的图空间随节点数量超指数增长。 + +- **自回归生成**一次构建一个节点(或一条边)。**GraphRNN**(You等,2018)顺序地生成图:RNN维护一个状态,每一步生成一个新节点,并决定将其连接到哪些现有节点。生成顺序为本来无序的图施加了人工序列,但BFS排序通过保持最近生成的节点相关性来帮助解决问题。 + +- **基于VAE的生成**将图编码到连续潜在空间(使用GNN编码器),然后从采样的潜在向量解码新图。**GraphVAE**一次性生成一个概率邻接矩阵 $\hat{A} \in [0, 1]^{n \times n}$,但这需要 $O(n^2)$ 规模并产生需要阈值化的密集输出。潜在空间允许平滑插值:在两个分子嵌入之间移动会产生化学上有效的中间结构。 + +- **基于扩散的生成**将扩散框架(第8章)应用于图。前向过程逐渐向节点特征和边结构添加噪声。反向过程学习去噪,从噪声中生成有效的图。**DiGress**(Vignac等,2023)对节点类型和边类型应用离散扩散,自然地处理图数据的分类性质。 + +- 对于**分子生成**,关键约束是**化学有效性**:生成的分子必须遵守化合价规则(碳形成4个键,氧形成2个,等等)。**Junction Tree VAE(JT-VAE)**等方将分子分解为有效子结构(环、链、官能团),并通过组装这些构建块来生成,通过构造保证有效性。 + +- **目标导向生成**优化特定性质:生成对靶蛋白具有高结合亲和力、低毒性和良好溶解度的分子。这在一个循环中结合了图生成与性质预测(使用3D GNN作为性质评估器):生成 → 评估 → 精炼。强化学习(第6章)或贝叶斯优化指导着化学空间的搜索。 + +- **DiffDock**(Corso等,2023)使用SE(3)等变扩散来预测药物分子如何对接入蛋白质结合口袋。该模型通过从随机位置去噪来生成3D结合姿态(药物相对于蛋白质的位置和朝向),将本文件中的3D等变网络与第8章的扩散框架相结合。 + +## 编程任务(使用CoLab或notebook) + +1. 构建一个使用原子间距离的简单不变3D消息传递层。将其应用于一个小分子(水:H-O-H),并验证输出对旋转是不变的。 +```python +import jax +import jax.numpy as jnp + +# 水分子:O在原点,两个H原子 +positions = jnp.array([[0.0, 0.0, 0.0], # O + [0.96, 0.0, 0.0], # H1 + [-0.24, 0.93, 0.0]]) # H2 + +# 节点特征:[原子序数] +features = jnp.array([[8.0], [1.0], [1.0]]) + +# 计算成对距离(不变的) +def pairwise_distances(pos): + diff = pos[:, None, :] - pos[None, :, :] + return jnp.sqrt(jnp.sum(diff**2, axis=-1) + 1e-8) + +# 简单的基于距离的消息传递 +def invariant_message_pass(features, positions): + dists = pairwise_distances(positions) + # 具有4个中心的RBF扩展 + centres = jnp.array([0.5, 1.0, 1.5, 2.0]) + rbf = jnp.exp(-5.0 * (dists[:, :, None] - centres[None, None, :]) ** 2) + + # 消息:由距离相关滤波器加权的特征 + messages = jnp.einsum("ij,jd->id", rbf.sum(axis=-1), features) + return messages + +output1 = invariant_message_pass(features, positions) + +# 将分子绕z轴旋转90度 +R = jnp.array([[0, -1, 0], [1, 0, 0], [0, 0, 1]], dtype=float) +rotated_positions = (R @ positions.T).T + +output2 = invariant_message_pass(features, rotated_positions) + +print(f"原始输出:\n{output1}") +print(f"\n旋转后输出:\n{output2}") +print(f"\n不变性: {jnp.allclose(output1, output2, atol=1e-5)}") +``` + +2. 计算三个原子之间的键角,并验证其对旋转不变。 +```python +import jax.numpy as jnp + +def bond_angle(r_i, r_j, r_k): + """节点j处边j->i和j->k之间的角度。""" + v1 = r_i - r_j + v2 = r_k - r_j + cos_angle = jnp.dot(v1, v2) / (jnp.linalg.norm(v1) * jnp.linalg.norm(v2)) + return jnp.arccos(jnp.clip(cos_angle, -1, 1)) + +# 三个原子 +r1 = jnp.array([1.0, 0.0, 0.0]) +r2 = jnp.array([0.0, 0.0, 0.0]) +r3 = jnp.array([0.0, 1.0, 0.0]) + +angle_original = bond_angle(r1, r2, r3) +print(f"原始角度: {jnp.degrees(angle_original):.1f}°") + +# 应用随机旋转 +R = jnp.array([[0.36, 0.48, -0.80], + [-0.80, 0.60, 0.00], + [0.48, 0.64, 0.60]]) +r1_rot, r2_rot, r3_rot = R @ r1, R @ r2, R @ r3 + +angle_rotated = bond_angle(r1_rot, r2_rot, r3_rot) +print(f"旋转后角度: {jnp.degrees(angle_rotated):.1f}°") +print(f"不变性: {jnp.allclose(angle_original, angle_rotated, atol=1e-4)}") +``` + +3. 演示等变位置更新(EGNN风格)。使用距离加权的相对向量更新节点位置,并验证等变性。 +```python +import jax +import jax.numpy as jnp + +def egnn_position_update(positions, features): + """简单的EGNN风格等变位置更新。""" + n = positions.shape[0] + new_positions = jnp.zeros_like(positions) + + for i in range(n): + shift = jnp.zeros(3) + for j in range(n): + if i != j: + r_ij = positions[i] - positions[j] + d_ij = jnp.linalg.norm(r_ij) + # 基于距离的权重(简单:反比距离) + weight = 1.0 / (d_ij + 1.0) + # 按特征相似度缩放 + feat_sim = jnp.dot(features[i], features[j]) + shift = shift + weight * feat_sim * r_ij + new_positions = new_positions.at[i].set(positions[i] + 0.1 * shift) + + return new_positions + +# 3个原子 +pos = jnp.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) +feat = jnp.array([[1.0, 0.5], [0.5, 1.0], [0.8, 0.3]]) + +# 更新位置 +pos_new = egnn_position_update(pos, feat) + +# 现在旋转输入、更新,并检查输出是否一致地旋转 +R = jnp.array([[0.0, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) +pos_rot = (R @ pos.T).T +pos_new_from_rot = egnn_position_update(pos_rot, feat) + +# 应与旋转原始输出相同 +pos_new_then_rot = (R @ pos_new.T).T + +print(f"先更新再旋转:\n{jnp.round(pos_new_then_rot, 4)}") +print(f"\n先旋转再更新:\n{jnp.round(pos_new_from_rot, 4)}") +print(f"\n等变性: {jnp.allclose(pos_new_then_rot, pos_new_from_rot, atol=1e-4)}") +``` diff --git a/chapter 13: computing and OS/01. discrete maths.md b/chapter 13: computing and OS/01. discrete maths.md new file mode 100644 index 0000000..f752ca3 --- /dev/null +++ b/chapter 13: computing and OS/01. discrete maths.md @@ -0,0 +1,253 @@ +# 离散数学 + +*离散数学是关于可数、分离结构的数学,是计算构建的基础。本文涵盖命题逻辑与谓词逻辑、证明技巧、集合、关系、函数、图论基础以及递推关系。* + +- 在前面的章节中,我们研究了连续数学:微积分(第3章)、概率分布(第5章)以及实值参数的优化(第6章)。但计算机本质上是**离散**机器。它们存储比特(0或1),处理整数,遵循分支逻辑,并操作有限数据结构。**离散数学**提供了推理这些结构的形式化语言。 + +- 本章所有内容都建立在离散数学之上:处理器逻辑门是布尔代数,调度算法需要正确性证明,内存管理使用集合运算,算法分析需要递推关系。 + +## 命题逻辑 + +- **命题逻辑**是真假语句的代数。一个**命题**是一个要么为真(T)要么为假(F)的陈述,绝不会两者兼有。"天在下雨"是一个命题。"现在几点了?"则不是(它是一个问句,不是具有真值的陈述)。 + +- 命题可以通过**逻辑连接词**进行组合: + + - **与**(合取,$p \wedge q$):仅当$p$和$q$都为真时为真。 + - **或**(析取,$p \vee q$):当$p$或$q$至少一个为真时为真。 + - **非**(否定,$\neg p$):翻转真值。 + - **蕴含**(蕴涵,$p \to q$):仅当$p$为真且$q$为假时为假。"如果下雨,地就是湿的"只有在下了雨而地却是干的时候才被违反。 + - **当且仅当**(双条件,$p \leftrightarrow q$):当两者真值相同时为真。 + +- **真值表**穷举列出所有可能的输入组合及相应的输出。对于$n$个命题,该表有$2^n$行。这就是我们验证逻辑等价性的方式: + +| $p$ | $q$ | $p \wedge q$ | $p \vee q$ | $p \to q$ | +|-----|-----|--------------|------------|-----------| +| T | T | T | T | T | +| T | F | F | T | F | +| F | T | F | T | T | +| F | F | F | F | T | + +- 蕴含行中$p$为假的情况值得关注:$F \to q$无论$q$为何值都为真。这就是**空真**。"如果猪会飞,那我就是英国国王"在逻辑上为真,因为前提为假。这看起来违反直觉,但对数学推理至关重要。 + +- **逻辑等价式**是对所有真值都成立的恒等式: + + - **德摩根定律**:$\neg(p \wedge q) \equiv \neg p \vee \neg q$ 和 $\neg(p \vee q) \equiv \neg p \wedge \neg q$。要否定一个AND,分别否定每个部分并切换为OR(反之亦然)。这些直接出现在编程中:`!(a && b)` 等价于 `(!a || !b)`。 + + - **逆否命题**:$p \to q \equiv \neg q \to \neg p$。"如果下雨,地就是湿的"等价于"如果地不是湿的,那么就没下雨。"这是一个强大的证明技巧。 + + - **双重否定**:$\neg(\neg p) \equiv p$。 + + - **分配律**:$p \wedge (q \vee r) \equiv (p \wedge q) \vee (p \wedge r)$。 + +- 一个总是为真(对所有真值指派)的公式是**重言式**。总是为假的公式是**矛盾式**。有时真有时假的公式是**偶然式**。例如,$p \vee \neg p$是重言式,$p \wedge \neg p$是矛盾式。 + +## 谓词逻辑与量词 + +- 命题逻辑无法表达关于集合中*所有*或*某些*元素的陈述。"每个大于2的素数都是奇数"需要**谓词逻辑**,它用变量、谓词和量词扩展了命题逻辑。 + +- **谓词**是依赖于变量的陈述:$P(x)$ = "$x$是偶数。"当给定$x$一个具体值时,它成为一个命题:$P(4)$为真,$P(7)$为假。 + +- **量词**表达范围: + + - **全称量词**($\forall$):"对于所有。" $\forall x \, P(x)$ 表示"$P(x)$对论域中的每一个$x$成立。" + - **存在量词**($\exists$):"存在。" $\exists x \, P(x)$ 表示"至少存在一个$x$使得$P(x)$为真。" + +- 否定量词会翻转它们:$\neg(\forall x \, P(x)) \equiv \exists x \, \neg P(x)$。"不是所有人都通过了"意味着"有人没通过。"而 $\neg(\exists x \, P(x)) \equiv \forall x \, \neg P(x)$。"没有完美的算法"意味着"每个算法都有缺陷。" + +- 嵌套量词表达复杂关系。$\forall x \, \exists y \, (y > x)$ 表示"对于每个数,都有一个更大的数"(对整数成立)。顺序很重要:$\exists y \, \forall x \, (y > x)$ 表示"存在一个比所有其他数都大的数"(对整数不成立)。 + +- 谓词逻辑是形式化规约的语言。当我们说一个算法是"正确"的,意味着 $\forall \text{输入} \, x, \, \text{输出}(x) = \text{期望输出}(x)$。当我们说它"终止",意味着 $\forall x \, \exists t \, \text{终止}(x, t)$。 + +## 证明技巧 + +- **证明**是确立一个陈述真理性、毫无疑义的逻辑论证。与经验证据(仅展示在某些测试案例下有效)不同,证明保证在所有情况下成立。这是计算机科学中正确性的标准。 + +- **直接证明**:假设前提,通过逻辑步骤推导出结论。要证明"如果$n$是偶数,那么$n^2$是偶数":假设$n = 2k$对于某个整数$k$,则$n^2 = 4k^2 = 2(2k^2)$,这是偶数。 + +- **反证法**:假设该陈述为假,推导出矛盾。要证明$\sqrt{2}$是无理数:假设$\sqrt{2} = a/b$(已约简)。那么$2 = a^2/b^2$,所以$a^2 = 2b^2$,意味着$a^2$是偶数,所以$a$是偶数,设$a = 2c$。那么$4c^2 = 2b^2$,所以$b^2 = 2c^2$,意味着$b$也是偶数。但我们已经假设$a/b$是约简形式——矛盾。 + +- **归纳证明**:通过证明以下两点来证明一个陈述对所有自然数成立:(1)**基础情形**成立(通常$n = 0$或$n = 1$),和(2)**归纳步骤**:如果陈述对$n = k$成立(归纳假设),那么它对$n = k + 1$也成立。 + +- 例如,证明 $\sum_{i=1}^{n} i = \frac{n(n+1)}{2}$: + - 基础情形:$n = 1$:$1 = \frac{1 \cdot 2}{2} = 1$。成立。 + - 归纳步骤:假设 $\sum_{i=1}^{k} i = \frac{k(k+1)}{2}$。那么 $\sum_{i=1}^{k+1} i = \frac{k(k+1)}{2} + (k+1) = \frac{k(k+1) + 2(k+1)}{2} = \frac{(k+1)(k+2)}{2}$。这正是$n = k+1$时的公式。证明完成。 + +- 归纳法是证明递归算法和数据结构性质的主力工具。每个递归算法都暗含一个归纳正确性证明:基础情形是终止条件,归纳步骤是递归调用。 + +- **强归纳法**假设该陈述对所有不大于$k$的值都成立(不仅仅是$k$),然后证明它对$k + 1$成立。当递归依赖于多个之前的值时,这很有用。 + +- **鸽巢原理**:如果把$n+1$个物体放入$n$个盒子中,至少有一个盒子包含两个物体。简单但出奇地强大。它证明了在任何13个人中,至少有两个人出生月份相同。在网络中,它证明了当项目数超过桶数时,哈希冲突是不可避免的。 + +## 集合 + +- **集合**是不同元素的无序收集。集合是数学中最原始的数据结构,支撑着从类型系统到数据库查询的一切。 + +- **集合运算**(联系第5章,我们在那里用这些进行概率计算): + + - **并集** $A \cup B$:在$A$或$B$或两者中的元素。 + - **交集** $A \cap B$:同时在$A$和$B$中的元素。 + - **补集** $\bar{A}$:不在$A$中的元素(相对于一个全集)。 + - **差集** $A \setminus B$:在$A$中但不在$B$中的元素。 + - **笛卡尔积** $A \times B$:所有有序对$(a, b)$,其中$a \in A, b \in B$。 + +- **幂集** $\mathcal{P}(A)$ 是$A$的所有子集构成的集合。如果 $|A| = n$,那么 $|\mathcal{P}(A)| = 2^n$。对于 $A = \{1, 2\}$:$\mathcal{P}(A) = \{\emptyset, \{1\}, \{2\}, \{1, 2\}\}$。 + +- **基数**衡量集合大小。有限集具有整数基数。无限集有不同的大小:自然数$\mathbb{N}$和有理数$\mathbb{Q}$是**可数无穷**(可以列举),而实数$\mathbb{R}$是**不可数无穷**(无法列举,由康托尔的对角线论证证明)。这种区别在可计算性理论中很重要:存在不可数多个函数,但只有可数多个程序,因此大多数函数是不可计算的。 + +## 关系 + +- 集合$A$上的**关系**$R$是$A \times A$的一个子集:指定哪些元素相关联的有序对集合。例如,整数上的$\leq$是集合 $\{(a, b) : a \leq b\}$。 + +- 关系的重要性质: + + - **自反性**:每个元素与自身相关。对所有$a$有$a R a$。例:$\leq$(每个数$\leq$自身)。 + - **对称性**:如果$a R b$则$b R a$。例:"是……的兄弟姐妹。" + - **反对称性**:如果$a R b$且$b R a$则$a = b$。例:$\leq$。 + - **传递性**:如果$a R b$且$b R c$则$a R c$。例:$<$、$\leq$、"是……的祖先。" + +- **等价关系**是自反、对称且传递的。它将集合划分为**等价类**,其中同一类中的所有元素彼此相关,但与不同类中的元素无关。模运算是一个等价关系:$a \equiv b \pmod{n}$ 将整数划分为$n$个类。编程语言中的类型等价是一个等价关系。 + +- **偏序**是自反、反对称且传递的。它定义了一个"小于等于"结构,可能会使某些元素不可比较。文件系统目录构成一个偏序(父-子),但同级目录是不可比较的。**全序**是每一对元素都可比较的偏序(如整数上的$\leq$)。 + +- 偏序在并发中至关重要:事件上的"先于发生"关系是一个偏序。不由先于发生关系排序的事件是并发的,可能以任意相对顺序执行。 + +## 函数 + +- **函数** $f: A \to B$ 将$A$(定义域)中的每个元素映射到$B$(陪域)中的恰好一个元素。函数是确定性计算的数学模型:给定一个输入,恰好有一个输出。 + +- **单射**(一对一):不同的输入总是产生不同的输出。$f(a) = f(b) \implies a = b$。无损压缩是单射的:不同的输入必须压缩成不同的输出(否则无法唯一解压)。 + +- **满射**(到上):$B$中的每个元素都被$A$中的某个元素命中。值域等于陪域。将字符串映射到256位哈希的哈希函数,如果字符串数少于可能的哈希数,则不是满射。 + +- **双射**:既是单射又是满射。$A$和$B$之间的一一对应。双射具有逆函数。加密必须是双射的:每个明文映射到唯一的密文,而解密函数就是逆函数。 + +- **复合** $(g \circ f)(x) = g(f(x))$:先应用$f$,再应用$g$。函数复合是可结合的(第2章:就像矩阵乘法是可结合的一样)。软件中的管道就是函数复合:数据流经一系列变换。 + +## 图论基础 + +- 我们在第12章(图神经网络)中广泛介绍了图,包括邻接矩阵、图类型、拉普拉斯矩阵和谱理论。这里我们专注于与CS相关的**算法**和**结构**性质。 + +- **树**是没有环的连通图。等价地,它有$n$个节点和$n-1$条边。树是文件系统、XML/HTML文档、决策过程和递归分解的结构。**有根树**有一个指定的根节点;每个其他节点恰好有一个父节点。 + +- 图$G$的**生成树**是包含$G$所有节点并使用其边子集的一棵树。**最小生成树(MST)**最小化总边权。Kruskal算法(对边排序,贪心地添加不形成环的最轻边)和Prim算法(从起始节点开始扩展树,总是添加连接到新节点的最轻边)都能在$O(|E| \log |V|)$内找到MST。 + +- **平面性**:如果一个图可以画在平面上而边不相交,则是平面图。根据**欧拉公式**,对于连通平面图:$|V| - |E| + |F| = 2$,其中$|F|$是面的数量(区域,包括外部面)。这意味着平面图的$|E| \leq 3|V| - 6$,因此平面图是稀疏的。电路板布线和地图着色利用了平面性。 + +- **图着色**为节点分配颜色,使得没有两个相邻节点共享相同的颜色。所需的最小颜色数是**色数** $\chi(G)$。**四色定理**指出任何平面图的 $\chi(G) \leq 4$。在CS中,图着色模拟寄存器分配(将变量分配到CPU寄存器,使得同时活跃的变量获得不同的寄存器)和调度(将任务分配到时间槽,使得冲突的任务不重叠)。 + +- **欧拉路径**恰好访问每条边一次。当且仅当图中恰好有0个或2个奇数度节点时,欧拉路径存在。**哈密顿路径**恰好访问每个节点一次。确定哈密顿路径是否存在是NP完全的——这是CS中的经典难题之一。这种对比(欧拉:多项式,哈密顿:NP完全)说明了听起来相似的问题可能具有截然不同的计算复杂度。 + +## 递推关系 + +- **递推关系**定义一个序列,其中每一项依赖于前面的项。它们自然地从递归算法中产生。 + +- 最简单的例子:$T(n) = T(n-1) + 1$,其中 $T(0) = 0$。展开:$T(n) = T(n-1) + 1 = T(n-2) + 2 = \cdots = n$。这是$O(n)$,即简单循环的时间复杂度。 + +- **归并排序**给出 $T(n) = 2T(n/2) + O(n)$:将数组分成两半(两个大小为$n/2$的子问题),递归排序每一半,然后合并($O(n)$工作)。解为 $T(n) = O(n \log n)$。 + +- **主定理**求解形式为 $T(n) = aT(n/b) + O(n^d)$ 的递推式: + + - 如果 $d > \log_b a$:$T(n) = O(n^d)$(每层的工作占主导) + - 如果 $d = \log_b a$:$T(n) = O(n^d \log n)$(工作在各层间平衡) + - 如果 $d < \log_b a$:$T(n) = O(n^{\log_b a})$(子问题的数量占主导) + +- 对于归并排序:$a = 2, b = 2, d = 1$。由于 $d = \log_2 2 = 1$,我们处于平衡情况:$T(n) = O(n \log n)$。 + +- **斐波那契递推** $F(n) = F(n-1) + F(n-2)$,其中 $F(0) = 0, F(1) = 1$,封闭形式解为 $F(n) = \frac{\phi^n - \psi^n}{\sqrt{5}}$,其中 $\phi = \frac{1+\sqrt{5}}{2}$(黄金比例)且 $\psi = \frac{1-\sqrt{5}}{2}$。这表明斐波那契数列以 $O(\phi^n)$ 指数增长,这就是为什么朴素递归斐波那契指数级慢。 + +- **组合数学**(排列、组合、二项式定理和容斥原理)在第5章(概率)中介绍。这些计数技术对算法分析至关重要(有多少种可能的输入?需要多少次比较?),但我们在此不再重复。 + +## 可计算性 + +- 并非所有事情都能被计算。这是整个数学中最深刻的结论之一,它设定了计算机能力的基本极限。 + +- **图灵机**是计算的抽象模型:一条无限长的单元格磁带(每个单元格包含一个符号),一个读写头,以及一组带转移规则的有限状态。尽管简单,图灵机可以计算任何实际计算机能计算的任何东西。这就是**邱奇-图灵论题**:任何有效可计算的函数都可以由图灵机计算。 + +- 每种编程语言(Python、C、Haskell)都是**图灵完备**的:它可以模拟图灵机,从而计算任何可计算的东西。语言之间的区别在于便利性、速度和安全性,而不在于它们根本上能计算什么。 + +- **停机问题**询问:给定一个程序和一个输入,该程序最终会停止,还是永远运行?图灵(1936)证明不存在能普遍解决这个问题的算法。证明采用反证法:假设存在一个停机检测器 $H(P, x)$。构造一个程序 $D$,它运行 $H(D, D)$ 并做与 $H$ 所说的相反的事。如果 $H$ 说 $D$ 停机,$D$ 就永远循环。如果 $H$ 说 $D$ 循环,$D$ 就停机。矛盾。 + +- 这不是当前技术的局限;这是一个数学上的不可能性。无论多少计算、多少聪明才智、或多少人工智能,都无法普遍解决停机问题。它是哥德尔不完备定理在计算机科学中的类比。 + +- 实际后果:你无法编写一个完美的死锁检测器、一个完美的病毒扫描器或一个完美的优化编译器。每一个都需要通用地解决停机问题(或一个等价的不判定问题)。实际工具使用启发式方法和近似方法,在常见情况下有效,但不能保证对所有输入都正确。 + +- 如果一个问题存在一个总是能给出正确是/否答案并终止的算法,则它是**可判定的**。如果不存在这样的算法,则是**不可判定的**。停机问题是不可判定的。素数测试是可判定的。大多数编程语言中的类型检查是可判定的(通过设计)。 + +## 复杂度理论 + +- 即使在可计算的问题中,有些也远比其他的难。**复杂度理论**根据解决问题所需的资源(时间、空间)随输入增长而分类问题。 + +![P、NP和NP完全:P包含在NP中,NP完全位于边界处,P是否等于NP是核心开放问题](../images/p_np_complexity.svg) + +- **P**(多项式时间):能在 $O(n^k)$ 时间内解决的问题,$k$为某个常数。排序($O(n \log n)$)、最短路径($O(|V|^2)$)、矩阵乘法($O(n^3)$)。这些被认为是"高效"或"可处理的。" + +- **NP**(非确定性多项式时间):一个拟议的解答能在多项式时间内**验证**的问题,即使**找到**解答可能需要指数时间。例如,给定一个声称的哈密顿路径,你可以通过检查每条边在 $O(n)$ 时间内验证它。但找到一条可能需尝试指数多个可能性。 + +- P中的每个问题也在NP中(如果你能快速解决它,你当然能快速验证一个解答)。核心问题是 $P = NP$ 是否成立:每个能快速验证解答的问题是否也能快速求解?这是计算机科学中最重要的开放问题,获得克莱数学研究所100万美元的千禧年大奖。 + +- 大多数专家相信 $P \neq NP$,意味着有些问题本质上比验证更难解决。如果 $P = NP$,密码学将崩溃(破解加密属于NP),而优化、调度和药物设计将变得异常简单。 + +- **NP完全**问题是NP中最难的问题。一个问题如果是NP完全的,则:(1)它在NP中,且(2)所有其他NP问题可以在多项式时间内**归约**到它。如果你能高效解决任何一个NP完全问题,你就能解决所有NP完全问题(从而 $P = NP$)。 + +- **归约**将一个问题转换为另一个问题。如果问题A归约到问题B,那么B至少和A一样难。Cook(1971)证明了**SAT**(布尔可满足性:给定一个逻辑公式,是否存在使公式为真的变量赋值?)是NP完全的。Karp(1972)通过将SAT归约到每个问题,证明了其他21个经典问题是NP完全的。 + +- 著名的NP完全问题: + - **旅行商问题(TSP)**:找到访问所有城市恰好一次的最短路线。 + - **图着色**:用$k$种颜色为节点着色,使得没有相邻节点共享同一颜色($k \geq 3$)。 + - **子集和问题**:给定一组整数,是否存在一个子集其和等于目标值? + - **布尔可满足性(SAT)**:是否存在使逻辑公式为真的真值赋值? + - **哈密顿路径**(上文图论中提到的)。 + +- 当你在实践中遇到NP完全问题时,你不会对大规模输入精确求解。相反,你使用:**近似算法**(找到保证在最优解一定倍数范围内的解)、**启发式方法**(贪心、局部搜索、模拟退火)或**特例求解器**(许多NP完全问题对受限输入很容易)。例如,现代SAT求解器尽管在最坏情况下是指数复杂度,但通过利用实际实例中的结构,通常能解决拥有数百万变量的实例。 + +- **NP困难**问题至少和NP完全问题一样难,但可能不在NP中(它们的解甚至可能不能在多项式时间内验证)。NP完全问题的优化版本通常是NP困难的:"找到最短TSP路线"是NP困难的,而"是否存在一条长度小于$k$的TSP路线?"是NP完全的。 + +## 编程任务(使用CoLab或笔记本) + +1. 构建一个真值表生成器。给定一个逻辑表达式,枚举所有输入组合并计算结果。 +```python +import itertools + +def truth_table(n_vars, expr_fn): + """为一个n_vars个变量的布尔函数生成真值表。""" + headers = [f"p{i}" for i in range(n_vars)] + print(" | ".join(headers + ["result"])) + print("-" * (len(headers) * 4 + 10)) + for vals in itertools.product([False, True], repeat=n_vars): + result = expr_fn(*vals) + row = [str(v)[0] for v in vals] + [str(result)[0]] + print(" | ".join(f"{r:>2}" for r in row)) + +# 德摩根定律:NOT(p AND q) == (NOT p) OR (NOT q) +print("德摩根定律验证:") +truth_table(2, lambda p, q: (not (p and q)) == ((not p) or (not q))) +``` + +2. 通过归纳法证明求和公式——对多个值进行数值验证,然后实现封闭形式解。 +```python +import jax.numpy as jnp + +# 验证求和公式:sum(1..n) = n(n+1)/2 +for n in [1, 5, 10, 100, 1000, 10000]: + brute = sum(range(1, n + 1)) + formula = n * (n + 1) // 2 + print(f"n={n:5d} sum={brute:>10d} formula={formula:>10d} match={brute == formula}") +``` + +3. 使用主定理求解归并排序递推关系,并通过计数操作进行经验验证。 +```python +import jax.numpy as jnp + +def merge_sort_ops(n): + """统计归并排序中的比较次数(递推:T(n) = 2T(n/2) + n)。""" + if n <= 1: + return 0 + half = n // 2 + return merge_sort_ops(half) + merge_sort_ops(n - half) + n + +for n in [8, 64, 512, 4096, 32768]: + ops = merge_sort_ops(n) + predicted = n * jnp.log2(n) + ratio = ops / predicted + print(f"n={n:5d} ops={ops:>10d} n log n={int(predicted):>10d} ratio={ratio:.3f}") +``` diff --git a/chapter 13: computing and OS/02. computer architecture.md b/chapter 13: computing and OS/02. computer architecture.md new file mode 100644 index 0000000..10e7c65 --- /dev/null +++ b/chapter 13: computing and OS/02. computer architecture.md @@ -0,0 +1,233 @@ +# 计算机体系结构 + +*计算机体系结构是关于如何构建执行指令的机器。本文涵盖数制、逻辑门、CPU设计、指令集架构、流水线、存储器层次结构和虚拟内存——每个程序、框架和AI模型最终运行其上的硬件基础。* + +- 每个神经网络、每个训练循环、每次推理调用最终都会变成流经晶体管的电信号序列。对于严肃的机器学习从业者来说,理解硬件不是可选的:它解释了为什么矩阵乘法很快,为什么内存是瓶颈,为什么GPU主导AI训练,以及为什么缓存友好的代码可以比朴素代码快100倍。 + +## 数制 + +- 计算机将所有内容表示为**二进制**(基2):0和1的序列。每个数字是一个**比特**。8个比特为一组称为一个**字节**。二进制数 $b_{n-1} b_{n-2} \ldots b_1 b_0$ 的值为 $\sum_{i=0}^{n-1} b_i \cdot 2^i$。 + +- 例如,$1011_2 = 1 \cdot 8 + 0 \cdot 4 + 1 \cdot 2 + 1 \cdot 1 = 11_{10}$。 + +- **十六进制**(基16)是二进制的紧凑表示法。每个十六进制数字代表4个比特:$0\text{-}9$ 映射到 $0000\text{-}1001$,$A\text{-}F$ 映射到 $1010\text{-}1111$。因此 $\text{0xFF} = 1111\,1111_2 = 255_{10}$。内存地址和颜色代码通常用十六进制书写。 + +- **补码**表示有符号整数。对于$n$位数字,最高有效位的权重为 $-2^{n-1}$ 而非 $+2^{n-1}$。8位补码的范围为 $-128$ 到 $+127$。要取一个数的相反数:翻转所有位然后加1。这种表示使加法和减法使用相同的硬件电路,这就是它被普遍采用的原因。 + +- **IEEE 754浮点数**将实数表示为 $(-1)^s \times 1.m \times 2^{e-\text{bias}}$,其中$s$是符号位,$m$是尾数(小数部分),$e$是移码指数。 + +![IEEE 754 float32布局:1个符号位、8个指数位、23个尾数位](../images/ieee754_float.svg) + + - **float32**(单精度):1个符号 + 8个指数 + 23个尾数 = 32位。范围:$\approx \pm 3.4 \times 10^{38}$,精度:$\approx 7$位十进制数字。 + - **float64**(双精度):1个符号 + 11个指数 + 52个尾数 = 64位。范围:$\approx \pm 1.8 \times 10^{308}$,精度:$\approx 15$位十进制数字。 + - **float16**(半精度):1 + 5 + 10 = 16位。范围和精度有限,但使用一半的内存和带宽。广泛用于ML训练(混合精度,第6章)。 + - **bfloat16**:1 + 8 + 7 = 16位。与float32相同的指数范围但精度更低。由Google专门为ML设计:完整的指数范围可防止训练期间溢出,降低的精度对梯度更新是可以接受的。 + +- 浮点算术**不精确**。在float64中,$0.1 + 0.2 \neq 0.3$(它等于 $0.30000000000000004$)。这是因为$0.1$没有精确的二进制表示,就像$1/3$没有精确的十进制表示一样。在数百万次操作(如梯度下降)中积累这些误差可能导致数值不稳定,这就是为什么存在像损失缩放(第6章)和Kahan求和法这样的技术。 + +## 逻辑门 + +- 所有计算都可以归结为**逻辑门**:实现布尔运算(来自文件1的命题逻辑)的物理电路。 + +- 基本门: + - **与门**(AND):仅当两个输入都为1时输出为1。 + - **或门**(OR):至少一个输入为1时输出为1。 + - **非门**(NOT,反相器):翻转输入。 + - **与非门**(NAND,NOT-AND):通用门。任何其他门都可以仅由与非门构建。这就是为什么与非门是数字电路的基本构建块。 + - **异或门**(XOR,异或):输入不同时输出为1。对于加法(二进制加法的和位就是XOR)和加密至关重要。 + +- **半加器**使用XOR(和)和AND(进位)相加两个单比特。**全加器**相加两个比特加上一个进位输入,可以串联起来创建$n$位加法器。这就是CPU执行整数加法的方式:一系列简单逻辑门的级联。 + +- **多路选择器**(MUX)根据控制信号从多个输入中选择一个。使用$n$个控制位,可以从$2^n$个输入中选择。多路选择器是if-else链的硬件等价物,广泛用于CPU数据通路中路由数据。 + +- 现代处理器包含数十亿个晶体管,每个晶体管充当一个微小的开关。晶体管要么导通(导电,表示1),要么不导通(不导电,表示0)。门由晶体管构成,加法器由门构成,ALU由加法器构成,CPU由ALU构成。整个计算层级就建立在这个基础之上。 + +## CPU架构 + +- **中央处理器(CPU)**执行指令。其核心组件: + + - **ALU**(算术逻辑单元):执行整数算术(加、减、乘)和逻辑运算(AND、OR、XOR、移位)。这里是实际计算发生的地方,由上述逻辑门构建而成。 + + - **寄存器**:CPU内部微小、超快的存储位置。现代CPU有数十个通用寄存器,每个寄存器保存一个字(在64位CPU上为64位)。寄存器是系统中速度最快的存储器:访问时间约~0.3纳秒。 + + - **程序计数器(PC)**:保存下一条要执行指令的内存地址。 + + - **控制单元**:解码指令并编排数据通路,告诉ALU执行什么操作以及使用哪些寄存器。 + +- **指令周期**(取指-译码-执行)每秒重复数十亿次: + + 1. **取指**:从PC中的地址读取指令。 + 2. **译码**:确定指令的功能(加法?从内存加载?分支?)及其使用的操作数。 + 3. **执行**:执行操作(ALU计算、内存访问或分支)。 + 4. 增加PC(除非指令是分支/跳转)。 + +- 运行在4 GHz的CPU每秒执行40亿个周期。每个周期耗时0.25纳秒。在这段时间内,光传播约7.5厘米,这就是芯片物理大小重要的原因:信号无法在一个周期内穿过大芯片。 + +## 指令集架构 + +- **指令集架构(ISA)**是硬件和软件之间的契约:它定义了CPU能理解的指令、寄存器集、内存模型和编码格式。 + +- **CISC**(复杂指令集计算机):指令可以复杂、变长,并可以直接访问内存。一条指令可以乘法两个内存值并存储结果。**x86**(Intel/AMD)是占主导地位的CISC ISA,驱动着大多数桌面和服务器。其向后兼容性(现代x86 CPU仍然运行1980年代的代码)既是优势也是负担。 + +- **RISC**(精简指令集计算机):指令简单、定长,且仅操作寄存器。内存访问需要单独的加载/存储指令。更简单的指令可实现更快的时钟速度和更易实现的流水线。 + + - **ARM**:移动设备的主要RISC ISA,并越来越多地用于服务器和笔记本电脑(Apple M系列芯片就是ARM)。ARM的能效使其非常适合电池供电和热受限设备。 + - **RISC-V**:一个开源的RISC ISA。任何人都可以设计RISC-V芯片而无需许可费。在嵌入式系统、研究和AI加速器中的采用正在增长。 + +- CISC与RISC的区别已经模糊:现代x86 CPU内部将复杂的CISC指令解码为更简单的微操作(本质上是内部RISC),从而获得两方面的优势。 + +## 流水线 + +- 没有流水线时,CPU完全完成一条指令后才开始下一条。这会浪费硬件:当ALU执行时,取指和译码单元处于空闲状态。 + +![CPU流水线:指令在取指、译码、执行、访存和写回阶段重叠](../images/cpu_pipeline.svg) + +- **流水线**使指令执行重叠,如同装配线。当指令1在执行时,指令2在译码,指令3在被取指。一个5级流水线(取指、译码、执行、访存、写回)可以同时有5条指令在执行中。 + +- 吞吐量接近每周期一条指令(尽管每条指令需要5个周期才能完成)。这与ML中的流水线原理相同:数据并行性使计算和通信重叠(第6章)。 + +- **冒险**是流水线被破坏的情况: + + - **数据冒险**:指令2需要指令1尚未产生的结果。"Add R1, R2, R3"后跟"Sub R4, R1, R5"——第二条指令需要R1,而第一条指令仍在计算。**转发**(旁路)通过将结果直接从一级流水线路由到另一级,无需等待写回阶段来解决这个问题。 + + - **控制冒险**:分支指令(if-else)意味着CPU在分支解析之前不知道应该取指哪条下一条指令。**分支预测**猜测分支将走哪条路径,并推测性地沿预测路径取指。现代预测器准确率超过95%,使用历史表和类似神经网络的模式匹配。一次预测错误代价约~15个周期(流水线必须被清空并重启)。 + + - **结构冒险**:两条指令同时需要相同的硬件资源(例如,都需要内存端口)。通过复制资源或插入停顿来解决。 + +## 存储器层次结构 + +- 计算机内存中的根本矛盾:快速内存昂贵且容量小,廉价内存缓慢但容量大。**存储器层次结构**通过利用**局部性**来弥合这一差距:程序倾向于重复访问相同的数据(时间局部性)并访问附近的数据(空间局部性)。 + +![存储器层次结构金字塔:寄存器在顶部(快速、小)到HDD在底部(慢、大)](../images/memory_hierarchy.svg) + +- 层次结构,从最快到最慢: + + - **寄存器**:~0.3 ns访问,总容量~KB。位于CPU内。 + - **L1缓存**:~1 ns,每核心32-64 KB。分为指令缓存和数据缓存。 + - **L2缓存**:~4 ns,每核心256 KB-1 MB。 + - **L3缓存**:~10 ns,跨核心共享8-64 MB。 + - **RAM(DRAM)**:~50-100 ns,8-512 GB。主内存。 + - **SSD**:~10-100 μs,256 GB-8 TB。持久存储。 + - **HDD**:~5-10 ms,1-20 TB。机械式,随机访问非常慢。 + +- 寄存器和RAM之间的速度差距约为300倍。寄存器和磁盘之间约为30,000,000倍。缓存层次结构隐藏了这一差距:如果CPU需要的数据在L1缓存中(**缓存命中**),访问很快。如果不在(**缓存未命中**),CPU停顿,同时从更慢的层级获取数据。 + +- **缓存关联度**决定内存地址可以存储在缓存中的位置: + - **直接映射**:每个地址映射到恰好一个缓存行。简单但会导致冲突。 + - **全关联**:任何地址可以放在任何位置。灵活但搜索成本高。 + - **组关联**($k$路):每个地址映射到一组$k$个位置。实际CPU中使用的实用折衷方案(通常为4路或8路)。 + +- **缓存一致性**确保所有CPU核心看到一致的内存视图。当核心1写入一个核心2已缓存的内存地址时,一致性协议(如MESI)会使核心2的副本失效或更新。这对并发编程(文件4)至关重要,也是共享内存并行性困难的原因之一。 + +- 对于ML从业者,存储器层次结构解释了为什么: + - 矩阵运算应按顺序访问内存(行优先与列优先的布局很重要)。 + - 批量大小会影响性能:更大的批次分摊内存延迟。 + - 混合精度(float16/bfloat16)使有效内存带宽翻倍,而内存带宽往往是瓶颈。 + +## 虚拟内存 + +- **虚拟内存**使每个进程仿佛拥有自己独立、连续的大内存空间,即使物理RAM是有限的并在进程间共享。 + +- 地址空间被划分为固定大小的**页**(通常为4 KB)。**页表**将虚拟页号映射到物理帧号。当程序访问虚拟地址0x1234时,CPU通过查找页表将其转换为物理地址。 + +- **转译后备缓冲器(TLB)**是页表项的缓存。由于页表位于RAM中(慢速),TLB在快速硬件中存储最近使用的转译结果。TLB未命中需要遍历内存中的页表,耗费数百个周期。 + +- 当程序访问一个不在物理RAM中的页时,发生**缺页**。OS从磁盘加载该页(交换),耗费数百万个周期。过多的缺页(**系统颠簸**)会严重损害性能。这就是为什么ML训练需要足够的RAM来容纳模型、优化器状态和合理的数据批次。 + +- **页面置换**算法决定当RAM满时应换出哪个页面: + - **LRU**(最近最少使用):换出最长时间未被访问的页面。在实践中对大多数工作负载最优。在硬件中通过**时钟算法**(带引用位的循环链表)近似实现。 + - **FIFO**:换出最旧的页面。简单但可能换出频繁使用的页面。 + - **最优**(Bélády算法):换出将在最长时间内不被使用的页面。无法实现(需要未来知识)但可作为理论基准。 + +- 虚拟内存还提供了**隔离**:每个进程都有自己的虚拟地址空间。一个进程中的错误不会破坏另一个进程的内存,因为它们的虚拟地址映射到不同的物理帧。这是OS安全性和稳定性的基础。 + +## I/O、中断和DMA + +- CPU需要与外部世界通信:磁盘、网卡、键盘、GPU。这就是**I/O子系统**。 + +- **程序控制I/O**(轮询):CPU在一个循环中反复检查设备的状态寄存器,等待数据就绪。简单但浪费CPU周期做空转而不是有用工作。 + +- **中断驱动I/O**:设备在数据就绪时发送一个硬件**中断**。CPU继续正常执行直到中断到达,然后运行一个**中断处理程序**(内核函数)来处理数据。这比轮询高效得多,因为CPU在等待时不会空闲。 + +- 中断机制: + 1. 设备通过硬件线路发出中断信号。 + 2. CPU完成当前指令,将当前状态(寄存器、程序计数器)保存到堆栈。 + 3. CPU在**中断向量表**(每个中断类型对应一个函数指针的表)中查找中断处理程序地址。 + 4. 处理程序在内核模式下运行,处理I/O,然后返回。 + 5. CPU恢复保存的状态并恢复被中断的程序。 + +- 这与上下文切换(文件3)的保存/恢复模式相同,但由硬件而非定时器触发。 + +- **DMA**(直接存储器访问):对于大数据传输(磁盘读取、网络数据包、GPU内存复制),让CPU逐字节复制数据是浪费的。**DMA控制器**直接在设备和RAM之间传输数据,无需CPU参与。CPU设置传输(源地址、目标地址、大小),DMA控制器处理传输,完成后CPU收到一个中断。 + +- DMA对ML至关重要:当你调用 `model.to('cuda')` 时,数据通过PCIe总线上的DMA从系统RAM传输到GPU内存。在训练期间,跨GPU的梯度同步使用基于DMA的RDMA(远程DMA)进行高带宽、低延迟传输(第6章)。 + +- **总线**将CPU连接到内存和I/O设备。现代系统使用**PCIe**(快速外设组件互连)连接高速设备(GPU、NVMe SSD、网卡)。PCIe 4.0在每个x16插槽上提供约~32 GB/s;PCIe 5.0将其翻倍。总线带宽通常是GPU训练的瓶颈:GPU的计算速度可能快于数据送达的速度。 + +- **MMIO**(内存映射I/O):设备寄存器被映射到内存地址。CPU使用普通的加载/存储指令对这些地址进行读写,硬件将访问路由到设备而不是RAM。这统一了内存和I/O访问为一个单一机制,简化了硬件和软件。 + +## 编程任务(使用CoLab或笔记本) + +1. 探索IEEE 754浮点数表示。将浮点数转换为二进制表示,观察符号、指数和尾数字段。 +```python +import struct + +def float_to_bits(f): + """显示float32的IEEE 754二进制表示。""" + packed = struct.pack('>f', f) + bits = ''.join(f'{byte:08b}' for byte in packed) + sign = bits[0] + exponent = bits[1:9] + mantissa = bits[9:] + return sign, exponent, mantissa + +for val in [1.0, -1.0, 0.1, 0.5, 3.14, float('inf'), float('nan')]: + s, e, m = float_to_bits(val) + print(f"{val:>10} sign={s} exp={e} ({int(e, 2) - 127:>4d}) mantissa={m[:10]}...") +``` + +2. 模拟直接映射缓存。跟踪一系列内存访问的命中与未命中。 +```python +def simulate_cache(accesses, cache_size=8, block_size=1): + """模拟直接映射缓存。""" + cache = [None] * cache_size + hits, misses = 0, 0 + + for addr in accesses: + cache_line = addr % cache_size + if cache[cache_line] == addr: + hits += 1 + status = "HIT " + else: + misses += 1 + cache[cache_line] = addr + status = "MISS" + print(f" Access {addr:3d} → line {cache_line}: {status}") + + print(f"\nHits: {hits}, Misses: {misses}, Hit rate: {hits/(hits+misses):.1%}") + +# 顺序访问(良好的局部性) +print("顺序访问:") +simulate_cache([0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3]) + +# 跨步访问(冲突未命中) +print("\n跨步访问(stride = cache size):") +simulate_cache([0, 8, 0, 8, 0, 8]) +``` + +3. 演示为什么浮点算术不满足结合律。展示 $(a + b) + c \neq a + (b + c)$ 的情况。 +```python +import jax.numpy as jnp + +a = jnp.float32(1e8) +b = jnp.float32(1.0) +c = jnp.float32(-1e8) + +left = (a + b) + c # (1e8 + 1) + (-1e8) +right = a + (b + c) # 1e8 + (1 + (-1e8)) + +print(f"(a + b) + c = {left}") # 应为 1.0 +print(f"a + (b + c) = {right}") # 可能会丢失 1.0 +print(f"Equal: {left == right}") +print(f"\n当 1.0 加到 1e8 上时被丢失,因为 float32 只有约 7 位精度") +``` diff --git a/chapter 13: computing and OS/03. operating systems.md b/chapter 13: computing and OS/03. operating systems.md new file mode 100644 index 0000000..9591e01 --- /dev/null +++ b/chapter 13: computing and OS/03. operating systems.md @@ -0,0 +1,258 @@ +# 操作系统 + +*操作系统是硬件与应用程序之间的软件层,负责管理资源、提供抽象并实施隔离。本文涵盖操作系统的功能、进程、线程、CPU调度、内存管理、文件系统和系统调用。* + +- 没有操作系统的计算机就像一个没有厨师的厨房:食材(硬件)都在那里,但没有人协调谁使用炉灶、餐具放在哪里、或者如何防止两个人同时抓同一把刀。**OS**就是那个协调者。 + +- 对于ML从业者,操作系统的概念解释了:为什么 `nvidia-smi` 显示每个进程的GPU内存使用量、为什么训练因"内存不足"而崩溃、为什么 `fork()` 会复制你的Python进程、以及为什么Docker容器提供隔离环境。 + +## 操作系统做什么 + +- OS有三个核心职责: + + - **抽象**:将硬件复杂性隐藏在简洁的接口之后。程序读写"文件"而无需知道底层存储是SSD、HDD还是网络驱动器。它们分配"内存"而无需管理物理RAM芯片。它们在"CPU"上运行而无需担心中断和缓存一致性。 + + - **资源管理**:多个程序共享CPU、内存、磁盘和网络。OS决定谁获得什么资源、何时获得、获得多久。公平高效的分配策略保持系统的响应性。 + + - **隔离与保护**:程序之间不得相互干扰。浏览器中的Bug不应导致内核崩溃。恶意程序不应读取另一个程序的密码。OS利用硬件支持(特权级、虚拟内存)强制实施边界。 + +## 进程 + +- **进程**是正在运行的程序。它是OS的基本工作单元。每个进程都有: + + - **代码**(程序指令,只读)。 + - **数据**(全局变量,堆分配)。 + - **堆栈**(函数调用帧,局部变量)。 + - **状态**(寄存器值、程序计数器、打开的文件等)。 + +- **进程控制块(PCB)**是OS用于跟踪进程的数据结构。它存储进程ID(PID)、状态、程序计数器、寄存器内容、内存映射、打开的文件描述符和调度优先级。当OS从一个进程切换到另一个进程时,它将当前进程的状态保存到其PCB中,并加载下一个进程的状态。这就是**上下文切换**。 + +- 上下文切换代价高昂:保存和恢复寄存器、刷新缓存、使TLB项失效需要微秒级时间。在一个运行数千个进程的系统中,开销可能很大。这就是为什么每进程每请求的服务器架构(如老式Apache)被基于线程或事件驱动的架构取代。 + +- Unix中的**进程创建**使用 `fork()` 和 `exec()`: + + - `fork()` 创建当前进程的一个**副本**。子进程获得父进程内存、文件描述符和状态的一份副本。两个进程从同一点继续执行,但 `fork()` 在子进程中返回0,在父进程中返回子进程的PID。 + + - `exec()` 用新程序替换当前进程的代码。在 `fork()` 之后,子进程通常调用 `exec()` 来运行一个不同的程序。 + + - 这种先fork后exec的模型很优雅:创建新进程(fork)和加载新程序(exec)是独立的操作,可以各自定制。在fork和exec之间,子进程可以重定向I/O、更改环境变量或降低权限。 + +![进程状态转换:新建→就绪→运行→阻塞/终止,包含抢占和I/O等待](../images/process_states.svg) + +- **进程状态**:一个进程处于以下几种状态之一: + - **运行**:当前在CPU核心上执行。 + - **就绪**:等待CPU核心(可运行但尚未被调度)。 + - **阻塞**(等待):无法继续,直到某个事件发生(I/O完成、锁获取、定时器到期)。 + - **终止**:执行完毕,等待父进程收集其退出状态。 + +## 线程 + +- **线程**是进程内的轻量级执行单元。进程内的所有线程共享相同的代码、数据和堆,但每个线程有自己的堆栈和寄存器状态。 + +- 与多个进程相比的优势:线程共享内存,因此它们之间的通信很快(只需读写共享变量)。进程需要进程间通信(管道、套接字、共享内存映射),这更慢且更复杂。 + +- 劣势:共享内存是危险的。两个线程同时写入同一变量会导致**竞态条件**(结果取决于哪个线程先运行)。这引导我们进入同步问题,在文件4中介绍。 + +- **内核线程**由OS调度器管理。每个线程独立地被调度到CPU核心上。创建和切换内核线程涉及系统调用,开销与进程上下文切换类似(但更小)。 + +- **用户线程**(绿色线程)由用户空间的运行时库管理,对OS不可见。创建和切换它们的成本更低(无需系统调用),但一个用户线程的阻塞操作会阻塞进程中的所有线程(因为OS只看到一个内核线程)。 + +- 现代系统使用**混合模型**:许多用户线程映射到较少数量的内核线程上(M:N线程)。Go的goroutine和Erlang的进程是由语言运行时调度到OS线程上的用户级线程。 + +- **线程池**预先创建固定数量的线程,等待任务。当任务到达时,分配给一个空闲线程。这避免了为每个任务创建和销毁线程的开销。Web服务器、数据库引擎和ML推理服务器都使用线程池。 + +## CPU调度 + +- **调度器**决定每个时刻哪个进程/线程在哪个CPU核心上运行。目标是:最大化CPU利用率、最小化响应时间(对交互式任务)、最大化吞吐量(对批处理任务)、并确保公平性。 + +- **先来先服务(FCFS)**:进程按到达顺序运行。简单但存在**护航效应**:一个长时间运行的进程阻塞了后面所有较短的进程。 + +- **最短作业优先(SJF)**:运行最短的进程优先。可证明最小化平均等待时间,但需要预先知道作业长度(通常不可能)。其抢占式版本**最短剩余时间优先(SRTF)**,如果出现更短的作业则中断正在运行的作业。 + +- **轮转(RR)**:每个进程获得一个固定的**时间片**(如10 ms),然后被抢占并移到队列末尾。公平且响应性好,但时间片大小很重要:太小会导致过多上下文切换,太大则会退化为FCFS。 + +- **优先级调度**:每个进程有一个优先级。高优先级进程先运行。危险是**饥饿**:如果高优先级进程源源不断到来,低优先级进程可能永远无法运行。**老化**解决这个问题:进程等待时间越长,其优先级就越高。 + +- **多级反馈队列(MLFQ)**:具有不同优先级和时间片的多个队列。新进程从最高优先级队列(短时间片)开始。如果一个进程用完其时间片(CPU密集型),它被降到较低优先级队列(较长时间片)。交互式进程自然停留在高优先级队列中(它们在使用完时间片之前就因I/O阻塞了)。这可以适应工作负载,而无需预先了解作业类型。 + +- **完全公平调度器(CFS)**:Linux调度器。它维护一棵红黑树(平衡二叉搜索树),进程按"虚拟运行时间"(它们已经消耗的CPU时间)排序。具有最小虚拟运行时间的进程接下来运行。这确保了随着时间的推移,每个进程获得其公平份额。CFS每次调度决策运行时间为 $O(\log n)$。 + +## 内存管理 + +- OS管理物理RAM,将其分配给进程并在不再需要时回收。 + +- **分页**(来自文件2)将虚拟内存划分为固定大小的页,物理内存划分为帧。页表将页映射到帧。分页消除了外部碎片(分配之间的浪费空间),因为所有页面大小相同。 + +- **请求分页**仅在首次访问时将页加载到RAM中(而不是在进程启动时)。这节省了内存:一个拥有1 GB代码的程序在典型运行中可能只使用50 MB。其余部分从未被加载。 + +- 当RAM满且需要新页时,OS必须**换出**一个现有页面。**页面置换**算法(LRU、FIFO、时钟,来自文件2)决定换出哪个页面。好的置换最小化缺页次数;坏的置换导致系统颠簸。 + +- **分段**将内存划分为可变大小的段(代码、数据、栈、堆),每个段有自己的基地址和长度。分段提供逻辑组织,而分页提供物理管理。现代系统最小限度地使用分段(主要用于保护),并依赖分页进行内存管理。 + +- **堆**是动态分配内存所在的地方(C中的`malloc`/`free`,Java中的`new`,Python中隐式管理)。OS向进程提供大块内存,**内存分配器**(如 `glibc malloc`、`jemalloc`、`tcmalloc`)将这些大块细分为更小的分配。分配器设计影响性能:碎片浪费空间,线程间的争用浪费时间。 + +## 文件系统 + +- **文件系统**将持久存储(SSD、HDD)上的数据组织为命名的文件和目录层次结构。 + +- **inode**(索引节点)存储文件的元数据:大小、所有权、权限、时间戳以及指向磁盘上数据块的指针。文件名存储在目录中,目录将名称映射到inode编号。这种分离意味着一个文件可以有多个名称(**硬链接**)指向同一个inode。 + +- **FAT**(文件分配表):一种简单的文件系统,用于USB驱动器和SD卡。一个表将每个簇(块)映射到文件中的下一个簇,形成一个链表。简单但不好支持权限、日志记录或大文件。 + +- **ext4**:默认的Linux文件系统。使用带有直接、间接、二级间接和三级间接块指针的inode来处理任何大小的文件。支持**区段**(块的连续范围)以高效处理大文件。最大文件大小:16 TB,最大分区:1 EB。 + +- **日志记录**防止因崩溃而损坏。在修改文件系统结构之前,更改被写入**日志**(journal)。如果系统在操作中间崩溃,重启时会重放日志以完成或撤销该操作。没有日志记录,写入期间的崩溃可能使文件系统处于不一致状态(文件的数据块已更新但其inode未更新,反之亦然)。 + +- **基于B树的文件系统**(Btrfs、ZFS)使用B树(平衡搜索树)来组织数据和元数据,实现高效搜索、写时复制快照以及用于数据完整性的内置校验和。这些与数据库索引中使用的B树相同。 + +## 系统调用与内核模式 + +- **系统调用**是用户程序和OS内核之间的接口。当程序需要做一些特权操作(读取文件、分配内存、创建进程、发送网络数据包)时,它会进行系统调用。 + +- CPU在两种模式下运行: + - **用户模式**:受限制。程序可以执行自己的代码并访问自己的内存,但不能直接访问硬件、其他进程的内存或OS数据结构。 + - **内核模式**:不受限制。OS内核可以访问所有硬件和内存。系统调用是从用户模式到内核模式的受控通道。 + +- 当程序调用 `read()` 时,发生以下过程: + 1. 程序将参数放入寄存器并触发**陷阱**(一种软件中断)。 + 2. CPU切换到内核模式并跳转到系统调用处理程序。 + 3. 内核验证参数,执行I/O操作,将数据复制到用户的缓冲区。 + 4. 内核切换回用户模式并返回结果。 + +- 常见系统调用:`open`、`read`、`write`、`close`(文件),`fork`、`exec`、`wait`、`exit`(进程),`mmap`、`brk`(内存),`socket`、`bind`、`listen`、`accept`(网络)。 + +- **中断**是迫使CPU暂时停止当前操作并运行中断处理程序(在内核中)的硬件信号。一次键盘按键、一个网络数据包到达或一个定时器滴答都会产生中断。定时器中断特别重要:它使OS能够抢占正在运行的进程并切换到另一个(抢占式多任务)。 + +## 网络基础 + +- 网络栈是OS的一个子系统,实现机器之间的通信。理解它解释了分布式训练如何同步梯度、模型服务如何处理请求以及为什么延迟很重要。 + +![TCP/IP栈:应用层、传输层、网络层和链路层,每层添加头部](../images/tcp_ip_layers.svg) + +- **TCP/IP模型**将网络组织为分层结构,每层为上层提供抽象: + + - **链路层**:处理单个物理链路上的通信(以太网、Wi-Fi)。处理MAC地址和帧。 + - **网络层(IP)**:将数据包跨多个网络从源路由到目标。每台机器有一个**IP地址**(例如 IPv4 的 192.168.1.1 或 128位的IPv6地址)。路由器基于目标IP逐跳转发数据包。 + - **传输层(TCP/UDP)**:提供应用程序之间的端到端通信。 + - **应用层**:HTTP、DNS、gRPC等协议,应用程序直接使用。 + +- **TCP**(传输控制协议)提供可靠、有序的交付。它建立一个连接(三次握手:SYN、SYN-ACK、ACK),保证所有数据按序到达(使用序列号和确认),重传丢失的数据包,并控制发送速率以避免网络过载(**拥塞控制**)。代价是延迟:握手增加了一个往返时间,重传增加了延迟。 + +- **UDP**(用户数据报协议)提供不可靠、无序的交付。无需握手、无需重传、无顺序保证。延迟远低于TCP。用于速度比可靠性更重要的场景:视频流、在线游戏、DNS查询。在ML中,一些梯度同步协议使用基于UDP的RDMA以获得更低延迟。 + +- **套接字**是用于网络通信的OS API。一个**套接字**是由(IP地址,端口号)标识的端点。服务器创建一个套接字,将其绑定到一个端口(例如HTTP的80),监听连接,并接受它们。客户端创建一个套接字并连接到服务器的地址:端口。然后通过套接字像文件一样读写数据。 + +- **DNS**(域名系统)将人类可读的名称(google.com)翻译为IP地址(142.250.80.46)。它是一个分布式的、层次化的数据库:你的机器询问本地解析器,后者询问根服务器,根服务器委托给每个域的权威服务器。 + +- **HTTP**(超文本传输协议)是Web的请求-响应协议。客户端发送一个请求(方法 + URL + 头部 + 可选体),服务器发送一个响应(状态码 + 头部 + 体)。ML模型服务(如TensorFlow Serving、Triton)将模型暴露为HTTP或gRPC端点。 + +- **延迟 vs 带宽**:延迟是一个数据包从源到目标所需的时间(由物理距离和网络跳数决定)。带宽是数据传输速率(每秒字节数)。高带宽、高延迟的连接(卫星互联网)可以传输大量数据,但每个字节需要很长时间才能到达。对于分布式训练,**延迟**对同步屏障(所有GPU必须等待最慢的那个)很重要,而**带宽**对传输大的梯度张量很重要(第6章)。 + +## 虚拟化与容器 + +- **虚拟化**在单个物理机上运行多个操作系统。**虚拟机监视器**(VMware、KVM、Xen)创建**虚拟机(VM)**,每个虚拟机有自己的虚拟CPU、内存、磁盘和网络接口。每个虚拟机运行一个完整的操作系统(来宾OS),它认为自己拥有专用硬件。 + +- VM提供强隔离(一个VM崩溃不影响其他VM)和灵活性(在同一台机器上运行Linux和Windows,在物理主机之间迁移VM)。代价是开销:每个VM运行一个完整的OS内核,消耗内存和CPU来执行与宿主机OS冗余的OS操作。 + +![VM在虚拟硬件上运行独立的来宾OS;容器共享宿主机内核,轻量得多](../images/container_vs_vm.svg) + +- **容器**(Docker、Podman)提供了一种更轻量的替代方案。容器不是虚拟化整个硬件,而是共享宿主机OS内核,并使用内核特性来隔离进程: + + - **命名空间**隔离进程可以看到的内容:每个容器拥有自己的进程树视图(PID命名空间)、网络接口(网络命名空间)、文件系统挂载点(挂载命名空间)和主机名(UTS命名空间)。容器内的进程不能看到其他容器中的进程。 + + - **Cgroups**(控制组)限制进程可以使用的内容:CPU时间、内存、磁盘I/O、网络带宽。容器不能消耗超过其cgroup允许的资源,防止一个容器饿死其他容器。 + +- 容器在毫秒内启动(无需OS启动),使用最小开销(共享内核),并通过**Dockerfile**定义,该文件指定基础镜像、依赖项和命令。这使得它们可复现:`docker build` 在任何地方产生相同的环境。 + +- 对于ML,容器解决了"在我机器上能运行"的问题。具有特定版本CUDA、cuDNN、PyTorch和Python的训练环境被打包为容器镜像。任何人都可以在任何机器上复现确切的环境。云训练平台(AWS SageMaker、GCP Vertex AI)在容器中运行训练任务。 + +- **Kubernetes**(K8s)大规模编排容器:它将容器调度到集群中的多台机器上,重启失败的容器,根据负载进行扩缩容,并管理容器之间的网络。大规模ML服务(数千个模型副本处理数百万请求)在Kubernetes上运行。 + +## 安全基础 + +- OS通过多种机制实施安全: + +- **权限**:每个文件有一个所有者、一个组和权限位(拥有者、组和其他人的读/写/执行)。进程以启动它的用户的身份(UID)运行,只能访问权限位允许的文件。**root**用户(UID 0)绕过所有权限检查,这就是为什么以root身份运行是危险的。 + +- **权限分离**:进程以所需的最小权限运行。Web服务器不需要root访问权限;它应该以一个受限用户身份运行,该用户只能读取Web文件并绑定到端口80。如果服务器被攻破,攻击者的访问限制在该受限用户能做的范围内。 + +- **沙箱化**:限制进程在文件权限之外能做的事情。**seccomp**(Linux)限制进程可以进行的系统调用。**AppArmor**和**SELinux**定义强制访问控制策略。容器结合了命名空间、cgroups和seccomp进行多层隔离。 + +- **地址空间布局随机化(ASLR)**:每次程序运行时,随机化堆栈、堆和库的内存位置。这使得攻击者更难利用内存损坏漏洞(缓冲区溢出),因为他们无法预测代码或数据在内存中的位置。 + +- 安全是一个全系统层面的关注:链条的强度取决于最弱的一环。模型服务系统需要安全的网络通信(TLS/HTTPS)、经过身份验证的API访问(API密钥、OAuth)、输入验证(防止对抗性输入)和隔离执行(具有最小权限的容器)。 + +## 编程任务(使用CoLab或笔记本) + +1. 探索进程创建。使用Python的 `os.fork()`(仅Unix)创建一个子进程,并观察父进程和子进程如何从同一点继续执行。 +```python +import os + +pid = os.fork() + +if pid == 0: + # 子进程 + print(f"Child: my PID is {os.getpid()}, parent PID is {os.getppid()}") +else: + # 父进程 + print(f"Parent: my PID is {os.getpid()}, child PID is {pid}") + os.wait() # 等待子进程结束 +``` + +2. 模拟轮转调度。给定一个带有执行时间的进程列表,模拟调度并计算平均等待时间。 +```python +def round_robin(processes, quantum=3): + """模拟轮转调度。 + processes: (name, burst_time) 元组列表。 + """ + queue = [(name, burst, 0) for name, burst in processes] # (name, remaining, wait) + time = 0 + log = [] + + while queue: + name, remaining, waited = queue.pop(0) + waited += (time - waited - (processes[[p[0] for p in processes].index(name)][1] - remaining)) + run_time = min(quantum, remaining) + log.append(f" t={time:3d}: {name} runs for {run_time} (remaining: {remaining - run_time})") + time += run_time + remaining -= run_time + + if remaining > 0: + queue.append((name, remaining, time)) + else: + log.append(f" t={time:3d}: {name} DONE (turnaround: {time})") + + for line in log: + print(line) + +print("轮转调度 (quantum=3):") +round_robin([("P1", 10), ("P2", 4), ("P3", 6)], quantum=3) +``` + +3. 模拟LRU页面置换。给定一个页面访问序列和固定数量的帧,统计缺页次数。 +```python +def lru_page_replacement(pages, n_frames): + """模拟LRU页面置换。""" + frames = [] + faults = 0 + + for page in pages: + if page in frames: + frames.remove(page) + frames.append(page) # 移动到最近使用 + status = "HIT " + else: + faults += 1 + if len(frames) >= n_frames: + evicted = frames.pop(0) # 移除最近最少使用 + status = f"MISS (evict {evicted})" + else: + status = "MISS (cold)" + frames.append(page) + print(f" Page {page}: {status} frames={frames}") + + print(f"\nTotal faults: {faults}/{len(pages)} ({faults/len(pages):.0%})") + +print("LRU with 3 frames:") +lru_page_replacement([1, 2, 3, 4, 1, 2, 5, 1, 2, 3, 4, 5], n_frames=3) +``` diff --git a/chapter 13: computing and OS/04. concurrency and parallelism.md b/chapter 13: computing and OS/04. concurrency and parallelism.md new file mode 100644 index 0000000..7e088dc --- /dev/null +++ b/chapter 13: computing and OS/04. concurrency and parallelism.md @@ -0,0 +1,226 @@ +# 并发与并行 + +*并发与并行是程序同时处理多件事情的方式。本文涵盖并发与并行的区别、同步原语、经典并发问题、死锁、无锁数据结构、并行编程模型、异步编程和扩展定律——这些概念支撑着多线程服务器、分布式训练和每一个现代应用程序。* + +- 单个CPU核心一次执行一条指令。但现代系统有8个、64个甚至数千个核心(GPU)。即使在单核上,我们也希望处理多个任务:一边下载文件一边渲染界面一边处理用户输入。**并发**和**并行**是管理多个活动的两种策略。 + +## 并发 vs 并行 + +![并发在单核上交错执行任务;并行在多核上同时执行任务](../images/concurrency_vs_parallelism.svg) + +- **并发**是关于*管理*多个任务。任务通过交错进行:任务A运行一会儿,然后任务B,然后回到A。在单核上,并发创造了同时执行的假象。这些任务并非真正同时执行;它们轮流进行。 + +- **并行**是关于*执行*多个任务同时进行。有$n$个核心,$n$个任务可以真正同时运行。并行需要多个硬件执行单元。 + +- 类比:并发是一个厨师交替切菜和搅拌锅。并行是两个厨师各自同时做一个任务。一个系统可以是并发但不并行的(单核,任务交错),并行但不并发的(多核运行独立程序,没有交互),或者两者兼有(多核运行互相交错交互的任务)。 + +- 在ML中,并发出现在数据加载中(数据预处理与GPU计算重叠),而并行出现在分布式训练中(多个GPU同时计算梯度,第6章)。 + +## 同步原语 + +- 当多个线程共享数据时,**同步**防止竞态条件。竞态条件发生在结果依赖于线程执行的不可预测顺序时。 + +- 考虑两个线程同时增加一个共享计数器:`counter += 1`。这实际上是三个操作:(1)读取计数器,(2)加1,(3)写入计数器。如果两个线程读取相同的值(比如5),都加1,都写入6,计数器最终为6而不是正确的7。一次增加丢失了。 + +- **互斥锁**(互斥排斥锁)确保一次只有一个线程访问临界区。一个线程在进入临界区前**获取**锁,之后**释放**锁。任何其他试图获取已被持有锁的线程将阻塞直到锁被释放。 + +``` +lock.acquire() +counter += 1 # 一次只有一个线程在此 +lock.release() +``` + +- 互斥锁是正确的,但会引入**争用**:如果许多线程竞争同一个锁,它们花费时间等待而不是计算。这限制了可扩展性。极端情况下,所有线程都想要同一个锁,会使整个程序串行化。 + +- **信号量**泛化了互斥锁。计数信号量维护一个计数器:`wait()` 递减计数器(如果会变负则阻塞),`signal()` 递增计数器。初始化为1的信号量行为类似互斥锁。初始化为$n$的信号量允许最多$n$个线程同时进入临界区(适用于资源池如数据库连接)。 + +- **条件变量**让一个线程等待直到某个特定条件满足。该线程释放一个锁,在条件变量上等待,当另一个线程发出该条件的信号时被唤醒。这避免了忙等待(在一个循环中反复检查条件,浪费CPU)。 + +- **监视器**将互斥锁与条件变量和共享数据捆绑为一个单一抽象。Java的 `synchronized` 关键字和Python的 `threading.Condition` 实现了类似监视器的语义。 + +- **读写锁**区分读线程(可以共享访问,因为读取不会修改数据)和写线程(需要独占访问)。多个读线程可以同时持有锁,但一个写线程会阻塞所有读线程和其他写线程。当读操作远多于写操作时(例如,提供预测的缓存模型),这是最优的。 + +## 经典并发问题 + +- **生产者-消费者**(有界缓冲区):生产者生成项目并将其放入固定大小的缓冲区;消费者移除项目。挑战:缓冲区满时生产者必须等待,缓冲区空时消费者必须等待,且两者必须防止损坏缓冲区。 + +- 解决方案使用两个信号量(一个计数空位,一个计数满位)加上一个用于缓冲区本身的互斥锁。这是大多数消息队列、日志系统和数据管道背后的模式。 + +- **读者-写者**:多个读者可以同时读取,但写者需要独占访问。挑战是公平性:如果读者源源不断地到来,写者可能饥饿(永远得不到访问)。解决方案要么优先考虑读者,要么优先考虑写者,要么公平地交替。 + +- **哲学家就餐问题**:五位哲学家围坐在一张有五个叉子的桌子旁。每人需要两把叉子才能吃饭。如果所有五位同时拿起左边的叉子,没人能拿起右边的叉子,所有人都饿死(死锁)。解决方案包括:同时拿起两把叉子(原子操作),引入不对称性(一位哲学家先拿右边的叉子),或者使用服务员(限制用餐人数为4的信号量)。 + +## 死锁 + +- **死锁**发生在一组线程各自等待集合中另一个线程持有的资源,形成一个依赖循环。没有人能继续。 + +![死锁:线程A持有锁1想要锁2,线程B持有锁2想要锁1——循环等待](../images/deadlock_cycle.svg) + +- 死锁的四个**必要条件**(必须同时满足): + + 1. **互斥**:资源一次只能被一个线程持有。 + 2. **持有并等待**:一个线程持有一个资源的同时等待另一个资源。 + 3. **不可剥夺**:资源不能被强制从线程中拿走。 + 4. **循环等待**:等待图中存在一个循环。 + +- **死锁预防**打破四个条件之一: + - 消除循环等待:对资源施加全序。所有线程以相同的顺序获取资源。如果每个线程总是在获取锁A之后才获取锁B,则不可能有循环。 + - 消除持有并等待:要求线程一次性(原子地)请求所有资源。 + +- **死锁避免**动态决定是否批准一个资源请求可能导致死锁。**银行家算法**维护每个线程的最大可能需求,仅批准使系统保持"安全状态"(所有线程最终都能完成的状态)的请求。该算法每个请求 $O(n^2 m)$($n$个线程,$m$种资源类型),对大多数实际系统来说过于昂贵。 + +- **死锁检测**让死锁发生,然后检测它们(通过在等待图中找到循环)并恢复(通过杀死一个线程或回滚一个事务)。 + +- 在实践中,大多数系统对常见情况使用预防(资源排序),对罕见情况使用检测。数据库系统是经典例子:它们检测事务之间的死锁并中止一个来打破循环。 + +## 无锁和免等待数据结构 + +- 锁引入了争用、优先级反转和死锁风险。**无锁**数据结构完全避免使用锁,使用硬件提供的**原子操作**。 + +- 关键的原子操作是**比较并交换(CAS)**:原子地检查一个内存位置是否具有期望的值,如果是,则将其替换为新值。伪代码: + +``` +CAS(address, expected, new_value): + if *address == expected: + *address = new_value + return true + else: + return false +``` + +- CAS实现为单个硬件指令,因此即使没有锁也是原子的。无锁算法使用重试循环中的CAS:读取当前值,计算新值,尝试CAS。如果另一个线程在此期间修改了该值,CAS失败,线程重试。 + +- **无锁**:至少一个线程在有限步骤内取得进展(不可能死锁,但个别线程在争用下可能无限重试)。 + +- **免等待**:每个线程在有限步骤内取得进展(最强保证,但最难实现)。 + +- 无锁的堆栈、队列和哈希映射广泛用于高性能系统。Java的 `ConcurrentHashMap` 和Go的原子操作都建立在CAS之上。 + +## 并行编程模型 + +- **共享内存**并行:所有线程访问同一内存空间。同步是程序员的责任。**OpenMP**提供编译器指令来并行化循环: + +```c +#pragma omp parallel for +for (int i = 0; i < n; i++) { + result[i] = compute(data[i]); +} +``` + +- 编译器将循环迭代拆分到可用的核心上。OpenMP对数据并行工作负载(对许多数据点执行相同操作)很有效,广泛用于科学计算。 + +- **消息传递**并行:每个进程有自己的内存。通信通过发送和接收消息实现。**MPI**(消息传递接口)是跨节点分布式计算的标准: + +```c +MPI_Send(data, count, MPI_FLOAT, dest, tag, MPI_COMM_WORLD); +MPI_Recv(data, count, MPI_FLOAT, src, tag, MPI_COMM_WORLD, &status); +``` + +- MPI可扩展到数千个节点,因为没有需要同步的共享状态。分布式深度学习(第6章)使用集合操作如 `MPI_AllReduce`(环状 all-reduce)来跨GPU同步梯度。 + +- **GPU并行**遵循**SIMT**(单指令多线程)模型:数千个线程在不同数据上执行相同的指令。这非常适合矩阵运算(第2章),其中相同的乘加操作应用于每个元素。我们将在后续章节中详细介绍GPU编程。 + +## 异步与事件驱动编程 + +- 并非所有并发都需要线程。**异步**编程使用**事件循环**在单个线程中处理许多I/O密集型任务。 + +- 事件循环维护一个任务队列。当一个任务需要等待I/O(网络响应、文件读取)时,它注册一个回调并交出控制权。事件循环选取下一个就绪的任务。当I/O完成时,回调被排队并最终执行。等待期间没有线程被阻塞。 + +- **协程**是可以暂停和恢复的函数。`async/await` 语法(Python、JavaScript、Rust)使协程看起来像常规的顺序代码: + +```python +async def fetch_data(url): + response = await http_get(url) # 在此暂停,事件循环运行其他任务 + return process(response) # 响应到达时恢复 +``` + +- `await` 关键字暂停协程并将控制权返回给事件循环。当等待的操作完成时,协程从中断处恢复。这是协作式多任务:协程自愿放弃控制,不同于抢占式多任务中OS强制切换线程。 + +- 异步适用于具有许多并发连接的**I/O密集型**工作负载(处理数千个客户的Web服务器)。它不适用于**CPU密集型**工作(单线程事件循环无法利用多核)。对于CPU密集型工作,请使用线程或进程。 + +- Python的**全局解释器锁(GIL)**阻止线程真正的并行:一次只有一个线程可以执行Python字节码。这就是为什么Python对CPU并行使用多处理(独立的进程,每个有自己的解释器),对I/O并发使用异步。GIL正在Python 3.13+中被移除(自由线程Python),这将启用真正的多线程并行。 + +## 扩展定律 + +- **阿姆达尔定律**描述了并行化程序的理论加速。如果程序的$p$部分是可并行的,其余 $1-p$ 部分是串行的: + +$$\text{加速比}(n) = \frac{1}{(1-p) + \frac{p}{n}}$$ + +![阿姆达尔定律:串行部分限制了最大加速比——10%串行意味着最大10x,无论多少核心](../images/amdahl_serial_bottleneck.svg) + +- 其中$n$是处理器数量。当 $n \to \infty$ 时,最大加速比趋近于 $\frac{1}{1-p}$。如果95%的程序是并行的,最大加速比为 $\frac{1}{0.05} = 20\times$,无论你添加多少核心。串行部分就是瓶颈。 + +- 这对ML有深远影响:如果数据加载花费训练时间的10%并且是串行的,增加更多GPU最多只能将训练加速10倍。10%的串行瓶颈限制了所有东西(这就是为什么高效的数据管道和I/O与计算重叠很重要,第6章)。 + +- **古斯塔夫森定律**提供了更乐观的视角。它不是在固定问题规模并添加处理器,而是固定总时间并问可以做多少额外工作。如果并行部分随问题规模扩展: + +$$\text{加速比}(n) = 1 - p + p \cdot n$$ + +- 这是关于$n$线性的。论证是:用更多处理器,我们解决更大的问题,而不是更快地解决同一问题。在ML中,这对应于用更多GPU增加批量大小(弱扩展),而不是保持批量大小固定(强扩展)。 + +## 编程任务(使用CoLab或笔记本) + +1. 演示竞态条件。两个线程在没有同步的情况下增加一个共享计数器,观察丢失的更新。 +```python +import threading + +counter = 0 + +def increment(n): + global counter + for _ in range(n): + counter += 1 # 不是原子的:读、加、写 + +threads = [threading.Thread(target=increment, args=(100000,)) for _ in range(4)] +for t in threads: t.start() +for t in threads: t.join() + +print(f"Expected: {4 * 100000}") +print(f"Actual: {counter}") +print(f"Lost updates: {4 * 100000 - counter}") +``` + +2. 用锁修复竞态条件并测量开销。 +```python +import threading +import time + +lock = threading.Lock() +counter = 0 + +def increment_locked(n): + global counter + for _ in range(n): + with lock: + counter += 1 + +start = time.time() +threads = [threading.Thread(target=increment_locked, args=(100000,)) for _ in range(4)] +for t in threads: t.start() +for t in threads: t.join() +elapsed = time.time() - start + +print(f"Counter: {counter} (correct: {4 * 100000})") +print(f"Time with lock: {elapsed:.3f}s") +``` + +3. 可视化阿姆达尔定律。绘制不同并行比例下加速比与处理器数量的关系图。 +```python +import jax.numpy as jnp +import matplotlib.pyplot as plt + +n_procs = jnp.arange(1, 65) + +for p, color in [(0.5, "#e74c3c"), (0.9, "#f39c12"), (0.95, "#27ae60"), (0.99, "#3498db")]: + speedup = 1 / ((1 - p) + p / n_procs) + plt.plot(n_procs, speedup, color=color, linewidth=2, label=f"p={p}") + # 最大加速比线 + plt.axhline(1 / (1 - p), color=color, linestyle="--", alpha=0.3) + +plt.xlabel("处理器数量") +plt.ylabel("加速比") +plt.title("阿姆达尔定律:串行比例限制加速比") +plt.legend() +plt.grid(True) +plt.show() +``` diff --git a/chapter 13: computing and OS/05. programming languages.md b/chapter 13: computing and OS/05. programming languages.md new file mode 100644 index 0000000..e5be6e3 --- /dev/null +++ b/chapter 13: computing and OS/05. programming languages.md @@ -0,0 +1,273 @@ +# 编程语言 + +*编程语言是人类意图与机器执行之间的接口。本文涵盖语言范式、类型系统、内存管理策略、编译流水线、解释与JIT编译、关键语言特性、领域特定语言以及设计权衡。* + +- 每一份软件、每一个ML模型、每一个操作系统都是用编程语言编写的。但存在数百种语言,每种都有不同的优势。为什么?因为语言设计涉及基本的权衡:性能 vs 安全、表现力 vs 简洁性、控制 vs 抽象。理解这些权衡有助于你为工作选择合适的工具,并理解你所处的约束。 + +## 语言范式 + +- **范式**是一种编程风格:一套指导你如何组织代码和思考问题的原则。 + +- **命令式**编程将计算描述为一系列改变状态的命令。"设x为5。将3加到x。如果x > 7,打印它。"C、Python和Java本质上是命令式的。心智模型是一个带有内存的机器,你逐步修改它。 + +- **面向对象(OOP)**编程围绕**对象**组织代码:数据(属性)和行为(方法)的捆绑。对象通过相互发送消息来交互。关键思想是**封装**(将内部状态隐藏在公共接口之后)、**继承**(通过扩展现有类创建新类)和**多态**(通过共享接口统一处理不同类型)。Java、C++和Python支持OOP。 + +- **函数式编程(FP)**将计算视为数学函数的求值。核心原则:**不可变性**(数据一旦创建就不改变)、**纯函数**(输出仅取决于输入,无副作用)和**一等函数**(函数是可以作为参数传递、从其他函数返回和存储在变量中的值)。Haskell是纯函数式的。Python、JavaScript和Scala支持函数式风格。 + +- 纯函数易于推理、测试和并行化(没有共享的可变状态意味着没有竞态条件)。这就是为什么函数式思想越来越多地用于分布式系统和数据管道。JAX(本书中一直在使用)是函数式的:`jax.grad` 之所以有效,是因为JAX函数是纯函数。 + +- **逻辑编程**描述*什么*应该为真,而不是*如何*计算它。你陈述事实和规则,运行时找到解。Prolog是经典例子:给定"苏格拉底是人"和"所有人都是必死的",引擎推导出"苏格拉底是必死的。"逻辑编程用于AI知识库和类型检查。 + +- 大多数现代语言是**多范式**的:Python支持命令式、OOP和函数式风格。Rust支持命令式和函数式。范式是一种工具,不是信仰。 + +## 类型系统 + +- **类型**对值进行分类,并确定哪些操作是有效的。整数3和字符串"3"是不同的类型:你可以对整数进行加法,但不能对字符串(好吧,你可以拼接字符串,但那是不同的操作)。 + +- **静态类型**:类型在**编译时**检查,在程序运行之前。类型错误及早被发现。C、Java、Rust和Go是静态类型的。你必须声明类型(或者编译器推断它们): + +```rust +let x: i32 = 5; // Rust:x是一个32位整数 +let y: f64 = 3.14; // y是一个64位浮点数 +// let z = x + y; // 编译错误:不能加 i32 和 f64 +``` + +- **动态类型**:类型在**运行时**检查,当操作实际执行时。更灵活,但类型错误只有在代码运行时才暴露。Python、JavaScript和Ruby是动态类型的: + +```python +x = 5 # x是一个int(目前) +x = "hello" # 现在x是一个字符串——没有错误 +``` + +- **强类型**:语言阻止隐式类型转换。Python是强类型的:`"3" + 5` 引发TypeError。**弱类型**:语言静默地转换类型。JavaScript是弱类型的:`"3" + 5` 得到 `"35"`(数字被强制转换为字符串)。C是弱类型的:你可以将指针强制转换为整数。 + +- **类型推断**让编译器推导类型而无需显式注解: + +```rust +let x = 5; // 编译器推断:i32 +let y = x + 3.0; // 编译错误:混合类型,即使有推断 +``` + +- **泛型**(参数化多态)让你编写适用于任何类型的代码: + +```rust +fn largest(list: &[T]) -> &T { + let mut max = &list[0]; + for item in &list[1..] { + if item > max { max = item; } + } + max +} +// 适用于整数、浮点数、字符串——任何支持比较的类型 +``` + +- 对于ML:Python的动态类型使实验快速,但隐藏了错误。生产ML系统越来越多地使用类型提示(`def train(model: nn.Module, lr: float) -> float`)和静态分析工具(mypy)以在部署前捕获错误。PyTorch和JAX使用Python以获得灵活性;TensorRT和ONNX Runtime使用C++以获得性能。 + +## 内存管理 + +- 每个程序分配和释放内存。如何管理这是最具影响力的语言设计决策之一。 + +![内存布局:堆栈从高地址向下增长,堆向上增长,代码和数据在底部](../images/stack_vs_heap.svg) + +- **堆栈**存储局部变量和函数调用帧。分配很简单(移动栈指针),释放是自动的(函数返回时弹出帧)。堆栈访问很快,因为它总在缓存中。但堆栈有固定大小(通常1-8 MB),且仅支持LIFO(后进先出)分配。 + +- **堆**存储动态分配的数据(编译时大小未知的对象、数组、字符串)。堆分配较慢(需要找到一个空闲块),需要显式或自动释放。堆可以增长到填满可用内存。 + +- **手动内存管理**(C、C++):程序员显式分配(`malloc`)和释放(`free`)堆内存。最大控制和性能,但极易出错: + - **释放后使用**:访问已被释放的内存。导致崩溃或安全漏洞。 + - **双重释放**:释放同一内存两次。破坏分配器的内部数据结构。 + - **内存泄漏**:分配了内存但从未释放。程序慢慢消耗所有可用RAM。 + +- **垃圾回收(GC)**:运行时自动检测并释放不再可达的内存。程序员从不调用 `free`。 + + - **跟踪GC**(Java、Go、Python的循环收集器):定期从"根"(堆栈变量、全局变量)遍历所有可达对象,释放不可达对象。简单但导致**GC暂停**:收集器运行时程序停止。现代收集器(Go的并发GC、Java的ZGC)将暂停时间最小化到亚毫秒级。 + + - **引用计数**(Python的主要机制、Swift、Objective-C):每个对象跟踪有多少引用指向它。当计数降到0时,对象被立即释放。无暂停,但无法处理**循环**(A引用B,B引用A,两者计数都 > 0 但都不可达)。Python使用单独的循环检测器来处理此问题。 + +- **所有权**(Rust):编译器在编译时强制实施内存安全规则,零运行时开销。 + + - 每个值有且仅有一个**所有者**。当所有者超出作用域时,该值被丢弃(释放)。 + - 值可以被**借用**(引用),但编译器强制:要么一个可变引用,要么任意数量的不可变引用,永远不能同时存在。 + - 这阻止了释放后使用、双重释放、数据竞争和悬垂指针,全部在编译时完成。无需GC,无运行时开销。 + +- **借用检查器**是Rust的杀手级特性,也是其最陡峭的学习曲线。它保证了内存安全和线程安全,且没有垃圾回收,这就是Rust越来越多地用于性能关键系统(OS内核、游戏引擎、ML推理运行时如Candle和Burn)的原因。 + +## 编译流水线 + +- **编译器**在程序运行之前将源代码转换为机器码(或其他目标语言)。该流水线有几个阶段: + +![编译流水线:源代码→词法分析器→解析器→语义分析→优化器→代码生成→机器码](../images/compilation_pipeline.svg) + +1. **词法分析**(分词):将源文本转换为令牌流。`x = 3 + y` 变为 `[IDENT("x"), EQUALS, INT(3), PLUS, IDENT("y")]`。词法分析器去除空白和注释。 + +2. **语法分析**:从令牌流构建**抽象语法树(AST)**。AST表示程序的层次结构。`3 + y * 2` 解析为 `Add(3, Mul(y, 2))`(乘法优先级更高)。解析器检查语法:括号不匹配和缺少分号在此被捕获。 + +3. **语义分析**:检查类型、解析变量名、验证函数调用参数是否正确。静态类型检查在此发生。输出是带类型注解的AST。 + +4. **优化**:在不改变行为的情况下转换程序以使其运行更快。常见优化: + - **常量折叠**:在编译时计算 `3 + 5`,替换为 `8`。 + - **死代码消除**:移除永远无法执行的代码。 + - **循环展开**:用重复的内联代码替换循环以减少分支开销。 + - **内联**:用函数体替换函数调用,消除调用开销。 + +5. **代码生成**:将优化后的表示转换为目标机器码(x86、ARM)或中间表示。 + +- **LLVM**是主流的编译器基础设施。它提供了一个通用中间表示(LLVM IR),许多语言可以编译到该表示上。LLVM的优化器在这个IR上工作,其后端为许多目标生成机器码。Clang(C/C++)、Rust、Swift、Julia和许多其他语言使用LLVM。这意味着LLVM优化器的改进同时惠及所有这些语言。 + +## 解释与JIT编译 + +- **解释器**逐行(或逐语句)执行程序而不产生机器码。这使得启动快速且开发交互式,但执行较慢(每行每次运行时都要重新分析)。 + +- 大多数解释型语言实际上编译为**字节码**:一种比源代码更简单但不特定于机器的中间表示。字节码在**虚拟机(VM)**上运行。 + + - **CPython**(标准Python实现)将Python源代码编译为字节码(`.pyc` 文件),由CPython VM执行。VM逐条指令解释字节码。这就是为什么Python在计算密集型代码上比C慢约~100倍。 + + - **JVM**(Java虚拟机):Java编译为JVM字节码(`.class` 文件)。JVM最初解释字节码,然后**JIT编译**频繁执行的代码路径("热点")为本机机器码。这就是为什么Java启动比C慢(解释开销),但对于长时间运行的程序(JIT优化的热路径)接近C的速度。 + +- **JIT(即时)编译**在运行时将代码编译为机器码,使用仅在执行期间可用的信息。JIT可以根据实际运行时数据进行优化:如果一个函数总是用整数参数调用,JIT生成专门化的仅整数机器码,跳过类型检查。 + +- **PyPy**是另一个带有JIT编译器的Python实现。它通过将热点循环JIT编译为机器码,使大多数Python代码运行速度比CPython快5-10倍。然而,它与C扩展模块(NumPy、PyTorch)的兼容性有限,这限制了它在ML中的使用。 + +- 从解释到编译的范围不是二元的: + - 纯解释:Bash shell脚本。 + - 字节码解释:CPython。 + - 字节码 + JIT:JVM、.NET CLR、LuaJIT、PyPy。 + - 提前(AOT)编译:C、C++、Rust、Go。 + - AOT + 运行时代码生成:JAX的 `jax.jit` 在首次调用时编译Python函数为优化的XLA代码,然后缓存编译后的版本。 + +## 关键语言特性 + +- **闭包**:捕获其包围作用域中变量的函数。该函数"闭合"其定义时的环境: + +```python +def make_adder(n): + def add(x): + return x + n # n 从包围作用域捕获 + return add + +add5 = make_adder(5) +print(add5(3)) # 8 +``` + +- 闭包是回调、装饰器和部分应用背后的机制。它们对函数式编程至关重要。 + +- **模式匹配**:一种强大的控制流机制,解构数据并根据其形状进行分支: + +```rust +match value { + Some(x) if x > 0 => println!("Positive: {}", x), + Some(0) => println!("Zero"), + Some(x) => println!("Negative: {}", x), + None => println!("Nothing"), +} +``` + +- 模式匹配比if-else链更具表现力:它检查数据的结构(是Some还是None?它包含的值是否符合某个条件?),而不仅仅是相等性。Python在3.10中增加了结构模式匹配(`match`/`case`)。 + +- **代数数据类型(ADT)**:可以是多个变体之一的类型,每个变体携带不同的数据。`Result` 类型要么是 `Ok(value)` 要么是 `Err(error)`。`Tree` 要么是 `Leaf(value)` 要么是 `Node(left, right)`。ADT结合模式匹配可以穷尽处理所有情况,消除整类bug(空指针异常、未处理的错误码)。 + +- **特质与接口**:定义一个类型必须实现的一组方法,而不指定如何实现。这实现了多态:一个接受"任何实现了Display特质的类型"的函数可以处理整数、字符串和自定义类型。Rust使用特质,Java使用接口,Go使用隐式接口,Python使用鸭子类型("如果它走路像鸭子……")。 + +## 领域特定语言 + +- **领域特定语言(DSL)**是为特定问题域设计的语言,在该领域内用通用性换取表现力。 + +- **SQL**:关系数据库的语言。`SELECT name FROM users WHERE age > 30` 比等价的命令式循环可读性强得多且更易优化。数据库引擎优化查询执行计划,自动选择连接策略和索引使用。 + +- **正则表达式**:用于文本模式匹配的微型语言。`\d{3}-\d{4}` 匹配像"555-1234"这样的电话号码。正则引擎将模式编译为有限自动机以实现高效匹配。 + +- **着色器语言**(GLSL、HLSL、Metal Shading Language):在GPU核心上运行的程序,用于计算像素颜色、顶点位置或计算操作。着色器是海量并行的:每次调用独立处理一个像素或一个元素。这与CUDA用于ML计算的执行模型相同。 + +- 在ML中,像PyTorch和JAX这样的框架本质上是嵌入在Python中的张量计算DSL。它们提供领域特定的抽象(张量、自动微分、设备放置),同时利用Python的生态系统。 + +## 语言设计权衡 + +- 没有一种语言在所有方面都是最好的。设计是关于选择哪些权衡: + +- **性能 vs 安全**:C提供了原始速度和硬件控制,但会让你破坏内存。Rust以编译时内存安全提供相当的速度。Java提供内存安全但有垃圾回收开销。Python提供最大的安全性和表现力,但执行速度慢100倍。 + +- **表现力 vs 简洁性**:Haskell的类型系统可以表达非常精确的约束,但有陡峭的学习曲线。Go故意省略了泛型(直到最近)、继承和异常以追求简洁性。Python的"应该有一种——最好只有一种——显而易见的做法"哲学保持了语言的可学习性。 + +- **控制 vs 抽象**:C/C++让你控制内存布局、缓存行为和硬件交互。Python隐藏了所有这些。对于ML训练(GPU计算占主导),Python的开销可以忽略不计。对于ML推理(每微秒都很关键),C++或Rust可能是必要的。 + +- **编译速度 vs 运行时速度**:Go在几秒内编译完成(简单的类型系统,最小优化)。Rust需要几分钟编译(复杂的类型系统,激进优化)。权衡的是开发者迭代速度与部署后的性能。 + +- ML生态系统反映了这些权衡:Python用于实验和训练(表现力取胜),C++/CUDA用于内核和推理(性能取胜),Rust用于基础设施和安全关键系统(安全取胜)。 + +## 编程任务(使用CoLab或笔记本) + +1. 探索闭包和高阶函数。实现一个简单的函数工厂,验证闭包捕获其环境。 +```python +def make_multiplier(factor): + """返回一个将输入乘以 factor 的函数。""" + def multiply(x): + return x * factor + return multiply + +double = make_multiplier(2) +triple = make_multiplier(3) + +print(f"double(5) = {double(5)}") # 10 +print(f"triple(5) = {triple(5)}") # 15 + +# 闭包通过引用捕获,而不是通过值 +def make_counter(): + count = [0] # 可变的容器以允许修改 + def increment(): + count[0] += 1 + return count[0] + return increment + +counter = make_counter() +print(f"counter() = {counter()}") # 1 +print(f"counter() = {counter()}") # 2 +print(f"counter() = {counter()}") # 3 +``` + +2. 比较动态与静态类型行为。展示Python的动态类型如何提供灵活性但可能隐藏bug。 +```python +def add(a, b): + return a + b + +# 适用于不同类型——灵活! +print(add(3, 5)) # 8 (int + int) +print(add("hello ", "world")) # "hello world" (str + str) +print(add([1, 2], [3, 4])) # [1, 2, 3, 4] (list + list) + +# 但类型错误仅在运行时暴露: +try: + print(add("hello", 5)) # TypeError!str + int +except TypeError as e: + print(f"运行时错误:{e}") + print("静态类型检查器会在运行前捕获此问题") +``` + +3. 测量解释型Python与编译/JIT方法在计算密集型任务上的性能差异。 +```python +import time +import jax +import jax.numpy as jnp + +n = 1_000_000 + +# 纯Python循环(解释型) +start = time.time() +total = 0.0 +for i in range(n): + total += i * i +python_time = time.time() - start + +# JAX(通过XLA编译) +@jax.jit +def sum_squares_jax(n): + return jnp.sum(jnp.arange(n, dtype=jnp.float32) ** 2) + +_ = sum_squares_jax(10) # 预热JIT +start = time.time() +result = sum_squares_jax(n) +jax_time = time.time() - start + +print(f"Python loop: {python_time:.4f}s") +print(f"JAX (JIT): {jax_time:.6f}s") +print(f"Speedup: {python_time / jax_time:.0f}x") +``` diff --git a/chapter 14: data structures and algorithms/00. foundations.md b/chapter 14: data structures and algorithms/00. foundations.md new file mode 100644 index 0000000..6310c43 --- /dev/null +++ b/chapter 14: data structures and algorithms/00. foundations.md @@ -0,0 +1,500 @@ +# 基础:大O表示法、递归、回溯与动态规划 + +*在深入学习数据结构和算法之前,你需要掌握四个基础概念:衡量效率的大O表示法、将问题分解为子问题的递归、带剪枝的穷举搜索——回溯,以及避免冗余计算的动态规划。本文件从基本原理出发逐一讲解。* + +- 本章后续文件默认你已经熟悉了这四个概念。如果你跳过本文件,那么后面文件中的 $O(n \log n)$ 标注、递归树遍历、回溯模板和 DP 状态转移对你来说就会像是魔法而非工程。 + +## 为什么是模式,而非死记硬背 + +- LeetCode、NeetCode 和 HackerRank 上有成千上万的编程题。没有人能记住全部,试图这么做是注定失败的策略。面试官不会从固定题库中选题——他们会修改、组合、伪装。背下来的"两数之和"解法,当面试官问你一个从未见过的变体时毫无用处。 + +- 好消息是:核心模式大约只有 **15-20 种**(双指针、滑动窗口、BFS/DFS、DP、回溯等)。所有问题,无论表面多新颖,最终都归结为这些模式中的一个或几个组合。面试考的不是你是否见过这道题,而是你是否能**剥离上下文**——故事、具体数据类型、边界情况——识别出底层的模式。 + +- 考虑这三个问题: + - "在数组中找到两个数,使其和等于一个目标值。" + - "找到两个分子,使其结合能之和等于一个阈值。" + - "给定一个账户余额列表,找到两个账户的余额之和等于一笔债务。" + +- 它们看起来截然不同。但它们是同一个问题:**两数之和**。上下文(数字、分子、账户)无关紧要。其结构是:在集合中搜索补数 → 哈希表查找。 + +- 这就是本章通过**直觉教授模式**而非通过重复教授解题方法的原因。对于每个模式,我们都会解释: + - **问题中的什么结构特征**指示了这个模式(输入已排序 → 双指针;子数组约束 → 滑动窗口;最优子结构 + 重叠子问题 → DP)。 + - **为什么这个模式有效**——数学或逻辑推理,而不仅仅是"它能给出正确答案"。 + - **如何适配它**——通过展示简单、中等和困难变体,在这些变体中相同的核心思想应用于不同的上下文。 + +- 当你深入理解*为什么*滑动窗口有效(约束的单调性意味着扩展/收缩就足够了),你就可以将其应用到任何具有该结构的问题上,即使是未曾见过的问题。当你只是背下了"无重复字符的最长子串"的代码,一旦问题发生变化,你就会束手无策。 + +- 实践策略: + 1. **学习模式**(本章)。 + 2. **练习识别模式**,在伪装的问题中(每个文件末尾的 NeetCode 练习题)。 + 3. **练习实现**,在时间压力下。 + 4. 面试中:阅读题目 → 剥离上下文 → 识别模式 → 实现。 + +--- + +## 大O表示法 + +- 当我们说一个算法"快"或"慢"时,需要一种精确的衡量方式。**大O表示法**描述了随着输入规模 $n$ 的增长,算法的运行时间(或空间使用量)如何增长,忽略了常数因子和低阶项。 + +- 形式化定义:$f(n) = O(g(n))$ 意味着存在常数 $c > 0$ 和 $n_0$,使得对所有 $n \geq n_0$ 有 $f(n) \leq c \cdot g(n)$。通俗地说:对于大规模输入,$f$ 的增长速度不超过 $g$。 + +- 为什么要忽略常数?因为 $2n$ 的算法和 $5n$ 的算法都是 $O(n)$:它们的扩展方式相同。在更快的计算机上,常数会变,但扩展性不会。大O表示法捕捉了问题的**内在**难度,与硬件无关。 + +### 增长率层级 + +- 从最快到最慢: + +| 大O | 名称 | 示例 | $n = 10^6$ 次操作 | +|-------|------|---------|----------------------| +| $O(1)$ | 常数级 | 数组访问、哈希查找 | 1 | +| $O(\log n)$ | 对数级 | 二分查找 | 20 | +| $O(n)$ | 线性级 | 线性扫描、单循环 | $10^6$ | +| $O(n \log n)$ | 线性对数级 | 归并排序、高效排序 | $2 \times 10^7$ | +| $O(n^2)$ | 平方级 | 嵌套循环、暴力配对 | $10^{12}$(太慢) | +| $O(n^3)$ | 立方级 | 三层嵌套循环、矩阵乘法 | $10^{18}$(实在太慢) | +| $O(2^n)$ | 指数级 | 所有子集、暴力回溯 | $10^{301030}$(不可能) | +| $O(n!)$ | 阶乘级 | 所有排列 | 荒谬 | + +- **经验法则**:现代计算机每秒执行约 $10^8$–$10^9$ 次简单操作。对于1秒的时间限制: + - $O(n)$ 适用于 $n \leq 10^8$ + - $O(n \log n)$ 适用于 $n \leq 10^7$ + - $O(n^2)$ 适用于 $n \leq 10^4$ + - $O(2^n)$ 适用于 $n \leq 25$ + +- 这张表能立即告诉你当前方法是否足够快。如果 $n = 10^5$ 而你的解法是 $O(n^2)$,那就是 $10^{10}$ 次操作——太慢了。你需要一个更好的算法。 + +### 如何分析大O + +- **单循环**遍历 $n$ 个元素:$O(n)$。 + +```python +total = 0 +for x in arr: # n 次迭代 + total += x # 每次迭代 O(1) +# 总计:O(n) +``` + +- **嵌套循环**:迭代次数相乘。 + +```python +for i in range(n): # n 次迭代 + for j in range(n): # 每次 n 次迭代 + process(i, j) # O(1) +# 总计:O(n^2) +``` + +- **每次减半的循环**:$O(\log n)$。每次迭代将问题规模减半,所以需要 $\log_2 n$ 次迭代。 + +```python +i = n +while i > 0: + process(i) + i //= 2 +# 总计:O(log n) +``` + +- **内循环依赖于外循环的嵌套循环**: + +```python +for i in range(n): + for j in range(i): # j 从 0 到 i-1 + process(i, j) +# 总计:0 + 1 + 2 + ... + (n-1) = n(n-1)/2 = O(n^2) +``` + +- **递归**:写出递推关系并求解(第13章介绍了主定理)。例如,归并排序:$T(n) = 2T(n/2) + O(n) = O(n \log n)$。 + +### 常见陷阱 + +- **隐藏的循环**:Python 中 `x in list` 是 $O(n)$(线性扫描),但 `x in set` 是 $O(1)$。在循环中对列表使用 `in` 会得到 $O(n^2)$,而不是 $O(n)$。 + +```python +# 不好:O(n^2) — 对列表用 "in" 是 O(n) +for x in arr: + if x in another_list: + process(x) + +# 好:O(n) — 先转换为 set +another_set = set(another_list) +for x in arr: + if x in another_set: + process(x) +``` + +- **字符串拼接**:Python 中 `s += c` 每次都会复制整个字符串。在 $n$ 次迭代的循环中:$O(1 + 2 + \cdots + n) = O(n^2)$。 + +- **排序主导**:如果你的算法先排序($O(n \log n)$)然后做线性扫描($O(n)$),总复杂度是 $O(n \log n)$——排序占主导。 + +- **平摊复杂度**:某些操作偶尔很昂贵,但平摊下来很便宜。动态数组的追加操作平摊复杂度为 $O(1)$,因为罕见的 $O(n)$ 扩容被分摊到 $n$ 次便宜的追加操作中。不要混淆平摊 $O(1)$ 和最坏情况 $O(1)$。 + +### 空间复杂度 + +- 空间复杂度遵循同样的大O规则,只是应用于内存使用而非时间。 + +- **原地**算法使用 $O(1)$ 额外空间(不计输入)。快速排序是 $O(\log n)$ 空间(递归栈深度)。归并排序是 $O(n)$(合并时使用的临时数组)。 + +- **递归栈**:每次递归调用都会使用栈空间。深度为 $n$ 的递归使用 $O(n)$ 空间,即使每次调用没有分配额外内存。这就是为什么在具有 $n$ 个节点的图上进行递归 DFS 使用 $O(n)$ 空间。 + +- 面试中,始终同时说明时间和空间复杂度。$O(n)$ 时间、$O(n)$ 空间的解法通常可以接受,但 $O(n)$ 时间、$O(1)$ 空间的解法更好。面试官可能会要求你优化其中一个。 + +--- + +## 递归 + +- **递归**是指函数调用自身来解决同一问题的更小实例。它是处理具有递归结构的问题最自然的方式:树、嵌套数据、分治法和数学序列。 + +- 每个递归函数都有两部分: + 1. **基本情况**:可以直接解决的最小的实例(无需递归)。这是递归停止的条件。 + 2. **递归情况**:将问题分解为更小的子问题,递归求解,然后合并结果。 + +### 示例:阶乘 + +```python +def factorial(n): + if n <= 1: # 基本情况 + return 1 + return n * factorial(n - 1) # 递归情况 +``` + +- `factorial(4)` 的执行过程: + - `factorial(4)` 调用 `factorial(3)` + - `factorial(3)` 调用 `factorial(2)` + - `factorial(2)` 调用 `factorial(1)` + - `factorial(1)` 返回 `1`(基本情况) + - `factorial(2)` 返回 `2 * 1 = 2` + - `factorial(3)` 返回 `3 * 2 = 6` + - `factorial(4)` 返回 `4 * 6 = 24` + +- 每次调用都被压入**调用栈**。栈一直增长直到到达基本情况,然后随着每次调用的返回而展开。如果递归太深(例如 Python 中的 `factorial(1000000)`),栈会溢出(`RecursionError`)。Python 的默认递归限制是 1000。 + +### 如何以递归方式思考 + +- 关键的思维转变是:**信任递归**。在编写递归函数时,假设递归调用已经正确返回了更小子问题的答案。你只需要: + 1. 处理基本情况。 + 2. 将问题分解为更小的部分。 + 3. 合并结果。 + +- 你不需要在脑中跟踪每一次递归调用。这就像试图通过在心里执行每次迭代来理解一个循环。相反,验证:"如果递归调用给了我更小输入的正确结果,那么我的组合步骤是否给出了完整输入的正确结果?" + +### 示例:链表上的递归 + +- 递归反转链表: + +```python +def reverse(head): + if not head or not head.next: # 基本情况:0 或 1 个节点 + return head + + new_head = reverse(head.next) # 反转剩余部分 + head.next.next = head # 将下一个节点指回当前节点 + head.next = None # 当前节点现在成为尾节点 + return new_head +``` + +- **信任递归**:`reverse(head.next)` 正确反转了链表的剩余部分并返回新的头节点。我们只需将当前节点附加到末尾。 + +### 示例:树上的递归 + +- 计算二叉树的高度: + +```python +def height(root): + if not root: # 基本情况:空树高度为 0 + return 0 + left_h = height(root.left) # 左子树高度 + right_h = height(root.right) # 右子树高度 + return 1 + max(left_h, right_h) # 当前节点增加 1 层 +``` + +- 这种模式——"递归左子树,递归右子树,合并结果"——解决了绝大多数树的问题(见文件03)。 + +### 递归 vs 迭代 + +- 每个递归算法都可以转换为迭代算法(使用显式栈或循环)。迭代避免了调用栈开销和栈溢出风险。 + +- **何时优先使用递归**:问题具有自然的递归结构(树、嵌套数据、分治法)。递归解法更简洁、更易于推理。 + +- **何时优先使用迭代**:递归深度可能非常大(例如,处理包含 $10^6$ 个节点的链表)。迭代解法避免了栈溢出。 + +- **尾递归**:如果递归调用是函数中的最后一个操作(递归调用返回后没有后续工作),则该递归调用是"尾递归"的。某些语言(Scheme、Scala)会将尾调用优化为使用常数栈空间。Python **不**优化尾调用,因此 Python 中的尾递归仍然使用 $O(n)$ 栈空间。 + +### 常见陷阱 + +| 陷阱 | 示例 | 修复 | +|---------|---------|-----| +| 缺少基本情况 | 无限递归 → 栈溢出 | 始终定义何时停止 | +| 基本情况错误 | 递归分解中的差一错误 | 用最小的输入测试(0、1、2) | +| 问题规模未减小 | `f(n)` 调用 `f(n)` 而非 `f(n-1)` | 确保子问题严格更小 | +| 冗余计算 | 斐波那契数列:`f(n) = f(n-1) + f(n-2)` 以指数级重复计算 | 使用记忆化(→ DP) | +| Python 递归限制 | `factorial(10000)` 崩溃 | 使用 `sys.setrecursionlimit` 或转为迭代 | + +--- + +## 回溯 + +- **回溯**是一种系统地探索所有可能解法的方法,通过逐步构建解并在发现部分解不可能得到有效答案时立即放弃。 + +- 可以把它想象成走迷宫。在每个岔路口,你选择一条路。如果碰到死胡同,你就回到上一个岔路口尝试不同的路。你不会从头开始——你**回溯**到最近的一个决策点。 + +### 三个步骤 + +每个回溯算法都遵循相同的模式: + +1. **选择**:选择一个候选来扩展当前的部分解。 +2. **探索**:递归地尝试从这个候选构建一个完整的解。 +3. **撤销**:撤销选择(回溯)并尝试下一个候选。 + +```python +def backtrack(state, choices, result): + if is_complete(state): + result.append(state.copy()) + return + + for choice in choices: + if is_valid(choice, state): + state.add(choice) # 1. 选择 + backtrack(state, choices, result) # 2. 探索 + state.remove(choice) # 3. 撤销(回溯) +``` + +- **撤销**步骤是回溯与普通递归的区别所在。没有它,状态会累积所有选择,你就无法探索替代路径。 + +### 何时使用回溯 + +- 问题要求**枚举所有有效配置**:所有排列、所有子集、所有有效排列(如 N 皇后)。 +- 问题要求**寻找任何有效配置**:数独求解、迷宫寻路。 +- 搜索空间很大但可以**剪枝**:大多数部分解可以在完全探索之前被提前拒绝。 + +### 剪枝如何使其变快 + +- 没有剪枝时,回溯会探索所有可能的组合——指数级时间。**剪枝**则提前砍掉分支: + +```python +for choice in choices: + if not is_valid(choice, state): + continue # 剪枝:跳过整个子树 + + state.add(choice) + backtrack(state, choices, result) + state.remove(choice) +``` + +- 在 N 皇后问题(文件05)中,在放置皇后之前检查列和对角线冲突,将搜索树从 $n^n$ 剪枝到大约 $n!$ 个候选。对于 $n = 8$,这是 1600 万 → 40,000。好的剪枝使指数级算法在中等规模的 $n$ 下变得可行。 + +### 生成所有子集(最简单的回溯) + +```python +def subsets(nums): + result = [] + + def backtrack(start, path): + result.append(path[:]) # 每个部分解都是一个有效的子集 + + for i in range(start, len(nums)): + path.append(nums[i]) # 选择 + backtrack(i + 1, path) # 探索(i+1:不允许重复使用) + path.pop() # 撤销 + + backtrack(0, []) + return result +``` + +- 对于 `[1, 2, 3]`,递归树: + - `[]` → `[1]` → `[1,2]` → `[1,2,3]`(回溯)→ `[1,3]`(回溯)→ `[2]` → `[2,3]`(回溯)→ `[3]` + +- 树中的每个节点是一次对 `backtrack` 的调用。每个叶子节点(以及中间节点)产生一个子集。总子集数:$2^n$。 + +### 生成所有排列 + +```python +def permutations(nums): + result = [] + + def backtrack(path, remaining): + if not remaining: + result.append(path[:]) + return + + for i in range(len(remaining)): + path.append(remaining[i]) # 选择 + backtrack(path, remaining[:i] + remaining[i+1:]) # 探索 + path.pop() # 撤销 + + backtrack([], nums) + return result +``` + +- 总排列数:$n!$。每个排列需要 $O(n)$ 工作来构造 `remaining`,所以总复杂度为 $O(n \cdot n!)$。 + +### 常见陷阱 + +| 陷阱 | 示例 | 修复 | +|---------|---------|-----| +| 忘记复制路径 | `result.append(path)` —— 所有条目共享同一个列表 | `result.append(path[:])` 或 `path.copy()` | +| 未回溯(撤销) | 状态不断增长,后面的候选看到过时的状态 | 递归调用后始终执行 `path.pop()` 或 `state.remove()` | +| 循环起始位置错误 | 子集中有重复项,或排列中出现了不应有的重复使用 | 使用 `start` 参数避免重新访问之前的索引 | +| 跳过剪枝 | 探索明显无效的分支 | 在递归调用前添加 `if not is_valid: continue` | + +--- + +## 动态规划 + +- **动态规划(DP)**是一种优化技术,适用于相同子问题被反复求解的情况。DP 不重复计算,而是每个子问题只解一次并存储结果。 + +- DP 适用于具有两个性质的问题: + 1. **最优子结构**:最优解可以由子问题的最优解构建而成。 + 2. **重叠子问题**:相同的子问题在递归中多次出现。 + +### 斐波那契数列的动机 + +- 朴素递归斐波那契数列: + +```python +def fib(n): + if n <= 1: + return n + return fib(n - 1) + fib(n - 2) +``` + +- 对于 `fib(5)`,递归树: + - `fib(5)` 调用 `fib(4)` 和 `fib(3)` + - `fib(4)` 调用 `fib(3)` 和 `fib(2)` + - `fib(3)` 被计算了**两次**,`fib(2)` 被计算了**三次** + +- 这是 $O(2^n)$,因为树在每一层都分支,而且大多数分支重复计算相同的值。对于 `fib(50)`,需要超过 $10^{15}$ 次操作——不可行。 + +- 使用**记忆化**(自顶向下 DP): + +```python +def fib_memo(n, memo={}): + if n in memo: + return memo[n] + if n <= 1: + return n + memo[n] = fib_memo(n - 1, memo) + fib_memo(n - 2, memo) + return memo[n] +``` + +- 现在 `fib(3)` 只计算一次,存储起来,后续调用直接查找。总计:$O(n)$ 时间,$O(n)$ 空间。 + +- 使用**制表法**(自底向上 DP): + +```python +def fib_tab(n): + if n <= 1: + return n + dp = [0] * (n + 1) + dp[1] = 1 + for i in range(2, n + 1): + dp[i] = dp[i - 1] + dp[i - 2] + return dp[n] +``` + +- 同样 $O(n)$ 时间,但自底向上构建解,无需递归。可以进一步优化到 $O(1)$ 空间,因为每个值只依赖于前两个值。 + +### DP 配方 + +对于任何 DP 问题,遵循以下步骤: + +1. **定义状态**:`dp[i]`(或 `dp[i][j]`)代表什么?这是最难的一步。状态必须捕获足够的信息以做出最优决策。 + +2. **写出递推关系**:`dp[i]` 如何与更小的子问题关联?这是转移公式。 + +3. **确定基本情况**:哪些是最小的子问题,可以直接求解? + +4. **确定迭代顺序**:哪些子问题必须先于哪些子问题求解?自底向上:按照确保依赖关系已解决的顺序迭代。自顶向下:递归会自动处理。 + +5. **优化空间**(可选):如果 `dp[i]` 只依赖于前一行或前几个条目,你就不需要完整的表。 + +### 示例:思路过程 + +**问题**:给定一个正整数数组,求不相邻元素的最大和(打家劫舍)。 + +**第1步——定义状态**:`dp[i]` = 考虑元素 `nums[0..i]` 的最大和。 + +**第2步——写出递推关系**:对于元素 $i$,我们要么: +- 跳过它:`dp[i] = dp[i-1]`(不含元素 $i$ 的最佳和)。 +- 取用它:`dp[i] = dp[i-2] + nums[i]`(必须跳过元素 $i-1$,然后加上元素 $i$)。 + +所以:`dp[i] = max(dp[i-1], dp[i-2] + nums[i])`。 + +**第3步——基本情况**:`dp[0] = nums[0]`,`dp[1] = max(nums[0], nums[1])`。 + +**第4步——迭代顺序**:从左到右(每个状态依赖于前两个状态)。 + +**第5步——空间优化**:只需要最后两个值。 + +```python +def rob(nums): + if len(nums) == 1: + return nums[0] + + prev2, prev1 = nums[0], max(nums[0], nums[1]) + + for i in range(2, len(nums)): + curr = max(prev1, prev2 + nums[i]) + prev2, prev1 = prev1, curr + + return prev1 +``` + +### 如何识别 DP 问题 + +- 问题要求**最优值**(最小成本、最大利润、最长序列)或**计数**(方法数)。 +- 问题在每一步都有**选择**(取/跳过、向左/向右、使用这枚硬币与否),并且整体最优答案依赖于子问题的最优答案。 +- 画出递归树会显示**重复的子问题**。 +- 暴力解法是指数级的,但**不同的状态**比递归调用少得多。 + +### DP 的分类 + +- **1D DP**:状态依赖于单个索引。示例:爬楼梯、打家劫舍、最大子数组。 + +- **2D DP**:状态依赖于两个索引。示例:最长公共子序列(`dp[i][j]` 表示字符串1的前 $i$ 个字符和字符串2的前 $j$ 个字符)、编辑距离、网格路径问题。 + +- **区间 DP**:状态是一个区间 `dp[i][j]`,表示 `arr[i..j]` 上的子问题。示例:矩阵链乘法、戳气球。 + +- **背包 DP**:状态是物品索引和容量。示例:0/1 背包、零钱兑换、子集和。 + +- **位掩码 DP**:状态包含一个位掩码,表示哪些元素已被使用。示例:旅行商问题、分配问题。状态空间为 $O(2^n \cdot n)$,对于 $n \leq 20$ 可行。 + +### 自顶向下 vs 自底向上 + +| | 自顶向下(记忆化) | 自底向上(制表法) | +|--|---|---| +| 实现 | 递归 + 缓存 | 迭代 + 表 | +| 计算 | 只计算实际需要的子问题 | 计算直到目标的所有子问题 | +| 栈溢出风险 | 有(深度递归) | 无 | +| 空间优化 | 较难 | 较易(使用滚动数组) | +| 编码难度 | 通常更自然(写递归,加缓存) | 需要考虑迭代顺序 | + +- 在面试中,自顶向下通常编码更快。在生产环境中,自底向上通常更受青睐(无递归开销,缓存行为更好)。 + +### 常见陷阱 + +| 陷阱 | 示例 | 修复 | +|---------|---------|-----| +| 状态定义错误 | `dp[i]` 没有捕获足够信息来做决策 | 增加维度(例如用 `dp[i][j]` 代替 `dp[i]`) | +| 缺少基本情况 | `dp[0]` 错误 → 所有后续值都错 | 手动验证基本情况 | +| 迭代顺序错误 | 在依赖关系未解决之前计算 `dp[i]` | 画出依赖箭头并相应迭代 | +| 未正确初始化 `dp` | 用 0 而应该用无穷大(求最小值时) | 最小化用 `float('inf')`,最大化用 `float('-inf')` | +| 忘记考虑"跳过"选项 | 总是取当前元素 | 递推关系通常有 `max(take, skip)` | +| 可变的默认参数 | `def f(memo={})` 在调用间共享缓存 | `def f(memo=None): if memo is None: memo = {}` | +| 2D DP 中的差一错误 | `dp` 是 1-indexed 时访问 `text1[i]` | `dp` 大小为 `(m+1) x (n+1)`,访问 `text1[i-1]` | + +--- + +## 融会贯通 + +- 这四个概念构成一个递进关系: + 1. **大O表示法**告诉你一个方法是否足够快。 + 2. **递归**将问题分解为子问题。 + 3. **回溯**是递归 + 选择 + 撤销,用于穷举搜索。 + 4. **DP**是递归 + 缓存,用于具有重叠子问题的优化。 + +- 当你遇到一个新问题时: + - 估计输入规模 $n$。什么样的 Big O 是可接受的? + - 如果暴力解法是指数级的,且问题要求枚举/寻找配置:**回溯**(配合剪枝使其可行)。 + - 如果暴力解法是指数级的,且问题要求最优值或计数,并且你看到重叠子问题:**DP**。 + - 如果问题具有减半搜索空间的结构:**二分查找**或**分治法**。 + - 如果问题涉及序列且有子数组约束:**滑动窗口**或**双指针**。 + - 如果问题需要快速查找:**哈希表**。 diff --git a/chapter 14: data structures and algorithms/01. arrays and hashing.md b/chapter 14: data structures and algorithms/01. arrays and hashing.md new file mode 100644 index 0000000..eb203b3 --- /dev/null +++ b/chapter 14: data structures and algorithms/01. arrays and hashing.md @@ -0,0 +1,526 @@ +# 数组与哈希 + +*数组和哈希表是编程中最基础的两种数据结构。本文件涵盖它们底层的运行机制,然后构建关键的问题解决模式:双指针、滑动窗口、前缀和以及基于哈希的查找,通过逐步增加难度的题目,并在每一步指出常见陷阱。* + +- 如果你深入理解数组和哈希表,你可以解决约40%的编码面试题。这两种结构无处不在,因为它们提供了算法最需要的两样东西:**快速索引访问**(数组)和**按键快速查找**(哈希表)。 + +- 本文件教授的是模式,而非解法。目标是当你看到一个新问题时,你能识别出适用哪个模式以及为什么,而不是试图回忆一个背下来的解法。 + +## 数组 + +- **数组**是一片连续的内存块,元素以固定偏移量存储。访问元素 $i$ 的成本是 $O(1)$,因为地址就是 `base + i * element_size`。这是最快的数据访问方式,也是数组成为默认选择的原因。 + +- **动态数组**(Python 的 `list`、Java 的 `ArrayList`、C++ 的 `vector`)在满时自动增长。其策略是**平摊加倍**:当数组满时,分配一个两倍大小的新数组并将所有元素复制过去。复制成本为 $O(n)$,但这种情况很少发生(每 $n$ 次插入一次),所以每次插入的平摊成本是 $O(1)$。 + +- **缓存局部性**是数组在实践中很快的原因,而不仅仅是理论上。因为元素是连续存储的,访问一个元素会将其邻近元素加载到 CPU 缓存中(第13章)。遍历数组是缓存友好的;在链表中跟随指针则不是。这个常数因子差异在实际中可能达到 10-100 倍。 + +| 操作 | 数组 | 动态数组 | +|-----------|-------|---------------| +| 按索引访问 | $O(1)$ | $O(1)$ | +| 追加 | 不适用 | $O(1)$ 平摊 | +| 在位置 $i$ 插入 | $O(n)$ | $O(n)$ | +| 在位置 $i$ 删除 | $O(n)$ | $O(n)$ | +| 搜索(未排序) | $O(n)$ | $O(n)$ | + +- **陷阱**:在数组中间插入或删除是 $O(n)$,因为所有后续元素都必须移动。如果你需要频繁在中间插入,考虑使用链表或其他方法。 + +## 字符串 + +- **字符串**是一个字符数组。在 Python 中,字符串是不可变的:每次拼接都会创建一个新的字符串。在循环中逐字符构建字符串是 $O(n^2)$,因为每次拼接都会复制到目前为止的整个字符串。 + +```python +# 不好:O(n^2) 字符串拼接 +s = "" +for c in characters: + s += c # 每次复制整个字符串 + +# 好:O(n) 使用列表然后 join +parts = [] +for c in characters: + parts.append(c) +s = "".join(parts) +``` + +- **陷阱**:在 Python 中,循环内的 `s += c` 是最常见的性能 bug 之一。始终先收集到列表中再 `.join()`。 + +- **编码**:ASCII 使用 7 位(128 个字符)。**UTF-8** 是可变长度的:ASCII 字符使用 1 字节,带重音字符使用 2 字节,中文/日文字符使用 3 字节,表情符号使用 4 字节。当问题说"小写英文字母"时,字母表大小为 26,这意味着你可以使用固定大小的数组而不是哈希表。 + +## 哈希表 + +- **哈希表**将键映射到值,平均情况下的查找、插入和删除都是 $O(1)$。它通过计算一个**哈希函数** $h(key)$ 将键转换为数组索引来实现。 + +- 哈希函数必须:**确定性的**(相同键总是得到相同哈希值)、**均匀的**(将键均匀分布到各个桶中)且**计算速度快**。 + +- **冲突**发生在两个不同的键哈希到相同的索引时。有两种主要策略: + + - **链地址法**:每个桶存储一个键值对链表。发生冲突时,追加到链表。最坏情况(所有键哈希到同一个桶):$O(n)$。使用好的哈希函数时的平均情况:$O(1)$。 + + - **开放地址法**:发生冲突时,探测下一个空槽。**线性探测**检查下一个槽位,然后再下一个,以此类推。它缓存友好,但会遭受**聚集**问题(长串的已占用槽位)。**罗宾汉哈希**通过将"离家较近"的条目移位来减少方差。 + +- **负载因子** $\alpha = n / m$(元素数 / 桶数)决定了性能。当 $\alpha$ 超过阈值(通常为 0.75)时,表会**重新哈希**:分配一个更大的表并重新插入所有元素。这需要 $O(n)$ 时间,但不常发生。 + +- **哈希映射**(Python 中的 `dict`、Java 中的 `HashMap`)存储键值对。**哈希集合**(Python 中的 `set`、Java 中的 `HashSet`)只存储键(用于快速成员测试)。 + +| 操作 | 平均 | 最坏情况 | +|-----------|---------|------------| +| 查找 | $O(1)$ | $O(n)$ | +| 插入 | $O(1)$ | $O(n)$ | +| 删除 | $O(1)$ | $O(n)$ | + +- **布隆过滤器**是空间高效的概率性集合。它可以告诉你"肯定不在集合中"或"可能在集合中"(具有可调的假阳性率)。它使用 $k$ 个哈希函数和一个位数组。用于数据库(避免对不存在的键进行磁盘读取)、Web 缓存和拼写检查器。 + +- **何时使用哈希表**:每当你需要用 $O(1)$ 的时间回答"我之前见过这个吗?"或"与这个键关联的计数/索引/值是什么?"时。如果你正在反复进行线性扫描寻找某物,哈希表几乎总能使其更快。 + +--- + +## 模式:哈希表查找 + +- 最基本的模式:使用哈希表将 $O(n)$ 扫描替换为 $O(1)$ 查找。 + +### 简单:两数之和 + +- **问题**:给定一个整数数组和一个目标值,返回两个数的索引,使它们的和等于目标值。 + +- **暴力解法** $O(n^2)$:检查每一对。 + +- **模式洞察**:对于每个数字 `num`,需要 `target - num` 存在于数组中的某处。与其扫描数组寻找它,不如将之前见过的数字存储在一个哈希表中。 + +```python +def two_sum(nums, target): + seen = {} # 值 -> 索引 + for i, num in enumerate(nums): + complement = target - num + if complement in seen: + return [seen[complement], i] + seen[num] = i +``` + +- **为什么有效**:一次遍历数组。对于每个元素,哈希表查找是 $O(1)$。总计:$O(n)$ 时间,$O(n)$ 空间。 + +- **陷阱**:在检查补数之前不要将当前数字添加到哈希表,否则可能会让元素与自身匹配。上面代码中的顺序是正确的:先检查,后插入。 + +### 中等:字母异位词分组 + +- **问题**:给定一个字符串列表,将字母异位词分组在一起。("eat"、"tea"、"ate")是一组。 + +- **模式洞察**:异位词具有相同的字符但顺序不同。如果对每个字符串进行排序,异位词会产生相同的排序后键。使用这个排序后的键作为哈希表的键。 + +```python +from collections import defaultdict + +def group_anagrams(strs): + groups = defaultdict(list) + for s in strs: + key = tuple(sorted(s)) # 或使用字符计数元组 + groups[key].append(s) + return list(groups.values()) +``` + +- **优化**:对每个字符串排序需要 $O(k \log k)$,其中 $k$ 是字符串长度。为了更快的键,统计字符频率并使用计数元组作为键: + +```python +def group_anagrams_fast(strs): + groups = defaultdict(list) + for s in strs: + count = [0] * 26 + for c in s: + count[ord(c) - ord('a')] += 1 + groups[tuple(count)].append(s) + return list(groups.values()) +``` + +- 这样每个字符串是 $O(k)$ 而不是 $O(k \log k)$。字符计数元组是一种**规范形式**:对组内所有成员都相同的表示。 + +- **陷阱**:在 Python 中,列表不可哈希(不能用作字典键)。你必须转换为元组。当人们尝试 `groups[count].append(s)` 时就会出错。 + +### 困难:最长连续序列 + +- **问题**:给定一个未排序的数组,找出最长连续序列的长度(例如,[100, 4, 200, 1, 3, 2] → 4,因为 [1, 2, 3, 4])。 + +- **暴力解法** $O(n \log n)$:对数组排序,然后扫描连续段。 + +- **模式洞察**:将所有数字放入哈希集以实现 $O(1)$ 查找。对于每个数字,检查它是否是一个序列的**起点**(即 `num - 1` 不在集合中)。如果是,则计算该序列能延伸多远。 + +```python +def longest_consecutive(nums): + num_set = set(nums) + best = 0 + + for num in num_set: + # 只从序列的开头开始计数 + if num - 1 not in num_set: + length = 1 + while num + length in num_set: + length += 1 + best = max(best, length) + + return best +``` + +- **为什么是 $O(n)$**:内部 `while` 循环在所有迭代中总共最多运行 $n$ 次(每个数字最多被访问两次:一次在外层循环,一次在 `while` 扩展中)。`if num - 1 not in num_set` 守卫确保我们只从序列起点开始计数。 + +- **陷阱**:如果没有 `if num - 1 not in num_set` 检查,你会从每个元素开始计数,在最坏情况下会变成 $O(n^2)$(例如,[1, 2, 3, ..., n] 会从每个起点扫描整个序列)。 + +--- + +## 模式:双指针 + +- **双指针**模式使用两个索引在数组中移动,通常从两端向中间或从同端以不同速度移动。它在数组已排序或需要比较成对元素时有效。 + +- **何时使用**:问题涉及成对、子数组或分区,并且数组已排序(或可在不丢失所需信息的情况下排序)。 + +### 简单:验证回文串 + +- **问题**:判断一个字符串是否是回文串,只考虑字母数字字符并忽略大小写。 + +- **模式**:一个指针在开头,一个在结尾。向中间移动,比较字符。 + +```python +def is_palindrome(s): + left, right = 0, len(s) - 1 + + while left < right: + # 跳过非字母数字字符 + while left < right and not s[left].isalnum(): + left += 1 + while left < right and not s[right].isalnum(): + right -= 1 + + if s[left].lower() != s[right].lower(): + return False + + left += 1 + right -= 1 + + return True +``` + +- **陷阱**:忘记内部 while 循环中的 `left < right` 检查。没有它,在像 "!!!"(全部非字母数字)这样的字符串上指针可能越界。 + +### 中等:三数之和 + +- **问题**:找出数组中所有唯一的三元组,使其和为零。 + +- **模式**:对数组排序。固定一个元素,然后在剩余部分使用双指针找到和为固定元素相反数的对。 + +```python +def three_sum(nums): + nums.sort() + result = [] + + for i in range(len(nums) - 2): + # 跳过重复的固定元素 + if i > 0 and nums[i] == nums[i - 1]: + continue + + left, right = i + 1, len(nums) - 1 + target = -nums[i] + + while left < right: + total = nums[left] + nums[right] + if total < target: + left += 1 + elif total > target: + right -= 1 + else: + result.append([nums[i], nums[left], nums[right]]) + # 跳过重复项 + while left < right and nums[left] == nums[left + 1]: + left += 1 + while left < right and nums[right] == nums[right - 1]: + right -= 1 + left += 1 + right -= 1 + + return result +``` + +- **为什么有效**:排序是 $O(n \log n)$。对于每个固定元素,双指针扫描是 $O(n)$。总计:$O(n^2)$,这是该问题的最优解(你必须考虑所有成对组合)。 + +- **陷阱**:处理重复项是最难的部分。没有跳过重复的逻辑(对固定元素和双指针结果都是如此),你会返回重复的三元组。`if i > 0 and nums[i] == nums[i-1]: continue` 这行至关重要。 + +### 困难:接雨水 + +- **问题**:给定一个高度图(非负整数数组),计算下雨后它能接住多少水。 + +- **模式洞察**:对于每个位置,水位由它左边最大高度和右边最大高度中的最小值减去当前高度决定。从两端开始的双指针跟踪这些运行中的最大值。 + +```python +def trap(height): + left, right = 0, len(height) - 1 + left_max, right_max = 0, 0 + water = 0 + + while left < right: + if height[left] < height[right]: + if height[left] >= left_max: + left_max = height[left] + else: + water += left_max - height[left] + left += 1 + else: + if height[right] >= right_max: + right_max = height[right] + else: + water += right_max - height[right] + right -= 1 + + return water +``` + +- **为什么有效**:关键的洞察是,如果 `height[left] < height[right]`,那么位置 `left` 处的水由 `left_max` 限制(我们知道右边有更高的柱子,所以右边不可能是瓶颈)。我们处理较短的一侧,保证另一侧有更高的柱子。 + +- **陷阱**:很多人试图先预计算 `left_max[i]` 和 `right_max[i]` 数组(这可行但使用 $O(n)$ 空间)。双指针方法实现了 $O(1)$ 空间。另外,在最大值更新中混淆 `>=` 和 `>` 会导致差一错误的水量计算。 + +--- + +## 模式:滑动窗口 + +- **滑动窗口**模式维护一个窗口(连续子数组),随着迭代扩展和收缩。它适用于询问满足某个条件的子数组或子串的问题。 + +- **何时使用**:问题要求满足约束条件的最长/最短子数组或子串,且扩展/收缩窗口是单调的(添加元素只能使约束更难/更容易满足,而不是两者兼有)。 + +- **模板**: + +```python +def sliding_window(arr): + left = 0 + state = ... # 窗口状态(计数、和等) + best = ... + + for right in range(len(arr)): + # 扩展:将 arr[right] 添加到窗口状态 + update_state(state, arr[right]) + + # 收缩:当约束被违反时从左侧缩小 + while constraint_violated(state): + remove_from_state(state, arr[left]) + left += 1 + + # 更新答案 + best = max(best, right - left + 1) # 或 min,取决于问题 + + return best +``` + +### 简单:买卖股票的最佳时机 + +- **问题**:给定每日价格,找出一笔交易(先买后卖)的最大利润。 + +- **模式**:跟踪到目前为止的最小价格(窗口的左边界),并在每一天计算利润。 + +```python +def max_profit(prices): + min_price = float('inf') + max_profit = 0 + + for price in prices: + min_price = min(min_price, price) + max_profit = max(max_profit, price - min_price) + + return max_profit +``` + +- 这是一个退化的滑动窗口:左指针(最低价格)只在找到新最小值时向前移动。$O(n)$ 时间,$O(1)$ 空间。 + +### 中等:无重复字符的最长子串 + +- **问题**:找出不含重复字符的最长子串的长度。 + +- **模式**:通过移动 `right` 扩展窗口。当发现重复时,从左侧收缩直到重复被移除。 + +```python +def length_of_longest_substring(s): + char_index = {} # 字符 -> 它的最近索引 + left = 0 + best = 0 + + for right, char in enumerate(s): + if char in char_index and char_index[char] >= left: + left = char_index[char] + 1 # 跳过重复字符 + + char_index[char] = right + best = max(best, right - left + 1) + + return best +``` + +- **为什么需要 `char_index[char] >= left`**:该字符可能来自当前窗口开始之前的映射。没有这个检查,你会错误地为当前窗口中实际不存在的字符收缩窗口。 + +- **陷阱**:使用集合并从左逐个删除字符是正确的但较慢。哈希表方法直接跳到正确的位置。 + +### 困难:最小覆盖子串 + +- **问题**:给定字符串 `s` 和 `t`,在 `s` 中找到包含 `t` 中所有字符的最小窗口。 + +- **模式**:扩展窗口以包含所有必需的字符,然后从左侧收缩以找到最小有效窗口。 + +```python +from collections import Counter + +def min_window(s, t): + if not t or not s: + return "" + + need = Counter(t) # 我们需要的字符及其计数 + have = 0 # 我们已经拥有足够数量的唯一字符数 + required = len(need) # 我们需要多少种唯一字符 + + left = 0 + best = (float('inf'), 0, 0) # (长度, 左, 右) + + window_counts = {} + + for right in range(len(s)): + char = s[right] + window_counts[char] = window_counts.get(char, 0) + 1 + + # 检查此字符的计数是否满足要求 + if char in need and window_counts[char] == need[char]: + have += 1 + + # 当窗口有效时从左侧收缩 + while have == required: + # 更新最佳值 + if (right - left + 1) < best[0]: + best = (right - left + 1, left, right) + + # 移除最左边的字符 + left_char = s[left] + window_counts[left_char] -= 1 + if left_char in need and window_counts[left_char] < need[left_char]: + have -= 1 + left += 1 + + length, start, end = best + return s[start:end + 1] if length != float('inf') else "" +``` + +- **陷阱**:`have` 计数器是关键优化。没有它,你需要在每一步比较整个 `window_counts` 字典与 `need`,每次比较是 $O(|\text{unique chars}|)$。`have` 计数器使有效性检查变为 $O(1)$。 + +- **陷阱**:检查 `window_counts[char] == need[char]`(而不是 `>=`)确保我们每个字符只递增一次 `have`。如果使用 `>=`,我们会多计数。 + +--- + +## 模式:前缀和 + +- **前缀和**数组存储累积和:`prefix[i] = sum(arr[0:i])`。一旦在 $O(n)$ 时间内构建完成,任何子数组和都可以在 $O(1)$ 时间内计算:`sum(arr[l:r]) = prefix[r] - prefix[l]`。 + +```python +def build_prefix(arr): + prefix = [0] * (len(arr) + 1) + for i in range(len(arr)): + prefix[i + 1] = prefix[i] + arr[i] + return prefix + +# arr[l:r] 的和(包含 l,不包含 r) +def range_sum(prefix, l, r): + return prefix[r] - prefix[l] +``` + +- **何时使用**:问题涉及多个子数组和查询,或寻找具有特定和的子数组。 + +### 简单:区间求和查询 + +- **问题**:给定一个数组,回答多个"从索引 $l$ 到 $r$ 的和是多少?"的查询。 + +- 没有前缀和:每个查询是 $O(n)$。有前缀和:$O(n)$ 预计算,然后每个查询 $O(1)$。 + +### 中等:和为 K 的子数组 + +- **问题**:统计有多少个连续子数组的和等于 $k$。 + +- **模式洞察**:从索引 $l$ 到 $r$ 的子数组和等于 `prefix[r+1] - prefix[l]`。我们希望这个值等于 $k$,所以 `prefix[l] = prefix[r+1] - k`。对于每个位置,使用哈希表统计多少个更早的前缀和等于 `current_prefix - k`。 + +```python +def subarray_sum(nums, k): + count = 0 + prefix = 0 + prefix_counts = {0: 1} # 空前缀和 + + for num in nums: + prefix += num + # 有多少更早的前缀和等于 prefix - k? + count += prefix_counts.get(prefix - k, 0) + prefix_counts[prefix] = prefix_counts.get(prefix, 0) + 1 + + return count +``` + +- 这结合了前缀和与哈希表查找:$O(n)$ 时间,$O(n)$ 空间。 + +- **陷阱**:忘记初始化 `prefix_counts = {0: 1}`。空前缀(在任何元素之前)的和为 0。没有它,你会漏掉从索引 0 开始的子数组。 + +### 困难:除自身以外数组的乘积 + +- **问题**:给定一个数组,返回一个数组,其中每个元素是所有其他元素的乘积。你不能使用除法。 + +- **模式**:从左侧构建前缀乘积,从右侧构建后缀乘积。每个位置的答案是 `left_product * right_product`。 + +```python +def product_except_self(nums): + n = len(nums) + result = [1] * n + + # 左向遍历:result[i] = nums[0..i-1] 的乘积 + prefix = 1 + for i in range(n): + result[i] = prefix + prefix *= nums[i] + + # 右向遍历:乘以 nums[i+1..n-1] 的乘积 + suffix = 1 + for i in range(n - 1, -1, -1): + result[i] *= suffix + suffix *= nums[i] + + return result +``` + +- $O(n)$ 时间,$O(1)$ 额外空间(输出数组不计入)。它使用输出数组本身来存储中间前缀乘积,然后在第二遍遍历中乘入后缀乘积。 + +- **陷阱**:如果数组包含零,基于除法的方法会失败。这种前缀/后缀方法正确处理零,因为它从不做除法。 + +--- + +## 常见陷阱总结 + +| 陷阱 | 示例 | 修复 | +|---------|---------|-----| +| 窗口大小的差一错误 | `right - left` vs `right - left + 1` | 画一个2元素示例 | +| Python 中的可变默认值 | `def f(seen={})` 在调用间共享状态 | 使用 `def f(seen=None)` | +| 循环中的字符串拼接 | `s += c` 在 Python 中是 $O(n^2)$ | 使用 `list.append` + `"".join` | +| 前缀和中忘记 `{0: 1}` | 漏掉从索引 0 开始的子数组 | 始终用空前缀初始化 | +| 检查前添加哈希表 | 两数之和:在检查补数之前添加了 `num` | 先检查,后插入 | +| 未处理重复项 | 三数之和返回重复的三元组 | 跳过连续相等的值 | +| 整数溢出 | C++/Java 中大数组求和 | 使用 `long` 或检查边界 | + +--- + +## 课后练习题(NeetCode) + +按顺序练习。每道题强化本文件中的一个模式。 + +### 哈希表查找 +- [Contains Duplicate](https://neetcode.io/problems/contains-duplicate) — 热身:哈希集判断是否见过 +- [Two Sum](https://neetcode.io/problems/two-sum) — 补数查找 +- [Group Anagrams](https://neetcode.io/problems/anagram-groups) — 规范形式作为键 +- [Top K Frequent Elements](https://neetcode.io/problems/top-k-elements-in-list) — 哈希表 + 桶排序 +- [Longest Consecutive Sequence](https://neetcode.io/problems/longest-consecutive-sequence) — 哈希集配合序列起点技巧 +- [Encode and Decode Strings](https://neetcode.io/problems/string-encode-and-decode) — 设计序列化方案 + +### 双指针 +- [Valid Palindrome](https://neetcode.io/problems/is-palindrome) — 向内指针 +- [Two Sum II (sorted)](https://neetcode.io/problems/two-integer-sum-ii) — 排序数组上的双指针 +- [Three Sum](https://neetcode.io/problems/three-integer-sum) — 固定 + 双指针 + 去重 +- [Container With Most Water](https://neetcode.io/problems/max-water-container) — 贪心双指针 +- [Trapping Rain Water](https://neetcode.io/problems/trapping-rain-water) — 带运行最大值的双指针 + +### 滑动窗口 +- [Best Time to Buy and Sell Stock](https://neetcode.io/problems/buy-and-sell-crypto) — 退化窗口 +- [Longest Substring Without Repeating Characters](https://neetcode.io/problems/longest-substring-without-duplicates) — 扩展/收缩配合哈希表 +- [Longest Repeating Character Replacement](https://neetcode.io/problems/longest-repeating-substring-with-replacement) — 窗口 + 最大频率技巧 +- [Minimum Window Substring](https://neetcode.io/problems/minimum-window-with-characters) — 扩展到有效,收缩到最小 + +### 前缀和 +- [Product of Array Except Self](https://neetcode.io/problems/products-of-array-discluding-self) — 前缀/后缀乘积 diff --git a/chapter 14: data structures and algorithms/02. linked lists, stacks, and queues.md b/chapter 14: data structures and algorithms/02. linked lists, stacks, and queues.md new file mode 100644 index 0000000..2214022 --- /dev/null +++ b/chapter 14: data structures and algorithms/02. linked lists, stacks, and queues.md @@ -0,0 +1,410 @@ +# 链表、栈和队列 + +*链表、栈和队列是更复杂数据结构的构建模块。本文件涵盖它们的运行机制,然后构建关键模式:快慢指针、单调栈和基于堆的优先队列,通过逐步增加难度的题目,并在每一步指出常见陷阱。* + +- 数组提供了快速的随机访问但插入代价高。**链表**提供了快速插入但没有随机访问。**栈**和**队列**将访问限制在一端或两端,而正是这种限制使它们强大:通过限制你能做的事情,它们简化了你需要考虑的事情。 + +## 链表 + +- **单向链表**是一个节点链。每个节点存储一个值和一个指向下一个节点的指针。最后一个节点指向 `null`。 + +```python +class ListNode: + def __init__(self, val=0, next=None): + self.val = val + self.next = next +``` + +- **相对于数组的优势**:在已知位置插入或删除是 $O(1)$(只需重新指向指针)。无需移动元素。 + +- **劣势**:访问元素 $i$ 需要 $O(i)$ 次遍历(无随机访问)。缓存局部性差(节点分散在内存中)。 + +- **双向链表**增加了一个 `prev` 指针,支持向后遍历。用于 LRU 缓存(常数时间删除任何节点)和浏览器历史(前进/后退)。 + +| 操作 | 单向 | 双向 | +|-----------|--------|--------| +| 按索引访问 | $O(n)$ | $O(n)$ | +| 在头部插入 | $O(1)$ | $O(1)$ | +| 在尾部插入 | $O(n)$ 或 $O(1)$* | $O(1)$ | +| 删除给定节点 | $O(n)$** | $O(1)$ | +| 搜索 | $O(n)$ | $O(n)$ | + +*有尾指针时。**需要前驱节点,需要遍历。 + +- **哨兵节点**(虚拟头/尾节点)简化了边界情况。没有虚拟头节点时,在头部插入或删除头部需要特殊代码。有了虚拟节点,每个真实节点都有前驱。 + +```python +# 无虚拟节点:头部删除需要特殊处理 +def delete_head(head): + if not head: + return None + return head.next + +# 有虚拟节点:统一逻辑 +dummy = ListNode(0) +dummy.next = head +# 现在每次删除都是:prev.next = prev.next.next +``` + +- **陷阱**:忘记处理空列表(`head is None`)或单元素列表。始终测试这些边界情况。 + +--- + +## 模式:快慢指针(弗洛伊德算法) + +- 使用两个以不同速度移动的指针来检测链表的属性。**慢**指针一次移动一步;**快**指针一次移动两步。 + +### 简单:环形链表 + +- **问题**:判断一个链表是否有环。 + +- **模式**:如果有环,快指针最终会追上慢指针(它们会相遇)。如果没有环,快指针会到达 `null`。 + +```python +def has_cycle(head): + slow = fast = head + while fast and fast.next: + slow = slow.next + fast = fast.next.next + if slow == fast: + return True + return False +``` + +- **为什么有效**:如果环的长度为 $c$,快指针每步缩小1个节点的差距。它们必在慢指针进入环后的 $c$ 步内相遇。 + +- **陷阱**:检查 `fast and fast.next`(而不仅仅是 `fast.next`)。如果 `fast` 是 `None`,调用 `fast.next` 会崩溃。 + +### 中等:寻找链表的中间节点 + +- **问题**:返回中间节点。 + +- **模式**:当快指针到达末尾时,慢指针在中间。 + +```python +def find_middle(head): + slow = fast = head + while fast and fast.next: + slow = slow.next + fast = fast.next.next + return slow # slow 在中间(偶数长度时为第二个中间节点) +``` + +### 中等:环形链表 II(寻找环的起点) + +- **问题**:返回环开始的节点。 + +- **模式**:在快指针和慢指针相遇后,将一个指针重置到头部。两者以速度1移动。它们在环的起点相遇。 + +```python +def detect_cycle(head): + slow = fast = head + while fast and fast.next: + slow = slow.next + fast = fast.next.next + if slow == fast: + # 将一个指针重置到头部 + slow = head + while slow != fast: + slow = slow.next + fast = fast.next + return slow + return None +``` + +- **为什么有效**:设从头到环起点的距离为 $a$,从环起点到相遇点的距离为 $b$。慢指针走了 $a + b$。快指针走了 $2(a + b)$。差值为一整圈:$a + b = c$(环长)。所以 $a = c - b$:从头到环起点的距离等于从相遇点到环起点的距离(沿环向前走)。 + +### 困难:K个一组反转链表 + +- **问题**:反转链表中每 $k$ 个连续节点。 + +```python +def reverse_k_group(head, k): + # 检查是否还有 k 个节点 + node = head + for _ in range(k): + if not node: + return head + node = node.next + + # 反转 k 个节点 + prev, curr = None, head + for _ in range(k): + nxt = curr.next + curr.next = prev + prev = curr + curr = nxt + + # 当前 head 现在是反转后的组的尾节点 + # 递归处理剩余部分 + head.next = reverse_k_group(curr, k) + return prev # prev 是这组的新头节点 +``` + +- **陷阱**:原地反转模式(`prev, curr, nxt`)值得记住。画出来:每一步,你将 `curr.next` 指回 `prev`,然后推进所有三个指针。顺序搞错会破坏链表。 + +--- + +## 栈 + +- **栈**是 LIFO(后进先出):最近添加的元素最先被移除。想象一堆盘子。 + +- 操作:`push(x)` 添加到顶部,`pop()` 从顶部移除,`peek()` 查看顶部不移除。全部 $O(1)$。 + +- 栈是**递归**(调用栈)、**表达式求值**(中缀转后缀)和**撤销操作**(每个操作被入栈,撤销时弹出最后一个)背后的隐式结构。 + +### 简单:有效的括号 + +- **问题**:给定一个由括号 `()[]{}` 组成的字符串,判断它们是否平衡。 + +- **模式**:将左括号入栈。当看到右括号时,检查栈顶是否匹配。 + +```python +def is_valid(s): + stack = [] + matching = {')': '(', ']': '[', '}': '{'} + + for char in s: + if char in matching: + if not stack or stack[-1] != matching[char]: + return False + stack.pop() + else: + stack.append(char) + + return len(stack) == 0 +``` + +- **陷阱**:忘记最后检查 `len(stack) == 0`。字符串 "(((" 中没有不匹配的情况,但因为没有闭合的括号,它是无效的。 + +--- + +## 模式:单调栈 + +- **单调栈**维护按排序顺序排列的元素(递增或递减)。当新元素会破坏顺序时,你弹出元素直到顺序恢复。 + +- **何时使用**:问题要求"对每个元素,找到下一个/上一个更大/更小的元素。"栈的总时间复杂度为 $O(n)$,因为每个元素最多被入栈和出栈一次。 + +### 中等:每日温度 + +- **问题**:给定每日温度,对于每一天,找到需要等多少天才会升温。 + +- **模式**:使用一个索引栈。当当前温度高于栈顶时,弹出并记录距离。 + +```python +def daily_temperatures(temperatures): + n = len(temperatures) + result = [0] * n + stack = [] # 索引栈,温度递减 + + for i in range(n): + while stack and temperatures[i] > temperatures[stack[-1]]: + prev = stack.pop() + result[prev] = i - prev + stack.append(i) + + return result +``` + +- 每个元素被入栈一次,最多出栈一次:总计 $O(n)$。 + +- **陷阱**:在栈中存储索引(而非值)。你需要索引来计算距离。 + +### 困难:柱状图中最大的矩形 + +- **问题**:给定一个条形高度数组,找出最大矩形的面积。 + +- **模式**:对于每个条形,找出它可以向左右延伸多远(即,每侧最近的更短条形)。单调递增栈高效地追踪这个信息。 + +```python +def largest_rectangle(heights): + stack = [] # 索引栈,高度递增 + max_area = 0 + heights.append(0) # 哨兵,用于最后清空栈 + + for i, h in enumerate(heights): + start = i + while stack and stack[-1][1] > h: + idx, height = stack.pop() + max_area = max(max_area, height * (i - idx)) + start = idx # 当前条形可以延伸到被弹出条形开始的位置 + stack.append((start, h)) + + heights.pop() # 移除哨兵 + return max_area +``` + +- **陷阱**:`start = idx` 这行很微妙。当我们弹出一个比当前条形更高的条形时,当前条形可以向后延伸至被弹出条形开始的位置(因为中间的所有条形至少和被弹出条形一样高)。缺少这行会得到错误的面积。 + +- **陷阱**:哨兵 `heights.append(0)` 确保栈中所有剩余的条形被处理。没有它,那些右侧从未遇到更短条形的条形会被遗漏。 + +--- + +## 队列 + +- **队列**是 FIFO(先进先出):元素从后面添加,从前面移除。想象商店里排队。 + +- **双端队列**(deque)支持在两端 $O(1)$ 插入和删除。Python 的 `collections.deque` 是标准实现。 + +- 队列是 **BFS**(广度优先搜索,第14章文件04)、**任务调度**和**消息传递**背后的结构。 + +### 简单:用栈实现队列 + +- **问题**:仅使用两个栈实现一个队列。 + +- **模式**:使用一个栈进行入队操作,一个栈进行出队操作。当出队栈为空时,将所有元素从入队栈转移到出队栈(反转顺序)。 + +```python +class MyQueue: + def __init__(self): + self.push_stack = [] + self.pop_stack = [] + + def push(self, x): + self.push_stack.append(x) + + def pop(self): + if not self.pop_stack: + while self.push_stack: + self.pop_stack.append(self.push_stack.pop()) + return self.pop_stack.pop() + + def peek(self): + if not self.pop_stack: + while self.push_stack: + self.pop_stack.append(self.push_stack.pop()) + return self.pop_stack[-1] + + def empty(self): + return not self.push_stack and not self.pop_stack +``` + +- 每次操作的平摊复杂度 $O(1)$:每个元素最多在两个栈之间移动一次。 + +--- + +## 优先队列和堆 + +- **优先队列**总是返回最小(或最大)的元素,不论插入顺序。标准实现是**二叉堆**。 + +- **最小堆**是一棵完全二叉树,其中每个父节点都小于其子节点。最小值总是在根节点。以数组形式存储:节点 $i$ 的子节点在位置 $2i + 1$ 和 $2i + 2$。 + +| 操作 | 时间 | +|-----------|------| +| 插入 | $O(\log n)$ | +| 获取最小值 | $O(1)$ | +| 提取最小值 | $O(\log n)$ | +| 从数组构建堆 | $O(n)$ | + +- Python 的 `heapq` 模块提供了最小堆。对于最大堆,将值取反。 + +```python +import heapq + +# 最小堆 +h = [] +heapq.heappush(h, 5) +heapq.heappush(h, 2) +heapq.heappush(h, 8) +print(heapq.heappop(h)) # 2(最小) + +# 最大堆技巧:取反 +heapq.heappush(h, -10) +print(-heapq.heappop(h)) # 10(最大) +``` + +### 中等:数组中的第 K 个最大元素 + +- **问题**:找到第 k 个最大的元素。 + +- **模式**:维护一个大小为 $k$ 的最小堆。堆的根节点就是第 k 大的元素。如果堆有 $k$ 个元素且新元素大于根节点,则替换根节点。 + +```python +import heapq + +def find_kth_largest(nums, k): + heap = nums[:k] + heapq.heapify(heap) # O(k) + + for num in nums[k:]: + if num > heap[0]: + heapq.heapreplace(heap, num) # 弹出最小值,推入 num:O(log k) + + return heap[0] +``` + +- $O(n \log k)$ 时间,$O(k)$ 空间。当 $k \ll n$ 时,这比排序($O(n \log n)$)好得多。 + +- **陷阱**:使用大小为 $n$ 的最大堆并弹出 $k$ 次也可行但较慢:$O(n + k \log n)$。大小为 $k$ 的最小堆是最优方法。 + +### 困难:合并 K 个排序链表 + +- **问题**:合并 $k$ 个已排序链表为一个排序链表。 + +- **模式**:使用一个包含每个链表头节点的最小堆。弹出最小的节点,将其添加到结果中,并将其下一个节点推入堆中。 + +```python +import heapq + +def merge_k_lists(lists): + heap = [] + for i, lst in enumerate(lists): + if lst: + heapq.heappush(heap, (lst.val, i, lst)) + + dummy = ListNode(0) + curr = dummy + + while heap: + val, i, node = heapq.heappop(heap) + curr.next = node + curr = curr.next + if node.next: + heapq.heappush(heap, (node.next.val, i, node.next)) + + return dummy.next +``` + +- $O(n \log k)$,其中 $n$ 是总节点数。堆中最多有 $k$ 个元素。 + +- **陷阱**:堆元组中的 `i`(索引)是用于打破平局的。没有它,当值相等时 Python 会尝试比较 `ListNode` 对象,这会崩溃因为 `ListNode` 不支持 `<`。索确保了一有效的比较。 + +--- + +## 常见陷阱总结 + +| 陷阱 | 示例 | 修复 | +|---------|---------|-----| +| `fast.next` 上的空指针 | 循环检测中使用 `while fast.next` | 检查 `fast and fast.next` | +| 未处理空链表 | 反转 `None` | 添加 `if not head` 守卫 | +| 栈下溢 | 从空栈弹出 | 检查 `len(stack) > 0` 或 `if stack` | +| 忘记哨兵 | 直方图遗漏了最后的条形 | 追加 0 来清空栈 | +| 堆中缺少平局打破 | 比较不可比较的对象 | 向堆元组添加索引 | +| 遍历时修改链表 | 遍历时删除节点 | 使用 prev/curr 模式或虚拟头节点 | + +--- + +## 课后练习题(NeetCode) + +### 链表 +- [反转链表](https://neetcode.io/problems/reverse-a-linked-list) — 基础的原地反转 +- [合并两个有序链表](https://neetcode.io/problems/merge-two-sorted-linked-lists) — 双指针合并 +- [环形链表](https://neetcode.io/problems/linked-list-cycle-detection) — 快慢指针 +- [重排链表](https://neetcode.io/problems/reorder-linked-list) — 找中间 + 反转 + 合并 +- [删除链表的倒数第 N 个节点](https://neetcode.io/problems/remove-node-from-end-of-linked-list) — 间距为 $n$ 的双指针 +- [LRU 缓存](https://neetcode.io/problems/lru-cache) — 哈希表 + 双向链表 + +### 栈 +- [有效的括号](https://neetcode.io/problems/validate-parentheses) — 括号匹配 +- [最小栈](https://neetcode.io/problems/minimum-stack) — 在每层跟踪最小值 +- [逆波兰表达式求值](https://neetcode.io/problems/evaluate-reverse-polish-notation) — 基于栈的求值 +- [每日温度](https://neetcode.io/problems/daily-temperatures) — 单调递减栈 +- [柱状图中最大的矩形](https://neetcode.io/problems/largest-rectangle-in-histogram) — 单调递增栈 +- [车队](https://neetcode.io/problems/car-fleet) — 带到达时间的栈 + +### 堆 / 优先队列 +- [数据流中的第 K 大元素](https://neetcode.io/problems/kth-largest-integer-in-a-stream) — 大小为 $k$ 的最小堆 +- [最后一块石头的重量](https://neetcode.io/problems/last-stone-weight) — 最大堆模拟 +- [最接近原点的 K 个点](https://neetcode.io/problems/k-closest-points-to-origin) — 按距离排序的最小堆 +- [任务调度器](https://neetcode.io/problems/task-scheduler) — 贪心 + 最大堆 + 冷却时间 +- [数据流的中位数](https://neetcode.io/problems/find-median-in-a-data-stream) — 双堆(下半部分用最大堆,上半部分用最小堆) diff --git a/chapter 14: data structures and algorithms/03. trees.md b/chapter 14: data structures and algorithms/03. trees.md new file mode 100644 index 0000000..480a20a --- /dev/null +++ b/chapter 14: data structures and algorithms/03. trees.md @@ -0,0 +1,381 @@ +# 树 + +*树是层次化数据结构,是文件系统、数据库、编译器和无数面试题背后的基础。本文件涵盖二叉树、二叉搜索树、平衡树、前缀树、线段树、树状数组和并查集,包括遍历模式、递归思维以及逐步增加难度的题目。* + +- **树**是一个连通的无环图(第13章)。最重要的变体是**二叉树**:每个节点最多有两个子节点(左和右)。树无处不在:编译器中的解析树、浏览器中的 DOM 树、机器学习中的决策树以及数据库中的 B 树。 + +- 解决树问题的关键洞察:**大多数树问题都可以递归解决**。结构是递归的(树是一个根节点加上两棵子树),因此解法也应是递归的。掌握"解决左子树、解决右子树、合并结果"的模式,你就能解决大多数树问题。 + +## 二叉树遍历 + +- 有四种标准的访问每个节点的方式: + + - **中序遍历**(左、根、右):对于 BST,这会按排序顺序访问节点。 + - **前序遍历**(根、左、右):用于序列化和复制树。 + - **后序遍历**(左、右、根):用于删除和计算大小。 + - **层序遍历**(BFS):使用队列逐层访问节点。 + +```python +class TreeNode: + def __init__(self, val=0, left=None, right=None): + self.val = val + self.left = left + self.right = right + +def inorder(root): + if not root: + return [] + return inorder(root.left) + [root.val] + inorder(root.right) + +def preorder(root): + if not root: + return [] + return [root.val] + preorder(root.left) + preorder(root.right) + +def postorder(root): + if not root: + return [] + return postorder(root.left) + postorder(root.right) + [root.val] + +from collections import deque + +def level_order(root): + if not root: + return [] + result, queue = [], deque([root]) + while queue: + level = [] + for _ in range(len(queue)): + node = queue.popleft() + level.append(node.val) + if node.left: + queue.append(node.left) + if node.right: + queue.append(node.right) + result.append(level) + return result +``` + +- **陷阱**:上面的递归遍历在每一步都创建新列表(由于 `+` 拼接),这是 $O(n^2)$。为了效率,传递一个结果列表并原地追加: + +```python +def inorder_efficient(root, result=None): + if result is None: + result = [] + if root: + inorder_efficient(root.left, result) + result.append(root.val) + inorder_efficient(root.right, result) + return result +``` + +### 简单:二叉树的最大深度 + +```python +def max_depth(root): + if not root: + return 0 + return 1 + max(max_depth(root.left), max_depth(root.right)) +``` + +- **递归模式**:基本情况(null → 0),递归子节点,合并(1 + max)。同样的模式适用于数十种树问题。 + +### 简单:翻转二叉树 + +```python +def invert_tree(root): + if not root: + return None + root.left, root.right = invert_tree(root.right), invert_tree(root.left) + return root +``` + +### 中等:二叉树的最近公共祖先 + +- **问题**:找到既是 $p$ 又是 $q$ 的祖先的最低节点。 + +- **模式**:如果 $p$ 和 $q$ 都在左子树中,则 LCA 在左子树中。如果都在右子树中,则在右子树中。如果它们分开了(一个在左,一个在右),则当前节点就是 LCA。 + +```python +def lowest_common_ancestor(root, p, q): + if not root or root == p or root == q: + return root + + left = lowest_common_ancestor(root.left, p, q) + right = lowest_common_ancestor(root.right, p, q) + + if left and right: + return root # p 和 q 在不同子树中 + return left if left else right +``` + +- **陷阱**:这假设 $p$ 和 $q$ 都在树中。如果它们可能不在,你需要额外的检查。 + +### 困难:二叉树中的最大路径和 + +- **问题**:找出任意两个节点之间的最大路径和(路径不需要经过根节点)。 + +```python +def max_path_sum(root): + best = [float('-inf')] + + def dfs(node): + if not node: + return 0 + left = max(dfs(node.left), 0) # 忽略负路径 + right = max(dfs(node.right), 0) + + # 经过当前节点的路径(可能作为"转弯点") + best[0] = max(best[0], node.val + left + right) + + # 返回到父节点的最大增益 + return node.val + max(left, right) + + dfs(root) + return best[0] +``` + +- **关键洞察**:在每个节点,有两个问题:(1) *经过*这个节点的最佳路径是什么(左 + 节点 + 右)?(2) 这个节点可以贡献给其*父节点*的最佳路径是什么(节点 + max(左, 右),因为路径不能在两个层级分叉)?混淆这两者是最常见的错误。 + +## 二叉搜索树(BST) + +- **BST** 满足:对于每个节点,左子树中的所有值都较小,右子树中的所有值都较大。这实现了 $O(\log n)$ 的搜索、插入和删除(当平衡时)。 + +```python +def search_bst(root, target): + if not root: + return None + if target < root.val: + return search_bst(root.left, target) + elif target > root.val: + return search_bst(root.right, target) + else: + return root + +def insert_bst(root, val): + if not root: + return TreeNode(val) + if val < root.val: + root.left = insert_bst(root.left, val) + else: + root.right = insert_bst(root.right, val) + return root +``` + +- **陷阱**:BST 操作仅在树平衡时才是 $O(\log n)$。由已排序插入构建的 BST 退化为链表:每次操作 $O(n)$。这就是平衡 BST(AVL、红黑树)存在的原因。 + +### 中等:验证二叉搜索树 + +```python +def is_valid_bst(root, lo=float('-inf'), hi=float('inf')): + if not root: + return True + if root.val <= lo or root.val >= hi: + return False + return (is_valid_bst(root.left, lo, root.val) and + is_valid_bst(root.right, root.val, hi)) +``` + +- **陷阱**:只检查 `left.val < root.val < right.val` 是错误的。约束条件是左子树中*所有*节点都更小,而不仅仅是直接子节点。`lo`/`hi` 边界将这个约束向下传递。 + +### 中等:二叉搜索树中第 K 小的元素 + +- **模式**:BST 的中序遍历按排序顺序访问节点。访问的第 $k$ 个节点就是答案。 + +```python +def kth_smallest(root, k): + count = [0] + result = [None] + + def inorder(node): + if not node or result[0] is not None: + return + inorder(node.left) + count[0] += 1 + if count[0] == k: + result[0] = node.val + return + inorder(node.right) + + inorder(root) + return result[0] +``` + +## 前缀树(Trie) + +- **前缀树**逐字符地将字符串存储在树中。每条边代表一个字符,从根到标记节点的路径代表存储的字符串。前缀树实现了 $O(L)$ 的查找,其中 $L$ 是字符串长度,无论存储了多少个字符串。 + +```python +class TrieNode: + def __init__(self): + self.children = {} + self.is_end = False + +class Trie: + def __init__(self): + self.root = TrieNode() + + def insert(self, word): + node = self.root + for char in word: + if char not in node.children: + node.children[char] = TrieNode() + node = node.children[char] + node.is_end = True + + def search(self, word): + node = self.root + for char in word: + if char not in node.children: + return False + node = node.children[char] + return node.is_end + + def starts_with(self, prefix): + node = self.root + for char in prefix: + if char not in node.children: + return False + node = node.children[char] + return True +``` + +- **何时使用**:自动补全、拼写检查、单词游戏、IP 路由表。每当你需要基于前缀的操作时。 + +### 困难:单词搜索 II + +- **问题**:给定一个字符板和一个单词列表,找出所有可以通过遍历相邻单元格形成的单词。 + +- **模式**:从单词列表构建一个前缀树,然后从每个单元格使用前缀树进行 DFS,尽早剪枝分支(如果没有单词以当前前缀开头,则停止)。 + +- **陷阱**:没有前缀树的话,你需要为每个单词单独进行 DFS:$O(w \cdot m \cdot n \cdot 4^L)$。前缀树跨单词共享前缀计算,大幅减少了工作量。 + +## 并查集(不相交集合) + +- **并查集**跟踪一组不相交集合。两个操作:`find(x)` 返回 $x$ 所在集合的代表元,`union(x, y)` 合并包含 $x$ 和 $y$ 的集合。 + +```python +class UnionFind: + def __init__(self, n): + self.parent = list(range(n)) + self.rank = [0] * n + self.count = n # 连通分量数 + + def find(self, x): + if self.parent[x] != x: + self.parent[x] = self.find(self.parent[x]) # 路径压缩 + return self.parent[x] + + def union(self, x, y): + rx, ry = self.find(x), self.find(y) + if rx == ry: + return False # 已经连通 + # 按秩合并 + if self.rank[rx] < self.rank[ry]: + rx, ry = ry, rx + self.parent[ry] = rx + if self.rank[rx] == self.rank[ry]: + self.rank[rx] += 1 + self.count -= 1 + return True +``` + +- 通过路径压缩和按秩合并,两个操作都是平摊 $O(\alpha(n)) \approx O(1)$(反阿克曼函数,实际上是常数)。 + +- **何时使用**:连通分量、无向图中的环检测、Kruskal 最小生成树、分组等价项。 + +### 中等:连通分量数量 + +```python +def count_components(n, edges): + uf = UnionFind(n) + for u, v in edges: + uf.union(u, v) + return uf.count +``` + +### 中等:冗余连接 + +- **问题**:找出从图中移除后使图成为树的那条边(即,创建环的那条边)。 + +- **模式**:逐一处理边。第一条两个端点已经在同一分量中的边就是创建环的边。 + +```python +def find_redundant(edges): + uf = UnionFind(len(edges) + 1) + for u, v in edges: + if not uf.union(u, v): + return [u, v] # 已经连通 → 这条边创建了环 +``` + +## 线段树和树状数组 + +- **线段树**支持区间查询(子数组上的和、最小值、最大值)和单点更新,两者都是 $O(\log n)$。 + +- **树状数组**(二叉索引树)是前缀和查询和单点更新的更简单、更快的替代方案。它使用一种巧妙的位操作技巧:每个位置存储一个部分和,覆盖范围由最低设置位决定。 + +```python +class FenwickTree: + def __init__(self, n): + self.n = n + self.tree = [0] * (n + 1) + + def update(self, i, delta): + i += 1 # 1-indexed + while i <= self.n: + self.tree[i] += delta + i += i & (-i) # 加上最低设置位 + + def prefix_sum(self, i): + i += 1 + total = 0 + while i > 0: + total += self.tree[i] + i -= i & (-i) # 移除最低设置位 + return total + + def range_sum(self, l, r): + return self.prefix_sum(r) - (self.prefix_sum(l - 1) if l > 0 else 0) +``` + +- **何时使用**:需要带更新的重复区间查询的问题。当你只需要前缀和时首选树状数组;当你需要任意区间操作(最小值、最大值、GCD)时使用线段树。 + +--- + +## 常见陷阱总结 + +| 陷阱 | 示例 | 修复 | +|---------|---------|-----| +| BST 只检查直接子节点 | `left.val < root.val` 遗漏了深层违规 | 传递 `lo`/`hi` 边界 | +| 递归中 $O(n^2)$ 列表拼接 | `inorder(left) + [val] + inorder(right)` | 追加到共享列表 | +| 忘记基本情况 | 空树上的无限递归 | `if not root: return` | +| 混淆经过路径和到父节点的路径 | 最大路径和:在两个层级分叉 | 向父节点返回单分支,单独跟踪双分支 | +| 树状数组 1-indexed vs 0-indexed | 树数组中的差一错误 | 入口处始终 `i += 1` | +| 并查集没有路径压缩 | 最坏情况下每次 `find` 是 $O(n)$ | `self.parent[x] = self.find(self.parent[x])` | + +--- + +## 课后练习题(NeetCode) + +### 二叉树模式 +- [翻转二叉树](https://neetcode.io/problems/invert-a-binary-tree) — 基础递归 +- [二叉树的最大深度](https://neetcode.io/problems/depth-of-binary-tree) — 递归深度 +- [相同的树](https://neetcode.io/problems/same-binary-tree) — 同步遍历 +- [另一棵树的子树](https://neetcode.io/problems/subtree-of-a-binary-tree) — 嵌套递归 +- [二叉树的层序遍历](https://neetcode.io/problems/level-order-traversal-of-binary-tree) — 带层级跟踪的 BFS +- [二叉树中的最大路径和](https://neetcode.io/problems/binary-tree-maximum-path-sum) — 带全局最优的 DFS +- [序列化与反序列化二叉树](https://neetcode.io/problems/serialize-and-deserialize-binary-tree) — 前序遍历 + null 标记 + +### BST 模式 +- [验证二叉搜索树](https://neetcode.io/problems/valid-binary-search-tree) — 边界传播 +- [二叉搜索树中第 K 小的元素](https://neetcode.io/problems/kth-smallest-integer-in-bst) — 中序遍历 +- [二叉搜索树的最近公共祖先](https://neetcode.io/problems/lowest-common-ancestor-in-binary-search-tree) — 利用 BST 排序性质 + +### 前缀树 +- [实现 Trie](https://neetcode.io/problems/implement-prefix-tree) — 基础前缀树操作 +- [设计添加和搜索单词](https://neetcode.io/problems/design-word-search-data-structure) — 前缀树 + 带通配符的 DFS +- [单词搜索 II](https://neetcode.io/problems/search-for-word-ii) — 前缀树引导的回溯 + +### 并查集 +- [连通分量数量](https://neetcode.io/problems/count-connected-components) — 基础并查集 +- [冗余连接](https://neetcode.io/problems/redundant-connection) — 通过并查集检测环 diff --git a/chapter 14: data structures and algorithms/04. graphs.md b/chapter 14: data structures and algorithms/04. graphs.md new file mode 100644 index 0000000..a8934cc --- /dev/null +++ b/chapter 14: data structures and algorithms/04. graphs.md @@ -0,0 +1,323 @@ +# 图 + +*图建模了关系和连接——从社交网络到道路地图再到依赖链。本文件涵盖图的表示、BFS、DFS、最短路径、拓扑排序和连通分量,包括遍历和寻路模式,这些是图面试题中的核心。* + +- 我们在第12章(邻接矩阵、拉普拉斯矩阵、谱性质)和第13章(树、平面性、着色)中已经介绍了图论。这里我们专注于**算法模式**:如何在代码中遍历、搜索和优化图。 + +- 两种基本的图算法是 **BFS** 和 **DFS**。几乎所有图问题都可以归结为其中一种,可能带有修改。掌握这两种算法,你就能解决绝大多数图问题。 + +## 图的表示 + +- **邻接表**:对于每个节点,存储一个邻居列表。空间:$O(|V| + |E|)$。最适合稀疏图(大多数现实世界的图)。 + +```python +# 无向图 +graph = { + 0: [1, 2], + 1: [0, 3], + 2: [0, 3], + 3: [1, 2] +} + +# 从边列表构建 +def build_graph(n, edges): + graph = {i: [] for i in range(n)} + for u, v in edges: + graph[u].append(v) + graph[v].append(u) # 有向图省略这一行 + return graph +``` + +- **邻接矩阵**:$n \times n$ 矩阵,其中 $A[i][j] = 1$ 如果边 $(i, j)$ 存在。空间:$O(|V|^2)$。最适合稠密图或需要 $O(1)$ 边查找时。 + +- **何时使用哪种**:绝大多数情况使用邻接表。只有当图很稠密($|E| \approx |V|^2$)或需要常数时间边存在性检查时才使用矩阵。 + +## 模式:BFS(广度优先搜索) + +- BFS 使用队列**逐层**探索节点。它是以下问题的首选算法: + - **无权**图中的最短路径 + - 层序遍历 + - 寻找连通分量 + - 任何询问"最小步数"的问题 + +```python +from collections import deque + +def bfs(graph, start): + visited = {start} + queue = deque([start]) + + while queue: + node = queue.popleft() + for neighbour in graph[node]: + if neighbour not in visited: + visited.add(neighbour) + queue.append(neighbour) +``` + +- **关键**:在**入队时**添加到 `visited`,而不是在出队时。如果你在出队时标记已访问,同一个节点可能被不同前驱多次入队,浪费时间并可能导致错误结果。 + +### 简单:岛屿数量 + +- **问题**:给定一个由 '1'(陆地)和 '0'(水)组成的 2D 网格,计算岛屿的数量。 + +- **模式**:遍历网格。当找到一个 '1' 时,使用 BFS/DFS 将所有连通的陆地单元格标记为已访问。每次开始 BFS 就是一个岛屿。 + +```python +from collections import deque + +def num_islands(grid): + if not grid: + return 0 + + rows, cols = len(grid), len(grid[0]) + count = 0 + + for r in range(rows): + for c in range(cols): + if grid[r][c] == '1': + count += 1 + # BFS 标记整个岛屿 + queue = deque([(r, c)]) + grid[r][c] = '0' # 标记已访问 + while queue: + cr, cc = queue.popleft() + for dr, dc in [(0,1),(0,-1),(1,0),(-1,0)]: + nr, nc = cr + dr, cc + dc + if 0 <= nr < rows and 0 <= nc < cols and grid[nr][nc] == '1': + grid[nr][nc] = '0' + queue.append((nr, nc)) + + return count +``` + +- **陷阱**:`directions = [(0,1),(0,-1),(1,0),(-1,0)]` 模式用于四连通网格邻居,几乎每个网格问题都会用到。记住它。对于八连通,加上对角线。 + +- **陷阱**:修改输入网格(`grid[r][c] = '0'`)避免了需要单独的 `visited` 集合。在面试中这是可以接受的,但要明确说明权衡(改变了输入)。 + +### 中等:腐烂的橘子 + +- **问题**:新鲜橘子如果与腐烂橘子相邻则会腐烂。返回所有橘子都腐烂的最短时间(如果不可能则返回 -1)。 + +- **模式**:多源 BFS。将所有初始腐烂的橘子同时放入队列。每层 BFS 就是一个时间步。 + +```python +from collections import deque + +def oranges_rotting(grid): + rows, cols = len(grid), len(grid[0]) + queue = deque() + fresh = 0 + + for r in range(rows): + for c in range(cols): + if grid[r][c] == 2: + queue.append((r, c)) + elif grid[r][c] == 1: + fresh += 1 + + if fresh == 0: + return 0 + + time = 0 + while queue and fresh > 0: + time += 1 + for _ in range(len(queue)): + cr, cc = queue.popleft() + for dr, dc in [(0,1),(0,-1),(1,0),(-1,0)]: + nr, nc = cr + dr, cc + dc + if 0 <= nr < rows and 0 <= nc < cols and grid[nr][nc] == 1: + grid[nr][nc] = 2 + fresh -= 1 + queue.append((nr, nc)) + + return time if fresh == 0 else -1 +``` + +- **关键洞察**:多源 BFS 同时处理所有源。这给出了从*任何*源的最短距离,这正是"最后一个新鲜橘子腐烂需要多长时间"。 + +## 模式:DFS(深度优先搜索) + +- DFS 尽可能深地探索,然后回溯。它使用栈(显式栈或通过递归使用调用栈)。DFS 是以下问题的首选: + - 环检测 + - 拓扑排序 + - 连通分量 + - 回溯 / 穷举搜索 + - 带约束的寻路 + +```python +def dfs(graph, node, visited=None): + if visited is None: + visited = set() + visited.add(node) + for neighbour in graph[node]: + if neighbour not in visited: + dfs(graph, neighbour, visited) +``` + +### 中等:课程表(环检测) + +- **问题**:给定 $n$ 门课程和先修条件,判断是否能完成所有课程(即,没有循环依赖)。 + +- **模式**:在有向图中检测环。使用带有三种状态的 DFS:未访问、正在进行(在当前 DFS 路径上)、已完成。 + +```python +def can_finish(num_courses, prerequisites): + graph = {i: [] for i in range(num_courses)} + for course, prereq in prerequisites: + graph[course].append(prereq) + + # 0 = 未访问, 1 = 进行中, 2 = 已完成 + state = [0] * num_courses + + def has_cycle(node): + if state[node] == 1: + return True # 回边 → 环 + if state[node] == 2: + return False # 已经完全探索过 + + state[node] = 1 # 标记为进行中 + for neighbour in graph[node]: + if has_cycle(neighbour): + return True + state[node] = 2 # 标记为已完成 + return False + + for course in range(num_courses): + if has_cycle(course): + return False + return True +``` + +- **为什么需要三种状态**:两种状态(已访问/未访问)无法区分"我正在探索这个节点"和"我已完成对这个节点的探索"。找到一个当前正在被探索的节点(状态 = 1)意味着我们发现了环。找到一个已经完全探索的节点(状态 = 2)只是交叉边,不是环。 + +### 中等:课程表 II(拓扑排序) + +- **问题**:返回一个有效的课程顺序(拓扑排序)。 + +- **模式(Kahn 算法——基于 BFS)**:从没有入边的节点(入度为 0)开始。处理它们,减少它们邻居的入度。重复。 + +```python +from collections import deque + +def find_order(num_courses, prerequisites): + graph = {i: [] for i in range(num_courses)} + indegree = [0] * num_courses + + for course, prereq in prerequisites: + graph[prereq].append(course) + indegree[course] += 1 + + queue = deque([i for i in range(num_courses) if indegree[i] == 0]) + order = [] + + while queue: + node = queue.popleft() + order.append(node) + for neighbour in graph[node]: + indegree[neighbour] -= 1 + if indegree[neighbour] == 0: + queue.append(neighbour) + + return order if len(order) == num_courses else [] # 空 = 存在环 +``` + +- **陷阱**:如果结果中的节点数少于图中的节点数,则存在环(某些节点的入度从未降到 0)。 + +## 最短路径 + +### Dijkstra 算法 + +- 在**非负**加权图中从源点找到到所有其他节点的最短路径。使用优先队列(最小堆)。 + +```python +import heapq + +def dijkstra(graph, start): + # graph: {node: [(neighbour, weight), ...]} + dist = {node: float('inf') for node in graph} + dist[start] = 0 + heap = [(0, start)] + + while heap: + d, node = heapq.heappop(heap) + if d > dist[node]: + continue # 过期条目 + + for neighbour, weight in graph[node]: + new_dist = d + weight + if new_dist < dist[neighbour]: + dist[neighbour] = new_dist + heapq.heappush(heap, (new_dist, neighbour)) + + return dist +``` + +- 时间:使用二叉堆为 $O((|V| + |E|) \log |V|)$。 + +- **陷阱**:`if d > dist[node]: continue` 这行是必须的。没有它,你会处理过期的堆条目,可能退化到 $O(|V|^2)$。 + +- **陷阱**:Dijkstra 不适用于负权重。如果一条边有负权重,贪心假设(一旦节点被确定,其距离就是最优的)就不成立了。应改用 Bellman-Ford。 + +### 困难:网络延迟时间 + +- **问题**:给定 $n$ 个节点和加权有向边,找出信号从源点到达所有节点所需的时间。如果并非所有节点都可到达,返回 -1。 + +```python +def network_delay(times, n, k): + graph = {i: [] for i in range(1, n + 1)} + for u, v, w in times: + graph[u].append((v, w)) + + dist = dijkstra(graph, k) + max_time = max(dist.values()) + return max_time if max_time < float('inf') else -1 +``` + +## 强连通分量 + +- 在有向图中,**强连通分量(SCC)**是一个最大节点集合,其中每个节点都能到达其他所有节点。 + +- **Kosaraju 算法**:(1) 在原始图上进行 DFS,记录完成顺序。(2) 转置图(反转所有边)。(3) 按完成顺序的逆序在转置图上进行 DFS。第3步中的每个 DFS 树就是一个 SCC。 + +- **何时使用**:寻找循环依赖、2-SAT、将有向图压缩为 SCC 的 DAG。 + +--- + +## 常见陷阱总结 + +| 陷阱 | 示例 | 修复 | +|---------|---------|-----| +| 在出队时标记已访问 | 同一节点被多次入队 | 在入队时标记已访问 | +| 有向图中只有两种状态 | 无法区分回边和交叉边 | 使用三种状态:未访问/进行中/已完成 | +| Dijkstra 用于负权重 | 错误的最短路径 | 使用 Bellman-Ford | +| 忘记 `if d > dist[node]: continue` | 处理过期堆条目 | 总是跳过当前距离更差的情况 | +| 网格边界检查 | 索引越界 | `0 <= nr < rows and 0 <= nc < cols` | +| 没有考虑 time=0 的边界情况 | 腐烂橘子:没有新鲜橘子 | 在 BFS 之前检查 `fresh == 0` | +| 将有向图构建为无向图 | 先修条件是单向的 | 只在一个方向添加边 | + +--- + +## 课后练习题(NeetCode) + +### BFS 模式 +- [岛屿数量](https://neetcode.io/problems/count-number-of-islands) — 网格 BFS/DFS +- [腐烂的橘子](https://neetcode.io/problems/rotting-fruit) — 多源 BFS +- [克隆图](https://neetcode.io/problems/clone-graph) — BFS + 哈希表克隆 +- [太平洋大西洋水流](https://neetcode.io/problems/pacific-atlantic-water-flow) — 从两个海洋开始的 BFS +- [单词接龙](https://neetcode.io/problems/word-ladder) — 隐式图上的 BFS + +### DFS 模式 +- [岛屿的最大面积](https://neetcode.io/problems/max-area-of-island) — 带面积计数的 DFS +- [课程表](https://neetcode.io/problems/course-schedule) — 有向图中的环检测 +- [课程表 II](https://neetcode.io/problems/course-schedule-ii) — 拓扑排序 +- [连通分量数量](https://neetcode.io/problems/count-connected-components) — DFS 或并查集 +- [图是否是树](https://neetcode.io/problems/valid-tree) — 连通 + 无环 + +### 最短路径 +- [网络延迟时间](https://neetcode.io/problems/network-delay-time) — Dijkstra +- [K 站中转内最便宜的航班](https://neetcode.io/problems/cheapest-flight-path) — 带约束的修改版 BFS/Bellman-Ford +- [上升水温游泳](https://neetcode.io/problems/swim-in-rising-water) — 二分查找 + BFS 或网格上的 Dijkstra + +### 进阶 +- [外星文字典](https://neetcode.io/problems/foreign-dictionary) — 从字符顺序进行拓扑排序 diff --git a/chapter 14: data structures and algorithms/05. sorting and search.md b/chapter 14: data structures and algorithms/05. sorting and search.md new file mode 100644 index 0000000..6663107 --- /dev/null +++ b/chapter 14: data structures and algorithms/05. sorting and search.md @@ -0,0 +1,553 @@ +# 排序、搜索与算法设计 + +*排序和搜索是最基础的算法操作。本文件涵盖排序算法、二分查找模式、分治法、贪心算法、动态规划和回溯。* + +- 每个数据结构都支持算法,每个算法都依赖数据结构。本文件涵盖了**设计范式**:解决问题的高级策略。一旦你识别出适用哪种范式,实现就自然而然地跟进了。 + +## 排序算法 + +- 排序是计算机科学中研究最多的问题。理解这些算法可以建立对递归、分治法和复杂度分析的直觉。 + +| 算法 | 最好 | 平均 | 最坏 | 空间 | 稳定? | +|-----------|------|---------|-------|-------|---------| +| 冒泡排序 | $O(n)$ | $O(n^2)$ | $O(n^2)$ | $O(1)$ | 是 | +| 插入排序 | $O(n)$ | $O(n^2)$ | $O(n^2)$ | $O(1)$ | 是 | +| 归并排序 | $O(n \log n)$ | $O(n \log n)$ | $O(n \log n)$ | $O(n)$ | 是 | +| 快速排序 | $O(n \log n)$ | $O(n \log n)$ | $O(n^2)$ | $O(\log n)$ | 否 | +| 堆排序 | $O(n \log n)$ | $O(n \log n)$ | $O(n \log n)$ | $O(1)$ | 否 | +| 计数排序 | $O(n + k)$ | $O(n + k)$ | $O(n + k)$ | $O(k)$ | 是 | +| 基数排序 | $O(d(n + k))$ | $O(d(n + k))$ | $O(d(n + k))$ | $O(n + k)$ | 是 | + +- **稳定**意味着相等元素保持其相对顺序。这在按多个键排序时很重要。 + +- 基于比较的排序的**下限**是 $\Omega(n \log n)$。证明使用决策树(第13章):任何比较排序必须区分所有 $n!$ 种排列,至少需要 $\log_2(n!) = \Omega(n \log n)$ 次比较。计数排序和基数排序通过不比较元素而超越了这个下限。 + +### 归并排序 + +- 将数组分成两半,递归排序每一半,然后合并已排序的两半。始终为 $O(n \log n)$,$O(n)$ 额外空间。 + +```python +def merge_sort(arr): + if len(arr) <= 1: + return arr + + mid = len(arr) // 2 + left = merge_sort(arr[:mid]) + right = merge_sort(arr[mid:]) + + return merge(left, right) + +def merge(left, right): + result = [] + i = j = 0 + while i < len(left) and j < len(right): + if left[i] <= right[j]: # <= 保证稳定性 + result.append(left[i]) + i += 1 + else: + result.append(right[j]) + j += 1 + result.extend(left[i:]) + result.extend(right[j:]) + return result +``` + +- **陷阱**:在合并中使用 `<` 而不是 `<=` 会破坏稳定性(右半部分的相等元素会排在左半部分之前)。 + +### 快速排序 + +- 选择一个**基准**,将元素分为"小于基准"和"大于基准"两组,递归排序每组。平均 $O(n \log n)$,最坏 $O(n^2)$(当基准总是最小或最大元素时)。 + +```python +def quicksort(arr, lo=0, hi=None): + if hi is None: + hi = len(arr) - 1 + if lo >= hi: + return + + pivot_idx = partition(arr, lo, hi) + quicksort(arr, lo, pivot_idx - 1) + quicksort(arr, pivot_idx + 1, hi) + +def partition(arr, lo, hi): + pivot = arr[hi] # Lomuto 分区:基准是最后一个元素 + i = lo + for j in range(lo, hi): + if arr[j] < pivot: + arr[i], arr[j] = arr[j], arr[i] + i += 1 + arr[i], arr[hi] = arr[hi], arr[i] + return i +``` + +- **基准策略**:最后一个元素(简单,对已排序输入不好)、随机(期望 $O(n \log n)$)、三数取中(实际选择)。在面试中始终优先选择随机基准以避免最坏情况的讨论。 + +- **陷阱**:快速排序的 $O(n^2)$ 最坏情况发生在已排序数组配合首/尾基准时。实践中,随机基准或三数取中消除了这个问题。 + +### 计数排序 + +- 当值在已知范围 $[0, k)$ 内的整数时,统计出现次数并重构:$O(n + k)$ 时间。不是基于比较的,因此可以超越 $O(n \log n)$。 + +```python +def counting_sort(arr, k): + count = [0] * k + for x in arr: + count[x] += 1 + result = [] + for val in range(k): + result.extend([val] * count[val]) + return result +``` + +- **何时使用**:范围 $k$ 不比 $n$ 大很多。如果 $k = O(n)$,这是 $O(n)$。如果 $k \gg n$(例如,对范围 $[0, 10^9]$ 中的 10 个数字排序),计数排序会浪费内存。 + +--- + +## 模式:二分查找 + +- 二分查找通过在已排序数组中反复减半搜索空间来以 $O(\log n)$ 的时间找到目标。但二分查找远不止"在已排序数组中找一个数"。通用模式是:**在单调条件上进行搜索**。 + +- **模板**(避免差一错误的那一个): + +```python +def binary_search(arr, target): + lo, hi = 0, len(arr) - 1 + + while lo <= hi: + mid = lo + (hi - lo) // 2 # 在其他语言中避免溢出 + if arr[mid] == target: + return mid + elif arr[mid] < target: + lo = mid + 1 + else: + hi = mid - 1 + + return -1 # 未找到 +``` + +- **下界**(第一个 $\geq$ target 的元素): + +```python +def lower_bound(arr, target): + lo, hi = 0, len(arr) + while lo < hi: + mid = (lo + hi) // 2 + if arr[mid] < target: + lo = mid + 1 + else: + hi = mid + return lo +``` + +- **陷阱**:`lo <= hi` 和 `lo < hi` 的区别,以及 `hi = mid` 和 `hi = mid - 1` 的区别,决定了你是找到精确匹配还是边界。用一个2元素数组画出来验证。 + +### 简单:二分查找 + +- 标准问题。使用上面的模板。 + +### 中等:搜索旋转排序数组 + +- **问题**:一个排序数组在某个枢轴处被旋转。找到目标值。 + +- **模式**:在每一步,有一半总是有序的。确定哪一半是有序的,并检查目标是否在这一半中。 + +```python +def search_rotated(nums, target): + lo, hi = 0, len(nums) - 1 + + while lo <= hi: + mid = (lo + hi) // 2 + if nums[mid] == target: + return mid + + # 左半部分有序 + if nums[lo] <= nums[mid]: + if nums[lo] <= target < nums[mid]: + hi = mid - 1 + else: + lo = mid + 1 + # 右半部分有序 + else: + if nums[mid] < target <= nums[hi]: + lo = mid + 1 + else: + hi = mid - 1 + + return -1 +``` + +- **陷阱**:`nums[lo] <= nums[mid]` 中的 `<=`(而不是 `<`)至关重要。当 `lo == mid`(只剩2个元素)时,我们必须正确识别有序的一半。 + +### 困难:寻找两个有序数组的中位数 + +- **问题**:在 $O(\log(m + n))$ 时间内找到两个有序数组的中位数。 + +- **模式**:对较小数组的分割点进行二分查找。分割将两个数组分为两部分,使得左侧所有元素都小于右侧所有元素。 + +```python +def find_median(nums1, nums2): + if len(nums1) > len(nums2): + nums1, nums2 = nums2, nums1 # 确保 nums1 较短 + + m, n = len(nums1), len(nums2) + lo, hi = 0, m + half = (m + n + 1) // 2 + + while lo <= hi: + i = (lo + hi) // 2 # nums1 中的分割点 + j = half - i # nums2 中的分割点 + + left1 = nums1[i - 1] if i > 0 else float('-inf') + right1 = nums1[i] if i < m else float('inf') + left2 = nums2[j - 1] if j > 0 else float('-inf') + right2 = nums2[j] if j < n else float('inf') + + if left1 <= right2 and left2 <= right1: + # 正确分割 + if (m + n) % 2 == 1: + return max(left1, left2) + return (max(left1, left2) + min(right1, right2)) / 2 + elif left1 > right2: + hi = i - 1 + else: + lo = i + 1 +``` + +- 这是最难的二分查找问题之一。关键在于你搜索的不是一个值,而是一个**满足条件的分割点**。 + +### 元模式:对答案进行二分查找 + +- 许多看起来不像二分查找的问题可以通过对答案进行二分查找来解决。如果答案是一个数字,并且你可以写一个单调的函数 `is_feasible(x)`(对所有 $x \geq$ 最优值为 True,或对所有 $x \geq$ 最优值为 False),那么就在 $x$ 上进行二分查找。 + +- **示例**:"在 $d$ 天内运送所有包裹所需的最小运力是多少?"对运力进行二分查找。对于每个候选运力,贪心地检查是否可以在 $d$ 天内运送所有包裹。 + +```python +def ship_within_days(weights, days): + lo, hi = max(weights), sum(weights) + + while lo < hi: + mid = (lo + hi) // 2 + # 能否以运力 mid 在 <= days 天内运送完? + current_load, num_days = 0, 1 + for w in weights: + if current_load + w > mid: + num_days += 1 + current_load = 0 + current_load += w + + if num_days <= days: + hi = mid + else: + lo = mid + 1 + + return lo +``` + +--- + +## 模式:贪心算法 + +- **贪心**算法在每一步做出局部最优选择,希望这能导致全局最优解。贪心在问题具有**贪心选择性质**(局部最优导致全局最优)和**最优子结构**(最优解包含子问题的最优解)时有效。 + +### 中等:跳跃游戏 + +- **问题**:给定一个数组,其中 `nums[i]` 是在位置 $i$ 的最大跳跃长度,判断是否能够到达最后一个索引。 + +```python +def can_jump(nums): + max_reach = 0 + for i, jump in enumerate(nums): + if i > max_reach: + return False # 无法到达这个位置 + max_reach = max(max_reach, i + jump) + return True +``` + +- **为什么贪心有效**:我们只需要知道最远可达位置。如果当前位置超过了最远可达位置,我们就卡住了。否则,更新最远可达位置。 + +### 中等:合并区间 + +- **问题**:合并重叠的区间。 + +```python +def merge_intervals(intervals): + intervals.sort(key=lambda x: x[0]) + merged = [intervals[0]] + + for start, end in intervals[1:]: + if start <= merged[-1][1]: + merged[-1][1] = max(merged[-1][1], end) + else: + merged.append([start, end]) + + return merged +``` + +- **模式**:按开始时间排序,然后贪心地合并。如果当前区间与上一个合并的区间重叠,则扩展它。否则,开始一个新的合并区间。 + +- **陷阱**:使用 `merged[-1][1] = end` 而不是 `merged[-1][1] = max(merged[-1][1], end)`。一个区间可能完全包含在另一个区间内(例如 [1, 10] 和 [2, 5])。 + +--- + +## 模式:动态规划 + +- **动态规划(DP)**通过将问题分解为重叠的子问题,每个子问题只解一次并存储结果。它适用于具有**最优子结构**和**重叠子问题**的问题。 + +- **两种方法**: + - **自顶向下(记忆化)**:写出自然的递归解法,然后在字典中缓存结果。 + - **自底向上(制表法)**:从最小的子问题开始向上构建表格。 + +- **如何识别 DP**:问题要求最优值(最小/最大)、计数或存在性,并且当前决策依赖于先前的决策。如果你画出递归树并看到重复的子问题,那就是 DP。 + +### 简单:爬楼梯 + +- **问题**:$n$ 个台阶,每次可以爬 1 或 2 个台阶。有多少种不同的方法? + +- 这就是斐波那契数列:$f(n) = f(n-1) + f(n-2)$。 + +```python +def climb_stairs(n): + if n <= 2: + return n + a, b = 1, 2 + for _ in range(3, n + 1): + a, b = b, a + b + return b +``` + +- $O(n)$ 时间,$O(1)$ 空间。不需要完整的记忆化表,因为每个状态只依赖于前两个。 + +### 中等:零钱兑换 + +- **问题**:给定硬币面额和一个目标金额,找到所需的最少硬币数量。 + +- **状态**:`dp[amount]` = 凑成 `amount` 所需的最小硬币数。 +- **转移**:`dp[amount] = min(dp[amount - coin] + 1)` 对每个硬币。 +- **基本情况**:`dp[0] = 0`。 + +```python +def coin_change(coins, amount): + dp = [float('inf')] * (amount + 1) + dp[0] = 0 + + for a in range(1, amount + 1): + for coin in coins: + if coin <= a and dp[a - coin] + 1 < dp[a]: + dp[a] = dp[a - coin] + 1 + + return dp[amount] if dp[amount] != float('inf') else -1 +``` + +- **陷阱**:用 `float('inf')` 初始化(而不是 0 或 -1)。最小比较只有在不可达状态为无穷大时才有效。 + +### 中等:最长公共子序列 + +- **问题**:给定两个字符串,找出它们的最长公共子序列的长度。 + +- **状态**:`dp[i][j]` = `text1[:i]` 和 `text2[:j]` 的 LCS。 +- **转移**:如果 `text1[i-1] == text2[j-1]`,则 `dp[i][j] = dp[i-1][j-1] + 1`。否则,`dp[i][j] = max(dp[i-1][j], dp[i][j-1])`。 + +```python +def longest_common_subsequence(text1, text2): + m, n = len(text1), len(text2) + dp = [[0] * (n + 1) for _ in range(m + 1)] + + for i in range(1, m + 1): + for j in range(1, n + 1): + if text1[i - 1] == text2[j - 1]: + dp[i][j] = dp[i - 1][j - 1] + 1 + else: + dp[i][j] = max(dp[i - 1][j], dp[i][j - 1]) + + return dp[m][n] +``` + +### 困难:0/1 背包 + +- **问题**:给定具有重量和价值的物品,以及容量 $W$,在不超出 $W$ 的情况下最大化总价值。 + +- **状态**:`dp[i][w]` = 使用前 $i$ 个物品在容量 $w$ 下的最大价值。 +- **转移**:`dp[i][w] = max(dp[i-1][w], dp[i-1][w - weight[i]] + value[i])`(跳过或取用物品 $i$)。 + +```python +def knapsack(weights, values, capacity): + n = len(weights) + dp = [[0] * (capacity + 1) for _ in range(n + 1)] + + for i in range(1, n + 1): + for w in range(capacity + 1): + dp[i][w] = dp[i - 1][w] # 跳过物品 i + if weights[i - 1] <= w: + dp[i][w] = max(dp[i][w], + dp[i - 1][w - weights[i - 1]] + values[i - 1]) + + return dp[n][capacity] +``` + +- **空间优化**:由于每一行只依赖于前一行,使用一维数组并从右向左迭代 $w$: + +```python +def knapsack_optimised(weights, values, capacity): + dp = [0] * (capacity + 1) + for i in range(len(weights)): + for w in range(capacity, weights[i] - 1, -1): # 从右向左! + dp[w] = max(dp[w], dp[w - weights[i]] + values[i]) + return dp[capacity] +``` + +- **陷阱**:在一维版本中从左向右迭代会允许多次使用物品 $i$(无限背包)。从右向左确保每个物品最多使用一次。 + +--- + +## 模式:回溯 + +- **回溯**是带剪枝的穷举搜索。逐步构建解,一旦部分解不可能导致完整的有效解,就立即放弃(回溯)。 + +- **模板**: + +```python +def backtrack(candidates, path, result): + if is_solution(path): + result.append(path[:]) # 复制! + return + + for candidate in get_candidates(path): + if is_valid(candidate, path): + path.append(candidate) # 选择 + backtrack(candidates, path, result) # 探索 + path.pop() # 撤销(回溯) +``` + +### 中等:子集 + +```python +def subsets(nums): + result = [] + def backtrack(start, path): + result.append(path[:]) + for i in range(start, len(nums)): + path.append(nums[i]) + backtrack(i + 1, path) + path.pop() + backtrack(0, []) + return result +``` + +### 中等:组合总和 + +- **问题**:找出所有和为目标值的唯一组合(元素可重复使用)。 + +```python +def combination_sum(candidates, target): + result = [] + def backtrack(start, path, remaining): + if remaining == 0: + result.append(path[:]) + return + for i in range(start, len(candidates)): + if candidates[i] > remaining: + break # 剪枝:已排序,后续候选都太大 + path.append(candidates[i]) + backtrack(i, path, remaining - candidates[i]) # i 而不是 i+1:允许重复使用 + path.pop() + + candidates.sort() # 排序以便剪枝 + backtrack(0, [], target) + return result +``` + +- **陷阱**:`backtrack(i, ...)` 允许重复使用同一元素。`backtrack(i + 1, ...)` 会移动到下一个元素(不可重复使用)。搞错这个是最常见的回溯 bug。 + +### 困难:N 皇后 + +- **问题**:在 $n \times n$ 的棋盘上放置 $n$ 个皇后,使得它们互不攻击。 + +```python +def solve_n_queens(n): + result = [] + cols = set() + pos_diag = set() # (row + col) 在 / 对角线上为常数 + neg_diag = set() # (row - col) 在 \ 对角线上为常数 + + board = [['.' ] * n for _ in range(n)] + + def backtrack(row): + if row == n: + result.append([''.join(r) for r in board]) + return + + for col in range(n): + if col in cols or (row + col) in pos_diag or (row - col) in neg_diag: + continue + + cols.add(col) + pos_diag.add(row + col) + neg_diag.add(row - col) + board[row][col] = 'Q' + + backtrack(row + 1) + + cols.remove(col) + pos_diag.remove(row + col) + neg_diag.remove(row - col) + board[row][col] = '.' + + backtrack(0) + return result +``` + +- **关键洞察**:对角线编码。对于 `/` 对角线,`row + col` 是常数。对于 `\` 对角线,`row - col` 是常数。使用集合跟踪列和对角线使得有效性检查变为 $O(1)$。 + +--- + +## 常见陷阱总结 + +| 陷阱 | 示例 | 修复 | +|---------|---------|-----| +| 二分查找中 `lo <= hi` vs `lo < hi` | 边界差一错误 | 根据 `hi` 是包含还是排除来选择 | +| 从左到右的一维背包 | 物品被多次使用 | 0/1 背包从右向左迭代 | +| 回溯中未复制路径 | `result.append(path)` — 所有条目指向同一列表 | `result.append(path[:])` 或 `path.copy()` | +| `backtrack(i)` vs `backtrack(i+1)` | 重复使用 vs 不重复使用元素 | 匹配问题要求 | +| 排序后的回溯中缺少 `break` | 探索过大的候选 | 排序 + 候选超过剩余时 break | +| DP 初始化 | `dp[0]` 错误 → 所有后续值都错 | 仔细定义基本情况 | +| 未经证明的贪心 | 贪心并不总是有效 | 验证贪心选择性质 | +| 多键排序时不稳定 | 相等元素的相对顺序丢失 | 使用稳定排序(归并排序、Python 的 `sorted`) | + +--- + +## 课后练习题(NeetCode) + +### 二分查找 +- [二分查找](https://neetcode.io/problems/binary-search) — 标准模板 +- [搜索二维矩阵](https://neetcode.io/problems/search-2d-matrix) — 在展平矩阵上二分查找 +- [Koko 吃香蕉](https://neetcode.io/problems/eating-bananas) — 对答案二分查找 +- [搜索旋转排序数组](https://neetcode.io/problems/find-target-in-rotated-sorted-array) — 识别有序的一半 +- [寻找旋转排序数组中的最小值](https://neetcode.io/problems/find-minimum-in-rotated-sorted-array) — 搜索拐点 +- [寻找两个有序数组的中位数](https://neetcode.io/problems/median-of-two-sorted-arrays) — 基于分割的二分查找 + +### 贪心 +- [跳跃游戏](https://neetcode.io/problems/jump-game) — 跟踪最远距离 +- [跳跃游戏 II](https://neetcode.io/problems/jump-game-ii) — BFS 风格的层级跟踪 +- [合并区间](https://neetcode.io/problems/merge-intervals) — 排序 + 合并 +- [插入区间](https://neetcode.io/problems/insert-new-interval) — 寻找重叠区域 +- [无重叠区间](https://neetcode.io/problems/non-overlapping-intervals) — 按结束时间排序 + +### 动态规划 +- [爬楼梯](https://neetcode.io/problems/climbing-stairs) — 斐波那契 DP +- [打家劫舍](https://neetcode.io/problems/house-robber) — 取或不取 DP +- [打家劫舍 II](https://neetcode.io/problems/house-robber-ii) — 环形:运行两次 +- [零钱兑换](https://neetcode.io/problems/coin-change) — 无限背包 +- [最长公共子序列](https://neetcode.io/problems/longest-common-subsequence) — 两个字符串上的 2D DP +- [单词拆分](https://neetcode.io/problems/word-break) — 带集合查找的 DP +- [最长递增子序列](https://neetcode.io/problems/longest-increasing-subsequence) — $O(n^2)$ DP 或带二分查找的 $O(n \log n)$ +- [编辑距离](https://neetcode.io/problems/edit-distance) — 经典 2D DP +- [分割等和子集](https://neetcode.io/problems/partition-equal-subset-sum) — 0/1 背包变体 + +### 回溯 +- [子集](https://neetcode.io/problems/subsets) — 枚举所有子集 +- [组合总和](https://neetcode.io/problems/combination-target-sum) — 允许重复使用的回溯 +- [全排列](https://neetcode.io/problems/permutations) — 带使用集合的回溯 +- [子集 II](https://neetcode.io/problems/subsets-ii) — 跳过重复项 +- [单词搜索](https://neetcode.io/problems/search-for-word) — 网格回溯 +- [分割回文串](https://neetcode.io/problems/palindrome-partitioning) — 回溯 + 回文检查 +- [N 皇后](https://neetcode.io/problems/n-queens) — 约束传播 diff --git a/chapter 15: production software engineering/01. linux and CMD.md b/chapter 15: production software engineering/01. linux and CMD.md new file mode 100644 index 0000000..92b1d93 --- /dev/null +++ b/chapter 15: production software engineering/01. linux and CMD.md @@ -0,0 +1,318 @@ +# Linux 与命令行 + +*命令行是机器学习工程的主要界面:训练任务、服务器管理、数据管道和集群管理都通过终端进行。本文涵盖 Shell、文件系统、权限、进程管理、包管理器、环境变量、SSH 以及每位机器学习工程师日常使用的基本命令。* + +- GUI 适合浏览网页,但在凌晨 2 点在远程 GPU 集群上运行训练任务时却很糟糕。**命令行**(或终端、Shell)是能够扩展的工具:它在任何机器上都能工作,可编写脚本,可组合,并且在你的笔记本电脑、云 VM 和 HPC 集群上完全相同。 + +- 如果你是一名只使用 Jupyter notebook 和 VS Code 按钮的机器学习工程师,你正在浪费巨大的生产力。每个生产级机器学习系统都是通过命令行进行部署、监控和调试的。 + +## Shell + +- **Shell** 是一个读取你的命令并执行它们的程序。它是你和操作系统之间的中介(第 13 章)。最常见的 Shell 是 **bash**(大多数 Linux 系统的默认 Shell)和 **zsh**(macOS 的默认 Shell)。 + +- 命令的格式为:`command [options] [arguments]` + +```bash +ls -la /home/user # 命令=ls, 选项=-la, 参数=/home/user +``` + +- 选项修改行为(通常以 `-` 表示短选项,`--` 表示长选项)。`ls -l` 以长格式列出,`ls --all` 显示隐藏文件。许多选项可以组合:`ls -la` 表示将 `-l` 和 `-a` 一起使用。 + +### 基本导航 + +```bash +pwd # 打印当前工作目录(我在哪?) +ls # 列出当前目录中的文件 +ls -la # 列出所有文件(包括隐藏文件)及详细信息 +cd /path/to/dir # 切换目录 +cd .. # 返回上一级 +cd ~ # 返回用户主目录 +cd - # 返回上一个目录 +``` + +### 文件操作 + +```bash +cp source dest # 复制文件 +cp -r dir1 dir2 # 递归复制目录 +mv old new # 移动/重命名文件 +rm file # 删除文件(没有回收站——永久删除) +rm -rf dir # 递归删除目录(危险——无确认) +mkdir -p a/b/c # 创建嵌套目录 +touch file.txt # 创建空文件(或更新时间戳) +cat file.txt # 打印文件内容 +head -n 20 file # 显示前 20 行 +tail -f logfile # 实时跟踪日志文件(监控训练时非常有用) +``` + +- **陷阱**:`rm -rf` 是计算中最危险的命令。没有撤销操作。按回车前请三次检查路径。切勿运行 `rm -rf /` 或 `rm -rf ~`。 + +### 管道与重定向 + +- Shell 的杀手级特性是**可组合性**:将小命令连接起来完成复杂任务。 + +- **管道**(`|`):将一个命令的输出作为下一个命令的输入。 + +```bash +cat training.log | grep "loss" | tail -5 # 最后5行包含"loss"的内容 +ps aux | grep python # 查找正在运行的 Python 进程 +history | grep "docker" # 查找之前的 docker 命令 +``` + +- **重定向**:将输出发送到文件而不是屏幕。 + +```bash +python train.py > output.log 2>&1 # stdout 和 stderr 都输出到文件 +python train.py >> output.log # 追加(不覆盖) +echo "data" > file.txt # 覆盖文件 +echo "more" >> file.txt # 追加到文件 +``` + +- `2>&1` 将 stderr(文件描述符 2)重定向到 stdout(文件描述符 1)。没有它,错误消息仍会出现在屏幕上,只有正常输出会进入文件。 + +### 文本处理 + +```bash +grep "error" logfile.txt # 查找包含"error"的行 +grep -r "import torch" src/ # 递归搜索目录 +grep -i "warning" log.txt # 不区分大小写搜索 +grep -c "epoch" train.log # 统计匹配行数 + +wc -l file.txt # 统计行数 +wc -w file.txt # 统计单词数 + +sort data.txt # 按字母顺序排序 +sort -n numbers.txt # 按数值排序 +sort -u data.txt # 排序并去重 +uniq -c sorted.txt # 统计连续重复项 + +cut -d',' -f2,3 data.csv # 提取 CSV 的第 2 和第 3 列 +awk '{print $1, $3}' data.txt # 打印第 1 和第 3 个空白分隔字段 +sed 's/old/new/g' file.txt # 将所有"old"替换为"new" +``` + +- 这些命令可以优美地组合: + +```bash +# 查找日志文件中最常见的 10 种错误类型 +grep "ERROR" app.log | awk -F': ' '{print $2}' | sort | uniq -c | sort -rn | head -10 +``` + +### 查找文件 + +```bash +find . -name "*.py" # 查找所有 Python 文件 +find . -name "*.pyc" -delete # 查找并删除编译后的 Python 文件 +find /data -size +100M # 查找大于 100MB 的文件 +find . -mtime -1 # 查找过去 24 小时内修改过的文件 + +which python # python 可执行文件在哪? +locate filename # 快速查找文件(使用预构建索引) +``` + +## 文件系统层次结构 + +- Linux 将所有内容组织在以 `/` 为根的单棵树中: + +| 目录 | 用途 | +|-----------|---------| +| `/` | 整个文件系统的根 | +| `/home/user` | 你的个人文件、配置、项目 | +| `/etc` | 系统级配置文件 | +| `/usr` | 用户程序、库、文档 | +| `/usr/local` | 本地安装的软件(非包管理器安装) | +| `/var` | 可变数据:日志(`/var/log`)、数据库、缓存 | +| `/tmp` | 临时文件(重启后清除) | +| `/opt` | 可选的第三方软件 | +| `/proc` | 暴露内核和进程信息的虚拟文件系统 | +| `/dev` | 设备文件(磁盘、GPU 在这里显示) | + +- 对于机器学习:你的训练数据通常在 `/data` 或 `/home/user/data`,模型在 `/home/user/models`,CUDA 在 `/usr/local/cuda`。GPU 设备显示为 `/dev/nvidia0`、`/dev/nvidia1` 等。 + +## 文件权限 + +- 每个文件和目录有三种用户类别的三种权限类型: + +| 权限 | 文件 | 目录 | +|------------|------|-----------| +| **r**(读) | 查看内容 | 列出内容 | +| **w**(写) | 修改内容 | 在内部创建/删除文件 | +| **x**(执行) | 作为程序运行 | 进入(cd 进入)目录 | + +- 三种用户类别:**所有者**(u)、**组**(g)、**其他人**(o)。 + +```bash +ls -l script.py +# -rwxr-xr-- 1 henry ml_team 2048 Mar 28 script.py +# ^^^ 所有者权限:rwx(读、写、执行) +# ^^^ 组权限:r-x(读、执行,不可写) +# ^^^ 其他人权限:r--(只读) +``` + +```bash +chmod 755 script.py # owner=rwx, group=rx, others=rx +chmod +x script.py # 为所有人添加执行权限 +chmod u+w,g-w file.txt # 为所有者添加写权限,移除组的写权限 +chown henry:ml_team file # 更改所有者和组 +``` + +- **陷阱**:顶部带有 `#!/usr/bin/env python3` 的 Python 脚本需要执行权限(`chmod +x`)才能以 `./script.py` 方式运行。没有它,你必须使用 `python3 script.py`。 + +## 进程管理 + +- **进程**是一个正在运行的程序(第 13 章)。Shell 为你提供了管理它们的工具: + +```bash +ps aux # 列出所有正在运行的进程 +ps aux | grep python # 查找 Python 进程 +top # 实时进程监控(CPU、内存) +htop # top 的增强版(需单独安装) +nvidia-smi # GPU 使用情况(机器学习必备) +watch -n 1 nvidia-smi # 每秒刷新 nvidia-smi + +kill PID # 优雅终止进程 +kill -9 PID # 强制终止(优雅方式失败时使用) +killall python # 终止所有 Python 进程 + +# 后台运行 +python train.py & # 后台运行 +nohup python train.py > log.txt & # 后台运行,退出登录后仍存活 +``` + +- **`nohup`** 对机器学习训练至关重要:没有它,关闭 SSH 连接会终止训练任务。`nohup` 将进程从终端分离出来。 + +- **`screen`** 和 **`tmux`** 是终端复用器,可以创建持久会话。你可以在 tmux 会话中启动训练任务,断开 SSH 连接,稍后重新连接,会话(和训练)仍在运行。 + +```bash +tmux new -s training # 创建命名会话 +# ... 开始训练 ... +# Ctrl+B, 然后 D # 从会话分离 +tmux attach -t training # 稍后重新连接(即使 SSH 重新连接后也可用) +tmux ls # 列出会话 +``` + +## 包管理器 + +- **系统包**(操作系统级软件): + +```bash +# Debian/Ubuntu +sudo apt update # 刷新包列表 +sudo apt install htop # 安装包 +sudo apt upgrade # 升级所有包 + +# macOS +brew install wget # 通过 Homebrew 安装 +``` + +- **Python 包**: + +```bash +pip install torch # 从 PyPI 安装 +pip install -e . # 以可编辑模式安装当前项目 +pip install -r requirements.txt # 从 requirements 文件安装 +pip freeze > requirements.txt # 导出已安装的包 + +# Conda(用于复杂依赖,如 CUDA) +conda create -n myenv python=3.11 +conda activate myenv +conda install pytorch torchvision cudatoolkit=12.1 -c pytorch +``` + +- **陷阱**:永远不要将 `pip install` 安装到系统 Python 中。始终使用虚拟环境(`python -m venv env`、`conda create` 或 `uv venv`)。系统 Python 被操作系统工具共享;破坏它可能导致系统崩溃。 + +## 环境变量 + +- **环境变量**是所有程序都可以访问的键值对。它们在不改变代码的情况下配置行为。 + +```bash +export CUDA_VISIBLE_DEVICES=0,1 # 仅使用 GPU 0 和 1 +export PYTHONPATH=/home/user/src # 添加到 Python 的导入路径 +export WANDB_API_KEY=abc123 # Weights & Biases 的 API 密钥 + +echo $PATH # 查看当前 PATH +export PATH=$PATH:/usr/local/cuda/bin # 将 CUDA 添加到 PATH +``` + +- **`.bashrc`**(或 `.zshrc`):每次打开 Shell 时运行的命令。把你的 `export` 语句放在这里,这样它们就会持久存在。 + +- **`.env` 文件**:由 `python-dotenv` 等工具加载的项目特定变量。将密钥(API 密钥、数据库密码)保存在 `.env` 中,并将 `.env` 添加到 `.gitignore`。切勿将密钥提交到 Git。 + +## SSH(安全外壳协议) + +- **SSH** 通过加密通道将你连接到远程机器。这是你访问云 VM、GPU 服务器和 HPC 集群的方式。 + +```bash +ssh user@hostname # 连接到远程机器 +ssh -i ~/.ssh/key.pem user@ip # 使用特定密钥连接 +ssh -L 8888:localhost:8888 user@server # 端口转发(远程 Jupyter) +``` + +- **SSH 密钥**(公钥/私钥对)替代密码: + +```bash +ssh-keygen -t ed25519 # 生成密钥对 +ssh-copy-id user@server # 将公钥复制到服务器 +# 现在无需输入密码即可 SSH +``` + +- **SSH 配置**(`~/.ssh/config`)保存连接详情: + +``` +Host gpu-server + HostName 10.0.1.42 + User henry + IdentityFile ~/.ssh/gpu_key + LocalForward 8888 localhost:8888 +``` + +- 现在输入 `ssh gpu-server` 即可自动使用所有这些设置进行连接。 + +- **`scp`** 和 **`rsync`** 在机器之间传输文件: + +```bash +scp model.pt user@server:/data/models/ # 将文件复制到远程 +scp -r user@server:/data/results/ ./ # 从远程复制目录 +rsync -avz --progress data/ user@server:/data/ # 带进度同步(比 scp 更智能) +``` + +## 机器学习必备命令速查表 + +```bash +# GPU 监控 +nvidia-smi # GPU 使用快照 +watch -n 1 nvidia-smi # 实时监控 +gpustat # 更清晰的 GPU 概览(pip install gpustat) + +# 训练管理 +nohup python train.py > train.log 2>&1 & # 退出登录后仍存活的后台训练 +tail -f train.log # 监控训练输出 +kill %1 # 终止最后一个后台任务 + +# 磁盘使用(数据集很大) +df -h # 所有挂载点的磁盘空间 +du -sh /data/* # /data 中每个项目的大小 +du -sh --max-depth=1 . # 子目录的大小 + +# 内存 +free -h # RAM 使用情况 +cat /proc/meminfo # 详细内存信息 + +# 网络 +curl -O https://example.com/dataset.tar.gz # 下载文件 +wget https://example.com/model.bin # 替代下载工具 +curl -X POST http://localhost:8080/predict \ + -H "Content-Type: application/json" \ + -d '{"text": "hello"}' # 测试模型推理端点 + +# 归档 +tar -czf archive.tar.gz directory/ # 压缩 +tar -xzf archive.tar.gz # 解压 +zip -r archive.zip directory/ # zip 压缩 +unzip archive.zip # zip 解压 + +# 快速数据检查 +head -5 data.csv # CSV 的前 5 行 +wc -l data.csv # 统计行数 +cut -d',' -f1 data.csv | sort -u | wc -l # 统计第 1 列的唯一值数量 +``` diff --git a/chapter 15: production software engineering/02. git and repository management.md b/chapter 15: production software engineering/02. git and repository management.md new file mode 100644 index 0000000..7b25a1b --- /dev/null +++ b/chapter 15: production software engineering/02. git and repository management.md @@ -0,0 +1,218 @@ +# Git 与版本控制 + +*Git 是软件团队在不相互覆盖工作的情况下进行协作的方式。本文涵盖 Git 的心智模型、分支策略、合并与变基、冲突解决、拉取请求,以及管理机器学习特定挑战(如大文件和实验追踪)的方法。* + +- 每个严肃的软件项目都使用版本控制。**Git** 是主导系统,几乎所有开源项目和公司都在使用。没有 Git,协作就是通过电子邮件发送 zip 文件并祈祷没人覆盖你的更改。有了 Git,每次更改都可追踪、可撤销、可追溯。 + +- 对于机器学习工程师:Git 追踪你的代码、配置和实验脚本。结合实验追踪工具,它能提供可重现性:"是哪个确切的代码和配置产生了这个模型?" + +## 心智模型 + +- Git 追踪项目的**快照**。每次提交都是那一刻所有追踪文件的完整快照,而不是差异(在内部,Git 为效率存储差异,但从概念上讲,每次提交都是一个完整状态)。 + +- 文件的四个"位置": + + 1. **工作目录**:磁盘上的实际文件。你在这里编辑。 + 2. **暂存区**(索引):你标记为下一次提交的文件。`git add` 将更改移到这里。 + 3. **本地仓库**:你的提交历史,存储在 `.git/` 中。`git commit` 将暂存区保存为新的快照。 + 4. **远程仓库**(例如 GitHub):一个共享副本。`git push` 上传你的提交,`git pull` 下载他人的提交。 + +``` +Working Dir → git add → Staging → git commit → Local Repo → git push → Remote + ← git pull ← +``` + +- 暂存区正是 Git 强大之处。你可以编辑 10 个文件,但只提交其中的 3 个,将其他更改保留给另一次提交。这使得清晰的、有重点的提交成为可能。 + +### 基本命令 + +```bash +git init # 创建新仓库 +git clone url # 下载远程仓库 +git status # 有什么变化?(最常用的命令) +git add file.py # 暂存特定文件 +git add . # 暂存所有更改(谨慎使用) +git commit -m "descriptive msg" # 提交暂存的更改 +git push # 将提交上传到远程 +git pull # 下载并合并远程更改 +git log --oneline # 紧凑的提交历史 +git diff # 显示未暂存的更改 +git diff --staged # 显示已暂存的更改 +``` + +## 分支 + +- **分支**是指向一次提交的指针。默认分支是 `main`(或 `master`)。创建分支让你拥有独立的开发线:你可以在不影响 `main` 的情况下进行更改。 + +```bash +git branch feature-x # 创建分支 +git checkout feature-x # 切换到此分支 +git checkout -b feature-x # 创建并切换(一步完成) +git branch -d feature-x # 删除分支(合并后) +git branch -a # 列出所有分支(本地 + 远程) +``` + +- **何时分支**:始终需要。永远不要直接提交到 `main`。每个功能、错误修复或实验都有其自己的分支。这保持了 `main` 的稳定性和可部署性。 + +### 分支策略 + +- **功能分支**(最常见):每个功能/修复从 `main` 创建一个分支。完成后,打开拉取请求(PR)以合并回去。简单,适用于大多数团队。 + +- **主干开发**:开发人员频繁(每天多次)提交到 `main`,使用特性标记隐藏未完成的工作。持续部署的团队(Google、Facebook)更偏好这种方式。需要优秀的 CI/CD。 + +- **Gitflow**:为功能、发布和热修复设置单独的分支。更复杂,适用于有版本化发布的软件(移动应用、打包软件)。对大多数机器学习项目来说过于复杂。 + +- 对于机器学习团队:**功能分支**配合短生命周期的分支(1-3 天内合并)是最佳选择。生命周期长的分支会与 `main` 产生分歧,导致痛苦的合并冲突。 + +## 合并与变基 + +- **合并**创建一个新的"合并提交",将两个分支合并: + +```bash +git checkout main +git merge feature-x +``` + +- 这保留了完整的历史记录:你可以看到工作是在分支上完成的,以及何时合并的。合并提交有两个父节点。 + +- **变基**在你的分支上重放提交到目标分支之上: + +```bash +git checkout feature-x +git rebase main +``` + +- 这会重写历史:你的分支上的提交会获得新的哈希值,就好像你是从 `main` 的当前顶端开始工作一样。结果是线性的历史记录(没有合并提交),阅读起来更清晰。 + +- **何时使用哪种**: + - **变基**用于使用最新的 `main` 更改更新你的功能分支(保持分支整洁和最新)。 + - **合并**用于将你的功能分支集成到 `main`(保留分支历史)。 + - **永远不要变基已经推送并与他人共享**的提交。变基会重写历史;如果其他人已经基于原始提交开展工作,变基会导致混乱。 + +## 解决冲突 + +- **冲突**发生在两个分支修改同一文件的同一行时。Git 无法自动决定保留哪个更改,需要你手动解决。 + +``` +<<<<<<< HEAD +learning_rate = 0.001 +======= +learning_rate = 0.0005 +>>>>>>> feature-x +``` + +- `<<<<<<< HEAD` 和 `=======` 之间是当前分支的版本。`=======` 和 `>>>>>>> feature-x` 之间是传入分支的版本。你决定保留哪个(或组合它们),删除标记,保存,然后运行 `git add` 添加已解决的文件。 + +- **陷阱**:不要在已提交的文件中留下冲突标记。它们是会破坏你代码的字面文本。解决后始终搜索 `<<<<<<<`。 + +- **减少冲突**:保持分支短生命周期,频繁将 `main` 合并到你的分支中,避免多人同时编辑同一个文件。 + +## 编写良好的提交信息 + +- 提交信息是为了未来的你和你的队友。"修复错误"告诉不了你什么。"修复批次大小计算中的差一错误,该错误导致 8-GPU 训练时 OOM"告诉你一切。 + +- **格式**: + +``` +简短摘要(50 字以内,祈使语气) + +如果需要,可附带更长的描述。解释 WHY,而不是 WHAT +(差异显示了什么改变了)。每行不超过 72 个字符。 + +Fixes #123 +``` + +- **祈使语气**:"添加功能"而不是"已添加功能"或"添加了功能"。将其视为完成句子:"如果应用此提交,它将**添加功能**。" + +- **原子提交**:每个提交应做一件事。"添加数据加载器"是一个提交。"添加数据加载器并修复无关的错误并更新 README"应该是三个提交。这使得 `git bisect`(找到哪个提交引入了错误)成为可能。 + +## 拉取请求与代码审查 + +- **拉取请求(PR)**提议将一个分支合并到 `main`。它是代码审查的门户:队友阅读你的更改,提出改进建议,并在合并前批准。 + +- **良好的 PR 实践**: + - 保持 PR 小(少于 400 行更改)。大的 PR 会被敷衍批准,因为没人想审查 2000 行。 + - 编写清晰的描述:更改了什么、为什么以及如何测试。 + - 链接到促使更改的问题或工单。 + - 及时回复审查评论。 + - 在合并前压缩琐碎的提交(这样 `main` 就有干净的历史记录)。 + +- **代码审查不是为了找错误**(测试来做这个)。它的目的是:知识分享(审查者学习代码库)、设计反馈(这是正确的方法吗?)和维护标准(命名、风格、架构)。 + +## .gitignore + +- `.gitignore` 文件告诉 Git 排除哪些文件不被追踪。对于机器学习项目: + +```gitignore +# Python +__pycache__/ +*.pyc +*.egg-info/ +.venv/ +env/ + +# 数据和模型(对 git 来说太大) +data/ +*.csv +*.parquet +models/ +*.pt +*.onnx +*.bin +checkpoints/ + +# 密钥 +.env +*.pem +credentials.json + +# IDE +.vscode/ +.idea/ +*.swp + +# 操作系统 +.DS_Store +Thumbs.db + +# Jupyter +.ipynb_checkpoints/ + +# 实验输出 +wandb/ +mlruns/ +outputs/ +logs/ +``` + +- **陷阱**:在文件已被提交后将文件添加到 `.gitignore` 不会将其从仓库中移除。你还必须使用 `git rm --cached file` 来取消追踪。该文件会永远留在历史中,除非你重写历史(这很麻烦)。 + +## Git 在机器学习中的应用 + +- 机器学习引入了传统软件不面临的挑战: + +- **大文件**:数据集和模型权重可能有数 GB 或更大。Git 是为文本文件(源代码)设计的,而不是二进制 blob。解决方案: + - **Git LFS**(大文件存储):在 Git 中追踪指针,将实际文件存储在单独的服务器上。简单,但在 GitHub 上有限制存储/带宽。 + - **DVC**(数据版本控制):将数据和模型文件与 Git 分开管理,使用远程存储(S3、GCS)。像 Git 一样用于数据:`dvc add data.csv`、`dvc push`、`dvc pull`。 + +- **实验追踪**:哪个提交 + 哪些超参数 + 哪个数据产生了哪些指标?Git 追踪代码,但不追踪完整的实验上下文。 + - **Weights & Biases(W&B)**:记录指标、超参数、系统信息,并链接到 Git 提交。提供用于比较运行结果的仪表板。 + - **MLflow**:开源的实验追踪,带有模型注册表。记录参数、指标和产物。 + - **简单方法**:在你的训练脚本中记录 Git 哈希值:`git_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD']).strip()`。将其与你的结果一起存储。 + +- **可重现性检查清单**(每个实验需要追踪的内容): + - Git 提交哈希值(确切的代码版本) + - 配置文件 / 超参数 + - 随机种子 + - Python 和库版本(`pip freeze`) + - 数据版本(DVC 哈希值或数据集版本标签) + - 硬件(GPU 类型、GPU 数量) + +```bash +# 快速可重现性快照 +echo "Commit: $(git rev-parse HEAD)" > experiment_info.txt +echo "Branch: $(git branch --show-current)" >> experiment_info.txt +echo "Dirty: $(git status --porcelain | wc -l) files" >> experiment_info.txt +pip freeze >> experiment_info.txt +nvidia-smi >> experiment_info.txt +``` diff --git a/chapter 15: production software engineering/03. codebase design.md b/chapter 15: production software engineering/03. codebase design.md new file mode 100644 index 0000000..dca1d5e --- /dev/null +++ b/chapter 15: production software engineering/03. codebase design.md @@ -0,0 +1,383 @@ +# 代码库设计与模式 + +*良好的代码库设计是区分研究原型与生产级软件的关键。本文涵盖项目结构、整洁代码原则、与机器学习相关的设计模式、配置管理、日志、API 设计以及打包分发。* + +- 大多数机器学习代码始于 Jupyter notebook。Notebook 不断增长、被复制、修改、共享,最终变成由全局变量、死单元格和魔数组成的难以维护的混乱。**代码库设计**是一门组织代码的学科,使代码在项目增长过程中保持可理解和可修改。 + +- 这不是为了遵循规则而遵循规则。而是为了减少从"我想改变 X"到"X 已被修改并能正常工作"之间的时间。在精心设计的代码库中,这个时间是几分钟。在设计糟糕的代码库中,则需要几天的时间去考古、翻阅未记录的意大利面条式代码。 + +## 项目结构 + +- 一致的项目布局让任何人(包括未来的你)都能立即浏览代码库。 + +``` +my_project/ +├── src/my_project/ # 源代码(可导入的包) +│ ├── __init__.py +│ ├── data/ # 数据加载和预处理 +│ │ ├── __init__.py +│ │ ├── dataset.py +│ │ └── transforms.py +│ ├── models/ # 模型架构 +│ │ ├── __init__.py +│ │ ├── transformer.py +│ │ └── layers.py +│ ├── training/ # 训练循环、优化器 +│ │ ├── __init__.py +│ │ ├── trainer.py +│ │ └── losses.py +│ └── utils/ # 共享工具 +│ ├── __init__.py +│ └── logging.py +├── configs/ # 配置文件 +│ ├── base.yaml +│ └── experiment_1.yaml +├── scripts/ # 入口点(训练、评估、推理) +│ ├── train.py +│ ├── evaluate.py +│ └── serve.py +├── tests/ # 测试文件(镜像 src/ 结构) +│ ├── test_dataset.py +│ ├── test_model.py +│ └── test_trainer.py +├── notebooks/ # 仅用于探索(非生产代码) +├── pyproject.toml # 项目元数据和依赖 +├── README.md +├── .gitignore +└── Dockerfile +``` + +- **`src/` 布局**:将源代码放在 `src/my_project/` 下可以防止从当前目录意外导入(这会掩盖在生产环境中才会暴露的导入错误)。使用 `pip install -e .` 进行开发安装。 + +- **单仓库 vs 多仓库**:**单仓库**将所有相关项目放在一个仓库中(跨项目更改更容易、CI 共享)。**多仓库**给每个项目自己的仓库(边界更清晰、版本控制独立)。大多数机器学习团队从单仓库开始,必要时再拆分。 + +- **脚本 vs 库**:将入口点(`train.py`、`evaluate.py`)保留在 `scripts/` 中。将可复用的逻辑放在 `src/` 中。训练脚本应约为 50 行:解析配置、构建数据集、构建模型、构建训练器、训练。所有复杂性都在库中。 + +## 整洁代码原则 + +- **命名**:你能做的唯一最有影响力的事情。名为 `x` 的变量需要你阅读周围的代码才能理解。名为 `learning_rate` 的变量是自解释的。 + +```python +# 糟糕 +def proc(d, n, lr): + for i in range(n): + for k, v in d.items(): + v -= lr * g[k] + +# 良好 +def update_parameters(parameters, num_steps, learning_rate): + for step in range(num_steps): + for name, param in parameters.items(): + param -= learning_rate * gradients[name] +``` + +- **单一职责原则**:每个函数/类只做一件事。名为 `load_data_and_train_model` 的函数在做两件事,应该拆分。这使每个部分都可以独立测试、复用和理解。 + +- **DRY(不要重复自己)**——但不要过早抽象。如果你复制粘贴代码三次,将其提取为一个函数。但不要为只使用过一次的代码创建抽象。过早的抽象比重复更糟糕:它增加了复杂性但没有经过验证的好处。 + +```python +# 过早抽象(一个用例,过度设计) +class AbstractDataTransformPipelineFactory: + ... + +# 恰到好处(直接、清晰、在三处使用) +def normalise_image(image, mean, std): + return (image - mean) / std +``` + +- **魔数**:永远不要使用未解释的字面值。 + +```python +# 糟糕 +if len(batch) > 32: + split_batch(batch, 32) + +# 良好 +MAX_BATCH_SIZE = 32 +if len(batch) > MAX_BATCH_SIZE: + split_batch(batch, MAX_BATCH_SIZE) +``` + +- **函数应该简短**:如果一个函数不能在一屏内显示完整(约 30 行),那它可能做得太多了。将逻辑块提取为带有描述性名称的辅助函数。然后函数体读起来就像高级摘要。 + +## 适用于机器学习的设范计式 + +- 设计模式是针对常见问题的可复用解决方案。以下是与机器学习代码库最相关的模式: + +- **工厂模式**:在不指定确切类的情况下创建对象。当你的配置说 `model: "transformer"` 并且你需要实例化正确的类时很有用: + +```python +MODEL_REGISTRY = { + "transformer": TransformerModel, + "cnn": CNNModel, + "mlp": MLPModel, +} + +def build_model(config): + model_cls = MODEL_REGISTRY[config["model"]] + return model_cls(**config["model_params"]) +``` + +- 这使训练脚本与特定的模型实现解耦。添加新模型意味着在注册表中添加一行,而不是修改训练循环。 + +- **策略模式**:在运行时交换算法。适用于损失函数、优化器、调度器: + +```python +LOSS_FUNCTIONS = { + "mse": nn.MSELoss, + "cross_entropy": nn.CrossEntropyLoss, + "focal": FocalLoss, +} + +loss_fn = LOSS_FUNCTIONS[config["loss"]]() +``` + +- **观察者模式**(回调/钩子):让模块响应事件而不紧密耦合。训练框架(PyTorch Lightning、Keras)广泛使用回调: + +```python +class EarlyStopping: + def __init__(self, patience=5): + self.patience = patience + self.best_loss = float('inf') + self.counter = 0 + + def on_epoch_end(self, epoch, val_loss): + if val_loss < self.best_loss: + self.best_loss = val_loss + self.counter = 0 + else: + self.counter += 1 + if self.counter >= self.patience: + return "stop" +``` + +- **依赖注入**:将依赖项传入函数/类,而不是在内部创建。这使得测试变得容易(注入 mock)并且配置灵活: + +```python +# 糟糕:硬编码依赖 +class Trainer: + def __init__(self): + self.logger = WandbLogger() # 没有 W&B 就无法测试 + +# 良好:注入依赖 +class Trainer: + def __init__(self, logger): + self.logger = logger # 可以注入任何记录器,包括 mock +``` + +## 配置管理 + +- 硬编码超参数、文件路径和模型设置使实验无法重现,修改也很痛苦。**将配置外部化**到文件中。 + +- **YAML** 是机器学习配置最常见的格式: + +```yaml +# configs/experiment_1.yaml +model: + name: transformer + d_model: 512 + n_heads: 8 + n_layers: 6 + +training: + batch_size: 64 + learning_rate: 3e-4 + max_epochs: 100 + early_stopping_patience: 10 + +data: + train_path: /data/train.parquet + val_path: /data/val.parquet + max_seq_length: 512 +``` + +- **Hydra**(Facebook)是一个支持组合(将基础配置与实验特定覆盖合并)、命令行覆盖(`python train.py training.lr=1e-3`)和多运行(超参数扫描)的配置框架。 + +- **argparse** 适用于参数较少的脚本: + +```python +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--lr", type=float, default=3e-4) +parser.add_argument("--batch-size", type=int, default=64) +parser.add_argument("--config", type=str, default="configs/base.yaml") +args = parser.parse_args() +``` + +- **最佳实践**:有一个包含所有默认值的基础配置,以及每个实验的配置,只覆盖更改的部分。追踪每个实验的配置及其结果。 + +## 日志与可观测性 + +- `print` 语句用于调试。**日志**用于生产环境: + +```python +import logging + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +logger.debug("Batch loaded: %d samples", len(batch)) # 详细,用于调试 +logger.info("Epoch %d: loss=%.4f, lr=%.6f", epoch, loss, lr) # 正常运行 +logger.warning("GPU memory >90%%, consider reducing batch size") +logger.error("Failed to load checkpoint: %s", path) # 可恢复的错误 +logger.critical("CUDA out of memory, aborting") # 致命错误 +``` + +- **为什么不用 print**:日志支持级别(在生产环境中过滤调试消息)、格式化(时间戳、模块名)和处理程序(写入文件、发送到监控系统),而无需更改日志调用。 + +- **结构化日志**同时输出机器可解析的格式(JSON)和人类可读的消息。这使得可以搜索特定字段并设置告警: + +```python +logger.info("training_step", extra={ + "epoch": 5, "step": 1200, "loss": 0.0342, "lr": 2.1e-4 +}) +``` + +## API 设计 + +- 如果你的模型将被其他服务使用(Web 应用、移动应用、另一个机器学习管道),它需要一个 **API**(应用程序编程接口)。 + +- **REST API** 使用 HTTP 方法:`GET` 用于读取,`POST` 用于创建/预测,`PUT` 用于更新,`DELETE` 用于删除。端点遵循基于资源的命名: + +``` +POST /api/v1/predict # 发送输入,获取预测结果 +GET /api/v1/models # 列出可用模型 +GET /api/v1/models/{id} # 获取模型详情 +POST /api/v1/models/{id}/predict # 使用特定模型进行预测 +``` + +- **FastAPI** 是机器学习推理的首选 Python 框架: + +```python +from fastapi import FastAPI +from pydantic import BaseModel + +app = FastAPI() + +class PredictRequest(BaseModel): + text: str + +class PredictResponse(BaseModel): + label: str + confidence: float + +@app.post("/predict", response_model=PredictResponse) +async def predict(request: PredictRequest): + result = model.predict(request.text) + return PredictResponse(label=result.label, confidence=result.score) +``` + +- FastAPI 自动生成 API 文档(在 `/docs` 的 Swagger UI),使用 Pydantic 模型验证输入/输出,并支持异步以实现高吞吐量。 + +- **gRPC** 在内部服务间通信方面比 REST 更快。它使用 Protocol Buffers(二进制序列化,比 JSON 更小更快)并支持流式传输。TensorFlow Serving、Triton Inference Server 和许多微服务架构都使用它。 + +## 打包与分发 + +- 让你的代码可以作为包安装,使其他人(和你自己的脚本)可以干净地导入: + +```toml +# pyproject.toml +[project] +name = "my-ml-project" +version = "0.1.0" +requires-python = ">=3.10" +dependencies = [ + "torch>=2.0", + "jax>=0.4", + "pydantic>=2.0", +] + +[project.optional-dependencies] +dev = ["pytest", "ruff", "mypy"] + +[build-system] +requires = ["setuptools>=64"] +build-backend = "setuptools.backends._legacy:_Backend" +``` + +```bash +pip install -e ".[dev]" # 以可编辑模式安装,包含开发依赖 +``` + +- **可编辑安装**(`-e`):对源代码的更改会立即生效,无需重新安装。开发期间必不可少。 + +- **锁定依赖**:使用确切版本的 `requirements.txt`(`torch==2.2.1`,而不是 `torch>=2.0`)确保可重现性。使用 `pip freeze > requirements.txt` 捕获你当前的环境。对于更复杂的依赖管理,使用 `uv`、`poetry` 或 `pip-tools`。 + +## 使用 AI 编码助手 + +- AI 编码助手(Claude Code、GitHub Copilot、Cursor 等)现在已成为专业工程师工作流程的一部分。使用得当,它们能极大加速开发。使用不当,它们会引入微妙的错误、侵蚀你对代码库的理解,并制造虚假的生产力感。 + +- 正确的心智模型:**AI 助手是一个快速但缺乏经验的结对程序员**。它可以快速编写代码,熟悉语法和标准模式,并且阅读过的文档比你还多。但它不了解你的特定系统、业务约束、边界情况以及设计决策背后的*原因*。你是高级工程师;AI 助手是初级工程师。你来指导、审查并承担责任。 + +### AI 助手擅长之处 + +- **样板代码和脚手架**:生成 Dockerfile、CI 配置、测试夹具、数据类定义、argparse 设置。这些遵循众所周知的模式,手动编写很繁琐。让 AI 生成它们,然后审查正确性。 + +- **编写测试**:描述函数的行为,AI 助手生成测试用例。它通常会捕捉到你可能会遗漏的边界情况(空输入、负值、Unicode)。始终阅读生成的测试——它们验证的是你的假设,而不仅仅是你的代码。 + +- **重构**:"将这个块提取成函数"、"将这个类改为使用 dataclasses"、"给这个模块添加类型提示"。机械性的转换,意图明确,引入细微错误的风险较低。 + +- **探索和原型开发**:"写一个快速脚本来 benchmark 推理延迟"或"展示如何使用 HuggingFace tokeniser API"。AI 助手能比阅读文档更快地给你一个可用的起点。 + +- **文档和 docstrings**:AI 助手可以根据你的代码结构生成文档。你需要审查准确性,但苦力活已经自动化了。 + +- **调试辅助**:粘贴错误回溯信息并请求诊断。AI 助手通常能识别根本原因并提出修复建议,尤其是对于常见问题(形状不匹配、导入错误、CUDA 内存不足)。 + +### 何时不应依赖 AI 助手 + +- **新颖的架构决策**:如果你正在设计一个新的训练管道,AI 助手会给出一个通用的答案。它不了解你的数据约束、延迟要求或团队专业知识。使用 AI 助手来实现你已经深思熟虑的设计。 + +- **安全关键代码**:认证、加密、输入清理。AI 助手可能生成看起来正确但存在细微漏洞的代码(SQL 注入、不安全的默认值、时序攻击)。安全代码应由理解威胁模型的人编写,并由另一个人审查。 + +- **性能关键的内循环**:AI 助手会编写正确但天真的代码。对于 GPU 内核、内存关键的数据结构或延迟敏感的推理路径,你需要理解硬件约束(第 13 章、第 16 章)并有目的地进行优化。 + +- **你不理解的代码**:如果 AI 助手生成了 200 行代码,而你无法解释每一行的作用,那就不要提交。你现在正在维护你不理解的代码,当它出问题时(它会的),你无法调试。这是最常见也最危险的失败模式。 + +### 审查纪律 + +- **在提交前始终逐行阅读**生成的代码。这不是可选的。AI 助手的代码是草稿,不是成品。就像对待同事的拉取请求一样:批判性地审查它。 + +- **检查什么**: + - **正确性**:它是否真的做了你要求的事情?AI 助手经常解决与你意图略有不同的问题。 + - **边界情况**:它是否处理了空输入、None 值、负数、非常大的输入?AI 助手经常省略边界情况处理。 + - **幻想的 API**:AI 助手可能调用不存在函数或使用不存在的参数,尤其是对于较新或较少使用的库。验证每个 API 调用是否真实存在。 + - **过度工程**:AI 助手倾向于产生比必要更多的代码。一个 50 行的解决方案解决一个 10 行的问题,增加了不必要的复杂性。无情地简化。 + - **安全性**:硬编码的密钥、未经清理的用户输入、不安全的默认值。AI 助手不会以对抗性思维思考。 + - **风格一致性**:生成的代码是否与项目的约定一致(命名、模式、错误处理)? + +### 如何编写好的提示词 + +- AI 助手输出的质量直接与你的指令质量成正比。模糊的提示词得到模糊的代码。 + +- **糟糕**:"写一个数据加载器" +- **好**:"为一个包含'text'和'label'列的 CSV 文件编写一个 PyTorch DataLoader。使用 HuggingFace tokeniser 'bert-base-uncased' 对文本进行分词,max_length=512。返回 input_ids、attention_mask 和 label 作为张量。处理 CSV 中标签列有缺失值的情况,跳过那些行。" + +- **提供上下文**:告诉 AI 助手你的项目结构、现有代码、约束和约定。上下文越多,输出越好。 + +- **指定约束**:"只使用标准库"、"必须兼容 Python 3.10"、"不要使用全局变量"、"遵循 `src/models/transformer.py` 中的现有模式"。 + +- **要求解释**:"实现 X 并解释关键的设计决策。"这会迫使 AI 助手阐述其推理,使你更容易发现错误假设。 + +### 使用质量门控来捕捉 AI 助手的错误 + +- 你现有的质量基础设施(文件 04)捕捉 AI 助手的错误与捕捉人类的错误同样有效: + + - **类型检查(mypy)**:捕捉幻想的 API 签名和类型不匹配。 + - **代码检查(ruff)**:捕捉未使用的导入、未定义的变量和风格违规。 + - **测试(pytest)**:如果 AI 助手的代码通过了你的测试套件,它更可能是正确的。如果你还没有测试,在要求 AI 助手实现功能之前*先编写测试*(测试驱动开发与 AI 助手配合得特别好)。 + - **CI 管道**:在每次提交时自动运行上述所有检查。 + +- **"AI 助手写代码" + "质量门控验证"** 的组合比单独使用任何一种都更高效。AI 助手快速但草率;门控工具彻底但不写代码。两者结合,你同时获得速度和正确性。 + +### 生产力陷阱 + +- 使用编码助手的最大风险是**生产力的幻觉**。你可以在 10 分钟内生成 500 行代码。但如果你花 2 小时调试这些你并不理解的 500 行代码,那还不如自己花 30 分钟写 200 行代码来得快。 + +- 使用 AI 助手的真正生产力来自: + 1. **保持控制**:你决定架构,AI 助手填入实现。 + 2. **理解生成的内容**:如果你无法解释它,就重写它或让 AI 助手简化它。 + 3. **投资质量门控**:测试、类型和代码检查的成本通过每次 AI 交互分摊。 + 4. **利用 AI 助手弥补你的弱点**:如果你擅长算法但编写测试很慢,让 AI 助手写测试。如果你对 UI 代码很快但不熟悉数据库查询,让 AI 助手草拟 SQL。发挥你的优势,委托你的短板。 + +- 从编码助手中获益最多的工程师是那些已经擅长编码的人。AI 助手放大你现有的技能;它不会取代你的技能。理解数据结构、算法、系统设计和软件工程(整章的内容)让你能够有效地指导 AI 助手并批判性地评估其输出。 diff --git a/chapter 15: production software engineering/04. testing and quality assurance.md b/chapter 15: production software engineering/04. testing and quality assurance.md new file mode 100644 index 0000000..4e31b0a --- /dev/null +++ b/chapter 15: production software engineering/04. testing and quality assurance.md @@ -0,0 +1,322 @@ +# 测试与质量保障 + +*测试是你如何确保代码正常工作的方法——不仅是现在,而且在每次更改后都能正常工作。本文涵盖测试金字塔、使用 pytest 进行的单元测试、Mock、测试机器学习特定代码、CI/CD 管道、代码检查、格式化和代码审查——这些实践能在错误到达生产环境之前捕获它们。* + +- 机器学习代码以缺乏测试而闻名。"能训练,所以能工作"是普遍态度。这会导致静默错误:一个错误地打乱数据的数据加载器、一个有符号错误的损失函数、一个丢弃 5% 数据的预处理步骤。这些错误不会使你的程序崩溃。它们只是让你的模型悄悄变差,然后你浪费数周时间调试"本应更高"的指标。 + +- 测试不是额外负担。它是快速前进而不破坏东西的最快方式。 + +## 测试金字塔 + +- 测试按层级组织,从快速且狭窄到慢速且广泛: + + - **单元测试**(底层):隔离测试单个函数和类。快速(毫秒级),数量多(数百到数千)。"`normalise_image` 是否产生 [0, 1] 范围内的值?" + + - **集成测试**(中层):测试组件协同工作。较慢(秒级)。"数据加载器是否以模型期望的格式产生批次?" + + - **端到端测试**(顶层):测试从输入到输出的完整管道。较慢(分钟级)。"`python train.py --config test.yaml` 是否无错误完成并产生有效的检查点?" + +- 金字塔形状意味着:编写大量单元测试,较少数量的集成测试,以及少量端到端测试。单元测试捕获大多数错误,并在几秒钟内运行。端到端测试捕获集成问题,但慢且脆弱。 + +## 使用 pytest 进行单元测试 + +- **pytest** 是标准的 Python 测试框架。测试是以 `test_` 开头的函数,放在以 `test_` 开头的文件中: + +```python +# tests/test_utils.py + +def test_normalise_image(): + import numpy as np + image = np.array([0, 128, 255], dtype=np.uint8) + result = normalise_image(image, mean=128, std=128) + assert result.min() >= -1.0 + assert result.max() <= 1.0 + assert abs(result[1]) < 1e-6 # 128 被 mean=128 归一化后应约为 0 + +def test_normalise_empty(): + import numpy as np + image = np.array([], dtype=np.uint8) + result = normalise_image(image, mean=128, std=128) + assert len(result) == 0 +``` + +```bash +pytest tests/ # 运行所有测试 +pytest tests/test_utils.py # 运行一个文件 +pytest -v # 详细输出 +pytest -x # 在第一个失败时停止 +pytest -k "normalise" # 运行匹配名称模式的测试 +pytest --tb=short # 更短的追溯信息 +``` + +### 夹具 + +- **夹具**为测试提供可复用的设置。无需在每个测试中重复设置代码,只需定义一次: + +```python +import pytest + +@pytest.fixture +def sample_dataset(): + """创建一个用于测试的小型数据集。""" + return { + "inputs": torch.randn(10, 3, 32, 32), + "labels": torch.randint(0, 10, (10,)) + } + +@pytest.fixture +def trained_model(): + """加载一个小型预训练模型。""" + model = SmallModel() + model.load_state_dict(torch.load("tests/fixtures/small_model.pt")) + return model + +def test_model_output_shape(trained_model, sample_dataset): + output = trained_model(sample_dataset["inputs"]) + assert output.shape == (10, 10) # batch_size x num_classes +``` + +- 夹具可以有**作用域**:`scope="function"`(默认,每次测试重新创建)、`scope="module"`(每个文件一次)、`scope="session"`(每次测试运行一次)。对于加载模型等昂贵设置,使用 `scope="session"`。 + +### 参数化测试 + +- 使用多个输入测试同一个函数,无需重复代码: + +```python +@pytest.mark.parametrize("input,expected", [ + ([1, 2, 3], 6), + ([], 0), + ([-1, 1], 0), + ([1000000, 1000000], 2000000), +]) +def test_sum(input, expected): + assert sum(input) == expected +``` + +## Mock 与补丁 + +- **Mock** 在测试期间用假依赖替换真实依赖。这让你可以隔离测试函数,而无需数据库、API 或 GPU。 + +```python +from unittest.mock import patch, MagicMock + +def test_training_logs_metrics(): + mock_logger = MagicMock() + + with patch("my_project.training.trainer.wandb") as mock_wandb: + trainer = Trainer(logger=mock_logger) + trainer.train_one_epoch() + + # 验证训练器记录了指标 + mock_logger.log.assert_called() + # 验证它记录了损失值 + call_args = mock_logger.log.call_args + assert "loss" in call_args[1] +``` + +- **何时使用 Mock**:外部服务(API、数据库、云存储)、昂贵操作(GPU 计算、大型文件 I/O)和非确定性行为(随机数生成器、时间戳)。 + +- **何时不要 Mock**:你自己的代码。如果你 Mock 了所有内容,你的测试验证的是 Mock 的行为符合预期,而不是你的代码能工作。在边界处进行 Mock,直接测试你的逻辑。 + +## 测试机器学习代码 + +- 机器学习代码有独特的测试挑战:输出是概率性的,训练很慢,而且"正确"是模糊的。 + +### 确定性种子 + +- 在所有地方设置随机种子,使测试可重现: + +```python +import random +import numpy as np +import torch + +def set_seed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False +``` + +### 数值容差 + +- 浮点数比较需要容差(第 13 章,IEEE 754): + +```python +# 糟糕:由于浮点数问题,精确比较会失败 +assert model_output == 0.5 + +# 良好:近似比较 +import numpy as np +assert np.isclose(model_output, 0.5, atol=1e-5) + +# 对于张量 +assert torch.allclose(output, expected, atol=1e-4) +``` + +### 机器学习中需要测试什么 + +- **形状测试**:验证输出具有预期的维度。 + +```python +def test_model_output_shape(): + model = MyModel(d_model=256, n_classes=10) + x = torch.randn(8, 32, 256) # batch=8, seq=32, dim=256 + output = model(x) + assert output.shape == (8, 10) +``` + +- **梯度流**:验证可训练参数具有非零梯度。 + +```python +def test_gradients_flow(): + model = MyModel() + x = torch.randn(4, 3, 32, 32) + y = torch.randint(0, 10, (4,)) + + output = model(x) + loss = F.cross_entropy(output, y) + loss.backward() + + for name, param in model.named_parameters(): + assert param.grad is not None, f"没有 {name} 的梯度" + assert param.grad.abs().sum() > 0, f"{name} 的梯度为零" +``` + +- **在一个批次上过拟合**:模型应该能够记忆单个批次。如果不能,说明某处存在根本性问题。 + +```python +def test_overfit_one_batch(): + model = MyModel() + optimiser = torch.optim.Adam(model.parameters(), lr=1e-3) + x, y = get_single_batch() + + for _ in range(100): + loss = F.cross_entropy(model(x), y) + loss.backward() + optimiser.step() + optimiser.zero_grad() + + assert loss.item() < 0.01, f"无法过拟合单个批次:loss={loss.item()}" +``` + +- **数据验证**:验证数据加载产生有效输出。 + +```python +def test_dataset_basics(): + dataset = MyDataset("tests/fixtures/small_data.csv") + assert len(dataset) > 0 + x, y = dataset[0] + assert x.shape == (3, 224, 224) + assert 0 <= y < 10 + assert not torch.isnan(x).any() + assert not torch.isinf(x).any() +``` + +- **确定性**:相同输入 + 相同种子 → 相同输出。 + +```python +def test_determinism(): + set_seed(42) + output1 = model(input_data) + set_seed(42) + output2 = model(input_data) + assert torch.allclose(output1, output2) +``` + +## CI/CD 管道 + +- **持续集成(CI)**:在每次提交或 PR 上自动运行测试。如果测试失败,PR 不能合并。这防止了损坏的代码到达 `main`。 + +- **GitHub Actions** 示例(`.github/workflows/ci.yml`): + +```yaml +name: CI +on: [push, pull_request] + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + - run: pip install -e ".[dev]" + - run: ruff check src/ + - run: mypy src/ + - run: pytest tests/ -v --tb=short +``` + +- **预提交钩子**:在每次提交前(本地)运行检查,在它们到达 CI 之前捕获问题: + +```yaml +# .pre-commit-config.yaml +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.0 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml +``` + +```bash +pip install pre-commit +pre-commit install # 现在每次 git 提交时都会运行钩子 +``` + +## 代码检查与格式化 + +- **代码检查**无需运行代码即可捕获错误和风格问题。**格式化**自动强制执行一致的风格。 + +- **Ruff**:一个快速的 Python 代码检查器和格式化器(在一个工具中替代 flake8、isort 和 black): + +```bash +ruff check src/ # 代码检查 +ruff check --fix src/ # 代码检查并自动修复 +ruff format src/ # 格式化 +``` + +- **mypy**:Python 静态类型检查器。在运行时之前捕获类型错误: + +```bash +mypy src/ +# src/model.py:42: error: Argument 1 to "forward" has incompatible type "int"; expected "Tensor" +``` + +- 类型提示使代码自文档化并捕获错误: + +```python +def train( + model: nn.Module, + dataloader: DataLoader, + optimiser: torch.optim.Optimizer, + num_epochs: int = 10, +) -> float: + """训练模型并返回最终损失。""" + ... +``` + +## 代码审查最佳实践 + +- **对于作者**: + - 在请求审查之前先自我审查你的差异。你会发现明显的问题。 + - 保持 PR 小而专注。一个 PR 聚焦一个问题。 + - 写清晰的描述:什么、为什么、如何测试。 + - 回复每条评论(即使只是"已修改")。 + +- **对于审查者**: + - 保持友善。批评代码,而不是人。"这里可以更清晰"而不是"这很令人困惑。" + - 区分阻塞性问题(错误、安全)和建议(风格、命名)。使用标签:"nit:"、"suggestion:"、"blocking:"。 + - 提问而不是发号施令。"如果这个列表为空会怎样?"比"处理空的情况"更有帮助。 + - 及时批准。等待数天的 PR 会阻塞作者,并鼓励大型、批量的 PR(这些更难审查)。 diff --git a/chapter 15: production software engineering/05. deployment and devops.md b/chapter 15: production software engineering/05. deployment and devops.md new file mode 100644 index 0000000..71de4f1 --- /dev/null +++ b/chapter 15: production software engineering/05. deployment and devops.md @@ -0,0 +1,233 @@ +# 部署与 DevOps + +*部署是你的模型从研究产物变成产品的地方。本文涵盖用于机器学习的 Docker、模型推理、实验追踪、可重现性、生产环境监控、特征存储和管道编排——这些基础设施将一个训练好的模型从 notebook 带到数百万用户面前。* + +- 一个只在你笔记本电脑上运行的模型是原型。一个能够可靠地大规模运行、在毫秒内提供预测结果、能够从故障中恢复并在不中断服务的情况下更新的模型才是产品。两者之间的差距就是**部署与 DevOps**。 + +- 大多数机器学习工程师在部署、监控和调试生产问题上花费的时间比训练模型还多。理解这些基础设施对于任何构建真实 ML 系统的人来说都不是可选项。 + +## 用于机器学习的 Docker + +- 我们在第 13 章(操作系统)中概念性地介绍了容器。这里我们关注实践方面:为机器学习工作负载编写 Dockerfile。 + +- **Dockerfile** 是构建容器镜像的配方: + +```dockerfile +# 从官方的 CUDA 基础镜像开始 +FROM nvidia/cuda:12.1.0-cudnn8-runtime-ubuntu22.04 + +# 系统依赖 +RUN apt-get update && apt-get install -y \ + python3.11 python3-pip git \ + && rm -rf /var/lib/apt/lists/* + +# Python 依赖(单独安装以利用缓存) +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# 复制源代码(频繁更改,因此此层放在最后) +COPY src/ /app/src/ +COPY configs/ /app/configs/ +WORKDIR /app + +# 入口点 +CMD ["python3", "src/scripts/serve.py", "--config", "configs/serve.yaml"] +``` + +- **层缓存**:Docker 会缓存每一层。如果 `requirements.txt` 没有变化,`pip install` 在重新构建时会被跳过。将不常更改的层(系统包、pip 安装)放在频繁更改的层(源代码)之前。这将 10 分钟的构建变成 10 秒的重新构建。 + +- **GPU 访问**:使用 `nvidia/cuda` 基础镜像,并使用 `docker run --gpus all` 运行。`nvidia-container-toolkit` 提供从宿主机到容器的 GPU 透传。 + +- **多阶段构建**通过将构建环境与运行环境分离来减小镜像大小: + +```dockerfile +# 构建阶段:安装构建工具、编译依赖 +FROM python:3.11 AS builder +COPY requirements.txt . +RUN pip install --user -r requirements.txt + +# 运行阶段:仅运行环境依赖 +FROM nvidia/cuda:12.1.0-cudnn8-runtime-ubuntu22.04 +COPY --from=builder /root/.local /root/.local +COPY src/ /app/src/ +ENV PATH=/root/.local/bin:$PATH +``` + +- 最终镜像只包含运行时库,不包含编译器、头文件或构建工具。一个 5GB 的构建镜像变成了 2GB 的运行镜像。 + +- **Docker Compose** 运行多容器设置(模型服务器 + 负载均衡器 + 监控): + +```yaml +# docker-compose.yml +services: + model: + build: . + ports: + - "8080:8080" + deploy: + resources: + reservations: + devices: + - capabilities: [gpu] + prometheus: + image: prom/prometheus + ports: + - "9090:9090" +``` + +## 模型推理 + +- **模型推理**是将推理作为服务运行:接收请求、运行模型、返回预测结果。 + +- **FastAPI**(在文件 03 中介绍)适用于低到中等吞吐量的最简单方法。对于高吞吐量和 GPU 优化推理,使用专用工具: + +- **Triton Inference Server**(NVIDIA):以 TensorRT、ONNX、PyTorch 和 TensorFlow 格式提供模型。特性: + - **动态批处理**:收集单个请求并将它们分批处理以提高 GPU 效率。单个请求流被分组为 32 的批次,大幅提高吞吐量。 + - **模型集成**:在单个请求中链式调用多个模型(预处理器 → 模型 → 后处理器)。 + - **多模型推理**:在同一 GPU 上提供多个模型,共享资源。 + - **并发模型执行**:在同一 GPU 上并行运行多个推理请求。 + +- **TorchServe**(PyTorch):以 REST/gRPC API 提供 PyTorch 模型。支持模型版本控制、A/B 测试和自定义处理器。 + +- **vLLM**:专门用于 LLM 推理。实现了 PagedAttention(高效的 KV 缓存管理)、连续批处理和跨 GPU 的张量并行。对于大语言模型,吞吐量比朴素推理高出 10-20 倍。 + +- **Cactus**([github.com/cactus-compute/cactus](https://github.com/cactus-compute/cactus)):一个用于移动端和边缘端设备推理的低延迟 AI 引擎。Cactus 提供**兼容 OpenAI 的 API**(聊天补全、流式传输、工具调用、转录、嵌入、RAG、视觉),完全在设备上运行,当本地模型无法处理请求时自动进行**云回退**。这种混合架构意味着你的应用程序代码使用相同的 API,无论推理是在本地还是在云端运行——引擎根据模型置信度和设备能力来决定。提供 Python、Swift、Kotlin、Flutter、React Native 和 Rust 的 SDK,以及 HuggingFace 上预转换的模型权重。支持多模态推理(LLM、视觉、语音),配备自定义 ARM SIMD 内核以实现 ARM CPU 上的最快推理,以及零拷贝内存映射以实现 10 倍 RAM 使用降低(第 16 章、第 17 章)。 + +- **模型格式优化**: + - **ONNX**:用于互操作性的开放格式。从 PyTorch/TensorFlow 导出,在任何地方运行。 + - **TensorRT**:NVIDIA 的优化器。融合层、选择最佳内核、量化权重。在 NVIDIA GPU 上通常比 PyTorch 快 2-5 倍。 + - **GGUF/GGML**:适用于 CPU 高效推理的格式,在消费级硬件上运行 LLM 时很流行。 + +## 实验追踪 + +- 没有实验追踪,机器学习研究会退化为:"我觉得上周二那个我改了些配置的模型是最好的,但我不记得改了啥。" + +- **Weights & Biases(W&B)**:最流行的实验追踪工具。从你的训练脚本中记录任何内容: + +```python +import wandb + +wandb.init(project="my-project", config={ + "model": "transformer", + "lr": 3e-4, + "batch_size": 64, +}) + +for epoch in range(num_epochs): + train_loss = train_one_epoch() + val_loss = validate() + + wandb.log({ + "train/loss": train_loss, + "val/loss": val_loss, + "epoch": epoch, + }) + + # 将模型记录为产物 + if val_loss < best_loss: + wandb.save("best_model.pt") + +wandb.finish() +``` + +- W&B 提供:用于比较运行的仪表板、超参数扫描工具、模型注册表、数据集版本控制和团队协作。 + +- **MLflow**:开源替代方案。在本地或服务器上运行: + +```python +import mlflow + +mlflow.set_experiment("my-experiment") + +with mlflow.start_run(): + mlflow.log_params({"lr": 3e-4, "batch_size": 64}) + mlflow.log_metric("val_loss", 0.042, step=epoch) + mlflow.pytorch.log_model(model, "model") +``` + +- **模型注册表**:训练模型的中央存储,带版本控制、阶段(开发 → 预发布 → 生产)和元数据。W&B 和 MLflow 都提供注册表。注册表回答:"当前生产环境中的是哪个模型,谁训练的,其验证准确率是多少,以及由哪个代码/数据产生?" + +## 可重现性 + +- 可重现性意味着:给定相同的代码、数据和配置,产生相同的模型。这在机器学习中出奇地困难,因为 GPU 操作的非确定性、数据打乱和浮点数累积。 + +- **可重现性检查清单**: + +| 什么 | 如何做 | +|------|------| +| 代码版本 | Git 提交哈希值 | +| 配置 / 超参数 | 配置文件(在 Git 中版本控制或记录到 W&B) | +| 随机种子 | 设置并记录所有种子(Python、NumPy、PyTorch、CUDA) | +| 数据版本 | DVC 哈希值、数据集版本标签或 S3 对象版本 | +| 依赖项 | `pip freeze`、Docker 镜像哈希值或锁定文件 | +| 硬件 | GPU 类型、GPU 数量、CUDA 版本 | +| 非确定性 | `torch.backends.cudnn.deterministic = True`(较慢但可重现) | + +- **锁定所有内容**:`pip install torch==2.2.1` 而不是 `torch>=2.0`。次版本号升级可能改变数值行为、优化器实现或默认超参数。 + +- **使用 Docker 实现可重现性**:Docker 镜像锁定了操作系统、系统库、Python 版本和 pip 包。镜像哈希值是完整的环境指纹。如果你能重现 Docker 镜像,就能重现训练。 + +## 生产环境监控 + +- 部署模型不是终点——而是一系列新问题的开始。随着现实世界的变化(**概念漂移**)以及输入数据分布的变化(**数据漂移**),模型会随时间推移而退化。 + +- **需要监控的内容**: + + - **延迟**:推理需要多长时间?追踪 p50(中位数)、p95 和 p99。p99 为 500ms 意味着每 100 个用户中有 1 个要等待半秒钟,这可能不可接受。 + + - **吞吐量**:每秒处理多少个请求?系统是否跟得上需求? + + - **错误率**:有多少比例的请求失败(异常、超时、无效输入)? + + - **模型指标**:在验证集上的准确率、精确率、召回率。如果生产环境中存在标注数据(例如用户纠正),追踪在线指标。 + + - **数据漂移**:输入数据的分布是否发生了变化?在白天照片上训练的模型可能在夜间照片上失败。统计检验(KS 检验、PSI)将训练分布与在线分布进行比较。 + + - **特征漂移**:单个特征的分布是否发生了变化?训练时呈正态分布但在生产时呈双峰分布的特征,表明数据管道存在问题。 + +- **工具**: + - **Prometheus** + **Grafana**:基础设施监控的标准方案。Prometheus 收集指标,Grafana 将其可视化为带告警的仪表板。 + - **Evidently AI**:开源机器学习监控。生成关于数据漂移、模型性能和数据质量的报告。 + +- **告警**:不要只放在仪表板上——设置自动告警。"如果 p99 延迟超过 200ms 持续 5 分钟,发送 Slack 通知。""如果数据漂移评分超过阈值,通知值班工程师。" + +## 特征存储 + +- **特征存储**是预计算特征的集中式仓库,在训练和推理之间共享。它解决两个问题: + + - **训练-推理偏差**:训练期间使用的特征必须与推理期间使用的特征完全相同。如果训练使用一种方式计算的 `user_age_at_signup`,而推理使用不同的方式计算,模型的预测结果会静默出错。 + + - **特征复用**:多个模型通常使用相同的特征(用户人口统计、物品嵌入、聚合统计)。计算一次并共享,避免了重复和不一致性。 + +- **Feast** 是最流行的开源特征存储。它管理在线特征(低延迟,从 Redis 或 DynamoDB 提供)和离线特征(批处理,存储在数据仓库中用于训练)。 + +- 特征存储对于推荐系统、欺诈检测以及任何特征从原始数据管道计算而来的应用都至关重要。 + +## 管道编排 + +- 生产级机器学习系统不仅仅是模型。它是一个**管道**:数据采集 → 预处理 → 特征计算 → 训练 → 评估 → 部署 → 监控。每个步骤依赖于前一步骤,可以独立失败,可能需要在不同的时间表上运行。 + +- **编排器**管理这些管道: + +- **Apache Airflow**:数据管道编排的标准方案。DAG(有向无环图)定义任务依赖关系。每个任务独立运行,失败时可以重试,并通过 Web UI 进行监控。 + +```python +# airflow DAG 示例(简化) +from airflow import DAG +from airflow.operators.python import PythonOperator + +dag = DAG("training_pipeline", schedule="@daily") + +preprocess = PythonOperator(task_id="preprocess", python_callable=preprocess_data, dag=dag) +train = PythonOperator(task_id="train", python_callable=train_model, dag=dag) +evaluate = PythonOperator(task_id="evaluate", python_callable=evaluate_model, dag=dag) +deploy = PythonOperator(task_id="deploy", python_callable=deploy_model, dag=dag) + +preprocess >> train >> evaluate >> deploy +``` + +- **Kubeflow Pipelines**:在 Kubernetes 上运行机器学习特定编排。每个步骤在容器中运行,GPU 资源按需分配,实验自动追踪。 + +- **Prefect** 和 **Dagster**:Airflow 的现代替代方案,拥有更好的开发者体验、原生 Python API 和内置数据血缘追踪。 + +- **何时需要编排**:当你的管道有超过 2-3 个步骤、按计划运行、涉及多个团队或服务、或需要自动故障恢复时。单一脚本的训练任务不需要编排器。每天重新训练的管道——从 5 个数据源采集数据、训练 3 个模型、评估它们并部署最佳模型——绝对需要。 diff --git a/chapter 16: SIMD and GPU programming/00. why C++ and how ML frameworks work.md b/chapter 16: SIMD and GPU programming/00. why C++ and how ML frameworks work.md new file mode 100644 index 0000000..4f645af --- /dev/null +++ b/chapter 16: SIMD and GPU programming/00. why C++ and how ML frameworks work.md @@ -0,0 +1,419 @@ +# 为什么是C++以及ML框架如何工作 + +*本书中每一次 `jnp.matmul`、每一次 `torch.nn.Linear`、每一次 `np.dot` 调用,底层都在执行C++和CUDA代码。本文档揭开帷幕:为何ML框架采用这种架构,面向Python工程师的C++快速入门,何时编写自定义C++核函数,以及如何将其绑定到Python——这是连接你所写代码与所运行硬件之间的桥梁。* + +- 你花了15章写Python。你导入了JAX,调用了`jax.grad`,运行了训练循环,构建了模型。一切感觉都像是Python。但事实是:**几乎没有实际计算发生在Python中。** + +- 当你在PyTorch中写 `output = model(input)` 或在JAX中写 `output = jnp.matmul(W, x)` 时,Python几乎什么都不做。它构建一个计算的描述(一个操作图),然后将其交给执行真正工作的C++/CUDA后端。Python是方向盘;C++是引擎。 + +## 为什么Python前端搭配C++后端 + +- 这种双语言架构的存在是因为Python和C++擅长截然不同的事情: + +| | Python | C++ | +|--|--------|-----| +| 开发速度 | 快(动态类型、REPL、无需编译) | 慢(静态类型、头文件、编译时间长) | +| 执行速度 | 比C慢约100倍(解释型、GIL) | 接近硬件速度(编译型、无开销) | +| 内存控制 | 自动(GC),无法控制布局 | 手动,精确控制每一个字节 | +| 硬件访问 | 无(无SIMD、无GPU、无自定义内存) | 全面(内联函数、CUDA、内联汇编) | +| 生态系统 | ML丰富(笔记本、可视化、数据) | 系统丰富(操作系统、驱动、引擎) | + +- 核心见解:**每种语言发挥其优势**。Python处理人力生产力重要的事务(实验设计、超参数调优、数据探索)。C++处理机器性能重要的事务(矩阵乘法、卷积、注意力核函数)。 + +- 一次矩阵乘法 `jnp.matmul(A, B)`,其中 $A$ 为 $4096 \times 4096$,执行约1370亿次浮点运算。在纯Python(嵌套循环)中需要约30分钟。在使用AVX-512 SIMD和多线程优化后的C++中,只需约10毫秒。差距达**180,000倍**。再多的Python技巧也无法弥合这一鸿沟。 + +## ML框架的结构 + +- 每个主流ML框架都遵循相同的架构: + +``` +用户代码(Python) + ↓ +Python API层(torch.nn、jax.numpy、numpy) + ↓ +调度/JIT编译器(torch.compile、XLA、NumPy调度) + ↓ +C++核函数库(ATen/PyTorch、XLA、BLAS/LAPACK) + ↓ +硬件特定后端(CUDA、cuDNN、MKL、oneDNN、Metal) + ↓ +硬件(CPU SIMD单元、GPU核心、TPU MXU) +``` + +### NumPy + +- NumPy的核心用C编写。当你调用 `np.dot(A, B)` 时,Python调用一个C函数,该函数调用BLAS(基本线性代数子程序),通常是Intel MKL或OpenBLAS。BLAS是手工优化的C和Fortran代码,使用SIMD指令、缓存感知的内存访问模式和多线程。数十年优化致力于让矩阵乘法更快。 + +- NumPy仅支持CPU,不使用GPU。但在CPU上,它极其快速,因为它委托给可用的最佳BLAS实现。 + +### PyTorch + +- PyTorch的计算引擎是**ATen**(张量库),用C++编写。ATen实现了约2000个张量操作(add、matmul、conv2d、softmax...),每个都有CPU和CUDA后端。 + +- 当你调用 `torch.matmul(A, B)` 时: + 1. Python调度到ATen的C++函数。 + 2. ATen检查设备(CPU或CUDA)和数据类型。 + 3. 在CPU上:调用MKL/OpenBLAS。在GPU上:调用cuBLAS(NVIDIA的GPU优化BLAS)。 + 4. 结果包装在Python张量对象中并返回。 + +- **torch.compile**(PyTorch 2.0+)更进一步:它追踪你的Python代码,构建计算图,并使用**Triton**(GPU)或**C++/OpenMP**(CPU)编译。编译后的代码融合操作,消除Python开销,可以比即时模式快2-5倍。 + +### JAX + +- JAX将Python函数编译为**XLA**(加速线性代数),Google的ML编译器。当你 `jax.jit` 一个函数时: + 1. JAX追踪函数,将操作捕获为XLA计算图(HLO——高级操作)。 + 2. XLA优化图:融合操作,消除冗余计算,优化内存布局。 + 3. XLA编译为目标后端:CPU(通过LLVM)、GPU(通过CUDA/PTX)或TPU(通过TPU特定指令)。 + 4. 编译后的代码直接在硬件上运行,零Python参与。 + +- 这就是为什么 `jax.jit` 如此重要:没有它,每个操作都是独立的Python→C++往返。有了它,整个函数是一个单一的编译核函数。 + +## 面向Python工程师的C++快速入门 + +- 你不需要成为C++专家。你需要理解足够的知识来阅读核函数代码、编写简单的扩展以及理解性能讨论。以下是精华内容。 + +### 类型和变量 + +```cpp +// C++需要显式类型(不像Python) +int count = 0; // 32位整数 +float loss = 0.5f; // 32位浮点数 +double lr = 3e-4; // 64位浮点数 +bool training = true; // 布尔值 + +// 数组(固定大小,栈分配) +float weights[1024]; // 1024个浮点数,内存中连续 + +// 指针:保存内存地址的变量 +float* ptr = weights; // ptr指向weights的第一个元素 +float val = ptr[42]; // 通过指针运算访问元素42 +// ptr[42] 等价于 *(ptr + 42) +``` + +- **指针**是与Python最大的概念差异。在Python中,一切都是引用,你从不需要思考内存地址。在C++中,指针让你直接访问内存——强大但危险(悬空指针、缓冲区溢出)。 + +### 函数 + +```cpp +// 函数声明:返回类型 名字(参数类型 参数名) +float relu(float x) { + return x > 0.0f ? x : 0.0f; +} + +// 传引用(避免拷贝大对象) +void scale_vector(std::vector& vec, float factor) { + for (size_t i = 0; i < vec.size(); i++) { + vec[i] *= factor; + } +} + +// const引用:只读,无拷贝 +float sum(const std::vector& vec) { + float total = 0.0f; + for (float x : vec) { // 基于范围的for循环(类似Python的for x in vec) + total += x; + } + return total; +} +``` + +### 内存:栈与堆 + +```cpp +// 栈分配:快速,自动生命周期(函数返回时释放) +float buffer[256]; // 栈上的256个浮点数 + +// 堆分配:手动,在函数外仍然存活 +float* data = new float[n]; // 在堆上分配n个浮点数 +// ... 使用data ... +delete[] data; // 必须手动释放(没有垃圾回收器) + +// 现代C++:智能指针(自动清理,类似Python引用) +#include +auto data = std::make_unique(n); // 离开作用域时自动释放 +``` + +- **关键规则**:栈快速但有限(通常1-8 MB)。大数组(张量、特征图)必须放在堆上。在Python中,一切都在堆上,GC处理清理。在C++中,你自行管理(或使用智能指针)。 + +### 模板(泛型) + +```cpp +// 适用于任何数值类型的函数 +template +T add(T a, T b) { + return a + b; +} + +add(1.5f, 2.5f); // 返回 4.0f +add(3, 4); // 返回 7 +``` + +- 模板是C++库(如ATen)编写适用于float16、float32、float64等的代码而不重复实现的方式。 + +### 标准库精华 + +```cpp +#include // 动态数组(类似Python list) +#include // 字符串类型 +#include // 哈希映射(类似Python dict) +#include // sort、find、transform等 +#include // 数学函数 + +std::vector vec = {1.0f, 2.0f, 3.0f}; +vec.push_back(4.0f); // 追加 +float first = vec[0]; // 索引 +size_t len = vec.size(); // 长度 + +std::unordered_map counts; +counts["hello"] = 5; // 插入 +if (counts.count("hello")) { } // 检查存在性 +``` + +## 何时编写自定义C++核函数 + +- 大多数ML工程师从不需要写C++。框架的内置操作覆盖了99%的用例。仅在以下情况考虑自定义C++: + +1. **框架中不存在你的操作**:新颖的激活函数、自定义注意力模式、无法表示为现有操作组合的特殊损失函数。 + +2. **融合操作以提高性能**:你的模型执行 `relu(layernorm(matmul(x, W) + b))`。每个操作启动一个独立的核函数,读写内存,并同步。一个融合核函数在一次遍历中完成所有工作,避免内存往返。这可快2-5倍。 + +3. **减少内存使用**:自定义核函数可以在不存储所有中间激活的情况下计算梯度(核函数级别的梯度检查点)。 + +4. **针对新型硬件**:新的加速器(如Cerebras、Groq)可能没有框架支持。你需要直接编写核函数。 + +- 对于情况1-2,**Triton**(第16章文件05)通常足够且比直接编写CUDA C更简单。只有在Triton无法表达你的需求时才下降到CUDA C。 + +## 如何将C++绑定到Python + +- 编写C++只是工作的一半。你还需要从Python调用它。 + +### pybind11(通用目的) + +- pybind11用最少的样板代码为C++函数创建Python绑定: + +```cpp +// my_ops.cpp +#include +#include +namespace py = pybind11; + +// 一个简单的自定义操作 +py::array_t custom_relu(py::array_t input) { + auto buf = input.request(); + float* ptr = static_cast(buf.ptr); + size_t n = buf.size; + + auto result = py::array_t(n); + float* out = static_cast(result.request().ptr); + + for (size_t i = 0; i < n; i++) { + out[i] = ptr[i] > 0 ? ptr[i] : 0; + } + return result; +} + +PYBIND11_MODULE(my_ops, m) { + m.def("custom_relu", &custom_relu, "自定义ReLU操作"); +} +``` + +```bash +# 编译 +pip install pybind11 +c++ -O3 -shared -std=c++17 -fPIC $(python3 -m pybind11 --includes) my_ops.cpp -o my_ops$(python3-config --extension-suffix) +``` + +```python +# 从Python使用 +import my_ops +import numpy as np + +x = np.array([-1.0, 2.0, -3.0, 4.0], dtype=np.float32) +y = my_ops.custom_relu(x) +print(y) # [0. 2. 0. 4.] +``` + +### PyTorch C++扩展 + +- PyTorch提供了一种简化的方式来添加自定义操作: + +```cpp +// custom_op.cpp +#include + +torch::Tensor custom_gelu(torch::Tensor x) { + return x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0))); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("custom_gelu", &custom_gelu, "自定义GELU激活函数"); +} +``` + +```python +# 动态加载和编译 +from torch.utils.cpp_extension import load + +custom_ops = load( + name="custom_ops", + sources=["custom_op.cpp"], + extra_cflags=["-O3"], +) + +x = torch.randn(1000) +y = custom_ops.custom_gelu(x) +``` + +- `torch.utils.cpp_extension.load` 编译C++代码,创建共享库,并将其作为Python模块加载,全在一个调用中完成。这是在PyTorch中实验自定义C++操作的最简单方式。 + +### JAX自定义调用 + +- JAX使用XLA自定义调用。过程更为复杂(你需要向XLA注册一个C函数),但概念相同:编写C/C++,绑定,从Python调用。 + +- 对于大多数JAX用户,**Pallas**(在文件05中介绍)是更好的选择:它让你用类似Python的语法编写GPU核函数,由XLA编译,无需离开JAX生态系统。 + +## 大局观 + +- 本文解释了Python和硬件之间的层次。本章剩余文件将深入探讨: + - **文件01**:硬件本身(CPU架构、GPU架构、内存系统) + - **文件02-03**:CPU上的SIMD编程(ARM NEON、x86 AVX)——编写使用CPU向量单元的C++代码 + - **文件04**:使用CUDA的GPU编程——编写在数千个GPU核心上运行的C++代码 + - **文件05**:Triton、Pallas和更高级的GPU编程——编写编译为GPU核函数的Python代码 + +- 这种递进反映了抽象阶梯:C++内联函数(最低层、最多控制)→ CUDA(GPU专用)→ Triton/Pallas(Python风格、编译型)→ JAX/PyTorch(最高层、自动)。每一层以控制权换取便利性。理解较低层使你成为较高层的更好使用者。 + +## 编程任务(用g++或clang++编译) + +1. 编写你的第一个C++程序。分配一个数组,填充数据,计算总和,并测量时间。这介绍了编译、数组、指针和计时。 +```cpp +// task1_basics.cpp +// 编译:g++ -O3 -o task1 task1_basics.cpp +// 运行:./task1 + +#include +#include +#include + +int main() { + const int N = 10'000'000; // C++允许'作为数字分隔符 + std::vector data(N); + + // 填充数组 + for (int i = 0; i < N; i++) { + data[i] = static_cast(i) * 0.001f; + } + + // 计算总和 + auto start = std::chrono::high_resolution_clock::now(); + float sum = 0.0f; + for (int i = 0; i < N; i++) { + sum += data[i]; + } + auto end = std::chrono::high_resolution_clock::now(); + double elapsed = std::chrono::duration(end - start).count(); + + std::cout << "总和: " << sum << std::endl; + std::cout << "时间: " << elapsed << " ms" << std::endl; + std::cout << "元素数: " << N << std::endl; + std::cout << "吞吐量: " << (N * sizeof(float)) / elapsed / 1e6 << " GB/s" << std::endl; + + return 0; +} +``` + +2. 编写一个C++函数在数组上计算ReLU,然后使用pybind11构建Python绑定。从Python调用它并与NumPy比较速度。 +```cpp +// task2_relu.cpp +// 编译:c++ -O3 -shared -std=c++17 -fPIC $(python3 -m pybind11 --includes) \ +// task2_relu.cpp -o my_relu$(python3-config --extension-suffix) + +#include +#include +namespace py = pybind11; + +py::array_t cpp_relu(py::array_t input) { + auto buf = input.request(); + float* ptr = static_cast(buf.ptr); + int n = buf.size; + + auto result = py::array_t(n); + float* out = static_cast(result.request().ptr); + + for (int i = 0; i < n; i++) { + out[i] = ptr[i] > 0.0f ? ptr[i] : 0.0f; + } + return result; +} + +PYBIND11_MODULE(my_relu, m) { + m.def("relu", &cpp_relu, "C++ ReLU"); +} +``` +```python +# test_relu.py — 在编译上述C++模块后运行 +import numpy as np +import time +import my_relu # 编译后的C++模块 + +x = np.random.randn(10_000_000).astype(np.float32) + +# C++ ReLU +start = time.time() +for _ in range(100): + y_cpp = my_relu.relu(x) +cpp_time = (time.time() - start) / 100 + +# NumPy ReLU +start = time.time() +for _ in range(100): + y_np = np.maximum(x, 0) +np_time = (time.time() - start) / 100 + +print(f"C++ ReLU: {cpp_time*1000:.2f} ms") +print(f"NumPy ReLU: {np_time*1000:.2f} ms") +print(f"匹配: {np.allclose(y_cpp, y_np)}") +``` + +3. 编写一个C++程序,演示为何内存布局很重要。比较行优先与列优先访问模式并测量性能差异。 +```cpp +// task3_layout.cpp +// 编译:g++ -O3 -o task3 task3_layout.cpp + +#include +#include +#include + +int main() { + const int N = 4096; + std::vector matrix(N * N, 1.0f); + + // 行优先访问:连续内存地址(缓存友好) + auto start = std::chrono::high_resolution_clock::now(); + float sum_row = 0.0f; + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + sum_row += matrix[i * N + j]; // 步长1访问 + } + } + auto end = std::chrono::high_resolution_clock::now(); + double row_ms = std::chrono::duration(end - start).count(); + + // 列优先访问:步长N访问(缓存不友好) + start = std::chrono::high_resolution_clock::now(); + float sum_col = 0.0f; + for (int j = 0; j < N; j++) { + for (int i = 0; i < N; i++) { + sum_col += matrix[i * N + j]; // 步长N访问(缓存缺失!) + } + } + end = std::chrono::high_resolution_clock::now(); + double col_ms = std::chrono::duration(end - start).count(); + + std::cout << "行优先(缓存友好): " << row_ms << " ms" << std::endl; + std::cout << "列优先(缓存不友好): " << col_ms << " ms" << std::endl; + std::cout << "减速比: " << col_ms / row_ms << "x" << std::endl; + std::cout << "(两个和: " << sum_row << ", " << sum_col << ")" << std::endl; + + return 0; +} +``` diff --git a/chapter 16: SIMD and GPU programming/01. hardware fundamentals.md b/chapter 16: SIMD and GPU programming/01. hardware fundamentals.md new file mode 100644 index 0000000..b0cd35c --- /dev/null +++ b/chapter 16: SIMD and GPU programming/01. hardware fundamentals.md @@ -0,0 +1,282 @@ +# 硬件基础 + +*在编写SIMD或GPU代码之前,你需要了解所编程的硬件。本文涵盖为什么并行性取代了时钟速度、现代CPU如何执行指令、什么是SIMD、用于推理性能的屋顶线模型,以及芯片架构的全景* + +- 几十年来,软件免费变快:购买一个时钟频率更高的新CPU,你的程序无需修改一行代码就能运行得更快。这个时代大约在2005年结束。理解它为何结束以及什么替代了它,对任何想编写快速代码的人都至关重要。 + +## 免费性能的终结 + +- **摩尔定律**(1965年)观察到芯片上的晶体管数量大约每两年翻一番。这一规律维持了60年。更多晶体管意味着更小的晶体管,进而意味着更高的时钟频率,从而意味着更快的程序。 + +- 但在2005年左右,时钟频率在大约4 GHz处撞上了墙壁。问题是**功耗**。芯片消耗的功率大约为: + +$$P \propto C \cdot V^2 \cdot f$$ + +- 其中 $C$ 是电容(与晶体管数量成正比),$V$ 是电压,$f$ 是时钟频率。要提高频率,必须提高电压(以使晶体管更快地切换)。但功耗与 $V^2 \cdot f$ 成比例,所以频率的小幅增加会导致功耗(和热量)的大幅增加。在4 GHz时,芯片已经达到100+瓦。达到8 GHz需要不切实际的冷却方案。 + +- 解决方案:不让单个核心更快,而是在同一芯片上放置**多个核心**。一个4核芯片在3 GHz下使用与单个核心在4.5 GHz下相似的功耗,但可以做4倍的并行工作。这就是为什么每个现代CPU都有多个核心,以及为什么并行性(SIMD、多线程、GPU计算)是获得更高性能的唯一途径。 + +- **对ML的影响**:一个在单核上需要10分钟的训练步骤,无法通过购买更快的CPU来加速。只能通过使用更多核心(数据并行性,第6章)、更宽的SIMD单元(本章)或GPU(数千个核心)来加速。 + +## 现代CPU如何执行指令 + +- 现代CPU核心远比第13章中简单的取指-译码-执行模型复杂。它使用几种技巧来每周期执行更多指令: + +- **超标量执行**:CPU有多个执行单元(ALU、FPU、加载/存储单元),可以同时执行多个独立的指令。如果指令不相互依赖,现代核心每周期可能执行4-6条指令。 + +- **乱序执行(OoO)**:CPU不按程序顺序执行指令。它向前看指令流,找到输入已准备好的指令,并立即执行,不论其位置。这隐藏了延迟:当一条指令等待来自内存的数据时(100+周期),CPU执行其他已准备好的指令。 + +- **分支预测**:条件分支(`if`语句、循环条件)造成不确定性:CPU在条件被评估之前不知道走哪条路径。为了避免停顿,CPU**预测**结果并沿预测路径投机执行。如果预测正确(使用现代预测器超过95%),则没有时间损失。如果错误,投机工作被丢弃,执行正确路径(约15周期惩罚)。 + +- **投机执行**:分支预测的延伸。CPU执行*可能*不需要的指令,赌它们会被需要。这填充了流水线并保持执行单元忙碌。 + +- 所有这些都是自动的——CPU无需任何程序员干预即可完成。但它们只帮助**指令级并行性(ILP)**:单条流内相互独立的指令。对于**数据级并行性**(对许多数据元素执行相同操作),我们需要SIMD。 + +## SIMD:单指令多数据 + +- **SIMD**是将一条指令同时应用于多个数据元素的思想。不是将两个数相加,而是在一条指令中将两个4(或8、或16)元素向量相加。 + +- 无SIMD(标量): + +```cpp +// 逐元素相加两数组:4条加法指令 +for (int i = 0; i < 4; i++) { + c[i] = a[i] + b[i]; // 每次迭代一次加法 +} +``` + +- 有SIMD(向量化): + +```cpp +// 两数组相加:1条SIMD指令完成所有4次加法 +#include // x86 SIMD内联函数 + +__m128 va = _mm_load_ps(a); // 加载4个浮点数到128位寄存器 +__m128 vb = _mm_load_ps(b); // 加载4个浮点数到另一个寄存器 +__m128 vc = _mm_add_ps(va, vb); // 同时相加所有4对 +_mm_store_ps(c, vc); // 存储4个结果 +``` + +- SIMD版本用1/4的指令完成相同工作。这是理论上的4倍加速,通过每条指令处理4个浮点数而非1个实现。 + +### 向量寄存器 + +- SIMD指令操作**向量寄存器**:保存多个数据元素的宽寄存器。 + +| 寄存器宽度 | 浮点数(32位) | 双精度浮点数(64位) | 名称 | +|---------------|-----------------|-------------------|------| +| 128位 | 4 | 2 | SSE(x86)、NEON(ARM) | +| 256位 | 8 | 4 | AVX/AVX2(x86) | +| 512位 | 16 | 8 | AVX-512(x86) | +| 可变(128-2048) | 可变 | 可变 | SVE/SVE2(ARM) | + +- 更宽的寄存器 = 更多并行性。一条512位AVX-512指令一次处理16个浮点数,是标量代码理论上的16倍加速。实际上,由于内存带宽限制(计算速度可能超过向CPU输送数据的速度),加速比更低。 + +- 对于ML:float32值的矩阵乘法从SIMD中获益巨大。内循环(两个向量的点积)直接映射到SIMD乘加指令。这就是为什么BLAS库(NumPy和PyTorch调用的)用SIMD进行了如此深度优化。 + +## 屋顶线模型 + +- 你如何知道你的代码是否快速?**屋顶线模型**提供了一个框架,根据两个硬件限制来描述性能: + +1. **峰值计算能力**(FLOPS):每秒最大浮点运算次数。对于一个4 GHz CPU,配备256位AVX(每条指令8个浮点数)和2个FMA单元:$4 \times 10^9 \times 8 \times 2 = 64$ GFLOPS。 + +2. **峰值内存带宽**(字节/秒):数据从内存到CPU的最大传输速度。现代CPU可能有50 GB/s的内存带宽。 + +- 代码的**算术强度**是计算与内存访问的比率: + +$$\text{算术强度} = \frac{\text{FLOPS}}{\text{传输的字节数}}$$ + +- 如果算术强度低(每加载字节的操作数少),你的代码是**内存受限的**:大部分时间花在等待数据上。让计算更快(更宽的SIMD、更高的时钟)不会有帮助。 + +- 如果算术强度高(每字节多次操作),你的代码是**计算受限的**:大部分时间花在计算上。更快的内存不会有帮助。 + +- 屋顶线: + +$$\text{可达FLOPS} = \min\left(\text{峰值FLOPS}, \; \text{带宽} \times \text{算术强度}\right)$$ + +- **矩阵乘法**具有高算术强度:$O(n^3)$ 次操作作用于 $O(n^2)$ 数据,因此强度 $\approx O(n)$。对于大矩阵,它是计算受限的。这就是为什么GPU(高计算能力)主导矩阵密集型的ML工作负载。 + +- **逐元素操作**(ReLU、加法、乘法)具有低算术强度:每加载一个元素1次操作。这些是内存受限的。让GPU更快没有帮助;你需要更快的内存(或者将这些操作与计算密集型操作融合,以避免独立的内存往返)。 + +- 屋顶线模型解释了为什么**核函数融合**如此重要:将matmul与偏置加法和ReLU组合成一个核函数,避免了将中间结果写入内存并重新读取,将三个内存受限操作转化为一个计算受限操作。 + +## 延迟与吞吐量 + +- **延迟**是完成一个操作所需的时间。**吞吐量**是单位时间内完成的操作数量。 + +- 打个比方:公交车延迟高(每站都停),但吞吐量高(一次搭载50人)。出租车延迟低(直达你的目的地),但吞吐量低(搭载1-4人)。 + +- GPU是公交车:每次操作延迟高(每条指令需要许多周期完成),但吞吐量巨大(数千个核心同时处理)。CPU是出租车:延迟低(乱序执行、分支预测、深层缓存最小化延迟),但吞吐量有限(4-64个核心)。 + +- 这就是为什么GPU更适合ML训练(吞吐量重要:处理数百万个样本)而CPU更适合操作系统任务(延迟重要:立即响应按键)。 + +- **流水线**将延迟转化为吞吐量。如果一条指令需要5个周期,但流水线每周期开始一条新指令,则吞吐量是每条指令1周期(即使每条指令需要5个周期完成)。这和第13章的CPU流水线是同一原理,但它适用于每个层面:SIMD单元、内存控制器和GPU核心都是流水线化的。 + +## 芯片架构全景 + +- 你编写代码的硬件决定了哪些SIMD指令可用: + +### x86(Intel, AMD) + +- 主导台式机、笔记本电脑和数据中心CPU。SIMD:SSE(128位)、AVX/AVX2(256位)、AVX-512(512位)。Intel AMX提供专用的矩阵乘法单元用于AI工作负载。 + +- **优势**:最高单核性能、最宽SIMD、成熟的软件生态系统(MKL、oneDNN)。 +- **弱点**:高功耗、复杂指令集、昂贵。 + +### ARM + +- 主导移动设备(每一部智能手机),在服务器(AWS Graviton、Ampere Altra)和笔记本电脑(Apple M系列)中增长。SIMD:NEON(128位)、SVE/SVE2(可伸缩,128-2048位)。 + +- **优势**:出色的功耗效率(每瓦性能)、自定义核心(Apple M4在单核性能上媲美Intel,功耗仅为其一小部分)。 +- **弱点**:较窄的SIMD(NEON仅为128位,虽SVE可更宽)、用于HPC的软件生态系统较小。 + +### Apple Silicon(M1/M2/M3/M4) + +- 基于ARM并带有自定义扩展。包含**AMX**(Apple矩阵扩展)——未公开的矩阵乘法单元,Accelerate框架将其用于BLAS操作。统一内存架构:CPU和GPU共享同一物理内存,消除了CPU↔GPU拷贝的瓶颈。 + +- **对于ML**:Apple的神经网络引擎(16核,专用ML加速器)和统一内存使M系列芯片在本地ML推理和小规模训练方面出奇地强大。不过没有CUDA:你必须使用Metal(Apple的GPU API)或MLX(Apple的ML框架)。 + +### RISC-V + +- 开源ISA。无许可费用(不像ARM)。在嵌入式系统、物联网和研究领域增长。SIMD:"V"(向量)扩展提供类似于ARM SVE的可伸缩向量处理。 + +- **对于ML**:在ML工作负载上尚不能与x86/ARM竞争,但值得关注。几个AI加速器初创公司使用RISC-V核心。 + +### GPU(NVIDIA、AMD、Intel) + +- 在文件04-05中深入介绍。数千个为吞吐量优化的简单核心。NVIDIA以CUDA主导ML;AMD以ROCm竞争;Intel以Arc GPU和Gaudi加速器进入市场。 + +### TPU(Google) + +- 专门为ML设计的自定义ASIC。为矩阵乘法优化的脉动阵列。在文件05中介绍。 + +## 热与功耗约束 + +- 性能最终受限于功耗和散热: + +- **TDP**(热设计功耗):芯片可以持续消耗的最大功率。笔记本电脑CPU可能有15W TDP;服务器CPU 250W;数据中心GPU 700W(NVIDIA B200)。 + +- **暗硅**:在任何给定时刻,为了保持在热预算内,必须关闭芯片的相当大一部分晶体管。理论上芯片可以同时使用所有晶体管,但会熔化。 + +- **功耗效率**(FLOPS/瓦)日益成为重要指标,而非原始FLOPS。这就是为什么: + - ARM正在接管数据中心(相比x86更好的FLOPS/瓦)。 + - TPU尽管峰值FLOPS较低,但仍与GPU竞争(对于ML工作负载,FLOPS/瓦好得多)。 + - 量化(INT8、FP8)不仅关乎内存:它也降低了每次操作的功耗。 + +- 对于大规模ML:训练前沿LLM数月消耗兆瓦级功率。电费可能超过硬件成本。功耗效率直接影响AI研究的经济性。 + +## 实践:在C++中测量性能 + +- 要推理性能,你需要测量它。以下是一个最小的C++基准测试设置: + +```cpp +#include +#include +#include + +// 标量加法 +void add_scalar(const float* a, const float* b, float* c, int n) { + for (int i = 0; i < n; i++) { + c[i] = a[i] + b[i]; + } +} + +int main() { + const int N = 1 << 24; // 约1600万个元素 + std::vector a(N, 1.0f), b(N, 2.0f), c(N); + + // 预热(填充缓存,触发频率缩放) + add_scalar(a.data(), b.data(), c.data(), N); + + // 基准测试 + auto start = std::chrono::high_resolution_clock::now(); + + for (int trial = 0; trial < 100; trial++) { + add_scalar(a.data(), b.data(), c.data(), N); + } + + auto end = std::chrono::high_resolution_clock::now(); + double elapsed = std::chrono::duration(end - start).count(); + + double total_bytes = 3.0 * N * sizeof(float) * 100; // 读a、读b、写c + double bandwidth = total_bytes / elapsed / 1e9; // GB/s + + std::cout << "时间: " << elapsed << " s\n"; + std::cout << "带宽: " << bandwidth << " GB/s\n"; + + return 0; +} +``` + +```bash +# 启用优化编译 +g++ -O3 -march=native -o bench bench.cpp +./bench +``` + +- **这段代码中的关键C++概念**: + - `#include `:动态数组(`std::vector`)——类似Python的`list`但带类型且在内存中连续。 + - `a.data()`:返回底层数组的原始指针(`float*`)——SIMD内联函数需要。 + - `std::chrono`:用于基准测试的高分辨率计时器。 + - `-O3`:最高编译器优化级别。编译器可能自动向量化你的循环(自动使用SIMD)。`-march=native` 启用你的CPU支持的所有SIMD指令。 + +- **为什么需要预热**:首次运行填充缓存并可能触发CPU频率缩放(睿频加速)。后续运行更具代表性。 + +- **为什么测量带宽**:对于内存受限的操作(如逐元素加法),有意义的度量是带宽(GB/s),而不是FLOPS。如果你的测量带宽接近硬件极限(DDR5约50 GB/s),你是内存受限的,SIMD不会有多大帮助(瓶颈是内存,而非计算)。 + +## 编程任务(使用CoLab或笔记本) + +1. 计算常见ML操作的算术强度,并将它们分类为内存受限或计算受限。 +```python +import jax.numpy as jnp + +def arithmetic_intensity(flops, bytes_transferred): + return flops / bytes_transferred + +# 逐元素ReLU:每元素1次比较,读取+写入 +n = 1024 +relu_flops = n # 每元素1次操作 +relu_bytes = 2 * n * 4 # 读取输入+写入输出(float32) +print(f"ReLU: {arithmetic_intensity(relu_flops, relu_bytes):.2f} FLOPS/byte → 内存受限") + +# 矩阵乘法:2*n^3次操作,读取2*n^2 + 写入n^2个浮点数 +matmul_flops = 2 * n**3 +matmul_bytes = 3 * n**2 * 4 # 读取A + 读取B + 写入C +print(f"矩阵乘法 ({n}×{n}): {arithmetic_intensity(matmul_flops, matmul_bytes):.0f} FLOPS/byte → 计算受限") + +# 层归一化:约5n次操作(均值、方差、归一化),读取+写入 +ln_flops = 5 * n +ln_bytes = 2 * n * 4 +print(f"LayerNorm: {arithmetic_intensity(ln_flops, ln_bytes):.2f} FLOPS/byte → 内存受限") + +# 3x3卷积:2*9*C_in*C_out*H*W,读取卷积核+特征图+写入输出 +C_in, C_out, H, W = 64, 128, 32, 32 +conv_flops = 2 * 9 * C_in * C_out * H * W +conv_bytes = (9 * C_in * C_out + C_in * H * W + C_out * H * W) * 4 +print(f"Conv3x3: {arithmetic_intensity(conv_flops, conv_bytes):.0f} FLOPS/byte → 计算受限") +``` + +2. 演示为什么并行性重要。比较顺序执行与并行(NumPy)执行随数据规模增长的表现。 +```python +import numpy as np +import time + +for n in [1000, 10000, 100000, 1000000, 10000000]: + a = np.random.randn(n).astype(np.float32) + b = np.random.randn(n).astype(np.float32) + + # "顺序执行"(Python循环) + start = time.time() + c = [a[i] * b[i] for i in range(min(n, 100000))] # 上限10万以确保合理 + seq_time = time.time() - start + if n > 100000: + seq_time *= n / 100000 # 外推 + + # "并行"(NumPy,内部使用SIMD+多线程) + start = time.time() + c = a * b + par_time = time.time() - start + + print(f"n={n:>10,} 顺序={seq_time:.4f}s 并行={par_time:.6f}s " + f"加速比={seq_time/par_time:.0f}x") +``` diff --git a/chapter 16: SIMD and GPU programming/02. ARM and NEON.md b/chapter 16: SIMD and GPU programming/02. ARM and NEON.md new file mode 100644 index 0000000..24a9247 --- /dev/null +++ b/chapter 16: SIMD and GPU programming/02. ARM and NEON.md @@ -0,0 +1,484 @@ +# ARM与NEON + +*ARM处理器驱动着每一部智能手机、大多数平板电脑、Apple的笔记本电脑以及日益增长的数据中心服务器份额。本文涵盖ARM架构、使用C++内联函数的NEON SIMD编程、用于可伸缩向量处理的SVE/SVE2、Apple Silicon特性以及实际向量化核函数示例* + +- 如果你拥有iPhone、MacBook或使用AWS Graviton实例,你正在运行ARM。ARM的功耗效率使其在移动和嵌入式领域占据主导地位,并在服务器和ML推理方面日益具有竞争力。理解ARM SIMD让你能够编写在大多数人实际使用的硬件上快速运行的代码。 + +- 有关生产中ARM SIMD核函数的实际例子,请参见**Cactus**——面向移动设备和可穿戴设备的低延迟AI引擎:[github.com/cactus-compute/cactus](https://github.com/cactus-compute/cactus)。Cactus实现了自定义ARM NEON和NPU加速的注意机制、KV缓存量化和分块预填充核函数,在ARM CPU上实现了最快的推理,且RAM比其它引擎低10倍。其三层架构(引擎→图→核函数)是本文中SIMD概念如何用于构建生产级ML基础设施的具体实例。 + +## ARM架构基础 + +- ARM是一种**RISC**(精简指令集计算机)架构(第13章)。关键特征: + + - **加载-存储架构**:算术指令只操作寄存器,从不直接操作内存。要对内存中的两个数相加,你必须:(1) 将它们加载到寄存器,(2) 将寄存器相加,(3) 将结果存回内存。这比x86更简单(x86可以在一条指令中加一个寄存器和一个内存位置),但使得流水线更清晰。 + + - **定长指令**:每个ARMv8(AArch64)指令恰好32位。这使得解码快速且可预测(不像x86的可变长指令,长度可以是1-15字节)。 + + - **32个通用寄存器**(x0-x30,每个64位)加上栈指针(sp)和零寄存器(xzr)。相比之下x86有16个通用寄存器。更多寄存器 = 更少内存访问 = 更快代码。 + + - **32个SIMD/浮点寄存器**(v0-v31,每个128位)用于NEON和浮点操作。 + +```cpp +// ARM汇编(仅感受风格——你将使用内联函数,而非汇编) +// 两寄存器相加 +add x0, x1, x2 // x0 = x1 + x2 + +// 从内存加载 +ldr x0, [x1] // x0 = *x1(从x1中的地址加载64位) + +// NEON:加四个浮点数 +fadd v0.4s, v1.4s, v2.4s // v0 = v1 + v2(四个32位浮点数) +``` + +- 你不会写汇编。你将使用**内联函数**:与特定指令一对一映射的C/C++函数。编译器处理寄存器分配、调度和其他底层细节。 + +## NEON:128位SIMD + +- **NEON**是ARM的SIMD扩展。每个NEON寄存器宽128位,可容纳: + +| 数据类型 | 每寄存器元素数 | 表示法 | +|-----------|---------------|----------| +| float32 | 4 | `float32x4_t` | +| float16 | 8 | `float16x8_t` | +| int32 | 4 | `int32x4_t` | +| int16 | 8 | `int16x8_t` | +| int8 | 16 | `int8x16_t` | + +- 128位比x86的AVX(256位)或AVX-512(512位)窄。但ARM以出色的功耗效率和广泛的可用性弥补了这一点。 + +### NEON内联函数:基础 + +- NEON内联函数遵循命名约定:`v[操作][限定符]_[类型]` + +```cpp +#include + +// 从内存加载4个浮点数到NEON寄存器 +float32x4_t a = vld1q_f32(ptr); // vld1q = vector load 1, q = 128位(四字) + +// 从NEON寄存器存储4个浮点数到内存 +vst1q_f32(out_ptr, a); // vst1q = vector store 1, q = 128位 + +// 算术运算 +float32x4_t c = vaddq_f32(a, b); // c = a + b(4个浮点数) +float32x4_t d = vmulq_f32(a, b); // d = a * b(4个浮点数) +float32x4_t e = vfmaq_f32(c, a, b); // e = c + a * b(融合乘加,4个浮点数) + +// 比较(返回掩码:若真则全1,若假则全0) +uint32x4_t mask = vcgtq_f32(a, b); // mask[i] = (a[i] > b[i]) ? 0xFFFFFFFF : 0 + +// 基于掩码选择元素(类似numpy.where) +float32x4_t result = vbslq_f32(mask, a, b); // result[i] = mask[i] ? a[i] : b[i] + +// 归约:将所有4个元素求和为标量 +float total = vaddvq_f32(a); // total = a[0] + a[1] + a[2] + a[3] +``` + +- **`vfmaq_f32`**(融合乘加)是ML最重要的SIMD指令。它用一次舍入步骤计算 $c = c + a \times b$(比分开乘然后加更精确)。点积、矩阵乘法和卷积都由FMA构建。 + +### 实践示例:向量化点积 + +- 点积是矩阵乘法的内循环。让我们先用标量C++编写,然后用NEON向量化。 + +```cpp +#include + +// 标量点积 +float dot_scalar(const float* a, const float* b, int n) { + float sum = 0.0f; + for (int i = 0; i < n; i++) { + sum += a[i] * b[i]; + } + return sum; +} + +// NEON向量化点积 +float dot_neon(const float* a, const float* b, int n) { + float32x4_t sum_vec = vdupq_n_f32(0.0f); // 初始化4个累加器为0 + + int i = 0; + for (; i + 4 <= n; i += 4) { + float32x4_t va = vld1q_f32(a + i); // 从a加载4个元素 + float32x4_t vb = vld1q_f32(b + i); // 从b加载4个元素 + sum_vec = vfmaq_f32(sum_vec, va, vb); // sum_vec += va * vb + } + + // 将4个累加器归约为单一标量 + float sum = vaddvq_f32(sum_vec); + + // 处理剩余元素(如果n不是4的倍数) + for (; i < n; i++) { + sum += a[i] * b[i]; + } + + return sum; +} +``` + +- **关键C++概念**: + - `const float*`:指向只读浮点数据的指针。`const` 承诺我们不会通过此指针修改数据。 + - `a + i`:指针运算。`a + i` 指向数组的第 $i$ 个元素(等同于 `&a[i]`)。 + - 末尾的"清理循环"处理 $n$ 不是4的倍数的情况。这是SIMD代码中的通用模式:用向量化块处理主体部分,然后用标量代码处理余数。 + +- **为什么 `sum_vec` 中使用4个累加器**:我们使用4个独立的累加器(每个SIMD通道一个),而不是单个标量累加器。这避免了数据依赖:每次迭代的FMA依赖于 `sum_vec`,但有了4个独立通道,CPU可以对FMAs进行流水线处理。最后,我们将4个部分和归约为一个。 + +### 实践示例:向量化ReLU + +```cpp +#include + +void relu_neon(const float* input, float* output, int n) { + float32x4_t zero = vdupq_n_f32(0.0f); + + int i = 0; + for (; i + 4 <= n; i += 4) { + float32x4_t x = vld1q_f32(input + i); + float32x4_t result = vmaxq_f32(x, zero); // max(x, 0) = ReLU + vst1q_f32(output + i, result); + } + + // 标量清理 + for (; i < n; i++) { + output[i] = input[i] > 0 ? input[i] : 0; + } +} +``` + +- `vmaxq_f32` 计算两个向量的逐元素最大值。由于一个向量全为零,这恰好就是ReLU。无需分支,无需比较——仅一条指令。 + +## I8MM:整数矩阵乘法 + +- **I8MM**(Int8矩阵乘法)是ARMv8.6扩展,增加了用于INT8矩阵乘法(INT32累加)的专用指令——这正是量化ML推理所需要的。 + +- 关键指令是 **`SMMLA`**(有符号矩阵乘加):它接受两个8×2块的INT8值,并将结果累加到2×2块的INT32中: + +```cpp +#include + +// I8MM:将两个8元素INT8向量相乘,累加到4个INT32结果中 +// 这从2x8 × 8x2输入块计算输出矩阵的一个2x2瓦片 +void matmul_i8mm_tile(const int8_t* A, const int8_t* B, int32_t* C) { + // 从A加载8字节(2行各4元素,打包) + int8x16_t va = vld1q_s8(A); // 16字节 = 2行 × 8元素 + int8x16_t vb = vld1q_s8(B); // 16字节 = 2行 × 8元素 + + // 加载现有累加器(2x2 = 4个int32值) + int32x4_t acc = vld1q_s32(C); + + // I8MM指令:acc += A_tile × B_tile^T + // 从2×8 × 8×2输入计算2×2输出 + acc = vmmlaq_s32(acc, va, vb); // I8MM指令 + + vst1q_s32(C, acc); +} +``` + +- **为什么I8MM重要**:没有I8MM时,NEON上的INT8矩阵乘法需要加宽乘法(`vmull`)后跟成对加法——每个输出元素需要多条指令。有了I8MM,硬件在一条指令中完成8元素点积(2×8 × 8×2 = 2×2)。对于INT8推理工作负载,这比纯NEON快4-8倍。 + +- **可用性**:Apple M1+(所有Apple Silicon)、ARM Cortex-A510/A710/X2+(ARMv9)、AWS Graviton3+。用 `#ifdef __ARM_FEATURE_MATMUL_INT8` 检查。 + +- 对于ML推理:在ARM服务器(Graviton)或Apple Silicon上运行的INT8量化模型(第18章)从I8MM中获益巨大。ONNX Runtime和llama.cpp等框架在运行时检测I8MM并自动使用优化核函数。 + +## SME和SME2:可伸缩矩阵扩展 + +- **SME**(可伸缩矩阵扩展)是ARM对Intel AMX和NVIDIA张量核心的回应:用于矩阵操作的专用硬件。SME2(ARMv9.2)进一步扩展了它。 + +- SME引入了**ZA瓦片寄存器**:存储在硬件中的2D矩阵,最大可达SVL×SVL字节(其中SVL是流向量长度,通常每维128-512位)。与NEON(1D向量)甚至SVE(1D可伸缩向量)不同,SME原生操作**2D瓦片**。 + +- 编程模型有两种模式: + - **普通模式**:标准ARM执行(NEON、SVE正常工作)。 + - **流SVE模式**:通过 `smstart` 进入,启用SME指令。SVE指令在此模式下也可工作,但可能使用不同的寄存器宽度。 + +```cpp +#include + +// SME2:矩阵乘法的外积累加 +// 将A_col × B_row 累加到ZA瓦片寄存器中 +void sme2_matmul_outer(const float* A_col, const float* B_row, int K) { + // 进入流模式 + // smstart; // (通过编译器内联或内联汇编完成) + + // 清零ZA瓦片累加器 + svzero_za(); + + for (int k = 0; k < K; k++) { + // 将A的一列和B的一行加载到SVE寄存器中 + svfloat32_t a = svld1_f32(svptrue_b32(), &A_col[k * SVL]); + svfloat32_t b = svld1_f32(svptrue_b32(), &B_row[k * SVL]); + + // 外积:ZA += a × b^T + // 这在一个指令中累加一个SVL×SVL瓦片 + svmopa_za32_f32_m(0, svptrue_b32(), svptrue_b32(), a, b); + } + + // 将ZA瓦片存储到内存 + // svst1_za(...); + + // 退出流模式 + // smstop; +} +``` + +- **关键概念**: + - **`svmopa`**(外积累加):核心SME指令。它计算两个向量的完整外积并累加到ZA瓦片中。对于SVL=512位(16个浮点数),这是一个16×16外积——一条指令中256次FMA操作。 + - **ZA瓦片**:在流模式中跨指令持久存在。你将多个外积(每个K迭代一个)累加到同一瓦片中,构建完整的矩阵乘法瓦片。 + - **流模式**:SME指令仅在流模式下工作。进入/退出流模式的开销意味着SME最适合持续的矩阵计算,而非短时爆发。 + +- **SME2新增**:多向量操作(同时处理2或4个SVE向量)、额外的瓦片操作以及与普通模式的改进集成。 + +- **可用性**:ARM Neoverse V2(AWS Graviton4)、一些即将推出的移动芯片。截至2026年尚未出现在Apple Silicon上。SME仍处于早期阶段——大多数ML框架还没有SME优化的核函数。 + +- **演进脉络**:NEON(128位向量,逐元素)→ I8MM(INT8矩阵瓦片)→ SVE(可伸缩向量)→ SME(可伸缩2D矩阵瓦片)。每一代都更接近硬件原生矩阵操作。 + +## SVE和SVE2:可伸缩向量扩展 + +- NEON具有固定的128位宽度。**SVE**(可伸缩向量扩展)引入了**向量长度无关(VLA)编程**:你编写一次代码,它在任何向量宽度(128到2048位)的硬件上运行。硬件在运行时确定宽度。 + +```cpp +#include + +void add_sve(const float* a, const float* b, float* c, int n) { + int i = 0; + svbool_t pred = svwhilelt_b32(i, n); // 谓词:哪些通道是激活的 + + while (svptest_any(svptrue_b32(), pred)) { + svfloat32_t va = svld1(pred, a + i); + svfloat32_t vb = svld1(pred, b + i); + svst1(pred, c + i, svadd_x(pred, va, vb)); + + i += svcntw(); // 按硬件向量宽度前进(以32位元素计) + pred = svwhilelt_b32(i, n); + } +} +``` + +- **谓词寄存器**(`svbool_t`)取代了标量清理循环。每个通道有一个谓词位:激活的通道参与,非激活的被屏蔽。`svwhilelt_b32(i, n)` 指令创建一个谓词,其中对应 `i, i+1, ..., n-1` 的通道被激活。这自动处理了尾部。 + +- **`svcntw()`** 在运行时返回每个向量寄存器中32位元素的数量。在具有256位SVE的CPU上,返回8。在512位SVE上,返回16。你的代码自动适应。 + +- SVE在ARM Neoverse V1/V2上可用(AWS Graviton3/4,一些服务器芯片)。在Apple Silicon上尚不可用。 + +## Apple Silicon特性 + +- Apple的M系列芯片(M1、M2、M3、M4)是基于ARM的自定义微架构: + +- **性能核心和效率核心**:P核心(Firestorm/Avalanche等)用于重型计算,E核心(Icestorm/Blizzard等)用于后台任务。调度器将线程分配给适当的核心类型。 + +- **AMX**(Apple矩阵扩展):专用矩阵乘法单元,独立于NEON。AMX未公开(Apple不发布ISA),但Accelerate框架内部将其用于BLAS操作。当你在Mac上调用 `np.dot` 时,它通过Accelerate,后者使用AMX。你不能直接对AMX编程(除非逆向工程)。 + +- **统一内存**:CPU和GPU共享同一物理RAM。在其他系统上,数据必须从CPU内存拷贝到GPU内存(通过PCIe,约32 GB/s)。在Apple Silicon上,无需拷贝——GPU读取CPU写入的同一内存。这消除了ML工作负载的主要瓶颈。 + +- **神经网络引擎**:一个16核专用ML加速器。INT8推理时达到约30 TOPS(每秒万亿次操作)。Core ML将其用于设备端推理。 + +- **Apple Silicon上的ML**:使用MLX(Apple的ML框架),它专为统一内存架构设计。PyTorch也有MPS(Metal性能着色器)后端支持,尽管不如CUDA成熟。 + +## 自动向量化 + +- 编写SIMD内联函数很繁琐。**编译器**能自动向量化你的代码吗? + +- 可以的,但有限制。现代编译器(GCC、Clang)可以自动向量化简单循环: + +```cpp +// 编译器可以自动向量化此代码(使用 -O3 -march=native) +void add_auto(const float* a, const float* b, float* c, int n) { + for (int i = 0; i < n; i++) { + c[i] = a[i] + b[i]; + } +} +``` + +- **有助于自动向量化的模式**: + - 简单的循环,迭代次数已知。 + - 迭代之间无数据依赖(`c[i]` 不依赖于 `c[i-1]`)。 + - 连续内存访问(无分散/聚集)。 + - `const` 和 `restrict` 指针(告知编译器数组不重叠)。 + +```cpp +// restrict 告诉编译器:a、b、c 指向不重叠的内存 +void add_restrict(const float* __restrict__ a, + const float* __restrict__ b, + float* __restrict__ c, int n) { + for (int i = 0; i < n; i++) { + c[i] = a[i] + b[i]; + } +} +``` + +- 没有 `restrict`,编译器必须假设 `c` 可能与 `a` 或 `b` 重叠(写入 `c[i]` 可能改变 `a[i+1]`),从而阻止向量化。 + +- **阻止自动向量化的模式**: + - 数据依赖:`a[i] = a[i-1] + b[i]`(每次迭代依赖前一次)。 + - 复杂控制流:循环内的 `if` 语句(除非编译器能转换为谓词化)。 + - 循环内的函数调用(除非函数被内联)。 + - 指针别名(数组可能重叠,没有 `restrict`)。 + +- **检查自动向量化**:使用编译器标志查看哪些被向量化了: + +```bash +# GCC:显示向量化决策 +g++ -O3 -march=native -fopt-info-vec-optimized code.cpp + +# Clang:显示向量化报告 +clang++ -O3 -march=native -Rpass=loop-vectorize code.cpp +``` + +- **何时使用内联函数 vs 自动向量化**:从干净的C++和编译器优化开始。如果编译器向量化了你的循环,很好。如果性能仍不足,检查编译器的向量化报告以理解原因,然后仅为关键内循环编写内联函数。过早使用内联函数会让代码难以阅读而没有确定的收益。 + +## 编程任务(在ARM上用g++或clang++编译——Mac M系列或Linux aarch64) + +1. 编写标量点积和NEON向量化点积。对两者进行基准测试并测量加速比。 +```cpp +// task1_neon_dot.cpp +// 编译(Mac/ARM Linux):clang++ -O3 -o task1 task1_neon_dot.cpp +// 注意:NEON在AArch64上默认启用,无需特殊标志 + +#include +#include +#include +#include + +float dot_scalar(const float* a, const float* b, int n) { + float sum = 0.0f; + for (int i = 0; i < n; i++) { + sum += a[i] * b[i]; + } + return sum; +} + +float dot_neon(const float* a, const float* b, int n) { + float32x4_t sum_vec = vdupq_n_f32(0.0f); + int i = 0; + for (; i + 4 <= n; i += 4) { + float32x4_t va = vld1q_f32(a + i); + float32x4_t vb = vld1q_f32(b + i); + sum_vec = vfmaq_f32(sum_vec, va, vb); + } + float sum = vaddvq_f32(sum_vec); + for (; i < n; i++) sum += a[i] * b[i]; + return sum; +} + +int main() { + const int N = 10'000'000; + std::vector a(N, 1.0f), b(N, 2.0f); + + // 预热 + volatile float s1 = dot_scalar(a.data(), b.data(), N); + volatile float s2 = dot_neon(a.data(), b.data(), N); + + // 标量基准测试 + auto start = std::chrono::high_resolution_clock::now(); + for (int t = 0; t < 100; t++) { + s1 = dot_scalar(a.data(), b.data(), N); + } + auto end = std::chrono::high_resolution_clock::now(); + double scalar_ms = std::chrono::duration(end - start).count() / 100; + + // NEON基准测试 + start = std::chrono::high_resolution_clock::now(); + for (int t = 0; t < 100; t++) { + s2 = dot_neon(a.data(), b.data(), N); + } + end = std::chrono::high_resolution_clock::now(); + double neon_ms = std::chrono::duration(end - start).count() / 100; + + std::cout << "标量: " << scalar_ms << " ms(结果: " << s1 << ")\n"; + std::cout << "NEON: " << neon_ms << " ms(结果: " << s2 << ")\n"; + std::cout << "加速比: " << scalar_ms / neon_ms << "x\n"; + return 0; +} +``` + +2. 实现NEON ReLU和softmax最大值查找。练习使用不同操作的加载→计算→存储模式。 +```cpp +// task2_neon_ops.cpp +// 编译:clang++ -O3 -o task2 task2_neon_ops.cpp + +#include +#include +#include +#include + +void relu_neon(const float* in, float* out, int n) { + float32x4_t zero = vdupq_n_f32(0.0f); + int i = 0; + for (; i + 4 <= n; i += 4) { + float32x4_t x = vld1q_f32(in + i); + vst1q_f32(out + i, vmaxq_f32(x, zero)); + } + for (; i < n; i++) out[i] = in[i] > 0 ? in[i] : 0; +} + +float max_neon(const float* data, int n) { + float32x4_t max_vec = vdupq_n_f32(-INFINITY); + int i = 0; + for (; i + 4 <= n; i += 4) { + max_vec = vmaxq_f32(max_vec, vld1q_f32(data + i)); + } + float result = vmaxvq_f32(max_vec); + for (; i < n; i++) result = result > data[i] ? result : data[i]; + return result; +} + +int main() { + std::vector data = {-3, 1, -1, 4, 2, -5, 0, 7, -2, 3}; + std::vector out(data.size()); + + relu_neon(data.data(), out.data(), data.size()); + std::cout << "ReLU: "; + for (float x : out) std::cout << x << " "; + std::cout << "\n"; + + float mx = max_neon(data.data(), data.size()); + std::cout << "最大值: " << mx << "(期望值: 7)\n"; + return 0; +} +``` + +3. 比较自动向量化代码与手写NEON内联函数。用 `-fopt-info-vec`(GCC)或 `-Rpass=loop-vectorize`(Clang)编译以查看编译器的操作。 +```cpp +// task3_auto_vs_manual.cpp +// 编译:clang++ -O3 -Rpass=loop-vectorize -o task3 task3_auto_vs_manual.cpp +// (或):g++ -O3 -fopt-info-vec-optimized -o task3 task3_auto_vs_manual.cpp + +#include +#include +#include +#include + +// 让编译器自动向量化 +void add_auto(const float* __restrict__ a, const float* __restrict__ b, + float* __restrict__ c, int n) { + for (int i = 0; i < n; i++) { + c[i] = a[i] + b[i]; + } +} + +// 手写NEON +void add_neon(const float* a, const float* b, float* c, int n) { + int i = 0; + for (; i + 4 <= n; i += 4) { + vst1q_f32(c + i, vaddq_f32(vld1q_f32(a + i), vld1q_f32(b + i))); + } + for (; i < n; i++) c[i] = a[i] + b[i]; +} + +int main() { + const int N = 10'000'000; + std::vector a(N, 1.0f), b(N, 2.0f), c(N); + + auto bench = [&](auto fn, const char* name) { + fn(a.data(), b.data(), c.data(), N); // 预热 + auto start = std::chrono::high_resolution_clock::now(); + for (int t = 0; t < 100; t++) fn(a.data(), b.data(), c.data(), N); + auto end = std::chrono::high_resolution_clock::now(); + double ms = std::chrono::duration(end - start).count() / 100; + std::cout << name << ": " << ms << " ms\n"; + }; + + bench(add_auto, "自动向量化"); + bench(add_neon, "手写NEON"); + // 它们应该非常接近——编译器能很好地自动向量化这个简单循环 + return 0; +} +``` diff --git a/chapter 16: SIMD and GPU programming/03. x86 and AVX.md b/chapter 16: SIMD and GPU programming/03. x86 and AVX.md new file mode 100644 index 0000000..31b4792 --- /dev/null +++ b/chapter 16: SIMD and GPU programming/03. x86 and AVX.md @@ -0,0 +1,450 @@ +# x86与AVX + +*x86处理器来自Intel和AMD,主导着大多数ML训练所在的数据中心服务器。本文涵盖x86 SIMD的演进、AVX/AVX2内联函数编程、AVX-512、用于矩阵操作的Intel AMX、内存对齐、性能陷阱以及性能分析——在全球最常见的服务器CPU上榨取最大性能的工具。* + +- 如果你的训练在云虚拟机(AWS、GCP、Azure)上运行,它几乎肯定运行在x86上。即使是GPU密集训练也有CPU瓶颈:数据加载、预处理、梯度聚合和检查点保存都在CPU上运行。使用x86 SIMD优化这些环节可以有意义地减少端到端训练时间。 + +## x86 SIMD演进 + +- x86 SIMD经历了越来越宽的向量寄存器: + +| 代次 | 年份 | 寄存器宽度 | 寄存器数量 | 关键特性 | +|------|------|---------|----------|----------| +| MMX | 1997 | 64位 | 8(mm0-7) | 仅整数,与FPU共享 | +| SSE | 1999 | 128位 | 8(xmm0-7) | 4个浮点数,专用寄存器 | +| SSE2 | 2001 | 128位 | 8/16 | 2个双精度浮点数,整数操作 | +| AVX | 2011 | 256位 | 16(ymm0-15) | 8个浮点数,三操作数指令 | +| AVX2 | 2013 | 256位 | 16 | 整数256位,FMA,收集 | +| AVX-512 | 2017 | 512位 | 32(zmm0-31) | 16个浮点数,掩码寄存器,分散 | +| AMX | 2023 | 瓦片寄存器 | 8个瓦片 | 矩阵乘法(BF16,INT8) | + +- 每一代都将向量化代码的吞吐量翻倍。用SSE内联函数编写的代码可以在2001年以来制造的每一个x86 CPU上运行。AVX2需要2013年以后的CPU。AVX-512需要Intel Xeon和一些消费级芯片。AMX是最新的(Sapphire Rapids及以后)。 + +- **向后兼容性**:x86 SSE寄存器(xmm)是AVX寄存器(ymm)的低128位,后者是AVX-512寄存器(zmm)的低256位。旧的SSE代码无需修改即可在新的CPU上运行。 + +## AVX2编程 + +- AVX2操作256位寄存器(YMM),同时处理8个浮点数或4个双精度浮点数。它是可移植高性能代码的甜点区域:在几乎所有现代x86 CPU(2013+)上可用。 + +### 内联函数命名约定 + +- 所有x86内联函数遵循模式:`_mm[宽度]_[操作]_[类型]` + + - `_mm` = MMX/SSE(128位),`_mm256` = AVX(256位),`_mm512` = AVX-512(512位) + - 操作:`add`、`mul`、`fmadd`、`load`、`store`、`set` 等 + - 类型:`ps` = 打包单精度(float32),`pd` = 打包双精度(float64),`epi32` = 打包int32,`si256` = 256位整数 + +```cpp +#include // 所有x86 SIMD内联函数 + +// 数据类型 +__m256 a; // 256位寄存器,保存8个float32 +__m256d b; // 256位寄存器,保存4个float64 +__m256i c; // 256位寄存器,保存整数(8x32、16x16或32x8) +``` + +### 加载和存储数据 + +```cpp +// 从内存加载8个浮点数 +__m256 v = _mm256_loadu_ps(ptr); // 非对齐加载(适用于任何地址) +__m256 v = _mm256_load_ps(ptr); // 对齐加载(ptr必须32字节对齐,更快) + +// 存储8个浮点数到内存 +_mm256_storeu_ps(out_ptr, v); // 非对齐存储 +_mm256_store_ps(out_ptr, v); // 对齐存储 + +// 将单个值广播到所有8个通道 +__m256 ones = _mm256_set1_ps(1.0f); // [1, 1, 1, 1, 1, 1, 1, 1] + +// 设置各个值(很少需要) +__m256 v = _mm256_set_ps(7,6,5,4,3,2,1,0); // 注意:逆序! + +// 零寄存器 +__m256 z = _mm256_setzero_ps(); +``` + +### 算术运算 + +```cpp +__m256 c = _mm256_add_ps(a, b); // c[i] = a[i] + b[i] +__m256 d = _mm256_mul_ps(a, b); // d[i] = a[i] * b[i] +__m256 e = _mm256_sub_ps(a, b); // e[i] = a[i] - b[i] +__m256 f = _mm256_div_ps(a, b); // f[i] = a[i] / b[i](比乘法慢) + +// 融合乘加:r = a * b + c(一条指令,一次舍入) +__m256 r = _mm256_fmadd_ps(a, b, c); // ML最重要的指令 + +// 最小值和最大值 +__m256 mn = _mm256_min_ps(a, b); // min(a[i], b[i]) — 用于裁剪 +__m256 mx = _mm256_max_ps(a, b); // max(a[i], b[i]) — 用于ReLU +``` + +### 实践示例:AVX2点积 + +```cpp +#include + +float dot_avx2(const float* a, const float* b, int n) { + __m256 sum = _mm256_setzero_ps(); // 8个累加器初始化为0 + + int i = 0; + for (; i + 8 <= n; i += 8) { + __m256 va = _mm256_loadu_ps(a + i); + __m256 vb = _mm256_loadu_ps(b + i); + sum = _mm256_fmadd_ps(va, vb, sum); // sum += va * vb + } + + // 水平归约:将sum的8个元素相加 + // 步骤1:将上128位加到下128位 + __m128 hi = _mm256_extractf128_ps(sum, 1); + __m128 lo = _mm256_castps256_ps128(sum); + __m128 sum128 = _mm_add_ps(hi, lo); // 4个部分和 + + // 步骤2:在128位寄存器内水平相加 + sum128 = _mm_hadd_ps(sum128, sum128); // [a+b, c+d, a+b, c+d] + sum128 = _mm_hadd_ps(sum128, sum128); // [a+b+c+d, ...] + + float result = _mm_cvtss_f32(sum128); // 提取标量 + + // 标量清理 + for (; i < n; i++) { + result += a[i] * b[i]; + } + + return result; +} +``` + +- **为什么水平归约如此丑陋**:SIMD是为垂直操作设计的(通道0与通道0,通道1与通道1)。水平操作(跨通道求和)与硬件对抗。这就是点积在末尾有尴尬归约代码的原因。向量化循环是简洁的;归约是样板代码。 + +- **性能**:与NEON版本(文件02)相比,AVX2每次迭代处理8个浮点数,而NEON处理4个。对于长向量,这比NEON快2倍(忽略内存带宽限制)。 + +### 实践示例:AVX2 Softmax(简化版) + +- Softmax需要:找到最大值,减去它,求指数,求和,除法。以下是最值查找步骤: + +```cpp +float vector_max_avx2(const float* data, int n) { + __m256 max_vec = _mm256_set1_ps(-INFINITY); + + int i = 0; + for (; i + 8 <= n; i += 8) { + __m256 v = _mm256_loadu_ps(data + i); + max_vec = _mm256_max_ps(max_vec, v); + } + + // 将8个最大值归约为1个 + __m128 hi = _mm256_extractf128_ps(max_vec, 1); + __m128 lo = _mm256_castps256_ps128(max_vec); + __m128 max128 = _mm_max_ps(hi, lo); + + // 通过混洗和取最大值找到单一最大值 + max128 = _mm_max_ps(max128, _mm_shuffle_ps(max128, max128, 0b01001110)); + max128 = _mm_max_ps(max128, _mm_shuffle_ps(max128, max128, 0b10110001)); + + float result = _mm_cvtss_f32(max128); + + for (; i < n; i++) { + result = result > data[i] ? result : data[i]; + } + + return result; +} +``` + +- `_mm_shuffle_ps` 指令在寄存器内重排元素。二进制常量 `0b01001110` 控制哪些元素去哪里。这称为**置换**,它直接连接到置换矩阵(第2章):打乱SIMD通道是做硬件级别的乘以置换矩阵。 + +## AVX-512 + +- AVX-512再次加倍宽度:512位寄存器(ZMM),同时处理16个浮点数。 + +```cpp +__m512 a = _mm512_loadu_ps(ptr); // 加载16个浮点数 +__m512 c = _mm512_fmadd_ps(a, b, c); // 16个FMA同时进行 +float sum = _mm512_reduce_add_ps(a); // 内置水平求和(无需手动归约!) + +// 掩码操作:操作通道子集 +__mmask16 mask = _mm512_cmpgt_ps_mask(a, zero); // 哪些通道 > 0? +__m512 relu = _mm512_maskz_mov_ps(mask, a); // 负通道置零 = ReLU +``` + +- **掩码寄存器**(`__mmask16`)是AVX-512最强大的功能。每个位控制一个通道是否参与操作。这取代了标量清理循环:最后一次迭代使用掩码,只有有效通道是激活的,处理任何向量长度而无需单独标量循环。 + +- **AVX-512频率降频**:在许多Intel CPU上,使用AVX-512指令会导致CPU暂时降低时钟频率(以保持在热限制内)。这意味着对于短时爆发,AVX-512并不总是比AVX2快——频率惩罚可能抵消更宽向量的优势。对于持续工作负载(如矩阵乘法),AVX-512胜出。对于混合代码(部分SIMD、部分标量),频率转换可能造成损失。 + +## Intel AMX:矩阵乘法硬件 + +- **AMX**(高级矩阵扩展)增加了专用矩阵乘法单元。AMX操作的不是SIMD向量,而是**瓦片**:2D数据块(最多16行 × 每行64字节)。 + +```cpp +#include + +// AMX瓦片乘法:C += A * B(BF16格式) +// A为16x32 BF16,B为32x16 BF16,C为16x16 FP32 +_tile_loadd(0, a_ptr, stride_a); // 从A加载瓦片0 +_tile_loadd(1, b_ptr, stride_b); // 从B加载瓦片1 +_tile_dpbf16ps(2, 0, 1); // 瓦片2 += 瓦片0 * 瓦片1(BF16矩阵乘法,FP32累加) +_tile_stored(2, c_ptr, stride_c); // 存储瓦片2到C +``` + +- AMX在一条指令中执行完整的16×32 × 32×16矩阵乘法。这是数百次FMA操作同时进行,专门为Transformer推理中主导的小矩阵乘法设计(注意力得分计算、MLP层)。 + +- AMX支持BF16(bfloat16)和INT8,匹配ML推理中使用的精度。结合用于其他操作的AVX-512,配备AMX的CPU(Intel Sapphire Rapids、Emerald Rapids)可以在Transformer推理中与入门级GPU竞争。 + +## 内存对齐 + +- **对齐内存访问**是指数据地址是向量寄存器宽度的倍数(SSE为16字节、AVX为32字节、AVX-512为64字节)。对齐访问在某些CPU上更快,并且是 `_mm256_load_ps`(相对于 `_mm256_loadu_ps`)的要求。 + +```cpp +// 分配对齐内存 +float* data = (float*)aligned_alloc(32, n * sizeof(float)); // AVX的32字节对齐 + +// C++对齐分配 +#include +float* data = new (std::align_val_t(32)) float[n]; + +// 或者使用编译器属性 +alignas(32) float data[1024]; +``` + +- **实际上**:在现代CPU(Haswell及以后)上,当数据不跨越缓存行边界时,非对齐加载(`loadu`)几乎与对齐加载一样快。非对齐访问的性能惩罚已基本消失,但缓存行分割(数据跨越两个64字节缓存行)仍可能使特定加载变慢约2倍。对齐分配完全避免了这种情况。 + +## 性能陷阱 + +- **AVX-SSE转换惩罚**:在较旧的Intel CPU(Skylake之前)上,在AVX(256位)和SSE(128位)指令之间切换会造成惩罚(约70周期)。这就是为什么你应该在从使用AVX的函数返回之前使用 `_mm256_zeroupper()`(或 `vzeroupper` 指令)清除YMM寄存器的上128位。现代CPU(Skylake+)没有此惩罚。 + +- **寄存器压力**:AVX2有16个YMM寄存器。如果你的核函数使用太多变量,编译器会将寄存器溢出到栈(内存),从而破坏性能。保持内循环简单,活变量少。 + +- **数据依赖**:`sum = _mm256_fmadd_ps(a, b, sum)` 对 `sum` 有依赖:每次迭代必须等待前一个FMA完成(约4-5个周期的延迟)。解决方案:使用多个独立累加器并在结束时归约: + +```cpp +// 单累加器:受FMA延迟限制(4-5个周期) +__m256 sum = _mm256_setzero_ps(); +for (...) { + sum = _mm256_fmadd_ps(a, b, sum); // 每个依赖前一个 +} + +// 四个累加器:4倍吞吐量(隐藏延迟) +__m256 sum0 = _mm256_setzero_ps(); +__m256 sum1 = _mm256_setzero_ps(); +__m256 sum2 = _mm256_setzero_ps(); +__m256 sum3 = _mm256_setzero_ps(); +for (...) { + sum0 = _mm256_fmadd_ps(a0, b0, sum0); // 独立 + sum1 = _mm256_fmadd_ps(a1, b1, sum1); // 独立 + sum2 = _mm256_fmadd_ps(a2, b2, sum2); // 独立 + sum3 = _mm256_fmadd_ps(a3, b3, sum3); // 独立 +} +sum0 = _mm256_add_ps(sum0, sum1); +sum2 = _mm256_add_ps(sum2, sum3); +sum0 = _mm256_add_ps(sum0, sum2); +``` + +- 这是**循环展开**以隐藏延迟。CPU可以背靠背发出FMAs,因为它们写入不同的寄存器。这是数值代码中最有影响力的微优化之一。 + +## 性能分析 + +- **性能计数器**提供硬件级测量: + +```bash +# Linux perf(需要内核支持) +perf stat ./my_program # 基本计数器:周期、指令、IPC +perf stat -e cache-misses,cache-references ./my_program # 缓存行为 +perf record -g ./my_program && perf report # 调用图分析 + +# Intel VTune(详细的x86性能分析) +vtune -collect hotspots -- ./my_program +vtune -collect memory-access -- ./my_program # 内存带宽分析 +``` + +- **需要关注什么**: + - **IPC**(每周期指令数):CPU被使用的效率。IPC > 2 良好。IPC < 1 表明内存停顿或分支预测错误。 + - **缓存缺失率**:高L1/L2缺失率表明数据局部性差。需重构数据访问模式。 + - **分支预测错误率**:> 5% 表明分支不可预测。如可能,转换为无分支代码(SIMD比较+混合)。 + - **实际FLOPS vs 屋顶线**:将你的实测FLOPS与屋顶线模型(文件01)比较。如果你低于屋顶线,还有改进空间。 + +## 编程任务(在x86——Intel/AMD上用g++或clang++编译) + +1. 编写标量点积和AVX2点积。对两者进行基准测试并测量8路SIMD带来的加速比。 +```cpp +// task1_avx_dot.cpp +// 编译:g++ -O3 -mavx2 -mfma -o task1 task1_avx_dot.cpp + +#include +#include +#include +#include + +float dot_scalar(const float* a, const float* b, int n) { + float sum = 0.0f; + for (int i = 0; i < n; i++) sum += a[i] * b[i]; + return sum; +} + +float dot_avx2(const float* a, const float* b, int n) { + __m256 sum = _mm256_setzero_ps(); + int i = 0; + for (; i + 8 <= n; i += 8) { + __m256 va = _mm256_loadu_ps(a + i); + __m256 vb = _mm256_loadu_ps(b + i); + sum = _mm256_fmadd_ps(va, vb, sum); + } + // 归约:上128加到下128,然后水平相加 + __m128 hi = _mm256_extractf128_ps(sum, 1); + __m128 lo = _mm256_castps256_ps128(sum); + __m128 r = _mm_add_ps(hi, lo); + r = _mm_hadd_ps(r, r); + r = _mm_hadd_ps(r, r); + float result = _mm_cvtss_f32(r); + for (; i < n; i++) result += a[i] * b[i]; + return result; +} + +int main() { + const int N = 10'000'000; + std::vector a(N, 1.0f), b(N, 2.0f); + + volatile float s1 = dot_scalar(a.data(), b.data(), N); + volatile float s2 = dot_avx2(a.data(), b.data(), N); + + auto bench = [&](auto fn, const char* name) { + auto start = std::chrono::high_resolution_clock::now(); + volatile float s; + for (int t = 0; t < 100; t++) s = fn(a.data(), b.data(), N); + auto end = std::chrono::high_resolution_clock::now(); + double ms = std::chrono::duration(end - start).count() / 100; + std::cout << name << ": " << ms << " ms(结果: " << s << ")\n"; + return ms; + }; + + double t1 = bench(dot_scalar, "标量"); + double t2 = bench(dot_avx2, "AVX2 "); + std::cout << "加速比: " << t1 / t2 << "x\n"; + return 0; +} +``` + +2. 使用 `_mm256_max_ps` 实现AVX2 ReLU并与标量循环比较。然后尝试使用多累加器(循环展开)以隐藏FMA延迟。 +```cpp +// task2_avx_relu.cpp +// 编译:g++ -O3 -mavx2 -o task2 task2_avx_relu.cpp + +#include +#include +#include +#include + +void relu_scalar(const float* in, float* out, int n) { + for (int i = 0; i < n; i++) { + out[i] = in[i] > 0.0f ? in[i] : 0.0f; + } +} + +void relu_avx2(const float* in, float* out, int n) { + __m256 zero = _mm256_setzero_ps(); + int i = 0; + for (; i + 8 <= n; i += 8) { + __m256 x = _mm256_loadu_ps(in + i); + _mm256_storeu_ps(out + i, _mm256_max_ps(x, zero)); + } + for (; i < n; i++) out[i] = in[i] > 0.0f ? in[i] : 0.0f; +} + +// 展开:每次迭代处理32个浮点数(4 x 8) +void relu_avx2_unrolled(const float* in, float* out, int n) { + __m256 zero = _mm256_setzero_ps(); + int i = 0; + for (; i + 32 <= n; i += 32) { + __m256 x0 = _mm256_loadu_ps(in + i); + __m256 x1 = _mm256_loadu_ps(in + i + 8); + __m256 x2 = _mm256_loadu_ps(in + i + 16); + __m256 x3 = _mm256_loadu_ps(in + i + 24); + _mm256_storeu_ps(out + i, _mm256_max_ps(x0, zero)); + _mm256_storeu_ps(out + i + 8, _mm256_max_ps(x1, zero)); + _mm256_storeu_ps(out + i + 16, _mm256_max_ps(x2, zero)); + _mm256_storeu_ps(out + i + 24, _mm256_max_ps(x3, zero)); + } + for (; i + 8 <= n; i += 8) { + _mm256_storeu_ps(out + i, _mm256_max_ps(_mm256_loadu_ps(in + i), zero)); + } + for (; i < n; i++) out[i] = in[i] > 0.0f ? in[i] : 0.0f; +} + +int main() { + const int N = 16'000'000; + std::vector in(N), out(N); + for (int i = 0; i < N; i++) in[i] = (float)(i % 200) - 100.0f; + + auto bench = [&](auto fn, const char* name) { + fn(in.data(), out.data(), N); // 预热 + auto start = std::chrono::high_resolution_clock::now(); + for (int t = 0; t < 100; t++) fn(in.data(), out.data(), N); + auto end = std::chrono::high_resolution_clock::now(); + double ms = std::chrono::duration(end - start).count() / 100; + double bw = 2.0 * N * sizeof(float) / ms / 1e6; // 读取+写入 + std::cout << name << ": " << ms << " ms(" << bw << " GB/s)\n"; + }; + + bench(relu_scalar, "标量 "); + bench(relu_avx2, "AVX2 "); + bench(relu_avx2_unrolled, "AVX2展开 "); + return 0; +} +``` + +3. 测量内存对齐的效果。比较在大数组上的对齐加载与非对齐加载。 +```cpp +// task3_alignment.cpp +// 编译:g++ -O3 -mavx2 -o task3 task3_alignment.cpp + +#include +#include +#include +#include + +int main() { + const int N = 16'000'000; + + // 对齐分配(AVX2为32字节) + float* aligned = (float*)aligned_alloc(32, N * sizeof(float)); + + // 非对齐:从对齐边界偏移4字节(1个浮点数) + float* raw = (float*)malloc((N + 1) * sizeof(float)); + float* unaligned = raw + 1; // 保证未对齐 + + for (int i = 0; i < N; i++) { + aligned[i] = 1.0f; + unaligned[i] = 1.0f; + } + + auto bench = [&](float* ptr, bool use_aligned, const char* name) { + __m256 sum = _mm256_setzero_ps(); + // 预热 + for (int i = 0; i + 8 <= N; i += 8) { + __m256 v = use_aligned ? _mm256_load_ps(ptr + i) : _mm256_loadu_ps(ptr + i); + sum = _mm256_add_ps(sum, v); + } + + auto start = std::chrono::high_resolution_clock::now(); + for (int t = 0; t < 100; t++) { + sum = _mm256_setzero_ps(); + for (int i = 0; i + 8 <= N; i += 8) { + __m256 v = use_aligned ? _mm256_load_ps(ptr + i) : _mm256_loadu_ps(ptr + i); + sum = _mm256_add_ps(sum, v); + } + } + auto end = std::chrono::high_resolution_clock::now(); + double ms = std::chrono::duration(end - start).count() / 100; + double bw = (double)N * sizeof(float) / ms / 1e6; + std::cout << name << ": " << ms << " ms(" << bw << " GB/s)\n"; + }; + + bench(aligned, true, "对齐加载 "); + bench(unaligned, false, "非对齐加载"); + + free(aligned); + free(raw); + return 0; +} +``` diff --git a/chapter 16: SIMD and GPU programming/04. GPU architecture and CUDA.md b/chapter 16: SIMD and GPU programming/04. GPU architecture and CUDA.md new file mode 100644 index 0000000..c835da5 --- /dev/null +++ b/chapter 16: SIMD and GPU programming/04. GPU architecture and CUDA.md @@ -0,0 +1,598 @@ +# GPU架构与CUDA + +*GPU通过提供数千个核心用于大规模并行计算,改变了AI。本文涵盖GPU与CPU的设计哲学对比、GPU存储层次、C++中的CUDA编程、SIMT执行模型、内存访问模式、同步、流、性能分析以及NVIDIA GPU代次——编写和理解GPU核函数所需的知识。* + +- 有关带有完整工作示例的实践CUDA教程,请参见配套仓库:[github.com/HenryNdubuaku/cuda-tutorials](https://github.com/HenryNdubuaku/cuda-tutorials)。 + +- 现代NVIDIA GPU有超过10,000个CUDA核心。CPU有4-128个核心。这100-1000倍的核心优势是GPU主导ML的原因:训练一个Transformer需要数万亿次乘加操作,GPU以CPU无法匹敌的规模并行处理它们。 + +- 即使你从不自己编写CUDA核函数,理解GPU架构也能解释:为什么批次大小很重要(需要足够的工作来饱和GPU),为什么内存通常是瓶颈(而非计算),以及为什么某些操作(分散、条件分支)在GPU上很慢。 + +## GPU vs CPU:根本不同的设计 + +- CPU是为**延迟**设计的:最小化完成一个任务的时间。它将其晶体管预算的大部分用于缓存、分支预测器和乱序执行——所有让单一线程快速运行的技巧。 + +- GPU是为**吞吐量**设计的:最大化每秒完成的任务数量。它将大部分晶体管用于执行单元(ALU)。单个线程很慢,但有数千个。 + +| | CPU | GPU | +|--|-----|-----| +| 核心 | 4-128(复杂、快速) | 1,000-20,000(简单、慢速) | +| 时钟频率 | 3-5 GHz | 1-2.5 GHz | +| 缓存 | 大(32 MB+ L3) | 小(每SM共享内存) | +| 分支预测 | 精密 | 无(所有线程遵循相同路径) | +| 最适合 | 低延迟、复杂控制流 | 高吞吐量、数据并行工作 | +| 典型FLOPS(FP32) | 1-5 TFLOPS | 30-80 TFLOPS | +| 内存带宽 | 50-100 GB/s | 1-3 TB/s | + +- GPU的内存带宽优势(10-30倍)通常比其计算优势更重要。许多ML操作是内存受限的(逐元素操作、归一化、注意力),GPU的带宽使其能够足够快地向核心输送数据。 + +## GPU存储层次 + +- 理解GPU内存至关重要,因为**内存访问是主要瓶颈**,而非计算。 + +| 内存 | 大小 | 延迟 | 带宽 | 作用域 | +|--------|------|---------|-----------|-------| +| 寄存器 | 每SM约256 KB | 0周期 | 最高 | 每线程 | +| 共享内存 | 每SM 48-228 KB | 约5周期 | 约20 TB/s | 每线程块 | +| L1缓存 | 每SM 128-256 KB | 约30周期 | | 每SM | +| L2缓存 | 4-96 MB | 约200周期 | 约6 TB/s | 全局 | +| 全局内存(HBM) | 24-192 GB | 约400周期 | 1-3.3 TB/s | 全局 | + +- **寄存器**是最快但最有限的。每个线程有一组私有寄存器(通常最多255个)。每线程使用过多寄存器会降低**占用率**(可同时运行的线程更少)。 + +- **共享内存**是由程序员管理的缓存,由块中的所有线程共享。它是编写快速CUDA核函数的关键:将数据瓦片从慢速全局内存加载到快速共享内存,然后进行计算。这是主导GPU编程的**分块**模式。 + +- **全局内存(HBM)**:主GPU内存(VRAM)。大但慢(400周期延迟)。所有数据起始和结束于此。核函数优化的目标是尽量减少全局内存访问。 + +## CUDA编程模型 + +- CUDA(统一计算设备架构)是NVIDIA的GPU编程模型。你编写**核函数**:在GPU上运行的函数,由数千个线程同时执行。 + +### 层次结构:网格、块、线程 + +``` +网格(整个启动) +├── 块 (0,0) +│ ├── 线程 (0,0) +│ ├── 线程 (1,0) +│ ├── 线程 (2,0) +│ └── ... (每块最多1024线程) +├── 块 (1,0) +│ ├── 线程 (0,0) +│ └── ... +└── ... (可能有数百万个块) +``` + +- **线程**:最小单位。每个线程在其块内有唯一ID(`threadIdx.x`)。 +- **块**:一组可以共享内存和同步的线程。块ID:`blockIdx.x`。块大小:`blockDim.x`(最多1024线程)。 +- **网格**:单个核函数启动的所有块。可以是1D、2D或3D。 + +- 每个线程计算其全局索引:`int idx = blockIdx.x * blockDim.x + threadIdx.x;` + +### 你的第一个CUDA核函数 + +```cpp +// vector_add.cu — CUDA源文件(.cu扩展名) + +#include + +// __global__ 标记此为GPU核函数(从CPU调用,在GPU上运行) +__global__ void vector_add(const float* a, const float* b, float* c, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { // 边界检查(网格可能大于数据) + c[idx] = a[idx] + b[idx]; + } +} + +int main() { + int n = 1 << 20; // 约100万个元素 + size_t bytes = n * sizeof(float); + + // 分配主机(CPU)内存 + float *h_a = new float[n]; + float *h_b = new float[n]; + float *h_c = new float[n]; + + // 初始化 + for (int i = 0; i < n; i++) { + h_a[i] = 1.0f; + h_b[i] = 2.0f; + } + + // 分配设备(GPU)内存 + float *d_a, *d_b, *d_c; + cudaMalloc(&d_a, bytes); + cudaMalloc(&d_b, bytes); + cudaMalloc(&d_c, bytes); + + // 将数据从CPU拷贝到GPU + cudaMemcpy(d_a, h_a, bytes, cudaMemcpyHostToDevice); + cudaMemcpy(d_b, h_b, bytes, cudaMemcpyHostToDevice); + + // 启动核函数:每块256线程,足够的块覆盖n个元素 + int block_size = 256; + int grid_size = (n + block_size - 1) / block_size; // 上取整除法 + vector_add<<>>(d_a, d_b, d_c, n); + + // 将结果从GPU拷贝到CPU + cudaMemcpy(h_c, d_a, bytes, cudaMemcpyDeviceToHost); + + // 验证 + printf("c[0] = %f(期望值 3.0)\n", h_c[0]); + + // 释放内存 + cudaFree(d_a); cudaFree(d_b); cudaFree(d_c); + delete[] h_a; delete[] h_b; delete[] h_c; + + return 0; +} +``` + +```bash +# 用NVIDIA编译器编译 +nvcc -O3 -o vector_add vector_add.cu +./vector_add +``` + +- **CUDA中的关键C++概念**: + - `__global__`:CUDA关键字,标记核函数。从CPU(主机)调用,在GPU(设备)上运行。 + - `<<>>`:核函数启动语法。指定使用多少块和线程。 + - `cudaMalloc` / `cudaFree`:分配/释放GPU内存(类似于`new`/`delete`,但针对GPU)。 + - `cudaMemcpy`:在CPU和GPU之间拷贝数据。这通常是最大的瓶颈(PCIe带宽约32 GB/s,而GPU内存带宽约3 TB/s)。 + +### 线程束与SIMT + +- GPU以32个为一组称为**线程束**的组执行线程。一个线程束中的所有32个线程同时执行**相同指令**(单指令多线程——SIMT)。这是GPU的SIMD等效,但在线程级别。 + +- **线程束分歧**发生在同一线程束中的线程在`if`语句中走不同分支时。GPU不能在一个线程束中同时执行两条不同指令,因此它顺序执行两个分支,屏蔽掉不应参与的线程。这使性能减半(或更差)。 + +```cpp +// 糟糕:线程束分歧(同一线程束中的线程走不同路径) +if (threadIdx.x % 2 == 0) { + c[idx] = a[idx] + b[idx]; // 偶数线程做这个 +} else { + c[idx] = a[idx] - b[idx]; // 奇数线程做这个(同一线程束,串行化) +} + +// 更好:无分支(无分歧) +float sign = (threadIdx.x % 2 == 0) ? 1.0f : -1.0f; +c[idx] = a[idx] + sign * b[idx]; // 所有线程执行相同指令 +``` + +### 内存合并 + +- **合并访问**:当连续的线程访问连续的内存地址时,GPU将它们组合成单个内存事务。这对性能至关重要。 + +```cpp +// 好:合并——线程0读a[0],线程1读a[1],... +c[idx] = a[idx] + b[idx]; + +// 坏:跨步——线程0读a[0],线程1读a[步长],... +c[idx] = a[idx * stride] + b[idx * stride]; // 步长 > 1 浪费带宽 +``` + +- 对于一个32线程的线程束,合并访问在单次事务中加载128字节(32 × 4字节用于float32)。跨步访问需要多次事务,每次加载128字节但只使用一小部分。步长为32是最坏情况:每次事务加载128字节,但只有一个线程使用4字节(3%的利用率)。 + +### 共享内存与分块 + +- **分块模式**是最重要的GPU优化技术。其想法:将数据块从慢速全局内存加载到快速共享内存,进行计算,然后将结果写回。 + +```cpp +// 使用共享内存分块的矩阵乘法(简化版) +__global__ void matmul_tiled(const float* A, const float* B, float* C, + int M, int N, int K) { + // A的一个瓦片和B的一个瓦片的共享内存 + __shared__ float tile_A[TILE_SIZE][TILE_SIZE]; + __shared__ float tile_B[TILE_SIZE][TILE_SIZE]; + + int row = blockIdx.y * TILE_SIZE + threadIdx.y; + int col = blockIdx.x * TILE_SIZE + threadIdx.x; + float sum = 0.0f; + + // 遍历瓦片 + for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) { + // 将A和B的一个瓦片加载到共享内存 + if (row < M && t * TILE_SIZE + threadIdx.x < K) + tile_A[threadIdx.y][threadIdx.x] = A[row * K + t * TILE_SIZE + threadIdx.x]; + else + tile_A[threadIdx.y][threadIdx.x] = 0.0f; + + if (col < N && t * TILE_SIZE + threadIdx.y < K) + tile_B[threadIdx.y][threadIdx.x] = B[(t * TILE_SIZE + threadIdx.y) * N + col]; + else + tile_B[threadIdx.y][threadIdx.x] = 0.0f; + + __syncthreads(); // 等待所有线程完成加载 + + // 计算此瓦片的部分点积 + for (int k = 0; k < TILE_SIZE; k++) { + sum += tile_A[threadIdx.y][k] * tile_B[k][threadIdx.x]; + } + + __syncthreads(); // 在加载下一个瓦片前等待 + } + + if (row < M && col < N) + C[row * N + col] = sum; +} +``` + +- **`__shared__`**:声明块内所有线程共享的内存(快速、片上)。 +- **`__syncthreads()`**:一个屏障,等待块中所有线程到达此点。在写入共享内存和读取它之间必须使用(否则某些线程读取到过期数据)。 +- **为什么分块有效**:没有它,每个线程每次乘法都从全局内存加载。有了分块,一个TILE_SIZE × TILE_SIZE的数据块被加载到共享内存一次,并被块中所有线程重用。重用因子为TILE_SIZE,将全局内存流量减少该因子。 + +## 流与并发 + +- 默认情况下,CUDA操作是顺序的:CPU启动一个核函数,等待它完成,然后启动下一个。**流**允许重叠: + +```cpp +cudaStream_t stream1, stream2; +cudaStreamCreate(&stream1); +cudaStreamCreate(&stream2); + +// 这些操作可以重叠:不同流并发执行 +cudaMemcpyAsync(d_a, h_a, bytes, cudaMemcpyHostToDevice, stream1); +cudaMemcpyAsync(d_b, h_b, bytes, cudaMemcpyHostToDevice, stream2); + +kernel1<<>>(d_a, d_c); +kernel2<<>>(d_b, d_d); +``` + +- 流将数据传输与计算重叠:当流1的核函数运行时,流2在拷贝数据。这隐藏了PCIe传输延迟,并保持GPU忙碌。 + +## 分析CUDA代码 + +```bash +# NVIDIA Nsight Compute:核函数级分析 +ncu --set full ./my_program + +# NVIDIA Nsight Systems:系统级时间线 +nsys profile ./my_program + +# 快速指标 +ncu --metrics sm__throughput,dram__throughput ./my_program +``` + +- **需要关注什么**: + - **占用率**:SM容量中被使用的比例。低占用率(< 50%)意味着线程太少,无法隐藏内存延迟。原因:每线程寄存器过多、每块共享内存过多。 + - **内存吞吐量**:与峰值带宽比较。如果你达到峰值带宽的50%以下,内存访问模式低效(非合并、存储体冲突)。 + - **计算吞吐量**:与峰值FLOPS比较。如果内存和计算吞吐量都低,核函数是延迟受限的(并行度不够)。 + +## 高级优化技术 + +- 除了合并和共享内存分块的基础知识外,高性能GPU(和CPU)代码还使用几种高级技术: + +### 数据布局:AoS vs SoA + +- **结构体数组(AoS)**:每个元素将所有字段存储在一起。`[{x,y,z}, {x,y,z}, {x,y,z}]`。 +- **数组结构体(SoA)**:每个字段存储在自己的连续数组中。`{[x,x,x], [y,y,y], [z,z,z]}`。 + +```cpp +// AoS:对于SIMD/GPU不好(访问所有x值触及非连续内存) +struct Particle { float x, y, z, mass; }; +Particle particles[N]; +// particles[0].x, particles[1].x 相隔16字节 + +// SoA:对于SIMD/GPU好(所有x值连续) +struct Particles { + float x[N], y[N], z[N], mass[N]; +}; +// x[0], x[1] 相隔4字节——非常适合合并访问和SIMD +``` + +- SoA对于数据并行工作负载(SIMD、GPU)几乎总是更快。AoS在你总是同时访问一个元素的所有字段时更好(在数值代码中很少见)。PyTorch张量本质上是SoA:每个特征是一个连续维度。 + +### 软件预取 + +- 可以告诉CPU在需要之前开始加载数据,隐藏内存延迟: + +```cpp +#include // for _mm_prefetch + +for (int i = 0; i < n; i += 4) { + _mm_prefetch((char*)(a + i + 64), _MM_HINT_T0); // 预取之前64个元素 + // 用SIMD处理 a[i:i+4] + __m128 va = _mm_load_ps(a + i); + // ... +} +``` + +- 预取指令是一个提示:如果数据已在缓存中,它是空操作。如果不是,CPU在执行其他指令的同时开始在后台获取数据。预取距离(此示例中向前64个元素)应根据内存延迟和循环迭代时间进行调整。 + +### 核函数融合 + +- **核函数融合**将多个操作组合成一个核函数,以避免将中间结果写入内存。这是ML中最有影响力的单个GPU优化: + +``` +// 未融合:3次核函数启动,3次全局内存往返 +y = matmul(x, W) // 写y到全局内存 +z = y + bias // 读y,写z +out = relu(z) // 读z,写out + +// 融合:1次核函数启动,1次全局内存写入 +out = fused_matmul_bias_relu(x, W, bias) // y和z永不离开SRAM +``` + +- 对于内存受限操作(偏置加法、ReLU、层归一化),内存流量主导执行时间。融合完全消除了流量。PyTorch的`torch.compile`和Triton可以自动或通过最少努力实现融合。 + +### 混合精度核函数 + +- 使用较低精度(FP16、BF16、INT8)进行计算和较高精度(FP32)进行累加,达到两全其美: + +```cpp +// 张量核心:乘FP16矩阵,在FP32中累加 +// 每条张量核心指令:D(FP32)= A(FP16)× B(FP16)+ C(FP32) +nvcuda::wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); +``` + +- FP16比FP32小2倍,因此它使内存带宽加倍(通常的瓶颈),并在缓存中容纳2倍的数据。张量核心以FP32 CUDA核心8-16倍的速度处理FP16。这就是为什么混合精度训练(第6章)提供2-3倍加速且精度损失最小。 + +### 内存池分配器 + +- `cudaMalloc` 很慢(每次调用约1毫秒),因为它与GPU同步。在每次迭代分配临时缓冲区的训练循环中,这会累积起来。 + +- **内存池**(PyTorch的缓存分配器、CUDA内存池)预先分配一大块GPU内存,并从其中子分配而无需系统调用: + +```python +# PyTorch自动执行此操作——但理解原因很重要 +# 每个 torch.empty() 从池中重用内存,无需cudaMalloc +temp = torch.empty(1024, 1024, device='cuda') # 微秒,而非毫秒 +``` + +- 这就是为什么PyTorch的 `torch.cuda.memory_allocated()` 和 `torch.cuda.max_memory_allocated()` 不同:allocated是当前使用的,max是峰值(池可能持有比当前使用更多的内存)。 + +### 分析指导的优化 + +- 不要盲目优化。**先分析**,识别瓶颈,优化那个,然后重新分析。屋顶线模型(文件01)告诉你瓶颈是内存还是计算: + + - **内存受限**(低算术强度):优化数据布局(SoA)、融合核函数、使用较低精度、预取。 + - **计算受限**(高算术强度):使用张量核心、增加并行度、使用更快指令(FMA)。 + - **延迟受限**(并行度不足):增加占用率、减少寄存器使用、启动更多线程。 + +- 大多数ML工作负载是**内存受限的**。令人惊讶的推论:更快的GPU(更多FLOPS)通常没有帮助。更快的内存(HBM3 vs HBM2e)更有帮助。这就是为什么A100→H100升级不只是关于FLOPS——H100也有2倍的内存带宽。 + +## NVIDIA GPU代次 + +| 代次 | 年份 | 关键创新 | AI相关性 | +|---------|------|-------------|--------------| +| Pascal(P100) | 2016 | HBM2、NVLink | 第一代严肃的深度学习GPU | +| Volta(V100) | 2017 | **张量核心**(混合精度矩阵乘法) | 实现FP16训练,125 TFLOPS TF32 | +| Ampere(A100) | 2020 | TF32、稀疏性、第三代张量核心 | 312 TFLOPS TF32,结构性稀疏2:4 | +| Hopper(H100) | 2022 | **Transformer引擎**(FP8)、HBM3 | 989 TFLOPS FP8,动态精度切换 | +| Blackwell(B200) | 2024 | 第二代Transformer引擎、NVLink 5 | 2.5 PFLOPS FP4,多芯片设计 | + +- **张量核心**是专用的矩阵乘法单元。单个张量核心指令在一个周期内计算4×4矩阵乘法(D = A×B + C)。常规CUDA核心需要64次FMA操作。张量核心就是为什么混合精度训练(float16计算,float32累加)如此快速。 + +- **Transformer引擎**(Hopper+)在单层内动态切换FP8和FP16精度,只在需要时选择更高精度。这最大化吞吐量而不牺牲模型质量。它专为Transformer架构(注意力+MLP)设计,后者主导现代AI。 + +## 编程任务(用nvcc编译) + +1. 编写一个对数组应用ReLU的CUDA核函数。测量包括内存传输在内的时间。这教授核函数编写、cudaMalloc/cudaMemcpy以及主机↔设备传输瓶颈。 +```cpp +// task1_relu.cu +// 编译:nvcc -O3 -o task1_relu task1_relu.cu + +#include +#include +#include + +__global__ void relu_kernel(const float* input, float* output, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + output[idx] = input[idx] > 0.0f ? input[idx] : 0.0f; + } +} + +int main() { + const int N = 1 << 24; // 约1600万元素 + size_t bytes = N * sizeof(float); + + // 分配主机内存 + float* h_input = (float*)malloc(bytes); + float* h_output = (float*)malloc(bytes); + for (int i = 0; i < N; i++) { + h_input[i] = (float)(i % 100) - 50.0f; // 正负混合 + } + + // 分配设备内存 + float *d_input, *d_output; + cudaMalloc(&d_input, bytes); + cudaMalloc(&d_output, bytes); + + // 计时完整流水线:拷贝到GPU、计算、拷贝回 + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + cudaEventRecord(start); + cudaMemcpy(d_input, h_input, bytes, cudaMemcpyHostToDevice); + + int block_size = 256; + int grid_size = (N + block_size - 1) / block_size; + relu_kernel<<>>(d_input, d_output, N); + + cudaMemcpy(h_output, d_output, bytes, cudaMemcpyDeviceToHost); + cudaEventRecord(stop); + cudaEventSynchronize(stop); + + float ms = 0; + cudaEventElapsedTime(&ms, start, stop); + + // 验证 + int errors = 0; + for (int i = 0; i < N; i++) { + float expected = h_input[i] > 0.0f ? h_input[i] : 0.0f; + if (h_output[i] != expected) errors++; + } + + printf("时间(含传输): %.2f ms\n", ms); + printf("带宽: %.1f GB/s\n", 2.0 * bytes / ms / 1e6); // 读取+写入 + printf("错误: %d / %d\n", errors, N); + + cudaFree(d_input); cudaFree(d_output); + free(h_input); free(h_output); + return 0; +} +``` + +2. 在CUDA中使用共享内存编写分块矩阵乘法。将性能与朴素(非分块)版本进行比较。这教授共享内存、`__syncthreads`以及为什么分块重要。 +```cpp +// task2_matmul.cu +// 编译:nvcc -O3 -o task2_matmul task2_matmul.cu + +#include +#include + +#define TILE 16 + +// 朴素矩阵乘法:每个线程计算C的一个元素 +__global__ void matmul_naive(const float* A, const float* B, float* C, int N) { + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + if (row < N && col < N) { + float sum = 0.0f; + for (int k = 0; k < N; k++) { + sum += A[row * N + k] * B[k * N + col]; + } + C[row * N + col] = sum; + } +} + +// 分块矩阵乘法:使用共享内存减少全局内存访问 +__global__ void matmul_tiled(const float* A, const float* B, float* C, int N) { + __shared__ float sA[TILE][TILE]; + __shared__ float sB[TILE][TILE]; + + int row = blockIdx.y * TILE + threadIdx.y; + int col = blockIdx.x * TILE + threadIdx.x; + float sum = 0.0f; + + for (int t = 0; t < (N + TILE - 1) / TILE; t++) { + sA[threadIdx.y][threadIdx.x] = (row < N && t*TILE+threadIdx.x < N) + ? A[row * N + t*TILE + threadIdx.x] : 0.0f; + sB[threadIdx.y][threadIdx.x] = (t*TILE+threadIdx.y < N && col < N) + ? B[(t*TILE + threadIdx.y) * N + col] : 0.0f; + + __syncthreads(); + for (int k = 0; k < TILE; k++) + sum += sA[threadIdx.y][k] * sB[k][threadIdx.x]; + __syncthreads(); + } + + if (row < N && col < N) + C[row * N + col] = sum; +} + +int main() { + const int N = 1024; + size_t bytes = N * N * sizeof(float); + + float *d_A, *d_B, *d_C; + cudaMalloc(&d_A, bytes); cudaMalloc(&d_B, bytes); cudaMalloc(&d_C, bytes); + + // 初始化为1(容易验证:C应全为N) + float* h_A = new float[N*N]; + for (int i = 0; i < N*N; i++) h_A[i] = 1.0f; + cudaMemcpy(d_A, h_A, bytes, cudaMemcpyHostToDevice); + cudaMemcpy(d_B, h_A, bytes, cudaMemcpyHostToDevice); + + dim3 block(TILE, TILE); + dim3 grid((N+TILE-1)/TILE, (N+TILE-1)/TILE); + + // 基准测试朴素版 + cudaEvent_t start, stop; + cudaEventCreate(&start); cudaEventCreate(&stop); + + cudaEventRecord(start); + for (int i = 0; i < 10; i++) + matmul_naive<<>>(d_A, d_B, d_C, N); + cudaEventRecord(stop); + cudaEventSynchronize(stop); + float naive_ms; cudaEventElapsedTime(&naive_ms, start, stop); + + // 基准测试分块版 + cudaEventRecord(start); + for (int i = 0; i < 10; i++) + matmul_tiled<<>>(d_A, d_B, d_C, N); + cudaEventRecord(stop); + cudaEventSynchronize(stop); + float tiled_ms; cudaEventElapsedTime(&tiled_ms, start, stop); + + double gflops_naive = 2.0 * N * N * N * 10 / naive_ms / 1e6; + double gflops_tiled = 2.0 * N * N * N * 10 / tiled_ms / 1e6; + + printf("朴素版: %.2f ms, %.1f GFLOPS\n", naive_ms/10, gflops_naive); + printf("分块版: %.2f ms, %.1f GFLOPS\n", tiled_ms/10, gflops_tiled); + printf("加速比: %.1fx\n", naive_ms / tiled_ms); + + cudaFree(d_A); cudaFree(d_B); cudaFree(d_C); + delete[] h_A; + return 0; +} +``` + +3. 演示线程束分歧。编写一个核函数,其中同一线程束中的线程走不同分支,并与无分支版本比较。 +```cpp +// task3_divergence.cu +// 编译:nvcc -O3 -o task3_diverge task3_divergence.cu + +#include +#include + +// 糟糕:线程束分歧——偶数/奇数线程走不同路径 +__global__ void divergent_kernel(float* data, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + if (idx % 2 == 0) { + data[idx] = data[idx] * 2.0f + 1.0f; + } else { + data[idx] = data[idx] * 0.5f - 1.0f; + } + } +} + +// 好:无分支——所有线程执行相同指令 +__global__ void branchless_kernel(float* data, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + float scale = (idx % 2 == 0) ? 2.0f : 0.5f; + float offset = (idx % 2 == 0) ? 1.0f : -1.0f; + data[idx] = data[idx] * scale + offset; + } +} + +int main() { + const int N = 1 << 24; + float* d_data; + cudaMalloc(&d_data, N * sizeof(float)); + cudaMemset(d_data, 0, N * sizeof(float)); + + int block = 256, grid = (N + block - 1) / block; + + cudaEvent_t start, stop; + cudaEventCreate(&start); cudaEventCreate(&stop); + + // 分歧版 + cudaEventRecord(start); + for (int i = 0; i < 100; i++) + divergent_kernel<<>>(d_data, N); + cudaEventRecord(stop); + cudaEventSynchronize(stop); + float div_ms; cudaEventElapsedTime(&div_ms, start, stop); + + // 无分支版 + cudaEventRecord(start); + for (int i = 0; i < 100; i++) + branchless_kernel<<>>(d_data, N); + cudaEventRecord(stop); + cudaEventSynchronize(stop); + float nodiv_ms; cudaEventElapsedTime(&nodiv_ms, start, stop); + + printf("分歧版: %.2f ms\n", div_ms / 100); + printf("无分支版: %.2f ms\n", nodiv_ms / 100); + printf("加速比: %.2fx\n", div_ms / nodiv_ms); + + cudaFree(d_data); + return 0; +} +``` diff --git a/chapter 16: SIMD and GPU programming/05. triton, TPUs and pallax.md b/chapter 16: SIMD and GPU programming/05. triton, TPUs and pallax.md new file mode 100644 index 0000000..895f3ff --- /dev/null +++ b/chapter 16: SIMD and GPU programming/05. triton, TPUs and pallax.md @@ -0,0 +1,393 @@ +# Triton与TPU + +*CUDA C功能强大但冗长。Triton让你用Python编写GPU核函数。TPU提供了GPU之外的选择,具有不同的权衡。本文涵盖Triton核函数编程、以Flash Attention为案例研究、TPU架构与JAX/Pallas,以及如何选择合适的工具。关于Vulkan和跨平台GPU计算,请参见文件07。* + +- 上篇文件教授了CUDA C中的GPU编程。本文更上一层抽象阶梯:Triton以20%的工作量提供CUDA 80%的性能,且用Python。TPU和Vulkan为特定用例提供替代硬件目标。 + +## Triton:用Python编写GPU核函数 + +- **Triton**(OpenAI)是一种基于Python的GPU核函数编写语言。你不需要思考单个线程(CUDA),而是思考**块**级数据。Triton的编译器自动处理线程映射、内存合并、共享内存管理和许多优化。 + +- **为什么Triton重要**:CUDA C需要对线程束调度、共享内存存储体冲突、寄存器压力和合并模式有深入理解。Triton抽象了其中大部分内容,使GPU核函数开发对了解Python但不了解系统编程的ML研究人员可及。 + +### 你的第一个Triton核函数 + +```python +import triton +import triton.language as tl +import torch + +@triton.jit +def add_kernel( + x_ptr, y_ptr, output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, # 编译时常量 +): + # 每个程序实例处理一个BLOCK_SIZE元素的块 + pid = tl.program_id(axis=0) # 我是哪个块? + block_start = pid * BLOCK_SIZE + + # 此块的偏移量 + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + # 掩码处理n_elements不是BLOCK_SIZE倍数的情况 + mask = offsets < n_elements + + # 加载数据(带掩码:越界读取返回0) + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + + # 计算 + output = x + y + + # 存储结果 + tl.store(output_ptr + offsets, output, mask=mask) + + +def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + output = torch.empty_like(x) + n_elements = output.numel() + + # 启动:每个块一个程序 + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + + return output + + +# 使用 +x = torch.randn(1000000, device='cuda') +y = torch.randn(1000000, device='cuda') +z = add(x, y) +``` + +- **与CUDA的关键区别**: + - 无需显式线程管理。你思考**块**(程序),而非线程。 + - `tl.arange(0, BLOCK_SIZE)` 为整个块创建一个偏移向量。此向量上的所有操作都隐式向量化。 + - `mask` 处理边界条件(类似于AVX-512掩码寄存器,文件03)。无需标量清理循环。 + - `tl.load` 和 `tl.store` 自动处理合并访问。 + - `@triton.jit` 在首次调用时将函数编译为PTX(GPU汇编),然后缓存编译后的核函数。 + +### Triton Softmax核函数 + +- Softmax是一个很好的Triton示例,因为它需要对数据进行多次遍历(最大值、减去、指数、求和、除法),并且受益于在多次遍历之间将数据保留在SRAM(共享内存)中: + +```python +@triton.jit +def softmax_kernel( + output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, + BLOCK_SIZE: tl.constexpr, +): + # 每个程序处理一行 + row_idx = tl.program_id(0) + row_start = input_ptr + row_idx * input_row_stride + + # 加载该行 + col_offsets = tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + row = tl.load(row_start + col_offsets, mask=mask, other=-float('inf')) + + # Softmax:为数值稳定性取最大值,然后exp,然后归一化 + row_max = tl.max(row, axis=0) + numerator = tl.exp(row - row_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + + # 存储结果 + output_start = output_ptr + row_idx * output_row_stride + tl.store(output_start + col_offsets, softmax_output, mask=mask) +``` + +- 在PyTorch中,`F.softmax(x, dim=-1)` 启动3个独立核函数(最大值、指数-求和、除法),每个都从全局内存读取和写入。Triton版本在一个核函数内完成所有操作,将数据保留在寄存器/SRAM中。这种**核函数融合**就是自定义Triton核函数可以比PyTorch内置操作快2-4倍的原因。 + +### Triton自动调优 + +- Triton支持**自动调优**:尝试多种配置并选择最快的: + +```python +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}), + ], + key=['M', 'N', 'K'], # 当这些变化时重新调优 +) +@triton.jit +def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, ...): + ... +``` + +- Triton在实际硬件上对每种配置进行基准测试并选择最快者。最优瓦片大小取决于GPU架构、矩阵维度和内存布局——自动调优无需手动实验即可找到它们。 + +### Triton vs CUDA:何时使用 + +| | Triton | CUDA C | +|--|--------|--------| +| 语言 | Python | C/C++ | +| 抽象层级 | 块级 | 线程级 | +| 开发速度 | 快(每核函数10-50行) | 慢(100-500行) | +| 性能天花板 | 手工调优CUDA的约80-95% | 100%(完全硬件控制) | +| 共享内存 | 自动 | 手动 | +| 合并 | 自动 | 手动 | +| 线程束级原语 | 有限 | 完整(shuffle、vote等) | +| 硬件支持 | 仅NVIDIA(AMD实验性) | 仅NVIDIA | + +- **使用Triton**对于:融合核函数、自定义注意力模式、激活函数、大多数ML研究核函数需求。 +- **使用CUDA C**对于:最高性能(最后5-20%)、线程束级原语、复杂数据相关并行性、当Triton无法表达你的模式。 + +## 案例研究:Flash Attention + +- **Flash Attention**(Dao等人,2022)是近年来最具影响力的自定义核函数。它以 $O(n)$ 内存而非 $O(n^2)$ 计算注意力,使得更长的序列成为可能。 + +- **问题**:标准注意力计算 $\\text{softmax}(QK^T / \\sqrt{d}) \\cdot V$。$QK^T$ 矩阵是 $n \\times n$,其中 $n$ 是序列长度。对于 $n = 128K$,此矩阵为 $128K \\times 128K \\times 4$ 字节 = 64 GB。它无法放入GPU内存。 + +- **关键洞察**:你不需要具体化完整的 $n \\times n$ 矩阵。按**瓦片**计算注意力:加载一组 $Q$、一组 $K$,计算它们的部分注意力得分,累加,然后移动到下一个块。$n \\times n$ 矩阵从未完全具体化——每次只有一块存在于SRAM中。 + +- **在线softmax**:棘手的部分是softmax,它需要知道整个行上的最大值(为数值稳定性)。Flash Attention使用**在线softmax**技巧:维护一个运行中的最大值,当发现新的最大值时重新缩放先前计算的值。这允许softmax以增量方式逐块计算。 + +- 算法: + +``` +对于每个Q行块: + 对于每个K列块: + 1. 将Q_block从HBM加载到SRAM + 2. 将K_block从HBM加载到SRAM + 3. 计算S_block = Q_block @ K_block.T(在SRAM中) + 4. 更新运行中最大值,重新缩放先前结果 + 5. 计算exp(S_block - 运行中最大值) + 6. 更新运行中求和和输出累加器 + 加载V_block并计算最终输出 + 将输出块写回HBM +``` + +- **为什么它快**:内循环完全在SRAM(共享内存)中操作。全局内存(HBM)仅用于加载Q、K、V块和写入最终输出。数据重用因子与SRAM大小成正比,而SRAM比HBM快约100倍。 + +- Flash Attention在Triton和CUDA C中都有实现。CUDA版本更快(效率高约10%),但Triton版本更具可读性和可修改性,这对研究新的注意力变体很重要。 + +## TPU架构 + +- **TPU**(张量处理单元)是Google的自定义ML加速器。它们采用与GPU截然不同的方法: + +- **脉动阵列**:TPU的核心计算单元是**矩阵乘法单元(MXU)**,一个128×128或256×256的脉动阵列,通过让数据流经乘加单元网格来计算矩阵乘法。数据从边缘进入并通过阵列传播,每个单元执行一次乘加并将结果传递给下一个。 + +- 与GPU(调度数千个独立线程)不同,脉动阵列是单一的确定性数据流。没有线程调度、没有线程束分歧、没有分支预测。这种简朴性使MXU在矩阵乘法方面极其能效高效。 + +- **HBM**:TPU使用与GPU相同的高带宽内存。TPU v5e每芯片16 GB HBM2e;TPU v5p每芯片95 GB HBM2e。 + +- **ICI**(芯片间互连):TPU Pod用自定义高速网络连接数百个TPU。JAX原生支持跨TPU Pod的数据并行性和模型并行性(第6章)。 + +- **BFloat16**:TPU是首个使用bfloat16的(第13章文件02)。BF16具有与float32相同的指数范围(防止训练期间溢出),尾数精度较低。这种权衡对ML是理想的,其中梯度值范围广但不需要23位精度。 + +### 编程TPU:JAX与Pallas + +- TPU通过**JAX**和**XLA**编程。你编写Python/JAX代码,`jax.jit` 将其编译为XLA HLO,XLA将HLO编译为TPU特定的指令。无需CUDA,无需C++。 + +```python +import jax +import jax.numpy as jnp + +@jax.jit +def matmul(a, b): + return jnp.dot(a, b) + +# 这将根据设备在CPU、GPU或TPU上运行 +a = jnp.ones((1024, 1024)) +b = jnp.ones((1024, 1024)) +c = matmul(a, b) +``` + +- **Pallas**是JAX的核函数编写API——JAX版的Triton。它让你编写低级核函数,XLA将其编译为GPU或TPU: + +```python +from jax.experimental import pallas as pl +import jax.numpy as jnp + +def add_kernel(x_ref, y_ref, o_ref): + o_ref[...] = x_ref[...] + y_ref[...] + +def add_pallas(x, y): + return pl.pallas_call( + add_kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + grid=(x.shape[0] // 128,), + in_specs=[pl.BlockSpec((128,), lambda i: (i,)), + pl.BlockSpec((128,), lambda i: (i,))], + out_specs=pl.BlockSpec((128,), lambda i: (i,)), + )(x, y) +``` + +- Pallas比Triton更新且不太成熟,但它是为TPU编写自定义核函数的唯一方式(因为TPU不支持CUDA)。 + +### GPU vs TPU + +| | GPU(NVIDIA) | TPU(Google) | +|--|-------------|--------------| +| 可用性 | 任何云、本地部署 | 仅Google Cloud | +| 编程 | CUDA C、Triton、PyTorch | JAX/XLA、Pallas | +| 灵活性 | 通用计算 | 针对矩阵密集型ML优化 | +| 峰值矩阵乘法FLOPS | 非常高(张量核心) | 非常高(MXU) | +| 非矩阵乘法操作 | 好 | 较慢(通过向量单元路由,而非MXU) | +| 多芯片扩展 | NVLink(8个GPU)、InfiniBand | ICI(数千个TPU,更紧密集成) | +| 成本效率 | 有竞争力 | 大规模训练通常更便宜 | +| 生态系统 | 最大(PyTorch、TensorFlow、JAX) | 面向JAX | + +- **使用GPU**对于:大多数ML工作负载、基于PyTorch的研究、推理服务、有大量非矩阵乘法计算的工作负载。 +- **使用TPU**对于:大规模JAX训练(数千芯片)、Google Cloud上的成本敏感训练、以矩阵乘法为主的工作负载。 + +## 选择合适的工具 + +| 工作负载 | 最佳工具 | 为什么 | +|----------|---------|-------| +| ML训练(PyTorch) | NVIDIA GPU + CUDA/Triton | 最大生态系统、最佳工具链 | +| ML训练(JAX,大规模) | TPU或NVIDIA GPU | TPU在Google规模下成本低,GPU灵活 | +| 自定义融合核函数 | Triton(Python)或CUDA C | Triton开发速度快,CUDA峰值性能高 | +| JAX自定义核函数 | Pallas | TPU唯一选项,也可在GPU上工作 | +| 跨平台推理 | Vulkan(文件07)或ONNX Runtime | 运行在任何GPU供应商上 | +| 移动/边缘推理 | Metal(Apple)、Vulkan(Android)、NNAPI | 平台特定的加速器 | +| 浏览器推理 | WebGPU(文件07) | 浏览器中唯一选项 | +| 仅CPU推理 | ONNX Runtime + AVX/NEON | 无需GPU,使用SIMD(文件02-03) | +| 新型硬件 | 供应商专用SDK | 每个加速器有自己的工具链 | + +## 编程任务(使用带GPU运行时的CoLab) + +1. 编写并运行向量加法的Triton核函数。将其性能与PyTorch内置加法比较。 +```python +import triton +import triton.language as tl +import torch +import time + +@triton.jit +def add_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + mask = offs < n + x = tl.load(x_ptr + offs, mask=mask) + y = tl.load(y_ptr + offs, mask=mask) + tl.store(out_ptr + offs, x + y, mask=mask) + +n = 10_000_000 +x = torch.randn(n, device='cuda') +y = torch.randn(n, device='cuda') + +# Triton +out_triton = torch.empty_like(x) +grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),) +add_kernel[grid](x, y, out_triton, n, BLOCK=1024) + +# PyTorch +out_torch = x + y + +# 验证正确性 +assert torch.allclose(out_triton, out_torch, atol=1e-5) + +# 基准测试 +torch.cuda.synchronize() +start = time.time() +for _ in range(1000): + add_kernel[grid](x, y, out_triton, n, BLOCK=1024) +torch.cuda.synchronize() +triton_time = (time.time() - start) / 1000 + +start = time.time() +for _ in range(1000): + out_torch = x + y +torch.cuda.synchronize() +torch_time = (time.time() - start) / 1000 + +print(f"Triton: {triton_time*1000:.3f} ms") +print(f"PyTorch: {torch_time*1000:.3f} ms") +print(f"比率: {torch_time/triton_time:.2f}x") +``` + +2. 编写一个Triton融合核函数,在单次遍历中执行乘法+加法+ReLU。与三个独立的PyTorch操作比较。 +```python +import triton +import triton.language as tl +import torch +import time + +@triton.jit +def fused_mul_add_relu_kernel(x_ptr, w_ptr, b_ptr, out_ptr, n, BLOCK: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + mask = offs < n + x = tl.load(x_ptr + offs, mask=mask) + w = tl.load(w_ptr + offs, mask=mask) + b = tl.load(b_ptr + offs, mask=mask) + result = tl.maximum(x * w + b, 0.0) # 融合:乘法 + 加法 + relu + tl.store(out_ptr + offs, result, mask=mask) + +n = 10_000_000 +x = torch.randn(n, device='cuda') +w = torch.randn(n, device='cuda') +b = torch.randn(n, device='cuda') + +# 融合(Triton) +out_fused = torch.empty_like(x) +grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),) +fused_mul_add_relu_kernel[grid](x, w, b, out_fused, n, BLOCK=1024) + +# 未融合(PyTorch) +out_unfused = torch.relu(x * w + b) + +assert torch.allclose(out_fused, out_unfused, atol=1e-5) + +# 基准测试 +torch.cuda.synchronize() +start = time.time() +for _ in range(1000): + fused_mul_add_relu_kernel[grid](x, w, b, out_fused, n, BLOCK=1024) +torch.cuda.synchronize() +fused_time = (time.time() - start) / 1000 + +start = time.time() +for _ in range(1000): + out_unfused = torch.relu(x * w + b) +torch.cuda.synchronize() +unfused_time = (time.time() - start) / 1000 + +print(f"融合(Triton): {fused_time*1000:.3f} ms") +print(f"未融合(PyTorch): {unfused_time*1000:.3f} ms") +print(f"加速比: {unfused_time/fused_time:.2f}x") +``` + +3. 测量JAX的XLA编译器如何自动融合操作。比较带和不带jit的操作链。 +```python +import jax +import jax.numpy as jnp +import time + +def chain_ops(x): + x = x * 2.0 + x = x + 1.0 + x = jnp.maximum(x, 0.0) # ReLU + x = x / jnp.sum(x) + return x + +chain_jit = jax.jit(chain_ops) +x = jax.random.normal(jax.random.PRNGKey(0), (10000, 1000)) + +# 预热 +_ = chain_jit(x) +jax.block_until_ready(_) + +# 即时模式(每个操作是独立核函数启动) +start = time.time() +for _ in range(100): + y = chain_ops(x) +jax.block_until_ready(y) +eager_time = (time.time() - start) / 100 + +# JIT(XLA融合操作) +start = time.time() +for _ in range(100): + y = chain_jit(x) +jax.block_until_ready(y) +jit_time = (time.time() - start) / 100 + +print(f"即时: {eager_time*1000:.2f} ms") +print(f"JIT: {jit_time*1000:.2f} ms") +print(f"加速比: {eager_time/jit_time:.1f}x(XLA将4个操作融合为1个核函数)") +``` diff --git a/chapter 16: SIMD and GPU programming/06. RISC-V and embedded systems.md b/chapter 16: SIMD and GPU programming/06. RISC-V and embedded systems.md new file mode 100644 index 0000000..13c32ee --- /dev/null +++ b/chapter 16: SIMD and GPU programming/06. RISC-V and embedded systems.md @@ -0,0 +1,428 @@ +# RISC-V与嵌入式系统 + +*RISC-V是正在重塑芯片行业的开源指令集架构。本文涵盖RISC-V哲学、V向量扩展、嵌入式ML推理、微控制器上的TinyML、AI加速器中的RISC-V以及边缘部署约束* + +- 我们之前讨论的每一种芯片架构(x86、ARM)都需要**许可**。Intel和AMD为x86付费。Apple、Qualcomm以及每一家智能手机厂商每年向ARM支付数十亿美元。**RISC-V**则不同:它是一个开放标准。任何人都可以设计、制造和销售RISC-V芯片,无需向任何人支付版税。这正在改变芯片设计的经济性,特别是对于AI。 + +## RISC-V哲学 + +- RISC-V(发音为"risk five")于2010年在加州大学伯克利分校创建,作为一个简洁、现代的RISC指令集。关键原则: + + - **开放标准**:ISA规范免费提供。你可以在没有许可费、NDA或法律协议的情况下构建RISC-V CPU。这就像Linux之于操作系统——任何人都可以使用、修改和在此基础上构建。 + + - **模块化设计**:基础ISA(RV32I或RV64I)是最小的——仅47条指令。其他一切都是可选的**扩展**:M(乘法/除法)、A(原子操作)、F/D(浮点)、C(压缩指令)、V(向量处理)。你只选择需要的,保持芯片小巧高效。 + + - **无遗留包袱**:x86背负着45年的向后兼容性。ARM背负着35年。RISC-V从零开始,融入了从两者中吸取的经验教训。没有仅为与1980年代软件兼容而存在的晦涩指令。 + +- **谁在使用RISC-V**:SiFive(通用核心)、阿里巴巴(玄铁服务器核心)、西部数据(存储控制器,已出货数十亿)、乐鑫(ESP32-C3,流行IoT芯片),以及数十家使用RISC-V作为管理其自定义计算单元的控制处理器的AI加速器初创公司。 + +## RISC-V基础架构 + +- 基础整数ISA(RV64I用于64位)具有: + - **32个通用寄存器**(x0-x31,每个64位)。x0硬连接为零(用于在没有特殊指令的情况下实现常见模式)。 + - **固定32位指令宽度**(C扩展为代码密度添加了16位压缩指令)。 + - **加载-存储架构**:与ARM一样,算术仅操作寄存器。内存访问通过显式加载/存储指令进行。 + +``` +# RISC-V汇编(感受风格——你将使用C/C++) +add x3, x1, x2 # x3 = x1 + x2 +lw x4, 0(x5) # 从x5中的地址加载字 +sw x4, 8(x5) # 存储字到地址 x5 + 8 +beq x1, x2, label # 如果x1 == x2则分支 +``` + +- ISA的简洁性使RISC-V核心小巧且能效高。最小的RV32I核心可以用约10,000个门实现(ARM Cortex-M0约为12,000)。这对于每一毫瓦和每一平方毫米硅片都至关重要的嵌入式系统很重要。 + +## V扩展:RISC-V向量处理 + +- **V扩展**(RVV)为RISC-V添加了可伸缩向量处理,类似于ARM SVE。向量寄存器具有可配置长度(VLEN),由硬件指定(128到65,536位)。代码编写为**向量长度无关**:无需重新编译可在任何VLEN上工作。 + +```c +#include + +// 使用RVV内联函数进行向量加法 +void vadd_rvv(const float* a, const float* b, float* c, int n) { + while (n > 0) { + // vsetvl:设置向量长度——处理 min(n, 硬件最大值) 个元素 + size_t vl = __riscv_vsetvl_e32m1(n); + + // 加载vl个元素 + vfloat32m1_t va = __riscv_vle32_v_f32m1(a, vl); + vfloat32m1_t vb = __riscv_vle32_v_f32m1(b, vl); + + // 相加 + vfloat32m1_t vc = __riscv_vfadd_vv_f32m1(va, vb, vl); + + // 存储 + __riscv_vse32_v_f32m1(c, vc, vl); + + // 前进指针 + a += vl; b += vl; c += vl; n -= vl; + } +} +``` + +- **`vsetvl`** 是关键指令。它告诉硬件"我想处理这么多元素",硬件回应"我可以处理这么多"(受VLEN限制)。循环自动适应任何向量宽度,无需标量清理(最后一次迭代只处理较少的元素)。 + +- **LMUL**(长度乘数):RVV可以将多个向量寄存器分组在一起(m1、m2、m4、m8),以每条指令处理更多元素,代价是可用的寄存器更少。`m1` 每个向量操作数使用一个寄存器;`m8` 使用八个,处理8倍元素,但只留下4个寄存器组可用。 + +- 与x86 AVX(固定256/512位)和ARM NEON(固定128位)相比,RVV的可伸缩性对于多样化硬件是一个重要优势:相同代码在小型嵌入式核心(VLEN=128)和高性能服务器核心(VLEN=1024+)上运行。 + +## 嵌入式ML:TinyML + +- **TinyML**是微控制器上的机器学习——具有千字节RAM、兆赫级CPU和毫瓦功率预算的设备。想想:检测关键词的传感器("Hey Siri")、分类手势的加速度计、或计数人数的小型摄像头,所有这些都在一个售价0.50美元、无需互联网连接的芯片上运行。 + +- 约束条件极其严苛: + +| 资源 | 服务器GPU | 智能手机 | 微控制器 | +|--------|----------|------------|-----------------| +| RAM | 80 GB | 6 GB | 256 KB | +| 存储 | TB | 128 GB | 1 MB | +| 计算能力 | 1000 TFLOPS | 10 TFLOPS | 0.001 TFLOPS | +| 功耗 | 700 W | 5 W | 0.001 W | +| 成本 | $30,000 | $500 | $1 | + +- 适合服务器GPU的模型($O(10^{10})$ 参数)无法放入微控制器。TinyML模型有 $O(10^4)$–$O(10^6)$ 参数,并使用INT8甚至INT4量化。 + +### TensorFlow Lite Micro(TFLM) + +- **TFLM**是Google的微控制器推理框架。它运行量化的TensorFlow Lite模型,无需动态内存分配、无需操作系统,二进制占用约20 KB。 + +```cpp +// 微控制器上的TinyML推理(简化版) +#include "tensorflow/lite/micro/micro_interpreter.h" +#include "tensorflow/lite/micro/micro_mutable_op_resolver.h" + +// 模型编译为C数组(const unsigned char model_data[]) +const tflite::Model* model = tflite::GetModel(model_data); + +// 分配固定内存缓冲区(无malloc!) +constexpr int kArenaSize = 10 * 1024; // 10 KB +uint8_t tensor_arena[kArenaSize]; + +// 设置解释器 +tflite::MicroInterpreter interpreter(model, resolver, tensor_arena, kArenaSize); +interpreter.AllocateTensors(); + +// 设置输入 +float* input = interpreter.input(0)->data.f; +input[0] = sensor_reading; + +// 运行推理 +interpreter.Invoke(); + +// 读取输出 +float* output = interpreter.output(0)->data.f; +if (output[0] > 0.8f) { + trigger_alert(); +} +``` + +- **此代码中的关键约束**: + - `tensor_arena` 是静态分配的——没有 `malloc`,没有堆。嵌入式系统通常没有动态内存分配器。 + - 模型是一个 `const` 字节数组,存储在闪存(ROM)中,而非从文件系统加载。 + - 整个框架+模型+运行时适合几十KB。 + +### 边缘模型优化 + +- 让模型在微控制器上运行需要激进优化: + + - **量化**(第18章):将float32权重转换为INT8(小4倍,在纯整数硬件上快2-4倍)。训练后量化简单;量化感知训练保留更多精度。 + + - **剪枝**:移除接近零的权重。结构化剪枝(移除整个通道/头)比非结构化剪枝(随机零)对硬件更友好,因为它减少实际计算,而不仅是存储。 + + - **知识蒸馏**(第6章):训练一个小型"学生"模型来模仿大型"教师"模型。学生模型比从头训练获得更高精度,因为它从教师模型的软预测中学习。 + + - **神经架构搜索(NAS)**:自动搜索适合硬件预算(延迟、内存、功耗)的高效架构。**MicroNets**和**MCUNet**为特定微控制器寻找优化架构。 + + - **算子融合**:将卷积+批归一化+ReLU组合成单个融合操作,消除中间内存写入(与GPU核函数融合同一原则,但在只有256 KB RAM时更加关键)。 + +## RISC-V在AI加速器中的应用 + +- 许多AI加速器初创公司使用RISC-V并非直接运行ML模型,而是作为管理自定义计算单元的**控制处理器**: + +``` +┌─────────────────────────────────────────┐ +│ AI加速器 │ +│ │ +│ ┌──────────┐ ┌──────────────────┐ │ +│ │ RISC-V │───→│ 自定义矩阵 │ │ +│ │ 控制 │ │ 乘法单元 │ │ +│ │ 核心 │ │ (脉动阵列、 │ │ +│ │ │ │ 自定义数据流) │ │ +│ └──────────┘ └──────────────────┘ │ +│ │ │ │ +│ ▼ ▼ │ +│ ┌──────────┐ ┌──────────────────┐ │ +│ │ 内存 │ │ 片上SRAM │ │ +│ │ 控制 │ │ (激活缓冲) │ │ +│ │ │ │ │ │ +│ └──────────┘ └──────────────────┘ │ +└─────────────────────────────────────────┘ +``` + +- RISC-V核心处理:从外部内存加载模型权重、调度层执行、管理计算单元之间的数据流以及与主机通信(通过PCIe、USB或SPI)。繁重计算(矩阵乘法、卷积)由自定义硬件完成,而非RISC-V核心。 + +- **为什么用RISC-V做控制**:无需许可费用(对初创公司至关重要)、可定制(添加领域特定指令)、小占用空间(控制核心不需要x86的复杂性),以及开放生态系统支持快速原型开发。 + +- **示例**:Esperanto Technologies(1000+个RISC-V核心用于ML)、Tenstorrent(RISC-V控制+自定义tensix核心)、SiFive(带向量扩展的RISC-V核心用于边缘ML)。 + +## 边缘部署约束 + +- 在边缘部署ML(设备端,而非云端)引入了云端部署不需要的约束: + +- **功耗**:电池供电的设备总功耗预算可能为100 mW。运行消耗50 mW的模型只给系统其余部分(传感器、无线电、显示器)留下50 mW。功耗感知推理调度计算以避免热降频并延长电池寿命。 + +- **延迟**:边缘推理通常必须是实时的。唤醒词检测器("Hey Siri")必须在约200 ms内响应。自动驾驶感知系统(第11章)必须在约30 ms内处理帧。到云端的网络往返(50-200 ms)对这些用例来说太慢了。 + +- **隐私**:在设备上处理数据意味着敏感数据(医学图像、语音记录、个人照片)永远不会离开设备。这在某些司法管辖区是法律要求(GDPR),在所有地方都是用户信任的要求。 + +- **连接性**:边缘设备可能间歇性或完全没有互联网连接。在火星车(第11章)、潜艇或农村农田传感器上运行的模型必须完全离线工作。 + +- **规模成本**:将ML部署到十亿部智能手机每台成本为$0(硬件已经存在)。将ML部署到十亿个IoT传感器意味着每个传感器的ML硬件预算只有几分钱。RISC-V的零许可成本在这个规模下意义重大。 + +## 编程任务(用g++或riscv64-gcc交叉编译器编译) + +1. 编写一个C程序,模拟TinyML推理流水线:静态分配模型缓冲区,运行模拟前向传播,并测量资源使用。这教授嵌入式约束(无malloc、固定内存缓冲区)。 +```cpp +// task1_tinyml_sim.cpp +// 编译:g++ -O2 -o task1 task1_tinyml_sim.cpp + +#include +#include +#include +#include + +// 模拟微控制器:固定内存缓冲区,无动态分配 +static constexpr int ARENA_SIZE = 32 * 1024; // 32 KB总RAM预算 +static uint8_t arena[ARENA_SIZE]; + +// 简单的2层MLP:784 -> 64 -> 10(类似MNIST,INT8权重) +struct TinyModel { + int8_t w1[784 * 64]; // 层1权重:50,176字节 + int8_t b1[64]; // 层1偏置 + int8_t w2[64 * 10]; // 层2权重:640字节 + int8_t b2[10]; // 层2偏置 + // 总计:约51 KB → 必须放在闪存(ROM),而非RAM +}; + +// 检查模型是否适合闪存 +void check_model_fit(int flash_kb) { + int model_bytes = sizeof(TinyModel); + std::cout << "模型大小: " << model_bytes << " 字节(" + << model_bytes / 1024 << " KB)\n"; + std::cout << "闪存: " << flash_kb << " KB → " + << (model_bytes <= flash_kb * 1024 ? "适合" : "太大") << "\n"; +} + +// 使用固定缓冲区进行激活的模拟推理 +void mock_inference(const int8_t* input, int8_t* output) { + // 激活值放在缓冲区(RAM)中,而非动态分配 + int8_t* act1 = (int8_t*)arena; // 层1输出64字节 + int8_t* act2 = (int8_t*)(arena + 64); // 层2输出10字节 + + // 层1:简化版矩阵乘法(不是真正的量化矩阵乘法,仅结构演示) + for (int j = 0; j < 64; j++) { + int32_t sum = 0; // 用int32累加避免溢出 + for (int i = 0; i < 784; i++) { + sum += (int32_t)input[i] * 1; // 模拟:权重=1 + } + act1[j] = (int8_t)std::max(-128, std::min(127, sum / 784)); // 量化回 + act1[j] = act1[j] > 0 ? act1[j] : 0; // ReLU + } + + // 层2 + for (int j = 0; j < 10; j++) { + int32_t sum = 0; + for (int i = 0; i < 64; i++) { + sum += (int32_t)act1[i] * 1; + } + act2[j] = (int8_t)std::max(-128, std::min(127, sum / 64)); + } + + std::memcpy(output, act2, 10); +} + +int main() { + std::cout << "=== TinyML资源预算 ===\n"; + std::cout << "缓冲区(RAM): " << ARENA_SIZE << " 字节(" + << ARENA_SIZE / 1024 << " KB)\n"; + check_model_fit(256); // 典型MCU闪存 + + // 激活内存使用 + int activation_bytes = 64 + 10; // 层1 + 层2输出 + std::cout << "激活内存: " << activation_bytes + << " 字节 / " << ARENA_SIZE << " 可用\n\n"; + + // 基准测试推理 + int8_t input[784]; + int8_t output[10]; + std::memset(input, 1, 784); + + auto start = std::chrono::high_resolution_clock::now(); + for (int i = 0; i < 10000; i++) { + mock_inference(input, output); + } + auto end = std::chrono::high_resolution_clock::now(); + double us = std::chrono::duration(end - start).count() / 10000; + + std::cout << "推理延迟: " << us << " us\n"; + std::cout << "在160 MHz MCU(约6.25 ns/周期)下:约" + << (int)(us * 160) << " 周期\n"; + + std::cout << "输出logits: "; + for (int i = 0; i < 10; i++) std::cout << (int)output[i] << " "; + std::cout << "\n"; + + return 0; +} +``` + +2. 编写一个C++程序,将float32权重量化为INT8,并测量压缩比和量化误差。 +```cpp +// task2_quantise.cpp +// 编译:g++ -O3 -o task2 task2_quantise.cpp + +#include +#include +#include +#include +#include + +// 对称量化:将浮点范围 [-max, +max] 映射到 [-127, +127] +void quantise_symmetric(const float* input, int8_t* output, int n, float& scale) { + float max_val = 0.0f; + for (int i = 0; i < n; i++) { + max_val = std::max(max_val, std::abs(input[i])); + } + scale = max_val / 127.0f; + for (int i = 0; i < n; i++) { + float scaled = input[i] / scale; + output[i] = (int8_t)std::max(-127.0f, std::min(127.0f, std::round(scaled))); + } +} + +// 反量化:INT8转回float +void dequantise(const int8_t* input, float* output, int n, float scale) { + for (int i = 0; i < n; i++) { + output[i] = (float)input[i] * scale; + } +} + +int main() { + const int N = 100000; + + // 模拟随机权重(大致正态分布) + std::vector weights(N); + for (int i = 0; i < N; i++) { + // 简单的伪随机正态值 + float u1 = (float)(i * 7 % 997 + 1) / 998.0f; + float u2 = (float)(i * 13 % 991 + 1) / 992.0f; + weights[i] = std::sqrt(-2.0f * std::log(u1)) * std::cos(6.2832f * u2) * 0.1f; + } + + // 量化 + std::vector quantised(N); + float scale; + quantise_symmetric(weights.data(), quantised.data(), N, scale); + + // 反量化并测量误差 + std::vector reconstructed(N); + dequantise(quantised.data(), reconstructed.data(), N, scale); + + float max_error = 0.0f, total_error = 0.0f; + for (int i = 0; i < N; i++) { + float err = std::abs(weights[i] - reconstructed[i]); + max_error = std::max(max_error, err); + total_error += err; + } + + std::cout << "=== 量化结果 ===\n"; + std::cout << "原始: " << N * 4 << " 字节(float32)\n"; + std::cout << "量化: " << N * 1 << " 字节(int8)+ 4 字节(缩放因子)\n"; + std::cout << "压缩比: " << 4.0f << "x\n"; + std::cout << "缩放因子: " << scale << "\n"; + std::cout << "平均绝对误差: " << total_error / N << "\n"; + std::cout << "最大绝对误差: " << max_error << "\n"; + std::cout << "最大绝对误差/缩放因子: " << max_error / scale + << "(应 <= 0.5 量化级别)\n"; + + return 0; +} +``` + +3. 编写一个C++程序,执行INT8矩阵乘法(INT32累加)——这是在嵌入式ML加速器上运行的实际计算。 +```cpp +// task3_int8_matmul.cpp +// 编译:g++ -O3 -o task3 task3_int8_matmul.cpp + +#include +#include +#include +#include + +// INT8矩阵乘法(INT32累加)——张量核心和MCU加速器的实际工作方式 +void matmul_int8(const int8_t* A, const int8_t* B, int32_t* C, + int M, int N, int K) { + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + int32_t sum = 0; + for (int k = 0; k < K; k++) { + sum += (int32_t)A[i * K + k] * (int32_t)B[k * N + j]; + } + C[i * N + j] = sum; + } + } +} + +// 用于比较的Float32矩阵乘法 +void matmul_f32(const float* A, const float* B, float* C, + int M, int N, int K) { + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + float sum = 0.0f; + for (int k = 0; k < K; k++) { + sum += A[i * K + k] * B[k * N + j]; + } + C[i * N + j] = sum; + } + } +} + +int main() { + const int M = 128, N = 128, K = 128; + + std::vector A_i8(M * K, 1), B_i8(K * N, 1); + std::vector C_i32(M * N); + + std::vector A_f32(M * K, 1.0f), B_f32(K * N, 1.0f); + std::vector C_f32(M * N); + + // 基准测试INT8 + auto start = std::chrono::high_resolution_clock::now(); + for (int t = 0; t < 100; t++) { + matmul_int8(A_i8.data(), B_i8.data(), C_i32.data(), M, N, K); + } + auto end = std::chrono::high_resolution_clock::now(); + double i8_ms = std::chrono::duration(end - start).count() / 100; + + // 基准测试FP32 + start = std::chrono::high_resolution_clock::now(); + for (int t = 0; t < 100; t++) { + matmul_f32(A_f32.data(), B_f32.data(), C_f32.data(), M, N, K); + } + end = std::chrono::high_resolution_clock::now(); + double f32_ms = std::chrono::duration(end - start).count() / 100; + + double gops_i8 = 2.0 * M * N * K / i8_ms / 1e6; + double gflops_f32 = 2.0 * M * N * K / f32_ms / 1e6; + + std::cout << "INT8矩阵乘法: " << i8_ms << " ms(" << gops_i8 << " GOPS)\n"; + std::cout << "FP32矩阵乘法: " << f32_ms << " ms(" << gflops_f32 << " GFLOPS)\n"; + std::cout << "INT8加速比: " << f32_ms / i8_ms << "x\n"; + std::cout << "内存: INT8 = " << M*K + K*N << " 字节 vs FP32 = " + << (M*K + K*N) * 4 << " 字节(小4倍)\n"; + + return 0; +} +``` diff --git a/chapter 16: SIMD and GPU programming/07. vulkan compute and cross-platform GPU.md b/chapter 16: SIMD and GPU programming/07. vulkan compute and cross-platform GPU.md new file mode 100644 index 0000000..d2948b5 --- /dev/null +++ b/chapter 16: SIMD and GPU programming/07. vulkan compute and cross-platform GPU.md @@ -0,0 +1,668 @@ +# Vulkan Compute 与跨平台 GPU + +*Vulkan 是唯一能在所有主要平台上运行的 GPU 计算 API:NVIDIA、AMD、Intel、Apple(通过 MoltenVK)、Android,甚至浏览器(通过 WebGPU)。本文涵盖 Vulkan 架构、计算管线、使用 GLSL 编写计算着色器、GPU 计算程序的完整 C++ 设置、共享内存与同步、用于浏览器的 WebGPU,以及实际的机器学习推理示例。* + +- CUDA 在 NVIDIA 硬件上主导着 ML 训练。但并非每个部署目标都有 NVIDIA GPU。移动应用运行在 Qualcomm Adreno 或 ARM Mali GPU 上。Web 应用运行在浏览器中。游戏引擎需要同时支持 AMD、Intel 和 NVIDIA。对于所有这些场景,**Vulkan** 就是答案。 + +- Vulkan 很冗长——一个"hello world"计算程序大约有 300 行 C++ 代码。但这种冗长是 **显式控制** 的代价:你需要自己管理每一个 GPU 资源(内存、管线、命令缓冲区)。这种控制带来了最大性能和可移植性,代价是开发速度。 + +## Vulkan 架构概述 + +- Vulkan 是由 Khronos Group(OpenGL 背后的同一组织)创建的低级 GPU API。与 CUDA(它隐藏了 GPU 资源管理)不同,Vulkan 要求你显式地管理: + + - **实例与设备**:创建 Vulkan 实例,枚举可用 GPU,并选择一个。 + - **内存**:显式分配 GPU 内存,指定内存类型(设备本地内存用于速度,主机可见内存用于 CPU 访问)。 + - **缓冲区**:创建引用已分配内存的缓冲区对象。 + - **描述符集**:将缓冲区绑定到着色器输入(类似于计算着色器的函数参数)。 + - **计算管线**:编译着色器并创建管线对象。 + - **命令缓冲区**:记录一系列 GPU 命令(绑定管线、绑定描述符、调度计算)。 + - **队列提交**:将命令缓冲区提交给 GPU 执行。 + - **同步**:使用栅栏和屏障确保正确的执行顺序。 + +- 这与 CUDA 的 `cudaMalloc` + 内核启动模型截然不同。在 CUDA 中,驱动程序在幕后处理大部分工作。在 Vulkan 中,你需要自己做这一切。 + +### 为什么如此冗长? + +- Vulkan 的显式性存在有两方面原因: + + 1. **驱动简化**:OpenGL 驱动极其复杂(它们必须猜测应用程序的意图并进行相应优化)。Vulkan 将该责任转移给应用程序,使驱动更精简、更可预测,并且更容易在各厂商间正确实现。 + + 2. **性能**:对内存布局、同步和命令批处理的显式控制使应用程序能够做出最优决策。在 CUDA 中,驱动可能会插入不必要的同步。在 Vulkan 中,你只在需要时才进行同步。 + +## GLSL 中的计算着色器 + +- **计算着色器** 是在 GPU 上运行的程序,类似于 CUDA 内核。它使用 **GLSL**(OpenGL 着色语言)编写,并编译为 **SPIR-V** 字节码(一种可移植的二进制格式)。 + +### 向量加法 + +```glsl +// add.comp — 编译命令: glslangValidator -V add.comp -o add.spv +#version 450 + +// 工作组大小:每个工作组有 256 个调用(= CUDA 中每块的线程数) +layout(local_size_x = 256) in; + +// 缓冲区绑定(类似于内核参数) +layout(set = 0, binding = 0) buffer InputA { float a[]; }; +layout(set = 0, binding = 1) buffer InputB { float b[]; }; +layout(set = 0, binding = 2) buffer Output { float c[]; }; + +// 推送常量:小的统一数据(类似于内核参数) +layout(push_constant) uniform PushConstants { + uint n; // 元素数量 +}; + +void main() { + uint idx = gl_GlobalInvocationID.x; // 全局线程索引 + if (idx < n) { + c[idx] = a[idx] + b[idx]; + } +} +``` + +- **与 CUDA 概念的映射**: + +| Vulkan | CUDA | 含义 | +|--------|------|------| +| 工作组 (Workgroup) | 块 (Block) | 可以共享内存的线程组 | +| 调用 (Invocation) | 线程 (Thread) | 单个执行单元 | +| `gl_GlobalInvocationID` | `blockIdx * blockDim + threadIdx` | 全局线程索引 | +| `gl_LocalInvocationID` | `threadIdx` | 工作组内的线程索引 | +| `gl_WorkGroupID` | `blockIdx` | 工作组索引 | +| `local_size_x` | `blockDim.x` | 每工作组的线程数 | +| 存储缓冲区 | 全局内存 | 可读写的 GPU 内存 | +| 共享内存 (`shared`) | `__shared__` | 每工作组的高速内存 | +| 推送常量 | 内核参数 | 小的统一数据 | + +### 使用共享内存的 ReLU + +```glsl +// relu_shared.comp +#version 450 + +layout(local_size_x = 256) in; + +layout(set = 0, binding = 0) buffer Input { float input_data[]; }; +layout(set = 0, binding = 1) buffer Output { float output_data[]; }; + +layout(push_constant) uniform PushConstants { uint n; }; + +// 共享内存(等同于 CUDA 的 __shared__) +shared float tile[256]; + +void main() { + uint gid = gl_GlobalInvocationID.x; + uint lid = gl_LocalInvocationID.x; + + // 加载到共享内存 + if (gid < n) { + tile[lid] = input_data[gid]; + } + + // 屏障:等待工作组中所有调用完成加载 + barrier(); // 等同于 CUDA 的 __syncthreads() + + // 计算 ReLU + if (gid < n) { + output_data[gid] = max(tile[lid], 0.0); + } +} +``` + +- 对于 ReLU,共享内存并非严格必要(该操作是按元素进行的)。但这演示了基本模式:加载到共享内存 → 屏障 → 计算 → 存储。对于需要相邻线程数据的操作(卷积、归约、softmax),共享内存是必不可少的。 + +### 并行归约(求和) + +```glsl +// reduce_sum.comp +#version 450 + +layout(local_size_x = 256) in; + +layout(set = 0, binding = 0) buffer Input { float input_data[]; }; +layout(set = 0, binding = 1) buffer Output { float partial_sums[]; }; + +layout(push_constant) uniform PushConstants { uint n; }; + +shared float sdata[256]; + +void main() { + uint gid = gl_GlobalInvocationID.x; + uint lid = gl_LocalInvocationID.x; + uint wgid = gl_WorkGroupID.x; + + // 加载到共享内存 + sdata[lid] = (gid < n) ? input_data[gid] : 0.0; + barrier(); + + // 工作组内的树形归约 + for (uint stride = 128; stride > 0; stride >>= 1) { + if (lid < stride) { + sdata[lid] += sdata[lid + stride]; + } + barrier(); + } + + // 线程 0 写入工作组的局部和 + if (lid == 0) { + partial_sums[wgid] = sdata[0]; + } +} +``` + +- 这是经典的并行归约模式(与 CUDA 相同)。每个工作组产生一个局部和。第二次调度将这些局部和归约为最终结果。树形归约每一步将活跃线程减半:256 → 128 → 64 → ... → 1。 + +### 使用分块的矩阵乘法 + +```glsl +// matmul_tiled.comp +#version 450 + +#define TILE_SIZE 16 + +layout(local_size_x = TILE_SIZE, local_size_y = TILE_SIZE) in; + +layout(set = 0, binding = 0) buffer MatA { float A[]; }; +layout(set = 0, binding = 1) buffer MatB { float B[]; }; +layout(set = 0, binding = 2) buffer MatC { float C[]; }; + +layout(push_constant) uniform PushConstants { + uint M, N, K; +}; + +shared float tileA[TILE_SIZE][TILE_SIZE]; +shared float tileB[TILE_SIZE][TILE_SIZE]; + +void main() { + uint row = gl_GlobalInvocationID.y; + uint col = gl_GlobalInvocationID.x; + uint lr = gl_LocalInvocationID.y; + uint lc = gl_LocalInvocationID.x; + + float sum = 0.0; + + for (uint t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) { + // 将 A 和 B 的分块加载到共享内存中 + uint aCol = t * TILE_SIZE + lc; + uint bRow = t * TILE_SIZE + lr; + + tileA[lr][lc] = (row < M && aCol < K) ? A[row * K + aCol] : 0.0; + tileB[lr][lc] = (bRow < K && col < N) ? B[bRow * N + col] : 0.0; + + barrier(); + + // 计算部分点积 + for (uint k = 0; k < TILE_SIZE; k++) { + sum += tileA[lr][k] * tileB[k][lc]; + } + + barrier(); + } + + if (row < M && col < N) { + C[row * N + col] = sum; + } +} +``` + +- 这与 CUDA 版本(文件 04)中的分块算法相同,只是用了 GLSL 语法。概念完全一样:将分块加载到共享内存,屏障,计算,屏障,重复。 + +## C++ Vulkan 设置 + +- 计算着色器是简单的部分。困难的部分是创建 Vulkan 实例、分配内存、绑定缓冲区和提交命令的 C++ 样板代码。以下是完整管线的精简版本: + +```cpp +// vulkan_compute.cpp — 一个最小但完整的 Vulkan 计算示例 +// 编译命令: g++ -O3 -o vulkan_compute vulkan_compute.cpp -lvulkan +// 要求: 已安装 Vulkan SDK,已从 add.comp 编译 add.spv + +#include +#include +#include +#include +#include + +// 辅助函数:读取 SPIR-V 文件 +std::vector readSPIRV(const std::string& filename) { + std::ifstream file(filename, std::ios::ate | std::ios::binary); + size_t fileSize = file.tellg(); + std::vector buffer(fileSize / sizeof(uint32_t)); + file.seekg(0); + file.read(reinterpret_cast(buffer.data()), fileSize); + return buffer; +} + +int main() { + const uint32_t N = 1024; + const size_t bufferSize = N * sizeof(float); + + // ========== 1. 创建 Vulkan 实例 ========== + VkApplicationInfo appInfo{}; + appInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; + appInfo.apiVersion = VK_API_VERSION_1_2; + + VkInstanceCreateInfo instanceInfo{}; + instanceInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; + instanceInfo.pApplicationInfo = &appInfo; + + VkInstance instance; + vkCreateInstance(&instanceInfo, nullptr, &instance); + + // ========== 2. 选择物理设备 (GPU) ========== + uint32_t deviceCount = 0; + vkEnumeratePhysicalDevices(instance, &deviceCount, nullptr); + std::vector devices(deviceCount); + vkEnumeratePhysicalDevices(instance, &deviceCount, devices.data()); + VkPhysicalDevice physicalDevice = devices[0]; // 使用第一个 GPU + + // 打印 GPU 名称 + VkPhysicalDeviceProperties props; + vkGetPhysicalDeviceProperties(physicalDevice, &props); + std::cout << "使用的 GPU: " << props.deviceName << "\n"; + + // ========== 3. 查找计算队列族 ========== + uint32_t queueFamilyCount = 0; + vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice, &queueFamilyCount, nullptr); + std::vector queueFamilies(queueFamilyCount); + vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice, &queueFamilyCount, queueFamilies.data()); + + uint32_t computeFamily = 0; + for (uint32_t i = 0; i < queueFamilyCount; i++) { + if (queueFamilies[i].queueFlags & VK_QUEUE_COMPUTE_BIT) { + computeFamily = i; + break; + } + } + + // ========== 4. 创建逻辑设备和队列 ========== + float queuePriority = 1.0f; + VkDeviceQueueCreateInfo queueInfo{}; + queueInfo.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; + queueInfo.queueFamilyIndex = computeFamily; + queueInfo.queueCount = 1; + queueInfo.pQueuePriorities = &queuePriority; + + VkDeviceCreateInfo deviceInfo{}; + deviceInfo.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; + deviceInfo.queueCreateInfoCount = 1; + deviceInfo.pQueueCreateInfos = &queueInfo; + + VkDevice device; + vkCreateDevice(physicalDevice, &deviceInfo, nullptr, &device); + + VkQueue computeQueue; + vkGetDeviceQueue(device, computeFamily, 0, &computeQueue); + + // ========== 5. 分配缓冲区 (A, B, C) ========== + // 为简洁起见,这里使用主机可见内存(较慢但更简单) + auto createBuffer = [&](VkBuffer& buffer, VkDeviceMemory& memory) { + VkBufferCreateInfo bufInfo{}; + bufInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; + bufInfo.size = bufferSize; + bufInfo.usage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; + vkCreateBuffer(device, &bufInfo, nullptr, &buffer); + + VkMemoryRequirements memReqs; + vkGetBufferMemoryRequirements(device, buffer, &memReqs); + + // 查找主机可见的内存类型 + VkPhysicalDeviceMemoryProperties memProps; + vkGetPhysicalDeviceMemoryProperties(physicalDevice, &memProps); + uint32_t memType = 0; + for (uint32_t i = 0; i < memProps.memoryTypeCount; i++) { + if ((memReqs.memoryTypeBits & (1 << i)) && + (memProps.memoryTypes[i].propertyFlags & + (VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT))) { + memType = i; + break; + } + } + + VkMemoryAllocateInfo allocInfo{}; + allocInfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; + allocInfo.allocationSize = memReqs.size; + allocInfo.memoryTypeIndex = memType; + vkAllocateMemory(device, &allocInfo, nullptr, &memory); + vkBindBufferMemory(device, buffer, memory, 0); + }; + + VkBuffer bufA, bufB, bufC; + VkDeviceMemory memA, memB, memC; + createBuffer(bufA, memA); + createBuffer(bufB, memB); + createBuffer(bufC, memC); + + // ========== 6. 填充输入缓冲区 ========== + float* ptrA; + vkMapMemory(device, memA, 0, bufferSize, 0, (void**)&ptrA); + for (uint32_t i = 0; i < N; i++) ptrA[i] = 1.0f; + vkUnmapMemory(device, memA); + + float* ptrB; + vkMapMemory(device, memB, 0, bufferSize, 0, (void**)&ptrB); + for (uint32_t i = 0; i < N; i++) ptrB[i] = 2.0f; + vkUnmapMemory(device, memB); + + // ========== 7. 创建计算管线 ========== + auto spirvCode = readSPIRV("add.spv"); + VkShaderModuleCreateInfo shaderInfo{}; + shaderInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; + shaderInfo.codeSize = spirvCode.size() * sizeof(uint32_t); + shaderInfo.pCode = spirvCode.data(); + VkShaderModule shaderModule; + vkCreateShaderModule(device, &shaderInfo, nullptr, &shaderModule); + + // 描述符集布局(告诉 Vulkan 缓冲区绑定的信息) + VkDescriptorSetLayoutBinding bindings[3] = {}; + for (int i = 0; i < 3; i++) { + bindings[i].binding = i; + bindings[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; + bindings[i].descriptorCount = 1; + bindings[i].stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; + } + + VkDescriptorSetLayoutCreateInfo layoutInfo{}; + layoutInfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; + layoutInfo.bindingCount = 3; + layoutInfo.pBindings = bindings; + VkDescriptorSetLayout descLayout; + vkCreateDescriptorSetLayout(device, &layoutInfo, nullptr, &descLayout); + + // 推送常量范围 + VkPushConstantRange pushRange{}; + pushRange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; + pushRange.offset = 0; + pushRange.size = sizeof(uint32_t); + + // 管线布局 + VkPipelineLayoutCreateInfo pipeLayoutInfo{}; + pipeLayoutInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; + pipeLayoutInfo.setLayoutCount = 1; + pipeLayoutInfo.pSetLayouts = &descLayout; + pipeLayoutInfo.pushConstantRangeCount = 1; + pipeLayoutInfo.pPushConstantRanges = &pushRange; + VkPipelineLayout pipelineLayout; + vkCreatePipelineLayout(device, &pipeLayoutInfo, nullptr, &pipelineLayout); + + // 计算管线 + VkComputePipelineCreateInfo pipeInfo{}; + pipeInfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; + pipeInfo.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; + pipeInfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT; + pipeInfo.stage.module = shaderModule; + pipeInfo.stage.pName = "main"; + pipeInfo.layout = pipelineLayout; + VkPipeline pipeline; + vkCreateComputePipelines(device, VK_NULL_HANDLE, 1, &pipeInfo, nullptr, &pipeline); + + // ========== 8. 描述符集(将缓冲区绑定到着色器) ========== + VkDescriptorPoolSize poolSize{}; + poolSize.type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; + poolSize.descriptorCount = 3; + + VkDescriptorPoolCreateInfo poolInfo{}; + poolInfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO; + poolInfo.maxSets = 1; + poolInfo.poolSizeCount = 1; + poolInfo.pPoolSizes = &poolSize; + VkDescriptorPool descPool; + vkCreateDescriptorPool(device, &poolInfo, nullptr, &descPool); + + VkDescriptorSetAllocateInfo descAllocInfo{}; + descAllocInfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; + descAllocInfo.descriptorPool = descPool; + descAllocInfo.descriptorSetCount = 1; + descAllocInfo.pSetLayouts = &descLayout; + VkDescriptorSet descSet; + vkAllocateDescriptorSets(device, &descAllocInfo, &descSet); + + // 将缓冲区引用写入描述符集 + VkDescriptorBufferInfo bufInfos[3] = { + {bufA, 0, bufferSize}, {bufB, 0, bufferSize}, {bufC, 0, bufferSize} + }; + VkWriteDescriptorSet writes[3] = {}; + for (int i = 0; i < 3; i++) { + writes[i].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET; + writes[i].dstSet = descSet; + writes[i].dstBinding = i; + writes[i].descriptorCount = 1; + writes[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; + writes[i].pBufferInfo = &bufInfos[i]; + } + vkUpdateDescriptorSets(device, 3, writes, 0, nullptr); + + // ========== 9. 记录和提交命令缓冲区 ========== + VkCommandPoolCreateInfo cmdPoolInfo{}; + cmdPoolInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; + cmdPoolInfo.queueFamilyIndex = computeFamily; + VkCommandPool cmdPool; + vkCreateCommandPool(device, &cmdPoolInfo, nullptr, &cmdPool); + + VkCommandBufferAllocateInfo cmdAllocInfo{}; + cmdAllocInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; + cmdAllocInfo.commandPool = cmdPool; + cmdAllocInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; + cmdAllocInfo.commandBufferCount = 1; + VkCommandBuffer cmdBuf; + vkAllocateCommandBuffers(device, &cmdAllocInfo, &cmdBuf); + + VkCommandBufferBeginInfo beginInfo{}; + beginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; + vkBeginCommandBuffer(cmdBuf, &beginInfo); + + vkCmdBindPipeline(cmdBuf, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline); + vkCmdBindDescriptorSets(cmdBuf, VK_PIPELINE_BIND_POINT_COMPUTE, + pipelineLayout, 0, 1, &descSet, 0, nullptr); + vkCmdPushConstants(cmdBuf, pipelineLayout, VK_SHADER_STAGE_COMPUTE_BIT, + 0, sizeof(uint32_t), &N); + vkCmdDispatch(cmdBuf, (N + 255) / 256, 1, 1); // 启动工作组 + + vkEndCommandBuffer(cmdBuf); + + // 提交 + VkFenceCreateInfo fenceInfo{}; + fenceInfo.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO; + VkFence fence; + vkCreateFence(device, &fenceInfo, nullptr, &fence); + + VkSubmitInfo submitInfo{}; + submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; + submitInfo.commandBufferCount = 1; + submitInfo.pCommandBuffers = &cmdBuf; + vkQueueSubmit(computeQueue, 1, &submitInfo, fence); + vkWaitForFences(device, 1, &fence, VK_TRUE, UINT64_MAX); + + // ========== 10. 读取结果 ========== + float* ptrC; + vkMapMemory(device, memC, 0, bufferSize, 0, (void**)&ptrC); + std::cout << "结果: c[0]=" << ptrC[0] << " c[1]=" << ptrC[1] + << " (期望值 3.0)\n"; + bool correct = true; + for (uint32_t i = 0; i < N; i++) { + if (ptrC[i] != 3.0f) { correct = false; break; } + } + std::cout << (correct ? "全部正确" : "发现错误") << "\n"; + vkUnmapMemory(device, memC); + + // ========== 清理(简写) ========== + vkDestroyFence(device, fence, nullptr); + vkDestroyCommandPool(device, cmdPool, nullptr); + vkDestroyPipeline(device, pipeline, nullptr); + vkDestroyPipelineLayout(device, pipelineLayout, nullptr); + vkDestroyDescriptorPool(device, descPool, nullptr); + vkDestroyDescriptorSetLayout(device, descLayout, nullptr); + vkDestroyShaderModule(device, shaderModule, nullptr); + vkDestroyBuffer(device, bufA, nullptr); vkFreeMemory(device, memA, nullptr); + vkDestroyBuffer(device, bufB, nullptr); vkFreeMemory(device, memB, nullptr); + vkDestroyBuffer(device, bufC, nullptr); vkFreeMemory(device, memC, nullptr); + vkDestroyDevice(device, nullptr); + vkDestroyInstance(instance, nullptr); + + return 0; +} +``` + +- **是的,向量加法就需要大约 200 行代码。** 相比之下 CUDA 只需要大约 30 行。这就是显式性的代价。但请注意:每一行都有其目的。没有隐藏的驱动决策,没有隐式同步,没有意外的内存分配。你控制一切。 + +- 在实践中,你可以将这些样板代码封装到辅助库中(或使用现有的库,如 **vk-bootstrap**、用于内存分配的 **VMA**,或专注于 ML 的 Vulkan 计算库 **kompute**)。 + +## Kompute:为 ML 简化的 Vulkan + +- **Kompute** 是一个开源 C++ 库,封装了 Vulkan 用于 GPU 计算的样板代码。同样的向量加法变成: + +```cpp +#include + +int main() { + kp::Manager mgr; + + auto tensorA = mgr.tensor({1, 1, 1, 1, 1}); + auto tensorB = mgr.tensor({2, 2, 2, 2, 2}); + auto tensorC = mgr.tensor({0, 0, 0, 0, 0}); + + std::string shader = R"( + #version 450 + layout(local_size_x = 1) in; + layout(set=0, binding=0) buffer A { float a[]; }; + layout(set=0, binding=1) buffer B { float b[]; }; + layout(set=0, binding=2) buffer C { float c[]; }; + void main() { + uint i = gl_GlobalInvocationID.x; + c[i] = a[i] + b[i]; + } + )"; + + auto algorithm = mgr.algorithm({tensorA, tensorB, tensorC}, + kompute::Shader::compile_source(shader)); + + mgr.sequence() + ->record({tensorA, tensorB, tensorC}) + ->record(algorithm) + ->record({tensorC}) + ->eval(); + + // tensorC 现在包含 [3, 3, 3, 3, 3] +} +``` + +- 可读性强多了。Kompute 处理实例创建、设备选择、内存分配、描述符集和命令缓冲区管理。你只需关注着色器和数据。 + +## WebGPU:浏览器中的 GPU 计算 + +- **WebGPU** 是 WebGL 的继任者,提供从 JavaScript 访问现代 GPU 的能力。它基于 Vulkan(Linux/Android)、Metal(macOS/iOS)和 DirectX 12(Windows)构建,抽象了平台差异。 + +- WebGPU 使用 **WGSL**(WebGPU 着色语言)而非 GLSL: + +```wgsl +// add.wgsl — WebGPU 计算着色器 +@group(0) @binding(0) var a: array; +@group(0) @binding(1) var b: array; +@group(0) @binding(2) var c: array; + +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) id: vec3) { + let i = id.x; + c[i] = a[i] + b[i]; +} +``` + +- **JavaScript 设置**(精简版): + +```javascript +const adapter = await navigator.gpu.requestAdapter(); +const device = await adapter.requestDevice(); + +// 创建缓冲区 +const bufferA = device.createBuffer({ size: N * 4, usage: GPUBufferUsage.STORAGE, mappedAtCreation: true }); +new Float32Array(bufferA.getMappedRange()).fill(1.0); +bufferA.unmap(); + +// ...(B 和 C 类似) + +// 从 WGSL 着色器创建管线 +const pipeline = device.createComputePipeline({ + layout: 'auto', + compute: { module: device.createShaderModule({ code: wgslSource }), entryPoint: 'main' } +}); + +// 调度 +const encoder = device.createCommandEncoder(); +const pass = encoder.beginComputePass(); +pass.setPipeline(pipeline); +pass.setBindGroup(0, bindGroup); +pass.dispatchWorkgroups(Math.ceil(N / 256)); +pass.end(); +device.queue.submit([encoder.finish()]); +``` + +- **为什么 WebGPU 对 ML 很重要**:在浏览器中运行推理意味着没有服务器成本、没有延迟,且用户数据永远不会离开设备。像 **ONNX Runtime Web** 和 **Transformers.js** 这样的库使用 WebGPU 完全在客户端运行模型(包括小型 LLM)。 + +## 何时使用 Vulkan + +| 场景 | 使用 Vulkan? | 原因 / 替代方案 | +|------|-------------|----------------| +| ML 训练 | 否 | CUDA/Triton 在 NVIDIA 上更简单更快速 | +| NVIDIA GPU 上的推理 | 否 | TensorRT 或 CUDA 更好 | +| AMD/Intel GPU 上的推理 | **是** | 唯一跨厂商的 GPU 计算选项 | +| 移动端推理(Android) | **是** | Vulkan 是 Android 上的标准 GPU API | +| 移动端推理(iOS) | 否 | 直接使用 Metal(MoltenVK 增加开销) | +| 浏览器推理 | **WebGPU** | 基于 Vulkan/Metal/DX12 | +| 游戏引擎 + ML | **是** | 引擎已使用 Vulkan 进行渲染 | +| 跨平台库 | **是** | 一套代码支持所有 GPU 厂商 | +| 学习 GPU 编程 | 视情况而定 | CUDA 更容易上手;Vulkan 能学到更多 | + +## 编码任务(使用 g++ -lvulkan 编译,需要 Vulkan SDK) + +1. 编译并运行上面的向量加法示例。修改着色器以计算 `c[i] = a[i] * b[i] + a[i]`(融合乘加)并验证结果。 + +2. 编写一个计算着色器,使用共享内存对一行数据应用 softmax(包括最大值和求和归约步骤)。用已知值进行测试。 + +```glsl +// softmax.comp — 编译命令: glslangValidator -V softmax.comp -o softmax.spv +#version 450 + +#define WG_SIZE 256 + +layout(local_size_x = WG_SIZE) in; + +layout(set = 0, binding = 0) buffer Input { float input_data[]; }; +layout(set = 0, binding = 1) buffer Output { float output_data[]; }; + +layout(push_constant) uniform PC { uint n; }; + +shared float sdata[WG_SIZE]; + +void main() { + uint gid = gl_GlobalInvocationID.x; + uint lid = gl_LocalInvocationID.x; + + // 步骤 1:找最大值(数值稳定性) + sdata[lid] = (gid < n) ? input_data[gid] : -1e30; + barrier(); + for (uint s = WG_SIZE / 2; s > 0; s >>= 1) { + if (lid < s) sdata[lid] = max(sdata[lid], sdata[lid + s]); + barrier(); + } + float maxVal = sdata[0]; + barrier(); + + // 步骤 2:计算 exp(x - max) + float expVal = (gid < n) ? exp(input_data[gid] - maxVal) : 0.0; + sdata[lid] = expVal; + barrier(); + + // 步骤 3:exp 值求和 + for (uint s = WG_SIZE / 2; s > 0; s >>= 1) { + if (lid < s) sdata[lid] += sdata[lid + s]; + barrier(); + } + float sumExp = sdata[0]; + + // 步骤 4:归一化 + if (gid < n) { + output_data[gid] = expVal / sumExp; + } +} +``` + +3. 修改 C++ 宿主代码以对计算着色器进行基准测试:使用 Vulkan 时间戳查询或 CPU 端栅栏对调度(排除设置阶段)计时,并计算以 GB/s 为单位的实际带宽。 diff --git a/chapter 17: AI inference/01. quantisation.md b/chapter 17: AI inference/01. quantisation.md new file mode 100644 index 0000000..a755b40 --- /dev/null +++ b/chapter 17: AI inference/01. quantisation.md @@ -0,0 +1,339 @@ +# 量化 + +*量化降低模型权重和激活值的精度,使模型更小、更快、运行成本更低。本文涵盖数字格式、训练后量化、量化感知训练、仅权重量化方法(GPTQ、AWQ)、激活值量化、混合精度和KV缓存量化* + +- 一个70B参数的float16模型需要140 GB内存,超过任何单张GPU。量化为INT4后,它可以装入35 GB(一张A100)甚至20 GB(带卸载的消费级RTX 4090)。量化不是一种可有可无的优化;它是让大模型部署在经济上可行的关键。 + +- 基本权衡:低精度意味着更少内存、更高吞吐量和更低功耗,但会引入**量化误差**,可能降低模型质量。量化的艺术在于最小化这种降级。 + +## 为什么要量化 + +- **内存减少**:INT8比FP16小2倍,INT4小4倍。对于LLM,模型权重占主导内存。精度减半意味着内存需求减半。 + +- **吞吐量提升**:低精度意味着每秒更多操作。NVIDIA Tensor Core(第16章)在FP16 vs FP32上实现2倍吞吐量,INT8 vs FP16再实现2倍,INT4 vs INT8再实现2倍。H100在FP8下达到989 TFLOPS,而FP32下只有67 TFLOPS——相差15倍。 + +- **带宽节省**:LLM推理通常是**内存带宽受限**的(第16章,屋顶模型)。瓶颈是从GPU内存加载权重,而不是计算。更小的权重意味着更少的传输字节,直接提高每秒token数。这就是量化通常能为LLM推理带来近乎线性加速的原因。 + +- **节能**:低精度每次操作消耗更少能量。在数据中心规模(数千GPU)下,这转化为显著的电力成本降低。 + +## 数字格式 + +- 我们在第13章(计算机体系结构)中介绍了IEEE 754浮点数。以下是ML的完整精度全景: + +![精度格式位布局:从FP32到三值,展示符号位、指数和尾数位在内存中的排列方式,以及每参数内存对比](../images/precision_formats_memory.svg) + +| 格式 | 位数 | 指数 | 尾数 | 范围 | 用途 | +|--------|------|----------|----------|-------|----------| +| FP32 | 32 | 8 | 23 | ±3.4×10³⁸ | 训练(黄金标准) | +| TF32 | 19 | 8 | 10 | ±3.4×10³⁸ | Tensor Core训练(A100+) | +| FP16 | 16 | 5 | 10 | ±65504 | 混合精度训练 | +| BF16 | 16 | 8 | 7 | ±3.4×10³⁸ | 训练(与FP32相同的范围) | +| FP8 E4M3 | 8 | 4 | 3 | ±448 | 前向传播(Hopper+) | +| FP8 E5M2 | 8 | 5 | 2 | ±57344 | 梯度(更宽范围) | +| INT8 | 8 | — | — | -128 到 127 | PTQ推理 | +| INT4 | 4 | — | — | -8 到 7 | 仅权重量化 | +| INT2/三值 | 2 | — | — | {-1, 0, 1} | 极限压缩 | + +- **FP8**有两种变体:**E4M3**(4位指数,3位尾数,范围较窄但精度更高)用于前向传播,**E5M2**(5位指数,2位尾数,范围更宽但精度较低)用于梯度。Transformer Engine(第16章)在每个张量之间自动切换。 + +- **BF16 vs FP16**:BF16具有与FP32相同的指数范围(无溢出风险),但尾数精度较低。FP16精度更高但范围较窄(最大65504),训练时需要损失缩放。对于推理,两者都表现良好;对于训练,BF16更安全。 + +- **整数格式**没有指数——它们表示定点值。要在浮点和整数之间转换,需要一个**缩放因子**和一个可选的**零点**:$x_{\text{float}} = \text{scale} \times (x_{\text{int}} - \text{zero\_point})$。 + +## 量化方程 + +- 所有量化方法都将浮点值映射到整数并返回: + +$$x_q = \text{clamp}\left(\text{round}\left(\frac{x}{\text{scale}}\right) + \text{zero\_point}, \; q_{\min}, \; q_{\max}\right)$$ + +$$\hat{x} = \text{scale} \times (x_q - \text{zero\_point})$$ + +- **缩放因子**决定分辨率:$\text{scale} = \frac{x_{\max} - x_{\min}}{q_{\max} - q_{\min}}$。对于INT8:$q_{\min} = -128$,$q_{\max} = 127$。 + +- **对称量化**设置$\text{zero\_point} = 0$,因此$\text{scale} = \frac{\max(|x|)}{127}$。更简单、更快(推理时无需减去零点)。 + +- **非对称量化**使用非零$\text{zero\_point}$来处理非对称分布(例如,ReLU输出全为非负)。将$[x_{\min}, x_{\max}]$映射到无符号INT8的$[0, 255]$。 + +![量化粒度:逐张量为整个矩阵使用一个缩放因子,逐通道每列一个,逐组每小块一个](../images/quantisation_granularity.svg) + +- **量化粒度**:多少个值共享同一个缩放因子: + - **逐张量**:整个张量一个缩放因子。最简单但精度最低(一个异常值就会扭曲整个张量的缩放因子)。 + - **逐通道**:每个输出通道(卷积)或每行(线性层)一个缩放因子。精度好得多,开销最小。 + - **逐组**:每$g$个元素一组(例如$g = 128$)一个缩放因子。精度最佳,用于现代仅权重量化(GPTQ、AWQ)。 + - **逐token**:每个token一个缩放因子用于激活值。处理不同token激活值幅度差异很大的情况。 + +## 训练后量化(PTQ) + +- **PTQ**量化预训练模型而不需要重新训练。通过**校准集**(一个小的代表性数据集,通常128-512个样本)输入模型收集激活值统计信息,然后计算最优缩放因子。 + +### 校准方法 + +- **最小-最大**:基于观察到的最小值和最大值设置缩放因子。简单但容易受异常值影响(一个极端值将大部分量化范围浪费在很少使用的值上)。 + +- **百分位数**:使用99.99百分位数而不是绝对最大值。裁剪极端异常值,为大多数值提供更好的分辨率。裁剪后的值饱和到$q_{\min}$或$q_{\max}$。 + +- **MSE最优**:找到最小化原始张量和量化张量之间均方误差的缩放因子。这是一个一维优化(搜索可能的裁剪值),通常给出最好的PTQ精度。 + +- **基于熵**(KL散度):找到最小化原始和量化值分布之间KL散度的缩放因子。用于TensorRT的INT8校准。 + +### PTQ实践 + +```python +# 使用PyTorch的简化PTQ(概念性) +import torch + +def quantise_tensor_symmetric(tensor, bits=8): + qmax = 2 ** (bits - 1) - 1 # INT8的127 + scale = tensor.abs().max() / qmax + quantised = torch.clamp(torch.round(tensor / scale), -qmax, qmax).to(torch.int8) + return quantised, scale + +def dequantise(quantised, scale): + return quantised.float() * scale + +# 量化一个权重矩阵 +weight = torch.randn(512, 512) # 预训练权重 +weight_q, scale = quantise_tensor_symmetric(weight, bits=8) +weight_reconstructed = dequantise(weight_q, scale) + +# 量化误差 +error = (weight - weight_reconstructed).abs().mean() +print(f"平均绝对误差: {error:.6f}") +print(f"压缩比: {weight.numel() * 4 / (weight_q.numel() * 1 + 4):.1f}x") # +4字节用于缩放因子 +``` + +- PTQ在INT8上对大多数模型效果良好,精度下降<1%。对于INT4,PTQ质量显著下降——仅权重量化方法(见下文)处理INT4要好得多。 + +## 量化感知训练(QAT) + +- **QAT**在训练图中插入伪量化操作:在前向传播中,权重和激活值被量化和反量化,但梯度像没有量化一样流过(**直通估计器**)。 + +$$\text{前向: } \hat{W} = \text{反量化}(\text{量化}(W))$$ +$$\text{反向: } \frac{\partial L}{\partial W} \approx \frac{\partial L}{\partial \hat{W}}$$ + +- 模型在训练过程中学会了抵抗量化噪声。QAT通常能恢复PTQ损失的全部或大部分精度,特别是在低位宽(INT4、INT2)下。 + +- **成本**:QAT需要重新训练(或微调)模型,这对大模型来说成本高昂。对于一个70B参数模型,QAT可能需要$10,000-$100,000的计算成本。PTQ基本上零成本(只需校准)。 + +- **何时使用QAT**:当PTQ质量不可接受时(通常是INT4或更低),当部署到有严格延迟预算的边缘设备时,或者当模型将被量化数百万次时(一次性QAT成本被摊销)。 + +## 仅权重量化 + +- 对于LLM推理,瓶颈是从内存加载权重,而不是计算(内存带宽受限模式)。**仅权重量化**将权重量化为INT4或INT3,而保持激活值为FP16。计算在FP16中进行(在运行时反量化权重),但内存消耗和带宽减少了4-8倍。 + +### GPTQ + +- **GPTQ**(Frantar等人,2022)一次量化一列权重,通过调整后续列来补偿每列的误差。它使用**Hessian矩阵**(来自校准集的二阶信息)来确定最优量化顺序和误差补偿: + +$$\hat{W}_{:,j} = \text{quant}(W_{:,j}), \quad W_{:,j+1:} \mathrel{-}= \frac{(\hat{W}_{:,j} - W_{:,j}) \cdot H_{j,j+1:}}{H_{j,j}}$$ + +- 关键洞察:量化第$j$列会引入误差。GPTQ立即通过调整所有剩余列来补偿,使得该层的整体输出($XW$)变化尽可能小。这是应用于Transformer的**最优脑量化**(OBQ)。 + +- 使用4位组量化(组大小128)的GPTQ在大多数LLM上达到<1%的困惑度降级。在单GPU上,70B模型的量化大约需要1小时。 + +### AWQ + +- **AWQ**(激活感知权重量化,Lin等人,2023)观察到一小部分权重通道(1-3%)比其他通道重要得多——它们对应于具有大幅度的激活通道。保护这些显著通道可以大幅降低量化误差。 + +- AWQ在量化前将这些重要通道乘以一个因子$s$(使它们变大,因此受舍入影响更小),并将相应的激活值乘以$1/s$(以保持输出不变)。缩放因子$s$按组优化,以最小化整体量化误差。 + +- AWQ比GPTQ更简单(无需Hessian计算),运行更快,并达到可比较的质量。它已成为许多开源LLM量化流程的默认选择。 + +### GGUF / llama.cpp量化 + +- **GGUF**(GGML通用格式)是llama.cpp用于CPU推理的格式。它支持多种量化方案: + - **Q4_0**:4位,32元素块,对称。 + - **Q4_K_M**:4位,带混合精度重要通道(k-quants)。 + - **Q5_K_M**:5位,带k-quants(更高质量)。 + - **Q8_0**:8位,简单快速。 + +- "K"变体(k-quants)为重要的权重块分配更多位,类似于AWQ的洞察但实现在格式层面。Q4_K_M是大多数模型的最佳选择:平均4位,质量损失最小。 + +### QuIP和QuIP# + +- **QuIP**(Chee等人,2023)引入了**非相干处理**:在量化之前使用随机正交变换旋转权重矩阵。这会将信息分散到所有权重上,防止少数异常权重主导量化误差。 + +- 直觉:如果一个权重是100,其余的大约是1,用相同的缩放因子量化所有权重会浪费INT4的大部分范围在异常值上。经过正交旋转(保持矩阵的数学性质)后,所有权重具有相似幅度,均匀量化效果更好。 + +- **QuIP#** 通过**格点码本**扩展了这一思想:不是映射到均匀整数网格,而是映射到最优格点中的点(8D中的E8格点)。格点编码在相同位数内打包更多量化点,实现了比均匀量化更好的率失真性能。QuIP#在**2位**精度下达到了可用质量——典型INT4方法的一半位数。 + +### SpQR + +- **SpQR**(Dettmers等人,2023)观察到极小一部分权重(0.1-1%)是**异常值**,对输出质量的贡献不成比例。SpQR不是将所有内容量化到相同精度,而是: + + 1. 使用敏感性分析(量化这个权重会改变层输出多少?)识别异常权重。 + 2. 以**全精度**(FP16)的稀疏格式存储异常值。 + 3. 将所有剩余权重量化为INT3或INT4。 + +- 结果:~99%的权重被积极量化(小),而关键的1%保持全精度(准确)。稀疏异常值存储增加的开销最小(占总大小的<5%)。 + +### HQQ + +- **HQQ**(半二次量化,Badri & Shaji,2023)是一种**零样本**权重量化方法,完全不需要校准数据。它将量化表述为一个半二次优化问题,迭代求解最优量化权重和缩放因子。 + +- 优势:无需校准集意味着没有数据依赖,即时量化,也没有校准数据不匹配的风险。HQQ对于无法获得代表性校准数据或数据敏感型的模型特别有用。 + +### AQLM + +- **AQLM**(Egiazarian等人,2024)将**加法量化**(多码本向量量化)应用于LLM。AQLM不是独立量化每个权重,而是将权重分组为向量,并将每个向量表示为来自多个学习到的码本的条目之和: + +$$\mathbf{w} \approx \mathbf{c}_1^{(1)} + \mathbf{c}_2^{(2)} + \cdots + \mathbf{c}_M^{(M)}$$ + +- 其中$\mathbf{c}_i^{(m)}$是来自码本$m$的一个条目。有$M = 2$个码本,每个有256个条目,一个8元素向量被编码为两个8位索引 = 8个权重2字节 = 每个权重有效**2位**。AQLM在2位精度下达到了最先进的质量,在这个极限压缩水平上优于GPTQ和AWQ。 + +### BitNet和1位LLM + +- **BitNet**(Wang等人,2023)将量化推向极致:权重是三值的($\{-1, 0, +1\}$),每个权重仅需约1.58位。矩阵乘法变成**只有加法和减法**——不需要浮点乘法。 + +- **BitNet b1.58**(Ma等人,2024)将每个权重约束为$\{-1, 0, +1\}$。"1.58位"来自$\log_2(3) \approx 1.58$。在这个精度下,一个70B模型适合约15 GB,推理不需要乘法运算——只需加、减和符号翻转。 + +- 矩阵乘法变成: + +$$y_j = \sum_i W_{ij} \cdot x_i = \sum_{i: W_{ij}=+1} x_i - \sum_{i: W_{ij}=-1} x_i$$ + +- 这比在任何硬件上的FP16矩阵乘法都要便宜得多,并且可以在没有浮点单元的设备上实现LLM推理。对于当前模型,质量权衡是显著的,但随着规模和训练时量化感知能力的提高而改善。 + +### 微缩放(MX)格式 + +- **微缩放**(MX)格式是一种新的行业标准(由AMD、Arm、Intel、Meta、Microsoft、NVIDIA、Qualcomm支持),使用**块浮点**:一组元素共享一个指数,每个元素有自己的尾数。 + +| 格式 | 共享指数 | 元素位数 | 总计(每元素) | 等价 | +|--------|----------------|-------------|--------------------|----| +| MXFP8 | 每块8位 | 8(E4M3/E5M2) | ~8 | 类似FP8,范围更好 | +| MXFP6 | 每块8位 | 6 | ~6.5 | 介于FP8和INT4之间 | +| MXFP4 | 每块8位 | 4 | ~4.5 | 类似INT4,但有浮点行为 | +| MXINT8 | 每块8位 | 8(整数) | ~8.5 | INT8,带共享缩放 | + +- 共享指数将指数成本分摊到一个块(通常16-32个元素)。每个元素比单独指数时保留更多尾数位,每位的精度更好。MX格式预计将在未来硬件中替代单独的FP8和INT8格式。 + +### FP8训练 + +- 在FP8中训练(不仅仅是推理)现在在NVIDIA Hopper和Blackwell GPU上可行。方案如下: + + - **前向传播**:权重和激活值使用E4M3(更高精度,更窄范围)。Transformer Engine使用延迟缩放(跟踪上一次迭代的统计信息,应用于当前迭代)动态计算每张量缩放因子。 + + - **反向传播**:梯度使用E5M2(更宽范围,更低精度)。梯度的值范围比权重/激活值更广,因此额外的指数位防止溢出。 + + - **主权重**:以FP32维护,用于优化器状态(就像使用FP16的标准混合精度训练,第6章)。FP8计算仅用于矩阵乘法,不用于权重更新。 + + - **损失缩放**:FP8仍然需要,就像FP16一样。动态损失缩放器调整缩放因子,使梯度值保持在FP8的可表示范围内。 + +- FP8训练在大多数模型规模上达到与BF16训练相当的质量,吞吐量提高约2倍。它是在H100集群上进行新的大规模训练运行的默认选择。 + +## 激活值量化 + +- 激活值(层之间流动的中间张量)也可以量化,实现完全INT8计算(权重和激活值都是INT8,INT32累加)。 + +- **动态量化**:在运行时根据实际激活值计算缩放因子。更准确(适应每个输入),但增加开销(每层计算最小值/最大值或百分位数)。 + +- **静态量化**:在校准期间计算一次缩放因子并固定。推理时更快(无需运行时统计),但如果校准数据不具代表性则精度较低。 + +- **逐token量化**:为序列中的每个token计算单独的缩放因子。对LLM至关重要,因为不同token的激活值幅度可能差异很大(某些token的激活值比其他token大100倍)。 + +- 激活值量化比权重量化更难,因为激活值依赖于数据(它们随每个输入变化),而权重是固定的。"异常值"问题尤其严重:少数激活通道具有极值(平均值的100倍),用与正常通道相同的缩放因子量化它们会浪费精度。 + +- **SmoothQuant**(Xiao等人,2022)通过数学上将量化难度从激活值(由于异常值难以量化)迁移到权重(易于量化)来解决异常值问题:将激活值乘以$1/s$,权重乘以$s$,其中$s$平衡难度。输出$XW = (X \cdot \text{diag}(s^{-1})) \cdot (\text{diag}(s) \cdot W)$保持不变。 + +## 混合精度量化 + +- 并非所有层对量化的敏感度相同。注意力层通常可以容忍INT4,而嵌入层和最终分类器需要更高精度。 + +- **敏感性分析**:逐层量化并测量精度影响。敏感性高的层获得更多位;不敏感的层获得更少位。 + +- Transformer Engine(第16章,NVIDIA Hopper)在操作级别实现动态混合精度:每个矩阵乘法根据张量统计信息在FP8和FP16之间选择,最大化吞吐量同时保持质量。 + +## KV缓存量化 + +- 在LLM生成过程中,**KV缓存**存储所有先前token的键和值张量。对于长序列,这主导了内存: + +$$\text{KV缓存大小} = 2 \times n_{\text{layers}} \times n_{\text{heads}} \times d_{\text{head}} \times \text{seq\_len} \times \text{bytes\_per\_element}$$ + +- 一个70B模型,80层,64头,128维头,序列长度128K,FP16:$2 \times 80 \times 64 \times 128 \times 131072 \times 2 = 330$ GB。这超过了GPU内存。 + +- **KV缓存量化**通过将缓存的键和值以INT8或INT4而不是FP16存储来减少内存。量化误差在序列中累积(每个新token关注所有缓存的K/V),但使用逐通道或逐头量化后,降级是可以接受的。 + +- **KV缓存量化具有乘法级收益**:它支持更长的序列(更多上下文)、更大的批次大小(更多并发用户)和更快的推理(加载缓存所需的内存带宽更少)。这是LLM服务中影响最大的优化之一。 + +## 编程任务(使用CoLab或notebook) + +1. 从头实现对称INT8量化。量化一个权重矩阵,反量化它,并测量作为值分布函数的重建误差。 +```python +import jax.numpy as jnp +import jax + +def quantise_int8(tensor): + scale = jnp.max(jnp.abs(tensor)) / 127.0 + quantised = jnp.clip(jnp.round(tensor / scale), -127, 127).astype(jnp.int8) + return quantised, scale + +def dequantise(quantised, scale): + return quantised.astype(jnp.float32) * scale + +# 正常权重(典型训练模型) +key = jax.random.PRNGKey(0) +weights = jax.random.normal(key, (1024, 1024)) * 0.02 + +q, s = quantise_int8(weights) +recon = dequantise(q, s) + +print(f"原始: {weights.nbytes / 1024:.0f} KB") +print(f"量化后: {q.nbytes / 1024:.0f} KB ({weights.nbytes / q.nbytes:.0f}x 更小)") +print(f"平均绝对误差: {jnp.abs(weights - recon).mean():.6f}") +print(f"最大绝对误差: {jnp.abs(weights - recon).max():.6f}") +print(f"相对误差: {jnp.abs(weights - recon).mean() / jnp.abs(weights).mean():.4%}") +``` + +2. 演示异常值问题。创建具有几个极端通道的激活值,展示逐张量量化失败而逐通道量化成功。 +```python +import jax.numpy as jnp +import jax + +key = jax.random.PRNGKey(42) + +# 激活值:大多数通道正常,2个通道有100x异常值 +activations = jax.random.normal(key, (32, 512)) * 0.1 +activations = activations.at[:, 0].set(activations[:, 0] * 100) # 异常通道 +activations = activations.at[:, 1].set(activations[:, 1] * 50) # 异常通道 + +# 逐张量量化(整个张量一个缩放因子) +scale_tensor = jnp.max(jnp.abs(activations)) / 127.0 +q_tensor = jnp.clip(jnp.round(activations / scale_tensor), -127, 127) +recon_tensor = q_tensor * scale_tensor + +# 逐通道量化(每通道一个缩放因子) +scales_channel = jnp.max(jnp.abs(activations), axis=0) / 127.0 +q_channel = jnp.clip(jnp.round(activations / scales_channel), -127, 127) +recon_channel = q_channel * scales_channel + +err_tensor = jnp.abs(activations - recon_tensor).mean() +err_channel = jnp.abs(activations - recon_channel).mean() + +print(f"逐张量误差: {err_tensor:.6f}") +print(f"逐通道误差: {err_channel:.6f}") +print(f"逐通道好 {err_tensor / err_channel:.1f}x") +print(f"\n异常通道浪费了 {(activations.shape[1] - 2) / activations.shape[1]:.0%} " + f"的量化范围给 {2 / activations.shape[1]:.1%} 的通道") +``` + +3. 计算不同模型大小和序列长度的KV缓存内存。展示为什么KV缓存量化对长上下文模型至关重要。 +```python +def kv_cache_gb(n_layers, n_heads, d_head, seq_len, bytes_per_elem): + return 2 * n_layers * n_heads * d_head * seq_len * bytes_per_elem / 1e9 + +models = [ + ("Llama-7B", 32, 32, 128), + ("Llama-70B", 80, 64, 128), + ("GPT-4 (估计)", 120, 96, 128), +] + +print(f"{'模型':<15} {'序列长度':>8} {'FP16 (GB)':>10} {'INT8 (GB)':>10} {'INT4 (GB)':>10}") +print("-" * 60) + +for name, layers, heads, d_head in models: + for seq_len in [4096, 32768, 131072]: + fp16 = kv_cache_gb(layers, heads, d_head, seq_len, 2) + int8 = kv_cache_gb(layers, heads, d_head, seq_len, 1) + int4 = kv_cache_gb(layers, heads, d_head, seq_len, 0.5) + print(f"{name:<15} {seq_len:>8} {fp16:>9.1f} {int8:>9.1f} {int4:>9.1f}") + print() +``` diff --git a/chapter 17: AI inference/02. efficient architectures.md b/chapter 17: AI inference/02. efficient architectures.md new file mode 100644 index 0000000..9e839d2 --- /dev/null +++ b/chapter 17: AI inference/02. efficient architectures.md @@ -0,0 +1,238 @@ +# 高效架构 + +*让模型更快不仅仅是降低精度,还在于设计更智能的架构,使每个token的计算量更少。本文涵盖StreamingLLM、稀疏和线性注意力、多查询和分组查询注意力、推理时的混合专家、知识蒸馏、剪枝和神经架构搜索* + +- 量化(文件1)使每个操作更廉价。本文从源头上减少操作数量。两者互补:一个架构高效且量化的模型可以比原始模型快10-100倍。 + +## StreamingLLM:无限长度生成 + +- 标准Transformer将所有先前的token存储在KV缓存中,KV缓存随序列长度线性增长。在某一点上,缓存超过GPU内存,生成失败。**StreamingLLM**(Xiao等人,2023)使用固定大小的**滚动KV缓存**解决了这个问题。 + +- 关键观察:序列中的前几个token,无论其内容如何,都获得不成比例的高注意力分数。这些被称为**注意力汇聚点**。如果将它们从缓存中逐出,注意力分布会崩溃,生成质量灾难性下降。 + +- StreamingLLM的解决方案:在缓存中永久保留少量**汇聚token**(前1-4个token),加上最近$w$个token的**滚动窗口**。总缓存大小为$\text{sink} + w$,无论生成了多少token都是固定的。 + +$$\text{缓存} = [\text{token}_0, \text{token}_1, \text{token}_{t-w+1}, \ldots, \text{token}_t]$$ + +- 注意力汇聚点锚定softmax分布,滚动窗口提供最近的上下文。这实现了**无限长度生成**,内存恒定,代价是失去了访问序列中间上下文的能力。 + +- StreamingLLM无需重新训练即可用于自然形成注意力汇聚点的模型(大多数预训练LLM都会)。对于不形成汇聚点的模型,在训练期间添加一个可学习的汇聚token即可解决。 + +## 稀疏注意力 + +- 全自注意力在序列长度$n$上是$O(n^2)$,因为每个token关注所有其他token。对于$n = 128K$,注意力矩阵有$128K^2 = 160$亿个条目。**稀疏注意力**模式通过限制哪些token关注哪些token来减少这个数量。 + +![注意力稀疏模式:全注意力是O(n²),滑动窗口是O(n·w),局部+全局添加长距离token](../images/attention_sparsity_patterns.svg) + +- **滑动窗口注意力**(Mistral、Gemma):每个token只关注之前$w$个token(例如$w = 4096$)。注意力是$O(n \cdot w)$而不是$O(n^2)$。信息通过多层在窗口之外传播:经过$L$层后,有效上下文为$L \times w$。 + +- **局部+全局注意力**(Longformer、BigBird):大多数token使用滑动窗口注意力(局部),但少数指定token(例如[CLS],每512个token)关注所有token(全局)。这同时捕获了局部模式和长距离依赖。 + +- **膨胀注意力**:关注窗口内每第$k$个token,创建一个覆盖更大范围但注意力分数数量相同的稀疏模式。跨层增加膨胀度创建类似于膨胀卷积的层次结构(第8章)。 + +- 现代LLM的实际胜者是**滑动窗口+全注意力交错**:某些层使用滑动窗口(廉价,处理局部上下文),某些层使用全注意力(昂贵,捕获长距离)。Mistral/Mixtral使用这种模式。 + +## 线性注意力和状态空间模型 + +- 我们能完全替换$O(n^2)$的注意力吗?**线性注意力**和**状态空间模型(SSM)**通过避免显式注意力矩阵,以$O(n)$时间处理序列。 + +- **线性注意力**用核近似替换softmax注意力: + +$$\text{标准: } O = \text{softmax}(QK^T / \sqrt{d}) V$$ +$$\text{线性: } O = \phi(Q) (\phi(K)^T V)$$ + +- 通过先关联$K^T V$乘积(这是$d \times d$,与序列长度无关),计算变成$O(n \cdot d^2)$而不是$O(n^2 \cdot d)$。对于$n \gg d$的长序列,这是巨大的节省。 + +- **RWKV**结合了RNN和Transformer的思想。它使用循环公式顺序处理token(像RNN),但可以在训练时并行化(像Transformer)。推理是每个token $O(1)$(常量内存,KV缓存不增长)。 + +- **Mamba**(Gu & Dao,2023)是一种选择性状态空间模型。它通过学习到的状态转换处理序列: + +$$h_t = \bar{A} h_{t-1} + \bar{B} x_t, \quad y_t = C h_t$$ + +- 其中$\bar{A}$和$\bar{B}$是依赖于输入的(选择性),允许Mamba动态关注或忽略输入的部分。与固定SSM不同,选择性使Mamba在语言任务上与Transformer具有竞争力,同时保持$O(n)$的扩展性。 + +- **权衡**:线性注意力和SSM对长序列更快,但对于需要精确长距离检索的任务,通常不如全注意力。混合架构(一些Transformer层+一些Mamba层)通常能提供两全其美的效果。 + +## 多查询和分组查询注意力 + +- 标准多头注意力(MHA,第7章)为每个头使用独立的$K$、$V$投影。对于$h$个head,KV缓存中有$h$个独立的键和值张量。**多查询注意力(MQA)**和**分组查询注意力(GQA)**减少了这个数量。 + +- **MQA**(Shazeer,2019):所有头共享单组$K, V$投影。每个头仍然有自己的$Q$投影。KV缓存缩小了$h$倍(例如,32个头则缩小32倍)。 + +- **GQA**(Ainslie等人,2023):一个中间方案。头被分组,每组共享一组$K, V$投影。有$h = 32$个头和$g = 8$个组,每组4个头共享K/V。KV缓存缩小了$h/g = 4$倍。 + +$$\text{MHA: } h \text{ 个头, } h \text{ 个 K/V 集} \quad \to \quad \text{GQA: } h \text{ 个头, } g \text{ 个 K/V 集} \quad \to \quad \text{MQA: } h \text{ 个头, } 1 \text{ 个 K/V 集}$$ + +![MHA vs GQA vs MQA:MHA给每个头自己的KV,GQA跨组共享KV,MQA为所有头使用单个KV——大幅减少KV缓存大小](../images/mha_gqa_mqa.svg) + +- 大多数现代LLM使用GQA(Llama 2/3、Gemma、Mistral)。它减少了KV缓存内存和推理延迟,与MHA相比质量损失可以忽略不计。 + +### 多头潜在注意力(MLA) + +- **MLA**(DeepSeek-V2,2024)通过将KV缓存压缩为**低秩潜在空间**,比GQA更进一步。MLA不是缓存完整的键和值向量,而是每个token缓存一个压缩后的潜在向量$\mathbf{c}_t$,并在注意力期间动态重构K/V: + +$$\mathbf{c}_t = W_{\text{compress}} \cdot [\mathbf{k}_t; \mathbf{v}_t], \quad \mathbf{k}_t = W_K^{\text{up}} \cdot \mathbf{c}_t, \quad \mathbf{v}_t = W_V^{\text{up}} \cdot \mathbf{c}_t$$ + +- 压缩向量$\mathbf{c}_t$比原始K和V的组合小得多。DeepSeek-V2实现了与MHA相比**93.3%的KV缓存大小减少**,甚至优于MQA,同时保持MHA级别的质量。 + +- 权衡:从潜在向量重构K/V在每个注意力操作中增加了少量计算成本。但由于LLM解码是内存带宽受限的(而非计算受限),这总体上是个净收益:更少的内存加载 > 每token稍多计算。 + +### Flash Attention + +- **Flash Attention**(Dao等人,2022,第16章文件05有详细论述)不是架构变化,而是一种实现优化,在任何高效注意力讨论中都不可或缺。它计算精确的标准注意力,具有以下特点: + + - **O(n)内存**而不是O(n²)(注意力矩阵从未在HBM中具体化)。 + - **比标准注意力快2-4倍**(通过分块和在线softmax将数据保留在SRAM中)。 + - **无质量损失**——输出在数学上与标准注意力完全相同。 + +- Flash Attention现在是PyTorch(`torch.nn.functional.scaled_dot_product_attention`)、JAX和所有主要推理框架中默认的注意力实现。如果你在2024+年运行注意力,你几乎肯定在使用Flash Attention。 + +### Ring Attention + +- **Ring Attention**(Liu等人,2023)将注意力计算分布到多个设备上,用于即使使用Flash Attention也无法装入单GPU内存的长序列。 + +- 思路:将序列分区到$N$个设备上。每个设备持有$n/N$个token的Q、K、V。设备排列成环形。每一步: + 1. 每个设备计算局部注意力(其Q对其局部K/V)。 + 2. 每个设备将其K/V块发送到环中的下一个设备。 + 3. 每个设备从上一个设备接收K/V,并针对这些K/V计算注意力。 + 4. 经过$N$步后,每个设备已经关注过每个K/V块。 + +- 通信与计算**重叠**:在当前K/V块上计算注意力的同时,下一个块正在传输中。这几乎完全隐藏了通信延迟。 + +- Ring Attention通过将KV缓存分布在一圈GPU上,实现了**百万token上下文窗口**。每台设备的内存为O(n/N),使得任意长序列都可行(仅受设备数量限制)。 + +## 推理时的混合专家 + +- MoE模型(第7章)每个token只激活其参数的一小部分(通常8个专家中的2个)。在推理时,独特的挑战是**专家缓存**:所有专家都必须在内存中(因为任何token可能路由到任何专家),但每个token只有2个活跃。 + +- 对于Mixtral 8x7B模型:总参数 = 47B(8 × 7B专家,但有共享组件)。每个token的活跃参数 ≈ 13B(2个专家 + 共享层)。该模型具有LLM-70B级别的质量,但推理成本为LLM-13B级别,不过需要在内存中保留47B参数。 + +- **专家卸载**:对于GPU内存受限的部署,将非活跃专家保留在CPU或SSD上,按需加载。这之所以有效,是因为token路由足够可预测,可以预取可能的专家。 + +- **专家缓存**:在GPU内存中维护最近使用的专家的LRU缓存。如果相同的专家被重复激活(在领域内数据中常见),缓存命中率很高。 + +## 知识蒸馏 + +- **蒸馏**(第6章)训练一个小的"学生"模型来模仿一个大的"教师"。学生从教师的软预测(类上的概率分布)中学习,这比单独的硬标签包含更多信息。 + +$$\mathcal{L} = \alpha \cdot \text{KL}(p_{\text{teacher}}^{T} \| p_{\text{student}}^{T}) + (1 - \alpha) \cdot \mathcal{L}_{\text{CE}}(y, p_{\text{student}})$$ + +- 其中$T$是温度(更高的$T$使分布变软,揭示教师的不确定性),$\alpha$平衡蒸馏损失与标准交叉熵损失。 + +- **对于LLM**:蒸馏用于从大型、能力强的模型创建小型、快速的模型。GPT-4 → 一个7B学生模型,在特定任务上捕获GPT-4的大部分行为。学生模型的推理成本可以低10-100倍。 + +- **任务特定蒸馏**:仅在与部署任务相关的数据上进行蒸馏。从70B教师模型在医疗问答上蒸馏出的7B模型,在该特定任务上可以超越70B模型(因为学生有限的容量完全集中在目标领域上)。 + +## 剪枝 + +- **剪枝**移除不必要的权重(将其设为零),减少模型大小和计算量。 + +- **非结构化剪枝**(基于幅值):移除绝对值最小的单个权重。这创建了稀疏权重矩阵。简单有效用于压缩,但当前硬件(GPU)除非稀疏性遵循特定模式,否则无法高效加速稀疏操作。 + +- **结构化剪枝**:移除整个单元——注意力头、MLP神经元或层。这产生一个更小的稠密模型,可以在标准硬件上直接加速。权衡是粒度更粗(移除一个完整的头可能同时移除了有用和无用的权重)。 + +- **2:4稀疏性**(NVIDIA Ampere+):一种硬件支持的稀疏模式,每4个权重中有2个为零。GPU的稀疏Tensor Core跳过零乘法,实现约2倍加速。这是目前唯一具有实际硬件加速的稀疏模式。 + +- **彩票假说**(Frankle & Carlin,2019):在随机初始化的网络中,存在一个子网络("中奖彩票"),可以单独训练以匹配完整网络的性能。找到这些子网络(通过训练、剪枝和重置)成本高昂,但这个洞察激励了剪枝研究。 + +## 神经架构搜索(NAS) + +- **NAS**通过搜索可能的架构空间来自动化架构设计,找到在硬件约束(延迟、内存、功耗)下最大化精度的架构。 + +- **EfficientNet**(第8章)就是通过NAS找到的:复合缩放规则(平衡深度、宽度、分辨率)是从搜索中涌现的,而非人类直觉。 + +- 对于推理效率,NAS可以找到针对特定硬件目标优化的架构:"找到一个在iPhone神经引擎上延迟<5ms且在ImageNet上精度>80%的模型。"搜索空间包括层类型、宽度、激活函数和注意力模式。 + +- **一次性网络**训练一个单个过参数化网络,为不同的部署目标提取子网络。一次训练运行产生针对云GPU、移动GPU和CPU优化的模型,每个都针对其目标进行了优化。 + +## 编程任务(使用CoLab或notebook) + +1. 实现滑动窗口注意力,并与全注意力比较内存使用。 +```python +import jax +import jax.numpy as jnp + +def full_attention(Q, K, V): + """标准O(n²)注意力。""" + scores = Q @ K.T / jnp.sqrt(Q.shape[-1]) + weights = jax.nn.softmax(scores, axis=-1) + return weights @ V + +def sliding_window_attention(Q, K, V, window_size=128): + """滑动窗口注意力:每个token关注前window_size个token。""" + n = Q.shape[0] + d = Q.shape[-1] + output = jnp.zeros_like(Q) + + for i in range(n): + start = max(0, i - window_size + 1) + k_window = K[start:i+1] + v_window = V[start:i+1] + scores = Q[i] @ k_window.T / jnp.sqrt(d) + weights = jax.nn.softmax(scores) + output = output.at[i].set(weights @ v_window) + + return output + +n, d = 512, 64 +key = jax.random.PRNGKey(0) +Q = jax.random.normal(key, (n, d)) +K = jax.random.normal(jax.random.PRNGKey(1), (n, d)) +V = jax.random.normal(jax.random.PRNGKey(2), (n, d)) + +print(f"全注意力内存: O(n²) = {n*n} 个条目") +print(f"窗口 (w=128) 内存: O(n*w) = {n*128} 个条目") +print(f"减少: {n*n / (n*128):.1f}x") +``` + +2. 比较MHA、GQA和MQA的KV缓存大小。展示为什么GQA是实际的最佳选择。 +```python +def kv_cache_size(n_heads, n_kv_heads, d_head, seq_len, bytes=2): + """KV缓存大小(MB)。""" + return 2 * n_kv_heads * d_head * seq_len * bytes / 1e6 + +n_heads = 32 +d_head = 128 +seq_len = 32768 + +mha = kv_cache_size(n_heads, n_heads, d_head, seq_len) # 32个KV头 +gqa = kv_cache_size(n_heads, 8, d_head, seq_len) # 8个KV头 +mqa = kv_cache_size(n_heads, 1, d_head, seq_len) # 1个KV头 + +print(f"MHA (32个KV头): {mha:.0f} MB 每层") +print(f"GQA (8个KV头): {gqa:.0f} MB 每层 ({mha/gqa:.0f}x 更小)") +print(f"MQA (1个KV头): {mqa:.0f} MB 每层 ({mha/mqa:.0f}x 更小)") +``` + +3. 通过从随机注意力层中移除最不重要的注意力头并测量输出变化来模拟结构化剪枝。 +```python +import jax +import jax.numpy as jnp + +key = jax.random.PRNGKey(0) +n_heads, seq_len, d_head = 8, 64, 32 + +# 随机多头注意力输出(每个头一个) +head_outputs = jax.random.normal(key, (n_heads, seq_len, d_head)) + +# 完整输出:连接所有头 +full_output = head_outputs.reshape(seq_len, n_heads * d_head) + +# 重要性:通过范数度量每个头的贡献 +head_norms = jnp.linalg.norm(head_outputs, axis=(1, 2)) +print("头重要性(按范数):", jnp.round(head_norms, 2)) + +# 剪枝最不重要的头 +for n_keep in [8, 6, 4, 2]: + top_heads = jnp.argsort(head_norms)[-n_keep:] + pruned = head_outputs[top_heads].reshape(seq_len, n_keep * d_head) + + # 填充到原始大小用于比较(将剪掉的头设为零) + full_pruned = jnp.zeros_like(head_outputs) + full_pruned = full_pruned.at[top_heads].set(head_outputs[top_heads]) + full_pruned = full_pruned.reshape(seq_len, n_heads * d_head) + + error = jnp.linalg.norm(full_output - full_pruned) / jnp.linalg.norm(full_output) + print(f"保留 {n_keep}/{n_heads} 个头: 相对误差 = {error:.4f}, " + f"内存 = {n_keep/n_heads:.0%}") +``` diff --git a/chapter 17: AI inference/03. serving and batching.md b/chapter 17: AI inference/03. serving and batching.md new file mode 100644 index 0000000..ddd5d94 --- /dev/null +++ b/chapter 17: AI inference/03. serving and batching.md @@ -0,0 +1,236 @@ +# 服务与批处理 + +*向数千并发用户提供LLM服务需要的不只是加载模型和运行推理。本文涵盖预填充-解码分离、连续批处理、PagedAttention和vLLM、调度策略、分离式服务、多模型和LoRA服务,以及关键指标* + +- 单个LLM推理请求很简单:输入token,生成输出token。但要向10,000个并发用户提供低延迟、高吞吐量的LLM服务,这是一个系统工程问题。朴素方法(一次处理一个请求)浪费了90%以上的GPU容量。智能批处理和调度可以在不增加硬件的情况下将吞吐量提高10-50倍。 + +## 预填充 vs 解码:两个截然不同的阶段 + +- LLM推理有两个不同的阶段,具有根本不同的计算特征: + +- **预填充**(提示处理):同时处理所有输入token。这是一个单次大规模矩阵乘法:$O(\text{prompt\_length} \times d_{\text{model}}^2)$。提示可以并行处理(所有token都已知)。预填充是**计算受限**的:GPU的ALU是瓶颈。 + +- **解码**(token生成):自回归地一次生成一个token。每个新token需要通过KV缓存关注所有先前的token。解码是**内存带宽受限**的:GPU大部分时间花在从内存加载模型权重和KV缓存上,而不是计算。每个解码步骤只产生一个token,但必须加载整个模型(70B FP16模型约140 GB)。 + +- 含义: + +| | 预填充 | 解码 | +|--|---------|--------| +| 处理的token | 一次性全部(并行) | 一次一个(顺序) | +| 瓶颈 | 计算(FLOPS) | 内存带宽 | +| 算术强度 | 高 | 非常低 | +| GPU利用率 | 高(50-80%) | 低(1-10%),无批处理时 | +| 延迟指标 | **首token时间(TTFT)** | **每输出token时间(TPOT)** | + +- TTFT影响用户体验(多久直到响应开始流式传输)。TPOT决定感知的生成速度。用户可以容忍较高的TTFT(1-5秒),但期望快速的TPOT(对话应用每token 30-100毫秒)。 + +## 静态批处理(朴素方法) + +- 最简单的批处理:收集$B$个请求,填充到相同长度,作为单个批次处理。 + +- **问题1**:请求有不同的提示长度,并生成不同数量的输出token。短请求提前完成,但必须等待批次中最长的请求完成后才能开始下一个批次。GPU在为剩余的一个长请求生成token时处于空闲状态。 + +- **问题2**:填充浪费计算。如果最长提示是2000个token,最短是50个,批次被填充到2000。GPU为短请求处理了1950个填充token——纯属浪费。 + +![静态批处理在等待最长请求时浪费GPU槽位;连续批处理立即填充释放的槽位](../images/static_vs_continuous_batching.svg) + +## 连续批处理 + +- **连续批处理**(也称为迭代级批处理)通过在单个解码步骤的粒度上操作来解决这两个问题,而不是整个请求。 + +- 在每个解码步骤: + 1. 所有进行中的请求并行生成一个token(作为一个批次)。 + 2. 完成的请求(生成EOS token)立即从批次中**移除**。 + 3. 队列中的新请求立即**插入**到释放的槽位中。 + +- 批次大小每步动态变化。GPU从不等候落后者,也没有浪费的填充(每个请求只使用它需要的槽位)。 + +- **影响**:连续批处理通常比静态批处理提高吞吐量2-10倍,模型质量不变且延迟无明显增加。 + +## PagedAttention和vLLM + +- KV缓存造成了一个内存管理噩梦。每个请求都有一个随着每个生成的token而增长的KV缓存。不同请求处于不同阶段(不同缓存大小)。为每个请求分配连续内存浪费空间(必须为最大可能长度分配,即使请求只生成几个token)。 + +![PagedAttention将虚拟KV缓存页映射到非连续的物理GPU内存,消除碎片并实现按需分配](../images/paged_attention.svg) + +- **PagedAttention**(Kwon等人,2023)将操作系统虚拟内存的概念(第13章)应用于KV缓存。缓存被划分为固定大小的**页**(token位置的块)。页按需分配,在物理GPU内存中可以是非连续的。 + +- 优势: + - **无碎片**:页大小统一,因此请求之间没有浪费内存的"空洞"。 + - **惰性分配**:仅在token实际生成时分配内存,而不是预分配最大长度。 + - **写时复制**:共享共同前缀(例如系统提示)的请求共享相同的KV缓存页。仅当请求分叉时才复制页。 + +- **vLLM**是基于PagedAttention构建的推理引擎。通过几乎消除KV缓存内存浪费,它实现了比静态分配服务(如没有分页注意力的HuggingFace text-generation-inference)高2-4倍的吞吐量。 + +## 调度策略 + +- 当多个请求在等待且GPU只能处理有限批次时,**调度**决定服务哪些请求: + +- **先来先服务(FCFS)**:按到达顺序处理请求。简单但不公平:一个提交10K-token生成的用户会阻塞所有后面的用户。 + +- **最短作业优先(SJF)**:处理最先完成的请求。最小化平均延迟,但惩罚长时间运行的请求(它们可能被饿死)。在实践中,估计输出长度未知,因此SJF使用启发式方法(提示长度、用户历史)。 + +- **抢占**:如果高优先级请求到达,暂停低优先级的进行中请求(将其KV缓存交换到CPU内存或SSD),服务高优先级请求,然后恢复暂停的请求。vLLM支持此功能。 + +- **基于优先级**:为用户或请求类型分配优先级。实时交互查询比批处理作业获得更高优先级。结合抢占,这确保高优先级流量的延迟SLO。 + +- **Token预算**:限制活跃批次中的总token数。这防止少量长请求独占GPU内存并饿死新请求。 + +## 分离式服务 + +- 预填充和解码具有相反的计算特征。在同一GPU上运行两者意味着GPU在计算受限(预填充)和内存带宽受限(解码)之间交替,从未充分利用任一资源。 + +- **分离式服务**将它们分开: + - **预填充节点**:为计算优化的GPU(高FLOPS,可能内存较少)。处理所有传入提示。 + - **解码节点**:为内存带宽优化的GPU(大KV缓存容量,高内存带宽)。处理所有token生成。 + +- 预填充节点计算初始KV缓存并通过NVLink或网络将其发送到解码节点。解码节点使用接收到的缓存生成token。 + +- 这是**Mooncake**(月之暗面)的架构,并正在被多个LLM服务团队探索。好处:每个GPU类型与其工作负载特征匹配,提高整体利用率。 + +## 多模型和LoRA服务 + +- 在生产中,你通常服务多个模型(不同层级的模型大小不同,不同任务的微调变体不同)。 + +- **模型复用**:在同一GPU上加载多个模型,将请求路由到相应模型。GPU内存共享:一个40 GB GPU可能同时持有一个13B模型(26 GB)和一个7B模型(14 GB)。 + +- **LoRA服务**:不是部署单独的微调模型,而是部署一个基础模型并带有多个**LoRA适配器**(第6章)。每个适配器增加<1%的参数。请求在推理时路由到相应的适配器。 + +- **S-LoRA**(Sheng等人,2023):从一个基础模型服务数千个LoRA适配器。适配器存储在CPU上,按需分页到GPU内存。基础模型的KV缓存和权重被共享;只有小的LoRA矩阵因请求而异。 + +- **Punica**(Chen等人,2023):通过使用自定义CUDA内核在同一批次中为不同请求应用不同的LoRA矩阵,跨不同LoRA适配器对请求进行批处理。这避免了每个请求切换适配器的开销。 + +## 受限和引导生成 + +- 许多应用需要LLM以特定格式产生输出:有效的JSON、SQL查询、特定语言的代码或遵循模式的响应。**受限生成**保证输出符合语法或模式。 + +- **语法受限解码**:在每个解码步骤,屏蔽会违反语法的token。如果到目前为止的输出是`{"name": "Alice", "age":`且语法要求接下来是整数,则屏蔽除数字外的所有token。LLM的概率分布在有效token上重新归一化。 + +- **Outlines**(Willard & Louf,2023):将JSON模式或正则表达式编译成有限状态机(FSM)。在每个解码步骤,FSM确定哪些token是有效的后续。无效token获得概率0。这保证了100%的模式合规,零重试。 + +- **SGLang**原生集成受限生成:你用Python指定输出结构,引擎高效处理token掩码和缓存。这与RadixAttention(前缀缓存)结合,使得结构化输出重用缓存的公共前缀。 + +- **为什么重要**:没有受限生成,你自由生成然后解析输出,失败时重试。对于复杂JSON模式,重试率通常为10-30%,浪费计算。受限生成完全消除了重试。 + +## 请求路由 + +- 并非每个查询都需要最大的模型。**请求路由**根据估计的难度将查询定向到不同的模型: + +- **级联**:先尝试小模型。如果小模型的置信度低于阈值(例如,top token的softmax概率<0.8),则升级到更大的模型。简单查询(80%+的流量)由小模型廉价服务;只有困难查询使用昂贵模型。 + +- **学习型路由**:训练一个轻量级分类器(或使用小模型的困惑度)来预测查询需要哪个模型层级。将"2+2等于多少?"路由到3B模型,将"解释量子纠缠的数学基础"路由到70B模型。 + +- **影响**:如果80%的查询可以由成本低10倍的模型处理,平均每查询成本下降约70%。这是多模型部署中影响最大的成本优化之一。 + +- **设备端+云混合路由**:**Cactus**([github.com/cactus-compute/cactus](https://github.com/cactus-compute/cactus))在设备级别实现请求路由。它通过自定义ARM SIMD内核在设备端(手机、笔记本电脑、可穿戴设备)运行小模型,并在本地模型置信度低或查询超出设备能力时自动路由到云端模型。应用为两条路径使用OpenAI兼容API——路由是透明的。这是在基础设施级别的级联:第一层是免费的(设备端),第二层花钱(云API)。对于大多数查询简单的应用(助手问答、自动补全、转录),设备端处理覆盖70-90%的流量,边际成本为零。 + +## 推理指标 + +- 正确的指标取决于用例: + +| 指标 | 测量内容 | 目标(对话式) | 目标(批处理) | +|--------|-----------------|------------------------|-----------------| +| **TTFT** | 首token时间 | <1 s | 不太重要 | +| **TPOT** | 每输出token时间 | <100 ms | 不太重要 | +| **吞吐量** | token/秒(总计) | 不太重要 | 最大化 | +| **p99延迟** | 最差的1%请求 | <5 s | <30 s | +| **每token成本** | $/100万token | 最小化 | 最小化 | +| **SLO合规率** | 满足延迟目标的请求百分比 | >99% | >95% | + +- **TTFT vs TPOT权衡**:激进的批处理增加吞吐量(总token数/秒更多),但增加TPOT(每个token耗时更长,因为GPU处理更多请求)。调度策略必须平衡吞吐量(收入)与延迟(用户体验)。 + +- **每token成本**是生产的最终指标。它结合了硬件成本(GPU租金)、吞吐量(token/秒)和利用率。运行在50% GPU利用率的系统比100%利用率的系统每token成本高2倍。这就是批处理、调度和PagedAttention如此重要的原因——它们提高了利用率。 + +## 编程任务(使用CoLab或notebook) + +1. 模拟连续vs静态批处理并测量吞吐量差异。 +```python +import random +import time + +def simulate_static_batching(requests, batch_size=8): + """在固定批次中处理请求。等待所有完成。""" + total_tokens = 0 + total_time = 0 + + for i in range(0, len(requests), batch_size): + batch = requests[i:i + batch_size] + max_len = max(r['output_len'] for r in batch) + # 批次中所有请求耗时等于最长请求 + batch_time = max_len * 0.01 # 每token 10ms + total_time += batch_time + total_tokens += sum(r['output_len'] for r in batch) + + return total_tokens / total_time # token/秒 + +def simulate_continuous_batching(requests, max_batch=8): + """使用连续批处理处理。移除完成请求,添加新请求。""" + total_tokens = 0 + total_time = 0 + active = [] + queue = list(requests) + + while active or queue: + # 填充批次 + while len(active) < max_batch and queue: + active.append({'remaining': queue.pop(0)['output_len']}) + + if not active: + break + + # 一个解码步骤:所有活跃请求生成1个token + for req in active: + req['remaining'] -= 1 + total_tokens += len(active) + total_time += 0.01 # 每步10ms + + # 移除完成的请求 + active = [r for r in active if r['remaining'] > 0] + + return total_tokens / total_time + +# 生成具有不同输出长度的请求 +random.seed(42) +requests = [{'output_len': random.randint(10, 500)} for _ in range(100)] + +static_tps = simulate_static_batching(requests) +continuous_tps = simulate_continuous_batching(requests) + +print(f"静态批处理: {static_tps:.0f} tokens/s") +print(f"连续批处理: {continuous_tps:.0f} tokens/s") +print(f"加速比: {continuous_tps / static_tps:.1f}x") +``` + +2. 计算PagedAttention的KV缓存内存节省。比较预分配(最坏情况)vs分页(实际使用)。 +```python +def paged_vs_preallocated(n_requests, max_seq_len, avg_seq_len, page_size, kv_per_token_bytes): + """比较内存使用:预分配vs分页KV缓存。""" + # 预分配:每个请求获得max_seq_len个槽位 + preallocated_gb = n_requests * max_seq_len * kv_per_token_bytes / 1e9 + + # 分页:只分配使用的部分(按页粒度) + import math + avg_pages = math.ceil(avg_seq_len / page_size) + paged_gb = n_requests * avg_pages * page_size * kv_per_token_bytes / 1e9 + + waste_preallocated = (max_seq_len - avg_seq_len) / max_seq_len + waste_paged = (avg_pages * page_size - avg_seq_len) / (avg_pages * page_size) + + print(f"请求数: {n_requests}, 最大序列: {max_seq_len}, 平均序列: {avg_seq_len}") + print(f" 预分配: {preallocated_gb:.1f} GB (浪费: {waste_preallocated:.0%})") + print(f" 分页: {paged_gb:.1f} GB (浪费: {waste_paged:.0%})") + print(f" 节省: {preallocated_gb - paged_gb:.1f} GB ({preallocated_gb/paged_gb:.1f}x)") + print() + +# Llama-70B:每层每token约1.3 KB,80层 = 每token约100 KB总计 +kv_bytes = 100_000 + +# 场景1:短请求,大最大值 +paged_vs_preallocated(256, max_seq_len=4096, avg_seq_len=256, page_size=16, kv_per_token_bytes=kv_bytes) + +# 场景2:不同长度 +paged_vs_preallocated(256, max_seq_len=8192, avg_seq_len=1024, page_size=16, kv_per_token_bytes=kv_bytes) + +# 场景3:长上下文 +paged_vs_preallocated(64, max_seq_len=131072, avg_seq_len=16000, page_size=16, kv_per_token_bytes=kv_bytes) +``` diff --git a/chapter 17: AI inference/04. edge inference.md b/chapter 17: AI inference/04. edge inference.md new file mode 100644 index 0000000..77c04df --- /dev/null +++ b/chapter 17: AI inference/04. edge inference.md @@ -0,0 +1,212 @@ +# 边缘推理 + +*边缘推理在用户设备(手机、笔记本电脑、物联网传感器)上运行模型,无需将数据发送到云端。本文涵盖边缘限制、模型压缩流水线、设备端运行时、编译器栈、硬件目标(NPU、神经引擎)、设备端LLM、联邦学习和延迟优化* + +- 云端推理需要网络连接,增加延迟(50-200毫秒往返),每次请求花费金钱,并将用户数据发送到第三方服务器。**边缘推理**消除了所有四个问题:模型本地运行,即时响应,每次推理零成本,且数据保持私密。 + +- 权衡:边缘设备的计算和内存比数据中心GPU小100-1000倍。使模型在这些约束下运行需要在每个层面进行积极优化。 + +- **Cactus**([github.com/cactus-compute/cactus](https://github.com/cactus-compute/cactus)) 是一个专为移动和可穿戴设备构建的低延迟AI引擎。它在生产中展示了本文涵盖的许多技术:自定义ARM SIMD内核用于注意力和矩阵运算(第16章)、KV缓存量化(第17章文件01)、分块预填充、Apple和Qualcomm芯片上的NPU加速推理、零拷贝内存映射实现10倍更低的RAM使用,以及在设备端计算不足时的自动云回退。Cactus支持跨iOS、Android、macOS和嵌入式Linux的多模态推理(LLM、视觉、语音),并提供Swift、Kotlin、Python、Flutter、React Native和Rust的SDK。其基准测试显示,在M4 Pro上1.2B INT4模型解码达到100 tokens/s,在iPhone 17 Pro上达到48 tokens/s——这是优化边缘推理的具体示例。 + +## 边缘约束 + +| 资源 | 云GPU(H100) | 笔记本电脑(M4) | 手机(Snapdragon 8 Gen 3) | IoT(ESP32) | +|----------|-----------------|-------------|---------------------------|-------------| +| 内存 | 80 GB HBM3 | 16-36 GB 统一内存 | 8-12 GB LPDDR5 | 520 KB | +| 计算 | 989 TFLOPS(FP8) | 38 TOPS(神经引擎) | 45 TOPS(NPU) | 0.001 TOPS | +| 功耗 | 700 W | 15-30 W | 5-10 W | 0.1 W | +| 存储 | TB | 256 GB-2 TB | 128-512 GB | 4 MB | + +- 云GPU和手机NPU之间的计算差距约为20倍。GPU和微控制器之间的差距约为1,000,000倍。不同设备需要不同程度的压缩和不同的模型架构。 + +## 模型压缩流水线 + +- 对于边缘部署,压缩不是单一技术——它是一个按顺序应用的互补技术**流水线**: + +``` +完整模型(FP32,70B参数) + ↓ 知识蒸馏 → 更小模型(7B参数) + ↓ 结构化剪枝 → 移除冗余头/层(4B有效) + ↓ 量化(INT4) → 4倍更小(2 GB) + ↓ 编译器优化 → 融合内核,优化内存布局 + ↓ 运行时 → 设备端执行 +``` + +- 每一步减少大小和延迟。顺序很重要:先蒸馏(减少架构),然后剪枝(移除结构),然后量化(降低精度),最后编译(为目标硬件优化)。在量化之后进行蒸馏会试图压缩已经损失质量的模型。 + +## 设备端运行时 + +- **运行时**加载模型、分配内存并在目标硬件上执行推理。每个平台有其偏好的运行时: + +- **ONNX Runtime**:跨平台(Windows、Linux、macOS、iOS、Android)。支持CPU、GPU(CUDA、DirectML、CoreML、NNAPI)和许多加速器后端。最具可移植性的选项。模型从PyTorch/TensorFlow导出为ONNX格式。 + +- **TensorFlow Lite(TFLite)**:Google的边缘运行时。针对ARM CPU和Android NPU优化。二进制文件小巧(约1 MB)。支持INT8和float16。Android部署的标准。 + +- **Core ML**:Apple的iOS/macOS运行时。根据模型特征自动使用神经引擎、GPU或CPU。模型使用`coremltools`从PyTorch/TensorFlow转换。与Apple硬件紧密集成(统一内存、神经引擎)。 + +- **ExecuTorch**:Meta新推出的设备端PyTorch运行时。专为边缘部署设计,具有提前编译和操作级硬件加速器委派功能。PyTorch Mobile的继任者。 + +- **TensorRT**:NVIDIA的GPU推理优化运行时(第15章)。融合层、选择最优内核并自动量化。在NVIDIA GPU上比PyTorch eager模式快2-5倍。 + +- **llama.cpp**:用于LLM的单文件C++推理引擎。支持GGUF量化(Q4、Q5、Q8)、CPU(AVX/NEON)、Metal(Apple GPU)、CUDA和Vulkan。在消费级硬件上运行LLM的首选方案。 + +## 编译器栈 + +- 在高级模型(PyTorch图)和硬件(NPU指令)之间是**编译器栈**,它为特定目标优化模型: + +``` +PyTorch模型 + ↓ 导出(torch.export、ONNX、TorchScript) +图IR(中间表示) + ↓ 图优化 + - 常量折叠(编译时计算常量表达式) + - 死代码消除(移除未使用的操作) + - 算子融合(conv + bn + relu → 单个融合操作) + - 布局转换(NCHW → NHWC用于ARM,通道最后) + ↓ 降级 +硬件特定IR + ↓ 后端优化 + - 分块和循环排序(缓存友好的访问模式) + - 向量化(SIMD,第16章) + - 内存规划(重用缓冲区以最小化峰值内存) + - 内核选择(为每个操作选择最佳实现) + ↓ 代码生成 +机器代码 / NPU指令 +``` + +- **算子融合**是影响最大的优化。一个Transformer块约有20个操作(矩阵乘、加法、层归一化、softmax等)。没有融合,每个操作将其输出写入内存,下一个操作再读回。有了融合,多个操作组合成一个内核,将数据保留在寄存器/缓存中。这可以使速度快2-5倍(第16章,屋顶模型)。 + +- **内存规划**:编译器分析模型图以确定哪些张量的生命周期重叠,可以共享相同的内存缓冲区。一个有100个中间张量的模型可能只需要10张量的内存,因为大多数张量在其他张量创建之前就被消耗和释放了。这在内存有限的设备上至关重要。 + +## 硬件目标 + +### 移动GPU + +- **Qualcomm Adreno**(Android):支持OpenCL、Vulkan计算(第16章)和Qualcomm专有的SNPE(Snapdragon神经处理引擎)。Adreno GPU具有256-1024个ALU,支持FP16和INT8。 + +- **ARM Mali**(Android):支持OpenCL和Vulkan。Mali GPU使用基于图块的架构(与桌面GPU不同),这影响最优内存访问模式。 + +- **Apple GPU**(iOS/macOS):通过Metal(Apple的GPU API)访问。统一内存架构意味着没有CPU↔GPU复制开销。Metal Performance Shaders(MPS)提供优化的ML原语。 + +### 神经处理单元(NPU) + +- NPU是专门为ML推理设计的固定功能加速器。它们在标准ML操作(矩阵乘、卷积、激活)上比GPU节能得多。 + +- **Apple神经引擎**:16核,约38 TOPS(INT8)。通过Core ML访问。非常适合视觉模型和设备端扩散。不能运行任意代码——只支持Core ML支持的操作。 + +- **Qualcomm Hexagon NPU**:集成到Snapdragon SoC中。支持INT8和INT4推理。通过SNPE或ONNX Runtime(带QNN后端)访问。为设备端功能如背景虚化、语音识别和实时翻译提供支持。 + +- **Google Edge TPU**:云端TPU的小型低功耗版本。4 TOPS,2W。用于Coral设备进行设备端推理。仅支持INT8量化的TFLite模型。 + +- **委派模式**:运行时在NPU(用于支持的操作)和CPU(用于不支持的操作)之间拆分模型图。最大化在NPU上运行的部分是性能和能效的关键。 + +## 设备端LLM + +- 在手机和笔记本电脑上运行LLM已变得可行,得益于小模型和积极的量化: + +| 模型 | 参数 | 量化后大小 | 目标设备 | 性能 | +|-------|--------|---------------|---------------|-------------| +| Phi-3 Mini | 3.8B | ~2 GB(Q4) | 手机/笔记本 | iPhone 15上~15 tokens/s | +| Gemma 2B | 2B | ~1.5 GB(Q4) | 手机 | Pixel 8上~20 tokens/s | +| Llama 3.2 1B | 1B | ~700 MB(Q4) | 手机 | ~30 tokens/s | +| Llama 3.2 3B | 3B | ~2 GB(Q4) | 手机/笔记本 | ~15 tokens/s | +| Llama 3.1 8B | 8B | ~4.5 GB(Q4) | 笔记本 | M2上~20 tokens/s | + +- **挑战**: + - **内存**:3B Q4模型占2 GB,但长对话的KV缓存增加了显著额外内存。手机上的上下文长度通常限制在2-4K token。 + - **热节流**:持续推理使手机发热。连续生成30秒后,SoC会降低时钟速度以防止过热,性能下降30-50%。 + - **电池**:以15 tokens/s运行3B模型消耗约3-5W。30分钟的对话消耗典型手机电池约5%。偶尔使用可以接受,但始终在线应用存在问题。 + +- **llama.cpp**是设备端LLM的标准。它在CPU(AVX2、NEON、I8MM)、Apple GPU(Metal)、NVIDIA GPU(CUDA)、AMD GPU(ROCm/Vulkan)甚至手机上(通过Android上的Termux)运行。 + +## 联邦学习 + +- **联邦学习**在许多设备上训练模型,无需集中数据。每个设备在其本地数据上训练,计算梯度更新,并将只有更新(而非数据)发送到聚合更新的中央服务器。 + +- **算法**(FedAvg): + 1. 服务器将当前模型发送给$K$个选定设备。 + 2. 每个设备在其本地数据上微调模型几步。 + 3. 每个设备将其更新后的模型(或差异)发送回服务器。 + 4. 服务器平均更新:$W_{\text{new}} = \frac{1}{K} \sum_{k=1}^{K} W_k$。 + 5. 重复。 + +- **隐私**:原始数据从不离开设备。服务器只看到聚合的模型更新。**差分隐私**向更新添加噪声,使得无法从梯度中逆向推断单个数据点。 + +- **通信效率**:模型更新很大(与模型相同大小)。压缩技术减少了这一点:**梯度量化**(发送INT8梯度而不是FP32)、**稀疏化**(只发送最大的梯度)和**梯度累积**(做更多本地步骤,发送更少频率)。 + +- **应用**:Google的键盘预测(Gboard)、Apple的语音识别、健康监测(在敏感健康数据上训练而不集中数据)。 + +## 延迟优化 + +- 除了压缩,还有几种技术减少端到端推理延迟: + +- **提前退出**:在中间层添加分类头。如果模型在第6层(共24层)已经自信,则返回预测而不运行第7-24层。简单输入提前退出,困难输入使用完整模型。对于混合简单和困难输入的任务,平均延迟显著下降。 + +- **模型分区**:在NPU(对矩阵乘高效)、GPU(对不规则操作高效)和CPU(处理其他一切)之间拆分模型。编译器根据性能分析决定哪些操作去哪里。 + +- **缓存**:对于具有重复查询的应用(自动补全、代码补全),缓存最近的计算。如果用户输入"How do I"且模型最近生成了"How do I"的补全,可以重用缓存的KV缓存,完全跳过预填充阶段。 + +- **推测性预取**:预测用户下一步将做什么,在用户询问之前开始推理。聊天应用可能在用户阅读当前答案时开始生成可能后续问题的响应。 + +## 编程任务(使用CoLab或notebook) + +1. 模拟模型压缩流水线。从float32模型开始,依次应用蒸馏(模拟)、剪枝和量化,并跟踪每一步的大小。 +```python +def compression_pipeline(original_params_M, original_bits=32): + size_mb = original_params_M * 1e6 * original_bits / 8 / 1e6 + + print(f"原始: {original_params_M}M 参数, {original_bits}-位 → {size_mb:.0f} MB") + + # 步骤1:知识蒸馏(减少参数) + distilled_params = original_params_M * 0.15 # 70B → ~10B 等价 + size_mb = distilled_params * 1e6 * original_bits / 8 / 1e6 + print(f"蒸馏后 ({distilled_params:.0f}M 参数): {size_mb:.0f} MB") + + # 步骤2:结构化剪枝(移除剩余30%) + pruned_params = distilled_params * 0.7 + size_mb = pruned_params * 1e6 * original_bits / 8 / 1e6 + print(f"剪枝后 ({pruned_params:.0f}M 参数): {size_mb:.0f} MB") + + # 步骤3:INT4量化 + size_mb = pruned_params * 1e6 * 4 / 8 / 1e6 + print(f"INT4量化后: {size_mb:.0f} MB") + + print(f"总压缩比: {original_params_M * 1e6 * original_bits / 8 / 1e6 / size_mb:.0f}x") + +print("=== 从70B模型开始 ===") +compression_pipeline(70000) + +print("\n=== 从7B模型开始 ===") +compression_pipeline(7000) +``` + +2. 估计设备端推理延迟。给定模型的操作计数和硬件规格,计算是否满足延迟目标。 +```python +def estimate_latency(model_name, params_M, bits, compute_tops, mem_bw_gbs, seq_len=256): + """估计内存带宽受限模型的token生成延迟。""" + # 模型大小(字节) + model_bytes = params_M * 1e6 * bits / 8 + + # 解码是内存受限的:每token必须加载整个模型 + time_per_token_ms = model_bytes / (mem_bw_gbs * 1e9) * 1000 + + # 每秒token数 + tokens_per_sec = 1000 / time_per_token_ms + + print(f"{model_name}: {params_M/1000:.1f}B 参数 @ {bits}-位 = {model_bytes/1e9:.1f} GB") + print(f" 内存带宽: {mem_bw_gbs} GB/s") + print(f" 每token时间: {time_per_token_ms:.1f} ms") + print(f" Tokens/秒: {tokens_per_sec:.0f}") + print() + +# Apple M2 Pro:200 GB/s 统一内存带宽 +print("=== Apple M2 Pro (200 GB/s) ===") +estimate_latency("Llama-7B Q4", 7000, 4, 15.8, 200) +estimate_latency("Llama-7B Q8", 7000, 8, 15.8, 200) +estimate_latency("Llama-70B Q4", 70000, 4, 15.8, 200) + +# 手机(Snapdragon 8 Gen 3):~50 GB/s LPDDR5 +print("=== Snapdragon 8 Gen 3 (50 GB/s) ===") +estimate_latency("Phi-3 Mini Q4", 3800, 4, 45, 50) +estimate_latency("Llama-3B Q4", 3000, 4, 45, 50) +``` diff --git a/chapter 17: AI inference/05. scaling and deployment.md b/chapter 17: AI inference/05. scaling and deployment.md new file mode 100644 index 0000000..7440348 --- /dev/null +++ b/chapter 17: AI inference/05. scaling and deployment.md @@ -0,0 +1,251 @@ +# 扩缩与部署 + +*向数百万用户提供大模型服务需要跨多个GPU分布推理、在需要之前预测token、缓存共享上下文以及选择合适的框架。本文涵盖推理时的并行性、推测性解码、前缀缓存、推理框架、成本优化和监控* + +- 单个H100 GPU服务一个70B模型可以处理约100个并发用户,交互延迟可接受。服务1000万用户需要100,000个GPU——云计算每年花费约30亿美元。每一个百分点的效率提升就能节省数千万美元。这就是推理优化不是学术问题的原因:它直接决定AI产品的经济性。 + +## 推理时的模型并行 + +- 当模型太大无法装入单张GPU时,必须跨多个GPU拆分。训练时的并行策略(第6章)在推理时适用,但权衡不同。 + +### 张量并行 + +- **张量并行**(Megatron风格,第6章)跨GPU拆分单个权重矩阵。对于线性层$Y = XW$,权重矩阵$W$跨$N$个GPU按列拆分。每个GPU计算部分结果,然后all-reduce聚合: + +$$W = [W_1 | W_2 | \cdots | W_N], \quad Y_i = X W_i, \quad Y = \text{concat}(Y_1, \ldots, Y_N)$$ + +- 在推理时,张量并行是模型无法装入单张GPU时的默认选择。FP16的70B模型需要140 GB——跨2张80 GB GPU使用张量并行拆分。 + +- **延迟影响**:张量并行每层增加一个all-reduce通信步骤。在NVLink(900 GB/s)上,每层增加约0.1 ms。在PCIe(32 GB/s)上,每层增加约3 ms。对于80层的70B模型在2张GPU上:NVLink总增加约8 ms,PCIe总增加约240 ms。这就是NVLink对多GPU推理至关重要的原因。 + +### 流水线并行 + +- **流水线并行**将不同的层分配给不同的GPU。GPU 1处理第0-39层,GPU 2处理第40-79层。token顺序流过流水线。 + +- 在推理时,流水线并行的延迟高于张量并行(每个token必须遍历整个流水线),但通信开销更低(只有激活值在GPU之间传递,无需all-reduce)。当GPU通过慢速互连(不同节点,无NVLink)连接时,更倾向于使用流水线并行。 + +### 序列并行 + +- 对于非常长的序列,即使模型本身适合,KV缓存本身可能无法装入单张GPU。**序列并行**将KV缓存分片到多个GPU上:每个GPU存储序列缓存键和值的一部分。 + +- 在注意力期间,每个GPU在其缓存的段上计算部分注意力分数,然后通过规约合并结果。这用于长上下文推理(128K+ token),其中KV缓存超过单GPU内存。 + +## 推测性解码 + +- **推测性解码**是影响最大的LLM推理优化之一。其洞察:解码速度慢是因为一次只生成一个token,每个token需要大模型的完整前向传播。但小模型可以更快地生成候选token,而大模型可以**验证**多个候选token。 + +![推测性解码:快速草稿模型生成5个候选token,目标模型一次验证所有,接受的token保留,拒绝的重新采样](../images/speculative_decoding.svg) + +- **算法**: + 1. **草稿模型**(小型、快速——例如1B参数)自回归生成$k$个候选token。 + 2. **目标模型**(大型、准确——例如70B)对整个草稿序列运行一次前向传播,计算每个候选token的概率。 + 3. 如果目标模型同意(该token的概率足够高),每个候选被**接受**。被拒绝的候选从目标模型的分布中重新采样。 + 4. 平均每个验证步骤接受多个token,加速比与接受率成正比。 + +$$\text{加速比} \approx \frac{k \times \text{acceptance\_rate}}{\text{cost\_ratio}} \approx 2\text{-}3\times$$ + +- **为什么无质量损失**:拒绝采样方案保证输出分布与目标模型完全匹配。推测性解码是无损的——输出在统计上与单独运行目标模型相同,只是更快。 + +- **变体**: + - **Medusa**(Cai等人,2024):不是独立的草稿模型,而是向目标模型添加多个轻量级"头",同时预测多个未来token。无需独立模型。 + - **EAGLE**(Li等人,2024):训练一个使用目标模型隐藏状态预测未来token的轻量级草稿头。接受率高于独立的草稿模型。 + - **自推测性解码**:目标模型本身使用提前退出生成草稿(仅运行前几层作为草稿,然后用完整模型验证)。 + - **并行解码**:并行生成多个延续(候选树)并一次性验证整棵树。吞吐量更高,但分支KV缓存使用更多内存。 + +## 前缀缓存 + +- 许多请求共享共同的前缀:系统提示、few-shot示例或常见查询模式。**前缀缓存**存储这些前缀的KV缓存并在请求之间重用。 + +- **系统提示缓存**:如果每个请求都以相同的2000-token系统提示开始,这2000个token的KV缓存被计算一次,并在所有请求之间共享。对于80层的70B模型,每次请求节省约200 MB。 + +- **基数树缓存**(SGLang):将缓存的前缀组织在基数树(trie)中。当新请求到达时,找到最长的缓存前缀匹配,并从那里开始生成,跳过匹配前缀的计算。 + +- **影响**:对于具有长共享前缀的应用(带系统提示的聊天机器人、具有常见检索段落的RAG),前缀缓存将TTFT降低50-90%,并节省相应的GPU计算。 + +## KV缓存驱逐 + +- 除了量化KV缓存(文件01)和使用GQA/MLA减小其大小(文件02)之外,**KV缓存驱逐**策略选择性地移除不太可能在未来被关注的缓存token。 + +- **H2O**(重要token识别器,Zhang等人,2023)观察到注意力分数遵循幂律:一小部分token("重要token")获得大部分注意力,而大多数token获得的注意力可以忽略不计。H2O保留: + + 1. **最近token**(最后$w$个token的滑动窗口,类似StreamingLLM)。 + 2. **重要token**(在所有过去解码步骤中累积注意力分数最高的前$k$个token)。 + +- 既不是最近也不是重要token的token被驱逐。这保持固定大小的KV缓存,同时保留实际影响生成的token。H2O仅使用20%的内存就实现了接近完整KV缓存的质量。 + +- **Scissorhands**(Liu等人,2023)采用类似方法,但使用更复杂的度量:在当前**步骤**中获得高注意力的token被保留,而已经$T$步没有被关注的token被驱逐。这适应了生成过程中注意力模式的变化。 + +- **动态驱逐+StreamingLLM**:结合注意力汇聚点(永久保留前几个token)和动态驱逐(保留最近+重要token)。这是最内存高效的方法,适用于非常长的生成,实现了无限长度生成,质量下降有限。 + +- 所有驱逐方法的核心洞察:LLM注意力在实践中是**稀疏的**——尽管架构会对所有缓存的token计算注意力,但实际注意力权重集中在小子集上。驱逐其余部分对输出质量影响极小。 + +## 推理框架 + +- LLM服务生态已收敛到几个主要框架: + +| 框架 | 优势 | 最适合 | +|-----------|-----------|----------| +| **vLLM** | PagedAttention、连续批处理、高吞吐量 | 通用LLM服务,最高吞吐量 | +| **TensorRT-LLM** | NVIDIA优化内核、FP8、飞行中批处理 | NVIDIA GPU上的最大性能 | +| **SGLang** | 前缀缓存(RadixAttention)、快速结构化生成 | 具有共享前缀的应用,受限输出 | +| **llama.cpp** | CPU/Metal/CUDA/Vulkan、GGUF量化、可移植 | 消费级硬件,设备端推理 | +| **TGI**(HuggingFace) | 简单API,易于部署,模型中心集成 | 快速部署,HuggingFace生态 | +| **Ollama** | 一键下载和提供服务 | 个人使用,本地开发 | +| **ExLlamaV2** | 极致量化优化(EXL2格式) | 内存受限的GPU推理 | + +- **vLLM**是生产级LLM服务的默认选择。它支持连续批处理、PagedAttention、张量并行、推测性解码、LoRA服务和大多数开源模型。 + +- **TensorRT-LLM**在NVIDIA硬件上实现最高的原始性能(在相同GPU上比vLLM快10-30%),但灵活性较低且更难以定制。 + +- **SGLang**在应用具有结构化输出(JSON、特定格式的代码)或共享前缀时表现出色,这得益于其基数注意力缓存和受限解码引擎。 + +## 成本优化 + +- 在规模上,推理成本主导ML预算。降低成本的策略: + +- **合理选择GPU**:并非每个模型都需要H100。量化的7B模型在A10G(约$1/小时)上运行良好,而不是H100(约$8/小时)。匹配GPU到工作负载。 + +- **竞价实例**:云提供商提供未使用的GPU容量,折扣60-90%(AWS Spot、GCP Preemptible)。竞价实例可能被中断,因此适用于批处理推理而不是延迟关键型服务。结合抢占处理(保存状态,在新实例上恢复),竞价实例也可以服务交互式流量。 + +- **自动扩缩**:根据流量扩展GPU数量。高峰期扩展,夜间缩减。Kubernetes HPA(水平Pod自动扩缩器)或云原生自动扩缩(AWS SageMaker、GCP Vertex AI)处理此功能。 + +- **批处理+利用率**:30%和90% GPU利用率之间的差异是每token成本3倍。连续批处理、智能调度和PagedAttention都提高了利用率。 + +- **量化**:INT4 vs FP16是4倍更少内存 → 适合更小的GPU → 成本降低2-4倍。此外,更多请求适合同一批次 → 更高吞吐量 → 更低每token成本。 + +- **每token成本基准**(近似值,2026年): + +| 配置 | 每100万token成本 | +|-------|-------------------| +| GPT-4o API | $2.50 | +| Claude 3.5 Sonnet API | $3.00 | +| Llama-70B on H100(vLLM,FP16) | $0.50 | +| Llama-70B on H100(TRT-LLM,INT8) | $0.25 | +| Llama-8B on A10G(vLLM,INT4) | $0.05 | +| Llama-3B 设备端(llama.cpp) | $0(硬件成本摊销) | + +## 监控 + +- 生产推理需要持续监控,以便在用户受到影响之前发现降级: + +- **延迟监控**:跟踪TTFT和TPOT的p50、p95和p99。设置告警,当p99超过SLO时触发。p99的尖峰通常指示:KV缓存内存压力(抖动)、长时间运行的请求垄断批次、或GPU热节流。 + +- **吞吐量监控**:跟踪每GPU每秒token数。下降指示:批次效率降低(许多短请求→批次利用率低)、序列长度增加(每个请求更多KV缓存内存)、或硬件问题(GPU处于ECC纠错模式,运行更慢)。 + +- **GPU利用率**:跟踪SM占用率、内存利用率和内存带宽。低SM占用率+高内存利用率=内存受限(需要更多带宽或量化)。高SM占用率+低内存利用率=计算受限(需要更多FLOPS或更小模型)。 + +- **模型质量监控**:跟踪每请求指标(响应长度、保留集上的困惑度、用户反馈信号)。模型质量可能因以下原因降级:数据漂移(传入请求的分布变化)、KV缓存量化误差在长对话中累积、或服务流水线中的错误。 + +- **成本监控**:跟踪每模型每GPU类型每token成本。如果成本增加而吞吐量没有增加,调查效率回归(新模型版本内存使用更高、批次配置次优、或GPU利用不足)。 + +- **工具**:Prometheus + Grafana(第15章)用于基础设施指标,vLLM/TRT-LLM的内置指标端点,以及用于模型级指标的自定义日志记录。 + +## 编程任务(使用CoLab或notebook) + +1. 模拟推测性解码。使用快速的"草稿"函数和慢速的"目标"函数,测量一次生成和验证多个token的加速比。 +```python +import random +import time + +def target_model(tokens): + """慢但准确的模型。返回每个候选token的概率。""" + time.sleep(0.01) # 模拟每次前向传播10ms + # 用于模拟:接受偶数token + return [0.9 if t % 2 == 0 else 0.1 for t in tokens] + +def draft_model(): + """快但近似的模型。生成一个候选token。""" + time.sleep(0.001) # 模拟每token 1ms + return random.randint(0, 9) + +def standard_decoding(n_tokens): + """一次生成一个token,使用目标模型。""" + tokens = [] + for _ in range(n_tokens): + time.sleep(0.01) # 目标模型生成1个token + tokens.append(random.randint(0, 9)) + return tokens + +def speculative_decoding(n_tokens, k=4): + """生成k个草稿token,用目标模型验证,接受/拒绝。""" + tokens = [] + total_target_calls = 0 + + while len(tokens) < n_tokens: + # 草稿:快速生成k个候选 + candidates = [draft_model() for _ in range(k)] + + # 验证:一次目标模型调用验证所有k个候选 + probs = target_model(candidates) + total_target_calls += 1 + + # 接受token,直到一个被拒绝 + for i, (tok, prob) in enumerate(zip(candidates, probs)): + if random.random() < prob: + tokens.append(tok) + if len(tokens) >= n_tokens: + break + else: + # 从目标分布重新采样 + tokens.append(tok + 1) # 简化重新采样 + break + + return tokens, total_target_calls + +n = 50 + +start = time.time() +_ = standard_decoding(n) +standard_time = time.time() - start + +start = time.time() +_, target_calls = speculative_decoding(n, k=5) +spec_time = time.time() - start + +print(f"标准: {standard_time:.2f}s ({n} 次目标调用)") +print(f"推测性: {spec_time:.2f}s ({target_calls} 次目标调用)") +print(f"加速比: {standard_time / spec_time:.1f}x") +``` + +2. 估计应用于LLM服务部署的不同优化策略的成本节省。 +```python +def serving_cost_analysis( + model_name, params_B, precision_bits, + gpu_name, gpu_mem_gb, gpu_cost_per_hr, + target_throughput_tps, +): + """估计LLM部署的服务成本。""" + model_size_gb = params_B * 1e9 * precision_bits / 8 / 1e9 + gpus_for_model = max(1, int((model_size_gb * 1.2) / gpu_mem_gb + 0.99)) # 1.2x用于KV缓存 + + # 粗略吞吐量估计(内存带宽受限) + tokens_per_gpu = 500 / (params_B * precision_bits / 16) # 归一化到7B FP16的500 tok/s + total_throughput = tokens_per_gpu * gpus_for_model + + replicas = max(1, int(target_throughput_tps / total_throughput + 0.99)) + total_gpus = gpus_for_model * replicas + cost_per_hr = total_gpus * gpu_cost_per_hr + cost_per_1M_tokens = cost_per_hr / (total_throughput * replicas * 3600 / 1e6) + + print(f"{model_name} @ {precision_bits}-位 在 {gpu_name} 上:") + print(f" 模型大小: {model_size_gb:.0f} GB → {gpus_for_model} GPU(s)/副本") + print(f" 吞吐量: {total_throughput:.0f} tok/s/副本") + print(f" 需达到{target_throughput_tps} tok/s的副本数: {replicas}") + print(f" 总GPU数: {total_gpus}") + print(f" 成本: ${cost_per_hr:.0f}/小时, ${cost_per_1M_tokens:.2f}/100万token") + print() + +print("=== 成本比较 ===\n") + +# 基线:H100上的FP16 +serving_cost_analysis("Llama-70B", 70, 16, "H100", 80, 8.0, 1000) + +# 量化后:H100上的INT8 +serving_cost_analysis("Llama-70B", 70, 8, "H100", 80, 8.0, 1000) + +# 量化后:A100上的INT4 +serving_cost_analysis("Llama-70B", 70, 4, "A100", 80, 4.0, 1000) + +# 小模型:A10G上的8B +serving_cost_analysis("Llama-8B", 8, 4, "A10G", 24, 1.0, 1000) +``` diff --git a/chapter 18: ML systems design/01. systems design fundamentals.md b/chapter 18: ML systems design/01. systems design fundamentals.md new file mode 100644 index 0000000..9a40e0a --- /dev/null +++ b/chapter 18: ML systems design/01. systems design fundamentals.md @@ -0,0 +1,176 @@ +# 系统设计基础 + +*系统设计是构建可在大规模下可靠运行的软件的方法。本文件涵盖客户端-服务器架构、网络协议、DNS、代理、负载均衡、缓存、数据库、消息队列、一致性模型和弹性模式* + +- 生产环境中的每一个ML系统都是一个分布式系统。推荐引擎不仅仅是模型——它是一个API服务器、一个特征存储、一个模型注册表、一个缓存层、一个消息队列和一个监控栈,所有这些通过网络进行通信。理解系统设计是区分"我训练了一个模型"和"我构建了一个产品"的关键。 +- 顶级科技公司(Google、Meta、Amazon、OpenAI)的系统设计面试会测试你是否能设计这些系统。本章为你提供基础构建模块(本文件)、云基础设施(文件02)、扩展模式(文件03)、ML特定设计(文件04)和实操示例(文件05)。 + +## 客户端-服务器架构 + +- 基本模式:**客户端**发送请求,**服务器**处理请求并返回响应。你的浏览器(客户端)向google.com(服务器)发送HTTP请求,服务器返回HTML。 +- **请求-响应模式**:同步。客户端等待响应。简单但会产生瓶颈:客户端在等待时空闲,服务器必须在处理完当前请求后才能继续。 +- **无状态服务器**:服务器不记住先前的请求。每个请求包含处理它所需的所有信息。这使得扩展变得容易:任何服务器都可以处理任何请求,因此你可以在负载均衡器后面添加更多服务器。 +- **有状态服务器**:服务器在请求之间维护状态(例如,用户会话)。扩展更困难,因为来自同一用户的请求必须发送到同一台服务器(会话亲和性)。现代系统通过将状态存储在数据库或缓存(Redis)中来避免服务器端状态。 + +## 网络协议 + +- 我们在第13章(TCP/IP层、套接字)中介绍了网络知识。这里我们关注系统设计中使用的应用层协议: +- **HTTP/HTTPS**:Web和大多数API的协议。请求方法:GET(读取)、POST(创建/预测)、PUT(更新)、DELETE(删除)。HTTPS增加了TLS加密(第13章安全部分)。REST API(第15章文件03)基于HTTP构建。 +- **WebSocket**:客户端和服务器之间的持久双向连接。与HTTP(请求→响应→连接关闭)不同,WebSocket保持连接打开,用于实时流式传输。用于:LLM令牌流式传输(生成时发送令牌)、实时仪表盘、聊天应用。 +- **gRPC**:Google的RPC框架。使用Protocol Buffers(二进制序列化,比JSON小约10倍且更快),基于HTTP/2。支持流式传输(服务端、客户端、双向)。用于注重性能的内部服务间通信。Triton推理服务器(第15章)和TensorFlow Serving使用gRPC。 +- **Protocol Buffers**:在`.proto`文件中定义消息模式: + +```protobuf +message PredictRequest { + repeated float features = 1; + string model_version = 2; +} + +message PredictResponse { + float prediction = 1; + float confidence = 2; +} + +service ModelService { + rpc Predict(PredictRequest) returns (PredictResponse); +} +``` + +- 该模式被编译成任何语言(Python、C++、Go、Java)的客户端和服务端代码。类型安全、向后兼容性和高性能都自然具备。 + +## DNS + +- **DNS**(域名系统)将人类可读的名称转换为IP地址(第13章)。对于系统设计,DNS还提供: +- **通过DNS的负载均衡**:为同一域名返回不同的IP地址,将流量分布到多个服务器。简单但粒度较粗(DNS结果会被缓存数分钟到数小时,因此流量不会快速重新平衡)。 +- **地理路由**:根据客户端位置返回最近数据中心的IP。东京的用户获得日本数据中心;伦敦的用户获得欧洲数据中心。 +- **故障转移**:如果服务器宕机,DNS停止返回其IP。新客户端连接到健康的服务器。但缓存的DNS条目意味着某些客户端会继续访问已宕机的服务器持续数分钟(TTL问题)。 + +## 代理 + +- **代理**是客户端和服务器之间的中介: +- **反向代理**(在服务器前面):客户端连接到代理,代理将请求转发给后端服务器。客户端不知道哪个服务器处理了请求。**Nginx**和**HAProxy**是标准的反向代理。它们提供:负载均衡(分发请求)、SSL终止(在代理处解密HTTPS,向后端发送明文HTTP)、缓存、速率限制和压缩。 +- **API网关**:一种专门用于API的反向代理。处理身份验证、速率限制、请求路由(不同路径→不同服务)和API版本管理。**Kong**、**AWS API Gateway**和**Envoy**是常见选择。 +- 对于ML服务:API网关位于模型服务器前面。它验证API密钥、对免费用户进行速率限制、将`/v1/predict`路由到模型服务器A、将`/v2/predict`路由到模型服务器B,并收集使用指标。 + +## 负载均衡 + +- 当你拥有多台服务器时,**负载均衡器**将传入请求分布到它们之间。 + +![负载均衡器将传入请求分布到多个后端服务器](../images/load_balancer.svg) + +- **算法**: + - **轮询**:按顺序发送请求到服务器(1, 2, 3, 1, 2, 3...)。简单、公平,但不考虑服务器负载。 + - **最少连接**:发送到活动连接最少的服务器。适用于处理时间可变的请求(有些LLM请求生成10个令牌,有些生成1000个)。 + - **加权轮询**:容量更大的服务器获得更多请求。拥有80 GB GPU内存的服务器处理的请求量是40 GB的两倍。 + - **一致性哈希**:对请求键进行哈希运算,映射到特定服务器。相同的键始终发送到相同的服务器。适用于:缓存(同一用户的请求命中同一缓存)、会话亲和性和前缀缓存(第17章:具有相同系统提示词的请求发送到具有该提示词KV缓存的服务器)。 +- **L4 vs L7负载均衡**: + - **L4**(传输层):基于IP和端口路由。快速但无法检查请求内容。 + - **L7**(应用层):基于HTTP路径、标头或正文内容路由。可以将`/api/chat`路由到聊天服务器,将`/api/embed`路由到嵌入服务器。较慢但更灵活。 + +## 缓存 + +- **缓存**将频繁访问的数据存储在快速存储层(内存)中,以避免重新计算或重新获取。 + +![缓存旁路模式:先检查缓存,未命中时从数据库获取并存储到缓存中以备下次使用](../images/cache_aside_pattern.svg) + +- **缓存模式**: + - **缓存旁路**(惰性加载):应用程序先检查缓存。未命中时,从数据库获取、存储到缓存并返回。最常见的模式。 + - **直写**:每次写入同时写入缓存和数据库。确保缓存始终是最新的,但会减慢写入速度。 + - **回写**:写入只进入缓存;缓存异步刷新到数据库。写入最快,但若缓存刷新前崩溃则有数据丢失风险。 +- **驱逐策略**(当缓存满时): + - **LRU**(最近最少使用):驱逐最长时间未被访问的条目。最常见的策略。 + - **LFU**(最不频繁使用):驱逐访问次数最少的条目。当某些条目持续受欢迎时效果更好。 + - **TTL**(生存时间):条目在固定时长后过期。用于会过时的数据(模型预测缓存5分钟,特征值缓存1小时)。 +- **CDN**(内容分发网络):用于静态内容(图片、JavaScript、CSS)的全球分布式缓存。遍布100多个地点的服务器从离用户最近的位置提供缓存内容。对于ML:模型权重可以缓存在CDN上以实现快速下载。 +- **Redis**:标准的键值内存缓存/数据库。支持字符串、列表、集合、有序集合、哈希和流。亚毫秒级延迟。用于:缓存模型预测、存储会话数据、速率限制(统计每个用户每分钟的请求数)和实时特征服务。 +- 对于ML服务:缓存重复输入的预测结果。如果很多用户问"法国的首都是什么?",计算一次答案然后提供缓存结果。对于聊天机器人工作负载,缓存命中率通常为20-40%,按比例降低GPU成本。 + +## 数据库 + +### SQL(关系型) + +- **SQL数据库**(PostgreSQL、MySQL)以包含行和列的形式存储数据。表之间的关系通过外键表示。查询使用SQL。**ACID**保证: + - **原子性**:事务要么完全完成,要么完全回滚。没有部分更新。 + - **一致性**:数据库从一个有效状态转换到另一个有效状态。约束条件(唯一键、外键)始终得到满足。 + - **隔离性**:并发事务不互相干扰。 + - **持久性**:已提交的数据在崩溃后仍然存在(在确认前写入磁盘)。 +- SQL数据库擅长:具有关系的有结构数据、复杂查询(联接、聚合)、严格的一致性要求和数据完整性。 + +### NoSQL + +- **NoSQL数据库**为了可扩展性和灵活性而牺牲了一些ACID保证: + - **键值存储**(Redis、DynamoDB):最简单的模型。按键快速查找。用于缓存、会话存储和特征存储。 + - **文档存储**(MongoDB、Firestore):存储类似JSON的文档。灵活的模式(每个文档可以有不同字段)。用于用户资料、产品目录和配置。 + - **列族存储**(Cassandra、HBase):针对写入密集型工作负载和时间序列数据进行了优化。用于事件日志、指标和分析。 + - **图数据库**(Neo4j):存储节点和边。针对遍历查询进行了优化。用于社交网络、知识图谱和推荐系统。 + - **向量数据库**(Pinecone、Milvus、Weaviate、FAISS):存储高维嵌入并支持近似最近邻(ANN)搜索。对于语义搜索、RAG(检索增强生成)和推荐系统至关重要。 + +### CAP定理 + +- 在分布式数据库中,最多只能满足三个属性中的两个: + - **一致性**:每次读取都返回最新的写入。 + - **可用性**:每个请求都会收到响应(即使某些节点宕机)。 + - **分区容忍性**:系统在网络分区(节点无法通信)时仍能继续运行。 + +![CAP定理:由于网络分区不可避免,需要在CP(一致)或AP(可用)之间选择](../images/cap_theorem.svg) + +- 由于分布式系统中网络分区不可避免,真正的选择是**CP**(一致但在分区期间可能不可用——如PostgreSQL) vs **AP**(可用但在分区期间可能返回过期数据——如Cassandra、DynamoDB)。 +- 对于ML:特征存储通常选择AP(稍微过期的特征值也比无法预测要好)。模型注册表选择CP(提供错误的模型版本是灾难性的)。 + +### 分片 + +- **分片**将数据库拆分到多台机器上。每个分片持有数据的一个子集。 +- **哈希分片**:对键进行哈希运算以确定分片。`shard = hash(user_id) % num_shards`。分布均匀但不支持范围查询。 +- **范围分片**:每个分片持有一个键范围(用户A-G在分片1,H-N在分片2)。支持范围查询但可能产生热点(如果很多用户名字以"S"开头)。 +- **重新分片问题**:添加分片会使哈希映射失效。**一致性哈希**最小化数据移动:添加第n个分片时,只有约1/n的键需要移动。 + +### 数据库索引 + +- **索引**是一种加速查询的数据结构,代价是额外的存储空间和较慢的写入速度。没有索引时,查询会扫描每一行($O(n)$)。有索引时,可以在$O(\log n)$时间内找到目标。 +- **B树索引**(默认):一种平衡树(第13章、第14章),其中每个节点包含多个键和指针。B树对缓存友好(宽节点适合缓存行)并支持范围查询(`WHERE age BETWEEN 20 AND 30`)。大多数SQL数据库使用B树。 +- **哈希索引**:使用哈希函数将键映射到行位置。$O(1)$查找但不支持范围查询。用于精确匹配查找(`WHERE id = 12345`)。 +- **复合索引**:对多个列的索引。`CREATE INDEX ON users(country, city)` 加速按国家或按国家+城市筛选的查询,但不能加速仅按城市的查询(最左边的列必须在查询中)。 +- **权衡**:每个索引都会加速读取但减慢写入(每次插入/更新/删除都必须更新索引)并占用存储空间(每个索引约占表大小的10-30%)。不要索引所有内容——只索引你经常查询的列。 +- **对于ML系统**:特征存储的在线数据库需要在实体键(user_id、item_id)上建立索引以实现快速特征查找。实验跟踪数据库需要在(experiment_id、metric_name)上建立索引以实现仪表盘查询。 + +### API设计 + +- 系统通过API进行通信。良好的API设计使系统可用、可进化和可调试: +- **REST约定**:使用名词表示资源(`/users`、`/models`),HTTP方法表示操作(GET=读取、POST=创建、PUT=更新、DELETE=删除),状态码表示结果(200=OK、201=已创建、400=错误请求、404=未找到、429=被限流、500=服务器错误)。 +- **分页**:对于返回列表的端点,永远不要一次返回所有结果。使用基于游标的分页(`GET /items?cursor=abc&limit=50`)或基于偏移量的分页(`GET /items?offset=100&limit=50`)。对于大数据集,基于游标的分页更高效(基于偏移量的分页需要跳过行)。 +- **版本管理**:在API路径前加上版本前缀(`/v1/predict`、`/v2/predict`)。这样可以在不破坏现有客户端的情况下演进API。客户端按照自己的节奏迁移到v2;v1被弃用但在流量下降之前不会删除。 +- **错误响应**:返回结构化的错误信息: + +```json +{ + "error": { + "code": "INVALID_INPUT", + "message": "特征'user_age'必须为正整数", + "details": {"field": "user_age", "value": -5} + } +} +``` + +## 消息队列 + +- **消息队列**将生产者(生成工作的服务)与消费者(处理工作的服务)解耦。生产者将消息发送到队列;消费者在就绪时拉取消息。 +- **队列为什么重要**:没有队列时,如果消费者慢或宕机,生产者会被阻塞。有了队列,生产者发送后就无需等待;队列缓冲消息,直到消费者准备好。 +- **Apache Kafka**:一个分布式、持久化、高吞吐量的消息队列。消息存储在**主题**中,每个主题跨多个代理分区。消费者从分区读取,跟踪其位置(**偏移量**)。Kafka保证分区内的顺序,并可重播消息(日志是持久化的)。 +- **发布/订阅**:发布者将消息发送到主题;该主题的所有订阅者都会收到一份副本。用于事件驱动架构:"新模型已部署"触发监控服务、A/B测试服务和日志服务同时响应。 +- 对于ML:预测请求通过HTTP到达,放入Kafka队列,由GPU工作线程处理,结果通过回调或WebSocket返回。队列缓冲突发的流量,并确保即使GPU工作线程崩溃也不会丢失请求。 + +## 一致性模型 + +- 在分布式系统中,不同节点可能对数据有不同的视图。**一致性模型**定义了系统提供的保证: +- **强一致性**:写操作之后,所有后续读取(从任何节点)都能看到新值。易于推理但速度慢(需要在节点之间协调)。 +- **最终一致性**:写操作之后,读取可能在某段时间内看到过期数据,但最终会看到新值。速度快(无需协调)但需要应用程序处理过期读取。 +- **因果一致性**:如果操作A因果上先于B(例如,"写入X然后读取X"),系统保证B能看到A的结果。但不相关的操作可能以任何顺序被看到。 +- **读写一致性**:用户始终能立即看到自己的写入,即使其他用户看到的是过期数据。大多数应用程序所需的最小一致性。 + +## 弹性模式 + +- **速率限制**:限制每个用户在时间窗口内的请求数。防止滥用并确保公平访问。使用Redis中的令牌桶或滑动窗口计数器实现。 +- **断路器**:如果下游服务开始失败(错误率超过阈值),断路器"断开"并停止向其发送请求(立即返回回退响应)。超时后,"半开"并发送测试请求。如果测试成功,则"闭合"(恢复正常操作)。这防止了级联故障:如果特征存储宕机,模型服务器返回无特征的预测,而不是每次请求都超时。 +- **背压**:当系统过载时,它向上游发出信号要求减速。与其接受请求然后失败,不如尽早拒绝多余的请求(返回429或503状态码)。客户端以指数退避重试。 +- **指数退避重试**:如果请求失败,等待1秒后重试。如果再次失败,等待2秒。然后是4秒、8秒,依此类推。加入随机抖动以防止所有客户端同时重试(惊群问题)。 +- **幂等性**:如果执行两次的效果与执行一次相同,则该操作是幂等的。`PUT /user/123 {"name": "Alice"}`是幂等的(将名称设置为"Alice"两次没问题)。`POST /payments`不是(支付两次很糟糕)。使操作幂等可确保重试是安全的。 diff --git a/chapter 18: ML systems design/02. cloud computing.md b/chapter 18: ML systems design/02. cloud computing.md new file mode 100644 index 0000000..6596768 --- /dev/null +++ b/chapter 18: ML systems design/02. cloud computing.md @@ -0,0 +1,165 @@ +# 云计算 + +*云计算为ML工作负载提供按需基础设施,无需拥有硬件。本文件涵盖服务模型、主要云服务商、容器和Kubernetes、存储、云网络、无服务器计算、成本管理和基础设施即代码* + +- 训练前沿模型需要数千个GPU持续数月。没有初创公司拥有这样的硬件。云计算让你按小时租赁,训练时扩展,推理时缩减,只为使用量付费。理解云基础设施对于任何在笔记本电脑之外构建ML系统的人来说都是必不可少的。 + +## 云服务模型 + +![云服务层:IaaS给你最大控制权,SaaS给你最少控制权](../images/cloud_service_layers.svg) + +- 云服务按提供商管理程度的层叠划分: + +| 模型 | 你管理 | 提供商管理 | 示例 | +|-------|-----------|-----------------|---------| +| **IaaS**(基础设施) | 操作系统、运行时、应用 | 硬件、虚拟化、网络 | AWS EC2、GCP Compute Engine | +| **PaaS**(平台) | 应用、数据 | 操作系统、运行时、扩展、修补 | AWS SageMaker、GCP Vertex AI | +| **SaaS**(软件) | 什么都不用管(只管用) | 一切 | OpenAI API、Weights & Biases | +| **FaaS**(函数) | 单个函数 | 其他所有 | AWS Lambda、GCP Cloud Functions | + +- **对于ML**:大多数团队混合使用。IaaS用于自定义训练(完全控制GPU实例),PaaS用于托管训练和服务(SageMaker、Vertex AI处理编排),SaaS用于工具(W&B用于实验跟踪,OpenAI API用于基线比较)。 + +## 主要云服务商 + +### AWS(亚马逊云服务) + +- 最大的云服务商(约32%市场份额)。关键ML服务: + - **EC2**:虚拟机。GPU实例:p4d(A100)、p5(H100)、g5(A10G用于推理)。 + - **S3**:对象存储。存储数据集和模型权重的标准。几乎无限的容量,约$0.023/GB/月。 + - **SageMaker**:托管ML平台。处理训练、超参数调优、部署和监控。 + - **EKS**:托管Kubernetes。 + - **Lambda**:无服务器函数。不适合GPU工作负载,但适用于预处理和编排。 + +### GCP(谷歌云平台) + +- 谷歌的云(约11%市场份额)。关键ML服务: + - **Compute Engine**:虚拟机。GPU实例提供A100、H100。**TPU VM**用于TPU访问。 + - **GCS**:对象存储(类似S3)。 + - **Vertex AI**:托管ML平台。原生支持JAX/TPU。 + - **GKE**:托管Kubernetes(最成熟的K8s产品,因为谷歌创建了Kubernetes)。 + - **Cloud TPU**:GCP独有。v5e和v5p用于大规模训练。 + +### Azure(微软) + +- 微软的云(约23%市场份额)。关键ML服务: + - **Azure VM**:GPU实例提供A100、H100。 + - **Azure Blob存储**:对象存储。 + - **Azure ML**:托管ML平台。 + - **AKS**:托管Kubernetes。 + - **OpenAI服务**:通过Azure API独家访问OpenAI模型。 + +## 容器和Kubernetes + +- 我们在第13章(操作系统)中概念性地介绍了容器(Docker)和Kubernetes,并在第15章(部署)中进行了实践。这里我们关注**云特定的**模式: + +### Kubernetes用于ML + +- **Kubernetes(K8s)**大规模编排容器。关键概念: + - **Pod**:最小的可部署单元。包含一个或多个共享网络和存储的容器。一个模型服务Pod可能包含:模型服务器容器 + 用于指标收集的边车容器。 + - **Deployment**:管理一组相同的Pod。指定所需的副本数。如果Pod崩溃,K8s会自动创建替代Pod。 + - **Service**:一组Pod的稳定网络端点。客户端连接到Service;K8s路由到健康的Pod。类型:ClusterIP(内部)、NodePort(通过节点端口对外暴露)、LoadBalancer(通过云负载均衡器对外暴露)。 + - **StatefulSet**:类似Deployment但用于有状态工作负载。每个Pod获得持久的身份和稳定的存储。用于数据库和分布式训练(每个工作者需要稳定的身份以便通信)。 + - **DaemonSet**:在每个节点上运行一个Pod。用于:监控代理(Prometheus节点导出器)、日志收集器(Fluentd)、GPU设备插件(NVIDIA设备插件)。 +- **K8s中的GPU调度**:NVIDIA设备插件将GPU暴露为K8s资源。Pod请求GPU: + +```yaml +resources: + limits: + nvidia.com/gpu: 2 # 此Pod需要2个GPU +``` + +- K8s将Pod调度到具有2个可用GPU的节点上。这就是云ML平台为训练和推理分配GPU的方式。 + +### 自动缩放 + +- **水平Pod自动缩放器(HPA)**:基于指标(CPU使用率、请求率、自定义指标如GPU利用率或队列深度)缩放Pod数量。 +- **集群自动缩放器**:缩放节点数量。如果由于没有足够节点而无法调度Pod,集群自动缩放器会从云服务商处配置新的VM。当节点利用不足时,它会排空并终止它们。 +- **KEDA**(Kubernetes事件驱动自动缩放):基于外部事件源(Kafka队列深度、HTTP请求率)进行缩放。非常适合推理:当请求队列增长时扩展模型服务器,当队列为空时缩减。 + +## 存储 + +| 类型 | 特性 | 用途 | 示例 | +|------|----------------|----------|---------| +| **块存储** | 低延迟,附加到单台VM | 操作系统磁盘、数据库 | AWS EBS、GCP Persistent Disk | +| **对象存储** | 无限容量,HTTP访问 | 数据集、模型权重、日志 | AWS S3、GCS、Azure Blob | +| **文件存储** | 跨VM共享,POSIX | 共享训练数据 | AWS EFS、GCP Filestore、NFS | +| **数据湖** | 读取时定义模式,原始数据 | 分析、特征工程 | Delta Lake、Iceberg、Hudi | + +- **对于ML训练**:数据集存储在对象存储(S3/GCS)中。训练脚本从对象存储读取数据到内存。对于快速随机访问(随机数据加载),要么:(1)在训练前将数据集下载到本地SSD,(2)使用高吞吐量文件系统(Lustre、FSx),或(3)使用能高效流式和缓存的数据库加载库(WebDataset、FFCV)。 +- **模型权重**:存储在带版本管理的对象存储中。70B模型在FP16下约140 GB。以1 GB/s的速度从S3加载约需2.5分钟。在本地SSD上缓存可减少推理的冷启动时间。 + +## 云网络 + +- **VPC**(虚拟私有云):云中的隔离网络。你的VM、数据库和服务在VPC内部通信。外部流量通过负载均衡器或网关进入。 +- **子网**:将VPC划分为多个段。公有子网可访问互联网(用于API服务器)。私有子网不可访问(用于数据库、GPU工作线程)。这是最小权限安全原则在网络上的等价物。 +- **安全组**(AWS)/ **防火墙规则**(GCP):控制允许哪些流量。"允许来自任何地方的入站HTTP端口80。仅允许来自我的IP的入站SSH端口22。阻止其他所有流量。"安全组配置错误是云安全事件的首要原因。 +- **服务网格**(Istio、Envoy):管理K8s内部的服务间通信。提供:mTLS加密(每次服务间调用都加密)、流量路由(A/B测试:将10%流量路由到新模型)、重试、超时、断路和可观测性(哪个服务调用了哪个,花了多长时间)。 + +## 无服务器计算 + +- **无服务器**(AWS Lambda、GCP Cloud Functions):你上传一个函数,云服务商在触发时运行它。无需管理服务器,无需配置缩放。按调用次数付费(通常每100万次调用$0.20 + 计算时间)。 +- **冷启动**:一段时间不活动后的第一次调用需要更长时间(服务商必须分配容器并加载你的代码)。冷启动为0.5-5秒,使得无服务器不适合对延迟敏感的ML推理。 +- **对于ML**:无服务器适用于:预处理(发送到模型前调整图像大小)、后处理(格式化模型输出,发送通知)、编排(新数据到达时触发训练流水线)和轻量级推理(能容忍冷启动的小模型)。 +- 无服务器**不**适用于:GPU推理(大多数无服务器平台不支持GPU)、长时间运行的训练作业(Lambda的15分钟超时)或有状态服务(调用之间没有持久状态)。 + +## 成本管理 + +- 云成本是ML团队的首要运营问题。单个H100实例约$8/小时。64-GPU训练运行约$500/小时。一个月的训练运行约$360,000。成本优化是工程问题,不是会计问题。 +- **竞价/抢占式实例**:未使用的云容量以60-90%的折扣出售。服务商可在30秒到2分钟通知后回收。用于:容错训练(经常检查点,在新实例上恢复)、批量推理、数据预处理。不用于:对延迟敏感的服务(中断=停机)。 +- **预留实例**:承诺使用1-3年,享受30-60%折扣。用于:你知道基线负载的稳态推理服务。 +- **自动缩放**:高峰时段扩展,夜间/周末缩减。峰时需要10个GPU、夜间需要2个的模型服务器,通过自动缩放相比24/7运行10个GPU可节省约60%成本。 +- **合理选型**:不要在7B模型上使用H100,如果它在A10G上运行良好。将GPU匹配到工作负载。使用性能分析(第16章)确定最合适的GPU。 +- **存储成本**:对象存储便宜(S3标准约$0.023/GB/月),但会累积。一个团队如果保存每个训练检查点(每个10 GB,每个实验100个,50个实验),累积50 TB = $1,150/月。设置生命周期策略自动删除旧检查点。 + +## 多区域部署 + +- 对于全球ML系统(服务全球用户),在单个区域部署意味着远程用户的高延迟(东京的用户访问美国服务器会增加约150ms的网络往返)和单点故障(如果该区域宕机,整个服务离线)。 +- **多区域模式**: + - **主备模式**:一个主区域处理所有流量。辅助区域有热备(复制数据,准备接收流量)。主区域故障时,DNS切换到辅助区域。故障转移期间的停机时间:30秒到几分钟。 + - **双活模式**:两个区域同时处理流量。用户被路由到最近的区域。两个区域都有最新数据(异步或同步复制)。单区域故障时无停机——流量自动重新路由。 +- **数据复制**:困难的部分。模型权重可以轻松复制(复制到每个区域的S3)。特征存储数据必须以可接受的陈旧度复制。用户数据可能有**数据驻留要求**(GDPR:欧洲用户数据必须留在欧洲)。 +- **GPU云价格比较**(2026年近似值): + +| GPU | AWS | GCP | Azure | 典型用途 | +|-----|-----|-----|-------|-------------| +| A10G(24 GB) | $1.00/小时(g5) | $0.90/小时 | $0.90/小时 | 小模型推理 | +| A100(80 GB) | $4.10/小时(p4d) | $3.70/小时 | $3.40/小时 | 训练、大型推理 | +| H100(80 GB) | $8.00/小时(p5) | $7.50/小时 | $7.00/小时 | 前沿训练 | +| TPU v5e | 无 | $1.20/小时 | 无 | JAX大规模训练 | + +- 竞价/抢占式定价通常比这些价格低60-70%。价格因区域和可用性而异。 + +## 基础设施即代码 + +- **IaC**在版本控制的配置文件中定义基础设施(VM、网络、数据库、K8s集群)。不是在AWS控制台中点击按钮,而是编写代码描述你想要的内容,然后工具创建它。 +- **Terraform**(HashiCorp):标准的IaC工具。适用于所有主要云服务商。声明式:你描述期望状态,Terraform计算需要创建/修改/删除什么以达到该状态。 + +```hcl +# main.tf — 创建用于推理的GPU VM +resource "aws_instance" "model_server" { + ami = "ami-0abcdef1234567890" # 深度学习AMI + instance_type = "g5.xlarge" # A10G GPU + + tags = { + Name = "model-server-prod" + } +} + +resource "aws_s3_bucket" "model_weights" { + bucket = "my-model-weights-prod" + + versioning { + enabled = true + } +} +``` + +```bash +terraform init # 下载提供商插件 +terraform plan # 显示将要更改的内容 +terraform apply # 创建基础设施 +terraform destroy # 全部拆除 +``` + +- **IaC为何重要**:可重现性(从代码重建整个基础设施)、审计(git历史显示谁更改了什么)、灾难恢复(从同一配置在不同区域重建)和环境一致性(开发、预发布和生产使用相同模板,仅参数不同)。 +- **Pulumi**:类似Terraform,但使用真正的编程语言(Python、TypeScript、Go)而不是HCL。当基础设施逻辑复杂时(条件、循环、动态配置)很有用。 diff --git a/chapter 18: ML systems design/03. large scale infrastructure.md b/chapter 18: ML systems design/03. large scale infrastructure.md new file mode 100644 index 0000000..0290ca3 --- /dev/null +++ b/chapter 18: ML systems design/03. large scale infrastructure.md @@ -0,0 +1,200 @@ +# 大规模基础设施 + +*构建服务数百万用户的系统需要的不只是单个服务器。本文件涵盖可扩展性模式、分布式系统基础、微服务、数据流水线、数据库扩展、搜索和向量系统、可观测性、可靠性工程以及CI/CD* + +- 每秒服务1个请求的模型可以在笔记本电脑上运行。每秒服务100,000个请求且可用性达到99.9%需要分布式系统、自动故障转移和精心设计的数据流水线。本文件涵盖弥合这一差距的模式。 + +## 可扩展性 + +- **垂直扩展**(向上扩展):换更大的机器。更多CPU、更多内存、更大的GPU。简单但有硬性限制(最大的可用机器)和单点故障。 +- **水平扩展**(向外扩展):增加更多机器。每台处理一部分流量。没有单机限制,但需要:负载均衡(文件01)、数据分区和处理分布式状态。 +- **无状态服务**默认是可水平扩展的。在负载均衡器后面添加更多实例即可。在启动时加载权重并独立处理请求的模型推理服务器是无状态的——任何实例都可以处理任何请求。 +- **有状态服务**(数据库、KV缓存、特征存储)更难扩展。状态必须在多台机器间分区(分片,文件01)并复制以实现容错。 +- **可扩展性方程**:对于一个有$n$台服务器的系统: + - **理想情况**:吞吐量线性扩展($n$台服务器→$n\times$吞吐量)。 + - **实际情况**:协调、负载均衡和数据传输的开销意味着吞吐量亚线性扩展。阿姆达尔定律(第13章)适用:串行部分(共享状态、协调)限制了加速比。 + +## 分布式系统 + +- **分布式系统**是一组协调提供服务器的机器。基本挑战: +- **网络分区**:机器不能总是通信。网线被切断、交换机故障、数据中心断电。系统必须处理部分故障。 +- **时钟偏差**:机器有不同的时钟。"事件A发生在机器1的10:00:01"和"事件B发生在机器2的10:00:01"并不意味它们同时发生。**逻辑时钟**(Lamport时间戳、向量时钟)建立排序而不依赖物理时钟。 +- **共识**:多台机器如何就某个值达成一致(例如,谁是领导者)?**Raft**是标准的共识算法。一组节点选举一个领导者。领导者处理所有写入。如果领导者失败,剩余节点选举新的领导者。需要多数(5个节点中的3个)才能运行,因此能容忍$\lfloor(n-1)/2\rfloor$个故障。 +- **分布式锁**:确保只有一台机器执行关键操作。**Redlock**(基于Redis)跨多个Redis实例获取锁。如果多数实例授予锁,则获取成功。用于:防止重复的模型部署,确保只有一个训练作业写入检查点。 + +## 微服务 + +![微服务ML架构:API网关路由到特征服务、模型服务和日志服务,每个都有自己的数据库,通过消息队列连接](../images/microservices_architecture.svg) + +- **微服务**将系统分解为小型、独立可部署的服务。每个服务拥有一个领域: + +``` +┌─────────────┐ ┌──────────────┐ ┌─────────────┐ +│ API网关 │→ │ 特征服务 │→ │ 特征数据库 │ +└─────────────┘ └──────────────┘ └─────────────┘ + │ + ├────────→ ┌──────────────┐ ┌─────────────┐ + │ │ 模型服务 │→ │ 模型存储 │ + │ └──────────────┘ └─────────────┘ + │ + └────────→ ┌──────────────┐ ┌─────────────┐ + │ 日志服务 │→ │ 日志存储 │ + └──────────────┘ └─────────────┘ +``` + +- **优点**:独立部署(更新模型服务而不影响特征服务)、独立缩放(根据请求负载缩放模型服务器,根据特征存储读取率缩放特征服务器)、技术自由(模型服务用Python,特征服务用Go)。 +- **缺点**:网络开销(每次服务调用都是网络往返)、复杂性(调试跨越多个服务)、数据一致性(没有跨服务的事务)。 +- **服务发现**:API网关如何找到模型服务?选项:基于DNS(每个服务注册一个DNS名)、K8s服务(内置)或服务注册表(Consul、Eureka)。 +- **Saga模式**:对于跨多个服务的操作(创建用户+分配资源+发送欢迎邮件),使用saga:一系列本地事务,如果任何步骤失败则执行补偿操作。 + +## 数据流水线 + +- ML系统消耗大量数据。**数据流水线**移动、转换和服务这些数据: + +### 批处理 + +- 按固定间隔(每小时、每天)处理大量数据。 +- **MapReduce**:原始的批处理范式。Map(独立转换每条记录)→ Shuffle(按键分组)→ Reduce(按组聚合)。概念上简单但实现繁琐。 +- **Apache Spark**:现代批处理引擎。内存处理(对于迭代算法比MapReduce快100倍)。支持SQL、DataFrame和ML流水线。大规模特征工程的标准。 +- **示例**:为推荐系统计算用户特征。输入:过去30天的10亿用户活动事件。输出:1亿用户特征向量。每天作为Spark作业运行,输出到特征存储。 + +### 流处理 + +- 实时处理到达的数据(亚秒级延迟)。 +- **Apache Flink**:领先的流处理引擎。精确一次处理、事件时间处理(按事件发生时间处理,而非到达时间)、窗口化(滚动、滑动、会话窗口)。 +- **Kafka Streams**:内置于Kafka的轻量级流处理。适用于简单转换(过滤、聚合),无需部署单独的集群。 +- **示例**:实时欺诈检测。每笔信用卡交易是一个Kafka事件。Flink作业计算运行统计(交易频率、位置变化)并在100ms内标记异常。 + +### Lambda架构 + +- 结合批处理和流处理。**批处理层**提供准确、全面的结果(但有延迟)。**速度层**提供近似、实时的结果。**服务层**合并两者。 +- 实际上,许多团队现在使用**Kappa架构**:仅流处理,将流视为事实来源。流是可重播的(Kafka保留事件),因此可以通过重播流来模拟批处理。 + +## ML训练基础设施 + +- 训练前沿模型(100B+参数)是一个大规模基础设施问题:数千个GPU运行数月,消耗兆瓦级电力,生成PB级数据,花费数千万美元。基础设施决定了训练成功还是失败。 + +### GPU集群 + +- 训练集群是由高速网络连接的GPU服务器集合。关键组件: + +![GPU集群:每个节点有8个通过NVLink连接的GPU,节点通过Infiniband以胖树拓扑连接,从64扩展到16,000+个GPU](../images/gpu_cluster_topology.svg) + +- **GPU服务器(节点)**:每台服务器有4-8个GPU。典型配置:8×H100 GPU、2×AMD EPYC CPU、2 TB RAM、30 TB NVMe SSD。节点内的GPU通过**NVLink**连接(H100上每个GPU 900 GB/s),比PCIe快30倍。 +- **集群规模**:小型训练集群有64-256个GPU(8-32个节点)。前沿模型训练集群有4,000-32,000个GPU(500-4000个节点)。Meta的Llama 3使用了16,384个H100 GPU。Google在拥有8,000+个芯片的TPU pod上训练。 +- **粗略估算**:训练70B模型需要约$200万。训练400B+前沿模型需要约$5000万-$1亿。集群硬件本身在H100价格下约$5亿-$10亿($3万/GPU × 16,000 GPU = $4.8亿)。 + +### 网络拓扑 + +- GPU节点之间的网络是最关键的基础设施组件。如果GPU不能足够快地交换梯度,它们就会闲置等待通信完成。 +- **InfiniBand**是GPU集群网络的标准。NVIDIA的**Quantum-2 InfiniBand**提供每个端口400 Gb/s。每个节点通常有8个InfiniBand端口(每个GPU一个),每个节点的总对分带宽为400 GB/s。 +- **RDMA**(远程直接内存访问):InfiniBand支持RDMA,它直接在节点间的GPU内存之间传输数据,无需CPU参与。这将延迟从约100μs(TCP)降低到约1μs,对于高效的梯度全规约(第6章)至关重要。 +- **网络拓扑很重要**:**胖树**(Clos网络)提供全对分带宽(任何GPU可以与其他任何GPU以全速通信)。更便宜的拓扑(**轨道优化**、**3D环面**)提供较少的带宽但成本更低。拓扑必须匹配并行策略: + - **数据并行**:跨所有GPU的全规约→需要高对分带宽(胖树)。 + - **张量并行**:节点内通信→NVLink处理此需求(不需要网络)。 + - **流水线并行**:相邻流水线阶段之间的通信→只需要特定节点对之间的带宽(轨道优化即可)。 +- **以太网替代方案**:**RoCE v2**(融合以太网上的RDMA)在标准以太网基础设施上提供RDMA。比InfiniBand便宜,但延迟更高且更易拥塞。Google在某些TPU pod网络中使用RoCE。超以太网联盟正在开发用于AI工作负载的无损以太网。 + +### 训练存储 + +- 训练需要三个存储层级: + - **数据集存储**:训练语料(1-100 TB文本,或PB级多模态数据)。存储在分布式文件系统或对象存储中。必须支持高吞吐量顺序读取(数据加载器以大批量读取数据)。**Lustre**和**GPFS**是常见的HPC文件系统;云替代方案包括**FSx for Lustre**(AWS)和**Filestore**(GCP)。 + - **检查点存储**:训练状态(模型权重+优化器状态+调度器状态)定期保存。对于使用Adam优化器的混合精度70B模型:每个检查点约560 GB(70B × 4字节 × 2用于优化器)。每小时保存一次,运行3个月=约2000个检查点=1.1 PB。实际上,只保留最新的N个检查点,旧的会被删除。必须足够快,使检查点不会显著拖慢训练。 + - **日志和指标**:实验跟踪数据(损失曲线、学习率计划、梯度范数)。相对较小但必须实时写入。W&B、MLflow或TensorBoard处理此需求。 +- **存储瓶颈**:一个16,000-GPU集群加载一个训练批次需要持续读取约100 GB/s的数据。如果文件系统无法维持此吞吐量,GPU将闲置等待数据。数据流水线优化(预取、缓存、使用WebDataset或Mosaic Streaming进行格式优化)至关重要。 + +### 作业调度 + +- GPU集群服务于多个团队和项目。**作业调度器**将GPU分配给训练作业: +- **SLURM**:标准的HPC作业调度器。用户提交作业,指定GPU数量、内存和时间限制。SLURM分配资源并管理队列。支持基于优先级的调度、抢占和团队间的公平份额分配。 +- **带GPU调度的Kubernetes**(第18章文件02):云原生方法。K8s GPU设备插件将GPU暴露为可调度资源。**Volcano**和**Run:ai**增加了ML特定的调度功能:群体调度(一次为一个作业分配所有GPU,而不是逐个分配)、优先级队列和GPU时间共享。 +- **调度挑战**: + - **碎片化**:一个拥有1000个GPU的集群可能有200个空闲,但分布在50个节点上(每个节点4个空闲)。需要128个连续GPU的作业无法运行,即使有足够的总GPU数。**去碎片化**(迁移作业以合并空闲GPU)或**拓扑感知调度**(分配连接良好的GPU)可以解决此问题。 + - **优先级和抢占**:紧急实验应抢占低优先级作业。但抢占一个已运行2天的训练作业会浪费计算资源。调度器必须在优先级和效率之间取得平衡。 + - **公平份额**:团队应在一段时间内获得其分配的计算份额,即使一个团队提交的作业超过其份额。 + +### 容错 + +- 在数千个GPU运行数月的规模下,硬件故障不是异常——而是常态。16,000-GPU集群的平均故障间隔时间以小时计,而非月。 +- **常见故障**:GPU内存错误(ECC可纠正和不可纠正)、NVLink故障(节点内GPU到GPU通信)、InfiniBand链路故障(节点到节点通信)、节点崩溃(内核恐慌、PSU故障)和存储故障(磁盘或控制器故障)。 +- **检查点**是主要的防御手段。每N步保存完整的训练状态(模型、优化器、数据加载器位置)。故障时:识别故障节点,替换或移除它,从最近的检查点恢复训练。故障的代价是最后一次检查点和故障之间的计算量。 +- **检查点频率权衡**:频繁检查点(每10分钟)在故障时浪费更少的计算,但会减慢训练(保存560 GB需要时间)。不频繁检查点(每2小时)更快,但故障时浪费多达2小时的计算。大多数团队每20-60分钟检查一次。 +- **弹性训练**:现代框架(PyTorch Elastic、DeepSpeed)支持在不重启的情况下调整训练规模。如果500个节点中有2个节点故障,训练继续使用498个节点。故障节点被替换,训练在它们重新上线时自动纳入。 +- **健康监控**:持续监控所有GPU(温度、内存错误、计算吞吐量)、网络链路(丢包、延迟)和存储(吞吐量、错误率)。异常时自动告警。一些集群运行定期GPU健康检查(一个简短的计算测试)以主动识别在故障前性能下降的硬件。 +- **大规模场景**:训练Meta的Llama 3(16,384个H100,54天)经历了约466次作业中断。有效训练时间仅为挂钟时间的约90%——10%损失于故障和恢复。实现90%(而非50%或70%)的基础设施是区分能训练前沿模型的组织和不能训练的组织的关键。 + +### 成本和效率 + +- 训练基础设施成本由GPU小时主导: + +| 组件 | 占总成本百分比 | +|-----------|----------------| +| GPU计算 | 70-80% | +| 网络(InfiniBand) | 10-15% | +| 存储 | 5-10% | +| 冷却和电源 | 5-10% | + +- **GPU利用率**(模型FLOPs利用率,MFU)衡量GPU理论峰值性能中有多少被用于实际有用计算。H100峰值为989 TFLOPS(FP8)。达到40-50% MFU算良好;50-60%算优秀。差距来自:通信开销(全规约、流水线气泡)、内存带宽限制以及检查点和数据加载期间的闲置时间。 +- **提高MFU**:重叠计算和通信(第6章)、使用高效注意力(Flash Attention,第16章)、优化数据加载(防止GPU饥饿)、减少检查点开销(异步检查点,先检查到快速NVMe,然后后台复制到持久存储)。 +- **自建vs租用**:在小规模(<256个GPU)下,云更便宜(无前期成本,按小时付费)。在大规模(>1000个GPU,持续使用6+个月)下,拥有硬件更便宜(3年内TCO低约2-3倍)。大多数AI公司混合使用:自有集群用于持续训练,云用于突发容量和实验。 + +## 数据库扩展 + +- **只读副本**:将读取查询路由到主数据库的副本。主库处理写入,副本处理读取。由于大多数工作负载是读取密集型的(95%+读取),这使读取吞吐量随副本数量线性扩展。 +- **分区**(分片,来自文件01):将数据分割到多个数据库。每个分区是独立的,支持并行读取和写入。挑战是跨分区查询(连接来自不同分片的数据)。 +- **连接池**:数据库有有限的连接容量。连接池(PostgreSQL的PgBouncer)在请求间复用连接,防止当数百个服务实例各自尝试连接时出现连接耗尽。 + +## 搜索和向量系统 + +### 文本搜索 + +- **倒排索引**:文本搜索的基础。对每个单词,存储包含该单词的文档列表。查询对每个查询词的列表求交集。**Elasticsearch**是标准:分布式、实时、支持全文搜索、聚合和地理空间查询。 +- **BM25**:标准文本检索评分函数。根据词频、逆文档频率和文档长度归一化对文档评分。简单而有效——对于关键词密集型查询仍然能与神经方法竞争。 + +### 向量搜索 + +- **向量数据库**存储嵌入(高维向量)并支持快速**近似最近邻(ANN)**搜索。给定一个查询嵌入,找到$k$个最相似的存储嵌入。 +- **FAISS**(Facebook AI相似性搜索):一个用于ANN搜索的库(而非数据库)。支持多种索引类型: + - **Flat**:精确搜索,$O(n)$。用于小数据集或作为基准。 + - **IVF**(倒排文件):将向量分区到簇中,仅搜索最近的簇。每个查询$O(n/k)$。 + - **HNSW**(分层可导航小世界):基于图。构建分层图,从粗到细导航。极快且准确,是大多数应用的默认选择。 + - **乘积量化(PQ)**:将向量压缩为紧凑编码以实现内存高效搜索。用准确度换取内存。 +- **托管向量数据库**:Pinecone、Weaviate、Milvus、Qdrant。它们处理FAISS不具备的扩展、复制和实时更新。 +- **对于RAG**(检索增强生成):用户查询→用文本编码器嵌入→搜索向量数据库以找到相关文档→将检索到的文档前置到LLM提示中。检索质量直接决定LLM响应的质量。 + +## 可观测性 + +- **可观测性**是从系统外部输出理解系统内部状态的能力。三大支柱: + +### 日志 + +- **结构化日志**(JSON)是可搜索和可解析的。非结构化日志("ERROR: something failed")则不是。始终记录:时间戳、服务名、请求ID(用于跨服务追踪)、严重级别和相关上下文。 +- **ELK栈**(Elasticsearch、Logstash、Kibana):标准日志流水线。Logstash收集和转换日志,Elasticsearch建立索引,Kibana可视化和搜索。 + +### 指标 + +- **指标**是随时间变化的数值测量:请求率、错误率、延迟百分位数、GPU利用率、队列深度。**Prometheus**从服务抓取指标;**Grafana**在仪表盘中可视化并设置告警。 +- **服务的RED方法**:**R**ate(请求/秒)、**E**rrors(错误率)、**D**uration(延迟)。为每个服务监控这些指标。 +- **资源的USE方法**:**U**tilisation(使用百分比)、**S**aturation(队列深度)、**E**rrors。为每个资源(CPU、GPU、内存、磁盘、网络)监控这些指标。 + +### 追踪 + +- **分布式追踪**跟踪单个请求跨多个服务的路径。用户请求命中API网关→特征服务→模型服务→后处理。一个**追踪**记录了每次跳转的时序,显示延迟花在哪里。 +- **OpenTelemetry**:追踪、指标和日志的开放标准。一次代码埋点,导出到任何后端(Jaeger、Zipkin、Datadog)。 + +## 可靠性 + +- **SLO**(服务等级目标):目标可靠性。"99.9%的请求在<200ms内完成。"这给出了具体的错误预算:0.1%的请求(每月约43分钟)可以慢或失败。 +- **SLI**(服务等级指标):测量指标。"过去5分钟的第99百分位延迟。" +- **SLA**(服务等级协议):有后果的合同承诺。"如果可用性低于99.95%,客户获得信用额度。" +- **错误预算**:如果你的SLO是99.9%,而你达到了99.99%,你就有进行风险变更(部署新模型、迁移数据库)的预算。如果你只有99.85%,冻结所有变更,专注于可靠性。错误预算将可靠性从抽象目标转化为可衡量的资源。 +- **混沌工程**:故意注入故障(杀死服务器、添加网络延迟、破坏数据)以测试系统是否能正确处理。Netflix的Chaos Monkey随机终止生产实例。如果系统保持运行,它就是有弹性的。如果崩溃了,你在用户之前发现了一个bug。 + +## CI/CD + +- **持续集成**:自动构建和测试每次代码变更。每次推送触发:lint、类型检查、单元测试、集成测试。任何失败,变更被拒绝。这能在bug到达生产之前捕获它们。 +- **持续部署**:自动部署通过CI的变更。部署策略: + - **蓝绿部署**:运行两个相同的环境(蓝色=当前,绿色=新版本)。将流量从蓝色瞬间切换到绿色。如果绿色失败,切换回蓝色(即时回滚)。 + - **金丝雀部署**:将一小部分流量(1-5%)路由到新版本。监控错误。如果指标良好,逐步增加流量。这限制了不良部署的影响范围。 + - **功能标志**:部署新代码但隐藏在标志后面。为部分用户启用该标志(内部测试人员,然后是beta用户,然后是所有用户)。将部署(代码上线)与发布(用户看到功能)解耦。 +- 对于ML:CI/CD包括模型特定的步骤。模型变更触发:单元测试(形状测试、梯度检查)、在保留集上评估(准确率不得下降)、影子部署(新旧模型并行运行,比较输出)和逐步推出(金丝雀从1%→100%)。 diff --git a/chapter 18: ML systems design/04. ML systems design.md b/chapter 18: ML systems design/04. ML systems design.md new file mode 100644 index 0000000..9834ca3 --- /dev/null +++ b/chapter 18: ML systems design/04. ML systems design.md @@ -0,0 +1,183 @@ +# ML系统设计 + +*ML系统设计将文件01-03中的基础设施模式应用于机器学习的特定挑战。本文件涵盖ML生命周期、数据管理、训练基础设施、模型评估、服务策略、特征工程、ML流水线和监控* + +- 像"为YouTube设计一个推荐系统"这样的系统设计面试问题并不是要求你描述推荐算法。它要求你设计**整个系统**:数据流水线、特征工程、模型训练、评估、服务、监控和迭代。本文件提供了框架。 + +## ML系统生命周期 + +![ML系统生命周期:从问题定义到部署和监控,持续迭代](../images/ml_lifecycle.svg) + +- 每个ML系统都遵循相同的生命周期,无论是垃圾邮件分类器还是基础模型: + +``` +问题定义 → 数据 → 特征 → 训练 → 评估 → 部署 → 监控 → 迭代 + ↑ │ + └────────────────────────────────────────────────────────┘ +``` + +### 问题定义 + +- 在接触数据或模型之前,先定义: + - **预测什么?**(点击概率、下一个令牌、目标边界框) + - **用户是谁?**(最终用户、内部分析师、其他ML模型) + - **约束是什么?**(延迟<100ms、离线批量操作可以、必须在设备上运行) + - **业务指标是什么?**(收入、参与度、准确率)以及ML指标如何与之关联? + - **基线是什么?**(启发式方法、基于规则的系统、现有模型)——你必须击败它才能证明ML系统的价值。 +- **常见错误**:在理解问题之前直接跳到模型架构。"我们应该使用Transformer"不是系统设计的答案。"我们需要在200ms内预测1000万个候选的点击概率,因此我们需要一个两阶段系统:快速检索然后一个小型排序模型"才是。 + +## 数据管理 + +### 数据收集和标注 + +- **显式标签**:人类标注数据(点击/不点击、目标边界框、对话质量评分)。昂贵(取决于复杂度,每个标签约$0.02-$10)、缓慢且主观。 +- **隐式标签**:从用户行为中推导标签。点击、停留时间、购买、跳过。廉价且丰富,但有噪声(点击不意味着满意;跳过不意味着不喜欢)。 +- **程序化标注**(Snorkel):编写标注函数(启发式方法、正则表达式、现有模型),对每个样本进行投票。统计汇总投票以产生概率标签。可扩展到数百万样本,具有中等准确度。 +- **主动学习**:模型识别最不确定的样本,并请求人工标注这些样本。这最大限度地提高了标注效率:1000个主动选择的标签可以匹配10000个随机标签的质量。 + +### 数据质量 + +- **数据验证**:检查每批传入数据的模式违反(字段缺失、类型错误)、分布偏移(平均值显著变化)和数量异常(预期100万行,收到50万行)。 +- **Great Expectations**和**TFX Data Validation**是定义数据期望并在违反时发出告警的工具。 +- **数据版本管理**:每次训练运行应该是可重现的。**DVC**(第15章)将数据文件与代码一起追踪。每个数据集版本获得一个哈希值;训练配置引用该哈希值。 + +### 特征存储 + +![特征存储:相同的特征计算同时供给离线存储(用于训练)和在线存储(用于服务),防止训练-服务偏差](../images/feature_store.svg) + +- **特征存储**(第15章)为训练和服务提供一致的特征。关键概念: + - **离线特征**:从批处理流水线(Spark)计算,存储在数据仓库中。在训练和批量推理期间使用。示例:用户过去30天的平均会话时长、商品的总购买次数。 + - **在线特征**:实时计算或预先计算并从低延迟存储(Redis、DynamoDB)提供服务。在实时推理期间使用。示例:用户最近的5个操作、当前购物车内容。 + - **训练-服务偏差**:如果特征计算在训练和服务之间不同,模型在推理时看到的特征值与训练时不同。特征存储通过对两者使用相同的计算来消除此问题。 + +## 训练基础设施 + +- 对于本书的读者,分布式训练在第6章(数据并行、模型并行、混合精度、缩放定律)中已有深入介绍。这里我们关注**系统**方面: +- **实验跟踪**(W&B、MLflow——第15章):每次训练运行记录超参数、指标、git提交、数据版本和硬件。这是模型版本控制的ML等价物。 +- **超参数调优**:自动化搜索超参数。方法:网格搜索(穷尽,昂贵)、随机搜索(出奇地有效)、贝叶斯优化(对目标建模,在改进可能性高的地方采样)和**ASHA**(异步连续减半:启动许多试验,早期淘汰表现不佳的)。 +- **训练流水线编排**(Airflow、Kubeflow——第15章):自动化数据准备→训练→评估→注册的序列。安排每日重新训练。在失败时发出告警。 + +## 模型评估 + +### 离线评估 + +- **保留测试集**:在模型训练时从未见过的数据上评估。标准做法,但如果测试集不代表生产数据,可能会产生误导。 +- **基于分片的评估**:在子组上评估(按用户人口统计、内容类型、语言、时间段)。一个总体准确率95%的模型可能在特定少数群体上的准确率只有70%——不可接受。 +- **回测**:对于时间序列或顺序预测,按时间顺序在历史数据上进行评估。使用截至时间$t$的数据训练,在$t$到$t + \Delta t$的数据上评估。避免使用未来数据进行训练导致的数据泄露。 + +### 在线评估 + +![A/B测试:将用户随机分为对照组(旧模型)和实验组(新模型),以统计显著性比较指标](../images/ab_testing.svg) + +- **A/B测试**:将实时流量随机分为对照组(旧模型)和实验组(新模型)。以统计显著性比较业务指标(收入、参与度、留存率)。评估ML变更的黄金标准。 + - **样本量**:你需要足够的数据来检测预期的效应量。点击率0.1%的改进需要数百万次展示才能以显著性检测到。 + - **时长**:运行至少一个完整周期(大多数产品1-2周)以捕获日-周效应。 + - **护栏指标**:监控不应变化的指标(页面加载时间、错误率、崩溃率)以及目标指标。一个增加收入但同时增加崩溃率的模型是净负面的。 +- **影子部署**:在生产中与新模型并行运行旧模型。两者接收相同的请求,但只有旧模型的预测会提供给用户。比较输出。这能在不影响用户的情况下捕获bug和质量问题。 +- **交错实验**:对于排序问题,将旧模型和新模型的结果交错在一个列表中。用户与交错列表交互,你测量哪个模型的结果获得更多参与。相比A/B测试需要更少的用户即可达到显著性。 + +## 模型服务 + +### 批量vs实时 + +- **批量推理**:预先计算所有可能输入的预测结果。存储在数据库/缓存中。从缓存提供服务。适用于:输入空间有限(每晚为所有用户推荐)、新鲜度不重要(每日预测即可)、延迟容忍度高。 +- **实时推理**:按需为每个请求计算预测结果。适用于:输入空间无限(任何用户查询)、新鲜度重要(立即为此特定查询进行预测)、延迟必须低。 +- 许多系统**两者都用**:批量预计算一组候选结果(便宜,覆盖80%的流量),实时处理其余部分(昂贵,覆盖尾部查询和新用户)。 + +### 模型版本管理和注册表 + +- **模型注册表**(MLflow、W&B、SageMaker)存储训练好的模型及其元数据: + - 版本号和训练日期。 + - 训练配置和数据版本。 + - 评估指标(准确率、延迟、内存使用)。 + - 阶段:开发→预发布→生产→归档。 +- **回滚**:如果新模型在生产中导致指标下降,立即恢复到前一个版本。注册表使这成为一键操作。 + +## 特征工程 + +- **特征工程**将原始数据转换为模型所需的输入。它通常是ML中杠杆率最高的活动:更好的特征能改进每个模型,而更好的模型受限于它们收到的特征。 + +### 在线vs离线特征 + +- **离线特征**是预先计算的,变化缓慢(用户人口统计、30天聚合)。由批处理流水线(Spark)计算,存储在特征存储中。 +- **在线特征**反映当前状态,变化迅速(购物车中的商品、最近操作、当前位置)。从事件流实时计算或从快速存储中查找。 +- **特征新鲜度**:某些特征需要秒级新鲜度(欺诈检测:此交易相对于最近5笔交易是否异常?)。其他的可以容忍小时级陈旧度(推荐:根据用户历史,该用户偏好什么类型?)。更新鲜的特征计算和服务更昂贵。 + +### 常见特征模式 + +- **计数特征**:时间窗口内的事件计数(过去7天的购买次数、过去24小时的登录次数)。 +- **嵌入特征**:分类变量的学习嵌入(用户嵌入、商品嵌入、查询嵌入)。这些是双塔模型和类似架构的输入。 +- **交叉特征**:两个或多个特征的组合(user_age × item_category)。捕获单个特征无法捕获的交互。 +- **时间特征**:自上次操作以来的时间、星期几、一天中的小时。捕获时间模式。 +- **聚合特征**:数值特征在某个组上的均值、中位数、最小值、最大值、标准差(此卖家的商品平均评分)。 + +## ML流水线 + +- ML流水线编排从数据到部署模型的整个工作流程: + +``` +数据摄入 → 验证 → 特征工程 → 训练 → 评估 → 注册 → 部署 → 监控 +``` + +- 每个步骤是编排器(Airflow、Kubeflow、Metaflow——第15章)中的一个任务。流水线: + - 按计划运行(每日重新训练)或触发运行(新数据可用)。 + - 是幂等的(重新运行产生相同结果)。 + - 有重试逻辑(如果特征计算失败,使用退避重试3次)。 + - 产生制品(训练好的模型、评估报告、特征统计),这些制品被版本化管理并存储。 +- **Metaflow**(Netflix/Outerbounds)特别适合ML:它对代码、数据和模型一起进行版本管理,支持相同代码的本地开发和云执行,并与K8s和AWS集成。 + +## 监控 + +- 我们在第15章(Prometheus、Grafana、告警)中介绍了监控基础。这里我们关注**ML特定的监控**: + +### 数据漂移 + +- **数据漂移**发生在传入数据的分布相对于训练数据发生变化时。在夏季数据上训练的模型可能在冬季数据上表现不佳(不同的用户行为、不同的产品可用性)。 +- **检测**:使用统计测试比较传入特征分布与训练分布: + - **KS检验**(Kolmogorov-Smirnov):比较两个经验分布。检验它们是否来自相同的底层分布。 + - **PSI**(总体稳定性指数):衡量分布偏移了多少。PSI < 0.1为稳定,0.1-0.25为中度偏移,> 0.25为显著偏移。 + - **嵌入漂移**:使用质心距离或MMD(最大均值差异)比较传入查询的嵌入分布与训练集。 + +### 概念漂移 + +- **概念漂移**发生在输入和输出之间的关系发生变化时。特征看起来相同,但正确的预测不同。示例:用户偏好在一场文化活动、流行病或产品变更后发生转变。 +- 概念漂移比数据漂移更难检测,因为它需要带标签的数据。监控代理指标:点击率、转化率、用户满意度评分。持续下降可能表明概念漂移。 + +### 模型退化 + +- 模型会因多种原因随时间退化:数据漂移、概念漂移、特征流水线错误(特征开始返回空值)以及上游数据变化(第三方API更改其响应格式)。 +- **响应**:检测到退化时,行动取决于严重程度: + - 轻度:在最近数据上重新训练(定时重新训练可处理此情况)。 + - 中度:调查根本原因(哪个特征发生了变化?哪个用户群体受到影响?)。 + - 严重:立即回滚到以前的模型版本,然后调查。 + +### 反馈循环 + +- ML系统创建**反馈循环**:模型的预测影响用户行为,后者成为下一个模型版本的训练数据。这些循环可能是良性的,也可能是恶性的。 +- **正反馈循环**(危险的):推荐模型主要展示热门商品→用户点击热门商品(因为他们只看到这些)→模型了解到热门商品更受欢迎→多样性崩溃。模型创造了确认其偏见的数据。 +- **负反馈循环**(也危险的):欺诈检测模型捕获了所有A类欺诈→没有A类欺诈进入训练数据→下一个模型未学会检测A类→A类欺诈重新出现。 +- **缓解措施**: + - **探索**:展示一些模型不确定的商品(epsilon-greedy、Thompson采样)。这生成了多样化的训练数据。 + - **反事实日志记录**:记录模型*本会*预测的结果,而不仅仅是用户看到的结果。在反事实数据上训练以消除偏差。 + - **保留集**:随机将一部分流量用于无模型过滤的服务。未经过滤的数据为评估模型质量提供了真实依据。 + - **延迟标签**:在使用数据训练之前等待真实结果。今天点击的推荐可能明天就后悔。欺诈预测必须等待退款窗口(30-90天)。 + +### 嵌入表管理 + +- 大规模ML系统通常有包含1亿+条目的嵌入表(每个用户、商品、广告或实体一个嵌入)。大规模管理这些是系统挑战: +- **存储**:1亿实体×256维×float16 = 50 GB。不适合GPU内存。解决方案:存储在CPU内存中并配合GPU端缓存,跨多台机器分片,或使用**哈希嵌入**(将实体哈希到固定大小的表,接受冲突)。 +- **更新**:嵌入随模型重新训练而变化。向服务部署新的嵌入表需要:在不中断实时流量时加载50 GB到内存,验证正确性,以及在指标下降时回滚。对嵌入表使用蓝绿部署。 +- **陈旧度**:新创建的用户没有嵌入(冷启动问题)。解决方案:使用默认嵌入,通过特征到嵌入模型从用户特征派生嵌入,或回退到非个性化模型。 + +### 公平性和偏见 + +- ML系统可能会系统性地对待不同群体不同,通常反映了训练数据中的偏见。**公平性监控**是一种责任,不是可选功能。 +- **监控指标**: + - **人口统计均等**:不同群体(性别、种族、年龄)的正预测率是否不同? + - **均等机会**:不同群体的真阳性率是否不同?(招聘模型应该同样擅长识别所有群体的合格候选人。) + - **校准**:如果模型说P(合格) = 0.7对于群体A,那么群体A中实际上有70%是合格的吗?对于群体B也是同样? +- **实际步骤**: + - 在分片(子组)上评估模型性能,而不仅仅是总体指标。 + - 在模型评估流水线中纳入公平性指标(一个提高总体准确率但降低特定群体性能的模型未经审查不应部署)。 + - 记录已知的限制和失败模式。 + - 为在敏感领域(招聘、贷款、刑事司法、医疗)部署的模型建立审查流程。 diff --git a/chapter 18: ML systems design/05. ML design examples.md b/chapter 18: ML systems design/05. ML design examples.md new file mode 100644 index 0000000..156a493 --- /dev/null +++ b/chapter 18: ML systems design/05. ML design examples.md @@ -0,0 +1,337 @@ +# ML设计示例 + +*学习ML系统设计的最佳方式是通过实操示例。本文件详细介绍了七个完整的设计:推荐系统、搜索排序、广告点击预测、欺诈检测、内容审核、对话式AI和大规模图像搜索* + +- 每个示例遵循一致的框架: + 1. **问题定义**:我们在构建什么,用户是谁,约束是什么? + 2. **数据**:我们有什么数据,如何收集,如何标注? + 3. **特征**:模型需要什么特征? + 4. **模型**:什么架构和训练方法? + 5. **服务**:模型如何部署和提供服务? + 6. **评估**:我们如何衡量成功? + 7. **迭代**:随着时间的推移,我们会做哪些改进? + +--- + +## 1. 推荐系统(例如YouTube、Netflix、Spotify) + +### 问题定义 + +- **目标**:向用户展示他们会喜欢的内容,最大化参与度(观看时间、收听次数、点击量)。 +- **规模**:10亿+用户,1亿+项目,每秒10K+推荐。 +- **延迟**:完整推荐流水线<200ms。 +- **关键挑战**:候选空间巨大(1亿个项目)。无法为所有用户实时评分所有项目。 + +### 架构:两阶段流水线 + +![推荐流水线:1亿个项目通过候选生成缩小到1000个,排序到100个,重新排序到20个展示项目](../images/recommendation_pipeline.svg) + +``` +1亿个项目 → 候选生成(快速、粗略)→ 1000个候选 + → 排序(缓慢、精确)→ 100个排序项目 + → 重新排序(业务规则)→ 展示给用户的20个 +``` + +### 候选生成 + +- **目标**:将1亿个项目减少到约1000个候选。必须快速(<50ms)。 +- **双塔模型**:将用户和项目编码到相同的嵌入空间。用户嵌入捕获偏好;项目嵌入捕获内容特征。得分 = 用户嵌入和项目嵌入的点积。 +- **训练**:在(用户、正样本、负样本)三元组上进行对比学习。正样本=用户参与过的项目。负样本=随机项目+难负样本(用户未参与过的热门项目)。 +- **服务**:预先计算所有项目嵌入。在请求时:计算用户嵌入,ANN搜索(向量数据库中的HNSW)以找到最近的1000个项目嵌入。 + +### 排序 + +- **目标**:精确评分1000个候选。可以花费约100ms。 +- **模型**:一个深度神经网络(MLP或Transformer),使用丰富特征:用户特征(人口统计、历史、上下文)、项目特征(内容、流行度、新鲜度)和交叉特征(用户-项目交互历史、上下文相关性)。 +- **输出**:预测的参与概率(点击、观看50%+、点赞、分享)。多个目标可以组合:$\text{score} = w_1 \cdot P(\text{点击}) + w_2 \cdot P(\text{观看}) + w_3 \cdot P(\text{点赞})$。 + +### 重新排序 + +- 应用业务规则:多样性(不展示来自同一创作者的5个视频)、新鲜度(提升新内容)、安全(过滤被标记的内容)和个性化探索(展示一些用户可能发现的低排名项目)。 + +### 粗略估算数字 + +- **项目嵌入索引**:1亿个项目×256维×float16 = 50 GB。HNSW索引增加约2倍开销→约100 GB。适合具有128 GB内存的单台机器,或分片到4×32 GB机器。 +- **用户嵌入计算**:每个用户约5ms(小型MLP处理用户特征)。在10K QPS下,需要约50个模型副本处理负载。 +- **ANN搜索**:使用HNSW从1亿个向量中搜索前1000个约需2ms。在10K QPS下,每个索引副本处理约500 QPS→需要20个副本。 +- **排序模型**:1000个候选×每个候选约0.1ms = 每次请求100ms。在10K QPS下,需要每秒1000 GPU秒→仅排序就需要约10个A10G GPU。 +- **总基础设施**:约20个嵌入索引副本+约50个用户嵌入服务器+约10个排序GPU+缓存+负载均衡器。成本:云价格下每月约$5万-$10万。 + +### 冷启动 + +- **新用户**(无历史记录):使用人口统计特征、设备/位置上下文和基于流行度的推荐。经过5-10次交互后,切换到个性化模型。 +- **新项目**(无参与数据):使用基于内容的特征(标题、描述、缩略图嵌入)。分配探索预算:向一部分用户展示新项目以快速收集参与数据。在经过提升期后仍无参与的项目被降级。 +- **冷启动是系统问题**:特征存储必须优雅地处理缺失特征(返回默认值,而不是错误)。模型必须使用缺失特征进行训练(训练期间对用户历史特征进行dropout可以模拟新用户)。 + +### 评估 + +- **离线**:NDCG(归一化折损累计增益)、Recall@K、Precision@K。 +- **在线**:测量观看时间、DAU、留存的A/B测试。长期A/B测试(数周)以捕获短期测试无法观察到的用户留存效应。 + +--- + +## 2. 搜索排序(例如Google、Bing) + +### 问题定义 + +- **目标**:给定用户查询,从数十亿文档的语料库中返回最相关的结果。 +- **延迟**:总计<500ms(检索100ms + 排序200ms + 渲染100ms + 开销)。 + +### 架构:查询理解→检索→排序 + +### 查询理解 + +- 在检索之前,处理原始查询以改进结果: +- **拼写纠正**:"reccomendation systm"→"recommendation system"。使用编辑距离模型或序列到序列模型,在(拼写错误,纠正)对上训练,数据来自搜索日志。 +- **查询扩展**:添加相关术语以提高召回率。"Python ML"→"Python machine learning scikit-learn pytorch。"使用同义词词典、词嵌入或LLM生成扩展。 +- **意图分类**:确定用户想要什么。"buy Nike shoes"是**交易型**(展示产品页面)。"How does backpropagation work"是**信息型**(展示文章)。"facebook.com"是**导航型**(直接转到网站)。不同意图应触发不同的检索策略和结果布局。 +- **实体识别**:从查询中提取实体。"best restaurants near Times Square"→位置:"Times Square",实体类型:"restaurants。"路由到位置感知搜索流水线。 + +### 检索 + +- **BM25**(传统):使用倒排索引进行词匹配检索。快速,对关键词查询有效。没有语义理解("dog food"不匹配"canine nutrition")。 +- **稠密检索**:将查询和文档编码为嵌入(使用如DPR或ColBERT的双编码器)。通过ANN搜索检索。捕获语义相似性("dog food"匹配"canine nutrition")。比BM25慢,但对于自然语言查询更好。 +- **混合检索**:结合BM25和稠密检索。BM25找到精确关键词匹配;稠密检索找到语义匹配。合并并去重。两全其美。 + +### 排序 + +- **学习排序**:一个模型对每个(查询,文档)对评分。三种方法: + - **点式**:独立预测每个文档的相关性分数。简单但忽略相对顺序。 + - **成对式**:预测两个文档中哪个更相关。LambdaMART(梯度提升树)是经典方法。 + - **列表式**:直接针对列表级指标(NDCG)优化整个排序列表。更复杂但结果最好。 +- **交叉编码器**:一个以`[查询,文档]`为输入并输出相关性分数的Transformer。比双编码器更准确(后者独立编码查询和文档),因为它捕获了细粒度的交互。但对于完整语料库来说太慢——仅用于对检索前100-1000个候选进行重新排序。 + +### 特征 + +- **查询特征**:查询长度、语言、意图分类(导航型、信息型、交易型)。 +- **文档特征**:PageRank、新鲜度、内容质量分数、域名权威性。 +- **查询-文档特征**:BM25分数、嵌入相似度、精确匹配数、历史日志中此(查询,文档)对的点击率。 + +--- + +## 3. 广告点击预测 + +### 问题定义 + +- **目标**:预测用户点击广告的概率。这决定在实时拍卖中出价多少。 +- **规模**:每秒100K+次拍卖,每次预测需在10ms内完成。 +- **收入影响**:点击预测准确率提高0.1%就相当于数百万的额外收入。 + +### 架构 + +- **特征工程**是广告系统的核心。特征包括: + - **用户特征**:人口统计、浏览历史、购买历史、设备、位置、一天中的时间。 + - **广告特征**:创意(图片/文字)、广告主、类别、历史CTR、出价金额。 + - **上下文特征**:页面内容、广告位置、设备类型、连接速度。 + - **交叉特征**:user_category×ad_category交互,user_region×ad_campaign交互。 +- **模型**:历史上用逻辑回归(简单、快速、可解释)。现代系统使用深度学习:**DLRM**(深度学习推荐模型),对分类特征使用嵌入表,对稠密特征使用MLP。 +- **校准**:预测概率必须准确(如果模型说P(点击)=0.05,那么实际上应该有5%的展示被点击)。校准至关重要,因为预测概率直接决定出价金额。 +- **探索-利用**:总是展示预测的最佳广告在长期来看是次优的(你永远无法发现新广告可能更好)。Thompson采样或$\epsilon$-greedy探索确保有一部分展示分配给不确定性较高的广告以收集数据。 + +### 实时竞价 + +- 当用户加载页面时,广告拍卖在<100ms内进行: + 1. 发布者向多个广告交易平台发送竞价请求(用户信息、页面上下文)。 + 2. 每个广告主的竞价服务器预测其广告的CTR。 + 3. 出价 = CTR × 每次点击的价值。出价高的赢得拍卖。 + 4. 获胜的广告被展示;如果被点击,广告主付费。 + +--- + +## 4. 欺诈检测 + +### 问题定义 + +- **目标**:实时检测欺诈性交易(信用卡欺诈、账户盗用、虚假评论)。 +- **延迟**:<100ms(交易必须在支付处理前被批准或标记)。 +- **关键挑战**:极端类别不平衡(欺诈率0.1%)。误报会阻止合法用户;漏报会造成金钱损失。 + +### 架构 + +![欺诈检测流水线:交易→实时特征流水线→ML模型→决策引擎→允许/审核/阻止,人工审核将标签反馈用于重新训练](../images/fraud_detection_pipeline.svg) + +### 特征 + +- **交易特征**:金额、货币、商户类别、一天中的时间、是否跨国。 +- **用户特征**:账户年龄、平均交易金额、近期交易次数、设备指纹。 +- **速度特征**(实时,来自流处理流水线):过去5分钟内的交易次数、过去1小时内的不同商户数、与上次交易的地理距离。 +- **图特征**:此商户是否与已知欺诈团伙有关联?此设备是否与被标记账户共享? + +### 模型 + +- **梯度提升树**(XGBoost、LightGBM)是表格数据欺诈检测的标准。它们处理混合特征类型、可解释(特征重要性)且训练快速。 +- **处理不平衡**:对多数类进行欠采样、对少数类进行过采样(SMOTE),或在损失函数中使用类别权重。Focal loss(第8章)降低简单负样本的权重。 +- **成本矩阵**:误报(阻止合法交易)有成本(用户挫败感、销售损失)。漏报(遗漏欺诈)有不同的成本(财务损失)。决策阈值应最小化总预期成本,而非最大化准确率。 + +### 人在回路中 + +- 不确定的预测(模型置信度在0.3和0.7之间)发送给人工审核员。审核员的决策成为重新训练的标签。这创建了一个反馈循环:随着模型看到更多标记的欺诈案例,它随着时间的推移而改进。 + +--- + +## 5. 内容审核 + +### 问题定义 + +- **目标**:自动检测并移除有害内容(仇恨言论、暴力、虚假信息、CSAM)从一个平台。 +- **规模**:每天数十亿条帖子(文本、图片、视频)。 +- **挑战**:上下文依赖(讽刺、戏仿、文化细微差别)。必须在言论自由和安全之间取得平衡。 + +### 架构 + +- **多模态分类**:文本、图片和视频分别使用单独的模型,加上融合层组合它们的信号。 +- **文本审核**:微调的语言模型将文本分类为类别(骚扰、仇恨言论、虚假信息、垃圾信息)。多语言模型处理100+种语言。 +- **图片审核**:视觉模型检测:露骨内容(裸体、暴力)、图片中的文字(OCR+文本分类器)和已知有害内容(哈希匹配与已知CSAM数据库进行比对)。 +- **视频审核**:按固定间隔采样帧,对每帧运行图像分类器,结合音频转录(ASR→文本分类器)。 +- **策略即代码**:审核策略以结构化规则定义,将模型输出映射到操作: + +```python +if text_model.hate_speech_score > 0.9: + action = "remove" +elif text_model.hate_speech_score > 0.7: + action = "human_review" +else: + action = "allow" +``` + +- 策略频繁更改(新法规、不断发展的规范)。将策略与模型分离确保可以在不重新训练的情况下部署更改。 + +### 主动vs被动审核 + +- **主动审核**(发布前):在内容上线前运行分类器。高置信度违规自动阻止。这防止了有害内容被看到,但会增加发布延迟并存在误报风险(阻止合法内容)。 +- **被动审核**(发布后):内容立即上线。用户可以举报违规。举报触发分类器+人工审核。发布者延迟低,但有害内容在检测到之前是可见的。 +- **大多数平台两者都用**:对高严重性类别(CSAM:零容忍,发布前阻止)使用主动审核,对细微类别(虚假信息:需要人工判断,收到举报后审核)使用被动审核。 + +### 哈希匹配 + +- 对于已知有害内容(CSAM、恐怖主义宣传),使用**感知哈希**:计算对微小修改(裁剪、调整大小、压缩)鲁棒的图像/视频哈希值。与已知有害内容数据库(NCMEC的哈希数据库、GIFCT共享哈希数据库)进行比较。匹配→立即移除,无需分类器。 +- **PhotoDNA**(微软)是CSAM检测的标准感知哈希。在许多司法管辖区这不仅是技术选择,更是法律义务。 + +### 粗略估算数字 + +- **规模**:每天10亿条帖子=约12K帖子/秒。每个帖子需要:文本分类(约5ms)、图片分类(约20ms)、哈希匹配(约1ms)。在12K QPS下:需要约60个文本分类器、约240个图片分类器和约12个哈希匹配器(加上冗余)。 +- **人工审核**:如果2%的帖子被标记审核=每天2000万条。以每人每天100条审核计,需要20万审核员(这就是自动化准确率至关重要的原因:误报每降低0.1%就能每天节省100万条审核)。 +- **延迟预算**:主动审核必须在发布流水线内完成(约500ms)。文本(5ms)+ 图片(20ms)+ 哈希(1ms)+ 开销=远在预算之内。视频是例外:即使从10分钟视频中每秒采样1帧,也需要600次分类器调用→异步处理。 + +### 升级工作流 + +- 自动移除→人工审核上诉→专家审核(法律、文化专家)→政策团队处理模糊案例。每个级别处理的案例更少但更细致。 +- **反馈给模型**:人工审核决策是重新训练的最高质量标签。模型和审核员之间的分歧被优先用于主动学习——它们代表了模型处理最差的案例。 + +--- + +## 6. 对话式AI(基于RAG的聊天机器人) + +### 问题定义 + +- **目标**:一个能回答关于公司产品问题的聊天机器人,使用其文档。 +- **要求**:准确(不产生幻觉)、引用来源、处理后续问题、保持在产品领域内。 + +![RAG架构:嵌入查询,搜索向量数据库寻找相关块,重新排序,与原始查询一起输入LLM以生成有依据的响应](../images/rag_architecture.svg) + +### 架构:检索增强生成(RAG) + +``` +用户查询 → 查询嵌入 → 向量搜索(文档)→ Top-K块 + ↓ +用户查询 + 检索到的块 → LLM → 响应(含引用) +``` + +### 组件 + +- **文档摄入**:将文档分块并嵌入。**分块策略**非常重要: + - **固定大小分块**:每N个令牌(如500)分割,M个令牌(如50)重叠。简单,块大小可预测,但可能在句子中间或段落中间分割,丢失上下文。 + - **语义分块**:在段落或章节边界分割。每个块是一个连贯的信息单元。大小可变(有些块100个令牌,其他800个),需要检索系统处理可变长度。 + - **递归分块**:尝试在段落边界分割。如果段落太长,在句子边界分割。如果句子太长,在固定大小分割。连贯性和大小一致性的最佳平衡。 + - **嵌入**:用文本编码器(如E5、BGE、Cohere embed)嵌入每个块。存储在向量数据库中。 +- **检索**:嵌入用户查询,搜索向量数据库中最相似的$k$个块(通常$k = 5$-$10$)。可选地使用交叉编码器重新排序以提高精度。 +- **生成**:构建包含检索块作为上下文的提示: + +``` +系统:你是一个有用的助手。仅基于提供的上下文回答。 +如果答案不在上下文中,请说"我不知道。" + +上下文: +[块 1] +[块 2] +... + +用户:{问题} +``` + +- **护栏**:防止LLM回答产品领域外的问题、生成有害内容或与检索到的上下文相矛盾。实现为:输入过滤(拒绝离题查询)、输出过滤(检查响应是否与检索到的上下文一致)和宪法提示(指示模型拒绝某些请求)。 +- **对话记忆**:维护最近$n$轮对话。将其包含在提示中,使模型能理解后续问题("定价如何?"→需要关于哪个产品的先前上下文)。 + +### 查询重写 + +- 用户经常问模糊的后续问题:"定价如何?"(什么产品的定价?)。**查询重写**使用对话历史生成独立查询: + - 输入:对话历史 + "定价如何?" + - 重写后:"产品X的企业版定价是多少?" +- 这个重写后的查询才是被嵌入并在向量数据库中搜索的。如果没有重写,检索会搜索"定价"而没有上下文,返回不相关的块。 +- 查询重写可以用小型LLM调用(约50ms)或微调的序列到序列模型(约5ms)完成。 + +### 粗略估算数字 + +- **文档语料库**:10K页,每页平均2000个令牌=2000万令牌。以每块500个令牌、50个重叠计=约44K个块。 +- **嵌入索引**:44K块×768维×float16=约65 MB。轻松适合内存。即使1000万个块也仅约15 GB。 +- **延迟分解**:查询嵌入(5ms)+ 向量搜索(2ms)+ 交叉编码器重新排序(前50个20ms)+ LLM生成(500-2000ms)= 总计约600-2100ms。LLM占主导地位。使用流式传输减少感知延迟。 +- **成本**:以$3/100万令牌(Claude/GPT-4 API)计,每天1000次查询、每次约2000个令牌=约$6/天。大规模(每天100万次查询)下,在2个A10G GPU上自托管7B模型(约$50/天)可实现100倍成本降低。 + +### 评估 + +- **检索质量**:Recall@K(前K个块是否包含答案?)、MRR(平均倒数排名)。 +- **生成质量**:事实准确性(响应是否匹配检索到的上下文?)、有依据性(响应是否引用了正确块?)、答案相关性。 +- **端到端**:用户满意度(赞/踩)、转接给人工客服的比率。 + +--- + +## 7. 大规模图像搜索 + +### 问题定义 + +- **目标**:给定一张图像,从10亿+图像的语料库中找到视觉上相似的图像。 +- **应用**:反向图像搜索、产品搜索(照片→匹配的产品)、重复检测。 +- **延迟**:<500ms(包括网络往返时间)。 + +### 架构 + +``` +查询图像 → 嵌入模型(ViT/CLIP)→ 512维向量 → ANN搜索 → Top-K结果 +``` + +### 嵌入提取 + +- **模型**:预训练的视觉编码器(ViT、CLIP的图像编码器、DINOv2)。如果需要,在特定领域(时尚、电商、医学影像)上进行微调。 +- **训练**:对比学习(第10章)。正样本对=同一图像的不同视角(或图像+匹配的文本)。负样本对=随机图像。模型学习为相似图像生成相似嵌入,为不同图像生成不同嵌入。 + +### 索引 + +- **离线**:嵌入所有10亿张图像并构建ANN索引。对于HNSW(文件03),构建索引需要数小时,索引存储在内存中(10亿×512维×float16 + 图开销约128 GB)。 +- **分片**:将索引拆分成跨多台机器。每台机器持有一个分片。查询时,并行搜索所有分片并合并前K个结果。 +- **增量更新**:新图像(上传、新产品)必须添加到索引中。HNSW支持增量插入而无需重建。向量数据库(Milvus、Pinecone)原生处理此需求。 + +### 服务 + +- **嵌入服务**:运行ViT模型的GPU服务器。延迟:每张图像约20ms。批量处理多个查询以提高吞吐量。 +- **搜索服务**:ANN索引服务器。延迟:对于10亿向量中搜索前100个(使用HNSW)约10ms。 +- **缓存**:缓存热门查询的结果。对于重复检测,缓存最近上传的图像的嵌入,在搜索完整索引之前将新上传与缓存进行比较。 + +### 评估 + +- **Precision@K**:前K个结果是否实际相似? +- **Recall@K**:在语料库中所有真正相似的图像中,有多少在前K个中? +- **平均精度均值(mAP)**:精确率-召回率曲线下的面积。 +- **人工评估**:对于主观相似性,人工评分员判断检索到的图像是否相关。 + +--- + +## 面试框架 + +- 当你遇到系统设计问题时,遵循此框架: +1. **澄清需求**(2-3分钟):询问规模、延迟、一致性要求和边缘情况。"多少用户?可接受的延迟是多少?故障时会发生什么?" +2. **高层设计**(5-7分钟):画出主要组件及其交互。从正常路径开始。使用文件01-03中的模式。 +3. **深入探讨**(15-20分钟):选择最有趣/最具挑战性的组件并详细设计。这是你展示深度的地方。对于ML系统,深入探讨通常涉及:模型架构、特征流水线或服务架构。 +4. **评估和监控**(3-5分钟):你如何衡量成功?可能出什么问题?你如何检测和响应问题? +5. **迭代**(2-3分钟):如果有更多时间/资源,你会改进什么?这表明你理解权衡并能设定优先级。 + +- **面试官看中的**:结构化思维(不是直接跳到解决方案)、权衡意识(每个选择都有代价)、实践知识(你确实构建过系统)和沟通能力(你能清晰解释你的设计吗?)。 diff --git a/chapter 19: applied AI/01. AI for finance.md b/chapter 19: applied AI/01. AI for finance.md new file mode 100644 index 0000000..14c9c0d --- /dev/null +++ b/chapter 19: applied AI/01. AI for finance.md @@ -0,0 +1,10 @@ +# AI for Finance + +- 时间序列预测:ARIMA、指数平滑、Prophet、神经网络方法(LSTM、Temporal Fusion Transformer、PatchTST) +- 算法交易:信号生成、执行算法(TWAP、VWAP)、市场微观结构 +- 投资组合优化:均值-方差(Markowitz)、Black-Litterman、基于RL的投资组合管理 +- 风险建模:Value at Risk (VaR)、Expected Shortfall、蒙特卡洛模拟、信用评分 +- 欺诈检测:异常检测、基于图的方法、实时流处理 +- 金融中的NLP:新闻/财报电话会议的情绪分析、金融文档理解 +- 另类数据:卫星图像、社交媒体、网络爬虫 +- 监管与伦理:模型可解释性(SHAP、LIME)、信贷公平性、监管合规 diff --git a/chapter 19: applied AI/02. protein design.md b/chapter 19: applied AI/02. protein design.md new file mode 100644 index 0000000..958aa4c --- /dev/null +++ b/chapter 19: applied AI/02. protein design.md @@ -0,0 +1,10 @@ +# AI for Biology + +- 蛋白质结构预测:AlphaFold 1/2/3、ESMFold、共进化分析、MSA transformers +- 蛋白质设计:逆折叠(ProteinMPNN)、用于蛋白质生成的扩散模型(RFDiffusion)、幻觉(hallucination) +- 药物发现:分子表示(SMILES、图)、分子性质预测、虚拟筛选、分子对接 +- 生成化学:分子生成(VAE、GAN、扩散模型)、逆合成预测 +- 基因组学:DNA序列建模(Enformer、Hyena DNA)、变异效应预测、CRISPR引导设计 +- 单细胞分析:scRNA-seq、细胞类型聚类、轨迹推断 +- 医学影像:放射学(CheXNet)、病理学(全切片图像)、分割(nnU-Net) +- 临床NLP:医学实体提取、临床试验匹配、电子健康记录 diff --git a/chapter 19: applied AI/03. drug discovery.md b/chapter 19: applied AI/03. drug discovery.md new file mode 100644 index 0000000..e69de29 diff --git a/chapter 19: applied AI/04. agentic systems.md b/chapter 19: applied AI/04. agentic systems.md new file mode 100644 index 0000000..e69de29 diff --git a/chapter 19: applied AI/05. healthcare.md b/chapter 19: applied AI/05. healthcare.md new file mode 100644 index 0000000..e69de29 diff --git a/chapter 20: bleeding edge AI/01. quantum machine learning.md b/chapter 20: bleeding edge AI/01. quantum machine learning.md new file mode 100644 index 0000000..4246898 --- /dev/null +++ b/chapter 20: bleeding edge AI/01. quantum machine learning.md @@ -0,0 +1,11 @@ +# 量子机器学习 (Quantum Machine Learning) + +- 量子计算基础:量子比特 (qubit)、叠加 (superposition)、纠缠 (entanglement)、测量 (measurement) +- 量子门:泡利门 (Pauli X, Y, Z)、哈达玛门 (Hadamard)、CNOT 门、托佛利门 (Toffoli)、旋转门 (rotation gates) +- 量子电路:电路模型 (circuit model)、参数化电路 (parameterised circuits)、深度与宽度 (depth and width) +- 变分量子算法:VQE、QAOA、变分分类器 (variational classifiers) +- 量子核方法:量子特征映射 (quantum feature maps)、量子支持向量机 (quantum support vector machines) +- 量子神经网络:作为神经层的参数化量子电路 (parameterised quantum circuits as neural layers) +- 贫瘠高原 (barren plateaus):量子电路中的梯度消失 (vanishing gradients)、可表达性与可训练性 (expressibility vs trainability) +- 量子优势辩论:NISQ 时代局限性 (NISQ era limitations)、容错量子计算时间线 (fault-tolerant quantum computing timeline) +- 混合经典-量子架构:经典流水线中的量子层 (quantum layers in classical pipelines) diff --git a/chapter 20: bleeding edge AI/02. neuromorphic computing.md b/chapter 20: bleeding edge AI/02. neuromorphic computing.md new file mode 100644 index 0000000..3364c12 --- /dev/null +++ b/chapter 20: bleeding edge AI/02. neuromorphic computing.md @@ -0,0 +1,10 @@ +# 神经形态计算 (Neuromorphic Computing) + +- 生物学启发:脉冲神经元 (spiking neurons)、突触可塑性 (synaptic plasticity)、时间编码 (temporal coding) +- 脉冲神经网络 (SNN):整合发放模型——LIF、IF (integrate-and-fire models)、脉冲时序 (spike timing) +- SNN 中的学习:STDP(脉冲时序依赖可塑性)、代理梯度方法 (surrogate gradient methods)、从 ANN 转换 (conversion from ANNs) +- 神经形态硬件:Intel Loihi 2、IBM TrueNorth、SpiNNaker、BrainScaleS +- 事件驱动计算:异步处理 (asynchronous processing)、高能效 (energy efficiency) +- 事件相机 (DVS):神经形态视觉传感器 (neuromorphic vision sensors)、稀疏时序数据 (sparse temporal data) +- 应用:低功耗边缘推理 (low-power edge inference)、机器人 (robotics)、始终在线传感 (always-on sensing) +- 与传统深度学习对比:延迟 (latency)、功耗 (power)、精度 (accuracy) 的权衡 diff --git a/chapter 20: bleeding edge AI/03. datacentres in space.md b/chapter 20: bleeding edge AI/03. datacentres in space.md new file mode 100644 index 0000000..e69de29 diff --git a/chapter 20: bleeding edge AI/04. decentralised AI.md b/chapter 20: bleeding edge AI/04. decentralised AI.md new file mode 100644 index 0000000..e69de29 diff --git a/chapter 20: bleeding edge AI/05. brain machine interfaces.md b/chapter 20: bleeding edge AI/05. brain machine interfaces.md new file mode 100644 index 0000000..e69de29 diff --git a/docs/chapter 01: vectors b/docs/chapter 01: vectors new file mode 120000 index 0000000..0efbeac --- /dev/null +++ b/docs/chapter 01: vectors @@ -0,0 +1 @@ +../chapter 01: vectors \ No newline at end of file diff --git a/docs/chapter 02: matrices b/docs/chapter 02: matrices new file mode 120000 index 0000000..3d2c310 --- /dev/null +++ b/docs/chapter 02: matrices @@ -0,0 +1 @@ +../chapter 02: matrices \ No newline at end of file diff --git a/docs/chapter 03: calculus b/docs/chapter 03: calculus new file mode 120000 index 0000000..2eb20ea --- /dev/null +++ b/docs/chapter 03: calculus @@ -0,0 +1 @@ +../chapter 03: calculus \ No newline at end of file diff --git a/docs/chapter 04: statistics b/docs/chapter 04: statistics new file mode 120000 index 0000000..6b463a1 --- /dev/null +++ b/docs/chapter 04: statistics @@ -0,0 +1 @@ +../chapter 04: statistics \ No newline at end of file diff --git a/docs/chapter 05: probability b/docs/chapter 05: probability new file mode 120000 index 0000000..6d8309a --- /dev/null +++ b/docs/chapter 05: probability @@ -0,0 +1 @@ +../chapter 05: probability \ No newline at end of file diff --git a/docs/chapter 06: machine learning b/docs/chapter 06: machine learning new file mode 120000 index 0000000..cb9aad6 --- /dev/null +++ b/docs/chapter 06: machine learning @@ -0,0 +1 @@ +../chapter 06: machine learning \ No newline at end of file diff --git a/docs/chapter 07: computational linguistics b/docs/chapter 07: computational linguistics new file mode 120000 index 0000000..83cbf9c --- /dev/null +++ b/docs/chapter 07: computational linguistics @@ -0,0 +1 @@ +../chapter 07: computational linguistics \ No newline at end of file diff --git a/docs/chapter 08: computer vision b/docs/chapter 08: computer vision new file mode 120000 index 0000000..a4877cf --- /dev/null +++ b/docs/chapter 08: computer vision @@ -0,0 +1 @@ +../chapter 08: computer vision \ No newline at end of file diff --git a/docs/chapter 09: audio and speech b/docs/chapter 09: audio and speech new file mode 120000 index 0000000..3be2dc4 --- /dev/null +++ b/docs/chapter 09: audio and speech @@ -0,0 +1 @@ +../chapter 09: audio and speech \ No newline at end of file diff --git a/docs/chapter 10: multimodal learning b/docs/chapter 10: multimodal learning new file mode 120000 index 0000000..0557e8a --- /dev/null +++ b/docs/chapter 10: multimodal learning @@ -0,0 +1 @@ +../chapter 10: multimodal learning \ No newline at end of file diff --git a/docs/chapter 11: autonomous systems b/docs/chapter 11: autonomous systems new file mode 120000 index 0000000..d4ffcd0 --- /dev/null +++ b/docs/chapter 11: autonomous systems @@ -0,0 +1 @@ +../chapter 11: autonomous systems \ No newline at end of file diff --git a/docs/chapter 12: graph neural networks b/docs/chapter 12: graph neural networks new file mode 120000 index 0000000..10c2e20 --- /dev/null +++ b/docs/chapter 12: graph neural networks @@ -0,0 +1 @@ +../chapter 12: graph neural networks \ No newline at end of file diff --git a/docs/chapter 13: computing and OS b/docs/chapter 13: computing and OS new file mode 120000 index 0000000..ffa78a3 --- /dev/null +++ b/docs/chapter 13: computing and OS @@ -0,0 +1 @@ +../chapter 13: computing and OS \ No newline at end of file diff --git a/docs/chapter 14: data structures and algorithms b/docs/chapter 14: data structures and algorithms new file mode 120000 index 0000000..e99425a --- /dev/null +++ b/docs/chapter 14: data structures and algorithms @@ -0,0 +1 @@ +../chapter 14: data structures and algorithms \ No newline at end of file diff --git a/docs/chapter 15: production software engineering b/docs/chapter 15: production software engineering new file mode 120000 index 0000000..7dbb50f --- /dev/null +++ b/docs/chapter 15: production software engineering @@ -0,0 +1 @@ +../chapter 15: production software engineering \ No newline at end of file diff --git a/docs/chapter 16: SIMD and GPU programming b/docs/chapter 16: SIMD and GPU programming new file mode 120000 index 0000000..b4a2bd9 --- /dev/null +++ b/docs/chapter 16: SIMD and GPU programming @@ -0,0 +1 @@ +../chapter 16: SIMD and GPU programming \ No newline at end of file diff --git a/docs/chapter 17: AI inference b/docs/chapter 17: AI inference new file mode 120000 index 0000000..abe56ba --- /dev/null +++ b/docs/chapter 17: AI inference @@ -0,0 +1 @@ +../chapter 17: AI inference \ No newline at end of file diff --git a/docs/chapter 18: ML systems design b/docs/chapter 18: ML systems design new file mode 120000 index 0000000..62c2f4a --- /dev/null +++ b/docs/chapter 18: ML systems design @@ -0,0 +1 @@ +../chapter 18: ML systems design \ No newline at end of file diff --git a/docs/chapter 19: applied AI b/docs/chapter 19: applied AI new file mode 120000 index 0000000..443c3cf --- /dev/null +++ b/docs/chapter 19: applied AI @@ -0,0 +1 @@ +../chapter 19: applied AI \ No newline at end of file diff --git a/docs/chapter 20: bleeding edge AI b/docs/chapter 20: bleeding edge AI new file mode 120000 index 0000000..df3fb60 --- /dev/null +++ b/docs/chapter 20: bleeding edge AI @@ -0,0 +1 @@ +../chapter 20: bleeding edge AI \ No newline at end of file diff --git a/docs/images b/docs/images new file mode 120000 index 0000000..5e67573 --- /dev/null +++ b/docs/images @@ -0,0 +1 @@ +../images \ No newline at end of file diff --git a/docs/index.md b/docs/index.md new file mode 120000 index 0000000..32d46ee --- /dev/null +++ b/docs/index.md @@ -0,0 +1 @@ +../README.md \ No newline at end of file diff --git a/docs/javascripts b/docs/javascripts new file mode 120000 index 0000000..194589a --- /dev/null +++ b/docs/javascripts @@ -0,0 +1 @@ +../javascripts \ No newline at end of file diff --git a/images/ab_testing.svg b/images/ab_testing.svg new file mode 100644 index 0000000..7078a98 --- /dev/null +++ b/images/ab_testing.svg @@ -0,0 +1,46 @@ + + + + + + + A/B Testing: Measuring Model Impact + + + + All Users + 100% + + + + + + random + split + + + + Control (50%) + Old Model v1 + CTR = 4.2% + + + + Treatment (50%) + New Model v2 + CTR = 4.5% + + + + + + + Statistical Analysis + Δ CTR = +0.3% (p = 0.02, significant) + Guardrails: latency OK, errors OK + + + → Ship v2 to 100% + + Run for 1-2 weeks to capture day-of-week effects and reach statistical significance + \ No newline at end of file diff --git a/images/action_tokenisation.svg b/images/action_tokenisation.svg new file mode 100644 index 0000000..6ef27e3 --- /dev/null +++ b/images/action_tokenisation.svg @@ -0,0 +1,65 @@ + + + + + + + Action Tokenisation: Continuous → Discrete Bins → Tokens + + + + Continuous Action + Δx = 0.023 m/s + Δy = -0.051 m/s + grip = 0.8 + 7 dimensions + + + + + + Discretise (256 bins) + + + + + + + + + + bin 159 + + each dim → nearest bin index + + + + + + Token Sequence + + + + + 159 + + + 63 + + + 143 + + + ... + + + 204 + + + 91 + + fed to LLM like text tokens + + + Each action dimension is independently binned into 256 discrete tokens → LLM generates them autoregressively + \ No newline at end of file diff --git a/images/activation_functions.svg b/images/activation_functions.svg new file mode 100644 index 0000000..52c3fdf --- /dev/null +++ b/images/activation_functions.svg @@ -0,0 +1,51 @@ + + Common Activation Functions + + + ReLU + + + + + + max(0, x) + + + Sigmoid + + + + 1 + 0 + + + + Tanh + + + + -1 + 1 + + + + + GELU + + + + + x·Φ(x) + + + + Properties + ReLU: sparse, fast, can "die" + Sigmoid: (0,1), vanishing gradients + Tanh: (-1,1), zero-centred + GELU: smooth ReLU, used in GPT + + + + Nonlinearities are essential: without them, stacking layers collapses to one linear map + \ No newline at end of file diff --git a/images/actor_critic.svg b/images/actor_critic.svg new file mode 100644 index 0000000..b61f853 --- /dev/null +++ b/images/actor_critic.svg @@ -0,0 +1,45 @@ + + + + + + + Actor-Critic Architecture + + + + State s_t + + + + + + + + Actor (Policy) + pi(a|s; theta) + + + + Critic (Value) + V(s; phi) + + + + + Action a_t + + + + + Value estimate + + + + advantage = r + gamma*V(s') - V(s) + (guides policy update) + + + + Actor decides what to do. Critic evaluates how good the decision was. + \ No newline at end of file diff --git a/images/additive_inverse.svg b/images/additive_inverse.svg new file mode 100644 index 0000000..d6ba829 --- /dev/null +++ b/images/additive_inverse.svg @@ -0,0 +1,21 @@ + + + + + + + + + + + + v + + + -v + + + = + + 0 + diff --git a/images/amdahl_serial_bottleneck.svg b/images/amdahl_serial_bottleneck.svg new file mode 100644 index 0000000..ed0a9b2 --- /dev/null +++ b/images/amdahl_serial_bottleneck.svg @@ -0,0 +1,33 @@ + + Amdahl's Law: The Serial Bottleneck + + + 1 core: + + serial (10%) + + parallel (90%) + + + 4 cores: + + serial (10%) + + 90% ÷ 4 + → 3.1x speedup + + + 16 cores: + + serial (10%) + + → 6.4x speedup + + + ∞ cores: + + serial (10%) + → 10x max! + + The serial fraction (red) cannot be parallelised — it limits the maximum speedup to 1/(1−p) + \ No newline at end of file diff --git a/images/any_to_any_architectures.svg b/images/any_to_any_architectures.svg new file mode 100644 index 0000000..09c58a1 --- /dev/null +++ b/images/any_to_any_architectures.svg @@ -0,0 +1,187 @@ + + + + + + + + + + + + Any-to-Any Model Architectures + + + CoDi + + + + Image Diff. + + + Text Diff. + + + Audio Diff. + + + + + + + + + Aligned Conditioning (Shared Latent) + + + + + + + + + Image Out + + + Text Out + + + Audio Out + + Composable Diffusion + + + NExT-GPT + + + + Img Enc + * + + + Aud Enc + * + + + Vid Enc + * + + * = frozen + + + + Proj. + + + Proj. + + + Proj. + + + + + + + + + LLM Hub + + + + + + + + + Proj. + + + Proj. + + + Proj. + + + + + + + + + Img Dec + * + + + Aud Dec + * + + + Vid Dec + * + + + + + + + LLM as Hub + + + Gemini-style + + + + Natively Multimodal + Transformer + + + + T + + + I + + + T + + + A + + + I + + + V + + + T + + + I + + + Cross-modal + self-attention + + + + + + + Any modality in → Any out + (single model) + + Native Integration + + + + + + Shallow Integration + Deep Integration + + + + + \ No newline at end of file diff --git a/images/area_under_curve.svg b/images/area_under_curve.svg new file mode 100644 index 0000000..2a15c98 --- /dev/null +++ b/images/area_under_curve.svg @@ -0,0 +1,40 @@ + + + + + + + + + + + x + y + + + + + + + + + + + + + + f(x) + + + + a + + b + + + + Δx + + more rectangles → better approximation → integral + diff --git a/images/asr_pipeline.svg b/images/asr_pipeline.svg new file mode 100644 index 0000000..0c7c2d8 --- /dev/null +++ b/images/asr_pipeline.svg @@ -0,0 +1,65 @@ + + + + + + + + + + + + ASR Pipeline: Audio to Text + + + + Waveform + + + + + + + + + Feature + Extraction + (MFCCs / Mel) + + + + + + + Acoustic + Model + + + + + + + Decoder + + + + + + + Text + "hello world" + + + + Language + Model + + + + P(word seq) + + + + Traditional pipeline: each component trained separately. + End-to-end: single neural network replaces the entire pipeline. + \ No newline at end of file diff --git a/images/attention_alignment.svg b/images/attention_alignment.svg new file mode 100644 index 0000000..46603a9 --- /dev/null +++ b/images/attention_alignment.svg @@ -0,0 +1,79 @@ + + Attention Alignment: French → English Translation + + + le + chat + noir + dort + bien + Source (French) + + + the + black + cat + sleeps + well + Target (English) + + + + .82 + + .05 + + .03 + + .06 + + .04 + + + + .03 + + .06 + + .85 + + .03 + + .03 + + + + .04 + + .80 + + .07 + + .05 + + .04 + + + + .02 + + .05 + + .04 + + .83 + + .06 + + + + .03 + + .02 + + .03 + + .06 + + .86 + \ No newline at end of file diff --git a/images/attention_captioning.svg b/images/attention_captioning.svg new file mode 100644 index 0000000..7793941 --- /dev/null +++ b/images/attention_captioning.svg @@ -0,0 +1,116 @@ + + + + + + + + + + + + Attention-Based Image Captioning + + + Image Patches + + + + sky + + + tree + + + sky + + + + grass + + + dog + + + fence + + + + ball + + + grass + + + grass + + + + + + + + + + + + + strong attention + weak attention + + + Decoder Output + + + + "A" + + + "dog" + + + "plays" + + + + + + + Attention Map + (for "dog") + + + + 0.02 + + + 0.05 + + + 0.02 + + + 0.06 + + + 0.65 + + + 0.04 + + + 0.08 + + + 0.04 + + + 0.04 + + + Generated Caption: + "A dog plays with a ball on the grass" + + + At each decoding step, the model attends to different spatial regions of the image + \ No newline at end of file diff --git a/images/attention_sparsity_patterns.svg b/images/attention_sparsity_patterns.svg new file mode 100644 index 0000000..395382f --- /dev/null +++ b/images/attention_sparsity_patterns.svg @@ -0,0 +1,48 @@ + + Attention Patterns: Full vs Sliding Window vs Sparse + + + Full Attention + O(n²) + + + + + every token attends + to all previous + queries → + keys → + + + Sliding Window + O(n·w) + + + + + + + each token attends + to w previous only + + + Local + Global + O(n·w + n·g) + + + + + + + + + + + + + + + + local window + global tokens + (yellow = global attention) + \ No newline at end of file diff --git a/images/audio_spectrogram_transformer.svg b/images/audio_spectrogram_transformer.svg new file mode 100644 index 0000000..6ebfa82 --- /dev/null +++ b/images/audio_spectrogram_transformer.svg @@ -0,0 +1,96 @@ + + + + + + + + + Audio Spectrogram Transformer (AST) + + + + Mel Spectrogram + + + + + + + + + + + + + + + (patch grid) + + + + + + Flatten + Patches + + + + + + + + + + + Linear + Projection + + pos emb + + + + [CLS] + + prepended + + + + + + Transformer + Encoder + + + + Multi-Head Attn + FFN + + x L layers + + + + + + Classification + Head + (from [CLS]) + + + + + + Predicted + Label + + + 2D input + sequence + embed + encode + classify + + + + AST adapts ViT (chapter 08) directly to audio spectrograms + No convolutions needed. The spectrogram is split into 16x16 patches, linearly projected, and processed + by a standard Transformer encoder. Pre-training on ImageNet transfers surprisingly well to audio. + \ No newline at end of file diff --git a/images/audio_visual_correspondence.svg b/images/audio_visual_correspondence.svg new file mode 100644 index 0000000..c2c04c4 --- /dev/null +++ b/images/audio_visual_correspondence.svg @@ -0,0 +1,95 @@ + + + + + + + + + + + + + + + Audio-Visual Correspondence Learning + + + + + + + + F₃ + F₂ + F₁ + Video frames + + + + + + + Visual Encoder + (CNN / ViT) + + + + + + + v_embed + + + + + + + + + + + + + Spectrogram + + + + + + + Audio Encoder + (Wav2Vec) + + + + + + + a_embed + + + + + + + + Contrastive + Loss + + + + matched pair + + temporally aligned + + + + unmatched pair + + misaligned + + + + + \ No newline at end of file diff --git a/images/audio_waveform.svg b/images/audio_waveform.svg new file mode 100644 index 0000000..607032e --- /dev/null +++ b/images/audio_waveform.svg @@ -0,0 +1,68 @@ + + + + + + + + + + + + + + + + + + + + Audio Waveform: Anatomy of a Sine Wave + + + + + Time (t) + Amplitude + + + + + + + + + + + + + A + + + + + + T = 1/f + + + + phi + + + + sin(2pi f t) + + sin(2pi f t + phi) + + + + Key Parameters + A = amplitude (peak height) + T = period (seconds/cycle) + phi = phase offset (radians) + f = 1/T = frequency (Hz) + + + + Frequency f = cycles per second (Hz). A 440 Hz wave = concert A + \ No newline at end of file diff --git a/images/autonomous_driving_stack.svg b/images/autonomous_driving_stack.svg new file mode 100644 index 0000000..f65dba8 --- /dev/null +++ b/images/autonomous_driving_stack.svg @@ -0,0 +1,54 @@ + + + + + + + The Autonomous Driving Stack + + + + Sensors + cameras, LiDAR + radar, IMU + + + + + + Perception + 3D detection + lane detection + occupancy + + + + + + Prediction + trajectory forecast + intent estimation + 3-8s horizon + + + + + + Planning + route planning + trajectory opt. + decision making + + + + + + Control + steering + throttle + brake + + + Each stage refines raw data into increasingly abstract decisions + errors propagate forward → perception quality is the bottleneck + \ No newline at end of file diff --git a/images/bag_of_words.svg b/images/bag_of_words.svg new file mode 100644 index 0000000..b75cc9c --- /dev/null +++ b/images/bag_of_words.svg @@ -0,0 +1,107 @@ + + + + + + + Bag-of-Words: Document → Word Count Vector + + + + Document: + "the cat sat on the mat" + (word order is discarded) + + + + count + + + Word Counts + + + + Word + + Count + + + the + 2 + + + cat + 1 + + + sat + 1 + + + on + 1 + + + mat + 1 + + + + + + BoW Vector ∈ ℝⱽ + + + a + + 0 + + cat + + 1 + + dog + + 0 + + mat + + 1 + + on + + 1 + + sat + + 1 + + the + + 2 + + + + + ] + [ + + + + Advantage + Simple, fast, effective + for document classification + and spam filtering + + + Limitation + Ignores word order: + "dog bites man" and + "man bites dog" are identical + + + Dimensionality + Vector has V dimensions + (one per vocab word). + Very sparse: mostly zeros. + \ No newline at end of file diff --git a/images/basis_transform.svg b/images/basis_transform.svg new file mode 100644 index 0000000..dd38e71 --- /dev/null +++ b/images/basis_transform.svg @@ -0,0 +1,44 @@ + + + + + + + + + + + + + + + Before + + + + + + i + + + j + + + + A = + [col1 | col2] + + + After + + + + + + Ai + + + Aj + + columns of A = where i and j land + \ No newline at end of file diff --git a/images/bayes_components.svg b/images/bayes_components.svg new file mode 100644 index 0000000..5505c8c --- /dev/null +++ b/images/bayes_components.svg @@ -0,0 +1,41 @@ + + Bayes' Theorem Components + + + + P(A | B) = P(B | A) · P(A) / P(B) + + + + + + Posterior P(A|B) + Updated belief about A + after seeing evidence B + + + + + Likelihood P(B|A) + How probable is the + evidence if A is true? + + + + + Prior P(A) + Initial belief about A + before seeing any evidence + + + + + Evidence P(B) + Total probability of + observing B (normaliser) + + + Intuition: start with a prior belief, update it with new evidence, get a posterior belief. + Strong evidence shifts the posterior far from the prior. + Weak evidence leaves the posterior close to the prior. + \ No newline at end of file diff --git a/images/beamforming.svg b/images/beamforming.svg new file mode 100644 index 0000000..396c721 --- /dev/null +++ b/images/beamforming.svg @@ -0,0 +1,106 @@ + + + + + + + + + + + + Beamforming with Microphone Array + + + + S + Sound + Source + + + + + + + + + + + + + + + + theta + + + + M1 + + + M2 + + + M3 + + + M4 + + + + d + + + + + + + + + + Delay tau_1 + + + Delay tau_2 + + + Delay tau_3 + + + Delay tau_4 + + + tau_i = i*d*sin(theta)/c + + + + + + + + + + + + Sum + + + + + + + 1/N + + + + + + + Enhanced Signal + + + + Delay-and-sum beamforming: + align signals by compensating for propagation delay, then sum. + Signals from the target direction add constructively; off-axis noise partially cancels. The delay for + each mic depends on the angle of arrival and mic spacing. Neural beamformers learn the combining weights. + \ No newline at end of file diff --git a/images/bernoulli_binomial.svg b/images/bernoulli_binomial.svg new file mode 100644 index 0000000..ddfae06 --- /dev/null +++ b/images/bernoulli_binomial.svg @@ -0,0 +1,75 @@ + + Bernoulli (single trial) vs Binomial (n trials) + + + Bernoulli(p=0.7) + + + + + Outcome + P(X=x) + + + + 0.3 + 0 + + + 0.7 + 1 + + One coin flip: fail or success + + + + + + Binomial(n=8, p=0.7) + + + + + Number of successes (k) + P(X=k) + + + + + 0 + + + + 1 + + + + 2 + + + + 3 + + + + 4 + + + + 5 + + + + 0.30 + 6 + + + + 7 + + + + 8 + + 8 coin flips: how many heads? + \ No newline at end of file diff --git a/images/bert_mlm.svg b/images/bert_mlm.svg new file mode 100644 index 0000000..c20812c --- /dev/null +++ b/images/bert_mlm.svg @@ -0,0 +1,81 @@ + + + + + + + BERT: Masked Language Modelling + + + Original: + + + [CLS] + + the + + cat + + sat + + on + + the + + mat + + + + 15% masked + + + + Input: + + + [CLS] + + the + + + [MASK] + + sat + + on + + the + + + [MASK] + + + + + BERT Transformer Encoder (bidirectional attention) + + + + + + + + + + + cat + + + mat + + predict masked tokens + + + + Of 15% selected: + 80% → [MASK] + 10% → random word + 10% → unchanged + + Every token attends to every other token (bidirectional), enabling rich contextual representations. + \ No newline at end of file diff --git a/images/bev_fusion_pipeline.svg b/images/bev_fusion_pipeline.svg new file mode 100644 index 0000000..63d88c6 --- /dev/null +++ b/images/bev_fusion_pipeline.svg @@ -0,0 +1,62 @@ + + + + + + + BEV Fusion: Projecting Sensors into Bird's-Eye-View + + + + Cameras + + + + + 2D Image + Encoder + + + + + Lift to 3D + (depth pred.) + + + + + + LiDAR + + + + + 3D Point + Encoder + + + + + Voxelise + (BEV grid) + + + + + + BEV Fusion + concatenate features + + + + + + Detection + Head + + + 3D boxes, classes, velocities + + + Both sensor streams projected into the same bird's-eye-view grid + \ No newline at end of file diff --git a/images/bidirectional_rnn.svg b/images/bidirectional_rnn.svg new file mode 100644 index 0000000..97c55f4 --- /dev/null +++ b/images/bidirectional_rnn.svg @@ -0,0 +1,91 @@ + + + + + + + + + + + + + Bidirectional RNN + + + x₁ + x₂ + x₃ + x₄ + + + + + + + + + Forward: + + h→₁ + + + h→₂ + + + h→₃ + + + h→₄ + + + + + + + + Backward: + + h←₁ + + + h←₂ + + + h←₃ + + + h←₄ + + + + + + + + Output: + + [h→;h←]₁ + + + [h→;h←]₂ + + + [h→;h←]₃ + + + [h→;h←]₄ + + + + + + + + + + + + + Forward and backward hidden states are concatenated at each position, giving access to full context. + \ No newline at end of file diff --git a/images/bio_tagging.svg b/images/bio_tagging.svg new file mode 100644 index 0000000..f338f47 --- /dev/null +++ b/images/bio_tagging.svg @@ -0,0 +1,43 @@ + + BIO Tagging for Named Entity Recognition + + + Tim + Cook + visited + New + York + + + + + B-PER + + + + I-PER + + + + O + + + + B-LOC + + + + I-LOC + + + + PER (Person) + + + LOC (Location) + + + O (Outside entity) + + B = Begin entity, I = Inside (continuation), O = Outside. The B tag marks where a new entity starts. + \ No newline at end of file diff --git a/images/cache_aside_pattern.svg b/images/cache_aside_pattern.svg new file mode 100644 index 0000000..00e77e7 --- /dev/null +++ b/images/cache_aside_pattern.svg @@ -0,0 +1,41 @@ + + + + + + + Cache-Aside Pattern + + + + Application + + + + Cache (Redis) + + + + Database + + + + 1. Check cache + + + + 2. Miss → query DB + + + + 3. Store result in cache for next time + + + + HIT: ~1ms + + + MISS: ~50ms + + Subsequent requests for the same data hit the cache (1ms) instead of the database (50ms) + \ No newline at end of file diff --git a/images/cap_theorem.svg b/images/cap_theorem.svg new file mode 100644 index 0000000..8e9e9a6 --- /dev/null +++ b/images/cap_theorem.svg @@ -0,0 +1,33 @@ + + CAP Theorem: Pick Two (in practice: CP or AP) + + + + + + + C + Consistency + + + + A + Availability + + + + P + Partition Tolerance + + + CP + PostgreSQL + model registry + + + AP + Cassandra, DynamoDB + feature store + + Network partitions are inevitable → real choice is CP (consistent) vs AP (available) + \ No newline at end of file diff --git a/images/central_limit_theorem.svg b/images/central_limit_theorem.svg new file mode 100644 index 0000000..30fc1e8 --- /dev/null +++ b/images/central_limit_theorem.svg @@ -0,0 +1,72 @@ + + + + + + + + + Population (any shape) + + + + + + + + + + + + + + + take many + samples + + + Sample means (n=30 each) + + + + x̄₁ = 4.2 + + + x̄₂ = 3.8 + + + x̄₃ = 4.1 + + + x̄₄ = 3.9 + + + x̄₅ = 4.3 + + + x̄₆ = 4.0 + + ... hundreds more ... + + + + plot the + means + + + Distribution of x̄ + + + + + + + + + μ + + + + CLT: regardless of population shape, sample means → Normal + with mean μ and standard deviation σ/√n + \ No newline at end of file diff --git a/images/chain_rule.svg b/images/chain_rule.svg new file mode 100644 index 0000000..b593324 --- /dev/null +++ b/images/chain_rule.svg @@ -0,0 +1,32 @@ + + + + + + + + + + x + + + + + + + g(x) + + + + + + + f(g(x)) + + + g'(x) + × + f'(g(x)) + + derivative of outer × derivative of inner + diff --git a/images/clip_contrastive_matrix.svg b/images/clip_contrastive_matrix.svg new file mode 100644 index 0000000..f23e2b9 --- /dev/null +++ b/images/clip_contrastive_matrix.svg @@ -0,0 +1,92 @@ + + + + + + + + + + + + + + + CLIP Contrastive Learning: N×N Similarity Matrix + + + Text₁ + Text₂ + Text₃ + Text₄ + + + Image₁ + Image₂ + Image₃ + Image₄ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + maximise + + + + minimise + + + + Loss = ½ (CE over rows + CE over columns) — symmetric cross-entropy + \ No newline at end of file diff --git a/images/cloud_service_layers.svg b/images/cloud_service_layers.svg new file mode 100644 index 0000000..3a4192b --- /dev/null +++ b/images/cloud_service_layers.svg @@ -0,0 +1,60 @@ + + Cloud Service Models: What You Manage vs Provider Manages + + + IaaS + PaaS + SaaS + FaaS + + + + + + + + + + Application + + + + + + + Data + + + + + + + Runtime + + + + + + + OS + + + + + + + Virtualisation + + + + + + + Hardware + + + + You manage + + Provider manages + \ No newline at end of file diff --git a/images/cnn_convolution.svg b/images/cnn_convolution.svg new file mode 100644 index 0000000..1dd59de --- /dev/null +++ b/images/cnn_convolution.svg @@ -0,0 +1,123 @@ + + 2D Convolution: Filter Slides Over Input + + + Input (5x5) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 1 + 0 + 2 + 1 + 0 + 0 + 1 + 3 + 1 + 2 + 1 + 2 + 1 + 0 + 1 + 2 + 0 + 1 + 2 + 0 + 0 + 1 + 0 + 1 + 3 + + + + + + Filter (3x3) + + + + + + + + + + + + 1 + 0 + -1 + 1 + 0 + -1 + 1 + 0 + -1 + + + * + + + = + + + Output (3x3) + + + + + + + + + + + + 2 + + + 1·1 + 3·0 + 1·(-1) + 2·1 + 1·0 + 0·(-1) + 0·1 + 1·0 + 2·(-1) = 2 + + + filter slides right and down + + + + + + Output size = (input - filter + 2·padding) / stride + 1 + Same filter weights applied at every position (weight sharing) + \ No newline at end of file diff --git a/images/cocktail_party.svg b/images/cocktail_party.svg new file mode 100644 index 0000000..f44783c --- /dev/null +++ b/images/cocktail_party.svg @@ -0,0 +1,86 @@ + + + + + + + + + + + + + + + The Cocktail Party Problem + + + + + + Speaker 1 + + + + + + + + + Speaker 2 + + + + + + + + + + + + + + + + + + Mic + + + + + + Mixed Signal + + + + + + + + + Separation + Model + (deep network) + + + + + + + + + Recovered Source 1 + + + + + Recovered Source 2 + + + + The cocktail party problem: + recover individual sources from a mixture. + Humans do this effortlessly via auditory attention. Deep learning approaches learn to separate sources + from mixed signals, even when trained on synthetic mixtures. Key challenge: the permutation problem. + \ No newline at end of file diff --git a/images/codebook_collapse.svg b/images/codebook_collapse.svg new file mode 100644 index 0000000..6e58f9c --- /dev/null +++ b/images/codebook_collapse.svg @@ -0,0 +1,111 @@ + + + Codebook Utilisation: Healthy vs Collapsed + + + Healthy Codebook + + + + + + + + + + + + + + + + + + + + + + + 15/16 entries active + + + + + + + + + + + + + + + + + + + + + + codebook entry index + usage + + + + + + + Collapsed Codebook + + + + + + + + + + + + + + + + + + + + + + + 3/16 entries active + + + + + + + + + + + + + + + + + + + + + + codebook entry index + usage + + + + High reconstruction quality + Poor reconstruction, wasted capacity + \ No newline at end of file diff --git a/images/cofactor.svg b/images/cofactor.svg new file mode 100644 index 0000000..7510791 --- /dev/null +++ b/images/cofactor.svg @@ -0,0 +1,39 @@ + + + + + + + + + + + + + + + 1 + 5 + 3 + 2 + 7 + 4 + 6 + 8 + 9 + + + = + + + + + + + 2 + 4 + 6 + 9 + + Minor M₁₂: delete row 1 and column 2 + \ No newline at end of file diff --git a/images/column_space.svg b/images/column_space.svg new file mode 100644 index 0000000..e95c8dd --- /dev/null +++ b/images/column_space.svg @@ -0,0 +1,46 @@ + + + + + + + + + + + + + + + Column space (2 columns in 3D) + + + + + + + + col 1 + + + col 2 + + + reachable outputs = a plane + + + + + + Dependent columns + + + + + + col 1 + + + col 2 + reachable outputs = only a line + \ No newline at end of file diff --git a/images/common_distributions.svg b/images/common_distributions.svg new file mode 100644 index 0000000..e0d880e --- /dev/null +++ b/images/common_distributions.svg @@ -0,0 +1,66 @@ + + Common Distribution Shapes + + + Uniform(0, 1) + + + + height = 1 + 0 + 1 + + + Exponential(λ=1.5) + + + + + steep decay + 0 + + + Beta(α=2, β=5) + + + + + skewed right + 0 + 1 + + + Poisson(λ=4) + + + + + + 0 + + + 1 + + + 2 + + + 3 + + + 4 + + + 5 + + + 6 + + + 7 + + + 8 + + discrete counts of events + \ No newline at end of file diff --git a/images/commutativity.svg b/images/commutativity.svg new file mode 100644 index 0000000..40cae0e --- /dev/null +++ b/images/commutativity.svg @@ -0,0 +1,29 @@ + + + + + + + + + + + + + + + u + + + v + + + + + + + u + v + + + O + diff --git a/images/compilation_pipeline.svg b/images/compilation_pipeline.svg new file mode 100644 index 0000000..4d9c88c --- /dev/null +++ b/images/compilation_pipeline.svg @@ -0,0 +1,66 @@ + + + + + + + Compilation Pipeline: Source Code → Machine Code + + + + Source + x = 3 + y + + + + + + Lexer + tokens + + + + + + Parser + AST + + + + + + Semantic + Analysis + type check + + + + + + Optimiser + IR (LLVM) + + + + + + Code + Gen + + + + + + Machine + Code + x86 / ARM + + + text → tokens + tokens → tree + types + names + fold, inline, elim + IR → target ISA + + Each stage transforms the program into a lower-level representation + \ No newline at end of file diff --git a/images/concurrency_vs_parallelism.svg b/images/concurrency_vs_parallelism.svg new file mode 100644 index 0000000..e060fd0 --- /dev/null +++ b/images/concurrency_vs_parallelism.svg @@ -0,0 +1,62 @@ + + Concurrency vs Parallelism + + + + + Concurrency (1 core) + interleaved execution + + Core: + + + + A + + + B + + + A + + + C + + + B + + + C + + + + time → + + tasks take turns on one core + not truly simultaneous + + + Parallelism (4 cores) + simultaneous execution + + Core 1: + Core 2: + Core 3: + Core 4: + + + + Task A + + + Task B + + + Task C + + + Task D + + all tasks run at the same time + requires multiple cores + \ No newline at end of file diff --git a/images/conditional_probability.svg b/images/conditional_probability.svg new file mode 100644 index 0000000..a8de426 --- /dev/null +++ b/images/conditional_probability.svg @@ -0,0 +1,47 @@ + + Conditional Probability: P(A | B) + + + Full sample space S + + + + A + + + B + + + + + + A∩B + + + Given B + occurred + + + + + + + + + New sample space = B + + + + B + + + + A ∩ B + + + B \ A + + + P(A | B) = P(A ∩ B) / P(B) + The fraction of B that also belongs to A + \ No newline at end of file diff --git a/images/confidence_interval.svg b/images/confidence_interval.svg new file mode 100644 index 0000000..439a77d --- /dev/null +++ b/images/confidence_interval.svg @@ -0,0 +1,37 @@ + + + + + + + + + + lower bound + (x̄ - ME) + + + + upper bound + (x̄ + ME) + + + + x̄ (point estimate) + + + + + + + ME + ME + + + + μ (true parameter, + hopefully in here!) + + + Confidence Interval = Point Estimate ± Margin of Error + \ No newline at end of file diff --git a/images/conformer_block.svg b/images/conformer_block.svg new file mode 100644 index 0000000..d9babd0 --- /dev/null +++ b/images/conformer_block.svg @@ -0,0 +1,88 @@ + + + + + + + + + Conformer Block (Macaron-Style Sandwich) + + + + Input + + + + + Feed-Forward (1/2 step) + + + + + + + Multi-Head Self-Attention + + + + + + + Convolution Module + + + + + + + Feed-Forward (1/2 step) + + + + + + + LayerNorm + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Residual + connections + + + local context + + + global context + + + + + Conformer combines local (conv) + and global (attention) context + for speech recognition. + \ No newline at end of file diff --git a/images/constituency_tree.svg b/images/constituency_tree.svg new file mode 100644 index 0000000..18ecaae --- /dev/null +++ b/images/constituency_tree.svg @@ -0,0 +1,73 @@ + + Constituency (Parse) Tree: "the cat sat on the mat" + + + + S + + + + + + + + NP + + + + + + + + Det + + the + + + + N + + cat + + + + VP + + + + + + + + V + + sat + + + + PP + + + + + + + + P + + on + + + + NP + + + + + + the + mat + + + Each internal node is a phrase type; each leaf is a word. Phrases nest inside phrases. + \ No newline at end of file diff --git a/images/container_vs_vm.svg b/images/container_vs_vm.svg new file mode 100644 index 0000000..71a9dd7 --- /dev/null +++ b/images/container_vs_vm.svg @@ -0,0 +1,75 @@ + + Virtual Machines vs Containers + + + + + Virtual Machines + + + + Hardware + + + + Host OS / Hypervisor + + + + + App A + + Bins/Libs + + Guest OS + + Virtual HW + + + + + App B + + Bins/Libs + + Guest OS + + Virtual HW + + + Containers + + + + Hardware + + + + Host OS (shared kernel) + namespaces + cgroups + + + + + App A + + Bins/Libs + no guest OS! + + + + + App B + + Bins/Libs + + + + + App C + + Libs + + heavy: each VM runs full OS + light: shared kernel, ms startup + \ No newline at end of file diff --git a/images/continuous_vs_discrete_tokens.svg b/images/continuous_vs_discrete_tokens.svg new file mode 100644 index 0000000..514fc61 --- /dev/null +++ b/images/continuous_vs_discrete_tokens.svg @@ -0,0 +1,103 @@ + + + Choosing Between Continuous and Discrete Tokens + + + + + + + + + + + + + + + + + + + How will you generate? + + + + Autoregressive + (next-token) + + + + Diffusion + (iterative denoising) + + + + Hybrid + + + + Use Discrete Tokens + + + + 42 + + 7 + + 156 + + 89 + + 3 + + VQ-VAE / VQ-GAN + + + + DALL-E, Parti, LlamaGen, + VideoGPT, MAGVIT + + + + Soft Quantisation + + + + 0.7|0.3 + + 0.9|0.1 + + 0.5|0.5 + + Gumbel-Softmax / dVAE + + + + DALL-E (training), + Maskbit + + + + Use Continuous Tokens + + + + 0.73 + + -0.21 + + 1.45 + + -0.58 + + VAE latents (KL-regularised) + + + + Stable Diffusion, DALL-E 3, + Sora, Imagen Video + + + The generation method determines whether discrete or continuous tokenisation is most appropriate + \ No newline at end of file diff --git a/images/contrastive_temperature.svg b/images/contrastive_temperature.svg new file mode 100644 index 0000000..066cde3 --- /dev/null +++ b/images/contrastive_temperature.svg @@ -0,0 +1,87 @@ + + + + Effect of Temperature on Contrastive Softmax + + + + + + Low Temperature (τ → 0) + Sharp / peaked distribution + + + + + Candidates + Softmax prob. + + + + + + + + + correct + + + + + + + 1 + 2 + 3 + 4 + 5 + + + + 0 + + 0.5 + + 1.0 + + + High Temperature (τ → ∞) + Uniform / flat distribution + + + + + Candidates + Softmax prob. + + + + + + + correct + + + + + 1 + 2 + 3 + 4 + 5 + + + + 0 + + 0.5 + + 1.0 + + + + 1/N + + + softmax(sᵢ / τ) — lower τ sharpens the distribution, higher τ flattens it + \ No newline at end of file diff --git a/images/conv_tasnet.svg b/images/conv_tasnet.svg new file mode 100644 index 0000000..2e51bcf --- /dev/null +++ b/images/conv_tasnet.svg @@ -0,0 +1,109 @@ + + + + + + + + + Conv-TasNet Architecture + + + Encoder + Separator (TCN) + Decoder + + + + + Mixture x(t) + + + + + + Encoder + 1D Conv + + ReLU + + + + + + Encoded + Mixture W + + + + + + + + + 1x1 Conv + PReLU + Norm + + + D-Conv (dilation 2^k) + + + 1x1 Conv (skip + res) + + + + + skip + + + + res + + x R repeats + + + + Sigmoid masks (M1, M2) + + + + + + + Apply + Masks + W * M_i + + + + + + + + + + Decoder + Transposed + 1D Conv + + + + + + + + s1(t) + + + + s2(t) + + + + encoded mixture W passed to mask application + + + + Time-domain approach: + works directly on waveforms, no STFT needed. + The encoder learns a task-optimal representation. The TCN separator produces masks in the learned basis, + and the decoder reconstructs the separated waveforms. Achieves strong SI-SDR on speech separation benchmarks. + \ No newline at end of file diff --git a/images/convex_nonconvex.svg b/images/convex_nonconvex.svg new file mode 100644 index 0000000..efa3b8c --- /dev/null +++ b/images/convex_nonconvex.svg @@ -0,0 +1,76 @@ + + + + + + + + + Convex (f'' > 0) + + + + + + + + + + + + line above curve + + + + minimum + curves upward + + + + + + Concave (f'' < 0) + + + + + + + + + + + + line below curve + + + + maximum + curves downward + + + + + + Inflection (f'' = 0) + + + + + + + + + + + + + + + concave + convex + switches here + curvature changes sign + + convex: bowl (unique min) | concave: hill (unique max) | inflection: curvature flips + diff --git a/images/correlation_scatter.svg b/images/correlation_scatter.svg new file mode 100644 index 0000000..98edec1 --- /dev/null +++ b/images/correlation_scatter.svg @@ -0,0 +1,63 @@ + + + Positive (r ≈ +0.9) + + + + + + + + + + + + + + + + + + + + + + No correlation (r ≈ 0) + + + + + + + + + + + + + + + + + + + Negative (r ≈ -0.9) + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/images/counting_outfits.svg b/images/counting_outfits.svg new file mode 100644 index 0000000..9d334a5 --- /dev/null +++ b/images/counting_outfits.svg @@ -0,0 +1,71 @@ + + Multiplication Rule: 3 shirts × 4 pants = 12 outfits + + + Shirts + + S₁ (red) + + S₂ (blue) + + S₃ (green) + + + Pants + + P₁ + + P₂ + + P₃ + + P₄ + + + + + + + + + + + + + + + + + + + + + All 12 outcomes + (S₁,P₁) + (S₁,P₂) + (S₁,P₃) + (S₁,P₄) + + (S₂,P₁) + (S₂,P₂) + (S₂,P₃) + (S₂,P₄) + + (S₃,P₁) + (S₃,P₂) + (S₃,P₃) + (S₃,P₄) + + + + + + + + + + + Each shirt pairs with every pant: 3 × 4 = 12 total combinations + This is the multiplication rule: if choice A has m options and choice B has n options, + the total number of combined outcomes is m × n. + \ No newline at end of file diff --git a/images/cpu_pipeline.svg b/images/cpu_pipeline.svg new file mode 100644 index 0000000..ce472ab --- /dev/null +++ b/images/cpu_pipeline.svg @@ -0,0 +1,62 @@ + + CPU Pipeline: Overlapping Instruction Execution + + + Cycle 1 + Cycle 2 + Cycle 3 + Cycle 4 + Cycle 5 + Cycle 6 + + + Instr 1 + Instr 2 + Instr 3 + Instr 4 + + + + + Fetch + + Decode + + Execute + + Memory + + Write + + + + Fetch + + Decode + + Execute + + Memory + + Write + + + + Fetch + + Decode + + Execute + + Memory + + + + Fetch + + Decode + + Execute + + Each cycle, a new instruction enters the pipeline — throughput approaches 1 instruction/cycle + \ No newline at end of file diff --git a/images/critical_points.svg b/images/critical_points.svg new file mode 100644 index 0000000..1a447b7 --- /dev/null +++ b/images/critical_points.svg @@ -0,0 +1,51 @@ + + + + + + + + Three kinds of critical point (f'(x) = 0) + + + + + + + + + + + + + local max + f'' < 0 (curves down) + + + + + + + + + + + + + local min + f'' > 0 (curves up) + + + + + + + + + + + + + inflection + f'' = 0 (flat but no turn) + diff --git a/images/crnn_ocr_pipeline.svg b/images/crnn_ocr_pipeline.svg new file mode 100644 index 0000000..493af9e --- /dev/null +++ b/images/crnn_ocr_pipeline.svg @@ -0,0 +1,68 @@ + + + + + + + CRNN Text Recognition Pipeline + + + + Hello + (text image) + Input + + + + + + CNN + Extract visual + features + Spatial features + + + + + + + + + + + + + + Column slices + + + + + + BiLSTM + Sequence + modelling + Context-aware + + + + + + CTC + Decode + alignment + No segmentation + + + + + + Hello + Output + + + + CTC: Connectionist Temporal Classification + Handles many-to-one alignment: "HH-ee-ll-ll-oo" → "Hello" + No need for pre-segmented training data. Blank token (–) separates repeated characters. + \ No newline at end of file diff --git a/images/cross_modal_generation_overview.svg b/images/cross_modal_generation_overview.svg new file mode 100644 index 0000000..6434027 --- /dev/null +++ b/images/cross_modal_generation_overview.svg @@ -0,0 +1,92 @@ + + + + + + + + + + + + + + + + + + + + + + + + Cross-Modal Generation Pathways + + + + Generation Model + + + + Text + "a sunset over ocean" + + + + Image + pixels / latents + + + + Audio + waveform / spectrogram + + + + Video + frame sequence + + + + + + + + + + + + + + + + + + + + text-to-image + + + + text-to-audio + + + + text-to-video + + + + image-to-text + + + + image-to-video + + + + audio-to-text + + + Solid lines = primary pathways | Dashed lines = direct cross-modal translation + \ No newline at end of file diff --git a/images/ctc_alignment.svg b/images/ctc_alignment.svg new file mode 100644 index 0000000..d0dda98 --- /dev/null +++ b/images/ctc_alignment.svg @@ -0,0 +1,121 @@ + + + + + + + + + + + + CTC Alignment with Blank Tokens + + + Input + Frames + + + + t1 + + + t2 + + + t3 + + + t4 + + + t5 + + + t6 + + + t7 + + + t8 + + + + + + + + + + + + + CTC + Output + + + + + + + + + h + + + + h + + + + + + + + i + + + + + + + + + + + + ! + + + collapse duplicates & remove blanks + + + + + + + + + Final + Output + + + + h i ! + + + Many CTC paths + map to same output + + + + + = blank token + + = real character + + + + CTC allows the model to output 'blank' between characters -- no forced alignment needed. + \ No newline at end of file diff --git a/images/dalle_autoregressive_pipeline.svg b/images/dalle_autoregressive_pipeline.svg new file mode 100644 index 0000000..f3122aa --- /dev/null +++ b/images/dalle_autoregressive_pipeline.svg @@ -0,0 +1,98 @@ + + + + + + + + + + + + + + + DALL-E: Autoregressive Image Generation + + + Text Tokens + + a + + cat + + on + + mat + + + Image Tokens + + 42 + + 17 + + 91 + + ... + + ? + + + concatenated sequence + + + + + + + + Transformer Decoder + 12B parameters, causal attention + + + + predict next + image token + + + + + + Predicted image token sequence (1024 tokens) + + + + + + + + + + ... + + + + + + + dVAE Decoder + 8192-vocab codebook + + + + + + + + + + + Generated + Image + 256 x 256 + + + Stage 1: Autoregressive token prediction + Stage 2: Token-to-pixel decoding + \ No newline at end of file diff --git a/images/data_model_parallelism.svg b/images/data_model_parallelism.svg new file mode 100644 index 0000000..676bfb2 --- /dev/null +++ b/images/data_model_parallelism.svg @@ -0,0 +1,88 @@ + + Data Parallelism vs Model Parallelism + + + + + + Data Parallelism + + + + Full + Model + + + Full + Model + + + Full + Model + + + GPU 0 + GPU 1 + GPU 2 + + + + Data 1/3 + + + Data 2/3 + + + Data 3/3 + + + + + + + All-Reduce gradients + + Same model on each GPU + Different data on each GPU + Scales batch size + + + Model Parallelism + + + + Layers + 1-4 + + + Layers + 5-8 + + + Layers + 9-12 + + + GPU 0 + GPU 1 + GPU 2 + + + + Full Data (same batch) + + + + + + + pass activations + + Model split across GPUs + Same data on all GPUs + Scales model size + + + + Most large models use both: hybrid parallelism (data + tensor + pipeline) + \ No newline at end of file diff --git a/images/deadlock_cycle.svg b/images/deadlock_cycle.svg new file mode 100644 index 0000000..6400c57 --- /dev/null +++ b/images/deadlock_cycle.svg @@ -0,0 +1,51 @@ + + + + + + + + + + Deadlock: Circular Wait + + + + Thread + A + + + + Thread + B + + + + Lock 1 + + + + Lock 2 + + + + holds + + + + holds + + + + A wants Lock 2 + + + + B wants Lock 1 + + + + DEADLOCK + + Neither thread can proceed — each waits for the other's lock + \ No newline at end of file diff --git a/images/decision_tree_split.svg b/images/decision_tree_split.svg new file mode 100644 index 0000000..dab1625 --- /dev/null +++ b/images/decision_tree_split.svg @@ -0,0 +1,61 @@ + + + + + + + Decision Tree: Splitting on Features + + + + age < 30? + + + + yes + + + + no + + + + income > 50k? + + + + student? + + + + yes + + Buy (85%) + + + + no + + No Buy (70%) + + + + yes + + Buy (75%) + + + + no + + No Buy (90%) + + + Internal nodes: feature tests + Branches: yes/no answers + Leaves: class predictions + + + + Each split chosen to maximise information gain (reduce impurity) + \ No newline at end of file diff --git a/images/deeplab_aspp.svg b/images/deeplab_aspp.svg new file mode 100644 index 0000000..8131465 --- /dev/null +++ b/images/deeplab_aspp.svg @@ -0,0 +1,105 @@ + + + + + + + DeepLab: Atrous Spatial Pyramid Pooling (ASPP) + + + Atrous (Dilated) Convolution + + + rate=1 + + + + + + + + + + + + + + RF: 3×3 + + + rate=2 + + + + + + + + + + + + + + RF: 5×5, 9 params + + + ASPP Module + + + + Features + + + + + + 1×1 (r=1) + + + + + 3×3 (r=6) + + + + + 3×3 (r=12) + + + + + 3×3 (r=18) + + + + + GAP + 1×1 + + + + + + + + + + + Concat + + + + + 1×1 conv + + + + Out + + + + How Atrous Convolution Works + Standard 3×3 filter with gaps of (rate−1) between elements. Rate r gives receptive field (2r+1)×(2r+1). + ASPP applies multiple rates in parallel to capture context at multiple scales — like Inception but with dilation. + Global average pooling branch captures image-level context (what's the overall scene?). + diff --git a/images/densenet_block.svg b/images/densenet_block.svg new file mode 100644 index 0000000..bb5e766 --- /dev/null +++ b/images/densenet_block.svg @@ -0,0 +1,65 @@ + + + + + + + DenseNet: Dense Block with Feature Reuse + + + + x₀ + Input + + + H₁ + BN→ReLU→Conv + + + H₂ + BN→ReLU→Conv + + + H₃ + BN→ReLU→Conv + + + Output + [x₀,x₁,x₂,x₃] + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Layer l receives all preceding feature maps: x_l = H_l([x₀, x₁, ..., x_{l-1}]) + Each layer adds k new channels (growth rate k). Concatenation grows channels linearly, encouraging feature reuse. + diff --git a/images/dependency_tree.svg b/images/dependency_tree.svg new file mode 100644 index 0000000..309040c --- /dev/null +++ b/images/dependency_tree.svg @@ -0,0 +1,43 @@ + + + + + + + Dependency Tree: "the cat sat on the mat" + + + the + cat + sat + on + the + mat + + + ROOT + + + + + nsubj + + + + prep + + + + det + + + + pobj + + + + det + + + Arrows point from head to dependent. "sat" is the root; every other word depends on exactly one head. + \ No newline at end of file diff --git a/images/depthwise_separable_conv.svg b/images/depthwise_separable_conv.svg new file mode 100644 index 0000000..68c53b7 --- /dev/null +++ b/images/depthwise_separable_conv.svg @@ -0,0 +1,84 @@ + + + + + + + Depthwise Separable Convolution (MobileNet) + + + Standard Convolution + + + + + + H×W×C_in + + + + + + + k×k×C_in + × C_out filters + + + + + + + H×W×C_out + + + Cost: k²·C_in·C_out + per spatial position + e.g. 3²·64·128 = 73,728 + + + Depthwise Separable Convolution + + + + + + H×W×C_in + + + + + Depthwise + k×k × 1 per channel + Cost: k²·C_in + + + + + + + H×W×C_in + + + + + Pointwise + 1×1 × C_in × C_out + Cost: C_in·C_out + + + + + + + H×W×C_out + + + + Total cost: + k²·C_in + C_in·C_out + ≈ 9× cheaper (k=3) + + + + Depthwise handles spatial filtering; pointwise handles channel mixing. Same output, far fewer parameters. + diff --git a/images/detection_boxes.svg b/images/detection_boxes.svg new file mode 100644 index 0000000..c7bf948 --- /dev/null +++ b/images/detection_boxes.svg @@ -0,0 +1,54 @@ + + Object Detection: Bounding Boxes with Class Labels + + + + + + + + + + + + + + + + + + + + + + + + + + car 0.97 + + + + + + + + + car 0.94 + + + + + + + + person 0.91 + + + + + + + + Each detection: bounding box (x, y, w, h) + class label + confidence score + \ No newline at end of file diff --git a/images/determinant.svg b/images/determinant.svg new file mode 100644 index 0000000..237a20d --- /dev/null +++ b/images/determinant.svg @@ -0,0 +1,29 @@ + + + + + + + + + Before (unit square) + + + + + + area = 1 + + + + transform + + + After (parallelogram) + + + + + + area = |det(A)| + \ No newline at end of file diff --git a/images/difference_quotient.svg b/images/difference_quotient.svg new file mode 100644 index 0000000..0b017cc --- /dev/null +++ b/images/difference_quotient.svg @@ -0,0 +1,43 @@ + + + + + + + + + + + + + + + + + + a + f(a) + + + + + a+h + + secant + + + + + + + + tangent + + + + + h → 0 + + as h shrinks, secant → tangent + diff --git a/images/diffusion_process.svg b/images/diffusion_process.svg new file mode 100644 index 0000000..198b232 --- /dev/null +++ b/images/diffusion_process.svg @@ -0,0 +1,109 @@ + + + + + + + + + + Diffusion: Forward and Reverse Processes + + + Forward Process: Gradually Add Noise (q) + + + + + + x₀ + clean + + + + + + + + + + + + + + x₁ + + + + + + + + + + + + + + x_t + + + ··· + + + + + + + + + + + + + + + x_{T-1} + + + + + + + + + + + + + + + + + + + + + x_T + ≈ N(0,I) + + + Reverse Process: Learn to Denoise (p_θ) + + + + + ··· + + + + + −ε_θ + −ε_θ + −ε_θ + + + + Neural network ε_θ(x_t, t) predicts the noise added at each step. + Training loss: L = E[ ‖ε − ε_θ(x_t, t)‖² ]. Simple MSE on noise prediction. + \ No newline at end of file diff --git a/images/distribution_shift_bc.svg b/images/distribution_shift_bc.svg new file mode 100644 index 0000000..5719e5f --- /dev/null +++ b/images/distribution_shift_bc.svg @@ -0,0 +1,32 @@ + + + + + + + Behavioural Cloning: Distribution Shift (Compounding Error) + + + + expert + + + + learned + + + + small error + + + error grows + + + catastrophic drift + + + + + + The policy visits states the expert never demonstrated → no training data → errors compound + \ No newline at end of file diff --git a/images/distribution_types.svg b/images/distribution_types.svg new file mode 100644 index 0000000..985de4c --- /dev/null +++ b/images/distribution_types.svg @@ -0,0 +1,37 @@ + + + Frequency Distribution + + + + + Value + Count + + + + + + + + + + + Probability Distribution + + + + + Value + Probability + + + + + + + + + discrete bins + continuous curve + \ No newline at end of file diff --git a/images/distributivity.svg b/images/distributivity.svg new file mode 100644 index 0000000..4f7be8b --- /dev/null +++ b/images/distributivity.svg @@ -0,0 +1,38 @@ + + + + + + + + + + + + + + + + + + u + + + v + + + u+v + + + + + + + c(u + v) + + cu + cv + + + O + diff --git a/images/dit_architecture.svg b/images/dit_architecture.svg new file mode 100644 index 0000000..6336c84 --- /dev/null +++ b/images/dit_architecture.svg @@ -0,0 +1,119 @@ + + + + + + + + + + + + Diffusion Transformer (DiT) + + + + + + + + + + Noisy + Image + + + + + + + Patchify + + pos embed + + + + + + + + + + + + + + + + + + + Transformer Block + MHSA + FFN + + + + Transformer Block + MHSA + FFN + + + + + + + + + Transformer Block + MHSA + FFN + + + + + + + + + + t + + + + c + + + timestep + class label + + + + + + + + + adaLN-Zero + + + + + + + Unpatchify + reassemble + + + + + + + Denoised + Image + + + + Replaces U-Net + Pure transformer backbone + scales with compute (Chinchilla-optimal) + + + adaLN conditions each block on timestep and class via learned scale/shift parameters + \ No newline at end of file diff --git a/images/dot_product.svg b/images/dot_product.svg new file mode 100644 index 0000000..6cd8aec --- /dev/null +++ b/images/dot_product.svg @@ -0,0 +1,36 @@ + + + + + + + + + + + + + + b + + + + a + + + + + + + proj_b(a) + + + + + + + θ + + + a · b = ‖a‖ ‖b‖ cos(θ) + \ No newline at end of file diff --git a/images/dual_vs_fusion_encoder.svg b/images/dual_vs_fusion_encoder.svg new file mode 100644 index 0000000..27960b6 --- /dev/null +++ b/images/dual_vs_fusion_encoder.svg @@ -0,0 +1,113 @@ + + + + + + + + + Dual Encoder vs Fusion Encoder + + + + + + Dual Encoder + + + + + + Image + + + + Text + Query + + + + + + + Image Encoder + (ViT) + + + + + + + Text Encoder + (Transformer) + + + + + + + + v_i + + + v_t + + + + + cos(v_i, v_t) + + + Independent encoders + Fast retrieval + O(1) similarity computation + + + Fusion Encoder + + + + + + Image + + + + Text + Query + + + + + + + + + + + + + + + + + + + + + + + + Cross-Attention + Transformer + (Deep Fusion) + + + + + + + Token-level interaction + Richer understanding + O(n*m) cross-attention cost + \ No newline at end of file diff --git a/images/earth_mars_delay.svg b/images/earth_mars_delay.svg new file mode 100644 index 0000000..f755c58 --- /dev/null +++ b/images/earth_mars_delay.svg @@ -0,0 +1,34 @@ + + + + + + + + + + Earth-Mars Communication Delay + + + + Earth + + + + Mars + + + + command signal → 4-24 min + + + + ← telemetry response 4-24 min + + + + Round trip: 8-48 minutes + + + No real-time joysticking — the rover must decide for itself + \ No newline at end of file diff --git a/images/efficientnet_scaling.svg b/images/efficientnet_scaling.svg new file mode 100644 index 0000000..24981fa --- /dev/null +++ b/images/efficientnet_scaling.svg @@ -0,0 +1,71 @@ + + EfficientNet: Compound Scaling + + + Baseline + + + + + + + + + d layers + w channels + r resolution + + + Width (w) + + + + + + + + More channels + + + Depth (d) + + + + + + + + + + + More layers + + + Resolution (r) + + + + + + + Higher resolution + + + Compound (φ) + + + + + + + + + + + All three scaled together + + + + d = α^φ, w = β^φ, r = γ^φ subject to α·β²·γ² ≈ 2 + Scaling all dimensions together is more effective than scaling any single one. + diff --git a/images/eigenvector.svg b/images/eigenvector.svg new file mode 100644 index 0000000..fe47f80 --- /dev/null +++ b/images/eigenvector.svg @@ -0,0 +1,47 @@ + + + + + + + + + + + + + + + + + + Eigenvector: same direction + + + + + + x + + + Ax = λx + + + just scaled, not rotated + + + + + + Regular vector: direction changes + + + + + + v + + + Av + scaled AND rotated + \ No newline at end of file diff --git a/images/ensemble_methods.svg b/images/ensemble_methods.svg new file mode 100644 index 0000000..802c168 --- /dev/null +++ b/images/ensemble_methods.svg @@ -0,0 +1,93 @@ + + + + + + + Bagging vs Boosting + + + + + + Bagging (parallel) + + + + Training Data + + + + + + + + + Sample 1 + + Sample 2 + + Sample 3 + + + + Tree 1 + + Tree 2 + + Tree 3 + + + + + + + + + Average + + All learners trained independently + Reduces variance + + + Boosting (sequential) + + + + Weak 1 + + + + errors + + + + Weak 2 + + + + errors + + + + Weak 3 + + + + + + + Weighted Sum + + + w₁ + w₂ + w₃ + + Each learner fixes previous errors + Reduces bias + + + + Random Forest = bagging of trees | AdaBoost, GBM = boosting variants + \ No newline at end of file diff --git a/images/faster_rcnn.svg b/images/faster_rcnn.svg new file mode 100644 index 0000000..9467b36 --- /dev/null +++ b/images/faster_rcnn.svg @@ -0,0 +1,81 @@ + + + + + + + Faster R-CNN Pipeline + + + + + + Input + Image + + + + + + + Backbone + (ResNet) + + + + + + + Shared + Feature + Map + + + + + + + RPN + Region Proposals + + + + proposals + + + + + + + RoI Pooling + fixed-size features + + + + + + + + Classifier + class labels + + + + Regressor + box offsets + + + + + + + + + + detections + + + + Two-stage: RPN proposes candidate regions, then each region is classified and refined. + Backbone runs once (shared features). RPN and detection heads run on the shared feature map. + \ No newline at end of file diff --git a/images/fcos_detection.svg b/images/fcos_detection.svg new file mode 100644 index 0000000..485c329 --- /dev/null +++ b/images/fcos_detection.svg @@ -0,0 +1,78 @@ + + + + + + + FCOS: Anchor-Free Per-Pixel Detection + + + Feature Map + + + + + + + + + + + + + + Ground Truth Box + + + + + + + + l + + + t + + + r + + + b + + + + + + + + Classification Head + C class scores per pixel + + + + Regression Head + (l, t, r, b) distances per pixel + + + + Centerness Head + suppresses low-quality detections + + + + Centerness score: + √(min(l,r)/max(l,r)) + × √(min(t,b)/max(t,b)) + + + + Multi-scale with FPN: + Small objects → high-res levels + Large objects → low-res levels + + + + No anchors needed. Every feature map location inside a ground truth box is a positive training sample. + Centerness down-weights predictions far from object centres, improving precision without NMS changes. + diff --git a/images/feature_store.svg b/images/feature_store.svg new file mode 100644 index 0000000..be7bd3e --- /dev/null +++ b/images/feature_store.svg @@ -0,0 +1,53 @@ + + + + + + + Feature Store: Same Features for Training and Serving + + + + Raw Data + events, logs, DBs + + + + + + Feature + Engineering + same code for both paths + + + + + + + + Offline Store + data warehouse (batch) + + + + Online Store + Redis/DynamoDB (<5ms) + + + + + Training + batch read + + + + + Serving + real-time lookup + + + + Key: Same feature computation → No training-serving skew + Without a feature store: training computes user_age one way, serving computes it differently + → model sees different features at inference → silently wrong predictions + \ No newline at end of file diff --git a/images/five_geometric_domains.svg b/images/five_geometric_domains.svg new file mode 100644 index 0000000..808abb6 --- /dev/null +++ b/images/five_geometric_domains.svg @@ -0,0 +1,83 @@ + + The Five Geometric Domains + + + + Grids + + + + + + + + + + + + + + + + images, audio + symmetry: translation + arch: CNN + + + + Sets + + + + + + + point clouds + symmetry: permutation + arch: DeepSets + + + + Sequences + + + + + + + + + + text, time series + symmetry: (time) translation + arch: RNN, Transformer + + + + Graphs + + + + + + + + + + + molecules, social + symmetry: node perm. + arch: GNN + + + + Manifolds + + + + surfaces, meshes + symmetry: gauge/diffeo. + arch: mesh CNN + + Every neural architecture exploits the symmetry of one of these five domains + \ No newline at end of file diff --git a/images/flamingo_architecture.svg b/images/flamingo_architecture.svg new file mode 100644 index 0000000..0602d24 --- /dev/null +++ b/images/flamingo_architecture.svg @@ -0,0 +1,126 @@ + + + + + + + + + + + + + + + Flamingo Architecture + + + Input: Image₁ Text₁ Image₂ Text₂ ... (interleaved sequence) + + + + Image₁ + + + Text₁ + + + Image₂ + + + Text₂ + + + + Vision Encoder + (NFNet / ViT) + + + + Frozen + + + + + + + Perceiver + Resampler + + + + + + + + + + fixed-length + visual tokens + + + + + LM Block 1 + + + + + Gated Cross-Attention + + + + + + + + + + LM Block 2 + + + + + + + + Gated Cross-Attention + + + + + + + LM Block 3 + + + + + LM Block 4 + + + + + + + + + + + + + Generated Text + + + + Frozen LM blocks + + + Trained xattn layers + + + Trained resampler + + + = Frozen weights + \ No newline at end of file diff --git a/images/focal_loss.svg b/images/focal_loss.svg new file mode 100644 index 0000000..d6a5b15 --- /dev/null +++ b/images/focal_loss.svg @@ -0,0 +1,69 @@ + + Focal Loss: Down-Weighting Easy Examples + + + + + + + Loss + + Probability of correct class (p_t) + + + 0 + 0.25 + 0.5 + 0.75 + 1.0 + + + 0 + 2.5 + 5 + + + + + γ = 0 (CE) + + + + γ = 1 + + + + γ = 2 + + + + γ = 5 + + + + Easy examples + (high p_t): loss → 0 + + + + FL(p_t) = −α_t(1−p_t)^γ log(p_t) + + • γ = 0: standard cross-entropy + • γ = 2: well-classified examples + contribute ~100× less loss + • Focuses training on hard examples + + + + The core one-stage detector problem: + ~100,000 anchors, but only ~10 are objects. + Easy negatives overwhelm the loss → focal loss. + + + + Higher γ → more aggressive down-weighting of easy examples. RetinaNet uses γ=2, α=0.25. + diff --git a/images/fpn_pyramid.svg b/images/fpn_pyramid.svg new file mode 100644 index 0000000..4f96899 --- /dev/null +++ b/images/fpn_pyramid.svg @@ -0,0 +1,96 @@ + + + + + + + + + + + + + Feature Pyramid Network (FPN) + + + Bottom-Up + (backbone) + Lateral + (1×1 conv) + Top-Down + (+ upsample) + + + + + C2 (H/4 × W/4) + + + + + + + C3 (H/8 × W/8) + + + + + + + C4 (H/16) + + + + + + + C5 (H/32) + + + + + + + + + + + P5 + + + + + + + 2× up + + + P4 + + + + + + + 2× up + + + P3 + + + + + + + 2× up + + + P2 + + + ← small objects + ← medium + ← medium + ← large objects + + + + Each pyramid level has strong semantics (from top-down path) and good spatial resolution (from bottom-up). + diff --git a/images/fraud_detection_pipeline.svg b/images/fraud_detection_pipeline.svg new file mode 100644 index 0000000..25dbc81 --- /dev/null +++ b/images/fraud_detection_pipeline.svg @@ -0,0 +1,64 @@ + + + + + + + Fraud Detection: Real-Time Pipeline + + + + Transaction + $249, NYC + + + + + + Feature Pipeline + user history + velocity (5 txns/5min) + device, location delta + ~10ms + + + + + + ML Model + P(fraud) = 0.82 + ~5ms + + + + + + Decision + Engine + >0.9: block + 0.3-0.9: review + + + + + + + + Allow + + + Review + + + Block + + + + Human + + + labels → retrain + + Total latency: ~15ms (well within 100ms payment processing window) + Human review creates feedback loop for continuous model improvement + \ No newline at end of file diff --git a/images/fusion_strategies.svg b/images/fusion_strategies.svg new file mode 100644 index 0000000..7698360 --- /dev/null +++ b/images/fusion_strategies.svg @@ -0,0 +1,124 @@ + + + + + + + + + + + + + + + + + + Fusion Strategies for Multimodal Learning + + + + + + + Early Fusion + + + + Image + + + Text + + + + Concatenate + + + + + + + + Shared + Model + + + + + + + Output + + + + + + Middle Fusion + + + + Image + Encoder + + + Text + Encoder + + + + Cross- + Attention + + + + + + + + Output + + + + + + Late Fusion + + + + Image + Encoder + + + Image + Predictor + + + + + + Text + Encoder + + + Text + Predictor + + + + + + Combine + + + + + + + + Output + + + + \ No newline at end of file diff --git a/images/gat_attention_weights.svg b/images/gat_attention_weights.svg new file mode 100644 index 0000000..aa5268f --- /dev/null +++ b/images/gat_attention_weights.svg @@ -0,0 +1,52 @@ + + + + + + + GCN (Fixed Weights) vs GAT (Learned Attention) + + + + + GCN: Fixed by Degree + + + i + + + + + + + + + + + + + + + all neighbours contribute equally + + + GAT: Learned Attention + + + i + + + + + + + + + + + 0.6 + 0.1 + 0.3 + + important neighbours get more weight + \ No newline at end of file diff --git a/images/gaussian_elimination.svg b/images/gaussian_elimination.svg new file mode 100644 index 0000000..f6df4bd --- /dev/null +++ b/images/gaussian_elimination.svg @@ -0,0 +1,43 @@ + + + + + 2 + 1 + 5 + 4 + 3 + 7 + 6 + 5 + 9 + + + + row ops + + + + + 2 + 1 + 5 + 0 + 1 + -3 + 0 + 0 + -2 + + + + solve ↑ + + + + + x₁ = ? + x₂ = ? + x₃ = 1 + back substitution: solve from bottom up + \ No newline at end of file diff --git a/images/generation_evaluation_metrics.svg b/images/generation_evaluation_metrics.svg new file mode 100644 index 0000000..ae8a410 --- /dev/null +++ b/images/generation_evaluation_metrics.svg @@ -0,0 +1,135 @@ + + + + + + + + + + + + Evaluation Metrics for Generative Models + + + + + + + FID + Frechet Inception Distance + + + + Real + Distribution + + + + Generated + Distribution + + + + + + d_F + + + Compares mean + covariance + of Inception features + + + + Lower FID = better quality + diversity + + + IS + Inception Score + + + + Generated + Images + + + + + + + + + + Quality + p(y|x) is peaked + + + + + confident class + + + + Diversity + p(y) is uniform + + + + + + + + + + + many classes + + + + Higher IS = better (quality x diversity) + + + CLIPScore + Text-Image Alignment + + + + Shared CLIP Embedding Space + + + + Image + embed + + + + Text + embed + + + + + + cos(v, w) + + + Measures semantic alignment + between generated image and prompt + + + + Higher CLIPScore = better alignment + + + + + Gen. Image + + + + Text Prompt + + + + + \ No newline at end of file diff --git a/images/gpu_cluster_topology.svg b/images/gpu_cluster_topology.svg new file mode 100644 index 0000000..c9fa2a1 --- /dev/null +++ b/images/gpu_cluster_topology.svg @@ -0,0 +1,109 @@ + + GPU Cluster: Node Architecture and Network Topology + + + + GPU Server Node + + + + H100 + + + H100 + + + H100 + + + H100 + + + H100 + + + H100 + + + H100 + + + + H100 + + + + + + + + + + NVLink: 900 GB/s per GPU (intra-node) + 8 × InfiniBand ports: 400 Gb/s each (inter-node) + + + Fat Tree Network (Cluster) + + + + + + + spine switches + + + + + + + + + leaf switches + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + GPU nodes (8 GPUs each) + + + + Scale Examples + Research: 64 GPUs (8 nodes) + fine-tune 7B models + Production: 256-1024 GPUs + train 70B models + Frontier: 16,000+ GPUs + Llama 3, GPT-4 class + MTBF at 16K GPUs: ~hours. Checkpointing + elastic training = survival. + \ No newline at end of file diff --git a/images/grad_cam.svg b/images/grad_cam.svg new file mode 100644 index 0000000..08e3dbd --- /dev/null +++ b/images/grad_cam.svg @@ -0,0 +1,79 @@ + + + + + + + Grad-CAM: Gradient-weighted Class Activation Mapping + + + Input Image + + + + + + + + + + + + + CNN + backbone + + + + + + Last Conv Maps + + + + A^k + + + ∂y^c / ∂A^k + + + + + y^c + = "dog" + + + + + + Weights + + α_k + GAP(∇) + + + + + + + ReLU(Σ α_k A^k) + Grad-CAM + heatmap + + + + + + Overlay + + + + + + + + + + Grad-CAM backpropagates the class score to the last conv layer, + computes per-channel importance (α_k), and highlights the most relevant image regions. + \ No newline at end of file diff --git a/images/gradient_contour.svg b/images/gradient_contour.svg new file mode 100644 index 0000000..066f38c --- /dev/null +++ b/images/gradient_contour.svg @@ -0,0 +1,46 @@ + + + + + + + + Gradient vectors on contour lines + + + + c=1 + + c=2 + + c=3 + + + + + + + + + + + + + + + + min + + + + contour lines + + ∇f (gradient) + + gradient is always + ⊥ to contour lines + -∇f = downhill + (gradient descent) + + for f(x,y) = x² + y², contours are circles, gradients point radially out + diff --git a/images/gradient_descent_landscape.svg b/images/gradient_descent_landscape.svg new file mode 100644 index 0000000..2e85eb8 --- /dev/null +++ b/images/gradient_descent_landscape.svg @@ -0,0 +1,50 @@ + + + + + + + Gradient Descent: Rolling Downhill on the Loss Surface + + + + + parameter w + Loss L(w) + + + + + + + global min + + + + local min + + + + + large lr + big step + + + + + + + + good lr + + + + + + small lr + (slow, stuck) + + + + w ← w - η · dL/dw (η = learning rate) + \ No newline at end of file diff --git a/images/graph_adjacency_matrix.svg b/images/graph_adjacency_matrix.svg new file mode 100644 index 0000000..fc9a62a --- /dev/null +++ b/images/graph_adjacency_matrix.svg @@ -0,0 +1,57 @@ + + A Graph and its Adjacency Matrix + + + + 1 + + + 2 + + + 3 + + + + + + + + = + + + A = + + + 1 + 2 + 3 + + + 1 + 2 + 3 + + + [ + ] + + + 0 + 1 + 1 + + 1 + 0 + 1 + + 1 + 1 + 0 + + + A_ij = 1 if edge + A_ij = 0 otherwise + Symmetric (undirected) + Diagonal = 0 (no self-loops) + \ No newline at end of file diff --git a/images/graph_laplacian_smoothness.svg b/images/graph_laplacian_smoothness.svg new file mode 100644 index 0000000..23cd395 --- /dev/null +++ b/images/graph_laplacian_smoothness.svg @@ -0,0 +1,49 @@ + + Graph Laplacian: Measuring Signal Smoothness + + + + + Smooth Signal (low x^T L x) + + + 0.9 + + + 1.0 + + + 0.8 + + + 0.9 + + + + + + + neighbours have similar values + + + Non-Smooth Signal (high x^T L x) + + + 0.9 + + + -0.8 + + + 0.7 + + + -0.5 + + + + + + + neighbours have very different values + \ No newline at end of file diff --git a/images/grounding_coordinate_tokens.svg b/images/grounding_coordinate_tokens.svg new file mode 100644 index 0000000..2304be6 --- /dev/null +++ b/images/grounding_coordinate_tokens.svg @@ -0,0 +1,109 @@ + + + + + + + + + + + + Grounding via Coordinate Tokenisation + + + + + + + + + + + + + dog + + + + bench + + + (x1,y1) + (x2,y2) + + Image with detected objects + + + + + + Generated Token Sequence + + + + A + + + dog + + + <x1> + + + <y1> + + + <x2> + + + <y2> + + + + sitting + + + on + + + a + + + bench + + + <x1> + + + <y1> + + + <x2> + + + <y2> + + + + + + + + + + Text tokens + + + Coordinate tokens + + + Entity tokens + + + Coordinates are discretised into special vocabulary tokens (e.g., bins of 0-999) + enabling grounding as a natural text generation task + + + Models: Kosmos-2, Shikra, Qwen-VL, Ferret + \ No newline at end of file diff --git a/images/hifi_gan_generator.svg b/images/hifi_gan_generator.svg new file mode 100644 index 0000000..05e5713 --- /dev/null +++ b/images/hifi_gan_generator.svg @@ -0,0 +1,127 @@ + + + + + + + + + HiFi-GAN Generator Structure + + + + Mel + Spec + 80 ch + + + + + + + TransConv + Upsample + 8x + + + + MRF + k=3,7,11 + + + + + + + TransConv + Upsample + 8x + + + + MRF + k=3,7,11 + d=1,3,5 + + + + + + + TransConv + Upsample + 2x + + + + MRF + + + k3 + + k7 + + k11 + sum outputs + + + + + + + TransConv + Upsample + 2x + + + + MRF + + k3 + + k7 + + k11 + sum outputs + + + + + + + Conv1d + k=7 + + tanh + + + + + + + Waveform + 24 kHz audio + 1-channel + + + + + + + Multi-Receptive Field Fusion (MRF): + Parallel ResBlocks with different kernel sizes (3, 7, 11) and + dilation rates (1, 3, 5). Outputs summed to capture patterns at + multiple temporal scales simultaneously. + + + + Transposed Convolution Upsampling: + Total upsample factor = 8 x 8 x 2 x 2 = 256 + Matches hop_size=256 of mel transform. + Each stage: LeakyReLU activation. + + + + Note: + Multi-period + multi-scale discriminators train the generator adversarially. MPD captures periodic structures + (pitch harmonics) while MSD evaluates audio quality at different time resolutions. Combined with mel loss for stability. + \ No newline at end of file diff --git a/images/hmm_structure.svg b/images/hmm_structure.svg new file mode 100644 index 0000000..42a4bb2 --- /dev/null +++ b/images/hmm_structure.svg @@ -0,0 +1,68 @@ + + + + + + + + + + + Hidden Markov Model Structure + + + Hidden states: + + + z₁ + + + z₂ + + + z₃ + + ··· + + + + P(z₂|z₁) + + + P(z₃|z₂) + + + + + + + + + P(x₁|z₁) + P(x₂|z₂) + P(x₃|z₃) + + + Observations: + + + x₁ + + + x₂ + + + x₃ + + ··· + + + Hidden states are not + directly observed + + Each hidden state emits + an observable signal + + Markov property: + z_t depends only on z_{t-1} + \ No newline at end of file diff --git a/images/human_vectors.svg b/images/human_vectors.svg new file mode 100644 index 0000000..9c7475a --- /dev/null +++ b/images/human_vectors.svg @@ -0,0 +1,45 @@ + + + + + + + + + height + + + weight + + + age + + + + + + + + + Alice (165, 60, 28) + + + + + + + + Carol (170, 65, 30) + + + + + + + + Bob (185, 85, 45) + + + + O + diff --git a/images/hypothesis_test.svg b/images/hypothesis_test.svg new file mode 100644 index 0000000..9f242ac --- /dev/null +++ b/images/hypothesis_test.svg @@ -0,0 +1,48 @@ + + + + + + + + + + μ₀ (null value) + + + + + + + + + + -z_crit + + +z_crit + + + Reject H₀ + Reject H₀ + + + Fail to reject H₀ + + + + + test statistic + + + + + p-value + (area in tail) + + + α/2 + α/2 + + + Two-Tailed Hypothesis Test + \ No newline at end of file diff --git a/images/ieee754_float.svg b/images/ieee754_float.svg new file mode 100644 index 0000000..39bd7f4 --- /dev/null +++ b/images/ieee754_float.svg @@ -0,0 +1,35 @@ + + IEEE 754 Float32 Bit Layout + + + + + S + 1 bit + sign + + + + Exponent + 8 bits + biased by 127 + + + + Mantissa (fraction) + 23 bits + implicit leading 1 + + + 31 + 30 + 23 + 22 + 0 + + + value = (−1)^S × 1.mantissa × 2^(exponent − 127) + + + float16: 1+5+10 = 16 bits | bfloat16: 1+8+7 = 16 bits | float64: 1+11+52 = 64 bits + \ No newline at end of file diff --git a/images/image_histogram.svg b/images/image_histogram.svg new file mode 100644 index 0000000..965c537 --- /dev/null +++ b/images/image_histogram.svg @@ -0,0 +1,68 @@ + + Image Intensity Histograms + + + Dark Image + + + + + + + + Histogram + + + + + + + + + + 0 + 255 + + + Bright Image + + + + + + + + Histogram + + + + + + + + + + + 0 + 255 + + + Skewed left (low intensities) + Skewed right (high intensities) + + + After + Equalisation + + + + + + + + + + + + Spread uniformly + \ No newline at end of file diff --git a/images/image_pyramid.svg b/images/image_pyramid.svg new file mode 100644 index 0000000..b2744ba --- /dev/null +++ b/images/image_pyramid.svg @@ -0,0 +1,65 @@ + + Gaussian Image Pyramid + + + + Level 0 + 256 × 256 + (original) + + + + + + + + + + + + blur + + ↓ 2× + + + + Level 1 + 128 × 128 + + + + + + + + + blur + + ↓ 2× + + + + Level 2 + 64 × 64 + + + + + + + blur + + ↓ 2× + + + + Level 3 + 32 × 32 + + + + + + L4 + + + + Each level: Gaussian blur then subsample by 2 — captures coarser features at each scale + \ No newline at end of file diff --git a/images/image_tokenisation_overview.svg b/images/image_tokenisation_overview.svg new file mode 100644 index 0000000..6938f08 --- /dev/null +++ b/images/image_tokenisation_overview.svg @@ -0,0 +1,100 @@ + + + Image Tokenisation: From Pixels to Discrete Tokens + + + + + + + + + + + + + + Input Image + (continuous pixels) + + + + + + + Encoder + + + + + + + + + + 0.73 + + -0.21 + + 0.45 + + 1.12 + + -0.58 + + 0.89 + + -0.34 + + 0.67 + + -0.91 + + Latent Vectors + (continuous) + + + + + + + Quantise + + + + + + codebook + + + + + + + + + + 42 + + 7 + + 156 + + 89 + + 3 + + 211 + + 64 + + 501 + + 18 + + Discrete Tokens + (integer indices) + + + Each token index maps to a learned codebook vector + \ No newline at end of file diff --git a/images/image_tokeniser_comparison.svg b/images/image_tokeniser_comparison.svg new file mode 100644 index 0000000..a0959bb --- /dev/null +++ b/images/image_tokeniser_comparison.svg @@ -0,0 +1,115 @@ + + + Image Tokeniser Architectures + + + + + + + + + + dVAE (DALL-E) + + + + Input + + + + + Encoder + + + + + + Gumbel-Softmax + (differentiable sampling) + + + + + + Soft token probs + + + + + + Decoder + + + soft / differentiable + + + + VQ-GAN + + + + Input + + + + + Encoder + + + + + + NN Codebook Lookup + (nearest-neighbour, hard) + + + + + + Discrete tokens + + + + + + Decoder + Discrim. + + + hard quantisation + adversarial + + + + FSQ + + + + Input + + + + + Encoder + + + + + + Round to Fixed Levels + (e.g. [-2, -1, 0, 1, 2] per dim) + + + + + + Discrete tokens + + + + + + Decoder + + + no codebook needed + \ No newline at end of file diff --git a/images/inception_module.svg b/images/inception_module.svg new file mode 100644 index 0000000..f9702b9 --- /dev/null +++ b/images/inception_module.svg @@ -0,0 +1,67 @@ + + + + + + + Inception Module: Multi-Scale Parallel Convolutions + + + + Previous Layer + + + + + + + + + + + + + 1×1 conv + + + + 1×1 conv + + + 3×3 conv + + + + 1×1 conv + + + 5×5 conv + + + + 3×3 pool + + + 1×1 conv + + + + + + + + + + + + + + + + + Concatenate (channel dim) + + + + 1×1 bottlenecks reduce channels before expensive 3×3/5×5 convolutions, cutting computation dramatically. + diff --git a/images/instructpix2pix_pipeline.svg b/images/instructpix2pix_pipeline.svg new file mode 100644 index 0000000..53a1103 --- /dev/null +++ b/images/instructpix2pix_pipeline.svg @@ -0,0 +1,94 @@ + + + + + + + + + + + + + + + InstructPix2Pix: Instruction-Following Image Editing + + + + + + + + + Input Image + (daytime scene) + + + + "Make it sunset" + Text Instruction + + + + + + + + Modified Diffusion Model + Conditioned on both image + text + + + + Image concat + + + Text cross-attention + + + CFG: e = e_u + s_i(e_i - e_u) + s_t(e_it - e_i) + + + Guidance Scales + + + Image s_i + + + + 1.5 + + + Text s_t + + + + 7.5 + + + + + + + + + + + + + + Edited Image + (sunset scene) + + + + Structure Preserved + Same layout, buildings, composition + Only lighting/color changed per instruction + + + + + + No per-example fine-tuning required -- single forward pass at inference + \ No newline at end of file diff --git a/images/invariance_vs_equivariance.svg b/images/invariance_vs_equivariance.svg new file mode 100644 index 0000000..ca6851f --- /dev/null +++ b/images/invariance_vs_equivariance.svg @@ -0,0 +1,85 @@ + + + + + + + Invariance vs Equivariance + + + + + + Invariance + output stays the same + + + + 🐱 + input + + + + + f + + + + + "cat" + + + + 🐱 + shifted + + + + + f + + + + + "cat" + + + } + same + + + Equivariance + output transforms correspondingly + + + + + input + + + + + f + + + + + + features (top-left) + + + + + shifted + + + + + f + + + + + + features (shifted too) + \ No newline at end of file diff --git a/images/joint_embedding_space.svg b/images/joint_embedding_space.svg new file mode 100644 index 0000000..32d916b --- /dev/null +++ b/images/joint_embedding_space.svg @@ -0,0 +1,81 @@ + + + + + + + + + + + + Joint Embedding Space + + + + Image + Encoder + + + + IMG + + + + + Text + Encoder + + + + "text" + + + + + + + + + + + + + Dimension 1 + Dimension 2 + + + + + + + + + + + + + + + + + + + + + + + high similarity + + + + low similarity + + + + + + Image embedding + + Text embedding + \ No newline at end of file diff --git a/images/kl_divergence.svg b/images/kl_divergence.svg new file mode 100644 index 0000000..5ba4af3 --- /dev/null +++ b/images/kl_divergence.svg @@ -0,0 +1,35 @@ + + KL Divergence: Distance Between Distributions + + + + + x + Density + + + + + + + + + + + gap = KL + + + p (true) + q (approx) + + + + D_KL(p ∥ q) = Σ p(x) · log( p(x) / q(x) ) ≥ 0 + + + D_KL = 0 only if + p = q exactly + + Not symmetric: + D_KL(p∥q) ≠ D_KL(q∥p) + \ No newline at end of file diff --git a/images/kmeans_clustering.svg b/images/kmeans_clustering.svg new file mode 100644 index 0000000..0b56871 --- /dev/null +++ b/images/kmeans_clustering.svg @@ -0,0 +1,67 @@ + + K-Means Clustering (K=3) + + + + + Feature 1 + Feature 2 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + C + + + C + + + C + + + + = data point + + C + = centroid + dashed = cluster boundary + \ No newline at end of file diff --git a/images/lagrange_multiplier.svg b/images/lagrange_multiplier.svg new file mode 100644 index 0000000..0447493 --- /dev/null +++ b/images/lagrange_multiplier.svg @@ -0,0 +1,47 @@ + + + + + + + + + + + Lagrange multipliers: gradients are parallel at optimum + + + + + + + f = max + f = c₁ + f = c₂ + + + + g(x,y) = c + (constraint) + + + + optimum + + + + ∇f + + + + ∇g + + + + + + not optimal: + gradients not ∥ + + at the constrained optimum, ∇f = λ∇g (parallel gradients) + diff --git a/images/lidar_time_of_flight.svg b/images/lidar_time_of_flight.svg new file mode 100644 index 0000000..47ae8a6 --- /dev/null +++ b/images/lidar_time_of_flight.svg @@ -0,0 +1,39 @@ + + + + + + + + + + LiDAR: Time-of-Flight Distance Measurement + + + + LiDAR + sensor + + + + emitted laser pulse → + + + + object + + + + ← reflected pulse returns + + + + + + d = c · Δt / 2 + + + Δt = round-trip time + c = speed of light + ÷2 for round trip + \ No newline at end of file diff --git a/images/line_equation.svg b/images/line_equation.svg new file mode 100644 index 0000000..605320b --- /dev/null +++ b/images/line_equation.svg @@ -0,0 +1,44 @@ + + + + + + + + + + + x + y + + + + y = mx + b + + + + + b + (intercept) + + + + + + + + Δx + + + + Δy + + + m = Δy/Δx + (slope = rate of change) + + + 0 + + b = starting value, m = how fast y changes per unit x + diff --git a/images/linear_regression_fit.svg b/images/linear_regression_fit.svg new file mode 100644 index 0000000..368a2e0 --- /dev/null +++ b/images/linear_regression_fit.svg @@ -0,0 +1,52 @@ + + Linear Regression: Best-Fit Line and Residuals + + + + + x + y + + + + y = wx + b + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + data points + + best-fit line + + residuals (errors) + \ No newline at end of file diff --git a/images/llava_architecture.svg b/images/llava_architecture.svg new file mode 100644 index 0000000..990bf97 --- /dev/null +++ b/images/llava_architecture.svg @@ -0,0 +1,96 @@ + + + + + + + + + LLaVA Architecture + + + + + + + + Input Image + + + + + + + CLIP ViT + (Vision Encoder) + + + + + + + + + + + + patch features (N tokens) + + + + + + + Linear Proj. + map to LLM dim + + + + + + + "Describe this image" + Text prompt + + + + + + Combined sequence + + + + + + + + + + | + + + + + + + + visual + text + + + + + + + LLM + (Vicuna / LLaMA) + + + + + + Response + + + Visual tokens are prepended to text tokens and processed jointly by the LLM + \ No newline at end of file diff --git a/images/lms_adaptive_filter.svg b/images/lms_adaptive_filter.svg new file mode 100644 index 0000000..6eee78c --- /dev/null +++ b/images/lms_adaptive_filter.svg @@ -0,0 +1,84 @@ + + + + + + + + + + + + LMS Adaptive Filter + + + + Input x[n] + (reference) + + + + + + Adaptive Filter + FIR: weights w[n] + y[n] = w^T x[n] + + + + y[n] + + + + Output + y[n] + + + + + + + Desired d[n] + + + + + + + + - + + + - + + + + + + + + + Error e[n] + = d[n] - y[n] + + + + + + + + Weight Update Feedback + + + + Weight Update Rule: + w[n+1] = w[n] + mu * e[n] * x[n] + + + mu = step size + (learning rate) + + + + LMS adapts filter weights to minimise error in real time. + Used in echo cancellation (x = far-end signal, d = mic signal) and active noise control. The step size + mu trades convergence speed for stability. Variants: NLMS normalises by input power for robustness. + \ No newline at end of file diff --git a/images/load_balancer.svg b/images/load_balancer.svg new file mode 100644 index 0000000..5655757 --- /dev/null +++ b/images/load_balancer.svg @@ -0,0 +1,49 @@ + + + + + + + Load Balancer: Distributing Requests Across Servers + + + + Clients + + + + + + + + Load + Balancer + + + + + + + + + + Server 1 + + + Server 2 + + + Server 3 + + + Server 4 + + + GPU 0 + GPU 1 + GPU 2 + GPU 3 + + + Algorithms: round robin, least connections, consistent hashing, weighted + \ No newline at end of file diff --git a/images/logo.png b/images/logo.png new file mode 100644 index 0000000000000000000000000000000000000000..86e6ae2cd808d21cf91100fbb25a1c76630c9f1c GIT binary patch literal 802543 zcmbrlc|4R~A2&X-Wu2%jSwG1`~Lm${Qh|^FXNc&oa;L0b3Xg|y!Qt8CP0TwjZBO{OiUmU z6YvAtn_?O?F)*;dYH4m{a@p|jjSdiSv77;cJUk!xSi&wy*xtA)!9MYK#Qr@v!q;p6 z@Bifjo=f_*|8x+j@AUtr&;Q@o98S)@2%y6n@GIp5yd0S95uknK?%%fTe%tYHTYJCl z|G?`3(C6xY+sE3{0BAb_ZP~m3({}uyHp0tizkdSIN87{w;r_e!ul*TwJ9}QU0{$KZ zenmmPAWIMobYXw}!2iI7d<_C=Y=c0|$^ZS%@hJ#Y6$Jtb4gdE$sh1!S_z?(H+4JA; z{!0@txDWi#<5+;t11>Hg&_)3W#B&1#I{Y03;0)|8GwKM(;I(_zoU8#X8H(#0NUS$HdIX zwATuP0E=K{-DkxAT&!#dIhdGP*f|ey0UbOJftXmB4zMsEVCCfH=`s@WmOQ%Pv8Eb-S z1{bXy5f9UH^P2l7g^ZjdBBNfnw634Darq?dcWi1pA9$S@AP+3e|2rRM&;b@UCRX-? z9Kh{cz|{9A!+em9osE?l1QOhz4IhY=|2Uhn{$UGm_9G`$;NdA{XDxkR{a8CHdBG~R zyy>8$Z$#RpfRsV+x+>zebfn-(>vL)sogU%+`+4~)G(FzZJ+@| z6+O-4WsFJ=MsCRLI#>n8w1%t&H?A4ZXtQ@QM;Uc;C{OJDOxfw4xMq*58%-m!Yr*4z@9_Vq-9w^L9vU_`UBL07elWcN6b4k&P zG2bL3nzo9j>AbA!g&{`ZA0!U?TU2=C$kTPGHeVo^X#@X3`OkdyGh{-7dRFjPX+_YU zrXofT^wa*JAJCX>A#Ehp4qais20P$+x0UvHnc|f8v*FFnO>CB_#X&=o)IubfkwZJi z?3xQ5XT%(Men*a-Y(DpaXEO1%AwUXi;vVR$D{w!N~t7*6Ni|f4T^Ovcq$IIZSGypeIC)^qen}iJ5kpF2V_o z7FT~>djv1;Yut$X?S4ZaGE-f|Sp+NdYp8&*F+6;ouYs{GZoAYZgNVN&!aDd4GLUfj?ALFDID*4o6_ zMK*ytqt@rV0sR4YT!!FCNydbh&;6%-A2#-JMKpmb?aLji$e~dqrsZAS_cjE~JW}8bE8*X)J4M zzeCJ>c<&fAhLrgyAT)seli8?mSAPIoet_$Kw?jFnyD2j)NQn<6HU8d})9s`YBjw9J zc?3X@&}5ctZAhUH2liRBdh#boUZun{%QxRMp-C z6(EYcxCe5YKKYL5j2^c@t0XBa_2Py#EIf|!&I`d5^Zj%R>1>QzCiUBLL-A7I~Lci5~Yst296$y)e?i zPV}GMuez*O^j8+huJ#kVQdCK{TVR%bR&5xlca8x@NuJWyJ$u(+bP=^4 z-OxdB!o*gd?&T?3#8V!bcT4UBbmL6-K=&s9zR~RO8)KZHWqlY*arhp{vVISg7Srx= zLi{&`Yn2R74N9DhSa$Nj^kLQy?tvObb#oW;s4(t=T_Hx#wua7xh8jRvAwOv0N&4a* z=yw)vR8!|6_0p!$ea~X(XOt~Ik7!}5Ypje@%@I@F&|F1Cw{5wSjCeG}|8Si)uD=}W z?Mq>gk@yQIXzC3Z?KItv9++rVKun4&M9K8Cs5jdj>@SPFM|0&qN{ zne6}aazyavL3*m2s8vupzfeeN4!?)j5Au4};ji0>GIoyN`z zY^LML^j7WMj7QJJsXdUQXhSJ&bA$(J%?LzBgxvEmw9z##g2b6lZqiRt_CWg423sZ* zsJraXx%S5*S|XP`Lt9VHYbyY6V%}&D<{Me``-r5-prvABHoStcZ_-*rCmdAwKo=XD z2`cChN1ijY@4vJgLj*IXR@aWUq`9zfc5$aYfG;-*NFBtxkyk=87=b@%&js?v?~)Gb zPKNh|a*|ikB9pMjGa>uLDOo;By%x%zkOa_WJ+gx$$_K1a!(c{*onExUWr6y++DqV( zWPWs0nW%u?fh6R{)eXP(K8|7E{<*ukEH8hO35BbI5_fW z=viumEf4eu`0!>kHkP6bFr4^x5*IZNYBUOHzV(Y+KQ2Nx(x6Gk6Wi?D@0U5Dlt{Xl zY^eT}*Xs!(MWFkU>#ue&KbqS`Qr4gfdy^c44G zf*LllZx65Tp8BorByuP?aXmuu;Hp3q+nIxf_aFmjae|c%P&(8_du~GQ-Y zv}Q)1(yUr$EsZg{Hg(0mcMp`JYYTWQRTRu-b5-nM3ZSmrj?SYHTYxKuZNK*O*1{fW z*aoqFB>*kMUMP@8i$%$w5el{01Mx8euC8&500R2>)OQ+wei9B>iY<8a*SpJfhTJV#YiOB^Li2gc{O{|)Vf6Q(_CUKytCXQB57$}) z1`A^1bZmM~9(D@1DXJI{6&C=IN`{q&LFoii$FNIy_GQ@WYsT>~KRnr-jM6yXd3wgA z`1CP5+s3X4oQi=(Z-m9^43>}^oYyToIE+VXBynHvIbI7U{ZM~NJrGnRHj38@J`YdR zlxK@${jv$-8uh5z@onlYmy_GVD`k-_#5EzA#uF!wLbqRwpPpVO!J8dfYt=_Q^~6eu zG+JFklFiuq^RSRweB#I&)m(U4kh}&^lBS$9$XH;k;sx3CrHBohLAn9!d!Tjb&U#WC z8ZFpL(?m|GktY52KvX0d1tms|)W(5->Gl@AM@XAG?UYnB+lgoo4c%5s`&9mIm8?jTB+Vy$x;`k%~%A`r0_+!x~;qqSjNCcVDJzg__?kPHKG>t)}His0*i{ zR(THbhF0kLr^2>cYLWCR2)DHpKZfz1fD(`d?}22Z29;JP8Ilw@Ks?J(LXu7hRECfQ zIOX4spbXSLld6`;6W^UIQxBlttXf+pb6n!< zSZ156ovu1?MoyhLMez19Tt#K>fjpNFU3Q0B)7g*rY%&H*eoEmi5(wC*v4A`z2|D0# zF%-;X)Cgn@I!2k`PP(z(uENAH7uf?HZT^?`XKE4D0xC}QsQqYXmW9gsLQhHdt0{BF zF6VFR^QxgVFK5U;^yt+J!df()h*fyp9FX|IU5Jn}TPgg9sCc_U@jEmTNs09IN6-D% z4JD+C6krm`YP#j@8>@}or$T?n-rAN~9m%@1K}N9e{kW zJy0vk-n{moDG5^dDGY41#+FiET2()Bg4%EyO*7{zwI{UDU~A39xY@(lnCXaYY9PP- z$#PeZ+^9uspQd@TJ&15LpQk~WOSCj7hMw`=GlbbuYWqb;`ES7Y`-2h=w@MI|{gyK8{P@8hbJ zNjF15lDJAdX1WSrh*|rs_f{U^D)}UQx}bXHb=)Z(YEnxm9d!r7Z7Ox)WI{61;R<~K z1TQ!4PS8@fOLF3=T5k`&-rM^~YI?wQ!PKfh8j+rp%r-$af3*$qjaEJ76MU~GZONC= z@rGkR61sPIo#U4rLlCmd;H3 zmHA}IPA#QBr=Uqr_azmN5?bfEvO8Gs7?|^hK8jaxc0LbV%u@s184;@?%nb;im@;9k zk6>O@1J;Bie5{8*_*hT=_m-iC?EC8@GEsrt!WFfS%4H+2bKy^#>Zq==31+-GU= zV{RqqgerLNJ{*uO!elI-e7VR!R`;gm*~g^OU&?ispD(8mOmBVSK418KY3W3TcIAbt zMJcSlx8XHElPAtYFmBU2t#p`X*T%4XK=QTUsjcmDZf>Hv74OduM*B&)394yK-cYD` z+9vJi`n=WmR!f=N@Nmsgx3%*5k5U$jFAwXLhrL?{A#P0%k|2}huewQ#5Jgint^D%? zg+d>amfpl6DlM_itJ(dE)lby&9Q&;x-sf&a%LNp6lqz3xX-~DT{*V)q9z9og3VPJb zEcD6m^yQ}?V3$rJU{9~`8!DU1$G$AiD|F)gP+t79RZso+lZcKyttgF3?5U|mg=a|y z(kGo_OISQ_7#M$ko#Of4F=Bw_m1v1e<#|oRt`!MctDE5Jo31i5Tbus0m=#)Gplvg| zCW5lMHit2aGpWxFt;b?}DoI-&1jx9Z)uw9CMuP_430IGqLUZ4(fYzd87-uZd$m7F- z#Ao-7!V@MoMP0*&{mq9I0PY*lOHPG#_)jFavcv*f9wXteD+tM9R68a8;lHAqQ(bs; zO1qVMuVtSJ<-kJd=}OM;tv;z2^hS|Vp5Q{sfw7j$@+LYJ7iO{r2w25CD@4pC#3{K_ zHuY+d&zys%Bc(nWT_=c`bJW-aZ3Hw$!RYrek&qT)m8z9I-PQI%2ru9$kx?m`9yBjKymZz{TC$)U1}I4*9$WJ#=zx1mD? z*HO?ucc(8fy3Bgjw0qE%ks0!<8VKzmHH7dM##LrdFq|x82#}f4E`n}^ntj`KlrY*b zdpgbK+0XPR$cUw4JdeZ!v0sz?=@(!=#wHTT5kwO_?qXTa@L3nO#_VjO!|9%lGo4O< z%vaH9VtCbUJZ>)jdr96&DeDmC2XmMGtbVJtuyPLNxm#X-w34b7xhh5y562u-z<)}( z9apBNCbd=75y0&%T^>Hv(8zm>N}5Z+#>}FrT7iIa0_@e{b}d#q4cdz4hE^cQpe@Xk zJNwpy!6pe%zgVbuc5aG0YIukg{B%~azF;(x z%Ctr1>0wEdD3?2IDF z{fy1fuh~(`??r6A%@tb=bs?qLb$3`D|62|UC<$ADCg;L^9GT8(DEDj4&#Mkr-VMHtY^HmiG}U>M_uj#n~rQhskOTKS@YI@a}{Gy7v-(yr)HpBVC$=IRj#%wHCIQrq5rV~OW>6>Bo zi5V_%90?&WHdA@R&}o{?o0)qisq&fAtI*T=Iwp$&;;A4xnSQ>7g?4e)esd2YR*Z^a z6?wawWEbKrecam5qddItB^k&YbclX4>Bz^>Z`RYe_wBRSx3)k(z&Omk8T9u=1D=@w zblCHKgIBU|@RR`|fz6I205Q%VALw?O_UmM7%ydF9cZ5l}vJ9kFPZ>MrsVsC{vg`Uo zKH21y%9-$X%cK3clK3aX%SzR=^PBC*EUm}o6kuKukrZT4>_)q}gxavt*<_{s3f??J zxQUh&NS_pTP!fSV*vItk9xPsjN@6^@@W;PwS#O(6{Q})BOD2*XOQA9q;>!OzYr@1m zkg2&B`|k8$$1LQQ3k{uy2k4%#^BIuei|n+obZarS!?5%sz(6nQ_Z6Z{qS4s_SkXpE6b85s=H*6!Ne@M*iNT&sd zcwU&Xc-{UQ9+`6-4`;ThS@OKItYTu4Si&7Kv7#q1m`L8S7nGRs$&TU6O(dMROK8?^ z&sE7&kj|4jEo~lQS#lL+Mz7as=}WvEd@m+1F!lV48ibrC=S%$Y<3=QrL6n~$d9(fQ zPpRC6gc(S*7V|<%%fF(lRRsi7X%ouxPAY(3Z4YL!cyDz;mT;4|BNzx>k5I68S2txp@zyeSPt9->xRs-%aj>h^Ae>v{lTXZ#e@64b6GJMfkBz!cnI>f8+few%Z7CA7W+vVDF$Hoah z*rCWf!&B!OL|y+aeo48o?B?#%hTzv)vYWz==aWk_QxBxo0p}(>vD6sf4X+$l1$+%bntB2}g`|4bP2^nt656uf^zbLGs)#78L0sbS^s7=%izUmT**me1-r7r7$pi_(?cN zColSu00SSU%DFYWu4OXx6kD!%v_f4DCfT&q^dXK zd`0$Qwbf};4l`=uqZeT^h<3{t4@+pR?7eJF&*o4DKYj;t!wVeq{P^RGeuFAUwim*0 zq+IfBIBaj_pS0tRc&X9HE%1xKSMMfA2!E=Y1aZ`VB%uqWh_8&jjZyIfIPDw7;>!hj z60(%}4K-k!T64Ip9sN@Z5M~z@@OU@czLEN`wW|bNq1yGf))^Hr9nzlcRS*NWKq6M(M3}!DP)b7CrAeA{pG^vtnD9%3*EVp)w2&p z$UeBlvY6xy2glT3T7#ifDyczPeuIPiSECbxjESB)6Sky0WaY_HnS3O|iKo|S1dDt9p5XT{cTNp%)sJ)p2?UMW_eo614M*#GXe?Y=S}unw)v{PwMF0XBMwb7K!a; z{xJym&^Qn*S6`3!o)^hN^@GkF^d zXMkX}0~!p2T+U(s#~3=)m(lzFhN|W`5gcS@9kU0@_)jK4Ip-(NR|c>kA5}<~O7qWc zL;VJC`N!BL7;%_bxrvz$TgDm_e9KLU?KVfb^R94wWc-$m)o+z9P7k;&NRHJ!FCFW>0~O>Au&xy!mL6TUjnH{slT_6% zV0nc1>(?;BEi{C-*%aH6524@0L9_#=fv|7FsKDj~Tm4Q@Agb+wV|-lZDS(l~UANP!%*@bVU@o8!d>uO#Idx$BYXN-`WYd zbf_O{2Pr=7=;=>NLW@eumC>yuJyt->p8hq14|fm;9#0vZyK1Xixw7Lfef-zpy~?uB zm2%P(eq>L*Ev@j?kRQyFhZP*>GQPHt3%LPCPg1P5p02jkGaZSJCeR~l_CPQ1?lRbi zlvZ;o*cLF+90*hygN2_%_CNtEfFH$#nN0@59BA;-h95h2yOVw~_l@SpHs(U?(mZ!_ zbur-{DgAwMOIFGxv-zu|MA;Y7bP~(jY^pI7AzTr2)A{u<-secfozXnZ1wR4~nXzf~26n=>t>?s*_LiJc`e3t(GqVBY z9y8wIG$AtF2XHW?ada-Ieu$T12a2RL1tm?=3&>Io2zf+2ss&H&^3?U)7Xsns7!w2~ zh0c3MYCklacE6F~NlBnkj{VT8n8TljEf~-1^em+6iPNKL&L@s#ZI=bLkTB?>W)=5P z4o&ga+s>qCt{=f4+oZ>A$+E=B6W zQrQXx>CP*s`uanrnkZ>?fv_y3C8z|sdC40{X55|MqzD2u6l`S($t7_#Fy`{ee0)S*>qzR_sMijH z4W?7xcKeh|6Le%-;0qzA7px|Xz@-sf-FErB*2_A8O zBat5V9%C=QimW{P9LTy)5iR1?yOI7M{6|w(% zJc9-T`Y9M!IS>}FLz$}p#qx?inlpxzqOm8*&h||fus8aP^?qPnM68ej0mU)xsI4~h z9nd2>j$(!Z`BBY={Os&t^qfwjK9?!Wc2zENz*7W4p` z@+jnXwe~t&2FCT+_~ zXdDen0RNka*ixK$<_$)=K(*;&*VxLX3ZCS6h zEr{eL&6r9ttqwx6QEdksC7wDC5?o|;dzK*BM#7DD%2@hZFo#ut5joFKI{mB2;#lly0J7c z6xrmt&e$AbCfnXa*=vQ749I0@zCBPCE74?D1nW^|ABCmIRM2>oX(a_a0Y(y`c61@VQ5mA#2#{Y2@MeD@`B#GCwG3!L_dPapE$EQa)+)*$kt?)-=E3?c)f zyUcg>W(VfFaLN@o@`b!m@5-C=HXFBX8nbHC(|S?!+Q;toci=1ZgGKm}U06j}!AI`I zA)8^*!574FzV%7&f$AG;fm{NRXM~RxiB5Y@X-?s|P3V!8G#)(T zIv(JPar`>oSpNqdf=W0tT ziw|=*Ss{h5LD1T490&DOm46$C5zx+iAgxLLwdKeK{f87=2lNhmnBOJpJVR!!DcDl8 zeY({V$7@Uv1LE43h(FofEOHsmd@>=A0i z6noh7m(txle`+*e<^cUxbTgxbx(B)}Q>Y}>UJMU6cqk$5V=V@cwjSb2->771;8>8C zvl#JcBD%u$0$X@)rR z`}Gm`yPrl;AAlUq0t1V!6;JG=%LJVYAKd*Y$`J1bBeQ+U8W;Z2qF2v6yMlQoqv-6e zPM%1Lo+aJSzGr1;%(URfTiRuM;q$Ea`-Za%T6u-}dVVY(3svkn5>7G|3GPb12rqX3 zxuI~wlNpuKT4|{0nTgfOfgWKfi5R7Hc`3D>Znm^@s`n09;(WE4V=e%}L4oIwK| zyKjirAsG#OpcY{w7_j%BaL^q*E;?u7DusW7dGQiAS}JX!T}z4nSJNkuu%l@dG=W9+ zIv8JKV9#vV!L;MHDFn}ui24$!cc9?Ai4k3%I2fUYwff34PA}G5xE*gtG6X{Bf$B=? zghi*bgrvwv@!K<4WYucrpr&96c9nH-XBrP-m`t3IKrgn@V7uWF!p2pLGc_deQ_0!K z@p$s!_uH7-8KV)(_mM}AkotHE2I>Wc)E+dxg0q<(+kxZ+kk>}&@+P@8pxx9~6v4kD zw;FPCg!UX!4)h*bfKa^>viV+{!AO2PlhCA_xMOA9kArW3OrE`5QmD8Z^jr4>HeuIT zPAcDE^d$^^2du5vfjpt~I-+%T5!10Euzn({;}Q#^RsU;FSk3GH5~7V(8*JDMXWiy)=fAzKRQ zssA#mZpBYt+Vlv&$xxz>{?K=xoN{HyvZe2uqK23b?sr? zTTPEHPqYHBdikJaiZ8+w%`HyTlq9h&F(RN0$#6gw4DPD}Edy#5MKnmz9D7qMWwR8| zkenwZ_!;B@>KF1I@Hfdwnnul9?U)Re$6Z~--wldH(eof_k4~`A6%0b@lAC=KqAi+J#)EgG2etUwXvb`gUIj{u zKVH7_nTG}rth&V^RZbL1DO`m+nrP#+wi<~v>Lp%kiPIh=v zL?59k&)+eIj&tF5g>YT6-btP4U`Q(il0!Q#2i+o!DqHz6T0>YLsEddVC63I|RYP^> z_RTYF9f~4g{gj7u%8YeJ_drnb3A4cyxiu-r4%|56&kssF2XD7^dKuLn8Y&qGT8a1& zg>Gk`lg^a5^V>G6ndrN%H!6Os?o4a)+b~JL(T%FAv9277#_r1$U`vVQ)qB7ZYL4)J z0=sruA(AFW_G~kR7n#0Tzbv2PUR9$^|{8R19 z>8_14Z+2g00COHA(^-%-bH3W6^ey|gh^PM?@-3)uTw6*SFPOuKhN{4Xmb1t|XvmFb zB4&lKvP#2}M+9|6lJ&=<=l#eZMb=;BqvyyPxBt2s$~pav1)9gB-#QOZstVd2@!L31 ztZ<79o@t7zIsNm9ctKxzvOgDdsSVXDaEg0^RK|6JdBKhpS}khnYO^I2v4Zrd8Ce-& zCs^}`Qi!Xoe&9V&vIf`@3Y$-ucXX+`K^2V~K?ydUJeWDv2W)*bUq%dFP8X6`nvMV9WE6KA!Zg~+`KjFh zgSMY_dX`j?;irejW?lMPvP^@tI^&$93I(4M(KZonv05OtYC-mbIhqw=XPhM=relMte zCQ<~yZj%(%mhk0!#7t_wYPr4HH#M^^X73Ge`8S7da6D7Kmn)tknYt+=9-MQxbxM_g z0lZtKo|4%1zGil1iDJiD^7e*3qua^89x(e;)tG;S{V^P)9}E7vZjovNWX4m#8}|r6 zA}oOuy3`yZ#1QIuo?~IMhR;ngk%&xml5h`x@z-3Rwnye9c<_nO4`aVcE8)m^12}I@#8FBU5lNZ<;!O32tM+3p&9`v01 z^}qGk*|1F+9 z=mW2pB2;JgOa(S@fqM($c0}3GuaP|{u+_XSe9P->HwBx+Sr-%57VxL`xj9%~>uQe{ z@#8((6f}+Q4^)K82Vheg?+PYhxFNaAhp84@UVkcPS^rebF8!^TJ$k#~!&^6$&9xmZ zK=-i357#Ofd^eOxuv=s5{ zJ4K^wZayqOMj7964fHc|4RP&ciQzHY9;m>3tv%=GH|L7OS?+gqO+K9e6yX|W|54?p zja~AE_b<-fxe0PBe_^KS?Y0{?#C=^aeevgzoYuN-%*n53gJR-Oy?p!LVVNW1Z+GarGe5uB6ZXF2&-d zsPj2?ii^Kr!nASoPEHZ&ac4HV7aPp%Gu6(2?jIKkzx_~6{nA8{HuBA+Z27w!$#3_{ zhZGLJ%#&2@PTYB?Z??qEa&zF>L(q%%U+p=HCqus5o#=9VmwO&J=VoyA2dHw)DdTc^ zWWx`9HH#a|q)u^NrmyBV0lp`b&%bWU!NMnoJXoF$e%9=2;O@Yma?cbAsvJ4h?{UG$ z^9HuYLMKzYikI~yt2MP)@&l8EtXZcRz)tstmHW3&D{;f$=-Kwd)2Y-{Jo;bfTm~rBx9M|wyjZ|0DOdf|5md(5F z{?Yf&6W;|vByVgHB$~GGTFF)@DIdF0c(e00#gKI6O(1LHEzTO{ujv6QFG}B5Uo`nX zjLWm~Q-1qCemza6J^e*}ROUL2^=?I=xbyR_Oy|HvXStHR*Coqs*VJB5jw*f0{8IU#+p*NZqM3&g@cdg+MZPY2O=60ZxA;PiG$Wpi z!G9l%(!id6_qI#Z1rxDpl{BAI>(s*Bqxq=)@O`ycamPZQ?SZ1@mTu#nRHAB_Wy+)E z&((d7RZCf_&*z+ak4IT*80AK7ipHEOEK&_@EjU^HZT?l_eBl`S&X4oEE9DsgMm}x)r=DqUV0Q8@zse#su=5RLADmL6B{Yt%rbl* z#@-m3{7pFA&#mH9cn)FY`^4?lPcfM;A)Ukb&MLT>U#xCYilRY0yF9Pi;pfvhFU8$3 z4T;KeIyxj;G!m7Q^z#1eZV8=y>036ecjRoA-nI%PBro;bg#_JI_$s%_`tVYl1xKI7 zC5cn2;Ere;P`Sr^wSbPhv5>#$y|T)tD~snB6%MZ`zhH1olxCae^dXbj<;o!*F@!0o z4lnmp5tZ*d$K26dynhSyzbL$$(7m(t8+>h zaR57yZP~?e3lH$To4}-sBD-_LJ5fJYYkoB|SQ5@dbM2XY2?^hXWA?!%Rzd28rI0Bg zc>(EcncBU$E+ks1T&hEmL@z`NO}hv~Hh_~vChc?k8m=Xo=hOhaX0C4y!Zh|Tz99RA zmc?-BpgD#S31B^*K|fI@tm0e88_oTGsn13pv@e_rUB>@TX(S9FD+p&}V2eu8Dgl3s z;BU_j-#4$5#TKjr`V9!lFtLXqM`LFX;^vN&{-6tGcI^cjG7jL%tC8{0eG2aNe2*AONg~8OcD(Z@+pp zoMDyl_+!00xNaKL+6xoPTCgXIo(Jlxm0vWj(SUkmF2#oa8=IE$Vm`HOWE&PMGc4!Zs#4p;>j7V|32yX+Ov{Us3wTL7#f{L49`X==d= zVHa_t(BsA}_9a=|Qy8axKDxp?Hf*^MzL@mn+h9oKx|g;OeGtK6A~)~tX>WQT=9+=qm-U4TS;GD!^8P*|4Zeu!8P(m#A4tb|Q5?L=o;wWD zKVZqvvE3#ARJ?Pi^&`1&=8IU{kVdxBkzZwEE^T2*Ny>sh3WkKv{sUb2H68Ts`0ANF zO#2ry!;FKkB97D_a;3DB^+!V|wW{s2@+Tb!NHGFHl*l#w7aCCa2O5CjZABvrAn5WA z$Tm<4^+gE^F zLf}<4;w>9ZMv6HJs3Pc8k<6UD#YD!*&CX?s&=%&fo99UJ696{IVZZYEX0jiAbu<%t z%$RXUr4ApvGxx&+inwZ9g|sK{*P5U6X>ErSs-sJydCi}aF5G(Y@!iFK4#@UjJXfbCvUsoM9%F4)0^ulaa?`S+ zP)^?1(mIvXI)kS9DM7O=H@7v&c{5%&@7|+_BL0m>Www*c+)-)Gd<|HSn>(Zl0u?^d+7Mk2-n& z^;^@Qhik08Uc($rZpJ|q*4>}qye`%Xu11)!tsf_N`-YM!++tn|`=#laQ&d*ePka68 zQ{4fWoWIbKcmb*J1BVywf+=020_c3Szfw6=W+c36>kgLoT98u zA_>-V1ApktkJZH9|`2b&22yg)St!W5sA=W zr{SvBfmSYJh#@g6Q8B-UCuXZJS5@ z>c;uJ)n3y|d7blFouX$C&(&Se#4B^XJ!Y7cTOJ(r;p;VFH3{TP$26$;l|$md$y(>p zo+TOoN2P>bP6_XZUt2vD+nUx5=RK{C!>$J&{|y#%&dbn2gwY^R+An}iET6}uTx@cw zOqa@duUE(a;_+vbXUuyF(Em8N;VEkP2eceYhO9SuKA3WNL?XFN)41ALw?| zM%zv+kA|weB&h7WpR5xtWH# ziJd?BRkCd{59^$J@9rpcsJcA3UbA)cL1HvX@CGLzDx^V8?V;-DqBl1rJ#>;5PjZEH zO~_@r`q2D-6r{f7wCWUhmnHO9R5IAMF05ZN#CWxSDn^o#~< zu~G7=4BEAUhr-4IAJt_BoS3P{R95ncQyLB1^#|RfOH4ZGo$^Tw@^A;w7?TKPj zOjhix?U*9x_dtLj5!t{RxsS{<41zqNi#qv=Oa8^bylB(vtuQaja6}U~plN+8UO_MN zq0UeMF%t^e$woDlHDh8a`}nwj@JG$f!2-rt%v!eTK+#64RA^u8o1t)pScMfTyOg1Ntn)r;O%$z;YucqYYdC9M(A~h3&RlwBrVrO1puSW#?B@D z574G&(jf~Xa@~*IT?!=);5SgZ|KfPgFb!?c-!Kr74a^aP{@^fx25QX7m9ixUn~+?< zoqH^`bv+LwBOm$VS4&yvnCW`0Wy+cv*8_?2`;Qz^=C6fQ{z}sU_-BcWWtIC?0kI zO@&LJ_z#+${a-Zui!TkmrquW(^pWTuw#Wtc3ci&S_?5=r<7#50odnrxUh%q^Xkdc@ z^cP$+^VFMYw*~_UrArdvH;>;I>yO$YK!;4K9qT)>wl2w5eivE zsnucio!=A%h-r?@pQH^=hoWsVCI^+*r#t&(udth^#RyKvG5H zpCNMf22*EV!UknXMZ4_G|ARrV{*#Akx%Rv3Ij_B@Uf<@^``>~O3&h-;T3R2L_0GHF zBK)wxc(~+9+kO7+PosVP$z{UHEn>zhPAVD9?F5<5j|#sF%DNVEd?KyOneM%^f3Fbp zLw_KYq%q*4nIT{YcD;VS?ZcGn$qmz%Q013fPqqec(mGnW+C%&_6`yAPm_CL{@RjSL z8<|zw;?D`*-V!<0RPnp_V$^Zvr#7ZhH&0d=FPU7dwJm)qU4ARbPFCy|5*zlB(SD<- z?WlL`*-JcPPF7~G{2i@+Y$L1dcR^-@KU>;v+)65(diz9~>CWpHxe0_!^2yPvPtgMs zj_}trs%HzQX1;%Tn?IfQ4l#P6Ara2E>-tb-`n9#4@h^|hh33~g5~h8GpN%XgEO8kW zdR4y6HB&Z~I`d@IROnS~u3val*@euXg<_9~$-{$Xekt?Bb0sdaHBJ+KMooaQb!A`;`XirLV`kb8VL|HERaELMjyHJ~1{G`Fiex``a<4=0VeMy@K8u z!FC*+S4$r1sl7X`&f#)}^OuWHuaouZ&;2eR4@zYOJqtQ>?1SHm^<#Xp$BOFV_0wRt z{-L0+*}qOV#VDEm2iFD*p`Tr=H>wqwa<<_4>@2QSsb^X2^4$4YvaW`rlhPhwC!l zBww3U?Ht4d&$$`yU${|{Mb8P(R>wQCBrxVslGUfiL$6WoitOL3>T z1b24`?pEBrxVyW%?33>O?)QA(IscQfl8hwlnYUb%XXuMCWMW;o+6qMsUoz7hP3p8_ z95!bR-??1ik1HtpYm6c2FwIJr6`xb<8?8y^+^k&jQb^b>`C0%u?AA9%CXl$oxAndi zAM-4=7AK>mceR?eul3%XEs={!?nPzBEW+PC>F0E|Ew5#tueF9Xyd*o)$>@dKpXs*x ztgTUDFj#-B)miULH~ji>1znIAiD6z+s$5efWTpLEp$`E`*2|v0(ckg4g{ILRxzx;& zGcM1l(#-0Haw&Ls5rRh>f1?jZu67Z30B5{1QkmDHGNVXadb0E==VP92frU}VzMz~w7r(&aI@HoqPV6^By)z^ z6tiipD_Y+qdPZ|c%F11?z!Ark zv3li2p}3iGeyVJa%B`@IM=QR3*-1LBkl-e|%(SVY0)Ne9d1R19cr1m|;Ez}h9A*2F zg@U4qoWk+zrL?4qElnEGj-q3M_eWVzI7PXs9q%cB$n_YiWy(qKIpka=$KgqF66MiB z=6e$b`^Y4QpS3MM`(Fl@NZ`#U+zpcmLcHlQ+nHpp`@2WWh<-4`aUx-_laFy}n&!px zZl&&yUx&W&G(-B{DpR))%l4#xV1Qp-u0{RbkkS}|E+?1H=2ZB}3b&x#!MqUWU8y>% z3P4IN69;#y6O03aE{Pj^VwQ3mpF?hb z?{MeV$wkN3qx8l068wRe@HrAg{mUR>6uXZby-@cJaQn{3+=ztA^_&%?JkAhQzeov8j>O`Bug4v1-KX($f0K2xAT?g>Yt1yA*?rMP$GR>Uk{Q72sYl(+FLqiBie)3+{ud)8oqf)eT|*%8Wv}Q152O zF`3V3#9E8qh*EZA!0lx%r7eY(s3={r3= z09OXo5&!09(6_B@dVEdNv!lu*pHI-=KVvCAvc_sN>g5Fqu&_9$H zZ+!3A$ak0kqhTfGduhq1Arh&V{Hx3AYRw-vHnWQr(u=T4@-OIu4RpIC)E@(NRrXGh zO&1A^e|zdaOD&RQhW3vH6^?)w(1oNo`xVI75x(#leg~sj2H~0V0kwks0wSOI{t;8* zN6)-2+9`Z2y#D2!JQuIs;MU3A7d83LD~5pI4_Lshx(5>cPXQMfdks_$cmzAOsZJq- zx3O;x?*_r0B^xFk64>Uz+d1R+ny9CO%cG#e7*}o*AMkUjmvHBK-l{ybM~UN2mRHJ` z$e8(gdPvzh-j5V10SBz}G@-W!Xj(n%t(?sG|MXxz%E38%S%KSKQR8^YK`&86+DsH!S-lB zdkCe35daBu3DE13%1r?A2hIr1#zDtV>_5CTd*^axxeIxcAV-sKj6Lj%?RsYW z1~ne2DP!Ca*NWrwhTX*ascHHl88oUeW1q%#=evF3w!j}1Vt}vX$~ZbPSGS8;A$8`5CV&fB83;ZO(8zc`q|l%NxrVP`(ylqm)lTDj`6HZJ+=*u{)(^yMZ?|{1waP=V9hsKG;G??_2kGps2}}{y)SQfnDnK2Mj)AX~7Hu}S zJ=%@jkS8Kie=Ff$8*WO8XJyG0sGDh^9!?cGrfxv3U5`$4247w z_DKSzj(y16+>Sl5=VPrt;;Zj)+kJmD`Ap~{9!Yzp5HW?<)Bbd zSSF=uxlrg}X5taAuQAj-ynJoku@unD^7ignMYMT%+kYf9L%@Gw==~(&8nrf*I)j7# zC~bxHu1Wer!|a{Lm&jy~PR>@`DkQ}(ZwtG4AwSV9jQA49WyV{U*~UnaR8*7&T#&Ln zCP$;6Y_CGslWbzO7k^k)I6&Wfvb7UI)aS55w3p#O9fh`AW;05yt5z7a31QSwHp(_X z@wBP8hj@a}>Uz!iUW}bM=+M-KXukD_83p5o@P`>?H1gi9*8mpnh8CmIB8k8d6|(;K z-wW?k7duS$TG2(mU|EHB(;h+858SSO0-`)3BY=M@eP~XUdSOeC!{*)-+Q6RS+@ncq zp0L#3ifp_vYtTBenxe|3r7)2R6DyUljzK?170r4w%zz|$iTz||7RoqjVP__Qmjf1Iyl5Rgx7Rf^Uqo44+gDoo{8iV3W=?CJZMVR| z86@S30hW<7#o6x?ze;sQR12SZL$ys=9K7Vv(fizQ=&V#v?|vs-UMtfosFY4NZa)j7 z#$-js+P@J|>A6L*^K;owNRPCL+9`9Sx6JqY7#Zm2EZTL@v?N(6E2@p`TbN7}3d!RZ zpgs#J&_WF%#C@RZOFD|8R@~wBn_iF_yom&;S1+P3I#M)yMCEb^@db*zk)q>^?Rq>N zhpR0rTe=N8=?i_n^`jW?Ntk0PxV=jVHYQEn9K(!~hC?)%1$Cjr;v9JD!)4AYmjX4Y zzV!6TNFPIRb`~fQpt^9ZBItSPOI-=GPcW#~Q_A zQt$=87G@<)XXVvmwnw9H4%eBZ7H6r*ky2g5K%^g6e@EB&*v!8_zvtIsB1Kmgng^kE zEGe)9Zgrh{H|=t~MYFEsd-NDG26QIEf)I0yXU~;`I?3!7?er8bOK z{1o<9#Uz&DBPUr(5&QFzDj<>lQ2S!|i}gRPR$GRTHG1=BGSKs==uLOzw-4WnMYDLm zB@I~d@DCxoqKwJ|ubX1SBv;#XGi(m#aeni}ys=LkW09yNG<3WNz7+F4oyx*ij?E&c zjWU#b@WhThY$U#2`99-!1sCV}h~M+H5a+g)+eKN$J^pa>))v;D;`Sb}EGjjS8}D@- z7s_!*rjaD={qFTOmLztnfLj2wamOe)8ZKx?u^o+SF#ph~5uAb-%9Q)n+2b>a5bxlnjTq?$`!Q&JWI02zT&0imQVH$=Wg z9O(7&$~)C1kCVK=uC@;kk1UX!VBCFXjG83jhH^C+^t~eN&w1^R`s}7sk@s~8hbT8Z z_l^rzS9S6cJn0!`yhtW)wEbGv?QokLKoJ_(yQ3HUFGg1S_;{1!{9+ViF#WD7Xk=G* z+?rv>^5n%@Q9Nuvq?O@Z#$UX=g~^$CFqMjp*;Ept8h35f#>qC1XsN<&5zz`hh|DRD zv-8PI&U${0a`}vOiW0U#a7?ug?T=4P5u4dfMw=M!3PBVB$+lK;+T01`unwQZvv%#-Ou3JZPQYv`zdC{6?19llP#4Cjrj_qCWE+it;!mS4eZxu@ zf;Gg8>kuC5N-vedj04D8iyQP$aT+_p!xike9?psU9GzCHKICuYccpSpo7{couS0D-G}!$-EDw&X~ta+I?6C-oT7N3`${VBXp&{m{a|rntijMqFqn1sXE~I?A_YftqkadK=ATS*WVdT_A0F*()MP>SuHR}_B4+uyXW55#3$C(>)|F5eyGMAuz zMppX7Aiwca`>YE4lJhBarAe3Pl|wa|sCC6|PHB5TG!w5t zh0*MxQ1XcWUihYMNQ{)+cT7TW{Y)|i=Vq$a5&EE&T#`aQ`k4lA0(g+HmER1_(>#~- zMGE>xbX&bATsEpbwcbL|c|7wHno}v2r*C==M?u+_o^- zR7&JTP_F9t>gsD{k2qgjh?{c8+|P_95*Ll((w8IAkS4ypalHxjil39miD zmJJx#amR7H3$s%)WW!9v^$gIqi_|5^?02agHV8@S^p%lV+BikYc^4g7m!8pe=%sZ! z4YRcB)2uqD)WEkLL!rI)maU%NmXaC^6!m5GNs}(jmFl8_q~-3#fS^H#z7kY#B}su_ zmdiBC3qa8q$G)Dp#W zbo{}V0kPbsj|TN1KvfpZf*=N5CHu%Da8`m{8np0V1);wn z>a!Ov`z-BLJ62GBk|bg2`~!>quU}+~>RMU*lKcSmK3}!BZK7W%b`}JVTcd?0_>^~U zxBn>_YHykK@yYy8B%k;e45$OFcE5Bx3Z7YRf7h!5dsDCZXfz)Kd;fxPe|dRp?3oNr zNd|wmf4m&_BnV61mtr;WSK^=0tsOXYyPAB)bi*X@?xu7}_$P7zer%5~4Z*Yw$rX5d z#I^Dr?)>Dp{LMS+x~r(u)Mj147l=s+0}T$S{z(^s1u6`@_ zCxv!5tKa@aX)jiG{N&S|)D($Le}5u@f@5ezLfG0XDrJ z{NWvM`$#~h&kB&k481_V&k26pG)ukWBK!M-I)7i#D0>OjK{mZ#FpWN4w`9i*uOm#~ zL+gU>k?6HV@mQmWMJs5ybBuYQQen@bR$OH~=*Q^S=`k8uDa{j;ybZ@;7SKq?$Kjpf zf*T&Mo{mbs2Noc8nnppiWB{C~U^M%GM&Ghqu+D8XJ_Fb0Gugm%Z@Cc zf_hXjxs;$(FEQ{Yo+oKi6EYb3xq}#G5~oeB$1&-bI@95(V}2QAD2>wID;`>v@mU(v z(@R$4HrO}nW=JJWtL!(-VKIMjNTfK8VT^&W3{r1-S?Rkfy>Y8N@G+FP*u#|qKU zFpG22CqEsP$j6T;_oe(%$Ur{W6+~an-M=i1FyWq z%L{Lo@ho^PEc!*MIFit@Rn)ba#d$w9bn(M=@?q$wxknt$DS`%al2ET-k7o2VV~tFx zZ6Z^z=ONcMFMj)>j($$4o&LZ(H0H!JHtn-sDJLeLvZGk^B{^^4rwQ@8St(ag+%C6^ zh$M~UqvzK?w&|7yz6;OfB#}1emV1O#h+gEmkX$E*Nvgy{d_a>0c;s?UGExE{s~>_+ z4IFL0ORg(I@#FWtB1TQ$z%FMLkye!uc;Gq!jYN>HxO znY>m|O8?B!8Il}vu1hjqQ=41V&9B{Rm|w%HRD`o+Bf$K{E{ISrk+hq&o)qHLhR&>V4!93`uW{yGBM%Lf5BrSR~-qf}tlR5Oob zL}hVf#rl2luO7H=1txW17nKK@gf)F=H?|d;^8bkIq^5-Hdj2iS;(m4D(OPBrXQmby zM5hzoD-jE&kZIu#VEI+2MDe3>z-^SY#bc#V<;-u!oliX2H)4U@GuTdFj3wSd;~L14 z!cX|41qWkjIV}ba8IUO}*-}J-maI19d&JCk!!a~;VQs!FeF|iu#kU@$(e_s~yqjhA zZ?eu=2Vr5~&bovaDmmUq|=|5Km8fk?TJ*iSz1@1X3R@9uP6K9P_$(7^$hIQc! zmO64M4631}8@AM;VQ9xTLX4zvePH=6p@E(tAlo=N_ z1^TThTC9d0y;>{?1h%P5Wb?QY+3J6klb4Ny#P?y1_0@0{5GJ{mbLEm}$mTLu*s44V zhp*X#>Tu+zi$FOzMWgy0keZ6psZ=rPB1i-IrgkFfc4H^*)4K;`GejcuwswOoY1*Y} zVUb73d7JmodXW~-%wf;bG;Gc1-vi|GV_u823x~2oS4YAof2CqEwapgtEoxL#tTOj0 zs7{;ca6Y4V0MoRNwMtHoX-nyD!qGLrdXUcxEHtPV zXHgXfXNcd#l~V909b=WZ%Fi3M?1dE>)XpF_dTU!WPHfPkCScb0SRjXh<{j>MyOG(!>=Jnol@~f^rj;P58K=q63>)jw% zoek-ugLj}@ z?PgXwev)h%G)LBHCWLSj4y*92k4F^dnObAtb44qYdL@+R%k!~>r7ue(XOsr%g#b1c zVP3LT*InD*2H&sDT4Eh9zh(R&f{Dp1Ue8io5u`;q6ydJU!c||swX3nKdguO$Srtb9 zN~TPykk^M9p^eF?s5S-l*eW{tcN{QqvCN6b?vecKw)vOR>(|LJUcRCV8pfDWdSz}< ziDbKm=MRNxvN(i@)zuwsvE*?QG$mt4Re-iwTyl9K`!~KY=o-@SPsv84{*B@W_Kn&f zJe`r;D#NwKJg4+%wBKfH>pFLL!h3(!03Q^uxQ?c62wXot8F7%l(O^Ej$##R;74Z_o zy#%JV9kF{(vkwL*$I&0I2inKkvnAQKinC^slxKcXHhxeZ8M)ze@vDa^)ncGGEVp^l zrJCMz&vbx{V5}_bWa2DLvwHa;YT#?oNUE~$?u_$Z+W}0r0*CWddXeWl-z0R?H)jCO zbO2{7oRD-2ADakEY!dy~q2G>?9QSSHIjLkF!rStO>=i0uU$b2(iZY@}k$x4jTw^YG zDUkcCDXN*=G%$EmieX8AGz}=C4pvpiQJNaU@*|HW>$8RHJuEw{!^JLRDl4uiaxu{g z*Hm)Tj5)He6(0*MgT#fZS=<7+a1_+NRGW}Ihj2W@K>k@SfRaTa>HVq6VXqVTd9=SS zy*wHl)YT$9JvOhFZK757(^xB_{5e_U(sER~#kreSw|0WDOjmfidFhiFJ|W0M9dg+L z{>Cej@&Qlbhz}-Kh%GVo-ax*A?a6DZTTJTnE$t0n*6%V%dZWmg$uY-{u1KZFeAYOu z!jYprW*Jq4QAOHSHn!O4d92Uqc&Fd^yMS`R>O)lljZkAgP|;S|Y#N0f85Wg_TH_r+ z$CYq-IjrNaD{NG%x^|UM%LY|+Bdos$FrvQ8l|-;KVrirVa?h%sR%kp%k9@aPS9tP~ zqF0_1eYU#6i`OAD;jpyCw->cQgt;>qvRW+JE5k3({`t8g%!*aid^^!uS#e;Q$jB;| zQnaH@dNa7XaV3+g%h((cERL0CImz8Y(!y9|$>#-?^d>*mVc4SXF)%`9C*(gcXn9qkFL>Ea8hsWO#wXWO;rTQjvLY^x=oJ*7}SO|6m(-(KMKfXa2+OhQvk2~)*= zgg6Fshnujmt62JssgSI{lqaiP2#nVq>-5u<9J1`^7D-p}#!06gXAD$Y2Fp*nA5fIj z7A#?YPaEsYtg*};wY1Q(3?JFxNz6md69XVo@;Q`YH;0qDI7;GYMkRiOfOPnr*lD#1 zc35w|Y*$fNzHCg}Jl%1vQG*)KiPj|QtW7V+{`qf>eG92oA4Zng2N#!2@frNH_@xi*&KjPkGH+F%mOA1^;04kQ?n*f~;+oV)&F)n36f`cQ z^OjouAv6-|L-x^@n%X3)d=FdHKrZ&taAL`!AHx#gOYwuapjbrl!SktjjJjseooq)? z;?9}IPuF+}F}GqVNmUb0#yB3%>w71tM~UMRkqJ*sB#z1I$cN#D$x_MFEB5^anu%B9W^~)vwO^h2vHe@kTPq6AI}B9*(eGgiRa-J2X10 zkAAvpqE@oM*jSEtVaM2sUFxhFQw@oJHFmj(3GloNap!PS+As?1%J)+^BQ=v<7Q={< z`@~)Q3y(W~XZcmxTILI?l_V4A=YYnTUyO8%2;CLK(8rTZ2(@6wBlYwFx~TYCRd`U< z71J!~Ev3}L@o@32lqh@2$)>3{aMfEaUyef}X??*g$+B1hKn7y)b@dgB^0gUujrddm z86|urP}#LZ|LVqrmE*j)D5*$Lh~0>eX0?dq&jK^z!rvGA7%@DKAI-K20wuTH*71$j zFaGREpcFqieieselo@dTC?|m~1Q`{YS7f{lw&wj-bZh~PDZjNEwvfi+>{lZg8?Hq6 z|AL?!p*h}@DFODcj(mx;LH8qVyph`tH(?JlAnNjYT@k z_1vU1>sI8@p@SuX`#U=p}U>2LLIP~C7B zf32XnJo>W`rq}S9?S=f`@W7{8XH{@{j3RohrFZ99^u_A67tzVsu2y2uS|jgF4!qa{ z&D-0zK7S!{{I=FR|K{EZ?ayYIP}w(B%ASo~(Y!hHz`EFD|% zKQ@;4@{U*}pvqT0)n}AIM8uI&w{~qP-aw5sW9)?#Dvq);#llJQg?*G8{gqSWFVM4u~6yEjuhh5FttYhaH z&E5lbD(OVdrY=`Nua4@I3S_02V)IQnM^Kx{T{v@dGmMuCEnlWRmReEnPu`*Idy4$3 zMDnO|8}mYpFHlbU;`1Ao3`TsVk0s_2He%DhXv0r^@O68_15Z&?%jJ29p=SwC4u!^S z={LrRi*tKvpGi2#HNCDHH^#_NMmEXg0MCew5>gIxAt!W3su3 zQoJ_Ho~il^;@jhyv93dRx*~&g5acEv^l=xlXy8f_$1izvK3%r@>Wr5pU5%%JipU&F zz%-G1{micbEAj${0UD=J>44Q_2udXPQtL1N9V@XN#n&ghP|TFt{M)uznbG;kPKkxq zXz^3W<{CAMcY&_e)kMrhk9H%@82(i1u?JrR*yWPv#%1U#UqV6T)3|5Vh994nAC*b) zK!_h_2loN-jCS>3{GghUHXHDZLrF#S8;VK07!ELqsGVVS0+*D;-=Kbfzt))y;L7U_ zdd;^twIgjP)nnIbY;v&wNa@&5$`7j`*Nq#bowDB|*C~5B;4NI7F*TJo;(Rg32c1Ja zV`q-LnnDs85<*-QVg-6i?yE^}!dJqOD6y~lN=MOJa9fy+5j)=(nd~tn1*p$@&&tp& z*h2msAyWvY(&vgkErTU$%{!YP*R!$;)zCKZd2oP_m*$HTW`JBTrKqy-kj?NNRU+B_ z=PcJIdck(C*)`r-dY{T(n2GAwR=VFt&B-+xNm*CbnA55?s@B82sF)w%7k@7Lc{=2S z3P8oN{bFpFsC{Vk@<~k%3CJmH)9SzH$kHrZExHvpKxv!@r)m_o)*l9jSH1E>Q_}xP z| z+A#thEPO}_Udy~gcBcQYWk~?A!_x4a`Q{x91{q{j!YSTfdsoFkL4_Q~_dcipDN+2r zxy9;CvTDrqH^VF08fYc_4rFLJY|+ryWPwyp+aUrHp6_z<{=NNa%>quZpCX)zXio=_ zqM$r`*=YY{;-gKc$}W&-?(}628KgP92Bmc)5jo2 zO9@Zp&)kdwEOi6nnbVS%Jt~Qh$l{PVY`j zC}PIU?j=T!izy?NK^9M+77CdHRX+7g)k%!*Q3P~X*Vkt{NGt-#Gsl$fmm(qVDKpTrc+dvj| z$0eiVt*HeCj7jZ6;81MSa0gyR?mkGcaHp9Bt?m&Hp1`Z^Qdkr-&9XS1h?_ZMduL5- z!`y$WhXEuVCv)%B=GRqFcVd{AmeH01ZFodtE?;GqNBg{F6p6=#m9 zjo_p{9gz~=s;-K}E6F?Ffw+q-7gR@BG_u*NW;y7-7NCpBs+QMM%+OC|j`13_fKf|* zy$t2S_lb|b1}T(0am8e1y3P|IWa(EkHO!F$Bd$O?oD(7?tk#s|&pb;6G!R+b`&N4k zFTBQytjhcEs(Y27s8FX~LYg!)%-xcWx)K9urHE_ui2PfFs5ufOcZsTS8pn5};J7T& z*1ZxwWLk9L?3_< z9-$Z8A6dLAEu^HSE~WX{24qcoa(JKMo{ML?ZZ)SeG9baYY_y95r6YVw=w|z7*TRCz z(9%U)OG3--Zd$S_RvGqjba!bE%ZX)iL)0WcrnATUw`As&5NrL4sU~5#&knvJ<5q_=Y+J7@ z)Z4#u7c|D)oS2Kp$wj9%@V5M;WIuJS8Hb-(5g3V}VEv|4+FKguFzVHe=Ljm8ETzmj z!@qx%o%V+gm+y`Y$Biu@Sd_kTGFY0Ya<*=s*^&)AL!%P$E>$#Akt#!QoRlR3pZ!5}uWDAyQ8B%X#^2 zW1ma>Ri^tKY62<$F#kQ#W91bL>fSHbq|wxU$J$<^B(6LW207V@AAy#Asf+XuIco*66C>e_QUaUuaLf zU}>@)XPERrxoS$o1y8#p$Xtww)MKR3erDmA4ynIa3_#TWrCU(44|^Yky`c7=Jz^2( zY>yNJ!L#--Lm}+~PoCfwA1ipW$re;`>dJYh+;_h@I3aceFQ5C9j6*dWBNU_S&*F=} z2ZN((4`Q1@D(6L_5$Ct?rviu24s(ynwiu})%za-;FK>KiHb!1wnMEhS241L<_9=_| zs_uaUd)3vk-vaYsHo) z;YUzHfY3^quEO`7t#cCKE`0+QOX8;m+CmXjdgK*I*ZXY}Sh7Z^j3Mi>SY?T3=n|gc zeMZJ_E3xN&cVnex)z%FHt(g0hhGDT92?6{ZG~UP*QEk=|gke-kg88vCr*%w2JGu4D z{0#$9ws*=mxHGBI^j5E#_%5@YxfJaJ~7HBouM?*Q+xS}