{"config":{"lang":["en"],"separator":"[\\s\\u200b\\u3000\\-\u3001\u3002\uff0c\uff0e\uff1f\uff01\uff1b]+","pipeline":["stemmer"],"fields":{"title":{"boost":1000.0},"text":{"boost":1.0},"tags":{"boost":1000000.0}}},"docs":[{"location":"","title":"\u6570\u5b66\u3001\u8ba1\u7b97\u673a\u79d1\u5b66\u4e0e\u4eba\u5de5\u667a\u80fd\u7eb2\u8981","text":"

\u5728\u7ebf\u9605\u8bfb: henryndubuaku.github.io/maths-cs-ai-compendium

"},{"location":"#_2","title":"\u6982\u8ff0","text":"

\u5927\u591a\u6570\u6559\u79d1\u4e66\u5c06\u597d\u7684\u601d\u60f3\u57cb\u6ca1\u5728\u5bc6\u96c6\u7684\u7b26\u53f7\u4e4b\u4e0b\uff0c\u8df3\u8fc7\u76f4\u89c9\uff0c\u5047\u8bbe\u4f60\u5df2\u7ecf\u638c\u63e1\u4e86\u4e00\u534a\u7684\u5185\u5bb9\uff0c\u5e76\u4e14\u5728\u4eba\u5de5\u667a\u80fd\u7b49\u5feb\u901f\u53d1\u5c55\u7684\u9886\u57df\u5f88\u5feb\u8fc7\u65f6\u3002\u8fd9\u662f\u4e00\u672c\u5f00\u653e\u3001\u975e\u4f20\u7edf\u7684\u6559\u79d1\u4e66\uff0c\u4ece\u96f6\u5f00\u59cb\u6db5\u76d6\u6570\u5b66\u3001\u8ba1\u7b97\u673a\u79d1\u5b66\u548c\u4eba\u5de5\u667a\u80fd\u3002\u4e3a\u90a3\u4e9b\u5e0c\u671b\u6df1\u5165\u7406\u89e3\u77e5\u8bc6\u3001\u800c\u4e0d\u4ec5\u4ec5\u662f\u4e3a\u4e86\u901a\u8fc7\u8003\u8bd5\u6216\u9762\u8bd5\u7684\u597d\u5947\u5b9e\u8df5\u8005\u800c\u7f16\u5199\u3002

"},{"location":"#_3","title":"\u80cc\u666f","text":"

\u5728\u8fc7\u53bb\u51e0\u5e74\u4ece\u4e8bAI/ML\u5de5\u4f5c\u7684\u8fc7\u7a0b\u4e2d\uff0c\u6211\u7528\u7b14\u8bb0\u672c\u8bb0\u5f55\u4e86\u6570\u5b66\u3001\u8ba1\u7b97\u673a\u79d1\u5b66\u548c\u4eba\u5de5\u667a\u80fd\u6982\u5ff5\u7684\u76f4\u89c9\u4f18\u5148\u3001\u7ed3\u5408\u5b9e\u9645\u3001\u4e0d\u6253\u9a6c\u864e\u773c\u7684\u89e3\u91ca\u30022025\u5e74\uff0c\u51e0\u4f4d\u670b\u53cb\u7528\u8fd9\u4e9b\u7b14\u8bb0\u51c6\u5907DeepMind\u3001OpenAI\u3001Nvidia\u7b49\u516c\u53f8\u7684\u9762\u8bd5\u3002\u4ed6\u4eec\u5168\u90e8\u88ab\u5f55\u7528\uff0c\u76ee\u524d\u5728\u5de5\u4f5c\u4e2d\u8868\u73b0\u51fa\u8272\u3002\u800c\u6211\u53bb\u5e74\u4e5f\u8fdb\u5165\u4e86Y Combinator\u3002\u6240\u4ee5\u73b0\u5728\u6211\u628a\u8fd9\u4e9b\u5206\u4eab\u7ed9\u6240\u6709\u4eba\u3002

"},{"location":"#mcp","title":"MCP \u670d\u52a1\u5668","text":"

\u672c\u4ed3\u5e93\u5305\u542b\u4e00\u4e2aMCP\u670d\u52a1\u5668\uff0c\u5141\u8bb8\u4efb\u4f55AI\u52a9\u624b\uff08Claude Code\u3001Cursor\u3001VS Code\u7b49\uff09\u5c06\u8fd9\u672c\u7eb2\u8981\u4f5c\u4e3a\u77e5\u8bc6\u5e93\u4f7f\u7528\u3002\u5b83\u9700\u8981\u672c\u5730\u514b\u9686\u8be5\u4ed3\u5e93\u3002\u5185\u7f6e\u6559\u80b2\u7528\u9014\u7684\u5de5\u5177\u548c\u793a\u4f8b\u5b9e\u73b0\u3002

"},{"location":"#_4","title":"\u5185\u5bb9\u5927\u7eb2","text":"# \u7ae0\u8282 \u7b80\u4ecb \u72b6\u6001 01 \u5411\u91cf \u7a7a\u95f4\u3001\u6a21\u957f\u3001\u65b9\u5411\u3001\u8303\u6570\u3001\u5ea6\u91cf\u3001\u70b9\u79ef/\u53c9\u79ef/\u5916\u79ef\u3001\u57fa\u3001\u5bf9\u5076\u6027 \u5df2\u5b8c\u6210 02 \u77e9\u9635 \u6027\u8d28\u3001\u7279\u6b8a\u7c7b\u578b\u3001\u8fd0\u7b97\u3001\u7ebf\u6027\u53d8\u6362\u3001\u5206\u89e3\uff08LU\u3001QR\u3001SVD\uff09 \u5df2\u5b8c\u6210 03 \u5fae\u79ef\u5206 \u5bfc\u6570\u3001\u79ef\u5206\u3001\u591a\u5143\u5fae\u79ef\u5206\u3001\u6cf0\u52d2\u8fd1\u4f3c\u3001\u4f18\u5316\u4e0e\u68af\u5ea6\u4e0b\u964d \u5df2\u5b8c\u6210 04 \u7edf\u8ba1\u5b66 \u63cf\u8ff0\u6027\u5ea6\u91cf\u3001\u62bd\u6837\u3001\u4e2d\u5fc3\u6781\u9650\u5b9a\u7406\u3001\u5047\u8bbe\u68c0\u9a8c\u3001\u7f6e\u4fe1\u533a\u95f4 \u5df2\u5b8c\u6210 05 \u6982\u7387\u8bba \u8ba1\u6570\u3001\u6761\u4ef6\u6982\u7387\u3001\u5206\u5e03\u3001\u8d1d\u53f6\u65af\u65b9\u6cd5\u3001\u4fe1\u606f\u8bba \u5df2\u5b8c\u6210 06 \u673a\u5668\u5b66\u4e60 \u7ecf\u5178\u673a\u5668\u5b66\u4e60\u3001\u68af\u5ea6\u65b9\u6cd5\u3001\u6df1\u5ea6\u5b66\u4e60\u3001\u5f3a\u5316\u5b66\u4e60\u3001\u5206\u5e03\u5f0f\u8bad\u7ec3 \u5df2\u5b8c\u6210 07 \u8ba1\u7b97\u8bed\u8a00\u5b66 \u53e5\u6cd5\u5b66\u3001\u8bed\u4e49\u5b66\u3001\u8bed\u7528\u5b66\u3001\u81ea\u7136\u8bed\u8a00\u5904\u7406\u3001\u8bed\u8a00\u6a21\u578b\u3001RNN\u3001CNN\u3001\u6ce8\u610f\u529b\u673a\u5236\u3001Transformer\u3001\u6587\u672c\u6269\u6563\u3001\u6587\u672cOCR\u3001MoE\u3001SSM\u3001\u73b0\u4ee3LLM\u67b6\u6784\u3001\u81ea\u7136\u8bed\u8a00\u5904\u7406\u8bc4\u4f30 \u5df2\u5b8c\u6210 08 \u8ba1\u7b97\u673a\u89c6\u89c9 \u56fe\u50cf\u5904\u7406\u3001\u76ee\u6807\u68c0\u6d4b\u3001\u5206\u5272\u3001\u89c6\u9891\u5904\u7406\u3001SLAM\u3001CNN\u3001\u89c6\u89c9Transformer\u3001\u6269\u6563\u6a21\u578b\u3001\u6d41\u5339\u914d\u3001VR/AR \u5df2\u5b8c\u6210 09 \u97f3\u9891\u4e0e\u8bed\u97f3 \u6570\u5b57\u4fe1\u53f7\u5904\u7406\u3001\u81ea\u52a8\u8bed\u97f3\u8bc6\u522b\u3001\u6587\u672c\u8f6c\u8bed\u97f3\u3001\u8bed\u97f3\u4e0e\u58f0\u5b66\u6d3b\u52a8\u68c0\u6d4b\u3001\u8bf4\u8bdd\u4eba\u5206\u79bb\u3001\u6e90\u5206\u79bb\u3001\u4e3b\u52a8\u964d\u566a\u3001WaveNet\u3001Conformer \u5df2\u5b8c\u6210 10 \u591a\u6a21\u6001\u5b66\u4e60 \u878d\u5408\u7b56\u7565\u3001\u5bf9\u6bd4\u5b66\u4e60\u3001CLIP\u3001\u89c6\u89c9\u8bed\u8a00\u6a21\u578b\u3001\u56fe\u50cf/\u89c6\u9891\u5206\u8bcd\u3001\u8de8\u6a21\u6001\u751f\u6210\u3001\u7edf\u4e00\u67b6\u6784\u3001\u4e16\u754c\u6a21\u578b \u5df2\u5b8c\u6210 11 \u81ea\u4e3b\u7cfb\u7edf \u611f\u77e5\u3001\u673a\u5668\u4eba\u5b66\u4e60\u3001\u89c6\u89c9-\u8bed\u8a00-\u52a8\u4f5c\u6a21\u578b\u3001\u81ea\u52a8\u9a7e\u9a76\u3001\u592a\u7a7a\u673a\u5668\u4eba \u5df2\u5b8c\u6210 12 \u56fe\u795e\u7ecf\u7f51\u7edc \u51e0\u4f55\u6df1\u5ea6\u5b66\u4e60\u3001\u56fe\u8bba\u3001GNN\u3001\u56fe\u6ce8\u610f\u529b\u673a\u5236\u3001\u56feTransformer\u3001\u4e09\u7ef4\u7b49\u53d8\u7f51\u7edc \u5df2\u5b8c\u6210 13 \u8ba1\u7b97\u4e0e\u64cd\u4f5c\u7cfb\u7edf \u79bb\u6563\u6570\u5b66\u3001\u8ba1\u7b97\u673a\u4f53\u7cfb\u7ed3\u6784\u3001\u64cd\u4f5c\u7cfb\u7edf\u3001\u5e76\u53d1\u3001\u5e76\u884c\u3001\u7f16\u7a0b\u8bed\u8a00 \u5df2\u5b8c\u6210 14 \u6570\u636e\u7ed3\u6784\u4e0e\u7b97\u6cd5 \u5927O\u8868\u793a\u6cd5\u3001\u9012\u5f52\u3001\u56de\u6eaf\u3001\u52a8\u6001\u89c4\u5212\u3001\u6570\u7ec4\u3001\u54c8\u5e0c\u3001\u94fe\u8868\u3001\u6808\u3001\u6811\u3001\u56fe\u3001\u6392\u5e8f\u3001\u4e8c\u5206\u67e5\u627e \u5df2\u5b8c\u6210 15 \u751f\u4ea7\u7ea7\u8f6f\u4ef6\u5de5\u7a0b Linux\u3001Git\u3001\u4ee3\u7801\u5e93\u8bbe\u8ba1\u3001\u6d4b\u8bd5\u3001CI/CD\u3001Docker\u3001\u6a21\u578b\u670d\u52a1\u3001MLOps\u3001\u76d1\u63a7\u3001\u4f7f\u7528\u7f16\u7801\u4ee3\u7406\u7684\u6700\u4f73\u5b9e\u8df5 \u5df2\u5b8c\u6210 16 SIMD\u4e0eGPU\u7f16\u7a0b \u9762\u5411\u673a\u5668\u5b66\u4e60\u7684C++\u3001\u6846\u67b6\u5de5\u4f5c\u539f\u7406\u3001\u786c\u4ef6\u57fa\u7840\u3001ARM NEON/I8MM/SME2\u3001x86 AVX\u3001GPU/CUDA\u3001Triton\u3001TPU\u3001RISC-V\u3001Vulkan\u3001WebGPU \u5df2\u5b8c\u6210 17 AI\u63a8\u7406 \u91cf\u5316\u3001\u9ad8\u6548\u67b6\u6784\u3001\u670d\u52a1\u4e0e\u6279\u5904\u7406\u3001\u8fb9\u7f18\u63a8\u7406\u3001\u63a8\u6d4b\u89e3\u7801\u3001\u6210\u672c\u4f18\u5316 \u5df2\u5b8c\u6210 18 ML\u7cfb\u7edf\u8bbe\u8ba1 \u7cfb\u7edf\u57fa\u7840\u3001\u4e91\u8ba1\u7b97\u3001\u5206\u5e03\u5f0f\u7cfb\u7edf\u3001ML\u751f\u547d\u5468\u671f\u3001\u7279\u5f81\u5b58\u50a8\u3001A/B\u6d4b\u8bd5\u3001\u63a8\u8350/\u641c\u7d22/\u5e7f\u544a/\u6b3a\u8bc8\u8bbe\u8ba1\u5b9e\u4f8b \u5df2\u5b8c\u6210 19 \u5e94\u7528\u4eba\u5de5\u667a\u80fd \u91d1\u878d\u3001\u533b\u7597\u5065\u5eb7\u3001\u86cb\u767d\u8d28\u3001\u836f\u7269\u53d1\u73b0\u4e2d\u7684\u4eba\u5de5\u667a\u80fd \u5f85\u5b8c\u6210 20 \u524d\u6cbf\u4eba\u5de5\u667a\u80fd \u91cf\u5b50\u673a\u5668\u5b66\u4e60\u3001\u795e\u7ecf\u5f62\u6001\u673a\u5668\u5b66\u4e60\u3001\u53bb\u4e2d\u5fc3\u5316\u4eba\u5de5\u667a\u80fd\u3001\u592a\u7a7a\u6570\u636e\u4e2d\u5fc3\u3001\u8111\u673a\u63a5\u53e3 \u5f85\u5b8c\u6210"},{"location":"#_5","title":"\u524d\u8a00","text":"

\u65b0\u751f\u5a74\u513f\u7684\u5927\u8111\u662f\u4e00\u4e2a\u65b0\u521d\u59cb\u5316\u7684\u795e\u7ecf\u7f51\u7edc\uff0c\u901a\u8fc7\u73b0\u5b9e\u4e16\u754c\u7684\u6570\u636e\u548c\u7ecf\u9a8c\u8bad\u7ec3\u76f4\u81f3\u6210\u5e74\u2026\u2026\u76f4\u81f3\u6c38\u8fdc\u3002\u80fd\u591f\u7528\u6cd5\u8bed\u6d41\u5229\u4ea4\u6d41\u5e76\u62e5\u6709\u5b8c\u7f8e\u53e3\u97f3\uff0c\u610f\u5473\u7740\u63a5\u89e6\u5230\u4e86\u4f18\u79c0\u7684\u6cd5\u8bed\u548c\u5b8c\u7f8e\u53e3\u97f3\u3002\u540c\u6837\uff0c\u4f18\u79c0\u7684\u4eba\u5de5\u667a\u80fd\u7814\u7a76\u5458\u548c\u5de5\u7a0b\u5e08\u5177\u5907\u51fa\u8272\u7684\u95ee\u9898\u89e3\u51b3\u80fd\u529b\uff0c\u610f\u5473\u7740\u4ed6\u4eec\u5438\u6536\u4e86\u9ad8\u8d28\u91cf\u7684\u77e5\u8bc6\u5e76\u62e5\u6709\u4e30\u5bcc\u7684\u7ecf\u9a8c\u3002

\u79d1\u74e6\u820d\u592b\u5b9e\u9a8c\u662f\u4e00\u9879\u957f\u671f\u7684\u585e\u5c14\u7ef4\u4e9a\u7814\u7a76\uff0c\u8868\u660e\u4e3a\u671f\u4e09\u5e74\u7684\u9ad8\u5f3a\u5ea6\u521b\u9020\u6027\u95ee\u9898\u89e3\u51b3\u8bad\u7ec3\u53ef\u4ee5\u663e\u8457\u63d0\u9ad8\u667a\u529b\uff0c\u5c24\u5176\u662f\u6d41\u4f53\u667a\u529b\uff0c\u63d0\u534710-15\u4e2aIQ\u70b9\u3002\u5f53\u7136\uff0c\u5929\u751f\u9ad8IQ\u662f\u771f\u5b9e\u5b58\u5728\u7684\uff0c\u5c31\u50cf\u4f18\u8d28\u7684\u6743\u91cd\u521d\u59cb\u5316\u80fd\u5e26\u6765\u66f4\u597d\u7684\u8bad\u7ec3\u6548\u679c\u4e00\u6837\u2014\u2014\u5148\u5929\u4e0e\u540e\u5929\u4e4b\u4e89\u7684\u5b9e\u9a8c\u7ed3\u679c\u4e5f\u8bc1\u660e\u4e86\u8fd9\u4e00\u70b9\u3002

\u7136\u800c\uff0c\u9ad8IQ\u4e2a\u4f53\u7684\u771f\u6b63\u4f18\u52bf\u4ec5\u5728\u4e8e\u80fd\u66f4\u5feb\u5730\u5b66\u4e60\u548c\u8bc6\u522b\u6a21\u5f0f\u3002\u4f46\u91cd\u590d\u4f7f\u7528\u4e00\u79cd\u6a21\u5f0f\u53ef\u4ee5\u4f7f\u4efb\u4f55\u6982\u5ff5\u90fd\u53d8\u5f97\u7edd\u5bf9\u53ef\u5b66\u3002\u67e5\u5c14\u65af\u00b7\u8fbe\u5c14\u6587\u88ab\u4ed6\u7684\u8001\u5e08\u548c\u7236\u4eb2\u8ba4\u4e3a\u662f\u4e00\u4e2a\u975e\u5e38\u666e\u901a\u3001\u751a\u81f3\u4f4e\u4e8e\u5e73\u5747\u6c34\u5e73\u7684\u5b66\u751f\u3002\u4ed6\u81ea\u79f0\u5e76\u4e0d\u673a\u667a\uff0c\u611f\u89c9\u81ea\u5df1\u50cf\u4e00\u4e2a\"\u6162\u5904\u7406\u5668\"\uff0c\u9700\u8981\u65f6\u95f4\u6765\u5438\u6536\u6570\u636e\u3002

\u57283\u523010\u5c81\u4e4b\u95f4\uff0c\u6211\u7684\u5b66\u4e60\u6210\u7ee9\u5f88\u597d\uff0c\u81ea\u7136\u800c\u7136\u5730\u7406\u89e3\u6982\u5ff5\uff0c\u4ece\u4e0d\u505a\u7b14\u8bb0\u6216\u590d\u4e60\u300211\u523013\u5c81\u4e4b\u95f4\u6211\u6709\u70b9\u81ea\u5927\uff0c\u7528\u8fd9\u79cd\u65b9\u5f0f\u5728\u4e00\u4e2a80\u4eba\u7684\u73ed\u7ea7\u4e2d\u8dcc\u5230\u4e86\u4e0b\u534a\u90e8\u5206\u300214\u523015\u5c81\u4e4b\u95f4\uff0c\u6211\u5f00\u59cb\u50cf\u666e\u901a\u5b66\u751f\u4e00\u6837\u8bfb\u4e66\uff0c\u5728\u4e2d\u5b66\u6700\u540e\u4e00\u4e2a\u5b66\u671f\u53d6\u5f97\u4e86\u7b2c\u4e00\u540d\u3002\u65e9\u671f\u5b66\u6821\u8bfe\u7a0b\u4e0e\u81ea\u7136IQ\u914d\u5408\u5f97\u5f88\u597d\uff0c\u4f46\u73b0\u5b9e\u4e16\u754c\u7684\u624d\u534e\u6e90\u4e8e\u9ad8\u8d28\u91cf\u7684\u77e5\u8bc6\u6444\u5165\u548c\u6267\u884c\u529b\u5ea6\u3002

\u4e8b\u5b9e\u4e0a\uff0c\u5927\u591a\u6570\u5b66\u4e60\u6210\u7ee9\u597d\u7684\u5b66\u751f\u53ea\u662f\u66f4\u52e4\u594b\uff0c\u4f46\u5b66\u672f\u7cfb\u7edf\u662f\u4e3a\u5feb\u901f\u5b66\u4e60\u8005\u8bbe\u8ba1\u7684\u3002\u8fd9\u672c\u7eb2\u8981\u63d0\u4f9b\u4e86\u4e00\u4e2a\u5168\u9762\u4e14\u76f8\u4e92\u5173\u8054\u7684\u77e5\u8bc6\u6d41\uff0c\u4ee5\u5e2e\u52a9\u4e16\u754c\u4e0a\u90a3\u4e9b\"\u8fbe\u5c14\u6587\u4eec\"\u66f4\u597d\u5730\u5b66\u4e60\u3002\u4f60\u53ea\u9700\u8981\u521d\u7b49\u6570\u5b66\u57fa\u7840\u548c\u57fa\u672c\u7684Python\u7f16\u7a0b\u77e5\u8bc6\uff0c\u5176\u4ed6\u4e00\u5207\u90fd\u4f1a\u9010\u6b65\u638c\u63e1\u2014\u2014\u53ea\u9700\u9605\u8bfb\u5e76\u76f8\u4fe1\u8fd9\u4e2a\u8fc7\u7a0b\uff01

"},{"location":"#_6","title":"\u5982\u4f55\u66f4\u597d\u5730\u5b66\u4e60","text":"

\u5927\u5b66\u7b2c\u4e00\u5b66\u671f\uff0c\u6211\u540c\u65f6\u9009\u4e8617\u95e8\u8bfe\uff0c\u6210\u7ee9\u5e76\u4e0d\u7406\u60f3\uff0c\u4e8e\u662f\u6211\u91c7\u7528\u4e86\u4e00\u4e2a\u6280\u5de7\uff1a

\u7b2c\u4e00\u9636\u6bb5\uff1a\u8bfe\u540e\u7d2f\u79ef\u9605\u8bfb \u53ea\u9605\u8bfb\u6bcf\u5f20\u5e7b\u706f\u7247/\u7b14\u8bb0\u7684\u6807\u9898/\u5927\u6807\u9898\uff0c\u5408\u4e0a\u4e66\uff0c\u7136\u540e\u5728\u8111\u6d77\u4e2d\u53ef\u89c6\u5316\u5e76\u5199\u51fa\u5bf9\u8be5\u6982\u5ff5\u7684\u89e3\u91ca\u3002\u53ea\u91cd\u8bfb\u4f60\u9057\u6f0f\u7684\u90e8\u5206\uff0c\u7c7b\u4f3c\u4e8e\u673a\u5668\u5b66\u4e60\u4e2d\u7684\u63a9\u7801\u8bed\u8a00\u5efa\u6a21\u3002\u91cd\u8bfb\u4e4b\u540e\uff0c\u6700\u7ec8\u5c06\u6982\u5ff5\u7528\u4ee3\u7801\u5b9e\u73b0\u3002\u8fd9\u6837\u4f60\u5c31\u80fd\u5bf9\u6bcf\u4e2a\u6982\u5ff5\u5f62\u6210\u808c\u8089\u8bb0\u5fc6\u3002

\u7b2c\u4e8c\u9636\u6bb5\uff1a\u8003\u524d\u5f71\u5b50\u9605\u8bfb \u9605\u8bfb\u6bcf\u5f20\u5e7b\u706f\u7247/\u7b14\u8bb0\u7684\u526f\u6807\u9898\uff0c\u5408\u4e0a\u4e66\uff0c\u7136\u540e\u5728\u8111\u6d77\u4e2d\u53ef\u89c6\u5316\u5e76\u5199\u51fa\u5bf9\u8be5\u6982\u5ff5\u7684\u89e3\u91ca\u3002\u53ea\u91cd\u8bfb\u4f60\u9057\u6f0f\u7684\u90e8\u5206\uff0c\u7c7b\u4f3c\u4e8e\u673a\u5668\u5b66\u4e60\u4e2d\u7684\u63a9\u7801\u8bed\u8a00\u5efa\u6a21\u3002\u91cd\u8bfb\u4e4b\u540e\uff0c\u6700\u7ec8\u5c06\u6982\u5ff5\u7528\u4ee3\u7801\u5b9e\u73b0\u3002\u8fd9\u6837\u4f60\u5c31\u80fd\u5bf9\u6bcf\u4e2a\u6982\u5ff5\u5f62\u6210\u808c\u8089\u8bb0\u5fc6\u3002

\u8fd9\u4e2a\u65b9\u6cd5\u5bf9\u6211\u4e0d\u592a\u81ea\u4fe1\u7684\u670b\u53cb\u4eec\u975e\u5e38\u6709\u6548\u3002\u4e8b\u5b9e\u4e0a\uff0c\u5176\u4e2d\u4e00\u4f4d\u670b\u53cb\u5728\u9ad8\u7b49\u5de5\u7a0b\u6570\u5b66\uff08\u6db5\u76d6\u6d77\u68ee\u77e9\u9635\u548c\u4f18\u5316\uff09\u8fd9\u95e8\u8bfe\u4e0a\u8d85\u8fc7\u4e86\u6211\u3002\u5979\u73b0\u5728\u5728\u4e00\u5bb6\u5927\u578b\u77f3\u6cb9\u5929\u7136\u6c14\u516c\u53f8\u5de5\u4f5c\u3002\u7075\u9b42\u7684\u610f\u613f\u6bd4\u6211\u4eec\u4e0e\u4e4b\u5de5\u4f5c\u7684\u8eab\u4f53\u66f4\u91cd\u8981\uff08\u7f57\u68ee\u5854\u5c14\u5b9e\u9a8c\uff09\u3002

"},{"location":"#_7","title":"\u5173\u4e8e\u4f5c\u8005","text":"

\u67e5\u770bGitHub\u4e2a\u4eba\u8d44\u6599\uff01

"},{"location":"#_8","title":"\u5f15\u7528","text":"
@book{ndubuaku2025compendium,\n  title     = {Maths, CS & AI Compendium},\n  author    = {Henry Ndubuaku},\n  year      = {2026},\n  publisher = {GitHub},\n  url       = {https://github.com/HenryNdubuaku/maths-cs-ai-compendium}\n}\n
"},{"location":"chapter%2001%3A%20vectors/01.%20vector%20spaces/","title":"\u5411\u91cf\u7a7a\u95f4","text":"

\u5411\u91cf\u7a7a\u95f4\u6784\u6210\u4e86\u673a\u5668\u5b66\u4e60\u7684\u6570\u5b66\u821e\u53f0\u3002\u672c\u6587\u6db5\u76d6\u5411\u91cf\u52a0\u6cd5\u3001\u6807\u91cf\u4e58\u6cd5\u3001\u5c01\u95ed\u6027\u516c\u7406\u3001\u5b50\u7a7a\u95f4\uff0c\u4ee5\u53ca\u4e3a\u4ec0\u4e48AI\u4e2d\u51e0\u4e4e\u6240\u6709\u4e1c\u897f\u90fd\u8868\u793a\u4e3a\u5411\u91cf\u3002

\\[\\mathbf{a} = [a_1, a_2, a_3]\\]

\\[\\mathbf{h} = [185, 75, 30]\\]

"},{"location":"chapter%2001%3A%20vectors/01.%20vector%20spaces/#colab-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
  1. \u8fd0\u884c\u4ee3\u7801\u9a8c\u8bc1\u5206\u914d\u5f8b\u6027\u8d28\uff0c\u7136\u540e\u4fee\u6539\u5e76\u5c1d\u8bd5\u6d4b\u8bd5\u5176\u4ed6\u89c4\u5219\uff01

    import jax.numpy as jnp\n\nu = jnp.array([1, 2])\nv = jnp.array([3, 0])\nc = 2\n\nlhs = c * (u + v)\nrhs = c*u + c*v\n\nprint(f\"LHS: {lhs}\")\nprint(f\"RHS: {rhs}\")\n

  2. \u8fd0\u884c\u4ee3\u7801\u53ef\u89c6\u5316\u4e0d\u540c\u7684\u5411\u91cf\uff0c\u7136\u540e\u4fee\u6539\u4e0d\u540c\u5750\u6807\u7684\u503c\u4ee5\u7406\u89e3\u6bcf\u4e2a\u8f74\u5982\u4f55\u5f71\u54cd\u4f4d\u7f6e\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u5c1d\u8bd5\u4fee\u6539\u8fd9\u4e9b\u5411\u91cf\uff01\na = jnp.array([3, 2, 4])\nb = jnp.array([1, 4, 2])\nc = jnp.array([4, 1, 3])\n\nfig = plt.figure()\nax = fig.add_subplot(111, projection=\"3d\")\n\nfor vec, name, color in [(a, \"a\", \"red\"), (b, \"b\", \"blue\"), (c, \"c\", \"green\")]:\n    ax.quiver(0, 0, 0, *vec, color=color, arrow_length_ratio=0.1, linewidth=2, label=name)\n\nlim = int(jnp.abs(jnp.stack([a, b, c])).max()) + 1\nax.set_xlim([0, lim]); ax.set_ylim([0, lim]); ax.set_zlim([0, lim])\nax.set_xlabel(\"X\"); ax.set_ylabel(\"Y\"); ax.set_zlabel(\"Z\")\nax.legend()\nplt.show()\n

"},{"location":"chapter%2001%3A%20vectors/02.%20vector%20properties/","title":"\u5411\u91cf\u6027\u8d28","text":"

\u5411\u91cf\u6027\u8d28\u63cf\u8ff0\u4e86\u5b9a\u4e49\u5411\u91cf\u884c\u4e3a\u7684\u51e0\u4f55\u548c\u4ee3\u6570\u7279\u5f81\u3002\u672c\u6587\u6db5\u76d6\u6a21\u957f\u3001\u65b9\u5411\u3001\u5355\u4f4d\u5411\u91cf\u3001\u76f8\u7b49\u6027\u3001\u5e73\u884c\u6027\u3001\u6b63\u4ea4\u6027\u548c\u7ebf\u6027\u65e0\u5173\u6027\uff0c\u5b83\u4eec\u662f\u6bcf\u4e2a ML \u7279\u5f81\u7a7a\u95f4\u7684\u57fa\u77f3\u3002

\\[\\|\\mathbf{a}\\| = \\sqrt{a_1^2 + a_2^2 + a_3^2}\\]

\\[\\mathbf{a} = \\mathbf{b} \\iff a_i = b_i \\text{ \u5bf9\u6240\u6709 } i\\] \\[\\mathbf{a} \\parallel \\mathbf{b} \\iff \\mathbf{a} = k\\mathbf{b} \\text{ \u5bf9\u4e8e\u67d0\u4e2a\u6807\u91cf } k \\neq 0\\]

\\[\\mathbf{s} = [0, 0, 3, 0, 0, 0, 1, 0, 0, 0]\\] \\[\\hat{\\mathbf{a}} = \\frac{\\mathbf{a}}{\\|\\mathbf{a}\\|}\\] "},{"location":"chapter%2001%3A%20vectors/02.%20vector%20properties/#colab-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
  1. \u8ba1\u7b97\u5411\u91cf\u7684\u6a21\u957f\u5e76\u9a8c\u8bc1\u5b83\u7b26\u5408\u52fe\u80a1\u5b9a\u7406\uff0c\u7136\u540e\u4fee\u6539\u4ee3\u7801\u8ba1\u7b97\u5355\u4f4d\u5411\u91cf\u3002

    import jax.numpy as jnp\n\na = jnp.array([3.0, 4.0])\n\nmagnitude = jnp.sqrt(jnp.sum(a ** 2))\nprint(f\"Magnitude of a: {magnitude}\") \n

  2. \u901a\u8fc7\u6d4b\u8bd5\u4e00\u4e2a\u5411\u91cf\u662f\u5426\u662f\u53e6\u4e00\u4e2a\u7684\u6807\u91cf\u500d\u6570\u6765\u68c0\u67e5\u4e24\u4e2a\u5411\u91cf\u662f\u5426\u5e73\u884c\u3002

    import jax.numpy as jnp\n\na = jnp.array([2, 4, 6])\nb = jnp.array([1, 2, 3])\n\nratios = a / b\nprint(f\"Ratios: {ratios}\")\nprint(f\"Parallel: {jnp.allclose(ratios, ratios[0])}\")\n

"},{"location":"chapter%2001%3A%20vectors/03.%20norms%20and%20metrics/","title":"\u5ea6\u91cf\u4e0e\u8303\u6570","text":"

\u8303\u6570\u8861\u91cf\u5355\u4e2a\u5411\u91cf\u7684\u5927\u5c0f\uff1b\u5ea6\u91cf\u8861\u91cf\u4e24\u4e2a\u5411\u91cf\u4e4b\u95f4\u7684\u8ddd\u79bb\u3002\u672c\u6587\u6db5\u76d6 L1\u3001L2 \u548c L-\u65e0\u7a77\u8303\u6570\u3001\u6b27\u51e0\u91cc\u5f97\u8ddd\u79bb\u548c\u4f59\u5f26\u8ddd\u79bb\uff0c\u4ee5\u53ca\u4e3a\u4ec0\u4e48\u4e3a kNN\u3001\u805a\u7c7b\u548c ML \u4e2d\u7684\u68c0\u7d22\u9009\u62e9\u5408\u9002\u7684\u8ddd\u79bb\u51fd\u6570\u81f3\u5173\u91cd\u8981\u3002

\\[\\|\\mathbf{v}\\|_2 = \\sqrt{v_1^2 + v_2^2 + \\cdots + v_n^2}\\] \\[\\|\\mathbf{v}\\|_1 = |v_1| + |v_2| + \\cdots + |v_n|\\] \\[\\|\\mathbf{v}\\|_\\infty = \\max(|v_1|, |v_2|, \\ldots, |v_n|)\\] \\[\\|\\mathbf{v}\\|_p = (|v_1|^p + |v_2|^p + \\cdots + |v_n|^p)^{1/p}\\] \\[d(\\mathbf{u}, \\mathbf{v}) = \\sqrt{(u_1 - v_1)^2 + (u_2 - v_2)^2 + \\cdots + (u_n - v_n)^2}\\] "},{"location":"chapter%2001%3A%20vectors/03.%20norms%20and%20metrics/#colab-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
  1. \u8ba1\u7b97\u540c\u4e00\u5411\u91cf\u7684 L1 \u548c L2 \u8303\u6570\u3002\u5c1d\u8bd5\u66f4\u6539\u503c\uff0c\u6ce8\u610f\u54ea\u4e2a\u8303\u6570\u5bf9\u5927\u7684\u5206\u91cf\u6700\u654f\u611f\uff0c\u54ea\u4e2a\u5bf9\u8bb8\u591a\u5c0f\u5206\u91cf\u6700\u654f\u611f\u3002\u7136\u540e\u5c1d\u8bd5\u8ba1\u7b97 p \u503c\u9012\u589e\uff08\u4f8b\u5982 1\u30012\u30015\u300110\u300150\u3001100\uff09\u65f6\u7684 Lp \u8303\u6570\uff0c\u89c2\u5bdf\u5b83\u5982\u4f55\u6536\u655b\u5230 L-\u65e0\u7a77\u503c\u3002

    import jax.numpy as jnp\n\nv = jnp.array([3.0, -4.0, 1.0])\n\nl1 = jnp.sum(jnp.abs(v))\nl2 = jnp.sqrt(jnp.sum(v ** 2))\n\nprint(f\"L1: {l1}, L2: {l2:.2f}\")\n

  2. \u8ba1\u7b97\u4e24\u4e2a\u5411\u91cf\u4e4b\u95f4\u7684\u6b27\u51e0\u91cc\u5f97\u8ddd\u79bb\u548c\u66fc\u54c8\u987f\u8ddd\u79bb\u3002\u5c1d\u8bd5\u8ba9\u5411\u91cf\u5f7c\u6b64\u9760\u8fd1\u6216\u8fdc\u79bb\uff0c\u89c2\u5bdf\u6bcf\u79cd\u8ddd\u79bb\u5982\u4f55\u4e0d\u540c\u5730\u54cd\u5e94\u3002

    import jax.numpy as jnp\n\nu = jnp.array([1.0, 2.0, 3.0])\nv = jnp.array([4.0, 0.0, 1.0])\n\neuclidean = jnp.sqrt(jnp.sum((u - v) ** 2))\nmanhattan = jnp.sum(jnp.abs(u - v))\n\nprint(f\"Euclidean: {euclidean:.2f}, Manhattan: {manhattan}\")\n

"},{"location":"chapter%2001%3A%20vectors/04.%20products/","title":"\u5411\u91cf\u79ef","text":"

\u5411\u91cf\u79ef\u662f\u8861\u91cf\u76f8\u4f3c\u6027\u548c\u8ba1\u7b97\u6295\u5f71\u7684\u57fa\u672c\u8fd0\u7b97\u3002\u672c\u6587\u6db5\u76d6\u5185\u79ef\u3001\u70b9\u79ef\u3001\u4f59\u5f26\u76f8\u4f3c\u5ea6\u3001\u53c9\u79ef\u548c\u5916\u79ef\uff0c\u8fd9\u4e9b\u8fd0\u7b97\u652f\u6491\u4e86 AI \u4e2d\u7684\u6ce8\u610f\u529b\u673a\u5236\u3001\u5d4c\u5165\u548c\u51e0\u4f55\u63a8\u7406\u3002

\\[\\mathbf{a} \\cdot \\mathbf{b} = a_1 b_1 + a_2 b_2 + \\cdots + a_n b_n\\] \\[\\mathbf{a} \\cdot \\mathbf{b} = \\|\\mathbf{a}\\| \\, \\|\\mathbf{b}\\| \\cos(\\theta)\\]

\\[\\text{proj}_{\\mathbf{b}}(\\mathbf{a}) = \\frac{\\mathbf{a} \\cdot \\mathbf{b}}{\\|\\mathbf{b}\\|^2} \\, \\mathbf{b}\\] \\[\\cos(\\theta) = \\frac{\\mathbf{a} \\cdot \\mathbf{b}}{\\|\\mathbf{a}\\| \\, \\|\\mathbf{b}\\|}\\] \\[\\mathbf{a} \\times \\mathbf{b} = (a_2 b_3 - a_3 b_2, \\; a_3 b_1 - a_1 b_3, \\; a_1 b_2 - a_2 b_1)\\] \\[\\|\\mathbf{a} \\times \\mathbf{b}\\| = \\|\\mathbf{a}\\| \\, \\|\\mathbf{b}\\| \\sin(\\theta)\\] \\[\\mathbf{a} \\times (\\mathbf{b} \\times \\mathbf{c}) = (\\mathbf{a} \\cdot \\mathbf{c})\\mathbf{b} - (\\mathbf{a} \\cdot \\mathbf{b})\\mathbf{c}\\] "},{"location":"chapter%2001%3A%20vectors/04.%20products/#colab-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
  1. \u8ba1\u7b97\u4e24\u4e2a\u5411\u91cf\u7684\u70b9\u79ef\u5e76\u7528\u5b83\u6c42\u51fa\u5b83\u4eec\u4e4b\u95f4\u7684\u89d2\u5ea6\u3002\u5c1d\u8bd5\u8ba9\u5b83\u4eec\u6b63\u4ea4\u3001\u5e73\u884c\u6216\u53cd\u5411\uff0c\u89c2\u5bdf\u89d2\u5ea6\u5982\u4f55\u53d8\u5316\u3002

    import jax.numpy as jnp\n\na = jnp.array([1.0, 2.0, 3.0])\nb = jnp.array([4.0, -1.0, 2.0])\n\ndot = jnp.dot(a, b)\nangle = jnp.arccos(dot / (jnp.linalg.norm(a) * jnp.linalg.norm(b)))\n\nprint(f\"Dot product: {dot}\")\nprint(f\"Angle: {jnp.degrees(angle):.1f}\u00b0\")\n

  2. \u8ba1\u7b97\u4e24\u4e2a\u4e09\u7ef4\u5411\u91cf\u7684\u53c9\u79ef\uff0c\u5e76\u901a\u8fc7\u68c0\u67e5\u7ed3\u679c\u4e0e\u6bcf\u4e2a\u539f\u59cb\u5411\u91cf\u7684\u70b9\u79ef\u4e3a\u96f6\u6765\u9a8c\u8bc1\u7ed3\u679c\u5782\u76f4\u4e8e\u4e24\u8005\u3002

    import jax.numpy as jnp\n\na = jnp.array([1.0, 0.0, 0.0])\nb = jnp.array([0.0, 1.0, 0.0])\n\ncross = jnp.cross(a, b)\n\nprint(f\"a x b = {cross}\")\nprint(f\"Perpendicular to a: {jnp.dot(cross, a) == 0}\")\nprint(f\"Perpendicular to b: {jnp.dot(cross, b) == 0}\")\n

"},{"location":"chapter%2001%3A%20vectors/05.%20basis%20and%20duality/","title":"\u57fa\u4e0e\u5bf9\u5076\u6027","text":"

\u57fa\u5b9a\u4e49\u4e86\u5411\u91cf\u7a7a\u95f4\u7684\u5750\u6807\u7cfb\uff0c\u800c\u5bf9\u5076\u6027\u63ed\u793a\u4e86\u7ebf\u6027\u51fd\u6570\u5982\u4f55\u4f5c\u7528\u4e8e\u5411\u91cf\u3002\u672c\u6587\u6db5\u76d6\u7ebf\u6027\u65e0\u5173\u6027\u3001\u751f\u6210\u96c6\u3001\u57fa\u53d8\u6362\u3001\u5bf9\u5076\u7a7a\u95f4\u548c\u4f59\u5411\u91cf\uff0c\u8fd9\u4e9b\u6982\u5ff5\u652f\u6491\u4e86 ML \u4e2d\u7684 PCA\u3001\u7279\u5f81\u53d8\u6362\u548c\u6ce8\u610f\u529b\u67e5\u8be2\u3002

\\[ \\mathbf{e}_i^\\ast(\\mathbf{e}_j) = \\delta_{ij} = \\begin{cases} 1 & \\text{if } i = j \\\\ 0 & \\text{if } i \\neq j \\end{cases} \\] "},{"location":"chapter%2001%3A%20vectors/05.%20basis%20and%20duality/#colab-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
  1. \u5728\u4e24\u4e2a\u4e0d\u540c\u7684\u57fa\u4e2d\u8868\u8fbe\u4e00\u4e2a\u5411\u91cf\uff0c\u5e76\u9a8c\u8bc1\u5b83\u4eec\u4ee3\u8868\u540c\u4e00\u4e2a\u70b9\u3002\u5c1d\u8bd5\u521b\u5efa\u4f60\u81ea\u5df1\u7684\u57fa\uff0c\u89c2\u5bdf\u5411\u91cf\u5f97\u5230\u4ec0\u4e48\u5750\u6807\u3002

    import jax.numpy as jnp\n\nv = jnp.array([3.0, 2.0])\n\n# \u6807\u51c6\u57fa\uff1a\u5750\u6807\u5c31\u662f\u5206\u91cf\u672c\u8eab\nprint(f\"Standard basis coords: {v}\")\n\n# \u65b0\u57fa\uff1a(1,1) \u548c (-1,1)\nP = jnp.array([[1.0, -1.0],\n               [1.0,  1.0]])\nnew_coords = jnp.linalg.solve(P, v)\nprint(f\"New basis coords: {new_coords}\")\n\n# \u9a8c\u8bc1\uff1a\u4ece\u65b0\u5750\u6807\u91cd\u5efa\nreconstructed = new_coords[0] * P[:, 0] + new_coords[1] * P[:, 1]\nprint(f\"Reconstructed: {reconstructed}\")\n

  2. \u9a8c\u8bc1\u5bf9\u5076\u57fa\u6027\u8d28\uff1a\u6bcf\u4e2a\u5bf9\u5076\u57fa\u5411\u91cf\u6070\u597d\u63d0\u53d6\u4e00\u4e2a\u5750\u6807\uff0c\u5bf9\u5176\u4ed6\u5411\u91cf\u8fd4\u56de\u96f6\u3002

    import jax.numpy as jnp\n\n# R3 \u4e2d\u7684\u6807\u51c6\u57fa\ne1 = jnp.array([1.0, 0.0, 0.0])\ne2 = jnp.array([0.0, 1.0, 0.0])\ne3 = jnp.array([0.0, 0.0, 1.0])\n\nv = jnp.array([5.0, 3.0, 7.0])\n\n# \u6bcf\u4e2a\u70b9\u79ef\u63d0\u53d6\u4e00\u4e2a\u5750\u6807\nprint(f\"e1 \u00b7 v = {jnp.dot(e1, v)}\")\nprint(f\"e2 \u00b7 v = {jnp.dot(e2, v)}\")\nprint(f\"e3 \u00b7 v = {jnp.dot(e3, v)}\")\n

"},{"location":"chapter%2002%3A%20matrices/01.%20matrix%20properties/","title":"\u77e9\u9635\u6027\u8d28","text":"

\u77e9\u9635\u662f\u5b58\u50a8\u6570\u636e\u96c6\u3001\u7f16\u7801\u53d8\u6362\u548c\u5b9a\u4e49\u6bcf\u4e2a\u795e\u7ecf\u7f51\u7edc\u5c42\u7684\u6570\u636e\u7ed3\u6784\u3002\u672c\u6587\u6db5\u76d6\u77e9\u9635\u7ef4\u5ea6\u3001\u5143\u7d20\u3001\u8f6c\u7f6e\u3001\u8ff9\u3001\u884c\u5217\u5f0f\u3001\u9006\u3001\u79e9\u548c\u96f6\u7a7a\u95f4\uff0c\u8fd9\u4e9b\u662f\u8d2f\u7a7f\u7ebf\u6027\u4ee3\u6570\u548c ML \u7684\u57fa\u7840\u6027\u8d28\u3002

\\[ A = \\begin{bmatrix} 1 & 2 & 3 \\\\ 4 & 5 & 6 \\end{bmatrix} \\] \\[ \\begin{bmatrix} 25 & 170 & 65 \\\\ 30 & 180 & 80 \\\\ 22 & 160 & 55 \\end{bmatrix} \\] \\[ A = \\begin{bmatrix} 1 & 2 & 3 \\\\ 4 & 5 & 6 \\end{bmatrix} \\quad \\Rightarrow \\quad A^T = \\begin{bmatrix} 1 & 4 \\\\ 2 & 5 \\\\ 3 & 6 \\end{bmatrix} \\]

\\[ \\begin{bmatrix} 1 & 2 \\\\ 3 & 4 \\end{bmatrix} \\]

\u4f46\u4ee5\u4e0b\u77e9\u9635\u7684\u79e9\u4e3a 1\uff0c\u56e0\u4e3a\u7b2c\u4e8c\u884c\u53ea\u662f\u7b2c\u4e00\u884c\u7684\u4e24\u500d\uff0c\u6240\u4ee5\u5b83\u6ca1\u6709\u589e\u52a0\u65b0\u4fe1\u606f\uff1a

\\[ \\begin{bmatrix} 1 & 2 \\\\ 2 & 4 \\end{bmatrix} \\]

\\[ \\det\\begin{bmatrix} a & b \\\\ c & d \\end{bmatrix} = ad - bc \\]

\\[ \\det\\begin{bmatrix} 2 & 1 \\\\ 0 & 3 \\end{bmatrix} = 2 \\cdot 3 - 1 \\cdot 0 = 6 \\]

\u8fd9\u4e2a\u53d8\u6362\u5c06\u5355\u4f4d\u6b63\u65b9\u5f62\u62c9\u4f38\u6210\u4e00\u4e2a\u9762\u79ef\u4e3a 6 \u7684\u5e73\u884c\u56db\u8fb9\u5f62\u3002

\\[ \\begin{bmatrix} a & b \\\\ c & d \\end{bmatrix}^{-1} = \\frac{1}{ad - bc}\\begin{bmatrix} d & -b \\\\ -c & a \\end{bmatrix} \\]

\u6ce8\u610f\u5206\u6bcd\u4e2d\u7684\u884c\u5217\u5f0f\uff0c\u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u5947\u5f02\u77e9\u9635\uff08\u884c\u5217\u5f0f\u4e3a\u96f6\uff09\u6ca1\u6709\u9006\u3002

\\[ \\begin{bmatrix} 1 & 0 \\\\ 0 & 10^{-8} \\end{bmatrix} \\] \\[ \\|A\\|_F = \\sqrt{\\sum_{i}\\sum_{j} A_{ij}^2} \\] \\[ \\left\\|\\begin{bmatrix} 1 & 2 \\\\ 3 & 4 \\end{bmatrix}\\right\\|_F = \\sqrt{1 + 4 + 9 + 16} = \\sqrt{30} \\approx 5.48 \\] \\[ A = \\begin{bmatrix} 2 & 1 \\\\ 1 & 3 \\end{bmatrix} \\]

\u53d6\u4efb\u610f\u5411\u91cf\uff0c\u6bd4\u5982 \\(\\mathbf{x} = [1, -1]^T\\)\uff1a\\(\\mathbf{x}^T A \\mathbf{x} = 2 - 1 - 1 + 3 = 3 > 0\\)\u3002\u65e0\u8bba\u4f60\u5c1d\u8bd5\u54ea\u4e2a\u975e\u96f6 \\(\\mathbf{x}\\)\uff0c\u4f60\u603b\u662f\u5f97\u5230\u6b63\u7684\u7ed3\u679c\u3002

"},{"location":"chapter%2002%3A%20matrices/01.%20matrix%20properties/#colab-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
  1. \u8ba1\u7b97\u77e9\u9635\u7684\u8ff9\u3001\u79e9\u548c\u884c\u5217\u5f0f\u3002\u5c1d\u8bd5\u4f7f\u4e00\u884c\u6210\u4e3a\u53e6\u4e00\u884c\u7684\u500d\u6570\uff0c\u89c2\u5bdf\u79e9\u548c\u884c\u5217\u5f0f\u5982\u4f55\u53d8\u5316\u3002

    import jax.numpy as jnp\n\nA = jnp.array([[1.0, 2.0],\n               [3.0, 4.0]])\n\nprint(f\"Trace: {jnp.trace(A)}\")\nprint(f\"Rank: {jnp.linalg.matrix_rank(A)}\")\nprint(f\"Determinant: {jnp.linalg.det(A):.2f}\")\n

  2. \u8ba1\u7b97\u77e9\u9635\u7684\u9006\uff0c\u5c06\u5176\u4e58\u4ee5\u539f\u77e9\u9635\uff0c\u9a8c\u8bc1\u5f97\u5230\u5355\u4f4d\u77e9\u9635\u3002\u7136\u540e\u5c1d\u8bd5\u5947\u5f02\u77e9\u9635\u5e76\u89c2\u5bdf\u4f1a\u53d1\u751f\u4ec0\u4e48\u3002

    import jax.numpy as jnp\n\nA = jnp.array([[1.0, 2.0],\n               [3.0, 4.0]])\n\nA_inv = jnp.linalg.inv(A)\nprint(f\"A * A_inv:\\n{A @ A_inv}\")\n

"},{"location":"chapter%2002%3A%20matrices/02.%20matrix%20types/","title":"\u77e9\u9635\u7c7b\u578b","text":"

\u7279\u6b8a\u7684\u77e9\u9635\u7ed3\u6784\u80fd\u591f\u89e3\u9501\u8ba1\u7b97\u6377\u5f84\u548c\u6570\u5b66\u4fdd\u8bc1\u3002\u672c\u6587\u6db5\u76d6\u5355\u4f4d\u77e9\u9635\u3001\u5bf9\u89d2\u77e9\u9635\u3001\u5bf9\u79f0\u77e9\u9635\u3001\u4e09\u89d2\u77e9\u9635\u3001\u6b63\u4ea4\u77e9\u9635\u3001\u6b63\u5b9a\u77e9\u9635\u3001\u7a00\u758f\u77e9\u9635\u548c\u968f\u673a\u77e9\u9635\uff0c\u8fd9\u4e9b\u7c7b\u578b\u51fa\u73b0\u5728\u534f\u65b9\u5dee\u4f30\u8ba1\u3001\u56fe\u7b97\u6cd5\u3001\u6b63\u5219\u5316\u548c\u9a6c\u5c14\u53ef\u592b\u94fe\u4e2d\u3002

\\[ I = \\begin{bmatrix} 1 & 0 & 0 \\\\ 0 & 1 & 0 \\\\ 0 & 0 & 1 \\end{bmatrix} \\] \\[ D = \\begin{bmatrix} 3 & 0 \\\\ 0 & 7 \\end{bmatrix} \\] \\[ S = \\begin{bmatrix} 3 & -1 \\\\ -1 & 6 \\end{bmatrix} \\] \\[ L = \\begin{bmatrix} 2 & 0 & 0 \\\\ 1 & 3 & 0 \\\\ -1 & 2 & 4 \\end{bmatrix} \\qquad U = \\begin{bmatrix} 5 & -1 & 2 \\\\ 0 & 1 & 3 \\\\ 0 & 0 & -2 \\end{bmatrix} \\]

\\[ P = \\begin{bmatrix} 0 & 0 & 1 \\\\ 1 & 0 & 0 \\\\ 0 & 1 & 0 \\end{bmatrix} \\] \\[ T = \\begin{bmatrix} a & b & c \\\\ d & a & b \\\\ e & d & a \\end{bmatrix} \\] \\[ C = \\begin{bmatrix} 1 & 3 & 2 \\\\ 2 & 1 & 3 \\\\ 3 & 2 & 1 \\end{bmatrix} \\] \\[ \\begin{bmatrix} 0 & 1 \\\\ 0 & 0 \\end{bmatrix}^2 = \\begin{bmatrix} 0 & 0 \\\\ 0 & 0 \\end{bmatrix} \\] \\[ B = \\begin{bmatrix} 0 & 1 & 1 \\\\ 1 & 0 & 0 \\\\ 1 & 0 & 0 \\end{bmatrix} \\] \\[ V = \\begin{bmatrix} 1 & x_1 & x_1^2 \\\\ 1 & x_2 & x_2^2 \\\\ 1 & x_3 & x_3^2 \\end{bmatrix} \\] \\[ H = \\begin{bmatrix} 4 & 2 & 1 \\\\ 3 & 5 & -1 \\\\ 0 & 1 & 6 \\end{bmatrix} \\] "},{"location":"chapter%2002%3A%20matrices/02.%20matrix%20types/#colab-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
  1. \u521b\u5efa\u4e00\u4e2a\u6b63\u4ea4\u77e9\u9635\uff08\u65cb\u8f6c\u77e9\u9635\uff09\uff0c\u4e58\u4ee5\u5176\u8f6c\u7f6e\uff0c\u9a8c\u8bc1\u5f97\u5230\u5355\u4f4d\u77e9\u9635\u3002\u5c1d\u8bd5\u4e0d\u540c\u7684\u89d2\u5ea6\u3002

    import jax.numpy as jnp\n\ntheta = jnp.pi / 4\nQ = jnp.array([[jnp.cos(theta), -jnp.sin(theta)],\n               [jnp.sin(theta),  jnp.cos(theta)]])\n\nprint(f\"Q @ Q.T:\\n{Q @ Q.T}\")\nprint(f\"Determinant: {jnp.linalg.det(Q):.2f}\")\n

  2. \u521b\u5efa\u4e00\u4e2a\u5bf9\u79f0\u77e9\u9635\u5e76\u9a8c\u8bc1\u5b83\u7b49\u4e8e\u5176\u8f6c\u7f6e\u3002\u7136\u540e\u8ba1\u7b97\u5176\u7279\u5f81\u503c\u5e76\u68c0\u67e5\u7279\u5f81\u5411\u91cf\u662f\u5426\u5782\u76f4\u3002

    import jax.numpy as jnp\n\nS = jnp.array([[4.0, 2.0],\n               [2.0, 3.0]])\n\nprint(f\"Symmetric: {jnp.allclose(S, S.T)}\")\n\neigenvalues, eigenvectors = jnp.linalg.eigh(S)\nprint(f\"Eigenvalues: {eigenvalues}\")\nprint(f\"Dot product of eigenvectors: {jnp.dot(eigenvectors[:, 0], eigenvectors[:, 1]):.6f}\")\n

"},{"location":"chapter%2002%3A%20matrices/03.%20operations/","title":"\u77e9\u9635\u8fd0\u7b97","text":"

\u77e9\u9635\u8fd0\u7b97\u662f\u6df1\u5ea6\u5b66\u4e60\u7684\u8ba1\u7b97\u5f15\u64ce\u3002\u672c\u6587\u6db5\u76d6\u77e9\u9635\u52a0\u6cd5\u3001\u6807\u91cf\u4e58\u6cd5\u3001\u77e9\u9635-\u5411\u91cf\u79ef\u3001\u77e9\u9635\u4e58\u6cd5\u3001\u9010\u5143\u7d20\u8fd0\u7b97\u3001Kronecker\u79ef\u548c\u5e7f\u64ad\u2014\u2014\u652f\u6491\u6bcf\u4e00\u6b21\u524d\u5411\u4f20\u64ad\u548c\u68af\u5ea6\u66f4\u65b0\u7684\u8fd0\u7b97\u3002

\\[ \\begin{bmatrix} 1 & 2 \\\\ 3 & 4 \\end{bmatrix} + \\begin{bmatrix} 5 & 6 \\\\ 7 & 8 \\end{bmatrix} = \\begin{bmatrix} 6 & 8 \\\\ 10 & 12 \\end{bmatrix} \\] \\[ 3 \\times \\begin{bmatrix} 1 & 2 \\\\ 3 & 4 \\end{bmatrix} = \\begin{bmatrix} 3 & 6 \\\\ 9 & 12 \\end{bmatrix} \\] \\[ \\begin{bmatrix} 1 & 2 \\\\ 3 & 4 \\end{bmatrix} \\begin{bmatrix} 5 \\\\ 6 \\end{bmatrix} = 5 \\begin{bmatrix} 1 \\\\ 3 \\end{bmatrix} + 6 \\begin{bmatrix} 2 \\\\ 4 \\end{bmatrix} = \\begin{bmatrix} 17 \\\\ 39 \\end{bmatrix} \\] \\[C_{ij} = \\sum_{k=1}^{n} A_{ik} B_{kj}\\] \\[ \\begin{bmatrix} 1 & 2 & 3 \\\\ 4 & 5 & 6 \\end{bmatrix} \\begin{bmatrix} 1 & 4 \\\\ 2 & 5 \\\\ 3 & 6 \\end{bmatrix} = \\begin{bmatrix} 14 & 32 \\\\ 32 & 77 \\end{bmatrix} \\] \\[ \\begin{bmatrix} 1 & 2 \\\\ 3 & 4 \\end{bmatrix} \\odot \\begin{bmatrix} 5 & 6 \\\\ 7 & 8 \\end{bmatrix} = \\begin{bmatrix} 5 & 12 \\\\ 21 & 32 \\end{bmatrix} \\] \\[ \\begin{bmatrix} 1 \\\\ 2 \\\\ 3 \\end{bmatrix} \\begin{bmatrix} 4 & 5 \\end{bmatrix} = \\begin{bmatrix} 4 & 5 \\\\ 8 & 10 \\\\ 12 & 15 \\end{bmatrix} \\] \\[ A = \\begin{bmatrix} 5 & 0 & 0 & 2 \\\\ 0 & 0 & 3 & 0 \\\\ 0 & 0 & 0 & -1 \\end{bmatrix} \\] \\[ \\begin{bmatrix} 2 & 1 \\\\ 1 & 3 \\end{bmatrix} \\begin{bmatrix} x_1 \\\\ x_2 \\end{bmatrix} = \\begin{bmatrix} 5 \\\\ 10 \\end{bmatrix} \\] \\[2x_1 + 1x_2 = 5 \\qquad \\text{(\u7b2c1\u884c)} \\qquad \\qquad x_1 + 3x_2 = 10 \\qquad \\text{(\u7b2c2\u884c)}\\] \\[ \\begin{bmatrix} 2 & 1 \\\\ 1 & 3 \\end{bmatrix} \\begin{bmatrix} 1 \\\\ 3 \\end{bmatrix} = \\begin{bmatrix} 2 + 3 \\\\ 1 + 9 \\end{bmatrix} = \\begin{bmatrix} 5 \\\\ 10 \\end{bmatrix} \\] \\[A^+ = (A^TA)^{-1}A^T\\] "},{"location":"chapter%2002%3A%20matrices/03.%20operations/#colabjupyter-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528CoLab\u6216Jupyter Notebook\uff09","text":"
  1. \u5c06\u4e24\u4e2a\u77e9\u9635\u76f8\u4e58\u5e76\u9a8c\u8bc1\u7ef4\u5ea6\u3002\u7136\u540e\u4ea4\u6362\u987a\u5e8f\uff0c\u89c2\u5bdf\u7ed3\u679c\u5982\u4f55\u53d8\u5316\uff08\u6216\u8005\uff0c\u5982\u679c\u7ef4\u5ea6\u4e0d\u5339\u914d\uff0c\u8fd0\u7b97\u5931\u8d25\uff09\u3002
import jax.numpy as jnp\n\nA = jnp.array([[1.0, 2.0],\n               [3.0, 4.0]])\nB = jnp.array([[5.0, 6.0],\n               [7.0, 8.0]])\n\nprint(f\"A @ B:\\n{A @ B}\")\nprint(f\"B @ A:\\n{B @ A}\")\nprint(f\"Equal: {jnp.allclose(A @ B, B @ A)}\")\n
  1. \u6c42\u89e3\u7ebf\u6027\u65b9\u7a0b\u7ec4 \\(A\\mathbf{x} = \\mathbf{b}\\)\uff0c\u5e76\u901a\u8fc7\u56de\u4ee3\u4e58\u6cd5\u9a8c\u8bc1\u89e3\u3002\u5c1d\u8bd5\u6539\u53d8 \\(\\mathbf{b}\\)\uff0c\u89c2\u5bdf\u89e3\u5982\u4f55\u53d8\u5316\u3002
import jax.numpy as jnp\n\nA = jnp.array([[2.0, 1.0],\n               [5.0, 3.0]])\nb = jnp.array([4.0, 7.0])\n\nx = jnp.linalg.solve(A, b)\nprint(f\"Solution x: {x}\")\nprint(f\"A @ x: {A @ x}\")\n
"},{"location":"chapter%2002%3A%20matrices/04.%20linear%20transformations/","title":"\u7ebf\u6027\u53d8\u6362","text":"

\u6bcf\u4e2a\u77e9\u9635\u4e58\u6cd5\u90fd\u662f\u4e00\u4e2a\u7ebf\u6027\u53d8\u6362\u2014\u2014\u4e00\u4e2a\u5728\u4fdd\u6301\u7ebf\u6027\u6027\u8d28\u7684\u540c\u65f6\u91cd\u5851\u3001\u65cb\u8f6c\u6216\u6295\u5f71\u5411\u91cf\u7684\u51fd\u6570\u3002\u672c\u6587\u6db5\u76d6\u65cb\u8f6c\u3001\u53cd\u5c04\u3001\u7f29\u653e\u3001\u526a\u5207\u3001\u6295\u5f71\u3001\u6620\u5c04\u7684\u6838\u4e0e\u50cf\uff0c\u4ee5\u53ca\u795e\u7ecf\u7f51\u7edc\u5c42\u5982\u4f55\u4e32\u8054\u8fd9\u4e9b\u53d8\u6362\u3002

\\[ A = \\begin{bmatrix} 2 & 1 \\\\ 1 & 2 \\end{bmatrix} \\]

\u90a3\u4e48 \\(\\hat{\\mathbf{i}} = [1, 0]^T\\) \u843d\u5728 \\([2, 1]^T\\)\uff08\u7b2c1\u5217\uff09\uff0c\\(\\hat{\\mathbf{j}} = [0, 1]^T\\) \u843d\u5728 \\([1, 2]^T\\)\uff08\u7b2c2\u5217\uff09\u3002\u5176\u4ed6\u6240\u6709\u5411\u91cf\u90fd\u662f\u8fd9\u4e24\u4e2a\u5411\u91cf\u7684\u7ec4\u5408\uff0c\u56e0\u6b64\u5176\u8f93\u51fa\u81ea\u52a8\u9075\u5faa\u3002

\\[ R(\\theta) = \\begin{bmatrix} \\cos\\theta & -\\sin\\theta \\\\ \\sin\\theta & \\cos\\theta \\end{bmatrix} \\] \\[ R = \\begin{bmatrix} 0 & -1 \\\\ 1 & 0 \\end{bmatrix} \\]

\u56e0\u6b64 \\([1, 0]^T\\) \u53d8\u6210 \\([0, 1]^T\\)\u3002\u539f\u6765\u6307\u5411\u53f3\u4fa7\u7684\u5411\u91cf\u73b0\u5728\u6307\u5411\u4e0a\u65b9\u3002\u65cb\u8f6c\u77e9\u9635\u662f\u6b63\u4ea4\u7684\uff0c\u4e14\u884c\u5217\u5f0f\u59cb\u7ec8\u4e3a1\u3002\u5f53\u4f60\u5728\u624b\u673a\u4e0a\u65cb\u8f6c\u7167\u7247\u65f6\uff0c\u5c31\u662f\u5bf9\u6bcf\u4e2a\u50cf\u7d20\u5750\u6807\u5e94\u7528\u8fd9\u4e2a\u77e9\u9635\u3002

\\[ R_z(\\theta) = \\begin{bmatrix} \\cos\\theta & -\\sin\\theta & 0 \\\\ \\sin\\theta & \\cos\\theta & 0 \\\\ 0 & 0 & 1 \\end{bmatrix} \\] \\[ S(s_x, s_y) = \\begin{bmatrix} s_x & 0 \\\\ 0 & s_y \\end{bmatrix} \\]

\\[ \\text{Ref}_x = \\begin{bmatrix} 1 & 0 \\\\ 0 & -1 \\end{bmatrix} \\]

\\[ \\text{Ref}_{y=x} = \\begin{bmatrix} 0 & 1 \\\\ 1 & 0 \\end{bmatrix} \\] \\[ \\text{Sh}_x(k) = \\begin{bmatrix} 1 & k \\\\ 0 & 1 \\end{bmatrix} \\]

\\[\\mathbf{y} = A\\mathbf{x} + \\mathbf{t}\\] \\[ \\begin{bmatrix} A & \\mathbf{t} \\\\ \\mathbf{0}^T & 1 \\end{bmatrix} \\begin{bmatrix} \\mathbf{x} \\\\ 1 \\end{bmatrix} = \\begin{bmatrix} A\\mathbf{x} + \\mathbf{t} \\\\ 1 \\end{bmatrix} \\] \\[ \\begin{bmatrix} 1 & 2 \\\\ 2 & 4 \\end{bmatrix} \\]

\u5c06\u6bcf\u4e2a\u4e8c\u7ef4\u5411\u91cf\u6620\u5c04\u5230\u4e00\u6761\u76f4\u7ebf\u4e0a\uff0c\u56e0\u4e3a\u4e24\u5217\u6307\u5411\u540c\u4e00\u65b9\u5411\u3002\u884c\u5217\u5f0f\u4e3a\u96f6\uff0c\u4fe1\u606f\u4e22\u5931\uff0c\u4e14\u8be5\u53d8\u6362\u4e0d\u53ef\u9006\u3002

"},{"location":"chapter%2002%3A%20matrices/04.%20linear%20transformations/#colabjupyter-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528CoLab\u6216Jupyter Notebook\uff09","text":"
  1. \u5bf9\u5411\u91cf\u5e94\u7528\u65cb\u8f6c\u77e9\u9635\uff0c\u5e76\u7ed8\u5236\u539f\u59cb\u5411\u91cf\u548c\u65cb\u8f6c\u540e\u7684\u5411\u91cf\u3002\u5c1d\u8bd5\u4e0d\u540c\u7684\u89d2\u5ea6\u3002
import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ntheta = jnp.pi / 3\nR = jnp.array([[jnp.cos(theta), -jnp.sin(theta)],\n               [jnp.sin(theta),  jnp.cos(theta)]])\n\nv = jnp.array([1.0, 0.0])\nv_rot = R @ v\n\nplt.figure(figsize=(5, 5))\nplt.quiver(0, 0, v[0], v[1], angles='xy', scale_units='xy', scale=1, color='red', label='original')\nplt.quiver(0, 0, v_rot[0], v_rot[1], angles='xy', scale_units='xy', scale=1, color='blue', label='rotated')\nplt.xlim(-1.5, 1.5); plt.ylim(-1.5, 1.5)\nplt.grid(True); plt.legend(); plt.gca().set_aspect('equal')\nplt.show()\n
  1. \u5bf9\u6784\u6210\u6b63\u65b9\u5f62\u7684\u4e00\u7ec4\u70b9\u5e94\u7528\u526a\u5207\u53d8\u6362\uff0c\u5e76\u53ef\u89c6\u5316\u53d8\u5f62\u540e\u7684\u5f62\u72b6\u3002
import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\nsquare = jnp.array([[0,0],[1,0],[1,1],[0,1],[0,0]]).T\n\nk = 0.5\nshear = jnp.array([[1, k],\n                    [0, 1]])\nsheared = shear @ square\n\nplt.figure(figsize=(6, 4))\nplt.plot(square[0], square[1], 'r-o', label='original')\nplt.plot(sheared[0], sheared[1], 'b-o', label='sheared')\nplt.grid(True); plt.legend(); plt.gca().set_aspect('equal')\nplt.show()\n
"},{"location":"chapter%2002%3A%20matrices/05.%20decompositions/","title":"\u77e9\u9635\u5206\u89e3","text":"

\u77e9\u9635\u5206\u89e3\u5c06\u590d\u6742\u77e9\u9635\u62c6\u5206\u4e3a\u66f4\u7b80\u5355\u7684\u56e0\u5b50\uff0c\u7528\u4e8e\u6c42\u89e3\u65b9\u7a0b\u7ec4\u3001\u8ba1\u7b97\u9006\u77e9\u9635\u548c\u6570\u636e\u538b\u7f29\u3002\u672c\u6587\u6db5\u76d6\u9ad8\u65af\u6d88\u5143\u3001LU\u3001QR\u3001Cholesky\u3001\u7279\u5f81\u5206\u89e3\u548cSVD\u2014\u2014\u8fd9\u4e9b\u7b97\u6cd5\u662fPCA\u3001\u63a8\u8350\u7cfb\u7edf\u548c\u673a\u5668\u5b66\u4e60\u6570\u503c\u7a33\u5b9a\u6027\u7684\u57fa\u77f3\u3002

\\[ \\begin{bmatrix} 2 & 1 & 5 \\\\ 4 & 3 & 7 \\\\ 6 & 5 & 9 \\end{bmatrix} \\xrightarrow{R_2 - 2R_1} \\begin{bmatrix} 2 & 1 & 5 \\\\ 0 & 1 & -3 \\\\ 6 & 5 & 9 \\end{bmatrix} \\xrightarrow{R_3 - 3R_1} \\begin{bmatrix} 2 & 1 & 5 \\\\ 0 & 1 & -3 \\\\ 0 & 2 & -6 \\end{bmatrix} \\]

\\[ \\begin{bmatrix} 4 & 2 \\\\ 2 & 5 \\end{bmatrix} = \\begin{bmatrix} 2 & 0 \\\\ 1 & 2 \\end{bmatrix} \\begin{bmatrix} 2 & 1 \\\\ 0 & 2 \\end{bmatrix} \\] \\[A\\mathbf{x} = \\lambda\\mathbf{x}\\]

\\[ A = \\begin{bmatrix} 3 & 1 \\\\ 0 & 2 \\end{bmatrix} \\]

\u5411\u91cf \\([1, 0]^T\\) \u662f\u7279\u5f81\u5411\u91cf\uff0c\\(\\lambda = 3\\)\uff0c\u56e0\u4e3a \\(A[1, 0]^T = [3, 0]^T = 3[1, 0]^T\\)\u3002

"},{"location":"chapter%2002%3A%20matrices/05.%20decompositions/#colabjupyter-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528CoLab\u6216Jupyter Notebook\uff09","text":"
  1. \u8ba1\u7b97\u5bf9\u79f0\u77e9\u9635\u7684\u7279\u5f81\u503c\u548c\u7279\u5f81\u5411\u91cf\u3002\u9a8c\u8bc1\u7279\u5f81\u5411\u91cf\u4e92\u76f8\u5782\u76f4\uff0c\u5e76\u4ece\u7279\u5f81\u5206\u89e3\u91cd\u5efa\u77e9\u9635\u3002
import jax.numpy as jnp\n\nA = jnp.array([[4.0, 2.0],\n               [2.0, 3.0]])\n\neigenvalues, eigenvectors = jnp.linalg.eigh(A)\nprint(f\"Eigenvalues: {eigenvalues}\")\nprint(f\"Eigenvectors orthogonal: {jnp.dot(eigenvectors[:,0], eigenvectors[:,1]):.6f}\")\n\n# Reconstruct: A = P D P^T\nD = jnp.diag(eigenvalues)\nA_reconstructed = eigenvectors @ D @ eigenvectors.T\nprint(f\"Reconstruction matches: {jnp.allclose(A, A_reconstructed)}\")\n
  1. \u5b9e\u73b0\u5e42\u8fed\u4ee3\u6c42\u6700\u5927\u7279\u5f81\u503c\uff0c\u4ee5\u53ca\u53cd\u8fed\u4ee3\u6c42\u6700\u5c0f\u7279\u5f81\u503c\u3002\u4e0e jnp.linalg.eigh \u6bd4\u8f83\u3002\u7136\u540e\u5c1d\u8bd5\u81ea\u5df1\u5b9e\u73b0QR\u7b97\u6cd5\u3002
import jax.numpy as jnp\n\nA = jnp.array([[4.0, 2.0],\n               [2.0, 3.0]])\n\n# Power iteration: finds the LARGEST eigenvalue\nv = jnp.array([1.0, 0.0])\nfor _ in range(20):\n    v = A @ v\n    v = v / jnp.linalg.norm(v)\nprint(f\"Largest eigenvalue:  {v @ A @ v:.4f}\")\n\n# Inverse iteration: multiply by A^{-1} instead of A, finds the SMALLEST eigenvalue\nv = jnp.array([1.0, 0.0])\nfor _ in range(20):\n    v = jnp.linalg.solve(A, v)\n    v = v / jnp.linalg.norm(v)\nprint(f\"Smallest eigenvalue: {1.0 / (v @ jnp.linalg.solve(A, v)):.4f}\")\n\nprint(f\"jnp.linalg.eigh:    {jnp.linalg.eigh(A)[0]}\")\n
  1. \u8ba1\u7b97\u77e9\u9635\u7684SVD\uff0c\u7136\u540e\u4ec5\u4f7f\u7528\u524dk\u4e2a\u5947\u5f02\u503c\u91cd\u5efa\u77e9\u9635\uff0c\u89c2\u5bdf\u8fd1\u4f3c\u8d28\u91cf\u968fk\u7684\u53d8\u5316\u3002
import jax.numpy as jnp\n\nA = jnp.array([[1.0, 2.0, 3.0],\n               [4.0, 5.0, 6.0],\n               [7.0, 8.0, 9.0]])\n\nU, S, Vt = jnp.linalg.svd(A)\n\nfor k in [1, 2, 3]:\n    approx = U[:, :k] @ jnp.diag(S[:k]) @ Vt[:k, :]\n    error = jnp.linalg.norm(A - approx)\n    print(f\"k={k}, reconstruction error: {error:.4f}\")\n
"},{"location":"chapter%2003%3A%20calculus/01.%20differential%20calculus/","title":"\u5fae\u5206","text":"

\u5fae\u5206\u5b66\u7814\u7a76\u77ac\u65f6\u53d8\u5316\u7387\u3002\u672c\u8282\u6db5\u76d6\u6781\u9650\u3001\u5bfc\u6570\u3001\u5fae\u5206\u6cd5\u5219\u3001\u94fe\u5f0f\u6cd5\u5219\uff08\u53cd\u5411\u4f20\u64ad\u7684\u57fa\u7840\uff09\uff0c\u4ee5\u53ca\u673a\u5668\u5b66\u4e60\u4e2d\u5e38\u7528\u7684\u5bfc\u6570\u3002

\\[\\lim_{x \\to a} f(x) = L\\]

\\[f'(a) = \\lim_{h \\to 0} \\frac{f(a + h) - f(a)}{h}\\]

\\[f'(3) = \\lim_{h \\to 0} \\frac{(3+h)^2 - 9}{h} = \\lim_{h \\to 0} \\frac{9 + 6h + h^2 - 9}{h} = \\lim_{h \\to 0} (6 + h) = 6\\] \\[\\frac{d}{dx} x^n = n x^{n-1}\\] \\[\\frac{d}{dx}[f(x) \\pm g(x)] = f'(x) \\pm g'(x)\\] \\[\\frac{d}{dx}[f(x) \\cdot g(x)] = f'(x)g(x) + f(x)g'(x)\\] \\[\\frac{d}{dx}\\left[\\frac{f(x)}{g(x)}\\right] = \\frac{f'(x)g(x) - f(x)g'(x)}{[g(x)]^2}\\] \\[\\frac{d}{dx} f(g(x)) = f'(g(x)) \\cdot g'(x)\\]

\u51fd\u6570 \u5bfc\u6570 \u5907\u6ce8 \\(e^x\\) \\(e^x\\) \u552f\u4e00\u4e00\u4e2a\u5bfc\u6570\u7b49\u4e8e\u81ea\u8eab\u7684\u51fd\u6570 \\(a^x\\) \\(a^x \\ln a\\) \u6307\u6570\u51fd\u6570\u7684\u4e00\u822c\u5f62\u5f0f \\(\\ln x\\) \\(\\frac{1}{x}\\) \u81ea\u7136\u5bf9\u6570 \\(\\log_a x\\) \\(\\frac{1}{x \\ln a}\\) \u4e00\u822c\u5bf9\u6570 \\(\\sin x\\) \\(\\cos x\\) \\(\\cos x\\) \\(-\\sin x\\) \u6ce8\u610f\u8d1f\u53f7 \\(\\tan x\\) \\(\\sec^2 x\\) \\[\\lim_{x \\to a} \\frac{f(x)}{g(x)} = \\lim_{x \\to a} \\frac{f'(x)}{g'(x)}\\] "},{"location":"chapter%2003%3A%20calculus/01.%20differential%20calculus/#colab-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
  1. \u53ef\u89c6\u5316\u5e38\u89c1\u51fd\u6570\u3002\u5728\u540c\u4e00\u5f20\u56fe\u4e2d\u7ed8\u5236 \\(x^2\\)\u3001\\(\\sin(x)\\) \u548c \\(e^x\\)\uff0c\u5efa\u7acb\u5bf9\u4e0d\u540c\u516c\u5f0f\u4ea7\u751f\u4e0d\u540c\u5f62\u72b6\u7684\u76f4\u89c2\u611f\u53d7\u3002\u5c1d\u8bd5\u4fee\u6539\u53c2\u6570\uff08\u4f8b\u5982 \\(2x^2\\)\u3001\\(\\sin(2x)\\)\uff09\uff0c\u89c2\u5bdf\u66f2\u7ebf\u5982\u4f55\u53d8\u5316\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\nx = jnp.linspace(-3, 3, 300)\n\nfig, axes = plt.subplots(1, 3, figsize=(12, 3))\naxes[0].plot(x, x**2, color=\"#e74c3c\")\naxes[0].set_title(\"x\u00b2  (\u629b\u7269\u7ebf)\")\naxes[1].plot(x, jnp.sin(x), color=\"#3498db\")\naxes[1].set_title(\"sin(x)  (\u6ce2\u5f62)\")\naxes[2].plot(x, jnp.exp(x), color=\"#27ae60\")\naxes[2].set_title(\"e\u02e3  (\u6307\u6570\u51fd\u6570)\")\nfor ax in axes:\n    ax.axhline(0, color=\"gray\", linewidth=0.5)\n    ax.axvline(0, color=\"gray\", linewidth=0.5)\nplt.tight_layout()\nplt.show()\n

  2. \u4f7f\u7528 JAX \u7684\u81ea\u52a8\u5fae\u5206\u8ba1\u7b97 \\(f(x) = x^3 - 2x + 1\\) \u5728\u82e5\u5e72\u70b9\u5904\u7684\u5bfc\u6570\uff0c\u5e76\u4e0e\u89e3\u6790\u5bfc\u6570 \\(f'(x) = 3x^2 - 2\\) \u8fdb\u884c\u6bd4\u8f83\u3002

    import jax\nimport jax.numpy as jnp\n\nf = lambda x: x**3 - 2*x + 1\ndf = jax.grad(f)\n\nfor x in [0.0, 1.0, 2.0, -1.0]:\n    print(f\"x={x:5.1f}  \u81ea\u52a8\u5fae\u5206: {df(x):.4f}  \u89e3\u6790\u89e3: {3*x**2 - 2:.4f}\")\n

  3. \u6570\u503c\u9a8c\u8bc1\u94fe\u5f0f\u6cd5\u5219\u3002\u5b9a\u4e49 \\(f(x) = \\sin(x^2)\\)\uff0c\u901a\u8fc7 jax.grad \u8ba1\u7b97\u5176\u5bfc\u6570\uff0c\u5e76\u4e0e\u89e3\u6790\u7ed3\u679c \\(2x\\cos(x^2)\\) \u8fdb\u884c\u6bd4\u8f83\u3002

    import jax\nimport jax.numpy as jnp\n\nf = lambda x: jnp.sin(x**2)\ndf = jax.grad(f)\n\nfor x in [0.5, 1.0, 2.0]:\n    auto = df(x)\n    analytical = 2*x * jnp.cos(x**2)\n    print(f\"x={x:.1f}  \u81ea\u52a8\u5fae\u5206: {auto:.6f}  \u89e3\u6790\u89e3: {analytical:.6f}\")\n

  4. \u53ef\u89c6\u5316\u5bfc\u6570\u3002\u5c06 \\(f(x) = x^3 - 3x\\) \u4e0e\u5176\u5bfc\u6570 \\(f'(x) = 3x^2 - 3\\) \u7ed8\u5236\u5728\u540c\u4e00\u5f20\u56fe\u4e0a\u3002\u89c2\u5bdf \\(f'(x) = 0\\) \u7684\u4f4d\u7f6e\u4e0e \\(f\\) \u7684\u5cf0\u8c37\u4e4b\u95f4\u7684\u5bf9\u5e94\u5173\u7cfb\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\nf = lambda x: x**3 - 3*x\n# jax.grad \u7528\u4e8e\u6807\u91cf\uff1bjax.vmap \u5c06\u5176\u5411\u91cf\u5316\uff0c\u53ef\u540c\u65f6\u5904\u7406\u4e00\u7ec4\u8f93\u5165\ndf = jax.vmap(jax.grad(f))\n\nx = jnp.linspace(-2.5, 2.5, 200)\nplt.plot(x, jax.vmap(f)(x), label=\"f(x)\")\nplt.plot(x, df(x), label=\"f'(x)\", linestyle=\"--\")\nplt.axhline(0, color=\"gray\", linewidth=0.5)\nplt.legend()\nplt.title(\"\u51fd\u6570\u53ca\u5176\u5bfc\u6570\")\nplt.show()\n

"},{"location":"chapter%2003%3A%20calculus/02.%20integral%20calculus/","title":"\u79ef\u5206\u5b66","text":"

\u79ef\u5206\u5b66\u5728\u533a\u95f4\u4e0a\u7d2f\u79ef\u91cf\uff0c\u5c06\u5c40\u90e8\u53d8\u5316\u7387\u8fd8\u539f\u4e3a\u603b\u91cf\u3002\u672c\u6587\u6db5\u76d6\u5b9a\u79ef\u5206\u4e0e\u4e0d\u5b9a\u79ef\u5206\u3001\u5fae\u79ef\u5206\u57fa\u672c\u5b9a\u7406\u3001\u79ef\u5206\u6280\u5de7\uff0c\u4ee5\u53ca\u5728\u673a\u5668\u5b66\u4e60\u4e2d\u4e0e\u6982\u7387\u5bc6\u5ea6\u548c\u671f\u671b\u503c\u7684\u5e94\u7528\u3002

\\[\\text{\u9762\u79ef} \\approx \\sum_{i=1}^{n} f(x_i^\\ast) \\, \\Delta x\\] \\[\\int_a^b f(x)\\, dx = \\lim_{n \\to \\infty} \\sum_{i=1}^{n} f(x_i^\\ast) \\, \\Delta x\\] \\[\\int f(x)\\, dx = F(x) + C\\] \\[\\int_a^b f(x)\\, dx = F(b) - F(a)\\] \u51fd\u6570 \u79ef\u5206 \u6761\u4ef6 \\(x^n\\) \\(\\frac{x^{n+1}}{n+1} + C\\) \\(n \\neq -1\\) \\(\\frac{1}{x}\\) \\(\\ln\\|x\\| + C\\) \\(e^x\\) \\(e^x + C\\) \\(a^x\\) \\(\\frac{a^x}{\\ln a} + C\\) \\(\\sin x\\) \\(-\\cos x + C\\) \\(\\cos x\\) \\(\\sin x + C\\) \\(k\\)\uff08\u5e38\u6570\uff09 \\(kx + C\\) \\[\\int u\\, dv = uv - \\int v\\, du\\] "},{"location":"chapter%2003%3A%20calculus/02.%20integral%20calculus/#colab-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
  1. \u4f7f\u7528\u9ece\u66fc\u548c\uff0c\u7528\u4e0d\u65ad\u589e\u52a0\u6570\u91cf\u7684\u77e9\u5f62\u6765\u6570\u503c\u903c\u8fd1 \\(\\int_0^1 x^2\\, dx\\)\u3002\u4e0e\u7cbe\u786e\u7b54\u6848 \\(\\frac{1}{3}\\) \u8fdb\u884c\u6bd4\u8f83\u3002

    import jax.numpy as jnp\n\nfor n in [10, 100, 1000, 10000]:\n    x = jnp.linspace(0, 1, n, endpoint=False)\n    dx = 1.0 / n\n    area = jnp.sum(x**2 * dx)\n    print(f\"n={n:5d}  approx: {area:.6f}  exact: {1/3:.6f}\")\n

  2. \u6570\u503c\u9a8c\u8bc1\u5fae\u79ef\u5206\u57fa\u672c\u5b9a\u7406\u3002\u5b9a\u4e49 \\(F(x) = \\int_0^x t^2\\, dt = \\frac{x^3}{3}\\)\uff0c\u5e76\u9a8c\u8bc1\u5176\u5bfc\u6570\uff08\u901a\u8fc7 jax.grad \u8ba1\u7b97\uff09\u7b49\u4e8e \\(x^2\\)\u3002

    import jax\nimport jax.numpy as jnp\n\nF = lambda x: x**3 / 3\ndF = jax.grad(F)\n\nfor x in [0.5, 1.0, 2.0, 3.0]:\n    print(f\"x={x:.1f}  F'(x)={dF(x):.4f}  x^2={x**2:.4f}\")\n

  3. \u53ef\u89c6\u5316 \\(f(x) = \\sin(x)\\) \u4ece \\(0\\) \u5230 \\(\\pi\\) \u7684\u66f2\u7ebf\u4e0b\u9762\u79ef\u3002\u4f7f\u7528 plt.fill_between \u586b\u5145\u8be5\u533a\u57df\uff0c\u5e76\u7528\u9ece\u66fc\u548c\u6570\u503c\u8ba1\u7b97\u9762\u79ef\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\nx = jnp.linspace(0, jnp.pi, 500)\ny = jnp.sin(x)\n\nplt.plot(x, y, color=\"purple\", linewidth=2)\nplt.fill_between(x, y, alpha=0.2, color=\"purple\")\nplt.title(f\"Area = {jnp.sum(jnp.sin(x) * (jnp.pi / 500)):.4f}  (exact: 2.0)\")\nplt.show()\n

"},{"location":"chapter%2003%3A%20calculus/03.%20multivariate%20calculus/","title":"\u591a\u5143\u5fae\u79ef\u5206","text":"

\u591a\u5143\u5fae\u79ef\u5206\u5c06\u5bfc\u6570\u548c\u79ef\u5206\u6269\u5c55\u5230\u591a\u53d8\u91cf\u51fd\u6570\uff0c\u8fd9\u5bf9\u4e8e\u673a\u5668\u5b66\u4e60\u6a21\u578b\u62e5\u6709\u6570\u767e\u4e07\u53c2\u6570\u7684\u60c5\u5f62\u81f3\u5173\u91cd\u8981\u3002\u672c\u7ae0\u6db5\u76d6\u504f\u5bfc\u6570\u3001\u68af\u5ea6\u3001\u96c5\u53ef\u6bd4\u77e9\u9635\u3001\u6d77\u68ee\u77e9\u9635\u4ee5\u53ca\u4f7f\u53cd\u5411\u4f20\u64ad\u6210\u4e3a\u53ef\u80fd\u7684\u591a\u53d8\u91cf\u94fe\u5f0f\u6cd5\u5219\u3002

\\[\\frac{\\partial f}{\\partial x} = 2xy + 3 \\qquad \\frac{\\partial f}{\\partial y} = x^2 - 2\\]

\\[\\nabla f = \\left(\\frac{\\partial f}{\\partial x_1}, \\frac{\\partial f}{\\partial x_2}, \\ldots, \\frac{\\partial f}{\\partial x_n}\\right)\\]

\\[D_{\\mathbf{u}} f = \\nabla f \\cdot \\mathbf{u}\\] \\[ J = \\begin{bmatrix} \\frac{\\partial f_1}{\\partial x_1} & \\cdots & \\frac{\\partial f_1}{\\partial x_n} \\\\ \\vdots & \\ddots & \\vdots \\\\ \\frac{\\partial f_m}{\\partial x_1} & \\cdots & \\frac{\\partial f_m}{\\partial x_n} \\end{bmatrix} \\] \\[ H = \\begin{bmatrix} \\frac{\\partial^2 f}{\\partial x_1^2} & \\frac{\\partial^2 f}{\\partial x_1 \\partial x_2} & \\cdots \\\\ \\frac{\\partial^2 f}{\\partial x_2 \\partial x_1} & \\frac{\\partial^2 f}{\\partial x_2^2} & \\cdots \\\\ \\vdots & \\vdots & \\ddots \\end{bmatrix} \\] \\[ H = \\begin{bmatrix} 6x & 4y \\\\ 4y & 4x - 6y \\end{bmatrix} \\] \\[\\frac{dz}{dt} = \\frac{\\partial f}{\\partial x}\\frac{dx}{dt} + \\frac{\\partial f}{\\partial y}\\frac{dy}{dt}\\] \\[\\frac{dz}{dt} = (2xy + 3)(-\\sin t) + (x^2 - 2y)(\\cos t)\\] "},{"location":"chapter%2003%3A%20calculus/03.%20multivariate%20calculus/#colab-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
  1. \u4f7f\u7528 jax.grad \u8ba1\u7b97\u51fd\u6570 \\(f(x, y) = x^2 y + 3x - 2y\\) \u5728\u70b9 \\((1, 2)\\) \u5904\u7684\u68af\u5ea6\u3002\u7531\u4e8e \\(f\\) \u63a5\u6536\u5411\u91cf\u8f93\u5165\uff0c\u8bf7\u4f7f\u7528\u5e26 argnums \u53c2\u6570\u7684 jax.grad\u3002

    import jax\nimport jax.numpy as jnp\n\ndef f(x, y):\n    return x**2 * y + 3*x - 2*y\n\ndf_dx = jax.grad(f, argnums=0)\ndf_dy = jax.grad(f, argnums=1)\n\nx, y = 1.0, 2.0\nprint(f\"\u2202f/\u2202x = {df_dx(x, y):.4f}  (\u671f\u671b: {2*x*y + 3:.4f})\")\nprint(f\"\u2202f/\u2202y = {df_dy(x, y):.4f}  (\u671f\u671b: {x**2 - 2:.4f})\")\n

  2. \u4f7f\u7528 jax.jacobian \u8ba1\u7b97\u5411\u91cf\u503c\u51fd\u6570\u7684\u96c5\u53ef\u6bd4\u77e9\u9635\uff0c\u5e76\u4e0e\u624b\u52a8\u8ba1\u7b97\u7ed3\u679c\u8fdb\u884c\u6bd4\u8f83\u3002

    import jax\nimport jax.numpy as jnp\n\ndef F(x):\n    return jnp.array([x[0]**2 + x[1], x[0] * x[1]**2])\n\nJ = jax.jacobian(F)\nx = jnp.array([1.0, 2.0])\nprint(f\"\u5728 (1,2) \u5904\u7684\u96c5\u53ef\u6bd4\u77e9\u9635:\\n{J(x)}\")\n# \u671f\u671b: [[2*x[0], 1], [x[1]**2, 2*x[0]*x[1]]] = [[2, 1], [4, 4]]\n

  3. \u4f7f\u7528 jax.hessian \u8ba1\u7b97 \\(f(x, y) = x^3 + 2xy^2 - y^3\\) \u7684\u6d77\u68ee\u77e9\u9635\uff0c\u5e76\u9a8c\u8bc1\u5176\u5bf9\u79f0\u6027\u3002

    import jax\nimport jax.numpy as jnp\n\ndef f(xy):\n    x, y = xy[0], xy[1]\n    return x**3 + 2*x*y**2 - y**3\n\nH = jax.hessian(f)\npoint = jnp.array([1.0, 2.0])\nhess = H(point)\nprint(f\"\u6d77\u68ee\u77e9\u9635:\\n{hess}\")\nprint(f\"\u662f\u5426\u5bf9\u79f0: {jnp.allclose(hess, hess.T)}\")\n# \u671f\u671b: [[6x, 4y], [4y, 4x-6y]] = [[6, 8], [8, -8]]\n

  4. \u4ece\u5934\u6784\u5efa\u4e00\u4e2a\u6781\u7b80\u7684\u81ea\u52a8\u5fae\u5206\u5f15\u64ce\u3002

"},{"location":"chapter%2003%3A%20calculus/04.%20function%20approximation/","title":"\u51fd\u6570\u903c\u8fd1","text":"

\u51fd\u6570\u903c\u8fd1\u7528\u8db3\u591f\u63a5\u8fd1\u539f\u51fd\u6570\u7684\u7b80\u5355\u51fd\u6570\u6765\u66ff\u4ee3\u590d\u6742\u51fd\u6570\u3002\u672c\u6587\u6db5\u76d6\u7ebf\u6027\u5316\u3001\u6cf0\u52d2\u7ea7\u6570\u3001\u591a\u9879\u5f0f\u903c\u8fd1\u3001\u5085\u91cc\u53f6\u7ea7\u6570\u4ee5\u53ca\u901a\u7528\u903c\u8fd1\u5b9a\u7406\u2014\u2014\u8fd9\u4e9b\u662f\u795e\u7ecf\u7f51\u7edc\u80fd\u591f\u5b66\u4e60\u4efb\u610f\u6620\u5c04\u7684\u7406\u8bba\u57fa\u7840\u3002

\\[L(x) = f(a) + f'(a)(x - a)\\] \\[f(x) = \\sum_{n=0}^{\\infty} \\frac{f^{(n)}(a)}{n!}(x - a)^n = f(a) + f'(a)(x-a) + \\frac{f''(a)}{2!}(x-a)^2 + \\frac{f'''(a)}{3!}(x-a)^3 + \\cdots\\]

\\[f(x) = \\sum_{n=0}^{\\infty} \\frac{f^{(n)}(0)}{n!} x^n\\] \\[e^x = 1 + x + \\frac{x^2}{2!} + \\frac{x^3}{3!} + \\cdots\\] \\[\\sin x = x - \\frac{x^3}{3!} + \\frac{x^5}{5!} - \\frac{x^7}{7!} + \\cdots\\] \\[\\cos x = 1 - \\frac{x^2}{2!} + \\frac{x^4}{4!} - \\frac{x^6}{6!} + \\cdots\\] \\[R_n(x) = \\frac{f^{(n+1)}(c)}{(n+1)!}(x-a)^{n+1}\\] \\[f(\\mathbf{x}) \\approx f(\\mathbf{a}) + \\nabla f(\\mathbf{a})^T (\\mathbf{x} - \\mathbf{a}) + \\frac{1}{2} (\\mathbf{x} - \\mathbf{a})^T H(\\mathbf{a}) (\\mathbf{x} - \\mathbf{a})\\] "},{"location":"chapter%2003%3A%20calculus/04.%20function%20approximation/#colab-jupyter-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528 CoLab \u6216 Jupyter Notebook\uff09","text":"
  1. \u7528\u9012\u589e\u6570\u91cf\u7684\u6cf0\u52d2\u9879\u903c\u8fd1 \\(e^x\\)\uff0c\u5e76\u53ef\u89c6\u5316\u903c\u8fd1\u6548\u679c\u5982\u4f55\u6539\u5584\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\nx = jnp.linspace(-2, 3, 300)\nplt.plot(x, jnp.exp(x), \"k-\", linewidth=2, label=\"e\u02e3 (\u7cbe\u786e\u503c)\")\n\ncolors = [\"#e74c3c\", \"#3498db\", \"#27ae60\", \"#9b59b6\"]\nfor n, color in zip([1, 2, 4, 8], colors):\n    approx = sum(x**k / jnp.array(float(jnp.prod(jnp.arange(1, k+1)) if k > 0 else 1))\n                 for k in range(n+1))\n    plt.plot(x, approx, color=color, linestyle=\"--\", label=f\"{n} \u9879\")\n\nplt.ylim(-2, 15)\nplt.legend()\nplt.title(\"e\u02e3 \u7684\u6cf0\u52d2\u903c\u8fd1\")\nplt.show()\n

  2. \u8ba1\u7b97\u62c9\u683c\u6717\u65e5\u4f59\u9879\uff0c\u4ee5\u9650\u5b9a\u7528\u4e0d\u540c\u6570\u91cf\u7684\u6cf0\u52d2\u9879\u903c\u8fd1 \\(\\sin(1)\\) \u65f6\u7684\u8bef\u5dee\u3002

    import jax.numpy as jnp\n\nx = 1.0\nexact = jnp.sin(x)\n\ntaylor = 0.0\nfor n in range(8):\n    sign = (-1)**n\n    factorial = float(jnp.prod(jnp.arange(1, 2*n+2)))\n    taylor += sign * x**(2*n+1) / factorial\n    error = abs(exact - taylor)\n    bound = x**(2*n+3) / float(jnp.prod(jnp.arange(1, 2*n+4)))\n    print(f\"\u9879\u6570={n+1}  \u8fd1\u4f3c\u503c={taylor:.10f}  \u8bef\u5dee={error:.2e}  \u754c\u9650={bound:.2e}\")\n

  3. \u6bd4\u8f83\u5728 \\(x=0\\) \u9644\u8fd1 \\(\\cos(x)\\) \u7684\u7ebf\u6027\u5316\u903c\u8fd1\u4e0e\u4e8c\u6b21\u6cf0\u52d2\u903c\u8fd1\u3002\u5728\u540c\u4e00\u5f20\u56fe\u4e0a\u7ed8\u5236\u4e24\u4e2a\u903c\u8fd1\u548c\u771f\u5b9e\u51fd\u6570\uff0c\u89c2\u5bdf\u5404\u81ea\u7cbe\u786e\u7684\u8303\u56f4\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\nx = jnp.linspace(-3, 3, 300)\nplt.plot(x, jnp.cos(x), \"k-\", linewidth=2, label=\"cos(x)\")\nplt.plot(x, jnp.ones_like(x), \"--\", color=\"#e74c3c\", label=\"\u7ebf\u6027: 1\")\nplt.plot(x, 1 - x**2/2, \"--\", color=\"#3498db\", label=\"\u4e8c\u6b21: 1 - x\u00b2/2\")\nplt.plot(x, 1 - x**2/2 + x**4/24, \"--\", color=\"#27ae60\", label=\"\u56db\u9636\")\nplt.ylim(-2, 2)\nplt.legend()\nplt.title(\"cos(x) \u7684\u6cf0\u52d2\u903c\u8fd1\")\nplt.show()\n

"},{"location":"chapter%2003%3A%20calculus/05.%20optimisation/","title":"\u4f18\u5316","text":"

\u4f18\u5316\u662f\u6a21\u578b\u8bad\u7ec3\u7684\u6570\u5b66\u6838\u5fc3\u2014\u2014\u5bfb\u627e\u4f7f\u635f\u5931\u51fd\u6570\u6700\u5c0f\u7684\u53c2\u6570\u3002\u672c\u6587\u6db5\u76d6\u9a7b\u70b9\u3001\u51f8\u6027\u3001\u68af\u5ea6\u4e0b\u964d\u3001\u725b\u987f\u6cd5\u3001\u5e26\u62c9\u683c\u6717\u65e5\u4e58\u6570\u7684\u7ea6\u675f\u4f18\u5316\uff0c\u4ee5\u53ca\u9a71\u52a8\u73b0\u4ee3\u6df1\u5ea6\u5b66\u4e60\u7684\u4e3b\u6d41\u4f18\u5316\u5668\uff08SGD\u3001Adam\uff09\u3002

\\[x_{n+1} = x_n - \\frac{f(x_n)}{f'(x_n)}\\]

\\[x_{n+1} = x_n - \\frac{f'(x_n)}{f''(x_n)}\\] \\[\\mathcal{L}(x, y, \\lambda) = f(x, y) - \\lambda(g(x, y) - c)\\] \\[\\frac{\\partial \\mathcal{L}}{\\partial x} = 0, \\quad \\frac{\\partial \\mathcal{L}}{\\partial y} = 0, \\quad \\frac{\\partial \\mathcal{L}}{\\partial \\lambda} = 0\\]

\\[2xy - 2\\lambda x = 0, \\quad x^2 - 2\\lambda y = 0, \\quad x^2 + y^2 = 1\\] "},{"location":"chapter%2003%3A%20calculus/05.%20optimisation/#colab-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u5728 CoLab \u6216 notebook \u4e2d\u5b8c\u6210\uff09","text":"
  1. \u5b9e\u73b0\u725b\u987f\u6cd5\u6c42 \\(\\sqrt{7}\\)\uff08\u5373 \\(f(x) = x^2 - 7\\) \u7684\u96f6\u70b9\uff09\u3002\u89c2\u5bdf\u5176\u5feb\u901f\u6536\u655b\u3002

    import jax.numpy as jnp\n\nf = lambda x: x**2 - 7\ndf = lambda x: 2*x\n\nx = 3.0  # \u521d\u59cb\u731c\u6d4b\nfor i in range(6):\n    x = x - f(x) / df(x)\n    print(f\"step {i+1}: x = {x:.10f}  (error: {abs(x - jnp.sqrt(7.0)):.2e})\")\n

  2. \u4f7f\u7528\u68af\u5ea6\u4e0b\u964d\u6700\u5c0f\u5316 \\(f(x, y) = (x - 3)^2 + (y + 1)^2\\)\u3002\u6700\u5c0f\u503c\u5728 \\((3, -1)\\) \u5904\u3002\u5c1d\u8bd5\u4e0d\u540c\u7684\u5b66\u4e60\u7387\u3002

    import jax\nimport jax.numpy as jnp\n\ndef f(params):\n    x, y = params\n    return (x - 3)**2 + (y + 1)**2\n\ngrad_f = jax.grad(f)\nparams = jnp.array([0.0, 0.0])\nlr = 0.1\n\nfor i in range(20):\n    g = grad_f(params)\n    params = params - lr * g\n    if i % 5 == 0 or i == 19:\n        print(f\"step {i:2d}: ({params[0]:.4f}, {params[1]:.4f})  loss={f(params):.6f}\")\n

  3. \u6570\u503c\u6c42\u89e3\u7ea6\u675f\u4f18\u5316\u95ee\u9898\u3002\u5728 \\(x + y = 10\\) \u7684\u7ea6\u675f\u4e0b\u6700\u5927\u5316 \\(f(x,y) = xy\\)\uff0c\u901a\u8fc7\u53c2\u6570\u5316 \\(y = 10 - x\\) \u5e76\u6c42\u5355\u53d8\u91cf\u51fd\u6570\u7684\u6700\u4f18\u503c\u3002

    import jax\nimport jax.numpy as jnp\n\n# \u4ee3\u5165\u7ea6\u675f\u6761\u4ef6\uff1ay = 10 - x\uff0c\u6240\u4ee5 f = x(10 - x) = 10x - x\u00b2\nf = lambda x: x * (10 - x)\ndf = jax.grad(f)\n\n# \u68af\u5ea6\u4e0a\u5347\uff08\u6211\u4eec\u8981\u6c42\u6700\u5927\u503c\uff0c\u6240\u4ee5\u52a0\u4e0a\u68af\u5ea6\uff09\nx = 1.0\nlr = 0.1\nfor i in range(20):\n    x = x + lr * df(x)\nprint(f\"x={x:.4f}, y={10-x:.4f}, f={f(x):.4f}\")  # \u5e94\u4e3a x=5, y=5, f=25\n

"},{"location":"chapter%2004%3A%20statistics/01.%20fundamentals/","title":"\u7edf\u8ba1\u5b66\u57fa\u7840","text":"

\u7edf\u8ba1\u5b66\u63d0\u4f9b\u4e86\u63cf\u8ff0\u6570\u636e\u548c\u91cf\u5316\u4e0d\u786e\u5b9a\u6027\u7684\u8bed\u8a00\u3002\u672c\u8282\u6db5\u76d6\u5206\u5e03\u3001\u968f\u673a\u53d8\u91cf\u3001PMF\u3001PDF\u3001CDF\u3001\u671f\u671b\u3001\u65b9\u5dee\u3001\u77e9\u4ee5\u53ca\u4e2d\u5fc3\u6781\u9650\u5b9a\u7406\u2014\u2014\u8fd9\u4e9b\u6982\u5ff5\u652f\u6491\u7740\u6bcf\u4e00\u4e2a\u673a\u5668\u5b66\u4e60\u8bc4\u4f30\u6307\u6807\u548c\u635f\u5931\u51fd\u6570\u3002

\\[P(X = x) = p(x), \\quad \\text{\u5176\u4e2d } \\sum_{x} p(x) = 1\\] \\[P(a \\le X \\le b) = \\int_a^b f(x)\\, dx, \\quad \\text{\u5176\u4e2d } \\int_{-\\infty}^{\\infty} f(x)\\, dx = 1\\] \\[E[X] = \\sum_{x} x \\cdot p(x)\\] \\[E[X] = \\int_{-\\infty}^{\\infty} x \\cdot f(x)\\, dx\\] \\[E[X] = 1 \\cdot \\tfrac{1}{6} + 2 \\cdot \\tfrac{1}{6} + 3 \\cdot \\tfrac{1}{6} + 4 \\cdot \\tfrac{1}{6} + 5 \\cdot \\tfrac{1}{6} + 6 \\cdot \\tfrac{1}{6} = \\frac{21}{6} = 3.5\\] \\[\\mu_k' = E[X^k]\\] \\[\\mu_k = E[(X - \\mu)^k]\\] \\[\\tilde{\\mu}_k = \\frac{\\mu_k}{\\sigma^k}\\]

\\[\\mu = \\frac{2 + 4 + 4 + 4 + 5 + 5 + 7 + 9}{8} = \\frac{40}{8} = 5\\] \\[\\sigma^2 = \\frac{(2{-}5)^2 + (4{-}5)^2 + (4{-}5)^2 + (4{-}5)^2 + (5{-}5)^2 + (5{-}5)^2 + (7{-}5)^2 + (9{-}5)^2}{8}\\] \\[= \\frac{9 + 1 + 1 + 1 + 0 + 0 + 4 + 16}{8} = \\frac{32}{8} = 4\\] \\[\\tilde{\\mu}_3 = \\frac{1}{8} \\cdot \\frac{(-3)^3 + (-1)^3 + (-1)^3 + (-1)^3 + 0^3 + 0^3 + 2^3 + 4^3}{2^3}\\] \\[= \\frac{1}{8} \\cdot \\frac{-27 -1 -1 -1 + 0 + 0 + 8 + 64}{8} = \\frac{42}{64} = 0.656\\] \\[\\tilde{\\mu}_4 = \\frac{1}{8} \\cdot \\frac{(-3)^4 + (-1)^4 + (-1)^4 + (-1)^4 + 0^4 + 0^4 + 2^4 + 4^4}{2^4}\\] \\[= \\frac{1}{8} \\cdot \\frac{81 + 1 + 1 + 1 + 0 + 0 + 16 + 256}{16} = \\frac{356}{128} = 2.781\\] "},{"location":"chapter%2004%3A%20statistics/01.%20fundamentals/#colab-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
  1. \u8ba1\u7b97\u4e00\u4e2a\u52a0\u8f7d\u9ab0\u5b50\u7684\u671f\u671b\u503c\uff0c\u5176\u4e2d\u9762 6 \u7684\u6982\u7387\u4e3a 0.3\uff0c\u5176\u4f59\u9762\u5747\u5206\u5269\u4f59\u6982\u7387\u3002\u901a\u8fc7\u6a21\u62df 100,000 \u6b21\u6295\u63b7\u8fdb\u884c\u9a8c\u8bc1\u3002

    import jax\nimport jax.numpy as jnp\n\n# \u52a0\u8f7d\u9ab0\u5b50\uff1a\u9762 6 \u7684 p=0.3\uff0c\u5176\u4f59\u9762\u5747\u5206 0.7\nprobs = jnp.array([0.14, 0.14, 0.14, 0.14, 0.14, 0.30])\nfaces = jnp.array([1, 2, 3, 4, 5, 6])\n\n# \u89e3\u6790\u6cd5\u8ba1\u7b97\u671f\u671b\u503c\nev = jnp.sum(faces * probs)\nprint(f\"\u671f\u671b\u503c\uff08\u516c\u5f0f\u6cd5\uff09: {ev:.4f}\")\n\n# \u6a21\u62df\nkey = jax.random.PRNGKey(42)\nrolls = jax.random.choice(key, faces, shape=(100_000,), p=probs)\nprint(f\"\u671f\u671b\u503c\uff08\u6a21\u62df\u6cd5\uff09: {rolls.mean():.4f}\")\n

  2. \u8ba1\u7b97\u793a\u4f8b\u6570\u636e\u96c6\u7684\u6240\u6709\u56db\u4e2a\u77e9\uff08\u5747\u503c\u3001\u65b9\u5dee\u3001\u504f\u5ea6\u3001\u5cf0\u5ea6\uff09\uff0c\u7136\u540e\u4fee\u6539\u6570\u636e\u5e76\u89c2\u5bdf\u6bcf\u4e2a\u77e9\u5982\u4f55\u53d8\u5316\u3002

    import jax.numpy as jnp\n\nx = jnp.array([2, 4, 4, 4, 5, 5, 7, 9], dtype=jnp.float32)\n\nmean = jnp.mean(x)\nvariance = jnp.mean((x - mean) ** 2)\nstd = jnp.sqrt(variance)\nskewness = jnp.mean(((x - mean) / std) ** 3)\nkurtosis = jnp.mean(((x - mean) / std) ** 4)\n\nprint(f\"\u5747\u503c:     {mean:.3f}\")\nprint(f\"\u65b9\u5dee:     {variance:.3f}\")\nprint(f\"\u6807\u51c6\u5dee:   {std:.3f}\")\nprint(f\"\u504f\u5ea6:     {skewness:.3f}\")\nprint(f\"\u5cf0\u5ea6:     {kurtosis:.3f}\")\nprint(f\"\u8d85\u503c\u5cf0\u5ea6: {kurtosis - 3:.3f}\")\n

  3. \u5e76\u6392\u53ef\u89c6\u5316\u516c\u5e73\u9ab0\u5b50\u7684 PMF \u548c CDF\u3002\u5c1d\u8bd5\u4fee\u6539\u6982\u7387\u4ee5\u89c2\u5bdf\u5f62\u72b6\u5982\u4f55\u53d8\u5316\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\nfaces = jnp.array([1, 2, 3, 4, 5, 6])\npmf = jnp.ones(6) / 6  # \u516c\u5e73\u9ab0\u5b50\uff1b\u8bd5\u8bd5\u4fee\u6539\u8fd9\u4e9b\u503c\uff01\ncdf = jnp.cumsum(pmf)\n\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))\n\nax1.bar(faces, pmf, color=\"#3498db\", alpha=0.8)\nax1.set_title(\"PMF\")\nax1.set_xlabel(\"\u9762\u503c\")\nax1.set_ylabel(\"P(X = x)\")\nax1.set_ylim(0, 0.5)\n\nax2.step(faces, cdf, where=\"mid\", color=\"#e74c3c\", linewidth=2)\nax2.set_title(\"CDF\")\nax2.set_xlabel(\"\u9762\u503c\")\nax2.set_ylabel(\"P(X \u2264 x)\")\nax2.set_ylim(0, 1.1)\n\nplt.tight_layout()\nplt.show()\n

"},{"location":"chapter%2004%3A%20statistics/02.%20measures/","title":"\u7edf\u8ba1\u91cf","text":"

\u7edf\u8ba1\u91cf\u7528\u5355\u4e2a\u6570\u503c\u6982\u62ec\u6570\u636e\uff0c\u6355\u6349\u5176\u79bb\u6563\u7a0b\u5ea6\u3001\u4f4d\u7f6e\u3001\u5f62\u72b6\u548c\u5173\u8054\u3002\u672c\u8282\u6db5\u76d6\u65b9\u5dee\u3001\u6807\u51c6\u5dee\u3001\u56db\u5206\u4f4d\u6570\u3001\u504f\u5ea6\u3001\u5cf0\u5ea6\u3001\u534f\u65b9\u5dee\u3001\u76f8\u5173\u548c z \u5206\u6570\u2014\u2014\u8fd9\u662f\u63a2\u7d22\u6027\u6570\u636e\u5206\u6790\u548c\u673a\u5668\u5b66\u4e60\u7279\u5f81\u5de5\u7a0b\u7684\u57fa\u7840\u5de5\u5177\u96c6\u3002

\\[\\sigma^2 = \\frac{1}{N} \\sum_{i=1}^{N} (x_i - \\mu)^2\\] \\[s^2 = \\frac{1}{N-1} \\sum_{i=1}^{N} (x_i - \\bar{x})^2\\] \\[\\text{MAD} = \\frac{1}{N} \\sum_{i=1}^{N} |x_i - \\mu|\\]

\\[z = \\frac{x - \\mu}{\\sigma}\\] \\[\\text{\u504f\u5ea6} = \\frac{1}{N} \\sum_{i=1}^{N} \\left(\\frac{x_i - \\mu}{\\sigma}\\right)^3\\] \\[\\text{\u5cf0\u5ea6} = \\frac{1}{N} \\sum_{i=1}^{N} \\left(\\frac{x_i - \\mu}{\\sigma}\\right)^4\\]

\\[r = \\frac{\\sum_{i=1}^{N} (x_i - \\bar{x})(y_i - \\bar{y})}{\\sqrt{\\sum (x_i - \\bar{x})^2} \\cdot \\sqrt{\\sum (y_i - \\bar{y})^2}}\\] \\[\\bar{x}_{\\text{geo}} = \\left(\\prod_{i=1}^{N} x_i\\right)^{1/N}\\] \\[\\text{EMA}_t = \\alpha \\cdot x_t + (1 - \\alpha) \\cdot \\text{EMA}_{t-1}\\] "},{"location":"chapter%2004%3A%20statistics/02.%20measures/#colab-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
  1. \u8ba1\u7b97\u6570\u636e\u96c6\u7684\u65b9\u5dee\u3001\u6807\u51c6\u5dee\u548c MAD\uff0c\u5e76\u8fdb\u884c\u6bd4\u8f83\u3002\u89c2\u5bdf\u6dfb\u52a0\u6781\u7aef\u5f02\u5e38\u503c\u65f6\u53d1\u751f\u7684\u53d8\u5316\u3002

    import jax.numpy as jnp\n\ndata = jnp.array([4, 8, 6, 5, 3, 7, 9, 5, 6, 7], dtype=jnp.float32)\n\nmean = jnp.mean(data)\nvariance = jnp.var(data)\nstd = jnp.std(data)\nmad = jnp.mean(jnp.abs(data - mean))\n\nprint(\"\u539f\u59cb\u6570\u636e\uff1a\")\nprint(f\"  \u65b9\u5dee\uff1a{variance:.3f}\uff0c\u6807\u51c6\u5dee\uff1a{std:.3f}\uff0cMAD\uff1a{mad:.3f}\")\n\n# \u6dfb\u52a0\u4e00\u4e2a\u5f02\u5e38\u503c\u5e76\u91cd\u65b0\u8ba1\u7b97\ndata_outlier = jnp.append(data, 100.0)\nmean2 = jnp.mean(data_outlier)\nprint(f\"\\n\u6dfb\u52a0\u5f02\u5e38\u503c\uff08100\uff09\u540e\uff1a\")\nprint(f\"  \u65b9\u5dee\uff1a{jnp.var(data_outlier):.3f}\uff0c\u6807\u51c6\u5dee\uff1a{jnp.std(data_outlier):.3f}\uff0cMAD\uff1a{jnp.mean(jnp.abs(data_outlier - mean2)):.3f}\")\n

  2. \u8ba1\u7b97\u4e24\u4e2a\u53d8\u91cf\u4e4b\u95f4\u7684\u76ae\u5c14\u68ee\u76f8\u5173\u548c\u65af\u76ae\u5c14\u66fc\u76f8\u5173\u3002\u5c1d\u8bd5\u4e0d\u540c\u7684\u5173\u7cfb\u3002

    import jax\nimport jax.numpy as jnp\n\n# \u5b8c\u5168\u7ebf\u6027\u5173\u7cfb\nx = jnp.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=jnp.float32)\ny = 2 * x + 1  # \u8bd5\u8bd5\u4fee\u6539\u8fd9\u4e2a\uff01\n\ndef pearson(a, b):\n    a_c = a - jnp.mean(a)\n    b_c = b - jnp.mean(b)\n    return jnp.sum(a_c * b_c) / (jnp.sqrt(jnp.sum(a_c**2)) * jnp.sqrt(jnp.sum(b_c**2)))\n\ndef spearman(a, b):\n    rank_a = jnp.argsort(jnp.argsort(a)).astype(jnp.float32)\n    rank_b = jnp.argsort(jnp.argsort(b)).astype(jnp.float32)\n    return pearson(rank_a, rank_b)\n\nprint(f\"\u76ae\u5c14\u68ee r\uff1a  {pearson(x, y):.4f}\")\nprint(f\"\u65af\u76ae\u5c14\u66fc \u03c1\uff1a{spearman(x, y):.4f}\")\n

  3. \u5206\u522b\u4f7f\u7528 IQR \u548c Z \u5206\u6570\u65b9\u6cd5\u5b9e\u73b0\u5f02\u5e38\u503c\u68c0\u6d4b\uff0c\u7136\u540e\u6bd4\u8f83\u5b83\u4eec\u5728\u504f\u659c\u6570\u636e\u4e0a\u7684\u7ed3\u679c\u3002

    import jax.numpy as jnp\n\ndata = jnp.array([2, 3, 3, 4, 5, 5, 5, 6, 6, 7, 50], dtype=jnp.float32)\n\n# IQR \u65b9\u6cd5\nq1, q3 = jnp.percentile(data, 25), jnp.percentile(data, 75)\niqr = q3 - q1\nlower, upper = q1 - 1.5 * iqr, q3 + 1.5 * iqr\niqr_outliers = data[(data < lower) | (data > upper)]\nprint(f\"IQR \u8fb9\u754c\uff1a[{lower:.1f}, {upper:.1f}]\")\nprint(f\"IQR \u5f02\u5e38\u503c\uff1a{iqr_outliers}\")\n\n# Z \u5206\u6570\u65b9\u6cd5\nz_scores = (data - jnp.mean(data)) / jnp.std(data)\nz_outliers = data[jnp.abs(z_scores) > 3]\nprint(f\"\\nZ \u5206\u6570\uff1a{z_scores}\")\nprint(f\"Z \u5206\u6570\u5f02\u5e38\u503c\uff08|z| > 3\uff09\uff1a{z_outliers}\")\n

  4. \u5728\u4e0d\u540c\u5e73\u6ed1\u56e0\u5b50\u4e0b\u8ba1\u7b97\u5e76\u7ed8\u5236\u5e26\u566a\u58f0\u6570\u636e\u7684\u6307\u6570\u79fb\u52a8\u5e73\u5747\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u751f\u6210\u5e26\u566a\u58f0\u7684\u6570\u636e\nkey = __import__(\"jax\").random.PRNGKey(0)\nnoise = __import__(\"jax\").random.normal(key, shape=(50,))\nsignal = jnp.linspace(0, 5, 50) + noise\n\ndef ema(data, alpha):\n    result = jnp.zeros_like(data)\n    result = result.at[0].set(data[0])\n    for t in range(1, len(data)):\n        result = result.at[t].set(alpha * data[t] + (1 - alpha) * result[t - 1])\n    return result\n\nplt.figure(figsize=(10, 4))\nplt.plot(signal, \"o\", alpha=0.3, label=\"\u539f\u59cb\u6570\u636e\", color=\"#999\")\nfor alpha, color in [(0.1, \"#e74c3c\"), (0.3, \"#3498db\"), (0.7, \"#27ae60\")]:\n    plt.plot(ema(signal, alpha), label=f\"\u03b1={alpha}\", color=color, linewidth=2)\nplt.legend()\nplt.title(\"\u4e0d\u540c\u5e73\u6ed1\u56e0\u5b50\u4e0b\u7684 EMA\")\nplt.show()\n

"},{"location":"chapter%2004%3A%20statistics/03.%20sampling/","title":"\u62bd\u6837","text":"

\u62bd\u6837\u51b3\u5b9a\u4e86\u6211\u4eec\u5982\u4f55\u6536\u96c6\u6570\u636e\uff0c\u5e76\u76f4\u63a5\u63a7\u5236\u7740\u6211\u4eec\u6240\u505a\u6bcf\u9879\u7ed3\u8bba\u7684\u8d28\u91cf\u3002\u672c\u6587\u6db5\u76d6\u968f\u673a\u62bd\u6837\u3001\u5206\u5c42\u62bd\u6837\u3001\u6574\u7fa4\u62bd\u6837\u4e0e\u7cfb\u7edf\u62bd\u6837\u3001\u62bd\u6837\u5206\u5e03\u3001\u5927\u6570\u5b9a\u5f8b\u4ee5\u53ca\u81ea\u52a9\u6cd5\u2014\u2014\u8fd9\u4e9b\u65b9\u6cd5\u5bf9\u4e8e\u673a\u5668\u5b66\u4e60\u4e2d\u7684\u8bad\u7ec3/\u6d4b\u8bd5\u5212\u5206\u548c\u6570\u636e\u96c6\u6574\u7406\u81f3\u5173\u91cd\u8981\u3002

\\[SE = \\frac{\\sigma}{\\sqrt{n}}\\]

\\[\\bar{X} \\approx \\text{Normal}\\!\\left(\\mu, \\frac{\\sigma^2}{n}\\right)\\] "},{"location":"chapter%2004%3A%20statistics/03.%20sampling/#colab-notebook","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
  1. \u53ef\u89c6\u5316\u6f14\u793a CLT\uff1a\u4ece\u9ad8\u5ea6\u504f\u6001\u7684\u5206\u5e03\u4e2d\u62bd\u53d6\u6837\u672c\uff0c\u8ba1\u7b97\u6837\u672c\u5747\u503c\uff0c\u89c2\u5bdf\u5747\u503c\u76f4\u65b9\u56fe\u5982\u4f55\u53d8\u6210\u949f\u5f62\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\nkey = jax.random.PRNGKey(0)\n\n# \u6307\u6570\u5206\u5e03\uff08\u9ad8\u5ea6\u504f\u6001\uff09\npopulation = jax.random.exponential(key, shape=(100_000,))\n\nfig, axes = plt.subplots(1, 4, figsize=(14, 3))\nsample_sizes = [1, 5, 30, 100]\n\nfor ax, n in zip(axes, sample_sizes):\n    keys = jax.random.split(key, 2000)\n    means = jnp.array([jax.random.choice(k, population, shape=(n,)).mean() for k in keys])\n    ax.hist(means, bins=40, color=\"#3498db\", alpha=0.7, density=True)\n    ax.set_title(f\"n = {n}\")\n    ax.set_xlim(0, 4)\n\nfig.suptitle(\"CLT\uff1a\u968f\u7740 n \u589e\u5927\uff0c\u6837\u672c\u5747\u503c\u8d8b\u8fd1\u6b63\u6001\u5206\u5e03\", fontsize=13)\nplt.tight_layout()\nplt.show()\n

  2. \u6bd4\u8f83\u7b80\u5355\u968f\u673a\u62bd\u6837\u4e0e\u5206\u5c42\u62bd\u6837\u3002\u521b\u5efa\u4e00\u4e2a\u5177\u6709\u4e0d\u540c\u5206\u7ec4\u7684\u603b\u4f53\uff0c\u5c55\u793a\u5206\u5c42\u62bd\u6837\u80fd\u7ed9\u51fa\u66f4\u4f4e\u7684\u4f30\u8ba1\u65b9\u5dee\u3002

    import jax\nimport jax.numpy as jnp\n\nkey = jax.random.PRNGKey(42)\n\n# \u603b\u4f53\uff1a\u4e24\u4e2a\u4e0d\u540c\u7684\u7ec4\ngroup_a = jax.random.normal(key, shape=(500,)) + 10   # \u5747\u503c ~10\nkey, subkey = jax.random.split(key)\ngroup_b = jax.random.normal(subkey, shape=(500,)) + 20  # \u5747\u503c ~20\npopulation = jnp.concatenate([group_a, group_b])\n\n# \u7b80\u5355\u968f\u673a\u62bd\u6837\uff1a1000 \u6b21\u8bd5\u9a8c\uff0c\u6837\u672c\u91cf 20\nsrs_means = []\nfor i in range(1000):\n    key, subkey = jax.random.split(key)\n    sample = jax.random.choice(subkey, population, shape=(20,), replace=False)\n    srs_means.append(sample.mean())\nsrs_means = jnp.array(srs_means)\n\n# \u5206\u5c42\u62bd\u6837\uff1a\u6bcf\u7ec4\u5404\u53d6 10 \u4e2a\nstrat_means = []\nfor i in range(1000):\n    key, k1, k2 = jax.random.split(key, 3)\n    s_a = jax.random.choice(k1, group_a, shape=(10,), replace=False)\n    s_b = jax.random.choice(k2, group_b, shape=(10,), replace=False)\n    strat_means.append(jnp.concatenate([s_a, s_b]).mean())\nstrat_means = jnp.array(strat_means)\n\nprint(f\"\u7b80\u5355\u968f\u673a - \u5747\u503c: {srs_means.mean():.3f}, \u6807\u51c6\u5dee: {srs_means.std():.3f}\")\nprint(f\"\u5206\u5c42\u62bd\u6837 - \u5747\u503c: {strat_means.mean():.3f}, \u6807\u51c6\u5dee: {strat_means.std():.3f}\")\nprint(f\"\u5206\u5c42\u62bd\u6837\u964d\u4f4e\u4e86\u65b9\u5dee {(1 - strat_means.var()/srs_means.var())*100:.1f}%\")\n

  3. \u63a2\u7d22\u6837\u672c\u91cf\u5982\u4f55\u5f71\u54cd\u6807\u51c6\u8bef\u3002\u7ed8\u5236\u6807\u51c6\u8bef\u968f\u6837\u672c\u91cf\u53d8\u5316\u7684\u66f2\u7ebf\uff0c\u9a8c\u8bc1 \\(1/\\sqrt{n}\\) \u7684\u5173\u7cfb\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\nkey = jax.random.PRNGKey(7)\npopulation = jax.random.normal(key, shape=(50_000,)) * 10 + 50\n\nsample_sizes = [5, 10, 20, 50, 100, 200, 500, 1000]\nstd_errors = []\n\nfor n in sample_sizes:\n    means = []\n    for _ in range(500):\n        key, subkey = jax.random.split(key)\n        sample = jax.random.choice(subkey, population, shape=(n,))\n        means.append(sample.mean())\n    std_errors.append(jnp.array(means).std())\n\nplt.figure(figsize=(8, 4))\nplt.plot(sample_sizes, std_errors, \"o-\", color=\"#e74c3c\", label=\"\u89c2\u6d4b\u5230\u7684 SE\")\ntheoretical = population.std() / jnp.sqrt(jnp.array(sample_sizes, dtype=jnp.float32))\nplt.plot(sample_sizes, theoretical, \"--\", color=\"#3498db\", label=\"\u03c3/\u221an\uff08\u7406\u8bba\u503c\uff09\")\nplt.xlabel(\"\u6837\u672c\u91cf (n)\")\nplt.ylabel(\"\u6807\u51c6\u8bef\")\nplt.legend()\nplt.title(\"\u6807\u51c6\u8bef\u968f\u6837\u672c\u91cf\u589e\u5927\u800c\u7f29\u5c0f\")\nplt.show()\n

"},{"location":"chapter%2004%3A%20statistics/04.%20hypothesis%20testing/","title":"\u5047\u8bbe\u68c0\u9a8c","text":"

\u5047\u8bbe\u68c0\u9a8c\u63d0\u4f9b\u4e86\u4e00\u4e2a\u4e25\u8c28\u7684\u6846\u67b6\uff0c\u7528\u4e8e\u5224\u65ad\u89c2\u6d4b\u5230\u7684\u6548\u5e94\u662f\u771f\u5b9e\u5b58\u5728\u7684\u8fd8\u662f\u7531\u968f\u673a\u56e0\u7d20\u9020\u6210\u7684\u3002\u672c\u6587\u6db5\u76d6\u539f\u5047\u8bbe\u4e0e\u5907\u62e9\u5047\u8bbe\u3001p\u503c\u3001\u663e\u8457\u6027\u6c34\u5e73\u3001t\u68c0\u9a8c\u3001\u5361\u65b9\u68c0\u9a8c\u3001\u65b9\u5dee\u5206\u6790\u4ee5\u53ca\u7b2c\u4e00\u7c7b/\u7b2c\u4e8c\u7c7b\u9519\u8bef\u2014\u2014\u8fd9\u4e9b\u903b\u8f91\u540c\u6837\u5e94\u7528\u4e8eA/B\u6d4b\u8bd5\u3001\u6a21\u578b\u6bd4\u8f83\u548c\u7814\u7a76\u4e2d\u3002

\\[z = \\frac{\\bar{x} - \\mu_0}{\\sigma / \\sqrt{n}} = \\frac{10.3 - 10}{0.9 / \\sqrt{36}} = \\frac{0.3}{0.15} = 2.0\\]

\\[z = \\frac{\\bar{x} - \\mu_0}{\\sigma / \\sqrt{n}}\\] \\[t = \\frac{\\bar{x} - \\mu_0}{s / \\sqrt{n}}\\] \\[F = \\frac{\\text{\u7ec4\u95f4\u65b9\u5dee}}{\\text{\u7ec4\u5185\u65b9\u5dee}}\\] \\[\\chi^2 = \\sum \\frac{(O_i - E_i)^2}{E_i}\\] "},{"location":"chapter%2004%3A%20statistics/04.%20hypothesis%20testing/#colabnotebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528CoLab\u6216notebook\uff09","text":"
  1. \u5bf9\u6587\u4e2d\u7684\u87ba\u6813\u5de5\u5382\u793a\u4f8b\u6267\u884cz\u68c0\u9a8c\u3002\u8ba1\u7b97\u68c0\u9a8c\u7edf\u8ba1\u91cf\u3001p\u503c\u5e76\u505a\u51fa\u51b3\u7b56\u3002

    import jax.numpy as jnp\n\nx_bar = 10.3    # \u6837\u672c\u5747\u503c\nmu_0 = 10.0     # \u539f\u5047\u8bbe\u503c\nsigma = 0.9     # \u5df2\u77e5\u603b\u4f53\u6807\u51c6\u5dee\nn = 36           # \u6837\u672c\u91cf\nalpha = 0.05\n\n# \u68c0\u9a8c\u7edf\u8ba1\u91cf\nz = (x_bar - mu_0) / (sigma / jnp.sqrt(n))\nprint(f\"z = {z:.4f}\")\n\n# p\u503c\uff08\u53cc\u4fa7\u68c0\u9a8c\uff09\u4f7f\u7528\u6b63\u6001CDF\u8fd1\u4f3c\n# \u5bf9\u4e8e |z| = 2.0\uff0cp \u2248 0.0456\nfrom jax.scipy.stats import norm\np_value = 2 * (1 - norm.cdf(jnp.abs(z)))\nprint(f\"p\u503c = {p_value:.4f}\")\nprint(f\"\u62d2\u7eddH\u2080\uff1f{p_value <= alpha}\")\n

  2. \u6a21\u62df\u7b2c\u4e00\u7c7b\u9519\u8bef\uff1a\u5f53 \\(H_0\\) \u4e3a\u771f\u65f6\uff0c\u6211\u4eec\u72af\u9519\u8bef\u7684\u9891\u7387\u6709\u591a\u9ad8\uff1f\u8fd0\u884c10,000\u6b21\u5b9e\u9a8c\uff0c\u68c0\u9a8c\u62d2\u7edd\u7387\u662f\u5426\u4e0e \\(\\alpha\\) \u4e00\u81f4\u3002

    import jax\nimport jax.numpy as jnp\n\nkey = jax.random.PRNGKey(0)\nmu_0 = 50.0\nsigma = 10.0\nn = 30\nalpha = 0.05\nn_experiments = 10_000\n\nrejections = 0\nfor i in range(n_experiments):\n    key, subkey = jax.random.split(key)\n    sample = mu_0 + sigma * jax.random.normal(subkey, shape=(n,))\n    z = (sample.mean() - mu_0) / (sigma / jnp.sqrt(n))\n    p_value = 2 * (1 - __import__(\"jax\").scipy.stats.norm.cdf(jnp.abs(z)))\n    if p_value <= alpha:\n        rejections += 1\n\nprint(f\"\u62d2\u7edd\u7387\uff1a{rejections/n_experiments:.4f}\")\nprint(f\"\u671f\u671b\u503c\uff08\u03b1\uff09\uff1a  {alpha}\")\n

  3. \u5bf9\u4e24\u7ec4\u6570\u636e\u5206\u522b\u8fd0\u884ct\u68c0\u9a8c\u548cMann-Whitney U\u68c0\u9a8c\u3002\u751f\u6210\u4e00\u7ec4\u5747\u503c\u7565\u9ad8\u4e8e\u53e6\u4e00\u7ec4\u7684\u6570\u636e\uff0c\u89c2\u5bdf\u54ea\u79cd\u68c0\u9a8c\u80fd\u68c0\u6d4b\u51fa\u5dee\u5f02\u3002

    import jax\nimport jax.numpy as jnp\n\nkey = jax.random.PRNGKey(99)\nk1, k2 = jax.random.split(key)\n\ngroup_a = jax.random.normal(k1, shape=(25,)) * 5 + 100\ngroup_b = jax.random.normal(k2, shape=(25,)) * 5 + 103  # \u5747\u503c\u7565\u9ad8\n\n# \u53cc\u6837\u672ct\u68c0\u9a8c\uff08\u5047\u8bbe\u65b9\u5dee\u76f8\u7b49\uff09\nn_a, n_b = len(group_a), len(group_b)\nmean_a, mean_b = group_a.mean(), group_b.mean()\npooled_var = ((n_a - 1) * group_a.var() + (n_b - 1) * group_b.var()) / (n_a + n_b - 2)\nse = jnp.sqrt(pooled_var * (1/n_a + 1/n_b))\nt_stat = (mean_a - mean_b) / se\nprint(f\"t\u68c0\u9a8c\u7edf\u8ba1\u91cf\uff1a{t_stat:.4f}\")\n\n# Mann-Whitney\uff1a\u7edf\u8ba1group_a\u7684\u503c\u5c0f\u4e8egroup_b\u503c\u7684\u6b21\u6570\nu_stat = jnp.sum(group_a[:, None] < group_b[None, :])\nprint(f\"Mann-Whitney U\uff1a  {u_stat}\")\nprint(f\"\\nA\u7ec4\u5747\u503c\uff1a{mean_a:.2f}\uff0cB\u7ec4\u5747\u503c\uff1a{mean_b:.2f}\")\n

"},{"location":"chapter%2004%3A%20statistics/05.%20inference/","title":"\u7edf\u8ba1\u63a8\u65ad","text":"

\u7edf\u8ba1\u63a8\u65ad\u8d85\u8d8a\u4e86\u7b80\u5355\u7684\"\u662f/\u5426\"\u51b3\u7b56\uff0c\u4ee5\u91cf\u5316\u7684\u4e0d\u786e\u5b9a\u6027\u6765\u4f30\u8ba1\u603b\u4f53\u53c2\u6570\u3002\u672c\u8282\u6db5\u76d6\u7f6e\u4fe1\u533a\u95f4\u3001\u70b9\u4f30\u8ba1\u4e0e\u533a\u95f4\u4f30\u8ba1\u3001\u6781\u5927\u4f3c\u7136\u4f30\u8ba1\u3001\u77e9\u6cd5\u4ee5\u53ca\u56de\u5f52\u5206\u6790\u2014\u2014\u8fd9\u662f\u8fde\u63a5\u539f\u59cb\u6570\u636e\u4e0e\u673a\u5668\u5b66\u4e60\u9884\u6d4b\u6a21\u578b\u7684\u6865\u6881\u3002

\\[\\text{CI} = \\bar{x} \\pm \\text{ME}\\] \\[\\text{ME} = z^\\ast \\cdot \\frac{\\sigma}{\\sqrt{n}}\\]

\\[\\text{ME} = 1.96 \\cdot \\frac{8}{\\sqrt{50}} = 1.96 \\cdot 1.131 = 2.22 \\text{ cm}\\] \\[\\text{CI} = [170 - 2.22, \\; 170 + 2.22] = [167.78, \\; 172.22]\\] \\[\\text{CI} = \\bar{x} \\pm t^\\ast_{n-1} \\cdot \\frac{s}{\\sqrt{n}}\\] \\[n = \\left(\\frac{(z_{\\alpha/2} + z_{\\beta}) \\cdot \\sigma}{\\delta}\\right)^2\\] \\[n = \\left(\\frac{(1.96 + 0.84) \\cdot 8}{2}\\right)^2 = \\left(\\frac{22.4}{2}\\right)^2 = 11.2^2 \\approx 126\\]

\\[\\pi \\approx 4 \\times \\frac{\\text{\u5706\u5185\u70b9\u6570}}{\\text{\u603b\u70b9\u6570}}\\] \\[x_i = \\lambda_{i1} f_1 + \\lambda_{i2} f_2 + \\ldots + \\lambda_{ik} f_k + \\epsilon_i\\] "},{"location":"chapter%2004%3A%20statistics/05.%20inference/#colab-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u5728 CoLab \u6216 notebook \u4e2d\u5b8c\u6210\uff09","text":"
  1. \u4e3a\u8eab\u9ad8\u793a\u4f8b\u6784\u5efa\u4e00\u4e2a 95% \u7f6e\u4fe1\u533a\u95f4\uff0c\u7136\u540e\u5c1d\u8bd5\u4e0d\u540c\u7684\u7f6e\u4fe1\u6c34\u5e73\u548c\u6837\u672c\u91cf\u3002

    import jax.numpy as jnp\n\nx_bar = 170.0    # \u6837\u672c\u5747\u503c\nsigma = 8.0      # \u603b\u4f53\u6807\u51c6\u5dee\uff08\u5df2\u77e5\uff09\nn = 50           # \u6837\u672c\u91cf\n\n# \u5e38\u7528\u7f6e\u4fe1\u6c34\u5e73\u7684\u4e34\u754c\u503c\nz_stars = {0.90: 1.645, 0.95: 1.960, 0.99: 2.576}\n\nfor conf, z_star in z_stars.items():\n    me = z_star * (sigma / jnp.sqrt(n))\n    lower, upper = x_bar - me, x_bar + me\n    print(f\"{conf*100:.0f}% CI: [{lower:.2f}, {upper:.2f}]  (ME = {me:.2f})\")\n

  2. \u4f7f\u7528\u8499\u7279\u5361\u6d1b\u6a21\u62df\u4f30\u7b97 \\(\\pi\\)\u3002\u7ed8\u5236\u968f\u7740\u70b9\u6570\u589e\u52a0\u4f30\u7b97\u503c\u6536\u655b\u7684\u66f2\u7ebf\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\nkey = jax.random.PRNGKey(42)\n\n# \u5728 [-1, 1] x [-1, 1] \u5185\u751f\u6210\u968f\u673a\u70b9\nn_points = 100_000\nk1, k2 = jax.random.split(key)\nx = jax.random.uniform(k1, shape=(n_points,), minval=-1, maxval=1)\ny = jax.random.uniform(k2, shape=(n_points,), minval=-1, maxval=1)\n\n# \u68c0\u67e5\u54ea\u4e9b\u70b9\u5728\u5355\u4f4d\u5706\u5185\ninside = (x**2 + y**2) <= 1.0\ncumulative_inside = jnp.cumsum(inside)\ncounts = jnp.arange(1, n_points + 1)\npi_estimates = 4.0 * cumulative_inside / counts\n\nplt.figure(figsize=(10, 4))\nplt.plot(pi_estimates, color=\"#3498db\", alpha=0.7, linewidth=0.5)\nplt.axhline(y=jnp.pi, color=\"#e74c3c\", linestyle=\"--\", label=f\"\u03c0 = {jnp.pi:.6f}\")\nplt.xlabel(\"\u70b9\u6570\")\nplt.ylabel(\"\u03c0 \u7684\u4f30\u7b97\u503c\")\nplt.title(\"\u8499\u7279\u5361\u6d1b\u4f30\u7b97 \u03c0\")\nplt.legend()\nplt.ylim(2.8, 3.5)\nplt.show()\n\nprint(f\"\u6700\u7ec8\u4f30\u7b97\u503c: {pi_estimates[-1]:.6f}\")\nprint(f\"\u771f\u5b9e\u503c:     {jnp.pi:.6f}\")\nprint(f\"\u8bef\u5dee:       {abs(pi_estimates[-1] - jnp.pi):.6f}\")\n

  3. \u6267\u884c\u4e00\u4e2a\u7b80\u5355\u7684\u529f\u6548\u5206\u6790\uff1a\u7ed9\u5b9a\u6548\u5e94\u5927\u5c0f\u548c\u6807\u51c6\u5dee\uff0c\u8ba1\u7b97\u6240\u9700\u6837\u672c\u91cf\u5e76\u901a\u8fc7\u6a21\u62df\u9a8c\u8bc1\u3002

    import jax\nimport jax.numpy as jnp\n\n# \u53c2\u6570\ndelta = 2.0      # \u6548\u5e94\u5927\u5c0f\uff08\u5747\u503c\u5dee\uff09\nsigma = 8.0      # \u603b\u4f53\u6807\u51c6\u5dee\nalpha = 0.05\npower_target = 0.80\n\n# \u89e3\u6790\u8ba1\u7b97\u7684\u6837\u672c\u91cf\nz_alpha = 1.96   # \u53cc\u5c3e\uff0calpha=0.05\nz_beta = 0.84    # power=0.80\nn_required = ((z_alpha + z_beta) * sigma / delta) ** 2\nprint(f\"\u6bcf\u7ec4\u6240\u9700\u6837\u672c\u91cf: {n_required:.0f}\")\n\n# \u901a\u8fc7\u6a21\u62df\u9a8c\u8bc1\nkey = jax.random.PRNGKey(7)\nn = int(jnp.ceil(n_required))\nn_sims = 5000\nrejections = 0\n\nfor _ in range(n_sims):\n    key, k1, k2 = jax.random.split(key, 3)\n    group_a = jax.random.normal(k1, shape=(n,)) * sigma + 50\n    group_b = jax.random.normal(k2, shape=(n,)) * sigma + 50 + delta\n    pooled_se = jnp.sqrt(2 * sigma**2 / n)\n    z = (group_b.mean() - group_a.mean()) / pooled_se\n    p = 2 * (1 - __import__(\"jax\").scipy.stats.norm.cdf(jnp.abs(z)))\n    if p <= alpha:\n        rejections += 1\n\nprint(f\"\u6a21\u62df\u529f\u6548: {rejections/n_sims:.3f}\")\nprint(f\"\u76ee\u6807\u529f\u6548: {power_target:.3f}\")\n

  4. \u53ef\u89c6\u5316\u7f6e\u4fe1\u533a\u95f4\u5bbd\u5ea6\u968f\u6837\u672c\u91cf\u7684\u53d8\u5316\u3002\u8fd9\u5c55\u793a\u4e86\u4e3a\u4ec0\u4e48\u6536\u96c6\u66f4\u591a\u6570\u636e\u53ef\u4ee5\u5f97\u5230\u66f4\u7cbe\u786e\u7684\u4f30\u8ba1\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\nsigma = 8.0\nz_star = 1.96  # 95% \u7f6e\u4fe1\u5ea6\n\nsample_sizes = jnp.array([10, 20, 30, 50, 100, 200, 500, 1000], dtype=jnp.float32)\nmargins = z_star * sigma / jnp.sqrt(sample_sizes)\n\nplt.figure(figsize=(8, 4))\nplt.bar([str(int(n)) for n in sample_sizes], margins, color=\"#3498db\", alpha=0.7)\nplt.xlabel(\"\u6837\u672c\u91cf\")\nplt.ylabel(\"\u8bef\u5dee\u8303\u56f4 (cm)\")\nplt.title(\"95% CI \u8bef\u5dee\u8303\u56f4\u968f\u6837\u672c\u91cf\u589e\u5927\u800c\u7f29\u5c0f\")\nplt.show()\n

"},{"location":"chapter%2005%3A%20probability/01.%20counting/","title":"\u8ba1\u6570","text":"

\u8ba1\u6570\u662f\u8ba1\u7b97\u6982\u7387\u7684\u524d\u63d0\u2014\u2014\u5728\u5206\u914d\u53ef\u80fd\u6027\u4e4b\u524d\uff0c\u4f60\u5fc5\u987b\u5148\u77e5\u9053\u6709\u591a\u5c11\u79cd\u7ed3\u679c\u3002\u672c\u6587\u6db5\u76d6\u4e58\u6cd5\u4e0e\u52a0\u6cd5\u89c4\u5219\u3001\u9636\u4e58\u3001\u6392\u5217\u3001\u7ec4\u5408\u3001\u4e8c\u9879\u5f0f\u7cfb\u6570\uff0c\u4ee5\u53ca\u652f\u6491\u673a\u5668\u5b66\u4e60\u4e2d\u91c7\u6837\u3001\u54c8\u5e0c\u548c\u6982\u7387\u5206\u6790\u7684\u57fa\u672c\u7ec4\u5408\u5de5\u5177\u3002

\\[n! = n \\times (n-1) \\times (n-2) \\times \\cdots \\times 2 \\times 1\\] \\[P(n, r) = \\frac{n!}{(n - r)!}\\] \\[C(n, r) = \\binom{n}{r} = \\frac{n!}{r!(n - r)!}\\]

\\[\\binom{10}{3} = \\frac{10!}{3! \\cdot 7!} = \\frac{10 \\times 9 \\times 8}{3 \\times 2 \\times 1} = 120\\] \\[\\binom{8}{3} = \\frac{8!}{3! \\cdot 5!} = \\frac{8 \\times 7 \\times 6}{3 \\times 2 \\times 1} = 56\\] \\[\\binom{6}{2} = \\frac{6!}{2! \\cdot 4!} = \\frac{6 \\times 5}{2 \\times 1} = 15\\] \\[56 \\times 15 = 840 \\text{ \u4e2a\u59d4\u5458\u4f1a}\\] \\[\\binom{n + r - 1}{r} = \\frac{(n + r - 1)!}{r!(n - 1)!}\\] \u573a\u666f \u516c\u5f0f \u6709\u5e8f\uff0c\u65e0\u91cd\u590d\uff08\u6392\u5217\uff09 \\(P(n,r) = \\frac{n!}{(n-r)!}\\) \u65e0\u5e8f\uff0c\u65e0\u91cd\u590d\uff08\u7ec4\u5408\uff09 \\(\\binom{n}{r} = \\frac{n!}{r!(n-r)!}\\) \u6709\u5e8f\uff0c\u53ef\u91cd\u590d \\(n^r\\) \u65e0\u5e8f\uff0c\u53ef\u91cd\u590d \\(\\binom{n+r-1}{r}\\) "},{"location":"chapter%2005%3A%20probability/01.%20counting/#colab-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u5728 CoLab \u6216 notebook \u4e2d\u5b8c\u6210\uff09","text":"
  1. \u4f7f\u7528\u9636\u4e58\u516c\u5f0f\u548c\u76f4\u63a5\u8ba1\u7b97\u4e24\u79cd\u65b9\u5f0f\u8ba1\u7b97 \\(P(10, 3)\\) \u548c \\(\\binom{10}{3}\\)\u3002\u9a8c\u8bc1\u6392\u5217\u6570\u603b\u662f\u7ec4\u5408\u6570\u7684 \\(r!\\) \u500d\u3002

    import jax.numpy as jnp\nfrom math import factorial\n\nn, r = 10, 3\n\nperm = factorial(n) // factorial(n - r)\ncomb = factorial(n) // (factorial(r) * factorial(n - r))\n\nprint(f\"P({n},{r}) = {perm}\")\nprint(f\"C({n},{r}) = {comb}\")\nprint(f\"P / C = {perm // comb} (\u5e94\u7b49\u4e8e {r}! = {factorial(r)})\")\n

  2. \u901a\u8fc7\u7a0b\u5e8f\u89e3\u51b3\u59d4\u5458\u4f1a\u95ee\u9898\uff088 \u4eba\u4e2d\u9009 3 \u540d\u7537\u6027\uff0c6 \u4eba\u4e2d\u9009 2 \u540d\u5973\u6027\uff09\uff0c\u5e76\u901a\u8fc7\u679a\u4e3e\u6240\u6709\u6709\u6548\u59d4\u5458\u4f1a\u6765\u9a8c\u8bc1\u3002

    from itertools import combinations\nfrom math import factorial\n\ndef comb_count(n, r):\n    return factorial(n) // (factorial(r) * factorial(n - r))\n\n# \u516c\u5f0f\u6cd5\nmen_ways = comb_count(8, 3)\nwomen_ways = comb_count(6, 2)\nprint(f\"\u516c\u5f0f\u6cd5: {men_ways} \u00d7 {women_ways} = {men_ways * women_ways}\")\n\n# \u679a\u4e3e\u6cd5\nmen = [f\"M{i}\" for i in range(1, 9)]\nwomen = [f\"W{i}\" for i in range(1, 7)]\ncount = sum(1 for _ in combinations(men, 3) for _ in combinations(women, 2))\nprint(f\"\u679a\u4e3e\u6cd5: {count}\")\n

  3. \u7edf\u8ba1\u7531 26 \u4e2a\u5c0f\u5199\u5b57\u6bcd\u7ec4\u6210\u7684 4 \u4f4d\u5bc6\u7801\u6709\u591a\u5c11\u79cd\uff08\u5141\u8bb8\u91cd\u590d\uff09\u3002\u7136\u540e\u7edf\u8ba1\u6ca1\u6709\u91cd\u590d\u5b57\u6bcd\u7684\u5bc6\u7801\u6709\u591a\u5c11\u79cd\u3002

    from math import factorial\n\nn = 26\nr = 4\n\nwith_rep = n ** r\nwithout_rep = factorial(n) // factorial(n - r)\n\nprint(f\"\u5141\u8bb8\u91cd\u590d:    {with_rep:>10,}\")\nprint(f\"\u4e0d\u5141\u8bb8\u91cd\u590d: {without_rep:>10,}\")\nprint(f\"\u542b\u91cd\u590d\u7684\u6bd4\u4f8b: {1 - without_rep/with_rep:.2%}\")\n

  4. \u6a21\u62df\u751f\u65e5\u95ee\u9898\uff1a\u5728 \\(k\\) \u4eba\u7684\u7fa4\u4f53\u4e2d\uff0c\u81f3\u5c11\u4e24\u4eba\u5171\u4eab\u751f\u65e5\u7684\u6982\u7387\u662f\u591a\u5c11\uff1f\u7ed8\u5236 \\(k = 1\\) \u5230 \\(60\\) \u7684\u6982\u7387\u66f2\u7ebf\uff0c\u5e76\u627e\u51fa\u6982\u7387\u8d85\u8fc7 50% \u7684\u4f4d\u7f6e\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef birthday_prob_exact(k):\n    \\\"\\\"\\\"k \u4eba\u7fa4\u4f53\u4e2d\u81f3\u5c11\u6709\u4e00\u5bf9\u5171\u4eab\u751f\u65e5\u7684\u6982\u7387\u3002\\\"\\\"\\\"\n    p_no_match = 1.0\n    for i in range(k):\n        p_no_match *= (365 - i) / 365\n    return 1 - p_no_match\n\nks = list(range(1, 61))\nprobs = [birthday_prob_exact(k) for k in ks]\n\nplt.figure(figsize=(8, 4))\nplt.plot(ks, probs, color=\"#3498db\", linewidth=2)\nplt.axhline(y=0.5, color=\"#e74c3c\", linestyle=\"--\", alpha=0.7, label=\"50%\")\ncross = next(k for k, p in zip(ks, probs) if p >= 0.5)\nplt.axvline(x=cross, color=\"#e74c3c\", linestyle=\"--\", alpha=0.7)\nplt.xlabel(\"\u7fa4\u4f53\u5927\u5c0f (k)\")\nplt.ylabel(\"P(\u81f3\u5c11\u4e24\u4eba\u5171\u4eab\u751f\u65e5)\")\nplt.title(f\"\u751f\u65e5\u95ee\u9898\uff08\u5728 k={cross} \u65f6\u8d85\u8fc7 50%\uff09\")\nplt.legend()\nplt.grid(alpha=0.3)\nplt.show()\n

"},{"location":"chapter%2005%3A%20probability/02.%20probability%20concepts/","title":"\u6982\u7387\u6982\u5ff5","text":"

\u6982\u7387\u8bba\u5f62\u5f0f\u5316\u4e86\u4e0d\u786e\u5b9a\u6027\uff0c\u5e76\u63d0\u4f9b\u4e86\u5728\u6b64\u6846\u67b6\u4e0b\u8fdb\u884c\u63a8\u7406\u7684\u89c4\u5219\u3002\u672c\u6587\u6db5\u76d6\u6837\u672c\u7a7a\u95f4\u3001\u4e8b\u4ef6\u3001\u6982\u7387\u516c\u7406\u3001\u6761\u4ef6\u6982\u7387\u3001\u72ec\u7acb\u6027\u3001\u8d1d\u53f6\u65af\u5b9a\u7406\u3001\u9891\u7387\u6d3e\u4e0e\u8d1d\u53f6\u65af\u6d3e\u89e3\u91ca\uff0c\u8fd9\u662f\u673a\u5668\u5b66\u4e60\u4e2d\u6bcf\u4e2a\u751f\u6210\u6a21\u578b\u548c\u5224\u522b\u6a21\u578b\u80cc\u540e\u7684\u6570\u5b66\u6846\u67b6\u3002

\\[P(A) = \\frac{|A|}{|S|} = \\frac{\\text{\u6709\u5229\u7ed3\u679c}}{\\text{\u603b\u7ed3\u679c}}\\]

\\[P(A') = 1 - P(A)\\] \\[P(A \\cup B) = P(A) + P(B) \\quad \\text{(\u82e5 } A \\cap B = \\emptyset\\text{)}\\] \\[P(A \\cup B) = P(A) + P(B) - P(A \\cap B)\\] \\[P(A | B) = \\frac{P(A \\cap B)}{P(B)}, \\quad P(B) > 0\\]

\\[P(A \\cap B) = P(A) \\cdot P(B)\\] \\[P(A \\cap B) = P(A | B) \\cdot P(B) = P(B | A) \\cdot P(A)\\] \\[P(A | B) = \\frac{P(B | A) \\cdot P(A)}{P(B)}\\]

\\[P(+) = P(+ | D) \\cdot P(D) + P(+ | D') \\cdot P(D')$$ $$= 0.95 \\times 0.01 + 0.10 \\times 0.99 = 0.0095 + 0.099 = 0.1085\\] \\[P(D | +) = \\frac{P(+ | D) \\cdot P(D)}{P(+)} = \\frac{0.95 \\times 0.01}{0.1085} \\approx 0.088\\] \\[P(A) = \\sum_{i=1}^{n} P(A | B_i) \\cdot P(B_i)\\] \\[P(A_1 \\cap A_2 \\cap \\cdots \\cap A_n) = P(A_1) \\cdot P(A_2 | A_1) \\cdot P(A_3 | A_1 \\cap A_2) \\cdots P(A_n | A_1 \\cap \\cdots \\cap A_{n-1})\\] \\[P(A \\cap B | C) = P(A | C) \\cdot P(B | C)\\] "},{"location":"chapter%2005%3A%20probability/02.%20probability%20concepts/#colab-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
  1. \u6a21\u62df\u533b\u5b66\u8bca\u65ad\u95ee\u9898\u3002\u751f\u6210 100,000 \u4eba\u7684\u603b\u4f53\uff0c\u5e94\u7528\u75be\u75c5\u60a3\u75c5\u7387\u548c\u68c0\u6d4b\u51c6\u786e\u7387\uff0c\u9a8c\u8bc1\u8d1d\u53f6\u65af\u5b9a\u7406\u7ed9\u51fa\u6b63\u786e\u7684\u540e\u9a8c\u6982\u7387\u3002

    import jax\nimport jax.numpy as jnp\n\nkey = jax.random.PRNGKey(42)\nn = 100_000\n\n# \u751f\u6210\u603b\u4f53\nk1, k2 = jax.random.split(key)\nhas_disease = jax.random.bernoulli(k1, p=0.01, shape=(n,))\n\n# \u751f\u6210\u68c0\u6d4b\u7ed3\u679c\nk3, k4 = jax.random.split(k2)\n# \u7075\u654f\u5ea6\uff1aP(+|D) = 0.95\uff0c\u7279\u5f02\u5ea6\uff1aP(-|D') = 0.90\ntest_positive = jnp.where(\n    has_disease,\n    jax.random.bernoulli(k3, p=0.95, shape=(n,)),\n    jax.random.bernoulli(k4, p=0.10, shape=(n,))\n)\n\n# \u5728\u68c0\u6d4b\u9633\u6027\u7684\u4eba\u7fa4\u4e2d\uff0c\u5b9e\u9645\u60a3\u75c5\u7684\u6bd4\u4f8b\u662f\u591a\u5c11\uff1f\npositives = test_positive.astype(bool)\ntrue_positives = (has_disease & positives).sum()\ntotal_positives = positives.sum()\n\nprint(f\"\u68c0\u6d4b\u9633\u6027\u603b\u4eba\u6570: {total_positives}\")\nprint(f\"\u771f\u9633\u6027\u4eba\u6570: {true_positives}\")\nprint(f\"P(\u60a3\u75c5 | \u9633\u6027) = {true_positives / total_positives:.4f}\")\nprint(f\"\u8d1d\u53f6\u65af\u516c\u5f0f:         {0.95 * 0.01 / 0.1085:.4f}\")\n

  2. \u901a\u8fc7\u6a21\u62df\u9a8c\u8bc1\u52a0\u6cd5\u6cd5\u5219\u3002\u751f\u6210\u5177\u6709\u5df2\u77e5\u6982\u7387\u548c\u91cd\u53e0\u7a0b\u5ea6\u7684\u968f\u673a\u4e8b\u4ef6 A \u548c B\uff0c\u7136\u540e\u9a8c\u8bc1 \\(P(A \\cup B) = P(A) + P(B) - P(A \\cap B)\\)\u3002

    import jax\nimport jax.numpy as jnp\n\nkey = jax.random.PRNGKey(0)\nn = 200_000\nk1, k2 = jax.random.split(key)\n\n# \u4e8b\u4ef6\uff1aA = \u503c < 0.4\uff0cB = \u503c < 0.6\uff08\u5728 < 0.4 \u5904\u91cd\u53e0\uff09\nvals_a = jax.random.uniform(k1, shape=(n,))\nvals_b = jax.random.uniform(k2, shape=(n,))\n\nA = vals_a < 0.4\nB = vals_b < 0.6\n\np_a = A.mean()\np_b = B.mean()\np_a_and_b = (A & B).mean()\np_a_or_b = (A | B).mean()\n\nprint(f\"P(A) = {p_a:.4f}\")\nprint(f\"P(B) = {p_b:.4f}\")\nprint(f\"P(A \u2229 B) = {p_a_and_b:.4f}\")\nprint(f\"P(A \u222a B) \u6a21\u62df\u503c = {p_a_or_b:.4f}\")\nprint(f\"P(A) + P(B) - P(A\u2229B) = {p_a + p_b - p_a_and_b:.4f}\")\n

  3. \u6f14\u793a\u6761\u4ef6\u6982\u7387\u968f\u8bc1\u636e\u53d8\u5316\u3002\u6a21\u62df\u63b7\u4e24\u4e2a\u9ab0\u5b50\uff0c\u8ba1\u7b97 \\(P(\\text{\u548c} = 7)\\)\uff0c\u7136\u540e\u8ba1\u7b97 \\(P(\\text{\u548c} = 7 | \\text{\u7b2c\u4e00\u4e2a\u9ab0\u5b50} = 3)\\)\u3002

    import jax\nimport jax.numpy as jnp\n\nkey = jax.random.PRNGKey(1)\nn = 500_000\nk1, k2 = jax.random.split(key)\n\nd1 = jax.random.randint(k1, shape=(n,), minval=1, maxval=7)\nd2 = jax.random.randint(k2, shape=(n,), minval=1, maxval=7)\ntotal = d1 + d2\n\n# \u65e0\u6761\u4ef6\u6982\u7387\np_sum7 = (total == 7).mean()\nprint(f\"P(\u548c=7) = {p_sum7:.4f} (\u7cbe\u786e\u503c: {6/36:.4f})\")\n\n# \u6761\u4ef6\u4e8e\u7b2c\u4e00\u4e2a\u9ab0\u5b50 = 3\nmask = d1 == 3\np_sum7_given_d1_3 = (total[mask] == 7).mean()\nprint(f\"P(\u548c=7 | d1=3) = {p_sum7_given_d1_3:.4f} (\u7cbe\u786e\u503c: {1/6:.4f})\")\n

  4. \u5c06\u8d1d\u53f6\u65af\u5b9a\u7406\u5b9e\u73b0\u4e3a\u4e00\u4e2a\u51fd\u6570\uff0c\u5e76\u7528\u5b83\u8fed\u4ee3\u66f4\u65b0\u4fe1\u5ff5\u3002\u4ece\u786c\u5e01\u504f\u5411\u7684\u5747\u5300\u5148\u9a8c\u5f00\u59cb\uff0c\u5728\u89c2\u5bdf\u5230\u6bcf\u6b21\u629b\u63b7\u540e\u66f4\u65b0\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef bayes_update(prior, likelihood):\n    \"\"\"\u5c06\u5148\u9a8c\u4e58\u4ee5\u4f3c\u7136\u5e76\u5f52\u4e00\u5316\u3002\"\"\"\n    posterior = prior * likelihood\n    return posterior / posterior.sum()\n\n# \u79bb\u6563\u5316\u53ef\u80fd\u7684\u504f\u5411\u503c\ntheta = jnp.linspace(0, 1, 200)\nprior = jnp.ones_like(theta)  # \u5747\u5300\u5148\u9a8c\nprior = prior / prior.sum()\n\n# \u89c2\u6d4b\u5230\u7684\u629b\u63b7\u7ed3\u679c\uff1a1=\u6b63\u9762\uff0c0=\u53cd\u9762\nflips = [1, 1, 0, 1, 1, 1, 0, 1, 0, 1]\n\nplt.figure(figsize=(10, 5))\nplt.plot(theta, prior, \"--\", color=\"#999\", label=\"\u5148\u9a8c\")\n\nfor i, flip in enumerate(flips):\n    likelihood = theta if flip == 1 else (1 - theta)\n    prior = bayes_update(prior, likelihood)\n    if i in [0, 2, 4, 9]:\n        plt.plot(theta, prior, label=f\"\u7ecf\u8fc7 {i+1} \u6b21\u629b\u63b7\u540e\", linewidth=2)\n\nplt.xlabel(\"\u786c\u5e01\u504f\u5411 \u03b8\")\nplt.ylabel(\"\u4fe1\u5ff5\uff08\u5f52\u4e00\u5316\uff09\")\nplt.title(\"\u8d1d\u53f6\u65af\u66f4\u65b0\uff1a\u5173\u4e8e\u786c\u5e01\u504f\u5411\u7684\u4fe1\u5ff5\")\nplt.legend()\nplt.grid(alpha=0.3)\nplt.show()\n

"},{"location":"chapter%2005%3A%20probability/03.%20distributions/","title":"\u6982\u7387\u5206\u5e03","text":"

\u6982\u7387\u5206\u5e03\u63cf\u8ff0\u4e86\u968f\u673a\u7ed3\u679c\u5982\u4f55\u5728\u53ef\u80fd\u53d6\u503c\u4e0a\u5206\u5e03\u3002\u672c\u6587\u6863\u6574\u7406\u4e86\u5173\u952e\u7684\u79bb\u6563\u548c\u8fde\u7eed\u5206\u5e03\uff1a\u4f2f\u52aa\u5229\u5206\u5e03\u3001\u4e8c\u9879\u5206\u5e03\u3001\u6cca\u677e\u5206\u5e03\u3001\u9ad8\u65af\u5206\u5e03\u3001\u6307\u6570\u5206\u5e03\u3001\u8d1d\u5854\u5206\u5e03\u7b49\uff0c\u7ed9\u51fa\u4e86\u5404\u81ea\u7684\u516c\u5f0f\u3001\u76f4\u89c2\u7406\u89e3\u53ca\u5176\u5728\u673a\u5668\u5b66\u4e60\u4e2d\u7684\u5e94\u7528\uff08\u635f\u5931\u51fd\u6570\u3001\u5148\u9a8c\u3001\u566a\u58f0\u6a21\u578b\uff09\u3002

\\[P(X = x) = p^x (1 - p)^{1-x}, \\quad x \\in \\{0, 1\\}\\] \\[P(X = k) = \\binom{n}{k} p^k (1-p)^{n-k}, \\quad k = 0, 1, \\ldots, n\\]

\\[P(X = k) = \\frac{\\lambda^k e^{-\\lambda}}{k!}, \\quad k = 0, 1, 2, \\ldots\\] \\[P(X = k) = (1-p)^{k-1} p, \\quad k = 1, 2, 3, \\ldots\\] \\[P(X = k) = \\binom{k-1}{r-1} p^r (1-p)^{k-r}, \\quad k = r, r+1, r+2, \\ldots\\] \\[f(x) = \\frac{1}{b - a}, \\quad a \\le x \\le b\\] \\[f(x) = \\frac{1}{\\sigma\\sqrt{2\\pi}} \\exp\\!\\left(-\\frac{(x - \\mu)^2}{2\\sigma^2}\\right)\\]

\\[f(x) = \\lambda e^{-\\lambda x}, \\quad x \\ge 0\\] \\[f(x) = \\frac{\\beta^\\alpha}{\\Gamma(\\alpha)} x^{\\alpha - 1} e^{-\\beta x}, \\quad x > 0\\] \\[f(x) = \\frac{x^{\\alpha - 1}(1 - x)^{\\beta - 1}}{B(\\alpha, \\beta)}, \\quad 0 \\le x \\le 1\\]

\\[f(x) = \\frac{1}{2^{k/2}\\Gamma(k/2)} x^{k/2 - 1} e^{-x/2}, \\quad x > 0\\] \\[f(x) = \\frac{\\Gamma\\!\\left(\\frac{\\nu+1}{2}\\right)}{\\sqrt{\\nu\\pi}\\,\\Gamma\\!\\left(\\frac{\\nu}{2}\\right)} \\left(1 + \\frac{x^2}{\\nu}\\right)^{-(\\nu+1)/2}\\] \u5206\u5e03 \u7c7b\u578b \u652f\u6491\u96c6 \u5747\u503c \u65b9\u5dee Bernoulli\\((p)\\) \u79bb\u6563 \\(\\{0,1\\}\\) \\(p\\) \\(p(1-p)\\) Binomial\\((n,p)\\) \u79bb\u6563 \\(\\{0,\\ldots,n\\}\\) \\(np\\) \\(np(1-p)\\) Poisson\\((\\lambda)\\) \u79bb\u6563 \\(\\{0,1,2,\\ldots\\}\\) \\(\\lambda\\) \\(\\lambda\\) Geometric\\((p)\\) \u79bb\u6563 \\(\\{1,2,3,\\ldots\\}\\) \\(1/p\\) \\((1-p)/p^2\\) Uniform\\((a,b)\\) \u8fde\u7eed \\([a,b]\\) \\((a+b)/2\\) \\((b-a)^2/12\\) Normal\\((\\mu,\\sigma^2)\\) \u8fde\u7eed \\((-\\infty,\\infty)\\) \\(\\mu\\) \\(\\sigma^2\\) Exponential\\((\\lambda)\\) \u8fde\u7eed \\([0,\\infty)\\) \\(1/\\lambda\\) \\(1/\\lambda^2\\) Gamma\\((\\alpha,\\beta)\\) \u8fde\u7eed \\((0,\\infty)\\) \\(\\alpha/\\beta\\) \\(\\alpha/\\beta^2\\) Beta\\((\\alpha,\\beta)\\) \u8fde\u7eed \\([0,1]\\) \\(\\alpha/(\\alpha+\\beta)\\) \u89c1\u4e0a\u6587 \\(\\chi^2(k)\\) \u8fde\u7eed \\((0,\\infty)\\) \\(k\\) \\(2k\\) Student's \\(t(\\nu)\\) \u8fde\u7eed \\((-\\infty,\\infty)\\) \\(0\\) \\(\\nu/(\\nu-2)\\)"},{"location":"chapter%2005%3A%20probability/03.%20distributions/#colab","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528CoLab\u6216\u7b14\u8bb0\u672c\uff09","text":"
  1. \u7ed8\u5236 \\(n=20\\) \u65f6\u4e8c\u9879\u5206\u5e03PMF\u5728\u4e0d\u540c \\(p\\) \u53d6\u503c\u4e0b\u7684\u56fe\u50cf\u3002\u89c2\u5bdf\u5f62\u72b6\u5982\u4f55\u4ece\u5de6\u504f\u53d8\u4e3a\u5bf9\u79f0\u518d\u53d8\u4e3a\u53f3\u504f\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\nfrom math import comb\n\nn = 20\nks = jnp.arange(0, n + 1)\n\nfig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True)\nfor ax, p, color in zip(axes, [0.2, 0.5, 0.8], [\"#e74c3c\", \"#3498db\", \"#27ae60\"]):\n    pmf = jnp.array([comb(n, int(k)) * p**k * (1-p)**(n-k) for k in ks])\n    ax.bar(ks, pmf, color=color, alpha=0.7)\n    ax.set_title(f\"Binomial(n={n}, p={p})\")\n    ax.set_xlabel(\"k\")\naxes[0].set_ylabel(\"P(X = k)\")\nplt.tight_layout()\nplt.show()\n

  2. \u9a8c\u8bc1\u6cca\u677e\u5206\u5e03\u5bf9\u4e8c\u9879\u5206\u5e03\u7684\u8fd1\u4f3c\u3002\u8bbe \\(n = 1000\\)\uff0c\\(p = 0.003\\)\uff0c\u6bd4\u8f83\u4e8c\u9879\u5206\u5e03 Binomial\\((n, p)\\) \u548c\u6cca\u677e\u5206\u5e03 Poisson\\((\\lambda = np)\\)\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\nfrom math import comb, factorial, exp\n\nn, p = 1000, 0.003\nlam = n * p\nks = jnp.arange(0, 15)\n\nbinom_pmf = jnp.array([comb(n, int(k)) * p**k * (1-p)**(n-k) for k in ks])\npoisson_pmf = jnp.array([lam**k * exp(-lam) / factorial(int(k)) for k in ks])\n\nplt.figure(figsize=(8, 4))\nplt.bar(ks - 0.15, binom_pmf, width=0.3, color=\"#3498db\", alpha=0.7, label=f\"Binomial({n},{p})\")\nplt.bar(ks + 0.15, poisson_pmf, width=0.3, color=\"#e74c3c\", alpha=0.7, label=f\"Poisson({lam})\")\nplt.xlabel(\"k\")\nplt.ylabel(\"P(X = k)\")\nplt.title(\"\u6cca\u677e\u5206\u5e03\u5bf9\u4e8c\u9879\u5206\u5e03\u7684\u8fd1\u4f3c\")\nplt.legend()\nplt.show()\n

  3. \u4ece\u6b63\u6001\u5206\u5e03\u4e2d\u91c7\u6837\u5e76\u9a8c\u8bc1\u7ecf\u9a8c\u6cd5\u5219\u3002\u8ba1\u7b97\u843d\u57281\u30012\u548c3\u4e2a\u6807\u51c6\u5dee\u5185\u7684\u6837\u672c\u6bd4\u4f8b\u3002

    import jax\nimport jax.numpy as jnp\n\nkey = jax.random.PRNGKey(42)\nmu, sigma = 5.0, 2.0\nsamples = mu + sigma * jax.random.normal(key, shape=(100_000,))\n\nfor k in [1, 2, 3]:\n    within = jnp.abs(samples - mu) <= k * sigma\n    print(f\"Within {k}\u03c3: {within.mean():.4f} (expected: {[0.6827, 0.9545, 0.9973][k-1]:.4f})\")\n

  4. \u901a\u8fc7\u6539\u53d8 \\(\\alpha\\) \u548c \\(\\beta\\) \u63a2\u7d22\u8d1d\u5854\u5206\u5e03\u3002\u7ed8\u5236\u51e0\u79cd\u5f62\u72b6\uff0c\u89c2\u5bdf\u5206\u5e03\u5982\u4f55\u4ece\u5747\u5300\u53d8\u4e3a\u504f\u659c\u518d\u53d8\u4e3a\u96c6\u4e2d\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\nx = jnp.linspace(0.01, 0.99, 200)\n\ndef beta_pdf(x, a, b):\n    # \u672a\u5f52\u4e00\u5316\uff0c\u7528\u4e8e\u5f62\u72b6\u6bd4\u8f83\n    return x**(a-1) * (1-x)**(b-1)\n\nplt.figure(figsize=(10, 5))\nparams = [(1,1,\"\u5747\u5300\"), (2,5,\"\u5de6\u504f\"), (5,2,\"\u53f3\u504f\"),\n          (5,5,\"\u5bf9\u79f0\"), (0.5,0.5,\"U\u5f62\")]\ncolors = [\"#999\", \"#e74c3c\", \"#3498db\", \"#27ae60\", \"#9b59b6\"]\n\nfor (a, b, label), color in zip(params, colors):\n    y = beta_pdf(x, a, b)\n    y = y / jnp.trapezoid(y, x)  # \u5f52\u4e00\u5316\n    plt.plot(x, y, label=f\"\u03b1={a}, \u03b2={b} ({label})\", color=color, linewidth=2)\n\nplt.xlabel(\"x\")\nplt.ylabel(\"\u5bc6\u5ea6\")\nplt.title(\"\u8d1d\u5854\u5206\u5e03\u5f62\u72b6\")\nplt.legend()\nplt.grid(alpha=0.3)\nplt.show()\n

"},{"location":"chapter%2005%3A%20probability/04.%20bayesian/","title":"\u8d1d\u53f6\u65af\u65b9\u6cd5\u4e0e\u5e8f\u5217\u6a21\u578b","text":"

\u8d1d\u53f6\u65af\u65b9\u6cd5\u5c06\u5148\u9a8c\u4fe1\u5ff5\u4e0e\u89c2\u6d4b\u6570\u636e\u76f8\u7ed3\u5408\uff0c\u751f\u6210\u6a21\u578b\u53c2\u6570\u7684\u540e\u9a8c\u5206\u5e03\u3002\u672c\u6587\u6db5\u76d6\u6700\u5927\u4f3c\u7136\u4f30\u8ba1\u3001\u6700\u5927\u540e\u9a8c\u4f30\u8ba1\u3001\u5171\u8f6d\u5148\u9a8c\u3001\u8d1d\u53f6\u65af\u63a8\u65ad\u3001\u9690\u9a6c\u5c14\u53ef\u592b\u6a21\u578b\u548cEM\u7b97\u6cd5\u2014\u2014\u8fd9\u4e9b\u6280\u672f\u662f\u5783\u573e\u90ae\u4ef6\u8fc7\u6ee4\u5668\u3001\u8bed\u8a00\u6a21\u578b\u548c\u4e0d\u786e\u5b9a\u6027\u611f\u77e5\u673a\u5668\u5b66\u4e60\u7684\u57fa\u7840\u3002

\\[L(\\theta | D) = P(D | \\theta) = \\prod_{i=1}^{n} P(x_i | \\theta)\\] \\[\\hat{\\theta}_{\\text{MLE}} = \\arg\\max_\\theta L(\\theta | D)\\] \\[\\ell(\\theta) = \\log L(\\theta | D) = \\sum_{i=1}^{n} \\log P(x_i | \\theta)\\] \\[L(p) = \\binom{10}{7} p^7 (1-p)^3\\] \\[\\hat{\\theta}_{\\text{MAP}} = \\arg\\max_\\theta P(\\theta | D) = \\arg\\max_\\theta P(D | \\theta) \\cdot P(\\theta)\\]

\\[\\hat{p}_{\\text{MAP}} = \\frac{\\alpha + h - 1}{\\alpha + \\beta + h + t - 2}\\] \\[P(X_{t+1} | X_t, X_{t-1}, \\ldots, X_1) = P(X_{t+1} | X_t)\\]

\\[ T = \\begin{pmatrix} 0.3 & 0.4 & 0.3 \\\\ 0.2 & 0.5 & 0.3 \\\\ 0.4 & 0.3 & 0.3 \\end{pmatrix} \\]

\\[P(\\mathbf{y} | \\mathbf{x}) = \\frac{1}{Z(\\mathbf{x})} \\exp\\!\\left(\\sum_t \\left[\\sum_k \\lambda_k f_k(y_t, y_{t-1}, \\mathbf{x}, t)\\right]\\right)\\] "},{"location":"chapter%2005%3A%20probability/04.%20bayesian/#colab-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
  1. \u5b9e\u73b0\u629b\u786c\u5e01\u5b9e\u9a8c\u7684MLE\u548cMAP\u3002\u89c2\u5bdfMAP\u4f30\u8ba1\u5982\u4f55\u968f\u4e0d\u540c\u7684\u5148\u9a8c\u548c\u4e0d\u540c\u7684\u6570\u636e\u91cf\u800c\u53d8\u5316\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u6570\u636e\uff1a\u89c2\u6d4b\u5230\u7684\u786c\u5e01\u629b\u63b7\u7ed3\u679c\nheads, tails = 7, 3\n\n# MLE\np_mle = heads / (heads + tails)\nprint(f\"MLE: {p_mle:.4f}\")\n\n# \u4f7f\u7528 Beta \u5148\u9a8c\u7684 MAP\nfor alpha, beta in [(1,1), (2,2), (5,5), (10,10)]:\n    p_map = (alpha + heads - 1) / (alpha + beta + heads + tails - 2)\n    print(f\"MAP (Beta({alpha},{beta})): {p_map:.4f}\")\n\n# \u53ef\u89c6\u5316 Beta(2,2) \u5148\u9a8c\u4e0b\u7684\u540e\u9a8c\ntheta = jnp.linspace(0.01, 0.99, 200)\n# \u540e\u9a8c\u4e3a Beta(alpha+heads, beta+tails)\na_post, b_post = 2 + heads, 2 + tails\nposterior = theta**(a_post-1) * (1-theta)**(b_post-1)\nposterior = posterior / jnp.trapezoid(posterior, theta)\n\nplt.figure(figsize=(8, 4))\nplt.plot(theta, posterior, color=\"#e74c3c\", linewidth=2, label=f\"\u540e\u9a8c Beta({a_post},{b_post})\")\nplt.axvline(p_mle, color=\"#3498db\", linestyle=\"--\", label=f\"MLE = {p_mle:.2f}\")\nplt.axvline((a_post-1)/(a_post+b_post-2), color=\"#e74c3c\", linestyle=\"--\", label=f\"MAP = {(a_post-1)/(a_post+b_post-2):.3f}\")\nplt.xlabel(\"\u03b8 (\u786c\u5e01\u504f\u7f6e)\")\nplt.ylabel(\"\u5bc6\u5ea6\")\nplt.title(\"7\u6b21\u6b63\u9762\u30013\u6b21\u53cd\u9762\u540e Beta(2,2) \u5148\u9a8c\u4e0b\u7684\u540e\u9a8c\u5206\u5e03\")\nplt.legend()\nplt.grid(alpha=0.3)\nplt.show()\n

  2. \u4e3a\u5929\u6c14\u6a21\u578b\u6784\u5efa\u4e00\u4e2a\u9a6c\u5c14\u53ef\u592b\u94fe\u5e76\u8fdb\u884c\u6a21\u62df\u3002\u5206\u522b\u901a\u8fc7\u6a21\u62df\u548c\u6c42\u89e3 \\(\\pi T = \\pi\\) \u8ba1\u7b97\u5e73\u7a33\u5206\u5e03\u3002

    import jax\nimport jax.numpy as jnp\n\n# \u8f6c\u79fb\u77e9\u9635\uff1aR\uff08\u96e8\u5929\uff09, S\uff08\u6674\u5929\uff09, C\uff08\u591a\u4e91\uff09\nT = jnp.array([\n    [0.3, 0.4, 0.3],\n    [0.2, 0.5, 0.3],\n    [0.4, 0.3, 0.3]\n])\nstates = [\"\u96e8\u5929\", \"\u6674\u5929\", \"\u591a\u4e91\"]\n\n# \u6a21\u62df 100,000 \u6b65\nkey = jax.random.PRNGKey(42)\nn_steps = 100_000\nstate = 0  # \u4ece\u96e8\u5929\u5f00\u59cb\ncounts = jnp.zeros(3)\n\nfor i in range(n_steps):\n    key, subkey = jax.random.split(key)\n    state = jax.random.choice(subkey, 3, p=T[state])\n    counts = counts.at[state].add(1)\n\nsim_stationary = counts / n_steps\nprint(\"\u6a21\u62df\u5f97\u5230\u7684\u5e73\u7a33\u5206\u5e03\uff1a\")\nfor s, p in zip(states, sim_stationary):\n    print(f\"  {s}: {p:.4f}\")\n\n# \u89e3\u6790\u6cd5\uff1a\u627e\u5230\u7279\u5f81\u503c\u4e3a1\u7684\u5de6\u7279\u5f81\u5411\u91cf\neigenvalues, eigenvectors = jnp.linalg.eig(T.T)\nidx = jnp.argmin(jnp.abs(eigenvalues - 1.0))\npi = jnp.real(eigenvectors[:, idx])\npi = pi / pi.sum()\nprint(\"\\n\u89e3\u6790\u5f97\u5230\u7684\u5e73\u7a33\u5206\u5e03\uff1a\")\nfor s, p in zip(states, pi):\n    print(f\"  {s}: {p:.4f}\")\n

  3. \u4e3a\u96e8\u4f1eHMM\u5b9e\u73b0\u7ef4\u7279\u6bd4\u7b97\u6cd5\uff0c\u5e76\u89e3\u7801\u4e00\u4e2a\u89c2\u6d4b\u5e8f\u5217\u3002

    import jax.numpy as jnp\n\n# HMM \u53c2\u6570\nstates = [\"\u96e8\u5929\", \"\u6674\u5929\"]\nobs_names = [\"\u5e26\u4f1e\", \"\u4e0d\u5e26\u4f1e\"]\n\ntrans = jnp.array([[0.7, 0.3],   # R->R, R->S\n                    [0.4, 0.6]])  # S->R, S->S\n\nemit = jnp.array([[0.9, 0.1],    # R->\u5e26\u4f1e, R->\u4e0d\u5e26\u4f1e\n                   [0.2, 0.8]])   # S->\u5e26\u4f1e, S->\u4e0d\u5e26\u4f1e\n\ninit = jnp.array([0.5, 0.5])\n\n# \u89c2\u6d4b\uff1a\u5e26\u4f1e=0\uff0c\u4e0d\u5e26\u4f1e=1\nobservations = [0, 0, 1]  # \u5e26\u4f1e, \u5e26\u4f1e, \u4e0d\u5e26\u4f1e\n\ndef viterbi(obs, init, trans, emit):\n    n_states = len(init)\n    T = len(obs)\n    V = jnp.zeros((T, n_states))\n    path = jnp.zeros((T, n_states), dtype=int)\n\n    # \u521d\u59cb\u5316\n    V = V.at[0].set(init * emit[:, obs[0]])\n\n    # \u9012\u63a8\n    for t in range(1, T):\n        for j in range(n_states):\n            probs = V[t-1] * trans[:, j]\n            V = V.at[t, j].set(jnp.max(probs) * emit[j, obs[t]])\n            path = path.at[t, j].set(jnp.argmax(probs))\n\n    # \u56de\u6eaf\n    best = [int(jnp.argmax(V[-1]))]\n    for t in range(T-1, 0, -1):\n        best.insert(0, int(path[t, best[0]]))\n    return best, V\n\ndecoded, scores = viterbi(observations, init, trans, emit)\nprint(\"\u89c2\u6d4b\u5e8f\u5217\uff1a\", [obs_names[o] for o in observations])\nprint(\"\u89e3\u7801\u7ed3\u679c\uff1a\", [states[s] for s in decoded])\n

  4. \u53ef\u89c6\u5316\u968f\u7740\u89c2\u6d4b\u66f4\u591a\u629b\u786c\u5e01\u7ed3\u679c\uff0c\u540e\u9a8c\u5982\u4f55\u6f14\u5316\u3002\u4ece Beta(1,1) \u5148\u9a8c\uff08\u5747\u5300\u5206\u5e03\uff09\u5f00\u59cb\uff0c\u6bcf\u6b21\u629b\u63b7\u540e\u66f4\u65b0\u540e\u9a8c\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ntheta = jnp.linspace(0.01, 0.99, 300)\nkey = jax.random.PRNGKey(7)\n\n# \u771f\u5b9e\u504f\u7f6e = 0.65\nflips = jax.random.bernoulli(key, p=0.65, shape=(50,))\n\nplt.figure(figsize=(10, 5))\na, b = 1, 1  # Beta(1,1) = \u5747\u5300\u5206\u5e03\n\nfor n_obs in [0, 1, 5, 10, 25, 50]:\n    h = int(flips[:n_obs].sum())\n    t = n_obs - h\n    a_post = a + h\n    b_post = b + t\n    y = theta**(a_post-1) * (1-theta)**(b_post-1)\n    y = y / jnp.trapezoid(y, theta)\n    plt.plot(theta, y, linewidth=2, label=f\"n={n_obs} (h={h})\")\n\nplt.axvline(0.65, color=\"black\", linestyle=\":\", alpha=0.5, label=\"\u771f\u5b9e p=0.65\")\nplt.xlabel(\"\u03b8\")\nplt.ylabel(\"\u5bc6\u5ea6\")\nplt.title(\"\u8d1d\u53f6\u65af\u66f4\u65b0\uff1a\u6570\u636e\u8d8a\u591a\u540e\u9a8c\u8d8a\u7a84\")\nplt.legend()\nplt.grid(alpha=0.3)\nplt.show()\n

"},{"location":"chapter%2005%3A%20probability/05.%20information%20theory/","title":"\u4fe1\u606f\u8bba","text":"

\u4fe1\u606f\u8bba\u91cf\u5316\u4e86\u4fe1\u606f\u3001\u60ca\u5947\u5ea6\u4ee5\u53ca\u6982\u7387\u5206\u5e03\u4e4b\u95f4\u7684\u5dee\u5f02\u3002\u672c\u6587\u6db5\u76d6\u71b5\u3001\u4ea4\u53c9\u71b5\u3001KL\u6563\u5ea6\u3001\u4e92\u4fe1\u606f\u548c\u81ea\u4fe1\u606f\u2014\u2014\u8fd9\u4e9b\u6982\u5ff5\u662f\u673a\u5668\u5b66\u4e60\u4e2d\u6bcf\u4e00\u4e2a\u5206\u7c7b\u635f\u5931\u51fd\u6570\u3001VAE\u76ee\u6807\u548c\u6570\u636e\u538b\u7f29\u65b9\u6848\u80cc\u540e\u7684\u7406\u8bba\u57fa\u7840\u3002

\\[I(x) = \\log_2 \\frac{1}{p(x)} = -\\log_2 p(x)\\] \\[H(X) = E[I(X)] = -\\sum_{x} p(x) \\log_2 p(x)\\]

\\[h(X) = -\\int_{-\\infty}^{\\infty} f(x) \\log f(x)\\, dx\\] \\[I(X; Y) = H(X) - H(X|Y) = H(Y) - H(Y|X)\\] \\[I(X; Y) = \\sum_{x,y} p(x,y) \\log_2 \\frac{p(x,y)}{p(x) p(y)}\\] \\[H(p, q) = -\\sum_{x} p(x) \\log_2 q(x)\\] \\[\\mathcal{L} = -\\sum_{c} y_c \\log \\hat{y}_c\\] \\[D_{\\text{KL}}(p \\| q) = \\sum_{x} p(x) \\log \\frac{p(x)}{q(x)} = H(p, q) - H(p)\\]

"},{"location":"chapter%2005%3A%20probability/05.%20information%20theory/#colab-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
  1. \u8ba1\u7b97\u5404\u79cd\u5206\u5e03\u7684\u71b5\uff0c\u5e76\u9a8c\u8bc1\u5728\u7ed9\u5b9a\u7ed3\u679c\u6570\u91cf\u4e0b\uff0c\u5747\u5300\u5206\u5e03\u7684\u71b5\u6700\u5927\u3002

    import jax.numpy as jnp\n\ndef entropy(p):\n    \"\"\"\u4ee5\u6bd4\u7279\u4e3a\u5355\u4f4d\u8ba1\u7b97\u71b5\u3002\u8fc7\u6ee4\u6389\u6982\u7387\u4e3a\u96f6\u7684\u4e8b\u4ef6\u3002\"\"\"\n    p = p[p > 0]\n    return -jnp.sum(p * jnp.log2(p))\n\n# \u516c\u5e73\u9ab0\u5b50\nfair = jnp.ones(6) / 6\nprint(f\"\u516c\u5e73\u9ab0\u5b50\u71b5:   {entropy(fair):.4f} \u6bd4\u7279 (\u6700\u5927 = log2(6) = {jnp.log2(6.):.4f})\")\n\n# \u704c\u94c5\u9ab0\u5b50\nloaded = jnp.array([0.1, 0.1, 0.1, 0.1, 0.1, 0.5])\nprint(f\"\u704c\u94c5\u9ab0\u5b50\u71b5: {entropy(loaded):.4f} \u6bd4\u7279\")\n\n# \u786e\u5b9a\u6027\ndet = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 1.0])\nprint(f\"\u786e\u5b9a\u6027:      {entropy(det):.4f} \u6bd4\u7279\")\n\n# \u516c\u5e73\u786c\u5e01\ncoin = jnp.array([0.5, 0.5])\nprint(f\"\u516c\u5e73\u786c\u5e01\u71b5:  {entropy(coin):.4f} \u6bd4\u7279\")\n

  2. \u8ba1\u7b97\u771f\u5b9e\u5206\u5e03\u4e0e\u591a\u4e2a\u8fd1\u4f3c\u5206\u5e03\u4e4b\u95f4\u7684\u4ea4\u53c9\u71b5\u548c KL \u6563\u5ea6\u3002\u9a8c\u8bc1 \\(D_{\\text{KL}}(p \\| q) = H(p, q) - H(p)\\)\u3002

    import jax.numpy as jnp\n\ndef cross_entropy(p, q):\n    return -jnp.sum(p * jnp.log2(jnp.clip(q, 1e-10, 1.0)))\n\ndef kl_divergence(p, q):\n    mask = p > 0\n    return jnp.sum(jnp.where(mask, p * jnp.log2(p / jnp.clip(q, 1e-10, 1.0)), 0.0))\n\ndef entropy(p):\n    p = p[p > 0]\n    return -jnp.sum(p * jnp.log2(p))\n\np = jnp.array([0.4, 0.3, 0.2, 0.1])  # \u771f\u5b9e\u5206\u5e03\n\nfor name, q in [(\"\u5b8c\u5168\u5339\u914d\", p),\n                (\"\u8f7b\u5fae\u504f\u5dee\", jnp.array([0.35, 0.30, 0.25, 0.10])),\n                (\"\u4e25\u91cd\u504f\u5dee\", jnp.array([0.1, 0.1, 0.1, 0.7]))]:\n    h_p = entropy(p)\n    h_pq = cross_entropy(p, q)\n    kl = kl_divergence(p, q)\n    print(f\"{name:20s}: H(p)={h_p:.4f}, H(p,q)={h_pq:.4f}, \"\n          f\"KL={kl:.4f}, H(p,q)-H(p)={h_pq-h_p:.4f}\")\n

  3. \u901a\u8fc7\u8ba1\u7b97\u4e24\u4e2a\u4e0d\u540c\u5206\u5e03\u4e4b\u95f4\u7684 \\(D_{\\text{KL}}(p \\| q)\\) \u548c \\(D_{\\text{KL}}(q \\| p)\\)\uff0c\u8bc1\u660e KL \u6563\u5ea6\u4e0d\u662f\u5bf9\u79f0\u7684\u3002

    import jax.numpy as jnp\n\ndef kl_div(p, q):\n    mask = p > 0\n    return float(jnp.sum(jnp.where(mask, p * jnp.log2(p / jnp.clip(q, 1e-10, 1.0)), 0.0)))\n\np = jnp.array([0.9, 0.1])\nq = jnp.array([0.5, 0.5])\n\nprint(f\"D_KL(p || q) = {kl_div(p, q):.4f}\")\nprint(f\"D_KL(q || p) = {kl_div(q, p):.4f}\")\nprint(\"\u4e0d\u76f8\u540c\uff01KL \u6563\u5ea6\u662f\u4e0d\u5bf9\u79f0\u7684\u3002\")\n

  4. \u6a21\u62df\u8bad\u7ec3\u8fc7\u7a0b\u4e2d\u4ea4\u53c9\u71b5\u635f\u5931\u7684\u53d8\u5316\u3002\u521b\u5efa\u4e00\u4e2a\"\u771f\u5b9e\"\u7684 one-hot \u6807\u7b7e\uff0c\u5c55\u793a\u968f\u7740\u6a21\u578b\u9884\u6d4b\u6982\u7387\u7684\u6539\u5584\uff0c\u635f\u5931\u5982\u4f55\u4e0b\u964d\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u771f\u5b9e\u6807\u7b7e\uff1a4 \u4e2a\u7c7b\u522b\u4e2d\u7684\u7b2c 2 \u7c7b\ntrue_label = jnp.array([0, 0, 1, 0])\n\n# \u6a21\u62df\u9884\u6d4b\u9010\u6b65\u6539\u5584\nsteps = []\nlosses = []\nfor confidence in jnp.linspace(0.25, 0.99, 50):\n    # \u6a21\u578b\u5bf9\u7c7b\u522b 2 \u7684\u7f6e\u4fe1\u5ea6\u9010\u6e10\u63d0\u9ad8\n    remaining = (1 - confidence) / 3\n    pred = jnp.array([remaining, remaining, confidence, remaining])\n    loss = -jnp.sum(true_label * jnp.log(jnp.clip(pred, 1e-10, 1.0)))\n    steps.append(float(confidence))\n    losses.append(float(loss))\n\nplt.figure(figsize=(8, 4))\nplt.plot(steps, losses, color=\"#e74c3c\", linewidth=2)\nplt.xlabel(\"\u6a21\u578b\u5bf9\u771f\u5b9e\u7c7b\u522b\u7684\u7f6e\u4fe1\u5ea6\")\nplt.ylabel(\"\u4ea4\u53c9\u71b5\u635f\u5931\")\nplt.title(\"\u4ea4\u53c9\u71b5\u635f\u5931\u968f\u9884\u6d4b\u6539\u5584\u800c\u4e0b\u964d\")\nplt.grid(alpha=0.3)\nplt.show()\n

"},{"location":"chapter%2006%3A%20machine%20learning/01.%20classical%20machine%20learning/","title":"\u7ecf\u5178\u673a\u5668\u5b66\u4e60","text":"

\u7ecf\u5178\u673a\u5668\u5b66\u4e60\u7b97\u6cd5\u901a\u8fc7\u6570\u636e\u5b66\u4e60\u6a21\u5f0f\u800c\u65e0\u9700\u663e\u5f0f\u7f16\u7a0b\uff0c\u4f7f\u7528\u95ed\u5f0f\u89e3\u6216\u542f\u53d1\u5f0f\u641c\u7d22\u800c\u975e\u68af\u5ea6\u4e0b\u964d\u3002\u672c\u6587\u6db5\u76d6\u6734\u7d20\u8d1d\u53f6\u65af\u3001k-NN\u3001\u51b3\u7b56\u6811\u3001\u968f\u673a\u68ee\u6797\u3001\u652f\u6301\u5411\u91cf\u673a\u3001k-means\u805a\u7c7b\u548c\u4e3b\u6210\u5206\u5206\u6790

\\[P(C_k \\mid x) = \\frac{P(x \\mid C_k) \\, P(C_k)}{P(x)}\\] \\[\\hat{y} = \\arg\\max_{k} \\; P(C_k) \\prod_{i=1}^{n} P(x_i \\mid C_k)\\] \\[P(x_i \\mid C_k) = \\frac{1}{\\sqrt{2\\pi\\sigma_{ik}^2}} \\exp\\!\\left(-\\frac{(x_i - \\mu_{ik})^2}{2\\sigma_{ik}^2}\\right)\\]

\\[P(x_i \\mid C_k) = \\frac{\\text{count}(x_i, C_k) + \\alpha}{\\text{count}(C_k) + \\alpha \\cdot V}\\]

\\[\\text{Gini}(S) = 1 - \\sum_{k=1}^{K} p_k^2\\] \\[H(S) = -\\sum_{k=1}^{K} p_k \\log_2 p_k\\] \\[\\text{IG}(S, \\text{split}) = H(S) - \\frac{|S_L|}{|S|} H(S_L) - \\frac{|S_R|}{|S|} H(S_R)\\]

\\[H(x) = \\text{sign}\\!\\left(\\sum_{t=1}^{T} \\alpha_t \\, h_t(x)\\right)\\] \\[\\alpha_t = \\frac{1}{2} \\ln\\!\\left(\\frac{1 - \\epsilon_t}{\\epsilon_t}\\right)\\]

\\[J = \\sum_{k=1}^{K} \\sum_{x \\in C_k} \\|x - \\mu_k\\|^2\\] \\[P(x) = \\sum_{k=1}^{K} \\pi_k \\, \\mathcal{N}(x \\mid \\mu_k, \\Sigma_k)\\]

\\[\\min_{w, b} \\; \\frac{1}{2}\\|w\\|^2 \\quad \\text{subject to} \\quad y_i(w \\cdot x_i + b) \\geq 1 \\; \\text{for all } i\\] \\[\\min_{w, b, \\xi} \\; \\frac{1}{2}\\|w\\|^2 + C \\sum_{i=1}^{n} \\xi_i \\quad \\text{subject to} \\quad y_i(w \\cdot x_i + b) \\geq 1 - \\xi_i\\] \\[K(x_i, x_j) = \\exp\\!\\left(-\\gamma \\|x_i - x_j\\|^2\\right)\\] \u7b97\u6cd5 \u7c7b\u578b \u5173\u952e\u4f18\u52bf \u5173\u952e\u52a3\u52bf \u6734\u7d20\u8d1d\u53f6\u65af \u76d1\u7763\uff08\u751f\u6210\u5f0f\uff09 \u5feb\u901f\uff0c\u5c11\u91cf\u6570\u636e\u5373\u53ef\u5de5\u4f5c \u72ec\u7acb\u6027\u5047\u8bbe \u51b3\u7b56\u6811 \u76d1\u7763 \u53ef\u89e3\u91ca \u5bb9\u6613\u8fc7\u62df\u5408 \u968f\u673a\u68ee\u6797 \u76d1\u7763\uff08\u96c6\u6210\uff09 \u7a33\u5065\uff0c\u8d85\u53c2\u6570\u5c11 \u53ef\u89e3\u91ca\u6027\u8f83\u5dee \u68af\u5ea6\u63d0\u5347 \u76d1\u7763\uff08\u96c6\u6210\uff09 \u8868\u683c\u6570\u636e\u4e0a\u7684\u6700\u4f18\u6c34\u5e73 \u8f83\u6162\uff0c\u8c03\u53c2\u66f4\u591a K-Means \u65e0\u76d1\u7763\uff08\u805a\u7c7b\uff09 \u7b80\u5355\uff0c\u53ef\u6269\u5c55 \u5047\u8bbe\u7403\u5f62\u7c07 GMM \u65e0\u76d1\u7763\uff08\u805a\u7c7b\uff09 \u8f6f\u5206\u914d\uff0c\u5f62\u72b6\u7075\u6d3b \u5bf9\u521d\u59cb\u5316\u654f\u611f SVM \u76d1\u7763 \u9ad8\u7ef4\u6709\u6548 \u5927\u6570\u636e\u96c6\u4e0a\u6162"},{"location":"chapter%2006%3A%20machine%20learning/01.%20classical%20machine%20learning/#colab","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u5728CoLab\u6216\u7b14\u8bb0\u672c\u4e2d\u5b8c\u6210\uff09","text":"
  1. \u4ece\u5934\u5b9e\u73b0\u9ad8\u65af\u6734\u7d20\u8d1d\u53f6\u65af\u3002\u5728\u5408\u6210\u4e8c\u7ef4\u6570\u636e\uff08\u4e24\u4e2a\u7c7b\u522b\uff09\u4e0a\u8bad\u7ec3\u5e76\u53ef\u89c6\u5316\u51b3\u7b56\u8fb9\u754c\u3002\u4e0escikit-learn\u7684\u5b9e\u73b0\u8fdb\u884c\u6bd4\u8f83\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\nfrom sklearn.datasets import make_classification\n\n# \u751f\u6210\u5408\u6210\u6570\u636e\nX, y = make_classification(n_samples=300, n_features=2, n_redundant=0,\n                           n_informative=2, n_clusters_per_class=1, random_state=42)\nX, y = jnp.array(X), jnp.array(y)\n\n# \u4ece\u5934\u62df\u5408\u9ad8\u65af\u6734\u7d20\u8d1d\u53f6\u65af\nclasses = jnp.unique(y)\nparams = {}\nfor c in classes:\n    c = int(c)\n    mask = y == c\n    X_c = X[mask]\n    params[c] = {\n        'mean': jnp.mean(X_c, axis=0),\n        'var': jnp.var(X_c, axis=0),\n        'prior': jnp.sum(mask) / len(y)\n    }\n\ndef gaussian_log_likelihood(x, mean, var):\n    return -0.5 * jnp.sum(jnp.log(2 * jnp.pi * var) + (x - mean)**2 / var)\n\ndef predict(X):\n    preds = []\n    for x in X:\n        log_posts = []\n        for c in [0, 1]:\n            log_post = jnp.log(params[c]['prior']) + gaussian_log_likelihood(\n                x, params[c]['mean'], params[c]['var'])\n            log_posts.append(log_post)\n        preds.append(jnp.argmax(jnp.array(log_posts)))\n    return jnp.array(preds)\n\n# \u51b3\u7b56\u8fb9\u754c\u53ef\u89c6\u5316\nxx, yy = jnp.meshgrid(jnp.linspace(X[:,0].min()-1, X[:,0].max()+1, 200),\n                       jnp.linspace(X[:,1].min()-1, X[:,1].max()+1, 200))\ngrid = jnp.column_stack([xx.ravel(), yy.ravel()])\nzz = predict(grid).reshape(xx.shape)\n\nplt.figure(figsize=(8, 6))\nplt.contourf(xx, yy, zz, alpha=0.3, cmap='coolwarm')\nplt.scatter(X[y==0, 0], X[y==0, 1], c='#3498db', label='Class 0', edgecolors='k', s=20)\nplt.scatter(X[y==1, 0], X[y==1, 1], c='#e74c3c', label='Class 1', edgecolors='k', s=20)\nplt.title(\"Gaussian Naive Bayes Decision Boundary\")\nplt.legend()\nplt.grid(alpha=0.3)\nplt.show()\n\naccuracy = jnp.mean(predict(X) == y)\nprint(f\"Training accuracy: {accuracy:.2%}\")\n

  2. \u6784\u5efa\u4e00\u4e2a\u4f7f\u7528\u57fa\u5c3c\u4e0d\u7eaf\u5ea6\u8fdb\u884c\u5206\u88c2\u7684\u51b3\u7b56\u6811\u3002\u5b9e\u73b0\u5355\u4e2a\u8282\u70b9\u7684\u5206\u88c2\u903b\u8f91\uff0c\u5e76\u5c55\u793a\u4fe1\u606f\u589e\u76ca\u5982\u4f55\u9009\u62e9\u6700\u4f73\u7279\u5f81\u548c\u9608\u503c\u3002

    import jax.numpy as jnp\n\ndef gini_impurity(y):\n    \"\"\"\u8ba1\u7b97\u6807\u7b7e\u6570\u7ec4\u7684\u57fa\u5c3c\u4e0d\u7eaf\u5ea6\u3002\"\"\"\n    classes, counts = jnp.unique(y, return_counts=True)\n    probs = counts / len(y)\n    return 1.0 - jnp.sum(probs ** 2)\n\ndef information_gain(y, left_mask):\n    \"\"\"\u901a\u8fc7\u5e03\u5c14\u63a9\u7801\u5c06y\u5206\u88c2\u4e3a\u5de6/\u53f3\u540e\u7684\u4fe1\u606f\u589e\u76ca\u3002\"\"\"\n    parent_gini = gini_impurity(y)\n    left_y, right_y = y[left_mask], y[~left_mask]\n    n = len(y)\n    if len(left_y) == 0 or len(right_y) == 0:\n        return 0.0\n    child_gini = (len(left_y)/n) * gini_impurity(left_y) + \\\n                 (len(right_y)/n) * gini_impurity(right_y)\n    return float(parent_gini - child_gini)\n\ndef best_split(X, y):\n    \"\"\"\u627e\u5230\u6700\u5927\u5316\u4fe1\u606f\u589e\u76ca\u7684\u7279\u5f81\u548c\u9608\u503c\u3002\"\"\"\n    best_ig, best_feat, best_thresh = -1, None, None\n    for feat in range(X.shape[1]):\n        thresholds = jnp.unique(X[:, feat])\n        for thresh in thresholds:\n            mask = X[:, feat] <= float(thresh)\n            ig = information_gain(y, mask)\n            if ig > best_ig:\n                best_ig, best_feat, best_thresh = ig, feat, float(thresh)\n    return best_feat, best_thresh, best_ig\n\n# \u793a\u4f8b\uff1a\u5408\u6210\u6570\u636e\nfrom sklearn.datasets import make_classification\nX, y = make_classification(n_samples=100, n_features=4, n_redundant=0, random_state=0)\nX, y = jnp.array(X), jnp.array(y)\n\nfeat, thresh, ig = best_split(X, y)\nprint(f\"Best split: feature {feat}, threshold {thresh:.3f}, info gain {ig:.4f}\")\nprint(f\"Parent Gini: {gini_impurity(y):.4f}\")\nmask = X[:, feat] <= thresh\nprint(f\"Left Gini:   {gini_impurity(y[mask]):.4f} ({int(jnp.sum(mask))} samples)\")\nprint(f\"Right Gini:  {gini_impurity(y[~mask]):.4f} ({int(jnp.sum(~mask))} samples)\")\n

  3. \u4ece\u5934\u5b9e\u73b0\u5e26K-Means++\u521d\u59cb\u5316\u7684K-Means\u3002\u5bf9\u5408\u6210\u6570\u636e\u96c6\u8fdb\u884c\u805a\u7c7b\u5e76\u53ef\u89c6\u5316\u6bcf\u6b21\u8fed\u4ee3\u7684\u7c07\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\nfrom sklearn.datasets import make_blobs\n\n# \u751f\u6210\u5408\u6210\u7c07\nX, y_true = make_blobs(n_samples=300, centers=4, cluster_std=0.8, random_state=42)\nX = jnp.array(X)\n\ndef kmeans_plus_plus_init(X, K, key):\n    \"\"\"K-Means++\u521d\u59cb\u5316\u3002\"\"\"\n    n = X.shape[0]\n    idx = jax.random.randint(key, (), 0, n)\n    centroids = [X[idx]]\n    for _ in range(1, K):\n        dists = jnp.min(jnp.stack([jnp.sum((X - c)**2, axis=1) for c in centroids]), axis=0)\n        probs = dists / jnp.sum(dists)\n        key, subkey = jax.random.split(key)\n        idx = jax.random.choice(subkey, n, p=probs)\n        centroids.append(X[idx])\n    return jnp.stack(centroids)\n\ndef kmeans(X, K, max_iters=20, key=jax.random.PRNGKey(0)):\n    centroids = kmeans_plus_plus_init(X, K, key)\n    history = [centroids]\n    for _ in range(max_iters):\n        # \u5206\u914d\u6b65\u9aa4\n        dists = jnp.stack([jnp.sum((X - c)**2, axis=1) for c in centroids])\n        labels = jnp.argmin(dists, axis=0)\n        # \u66f4\u65b0\u6b65\u9aa4\n        new_centroids = jnp.stack([\n            jnp.mean(X[labels == k], axis=0) for k in range(K)\n        ])\n        history.append(new_centroids)\n        if jnp.allclose(centroids, new_centroids):\n            break\n        centroids = new_centroids\n    return labels, centroids, history\n\nK = 4\nlabels, centroids, history = kmeans(X, K)\n\n# \u7ed8\u5236\u6700\u7ec8\u7ed3\u679c\ncolors = ['#3498db', '#e74c3c', '#27ae60', '#9b59b6']\nplt.figure(figsize=(8, 6))\nfor k in range(K):\n    mask = labels == k\n    plt.scatter(X[mask, 0], X[mask, 1], c=colors[k], s=20, alpha=0.6)\n    plt.scatter(centroids[k, 0], centroids[k, 1], c=colors[k], marker='X',\n                s=200, edgecolors='k', linewidths=1.5)\nplt.title(f\"K-Means Clustering (K={K}, {len(history)-1} iterations)\")\nplt.grid(alpha=0.3)\nplt.show()\n\n# \u8ba1\u7b97\u60ef\u6027\ninertia = sum(jnp.sum((X[labels == k] - centroids[k])**2) for k in range(K))\nprint(f\"Final inertia: {inertia:.2f}\")\n

  4. \u6f14\u793a\u6838\u6280\u5de7\u3002\u901a\u8fc7\u6bd4\u8f83\u6838\u77e9\u9635\u4e0e\u591a\u9879\u5f0f\u6838\u7684\u663e\u5f0f\u7279\u5f81\u6620\u5c04\uff0c\u5c55\u793aRBF\u6838\u5982\u4f55\u5728\u9ad8\u7ef4\u7a7a\u95f4\u4e2d\u8ba1\u7b97\u70b9\u79ef\u3002

    import jax.numpy as jnp\n\n# \u7b80\u53552D\u6570\u636e\nX = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])\n\n# \u591a\u9879\u5f0f\u6838\uff1aK(x,y) = (x\u00b7y + 1)^2\ndef poly_kernel(X, degree=2, c=1.0):\n    return (X @ X.T + c) ** degree\n\n# 2D\u7684\u663e\u5f0f\u4e8c\u6b21\u7279\u5f81\u6620\u5c04\uff1a(1, sqrt(2)*x1, sqrt(2)*x2, x1^2, x2^2, sqrt(2)*x1*x2)\ndef poly_features(X):\n    x1, x2 = X[:, 0], X[:, 1]\n    return jnp.column_stack([\n        jnp.ones(len(X)),\n        jnp.sqrt(2) * x1,\n        jnp.sqrt(2) * x2,\n        x1 ** 2,\n        x2 ** 2,\n        jnp.sqrt(2) * x1 * x2\n    ])\n\nK_trick = poly_kernel(X)\nphi = poly_features(X)\nK_explicit = phi @ phi.T\n\nprint(\"Kernel trick (polynomial degree 2):\")\nprint(K_trick)\nprint(\"\\nExplicit feature map dot products:\")\nprint(K_explicit)\nprint(f\"\\nMatrices match: {jnp.allclose(K_trick, K_explicit)}\")\n\n# RBF\u6838\uff1a\u4e0d\u5b58\u5728\u6709\u9650\u7684\u663e\u5f0f\u6620\u5c04\ndef rbf_kernel(X, gamma=0.5):\n    sq_dists = jnp.sum(X**2, axis=1, keepdims=True) + \\\n               jnp.sum(X**2, axis=1) - 2 * X @ X.T\n    return jnp.exp(-gamma * sq_dists)\n\nK_rbf = rbf_kernel(X)\nprint(\"\\nRBF kernel matrix:\")\nprint(K_rbf)\nprint(\"Diagonal is always 1 (a point is identical to itself)\")\nprint(\"Off-diagonal entries decay with distance\")\n

"},{"location":"chapter%2006%3A%20machine%20learning/02.%20gradient%20machine%20learning/","title":"\u68af\u5ea6\u673a\u5668\u5b66\u4e60","text":"

\u57fa\u4e8e\u68af\u5ea6\u7684\u5b66\u4e60\u901a\u8fc7\u6cbf\u7740\u635f\u5931\u66f2\u9762\u7684\u659c\u7387\u8fed\u4ee3\u4f18\u5316\u6a21\u578b\u53c2\u6570\u3002\u672c\u6587\u6db5\u76d6\u7ebf\u6027\u56de\u5f52\u3001\u903b\u8f91\u56de\u5f52\u3001softmax\u5206\u7c7b\u3001\u68af\u5ea6\u4e0b\u964d\u53d8\u4f53\u3001\u6b63\u5219\u5316\uff08L1/L2\uff09\u548c\u504f\u5dee-\u65b9\u5dee\u6743\u8861

\\[\\hat{y} = w \\cdot x + b = \\sum_{i=1}^{d} w_i x_i + b\\] \\[\\mathcal{L}(w) = \\frac{1}{n} \\sum_{i=1}^{n} (y_i - \\hat{y}_i)^2 = \\frac{1}{n} \\|y - Xw\\|^2\\]

\\[w^{*} = (X^T X)^{-1} X^T y\\] \\[\\sigma(z) = \\frac{1}{1 + e^{-z}}\\]

\\[\\mathcal{L} = -\\frac{1}{n} \\sum_{i=1}^{n} \\left[ y_i \\log(\\hat{y}_i) + (1 - y_i) \\log(1 - \\hat{y}_i) \\right]\\] \\[w \\leftarrow w - \\eta \\frac{\\partial \\mathcal{L}}{\\partial w}\\]

\\[\\frac{\\partial L}{\\partial w} = \\frac{\\partial L}{\\partial z} \\cdot \\frac{\\partial z}{\\partial w}\\] \\[v_t = \\beta v_{t-1} + (1 - \\beta) \\nabla \\mathcal{L}$$ $$w \\leftarrow w - \\eta \\, v_t\\] \\[v_t = \\beta \\, v_{t-1} + \\nabla \\mathcal{L}(w - \\eta \\beta \\, v_{t-1})$$ $$w \\leftarrow w - \\eta \\, v_t\\] \\[G_t = G_{t-1} + g_t^2, \\quad w \\leftarrow w - \\frac{\\eta}{\\sqrt{G_t + \\epsilon}} g_t\\] \\[s_t = \\beta \\, s_{t-1} + (1 - \\beta) g_t^2, \\quad w \\leftarrow w - \\frac{\\eta}{\\sqrt{s_t + \\epsilon}} g_t\\] \\[m_t = \\beta_1 m_{t-1} + (1 - \\beta_1) g_t$$ $$v_t = \\beta_2 v_{t-1} + (1 - \\beta_2) g_t^2\\] \\[\\hat{m}_t = \\frac{m_t}{1 - \\beta_1^t}, \\quad \\hat{v}_t = \\frac{v_t}{1 - \\beta_2^t}$$ $$w \\leftarrow w - \\frac{\\eta}{\\sqrt{\\hat{v}_t} + \\epsilon} \\hat{m}_t\\]

\\[w \\leftarrow w - \\eta \\left( \\frac{\\hat{m}_t}{\\sqrt{\\hat{v}_t} + \\epsilon} + \\lambda \\, w \\right)\\] \\[w \\leftarrow w - \\eta \\cdot \\text{sign}(\\beta_1 \\, m_{t-1} + (1 - \\beta_1) \\, g_t)$$ $$m_t = \\beta_2 \\, m_{t-1} + (1 - \\beta_2) \\, g_t\\] \\[G_t = \\text{NesterovMomentum}(\\nabla \\mathcal{L})$$ $$U_t = \\text{NewtonSchulz}(G_t) \\approx G_t (G_t^T G_t)^{-1/2}$$ $$W \\leftarrow W - \\eta \\, U_t\\]

\\[\\mathcal{L} = -\\log(\\hat{y}_c)\\] \\[\\mathcal{L}_{\\text{reg}} = \\mathcal{L}_{\\text{data}} + \\lambda \\, R(w)\\] \\[\\text{Error} = \\text{Bias}^2 + \\text{Variance} + \\text{Irreducible Noise}\\] "},{"location":"chapter%2006%3A%20machine%20learning/02.%20gradient%20machine%20learning/#colab","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u5728CoLab\u6216\u7b14\u8bb0\u672c\u4e2d\u5b8c\u6210\uff09","text":"
  1. \u4f7f\u7528\u6b63\u89c4\u65b9\u7a0b\u548c\u68af\u5ea6\u4e0b\u964d\u4e24\u79cd\u65b9\u6cd5\u5b9e\u73b0\u7ebf\u6027\u56de\u5f52\u3002\u6bd4\u8f83\u6c42\u89e3\u7ed3\u679c\uff0c\u5e76\u7ed8\u5236GD\u635f\u5931\u968f\u8fed\u4ee3\u7684\u6536\u655b\u66f2\u7ebf\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u751f\u6210\u5408\u6210\u6570\u636e\uff1ay = 3x + 2 + noise\nkey = jax.random.PRNGKey(42)\nn = 100\nX = jax.random.uniform(key, (n, 1), minval=0, maxval=10)\ny = 3 * X[:, 0] + 2 + jax.random.normal(key, (n,)) * 1.5\n\n# \u6dfb\u52a0\u504f\u7f6e\u5217\nX_b = jnp.column_stack([X, jnp.ones(n)])\n\n# \u6b63\u89c4\u65b9\u7a0b\nw_exact = jnp.linalg.solve(X_b.T @ X_b, X_b.T @ y)\nprint(f\"Normal equation: w={w_exact[0]:.4f}, b={w_exact[1]:.4f}\")\n\n# \u68af\u5ea6\u4e0b\u964d\nw_gd = jnp.zeros(2)\nlr = 0.005\nlosses = []\nfor step in range(500):\n    pred = X_b @ w_gd\n    error = pred - y\n    loss = jnp.mean(error ** 2)\n    losses.append(float(loss))\n    grad = (2 / n) * X_b.T @ error\n    w_gd = w_gd - lr * grad\n\nprint(f\"Gradient descent: w={w_gd[0]:.4f}, b={w_gd[1]:.4f}\")\n\nfig, axes = plt.subplots(1, 2, figsize=(12, 4))\naxes[0].scatter(X[:, 0], y, s=15, alpha=0.5, color='#3498db')\naxes[0].plot([0, 10], [w_exact[1], w_exact[0]*10 + w_exact[1]], color='#e74c3c', linewidth=2)\naxes[0].set_title(\"Linear Regression Fit\")\naxes[0].set_xlabel(\"x\"); axes[0].set_ylabel(\"y\")\n\naxes[1].plot(losses, color='#27ae60', linewidth=1.5)\naxes[1].set_title(\"GD Loss Convergence\")\naxes[1].set_xlabel(\"Step\"); axes[1].set_ylabel(\"MSE\")\naxes[1].set_yscale('log')\nplt.tight_layout()\nplt.show()\n

  2. \u4ece\u5934\u5b9e\u73b0\u5e26\u68af\u5ea6\u4e0b\u964d\u7684\u903b\u8f91\u56de\u5f52\u3002\u5728\u4e8c\u7ef4\u6570\u636e\u96c6\u4e0a\u8bad\u7ec3\u5e76\u53ef\u89c6\u5316\u5b66\u4e60\u5230\u7684\u51b3\u7b56\u8fb9\u754c\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\nfrom sklearn.datasets import make_moons\n\n# \u751f\u6210\u6570\u636e\nX, y = make_moons(n_samples=300, noise=0.2, random_state=42)\nX, y = jnp.array(X), jnp.array(y, dtype=jnp.float32)\n\ndef sigmoid(z):\n    return 1 / (1 + jnp.exp(-z))\n\n# \u6dfb\u52a0\u504f\u7f6e\u5217\nX_b = jnp.column_stack([X, jnp.ones(len(X))])\nw = jnp.zeros(3)\nlr = 0.5\nlosses = []\n\nfor step in range(2000):\n    z = X_b @ w\n    pred = sigmoid(z)\n    # BCE\u635f\u5931\n    loss = -jnp.mean(y * jnp.log(pred + 1e-8) + (1 - y) * jnp.log(1 - pred + 1e-8))\n    losses.append(float(loss))\n    # \u68af\u5ea6\n    grad = X_b.T @ (pred - y) / len(y)\n    w = w - lr * grad\n\n# \u51b3\u7b56\u8fb9\u754c\nxx, yy = jnp.meshgrid(jnp.linspace(-2, 3, 200), jnp.linspace(-1.5, 2, 200))\ngrid = jnp.column_stack([xx.ravel(), yy.ravel(), jnp.ones(xx.size)])\nzz = sigmoid(grid @ w).reshape(xx.shape)\n\nplt.figure(figsize=(8, 6))\nplt.contourf(xx, yy, zz, levels=[0, 0.5, 1], alpha=0.3, colors=['#e74c3c', '#3498db'])\nplt.contour(xx, yy, zz, levels=[0.5], colors='#9b59b6', linewidths=2)\nplt.scatter(X[y==0, 0], X[y==0, 1], c='#e74c3c', s=15, label='Class 0')\nplt.scatter(X[y==1, 0], X[y==1, 1], c='#3498db', s=15, label='Class 1')\nplt.title(\"Logistic Regression Decision Boundary\")\nplt.legend()\nplt.grid(alpha=0.3)\nplt.show()\n

  3. \u5728\u4e8c\u7ef4\u4e8c\u6b21\u66f2\u9762\u4e0a\u6bd4\u8f83\u4f18\u5316\u5668\u7684\u8f68\u8ff9\u3002\u4ece\u76f8\u540c\u7684\u8d77\u70b9\u8fd0\u884cSGD\u3001SGD+Momentum\u548cAdam\uff0c\u7ed8\u5236\u5b83\u4eec\u7684\u8def\u5f84\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u62c9\u957f\u7684\u4e8c\u6b21\u66f2\u9762\uff1aL(w1, w2) = 0.5*w1^2 + 10*w2^2\ndef loss_fn(w):\n    return 0.5 * w[0]**2 + 10 * w[1]**2\n\ngrad_fn = jax.grad(loss_fn)\n\ndef run_sgd(w0, lr=0.05, steps=80):\n    w = w0.copy()\n    path = [w.copy()]\n    for _ in range(steps):\n        g = grad_fn(w)\n        w = w - lr * g\n        path.append(w.copy())\n    return jnp.stack(path)\n\ndef run_momentum(w0, lr=0.05, beta=0.9, steps=80):\n    w, v = w0.copy(), jnp.zeros(2)\n    path = [w.copy()]\n    for _ in range(steps):\n        g = grad_fn(w)\n        v = beta * v + (1 - beta) * g\n        w = w - lr * v\n        path.append(w.copy())\n    return jnp.stack(path)\n\ndef run_adam(w0, lr=0.05, b1=0.9, b2=0.999, eps=1e-8, steps=80):\n    w, m, v = w0.copy(), jnp.zeros(2), jnp.zeros(2)\n    path = [w.copy()]\n    for t in range(1, steps + 1):\n        g = grad_fn(w)\n        m = b1 * m + (1 - b1) * g\n        v = b2 * v + (1 - b2) * g**2\n        m_hat = m / (1 - b1**t)\n        v_hat = v / (1 - b2**t)\n        w = w - lr * m_hat / (jnp.sqrt(v_hat) + eps)\n        path.append(w.copy())\n    return jnp.stack(path)\n\nw0 = jnp.array([8.0, 3.0])\nsgd_path = run_sgd(w0)\nmom_path = run_momentum(w0)\nadam_path = run_adam(w0)\n\n# \u7ed8\u56fe\nfig, ax = plt.subplots(figsize=(8, 6))\nw1 = jnp.linspace(-10, 10, 100)\nw2 = jnp.linspace(-4, 4, 100)\nW1, W2 = jnp.meshgrid(w1, w2)\nL = 0.5 * W1**2 + 10 * W2**2\nax.contour(W1, W2, L, levels=20, cmap='Greys', alpha=0.4)\nax.plot(sgd_path[:,0], sgd_path[:,1], 'o-', color='#3498db', markersize=2, linewidth=1, label='SGD')\nax.plot(mom_path[:,0], mom_path[:,1], 'o-', color='#27ae60', markersize=2, linewidth=1, label='Momentum')\nax.plot(adam_path[:,0], adam_path[:,1], 'o-', color='#e74c3c', markersize=2, linewidth=1, label='Adam')\nax.plot(0, 0, 'k*', markersize=15, label='Minimum')\nax.set_xlabel('w\u2081'); ax.set_ylabel('w\u2082')\nax.set_title(\"Optimizer Trajectories on Elongated Quadratic\")\nax.legend()\nplt.grid(alpha=0.3)\nplt.show()\n

  4. \u5c55\u793aL1\u4e0eL2\u6b63\u5219\u5316\u5bf9\u6743\u91cd\u7a00\u758f\u6027\u7684\u5f71\u54cd\u3002\u4f7f\u7528\u4e24\u79cd\u60e9\u7f5a\u8bad\u7ec3\u7ebf\u6027\u56de\u5f52\uff0c\u5e76\u6bd4\u8f83\u5f97\u5230\u7684\u6743\u91cd\u5411\u91cf\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u5408\u6210\u6570\u636e\uff1a20\u4e2a\u7279\u5f81\u4e2d\u53ea\u6709\u524d3\u4e2a\u662f\u76f8\u5173\u7684\nkey = jax.random.PRNGKey(0)\nn, d = 200, 20\nw_true = jnp.zeros(d).at[:3].set(jnp.array([3.0, -2.0, 1.5]))\nX = jax.random.normal(key, (n, d))\ny = X @ w_true + 0.5 * jax.random.normal(key, (n,))\n\ndef train_ridge(X, y, lam=1.0, lr=0.01, steps=2000):\n    \"\"\"\u901a\u8fc7GD\u8fdb\u884cL2\u6b63\u5219\u5316\u7ebf\u6027\u56de\u5f52\u3002\"\"\"\n    w = jnp.zeros(X.shape[1])\n    for _ in range(steps):\n        pred = X @ w\n        grad = (2/len(y)) * X.T @ (pred - y) + 2 * lam * w\n        w = w - lr * grad\n    return w\n\ndef train_lasso(X, y, lam=1.0, lr=0.01, steps=2000):\n    \"\"\"\u901a\u8fc7\u8fd1\u7aefGD\u8fdb\u884cL1\u6b63\u5219\u5316\u7ebf\u6027\u56de\u5f52\u3002\"\"\"\n    w = jnp.zeros(X.shape[1])\n    for _ in range(steps):\n        pred = X @ w\n        grad = (2/len(y)) * X.T @ (pred - y)\n        w = w - lr * grad\n        # \u8f6f\u9608\u503c\uff08L1\u7684\u8fd1\u7aef\u7b97\u5b50\uff09\n        w = jnp.sign(w) * jnp.maximum(jnp.abs(w) - lr * lam, 0)\n    return w\n\nw_l2 = train_ridge(X, y, lam=0.1)\nw_l1 = train_lasso(X, y, lam=0.1)\n\nfig, axes = plt.subplots(1, 3, figsize=(14, 4))\naxes[0].bar(range(d), w_true, color='#333', alpha=0.7)\naxes[0].set_title(\"True Weights\"); axes[0].set_xlabel(\"Feature\")\naxes[1].bar(range(d), w_l2, color='#3498db', alpha=0.7)\naxes[1].set_title(\"L2 (Ridge): shrinks all\"); axes[1].set_xlabel(\"Feature\")\naxes[2].bar(range(d), w_l1, color='#e74c3c', alpha=0.7)\naxes[2].set_title(\"L1 (Lasso): zeros out irrelevant\"); axes[2].set_xlabel(\"Feature\")\nplt.tight_layout()\nplt.show()\n\nprint(f\"L2 non-zero weights: {int(jnp.sum(jnp.abs(w_l2) > 0.01))}/{d}\")\nprint(f\"L1 non-zero weights: {int(jnp.sum(jnp.abs(w_l1) > 0.01))}/{d}\")\n

"},{"location":"chapter%2006%3A%20machine%20learning/03.%20deep%20learning/","title":"\u6df1\u5ea6\u5b66\u4e60","text":"

\u6df1\u5ea6\u5b66\u4e60\u5806\u53e0\u975e\u7ebf\u6027\u5c42\u6765\u6784\u5efa\u5c42\u6b21\u5316\u8868\u793a\uff0c\u81ea\u52a8\u5c06\u539f\u59cb\u8f93\u5165\u8f6c\u6362\u4e3a\u6709\u7528\u7684\u7279\u5f81\u3002\u672c\u6587\u6db5\u76d6MLP\u3001\u6fc0\u6d3b\u51fd\u6570\u3001\u53cd\u5411\u4f20\u64ad\u3001CNN\u3001RNN\u3001LSTM\u3001\u6ce8\u610f\u529b\u673a\u5236\u3001Transformer\u3001GAN\u3001VAE\u3001\u6269\u6563\u6a21\u578b\u548c\u5f52\u4e00\u5316\u6280\u672f

\\[h = \\sigma(Wx + b)\\]

\\[\\hat{x} = \\frac{x - \\mu_B}{\\sqrt{\\sigma_B^2 + \\epsilon}}, \\quad y = \\gamma \\hat{x} + \\beta\\]

\\[(\\text{input} * K)[i,j] = \\sum_{m=0}^{k-1} \\sum_{n=0}^{k-1} \\text{input}[i+m, j+n] \\cdot K[m, n]\\]

\\[h_t = \\tanh(W_h h_{t-1} + W_x x_t + b)\\]

\\[\\text{Attention}(Q, K, V) = \\text{softmax}\\!\\left(\\frac{QK^T}{\\sqrt{d_k}}\\right) V\\] \\[\\text{MultiHead}(Q, K, V) = \\text{Concat}(\\text{head}_1, \\ldots, \\text{head}_h) W^O\\]

\\[PE_{(pos, 2i)} = \\sin\\!\\left(\\frac{pos}{10000^{2i/d}}\\right), \\quad PE_{(pos, 2i+1)} = \\cos\\!\\left(\\frac{pos}{10000^{2i/d}}\\right)\\] \\[z = f_{\\text{enc}}(x), \\quad \\hat{x} = f_{\\text{dec}}(z), \\quad \\mathcal{L} = \\|x - \\hat{x}\\|^2\\] \\[\\mathcal{L} = \\underbrace{\\|x - \\hat{x}\\|^2}_{\\text{reconstruction}} + \\underbrace{D_{\\text{KL}}(q(z|x) \\| p(z))}_{\\text{regularisation}}\\] "},{"location":"chapter%2006%3A%20machine%20learning/03.%20deep%20learning/#colab","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u5728CoLab\u6216\u7b14\u8bb0\u672c\u4e2d\u5b8c\u6210\uff09","text":"
  1. \u5728JAX\u4e2d\u4ece\u5934\u6784\u5efa\u4e00\u4e2a\u7b80\u5355\u7684MLP\u3002\u5728\u4e8c\u7ef4\u5206\u7c7b\u95ee\u9898\uff08\u5982\u540c\u5fc3\u5706\uff09\u4e0a\u8bad\u7ec3\u5e76\u53ef\u89c6\u5316\u51b3\u7b56\u8fb9\u754c\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\nfrom sklearn.datasets import make_circles\n\n# \u6570\u636e\nX, y = make_circles(n_samples=500, noise=0.1, factor=0.5, random_state=42)\nX, y = jnp.array(X), jnp.array(y, dtype=jnp.float32)\n\n# \u521d\u59cb\u5316\u4e00\u4e2a2\u5c42MLP\uff1a2 -> 16 -> 16 -> 1\ndef init_params(key):\n    k1, k2, k3 = jax.random.split(key, 3)\n    return {\n        'W1': jax.random.normal(k1, (2, 16)) * 0.5,\n        'b1': jnp.zeros(16),\n        'W2': jax.random.normal(k2, (16, 16)) * 0.5,\n        'b2': jnp.zeros(16),\n        'W3': jax.random.normal(k3, (16, 1)) * 0.5,\n        'b3': jnp.zeros(1),\n    }\n\ndef forward(params, x):\n    h = jnp.maximum(0, x @ params['W1'] + params['b1'])  # ReLU\n    h = jnp.maximum(0, h @ params['W2'] + params['b2'])   # ReLU\n    logit = (h @ params['W3'] + params['b3']).squeeze()\n    return jax.nn.sigmoid(logit)\n\ndef loss_fn(params, X, y):\n    pred = forward(params, X)\n    return -jnp.mean(y * jnp.log(pred + 1e-7) + (1 - y) * jnp.log(1 - pred + 1e-7))\n\ngrad_fn = jax.jit(jax.grad(loss_fn))\nparams = init_params(jax.random.PRNGKey(0))\nlr = 0.1\n\nfor step in range(2000):\n    grads = grad_fn(params, X, y)\n    params = {k: params[k] - lr * grads[k] for k in params}\n\n# \u7ed8\u5236\u51b3\u7b56\u8fb9\u754c\nxx, yy = jnp.meshgrid(jnp.linspace(-2, 2, 200), jnp.linspace(-2, 2, 200))\ngrid = jnp.column_stack([xx.ravel(), yy.ravel()])\nzz = forward(params, grid).reshape(xx.shape)\n\nplt.figure(figsize=(7, 6))\nplt.contourf(xx, yy, zz, levels=[0, 0.5, 1], alpha=0.3, colors=['#e74c3c', '#3498db'])\nplt.scatter(X[y==0,0], X[y==0,1], c='#e74c3c', s=10, label='Class 0')\nplt.scatter(X[y==1,0], X[y==1,1], c='#3498db', s=10, label='Class 1')\nplt.title(\"MLP Decision Boundary on Concentric Circles\")\nplt.legend(); plt.grid(alpha=0.3); plt.show()\n\nacc = jnp.mean((forward(params, X) > 0.5) == y)\nprint(f\"Accuracy: {acc:.2%}\")\n

  2. \u4ece\u5934\u5b9e\u73b0\u4e00\u7ef4\u5377\u79ef\u3002\u5c06\u7b80\u5355\u7684\u8fb9\u7f18\u68c0\u6d4b\u6ee4\u6ce2\u5668\u5e94\u7528\u4e8e\u4fe1\u53f7\uff0c\u5e76\u4e0e\u5185\u7f6e\u7684 jnp.convolve \u8fdb\u884c\u6bd4\u8f83\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef conv1d(signal, kernel):\n    \"\"\"\u4ece\u5934\u5b9e\u73b0\u4e00\u7ef4\u5377\u79ef\uff08valid\u6a21\u5f0f\uff09\u3002\"\"\"\n    n, k = len(signal), len(kernel)\n    output = jnp.zeros(n - k + 1)\n    for i in range(n - k + 1):\n        output = output.at[i].set(jnp.sum(signal[i:i+k] * kernel))\n    return output\n\n# \u521b\u5efa\u4e00\u4e2a\u5e26\u6709\u9636\u8dc3\u51fd\u6570\u7684\u4fe1\u53f7\nt = jnp.linspace(0, 4, 200)\nsignal = jnp.where(t < 1, 0.0, jnp.where(t < 2, 1.0, jnp.where(t < 3, 0.5, 1.5)))\n\n# \u8fb9\u7f18\u68c0\u6d4b\u6838\nedge_kernel = jnp.array([-1.0, 0.0, 1.0])\n\n# \u6211\u4eec\u7684\u5b9e\u73b0 vs \u5185\u7f6e\u51fd\u6570\nour_output = conv1d(signal, edge_kernel)\njnp_output = jnp.convolve(signal, edge_kernel, mode='valid')\n\nfig, axes = plt.subplots(3, 1, figsize=(10, 6), sharex=True)\naxes[0].plot(t, signal, color='#3498db', linewidth=1.5)\naxes[0].set_title(\"Original Signal\"); axes[0].set_ylabel(\"Value\")\n\naxes[1].plot(t[:len(our_output)], our_output, color='#e74c3c', linewidth=1.5)\naxes[1].set_title(\"After Edge Detection (our conv1d)\"); axes[1].set_ylabel(\"Value\")\n\naxes[2].plot(t[:len(jnp_output)], jnp_output, color='#27ae60', linewidth=1.5, linestyle='--')\naxes[2].set_title(\"After Edge Detection (jnp.convolve)\"); axes[2].set_ylabel(\"Value\")\naxes[2].set_xlabel(\"t\")\n\nplt.tight_layout(); plt.show()\nprint(f\"Outputs match: {jnp.allclose(our_output, jnp_output)}\")\n

  3. \u4ece\u5934\u5b9e\u73b0\u7f29\u653e\u70b9\u79ef\u6ce8\u610f\u529b\u3002\u4e3a\u4e00\u4e2a\u5c0f\u4f8b\u5b50\u8ba1\u7b97\u6ce8\u610f\u529b\u6743\u91cd\uff0c\u5e76\u5c06\u6ce8\u610f\u529b\u77e9\u9635\u53ef\u89c6\u5316\u4e3a\u70ed\u529b\u56fe\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef scaled_dot_product_attention(Q, K, V):\n    \"\"\"\u7f29\u653e\u70b9\u79ef\u6ce8\u610f\u529b\u3002\"\"\"\n    d_k = Q.shape[-1]\n    scores = Q @ K.T / jnp.sqrt(d_k)\n    weights = jax.nn.softmax(scores, axis=-1)\n    output = weights @ V\n    return output, weights\n\n# \u793a\u4f8b\uff1a4\u4e2a\u6807\u8bb0\uff0c\u5d4c\u5165\u7ef4\u5ea68\nkey = jax.random.PRNGKey(42)\nk1, k2, k3 = jax.random.split(key, 3)\nseq_len, d_model = 4, 8\n\nQ = jax.random.normal(k1, (seq_len, d_model))\nK = jax.random.normal(k2, (seq_len, d_model))\nV = jax.random.normal(k3, (seq_len, d_model))\n\noutput, weights = scaled_dot_product_attention(Q, K, V)\n\nprint(f\"Q shape: {Q.shape}\")\nprint(f\"Attention weights shape: {weights.shape}\")\nprint(f\"Output shape: {output.shape}\")\nprint(f\"\\nAttention weights (rows sum to 1):\")\nprint(weights)\nprint(f\"Row sums: {weights.sum(axis=-1)}\")\n\n# \u53ef\u89c6\u5316\u6ce8\u610f\u529b\nfig, ax = plt.subplots(figsize=(5, 4))\nim = ax.imshow(weights, cmap='Blues', vmin=0, vmax=1)\nax.set_xlabel(\"Key position\"); ax.set_ylabel(\"Query position\")\nax.set_title(\"Attention Weights\")\ntokens = ['tok 0', 'tok 1', 'tok 2', 'tok 3']\nax.set_xticks(range(4)); ax.set_xticklabels(tokens)\nax.set_yticks(range(4)); ax.set_yticklabels(tokens)\nfor i in range(4):\n    for j in range(4):\n        ax.text(j, i, f\"{weights[i,j]:.2f}\", ha='center', va='center', fontsize=10)\nplt.colorbar(im); plt.tight_layout(); plt.show()\n

  4. \u6784\u5efa\u4e00\u4e2a\u7b80\u5355\u7684\u81ea\u7f16\u7801\u5668\uff0c\u901a\u8fc7\u4e00\u7ef4\u74f6\u9888\u538b\u7f29\u4e8c\u7ef4\u6570\u636e\u5e76\u91cd\u5efa\u3002\u53ef\u89c6\u5316\u6f5c\u7a7a\u95f4\u548c\u91cd\u5efa\u7ed3\u679c\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\nfrom sklearn.datasets import make_moons\n\n# \u6570\u636e\nX, _ = make_moons(n_samples=500, noise=0.05, random_state=42)\nX = jnp.array(X)\n\n# \u81ea\u7f16\u7801\u5668\uff1a2 -> 8 -> 1 -> 8 -> 2\ndef init_ae(key):\n    k1, k2, k3, k4 = jax.random.split(key, 4)\n    return {\n        'enc_W1': jax.random.normal(k1, (2, 8)) * 0.5, 'enc_b1': jnp.zeros(8),\n        'enc_W2': jax.random.normal(k2, (8, 1)) * 0.5, 'enc_b2': jnp.zeros(1),\n        'dec_W1': jax.random.normal(k3, (1, 8)) * 0.5, 'dec_b1': jnp.zeros(8),\n        'dec_W2': jax.random.normal(k4, (8, 2)) * 0.5, 'dec_b2': jnp.zeros(2),\n    }\n\ndef encode(p, x):\n    h = jnp.tanh(x @ p['enc_W1'] + p['enc_b1'])\n    return h @ p['enc_W2'] + p['enc_b2']\n\ndef decode(p, z):\n    h = jnp.tanh(z @ p['dec_W1'] + p['dec_b1'])\n    return h @ p['dec_W2'] + p['dec_b2']\n\ndef ae_loss(p, X):\n    z = encode(p, X)\n    X_hat = decode(p, z)\n    return jnp.mean((X - X_hat) ** 2)\n\ngrad_fn = jax.jit(jax.grad(ae_loss))\nparams = init_ae(jax.random.PRNGKey(0))\nlr = 0.01\n\nfor step in range(3000):\n    grads = grad_fn(params, X)\n    params = {k: params[k] - lr * grads[k] for k in params}\n\nz = encode(params, X)\nX_hat = decode(params, z)\n\nfig, axes = plt.subplots(1, 2, figsize=(12, 5))\naxes[0].scatter(X[:,0], X[:,1], c=z.squeeze(), cmap='viridis', s=10)\naxes[0].set_title(\"Original Data (coloured by latent code)\")\naxes[1].scatter(X_hat[:,0], X_hat[:,1], c=z.squeeze(), cmap='viridis', s=10)\naxes[1].set_title(\"Reconstruction from 1D bottleneck\")\nfor ax in axes:\n    ax.set_aspect('equal'); ax.grid(alpha=0.3)\nplt.tight_layout(); plt.show()\n\nprint(f\"Reconstruction MSE: {ae_loss(params, X):.4f}\")\n

"},{"location":"chapter%2006%3A%20machine%20learning/04.%20reinforcement%20learning/","title":"\u5f3a\u5316\u5b66\u4e60","text":"

\u5f3a\u5316\u5b66\u4e60\u901a\u8fc7\u8bd5\u9519\u6cd5\u6700\u5927\u5316\u7d2f\u79ef\u5956\u52b1\u6765\u8bad\u7ec3\u667a\u80fd\u4f53\u505a\u51fa\u5e8f\u5217\u51b3\u7b56\u3002\u672c\u6587\u4ef6\u6db5\u76d6MDP\u3001\u4ef7\u503c\u51fd\u6570\u3001\u8d1d\u5c14\u66fc\u65b9\u7a0b\u3001Q\u5b66\u4e60\u3001\u7b56\u7565\u68af\u5ea6\u3001\u6f14\u5458-\u8bc4\u8bba\u5bb6\u65b9\u6cd5\u3001PPO\u548cRLHF\u2014\u2014\u8fd9\u4e9b\u662f\u6e38\u620f\u667a\u80fd\u4f53\u548c\u8bed\u8a00\u6a21\u578b\u5bf9\u9f50\u80cc\u540e\u7684\u6846\u67b6\u3002

\\[G_t = r_t + \\gamma r_{t+1} + \\gamma^2 r_{t+2} + \\cdots = \\sum_{k=0}^{\\infty} \\gamma^k r_{t+k}\\] \\[V^\\pi(s) = \\mathbb{E}_\\pi \\left[ G_t \\mid s_t = s \\right]\\] \\[Q^\\pi(s, a) = \\mathbb{E}_\\pi \\left[ G_t \\mid s_t = s, a_t = a \\right]\\] \\[V^\\pi(s) = \\sum_a \\pi(a \\mid s) \\sum_{s'} P(s' \\mid s, a) \\left[ R(s, a) + \\gamma \\, V^\\pi(s') \\right]\\] \\[V^{*}(s) = \\max_a \\sum_{s'} P(s' \\mid s, a) \\left[ R(s, a) + \\gamma \\, V^{*}(s') \\right]\\] \\[Q^{*}(s, a) = \\sum_{s'} P(s' \\mid s, a) \\left[ R(s, a) + \\gamma \\max_{a'} Q^{*}(s', a') \\right]\\] \\[V(s) \\leftarrow \\max_a \\sum_{s'} P(s' \\mid s, a) \\left[ R(s, a) + \\gamma \\, V(s') \\right]\\] \\[V(s_t) \\leftarrow V(s_t) + \\alpha \\left[ r_t + \\gamma \\, V(s_{t+1}) - V(s_t) \\right]\\]

\\[Q(s, a) \\leftarrow Q(s, a) + \\alpha \\left[ r + \\gamma \\, Q(s', a') - Q(s, a) \\right]\\] \\[Q(s, a) \\leftarrow Q(s, a) + \\alpha \\left[ r + \\gamma \\max_{a'} Q(s', a') - Q(s, a) \\right]\\] \\[\\mathcal{L}(\\theta) = \\mathbb{E} \\left[ \\left( r + \\gamma \\max_{a'} Q(s', a'; \\theta^{-}) - Q(s, a; \\theta) \\right)^2 \\right]\\] \\[\\nabla_\\theta J(\\theta) = \\mathbb{E}_\\pi \\left[ \\nabla_\\theta \\log \\pi(a \\mid s; \\theta) \\cdot G_t \\right]\\] \\[\\theta \\leftarrow \\theta + \\alpha \\, \\nabla_\\theta \\log \\pi(a_t \\mid s_t; \\theta) \\cdot G_t\\] \\[\\theta \\leftarrow \\theta + \\alpha \\, \\nabla_\\theta \\log \\pi(a_t \\mid s_t; \\theta) \\cdot (G_t - b)\\] \\[\\theta \\leftarrow \\theta + \\alpha \\, \\nabla_\\theta \\log \\pi(a_t \\mid s_t; \\theta) \\cdot A_t\\]

\\[\\mathcal{L}^{\\text{CLIP}}(\\theta) = \\mathbb{E} \\left[ \\min\\!\\left( r_t(\\theta) A_t, \\; \\text{clip}(r_t(\\theta), 1-\\epsilon, 1+\\epsilon) A_t \\right) \\right]\\] \\[\\mathcal{L}_{\\text{DPO}}(\\theta) = -\\mathbb{E} \\left[ \\log \\sigma\\!\\left( \\beta \\log \\frac{\\pi_\\theta(y_w \\mid x)}{\\pi_{\\text{ref}}(y_w \\mid x)} - \\beta \\log \\frac{\\pi_\\theta(y_l \\mid x)}{\\pi_{\\text{ref}}(y_l \\mid x)} \\right) \\right]\\] \u65b9\u6cd5 \u7c7b\u578b \u6838\u5fc3\u601d\u60f3 \u4f18\u52bf \u4ef7\u503c\u8fed\u4ee3 DP, \u57fa\u4e8e\u6a21\u578b \u8d1d\u5c14\u66fc\u6700\u4f18\u6027 \u7cbe\u786e\u89e3\uff08\u5c0fMDP\uff09 SARSA TD, \u5728\u7b56\u7565 \u5728\u7b56\u7565\u5b66\u4e60Q \u4fdd\u5b88\u3001\u5b89\u5168 Q\u5b66\u4e60 TD, \u79bb\u7b56\u7565 \u5b66\u4e60Q*, \u8d2a\u5fc3\u76ee\u6807 \u7b80\u5355\u3001\u6709\u6548 DQN \u6df1\u5ea6, \u79bb\u7b56\u7565 \u795e\u7ecfQ + \u56de\u653e + \u76ee\u6807\u7f51\u7edc \u6269\u5c55\u5230\u9ad8\u7ef4\u72b6\u6001 REINFORCE \u7b56\u7565\u68af\u5ea6 log-\u6982\u7387 * \u56de\u62a5\u7684\u68af\u5ea6 \u7b80\u5355\u7684\u7b56\u7565\u4f18\u5316 \u6f14\u5458-\u8bc4\u8bba\u5bb6 PG + \u4ef7\u503c \u6f14\u5458 + \u8bc4\u8bba\u5bb6\u964d\u4f4e\u65b9\u5dee \u5b9e\u7528\u4e14\u7075\u6d3b PPO PG, \u88c1\u526a \u4fe1\u4efb\u533a\u57df\u822c\u7684\u7a33\u5b9a\u6027 \u884c\u4e1a\u6807\u51c6 DPO \u76f4\u63a5\u504f\u597d \u8df3\u8fc7\u5956\u52b1\u6a21\u578b \u66f4\u7b80\u5355\u7684RLHF"},{"location":"chapter%2006%3A%20machine%20learning/04.%20reinforcement%20learning/#colab","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528CoLab\u6216\u7b14\u8bb0\u672c\uff09","text":"
  1. \u4e3a\u7b80\u5355\u7684\u7f51\u683c\u4e16\u754c\u5b9e\u73b0\u4ef7\u503c\u8fed\u4ee3\u3002\u8ba1\u7b97\u6700\u4f18\u4ef7\u503c\u51fd\u6570\u5e76\u63d0\u53d6\u6700\u4f18\u7b56\u7565\u3002\u5c06\u4e24\u8005\u53ef\u89c6\u5316\u4e3a\u70ed\u529b\u56fe\u548c\u7bad\u5934\u56fe\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# 4x4\u7f51\u683c\u4e16\u754c\uff1a\u76ee\u6807\u5728(3,3)\uff0c\u6bcf\u6b65\u5956\u52b1-1\uff0c\u76ee\u6807\u5904\u4e3a0\ngrid_size = 4\ngamma = 0.99\ngoal = (3, 3)\n\n# \u52a8\u4f5c\uff1a\u4e0a\u3001\u4e0b\u3001\u5de6\u3001\u53f3\nactions = [(-1, 0), (1, 0), (0, -1), (0, 1)]\naction_names = ['up', 'down', 'left', 'right']\naction_arrows = ['\\u2191', '\\u2193', '\\u2190', '\\u2192']\n\ndef step(s, a):\n    \"\"\"\u786e\u5b9a\u6027\u8f6c\u79fb\u3002\"\"\"\n    ns = (max(0, min(grid_size-1, s[0]+a[0])),\n          max(0, min(grid_size-1, s[1]+a[1])))\n    return ns\n\n# \u4ef7\u503c\u8fed\u4ee3\nV = jnp.zeros((grid_size, grid_size))\nfor iteration in range(100):\n    V_new = jnp.array(V)\n    for i in range(grid_size):\n        for j in range(grid_size):\n            if (i, j) == goal:\n                continue\n            values = []\n            for a in actions:\n                ns = step((i, j), a)\n                values.append(-1 + gamma * float(V[ns[0], ns[1]]))\n            V_new = V_new.at[i, j].set(max(values))\n    if jnp.max(jnp.abs(V_new - V)) < 1e-6:\n        print(f\"\u5728{iteration+1}\u6b21\u8fed\u4ee3\u540e\u6536\u655b\")\n        break\n    V = V_new\n\n# \u63d0\u53d6\u7b56\u7565\npolicy = [['' for _ in range(grid_size)] for _ in range(grid_size)]\nfor i in range(grid_size):\n    for j in range(grid_size):\n        if (i, j) == goal:\n            policy[i][j] = 'G'\n            continue\n        best_a = max(range(4), key=lambda a: -1 + gamma * float(V[step((i,j), actions[a])[0], step((i,j), actions[a])[1]]))\n        policy[i][j] = action_arrows[best_a]\n\nfig, axes = plt.subplots(1, 2, figsize=(10, 4))\nim = axes[0].imshow(V, cmap='YlOrRd_r')\naxes[0].set_title(\"\u6700\u4f18\u4ef7\u503c\u51fd\u6570\")\nfor i in range(grid_size):\n    for j in range(grid_size):\n        axes[0].text(j, i, f\"{V[i,j]:.1f}\", ha='center', va='center', fontsize=10)\nplt.colorbar(im, ax=axes[0])\n\naxes[1].imshow(jnp.ones((grid_size, grid_size)), cmap='Greys', vmin=0, vmax=2)\naxes[1].set_title(\"\u6700\u4f18\u7b56\u7565\")\nfor i in range(grid_size):\n    for j in range(grid_size):\n        axes[1].text(j, i, policy[i][j], ha='center', va='center', fontsize=18)\nplt.tight_layout(); plt.show()\n

  2. \u5728\u7b80\u5355\u7684\u7f51\u683c\u4e16\u754c\u4e0a\u5b9e\u73b0\u8868\u683cQ\u5b66\u4e60\u3002\u8bad\u7ec3\u667a\u80fd\u4f53\uff0c\u7ed8\u5236\u5b66\u4e60\u66f2\u7ebf\uff0c\u663e\u793a\u5b66\u4e60\u5230\u7684Q\u503c\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ngrid_size = 5\ngoal = (4, 4)\nactions = [(-1,0), (1,0), (0,-1), (0,1)]\n\n# Q\u8868\nQ = {}\nfor i in range(grid_size):\n    for j in range(grid_size):\n        Q[(i,j)] = [0.0] * 4\n\nalpha = 0.1\ngamma = 0.95\nepsilon = 1.0\nepsilon_decay = 0.995\nmin_epsilon = 0.01\n\ndef step(s, a_idx):\n    a = actions[a_idx]\n    ns = (max(0, min(grid_size-1, s[0]+a[0])),\n          max(0, min(grid_size-1, s[1]+a[1])))\n    r = 0.0 if ns == goal else -1.0\n    done = ns == goal\n    return ns, r, done\n\nkey = jax.random.PRNGKey(42)\nrewards_per_episode = []\n\nfor ep in range(500):\n    s = (0, 0)\n    total_reward = 0\n    for _ in range(100):\n        key, subkey = jax.random.split(key)\n        if float(jax.random.uniform(subkey)) < epsilon:\n            key, subkey = jax.random.split(key)\n            a = int(jax.random.randint(subkey, (), 0, 4))\n        else:\n            a = max(range(4), key=lambda i: Q[s][i])\n\n        ns, r, done = step(s, a)\n        total_reward += r\n        # Q\u5b66\u4e60\u66f4\u65b0\n        Q[s][a] += alpha * (r + gamma * max(Q[ns]) - Q[s][a])\n        s = ns\n        if done:\n            break\n    rewards_per_episode.append(total_reward)\n    epsilon = max(min_epsilon, epsilon * epsilon_decay)\n\nplt.figure(figsize=(8, 4))\n# \u5e73\u6ed1\u66f2\u7ebf\nwindow = 20\nsmoothed = [sum(rewards_per_episode[max(0,i-window):i+1])/min(i+1, window)\n            for i in range(len(rewards_per_episode))]\nplt.plot(smoothed, color='#3498db', linewidth=1.5)\nplt.xlabel(\"Episode\"); plt.ylabel(\"Total Reward (smoothed)\")\nplt.title(\"Q-Learning on Gridworld\")\nplt.grid(alpha=0.3); plt.show()\n\n# \u663e\u793a\u5b66\u5230\u7684\u7b56\u7565\narrow = ['\\u2191', '\\u2193', '\\u2190', '\\u2192']\nprint(\"\u5b66\u5230\u7684\u7b56\u7565:\")\nfor i in range(grid_size):\n    row = \"\"\n    for j in range(grid_size):\n        if (i,j) == goal:\n            row += \" G \"\n        else:\n            row += f\" {arrow[max(range(4), key=lambda a: Q[(i,j)][a])]} \"\n    print(row)\n

  3. \u5728\u591a\u81c2\u8001\u864e\u673a\u95ee\u9898\u4e0a\u5b9e\u73b0REINFORCE\u3002\u5c55\u793a\u7b56\u7565\u5982\u4f55\u968f\u8bad\u7ec3\u6f14\u53d8\u4ee5\u504f\u5411\u6700\u4f73\u81c2\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# 5\u81c2\u8001\u864e\u673a\uff0c\u4e0d\u540c\u671f\u671b\u5956\u52b1\ntrue_rewards = jnp.array([0.2, 0.5, 0.8, 0.3, 0.1])\nn_arms = len(true_rewards)\n\n# \u7b56\u7565\uff1a\u5728logits\u4e0a\u7684softmax\nlogits = jnp.zeros(n_arms)\nlr = 0.1\nkey = jax.random.PRNGKey(42)\n\npolicy_history = []\nreward_history = []\n\nfor step in range(2000):\n    probs = jax.nn.softmax(logits)\n    policy_history.append(probs)\n\n    # \u91c7\u6837\u52a8\u4f5c\n    key, subkey = jax.random.split(key)\n    action = jax.random.choice(subkey, n_arms, p=probs)\n\n    # \u83b7\u53d6\u5956\u52b1\uff08\u4f2f\u52aa\u5229\u5206\u5e03\uff09\n    key, subkey = jax.random.split(key)\n    reward = float(jax.random.uniform(subkey) < true_rewards[action])\n    reward_history.append(reward)\n\n    # REINFORCE\u66f4\u65b0\n    # grad log pi(a) = e_a - probs\uff08\u5bf9\u4e8esoftmax\u53c2\u6570\u5316\uff09\n    grad_log_pi = -probs.at[action].add(1.0)  # one-hot(a) - probs\n    logits = logits + lr * reward * grad_log_pi\n\npolicy_history = jnp.stack(policy_history)\n\nfig, axes = plt.subplots(1, 2, figsize=(12, 4))\ncolors = ['#3498db', '#e74c3c', '#27ae60', '#9b59b6', '#f39c12']\nfor i in range(n_arms):\n    axes[0].plot(policy_history[:, i], color=colors[i],\n                 label=f'\u81c2{i} (\u771f\u5b9e={true_rewards[i]:.1f})', linewidth=1.5)\naxes[0].set_xlabel(\"\u6b65\u9aa4\"); axes[0].set_ylabel(\"P(\u81c2)\")\naxes[0].set_title(\"\u7b56\u7565\u6f14\u53d8 (REINFORCE)\")\naxes[0].legend(fontsize=8); axes[0].grid(alpha=0.3)\n\n# \u5e73\u6ed1\u5956\u52b1\nwindow = 50\nsmoothed = [sum(reward_history[max(0,i-window):i+1])/min(i+1,window)\n            for i in range(len(reward_history))]\naxes[1].plot(smoothed, color='#27ae60', linewidth=1.5)\naxes[1].axhline(y=0.8, color='#e74c3c', linestyle='--', alpha=0.5, label='\u6700\u4f73\u81c2')\naxes[1].set_xlabel(\"\u6b65\u9aa4\"); axes[1].set_ylabel(\"\u5e73\u5747\u5956\u52b1\")\naxes[1].set_title(\"\u5956\u52b1\u968f\u65f6\u95f4\u53d8\u5316\"); axes[1].legend()\naxes[1].grid(alpha=0.3)\nplt.tight_layout(); plt.show()\n

"},{"location":"chapter%2006%3A%20machine%20learning/05.%20distributed%20deep%20learning/","title":"\u5206\u5e03\u5f0f\u6df1\u5ea6\u5b66\u4e60","text":"

\u5206\u5e03\u5f0f\u8bad\u7ec3\u5c06\u8ba1\u7b97\u5206\u6563\u5230\u591a\u4e2aGPU\u548c\u673a\u5668\u4e0a\uff0c\u4ee5\u8bad\u7ec3\u5355\u4e2a\u8bbe\u5907\u65e0\u6cd5\u5bb9\u7eb3\u6216\u8bad\u7ec3\u592a\u6162\u7684\u6a21\u578b\u3002\u672c\u6587\u4ef6\u6db5\u76d6\u6df7\u5408\u7cbe\u5ea6\u3001\u6570\u636e\u5e76\u884c\u3001\u6a21\u578b\u5e76\u884c\u3001\u6d41\u6c34\u7ebf\u5e76\u884c\u3001ZeRO\u3001FSDP\u3001\u5f20\u91cf\u5e76\u884c\u4ee5\u53ca\u5168\u89c4\u7ea6\u7b49\u901a\u4fe1\u539f\u8bed\u2014\u2014\u8fd9\u4e9b\u5bf9\u4e8e\u5927\u89c4\u6a21\u8bad\u7ec3LLM\u81f3\u5173\u91cd\u8981\u3002

\\[L(N) \\propto N^{-\\alpha_N}, \\quad L(D) \\propto D^{-\\alpha_D}, \\quad L(C) \\propto C^{-\\alpha_C}\\]

\u6280\u672f \u4f5c\u7528 \u6743\u8861 \u6df7\u5408\u7cbe\u5ea6 (BF16) \u5c06\u6fc0\u6d3b\u503c/\u68af\u5ea6\u7684\u5185\u5b58\u51cf\u534a \u8f7b\u5fae\u6570\u503c\u5dee\u5f02 \u6570\u636e\u5e76\u884c \u5728GPU\u95f4\u6269\u5c55\u6279\u91cf\u5927\u5c0f \u68af\u5ea6\u540c\u6b65\u7684\u901a\u4fe1\u5f00\u9500 \u5f20\u91cf\u5e76\u884c \u5728GPU\u95f4\u5206\u5272\u5c42 \u9700\u8981\u5feb\u901f\u4e92\u8054 \u6d41\u6c34\u7ebf\u5e76\u884c \u5728GPU\u95f4\u5206\u5272\u6a21\u578b\u9636\u6bb5 \u6d41\u6c34\u7ebf\u6c14\u6ce1\uff08\u8ba1\u7b97\u6d6a\u8d39\uff09 \u68af\u5ea6\u7d2f\u79ef \u6a21\u62df\u5927\u6279\u91cf \u66f4\u6162\uff08\u591a\u6b21\u524d\u5411/\u53cd\u5411\u4f20\u64ad\uff09 \u68af\u5ea6\u68c0\u67e5\u70b9 \u51cf\u5c11\u6fc0\u6d3b\u5185\u5b58 \u7ea6\u591a33%\u8ba1\u7b97 \u73af\u5168\u89c4\u7ea6 \u9ad8\u6548\u7684\u68af\u5ea6\u5e73\u5747 \u5927\u6a21\u578b\u53d7\u9650\u4e8e\u5e26\u5bbd MoE \u66f4\u591a\u5bb9\u91cf\uff0c\u76f8\u540cFLOPs \u8d1f\u8f7d\u5747\u8861\u3001\u8def\u7531\u590d\u6742\u6027 \u7f29\u653e\u5b9a\u5f8b \u6307\u5bfc\u8ba1\u7b97\u5206\u914d \u7ecf\u9a8c\u516c\u5f0f\uff0c\u672a\u5fc5\u5728\u6240\u6709\u89c4\u6a21\u90fd\u6210\u7acb"},{"location":"chapter%2006%3A%20machine%20learning/05.%20distributed%20deep%20learning/#colab","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528CoLab\u6216\u7b14\u8bb0\u672c\uff09","text":"
  1. \u8ba1\u7b97Transformer\u5c42\u7684FLOPs\u548c\u5185\u5b58\u9700\u6c42\u3002\u7ed9\u5b9a\u9690\u85cf\u7ef4\u5ea6 \\(d\\)\u3001\u5e8f\u5217\u957f\u5ea6 \\(n\\)\u3001\u6279\u91cf\u5927\u5c0f \\(B\\) \u548c\u5c42\u6570\uff0c\u4f30\u8ba1\u603b\u8bad\u7ec3\u6210\u672c\u3002

    import jax.numpy as jnp\n\ndef transformer_layer_flops(d, n, B):\n    \"\"\"\u4e00\u4e2aTransformer\u5c42\u524d\u5411\u4f20\u64ad\u7684\u8fd1\u4f3cFLOPs\u3002\"\"\"\n    # QKV\u6295\u5f71\uff1a3 * (B * n * d * d) * 2\uff08\u4e58\u6cd5-\u52a0\u6cd5\uff09\n    qkv_flops = 3 * 2 * B * n * d * d\n    # \u6ce8\u610f\u529b\uff1a(B * n * n * d) * 2 \u7528\u4e8eQK^T\uff0c(B * n * n * d) * 2 \u7528\u4e8eattn*V\n    attn_flops = 2 * 2 * B * n * n * d\n    # \u8f93\u51fa\u6295\u5f71\uff1a(B * n * d * d) * 2\n    out_flops = 2 * B * n * d * d\n    # FFN\uff1a\u4e24\u5c42\uff0cd->4d \u548c 4d->d\uff1a2 * (B * n * d * 4d) * 2\n    ffn_flops = 2 * 2 * B * n * d * 4 * d\n    return qkv_flops + attn_flops + out_flops + ffn_flops\n\ndef transformer_layer_memory(d, n, B, dtype_bytes=2):\n    \"\"\"\u4e00\u4e2a\u5c42\u7684\u8fd1\u4f3c\u6fc0\u6d3b\u5185\u5b58\uff08\u5b57\u8282\uff09\u3002\"\"\"\n    # QKV\uff1a3 * B * n * d\n    qkv_mem = 3 * B * n * d * dtype_bytes\n    # \u6ce8\u610f\u529b\u6743\u91cd\uff1aB * heads * n * n\uff08\u8fd1\u4f3c B * n * n * sizeof\uff09\n    attn_mem = B * n * n * dtype_bytes\n    # FFN\u4e2d\u95f4\u503c\uff1aB * n * 4d\n    ffn_mem = B * n * 4 * d * dtype_bytes\n    return qkv_mem + attn_mem + ffn_mem\n\n# \u793a\u4f8b\uff1aGPT-2\u89c4\u6a21\nd, n, B, L = 1024, 1024, 8, 24\nfwd_flops = transformer_layer_flops(d, n, B)\ntotal_flops = 3 * L * fwd_flops  # \u524d\u5411+\u53cd\u5411\u76843\u500d\nact_mem = L * transformer_layer_memory(d, n, B)\nparam_count = L * (12 * d * d + 13 * d)  # \u8fd1\u4f3c\n\nprint(f\"\u6a21\u578b\uff1ad={d}, n={n}, B={B}, L={L}\")\nprint(f\"\u53c2\u6570\uff1a{param_count / 1e6:.0f}M\")\nprint(f\"\u6bcf\u6b65FLOPs\uff1a{total_flops / 1e12:.2f} TFLOPs\")\nprint(f\"\u6fc0\u6d3b\u5185\u5b58\uff1a{act_mem / 1e9:.2f} GB (BF16)\")\nprint(f\"\u53c2\u6570\u5185\u5b58 (FP32)\uff1a{param_count * 4 / 1e9:.2f} GB\")\nprint(f\"Adam\u4f18\u5316\u5668\u5185\u5b58\uff1a{param_count * 8 / 1e9:.2f} GB\")\nprint(f\"\u603b\u8bad\u7ec3\u5185\u5b58\uff1a{(param_count * 16 + act_mem) / 1e9:.2f} GB\")\n

  2. \u6a21\u62df\u6570\u636e\u5e76\u884c\u8bad\u7ec3\u3002\u5c06\u6570\u636e\u96c6\u5206\u5272\u5230\u591a\u4e2a\"\u865a\u62dfGPU\"\u4e0a\uff0c\u72ec\u7acb\u8ba1\u7b97\u68af\u5ea6\uff0c\u5e73\u5747\u5b83\u4eec\uff0c\u5e76\u9a8c\u8bc1\u7ed3\u679c\u4e0e\u5355GPU\u8bad\u7ec3\u5339\u914d\u3002

    import jax\nimport jax.numpy as jnp\n\n# \u7b80\u5355\u7ebf\u6027\u6a21\u578b\uff1ay = wx + b\nkey = jax.random.PRNGKey(0)\nX = jax.random.normal(key, (64, 4))\nw_true = jnp.array([1.0, -2.0, 3.0, 0.5])\ny = X @ w_true + 0.1 * jax.random.normal(key, (64,))\n\ndef loss_fn(w, X, y):\n    return jnp.mean((X @ w - y) ** 2)\n\ngrad_fn = jax.grad(loss_fn)\n\n# \u5355GPU\uff1a\u5168\u6279\u91cf\u68af\u5ea6\nw = jnp.zeros(4)\ngrad_single = grad_fn(w, X, y)\n\n# \u6570\u636e\u5e76\u884c\uff1a\u5206\u5272\u52304\u4e2a\"GPU\"\u4e0a\nn_gpus = 4\nchunk_size = len(X) // n_gpus\ngrads = []\nfor i in range(n_gpus):\n    X_chunk = X[i*chunk_size:(i+1)*chunk_size]\n    y_chunk = y[i*chunk_size:(i+1)*chunk_size]\n    grads.append(grad_fn(w, X_chunk, y_chunk))\n\n# \u5168\u89c4\u7ea6\uff1a\u5e73\u5747\u68af\u5ea6\ngrad_parallel = jnp.mean(jnp.stack(grads), axis=0)\n\nprint(\"\u5355GPU\u68af\u5ea6\uff1a\", grad_single)\nprint(\"\u6570\u636e\u5e76\u884c\u68af\u5ea6\uff08\u5e73\u5747\uff09\uff1a\", grad_parallel)\nprint(f\"\u5339\u914d\uff1a{jnp.allclose(grad_single, grad_parallel, atol=1e-5)}\")\n\n# \u8bad\u7ec3\u4e24\u8005\u5e76\u6bd4\u8f83\nw_single, w_parallel = jnp.zeros(4), jnp.zeros(4)\nlr = 0.1\nfor step in range(100):\n    w_single = w_single - lr * grad_fn(w_single, X, y)\n\n    grads = [grad_fn(w_parallel, X[i*chunk_size:(i+1)*chunk_size],\n                     y[i*chunk_size:(i+1)*chunk_size]) for i in range(n_gpus)]\n    avg_grad = jnp.mean(jnp.stack(grads), axis=0)\n    w_parallel = w_parallel - lr * avg_grad\n\nprint(f\"\\n100\u6b65\u4e4b\u540e\uff1a\")\nprint(f\"\u5355GPU\u6743\u91cd\uff1a{w_single}\")\nprint(f\"\u6570\u636e\u5e76\u884c\u6743\u91cd\uff1a{w_parallel}\")\nprint(f\"\u6700\u5927\u5dee\u5f02\uff1a{jnp.max(jnp.abs(w_single - w_parallel)):.2e}\")\n

  3. \u5b9e\u73b0\u4e00\u4e2a\u7b80\u5355\u7684\u6df7\u5408\u4e13\u5bb6\u5c42\u3002\u521b\u5efa\u4e00\u4e2a\u95e8\u63a7\u7f51\u7edc\uff0c\u5c06token\u8def\u7531\u5230top-K\u4e2a\u4e13\u5bb6\u5e76\u7ec4\u5408\u5b83\u4eec\u7684\u8f93\u51fa\u3002

    import jax\nimport jax.numpy as jnp\n\ndef expert_fn(x, W1, b1, W2, b2):\n    \"\"\"\u7b80\u5355\u76842\u5c42FFN\u4e13\u5bb6\u3002\"\"\"\n    h = jnp.maximum(0, x @ W1 + b1)  # ReLU\n    return h @ W2 + b2\n\ndef moe_layer(x, gate_W, experts_params, top_k=2):\n    \"\"\"\n    MoE\u524d\u5411\u4f20\u64ad\u3002\n    x: (batch, d_model)\n    gate_W: (d_model, n_experts)\n    experts_params: \u6bcf\u4e2a\u4e13\u5bb6\u7684 (W1, b1, W2, b2) \u5217\u8868\n    \"\"\"\n    n_experts = len(experts_params)\n\n    # \u95e8\u63a7\uff1a\u8ba1\u7b97\u8def\u7531\u5206\u6570\n    gate_logits = x @ gate_W  # (batch, n_experts)\n    gate_probs = jax.nn.softmax(gate_logits, axis=-1)\n\n    # Top-K\u9009\u62e9\n    top_k_indices = jnp.argsort(-gate_probs, axis=-1)[:, :top_k]\n    top_k_probs = jnp.take_along_axis(gate_probs, top_k_indices, axis=-1)\n    # \u91cd\u65b0\u5f52\u4e00\u5316\n    top_k_probs = top_k_probs / jnp.sum(top_k_probs, axis=-1, keepdims=True)\n\n    # \u8ba1\u7b97\u4e13\u5bb6\u8f93\u51fa\uff08\u7b80\u5316\uff1a\u8fd0\u884c\u6240\u6709\u4e13\u5bb6\uff0c\u7a0d\u540e\u63a9\u7801\uff09\n    expert_outputs = jnp.stack([\n        expert_fn(x, *experts_params[i]) for i in range(n_experts)\n    ], axis=1)  # (batch, n_experts, d_model)\n\n    # \u6536\u96c6top-K\u4e13\u5bb6\u8f93\u51fa\u5e76\u52a0\u6743\n    batch_idx = jnp.arange(x.shape[0])[:, None]\n    selected_outputs = expert_outputs[batch_idx, top_k_indices]  # (batch, top_k, d_model)\n    output = jnp.sum(selected_outputs * top_k_probs[:, :, None], axis=1)\n\n    return output, gate_probs\n\n# \u8bbe\u7f6e\nkey = jax.random.PRNGKey(42)\nbatch, d_model, d_ff, n_experts = 8, 16, 32, 4\n\n# \u521d\u59cb\u5316\u4e13\u5bb6\nexperts_params = []\nfor i in range(n_experts):\n    k1, k2, key = jax.random.split(key, 3)[0], jax.random.split(key, 3)[1], jax.random.split(key, 3)[2]\n    experts_params.append((\n        jax.random.normal(k1, (d_model, d_ff)) * 0.1,\n        jnp.zeros(d_ff),\n        jax.random.normal(k2, (d_ff, d_model)) * 0.1,\n        jnp.zeros(d_model),\n    ))\n\nkey, subkey = jax.random.split(key)\ngate_W = jax.random.normal(subkey, (d_model, n_experts)) * 0.1\nx = jax.random.normal(key, (batch, d_model))\n\noutput, gate_probs = moe_layer(x, gate_W, experts_params, top_k=2)\n\nprint(f\"\u8f93\u5165\u5f62\u72b6\uff1a{x.shape}\")\nprint(f\"\u8f93\u51fa\u5f62\u72b6\uff1a{output.shape}\")\nprint(f\"\u95e8\u63a7\u6982\u7387\uff08\u7b2c\u4e00\u4e2a\u6837\u672c\uff09\uff1a{gate_probs[0]}\")\nprint(f\"\u4e13\u5bb6\u4f7f\u7528\u7387\uff08\u6279\u91cf\u5e73\u5747\uff09\uff1a\")\nfor i in range(n_experts):\n    usage = jnp.mean(gate_probs[:, i])\n    print(f\"  \u4e13\u5bb6 {i}: {usage:.3f}\")\n

"},{"location":"chapter%2007%3A%20computational%20linguistics/01.%20linguistic%20foundations/","title":"\u8bed\u8a00\u5b66\u57fa\u7840","text":"

\u8bed\u8a00\u5b66\u4e3aNLP\u7cfb\u7edf\u63d0\u4f9b\u4e86\u5b83\u4eec\u9690\u5f0f\u5b66\u4e60\u5e76\u5229\u7528\u7684\u7ed3\u6784\u5316\u8bcd\u6c47\u3002\u672c\u6587\u6db5\u76d6\u5f62\u6001\u5b66\u3001\u53e5\u6cd5\u5b66\u3001\u8bed\u4e49\u5b66\u3001\u8bed\u7528\u5b66\u3001\u97f3\u7cfb\u5b66\u3001\u6210\u5206\u53e5\u6cd5\u548c\u4f9d\u5b58\u53e5\u6cd5\u5206\u6790\uff0c\u4ee5\u53ca\u5206\u5e03\u5047\u8bbe\u2014\u2014\u8fd9\u4e9b\u4eba\u7c7b\u8bed\u8a00\u79d1\u5b66\u6784\u6210\u4e86AI\u4e2d\u8bcd\u5143\u5316\u3001\u8bed\u6cd5\u548c\u610f\u4e49\u7684\u57fa\u7840\u3002

S  \u2192 NP VP\nNP \u2192 Det N\nNP \u2192 Det N PP\nVP \u2192 V NP\nVP \u2192 V PP\nPP \u2192 P NP\nDet \u2192 \"the\" | \"a\"\nN  \u2192 \"cat\" | \"mat\" | \"dog\"\nV  \u2192 \"sat\" | \"chased\"\nP  \u2192 \"on\" | \"under\"\n

"},{"location":"chapter%2007%3A%20computational%20linguistics/01.%20linguistic%20foundations/#colabnotebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528CoLab\u6216notebook\uff09","text":"
  1. \u6784\u5efa\u4e00\u4e2a\u7b80\u5355\u7684\u5f62\u6001\u5206\u6790\u5668\uff0c\u4f7f\u7528\u5e38\u89c1\u524d\u7f00\u548c\u540e\u7f00\u5217\u8868\u5c06\u82f1\u8bed\u5355\u8bcd\u5206\u89e3\u4e3a\u53ef\u80fd\u7684\u8bed\u7d20\u3002

    prefixes = ['un', 're', 'pre', 'dis', 'mis', 'over', 'under', 'out', 'non']\nsuffixes = ['ing', 'ed', 'ly', 'ness', 'ment', 'tion', 'able', 'ible', 'er', 'est', 'ful', 'less', 'ous']\n\ndef analyse_morphemes(word):\n    \"\"\"\u4f7f\u7528\u5df2\u77e5\u8bcd\u7f00\u8fdb\u884c\u7b80\u5355\u7684\u8bed\u7d20\u5206\u6790\u3002\"\"\"\n    parts = []\n    remaining = word.lower()\n\n    # \u68c0\u67e5\u524d\u7f00\n    for p in sorted(prefixes, key=len, reverse=True):\n        if remaining.startswith(p) and len(remaining) > len(p) + 2:\n            parts.append(f\"[prefix: {p}]\")\n            remaining = remaining[len(p):]\n            break\n\n    # \u68c0\u67e5\u540e\u7f00\n    for s in sorted(suffixes, key=len, reverse=True):\n        if remaining.endswith(s) and len(remaining) > len(s) + 2:\n            root = remaining[:-len(s)]\n            parts.append(f\"[root: {root}]\")\n            parts.append(f\"[suffix: {s}]\")\n            remaining = None\n            break\n\n    if remaining is not None:\n        parts.append(f\"[root: {remaining}]\")\n\n    return parts\n\nfor word in ['unhappiness', 'reusable', 'disconnected', 'overreacting', 'kindness']:\n    print(f\"{word:20s} \u2192 {' + '.join(analyse_morphemes(word))}\")\n

  2. \u5b9e\u73b0\u4e00\u4e2a\u4f7f\u7528\u9012\u5f52\u4e0b\u964d\u6cd5\u7684\u7b80\u5355\u4e0a\u4e0b\u6587\u65e0\u5173\u6587\u6cd5\u5206\u6790\u5668\u3002\u5b9a\u4e49\u4e00\u4e2a\u5c0f\u578b\u6587\u6cd5\uff0c\u5e76\u5c06\u53e5\u5b50\u5206\u6790\u4e3a\u6210\u5206\u6811\u3002

    class CFGParser:\n    \"\"\"\u7528\u4e8e\u5c0f\u578b\u82f1\u8bed\u6587\u6cd5\u7684\u9012\u5f52\u4e0b\u964d\u5206\u6790\u5668\u3002\"\"\"\n    def __init__(self, tokens):\n        self.tokens = tokens\n        self.pos = 0\n\n    def peek(self):\n        return self.tokens[self.pos] if self.pos < len(self.tokens) else None\n\n    def consume(self, expected=None):\n        tok = self.peek()\n        if expected and tok != expected:\n            return None\n        self.pos += 1\n        return tok\n\n    def parse_det(self):\n        if self.peek() in ('the', 'a'):\n            return ('Det', self.consume())\n        return None\n\n    def parse_noun(self):\n        if self.peek() in ('cat', 'dog', 'mat', 'man'):\n            return ('N', self.consume())\n        return None\n\n    def parse_verb(self):\n        if self.peek() in ('sat', 'chased', 'saw'):\n            return ('V', self.consume())\n        return None\n\n    def parse_prep(self):\n        if self.peek() in ('on', 'under', 'with'):\n            return ('P', self.consume())\n        return None\n\n    def parse_np(self):\n        save = self.pos\n        det = self.parse_det()\n        noun = self.parse_noun()\n        if det and noun:\n            # \u68c0\u67e5\u53ef\u9009\u7684PP\n            pp = self.parse_pp()\n            if pp:\n                return ('NP', det, noun, pp)\n            return ('NP', det, noun)\n        self.pos = save\n        return None\n\n    def parse_pp(self):\n        save = self.pos\n        prep = self.parse_prep()\n        np = self.parse_np()\n        if prep and np:\n            return ('PP', prep, np)\n        self.pos = save\n        return None\n\n    def parse_vp(self):\n        save = self.pos\n        verb = self.parse_verb()\n        if verb:\n            np = self.parse_np()\n            if np:\n                return ('VP', verb, np)\n            pp = self.parse_pp()\n            if pp:\n                return ('VP', verb, pp)\n        self.pos = save\n        return None\n\n    def parse_sentence(self):\n        np = self.parse_np()\n        vp = self.parse_vp()\n        if np and vp and self.pos == len(self.tokens):\n            return ('S', np, vp)\n        return None\n\ndef print_tree(tree, indent=0):\n    if isinstance(tree, str):\n        print(' ' * indent + tree)\n    elif isinstance(tree, tuple):\n        print(' ' * indent + tree[0])\n        for child in tree[1:]:\n            print_tree(child, indent + 2)\n\nsentences = [\n    \"the cat sat on the mat\",\n    \"a dog chased the cat\",\n]\n\nfor sent in sentences:\n    tokens = sent.split()\n    parser = CFGParser(tokens)\n    tree = parser.parse_sentence()\n    print(f\"\\n'{sent}':\")\n    if tree:\n        print_tree(tree)\n    else:\n        print(\"  (no parse found)\")\n

  3. \u901a\u8fc7\u6784\u5efa\u4e00\u4e2a\u7b80\u5355\u7684\u8bcd\u56fe\u6765\u63a2\u7d22\u8bcd\u6c47\u5173\u7cfb\u3002\u7ed9\u5b9a\u4e00\u4e2a\u5305\u542b\u540c\u4e49\u3001\u53cd\u4e49\u548c\u4e0a\u4f4d\u5173\u7cfb\u7684\u5c0f\u578b\u8bcd\u6c47\u8868\uff0c\u67e5\u627e\u5355\u8bcd\u4e4b\u95f4\u7684\u8def\u5f84\u3002

    relations = {\n    ('big', 'large'): 'synonym',\n    ('big', 'small'): 'antonym',\n    ('small', 'tiny'): 'synonym',\n    ('dog', 'animal'): 'hypernym',\n    ('cat', 'animal'): 'hypernym',\n    ('puppy', 'dog'): 'hypernym',\n    ('happy', 'glad'): 'synonym',\n    ('happy', 'sad'): 'antonym',\n    ('hot', 'cold'): 'antonym',\n    ('hot', 'warm'): 'synonym',\n}\n\n# \u6784\u5efa\u90bb\u63a5\u5217\u8868\nfrom collections import defaultdict, deque\n\ngraph = defaultdict(list)\nfor (w1, w2), rel in relations.items():\n    graph[w1].append((w2, rel))\n    graph[w2].append((w1, rel))\n\ndef find_path(start, end):\n    \"\"\"\u4f7f\u7528BFS\u5728\u5173\u7cfb\u56fe\u4e2d\u67e5\u627e\u4e24\u4e2a\u5355\u8bcd\u4e4b\u95f4\u7684\u8def\u5f84\u3002\"\"\"\n    queue = deque([(start, [(start, None)])])\n    visited = {start}\n    while queue:\n        node, path = queue.popleft()\n        if node == end:\n            return path\n        for neighbor, rel in graph[node]:\n            if neighbor not in visited:\n                visited.add(neighbor)\n                queue.append((neighbor, path + [(neighbor, rel)]))\n    return None\n\npairs = [('big', 'tiny'), ('puppy', 'cat'), ('happy', 'sad')]\nfor w1, w2 in pairs:\n    path = find_path(w1, w2)\n    if path:\n        steps = \" \u2192 \".join(f\"{w}({r})\" if r else w for w, r in path)\n        print(f\"{w1} \u2192 {w2}: {steps}\")\n    else:\n        print(f\"{w1} \u2192 {w2}: no path found\")\n

"},{"location":"chapter%2007%3A%20computational%20linguistics/02.%20text%20processing%20and%20classic%20NLP/","title":"\u6587\u672c\u5904\u7406\u4e0e\u7ecf\u5178NLP","text":"

\u6587\u672c\u5904\u7406\u5c06\u539f\u59cb\u5b57\u7b26\u8f6c\u6362\u4e3a\u6a21\u578b\u53ef\u6d88\u8d39\u7684\u7ed3\u6784\u5316\u8868\u793a\u3002\u672c\u6587\u6db5\u76d6\u5206\u8bcd\uff08\u8bcd\u7ea7\u3001\u5b50\u8bcd\u3001BPE\u3001WordPiece\uff09\u3001\u6587\u672c\u89c4\u8303\u5316\u3001\u7f16\u8f91\u8ddd\u79bb\u3001TF-IDF\u3001n\u5143\u7ec4\u8bed\u8a00\u6a21\u578b\u3001\u8bcd\u6027\u6807\u6ce8\u3001\u547d\u540d\u5b9e\u4f53\u8bc6\u522b\u548c\u60c5\u611f\u5206\u6790\u2014\u2014\u8fd9\u4e9b\u7ecf\u5178NLP\u6d41\u6c34\u7ebf\u81f3\u4eca\u4ecd\u662f\u73b0\u4ee3\u7cfb\u7edf\u7684\u57fa\u7840\u3002

\\[ D[i][j] = \\begin{cases} j & \\text{if } i = 0 \\\\ i & \\text{if } j = 0 \\\\ D[i{-}1][j{-}1] & \\text{if } s[i] = t[j] \\\\ 1 + \\min(D[i{-}1][j], \\; D[i][j{-}1], \\; D[i{-}1][j{-}1]) & \\text{otherwise} \\end{cases} \\]

\\[\\\\hat{t}_{1:n} = \\\\arg\\\\max_{t_{1:n}} \\\\prod_{i=1}^{n} P(w_i \\\\mid t_i) \\\\cdot P(t_i \\\\mid t_{i-1})\\]

\\[P(y_{1:n} \\\\mid x_{1:n}) = \\\\frac{1}{Z(x)} \\\\exp\\\\!\\\\left(\\\\sum_{i=1}^{n} \\\\left[\\\\sum_k \\\\lambda_k f_k(y_i, x, i) + \\\\sum_j \\\\mu_j g_j(y_i, y_{i-1}, x, i)\\\\right]\\\\right)\\]

\\[\\\\text{TF-IDF}(t, d) = \\\\text{TF}(t, d) \\\\times \\\\text{IDF}(t)\\] \\[P(w_1, w_2, \\\\ldots, w_n) = \\\\prod_{i=1}^{n} P(w_i \\\\mid w_1, \\\\ldots, w_{i-1})\\] \\[P(w_i \\\\mid w_1, \\\\ldots, w_{i-1}) \\\\approx P(w_i \\\\mid w_{i-1})\\] \\[P(w_i \\\\mid w_{i-1}) = \\\\frac{\\\\text{count}(w_{i-1}, w_i)}{\\\\text{count}(w_{i-1})}\\] \\[\\\\text{PPL} = P(w_1, \\\\ldots, w_N)^{-1/N} = \\\\exp\\\\!\\\\left(-\\\\frac{1}{N} \\\\sum_{i=1}^{N} \\\\log P(w_i \\\\mid w_{<i})\\\\right)\\] \\[P_{\\\\text{Laplace}}(w_i \\\\mid w_{i-1}) = \\\\frac{\\\\text{count}(w_{i-1}, w_i) + 1}{\\\\text{count}(w_{i-1}) + V}\\] \\[P_{\\\\text{KN}}(w_i \\\\mid w_{i-1}) = \\\\frac{\\\\max(\\\\text{count}(w_{i-1}, w_i) - d, \\\\; 0)}{\\\\text{count}(w_{i-1})} + \\\\lambda(w_{i-1}) \\\\cdot P_{\\\\text{cont}}(w_i)\\] \\[P_{\\\\text{cont}}(w_i) = \\\\frac{|\\\\{w' : \\\\text{count}(w', w_i) > 0\\\\}|}{|\\\\{(w', w'') : \\\\text{count}(w', w'') > 0\\\\}|}\\] "},{"location":"chapter%2007%3A%20computational%20linguistics/02.%20text%20processing%20and%20classic%20NLP/#colabnotebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528CoLab\u6216notebook\uff09","text":"
  1. \u4f7f\u7528\u52a8\u6001\u89c4\u5212\u5b9e\u73b0\u83b1\u6587\u65af\u5766\u7f16\u8f91\u8ddd\u79bb\u3002\u5728\u8bcd\u5bf9\u4e0a\u6d4b\u8bd5\uff0c\u5e76\u7528\u4e8e\u7b80\u5355\u7684\u62fc\u5199\u7ea0\u6b63\u3002

    import jax.numpy as jnp\n\ndef edit_distance(s, t):\n    \"\"\"Compute Levenshtein edit distance using DP.\"\"\"\n    m, n = len(s), len(t)\n    D = [[0] * (n + 1) for _ in range(m + 1)]\n\n    for i in range(m + 1):\n        D[i][0] = i\n    for j in range(n + 1):\n        D[0][j] = j\n\n    for i in range(1, m + 1):\n        for j in range(1, n + 1):\n            if s[i-1] == t[j-1]:\n                D[i][j] = D[i-1][j-1]\n            else:\n                D[i][j] = 1 + min(D[i-1][j], D[i][j-1], D[i-1][j-1])\n\n    return D[m][n]\n\n# Test\npairs = [(\"kitten\", \"sitting\"), (\"sunday\", \"saturday\"), (\"hello\", \"hallo\")]\nfor s, t in pairs:\n    print(f\"d('{s}', '{t}') = {edit_distance(s, t)}\")\n\n# Simple spelling correction\ndictionary = [\"the\", \"their\", \"there\", \"then\", \"than\", \"this\", \"that\", \"these\", \"those\"]\nmisspelled = \"thier\"\ncorrections = sorted(dictionary, key=lambda w: edit_distance(misspelled, w))\nprint(f\"\\nClosest to '{misspelled}': {corrections[:3]}\")\n

  2. \u4ece\u5934\u5b9e\u73b0BPE\u5206\u8bcd\u3002\u4ece\u5b57\u7b26\u7ea7\u8bcd\u5143\u5f00\u59cb\uff0c\u8fed\u4ee3\u5730\u5408\u5e76\u6700\u9891\u7e41\u7684\u5bf9\u3002

    from collections import Counter\n\ndef get_pairs(corpus):\n    \"\"\"Count adjacent token pairs across all words.\"\"\"\n    pairs = Counter()\n    for word, freq in corpus.items():\n        symbols = word.split()\n        for i in range(len(symbols) - 1):\n            pairs[(symbols[i], symbols[i+1])] += freq\n    return pairs\n\ndef merge_pair(pair, corpus):\n    \"\"\"Merge all occurrences of a pair in the corpus.\"\"\"\n    new_corpus = {}\n    bigram = ' '.join(pair)\n    replacement = ''.join(pair)\n    for word, freq in corpus.items():\n        new_word = word.replace(bigram, replacement)\n        new_corpus[new_word] = freq\n    return new_corpus\n\n# Training corpus with word frequencies\ntext = \"low low low low low lower lower newest newest newest newest newest newest\"\nword_freqs = Counter(text.split())\n# Initialise: split each word into characters with end-of-word marker\ncorpus = {' '.join(word) + ' _': freq for word, freq in word_freqs.items()}\n\nprint(\"Initial corpus:\")\nfor word, freq in corpus.items():\n    print(f\"  {word}: {freq}\")\n\n# Run BPE for 10 merges\nfor i in range(10):\n    pairs = get_pairs(corpus)\n    if not pairs:\n        break\n    best_pair = max(pairs, key=pairs.get)\n    corpus = merge_pair(best_pair, corpus)\n    print(f\"\\nMerge {i+1}: {best_pair} (freq={pairs[best_pair]})\")\n    for word, freq in corpus.items():\n        print(f\"  {word}: {freq}\")\n

  3. \u6784\u5efa\u4e00\u4e2a\u4e8c\u5143\u8bed\u8a00\u6a21\u578b\uff0c\u5e76\u8ba1\u7b97\u6d4b\u8bd5\u53e5\u5b50\u7684\u56f0\u60d1\u5ea6\u3002\u5c1d\u8bd5\u62c9\u666e\u62c9\u65af\u5e73\u6ed1\u3002

    from collections import Counter, defaultdict\nimport math\n\n# Training corpus\ntrain = \"\"\"the cat sat on the mat . the dog chased the cat .\nthe cat ran from the dog . a dog sat on a mat .\"\"\".split()\n\n# Count bigrams and unigrams\nbigrams = Counter(zip(train[:-1], train[1:]))\nunigrams = Counter(train)\nvocab_size = len(set(train))\n\ndef bigram_prob(w2, w1, alpha=0):\n    \"\"\"P(w2 | w1) with optional Laplace smoothing.\"\"\"\n    return (bigrams[(w1, w2)] + alpha) / (unigrams[w1] + alpha * vocab_size)\n\n# Compute perplexity\ntest = \"the cat sat on a mat .\".split()\n\nfor alpha in [0, 1, 0.1]:\n    log_prob = 0\n    for w1, w2 in zip(test[:-1], test[1:]):\n        p = bigram_prob(w2, w1, alpha=alpha)\n        if p > 0:\n            log_prob += math.log(p)\n        else:\n            log_prob += float('-inf')\n\n    ppl = math.exp(-log_prob / (len(test) - 1)) if log_prob > float('-inf') else float('inf')\n    print(f\"Smoothing \u03b1={alpha}: perplexity = {ppl:.2f}\")\n

  4. \u4ece\u5934\u5b9e\u73b0TF-IDF\uff0c\u5e76\u4f7f\u7528\u4f59\u5f26\u76f8\u4f3c\u5ea6\u627e\u5230\u4e0e\u67e5\u8be2\u6700\u76f8\u4f3c\u7684\u6587\u6863\u3002

    import jax.numpy as jnp\nimport math\nfrom collections import Counter\n\ndocuments = [\n    \"the cat sat on the mat\",\n    \"the dog chased the cat around the park\",\n    \"a mat was placed on the floor by the door\",\n    \"the quick brown fox jumped over the lazy dog\",\n]\n\n# Build vocabulary\nvocab = sorted(set(word for doc in documents for word in doc.split()))\nword_to_idx = {w: i for i, w in enumerate(vocab)}\nV = len(vocab)\nN = len(documents)\n\n# Compute TF-IDF matrix\ndoc_freq = Counter()\nfor doc in documents:\n    for word in set(doc.split()):\n        doc_freq[word] += 1\n\ntfidf_matrix = jnp.zeros((N, V))\nfor i, doc in enumerate(documents):\n    word_counts = Counter(doc.split())\n    for word, count in word_counts.items():\n        tf = 1 + math.log(count)\n        idf = math.log(N / doc_freq[word])\n        j = word_to_idx[word]\n        tfidf_matrix = tfidf_matrix.at[i, j].set(tf * idf)\n\n# Query\nquery = \"cat on the mat\"\nquery_vec = jnp.zeros(V)\nquery_counts = Counter(query.split())\nfor word, count in query_counts.items():\n    if word in word_to_idx:\n        tf = 1 + math.log(count)\n        idf = math.log(N / doc_freq.get(word, 1))\n        query_vec = query_vec.at[word_to_idx[word]].set(tf * idf)\n\n# Cosine similarity (from chapter 01)\ndef cosine_sim(a, b):\n    return jnp.dot(a, b) / (jnp.linalg.norm(a) * jnp.linalg.norm(b) + 1e-8)\n\nprint(f\"Query: '{query}'\\n\")\nfor i, doc in enumerate(documents):\n    sim = cosine_sim(query_vec, tfidf_matrix[i])\n    print(f\"  Doc {i} (sim={sim:.3f}): '{doc}'\")\n

"},{"location":"chapter%2007%3A%20computational%20linguistics/03.%20embeddings%20and%20sequence%20models/","title":"\u5d4c\u5165\u4e0e\u5e8f\u5217\u6a21\u578b","text":"

\u8bcd\u5d4c\u5165\u5c06\u7a00\u758f\u7684\u7b26\u53f7\u5316\u6587\u672c\u538b\u7f29\u5230\u7a20\u5bc6\u5411\u91cf\u7a7a\u95f4\u4e2d\uff0c\u4f7f\u5f97\u8bed\u4e49\u76f8\u4f3c\u6027\u8f6c\u5316\u4e3a\u51e0\u4f55\u90bb\u8fd1\u6027\u3002\u672c\u6587\u6db5\u76d6 Word2Vec\uff08CBOW\u3001Skip-gram\uff09\u3001GloVe\u3001FastText\u3001RNN\u3001LSTM\u3001GRU\u3001\u5e26\u6ce8\u610f\u529b\u673a\u5236\u7684 seq2seq\u3001\u7f16\u7801\u5668-\u89e3\u7801\u5668\u8303\u5f0f\uff0c\u4ee5\u53ca\u4ece\u8bcd\u888b\u6a21\u578b\u5230\u4e0a\u4e0b\u6587\u8868\u793a\u7684\u53d1\u5c55\u5386\u7a0b\u3002

\\[P(w_t \\mid w_{t-k}, \\ldots, w_{t-1}, w_{t+1}, \\ldots, w_{t+k})\\] \\[P(w_{t+j} \\mid w_t) \\quad \\text{\u5bf9\u4e8e\u6bcf\u4e2a } j \\in [-k, k], \\; j \\neq 0\\]

\\[\\mathcal{L} = \\log \\sigma(v_{w_O}^T v_{w_I}) + \\sum_{i=1}^{k} \\mathbb{E}_{w_i \\sim P_n} [\\log \\sigma(-v_{w_i}^T v_{w_I})]\\] \\[v_w^T v_c \\approx \\text{PMI}(w, c) - \\log k\\] \\[w_i^T \\tilde{w}_j + b_i + \\tilde{b}_j = \\log X_{ij}\\] \\[\\mathcal{L} = \\sum_{i,j=1}^{V} f(X_{ij}) \\left(w_i^T \\tilde{w}_j + b_i + \\tilde{b}_j - \\log X_{ij}\\right)^2\\]

\\[e_{ti} = v^T \\tanh(W_s s_{t-1} + W_h h_i)\\] \\[\\alpha_{ti} = \\frac{\\exp(e_{ti})}{\\sum_j \\exp(e_{tj})}, \\quad c_t = \\sum_i \\alpha_{ti} h_i\\]

\\[\\text{score}(y) = \\frac{1}{|y|^\\alpha} \\sum_{t=1}^{|y|} \\log P(y_t \\mid y_{<t})\\]

\\[\\text{ELMo}_k = \\gamma \\sum_{j=0}^{L} s_j \\, h_{k,j}\\] "},{"location":"chapter%2007%3A%20computational%20linguistics/03.%20embeddings%20and%20sequence%20models/#colab-notebook","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
  1. \u4ece\u5934\u5b9e\u73b0\u5e26\u8d1f\u91c7\u6837\u7684 Word2Vec skip-gram\u3002\u5728\u5c0f\u578b\u8bed\u6599\u5e93\u4e0a\u8bad\u7ec3\uff0c\u5e76\u4f7f\u7528 PCA \u53ef\u89c6\u5316\u5b66\u4e60\u5230\u7684\u5d4c\u5165\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u5c0f\u578b\u8bed\u6599\u5e93\ncorpus = \"\"\"the king ruled the kingdom . the queen ruled the kingdom .\nthe prince is the son of the king . the princess is the daughter of the queen .\na man worked in the castle . a woman worked in the castle .\nthe king and queen lived in the castle . the prince and princess played outside .\"\"\".lower().split()\n\nvocab = sorted(set(corpus))\nword2idx = {w: i for i, w in enumerate(vocab)}\nidx2word = {i: w for w, i in word2idx.items()}\nV = len(vocab)\n\n# \u751f\u6210 skip-gram \u5bf9\uff0c\u7a97\u53e3\u5927\u5c0f\u4e3a 2\nwindow = 2\npairs = []\nfor i, word in enumerate(corpus):\n    for j in range(max(0, i - window), min(len(corpus), i + window + 1)):\n        if i != j:\n            pairs.append((word2idx[word], word2idx[corpus[j]]))\n\npairs = jnp.array(pairs)\nprint(f\"\u8bcd\u6c47\u8868\u5927\u5c0f: {V} \u4e2a\u8bcd, \u8bad\u7ec3\u6837\u672c\u6570: {len(pairs)}\")\n\n# \u6a21\u578b\u53c2\u6570\nembed_dim = 16\nkey = jax.random.PRNGKey(42)\nk1, k2 = jax.random.split(key)\nW_in = jax.random.normal(k1, (V, embed_dim)) * 0.1    # \u8f93\u5165\u5d4c\u5165\nW_out = jax.random.normal(k2, (V, embed_dim)) * 0.1   # \u8f93\u51fa\u5d4c\u5165\n\n# \u5355\u4e2a\u6837\u672c\u5bf9\u7684\u8d1f\u91c7\u6837\u635f\u5931\ndef neg_sampling_loss(W_in, W_out, target, context, neg_ids):\n    v_in = W_in[target]      # (embed_dim,)\n    v_out = W_out[context]   # (embed_dim,)\n    v_neg = W_out[neg_ids]   # (k, embed_dim)\n\n    pos_loss = -jax.nn.log_sigmoid(jnp.dot(v_in, v_out))\n    neg_loss = -jnp.sum(jax.nn.log_sigmoid(-v_neg @ v_in))\n    return pos_loss + neg_loss\n\n# \u8bad\u7ec3\u5faa\u73af\nnum_neg = 5\nlr = 0.05\n\n@jax.jit\ndef train_step(W_in, W_out, target, context, neg_ids):\n    loss, (g_in, g_out) = jax.value_and_grad(neg_sampling_loss, argnums=(0, 1))(\n        W_in, W_out, target, context, neg_ids)\n    return loss, W_in - lr * g_in, W_out - lr * g_out\n\nkey = jax.random.PRNGKey(0)\nfor epoch in range(50):\n    total_loss = 0.0\n    for i in range(len(pairs)):\n        key, subkey = jax.random.split(key)\n        neg_ids = jax.random.randint(subkey, (num_neg,), 0, V)\n        loss, W_in, W_out = train_step(W_in, W_out, pairs[i, 0], pairs[i, 1], neg_ids)\n        total_loss += loss\n    if (epoch + 1) % 10 == 0:\n        print(f\"Epoch {epoch+1}: avg loss = {total_loss / len(pairs):.4f}\")\n\n# \u4f7f\u7528 PCA \u53ef\u89c6\u5316\uff08\u7b2c 01 \u7ae0\uff09\nembeddings = W_in\nmean = embeddings.mean(axis=0)\ncentered = embeddings - mean\nU, S, Vt = jnp.linalg.svd(centered, full_matrices=False)\ncoords = centered @ Vt[:2].T  # \u6295\u5f71\u5230\u524d\u4e24\u4e2a\u4e3b\u6210\u5206\n\nplt.figure(figsize=(10, 8))\nfor i, word in idx2word.items():\n    plt.scatter(coords[i, 0], coords[i, 1], c='#3498db', s=40)\n    plt.annotate(word, (coords[i, 0] + 0.02, coords[i, 1] + 0.02), fontsize=9)\nplt.title(\"Word2Vec Skip-gram \u5d4c\u5165\uff08PCA \u6295\u5f71\uff09\")\nplt.grid(alpha=0.3); plt.show()\n

  2. \u6784\u5efa\u4e00\u4e2a\u5b57\u7b26\u7ea7 RNN \u8bed\u8a00\u6a21\u578b\uff0c\u4ece\u4e00\u5c0f\u6bb5\u8bad\u7ec3\u6587\u672c\u4e2d\u5b66\u4e60\u751f\u6210\u6587\u672c\u3002

    import jax\nimport jax.numpy as jnp\n\n# \u5c0f\u578b\u8bad\u7ec3\u6587\u672c\ntext = \"to be or not to be that is the question \"\nchars = sorted(set(text))\nchar2idx = {c: i for i, c in enumerate(chars)}\nidx2char = {i: c for c, i in char2idx.items()}\nV = len(chars)\ndata = jnp.array([char2idx[c] for c in text])\n\n# RNN \u53c2\u6570\nhidden_dim = 64\nkey = jax.random.PRNGKey(0)\nk1, k2, k3, k4, k5 = jax.random.split(key, 5)\n\nparams = {\n    'Wx': jax.random.normal(k1, (V, hidden_dim)) * 0.1,\n    'Wh': jax.random.normal(k2, (hidden_dim, hidden_dim)) * 0.05,\n    'bh': jnp.zeros(hidden_dim),\n    'Wy': jax.random.normal(k3, (hidden_dim, V)) * 0.1,\n    'by': jnp.zeros(V),\n}\n\ndef rnn_step(params, h, x_idx):\n    x = jnp.eye(V)[x_idx]  # one-hot \u7f16\u7801\n    h = jnp.tanh(x @ params['Wx'] + h @ params['Wh'] + params['bh'])\n    logits = h @ params['Wy'] + params['by']\n    return h, logits\n\ndef loss_fn(params, inputs, targets):\n    h = jnp.zeros(hidden_dim)\n    total_loss = 0.0\n    for t in range(len(inputs)):\n        h, logits = rnn_step(params, h, inputs[t])\n        log_probs = jax.nn.log_softmax(logits)\n        total_loss -= log_probs[targets[t]]\n    return total_loss / len(inputs)\n\ngrad_fn = jax.jit(jax.grad(loss_fn))\n\n# \u8bad\u7ec3\ninputs = data[:-1]\ntargets = data[1:]\nlr = 0.01\n\nfor step in range(500):\n    grads = grad_fn(params, inputs, targets)\n    params = {k: params[k] - lr * grads[k] for k in params}\n    if (step + 1) % 100 == 0:\n        l = loss_fn(params, inputs, targets)\n        print(f\"Step {step+1}: loss = {l:.4f}\")\n\n# \u751f\u6210\u6587\u672c\ndef generate(params, seed_char, length=60):\n    h = jnp.zeros(hidden_dim)\n    idx = char2idx[seed_char]\n    result = [seed_char]\n    key = jax.random.PRNGKey(42)\n    for _ in range(length):\n        h, logits = rnn_step(params, h, idx)\n        key, subkey = jax.random.split(key)\n        idx = jax.random.categorical(subkey, logits)\n        result.append(idx2char[int(idx)])\n    return ''.join(result)\n\nprint(f\"\\n\u751f\u6210\u6587\u672c: {generate(params, 't')}\")\n

  3. \u5b9e\u73b0\u4e00\u4e2a\u5e26 Bahdanau \u6ce8\u610f\u529b\u7684\u7b80\u6613 seq2seq \u6a21\u578b\uff0c\u7528\u4e8e\u5e8f\u5217\u53cd\u8f6c\u3002\u53ef\u89c6\u5316\u6ce8\u610f\u529b\u5bf9\u9f50\u77e9\u9635\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u4efb\u52a1\uff1a\u53cd\u8f6c\u6570\u5b57\u5e8f\u5217\uff08\u4f8b\u5982\uff0c[3, 1, 4] -> [4, 1, 3]\uff09\nvocab_size = 10  # \u6570\u5b57 0-9\nSOS, EOS = 10, 11  # \u7279\u6b8a\u8bcd\u5143\ntotal_vocab = 12\nembed_dim, hidden_dim = 16, 32\nmax_len = 5\n\nkey = jax.random.PRNGKey(42)\nkeys = jax.random.split(key, 8)\n\nparams = {\n    'embed': jax.random.normal(keys[0], (total_vocab, embed_dim)) * 0.1,\n    'enc_Wx': jax.random.normal(keys[1], (embed_dim, hidden_dim)) * 0.1,\n    'enc_Wh': jax.random.normal(keys[2], (hidden_dim, hidden_dim)) * 0.05,\n    'dec_Wx': jax.random.normal(keys[3], (embed_dim, hidden_dim)) * 0.1,\n    'dec_Wh': jax.random.normal(keys[4], (hidden_dim, hidden_dim)) * 0.05,\n    # Bahdanau \u6ce8\u610f\u529b\n    'Ws': jax.random.normal(keys[5], (hidden_dim, hidden_dim)) * 0.1,\n    'Wh_att': jax.random.normal(keys[6], (hidden_dim, hidden_dim)) * 0.1,\n    'v_att': jax.random.normal(keys[7], (hidden_dim,)) * 0.1,\n    # \u8f93\u51fa\u6295\u5f71\uff08\u4ece\u9690\u85cf\u72b6\u6001+\u4e0a\u4e0b\u6587\u5230\u8bcd\u6c47\u8868\uff09\n    'Wo': jax.random.normal(keys[0], (hidden_dim * 2, total_vocab)) * 0.1,\n}\n\ndef encode(params, seq):\n    \"\"\"\u7f16\u7801\u8f93\u5165\u5e8f\u5217\uff0c\u8fd4\u56de\u6240\u6709\u9690\u85cf\u72b6\u6001\u3002\"\"\"\n    h = jnp.zeros(hidden_dim)\n    states = []\n    for t in range(len(seq)):\n        x = params['embed'][seq[t]]\n        h = jnp.tanh(x @ params['enc_Wx'] + h @ params['enc_Wh'])\n        states.append(h)\n    return jnp.stack(states), h\n\ndef bahdanau_attention(params, dec_state, enc_states):\n    \"\"\"\u8ba1\u7b97 Bahdanau \u6ce8\u610f\u529b\u6743\u91cd\u548c\u4e0a\u4e0b\u6587\u5411\u91cf\u3002\"\"\"\n    scores = jnp.tanh(enc_states @ params['Wh_att'] + dec_state @ params['Ws'])\n    e = scores @ params['v_att']  # (src_len,)\n    alpha = jax.nn.softmax(e)\n    context = alpha @ enc_states\n    return context, alpha\n\ndef decode_step(params, dec_h, prev_token, enc_states):\n    x = params['embed'][prev_token]\n    dec_h = jnp.tanh(x @ params['dec_Wx'] + dec_h @ params['dec_Wh'])\n    context, alpha = bahdanau_attention(params, dec_h, enc_states)\n    combined = jnp.concatenate([dec_h, context])\n    logits = combined @ params['Wo']\n    return dec_h, logits, alpha\n\ndef seq2seq_loss(params, src, tgt):\n    enc_states, enc_final = encode(params, src)\n    dec_h = enc_final\n    loss = 0.0\n    prev_token = SOS\n    for t in range(len(tgt)):\n        dec_h, logits, _ = decode_step(params, dec_h, prev_token, enc_states)\n        log_probs = jax.nn.log_softmax(logits)\n        loss -= log_probs[tgt[t]]\n        prev_token = tgt[t]\n    return loss / len(tgt)\n\n# \u751f\u6210\u8bad\u7ec3\u6570\u636e\uff1a\u53cd\u8f6c\u5e8f\u5217\nkey = jax.random.PRNGKey(0)\ntrain_srcs, train_tgts = [], []\nfor _ in range(200):\n    key, subkey = jax.random.split(key)\n    length = jax.random.randint(subkey, (), 3, max_len + 1)\n    key, subkey = jax.random.split(key)\n    seq = jax.random.randint(subkey, (int(length),), 0, vocab_size)\n    train_srcs.append(seq)\n    train_tgts.append(seq[::-1])  # \u53cd\u8f6c\n\n# \u8bad\u7ec3\ngrad_fn = jax.grad(seq2seq_loss)\nlr = 0.01\n\nfor epoch in range(100):\n    total_loss = 0.0\n    for src, tgt in zip(train_srcs, train_tgts):\n        grads = grad_fn(params, src, tgt)\n        params = {k: params[k] - lr * grads[k] for k in params}\n        total_loss += seq2seq_loss(params, src, tgt)\n    if (epoch + 1) % 20 == 0:\n        print(f\"Epoch {epoch+1}: avg loss = {total_loss / len(train_srcs):.4f}\")\n\n# \u53ef\u89c6\u5316\u4e00\u4e2a\u793a\u4f8b\u7684\u6ce8\u610f\u529b\ntest_src = jnp.array([3, 1, 4, 1, 5])\ntest_tgt = test_src[::-1]\n\nenc_states, enc_final = encode(params, test_src)\ndec_h = enc_final\nattentions = []\nprev_token = SOS\nfor t in range(len(test_tgt)):\n    dec_h, logits, alpha = decode_step(params, dec_h, prev_token, enc_states)\n    attentions.append(alpha)\n    prev_token = test_tgt[t]\n\natt_matrix = jnp.stack(attentions)\nfig, ax = plt.subplots(figsize=(6, 5))\nim = ax.imshow(att_matrix, cmap='Blues')\nax.set_xlabel(\"\u6e90\u4f4d\u7f6e\"); ax.set_ylabel(\"\u76ee\u6807\u4f4d\u7f6e\")\nsrc_labels = [str(int(x)) for x in test_src]\ntgt_labels = [str(int(x)) for x in test_tgt]\nax.set_xticks(range(len(src_labels))); ax.set_xticklabels(src_labels)\nax.set_yticks(range(len(tgt_labels))); ax.set_yticklabels(tgt_labels)\nfor i in range(len(tgt_labels)):\n    for j in range(len(src_labels)):\n        ax.text(j, i, f\"{att_matrix[i,j]:.2f}\", ha='center', va='center', fontsize=9)\nax.set_title(\"Bahdanau \u6ce8\u610f\u529b\u5bf9\u9f50\uff08\u5e8f\u5217\u53cd\u8f6c\uff09\")\nplt.colorbar(im); plt.tight_layout(); plt.show()\n

"},{"location":"chapter%2007%3A%20computational%20linguistics/04.%20transformers%20and%20language%20models/","title":"Transformer\u4e0e\u8bed\u8a00\u6a21\u578b","text":"

Transformer\u7528\u81ea\u6ce8\u610f\u529b\u53d6\u4ee3\u4e86\u5faa\u73af\u7ed3\u6784\uff0c\u6210\u4e3a\u8bed\u8a00\u7406\u89e3\u548c\u751f\u6210\u7684\u4e3b\u5bfc\u67b6\u6784\u3002\u672c\u6587\u4ef6\u6db5\u76d6BERT\u3001GPT\u3001T5\u3001\u4f4d\u7f6e\u7f16\u7801\uff08\u6b63\u5f26\u7f16\u7801\u3001RoPE\uff09\u3001\u9884\u8bad\u7ec3\u76ee\u6807\uff08MLM\u3001CLM\uff09\u3001\u5fae\u8c03\u3001\u63d0\u793a\u5de5\u7a0b\u548c\u7f29\u653e\u5b9a\u5f8b\u2014\u2014\u8fd9\u4e9b\u662f\u73b0\u4ee3\u5927\u8bed\u8a00\u6a21\u578b\u80cc\u540e\u7684\u84dd\u56fe\u3002

\\[\\text{FFN}(x) = W_2 \\cdot \\text{GELU}(W_1 x + b_1) + b_2\\] \\[ \\begin{bmatrix} q'_{2i} \\\\ q'_{2i+1} \\end{bmatrix} = \\begin{bmatrix} \\cos m\\theta_i & -\\sin m\\theta_i \\\\ \\sin m\\theta_i & \\cos m\\theta_i \\end{bmatrix} \\begin{bmatrix} q_{2i} \\\\ q_{2i+1} \\end{bmatrix} \\]

\\[q'^T k' = (R_m q)^T (R_n k) = q^T R_m^T R_n \\, k = q^T R_{n-m} \\, k\\]

\\[\\mathcal{L}_{\\text{MLM}} = -\\sum_{i \\in \\mathcal{M}} \\log P(w_i \\mid w_{\\backslash \\mathcal{M}})\\]

\\[\\mathcal{L}_{\\text{CLM}} = -\\sum_{i=1}^{n} \\log P(w_i \\mid w_1, \\ldots, w_{i-1})\\] \\[W' = W + BA\\]

\\[L(N) \\propto N^{-\\alpha_N}, \\quad L(D) \\propto D^{-\\alpha_D}, \\quad L(C) \\propto C^{-\\alpha_C}\\]

\\[N_{\\text{opt}} \\propto C^{0.5}, \\quad D_{\\text{opt}} \\propto C^{0.5}\\] \\[G(x) = \\text{TopK}(\\text{softmax}(W_g x))\\]

\\[\\mathcal{L}_{\\text{balance}} = E \\cdot \\sum_{i=1}^{E} f_i \\cdot p_i\\] \\[F_1 = \\frac{2PR}{P + R}\\] \\[\\text{BLEU} = \\text{BP} \\cdot \\exp\\!\\left(\\sum_{n=1}^{N} w_n \\log p_n\\right)\\] \\[R_{\\text{LCS}} = \\frac{\\text{LCS}(X, Y)}{m}, \\quad P_{\\text{LCS}} = \\frac{\\text{LCS}(X, Y)}{n}, \\quad F_{\\text{LCS}} = \\frac{(1 + \\beta^2) R_{\\text{LCS}} P_{\\text{LCS}}}{R_{\\text{LCS}} + \\beta^2 P_{\\text{LCS}}}\\] \\[ \\text{BPB} = \\frac{-\\sum_{i} \\log_2 P(w_i \\mid w_{ \\[R_{\\text{BERT}} = \\frac{1}{|r|} \\sum_{r_i \\in r} \\max_{c_j \\in c} \\cos(r_i, c_j), \\quad P_{\\text{BERT}} = \\frac{1}{|c|} \\sum_{c_j \\in c} \\max_{r_i \\in r} \\cos(c_j, r_i)\\] \\[P(A \\succ B) = \\frac{1}{1 + 10^{(R_B - R_A) / 400}}\\] \\[\\text{pass@}k = 1 - \\frac{\\binom{n-c}{k}}{\\binom{n}{k}}\\] "},{"location":"chapter%2007%3A%20computational%20linguistics/04.%20transformers%20and%20language%20models/#colab","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528CoLab\u6216\u7b14\u8bb0\u672c\uff09","text":"
  1. \u4ece\u5934\u5b9e\u73b0\u4e00\u4e2a\u5b8c\u6574\u7684Transformer\u7f16\u7801\u5668\u5757\uff08\u591a\u5934\u6ce8\u610f\u529b\u3001\u524d\u9988\u7f51\u7edc\u3001\u6b8b\u5dee\u8fde\u63a5\u3001\u5c42\u5f52\u4e00\u5316\uff09\u3002\u5c06\u5176\u5e94\u7528\u4e8e\u4e00\u4e2a\u7b80\u5355\u7684\u5e8f\u5217\u5206\u7c7b\u4efb\u52a1\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef layer_norm(x, gamma, beta, eps=1e-5):\n    mean = x.mean(axis=-1, keepdims=True)\n    var = x.var(axis=-1, keepdims=True)\n    return gamma * (x - mean) / jnp.sqrt(var + eps) + beta\n\ndef multi_head_attention(Q, K, V, W_q, W_k, W_v, W_o, n_heads):\n    B, T, D = Q.shape\n    head_dim = D // n_heads\n\n    q = Q @ W_q  # (B, T, D)\n    k = K @ W_k\n    v = V @ W_v\n\n    # Reshape to (B, n_heads, T, head_dim)\n    q = q.reshape(B, T, n_heads, head_dim).transpose(0, 2, 1, 3)\n    k = k.reshape(B, T, n_heads, head_dim).transpose(0, 2, 1, 3)\n    v = v.reshape(B, T, n_heads, head_dim).transpose(0, 2, 1, 3)\n\n    scores = q @ k.transpose(0, 1, 3, 2) / jnp.sqrt(head_dim)\n    weights = jax.nn.softmax(scores, axis=-1)\n    out = (weights @ v).transpose(0, 2, 1, 3).reshape(B, T, D)\n    return out @ W_o, weights\n\ndef transformer_block(x, params):\n    # Pre-norm multi-head self-attention\n    normed = layer_norm(x, params['ln1_g'], params['ln1_b'])\n    attn_out, weights = multi_head_attention(\n        normed, normed, normed,\n        params['W_q'], params['W_k'], params['W_v'], params['W_o'],\n        n_heads=4\n    )\n    x = x + attn_out\n\n    # Pre-norm feed-forward\n    normed = layer_norm(x, params['ln2_g'], params['ln2_b'])\n    ff = jax.nn.gelu(normed @ params['W1'] + params['b1'])\n    ff = ff @ params['W2'] + params['b2']\n    x = x + ff\n    return x, weights\n\n# Initialise parameters\nd_model, d_ff, n_heads = 32, 128, 4\nkey = jax.random.PRNGKey(42)\nkeys = jax.random.split(key, 10)\n\nparams = {\n    'W_q': jax.random.normal(keys[0], (d_model, d_model)) * 0.05,\n    'W_k': jax.random.normal(keys[1], (d_model, d_model)) * 0.05,\n    'W_v': jax.random.normal(keys[2], (d_model, d_model)) * 0.05,\n    'W_o': jax.random.normal(keys[3], (d_model, d_model)) * 0.05,\n    'ln1_g': jnp.ones(d_model), 'ln1_b': jnp.zeros(d_model),\n    'ln2_g': jnp.ones(d_model), 'ln2_b': jnp.zeros(d_model),\n    'W1': jax.random.normal(keys[4], (d_model, d_ff)) * 0.05,\n    'b1': jnp.zeros(d_ff),\n    'W2': jax.random.normal(keys[5], (d_ff, d_model)) * 0.05,\n    'b2': jnp.zeros(d_model),\n}\n\n# Test with random input\nx = jax.random.normal(keys[6], (2, 8, d_model))  # batch=2, seq_len=8\nout, attn_weights = transformer_block(x, params)\nprint(f\"Input shape:  {x.shape}\")\nprint(f\"Output shape: {out.shape}\")\nprint(f\"Attention weights shape: {attn_weights.shape}\")  # (B, n_heads, T, T)\n\n# Visualise attention patterns for each head\nfig, axes = plt.subplots(1, 4, figsize=(16, 3.5))\nfor h in range(4):\n    im = axes[h].imshow(attn_weights[0, h], cmap='Blues', vmin=0)\n    axes[h].set_title(f\"Head {h}\")\n    axes[h].set_xlabel(\"Key pos\"); axes[h].set_ylabel(\"Query pos\")\nplt.suptitle(\"Multi-Head Attention Patterns\")\nplt.tight_layout(); plt.show()\n

  2. \u5b9e\u73b0\u56e0\u679c\uff08\u81ea\u56de\u5f52\uff09\u6ce8\u610f\u529b\u63a9\u7801\uff0c\u5e76\u4e0e\u53cc\u5411\u6ce8\u610f\u529b\u8fdb\u884c\u6bd4\u8f83\u3002\u5c55\u793a\u63a9\u7801\u5982\u4f55\u9632\u6b62\u4fe1\u606f\u4ece\u672a\u6765\u6d41\u5411\u8fc7\u53bb\u7684\u6807\u8bb0\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef attention(Q, K, V, mask=None):\n    d_k = Q.shape[-1]\n    scores = Q @ K.T / jnp.sqrt(d_k)\n    if mask is not None:\n        scores = jnp.where(mask, scores, -1e9)\n    weights = jax.nn.softmax(scores, axis=-1)\n    return weights @ V, weights\n\nseq_len, d_model = 6, 8\nkey = jax.random.PRNGKey(0)\nk1, k2, k3 = jax.random.split(key, 3)\nQ = jax.random.normal(k1, (seq_len, d_model))\nK = jax.random.normal(k2, (seq_len, d_model))\nV = jax.random.normal(k3, (seq_len, d_model))\n\n# Bidirectional (encoder-style): all positions visible\nbidir_mask = jnp.ones((seq_len, seq_len), dtype=bool)\nbidir_out, bidir_weights = attention(Q, K, V, bidir_mask)\n\n# Causal (decoder-style): only past and current positions visible\ncausal_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=bool))\ncausal_out, causal_weights = attention(Q, K, V, causal_mask)\n\nfig, axes = plt.subplots(1, 3, figsize=(14, 4))\ntokens = [f\"t{i}\" for i in range(seq_len)]\n\naxes[0].imshow(bidir_weights, cmap='Blues', vmin=0, vmax=0.5)\naxes[0].set_title(\"Bidirectional Attention\\n(BERT-style)\")\naxes[0].set_xticks(range(seq_len)); axes[0].set_xticklabels(tokens)\naxes[0].set_yticks(range(seq_len)); axes[0].set_yticklabels(tokens)\n\naxes[1].imshow(causal_mask.astype(float), cmap='Greys', vmin=0, vmax=1)\naxes[1].set_title(\"Causal Mask\\n(1 = allowed, 0 = blocked)\")\naxes[1].set_xticks(range(seq_len)); axes[1].set_xticklabels(tokens)\naxes[1].set_yticks(range(seq_len)); axes[1].set_yticklabels(tokens)\n\naxes[2].imshow(causal_weights, cmap='Blues', vmin=0, vmax=0.5)\naxes[2].set_title(\"Causal Attention\\n(GPT-style)\")\naxes[2].set_xticks(range(seq_len)); axes[2].set_xticklabels(tokens)\naxes[2].set_yticks(range(seq_len)); axes[2].set_yticklabels(tokens)\n\nfor ax in axes:\n    ax.set_xlabel(\"Key\"); ax.set_ylabel(\"Query\")\nplt.tight_layout(); plt.show()\n\n# Verify: in causal attention, output at position i depends only on positions <= i\nprint(\"Causal attention weight at position 2 (should only attend to 0, 1, 2):\")\nprint(f\"  Weights: {causal_weights[2]}\")\nprint(f\"  Sum of future weights (should be ~0): {causal_weights[2, 3:].sum():.6f}\")\n

  3. \u5b9e\u73b0LoRA\uff08\u4f4e\u79e9\u9002\u914d\uff09\uff0c\u5e76\u5c55\u793a\u5b83\u5982\u4f55\u4ee5\u8fdc\u5c11\u4e8e\u5168\u91cf\u5fae\u8c03\u7684\u53ef\u8bad\u7ec3\u53c2\u6570\u6765\u4fee\u6539\u6743\u91cd\u77e9\u9635\u3002

    import jax\nimport jax.numpy as jnp\n\nd_model = 256\nrank = 4  # LoRA rank (much smaller than d_model)\n\nkey = jax.random.PRNGKey(42)\nk1, k2, k3 = jax.random.split(key, 3)\n\n# Original frozen weight matrix\nW_frozen = jax.random.normal(k1, (d_model, d_model)) * 0.02\n\n# LoRA matrices (only these are trainable)\nB = jnp.zeros((d_model, rank))       # initialised to zero\nA = jax.random.normal(k2, (rank, d_model)) * 0.01  # random init\n\n# Forward pass: W_effective = W_frozen + B @ A\nx = jax.random.normal(k3, (8, d_model))\n\n# Without LoRA\ny_original = x @ W_frozen.T\n\n# With LoRA\nW_effective = W_frozen + B @ A\ny_lora = x @ W_effective.T\n\n# Parameter counts\nfull_params = d_model * d_model\nlora_params = d_model * rank + rank * d_model  # B + A\n\nprint(f\"Model dimension: {d_model}\")\nprint(f\"LoRA rank: {rank}\")\nprint(f\"Full fine-tuning parameters: {full_params:,}\")\nprint(f\"LoRA parameters: {lora_params:,}\")\nprint(f\"Parameter reduction: {full_params / lora_params:.1f}x\")\nprint(f\"\\nSince B is initialised to zeros, initial LoRA output matches original:\")\nprint(f\"  Max difference: {jnp.abs(y_original - y_lora).max():.2e}\")\n\n# Simulate training: only update A and B\ndef lora_forward(A, B, W_frozen, x):\n    return x @ (W_frozen + B @ A).T\n\ndef dummy_loss(A, B, W_frozen, x, target):\n    pred = lora_forward(A, B, W_frozen, x)\n    return jnp.mean((pred - target) ** 2)\n\n# Target: some transformation of x\ntarget = x @ jax.random.normal(jax.random.PRNGKey(99), (d_model, d_model)).T * 0.02\n\ngrad_fn = jax.jit(jax.grad(dummy_loss, argnums=(0, 1)))\nlr = 0.01\n\nfor step in range(200):\n    gA, gB = grad_fn(A, B, W_frozen, x, target)\n    A = A - lr * gA\n    B = B - lr * gB\n\nloss_before = dummy_loss(jnp.zeros_like(A), jnp.zeros_like(B), W_frozen, x, target)\nloss_after = dummy_loss(A, B, W_frozen, x, target)\nprint(f\"\\nLoss before LoRA: {loss_before:.6f}\")\nprint(f\"Loss after LoRA:  {loss_after:.6f}\")\nprint(f\"Effective weight change rank: {jnp.linalg.matrix_rank(B @ A)}\")\n

"},{"location":"chapter%2007%3A%20computational%20linguistics/05.%20advanced%20text%20generation/","title":"\u9ad8\u7ea7\u6587\u672c\u751f\u6210","text":"

\u9ad8\u7ea7\u6587\u672c\u751f\u6210\u8d85\u8d8a\u4e86\u666e\u901a\u7684\u81ea\u56de\u5f52\u89e3\u7801\uff0c\u65e8\u5728\u63d0\u5347\u8d28\u91cf\u3001\u53ef\u63a7\u6027\u548c\u901f\u5ea6\u3002\u672c\u6587\u6db5\u76d6\u6587\u672c\u6269\u6563\u6a21\u578b\uff08D3PM\u3001MDLM\uff09\u3001OCR\u3001\u7528\u4e8e\u5bf9\u9f50\u7684RLHF\u4e0eDPO\u3001\u957f\u4e0a\u4e0b\u6587\u65b9\u6cd5\uff08RoPE\u7f29\u653e\u3001\u73af\u5f62\u6ce8\u610f\u529b\uff09\u3001\u68c0\u7d22\u589e\u5f3a\u751f\u6210\uff0c\u4ee5\u53ca\u7528\u4e8e\u52a0\u901f\u63a8\u7406\u7684\u63a8\u6d4b\u6027\u89e3\u7801\u3002

\\[q(x_t \\mid x_{t-1}) = \\text{Cat}(x_t ; \\, x_{t-1} Q_t)\\] \\[\\mathcal{L}_{\\text{D3PM}} = D_{\\text{KL}}(q(x_T \\mid x_0) \\| p(x_T)) + \\sum_{t=2}^{T} D_{\\text{KL}}(q(x_{t-1} \\mid x_t, x_0) \\| p_\\theta(x_{t-1} \\mid x_t)) - \\log p_\\theta(x_0 \\mid x_1)\\]

\\[P(y \\mid x) = \\sum_{\\pi \\in \\mathcal{B}^{-1}(y)} \\prod_{t=1}^{T} P(\\pi_t \\mid x)\\]

\\[\\text{logits}_{\\text{guided}} = (1 + w) \\cdot \\text{logits}_{\\text{conditional}} - w \\cdot \\text{logits}_{\\text{unconditional}}\\] \\[\\mathcal{L}_{\\text{RM}} = -\\log \\sigma(r_\\phi(x, y_w) - r_\\phi(x, y_l))\\] \\[\\mathcal{L}_{\\text{RL}} = -\\mathbb{E}\\left[r_\\phi(x, y) - \\beta \\, D_{\\text{KL}}(\\pi_\\theta \\| \\pi_{\\text{SFT}})\\right]\\]

\\[\\pi^\\ast(y \\mid x) = \\frac{1}{Z(x)} \\pi_{\\text{ref}}(y \\mid x) \\exp\\!\\left(\\frac{r(x, y)}{\\beta}\\right)\\] \\[\\mathcal{L}_{\\text{DPO}} = -\\log \\sigma\\!\\left(\\beta \\log \\frac{\\pi_\\theta(y_w \\mid x)}{\\pi_{\\text{ref}}(y_w \\mid x)} - \\beta \\log \\frac{\\pi_\\theta(y_l \\mid x)}{\\pi_{\\text{ref}}(y_l \\mid x)}\\right)\\]

\\[\\text{score}'_{ij} = \\frac{q_i^T k_j}{t \\sqrt{d_k}}\\] \\[x'(t) = Ax(t) + Bu(t), \\quad y(t) = Cx(t) + Du(t)\\] \\[\\bar{A} = \\exp(\\Delta A), \\quad \\bar{B} = (\\Delta A)^{-1}(\\exp(\\Delta A) - I) \\cdot \\Delta B\\]

\\[ A_{nk} = -\\begin{cases} (2n+1)^{1/2}(2k+1)^{1/2} & \\text{if } n > k \\\\ n+1 & \\text{if } n = k \\\\ 0 & \\text{if } n < k \\end{cases} \\] \\[B_k = \\text{Linear}(u_k), \\quad C_k = \\text{Linear}(u_k), \\quad \\Delta_k = \\text{softplus}(\\text{Linear}(u_k))\\]

\\[c_t = W_{\\text{down}} \\, h_t\\]

\\[\\mathcal{L}_{\\text{distill}} = D_{\\text{KL}}(p_{\\text{teacher}}(\\cdot \\mid x) \\| p_{\\text{student}}(\\cdot \\mid x))\\] \\[A_i = \\frac{r_i - \\text{mean}(r_1, \\ldots, r_G)}{\\text{std}(r_1, \\ldots, r_G)}\\] "},{"location":"chapter%2007%3A%20computational%20linguistics/05.%20advanced%20text%20generation/#colab-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
  1. \u4ece\u5934\u5b9e\u73b0\u4e00\u4e2a\u7b80\u5355\u7684\u68c0\u7d22\u589e\u5f3a\u751f\u6210\u7ba1\u7ebf\u3002\u4f7f\u7528TF-IDF\uff08\u6587\u4ef602\uff09\u7d22\u5f15\u4e00\u7ec4\u6587\u6863\uff0c\u4e3a\u67e5\u8be2\u68c0\u7d22\u6700\u76f8\u5173\u7684\u6bb5\u843d\uff0c\u5e76\u5c06\u5176\u524d\u7f6e\u5230\u63d0\u793a\u4e2d\u3002

    import jax.numpy as jnp\nimport math\nfrom collections import Counter\n\n# \u77e5\u8bc6\u5e93\uff1a\u4e00\u7ec4\u7b80\u77ed\u6bb5\u843d\nknowledge_base = [\n    \"The Eiffel Tower is a wrought-iron lattice tower in Paris, France. It was constructed from 1887 to 1889 as the centerpiece of the 1889 World's Fair.\",\n    \"The Great Wall of China is a series of fortifications built along the northern borders of China. Construction began in the 7th century BC.\",\n    \"Photosynthesis is the process by which plants convert sunlight, water, and carbon dioxide into glucose and oxygen using chlorophyll.\",\n    \"The theory of general relativity, published by Albert Einstein in 1915, describes gravity as the curvature of spacetime caused by mass and energy.\",\n    \"Python is a high-level programming language known for its simple syntax and readability. It was created by Guido van Rossum and released in 1991.\",\n    \"The mitochondria are organelles found in eukaryotic cells. They generate most of the cell's supply of ATP, used as a source of chemical energy.\",\n]\n\n# \u6784\u5efa TF-IDF \u7d22\u5f15\uff08\u91cd\u7528\u4e86\u6587\u4ef602\u4e2d\u7684\u6982\u5ff5\uff09\ndef tokenise(text):\n    return text.lower().split()\n\nvocab = sorted(set(w for doc in knowledge_base for w in tokenise(doc)))\nword2idx = {w: i for i, w in enumerate(vocab)}\nV = len(vocab)\nN = len(knowledge_base)\n\n# \u6587\u6863\u9891\u7387\ndoc_freq = Counter()\nfor doc in knowledge_base:\n    for w in set(tokenise(doc)):\n        doc_freq[w] += 1\n\ndef tfidf_vector(text):\n    words = tokenise(text)\n    counts = Counter(words)\n    vec = jnp.zeros(V)\n    for w, c in counts.items():\n        if w in word2idx:\n            tf = 1 + math.log(c)\n            idf = math.log(N / (doc_freq.get(w, 0) + 1))\n            vec = vec.at[word2idx[w]].set(tf * idf)\n    return vec\n\n# \u7d22\u5f15\u6240\u6709\u6587\u6863\ndoc_vectors = jnp.stack([tfidf_vector(doc) for doc in knowledge_base])\n\ndef cosine_sim(a, b):\n    return jnp.dot(a, b) / (jnp.linalg.norm(a) * jnp.linalg.norm(b) + 1e-8)\n\ndef retrieve(query, top_k=2):\n    \"\"\"\u4e3a\u67e5\u8be2\u68c0\u7d22top-k\u4e2a\u6700\u76f8\u5173\u7684\u6bb5\u843d\u3002\"\"\"\n    q_vec = tfidf_vector(query)\n    sims = jnp.array([cosine_sim(q_vec, doc_vectors[i]) for i in range(N)])\n    top_indices = jnp.argsort(-sims)[:top_k]\n    return [(int(i), float(sims[i]), knowledge_base[int(i)]) for i in top_indices]\n\n# \u6d4b\u8bd5\u68c0\u7d22\nqueries = [\n    \"Who built the Eiffel Tower?\",\n    \"How do plants make food?\",\n    \"What did Einstein discover?\",\n]\n\nfor query in queries:\n    results = retrieve(query, top_k=1)\n    print(f\"\\nQuery: '{query}'\")\n    for idx, sim, passage in results:\n        print(f\"  Retrieved (sim={sim:.3f}): '{passage[:80]}...'\")\n\n    # RAG\u98ce\u683c\u7684\u63d0\u793a\u6784\u5efa\n    context = results[0][2]\n    rag_prompt = f\"Context: {context}\\n\\nQuestion: {query}\\nAnswer:\"\n    print(f\"  RAG prompt:\\n    {rag_prompt[:120]}...\")\n

  2. \u4f7f\u7528\u73a9\u5177\u8349\u7a3f\u6a21\u578b\u548c\u76ee\u6807\u6a21\u578b\u5b9e\u73b0\u63a8\u6d4b\u6027\u89e3\u7801\u3002\u5c55\u793a\u63a5\u53d7\u7684\u8f93\u51fa\u4e0e\u76ee\u6807\u6a21\u578b\u7684\u5206\u5e03\u4e00\u81f4\u3002

    import jax\nimport jax.numpy as jnp\n\n# \u6a21\u62df\u8349\u7a3f\u6a21\u578b\uff08\u5feb\u901f\uff0c\u4e0d\u592a\u51c6\u786e\uff09\u548c\u76ee\u6807\u6a21\u578b\uff08\u6162\u901f\uff0c\u51c6\u786e\uff09\nvocab_size = 8\nseq_len = 5\n\nkey = jax.random.PRNGKey(42)\n\n# \u76ee\u6807\u6a21\u578b\uff1a\u7ed9\u5b9a\u5e8f\u5217\u8fd4\u56delogits\ndef target_model(seq, key):\n    \"\"\"\u6a21\u62df\u7684\u76ee\u6807\u6a21\u578b\uff1a\u4ea7\u751ftoken logits\uff08\u6602\u8d35\u7684\uff09\u3002\"\"\"\n    # \u5b9e\u8df5\u4e2d\u8fd9\u5c06\u662f\u4e00\u4e2a\u5927\u578bTransformer\u524d\u5411\u4f20\u64ad\n    k1, k2 = jax.random.split(key)\n    logits = jax.random.normal(k1, (len(seq), vocab_size)) * 2\n    # \u4f7f\u5176\u6709\u4e9b\u53ef\u9884\u6d4b\u6027\uff1a\u504f\u5411\u4e8e token (seq[-1] + 1) % vocab_size\n    for i in range(len(seq)):\n        logits = logits.at[i, (seq[i] + 1) % vocab_size].add(3.0)\n    return logits\n\ndef draft_model(seq, key):\n    \"\"\"\u6a21\u62df\u7684\u8349\u7a3f\u6a21\u578b\uff1a\u7c7b\u4f3c\u4f46\u566a\u58f0\u66f4\u5927\uff08\u4fbf\u5b9c\u7684\uff09\u3002\"\"\"\n    k1, k2 = jax.random.split(key)\n    logits = jax.random.normal(k1, (len(seq), vocab_size))\n    for i in range(len(seq)):\n        logits = logits.at[i, (seq[i] + 1) % vocab_size].add(2.0)\n    return logits\n\ndef sample_token(logits, key):\n    return jax.random.categorical(key, logits)\n\ndef speculative_decode(prefix, draft_steps=3, key=jax.random.PRNGKey(0)):\n    \"\"\"\u63a8\u6d4b\u6027\u89e3\u7801\uff1a\u8349\u7a3f\u63d0\u51fa\uff0c\u76ee\u6807\u9a8c\u8bc1\u3002\"\"\"\n    seq = list(prefix)\n    total_accepted = 0\n    total_proposed = 0\n\n    for _ in range(4):  # \u751f\u62104\u8f6e\n        key, *subkeys = jax.random.split(key, draft_steps + 3)\n\n        # \u8349\u7a3f\u6a21\u578b\u63d0\u51fadraft_steps\u4e2atoken\n        draft_tokens = []\n        draft_probs = []\n        draft_seq = list(seq)\n        for i in range(draft_steps):\n            d_logits = draft_model(jnp.array(draft_seq), subkeys[i])\n            d_probs = jax.nn.softmax(d_logits[-1])\n            tok = sample_token(d_logits[-1], subkeys[i])\n            draft_tokens.append(int(tok))\n            draft_probs.append(d_probs)\n            draft_seq.append(int(tok))\n\n        # \u76ee\u6807\u6a21\u578b\u5728\u4e00\u6b21\u524d\u5411\u4e2d\u8bc4\u4f30\u6240\u6709\u8349\u7a3ftoken\n        target_logits = target_model(jnp.array(draft_seq), subkeys[draft_steps])\n        target_start = len(seq) - 1  # \u6700\u540e\u4e00\u4e2a\u524d\u7f00token\u7684\u4f4d\u7f6e\n\n        # \u63a5\u53d7/\u62d2\u7edd\u6bcf\u4e2a\u8349\u7a3ftoken\n        accepted = 0\n        for i in range(draft_steps):\n            t_probs = jax.nn.softmax(target_logits[target_start + i])\n            d_prob = draft_probs[i][draft_tokens[i]]\n            t_prob = t_probs[draft_tokens[i]]\n\n            # \u4ee5\u6982\u7387 min(1, target_prob / draft_prob) \u63a5\u53d7\n            accept_prob = jnp.minimum(1.0, t_prob / (d_prob + 1e-10))\n            key, accept_key = jax.random.split(key)\n            if jax.random.uniform(accept_key) < accept_prob:\n                seq.append(draft_tokens[i])\n                accepted += 1\n            else:\n                # \u62d2\u7edd\uff1a\u4ece\u8c03\u6574\u540e\u7684\u5206\u5e03\u4e2d\u91c7\u6837\n                key, resample_key = jax.random.split(key)\n                adjusted = jnp.maximum(0, t_probs - draft_probs[i])\n                adjusted = adjusted / (adjusted.sum() + 1e-10)\n                new_tok = jax.random.categorical(resample_key, jnp.log(adjusted + 1e-10))\n                seq.append(int(new_tok))\n                break\n\n        total_accepted += accepted\n        total_proposed += draft_steps\n\n    return seq, total_accepted, total_proposed\n\n# \u8fd0\u884c\u63a8\u6d4b\u6027\u89e3\u7801\nprefix = [0, 1]\nresult_seq, accepted, proposed = speculative_decode(prefix)\nacceptance_rate = accepted / proposed if proposed > 0 else 0\n\nprint(f\"Prefix: {prefix}\")\nprint(f\"Generated sequence: {result_seq}\")\nprint(f\"Draft proposals: {proposed}\")\nprint(f\"Accepted: {accepted}\")\nprint(f\"Acceptance rate: {acceptance_rate:.1%}\")\nprint(f\"Speedup potential: {(accepted + proposed) / proposed:.2f}x\")\n

  3. \u6784\u5efa\u4e00\u4e2a\u7b80\u5355\u7684DPO\u8bad\u7ec3\u5faa\u73af\u3002\u7ed9\u5b9a\u504f\u597d\u548c\u4e0d\u504f\u597d\u7684\u5b8c\u6210\u5e8f\u5217\u5bf9\uff0c\u4f7f\u7528DPO\u635f\u5931\u66f4\u65b0\u4e00\u4e2a\u5c0f\u6a21\u578b\u3002

    import jax\nimport jax.numpy as jnp\n\n# \u5fae\u578b\u8bed\u8a00\u6a21\u578b\uff1a\u4eceone-hot\u5230logits\u7684\u7ebf\u6027\u6295\u5f71\nvocab_size = 10\nseq_len = 4\n\nkey = jax.random.PRNGKey(42)\nk1, k2 = jax.random.split(key)\n\n# \u5f53\u524d\u7b56\u7565\u53c2\u6570\uff08\u53ef\u8bad\u7ec3\u7684\uff09\ntheta = jax.random.normal(k1, (vocab_size, vocab_size)) * 0.1\n# \u53c2\u8003\u7b56\u7565\u53c2\u6570\uff08theta\u7684\u51bb\u7ed3\u526f\u672c\uff09\ntheta_ref = theta.copy()\n\ndef log_prob_sequence(params, sequence):\n    \"\"\"\u8ba1\u7b97\u7b80\u5355\u81ea\u56de\u5f52\u6a21\u578b\u4e0b\u7684 log P(sequence)\u3002\"\"\"\n    total = 0.0\n    for t in range(1, len(sequence)):\n        # \u7b80\u5355\uff1a\u4f4d\u7f6et\u5904\u7684logits\u53d6\u51b3\u4e8e\u4f4d\u7f6et-1\u5904\u7684token\n        logits = params[sequence[t-1]]\n        log_probs = jax.nn.log_softmax(logits)\n        total += log_probs[sequence[t]]\n    return total\n\ndef dpo_loss(theta, theta_ref, preferred, dispreferred, beta=0.1):\n    \"\"\"\u4e00\u5bf9\u6570\u636e\u7684\u76f4\u63a5\u504f\u597d\u4f18\u5316\u635f\u5931\u3002\"\"\"\n    log_pi_w = log_prob_sequence(theta, preferred)\n    log_pi_l = log_prob_sequence(theta, dispreferred)\n    log_ref_w = log_prob_sequence(theta_ref, preferred)\n    log_ref_l = log_prob_sequence(theta_ref, dispreferred)\n\n    # DPO\u76ee\u6807\n    return -jax.nn.log_sigmoid(\n        beta * ((log_pi_w - log_ref_w) - (log_pi_l - log_ref_l))\n    )\n\n# \u504f\u597d\u6570\u636e\u96c6\uff1a(\u63d0\u793a\u524d\u7f00, \u504f\u597d\u5b8c\u6210\u5e8f\u5217, \u4e0d\u504f\u597d\u5b8c\u6210\u5e8f\u5217)\npreferences = [\n    (jnp.array([1, 3, 5, 7]), jnp.array([1, 3, 5, 2])),  # \u7ed3\u5c3e\u504f\u597d7\u800c\u4e0d\u662f2\n    (jnp.array([0, 2, 4, 6]), jnp.array([0, 2, 4, 9])),  # \u504f\u597d6\u800c\u4e0d\u662f9\n    (jnp.array([3, 3, 3, 3]), jnp.array([3, 3, 3, 0])),  # \u504f\u597d\u91cd\u590d\u800c\u4e0d\u662f0\n    (jnp.array([5, 6, 7, 8]), jnp.array([5, 6, 7, 1])),  # \u504f\u597d8\u800c\u4e0d\u662f1\n]\n\ngrad_fn = jax.jit(jax.grad(dpo_loss))\nlr = 0.05\n\nprint(\"\u8bad\u7ec3 DPO...\")\nfor epoch in range(100):\n    total_loss = 0.0\n    for preferred, dispreferred in preferences:\n        loss = dpo_loss(theta, theta_ref, preferred, dispreferred)\n        grads = grad_fn(theta, theta_ref, preferred, dispreferred)\n        theta = theta - lr * grads\n        total_loss += loss\n    if (epoch + 1) % 20 == 0:\n        avg_loss = total_loss / len(preferences)\n        print(f\"  Epoch {epoch+1}: avg DPO loss = {avg_loss:.4f}\")\n\n# \u68c0\u67e5\uff1a\u6a21\u578b\u73b0\u5728\u5e94\u8be5\u504f\u597d\u504f\u597d\u7684\u5b8c\u6210\u5e8f\u5217\nprint(\"\\nDPO\u8bad\u7ec3\u540e\u7684\u504f\u597d\u68c0\u67e5:\")\nfor preferred, dispreferred in preferences:\n    lp_w = log_prob_sequence(theta, preferred)\n    lp_l = log_prob_sequence(theta, dispreferred)\n    print(f\"  Preferred {list(preferred.astype(int))}: logP={lp_w:.3f}  \"\n          f\"Dispreferred {list(dispreferred.astype(int))}: logP={lp_l:.3f}  \"\n          f\"{'correct' if lp_w > lp_l else 'WRONG'}\")\n

"},{"location":"chapter%2008%3A%20computer%20vision/01.%20image%20fundamentals/","title":"\u56fe\u50cf\u57fa\u7840","text":"

\u56fe\u50cf\u57fa\u7840\u89e3\u91ca\u6570\u5b57\u56fe\u50cf\u5728\u88ab\u4efb\u4f55\u6a21\u578b\u5904\u7406\u4e4b\u524d\u5982\u4f55\u8868\u793a\u3001\u5f62\u6210\u548c\u9884\u5904\u7406\u3002\u672c\u6587\u6db5\u76d6\u50cf\u7d20\u3001\u8272\u5f69\u7a7a\u95f4\uff08RGB\u3001HSV\u3001YCbCr\u3001LAB\uff09\u3001\u9488\u5b54\u76f8\u673a\u6a21\u578b\u3001\u5377\u79ef\u3001\u8fb9\u7f18\u68c0\u6d4b\uff08Sobel\u3001Canny\uff09\u3001\u76f4\u65b9\u56fe\u4ee5\u53ca\u7279\u5f81\u63cf\u8ff0\u5b50\uff08SIFT\u3001ORB\uff09\uff0c\u662f\u5e95\u5c42\u89c6\u89c9\u7684\u5de5\u5177\u5305\u3002

\\[ \\begin{bmatrix} u \\\\ v \\\\ 1 \\end{bmatrix} = \\frac{1}{Z} \\begin{bmatrix} f_x & 0 & c_x \\\\ 0 & f_y & c_y \\\\ 0 & 0 & 1 \\end{bmatrix} \\begin{bmatrix} X \\\\ Y \\\\ Z \\end{bmatrix} \\]

\\[\\mathbf{p} = K [R \\mid t] \\mathbf{P}\\] \\[(\\text{\u56fe\u50cf} * K)[i,j] = \\sum_{m} \\sum_{n} \\text{\u56fe\u50cf}[i+m, j+n] \\cdot K[m, n]\\] \\[ G_x = \\begin{bmatrix} -1 & 0 & 1 \\\\ -2 & 0 & 2 \\\\ -1 & 0 & 1 \\end{bmatrix}, \\quad G_y = \\begin{bmatrix} -1 & -2 & -1 \\\\ 0 & 0 & 0 \\\\ 1 & 2 & 1 \\end{bmatrix} \\]

\\[F(u, v) = \\sum_{x=0}^{M-1} \\sum_{y=0}^{N-1} f(x, y) \\cdot e^{-j2\\pi(ux/M + vy/N)}\\]

\\[ M = \\sum_{(x,y) \\in W} w(x,y) \\begin{bmatrix} I_x^2 & I_x I_y \\\\ I_x I_y & I_y^2 \\end{bmatrix} \\]

\\[L(x, y, \\sigma) = G(x, y, \\sigma) * I(x, y)\\] "},{"location":"chapter%2008%3A%20computer%20vision/01.%20image%20fundamentals/#colab-notebook","title":"\u7f16\u7801\u4efb\u52a1\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
  1. \u52a0\u8f7d\u56fe\u50cf\uff0c\u5c06\u5176\u8f6c\u6362\u4e3a\u4e0d\u540c\u7684\u8272\u5f69\u7a7a\u95f4\uff08RGB\u3001HSV\u3001LAB\uff09\uff0c\u5e76\u53ef\u89c6\u5316\u5404\u4e2a\u901a\u9053\u3002\u89c2\u5bdf\u989c\u8272\u4fe1\u606f\u5728\u4e0d\u540c\u7a7a\u95f4\u4e2d\u7684\u5206\u5e03\u5dee\u5f02\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\nfrom PIL import Image\nimport numpy as np\n\n# Create a synthetic test image with distinct colours\nH, W = 128, 256\nimg = np.zeros((H, W, 3), dtype=np.uint8)\nimg[:, :64] = [255, 50, 50]     # red\nimg[:, 64:128] = [50, 255, 50]  # green\nimg[:, 128:192] = [50, 50, 255] # blue\nimg[:, 192:] = [255, 255, 50]   # yellow\n\n# Add a brightness gradient\nfor y in range(H):\n    scale = 0.3 + 0.7 * y / H\n    img[y] = (img[y] * scale).astype(np.uint8)\n\nimg_jnp = jnp.array(img, dtype=jnp.float32) / 255.0\n\n# Manual RGB to HSV conversion\ndef rgb_to_hsv(rgb):\n    r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2]\n    maxc = jnp.max(rgb, axis=-1)\n    minc = jnp.min(rgb, axis=-1)\n    diff = maxc - minc + 1e-7\n\n    # Hue\n    h = jnp.where(maxc == minc, 0.0,\n        jnp.where(maxc == r, 60 * ((g - b) / diff % 6),\n        jnp.where(maxc == g, 60 * ((b - r) / diff + 2),\n                              60 * ((r - g) / diff + 4))))\n    s = jnp.where(maxc < 1e-7, 0.0, diff / maxc)\n    v = maxc\n    return jnp.stack([h / 360, s, v], axis=-1)\n\nhsv = rgb_to_hsv(img_jnp)\n\nfig, axes = plt.subplots(2, 3, figsize=(14, 8))\nfor i, (ch, name) in enumerate(zip([img_jnp[...,0], img_jnp[...,1], img_jnp[...,2]],\n                                     ['Red', 'Green', 'Blue'])):\n    axes[0, i].imshow(ch, cmap='gray', vmin=0, vmax=1)\n    axes[0, i].set_title(f'RGB: {name}'); axes[0, i].axis('off')\n\nfor i, (ch, name) in enumerate(zip([hsv[...,0], hsv[...,1], hsv[...,2]],\n                                     ['Hue', 'Saturation', 'Value'])):\n    axes[1, i].imshow(ch, cmap='gray', vmin=0, vmax=1)\n    axes[1, i].set_title(f'HSV: {name}'); axes[1, i].axis('off')\n\nplt.suptitle('RGB vs HSV Channels')\nplt.tight_layout(); plt.show()\n

  2. \u4f7f\u7528\u4e8c\u7ef4\u5377\u79ef\u4ece\u5934\u5b9e\u73b0 Sobel \u8fb9\u7f18\u68c0\u6d4b\u548c\u9ad8\u65af\u6a21\u7cca\u3002\u5c06\u5176\u5e94\u7528\u4e8e\u56fe\u50cf\u5e76\u6bd4\u8f83\u7ed3\u679c\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef conv2d(image, kernel):\n    \"\"\"2D convolution (valid mode) from scratch.\"\"\"\n    H, W = image.shape\n    kH, kW = kernel.shape\n    out_h, out_w = H - kH + 1, W - kW + 1\n    output = jnp.zeros((out_h, out_w))\n    for i in range(out_h):\n        for j in range(out_w):\n            patch = image[i:i+kH, j:j+kW]\n            output = output.at[i, j].set(jnp.sum(patch * kernel))\n    return output\n\n# Create a test image: white rectangle on dark background\nimg = jnp.zeros((64, 64))\nimg = img.at[15:50, 20:45].set(1.0)\n# Add some noise\nkey = jax.random.PRNGKey(42)\nimg = img + jax.random.normal(key, img.shape) * 0.05\n\n# Sobel filters\nsobel_x = jnp.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=jnp.float32)\nsobel_y = jnp.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=jnp.float32)\n\n# Gaussian blur kernel (5x5, sigma=1)\nax = jnp.arange(-2, 3, dtype=jnp.float32)\nxx, yy = jnp.meshgrid(ax, ax)\ngaussian = jnp.exp(-(xx**2 + yy**2) / (2 * 1.0**2))\ngaussian = gaussian / gaussian.sum()\n\n# Apply filters\ngx = conv2d(img, sobel_x)\ngy = conv2d(img, sobel_y)\nedges = jnp.sqrt(gx**2 + gy**2)\nblurred = conv2d(img, gaussian)\n\nfig, axes = plt.subplots(1, 4, figsize=(16, 4))\nfor ax, data, title in zip(axes,\n    [img, edges, blurred, gx],\n    ['Original', 'Edge Magnitude', 'Gaussian Blur', 'Horizontal Gradient']):\n    ax.imshow(data, cmap='gray')\n    ax.set_title(title); ax.axis('off')\nplt.tight_layout(); plt.show()\n

  3. \u4ece\u5934\u5b9e\u73b0\u76f4\u65b9\u56fe\u5747\u8861\u5316\uff0c\u5e76\u5c06\u5176\u5e94\u7528\u4e8e\u4f4e\u5bf9\u6bd4\u5ea6\u7070\u5ea6\u56fe\u50cf\u3002\u6bd4\u8f83\u5747\u8861\u524d\u540e\u7684\u76f4\u65b9\u56fe\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# Create a low-contrast image (values clustered in a narrow range)\nkey = __import__('jax').random.PRNGKey(42)\nimg = __import__('jax').random.uniform(key, (128, 128)) * 0.3 + 0.3  # values in [0.3, 0.6]\n\ndef histogram_equalise(img, n_bins=256):\n    \"\"\"Histogram equalisation for a grayscale image.\"\"\"\n    # Quantise to bins\n    bins = jnp.linspace(0, 1, n_bins + 1)\n    hist = jnp.histogram(img, bins=bins)[0]\n\n    # Compute CDF\n    cdf = jnp.cumsum(hist)\n    cdf_normalised = (cdf - cdf.min()) / (cdf.max() - cdf.min())\n\n    # Map each pixel through the CDF\n    indices = jnp.clip((img * n_bins).astype(jnp.int32), 0, n_bins - 1)\n    equalised = cdf_normalised[indices]\n    return equalised\n\neq_img = histogram_equalise(img)\n\nfig, axes = plt.subplots(2, 2, figsize=(12, 10))\naxes[0, 0].imshow(img, cmap='gray', vmin=0, vmax=1)\naxes[0, 0].set_title('Original (Low Contrast)'); axes[0, 0].axis('off')\naxes[0, 1].imshow(eq_img, cmap='gray', vmin=0, vmax=1)\naxes[0, 1].set_title('After Histogram Equalisation'); axes[0, 1].axis('off')\n\naxes[1, 0].hist(img.ravel(), bins=64, color='#3498db', alpha=0.8)\naxes[1, 0].set_title('Histogram Before'); axes[1, 0].set_xlim(0, 1)\naxes[1, 1].hist(eq_img.ravel(), bins=64, color='#e74c3c', alpha=0.8)\naxes[1, 1].set_title('Histogram After'); axes[1, 1].set_xlim(0, 1)\n\nplt.tight_layout(); plt.show()\n

  4. \u4ece\u5934\u5b9e\u73b0 Harris \u89d2\u70b9\u68c0\u6d4b\u5668\u3002\u5728\u7b80\u5355\u56fe\u50cf\u4e2d\u68c0\u6d4b\u89d2\u70b9\u5e76\u53ef\u89c6\u5316\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef harris_corners(img, k=0.05, threshold=0.01):\n    \"\"\"Harris corner detection from scratch.\"\"\"\n    # Compute gradients with Sobel\n    sobel_x = jnp.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=jnp.float32)\n    sobel_y = jnp.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=jnp.float32)\n\n    # Pad image for valid convolution to preserve size\n    img_pad = jnp.pad(img, 1, mode='edge')\n    H, W = img.shape\n\n    Ix = jnp.zeros_like(img)\n    Iy = jnp.zeros_like(img)\n    for i in range(H):\n        for j in range(W):\n            patch = img_pad[i:i+3, j:j+3]\n            Ix = Ix.at[i, j].set(jnp.sum(patch * sobel_x))\n            Iy = Iy.at[i, j].set(jnp.sum(patch * sobel_y))\n\n    # Structure tensor components\n    Ixx = Ix * Ix\n    Iyy = Iy * Iy\n    Ixy = Ix * Iy\n\n    # Gaussian smoothing of structure tensor (approximate with window sum)\n    w = 3  # window half-size\n    R = jnp.zeros_like(img)\n    pad_xx = jnp.pad(Ixx, w, mode='constant')\n    pad_yy = jnp.pad(Iyy, w, mode='constant')\n    pad_xy = jnp.pad(Ixy, w, mode='constant')\n\n    for i in range(H):\n        for j in range(W):\n            sxx = jnp.sum(pad_xx[i:i+2*w+1, j:j+2*w+1])\n            syy = jnp.sum(pad_yy[i:i+2*w+1, j:j+2*w+1])\n            sxy = jnp.sum(pad_xy[i:i+2*w+1, j:j+2*w+1])\n            det = sxx * syy - sxy * sxy\n            trace = sxx + syy\n            R = R.at[i, j].set(det - k * trace * trace)\n\n    # Threshold\n    corners = R > threshold * R.max()\n    return R, corners\n\n# Test image: checkerboard pattern (lots of corners)\nblock = 16\nn = 4\nchecker = jnp.zeros((block * n, block * n))\nfor i in range(n):\n    for j in range(n):\n        if (i + j) % 2 == 0:\n            checker = checker.at[i*block:(i+1)*block, j*block:(j+1)*block].set(1.0)\n\nR, corners = harris_corners(checker)\ncy, cx = jnp.where(corners)\n\nfig, axes = plt.subplots(1, 3, figsize=(14, 4))\naxes[0].imshow(checker, cmap='gray')\naxes[0].set_title('Checkerboard'); axes[0].axis('off')\naxes[1].imshow(R, cmap='hot')\naxes[1].set_title('Harris Response'); axes[1].axis('off')\naxes[2].imshow(checker, cmap='gray')\naxes[2].scatter(cx, cy, c='#e74c3c', s=15, marker='x')\naxes[2].set_title(f'Detected Corners ({len(cx)})'); axes[2].axis('off')\nplt.tight_layout(); plt.show()\n

"},{"location":"chapter%2008%3A%20computer%20vision/02.%20convolutional%20networks/","title":"\u5377\u79ef\u7f51\u7edc","text":"

\u5377\u79ef\u795e\u7ecf\u7f51\u7edc\u76f4\u63a5\u4ece\u50cf\u7d20\u6570\u636e\u4e2d\u5b66\u4e60\u7a7a\u95f4\u7279\u5f81\u5c42\u7ea7\uff0c\u7528\u68af\u5ea6\u4f18\u5316\u7684\u6ee4\u6ce2\u5668\u53d6\u4ee3\u4eba\u5de5\u8bbe\u8ba1\u7684\u6ee4\u6ce2\u5668\u3002\u672c\u6587\u6db5\u76d6\u5377\u79ef\u673a\u5236\u3001\u6c60\u5316\u3001\u6b65\u957f\u3001\u7a7a\u6d1e\u5377\u79ef\u3001\u611f\u53d7\u91ce\uff0c\u4ee5\u53ca\u5b9a\u4e49\u4e86\u56fe\u50cf\u5206\u7c7b\u7684\u6807\u5fd7\u6027\u67b6\u6784\uff08LeNet\u3001AlexNet\u3001VGG\u3001ResNet\u3001Inception\u3001EfficientNet\uff09\u3002

\\[\\text{out} = \\left\\lfloor \\frac{\\text{in} - k + 2p}{s} \\right\\rfloor + 1\\]

\\[\\text{output} = F(x) + x\\]

\\[\\text{depth}: d = \\alpha^\\phi, \\quad \\text{width}: w = \\beta^\\phi, \\quad \\text{resolution}: r = \\gamma^\\phi\\]

\\[L_{\\text{Grad-CAM}} = \\text{ReLU}\\!\\left(\\sum_k \\alpha_k A^k\\right), \\quad \\alpha_k = \\frac{1}{Z} \\sum_i \\sum_j \\frac{\\partial y^c}{\\partial A^k_{ij}}\\]

"},{"location":"chapter%2008%3A%20computer%20vision/02.%20convolutional%20networks/#colab-notebook","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
  1. \u7528 JAX \u4ece\u5934\u5b9e\u73b0\u4e00\u4e2a\u7b80\u5355\u7684 CNN\uff0c\u5305\u542b\u4e24\u4e2a\u5377\u79ef\u5c42\u3001\u6700\u5927\u6c60\u5316\u548c\u4e00\u4e2a\u5206\u7c7b\u5934\u3002\u5728\u4e00\u4e2a\u5408\u6210\u7684\u4e8c\u7ef4\u6a21\u5f0f\u5206\u7c7b\u4efb\u52a1\u4e0a\u8bad\u7ec3\u5b83\u3002

    import jax\nimport jax.numpy as jnp\nimport jax.lax as lax\nimport matplotlib.pyplot as plt\n\ndef conv2d(x, kernel, stride=1):\n    \"\"\"\u7b80\u5355 2D \u5377\u79ef\uff0c\u5355\u8f93\u5165\uff0c\u5355\u6ee4\u6ce2\u5668\u3002\"\"\"\n    return lax.conv(x[None, None], kernel[None, None], (stride, stride), 'SAME')[0, 0]\n\ndef max_pool(x, size=2):\n    \"\"\"2x2 \u6700\u5927\u6c60\u5316\u3002\"\"\"\n    H, W = x.shape\n    x = x[:H//size*size, :W//size*size]\n    return x.reshape(H//size, size, W//size, size).max(axis=(1, 3))\n\ndef init_cnn(key):\n    k1, k2, k3 = jax.random.split(key, 3)\n    return {\n        'conv1': jax.random.normal(k1, (5, 5)) * 0.3,\n        'conv2': jax.random.normal(k2, (3, 3)) * 0.3,\n        'fc_w': jax.random.normal(k3, (64, 1)) * 0.1,\n        'fc_b': jnp.zeros(1),\n    }\n\ndef forward_cnn(params, img):\n    # Conv1 -> ReLU -> Pool\n    h = jnp.maximum(0, conv2d(img, params['conv1']))\n    h = max_pool(h)\n    # Conv2 -> ReLU -> Pool\n    h = jnp.maximum(0, conv2d(h, params['conv2']))\n    h = max_pool(h)\n    # Flatten and classify\n    flat = h.ravel()\n    # Pad or truncate to fixed size\n    flat = jnp.pad(flat, (0, max(0, 64 - len(flat))))[:64]\n    logit = (flat @ params['fc_w'] + params['fc_b']).squeeze()\n    return jax.nn.sigmoid(logit)\n\n# Generate synthetic data: class 0 = low-freq pattern, class 1 = high-freq\ndef make_data(key, n=200):\n    images, labels = [], []\n    for i in range(n):\n        k1, key = jax.random.split(key)\n        x, y = jnp.meshgrid(jnp.linspace(0, 4*jnp.pi, 32), jnp.linspace(0, 4*jnp.pi, 32))\n        if i < n // 2:\n            img = jnp.sin(x) + jax.random.normal(k1, (32, 32)) * 0.1\n            labels.append(0)\n        else:\n            img = jnp.sin(4 * x) * jnp.sin(4 * y) + jax.random.normal(k1, (32, 32)) * 0.1\n            labels.append(1)\n        images.append(img)\n    return images, jnp.array(labels, dtype=jnp.float32)\n\nkey = jax.random.PRNGKey(42)\nimages, labels = make_data(key)\nparams = init_cnn(jax.random.PRNGKey(0))\n\ndef loss_fn(params, img, label):\n    pred = forward_cnn(params, img)\n    return -(label * jnp.log(pred + 1e-7) + (1 - label) * jnp.log(1 - pred + 1e-7))\n\ngrad_fn = jax.grad(loss_fn)\nlr = 0.01\n\nfor epoch in range(5):\n    total_loss = 0.0\n    for img, label in zip(images, labels):\n        grads = grad_fn(params, img, label)\n        params = {k: params[k] - lr * grads[k] for k in params}\n        total_loss += loss_fn(params, img, label)\n    print(f\"Epoch {epoch}: loss = {total_loss / len(images):.4f}\")\n\n# Test accuracy\npreds = jnp.array([forward_cnn(params, img) > 0.5 for img in images])\nacc = jnp.mean(preds == labels)\nprint(f\"Accuracy: {acc:.2%}\")\n

  2. \u53ef\u89c6\u5316\u4e0d\u540c\u6ee4\u6ce2\u5668\u5927\u5c0f\u5982\u4f55\u5f71\u54cd\u611f\u53d7\u91ce\u3002\u5c55\u793a\u4e24\u4e2a\u5806\u53e0\u7684 3x3 \u6ee4\u6ce2\u5668\u4e0e\u4e00\u4e2a 5x5 \u6ee4\u6ce2\u5668\u8986\u76d6\u76f8\u540c\u7684\u611f\u53d7\u91ce\uff0c\u4f46\u53c2\u6570\u66f4\u5c11\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef compute_receptive_field(layers):\n    \"\"\"\u4ece\u4e00\u7ec4 (kernel_size, stride) \u5143\u7ec4\u8ba1\u7b97\u611f\u53d7\u91ce\u5927\u5c0f\u3002\"\"\"\n    rf = 1  # \u4ece 1 \u4e2a\u50cf\u7d20\u5f00\u59cb\n    stride_product = 1\n    for k, s in layers:\n        rf += (k - 1) * stride_product\n        stride_product *= s\n    return rf\n\n# Compare architectures\nconfigs = {\n    'Single 5x5': [(5, 1)],\n    'Two 3x3':    [(3, 1), (3, 1)],\n    'Three 3x3':  [(3, 1), (3, 1), (3, 1)],\n    'Single 7x7': [(7, 1)],\n    '3x3 stride 2 + 3x3': [(3, 2), (3, 1)],\n}\n\nprint(f\"{'Config':<25} {'RF':>4} {'Params (per channel)':>20}\")\nprint('-' * 55)\nfor name, layers in configs.items():\n    rf = compute_receptive_field(layers)\n    # Parameters: sum of k^2 for each layer (per input-output channel pair)\n    params = sum(k * k for k, s in layers)\n    print(f\"{name:<25} {rf:>4} {params:>20}\")\n\n# Visualise receptive fields\nfig, axes = plt.subplots(1, 3, figsize=(14, 4))\nfor ax, (name, rf_size) in zip(axes, [('5x5 filter', 5), ('Two 3x3 filters', 5), ('Three 3x3 filters', 7)]):\n    grid = jnp.zeros((9, 9))\n    c = 4  # centre\n    half = rf_size // 2\n    grid = grid.at[c-half:c+half+1, c-half:c+half+1].set(1.0)\n    ax.imshow(grid, cmap='Blues', vmin=0, vmax=1)\n    ax.set_title(f'{name}\\nRF = {rf_size}x{rf_size}')\n    ax.set_xticks(range(9)); ax.set_yticks(range(9))\n    ax.grid(True, alpha=0.3)\nplt.suptitle('Receptive Field Comparison')\nplt.tight_layout(); plt.show()\n

  3. \u4ece\u5934\u5b9e\u73b0 Grad-CAM\u3002\u7ed9\u5b9a\u4e00\u4e2a\u9884\u6784\u5efa\u7684\u7b80\u5355 CNN\uff0c\u8ba1\u7b97\u9488\u5bf9\u7279\u5b9a\u7c7b\u522b\u7684\u68af\u5ea6\u52a0\u6743\u6fc0\u6d3b\u56fe\uff0c\u5e76\u5c06\u5176\u53ef\u89c6\u5316\u4e3a\u70ed\u529b\u56fe\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef simple_cnn(params, img):\n    \"\"\"\u8fd4\u56de\u9884\u6d4b\u548c\u6700\u540e\u4e00\u4e2a\u5377\u79ef\u5c42\u6fc0\u6d3b\u7684\u7b80\u5355 CNN\u3002\"\"\"\n    # Conv layer (our \"last conv layer\" for Grad-CAM)\n    H, W = img.shape\n    k = params['conv'].shape[0]\n    pad = k // 2\n    img_pad = jnp.pad(img, pad, mode='edge')\n    activation_map = jnp.zeros((H, W))\n    for i in range(H):\n        for j in range(W):\n            activation_map = activation_map.at[i, j].set(\n                jnp.sum(img_pad[i:i+k, j:j+k] * params['conv'])\n            )\n    activation_map = jnp.maximum(0, activation_map)  # ReLU\n\n    # Global average pool -> dense -> output\n    pooled = activation_map.mean()\n    logit = pooled * params['w'] + params['b']\n    return jax.nn.sigmoid(logit), activation_map\n\n# Create test image: bright region on the left (class indicator)\nimg = jnp.zeros((32, 32))\nimg = img.at[8:24, 4:16].set(1.0)\nimg = img.at[5:10, 20:28].set(0.3)\n\nkey = jax.random.PRNGKey(42)\nparams = {\n    'conv': jax.random.normal(key, (5, 5)) * 0.3,\n    'w': jnp.array(2.0),\n    'b': jnp.array(-0.5),\n}\n\n# Compute Grad-CAM\ndef class_score(params, img):\n    pred, _ = simple_cnn(params, img)\n    return pred\n\n# Get activation map and gradients\npred, act_map = simple_cnn(params, img)\ngrad_fn = jax.grad(lambda img: simple_cnn(params, img)[0])\nimg_grad = grad_fn(img)\n\n# Weight = global average of gradients (simplified 1-channel Grad-CAM)\nalpha = img_grad.mean()\ngrad_cam = jnp.maximum(0, alpha * act_map)  # ReLU\ngrad_cam = (grad_cam - grad_cam.min()) / (grad_cam.max() - grad_cam.min() + 1e-8)\n\nfig, axes = plt.subplots(1, 3, figsize=(14, 4))\naxes[0].imshow(img, cmap='gray'); axes[0].set_title('Input Image'); axes[0].axis('off')\naxes[1].imshow(act_map, cmap='viridis'); axes[1].set_title('Activation Map'); axes[1].axis('off')\naxes[2].imshow(img, cmap='gray', alpha=0.6)\naxes[2].imshow(grad_cam, cmap='jet', alpha=0.4)\naxes[2].set_title(f'Grad-CAM (pred={pred:.2f})'); axes[2].axis('off')\nplt.tight_layout(); plt.show()\n

  4. \u6bd4\u8f83\u6df1\u5ea6\u53ef\u5206\u79bb\u5377\u79ef\u4e0e\u6807\u51c6\u5377\u79ef\u3002\u7edf\u8ba1\u4e24\u8005\u7684\u53c2\u6570\u548c FLOPs\uff0c\u5e76\u5c55\u793a\u5b83\u4eec\u5728\u8ba1\u7b97\u91cf\u5c11\u5f97\u591a\u7684\u60c5\u51b5\u4e0b\u4ea7\u751f\u76f8\u4f3c\u7684\u8f93\u51fa\u3002

    import jax\nimport jax.numpy as jnp\n\ndef standard_conv(x, kernel):\n    \"\"\"\u6807\u51c6\u5377\u79ef\uff1a(H, W, C_in) * (k, k, C_in, C_out) -> (H, W, C_out)\u3002\"\"\"\n    H, W, C_in = x.shape\n    k, _, _, C_out = kernel.shape\n    pad = k // 2\n    x_pad = jnp.pad(x, ((pad, pad), (pad, pad), (0, 0)), mode='constant')\n    out = jnp.zeros((H, W, C_out))\n    for i in range(H):\n        for j in range(W):\n            patch = x_pad[i:i+k, j:j+k, :]  # (k, k, C_in)\n            for c in range(C_out):\n                out = out.at[i, j, c].set(jnp.sum(patch * kernel[:, :, :, c]))\n    return out\n\ndef depthwise_separable_conv(x, dw_kernel, pw_kernel):\n    \"\"\"\u6df1\u5ea6\u53ef\u5206\u79bb\uff1a\u6df1\u5ea6\u5377\u79ef (k,k,C_in) \u7136\u540e\u9010\u70b9\u5377\u79ef (C_in, C_out)\u3002\"\"\"\n    H, W, C_in = x.shape\n    k = dw_kernel.shape[0]\n    pad = k // 2\n    x_pad = jnp.pad(x, ((pad, pad), (pad, pad), (0, 0)), mode='constant')\n\n    # Depthwise: one filter per channel\n    dw_out = jnp.zeros((H, W, C_in))\n    for i in range(H):\n        for j in range(W):\n            for c in range(C_in):\n                patch = x_pad[i:i+k, j:j+k, c]\n                dw_out = dw_out.at[i, j, c].set(jnp.sum(patch * dw_kernel[:, :, c]))\n\n    # Pointwise: 1x1 conv across channels\n    out = dw_out @ pw_kernel\n    return out\n\n# Setup\nH, W, C_in, C_out, k = 8, 8, 16, 32, 3\nkey = jax.random.PRNGKey(42)\nk1, k2, k3, k4 = jax.random.split(key, 4)\n\nx = jax.random.normal(k1, (H, W, C_in))\nstd_kernel = jax.random.normal(k2, (k, k, C_in, C_out)) * 0.1\ndw_kernel = jax.random.normal(k3, (k, k, C_in)) * 0.1\npw_kernel = jax.random.normal(k4, (C_in, C_out)) * 0.1\n\n# Compare\nstd_params = k * k * C_in * C_out\ndw_params = k * k * C_in + C_in * C_out\n\nstd_flops = H * W * k * k * C_in * C_out\ndw_flops = H * W * (k * k * C_in + C_in * C_out)\n\nprint(f\"Standard conv:            {std_params:>8,} params,  {std_flops:>10,} FLOPs\")\nprint(f\"Depthwise separable conv: {dw_params:>8,} params,  {dw_flops:>10,} FLOPs\")\nprint(f\"Parameter reduction:      {std_params / dw_params:.1f}x\")\nprint(f\"FLOP reduction:           {std_flops / dw_flops:.1f}x\")\n\nstd_out = standard_conv(x, std_kernel)\nds_out = depthwise_separable_conv(x, dw_kernel, pw_kernel)\nprint(f\"\\nStandard output shape:    {std_out.shape}\")\nprint(f\"Depthwise sep output shape: {ds_out.shape}\")\n

"},{"location":"chapter%2008%3A%20computer%20vision/03.%20object%20detection%20and%20segmentation/","title":"\u76ee\u6807\u68c0\u6d4b\u4e0e\u5206\u5272","text":"

\u76ee\u6807\u68c0\u6d4b\u5b9a\u4f4d\u5e76\u5206\u7c7b\u56fe\u50cf\u4e2d\u7684\u6bcf\u4e2a\u7269\u4f53\uff1b\u5206\u5272\u4e3a\u6bcf\u4e2a\u50cf\u7d20\u5206\u914d\u4e00\u4e2a\u6807\u7b7e\u3002\u672c\u6587\u4ef6\u6db5\u76d6\u4ea4\u5e76\u6bd4\uff08IoU\uff09\u3001\u5e73\u5747\u7cbe\u5ea6\u5747\u503c\uff08mAP\uff09\u3001\u951a\u6846\u3001R-CNN\u7cfb\u5217\u3001YOLO\u3001SSD\u3001\u7279\u5f81\u91d1\u5b57\u5854\u7f51\u7edc\uff08FPN\uff09\u3001\u8bed\u4e49/\u5b9e\u4f8b/\u5168\u666f\u5206\u5272\uff08U-Net\u3001Mask R-CNN\u3001SAM\uff09\u4ee5\u53ca\u7528\u4e8e\u57fa\u51c6\u6d4b\u8bd5\u7684\u8bc4\u4f30\u6307\u6807\u3002

\\[\\text{IoU} = \\frac{\\text{\u4ea4\u96c6\u9762\u79ef}}{\\text{\u5e76\u96c6\u9762\u79ef}}\\] \\[\\text{AP} = \\int_0^1 p(r) \\, dr\\]

\\[t_x = \\frac{x - x_a}{w_a}, \\quad t_y = \\frac{y - y_a}{h_a}, \\quad t_w = \\log\\frac{w}{w_a}, \\quad t_h = \\log\\frac{h}{h_a}\\] \\[ \\text{smooth}_{L1}(x) = \\begin{cases} 0.5x^2 & \\text{if } |x| < 1 \\\\ |x| - 0.5 & \\text{otherwise} \\end{cases} \\] \\[\\text{FL}(p_t) = -\\alpha_t (1 - p_t)^\\gamma \\log(p_t)\\]

\\[\\text{PQ} = \\underbrace{\\frac{\\sum_{(p,g) \\in \\text{TP}} \\text{IoU}(p,g)}{|\\text{TP}|}}_{\\text{SQ}} \\times \\underbrace{\\frac{|\\text{TP}|}{|\\text{TP}| + \\frac{1}{2}|\\text{FP}| + \\frac{1}{2}|\\text{FN}|}}_{\\text{RQ}}\\] "},{"location":"chapter%2008%3A%20computer%20vision/03.%20object%20detection%20and%20segmentation/#colabnotebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528CoLab\u6216notebook\uff09","text":"
  1. \u4ece\u5934\u5b9e\u73b0IoU\u8ba1\u7b97\u548c\u975e\u6781\u5927\u503c\u6291\u5236\u3002\u5bf9\u4e00\u7ec4\u91cd\u53e0\u7684\u8fb9\u754c\u6846\u5e94\u7528NMS\u5e76\u53ef\u89c6\u5316\u7ed3\u679c\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\nimport matplotlib.patches as patches\n\ndef compute_iou(box1, box2):\n    \"\"\"\u8ba1\u7b97\u4e24\u4e2a\u6846[x1, y1, x2, y2]\u4e4b\u95f4\u7684IoU\u3002\"\"\"\n    x1 = jnp.maximum(box1[0], box2[0])\n    y1 = jnp.maximum(box1[1], box2[1])\n    x2 = jnp.minimum(box1[2], box2[2])\n    y2 = jnp.minimum(box1[3], box2[3])\n\n    intersection = jnp.maximum(0, x2 - x1) * jnp.maximum(0, y2 - y1)\n    area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])\n    area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])\n    union = area1 + area2 - intersection\n\n    return intersection / (union + 1e-6)\n\ndef nms(boxes, scores, iou_threshold=0.5):\n    \"\"\"\u975e\u6781\u5927\u503c\u6291\u5236\u3002\"\"\"\n    order = jnp.argsort(-scores)  # \u6309\u7f6e\u4fe1\u5ea6\u964d\u5e8f\u6392\u5217\n    keep = []\n\n    remaining = list(range(len(scores)))\n    order_list = order.tolist()\n\n    while order_list:\n        idx = order_list[0]\n        keep.append(idx)\n        order_list = order_list[1:]\n\n        new_order = []\n        for j in order_list:\n            iou = compute_iou(boxes[idx], boxes[j])\n            if iou < iou_threshold:\n                new_order.append(j)\n        order_list = new_order\n\n    return keep\n\n# \u793a\u4f8b\uff1a\u540c\u4e00\u7269\u4f53\u7684\u91cd\u53e0\u68c0\u6d4b\nboxes = jnp.array([\n    [50, 60, 150, 160],   # \u9ad8\u7f6e\u4fe1\u5ea6\n    [55, 65, 155, 165],   # \u91cd\u53e0\u7684\u91cd\u590d\u6846\n    [52, 58, 148, 158],   # \u91cd\u53e0\u7684\u91cd\u590d\u6846\n    [200, 100, 300, 200], # \u4e0d\u540c\u7269\u4f53\n    [205, 105, 305, 205], # \u91cd\u53e0\u7684\u91cd\u590d\u6846\n])\nscores = jnp.array([0.95, 0.80, 0.70, 0.90, 0.60])\n\nkeep = nms(boxes, scores, iou_threshold=0.5)\n\nfig, axes = plt.subplots(1, 2, figsize=(14, 5))\ncolors = ['#3498db', '#e74c3c', '#27ae60', '#9b59b6', '#f39c12']\n\nfor ax, title, indices in zip(axes, ['NMS\u4e4b\u524d', 'NMS\u4e4b\u540e'],\n                               [range(len(boxes)), keep]):\n    ax.set_xlim(0, 400); ax.set_ylim(0, 300)\n    ax.set_aspect('equal'); ax.invert_yaxis()\n    ax.set_title(title)\n    for i in indices:\n        b = boxes[i]\n        rect = patches.Rectangle((b[0], b[1]), b[2]-b[0], b[3]-b[1],\n                                  linewidth=2, edgecolor=colors[i],\n                                  facecolor='none')\n        ax.add_patch(rect)\n        ax.text(b[0], b[1]-5, f'{scores[i]:.2f}', color=colors[i], fontsize=10)\n\nplt.tight_layout(); plt.show()\nprint(f\"NMS\u540e\u4fdd\u7559\u4e86{len(keep)}\u4e2a\u6846\uff0c\u5171{len(boxes)}\u4e2a\")\n

  2. \u5b9e\u73b0\u4e00\u4e2a\u7b80\u5316\u7684\u533a\u57df\u63d0\u8bae\u7f51\u7edc\uff08RPN\uff09\u3002\u7ed9\u5b9a\u4e00\u4e2a\u7279\u5f81\u56fe\uff0c\u751f\u6210\u5177\u6709\u591a\u79cd\u5c3a\u5ea6\u548c\u957f\u5bbd\u6bd4\u7684\u951a\u6846\uff0c\u5e76\u9884\u6d4b\u7269\u4f53\u6027\u5206\u6570\u548c\u8fb9\u754c\u6846\u504f\u79fb\u91cf\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\nimport matplotlib.patches as patches\n\ndef generate_anchors(feature_h, feature_w, stride, scales, ratios):\n    \"\"\"\u4e3a\u7279\u5f81\u56fe\u4e0a\u7684\u6bcf\u4e2a\u4f4d\u7f6e\u751f\u6210\u951a\u6846\u3002\"\"\"\n    anchors = []\n    for y in range(feature_h):\n        for x in range(feature_w):\n            cx = (x + 0.5) * stride\n            cy = (y + 0.5) * stride\n            for s in scales:\n                for r in ratios:\n                    w = s * jnp.sqrt(r)\n                    h = s / jnp.sqrt(r)\n                    anchors.append([cx - w/2, cy - h/2, cx + w/2, cy + h/2])\n    return jnp.array(anchors)\n\ndef rpn_forward(feature_map, params):\n    \"\"\"\u7b80\u5316\u7248RPN\uff1a\u9884\u6d4b\u6bcf\u4e2a\u951a\u6846\u7684\u7269\u4f53\u6027\u548c\u6846\u504f\u79fb\u91cf\u3002\"\"\"\n    H, W, C = feature_map.shape\n    n_anchors = params['cls_w'].shape[1]\n\n    # \u5728\u7279\u5f81\u56fe\u4e0a\u6ed1\u52a81x1\u5377\u79ef\uff08\u7b80\u5316\u7248\uff09\n    cls_scores = feature_map.reshape(-1, C) @ params['cls_w']  # (H*W, n_anchors)\n    box_offsets = feature_map.reshape(-1, C) @ params['reg_w']  # (H*W, n_anchors*4)\n\n    cls_scores = jax.nn.sigmoid(cls_scores)\n    return cls_scores.ravel(), box_offsets.reshape(-1, 4)\n\n# \u8bbe\u7f6e\nfeature_h, feature_w, channels = 4, 4, 16\nstride = 16  # \u6bcf\u4e2a\u7279\u5f81\u56fe\u5355\u5143\u683c\u8986\u76d616x16\u50cf\u7d20\nscales = [32, 64, 128]\nratios = [0.5, 1.0, 2.0]\nn_anchors_per_pos = len(scales) * len(ratios)\n\nkey = jax.random.PRNGKey(42)\nk1, k2, k3 = jax.random.split(key, 3)\n\nfeature_map = jax.random.normal(k1, (feature_h, feature_w, channels))\nparams = {\n    'cls_w': jax.random.normal(k2, (channels, n_anchors_per_pos)) * 0.01,\n    'reg_w': jax.random.normal(k3, (channels, n_anchors_per_pos * 4)) * 0.01,\n}\n\nanchors = generate_anchors(feature_h, feature_w, stride, scales, ratios)\nscores, offsets = rpn_forward(feature_map, params)\n\nprint(f\"\u7279\u5f81\u56fe\uff1a{feature_h}x{feature_w}\uff0c\u6b65\u5e45={stride}\")\nprint(f\"\u6bcf\u4e2a\u4f4d\u7f6e\u7684\u951a\u6846\u6570\uff1a{n_anchors_per_pos}\")\nprint(f\"\u951a\u6846\u603b\u6570\uff1a{len(anchors)}\")\nprint(f\"\u7269\u4f53\u6027\u5206\u6570\u5f62\u72b6\uff1a{scores.shape}\")\nprint(f\"\u8fb9\u754c\u6846\u504f\u79fb\u91cf\u5f62\u72b6\uff1a{offsets.shape}\")\n\n# \u53ef\u89c6\u5316\u4e00\u4e2a\u4f4d\u7f6e\u7684\u951a\u6846\nfig, ax = plt.subplots(figsize=(6, 6))\nimg_size = feature_h * stride\nax.set_xlim(0, img_size); ax.set_ylim(0, img_size)\nax.invert_yaxis(); ax.set_aspect('equal')\n\npos_idx = feature_h // 2 * feature_w + feature_w // 2  # \u4e2d\u5fc3\u4f4d\u7f6e\ncolors = ['#3498db', '#e74c3c', '#27ae60']\nfor i, s in enumerate(scales):\n    for j, r in enumerate(ratios):\n        idx = pos_idx * n_anchors_per_pos + i * len(ratios) + j\n        a = anchors[idx]\n        rect = patches.Rectangle((a[0], a[1]), a[2]-a[0], a[3]-a[1],\n                                  linewidth=1.5, edgecolor=colors[i],\n                                  facecolor='none', linestyle=['--', '-', ':'][j])\n        ax.add_patch(rect)\n\nax.scatter([img_size/2], [img_size/2], c='red', s=50, zorder=5)\nax.set_title(f'\u4e2d\u5fc3\u4f4d\u7f6e\u7684\u951a\u6846\\n3\u4e2a\u5c3a\u5ea6 \u00d7 3\u4e2a\u6bd4\u4f8b = {n_anchors_per_pos}')\nax.grid(True, alpha=0.3)\nplt.tight_layout(); plt.show()\n

  3. \u5b9e\u73b0\u4e00\u4e2a\u7b80\u5316\u7248\u7684\u4e00\u7ef4U-Net\u7f16\u7801\u5668-\u89e3\u7801\u5668\uff0c\u5e26\u6709\u8df3\u8dc3\u8fde\u63a5\uff0c\u7528\u4e8e\u4e00\u7ef4\u5206\u5272\uff08\u4e00\u7ef4\u4fe1\u53f7\u7684\u4e8c\u503c\u6807\u6ce8\uff09\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef conv1d_same(x, kernel):\n    \"\"\"\u5177\u6709\u76f8\u540c\u586b\u5145\u7684\u4e00\u7ef4\u5377\u79ef\u3002\"\"\"\n    k = len(kernel)\n    pad = k // 2\n    x_pad = jnp.pad(x, pad, mode='edge')\n    n = len(x)\n    out = jnp.zeros(n)\n    for i in range(n):\n        out = out.at[i].set(jnp.sum(x_pad[i:i+k] * kernel))\n    return out\n\ndef downsample(x):\n    return x[::2]\n\ndef upsample(x, target_len):\n    return jnp.interp(jnp.linspace(0, 1, target_len), jnp.linspace(0, 1, len(x)), x)\n\ndef unet_1d(x, params):\n    \"\"\"\u7b80\u5316\u7248\u4e00\u7ef4U-Net\uff0c\u5305\u542b2\u4e2a\u7f16\u7801\u5668/\u89e3\u7801\u5668\u5c42\u7ea7\u3002\"\"\"\n    # \u7f16\u7801\u5668\n    e1 = jnp.maximum(0, conv1d_same(x, params['enc1']))\n    e1_down = downsample(e1)\n\n    e2 = jnp.maximum(0, conv1d_same(e1_down, params['enc2']))\n    e2_down = downsample(e2)\n\n    # \u74f6\u9888\u5c42\n    bottleneck = jnp.maximum(0, conv1d_same(e2_down, params['bottleneck']))\n\n    # \u5e26\u8df3\u8dc3\u8fde\u63a5\u7684\u89e3\u7801\u5668\n    d2_up = upsample(bottleneck, len(e2))\n    d2 = jnp.maximum(0, conv1d_same(d2_up + e2, params['dec2']))  # \u8df3\u8dc3\u8fde\u63a5\n\n    d1_up = upsample(d2, len(e1))\n    d1 = conv1d_same(d1_up + e1, params['dec1'])  # \u8df3\u8dc3\u8fde\u63a5\n\n    return jax.nn.sigmoid(d1)\n\n# \u521b\u5efa\u5e26\u6709\u6807\u6ce8\u533a\u57df\u7684\u4fe1\u53f7\nn = 128\nt = jnp.linspace(0, 4 * jnp.pi, n)\nsignal = jnp.sin(t) + 0.5 * jnp.sin(3 * t)\nlabels = (signal > 0.5).astype(jnp.float32)  # \u4e8c\u503c\u5206\u5272\u76ee\u6807\n\nkey = jax.random.PRNGKey(42)\nkeys = jax.random.split(key, 5)\nparams = {\n    'enc1': jax.random.normal(keys[0], (5,)) * 0.3,\n    'enc2': jax.random.normal(keys[1], (5,)) * 0.3,\n    'bottleneck': jax.random.normal(keys[2], (3,)) * 0.3,\n    'dec2': jax.random.normal(keys[3], (5,)) * 0.3,\n    'dec1': jax.random.normal(keys[4], (5,)) * 0.3,\n}\n\ndef loss_fn(params, signal, labels):\n    pred = unet_1d(signal, params)\n    return -jnp.mean(labels * jnp.log(pred + 1e-7) + (1 - labels) * jnp.log(1 - pred + 1e-7))\n\ngrad_fn = jax.jit(jax.grad(loss_fn))\nlr = 0.05\n\nfor step in range(500):\n    grads = grad_fn(params, signal, labels)\n    params = {k: params[k] - lr * grads[k] for k in params}\n\npred = unet_1d(signal, params)\n\nfig, axes = plt.subplots(3, 1, figsize=(12, 7), sharex=True)\naxes[0].plot(t, signal, color='#3498db', linewidth=1.5)\naxes[0].set_title('\u8f93\u5165\u4fe1\u53f7'); axes[0].set_ylabel('\u503c')\n\naxes[1].fill_between(t, 0, labels, alpha=0.3, color='#27ae60')\naxes[1].set_title('\u771f\u5b9e\u6807\u6ce8'); axes[1].set_ylabel('\u6807\u7b7e')\n\naxes[2].plot(t, pred, color='#e74c3c', linewidth=1.5)\naxes[2].fill_between(t, 0, (pred > 0.5).astype(float), alpha=0.2, color='#e74c3c')\naxes[2].set_title('U-Net\u9884\u6d4b'); axes[2].set_ylabel('\u6982\u7387')\naxes[2].set_xlabel('t')\n\nplt.tight_layout(); plt.show()\nprint(f\"\u6700\u7ec8\u635f\u5931\uff1a{loss_fn(params, signal, labels):.4f}\")\nprint(f\"\u50cf\u7d20\u51c6\u786e\u7387\uff1a{jnp.mean((pred > 0.5) == labels):.2%}\")\n

"},{"location":"chapter%2008%3A%20computer%20vision/04.%20vision%20transformers%20and%20generation/","title":"\u89c6\u89c9Transformer\u4e0e\u751f\u6210\u6a21\u578b","text":"

\u89c6\u89c9Transformer\u5c06\u81ea\u6ce8\u610f\u529b\u5e94\u7528\u4e8e\u56fe\u50cf\u5757\uff0c\u901a\u8fc7\u6570\u636e\u9a71\u52a8\u7684\u7a7a\u95f4\u5b66\u4e60\u6311\u6218\u4e86CNN\u7684\u4e3b\u5bfc\u5730\u4f4d\u3002\u672c\u6587\u6db5\u76d6ViT\u3001DeiT\u3001Swin Transformer\u3001\u57fa\u4e8eGAN\u7684\u56fe\u50cf\u751f\u6210\uff08StyleGAN\uff09\u3001VAE\u548c\u6269\u6563\u6a21\u578b\uff08DDPM\u3001Stable Diffusion\uff09\uff0c\u4ee5\u53ca\u8d85\u5206\u8fa8\u7387\u548c\u795e\u7ecf\u98ce\u683c\u8fc1\u79fb\u3002

\\[\\ell_{i,j} = -\\log \\frac{\\exp(\\text{sim}(z_i, z_j) / \\tau)}{\\sum_{k \\neq i} \\exp(\\text{sim}(z_i, z_k) / \\tau)}\\] \\[\\min_G \\max_D \\; \\mathbb{E}_{x \\sim p_{\\text{data}}}[\\log D(x)] + \\mathbb{E}_{z \\sim p(z)}[\\log(1 - D(G(z)))]\\] \\[\\text{AdaIN}(x, y) = y_{s} \\cdot \\frac{x - \\mu(x)}{\\sigma(x)} + y_{b}\\] \\[q(x_t | x_{t-1}) = \\mathcal{N}(x_t; \\sqrt{1 - \\beta_t} \\, x_{t-1}, \\beta_t I)\\] \\[x_t = \\sqrt{\\bar{\\alpha}_t} \\, x_0 + \\sqrt{1 - \\bar{\\alpha}_t} \\, \\epsilon, \\quad \\epsilon \\sim \\mathcal{N}(0, I)\\] \\[\\mathcal{L} = \\mathbb{E}_{t, x_0, \\epsilon}\\left[\\|\\epsilon - \\epsilon_\\theta(x_t, t)\\|^2\\right]\\]

\\[\\hat{\\epsilon} = \\epsilon_\\theta(x_t, \\varnothing) + s \\cdot (\\epsilon_\\theta(x_t, c) - \\epsilon_\\theta(x_t, \\varnothing))\\] \\[\\frac{dx}{dt} = v_\\theta(x, t), \\quad t \\in [0, 1]\\] \\[\\mathcal{L} = \\mathbb{E}_{t, x_0, x_1} \\left[\\|v_\\theta(x_t, t) - (x_1 - x_0)\\|^2\\right]\\] "},{"location":"chapter%2008%3A%20computer%20vision/04.%20vision%20transformers%20and%20generation/#colabnotebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528CoLab\u6216notebook\uff09","text":"
  1. \u4ece\u5934\u5b9e\u73b0ViT\u56fe\u50cf\u5757\u5d4c\u5165\u3002\u5c06\u56fe\u50cf\u5206\u5272\u6210\u56fe\u50cf\u5757\uff0c\u5c55\u5e73\uff0c\u6295\u5f71\u5230\u6a21\u578b\u7ef4\u5ea6\uff0c\u6dfb\u52a0\u4f4d\u7f6e\u5d4c\u5165\uff0c\u5e76\u524d\u7f6e[CLS]\u6807\u8bb0\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef create_patch_embedding(image, patch_size, d_model, params):\n    \"\"\"\u5c06\u56fe\u50cf\u8f6c\u6362\u4e3a\u56fe\u50cf\u5757\u5d4c\u5165\u5e8f\u5217\u3002\"\"\"\n    H, W, C = image.shape\n    n_patches_h = H // patch_size\n    n_patches_w = W // patch_size\n    n_patches = n_patches_h * n_patches_w\n\n    # \u63d0\u53d6\u56fe\u50cf\u5757\n    patches = []\n    for i in range(n_patches_h):\n        for j in range(n_patches_w):\n            patch = image[i*patch_size:(i+1)*patch_size,\n                          j*patch_size:(j+1)*patch_size, :]\n            patches.append(patch.ravel())\n    patches = jnp.stack(patches)  # (N, P*P*C)\n\n    # \u7ebf\u6027\u6295\u5f71\u5230d_model\n    embeddings = patches @ params['proj_w'] + params['proj_b']  # (N, d_model)\n\n    # \u524d\u7f6eCLS\u6807\u8bb0\n    cls_token = params['cls_token']  # (1, d_model)\n    embeddings = jnp.concatenate([cls_token, embeddings], axis=0)  # (N+1, d_model)\n\n    # \u6dfb\u52a0\u4f4d\u7f6e\u5d4c\u5165\n    embeddings = embeddings + params['pos_embed']  # (N+1, d_model)\n\n    return embeddings, patches\n\n# \u8bbe\u7f6e\nH, W, C = 32, 32, 3\npatch_size = 8\nd_model = 64\nn_patches = (H // patch_size) * (W // patch_size)  # 16\n\nkey = jax.random.PRNGKey(42)\nkeys = jax.random.split(key, 5)\n\n# \u521b\u5efa\u5177\u6709\u4e0d\u540c\u8c61\u9650\u7684\u5408\u6210\u56fe\u50cf\nimage = jnp.zeros((H, W, C))\nimage = image.at[:16, :16, 0].set(1.0)   # \u7ea2\u8272 \u5de6\u4e0a\nimage = image.at[:16, 16:, 1].set(1.0)   # \u7eff\u8272 \u53f3\u4e0a\nimage = image.at[16:, :16, 2].set(1.0)   # \u84dd\u8272 \u5de6\u4e0b\nimage = image.at[16:, 16:, :2].set(1.0)  # \u9ec4\u8272 \u53f3\u4e0b\n\nparams = {\n    'proj_w': jax.random.normal(keys[0], (patch_size**2 * C, d_model)) * 0.02,\n    'proj_b': jnp.zeros(d_model),\n    'cls_token': jax.random.normal(keys[1], (1, d_model)) * 0.02,\n    'pos_embed': jax.random.normal(keys[2], (n_patches + 1, d_model)) * 0.02,\n}\n\nembeddings, patches = create_patch_embedding(image, patch_size, d_model, params)\n\nprint(f\"\u56fe\u50cf\u5f62\u72b6: {image.shape}\")\nprint(f\"\u56fe\u50cf\u5757\u5927\u5c0f: {patch_size}x{patch_size}\")\nprint(f\"\u56fe\u50cf\u5757\u6570\u91cf: {n_patches}\")\nprint(f\"\u56fe\u50cf\u5757\u5411\u91cf\u957f\u5ea6: {patch_size**2 * C}\")\nprint(f\"\u5d4c\u5165\u5f62\u72b6: {embeddings.shape}  (CLS + {n_patches} \u4e2a\u56fe\u50cf\u5757)\")\n\n# \u53ef\u89c6\u5316\u56fe\u50cf\u5757\nfig, axes = plt.subplots(2, 5, figsize=(14, 6))\naxes[0, 0].imshow(image); axes[0, 0].set_title('\u5b8c\u6574\u56fe\u50cf'); axes[0, 0].axis('off')\nfor idx in range(min(9, n_patches)):\n    ax = axes[(idx+1) // 5, (idx+1) % 5]\n    patch_img = patches[idx].reshape(patch_size, patch_size, C)\n    ax.imshow(patch_img); ax.set_title(f'\u56fe\u50cf\u5757 {idx}'); ax.axis('off')\nplt.suptitle('ViT \u56fe\u50cf\u5757\u5206\u89e3')\nplt.tight_layout(); plt.show()\n

  2. \u5b9e\u73b0\u4e00\u4e2a\u7b80\u5355\u7684GAN\u8bad\u7ec3\u5faa\u73af\u3002\u5728\u4e8c\u7ef4\u6570\u636e\u4e0a\u8bad\u7ec3\u751f\u6210\u5668\u548c\u5224\u522b\u5668\uff0c\u5e76\u53ef\u89c6\u5316\u751f\u6210\u5206\u5e03\u9010\u6e10\u6536\u655b\u5230\u771f\u5b9e\u5206\u5e03\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef generator(z, params):\n    h = jnp.tanh(z @ params['g_w1'] + params['g_b1'])\n    h = jnp.tanh(h @ params['g_w2'] + params['g_b2'])\n    return h @ params['g_w3'] + params['g_b3']\n\ndef discriminator(x, params):\n    h = jax.nn.leaky_relu(x @ params['d_w1'] + params['d_b1'], 0.2)\n    h = jax.nn.leaky_relu(h @ params['d_w2'] + params['d_b2'], 0.2)\n    return jax.nn.sigmoid(h @ params['d_w3'] + params['d_b3'])\n\ndef init_params(key):\n    keys = jax.random.split(key, 6)\n    z_dim, h_dim, data_dim = 2, 32, 2\n    scale = 0.1\n    return {\n        'g_w1': jax.random.normal(keys[0], (z_dim, h_dim)) * scale,\n        'g_b1': jnp.zeros(h_dim),\n        'g_w2': jax.random.normal(keys[1], (h_dim, h_dim)) * scale,\n        'g_b2': jnp.zeros(h_dim),\n        'g_w3': jax.random.normal(keys[2], (h_dim, data_dim)) * scale,\n        'g_b3': jnp.zeros(data_dim),\n        'd_w1': jax.random.normal(keys[3], (data_dim, h_dim)) * scale,\n        'd_b1': jnp.zeros(h_dim),\n        'd_w2': jax.random.normal(keys[4], (h_dim, h_dim)) * scale,\n        'd_b2': jnp.zeros(h_dim),\n        'd_w3': jax.random.normal(keys[5], (h_dim, 1)) * scale,\n        'd_b3': jnp.zeros(1),\n    }\n\ndef d_loss(params, real_data, fake_data):\n    real_score = discriminator(real_data, params)\n    fake_score = discriminator(fake_data, params)\n    return -jnp.mean(jnp.log(real_score + 1e-7) + jnp.log(1 - fake_score + 1e-7))\n\ndef g_loss(params, fake_data):\n    fake_score = discriminator(fake_data, params)\n    return -jnp.mean(jnp.log(fake_score + 1e-7))\n\n# \u771f\u5b9e\u6570\u636e\uff1a\u73af\u5f62\u5206\u5e03\nkey = jax.random.PRNGKey(42)\ntheta = jax.random.uniform(key, (512,)) * 2 * jnp.pi\nreal_data = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=1)\nreal_data = real_data + jax.random.normal(key, real_data.shape) * 0.05\n\nparams = init_params(jax.random.PRNGKey(0))\nd_grad = jax.grad(d_loss)\ng_grad = jax.grad(g_loss)\nlr = 0.001\n\nsnapshots = []\nfor step in range(3000):\n    key, k1 = jax.random.split(key)\n    z = jax.random.normal(k1, (512, 2))\n    fake_data = generator(z, params)\n\n    # \u66f4\u65b0\u5224\u522b\u5668\n    grads = d_grad(params, real_data, fake_data)\n    for k in ['d_w1', 'd_b1', 'd_w2', 'd_b2', 'd_w3', 'd_b3']:\n        params[k] = params[k] - lr * grads[k]\n\n    # \u66f4\u65b0\u751f\u6210\u5668\n    fake_data = generator(z, params)\n    grads = g_grad(params, fake_data)\n    for k in ['g_w1', 'g_b1', 'g_w2', 'g_b2', 'g_w3', 'g_b3']:\n        params[k] = params[k] - lr * grads[k]\n\n    if step in [0, 500, 1500, 2999]:\n        snapshots.append((step, fake_data.copy()))\n\nfig, axes = plt.subplots(1, 4, figsize=(16, 4))\nfor ax, (step, fake) in zip(axes, snapshots):\n    ax.scatter(real_data[:, 0], real_data[:, 1], s=5, alpha=0.3, c='#3498db', label='\u771f\u5b9e')\n    ax.scatter(fake[:, 0], fake[:, 1], s=5, alpha=0.3, c='#e74c3c', label='\u751f\u6210')\n    ax.set_title(f'\u6b65\u9aa4 {step}'); ax.set_xlim(-2, 2); ax.set_ylim(-2, 2)\n    ax.set_aspect('equal'); ax.legend(markerscale=3)\nplt.suptitle('GAN\u8bad\u7ec3\uff1a\u751f\u6210\u5668\u5b66\u4e60\u73af\u5f62\u5206\u5e03')\nplt.tight_layout(); plt.show()\n

  3. \u5b9e\u73b0\u6269\u6563\u524d\u5411\u8fc7\u7a0b\uff1a\u5728\u4e0d\u540c\u65f6\u95f4\u6b65\u5411\u56fe\u50cf\u6dfb\u52a0\u566a\u58f0\uff0c\u5e76\u53ef\u89c6\u5316\u9010\u6b65\u7834\u574f\u8fc7\u7a0b\u3002\u7136\u540e\u5b9e\u73b0\u5355\u6b65\u53bb\u566a\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef noise_schedule(T, beta_start=0.0001, beta_end=0.02):\n    \"\"\"\u7ebf\u6027\u566a\u58f0\u8c03\u5ea6\u3002\"\"\"\n    betas = jnp.linspace(beta_start, beta_end, T)\n    alphas = 1.0 - betas\n    alpha_bars = jnp.cumprod(alphas)\n    return betas, alphas, alpha_bars\n\ndef forward_diffusion(x0, t, alpha_bars, key):\n    \"\"\"\u5728\u65f6\u95f4\u6b65t\u5411x0\u6dfb\u52a0\u566a\u58f0\u3002\"\"\"\n    alpha_bar_t = alpha_bars[t]\n    noise = jax.random.normal(key, x0.shape)\n    xt = jnp.sqrt(alpha_bar_t) * x0 + jnp.sqrt(1 - alpha_bar_t) * noise\n    return xt, noise\n\n# \u521b\u5efa\u7b80\u5355\u76842D\"\u56fe\u50cf\"\uff08\u68cb\u76d8\u683c\uff09\nimg = jnp.zeros((32, 32))\nfor i in range(4):\n    for j in range(4):\n        if (i + j) % 2 == 0:\n            img = img.at[i*8:(i+1)*8, j*8:(j+1)*8].set(1.0)\n\nT = 1000\nbetas, alphas, alpha_bars = noise_schedule(T)\n\n# \u53ef\u89c6\u5316\u524d\u5411\u8fc7\u7a0b\ntimesteps = [0, 50, 200, 500, 999]\nkey = jax.random.PRNGKey(42)\n\nfig, axes = plt.subplots(1, len(timesteps), figsize=(16, 3.5))\nfor ax, t in zip(axes, timesteps):\n    key, subkey = jax.random.split(key)\n    xt, noise = forward_diffusion(img, t, alpha_bars, subkey)\n    ax.imshow(xt, cmap='gray', vmin=-2, vmax=2)\n    ax.set_title(f't={t}\\n$\\\\bar{{\\\\alpha}}$={alpha_bars[t]:.3f}')\n    ax.axis('off')\nplt.suptitle('\u6269\u6563\u524d\u5411\u8fc7\u7a0b\uff1a\u9010\u6b65\u6dfb\u52a0\u566a\u58f0')\nplt.tight_layout(); plt.show()\n\n# \u7b80\u5355\u53bb\u566a\uff1a\u8bad\u7ec3\u5c0f\u578b\u7f51\u7edc\u5728t=200\u65f6\u9884\u6d4b\u566a\u58f0\nt_denoise = 200\nkey, k1 = jax.random.split(key)\nxt, true_noise = forward_diffusion(img, t_denoise, alpha_bars, k1)\n\n# \u5c0f\u578b\"\u53bb\u566a\u5668\"\uff1a\u4ec5\u5b66\u4e60\u6052\u5b9a\u7684\u566a\u58f0\u4f30\u8ba1\uff08\u7528\u4e8e\u6f14\u793a\uff09\nnoise_estimate = jnp.zeros_like(img)\nlr = 0.01\nfor step in range(100):\n    residual = noise_estimate - true_noise\n    noise_estimate = noise_estimate - lr * residual\n\n# \u53cd\u5411\u4e00\u6b65\nalpha_bar_t = alpha_bars[t_denoise]\nx_denoised = (xt - jnp.sqrt(1 - alpha_bar_t) * noise_estimate) / jnp.sqrt(alpha_bar_t)\n\nfig, axes = plt.subplots(1, 3, figsize=(12, 4))\naxes[0].imshow(img, cmap='gray'); axes[0].set_title('\u539f\u59cb $x_0$'); axes[0].axis('off')\naxes[1].imshow(xt, cmap='gray', vmin=-2, vmax=2)\naxes[1].set_title(f'\u542b\u566a $x_{{200}}$'); axes[1].axis('off')\naxes[2].imshow(x_denoised, cmap='gray')\naxes[2].set_title('\u53bb\u566a\u540e\uff08\u5355\u6b65\uff09'); axes[2].axis('off')\nplt.tight_layout(); plt.show()\n\nmse = jnp.mean((x_denoised - img)**2)\nprint(f\"\u53bb\u566aMSE: {mse:.4f}\")\n

"},{"location":"chapter%2008%3A%20computer%20vision/05.%20video%20and%203D%20vision/","title":"\u89c6\u9891\u4e0e3D\u89c6\u89c9","text":"

\u89c6\u9891\u4e0e3D\u89c6\u89c9\u5c06\u56fe\u50cf\u7406\u89e3\u6269\u5c55\u5230\u65f6\u95f4\u57df\u548c\u7a7a\u95f4\u57df\u3002\u672c\u6587\u6db5\u76d6\u5149\u6d41\u3001\u89c6\u9891\u5206\u7c7b\uff083D\u5377\u79ef\u7f51\u7edc\u3001TimeSformer\uff09\u3001\u76ee\u6807\u8ddf\u8e2a\uff08SORT\u3001DeepSORT\uff09\u3001\u52a8\u4f5c\u8bc6\u522b\u3001\u6df1\u5ea6\u4f30\u8ba1\uff08\u5355\u76ee\u4e0e\u7acb\u4f53\uff09\u3001\u70b9\u4e91\u3001\u795e\u7ecf\u8f90\u5c04\u573a\uff08NeRF\uff09\u548c3D\u9ad8\u65af\u6cfc\u6e85\u3002

\\[I(x + u\\delta t, \\, y + v\\delta t, \\, t + \\delta t) = I(x, y, t)\\] \\[I_x u + I_y v + I_t = 0\\] \\[ \\begin{bmatrix} u \\\\ v \\end{bmatrix} = \\begin{bmatrix} \\sum I_x^2 & \\sum I_x I_y \\\\ \\sum I_x I_y & \\sum I_y^2 \\end{bmatrix}^{-1} \\begin{bmatrix} -\\sum I_x I_t \\\\ -\\sum I_y I_t \\end{bmatrix} \\] \\[Z = \\frac{f \\cdot b}{d}\\] \\[F_\\theta: (x, y, z, \\theta, \\phi) \\to (r, g, b, \\sigma)\\] \\[C(\\mathbf{r}) = \\int_{t_n}^{t_f} T(t) \\cdot \\sigma(\\mathbf{r}(t)) \\cdot \\mathbf{c}(\\mathbf{r}(t), \\mathbf{d}) \\, dt\\] \\[\\hat{C} = \\sum_{i=1}^{N} T_i \\cdot (1 - \\exp(-\\sigma_i \\delta_i)) \\cdot c_i\\] "},{"location":"chapter%2008%3A%20computer%20vision/05.%20video%20and%203D%20vision/#colabnotebook","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528CoLab\u6216Notebook\uff09","text":"
  1. \u4ece\u5934\u5b9e\u73b0Lucas-Kanade\u5149\u6d41\u7b97\u6cd5\u3002\u8ba1\u7b97\u4e00\u4e2a\u65b9\u5757\u5411\u53f3\u79fb\u52a8\u7684\u4e24\u5e27\u5408\u6210\u56fe\u50cf\u4e4b\u95f4\u7684\u5149\u6d41\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef lucas_kanade(frame1, frame2, window_size=5):\n    \"\"\"Lucas-Kanade\u5149\u6d41\u3002\"\"\"\n    # \u8ba1\u7b97\u68af\u5ea6\n    Ix = jnp.zeros_like(frame1)\n    Iy = jnp.zeros_like(frame1)\n    It = frame2 - frame1\n\n    # Sobel\u98ce\u683c\u68af\u5ea6\n    Ix = Ix.at[1:-1, :].set((frame1[2:, :] - frame1[:-2, :]) / 2)\n    Iy = Iy.at[:, 1:-1].set((frame1[:, 2:] - frame1[:, :-2]) / 2)\n\n    H, W = frame1.shape\n    half_w = window_size // 2\n    u = jnp.zeros_like(frame1)\n    v = jnp.zeros_like(frame1)\n\n    for i in range(half_w, H - half_w):\n        for j in range(half_w, W - half_w):\n            Ix_win = Ix[i-half_w:i+half_w+1, j-half_w:j+half_w+1].ravel()\n            Iy_win = Iy[i-half_w:i+half_w+1, j-half_w:j+half_w+1].ravel()\n            It_win = It[i-half_w:i+half_w+1, j-half_w:j+half_w+1].ravel()\n\n            A = jnp.stack([Ix_win, Iy_win], axis=1)\n            ATA = A.T @ A\n            ATb = -A.T @ It_win\n\n            # \u68c0\u67e5\u7cfb\u7edf\u662f\u5426\u826f\u6001\n            det = ATA[0,0] * ATA[1,1] - ATA[0,1] * ATA[1,0]\n            if jnp.abs(det) > 1e-6:\n                flow = jnp.linalg.solve(ATA, ATb)\n                u = u.at[i, j].set(flow[0])\n                v = v.at[i, j].set(flow[1])\n\n    return u, v\n\n# \u521b\u5efa\u4e24\u5e27\uff1a\u4e00\u4e2a\u5411\u53f3\u79fb\u52a8\u7684\u767d\u8272\u65b9\u5757\nframe1 = jnp.zeros((64, 64))\nframe1 = frame1.at[20:40, 15:35].set(1.0)\n\nframe2 = jnp.zeros((64, 64))\nframe2 = frame2.at[20:40, 20:40].set(1.0)  # \u5411\u53f3\u79fb\u52a85\u4e2a\u50cf\u7d20\n\nu, v = lucas_kanade(frame1, frame2, window_size=7)\n\n# \u53ef\u89c6\u5316\nfig, axes = plt.subplots(1, 3, figsize=(14, 4))\naxes[0].imshow(frame1, cmap='gray'); axes[0].set_title('\u5e271'); axes[0].axis('off')\naxes[1].imshow(frame2, cmap='gray'); axes[1].set_title('\u5e272'); axes[1].axis('off')\n\n# \u5149\u6d41\u7684\u7bad\u77e2\u56fe\uff08\u4e3a\u6e05\u6670\u8d77\u89c1\u964d\u91c7\u6837\uff09\nstep = 4\nY, X = jnp.mgrid[0:64:step, 0:64:step]\naxes[2].imshow(frame1, cmap='gray', alpha=0.5)\naxes[2].quiver(X, Y, u[::step, ::step], v[::step, ::step],\n               color='#e74c3c', scale=50, width=0.005)\naxes[2].set_title('\u5149\u6d41'); axes[2].axis('off')\n\nplt.tight_layout(); plt.show()\n\n# \u68c0\u67e5\u8fd0\u52a8\u533a\u57df\u7684\u5e73\u5747\u5149\u6d41\nregion_u = u[20:40, 15:35]\nprint(f\"\u7269\u4f53\u533a\u57df\u7684\u5e73\u5747\u6c34\u5e73\u5149\u6d41: {region_u[region_u != 0].mean():.2f} \u50cf\u7d20\")\n

  2. \u5b9e\u73b0\u4e00\u4e2a\u7528\u4e8e2D\u76ee\u6807\u8ddf\u8e2a\u7684\u7b80\u5355\u5361\u5c14\u66fc\u6ee4\u6ce2\u5668\u3002\u6a21\u62df\u4e00\u4e2a\u5e26\u566a\u58f0\u7684\u8f68\u8ff9\uff0c\u5e76\u5c55\u793a\u5361\u5c14\u66fc\u6ee4\u6ce2\u5668\u5982\u4f55\u5e73\u6ed1\u4f30\u8ba1\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef kalman_predict(x, P, F, Q):\n    \"\"\"\u5361\u5c14\u66fc\u6ee4\u6ce2\u5668\u9884\u6d4b\u6b65\u9aa4\u3002\"\"\"\n    x_pred = F @ x\n    P_pred = F @ P @ F.T + Q\n    return x_pred, P_pred\n\ndef kalman_update(x_pred, P_pred, z, H, R):\n    \"\"\"\u5361\u5c14\u66fc\u6ee4\u6ce2\u5668\u66f4\u65b0\u6b65\u9aa4\u3002\"\"\"\n    y = z - H @ x_pred                        # \u521b\u65b0\n    S = H @ P_pred @ H.T + R                  # \u521b\u65b0\u534f\u65b9\u5dee\n    K = P_pred @ H.T @ jnp.linalg.inv(S)      # \u5361\u5c14\u66fc\u589e\u76ca\n    x_updated = x_pred + K @ y\n    P_updated = (jnp.eye(len(x_pred)) - K @ H) @ P_pred\n    return x_updated, P_updated\n\n# \u72b6\u6001: [x, y, vx, vy]\ndt = 1.0\nF = jnp.array([[1, 0, dt, 0],    # \u72b6\u6001\u8f6c\u79fb\n                [0, 1, 0, dt],\n                [0, 0, 1, 0],\n                [0, 0, 0, 1]])\nH = jnp.array([[1, 0, 0, 0],     # \u89c2\u6d4b\uff1a\u6d4b\u91cf x, y\n                [0, 1, 0, 0]])\nQ = jnp.eye(4) * 0.01            # \u8fc7\u7a0b\u566a\u58f0\nR = jnp.eye(2) * 4.0             # \u6d4b\u91cf\u566a\u58f0\uff08\u6709\u566a\u58f0\u7684\u68c0\u6d4b\u5668\uff09\n\n# \u6a21\u62df\u771f\u5b9e\u8f68\u8ff9\uff1a\u5706\u5468\u8fd0\u52a8\nn_steps = 50\nt = jnp.linspace(0, 2 * jnp.pi, n_steps)\ntrue_x = 10 * jnp.cos(t) + 20\ntrue_y = 10 * jnp.sin(t) + 20\n\n# \u5e26\u566a\u58f0\u7684\u89c2\u6d4b\nkey = jax.random.PRNGKey(42)\nnoise = jax.random.normal(key, (n_steps, 2)) * 2.0\nobs_x = true_x + noise[:, 0]\nobs_y = true_y + noise[:, 1]\n\n# \u8fd0\u884c\u5361\u5c14\u66fc\u6ee4\u6ce2\u5668\nx = jnp.array([obs_x[0], obs_y[0], 0.0, 0.0])  # \u521d\u59cb\u72b6\u6001\nP = jnp.eye(4) * 10.0                             # \u521d\u59cb\u4e0d\u786e\u5b9a\u6027\n\nkalman_x, kalman_y = [], []\nfor i in range(n_steps):\n    x, P = kalman_predict(x, P, F, Q)\n    z = jnp.array([obs_x[i], obs_y[i]])\n    x, P = kalman_update(x, P, z, H, R)\n    kalman_x.append(x[0])\n    kalman_y.append(x[1])\n\nkalman_x = jnp.array(kalman_x)\nkalman_y = jnp.array(kalman_y)\n\n# \u53ef\u89c6\u5316\nplt.figure(figsize=(8, 8))\nplt.plot(true_x, true_y, 'k-', linewidth=2, label='\u771f\u5b9e\u8f68\u8ff9')\nplt.scatter(obs_x, obs_y, c='#e74c3c', s=20, alpha=0.5, label='\u5e26\u566a\u58f0\u7684\u89c2\u6d4b')\nplt.plot(kalman_x, kalman_y, '#3498db', linewidth=2, label='\u5361\u5c14\u66fc\u6ee4\u6ce2')\nplt.legend(); plt.grid(alpha=0.3)\nplt.title('\u5361\u5c14\u66fc\u6ee4\u6ce2\u8ddf\u8e2a')\nplt.xlabel('x'); plt.ylabel('y')\nplt.axis('equal'); plt.show()\n\nobs_error = jnp.mean(jnp.sqrt((obs_x - true_x)**2 + (obs_y - true_y)**2))\nkalman_error = jnp.mean(jnp.sqrt((kalman_x - true_x)**2 + (kalman_y - true_y)**2))\nprint(f\"\u89c2\u6d4bRMSE: {obs_error:.2f}\")\nprint(f\"\u5361\u5c14\u66fc\u6ee4\u6ce2RMSE: {kalman_error:.2f}\")\nprint(f\"\u8bef\u5dee\u964d\u4f4e: {(1 - kalman_error/obs_error) * 100:.1f}%\")\n

  3. \u5b9e\u73b0\u4e00\u4e2a\u7b80\u5316\u7684NeRF\u98ce\u683c\u4f53\u6e32\u67d3\u7ba1\u7ebf\u3002\u901a\u8fc7\u4e00\u4e2a\u7b80\u5355\u76843D\u573a\u666f\uff08\u5df2\u77e5\u989c\u8272\u548c\u5bc6\u5ea6\u7684\u7403\u4f53\uff09\u6295\u5c04\u5c04\u7ebf\uff0c\u5e76\u6cbf\u6bcf\u6761\u5c04\u7ebf\u79ef\u5206\u6765\u6e32\u67d3\u56fe\u50cf\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef render_ray(origin, direction, spheres, n_samples=64, t_near=1.0, t_far=6.0):\n    \"\"\"\u7a7f\u8fc7\u7403\u4f53\u573a\u666f\u5bf9\u5355\u6761\u5c04\u7ebf\u8fdb\u884c\u4f53\u6e32\u67d3\u3002\"\"\"\n    t_vals = jnp.linspace(t_near, t_far, n_samples)\n    deltas = jnp.concatenate([jnp.diff(t_vals), jnp.array([1e-3])])\n\n    colour = jnp.zeros(3)\n    transmittance = 1.0\n\n    for i in range(n_samples):\n        point = origin + t_vals[i] * direction\n\n        # \u8ba1\u7b97\u8be5\u70b9\u7684\u5bc6\u5ea6\u548c\u989c\u8272\n        density = 0.0\n        point_colour = jnp.zeros(3)\n\n        for center, radius, col, sigma in spheres:\n            dist = jnp.linalg.norm(point - center)\n            # \u8f6f\u7403\u4f53\uff1a\u5bc6\u5ea6\u968f\u8ddd\u8868\u9762\u7684\u8ddd\u79bb\u6307\u6570\u8870\u51cf\n            d = jnp.exp(-jnp.maximum(0, dist - radius) * sigma) * sigma\n            density += d\n            point_colour += d * jnp.array(col)\n\n        # \u6309\u603b\u5bc6\u5ea6\u5f52\u4e00\u5316\u989c\u8272\n        point_colour = jnp.where(density > 1e-6, point_colour / density, point_colour)\n\n        # \u4f53\u6e32\u67d3\u65b9\u7a0b\n        alpha = 1.0 - jnp.exp(-density * deltas[i])\n        colour += transmittance * alpha * point_colour\n        transmittance *= (1.0 - alpha)\n\n    return colour\n\n# \u573a\u666f\uff1a\u4e09\u4e2a\u5f69\u8272\u7403\u4f53\nspheres = [\n    (jnp.array([0.0, 0.0, 4.0]), 0.8, [1.0, 0.2, 0.2], 5.0),   # \u7ea2\u8272\n    (jnp.array([1.5, 0.5, 5.0]), 0.6, [0.2, 1.0, 0.2], 5.0),   # \u7eff\u8272\n    (jnp.array([-1.0, -0.5, 3.5]), 0.5, [0.2, 0.2, 1.0], 5.0), # \u84dd\u8272\n]\n\n# \u76f8\u673a\u8bbe\u7f6e\nimg_h, img_w = 64, 64\nfocal = 60.0\norigin = jnp.array([0.0, 0.0, 0.0])\n\nimage = jnp.zeros((img_h, img_w, 3))\nfor i in range(img_h):\n    for j in range(img_w):\n        # \u8ba1\u7b97\u5c04\u7ebf\u65b9\u5411\n        px = (j - img_w / 2) / focal\n        py = -(i - img_h / 2) / focal\n        direction = jnp.array([px, py, 1.0])\n        direction = direction / jnp.linalg.norm(direction)\n\n        colour = render_ray(origin, direction, spheres)\n        image = image.at[i, j].set(jnp.clip(colour, 0, 1))\n\nplt.figure(figsize=(6, 6))\nplt.imshow(image)\nplt.title('NeRF\u98ce\u683c\u4f53\u6e32\u67d3\\n(3\u4e2a\u7403\u4f53)')\nplt.axis('off')\nplt.tight_layout(); plt.show()\nprint(f\"\u56fe\u50cf\u5f62\u72b6: {image.shape}\")\nprint(f\"\u6e32\u67d3\u4e86 {img_h * img_w} \u6761\u5c04\u7ebf\uff0c\u6bcf\u6761 {64} \u4e2a\u91c7\u6837\u70b9\")\n

"},{"location":"chapter%2009%3A%20audio%20and%20speech/01.%20digital%20signal%20processing/","title":"\u6570\u5b57\u4fe1\u53f7\u5904\u7406","text":"

\u6570\u5b57\u4fe1\u53f7\u5904\u7406\u5c06\u539f\u59cb\u97f3\u9891\u6ce2\u5f62\u8f6c\u6362\u4e3a\u7ed3\u6784\u5316\u8868\u793a\uff0c\u673a\u5668\u5b66\u4e60\u6a21\u578b\u53ef\u4ee5\u4ece\u4e2d\u5b66\u4e60\u3002\u672c\u6587\u6db5\u76d6\u58f0\u97f3\u7269\u7406\u5b66\u3001\u91c7\u6837\u4e0e\u91cf\u5316\u3001\u5085\u91cc\u53f6\u53d8\u6362\uff08DFT\u3001FFT\uff09\u3001\u8bed\u8c31\u56fe\u3001\u6885\u5c14\u6ee4\u6ce2\u5668\u7ec4\u3001MFCC \u548c\u52a0\u7a97\uff0c\u4ee5\u53ca\u6240\u6709\u8bed\u97f3\u548c\u97f3\u9891 AI \u6240\u9700\u7684\u7279\u5f81\u63d0\u53d6\u6d41\u6c34\u7ebf\u3002

\\[x(t) = A \\sin(2\\pi f t + \\phi)\\]

\\[L = 20 \\log_{10}\\left(\\frac{A}{A_\\text{ref}}\\right) \\text{ dB}\\] \\[f_s \\geq 2 f_\\text{max}\\]

\\[E = \\sum_{n=0}^{N-1} x[n]^2\\] \\[\\text{ZCR} = \\frac{1}{2(N-1)} \\sum_{n=1}^{N-1} |\\text{sign}(x[n]) - \\text{sign}(x[n-1])|\\] \\[R[k] = \\sum_{n=0}^{N-1-k} x[n] \\cdot x[n+k]\\] \\[X[k] = \\sum_{n=0}^{N-1} x[n] \\cdot e^{-j 2\\pi k n / N}, \\quad k = 0, 1, \\ldots, N-1\\]

\\[m = 2595 \\log_{10}\\left(1 + \\frac{f}{700}\\right)\\]

\\[ \\text{STFT}\\{x[n]\\}(m, k) = \\sum_{n=0}^{N-1} x[n + mH] \\cdot w[n] \\cdot e^{-j 2\\pi k n / N} \\] \\[y[n] = \\sum_{k=0}^{M} b_k \\cdot x[n-k]\\] \\[ y[n] = \\sum_{k=0}^{M} b_k \\cdot x[n-k] - \\sum_{k=1}^{L} a_k \\cdot y[n-k] \\] \\[H(z) = \\frac{\\sum_{k=0}^{M} b_k z^{-k}}{1 + \\sum_{k=1}^{L} a_k z^{-k}}\\] \\[ x[n] = \\frac{\\sum_{m} w[n - mH] \\cdot \\text{IDFT}\\{X(m, k)\\}[n - mH]}{\\sum_{m} w[n - mH]^2} \\] "},{"location":"chapter%2009%3A%20audio%20and%20speech/01.%20digital%20signal%20processing/#colab-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u5728 CoLab \u6216 notebook \u4e2d\u5b8c\u6210\uff09","text":"
  1. \u751f\u6210\u4e00\u4e2a\u6b63\u5f26\u6ce2\uff0c\u4ee5\u4e0d\u540c\u91c7\u6837\u7387\u91c7\u6837\uff0c\u6f14\u793a\u6df7\u53e0\u73b0\u8c61\u3002\u7ed8\u5236\u8fde\u7eed\u4fe1\u53f7\u3001\u6b63\u786e\u91c7\u6837\u7248\u672c\u548c\u6b20\u91c7\u6837\uff08\u6df7\u53e0\uff09\u7248\u672c\u7684\u5bf9\u6bd4\u56fe\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u53c2\u6570\nf_signal = 5.0  # 5 Hz \u4fe1\u53f7\nduration = 1.0  # 1 \u79d2\n\n# \"\u8fde\u7eed\"\u4fe1\u53f7\uff08\u975e\u5e38\u9ad8\u7684\u91c7\u6837\u7387\uff09\nt_cont = jnp.linspace(0, duration, 10000)\nx_cont = jnp.sin(2 * jnp.pi * f_signal * t_cont)\n\n# \u6b63\u786e\u91c7\u6837\uff08fs = 50 Hz\uff0c\u8fdc\u9ad8\u4e8e\u5948\u594e\u65af\u7279\u9891\u7387 10 Hz\uff09\nfs_good = 50\nt_good = jnp.arange(0, duration, 1.0 / fs_good)\nx_good = jnp.sin(2 * jnp.pi * f_signal * t_good)\n\n# \u6b20\u91c7\u6837\uff08fs = 7 Hz\uff0c\u4f4e\u4e8e\u5948\u594e\u65af\u7279\u9891\u7387 10 Hz\uff09-> \u6df7\u53e0\nfs_bad = 7\nt_bad = jnp.arange(0, duration, 1.0 / fs_bad)\nx_bad = jnp.sin(2 * jnp.pi * f_signal * t_bad)\n\n# \u6df7\u53e0\u540e\u7684\u9891\u7387\uff1a|f_signal - fs_bad| = |5 - 7| = 2 Hz\nf_alias = abs(f_signal - fs_bad)\nx_alias_cont = jnp.sin(2 * jnp.pi * f_alias * t_cont)\n\nfig, axes = plt.subplots(3, 1, figsize=(12, 9))\n\n# \u56fe 1\uff1a\u539f\u59cb\u4fe1\u53f7\naxes[0].plot(t_cont, x_cont, color='#3498db', linewidth=1.5, label=f'\u539f\u59cb {f_signal} Hz \u4fe1\u53f7')\naxes[0].set_title(f'\u539f\u59cb {f_signal} Hz \u4fe1\u53f7')\naxes[0].set_xlabel('\u65f6\u95f4 (s)'); axes[0].set_ylabel('\u632f\u5e45')\naxes[0].legend(); axes[0].grid(True, alpha=0.3)\n\n# \u56fe 2\uff1a\u6b63\u786e\u91c7\u6837\naxes[1].plot(t_cont, x_cont, color='#3498db', linewidth=1, alpha=0.4, label='\u539f\u59cb\u4fe1\u53f7')\naxes[1].stem(t_good, x_good, linefmt='#27ae60', markerfmt='o', basefmt='k-',\n             label=f'\u4ee5 {fs_good} Hz \u91c7\u6837\uff08\u9ad8\u4e8e\u5948\u594e\u65af\u7279\u9891\u7387\uff09')\naxes[1].set_title(f'\u6b63\u786e\u91c7\u6837\uff1afs = {fs_good} Hz > 2 x {f_signal} Hz')\naxes[1].set_xlabel('\u65f6\u95f4 (s)'); axes[1].set_ylabel('\u632f\u5e45')\naxes[1].legend(); axes[1].grid(True, alpha=0.3)\n\n# \u56fe 3\uff1a\u6df7\u53e0\u91c7\u6837\naxes[2].plot(t_cont, x_cont, color='#3498db', linewidth=1, alpha=0.4, label='\u539f\u59cb\u4fe1\u53f7')\naxes[2].stem(t_bad, x_bad, linefmt='#e74c3c', markerfmt='o', basefmt='k-',\n             label=f'\u4ee5 {fs_bad} Hz \u91c7\u6837\uff08\u4f4e\u4e8e\u5948\u594e\u65af\u7279\u9891\u7387\uff09')\naxes[2].plot(t_cont, x_alias_cont, color='#f39c12', linewidth=1.5, linestyle='--',\n             label=f'\u6df7\u53e0\u4fe1\u53f7\u8868\u73b0\u4e3a {f_alias} Hz')\naxes[2].set_title(f'\u6df7\u53e0\u91c7\u6837\uff1afs = {fs_bad} Hz < 2 x {f_signal} Hz')\naxes[2].set_xlabel('\u65f6\u95f4 (s)'); axes[2].set_ylabel('\u632f\u5e45')\naxes[2].legend(); axes[2].grid(True, alpha=0.3)\n\nplt.tight_layout(); plt.show()\n

  2. \u8ba1\u7b97\u5e76\u53ef\u89c6\u5316\u7531\u591a\u4e2a\u6b63\u5f26\u6ce2\u7ec4\u6210\u7684\u4fe1\u53f7\u7684 FFT\u3002\u663e\u793a\u5e45\u5ea6\u8c31\u5e76\u8bc6\u522b\u7ec4\u6210\u9891\u7387\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u521b\u5efa\u590d\u5408\u4fe1\u53f7\uff1a220 Hz + 440 Hz + 880 Hz\uff08A3 + A4 + A5\uff09\nfs = 8000  # 8 kHz \u91c7\u6837\u7387\nduration = 0.1  # 100 ms\nt = jnp.arange(0, duration, 1.0 / fs)\nn_samples = len(t)\n\n# \u4e09\u4e2a\u9891\u7387\u5206\u91cf\uff0c\u4e0d\u540c\u632f\u5e45\nx = 1.0 * jnp.sin(2 * jnp.pi * 220 * t) + \\\n    0.6 * jnp.sin(2 * jnp.pi * 440 * t) + \\\n    0.3 * jnp.sin(2 * jnp.pi * 880 * t)\n\n# \u8ba1\u7b97 FFT\nX = jnp.fft.fft(x)\nfreqs = jnp.fft.fftfreq(n_samples, d=1.0 / fs)\nmagnitude = jnp.abs(X) / n_samples  # \u5f52\u4e00\u5316\n\n# \u53ea\u7ed8\u5236\u6b63\u9891\u7387\u90e8\u5206\npos_mask = freqs >= 0\nfreqs_pos = freqs[pos_mask]\nmag_pos = magnitude[pos_mask] * 2  # \u7ffb\u500d\u4ee5\u8865\u507f\u8d1f\u9891\u7387\u7684\u80fd\u91cf\n\nfig, axes = plt.subplots(2, 1, figsize=(12, 7))\n\n# \u65f6\u57df\naxes[0].plot(t * 1000, x, color='#3498db', linewidth=1)\naxes[0].set_title('\u590d\u5408\u4fe1\u53f7\uff1a220 Hz + 440 Hz + 880 Hz')\naxes[0].set_xlabel('\u65f6\u95f4 (ms)'); axes[0].set_ylabel('\u632f\u5e45')\naxes[0].grid(True, alpha=0.3)\n\n# \u9891\u57df\naxes[1].plot(freqs_pos, mag_pos, color='#e74c3c', linewidth=1.5)\naxes[1].set_title('\u5e45\u5ea6\u8c31\uff08FFT\uff09')\naxes[1].set_xlabel('\u9891\u7387 (Hz)'); axes[1].set_ylabel('\u5e45\u5ea6')\naxes[1].set_xlim(0, 1500)\n# \u6807\u6ce8\u5cf0\u503c\nfor f_peak, amp in [(220, 1.0), (440, 0.6), (880, 0.3)]:\n    axes[1].annotate(f'{f_peak} Hz', xy=(f_peak, amp), fontsize=10,\n                     ha='center', va='bottom', color='#9b59b6',\n                     arrowprops=dict(arrowstyle='->', color='#9b59b6'))\naxes[1].grid(True, alpha=0.3)\n\nplt.tight_layout(); plt.show()\n

  3. \u5728 JAX \u4e2d\u4ece\u5934\u6784\u5efa\u5b8c\u6574\u7684 MFCC \u6d41\u6c34\u7ebf\uff1a\u9884\u52a0\u91cd\u3001\u5206\u5e27\u3001\u52a0\u7a97\u3001FFT\u3001\u6885\u5c14\u6ee4\u6ce2\u5668\u7ec4\u3001\u5bf9\u6570\u3001DCT\u3002\u53ef\u89c6\u5316\u6885\u5c14\u6ee4\u6ce2\u5668\u7ec4\u548c\u751f\u6210\u7684 MFCC \u70ed\u529b\u56fe\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# --- \u751f\u6210\u4e00\u4e2a\u5408\u6210\u7c7b\u8bed\u97f3\u4fe1\u53f7 ---\nkey = jax.random.PRNGKey(42)\nfs = 16000\nduration = 1.0\nt = jnp.arange(0, duration, 1.0 / fs)\n\n# \u6a21\u62df\u6d4a\u97f3\u8bed\u97f3\uff1a\u57fa\u9891 + \u8c10\u6ce2\uff0c\u632f\u5e45\u8870\u51cf\nf0 = 150.0  # \u57fa\u9891\nx = sum(jnp.sin(2 * jnp.pi * f0 * k * t) / k for k in range(1, 8))\n# \u6dfb\u52a0\u4e00\u4e9b\u566a\u58f0\nx = x + 0.1 * jax.random.normal(key, t.shape)\nx = x / jnp.max(jnp.abs(x))  # \u5f52\u4e00\u5316\n\n# --- \u7b2c 1 \u6b65\uff1a\u9884\u52a0\u91cd ---\nalpha = 0.97\nx_pre = jnp.concatenate([x[:1], x[1:] - alpha * x[:-1]])\n\n# --- \u7b2c 2 \u6b65\uff1a\u5206\u5e27 ---\nframe_len = int(0.025 * fs)   # 25 ms = 400 \u4e2a\u6837\u672c\nhop_len = int(0.010 * fs)     # 10 ms = 160 \u4e2a\u6837\u672c\nn_frames = (len(x_pre) - frame_len) // hop_len + 1\nframes = jnp.stack([x_pre[i * hop_len : i * hop_len + frame_len]\n                     for i in range(n_frames)])\n\n# --- \u7b2c 3 \u6b65\uff1a\u6c49\u660e\u7a97 ---\nhamming = 0.54 - 0.46 * jnp.cos(2 * jnp.pi * jnp.arange(frame_len) / (frame_len - 1))\nwindowed = frames * hamming\n\n# --- \u7b2c 4 \u6b65\uff1aFFT ---\nn_fft = 512\nspectra = jnp.fft.rfft(windowed, n=n_fft)\npower_spectra = jnp.abs(spectra) ** 2 / n_fft\n\n# --- \u7b2c 5 \u6b65\uff1a\u6885\u5c14\u6ee4\u6ce2\u5668\u7ec4 ---\nn_mels = 40\nf_min, f_max = 0.0, fs / 2.0\n\ndef hz_to_mel(f):\n    return 2595 * jnp.log10(1 + f / 700)\n\ndef mel_to_hz(m):\n    return 700 * (10 ** (m / 2595) - 1)\n\nmel_min = hz_to_mel(f_min)\nmel_max = hz_to_mel(f_max)\nmel_points = jnp.linspace(mel_min, mel_max, n_mels + 2)\nhz_points = mel_to_hz(mel_points)\n\nfreq_bins = jnp.floor((n_fft + 1) * hz_points / fs).astype(jnp.int32)\nn_freqs = n_fft // 2 + 1\nfilterbank = jnp.zeros((n_mels, n_freqs))\n\nfor m in range(n_mels):\n    f_left = freq_bins[m]\n    f_center = freq_bins[m + 1]\n    f_right = freq_bins[m + 2]\n    # \u4e0a\u5347\u6cbf\n    for k in range(int(f_left), int(f_center)):\n        if f_center != f_left:\n            filterbank = filterbank.at[m, k].set((k - f_left) / (f_center - f_left))\n    # \u4e0b\u964d\u6cbf\n    for k in range(int(f_center), int(f_right)):\n        if f_right != f_center:\n            filterbank = filterbank.at[m, k].set((f_right - k) / (f_right - f_center))\n\n# \u5e94\u7528\u6ee4\u6ce2\u5668\u7ec4\nmel_spectra = jnp.dot(power_spectra, filterbank.T)\n\n# --- \u7b2c 6 \u6b65\uff1a\u5bf9\u6570 ---\nlog_mel = jnp.log(mel_spectra + 1e-10)\n\n# --- \u7b2c 7 \u6b65\uff1aDCT\uff08\u7b2c\u4e8c\u7c7b\uff09 ---\nn_mfcc = 13\nn_mel_channels = log_mel.shape[1]\ndct_matrix = jnp.zeros((n_mfcc, n_mel_channels))\nfor i in range(n_mfcc):\n    for j in range(n_mel_channels):\n        dct_matrix = dct_matrix.at[i, j].set(\n            jnp.cos(jnp.pi * i * (j + 0.5) / n_mel_channels)\n        )\nmfccs = jnp.dot(log_mel, dct_matrix.T)\n\n# --- \u53ef\u89c6\u5316 ---\nfig, axes = plt.subplots(3, 1, figsize=(14, 11))\n\n# \u6885\u5c14\u6ee4\u6ce2\u5668\u7ec4\nfreq_axis = jnp.linspace(0, fs / 2, n_freqs)\nfor m in range(n_mels):\n    color = '#3498db' if m % 2 == 0 else '#e74c3c'\n    axes[0].plot(freq_axis, filterbank[m], color=color, alpha=0.6, linewidth=0.8)\naxes[0].set_title(f'\u6885\u5c14\u6ee4\u6ce2\u5668\u7ec4\uff08{n_mels} \u4e2a\u6ee4\u6ce2\u5668\uff09')\naxes[0].set_xlabel('\u9891\u7387 (Hz)'); axes[0].set_ylabel('\u6743\u91cd')\naxes[0].grid(True, alpha=0.3)\n\n# \u5bf9\u6570\u6885\u5c14\u8bed\u8c31\u56fe\nim1 = axes[1].imshow(log_mel.T, aspect='auto', origin='lower',\n                      extent=[0, duration, 0, n_mels], cmap='viridis')\naxes[1].set_title('\u5bf9\u6570\u6885\u5c14\u8bed\u8c31\u56fe')\naxes[1].set_xlabel('\u65f6\u95f4 (s)'); axes[1].set_ylabel('\u6885\u5c14\u9891\u5e26')\nplt.colorbar(im1, ax=axes[1], label='\u5bf9\u6570\u80fd\u91cf')\n\n# MFCC\nim2 = axes[2].imshow(mfccs.T, aspect='auto', origin='lower',\n                      extent=[0, duration, 0, n_mfcc], cmap='coolwarm')\naxes[2].set_title(f'MFCC\uff08\u524d {n_mfcc} \u4e2a\u7cfb\u6570\uff09')\naxes[2].set_xlabel('\u65f6\u95f4 (s)'); axes[2].set_ylabel('MFCC \u7d22\u5f15')\nplt.colorbar(im2, ax=axes[2], label='\u7cfb\u6570\u503c')\n\nplt.tight_layout(); plt.show()\n

  4. \u5b9e\u73b0 FIR \u4f4e\u901a\u548c\u9ad8\u901a\u6ee4\u6ce2\u5668\uff0c\u5e76\u53ef\u89c6\u5316\u5b83\u4eec\u5bf9\u5305\u542b\u4f4e\u9891\u548c\u9ad8\u9891\u5206\u91cf\u4fe1\u53f7\u7684\u5f71\u54cd\u3002\u540c\u65f6\u663e\u793a\u65f6\u57df\u548c\u9891\u57df\u7684\u89c6\u56fe\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u521b\u5efa\u5305\u542b\u4f4e\u9891\uff08100 Hz\uff09\u548c\u9ad8\u9891\uff082000 Hz\uff09\u5206\u91cf\u7684\u4fe1\u53f7\nfs = 8000\nduration = 0.05  # 50 ms\uff0c\u4fbf\u4e8e\u6e05\u6670\u663e\u793a\nt = jnp.arange(0, duration, 1.0 / fs)\n\nx_low = jnp.sin(2 * jnp.pi * 100 * t)\nx_high = 0.5 * jnp.sin(2 * jnp.pi * 2000 * t)\nx = x_low + x_high\n\n# \u4f7f\u7528\u7a97\u51fd\u6570\u6cd5\u8bbe\u8ba1\u7b80\u5355\u7684 FIR \u4f4e\u901a\u6ee4\u6ce2\u5668\ndef fir_lowpass(cutoff_hz, fs, n_taps=51):\n    \"\"\"\u4f7f\u7528\u7a97\u51fd\u6570\u6cd5\u8bbe\u8ba1 FIR \u4f4e\u901a\u6ee4\u6ce2\u5668\u3002\"\"\"\n    fc = cutoff_hz / fs  # \u5f52\u4e00\u5316\u622a\u6b62\u9891\u7387\n    n = jnp.arange(n_taps)\n    mid = (n_taps - 1) / 2.0\n    # Sinc \u51fd\u6570\uff08\u7406\u60f3\u4f4e\u901a\u51b2\u6fc0\u54cd\u5e94\uff09\n    h = jnp.where(n == mid, 2 * fc,\n                  jnp.sin(2 * jnp.pi * fc * (n - mid)) / (jnp.pi * (n - mid)))\n    # \u5e94\u7528\u6c49\u660e\u7a97\n    window = 0.54 - 0.46 * jnp.cos(2 * jnp.pi * n / (n_taps - 1))\n    h = h * window\n    h = h / jnp.sum(h)  # \u5f52\u4e00\u5316\u5230\u76f4\u6d41\u589e\u76ca\u4e3a 1\n    return h\n\ndef apply_filter(x, h):\n    \"\"\"\u901a\u8fc7\u5377\u79ef\u5e94\u7528 FIR \u6ee4\u6ce2\u5668\u3002\"\"\"\n    return jnp.convolve(x, h, mode='same')\n\n# 500 Hz \u4f4e\u901a\u6ee4\u6ce2\u5668\uff08\u901a\u8fc7 100 Hz\uff0c\u963b\u585e 2000 Hz\uff09\nh_lp = fir_lowpass(500, fs, n_taps=51)\nx_lp = apply_filter(x, h_lp)\n\n# \u9ad8\u901a = \u51b2\u6fc0 - \u4f4e\u901a\uff08\u9891\u8c31\u53cd\u8f6c\uff09\ndelta = jnp.zeros(51)\ndelta = delta.at[25].set(1.0)\nh_hp = delta - h_lp\nx_hp = apply_filter(x, h_hp)\n\n# \u8ba1\u7b97\u6240\u6709\u4fe1\u53f7\u7684\u9891\u8c31\ndef compute_spectrum(signal, fs):\n    X = jnp.fft.rfft(signal)\n    freqs = jnp.fft.rfftfreq(len(signal), d=1.0 / fs)\n    mag = jnp.abs(X) / len(signal) * 2\n    return freqs, mag\n\nfig, axes = plt.subplots(3, 2, figsize=(14, 10))\n\n# \u65f6\u57df\u56fe\nfor i, (sig, title, color) in enumerate([\n    (x, '\u539f\u59cb\u4fe1\u53f7\uff08100 Hz + 2000 Hz\uff09', '#3498db'),\n    (x_lp, '\u4f4e\u901a\u6ee4\u6ce2\u540e\uff08< 500 Hz\uff09', '#27ae60'),\n    (x_hp, '\u9ad8\u901a\u6ee4\u6ce2\u540e\uff08> 500 Hz\uff09', '#e74c3c')\n]):\n    axes[i, 0].plot(t * 1000, sig[:len(t)], color=color, linewidth=1)\n    axes[i, 0].set_title(f'\u65f6\u57df\uff1a{title}')\n    axes[i, 0].set_xlabel('\u65f6\u95f4 (ms)'); axes[i, 0].set_ylabel('\u632f\u5e45')\n    axes[i, 0].grid(True, alpha=0.3)\n\n# \u9891\u57df\u56fe\nfor i, (sig, title, color) in enumerate([\n    (x, '\u539f\u59cb\u4fe1\u53f7', '#3498db'),\n    (x_lp, '\u4f4e\u901a', '#27ae60'),\n    (x_hp, '\u9ad8\u901a', '#e74c3c')\n]):\n    freqs, mag = compute_spectrum(sig, fs)\n    axes[i, 1].plot(freqs, mag, color=color, linewidth=1.5)\n    axes[i, 1].set_title(f'\u9891\u8c31\uff1a{title}')\n    axes[i, 1].set_xlabel('\u9891\u7387 (Hz)'); axes[i, 1].set_ylabel('\u5e45\u5ea6')\n    axes[i, 1].set_xlim(0, 3000)\n    axes[i, 1].axvline(x=500, color='#f39c12', linestyle='--', alpha=0.7,\n                        label='\u622a\u6b62\u9891\u7387\uff08500 Hz\uff09')\n    axes[i, 1].legend(); axes[i, 1].grid(True, alpha=0.3)\n\nplt.tight_layout(); plt.show()\n

"},{"location":"chapter%2009%3A%20audio%20and%20speech/02.%20automatic%20speech%20recognition/","title":"\u81ea\u52a8\u8bed\u97f3\u8bc6\u522b","text":"

\u81ea\u52a8\u8bed\u97f3\u8bc6\u522b\u5c06\u53e3\u8bed\u97f3\u9891\u8f6c\u6362\u4e3a\u4e66\u9762\u6587\u672c\uff0c\u5f25\u5408\u4eba\u7c7b\u8bed\u97f3\u4e0e\u673a\u5668\u53ef\u8bfb\u8bed\u8a00\u4e4b\u95f4\u7684\u9e3f\u6c9f\u3002\u672c\u6587\u6db5\u76d6 GMM-HMM\u3001CTC \u635f\u5931\u3001RNN-\u8f6c\u5bfc\u5668\u3001\u57fa\u4e8e\u6ce8\u610f\u529b\u7684\u7f16\u7801\u5668-\u89e3\u7801\u5668\u6a21\u578b\uff08LAS\uff09\u3001Whisper \u4ee5\u53ca\u7aef\u5230\u7aef ASR\uff0c\u4ece\u7ecf\u5178\u6d41\u6c34\u7ebf\u5230\u73b0\u4ee3\u795e\u7ecf\u67b6\u6784\u3002

\\[ p(\\mathbf{x} | s) = \\sum_{m=1}^{M} w_m \\cdot \\mathcal{N}(\\mathbf{x} ; \\boldsymbol{\\mu}_m, \\boldsymbol{\\Sigma}_m) \\] \\[ \\delta_t(j) = \\max_{i} \\left[ \\delta_{t-1}(i) \\cdot a_{ij} \\right] \\cdot b_j(\\mathbf{x}_t) \\] \\[P(\\mathbf{y} | \\mathbf{x}) = \\sum_{\\boldsymbol{\\pi} \\in \\mathcal{B}^{-1}(\\mathbf{y})} \\prod_{t=1}^{T} p(\\pi_t | \\mathbf{x})\\]

\\[p(y | t, u) = \\text{softmax}(W \\cdot \\text{tanh}(W_\\text{enc} \\mathbf{h}_t^\\text{enc} + W_\\text{pred} \\mathbf{h}_u^\\text{pred} + b))\\]

\\[\\mathcal{L} = -\\log \\frac{\\exp(\\text{sim}(\\mathbf{c}_t, \\mathbf{q}_t) / \\kappa)}{\\sum_{\\tilde{\\mathbf{q}} \\in Q_t} \\exp(\\text{sim}(\\mathbf{c}_t, \\tilde{\\mathbf{q}}) / \\kappa)}\\] \\[\\hat{\\mathbf{y}} = \\arg\\max_\\mathbf{y} \\left[ \\log p_\\text{AM}(\\mathbf{y} | \\mathbf{x}) + \\lambda \\log p_\\text{LM}(\\mathbf{y}) \\right]\\] \\[\\hat{\\mathbf{y}} = \\arg\\max_\\mathbf{y} \\left[ \\log p_\\text{E2E}(\\mathbf{y} | \\mathbf{x}) - \\beta \\log p_\\text{ILM}(\\mathbf{y}) + \\lambda \\log p_\\text{LM}(\\mathbf{y}) \\right]\\] \\[\\text{WER} = \\frac{S + D + I}{N}\\] "},{"location":"chapter%2009%3A%20audio%20and%20speech/02.%20automatic%20speech%20recognition/#colab-notebook","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
  1. \u5728 JAX \u4e2d\u4ece\u5934\u5b9e\u73b0 CTC \u635f\u5931\u3002\u521b\u5efa\u4e00\u4e2a\u5305\u542b\u77ed\u5e8f\u5217 logits \u548c\u76ee\u6807\u6807\u7b7e\u7684\u73a9\u5177\u793a\u4f8b\uff0c\u8ba1\u7b97 CTC \u524d\u5411\u7b97\u6cd5\u5f97\u5230\u603b\u6982\u7387\uff0c\u5e76\u8ba1\u7b97\u8d1f\u5bf9\u6570\u4f3c\u7136\u635f\u5931\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef ctc_forward(log_probs, targets):\n    \"\"\"\n    CTC \u524d\u5411\u7b97\u6cd5\uff08\u5bf9\u6570\u57df\uff0c\u6570\u503c\u7a33\u5b9a\u6027\uff09\u3002\n    log_probs: (T, V) \u8bcd\u6c47\u8868\u4e0a\u7684\u5bf9\u6570\u6982\u7387\uff08\u7d22\u5f15 0 = \u7a7a\u767d\uff09\n    targets: (U,) \u76ee\u6807\u6807\u7b7e\u7d22\u5f15\uff08\u4e0d\u542b\u7a7a\u767d\uff09\n    \u8fd4\u56de\uff1a\u76ee\u6807\u5e8f\u5217\u5728 CTC \u4e0b\u7684\u5bf9\u6570\u6982\u7387\u3002\n    \"\"\"\n    T, V = log_probs.shape\n    U = len(targets)\n\n    # \u6784\u5efa\u5e26\u6709\u7a7a\u767d\u7684\u6269\u5c55\u6807\u7b7e\u5e8f\u5217\uff1a[blank, y1, blank, y2, ..., yU, blank]\n    S = 2 * U + 1\n    labels = jnp.zeros(S, dtype=jnp.int32)  # \u5168\u90e8\u4e3a\u7a7a\u767d\n    for i in range(U):\n        labels = labels.at[2 * i + 1].set(targets[i])\n\n    # \u521d\u59cb\u5316 alpha\uff08\u5bf9\u6570\u57df\uff09\n    NEG_INF = -1e30\n    alpha = jnp.full((T, S), NEG_INF)\n    alpha = alpha.at[0, 0].set(log_probs[0, labels[0]])        # \u4ee5\u7a7a\u767d\u5f00\u59cb\n    alpha = alpha.at[0, 1].set(log_probs[0, labels[1]])        # \u6216\u7b2c\u4e00\u4e2a\u6807\u7b7e\n\n    # \u524d\u5411\u586b\u5145\n    for t in range(1, T):\n        for s in range(S):\n            # \u540c\u4e00\u72b6\u6001\n            a = alpha[t - 1, s]\n            # \u4ece\u524d\u4e00\u72b6\u6001\u6765\n            if s > 0:\n                a = jnp.logaddexp(a, alpha[t - 1, s - 1])\n            # \u8df3\u8fc7\u7a7a\u767d\uff08\u5982\u679c\u5f53\u524d\u6807\u7b7e\u4e0e\u4e24\u6b65\u524d\u7684\u6807\u7b7e\u4e0d\u540c\uff09\n            if s > 1 and labels[s] != 0 and labels[s] != labels[s - 2]:\n                a = jnp.logaddexp(a, alpha[t - 1, s - 2])\n            alpha = alpha.at[t, s].set(a + log_probs[t, labels[s]])\n\n    # \u603b\u5bf9\u6570\u6982\u7387\uff1a\u6700\u540e\u65f6\u95f4\u6b65\u7684\u6700\u540e\u4e24\u4e2a\u72b6\u6001\u4e4b\u548c\n    log_prob = jnp.logaddexp(alpha[T - 1, S - 1], alpha[T - 1, S - 2])\n    return log_prob, alpha\n\n# --- \u73a9\u5177\u793a\u4f8b ---\nT = 12   # \u8f93\u5165\u957f\u5ea6\uff08\u65f6\u95f4\u6b65\uff09\nV = 5    # \u8bcd\u6c47\u8868\u5927\u5c0f\uff080=\u7a7a\u767d\uff0c1='c'\uff0c2='a'\uff0c3='t'\uff0c4='x'\uff09\ntargets = jnp.array([1, 2, 3])  # \"c\", \"a\", \"t\"\n\n# \u521b\u5efa\u968f\u673a logits \u5e76\u8f6c\u6362\u4e3a\u5bf9\u6570\u6982\u7387\nkey = jax.random.PRNGKey(42)\nlogits = jax.random.normal(key, (T, V))\nlog_probs = jax.nn.log_softmax(logits, axis=-1)\n\nlog_prob, alpha = ctc_forward(log_probs, targets)\nctc_loss = -log_prob\n\nprint(f\"\u76ee\u6807\u5e8f\u5217: {targets.tolist()} ('c', 'a', 't')\")\nprint(f\"\u8f93\u5165\u957f\u5ea6 T={T}, \u8bcd\u6c47\u8868\u5927\u5c0f V={V}\")\nprint(f\"CTC \u5bf9\u6570\u6982\u7387: {log_prob:.4f}\")\nprint(f\"CTC \u635f\u5931\uff08\u8d1f\u5bf9\u6570\u6982\u7387\uff09: {ctc_loss:.4f}\")\n\n# \u53ef\u89c6\u5316\u524d\u5411\u53d8\u91cf\uff08alpha\uff09\u7f51\u683c\nfig, ax = plt.subplots(figsize=(12, 5))\n# \u5c06\u5bf9\u6570\u8f6c\u6362\u4e3a\u7ebf\u6027\u4ee5\u4fbf\u53ef\u89c6\u5316\nalpha_linear = jnp.exp(alpha - jnp.max(alpha))  # \u5f52\u4e00\u5316\u4ee5\u4fbf\u89c2\u5bdf\nim = ax.imshow(alpha_linear.T, aspect='auto', origin='lower', cmap='viridis')\nax.set_xlabel('\u65f6\u95f4\u6b65 (t)')\nax.set_ylabel('\u6269\u5c55\u6807\u7b7e\u7d22\u5f15 (s)')\n\nlabel_names = ['_', 'c', '_', 'a', '_', 't', '_']  # _ = \u7a7a\u767d\nax.set_yticks(range(len(label_names)))\nax.set_yticklabels(label_names)\nax.set_title(f'CTC \u524d\u5411\u53d8\u91cf\uff08alpha \u7f51\u683c\uff09| \u635f\u5931 = {ctc_loss:.2f}')\nplt.colorbar(im, ax=ax, label='\u5f52\u4e00\u5316\u6982\u7387')\nplt.tight_layout(); plt.show()\n

  2. \u5728 JAX \u4e2d\u6784\u5efa\u4e00\u4e2a\u7b80\u5355\u7684\u7f16\u7801\u5668-\u89e3\u7801\u5668\u57fa\u4e8e\u6ce8\u610f\u529b\u7684 ASR \u6a21\u578b\uff08\u6700\u5c0f\u5316\u7684 LAS \u7c7b\u67b6\u6784\uff09\u3002\u4f7f\u7528\u4e00\u7ef4\u5377\u79ef\u7f16\u7801\u5668\u548c\u5e26\u6709\u70b9\u79ef\u6ce8\u610f\u529b\u7684\u5355\u5c42\u89e3\u7801\u5668\u3002\u5728\u5408\u6210\u6570\u636e\u4e0a\u8fd0\u884c\u5e76\u53ef\u89c6\u5316\u6ce8\u610f\u529b\u6743\u91cd\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# --- \u6700\u5c0f\u5316\u7684\u57fa\u4e8e\u6ce8\u610f\u529b\u7684\u7f16\u7801\u5668-\u89e3\u7801\u5668 ASR \u6a21\u578b ---\n\ndef init_params(key, input_dim, hidden_dim, vocab_size):\n    \"\"\"\u521d\u59cb\u5316\u5c0f\u578b LAS \u7c7b\u6a21\u578b\u7684\u53c2\u6570\u3002\"\"\"\n    keys = jax.random.split(key, 8)\n    scale = 0.1\n    params = {\n        # \u7f16\u7801\u5668\uff1a\u7b80\u5355\u7684\u7ebf\u6027\u6295\u5f71\uff08\u6a21\u62df\u5377\u79ef\u8f93\u51fa\uff09\n        'enc_w': jax.random.normal(keys[0], (input_dim, hidden_dim)) * scale,\n        'enc_b': jnp.zeros(hidden_dim),\n        # \u6ce8\u610f\u529b\uff1a\u67e5\u8be2\u3001\u952e\u3001\u503c\u6295\u5f71\n        'attn_q': jax.random.normal(keys[1], (hidden_dim, hidden_dim)) * scale,\n        'attn_k': jax.random.normal(keys[2], (hidden_dim, hidden_dim)) * scale,\n        'attn_v': jax.random.normal(keys[3], (hidden_dim, hidden_dim)) * scale,\n        # \u89e3\u7801\u5668 RNN\uff08\u4e3a\u6f14\u793a\u4f7f\u7528\u7b80\u5355 Elman RNN\uff09\n        'dec_wh': jax.random.normal(keys[4], (hidden_dim, hidden_dim)) * scale,\n        'dec_wx': jax.random.normal(keys[5], (vocab_size, hidden_dim)) * scale,\n        'dec_wc': jax.random.normal(keys[6], (hidden_dim, hidden_dim)) * scale,\n        'dec_b': jnp.zeros(hidden_dim),\n        # \u8f93\u51fa\u6295\u5f71\n        'out_w': jax.random.normal(keys[7], (hidden_dim, vocab_size)) * scale,\n        'out_b': jnp.zeros(vocab_size),\n    }\n    return params\n\ndef encode(params, x):\n    \"\"\"\u7f16\u7801\u5668\uff1a\u7ebf\u6027\u6295\u5f71\uff08\u5360\u4f4d\u7b26\uff0c\u4ee3\u8868\u5377\u79ef/LSTM \u5806\u53e0\uff09\u3002\"\"\"\n    return jnp.tanh(x @ params['enc_w'] + params['enc_b'])\n\ndef attend(params, query, enc_out):\n    \"\"\"\u5728\u7f16\u7801\u5668\u8f93\u51fa\u4e0a\u7684\u70b9\u79ef\u6ce8\u610f\u529b\u3002\"\"\"\n    q = query @ params['attn_q']                   # (hidden,)\n    k = enc_out @ params['attn_k']                 # (T_enc, hidden)\n    v = enc_out @ params['attn_v']                 # (T_enc, hidden)\n    d_k = q.shape[-1]\n    scores = (k @ q) / jnp.sqrt(d_k)              # (T_enc,)\n    weights = jax.nn.softmax(scores)               # (T_enc,)\n    context = weights @ v                          # (hidden,)\n    return context, weights\n\ndef decode_step(params, h_prev, y_prev_onehot, enc_out):\n    \"\"\"\u5355\u6b65\u89e3\u7801\uff1aRNN + \u6ce8\u610f\u529b\u3002\"\"\"\n    # \u5d4c\u5165\u524d\u4e00\u4e2a\u6807\u8bb0\n    y_emb = y_prev_onehot @ params['dec_wx']       # (hidden,)\n    # \u6ce8\u610f\u529b\u5230\u7f16\u7801\u5668\n    context, attn_w = attend(params, h_prev, enc_out)\n    # RNN \u66f4\u65b0\n    h = jnp.tanh(h_prev @ params['dec_wh'] + y_emb + context @ params['dec_wc']\n                  + params['dec_b'])\n    # \u8f93\u51fa logits\n    logits = h @ params['out_w'] + params['out_b']\n    return h, logits, attn_w\n\n# --- \u8bbe\u7f6e ---\nkey = jax.random.PRNGKey(0)\ninput_dim = 40       # \u4f8b\u5982 40 \u4e2a\u6885\u5c14\u9891\u5e26\nhidden_dim = 64\nvocab_size = 10      # \u7528\u4e8e\u6f14\u793a\u7684\u5c0f\u8bcd\u6c47\u8868\nT_enc = 30           # \u7f16\u7801\u5668\u65f6\u95f4\u6b65\nT_dec = 8            # \u89e3\u7801\u5668\u6b65\u6570\n\nparams = init_params(key, input_dim, hidden_dim, vocab_size)\n\n# \u5408\u6210\u8f93\u5165\uff1a\u968f\u673a\u6885\u5c14\u7c7b\u7279\u5f81\nkey, subkey = jax.random.split(key)\nx = jax.random.normal(subkey, (T_enc, input_dim))\n\n# \u7f16\u7801\nenc_out = encode(params, x)\n\n# \u89e3\u7801\uff08\u4f7f\u7528\u968f\u673a\u76ee\u6807\u7684\u6559\u5e08\u5f3a\u5236\uff09\nkey, subkey = jax.random.split(key)\ntargets = jax.random.randint(subkey, (T_dec,), 0, vocab_size)\n\nh = jnp.zeros(hidden_dim)\nall_logits = []\nall_attn = []\n\nfor t in range(T_dec):\n    y_prev = jax.nn.one_hot(targets[t] if t > 0 else 0, vocab_size)\n    h, logits, attn_w = decode_step(params, h, y_prev, enc_out)\n    all_logits.append(logits)\n    all_attn.append(attn_w)\n\nall_attn = jnp.stack(all_attn)  # (T_dec, T_enc)\nall_logits = jnp.stack(all_logits)  # (T_dec, vocab_size)\n\n# --- \u53ef\u89c6\u5316\u6ce8\u610f\u529b\u6743\u91cd ---\nfig, axes = plt.subplots(1, 2, figsize=(14, 5))\n\nim = axes[0].imshow(all_attn, aspect='auto', cmap='Blues', origin='lower')\naxes[0].set_xlabel('\u7f16\u7801\u5668\u65f6\u95f4\u6b65')\naxes[0].set_ylabel('\u89e3\u7801\u5668\u6b65')\naxes[0].set_title('\u6ce8\u610f\u529b\u6743\u91cd\uff08\u89e3\u7801\u5668 -> \u7f16\u7801\u5668\uff09')\nplt.colorbar(im, ax=axes[0])\n\n# \u663e\u793a\u6bcf\u4e2a\u89e3\u7801\u6b65\u7684\u9884\u6d4b\u6807\u8bb0\u5206\u5e03\nim2 = axes[1].imshow(jax.nn.softmax(all_logits, axis=-1), aspect='auto',\n                      cmap='Oranges', origin='lower')\naxes[1].set_xlabel('\u8bcd\u6c47\u8868\u7d22\u5f15')\naxes[1].set_ylabel('\u89e3\u7801\u5668\u6b65')\naxes[1].set_title('\u8f93\u51fa\u6807\u8bb0\u6982\u7387')\nplt.colorbar(im2, ax=axes[1])\n\nplt.suptitle('\u6700\u5c0f\u5316\u7684\u57fa\u4e8e\u6ce8\u610f\u529b\u7684 ASR \u6a21\u578b\uff08\u672a\u8bad\u7ec3\uff09')\nplt.tight_layout(); plt.show()\n

  3. \u4f7f\u7528\u52a8\u6001\u89c4\u5212\uff08\u7f16\u8f91\u8ddd\u79bb\uff09\u4ece\u5934\u8ba1\u7b97\u8bcd\u9519\u8bef\u7387\uff08WER\uff09\uff0c\u5e76\u9488\u5bf9\u4e00\u4e2a\u53c2\u8003\u6587\u672c\u8bc4\u4f30\u591a\u4e2a\u5047\u8bbe\u3002\u53ef\u89c6\u5316\u7f16\u8f91\u8ddd\u79bb\u77e9\u9635\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\nimport numpy as np\n\ndef compute_wer(reference, hypothesis):\n    \"\"\"\n    \u4f7f\u7528\u52a8\u6001\u89c4\u5212\uff08\u8bcd\u7ea7\u522b\u7684 Levenshtein \u8ddd\u79bb\uff09\u8ba1\u7b97 WER\u3002\n    \u8fd4\u56de WER\u3001\u66ff\u6362\u6570\u3001\u5220\u9664\u6570\u3001\u63d2\u5165\u6570\u548c DP \u77e9\u9635\u3002\n    \"\"\"\n    ref_words = reference.split()\n    hyp_words = hypothesis.split()\n    N = len(ref_words)\n    M = len(hyp_words)\n\n    # DP \u77e9\u9635\uff1ad[i][j] = ref[:i] \u548c hyp[:j] \u4e4b\u95f4\u7684\u7f16\u8f91\u8ddd\u79bb\n    d = np.zeros((N + 1, M + 1), dtype=np.int32)\n    # \u56de\u6eaf\u77e9\u9635\u7528\u4e8e\u7edf\u8ba1 S, D, I\n    ops = np.zeros((N + 1, M + 1, 3), dtype=np.int32)  # [sub, del, ins]\n\n    for i in range(N + 1):\n        d[i][0] = i  # \u5168\u90e8\u5220\u9664\n    for j in range(M + 1):\n        d[0][j] = j  # \u5168\u90e8\u63d2\u5165\n\n    for i in range(1, N + 1):\n        for j in range(1, M + 1):\n            if ref_words[i - 1] == hyp_words[j - 1]:\n                sub_cost = d[i - 1][j - 1]  # \u5339\u914d\uff0c\u65e0\u9700\u7f16\u8f91\n            else:\n                sub_cost = d[i - 1][j - 1] + 1  # \u66ff\u6362\n            del_cost = d[i - 1][j] + 1      # \u5220\u9664\n            ins_cost = d[i][j - 1] + 1      # \u63d2\u5165\n\n            d[i][j] = min(sub_cost, del_cost, ins_cost)\n\n    # \u56de\u6eaf\u7edf\u8ba1\u64cd\u4f5c\u6b21\u6570\n    i, j = N, M\n    S, D, I = 0, 0, 0\n    while i > 0 or j > 0:\n        if i > 0 and j > 0 and d[i][j] == d[i-1][j-1] and ref_words[i-1] == hyp_words[j-1]:\n            i -= 1; j -= 1  # \u6b63\u786e\n        elif i > 0 and j > 0 and d[i][j] == d[i-1][j-1] + 1:\n            S += 1; i -= 1; j -= 1  # \u66ff\u6362\n        elif i > 0 and d[i][j] == d[i-1][j] + 1:\n            D += 1; i -= 1  # \u5220\u9664\n        elif j > 0 and d[i][j] == d[i][j-1] + 1:\n            I += 1; j -= 1  # \u63d2\u5165\n        else:\n            break\n\n    wer = (S + D + I) / N if N > 0 else 0.0\n    return wer, S, D, I, d\n\n# --- \u6d4b\u8bd5\u7528\u4f8b ---\nreference = \"the cat sat on the mat\"\nhypotheses = [\n    \"the cat sat on the mat\",          # \u5b8c\u7f8e\n    \"the cat sit on the mat\",          # 1 \u6b21\u66ff\u6362\n    \"the cat on the mat\",              # 1 \u6b21\u5220\u9664\n    \"the big cat sat on the mat\",      # 1 \u6b21\u63d2\u5165\n    \"a dog sat in a rug\",              # \u591a\u5904\u9519\u8bef\n]\n\nprint(f\"\u53c2\u8003\u6587\u672c: '{reference}'\\n\")\nprint(f\"{'\u5047\u8bbe':<40s} {'WER':>6s} {'S':>3s} {'D':>3s} {'I':>3s}\")\nprint(\"-\" * 60)\nresults = []\nfor hyp in hypotheses:\n    wer, S, D, I, dp = compute_wer(reference, hyp)\n    results.append((hyp, wer, S, D, I, dp))\n    print(f\"'{hyp}':<40s} {wer:>6.1%} {S:>3d} {D:>3d} {I:>3d}\")\n\n# \u53ef\u89c6\u5316\u6700\u5dee\u60c5\u51b5\u7684 DP \u77e9\u9635\nworst = results[-1]\nhyp_words = worst[0].split()\nref_words = reference.split()\ndp_matrix = worst[5]\n\nfig, axes = plt.subplots(1, 2, figsize=(14, 5))\n\n# DP \u77e9\u9635\nim = axes[0].imshow(dp_matrix, cmap='YlOrRd', origin='upper')\naxes[0].set_xticks(range(len(hyp_words) + 1))\naxes[0].set_xticklabels([''] + hyp_words, rotation=45, ha='right', fontsize=9)\naxes[0].set_yticks(range(len(ref_words) + 1))\naxes[0].set_yticklabels([''] + ref_words, fontsize=9)\naxes[0].set_xlabel('\u5047\u8bbe\u8bcd')\naxes[0].set_ylabel('\u53c2\u8003\u8bcd')\naxes[0].set_title(f'\u7f16\u8f91\u8ddd\u79bb\u77e9\u9635\\nWER = {worst[1]:.1%}')\nfor i in range(dp_matrix.shape[0]):\n    for j in range(dp_matrix.shape[1]):\n        axes[0].text(j, i, str(dp_matrix[i, j]), ha='center', va='center', fontsize=8)\nplt.colorbar(im, ax=axes[0])\n\n# WER \u6bd4\u8f83\u67f1\u72b6\u56fe\nnames = [f'Hyp {i+1}' for i in range(len(results))]\nwers = [r[1] * 100 for r in results]\ncolors = ['#27ae60' if w == 0 else '#f39c12' if w < 30 else '#e74c3c' for w in wers]\naxes[1].barh(names, wers, color=colors)\naxes[1].set_xlabel('WER (%)')\naxes[1].set_title('\u8bcd\u9519\u8bef\u7387\u6bd4\u8f83')\nfor i, (w, r) in enumerate(zip(wers, results)):\n    axes[1].text(w + 1, i, f'{w:.0f}% (S={r[2]}, D={r[3]}, I={r[4]})',\n                 va='center', fontsize=9)\naxes[1].set_xlim(0, max(wers) * 1.4)\n\nplt.tight_layout(); plt.show()\n

  4. \u5728\u5bf9\u6570\u6885\u5c14\u9891\u8c31\u56fe\u4e0a\u5b9e\u73b0 SpecAugment\uff08\u9891\u7387\u63a9\u7801\u548c\u65f6\u95f4\u63a9\u7801\uff09\uff0c\u5e76\u53ef\u89c6\u5316\u539f\u59cb\u7248\u672c\u4e0e\u589e\u5f3a\u7248\u672c\u3002\u4ece\u5408\u6210\u4fe1\u53f7\u751f\u6210\u9891\u8c31\u56fe\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# --- \u751f\u6210\u5408\u6210\u5bf9\u6570\u6885\u5c14\u9891\u8c31\u56fe ---\nkey = jax.random.PRNGKey(42)\nfs = 16000\nduration = 2.0\nt = jnp.arange(0, duration, 1.0 / fs)\n\n# \u6a21\u62df\u8bed\u97f3\uff1a\u5e26\u8c10\u6ce2\u7684\u5541\u557e\u4fe1\u53f7\nf0 = 120.0\nx = sum(jnp.sin(2 * jnp.pi * f0 * k * t * (1 + 0.1 * t)) / k for k in range(1, 10))\nkey, subkey = jax.random.split(key)\nx = x + 0.05 * jax.random.normal(subkey, t.shape)\n\n# \u8ba1\u7b97\u5bf9\u6570\u6885\u5c14\u9891\u8c31\u56fe\uff08\u7b80\u5316\u7248\uff09\nframe_len = 400  # 25 ms\nhop_len = 160    # 10 ms\nn_fft = 512\nn_mels = 80\n\nn_frames = (len(x) - frame_len) // hop_len + 1\nhamming = 0.54 - 0.46 * jnp.cos(2 * jnp.pi * jnp.arange(frame_len) / (frame_len - 1))\n\nframes = jnp.stack([x[i * hop_len : i * hop_len + frame_len] for i in range(n_frames)])\nwindowed = frames * hamming\nspectra = jnp.abs(jnp.fft.rfft(windowed, n=n_fft)) ** 2\n\n# \u7b80\u5355\u7684\u6885\u5c14\u6ee4\u6ce2\u5668\u7ec4\ndef hz_to_mel(f): return 2595 * jnp.log10(1 + f / 700)\ndef mel_to_hz(m): return 700 * (10 ** (m / 2595) - 1)\n\nmel_points = jnp.linspace(hz_to_mel(0), hz_to_mel(fs / 2), n_mels + 2)\nhz_pts = mel_to_hz(mel_points)\nbins = jnp.floor((n_fft + 1) * hz_pts / fs).astype(jnp.int32)\n\nn_freqs = n_fft // 2 + 1\nfb = jnp.zeros((n_mels, n_freqs))\nfor m in range(n_mels):\n    lo, mid, hi = int(bins[m]), int(bins[m+1]), int(bins[m+2])\n    for k in range(lo, mid):\n        if mid != lo:\n            fb = fb.at[m, k].set((k - lo) / (mid - lo))\n    for k in range(mid, hi):\n        if hi != mid:\n            fb = fb.at[m, k].set((hi - k) / (hi - mid))\n\nlog_mel = jnp.log(spectra @ fb.T + 1e-10)\n\n# --- SpecAugment ---\ndef spec_augment(spec, key, n_freq_masks=2, freq_mask_width=15,\n                 n_time_masks=2, time_mask_width=25):\n    \"\"\"\u5e94\u7528 SpecAugment\uff1a\u9891\u7387\u63a9\u7801\u548c\u65f6\u95f4\u63a9\u7801\u3002\"\"\"\n    augmented = spec.copy()\n    T, F = spec.shape\n\n    # \u9891\u7387\u63a9\u7801\n    for _ in range(n_freq_masks):\n        key, k1, k2 = jax.random.split(key, 3)\n        f_width = jax.random.randint(k1, (), 1, freq_mask_width + 1)\n        f_start = jax.random.randint(k2, (), 0, max(1, F - freq_mask_width))\n        mask = (jnp.arange(F) >= f_start) & (jnp.arange(F) < f_start + f_width)\n        augmented = jnp.where(mask[None, :], 0.0, augmented)\n\n    # \u65f6\u95f4\u63a9\u7801\n    for _ in range(n_time_masks):\n        key, k1, k2 = jax.random.split(key, 3)\n        t_width = jax.random.randint(k1, (), 1, time_mask_width + 1)\n        t_start = jax.random.randint(k2, (), 0, max(1, T - time_mask_width))\n        mask = (jnp.arange(T) >= t_start) & (jnp.arange(T) < t_start + t_width)\n        augmented = jnp.where(mask[:, None], 0.0, augmented)\n\n    return augmented\n\nkey, subkey = jax.random.split(key)\nlog_mel_aug = spec_augment(log_mel, subkey)\n\n# --- \u53ef\u89c6\u5316 ---\nfig, axes = plt.subplots(2, 1, figsize=(14, 8))\n\nim0 = axes[0].imshow(log_mel.T, aspect='auto', origin='lower', cmap='inferno',\n                       extent=[0, duration, 0, n_mels])\naxes[0].set_title('\u539f\u59cb\u5bf9\u6570\u6885\u5c14\u9891\u8c31\u56fe')\naxes[0].set_xlabel('\u65f6\u95f4 (s)'); axes[0].set_ylabel('\u6885\u5c14\u9891\u5e26')\nplt.colorbar(im0, ax=axes[0], label='\u5bf9\u6570\u80fd\u91cf')\n\nim1 = axes[1].imshow(log_mel_aug.T, aspect='auto', origin='lower', cmap='inferno',\n                       extent=[0, duration, 0, n_mels])\naxes[1].set_title('SpecAugment \u540e\uff08\u9891\u7387 + \u65f6\u95f4\u63a9\u7801\uff09')\naxes[1].set_xlabel('\u65f6\u95f4 (s)'); axes[1].set_ylabel('\u6885\u5c14\u9891\u5e26')\nplt.colorbar(im1, ax=axes[1], label='\u5bf9\u6570\u80fd\u91cf')\n\nplt.tight_layout(); plt.show()\n

"},{"location":"chapter%2009%3A%20audio%20and%20speech/03.%20text%20to%20speech%20and%20voice/","title":"\u8bed\u97f3\u5408\u6210\u4e0e\u58f0\u97f3","text":"

\u8bed\u97f3\u5408\u6210\uff08Text-to-Speech Synthesis\uff09\u9006\u5411\u6267\u884c ASR \u6d41\u6c34\u7ebf\uff0c\u4ece\u4e66\u9762\u6587\u672c\u751f\u6210\u81ea\u7136\u542c\u611f\u7684\u97f3\u9891\u3002\u672c\u6587\u6db5\u76d6 TTS \u6d41\u6c34\u7ebf\uff08\u6587\u672c\u89c4\u8303\u5316\u3001G2P\u3001\u58f0\u5b66\u6a21\u578b\u3001\u58f0\u7801\u5668\uff09\u3001Tacotron\u3001WaveNet\u3001HiFi-GAN\u3001\u58f0\u97f3\u514b\u9686\u3001\u58f0\u97f3\u8f6c\u6362\u4ee5\u53ca\u8bed\u97f3\u6d3b\u52a8\u68c0\u6d4b\uff08VAD\uff09\u3002

\\[P(x) = \\prod_{t=1}^{T} P(x_t \\mid x_1, \\ldots, x_{t-1}, c)\\] \\[z = \\tanh(W_{f} \\ast x) \\odot \\sigma(W_{g} \\ast x)\\] \\[\\log P(x) = \\log P(z) + \\sum_{i} \\log \\left| \\det \\frac{\\partial f_i}{\\partial f_{i-1}} \\right|\\]

\\[\\mathcal{L}_G = \\mathcal{L}_{\\text{adv}}(G) + \\lambda_{\\text{mel}} \\mathcal{L}_{\\text{mel}}(G) + \\lambda_{\\text{fm}} \\mathcal{L}_{\\text{fm}}(G)\\]

\\[e_{i,j} = w^T \\tanh(W_s s_{i-1} + W_h h_j + W_f f_{i,j} + b)\\] \\[ \\begin{aligned} \\hat{d}_i &= \\text{DurationPredictor}(h_i) \\\\ \\hat{p}_i &= \\text{PitchPredictor}(h_i) \\\\ \\hat{e}_i &= \\text{EnergyPredictor}(h_i) \\end{aligned} \\]

"},{"location":"chapter%2009%3A%20audio%20and%20speech/03.%20text%20to%20speech%20and%20voice/#colab-notebook","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u751f\u6210\u5408\u6210\u6ce2\u5f62\uff08\u6a21\u62df\u5143\u97f3\u7684\u8c10\u6ce2\u4e4b\u548c\uff09\nsr = 16000\nduration = 1.0\nt = jnp.linspace(0, duration, int(sr * duration))\nf0 = 220.0  # \u57fa\u9891\nwaveform = (\n    0.6 * jnp.sin(2 * jnp.pi * f0 * t) +\n    0.3 * jnp.sin(2 * jnp.pi * 2 * f0 * t) +\n    0.1 * jnp.sin(2 * jnp.pi * 3 * f0 * t)\n)\n\n# \u8ba1\u7b97 STFT\nn_fft = 1024\nhop_length = 256\nwindow = jnp.hanning(n_fft)\n\ndef stft(signal, n_fft, hop_length, window):\n    \"\"\"\u8ba1\u7b97\u77ed\u65f6\u5085\u91cc\u53f6\u53d8\u6362\u3002\"\"\"\n    n_frames = 1 + (len(signal) - n_fft) // hop_length\n    frames = jnp.stack([\n        signal[i * hop_length : i * hop_length + n_fft] * window\n        for i in range(n_frames)\n    ])\n    return jnp.fft.rfft(frames, n=n_fft)\n\ndef istft(stft_matrix, hop_length, window, length):\n    \"\"\"\u4f7f\u7528\u91cd\u53e0\u76f8\u52a0\u6cd5\u8ba1\u7b97\u9006 STFT\u3002\"\"\"\n    n_fft = (stft_matrix.shape[1] - 1) * 2\n    n_frames = stft_matrix.shape[0]\n    frames = jnp.fft.irfft(stft_matrix, n=n_fft)\n    frames = frames * window[None, :]\n    output = jnp.zeros(length)\n    for i in range(n_frames):\n        start = i * hop_length\n        end = start + n_fft\n        if end <= length:\n            output = output.at[start:end].add(frames[i])\n    return output\n\n# \u6b63\u5411 STFT\nS = stft(waveform, n_fft, hop_length, window)\nmagnitude = jnp.abs(S)\n\n# \u6885\u5c14\u6ee4\u6ce2\u5668\u7ec4\nn_mels = 80\nmel_low = 0.0\nmel_high = 2595 * jnp.log10(1 + (sr / 2) / 700)\nmel_points = jnp.linspace(mel_low, mel_high, n_mels + 2)\nhz_points = 700 * (10 ** (mel_points / 2595) - 1)\nfreq_bins = jnp.floor((n_fft + 1) * hz_points / sr).astype(int)\n\nmel_filterbank = jnp.zeros((n_mels, n_fft // 2 + 1))\nfor m in range(n_mels):\n    f_left = freq_bins[m]\n    f_center = freq_bins[m + 1]\n    f_right = freq_bins[m + 2]\n    for k in range(f_left, f_center):\n        mel_filterbank = mel_filterbank.at[m, k].set(\n            (k - f_left) / max(f_center - f_left, 1)\n        )\n    for k in range(f_center, f_right):\n        mel_filterbank = mel_filterbank.at[m, k].set(\n            (f_right - k) / max(f_right - f_center, 1)\n        )\n\n# \u8f6c\u5230\u6885\u5c14\u5e76\u8fd4\u56de\uff08\u4f2a\u9006\uff09\nmel_spec = magnitude @ mel_filterbank.T\nmagnitude_reconstructed = mel_spec @ jnp.linalg.pinv(mel_filterbank.T)\nmagnitude_reconstructed = jnp.maximum(magnitude_reconstructed, 1e-7)\n\n# Griffin-Lim \u7b97\u6cd5\ndef griffin_lim(magnitude, n_iter, hop_length, window, signal_length):\n    \"\"\"\u8fed\u4ee3\u76f8\u4f4d\u91cd\u6784\u3002\"\"\"\n    n_fft = (magnitude.shape[1] - 1) * 2\n    key = jax.random.PRNGKey(42)\n    phase = jax.random.uniform(key, magnitude.shape, minval=-jnp.pi, maxval=jnp.pi)\n\n    for _ in range(n_iter):\n        complex_spec = magnitude * jnp.exp(1j * phase)\n        signal = istft(complex_spec, hop_length, window, signal_length)\n        reanalysis = stft(signal, n_fft, hop_length, window)\n        phase = jnp.angle(reanalysis)\n\n    complex_spec = magnitude * jnp.exp(1j * phase)\n    return istft(complex_spec, hop_length, window, signal_length)\n\nreconstructed = griffin_lim(magnitude_reconstructed, n_iter=60, hop_length=hop_length,\n                            window=window, signal_length=len(waveform))\n\n# \u7ed8\u5236\u5bf9\u6bd4\u56fe\nfig, axes = plt.subplots(3, 1, figsize=(12, 8))\n\naxes[0].plot(t[:1000], waveform[:1000], color='#3498db', linewidth=0.8)\naxes[0].set_title('\u539f\u59cb\u6ce2\u5f62')\naxes[0].set_ylabel('\u632f\u5e45')\n\naxes[1].imshow(jnp.log1p(mel_spec.T), aspect='auto', origin='lower', cmap='magma')\naxes[1].set_title('\u6885\u5c14\u8bed\u8c31\u56fe\uff08\u4e2d\u95f4\u8868\u793a\uff09')\naxes[1].set_ylabel('\u6885\u5c14\u9891\u5e26')\n\naxes[2].plot(t[:1000], reconstructed[:1000], color='#e74c3c', linewidth=0.8)\naxes[2].set_title('Griffin-Lim \u91cd\u6784\u6ce2\u5f62\uff0860 \u6b21\u8fed\u4ee3\uff09')\naxes[2].set_xlabel('\u65f6\u95f4 (\u79d2)')\naxes[2].set_ylabel('\u632f\u5e45')\n\nplt.tight_layout()\nplt.show()\n\n# \u6d4b\u91cf\u91cd\u6784\u8bef\u5dee\nmse = jnp.mean((waveform[:len(reconstructed)] - reconstructed[:len(waveform)]) ** 2)\nprint(f\"\u539f\u59cb\u4e0e\u91cd\u6784\u4e4b\u95f4\u7684 MSE\uff1a{mse:.6f}\")\nprint(\"\u6ce8\u610f\uff1a\u901a\u8fc7\u6885\u5c14\u53cd\u6f14\u5bfc\u81f4\u7684\u76f8\u4f4d\u4fe1\u606f\u4e22\u5931\u4f1a\u5f15\u8d77\u4f2a\u5f71\u3002\")\n
import jax\nimport jax.numpy as jnp\nimport jax.random as jr\nimport matplotlib.pyplot as plt\n\n# \u6a21\u62df\u5e26\u771f\u5b9e\u65f6\u957f\u7684\u97f3\u7d20\u5e8f\u5217\n# \u5728\u771f\u5b9e TTS \u4e2d\uff0c\u65f6\u957f\u6765\u81ea\u5f3a\u5236\u5bf9\u9f50\u6216\u6559\u5e08\u6ce8\u610f\u529b\ndef generate_synthetic_data(key, n_samples=200, max_phonemes=30, embed_dim=64):\n    \"\"\"\u751f\u6210\u5408\u6210\u97f3\u7d20\u5d4c\u5165\u548c\u65f6\u957f\u3002\"\"\"\n    keys = jr.split(key, 4)\n    lengths = jr.randint(keys[0], (n_samples,), 5, max_phonemes)\n\n    all_embeddings = []\n    all_durations = []\n    all_masks = []\n\n    for i in range(n_samples):\n        L = int(lengths[i])\n        emb = jr.normal(keys[1], (max_phonemes, embed_dim))\n        # \u65f6\u957f\uff1a\u5143\u97f3\uff08\u5076\u6570\u7d22\u5f15\uff09\u8f83\u957f\uff0c\u8f85\u97f3\u8f83\u77ed\n        base_dur = jnp.where(jnp.arange(max_phonemes) % 2 == 0, 8.0, 4.0)\n        noise = jr.normal(jr.fold_in(keys[2], i), (max_phonemes,)) * 1.5\n        dur = jnp.clip(base_dur + noise, 1.0, 20.0).astype(jnp.float32)\n        mask = (jnp.arange(max_phonemes) < L).astype(jnp.float32)\n\n        all_embeddings.append(emb)\n        all_durations.append(dur * mask)\n        all_masks.append(mask)\n\n    return (jnp.stack(all_embeddings), jnp.stack(all_durations),\n            jnp.stack(all_masks))\n\nkey = jr.PRNGKey(42)\nembeddings, durations, masks = generate_synthetic_data(key)\n\n# \u65f6\u957f\u9884\u6d4b\u5668\uff1a2 \u5c42\u4e00\u7ef4\u5377\u79ef + \u7ebf\u6027\u6295\u5f71\ndef init_duration_predictor(key, embed_dim=64, hidden_dim=128, kernel_size=3):\n    \"\"\"\u521d\u59cb\u5316\u65f6\u957f\u9884\u6d4b\u5668\u6743\u91cd\u3002\"\"\"\n    keys = jr.split(key, 4)\n    scale1 = jnp.sqrt(2.0 / (embed_dim * kernel_size))\n    scale2 = jnp.sqrt(2.0 / (hidden_dim * kernel_size))\n    params = {\n        'conv1_w': jr.normal(keys[0], (kernel_size, embed_dim, hidden_dim)) * scale1,\n        'conv1_b': jnp.zeros(hidden_dim),\n        'conv2_w': jr.normal(keys[1], (kernel_size, hidden_dim, hidden_dim)) * scale2,\n        'conv2_b': jnp.zeros(hidden_dim),\n        'linear_w': jr.normal(keys[2], (hidden_dim, 1)) * jnp.sqrt(2.0 / hidden_dim),\n        'linear_b': jnp.zeros(1),\n    }\n    return params\n\ndef duration_predictor(params, x):\n    \"\"\"\u4ece\u97f3\u7d20\u5d4c\u5165\u9884\u6d4b\u5bf9\u6570\u65f6\u957f\u3002x: (batch, seq, embed)\u3002\"\"\"\n    # \u5377\u79ef\u5c42 1 \u52a0 ReLU\n    h = jax.lax.conv_general_dilated(\n        x.transpose(0, 2, 1),  # (batch, embed, seq)\n        params['conv1_w'].transpose(2, 1, 0),  # (out, in, kernel)\n        window_strides=(1,), padding='SAME'\n    ).transpose(0, 2, 1) + params['conv1_b']  # \u56de\u5230 (batch, seq, hidden)\n    h = jax.nn.relu(h)\n\n    # \u5377\u79ef\u5c42 2 \u52a0 ReLU\n    h = jax.lax.conv_general_dilated(\n        h.transpose(0, 2, 1),\n        params['conv2_w'].transpose(2, 1, 0),\n        window_strides=(1,), padding='SAME'\n    ).transpose(0, 2, 1) + params['conv2_b']\n    h = jax.nn.relu(h)\n\n    # \u7ebf\u6027\u6295\u5f71\u5230\u6807\u91cf\n    log_dur = (h @ params['linear_w'] + params['linear_b']).squeeze(-1)\n    return log_dur\n\n# \u635f\u5931\uff1a\u5bf9\u6570\u65f6\u957f\u7684 MSE\uff08FastSpeech \u4e2d\u7684\u6807\u51c6\u505a\u6cd5\uff09\ndef loss_fn(params, embeddings, durations, masks):\n    log_dur_pred = duration_predictor(params, embeddings)\n    log_dur_true = jnp.log(jnp.clip(durations, 1.0, None))\n    sq_err = (log_dur_pred - log_dur_true) ** 2 * masks\n    return jnp.sum(sq_err) / jnp.sum(masks)\n\ngrad_fn = jax.jit(jax.value_and_grad(loss_fn))\n\n# \u8bad\u7ec3\u5faa\u73af\nparams = init_duration_predictor(jr.PRNGKey(0))\nlr = 1e-3\nlosses = []\n\nfor epoch in range(300):\n    loss_val, grads = grad_fn(params, embeddings, durations, masks)\n    params = jax.tree.map(lambda p, g: p - lr * g, params, grads)\n    losses.append(float(loss_val))\n\n# \u5728\u4e00\u4e2a\u6837\u672c\u4e0a\u8bc4\u4f30\nlog_dur_pred = duration_predictor(params, embeddings[:1])\ndur_pred = jnp.exp(log_dur_pred[0])\ndur_true = durations[0]\nmask = masks[0]\nvalid_len = int(jnp.sum(mask))\n\nfig, axes = plt.subplots(1, 2, figsize=(14, 5))\n\naxes[0].plot(losses, color='#3498db', linewidth=1.5)\naxes[0].set_xlabel('\u8f6e\u6b21')\naxes[0].set_ylabel('MSE \u635f\u5931\uff08\u5bf9\u6570\u65f6\u957f\uff09')\naxes[0].set_title('\u65f6\u957f\u9884\u6d4b\u5668\u8bad\u7ec3')\naxes[0].set_yscale('log')\n\nx_pos = jnp.arange(valid_len)\nwidth = 0.35\naxes[1].bar(x_pos - width/2, dur_true[:valid_len], width, color='#27ae60',\n            label='\u771f\u5b9e\u503c', alpha=0.8)\naxes[1].bar(x_pos + width/2, dur_pred[:valid_len], width, color='#e74c3c',\n            label='\u9884\u6d4b\u503c', alpha=0.8)\naxes[1].set_xlabel('\u97f3\u7d20\u7d22\u5f15')\naxes[1].set_ylabel('\u65f6\u957f\uff08\u5e27\uff09')\naxes[1].set_title('\u65f6\u957f\u9884\u6d4b\u4e0e\u771f\u5b9e\u503c\u5bf9\u6bd4')\naxes[1].legend()\n\nplt.tight_layout()\nplt.show()\n
import jax\nimport jax.numpy as jnp\nimport jax.random as jr\nimport matplotlib.pyplot as plt\n\ndef init_residual_block(key, channels, kernel_size, dilation):\n    \"\"\"\u521d\u59cb\u5316\u6269\u5f20\u6b8b\u5dee\u5377\u79ef\u5757\u3002\"\"\"\n    k1, k2 = jr.split(key)\n    scale = jnp.sqrt(2.0 / (channels * kernel_size))\n    return {\n        'conv1_w': jr.normal(k1, (kernel_size, channels, channels)) * scale,\n        'conv1_b': jnp.zeros(channels),\n        'conv2_w': jr.normal(k2, (kernel_size, channels, channels)) * scale,\n        'conv2_b': jnp.zeros(channels),\n        'dilation': dilation\n    }\n\ndef residual_block(params, x):\n    \"\"\"x: (batch, time, channels)\u3002\u5e26 LeakyReLU \u7684\u6269\u5f20\u5377\u79ef\u6b8b\u5dee\u5757\u3002\"\"\"\n    h = jax.nn.leaky_relu(x, negative_slope=0.1)\n    # \u7b80\u5316\uff1a\u4f7f\u7528\u6807\u51c6\u5377\u79ef\uff08\u6269\u5f20\u5728\u6982\u5ff5\u4e0a\u5904\u7406\uff09\n    h = jax.lax.conv_general_dilated(\n        h.transpose(0, 2, 1),\n        params['conv1_w'].transpose(2, 1, 0),\n        window_strides=(1,),\n        padding='SAME',\n        rhs_dilation=(params['dilation'],)\n    ).transpose(0, 2, 1) + params['conv1_b']\n    h = jax.nn.leaky_relu(h, negative_slope=0.1)\n    h = jax.lax.conv_general_dilated(\n        h.transpose(0, 2, 1),\n        params['conv2_w'].transpose(2, 1, 0),\n        window_strides=(1,),\n        padding='SAME'\n    ).transpose(0, 2, 1) + params['conv2_b']\n    return x + h\n\ndef init_generator(key, n_mels=80, upsample_rates=(8, 8, 4),\n                   channels=128):\n    \"\"\"\u521d\u59cb\u5316\u6700\u5c0f\u5316\u7684 HiFi-GAN \u98ce\u683c\u751f\u6210\u5668\u3002\"\"\"\n    keys = jr.split(key, 10)\n    params = {}\n\n    # \u8f93\u5165\u6295\u5f71\uff1a\u6885\u5c14\u9891\u5e26 -> \u901a\u9053\n    params['input_w'] = jr.normal(keys[0], (7, n_mels, channels)) * 0.02\n    params['input_b'] = jnp.zeros(channels)\n\n    # \u4e0a\u91c7\u6837\u5757\uff08\u8f6c\u7f6e\u5377\u79ef\uff09\n    in_ch = channels\n    for i, rate in enumerate(upsample_rates):\n        k_size = rate * 2\n        scale = jnp.sqrt(2.0 / (in_ch * k_size))\n        out_ch = in_ch // 2\n        params[f'up{i}_w'] = jr.normal(keys[i+1], (k_size, in_ch, out_ch)) * scale\n        params[f'up{i}_b'] = jnp.zeros(out_ch)\n        # \u6bcf\u4e2a\u5c3a\u5ea6\u4e0b\u7684\u6b8b\u5dee\u5757\n        params[f'res{i}_0'] = init_residual_block(jr.fold_in(keys[i+4], 0),\n                                                    out_ch, 3, 1)\n        params[f'res{i}_1'] = init_residual_block(jr.fold_in(keys[i+4], 1),\n                                                    out_ch, 3, 3)\n        in_ch = out_ch\n\n    # \u8f93\u51fa\u6295\u5f71\u5230\u5355\u58f0\u9053\u6ce2\u5f62\n    params['output_w'] = jr.normal(keys[8], (7, in_ch, 1)) * 0.02\n    params['output_b'] = jnp.zeros(1)\n    params['upsample_rates'] = upsample_rates\n\n    return params\n\ndef generator_forward(params, mel):\n    \"\"\"mel: (batch, time, n_mels) -> waveform: (batch, time * prod(rates), 1)\u3002\"\"\"\n    # \u8f93\u5165\u6295\u5f71\n    h = jax.lax.conv_general_dilated(\n        mel.transpose(0, 2, 1),\n        params['input_w'].transpose(2, 1, 0),\n        window_strides=(1,), padding='SAME'\n    ).transpose(0, 2, 1) + params['input_b']\n\n    for i, rate in enumerate(params['upsample_rates']):\n        h = jax.nn.leaky_relu(h, negative_slope=0.1)\n        # \u901a\u8fc7\u8f6c\u7f6e\u5377\u79ef\u4e0a\u91c7\u6837\n        k_size = rate * 2\n        h = jax.lax.conv_transpose(\n            h.transpose(0, 2, 1),\n            params[f'up{i}_w'].transpose(2, 1, 0),\n            strides=(rate,),\n            padding='SAME'\n        ).transpose(0, 2, 1) + params[f'up{i}_b']\n        # \u6b8b\u5dee\u5757\n        h = residual_block(params[f'res{i}_0'], h)\n        h = residual_block(params[f'res{i}_1'], h)\n\n    h = jax.nn.leaky_relu(h, negative_slope=0.1)\n    out = jax.lax.conv_general_dilated(\n        h.transpose(0, 2, 1),\n        params['output_w'].transpose(2, 1, 0),\n        window_strides=(1,), padding='SAME'\n    ).transpose(0, 2, 1) + params['output_b']\n\n    return jnp.tanh(out)\n\n# \u521b\u5efa\u4e00\u4e2a\u5408\u6210\u6885\u5c14\u8bed\u8c31\u56fe\uff08\u6a21\u62df\u5143\u97f3\uff09\nn_mels = 80\nn_frames = 50\nmel = jnp.zeros((1, n_frames, n_mels))\n# \u5728\u4f4e\u9891\u6885\u5c14\u9891\u5e26\u4e2d\u6dfb\u52a0\u80fd\u91cf\uff08\u6a21\u62df\u5171\u632f\u5cf0\uff09\nmel = mel.at[:, :, 5:15].set(1.0)\nmel = mel.at[:, :, 20:25].set(0.6)\n\n# \u521d\u59cb\u5316\u5e76\u8fd0\u884c\u751f\u6210\u5668\nkey = jr.PRNGKey(42)\nparams = init_generator(key, n_mels=n_mels, upsample_rates=(8, 8, 4),\n                         channels=128)\nwaveform = generator_forward(params, mel)\n\nprint(f\"\u8f93\u5165\u6885\u5c14\u5f62\u72b6\uff1a{mel.shape}\")\nprint(f\"\u8f93\u51fa\u6ce2\u5f62\u5f62\u72b6\uff1a{waveform.shape}\")\nprint(f\"\u4e0a\u91c7\u6837\u56e0\u5b50\uff1a{8 * 8 * 4} = {8*8*4}x\")\n\nfig, axes = plt.subplots(2, 1, figsize=(12, 6))\n\naxes[0].imshow(mel[0].T, aspect='auto', origin='lower', cmap='magma')\naxes[0].set_title('\u8f93\u5165\u6885\u5c14\u8bed\u8c31\u56fe')\naxes[0].set_ylabel('\u6885\u5c14\u9891\u5e26')\naxes[0].set_xlabel('\u5e27')\n\nwaveform_np = waveform[0, :, 0]\naxes[1].plot(waveform_np[:2000], color='#9b59b6', linewidth=0.5)\naxes[1].set_title('\u751f\u6210\u5668\u8f93\u51fa\u6ce2\u5f62\uff08\u672a\u7ecf\u8bad\u7ec3 - \u968f\u673a\u566a\u58f0\uff09')\naxes[1].set_ylabel('\u632f\u5e45')\naxes[1].set_xlabel('\u6837\u672c')\n\nplt.tight_layout()\nplt.show()\nprint(\"\u6ce8\u610f\uff1a\u8f93\u51fa\u662f\u566a\u58f0\uff0c\u56e0\u4e3a\u751f\u6210\u5668\u672a\u7ecf\u8bad\u7ec3\u3002\")\nprint(\"\u5728\u5b9e\u8df5\u4e2d\uff0c\u5bf9\u6297\u635f\u5931 + \u6885\u5c14\u635f\u5931\u8bad\u7ec3\u4f1a\u5c06\u5176\u5851\u9020\u6210\u8bed\u97f3\u3002\")\n
import jax\nimport jax.numpy as jnp\nimport jax.random as jr\nimport matplotlib.pyplot as plt\n\n# \u751f\u6210\u5408\u6210\u5bf9\u6570\u6885\u5c14\u80fd\u91cf\u7279\u5f81\u53ca\u8bed\u97f3/\u9759\u97f3\u6807\u7b7e\ndef generate_vad_data(key, n_sequences=100, n_frames=200, n_features=40):\n    \"\"\"\u6a21\u62df\u5bf9\u6570\u6885\u5c14\u7279\u5f81\uff1a\u8bed\u97f3\u533a\u57df\u80fd\u91cf\u66f4\u9ad8\u4e14\u5177\u6709\u7ed3\u6784\u3002\"\"\"\n    keys = jr.split(key, 5)\n    all_features = []\n    all_labels = []\n\n    for i in range(n_sequences):\n        k = jr.fold_in(keys[0], i)\n        k1, k2, k3 = jr.split(k, 3)\n\n        # \u968f\u673a\u8bed\u97f3/\u9759\u97f3\u6a21\u5f0f\n        label = jnp.zeros(n_frames)\n        n_segments = jr.randint(k1, (), 2, 6)\n        for seg in range(int(n_segments)):\n            start = jr.randint(jr.fold_in(k2, seg), (), 0, n_frames - 20)\n            length = jr.randint(jr.fold_in(k3, seg), (), 10, 50)\n            end = jnp.minimum(start + length, n_frames)\n            label = label.at[int(start):int(end)].set(1.0)\n\n        # \u7279\u5f81\uff1a\u8bed\u97f3\u5e27\u5177\u6709\u66f4\u9ad8\u80fd\u91cf + \u9891\u8c31\u7ed3\u6784\n        noise = jr.normal(jr.fold_in(keys[1], i), (n_frames, n_features)) * 0.3\n        speech_pattern = jnp.outer(label, jnp.exp(-jnp.arange(n_features) / 15.0))\n        features = speech_pattern * 2.0 + noise + 0.1\n\n        all_features.append(features)\n        all_labels.append(label)\n\n    return jnp.stack(all_features), jnp.stack(all_labels)\n\nkey = jr.PRNGKey(123)\nfeatures, labels = generate_vad_data(key)\ntrain_features, train_labels = features[:80], labels[:80]\ntest_features, test_labels = features[80:], labels[80:]\n\n# \u57fa\u4e8e GRU \u7684\u7b80\u5355 VAD \u6a21\u578b\ndef init_vad_model(key, input_dim=40, hidden_dim=64):\n    keys = jr.split(key, 6)\n    scale_ih = jnp.sqrt(2.0 / input_dim)\n    scale_hh = jnp.sqrt(2.0 / hidden_dim)\n    return {\n        'W_z': jr.normal(keys[0], (input_dim, hidden_dim)) * scale_ih,\n        'U_z': jr.normal(keys[1], (hidden_dim, hidden_dim)) * scale_hh,\n        'b_z': jnp.zeros(hidden_dim),\n        'W_r': jr.normal(keys[2], (input_dim, hidden_dim)) * scale_ih,\n        'U_r': jr.normal(keys[3], (hidden_dim, hidden_dim)) * scale_hh,\n        'b_r': jnp.zeros(hidden_dim),\n        'W_h': jr.normal(keys[4], (input_dim, hidden_dim)) * scale_ih,\n        'U_h': jr.normal(keys[5], (hidden_dim, hidden_dim)) * scale_hh,\n        'b_h': jnp.zeros(hidden_dim),\n        'W_out': jr.normal(jr.fold_in(keys[0], 99), (hidden_dim, 1)) * 0.1,\n        'b_out': jnp.zeros(1),\n    }\n\ndef gru_step(params, h, x):\n    \"\"\"\u5355\u6b65 GRU\u3002\"\"\"\n    z = jax.nn.sigmoid(x @ params['W_z'] + h @ params['U_z'] + params['b_z'])\n    r = jax.nn.sigmoid(x @ params['W_r'] + h @ params['U_r'] + params['b_r'])\n    h_tilde = jnp.tanh(x @ params['W_h'] + (r * h) @ params['U_h'] + params['b_h'])\n    h_new = (1 - z) * h + z * h_tilde\n    return h_new\n\ndef vad_forward(params, x):\n    \"\"\"x: (batch, time, features) -> logits: (batch, time)\u3002\"\"\"\n    batch_size, n_frames, _ = x.shape\n    hidden_dim = params['W_z'].shape[1]\n    h = jnp.zeros((batch_size, hidden_dim))\n\n    outputs = []\n    for t in range(n_frames):\n        h = gru_step(params, h, x[:, t, :])\n        logit = (h @ params['W_out'] + params['b_out']).squeeze(-1)\n        outputs.append(logit)\n\n    return jnp.stack(outputs, axis=1)\n\ndef bce_loss(params, features, labels):\n    \"\"\"VAD \u7684\u4e8c\u5143\u4ea4\u53c9\u71b5\u635f\u5931\u3002\"\"\"\n    logits = vad_forward(params, features)\n    probs = jax.nn.sigmoid(logits)\n    probs = jnp.clip(probs, 1e-7, 1 - 1e-7)\n    loss = -(labels * jnp.log(probs) + (1 - labels) * jnp.log(1 - probs))\n    return jnp.mean(loss)\n\ngrad_fn = jax.jit(jax.value_and_grad(bce_loss))\n\n# \u8bad\u7ec3\nparams = init_vad_model(jr.PRNGKey(0))\nlr = 5e-3\nlosses = []\n\nfor epoch in range(200):\n    loss_val, grads = grad_fn(params, train_features, train_labels)\n    params = jax.tree.map(lambda p, g: p - lr * g, params, grads)\n    losses.append(float(loss_val))\n    if epoch % 50 == 0:\n        print(f\"\u8f6e\u6b21 {epoch}\uff1a\u635f\u5931 = {loss_val:.4f}\")\n\n# \u5728\u6d4b\u8bd5\u96c6\u4e0a\u8bc4\u4f30\ntest_logits = vad_forward(params, test_features)\ntest_preds = (jax.nn.sigmoid(test_logits) > 0.5).astype(jnp.float32)\naccuracy = jnp.mean(test_preds == test_labels)\nprint(f\"\\n\u6d4b\u8bd5\u51c6\u786e\u7387\uff1a{accuracy:.4f}\")\n\n# \u53ef\u89c6\u5316\u4e00\u4e2a\u6d4b\u8bd5\u793a\u4f8b\nidx = 0\nfig, axes = plt.subplots(3, 1, figsize=(14, 7))\n\naxes[0].imshow(test_features[idx].T, aspect='auto', origin='lower', cmap='magma')\naxes[0].set_title('\u5bf9\u6570\u6885\u5c14\u80fd\u91cf\u7279\u5f81')\naxes[0].set_ylabel('\u6885\u5c14\u9891\u5e26')\n\naxes[1].fill_between(range(200), test_labels[idx], alpha=0.4, color='#27ae60',\n                     label='\u771f\u5b9e\u503c')\naxes[1].plot(jax.nn.sigmoid(test_logits[idx]), color='#e74c3c',\n             linewidth=1.5, label='\u9884\u6d4b\u6982\u7387')\naxes[1].axhline(0.5, color='gray', linestyle='--', linewidth=0.8)\naxes[1].set_ylabel('\u8bed\u97f3\u6982\u7387')\naxes[1].legend()\naxes[1].set_title('VAD \u9884\u6d4b')\n\naxes[2].fill_between(range(200), test_labels[idx], alpha=0.4, color='#27ae60',\n                     label='\u771f\u5b9e\u503c')\naxes[2].fill_between(range(200), test_preds[idx], alpha=0.4, color='#f39c12',\n                     label='\u9884\u6d4b\uff08\u9608\u503c=0.5\uff09')\naxes[2].set_ylabel('\u8bed\u97f3 / \u9759\u97f3')\naxes[2].set_xlabel('\u5e27')\naxes[2].legend()\naxes[2].set_title('VAD \u4e8c\u503c\u51b3\u7b56')\n\nplt.tight_layout()\nplt.show()\n
"},{"location":"chapter%2009%3A%20audio%20and%20speech/04.%20speaker%20and%20audio%20analysis/","title":"\u8bf4\u8bdd\u4eba\u4e0e\u97f3\u9891\u5206\u6790","text":"

\u8bf4\u8bdd\u4eba\u4e0e\u97f3\u9891\u5206\u6790\u8bc6\u522b\u8c01\u5728\u8bf4\u8bdd\u3001\u4f55\u65f6\u8bf4\u8bdd\u4ee5\u53ca\u5b58\u5728\u54ea\u4e9b\u975e\u8bed\u8a00\u58f0\u97f3\u3002\u672c\u6587\u6db5\u76d6\u8bf4\u8bdd\u4eba\u786e\u8ba4\u4e0e\u8bc6\u522b\u3001i\u5411\u91cf\u3001d\u5411\u91cf\u3001x\u5411\u91cf\u3001\u8bf4\u8bdd\u4eba\u65e5\u5fd7\u3001\u97f3\u9891\u4e8b\u4ef6\u5206\u7c7b\u3001\u97f3\u4e50\u4fe1\u606f\u68c0\u7d22\u4ee5\u53ca\u8bed\u97f3\u60c5\u611f\u8bc6\u522b\u3002

\\[s = \\frac{e \\cdot t}{\\|e\\| \\, \\|t\\|}\\] \\[M = m + Tw\\] \\[\\text{score}(w_1, w_2) = \\log \\frac{P(w_1, w_2 \\mid \\text{\u540c\u4e00\u8bf4\u8bdd\u4eba})}{P(w_1 \\mid \\text{\u8bf4\u8bdd\u4eba}_1) \\, P(w_2 \\mid \\text{\u8bf4\u8bdd\u4eba}_2)}\\]

\\[ \\begin{aligned} \\mu &= \\frac{1}{T} \\sum_{t=1}^{T} h_t \\\\ \\sigma &= \\sqrt{\\frac{1}{T} \\sum_{t=1}^{T} (h_t - \\mu)^2} \\end{aligned} \\] \\[\\alpha_t = \\frac{\\exp(v^T f(h_t))}{\\sum_{\\tau} \\exp(v^T f(h_\\tau))}\\] \\[L = -\\log \\frac{e^{s \\cos(\\theta_{y_i} + m)}}{e^{s \\cos(\\theta_{y_i} + m)} + \\sum_{j \\neq y_i} e^{s \\cos \\theta_j}}\\]

\\[\\hat{y}_{t,s} = \\sigma(f_s(h_t))\\] \\[\\hat{Y}_c = \\sigma\\left(\\sum_t \\alpha_{t,c} \\cdot f_{t,c}\\right)\\]

\\[\\text{chroma}(p) = \\sum_{k : \\text{pitch}(k) \\bmod 12 = p} |X(k)|^2\\] "},{"location":"chapter%2009%3A%20audio%20and%20speech/04.%20speaker%20and%20audio%20analysis/#colab","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528 Colab \u6216\u7b14\u8bb0\u672c\uff09","text":"
import jax\nimport jax.numpy as jnp\nimport jax.random as jr\nimport matplotlib.pyplot as plt\n\n# Simulate frame-level MFCC features for multiple speakers\ndef generate_speaker_data(key, n_speakers=5, utterances_per_speaker=20,\n                          n_frames=100, n_features=40):\n    \"\"\"Generate synthetic speaker data with speaker-dependent patterns.\"\"\"\n    keys = jr.split(key, 3)\n    all_features = []\n    all_labels = []\n\n    # Each speaker has a characteristic spectral pattern\n    speaker_patterns = jr.normal(keys[0], (n_speakers, n_features)) * 0.5\n\n    for spk in range(n_speakers):\n        for utt in range(utterances_per_speaker):\n            k = jr.fold_in(keys[1], spk * utterances_per_speaker + utt)\n            noise = jr.normal(k, (n_frames, n_features)) * 0.3\n            features = speaker_patterns[spk][None, :] + noise\n            all_features.append(features)\n            all_labels.append(spk)\n\n    perm = jr.permutation(keys[2], len(all_features))\n    features = jnp.stack(all_features)[perm]\n    labels = jnp.array(all_labels)[perm]\n    return features, labels\n\nkey = jr.PRNGKey(42)\nfeatures, labels = generate_speaker_data(key)\nn_speakers = 5\nn_features = 40\n\n# x-vector-style model\ndef init_xvector(key, n_features=40, hidden=128, embed_dim=64, n_speakers=5):\n    keys = jr.split(key, 8)\n    params = {\n        # TDNN layer 1: context [-2, 2]\n        'tdnn1_w': jr.normal(keys[0], (5, n_features, hidden)) * jnp.sqrt(2.0 / (5 * n_features)),\n        'tdnn1_b': jnp.zeros(hidden),\n        # TDNN layer 2: context [-2, 2]\n        'tdnn2_w': jr.normal(keys[1], (5, hidden, hidden)) * jnp.sqrt(2.0 / (5 * hidden)),\n        'tdnn2_b': jnp.zeros(hidden),\n        # TDNN layer 3: context [-3, 3]\n        'tdnn3_w': jr.normal(keys[2], (7, hidden, hidden)) * jnp.sqrt(2.0 / (7 * hidden)),\n        'tdnn3_b': jnp.zeros(hidden),\n        # Segment-level layers (after pooling: 2*hidden -> embed_dim)\n        'seg1_w': jr.normal(keys[3], (2 * hidden, embed_dim)) * jnp.sqrt(2.0 / (2 * hidden)),\n        'seg1_b': jnp.zeros(embed_dim),\n        # Classification head\n        'cls_w': jr.normal(keys[4], (embed_dim, n_speakers)) * jnp.sqrt(2.0 / embed_dim),\n        'cls_b': jnp.zeros(n_speakers),\n    }\n    return params\n\ndef xvector_forward(params, x, return_embedding=False):\n    \"\"\"x: (batch, frames, features) -> logits or embeddings.\"\"\"\n    # TDNN layers (1D convolutions)\n    h = jax.lax.conv_general_dilated(\n        x.transpose(0, 2, 1), params['tdnn1_w'].transpose(2, 1, 0),\n        window_strides=(1,), padding='SAME'\n    ).transpose(0, 2, 1) + params['tdnn1_b']\n    h = jax.nn.relu(h)\n\n    h = jax.lax.conv_general_dilated(\n        h.transpose(0, 2, 1), params['tdnn2_w'].transpose(2, 1, 0),\n        window_strides=(1,), padding='SAME'\n    ).transpose(0, 2, 1) + params['tdnn2_b']\n    h = jax.nn.relu(h)\n\n    h = jax.lax.conv_general_dilated(\n        h.transpose(0, 2, 1), params['tdnn3_w'].transpose(2, 1, 0),\n        window_strides=(1,), padding='SAME'\n    ).transpose(0, 2, 1) + params['tdnn3_b']\n    h = jax.nn.relu(h)\n\n    # Statistics pooling: mean and std over time\n    mu = jnp.mean(h, axis=1)\n    sigma = jnp.std(h, axis=1)\n    pooled = jnp.concatenate([mu, sigma], axis=-1)\n\n    # Segment-level layer -> embedding\n    embedding = jax.nn.relu(pooled @ params['seg1_w'] + params['seg1_b'])\n\n    if return_embedding:\n        return embedding\n\n    # Classification\n    logits = embedding @ params['cls_w'] + params['cls_b']\n    return logits\n\ndef cross_entropy_loss(params, features, labels):\n    logits = xvector_forward(params, features)\n    one_hot = jax.nn.one_hot(labels, n_speakers)\n    log_probs = jax.nn.log_softmax(logits)\n    return -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1))\n\ngrad_fn = jax.jit(jax.value_and_grad(cross_entropy_loss))\n\n# Train\nparams = init_xvector(jr.PRNGKey(0))\nlr = 1e-3\nlosses = []\n\nfor epoch in range(300):\n    loss_val, grads = grad_fn(params, features, labels)\n    params = jax.tree.map(lambda p, g: p - lr * g, params, grads)\n    losses.append(float(loss_val))\n\n# Extract embeddings and visualise with t-SNE-style 2D projection (using PCA)\nembeddings = xvector_forward(params, features, return_embedding=True)\n\n# Simple PCA to 2D\nemb_centered = embeddings - jnp.mean(embeddings, axis=0)\n_, _, Vt = jnp.linalg.svd(emb_centered, full_matrices=False)\nproj_2d = emb_centered @ Vt[:2].T\n\nfig, axes = plt.subplots(1, 2, figsize=(14, 5))\n\naxes[0].plot(losses, color='#3498db', linewidth=1.5)\naxes[0].set_xlabel('Epoch')\naxes[0].set_ylabel('Cross-Entropy Loss')\naxes[0].set_title('Speaker Classification Training')\naxes[0].set_yscale('log')\n\ncolors = ['#3498db', '#e74c3c', '#27ae60', '#f39c12', '#9b59b6']\nfor spk in range(n_speakers):\n    mask = labels == spk\n    axes[1].scatter(proj_2d[mask, 0], proj_2d[mask, 1], c=colors[spk],\n                    label=f'Speaker {spk}', alpha=0.7, s=30)\naxes[1].set_xlabel('PC 1')\naxes[1].set_ylabel('PC 2')\naxes[1].set_title('Speaker Embeddings (PCA projection)')\naxes[1].legend()\n\nplt.tight_layout()\nplt.show()\n\n# Verification demo: cosine similarity\nemb_norm = embeddings / jnp.linalg.norm(embeddings, axis=-1, keepdims=True)\nsim_matrix = emb_norm @ emb_norm.T\nprint(f\"Embedding shape: {embeddings.shape}\")\nprint(f\"Avg same-speaker similarity: {jnp.mean(sim_matrix[labels[:, None] == labels[None, :]]):.4f}\")\nprint(f\"Avg diff-speaker similarity: {jnp.mean(sim_matrix[labels[:, None] != labels[None, :]]):.4f}\")\n
import jax\nimport jax.numpy as jnp\nimport jax.random as jr\nimport matplotlib.pyplot as plt\n\ndef generate_verification_pairs(key, n_speakers=20, dim=64, n_pairs=2000):\n    \"\"\"Generate speaker embeddings and verification trial pairs.\"\"\"\n    keys = jr.split(key, 5)\n\n    # Speaker centroids with some variance\n    centroids = jr.normal(keys[0], (n_speakers, dim))\n    centroids = centroids / jnp.linalg.norm(centroids, axis=-1, keepdims=True)\n\n    # Generate enrollment and test embeddings with intra-speaker variance\n    enroll_embs = []\n    test_embs = []\n    trial_labels = []  # 1 = same speaker (target), 0 = different (impostor)\n\n    for i in range(n_pairs):\n        k1, k2, k3 = jr.split(jr.fold_in(keys[1], i), 3)\n        is_target = jr.bernoulli(k1).astype(int)\n\n        spk1 = jr.randint(k2, (), 0, n_speakers)\n        emb1 = centroids[spk1] + jr.normal(jr.fold_in(k3, 0), (dim,)) * 0.15\n\n        if is_target:\n            spk2 = spk1\n        else:\n            spk2 = (spk1 + jr.randint(jr.fold_in(k3, 1), (), 1, n_speakers)) % n_speakers\n\n        emb2 = centroids[spk2] + jr.normal(jr.fold_in(k3, 2), (dim,)) * 0.15\n\n        enroll_embs.append(emb1)\n        test_embs.append(emb2)\n        trial_labels.append(int(is_target))\n\n    return (jnp.stack(enroll_embs), jnp.stack(test_embs),\n            jnp.array(trial_labels))\n\nkey = jr.PRNGKey(42)\nenroll, test, labels = generate_verification_pairs(key)\n\n# Compute cosine similarity scores\nenroll_norm = enroll / jnp.linalg.norm(enroll, axis=-1, keepdims=True)\ntest_norm = test / jnp.linalg.norm(test, axis=-1, keepdims=True)\nscores = jnp.sum(enroll_norm * test_norm, axis=-1)\n\n# Compute FAR and FRR at various thresholds\nthresholds = jnp.linspace(-1.0, 1.0, 500)\n\ntarget_scores = scores[labels == 1]\nimpostor_scores = scores[labels == 0]\n\nfars = []\nfrrs = []\nfor thresh in thresholds:\n    far = jnp.mean(impostor_scores >= thresh)  # false accepts\n    frr = jnp.mean(target_scores < thresh)     # false rejects\n    fars.append(float(far))\n    frrs.append(float(frr))\n\nfars = jnp.array(fars)\nfrrs = jnp.array(frrs)\n\n# Find EER: where FAR \u2248 FRR\neer_idx = jnp.argmin(jnp.abs(fars - frrs))\neer = float((fars[eer_idx] + frrs[eer_idx]) / 2)\neer_threshold = float(thresholds[eer_idx])\n\nprint(f\"Equal Error Rate (EER): {eer:.4f} ({eer*100:.2f}%)\")\nprint(f\"EER threshold: {eer_threshold:.4f}\")\n\nfig, axes = plt.subplots(1, 3, figsize=(18, 5))\n\n# Score distributions\nbins = jnp.linspace(-0.5, 1.0, 60)\naxes[0].hist(target_scores, bins=bins, alpha=0.6, color='#27ae60',\n             label='Target (same speaker)', density=True)\naxes[0].hist(impostor_scores, bins=bins, alpha=0.6, color='#e74c3c',\n             label='Impostor (different speaker)', density=True)\naxes[0].axvline(eer_threshold, color='#f39c12', linestyle='--', linewidth=2,\n                label=f'EER threshold = {eer_threshold:.3f}')\naxes[0].set_xlabel('Cosine Similarity Score')\naxes[0].set_ylabel('Density')\naxes[0].set_title('Score Distributions')\naxes[0].legend()\n\n# FAR vs FRR\naxes[1].plot(thresholds, fars, color='#e74c3c', linewidth=2, label='FAR')\naxes[1].plot(thresholds, frrs, color='#3498db', linewidth=2, label='FRR')\naxes[1].axvline(eer_threshold, color='#f39c12', linestyle='--', linewidth=1.5)\naxes[1].scatter([eer_threshold], [eer], color='#f39c12', s=100, zorder=5,\n                label=f'EER = {eer:.4f}')\naxes[1].set_xlabel('Threshold')\naxes[1].set_ylabel('Error Rate')\naxes[1].set_title('FAR and FRR vs Threshold')\naxes[1].legend()\n\n# DET curve (FAR vs FRR)\naxes[2].plot(fars, frrs, color='#9b59b6', linewidth=2)\naxes[2].plot([0, 1], [0, 1], 'k--', alpha=0.3)\naxes[2].scatter([eer], [eer], color='#f39c12', s=100, zorder=5,\n                label=f'EER = {eer:.4f}')\naxes[2].set_xlabel('False Acceptance Rate')\naxes[2].set_ylabel('False Rejection Rate')\naxes[2].set_title('DET Curve')\naxes[2].set_xlim([0, 0.5])\naxes[2].set_ylim([0, 0.5])\naxes[2].legend()\naxes[2].set_aspect('equal')\n\nplt.tight_layout()\nplt.show()\n
import jax\nimport jax.numpy as jnp\nimport jax.random as jr\nimport matplotlib.pyplot as plt\n\n# Generate a synthetic spectrogram (harmonic structure + noise)\ndef generate_spectrogram(key, n_time=128, n_freq=128):\n    \"\"\"Create a synthetic spectrogram with harmonic patterns.\"\"\"\n    k1, k2 = jr.split(key)\n    spec = jr.normal(k1, (n_time, n_freq)) * 0.1\n\n    # Add harmonic bands (simulating speech formants)\n    for f0 in [15, 30, 45, 70]:\n        width = 3\n        envelope = jnp.exp(-0.5 * ((jnp.arange(n_freq) - f0) / width) ** 2)\n        time_mod = 0.5 + 0.5 * jnp.sin(2 * jnp.pi * jnp.arange(n_time) / 40)\n        spec += jnp.outer(time_mod, envelope)\n\n    return jnp.clip(spec, 0, None)\n\nkey = jr.PRNGKey(42)\nspectrogram = generate_spectrogram(key)\nn_time, n_freq = spectrogram.shape\n\n# Patch extraction parameters\npatch_h = 16  # time\npatch_w = 16  # frequency\nstride_h = 16\nstride_w = 16\nembed_dim = 192  # ViT-Small dimension\n\nn_patches_h = n_time // stride_h\nn_patches_w = n_freq // stride_w\nn_patches = n_patches_h * n_patches_w\n\nprint(f\"Spectrogram: {n_time} x {n_freq}\")\nprint(f\"Patch size: {patch_h} x {patch_w}\")\nprint(f\"Number of patches: {n_patches_h} x {n_patches_w} = {n_patches}\")\n\n# Extract patches\ndef extract_patches(spec, patch_h, patch_w, stride_h, stride_w):\n    \"\"\"Extract non-overlapping patches from spectrogram.\"\"\"\n    patches = []\n    positions = []\n    for i in range(0, spec.shape[0] - patch_h + 1, stride_h):\n        for j in range(0, spec.shape[1] - patch_w + 1, stride_w):\n            patch = spec[i:i+patch_h, j:j+patch_w]\n            patches.append(patch.flatten())\n            positions.append((i, j))\n    return jnp.stack(patches), positions\n\npatches, positions = extract_patches(spectrogram, patch_h, patch_w, stride_h, stride_w)\nprint(f\"Patches shape: {patches.shape}\")  # (n_patches, patch_h * patch_w)\n\n# Linear projection (patch embedding)\npatch_dim = patch_h * patch_w\nk1, k2 = jr.split(jr.PRNGKey(0))\nW_embed = jr.normal(k1, (patch_dim, embed_dim)) * jnp.sqrt(2.0 / patch_dim)\nb_embed = jnp.zeros(embed_dim)\n\n# Learnable positional embeddings\npos_embed = jr.normal(k2, (n_patches + 1, embed_dim)) * 0.02  # +1 for CLS\n\n# CLS token\ncls_token = jnp.zeros((1, embed_dim))\n\n# Forward pass\npatch_tokens = patches @ W_embed + b_embed  # (n_patches, embed_dim)\ntokens = jnp.concatenate([cls_token, patch_tokens], axis=0)  # (n_patches+1, embed_dim)\ntokens = tokens + pos_embed  # Add positional embeddings\n\nprint(f\"Token sequence shape: {tokens.shape}\")\nprint(f\"Each token has dimension: {embed_dim}\")\n\n# Visualisation\nfig, axes = plt.subplots(2, 2, figsize=(14, 10))\n\n# Original spectrogram with patch grid\naxes[0, 0].imshow(spectrogram.T, aspect='auto', origin='lower', cmap='magma')\nfor i in range(0, n_time + 1, stride_h):\n    axes[0, 0].axvline(i - 0.5, color='white', linewidth=0.5, alpha=0.5)\nfor j in range(0, n_freq + 1, stride_w):\n    axes[0, 0].axhline(j - 0.5, color='white', linewidth=0.5, alpha=0.5)\naxes[0, 0].set_title(f'Spectrogram with {patch_h}x{patch_w} Patch Grid')\naxes[0, 0].set_xlabel('Time frame')\naxes[0, 0].set_ylabel('Frequency bin')\n\n# Individual patches visualised\nn_show = min(16, n_patches)\npatch_grid = patches[:n_show].reshape(n_show, patch_h, patch_w)\ncombined = jnp.concatenate([patch_grid[i] for i in range(min(8, n_show))], axis=1)\naxes[0, 1].imshow(combined.T, aspect='auto', origin='lower', cmap='magma')\naxes[0, 1].set_title(f'First {min(8, n_show)} Patches (concatenated)')\naxes[0, 1].set_xlabel('Patch index (horizontal)')\naxes[0, 1].set_ylabel('Frequency within patch')\n\n# Token embeddings similarity matrix\ntoken_norms = tokens / jnp.linalg.norm(tokens, axis=-1, keepdims=True)\nsim = token_norms @ token_norms.T\nim = axes[1, 0].imshow(sim, cmap='RdBu_r', vmin=-1, vmax=1)\naxes[1, 0].set_title('Token Similarity Matrix (cosine)')\naxes[1, 0].set_xlabel('Token index')\naxes[1, 0].set_ylabel('Token index')\nplt.colorbar(im, ax=axes[1, 0], fraction=0.046)\n\n# Positional embedding similarity\npos_norms = pos_embed / jnp.linalg.norm(pos_embed, axis=-1, keepdims=True)\npos_sim = pos_norms @ pos_norms.T\nim2 = axes[1, 1].imshow(pos_sim, cmap='RdBu_r', vmin=-1, vmax=1)\naxes[1, 1].set_title('Positional Embedding Similarity')\naxes[1, 1].set_xlabel('Position index')\naxes[1, 1].set_ylabel('Position index')\nplt.colorbar(im2, ax=axes[1, 1], fraction=0.046)\n\nplt.tight_layout()\nplt.show()\n
import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# Generate a synthetic musical signal: C major chord -> G major chord\nsr = 16000\nduration = 2.0\nt = jnp.linspace(0, duration, int(sr * duration))\n\n# C major (C4=261.6, E4=329.6, G4=392.0) for first half\n# G major (G3=196.0, B3=246.9, D4=293.7) for second half\nhalf = len(t) // 2\n\nc_major = (0.5 * jnp.sin(2 * jnp.pi * 261.63 * t[:half]) +\n           0.4 * jnp.sin(2 * jnp.pi * 329.63 * t[:half]) +\n           0.3 * jnp.sin(2 * jnp.pi * 392.00 * t[:half]))\n\ng_major = (0.5 * jnp.sin(2 * jnp.pi * 196.00 * t[:half]) +\n           0.4 * jnp.sin(2 * jnp.pi * 246.94 * t[:half]) +\n           0.3 * jnp.sin(2 * jnp.pi * 293.66 * t[:half]))\n\nsignal = jnp.concatenate([c_major, g_major])\n\n# Compute STFT\nn_fft = 4096  # high resolution for pitch accuracy\nhop_length = 512\nwindow = jnp.hanning(n_fft)\n\ndef stft(signal, n_fft, hop_length, window):\n    n_frames = 1 + (len(signal) - n_fft) // hop_length\n    frames = jnp.stack([\n        signal[i * hop_length : i * hop_length + n_fft] * window\n        for i in range(n_frames)\n    ])\n    return jnp.fft.rfft(frames, n=n_fft)\n\nS = stft(signal, n_fft, hop_length, window)\npower_spec = jnp.abs(S) ** 2\nfreqs = jnp.fft.rfftfreq(n_fft, 1.0 / sr)\n\n# Compute chromagram by mapping frequency bins to pitch classes\n# MIDI note number from frequency: 69 + 12 * log2(f / 440)\nnote_names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']\n\ndef freq_to_chroma(freq):\n    \"\"\"Map frequency to pitch class (0-11). Returns -1 for freq <= 0.\"\"\"\n    midi = 69 + 12 * jnp.log2(jnp.clip(freq, 1e-10, None) / 440.0)\n    return jnp.round(midi).astype(int) % 12\n\n# Build chromagram: sum power spectrum energy for each pitch class\nchromagram = jnp.zeros((power_spec.shape[0], 12))\nvalid_freqs = freqs[1:]  # skip DC\nvalid_power = power_spec[:, 1:]\n\nfor p in range(12):\n    # Find frequency bins belonging to this pitch class\n    chroma_bins = freq_to_chroma(valid_freqs)\n    mask = (chroma_bins == p).astype(jnp.float32)\n    chromagram = chromagram.at[:, p].set(\n        jnp.sum(valid_power * mask[None, :], axis=1)\n    )\n\n# Normalise each frame\nchromagram = chromagram / (jnp.max(chromagram, axis=1, keepdims=True) + 1e-8)\n\n# Visualisation\nfig, axes = plt.subplots(3, 1, figsize=(14, 10))\n\n# Waveform\naxes[0].plot(t[:3000], signal[:3000], color='#3498db', linewidth=0.5,\n             label='C major')\naxes[0].plot(t[half:half+3000], signal[half:half+3000], color='#e74c3c',\n             linewidth=0.5, label='G major')\naxes[0].set_title('Waveform: C major \u2192 G major')\naxes[0].set_ylabel('Amplitude')\naxes[0].set_xlabel('Time (s)')\naxes[0].legend()\n\n# Spectrogram (log scale)\ntime_axis = jnp.arange(power_spec.shape[0]) * hop_length / sr\naxes[1].imshow(jnp.log1p(power_spec[:, :500].T), aspect='auto', origin='lower',\n               cmap='magma', extent=[0, time_axis[-1], 0, freqs[500]])\naxes[1].set_title('Power Spectrogram')\naxes[1].set_ylabel('Frequency (Hz)')\naxes[1].set_xlabel('Time (s)')\n\n# Chromagram\nim = axes[2].imshow(chromagram.T, aspect='auto', origin='lower', cmap='YlOrRd',\n                     extent=[0, time_axis[-1], -0.5, 11.5])\naxes[2].set_yticks(range(12))\naxes[2].set_yticklabels(note_names)\naxes[2].set_title('Chromagram (pitch class energy over time)')\naxes[2].set_ylabel('Pitch class')\naxes[2].set_xlabel('Time (s)')\nplt.colorbar(im, ax=axes[2], fraction=0.046, label='Normalised energy')\n\n# Mark expected active pitch classes\nmid_frame = chromagram.shape[0] // 2\nprint(f\"C major region - expected: C, E, G\")\nprint(f\"  Chroma values: {dict(zip(note_names, [f'{v:.2f}' for v in chromagram[mid_frame//2]]))}\")\nprint(f\"G major region - expected: G, B, D\")\nprint(f\"  Chroma values: {dict(zip(note_names, [f'{v:.2f}' for v in chromagram[mid_frame + mid_frame//2]]))}\")\n\nplt.tight_layout()\nplt.show()\n
"},{"location":"chapter%2009%3A%20audio%20and%20speech/05.%20source%20separation%20and%20noise/","title":"\u6e90\u5206\u79bb\u4e0e\u964d\u566a","text":"

\u6e90\u5206\u79bb\u4e0e\u964d\u566a\u4ece\u6df7\u5408\u97f3\u9891\u4e2d\u6062\u590d\u5355\u4e2a\u4fe1\u53f7\uff1b\u5373\u8ba1\u7b97\u5c42\u9762\u7684\"\u9e21\u5c3e\u9152\u4f1a\u95ee\u9898\"\u3002\u672c\u6587\u6db5\u76d6ICA\u3001NMF\u3001\u65f6\u9891\u63a9\u853d\u3001\u6ce2\u675f\u6210\u5f62\u3001\u6df1\u5ea6\u5b66\u4e60\u5206\u79bb\u7f51\u7edc\uff08Conv-TasNet\u3001SepFormer\uff09\u3001\u8bed\u97f3\u589e\u5f3a\u4ee5\u53ca\u81ea\u9002\u5e94\u964d\u566a\u3002

\\[x(t) = \\sum_{c=1}^{C} s_c(t) + n(t)\\] \\[X(t, f) = \\sum_{c=1}^{C} S_c(t, f) + N(t, f)\\] \\[\\text{IRM}_c(t, f) = \\frac{|S_c(t, f)|^2}{\\sum_{j=1}^{C} |S_j(t, f)|^2}\\] \\[V \\approx WH\\] \\[ \\begin{aligned} \\text{Frobenius:} \\quad D_F(V \\| WH) &= \\|V - WH\\|_F^2 \\\\ \\text{KL:} \\quad D_{KL}(V \\| WH) &= \\sum_{f,t} \\left[ V_{ft} \\log \\frac{V_{ft}}{(WH)_{ft}} - V_{ft} + (WH)_{ft} \\right] \\end{aligned} \\]

\\[y(t) = \\frac{1}{M} \\sum_{m=1}^{M} x_m(t - \\tau_m(\\theta))\\] \\[ \\begin{aligned} \\min_{\\mathbf{w}} \\quad & \\mathbf{w}^H \\Phi_{nn} \\mathbf{w} \\\\ \\text{subject to} \\quad & \\mathbf{w}^H \\mathbf{d}(\\theta) = 1 \\end{aligned} \\] \\[\\mathbf{w}_{\\text{MVDR}} = \\frac{\\Phi_{nn}^{-1} \\mathbf{d}(\\theta)}{\\mathbf{d}(\\theta)^H \\Phi_{nn}^{-1} \\mathbf{d}(\\theta)}\\] \\[\\mathcal{L} = \\|VV^T - YY^T\\|_F^2\\]

\\[ \\begin{aligned} \\text{\u5757\u5185:} \\quad & h_{k,n}^{\\text{\u5757\u5185}} = \\text{BiLSTM}_{\\text{\u5757\u5185}}(z_{k,n}) \\\\ \\text{\u5757\u95f4:} \\quad & h_{k,n}^{\\text{\u5757\u95f4}} = \\text{BiLSTM}_{\\text{\u5757\u95f4}}(h_{k,n}^{\\text{\u5757\u5185}}) \\end{aligned} \\] \\[\\mathcal{L}_{\\text{PIT}} = \\min_{\\pi \\in \\mathcal{P}} \\sum_{c=1}^{C} \\ell(\\hat{s}_{\\pi(c)}, s_c)\\] \\[ \\begin{aligned} s_{\\text{target}} &= \\frac{\\langle \\hat{s}, s \\rangle}{\\|s\\|^2} s \\\\ e_{\\text{noise}} &= \\hat{s} - s_{\\text{target}} \\\\ \\text{SI-SDR} &= 10 \\log_{10} \\frac{\\|s_{\\text{target}}\\|^2}{\\|e_{\\text{noise}}\\|^2} \\end{aligned} \\]

\\[\\mathbf{w}(n+1) = \\mathbf{w}(n) + \\mu \\, e(n) \\, \\mathbf{x}(n)\\] \\[\\mathbf{w}(n+1) = \\mathbf{w}(n) + \\frac{\\mu}{\\|\\mathbf{x}(n)\\|^2 + \\epsilon} \\, e(n) \\, \\mathbf{x}(n)\\] \\[|\\hat{S}(f)|^2 = \\max(|X(f)|^2 - \\alpha |\\hat{N}(f)|^2, \\beta |X(f)|^2)\\] \\[\\hat{S}(t, f) = \\frac{|S(t,f)|^2}{|S(t,f)|^2 + |N(t,f)|^2} \\cdot X(t, f) = G(t, f) \\cdot X(t, f)\\] \\[\\xi(n) = \\frac{|\\sum_{k=0}^{L-1} x(n-k) d(n-k)|}{\\sqrt{\\sum_{k} x^2(n-k)} \\sqrt{\\sum_{k} d^2(n-k)}}\\] "},{"location":"chapter%2009%3A%20audio%20and%20speech/05.%20source%20separation%20and%20noise/#colab-notebook","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
import jax\nimport jax.numpy as jnp\nimport jax.random as jr\nimport matplotlib.pyplot as plt\n\n# \u751f\u6210\u4e24\u4e2a\u6e90\u4fe1\u53f7\nsr = 8000\nduration = 1.0\nt = jnp.linspace(0, duration, int(sr * duration))\n\n# \u6e90 1\uff1a\u6b63\u5f26\u6ce2\uff08\u7c7b\u4f3c\u97f3\u8c03\uff09\ns1 = jnp.sin(2 * jnp.pi * 440 * t) + 0.3 * jnp.sin(2 * jnp.pi * 880 * t)\n\n# \u6e90 2\uff1a\u952f\u9f7f\u6ce2\uff08\u4e30\u5bcc\u7684\u8c10\u6ce2\uff09\ns2 = 2 * (t * 200 % 1) - 1  # 200 Hz \u952f\u9f7f\u6ce2\n\n# \u5f52\u4e00\u5316\u6e90\u4fe1\u53f7\ns1 = s1 / jnp.max(jnp.abs(s1))\ns2 = s2 / jnp.max(jnp.abs(s2))\nsources = jnp.stack([s1, s2])  # (2, T)\n\n# \u6df7\u53e0\u77e9\u9635\uff08\u7b97\u6cd5\u672a\u77e5\uff09\nA = jnp.array([[0.8, 0.4],\n               [0.3, 0.9]])\nmixtures = A @ sources  # (2, T)\n\n# FastICA \u5b9e\u73b0\ndef whiten(X):\n    \"\"\"\u6570\u636e\u4e2d\u5fc3\u5316\u4e0e\u767d\u5316\u3002\"\"\"\n    X_centered = X - jnp.mean(X, axis=1, keepdims=True)\n    cov = (X_centered @ X_centered.T) / X_centered.shape[1]\n    eigvals, eigvecs = jnp.linalg.eigh(cov)\n    D_inv_sqrt = jnp.diag(1.0 / jnp.sqrt(eigvals + 1e-8))\n    whitening = D_inv_sqrt @ eigvecs.T\n    return whitening @ X_centered, whitening\n\ndef fastica(X, n_components=2, max_iter=200, tol=1e-6):\n    \"\"\"\u4f7f\u7528 tanh \u975e\u7ebf\u6027\u7684 FastICA\uff08\u8d1f\u71b5\u8fd1\u4f3c\uff09\u3002\"\"\"\n    X_white, whitening = whiten(X)\n    n, T = X_white.shape\n\n    key = jr.PRNGKey(42)\n    W = jr.normal(key, (n_components, n))\n    # \u6b63\u4ea4\u5316 W\n    U, _, Vt = jnp.linalg.svd(W, full_matrices=False)\n    W = U @ Vt\n\n    for iteration in range(max_iter):\n        W_old = W.copy()\n\n        # \u5bf9\u6bcf\u4e2a\u5206\u91cf\n        for i in range(n_components):\n            w = W[i]\n            # w^T X_white: (T,)\n            wx = w @ X_white  # (T,)\n\n            # g(u) = tanh(u), g'(u) = 1 - tanh^2(u)\n            g_wx = jnp.tanh(wx)\n            g_prime_wx = 1 - g_wx ** 2\n\n            # Newton \u66f4\u65b0: w_new = E[X * g(w^T X)] - E[g'(w^T X)] * w\n            w_new = jnp.mean(X_white * g_wx[None, :], axis=1) - \\\n                    jnp.mean(g_prime_wx) * w\n\n            # \u4e0e\u4e4b\u524d\u7684\u5206\u91cf\u53bb\u76f8\u5173\uff08\u6d88\u53bb\u6cd5\uff09\n            for j in range(i):\n                w_new = w_new - jnp.dot(w_new, W[j]) * W[j]\n\n            w_new = w_new / jnp.linalg.norm(w_new)\n            W = W.at[i].set(w_new)\n\n        # \u68c0\u67e5\u6536\u655b\n        convergence = jnp.min(jnp.abs(jnp.diag(W @ W_old.T)))\n        if convergence > 1 - tol:\n            print(f\"FastICA \u5728 {iteration + 1} \u6b21\u8fed\u4ee3\u540e\u6536\u655b\")\n            break\n\n    # \u89e3\u6df7\u77e9\u9635\n    unmixing = W @ whitening\n    recovered = unmixing @ X\n    return recovered, unmixing\n\nrecovered, W_unmix = fastica(mixtures)\n\n# \u4fee\u590d\u7b26\u53f7\u6b67\u4e49\uff08ICA \u53ef\u80fd\u7ffb\u8f6c\u7b26\u53f7\uff09\nfor i in range(2):\n    if jnp.corrcoef(recovered[i], sources[i])[0, 1] < -0.5:\n        recovered = recovered.at[i].set(-recovered[i])\n\n# \u5982\u679c\u6e90\u88ab\u4ea4\u6362\uff0c\u4fee\u590d\u6392\u5217\ncorr_00 = jnp.abs(jnp.corrcoef(recovered[0], sources[0])[0, 1])\ncorr_01 = jnp.abs(jnp.corrcoef(recovered[0], sources[1])[0, 1])\nif corr_01 > corr_00:\n    recovered = recovered[::-1]\n\n# \u5f52\u4e00\u5316\u4ee5\u4fbf\u663e\u793a\nrecovered = recovered / jnp.max(jnp.abs(recovered), axis=1, keepdims=True)\n\nfig, axes = plt.subplots(3, 2, figsize=(14, 9))\n\naxes[0, 0].plot(t[:1000], s1[:1000], color='#3498db', linewidth=0.8)\naxes[0, 0].set_title('\u6e90\u4fe1\u53f7 1\uff08\u539f\u59cb\uff09')\naxes[0, 0].set_ylabel('\u5e45\u5ea6')\n\naxes[0, 1].plot(t[:1000], s2[:1000], color='#e74c3c', linewidth=0.8)\naxes[0, 1].set_title('\u6e90\u4fe1\u53f7 2\uff08\u539f\u59cb\uff09')\n\naxes[1, 0].plot(t[:1000], mixtures[0, :1000], color='#9b59b6', linewidth=0.8)\naxes[1, 0].set_title('\u6df7\u5408\u4fe1\u53f7 1\uff08\u9ea6\u514b\u98ce 1\uff09')\naxes[1, 0].set_ylabel('\u5e45\u5ea6')\n\naxes[1, 1].plot(t[:1000], mixtures[1, :1000], color='#9b59b6', linewidth=0.8)\naxes[1, 1].set_title('\u6df7\u5408\u4fe1\u53f7 2\uff08\u9ea6\u514b\u98ce 2\uff09')\n\naxes[2, 0].plot(t[:1000], recovered[0, :1000], color='#27ae60', linewidth=0.8)\naxes[2, 0].set_title('\u6062\u590d\u7684\u6e90\u4fe1\u53f7 1\uff08FastICA\uff09')\naxes[2, 0].set_ylabel('\u5e45\u5ea6')\naxes[2, 0].set_xlabel('\u65f6\u95f4 (s)')\n\naxes[2, 1].plot(t[:1000], recovered[1, :1000], color='#f39c12', linewidth=0.8)\naxes[2, 1].set_title('\u6062\u590d\u7684\u6e90\u4fe1\u53f7 2\uff08FastICA\uff09')\naxes[2, 1].set_xlabel('\u65f6\u95f4 (s)')\n\nplt.tight_layout()\nplt.show()\n\n# \u62a5\u544a\u4e0e\u539f\u59cb\u4fe1\u53f7\u7684\u76f8\u5173\u6027\nfor i in range(2):\n    corr = jnp.corrcoef(recovered[i], sources[i])[0, 1]\n    print(f\"\u6e90 {i+1} \u6062\u590d\u76f8\u5173\u6027: {corr:.4f}\")\n
import jax\nimport jax.numpy as jnp\nimport jax.random as jr\nimport matplotlib.pyplot as plt\n\n# \u751f\u6210\u4e24\u4e2a\u5177\u6709\u4e0d\u540c\u9891\u8c31\u7279\u5f81\u7684\u4fe1\u53f7\nsr = 8000\nduration = 1.0\nt = jnp.linspace(0, duration, int(sr * duration))\n\n# \u6e90 1\uff1a\u4f4e\u9891\u8c10\u6ce2\uff08\u6a21\u62df\u8d1d\u65af\uff09\nsrc1 = (jnp.sin(2 * jnp.pi * 100 * t) +\n        0.5 * jnp.sin(2 * jnp.pi * 200 * t) +\n        0.3 * jnp.sin(2 * jnp.pi * 300 * t))\n\n# \u6e90 2\uff1a\u9ad8\u9891\u8c10\u6ce2\uff08\u6a21\u62df\u957f\u7b1b\uff09\nsrc2 = (jnp.sin(2 * jnp.pi * 800 * t) +\n        0.4 * jnp.sin(2 * jnp.pi * 1600 * t))\n\n# \u65f6\u53d8\u5e45\u5ea6\uff08\u6e90\u5728\u4e0d\u540c\u65f6\u95f4\u6fc0\u6d3b\uff09\nenv1 = jnp.where(t < 0.5, 1.0, 0.3)\nenv2 = jnp.where(t > 0.3, 1.0, 0.2)\nsrc1 = src1 * env1\nsrc2 = src2 * env2\n\nmixture = src1 + src2\n\n# \u8ba1\u7b97\u5e45\u5ea6\u8bed\u8c31\u56fe\uff08STFT\uff09\nn_fft = 512\nhop = 128\nwindow = jnp.hanning(n_fft)\n\ndef compute_stft(signal, n_fft, hop, window):\n    n_frames = 1 + (len(signal) - n_fft) // hop\n    frames = jnp.stack([\n        signal[i * hop : i * hop + n_fft] * window\n        for i in range(n_frames)\n    ])\n    return jnp.fft.rfft(frames, n=n_fft)\n\nS_mix = compute_stft(mixture, n_fft, hop, window)\nV = jnp.abs(S_mix).T  # (F, T) - \u9891\u7387 x \u65f6\u95f4\nphase = jnp.angle(S_mix).T\n\nF, T = V.shape\nprint(f\"\u8bed\u8c31\u56fe\u5f62\u72b6: {F} \u4e2a\u9891\u7387 bin x {T} \u4e2a\u65f6\u95f4\u5e27\")\n\n# NMF: V \u2248 WH \u4f7f\u7528\u4e58\u6cd5\u66f4\u65b0\u89c4\u5219\ndef nmf(V, K, n_iter=200, key=jr.PRNGKey(0)):\n    \"\"\"\u4f7f\u7528 Frobenius \u8303\u6570\u7684\u975e\u8d1f\u77e9\u9635\u5206\u89e3\u3002\"\"\"\n    k1, k2 = jr.split(key)\n    W = jnp.abs(jr.normal(k1, (F, K))) * 0.1 + 0.01  # (F, K)\n    H = jnp.abs(jr.normal(k2, (K, T))) * 0.1 + 0.01  # (K, T)\n\n    costs = []\n    for i in range(n_iter):\n        # H \u7684\u4e58\u6cd5\u66f4\u65b0\n        WtV = W.T @ V\n        WtWH = W.T @ W @ H + 1e-8\n        H = H * (WtV / WtWH)\n\n        # W \u7684\u4e58\u6cd5\u66f4\u65b0\n        VHt = V @ H.T\n        WHHt = W @ H @ H.T + 1e-8\n        W = W * (VHt / WHHt)\n\n        cost = jnp.sum((V - W @ H) ** 2)\n        costs.append(float(cost))\n\n    return W, H, costs\n\n# \u8fd0\u884c K=2 \u4e2a\u5206\u91cf\u7684 NMF\nK = 2\nW, H, costs = nmf(V, K, n_iter=300)\n\n# \u4f7f\u7528\u8f6f\u63a9\u853d\u91cd\u5efa\u6bcf\u4e2a\u6e90\nV_hat = W @ H\nmask1 = (W[:, 0:1] @ H[0:1, :]) / (V_hat + 1e-8)\nmask2 = (W[:, 1:2] @ H[1:2, :]) / (V_hat + 1e-8)\n\nV_src1 = mask1 * V\nV_src2 = mask2 * V\n\n# \u53ef\u89c6\u5316\nfig, axes = plt.subplots(3, 2, figsize=(14, 10))\n\n# \u6df7\u5408\u4fe1\u53f7\u8bed\u8c31\u56fe\naxes[0, 0].imshow(jnp.log1p(V), aspect='auto', origin='lower', cmap='magma')\naxes[0, 0].set_title('\u6df7\u5408\u4fe1\u53f7\u8bed\u8c31\u56fe |X|')\naxes[0, 0].set_ylabel('\u9891\u7387 bin')\n\n# NMF \u6536\u655b\naxes[0, 1].plot(costs, color='#3498db', linewidth=1.5)\naxes[0, 1].set_title('NMF \u6536\u655b\u66f2\u7ebf')\naxes[0, 1].set_xlabel('\u8fed\u4ee3\u6b21\u6570')\naxes[0, 1].set_ylabel('Frobenius \u4ee3\u4ef7')\naxes[0, 1].set_yscale('log')\n\n# \u9891\u8c31\u57fa\u5411\u91cf W\nfreq_hz = jnp.arange(F) * sr / n_fft\naxes[1, 0].plot(freq_hz, W[:, 0], color='#27ae60', linewidth=1.5,\n                label='\u57fa 1\uff08\u4f4e\u9891\uff09')\naxes[1, 0].plot(freq_hz, W[:, 1], color='#e74c3c', linewidth=1.5,\n                label='\u57fa 2\uff08\u9ad8\u9891\uff09')\naxes[1, 0].set_title('\u5b66\u4e60\u5230\u7684\u9891\u8c31\u57fa W')\naxes[1, 0].set_xlabel('\u9891\u7387 (Hz)')\naxes[1, 0].set_ylabel('\u5e45\u5ea6')\naxes[1, 0].legend()\n\n# \u65f6\u57df\u6fc0\u6d3b H\ntime_s = jnp.arange(T) * hop / sr\naxes[1, 1].plot(time_s, H[0], color='#27ae60', linewidth=1.5,\n                label='\u6fc0\u6d3b 1')\naxes[1, 1].plot(time_s, H[1], color='#e74c3c', linewidth=1.5,\n                label='\u6fc0\u6d3b 2')\naxes[1, 1].set_title('\u65f6\u57df\u6fc0\u6d3b H')\naxes[1, 1].set_xlabel('\u65f6\u95f4 (s)')\naxes[1, 1].set_ylabel('\u6fc0\u6d3b\u503c')\naxes[1, 1].legend()\n\n# \u5206\u79bb\u540e\u7684\u8bed\u8c31\u56fe\naxes[2, 0].imshow(jnp.log1p(V_src1), aspect='auto', origin='lower', cmap='magma')\naxes[2, 0].set_title('\u5206\u79bb\u540e\u7684\u6e90\u4fe1\u53f7 1\uff08\u4f4e\u9891\uff09')\naxes[2, 0].set_ylabel('\u9891\u7387 bin')\naxes[2, 0].set_xlabel('\u65f6\u95f4\u5e27')\n\naxes[2, 1].imshow(jnp.log1p(V_src2), aspect='auto', origin='lower', cmap='magma')\naxes[2, 1].set_title('\u5206\u79bb\u540e\u7684\u6e90\u4fe1\u53f7 2\uff08\u9ad8\u9891\uff09')\naxes[2, 1].set_xlabel('\u65f6\u95f4\u5e27')\n\nplt.tight_layout()\nplt.show()\n\nprint(f\"\u91cd\u5efa\u8bef\u5dee: {jnp.sum((V - W @ H)**2):.2f}\")\nprint(f\"NMF \u5b66\u4e60\u5230\u7684\u9891\u8c31\u57fa\u80fd\u591f\u6355\u6349\u6bcf\u4e2a\u6e90\u7684\u9891\u7387\u7279\u5f81\u3002\")\n
import jax\nimport jax.numpy as jnp\nimport jax.random as jr\nimport matplotlib.pyplot as plt\n\n# \u6a21\u62df\u56de\u58f0\u6d88\u9664\u573a\u666f\n# \u8fdc\u7aef\u4fe1\u53f7 -> \u623f\u95f4\u8109\u51b2\u54cd\u5e94 -> \u9ea6\u514b\u98ce\u5904\u7684\u56de\u58f0\n# \u8fd1\u7aef\u8bed\u97f3\u662f\u6211\u4eec\u5e0c\u671b\u4fdd\u7559\u7684\u76ee\u6807\u4fe1\u53f7\n\nsr = 8000\nduration = 2.0\nn_samples = int(sr * duration)\nkey = jr.PRNGKey(42)\nkeys = jr.split(key, 5)\n\n# \u8fdc\u7aef\u4fe1\u53f7\uff08\u53c2\u8003\uff09\uff1a\u968f\u673a\u7684\u7c7b\u8bed\u97f3\u4fe1\u53f7\nfar_end = jr.normal(keys[0], (n_samples,)) * 0.5\n\n# \u623f\u95f4\u8109\u51b2\u54cd\u5e94\uff08\u7b97\u6cd5\u672a\u77e5\uff09\nrir_length = 64\nrir = jnp.zeros(rir_length)\nrir = rir.at[0].set(0.8)   # \u76f4\u8fbe\u8def\u5f84\nrir = rir.at[5].set(0.3)   # \u65e9\u671f\u53cd\u5c04\nrir = rir.at[12].set(-0.2) # \u53cd\u5c04\nrir = rir.at[25].set(0.1)  # \u665a\u671f\u53cd\u5c04\nrir = rir.at[40].set(-0.05)\n\n# \u56de\u58f0\uff1a\u8fdc\u7aef\u4fe1\u53f7\u4e0e RIR \u7684\u5377\u79ef\necho = jnp.convolve(far_end, rir)[:n_samples]\n\n# \u8fd1\u7aef\u8bed\u97f3\uff08\u5728\u4fe1\u53f7\u7684\u4e00\u90e8\u5206\u4e2d\u6d3b\u8dc3\uff09\nnear_end = jnp.zeros(n_samples)\nstart, end = n_samples // 3, 2 * n_samples // 3\nnear_speech = 0.3 * jnp.sin(\n    2 * jnp.pi * 300 * jnp.linspace(0, (end - start) / sr, end - start)\n)\nnear_end = near_end.at[start:end].set(near_speech)\n\n# \u9ea6\u514b\u98ce\u4fe1\u53f7\uff1a\u56de\u58f0 + \u8fd1\u7aef + \u566a\u58f0\nnoise = jr.normal(keys[1], (n_samples,)) * 0.01\nmic_signal = echo + near_end + noise\n\n# LMS \u81ea\u9002\u5e94\u6ee4\u6ce2\u5668\ndef lms_filter(reference, desired, filter_length, mu):\n    \"\"\"\u6807\u51c6 LMS \u81ea\u9002\u5e94\u6ee4\u6ce2\u5668\u3002\"\"\"\n    n = len(reference)\n    w = jnp.zeros(filter_length)\n    output = jnp.zeros(n)\n    error = jnp.zeros(n)\n    w_history = []\n\n    for i in range(filter_length, n):\n        x = reference[max(0, i-filter_length+1):i+1][::-1]\n\n        y = jnp.dot(w, x)\n        e = desired[i] - y\n        w = w + mu * e * x\n\n        output = output.at[i].set(y)\n        error = error.at[i].set(e)\n\n        if i % 500 == 0:\n            w_history.append(w.copy())\n\n    return output, error, w_history\n\n# NLMS \u81ea\u9002\u5e94\u6ee4\u6ce2\u5668\ndef nlms_filter(reference, desired, filter_length, mu, eps=1e-6):\n    \"\"\"\u5f52\u4e00\u5316 LMS \u81ea\u9002\u5e94\u6ee4\u6ce2\u5668\u3002\"\"\"\n    n = len(reference)\n    w = jnp.zeros(filter_length)\n    output = jnp.zeros(n)\n    error = jnp.zeros(n)\n\n    for i in range(filter_length, n):\n        x = reference[max(0, i-filter_length+1):i+1][::-1]\n\n        y = jnp.dot(w, x)\n        e = desired[i] - y\n        norm_factor = jnp.dot(x, x) + eps\n        w = w + (mu / norm_factor) * e * x\n\n        output = output.at[i].set(y)\n        error = error.at[i].set(e)\n\n    return output, error\n\n# \u4f7f\u7528\u4e0d\u540c\u6b65\u957f\u8fd0\u884c LMS\nfilter_len = 64\nmu_values = [0.001, 0.01, 0.05]\ncolors_mu = ['#3498db', '#e74c3c', '#27ae60']\n\nfig, axes = plt.subplots(2, 2, figsize=(14, 10))\n\n# \u539f\u59cb\u4fe1\u53f7\nt = jnp.arange(n_samples) / sr\naxes[0, 0].plot(t, mic_signal, color='#9b59b6', linewidth=0.5, alpha=0.7,\n                label='\u9ea6\u514b\u98ce\uff08\u56de\u58f0 + \u8fd1\u7aef\uff09')\naxes[0, 0].plot(t, echo, color='#e74c3c', linewidth=0.5, alpha=0.7,\n                label='\u56de\u58f0\uff08\u5f85\u6d88\u9664\uff09')\naxes[0, 0].plot(t, near_end, color='#27ae60', linewidth=0.8,\n                label='\u8fd1\u7aef\u8bed\u97f3\uff08\u9700\u4fdd\u7559\uff09')\naxes[0, 0].set_title('\u4fe1\u53f7\u5206\u91cf')\naxes[0, 0].set_xlabel('\u65f6\u95f4 (s)')\naxes[0, 0].set_ylabel('\u5e45\u5ea6')\naxes[0, 0].legend(fontsize=8)\n\n# \u4e0d\u540c\u6b65\u957f\u4e0b\u7684 LMS \u6536\u655b\nfor mu, color in zip(mu_values, colors_mu):\n    _, err, _ = lms_filter(far_end, mic_signal, filter_len, mu)\n    # \u5e73\u6ed1\u540e\u7684\u5e73\u65b9\u8bef\u5dee\n    sq_err = err ** 2\n    window_size = 200\n    smoothed = jnp.convolve(sq_err, jnp.ones(window_size)/window_size,\n                             mode='valid')\n    axes[0, 1].plot(smoothed, color=color, linewidth=1.2,\n                    label=f'mu={mu}')\n\naxes[0, 1].set_title('LMS \u6536\u655b\u66f2\u7ebf\uff08\u5e73\u6ed1 MSE\uff09')\naxes[0, 1].set_xlabel('\u6837\u672c')\naxes[0, 1].set_ylabel('\u5e73\u65b9\u8bef\u5dee')\naxes[0, 1].set_yscale('log')\naxes[0, 1].legend()\n\n# \u6700\u4f73 LMS \u7ed3\u679c\n_, err_lms, w_hist = lms_filter(far_end, mic_signal, filter_len, 0.01)\naxes[1, 0].plot(t, mic_signal, color='#9b59b6', linewidth=0.5, alpha=0.4,\n                label='\u6d88\u9664\u524d')\naxes[1, 0].plot(t, err_lms, color='#3498db', linewidth=0.5, alpha=0.8,\n                label='LMS \u6d88\u9664\u540e')\naxes[1, 0].plot(t, near_end, color='#27ae60', linewidth=0.8, alpha=0.5,\n                label='\u771f\u5b9e\u8fd1\u7aef')\naxes[1, 0].set_title('LMS \u56de\u58f0\u6d88\u9664\u7ed3\u679c (mu=0.01)')\naxes[1, 0].set_xlabel('\u65f6\u95f4 (s)')\naxes[1, 0].set_ylabel('\u5e45\u5ea6')\naxes[1, 0].legend(fontsize=8)\n\n# NLMS \u7ed3\u679c\n_, err_nlms = nlms_filter(far_end, mic_signal, filter_len, 0.5)\naxes[1, 1].plot(t, mic_signal, color='#9b59b6', linewidth=0.5, alpha=0.4,\n                label='\u6d88\u9664\u524d')\naxes[1, 1].plot(t, err_nlms, color='#f39c12', linewidth=0.5, alpha=0.8,\n                label='NLMS \u6d88\u9664\u540e')\naxes[1, 1].plot(t, near_end, color='#27ae60', linewidth=0.8, alpha=0.5,\n                label='\u771f\u5b9e\u8fd1\u7aef')\naxes[1, 1].set_title('NLMS \u56de\u58f0\u6d88\u9664\u7ed3\u679c (mu=0.5)')\naxes[1, 1].set_xlabel('\u65f6\u95f4 (s)')\naxes[1, 1].set_ylabel('\u5e45\u5ea6')\naxes[1, 1].legend(fontsize=8)\n\nplt.tight_layout()\nplt.show()\n\n# \u6d4b\u91cf\u56de\u58f0\u8870\u51cf\necho_power = jnp.mean(echo ** 2)\nlms_residual = jnp.mean(err_lms[n_samples//2:] ** 2)  # \u6536\u655b\u540e\nnlms_residual = jnp.mean(err_nlms[n_samples//2:] ** 2)\nprint(f\"\u56de\u58f0\u529f\u7387: {10*jnp.log10(echo_power):.1f} dB\")\nprint(f\"LMS \u6b8b\u5dee: {10*jnp.log10(lms_residual):.1f} dB \"\n      f\"(ERLE: {10*jnp.log10(echo_power/lms_residual):.1f} dB)\")\nprint(f\"NLMS \u6b8b\u5dee: {10*jnp.log10(nlms_residual):.1f} dB \"\n      f\"(ERLE: {10*jnp.log10(echo_power/nlms_residual):.1f} dB)\")\n
import jax\nimport jax.numpy as jnp\nimport jax.random as jr\nimport matplotlib.pyplot as plt\n\n# \u521b\u5efa\u5408\u6210\u7684\"\u8bed\u97f3\"\u548c\"\u566a\u58f0\"\u4fe1\u53f7\nsr = 8000\nduration = 2.0\nt = jnp.linspace(0, duration, int(sr * duration))\n\n# \u8bed\u97f3\uff1a\u5177\u6709\u65f6\u53d8\u5e45\u5ea6\u7684\u8c10\u6ce2\u5e8f\u5217\uff08\u6a21\u62df\u8bed\u97f3\uff09\nspeech = jnp.zeros_like(t)\nfor f0 in [150, 300, 450, 600, 900]:\n    amp_env = 0.5 + 0.5 * jnp.sin(2 * jnp.pi * 2.0 * t)  # 2 Hz \u8c03\u5236\n    speech = speech + (0.5 / (f0/150)) * amp_env * jnp.sin(2 * jnp.pi * f0 * t)\nspeech = speech / jnp.max(jnp.abs(speech))\n\n# \u566a\u58f0\uff1a\u9650\u5e26\u566a\u58f0\nkey = jr.PRNGKey(42)\nnoise_raw = jr.normal(key, t.shape) * 0.4\n\n# \u5728\u7ed9\u5b9a SNR \u4e0b\u6df7\u5408\nsnr_db = 5.0\nspeech_power = jnp.mean(speech ** 2)\nnoise_power = jnp.mean(noise_raw ** 2)\nnoise_scale = jnp.sqrt(speech_power / (noise_power * 10 ** (snr_db / 10)))\nnoise = noise_raw * noise_scale\nmixture = speech + noise\n\n# STFT\nn_fft = 512\nhop = 128\nwindow = jnp.hanning(n_fft)\n\ndef stft(signal, n_fft, hop, window):\n    n_frames = 1 + (len(signal) - n_fft) // hop\n    frames = jnp.stack([\n        signal[i * hop : i * hop + n_fft] * window\n        for i in range(n_frames)\n    ])\n    return jnp.fft.rfft(frames, n=n_fft)\n\ndef istft(S, hop, window, length):\n    n_fft = (S.shape[1] - 1) * 2\n    n_frames = S.shape[0]\n    frames = jnp.fft.irfft(S, n=n_fft) * window[None, :]\n    output = jnp.zeros(length)\n    window_sum = jnp.zeros(length)\n    for i in range(n_frames):\n        start = i * hop\n        end = start + n_fft\n        if end <= length:\n            output = output.at[start:end].add(frames[i])\n            window_sum = window_sum.at[start:end].add(window ** 2)\n    window_sum = jnp.maximum(window_sum, 1e-8)\n    return output / window_sum\n\nS_speech = stft(speech, n_fft, hop, window)\nS_noise = stft(noise, n_fft, hop, window)\nS_mix = stft(mixture, n_fft, hop, window)\n\nmag_speech = jnp.abs(S_speech)\nmag_noise = jnp.abs(S_noise)\nmag_mix = jnp.abs(S_mix)\nphase_mix = jnp.angle(S_mix)\n\n# \u65b9\u6cd5 1\uff1a\u7406\u60f3\u6bd4\u7387\u63a9\u853d\uff08oracle - \u7406\u8bba\u4e0a\u9650\uff09\nirm = mag_speech ** 2 / (mag_speech ** 2 + mag_noise ** 2 + 1e-8)\nS_irm = (irm * mag_mix) * jnp.exp(1j * phase_mix)\nenhanced_irm = istft(S_irm, hop, window, len(mixture))\n\n# \u65b9\u6cd5 2\uff1a\u8c31\u51cf\u6cd5\n# \u4ece\u524d 0.2s \u4f30\u8ba1\u566a\u58f0\uff08\u5047\u8bbe\u4e3a\u9759\u97f3\u6bb5\uff09\nnoise_frames = int(0.2 * sr / hop)\nnoise_est = jnp.mean(mag_mix[:noise_frames] ** 2, axis=0, keepdims=True)\nalpha = 2.0  # \u8fc7\u51cf\u56e0\u5b50\nbeta = 0.02  # \u9891\u8c31\u5730\u677f\nmag_sub = jnp.maximum(mag_mix ** 2 - alpha * noise_est, beta * mag_mix ** 2)\nmag_sub = jnp.sqrt(mag_sub)\nS_sub = mag_sub * jnp.exp(1j * phase_mix)\nenhanced_sub = istft(S_sub, hop, window, len(mixture))\n\n# \u65b9\u6cd5 3\uff1a\u7ef4\u7eb3\u6ee4\u6ce2\u5668\nsnr_est = mag_mix ** 2 / (noise_est + 1e-8)\nwiener_gain = snr_est / (1 + snr_est)\nS_wiener = (wiener_gain * mag_mix) * jnp.exp(1j * phase_mix)\nenhanced_wiener = istft(S_wiener, hop, window, len(mixture))\n\n# \u8ba1\u7b97\u6bcf\u79cd\u65b9\u6cd5\u7684 SI-SDR\ndef si_sdr(estimate, reference):\n    \"\"\"\u5c3a\u5ea6\u4e0d\u53d8\u4fe1\u53f7\u5931\u771f\u6bd4\u3002\"\"\"\n    ref = reference[:len(estimate)]\n    est = estimate[:len(reference)]\n    s_target = (jnp.dot(est, ref) / (jnp.dot(ref, ref) + 1e-8)) * ref\n    e_noise = est - s_target\n    return 10 * jnp.log10(jnp.dot(s_target, s_target) /\n                           (jnp.dot(e_noise, e_noise) + 1e-8))\n\nsi_sdr_mix = si_sdr(mixture, speech)\nsi_sdr_irm_val = si_sdr(enhanced_irm, speech)\nsi_sdr_sub_val = si_sdr(enhanced_sub, speech)\nsi_sdr_wiener_val = si_sdr(enhanced_wiener, speech)\n\n# \u53ef\u89c6\u5316\nfig, axes = plt.subplots(3, 2, figsize=(14, 12))\n\n# \u8bed\u8c31\u56fe\naxes[0, 0].imshow(jnp.log1p(mag_speech.T), aspect='auto', origin='lower',\n                   cmap='magma')\naxes[0, 0].set_title('\u5e72\u51c0\u8bed\u97f3\u8bed\u8c31\u56fe')\naxes[0, 0].set_ylabel('\u9891\u7387 bin')\n\naxes[0, 1].imshow(jnp.log1p(mag_mix.T), aspect='auto', origin='lower',\n                   cmap='magma')\naxes[0, 1].set_title(f'\u5e26\u566a\u6df7\u5408 ({snr_db:.0f} dB SNR)')\n\n# \u63a9\u853d\naxes[1, 0].imshow(irm.T, aspect='auto', origin='lower', cmap='RdYlGn')\naxes[1, 0].set_title('\u7406\u60f3\u6bd4\u7387\u63a9\u853d\uff08Oracle\uff09')\naxes[1, 0].set_ylabel('\u9891\u7387 bin')\n\naxes[1, 1].imshow(wiener_gain.T, aspect='auto', origin='lower', cmap='RdYlGn',\n                   vmin=0, vmax=1)\naxes[1, 1].set_title('\u4f30\u8ba1\u7684\u7ef4\u7eb3\u589e\u76ca')\n\n# \u589e\u5f3a\u540e\u7684\u6ce2\u5f62\u5bf9\u6bd4\nn_show = 3000\naxes[2, 0].plot(t[:n_show], speech[:n_show], color='#27ae60', linewidth=0.8,\n                alpha=0.5, label='\u5e72\u51c0')\naxes[2, 0].plot(t[:n_show], mixture[:n_show], color='#e74c3c', linewidth=0.5,\n                alpha=0.4, label='\u5e26\u566a')\naxes[2, 0].plot(t[:n_show], enhanced_irm[:n_show], color='#3498db',\n                linewidth=0.8, label='IRM \u589e\u5f3a')\naxes[2, 0].set_title('\u6ce2\u5f62\u5bf9\u6bd4\uff08IRM\uff09')\naxes[2, 0].set_xlabel('\u65f6\u95f4 (s)')\naxes[2, 0].set_ylabel('\u5e45\u5ea6')\naxes[2, 0].legend(fontsize=8)\n\n# SI-SDR \u67f1\u72b6\u56fe\nmethods = ['\u6df7\u5408\u4fe1\u53f7', '\u8c31\u51cf\u6cd5', '\u7ef4\u7eb3\u6ee4\u6ce2\u5668', '\u7406\u60f3\u6bd4\u7387\u63a9\u853d']\nsdr_values = [float(si_sdr_mix), float(si_sdr_sub_val),\n              float(si_sdr_wiener_val), float(si_sdr_irm_val)]\nbar_colors = ['#e74c3c', '#f39c12', '#9b59b6', '#27ae60']\nbars = axes[2, 1].bar(methods, sdr_values, color=bar_colors, alpha=0.8)\naxes[2, 1].set_ylabel('SI-SDR (dB)')\naxes[2, 1].set_title('\u589e\u5f3a\u8d28\u91cf\u5bf9\u6bd4')\nfor bar, val in zip(bars, sdr_values):\n    axes[2, 1].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.3,\n                    f'{val:.1f}', ha='center', fontsize=10)\naxes[2, 1].axhline(0, color='gray', linestyle='--', linewidth=0.8)\n\nplt.tight_layout()\nplt.show()\n\nprint(f\"SI-SDR\uff08\u5e26\u566a\u6df7\u5408\uff09:        {si_sdr_mix:.2f} dB\")\nprint(f\"SI-SDR\uff08\u8c31\u51cf\u6cd5\uff09:          {si_sdr_sub_val:.2f} dB\")\nprint(f\"SI-SDR\uff08\u7ef4\u7eb3\u6ee4\u6ce2\u5668\uff09:      {si_sdr_wiener_val:.2f} dB\")\nprint(f\"SI-SDR\uff08\u7406\u60f3\u6bd4\u7387\u63a9\u853d\uff09:    {si_sdr_irm_val:.2f} dB\uff08oracle \u7406\u8bba\u4e0a\u9650\uff09\")\n
"},{"location":"chapter%2010%3A%20multimodal%20learning/01.%20multimodal%20representations/","title":"\u591a\u6a21\u6001\u8868\u5f81","text":"

\u591a\u6a21\u6001\u8868\u5f81\u5c06\u89c6\u89c9\u3001\u8bed\u8a00\u548c\u97f3\u9891\u6865\u63a5\u5230\u5171\u4eab\u5d4c\u5165\u7a7a\u95f4\u4e2d\u3002\u672c\u6587\u4ef6\u6db5\u76d6\u878d\u5408\u7b56\u7565\u3001CLIP\u3001ALIGN\u3001SigLIP\u3001\u5bf9\u6bd4\u635f\u5931\u51fd\u6570\uff08InfoNCE\u3001NT-Xent\uff09\u3001\u96f6\u6837\u672c\u5206\u7c7b\u548c\u68c0\u7d22\u8bc4\u4f30\u3002

"},{"location":"chapter%2010%3A%20multimodal%20learning/01.%20multimodal%20representations/#_2","title":"\u878d\u5408\u7b56\u7565","text":" \\[x_{\\\\text{fused}} = [x_{\\\\text{img}}; x_{\\\\text{txt}}] \\\\in \\\\mathbb{R}^{d_1 + d_2}\\] \\[\\hat{y} = \\\\alpha \\\\hat{y}_1 + (1 - \\\\alpha) \\\\hat{y}_2\\]

"},{"location":"chapter%2010%3A%20multimodal%20learning/01.%20multimodal%20representations/#_3","title":"\u8054\u5408\u5d4c\u5165\u7a7a\u95f4","text":" \\[\\\\text{sim}(u, v) = \\\\frac{u \\\\cdot v}{\\\\|u\\\\| \\\\|v\\\\|}\\]

"},{"location":"chapter%2010%3A%20multimodal%20learning/01.%20multimodal%20representations/#_4","title":"\u7528\u4e8e\u591a\u6a21\u6001\u5bf9\u9f50\u7684\u5bf9\u6bd4\u5b66\u4e60","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/01.%20multimodal%20representations/#clip","title":"CLIP","text":" \\[\\\\mathcal{L}_{i \\\\to t} = -\\\\frac{1}{N} \\\\sum_{i=1}^{N} \\\\log \\\\frac{\\\\exp(\\\\text{sim}(z_i^{\\\\text{img}}, z_i^{\\\\text{txt}}) / \\\\tau)}{\\\\sum_{k=1}^{N} \\\\exp(\\\\text{sim}(z_i^{\\\\text{img}}, z_k^{\\\\text{txt}}) / \\\\tau)}\\] \\[\\\\mathcal{L}_{t \\\\to i} = -\\\\frac{1}{N} \\\\sum_{i=1}^{N} \\\\log \\\\frac{\\\\exp(\\\\text{sim}(z_i^{\\\\text{txt}}, z_i^{\\\\text{img}}) / \\\\tau)}{\\\\sum_{k=1}^{N} \\\\exp(\\\\text{sim}(z_i^{\\\\text{txt}}, z_k^{\\\\text{img}}) / \\\\tau)}\\] \\[\\\\mathcal{L}_{\\\\text{CLIP}} = \\\\frac{1}{2}(\\\\mathcal{L}_{i \\\\to t} + \\\\mathcal{L}_{t \\\\to i})\\]

"},{"location":"chapter%2010%3A%20multimodal%20learning/01.%20multimodal%20representations/#align","title":"ALIGN","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/01.%20multimodal%20representations/#siglip","title":"SigLIP","text":" \\[\\\\mathcal{L}_{ij} = -y_{ij} \\\\log \\\\sigma(z_i^{\\\\text{img}} \\\\cdot z_j^{\\\\text{txt}} / \\\\tau) - (1 - y_{ij}) \\\\log(1 - \\\\sigma(z_i^{\\\\text{img}} \\\\cdot z_j^{\\\\text{txt}} / \\\\tau))\\] "},{"location":"chapter%2010%3A%20multimodal%20learning/01.%20multimodal%20representations/#_5","title":"\u5bf9\u6bd4\u635f\u5931\u51fd\u6570\u8be6\u89e3","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/01.%20multimodal%20representations/#infonce","title":"InfoNCE","text":" \\[\\\\mathcal{L}_{\\\\text{InfoNCE}} = -\\\\log \\\\frac{\\\\exp(q \\\\cdot k^+ / \\\\tau)}{\\\\exp(q \\\\cdot k^+ / \\\\tau) + \\\\sum_{j=1}^{K} \\\\exp(q \\\\cdot k_j^- / \\\\tau)}\\] "},{"location":"chapter%2010%3A%20multimodal%20learning/01.%20multimodal%20representations/#nt-xent","title":"NT-Xent","text":" \\[\\\\ell_{i,j} = -\\\\log \\\\frac{\\\\exp(\\\\text{sim}(z_i, z_j) / \\\\tau)}{\\\\sum_{k=1}^{2N} \\\\mathbf{1}_{[k \\\\neq i]} \\\\exp(\\\\text{sim}(z_i, z_k) / \\\\tau)}\\] "},{"location":"chapter%2010%3A%20multimodal%20learning/01.%20multimodal%20representations/#_6","title":"\u6e29\u5ea6\u7684\u4f5c\u7528","text":"

"},{"location":"chapter%2010%3A%20multimodal%20learning/01.%20multimodal%20representations/#_7","title":"\u4e09\u5143\u7ec4\u635f\u5931\u548c\u57fa\u4e8e\u95f4\u9694\u7684\u66ff\u4ee3\u65b9\u6848","text":" \\[\\\\mathcal{L}_{\\\\text{triplet}} = \\\\max(0, \\\\|a - p\\\\|^2 - \\\\|a - n\\\\|^2 + m)\\] "},{"location":"chapter%2010%3A%20multimodal%20learning/01.%20multimodal%20representations/#-","title":"\u56fe\u50cf-\u6587\u672c\u68c0\u7d22\u4e0e\u96f6\u6837\u672c\u5206\u7c7b","text":" \\[\\\\hat{y} = \\\\arg\\\\max_{k} \\\\; \\\\text{sim}(f_\\\\theta(x), g_\\\\phi(t_k))\\]

"},{"location":"chapter%2010%3A%20multimodal%20learning/01.%20multimodal%20representations/#_8","title":"\u97f3\u89c6\u9891\u5bf9\u5e94","text":" \\[\\\\mathcal{L}_{\\\\text{AV}} = -\\\\log \\\\frac{\\\\exp(\\\\text{sim}(z^{\\\\text{vis}}, z^{\\\\text{aud}}) / \\\\tau)}{\\\\sum_{k=1}^{N} \\\\exp(\\\\text{sim}(z^{\\\\text{vis}}, z_k^{\\\\text{aud}}) / \\\\tau)}\\]

"},{"location":"chapter%2010%3A%20multimodal%20learning/01.%20multimodal%20representations/#_9","title":"\u8bc4\u4f30","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/01.%20multimodal%20representations/#_10","title":"\u96f6\u6837\u672c\u57fa\u51c6\u6d4b\u8bd5","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/01.%20multimodal%20representations/#_11","title":"\u68c0\u7d22\u5ea6\u91cf","text":" \\[\\\\text{R@}K = \\\\frac{1}{Q} \\\\sum_{q=1}^{Q} \\\\mathbf{1}[\\\\text{rank}(q) \\\\leq K]\\]

"},{"location":"chapter%2010%3A%20multimodal%20learning/01.%20multimodal%20representations/#_12","title":"\u603b\u7ed3","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/01.%20multimodal%20representations/#colab-notebook","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
  1. \u4ece\u5934\u5b9e\u73b0 CLIP \u5bf9\u6bd4\u635f\u5931\u3002\u521b\u5efa\u968f\u673a\u56fe\u50cf\u548c\u6587\u672c\u5d4c\u5165\uff0c\u8ba1\u7b97\u76f8\u4f3c\u5ea6\u77e9\u9635\uff0c\u5e76\u8ba1\u7b97\u5bf9\u79f0\u4ea4\u53c9\u71b5\u635f\u5931\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef clip_loss(image_embeds, text_embeds, temperature=0.07):\n    \"\"\"\u8ba1\u7b97\u5bf9\u79f0 CLIP \u5bf9\u6bd4\u635f\u5931\u3002\"\"\"\n    # L2 \u5f52\u4e00\u5316\u5d4c\u5165\n    image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=1, keepdims=True)\n    text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=1, keepdims=True)\n\n    # \u8ba1\u7b97\u4f59\u5f26\u76f8\u4f3c\u5ea6\u77e9\u9635 (N x N)\n    logits = image_embeds @ text_embeds.T / temperature  # (N, N)\n\n    # \u6807\u7b7e\uff1a\u5bf9\u89d2\u7ebf\uff08\u7b2c i \u5f20\u56fe\u50cf\u5339\u914d\u7b2c i \u6bb5\u6587\u672c\uff09\n    N = logits.shape[0]\n    labels = jnp.arange(N)\n\n    # \u5bf9\u79f0\u4ea4\u53c9\u71b5\uff1a\u56fe\u50cf\u5230\u6587\u672c + \u6587\u672c\u5230\u56fe\u50cf\n    loss_i2t = -jnp.mean(jax.nn.log_softmax(logits, axis=1)[jnp.arange(N), labels])\n    loss_t2i = -jnp.mean(jax.nn.log_softmax(logits, axis=0)[labels, jnp.arange(N)])\n    return (loss_i2t + loss_t2i) / 2, logits * temperature\n\n# \u6a21\u62df\u4e00\u6279 8 \u4e2a\u56fe\u50cf-\u6587\u672c\u5bf9\uff0c64 \u7ef4\u7a7a\u95f4\nkey = jax.random.PRNGKey(42)\nk1, k2 = jax.random.split(key)\nN, D = 8, 64\nimage_embeds = jax.random.normal(k1, (N, D))\ntext_embeds = jax.random.normal(k2, (N, D))\n\nloss, sim_matrix = clip_loss(image_embeds, text_embeds)\nprint(f\"CLIP loss (random embeddings): {loss:.4f}\")\n\n# \u53ef\u89c6\u5316\u76f8\u4f3c\u5ea6\u77e9\u9635\nfig, ax = plt.subplots(figsize=(6, 5))\nim = ax.imshow(sim_matrix, cmap='coolwarm', vmin=-1, vmax=1)\nax.set_xlabel(\"Text index\"); ax.set_ylabel(\"Image index\")\nax.set_title(f\"Cosine Similarity Matrix (loss={loss:.3f})\")\nplt.colorbar(im); plt.tight_layout(); plt.show()\n# \u5c1d\u8bd5\u6539\u53d8\u6e29\u5ea6 (0.01, 0.1, 1.0) \u5e76\u89c2\u5bdf\u635f\u5931\u5982\u4f55\u53d8\u5316\n# \u5c1d\u8bd5\u4f7f\u5339\u914d\u5bf9\u76f8\u4f3c\uff1a\u5c06 text_embeds \u8bbe\u7f6e\u4e3a image_embeds + \u5c0f\u566a\u58f0\n

  2. \u6784\u5efa\u4e00\u4e2a\u73a9\u5177\u8054\u5408\u5d4c\u5165\u6a21\u578b\uff0c\u5b66\u4e60\u4f7f\u7528 InfoNCE \u635f\u5931\u548c\u68af\u5ea6\u4e0b\u964d\u6765\u5bf9\u9f50 2D\"\u56fe\u50cf\"\uff08\u968f\u673a\u5411\u91cf\uff09\u4e0e\"\u63cf\u8ff0\"\uff08\u4e0d\u540c\u7684\u968f\u673a\u5411\u91cf\uff09\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef info_nce_loss(img_enc, txt_enc, img_data, txt_data, tau=0.1):\n    \"\"\"\u5728\u4e00\u6279\u914d\u5bf9\u7684 (\u56fe\u50cf, \u6587\u672c) \u6570\u636e\u4e0a\u8ba1\u7b97 InfoNCE\u3002\"\"\"\n    z_img = img_data @ img_enc  # (N, D)\n    z_txt = txt_data @ txt_enc  # (N, D)\n    # L2 \u5f52\u4e00\u5316\n    z_img = z_img / jnp.linalg.norm(z_img, axis=1, keepdims=True)\n    z_txt = z_txt / jnp.linalg.norm(z_txt, axis=1, keepdims=True)\n    logits = z_img @ z_txt.T / tau\n    labels = jnp.arange(logits.shape[0])\n    return -jnp.mean(jax.nn.log_softmax(logits, axis=1)[jnp.arange(len(labels)), labels])\n\n# \u521b\u5efa 32 \u4e2a\u914d\u5bf9\u6837\u672c\uff1a\u56fe\u50cf\u5728 R^8 \u4e2d\uff0c\u6587\u672c\u5728 R^6 \u4e2d\uff0c\u5d4c\u5165\u5230 R^4\nkey = jax.random.PRNGKey(0)\nk1, k2, k3, k4 = jax.random.split(key, 4)\nN, d_img, d_txt, d_embed = 32, 8, 6, 4\n\nimg_data = jax.random.normal(k1, (N, d_img))\ntxt_data = jax.random.normal(k2, (N, d_txt))\n\n# \u53ef\u5b66\u4e60\u7684\u6295\u5f71\u77e9\u9635\nimg_enc = jax.random.normal(k3, (d_img, d_embed)) * 0.1\ntxt_enc = jax.random.normal(k4, (d_txt, d_embed)) * 0.1\n\ngrad_fn = jax.jit(jax.grad(info_nce_loss, argnums=(0, 1)))\nlr = 0.05\nlosses = []\n\nfor step in range(300):\n    loss = info_nce_loss(img_enc, txt_enc, img_data, txt_data)\n    losses.append(float(loss))\n    g_img, g_txt = grad_fn(img_enc, txt_enc, img_data, txt_data)\n    img_enc = img_enc - lr * g_img\n    txt_enc = txt_enc - lr * g_txt\n\nprint(f\"Initial loss: {losses[0]:.3f}, Final loss: {losses[-1]:.3f}\")\nprint(f\"Random baseline (log N): {jnp.log(N):.3f}\")\n\nplt.figure(figsize=(8, 4))\nplt.plot(losses, color='#2c3e50')\nplt.axhline(y=0, color='green', linestyle='--', alpha=0.5, label='Perfect alignment')\nplt.axhline(y=float(jnp.log(N)), color='red', linestyle='--', alpha=0.5, label='Random (log N)')\nplt.xlabel(\"Step\"); plt.ylabel(\"InfoNCE Loss\")\nplt.title(\"Learning a Joint Embedding Space\")\nplt.legend(); plt.grid(alpha=0.3); plt.tight_layout(); plt.show()\n# \u4fee\u6539 d_embed\uff08\u5c1d\u8bd5 2, 4, 16\uff09\u89c2\u5bdf\u5d4c\u5165\u7ef4\u5ea6\u5982\u4f55\u5f71\u54cd\u5bf9\u9f50\n

  3. \u4f7f\u7528\u9884\u8ba1\u7b97\u7684\u5d4c\u5165\u5b9e\u73b0\u96f6\u6837\u672c\u5206\u7c7b\u3002\u6a21\u62df\u7c7b\"\u539f\u578b\"\u4f5c\u4e3a\u6587\u672c\u5d4c\u5165\uff0c\u901a\u8fc7\u6700\u8fd1\u90bb\u67e5\u627e\u5bf9\u65b0\u56fe\u50cf\u8fdb\u884c\u5206\u7c7b\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u6a21\u62df 5 \u4e2a\u7c7b\uff0c\u6bcf\u4e2a\u7c7b\u6709\u4e00\u4e2a\u539f\u578b\u6587\u672c\u5d4c\u5165\u5728 R^32 \u4e2d\nkey = jax.random.PRNGKey(42)\nn_classes, d = 5, 32\nclass_names = [\"cat\", \"dog\", \"car\", \"plane\", \"ship\"]\n\n# \u7c7b\u539f\u578b\uff08\u60f3\u8c61\u8fd9\u4e9b\u6765\u81ea\u6587\u672c\u7f16\u7801\u5668\uff09\nk1, k2 = jax.random.split(key)\nclass_prototypes = jax.random.normal(k1, (n_classes, d))\nclass_prototypes = class_prototypes / jnp.linalg.norm(class_prototypes, axis=1, keepdims=True)\n\n# \u751f\u6210 200 \u4e2a\u6d4b\u8bd5\"\u56fe\u50cf\"\uff08\u5728\u5176\u7c7b\u539f\u578b\u9644\u8fd1\u52a0\u4e0a\u566a\u58f0\u7684\u5d4c\u5165\uff09\nn_per_class = 40\ntrue_labels = jnp.repeat(jnp.arange(n_classes), n_per_class)\nkeys = jax.random.split(k2, n_classes * n_per_class)\n\nimage_embeds = []\nfor i in range(n_classes):\n    noise = jax.random.normal(keys[i], (n_per_class, d)) * 0.5\n    cluster = class_prototypes[i] + noise\n    image_embeds.append(cluster)\nimage_embeds = jnp.concatenate(image_embeds, axis=0)\nimage_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=1, keepdims=True)\n\n# \u96f6\u6837\u672c\u5206\u7c7b\uff1a\u4e0e\u6bcf\u4e2a\u539f\u578b\u7684\u4f59\u5f26\u76f8\u4f3c\u5ea6\nsimilarities = image_embeds @ class_prototypes.T  # (200, 5)\npredicted_labels = jnp.argmax(similarities, axis=1)\naccuracy = jnp.mean(predicted_labels == true_labels)\nprint(f\"Zero-shot accuracy: {accuracy:.1%}\")\n\n# \u6df7\u6dc6\u77e9\u9635\nconf = jnp.zeros((n_classes, n_classes), dtype=jnp.int32)\nfor true, pred in zip(true_labels, predicted_labels):\n    conf = conf.at[true, pred].add(1)\n\nfig, ax = plt.subplots(figsize=(6, 5))\nim = ax.imshow(conf, cmap='Blues')\nax.set_xticks(range(n_classes)); ax.set_xticklabels(class_names, rotation=45)\nax.set_yticks(range(n_classes)); ax.set_yticklabels(class_names)\nax.set_xlabel(\"Predicted\"); ax.set_ylabel(\"True\")\nfor i in range(n_classes):\n    for j in range(n_classes):\n        ax.text(j, i, int(conf[i, j]), ha='center', va='center', fontsize=11)\nax.set_title(f\"Zero-Shot Confusion Matrix (acc={accuracy:.1%})\")\nplt.colorbar(im); plt.tight_layout(); plt.show()\n# \u5c1d\u8bd5\u589e\u52a0\u566a\u58f0\uff080.5 -> 1.0 -> 2.0\uff09\u89c2\u5bdf\u51c6\u786e\u7387\u4e0b\u964d\n# \u5c1d\u8bd5\u63d0\u793a\u96c6\u6210\uff1a\u5e73\u5747\u6bcf\u4e2a\u539f\u578b\u7684 3 \u4e2a\u566a\u58f0\u526f\u672c\n

"},{"location":"chapter%2010%3A%20multimodal%20learning/02.%20vision%20language%20models/","title":"\u89c6\u89c9\u8bed\u8a00\u6a21\u578b","text":"

\u89c6\u89c9\u8bed\u8a00\u6a21\u578b\u5171\u540c\u7406\u89e3\u56fe\u50cf\u548c\u6587\u672c\uff0c\u5b9e\u73b0\u89c6\u89c9\u95ee\u7b54\u3001\u56fe\u50cf\u63cf\u8ff0\u548c\u89c6\u89c9\u63a8\u7406\u3002\u672c\u6587\u4ef6\u6db5\u76d6 VQA\u3001\u56fe\u50cf\u63cf\u8ff0\u3001\u89c6\u89c9\u5b9a\u4f4d\uff0c\u4ee5\u53ca VisualBERT\u3001BLIP\u3001LLaVA\u3001Flamingo\u3001PaLI \u548c Qwen-VL \u7b49\u5c06\u89c6\u89c9\u7f16\u7801\u5668\u4e0e\u5927\u578b\u8bed\u8a00\u6a21\u578b\u878d\u5408\u7684\u67b6\u6784\u3002

"},{"location":"chapter%2010%3A%20multimodal%20learning/02.%20vision%20language%20models/#_2","title":"\u89c6\u89c9\u95ee\u7b54","text":" \\[p(a \\mid I, q) = \\text{softmax}(W \\cdot g(v, h))\\]

"},{"location":"chapter%2010%3A%20multimodal%20learning/02.%20vision%20language%20models/#_3","title":"\u56fe\u50cf\u63cf\u8ff0","text":" \\[p(w_t \\mid w_{1:t-1}, I) = \\text{LSTM}(w_{t-1}, h_{t-1})\\]

"},{"location":"chapter%2010%3A%20multimodal%20learning/02.%20vision%20language%20models/#_4","title":"\u67b6\u6784\u6a21\u5f0f","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/02.%20vision%20language%20models/#_5","title":"\u53cc\u7f16\u7801\u5668","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/02.%20vision%20language%20models/#_6","title":"\u878d\u5408\u7f16\u7801\u5668","text":" \\[\\text{CrossAttn}(T, V) = \\text{softmax}\\!\\left(\\frac{(TW_Q)(VW_K)^T}{\\sqrt{d_k}}\\right)(VW_V)\\]

"},{"location":"chapter%2010%3A%20multimodal%20learning/02.%20vision%20language%20models/#-","title":"\u7f16\u7801\u5668-\u89e3\u7801\u5668","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/02.%20vision%20language%20models/#flamingo","title":"Flamingo\uff1a\u5c11\u6837\u672c\u591a\u6a21\u6001\u5b66\u4e60","text":" \\[z = \\text{CrossAttn}(Q_{\\text{learned}}, V_{\\text{image}}) \\in \\mathbb{R}^{N \\times d}\\] \\[\\hat{x} = x + \\alpha \\cdot \\text{CrossAttn}(x, z)\\]

"},{"location":"chapter%2010%3A%20multimodal%20learning/02.%20vision%20language%20models/#llava","title":"LLaVA \u4e0e\u89c6\u89c9\u6307\u4ee4\u5fae\u8c03","text":" \\[H_v = VW, \\quad W \\in \\mathbb{R}^{d_v \\times d_{\\text{LLM}}}\\]

"},{"location":"chapter%2010%3A%20multimodal%20learning/02.%20vision%20language%20models/#_7","title":"\u6269\u5c55\u89c6\u89c9\u8bed\u8a00\u6a21\u578b","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/02.%20vision%20language%20models/#pali","title":"PaLI","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/02.%20vision%20language%20models/#qwen-vl","title":"Qwen-VL","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/02.%20vision%20language%20models/#internvl","title":"InternVL","text":"

"},{"location":"chapter%2010%3A%20multimodal%20learning/02.%20vision%20language%20models/#_8","title":"\u5b9a\u4f4d\u4e0e\u6307\u4ee3","text":" \\[\\text{\u8f93\u51fa: } \\texttt{<loc\\_102><loc\\_215><loc\\_487><loc\\_398>}\\]

"},{"location":"chapter%2010%3A%20multimodal%20learning/02.%20vision%20language%20models/#ocr","title":"\u514d OCR \u6587\u6863\u7406\u89e3","text":"

"},{"location":"chapter%2010%3A%20multimodal%20learning/02.%20vision%20language%20models/#token","title":"\u89c6\u89c9 Token \u6d41\u6c34\u7ebf","text":"

"},{"location":"chapter%2010%3A%20multimodal%20learning/02.%20vision%20language%20models/#_9","title":"\u8bad\u7ec3\u76ee\u6807","text":" \\[\\mathcal{L}_{\\text{LM}} = -\\sum_{t=1}^{T} \\log p(w_t \\mid w_{<t}, V)\\] "},{"location":"chapter%2010%3A%20multimodal%20learning/02.%20vision%20language%20models/#colab-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
  1. \u5b9e\u73b0\u4e00\u4e2a\u7b80\u5355\u7684\u57fa\u4e8e\u6ce8\u610f\u529b\u7684\u56fe\u50cf\u63cf\u8ff0\u89e3\u7801\u5668\u3002\u4f7f\u7528\u968f\u673a\u7684\"\u56fe\u50cf\u7279\u5f81\"\u4f5c\u4e3a\u7f16\u7801\u5668\u8f93\u51fa\uff0c\u8bad\u7ec3\u89e3\u7801\u5668\u751f\u6210\u56fa\u5b9a\u7684\u63cf\u8ff0\uff0c\u89c2\u5bdf\u6ce8\u610f\u529b\u6743\u91cd\u5728\u6bcf\u4e2a\u89e3\u7801\u6b65\u9aa4\u5982\u4f55\u8de8\u7a7a\u95f4\u4f4d\u7f6e\u79fb\u52a8\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u6a21\u62df 4x4 \u7a7a\u95f4\u7f51\u683c\u7684\u56fe\u50cf\u7279\u5f81\uff0816 \u4e2a\u533a\u57df\uff0cdim=32\uff09\nkey = jax.random.PRNGKey(42)\nk1, k2, k3 = jax.random.split(key, 3)\nimg_features = jax.random.normal(k1, (16, 32))  # 16 \u4e2a\u7a7a\u95f4\u533a\u57df\uff0c32 \u7ef4\n\n# \u8bcd\u6c47\u8868\uff1a0=<start>, 1=\"a\", 2=\"red\", 3=\"car\", 4=<end>\nvocab_size, embed_dim, hidden_dim = 5, 16, 32\nW_embed = jax.random.normal(k2, (vocab_size, embed_dim)) * 0.1\nW_attn_q = jax.random.normal(k3, (hidden_dim, 32)) * 0.1  # \u67e5\u8be2\u6295\u5f71\n\ndef attend(h, img_feats, W_q):\n    \"\"\"\u5728\u7ed9\u5b9a\u89e3\u7801\u5668\u72b6\u6001 h \u7684\u60c5\u51b5\u4e0b\u8ba1\u7b97\u56fe\u50cf\u7279\u5f81\u4e0a\u7684\u8f6f\u6ce8\u610f\u529b\u3002\"\"\"\n    query = h @ W_q  # (32,)\n    scores = img_feats @ query  # (16,)\n    weights = jax.nn.softmax(scores)  # (16,)\n    context = weights @ img_feats  # (32,)\n    return context, weights\n\n# \u7b80\u5355\u7684 GRU \u98ce\u683c\u6b65\u9aa4\uff08\u4e3a\u8bf4\u660e\u76ee\u7684\uff0c\u4ec5\u7528\u7ebf\u6027 + tanh\uff09\nW_h = jax.random.normal(jax.random.PRNGKey(0), (embed_dim + 32, hidden_dim)) * 0.1\n\ndef decode_step(h, word_idx, img_feats):\n    context, attn_weights = attend(h, img_feats, W_attn_q)\n    word_emb = W_embed[word_idx]  # (16,)\n    inp = jnp.concatenate([word_emb, context])  # (48,)\n    h_new = jnp.tanh(inp @ W_h)  # (32,)\n    return h_new, attn_weights\n\n# \u8fd0\u884c\u89e3\u7801\u5e8f\u5217\uff1a<start> -> \"a\" -> \"red\" -> \"car\" -> <end>\ntarget_seq = [0, 1, 2, 3, 4]\nh = jnp.zeros(hidden_dim)\nall_attn = []\nfor word_idx in target_seq[:-1]:\n    h, attn_w = decode_step(h, word_idx, img_features)\n    all_attn.append(attn_w)\n\n# \u53ef\u89c6\u5316\u6bcf\u4e00\u6b65\u7684\u6ce8\u610f\u529b\u56fe\uff08\u91cd\u5851\u4e3a 4x4 \u7f51\u683c\uff09\nwords = [\"<start>\", \"a\", \"red\", \"car\"]\nfig, axes = plt.subplots(1, 4, figsize=(14, 3))\nfor i, (ax, w) in enumerate(zip(axes, words)):\n    ax.imshow(all_attn[i].reshape(4, 4), cmap='viridis')\n    ax.set_title(f'\u751f\u6210\"{w}\"\u540e\\n\u5173\u6ce8\u7684\u533a\u57df')\n    ax.axis('off')\nplt.suptitle('\u6bcf\u4e2a\u89e3\u7801\u6b65\u9aa4\u7684\u56fe\u50cf\u533a\u57df\u6ce8\u610f\u529b')\nplt.tight_layout(); plt.show()\n# \u5c1d\u8bd5\u4fee\u6539 img_features\uff0c\u89c2\u5bdf\u6ce8\u610f\u529b\u6a21\u5f0f\u5982\u4f55\u53d8\u5316\uff01\n

  2. \u6a21\u62df\u89c6\u89c9 token \u6d41\u6c34\u7ebf\uff1a\u5c06\u56fe\u50cf\u5212\u5206\u4e3a patch\uff0c\u5c06 patch \u6295\u5f71\u5230\u5d4c\u5165\u7a7a\u95f4\uff0c\u4e0e\u6587\u672c token \u5d4c\u5165\u62fc\u63a5\uff0c\u5e76\u5728\u7ec4\u5408\u5e8f\u5217\u4e0a\u8fd0\u884c\u5355\u5c42\u81ea\u6ce8\u610f\u529b\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\nkey = jax.random.PRNGKey(7)\n\n# \u521b\u5efa\u4e00\u4e2a\u5408\u6210\u7684 8x8 \"\u56fe\u50cf\"\uff0c3 \u4e2a\u901a\u9053\nk1, k2, k3, k4 = jax.random.split(key, 4)\nimage = jax.random.uniform(k1, (8, 8, 3))\n\n# \u7b2c 1 \u6b65\uff1a\u5212\u5206\u4e3a 4x4 patch -> 4 \u4e2a patch\npatch_size = 4\npatches = image.reshape(2, patch_size, 2, patch_size, 3)\npatches = patches.transpose(0, 2, 1, 3, 4).reshape(4, patch_size * patch_size * 3)  # (4, 48)\nprint(f\"Patch \u6570\u91cf: {patches.shape[0]}, Patch \u7ef4\u5ea6: {patches.shape[1]}\")\n\n# \u7b2c 2 \u6b65\uff1a\u5c06 patch \u6295\u5f71\u5230\u5d4c\u5165\u7ef4\u5ea6 (d=16)\nd_model = 16\nW_patch = jax.random.normal(k2, (patches.shape[1], d_model)) * 0.1\nvisual_tokens = patches @ W_patch  # (4, 16)\n\n# \u7b2c 3 \u6b65\uff1a\u521b\u5efa\u6587\u672c token \u5d4c\u5165\uff08\u6a21\u62df 3 \u4e2a\u6587\u672c token\uff09\ntext_tokens = jax.random.normal(k3, (3, d_model)) * 0.1\n\n# \u7b2c 4 \u6b65\uff1a\u62fc\u63a5\u89c6\u89c9 + \u6587\u672c token\ncombined = jnp.concatenate([visual_tokens, text_tokens], axis=0)  # (7, 16)\nprint(f\"\u7ec4\u5408\u5e8f\u5217\u957f\u5ea6: {combined.shape[0]} (4 \u4e2a\u89c6\u89c9 + 3 \u4e2a\u6587\u672c)\")\n\n# \u7b2c 5 \u6b65\uff1a\u5728\u7ec4\u5408\u5e8f\u5217\u4e0a\u8fd0\u884c\u5355\u5934\u81ea\u6ce8\u610f\u529b\nW_Q = jax.random.normal(k4, (d_model, d_model)) * 0.1\nk5, k6 = jax.random.split(k4)\nW_K = jax.random.normal(k5, (d_model, d_model)) * 0.1\nW_V = jax.random.normal(k6, (d_model, d_model)) * 0.1\n\nQ = combined @ W_Q\nK = combined @ W_K\nV = combined @ W_V\nattn_scores = (Q @ K.T) / jnp.sqrt(d_model)\nattn_weights = jax.nn.softmax(attn_scores, axis=-1)  # (7, 7)\n\noutput = attn_weights @ V  # (7, 16)\n\n# \u53ef\u89c6\u5316\u8de8\u6a21\u6001\u6ce8\u610f\u529b\u6a21\u5f0f\nlabels = ['V1', 'V2', 'V3', 'V4', 'T1', 'T2', 'T3']\nfig, ax = plt.subplots(figsize=(6, 5))\nim = ax.imshow(attn_weights, cmap='Blues')\nax.set_xticks(range(7)); ax.set_xticklabels(labels)\nax.set_yticks(range(7)); ax.set_yticklabels(labels)\nax.set_xlabel('\u952e'); ax.set_ylabel('\u67e5\u8be2')\nax.set_title('\u81ea\u6ce8\u610f\u529b\uff1a\u89c6\u89c9\uff08V\uff09\u548c\u6587\u672c\uff08T\uff09Token')\nplt.colorbar(im, ax=ax); plt.tight_layout(); plt.show()\n# \u89c2\u5bdf\uff1a\u6587\u672c token \u5173\u6ce8\u89c6\u89c9 token\uff08\u8de8\u6a21\u6001\u6ce8\u610f\u529b\uff09\uff01\n

  3. \u5b9e\u73b0\u7528\u4e8e\u89c6\u89c9\u5b9a\u4f4d\u7684\u5750\u6807 token \u5316\u3002\u7ed9\u5b9a\u4e00\u4e2a\u8fb9\u754c\u6846\uff0c\u5c06\u5176\u8f6c\u6362\u4e3a\u79bb\u6563 token\uff1b\u7ed9\u5b9a\u79bb\u6563 token\uff0c\u91cd\u6784\u8fb9\u754c\u6846\u3002\u5728\u4e0d\u540c\u69fd\u4f4d\u5206\u8fa8\u7387\u4e0b\u53ef\u89c6\u5316\u91cf\u5316\u8bef\u5dee\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef encode_bbox(bbox, num_bins=1000):\n    \"\"\"\u5c06\u8fde\u7eed\u7684\u8fb9\u754c\u6846 (x, y, w, h)\uff08\u5728 [0,1] \u8303\u56f4\u5185\uff09\u8f6c\u6362\u4e3a\u79bb\u6563 token\u3002\"\"\"\n    tokens = jnp.round(jnp.array(bbox) * (num_bins - 1)).astype(jnp.int32)\n    return tokens\n\ndef decode_bbox(tokens, num_bins=1000):\n    \"\"\"\u5c06\u79bb\u6563 token \u8f6c\u6362\u56de\u8fde\u7eed\u7684\u8fb9\u754c\u6846\u3002\"\"\"\n    return tokens.astype(jnp.float32) / (num_bins - 1)\n\n# \u771f\u5b9e\u8fb9\u754c\u6846\uff08\u5f52\u4e00\u5316\u5230 [0, 1]\uff09\ngt_bbox = jnp.array([0.123, 0.456, 0.333, 0.222])\n\n# \u6d4b\u8bd5\u4e0d\u540c\u69fd\u4f4d\u5206\u8fa8\u7387\u4e0b\u7684\u91cf\u5316\nbin_sizes = [10, 50, 100, 500, 1000]\nerrors = []\nfor n_bins in bin_sizes:\n    tokens = encode_bbox(gt_bbox, n_bins)\n    reconstructed = decode_bbox(tokens, n_bins)\n    error = jnp.max(jnp.abs(gt_bbox - reconstructed))\n    errors.append(float(error))\n    print(f\"\u69fd\u4f4d\u6570={n_bins:>5d} | Token={tokens} | \"\n          f\"\u91cd\u6784={reconstructed} | \u6700\u5927\u8bef\u5dee={error:.6f}\")\n\nfig, ax = plt.subplots(figsize=(8, 4))\nax.plot(bin_sizes, errors, 'o-', color='#e74c3c', linewidth=2, markersize=8)\nax.set_xlabel('\u69fd\u4f4d\u6570'); ax.set_ylabel('\u6700\u5927\u91cf\u5316\u8bef\u5dee')\nax.set_title('\u8fb9\u754c\u6846\u91cf\u5316\u8bef\u5dee vs \u69fd\u4f4d\u5206\u8fa8\u7387')\nax.set_xscale('log'); ax.set_yscale('log')\nax.grid(True, alpha=0.3); plt.tight_layout(); plt.show()\n# \u5c1d\u8bd5\uff1a\u69fd\u4f4d\u975e\u5e38\u5c11\u65f6\uff08\u5982 5\uff09\u4f1a\u53d1\u751f\u4ec0\u4e48\uff1f\u8bef\u5dee\u5728\u4f55\u65f6\u662f\u53ef\u63a5\u53d7\u7684\uff1f\n

"},{"location":"chapter%2010%3A%20multimodal%20learning/03.%20image%20and%20video%20tokenisation/","title":"\u56fe\u50cf\u4e0e\u89c6\u9891\u8bcd\u5143\u5316","text":"

\u56fe\u50cf\u4e0e\u89c6\u9891\u8bcd\u5143\u5316\u5c06\u8fde\u7eed\u7684\u89c6\u89c9\u6570\u636e\u8f6c\u6362\u4e3a\u79bb\u6563\u7684\u8bcd\u5143\u5e8f\u5217\uff0c\u4f7f Transformer \u80fd\u591f\u50cf\u5904\u7406\u6587\u672c\u4e00\u6837\u5904\u7406\u5b83\u4eec\u3002\u672c\u8282\u6db5\u76d6 VQ-VAE\u3001VQ-GAN\u3001\u7801\u672c\u5b66\u4e60\u3001DALL-E \u7684 dVAE\u3001\u89c6\u9891\u8bcd\u5143\u5316\u4ee5\u53ca\u514d\u67e5\u8be2\u8bcd\u5143\u5316\u3002

"},{"location":"chapter%2010%3A%20multimodal%20learning/03.%20image%20and%20video%20tokenisation/#_2","title":"\u4e3a\u4ec0\u4e48\u8981\u5bf9\u56fe\u50cf\u8fdb\u884c\u8bcd\u5143\u5316","text":"

"},{"location":"chapter%2010%3A%20multimodal%20learning/03.%20image%20and%20video%20tokenisation/#vq-vae","title":"VQ-VAE\uff1a\u5411\u91cf\u91cf\u5316","text":" \\[\\mathbf{z}_q(i,j) = \\mathbf{e}_{k^\\ast} \\quad \\text{\u5176\u4e2d} \\quad k^\\ast = \\arg\\min_k \\|\\mathbf{z}_e(i,j) - \\mathbf{e}_k\\|_2\\]

\\[\\mathbf{z}_q = \\mathbf{z}_e + \\text{sg}(\\mathbf{z}_q - \\mathbf{z}_e)\\] \\[\\mathcal{L} = \\underbrace{\\|\\mathbf{x} - D(\\mathbf{z}_q)\\|_2^2}_{\\text{\u91cd\u5efa\u635f\u5931}} + \\underbrace{\\|\\text{sg}(\\mathbf{z}_e) - \\mathbf{e}\\|_2^2}_{\\text{\u7801\u672c\uff08VQ\uff09\u635f\u5931}} + \\underbrace{\\beta \\|\\mathbf{z}_e - \\text{sg}(\\mathbf{e})\\|_2^2}_{\\text{\u627f\u8bfa\u635f\u5931}}\\] \\[\\mathbf{n}_k \\leftarrow \\gamma \\mathbf{n}_k + (1 - \\gamma) |\\{(i,j) : k^\\ast_{ij} = k\\}|\\] \\[\\mathbf{s}_k \\leftarrow \\gamma \\mathbf{s}_k + (1 - \\gamma) \\sum_{(i,j) : k^\\ast_{ij} = k} \\mathbf{z}_e(i,j)\\] \\[\\mathbf{e}_k \\leftarrow \\frac{\\mathbf{s}_k}{\\mathbf{n}_k}\\] "},{"location":"chapter%2010%3A%20multimodal%20learning/03.%20image%20and%20video%20tokenisation/#_3","title":"\u7801\u672c\u574d\u584c","text":"

"},{"location":"chapter%2010%3A%20multimodal%20learning/03.%20image%20and%20video%20tokenisation/#vq-gan","title":"VQ-GAN\uff1a\u5bf9\u6297\u8bad\u7ec3\u5b9e\u73b0\u66f4\u9ad8\u4fdd\u771f\u5ea6","text":" \\[\\mathcal{L}_\\text{VQ-GAN} = \\mathcal{L}_\\text{VQ-VAE} + \\lambda_\\text{adv} \\mathcal{L}_\\text{adv} + \\lambda_\\text{perc} \\mathcal{L}_\\text{perc}\\] \\[\\mathcal{L}_\\text{adv} = -\\mathbb{E}[\\log \\mathcal{D}(D(\\mathbf{z}_q))]\\] \\[\\mathcal{L}_\\text{perc} = \\sum_l \\|\\phi_l(\\mathbf{x}) - \\phi_l(D(\\mathbf{z}_q))\\|_2^2\\]

"},{"location":"chapter%2010%3A%20multimodal%20learning/03.%20image%20and%20video%20tokenisation/#_4","title":"\u6b8b\u5dee\u91cf\u5316\u4e0e\u591a\u5c3a\u5ea6\u7801\u672c","text":" \\[\\mathbf{r}^{(0)} = \\mathbf{z}_e\\] \\[\\mathbf{z}_q^{(t)} = \\text{Quantise}(\\mathbf{r}^{(t-1)}, \\mathcal{C}^{(t)})\\] \\[\\mathbf{r}^{(t)} = \\mathbf{r}^{(t-1)} - \\mathbf{z}_q^{(t)}\\]

"},{"location":"chapter%2010%3A%20multimodal%20learning/03.%20image%20and%20video%20tokenisation/#_5","title":"\u5b9e\u8df5\u4e2d\u7684\u56fe\u50cf\u8bcd\u5143\u5316\u5668","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/03.%20image%20and%20video%20tokenisation/#dall-e-dvae","title":"DALL-E \u8bcd\u5143\u5316\u5668\uff08dVAE\uff09","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/03.%20image%20and%20video%20tokenisation/#llamagen","title":"LlamaGen","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/03.%20image%20and%20video%20tokenisation/#cosmos","title":"Cosmos \u8bcd\u5143\u5316\u5668","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/03.%20image%20and%20video%20tokenisation/#_6","title":"\u89c6\u9891\u8bcd\u5143\u5316","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/03.%20image%20and%20video%20tokenisation/#3d-vq-vae","title":"3D VQ-VAE","text":"

"},{"location":"chapter%2010%3A%20multimodal%20learning/03.%20image%20and%20video%20tokenisation/#_7","title":"\u56e0\u679c\u89c6\u9891\u8bcd\u5143\u5316\u5668","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/03.%20image%20and%20video%20tokenisation/#_8","title":"\u65f6\u95f4\u538b\u7f29\u7b56\u7565","text":"

"},{"location":"chapter%2010%3A%20multimodal%20learning/03.%20image%20and%20video%20tokenisation/#_9","title":"\u8fde\u7eed\u8bcd\u5143\u4e0e\u79bb\u6563\u8bcd\u5143","text":"

"},{"location":"chapter%2010%3A%20multimodal%20learning/03.%20image%20and%20video%20tokenisation/#_10","title":"\u5e94\u7528","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/03.%20image%20and%20video%20tokenisation/#_11","title":"\u81ea\u56de\u5f52\u56fe\u50cf\u751f\u6210","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/03.%20image%20and%20video%20tokenisation/#-","title":"\u7edf\u4e00\u7684\u89c6\u89c9-\u8bed\u8a00\u8bcd\u5143","text":"

"},{"location":"chapter%2010%3A%20multimodal%20learning/03.%20image%20and%20video%20tokenisation/#colab","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u5728 Colab \u6216\u7b14\u8bb0\u672c\u4e2d\u8fd0\u884c\uff09","text":"
  1. \u5728 JAX \u4e2d\u5b9e\u73b0\u4e00\u4e2a\u6700\u5c0f VQ \u5c42\uff1a\u7ed9\u5b9a\u4e00\u6279\u7f16\u7801\u5668\u8f93\u51fa\u5411\u91cf\uff0c\u6267\u884c\u6700\u8fd1\u90bb\u7801\u672c\u67e5\u627e\u5e76\u8ba1\u7b97 VQ-VAE \u635f\u5931\uff08\u91cd\u5efa + \u7801\u672c + \u627f\u8bfa\uff09\u3002\u5c06\u7801\u672c\u5229\u7528\u7387\u53ef\u89c6\u5316\u4e3a\u76f4\u65b9\u56fe\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# --- \u6700\u5c0f VQ \u5c42 ---\nkey = jax.random.PRNGKey(42)\nd = 8          # \u5d4c\u5165\u7ef4\u5ea6\nK = 64         # \u7801\u672c\u5927\u5c0f\nn_vectors = 256  # \u4e00\u6279\u7f16\u7801\u5668\u8f93\u51fa\n\n# \u968f\u673a\u7f16\u7801\u5668\u8f93\u51fa\u548c\u7801\u672c\nk1, k2 = jax.random.split(key)\nz_e = jax.random.normal(k1, (n_vectors, d))       # \u7f16\u7801\u5668\u8f93\u51fa\ncodebook = jax.random.normal(k2, (K, d)) * 0.1     # \u7801\u672c\uff08\u5c0f\u521d\u59cb\u5316\uff09\n\n# \u6700\u8fd1\u90bb\u67e5\u627e\uff1a\u4e3a\u6bcf\u4e2a z_e \u627e\u5230\u6700\u8fd1\u7684\u7801\u672c\u6761\u76ee\n# distances[i, k] = ||z_e[i] - codebook[k]||^2\ndistances = (\n    jnp.sum(z_e ** 2, axis=1, keepdims=True)\n    - 2 * z_e @ codebook.T\n    + jnp.sum(codebook ** 2, axis=1, keepdims=True).T\n)\nindices = jnp.argmin(distances, axis=1)       # \u8bcd\u5143\u7d22\u5f15\nz_q = codebook[indices]                        # \u91cf\u5316\u5411\u91cf\n\n# VQ-VAE \u635f\u5931\u9879\nbeta = 0.25\nloss_codebook = jnp.mean((jax.lax.stop_gradient(z_e) - z_q) ** 2)\nloss_commit   = jnp.mean((z_e - jax.lax.stop_gradient(z_q)) ** 2)\nloss_total    = loss_codebook + beta * loss_commit\nprint(f\"\u7801\u672c\u635f\u5931: {loss_codebook:.4f}, \u627f\u8bfa\u635f\u5931: {loss_commit:.4f}\")\n\n# \u7801\u672c\u5229\u7528\u7387\nunique, counts = jnp.unique(indices, return_counts=True, size=K, fill_value=-1)\nplt.figure(figsize=(10, 4))\nplt.bar(range(K), counts, color='#3498db', alpha=0.8)\nplt.xlabel('\u7801\u672c\u7d22\u5f15'); plt.ylabel('\u5206\u914d\u8ba1\u6570')\nplt.title(f'\u7801\u672c\u5229\u7528\u7387\uff08\u5df2\u4f7f\u7528 {jnp.sum(counts > 0)}/{K} \u4e2a\u6761\u76ee\uff09')\nplt.grid(True, alpha=0.3); plt.tight_layout(); plt.show()\n# \u5c1d\u8bd5\uff1a\u5c06 K \u589e\u52a0\u5230 512 \u5e76\u89c2\u5bdf\u574d\u584c\u3002\u7136\u540e\u6dfb\u52a0\u7801\u672c\u91cd\u7f6e\u903b\u8f91\u3002\n

  2. \u6784\u5efa\u4e00\u4e2a\u73a9\u5177 2D \u5411\u91cf\u91cf\u5316\u5668\uff0c\u5b66\u4e60\u5bf9 2D \u5206\u5e03\u8fdb\u884c\u5212\u5206\u3002\u751f\u6210\u968f\u673a 2D \u70b9\uff0c\u901a\u8fc7 EMA \u66f4\u65b0\u5b66\u4e60\u7801\u672c\uff0c\u5e76\u5c06 Voronoi \u533a\u57df\u53ef\u89c6\u5316\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u4ece\u9ad8\u65af\u6df7\u5408\u751f\u6210 2D \u6570\u636e\nkey = jax.random.PRNGKey(0)\nn_points = 2000\nK = 16  # \u7801\u672c\u6761\u76ee\u6570\ngamma = 0.99  # EMA \u8870\u51cf\n\n# \u56db\u4e2a\u7c07\nkeys = jax.random.split(key, 5)\ncentres = jnp.array([[2, 2], [-2, 2], [-2, -2], [2, -2]], dtype=jnp.float32)\ndata = jnp.concatenate([\n    jax.random.normal(keys[i], (n_points // 4, 2)) * 0.5 + centres[i]\n    for i in range(4)\n])\n\n# \u4ece\u968f\u673a\u6570\u636e\u70b9\u521d\u59cb\u5316\u7801\u672c\nidx = jax.random.choice(keys[4], n_points, (K,), replace=False)\ncodebook = data[idx]\nema_count = jnp.ones(K)\nema_sum = codebook.copy()\n\n# \u8fd0\u884c\u591a\u4e2a epoch \u7684\u57fa\u4e8e EMA \u7684\u7801\u672c\u5b66\u4e60\nfor epoch in range(30):\n    # \u5c06\u6bcf\u4e2a\u70b9\u5206\u914d\u7ed9\u6700\u8fd1\u7684\u7801\u672c\u6761\u76ee\n    dists = jnp.sum((data[:, None, :] - codebook[None, :, :]) ** 2, axis=2)\n    assignments = jnp.argmin(dists, axis=1)\n    # EMA \u66f4\u65b0\n    for k in range(K):\n        mask = (assignments == k)\n        count_k = jnp.sum(mask)\n        ema_count = ema_count.at[k].set(gamma * ema_count[k] + (1 - gamma) * count_k)\n        if count_k > 0:\n            sum_k = jnp.sum(data[mask], axis=0)\n            ema_sum = ema_sum.at[k].set(gamma * ema_sum[k] + (1 - gamma) * sum_k)\n    codebook = ema_sum / ema_count[:, None]\n\n# \u53ef\u89c6\u5316\u5206\u914d\u548c\u7801\u672c\nfig, ax = plt.subplots(1, 1, figsize=(8, 8))\ncolors = plt.cm.tab20(jnp.linspace(0, 1, K))\nfor k in range(K):\n    mask = assignments == k\n    ax.scatter(data[mask, 0], data[mask, 1], c=[colors[k]], s=5, alpha=0.3)\nax.scatter(codebook[:, 0], codebook[:, 1], c='black', s=120, marker='X',\n           edgecolors='white', linewidths=1.5, zorder=10, label='\u7801\u672c')\nax.set_title(f'\u5728 2D \u6570\u636e\u4e0a\u5b66\u5f97\u7684 VQ \u7801\u672c\uff08{K} \u4e2a\u6761\u76ee\uff09')\nax.legend(); ax.set_aspect('equal'); ax.grid(True, alpha=0.3)\nplt.tight_layout(); plt.show()\n# \u5c1d\u8bd5\uff1a\u5c06 K \u589e\u52a0\u5230 64 \u5e76\u89c2\u5bdf\u66f4\u7cbe\u7ec6\u7684\u5212\u5206\u3002\u51cf\u5c0f gamma \u5e76\u89c2\u5bdf\u4e0d\u7a33\u5b9a\u6027\u3002\n

  3. \u6f14\u793a\u6b8b\u5dee\u91cf\u5316\uff1a\u7528 \\(T\\) \u4e2a\u8fde\u7eed\u7684\u91cf\u5316\u9636\u6bb5\u5bf9\u4e00\u6279\u5411\u91cf\u8fdb\u884c\u7f16\u7801\uff0c\u5e76\u6d4b\u91cf\u6bcf\u4e2a\u5c42\u7ea7\u91cd\u5efa\u8bef\u5dee\u7684\u4e0b\u964d\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\nkey = jax.random.PRNGKey(7)\nd = 16         # \u5d4c\u5165\u7ef4\u5ea6\nK = 32         # \u6bcf\u4e2a\u5c42\u7ea7\u7684\u7801\u672c\u5927\u5c0f\nT = 8          # \u6b8b\u5dee\u5c42\u7ea7\u6570\nn_vectors = 512\n\n# \u5f85\u91cf\u5316\u7684\u968f\u673a\u6570\u636e\nk1, *cb_keys = jax.random.split(key, T + 1)\nz = jax.random.normal(k1, (n_vectors, d))\n\n# \u6bcf\u4e2a\u5c42\u7ea7\u7684\u72ec\u7acb\u968f\u673a\u7801\u672c\ncodebooks = [jax.random.normal(cb_keys[t], (K, d)) * (0.5 ** t)\n             for t in range(T)]\n\n# \u6b8b\u5dee\u91cf\u5316\u5faa\u73af\nresidual = z.copy()\nz_hat = jnp.zeros_like(z)\nerrors = []\n\nfor t in range(T):\n    cb = codebooks[t]\n    dists = (jnp.sum(residual ** 2, axis=1, keepdims=True)\n             - 2 * residual @ cb.T\n             + jnp.sum(cb ** 2, axis=1, keepdims=True).T)\n    indices = jnp.argmin(dists, axis=1)\n    z_q_t = cb[indices]\n    z_hat = z_hat + z_q_t\n    residual = residual - z_q_t\n    mse = jnp.mean(jnp.sum((z - z_hat) ** 2, axis=1))\n    errors.append(float(mse))\n    print(f\"\u5c42\u7ea7 {t+1}: MSE = {mse:.4f}\")\n\nplt.figure(figsize=(8, 5))\nplt.plot(range(1, T + 1), errors, 'o-', color='#e74c3c', linewidth=2, markersize=8)\nplt.xlabel('\u6b8b\u5dee\u91cf\u5316\u5c42\u7ea7')\nplt.ylabel('\u91cd\u5efa MSE')\nplt.title('\u6b8b\u5dee\u91cf\u5316\u7684\u8bef\u5dee\u964d\u4f4e')\nplt.xticks(range(1, T + 1)); plt.grid(True, alpha=0.3)\nplt.tight_layout(); plt.show()\n# \u5c1d\u8bd5\uff1a\u4f7f\u7528\u5927\u5c0f\u4e3a K*T \u7684\u5355\u4e2a\u7801\u672c\u5e76\u4e0e RQ \u6bd4\u8f83\u3002\u54ea\u4e2a\u66f4\u597d\uff1f\n

  4. \u6a21\u62df\u4e00\u4e2a\u7b80\u5355\u7684 1D\"\u89c6\u9891\u8bcd\u5143\u5316\u5668\"\uff1a\u751f\u6210\u4e00\u7cfb\u5217 1D \u4fe1\u53f7\uff08\u6a21\u62df\u89c6\u9891\u5e27\uff09\uff0c\u5e94\u7528\u56e0\u679c\u65f6\u95f4\u538b\u7f29\uff0c\u5e76\u4e0e\u65e0\u56e0\u679c\u538b\u7f29\u5728\u91cd\u5efa\u8d28\u91cf\u65b9\u9762\u8fdb\u884c\u6bd4\u8f83\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\nkey = jax.random.PRNGKey(99)\nn_frames = 16\nframe_len = 64\n\n# \u751f\u6210\u4e00\u4e2a\"\u89c6\u9891\"\uff1a\u5728\u5e27\u95f4\u7f13\u6162\u79fb\u52a8\u7684\u9ad8\u65af\u51f8\u8d77\nx_axis = jnp.linspace(-3, 3, frame_len)\nframes = jnp.stack([\n    jnp.exp(-0.5 * (x_axis - (-2 + 4 * t / n_frames)) ** 2)\n    for t in range(n_frames)\n])  # \u5f62\u72b6: (n_frames, frame_len)\n\n# \u56e0\u679c\u65f6\u95f4\u538b\u7f29\uff1a\u6bcf\u5e27\u7684\u7f16\u7801\u4ec5\u4f9d\u8d56\u4e8e\u8fc7\u53bb\u7684\u5e27\n# \u7b80\u5355\u65b9\u6cd5\uff1a\u4f7f\u7528\u8fc7\u53bb\u5e27\u7684\u6307\u6570\u8870\u51cf\u5bf9\u5f53\u524d\u5e27\u8fdb\u884c\u5e73\u5747\nalpha_causal = 0.6\ncausal_codes = jnp.zeros_like(frames)\ncausal_codes = causal_codes.at[0].set(frames[0])\nfor t in range(1, n_frames):\n    causal_codes = causal_codes.at[t].set(\n        alpha_causal * frames[t] + (1 - alpha_causal) * causal_codes[t - 1]\n    )\n\n# \u65e0\u56e0\u679c\uff1a\u540c\u65f6\u5e73\u5747\u8fc7\u53bb\u548c\u672a\u6765\uff08\u53cc\u8fb9\u5e73\u6ed1\uff09\nkernel = jnp.array([0.2, 0.6, 0.2])  # \u8fc7\u53bb, \u5f53\u524d, \u672a\u6765\npadded = jnp.concatenate([frames[:1], frames, frames[-1:]], axis=0)\nnoncausal_codes = jnp.stack([\n    kernel[0] * padded[t] + kernel[1] * padded[t+1] + kernel[2] * padded[t+2]\n    for t in range(n_frames)\n])\n\n# \u91cd\u5efa\u8bef\u5dee\nmse_causal = jnp.mean((frames - causal_codes) ** 2)\nmse_noncausal = jnp.mean((frames - noncausal_codes) ** 2)\nprint(f\"\u56e0\u679c MSE: {mse_causal:.6f}, \u65e0\u56e0\u679c MSE: {mse_noncausal:.6f}\")\n\nfig, axes = plt.subplots(1, 3, figsize=(15, 5))\nfor ax, data, title in zip(axes,\n    [frames, causal_codes, noncausal_codes],\n    ['\u539f\u59cb\u5e27', f'\u56e0\u679c (MSE={mse_causal:.5f})',\n     f'\u65e0\u56e0\u679c (MSE={mse_noncausal:.5f})']):\n    ax.imshow(data, aspect='auto', cmap='viridis', origin='lower')\n    ax.set_xlabel('\u7a7a\u95f4\u4f4d\u7f6e'); ax.set_ylabel('\u5e27\u7d22\u5f15')\n    ax.set_title(title)\nplt.tight_layout(); plt.show()\n# \u5c1d\u8bd5\uff1a\u6539\u53d8 alpha_causal \u548c\u6838\u6743\u91cd\u3002alpha=1.0 \u65f6\u4f1a\u53d1\u751f\u4ec0\u4e48\uff1f\n

"},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/","title":"\u8de8\u6a21\u6001\u751f\u6210 (Cross-Modal Generation)","text":"

\u8de8\u6a21\u6001\u751f\u6210\uff08cross-modal generation\uff09\u662f\u6307\u4ee5\u67d0\u4e00\u6a21\u6001\u7684\u8f93\u5165\u4e3a\u6761\u4ef6\uff0c\u751f\u6210\u53e6\u4e00\u6a21\u6001\u7684\u8f93\u51fa\u2014\u2014\u4ece\u6587\u751f\u56fe\u3001\u56fe\u751f\u6587\u3001\u6587\u751f\u97f3\u9891\uff0c\u4e43\u81f3\u66f4\u591a\u3002\u672c\u7ae0\u6db5\u76d6 DALL\u00b7E\u3001Stable Diffusion\u3001\u65e0\u5206\u7c7b\u5668\u5f15\u5bfc\u3001ControlNet\u3001\u56fe\u50cf\u63cf\u8ff0\u3001\u6587\u751f\u89c6\u9891\uff08Sora\uff09\u4ee5\u53ca\u6587\u751f\u97f3\u9891\u751f\u6210\u3002

"},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#text-to-image-generation","title":"\u6587\u751f\u56fe\u751f\u6210 (Text-to-Image Generation)","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#dalle","title":"DALL\u00b7E\uff1a\u81ea\u56de\u5f52\u56fe\u50cf\u751f\u6210","text":" \\[p(x_{\\text{text}}, x_{\\text{img}}) = \\prod_{i=1}^{1280} p(x_i \\mid x_1, \\ldots, x_{i-1})\\]

"},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#stable-diffusion","title":"Stable Diffusion\uff1a\u5e26\u6587\u672c\u6761\u4ef6\u7684\u9690\u7a7a\u95f4\u6269\u6563","text":" \\[\\text{Attention}(Q, K, V) = \\text{softmax}\\left(\\frac{QK^T}{\\sqrt{d}}\\right)V\\]

"},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#_1","title":"\u65e0\u5206\u7c7b\u5668\u5f15\u5bfc\u7684\u5b9e\u8df5\u5e94\u7528","text":" \\[\\hat{\\epsilon} = \\epsilon_\\theta(x_t, \\varnothing) + s \\cdot (\\epsilon_\\theta(x_t, c) - \\epsilon_\\theta(x_t, \\varnothing))\\] "},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#imagen","title":"Imagen\uff1a\u57fa\u4e8e\u8bed\u8a00\u7406\u89e3\u7684\u7ea7\u8054\u6269\u6563","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#parti","title":"Parti\uff1a\u5927\u89c4\u6a21\u81ea\u56de\u5f52","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#dit","title":"DiT \u4e0e\u57fa\u4e8e\u6d41\u5339\u914d\u7684\u751f\u6210","text":"

"},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#text-to-video-generation","title":"\u6587\u751f\u89c6\u9891\u751f\u6210 (Text-to-Video Generation)","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#_2","title":"\u65f6\u95f4\u7ef4\u5ea6\u7684\u6311\u6218","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#make-a-video","title":"Make-A-Video \u4e0e\u5ef6\u5c55\u81f3\u89c6\u9891\u65b9\u6cd5","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#videopoet-token","title":"VideoPoet \u4e0e\u57fa\u4e8e Token \u7684\u89c6\u9891\u6a21\u578b","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#sora","title":"Sora \u98ce\u683c\u7684\u65f6\u95f4\u6269\u6563","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#wan","title":"Wan\uff1a\u5f00\u6e90\u89c6\u9891\u751f\u6210","text":"

"},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#text-to-audio-generation","title":"\u6587\u751f\u97f3\u9891\u751f\u6210 (Text-to-Audio Generation)","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#audiolm","title":"AudioLM\uff1a\u97f3\u9891\u7684\u8bed\u8a00\u5efa\u6a21","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#musiclm","title":"MusicLM\uff1a\u6587\u672c\u6761\u4ef6\u97f3\u4e50\u751f\u6210","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#musicgen","title":"MusicGen\uff1a\u9ad8\u6548\u5355\u9636\u6bb5\u751f\u6210","text":" \\[p(a_1, \\ldots, a_T) = \\prod_{t=1}^{T} \\prod_{k=1}^{K} p(a_{t,k} \\mid a_{<t}, c_{\\text{text}})\\]

"},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#image-to-text-generation","title":"\u56fe\u751f\u6587\u751f\u6210 (Image-to-Text Generation)","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#_3","title":"\u4f5c\u4e3a\u6761\u4ef6\u751f\u6210\u7684\u56fe\u50cf\u63cf\u8ff0","text":" \\[p(w_1, \\ldots, w_L \\mid I) = \\prod_{l=1}^{L} p(w_l \\mid w_1, \\ldots, w_{l-1}, I)\\] "},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#_4","title":"\u73b0\u4ee3\u89c6\u89c9\u8bed\u8a00\u63cf\u8ff0","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#-video-audio-co-generation","title":"\u89c6\u9891-\u97f3\u9891\u8054\u5408\u751f\u6210 (Video-Audio Co-Generation)","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#_5","title":"\u8054\u5408\u65f6\u95f4\u5efa\u6a21","text":" \\[\\mathcal{L}_{\\text{sync}} = -\\mathbb{E}_t \\left[\\log \\frac{\\exp(\\text{sim}(v_t, a_t) / \\tau)}{\\sum_{t'} \\exp(\\text{sim}(v_t, a_{t'}) / \\tau)}\\right]\\] "},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#instruction-following-generation","title":"\u6307\u4ee4\u9075\u5faa\u5f0f\u751f\u6210 (Instruction-Following Generation)","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#instructpix2pix","title":"InstructPix2Pix\uff1a\u901a\u8fc7\u63cf\u8ff0\u8fdb\u884c\u7f16\u8f91","text":" \\[\\hat{\\epsilon} = \\epsilon_\\theta(x_t, \\varnothing, \\varnothing) + s_I \\cdot (\\epsilon_\\theta(x_t, c_I, \\varnothing) - \\epsilon_\\theta(x_t, \\varnothing, \\varnothing)) + s_T \\cdot (\\epsilon_\\theta(x_t, c_I, c_T) - \\epsilon_\\theta(x_t, c_I, \\varnothing))\\]

"},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#sdedit","title":"SDEdit \u4e0e\u57fa\u4e8e\u566a\u58f0\u7684\u7f16\u8f91","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#controlnet","title":"ControlNet\uff1a\u7a7a\u95f4\u6761\u4ef6\u63a7\u5236","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#consistency-and-alignment-metrics","title":"\u4e00\u81f4\u6027\u4e0e\u5bf9\u9f50\u6307\u6807 (Consistency and Alignment Metrics)","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#frechet-inception-distance-fid","title":"Frechet Inception Distance (FID)","text":" \\[\\text{FID} = \\|\\mu_r - \\mu_g\\|^2 + \\text{Tr}\\left(\\Sigma_r + \\Sigma_g - 2(\\Sigma_r \\Sigma_g)^{1/2}\\right)\\] "},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#inception-score-is","title":"Inception Score (IS)","text":" \\[\\text{IS} = \\exp\\left(\\mathbb{E}_x \\left[D_{\\text{KL}}(p(y \\mid x) \\| p(y))\\right]\\right)\\] "},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#clipscore-","title":"CLIPScore\uff1a\u8861\u91cf\u6587\u672c-\u56fe\u50cf\u5bf9\u9f50\u5ea6","text":" \\[\\text{CLIPScore}(I, T) = \\max(0, \\cos(E_I(I), E_T(T)))\\] \\[\\text{RefCLIPScore} = \\text{HarmonicMean}(\\text{CLIPScore}(I, T), \\max(0, \\cos(E_I(I), E_I(I_{\\text{ref}}))))\\]

"},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#_6","title":"\u4eba\u5de5\u8bc4\u4f30","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#ethical-considerations","title":"\u4f26\u7406\u8003\u91cf (Ethical Considerations)","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#_7","title":"\u6df1\u5ea6\u4f2a\u9020\u4e0e\u865a\u5047\u4fe1\u606f","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#_8","title":"\u751f\u6210\u4e2d\u7684\u504f\u5dee","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#_9","title":"\u5185\u5bb9\u8fc7\u6ee4\u4e0e\u5b89\u5168\u6027","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#_10","title":"\u77e5\u8bc6\u4ea7\u6743\u4e0e\u77e5\u60c5\u540c\u610f","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/04.%20cross-modal%20generation/#colab-notebook","title":"\u7f16\u7a0b\u7ec3\u4e60\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"
  1. \u4e3a\u4e00\u4e2a\u73a9\u5177 2D \u6269\u6563\u6a21\u578b\u5b9e\u73b0\u65e0\u5206\u7c7b\u5668\u5f15\u5bfc\u3002\u5728 2D \u6570\u636e\u96c6\uff08\u4f8b\u5982\u6807\u6ce8\u7684\u805a\u7c7b\uff09\u4e0a\u8bad\u7ec3\u4e00\u4e2a\u6761\u4ef6\u6269\u6563\u6a21\u578b\uff0c\u7136\u540e\u4f7f\u7528\u4e0d\u540c\u7684\u5f15\u5bfc\u5c3a\u5ea6\u8fdb\u884c\u91c7\u6837\uff0c\u89c2\u5bdf\u8d28\u91cf\u4e0e\u591a\u6837\u6027\u7684\u6743\u8861\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# Toy 2D conditional diffusion with classifier-free guidance\ndef noise_schedule(T):\n    betas = jnp.linspace(1e-4, 0.02, T)\n    alphas = 1.0 - betas\n    return jnp.cumprod(alphas)\n\ndef forward_diffuse(x0, t, alpha_bars, key):\n    noise = jax.random.normal(key, x0.shape)\n    return jnp.sqrt(alpha_bars[t]) * x0 + jnp.sqrt(1 - alpha_bars[t]) * noise, noise\n\n# Generate labelled 2D data: class 0 = ring, class 1 = cluster\nkey = jax.random.PRNGKey(42)\nk1, k2, k3 = jax.random.split(key, 3)\ntheta = jax.random.uniform(k1, (200,)) * 2 * jnp.pi\nring = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=1) * 2\nring += jax.random.normal(k2, ring.shape) * 0.1\ncluster = jax.random.normal(k3, (200, 2)) * 0.3\n\ndata = jnp.concatenate([ring, cluster])\nlabels = jnp.concatenate([jnp.zeros(200), jnp.ones(200)])\n\n# Simulate CFG: show how guidance pushes samples toward class-conditional modes\n# Try varying guidance_scale from 0.0 to 5.0 and observe results\nguidance_scales = [0.0, 1.0, 3.0, 7.0]\nfig, axes = plt.subplots(1, 4, figsize=(16, 4))\nfor ax, s in zip(axes, guidance_scales):\n    ax.scatter(ring[:, 0], ring[:, 1], s=8, alpha=0.4, label='Ring (c=0)')\n    ax.scatter(cluster[:, 0], cluster[:, 1], s=8, alpha=0.4, label='Cluster (c=1)')\n    ax.set_title(f'Guidance scale s={s}')\n    ax.set_xlim(-4, 4); ax.set_ylim(-4, 4)\n    ax.set_aspect('equal'); ax.legend(fontsize=7)\nplt.suptitle('Experiment: vary guidance scale and observe quality vs diversity')\nplt.tight_layout(); plt.show()\n# Exercise: train a small MLP denoiser with class conditioning,\n# then implement the CFG formula to sample with different s values.\n

  2. \u4f7f\u7528\u5b8c\u6574\u7684 Frechet \u8ddd\u79bb\u516c\u5f0f\u8ba1\u7b97\u4e24\u7ec4 2D \u6837\u672c\u4e4b\u95f4\u7684 FID\u3002\u6539\u53d8\u751f\u6210\u5206\u5e03\uff0c\u89c2\u5bdf FID \u5982\u4f55\u53d8\u5316\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef compute_fid(real, generated):\n    \"\"\"Compute Frechet distance between two 2D sample sets.\"\"\"\n    mu_r, mu_g = jnp.mean(real, axis=0), jnp.mean(generated, axis=0)\n    sigma_r = jnp.cov(real.T)\n    sigma_g = jnp.cov(generated.T)\n    diff = mu_r - mu_g\n    # Matrix square root via eigendecomposition\n    product = sigma_r @ sigma_g\n    eigvals, eigvecs = jnp.linalg.eigh(product)\n    sqrt_product = eigvecs @ jnp.diag(jnp.sqrt(jnp.maximum(eigvals, 0))) @ eigvecs.T\n    fid = jnp.sum(diff ** 2) + jnp.trace(sigma_r + sigma_g - 2 * sqrt_product)\n    return fid\n\nkey = jax.random.PRNGKey(0)\nk1, k2, k3, k4 = jax.random.split(key, 4)\n\n# Real distribution: standard 2D Gaussian\nreal = jax.random.normal(k1, (1000, 2))\n\n# Generated distributions with increasing divergence\nshifts = [0.0, 0.5, 1.0, 2.0, 4.0]\nfig, axes = plt.subplots(1, len(shifts), figsize=(18, 3.5))\nfor ax, shift in zip(axes, shifts):\n    gen = jax.random.normal(k2, (1000, 2)) * (1 + shift * 0.2) + shift\n    fid = compute_fid(real, gen)\n    ax.scatter(real[:, 0], real[:, 1], s=3, alpha=0.3, label='Real')\n    ax.scatter(gen[:, 0], gen[:, 1], s=3, alpha=0.3, label='Generated')\n    ax.set_title(f'Shift={shift}\\nFID={fid:.2f}')\n    ax.set_xlim(-5, 8); ax.set_ylim(-5, 8)\n    ax.set_aspect('equal'); ax.legend(fontsize=7)\nplt.suptitle('FID increases as generated distribution diverges from real')\nplt.tight_layout(); plt.show()\n# Try: change the variance of generated samples without shifting the mean.\n# How does FID respond to a diversity mismatch vs a location mismatch?\n

  3. \u4f7f\u7528\u968f\u673a\u6295\u5f71\u4f5c\u4e3a CLIP \u7684\u66ff\u4ee3\uff0c\u5b9e\u73b0\u6587\u672c\u548c\u56fe\u50cf\u5d4c\u5165\u4e4b\u95f4\u7684 CLIPScore \u8ba1\u7b97\u3002\u89c2\u5bdf\u5f53\u4f60\u6539\u53d8\u6a21\u6001\u4e4b\u95f4\u7684\"\u5bf9\u9f50\u5ea6\"\u65f6\uff0c\u4f59\u5f26\u76f8\u4f3c\u5ea6\u5982\u4f55\u53d8\u5316\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef cosine_similarity(a, b):\n    return jnp.dot(a, b) / (jnp.linalg.norm(a) * jnp.linalg.norm(b))\n\ndef clip_score(img_emb, txt_emb):\n    \"\"\"CLIPScore: clamped cosine similarity.\"\"\"\n    return jnp.maximum(0.0, cosine_similarity(img_emb, txt_emb))\n\nkey = jax.random.PRNGKey(42)\ndim = 512  # CLIP embedding dimension\n\n# Simulate aligned and misaligned pairs\n# Aligned: image and text embeddings share a component\nk1, k2, k3 = jax.random.split(key, 3)\nshared = jax.random.normal(k1, (dim,))\nshared = shared / jnp.linalg.norm(shared)\n\nnoise_levels = jnp.linspace(0, 5, 20)\nscores = []\nfor noise in noise_levels:\n    noise_vec = jax.random.normal(k2, (dim,)) * noise\n    img_emb = shared + noise_vec * 0.3\n    txt_emb = shared + jax.random.normal(k3, (dim,)) * noise * 0.3\n    scores.append(float(clip_score(img_emb, txt_emb)))\n\nplt.figure(figsize=(8, 4))\nplt.plot(noise_levels, scores, 'o-', color='#2c3e50')\nplt.xlabel('Noise level (misalignment)')\nplt.ylabel('CLIPScore')\nplt.title('CLIPScore decreases as text-image alignment degrades')\nplt.grid(True, alpha=0.3)\nplt.tight_layout(); plt.show()\n# Experiment: what happens if you normalise embeddings before adding noise?\n# How does dimensionality affect the score distribution?\n

"},{"location":"chapter%2010%3A%20multimodal%20learning/05.%20unified%20multimodal%20architectures/","title":"\u7edf\u4e00\u591a\u6a21\u6001\u67b6\u6784","text":"

\u7edf\u4e00\u591a\u6a21\u6001\u67b6\u6784\u7528\u5355\u4e00\u7cfb\u7edf\u53d6\u4ee3\u4e86\u5404\u81ea\u4e3a\u653f\u7684\u4e13\u5bb6\u6a21\u578b\uff0c\u8fd9\u4e2a\u7cfb\u7edf\u80fd\u591f\u8de8\u8d8a\u6587\u672c\u3001\u56fe\u50cf\u3001\u97f3\u9891\u548c\u89c6\u9891\u8fdb\u884c\u8bfb\u53d6\u3001\u63a8\u7406\u548c\u751f\u6210\u3002\u672c\u6587\u6db5\u76d6\u4e86\u4efb\u610f\u5230\u4efb\u610f\u6a21\u578b\uff08CoDi\u3001NExT-GPT\uff09\u3001\u539f\u751f\u591a\u6a21\u6001\u5927\u8bed\u8a00\u6a21\u578b\uff08Gemini\u3001GPT-4o\uff09\u3001\u591a\u6a21\u6001\u5206\u8bcd\u7b56\u7565\uff0c\u4ee5\u53ca\u7edf\u4e00\u5316\u6240\u5e26\u6765\u7684\u67b6\u6784\u6743\u8861\u3002

"},{"location":"chapter%2010%3A%20multimodal%20learning/05.%20unified%20multimodal%20architectures/#_2","title":"\u7edf\u4e00\u5316\u7684\u7406\u7531","text":" \\[f_\\theta : \\mathcal{P}(\\mathcal{M}) \\rightarrow \\mathcal{P}(\\mathcal{M})\\]

"},{"location":"chapter%2010%3A%20multimodal%20learning/05.%20unified%20multimodal%20architectures/#_3","title":"\u4efb\u610f\u5230\u4efb\u610f\u6a21\u578b","text":"

"},{"location":"chapter%2010%3A%20multimodal%20learning/05.%20unified%20multimodal%20architectures/#_4","title":"\u5171\u4eab\u4e3b\u5e72\u4e0a\u7684\u6a21\u6001\u7279\u5b9a\u7f16\u7801\u5668\u548c\u89e3\u7801\u5668","text":" \\[\\text{Attention}(\\mathbf{Q}, \\mathbf{K}, \\mathbf{V}) = \\text{softmax}\\left(\\frac{\\mathbf{Q}\\mathbf{K}^\\top}{\\sqrt{d_k}}\\right)\\mathbf{V}\\] \\[\\tilde{\\mathbf{h}}_i^m = \\mathbf{h}_i^m + \\mathbf{e}_m + \\mathbf{p}_i\\]

"},{"location":"chapter%2010%3A%20multimodal%20learning/05.%20unified%20multimodal%20architectures/#_5","title":"\u591a\u6a21\u6001\u5206\u8bcd","text":"
[TEXT] \u732b\u5750\u5728\u57ab\u5b50\u4e0a [/TEXT] [IMAGE] <img_tok_1> <img_tok_2> ... <img_tok_n> [/IMAGE] [AUDIO] <aud_tok_1> ... <aud_tok_m> [/AUDIO]\n

"},{"location":"chapter%2010%3A%20multimodal%20learning/05.%20unified%20multimodal%20architectures/#_6","title":"\u8bad\u7ec3\u914d\u65b9\uff1a\u5206\u9636\u6bb5\u9884\u8bad\u7ec3\u4e0e\u8054\u5408\u5fae\u8c03","text":" \\[\\mathcal{L} = -\\sum_{t=1}^{T} \\log p_\\theta(x_t \\mid x_{<t})\\]

"},{"location":"chapter%2010%3A%20multimodal%20learning/05.%20unified%20multimodal%20architectures/#_7","title":"\u591a\u6a21\u6001\u601d\u7ef4\u94fe\u63a8\u7406","text":" \\[p(y \\mid \\mathbf{x}) = \\sum_{\\mathbf{r}} p(y \\mid \\mathbf{r}, \\mathbf{x}) \\cdot p(\\mathbf{r} \\mid \\mathbf{x})\\] "},{"location":"chapter%2010%3A%20multimodal%20learning/05.%20unified%20multimodal%20architectures/#_8","title":"\u591a\u6a21\u6001\u667a\u80fd\u4f53","text":"

"},{"location":"chapter%2010%3A%20multimodal%20learning/05.%20unified%20multimodal%20architectures/#_9","title":"\u57fa\u51c6\u6d4b\u8bd5\u4e0e\u8bc4\u4f30","text":""},{"location":"chapter%2010%3A%20multimodal%20learning/05.%20unified%20multimodal%20architectures/#_10","title":"\u4e16\u754c\u6a21\u578b","text":" \\[\\hat{s}_{t+1} = g_\\phi(s_t, a_t)\\] \\[\\mathcal{L}_\\text{world} = \\mathbb{E}\\left[\\sum_{m \\in \\mathcal{M}} \\lambda_m \\| s_{t+1}^m - g_\\phi^m(s_t, a_t) \\|^2 \\right]\\]

\\[\\hat{\\mathbf{z}}_{t+1} = h_\\psi(\\mathbf{z}_t, a_t), \\quad \\mathbf{z}_t = \\text{Enc}(s_t)\\] "},{"location":"chapter%2010%3A%20multimodal%20learning/05.%20unified%20multimodal%20architectures/#colab-notebook","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528 CoLab \u6216 notebook\uff09","text":"

\u4efb\u52a1 1\uff1a\u6784\u5efa\u4e00\u4e2a\u6700\u5c0f\u5316\u7684\u591a\u6a21\u6001 token \u4ea4\u9519\u5668

import jax\nimport jax.numpy as jnp\n\n# \u6a21\u62df\u591a\u6a21\u6001\u5206\u8bcd\uff1a\u6587\u672c token + \"\u56fe\u50cf\u5757\" token\ndef interleave_modalities(text_tokens, image_patches, embed_dim=32, key=jax.random.PRNGKey(0)):\n    \"\"\"\u5c06\u6587\u672c\u548c\u56fe\u50cf token \u4e0e\u5b66\u4e60\u5230\u7684\u6a21\u6001\u5d4c\u5165\u4ea4\u9519\u3002\"\"\"\n    k1, k2, k3 = jax.random.split(key, 3)\n    n_text = text_tokens.shape[0]\n    n_img = image_patches.shape[0]\n    # \u968f\u673a\u6295\u5f71\u77e9\u9635\uff08\u66ff\u4ee3\u771f\u5b9e\u7f16\u7801\u5668\uff09\n    W_text = jax.random.normal(k1, (text_tokens.shape[-1], embed_dim)) * 0.02\n    W_img = jax.random.normal(k2, (image_patches.shape[-1], embed_dim)) * 0.02\n    # \u6a21\u6001\u5d4c\u5165\uff1a\u4e00\u4e2a\u7528\u4e8e\u6587\u672c\uff0c\u4e00\u4e2a\u7528\u4e8e\u56fe\u50cf\n    mod_emb = jax.random.normal(k3, (2, embed_dim)) * 0.02\n    text_embs = text_tokens @ W_text + mod_emb[0]  # (n_text, embed_dim)\n    img_embs = image_patches @ W_img + mod_emb[1]   # (n_img, embed_dim)\n    # \u4ea4\u9519\uff1a[IMG] token \u5728\u524d\uff0c\u7136\u540e\u662f [TEXT] token\uff08\u50cf LLaVA\uff09\n    combined = jnp.concatenate([img_embs, text_embs], axis=0)\n    print(f\"\u7ec4\u5408\u5e8f\u5217: {n_img} \u56fe\u50cf + {n_text} \u6587\u672c = {combined.shape[0]} tokens\")\n    return combined\n\n# \u5c1d\u8bd5\uff1a5 \u4e2a\u6587\u672c token\uff08dim 16\uff09\u548c 4 \u4e2a\u56fe\u50cf\u5757\uff08dim 64\uff09\ntext = jax.random.normal(jax.random.PRNGKey(1), (5, 16))\nimage = jax.random.normal(jax.random.PRNGKey(2), (4, 64))\nseq = interleave_modalities(text, image)\n# \u5b9e\u9a8c\uff1a\u6539\u53d8 embed_dim\uff0c\u4ea4\u6362\u4ea4\u9519\u987a\u5e8f\uff0c\u6dfb\u52a0\u7b2c\u4e09\u4e2a\u6a21\u6001\n

\u4efb\u52a1 2\uff1a\u53ef\u89c6\u5316\u8de8\u6a21\u6001\u6ce8\u610f\u529b\u6a21\u5f0f

import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef cross_modal_attention(n_text=6, n_img=4, d=32, key=jax.random.PRNGKey(42)):\n    \"\"\"\u8ba1\u7b97\u5e76\u53ef\u89c6\u5316\u6587\u672c\u548c\u56fe\u50cf token \u4e4b\u95f4\u7684\u6ce8\u610f\u529b\u3002\"\"\"\n    k1, k2, k3 = jax.random.split(key, 3)\n    # \u6a21\u62df\u4e24\u79cd\u6a21\u6001\u7684 token \u5d4c\u5165\n    text_embs = jax.random.normal(k1, (n_text, d))\n    img_embs = jax.random.normal(k2, (n_img, d))\n    seq = jnp.concatenate([img_embs, text_embs], axis=0)  # (n_img+n_text, d)\n    # \u5b66\u4e60\u5230\u7684 Q, K \u6295\u5f71\n    Wq = jax.random.normal(k3, (d, d)) * 0.1\n    Wk = jax.random.normal(jax.random.PRNGKey(99), (d, d)) * 0.1\n    Q, K = seq @ Wq, seq @ Wk\n    scores = Q @ K.T / jnp.sqrt(d)\n    attn = jax.nn.softmax(scores, axis=-1)\n    # \u7ed8\u56fe\n    labels = [f\"img_{i}\" for i in range(n_img)] + [f\"txt_{i}\" for i in range(n_text)]\n    fig, ax = plt.subplots(figsize=(7, 6))\n    ax.imshow(attn, cmap=\"viridis\")\n    ax.set_xticks(range(len(labels))); ax.set_xticklabels(labels, rotation=45, fontsize=8)\n    ax.set_yticks(range(len(labels))); ax.set_yticklabels(labels, fontsize=8)\n    ax.set_xlabel(\"Key\uff08\u88ab\u5173\u6ce8\u7684\uff09\"); ax.set_ylabel(\"Query\uff08\u53d1\u8d77\u7684\uff09\")\n    ax.set_title(\"\u8de8\u6a21\u6001\u81ea\u6ce8\u610f\u529b\u56fe\")\n    plt.colorbar(ax.images[0], ax=ax, shrink=0.8)\n    plt.tight_layout(); plt.show()\n\ncross_modal_attention()\n# \u5b9e\u9a8c\uff1a\u589e\u5927 d\uff0c\u6dfb\u52a0\u56e0\u679c\u63a9\u7801\uff0c\u89c2\u5bdf\u6ce8\u610f\u529b\u6a21\u5f0f\u5982\u4f55\u53d8\u5316\n

\u4efb\u52a1 3\uff1a\u6a21\u62df\u5e26\u6709\u6a21\u6001\u7279\u5b9a\u635f\u5931\u6743\u91cd\u7684\u5206\u9636\u6bb5\u8bad\u7ec3

import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef staged_training_sim(steps=200, key=jax.random.PRNGKey(7)):\n    \"\"\"\u6a21\u62df\u5177\u6709\u53ef\u8c03\u8282\u6a21\u6001\u635f\u5931\u6743\u91cd\u7684\u591a\u6a21\u6001\u8bad\u7ec3\u3002\"\"\"\n    # \u4e24\u79cd\"\u6a21\u6001\"\uff0c\u635f\u5931\u5c3a\u5ea6\u4e0d\u540c\uff08\u6587\u672c\u635f\u5931\u6bd4\u56fe\u50cf\u635f\u5931\u5927\u7ea6 10 \u500d\uff09\n    losses_text, losses_img = [], []\n    param = jnp.array([0.0, 0.0])  # \u4e24\u79cd\u6a21\u6001\u635f\u5931\u5171\u540c\u66f4\u65b0\u7684\u5171\u4eab\u53c2\u6570\n    lr = 0.05\n    # \u5c1d\u8bd5\u66f4\u6539\u8fd9\u4e9b\u6743\u91cd\u4ee5\u89c2\u5bdf\u5bf9\u6536\u655b\u5e73\u8861\u7684\u5f71\u54cd\n    lambda_text, lambda_img = 1.0, 5.0  # \u5bf9\u8f83\u5f31\u6a21\u6001\u52a0\u5927\u6743\u91cd\n\n    for step in range(steps):\n        k1, k2, key = jax.random.split(key, 3)\n        noise_t = jax.random.normal(k1, ()) * 0.3\n        noise_i = jax.random.normal(k2, ()) * 0.1\n        loss_t = (param[0] - 3.0) ** 2 + noise_t  # \u6587\u672c\u76ee\u6807 = 3.0\n        loss_i = 0.1 * (param[1] - 1.0) ** 2 + noise_i  # \u56fe\u50cf\u76ee\u6807 = 1.0\uff08\u5c3a\u5ea6\u66f4\u5c0f\uff09\n        # \u52a0\u6743\u7ec4\u5408\u68af\u5ea6\n        grad_t = lambda_text * 2 * (param[0] - 3.0)\n        grad_i = lambda_img * 0.2 * (param[1] - 1.0)\n        param = param - lr * jnp.array([grad_t, grad_i])\n        losses_text.append(float(loss_t)); losses_img.append(float(loss_i))\n\n    fig, ax = plt.subplots(figsize=(8, 4))\n    ax.plot(losses_text, label=f\"\u6587\u672c\u635f\u5931 (\u6743\u91cd={lambda_text})\", alpha=0.7)\n    ax.plot(losses_img, label=f\"\u56fe\u50cf\u635f\u5931 (\u6743\u91cd={lambda_img})\", alpha=0.7)\n    ax.set_xlabel(\"\u8bad\u7ec3\u6b65\u6570\"); ax.set_ylabel(\"\u635f\u5931\"); ax.legend()\n    ax.set_title(\"\u5206\u9636\u6bb5\u8bad\u7ec3\u4e2d\u7684\u6a21\u6001\u635f\u5931\u5e73\u8861\")\n    plt.tight_layout(); plt.show()\n\nstaged_training_sim()\n# \u5b9e\u9a8c\uff1a\u8bbe\u7f6e lambda_img=1.0\uff0c\u89c2\u5bdf\u56fe\u50cf\u635f\u5931\u6536\u655b\u6162\u5f97\u591a\n
"},{"location":"chapter%2011%3A%20autonomous%20systems/01.%20perception/","title":"\u611f\u77e5","text":"

\u611f\u77e5\u662f\u81ea\u4e3b\u7cfb\u7edf\u611f\u77e5\u548c\u89e3\u91ca\u7269\u7406\u4e16\u754c\u7684\u65b9\u5f0f\u3002\u672c\u7ae0\u6db5\u76d6\u4f20\u611f\u5668\u6a21\u6001\u3001\u6807\u5b9a\u3001\u4f20\u611f\u5668\u878d\u5408\u30013D\u76ee\u6807\u68c0\u6d4b\u3001\u6df1\u5ea6\u4f30\u8ba1\u3001\u5360\u636e\u7f51\u7edc\u3001\u8f66\u9053\u68c0\u6d4b\u548c\u8bed\u4e49\u5efa\u56fe\u2014\u2014\u8fd9\u662f\u6bcf\u4e2a\u673a\u5668\u4eba\u3001\u65e0\u4eba\u673a\u548c\u81ea\u52a8\u9a7e\u9a76\u6c7d\u8f66\u8d56\u4ee5\u6784\u5efa\u7684\u611f\u77e5\u57fa\u7840\u3002

"},{"location":"chapter%2011%3A%20autonomous%20systems/01.%20perception/#_2","title":"\u4f20\u611f\u5668\u6a21\u6001","text":" \\[\\\\begin{bmatrix} u \\\\\\\\ v \\\\\\\\ 1 \\\\end{bmatrix} = \\\\frac{1}{Z} K \\\\begin{bmatrix} X \\\\\\\\ Y \\\\\\\\ Z \\\\end{bmatrix}\\] \\[d = \\\\frac{c \\\\cdot \\\\Delta t}{2}\\]

\\[v = \\\\frac{\\\\Delta f \\\\cdot c}{2 f_0}\\] "},{"location":"chapter%2011%3A%20autonomous%20systems/01.%20perception/#_3","title":"\u4f20\u611f\u5668\u6807\u5b9a","text":" \\[\\\\mathbf{p}_{\\\\text{\u76f8}} = \\\\begin{bmatrix} R & \\\\mathbf{t} \\\\\\\\ \\\\mathbf{0}^T & 1 \\\\end{bmatrix} \\\\mathbf{p}_{\\\\text{\u6fc0}}\\] "},{"location":"chapter%2011%3A%20autonomous%20systems/01.%20perception/#_4","title":"\u4f20\u611f\u5668\u878d\u5408","text":"

"},{"location":"chapter%2011%3A%20autonomous%20systems/01.%20perception/#3d","title":"3D\u76ee\u6807\u68c0\u6d4b","text":""},{"location":"chapter%2011%3A%20autonomous%20systems/01.%20perception/#_5","title":"\u6df1\u5ea6\u4f30\u8ba1","text":" \\[Z = \\\\frac{f \\\\cdot b}{d}\\] "},{"location":"chapter%2011%3A%20autonomous%20systems/01.%20perception/#_6","title":"\u5360\u636e\u7f51\u7edc","text":" "},{"location":"chapter%2011%3A%20autonomous%20systems/01.%20perception/#_7","title":"\u8f66\u9053\u68c0\u6d4b\u4e0e\u9053\u8def\u62d3\u6251","text":" \\[x(y) = a_0 + a_1 y + a_2 y^2 + a_3 y^3\\] "},{"location":"chapter%2011%3A%20autonomous%20systems/01.%20perception/#_8","title":"\u8bed\u4e49\u5efa\u56fe","text":" \\[P(\\\\text{\u5360\u636e} \\\\mid z_{1:t}) = \\\\frac{P(z_t \\\\mid \\\\text{\u5360\u636e}) \\\\cdot P(\\\\text{\u5360\u636e} \\\\mid z_{1:t-1})}{P(z_t)}\\] \\[l_t = l_{t-1} + \\\\log \\\\frac{P(z_t \\\\mid \\\\text{\u5360\u636e})}{P(z_t \\\\mid \\\\text{\u7a7a\u95f2})}\\] "},{"location":"chapter%2011%3A%20autonomous%20systems/01.%20perception/#colabnotebook","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528CoLab\u6216notebook\uff09","text":"
  1. \u4f7f\u7528\u6295\u5f71\u77e9\u9635\u5c063D LiDAR\u70b9\u6295\u5f71\u52302D\u76f8\u673a\u56fe\u50cf\u4e0a\u3002\u53ef\u89c6\u5316\u54ea\u4e9b\u70b9\u843d\u5728\u56fe\u50cf\u8fb9\u754c\u5185\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u6a21\u62df3D LiDAR\u70b9\uff08x=\u5411\u524d\uff0cy=\u5411\u5de6\uff0cz=\u5411\u4e0a\uff09\nrng = jax.random.PRNGKey(0)\npoints_3d = jax.random.uniform(rng, (200, 3), minval=jnp.array([5, -10, -2]),\n                                maxval=jnp.array([50, 10, 3]))\n\n# \u76f8\u673a\u5185\u53c2\u77e9\u9635\uff08\u7126\u8ddd500\uff0c\u56fe\u50cf\u4e2d\u5fc3320x240\uff09\nK = jnp.array([[500, 0, 320],\n               [0, 500, 240],\n               [0,   0,   1.0]])\n\n# \u5916\u53c2\uff1aLiDAR\u5230\u76f8\u673a\uff08\u5355\u4f4d\u65cb\u8f6c\uff0c\u5c0f\u5e73\u79fb\uff09\nR = jnp.eye(3)\nt = jnp.array([0.0, 0.0, -0.5])\n\n# \u6295\u5f71\uff1ap_cam = K @ (R @ p_lidar + t)\np_cam = (R @ points_3d.T).T + t\np_img = (K @ p_cam.T).T\np_img = p_img[:, :2] / p_img[:, 2:3]  # \u9664\u4ee5Z\n\n# \u8fc7\u6ee4\u76f8\u673a\u524d\u65b9\u4e14\u5728\u56fe\u50cf\u5185\u7684\u70b9\nmask = (p_cam[:, 2] > 0) & (p_img[:, 0] > 0) & (p_img[:, 0] < 640) & \\\n       (p_img[:, 1] > 0) & (p_img[:, 1] < 480)\ndepth = p_cam[mask, 2]\n\nplt.figure(figsize=(8, 5))\nplt.scatter(p_img[mask, 0], p_img[mask, 1], c=depth, cmap=\"viridis\", s=5)\nplt.colorbar(label=\"\u6df1\u5ea6 (\u7c73)\")\nplt.xlim(0, 640); plt.ylim(480, 0)\nplt.title(\"\u6295\u5f71\u5230\u76f8\u673a\u56fe\u50cf\u4e0a\u7684LiDAR\u70b9\")\nplt.xlabel(\"u (\u50cf\u7d20)\"); plt.ylabel(\"v (\u50cf\u7d20)\")\nplt.show()\n

  2. \u4f7f\u7528\u8d1d\u53f6\u65af\u5bf9\u6570\u51e0\u7387\u66f4\u65b0\u6784\u5efa\u4e00\u4e2a\u7b80\u5355\u76842D\u5360\u636e\u7f51\u683c\u3002\u6a21\u62df\u4e00\u4e2a\u8ddd\u79bb\u4f20\u611f\u5668\u626b\u63cf\u73af\u5883\uff0c\u89c2\u5bdf\u5730\u56fe\u7684\u751f\u6210\u8fc7\u7a0b\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u7f51\u683c\u8bbe\u7f6e\uff1a50x50\u4e2a\u5355\u5143\uff0c\u6bcf\u4e2a0.2\u7c73\ngrid_size = 50\nlog_odds = jnp.zeros((grid_size, grid_size))\n\n# \u4f20\u611f\u5668\u6a21\u578b\uff1a\u5bf9\u6570\u51e0\u7387\u66f4\u65b0\u503c\nl_occ = 0.85   # \u547d\u4e2d\u610f\u5473\u7740\u5360\u636e\u7684\u7f6e\u4fe1\u5ea6\nl_free = -0.4  # \u7a7f\u8fc7\u610f\u5473\u7740\u7a7a\u95f2\u7684\u7f6e\u4fe1\u5ea6\n\n# \u6a21\u62df\u969c\u788d\u7269\uff1a\u4ece(5,20)\u5230(5,30)\u7684\u5899\uff08\u7f51\u683c\u5750\u6807\uff09\nwall_y = jnp.arange(20, 30)\n\n# \u673a\u5668\u4eba\u5728(25, 25)\uff0c\u5411\u5916\u626b\u63cf\nrobot = jnp.array([25, 25])\n\nfor angle_deg in range(0, 360, 5):\n    angle = jnp.radians(angle_deg)\n    direction = jnp.array([jnp.cos(angle), jnp.sin(angle)])\n\n    for step in range(1, 25):\n        cell = (robot + direction * step).astype(int)\n        r, c = int(cell[0]), int(cell[1])\n        if r < 0 or r >= grid_size or c < 0 or c >= grid_size:\n            break\n\n        # \u68c0\u67e5\u6b64\u5355\u5143\u662f\u5426\u4e3a\u5899\n        is_wall = (r == 5) and (c >= 20) and (c < 30)\n        if is_wall:\n            log_odds = log_odds.at[r, c].add(l_occ)\n            break\n        else:\n            log_odds = log_odds.at[r, c].add(l_free)\n\n# \u5c06\u5bf9\u6570\u51e0\u7387\u8f6c\u6362\u4e3a\u6982\u7387\nprob = 1.0 / (1.0 + jnp.exp(-log_odds))\n\nplt.figure(figsize=(6, 6))\nplt.imshow(prob.T, origin=\"lower\", cmap=\"RdYlGn_r\", vmin=0, vmax=1)\nplt.colorbar(label=\"P(\u88ab\u5360\u636e)\")\nplt.plot(25, 25, \"b*\", markersize=10, label=\"\u673a\u5668\u4eba\")\nplt.legend()\nplt.title(\"\u8d1d\u53f6\u65af\u66f4\u65b0\u751f\u6210\u76842D\u5360\u636e\u7f51\u683c\")\nplt.show()\n

  3. \u4f7f\u7528\u89c6\u5dee\u4ece\u7acb\u4f53\u56fe\u50cf\u5bf9\u8ba1\u7b97\u6df1\u5ea6\u3002\u6a21\u62df\u4e24\u4e2a\u76f8\u673a\u89c6\u89d2\u4e0b\u76843D\u70b9\uff0c\u8ba1\u7b97\u89c6\u5dee\u5e76\u6062\u590d\u6df1\u5ea6\u3002

    import jax\nimport jax.numpy as jnp\n\n# \u76f8\u673a\u53c2\u6570\nf = 500.0     # \u7126\u8ddd\uff08\u50cf\u7d20\uff09\nb = 0.12      # \u57fa\u7ebf\uff08\u7c73\uff0c12\u5398\u7c73\uff09\n\n# \u5df2\u77e5\u6df1\u5ea6\u76843D\u70b9\ndepths_true = jnp.array([5.0, 10.0, 20.0, 50.0, 100.0])\n\n# \u89c6\u5dee = f * b / Z\ndisparities = f * b / depths_true\n\n# \u4ece\u89c6\u5dee\u6062\u590d\u6df1\u5ea6\ndepths_recovered = f * b / disparities\n\nfor z, d, z_r in zip(depths_true, disparities, depths_recovered):\n    print(f\"\u771f\u5b9e\u6df1\u5ea6: {z:6.1f}\u7c73  \u89c6\u5dee: {d:6.2f}\u50cf\u7d20  \u6062\u590d\u503c: {z_r:6.1f}\u7c73\")\n\n# \u6ce8\u610f\uff1a\u89c6\u5dee\u4e0e\u6df1\u5ea6\u6210\u53cd\u6bd4\n# \u8fd1\u5904\u7269\u4f53\u89c6\u5dee\u5927\uff0c\u8fdc\u5904\u7269\u4f53\u89c6\u5dee\u5c0f\n# \u8fd9\u5c31\u662f\u4e3a\u4ec0\u4e48\u7acb\u4f53\u89c6\u89c9\u5728\u8fd1\u8ddd\u79bb\u6700\u51c6\u786e\n

"},{"location":"chapter%2011%3A%20autonomous%20systems/02.%20robot%20learning/","title":"\u673a\u5668\u4eba\u5b66\u4e60","text":"

\u673a\u5668\u4eba\u5b66\u4e60\u5f25\u5408\u4e86\u7b97\u6cd5\u4e0e\u7269\u7406\u884c\u52a8\u4e4b\u95f4\u7684\u9e3f\u6c9f\u3002\u672c\u7ae0\u6db5\u76d6\u8fd0\u52a8\u5b66\u3001\u52a8\u529b\u5b66\u3001\u7ecf\u5178\u63a7\u5236\u3001\u6a21\u4eff\u5b66\u4e60\u3001\u4eff\u771f\u5230\u73b0\u5b9e\u8fc1\u79fb\u3001\u64cd\u4f5c\u3001\u79fb\u52a8\u548c\u5b89\u5168\u2014\u2014\u8fd9\u4e9b\u6280\u672f\u8d4b\u4e88\u673a\u5668\u4eba\u5728\u73b0\u5b9e\u4e16\u754c\u4e2d\u79fb\u52a8\u3001\u6293\u53d6\u3001\u884c\u8d70\u548c\u4ea4\u4e92\u7684\u80fd\u529b\u3002

"},{"location":"chapter%2011%3A%20autonomous%20systems/02.%20robot%20learning/#_2","title":"\u673a\u5668\u4eba\u8fd0\u52a8\u5b66","text":"

\\[T_i = \\\\begin{bmatrix} \\\\cos\\\\theta_i & -\\\\sin\\\\theta_i \\\\cos\\\\alpha_i & \\\\sin\\\\theta_i \\\\sin\\\\alpha_i & a_i \\\\cos\\\\theta_i \\\\\\\\ \\\\sin\\\\theta_i & \\\\cos\\\\theta_i \\\\cos\\\\alpha_i & -\\\\cos\\\\theta_i \\\\sin\\\\alpha_i & a_i \\\\sin\\\\theta_i \\\\\\\\ 0 & \\\\sin\\\\alpha_i & \\\\cos\\\\alpha_i & d_i \\\\\\\\ 0 & 0 & 0 & 1 \\\\end{bmatrix}\\] \\[\\\\dot{\\\\mathbf{x}} = J(\\\\mathbf{q}) \\\\dot{\\\\mathbf{q}}\\] \\[\\\\Delta \\\\mathbf{q} = J^T(JJ^T + \\\\lambda^2 I)^{-1} \\\\Delta \\\\mathbf{x}\\]"},{"location":"chapter%2011%3A%20autonomous%20systems/02.%20robot%20learning/#_3","title":"\u52a8\u529b\u5b66\u4e0e\u63a7\u5236","text":" \\[M(\\\\mathbf{q})\\\\ddot{\\\\mathbf{q}} + C(\\\\mathbf{q}, \\\\dot{\\\\mathbf{q}})\\\\dot{\\\\mathbf{q}} + \\\\mathbf{g}(\\\\mathbf{q}) = \\\\boldsymbol{\\\\tau}\\] \\[\\\\tau(t) = K_p e(t) + K_i \\\\int_0^t e(s) \\\\, ds + K_d \\\\dot{e}(t)\\]

\\[\\\\min_{\\\\mathbf{u}_{0:T}} \\\\sum_{t=0}^{T} \\\\left[ \\\\|\\\\mathbf{x}_t - \\\\mathbf{x}_t^*\\\\|_Q^2 + \\\\|\\\\mathbf{u}_t\\\\|_R^2 \\\\right] \\\\quad \\\\text{subject to} \\\\quad \\\\mathbf{x}_{t+1} = f(\\\\mathbf{x}_t, \\\\mathbf{u}_t)\\] \\[F = K_s(\\\\mathbf{x}^* - \\\\mathbf{x}) + D(\\\\dot{\\\\mathbf{x}}^* - \\\\dot{\\\\mathbf{x}})\\] "},{"location":"chapter%2011%3A%20autonomous%20systems/02.%20robot%20learning/#_4","title":"\u6a21\u4eff\u5b66\u4e60","text":" \\[\\\\mathcal{L}(\\\\theta) = \\\\mathbb{E}_{(\\\\mathbf{o}, \\\\mathbf{a}) \\\\sim \\\\mathcal{D}} \\\\left[ \\\\| \\\\pi_\\\\theta(\\\\mathbf{o}) - \\\\mathbf{a} \\\\|^2 \\\\right]\\]

"},{"location":"chapter%2011%3A%20autonomous%20systems/02.%20robot%20learning/#_5","title":"\u4eff\u771f\u5230\u73b0\u5b9e\u8fc1\u79fb","text":"

"},{"location":"chapter%2011%3A%20autonomous%20systems/02.%20robot%20learning/#_6","title":"\u673a\u5668\u4eba\u4e16\u754c\u6a21\u578b","text":" \\[\\\\hat{s}_{t+1} = f_\\\\theta(s_t, a_t), \\\\quad \\\\hat{r}_t = g_\\\\theta(s_t)\\] "},{"location":"chapter%2011%3A%20autonomous%20systems/02.%20robot%20learning/#_7","title":"\u64cd\u4f5c","text":""},{"location":"chapter%2011%3A%20autonomous%20systems/02.%20robot%20learning/#_8","title":"\u79fb\u52a8","text":" \\[\\\\dot{\\\\phi}_i = \\\\omega_i + \\\\sum_j w_{ij} \\\\sin(\\\\phi_j - \\\\phi_i - \\\\psi_{ij})\\] "},{"location":"chapter%2011%3A%20autonomous%20systems/02.%20robot%20learning/#_9","title":"\u673a\u5668\u4eba\u5b66\u4e60\u4e2d\u7684\u5b89\u5168\u6027","text":""},{"location":"chapter%2011%3A%20autonomous%20systems/02.%20robot%20learning/#colabnotebook","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528CoLab\u6216notebook\uff09","text":"
  1. \u5b9e\u73b0\u4e00\u4e2a\u7b80\u53552\u8fde\u6746\u5e73\u9762\u673a\u5668\u4eba\u624b\u81c2\u7684\u6b63\u5411\u8fd0\u52a8\u5b66\u3002\u8ba1\u7b97\u5e76\u53ef\u89c6\u5316\u4e0d\u540c\u5173\u8282\u89d2\u5ea6\u4e0b\u7684\u672b\u7aef\u6267\u884c\u5668\u4f4d\u7f6e\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ndef forward_kinematics(q1, q2, l1=1.0, l2=0.8):\n    \"\"\"\u8ba1\u7b972\u8fde\u6746\u624b\u81c2\u7684\u5173\u8282\u548c\u672b\u7aef\u6267\u884c\u5668\u4f4d\u7f6e\u3002\"\"\"\n    x1 = l1 * jnp.cos(q1)\n    y1 = l1 * jnp.sin(q1)\n    x2 = x1 + l2 * jnp.cos(q1 + q2)\n    y2 = y1 + l2 * jnp.sin(q1 + q2)\n    return jnp.array([0, x1, x2]), jnp.array([0, y1, y2])\n\nfig, ax = plt.subplots(figsize=(6, 6))\nconfigs = [(0.5, 0.3), (1.0, -0.5), (1.5, 1.0), (2.0, -1.5)]\ncolors = [\"#e74c3c\", \"#3498db\", \"#27ae60\", \"#9b59b6\"]\n\nfor (q1, q2), c in zip(configs, colors):\n    xs, ys = forward_kinematics(q1, q2)\n    ax.plot(xs, ys, \"o-\", color=c, linewidth=2, markersize=6,\n            label=f\"q=({q1:.1f}, {q2:.1f})\")\n\nax.set_xlim(-2, 2); ax.set_ylim(-2, 2)\nax.set_aspect(\"equal\"); ax.grid(True); ax.legend()\nax.set_title(\"2\u8fde\u6746\u673a\u5668\u4eba\u624b\u81c2\uff1a\u6b63\u5411\u8fd0\u52a8\u5b66\")\nplt.show()\n

  2. \u4f7f\u7528\u96c5\u53ef\u6bd4\u4f2a\u9006\u5b9e\u73b0\u9006\u5411\u8fd0\u52a8\u5b66\u3002\u4ece\u968f\u673a\u6784\u578b\u5f00\u59cb\uff0c\u8fed\u4ee3\u5730\u5c06\u672b\u7aef\u6267\u884c\u5668\u79fb\u52a8\u5230\u76ee\u6807\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\nl1, l2 = 1.0, 0.8\n\ndef end_effector(q):\n    x = l1 * jnp.cos(q[0]) + l2 * jnp.cos(q[0] + q[1])\n    y = l1 * jnp.sin(q[0]) + l2 * jnp.sin(q[0] + q[1])\n    return jnp.array([x, y])\n\njacobian_fn = jax.jacobian(end_effector)\n\ntarget = jnp.array([0.5, 1.2])\nq = jnp.array([0.1, 0.1])\ntrajectory = [end_effector(q)]\n\nfor _ in range(50):\n    pos = end_effector(q)\n    error = target - pos\n    if jnp.linalg.norm(error) < 1e-4:\n        break\n    J = jacobian_fn(q)\n    # \u963b\u5c3c\u4f2a\u9006\u5904\u7406\u63a5\u8fd1\u5947\u5f02\u70b9\u7684\u60c5\u51b5\n    dq = J.T @ jnp.linalg.solve(J @ J.T + 0.01 * jnp.eye(2), error)\n    q = q + dq\n    trajectory.append(end_effector(q))\n\ntraj = jnp.stack(trajectory)\nplt.plot(traj[:, 0], traj[:, 1], \"b.-\", label=\"\u672b\u7aef\u6267\u884c\u5668\u8def\u5f84\")\nplt.plot(*target, \"r*\", markersize=15, label=\"\u76ee\u6807\u70b9\")\nplt.gca().set_aspect(\"equal\"); plt.grid(True); plt.legend()\nplt.title(f\"IK\u5728{len(trajectory)-1}\u6b65\u5185\u6536\u655b\")\nplt.show()\n

  3. \u6a21\u62df\u4e00\u4e2a\u7b80\u5355\u7684PID\u63a7\u5236\u5668\u8ddf\u8e2a\u671f\u671b\u7684\u5173\u8282\u8f68\u8ff9\u3002\u89c2\u5bdf\u8c03\u53c2\u5bf9\u589e\u76ca\u7684\u5f71\u54cd\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u671f\u671b\u8f68\u8ff9\uff1a\u5e73\u6ed1\u6b63\u5f26\u8fd0\u52a8\ndt = 0.01\nt = jnp.arange(0, 5, dt)\nq_desired = jnp.sin(2 * t)\n\n# \u6a21\u62df\u4e8c\u9636\u52a8\u529b\u5b66\uff1am * q_ddot + b * q_dot = tau\nm, b_damp = 1.0, 0.5\n\nfor Kp, Kd, Ki, label in [(10, 5, 0, \"\u4ec5PD\"), (10, 5, 2, \"PID\"), (50, 10, 2, \"\u6fc0\u8fdbPID\")]:\n    q, q_dot, integral = 0.0, 0.0, 0.0\n    qs = []\n    for i in range(len(t)):\n        error = q_desired[i] - q\n        integral += error * dt\n        d_error = -q_dot  # \u8bef\u5dee\u5bfc\u6570\uff08\u6b64\u5904\u7b80\u5316\uff0c\u5df2\u77e5\u671f\u671b\u901f\u5ea6\uff09\n        tau = Kp * error + Kd * d_error + Ki * integral\n        q_ddot = (tau - b_damp * q_dot) / m\n        q_dot += q_ddot * dt\n        q += q_dot * dt\n        qs.append(float(q))\n\n    plt.plot(t, qs, label=label)\n\nplt.plot(t, q_desired, \"k--\", label=\"\u671f\u671b\u503c\", linewidth=2)\nplt.xlabel(\"\u65f6\u95f4 (\u79d2)\"); plt.ylabel(\"\u5173\u8282\u89d2\u5ea6\")\nplt.legend(); plt.title(\"PID\u63a7\u5236\u5668\u8ddf\u8e2a\")\nplt.show()\n

"},{"location":"chapter%2011%3A%20autonomous%20systems/03.%20vision-language-action%20models/","title":"\u89c6\u89c9-\u8bed\u8a00-\u52a8\u4f5c\u6a21\u578b","text":"

\u89c6\u89c9-\u8bed\u8a00-\u52a8\u4f5c\u6a21\u578b\uff08VLA\uff09\u5c06\u89c6\u89c9\u7406\u89e3\u3001\u8bed\u8a00\u7406\u89e3\u548c\u884c\u52a8\u63a7\u5236\u7edf\u4e00\u5230\u5355\u4e2a\u795e\u7ecf\u7f51\u7edc\u4e2d\u3002\u672c\u7ae0\u6db5\u76d6VLA\u67b6\u6784\u3001\u52a8\u4f5c\u6807\u8bb0\u5316\u3001RT-2\u3001Octo\u3001OpenVLA\u3001\u9884\u8bad\u7ec3\u7b56\u7565\u3001\u6cdb\u5316\u80fd\u529b\u3001\u4e0e\u5177\u4f53\u5f62\u6001\u65e0\u5173\u7684\u6a21\u578b\u4ee5\u53ca\u57fa\u51c6\u6d4b\u8bd5\u3002

"},{"location":"chapter%2011%3A%20autonomous%20systems/03.%20vision-language-action%20models/#-_1","title":"\u4ece\u89c6\u89c9-\u8bed\u8a00\u5230\u884c\u52a8","text":""},{"location":"chapter%2011%3A%20autonomous%20systems/03.%20vision-language-action%20models/#vla","title":"VLA\u67b6\u6784","text":" \\[\\\\text{\u56fe\u50cf} \\\\xrightarrow{\\\\text{ViT}} \\\\text{\u89c6\u89c9\u6807\u8bb0} \\\\quad + \\\\quad \\\\text{\u6307\u4ee4} \\\\xrightarrow{\\\\text{\u5206\u8bcd\u5668}} \\\\text{\u8bed\u8a00\u6807\u8bb0} \\\\quad \\\\xrightarrow{\\\\text{LLM}} \\\\quad \\\\text{\u52a8\u4f5c\u6807\u8bb0}\\] "},{"location":"chapter%2011%3A%20autonomous%20systems/03.%20vision-language-action%20models/#_1","title":"\u52a8\u4f5c\u6807\u8bb0\u5316","text":" "},{"location":"chapter%2011%3A%20autonomous%20systems/03.%20vision-language-action%20models/#vla_1","title":"\u5173\u952eVLA\u6a21\u578b","text":"

- **\u7075\u6d3b\u7684\u89c2\u6d4b\u548c\u52a8\u4f5c\u7a7a\u95f4**\uff1aOcto\u4e3a\u4e0d\u540c\u7684\u673a\u5668\u4eba\u914d\u7f6e\u4f7f\u7528\u7279\u5b9a\u4e8e\u4efb\u52a1\u7684\u6807\u8bb0\u5316\u5668\u3002\u5b83\u5728Open X-Embodiment\u6570\u636e\u96c6\u4e0a\u9884\u8bad\u7ec3\uff0c\u8be5\u6570\u636e\u96c6\u5305\u542b\u6765\u81ea22\u79cd\u4e0d\u540c\u673a\u5668\u4eba\u5f62\u6001\u7684\u793a\u8303\u3002\n\n- **\u9ad8\u6548\u5fae\u8c03**\uff1aOcto\u53ea\u9700100\u4e2a\u793a\u8303\u5c31\u53ef\u4ee5\u5fae\u8c03\u5230\u65b0\u673a\u5668\u4eba\uff0c\u4f7f\u5176\u9002\u7528\u4e8e\u6570\u636e\u6709\u9650\u7684\u5b9e\u9a8c\u5ba4\u3002\n
"},{"location":"chapter%2011%3A%20autonomous%20systems/03.%20vision-language-action%20models/#_2","title":"\u9884\u8bad\u7ec3\u914d\u65b9","text":""},{"location":"chapter%2011%3A%20autonomous%20systems/03.%20vision-language-action%20models/#_3","title":"\u6cdb\u5316\u80fd\u529b","text":""},{"location":"chapter%2011%3A%20autonomous%20systems/03.%20vision-language-action%20models/#_4","title":"\u4e0e\u5f62\u6001\u65e0\u5173\u7684\u6a21\u578b","text":""},{"location":"chapter%2011%3A%20autonomous%20systems/03.%20vision-language-action%20models/#_5","title":"\u57fa\u51c6\u6d4b\u8bd5\u4e0e\u8bc4\u4f30","text":""},{"location":"chapter%2011%3A%20autonomous%20systems/03.%20vision-language-action%20models/#colabnotebook","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528CoLab\u6216notebook\uff09","text":"
  1. \u5b9e\u73b0\u52a8\u4f5c\u6807\u8bb0\u5316\uff1a\u5c06\u8fde\u7eed\u52a8\u4f5c\u79bb\u6563\u5316\u4e3a\u7bb1\u5e76\u91cd\u5efa\u3002\u89c2\u5bdf\u91cf\u5316\u8bef\u5dee\u968f\u7bb1\u6570\u91cf\u7684\u53d8\u5316\u3002

    import jax.numpy as jnp\n\n# \u8fde\u7eed\u52a8\u4f5c\uff1a7\u4e2a\u7ef4\u5ea6\uff086\u81ea\u7531\u5ea6+\u5939\u722a\uff09\naction_true = jnp.array([0.023, -0.051, 0.012, 0.1, -0.03, 0.005, 0.8])\naction_min = jnp.array([-0.1, -0.1, -0.1, -0.5, -0.5, -0.5, 0.0])\naction_max = jnp.array([ 0.1,  0.1,  0.1,  0.5,  0.5,  0.5, 1.0])\n\nfor n_bins in [16, 64, 256, 1024]:\n    # \u6807\u8bb0\u5316\uff1a\u5c06\u8fde\u7eed\u503c\u6620\u5c04\u4e3a\u7bb1\u7d22\u5f15\n    normalised = (action_true - action_min) / (action_max - action_min)\n    tokens = jnp.clip((normalised * n_bins).astype(int), 0, n_bins - 1)\n\n    # \u53bb\u6807\u8bb0\u5316\uff1a\u5c06\u7bb1\u7d22\u5f15\u6620\u5c04\u56de\u8fde\u7eed\u503c\n    reconstructed = (tokens + 0.5) / n_bins * (action_max - action_min) + action_min\n\n    error = jnp.linalg.norm(action_true - reconstructed)\n    print(f\"\u7bb1\u6570={n_bins:4d}  \u6807\u8bb0={tokens}  \u8bef\u5dee={error:.6f}\")\n

  2. \u6a21\u62df\u52a8\u4f5c\u5206\u5757\u4e0e\u5355\u6b65\u9884\u6d4b\u7684\u6bd4\u8f83\u3002\u751f\u6210\u5e73\u6ed1\u8f68\u8ff9\uff0c\u5411\u5355\u6b65\u9884\u6d4b\u6dfb\u52a0\u566a\u58f0\uff0c\u5e76\u4e0e\u5206\u5757\u9884\u6d4b\u6bd4\u8f83\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u771f\u5b9e\u5e73\u6ed1\u8f68\u8ff9\uff08\u4f8b\u5982\uff0c\u4f38\u624b\u52a8\u4f5c\uff09\nt = jnp.linspace(0, 2 * jnp.pi, 100)\ngt_x = jnp.sin(t)\ngt_y = 1 - jnp.cos(t)\n\n# \u5355\u6b65\uff1a\u6bcf\u6b21\u9884\u6d4b\u6709\u72ec\u7acb\u566a\u58f0\nrng = jax.random.PRNGKey(42)\nnoise_ss = jax.random.normal(rng, (100, 2)) * 0.05\nsingle_step = jnp.stack([gt_x, gt_y], axis=1) + noise_ss\n# \u5355\u6b65\u8bef\u5dee\u7d2f\u79ef\u6f02\u79fb\nsingle_step_cumulative = jnp.cumsum(noise_ss, axis=0) * 0.3 + jnp.stack([gt_x, gt_y], axis=1)\n\n# \u5206\u5757\uff08\u5757\u5927\u5c0f=10\uff09\uff1a\u5757\u5185\u566a\u58f0\u5173\u8054\uff0c\u66f4\u5e73\u6ed1\nchunk_size = 10\nrng2 = jax.random.PRNGKey(7)\nchunks = []\nfor i in range(0, 100, chunk_size):\n    chunk_noise = jax.random.normal(jax.random.fold_in(rng2, i), (2,)) * 0.05\n    chunk = jnp.stack([gt_x[i:i+chunk_size], gt_y[i:i+chunk_size]], axis=1)\n    chunks.append(chunk + chunk_noise)\nchunked = jnp.concatenate(chunks, axis=0)\n\nplt.figure(figsize=(8, 4))\nplt.plot(gt_x, gt_y, \"k-\", linewidth=2, label=\"\u771f\u5b9e\u8f68\u8ff9\")\nplt.plot(single_step_cumulative[:, 0], single_step_cumulative[:, 1],\n         \"r-\", alpha=0.7, label=\"\u5355\u6b65\uff08\u6f02\u79fb\uff09\")\nplt.plot(chunked[:, 0], chunked[:, 1], \"b-\", alpha=0.7, label=\"\u5206\u5757\uff08\u7a33\u5b9a\uff09\")\nplt.legend(); plt.axis(\"equal\"); plt.grid(True)\nplt.title(\"\u52a8\u4f5c\u5206\u5757 vs \u5355\u6b65\u9884\u6d4b\")\nplt.show()\n

  3. \u53ef\u89c6\u5316VLA\u52a8\u4f5c\u5206\u5e03\u5982\u4f55\u662f\u591a\u6a21\u6001\u7684\u3002\u4f7f\u7528\u7b80\u5355\u76842D\u9ad8\u65af\u6df7\u5408\u6765\u5c55\u793a\u4e3a\u4ec0\u4e48\u6269\u6563/\u6d41\u5339\u914d\u52a8\u4f5c\u5934\u4f18\u4e8e\u56de\u5f52\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u7ed5\u8fc7\u969c\u788d\u7269\u7684\u4e24\u79cd\u6709\u6548\u65b9\u5f0f\uff1a\u5de6\u8fb9\u6216\u53f3\u8fb9\nrng = jax.random.PRNGKey(0)\nk1, k2 = jax.random.split(rng)\n\nmode1 = jax.random.normal(k1, (200, 2)) * 0.15 + jnp.array([-1.0, 0.5])\nmode2 = jax.random.normal(k2, (200, 2)) * 0.15 + jnp.array([ 1.0, 0.5])\nsamples = jnp.concatenate([mode1, mode2])\n\n# \u56de\u5f52\u9884\u6d4b\u5747\u503c = \u6a21\u6001\u7684\u5747\u503c\uff08\u65e0\u6548\uff01\uff09\nmean_pred = samples.mean(axis=0)\n\nplt.figure(figsize=(6, 5))\nplt.scatter(samples[:, 0], samples[:, 1], s=5, alpha=0.5, label=\"\u771f\u5b9e\u52a8\u4f5c\u5206\u5e03\")\nplt.plot(*mean_pred, \"rx\", markersize=15, markeredgewidth=3, label=\"\u56de\u5f52\u5747\u503c\uff08\u65e0\u6548\uff01\uff09\")\nplt.plot(-1, 0.5, \"g^\", markersize=12, label=\"\u6a21\u60011\uff08\u5411\u5de6\uff09\")\nplt.plot(1, 0.5, \"b^\", markersize=12, label=\"\u6a21\u60012\uff08\u5411\u53f3\uff09\")\nplt.legend(); plt.grid(True)\nplt.title(\"\u591a\u6a21\u6001\u52a8\u4f5c\uff1a\u4e3a\u4ec0\u4e48\u56de\u5f52\u5931\u8d25\")\nplt.xlabel(\"\u52a8\u4f5c\u7ef4\u5ea61\"); plt.ylabel(\"\u52a8\u4f5c\u7ef4\u5ea62\")\nplt.show()\n

"},{"location":"chapter%2011%3A%20autonomous%20systems/04.%20self-driving/","title":"\u81ea\u52a8\u9a7e\u9a76\u6c7d\u8f66","text":"

\u81ea\u52a8\u9a7e\u9a76\u6c7d\u8f66\u662f\u5546\u4e1a\u4e0a\u6700\u5148\u8fdb\u7684\u81ea\u4e3b\u7cfb\u7edf\uff0c\u5c06\u611f\u77e5\u3001\u9884\u6d4b\u3001\u89c4\u5212\u548c\u63a7\u5236\u96c6\u6210\u5230\u5355\u4e2a\u8f66\u8f86\u4e2d\u3002\u672c\u7ae0\u6db5\u76d6\u81ea\u52a8\u9a7e\u9a76\u5806\u6808\u3001\u9ad8\u7cbe\u5730\u56fe\u3001\u8fd0\u52a8\u9884\u6d4b\u3001\u89c4\u5212\u3001\u7aef\u5230\u7aef\u9a7e\u9a76\u3001\u4eff\u771f\u3001\u5b89\u5168\u6807\u51c6\u548c\u81ea\u4e3b\u7b49\u7ea7\u3002

"},{"location":"chapter%2011%3A%20autonomous%20systems/04.%20self-driving/#_2","title":"\u81ea\u52a8\u9a7e\u9a76\u5806\u6808","text":" \\[\\\\text{\u611f\u77e5} \\\\to \\\\text{\u9884\u6d4b} \\\\to \\\\text{\u89c4\u5212} \\\\to \\\\text{\u63a7\u5236}\\] "},{"location":"chapter%2011%3A%20autonomous%20systems/04.%20self-driving/#_3","title":"\u9ad8\u7cbe\u5730\u56fe","text":""},{"location":"chapter%2011%3A%20autonomous%20systems/04.%20self-driving/#_4","title":"\u8fd0\u52a8\u9884\u6d4b","text":" \\[\\\\text{minADE}_K = \\\\min_{k \\\\in \\\\{1, \\\\ldots, K\\\\}} \\\\frac{1}{T} \\\\sum_{t=1}^{T} \\\\| \\\\hat{\\\\mathbf{p}}_t^{(k)} - \\\\mathbf{p}_t \\\\|_2\\] \\[\\\\mathbf{a}_i = \\\\frac{\\\\mathbf{v}_i^{\\\\text{\u671f\u671b}} - \\\\mathbf{v}_i}{\\\\tau} + \\\\sum_{j \\\\neq i} \\\\mathbf{f}_{ij}^{\\\\text{\u6392\u65a5}} + \\\\sum_{\\\\text{\u5899\u58c1}} \\\\mathbf{f}_{\\\\text{\u5899\u58c1}}\\] "},{"location":"chapter%2011%3A%20autonomous%20systems/04.%20self-driving/#_5","title":"\u89c4\u5212","text":" \\[\\\\min_{\\\\boldsymbol{\\\\xi}} \\\\underbrace{w_1 \\\\cdot J_{\\\\text{\u8fdb\u5ea6}}(\\\\boldsymbol{\\\\xi})}_{\\\\text{\u5230\u8fbe\u76ee\u7684\u5730}} + \\\\underbrace{w_2 \\\\cdot J_{\\\\text{\u8212\u9002}}(\\\\boldsymbol{\\\\xi})}_{\\\\text{\u5e73\u7a33\u884c\u9a76}} + \\\\underbrace{w_3 \\\\cdot J_{\\\\text{\u5b89\u5168}}(\\\\boldsymbol{\\\\xi})}_{\\\\text{\u907f\u514d\u78b0\u649e}}\\] \\[\\\\text{\u7ea6\u675f\u6761\u4ef6\uff1a\u8fd0\u52a8\u5b66\u7ea6\u675f\u3001\u9650\u901f\u3001\u8f66\u9053\u8fb9\u754c}\\] "},{"location":"chapter%2011%3A%20autonomous%20systems/04.%20self-driving/#_6","title":"\u7aef\u5230\u7aef\u9a7e\u9a76","text":""},{"location":"chapter%2011%3A%20autonomous%20systems/04.%20self-driving/#_7","title":"\u9a7e\u9a76\u4e16\u754c\u6a21\u578b","text":""},{"location":"chapter%2011%3A%20autonomous%20systems/04.%20self-driving/#_8","title":"\u4eff\u771f","text":"

"},{"location":"chapter%2011%3A%20autonomous%20systems/04.%20self-driving/#_9","title":"\u5b89\u5168\u6027","text":""},{"location":"chapter%2011%3A%20autonomous%20systems/04.%20self-driving/#_10","title":"\u81ea\u52a8\u9a7e\u9a76\u7b49\u7ea7","text":""},{"location":"chapter%2011%3A%20autonomous%20systems/04.%20self-driving/#colabnotebook","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528CoLab\u6216notebook\uff09","text":"
  1. \u5b9e\u73b0\u4e00\u4e2a\u7b80\u5355\u7684\u8f68\u8ff9\u4f18\u5316\u89c4\u5212\u5668\u3002\u7ed9\u5b9a\u8d77\u59cb\u4f4d\u7f6e\u3001\u76ee\u6807\u548c\u969c\u788d\u7269\uff0c\u4f7f\u7528\u68af\u5ea6\u4e0b\u964d\u627e\u5230\u6700\u5e73\u6ed1\u7684\u65e0\u78b0\u649e\u8def\u5f84\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u8f68\u8ff9\uff1aN\u4e2a\u8def\u5f84\u70b9\uff0c\u6bcf\u4e2a(x, y)\nN = 20\nstart = jnp.array([0.0, 0.0])\ngoal = jnp.array([10.0, 0.0])\nobstacle = jnp.array([5.0, 0.0])\nobs_radius = 1.5\n\n# \u521d\u59cb\u5316\uff1a\u4ece\u8d77\u70b9\u5230\u7ec8\u70b9\u7684\u76f4\u7ebf\nwaypoints_init = jnp.linspace(start, goal, N)\n\ndef cost(waypoints):\n    wp = jnp.concatenate([start[None], waypoints, goal[None]], axis=0)\n\n    # \u5e73\u6ed1\u5ea6\uff1a\u60e9\u7f5a\u52a0\u901f\u5ea6\uff08\u4e8c\u9636\u5dee\u5206\uff09\n    accel = wp[2:] - 2 * wp[1:-1] + wp[:-2]\n    smooth_cost = jnp.sum(accel ** 2)\n\n    # \u907f\u969c\uff1a\u60e9\u7f5a\u63a5\u8fd1\u5ea6\n    dists = jnp.linalg.norm(wp - obstacle, axis=1)\n    collision_cost = jnp.sum(jnp.maximum(0, obs_radius + 0.5 - dists) ** 2)\n\n    return 10 * smooth_cost + 100 * collision_cost\n\ngrad_cost = jax.grad(cost)\n\n# \u4f18\u5316\u5185\u90e8\u8def\u5f84\u70b9\nwaypoints = waypoints_init[1:-1]\nlr = 0.01\nfor _ in range(500):\n    g = grad_cost(waypoints)\n    waypoints = waypoints - lr * g\n\n# \u7ed8\u56fe\nfull_path = jnp.concatenate([start[None], waypoints, goal[None]], axis=0)\ntheta = jnp.linspace(0, 2 * jnp.pi, 100)\n\nplt.figure(figsize=(10, 4))\nplt.plot(full_path[:, 0], full_path[:, 1], \"b.-\", label=\"\u4f18\u5316\u540e\u8def\u5f84\")\nplt.plot(waypoints_init[:, 0], waypoints_init[:, 1], \"r--\", alpha=0.5, label=\"\u521d\u59cb\uff08\u76f4\u7ebf\uff09\")\nplt.fill(obstacle[0] + obs_radius * jnp.cos(theta),\n         obstacle[1] + obs_radius * jnp.sin(theta), alpha=0.3, color=\"red\", label=\"\u969c\u788d\u7269\")\nplt.plot(*start, \"go\", markersize=10); plt.plot(*goal, \"g*\", markersize=15)\nplt.legend(); plt.axis(\"equal\"); plt.grid(True)\nplt.title(\"\u8f68\u8ff9\u4f18\u5316\uff1a\u5e73\u6ed1\u65e0\u78b0\u649e\u8def\u5f84\")\nplt.show()\n

  2. \u6a21\u62df\u4e00\u4e2a\u5300\u901f\u8fd0\u52a8\u9884\u6d4b\u6a21\u578b\uff0c\u5e76\u4e0e\u8f6c\u5f2f\u8f66\u8f86\u7684\u771f\u5b9e\u503c\u6bd4\u8f83\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u771f\u5b9e\u503c\uff1a\u8f66\u8f86\u53f3\u8f6c\ndt = 0.1\nT = 40  # 4\u79d2\nv = 10.0  # \u7c73/\u79d2\nomega = 0.3  # \u5f27\u5ea6/\u79d2\uff08\u8f6c\u5f2f\u901f\u7387\uff09\n\n# \u771f\u5b9e\u8f68\u8ff9\uff08\u6052\u5b9a\u8f6c\u5f2f\u901f\u7387\uff09\nt = jnp.arange(T) * dt\ntheta = omega * t\ngt_x = (v / omega) * jnp.sin(theta)\ngt_y = (v / omega) * (1 - jnp.cos(theta))\n\n# \u4ecet=0\u5f00\u59cb\u7684\u5300\u901f\u9884\u6d4b\n# \u5047\u8bbe\u8f66\u8f86\u6cbf\u5f53\u524d\u822a\u5411\u7ee7\u7eed\u76f4\u884c\nobs_steps = 10  # \u89c2\u5bdf\u524d1\u79d2\nvx0 = v * jnp.cos(theta[obs_steps - 1])\nvy0 = v * jnp.sin(theta[obs_steps - 1])\npred_t = jnp.arange(T - obs_steps) * dt\npred_x = gt_x[obs_steps - 1] + vx0 * pred_t\npred_y = gt_y[obs_steps - 1] + vy0 * pred_t\n\nplt.figure(figsize=(8, 6))\nplt.plot(gt_x[:obs_steps], gt_y[:obs_steps], \"ko-\", label=\"\u5df2\u89c2\u6d4b\")\nplt.plot(gt_x[obs_steps:], gt_y[obs_steps:], \"g-\", linewidth=2, label=\"\u771f\u5b9e\u672a\u6765\")\nplt.plot(pred_x, pred_y, \"r--\", linewidth=2, label=\"\u5300\u901f\u9884\u6d4b\")\nplt.legend(); plt.axis(\"equal\"); plt.grid(True)\nplt.xlabel(\"x (\u7c73)\"); plt.ylabel(\"y (\u7c73)\")\nplt.title(\"\u5300\u901f\u9884\u6d4b vs \u8f6c\u5f2f\u8f66\u8f86\")\nplt.show()\n

  3. \u5b9e\u73b0\u4e00\u4e2a\u7b80\u5355\u7684\u57fa\u4e8e\u89c4\u5219\u7684\u89c4\u5212\u5668\uff0c\u6839\u636e\u68c0\u6d4b\u5230\u7684\u969c\u788d\u7269\u51b3\u5b9a\u4fdd\u6301\u8f66\u9053\u8fd8\u662f\u505c\u8f66\u3002

    import jax.numpy as jnp\n\ndef rule_based_planner(ego_speed, obstacles, speed_limit=13.9):\n    \"\"\"\n    \u7b80\u5355\u7684\u57fa\u4e8e\u89c4\u5219\u7684\u89c4\u5212\u5668\u3002\n    ego_speed: \u5f53\u524d\u901f\u5ea6\uff08\u7c73/\u79d2\uff09\n    obstacles: \u524d\u65b9\u8f66\u8f86\u7684\uff08\u8ddd\u79bb\uff0c\u901f\u5ea6\uff09\u5143\u7ec4\u5217\u8868\n    speed_limit: \u6700\u9ad8\u5141\u8bb8\u901f\u5ea6\uff08\u7c73/\u79d2\uff09\uff0c\u9ed8\u8ba4\u7ea650\u516c\u91cc/\u5c0f\u65f6\n\n    \u8fd4\u56de\uff1a(\u76ee\u6807\u901f\u5ea6\uff0c\u52a8\u4f5c\u6807\u7b7e)\n    \"\"\"\n    min_following_distance = 2.0 * ego_speed  # 2\u79d2\u89c4\u5219\n    emergency_distance = 5.0  # \u7c73\n\n    if not obstacles:\n        return speed_limit, \"\u5de1\u822a\"\n\n    # \u627e\u5230\u6700\u8fd1\u7684\u524d\u65b9\u969c\u788d\u7269\n    closest_dist, closest_speed = min(obstacles, key=lambda o: o[0])\n\n    if closest_dist < emergency_distance:\n        return 0.0, \"\u7d27\u6025\u505c\u8f66\"\n    elif closest_dist < min_following_distance:\n        # \u5339\u914d\u524d\u8f66\u901f\u5ea6\n        target = min(closest_speed, speed_limit)\n        return target, \"\u8ddf\u968f\"\n    else:\n        return speed_limit, \"\u5de1\u822a\"\n\n# \u6d4b\u8bd5\u573a\u666f\nscenarios = [\n    (13.9, [], \"\u7a7a\u65f7\u9053\u8def\"),\n    (13.9, [(30.0, 10.0)], \"\u524d\u65b9\u6709\u8f83\u6162\u8f66\u8f86\"),\n    (13.9, [(3.0, 0.0)], \"\u524d\u65b9\u6709\u505c\u9760\u8f66\u8f86\uff0c\u8ddd\u79bb\u6781\u8fd1\"),\n    (13.9, [(50.0, 13.9)], \"\u524d\u65b9\u8f66\u8f86\u540c\u901f\u884c\u9a76\"),\n]\n\nfor speed, obs, desc in scenarios:\n    target, action = rule_based_planner(speed, obs)\n    print(f\"{desc:30s}  \u2192  {action:15s} \u76ee\u6807\u901f\u5ea6={target:.1f} \u7c73/\u79d2 ({target*3.6:.0f} \u516c\u91cc/\u5c0f\u65f6)\")\n

"},{"location":"chapter%2011%3A%20autonomous%20systems/05.%20space%20and%20extreme%20robotics/","title":"\u592a\u7a7a\u4e0e\u6781\u7aef\u73af\u5883\u673a\u5668\u4eba","text":"

\u592a\u7a7a\u548c\u6781\u7aef\u73af\u5883\u673a\u5668\u4eba\u5c06\u81ea\u4e3b\u6027\u63a8\u5411\u6781\u9650\u2014\u2014\u901a\u4fe1\u5ef6\u8fdf\u3001\u8f90\u5c04\u548c\u975e\u7ed3\u6784\u5316\u5730\u5f62\u8981\u6c42\u673a\u5668\u4eba\u81ea\u5df1\u601d\u8003\u3002\u672c\u7ae0\u6db5\u76d6\u884c\u661f\u6f2b\u6e38\u8f66\u3001\u5728\u8f68\u670d\u52a1\u3001\u901a\u4fe1\u53d7\u9650\u81ea\u4e3b\u6027\u3001\u6297\u8f90\u5c04\u8ba1\u7b97\u3001\u6c34\u4e0b\u673a\u5668\u4eba\u3001\u641c\u7d22\u6551\u63f4\u3001\u7fa4\u4f53\u673a\u5668\u4eba\u548c\u4eba\u673a\u4ea4\u4e92\u3002

"},{"location":"chapter%2011%3A%20autonomous%20systems/05.%20space%20and%20extreme%20robotics/#_2","title":"\u592a\u7a7a\u673a\u5668\u4eba","text":"

"},{"location":"chapter%2011%3A%20autonomous%20systems/05.%20space%20and%20extreme%20robotics/#_3","title":"\u901a\u4fe1\u7ea6\u675f","text":""},{"location":"chapter%2011%3A%20autonomous%20systems/05.%20space%20and%20extreme%20robotics/#_4","title":"\u6297\u8f90\u5c04\u8ba1\u7b97","text":""},{"location":"chapter%2011%3A%20autonomous%20systems/05.%20space%20and%20extreme%20robotics/#_5","title":"\u975e\u7ed3\u6784\u5316\u5730\u5f62\u4e2d\u7684\u81ea\u4e3b\u5bfc\u822a","text":" \\[\\\\mathbf{x}_{t+1} = f(\\\\mathbf{x}_t, \\\\mathbf{u}_t)\\] "},{"location":"chapter%2011%3A%20autonomous%20systems/05.%20space%20and%20extreme%20robotics/#_6","title":"\u6c34\u4e0b\u673a\u5668\u4eba","text":""},{"location":"chapter%2011%3A%20autonomous%20systems/05.%20space%20and%20extreme%20robotics/#_7","title":"\u641c\u7d22\u6551\u63f4\u673a\u5668\u4eba","text":""},{"location":"chapter%2011%3A%20autonomous%20systems/05.%20space%20and%20extreme%20robotics/#_8","title":"\u7fa4\u4f53\u673a\u5668\u4eba","text":" \\[x_i(t+1) = \\\\frac{1}{|N_i| + 1} \\\\left( x_i(t) + \\\\sum_{j \\\\in N_i} x_j(t) \\\\right)\\]

"},{"location":"chapter%2011%3A%20autonomous%20systems/05.%20space%20and%20extreme%20robotics/#_9","title":"\u4eba\u673a\u4ea4\u4e92","text":" \\[\\\\mathbf{u} = \\\\alpha \\\\mathbf{u}_h + (1 - \\\\alpha) \\\\mathbf{u}_r\\] \\[\\\\pi^* = \\\\arg\\\\max_\\\\pi P(G \\\\mid \\\\xi_{0:t})\\] "},{"location":"chapter%2011%3A%20autonomous%20systems/05.%20space%20and%20extreme%20robotics/#colabnotebook","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528CoLab\u6216notebook\uff09","text":"
  1. \u6a21\u62df\u673a\u5668\u4eba\u7fa4\u4f53\u5c31\u76ee\u6807\u4f4d\u7f6e\u8fbe\u6210\u4e00\u81f4\u7684\u5171\u8bc6\u7b97\u6cd5\u3002\u4ece\u968f\u673a\u521d\u59cb\u4f4d\u7f6e\u5f00\u59cb\uff0c\u89c2\u5bdf\u6536\u655b\u8fc7\u7a0b\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\nn_robots = 10\nrng = jax.random.PRNGKey(0)\npositions = jax.random.uniform(rng, (n_robots, 2), minval=-5, maxval=5)\n\n# \u901a\u4fe1\u56fe\uff1a\u6bcf\u4e2a\u673a\u5668\u4eba\u4e0e\u6700\u8fd1\u76843\u4e2a\u90bb\u5c45\u901a\u4fe1\ndef get_neighbours(positions, k=3):\n    dists = jnp.linalg.norm(positions[:, None] - positions[None, :], axis=-1)\n    # \u5bf9\u6bcf\u4e2a\u673a\u5668\u4eba\uff0c\u627e\u6700\u8fd1\u7684k\u4e2a\uff08\u6392\u9664\u81ea\u8eab\uff09\n    neighbours = jnp.argsort(dists, axis=1)[:, 1:k+1]\n    return neighbours\n\nhistory = [positions.copy()]\n\nfor step in range(30):\n    neighbours = get_neighbours(positions)\n    new_positions = jnp.zeros_like(positions)\n    for i in range(n_robots):\n        nbr_pos = positions[neighbours[i]]\n        new_positions = new_positions.at[i].set(\n            (positions[i] + nbr_pos.sum(axis=0)) / (len(neighbours[i]) + 1)\n        )\n    positions = new_positions\n    history.append(positions.copy())\n\n# \u7ed8\u5236\u6536\u655b\u8fc7\u7a0b\nfig, axes = plt.subplots(1, 3, figsize=(15, 4))\nfor ax, step_idx, title in zip(axes, [0, 10, 29], [\"\u521d\u59cb\", \"\u7b2c10\u6b65\", \"\u6700\u7ec8\"]):\n    h = history[step_idx]\n    ax.scatter(h[:, 0], h[:, 1], s=50)\n    ax.set_xlim(-6, 6); ax.set_ylim(-6, 6)\n    ax.set_aspect(\"equal\"); ax.grid(True); ax.set_title(title)\nplt.suptitle(\"\u7fa4\u4f53\u5171\u8bc6\uff1a\u673a\u5668\u4eba\u6536\u655b\u5230\u4e00\u81f4\u6027\")\nplt.tight_layout()\nplt.show()\n

  2. \u5b9e\u73b0Reynolds\u7fa4\u96c6\u89c4\u5219\uff08\u5206\u79bb\u3001\u5bf9\u9f50\u3001\u5185\u805a\uff09\u5e76\u6a21\u62df\u4e00\u4e2a\u7fa4\u4f53\u4e00\u8d77\u79fb\u52a8\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\nn = 30\nrng = jax.random.PRNGKey(1)\nk1, k2 = jax.random.split(rng)\npos = jax.random.uniform(k1, (n, 2), minval=-5, maxval=5)\nvel = jax.random.uniform(k2, (n, 2), minval=-0.5, maxval=0.5)\n\ndt = 0.1\nseparation_radius = 1.0\nneighbour_radius = 3.0\n\ntrajectories = [pos.copy()]\n\nfor _ in range(200):\n    new_vel = jnp.zeros_like(vel)\n    for i in range(n):\n        diffs = pos - pos[i]\n        dists = jnp.linalg.norm(diffs, axis=1)\n\n        # \u534a\u5f84\u5185\u7684\u90bb\u5c45\uff08\u6392\u9664\u81ea\u8eab\uff09\n        nbr_mask = (dists < neighbour_radius) & (dists > 0)\n        sep_mask = (dists < separation_radius) & (dists > 0)\n\n        # \u5206\u79bb\uff1a\u8fdc\u79bb\u975e\u5e38\u8fd1\u7684\u90bb\u5c45\n        if sep_mask.any():\n            sep = -diffs[sep_mask].sum(axis=0)\n        else:\n            sep = jnp.zeros(2)\n\n        # \u5bf9\u9f50\uff1a\u5339\u914d\u90bb\u5c45\u7684\u5e73\u5747\u901f\u5ea6\n        if nbr_mask.any():\n            align = vel[nbr_mask].mean(axis=0) - vel[i]\n        else:\n            align = jnp.zeros(2)\n\n        # \u5185\u805a\uff1a\u671d\u5411\u90bb\u5c45\u7684\u5e73\u5747\u4f4d\u7f6e\n        if nbr_mask.any():\n            cohesion = pos[nbr_mask].mean(axis=0) - pos[i]\n        else:\n            cohesion = jnp.zeros(2)\n\n        new_vel = new_vel.at[i].set(vel[i] + 1.5 * sep + 0.5 * align + 0.3 * cohesion)\n\n    # \u9650\u5236\u901f\u5ea6\n    speeds = jnp.linalg.norm(new_vel, axis=1, keepdims=True)\n    vel = jnp.where(speeds > 2.0, new_vel / speeds * 2.0, new_vel)\n    pos = pos + vel * dt\n    trajectories.append(pos.copy())\n\n# \u7ed8\u5236\u5feb\u7167\nfig, axes = plt.subplots(1, 3, figsize=(15, 4))\nfor ax, idx, title in zip(axes, [0, 50, 199], [\"\u5f00\u59cb\", \"\u7b2c50\u6b65\", \"\u7b2c200\u6b65\"]):\n    p = trajectories[idx]\n    v = vel if idx == 199 else jnp.zeros_like(vel)\n    ax.scatter(p[:, 0], p[:, 1], s=20, c=\"blue\")\n    ax.set_aspect(\"equal\"); ax.grid(True); ax.set_title(title)\n    lim = max(abs(p).max() + 1, 6)\n    ax.set_xlim(-lim, lim); ax.set_ylim(-lim, lim)\nplt.suptitle(\"Reynolds\u7fa4\u96c6\uff1a\u5206\u79bb+\u5bf9\u9f50+\u5185\u805a\")\nplt.tight_layout()\nplt.show()\n

  3. \u6a21\u62df\u5171\u4eab\u81ea\u4e3b\u6df7\u5408\uff1a\u4eba\u7c7b\u63d0\u4f9b\u5e26\u566a\u58f0\u7684\u65b9\u5411\u8f93\u5165\uff0c\u673a\u5668\u4eba\u7684\u81ea\u4e3b\u7cfb\u7edf\u63d0\u4f9b\u5230\u76ee\u6807\u7684\u5e73\u6ed1\u8def\u5f84\u3002\u7528\u4e0d\u540c\u7684alpha\u503c\u8fdb\u884c\u6df7\u5408\u3002

    import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\ngoal = jnp.array([10.0, 5.0])\npos = jnp.array([0.0, 0.0])\ndt = 0.1\n\nrng = jax.random.PRNGKey(3)\n\nfig, axes = plt.subplots(1, 3, figsize=(15, 4))\nfor ax, alpha in zip(axes, [1.0, 0.5, 0.0]):\n    pos = jnp.array([0.0, 0.0])\n    path = [pos.copy()]\n\n    for step in range(150):\n        # \u673a\u5668\u4eba\u81ea\u4e3b\uff1a\u5230\u76ee\u6807\u7684\u5e73\u6ed1\u8def\u5f84\n        direction = goal - pos\n        u_robot = direction / (jnp.linalg.norm(direction) + 1e-6) * 1.0\n\n        # \u4eba\u7c7b\u8f93\u5165\uff1a\u5927\u81f4\u6b63\u786e\u7684\u65b9\u5411\u4f46\u6709\u566a\u58f0\n        noise = jax.random.normal(jax.random.fold_in(rng, step), (2,)) * 0.5\n        u_human = u_robot + noise\n\n        # \u6df7\u5408\n        u = alpha * u_human + (1 - alpha) * u_robot\n        pos = pos + u * dt\n        path.append(pos.copy())\n\n        if jnp.linalg.norm(pos - goal) < 0.3:\n            break\n\n    path = jnp.stack(path)\n    ax.plot(path[:, 0], path[:, 1], \"b-\", alpha=0.7)\n    ax.plot(*goal, \"r*\", markersize=15)\n    ax.plot(0, 0, \"go\", markersize=10)\n    ax.set_title(f\"\u03b1={alpha:.1f} ({'\u4eba\u7c7b' if alpha==1 else '\u673a\u5668\u4eba' if alpha==0 else '\u5171\u4eab'})\")\n    ax.set_xlim(-1, 12); ax.set_ylim(-3, 8)\n    ax.set_aspect(\"equal\"); ax.grid(True)\n\nplt.suptitle(\"\u5171\u4eab\u81ea\u4e3b\uff1a\u6df7\u5408\u4eba\u7c7b\u4e0e\u673a\u5668\u4eba\u63a7\u5236\")\nplt.tight_layout()\nplt.show()\n

"},{"location":"chapter%2012%3A%20graph%20neural%20networks/01.%20geometric%20deep%20learning/","title":"\u51e0\u4f55\u6df1\u5ea6\u5b66\u4e60","text":"

\u51e0\u4f55\u6df1\u5ea6\u5b66\u4e60\u662f\u63ed\u793aCNN\u3001Transformer\u548cGNN\u7686\u9075\u5faa\u540c\u4e00\u539f\u7406\u2014\u2014\u5229\u7528\u5bf9\u79f0\u6027\u2014\u2014\u7684\u7edf\u4e00\u6846\u67b6\u3002\u672c\u7ae0\u6db5\u76d6\u5bf9\u79f0\u7fa4\u3001\u7fa4\u4f5c\u7528\u3001\u4e0d\u53d8\u6027\u3001\u7b49\u53d8\u6027\u3001\u4e94\u4e2a\u51e0\u4f55\u57df\u4ee5\u53ca\u5c3a\u5ea6\u5206\u79bb

"},{"location":"chapter%2012%3A%20graph%20neural%20networks/01.%20geometric%20deep%20learning/#_2","title":"\u5bf9\u79f0\u6027\u4e0e\u7fa4","text":""},{"location":"chapter%2012%3A%20graph%20neural%20networks/01.%20geometric%20deep%20learning/#_3","title":"\u4e0d\u53d8\u6027\u4e0e\u7b49\u53d8\u6027","text":" \\[f(\\rho(g, x)) = f(x) \\quad \\text{\u5bf9\u4e8e\u6240\u6709 } g \\in G\\] \\[f(\\rho_{\\text{in}}(g, x)) = \\rho_{\\text{out}}(g, f(x)) \\quad \\text{\u5bf9\u4e8e\u6240\u6709 } g \\in G\\]

"},{"location":"chapter%2012%3A%20graph%20neural%20networks/01.%20geometric%20deep%20learning/#_4","title":"\u4e94\u4e2a\u51e0\u4f55\u57df","text":" "},{"location":"chapter%2012%3A%20graph%20neural%20networks/01.%20geometric%20deep%20learning/#_5","title":"\u5c3a\u5ea6\u5206\u79bb\u4e0e\u7c97\u5316","text":" \\[x \\xrightarrow{\\text{\u5c40\u90e8\u7279\u5f81}} h^{(1)} \\xrightarrow{\\text{\u7c97\u5316}} h^{(2)} \\xrightarrow{\\text{\u7c97\u5316}} \\cdots \\xrightarrow{\\text{\u5168\u5c40}} y\\] "},{"location":"chapter%2012%3A%20graph%20neural%20networks/01.%20geometric%20deep%20learning/#colabnotebook","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528CoLab\u6216notebook\uff09","text":"
  1. \u9a8c\u8bc1\u5377\u79ef\u7684\u5e73\u79fb\u7b49\u53d8\u6027\u3002\u5bf9\u56fe\u50cf\u5e94\u7528\u5377\u79ef\uff0c\u7136\u540e\u5e73\u79fb\u56fe\u50cf\u518d\u6b21\u5377\u79ef\u3002\u68c0\u67e5\u8f93\u51fa\u662f\u5426\u4e92\u4e3a\u5e73\u79fb\u7248\u672c\u3002

    import jax\nimport jax.numpy as jnp\n\n# \u4e00\u7ef4\u4fe1\u53f7\u548c\u4e00\u4e2a\u7b80\u5355\u6ee4\u6ce2\u5668\nsignal = jnp.array([0, 0, 0, 1, 2, 3, 2, 1, 0, 0, 0], dtype=float)\nkernel = jnp.array([1, 0, -1], dtype=float)\n\n# \u5148\u5377\u79ef\u518d\u5e73\u79fb\nconv_result = jnp.convolve(signal, kernel, mode=\"same\")\nshifted_signal = jnp.roll(signal, 3)\nconv_shifted = jnp.convolve(shifted_signal, kernel, mode=\"same\")\nshifted_conv = jnp.roll(conv_result, 3)\n\nprint(f\"\u5148\u5377\u79ef\u518d\u5e73\u79fb:  {shifted_conv}\")\nprint(f\"\u5148\u5e73\u79fb\u518d\u5377\u79ef:  {conv_shifted}\")\nprint(f\"\u7b49\u53d8\u6027: {jnp.allclose(shifted_conv, conv_shifted, atol=1e-5)}\")\n

  2. \u9a8c\u8bc1DeepSets\u98ce\u683c\u805a\u5408\u7684\u7f6e\u6362\u4e0d\u53d8\u6027\u3002\u5bf9\u96c6\u5408\u4e2d\u7684\u6bcf\u4e2a\u5143\u7d20\u5e94\u7528\u5171\u4eab\u51fd\u6570\uff0c\u6c42\u548c\u7ed3\u679c\uff0c\u5e76\u68c0\u67e5\u8f93\u51fa\u662f\u5426\u4e0d\u4f9d\u8d56\u4e8e\u5143\u7d20\u987a\u5e8f\u3002

    import jax\nimport jax.numpy as jnp\n\n# 4\u4e2a\u5411\u91cf\u7684\"\u96c6\u5408\"\uff08\u987a\u5e8f\u5e94\u65e0\u5173\u7d27\u8981\uff09\nx = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])\n\n# \u7b80\u5355\u7684\u5171\u4eab\u51fd\u6570\uff1a\u9010\u5143\u7d20\u5e73\u65b9\npsi = lambda v: v ** 2\n\n# \u901a\u8fc7\u6c42\u548c\u805a\u5408\ndef deepsets(points):\n    return jnp.sum(jax.vmap(psi)(points), axis=0)\n\n# \u539f\u59cb\u987a\u5e8f\nresult1 = deepsets(x)\n\n# \u7f6e\u6362\u540e\u7684\u987a\u5e8f\nperm = jnp.array([2, 0, 3, 1])\nresult2 = deepsets(x[perm])\n\nprint(f\"\u539f\u59cb\u987a\u5e8f:  {result1}\")\nprint(f\"\u7f6e\u6362\u987a\u5e8f:  {result2}\")\nprint(f\"\u4e0d\u53d8\u6027: {jnp.allclose(result1, result2)}\")\n

  3. \u63a2\u7d22\u7fa4\u7ed3\u6784\u3002\u901a\u8fc7\u68c0\u67e5\u5c01\u95ed\u6027\u3001\u7ed3\u5408\u5f8b\u3001\u5355\u4f4d\u5143\u548c\u9006\u5143\uff0c\u9a8c\u8bc1\u4e8c\u7ef4\u65cb\u8f6c\u77e9\u9635\u6784\u6210\u7fa4\u3002

    import jax.numpy as jnp\n\ndef rot2d(theta):\n    return jnp.array([[jnp.cos(theta), -jnp.sin(theta)],\n                       [jnp.sin(theta),  jnp.cos(theta)]])\n\nR1 = rot2d(jnp.pi / 6)\nR2 = rot2d(jnp.pi / 4)\nR3 = rot2d(jnp.pi / 3)\n\n# \u5c01\u95ed\u6027\uff1a\u4e24\u4e2a\u65cb\u8f6c\u7684\u4e58\u79ef\u8fd8\u662f\u4e00\u4e2a\u65cb\u8f6c\nR12 = R1 @ R2\nprint(f\"\u5c01\u95ed\u6027 (\u884c\u5217\u5f0f=1, \u6b63\u4ea4): det={jnp.linalg.det(R12):.4f}, \"\n      f\"R^T R = I: {jnp.allclose(R12.T @ R12, jnp.eye(2), atol=1e-5)}\")\n\n# \u7ed3\u5408\u5f8b\nprint(f\"\u7ed3\u5408\u5f8b: {jnp.allclose((R1 @ R2) @ R3, R1 @ (R2 @ R3), atol=1e-5)}\")\n\n# \u5355\u4f4d\u5143\nI = rot2d(0.0)\nprint(f\"\u5355\u4f4d\u5143: {jnp.allclose(R1 @ I, R1, atol=1e-5)}\")\n\n# \u9006\u5143\nR1_inv = rot2d(-jnp.pi / 6)\nprint(f\"\u9006\u5143: {jnp.allclose(R1 @ R1_inv, jnp.eye(2), atol=1e-5)}\")\n

"},{"location":"chapter%2012%3A%20graph%20neural%20networks/02.%20graph%20theory/","title":"\u56fe\u8bba","text":"

\u56fe\u8bba\u4e3a\u63cf\u8ff0\u5b9e\u4f53\u95f4\u5173\u7cfb\u63d0\u4f9b\u4e86\u6570\u5b66\u8bed\u8a00\u3002\u672c\u7ae0\u6db5\u76d6\u8282\u70b9\u3001\u8fb9\u3001\u90bb\u63a5\u77e9\u9635\u3001\u56fe\u7c7b\u578b\u3001\u5ea6\u548c\u8fde\u901a\u6027\u3001\u56fe\u62c9\u666e\u62c9\u65af\u7b97\u5b50\u3001\u8c31\u56fe\u7406\u8bba\u4ee5\u53ca\u73b0\u5b9e\u4e16\u754c\u7684\u56fe\u5e94\u7528\u3002\u6211\u4eec\u5c06\u5728\u7eaf\u8ba1\u7b97\u673a\u79d1\u5b66\u7ae0\u8282\u4e2d\u66f4\u6df1\u5165\u5730\u8ba8\u8bba\u56fe

"},{"location":"chapter%2012%3A%20graph%20neural%20networks/02.%20graph%20theory/#_2","title":"\u8282\u70b9\u3001\u8fb9\u548c\u90bb\u63a5","text":" \\[ A = \\begin{bmatrix} 0 & 1 & 1 \\\\ 1 & 0 & 1 \\\\ 1 & 1 & 0 \\end{bmatrix} \\]

"},{"location":"chapter%2012%3A%20graph%20neural%20networks/02.%20graph%20theory/#_3","title":"\u56fe\u7c7b\u578b","text":" \\[ A = \\begin{bmatrix} 0 & B \\\\ B^T & 0 \\end{bmatrix} \\] "},{"location":"chapter%2012%3A%20graph%20neural%20networks/02.%20graph%20theory/#_4","title":"\u5ea6\u3001\u8def\u5f84\u548c\u8fde\u901a\u6027","text":""},{"location":"chapter%2012%3A%20graph%20neural%20networks/02.%20graph%20theory/#_5","title":"\u56fe\u62c9\u666e\u62c9\u65af\u7b97\u5b50","text":" \\[L = D - A\\] \\[ L = \\begin{bmatrix} 2 & 0 & 0 \\\\ 0 & 2 & 0 \\\\ 0 & 0 & 2 \\end{bmatrix} - \\begin{bmatrix} 0 & 1 & 1 \\\\ 1 & 0 & 1 \\\\ 1 & 1 & 0 \\end{bmatrix} = \\begin{bmatrix} 2 & -1 & -1 \\\\ -1 & 2 & -1 \\\\ -1 & -1 & 2 \\end{bmatrix} \\] \\[\\mathbf{x}^T L \\mathbf{x} = \\sum_{(i,j) \\in E} (x_i - x_j)^2\\]

- \u8fd9\u4e2a\u4e8c\u6b21\u5f62\u5f0f\u5ea6\u91cf\u56fe\u4e0a\u7684\u4fe1\u53f7 $\\mathbf{x}$ \u5728\u8fb9\u4e0a\u7684\u53d8\u5316\u7a0b\u5ea6\u3002\u5982\u679c\u76f8\u90bb\u8282\u70b9\u503c\u76f8\u8fd1\uff0c\u5219 $\\mathbf{x}^T L \\mathbf{x}$ \u8f83\u5c0f\u3002\u5982\u679c\u5b83\u4eec\u5dee\u5f02\u5f88\u5927\uff0c\u5219\u8f83\u5927\u3002\u62c9\u666e\u62c9\u65af\u7b97\u5b50\u5ea6\u91cf\u56fe\u4e0a\u4fe1\u53f7\u7684**\u5e73\u6ed1\u5ea6**\u3002\n\n- \u6700\u5c0f\u7279\u5f81\u503c\u59cb\u7ec8\u4e3a0\uff0c\u7279\u5f81\u5411\u91cf\u4e3a $\\mathbf{1} = [1, 1, \\ldots, 1]^T$\uff08\u5e38\u6570\u4fe1\u53f7\u6ca1\u6709\u53d8\u5316\uff09\u3002\u96f6\u7279\u5f81\u503c\u7684\u6570\u91cf\u7b49\u4e8e\u8fde\u901a\u5206\u91cf\u7684\u6570\u91cf\u3002\n\n- \u7b2c\u4e8c\u5c0f\u7279\u5f81\u503c $\\lambda_2$ \u662f**\u4ee3\u6570\u8fde\u901a\u5ea6**\uff08Fiedler\u503c\uff09\u3002\u5b83\u8861\u91cf\u56fe\u7684\u8fde\u901a\u7a0b\u5ea6\uff1a$\\lambda_2 = 0$ \u8868\u793a\u56fe\u4e0d\u8fde\u901a\uff0c\u5927\u7684 $\\lambda_2$ \u8868\u793a\u56fe\u7d27\u5bc6\u8fde\u901a\u3002\n
\\[\\hat{L} = D^{-1/2} L D^{-1/2} = I - D^{-1/2} A D^{-1/2}\\] "},{"location":"chapter%2012%3A%20graph%20neural%20networks/02.%20graph%20theory/#_6","title":"\u8c31\u56fe\u7406\u8bba","text":" \\[\\hat{\\mathbf{x}} = U^T \\mathbf{x}\\] \\[g_\\theta \\star \\mathbf{x} = U \\left( (U^T g_\\theta) \\odot (U^T \\mathbf{x}) \\right) = U \\, \\text{diag}(\\hat{g}_\\theta) \\, U^T \\mathbf{x}\\] "},{"location":"chapter%2012%3A%20graph%20neural%20networks/02.%20graph%20theory/#_7","title":"\u793e\u533a\u68c0\u6d4b","text":" \\[Q = \\frac{1}{2|E|} \\sum_{ij} \\left( A_{ij} - \\frac{d_i d_j}{2|E|} \\right) \\delta(c_i, c_j)\\] "},{"location":"chapter%2012%3A%20graph%20neural%20networks/02.%20graph%20theory/#_8","title":"\u73b0\u5b9e\u4e16\u754c\u4e2d\u7684\u56fe","text":""},{"location":"chapter%2012%3A%20graph%20neural%20networks/02.%20graph%20theory/#colabnotebook","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528CoLab\u6216notebook\uff09","text":"
  1. \u6784\u5efa\u4e00\u4e2a\u5c0f\u578b\u56fe\u7684\u90bb\u63a5\u77e9\u9635\uff0c\u8ba1\u7b97\u57fa\u672c\u6027\u8d28\uff1a\u6bcf\u4e2a\u8282\u70b9\u7684\u5ea6\u3001\u957f\u5ea6\u4e3a2\u7684\u8def\u5f84\u6570\u91cf\u4ee5\u53ca\u56fe\u662f\u5426\u8fde\u901a\u3002

    import jax.numpy as jnp\n\n# \u4e00\u4e2a\u7b80\u5355\u56fe\uff1a5\u4e2a\u8282\u70b9\n# 0-1, 0-2, 1-2, 2-3, 3-4\nA = jnp.array([[0, 1, 1, 0, 0],\n               [1, 0, 1, 0, 0],\n               [1, 1, 0, 1, 0],\n               [0, 0, 1, 0, 1],\n               [0, 0, 0, 1, 0]], dtype=float)\n\n# \u5ea6\ndegrees = A.sum(axis=1)\nprint(f\"\u5ea6\u6570: {degrees}\")\n\n# \u957f\u5ea6\u4e3a2\u7684\u8def\u5f84\nA2 = A @ A\nprint(f\"\u957f\u5ea6\u4e3a2\u7684\u8def\u5f84\uff08\u8282\u70b90\u52303\uff09: {int(A2[0, 3])}\")\n\n# \u662f\u5426\u8fde\u901a\uff1f\u68c0\u67e5 A^(n-1) \u662f\u5426\u6240\u6709\u6761\u76ee\u975e\u96f6\nAn = jnp.linalg.matrix_power(A + jnp.eye(5), 4)  # (A+I)^4 \u7528\u4e8e\u53ef\u8fbe\u6027\nconnected = jnp.all(An > 0)\nprint(f\"\u8fde\u901a: {connected}\")\n

  2. \u8ba1\u7b97\u56fe\u62c9\u666e\u62c9\u65af\u7b97\u5b50\u53ca\u5176\u7279\u5f81\u503c\u3002\u9a8c\u8bc1\u6700\u5c0f\u7279\u5f81\u503c\u4e3a0\u4e14\u5bf9\u5e94\u7684\u7279\u5f81\u5411\u91cf\u4e3a\u5e38\u6570\u3002

    import jax.numpy as jnp\n\nA = jnp.array([[0, 1, 1, 0, 0],\n               [1, 0, 1, 0, 0],\n               [1, 1, 0, 1, 0],\n               [0, 0, 1, 0, 1],\n               [0, 0, 0, 1, 0]], dtype=float)\n\nD = jnp.diag(A.sum(axis=1))\nL = D - A\n\neigenvalues, eigenvectors = jnp.linalg.eigh(L)\nprint(f\"\u7279\u5f81\u503c: {eigenvalues}\")\nprint(f\"\u6700\u5c0f\u7279\u5f81\u5411\u91cf: {eigenvectors[:, 0]}\")\nprint(f\"Fiedler\u503c\uff08\u4ee3\u6570\u8fde\u901a\u5ea6\uff09: {eigenvalues[1]:.4f}\")\n\n# \u9a8c\u8bc1: x^T L x \u5ea6\u91cf\u5e73\u6ed1\u5ea6\nx = jnp.array([1.0, 1.0, 1.0, -1.0, -1.0])  # \u4e24\u4e2a\u7ec4\nsmoothness = x @ L @ x\nprint(f\"\u4e24\u7ec4\u4fe1\u53f7\u7684\u5e73\u6ed1\u5ea6: {smoothness:.2f}\")\n

  3. \u5bf9\u5177\u6709\u4e24\u4e2a\u793e\u533a\u7684\u56fe\u6267\u884c\u8c31\u805a\u7c7b\u3002\u4f7f\u7528Fiedler\u5411\u91cf\u5d4c\u5165\u8282\u70b9\uff0c\u5e76\u6309\u7b26\u53f7\u5206\u79bb\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u4e24\u4e2a\u793e\u533a\uff0c\u54045\u4e2a\u8282\u70b9\uff0c\u5f31\u8fde\u63a5\nA = jnp.zeros((10, 10))\n# \u793e\u533a1\uff1a\u8282\u70b90-4\uff08\u5bc6\u96c6\uff09\nfor i in range(5):\n    for j in range(i+1, 5):\n        A = A.at[i, j].set(1).at[j, i].set(1)\n# \u793e\u533a2\uff1a\u8282\u70b95-9\uff08\u5bc6\u96c6\uff09\nfor i in range(5, 10):\n    for j in range(i+1, 10):\n        A = A.at[i, j].set(1).at[j, i].set(1)\n# \u4e00\u6761\u6865\u63a5\u8fb9\nA = A.at[2, 7].set(1).at[7, 2].set(1)\n\nD = jnp.diag(A.sum(axis=1))\nL = D - A\neigenvalues, eigenvectors = jnp.linalg.eigh(L)\n\n# Fiedler\u5411\u91cf\uff08\u7b2c\u4e8c\u5c0f\u7279\u5f81\u503c\uff09\nfiedler = eigenvectors[:, 1]\ncommunities = (fiedler > 0).astype(int)\n\nprint(f\"Fiedler\u5411\u91cf: {fiedler}\")\nprint(f\"\u805a\u7c7b: {communities}\")\n\nplt.bar(range(10), fiedler, color=[\"#3498db\" if c == 0 else \"#e74c3c\" for c in communities])\nplt.xlabel(\"\u8282\u70b9\"); plt.ylabel(\"Fiedler\u5411\u91cf\u503c\")\nplt.title(\"\u901a\u8fc7Fiedler\u5411\u91cf\u8fdb\u884c\u8c31\u805a\u7c7b\")\nplt.show()\n

"},{"location":"chapter%2012%3A%20graph%20neural%20networks/03.%20graph%20neural%20networks/","title":"\u56fe\u795e\u7ecf\u7f51\u7edc","text":"

\u56fe\u795e\u7ecf\u7f51\u7edc\u901a\u8fc7\u5728\u8fde\u63a5\u8282\u70b9\u4e4b\u95f4\u4f20\u9012\u6d88\u606f\u6765\u5b66\u4e60\u56fe\u7ed3\u6784\u6570\u636e\u3002\u672c\u7ae0\u6db5\u76d6\u6d88\u606f\u4f20\u9012\u6846\u67b6\u3001GCN\u3001GraphSAGE\u3001GIN\u3001\u8fc7\u5e73\u6ed1\u3001\u56fe\u6c60\u5316\u4ee5\u53ca\u8282\u70b9/\u8fb9/\u56fe\u7ea7\u522b\u7684\u4efb\u52a1\uff1b\u652f\u6491\u5206\u5b50\u6027\u8d28\u9884\u6d4b\u3001\u793e\u4ea4\u7f51\u7edc\u5206\u6790\u548c\u63a8\u8350\u7cfb\u7edf\u7684\u6838\u5fc3\u67b6\u6784\u3002

"},{"location":"chapter%2012%3A%20graph%20neural%20networks/03.%20graph%20neural%20networks/#_2","title":"\u6d88\u606f\u4f20\u9012\u6846\u67b6","text":" \\[\\mathbf{m}_i^{(l)} = \\bigoplus_{j \\in \\mathcal{N}(i)} \\phi^{(l)}\\left(\\mathbf{h}_i^{(l)}, \\mathbf{h}_j^{(l)}, \\mathbf{e}_{ij}\\right)\\] \\[\\mathbf{h}_i^{(l+1)} = \\psi^{(l)}\\left(\\mathbf{h}_i^{(l)}, \\mathbf{m}_i^{(l)}\\right)\\]

"},{"location":"chapter%2012%3A%20graph%20neural%20networks/03.%20graph%20neural%20networks/#gcn","title":"\u56fe\u5377\u79ef\u7f51\u7edc\uff08GCN\uff09","text":" \\[H^{(l+1)} = \\sigma\\left(\\hat{A} H^{(l)} W^{(l)}\\right)\\] "},{"location":"chapter%2012%3A%20graph%20neural%20networks/03.%20graph%20neural%20networks/#graphsage","title":"GraphSAGE","text":" \\[\\mathbf{h}_i^{(l+1)} = \\sigma\\left(W^{(l)} \\cdot \\text{CONCAT}\\left(\\mathbf{h}_i^{(l)}, \\text{AGG}\\left(\\{\\mathbf{h}_j^{(l)} : j \\in \\mathcal{S}(i)\\}\\right)\\right)\\right)\\] "},{"location":"chapter%2012%3A%20graph%20neural%20networks/03.%20graph%20neural%20networks/#gin","title":"\u56fe\u540c\u6784\u7f51\u7edc\uff08GIN\uff09","text":" \\[\\mathbf{h}_i^{(l+1)} = \\text{MLP}^{(l)}\\left((1 + \\epsilon^{(l)}) \\cdot \\mathbf{h}_i^{(l)} + \\sum_{j \\in \\mathcal{N}(i)} \\mathbf{h}_j^{(l)}\\right)\\] "},{"location":"chapter%2012%3A%20graph%20neural%20networks/03.%20graph%20neural%20networks/#_3","title":"\u8fc7\u5e73\u6ed1","text":" "},{"location":"chapter%2012%3A%20graph%20neural%20networks/03.%20graph%20neural%20networks/#_4","title":"\u56fe\u6c60\u5316","text":" \\[\\mathbf{h}_G = \\text{READOUT}(\\{\\mathbf{h}_i^{(L)} : i \\in V\\}) = \\sum_i \\mathbf{h}_i^{(L)} \\quad \\text{\u6216} \\quad \\frac{1}{|V|} \\sum_i \\mathbf{h}_i^{(L)} \\quad \\text{\u6216} \\quad \\max_i \\mathbf{h}_i^{(L)}\\] \\[X^{(l+1)} = S^{(l)T} H^{(l)}, \\quad A^{(l+1)} = S^{(l)T} A^{(l)} S^{(l)}\\] "},{"location":"chapter%2012%3A%20graph%20neural%20networks/03.%20graph%20neural%20networks/#_5","title":"\u5f02\u6784\u56fe","text":" \\[\\mathbf{h}_i^{(l+1)} = \\sigma\\left(\\sum_{r \\in \\mathcal{R}} \\sum_{j \\in \\mathcal{N}_r(i)} \\frac{1}{|\\mathcal{N}_r(i)|} W_r^{(l)} \\mathbf{h}_j^{(l)} + W_0^{(l)} \\mathbf{h}_i^{(l)}\\right)\\] \\[\\text{Attention}(i, j) = \\left(W_{\\tau(i)}^Q \\mathbf{h}_i\\right)^T \\cdot \\frac{W_{\\phi(i,j)}^{\\text{ATT}}}{\\sqrt{d}} \\cdot \\left(W_{\\tau(j)}^K \\mathbf{h}_j\\right)\\] "},{"location":"chapter%2012%3A%20graph%20neural%20networks/03.%20graph%20neural%20networks/#_6","title":"\u94fe\u63a5\u9884\u6d4b\u4e0e\u77e5\u8bc6\u56fe\u8c31\u8865\u5168","text":""},{"location":"chapter%2012%3A%20graph%20neural%20networks/03.%20graph%20neural%20networks/#_7","title":"\u4efb\u52a1\u7c7b\u578b","text":""},{"location":"chapter%2012%3A%20graph%20neural%20networks/03.%20graph%20neural%20networks/#colabnotebook","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528CoLab\u6216notebook\uff09","text":"
  1. \u4f7f\u7528\u5f52\u4e00\u5316\u90bb\u63a5\u77e9\u9635\u4ece\u5934\u5b9e\u73b0\u4e00\u4e2a\u5355\u5c42GCN\u3002\u5e94\u7528\u4e8e\u4e00\u4e2a\u5c0f\u578b\u56fe\uff0c\u89c2\u5bdf\u8282\u70b9\u7279\u5f81\u5982\u4f55\u88ab\u5e73\u6ed1\u3002

    import jax\nimport jax.numpy as jnp\n\n# \u56fe\uff1a5\u4e2a\u8282\u70b9\uff0c\u7b80\u5355\u94fe\u5e26\u5206\u652f\nA = jnp.array([[0, 1, 0, 0, 0],\n               [1, 0, 1, 0, 0],\n               [0, 1, 0, 1, 1],\n               [0, 0, 1, 0, 0],\n               [0, 0, 1, 0, 0]], dtype=float)\n\n# \u6dfb\u52a0\u81ea\u73af\nA_hat = A + jnp.eye(5)\nD_hat = jnp.diag(A_hat.sum(axis=1))\nD_inv_sqrt = jnp.diag(1.0 / jnp.sqrt(A_hat.sum(axis=1)))\nA_norm = D_inv_sqrt @ A_hat @ D_inv_sqrt\n\n# \u8282\u70b9\u7279\u5f81\uff1aone-hot \u5355\u4f4d\u9635\nH = jnp.eye(5)\n\n# \u6743\u91cd\u77e9\u9635\uff08\u968f\u673a\u521d\u59cb\u5316\uff09\nrng = jax.random.PRNGKey(0)\nW = jax.random.normal(rng, (5, 3)) * 0.5\n\n# GCN\u5c42\uff1aH' = ReLU(A_norm @ H @ W)\nH_new = jax.nn.relu(A_norm @ H @ W)\n\nprint(\"\u539f\u59cb\u7279\u5f81\uff08one-hot\uff09:\")\nprint(H)\nprint(\"\\n\u7ecf\u8fc7GCN\u5c42\u540e:\")\nprint(jnp.round(H_new, 3))\nprint(\"\\n\u6ce8\u610f\uff1a\u8fde\u63a5\u7684\u8282\u70b9\u73b0\u5728\u5177\u6709\u76f8\u4f3c\u7684\u8868\u793a\")\n

  2. \u5b9e\u73b0\u5177\u6709\u6c42\u548c\u805a\u5408\uff08GIN\u98ce\u683c\uff09\u548c\u5747\u503c\u805a\u5408\uff08GCN\u98ce\u683c\uff09\u7684\u6d88\u606f\u4f20\u9012\u3002\u5c55\u793a\u6c42\u548c\u80fd\u533a\u5206\u5747\u503c\u65e0\u6cd5\u533a\u5206\u7684\u591a\u91cd\u96c6\u3002

    import jax.numpy as jnp\n\n# \u4e24\u4e2a\u5177\u6709\u76f8\u540c\u5747\u503c\u7684\u4e0d\u540c\u90bb\u5c45\u591a\u91cd\u96c6\n# \u8282\u70b9A\uff1a\u90bb\u5c45\u7279\u5f81\u4e3a [1, 1, 1, 1]  \uff08\u56db\u4e2a\u90bb\u5c45\uff0c\u90fd\u662f1\uff09\n# \u8282\u70b9B\uff1a\u90bb\u5c45\u7279\u5f81\u4e3a [2, 2]          \uff08\u4e24\u4e2a\u90bb\u5c45\uff0c\u90fd\u662f2\uff09\n\nneighbours_A = jnp.array([[1.0], [1.0], [1.0], [1.0]])\nneighbours_B = jnp.array([[2.0], [2.0]])\n\n# \u5747\u503c\u805a\u5408\nmean_A = neighbours_A.mean(axis=0)\nmean_B = neighbours_B.mean(axis=0)\nprint(f\"\u5747\u503c A: {mean_A}, \u5747\u503c B: {mean_B}, \u76f8\u540c: {jnp.allclose(mean_A, mean_B)}\")\n\n# \u6c42\u548c\u805a\u5408\nsum_A = neighbours_A.sum(axis=0)\nsum_B = neighbours_B.sum(axis=0)\nprint(f\"\u6c42\u548c A:  {sum_A},  \u6c42\u548c B:  {sum_B},  \u76f8\u540c: {jnp.allclose(sum_A, sum_B)}\")\nprint(\"\\n\u6c42\u548c\u80fd\u533a\u5206\u8fd9\u4e9b\u591a\u91cd\u96c6\uff1b\u5747\u503c\u4e0d\u80fd\uff01\")\n

  3. \u6f14\u793a\u8fc7\u5e73\u6ed1\u3002\u91cd\u590d\u5e94\u7528\u5f52\u4e00\u5316\u90bb\u63a5\u77e9\u9635\uff0c\u89c2\u5bdf\u8282\u70b9\u7279\u5f81\u6536\u655b\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u968f\u673a\u56fe\nA = jnp.array([[0,1,1,0,0,0],\n               [1,0,1,0,0,0],\n               [1,1,0,1,0,0],\n               [0,0,1,0,1,1],\n               [0,0,0,1,0,1],\n               [0,0,0,1,1,0]], dtype=float)\n\nA_hat = A + jnp.eye(6)\nD_inv_sqrt = jnp.diag(1.0 / jnp.sqrt(A_hat.sum(axis=1)))\nA_norm = D_inv_sqrt @ A_hat @ D_inv_sqrt\n\n# \u521d\u59cb\u7279\u5f81\uff1a\u6bcf\u4e2a\u8282\u70b9\u5404\u4e0d\u76f8\u540c\nH = jnp.array([[1,0], [0,1], [1,1], [-1,0], [0,-1], [-1,-1]], dtype=float)\n\ndistances = []\nfor k in range(20):\n    H = A_norm @ H\n    # \u8861\u91cf\u7279\u5f81\u7684\u533a\u522b\u7a0b\u5ea6\uff08\u8282\u70b9\u95f4\u7684\u6807\u51c6\u5dee\uff09\n    spread = jnp.std(H, axis=0).mean()\n    distances.append(float(spread))\n\nplt.plot(distances, \"o-\")\nplt.xlabel(\"\u6d88\u606f\u4f20\u9012\u8f6e\u6570\")\nplt.ylabel(\"\u7279\u5f81\u5206\u6563\u5ea6\uff08\u8282\u70b9\u95f4\u6807\u51c6\u5dee\uff09\")\nplt.title(\"\u8fc7\u5e73\u6ed1\uff1a\u7279\u5f81\u968f\u6df1\u5ea6\u589e\u52a0\u800c\u6536\u655b\")\nplt.show()\n

"},{"location":"chapter%2012%3A%20graph%20neural%20networks/04.%20graph%20attention%20networks/","title":"\u56fe\u6ce8\u610f\u529b\u7f51\u7edc","text":"

\u56fe\u6ce8\u610f\u529b\u7f51\u7edc\u5c06\u5747\u5300\u7684\u90bb\u5c45\u805a\u5408\u66ff\u6362\u4e3a\u5b66\u4e60\u5230\u7684\u3001\u4f9d\u8d56\u6570\u636e\u7684\u52a0\u6743\u3002\u672c\u7ae0\u6db5\u76d6GAT\u3001\u591a\u5934\u56fe\u6ce8\u610f\u529b\u3001GATv2\u3001\u56feTransformer\u3001\u4f4d\u7f6e\u548c\u7ed3\u6784\u7f16\u7801\u4ee5\u53ca\u53ef\u6269\u5c55\u6027

"},{"location":"chapter%2012%3A%20graph%20neural%20networks/04.%20graph%20attention%20networks/#gat","title":"GAT\uff1a\u56fe\u6ce8\u610f\u529b\u7f51\u7edc","text":" \\[e_{ij} = \\text{LeakyReLU}\\left(\\mathbf{a}^T \\left[W\\mathbf{h}_i \\| W\\mathbf{h}_j\\right]\\right)\\] \\[\\alpha_{ij} = \\text{softmax}_j(e_{ij}) = \\frac{\\exp(e_{ij})}{\\sum_{k \\in \\mathcal{N}(i)} \\exp(e_{ik})}\\] \\[\\mathbf{h}_i' = \\sigma\\left(\\sum_{j \\in \\mathcal{N}(i)} \\alpha_{ij} W\\mathbf{h}_j\\right)\\]

"},{"location":"chapter%2012%3A%20graph%20neural%20networks/04.%20graph%20attention%20networks/#_2","title":"\u591a\u5934\u56fe\u6ce8\u610f\u529b","text":" \\[\\mathbf{h}_i' = \\Big\\|_{k=1}^{K} \\sigma\\left(\\sum_{j \\in \\mathcal{N}(i)} \\alpha_{ij}^k W^k \\mathbf{h}_j\\right)\\] "},{"location":"chapter%2012%3A%20graph%20neural%20networks/04.%20graph%20attention%20networks/#gatv2","title":"GATv2\uff1a\u4fee\u590d\u9759\u6001\u6ce8\u610f\u529b","text":" \\[e_{ij} = \\mathbf{a}^T \\text{LeakyReLU}\\left(W \\left[\\mathbf{h}_i \\| \\mathbf{h}_j\\right]\\right)\\] "},{"location":"chapter%2012%3A%20graph%20neural%20networks/04.%20graph%20attention%20networks/#transformer","title":"\u56feTransformer","text":" \\[\\text{Attention}(Q, K, V) = \\text{softmax}\\left(\\frac{QK^T}{\\sqrt{d_k}}\\right)V\\] \\[A_{ij} = \\frac{(\\mathbf{h}_i W_Q)(W_K^T \\mathbf{h}_j^T)}{\\sqrt{d_k}} + b_{\\text{spatial}}(i, j) + b_{\\text{edge}}(i, j)\\] \\[\\mathbf{h}_i' = \\text{MLP}\\left(\\mathbf{h}_i^{\\text{MPNN}} + \\mathbf{h}_i^{\\text{Attention}}\\right)\\] "},{"location":"chapter%2012%3A%20graph%20neural%20networks/04.%20graph%20attention%20networks/#_3","title":"\u4f4d\u7f6e\u7f16\u7801\u4e0e\u7ed3\u6784\u7f16\u7801","text":""},{"location":"chapter%2012%3A%20graph%20neural%20networks/04.%20graph%20attention%20networks/#_4","title":"\u53ef\u6269\u5c55\u6027","text":""},{"location":"chapter%2012%3A%20graph%20neural%20networks/04.%20graph%20attention%20networks/#_5","title":"\u65f6\u5e8f\u56fe\u4e0e\u52a8\u6001\u56fe","text":" \\[\\mathbf{s}_i(t^+) = \\text{GRU}\\left(\\mathbf{s}_i(t^-), \\; \\mathbf{m}_i(t)\\right)\\] \\[\\Phi(t) = \\left[\\cos(\\omega_1 t), \\sin(\\omega_1 t), \\ldots, \\cos(\\omega_d t), \\sin(\\omega_d t)\\right]\\] "},{"location":"chapter%2012%3A%20graph%20neural%20networks/04.%20graph%20attention%20networks/#colabnotebook","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528CoLab\u6216notebook\uff09","text":"
  1. \u4ece\u5934\u5b9e\u73b0\u4e00\u4e2a\u5355\u5934GAT\u6ce8\u610f\u529b\u3002\u8ba1\u7b97\u8282\u70b9\u4e0e\u5176\u90bb\u5c45\u4e4b\u95f4\u7684\u6ce8\u610f\u529b\u6743\u91cd\uff0c\u5e76\u9a8c\u8bc1\u6743\u91cd\u4e4b\u548c\u4e3a1\u3002

    import jax\nimport jax.numpy as jnp\n\nrng = jax.random.PRNGKey(0)\nk1, k2, k3 = jax.random.split(rng, 3)\n\nn_nodes, d_in, d_out = 5, 4, 3\n\n# \u968f\u673a\u8282\u70b9\u7279\u5f81\nH = jax.random.normal(k1, (n_nodes, d_in))\n\n# \u53ef\u5b66\u4e60\u53c2\u6570\nW = jax.random.normal(k2, (d_in, d_out)) * 0.5\na = jax.random.normal(k3, (2 * d_out,)) * 0.5\n\n# \u90bb\u63a5\uff08\u8282\u70b90\u8fde\u63a5\u52301, 2, 3\uff09\nneighbours_of_0 = [1, 2, 3]\n\n# \u53d8\u6362\u7279\u5f81\nWh = H @ W  # (n_nodes, d_out)\n\n# \u8ba1\u7b97\u8282\u70b90\u7684\u6ce8\u610f\u529b\u5206\u6570\nh_i = Wh[0]\nscores = []\nfor j in neighbours_of_0:\n    h_j = Wh[j]\n    e_ij = jnp.dot(a, jnp.concatenate([h_i, h_j]))\n    e_ij = jax.nn.leaky_relu(e_ij, negative_slope=0.2)\n    scores.append(float(e_ij))\n\nscores = jnp.array(scores)\nalpha = jax.nn.softmax(scores)\n\nprint(f\"\u539f\u59cb\u5206\u6570: {scores}\")\nprint(f\"\u6ce8\u610f\u529b\u6743\u91cd: {alpha}\")\nprint(f\"\u6743\u91cd\u4e4b\u548c: {alpha.sum():.4f}\")\n\n# \u52a0\u6743\u805a\u5408\nh_new = sum(alpha[k] * Wh[neighbours_of_0[k]] for k in range(len(neighbours_of_0)))\nprint(f\"\u66f4\u65b0\u540e\u7684\u8282\u70b90\u7279\u5f81: {h_new}\")\n

  2. \u6bd4\u8f83GCN\uff08\u56fa\u5b9a\u6743\u91cd\uff09\u548cGAT\uff08\u5b66\u4e60\u6743\u91cd\uff09\u7684\u805a\u5408\u3002\u5c55\u793aGAT\u53ef\u4ee5\u4e3a\u90bb\u5c45\u5206\u914d\u4e0d\u540c\u7684\u6743\u91cd\uff0c\u800cGCN\u7edf\u4e00\u5bf9\u5f85\u5b83\u4eec\u3002

    import jax\nimport jax.numpy as jnp\n\n# 4\u4e2a\u8282\u70b9\uff1a\u8282\u70b90\u8fde\u63a5\u52301, 2, 3\nA = jnp.array([[0,1,1,1],\n               [1,0,0,0],\n               [1,0,0,0],\n               [1,0,0,0]], dtype=float)\n\n# \u7279\u5f81\uff1a\u8282\u70b91\u975e\u5e38\u76f8\u5173\uff0c\u8282\u70b92\u662f\u566a\u58f0\uff0c\u8282\u70b93\u4e2d\u7b49\nH = jnp.array([[0.0, 0.0],   # \u8282\u70b90\n               [1.0, 0.0],   # \u8282\u70b91\uff08\u4fe1\u53f7\uff09\n               [0.0, 0.0],   # \u8282\u70b92\uff08\u566a\u58f0\uff09\n               [0.5, 0.0]])  # \u8282\u70b93\uff08\u4e2d\u7b49\uff09\n\n# GCN\uff1a\u5f52\u4e00\u5316\u90bb\u63a5\u6743\u91cd\nA_hat = A + jnp.eye(4)\nD_inv = jnp.diag(1.0 / A_hat.sum(axis=1))\ngcn_weights = (D_inv @ A_hat)[0]  # \u8282\u70b90\u7684\u6743\u91cd\nprint(f\"GCN\u4e2d\u8282\u70b90\u7684\u6743\u91cd: {gcn_weights}\")\nprint(\"  \u2192 \u6240\u6709\u90bb\u5c45\u83b7\u5f97\u5927\u81f4\u76f8\u7b49\u7684\u6743\u91cd\")\n\n# GAT\uff1a\u5b66\u4e60\u5230\u7684\u6ce8\u610f\u529b\uff08\u6a21\u62df\uff09\n# \u5047\u8bbe\u6ce8\u610f\u529b\u673a\u5236\u5b66\u4f1a\u5173\u6ce8\u8282\u70b91\ngat_weights = jnp.array([0.1, 0.7, 0.05, 0.15])  # \u5b66\u4e60\u5230\u7684\nprint(f\"\\nGAT\u4e2d\u8282\u70b90\u7684\u6743\u91cd: {gat_weights}\")\nprint(\"  \u2192 \u6700\u5177\u4fe1\u606f\u91cf\u7684\u8282\u70b91\u83b7\u5f97\u6700\u591a\u5173\u6ce8\")\n\ngcn_output = gcn_weights @ H\ngat_output = gat_weights @ H\nprint(f\"\\nGCN\u8f93\u51fa: {gcn_output}  \uff08\u88ab\u566a\u58f0\u7a00\u91ca\uff09\")\nprint(f\"GAT\u8f93\u51fa: {gat_output}  \uff08\u805a\u7126\u4e8e\u4fe1\u53f7\uff09\")\n

  3. \u6f14\u793a\u4f4d\u7f6e\u7f16\u7801\u7684\u76ca\u5904\u3002\u8ba1\u7b97\u56fe\u7684\u62c9\u666e\u62c9\u65af\u7279\u5f81\u5411\u91cf\u7f16\u7801\uff0c\u5c55\u793a\u7ed3\u6784\u76f8\u4f3c\u7684\u8282\u70b9\u83b7\u5f97\u76f8\u4f3c\u7684\u7f16\u7801\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\n# \u6760\u94c3\u56fe\uff1a\u4e24\u4e2a\u56e2\u7531\u4e00\u6761\u6865\u8fde\u63a5\nn = 10\nA = jnp.zeros((n, n))\n# \u56e21\uff1a\u8282\u70b90-4\nfor i in range(5):\n    for j in range(i+1, 5):\n        A = A.at[i,j].set(1).at[j,i].set(1)\n# \u56e22\uff1a\u8282\u70b95-9\nfor i in range(5, 10):\n    for j in range(i+1, 10):\n        A = A.at[i,j].set(1).at[j,i].set(1)\n# \u6865\nA = A.at[4,5].set(1).at[5,4].set(1)\n\nD = jnp.diag(A.sum(axis=1))\nL = D - A\neigenvalues, eigenvectors = jnp.linalg.eigh(L)\n\n# \u4f7f\u7528\u524d3\u4e2a\u975e\u5e73\u51e1\u7279\u5f81\u5411\u91cf\u4f5c\u4e3a\u4f4d\u7f6e\u7f16\u7801\npe = eigenvectors[:, 1:4]\n\nprint(\"\u62c9\u666e\u62c9\u65af\u4f4d\u7f6e\u7f16\u7801:\")\nfor i in range(n):\n    group = \"\u56e21\" if i < 5 else \"\u56e22\"\n    bridge = \" (\u6865)\" if i in [4, 5] else \"\"\n    print(f\"  \u8282\u70b9 {i} ({group}{bridge}): {pe[i]}\")\n\nplt.scatter(pe[:5, 0], pe[:5, 1], c=\"#3498db\", s=80, label=\"\u56e21\")\nplt.scatter(pe[5:, 0], pe[5:, 1], c=\"#e74c3c\", s=80, label=\"\u56e22\")\nplt.scatter(pe[[4,5], 0], pe[[4,5], 1], c=\"black\", s=120, marker=\"*\",\n            label=\"\u6865\u8282\u70b9\", zorder=5)\nplt.legend(); plt.grid(True)\nplt.title(\"\u62c9\u666e\u62c9\u65af\u7279\u5f81\u5411\u91cf\u4f4d\u7f6e\u7f16\u7801\")\nplt.xlabel(\"\u7279\u5f81\u5411\u91cf 1\"); plt.ylabel(\"\u7279\u5f81\u5411\u91cf 2\")\nplt.show()\n

"},{"location":"chapter%2012%3A%20graph%20neural%20networks/05.%203d%20graph%20networks/","title":"3D\u56fe\u7f51\u7edc","text":"

3D\u56fe\u7f51\u7edc\u5c06GNN\u6269\u5c55\u5230\u5177\u6709\u7a7a\u95f4\u51e0\u4f55\u7684\u6570\u636e\uff0c\u5176\u4e2d\u5fc5\u987b\u6b63\u786e\u5904\u7406\u65cb\u8f6c\u548c\u5e73\u79fb\u3002\u672c\u7ae0\u6db5\u76d6\u51e0\u4f55\u56fe\u3001SE(3)/E(n)\u7b49\u53d8\u6027\u3001SchNet\u3001DimeNet\u3001EGNN\u3001\u5f20\u91cf\u573a\u7f51\u7edc\u4ee5\u53ca\u5206\u5b50\u6027\u8d28\u9884\u6d4b\u3001\u86cb\u767d\u8d28\u7ed3\u6784\u3001\u6750\u6599\u79d1\u5b66\u548c\u836f\u7269\u53d1\u73b0\u4e2d\u7684\u5e94\u7528\u2014\u2014\u4ece3D\u7269\u7406\u4e16\u754c\u4e2d\u5b66\u4e60\u7684\u67b6\u6784\u3002

"},{"location":"chapter%2012%3A%20graph%20neural%20networks/05.%203d%20graph%20networks/#_1","title":"\u51e0\u4f55\u56fe","text":""},{"location":"chapter%2012%3A%20graph%20neural%20networks/05.%203d%20graph%20networks/#se3-en","title":"SE(3) \u548c E(n) \u7b49\u53d8\u6027","text":"

\\[f(R\\mathbf{r}_1, R\\mathbf{r}_2, \\ldots) = f(\\mathbf{r}_1, \\mathbf{r}_2, \\ldots) \\quad \\text{\uff08\u4e0d\u53d8\u6027\uff09}\\] \\[\\mathbf{F}(R\\mathbf{r}_1, R\\mathbf{r}_2, \\ldots) = R \\cdot \\mathbf{F}(\\mathbf{r}_1, \\mathbf{r}_2, \\ldots) \\quad \\text{\uff08\u7b49\u53d8\u6027\uff09}\\] "},{"location":"chapter%2012%3A%20graph%20neural%20networks/05.%203d%20graph%20networks/#schnet","title":"SchNet\uff1a\u57fa\u4e8e\u8ddd\u79bb\u7684\u6d88\u606f\u4f20\u9012","text":" \\[\\text{RBF}(d_{ij}) = \\left[\\exp\\left(-\\gamma_1 (d_{ij} - \\mu_1)^2\\right), \\ldots, \\exp\\left(-\\gamma_K (d_{ij} - \\mu_K)^2\\right)\\right]\\] \\[\\mathbf{m}_{j \\to i} = \\mathbf{h}_j \\odot W_{\\text{filter}}(\\text{RBF}(d_{ij}))\\] "},{"location":"chapter%2012%3A%20graph%20neural%20networks/05.%203d%20graph%20networks/#dimenetspherenet","title":"DimeNet\u548cSphereNet\uff1a\u89d2\u5ea6\u548c\u4e8c\u9762\u89d2","text":" \\[\\mathbf{m}_{kj \\to ji} = f\\left(\\mathbf{m}_{kj}, d_{ji}, \\theta_{kji}\\right)\\] "},{"location":"chapter%2012%3A%20graph%20neural%20networks/05.%203d%20graph%20networks/#engnnegnn","title":"E(n)\u7b49\u53d8GNN\uff08EGNN\uff09","text":" \\[\\mathbf{m}_{ij} = \\phi_e\\left(\\mathbf{h}_i, \\mathbf{h}_j, d_{ij}^2, a_{ij}\\right)\\] \\[\\mathbf{r}_i' = \\mathbf{r}_i + C \\sum_{j \\neq i} (\\mathbf{r}_i - \\mathbf{r}_j) \\cdot \\phi_r(\\mathbf{m}_{ij})\\] \\[\\mathbf{h}_i' = \\phi_h\\left(\\mathbf{h}_i, \\sum_j \\mathbf{m}_{ij}\\right)\\] "},{"location":"chapter%2012%3A%20graph%20neural%20networks/05.%203d%20graph%20networks/#_2","title":"\u5f20\u91cf\u573a\u7f51\u7edc\u4e0e\u9ad8\u9636\u8868\u793a","text":" \\[(\\mathbf{f}^{\\ell_1} \\otimes \\mathbf{f}^{\\ell_2})^{\\ell_{\\text{out}}} = \\sum_{m_1, m_2} C^{\\ell_{\\text{out}}, m_{\\text{out}}}_{\\ell_1, m_1, \\ell_2, m_2} \\cdot f^{\\ell_1}_{m_1} \\cdot f^{\\ell_2}_{m_2}\\] "},{"location":"chapter%2012%3A%20graph%20neural%20networks/05.%203d%20graph%20networks/#_3","title":"\u5e94\u7528","text":""},{"location":"chapter%2012%3A%20graph%20neural%20networks/05.%203d%20graph%20networks/#_4","title":"\u56fe\u751f\u6210","text":""},{"location":"chapter%2012%3A%20graph%20neural%20networks/05.%203d%20graph%20networks/#colabnotebook","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528CoLab\u6216notebook\uff09","text":"
  1. \u6784\u5efa\u4e00\u4e2a\u4f7f\u7528\u539f\u5b50\u95f4\u8ddd\u79bb\u7684\u7b80\u5355\u4e0d\u53d83D\u6d88\u606f\u4f20\u9012\u5c42\u3002\u5c06\u5176\u5e94\u7528\u4e8e\u4e00\u4e2a\u5c0f\u5206\u5b50\uff08\u6c34\uff1aH-O-H\uff09\uff0c\u5e76\u9a8c\u8bc1\u8f93\u51fa\u5bf9\u65cb\u8f6c\u662f\u4e0d\u53d8\u7684\u3002

    import jax\nimport jax.numpy as jnp\n\n# \u6c34\u5206\u5b50\uff1aO\u5728\u539f\u70b9\uff0c\u4e24\u4e2aH\u539f\u5b50\npositions = jnp.array([[0.0, 0.0, 0.0],     # O\n                        [0.96, 0.0, 0.0],    # H1\n                        [-0.24, 0.93, 0.0]])  # H2\n\n# \u8282\u70b9\u7279\u5f81\uff1a[\u539f\u5b50\u5e8f\u6570]\nfeatures = jnp.array([[8.0], [1.0], [1.0]])\n\n# \u8ba1\u7b97\u6210\u5bf9\u8ddd\u79bb\uff08\u4e0d\u53d8\u7684\uff09\ndef pairwise_distances(pos):\n    diff = pos[:, None, :] - pos[None, :, :]\n    return jnp.sqrt(jnp.sum(diff**2, axis=-1) + 1e-8)\n\n# \u7b80\u5355\u7684\u57fa\u4e8e\u8ddd\u79bb\u7684\u6d88\u606f\u4f20\u9012\ndef invariant_message_pass(features, positions):\n    dists = pairwise_distances(positions)\n    # \u5177\u67094\u4e2a\u4e2d\u5fc3\u7684RBF\u6269\u5c55\n    centres = jnp.array([0.5, 1.0, 1.5, 2.0])\n    rbf = jnp.exp(-5.0 * (dists[:, :, None] - centres[None, None, :]) ** 2)\n\n    # \u6d88\u606f\uff1a\u7531\u8ddd\u79bb\u76f8\u5173\u6ee4\u6ce2\u5668\u52a0\u6743\u7684\u7279\u5f81\n    messages = jnp.einsum(\"ij,jd->id\", rbf.sum(axis=-1), features)\n    return messages\n\noutput1 = invariant_message_pass(features, positions)\n\n# \u5c06\u5206\u5b50\u7ed5z\u8f74\u65cb\u8f6c90\u5ea6\nR = jnp.array([[0, -1, 0], [1, 0, 0], [0, 0, 1]], dtype=float)\nrotated_positions = (R @ positions.T).T\n\noutput2 = invariant_message_pass(features, rotated_positions)\n\nprint(f\"\u539f\u59cb\u8f93\u51fa:\\n{output1}\")\nprint(f\"\\n\u65cb\u8f6c\u540e\u8f93\u51fa:\\n{output2}\")\nprint(f\"\\n\u4e0d\u53d8\u6027: {jnp.allclose(output1, output2, atol=1e-5)}\")\n

  2. \u8ba1\u7b97\u4e09\u4e2a\u539f\u5b50\u4e4b\u95f4\u7684\u952e\u89d2\uff0c\u5e76\u9a8c\u8bc1\u5176\u5bf9\u65cb\u8f6c\u4e0d\u53d8\u3002

    import jax.numpy as jnp\n\ndef bond_angle(r_i, r_j, r_k):\n    \"\"\"\u8282\u70b9j\u5904\u8fb9j->i\u548cj->k\u4e4b\u95f4\u7684\u89d2\u5ea6\u3002\"\"\"\n    v1 = r_i - r_j\n    v2 = r_k - r_j\n    cos_angle = jnp.dot(v1, v2) / (jnp.linalg.norm(v1) * jnp.linalg.norm(v2))\n    return jnp.arccos(jnp.clip(cos_angle, -1, 1))\n\n# \u4e09\u4e2a\u539f\u5b50\nr1 = jnp.array([1.0, 0.0, 0.0])\nr2 = jnp.array([0.0, 0.0, 0.0])\nr3 = jnp.array([0.0, 1.0, 0.0])\n\nangle_original = bond_angle(r1, r2, r3)\nprint(f\"\u539f\u59cb\u89d2\u5ea6: {jnp.degrees(angle_original):.1f}\u00b0\")\n\n# \u5e94\u7528\u968f\u673a\u65cb\u8f6c\nR = jnp.array([[0.36, 0.48, -0.80],\n               [-0.80, 0.60, 0.00],\n               [0.48, 0.64, 0.60]])\nr1_rot, r2_rot, r3_rot = R @ r1, R @ r2, R @ r3\n\nangle_rotated = bond_angle(r1_rot, r2_rot, r3_rot)\nprint(f\"\u65cb\u8f6c\u540e\u89d2\u5ea6:  {jnp.degrees(angle_rotated):.1f}\u00b0\")\nprint(f\"\u4e0d\u53d8\u6027: {jnp.allclose(angle_original, angle_rotated, atol=1e-4)}\")\n

  3. \u6f14\u793a\u7b49\u53d8\u4f4d\u7f6e\u66f4\u65b0\uff08EGNN\u98ce\u683c\uff09\u3002\u4f7f\u7528\u8ddd\u79bb\u52a0\u6743\u7684\u76f8\u5bf9\u5411\u91cf\u66f4\u65b0\u8282\u70b9\u4f4d\u7f6e\uff0c\u5e76\u9a8c\u8bc1\u7b49\u53d8\u6027\u3002

    import jax\nimport jax.numpy as jnp\n\ndef egnn_position_update(positions, features):\n    \"\"\"\u7b80\u5355\u7684EGNN\u98ce\u683c\u7b49\u53d8\u4f4d\u7f6e\u66f4\u65b0\u3002\"\"\"\n    n = positions.shape[0]\n    new_positions = jnp.zeros_like(positions)\n\n    for i in range(n):\n        shift = jnp.zeros(3)\n        for j in range(n):\n            if i != j:\n                r_ij = positions[i] - positions[j]\n                d_ij = jnp.linalg.norm(r_ij)\n                # \u57fa\u4e8e\u8ddd\u79bb\u7684\u6743\u91cd\uff08\u7b80\u5355\uff1a\u53cd\u6bd4\u8ddd\u79bb\uff09\n                weight = 1.0 / (d_ij + 1.0)\n                # \u6309\u7279\u5f81\u76f8\u4f3c\u5ea6\u7f29\u653e\n                feat_sim = jnp.dot(features[i], features[j])\n                shift = shift + weight * feat_sim * r_ij\n        new_positions = new_positions.at[i].set(positions[i] + 0.1 * shift)\n\n    return new_positions\n\n# 3\u4e2a\u539f\u5b50\npos = jnp.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])\nfeat = jnp.array([[1.0, 0.5], [0.5, 1.0], [0.8, 0.3]])\n\n# \u66f4\u65b0\u4f4d\u7f6e\npos_new = egnn_position_update(pos, feat)\n\n# \u73b0\u5728\u65cb\u8f6c\u8f93\u5165\u3001\u66f4\u65b0\uff0c\u5e76\u68c0\u67e5\u8f93\u51fa\u662f\u5426\u4e00\u81f4\u5730\u65cb\u8f6c\nR = jnp.array([[0.0, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]])\npos_rot = (R @ pos.T).T\npos_new_from_rot = egnn_position_update(pos_rot, feat)\n\n# \u5e94\u4e0e\u65cb\u8f6c\u539f\u59cb\u8f93\u51fa\u76f8\u540c\npos_new_then_rot = (R @ pos_new.T).T\n\nprint(f\"\u5148\u66f4\u65b0\u518d\u65cb\u8f6c:\\n{jnp.round(pos_new_then_rot, 4)}\")\nprint(f\"\\n\u5148\u65cb\u8f6c\u518d\u66f4\u65b0:\\n{jnp.round(pos_new_from_rot, 4)}\")\nprint(f\"\\n\u7b49\u53d8\u6027: {jnp.allclose(pos_new_then_rot, pos_new_from_rot, atol=1e-4)}\")\n

"},{"location":"chapter%2013%3A%20computing%20and%20OS/01.%20discrete%20maths/","title":"\u79bb\u6563\u6570\u5b66","text":"

\u79bb\u6563\u6570\u5b66\u662f\u5173\u4e8e\u53ef\u6570\u3001\u5206\u79bb\u7ed3\u6784\u7684\u6570\u5b66\uff0c\u662f\u8ba1\u7b97\u6784\u5efa\u7684\u57fa\u7840\u3002\u672c\u6587\u6db5\u76d6\u547d\u9898\u903b\u8f91\u4e0e\u8c13\u8bcd\u903b\u8f91\u3001\u8bc1\u660e\u6280\u5de7\u3001\u96c6\u5408\u3001\u5173\u7cfb\u3001\u51fd\u6570\u3001\u56fe\u8bba\u57fa\u7840\u4ee5\u53ca\u9012\u63a8\u5173\u7cfb\u3002

"},{"location":"chapter%2013%3A%20computing%20and%20OS/01.%20discrete%20maths/#_2","title":"\u547d\u9898\u903b\u8f91","text":" \\(p\\) \\(q\\) \\(p \\wedge q\\) \\(p \\vee q\\) \\(p \\to q\\) T T T T T T F F T F F T F T T F F F F T "},{"location":"chapter%2013%3A%20computing%20and%20OS/01.%20discrete%20maths/#_3","title":"\u8c13\u8bcd\u903b\u8f91\u4e0e\u91cf\u8bcd","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/01.%20discrete%20maths/#_4","title":"\u8bc1\u660e\u6280\u5de7","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/01.%20discrete%20maths/#_5","title":"\u96c6\u5408","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/01.%20discrete%20maths/#_6","title":"\u5173\u7cfb","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/01.%20discrete%20maths/#_7","title":"\u51fd\u6570","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/01.%20discrete%20maths/#_8","title":"\u56fe\u8bba\u57fa\u7840","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/01.%20discrete%20maths/#_9","title":"\u9012\u63a8\u5173\u7cfb","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/01.%20discrete%20maths/#_10","title":"\u53ef\u8ba1\u7b97\u6027","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/01.%20discrete%20maths/#_11","title":"\u590d\u6742\u5ea6\u7406\u8bba","text":" "},{"location":"chapter%2013%3A%20computing%20and%20OS/01.%20discrete%20maths/#colab","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528CoLab\u6216\u7b14\u8bb0\u672c\uff09","text":"
  1. \u6784\u5efa\u4e00\u4e2a\u771f\u503c\u8868\u751f\u6210\u5668\u3002\u7ed9\u5b9a\u4e00\u4e2a\u903b\u8f91\u8868\u8fbe\u5f0f\uff0c\u679a\u4e3e\u6240\u6709\u8f93\u5165\u7ec4\u5408\u5e76\u8ba1\u7b97\u7ed3\u679c\u3002

    import itertools\n\ndef truth_table(n_vars, expr_fn):\n    \"\"\"\u4e3a\u4e00\u4e2an_vars\u4e2a\u53d8\u91cf\u7684\u5e03\u5c14\u51fd\u6570\u751f\u6210\u771f\u503c\u8868\u3002\"\"\"\n    headers = [f\"p{i}\" for i in range(n_vars)]\n    print(\" | \".join(headers + [\"result\"]))\n    print(\"-\" * (len(headers) * 4 + 10))\n    for vals in itertools.product([False, True], repeat=n_vars):\n        result = expr_fn(*vals)\n        row = [str(v)[0] for v in vals] + [str(result)[0]]\n        print(\" | \".join(f\"{r:>2}\" for r in row))\n\n# \u5fb7\u6469\u6839\u5b9a\u5f8b\uff1aNOT(p AND q) == (NOT p) OR (NOT q)\nprint(\"\u5fb7\u6469\u6839\u5b9a\u5f8b\u9a8c\u8bc1\uff1a\")\ntruth_table(2, lambda p, q: (not (p and q)) == ((not p) or (not q)))\n

  2. \u901a\u8fc7\u5f52\u7eb3\u6cd5\u8bc1\u660e\u6c42\u548c\u516c\u5f0f\u2014\u2014\u5bf9\u591a\u4e2a\u503c\u8fdb\u884c\u6570\u503c\u9a8c\u8bc1\uff0c\u7136\u540e\u5b9e\u73b0\u5c01\u95ed\u5f62\u5f0f\u89e3\u3002

    import jax.numpy as jnp\n\n# \u9a8c\u8bc1\u6c42\u548c\u516c\u5f0f\uff1asum(1..n) = n(n+1)/2\nfor n in [1, 5, 10, 100, 1000, 10000]:\n    brute = sum(range(1, n + 1))\n    formula = n * (n + 1) // 2\n    print(f\"n={n:5d}  sum={brute:>10d}  formula={formula:>10d}  match={brute == formula}\")\n

  3. \u4f7f\u7528\u4e3b\u5b9a\u7406\u6c42\u89e3\u5f52\u5e76\u6392\u5e8f\u9012\u63a8\u5173\u7cfb\uff0c\u5e76\u901a\u8fc7\u8ba1\u6570\u64cd\u4f5c\u8fdb\u884c\u7ecf\u9a8c\u9a8c\u8bc1\u3002

    import jax.numpy as jnp\n\ndef merge_sort_ops(n):\n    \"\"\"\u7edf\u8ba1\u5f52\u5e76\u6392\u5e8f\u4e2d\u7684\u6bd4\u8f83\u6b21\u6570\uff08\u9012\u63a8\uff1aT(n) = 2T(n/2) + n\uff09\u3002\"\"\"\n    if n <= 1:\n        return 0\n    half = n // 2\n    return merge_sort_ops(half) + merge_sort_ops(n - half) + n\n\nfor n in [8, 64, 512, 4096, 32768]:\n    ops = merge_sort_ops(n)\n    predicted = n * jnp.log2(n)\n    ratio = ops / predicted\n    print(f\"n={n:5d}  ops={ops:>10d}  n log n={int(predicted):>10d}  ratio={ratio:.3f}\")\n

"},{"location":"chapter%2013%3A%20computing%20and%20OS/02.%20computer%20architecture/","title":"\u8ba1\u7b97\u673a\u4f53\u7cfb\u7ed3\u6784","text":"

\u8ba1\u7b97\u673a\u4f53\u7cfb\u7ed3\u6784\u662f\u5173\u4e8e\u5982\u4f55\u6784\u5efa\u6267\u884c\u6307\u4ee4\u7684\u673a\u5668\u3002\u672c\u6587\u6db5\u76d6\u6570\u5236\u3001\u903b\u8f91\u95e8\u3001CPU\u8bbe\u8ba1\u3001\u6307\u4ee4\u96c6\u67b6\u6784\u3001\u6d41\u6c34\u7ebf\u3001\u5b58\u50a8\u5668\u5c42\u6b21\u7ed3\u6784\u548c\u865a\u62df\u5185\u5b58\u2014\u2014\u6bcf\u4e2a\u7a0b\u5e8f\u3001\u6846\u67b6\u548cAI\u6a21\u578b\u6700\u7ec8\u8fd0\u884c\u5176\u4e0a\u7684\u786c\u4ef6\u57fa\u7840\u3002

"},{"location":"chapter%2013%3A%20computing%20and%20OS/02.%20computer%20architecture/#_2","title":"\u6570\u5236","text":"

- **float32**\uff08\u5355\u7cbe\u5ea6\uff09\uff1a1\u4e2a\u7b26\u53f7 + 8\u4e2a\u6307\u6570 + 23\u4e2a\u5c3e\u6570 = 32\u4f4d\u3002\u8303\u56f4\uff1a$\\approx \\pm 3.4 \\times 10^{38}$\uff0c\u7cbe\u5ea6\uff1a$\\approx 7$\u4f4d\u5341\u8fdb\u5236\u6570\u5b57\u3002\n- **float64**\uff08\u53cc\u7cbe\u5ea6\uff09\uff1a1\u4e2a\u7b26\u53f7 + 11\u4e2a\u6307\u6570 + 52\u4e2a\u5c3e\u6570 = 64\u4f4d\u3002\u8303\u56f4\uff1a$\\approx \\pm 1.8 \\times 10^{308}$\uff0c\u7cbe\u5ea6\uff1a$\\approx 15$\u4f4d\u5341\u8fdb\u5236\u6570\u5b57\u3002\n- **float16**\uff08\u534a\u7cbe\u5ea6\uff09\uff1a1 + 5 + 10 = 16\u4f4d\u3002\u8303\u56f4\u548c\u7cbe\u5ea6\u6709\u9650\uff0c\u4f46\u4f7f\u7528\u4e00\u534a\u7684\u5185\u5b58\u548c\u5e26\u5bbd\u3002\u5e7f\u6cdb\u7528\u4e8eML\u8bad\u7ec3\uff08\u6df7\u5408\u7cbe\u5ea6\uff0c\u7b2c6\u7ae0\uff09\u3002\n- **bfloat16**\uff1a1 + 8 + 7 = 16\u4f4d\u3002\u4e0efloat32\u76f8\u540c\u7684\u6307\u6570\u8303\u56f4\u4f46\u7cbe\u5ea6\u66f4\u4f4e\u3002\u7531Google\u4e13\u95e8\u4e3aML\u8bbe\u8ba1\uff1a\u5b8c\u6574\u7684\u6307\u6570\u8303\u56f4\u53ef\u9632\u6b62\u8bad\u7ec3\u671f\u95f4\u6ea2\u51fa\uff0c\u964d\u4f4e\u7684\u7cbe\u5ea6\u5bf9\u68af\u5ea6\u66f4\u65b0\u662f\u53ef\u4ee5\u63a5\u53d7\u7684\u3002\n
"},{"location":"chapter%2013%3A%20computing%20and%20OS/02.%20computer%20architecture/#_3","title":"\u903b\u8f91\u95e8","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/02.%20computer%20architecture/#cpu","title":"CPU\u67b6\u6784","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/02.%20computer%20architecture/#_4","title":"\u6307\u4ee4\u96c6\u67b6\u6784","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/02.%20computer%20architecture/#_5","title":"\u6d41\u6c34\u7ebf","text":" "},{"location":"chapter%2013%3A%20computing%20and%20OS/02.%20computer%20architecture/#_6","title":"\u5b58\u50a8\u5668\u5c42\u6b21\u7ed3\u6784","text":" "},{"location":"chapter%2013%3A%20computing%20and%20OS/02.%20computer%20architecture/#_7","title":"\u865a\u62df\u5185\u5b58","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/02.%20computer%20architecture/#iodma","title":"I/O\u3001\u4e2d\u65ad\u548cDMA","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/02.%20computer%20architecture/#colab","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528CoLab\u6216\u7b14\u8bb0\u672c\uff09","text":"
  1. \u63a2\u7d22IEEE 754\u6d6e\u70b9\u6570\u8868\u793a\u3002\u5c06\u6d6e\u70b9\u6570\u8f6c\u6362\u4e3a\u4e8c\u8fdb\u5236\u8868\u793a\uff0c\u89c2\u5bdf\u7b26\u53f7\u3001\u6307\u6570\u548c\u5c3e\u6570\u5b57\u6bb5\u3002

    import struct\n\ndef float_to_bits(f):\n    \"\"\"\u663e\u793afloat32\u7684IEEE 754\u4e8c\u8fdb\u5236\u8868\u793a\u3002\"\"\"\n    packed = struct.pack('>f', f)\n    bits = ''.join(f'{byte:08b}' for byte in packed)\n    sign = bits[0]\n    exponent = bits[1:9]\n    mantissa = bits[9:]\n    return sign, exponent, mantissa\n\nfor val in [1.0, -1.0, 0.1, 0.5, 3.14, float('inf'), float('nan')]:\n    s, e, m = float_to_bits(val)\n    print(f\"{val:>10}  sign={s}  exp={e} ({int(e, 2) - 127:>4d})  mantissa={m[:10]}...\")\n

  2. \u6a21\u62df\u76f4\u63a5\u6620\u5c04\u7f13\u5b58\u3002\u8ddf\u8e2a\u4e00\u7cfb\u5217\u5185\u5b58\u8bbf\u95ee\u7684\u547d\u4e2d\u4e0e\u672a\u547d\u4e2d\u3002

    def simulate_cache(accesses, cache_size=8, block_size=1):\n    \"\"\"\u6a21\u62df\u76f4\u63a5\u6620\u5c04\u7f13\u5b58\u3002\"\"\"\n    cache = [None] * cache_size\n    hits, misses = 0, 0\n\n    for addr in accesses:\n        cache_line = addr % cache_size\n        if cache[cache_line] == addr:\n            hits += 1\n            status = \"HIT \"\n        else:\n            misses += 1\n            cache[cache_line] = addr\n            status = \"MISS\"\n        print(f\"  Access {addr:3d} \u2192 line {cache_line}: {status}\")\n\n    print(f\"\\nHits: {hits}, Misses: {misses}, Hit rate: {hits/(hits+misses):.1%}\")\n\n# \u987a\u5e8f\u8bbf\u95ee\uff08\u826f\u597d\u7684\u5c40\u90e8\u6027\uff09\nprint(\"\u987a\u5e8f\u8bbf\u95ee\uff1a\")\nsimulate_cache([0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3])\n\n# \u8de8\u6b65\u8bbf\u95ee\uff08\u51b2\u7a81\u672a\u547d\u4e2d\uff09\nprint(\"\\n\u8de8\u6b65\u8bbf\u95ee\uff08stride = cache size\uff09\uff1a\")\nsimulate_cache([0, 8, 0, 8, 0, 8])\n

  3. \u6f14\u793a\u4e3a\u4ec0\u4e48\u6d6e\u70b9\u7b97\u672f\u4e0d\u6ee1\u8db3\u7ed3\u5408\u5f8b\u3002\u5c55\u793a \\((a + b) + c \\neq a + (b + c)\\) \u7684\u60c5\u51b5\u3002

    import jax.numpy as jnp\n\na = jnp.float32(1e8)\nb = jnp.float32(1.0)\nc = jnp.float32(-1e8)\n\nleft = (a + b) + c   # (1e8 + 1) + (-1e8)\nright = a + (b + c)  # 1e8 + (1 + (-1e8))\n\nprint(f\"(a + b) + c = {left}\")   # \u5e94\u4e3a 1.0\nprint(f\"a + (b + c) = {right}\")  # \u53ef\u80fd\u4f1a\u4e22\u5931 1.0\nprint(f\"Equal: {left == right}\")\nprint(f\"\\n\u5f53 1.0 \u52a0\u5230 1e8 \u4e0a\u65f6\u88ab\u4e22\u5931\uff0c\u56e0\u4e3a float32 \u53ea\u6709\u7ea6 7 \u4f4d\u7cbe\u5ea6\")\n

"},{"location":"chapter%2013%3A%20computing%20and%20OS/03.%20operating%20systems/","title":"\u64cd\u4f5c\u7cfb\u7edf","text":"

\u64cd\u4f5c\u7cfb\u7edf\u662f\u786c\u4ef6\u4e0e\u5e94\u7528\u7a0b\u5e8f\u4e4b\u95f4\u7684\u8f6f\u4ef6\u5c42\uff0c\u8d1f\u8d23\u7ba1\u7406\u8d44\u6e90\u3001\u63d0\u4f9b\u62bd\u8c61\u5e76\u5b9e\u65bd\u9694\u79bb\u3002\u672c\u6587\u6db5\u76d6\u64cd\u4f5c\u7cfb\u7edf\u7684\u529f\u80fd\u3001\u8fdb\u7a0b\u3001\u7ebf\u7a0b\u3001CPU\u8c03\u5ea6\u3001\u5185\u5b58\u7ba1\u7406\u3001\u6587\u4ef6\u7cfb\u7edf\u548c\u7cfb\u7edf\u8c03\u7528\u3002

"},{"location":"chapter%2013%3A%20computing%20and%20OS/03.%20operating%20systems/#_2","title":"\u64cd\u4f5c\u7cfb\u7edf\u505a\u4ec0\u4e48","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/03.%20operating%20systems/#_3","title":"\u8fdb\u7a0b","text":"

"},{"location":"chapter%2013%3A%20computing%20and%20OS/03.%20operating%20systems/#_4","title":"\u7ebf\u7a0b","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/03.%20operating%20systems/#cpu","title":"CPU\u8c03\u5ea6","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/03.%20operating%20systems/#_5","title":"\u5185\u5b58\u7ba1\u7406","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/03.%20operating%20systems/#_6","title":"\u6587\u4ef6\u7cfb\u7edf","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/03.%20operating%20systems/#_7","title":"\u7cfb\u7edf\u8c03\u7528\u4e0e\u5185\u6838\u6a21\u5f0f","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/03.%20operating%20systems/#_8","title":"\u7f51\u7edc\u57fa\u7840","text":" "},{"location":"chapter%2013%3A%20computing%20and%20OS/03.%20operating%20systems/#_9","title":"\u865a\u62df\u5316\u4e0e\u5bb9\u5668","text":"

"},{"location":"chapter%2013%3A%20computing%20and%20OS/03.%20operating%20systems/#_10","title":"\u5b89\u5168\u57fa\u7840","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/03.%20operating%20systems/#colab","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528CoLab\u6216\u7b14\u8bb0\u672c\uff09","text":"
  1. \u63a2\u7d22\u8fdb\u7a0b\u521b\u5efa\u3002\u4f7f\u7528Python\u7684 os.fork()\uff08\u4ec5Unix\uff09\u521b\u5efa\u4e00\u4e2a\u5b50\u8fdb\u7a0b\uff0c\u5e76\u89c2\u5bdf\u7236\u8fdb\u7a0b\u548c\u5b50\u8fdb\u7a0b\u5982\u4f55\u4ece\u540c\u4e00\u70b9\u7ee7\u7eed\u6267\u884c\u3002

    import os\n\npid = os.fork()\n\nif pid == 0:\n    # \u5b50\u8fdb\u7a0b\n    print(f\"Child: my PID is {os.getpid()}, parent PID is {os.getppid()}\")\nelse:\n    # \u7236\u8fdb\u7a0b\n    print(f\"Parent: my PID is {os.getpid()}, child PID is {pid}\")\n    os.wait()  # \u7b49\u5f85\u5b50\u8fdb\u7a0b\u7ed3\u675f\n

  2. \u6a21\u62df\u8f6e\u8f6c\u8c03\u5ea6\u3002\u7ed9\u5b9a\u4e00\u4e2a\u5e26\u6709\u6267\u884c\u65f6\u95f4\u7684\u8fdb\u7a0b\u5217\u8868\uff0c\u6a21\u62df\u8c03\u5ea6\u5e76\u8ba1\u7b97\u5e73\u5747\u7b49\u5f85\u65f6\u95f4\u3002

    def round_robin(processes, quantum=3):\n    \"\"\"\u6a21\u62df\u8f6e\u8f6c\u8c03\u5ea6\u3002\n    processes: (name, burst_time) \u5143\u7ec4\u5217\u8868\u3002\n    \"\"\"\n    queue = [(name, burst, 0) for name, burst in processes]  # (name, remaining, wait)\n    time = 0\n    log = []\n\n    while queue:\n        name, remaining, waited = queue.pop(0)\n        waited += (time - waited - (processes[[p[0] for p in processes].index(name)][1] - remaining))\n        run_time = min(quantum, remaining)\n        log.append(f\"  t={time:3d}: {name} runs for {run_time} (remaining: {remaining - run_time})\")\n        time += run_time\n        remaining -= run_time\n\n        if remaining > 0:\n            queue.append((name, remaining, time))\n        else:\n            log.append(f\"  t={time:3d}: {name} DONE (turnaround: {time})\")\n\n    for line in log:\n        print(line)\n\nprint(\"\u8f6e\u8f6c\u8c03\u5ea6 (quantum=3)\uff1a\")\nround_robin([(\"P1\", 10), (\"P2\", 4), (\"P3\", 6)], quantum=3)\n

  3. \u6a21\u62dfLRU\u9875\u9762\u7f6e\u6362\u3002\u7ed9\u5b9a\u4e00\u4e2a\u9875\u9762\u8bbf\u95ee\u5e8f\u5217\u548c\u56fa\u5b9a\u6570\u91cf\u7684\u5e27\uff0c\u7edf\u8ba1\u7f3a\u9875\u6b21\u6570\u3002

    def lru_page_replacement(pages, n_frames):\n    \"\"\"\u6a21\u62dfLRU\u9875\u9762\u7f6e\u6362\u3002\"\"\"\n    frames = []\n    faults = 0\n\n    for page in pages:\n        if page in frames:\n            frames.remove(page)\n            frames.append(page)  # \u79fb\u52a8\u5230\u6700\u8fd1\u4f7f\u7528\n            status = \"HIT \"\n        else:\n            faults += 1\n            if len(frames) >= n_frames:\n                evicted = frames.pop(0)  # \u79fb\u9664\u6700\u8fd1\u6700\u5c11\u4f7f\u7528\n                status = f\"MISS (evict {evicted})\"\n            else:\n                status = \"MISS (cold)\"\n            frames.append(page)\n        print(f\"  Page {page}: {status}  frames={frames}\")\n\n    print(f\"\\nTotal faults: {faults}/{len(pages)} ({faults/len(pages):.0%})\")\n\nprint(\"LRU with 3 frames:\")\nlru_page_replacement([1, 2, 3, 4, 1, 2, 5, 1, 2, 3, 4, 5], n_frames=3)\n

"},{"location":"chapter%2013%3A%20computing%20and%20OS/04.%20concurrency%20and%20parallelism/","title":"\u5e76\u53d1\u4e0e\u5e76\u884c","text":"

\u5e76\u53d1\u4e0e\u5e76\u884c\u662f\u7a0b\u5e8f\u540c\u65f6\u5904\u7406\u591a\u4ef6\u4e8b\u60c5\u7684\u65b9\u5f0f\u3002\u672c\u6587\u6db5\u76d6\u5e76\u53d1\u4e0e\u5e76\u884c\u7684\u533a\u522b\u3001\u540c\u6b65\u539f\u8bed\u3001\u7ecf\u5178\u5e76\u53d1\u95ee\u9898\u3001\u6b7b\u9501\u3001\u65e0\u9501\u6570\u636e\u7ed3\u6784\u3001\u5e76\u884c\u7f16\u7a0b\u6a21\u578b\u3001\u5f02\u6b65\u7f16\u7a0b\u548c\u6269\u5c55\u5b9a\u5f8b\u2014\u2014\u8fd9\u4e9b\u6982\u5ff5\u652f\u6491\u7740\u591a\u7ebf\u7a0b\u670d\u52a1\u5668\u3001\u5206\u5e03\u5f0f\u8bad\u7ec3\u548c\u6bcf\u4e00\u4e2a\u73b0\u4ee3\u5e94\u7528\u7a0b\u5e8f\u3002

"},{"location":"chapter%2013%3A%20computing%20and%20OS/04.%20concurrency%20and%20parallelism/#vs","title":"\u5e76\u53d1 vs \u5e76\u884c","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/04.%20concurrency%20and%20parallelism/#_2","title":"\u540c\u6b65\u539f\u8bed","text":"
lock.acquire()\ncounter += 1      # \u4e00\u6b21\u53ea\u6709\u4e00\u4e2a\u7ebf\u7a0b\u5728\u6b64\nlock.release()\n
"},{"location":"chapter%2013%3A%20computing%20and%20OS/04.%20concurrency%20and%20parallelism/#_3","title":"\u7ecf\u5178\u5e76\u53d1\u95ee\u9898","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/04.%20concurrency%20and%20parallelism/#_4","title":"\u6b7b\u9501","text":" "},{"location":"chapter%2013%3A%20computing%20and%20OS/04.%20concurrency%20and%20parallelism/#_5","title":"\u65e0\u9501\u548c\u514d\u7b49\u5f85\u6570\u636e\u7ed3\u6784","text":"
CAS(address, expected, new_value):\n    if *address == expected:\n        *address = new_value\n        return true\n    else:\n        return false\n
"},{"location":"chapter%2013%3A%20computing%20and%20OS/04.%20concurrency%20and%20parallelism/#_6","title":"\u5e76\u884c\u7f16\u7a0b\u6a21\u578b","text":"
#pragma omp parallel for\nfor (int i = 0; i < n; i++) {\n    result[i] = compute(data[i]);\n}\n
MPI_Send(data, count, MPI_FLOAT, dest, tag, MPI_COMM_WORLD);\nMPI_Recv(data, count, MPI_FLOAT, src, tag, MPI_COMM_WORLD, &status);\n
"},{"location":"chapter%2013%3A%20computing%20and%20OS/04.%20concurrency%20and%20parallelism/#_7","title":"\u5f02\u6b65\u4e0e\u4e8b\u4ef6\u9a71\u52a8\u7f16\u7a0b","text":"
async def fetch_data(url):\n    response = await http_get(url)  # \u5728\u6b64\u6682\u505c\uff0c\u4e8b\u4ef6\u5faa\u73af\u8fd0\u884c\u5176\u4ed6\u4efb\u52a1\n    return process(response)         # \u54cd\u5e94\u5230\u8fbe\u65f6\u6062\u590d\n
"},{"location":"chapter%2013%3A%20computing%20and%20OS/04.%20concurrency%20and%20parallelism/#_8","title":"\u6269\u5c55\u5b9a\u5f8b","text":" \\[\\text{\u52a0\u901f\u6bd4}(n) = \\frac{1}{(1-p) + \\frac{p}{n}}\\] \\[\\text{\u52a0\u901f\u6bd4}(n) = 1 - p + p \\cdot n\\] "},{"location":"chapter%2013%3A%20computing%20and%20OS/04.%20concurrency%20and%20parallelism/#colab","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528CoLab\u6216\u7b14\u8bb0\u672c\uff09","text":"
  1. \u6f14\u793a\u7ade\u6001\u6761\u4ef6\u3002\u4e24\u4e2a\u7ebf\u7a0b\u5728\u6ca1\u6709\u540c\u6b65\u7684\u60c5\u51b5\u4e0b\u589e\u52a0\u4e00\u4e2a\u5171\u4eab\u8ba1\u6570\u5668\uff0c\u89c2\u5bdf\u4e22\u5931\u7684\u66f4\u65b0\u3002

    import threading\n\ncounter = 0\n\ndef increment(n):\n    global counter\n    for _ in range(n):\n        counter += 1  # \u4e0d\u662f\u539f\u5b50\u7684\uff1a\u8bfb\u3001\u52a0\u3001\u5199\n\nthreads = [threading.Thread(target=increment, args=(100000,)) for _ in range(4)]\nfor t in threads: t.start()\nfor t in threads: t.join()\n\nprint(f\"Expected: {4 * 100000}\")\nprint(f\"Actual:   {counter}\")\nprint(f\"Lost updates: {4 * 100000 - counter}\")\n

  2. \u7528\u9501\u4fee\u590d\u7ade\u6001\u6761\u4ef6\u5e76\u6d4b\u91cf\u5f00\u9500\u3002

    import threading\nimport time\n\nlock = threading.Lock()\ncounter = 0\n\ndef increment_locked(n):\n    global counter\n    for _ in range(n):\n        with lock:\n            counter += 1\n\nstart = time.time()\nthreads = [threading.Thread(target=increment_locked, args=(100000,)) for _ in range(4)]\nfor t in threads: t.start()\nfor t in threads: t.join()\nelapsed = time.time() - start\n\nprint(f\"Counter: {counter} (correct: {4 * 100000})\")\nprint(f\"Time with lock: {elapsed:.3f}s\")\n

  3. \u53ef\u89c6\u5316\u963f\u59c6\u8fbe\u5c14\u5b9a\u5f8b\u3002\u7ed8\u5236\u4e0d\u540c\u5e76\u884c\u6bd4\u4f8b\u4e0b\u52a0\u901f\u6bd4\u4e0e\u5904\u7406\u5668\u6570\u91cf\u7684\u5173\u7cfb\u56fe\u3002

    import jax.numpy as jnp\nimport matplotlib.pyplot as plt\n\nn_procs = jnp.arange(1, 65)\n\nfor p, color in [(0.5, \"#e74c3c\"), (0.9, \"#f39c12\"), (0.95, \"#27ae60\"), (0.99, \"#3498db\")]:\n    speedup = 1 / ((1 - p) + p / n_procs)\n    plt.plot(n_procs, speedup, color=color, linewidth=2, label=f\"p={p}\")\n    # \u6700\u5927\u52a0\u901f\u6bd4\u7ebf\n    plt.axhline(1 / (1 - p), color=color, linestyle=\"--\", alpha=0.3)\n\nplt.xlabel(\"\u5904\u7406\u5668\u6570\u91cf\")\nplt.ylabel(\"\u52a0\u901f\u6bd4\")\nplt.title(\"\u963f\u59c6\u8fbe\u5c14\u5b9a\u5f8b\uff1a\u4e32\u884c\u6bd4\u4f8b\u9650\u5236\u52a0\u901f\u6bd4\")\nplt.legend()\nplt.grid(True)\nplt.show()\n

"},{"location":"chapter%2013%3A%20computing%20and%20OS/05.%20programming%20languages/","title":"\u7f16\u7a0b\u8bed\u8a00","text":"

\u7f16\u7a0b\u8bed\u8a00\u662f\u4eba\u7c7b\u610f\u56fe\u4e0e\u673a\u5668\u6267\u884c\u4e4b\u95f4\u7684\u63a5\u53e3\u3002\u672c\u6587\u6db5\u76d6\u8bed\u8a00\u8303\u5f0f\u3001\u7c7b\u578b\u7cfb\u7edf\u3001\u5185\u5b58\u7ba1\u7406\u7b56\u7565\u3001\u7f16\u8bd1\u6d41\u6c34\u7ebf\u3001\u89e3\u91ca\u4e0eJIT\u7f16\u8bd1\u3001\u5173\u952e\u8bed\u8a00\u7279\u6027\u3001\u9886\u57df\u7279\u5b9a\u8bed\u8a00\u4ee5\u53ca\u8bbe\u8ba1\u6743\u8861\u3002

"},{"location":"chapter%2013%3A%20computing%20and%20OS/05.%20programming%20languages/#_2","title":"\u8bed\u8a00\u8303\u5f0f","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/05.%20programming%20languages/#_3","title":"\u7c7b\u578b\u7cfb\u7edf","text":"
let x: i32 = 5;     // Rust\uff1ax\u662f\u4e00\u4e2a32\u4f4d\u6574\u6570\nlet y: f64 = 3.14;  // y\u662f\u4e00\u4e2a64\u4f4d\u6d6e\u70b9\u6570\n// let z = x + y;    // \u7f16\u8bd1\u9519\u8bef\uff1a\u4e0d\u80fd\u52a0 i32 \u548c f64\n
x = 5       # x\u662f\u4e00\u4e2aint\uff08\u76ee\u524d\uff09\nx = \"hello\" # \u73b0\u5728x\u662f\u4e00\u4e2a\u5b57\u7b26\u4e32\u2014\u2014\u6ca1\u6709\u9519\u8bef\n
let x = 5;        // \u7f16\u8bd1\u5668\u63a8\u65ad\uff1ai32\nlet y = x + 3.0;  // \u7f16\u8bd1\u9519\u8bef\uff1a\u6df7\u5408\u7c7b\u578b\uff0c\u5373\u4f7f\u6709\u63a8\u65ad\n
fn largest<T: PartialOrd>(list: &[T]) -> &T {\n    let mut max = &list[0];\n    for item in &list[1..] {\n        if item > max { max = item; }\n    }\n    max\n}\n// \u9002\u7528\u4e8e\u6574\u6570\u3001\u6d6e\u70b9\u6570\u3001\u5b57\u7b26\u4e32\u2014\u2014\u4efb\u4f55\u652f\u6301\u6bd4\u8f83\u7684\u7c7b\u578b\n
"},{"location":"chapter%2013%3A%20computing%20and%20OS/05.%20programming%20languages/#_4","title":"\u5185\u5b58\u7ba1\u7406","text":" "},{"location":"chapter%2013%3A%20computing%20and%20OS/05.%20programming%20languages/#_5","title":"\u7f16\u8bd1\u6d41\u6c34\u7ebf","text":"
  1. \u8bcd\u6cd5\u5206\u6790\uff08\u5206\u8bcd\uff09\uff1a\u5c06\u6e90\u6587\u672c\u8f6c\u6362\u4e3a\u4ee4\u724c\u6d41\u3002x = 3 + y \u53d8\u4e3a [IDENT(\"x\"), EQUALS, INT(3), PLUS, IDENT(\"y\")]\u3002\u8bcd\u6cd5\u5206\u6790\u5668\u53bb\u9664\u7a7a\u767d\u548c\u6ce8\u91ca\u3002

  2. \u8bed\u6cd5\u5206\u6790\uff1a\u4ece\u4ee4\u724c\u6d41\u6784\u5efa\u62bd\u8c61\u8bed\u6cd5\u6811\uff08AST\uff09\u3002AST\u8868\u793a\u7a0b\u5e8f\u7684\u5c42\u6b21\u7ed3\u6784\u30023 + y * 2 \u89e3\u6790\u4e3a Add(3, Mul(y, 2))\uff08\u4e58\u6cd5\u4f18\u5148\u7ea7\u66f4\u9ad8\uff09\u3002\u89e3\u6790\u5668\u68c0\u67e5\u8bed\u6cd5\uff1a\u62ec\u53f7\u4e0d\u5339\u914d\u548c\u7f3a\u5c11\u5206\u53f7\u5728\u6b64\u88ab\u6355\u83b7\u3002

  3. \u8bed\u4e49\u5206\u6790\uff1a\u68c0\u67e5\u7c7b\u578b\u3001\u89e3\u6790\u53d8\u91cf\u540d\u3001\u9a8c\u8bc1\u51fd\u6570\u8c03\u7528\u53c2\u6570\u662f\u5426\u6b63\u786e\u3002\u9759\u6001\u7c7b\u578b\u68c0\u67e5\u5728\u6b64\u53d1\u751f\u3002\u8f93\u51fa\u662f\u5e26\u7c7b\u578b\u6ce8\u89e3\u7684AST\u3002

  4. \u4f18\u5316\uff1a\u5728\u4e0d\u6539\u53d8\u884c\u4e3a\u7684\u60c5\u51b5\u4e0b\u8f6c\u6362\u7a0b\u5e8f\u4ee5\u4f7f\u5176\u8fd0\u884c\u66f4\u5feb\u3002\u5e38\u89c1\u4f18\u5316\uff1a

  5. \u4ee3\u7801\u751f\u6210\uff1a\u5c06\u4f18\u5316\u540e\u7684\u8868\u793a\u8f6c\u6362\u4e3a\u76ee\u6807\u673a\u5668\u7801\uff08x86\u3001ARM\uff09\u6216\u4e2d\u95f4\u8868\u793a\u3002

  6. LLVM\u662f\u4e3b\u6d41\u7684\u7f16\u8bd1\u5668\u57fa\u7840\u8bbe\u65bd\u3002\u5b83\u63d0\u4f9b\u4e86\u4e00\u4e2a\u901a\u7528\u4e2d\u95f4\u8868\u793a\uff08LLVM IR\uff09\uff0c\u8bb8\u591a\u8bed\u8a00\u53ef\u4ee5\u7f16\u8bd1\u5230\u8be5\u8868\u793a\u4e0a\u3002LLVM\u7684\u4f18\u5316\u5668\u5728\u8fd9\u4e2aIR\u4e0a\u5de5\u4f5c\uff0c\u5176\u540e\u7aef\u4e3a\u8bb8\u591a\u76ee\u6807\u751f\u6210\u673a\u5668\u7801\u3002Clang\uff08C/C++\uff09\u3001Rust\u3001Swift\u3001Julia\u548c\u8bb8\u591a\u5176\u4ed6\u8bed\u8a00\u4f7f\u7528LLVM\u3002\u8fd9\u610f\u5473\u7740LLVM\u4f18\u5316\u5668\u7684\u6539\u8fdb\u540c\u65f6\u60e0\u53ca\u6240\u6709\u8fd9\u4e9b\u8bed\u8a00\u3002

"},{"location":"chapter%2013%3A%20computing%20and%20OS/05.%20programming%20languages/#jit","title":"\u89e3\u91ca\u4e0eJIT\u7f16\u8bd1","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/05.%20programming%20languages/#_6","title":"\u5173\u952e\u8bed\u8a00\u7279\u6027","text":"
def make_adder(n):\n    def add(x):\n        return x + n  # n \u4ece\u5305\u56f4\u4f5c\u7528\u57df\u6355\u83b7\n    return add\n\nadd5 = make_adder(5)\nprint(add5(3))  # 8\n
match value {\n    Some(x) if x > 0 => println!(\"Positive: {}\", x),\n    Some(0)           => println!(\"Zero\"),\n    Some(x)           => println!(\"Negative: {}\", x),\n    None              => println!(\"Nothing\"),\n}\n
"},{"location":"chapter%2013%3A%20computing%20and%20OS/05.%20programming%20languages/#_7","title":"\u9886\u57df\u7279\u5b9a\u8bed\u8a00","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/05.%20programming%20languages/#_8","title":"\u8bed\u8a00\u8bbe\u8ba1\u6743\u8861","text":""},{"location":"chapter%2013%3A%20computing%20and%20OS/05.%20programming%20languages/#colab","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528CoLab\u6216\u7b14\u8bb0\u672c\uff09","text":"
  1. \u63a2\u7d22\u95ed\u5305\u548c\u9ad8\u9636\u51fd\u6570\u3002\u5b9e\u73b0\u4e00\u4e2a\u7b80\u5355\u7684\u51fd\u6570\u5de5\u5382\uff0c\u9a8c\u8bc1\u95ed\u5305\u6355\u83b7\u5176\u73af\u5883\u3002

    def make_multiplier(factor):\n    \"\"\"\u8fd4\u56de\u4e00\u4e2a\u5c06\u8f93\u5165\u4e58\u4ee5 factor \u7684\u51fd\u6570\u3002\"\"\"\n    def multiply(x):\n        return x * factor\n    return multiply\n\ndouble = make_multiplier(2)\ntriple = make_multiplier(3)\n\nprint(f\"double(5) = {double(5)}\")  # 10\nprint(f\"triple(5) = {triple(5)}\")  # 15\n\n# \u95ed\u5305\u901a\u8fc7\u5f15\u7528\u6355\u83b7\uff0c\u800c\u4e0d\u662f\u901a\u8fc7\u503c\ndef make_counter():\n    count = [0]  # \u53ef\u53d8\u7684\u5bb9\u5668\u4ee5\u5141\u8bb8\u4fee\u6539\n    def increment():\n        count[0] += 1\n        return count[0]\n    return increment\n\ncounter = make_counter()\nprint(f\"counter() = {counter()}\")  # 1\nprint(f\"counter() = {counter()}\")  # 2\nprint(f\"counter() = {counter()}\")  # 3\n

  2. \u6bd4\u8f83\u52a8\u6001\u4e0e\u9759\u6001\u7c7b\u578b\u884c\u4e3a\u3002\u5c55\u793aPython\u7684\u52a8\u6001\u7c7b\u578b\u5982\u4f55\u63d0\u4f9b\u7075\u6d3b\u6027\u4f46\u53ef\u80fd\u9690\u85cfbug\u3002

    def add(a, b):\n    return a + b\n\n# \u9002\u7528\u4e8e\u4e0d\u540c\u7c7b\u578b\u2014\u2014\u7075\u6d3b\uff01\nprint(add(3, 5))           # 8 (int + int)\nprint(add(\"hello \", \"world\"))  # \"hello world\" (str + str)\nprint(add([1, 2], [3, 4]))    # [1, 2, 3, 4] (list + list)\n\n# \u4f46\u7c7b\u578b\u9519\u8bef\u4ec5\u5728\u8fd0\u884c\u65f6\u66b4\u9732\uff1a\ntry:\n    print(add(\"hello\", 5))  # TypeError\uff01str + int\nexcept TypeError as e:\n    print(f\"\u8fd0\u884c\u65f6\u9519\u8bef\uff1a{e}\")\n    print(\"\u9759\u6001\u7c7b\u578b\u68c0\u67e5\u5668\u4f1a\u5728\u8fd0\u884c\u524d\u6355\u83b7\u6b64\u95ee\u9898\")\n

  3. \u6d4b\u91cf\u89e3\u91ca\u578bPython\u4e0e\u7f16\u8bd1/JIT\u65b9\u6cd5\u5728\u8ba1\u7b97\u5bc6\u96c6\u578b\u4efb\u52a1\u4e0a\u7684\u6027\u80fd\u5dee\u5f02\u3002

    import time\nimport jax\nimport jax.numpy as jnp\n\nn = 1_000_000\n\n# \u7eafPython\u5faa\u73af\uff08\u89e3\u91ca\u578b\uff09\nstart = time.time()\ntotal = 0.0\nfor i in range(n):\n    total += i * i\npython_time = time.time() - start\n\n# JAX\uff08\u901a\u8fc7XLA\u7f16\u8bd1\uff09\n@jax.jit\ndef sum_squares_jax(n):\n    return jnp.sum(jnp.arange(n, dtype=jnp.float32) ** 2)\n\n_ = sum_squares_jax(10)  # \u9884\u70edJIT\nstart = time.time()\nresult = sum_squares_jax(n)\njax_time = time.time() - start\n\nprint(f\"Python loop: {python_time:.4f}s\")\nprint(f\"JAX (JIT):   {jax_time:.6f}s\")\nprint(f\"Speedup:     {python_time / jax_time:.0f}x\")\n

"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/","title":"\u57fa\u7840\uff1a\u5927O\u8868\u793a\u6cd5\u3001\u9012\u5f52\u3001\u56de\u6eaf\u4e0e\u52a8\u6001\u89c4\u5212","text":"

\u5728\u6df1\u5165\u5b66\u4e60\u6570\u636e\u7ed3\u6784\u548c\u7b97\u6cd5\u4e4b\u524d\uff0c\u4f60\u9700\u8981\u638c\u63e1\u56db\u4e2a\u57fa\u7840\u6982\u5ff5\uff1a\u8861\u91cf\u6548\u7387\u7684\u5927O\u8868\u793a\u6cd5\u3001\u5c06\u95ee\u9898\u5206\u89e3\u4e3a\u5b50\u95ee\u9898\u7684\u9012\u5f52\u3001\u5e26\u526a\u679d\u7684\u7a77\u4e3e\u641c\u7d22\u2014\u2014\u56de\u6eaf\uff0c\u4ee5\u53ca\u907f\u514d\u5197\u4f59\u8ba1\u7b97\u7684\u52a8\u6001\u89c4\u5212\u3002\u672c\u6587\u4ef6\u4ece\u57fa\u672c\u539f\u7406\u51fa\u53d1\u9010\u4e00\u8bb2\u89e3\u3002

"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#_1","title":"\u4e3a\u4ec0\u4e48\u662f\u6a21\u5f0f\uff0c\u800c\u975e\u6b7b\u8bb0\u786c\u80cc","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#o_1","title":"\u5927O\u8868\u793a\u6cd5","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#_2","title":"\u589e\u957f\u7387\u5c42\u7ea7","text":" \u5927O \u540d\u79f0 \u793a\u4f8b \\(n = 10^6\\) \u6b21\u64cd\u4f5c \\(O(1)\\) \u5e38\u6570\u7ea7 \u6570\u7ec4\u8bbf\u95ee\u3001\u54c8\u5e0c\u67e5\u627e 1 \\(O(\\log n)\\) \u5bf9\u6570\u7ea7 \u4e8c\u5206\u67e5\u627e 20 \\(O(n)\\) \u7ebf\u6027\u7ea7 \u7ebf\u6027\u626b\u63cf\u3001\u5355\u5faa\u73af \\(10^6\\) \\(O(n \\log n)\\) \u7ebf\u6027\u5bf9\u6570\u7ea7 \u5f52\u5e76\u6392\u5e8f\u3001\u9ad8\u6548\u6392\u5e8f \\(2 \\times 10^7\\) \\(O(n^2)\\) \u5e73\u65b9\u7ea7 \u5d4c\u5957\u5faa\u73af\u3001\u66b4\u529b\u914d\u5bf9 \\(10^{12}\\)\uff08\u592a\u6162\uff09 \\(O(n^3)\\) \u7acb\u65b9\u7ea7 \u4e09\u5c42\u5d4c\u5957\u5faa\u73af\u3001\u77e9\u9635\u4e58\u6cd5 \\(10^{18}\\)\uff08\u5b9e\u5728\u592a\u6162\uff09 \\(O(2^n)\\) \u6307\u6570\u7ea7 \u6240\u6709\u5b50\u96c6\u3001\u66b4\u529b\u56de\u6eaf \\(10^{301030}\\)\uff08\u4e0d\u53ef\u80fd\uff09 \\(O(n!)\\) \u9636\u4e58\u7ea7 \u6240\u6709\u6392\u5217 \u8352\u8c2c "},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#o_2","title":"\u5982\u4f55\u5206\u6790\u5927O","text":"
total = 0\nfor x in arr:   # n \u6b21\u8fed\u4ee3\n    total += x   # \u6bcf\u6b21\u8fed\u4ee3 O(1)\n# \u603b\u8ba1\uff1aO(n)\n
for i in range(n):       # n \u6b21\u8fed\u4ee3\n    for j in range(n):   # \u6bcf\u6b21 n \u6b21\u8fed\u4ee3\n        process(i, j)    # O(1)\n# \u603b\u8ba1\uff1aO(n^2)\n
i = n\nwhile i > 0:\n    process(i)\n    i //= 2\n# \u603b\u8ba1\uff1aO(log n)\n
for i in range(n):\n    for j in range(i):   # j \u4ece 0 \u5230 i-1\n        process(i, j)\n# \u603b\u8ba1\uff1a0 + 1 + 2 + ... + (n-1) = n(n-1)/2 = O(n^2)\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#_3","title":"\u5e38\u89c1\u9677\u9631","text":"
# \u4e0d\u597d\uff1aO(n^2) \u2014 \u5bf9\u5217\u8868\u7528 \"in\" \u662f O(n)\nfor x in arr:\n    if x in another_list:\n        process(x)\n\n# \u597d\uff1aO(n) \u2014 \u5148\u8f6c\u6362\u4e3a set\nanother_set = set(another_list)\nfor x in arr:\n    if x in another_set:\n        process(x)\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#_4","title":"\u7a7a\u95f4\u590d\u6742\u5ea6","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#_5","title":"\u9012\u5f52","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#_6","title":"\u793a\u4f8b\uff1a\u9636\u4e58","text":"
def factorial(n):\n    if n <= 1:        # \u57fa\u672c\u60c5\u51b5\n        return 1\n    return n * factorial(n - 1)  # \u9012\u5f52\u60c5\u51b5\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#_7","title":"\u5982\u4f55\u4ee5\u9012\u5f52\u65b9\u5f0f\u601d\u8003","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#_8","title":"\u793a\u4f8b\uff1a\u94fe\u8868\u4e0a\u7684\u9012\u5f52","text":"
def reverse(head):\n    if not head or not head.next:   # \u57fa\u672c\u60c5\u51b5\uff1a0 \u6216 1 \u4e2a\u8282\u70b9\n        return head\n\n    new_head = reverse(head.next)   # \u53cd\u8f6c\u5269\u4f59\u90e8\u5206\n    head.next.next = head           # \u5c06\u4e0b\u4e00\u4e2a\u8282\u70b9\u6307\u56de\u5f53\u524d\u8282\u70b9\n    head.next = None                # \u5f53\u524d\u8282\u70b9\u73b0\u5728\u6210\u4e3a\u5c3e\u8282\u70b9\n    return new_head\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#_9","title":"\u793a\u4f8b\uff1a\u6811\u4e0a\u7684\u9012\u5f52","text":"
def height(root):\n    if not root:           # \u57fa\u672c\u60c5\u51b5\uff1a\u7a7a\u6811\u9ad8\u5ea6\u4e3a 0\n        return 0\n    left_h = height(root.left)    # \u5de6\u5b50\u6811\u9ad8\u5ea6\n    right_h = height(root.right)  # \u53f3\u5b50\u6811\u9ad8\u5ea6\n    return 1 + max(left_h, right_h)  # \u5f53\u524d\u8282\u70b9\u589e\u52a0 1 \u5c42\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#vs","title":"\u9012\u5f52 vs \u8fed\u4ee3","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#_10","title":"\u5e38\u89c1\u9677\u9631","text":"\u9677\u9631 \u793a\u4f8b \u4fee\u590d \u7f3a\u5c11\u57fa\u672c\u60c5\u51b5 \u65e0\u9650\u9012\u5f52 \u2192 \u6808\u6ea2\u51fa \u59cb\u7ec8\u5b9a\u4e49\u4f55\u65f6\u505c\u6b62 \u57fa\u672c\u60c5\u51b5\u9519\u8bef \u9012\u5f52\u5206\u89e3\u4e2d\u7684\u5dee\u4e00\u9519\u8bef \u7528\u6700\u5c0f\u7684\u8f93\u5165\u6d4b\u8bd5\uff080\u30011\u30012\uff09 \u95ee\u9898\u89c4\u6a21\u672a\u51cf\u5c0f f(n) \u8c03\u7528 f(n) \u800c\u975e f(n-1) \u786e\u4fdd\u5b50\u95ee\u9898\u4e25\u683c\u66f4\u5c0f \u5197\u4f59\u8ba1\u7b97 \u6590\u6ce2\u90a3\u5951\u6570\u5217\uff1af(n) = f(n-1) + f(n-2) \u4ee5\u6307\u6570\u7ea7\u91cd\u590d\u8ba1\u7b97 \u4f7f\u7528\u8bb0\u5fc6\u5316\uff08\u2192 DP\uff09 Python \u9012\u5f52\u9650\u5236 factorial(10000) \u5d29\u6e83 \u4f7f\u7528 sys.setrecursionlimit \u6216\u8f6c\u4e3a\u8fed\u4ee3"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#_11","title":"\u56de\u6eaf","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#_12","title":"\u4e09\u4e2a\u6b65\u9aa4","text":"

\u6bcf\u4e2a\u56de\u6eaf\u7b97\u6cd5\u90fd\u9075\u5faa\u76f8\u540c\u7684\u6a21\u5f0f\uff1a

  1. \u9009\u62e9\uff1a\u9009\u62e9\u4e00\u4e2a\u5019\u9009\u6765\u6269\u5c55\u5f53\u524d\u7684\u90e8\u5206\u89e3\u3002
  2. \u63a2\u7d22\uff1a\u9012\u5f52\u5730\u5c1d\u8bd5\u4ece\u8fd9\u4e2a\u5019\u9009\u6784\u5efa\u4e00\u4e2a\u5b8c\u6574\u7684\u89e3\u3002
  3. \u64a4\u9500\uff1a\u64a4\u9500\u9009\u62e9\uff08\u56de\u6eaf\uff09\u5e76\u5c1d\u8bd5\u4e0b\u4e00\u4e2a\u5019\u9009\u3002
def backtrack(state, choices, result):\n    if is_complete(state):\n        result.append(state.copy())\n        return\n\n    for choice in choices:\n        if is_valid(choice, state):\n            state.add(choice)           # 1. \u9009\u62e9\n            backtrack(state, choices, result)  # 2. \u63a2\u7d22\n            state.remove(choice)        # 3. \u64a4\u9500\uff08\u56de\u6eaf\uff09\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#_13","title":"\u4f55\u65f6\u4f7f\u7528\u56de\u6eaf","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#_14","title":"\u526a\u679d\u5982\u4f55\u4f7f\u5176\u53d8\u5feb","text":"
for choice in choices:\n    if not is_valid(choice, state):\n        continue  # \u526a\u679d\uff1a\u8df3\u8fc7\u6574\u4e2a\u5b50\u6811\n\n    state.add(choice)\n    backtrack(state, choices, result)\n    state.remove(choice)\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#_15","title":"\u751f\u6210\u6240\u6709\u5b50\u96c6\uff08\u6700\u7b80\u5355\u7684\u56de\u6eaf\uff09","text":"
def subsets(nums):\n    result = []\n\n    def backtrack(start, path):\n        result.append(path[:])  # \u6bcf\u4e2a\u90e8\u5206\u89e3\u90fd\u662f\u4e00\u4e2a\u6709\u6548\u7684\u5b50\u96c6\n\n        for i in range(start, len(nums)):\n            path.append(nums[i])        # \u9009\u62e9\n            backtrack(i + 1, path)       # \u63a2\u7d22\uff08i+1\uff1a\u4e0d\u5141\u8bb8\u91cd\u590d\u4f7f\u7528\uff09\n            path.pop()                   # \u64a4\u9500\n\n    backtrack(0, [])\n    return result\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#_16","title":"\u751f\u6210\u6240\u6709\u6392\u5217","text":"
def permutations(nums):\n    result = []\n\n    def backtrack(path, remaining):\n        if not remaining:\n            result.append(path[:])\n            return\n\n        for i in range(len(remaining)):\n            path.append(remaining[i])                    # \u9009\u62e9\n            backtrack(path, remaining[:i] + remaining[i+1:])  # \u63a2\u7d22\n            path.pop()                                   # \u64a4\u9500\n\n    backtrack([], nums)\n    return result\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#_17","title":"\u5e38\u89c1\u9677\u9631","text":"\u9677\u9631 \u793a\u4f8b \u4fee\u590d \u5fd8\u8bb0\u590d\u5236\u8def\u5f84 result.append(path) \u2014\u2014 \u6240\u6709\u6761\u76ee\u5171\u4eab\u540c\u4e00\u4e2a\u5217\u8868 result.append(path[:]) \u6216 path.copy() \u672a\u56de\u6eaf\uff08\u64a4\u9500\uff09 \u72b6\u6001\u4e0d\u65ad\u589e\u957f\uff0c\u540e\u9762\u7684\u5019\u9009\u770b\u5230\u8fc7\u65f6\u7684\u72b6\u6001 \u9012\u5f52\u8c03\u7528\u540e\u59cb\u7ec8\u6267\u884c path.pop() \u6216 state.remove() \u5faa\u73af\u8d77\u59cb\u4f4d\u7f6e\u9519\u8bef \u5b50\u96c6\u4e2d\u6709\u91cd\u590d\u9879\uff0c\u6216\u6392\u5217\u4e2d\u51fa\u73b0\u4e86\u4e0d\u5e94\u6709\u7684\u91cd\u590d\u4f7f\u7528 \u4f7f\u7528 start \u53c2\u6570\u907f\u514d\u91cd\u65b0\u8bbf\u95ee\u4e4b\u524d\u7684\u7d22\u5f15 \u8df3\u8fc7\u526a\u679d \u63a2\u7d22\u660e\u663e\u65e0\u6548\u7684\u5206\u652f \u5728\u9012\u5f52\u8c03\u7528\u524d\u6dfb\u52a0 if not is_valid: continue"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#_18","title":"\u52a8\u6001\u89c4\u5212","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#_19","title":"\u6590\u6ce2\u90a3\u5951\u6570\u5217\u7684\u52a8\u673a","text":"
def fib(n):\n    if n <= 1:\n        return n\n    return fib(n - 1) + fib(n - 2)\n
def fib_memo(n, memo={}):\n    if n in memo:\n        return memo[n]\n    if n <= 1:\n        return n\n    memo[n] = fib_memo(n - 1, memo) + fib_memo(n - 2, memo)\n    return memo[n]\n
def fib_tab(n):\n    if n <= 1:\n        return n\n    dp = [0] * (n + 1)\n    dp[1] = 1\n    for i in range(2, n + 1):\n        dp[i] = dp[i - 1] + dp[i - 2]\n    return dp[n]\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#dp","title":"DP \u914d\u65b9","text":"

\u5bf9\u4e8e\u4efb\u4f55 DP \u95ee\u9898\uff0c\u9075\u5faa\u4ee5\u4e0b\u6b65\u9aa4\uff1a

  1. \u5b9a\u4e49\u72b6\u6001\uff1adp[i]\uff08\u6216 dp[i][j]\uff09\u4ee3\u8868\u4ec0\u4e48\uff1f\u8fd9\u662f\u6700\u96be\u7684\u4e00\u6b65\u3002\u72b6\u6001\u5fc5\u987b\u6355\u83b7\u8db3\u591f\u7684\u4fe1\u606f\u4ee5\u505a\u51fa\u6700\u4f18\u51b3\u7b56\u3002

  2. \u5199\u51fa\u9012\u63a8\u5173\u7cfb\uff1adp[i] \u5982\u4f55\u4e0e\u66f4\u5c0f\u7684\u5b50\u95ee\u9898\u5173\u8054\uff1f\u8fd9\u662f\u8f6c\u79fb\u516c\u5f0f\u3002

  3. \u786e\u5b9a\u57fa\u672c\u60c5\u51b5\uff1a\u54ea\u4e9b\u662f\u6700\u5c0f\u7684\u5b50\u95ee\u9898\uff0c\u53ef\u4ee5\u76f4\u63a5\u6c42\u89e3\uff1f

  4. \u786e\u5b9a\u8fed\u4ee3\u987a\u5e8f\uff1a\u54ea\u4e9b\u5b50\u95ee\u9898\u5fc5\u987b\u5148\u4e8e\u54ea\u4e9b\u5b50\u95ee\u9898\u6c42\u89e3\uff1f\u81ea\u5e95\u5411\u4e0a\uff1a\u6309\u7167\u786e\u4fdd\u4f9d\u8d56\u5173\u7cfb\u5df2\u89e3\u51b3\u7684\u987a\u5e8f\u8fed\u4ee3\u3002\u81ea\u9876\u5411\u4e0b\uff1a\u9012\u5f52\u4f1a\u81ea\u52a8\u5904\u7406\u3002

  5. \u4f18\u5316\u7a7a\u95f4\uff08\u53ef\u9009\uff09\uff1a\u5982\u679c dp[i] \u53ea\u4f9d\u8d56\u4e8e\u524d\u4e00\u884c\u6216\u524d\u51e0\u4e2a\u6761\u76ee\uff0c\u4f60\u5c31\u4e0d\u9700\u8981\u5b8c\u6574\u7684\u8868\u3002

"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#_20","title":"\u793a\u4f8b\uff1a\u601d\u8def\u8fc7\u7a0b","text":"

\u95ee\u9898\uff1a\u7ed9\u5b9a\u4e00\u4e2a\u6b63\u6574\u6570\u6570\u7ec4\uff0c\u6c42\u4e0d\u76f8\u90bb\u5143\u7d20\u7684\u6700\u5927\u548c\uff08\u6253\u5bb6\u52ab\u820d\uff09\u3002

\u7b2c1\u6b65\u2014\u2014\u5b9a\u4e49\u72b6\u6001\uff1adp[i] = \u8003\u8651\u5143\u7d20 nums[0..i] \u7684\u6700\u5927\u548c\u3002

\u7b2c2\u6b65\u2014\u2014\u5199\u51fa\u9012\u63a8\u5173\u7cfb\uff1a\u5bf9\u4e8e\u5143\u7d20 \\(i\\)\uff0c\u6211\u4eec\u8981\u4e48\uff1a - \u8df3\u8fc7\u5b83\uff1adp[i] = dp[i-1]\uff08\u4e0d\u542b\u5143\u7d20 \\(i\\) \u7684\u6700\u4f73\u548c\uff09\u3002 - \u53d6\u7528\u5b83\uff1adp[i] = dp[i-2] + nums[i]\uff08\u5fc5\u987b\u8df3\u8fc7\u5143\u7d20 \\(i-1\\)\uff0c\u7136\u540e\u52a0\u4e0a\u5143\u7d20 \\(i\\)\uff09\u3002

\u6240\u4ee5\uff1adp[i] = max(dp[i-1], dp[i-2] + nums[i])\u3002

\u7b2c3\u6b65\u2014\u2014\u57fa\u672c\u60c5\u51b5\uff1adp[0] = nums[0]\uff0cdp[1] = max(nums[0], nums[1])\u3002

\u7b2c4\u6b65\u2014\u2014\u8fed\u4ee3\u987a\u5e8f\uff1a\u4ece\u5de6\u5230\u53f3\uff08\u6bcf\u4e2a\u72b6\u6001\u4f9d\u8d56\u4e8e\u524d\u4e24\u4e2a\u72b6\u6001\uff09\u3002

\u7b2c5\u6b65\u2014\u2014\u7a7a\u95f4\u4f18\u5316\uff1a\u53ea\u9700\u8981\u6700\u540e\u4e24\u4e2a\u503c\u3002

def rob(nums):\n    if len(nums) == 1:\n        return nums[0]\n\n    prev2, prev1 = nums[0], max(nums[0], nums[1])\n\n    for i in range(2, len(nums)):\n        curr = max(prev1, prev2 + nums[i])\n        prev2, prev1 = prev1, curr\n\n    return prev1\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#dp_1","title":"\u5982\u4f55\u8bc6\u522b DP \u95ee\u9898","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#dp_2","title":"DP \u7684\u5206\u7c7b","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#vs_1","title":"\u81ea\u9876\u5411\u4e0b vs \u81ea\u5e95\u5411\u4e0a","text":"\u81ea\u9876\u5411\u4e0b\uff08\u8bb0\u5fc6\u5316\uff09 \u81ea\u5e95\u5411\u4e0a\uff08\u5236\u8868\u6cd5\uff09 \u5b9e\u73b0 \u9012\u5f52 + \u7f13\u5b58 \u8fed\u4ee3 + \u8868 \u8ba1\u7b97 \u53ea\u8ba1\u7b97\u5b9e\u9645\u9700\u8981\u7684\u5b50\u95ee\u9898 \u8ba1\u7b97\u76f4\u5230\u76ee\u6807\u7684\u6240\u6709\u5b50\u95ee\u9898 \u6808\u6ea2\u51fa\u98ce\u9669 \u6709\uff08\u6df1\u5ea6\u9012\u5f52\uff09 \u65e0 \u7a7a\u95f4\u4f18\u5316 \u8f83\u96be \u8f83\u6613\uff08\u4f7f\u7528\u6eda\u52a8\u6570\u7ec4\uff09 \u7f16\u7801\u96be\u5ea6 \u901a\u5e38\u66f4\u81ea\u7136\uff08\u5199\u9012\u5f52\uff0c\u52a0\u7f13\u5b58\uff09 \u9700\u8981\u8003\u8651\u8fed\u4ee3\u987a\u5e8f "},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#_21","title":"\u5e38\u89c1\u9677\u9631","text":"\u9677\u9631 \u793a\u4f8b \u4fee\u590d \u72b6\u6001\u5b9a\u4e49\u9519\u8bef dp[i] \u6ca1\u6709\u6355\u83b7\u8db3\u591f\u4fe1\u606f\u6765\u505a\u51b3\u7b56 \u589e\u52a0\u7ef4\u5ea6\uff08\u4f8b\u5982\u7528 dp[i][j] \u4ee3\u66ff dp[i]\uff09 \u7f3a\u5c11\u57fa\u672c\u60c5\u51b5 dp[0] \u9519\u8bef \u2192 \u6240\u6709\u540e\u7eed\u503c\u90fd\u9519 \u624b\u52a8\u9a8c\u8bc1\u57fa\u672c\u60c5\u51b5 \u8fed\u4ee3\u987a\u5e8f\u9519\u8bef \u5728\u4f9d\u8d56\u5173\u7cfb\u672a\u89e3\u51b3\u4e4b\u524d\u8ba1\u7b97 dp[i] \u753b\u51fa\u4f9d\u8d56\u7bad\u5934\u5e76\u76f8\u5e94\u8fed\u4ee3 \u672a\u6b63\u786e\u521d\u59cb\u5316 dp \u7528 0 \u800c\u5e94\u8be5\u7528\u65e0\u7a77\u5927\uff08\u6c42\u6700\u5c0f\u503c\u65f6\uff09 \u6700\u5c0f\u5316\u7528 float('inf')\uff0c\u6700\u5927\u5316\u7528 float('-inf') \u5fd8\u8bb0\u8003\u8651\"\u8df3\u8fc7\"\u9009\u9879 \u603b\u662f\u53d6\u5f53\u524d\u5143\u7d20 \u9012\u63a8\u5173\u7cfb\u901a\u5e38\u6709 max(take, skip) \u53ef\u53d8\u7684\u9ed8\u8ba4\u53c2\u6570 def f(memo={}) \u5728\u8c03\u7528\u95f4\u5171\u4eab\u7f13\u5b58 def f(memo=None): if memo is None: memo = {} 2D DP \u4e2d\u7684\u5dee\u4e00\u9519\u8bef dp \u662f 1-indexed \u65f6\u8bbf\u95ee text1[i] dp \u5927\u5c0f\u4e3a (m+1) x (n+1)\uff0c\u8bbf\u95ee text1[i-1]"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/00.%20foundations/#_22","title":"\u878d\u4f1a\u8d2f\u901a","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/","title":"\u6570\u7ec4\u4e0e\u54c8\u5e0c","text":"

\u6570\u7ec4\u548c\u54c8\u5e0c\u8868\u662f\u7f16\u7a0b\u4e2d\u6700\u57fa\u7840\u7684\u4e24\u79cd\u6570\u636e\u7ed3\u6784\u3002\u672c\u6587\u4ef6\u6db5\u76d6\u5b83\u4eec\u5e95\u5c42\u7684\u8fd0\u884c\u673a\u5236\uff0c\u7136\u540e\u6784\u5efa\u5173\u952e\u7684\u95ee\u9898\u89e3\u51b3\u6a21\u5f0f\uff1a\u53cc\u6307\u9488\u3001\u6ed1\u52a8\u7a97\u53e3\u3001\u524d\u7f00\u548c\u4ee5\u53ca\u57fa\u4e8e\u54c8\u5e0c\u7684\u67e5\u627e\uff0c\u901a\u8fc7\u9010\u6b65\u589e\u52a0\u96be\u5ea6\u7684\u9898\u76ee\uff0c\u5e76\u5728\u6bcf\u4e00\u6b65\u6307\u51fa\u5e38\u89c1\u9677\u9631\u3002

"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/#_2","title":"\u6570\u7ec4","text":" \u64cd\u4f5c \u6570\u7ec4 \u52a8\u6001\u6570\u7ec4 \u6309\u7d22\u5f15\u8bbf\u95ee \\(O(1)\\) \\(O(1)\\) \u8ffd\u52a0 \u4e0d\u9002\u7528 \\(O(1)\\) \u5e73\u644a \u5728\u4f4d\u7f6e \\(i\\) \u63d2\u5165 \\(O(n)\\) \\(O(n)\\) \u5728\u4f4d\u7f6e \\(i\\) \u5220\u9664 \\(O(n)\\) \\(O(n)\\) \u641c\u7d22\uff08\u672a\u6392\u5e8f\uff09 \\(O(n)\\) \\(O(n)\\) "},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/#_3","title":"\u5b57\u7b26\u4e32","text":"
# \u4e0d\u597d\uff1aO(n^2) \u5b57\u7b26\u4e32\u62fc\u63a5\ns = \"\"\nfor c in characters:\n    s += c  # \u6bcf\u6b21\u590d\u5236\u6574\u4e2a\u5b57\u7b26\u4e32\n\n# \u597d\uff1aO(n) \u4f7f\u7528\u5217\u8868\u7136\u540e join\nparts = []\nfor c in characters:\n    parts.append(c)\ns = \"\".join(parts)\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/#_4","title":"\u54c8\u5e0c\u8868","text":" \u64cd\u4f5c \u5e73\u5747 \u6700\u574f\u60c5\u51b5 \u67e5\u627e \\(O(1)\\) \\(O(n)\\) \u63d2\u5165 \\(O(1)\\) \\(O(n)\\) \u5220\u9664 \\(O(1)\\) \\(O(n)\\) "},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/#_5","title":"\u6a21\u5f0f\uff1a\u54c8\u5e0c\u8868\u67e5\u627e","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/#_6","title":"\u7b80\u5355\uff1a\u4e24\u6570\u4e4b\u548c","text":"
def two_sum(nums, target):\n    seen = {}  # \u503c -> \u7d22\u5f15\n    for i, num in enumerate(nums):\n        complement = target - num\n        if complement in seen:\n            return [seen[complement], i]\n        seen[num] = i\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/#_7","title":"\u4e2d\u7b49\uff1a\u5b57\u6bcd\u5f02\u4f4d\u8bcd\u5206\u7ec4","text":"
from collections import defaultdict\n\ndef group_anagrams(strs):\n    groups = defaultdict(list)\n    for s in strs:\n        key = tuple(sorted(s))  # \u6216\u4f7f\u7528\u5b57\u7b26\u8ba1\u6570\u5143\u7ec4\n        groups[key].append(s)\n    return list(groups.values())\n
def group_anagrams_fast(strs):\n    groups = defaultdict(list)\n    for s in strs:\n        count = [0] * 26\n        for c in s:\n            count[ord(c) - ord('a')] += 1\n        groups[tuple(count)].append(s)\n    return list(groups.values())\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/#_8","title":"\u56f0\u96be\uff1a\u6700\u957f\u8fde\u7eed\u5e8f\u5217","text":"
def longest_consecutive(nums):\n    num_set = set(nums)\n    best = 0\n\n    for num in num_set:\n        # \u53ea\u4ece\u5e8f\u5217\u7684\u5f00\u5934\u5f00\u59cb\u8ba1\u6570\n        if num - 1 not in num_set:\n            length = 1\n            while num + length in num_set:\n                length += 1\n            best = max(best, length)\n\n    return best\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/#_9","title":"\u6a21\u5f0f\uff1a\u53cc\u6307\u9488","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/#_10","title":"\u7b80\u5355\uff1a\u9a8c\u8bc1\u56de\u6587\u4e32","text":"
def is_palindrome(s):\n    left, right = 0, len(s) - 1\n\n    while left < right:\n        # \u8df3\u8fc7\u975e\u5b57\u6bcd\u6570\u5b57\u5b57\u7b26\n        while left < right and not s[left].isalnum():\n            left += 1\n        while left < right and not s[right].isalnum():\n            right -= 1\n\n        if s[left].lower() != s[right].lower():\n            return False\n\n        left += 1\n        right -= 1\n\n    return True\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/#_11","title":"\u4e2d\u7b49\uff1a\u4e09\u6570\u4e4b\u548c","text":"
def three_sum(nums):\n    nums.sort()\n    result = []\n\n    for i in range(len(nums) - 2):\n        # \u8df3\u8fc7\u91cd\u590d\u7684\u56fa\u5b9a\u5143\u7d20\n        if i > 0 and nums[i] == nums[i - 1]:\n            continue\n\n        left, right = i + 1, len(nums) - 1\n        target = -nums[i]\n\n        while left < right:\n            total = nums[left] + nums[right]\n            if total < target:\n                left += 1\n            elif total > target:\n                right -= 1\n            else:\n                result.append([nums[i], nums[left], nums[right]])\n                # \u8df3\u8fc7\u91cd\u590d\u9879\n                while left < right and nums[left] == nums[left + 1]:\n                    left += 1\n                while left < right and nums[right] == nums[right - 1]:\n                    right -= 1\n                left += 1\n                right -= 1\n\n    return result\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/#_12","title":"\u56f0\u96be\uff1a\u63a5\u96e8\u6c34","text":"
def trap(height):\n    left, right = 0, len(height) - 1\n    left_max, right_max = 0, 0\n    water = 0\n\n    while left < right:\n        if height[left] < height[right]:\n            if height[left] >= left_max:\n                left_max = height[left]\n            else:\n                water += left_max - height[left]\n            left += 1\n        else:\n            if height[right] >= right_max:\n                right_max = height[right]\n            else:\n                water += right_max - height[right]\n            right -= 1\n\n    return water\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/#_13","title":"\u6a21\u5f0f\uff1a\u6ed1\u52a8\u7a97\u53e3","text":"
def sliding_window(arr):\n    left = 0\n    state = ...  # \u7a97\u53e3\u72b6\u6001\uff08\u8ba1\u6570\u3001\u548c\u7b49\uff09\n    best = ...\n\n    for right in range(len(arr)):\n        # \u6269\u5c55\uff1a\u5c06 arr[right] \u6dfb\u52a0\u5230\u7a97\u53e3\u72b6\u6001\n        update_state(state, arr[right])\n\n        # \u6536\u7f29\uff1a\u5f53\u7ea6\u675f\u88ab\u8fdd\u53cd\u65f6\u4ece\u5de6\u4fa7\u7f29\u5c0f\n        while constraint_violated(state):\n            remove_from_state(state, arr[left])\n            left += 1\n\n        # \u66f4\u65b0\u7b54\u6848\n        best = max(best, right - left + 1)  # \u6216 min\uff0c\u53d6\u51b3\u4e8e\u95ee\u9898\n\n    return best\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/#_14","title":"\u7b80\u5355\uff1a\u4e70\u5356\u80a1\u7968\u7684\u6700\u4f73\u65f6\u673a","text":"
def max_profit(prices):\n    min_price = float('inf')\n    max_profit = 0\n\n    for price in prices:\n        min_price = min(min_price, price)\n        max_profit = max(max_profit, price - min_price)\n\n    return max_profit\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/#_15","title":"\u4e2d\u7b49\uff1a\u65e0\u91cd\u590d\u5b57\u7b26\u7684\u6700\u957f\u5b50\u4e32","text":"
def length_of_longest_substring(s):\n    char_index = {}  # \u5b57\u7b26 -> \u5b83\u7684\u6700\u8fd1\u7d22\u5f15\n    left = 0\n    best = 0\n\n    for right, char in enumerate(s):\n        if char in char_index and char_index[char] >= left:\n            left = char_index[char] + 1  # \u8df3\u8fc7\u91cd\u590d\u5b57\u7b26\n\n        char_index[char] = right\n        best = max(best, right - left + 1)\n\n    return best\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/#_16","title":"\u56f0\u96be\uff1a\u6700\u5c0f\u8986\u76d6\u5b50\u4e32","text":"
from collections import Counter\n\ndef min_window(s, t):\n    if not t or not s:\n        return \"\"\n\n    need = Counter(t)       # \u6211\u4eec\u9700\u8981\u7684\u5b57\u7b26\u53ca\u5176\u8ba1\u6570\n    have = 0                # \u6211\u4eec\u5df2\u7ecf\u62e5\u6709\u8db3\u591f\u6570\u91cf\u7684\u552f\u4e00\u5b57\u7b26\u6570\n    required = len(need)    # \u6211\u4eec\u9700\u8981\u591a\u5c11\u79cd\u552f\u4e00\u5b57\u7b26\n\n    left = 0\n    best = (float('inf'), 0, 0)  # (\u957f\u5ea6, \u5de6, \u53f3)\n\n    window_counts = {}\n\n    for right in range(len(s)):\n        char = s[right]\n        window_counts[char] = window_counts.get(char, 0) + 1\n\n        # \u68c0\u67e5\u6b64\u5b57\u7b26\u7684\u8ba1\u6570\u662f\u5426\u6ee1\u8db3\u8981\u6c42\n        if char in need and window_counts[char] == need[char]:\n            have += 1\n\n        # \u5f53\u7a97\u53e3\u6709\u6548\u65f6\u4ece\u5de6\u4fa7\u6536\u7f29\n        while have == required:\n            # \u66f4\u65b0\u6700\u4f73\u503c\n            if (right - left + 1) < best[0]:\n                best = (right - left + 1, left, right)\n\n            # \u79fb\u9664\u6700\u5de6\u8fb9\u7684\u5b57\u7b26\n            left_char = s[left]\n            window_counts[left_char] -= 1\n            if left_char in need and window_counts[left_char] < need[left_char]:\n                have -= 1\n            left += 1\n\n    length, start, end = best\n    return s[start:end + 1] if length != float('inf') else \"\"\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/#_17","title":"\u6a21\u5f0f\uff1a\u524d\u7f00\u548c","text":"
def build_prefix(arr):\n    prefix = [0] * (len(arr) + 1)\n    for i in range(len(arr)):\n        prefix[i + 1] = prefix[i] + arr[i]\n    return prefix\n\n# arr[l:r] \u7684\u548c\uff08\u5305\u542b l\uff0c\u4e0d\u5305\u542b r\uff09\ndef range_sum(prefix, l, r):\n    return prefix[r] - prefix[l]\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/#_18","title":"\u7b80\u5355\uff1a\u533a\u95f4\u6c42\u548c\u67e5\u8be2","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/#k","title":"\u4e2d\u7b49\uff1a\u548c\u4e3a K \u7684\u5b50\u6570\u7ec4","text":"
def subarray_sum(nums, k):\n    count = 0\n    prefix = 0\n    prefix_counts = {0: 1}  # \u7a7a\u524d\u7f00\u548c\n\n    for num in nums:\n        prefix += num\n        # \u6709\u591a\u5c11\u66f4\u65e9\u7684\u524d\u7f00\u548c\u7b49\u4e8e prefix - k\uff1f\n        count += prefix_counts.get(prefix - k, 0)\n        prefix_counts[prefix] = prefix_counts.get(prefix, 0) + 1\n\n    return count\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/#_19","title":"\u56f0\u96be\uff1a\u9664\u81ea\u8eab\u4ee5\u5916\u6570\u7ec4\u7684\u4e58\u79ef","text":"
def product_except_self(nums):\n    n = len(nums)\n    result = [1] * n\n\n    # \u5de6\u5411\u904d\u5386\uff1aresult[i] = nums[0..i-1] \u7684\u4e58\u79ef\n    prefix = 1\n    for i in range(n):\n        result[i] = prefix\n        prefix *= nums[i]\n\n    # \u53f3\u5411\u904d\u5386\uff1a\u4e58\u4ee5 nums[i+1..n-1] \u7684\u4e58\u79ef\n    suffix = 1\n    for i in range(n - 1, -1, -1):\n        result[i] *= suffix\n        suffix *= nums[i]\n\n    return result\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/#_20","title":"\u5e38\u89c1\u9677\u9631\u603b\u7ed3","text":"\u9677\u9631 \u793a\u4f8b \u4fee\u590d \u7a97\u53e3\u5927\u5c0f\u7684\u5dee\u4e00\u9519\u8bef right - left vs right - left + 1 \u753b\u4e00\u4e2a2\u5143\u7d20\u793a\u4f8b Python \u4e2d\u7684\u53ef\u53d8\u9ed8\u8ba4\u503c def f(seen={}) \u5728\u8c03\u7528\u95f4\u5171\u4eab\u72b6\u6001 \u4f7f\u7528 def f(seen=None) \u5faa\u73af\u4e2d\u7684\u5b57\u7b26\u4e32\u62fc\u63a5 s += c \u5728 Python \u4e2d\u662f \\(O(n^2)\\) \u4f7f\u7528 list.append + \"\".join \u524d\u7f00\u548c\u4e2d\u5fd8\u8bb0 {0: 1} \u6f0f\u6389\u4ece\u7d22\u5f15 0 \u5f00\u59cb\u7684\u5b50\u6570\u7ec4 \u59cb\u7ec8\u7528\u7a7a\u524d\u7f00\u521d\u59cb\u5316 \u68c0\u67e5\u524d\u6dfb\u52a0\u54c8\u5e0c\u8868 \u4e24\u6570\u4e4b\u548c\uff1a\u5728\u68c0\u67e5\u8865\u6570\u4e4b\u524d\u6dfb\u52a0\u4e86 num \u5148\u68c0\u67e5\uff0c\u540e\u63d2\u5165 \u672a\u5904\u7406\u91cd\u590d\u9879 \u4e09\u6570\u4e4b\u548c\u8fd4\u56de\u91cd\u590d\u7684\u4e09\u5143\u7ec4 \u8df3\u8fc7\u8fde\u7eed\u76f8\u7b49\u7684\u503c \u6574\u6570\u6ea2\u51fa C++/Java \u4e2d\u5927\u6570\u7ec4\u6c42\u548c \u4f7f\u7528 long \u6216\u68c0\u67e5\u8fb9\u754c"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/#neetcode","title":"\u8bfe\u540e\u7ec3\u4e60\u9898\uff08NeetCode\uff09","text":"

\u6309\u987a\u5e8f\u7ec3\u4e60\u3002\u6bcf\u9053\u9898\u5f3a\u5316\u672c\u6587\u4ef6\u4e2d\u7684\u4e00\u4e2a\u6a21\u5f0f\u3002

"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/#_21","title":"\u54c8\u5e0c\u8868\u67e5\u627e","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/#_22","title":"\u53cc\u6307\u9488","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/#_23","title":"\u6ed1\u52a8\u7a97\u53e3","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/01.%20arrays%20and%20hashing/#_24","title":"\u524d\u7f00\u548c","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/02.%20linked%20lists%2C%20stacks%2C%20and%20queues/","title":"\u94fe\u8868\u3001\u6808\u548c\u961f\u5217","text":"

\u94fe\u8868\u3001\u6808\u548c\u961f\u5217\u662f\u66f4\u590d\u6742\u6570\u636e\u7ed3\u6784\u7684\u6784\u5efa\u6a21\u5757\u3002\u672c\u6587\u4ef6\u6db5\u76d6\u5b83\u4eec\u7684\u8fd0\u884c\u673a\u5236\uff0c\u7136\u540e\u6784\u5efa\u5173\u952e\u6a21\u5f0f\uff1a\u5feb\u6162\u6307\u9488\u3001\u5355\u8c03\u6808\u548c\u57fa\u4e8e\u5806\u7684\u4f18\u5148\u961f\u5217\uff0c\u901a\u8fc7\u9010\u6b65\u589e\u52a0\u96be\u5ea6\u7684\u9898\u76ee\uff0c\u5e76\u5728\u6bcf\u4e00\u6b65\u6307\u51fa\u5e38\u89c1\u9677\u9631\u3002

"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/02.%20linked%20lists%2C%20stacks%2C%20and%20queues/#_2","title":"\u94fe\u8868","text":"
class ListNode:\n    def __init__(self, val=0, next=None):\n        self.val = val\n        self.next = next\n
\u64cd\u4f5c \u5355\u5411 \u53cc\u5411 \u6309\u7d22\u5f15\u8bbf\u95ee \\(O(n)\\) \\(O(n)\\) \u5728\u5934\u90e8\u63d2\u5165 \\(O(1)\\) \\(O(1)\\) \u5728\u5c3e\u90e8\u63d2\u5165 \\(O(n)\\) \u6216 \\(O(1)\\)* \\(O(1)\\) \u5220\u9664\u7ed9\u5b9a\u8282\u70b9 \\(O(n)\\)** \\(O(1)\\) \u641c\u7d22 \\(O(n)\\) \\(O(n)\\)

\u6709\u5c3e\u6307\u9488\u65f6\u3002*\u9700\u8981\u524d\u9a71\u8282\u70b9\uff0c\u9700\u8981\u904d\u5386\u3002

# \u65e0\u865a\u62df\u8282\u70b9\uff1a\u5934\u90e8\u5220\u9664\u9700\u8981\u7279\u6b8a\u5904\u7406\ndef delete_head(head):\n    if not head:\n        return None\n    return head.next\n\n# \u6709\u865a\u62df\u8282\u70b9\uff1a\u7edf\u4e00\u903b\u8f91\ndummy = ListNode(0)\ndummy.next = head\n# \u73b0\u5728\u6bcf\u6b21\u5220\u9664\u90fd\u662f\uff1aprev.next = prev.next.next\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/02.%20linked%20lists%2C%20stacks%2C%20and%20queues/#_3","title":"\u6a21\u5f0f\uff1a\u5feb\u6162\u6307\u9488\uff08\u5f17\u6d1b\u4f0a\u5fb7\u7b97\u6cd5\uff09","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/02.%20linked%20lists%2C%20stacks%2C%20and%20queues/#_4","title":"\u7b80\u5355\uff1a\u73af\u5f62\u94fe\u8868","text":"
def has_cycle(head):\n    slow = fast = head\n    while fast and fast.next:\n        slow = slow.next\n        fast = fast.next.next\n        if slow == fast:\n            return True\n    return False\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/02.%20linked%20lists%2C%20stacks%2C%20and%20queues/#_5","title":"\u4e2d\u7b49\uff1a\u5bfb\u627e\u94fe\u8868\u7684\u4e2d\u95f4\u8282\u70b9","text":"
def find_middle(head):\n    slow = fast = head\n    while fast and fast.next:\n        slow = slow.next\n        fast = fast.next.next\n    return slow  # slow \u5728\u4e2d\u95f4\uff08\u5076\u6570\u957f\u5ea6\u65f6\u4e3a\u7b2c\u4e8c\u4e2a\u4e2d\u95f4\u8282\u70b9\uff09\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/02.%20linked%20lists%2C%20stacks%2C%20and%20queues/#ii","title":"\u4e2d\u7b49\uff1a\u73af\u5f62\u94fe\u8868 II\uff08\u5bfb\u627e\u73af\u7684\u8d77\u70b9\uff09","text":"
def detect_cycle(head):\n    slow = fast = head\n    while fast and fast.next:\n        slow = slow.next\n        fast = fast.next.next\n        if slow == fast:\n            # \u5c06\u4e00\u4e2a\u6307\u9488\u91cd\u7f6e\u5230\u5934\u90e8\n            slow = head\n            while slow != fast:\n                slow = slow.next\n                fast = fast.next\n            return slow\n    return None\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/02.%20linked%20lists%2C%20stacks%2C%20and%20queues/#k","title":"\u56f0\u96be\uff1aK\u4e2a\u4e00\u7ec4\u53cd\u8f6c\u94fe\u8868","text":"
def reverse_k_group(head, k):\n    # \u68c0\u67e5\u662f\u5426\u8fd8\u6709 k \u4e2a\u8282\u70b9\n    node = head\n    for _ in range(k):\n        if not node:\n            return head\n        node = node.next\n\n    # \u53cd\u8f6c k \u4e2a\u8282\u70b9\n    prev, curr = None, head\n    for _ in range(k):\n        nxt = curr.next\n        curr.next = prev\n        prev = curr\n        curr = nxt\n\n    # \u5f53\u524d head \u73b0\u5728\u662f\u53cd\u8f6c\u540e\u7684\u7ec4\u7684\u5c3e\u8282\u70b9\n    # \u9012\u5f52\u5904\u7406\u5269\u4f59\u90e8\u5206\n    head.next = reverse_k_group(curr, k)\n    return prev  # prev \u662f\u8fd9\u7ec4\u7684\u65b0\u5934\u8282\u70b9\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/02.%20linked%20lists%2C%20stacks%2C%20and%20queues/#_6","title":"\u6808","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/02.%20linked%20lists%2C%20stacks%2C%20and%20queues/#_7","title":"\u7b80\u5355\uff1a\u6709\u6548\u7684\u62ec\u53f7","text":"
def is_valid(s):\n    stack = []\n    matching = {')': '(', ']': '[', '}': '{'}\n\n    for char in s:\n        if char in matching:\n            if not stack or stack[-1] != matching[char]:\n                return False\n            stack.pop()\n        else:\n            stack.append(char)\n\n    return len(stack) == 0\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/02.%20linked%20lists%2C%20stacks%2C%20and%20queues/#_8","title":"\u6a21\u5f0f\uff1a\u5355\u8c03\u6808","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/02.%20linked%20lists%2C%20stacks%2C%20and%20queues/#_9","title":"\u4e2d\u7b49\uff1a\u6bcf\u65e5\u6e29\u5ea6","text":"
def daily_temperatures(temperatures):\n    n = len(temperatures)\n    result = [0] * n\n    stack = []  # \u7d22\u5f15\u6808\uff0c\u6e29\u5ea6\u9012\u51cf\n\n    for i in range(n):\n        while stack and temperatures[i] > temperatures[stack[-1]]:\n            prev = stack.pop()\n            result[prev] = i - prev\n        stack.append(i)\n\n    return result\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/02.%20linked%20lists%2C%20stacks%2C%20and%20queues/#_10","title":"\u56f0\u96be\uff1a\u67f1\u72b6\u56fe\u4e2d\u6700\u5927\u7684\u77e9\u5f62","text":"
def largest_rectangle(heights):\n    stack = []  # \u7d22\u5f15\u6808\uff0c\u9ad8\u5ea6\u9012\u589e\n    max_area = 0\n    heights.append(0)  # \u54e8\u5175\uff0c\u7528\u4e8e\u6700\u540e\u6e05\u7a7a\u6808\n\n    for i, h in enumerate(heights):\n        start = i\n        while stack and stack[-1][1] > h:\n            idx, height = stack.pop()\n            max_area = max(max_area, height * (i - idx))\n            start = idx  # \u5f53\u524d\u6761\u5f62\u53ef\u4ee5\u5ef6\u4f38\u5230\u88ab\u5f39\u51fa\u6761\u5f62\u5f00\u59cb\u7684\u4f4d\u7f6e\n        stack.append((start, h))\n\n    heights.pop()  # \u79fb\u9664\u54e8\u5175\n    return max_area\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/02.%20linked%20lists%2C%20stacks%2C%20and%20queues/#_11","title":"\u961f\u5217","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/02.%20linked%20lists%2C%20stacks%2C%20and%20queues/#_12","title":"\u7b80\u5355\uff1a\u7528\u6808\u5b9e\u73b0\u961f\u5217","text":"
class MyQueue:\n    def __init__(self):\n        self.push_stack = []\n        self.pop_stack = []\n\n    def push(self, x):\n        self.push_stack.append(x)\n\n    def pop(self):\n        if not self.pop_stack:\n            while self.push_stack:\n                self.pop_stack.append(self.push_stack.pop())\n        return self.pop_stack.pop()\n\n    def peek(self):\n        if not self.pop_stack:\n            while self.push_stack:\n                self.pop_stack.append(self.push_stack.pop())\n        return self.pop_stack[-1]\n\n    def empty(self):\n        return not self.push_stack and not self.pop_stack\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/02.%20linked%20lists%2C%20stacks%2C%20and%20queues/#_13","title":"\u4f18\u5148\u961f\u5217\u548c\u5806","text":" \u64cd\u4f5c \u65f6\u95f4 \u63d2\u5165 \\(O(\\log n)\\) \u83b7\u53d6\u6700\u5c0f\u503c \\(O(1)\\) \u63d0\u53d6\u6700\u5c0f\u503c \\(O(\\log n)\\) \u4ece\u6570\u7ec4\u6784\u5efa\u5806 \\(O(n)\\)
import heapq\n\n# \u6700\u5c0f\u5806\nh = []\nheapq.heappush(h, 5)\nheapq.heappush(h, 2)\nheapq.heappush(h, 8)\nprint(heapq.heappop(h))  # 2\uff08\u6700\u5c0f\uff09\n\n# \u6700\u5927\u5806\u6280\u5de7\uff1a\u53d6\u53cd\nheapq.heappush(h, -10)\nprint(-heapq.heappop(h))  # 10\uff08\u6700\u5927\uff09\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/02.%20linked%20lists%2C%20stacks%2C%20and%20queues/#k_1","title":"\u4e2d\u7b49\uff1a\u6570\u7ec4\u4e2d\u7684\u7b2c K \u4e2a\u6700\u5927\u5143\u7d20","text":"
import heapq\n\ndef find_kth_largest(nums, k):\n    heap = nums[:k]\n    heapq.heapify(heap)  # O(k)\n\n    for num in nums[k:]:\n        if num > heap[0]:\n            heapq.heapreplace(heap, num)  # \u5f39\u51fa\u6700\u5c0f\u503c\uff0c\u63a8\u5165 num\uff1aO(log k)\n\n    return heap[0]\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/02.%20linked%20lists%2C%20stacks%2C%20and%20queues/#k_2","title":"\u56f0\u96be\uff1a\u5408\u5e76 K \u4e2a\u6392\u5e8f\u94fe\u8868","text":"
import heapq\n\ndef merge_k_lists(lists):\n    heap = []\n    for i, lst in enumerate(lists):\n        if lst:\n            heapq.heappush(heap, (lst.val, i, lst))\n\n    dummy = ListNode(0)\n    curr = dummy\n\n    while heap:\n        val, i, node = heapq.heappop(heap)\n        curr.next = node\n        curr = curr.next\n        if node.next:\n            heapq.heappush(heap, (node.next.val, i, node.next))\n\n    return dummy.next\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/02.%20linked%20lists%2C%20stacks%2C%20and%20queues/#_14","title":"\u5e38\u89c1\u9677\u9631\u603b\u7ed3","text":"\u9677\u9631 \u793a\u4f8b \u4fee\u590d fast.next \u4e0a\u7684\u7a7a\u6307\u9488 \u5faa\u73af\u68c0\u6d4b\u4e2d\u4f7f\u7528 while fast.next \u68c0\u67e5 fast and fast.next \u672a\u5904\u7406\u7a7a\u94fe\u8868 \u53cd\u8f6c None \u6dfb\u52a0 if not head \u5b88\u536b \u6808\u4e0b\u6ea2 \u4ece\u7a7a\u6808\u5f39\u51fa \u68c0\u67e5 len(stack) > 0 \u6216 if stack \u5fd8\u8bb0\u54e8\u5175 \u76f4\u65b9\u56fe\u9057\u6f0f\u4e86\u6700\u540e\u7684\u6761\u5f62 \u8ffd\u52a0 0 \u6765\u6e05\u7a7a\u6808 \u5806\u4e2d\u7f3a\u5c11\u5e73\u5c40\u6253\u7834 \u6bd4\u8f83\u4e0d\u53ef\u6bd4\u8f83\u7684\u5bf9\u8c61 \u5411\u5806\u5143\u7ec4\u6dfb\u52a0\u7d22\u5f15 \u904d\u5386\u65f6\u4fee\u6539\u94fe\u8868 \u904d\u5386\u65f6\u5220\u9664\u8282\u70b9 \u4f7f\u7528 prev/curr \u6a21\u5f0f\u6216\u865a\u62df\u5934\u8282\u70b9"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/02.%20linked%20lists%2C%20stacks%2C%20and%20queues/#neetcode","title":"\u8bfe\u540e\u7ec3\u4e60\u9898\uff08NeetCode\uff09","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/02.%20linked%20lists%2C%20stacks%2C%20and%20queues/#_15","title":"\u94fe\u8868","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/02.%20linked%20lists%2C%20stacks%2C%20and%20queues/#_16","title":"\u6808","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/02.%20linked%20lists%2C%20stacks%2C%20and%20queues/#_17","title":"\u5806 / \u4f18\u5148\u961f\u5217","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/03.%20trees/","title":"\u6811","text":"

\u6811\u662f\u5c42\u6b21\u5316\u6570\u636e\u7ed3\u6784\uff0c\u662f\u6587\u4ef6\u7cfb\u7edf\u3001\u6570\u636e\u5e93\u3001\u7f16\u8bd1\u5668\u548c\u65e0\u6570\u9762\u8bd5\u9898\u80cc\u540e\u7684\u57fa\u7840\u3002\u672c\u6587\u4ef6\u6db5\u76d6\u4e8c\u53c9\u6811\u3001\u4e8c\u53c9\u641c\u7d22\u6811\u3001\u5e73\u8861\u6811\u3001\u524d\u7f00\u6811\u3001\u7ebf\u6bb5\u6811\u3001\u6811\u72b6\u6570\u7ec4\u548c\u5e76\u67e5\u96c6\uff0c\u5305\u62ec\u904d\u5386\u6a21\u5f0f\u3001\u9012\u5f52\u601d\u7ef4\u4ee5\u53ca\u9010\u6b65\u589e\u52a0\u96be\u5ea6\u7684\u9898\u76ee\u3002

"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/03.%20trees/#_2","title":"\u4e8c\u53c9\u6811\u904d\u5386","text":"
class TreeNode:\n    def __init__(self, val=0, left=None, right=None):\n        self.val = val\n        self.left = left\n        self.right = right\n\ndef inorder(root):\n    if not root:\n        return []\n    return inorder(root.left) + [root.val] + inorder(root.right)\n\ndef preorder(root):\n    if not root:\n        return []\n    return [root.val] + preorder(root.left) + preorder(root.right)\n\ndef postorder(root):\n    if not root:\n        return []\n    return postorder(root.left) + postorder(root.right) + [root.val]\n\nfrom collections import deque\n\ndef level_order(root):\n    if not root:\n        return []\n    result, queue = [], deque([root])\n    while queue:\n        level = []\n        for _ in range(len(queue)):\n            node = queue.popleft()\n            level.append(node.val)\n            if node.left:\n                queue.append(node.left)\n            if node.right:\n                queue.append(node.right)\n        result.append(level)\n    return result\n
def inorder_efficient(root, result=None):\n    if result is None:\n        result = []\n    if root:\n        inorder_efficient(root.left, result)\n        result.append(root.val)\n        inorder_efficient(root.right, result)\n    return result\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/03.%20trees/#_3","title":"\u7b80\u5355\uff1a\u4e8c\u53c9\u6811\u7684\u6700\u5927\u6df1\u5ea6","text":"
def max_depth(root):\n    if not root:\n        return 0\n    return 1 + max(max_depth(root.left), max_depth(root.right))\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/03.%20trees/#_4","title":"\u7b80\u5355\uff1a\u7ffb\u8f6c\u4e8c\u53c9\u6811","text":"
def invert_tree(root):\n    if not root:\n        return None\n    root.left, root.right = invert_tree(root.right), invert_tree(root.left)\n    return root\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/03.%20trees/#_5","title":"\u4e2d\u7b49\uff1a\u4e8c\u53c9\u6811\u7684\u6700\u8fd1\u516c\u5171\u7956\u5148","text":"
def lowest_common_ancestor(root, p, q):\n    if not root or root == p or root == q:\n        return root\n\n    left = lowest_common_ancestor(root.left, p, q)\n    right = lowest_common_ancestor(root.right, p, q)\n\n    if left and right:\n        return root  # p \u548c q \u5728\u4e0d\u540c\u5b50\u6811\u4e2d\n    return left if left else right\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/03.%20trees/#_6","title":"\u56f0\u96be\uff1a\u4e8c\u53c9\u6811\u4e2d\u7684\u6700\u5927\u8def\u5f84\u548c","text":"
def max_path_sum(root):\n    best = [float('-inf')]\n\n    def dfs(node):\n        if not node:\n            return 0\n        left = max(dfs(node.left), 0)   # \u5ffd\u7565\u8d1f\u8def\u5f84\n        right = max(dfs(node.right), 0)\n\n        # \u7ecf\u8fc7\u5f53\u524d\u8282\u70b9\u7684\u8def\u5f84\uff08\u53ef\u80fd\u4f5c\u4e3a\"\u8f6c\u5f2f\u70b9\"\uff09\n        best[0] = max(best[0], node.val + left + right)\n\n        # \u8fd4\u56de\u5230\u7236\u8282\u70b9\u7684\u6700\u5927\u589e\u76ca\n        return node.val + max(left, right)\n\n    dfs(root)\n    return best[0]\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/03.%20trees/#bst","title":"\u4e8c\u53c9\u641c\u7d22\u6811\uff08BST\uff09","text":"
def search_bst(root, target):\n    if not root:\n        return None\n    if target < root.val:\n        return search_bst(root.left, target)\n    elif target > root.val:\n        return search_bst(root.right, target)\n    else:\n        return root\n\ndef insert_bst(root, val):\n    if not root:\n        return TreeNode(val)\n    if val < root.val:\n        root.left = insert_bst(root.left, val)\n    else:\n        root.right = insert_bst(root.right, val)\n    return root\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/03.%20trees/#_7","title":"\u4e2d\u7b49\uff1a\u9a8c\u8bc1\u4e8c\u53c9\u641c\u7d22\u6811","text":"
def is_valid_bst(root, lo=float('-inf'), hi=float('inf')):\n    if not root:\n        return True\n    if root.val <= lo or root.val >= hi:\n        return False\n    return (is_valid_bst(root.left, lo, root.val) and\n            is_valid_bst(root.right, root.val, hi))\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/03.%20trees/#k","title":"\u4e2d\u7b49\uff1a\u4e8c\u53c9\u641c\u7d22\u6811\u4e2d\u7b2c K \u5c0f\u7684\u5143\u7d20","text":"
def kth_smallest(root, k):\n    count = [0]\n    result = [None]\n\n    def inorder(node):\n        if not node or result[0] is not None:\n            return\n        inorder(node.left)\n        count[0] += 1\n        if count[0] == k:\n            result[0] = node.val\n            return\n        inorder(node.right)\n\n    inorder(root)\n    return result[0]\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/03.%20trees/#trie","title":"\u524d\u7f00\u6811\uff08Trie\uff09","text":"
class TrieNode:\n    def __init__(self):\n        self.children = {}\n        self.is_end = False\n\nclass Trie:\n    def __init__(self):\n        self.root = TrieNode()\n\n    def insert(self, word):\n        node = self.root\n        for char in word:\n            if char not in node.children:\n                node.children[char] = TrieNode()\n            node = node.children[char]\n        node.is_end = True\n\n    def search(self, word):\n        node = self.root\n        for char in word:\n            if char not in node.children:\n                return False\n            node = node.children[char]\n        return node.is_end\n\n    def starts_with(self, prefix):\n        node = self.root\n        for char in prefix:\n            if char not in node.children:\n                return False\n            node = node.children[char]\n        return True\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/03.%20trees/#ii","title":"\u56f0\u96be\uff1a\u5355\u8bcd\u641c\u7d22 II","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/03.%20trees/#_8","title":"\u5e76\u67e5\u96c6\uff08\u4e0d\u76f8\u4ea4\u96c6\u5408\uff09","text":"
class UnionFind:\n    def __init__(self, n):\n        self.parent = list(range(n))\n        self.rank = [0] * n\n        self.count = n  # \u8fde\u901a\u5206\u91cf\u6570\n\n    def find(self, x):\n        if self.parent[x] != x:\n            self.parent[x] = self.find(self.parent[x])  # \u8def\u5f84\u538b\u7f29\n        return self.parent[x]\n\n    def union(self, x, y):\n        rx, ry = self.find(x), self.find(y)\n        if rx == ry:\n            return False  # \u5df2\u7ecf\u8fde\u901a\n        # \u6309\u79e9\u5408\u5e76\n        if self.rank[rx] < self.rank[ry]:\n            rx, ry = ry, rx\n        self.parent[ry] = rx\n        if self.rank[rx] == self.rank[ry]:\n            self.rank[rx] += 1\n        self.count -= 1\n        return True\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/03.%20trees/#_9","title":"\u4e2d\u7b49\uff1a\u8fde\u901a\u5206\u91cf\u6570\u91cf","text":"
def count_components(n, edges):\n    uf = UnionFind(n)\n    for u, v in edges:\n        uf.union(u, v)\n    return uf.count\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/03.%20trees/#_10","title":"\u4e2d\u7b49\uff1a\u5197\u4f59\u8fde\u63a5","text":"
def find_redundant(edges):\n    uf = UnionFind(len(edges) + 1)\n    for u, v in edges:\n        if not uf.union(u, v):\n            return [u, v]  # \u5df2\u7ecf\u8fde\u901a \u2192 \u8fd9\u6761\u8fb9\u521b\u5efa\u4e86\u73af\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/03.%20trees/#_11","title":"\u7ebf\u6bb5\u6811\u548c\u6811\u72b6\u6570\u7ec4","text":"
class FenwickTree:\n    def __init__(self, n):\n        self.n = n\n        self.tree = [0] * (n + 1)\n\n    def update(self, i, delta):\n        i += 1  # 1-indexed\n        while i <= self.n:\n            self.tree[i] += delta\n            i += i & (-i)  # \u52a0\u4e0a\u6700\u4f4e\u8bbe\u7f6e\u4f4d\n\n    def prefix_sum(self, i):\n        i += 1\n        total = 0\n        while i > 0:\n            total += self.tree[i]\n            i -= i & (-i)  # \u79fb\u9664\u6700\u4f4e\u8bbe\u7f6e\u4f4d\n        return total\n\n    def range_sum(self, l, r):\n        return self.prefix_sum(r) - (self.prefix_sum(l - 1) if l > 0 else 0)\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/03.%20trees/#_12","title":"\u5e38\u89c1\u9677\u9631\u603b\u7ed3","text":"\u9677\u9631 \u793a\u4f8b \u4fee\u590d BST \u53ea\u68c0\u67e5\u76f4\u63a5\u5b50\u8282\u70b9 left.val < root.val \u9057\u6f0f\u4e86\u6df1\u5c42\u8fdd\u89c4 \u4f20\u9012 lo/hi \u8fb9\u754c \u9012\u5f52\u4e2d \\(O(n^2)\\) \u5217\u8868\u62fc\u63a5 inorder(left) + [val] + inorder(right) \u8ffd\u52a0\u5230\u5171\u4eab\u5217\u8868 \u5fd8\u8bb0\u57fa\u672c\u60c5\u51b5 \u7a7a\u6811\u4e0a\u7684\u65e0\u9650\u9012\u5f52 if not root: return \u6df7\u6dc6\u7ecf\u8fc7\u8def\u5f84\u548c\u5230\u7236\u8282\u70b9\u7684\u8def\u5f84 \u6700\u5927\u8def\u5f84\u548c\uff1a\u5728\u4e24\u4e2a\u5c42\u7ea7\u5206\u53c9 \u5411\u7236\u8282\u70b9\u8fd4\u56de\u5355\u5206\u652f\uff0c\u5355\u72ec\u8ddf\u8e2a\u53cc\u5206\u652f \u6811\u72b6\u6570\u7ec4 1-indexed vs 0-indexed \u6811\u6570\u7ec4\u4e2d\u7684\u5dee\u4e00\u9519\u8bef \u5165\u53e3\u5904\u59cb\u7ec8 i += 1 \u5e76\u67e5\u96c6\u6ca1\u6709\u8def\u5f84\u538b\u7f29 \u6700\u574f\u60c5\u51b5\u4e0b\u6bcf\u6b21 find \u662f \\(O(n)\\) self.parent[x] = self.find(self.parent[x])"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/03.%20trees/#neetcode","title":"\u8bfe\u540e\u7ec3\u4e60\u9898\uff08NeetCode\uff09","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/03.%20trees/#_13","title":"\u4e8c\u53c9\u6811\u6a21\u5f0f","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/03.%20trees/#bst_1","title":"BST \u6a21\u5f0f","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/03.%20trees/#_14","title":"\u524d\u7f00\u6811","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/03.%20trees/#_15","title":"\u5e76\u67e5\u96c6","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/04.%20graphs/","title":"\u56fe","text":"

\u56fe\u5efa\u6a21\u4e86\u5173\u7cfb\u548c\u8fde\u63a5\u2014\u2014\u4ece\u793e\u4ea4\u7f51\u7edc\u5230\u9053\u8def\u5730\u56fe\u518d\u5230\u4f9d\u8d56\u94fe\u3002\u672c\u6587\u4ef6\u6db5\u76d6\u56fe\u7684\u8868\u793a\u3001BFS\u3001DFS\u3001\u6700\u77ed\u8def\u5f84\u3001\u62d3\u6251\u6392\u5e8f\u548c\u8fde\u901a\u5206\u91cf\uff0c\u5305\u62ec\u904d\u5386\u548c\u5bfb\u8def\u6a21\u5f0f\uff0c\u8fd9\u4e9b\u662f\u56fe\u9762\u8bd5\u9898\u4e2d\u7684\u6838\u5fc3\u3002

"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/04.%20graphs/#_2","title":"\u56fe\u7684\u8868\u793a","text":"
# \u65e0\u5411\u56fe\ngraph = {\n    0: [1, 2],\n    1: [0, 3],\n    2: [0, 3],\n    3: [1, 2]\n}\n\n# \u4ece\u8fb9\u5217\u8868\u6784\u5efa\ndef build_graph(n, edges):\n    graph = {i: [] for i in range(n)}\n    for u, v in edges:\n        graph[u].append(v)\n        graph[v].append(u)  # \u6709\u5411\u56fe\u7701\u7565\u8fd9\u4e00\u884c\n    return graph\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/04.%20graphs/#bfs","title":"\u6a21\u5f0f\uff1aBFS\uff08\u5e7f\u5ea6\u4f18\u5148\u641c\u7d22\uff09","text":"
from collections import deque\n\ndef bfs(graph, start):\n    visited = {start}\n    queue = deque([start])\n\n    while queue:\n        node = queue.popleft()\n        for neighbour in graph[node]:\n            if neighbour not in visited:\n                visited.add(neighbour)\n                queue.append(neighbour)\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/04.%20graphs/#_3","title":"\u7b80\u5355\uff1a\u5c9b\u5c7f\u6570\u91cf","text":"
from collections import deque\n\ndef num_islands(grid):\n    if not grid:\n        return 0\n\n    rows, cols = len(grid), len(grid[0])\n    count = 0\n\n    for r in range(rows):\n        for c in range(cols):\n            if grid[r][c] == '1':\n                count += 1\n                # BFS \u6807\u8bb0\u6574\u4e2a\u5c9b\u5c7f\n                queue = deque([(r, c)])\n                grid[r][c] = '0'  # \u6807\u8bb0\u5df2\u8bbf\u95ee\n                while queue:\n                    cr, cc = queue.popleft()\n                    for dr, dc in [(0,1),(0,-1),(1,0),(-1,0)]:\n                        nr, nc = cr + dr, cc + dc\n                        if 0 <= nr < rows and 0 <= nc < cols and grid[nr][nc] == '1':\n                            grid[nr][nc] = '0'\n                            queue.append((nr, nc))\n\n    return count\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/04.%20graphs/#_4","title":"\u4e2d\u7b49\uff1a\u8150\u70c2\u7684\u6a58\u5b50","text":"
from collections import deque\n\ndef oranges_rotting(grid):\n    rows, cols = len(grid), len(grid[0])\n    queue = deque()\n    fresh = 0\n\n    for r in range(rows):\n        for c in range(cols):\n            if grid[r][c] == 2:\n                queue.append((r, c))\n            elif grid[r][c] == 1:\n                fresh += 1\n\n    if fresh == 0:\n        return 0\n\n    time = 0\n    while queue and fresh > 0:\n        time += 1\n        for _ in range(len(queue)):\n            cr, cc = queue.popleft()\n            for dr, dc in [(0,1),(0,-1),(1,0),(-1,0)]:\n                nr, nc = cr + dr, cc + dc\n                if 0 <= nr < rows and 0 <= nc < cols and grid[nr][nc] == 1:\n                    grid[nr][nc] = 2\n                    fresh -= 1\n                    queue.append((nr, nc))\n\n    return time if fresh == 0 else -1\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/04.%20graphs/#dfs","title":"\u6a21\u5f0f\uff1aDFS\uff08\u6df1\u5ea6\u4f18\u5148\u641c\u7d22\uff09","text":"
def dfs(graph, node, visited=None):\n    if visited is None:\n        visited = set()\n    visited.add(node)\n    for neighbour in graph[node]:\n        if neighbour not in visited:\n            dfs(graph, neighbour, visited)\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/04.%20graphs/#_5","title":"\u4e2d\u7b49\uff1a\u8bfe\u7a0b\u8868\uff08\u73af\u68c0\u6d4b\uff09","text":"
def can_finish(num_courses, prerequisites):\n    graph = {i: [] for i in range(num_courses)}\n    for course, prereq in prerequisites:\n        graph[course].append(prereq)\n\n    # 0 = \u672a\u8bbf\u95ee, 1 = \u8fdb\u884c\u4e2d, 2 = \u5df2\u5b8c\u6210\n    state = [0] * num_courses\n\n    def has_cycle(node):\n        if state[node] == 1:\n            return True   # \u56de\u8fb9 \u2192 \u73af\n        if state[node] == 2:\n            return False  # \u5df2\u7ecf\u5b8c\u5168\u63a2\u7d22\u8fc7\n\n        state[node] = 1  # \u6807\u8bb0\u4e3a\u8fdb\u884c\u4e2d\n        for neighbour in graph[node]:\n            if has_cycle(neighbour):\n                return True\n        state[node] = 2  # \u6807\u8bb0\u4e3a\u5df2\u5b8c\u6210\n        return False\n\n    for course in range(num_courses):\n        if has_cycle(course):\n            return False\n    return True\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/04.%20graphs/#ii","title":"\u4e2d\u7b49\uff1a\u8bfe\u7a0b\u8868 II\uff08\u62d3\u6251\u6392\u5e8f\uff09","text":"
from collections import deque\n\ndef find_order(num_courses, prerequisites):\n    graph = {i: [] for i in range(num_courses)}\n    indegree = [0] * num_courses\n\n    for course, prereq in prerequisites:\n        graph[prereq].append(course)\n        indegree[course] += 1\n\n    queue = deque([i for i in range(num_courses) if indegree[i] == 0])\n    order = []\n\n    while queue:\n        node = queue.popleft()\n        order.append(node)\n        for neighbour in graph[node]:\n            indegree[neighbour] -= 1\n            if indegree[neighbour] == 0:\n                queue.append(neighbour)\n\n    return order if len(order) == num_courses else []  # \u7a7a = \u5b58\u5728\u73af\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/04.%20graphs/#_6","title":"\u6700\u77ed\u8def\u5f84","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/04.%20graphs/#dijkstra","title":"Dijkstra \u7b97\u6cd5","text":"
import heapq\n\ndef dijkstra(graph, start):\n    # graph: {node: [(neighbour, weight), ...]}\n    dist = {node: float('inf') for node in graph}\n    dist[start] = 0\n    heap = [(0, start)]\n\n    while heap:\n        d, node = heapq.heappop(heap)\n        if d > dist[node]:\n            continue  # \u8fc7\u671f\u6761\u76ee\n\n        for neighbour, weight in graph[node]:\n            new_dist = d + weight\n            if new_dist < dist[neighbour]:\n                dist[neighbour] = new_dist\n                heapq.heappush(heap, (new_dist, neighbour))\n\n    return dist\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/04.%20graphs/#_7","title":"\u56f0\u96be\uff1a\u7f51\u7edc\u5ef6\u8fdf\u65f6\u95f4","text":"
def network_delay(times, n, k):\n    graph = {i: [] for i in range(1, n + 1)}\n    for u, v, w in times:\n        graph[u].append((v, w))\n\n    dist = dijkstra(graph, k)\n    max_time = max(dist.values())\n    return max_time if max_time < float('inf') else -1\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/04.%20graphs/#_8","title":"\u5f3a\u8fde\u901a\u5206\u91cf","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/04.%20graphs/#_9","title":"\u5e38\u89c1\u9677\u9631\u603b\u7ed3","text":"\u9677\u9631 \u793a\u4f8b \u4fee\u590d \u5728\u51fa\u961f\u65f6\u6807\u8bb0\u5df2\u8bbf\u95ee \u540c\u4e00\u8282\u70b9\u88ab\u591a\u6b21\u5165\u961f \u5728\u5165\u961f\u65f6\u6807\u8bb0\u5df2\u8bbf\u95ee \u6709\u5411\u56fe\u4e2d\u53ea\u6709\u4e24\u79cd\u72b6\u6001 \u65e0\u6cd5\u533a\u5206\u56de\u8fb9\u548c\u4ea4\u53c9\u8fb9 \u4f7f\u7528\u4e09\u79cd\u72b6\u6001\uff1a\u672a\u8bbf\u95ee/\u8fdb\u884c\u4e2d/\u5df2\u5b8c\u6210 Dijkstra \u7528\u4e8e\u8d1f\u6743\u91cd \u9519\u8bef\u7684\u6700\u77ed\u8def\u5f84 \u4f7f\u7528 Bellman-Ford \u5fd8\u8bb0 if d > dist[node]: continue \u5904\u7406\u8fc7\u671f\u5806\u6761\u76ee \u603b\u662f\u8df3\u8fc7\u5f53\u524d\u8ddd\u79bb\u66f4\u5dee\u7684\u60c5\u51b5 \u7f51\u683c\u8fb9\u754c\u68c0\u67e5 \u7d22\u5f15\u8d8a\u754c 0 <= nr < rows and 0 <= nc < cols \u6ca1\u6709\u8003\u8651 time=0 \u7684\u8fb9\u754c\u60c5\u51b5 \u8150\u70c2\u6a58\u5b50\uff1a\u6ca1\u6709\u65b0\u9c9c\u6a58\u5b50 \u5728 BFS \u4e4b\u524d\u68c0\u67e5 fresh == 0 \u5c06\u6709\u5411\u56fe\u6784\u5efa\u4e3a\u65e0\u5411\u56fe \u5148\u4fee\u6761\u4ef6\u662f\u5355\u5411\u7684 \u53ea\u5728\u4e00\u4e2a\u65b9\u5411\u6dfb\u52a0\u8fb9"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/04.%20graphs/#neetcode","title":"\u8bfe\u540e\u7ec3\u4e60\u9898\uff08NeetCode\uff09","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/04.%20graphs/#bfs_1","title":"BFS \u6a21\u5f0f","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/04.%20graphs/#dfs_1","title":"DFS \u6a21\u5f0f","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/04.%20graphs/#_10","title":"\u6700\u77ed\u8def\u5f84","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/04.%20graphs/#_11","title":"\u8fdb\u9636","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/","title":"\u6392\u5e8f\u3001\u641c\u7d22\u4e0e\u7b97\u6cd5\u8bbe\u8ba1","text":"

\u6392\u5e8f\u548c\u641c\u7d22\u662f\u6700\u57fa\u7840\u7684\u7b97\u6cd5\u64cd\u4f5c\u3002\u672c\u6587\u4ef6\u6db5\u76d6\u6392\u5e8f\u7b97\u6cd5\u3001\u4e8c\u5206\u67e5\u627e\u6a21\u5f0f\u3001\u5206\u6cbb\u6cd5\u3001\u8d2a\u5fc3\u7b97\u6cd5\u3001\u52a8\u6001\u89c4\u5212\u548c\u56de\u6eaf\u3002

"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#_2","title":"\u6392\u5e8f\u7b97\u6cd5","text":" \u7b97\u6cd5 \u6700\u597d \u5e73\u5747 \u6700\u574f \u7a7a\u95f4 \u7a33\u5b9a\uff1f \u5192\u6ce1\u6392\u5e8f \\(O(n)\\) \\(O(n^2)\\) \\(O(n^2)\\) \\(O(1)\\) \u662f \u63d2\u5165\u6392\u5e8f \\(O(n)\\) \\(O(n^2)\\) \\(O(n^2)\\) \\(O(1)\\) \u662f \u5f52\u5e76\u6392\u5e8f \\(O(n \\log n)\\) \\(O(n \\log n)\\) \\(O(n \\log n)\\) \\(O(n)\\) \u662f \u5feb\u901f\u6392\u5e8f \\(O(n \\log n)\\) \\(O(n \\log n)\\) \\(O(n^2)\\) \\(O(\\log n)\\) \u5426 \u5806\u6392\u5e8f \\(O(n \\log n)\\) \\(O(n \\log n)\\) \\(O(n \\log n)\\) \\(O(1)\\) \u5426 \u8ba1\u6570\u6392\u5e8f \\(O(n + k)\\) \\(O(n + k)\\) \\(O(n + k)\\) \\(O(k)\\) \u662f \u57fa\u6570\u6392\u5e8f \\(O(d(n + k))\\) \\(O(d(n + k))\\) \\(O(d(n + k))\\) \\(O(n + k)\\) \u662f "},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#_3","title":"\u5f52\u5e76\u6392\u5e8f","text":"
def merge_sort(arr):\n    if len(arr) <= 1:\n        return arr\n\n    mid = len(arr) // 2\n    left = merge_sort(arr[:mid])\n    right = merge_sort(arr[mid:])\n\n    return merge(left, right)\n\ndef merge(left, right):\n    result = []\n    i = j = 0\n    while i < len(left) and j < len(right):\n        if left[i] <= right[j]:  # <= \u4fdd\u8bc1\u7a33\u5b9a\u6027\n            result.append(left[i])\n            i += 1\n        else:\n            result.append(right[j])\n            j += 1\n    result.extend(left[i:])\n    result.extend(right[j:])\n    return result\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#_4","title":"\u5feb\u901f\u6392\u5e8f","text":"
def quicksort(arr, lo=0, hi=None):\n    if hi is None:\n        hi = len(arr) - 1\n    if lo >= hi:\n        return\n\n    pivot_idx = partition(arr, lo, hi)\n    quicksort(arr, lo, pivot_idx - 1)\n    quicksort(arr, pivot_idx + 1, hi)\n\ndef partition(arr, lo, hi):\n    pivot = arr[hi]  # Lomuto \u5206\u533a\uff1a\u57fa\u51c6\u662f\u6700\u540e\u4e00\u4e2a\u5143\u7d20\n    i = lo\n    for j in range(lo, hi):\n        if arr[j] < pivot:\n            arr[i], arr[j] = arr[j], arr[i]\n            i += 1\n    arr[i], arr[hi] = arr[hi], arr[i]\n    return i\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#_5","title":"\u8ba1\u6570\u6392\u5e8f","text":"
def counting_sort(arr, k):\n    count = [0] * k\n    for x in arr:\n        count[x] += 1\n    result = []\n    for val in range(k):\n        result.extend([val] * count[val])\n    return result\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#_6","title":"\u6a21\u5f0f\uff1a\u4e8c\u5206\u67e5\u627e","text":"
def binary_search(arr, target):\n    lo, hi = 0, len(arr) - 1\n\n    while lo <= hi:\n        mid = lo + (hi - lo) // 2  # \u5728\u5176\u4ed6\u8bed\u8a00\u4e2d\u907f\u514d\u6ea2\u51fa\n        if arr[mid] == target:\n            return mid\n        elif arr[mid] < target:\n            lo = mid + 1\n        else:\n            hi = mid - 1\n\n    return -1  # \u672a\u627e\u5230\n
def lower_bound(arr, target):\n    lo, hi = 0, len(arr)\n    while lo < hi:\n        mid = (lo + hi) // 2\n        if arr[mid] < target:\n            lo = mid + 1\n        else:\n            hi = mid\n    return lo\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#_7","title":"\u7b80\u5355\uff1a\u4e8c\u5206\u67e5\u627e","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#_8","title":"\u4e2d\u7b49\uff1a\u641c\u7d22\u65cb\u8f6c\u6392\u5e8f\u6570\u7ec4","text":"
def search_rotated(nums, target):\n    lo, hi = 0, len(nums) - 1\n\n    while lo <= hi:\n        mid = (lo + hi) // 2\n        if nums[mid] == target:\n            return mid\n\n        # \u5de6\u534a\u90e8\u5206\u6709\u5e8f\n        if nums[lo] <= nums[mid]:\n            if nums[lo] <= target < nums[mid]:\n                hi = mid - 1\n            else:\n                lo = mid + 1\n        # \u53f3\u534a\u90e8\u5206\u6709\u5e8f\n        else:\n            if nums[mid] < target <= nums[hi]:\n                lo = mid + 1\n            else:\n                hi = mid - 1\n\n    return -1\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#_9","title":"\u56f0\u96be\uff1a\u5bfb\u627e\u4e24\u4e2a\u6709\u5e8f\u6570\u7ec4\u7684\u4e2d\u4f4d\u6570","text":"
def find_median(nums1, nums2):\n    if len(nums1) > len(nums2):\n        nums1, nums2 = nums2, nums1  # \u786e\u4fdd nums1 \u8f83\u77ed\n\n    m, n = len(nums1), len(nums2)\n    lo, hi = 0, m\n    half = (m + n + 1) // 2\n\n    while lo <= hi:\n        i = (lo + hi) // 2          # nums1 \u4e2d\u7684\u5206\u5272\u70b9\n        j = half - i                 # nums2 \u4e2d\u7684\u5206\u5272\u70b9\n\n        left1 = nums1[i - 1] if i > 0 else float('-inf')\n        right1 = nums1[i] if i < m else float('inf')\n        left2 = nums2[j - 1] if j > 0 else float('-inf')\n        right2 = nums2[j] if j < n else float('inf')\n\n        if left1 <= right2 and left2 <= right1:\n            # \u6b63\u786e\u5206\u5272\n            if (m + n) % 2 == 1:\n                return max(left1, left2)\n            return (max(left1, left2) + min(right1, right2)) / 2\n        elif left1 > right2:\n            hi = i - 1\n        else:\n            lo = i + 1\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#_10","title":"\u5143\u6a21\u5f0f\uff1a\u5bf9\u7b54\u6848\u8fdb\u884c\u4e8c\u5206\u67e5\u627e","text":"
def ship_within_days(weights, days):\n    lo, hi = max(weights), sum(weights)\n\n    while lo < hi:\n        mid = (lo + hi) // 2\n        # \u80fd\u5426\u4ee5\u8fd0\u529b mid \u5728 <= days \u5929\u5185\u8fd0\u9001\u5b8c\uff1f\n        current_load, num_days = 0, 1\n        for w in weights:\n            if current_load + w > mid:\n                num_days += 1\n                current_load = 0\n            current_load += w\n\n        if num_days <= days:\n            hi = mid\n        else:\n            lo = mid + 1\n\n    return lo\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#_11","title":"\u6a21\u5f0f\uff1a\u8d2a\u5fc3\u7b97\u6cd5","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#_12","title":"\u4e2d\u7b49\uff1a\u8df3\u8dc3\u6e38\u620f","text":"
def can_jump(nums):\n    max_reach = 0\n    for i, jump in enumerate(nums):\n        if i > max_reach:\n            return False  # \u65e0\u6cd5\u5230\u8fbe\u8fd9\u4e2a\u4f4d\u7f6e\n        max_reach = max(max_reach, i + jump)\n    return True\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#_13","title":"\u4e2d\u7b49\uff1a\u5408\u5e76\u533a\u95f4","text":"
def merge_intervals(intervals):\n    intervals.sort(key=lambda x: x[0])\n    merged = [intervals[0]]\n\n    for start, end in intervals[1:]:\n        if start <= merged[-1][1]:\n            merged[-1][1] = max(merged[-1][1], end)\n        else:\n            merged.append([start, end])\n\n    return merged\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#_14","title":"\u6a21\u5f0f\uff1a\u52a8\u6001\u89c4\u5212","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#_15","title":"\u7b80\u5355\uff1a\u722c\u697c\u68af","text":"
def climb_stairs(n):\n    if n <= 2:\n        return n\n    a, b = 1, 2\n    for _ in range(3, n + 1):\n        a, b = b, a + b\n    return b\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#_16","title":"\u4e2d\u7b49\uff1a\u96f6\u94b1\u5151\u6362","text":"
def coin_change(coins, amount):\n    dp = [float('inf')] * (amount + 1)\n    dp[0] = 0\n\n    for a in range(1, amount + 1):\n        for coin in coins:\n            if coin <= a and dp[a - coin] + 1 < dp[a]:\n                dp[a] = dp[a - coin] + 1\n\n    return dp[amount] if dp[amount] != float('inf') else -1\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#_17","title":"\u4e2d\u7b49\uff1a\u6700\u957f\u516c\u5171\u5b50\u5e8f\u5217","text":"
def longest_common_subsequence(text1, text2):\n    m, n = len(text1), len(text2)\n    dp = [[0] * (n + 1) for _ in range(m + 1)]\n\n    for i in range(1, m + 1):\n        for j in range(1, n + 1):\n            if text1[i - 1] == text2[j - 1]:\n                dp[i][j] = dp[i - 1][j - 1] + 1\n            else:\n                dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])\n\n    return dp[m][n]\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#01","title":"\u56f0\u96be\uff1a0/1 \u80cc\u5305","text":"
def knapsack(weights, values, capacity):\n    n = len(weights)\n    dp = [[0] * (capacity + 1) for _ in range(n + 1)]\n\n    for i in range(1, n + 1):\n        for w in range(capacity + 1):\n            dp[i][w] = dp[i - 1][w]  # \u8df3\u8fc7\u7269\u54c1 i\n            if weights[i - 1] <= w:\n                dp[i][w] = max(dp[i][w],\n                               dp[i - 1][w - weights[i - 1]] + values[i - 1])\n\n    return dp[n][capacity]\n
def knapsack_optimised(weights, values, capacity):\n    dp = [0] * (capacity + 1)\n    for i in range(len(weights)):\n        for w in range(capacity, weights[i] - 1, -1):  # \u4ece\u53f3\u5411\u5de6\uff01\n            dp[w] = max(dp[w], dp[w - weights[i]] + values[i])\n    return dp[capacity]\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#_18","title":"\u6a21\u5f0f\uff1a\u56de\u6eaf","text":"
def backtrack(candidates, path, result):\n    if is_solution(path):\n        result.append(path[:])  # \u590d\u5236\uff01\n        return\n\n    for candidate in get_candidates(path):\n        if is_valid(candidate, path):\n            path.append(candidate)     # \u9009\u62e9\n            backtrack(candidates, path, result)  # \u63a2\u7d22\n            path.pop()                 # \u64a4\u9500\uff08\u56de\u6eaf\uff09\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#_19","title":"\u4e2d\u7b49\uff1a\u5b50\u96c6","text":"
def subsets(nums):\n    result = []\n    def backtrack(start, path):\n        result.append(path[:])\n        for i in range(start, len(nums)):\n            path.append(nums[i])\n            backtrack(i + 1, path)\n            path.pop()\n    backtrack(0, [])\n    return result\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#_20","title":"\u4e2d\u7b49\uff1a\u7ec4\u5408\u603b\u548c","text":"
def combination_sum(candidates, target):\n    result = []\n    def backtrack(start, path, remaining):\n        if remaining == 0:\n            result.append(path[:])\n            return\n        for i in range(start, len(candidates)):\n            if candidates[i] > remaining:\n                break  # \u526a\u679d\uff1a\u5df2\u6392\u5e8f\uff0c\u540e\u7eed\u5019\u9009\u90fd\u592a\u5927\n            path.append(candidates[i])\n            backtrack(i, path, remaining - candidates[i])  # i \u800c\u4e0d\u662f i+1\uff1a\u5141\u8bb8\u91cd\u590d\u4f7f\u7528\n            path.pop()\n\n    candidates.sort()  # \u6392\u5e8f\u4ee5\u4fbf\u526a\u679d\n    backtrack(0, [], target)\n    return result\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#n","title":"\u56f0\u96be\uff1aN \u7687\u540e","text":"
def solve_n_queens(n):\n    result = []\n    cols = set()\n    pos_diag = set()  # (row + col) \u5728 / \u5bf9\u89d2\u7ebf\u4e0a\u4e3a\u5e38\u6570\n    neg_diag = set()  # (row - col) \u5728 \\ \u5bf9\u89d2\u7ebf\u4e0a\u4e3a\u5e38\u6570\n\n    board = [['.' ] * n for _ in range(n)]\n\n    def backtrack(row):\n        if row == n:\n            result.append([''.join(r) for r in board])\n            return\n\n        for col in range(n):\n            if col in cols or (row + col) in pos_diag or (row - col) in neg_diag:\n                continue\n\n            cols.add(col)\n            pos_diag.add(row + col)\n            neg_diag.add(row - col)\n            board[row][col] = 'Q'\n\n            backtrack(row + 1)\n\n            cols.remove(col)\n            pos_diag.remove(row + col)\n            neg_diag.remove(row - col)\n            board[row][col] = '.'\n\n    backtrack(0)\n    return result\n
"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#_21","title":"\u5e38\u89c1\u9677\u9631\u603b\u7ed3","text":"\u9677\u9631 \u793a\u4f8b \u4fee\u590d \u4e8c\u5206\u67e5\u627e\u4e2d lo <= hi vs lo < hi \u8fb9\u754c\u5dee\u4e00\u9519\u8bef \u6839\u636e hi \u662f\u5305\u542b\u8fd8\u662f\u6392\u9664\u6765\u9009\u62e9 \u4ece\u5de6\u5230\u53f3\u7684\u4e00\u7ef4\u80cc\u5305 \u7269\u54c1\u88ab\u591a\u6b21\u4f7f\u7528 0/1 \u80cc\u5305\u4ece\u53f3\u5411\u5de6\u8fed\u4ee3 \u56de\u6eaf\u4e2d\u672a\u590d\u5236\u8def\u5f84 result.append(path) \u2014 \u6240\u6709\u6761\u76ee\u6307\u5411\u540c\u4e00\u5217\u8868 result.append(path[:]) \u6216 path.copy() backtrack(i) vs backtrack(i+1) \u91cd\u590d\u4f7f\u7528 vs \u4e0d\u91cd\u590d\u4f7f\u7528\u5143\u7d20 \u5339\u914d\u95ee\u9898\u8981\u6c42 \u6392\u5e8f\u540e\u7684\u56de\u6eaf\u4e2d\u7f3a\u5c11 break \u63a2\u7d22\u8fc7\u5927\u7684\u5019\u9009 \u6392\u5e8f + \u5019\u9009\u8d85\u8fc7\u5269\u4f59\u65f6 break DP \u521d\u59cb\u5316 dp[0] \u9519\u8bef \u2192 \u6240\u6709\u540e\u7eed\u503c\u90fd\u9519 \u4ed4\u7ec6\u5b9a\u4e49\u57fa\u672c\u60c5\u51b5 \u672a\u7ecf\u8bc1\u660e\u7684\u8d2a\u5fc3 \u8d2a\u5fc3\u5e76\u4e0d\u603b\u662f\u6709\u6548 \u9a8c\u8bc1\u8d2a\u5fc3\u9009\u62e9\u6027\u8d28 \u591a\u952e\u6392\u5e8f\u65f6\u4e0d\u7a33\u5b9a \u76f8\u7b49\u5143\u7d20\u7684\u76f8\u5bf9\u987a\u5e8f\u4e22\u5931 \u4f7f\u7528\u7a33\u5b9a\u6392\u5e8f\uff08\u5f52\u5e76\u6392\u5e8f\u3001Python \u7684 sorted\uff09"},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#neetcode","title":"\u8bfe\u540e\u7ec3\u4e60\u9898\uff08NeetCode\uff09","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#_22","title":"\u4e8c\u5206\u67e5\u627e","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#_23","title":"\u8d2a\u5fc3","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#_24","title":"\u52a8\u6001\u89c4\u5212","text":""},{"location":"chapter%2014%3A%20data%20structures%20and%20algorithms/05.%20sorting%20and%20search/#_25","title":"\u56de\u6eaf","text":""},{"location":"chapter%2015%3A%20production%20software%20engineering/01.%20linux%20and%20CMD/","title":"Linux \u4e0e\u547d\u4ee4\u884c","text":"

\u547d\u4ee4\u884c\u662f\u673a\u5668\u5b66\u4e60\u5de5\u7a0b\u7684\u4e3b\u8981\u754c\u9762\uff1a\u8bad\u7ec3\u4efb\u52a1\u3001\u670d\u52a1\u5668\u7ba1\u7406\u3001\u6570\u636e\u7ba1\u9053\u548c\u96c6\u7fa4\u7ba1\u7406\u90fd\u901a\u8fc7\u7ec8\u7aef\u8fdb\u884c\u3002\u672c\u6587\u6db5\u76d6 Shell\u3001\u6587\u4ef6\u7cfb\u7edf\u3001\u6743\u9650\u3001\u8fdb\u7a0b\u7ba1\u7406\u3001\u5305\u7ba1\u7406\u5668\u3001\u73af\u5883\u53d8\u91cf\u3001SSH \u4ee5\u53ca\u6bcf\u4f4d\u673a\u5668\u5b66\u4e60\u5de5\u7a0b\u5e08\u65e5\u5e38\u4f7f\u7528\u7684\u57fa\u672c\u547d\u4ee4\u3002

"},{"location":"chapter%2015%3A%20production%20software%20engineering/01.%20linux%20and%20CMD/#shell","title":"Shell","text":"
ls -la /home/user    # \u547d\u4ee4=ls, \u9009\u9879=-la, \u53c2\u6570=/home/user\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/01.%20linux%20and%20CMD/#_1","title":"\u57fa\u672c\u5bfc\u822a","text":"
pwd                 # \u6253\u5370\u5f53\u524d\u5de5\u4f5c\u76ee\u5f55\uff08\u6211\u5728\u54ea\uff1f\uff09\nls                  # \u5217\u51fa\u5f53\u524d\u76ee\u5f55\u4e2d\u7684\u6587\u4ef6\nls -la              # \u5217\u51fa\u6240\u6709\u6587\u4ef6\uff08\u5305\u62ec\u9690\u85cf\u6587\u4ef6\uff09\u53ca\u8be6\u7ec6\u4fe1\u606f\ncd /path/to/dir     # \u5207\u6362\u76ee\u5f55\ncd ..               # \u8fd4\u56de\u4e0a\u4e00\u7ea7\ncd ~                # \u8fd4\u56de\u7528\u6237\u4e3b\u76ee\u5f55\ncd -                # \u8fd4\u56de\u4e0a\u4e00\u4e2a\u76ee\u5f55\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/01.%20linux%20and%20CMD/#_2","title":"\u6587\u4ef6\u64cd\u4f5c","text":"
cp source dest      # \u590d\u5236\u6587\u4ef6\ncp -r dir1 dir2     # \u9012\u5f52\u590d\u5236\u76ee\u5f55\nmv old new          # \u79fb\u52a8/\u91cd\u547d\u540d\u6587\u4ef6\nrm file             # \u5220\u9664\u6587\u4ef6\uff08\u6ca1\u6709\u56de\u6536\u7ad9\u2014\u2014\u6c38\u4e45\u5220\u9664\uff09\nrm -rf dir          # \u9012\u5f52\u5220\u9664\u76ee\u5f55\uff08\u5371\u9669\u2014\u2014\u65e0\u786e\u8ba4\uff09\nmkdir -p a/b/c      # \u521b\u5efa\u5d4c\u5957\u76ee\u5f55\ntouch file.txt      # \u521b\u5efa\u7a7a\u6587\u4ef6\uff08\u6216\u66f4\u65b0\u65f6\u95f4\u6233\uff09\ncat file.txt        # \u6253\u5370\u6587\u4ef6\u5185\u5bb9\nhead -n 20 file     # \u663e\u793a\u524d 20 \u884c\ntail -f logfile     # \u5b9e\u65f6\u8ddf\u8e2a\u65e5\u5fd7\u6587\u4ef6\uff08\u76d1\u63a7\u8bad\u7ec3\u65f6\u975e\u5e38\u6709\u7528\uff09\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/01.%20linux%20and%20CMD/#_3","title":"\u7ba1\u9053\u4e0e\u91cd\u5b9a\u5411","text":"
cat training.log | grep \"loss\" | tail -5    # \u6700\u540e5\u884c\u5305\u542b\"loss\"\u7684\u5185\u5bb9\nps aux | grep python                        # \u67e5\u627e\u6b63\u5728\u8fd0\u884c\u7684 Python \u8fdb\u7a0b\nhistory | grep \"docker\"                     # \u67e5\u627e\u4e4b\u524d\u7684 docker \u547d\u4ee4\n
python train.py > output.log 2>&1    # stdout \u548c stderr \u90fd\u8f93\u51fa\u5230\u6587\u4ef6\npython train.py >> output.log        # \u8ffd\u52a0\uff08\u4e0d\u8986\u76d6\uff09\necho \"data\" > file.txt               # \u8986\u76d6\u6587\u4ef6\necho \"more\" >> file.txt              # \u8ffd\u52a0\u5230\u6587\u4ef6\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/01.%20linux%20and%20CMD/#_4","title":"\u6587\u672c\u5904\u7406","text":"
grep \"error\" logfile.txt             # \u67e5\u627e\u5305\u542b\"error\"\u7684\u884c\ngrep -r \"import torch\" src/          # \u9012\u5f52\u641c\u7d22\u76ee\u5f55\ngrep -i \"warning\" log.txt            # \u4e0d\u533a\u5206\u5927\u5c0f\u5199\u641c\u7d22\ngrep -c \"epoch\" train.log            # \u7edf\u8ba1\u5339\u914d\u884c\u6570\n\nwc -l file.txt                       # \u7edf\u8ba1\u884c\u6570\nwc -w file.txt                       # \u7edf\u8ba1\u5355\u8bcd\u6570\n\nsort data.txt                        # \u6309\u5b57\u6bcd\u987a\u5e8f\u6392\u5e8f\nsort -n numbers.txt                  # \u6309\u6570\u503c\u6392\u5e8f\nsort -u data.txt                     # \u6392\u5e8f\u5e76\u53bb\u91cd\nuniq -c sorted.txt                   # \u7edf\u8ba1\u8fde\u7eed\u91cd\u590d\u9879\n\ncut -d',' -f2,3 data.csv            # \u63d0\u53d6 CSV \u7684\u7b2c 2 \u548c\u7b2c 3 \u5217\nawk '{print $1, $3}' data.txt       # \u6253\u5370\u7b2c 1 \u548c\u7b2c 3 \u4e2a\u7a7a\u767d\u5206\u9694\u5b57\u6bb5\nsed 's/old/new/g' file.txt          # \u5c06\u6240\u6709\"old\"\u66ff\u6362\u4e3a\"new\"\n
# \u67e5\u627e\u65e5\u5fd7\u6587\u4ef6\u4e2d\u6700\u5e38\u89c1\u7684 10 \u79cd\u9519\u8bef\u7c7b\u578b\ngrep \"ERROR\" app.log | awk -F': ' '{print $2}' | sort | uniq -c | sort -rn | head -10\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/01.%20linux%20and%20CMD/#_5","title":"\u67e5\u627e\u6587\u4ef6","text":"
find . -name \"*.py\"                  # \u67e5\u627e\u6240\u6709 Python \u6587\u4ef6\nfind . -name \"*.pyc\" -delete         # \u67e5\u627e\u5e76\u5220\u9664\u7f16\u8bd1\u540e\u7684 Python \u6587\u4ef6\nfind /data -size +100M               # \u67e5\u627e\u5927\u4e8e 100MB \u7684\u6587\u4ef6\nfind . -mtime -1                     # \u67e5\u627e\u8fc7\u53bb 24 \u5c0f\u65f6\u5185\u4fee\u6539\u8fc7\u7684\u6587\u4ef6\n\nwhich python                        # python \u53ef\u6267\u884c\u6587\u4ef6\u5728\u54ea\uff1f\nlocate filename                      # \u5feb\u901f\u67e5\u627e\u6587\u4ef6\uff08\u4f7f\u7528\u9884\u6784\u5efa\u7d22\u5f15\uff09\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/01.%20linux%20and%20CMD/#_6","title":"\u6587\u4ef6\u7cfb\u7edf\u5c42\u6b21\u7ed3\u6784","text":" \u76ee\u5f55 \u7528\u9014 / \u6574\u4e2a\u6587\u4ef6\u7cfb\u7edf\u7684\u6839 /home/user \u4f60\u7684\u4e2a\u4eba\u6587\u4ef6\u3001\u914d\u7f6e\u3001\u9879\u76ee /etc \u7cfb\u7edf\u7ea7\u914d\u7f6e\u6587\u4ef6 /usr \u7528\u6237\u7a0b\u5e8f\u3001\u5e93\u3001\u6587\u6863 /usr/local \u672c\u5730\u5b89\u88c5\u7684\u8f6f\u4ef6\uff08\u975e\u5305\u7ba1\u7406\u5668\u5b89\u88c5\uff09 /var \u53ef\u53d8\u6570\u636e\uff1a\u65e5\u5fd7\uff08/var/log\uff09\u3001\u6570\u636e\u5e93\u3001\u7f13\u5b58 /tmp \u4e34\u65f6\u6587\u4ef6\uff08\u91cd\u542f\u540e\u6e05\u9664\uff09 /opt \u53ef\u9009\u7684\u7b2c\u4e09\u65b9\u8f6f\u4ef6 /proc \u66b4\u9732\u5185\u6838\u548c\u8fdb\u7a0b\u4fe1\u606f\u7684\u865a\u62df\u6587\u4ef6\u7cfb\u7edf /dev \u8bbe\u5907\u6587\u4ef6\uff08\u78c1\u76d8\u3001GPU \u5728\u8fd9\u91cc\u663e\u793a\uff09 "},{"location":"chapter%2015%3A%20production%20software%20engineering/01.%20linux%20and%20CMD/#_7","title":"\u6587\u4ef6\u6743\u9650","text":" \u6743\u9650 \u6587\u4ef6 \u76ee\u5f55 r\uff08\u8bfb\uff09 \u67e5\u770b\u5185\u5bb9 \u5217\u51fa\u5185\u5bb9 w\uff08\u5199\uff09 \u4fee\u6539\u5185\u5bb9 \u5728\u5185\u90e8\u521b\u5efa/\u5220\u9664\u6587\u4ef6 x\uff08\u6267\u884c\uff09 \u4f5c\u4e3a\u7a0b\u5e8f\u8fd0\u884c \u8fdb\u5165\uff08cd \u8fdb\u5165\uff09\u76ee\u5f55
ls -l script.py\n# -rwxr-xr-- 1 henry ml_team 2048 Mar 28 script.py\n#  ^^^         \u6240\u6709\u8005\u6743\u9650\uff1arwx\uff08\u8bfb\u3001\u5199\u3001\u6267\u884c\uff09\n#     ^^^      \u7ec4\u6743\u9650\uff1ar-x\uff08\u8bfb\u3001\u6267\u884c\uff0c\u4e0d\u53ef\u5199\uff09\n#        ^^^   \u5176\u4ed6\u4eba\u6743\u9650\uff1ar--\uff08\u53ea\u8bfb\uff09\n
chmod 755 script.py       # owner=rwx, group=rx, others=rx\nchmod +x script.py        # \u4e3a\u6240\u6709\u4eba\u6dfb\u52a0\u6267\u884c\u6743\u9650\nchmod u+w,g-w file.txt    # \u4e3a\u6240\u6709\u8005\u6dfb\u52a0\u5199\u6743\u9650\uff0c\u79fb\u9664\u7ec4\u7684\u5199\u6743\u9650\nchown henry:ml_team file  # \u66f4\u6539\u6240\u6709\u8005\u548c\u7ec4\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/01.%20linux%20and%20CMD/#_8","title":"\u8fdb\u7a0b\u7ba1\u7406","text":"
ps aux                    # \u5217\u51fa\u6240\u6709\u6b63\u5728\u8fd0\u884c\u7684\u8fdb\u7a0b\nps aux | grep python      # \u67e5\u627e Python \u8fdb\u7a0b\ntop                       # \u5b9e\u65f6\u8fdb\u7a0b\u76d1\u63a7\uff08CPU\u3001\u5185\u5b58\uff09\nhtop                      # top \u7684\u589e\u5f3a\u7248\uff08\u9700\u5355\u72ec\u5b89\u88c5\uff09\nnvidia-smi                # GPU \u4f7f\u7528\u60c5\u51b5\uff08\u673a\u5668\u5b66\u4e60\u5fc5\u5907\uff09\nwatch -n 1 nvidia-smi     # \u6bcf\u79d2\u5237\u65b0 nvidia-smi\n\nkill PID                  # \u4f18\u96c5\u7ec8\u6b62\u8fdb\u7a0b\nkill -9 PID               # \u5f3a\u5236\u7ec8\u6b62\uff08\u4f18\u96c5\u65b9\u5f0f\u5931\u8d25\u65f6\u4f7f\u7528\uff09\nkillall python            # \u7ec8\u6b62\u6240\u6709 Python \u8fdb\u7a0b\n\n# \u540e\u53f0\u8fd0\u884c\npython train.py &                    # \u540e\u53f0\u8fd0\u884c\nnohup python train.py > log.txt &    # \u540e\u53f0\u8fd0\u884c\uff0c\u9000\u51fa\u767b\u5f55\u540e\u4ecd\u5b58\u6d3b\n
tmux new -s training          # \u521b\u5efa\u547d\u540d\u4f1a\u8bdd\n# ... \u5f00\u59cb\u8bad\u7ec3 ...\n# Ctrl+B, \u7136\u540e D              # \u4ece\u4f1a\u8bdd\u5206\u79bb\ntmux attach -t training       # \u7a0d\u540e\u91cd\u65b0\u8fde\u63a5\uff08\u5373\u4f7f SSH \u91cd\u65b0\u8fde\u63a5\u540e\u4e5f\u53ef\u7528\uff09\ntmux ls                       # \u5217\u51fa\u4f1a\u8bdd\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/01.%20linux%20and%20CMD/#_9","title":"\u5305\u7ba1\u7406\u5668","text":"
# Debian/Ubuntu\nsudo apt update               # \u5237\u65b0\u5305\u5217\u8868\nsudo apt install htop         # \u5b89\u88c5\u5305\nsudo apt upgrade              # \u5347\u7ea7\u6240\u6709\u5305\n\n# macOS\nbrew install wget             # \u901a\u8fc7 Homebrew \u5b89\u88c5\n
pip install torch             # \u4ece PyPI \u5b89\u88c5\npip install -e .              # \u4ee5\u53ef\u7f16\u8f91\u6a21\u5f0f\u5b89\u88c5\u5f53\u524d\u9879\u76ee\npip install -r requirements.txt  # \u4ece requirements \u6587\u4ef6\u5b89\u88c5\npip freeze > requirements.txt    # \u5bfc\u51fa\u5df2\u5b89\u88c5\u7684\u5305\n\n# Conda\uff08\u7528\u4e8e\u590d\u6742\u4f9d\u8d56\uff0c\u5982 CUDA\uff09\nconda create -n myenv python=3.11\nconda activate myenv\nconda install pytorch torchvision cudatoolkit=12.1 -c pytorch\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/01.%20linux%20and%20CMD/#_10","title":"\u73af\u5883\u53d8\u91cf","text":"
export CUDA_VISIBLE_DEVICES=0,1    # \u4ec5\u4f7f\u7528 GPU 0 \u548c 1\nexport PYTHONPATH=/home/user/src   # \u6dfb\u52a0\u5230 Python \u7684\u5bfc\u5165\u8def\u5f84\nexport WANDB_API_KEY=abc123        # Weights & Biases \u7684 API \u5bc6\u94a5\n\necho $PATH                         # \u67e5\u770b\u5f53\u524d PATH\nexport PATH=$PATH:/usr/local/cuda/bin  # \u5c06 CUDA \u6dfb\u52a0\u5230 PATH\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/01.%20linux%20and%20CMD/#ssh","title":"SSH\uff08\u5b89\u5168\u5916\u58f3\u534f\u8bae\uff09","text":"
ssh user@hostname              # \u8fde\u63a5\u5230\u8fdc\u7a0b\u673a\u5668\nssh -i ~/.ssh/key.pem user@ip  # \u4f7f\u7528\u7279\u5b9a\u5bc6\u94a5\u8fde\u63a5\nssh -L 8888:localhost:8888 user@server  # \u7aef\u53e3\u8f6c\u53d1\uff08\u8fdc\u7a0b Jupyter\uff09\n
ssh-keygen -t ed25519          # \u751f\u6210\u5bc6\u94a5\u5bf9\nssh-copy-id user@server        # \u5c06\u516c\u94a5\u590d\u5236\u5230\u670d\u52a1\u5668\n# \u73b0\u5728\u65e0\u9700\u8f93\u5165\u5bc6\u7801\u5373\u53ef SSH\n
Host gpu-server\n    HostName 10.0.1.42\n    User henry\n    IdentityFile ~/.ssh/gpu_key\n    LocalForward 8888 localhost:8888\n
scp model.pt user@server:/data/models/     # \u5c06\u6587\u4ef6\u590d\u5236\u5230\u8fdc\u7a0b\nscp -r user@server:/data/results/ ./       # \u4ece\u8fdc\u7a0b\u590d\u5236\u76ee\u5f55\nrsync -avz --progress data/ user@server:/data/  # \u5e26\u8fdb\u5ea6\u540c\u6b65\uff08\u6bd4 scp \u66f4\u667a\u80fd\uff09\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/01.%20linux%20and%20CMD/#_11","title":"\u673a\u5668\u5b66\u4e60\u5fc5\u5907\u547d\u4ee4\u901f\u67e5\u8868","text":"
# GPU \u76d1\u63a7\nnvidia-smi                                   # GPU \u4f7f\u7528\u5feb\u7167\nwatch -n 1 nvidia-smi                        # \u5b9e\u65f6\u76d1\u63a7\ngpustat                                      # \u66f4\u6e05\u6670\u7684 GPU \u6982\u89c8\uff08pip install gpustat\uff09\n\n# \u8bad\u7ec3\u7ba1\u7406\nnohup python train.py > train.log 2>&1 &     # \u9000\u51fa\u767b\u5f55\u540e\u4ecd\u5b58\u6d3b\u7684\u540e\u53f0\u8bad\u7ec3\ntail -f train.log                            # \u76d1\u63a7\u8bad\u7ec3\u8f93\u51fa\nkill %1                                      # \u7ec8\u6b62\u6700\u540e\u4e00\u4e2a\u540e\u53f0\u4efb\u52a1\n\n# \u78c1\u76d8\u4f7f\u7528\uff08\u6570\u636e\u96c6\u5f88\u5927\uff09\ndf -h                                        # \u6240\u6709\u6302\u8f7d\u70b9\u7684\u78c1\u76d8\u7a7a\u95f4\ndu -sh /data/*                               # /data \u4e2d\u6bcf\u4e2a\u9879\u76ee\u7684\u5927\u5c0f\ndu -sh --max-depth=1 .                       # \u5b50\u76ee\u5f55\u7684\u5927\u5c0f\n\n# \u5185\u5b58\nfree -h                                      # RAM \u4f7f\u7528\u60c5\u51b5\ncat /proc/meminfo                            # \u8be6\u7ec6\u5185\u5b58\u4fe1\u606f\n\n# \u7f51\u7edc\ncurl -O https://example.com/dataset.tar.gz   # \u4e0b\u8f7d\u6587\u4ef6\nwget https://example.com/model.bin           # \u66ff\u4ee3\u4e0b\u8f7d\u5de5\u5177\ncurl -X POST http://localhost:8080/predict \\\n    -H \"Content-Type: application/json\" \\\n    -d '{\"text\": \"hello\"}'                   # \u6d4b\u8bd5\u6a21\u578b\u63a8\u7406\u7aef\u70b9\n\n# \u5f52\u6863\ntar -czf archive.tar.gz directory/           # \u538b\u7f29\ntar -xzf archive.tar.gz                      # \u89e3\u538b\nzip -r archive.zip directory/                # zip \u538b\u7f29\nunzip archive.zip                            # zip \u89e3\u538b\n\n# \u5feb\u901f\u6570\u636e\u68c0\u67e5\nhead -5 data.csv                             # CSV \u7684\u524d 5 \u884c\nwc -l data.csv                               # \u7edf\u8ba1\u884c\u6570\ncut -d',' -f1 data.csv | sort -u | wc -l    # \u7edf\u8ba1\u7b2c 1 \u5217\u7684\u552f\u4e00\u503c\u6570\u91cf\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/02.%20git%20and%20repository%20management/","title":"Git \u4e0e\u7248\u672c\u63a7\u5236","text":"

Git \u662f\u8f6f\u4ef6\u56e2\u961f\u5728\u4e0d\u76f8\u4e92\u8986\u76d6\u5de5\u4f5c\u7684\u60c5\u51b5\u4e0b\u8fdb\u884c\u534f\u4f5c\u7684\u65b9\u5f0f\u3002\u672c\u6587\u6db5\u76d6 Git \u7684\u5fc3\u667a\u6a21\u578b\u3001\u5206\u652f\u7b56\u7565\u3001\u5408\u5e76\u4e0e\u53d8\u57fa\u3001\u51b2\u7a81\u89e3\u51b3\u3001\u62c9\u53d6\u8bf7\u6c42\uff0c\u4ee5\u53ca\u7ba1\u7406\u673a\u5668\u5b66\u4e60\u7279\u5b9a\u6311\u6218\uff08\u5982\u5927\u6587\u4ef6\u548c\u5b9e\u9a8c\u8ffd\u8e2a\uff09\u7684\u65b9\u6cd5\u3002

"},{"location":"chapter%2015%3A%20production%20software%20engineering/02.%20git%20and%20repository%20management/#_1","title":"\u5fc3\u667a\u6a21\u578b","text":"
Working Dir  \u2192  git add  \u2192  Staging  \u2192  git commit  \u2192  Local Repo  \u2192  git push  \u2192  Remote\n                                                        \u2190  git pull  \u2190\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/02.%20git%20and%20repository%20management/#_2","title":"\u57fa\u672c\u547d\u4ee4","text":"
git init                          # \u521b\u5efa\u65b0\u4ed3\u5e93\ngit clone url                     # \u4e0b\u8f7d\u8fdc\u7a0b\u4ed3\u5e93\ngit status                        # \u6709\u4ec0\u4e48\u53d8\u5316\uff1f\uff08\u6700\u5e38\u7528\u7684\u547d\u4ee4\uff09\ngit add file.py                   # \u6682\u5b58\u7279\u5b9a\u6587\u4ef6\ngit add .                         # \u6682\u5b58\u6240\u6709\u66f4\u6539\uff08\u8c28\u614e\u4f7f\u7528\uff09\ngit commit -m \"descriptive msg\"   # \u63d0\u4ea4\u6682\u5b58\u7684\u66f4\u6539\ngit push                          # \u5c06\u63d0\u4ea4\u4e0a\u4f20\u5230\u8fdc\u7a0b\ngit pull                          # \u4e0b\u8f7d\u5e76\u5408\u5e76\u8fdc\u7a0b\u66f4\u6539\ngit log --oneline                 # \u7d27\u51d1\u7684\u63d0\u4ea4\u5386\u53f2\ngit diff                          # \u663e\u793a\u672a\u6682\u5b58\u7684\u66f4\u6539\ngit diff --staged                 # \u663e\u793a\u5df2\u6682\u5b58\u7684\u66f4\u6539\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/02.%20git%20and%20repository%20management/#_3","title":"\u5206\u652f","text":"
git branch feature-x              # \u521b\u5efa\u5206\u652f\ngit checkout feature-x            # \u5207\u6362\u5230\u6b64\u5206\u652f\ngit checkout -b feature-x         # \u521b\u5efa\u5e76\u5207\u6362\uff08\u4e00\u6b65\u5b8c\u6210\uff09\ngit branch -d feature-x           # \u5220\u9664\u5206\u652f\uff08\u5408\u5e76\u540e\uff09\ngit branch -a                     # \u5217\u51fa\u6240\u6709\u5206\u652f\uff08\u672c\u5730 + \u8fdc\u7a0b\uff09\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/02.%20git%20and%20repository%20management/#_4","title":"\u5206\u652f\u7b56\u7565","text":""},{"location":"chapter%2015%3A%20production%20software%20engineering/02.%20git%20and%20repository%20management/#_5","title":"\u5408\u5e76\u4e0e\u53d8\u57fa","text":"
git checkout main\ngit merge feature-x\n
git checkout feature-x\ngit rebase main\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/02.%20git%20and%20repository%20management/#_6","title":"\u89e3\u51b3\u51b2\u7a81","text":"
<<<<<<< HEAD\nlearning_rate = 0.001\n=======\nlearning_rate = 0.0005\n>>>>>>> feature-x\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/02.%20git%20and%20repository%20management/#_7","title":"\u7f16\u5199\u826f\u597d\u7684\u63d0\u4ea4\u4fe1\u606f","text":"
\u7b80\u77ed\u6458\u8981\uff0850 \u5b57\u4ee5\u5185\uff0c\u7948\u4f7f\u8bed\u6c14\uff09\n\n\u5982\u679c\u9700\u8981\uff0c\u53ef\u9644\u5e26\u66f4\u957f\u7684\u63cf\u8ff0\u3002\u89e3\u91ca WHY\uff0c\u800c\u4e0d\u662f WHAT\n\uff08\u5dee\u5f02\u663e\u793a\u4e86\u4ec0\u4e48\u6539\u53d8\u4e86\uff09\u3002\u6bcf\u884c\u4e0d\u8d85\u8fc7 72 \u4e2a\u5b57\u7b26\u3002\n\nFixes #123\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/02.%20git%20and%20repository%20management/#_8","title":"\u62c9\u53d6\u8bf7\u6c42\u4e0e\u4ee3\u7801\u5ba1\u67e5","text":""},{"location":"chapter%2015%3A%20production%20software%20engineering/02.%20git%20and%20repository%20management/#gitignore","title":".gitignore","text":"
# Python\n__pycache__/\n*.pyc\n*.egg-info/\n.venv/\nenv/\n\n# \u6570\u636e\u548c\u6a21\u578b\uff08\u5bf9 git \u6765\u8bf4\u592a\u5927\uff09\ndata/\n*.csv\n*.parquet\nmodels/\n*.pt\n*.onnx\n*.bin\ncheckpoints/\n\n# \u5bc6\u94a5\n.env\n*.pem\ncredentials.json\n\n# IDE\n.vscode/\n.idea/\n*.swp\n\n# \u64cd\u4f5c\u7cfb\u7edf\n.DS_Store\nThumbs.db\n\n# Jupyter\n.ipynb_checkpoints/\n\n# \u5b9e\u9a8c\u8f93\u51fa\nwandb/\nmlruns/\noutputs/\nlogs/\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/02.%20git%20and%20repository%20management/#git_1","title":"Git \u5728\u673a\u5668\u5b66\u4e60\u4e2d\u7684\u5e94\u7528","text":"
# \u5feb\u901f\u53ef\u91cd\u73b0\u6027\u5feb\u7167\necho \"Commit: $(git rev-parse HEAD)\" > experiment_info.txt\necho \"Branch: $(git branch --show-current)\" >> experiment_info.txt\necho \"Dirty: $(git status --porcelain | wc -l) files\" >> experiment_info.txt\npip freeze >> experiment_info.txt\nnvidia-smi >> experiment_info.txt\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/03.%20codebase%20design/","title":"\u4ee3\u7801\u5e93\u8bbe\u8ba1\u4e0e\u6a21\u5f0f","text":"

\u826f\u597d\u7684\u4ee3\u7801\u5e93\u8bbe\u8ba1\u662f\u533a\u5206\u7814\u7a76\u539f\u578b\u4e0e\u751f\u4ea7\u7ea7\u8f6f\u4ef6\u7684\u5173\u952e\u3002\u672c\u6587\u6db5\u76d6\u9879\u76ee\u7ed3\u6784\u3001\u6574\u6d01\u4ee3\u7801\u539f\u5219\u3001\u4e0e\u673a\u5668\u5b66\u4e60\u76f8\u5173\u7684\u8bbe\u8ba1\u6a21\u5f0f\u3001\u914d\u7f6e\u7ba1\u7406\u3001\u65e5\u5fd7\u3001API \u8bbe\u8ba1\u4ee5\u53ca\u6253\u5305\u5206\u53d1\u3002

"},{"location":"chapter%2015%3A%20production%20software%20engineering/03.%20codebase%20design/#_2","title":"\u9879\u76ee\u7ed3\u6784","text":"
my_project/\n\u251c\u2500\u2500 src/my_project/       # \u6e90\u4ee3\u7801\uff08\u53ef\u5bfc\u5165\u7684\u5305\uff09\n\u2502   \u251c\u2500\u2500 __init__.py\n\u2502   \u251c\u2500\u2500 data/             # \u6570\u636e\u52a0\u8f7d\u548c\u9884\u5904\u7406\n\u2502   \u2502   \u251c\u2500\u2500 __init__.py\n\u2502   \u2502   \u251c\u2500\u2500 dataset.py\n\u2502   \u2502   \u2514\u2500\u2500 transforms.py\n\u2502   \u251c\u2500\u2500 models/           # \u6a21\u578b\u67b6\u6784\n\u2502   \u2502   \u251c\u2500\u2500 __init__.py\n\u2502   \u2502   \u251c\u2500\u2500 transformer.py\n\u2502   \u2502   \u2514\u2500\u2500 layers.py\n\u2502   \u251c\u2500\u2500 training/         # \u8bad\u7ec3\u5faa\u73af\u3001\u4f18\u5316\u5668\n\u2502   \u2502   \u251c\u2500\u2500 __init__.py\n\u2502   \u2502   \u251c\u2500\u2500 trainer.py\n\u2502   \u2502   \u2514\u2500\u2500 losses.py\n\u2502   \u2514\u2500\u2500 utils/            # \u5171\u4eab\u5de5\u5177\n\u2502       \u251c\u2500\u2500 __init__.py\n\u2502       \u2514\u2500\u2500 logging.py\n\u251c\u2500\u2500 configs/              # \u914d\u7f6e\u6587\u4ef6\n\u2502   \u251c\u2500\u2500 base.yaml\n\u2502   \u2514\u2500\u2500 experiment_1.yaml\n\u251c\u2500\u2500 scripts/              # \u5165\u53e3\u70b9\uff08\u8bad\u7ec3\u3001\u8bc4\u4f30\u3001\u63a8\u7406\uff09\n\u2502   \u251c\u2500\u2500 train.py\n\u2502   \u251c\u2500\u2500 evaluate.py\n\u2502   \u2514\u2500\u2500 serve.py\n\u251c\u2500\u2500 tests/                # \u6d4b\u8bd5\u6587\u4ef6\uff08\u955c\u50cf src/ \u7ed3\u6784\uff09\n\u2502   \u251c\u2500\u2500 test_dataset.py\n\u2502   \u251c\u2500\u2500 test_model.py\n\u2502   \u2514\u2500\u2500 test_trainer.py\n\u251c\u2500\u2500 notebooks/            # \u4ec5\u7528\u4e8e\u63a2\u7d22\uff08\u975e\u751f\u4ea7\u4ee3\u7801\uff09\n\u251c\u2500\u2500 pyproject.toml        # \u9879\u76ee\u5143\u6570\u636e\u548c\u4f9d\u8d56\n\u251c\u2500\u2500 README.md\n\u251c\u2500\u2500 .gitignore\n\u2514\u2500\u2500 Dockerfile\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/03.%20codebase%20design/#_3","title":"\u6574\u6d01\u4ee3\u7801\u539f\u5219","text":"
# \u7cdf\u7cd5\ndef proc(d, n, lr):\n    for i in range(n):\n        for k, v in d.items():\n            v -= lr * g[k]\n\n# \u826f\u597d\ndef update_parameters(parameters, num_steps, learning_rate):\n    for step in range(num_steps):\n        for name, param in parameters.items():\n            param -= learning_rate * gradients[name]\n
# \u8fc7\u65e9\u62bd\u8c61\uff08\u4e00\u4e2a\u7528\u4f8b\uff0c\u8fc7\u5ea6\u8bbe\u8ba1\uff09\nclass AbstractDataTransformPipelineFactory:\n    ...\n\n# \u6070\u5230\u597d\u5904\uff08\u76f4\u63a5\u3001\u6e05\u6670\u3001\u5728\u4e09\u5904\u4f7f\u7528\uff09\ndef normalise_image(image, mean, std):\n    return (image - mean) / std\n
# \u7cdf\u7cd5\nif len(batch) > 32:\n    split_batch(batch, 32)\n\n# \u826f\u597d\nMAX_BATCH_SIZE = 32\nif len(batch) > MAX_BATCH_SIZE:\n    split_batch(batch, MAX_BATCH_SIZE)\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/03.%20codebase%20design/#_4","title":"\u9002\u7528\u4e8e\u673a\u5668\u5b66\u4e60\u7684\u8bbe\u8303\u8ba1\u5f0f","text":"
MODEL_REGISTRY = {\n    \"transformer\": TransformerModel,\n    \"cnn\": CNNModel,\n    \"mlp\": MLPModel,\n}\n\ndef build_model(config):\n    model_cls = MODEL_REGISTRY[config[\"model\"]]\n    return model_cls(**config[\"model_params\"])\n
LOSS_FUNCTIONS = {\n    \"mse\": nn.MSELoss,\n    \"cross_entropy\": nn.CrossEntropyLoss,\n    \"focal\": FocalLoss,\n}\n\nloss_fn = LOSS_FUNCTIONS[config[\"loss\"]]()\n
class EarlyStopping:\n    def __init__(self, patience=5):\n        self.patience = patience\n        self.best_loss = float('inf')\n        self.counter = 0\n\n    def on_epoch_end(self, epoch, val_loss):\n        if val_loss < self.best_loss:\n            self.best_loss = val_loss\n            self.counter = 0\n        else:\n            self.counter += 1\n            if self.counter >= self.patience:\n                return \"stop\"\n
# \u7cdf\u7cd5\uff1a\u786c\u7f16\u7801\u4f9d\u8d56\nclass Trainer:\n    def __init__(self):\n        self.logger = WandbLogger()  # \u6ca1\u6709 W&B \u5c31\u65e0\u6cd5\u6d4b\u8bd5\n\n# \u826f\u597d\uff1a\u6ce8\u5165\u4f9d\u8d56\nclass Trainer:\n    def __init__(self, logger):\n        self.logger = logger  # \u53ef\u4ee5\u6ce8\u5165\u4efb\u4f55\u8bb0\u5f55\u5668\uff0c\u5305\u62ec mock\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/03.%20codebase%20design/#_5","title":"\u914d\u7f6e\u7ba1\u7406","text":"
# configs/experiment_1.yaml\nmodel:\n  name: transformer\n  d_model: 512\n  n_heads: 8\n  n_layers: 6\n\ntraining:\n  batch_size: 64\n  learning_rate: 3e-4\n  max_epochs: 100\n  early_stopping_patience: 10\n\ndata:\n  train_path: /data/train.parquet\n  val_path: /data/val.parquet\n  max_seq_length: 512\n
import argparse\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--lr\", type=float, default=3e-4)\nparser.add_argument(\"--batch-size\", type=int, default=64)\nparser.add_argument(\"--config\", type=str, default=\"configs/base.yaml\")\nargs = parser.parse_args()\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/03.%20codebase%20design/#_6","title":"\u65e5\u5fd7\u4e0e\u53ef\u89c2\u6d4b\u6027","text":"
import logging\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n\nlogger.debug(\"Batch loaded: %d samples\", len(batch))     # \u8be6\u7ec6\uff0c\u7528\u4e8e\u8c03\u8bd5\nlogger.info(\"Epoch %d: loss=%.4f, lr=%.6f\", epoch, loss, lr)  # \u6b63\u5e38\u8fd0\u884c\nlogger.warning(\"GPU memory >90%%, consider reducing batch size\")\nlogger.error(\"Failed to load checkpoint: %s\", path)       # \u53ef\u6062\u590d\u7684\u9519\u8bef\nlogger.critical(\"CUDA out of memory, aborting\")            # \u81f4\u547d\u9519\u8bef\n
logger.info(\"training_step\", extra={\n    \"epoch\": 5, \"step\": 1200, \"loss\": 0.0342, \"lr\": 2.1e-4\n})\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/03.%20codebase%20design/#api","title":"API \u8bbe\u8ba1","text":"
POST /api/v1/predict          # \u53d1\u9001\u8f93\u5165\uff0c\u83b7\u53d6\u9884\u6d4b\u7ed3\u679c\nGET  /api/v1/models           # \u5217\u51fa\u53ef\u7528\u6a21\u578b\nGET  /api/v1/models/{id}      # \u83b7\u53d6\u6a21\u578b\u8be6\u60c5\nPOST /api/v1/models/{id}/predict  # \u4f7f\u7528\u7279\u5b9a\u6a21\u578b\u8fdb\u884c\u9884\u6d4b\n
from fastapi import FastAPI\nfrom pydantic import BaseModel\n\napp = FastAPI()\n\nclass PredictRequest(BaseModel):\n    text: str\n\nclass PredictResponse(BaseModel):\n    label: str\n    confidence: float\n\n@app.post(\"/predict\", response_model=PredictResponse)\nasync def predict(request: PredictRequest):\n    result = model.predict(request.text)\n    return PredictResponse(label=result.label, confidence=result.score)\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/03.%20codebase%20design/#_7","title":"\u6253\u5305\u4e0e\u5206\u53d1","text":"
# pyproject.toml\n[project]\nname = \"my-ml-project\"\nversion = \"0.1.0\"\nrequires-python = \">=3.10\"\ndependencies = [\n    \"torch>=2.0\",\n    \"jax>=0.4\",\n    \"pydantic>=2.0\",\n]\n\n[project.optional-dependencies]\ndev = [\"pytest\", \"ruff\", \"mypy\"]\n\n[build-system]\nrequires = [\"setuptools>=64\"]\nbuild-backend = \"setuptools.backends._legacy:_Backend\"\n
pip install -e \".[dev]\"    # \u4ee5\u53ef\u7f16\u8f91\u6a21\u5f0f\u5b89\u88c5\uff0c\u5305\u542b\u5f00\u53d1\u4f9d\u8d56\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/03.%20codebase%20design/#ai","title":"\u4f7f\u7528 AI \u7f16\u7801\u52a9\u624b","text":""},{"location":"chapter%2015%3A%20production%20software%20engineering/03.%20codebase%20design/#ai_1","title":"AI \u52a9\u624b\u64c5\u957f\u4e4b\u5904","text":""},{"location":"chapter%2015%3A%20production%20software%20engineering/03.%20codebase%20design/#ai_2","title":"\u4f55\u65f6\u4e0d\u5e94\u4f9d\u8d56 AI \u52a9\u624b","text":""},{"location":"chapter%2015%3A%20production%20software%20engineering/03.%20codebase%20design/#_8","title":"\u5ba1\u67e5\u7eaa\u5f8b","text":""},{"location":"chapter%2015%3A%20production%20software%20engineering/03.%20codebase%20design/#_9","title":"\u5982\u4f55\u7f16\u5199\u597d\u7684\u63d0\u793a\u8bcd","text":""},{"location":"chapter%2015%3A%20production%20software%20engineering/03.%20codebase%20design/#ai_3","title":"\u4f7f\u7528\u8d28\u91cf\u95e8\u63a7\u6765\u6355\u6349 AI \u52a9\u624b\u7684\u9519\u8bef","text":""},{"location":"chapter%2015%3A%20production%20software%20engineering/03.%20codebase%20design/#_10","title":"\u751f\u4ea7\u529b\u9677\u9631","text":""},{"location":"chapter%2015%3A%20production%20software%20engineering/04.%20testing%20and%20quality%20assurance/","title":"\u6d4b\u8bd5\u4e0e\u8d28\u91cf\u4fdd\u969c","text":"

\u6d4b\u8bd5\u662f\u4f60\u5982\u4f55\u786e\u4fdd\u4ee3\u7801\u6b63\u5e38\u5de5\u4f5c\u7684\u65b9\u6cd5\u2014\u2014\u4e0d\u4ec5\u662f\u73b0\u5728\uff0c\u800c\u4e14\u5728\u6bcf\u6b21\u66f4\u6539\u540e\u90fd\u80fd\u6b63\u5e38\u5de5\u4f5c\u3002\u672c\u6587\u6db5\u76d6\u6d4b\u8bd5\u91d1\u5b57\u5854\u3001\u4f7f\u7528 pytest \u8fdb\u884c\u7684\u5355\u5143\u6d4b\u8bd5\u3001Mock\u3001\u6d4b\u8bd5\u673a\u5668\u5b66\u4e60\u7279\u5b9a\u4ee3\u7801\u3001CI/CD \u7ba1\u9053\u3001\u4ee3\u7801\u68c0\u67e5\u3001\u683c\u5f0f\u5316\u548c\u4ee3\u7801\u5ba1\u67e5\u2014\u2014\u8fd9\u4e9b\u5b9e\u8df5\u80fd\u5728\u9519\u8bef\u5230\u8fbe\u751f\u4ea7\u73af\u5883\u4e4b\u524d\u6355\u83b7\u5b83\u4eec\u3002

"},{"location":"chapter%2015%3A%20production%20software%20engineering/04.%20testing%20and%20quality%20assurance/#_2","title":"\u6d4b\u8bd5\u91d1\u5b57\u5854","text":""},{"location":"chapter%2015%3A%20production%20software%20engineering/04.%20testing%20and%20quality%20assurance/#pytest","title":"\u4f7f\u7528 pytest \u8fdb\u884c\u5355\u5143\u6d4b\u8bd5","text":"
# tests/test_utils.py\n\ndef test_normalise_image():\n    import numpy as np\n    image = np.array([0, 128, 255], dtype=np.uint8)\n    result = normalise_image(image, mean=128, std=128)\n    assert result.min() >= -1.0\n    assert result.max() <= 1.0\n    assert abs(result[1]) < 1e-6  # 128 \u88ab mean=128 \u5f52\u4e00\u5316\u540e\u5e94\u7ea6\u4e3a 0\n\ndef test_normalise_empty():\n    import numpy as np\n    image = np.array([], dtype=np.uint8)\n    result = normalise_image(image, mean=128, std=128)\n    assert len(result) == 0\n
pytest tests/                     # \u8fd0\u884c\u6240\u6709\u6d4b\u8bd5\npytest tests/test_utils.py        # \u8fd0\u884c\u4e00\u4e2a\u6587\u4ef6\npytest -v                         # \u8be6\u7ec6\u8f93\u51fa\npytest -x                         # \u5728\u7b2c\u4e00\u4e2a\u5931\u8d25\u65f6\u505c\u6b62\npytest -k \"normalise\"             # \u8fd0\u884c\u5339\u914d\u540d\u79f0\u6a21\u5f0f\u7684\u6d4b\u8bd5\npytest --tb=short                 # \u66f4\u77ed\u7684\u8ffd\u6eaf\u4fe1\u606f\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/04.%20testing%20and%20quality%20assurance/#_3","title":"\u5939\u5177","text":"
import pytest\n\n@pytest.fixture\ndef sample_dataset():\n    \"\"\"\u521b\u5efa\u4e00\u4e2a\u7528\u4e8e\u6d4b\u8bd5\u7684\u5c0f\u578b\u6570\u636e\u96c6\u3002\"\"\"\n    return {\n        \"inputs\": torch.randn(10, 3, 32, 32),\n        \"labels\": torch.randint(0, 10, (10,))\n    }\n\n@pytest.fixture\ndef trained_model():\n    \"\"\"\u52a0\u8f7d\u4e00\u4e2a\u5c0f\u578b\u9884\u8bad\u7ec3\u6a21\u578b\u3002\"\"\"\n    model = SmallModel()\n    model.load_state_dict(torch.load(\"tests/fixtures/small_model.pt\"))\n    return model\n\ndef test_model_output_shape(trained_model, sample_dataset):\n    output = trained_model(sample_dataset[\"inputs\"])\n    assert output.shape == (10, 10)  # batch_size x num_classes\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/04.%20testing%20and%20quality%20assurance/#_4","title":"\u53c2\u6570\u5316\u6d4b\u8bd5","text":"
@pytest.mark.parametrize(\"input,expected\", [\n    ([1, 2, 3], 6),\n    ([], 0),\n    ([-1, 1], 0),\n    ([1000000, 1000000], 2000000),\n])\ndef test_sum(input, expected):\n    assert sum(input) == expected\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/04.%20testing%20and%20quality%20assurance/#mock","title":"Mock \u4e0e\u8865\u4e01","text":"
from unittest.mock import patch, MagicMock\n\ndef test_training_logs_metrics():\n    mock_logger = MagicMock()\n\n    with patch(\"my_project.training.trainer.wandb\") as mock_wandb:\n        trainer = Trainer(logger=mock_logger)\n        trainer.train_one_epoch()\n\n        # \u9a8c\u8bc1\u8bad\u7ec3\u5668\u8bb0\u5f55\u4e86\u6307\u6807\n        mock_logger.log.assert_called()\n        # \u9a8c\u8bc1\u5b83\u8bb0\u5f55\u4e86\u635f\u5931\u503c\n        call_args = mock_logger.log.call_args\n        assert \"loss\" in call_args[1]\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/04.%20testing%20and%20quality%20assurance/#_5","title":"\u6d4b\u8bd5\u673a\u5668\u5b66\u4e60\u4ee3\u7801","text":""},{"location":"chapter%2015%3A%20production%20software%20engineering/04.%20testing%20and%20quality%20assurance/#_6","title":"\u786e\u5b9a\u6027\u79cd\u5b50","text":"
import random\nimport numpy as np\nimport torch\n\ndef set_seed(seed=42):\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed_all(seed)\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/04.%20testing%20and%20quality%20assurance/#_7","title":"\u6570\u503c\u5bb9\u5dee","text":"
# \u7cdf\u7cd5\uff1a\u7531\u4e8e\u6d6e\u70b9\u6570\u95ee\u9898\uff0c\u7cbe\u786e\u6bd4\u8f83\u4f1a\u5931\u8d25\nassert model_output == 0.5\n\n# \u826f\u597d\uff1a\u8fd1\u4f3c\u6bd4\u8f83\nimport numpy as np\nassert np.isclose(model_output, 0.5, atol=1e-5)\n\n# \u5bf9\u4e8e\u5f20\u91cf\nassert torch.allclose(output, expected, atol=1e-4)\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/04.%20testing%20and%20quality%20assurance/#_8","title":"\u673a\u5668\u5b66\u4e60\u4e2d\u9700\u8981\u6d4b\u8bd5\u4ec0\u4e48","text":"
def test_model_output_shape():\n    model = MyModel(d_model=256, n_classes=10)\n    x = torch.randn(8, 32, 256)  # batch=8, seq=32, dim=256\n    output = model(x)\n    assert output.shape == (8, 10)\n
def test_gradients_flow():\n    model = MyModel()\n    x = torch.randn(4, 3, 32, 32)\n    y = torch.randint(0, 10, (4,))\n\n    output = model(x)\n    loss = F.cross_entropy(output, y)\n    loss.backward()\n\n    for name, param in model.named_parameters():\n        assert param.grad is not None, f\"\u6ca1\u6709 {name} \u7684\u68af\u5ea6\"\n        assert param.grad.abs().sum() > 0, f\"{name} \u7684\u68af\u5ea6\u4e3a\u96f6\"\n
def test_overfit_one_batch():\n    model = MyModel()\n    optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)\n    x, y = get_single_batch()\n\n    for _ in range(100):\n        loss = F.cross_entropy(model(x), y)\n        loss.backward()\n        optimiser.step()\n        optimiser.zero_grad()\n\n    assert loss.item() < 0.01, f\"\u65e0\u6cd5\u8fc7\u62df\u5408\u5355\u4e2a\u6279\u6b21\uff1aloss={loss.item()}\"\n
def test_dataset_basics():\n    dataset = MyDataset(\"tests/fixtures/small_data.csv\")\n    assert len(dataset) > 0\n    x, y = dataset[0]\n    assert x.shape == (3, 224, 224)\n    assert 0 <= y < 10\n    assert not torch.isnan(x).any()\n    assert not torch.isinf(x).any()\n
def test_determinism():\n    set_seed(42)\n    output1 = model(input_data)\n    set_seed(42)\n    output2 = model(input_data)\n    assert torch.allclose(output1, output2)\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/04.%20testing%20and%20quality%20assurance/#cicd","title":"CI/CD \u7ba1\u9053","text":"
name: CI\non: [push, pull_request]\n\njobs:\n  test:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n      - uses: actions/setup-python@v5\n        with:\n          python-version: \"3.11\"\n      - run: pip install -e \".[dev]\"\n      - run: ruff check src/\n      - run: mypy src/\n      - run: pytest tests/ -v --tb=short\n
# .pre-commit-config.yaml\nrepos:\n  - repo: https://github.com/astral-sh/ruff-pre-commit\n    rev: v0.3.0\n    hooks:\n      - id: ruff\n        args: [--fix]\n      - id: ruff-format\n  - repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v4.5.0\n    hooks:\n      - id: trailing-whitespace\n      - id: end-of-file-fixer\n      - id: check-yaml\n
pip install pre-commit\npre-commit install    # \u73b0\u5728\u6bcf\u6b21 git \u63d0\u4ea4\u65f6\u90fd\u4f1a\u8fd0\u884c\u94a9\u5b50\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/04.%20testing%20and%20quality%20assurance/#_9","title":"\u4ee3\u7801\u68c0\u67e5\u4e0e\u683c\u5f0f\u5316","text":"
ruff check src/          # \u4ee3\u7801\u68c0\u67e5\nruff check --fix src/    # \u4ee3\u7801\u68c0\u67e5\u5e76\u81ea\u52a8\u4fee\u590d\nruff format src/         # \u683c\u5f0f\u5316\n
mypy src/\n# src/model.py:42: error: Argument 1 to \"forward\" has incompatible type \"int\"; expected \"Tensor\"\n
def train(\n    model: nn.Module,\n    dataloader: DataLoader,\n    optimiser: torch.optim.Optimizer,\n    num_epochs: int = 10,\n) -> float:\n    \"\"\"\u8bad\u7ec3\u6a21\u578b\u5e76\u8fd4\u56de\u6700\u7ec8\u635f\u5931\u3002\"\"\"\n    ...\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/04.%20testing%20and%20quality%20assurance/#_10","title":"\u4ee3\u7801\u5ba1\u67e5\u6700\u4f73\u5b9e\u8df5","text":""},{"location":"chapter%2015%3A%20production%20software%20engineering/05.%20deployment%20and%20devops/","title":"\u90e8\u7f72\u4e0e DevOps","text":"

\u90e8\u7f72\u662f\u4f60\u7684\u6a21\u578b\u4ece\u7814\u7a76\u4ea7\u7269\u53d8\u6210\u4ea7\u54c1\u7684\u5730\u65b9\u3002\u672c\u6587\u6db5\u76d6\u7528\u4e8e\u673a\u5668\u5b66\u4e60\u7684 Docker\u3001\u6a21\u578b\u63a8\u7406\u3001\u5b9e\u9a8c\u8ffd\u8e2a\u3001\u53ef\u91cd\u73b0\u6027\u3001\u751f\u4ea7\u73af\u5883\u76d1\u63a7\u3001\u7279\u5f81\u5b58\u50a8\u548c\u7ba1\u9053\u7f16\u6392\u2014\u2014\u8fd9\u4e9b\u57fa\u7840\u8bbe\u65bd\u5c06\u4e00\u4e2a\u8bad\u7ec3\u597d\u7684\u6a21\u578b\u4ece notebook \u5e26\u5230\u6570\u767e\u4e07\u7528\u6237\u9762\u524d\u3002

"},{"location":"chapter%2015%3A%20production%20software%20engineering/05.%20deployment%20and%20devops/#docker","title":"\u7528\u4e8e\u673a\u5668\u5b66\u4e60\u7684 Docker","text":"
# \u4ece\u5b98\u65b9\u7684 CUDA \u57fa\u7840\u955c\u50cf\u5f00\u59cb\nFROM nvidia/cuda:12.1.0-cudnn8-runtime-ubuntu22.04\n\n# \u7cfb\u7edf\u4f9d\u8d56\nRUN apt-get update && apt-get install -y \\\n    python3.11 python3-pip git \\\n    && rm -rf /var/lib/apt/lists/*\n\n# Python \u4f9d\u8d56\uff08\u5355\u72ec\u5b89\u88c5\u4ee5\u5229\u7528\u7f13\u5b58\uff09\nCOPY requirements.txt .\nRUN pip install --no-cache-dir -r requirements.txt\n\n# \u590d\u5236\u6e90\u4ee3\u7801\uff08\u9891\u7e41\u66f4\u6539\uff0c\u56e0\u6b64\u6b64\u5c42\u653e\u5728\u6700\u540e\uff09\nCOPY src/ /app/src/\nCOPY configs/ /app/configs/\nWORKDIR /app\n\n# \u5165\u53e3\u70b9\nCMD [\"python3\", \"src/scripts/serve.py\", \"--config\", \"configs/serve.yaml\"]\n
# \u6784\u5efa\u9636\u6bb5\uff1a\u5b89\u88c5\u6784\u5efa\u5de5\u5177\u3001\u7f16\u8bd1\u4f9d\u8d56\nFROM python:3.11 AS builder\nCOPY requirements.txt .\nRUN pip install --user -r requirements.txt\n\n# \u8fd0\u884c\u9636\u6bb5\uff1a\u4ec5\u8fd0\u884c\u73af\u5883\u4f9d\u8d56\nFROM nvidia/cuda:12.1.0-cudnn8-runtime-ubuntu22.04\nCOPY --from=builder /root/.local /root/.local\nCOPY src/ /app/src/\nENV PATH=/root/.local/bin:$PATH\n
# docker-compose.yml\nservices:\n  model:\n    build: .\n    ports:\n      - \"8080:8080\"\n    deploy:\n      resources:\n        reservations:\n          devices:\n            - capabilities: [gpu]\n  prometheus:\n    image: prom/prometheus\n    ports:\n      - \"9090:9090\"\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/05.%20deployment%20and%20devops/#_1","title":"\u6a21\u578b\u63a8\u7406","text":""},{"location":"chapter%2015%3A%20production%20software%20engineering/05.%20deployment%20and%20devops/#_2","title":"\u5b9e\u9a8c\u8ffd\u8e2a","text":"
import wandb\n\nwandb.init(project=\"my-project\", config={\n    \"model\": \"transformer\",\n    \"lr\": 3e-4,\n    \"batch_size\": 64,\n})\n\nfor epoch in range(num_epochs):\n    train_loss = train_one_epoch()\n    val_loss = validate()\n\n    wandb.log({\n        \"train/loss\": train_loss,\n        \"val/loss\": val_loss,\n        \"epoch\": epoch,\n    })\n\n    # \u5c06\u6a21\u578b\u8bb0\u5f55\u4e3a\u4ea7\u7269\n    if val_loss < best_loss:\n        wandb.save(\"best_model.pt\")\n\nwandb.finish()\n
import mlflow\n\nmlflow.set_experiment(\"my-experiment\")\n\nwith mlflow.start_run():\n    mlflow.log_params({\"lr\": 3e-4, \"batch_size\": 64})\n    mlflow.log_metric(\"val_loss\", 0.042, step=epoch)\n    mlflow.pytorch.log_model(model, \"model\")\n
"},{"location":"chapter%2015%3A%20production%20software%20engineering/05.%20deployment%20and%20devops/#_3","title":"\u53ef\u91cd\u73b0\u6027","text":" \u4ec0\u4e48 \u5982\u4f55\u505a \u4ee3\u7801\u7248\u672c Git \u63d0\u4ea4\u54c8\u5e0c\u503c \u914d\u7f6e / \u8d85\u53c2\u6570 \u914d\u7f6e\u6587\u4ef6\uff08\u5728 Git \u4e2d\u7248\u672c\u63a7\u5236\u6216\u8bb0\u5f55\u5230 W&B\uff09 \u968f\u673a\u79cd\u5b50 \u8bbe\u7f6e\u5e76\u8bb0\u5f55\u6240\u6709\u79cd\u5b50\uff08Python\u3001NumPy\u3001PyTorch\u3001CUDA\uff09 \u6570\u636e\u7248\u672c DVC \u54c8\u5e0c\u503c\u3001\u6570\u636e\u96c6\u7248\u672c\u6807\u7b7e\u6216 S3 \u5bf9\u8c61\u7248\u672c \u4f9d\u8d56\u9879 pip freeze\u3001Docker \u955c\u50cf\u54c8\u5e0c\u503c\u6216\u9501\u5b9a\u6587\u4ef6 \u786c\u4ef6 GPU \u7c7b\u578b\u3001GPU \u6570\u91cf\u3001CUDA \u7248\u672c \u975e\u786e\u5b9a\u6027 torch.backends.cudnn.deterministic = True\uff08\u8f83\u6162\u4f46\u53ef\u91cd\u73b0\uff09 "},{"location":"chapter%2015%3A%20production%20software%20engineering/05.%20deployment%20and%20devops/#_4","title":"\u751f\u4ea7\u73af\u5883\u76d1\u63a7","text":""},{"location":"chapter%2015%3A%20production%20software%20engineering/05.%20deployment%20and%20devops/#_5","title":"\u7279\u5f81\u5b58\u50a8","text":""},{"location":"chapter%2015%3A%20production%20software%20engineering/05.%20deployment%20and%20devops/#_6","title":"\u7ba1\u9053\u7f16\u6392","text":"
# airflow DAG \u793a\u4f8b\uff08\u7b80\u5316\uff09\nfrom airflow import DAG\nfrom airflow.operators.python import PythonOperator\n\ndag = DAG(\"training_pipeline\", schedule=\"@daily\")\n\npreprocess = PythonOperator(task_id=\"preprocess\", python_callable=preprocess_data, dag=dag)\ntrain = PythonOperator(task_id=\"train\", python_callable=train_model, dag=dag)\nevaluate = PythonOperator(task_id=\"evaluate\", python_callable=evaluate_model, dag=dag)\ndeploy = PythonOperator(task_id=\"deploy\", python_callable=deploy_model, dag=dag)\n\npreprocess >> train >> evaluate >> deploy\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/00.%20why%20C%2B%2B%20and%20how%20ML%20frameworks%20work/","title":"\u4e3a\u4ec0\u4e48\u662fC++\u4ee5\u53caML\u6846\u67b6\u5982\u4f55\u5de5\u4f5c","text":"

\u672c\u4e66\u4e2d\u6bcf\u4e00\u6b21 jnp.matmul\u3001\u6bcf\u4e00\u6b21 torch.nn.Linear\u3001\u6bcf\u4e00\u6b21 np.dot \u8c03\u7528\uff0c\u5e95\u5c42\u90fd\u5728\u6267\u884cC++\u548cCUDA\u4ee3\u7801\u3002\u672c\u6587\u6863\u63ed\u5f00\u5e37\u5e55\uff1a\u4e3a\u4f55ML\u6846\u67b6\u91c7\u7528\u8fd9\u79cd\u67b6\u6784\uff0c\u9762\u5411Python\u5de5\u7a0b\u5e08\u7684C++\u5feb\u901f\u5165\u95e8\uff0c\u4f55\u65f6\u7f16\u5199\u81ea\u5b9a\u4e49C++\u6838\u51fd\u6570\uff0c\u4ee5\u53ca\u5982\u4f55\u5c06\u5176\u7ed1\u5b9a\u5230Python\u2014\u2014\u8fd9\u662f\u8fde\u63a5\u4f60\u6240\u5199\u4ee3\u7801\u4e0e\u6240\u8fd0\u884c\u786c\u4ef6\u4e4b\u95f4\u7684\u6865\u6881\u3002

"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/00.%20why%20C%2B%2B%20and%20how%20ML%20frameworks%20work/#pythonc","title":"\u4e3a\u4ec0\u4e48Python\u524d\u7aef\u642d\u914dC++\u540e\u7aef","text":" Python C++ \u5f00\u53d1\u901f\u5ea6 \u5feb\uff08\u52a8\u6001\u7c7b\u578b\u3001REPL\u3001\u65e0\u9700\u7f16\u8bd1\uff09 \u6162\uff08\u9759\u6001\u7c7b\u578b\u3001\u5934\u6587\u4ef6\u3001\u7f16\u8bd1\u65f6\u95f4\u957f\uff09 \u6267\u884c\u901f\u5ea6 \u6bd4C\u6162\u7ea6100\u500d\uff08\u89e3\u91ca\u578b\u3001GIL\uff09 \u63a5\u8fd1\u786c\u4ef6\u901f\u5ea6\uff08\u7f16\u8bd1\u578b\u3001\u65e0\u5f00\u9500\uff09 \u5185\u5b58\u63a7\u5236 \u81ea\u52a8\uff08GC\uff09\uff0c\u65e0\u6cd5\u63a7\u5236\u5e03\u5c40 \u624b\u52a8\uff0c\u7cbe\u786e\u63a7\u5236\u6bcf\u4e00\u4e2a\u5b57\u8282 \u786c\u4ef6\u8bbf\u95ee \u65e0\uff08\u65e0SIMD\u3001\u65e0GPU\u3001\u65e0\u81ea\u5b9a\u4e49\u5185\u5b58\uff09 \u5168\u9762\uff08\u5185\u8054\u51fd\u6570\u3001CUDA\u3001\u5185\u8054\u6c47\u7f16\uff09 \u751f\u6001\u7cfb\u7edf ML\u4e30\u5bcc\uff08\u7b14\u8bb0\u672c\u3001\u53ef\u89c6\u5316\u3001\u6570\u636e\uff09 \u7cfb\u7edf\u4e30\u5bcc\uff08\u64cd\u4f5c\u7cfb\u7edf\u3001\u9a71\u52a8\u3001\u5f15\u64ce\uff09 "},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/00.%20why%20C%2B%2B%20and%20how%20ML%20frameworks%20work/#ml","title":"ML\u6846\u67b6\u7684\u7ed3\u6784","text":"
\u7528\u6237\u4ee3\u7801\uff08Python\uff09\n    \u2193\nPython API\u5c42\uff08torch.nn\u3001jax.numpy\u3001numpy\uff09\n    \u2193\n\u8c03\u5ea6/JIT\u7f16\u8bd1\u5668\uff08torch.compile\u3001XLA\u3001NumPy\u8c03\u5ea6\uff09\n    \u2193\nC++\u6838\u51fd\u6570\u5e93\uff08ATen/PyTorch\u3001XLA\u3001BLAS/LAPACK\uff09\n    \u2193\n\u786c\u4ef6\u7279\u5b9a\u540e\u7aef\uff08CUDA\u3001cuDNN\u3001MKL\u3001oneDNN\u3001Metal\uff09\n    \u2193\n\u786c\u4ef6\uff08CPU SIMD\u5355\u5143\u3001GPU\u6838\u5fc3\u3001TPU MXU\uff09\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/00.%20why%20C%2B%2B%20and%20how%20ML%20frameworks%20work/#numpy","title":"NumPy","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/00.%20why%20C%2B%2B%20and%20how%20ML%20frameworks%20work/#pytorch","title":"PyTorch","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/00.%20why%20C%2B%2B%20and%20how%20ML%20frameworks%20work/#jax","title":"JAX","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/00.%20why%20C%2B%2B%20and%20how%20ML%20frameworks%20work/#pythonc_1","title":"\u9762\u5411Python\u5de5\u7a0b\u5e08\u7684C++\u5feb\u901f\u5165\u95e8","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/00.%20why%20C%2B%2B%20and%20how%20ML%20frameworks%20work/#_1","title":"\u7c7b\u578b\u548c\u53d8\u91cf","text":"
// C++\u9700\u8981\u663e\u5f0f\u7c7b\u578b\uff08\u4e0d\u50cfPython\uff09\nint count = 0;           // 32\u4f4d\u6574\u6570\nfloat loss = 0.5f;       // 32\u4f4d\u6d6e\u70b9\u6570\ndouble lr = 3e-4;        // 64\u4f4d\u6d6e\u70b9\u6570\nbool training = true;    // \u5e03\u5c14\u503c\n\n// \u6570\u7ec4\uff08\u56fa\u5b9a\u5927\u5c0f\uff0c\u6808\u5206\u914d\uff09\nfloat weights[1024];     // 1024\u4e2a\u6d6e\u70b9\u6570\uff0c\u5185\u5b58\u4e2d\u8fde\u7eed\n\n// \u6307\u9488\uff1a\u4fdd\u5b58\u5185\u5b58\u5730\u5740\u7684\u53d8\u91cf\nfloat* ptr = weights;    // ptr\u6307\u5411weights\u7684\u7b2c\u4e00\u4e2a\u5143\u7d20\nfloat val = ptr[42];     // \u901a\u8fc7\u6307\u9488\u8fd0\u7b97\u8bbf\u95ee\u5143\u7d2042\n// ptr[42] \u7b49\u4ef7\u4e8e *(ptr + 42)\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/00.%20why%20C%2B%2B%20and%20how%20ML%20frameworks%20work/#_2","title":"\u51fd\u6570","text":"
// \u51fd\u6570\u58f0\u660e\uff1a\u8fd4\u56de\u7c7b\u578b \u540d\u5b57(\u53c2\u6570\u7c7b\u578b \u53c2\u6570\u540d)\nfloat relu(float x) {\n    return x > 0.0f ? x : 0.0f;\n}\n\n// \u4f20\u5f15\u7528\uff08\u907f\u514d\u62f7\u8d1d\u5927\u5bf9\u8c61\uff09\nvoid scale_vector(std::vector<float>& vec, float factor) {\n    for (size_t i = 0; i < vec.size(); i++) {\n        vec[i] *= factor;\n    }\n}\n\n// const\u5f15\u7528\uff1a\u53ea\u8bfb\uff0c\u65e0\u62f7\u8d1d\nfloat sum(const std::vector<float>& vec) {\n    float total = 0.0f;\n    for (float x : vec) {  // \u57fa\u4e8e\u8303\u56f4\u7684for\u5faa\u73af\uff08\u7c7b\u4f3cPython\u7684for x in vec\uff09\n        total += x;\n    }\n    return total;\n}\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/00.%20why%20C%2B%2B%20and%20how%20ML%20frameworks%20work/#_3","title":"\u5185\u5b58\uff1a\u6808\u4e0e\u5806","text":"
// \u6808\u5206\u914d\uff1a\u5feb\u901f\uff0c\u81ea\u52a8\u751f\u547d\u5468\u671f\uff08\u51fd\u6570\u8fd4\u56de\u65f6\u91ca\u653e\uff09\nfloat buffer[256];   // \u6808\u4e0a\u7684256\u4e2a\u6d6e\u70b9\u6570\n\n// \u5806\u5206\u914d\uff1a\u624b\u52a8\uff0c\u5728\u51fd\u6570\u5916\u4ecd\u7136\u5b58\u6d3b\nfloat* data = new float[n];   // \u5728\u5806\u4e0a\u5206\u914dn\u4e2a\u6d6e\u70b9\u6570\n// ... \u4f7f\u7528data ...\ndelete[] data;                 // \u5fc5\u987b\u624b\u52a8\u91ca\u653e\uff08\u6ca1\u6709\u5783\u573e\u56de\u6536\u5668\uff09\n\n// \u73b0\u4ee3C++\uff1a\u667a\u80fd\u6307\u9488\uff08\u81ea\u52a8\u6e05\u7406\uff0c\u7c7b\u4f3cPython\u5f15\u7528\uff09\n#include <memory>\nauto data = std::make_unique<float[]>(n);  // \u79bb\u5f00\u4f5c\u7528\u57df\u65f6\u81ea\u52a8\u91ca\u653e\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/00.%20why%20C%2B%2B%20and%20how%20ML%20frameworks%20work/#_4","title":"\u6a21\u677f\uff08\u6cdb\u578b\uff09","text":"
// \u9002\u7528\u4e8e\u4efb\u4f55\u6570\u503c\u7c7b\u578b\u7684\u51fd\u6570\ntemplate <typename T>\nT add(T a, T b) {\n    return a + b;\n}\n\nadd<float>(1.5f, 2.5f);   // \u8fd4\u56de 4.0f\nadd<int>(3, 4);             // \u8fd4\u56de 7\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/00.%20why%20C%2B%2B%20and%20how%20ML%20frameworks%20work/#_5","title":"\u6807\u51c6\u5e93\u7cbe\u534e","text":"
#include <vector>      // \u52a8\u6001\u6570\u7ec4\uff08\u7c7b\u4f3cPython list\uff09\n#include <string>      // \u5b57\u7b26\u4e32\u7c7b\u578b\n#include <unordered_map>  // \u54c8\u5e0c\u6620\u5c04\uff08\u7c7b\u4f3cPython dict\uff09\n#include <algorithm>   // sort\u3001find\u3001transform\u7b49\n#include <cmath>       // \u6570\u5b66\u51fd\u6570\n\nstd::vector<float> vec = {1.0f, 2.0f, 3.0f};\nvec.push_back(4.0f);            // \u8ffd\u52a0\nfloat first = vec[0];           // \u7d22\u5f15\nsize_t len = vec.size();        // \u957f\u5ea6\n\nstd::unordered_map<std::string, int> counts;\ncounts[\"hello\"] = 5;            // \u63d2\u5165\nif (counts.count(\"hello\")) { }  // \u68c0\u67e5\u5b58\u5728\u6027\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/00.%20why%20C%2B%2B%20and%20how%20ML%20frameworks%20work/#c","title":"\u4f55\u65f6\u7f16\u5199\u81ea\u5b9a\u4e49C++\u6838\u51fd\u6570","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/00.%20why%20C%2B%2B%20and%20how%20ML%20frameworks%20work/#cpython","title":"\u5982\u4f55\u5c06C++\u7ed1\u5b9a\u5230Python","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/00.%20why%20C%2B%2B%20and%20how%20ML%20frameworks%20work/#pybind11","title":"pybind11\uff08\u901a\u7528\u76ee\u7684\uff09","text":"
// my_ops.cpp\n#include <pybind11/pybind11.h>\n#include <pybind11/numpy.h>\nnamespace py = pybind11;\n\n// \u4e00\u4e2a\u7b80\u5355\u7684\u81ea\u5b9a\u4e49\u64cd\u4f5c\npy::array_t<float> custom_relu(py::array_t<float> input) {\n    auto buf = input.request();\n    float* ptr = static_cast<float*>(buf.ptr);\n    size_t n = buf.size;\n\n    auto result = py::array_t<float>(n);\n    float* out = static_cast<float*>(result.request().ptr);\n\n    for (size_t i = 0; i < n; i++) {\n        out[i] = ptr[i] > 0 ? ptr[i] : 0;\n    }\n    return result;\n}\n\nPYBIND11_MODULE(my_ops, m) {\n    m.def(\"custom_relu\", &custom_relu, \"\u81ea\u5b9a\u4e49ReLU\u64cd\u4f5c\");\n}\n
# \u7f16\u8bd1\npip install pybind11\nc++ -O3 -shared -std=c++17 -fPIC $(python3 -m pybind11 --includes) my_ops.cpp -o my_ops$(python3-config --extension-suffix)\n
# \u4ecePython\u4f7f\u7528\nimport my_ops\nimport numpy as np\n\nx = np.array([-1.0, 2.0, -3.0, 4.0], dtype=np.float32)\ny = my_ops.custom_relu(x)\nprint(y)  # [0. 2. 0. 4.]\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/00.%20why%20C%2B%2B%20and%20how%20ML%20frameworks%20work/#pytorch-c","title":"PyTorch C++\u6269\u5c55","text":"
// custom_op.cpp\n#include <torch/extension.h>\n\ntorch::Tensor custom_gelu(torch::Tensor x) {\n    return x * 0.5 * (1.0 + torch::erf(x / std::sqrt(2.0)));\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n    m.def(\"custom_gelu\", &custom_gelu, \"\u81ea\u5b9a\u4e49GELU\u6fc0\u6d3b\u51fd\u6570\");\n}\n
# \u52a8\u6001\u52a0\u8f7d\u548c\u7f16\u8bd1\nfrom torch.utils.cpp_extension import load\n\ncustom_ops = load(\n    name=\"custom_ops\",\n    sources=[\"custom_op.cpp\"],\n    extra_cflags=[\"-O3\"],\n)\n\nx = torch.randn(1000)\ny = custom_ops.custom_gelu(x)\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/00.%20why%20C%2B%2B%20and%20how%20ML%20frameworks%20work/#jax_1","title":"JAX\u81ea\u5b9a\u4e49\u8c03\u7528","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/00.%20why%20C%2B%2B%20and%20how%20ML%20frameworks%20work/#_6","title":"\u5927\u5c40\u89c2","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/00.%20why%20C%2B%2B%20and%20how%20ML%20frameworks%20work/#gclang","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u7528g++\u6216clang++\u7f16\u8bd1\uff09","text":"
  1. \u7f16\u5199\u4f60\u7684\u7b2c\u4e00\u4e2aC++\u7a0b\u5e8f\u3002\u5206\u914d\u4e00\u4e2a\u6570\u7ec4\uff0c\u586b\u5145\u6570\u636e\uff0c\u8ba1\u7b97\u603b\u548c\uff0c\u5e76\u6d4b\u91cf\u65f6\u95f4\u3002\u8fd9\u4ecb\u7ecd\u4e86\u7f16\u8bd1\u3001\u6570\u7ec4\u3001\u6307\u9488\u548c\u8ba1\u65f6\u3002

    // task1_basics.cpp\n// \u7f16\u8bd1\uff1ag++ -O3 -o task1 task1_basics.cpp\n// \u8fd0\u884c\uff1a./task1\n\n#include <iostream>\n#include <chrono>\n#include <vector>\n\nint main() {\n    const int N = 10'000'000;  // C++\u5141\u8bb8'\u4f5c\u4e3a\u6570\u5b57\u5206\u9694\u7b26\n    std::vector<float> data(N);\n\n    // \u586b\u5145\u6570\u7ec4\n    for (int i = 0; i < N; i++) {\n        data[i] = static_cast<float>(i) * 0.001f;\n    }\n\n    // \u8ba1\u7b97\u603b\u548c\n    auto start = std::chrono::high_resolution_clock::now();\n    float sum = 0.0f;\n    for (int i = 0; i < N; i++) {\n        sum += data[i];\n    }\n    auto end = std::chrono::high_resolution_clock::now();\n    double elapsed = std::chrono::duration<double, std::milli>(end - start).count();\n\n    std::cout << \"\u603b\u548c: \" << sum << std::endl;\n    std::cout << \"\u65f6\u95f4: \" << elapsed << \" ms\" << std::endl;\n    std::cout << \"\u5143\u7d20\u6570: \" << N << std::endl;\n    std::cout << \"\u541e\u5410\u91cf: \" << (N * sizeof(float)) / elapsed / 1e6 << \" GB/s\" << std::endl;\n\n    return 0;\n}\n

  2. \u7f16\u5199\u4e00\u4e2aC++\u51fd\u6570\u5728\u6570\u7ec4\u4e0a\u8ba1\u7b97ReLU\uff0c\u7136\u540e\u4f7f\u7528pybind11\u6784\u5efaPython\u7ed1\u5b9a\u3002\u4ecePython\u8c03\u7528\u5b83\u5e76\u4e0eNumPy\u6bd4\u8f83\u901f\u5ea6\u3002

    // task2_relu.cpp\n// \u7f16\u8bd1\uff1ac++ -O3 -shared -std=c++17 -fPIC $(python3 -m pybind11 --includes) \\\n//          task2_relu.cpp -o my_relu$(python3-config --extension-suffix)\n\n#include <pybind11/pybind11.h>\n#include <pybind11/numpy.h>\nnamespace py = pybind11;\n\npy::array_t<float> cpp_relu(py::array_t<float> input) {\n    auto buf = input.request();\n    float* ptr = static_cast<float*>(buf.ptr);\n    int n = buf.size;\n\n    auto result = py::array_t<float>(n);\n    float* out = static_cast<float*>(result.request().ptr);\n\n    for (int i = 0; i < n; i++) {\n        out[i] = ptr[i] > 0.0f ? ptr[i] : 0.0f;\n    }\n    return result;\n}\n\nPYBIND11_MODULE(my_relu, m) {\n    m.def(\"relu\", &cpp_relu, \"C++ ReLU\");\n}\n
    # test_relu.py \u2014 \u5728\u7f16\u8bd1\u4e0a\u8ff0C++\u6a21\u5757\u540e\u8fd0\u884c\nimport numpy as np\nimport time\nimport my_relu  # \u7f16\u8bd1\u540e\u7684C++\u6a21\u5757\n\nx = np.random.randn(10_000_000).astype(np.float32)\n\n# C++ ReLU\nstart = time.time()\nfor _ in range(100):\n    y_cpp = my_relu.relu(x)\ncpp_time = (time.time() - start) / 100\n\n# NumPy ReLU\nstart = time.time()\nfor _ in range(100):\n    y_np = np.maximum(x, 0)\nnp_time = (time.time() - start) / 100\n\nprint(f\"C++ ReLU:   {cpp_time*1000:.2f} ms\")\nprint(f\"NumPy ReLU: {np_time*1000:.2f} ms\")\nprint(f\"\u5339\u914d: {np.allclose(y_cpp, y_np)}\")\n

  3. \u7f16\u5199\u4e00\u4e2aC++\u7a0b\u5e8f\uff0c\u6f14\u793a\u4e3a\u4f55\u5185\u5b58\u5e03\u5c40\u5f88\u91cd\u8981\u3002\u6bd4\u8f83\u884c\u4f18\u5148\u4e0e\u5217\u4f18\u5148\u8bbf\u95ee\u6a21\u5f0f\u5e76\u6d4b\u91cf\u6027\u80fd\u5dee\u5f02\u3002

    // task3_layout.cpp\n// \u7f16\u8bd1\uff1ag++ -O3 -o task3 task3_layout.cpp\n\n#include <iostream>\n#include <chrono>\n#include <vector>\n\nint main() {\n    const int N = 4096;\n    std::vector<float> matrix(N * N, 1.0f);\n\n    // \u884c\u4f18\u5148\u8bbf\u95ee\uff1a\u8fde\u7eed\u5185\u5b58\u5730\u5740\uff08\u7f13\u5b58\u53cb\u597d\uff09\n    auto start = std::chrono::high_resolution_clock::now();\n    float sum_row = 0.0f;\n    for (int i = 0; i < N; i++) {\n        for (int j = 0; j < N; j++) {\n            sum_row += matrix[i * N + j];  // \u6b65\u957f1\u8bbf\u95ee\n        }\n    }\n    auto end = std::chrono::high_resolution_clock::now();\n    double row_ms = std::chrono::duration<double, std::milli>(end - start).count();\n\n    // \u5217\u4f18\u5148\u8bbf\u95ee\uff1a\u6b65\u957fN\u8bbf\u95ee\uff08\u7f13\u5b58\u4e0d\u53cb\u597d\uff09\n    start = std::chrono::high_resolution_clock::now();\n    float sum_col = 0.0f;\n    for (int j = 0; j < N; j++) {\n        for (int i = 0; i < N; i++) {\n            sum_col += matrix[i * N + j];  // \u6b65\u957fN\u8bbf\u95ee\uff08\u7f13\u5b58\u7f3a\u5931\uff01\uff09\n        }\n    }\n    end = std::chrono::high_resolution_clock::now();\n    double col_ms = std::chrono::duration<double, std::milli>(end - start).count();\n\n    std::cout << \"\u884c\u4f18\u5148\uff08\u7f13\u5b58\u53cb\u597d\uff09: \" << row_ms << \" ms\" << std::endl;\n    std::cout << \"\u5217\u4f18\u5148\uff08\u7f13\u5b58\u4e0d\u53cb\u597d\uff09: \" << col_ms << \" ms\" << std::endl;\n    std::cout << \"\u51cf\u901f\u6bd4: \" << col_ms / row_ms << \"x\" << std::endl;\n    std::cout << \"\uff08\u4e24\u4e2a\u548c: \" << sum_row << \", \" << sum_col << \"\uff09\" << std::endl;\n\n    return 0;\n}\n

"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/01.%20hardware%20fundamentals/","title":"\u786c\u4ef6\u57fa\u7840","text":"

\u5728\u7f16\u5199SIMD\u6216GPU\u4ee3\u7801\u4e4b\u524d\uff0c\u4f60\u9700\u8981\u4e86\u89e3\u6240\u7f16\u7a0b\u7684\u786c\u4ef6\u3002\u672c\u6587\u6db5\u76d6\u4e3a\u4ec0\u4e48\u5e76\u884c\u6027\u53d6\u4ee3\u4e86\u65f6\u949f\u901f\u5ea6\u3001\u73b0\u4ee3CPU\u5982\u4f55\u6267\u884c\u6307\u4ee4\u3001\u4ec0\u4e48\u662fSIMD\u3001\u7528\u4e8e\u63a8\u7406\u6027\u80fd\u7684\u5c4b\u9876\u7ebf\u6a21\u578b\uff0c\u4ee5\u53ca\u82af\u7247\u67b6\u6784\u7684\u5168\u666f

"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/01.%20hardware%20fundamentals/#_2","title":"\u514d\u8d39\u6027\u80fd\u7684\u7ec8\u7ed3","text":" \\[P \\propto C \\cdot V^2 \\cdot f\\] "},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/01.%20hardware%20fundamentals/#cpu","title":"\u73b0\u4ee3CPU\u5982\u4f55\u6267\u884c\u6307\u4ee4","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/01.%20hardware%20fundamentals/#simd","title":"SIMD\uff1a\u5355\u6307\u4ee4\u591a\u6570\u636e","text":"
// \u9010\u5143\u7d20\u76f8\u52a0\u4e24\u6570\u7ec4\uff1a4\u6761\u52a0\u6cd5\u6307\u4ee4\nfor (int i = 0; i < 4; i++) {\n    c[i] = a[i] + b[i];  // \u6bcf\u6b21\u8fed\u4ee3\u4e00\u6b21\u52a0\u6cd5\n}\n
// \u4e24\u6570\u7ec4\u76f8\u52a0\uff1a1\u6761SIMD\u6307\u4ee4\u5b8c\u6210\u6240\u67094\u6b21\u52a0\u6cd5\n#include <immintrin.h>  // x86 SIMD\u5185\u8054\u51fd\u6570\n\n__m128 va = _mm_load_ps(a);    // \u52a0\u8f7d4\u4e2a\u6d6e\u70b9\u6570\u5230128\u4f4d\u5bc4\u5b58\u5668\n__m128 vb = _mm_load_ps(b);    // \u52a0\u8f7d4\u4e2a\u6d6e\u70b9\u6570\u5230\u53e6\u4e00\u4e2a\u5bc4\u5b58\u5668\n__m128 vc = _mm_add_ps(va, vb); // \u540c\u65f6\u76f8\u52a0\u6240\u67094\u5bf9\n_mm_store_ps(c, vc);            // \u5b58\u50a84\u4e2a\u7ed3\u679c\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/01.%20hardware%20fundamentals/#_3","title":"\u5411\u91cf\u5bc4\u5b58\u5668","text":" \u5bc4\u5b58\u5668\u5bbd\u5ea6 \u6d6e\u70b9\u6570\uff0832\u4f4d\uff09 \u53cc\u7cbe\u5ea6\u6d6e\u70b9\u6570\uff0864\u4f4d\uff09 \u540d\u79f0 128\u4f4d 4 2 SSE\uff08x86\uff09\u3001NEON\uff08ARM\uff09 256\u4f4d 8 4 AVX/AVX2\uff08x86\uff09 512\u4f4d 16 8 AVX-512\uff08x86\uff09 \u53ef\u53d8\uff08128-2048\uff09 \u53ef\u53d8 \u53ef\u53d8 SVE/SVE2\uff08ARM\uff09 "},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/01.%20hardware%20fundamentals/#_4","title":"\u5c4b\u9876\u7ebf\u6a21\u578b","text":" \\[\\text{\u7b97\u672f\u5f3a\u5ea6} = \\frac{\\text{FLOPS}}{\\text{\u4f20\u8f93\u7684\u5b57\u8282\u6570}}\\] \\[\\text{\u53ef\u8fbeFLOPS} = \\min\\left(\\text{\u5cf0\u503cFLOPS}, \\; \\text{\u5e26\u5bbd} \\times \\text{\u7b97\u672f\u5f3a\u5ea6}\\right)\\] "},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/01.%20hardware%20fundamentals/#_5","title":"\u5ef6\u8fdf\u4e0e\u541e\u5410\u91cf","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/01.%20hardware%20fundamentals/#_6","title":"\u82af\u7247\u67b6\u6784\u5168\u666f","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/01.%20hardware%20fundamentals/#x86intel-amd","title":"x86\uff08Intel, AMD\uff09","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/01.%20hardware%20fundamentals/#arm","title":"ARM","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/01.%20hardware%20fundamentals/#apple-siliconm1m2m3m4","title":"Apple Silicon\uff08M1/M2/M3/M4\uff09","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/01.%20hardware%20fundamentals/#risc-v","title":"RISC-V","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/01.%20hardware%20fundamentals/#gpunvidiaamdintel","title":"GPU\uff08NVIDIA\u3001AMD\u3001Intel\uff09","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/01.%20hardware%20fundamentals/#tpugoogle","title":"TPU\uff08Google\uff09","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/01.%20hardware%20fundamentals/#_7","title":"\u70ed\u4e0e\u529f\u8017\u7ea6\u675f","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/01.%20hardware%20fundamentals/#c","title":"\u5b9e\u8df5\uff1a\u5728C++\u4e2d\u6d4b\u91cf\u6027\u80fd","text":"
#include <iostream>\n#include <chrono>\n#include <vector>\n\n// \u6807\u91cf\u52a0\u6cd5\nvoid add_scalar(const float* a, const float* b, float* c, int n) {\n    for (int i = 0; i < n; i++) {\n        c[i] = a[i] + b[i];\n    }\n}\n\nint main() {\n    const int N = 1 << 24;  // \u7ea61600\u4e07\u4e2a\u5143\u7d20\n    std::vector<float> a(N, 1.0f), b(N, 2.0f), c(N);\n\n    // \u9884\u70ed\uff08\u586b\u5145\u7f13\u5b58\uff0c\u89e6\u53d1\u9891\u7387\u7f29\u653e\uff09\n    add_scalar(a.data(), b.data(), c.data(), N);\n\n    // \u57fa\u51c6\u6d4b\u8bd5\n    auto start = std::chrono::high_resolution_clock::now();\n\n    for (int trial = 0; trial < 100; trial++) {\n        add_scalar(a.data(), b.data(), c.data(), N);\n    }\n\n    auto end = std::chrono::high_resolution_clock::now();\n    double elapsed = std::chrono::duration<double>(end - start).count();\n\n    double total_bytes = 3.0 * N * sizeof(float) * 100;  // \u8bfba\u3001\u8bfbb\u3001\u5199c\n    double bandwidth = total_bytes / elapsed / 1e9;        // GB/s\n\n    std::cout << \"\u65f6\u95f4: \" << elapsed << \" s\\n\";\n    std::cout << \"\u5e26\u5bbd: \" << bandwidth << \" GB/s\\n\";\n\n    return 0;\n}\n
# \u542f\u7528\u4f18\u5316\u7f16\u8bd1\ng++ -O3 -march=native -o bench bench.cpp\n./bench\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/01.%20hardware%20fundamentals/#colab","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528CoLab\u6216\u7b14\u8bb0\u672c\uff09","text":"
  1. \u8ba1\u7b97\u5e38\u89c1ML\u64cd\u4f5c\u7684\u7b97\u672f\u5f3a\u5ea6\uff0c\u5e76\u5c06\u5b83\u4eec\u5206\u7c7b\u4e3a\u5185\u5b58\u53d7\u9650\u6216\u8ba1\u7b97\u53d7\u9650\u3002

    import jax.numpy as jnp\n\ndef arithmetic_intensity(flops, bytes_transferred):\n    return flops / bytes_transferred\n\n# \u9010\u5143\u7d20ReLU\uff1a\u6bcf\u5143\u7d201\u6b21\u6bd4\u8f83\uff0c\u8bfb\u53d6+\u5199\u5165\nn = 1024\nrelu_flops = n  # \u6bcf\u5143\u7d201\u6b21\u64cd\u4f5c\nrelu_bytes = 2 * n * 4  # \u8bfb\u53d6\u8f93\u5165+\u5199\u5165\u8f93\u51fa\uff08float32\uff09\nprint(f\"ReLU: {arithmetic_intensity(relu_flops, relu_bytes):.2f} FLOPS/byte \u2192 \u5185\u5b58\u53d7\u9650\")\n\n# \u77e9\u9635\u4e58\u6cd5\uff1a2*n^3\u6b21\u64cd\u4f5c\uff0c\u8bfb\u53d62*n^2 + \u5199\u5165n^2\u4e2a\u6d6e\u70b9\u6570\nmatmul_flops = 2 * n**3\nmatmul_bytes = 3 * n**2 * 4  # \u8bfb\u53d6A + \u8bfb\u53d6B + \u5199\u5165C\nprint(f\"\u77e9\u9635\u4e58\u6cd5 ({n}\u00d7{n}): {arithmetic_intensity(matmul_flops, matmul_bytes):.0f} FLOPS/byte \u2192 \u8ba1\u7b97\u53d7\u9650\")\n\n# \u5c42\u5f52\u4e00\u5316\uff1a\u7ea65n\u6b21\u64cd\u4f5c\uff08\u5747\u503c\u3001\u65b9\u5dee\u3001\u5f52\u4e00\u5316\uff09\uff0c\u8bfb\u53d6+\u5199\u5165\nln_flops = 5 * n\nln_bytes = 2 * n * 4\nprint(f\"LayerNorm: {arithmetic_intensity(ln_flops, ln_bytes):.2f} FLOPS/byte \u2192 \u5185\u5b58\u53d7\u9650\")\n\n# 3x3\u5377\u79ef\uff1a2*9*C_in*C_out*H*W\uff0c\u8bfb\u53d6\u5377\u79ef\u6838+\u7279\u5f81\u56fe+\u5199\u5165\u8f93\u51fa\nC_in, C_out, H, W = 64, 128, 32, 32\nconv_flops = 2 * 9 * C_in * C_out * H * W\nconv_bytes = (9 * C_in * C_out + C_in * H * W + C_out * H * W) * 4\nprint(f\"Conv3x3: {arithmetic_intensity(conv_flops, conv_bytes):.0f} FLOPS/byte \u2192 \u8ba1\u7b97\u53d7\u9650\")\n

  2. \u6f14\u793a\u4e3a\u4ec0\u4e48\u5e76\u884c\u6027\u91cd\u8981\u3002\u6bd4\u8f83\u987a\u5e8f\u6267\u884c\u4e0e\u5e76\u884c\uff08NumPy\uff09\u6267\u884c\u968f\u6570\u636e\u89c4\u6a21\u589e\u957f\u7684\u8868\u73b0\u3002

    import numpy as np\nimport time\n\nfor n in [1000, 10000, 100000, 1000000, 10000000]:\n    a = np.random.randn(n).astype(np.float32)\n    b = np.random.randn(n).astype(np.float32)\n\n    # \"\u987a\u5e8f\u6267\u884c\"\uff08Python\u5faa\u73af\uff09\n    start = time.time()\n    c = [a[i] * b[i] for i in range(min(n, 100000))]  # \u4e0a\u965010\u4e07\u4ee5\u786e\u4fdd\u5408\u7406\n    seq_time = time.time() - start\n    if n > 100000:\n        seq_time *= n / 100000  # \u5916\u63a8\n\n    # \"\u5e76\u884c\"\uff08NumPy\uff0c\u5185\u90e8\u4f7f\u7528SIMD+\u591a\u7ebf\u7a0b\uff09\n    start = time.time()\n    c = a * b\n    par_time = time.time() - start\n\n    print(f\"n={n:>10,}  \u987a\u5e8f={seq_time:.4f}s  \u5e76\u884c={par_time:.6f}s  \"\n          f\"\u52a0\u901f\u6bd4={seq_time/par_time:.0f}x\")\n

"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/02.%20ARM%20and%20NEON/","title":"ARM\u4e0eNEON","text":"

ARM\u5904\u7406\u5668\u9a71\u52a8\u7740\u6bcf\u4e00\u90e8\u667a\u80fd\u624b\u673a\u3001\u5927\u591a\u6570\u5e73\u677f\u7535\u8111\u3001Apple\u7684\u7b14\u8bb0\u672c\u7535\u8111\u4ee5\u53ca\u65e5\u76ca\u589e\u957f\u7684\u6570\u636e\u4e2d\u5fc3\u670d\u52a1\u5668\u4efd\u989d\u3002\u672c\u6587\u6db5\u76d6ARM\u67b6\u6784\u3001\u4f7f\u7528C++\u5185\u8054\u51fd\u6570\u7684NEON SIMD\u7f16\u7a0b\u3001\u7528\u4e8e\u53ef\u4f38\u7f29\u5411\u91cf\u5904\u7406\u7684SVE/SVE2\u3001Apple Silicon\u7279\u6027\u4ee5\u53ca\u5b9e\u9645\u5411\u91cf\u5316\u6838\u51fd\u6570\u793a\u4f8b

"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/02.%20ARM%20and%20NEON/#arm","title":"ARM\u67b6\u6784\u57fa\u7840","text":"
// ARM\u6c47\u7f16\uff08\u4ec5\u611f\u53d7\u98ce\u683c\u2014\u2014\u4f60\u5c06\u4f7f\u7528\u5185\u8054\u51fd\u6570\uff0c\u800c\u975e\u6c47\u7f16\uff09\n// \u4e24\u5bc4\u5b58\u5668\u76f8\u52a0\nadd x0, x1, x2    // x0 = x1 + x2\n\n// \u4ece\u5185\u5b58\u52a0\u8f7d\nldr x0, [x1]      // x0 = *x1\uff08\u4ecex1\u4e2d\u7684\u5730\u5740\u52a0\u8f7d64\u4f4d\uff09\n\n// NEON\uff1a\u52a0\u56db\u4e2a\u6d6e\u70b9\u6570\nfadd v0.4s, v1.4s, v2.4s  // v0 = v1 + v2\uff08\u56db\u4e2a32\u4f4d\u6d6e\u70b9\u6570\uff09\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/02.%20ARM%20and%20NEON/#neon128simd","title":"NEON\uff1a128\u4f4dSIMD","text":" \u6570\u636e\u7c7b\u578b \u6bcf\u5bc4\u5b58\u5668\u5143\u7d20\u6570 \u8868\u793a\u6cd5 float32 4 float32x4_t float16 8 float16x8_t int32 4 int32x4_t int16 8 int16x8_t int8 16 int8x16_t "},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/02.%20ARM%20and%20NEON/#neon","title":"NEON\u5185\u8054\u51fd\u6570\uff1a\u57fa\u7840","text":"
#include <arm_neon.h>\n\n// \u4ece\u5185\u5b58\u52a0\u8f7d4\u4e2a\u6d6e\u70b9\u6570\u5230NEON\u5bc4\u5b58\u5668\nfloat32x4_t a = vld1q_f32(ptr);        // vld1q = vector load 1, q = 128\u4f4d\uff08\u56db\u5b57\uff09\n\n// \u4eceNEON\u5bc4\u5b58\u5668\u5b58\u50a84\u4e2a\u6d6e\u70b9\u6570\u5230\u5185\u5b58\nvst1q_f32(out_ptr, a);                   // vst1q = vector store 1, q = 128\u4f4d\n\n// \u7b97\u672f\u8fd0\u7b97\nfloat32x4_t c = vaddq_f32(a, b);        // c = a + b\uff084\u4e2a\u6d6e\u70b9\u6570\uff09\nfloat32x4_t d = vmulq_f32(a, b);        // d = a * b\uff084\u4e2a\u6d6e\u70b9\u6570\uff09\nfloat32x4_t e = vfmaq_f32(c, a, b);     // e = c + a * b\uff08\u878d\u5408\u4e58\u52a0\uff0c4\u4e2a\u6d6e\u70b9\u6570\uff09\n\n// \u6bd4\u8f83\uff08\u8fd4\u56de\u63a9\u7801\uff1a\u82e5\u771f\u5219\u51681\uff0c\u82e5\u5047\u5219\u51680\uff09\nuint32x4_t mask = vcgtq_f32(a, b);      // mask[i] = (a[i] > b[i]) ? 0xFFFFFFFF : 0\n\n// \u57fa\u4e8e\u63a9\u7801\u9009\u62e9\u5143\u7d20\uff08\u7c7b\u4f3cnumpy.where\uff09\nfloat32x4_t result = vbslq_f32(mask, a, b);  // result[i] = mask[i] ? a[i] : b[i]\n\n// \u5f52\u7ea6\uff1a\u5c06\u6240\u67094\u4e2a\u5143\u7d20\u6c42\u548c\u4e3a\u6807\u91cf\nfloat total = vaddvq_f32(a);             // total = a[0] + a[1] + a[2] + a[3]\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/02.%20ARM%20and%20NEON/#_1","title":"\u5b9e\u8df5\u793a\u4f8b\uff1a\u5411\u91cf\u5316\u70b9\u79ef","text":"
#include <arm_neon.h>\n\n// \u6807\u91cf\u70b9\u79ef\nfloat dot_scalar(const float* a, const float* b, int n) {\n    float sum = 0.0f;\n    for (int i = 0; i < n; i++) {\n        sum += a[i] * b[i];\n    }\n    return sum;\n}\n\n// NEON\u5411\u91cf\u5316\u70b9\u79ef\nfloat dot_neon(const float* a, const float* b, int n) {\n    float32x4_t sum_vec = vdupq_n_f32(0.0f);  // \u521d\u59cb\u53164\u4e2a\u7d2f\u52a0\u5668\u4e3a0\n\n    int i = 0;\n    for (; i + 4 <= n; i += 4) {\n        float32x4_t va = vld1q_f32(a + i);     // \u4ecea\u52a0\u8f7d4\u4e2a\u5143\u7d20\n        float32x4_t vb = vld1q_f32(b + i);     // \u4eceb\u52a0\u8f7d4\u4e2a\u5143\u7d20\n        sum_vec = vfmaq_f32(sum_vec, va, vb);   // sum_vec += va * vb\n    }\n\n    // \u5c064\u4e2a\u7d2f\u52a0\u5668\u5f52\u7ea6\u4e3a\u5355\u4e00\u6807\u91cf\n    float sum = vaddvq_f32(sum_vec);\n\n    // \u5904\u7406\u5269\u4f59\u5143\u7d20\uff08\u5982\u679cn\u4e0d\u662f4\u7684\u500d\u6570\uff09\n    for (; i < n; i++) {\n        sum += a[i] * b[i];\n    }\n\n    return sum;\n}\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/02.%20ARM%20and%20NEON/#relu","title":"\u5b9e\u8df5\u793a\u4f8b\uff1a\u5411\u91cf\u5316ReLU","text":"
#include <arm_neon.h>\n\nvoid relu_neon(const float* input, float* output, int n) {\n    float32x4_t zero = vdupq_n_f32(0.0f);\n\n    int i = 0;\n    for (; i + 4 <= n; i += 4) {\n        float32x4_t x = vld1q_f32(input + i);\n        float32x4_t result = vmaxq_f32(x, zero);  // max(x, 0) = ReLU\n        vst1q_f32(output + i, result);\n    }\n\n    // \u6807\u91cf\u6e05\u7406\n    for (; i < n; i++) {\n        output[i] = input[i] > 0 ? input[i] : 0;\n    }\n}\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/02.%20ARM%20and%20NEON/#i8mm","title":"I8MM\uff1a\u6574\u6570\u77e9\u9635\u4e58\u6cd5","text":"
#include <arm_neon.h>\n\n// I8MM\uff1a\u5c06\u4e24\u4e2a8\u5143\u7d20INT8\u5411\u91cf\u76f8\u4e58\uff0c\u7d2f\u52a0\u52304\u4e2aINT32\u7ed3\u679c\u4e2d\n// \u8fd9\u4ece2x8 \u00d7 8x2\u8f93\u5165\u5757\u8ba1\u7b97\u8f93\u51fa\u77e9\u9635\u7684\u4e00\u4e2a2x2\u74e6\u7247\nvoid matmul_i8mm_tile(const int8_t* A, const int8_t* B, int32_t* C) {\n    // \u4eceA\u52a0\u8f7d8\u5b57\u8282\uff082\u884c\u54044\u5143\u7d20\uff0c\u6253\u5305\uff09\n    int8x16_t va = vld1q_s8(A);   // 16\u5b57\u8282 = 2\u884c \u00d7 8\u5143\u7d20\n    int8x16_t vb = vld1q_s8(B);   // 16\u5b57\u8282 = 2\u884c \u00d7 8\u5143\u7d20\n\n    // \u52a0\u8f7d\u73b0\u6709\u7d2f\u52a0\u5668\uff082x2 = 4\u4e2aint32\u503c\uff09\n    int32x4_t acc = vld1q_s32(C);\n\n    // I8MM\u6307\u4ee4\uff1aacc += A_tile \u00d7 B_tile^T\n    // \u4ece2\u00d78 \u00d7 8\u00d72\u8f93\u5165\u8ba1\u7b972\u00d72\u8f93\u51fa\n    acc = vmmlaq_s32(acc, va, vb);  // I8MM\u6307\u4ee4\n\n    vst1q_s32(C, acc);\n}\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/02.%20ARM%20and%20NEON/#smesme2","title":"SME\u548cSME2\uff1a\u53ef\u4f38\u7f29\u77e9\u9635\u6269\u5c55","text":"
#include <arm_sme.h>\n\n// SME2\uff1a\u77e9\u9635\u4e58\u6cd5\u7684\u5916\u79ef\u7d2f\u52a0\n// \u5c06A_col \u00d7 B_row \u7d2f\u52a0\u5230ZA\u74e6\u7247\u5bc4\u5b58\u5668\u4e2d\nvoid sme2_matmul_outer(const float* A_col, const float* B_row, int K) {\n    // \u8fdb\u5165\u6d41\u6a21\u5f0f\n    // smstart;  // \uff08\u901a\u8fc7\u7f16\u8bd1\u5668\u5185\u8054\u6216\u5185\u8054\u6c47\u7f16\u5b8c\u6210\uff09\n\n    // \u6e05\u96f6ZA\u74e6\u7247\u7d2f\u52a0\u5668\n    svzero_za();\n\n    for (int k = 0; k < K; k++) {\n        // \u5c06A\u7684\u4e00\u5217\u548cB\u7684\u4e00\u884c\u52a0\u8f7d\u5230SVE\u5bc4\u5b58\u5668\u4e2d\n        svfloat32_t a = svld1_f32(svptrue_b32(), &A_col[k * SVL]);\n        svfloat32_t b = svld1_f32(svptrue_b32(), &B_row[k * SVL]);\n\n        // \u5916\u79ef\uff1aZA += a \u00d7 b^T\n        // \u8fd9\u5728\u4e00\u4e2a\u6307\u4ee4\u4e2d\u7d2f\u52a0\u4e00\u4e2aSVL\u00d7SVL\u74e6\u7247\n        svmopa_za32_f32_m(0, svptrue_b32(), svptrue_b32(), a, b);\n    }\n\n    // \u5c06ZA\u74e6\u7247\u5b58\u50a8\u5230\u5185\u5b58\n    // svst1_za(...);\n\n    // \u9000\u51fa\u6d41\u6a21\u5f0f\n    // smstop;\n}\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/02.%20ARM%20and%20NEON/#svesve2","title":"SVE\u548cSVE2\uff1a\u53ef\u4f38\u7f29\u5411\u91cf\u6269\u5c55","text":"
#include <arm_sve.h>\n\nvoid add_sve(const float* a, const float* b, float* c, int n) {\n    int i = 0;\n    svbool_t pred = svwhilelt_b32(i, n);  // \u8c13\u8bcd\uff1a\u54ea\u4e9b\u901a\u9053\u662f\u6fc0\u6d3b\u7684\n\n    while (svptest_any(svptrue_b32(), pred)) {\n        svfloat32_t va = svld1(pred, a + i);\n        svfloat32_t vb = svld1(pred, b + i);\n        svst1(pred, c + i, svadd_x(pred, va, vb));\n\n        i += svcntw();  // \u6309\u786c\u4ef6\u5411\u91cf\u5bbd\u5ea6\u524d\u8fdb\uff08\u4ee532\u4f4d\u5143\u7d20\u8ba1\uff09\n        pred = svwhilelt_b32(i, n);\n    }\n}\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/02.%20ARM%20and%20NEON/#apple-silicon","title":"Apple Silicon\u7279\u6027","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/02.%20ARM%20and%20NEON/#_2","title":"\u81ea\u52a8\u5411\u91cf\u5316","text":"
// \u7f16\u8bd1\u5668\u53ef\u4ee5\u81ea\u52a8\u5411\u91cf\u5316\u6b64\u4ee3\u7801\uff08\u4f7f\u7528 -O3 -march=native\uff09\nvoid add_auto(const float* a, const float* b, float* c, int n) {\n    for (int i = 0; i < n; i++) {\n        c[i] = a[i] + b[i];\n    }\n}\n
// restrict \u544a\u8bc9\u7f16\u8bd1\u5668\uff1aa\u3001b\u3001c \u6307\u5411\u4e0d\u91cd\u53e0\u7684\u5185\u5b58\nvoid add_restrict(const float* __restrict__ a,\n                  const float* __restrict__ b,\n                  float* __restrict__ c, int n) {\n    for (int i = 0; i < n; i++) {\n        c[i] = a[i] + b[i];\n    }\n}\n
# GCC\uff1a\u663e\u793a\u5411\u91cf\u5316\u51b3\u7b56\ng++ -O3 -march=native -fopt-info-vec-optimized code.cpp\n\n# Clang\uff1a\u663e\u793a\u5411\u91cf\u5316\u62a5\u544a\nclang++ -O3 -march=native -Rpass=loop-vectorize code.cpp\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/02.%20ARM%20and%20NEON/#armgclangmac-mlinux-aarch64","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u5728ARM\u4e0a\u7528g++\u6216clang++\u7f16\u8bd1\u2014\u2014Mac M\u7cfb\u5217\u6216Linux aarch64\uff09","text":"
  1. \u7f16\u5199\u6807\u91cf\u70b9\u79ef\u548cNEON\u5411\u91cf\u5316\u70b9\u79ef\u3002\u5bf9\u4e24\u8005\u8fdb\u884c\u57fa\u51c6\u6d4b\u8bd5\u5e76\u6d4b\u91cf\u52a0\u901f\u6bd4\u3002

    // task1_neon_dot.cpp\n// \u7f16\u8bd1\uff08Mac/ARM Linux\uff09\uff1aclang++ -O3 -o task1 task1_neon_dot.cpp\n// \u6ce8\u610f\uff1aNEON\u5728AArch64\u4e0a\u9ed8\u8ba4\u542f\u7528\uff0c\u65e0\u9700\u7279\u6b8a\u6807\u5fd7\n\n#include <iostream>\n#include <chrono>\n#include <vector>\n#include <arm_neon.h>\n\nfloat dot_scalar(const float* a, const float* b, int n) {\n    float sum = 0.0f;\n    for (int i = 0; i < n; i++) {\n        sum += a[i] * b[i];\n    }\n    return sum;\n}\n\nfloat dot_neon(const float* a, const float* b, int n) {\n    float32x4_t sum_vec = vdupq_n_f32(0.0f);\n    int i = 0;\n    for (; i + 4 <= n; i += 4) {\n        float32x4_t va = vld1q_f32(a + i);\n        float32x4_t vb = vld1q_f32(b + i);\n        sum_vec = vfmaq_f32(sum_vec, va, vb);\n    }\n    float sum = vaddvq_f32(sum_vec);\n    for (; i < n; i++) sum += a[i] * b[i];\n    return sum;\n}\n\nint main() {\n    const int N = 10'000'000;\n    std::vector<float> a(N, 1.0f), b(N, 2.0f);\n\n    // \u9884\u70ed\n    volatile float s1 = dot_scalar(a.data(), b.data(), N);\n    volatile float s2 = dot_neon(a.data(), b.data(), N);\n\n    // \u6807\u91cf\u57fa\u51c6\u6d4b\u8bd5\n    auto start = std::chrono::high_resolution_clock::now();\n    for (int t = 0; t < 100; t++) {\n        s1 = dot_scalar(a.data(), b.data(), N);\n    }\n    auto end = std::chrono::high_resolution_clock::now();\n    double scalar_ms = std::chrono::duration<double, std::milli>(end - start).count() / 100;\n\n    // NEON\u57fa\u51c6\u6d4b\u8bd5\n    start = std::chrono::high_resolution_clock::now();\n    for (int t = 0; t < 100; t++) {\n        s2 = dot_neon(a.data(), b.data(), N);\n    }\n    end = std::chrono::high_resolution_clock::now();\n    double neon_ms = std::chrono::duration<double, std::milli>(end - start).count() / 100;\n\n    std::cout << \"\u6807\u91cf: \" << scalar_ms << \" ms\uff08\u7ed3\u679c: \" << s1 << \"\uff09\\n\";\n    std::cout << \"NEON: \" << neon_ms << \" ms\uff08\u7ed3\u679c: \" << s2 << \"\uff09\\n\";\n    std::cout << \"\u52a0\u901f\u6bd4: \" << scalar_ms / neon_ms << \"x\\n\";\n    return 0;\n}\n

  2. \u5b9e\u73b0NEON ReLU\u548csoftmax\u6700\u5927\u503c\u67e5\u627e\u3002\u7ec3\u4e60\u4f7f\u7528\u4e0d\u540c\u64cd\u4f5c\u7684\u52a0\u8f7d\u2192\u8ba1\u7b97\u2192\u5b58\u50a8\u6a21\u5f0f\u3002

    // task2_neon_ops.cpp\n// \u7f16\u8bd1\uff1aclang++ -O3 -o task2 task2_neon_ops.cpp\n\n#include <iostream>\n#include <vector>\n#include <cmath>\n#include <arm_neon.h>\n\nvoid relu_neon(const float* in, float* out, int n) {\n    float32x4_t zero = vdupq_n_f32(0.0f);\n    int i = 0;\n    for (; i + 4 <= n; i += 4) {\n        float32x4_t x = vld1q_f32(in + i);\n        vst1q_f32(out + i, vmaxq_f32(x, zero));\n    }\n    for (; i < n; i++) out[i] = in[i] > 0 ? in[i] : 0;\n}\n\nfloat max_neon(const float* data, int n) {\n    float32x4_t max_vec = vdupq_n_f32(-INFINITY);\n    int i = 0;\n    for (; i + 4 <= n; i += 4) {\n        max_vec = vmaxq_f32(max_vec, vld1q_f32(data + i));\n    }\n    float result = vmaxvq_f32(max_vec);\n    for (; i < n; i++) result = result > data[i] ? result : data[i];\n    return result;\n}\n\nint main() {\n    std::vector<float> data = {-3, 1, -1, 4, 2, -5, 0, 7, -2, 3};\n    std::vector<float> out(data.size());\n\n    relu_neon(data.data(), out.data(), data.size());\n    std::cout << \"ReLU: \";\n    for (float x : out) std::cout << x << \" \";\n    std::cout << \"\\n\";\n\n    float mx = max_neon(data.data(), data.size());\n    std::cout << \"\u6700\u5927\u503c: \" << mx << \"\uff08\u671f\u671b\u503c: 7\uff09\\n\";\n    return 0;\n}\n

  3. \u6bd4\u8f83\u81ea\u52a8\u5411\u91cf\u5316\u4ee3\u7801\u4e0e\u624b\u5199NEON\u5185\u8054\u51fd\u6570\u3002\u7528 -fopt-info-vec\uff08GCC\uff09\u6216 -Rpass=loop-vectorize\uff08Clang\uff09\u7f16\u8bd1\u4ee5\u67e5\u770b\u7f16\u8bd1\u5668\u7684\u64cd\u4f5c\u3002

    // task3_auto_vs_manual.cpp\n// \u7f16\u8bd1\uff1aclang++ -O3 -Rpass=loop-vectorize -o task3 task3_auto_vs_manual.cpp\n//    \uff08\u6216\uff09\uff1ag++ -O3 -fopt-info-vec-optimized -o task3 task3_auto_vs_manual.cpp\n\n#include <iostream>\n#include <chrono>\n#include <vector>\n#include <arm_neon.h>\n\n// \u8ba9\u7f16\u8bd1\u5668\u81ea\u52a8\u5411\u91cf\u5316\nvoid add_auto(const float* __restrict__ a, const float* __restrict__ b,\n              float* __restrict__ c, int n) {\n    for (int i = 0; i < n; i++) {\n        c[i] = a[i] + b[i];\n    }\n}\n\n// \u624b\u5199NEON\nvoid add_neon(const float* a, const float* b, float* c, int n) {\n    int i = 0;\n    for (; i + 4 <= n; i += 4) {\n        vst1q_f32(c + i, vaddq_f32(vld1q_f32(a + i), vld1q_f32(b + i)));\n    }\n    for (; i < n; i++) c[i] = a[i] + b[i];\n}\n\nint main() {\n    const int N = 10'000'000;\n    std::vector<float> a(N, 1.0f), b(N, 2.0f), c(N);\n\n    auto bench = [&](auto fn, const char* name) {\n        fn(a.data(), b.data(), c.data(), N);  // \u9884\u70ed\n        auto start = std::chrono::high_resolution_clock::now();\n        for (int t = 0; t < 100; t++) fn(a.data(), b.data(), c.data(), N);\n        auto end = std::chrono::high_resolution_clock::now();\n        double ms = std::chrono::duration<double, std::milli>(end - start).count() / 100;\n        std::cout << name << \": \" << ms << \" ms\\n\";\n    };\n\n    bench(add_auto, \"\u81ea\u52a8\u5411\u91cf\u5316\");\n    bench(add_neon, \"\u624b\u5199NEON\");\n    // \u5b83\u4eec\u5e94\u8be5\u975e\u5e38\u63a5\u8fd1\u2014\u2014\u7f16\u8bd1\u5668\u80fd\u5f88\u597d\u5730\u81ea\u52a8\u5411\u91cf\u5316\u8fd9\u4e2a\u7b80\u5355\u5faa\u73af\n    return 0;\n}\n

"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/03.%20x86%20and%20AVX/","title":"x86\u4e0eAVX","text":"

x86\u5904\u7406\u5668\u6765\u81eaIntel\u548cAMD\uff0c\u4e3b\u5bfc\u7740\u5927\u591a\u6570ML\u8bad\u7ec3\u6240\u5728\u7684\u6570\u636e\u4e2d\u5fc3\u670d\u52a1\u5668\u3002\u672c\u6587\u6db5\u76d6x86 SIMD\u7684\u6f14\u8fdb\u3001AVX/AVX2\u5185\u8054\u51fd\u6570\u7f16\u7a0b\u3001AVX-512\u3001\u7528\u4e8e\u77e9\u9635\u64cd\u4f5c\u7684Intel AMX\u3001\u5185\u5b58\u5bf9\u9f50\u3001\u6027\u80fd\u9677\u9631\u4ee5\u53ca\u6027\u80fd\u5206\u6790\u2014\u2014\u5728\u5168\u7403\u6700\u5e38\u89c1\u7684\u670d\u52a1\u5668CPU\u4e0a\u69a8\u53d6\u6700\u5927\u6027\u80fd\u7684\u5de5\u5177\u3002

"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/03.%20x86%20and%20AVX/#x86-simd","title":"x86 SIMD\u6f14\u8fdb","text":" \u4ee3\u6b21 \u5e74\u4efd \u5bc4\u5b58\u5668\u5bbd\u5ea6 \u5bc4\u5b58\u5668\u6570\u91cf \u5173\u952e\u7279\u6027 MMX 1997 64\u4f4d 8\uff08mm0-7\uff09 \u4ec5\u6574\u6570\uff0c\u4e0eFPU\u5171\u4eab SSE 1999 128\u4f4d 8\uff08xmm0-7\uff09 4\u4e2a\u6d6e\u70b9\u6570\uff0c\u4e13\u7528\u5bc4\u5b58\u5668 SSE2 2001 128\u4f4d 8/16 2\u4e2a\u53cc\u7cbe\u5ea6\u6d6e\u70b9\u6570\uff0c\u6574\u6570\u64cd\u4f5c AVX 2011 256\u4f4d 16\uff08ymm0-15\uff09 8\u4e2a\u6d6e\u70b9\u6570\uff0c\u4e09\u64cd\u4f5c\u6570\u6307\u4ee4 AVX2 2013 256\u4f4d 16 \u6574\u6570256\u4f4d\uff0cFMA\uff0c\u6536\u96c6 AVX-512 2017 512\u4f4d 32\uff08zmm0-31\uff09 16\u4e2a\u6d6e\u70b9\u6570\uff0c\u63a9\u7801\u5bc4\u5b58\u5668\uff0c\u5206\u6563 AMX 2023 \u74e6\u7247\u5bc4\u5b58\u5668 8\u4e2a\u74e6\u7247 \u77e9\u9635\u4e58\u6cd5\uff08BF16\uff0cINT8\uff09 "},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/03.%20x86%20and%20AVX/#avx2","title":"AVX2\u7f16\u7a0b","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/03.%20x86%20and%20AVX/#_1","title":"\u5185\u8054\u51fd\u6570\u547d\u540d\u7ea6\u5b9a","text":"
#include <immintrin.h>  // \u6240\u6709x86 SIMD\u5185\u8054\u51fd\u6570\n\n// \u6570\u636e\u7c7b\u578b\n__m256  a;   // 256\u4f4d\u5bc4\u5b58\u5668\uff0c\u4fdd\u5b588\u4e2afloat32\n__m256d b;   // 256\u4f4d\u5bc4\u5b58\u5668\uff0c\u4fdd\u5b584\u4e2afloat64\n__m256i c;   // 256\u4f4d\u5bc4\u5b58\u5668\uff0c\u4fdd\u5b58\u6574\u6570\uff088x32\u300116x16\u621632x8\uff09\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/03.%20x86%20and%20AVX/#_2","title":"\u52a0\u8f7d\u548c\u5b58\u50a8\u6570\u636e","text":"
// \u4ece\u5185\u5b58\u52a0\u8f7d8\u4e2a\u6d6e\u70b9\u6570\n__m256 v = _mm256_loadu_ps(ptr);      // \u975e\u5bf9\u9f50\u52a0\u8f7d\uff08\u9002\u7528\u4e8e\u4efb\u4f55\u5730\u5740\uff09\n__m256 v = _mm256_load_ps(ptr);       // \u5bf9\u9f50\u52a0\u8f7d\uff08ptr\u5fc5\u987b32\u5b57\u8282\u5bf9\u9f50\uff0c\u66f4\u5feb\uff09\n\n// \u5b58\u50a88\u4e2a\u6d6e\u70b9\u6570\u5230\u5185\u5b58\n_mm256_storeu_ps(out_ptr, v);          // \u975e\u5bf9\u9f50\u5b58\u50a8\n_mm256_store_ps(out_ptr, v);           // \u5bf9\u9f50\u5b58\u50a8\n\n// \u5c06\u5355\u4e2a\u503c\u5e7f\u64ad\u5230\u6240\u67098\u4e2a\u901a\u9053\n__m256 ones = _mm256_set1_ps(1.0f);    // [1, 1, 1, 1, 1, 1, 1, 1]\n\n// \u8bbe\u7f6e\u5404\u4e2a\u503c\uff08\u5f88\u5c11\u9700\u8981\uff09\n__m256 v = _mm256_set_ps(7,6,5,4,3,2,1,0);  // \u6ce8\u610f\uff1a\u9006\u5e8f\uff01\n\n// \u96f6\u5bc4\u5b58\u5668\n__m256 z = _mm256_setzero_ps();\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/03.%20x86%20and%20AVX/#_3","title":"\u7b97\u672f\u8fd0\u7b97","text":"
__m256 c = _mm256_add_ps(a, b);        // c[i] = a[i] + b[i]\n__m256 d = _mm256_mul_ps(a, b);        // d[i] = a[i] * b[i]\n__m256 e = _mm256_sub_ps(a, b);        // e[i] = a[i] - b[i]\n__m256 f = _mm256_div_ps(a, b);        // f[i] = a[i] / b[i]\uff08\u6bd4\u4e58\u6cd5\u6162\uff09\n\n// \u878d\u5408\u4e58\u52a0\uff1ar = a * b + c\uff08\u4e00\u6761\u6307\u4ee4\uff0c\u4e00\u6b21\u820d\u5165\uff09\n__m256 r = _mm256_fmadd_ps(a, b, c);   // ML\u6700\u91cd\u8981\u7684\u6307\u4ee4\n\n// \u6700\u5c0f\u503c\u548c\u6700\u5927\u503c\n__m256 mn = _mm256_min_ps(a, b);       // min(a[i], b[i]) \u2014 \u7528\u4e8e\u88c1\u526a\n__m256 mx = _mm256_max_ps(a, b);       // max(a[i], b[i]) \u2014 \u7528\u4e8eReLU\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/03.%20x86%20and%20AVX/#avx2_1","title":"\u5b9e\u8df5\u793a\u4f8b\uff1aAVX2\u70b9\u79ef","text":"
#include <immintrin.h>\n\nfloat dot_avx2(const float* a, const float* b, int n) {\n    __m256 sum = _mm256_setzero_ps();  // 8\u4e2a\u7d2f\u52a0\u5668\u521d\u59cb\u5316\u4e3a0\n\n    int i = 0;\n    for (; i + 8 <= n; i += 8) {\n        __m256 va = _mm256_loadu_ps(a + i);\n        __m256 vb = _mm256_loadu_ps(b + i);\n        sum = _mm256_fmadd_ps(va, vb, sum);  // sum += va * vb\n    }\n\n    // \u6c34\u5e73\u5f52\u7ea6\uff1a\u5c06sum\u76848\u4e2a\u5143\u7d20\u76f8\u52a0\n    // \u6b65\u9aa41\uff1a\u5c06\u4e0a128\u4f4d\u52a0\u5230\u4e0b128\u4f4d\n    __m128 hi = _mm256_extractf128_ps(sum, 1);\n    __m128 lo = _mm256_castps256_ps128(sum);\n    __m128 sum128 = _mm_add_ps(hi, lo);        // 4\u4e2a\u90e8\u5206\u548c\n\n    // \u6b65\u9aa42\uff1a\u5728128\u4f4d\u5bc4\u5b58\u5668\u5185\u6c34\u5e73\u76f8\u52a0\n    sum128 = _mm_hadd_ps(sum128, sum128);       // [a+b, c+d, a+b, c+d]\n    sum128 = _mm_hadd_ps(sum128, sum128);       // [a+b+c+d, ...]\n\n    float result = _mm_cvtss_f32(sum128);       // \u63d0\u53d6\u6807\u91cf\n\n    // \u6807\u91cf\u6e05\u7406\n    for (; i < n; i++) {\n        result += a[i] * b[i];\n    }\n\n    return result;\n}\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/03.%20x86%20and%20AVX/#avx2-softmax","title":"\u5b9e\u8df5\u793a\u4f8b\uff1aAVX2 Softmax\uff08\u7b80\u5316\u7248\uff09","text":"
float vector_max_avx2(const float* data, int n) {\n    __m256 max_vec = _mm256_set1_ps(-INFINITY);\n\n    int i = 0;\n    for (; i + 8 <= n; i += 8) {\n        __m256 v = _mm256_loadu_ps(data + i);\n        max_vec = _mm256_max_ps(max_vec, v);\n    }\n\n    // \u5c068\u4e2a\u6700\u5927\u503c\u5f52\u7ea6\u4e3a1\u4e2a\n    __m128 hi = _mm256_extractf128_ps(max_vec, 1);\n    __m128 lo = _mm256_castps256_ps128(max_vec);\n    __m128 max128 = _mm_max_ps(hi, lo);\n\n    // \u901a\u8fc7\u6df7\u6d17\u548c\u53d6\u6700\u5927\u503c\u627e\u5230\u5355\u4e00\u6700\u5927\u503c\n    max128 = _mm_max_ps(max128, _mm_shuffle_ps(max128, max128, 0b01001110));\n    max128 = _mm_max_ps(max128, _mm_shuffle_ps(max128, max128, 0b10110001));\n\n    float result = _mm_cvtss_f32(max128);\n\n    for (; i < n; i++) {\n        result = result > data[i] ? result : data[i];\n    }\n\n    return result;\n}\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/03.%20x86%20and%20AVX/#avx-512","title":"AVX-512","text":"
__m512 a = _mm512_loadu_ps(ptr);                // \u52a0\u8f7d16\u4e2a\u6d6e\u70b9\u6570\n__m512 c = _mm512_fmadd_ps(a, b, c);            // 16\u4e2aFMA\u540c\u65f6\u8fdb\u884c\nfloat sum = _mm512_reduce_add_ps(a);             // \u5185\u7f6e\u6c34\u5e73\u6c42\u548c\uff08\u65e0\u9700\u624b\u52a8\u5f52\u7ea6\uff01\uff09\n\n// \u63a9\u7801\u64cd\u4f5c\uff1a\u64cd\u4f5c\u901a\u9053\u5b50\u96c6\n__mmask16 mask = _mm512_cmpgt_ps_mask(a, zero);  // \u54ea\u4e9b\u901a\u9053 > 0\uff1f\n__m512 relu = _mm512_maskz_mov_ps(mask, a);       // \u8d1f\u901a\u9053\u7f6e\u96f6 = ReLU\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/03.%20x86%20and%20AVX/#intel-amx","title":"Intel AMX\uff1a\u77e9\u9635\u4e58\u6cd5\u786c\u4ef6","text":"
#include <immintrin.h>\n\n// AMX\u74e6\u7247\u4e58\u6cd5\uff1aC += A * B\uff08BF16\u683c\u5f0f\uff09\n// A\u4e3a16x32 BF16\uff0cB\u4e3a32x16 BF16\uff0cC\u4e3a16x16 FP32\n_tile_loadd(0, a_ptr, stride_a);   // \u4eceA\u52a0\u8f7d\u74e6\u72470\n_tile_loadd(1, b_ptr, stride_b);   // \u4eceB\u52a0\u8f7d\u74e6\u72471\n_tile_dpbf16ps(2, 0, 1);           // \u74e6\u72472 += \u74e6\u72470 * \u74e6\u72471\uff08BF16\u77e9\u9635\u4e58\u6cd5\uff0cFP32\u7d2f\u52a0\uff09\n_tile_stored(2, c_ptr, stride_c);  // \u5b58\u50a8\u74e6\u72472\u5230C\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/03.%20x86%20and%20AVX/#_4","title":"\u5185\u5b58\u5bf9\u9f50","text":"
// \u5206\u914d\u5bf9\u9f50\u5185\u5b58\nfloat* data = (float*)aligned_alloc(32, n * sizeof(float));  // AVX\u768432\u5b57\u8282\u5bf9\u9f50\n\n// C++\u5bf9\u9f50\u5206\u914d\n#include <new>\nfloat* data = new (std::align_val_t(32)) float[n];\n\n// \u6216\u8005\u4f7f\u7528\u7f16\u8bd1\u5668\u5c5e\u6027\nalignas(32) float data[1024];\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/03.%20x86%20and%20AVX/#_5","title":"\u6027\u80fd\u9677\u9631","text":"
// \u5355\u7d2f\u52a0\u5668\uff1a\u53d7FMA\u5ef6\u8fdf\u9650\u5236\uff084-5\u4e2a\u5468\u671f\uff09\n__m256 sum = _mm256_setzero_ps();\nfor (...) {\n    sum = _mm256_fmadd_ps(a, b, sum);  // \u6bcf\u4e2a\u4f9d\u8d56\u524d\u4e00\u4e2a\n}\n\n// \u56db\u4e2a\u7d2f\u52a0\u5668\uff1a4\u500d\u541e\u5410\u91cf\uff08\u9690\u85cf\u5ef6\u8fdf\uff09\n__m256 sum0 = _mm256_setzero_ps();\n__m256 sum1 = _mm256_setzero_ps();\n__m256 sum2 = _mm256_setzero_ps();\n__m256 sum3 = _mm256_setzero_ps();\nfor (...) {\n    sum0 = _mm256_fmadd_ps(a0, b0, sum0);  // \u72ec\u7acb\n    sum1 = _mm256_fmadd_ps(a1, b1, sum1);  // \u72ec\u7acb\n    sum2 = _mm256_fmadd_ps(a2, b2, sum2);  // \u72ec\u7acb\n    sum3 = _mm256_fmadd_ps(a3, b3, sum3);  // \u72ec\u7acb\n}\nsum0 = _mm256_add_ps(sum0, sum1);\nsum2 = _mm256_add_ps(sum2, sum3);\nsum0 = _mm256_add_ps(sum0, sum2);\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/03.%20x86%20and%20AVX/#_6","title":"\u6027\u80fd\u5206\u6790","text":"
# Linux perf\uff08\u9700\u8981\u5185\u6838\u652f\u6301\uff09\nperf stat ./my_program                    # \u57fa\u672c\u8ba1\u6570\u5668\uff1a\u5468\u671f\u3001\u6307\u4ee4\u3001IPC\nperf stat -e cache-misses,cache-references ./my_program  # \u7f13\u5b58\u884c\u4e3a\nperf record -g ./my_program && perf report              # \u8c03\u7528\u56fe\u5206\u6790\n\n# Intel VTune\uff08\u8be6\u7ec6\u7684x86\u6027\u80fd\u5206\u6790\uff09\nvtune -collect hotspots -- ./my_program\nvtune -collect memory-access -- ./my_program   # \u5185\u5b58\u5e26\u5bbd\u5206\u6790\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/03.%20x86%20and%20AVX/#x86intelamdgclang","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u5728x86\u2014\u2014Intel/AMD\u4e0a\u7528g++\u6216clang++\u7f16\u8bd1\uff09","text":"
  1. \u7f16\u5199\u6807\u91cf\u70b9\u79ef\u548cAVX2\u70b9\u79ef\u3002\u5bf9\u4e24\u8005\u8fdb\u884c\u57fa\u51c6\u6d4b\u8bd5\u5e76\u6d4b\u91cf8\u8defSIMD\u5e26\u6765\u7684\u52a0\u901f\u6bd4\u3002

    // task1_avx_dot.cpp\n// \u7f16\u8bd1\uff1ag++ -O3 -mavx2 -mfma -o task1 task1_avx_dot.cpp\n\n#include <iostream>\n#include <chrono>\n#include <vector>\n#include <immintrin.h>\n\nfloat dot_scalar(const float* a, const float* b, int n) {\n    float sum = 0.0f;\n    for (int i = 0; i < n; i++) sum += a[i] * b[i];\n    return sum;\n}\n\nfloat dot_avx2(const float* a, const float* b, int n) {\n    __m256 sum = _mm256_setzero_ps();\n    int i = 0;\n    for (; i + 8 <= n; i += 8) {\n        __m256 va = _mm256_loadu_ps(a + i);\n        __m256 vb = _mm256_loadu_ps(b + i);\n        sum = _mm256_fmadd_ps(va, vb, sum);\n    }\n    // \u5f52\u7ea6\uff1a\u4e0a128\u52a0\u5230\u4e0b128\uff0c\u7136\u540e\u6c34\u5e73\u76f8\u52a0\n    __m128 hi = _mm256_extractf128_ps(sum, 1);\n    __m128 lo = _mm256_castps256_ps128(sum);\n    __m128 r = _mm_add_ps(hi, lo);\n    r = _mm_hadd_ps(r, r);\n    r = _mm_hadd_ps(r, r);\n    float result = _mm_cvtss_f32(r);\n    for (; i < n; i++) result += a[i] * b[i];\n    return result;\n}\n\nint main() {\n    const int N = 10'000'000;\n    std::vector<float> a(N, 1.0f), b(N, 2.0f);\n\n    volatile float s1 = dot_scalar(a.data(), b.data(), N);\n    volatile float s2 = dot_avx2(a.data(), b.data(), N);\n\n    auto bench = [&](auto fn, const char* name) {\n        auto start = std::chrono::high_resolution_clock::now();\n        volatile float s;\n        for (int t = 0; t < 100; t++) s = fn(a.data(), b.data(), N);\n        auto end = std::chrono::high_resolution_clock::now();\n        double ms = std::chrono::duration<double, std::milli>(end - start).count() / 100;\n        std::cout << name << \": \" << ms << \" ms\uff08\u7ed3\u679c: \" << s << \"\uff09\\n\";\n        return ms;\n    };\n\n    double t1 = bench(dot_scalar, \"\u6807\u91cf\");\n    double t2 = bench(dot_avx2,   \"AVX2  \");\n    std::cout << \"\u52a0\u901f\u6bd4: \" << t1 / t2 << \"x\\n\";\n    return 0;\n}\n

  2. \u4f7f\u7528 _mm256_max_ps \u5b9e\u73b0AVX2 ReLU\u5e76\u4e0e\u6807\u91cf\u5faa\u73af\u6bd4\u8f83\u3002\u7136\u540e\u5c1d\u8bd5\u4f7f\u7528\u591a\u7d2f\u52a0\u5668\uff08\u5faa\u73af\u5c55\u5f00\uff09\u4ee5\u9690\u85cfFMA\u5ef6\u8fdf\u3002

    // task2_avx_relu.cpp\n// \u7f16\u8bd1\uff1ag++ -O3 -mavx2 -o task2 task2_avx_relu.cpp\n\n#include <iostream>\n#include <chrono>\n#include <vector>\n#include <immintrin.h>\n\nvoid relu_scalar(const float* in, float* out, int n) {\n    for (int i = 0; i < n; i++) {\n        out[i] = in[i] > 0.0f ? in[i] : 0.0f;\n    }\n}\n\nvoid relu_avx2(const float* in, float* out, int n) {\n    __m256 zero = _mm256_setzero_ps();\n    int i = 0;\n    for (; i + 8 <= n; i += 8) {\n        __m256 x = _mm256_loadu_ps(in + i);\n        _mm256_storeu_ps(out + i, _mm256_max_ps(x, zero));\n    }\n    for (; i < n; i++) out[i] = in[i] > 0.0f ? in[i] : 0.0f;\n}\n\n// \u5c55\u5f00\uff1a\u6bcf\u6b21\u8fed\u4ee3\u5904\u740632\u4e2a\u6d6e\u70b9\u6570\uff084 x 8\uff09\nvoid relu_avx2_unrolled(const float* in, float* out, int n) {\n    __m256 zero = _mm256_setzero_ps();\n    int i = 0;\n    for (; i + 32 <= n; i += 32) {\n        __m256 x0 = _mm256_loadu_ps(in + i);\n        __m256 x1 = _mm256_loadu_ps(in + i + 8);\n        __m256 x2 = _mm256_loadu_ps(in + i + 16);\n        __m256 x3 = _mm256_loadu_ps(in + i + 24);\n        _mm256_storeu_ps(out + i,      _mm256_max_ps(x0, zero));\n        _mm256_storeu_ps(out + i + 8,  _mm256_max_ps(x1, zero));\n        _mm256_storeu_ps(out + i + 16, _mm256_max_ps(x2, zero));\n        _mm256_storeu_ps(out + i + 24, _mm256_max_ps(x3, zero));\n    }\n    for (; i + 8 <= n; i += 8) {\n        _mm256_storeu_ps(out + i, _mm256_max_ps(_mm256_loadu_ps(in + i), zero));\n    }\n    for (; i < n; i++) out[i] = in[i] > 0.0f ? in[i] : 0.0f;\n}\n\nint main() {\n    const int N = 16'000'000;\n    std::vector<float> in(N), out(N);\n    for (int i = 0; i < N; i++) in[i] = (float)(i % 200) - 100.0f;\n\n    auto bench = [&](auto fn, const char* name) {\n        fn(in.data(), out.data(), N);  // \u9884\u70ed\n        auto start = std::chrono::high_resolution_clock::now();\n        for (int t = 0; t < 100; t++) fn(in.data(), out.data(), N);\n        auto end = std::chrono::high_resolution_clock::now();\n        double ms = std::chrono::duration<double, std::milli>(end - start).count() / 100;\n        double bw = 2.0 * N * sizeof(float) / ms / 1e6;  // \u8bfb\u53d6+\u5199\u5165\n        std::cout << name << \": \" << ms << \" ms\uff08\" << bw << \" GB/s\uff09\\n\";\n    };\n\n    bench(relu_scalar,        \"\u6807\u91cf          \");\n    bench(relu_avx2,          \"AVX2          \");\n    bench(relu_avx2_unrolled, \"AVX2\u5c55\u5f00      \");\n    return 0;\n}\n

  3. \u6d4b\u91cf\u5185\u5b58\u5bf9\u9f50\u7684\u6548\u679c\u3002\u6bd4\u8f83\u5728\u5927\u6570\u7ec4\u4e0a\u7684\u5bf9\u9f50\u52a0\u8f7d\u4e0e\u975e\u5bf9\u9f50\u52a0\u8f7d\u3002

    // task3_alignment.cpp\n// \u7f16\u8bd1\uff1ag++ -O3 -mavx2 -o task3 task3_alignment.cpp\n\n#include <iostream>\n#include <chrono>\n#include <cstdlib>\n#include <immintrin.h>\n\nint main() {\n    const int N = 16'000'000;\n\n    // \u5bf9\u9f50\u5206\u914d\uff08AVX2\u4e3a32\u5b57\u8282\uff09\n    float* aligned = (float*)aligned_alloc(32, N * sizeof(float));\n\n    // \u975e\u5bf9\u9f50\uff1a\u4ece\u5bf9\u9f50\u8fb9\u754c\u504f\u79fb4\u5b57\u8282\uff081\u4e2a\u6d6e\u70b9\u6570\uff09\n    float* raw = (float*)malloc((N + 1) * sizeof(float));\n    float* unaligned = raw + 1;  // \u4fdd\u8bc1\u672a\u5bf9\u9f50\n\n    for (int i = 0; i < N; i++) {\n        aligned[i] = 1.0f;\n        unaligned[i] = 1.0f;\n    }\n\n    auto bench = [&](float* ptr, bool use_aligned, const char* name) {\n        __m256 sum = _mm256_setzero_ps();\n        // \u9884\u70ed\n        for (int i = 0; i + 8 <= N; i += 8) {\n            __m256 v = use_aligned ? _mm256_load_ps(ptr + i) : _mm256_loadu_ps(ptr + i);\n            sum = _mm256_add_ps(sum, v);\n        }\n\n        auto start = std::chrono::high_resolution_clock::now();\n        for (int t = 0; t < 100; t++) {\n            sum = _mm256_setzero_ps();\n            for (int i = 0; i + 8 <= N; i += 8) {\n                __m256 v = use_aligned ? _mm256_load_ps(ptr + i) : _mm256_loadu_ps(ptr + i);\n                sum = _mm256_add_ps(sum, v);\n            }\n        }\n        auto end = std::chrono::high_resolution_clock::now();\n        double ms = std::chrono::duration<double, std::milli>(end - start).count() / 100;\n        double bw = (double)N * sizeof(float) / ms / 1e6;\n        std::cout << name << \": \" << ms << \" ms\uff08\" << bw << \" GB/s\uff09\\n\";\n    };\n\n    bench(aligned,   true,  \"\u5bf9\u9f50\u52a0\u8f7d  \");\n    bench(unaligned, false, \"\u975e\u5bf9\u9f50\u52a0\u8f7d\");\n\n    free(aligned);\n    free(raw);\n    return 0;\n}\n

"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/04.%20GPU%20architecture%20and%20CUDA/","title":"GPU\u67b6\u6784\u4e0eCUDA","text":"

GPU\u901a\u8fc7\u63d0\u4f9b\u6570\u5343\u4e2a\u6838\u5fc3\u7528\u4e8e\u5927\u89c4\u6a21\u5e76\u884c\u8ba1\u7b97\uff0c\u6539\u53d8\u4e86AI\u3002\u672c\u6587\u6db5\u76d6GPU\u4e0eCPU\u7684\u8bbe\u8ba1\u54f2\u5b66\u5bf9\u6bd4\u3001GPU\u5b58\u50a8\u5c42\u6b21\u3001C++\u4e2d\u7684CUDA\u7f16\u7a0b\u3001SIMT\u6267\u884c\u6a21\u578b\u3001\u5185\u5b58\u8bbf\u95ee\u6a21\u5f0f\u3001\u540c\u6b65\u3001\u6d41\u3001\u6027\u80fd\u5206\u6790\u4ee5\u53caNVIDIA GPU\u4ee3\u6b21\u2014\u2014\u7f16\u5199\u548c\u7406\u89e3GPU\u6838\u51fd\u6570\u6240\u9700\u7684\u77e5\u8bc6\u3002

"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/04.%20GPU%20architecture%20and%20CUDA/#gpu-vs-cpu","title":"GPU vs CPU\uff1a\u6839\u672c\u4e0d\u540c\u7684\u8bbe\u8ba1","text":" CPU GPU \u6838\u5fc3 4-128\uff08\u590d\u6742\u3001\u5feb\u901f\uff09 1,000-20,000\uff08\u7b80\u5355\u3001\u6162\u901f\uff09 \u65f6\u949f\u9891\u7387 3-5 GHz 1-2.5 GHz \u7f13\u5b58 \u5927\uff0832 MB+ L3\uff09 \u5c0f\uff08\u6bcfSM\u5171\u4eab\u5185\u5b58\uff09 \u5206\u652f\u9884\u6d4b \u7cbe\u5bc6 \u65e0\uff08\u6240\u6709\u7ebf\u7a0b\u9075\u5faa\u76f8\u540c\u8def\u5f84\uff09 \u6700\u9002\u5408 \u4f4e\u5ef6\u8fdf\u3001\u590d\u6742\u63a7\u5236\u6d41 \u9ad8\u541e\u5410\u91cf\u3001\u6570\u636e\u5e76\u884c\u5de5\u4f5c \u5178\u578bFLOPS\uff08FP32\uff09 1-5 TFLOPS 30-80 TFLOPS \u5185\u5b58\u5e26\u5bbd 50-100 GB/s 1-3 TB/s "},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/04.%20GPU%20architecture%20and%20CUDA/#gpu","title":"GPU\u5b58\u50a8\u5c42\u6b21","text":" \u5185\u5b58 \u5927\u5c0f \u5ef6\u8fdf \u5e26\u5bbd \u4f5c\u7528\u57df \u5bc4\u5b58\u5668 \u6bcfSM\u7ea6256 KB 0\u5468\u671f \u6700\u9ad8 \u6bcf\u7ebf\u7a0b \u5171\u4eab\u5185\u5b58 \u6bcfSM 48-228 KB \u7ea65\u5468\u671f \u7ea620 TB/s \u6bcf\u7ebf\u7a0b\u5757 L1\u7f13\u5b58 \u6bcfSM 128-256 KB \u7ea630\u5468\u671f \u6bcfSM L2\u7f13\u5b58 4-96 MB \u7ea6200\u5468\u671f \u7ea66 TB/s \u5168\u5c40 \u5168\u5c40\u5185\u5b58\uff08HBM\uff09 24-192 GB \u7ea6400\u5468\u671f 1-3.3 TB/s \u5168\u5c40 "},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/04.%20GPU%20architecture%20and%20CUDA/#cuda","title":"CUDA\u7f16\u7a0b\u6a21\u578b","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/04.%20GPU%20architecture%20and%20CUDA/#_1","title":"\u5c42\u6b21\u7ed3\u6784\uff1a\u7f51\u683c\u3001\u5757\u3001\u7ebf\u7a0b","text":"
\u7f51\u683c\uff08\u6574\u4e2a\u542f\u52a8\uff09\n\u251c\u2500\u2500 \u5757 (0,0)\n\u2502   \u251c\u2500\u2500 \u7ebf\u7a0b (0,0)\n\u2502   \u251c\u2500\u2500 \u7ebf\u7a0b (1,0)\n\u2502   \u251c\u2500\u2500 \u7ebf\u7a0b (2,0)\n\u2502   \u2514\u2500\u2500 ... \uff08\u6bcf\u5757\u6700\u591a1024\u7ebf\u7a0b\uff09\n\u251c\u2500\u2500 \u5757 (1,0)\n\u2502   \u251c\u2500\u2500 \u7ebf\u7a0b (0,0)\n\u2502   \u2514\u2500\u2500 ...\n\u2514\u2500\u2500 ... \uff08\u53ef\u80fd\u6709\u6570\u767e\u4e07\u4e2a\u5757\uff09\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/04.%20GPU%20architecture%20and%20CUDA/#cuda_1","title":"\u4f60\u7684\u7b2c\u4e00\u4e2aCUDA\u6838\u51fd\u6570","text":"
// vector_add.cu \u2014 CUDA\u6e90\u6587\u4ef6\uff08.cu\u6269\u5c55\u540d\uff09\n\n#include <stdio.h>\n\n// __global__ \u6807\u8bb0\u6b64\u4e3aGPU\u6838\u51fd\u6570\uff08\u4eceCPU\u8c03\u7528\uff0c\u5728GPU\u4e0a\u8fd0\u884c\uff09\n__global__ void vector_add(const float* a, const float* b, float* c, int n) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx < n) {           // \u8fb9\u754c\u68c0\u67e5\uff08\u7f51\u683c\u53ef\u80fd\u5927\u4e8e\u6570\u636e\uff09\n        c[idx] = a[idx] + b[idx];\n    }\n}\n\nint main() {\n    int n = 1 << 20;  // \u7ea6100\u4e07\u4e2a\u5143\u7d20\n    size_t bytes = n * sizeof(float);\n\n    // \u5206\u914d\u4e3b\u673a\uff08CPU\uff09\u5185\u5b58\n    float *h_a = new float[n];\n    float *h_b = new float[n];\n    float *h_c = new float[n];\n\n    // \u521d\u59cb\u5316\n    for (int i = 0; i < n; i++) {\n        h_a[i] = 1.0f;\n        h_b[i] = 2.0f;\n    }\n\n    // \u5206\u914d\u8bbe\u5907\uff08GPU\uff09\u5185\u5b58\n    float *d_a, *d_b, *d_c;\n    cudaMalloc(&d_a, bytes);\n    cudaMalloc(&d_b, bytes);\n    cudaMalloc(&d_c, bytes);\n\n    // \u5c06\u6570\u636e\u4eceCPU\u62f7\u8d1d\u5230GPU\n    cudaMemcpy(d_a, h_a, bytes, cudaMemcpyHostToDevice);\n    cudaMemcpy(d_b, h_b, bytes, cudaMemcpyHostToDevice);\n\n    // \u542f\u52a8\u6838\u51fd\u6570\uff1a\u6bcf\u5757256\u7ebf\u7a0b\uff0c\u8db3\u591f\u7684\u5757\u8986\u76d6n\u4e2a\u5143\u7d20\n    int block_size = 256;\n    int grid_size = (n + block_size - 1) / block_size;  // \u4e0a\u53d6\u6574\u9664\u6cd5\n    vector_add<<<grid_size, block_size>>>(d_a, d_b, d_c, n);\n\n    // \u5c06\u7ed3\u679c\u4eceGPU\u62f7\u8d1d\u5230CPU\n    cudaMemcpy(h_c, d_a, bytes, cudaMemcpyDeviceToHost);\n\n    // \u9a8c\u8bc1\n    printf(\"c[0] = %f\uff08\u671f\u671b\u503c 3.0\uff09\\n\", h_c[0]);\n\n    // \u91ca\u653e\u5185\u5b58\n    cudaFree(d_a); cudaFree(d_b); cudaFree(d_c);\n    delete[] h_a; delete[] h_b; delete[] h_c;\n\n    return 0;\n}\n
# \u7528NVIDIA\u7f16\u8bd1\u5668\u7f16\u8bd1\nnvcc -O3 -o vector_add vector_add.cu\n./vector_add\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/04.%20GPU%20architecture%20and%20CUDA/#simt","title":"\u7ebf\u7a0b\u675f\u4e0eSIMT","text":"
// \u7cdf\u7cd5\uff1a\u7ebf\u7a0b\u675f\u5206\u6b67\uff08\u540c\u4e00\u7ebf\u7a0b\u675f\u4e2d\u7684\u7ebf\u7a0b\u8d70\u4e0d\u540c\u8def\u5f84\uff09\nif (threadIdx.x % 2 == 0) {\n    c[idx] = a[idx] + b[idx];    // \u5076\u6570\u7ebf\u7a0b\u505a\u8fd9\u4e2a\n} else {\n    c[idx] = a[idx] - b[idx];    // \u5947\u6570\u7ebf\u7a0b\u505a\u8fd9\u4e2a\uff08\u540c\u4e00\u7ebf\u7a0b\u675f\uff0c\u4e32\u884c\u5316\uff09\n}\n\n// \u66f4\u597d\uff1a\u65e0\u5206\u652f\uff08\u65e0\u5206\u6b67\uff09\nfloat sign = (threadIdx.x % 2 == 0) ? 1.0f : -1.0f;\nc[idx] = a[idx] + sign * b[idx];  // \u6240\u6709\u7ebf\u7a0b\u6267\u884c\u76f8\u540c\u6307\u4ee4\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/04.%20GPU%20architecture%20and%20CUDA/#_2","title":"\u5185\u5b58\u5408\u5e76","text":"
// \u597d\uff1a\u5408\u5e76\u2014\u2014\u7ebf\u7a0b0\u8bfba[0]\uff0c\u7ebf\u7a0b1\u8bfba[1]\uff0c...\nc[idx] = a[idx] + b[idx];\n\n// \u574f\uff1a\u8de8\u6b65\u2014\u2014\u7ebf\u7a0b0\u8bfba[0]\uff0c\u7ebf\u7a0b1\u8bfba[\u6b65\u957f]\uff0c...\nc[idx] = a[idx * stride] + b[idx * stride];  // \u6b65\u957f > 1 \u6d6a\u8d39\u5e26\u5bbd\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/04.%20GPU%20architecture%20and%20CUDA/#_3","title":"\u5171\u4eab\u5185\u5b58\u4e0e\u5206\u5757","text":"
// \u4f7f\u7528\u5171\u4eab\u5185\u5b58\u5206\u5757\u7684\u77e9\u9635\u4e58\u6cd5\uff08\u7b80\u5316\u7248\uff09\n__global__ void matmul_tiled(const float* A, const float* B, float* C,\n                              int M, int N, int K) {\n    // A\u7684\u4e00\u4e2a\u74e6\u7247\u548cB\u7684\u4e00\u4e2a\u74e6\u7247\u7684\u5171\u4eab\u5185\u5b58\n    __shared__ float tile_A[TILE_SIZE][TILE_SIZE];\n    __shared__ float tile_B[TILE_SIZE][TILE_SIZE];\n\n    int row = blockIdx.y * TILE_SIZE + threadIdx.y;\n    int col = blockIdx.x * TILE_SIZE + threadIdx.x;\n    float sum = 0.0f;\n\n    // \u904d\u5386\u74e6\u7247\n    for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {\n        // \u5c06A\u548cB\u7684\u4e00\u4e2a\u74e6\u7247\u52a0\u8f7d\u5230\u5171\u4eab\u5185\u5b58\n        if (row < M && t * TILE_SIZE + threadIdx.x < K)\n            tile_A[threadIdx.y][threadIdx.x] = A[row * K + t * TILE_SIZE + threadIdx.x];\n        else\n            tile_A[threadIdx.y][threadIdx.x] = 0.0f;\n\n        if (col < N && t * TILE_SIZE + threadIdx.y < K)\n            tile_B[threadIdx.y][threadIdx.x] = B[(t * TILE_SIZE + threadIdx.y) * N + col];\n        else\n            tile_B[threadIdx.y][threadIdx.x] = 0.0f;\n\n        __syncthreads();  // \u7b49\u5f85\u6240\u6709\u7ebf\u7a0b\u5b8c\u6210\u52a0\u8f7d\n\n        // \u8ba1\u7b97\u6b64\u74e6\u7247\u7684\u90e8\u5206\u70b9\u79ef\n        for (int k = 0; k < TILE_SIZE; k++) {\n            sum += tile_A[threadIdx.y][k] * tile_B[k][threadIdx.x];\n        }\n\n        __syncthreads();  // \u5728\u52a0\u8f7d\u4e0b\u4e00\u4e2a\u74e6\u7247\u524d\u7b49\u5f85\n    }\n\n    if (row < M && col < N)\n        C[row * N + col] = sum;\n}\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/04.%20GPU%20architecture%20and%20CUDA/#_4","title":"\u6d41\u4e0e\u5e76\u53d1","text":"
cudaStream_t stream1, stream2;\ncudaStreamCreate(&stream1);\ncudaStreamCreate(&stream2);\n\n// \u8fd9\u4e9b\u64cd\u4f5c\u53ef\u4ee5\u91cd\u53e0\uff1a\u4e0d\u540c\u6d41\u5e76\u53d1\u6267\u884c\ncudaMemcpyAsync(d_a, h_a, bytes, cudaMemcpyHostToDevice, stream1);\ncudaMemcpyAsync(d_b, h_b, bytes, cudaMemcpyHostToDevice, stream2);\n\nkernel1<<<grid, block, 0, stream1>>>(d_a, d_c);\nkernel2<<<grid, block, 0, stream2>>>(d_b, d_d);\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/04.%20GPU%20architecture%20and%20CUDA/#cuda_2","title":"\u5206\u6790CUDA\u4ee3\u7801","text":"
# NVIDIA Nsight Compute\uff1a\u6838\u51fd\u6570\u7ea7\u5206\u6790\nncu --set full ./my_program\n\n# NVIDIA Nsight Systems\uff1a\u7cfb\u7edf\u7ea7\u65f6\u95f4\u7ebf\nnsys profile ./my_program\n\n# \u5feb\u901f\u6307\u6807\nncu --metrics sm__throughput,dram__throughput ./my_program\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/04.%20GPU%20architecture%20and%20CUDA/#_5","title":"\u9ad8\u7ea7\u4f18\u5316\u6280\u672f","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/04.%20GPU%20architecture%20and%20CUDA/#aos-vs-soa","title":"\u6570\u636e\u5e03\u5c40\uff1aAoS vs SoA","text":"
// AoS\uff1a\u5bf9\u4e8eSIMD/GPU\u4e0d\u597d\uff08\u8bbf\u95ee\u6240\u6709x\u503c\u89e6\u53ca\u975e\u8fde\u7eed\u5185\u5b58\uff09\nstruct Particle { float x, y, z, mass; };\nParticle particles[N];\n// particles[0].x, particles[1].x \u76f8\u969416\u5b57\u8282\n\n// SoA\uff1a\u5bf9\u4e8eSIMD/GPU\u597d\uff08\u6240\u6709x\u503c\u8fde\u7eed\uff09\nstruct Particles {\n    float x[N], y[N], z[N], mass[N];\n};\n// x[0], x[1] \u76f8\u96944\u5b57\u8282\u2014\u2014\u975e\u5e38\u9002\u5408\u5408\u5e76\u8bbf\u95ee\u548cSIMD\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/04.%20GPU%20architecture%20and%20CUDA/#_6","title":"\u8f6f\u4ef6\u9884\u53d6","text":"
#include <xmmintrin.h>  // for _mm_prefetch\n\nfor (int i = 0; i < n; i += 4) {\n    _mm_prefetch((char*)(a + i + 64), _MM_HINT_T0);  // \u9884\u53d6\u4e4b\u524d64\u4e2a\u5143\u7d20\n    // \u7528SIMD\u5904\u7406 a[i:i+4]\n    __m128 va = _mm_load_ps(a + i);\n    // ...\n}\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/04.%20GPU%20architecture%20and%20CUDA/#_7","title":"\u6838\u51fd\u6570\u878d\u5408","text":"
// \u672a\u878d\u5408\uff1a3\u6b21\u6838\u51fd\u6570\u542f\u52a8\uff0c3\u6b21\u5168\u5c40\u5185\u5b58\u5f80\u8fd4\ny = matmul(x, W)     // \u5199y\u5230\u5168\u5c40\u5185\u5b58\nz = y + bias          // \u8bfby\uff0c\u5199z\nout = relu(z)         // \u8bfbz\uff0c\u5199out\n\n// \u878d\u5408\uff1a1\u6b21\u6838\u51fd\u6570\u542f\u52a8\uff0c1\u6b21\u5168\u5c40\u5185\u5b58\u5199\u5165\nout = fused_matmul_bias_relu(x, W, bias)  // y\u548cz\u6c38\u4e0d\u79bb\u5f00SRAM\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/04.%20GPU%20architecture%20and%20CUDA/#_8","title":"\u6df7\u5408\u7cbe\u5ea6\u6838\u51fd\u6570","text":"
// \u5f20\u91cf\u6838\u5fc3\uff1a\u4e58FP16\u77e9\u9635\uff0c\u5728FP32\u4e2d\u7d2f\u52a0\n// \u6bcf\u6761\u5f20\u91cf\u6838\u5fc3\u6307\u4ee4\uff1aD\uff08FP32\uff09= A\uff08FP16\uff09\u00d7 B\uff08FP16\uff09+ C\uff08FP32\uff09\nnvcuda::wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/04.%20GPU%20architecture%20and%20CUDA/#_9","title":"\u5185\u5b58\u6c60\u5206\u914d\u5668","text":"
# PyTorch\u81ea\u52a8\u6267\u884c\u6b64\u64cd\u4f5c\u2014\u2014\u4f46\u7406\u89e3\u539f\u56e0\u5f88\u91cd\u8981\n# \u6bcf\u4e2a torch.empty() \u4ece\u6c60\u4e2d\u91cd\u7528\u5185\u5b58\uff0c\u65e0\u9700cudaMalloc\ntemp = torch.empty(1024, 1024, device='cuda')  # \u5fae\u79d2\uff0c\u800c\u975e\u6beb\u79d2\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/04.%20GPU%20architecture%20and%20CUDA/#_10","title":"\u5206\u6790\u6307\u5bfc\u7684\u4f18\u5316","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/04.%20GPU%20architecture%20and%20CUDA/#nvidia-gpu","title":"NVIDIA GPU\u4ee3\u6b21","text":"\u4ee3\u6b21 \u5e74\u4efd \u5173\u952e\u521b\u65b0 AI\u76f8\u5173\u6027 Pascal\uff08P100\uff09 2016 HBM2\u3001NVLink \u7b2c\u4e00\u4ee3\u4e25\u8083\u7684\u6df1\u5ea6\u5b66\u4e60GPU Volta\uff08V100\uff09 2017 \u5f20\u91cf\u6838\u5fc3\uff08\u6df7\u5408\u7cbe\u5ea6\u77e9\u9635\u4e58\u6cd5\uff09 \u5b9e\u73b0FP16\u8bad\u7ec3\uff0c125 TFLOPS TF32 Ampere\uff08A100\uff09 2020 TF32\u3001\u7a00\u758f\u6027\u3001\u7b2c\u4e09\u4ee3\u5f20\u91cf\u6838\u5fc3 312 TFLOPS TF32\uff0c\u7ed3\u6784\u6027\u7a00\u758f2:4 Hopper\uff08H100\uff09 2022 Transformer\u5f15\u64ce\uff08FP8\uff09\u3001HBM3 989 TFLOPS FP8\uff0c\u52a8\u6001\u7cbe\u5ea6\u5207\u6362 Blackwell\uff08B200\uff09 2024 \u7b2c\u4e8c\u4ee3Transformer\u5f15\u64ce\u3001NVLink 5 2.5 PFLOPS FP4\uff0c\u591a\u82af\u7247\u8bbe\u8ba1 "},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/04.%20GPU%20architecture%20and%20CUDA/#nvcc","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u7528nvcc\u7f16\u8bd1\uff09","text":"
  1. \u7f16\u5199\u4e00\u4e2a\u5bf9\u6570\u7ec4\u5e94\u7528ReLU\u7684CUDA\u6838\u51fd\u6570\u3002\u6d4b\u91cf\u5305\u62ec\u5185\u5b58\u4f20\u8f93\u5728\u5185\u7684\u65f6\u95f4\u3002\u8fd9\u6559\u6388\u6838\u51fd\u6570\u7f16\u5199\u3001cudaMalloc/cudaMemcpy\u4ee5\u53ca\u4e3b\u673a\u2194\u8bbe\u5907\u4f20\u8f93\u74f6\u9888\u3002

    // task1_relu.cu\n// \u7f16\u8bd1\uff1anvcc -O3 -o task1_relu task1_relu.cu\n\n#include <stdio.h>\n#include <stdlib.h>\n#include <cuda_runtime.h>\n\n__global__ void relu_kernel(const float* input, float* output, int n) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx < n) {\n        output[idx] = input[idx] > 0.0f ? input[idx] : 0.0f;\n    }\n}\n\nint main() {\n    const int N = 1 << 24;  // \u7ea61600\u4e07\u5143\u7d20\n    size_t bytes = N * sizeof(float);\n\n    // \u5206\u914d\u4e3b\u673a\u5185\u5b58\n    float* h_input  = (float*)malloc(bytes);\n    float* h_output = (float*)malloc(bytes);\n    for (int i = 0; i < N; i++) {\n        h_input[i] = (float)(i % 100) - 50.0f;  // \u6b63\u8d1f\u6df7\u5408\n    }\n\n    // \u5206\u914d\u8bbe\u5907\u5185\u5b58\n    float *d_input, *d_output;\n    cudaMalloc(&d_input, bytes);\n    cudaMalloc(&d_output, bytes);\n\n    // \u8ba1\u65f6\u5b8c\u6574\u6d41\u6c34\u7ebf\uff1a\u62f7\u8d1d\u5230GPU\u3001\u8ba1\u7b97\u3001\u62f7\u8d1d\u56de\n    cudaEvent_t start, stop;\n    cudaEventCreate(&start);\n    cudaEventCreate(&stop);\n\n    cudaEventRecord(start);\n    cudaMemcpy(d_input, h_input, bytes, cudaMemcpyHostToDevice);\n\n    int block_size = 256;\n    int grid_size = (N + block_size - 1) / block_size;\n    relu_kernel<<<grid_size, block_size>>>(d_input, d_output, N);\n\n    cudaMemcpy(h_output, d_output, bytes, cudaMemcpyDeviceToHost);\n    cudaEventRecord(stop);\n    cudaEventSynchronize(stop);\n\n    float ms = 0;\n    cudaEventElapsedTime(&ms, start, stop);\n\n    // \u9a8c\u8bc1\n    int errors = 0;\n    for (int i = 0; i < N; i++) {\n        float expected = h_input[i] > 0.0f ? h_input[i] : 0.0f;\n        if (h_output[i] != expected) errors++;\n    }\n\n    printf(\"\u65f6\u95f4\uff08\u542b\u4f20\u8f93\uff09: %.2f ms\\n\", ms);\n    printf(\"\u5e26\u5bbd: %.1f GB/s\\n\", 2.0 * bytes / ms / 1e6);  // \u8bfb\u53d6+\u5199\u5165\n    printf(\"\u9519\u8bef: %d / %d\\n\", errors, N);\n\n    cudaFree(d_input); cudaFree(d_output);\n    free(h_input); free(h_output);\n    return 0;\n}\n

  2. \u5728CUDA\u4e2d\u4f7f\u7528\u5171\u4eab\u5185\u5b58\u7f16\u5199\u5206\u5757\u77e9\u9635\u4e58\u6cd5\u3002\u5c06\u6027\u80fd\u4e0e\u6734\u7d20\uff08\u975e\u5206\u5757\uff09\u7248\u672c\u8fdb\u884c\u6bd4\u8f83\u3002\u8fd9\u6559\u6388\u5171\u4eab\u5185\u5b58\u3001__syncthreads\u4ee5\u53ca\u4e3a\u4ec0\u4e48\u5206\u5757\u91cd\u8981\u3002

    // task2_matmul.cu\n// \u7f16\u8bd1\uff1anvcc -O3 -o task2_matmul task2_matmul.cu\n\n#include <stdio.h>\n#include <cuda_runtime.h>\n\n#define TILE 16\n\n// \u6734\u7d20\u77e9\u9635\u4e58\u6cd5\uff1a\u6bcf\u4e2a\u7ebf\u7a0b\u8ba1\u7b97C\u7684\u4e00\u4e2a\u5143\u7d20\n__global__ void matmul_naive(const float* A, const float* B, float* C, int N) {\n    int row = blockIdx.y * blockDim.y + threadIdx.y;\n    int col = blockIdx.x * blockDim.x + threadIdx.x;\n    if (row < N && col < N) {\n        float sum = 0.0f;\n        for (int k = 0; k < N; k++) {\n            sum += A[row * N + k] * B[k * N + col];\n        }\n        C[row * N + col] = sum;\n    }\n}\n\n// \u5206\u5757\u77e9\u9635\u4e58\u6cd5\uff1a\u4f7f\u7528\u5171\u4eab\u5185\u5b58\u51cf\u5c11\u5168\u5c40\u5185\u5b58\u8bbf\u95ee\n__global__ void matmul_tiled(const float* A, const float* B, float* C, int N) {\n    __shared__ float sA[TILE][TILE];\n    __shared__ float sB[TILE][TILE];\n\n    int row = blockIdx.y * TILE + threadIdx.y;\n    int col = blockIdx.x * TILE + threadIdx.x;\n    float sum = 0.0f;\n\n    for (int t = 0; t < (N + TILE - 1) / TILE; t++) {\n        sA[threadIdx.y][threadIdx.x] = (row < N && t*TILE+threadIdx.x < N)\n            ? A[row * N + t*TILE + threadIdx.x] : 0.0f;\n        sB[threadIdx.y][threadIdx.x] = (t*TILE+threadIdx.y < N && col < N)\n            ? B[(t*TILE + threadIdx.y) * N + col] : 0.0f;\n\n        __syncthreads();\n        for (int k = 0; k < TILE; k++)\n            sum += sA[threadIdx.y][k] * sB[k][threadIdx.x];\n        __syncthreads();\n    }\n\n    if (row < N && col < N)\n        C[row * N + col] = sum;\n}\n\nint main() {\n    const int N = 1024;\n    size_t bytes = N * N * sizeof(float);\n\n    float *d_A, *d_B, *d_C;\n    cudaMalloc(&d_A, bytes); cudaMalloc(&d_B, bytes); cudaMalloc(&d_C, bytes);\n\n    // \u521d\u59cb\u5316\u4e3a1\uff08\u5bb9\u6613\u9a8c\u8bc1\uff1aC\u5e94\u5168\u4e3aN\uff09\n    float* h_A = new float[N*N];\n    for (int i = 0; i < N*N; i++) h_A[i] = 1.0f;\n    cudaMemcpy(d_A, h_A, bytes, cudaMemcpyHostToDevice);\n    cudaMemcpy(d_B, h_A, bytes, cudaMemcpyHostToDevice);\n\n    dim3 block(TILE, TILE);\n    dim3 grid((N+TILE-1)/TILE, (N+TILE-1)/TILE);\n\n    // \u57fa\u51c6\u6d4b\u8bd5\u6734\u7d20\u7248\n    cudaEvent_t start, stop;\n    cudaEventCreate(&start); cudaEventCreate(&stop);\n\n    cudaEventRecord(start);\n    for (int i = 0; i < 10; i++)\n        matmul_naive<<<grid, block>>>(d_A, d_B, d_C, N);\n    cudaEventRecord(stop);\n    cudaEventSynchronize(stop);\n    float naive_ms; cudaEventElapsedTime(&naive_ms, start, stop);\n\n    // \u57fa\u51c6\u6d4b\u8bd5\u5206\u5757\u7248\n    cudaEventRecord(start);\n    for (int i = 0; i < 10; i++)\n        matmul_tiled<<<grid, block>>>(d_A, d_B, d_C, N);\n    cudaEventRecord(stop);\n    cudaEventSynchronize(stop);\n    float tiled_ms; cudaEventElapsedTime(&tiled_ms, start, stop);\n\n    double gflops_naive = 2.0 * N * N * N * 10 / naive_ms / 1e6;\n    double gflops_tiled = 2.0 * N * N * N * 10 / tiled_ms / 1e6;\n\n    printf(\"\u6734\u7d20\u7248:  %.2f ms, %.1f GFLOPS\\n\", naive_ms/10, gflops_naive);\n    printf(\"\u5206\u5757\u7248:  %.2f ms, %.1f GFLOPS\\n\", tiled_ms/10, gflops_tiled);\n    printf(\"\u52a0\u901f\u6bd4: %.1fx\\n\", naive_ms / tiled_ms);\n\n    cudaFree(d_A); cudaFree(d_B); cudaFree(d_C);\n    delete[] h_A;\n    return 0;\n}\n

  3. \u6f14\u793a\u7ebf\u7a0b\u675f\u5206\u6b67\u3002\u7f16\u5199\u4e00\u4e2a\u6838\u51fd\u6570\uff0c\u5176\u4e2d\u540c\u4e00\u7ebf\u7a0b\u675f\u4e2d\u7684\u7ebf\u7a0b\u8d70\u4e0d\u540c\u5206\u652f\uff0c\u5e76\u4e0e\u65e0\u5206\u652f\u7248\u672c\u6bd4\u8f83\u3002

    // task3_divergence.cu\n// \u7f16\u8bd1\uff1anvcc -O3 -o task3_diverge task3_divergence.cu\n\n#include <stdio.h>\n#include <cuda_runtime.h>\n\n// \u7cdf\u7cd5\uff1a\u7ebf\u7a0b\u675f\u5206\u6b67\u2014\u2014\u5076\u6570/\u5947\u6570\u7ebf\u7a0b\u8d70\u4e0d\u540c\u8def\u5f84\n__global__ void divergent_kernel(float* data, int n) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx < n) {\n        if (idx % 2 == 0) {\n            data[idx] = data[idx] * 2.0f + 1.0f;\n        } else {\n            data[idx] = data[idx] * 0.5f - 1.0f;\n        }\n    }\n}\n\n// \u597d\uff1a\u65e0\u5206\u652f\u2014\u2014\u6240\u6709\u7ebf\u7a0b\u6267\u884c\u76f8\u540c\u6307\u4ee4\n__global__ void branchless_kernel(float* data, int n) {\n    int idx = blockIdx.x * blockDim.x + threadIdx.x;\n    if (idx < n) {\n        float scale = (idx % 2 == 0) ? 2.0f : 0.5f;\n        float offset = (idx % 2 == 0) ? 1.0f : -1.0f;\n        data[idx] = data[idx] * scale + offset;\n    }\n}\n\nint main() {\n    const int N = 1 << 24;\n    float* d_data;\n    cudaMalloc(&d_data, N * sizeof(float));\n    cudaMemset(d_data, 0, N * sizeof(float));\n\n    int block = 256, grid = (N + block - 1) / block;\n\n    cudaEvent_t start, stop;\n    cudaEventCreate(&start); cudaEventCreate(&stop);\n\n    // \u5206\u6b67\u7248\n    cudaEventRecord(start);\n    for (int i = 0; i < 100; i++)\n        divergent_kernel<<<grid, block>>>(d_data, N);\n    cudaEventRecord(stop);\n    cudaEventSynchronize(stop);\n    float div_ms; cudaEventElapsedTime(&div_ms, start, stop);\n\n    // \u65e0\u5206\u652f\u7248\n    cudaEventRecord(start);\n    for (int i = 0; i < 100; i++)\n        branchless_kernel<<<grid, block>>>(d_data, N);\n    cudaEventRecord(stop);\n    cudaEventSynchronize(stop);\n    float nodiv_ms; cudaEventElapsedTime(&nodiv_ms, start, stop);\n\n    printf(\"\u5206\u6b67\u7248:  %.2f ms\\n\", div_ms / 100);\n    printf(\"\u65e0\u5206\u652f\u7248: %.2f ms\\n\", nodiv_ms / 100);\n    printf(\"\u52a0\u901f\u6bd4:    %.2fx\\n\", div_ms / nodiv_ms);\n\n    cudaFree(d_data);\n    return 0;\n}\n

"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/05.%20triton%2C%20TPUs%20and%20pallax/","title":"Triton\u4e0eTPU","text":"

CUDA C\u529f\u80fd\u5f3a\u5927\u4f46\u5197\u957f\u3002Triton\u8ba9\u4f60\u7528Python\u7f16\u5199GPU\u6838\u51fd\u6570\u3002TPU\u63d0\u4f9b\u4e86GPU\u4e4b\u5916\u7684\u9009\u62e9\uff0c\u5177\u6709\u4e0d\u540c\u7684\u6743\u8861\u3002\u672c\u6587\u6db5\u76d6Triton\u6838\u51fd\u6570\u7f16\u7a0b\u3001\u4ee5Flash Attention\u4e3a\u6848\u4f8b\u7814\u7a76\u3001TPU\u67b6\u6784\u4e0eJAX/Pallas\uff0c\u4ee5\u53ca\u5982\u4f55\u9009\u62e9\u5408\u9002\u7684\u5de5\u5177\u3002\u5173\u4e8eVulkan\u548c\u8de8\u5e73\u53f0GPU\u8ba1\u7b97\uff0c\u8bf7\u53c2\u89c1\u6587\u4ef607\u3002

"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/05.%20triton%2C%20TPUs%20and%20pallax/#tritonpythongpu","title":"Triton\uff1a\u7528Python\u7f16\u5199GPU\u6838\u51fd\u6570","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/05.%20triton%2C%20TPUs%20and%20pallax/#triton","title":"\u4f60\u7684\u7b2c\u4e00\u4e2aTriton\u6838\u51fd\u6570","text":"
import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef add_kernel(\n    x_ptr, y_ptr, output_ptr,\n    n_elements,\n    BLOCK_SIZE: tl.constexpr,  # \u7f16\u8bd1\u65f6\u5e38\u91cf\n):\n    # \u6bcf\u4e2a\u7a0b\u5e8f\u5b9e\u4f8b\u5904\u7406\u4e00\u4e2aBLOCK_SIZE\u5143\u7d20\u7684\u5757\n    pid = tl.program_id(axis=0)  # \u6211\u662f\u54ea\u4e2a\u5757\uff1f\n    block_start = pid * BLOCK_SIZE\n\n    # \u6b64\u5757\u7684\u504f\u79fb\u91cf\n    offsets = block_start + tl.arange(0, BLOCK_SIZE)\n\n    # \u63a9\u7801\u5904\u7406n_elements\u4e0d\u662fBLOCK_SIZE\u500d\u6570\u7684\u60c5\u51b5\n    mask = offsets < n_elements\n\n    # \u52a0\u8f7d\u6570\u636e\uff08\u5e26\u63a9\u7801\uff1a\u8d8a\u754c\u8bfb\u53d6\u8fd4\u56de0\uff09\n    x = tl.load(x_ptr + offsets, mask=mask)\n    y = tl.load(y_ptr + offsets, mask=mask)\n\n    # \u8ba1\u7b97\n    output = x + y\n\n    # \u5b58\u50a8\u7ed3\u679c\n    tl.store(output_ptr + offsets, output, mask=mask)\n\n\ndef add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:\n    output = torch.empty_like(x)\n    n_elements = output.numel()\n\n    # \u542f\u52a8\uff1a\u6bcf\u4e2a\u5757\u4e00\u4e2a\u7a0b\u5e8f\n    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)\n\n    return output\n\n\n# \u4f7f\u7528\nx = torch.randn(1000000, device='cuda')\ny = torch.randn(1000000, device='cuda')\nz = add(x, y)\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/05.%20triton%2C%20TPUs%20and%20pallax/#triton-softmax","title":"Triton Softmax\u6838\u51fd\u6570","text":"
@triton.jit\ndef softmax_kernel(\n    output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,\n    BLOCK_SIZE: tl.constexpr,\n):\n    # \u6bcf\u4e2a\u7a0b\u5e8f\u5904\u7406\u4e00\u884c\n    row_idx = tl.program_id(0)\n    row_start = input_ptr + row_idx * input_row_stride\n\n    # \u52a0\u8f7d\u8be5\u884c\n    col_offsets = tl.arange(0, BLOCK_SIZE)\n    mask = col_offsets < n_cols\n    row = tl.load(row_start + col_offsets, mask=mask, other=-float('inf'))\n\n    # Softmax\uff1a\u4e3a\u6570\u503c\u7a33\u5b9a\u6027\u53d6\u6700\u5927\u503c\uff0c\u7136\u540eexp\uff0c\u7136\u540e\u5f52\u4e00\u5316\n    row_max = tl.max(row, axis=0)\n    numerator = tl.exp(row - row_max)\n    denominator = tl.sum(numerator, axis=0)\n    softmax_output = numerator / denominator\n\n    # \u5b58\u50a8\u7ed3\u679c\n    output_start = output_ptr + row_idx * output_row_stride\n    tl.store(output_start + col_offsets, softmax_output, mask=mask)\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/05.%20triton%2C%20TPUs%20and%20pallax/#triton_1","title":"Triton\u81ea\u52a8\u8c03\u4f18","text":"
@triton.autotune(\n    configs=[\n        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}),\n        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}),\n        triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}),\n    ],\n    key=['M', 'N', 'K'],  # \u5f53\u8fd9\u4e9b\u53d8\u5316\u65f6\u91cd\u65b0\u8c03\u4f18\n)\n@triton.jit\ndef matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, ...):\n    ...\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/05.%20triton%2C%20TPUs%20and%20pallax/#triton-vs-cuda","title":"Triton vs CUDA\uff1a\u4f55\u65f6\u4f7f\u7528","text":"Triton CUDA C \u8bed\u8a00 Python C/C++ \u62bd\u8c61\u5c42\u7ea7 \u5757\u7ea7 \u7ebf\u7a0b\u7ea7 \u5f00\u53d1\u901f\u5ea6 \u5feb\uff08\u6bcf\u6838\u51fd\u657010-50\u884c\uff09 \u6162\uff08100-500\u884c\uff09 \u6027\u80fd\u5929\u82b1\u677f \u624b\u5de5\u8c03\u4f18CUDA\u7684\u7ea680-95% 100%\uff08\u5b8c\u5168\u786c\u4ef6\u63a7\u5236\uff09 \u5171\u4eab\u5185\u5b58 \u81ea\u52a8 \u624b\u52a8 \u5408\u5e76 \u81ea\u52a8 \u624b\u52a8 \u7ebf\u7a0b\u675f\u7ea7\u539f\u8bed \u6709\u9650 \u5b8c\u6574\uff08shuffle\u3001vote\u7b49\uff09 \u786c\u4ef6\u652f\u6301 \u4ec5NVIDIA\uff08AMD\u5b9e\u9a8c\u6027\uff09 \u4ec5NVIDIA "},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/05.%20triton%2C%20TPUs%20and%20pallax/#flash-attention","title":"\u6848\u4f8b\u7814\u7a76\uff1aFlash Attention","text":"
\u5bf9\u4e8e\u6bcf\u4e2aQ\u884c\u5757\uff1a\n    \u5bf9\u4e8e\u6bcf\u4e2aK\u5217\u5757\uff1a\n        1. \u5c06Q_block\u4eceHBM\u52a0\u8f7d\u5230SRAM\n        2. \u5c06K_block\u4eceHBM\u52a0\u8f7d\u5230SRAM\n        3. \u8ba1\u7b97S_block = Q_block @ K_block.T\uff08\u5728SRAM\u4e2d\uff09\n        4. \u66f4\u65b0\u8fd0\u884c\u4e2d\u6700\u5927\u503c\uff0c\u91cd\u65b0\u7f29\u653e\u5148\u524d\u7ed3\u679c\n        5. \u8ba1\u7b97exp(S_block - \u8fd0\u884c\u4e2d\u6700\u5927\u503c)\n        6. \u66f4\u65b0\u8fd0\u884c\u4e2d\u6c42\u548c\u548c\u8f93\u51fa\u7d2f\u52a0\u5668\n    \u52a0\u8f7dV_block\u5e76\u8ba1\u7b97\u6700\u7ec8\u8f93\u51fa\n    \u5c06\u8f93\u51fa\u5757\u5199\u56deHBM\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/05.%20triton%2C%20TPUs%20and%20pallax/#tpu","title":"TPU\u67b6\u6784","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/05.%20triton%2C%20TPUs%20and%20pallax/#tpujaxpallas","title":"\u7f16\u7a0bTPU\uff1aJAX\u4e0ePallas","text":"
import jax\nimport jax.numpy as jnp\n\n@jax.jit\ndef matmul(a, b):\n    return jnp.dot(a, b)\n\n# \u8fd9\u5c06\u6839\u636e\u8bbe\u5907\u5728CPU\u3001GPU\u6216TPU\u4e0a\u8fd0\u884c\na = jnp.ones((1024, 1024))\nb = jnp.ones((1024, 1024))\nc = matmul(a, b)\n
from jax.experimental import pallas as pl\nimport jax.numpy as jnp\n\ndef add_kernel(x_ref, y_ref, o_ref):\n    o_ref[...] = x_ref[...] + y_ref[...]\n\ndef add_pallas(x, y):\n    return pl.pallas_call(\n        add_kernel,\n        out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),\n        grid=(x.shape[0] // 128,),\n        in_specs=[pl.BlockSpec((128,), lambda i: (i,)),\n                  pl.BlockSpec((128,), lambda i: (i,))],\n        out_specs=pl.BlockSpec((128,), lambda i: (i,)),\n    )(x, y)\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/05.%20triton%2C%20TPUs%20and%20pallax/#gpu-vs-tpu","title":"GPU vs TPU","text":"GPU\uff08NVIDIA\uff09 TPU\uff08Google\uff09 \u53ef\u7528\u6027 \u4efb\u4f55\u4e91\u3001\u672c\u5730\u90e8\u7f72 \u4ec5Google Cloud \u7f16\u7a0b CUDA C\u3001Triton\u3001PyTorch JAX/XLA\u3001Pallas \u7075\u6d3b\u6027 \u901a\u7528\u8ba1\u7b97 \u9488\u5bf9\u77e9\u9635\u5bc6\u96c6\u578bML\u4f18\u5316 \u5cf0\u503c\u77e9\u9635\u4e58\u6cd5FLOPS \u975e\u5e38\u9ad8\uff08\u5f20\u91cf\u6838\u5fc3\uff09 \u975e\u5e38\u9ad8\uff08MXU\uff09 \u975e\u77e9\u9635\u4e58\u6cd5\u64cd\u4f5c \u597d \u8f83\u6162\uff08\u901a\u8fc7\u5411\u91cf\u5355\u5143\u8def\u7531\uff0c\u800c\u975eMXU\uff09 \u591a\u82af\u7247\u6269\u5c55 NVLink\uff088\u4e2aGPU\uff09\u3001InfiniBand ICI\uff08\u6570\u5343\u4e2aTPU\uff0c\u66f4\u7d27\u5bc6\u96c6\u6210\uff09 \u6210\u672c\u6548\u7387 \u6709\u7ade\u4e89\u529b \u5927\u89c4\u6a21\u8bad\u7ec3\u901a\u5e38\u66f4\u4fbf\u5b9c \u751f\u6001\u7cfb\u7edf \u6700\u5927\uff08PyTorch\u3001TensorFlow\u3001JAX\uff09 \u9762\u5411JAX "},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/05.%20triton%2C%20TPUs%20and%20pallax/#_1","title":"\u9009\u62e9\u5408\u9002\u7684\u5de5\u5177","text":"\u5de5\u4f5c\u8d1f\u8f7d \u6700\u4f73\u5de5\u5177 \u4e3a\u4ec0\u4e48 ML\u8bad\u7ec3\uff08PyTorch\uff09 NVIDIA GPU + CUDA/Triton \u6700\u5927\u751f\u6001\u7cfb\u7edf\u3001\u6700\u4f73\u5de5\u5177\u94fe ML\u8bad\u7ec3\uff08JAX\uff0c\u5927\u89c4\u6a21\uff09 TPU\u6216NVIDIA GPU TPU\u5728Google\u89c4\u6a21\u4e0b\u6210\u672c\u4f4e\uff0cGPU\u7075\u6d3b \u81ea\u5b9a\u4e49\u878d\u5408\u6838\u51fd\u6570 Triton\uff08Python\uff09\u6216CUDA C Triton\u5f00\u53d1\u901f\u5ea6\u5feb\uff0cCUDA\u5cf0\u503c\u6027\u80fd\u9ad8 JAX\u81ea\u5b9a\u4e49\u6838\u51fd\u6570 Pallas TPU\u552f\u4e00\u9009\u9879\uff0c\u4e5f\u53ef\u5728GPU\u4e0a\u5de5\u4f5c \u8de8\u5e73\u53f0\u63a8\u7406 Vulkan\uff08\u6587\u4ef607\uff09\u6216ONNX Runtime \u8fd0\u884c\u5728\u4efb\u4f55GPU\u4f9b\u5e94\u5546\u4e0a \u79fb\u52a8/\u8fb9\u7f18\u63a8\u7406 Metal\uff08Apple\uff09\u3001Vulkan\uff08Android\uff09\u3001NNAPI \u5e73\u53f0\u7279\u5b9a\u7684\u52a0\u901f\u5668 \u6d4f\u89c8\u5668\u63a8\u7406 WebGPU\uff08\u6587\u4ef607\uff09 \u6d4f\u89c8\u5668\u4e2d\u552f\u4e00\u9009\u9879 \u4ec5CPU\u63a8\u7406 ONNX Runtime + AVX/NEON \u65e0\u9700GPU\uff0c\u4f7f\u7528SIMD\uff08\u6587\u4ef602-03\uff09 \u65b0\u578b\u786c\u4ef6 \u4f9b\u5e94\u5546\u4e13\u7528SDK \u6bcf\u4e2a\u52a0\u901f\u5668\u6709\u81ea\u5df1\u7684\u5de5\u5177\u94fe"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/05.%20triton%2C%20TPUs%20and%20pallax/#gpucolab","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528\u5e26GPU\u8fd0\u884c\u65f6\u7684CoLab\uff09","text":"
  1. \u7f16\u5199\u5e76\u8fd0\u884c\u5411\u91cf\u52a0\u6cd5\u7684Triton\u6838\u51fd\u6570\u3002\u5c06\u5176\u6027\u80fd\u4e0ePyTorch\u5185\u7f6e\u52a0\u6cd5\u6bd4\u8f83\u3002

    import triton\nimport triton.language as tl\nimport torch\nimport time\n\n@triton.jit\ndef add_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr):\n    pid = tl.program_id(0)\n    offs = pid * BLOCK + tl.arange(0, BLOCK)\n    mask = offs < n\n    x = tl.load(x_ptr + offs, mask=mask)\n    y = tl.load(y_ptr + offs, mask=mask)\n    tl.store(out_ptr + offs, x + y, mask=mask)\n\nn = 10_000_000\nx = torch.randn(n, device='cuda')\ny = torch.randn(n, device='cuda')\n\n# Triton\nout_triton = torch.empty_like(x)\ngrid = lambda meta: (triton.cdiv(n, meta['BLOCK']),)\nadd_kernel[grid](x, y, out_triton, n, BLOCK=1024)\n\n# PyTorch\nout_torch = x + y\n\n# \u9a8c\u8bc1\u6b63\u786e\u6027\nassert torch.allclose(out_triton, out_torch, atol=1e-5)\n\n# \u57fa\u51c6\u6d4b\u8bd5\ntorch.cuda.synchronize()\nstart = time.time()\nfor _ in range(1000):\n    add_kernel[grid](x, y, out_triton, n, BLOCK=1024)\ntorch.cuda.synchronize()\ntriton_time = (time.time() - start) / 1000\n\nstart = time.time()\nfor _ in range(1000):\n    out_torch = x + y\ntorch.cuda.synchronize()\ntorch_time = (time.time() - start) / 1000\n\nprint(f\"Triton:  {triton_time*1000:.3f} ms\")\nprint(f\"PyTorch: {torch_time*1000:.3f} ms\")\nprint(f\"\u6bd4\u7387:   {torch_time/triton_time:.2f}x\")\n

  2. \u7f16\u5199\u4e00\u4e2aTriton\u878d\u5408\u6838\u51fd\u6570\uff0c\u5728\u5355\u6b21\u904d\u5386\u4e2d\u6267\u884c\u4e58\u6cd5+\u52a0\u6cd5+ReLU\u3002\u4e0e\u4e09\u4e2a\u72ec\u7acb\u7684PyTorch\u64cd\u4f5c\u6bd4\u8f83\u3002

    import triton\nimport triton.language as tl\nimport torch\nimport time\n\n@triton.jit\ndef fused_mul_add_relu_kernel(x_ptr, w_ptr, b_ptr, out_ptr, n, BLOCK: tl.constexpr):\n    pid = tl.program_id(0)\n    offs = pid * BLOCK + tl.arange(0, BLOCK)\n    mask = offs < n\n    x = tl.load(x_ptr + offs, mask=mask)\n    w = tl.load(w_ptr + offs, mask=mask)\n    b = tl.load(b_ptr + offs, mask=mask)\n    result = tl.maximum(x * w + b, 0.0)  # \u878d\u5408\uff1a\u4e58\u6cd5 + \u52a0\u6cd5 + relu\n    tl.store(out_ptr + offs, result, mask=mask)\n\nn = 10_000_000\nx = torch.randn(n, device='cuda')\nw = torch.randn(n, device='cuda')\nb = torch.randn(n, device='cuda')\n\n# \u878d\u5408\uff08Triton\uff09\nout_fused = torch.empty_like(x)\ngrid = lambda meta: (triton.cdiv(n, meta['BLOCK']),)\nfused_mul_add_relu_kernel[grid](x, w, b, out_fused, n, BLOCK=1024)\n\n# \u672a\u878d\u5408\uff08PyTorch\uff09\nout_unfused = torch.relu(x * w + b)\n\nassert torch.allclose(out_fused, out_unfused, atol=1e-5)\n\n# \u57fa\u51c6\u6d4b\u8bd5\ntorch.cuda.synchronize()\nstart = time.time()\nfor _ in range(1000):\n    fused_mul_add_relu_kernel[grid](x, w, b, out_fused, n, BLOCK=1024)\ntorch.cuda.synchronize()\nfused_time = (time.time() - start) / 1000\n\nstart = time.time()\nfor _ in range(1000):\n    out_unfused = torch.relu(x * w + b)\ntorch.cuda.synchronize()\nunfused_time = (time.time() - start) / 1000\n\nprint(f\"\u878d\u5408\uff08Triton\uff09:    {fused_time*1000:.3f} ms\")\nprint(f\"\u672a\u878d\u5408\uff08PyTorch\uff09: {unfused_time*1000:.3f} ms\")\nprint(f\"\u52a0\u901f\u6bd4:           {unfused_time/fused_time:.2f}x\")\n

  3. \u6d4b\u91cfJAX\u7684XLA\u7f16\u8bd1\u5668\u5982\u4f55\u81ea\u52a8\u878d\u5408\u64cd\u4f5c\u3002\u6bd4\u8f83\u5e26\u548c\u4e0d\u5e26jit\u7684\u64cd\u4f5c\u94fe\u3002

    import jax\nimport jax.numpy as jnp\nimport time\n\ndef chain_ops(x):\n    x = x * 2.0\n    x = x + 1.0\n    x = jnp.maximum(x, 0.0)  # ReLU\n    x = x / jnp.sum(x)\n    return x\n\nchain_jit = jax.jit(chain_ops)\nx = jax.random.normal(jax.random.PRNGKey(0), (10000, 1000))\n\n# \u9884\u70ed\n_ = chain_jit(x)\njax.block_until_ready(_)\n\n# \u5373\u65f6\u6a21\u5f0f\uff08\u6bcf\u4e2a\u64cd\u4f5c\u662f\u72ec\u7acb\u6838\u51fd\u6570\u542f\u52a8\uff09\nstart = time.time()\nfor _ in range(100):\n    y = chain_ops(x)\njax.block_until_ready(y)\neager_time = (time.time() - start) / 100\n\n# JIT\uff08XLA\u878d\u5408\u64cd\u4f5c\uff09\nstart = time.time()\nfor _ in range(100):\n    y = chain_jit(x)\njax.block_until_ready(y)\njit_time = (time.time() - start) / 100\n\nprint(f\"\u5373\u65f6: {eager_time*1000:.2f} ms\")\nprint(f\"JIT:   {jit_time*1000:.2f} ms\")\nprint(f\"\u52a0\u901f\u6bd4: {eager_time/jit_time:.1f}x\uff08XLA\u5c064\u4e2a\u64cd\u4f5c\u878d\u5408\u4e3a1\u4e2a\u6838\u51fd\u6570\uff09\")\n

"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/06.%20RISC-V%20and%20embedded%20systems/","title":"RISC-V\u4e0e\u5d4c\u5165\u5f0f\u7cfb\u7edf","text":"

RISC-V\u662f\u6b63\u5728\u91cd\u5851\u82af\u7247\u884c\u4e1a\u7684\u5f00\u6e90\u6307\u4ee4\u96c6\u67b6\u6784\u3002\u672c\u6587\u6db5\u76d6RISC-V\u54f2\u5b66\u3001V\u5411\u91cf\u6269\u5c55\u3001\u5d4c\u5165\u5f0fML\u63a8\u7406\u3001\u5fae\u63a7\u5236\u5668\u4e0a\u7684TinyML\u3001AI\u52a0\u901f\u5668\u4e2d\u7684RISC-V\u4ee5\u53ca\u8fb9\u7f18\u90e8\u7f72\u7ea6\u675f

"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/06.%20RISC-V%20and%20embedded%20systems/#risc-v_1","title":"RISC-V\u54f2\u5b66","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/06.%20RISC-V%20and%20embedded%20systems/#risc-v_2","title":"RISC-V\u57fa\u7840\u67b6\u6784","text":"
# RISC-V\u6c47\u7f16\uff08\u611f\u53d7\u98ce\u683c\u2014\u2014\u4f60\u5c06\u4f7f\u7528C/C++\uff09\nadd  x3, x1, x2      # x3 = x1 + x2\nlw   x4, 0(x5)       # \u4ecex5\u4e2d\u7684\u5730\u5740\u52a0\u8f7d\u5b57\nsw   x4, 8(x5)       # \u5b58\u50a8\u5b57\u5230\u5730\u5740 x5 + 8\nbeq  x1, x2, label   # \u5982\u679cx1 == x2\u5219\u5206\u652f\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/06.%20RISC-V%20and%20embedded%20systems/#vrisc-v","title":"V\u6269\u5c55\uff1aRISC-V\u5411\u91cf\u5904\u7406","text":"
#include <riscv_vector.h>\n\n// \u4f7f\u7528RVV\u5185\u8054\u51fd\u6570\u8fdb\u884c\u5411\u91cf\u52a0\u6cd5\nvoid vadd_rvv(const float* a, const float* b, float* c, int n) {\n    while (n > 0) {\n        // vsetvl\uff1a\u8bbe\u7f6e\u5411\u91cf\u957f\u5ea6\u2014\u2014\u5904\u7406 min(n, \u786c\u4ef6\u6700\u5927\u503c) \u4e2a\u5143\u7d20\n        size_t vl = __riscv_vsetvl_e32m1(n);\n\n        // \u52a0\u8f7dvl\u4e2a\u5143\u7d20\n        vfloat32m1_t va = __riscv_vle32_v_f32m1(a, vl);\n        vfloat32m1_t vb = __riscv_vle32_v_f32m1(b, vl);\n\n        // \u76f8\u52a0\n        vfloat32m1_t vc = __riscv_vfadd_vv_f32m1(va, vb, vl);\n\n        // \u5b58\u50a8\n        __riscv_vse32_v_f32m1(c, vc, vl);\n\n        // \u524d\u8fdb\u6307\u9488\n        a += vl; b += vl; c += vl; n -= vl;\n    }\n}\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/06.%20RISC-V%20and%20embedded%20systems/#mltinyml","title":"\u5d4c\u5165\u5f0fML\uff1aTinyML","text":" \u8d44\u6e90 \u670d\u52a1\u5668GPU \u667a\u80fd\u624b\u673a \u5fae\u63a7\u5236\u5668 RAM 80 GB 6 GB 256 KB \u5b58\u50a8 TB 128 GB 1 MB \u8ba1\u7b97\u80fd\u529b 1000 TFLOPS 10 TFLOPS 0.001 TFLOPS \u529f\u8017 700 W 5 W 0.001 W \u6210\u672c $30,000 $500 $1 "},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/06.%20RISC-V%20and%20embedded%20systems/#tensorflow-lite-microtflm","title":"TensorFlow Lite Micro\uff08TFLM\uff09","text":"
// \u5fae\u63a7\u5236\u5668\u4e0a\u7684TinyML\u63a8\u7406\uff08\u7b80\u5316\u7248\uff09\n#include \"tensorflow/lite/micro/micro_interpreter.h\"\n#include \"tensorflow/lite/micro/micro_mutable_op_resolver.h\"\n\n// \u6a21\u578b\u7f16\u8bd1\u4e3aC\u6570\u7ec4\uff08const unsigned char model_data[]\uff09\nconst tflite::Model* model = tflite::GetModel(model_data);\n\n// \u5206\u914d\u56fa\u5b9a\u5185\u5b58\u7f13\u51b2\u533a\uff08\u65e0malloc\uff01\uff09\nconstexpr int kArenaSize = 10 * 1024;  // 10 KB\nuint8_t tensor_arena[kArenaSize];\n\n// \u8bbe\u7f6e\u89e3\u91ca\u5668\ntflite::MicroInterpreter interpreter(model, resolver, tensor_arena, kArenaSize);\ninterpreter.AllocateTensors();\n\n// \u8bbe\u7f6e\u8f93\u5165\nfloat* input = interpreter.input(0)->data.f;\ninput[0] = sensor_reading;\n\n// \u8fd0\u884c\u63a8\u7406\ninterpreter.Invoke();\n\n// \u8bfb\u53d6\u8f93\u51fa\nfloat* output = interpreter.output(0)->data.f;\nif (output[0] > 0.8f) {\n    trigger_alert();\n}\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/06.%20RISC-V%20and%20embedded%20systems/#_1","title":"\u8fb9\u7f18\u6a21\u578b\u4f18\u5316","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/06.%20RISC-V%20and%20embedded%20systems/#risc-vai","title":"RISC-V\u5728AI\u52a0\u901f\u5668\u4e2d\u7684\u5e94\u7528","text":"
\u250c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2510\n\u2502              AI\u52a0\u901f\u5668                    \u2502\n\u2502                                         \u2502\n\u2502  \u250c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2510    \u250c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2510   \u2502\n\u2502  \u2502  RISC-V  \u2502\u2500\u2500\u2500\u2192\u2502  \u81ea\u5b9a\u4e49\u77e9\u9635       \u2502   \u2502\n\u2502  \u2502  \u63a7\u5236     \u2502    \u2502  \u4e58\u6cd5\u5355\u5143         \u2502   \u2502\n\u2502  \u2502  \u6838\u5fc3     \u2502    \u2502 \uff08\u8109\u52a8\u9635\u5217\u3001      \u2502   \u2502\n\u2502  \u2502          \u2502    \u2502  \u81ea\u5b9a\u4e49\u6570\u636e\u6d41\uff09    \u2502   \u2502\n\u2502  \u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518    \u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518   \u2502\n\u2502       \u2502                    \u2502            \u2502\n\u2502       \u25bc                    \u25bc            \u2502\n\u2502  \u250c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2510    \u250c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2510   \u2502\n\u2502  \u2502  \u5185\u5b58     \u2502    \u2502  \u7247\u4e0aSRAM        \u2502   \u2502\n\u2502  \u2502  \u63a7\u5236     \u2502    \u2502 \uff08\u6fc0\u6d3b\u7f13\u51b2\uff09      \u2502   \u2502\n\u2502  \u2502          \u2502    \u2502                  \u2502   \u2502\n\u2502  \u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518    \u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518   \u2502\n\u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/06.%20RISC-V%20and%20embedded%20systems/#_2","title":"\u8fb9\u7f18\u90e8\u7f72\u7ea6\u675f","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/06.%20RISC-V%20and%20embedded%20systems/#griscv64-gcc","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u7528g++\u6216riscv64-gcc\u4ea4\u53c9\u7f16\u8bd1\u5668\u7f16\u8bd1\uff09","text":"
  1. \u7f16\u5199\u4e00\u4e2aC\u7a0b\u5e8f\uff0c\u6a21\u62dfTinyML\u63a8\u7406\u6d41\u6c34\u7ebf\uff1a\u9759\u6001\u5206\u914d\u6a21\u578b\u7f13\u51b2\u533a\uff0c\u8fd0\u884c\u6a21\u62df\u524d\u5411\u4f20\u64ad\uff0c\u5e76\u6d4b\u91cf\u8d44\u6e90\u4f7f\u7528\u3002\u8fd9\u6559\u6388\u5d4c\u5165\u5f0f\u7ea6\u675f\uff08\u65e0malloc\u3001\u56fa\u5b9a\u5185\u5b58\u7f13\u51b2\u533a\uff09\u3002

    // task1_tinyml_sim.cpp\n// \u7f16\u8bd1\uff1ag++ -O2 -o task1 task1_tinyml_sim.cpp\n\n#include <iostream>\n#include <chrono>\n#include <cmath>\n#include <cstring>\n\n// \u6a21\u62df\u5fae\u63a7\u5236\u5668\uff1a\u56fa\u5b9a\u5185\u5b58\u7f13\u51b2\u533a\uff0c\u65e0\u52a8\u6001\u5206\u914d\nstatic constexpr int ARENA_SIZE = 32 * 1024;  // 32 KB\u603bRAM\u9884\u7b97\nstatic uint8_t arena[ARENA_SIZE];\n\n// \u7b80\u5355\u76842\u5c42MLP\uff1a784 -> 64 -> 10\uff08\u7c7b\u4f3cMNIST\uff0cINT8\u6743\u91cd\uff09\nstruct TinyModel {\n    int8_t w1[784 * 64];      // \u5c421\u6743\u91cd\uff1a50,176\u5b57\u8282\n    int8_t b1[64];             // \u5c421\u504f\u7f6e\n    int8_t w2[64 * 10];       // \u5c422\u6743\u91cd\uff1a640\u5b57\u8282\n    int8_t b2[10];             // \u5c422\u504f\u7f6e\n    // \u603b\u8ba1\uff1a\u7ea651 KB \u2192 \u5fc5\u987b\u653e\u5728\u95ea\u5b58\uff08ROM\uff09\uff0c\u800c\u975eRAM\n};\n\n// \u68c0\u67e5\u6a21\u578b\u662f\u5426\u9002\u5408\u95ea\u5b58\nvoid check_model_fit(int flash_kb) {\n    int model_bytes = sizeof(TinyModel);\n    std::cout << \"\u6a21\u578b\u5927\u5c0f: \" << model_bytes << \" \u5b57\u8282\uff08\"\n              << model_bytes / 1024 << \" KB\uff09\\n\";\n    std::cout << \"\u95ea\u5b58: \" << flash_kb << \" KB \u2192 \"\n              << (model_bytes <= flash_kb * 1024 ? \"\u9002\u5408\" : \"\u592a\u5927\") << \"\\n\";\n}\n\n// \u4f7f\u7528\u56fa\u5b9a\u7f13\u51b2\u533a\u8fdb\u884c\u6fc0\u6d3b\u7684\u6a21\u62df\u63a8\u7406\nvoid mock_inference(const int8_t* input, int8_t* output) {\n    // \u6fc0\u6d3b\u503c\u653e\u5728\u7f13\u51b2\u533a\uff08RAM\uff09\u4e2d\uff0c\u800c\u975e\u52a8\u6001\u5206\u914d\n    int8_t* act1 = (int8_t*)arena;            // \u5c421\u8f93\u51fa64\u5b57\u8282\n    int8_t* act2 = (int8_t*)(arena + 64);     // \u5c422\u8f93\u51fa10\u5b57\u8282\n\n    // \u5c421\uff1a\u7b80\u5316\u7248\u77e9\u9635\u4e58\u6cd5\uff08\u4e0d\u662f\u771f\u6b63\u7684\u91cf\u5316\u77e9\u9635\u4e58\u6cd5\uff0c\u4ec5\u7ed3\u6784\u6f14\u793a\uff09\n    for (int j = 0; j < 64; j++) {\n        int32_t sum = 0;  // \u7528int32\u7d2f\u52a0\u907f\u514d\u6ea2\u51fa\n        for (int i = 0; i < 784; i++) {\n            sum += (int32_t)input[i] * 1;  // \u6a21\u62df\uff1a\u6743\u91cd=1\n        }\n        act1[j] = (int8_t)std::max(-128, std::min(127, sum / 784));  // \u91cf\u5316\u56de\n        act1[j] = act1[j] > 0 ? act1[j] : 0;  // ReLU\n    }\n\n    // \u5c422\n    for (int j = 0; j < 10; j++) {\n        int32_t sum = 0;\n        for (int i = 0; i < 64; i++) {\n            sum += (int32_t)act1[i] * 1;\n        }\n        act2[j] = (int8_t)std::max(-128, std::min(127, sum / 64));\n    }\n\n    std::memcpy(output, act2, 10);\n}\n\nint main() {\n    std::cout << \"=== TinyML\u8d44\u6e90\u9884\u7b97 ===\\n\";\n    std::cout << \"\u7f13\u51b2\u533a\uff08RAM\uff09: \" << ARENA_SIZE << \" \u5b57\u8282\uff08\"\n              << ARENA_SIZE / 1024 << \" KB\uff09\\n\";\n    check_model_fit(256);  // \u5178\u578bMCU\u95ea\u5b58\n\n    // \u6fc0\u6d3b\u5185\u5b58\u4f7f\u7528\n    int activation_bytes = 64 + 10;  // \u5c421 + \u5c422\u8f93\u51fa\n    std::cout << \"\u6fc0\u6d3b\u5185\u5b58: \" << activation_bytes\n              << \" \u5b57\u8282 / \" << ARENA_SIZE << \" \u53ef\u7528\\n\\n\";\n\n    // \u57fa\u51c6\u6d4b\u8bd5\u63a8\u7406\n    int8_t input[784];\n    int8_t output[10];\n    std::memset(input, 1, 784);\n\n    auto start = std::chrono::high_resolution_clock::now();\n    for (int i = 0; i < 10000; i++) {\n        mock_inference(input, output);\n    }\n    auto end = std::chrono::high_resolution_clock::now();\n    double us = std::chrono::duration<double, std::micro>(end - start).count() / 10000;\n\n    std::cout << \"\u63a8\u7406\u5ef6\u8fdf: \" << us << \" us\\n\";\n    std::cout << \"\u5728160 MHz MCU\uff08\u7ea66.25 ns/\u5468\u671f\uff09\u4e0b\uff1a\u7ea6\"\n              << (int)(us * 160) << \" \u5468\u671f\\n\";\n\n    std::cout << \"\u8f93\u51falogits: \";\n    for (int i = 0; i < 10; i++) std::cout << (int)output[i] << \" \";\n    std::cout << \"\\n\";\n\n    return 0;\n}\n

  2. \u7f16\u5199\u4e00\u4e2aC++\u7a0b\u5e8f\uff0c\u5c06float32\u6743\u91cd\u91cf\u5316\u4e3aINT8\uff0c\u5e76\u6d4b\u91cf\u538b\u7f29\u6bd4\u548c\u91cf\u5316\u8bef\u5dee\u3002

    // task2_quantise.cpp\n// \u7f16\u8bd1\uff1ag++ -O3 -o task2 task2_quantise.cpp\n\n#include <iostream>\n#include <vector>\n#include <cmath>\n#include <algorithm>\n#include <numeric>\n\n// \u5bf9\u79f0\u91cf\u5316\uff1a\u5c06\u6d6e\u70b9\u8303\u56f4 [-max, +max] \u6620\u5c04\u5230 [-127, +127]\nvoid quantise_symmetric(const float* input, int8_t* output, int n, float& scale) {\n    float max_val = 0.0f;\n    for (int i = 0; i < n; i++) {\n        max_val = std::max(max_val, std::abs(input[i]));\n    }\n    scale = max_val / 127.0f;\n    for (int i = 0; i < n; i++) {\n        float scaled = input[i] / scale;\n        output[i] = (int8_t)std::max(-127.0f, std::min(127.0f, std::round(scaled)));\n    }\n}\n\n// \u53cd\u91cf\u5316\uff1aINT8\u8f6c\u56defloat\nvoid dequantise(const int8_t* input, float* output, int n, float scale) {\n    for (int i = 0; i < n; i++) {\n        output[i] = (float)input[i] * scale;\n    }\n}\n\nint main() {\n    const int N = 100000;\n\n    // \u6a21\u62df\u968f\u673a\u6743\u91cd\uff08\u5927\u81f4\u6b63\u6001\u5206\u5e03\uff09\n    std::vector<float> weights(N);\n    for (int i = 0; i < N; i++) {\n        // \u7b80\u5355\u7684\u4f2a\u968f\u673a\u6b63\u6001\u503c\n        float u1 = (float)(i * 7 % 997 + 1) / 998.0f;\n        float u2 = (float)(i * 13 % 991 + 1) / 992.0f;\n        weights[i] = std::sqrt(-2.0f * std::log(u1)) * std::cos(6.2832f * u2) * 0.1f;\n    }\n\n    // \u91cf\u5316\n    std::vector<int8_t> quantised(N);\n    float scale;\n    quantise_symmetric(weights.data(), quantised.data(), N, scale);\n\n    // \u53cd\u91cf\u5316\u5e76\u6d4b\u91cf\u8bef\u5dee\n    std::vector<float> reconstructed(N);\n    dequantise(quantised.data(), reconstructed.data(), N, scale);\n\n    float max_error = 0.0f, total_error = 0.0f;\n    for (int i = 0; i < N; i++) {\n        float err = std::abs(weights[i] - reconstructed[i]);\n        max_error = std::max(max_error, err);\n        total_error += err;\n    }\n\n    std::cout << \"=== \u91cf\u5316\u7ed3\u679c ===\\n\";\n    std::cout << \"\u539f\u59cb:    \" << N * 4 << \" \u5b57\u8282\uff08float32\uff09\\n\";\n    std::cout << \"\u91cf\u5316:   \" << N * 1 << \" \u5b57\u8282\uff08int8\uff09+ 4 \u5b57\u8282\uff08\u7f29\u653e\u56e0\u5b50\uff09\\n\";\n    std::cout << \"\u538b\u7f29\u6bd4: \" << 4.0f << \"x\\n\";\n    std::cout << \"\u7f29\u653e\u56e0\u5b50: \" << scale << \"\\n\";\n    std::cout << \"\u5e73\u5747\u7edd\u5bf9\u8bef\u5dee: \" << total_error / N << \"\\n\";\n    std::cout << \"\u6700\u5927\u7edd\u5bf9\u8bef\u5dee:  \" << max_error << \"\\n\";\n    std::cout << \"\u6700\u5927\u7edd\u5bf9\u8bef\u5dee/\u7f29\u653e\u56e0\u5b50: \" << max_error / scale\n              << \"\uff08\u5e94 <= 0.5 \u91cf\u5316\u7ea7\u522b\uff09\\n\";\n\n    return 0;\n}\n

  3. \u7f16\u5199\u4e00\u4e2aC++\u7a0b\u5e8f\uff0c\u6267\u884cINT8\u77e9\u9635\u4e58\u6cd5\uff08INT32\u7d2f\u52a0\uff09\u2014\u2014\u8fd9\u662f\u5728\u5d4c\u5165\u5f0fML\u52a0\u901f\u5668\u4e0a\u8fd0\u884c\u7684\u5b9e\u9645\u8ba1\u7b97\u3002

    // task3_int8_matmul.cpp\n// \u7f16\u8bd1\uff1ag++ -O3 -o task3 task3_int8_matmul.cpp\n\n#include <iostream>\n#include <chrono>\n#include <vector>\n#include <cstdint>\n\n// INT8\u77e9\u9635\u4e58\u6cd5\uff08INT32\u7d2f\u52a0\uff09\u2014\u2014\u5f20\u91cf\u6838\u5fc3\u548cMCU\u52a0\u901f\u5668\u7684\u5b9e\u9645\u5de5\u4f5c\u65b9\u5f0f\nvoid matmul_int8(const int8_t* A, const int8_t* B, int32_t* C,\n                 int M, int N, int K) {\n    for (int i = 0; i < M; i++) {\n        for (int j = 0; j < N; j++) {\n            int32_t sum = 0;\n            for (int k = 0; k < K; k++) {\n                sum += (int32_t)A[i * K + k] * (int32_t)B[k * N + j];\n            }\n            C[i * N + j] = sum;\n        }\n    }\n}\n\n// \u7528\u4e8e\u6bd4\u8f83\u7684Float32\u77e9\u9635\u4e58\u6cd5\nvoid matmul_f32(const float* A, const float* B, float* C,\n                int M, int N, int K) {\n    for (int i = 0; i < M; i++) {\n        for (int j = 0; j < N; j++) {\n            float sum = 0.0f;\n            for (int k = 0; k < K; k++) {\n                sum += A[i * K + k] * B[k * N + j];\n            }\n            C[i * N + j] = sum;\n        }\n    }\n}\n\nint main() {\n    const int M = 128, N = 128, K = 128;\n\n    std::vector<int8_t> A_i8(M * K, 1), B_i8(K * N, 1);\n    std::vector<int32_t> C_i32(M * N);\n\n    std::vector<float> A_f32(M * K, 1.0f), B_f32(K * N, 1.0f);\n    std::vector<float> C_f32(M * N);\n\n    // \u57fa\u51c6\u6d4b\u8bd5INT8\n    auto start = std::chrono::high_resolution_clock::now();\n    for (int t = 0; t < 100; t++) {\n        matmul_int8(A_i8.data(), B_i8.data(), C_i32.data(), M, N, K);\n    }\n    auto end = std::chrono::high_resolution_clock::now();\n    double i8_ms = std::chrono::duration<double, std::milli>(end - start).count() / 100;\n\n    // \u57fa\u51c6\u6d4b\u8bd5FP32\n    start = std::chrono::high_resolution_clock::now();\n    for (int t = 0; t < 100; t++) {\n        matmul_f32(A_f32.data(), B_f32.data(), C_f32.data(), M, N, K);\n    }\n    end = std::chrono::high_resolution_clock::now();\n    double f32_ms = std::chrono::duration<double, std::milli>(end - start).count() / 100;\n\n    double gops_i8 = 2.0 * M * N * K / i8_ms / 1e6;\n    double gflops_f32 = 2.0 * M * N * K / f32_ms / 1e6;\n\n    std::cout << \"INT8\u77e9\u9635\u4e58\u6cd5:  \" << i8_ms << \" ms\uff08\" << gops_i8 << \" GOPS\uff09\\n\";\n    std::cout << \"FP32\u77e9\u9635\u4e58\u6cd5:  \" << f32_ms << \" ms\uff08\" << gflops_f32 << \" GFLOPS\uff09\\n\";\n    std::cout << \"INT8\u52a0\u901f\u6bd4: \" << f32_ms / i8_ms << \"x\\n\";\n    std::cout << \"\u5185\u5b58: INT8 = \" << M*K + K*N << \" \u5b57\u8282 vs FP32 = \"\n              << (M*K + K*N) * 4 << \" \u5b57\u8282\uff08\u5c0f4\u500d\uff09\\n\";\n\n    return 0;\n}\n

"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/07.%20vulkan%20compute%20and%20cross-platform%20GPU/","title":"Vulkan Compute \u4e0e\u8de8\u5e73\u53f0 GPU","text":"

Vulkan \u662f\u552f\u4e00\u80fd\u5728\u6240\u6709\u4e3b\u8981\u5e73\u53f0\u4e0a\u8fd0\u884c\u7684 GPU \u8ba1\u7b97 API\uff1aNVIDIA\u3001AMD\u3001Intel\u3001Apple\uff08\u901a\u8fc7 MoltenVK\uff09\u3001Android\uff0c\u751a\u81f3\u6d4f\u89c8\u5668\uff08\u901a\u8fc7 WebGPU\uff09\u3002\u672c\u6587\u6db5\u76d6 Vulkan \u67b6\u6784\u3001\u8ba1\u7b97\u7ba1\u7ebf\u3001\u4f7f\u7528 GLSL \u7f16\u5199\u8ba1\u7b97\u7740\u8272\u5668\u3001GPU \u8ba1\u7b97\u7a0b\u5e8f\u7684\u5b8c\u6574 C++ \u8bbe\u7f6e\u3001\u5171\u4eab\u5185\u5b58\u4e0e\u540c\u6b65\u3001\u7528\u4e8e\u6d4f\u89c8\u5668\u7684 WebGPU\uff0c\u4ee5\u53ca\u5b9e\u9645\u7684\u673a\u5668\u5b66\u4e60\u63a8\u7406\u793a\u4f8b\u3002

"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/07.%20vulkan%20compute%20and%20cross-platform%20GPU/#vulkan","title":"Vulkan \u67b6\u6784\u6982\u8ff0","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/07.%20vulkan%20compute%20and%20cross-platform%20GPU/#_1","title":"\u4e3a\u4ec0\u4e48\u5982\u6b64\u5197\u957f\uff1f","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/07.%20vulkan%20compute%20and%20cross-platform%20GPU/#glsl","title":"GLSL \u4e2d\u7684\u8ba1\u7b97\u7740\u8272\u5668","text":""},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/07.%20vulkan%20compute%20and%20cross-platform%20GPU/#_2","title":"\u5411\u91cf\u52a0\u6cd5","text":"
// add.comp \u2014 \u7f16\u8bd1\u547d\u4ee4: glslangValidator -V add.comp -o add.spv\n#version 450\n\n// \u5de5\u4f5c\u7ec4\u5927\u5c0f\uff1a\u6bcf\u4e2a\u5de5\u4f5c\u7ec4\u6709 256 \u4e2a\u8c03\u7528\uff08= CUDA \u4e2d\u6bcf\u5757\u7684\u7ebf\u7a0b\u6570\uff09\nlayout(local_size_x = 256) in;\n\n// \u7f13\u51b2\u533a\u7ed1\u5b9a\uff08\u7c7b\u4f3c\u4e8e\u5185\u6838\u53c2\u6570\uff09\nlayout(set = 0, binding = 0) buffer InputA { float a[]; };\nlayout(set = 0, binding = 1) buffer InputB { float b[]; };\nlayout(set = 0, binding = 2) buffer Output { float c[]; };\n\n// \u63a8\u9001\u5e38\u91cf\uff1a\u5c0f\u7684\u7edf\u4e00\u6570\u636e\uff08\u7c7b\u4f3c\u4e8e\u5185\u6838\u53c2\u6570\uff09\nlayout(push_constant) uniform PushConstants {\n    uint n;  // \u5143\u7d20\u6570\u91cf\n};\n\nvoid main() {\n    uint idx = gl_GlobalInvocationID.x;  // \u5168\u5c40\u7ebf\u7a0b\u7d22\u5f15\n    if (idx < n) {\n        c[idx] = a[idx] + b[idx];\n    }\n}\n
Vulkan CUDA \u542b\u4e49 \u5de5\u4f5c\u7ec4 (Workgroup) \u5757 (Block) \u53ef\u4ee5\u5171\u4eab\u5185\u5b58\u7684\u7ebf\u7a0b\u7ec4 \u8c03\u7528 (Invocation) \u7ebf\u7a0b (Thread) \u5355\u4e2a\u6267\u884c\u5355\u5143 gl_GlobalInvocationID blockIdx * blockDim + threadIdx \u5168\u5c40\u7ebf\u7a0b\u7d22\u5f15 gl_LocalInvocationID threadIdx \u5de5\u4f5c\u7ec4\u5185\u7684\u7ebf\u7a0b\u7d22\u5f15 gl_WorkGroupID blockIdx \u5de5\u4f5c\u7ec4\u7d22\u5f15 local_size_x blockDim.x \u6bcf\u5de5\u4f5c\u7ec4\u7684\u7ebf\u7a0b\u6570 \u5b58\u50a8\u7f13\u51b2\u533a \u5168\u5c40\u5185\u5b58 \u53ef\u8bfb\u5199\u7684 GPU \u5185\u5b58 \u5171\u4eab\u5185\u5b58 (shared) __shared__ \u6bcf\u5de5\u4f5c\u7ec4\u7684\u9ad8\u901f\u5185\u5b58 \u63a8\u9001\u5e38\u91cf \u5185\u6838\u53c2\u6570 \u5c0f\u7684\u7edf\u4e00\u6570\u636e"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/07.%20vulkan%20compute%20and%20cross-platform%20GPU/#relu","title":"\u4f7f\u7528\u5171\u4eab\u5185\u5b58\u7684 ReLU","text":"
// relu_shared.comp\n#version 450\n\nlayout(local_size_x = 256) in;\n\nlayout(set = 0, binding = 0) buffer Input  { float input_data[]; };\nlayout(set = 0, binding = 1) buffer Output { float output_data[]; };\n\nlayout(push_constant) uniform PushConstants { uint n; };\n\n// \u5171\u4eab\u5185\u5b58\uff08\u7b49\u540c\u4e8e CUDA \u7684 __shared__\uff09\nshared float tile[256];\n\nvoid main() {\n    uint gid = gl_GlobalInvocationID.x;\n    uint lid = gl_LocalInvocationID.x;\n\n    // \u52a0\u8f7d\u5230\u5171\u4eab\u5185\u5b58\n    if (gid < n) {\n        tile[lid] = input_data[gid];\n    }\n\n    // \u5c4f\u969c\uff1a\u7b49\u5f85\u5de5\u4f5c\u7ec4\u4e2d\u6240\u6709\u8c03\u7528\u5b8c\u6210\u52a0\u8f7d\n    barrier();  // \u7b49\u540c\u4e8e CUDA \u7684 __syncthreads()\n\n    // \u8ba1\u7b97 ReLU\n    if (gid < n) {\n        output_data[gid] = max(tile[lid], 0.0);\n    }\n}\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/07.%20vulkan%20compute%20and%20cross-platform%20GPU/#_3","title":"\u5e76\u884c\u5f52\u7ea6\uff08\u6c42\u548c\uff09","text":"
// reduce_sum.comp\n#version 450\n\nlayout(local_size_x = 256) in;\n\nlayout(set = 0, binding = 0) buffer Input  { float input_data[]; };\nlayout(set = 0, binding = 1) buffer Output { float partial_sums[]; };\n\nlayout(push_constant) uniform PushConstants { uint n; };\n\nshared float sdata[256];\n\nvoid main() {\n    uint gid = gl_GlobalInvocationID.x;\n    uint lid = gl_LocalInvocationID.x;\n    uint wgid = gl_WorkGroupID.x;\n\n    // \u52a0\u8f7d\u5230\u5171\u4eab\u5185\u5b58\n    sdata[lid] = (gid < n) ? input_data[gid] : 0.0;\n    barrier();\n\n    // \u5de5\u4f5c\u7ec4\u5185\u7684\u6811\u5f62\u5f52\u7ea6\n    for (uint stride = 128; stride > 0; stride >>= 1) {\n        if (lid < stride) {\n            sdata[lid] += sdata[lid + stride];\n        }\n        barrier();\n    }\n\n    // \u7ebf\u7a0b 0 \u5199\u5165\u5de5\u4f5c\u7ec4\u7684\u5c40\u90e8\u548c\n    if (lid == 0) {\n        partial_sums[wgid] = sdata[0];\n    }\n}\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/07.%20vulkan%20compute%20and%20cross-platform%20GPU/#_4","title":"\u4f7f\u7528\u5206\u5757\u7684\u77e9\u9635\u4e58\u6cd5","text":"
// matmul_tiled.comp\n#version 450\n\n#define TILE_SIZE 16\n\nlayout(local_size_x = TILE_SIZE, local_size_y = TILE_SIZE) in;\n\nlayout(set = 0, binding = 0) buffer MatA { float A[]; };\nlayout(set = 0, binding = 1) buffer MatB { float B[]; };\nlayout(set = 0, binding = 2) buffer MatC { float C[]; };\n\nlayout(push_constant) uniform PushConstants {\n    uint M, N, K;\n};\n\nshared float tileA[TILE_SIZE][TILE_SIZE];\nshared float tileB[TILE_SIZE][TILE_SIZE];\n\nvoid main() {\n    uint row = gl_GlobalInvocationID.y;\n    uint col = gl_GlobalInvocationID.x;\n    uint lr = gl_LocalInvocationID.y;\n    uint lc = gl_LocalInvocationID.x;\n\n    float sum = 0.0;\n\n    for (uint t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {\n        // \u5c06 A \u548c B \u7684\u5206\u5757\u52a0\u8f7d\u5230\u5171\u4eab\u5185\u5b58\u4e2d\n        uint aCol = t * TILE_SIZE + lc;\n        uint bRow = t * TILE_SIZE + lr;\n\n        tileA[lr][lc] = (row < M && aCol < K) ? A[row * K + aCol] : 0.0;\n        tileB[lr][lc] = (bRow < K && col < N) ? B[bRow * N + col] : 0.0;\n\n        barrier();\n\n        // \u8ba1\u7b97\u90e8\u5206\u70b9\u79ef\n        for (uint k = 0; k < TILE_SIZE; k++) {\n            sum += tileA[lr][k] * tileB[k][lc];\n        }\n\n        barrier();\n    }\n\n    if (row < M && col < N) {\n        C[row * N + col] = sum;\n    }\n}\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/07.%20vulkan%20compute%20and%20cross-platform%20GPU/#c-vulkan","title":"C++ Vulkan \u8bbe\u7f6e","text":"
// vulkan_compute.cpp \u2014 \u4e00\u4e2a\u6700\u5c0f\u4f46\u5b8c\u6574\u7684 Vulkan \u8ba1\u7b97\u793a\u4f8b\n// \u7f16\u8bd1\u547d\u4ee4: g++ -O3 -o vulkan_compute vulkan_compute.cpp -lvulkan\n// \u8981\u6c42: \u5df2\u5b89\u88c5 Vulkan SDK\uff0c\u5df2\u4ece add.comp \u7f16\u8bd1 add.spv\n\n#include <vulkan/vulkan.h>\n#include <iostream>\n#include <vector>\n#include <fstream>\n#include <cassert>\n\n// \u8f85\u52a9\u51fd\u6570\uff1a\u8bfb\u53d6 SPIR-V \u6587\u4ef6\nstd::vector<uint32_t> readSPIRV(const std::string& filename) {\n    std::ifstream file(filename, std::ios::ate | std::ios::binary);\n    size_t fileSize = file.tellg();\n    std::vector<uint32_t> buffer(fileSize / sizeof(uint32_t));\n    file.seekg(0);\n    file.read(reinterpret_cast<char*>(buffer.data()), fileSize);\n    return buffer;\n}\n\nint main() {\n    const uint32_t N = 1024;\n    const size_t bufferSize = N * sizeof(float);\n\n    // ========== 1. \u521b\u5efa Vulkan \u5b9e\u4f8b ==========\n    VkApplicationInfo appInfo{};\n    appInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;\n    appInfo.apiVersion = VK_API_VERSION_1_2;\n\n    VkInstanceCreateInfo instanceInfo{};\n    instanceInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;\n    instanceInfo.pApplicationInfo = &appInfo;\n\n    VkInstance instance;\n    vkCreateInstance(&instanceInfo, nullptr, &instance);\n\n    // ========== 2. \u9009\u62e9\u7269\u7406\u8bbe\u5907 (GPU) ==========\n    uint32_t deviceCount = 0;\n    vkEnumeratePhysicalDevices(instance, &deviceCount, nullptr);\n    std::vector<VkPhysicalDevice> devices(deviceCount);\n    vkEnumeratePhysicalDevices(instance, &deviceCount, devices.data());\n    VkPhysicalDevice physicalDevice = devices[0];  // \u4f7f\u7528\u7b2c\u4e00\u4e2a GPU\n\n    // \u6253\u5370 GPU \u540d\u79f0\n    VkPhysicalDeviceProperties props;\n    vkGetPhysicalDeviceProperties(physicalDevice, &props);\n    std::cout << \"\u4f7f\u7528\u7684 GPU: \" << props.deviceName << \"\\n\";\n\n    // ========== 3. \u67e5\u627e\u8ba1\u7b97\u961f\u5217\u65cf ==========\n    uint32_t queueFamilyCount = 0;\n    vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice, &queueFamilyCount, nullptr);\n    std::vector<VkQueueFamilyProperties> queueFamilies(queueFamilyCount);\n    vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice, &queueFamilyCount, queueFamilies.data());\n\n    uint32_t computeFamily = 0;\n    for (uint32_t i = 0; i < queueFamilyCount; i++) {\n        if (queueFamilies[i].queueFlags & VK_QUEUE_COMPUTE_BIT) {\n            computeFamily = i;\n            break;\n        }\n    }\n\n    // ========== 4. \u521b\u5efa\u903b\u8f91\u8bbe\u5907\u548c\u961f\u5217 ==========\n    float queuePriority = 1.0f;\n    VkDeviceQueueCreateInfo queueInfo{};\n    queueInfo.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;\n    queueInfo.queueFamilyIndex = computeFamily;\n    queueInfo.queueCount = 1;\n    queueInfo.pQueuePriorities = &queuePriority;\n\n    VkDeviceCreateInfo deviceInfo{};\n    deviceInfo.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;\n    deviceInfo.queueCreateInfoCount = 1;\n    deviceInfo.pQueueCreateInfos = &queueInfo;\n\n    VkDevice device;\n    vkCreateDevice(physicalDevice, &deviceInfo, nullptr, &device);\n\n    VkQueue computeQueue;\n    vkGetDeviceQueue(device, computeFamily, 0, &computeQueue);\n\n    // ========== 5. \u5206\u914d\u7f13\u51b2\u533a (A, B, C) ==========\n    // \u4e3a\u7b80\u6d01\u8d77\u89c1\uff0c\u8fd9\u91cc\u4f7f\u7528\u4e3b\u673a\u53ef\u89c1\u5185\u5b58\uff08\u8f83\u6162\u4f46\u66f4\u7b80\u5355\uff09\n    auto createBuffer = [&](VkBuffer& buffer, VkDeviceMemory& memory) {\n        VkBufferCreateInfo bufInfo{};\n        bufInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;\n        bufInfo.size = bufferSize;\n        bufInfo.usage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;\n        vkCreateBuffer(device, &bufInfo, nullptr, &buffer);\n\n        VkMemoryRequirements memReqs;\n        vkGetBufferMemoryRequirements(device, buffer, &memReqs);\n\n        // \u67e5\u627e\u4e3b\u673a\u53ef\u89c1\u7684\u5185\u5b58\u7c7b\u578b\n        VkPhysicalDeviceMemoryProperties memProps;\n        vkGetPhysicalDeviceMemoryProperties(physicalDevice, &memProps);\n        uint32_t memType = 0;\n        for (uint32_t i = 0; i < memProps.memoryTypeCount; i++) {\n            if ((memReqs.memoryTypeBits & (1 << i)) &&\n                (memProps.memoryTypes[i].propertyFlags &\n                 (VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT))) {\n                memType = i;\n                break;\n            }\n        }\n\n        VkMemoryAllocateInfo allocInfo{};\n        allocInfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;\n        allocInfo.allocationSize = memReqs.size;\n        allocInfo.memoryTypeIndex = memType;\n        vkAllocateMemory(device, &allocInfo, nullptr, &memory);\n        vkBindBufferMemory(device, buffer, memory, 0);\n    };\n\n    VkBuffer bufA, bufB, bufC;\n    VkDeviceMemory memA, memB, memC;\n    createBuffer(bufA, memA);\n    createBuffer(bufB, memB);\n    createBuffer(bufC, memC);\n\n    // ========== 6. \u586b\u5145\u8f93\u5165\u7f13\u51b2\u533a ==========\n    float* ptrA;\n    vkMapMemory(device, memA, 0, bufferSize, 0, (void**)&ptrA);\n    for (uint32_t i = 0; i < N; i++) ptrA[i] = 1.0f;\n    vkUnmapMemory(device, memA);\n\n    float* ptrB;\n    vkMapMemory(device, memB, 0, bufferSize, 0, (void**)&ptrB);\n    for (uint32_t i = 0; i < N; i++) ptrB[i] = 2.0f;\n    vkUnmapMemory(device, memB);\n\n    // ========== 7. \u521b\u5efa\u8ba1\u7b97\u7ba1\u7ebf ==========\n    auto spirvCode = readSPIRV(\"add.spv\");\n    VkShaderModuleCreateInfo shaderInfo{};\n    shaderInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;\n    shaderInfo.codeSize = spirvCode.size() * sizeof(uint32_t);\n    shaderInfo.pCode = spirvCode.data();\n    VkShaderModule shaderModule;\n    vkCreateShaderModule(device, &shaderInfo, nullptr, &shaderModule);\n\n    // \u63cf\u8ff0\u7b26\u96c6\u5e03\u5c40\uff08\u544a\u8bc9 Vulkan \u7f13\u51b2\u533a\u7ed1\u5b9a\u7684\u4fe1\u606f\uff09\n    VkDescriptorSetLayoutBinding bindings[3] = {};\n    for (int i = 0; i < 3; i++) {\n        bindings[i].binding = i;\n        bindings[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;\n        bindings[i].descriptorCount = 1;\n        bindings[i].stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;\n    }\n\n    VkDescriptorSetLayoutCreateInfo layoutInfo{};\n    layoutInfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;\n    layoutInfo.bindingCount = 3;\n    layoutInfo.pBindings = bindings;\n    VkDescriptorSetLayout descLayout;\n    vkCreateDescriptorSetLayout(device, &layoutInfo, nullptr, &descLayout);\n\n    // \u63a8\u9001\u5e38\u91cf\u8303\u56f4\n    VkPushConstantRange pushRange{};\n    pushRange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;\n    pushRange.offset = 0;\n    pushRange.size = sizeof(uint32_t);\n\n    // \u7ba1\u7ebf\u5e03\u5c40\n    VkPipelineLayoutCreateInfo pipeLayoutInfo{};\n    pipeLayoutInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;\n    pipeLayoutInfo.setLayoutCount = 1;\n    pipeLayoutInfo.pSetLayouts = &descLayout;\n    pipeLayoutInfo.pushConstantRangeCount = 1;\n    pipeLayoutInfo.pPushConstantRanges = &pushRange;\n    VkPipelineLayout pipelineLayout;\n    vkCreatePipelineLayout(device, &pipeLayoutInfo, nullptr, &pipelineLayout);\n\n    // \u8ba1\u7b97\u7ba1\u7ebf\n    VkComputePipelineCreateInfo pipeInfo{};\n    pipeInfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;\n    pipeInfo.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;\n    pipeInfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;\n    pipeInfo.stage.module = shaderModule;\n    pipeInfo.stage.pName = \"main\";\n    pipeInfo.layout = pipelineLayout;\n    VkPipeline pipeline;\n    vkCreateComputePipelines(device, VK_NULL_HANDLE, 1, &pipeInfo, nullptr, &pipeline);\n\n    // ========== 8. \u63cf\u8ff0\u7b26\u96c6\uff08\u5c06\u7f13\u51b2\u533a\u7ed1\u5b9a\u5230\u7740\u8272\u5668\uff09 ==========\n    VkDescriptorPoolSize poolSize{};\n    poolSize.type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;\n    poolSize.descriptorCount = 3;\n\n    VkDescriptorPoolCreateInfo poolInfo{};\n    poolInfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;\n    poolInfo.maxSets = 1;\n    poolInfo.poolSizeCount = 1;\n    poolInfo.pPoolSizes = &poolSize;\n    VkDescriptorPool descPool;\n    vkCreateDescriptorPool(device, &poolInfo, nullptr, &descPool);\n\n    VkDescriptorSetAllocateInfo descAllocInfo{};\n    descAllocInfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;\n    descAllocInfo.descriptorPool = descPool;\n    descAllocInfo.descriptorSetCount = 1;\n    descAllocInfo.pSetLayouts = &descLayout;\n    VkDescriptorSet descSet;\n    vkAllocateDescriptorSets(device, &descAllocInfo, &descSet);\n\n    // \u5c06\u7f13\u51b2\u533a\u5f15\u7528\u5199\u5165\u63cf\u8ff0\u7b26\u96c6\n    VkDescriptorBufferInfo bufInfos[3] = {\n        {bufA, 0, bufferSize}, {bufB, 0, bufferSize}, {bufC, 0, bufferSize}\n    };\n    VkWriteDescriptorSet writes[3] = {};\n    for (int i = 0; i < 3; i++) {\n        writes[i].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;\n        writes[i].dstSet = descSet;\n        writes[i].dstBinding = i;\n        writes[i].descriptorCount = 1;\n        writes[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;\n        writes[i].pBufferInfo = &bufInfos[i];\n    }\n    vkUpdateDescriptorSets(device, 3, writes, 0, nullptr);\n\n    // ========== 9. \u8bb0\u5f55\u548c\u63d0\u4ea4\u547d\u4ee4\u7f13\u51b2\u533a ==========\n    VkCommandPoolCreateInfo cmdPoolInfo{};\n    cmdPoolInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;\n    cmdPoolInfo.queueFamilyIndex = computeFamily;\n    VkCommandPool cmdPool;\n    vkCreateCommandPool(device, &cmdPoolInfo, nullptr, &cmdPool);\n\n    VkCommandBufferAllocateInfo cmdAllocInfo{};\n    cmdAllocInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;\n    cmdAllocInfo.commandPool = cmdPool;\n    cmdAllocInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;\n    cmdAllocInfo.commandBufferCount = 1;\n    VkCommandBuffer cmdBuf;\n    vkAllocateCommandBuffers(device, &cmdAllocInfo, &cmdBuf);\n\n    VkCommandBufferBeginInfo beginInfo{};\n    beginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;\n    vkBeginCommandBuffer(cmdBuf, &beginInfo);\n\n    vkCmdBindPipeline(cmdBuf, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);\n    vkCmdBindDescriptorSets(cmdBuf, VK_PIPELINE_BIND_POINT_COMPUTE,\n                            pipelineLayout, 0, 1, &descSet, 0, nullptr);\n    vkCmdPushConstants(cmdBuf, pipelineLayout, VK_SHADER_STAGE_COMPUTE_BIT,\n                       0, sizeof(uint32_t), &N);\n    vkCmdDispatch(cmdBuf, (N + 255) / 256, 1, 1);  // \u542f\u52a8\u5de5\u4f5c\u7ec4\n\n    vkEndCommandBuffer(cmdBuf);\n\n    // \u63d0\u4ea4\n    VkFenceCreateInfo fenceInfo{};\n    fenceInfo.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO;\n    VkFence fence;\n    vkCreateFence(device, &fenceInfo, nullptr, &fence);\n\n    VkSubmitInfo submitInfo{};\n    submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;\n    submitInfo.commandBufferCount = 1;\n    submitInfo.pCommandBuffers = &cmdBuf;\n    vkQueueSubmit(computeQueue, 1, &submitInfo, fence);\n    vkWaitForFences(device, 1, &fence, VK_TRUE, UINT64_MAX);\n\n    // ========== 10. \u8bfb\u53d6\u7ed3\u679c ==========\n    float* ptrC;\n    vkMapMemory(device, memC, 0, bufferSize, 0, (void**)&ptrC);\n    std::cout << \"\u7ed3\u679c: c[0]=\" << ptrC[0] << \" c[1]=\" << ptrC[1]\n              << \" (\u671f\u671b\u503c 3.0)\\n\";\n    bool correct = true;\n    for (uint32_t i = 0; i < N; i++) {\n        if (ptrC[i] != 3.0f) { correct = false; break; }\n    }\n    std::cout << (correct ? \"\u5168\u90e8\u6b63\u786e\" : \"\u53d1\u73b0\u9519\u8bef\") << \"\\n\";\n    vkUnmapMemory(device, memC);\n\n    // ========== \u6e05\u7406\uff08\u7b80\u5199\uff09 ==========\n    vkDestroyFence(device, fence, nullptr);\n    vkDestroyCommandPool(device, cmdPool, nullptr);\n    vkDestroyPipeline(device, pipeline, nullptr);\n    vkDestroyPipelineLayout(device, pipelineLayout, nullptr);\n    vkDestroyDescriptorPool(device, descPool, nullptr);\n    vkDestroyDescriptorSetLayout(device, descLayout, nullptr);\n    vkDestroyShaderModule(device, shaderModule, nullptr);\n    vkDestroyBuffer(device, bufA, nullptr); vkFreeMemory(device, memA, nullptr);\n    vkDestroyBuffer(device, bufB, nullptr); vkFreeMemory(device, memB, nullptr);\n    vkDestroyBuffer(device, bufC, nullptr); vkFreeMemory(device, memC, nullptr);\n    vkDestroyDevice(device, nullptr);\n    vkDestroyInstance(instance, nullptr);\n\n    return 0;\n}\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/07.%20vulkan%20compute%20and%20cross-platform%20GPU/#kompute-ml-vulkan","title":"Kompute\uff1a\u4e3a ML \u7b80\u5316\u7684 Vulkan","text":"
#include <kompute/Kompute.hpp>\n\nint main() {\n    kp::Manager mgr;\n\n    auto tensorA = mgr.tensor({1, 1, 1, 1, 1});\n    auto tensorB = mgr.tensor({2, 2, 2, 2, 2});\n    auto tensorC = mgr.tensor({0, 0, 0, 0, 0});\n\n    std::string shader = R\"(\n        #version 450\n        layout(local_size_x = 1) in;\n        layout(set=0, binding=0) buffer A { float a[]; };\n        layout(set=0, binding=1) buffer B { float b[]; };\n        layout(set=0, binding=2) buffer C { float c[]; };\n        void main() {\n            uint i = gl_GlobalInvocationID.x;\n            c[i] = a[i] + b[i];\n        }\n    )\";\n\n    auto algorithm = mgr.algorithm({tensorA, tensorB, tensorC},\n                                     kompute::Shader::compile_source(shader));\n\n    mgr.sequence()\n        ->record<kp::OpTensorSyncDevice>({tensorA, tensorB, tensorC})\n        ->record<kp::OpAlgoDispatch>(algorithm)\n        ->record<kp::OpTensorSyncLocal>({tensorC})\n        ->eval();\n\n    // tensorC \u73b0\u5728\u5305\u542b [3, 3, 3, 3, 3]\n}\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/07.%20vulkan%20compute%20and%20cross-platform%20GPU/#webgpu-gpu","title":"WebGPU\uff1a\u6d4f\u89c8\u5668\u4e2d\u7684 GPU \u8ba1\u7b97","text":"
// add.wgsl \u2014 WebGPU \u8ba1\u7b97\u7740\u8272\u5668\n@group(0) @binding(0) var<storage, read> a: array<f32>;\n@group(0) @binding(1) var<storage, read> b: array<f32>;\n@group(0) @binding(2) var<storage, read_write> c: array<f32>;\n\n@compute @workgroup_size(256)\nfn main(@builtin(global_invocation_id) id: vec3<u32>) {\n    let i = id.x;\n    c[i] = a[i] + b[i];\n}\n
const adapter = await navigator.gpu.requestAdapter();\nconst device = await adapter.requestDevice();\n\n// \u521b\u5efa\u7f13\u51b2\u533a\nconst bufferA = device.createBuffer({ size: N * 4, usage: GPUBufferUsage.STORAGE, mappedAtCreation: true });\nnew Float32Array(bufferA.getMappedRange()).fill(1.0);\nbufferA.unmap();\n\n// ...\uff08B \u548c C \u7c7b\u4f3c\uff09\n\n// \u4ece WGSL \u7740\u8272\u5668\u521b\u5efa\u7ba1\u7ebf\nconst pipeline = device.createComputePipeline({\n    layout: 'auto',\n    compute: { module: device.createShaderModule({ code: wgslSource }), entryPoint: 'main' }\n});\n\n// \u8c03\u5ea6\nconst encoder = device.createCommandEncoder();\nconst pass = encoder.beginComputePass();\npass.setPipeline(pipeline);\npass.setBindGroup(0, bindGroup);\npass.dispatchWorkgroups(Math.ceil(N / 256));\npass.end();\ndevice.queue.submit([encoder.finish()]);\n
"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/07.%20vulkan%20compute%20and%20cross-platform%20GPU/#vulkan_1","title":"\u4f55\u65f6\u4f7f\u7528 Vulkan","text":"\u573a\u666f \u4f7f\u7528 Vulkan\uff1f \u539f\u56e0 / \u66ff\u4ee3\u65b9\u6848 ML \u8bad\u7ec3 \u5426 CUDA/Triton \u5728 NVIDIA \u4e0a\u66f4\u7b80\u5355\u66f4\u5feb\u901f NVIDIA GPU \u4e0a\u7684\u63a8\u7406 \u5426 TensorRT \u6216 CUDA \u66f4\u597d AMD/Intel GPU \u4e0a\u7684\u63a8\u7406 \u662f \u552f\u4e00\u8de8\u5382\u5546\u7684 GPU \u8ba1\u7b97\u9009\u9879 \u79fb\u52a8\u7aef\u63a8\u7406\uff08Android\uff09 \u662f Vulkan \u662f Android \u4e0a\u7684\u6807\u51c6 GPU API \u79fb\u52a8\u7aef\u63a8\u7406\uff08iOS\uff09 \u5426 \u76f4\u63a5\u4f7f\u7528 Metal\uff08MoltenVK \u589e\u52a0\u5f00\u9500\uff09 \u6d4f\u89c8\u5668\u63a8\u7406 WebGPU \u57fa\u4e8e Vulkan/Metal/DX12 \u6e38\u620f\u5f15\u64ce + ML \u662f \u5f15\u64ce\u5df2\u4f7f\u7528 Vulkan \u8fdb\u884c\u6e32\u67d3 \u8de8\u5e73\u53f0\u5e93 \u662f \u4e00\u5957\u4ee3\u7801\u652f\u6301\u6240\u6709 GPU \u5382\u5546 \u5b66\u4e60 GPU \u7f16\u7a0b \u89c6\u60c5\u51b5\u800c\u5b9a CUDA \u66f4\u5bb9\u6613\u4e0a\u624b\uff1bVulkan \u80fd\u5b66\u5230\u66f4\u591a"},{"location":"chapter%2016%3A%20SIMD%20and%20GPU%20programming/07.%20vulkan%20compute%20and%20cross-platform%20GPU/#g-lvulkan-vulkan-sdk","title":"\u7f16\u7801\u4efb\u52a1\uff08\u4f7f\u7528 g++ -lvulkan \u7f16\u8bd1\uff0c\u9700\u8981 Vulkan SDK\uff09","text":"
  1. \u7f16\u8bd1\u5e76\u8fd0\u884c\u4e0a\u9762\u7684\u5411\u91cf\u52a0\u6cd5\u793a\u4f8b\u3002\u4fee\u6539\u7740\u8272\u5668\u4ee5\u8ba1\u7b97 c[i] = a[i] * b[i] + a[i]\uff08\u878d\u5408\u4e58\u52a0\uff09\u5e76\u9a8c\u8bc1\u7ed3\u679c\u3002

  2. \u7f16\u5199\u4e00\u4e2a\u8ba1\u7b97\u7740\u8272\u5668\uff0c\u4f7f\u7528\u5171\u4eab\u5185\u5b58\u5bf9\u4e00\u884c\u6570\u636e\u5e94\u7528 softmax\uff08\u5305\u62ec\u6700\u5927\u503c\u548c\u6c42\u548c\u5f52\u7ea6\u6b65\u9aa4\uff09\u3002\u7528\u5df2\u77e5\u503c\u8fdb\u884c\u6d4b\u8bd5\u3002

// softmax.comp \u2014 \u7f16\u8bd1\u547d\u4ee4: glslangValidator -V softmax.comp -o softmax.spv\n#version 450\n\n#define WG_SIZE 256\n\nlayout(local_size_x = WG_SIZE) in;\n\nlayout(set = 0, binding = 0) buffer Input  { float input_data[]; };\nlayout(set = 0, binding = 1) buffer Output { float output_data[]; };\n\nlayout(push_constant) uniform PC { uint n; };\n\nshared float sdata[WG_SIZE];\n\nvoid main() {\n    uint gid = gl_GlobalInvocationID.x;\n    uint lid = gl_LocalInvocationID.x;\n\n    // \u6b65\u9aa4 1\uff1a\u627e\u6700\u5927\u503c\uff08\u6570\u503c\u7a33\u5b9a\u6027\uff09\n    sdata[lid] = (gid < n) ? input_data[gid] : -1e30;\n    barrier();\n    for (uint s = WG_SIZE / 2; s > 0; s >>= 1) {\n        if (lid < s) sdata[lid] = max(sdata[lid], sdata[lid + s]);\n        barrier();\n    }\n    float maxVal = sdata[0];\n    barrier();\n\n    // \u6b65\u9aa4 2\uff1a\u8ba1\u7b97 exp(x - max)\n    float expVal = (gid < n) ? exp(input_data[gid] - maxVal) : 0.0;\n    sdata[lid] = expVal;\n    barrier();\n\n    // \u6b65\u9aa4 3\uff1aexp \u503c\u6c42\u548c\n    for (uint s = WG_SIZE / 2; s > 0; s >>= 1) {\n        if (lid < s) sdata[lid] += sdata[lid + s];\n        barrier();\n    }\n    float sumExp = sdata[0];\n\n    // \u6b65\u9aa4 4\uff1a\u5f52\u4e00\u5316\n    if (gid < n) {\n        output_data[gid] = expVal / sumExp;\n    }\n}\n
  1. \u4fee\u6539 C++ \u5bbf\u4e3b\u4ee3\u7801\u4ee5\u5bf9\u8ba1\u7b97\u7740\u8272\u5668\u8fdb\u884c\u57fa\u51c6\u6d4b\u8bd5\uff1a\u4f7f\u7528 Vulkan \u65f6\u95f4\u6233\u67e5\u8be2\u6216 CPU \u7aef\u6805\u680f\u5bf9\u8c03\u5ea6\uff08\u6392\u9664\u8bbe\u7f6e\u9636\u6bb5\uff09\u8ba1\u65f6\uff0c\u5e76\u8ba1\u7b97\u4ee5 GB/s \u4e3a\u5355\u4f4d\u7684\u5b9e\u9645\u5e26\u5bbd\u3002
"},{"location":"chapter%2017%3A%20AI%20inference/01.%20quantisation/","title":"\u91cf\u5316","text":"

\u91cf\u5316\u964d\u4f4e\u6a21\u578b\u6743\u91cd\u548c\u6fc0\u6d3b\u503c\u7684\u7cbe\u5ea6\uff0c\u4f7f\u6a21\u578b\u66f4\u5c0f\u3001\u66f4\u5feb\u3001\u8fd0\u884c\u6210\u672c\u66f4\u4f4e\u3002\u672c\u6587\u6db5\u76d6\u6570\u5b57\u683c\u5f0f\u3001\u8bad\u7ec3\u540e\u91cf\u5316\u3001\u91cf\u5316\u611f\u77e5\u8bad\u7ec3\u3001\u4ec5\u6743\u91cd\u91cf\u5316\u65b9\u6cd5\uff08GPTQ\u3001AWQ\uff09\u3001\u6fc0\u6d3b\u503c\u91cf\u5316\u3001\u6df7\u5408\u7cbe\u5ea6\u548cKV\u7f13\u5b58\u91cf\u5316

"},{"location":"chapter%2017%3A%20AI%20inference/01.%20quantisation/#_2","title":"\u4e3a\u4ec0\u4e48\u8981\u91cf\u5316","text":""},{"location":"chapter%2017%3A%20AI%20inference/01.%20quantisation/#_3","title":"\u6570\u5b57\u683c\u5f0f","text":" \u683c\u5f0f \u4f4d\u6570 \u6307\u6570 \u5c3e\u6570 \u8303\u56f4 \u7528\u9014 FP32 32 8 23 \u00b13.4\u00d710\u00b3\u2078 \u8bad\u7ec3\uff08\u9ec4\u91d1\u6807\u51c6\uff09 TF32 19 8 10 \u00b13.4\u00d710\u00b3\u2078 Tensor Core\u8bad\u7ec3\uff08A100+\uff09 FP16 16 5 10 \u00b165504 \u6df7\u5408\u7cbe\u5ea6\u8bad\u7ec3 BF16 16 8 7 \u00b13.4\u00d710\u00b3\u2078 \u8bad\u7ec3\uff08\u4e0eFP32\u76f8\u540c\u7684\u8303\u56f4\uff09 FP8 E4M3 8 4 3 \u00b1448 \u524d\u5411\u4f20\u64ad\uff08Hopper+\uff09 FP8 E5M2 8 5 2 \u00b157344 \u68af\u5ea6\uff08\u66f4\u5bbd\u8303\u56f4\uff09 INT8 8 \u2014 \u2014 -128 \u5230 127 PTQ\u63a8\u7406 INT4 4 \u2014 \u2014 -8 \u5230 7 \u4ec5\u6743\u91cd\u91cf\u5316 INT2/\u4e09\u503c 2 \u2014 \u2014 {-1, 0, 1} \u6781\u9650\u538b\u7f29 "},{"location":"chapter%2017%3A%20AI%20inference/01.%20quantisation/#_4","title":"\u91cf\u5316\u65b9\u7a0b","text":" \\[x_q = \\text{clamp}\\left(\\text{round}\\left(\\frac{x}{\\text{scale}}\\right) + \\text{zero\\_point}, \\; q_{\\min}, \\; q_{\\max}\\right)\\] \\[\\hat{x} = \\text{scale} \\times (x_q - \\text{zero\\_point})\\]

"},{"location":"chapter%2017%3A%20AI%20inference/01.%20quantisation/#ptq","title":"\u8bad\u7ec3\u540e\u91cf\u5316\uff08PTQ\uff09","text":""},{"location":"chapter%2017%3A%20AI%20inference/01.%20quantisation/#_5","title":"\u6821\u51c6\u65b9\u6cd5","text":""},{"location":"chapter%2017%3A%20AI%20inference/01.%20quantisation/#ptq_1","title":"PTQ\u5b9e\u8df5","text":"
# \u4f7f\u7528PyTorch\u7684\u7b80\u5316PTQ\uff08\u6982\u5ff5\u6027\uff09\nimport torch\n\ndef quantise_tensor_symmetric(tensor, bits=8):\n    qmax = 2 ** (bits - 1) - 1  # INT8\u7684127\n    scale = tensor.abs().max() / qmax\n    quantised = torch.clamp(torch.round(tensor / scale), -qmax, qmax).to(torch.int8)\n    return quantised, scale\n\ndef dequantise(quantised, scale):\n    return quantised.float() * scale\n\n# \u91cf\u5316\u4e00\u4e2a\u6743\u91cd\u77e9\u9635\nweight = torch.randn(512, 512)  # \u9884\u8bad\u7ec3\u6743\u91cd\nweight_q, scale = quantise_tensor_symmetric(weight, bits=8)\nweight_reconstructed = dequantise(weight_q, scale)\n\n# \u91cf\u5316\u8bef\u5dee\nerror = (weight - weight_reconstructed).abs().mean()\nprint(f\"\u5e73\u5747\u7edd\u5bf9\u8bef\u5dee: {error:.6f}\")\nprint(f\"\u538b\u7f29\u6bd4: {weight.numel() * 4 / (weight_q.numel() * 1 + 4):.1f}x\")  # +4\u5b57\u8282\u7528\u4e8e\u7f29\u653e\u56e0\u5b50\n
"},{"location":"chapter%2017%3A%20AI%20inference/01.%20quantisation/#qat","title":"\u91cf\u5316\u611f\u77e5\u8bad\u7ec3\uff08QAT\uff09","text":" \\[\\text{\u524d\u5411: } \\hat{W} = \\text{\u53cd\u91cf\u5316}(\\text{\u91cf\u5316}(W))$$ $$\\text{\u53cd\u5411: } \\frac{\\partial L}{\\partial W} \\approx \\frac{\\partial L}{\\partial \\hat{W}}\\] "},{"location":"chapter%2017%3A%20AI%20inference/01.%20quantisation/#_6","title":"\u4ec5\u6743\u91cd\u91cf\u5316","text":""},{"location":"chapter%2017%3A%20AI%20inference/01.%20quantisation/#gptq","title":"GPTQ","text":" \\[\\hat{W}_{:,j} = \\text{quant}(W_{:,j}), \\quad W_{:,j+1:} \\mathrel{-}= \\frac{(\\hat{W}_{:,j} - W_{:,j}) \\cdot H_{j,j+1:}}{H_{j,j}}\\] "},{"location":"chapter%2017%3A%20AI%20inference/01.%20quantisation/#awq","title":"AWQ","text":""},{"location":"chapter%2017%3A%20AI%20inference/01.%20quantisation/#gguf-llamacpp","title":"GGUF / llama.cpp\u91cf\u5316","text":""},{"location":"chapter%2017%3A%20AI%20inference/01.%20quantisation/#quipquip","title":"QuIP\u548cQuIP","text":""},{"location":"chapter%2017%3A%20AI%20inference/01.%20quantisation/#spqr","title":"SpQR","text":""},{"location":"chapter%2017%3A%20AI%20inference/01.%20quantisation/#hqq","title":"HQQ","text":""},{"location":"chapter%2017%3A%20AI%20inference/01.%20quantisation/#aqlm","title":"AQLM","text":" \\[\\mathbf{w} \\approx \\mathbf{c}_1^{(1)} + \\mathbf{c}_2^{(2)} + \\cdots + \\mathbf{c}_M^{(M)}\\] "},{"location":"chapter%2017%3A%20AI%20inference/01.%20quantisation/#bitnet1llm","title":"BitNet\u548c1\u4f4dLLM","text":" \\[y_j = \\sum_i W_{ij} \\cdot x_i = \\sum_{i: W_{ij}=+1} x_i - \\sum_{i: W_{ij}=-1} x_i\\] "},{"location":"chapter%2017%3A%20AI%20inference/01.%20quantisation/#mx","title":"\u5fae\u7f29\u653e\uff08MX\uff09\u683c\u5f0f","text":" \u683c\u5f0f \u5171\u4eab\u6307\u6570 \u5143\u7d20\u4f4d\u6570 \u603b\u8ba1\uff08\u6bcf\u5143\u7d20\uff09 \u7b49\u4ef7 MXFP8 \u6bcf\u57578\u4f4d 8\uff08E4M3/E5M2\uff09 ~8 \u7c7b\u4f3cFP8\uff0c\u8303\u56f4\u66f4\u597d MXFP6 \u6bcf\u57578\u4f4d 6 ~6.5 \u4ecb\u4e8eFP8\u548cINT4\u4e4b\u95f4 MXFP4 \u6bcf\u57578\u4f4d 4 ~4.5 \u7c7b\u4f3cINT4\uff0c\u4f46\u6709\u6d6e\u70b9\u884c\u4e3a MXINT8 \u6bcf\u57578\u4f4d 8\uff08\u6574\u6570\uff09 ~8.5 INT8\uff0c\u5e26\u5171\u4eab\u7f29\u653e "},{"location":"chapter%2017%3A%20AI%20inference/01.%20quantisation/#fp8","title":"FP8\u8bad\u7ec3","text":""},{"location":"chapter%2017%3A%20AI%20inference/01.%20quantisation/#_7","title":"\u6fc0\u6d3b\u503c\u91cf\u5316","text":""},{"location":"chapter%2017%3A%20AI%20inference/01.%20quantisation/#_8","title":"\u6df7\u5408\u7cbe\u5ea6\u91cf\u5316","text":""},{"location":"chapter%2017%3A%20AI%20inference/01.%20quantisation/#kv","title":"KV\u7f13\u5b58\u91cf\u5316","text":" \\[\\text{KV\u7f13\u5b58\u5927\u5c0f} = 2 \\times n_{\\text{layers}} \\times n_{\\text{heads}} \\times d_{\\text{head}} \\times \\text{seq\\_len} \\times \\text{bytes\\_per\\_element}\\] "},{"location":"chapter%2017%3A%20AI%20inference/01.%20quantisation/#colabnotebook","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528CoLab\u6216notebook\uff09","text":"
  1. \u4ece\u5934\u5b9e\u73b0\u5bf9\u79f0INT8\u91cf\u5316\u3002\u91cf\u5316\u4e00\u4e2a\u6743\u91cd\u77e9\u9635\uff0c\u53cd\u91cf\u5316\u5b83\uff0c\u5e76\u6d4b\u91cf\u4f5c\u4e3a\u503c\u5206\u5e03\u51fd\u6570\u7684\u91cd\u5efa\u8bef\u5dee\u3002

    import jax.numpy as jnp\nimport jax\n\ndef quantise_int8(tensor):\n    scale = jnp.max(jnp.abs(tensor)) / 127.0\n    quantised = jnp.clip(jnp.round(tensor / scale), -127, 127).astype(jnp.int8)\n    return quantised, scale\n\ndef dequantise(quantised, scale):\n    return quantised.astype(jnp.float32) * scale\n\n# \u6b63\u5e38\u6743\u91cd\uff08\u5178\u578b\u8bad\u7ec3\u6a21\u578b\uff09\nkey = jax.random.PRNGKey(0)\nweights = jax.random.normal(key, (1024, 1024)) * 0.02\n\nq, s = quantise_int8(weights)\nrecon = dequantise(q, s)\n\nprint(f\"\u539f\u59cb:     {weights.nbytes / 1024:.0f} KB\")\nprint(f\"\u91cf\u5316\u540e:    {q.nbytes / 1024:.0f} KB ({weights.nbytes / q.nbytes:.0f}x \u66f4\u5c0f)\")\nprint(f\"\u5e73\u5747\u7edd\u5bf9\u8bef\u5dee: {jnp.abs(weights - recon).mean():.6f}\")\nprint(f\"\u6700\u5927\u7edd\u5bf9\u8bef\u5dee:  {jnp.abs(weights - recon).max():.6f}\")\nprint(f\"\u76f8\u5bf9\u8bef\u5dee: {jnp.abs(weights - recon).mean() / jnp.abs(weights).mean():.4%}\")\n

  2. \u6f14\u793a\u5f02\u5e38\u503c\u95ee\u9898\u3002\u521b\u5efa\u5177\u6709\u51e0\u4e2a\u6781\u7aef\u901a\u9053\u7684\u6fc0\u6d3b\u503c\uff0c\u5c55\u793a\u9010\u5f20\u91cf\u91cf\u5316\u5931\u8d25\u800c\u9010\u901a\u9053\u91cf\u5316\u6210\u529f\u3002

    import jax.numpy as jnp\nimport jax\n\nkey = jax.random.PRNGKey(42)\n\n# \u6fc0\u6d3b\u503c\uff1a\u5927\u591a\u6570\u901a\u9053\u6b63\u5e38\uff0c2\u4e2a\u901a\u9053\u6709100x\u5f02\u5e38\u503c\nactivations = jax.random.normal(key, (32, 512)) * 0.1\nactivations = activations.at[:, 0].set(activations[:, 0] * 100)   # \u5f02\u5e38\u901a\u9053\nactivations = activations.at[:, 1].set(activations[:, 1] * 50)    # \u5f02\u5e38\u901a\u9053\n\n# \u9010\u5f20\u91cf\u91cf\u5316\uff08\u6574\u4e2a\u5f20\u91cf\u4e00\u4e2a\u7f29\u653e\u56e0\u5b50\uff09\nscale_tensor = jnp.max(jnp.abs(activations)) / 127.0\nq_tensor = jnp.clip(jnp.round(activations / scale_tensor), -127, 127)\nrecon_tensor = q_tensor * scale_tensor\n\n# \u9010\u901a\u9053\u91cf\u5316\uff08\u6bcf\u901a\u9053\u4e00\u4e2a\u7f29\u653e\u56e0\u5b50\uff09\nscales_channel = jnp.max(jnp.abs(activations), axis=0) / 127.0\nq_channel = jnp.clip(jnp.round(activations / scales_channel), -127, 127)\nrecon_channel = q_channel * scales_channel\n\nerr_tensor = jnp.abs(activations - recon_tensor).mean()\nerr_channel = jnp.abs(activations - recon_channel).mean()\n\nprint(f\"\u9010\u5f20\u91cf\u8bef\u5dee: {err_tensor:.6f}\")\nprint(f\"\u9010\u901a\u9053\u8bef\u5dee: {err_channel:.6f}\")\nprint(f\"\u9010\u901a\u9053\u597d {err_tensor / err_channel:.1f}x\")\nprint(f\"\\n\u5f02\u5e38\u901a\u9053\u6d6a\u8d39\u4e86 {(activations.shape[1] - 2) / activations.shape[1]:.0%} \"\n      f\"\u7684\u91cf\u5316\u8303\u56f4\u7ed9 {2 / activations.shape[1]:.1%} \u7684\u901a\u9053\")\n

  3. \u8ba1\u7b97\u4e0d\u540c\u6a21\u578b\u5927\u5c0f\u548c\u5e8f\u5217\u957f\u5ea6\u7684KV\u7f13\u5b58\u5185\u5b58\u3002\u5c55\u793a\u4e3a\u4ec0\u4e48KV\u7f13\u5b58\u91cf\u5316\u5bf9\u957f\u4e0a\u4e0b\u6587\u6a21\u578b\u81f3\u5173\u91cd\u8981\u3002

    def kv_cache_gb(n_layers, n_heads, d_head, seq_len, bytes_per_elem):\n    return 2 * n_layers * n_heads * d_head * seq_len * bytes_per_elem / 1e9\n\nmodels = [\n    (\"Llama-7B\",  32, 32, 128),\n    (\"Llama-70B\", 80, 64, 128),\n    (\"GPT-4 (\u4f30\u8ba1)\", 120, 96, 128),\n]\n\nprint(f\"{'\u6a21\u578b':<15} {'\u5e8f\u5217\u957f\u5ea6':>8} {'FP16 (GB)':>10} {'INT8 (GB)':>10} {'INT4 (GB)':>10}\")\nprint(\"-\" * 60)\n\nfor name, layers, heads, d_head in models:\n    for seq_len in [4096, 32768, 131072]:\n        fp16 = kv_cache_gb(layers, heads, d_head, seq_len, 2)\n        int8 = kv_cache_gb(layers, heads, d_head, seq_len, 1)\n        int4 = kv_cache_gb(layers, heads, d_head, seq_len, 0.5)\n        print(f\"{name:<15} {seq_len:>8} {fp16:>9.1f}  {int8:>9.1f}  {int4:>9.1f}\")\n    print()\n

"},{"location":"chapter%2017%3A%20AI%20inference/02.%20efficient%20architectures/","title":"\u9ad8\u6548\u67b6\u6784","text":"

\u8ba9\u6a21\u578b\u66f4\u5feb\u4e0d\u4ec5\u4ec5\u662f\u964d\u4f4e\u7cbe\u5ea6\uff0c\u8fd8\u5728\u4e8e\u8bbe\u8ba1\u66f4\u667a\u80fd\u7684\u67b6\u6784\uff0c\u4f7f\u6bcf\u4e2atoken\u7684\u8ba1\u7b97\u91cf\u66f4\u5c11\u3002\u672c\u6587\u6db5\u76d6StreamingLLM\u3001\u7a00\u758f\u548c\u7ebf\u6027\u6ce8\u610f\u529b\u3001\u591a\u67e5\u8be2\u548c\u5206\u7ec4\u67e5\u8be2\u6ce8\u610f\u529b\u3001\u63a8\u7406\u65f6\u7684\u6df7\u5408\u4e13\u5bb6\u3001\u77e5\u8bc6\u84b8\u998f\u3001\u526a\u679d\u548c\u795e\u7ecf\u67b6\u6784\u641c\u7d22

"},{"location":"chapter%2017%3A%20AI%20inference/02.%20efficient%20architectures/#streamingllm","title":"StreamingLLM\uff1a\u65e0\u9650\u957f\u5ea6\u751f\u6210","text":" \\[\\text{\u7f13\u5b58} = [\\text{token}_0, \\text{token}_1, \\text{token}_{t-w+1}, \\ldots, \\text{token}_t]\\] "},{"location":"chapter%2017%3A%20AI%20inference/02.%20efficient%20architectures/#_2","title":"\u7a00\u758f\u6ce8\u610f\u529b","text":" "},{"location":"chapter%2017%3A%20AI%20inference/02.%20efficient%20architectures/#_3","title":"\u7ebf\u6027\u6ce8\u610f\u529b\u548c\u72b6\u6001\u7a7a\u95f4\u6a21\u578b","text":" \\[\\text{\u6807\u51c6: } O = \\text{softmax}(QK^T / \\sqrt{d}) V$$ $$\\text{\u7ebf\u6027: } O = \\phi(Q) (\\phi(K)^T V)\\] \\[h_t = \\bar{A} h_{t-1} + \\bar{B} x_t, \\quad y_t = C h_t\\] "},{"location":"chapter%2017%3A%20AI%20inference/02.%20efficient%20architectures/#_4","title":"\u591a\u67e5\u8be2\u548c\u5206\u7ec4\u67e5\u8be2\u6ce8\u610f\u529b","text":" \\[\\text{MHA: } h \\text{ \u4e2a\u5934, } h \\text{ \u4e2a K/V \u96c6} \\quad \\to \\quad \\text{GQA: } h \\text{ \u4e2a\u5934, } g \\text{ \u4e2a K/V \u96c6} \\quad \\to \\quad \\text{MQA: } h \\text{ \u4e2a\u5934, } 1 \\text{ \u4e2a K/V \u96c6}\\]

"},{"location":"chapter%2017%3A%20AI%20inference/02.%20efficient%20architectures/#mla","title":"\u591a\u5934\u6f5c\u5728\u6ce8\u610f\u529b\uff08MLA\uff09","text":" \\[\\mathbf{c}_t = W_{\\text{compress}} \\cdot [\\mathbf{k}_t; \\mathbf{v}_t], \\quad \\mathbf{k}_t = W_K^{\\text{up}} \\cdot \\mathbf{c}_t, \\quad \\mathbf{v}_t = W_V^{\\text{up}} \\cdot \\mathbf{c}_t\\] "},{"location":"chapter%2017%3A%20AI%20inference/02.%20efficient%20architectures/#flash-attention","title":"Flash Attention","text":""},{"location":"chapter%2017%3A%20AI%20inference/02.%20efficient%20architectures/#ring-attention","title":"Ring Attention","text":""},{"location":"chapter%2017%3A%20AI%20inference/02.%20efficient%20architectures/#_5","title":"\u63a8\u7406\u65f6\u7684\u6df7\u5408\u4e13\u5bb6","text":""},{"location":"chapter%2017%3A%20AI%20inference/02.%20efficient%20architectures/#_6","title":"\u77e5\u8bc6\u84b8\u998f","text":" \\[\\mathcal{L} = \\alpha \\cdot \\text{KL}(p_{\\text{teacher}}^{T} \\| p_{\\text{student}}^{T}) + (1 - \\alpha) \\cdot \\mathcal{L}_{\\text{CE}}(y, p_{\\text{student}})\\] "},{"location":"chapter%2017%3A%20AI%20inference/02.%20efficient%20architectures/#_7","title":"\u526a\u679d","text":""},{"location":"chapter%2017%3A%20AI%20inference/02.%20efficient%20architectures/#nas","title":"\u795e\u7ecf\u67b6\u6784\u641c\u7d22\uff08NAS\uff09","text":""},{"location":"chapter%2017%3A%20AI%20inference/02.%20efficient%20architectures/#colabnotebook","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528CoLab\u6216notebook\uff09","text":"
  1. \u5b9e\u73b0\u6ed1\u52a8\u7a97\u53e3\u6ce8\u610f\u529b\uff0c\u5e76\u4e0e\u5168\u6ce8\u610f\u529b\u6bd4\u8f83\u5185\u5b58\u4f7f\u7528\u3002

    import jax\nimport jax.numpy as jnp\n\ndef full_attention(Q, K, V):\n    \"\"\"\u6807\u51c6O(n\u00b2)\u6ce8\u610f\u529b\u3002\"\"\"\n    scores = Q @ K.T / jnp.sqrt(Q.shape[-1])\n    weights = jax.nn.softmax(scores, axis=-1)\n    return weights @ V\n\ndef sliding_window_attention(Q, K, V, window_size=128):\n    \"\"\"\u6ed1\u52a8\u7a97\u53e3\u6ce8\u610f\u529b\uff1a\u6bcf\u4e2atoken\u5173\u6ce8\u524dwindow_size\u4e2atoken\u3002\"\"\"\n    n = Q.shape[0]\n    d = Q.shape[-1]\n    output = jnp.zeros_like(Q)\n\n    for i in range(n):\n        start = max(0, i - window_size + 1)\n        k_window = K[start:i+1]\n        v_window = V[start:i+1]\n        scores = Q[i] @ k_window.T / jnp.sqrt(d)\n        weights = jax.nn.softmax(scores)\n        output = output.at[i].set(weights @ v_window)\n\n    return output\n\nn, d = 512, 64\nkey = jax.random.PRNGKey(0)\nQ = jax.random.normal(key, (n, d))\nK = jax.random.normal(jax.random.PRNGKey(1), (n, d))\nV = jax.random.normal(jax.random.PRNGKey(2), (n, d))\n\nprint(f\"\u5168\u6ce8\u610f\u529b\u5185\u5b58:    O(n\u00b2) = {n*n} \u4e2a\u6761\u76ee\")\nprint(f\"\u7a97\u53e3 (w=128) \u5185\u5b58:   O(n*w) = {n*128} \u4e2a\u6761\u76ee\")\nprint(f\"\u51cf\u5c11: {n*n / (n*128):.1f}x\")\n

  2. \u6bd4\u8f83MHA\u3001GQA\u548cMQA\u7684KV\u7f13\u5b58\u5927\u5c0f\u3002\u5c55\u793a\u4e3a\u4ec0\u4e48GQA\u662f\u5b9e\u9645\u7684\u6700\u4f73\u9009\u62e9\u3002

    def kv_cache_size(n_heads, n_kv_heads, d_head, seq_len, bytes=2):\n    \"\"\"KV\u7f13\u5b58\u5927\u5c0f\uff08MB\uff09\u3002\"\"\"\n    return 2 * n_kv_heads * d_head * seq_len * bytes / 1e6\n\nn_heads = 32\nd_head = 128\nseq_len = 32768\n\nmha = kv_cache_size(n_heads, n_heads, d_head, seq_len)       # 32\u4e2aKV\u5934\ngqa = kv_cache_size(n_heads, 8, d_head, seq_len)              # 8\u4e2aKV\u5934\nmqa = kv_cache_size(n_heads, 1, d_head, seq_len)              # 1\u4e2aKV\u5934\n\nprint(f\"MHA (32\u4e2aKV\u5934): {mha:.0f} MB \u6bcf\u5c42\")\nprint(f\"GQA (8\u4e2aKV\u5934):  {gqa:.0f} MB \u6bcf\u5c42 ({mha/gqa:.0f}x \u66f4\u5c0f)\")\nprint(f\"MQA (1\u4e2aKV\u5934):   {mqa:.0f} MB \u6bcf\u5c42 ({mha/mqa:.0f}x \u66f4\u5c0f)\")\n

  3. \u901a\u8fc7\u4ece\u968f\u673a\u6ce8\u610f\u529b\u5c42\u4e2d\u79fb\u9664\u6700\u4e0d\u91cd\u8981\u7684\u6ce8\u610f\u529b\u5934\u5e76\u6d4b\u91cf\u8f93\u51fa\u53d8\u5316\u6765\u6a21\u62df\u7ed3\u6784\u5316\u526a\u679d\u3002

    import jax\nimport jax.numpy as jnp\n\nkey = jax.random.PRNGKey(0)\nn_heads, seq_len, d_head = 8, 64, 32\n\n# \u968f\u673a\u591a\u5934\u6ce8\u610f\u529b\u8f93\u51fa\uff08\u6bcf\u4e2a\u5934\u4e00\u4e2a\uff09\nhead_outputs = jax.random.normal(key, (n_heads, seq_len, d_head))\n\n# \u5b8c\u6574\u8f93\u51fa\uff1a\u8fde\u63a5\u6240\u6709\u5934\nfull_output = head_outputs.reshape(seq_len, n_heads * d_head)\n\n# \u91cd\u8981\u6027\uff1a\u901a\u8fc7\u8303\u6570\u5ea6\u91cf\u6bcf\u4e2a\u5934\u7684\u8d21\u732e\nhead_norms = jnp.linalg.norm(head_outputs, axis=(1, 2))\nprint(\"\u5934\u91cd\u8981\u6027\uff08\u6309\u8303\u6570\uff09:\", jnp.round(head_norms, 2))\n\n# \u526a\u679d\u6700\u4e0d\u91cd\u8981\u7684\u5934\nfor n_keep in [8, 6, 4, 2]:\n    top_heads = jnp.argsort(head_norms)[-n_keep:]\n    pruned = head_outputs[top_heads].reshape(seq_len, n_keep * d_head)\n\n    # \u586b\u5145\u5230\u539f\u59cb\u5927\u5c0f\u7528\u4e8e\u6bd4\u8f83\uff08\u5c06\u526a\u6389\u7684\u5934\u8bbe\u4e3a\u96f6\uff09\n    full_pruned = jnp.zeros_like(head_outputs)\n    full_pruned = full_pruned.at[top_heads].set(head_outputs[top_heads])\n    full_pruned = full_pruned.reshape(seq_len, n_heads * d_head)\n\n    error = jnp.linalg.norm(full_output - full_pruned) / jnp.linalg.norm(full_output)\n    print(f\"\u4fdd\u7559 {n_keep}/{n_heads} \u4e2a\u5934: \u76f8\u5bf9\u8bef\u5dee = {error:.4f}, \"\n          f\"\u5185\u5b58 = {n_keep/n_heads:.0%}\")\n

"},{"location":"chapter%2017%3A%20AI%20inference/03.%20serving%20and%20batching/","title":"\u670d\u52a1\u4e0e\u6279\u5904\u7406","text":"

\u5411\u6570\u5343\u5e76\u53d1\u7528\u6237\u63d0\u4f9bLLM\u670d\u52a1\u9700\u8981\u7684\u4e0d\u53ea\u662f\u52a0\u8f7d\u6a21\u578b\u548c\u8fd0\u884c\u63a8\u7406\u3002\u672c\u6587\u6db5\u76d6\u9884\u586b\u5145-\u89e3\u7801\u5206\u79bb\u3001\u8fde\u7eed\u6279\u5904\u7406\u3001PagedAttention\u548cvLLM\u3001\u8c03\u5ea6\u7b56\u7565\u3001\u5206\u79bb\u5f0f\u670d\u52a1\u3001\u591a\u6a21\u578b\u548cLoRA\u670d\u52a1\uff0c\u4ee5\u53ca\u5173\u952e\u6307\u6807

"},{"location":"chapter%2017%3A%20AI%20inference/03.%20serving%20and%20batching/#vs","title":"\u9884\u586b\u5145 vs \u89e3\u7801\uff1a\u4e24\u4e2a\u622a\u7136\u4e0d\u540c\u7684\u9636\u6bb5","text":" \u9884\u586b\u5145 \u89e3\u7801 \u5904\u7406\u7684token \u4e00\u6b21\u6027\u5168\u90e8\uff08\u5e76\u884c\uff09 \u4e00\u6b21\u4e00\u4e2a\uff08\u987a\u5e8f\uff09 \u74f6\u9888 \u8ba1\u7b97\uff08FLOPS\uff09 \u5185\u5b58\u5e26\u5bbd \u7b97\u672f\u5f3a\u5ea6 \u9ad8 \u975e\u5e38\u4f4e GPU\u5229\u7528\u7387 \u9ad8\uff0850-80%\uff09 \u4f4e\uff081-10%\uff09\uff0c\u65e0\u6279\u5904\u7406\u65f6 \u5ef6\u8fdf\u6307\u6807 \u9996token\u65f6\u95f4\uff08TTFT\uff09 \u6bcf\u8f93\u51fatoken\u65f6\u95f4\uff08TPOT\uff09 "},{"location":"chapter%2017%3A%20AI%20inference/03.%20serving%20and%20batching/#_2","title":"\u9759\u6001\u6279\u5904\u7406\uff08\u6734\u7d20\u65b9\u6cd5\uff09","text":"

"},{"location":"chapter%2017%3A%20AI%20inference/03.%20serving%20and%20batching/#_3","title":"\u8fde\u7eed\u6279\u5904\u7406","text":""},{"location":"chapter%2017%3A%20AI%20inference/03.%20serving%20and%20batching/#pagedattentionvllm","title":"PagedAttention\u548cvLLM","text":" "},{"location":"chapter%2017%3A%20AI%20inference/03.%20serving%20and%20batching/#_4","title":"\u8c03\u5ea6\u7b56\u7565","text":""},{"location":"chapter%2017%3A%20AI%20inference/03.%20serving%20and%20batching/#_5","title":"\u5206\u79bb\u5f0f\u670d\u52a1","text":""},{"location":"chapter%2017%3A%20AI%20inference/03.%20serving%20and%20batching/#lora","title":"\u591a\u6a21\u578b\u548cLoRA\u670d\u52a1","text":""},{"location":"chapter%2017%3A%20AI%20inference/03.%20serving%20and%20batching/#_6","title":"\u53d7\u9650\u548c\u5f15\u5bfc\u751f\u6210","text":""},{"location":"chapter%2017%3A%20AI%20inference/03.%20serving%20and%20batching/#_7","title":"\u8bf7\u6c42\u8def\u7531","text":""},{"location":"chapter%2017%3A%20AI%20inference/03.%20serving%20and%20batching/#_8","title":"\u63a8\u7406\u6307\u6807","text":" \u6307\u6807 \u6d4b\u91cf\u5185\u5bb9 \u76ee\u6807\uff08\u5bf9\u8bdd\u5f0f\uff09 \u76ee\u6807\uff08\u6279\u5904\u7406\uff09 TTFT \u9996token\u65f6\u95f4 <1 s \u4e0d\u592a\u91cd\u8981 TPOT \u6bcf\u8f93\u51fatoken\u65f6\u95f4 <100 ms \u4e0d\u592a\u91cd\u8981 \u541e\u5410\u91cf token/\u79d2\uff08\u603b\u8ba1\uff09 \u4e0d\u592a\u91cd\u8981 \u6700\u5927\u5316 p99\u5ef6\u8fdf \u6700\u5dee\u76841%\u8bf7\u6c42 <5 s <30 s \u6bcftoken\u6210\u672c $/100\u4e07token \u6700\u5c0f\u5316 \u6700\u5c0f\u5316 SLO\u5408\u89c4\u7387 \u6ee1\u8db3\u5ef6\u8fdf\u76ee\u6807\u7684\u8bf7\u6c42\u767e\u5206\u6bd4 >99% >95% "},{"location":"chapter%2017%3A%20AI%20inference/03.%20serving%20and%20batching/#colabnotebook","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528CoLab\u6216notebook\uff09","text":"
  1. \u6a21\u62df\u8fde\u7eedvs\u9759\u6001\u6279\u5904\u7406\u5e76\u6d4b\u91cf\u541e\u5410\u91cf\u5dee\u5f02\u3002

    import random\nimport time\n\ndef simulate_static_batching(requests, batch_size=8):\n    \"\"\"\u5728\u56fa\u5b9a\u6279\u6b21\u4e2d\u5904\u7406\u8bf7\u6c42\u3002\u7b49\u5f85\u6240\u6709\u5b8c\u6210\u3002\"\"\"\n    total_tokens = 0\n    total_time = 0\n\n    for i in range(0, len(requests), batch_size):\n        batch = requests[i:i + batch_size]\n        max_len = max(r['output_len'] for r in batch)\n        # \u6279\u6b21\u4e2d\u6240\u6709\u8bf7\u6c42\u8017\u65f6\u7b49\u4e8e\u6700\u957f\u8bf7\u6c42\n        batch_time = max_len * 0.01  # \u6bcftoken 10ms\n        total_time += batch_time\n        total_tokens += sum(r['output_len'] for r in batch)\n\n    return total_tokens / total_time  # token/\u79d2\n\ndef simulate_continuous_batching(requests, max_batch=8):\n    \"\"\"\u4f7f\u7528\u8fde\u7eed\u6279\u5904\u7406\u5904\u7406\u3002\u79fb\u9664\u5b8c\u6210\u8bf7\u6c42\uff0c\u6dfb\u52a0\u65b0\u8bf7\u6c42\u3002\"\"\"\n    total_tokens = 0\n    total_time = 0\n    active = []\n    queue = list(requests)\n\n    while active or queue:\n        # \u586b\u5145\u6279\u6b21\n        while len(active) < max_batch and queue:\n            active.append({'remaining': queue.pop(0)['output_len']})\n\n        if not active:\n            break\n\n        # \u4e00\u4e2a\u89e3\u7801\u6b65\u9aa4\uff1a\u6240\u6709\u6d3b\u8dc3\u8bf7\u6c42\u751f\u62101\u4e2atoken\n        for req in active:\n            req['remaining'] -= 1\n        total_tokens += len(active)\n        total_time += 0.01  # \u6bcf\u6b6510ms\n\n        # \u79fb\u9664\u5b8c\u6210\u7684\u8bf7\u6c42\n        active = [r for r in active if r['remaining'] > 0]\n\n    return total_tokens / total_time\n\n# \u751f\u6210\u5177\u6709\u4e0d\u540c\u8f93\u51fa\u957f\u5ea6\u7684\u8bf7\u6c42\nrandom.seed(42)\nrequests = [{'output_len': random.randint(10, 500)} for _ in range(100)]\n\nstatic_tps = simulate_static_batching(requests)\ncontinuous_tps = simulate_continuous_batching(requests)\n\nprint(f\"\u9759\u6001\u6279\u5904\u7406:     {static_tps:.0f} tokens/s\")\nprint(f\"\u8fde\u7eed\u6279\u5904\u7406: {continuous_tps:.0f} tokens/s\")\nprint(f\"\u52a0\u901f\u6bd4: {continuous_tps / static_tps:.1f}x\")\n

  2. \u8ba1\u7b97PagedAttention\u7684KV\u7f13\u5b58\u5185\u5b58\u8282\u7701\u3002\u6bd4\u8f83\u9884\u5206\u914d\uff08\u6700\u574f\u60c5\u51b5\uff09vs\u5206\u9875\uff08\u5b9e\u9645\u4f7f\u7528\uff09\u3002

    def paged_vs_preallocated(n_requests, max_seq_len, avg_seq_len, page_size, kv_per_token_bytes):\n    \"\"\"\u6bd4\u8f83\u5185\u5b58\u4f7f\u7528\uff1a\u9884\u5206\u914dvs\u5206\u9875KV\u7f13\u5b58\u3002\"\"\"\n    # \u9884\u5206\u914d\uff1a\u6bcf\u4e2a\u8bf7\u6c42\u83b7\u5f97max_seq_len\u4e2a\u69fd\u4f4d\n    preallocated_gb = n_requests * max_seq_len * kv_per_token_bytes / 1e9\n\n    # \u5206\u9875\uff1a\u53ea\u5206\u914d\u4f7f\u7528\u7684\u90e8\u5206\uff08\u6309\u9875\u7c92\u5ea6\uff09\n    import math\n    avg_pages = math.ceil(avg_seq_len / page_size)\n    paged_gb = n_requests * avg_pages * page_size * kv_per_token_bytes / 1e9\n\n    waste_preallocated = (max_seq_len - avg_seq_len) / max_seq_len\n    waste_paged = (avg_pages * page_size - avg_seq_len) / (avg_pages * page_size)\n\n    print(f\"\u8bf7\u6c42\u6570: {n_requests}, \u6700\u5927\u5e8f\u5217: {max_seq_len}, \u5e73\u5747\u5e8f\u5217: {avg_seq_len}\")\n    print(f\"  \u9884\u5206\u914d: {preallocated_gb:.1f} GB (\u6d6a\u8d39: {waste_preallocated:.0%})\")\n    print(f\"  \u5206\u9875:        {paged_gb:.1f} GB (\u6d6a\u8d39: {waste_paged:.0%})\")\n    print(f\"  \u8282\u7701:      {preallocated_gb - paged_gb:.1f} GB ({preallocated_gb/paged_gb:.1f}x)\")\n    print()\n\n# Llama-70B\uff1a\u6bcf\u5c42\u6bcftoken\u7ea61.3 KB\uff0c80\u5c42 = \u6bcftoken\u7ea6100 KB\u603b\u8ba1\nkv_bytes = 100_000\n\n# \u573a\u666f1\uff1a\u77ed\u8bf7\u6c42\uff0c\u5927\u6700\u5927\u503c\npaged_vs_preallocated(256, max_seq_len=4096, avg_seq_len=256, page_size=16, kv_per_token_bytes=kv_bytes)\n\n# \u573a\u666f2\uff1a\u4e0d\u540c\u957f\u5ea6\npaged_vs_preallocated(256, max_seq_len=8192, avg_seq_len=1024, page_size=16, kv_per_token_bytes=kv_bytes)\n\n# \u573a\u666f3\uff1a\u957f\u4e0a\u4e0b\u6587\npaged_vs_preallocated(64, max_seq_len=131072, avg_seq_len=16000, page_size=16, kv_per_token_bytes=kv_bytes)\n

"},{"location":"chapter%2017%3A%20AI%20inference/04.%20edge%20inference/","title":"\u8fb9\u7f18\u63a8\u7406","text":"

\u8fb9\u7f18\u63a8\u7406\u5728\u7528\u6237\u8bbe\u5907\uff08\u624b\u673a\u3001\u7b14\u8bb0\u672c\u7535\u8111\u3001\u7269\u8054\u7f51\u4f20\u611f\u5668\uff09\u4e0a\u8fd0\u884c\u6a21\u578b\uff0c\u65e0\u9700\u5c06\u6570\u636e\u53d1\u9001\u5230\u4e91\u7aef\u3002\u672c\u6587\u6db5\u76d6\u8fb9\u7f18\u9650\u5236\u3001\u6a21\u578b\u538b\u7f29\u6d41\u6c34\u7ebf\u3001\u8bbe\u5907\u7aef\u8fd0\u884c\u65f6\u3001\u7f16\u8bd1\u5668\u6808\u3001\u786c\u4ef6\u76ee\u6807\uff08NPU\u3001\u795e\u7ecf\u5f15\u64ce\uff09\u3001\u8bbe\u5907\u7aefLLM\u3001\u8054\u90a6\u5b66\u4e60\u548c\u5ef6\u8fdf\u4f18\u5316

"},{"location":"chapter%2017%3A%20AI%20inference/04.%20edge%20inference/#_2","title":"\u8fb9\u7f18\u7ea6\u675f","text":"\u8d44\u6e90 \u4e91GPU\uff08H100\uff09 \u7b14\u8bb0\u672c\u7535\u8111\uff08M4\uff09 \u624b\u673a\uff08Snapdragon 8 Gen 3\uff09 IoT\uff08ESP32\uff09 \u5185\u5b58 80 GB HBM3 16-36 GB \u7edf\u4e00\u5185\u5b58 8-12 GB LPDDR5 520 KB \u8ba1\u7b97 989 TFLOPS\uff08FP8\uff09 38 TOPS\uff08\u795e\u7ecf\u5f15\u64ce\uff09 45 TOPS\uff08NPU\uff09 0.001 TOPS \u529f\u8017 700 W 15-30 W 5-10 W 0.1 W \u5b58\u50a8 TB 256 GB-2 TB 128-512 GB 4 MB "},{"location":"chapter%2017%3A%20AI%20inference/04.%20edge%20inference/#_3","title":"\u6a21\u578b\u538b\u7f29\u6d41\u6c34\u7ebf","text":"
\u5b8c\u6574\u6a21\u578b\uff08FP32\uff0c70B\u53c2\u6570\uff09\n    \u2193 \u77e5\u8bc6\u84b8\u998f \u2192 \u66f4\u5c0f\u6a21\u578b\uff087B\u53c2\u6570\uff09\n    \u2193 \u7ed3\u6784\u5316\u526a\u679d \u2192 \u79fb\u9664\u5197\u4f59\u5934/\u5c42\uff084B\u6709\u6548\uff09\n    \u2193 \u91cf\u5316\uff08INT4\uff09 \u2192 4\u500d\u66f4\u5c0f\uff082 GB\uff09\n    \u2193 \u7f16\u8bd1\u5668\u4f18\u5316 \u2192 \u878d\u5408\u5185\u6838\uff0c\u4f18\u5316\u5185\u5b58\u5e03\u5c40\n    \u2193 \u8fd0\u884c\u65f6 \u2192 \u8bbe\u5907\u7aef\u6267\u884c\n
"},{"location":"chapter%2017%3A%20AI%20inference/04.%20edge%20inference/#_4","title":"\u8bbe\u5907\u7aef\u8fd0\u884c\u65f6","text":""},{"location":"chapter%2017%3A%20AI%20inference/04.%20edge%20inference/#_5","title":"\u7f16\u8bd1\u5668\u6808","text":"
PyTorch\u6a21\u578b\n    \u2193 \u5bfc\u51fa\uff08torch.export\u3001ONNX\u3001TorchScript\uff09\n\u56feIR\uff08\u4e2d\u95f4\u8868\u793a\uff09\n    \u2193 \u56fe\u4f18\u5316\n        - \u5e38\u91cf\u6298\u53e0\uff08\u7f16\u8bd1\u65f6\u8ba1\u7b97\u5e38\u91cf\u8868\u8fbe\u5f0f\uff09\n        - \u6b7b\u4ee3\u7801\u6d88\u9664\uff08\u79fb\u9664\u672a\u4f7f\u7528\u7684\u64cd\u4f5c\uff09\n        - \u7b97\u5b50\u878d\u5408\uff08conv + bn + relu \u2192 \u5355\u4e2a\u878d\u5408\u64cd\u4f5c\uff09\n        - \u5e03\u5c40\u8f6c\u6362\uff08NCHW \u2192 NHWC\u7528\u4e8eARM\uff0c\u901a\u9053\u6700\u540e\uff09\n    \u2193 \u964d\u7ea7\n\u786c\u4ef6\u7279\u5b9aIR\n    \u2193 \u540e\u7aef\u4f18\u5316\n        - \u5206\u5757\u548c\u5faa\u73af\u6392\u5e8f\uff08\u7f13\u5b58\u53cb\u597d\u7684\u8bbf\u95ee\u6a21\u5f0f\uff09\n        - \u5411\u91cf\u5316\uff08SIMD\uff0c\u7b2c16\u7ae0\uff09\n        - \u5185\u5b58\u89c4\u5212\uff08\u91cd\u7528\u7f13\u51b2\u533a\u4ee5\u6700\u5c0f\u5316\u5cf0\u503c\u5185\u5b58\uff09\n        - \u5185\u6838\u9009\u62e9\uff08\u4e3a\u6bcf\u4e2a\u64cd\u4f5c\u9009\u62e9\u6700\u4f73\u5b9e\u73b0\uff09\n    \u2193 \u4ee3\u7801\u751f\u6210\n\u673a\u5668\u4ee3\u7801 / NPU\u6307\u4ee4\n
"},{"location":"chapter%2017%3A%20AI%20inference/04.%20edge%20inference/#_6","title":"\u786c\u4ef6\u76ee\u6807","text":""},{"location":"chapter%2017%3A%20AI%20inference/04.%20edge%20inference/#gpu","title":"\u79fb\u52a8GPU","text":""},{"location":"chapter%2017%3A%20AI%20inference/04.%20edge%20inference/#npu","title":"\u795e\u7ecf\u5904\u7406\u5355\u5143\uff08NPU\uff09","text":""},{"location":"chapter%2017%3A%20AI%20inference/04.%20edge%20inference/#llm","title":"\u8bbe\u5907\u7aefLLM","text":" \u6a21\u578b \u53c2\u6570 \u91cf\u5316\u540e\u5927\u5c0f \u76ee\u6807\u8bbe\u5907 \u6027\u80fd Phi-3 Mini 3.8B ~2 GB\uff08Q4\uff09 \u624b\u673a/\u7b14\u8bb0\u672c iPhone 15\u4e0a~15 tokens/s Gemma 2B 2B ~1.5 GB\uff08Q4\uff09 \u624b\u673a Pixel 8\u4e0a~20 tokens/s Llama 3.2 1B 1B ~700 MB\uff08Q4\uff09 \u624b\u673a ~30 tokens/s Llama 3.2 3B 3B ~2 GB\uff08Q4\uff09 \u624b\u673a/\u7b14\u8bb0\u672c ~15 tokens/s Llama 3.1 8B 8B ~4.5 GB\uff08Q4\uff09 \u7b14\u8bb0\u672c M2\u4e0a~20 tokens/s "},{"location":"chapter%2017%3A%20AI%20inference/04.%20edge%20inference/#_7","title":"\u8054\u90a6\u5b66\u4e60","text":""},{"location":"chapter%2017%3A%20AI%20inference/04.%20edge%20inference/#_8","title":"\u5ef6\u8fdf\u4f18\u5316","text":""},{"location":"chapter%2017%3A%20AI%20inference/04.%20edge%20inference/#colabnotebook","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528CoLab\u6216notebook\uff09","text":"
  1. \u6a21\u62df\u6a21\u578b\u538b\u7f29\u6d41\u6c34\u7ebf\u3002\u4ecefloat32\u6a21\u578b\u5f00\u59cb\uff0c\u4f9d\u6b21\u5e94\u7528\u84b8\u998f\uff08\u6a21\u62df\uff09\u3001\u526a\u679d\u548c\u91cf\u5316\uff0c\u5e76\u8ddf\u8e2a\u6bcf\u4e00\u6b65\u7684\u5927\u5c0f\u3002

    def compression_pipeline(original_params_M, original_bits=32):\n    size_mb = original_params_M * 1e6 * original_bits / 8 / 1e6\n\n    print(f\"\u539f\u59cb: {original_params_M}M \u53c2\u6570, {original_bits}-\u4f4d \u2192 {size_mb:.0f} MB\")\n\n    # \u6b65\u9aa41\uff1a\u77e5\u8bc6\u84b8\u998f\uff08\u51cf\u5c11\u53c2\u6570\uff09\n    distilled_params = original_params_M * 0.15  # 70B \u2192 ~10B \u7b49\u4ef7\n    size_mb = distilled_params * 1e6 * original_bits / 8 / 1e6\n    print(f\"\u84b8\u998f\u540e ({distilled_params:.0f}M \u53c2\u6570): {size_mb:.0f} MB\")\n\n    # \u6b65\u9aa42\uff1a\u7ed3\u6784\u5316\u526a\u679d\uff08\u79fb\u9664\u5269\u4f5930%\uff09\n    pruned_params = distilled_params * 0.7\n    size_mb = pruned_params * 1e6 * original_bits / 8 / 1e6\n    print(f\"\u526a\u679d\u540e ({pruned_params:.0f}M \u53c2\u6570): {size_mb:.0f} MB\")\n\n    # \u6b65\u9aa43\uff1aINT4\u91cf\u5316\n    size_mb = pruned_params * 1e6 * 4 / 8 / 1e6\n    print(f\"INT4\u91cf\u5316\u540e: {size_mb:.0f} MB\")\n\n    print(f\"\u603b\u538b\u7f29\u6bd4: {original_params_M * 1e6 * original_bits / 8 / 1e6 / size_mb:.0f}x\")\n\nprint(\"=== \u4ece70B\u6a21\u578b\u5f00\u59cb ===\")\ncompression_pipeline(70000)\n\nprint(\"\\n=== \u4ece7B\u6a21\u578b\u5f00\u59cb ===\")\ncompression_pipeline(7000)\n

  2. \u4f30\u8ba1\u8bbe\u5907\u7aef\u63a8\u7406\u5ef6\u8fdf\u3002\u7ed9\u5b9a\u6a21\u578b\u7684\u64cd\u4f5c\u8ba1\u6570\u548c\u786c\u4ef6\u89c4\u683c\uff0c\u8ba1\u7b97\u662f\u5426\u6ee1\u8db3\u5ef6\u8fdf\u76ee\u6807\u3002

    def estimate_latency(model_name, params_M, bits, compute_tops, mem_bw_gbs, seq_len=256):\n    \"\"\"\u4f30\u8ba1\u5185\u5b58\u5e26\u5bbd\u53d7\u9650\u6a21\u578b\u7684token\u751f\u6210\u5ef6\u8fdf\u3002\"\"\"\n    # \u6a21\u578b\u5927\u5c0f\uff08\u5b57\u8282\uff09\n    model_bytes = params_M * 1e6 * bits / 8\n\n    # \u89e3\u7801\u662f\u5185\u5b58\u53d7\u9650\u7684\uff1a\u6bcftoken\u5fc5\u987b\u52a0\u8f7d\u6574\u4e2a\u6a21\u578b\n    time_per_token_ms = model_bytes / (mem_bw_gbs * 1e9) * 1000\n\n    # \u6bcf\u79d2token\u6570\n    tokens_per_sec = 1000 / time_per_token_ms\n\n    print(f\"{model_name}: {params_M/1000:.1f}B \u53c2\u6570 @ {bits}-\u4f4d = {model_bytes/1e9:.1f} GB\")\n    print(f\"  \u5185\u5b58\u5e26\u5bbd: {mem_bw_gbs} GB/s\")\n    print(f\"  \u6bcftoken\u65f6\u95f4: {time_per_token_ms:.1f} ms\")\n    print(f\"  Tokens/\u79d2: {tokens_per_sec:.0f}\")\n    print()\n\n# Apple M2 Pro\uff1a200 GB/s \u7edf\u4e00\u5185\u5b58\u5e26\u5bbd\nprint(\"=== Apple M2 Pro (200 GB/s) ===\")\nestimate_latency(\"Llama-7B Q4\", 7000, 4, 15.8, 200)\nestimate_latency(\"Llama-7B Q8\", 7000, 8, 15.8, 200)\nestimate_latency(\"Llama-70B Q4\", 70000, 4, 15.8, 200)\n\n# \u624b\u673a\uff08Snapdragon 8 Gen 3\uff09\uff1a~50 GB/s LPDDR5\nprint(\"=== Snapdragon 8 Gen 3 (50 GB/s) ===\")\nestimate_latency(\"Phi-3 Mini Q4\", 3800, 4, 45, 50)\nestimate_latency(\"Llama-3B Q4\", 3000, 4, 45, 50)\n

"},{"location":"chapter%2017%3A%20AI%20inference/05.%20scaling%20and%20deployment/","title":"\u6269\u7f29\u4e0e\u90e8\u7f72","text":"

\u5411\u6570\u767e\u4e07\u7528\u6237\u63d0\u4f9b\u5927\u6a21\u578b\u670d\u52a1\u9700\u8981\u8de8\u591a\u4e2aGPU\u5206\u5e03\u63a8\u7406\u3001\u5728\u9700\u8981\u4e4b\u524d\u9884\u6d4btoken\u3001\u7f13\u5b58\u5171\u4eab\u4e0a\u4e0b\u6587\u4ee5\u53ca\u9009\u62e9\u5408\u9002\u7684\u6846\u67b6\u3002\u672c\u6587\u6db5\u76d6\u63a8\u7406\u65f6\u7684\u5e76\u884c\u6027\u3001\u63a8\u6d4b\u6027\u89e3\u7801\u3001\u524d\u7f00\u7f13\u5b58\u3001\u63a8\u7406\u6846\u67b6\u3001\u6210\u672c\u4f18\u5316\u548c\u76d1\u63a7

"},{"location":"chapter%2017%3A%20AI%20inference/05.%20scaling%20and%20deployment/#_2","title":"\u63a8\u7406\u65f6\u7684\u6a21\u578b\u5e76\u884c","text":""},{"location":"chapter%2017%3A%20AI%20inference/05.%20scaling%20and%20deployment/#_3","title":"\u5f20\u91cf\u5e76\u884c","text":" \\[W = [W_1 | W_2 | \\cdots | W_N], \\quad Y_i = X W_i, \\quad Y = \\text{concat}(Y_1, \\ldots, Y_N)\\] "},{"location":"chapter%2017%3A%20AI%20inference/05.%20scaling%20and%20deployment/#_4","title":"\u6d41\u6c34\u7ebf\u5e76\u884c","text":""},{"location":"chapter%2017%3A%20AI%20inference/05.%20scaling%20and%20deployment/#_5","title":"\u5e8f\u5217\u5e76\u884c","text":""},{"location":"chapter%2017%3A%20AI%20inference/05.%20scaling%20and%20deployment/#_6","title":"\u63a8\u6d4b\u6027\u89e3\u7801","text":" \\[\\text{\u52a0\u901f\u6bd4} \\approx \\frac{k \\times \\text{acceptance\\_rate}}{\\text{cost\\_ratio}} \\approx 2\\text{-}3\\times\\] "},{"location":"chapter%2017%3A%20AI%20inference/05.%20scaling%20and%20deployment/#_7","title":"\u524d\u7f00\u7f13\u5b58","text":""},{"location":"chapter%2017%3A%20AI%20inference/05.%20scaling%20and%20deployment/#kv","title":"KV\u7f13\u5b58\u9a71\u9010","text":""},{"location":"chapter%2017%3A%20AI%20inference/05.%20scaling%20and%20deployment/#_8","title":"\u63a8\u7406\u6846\u67b6","text":" \u6846\u67b6 \u4f18\u52bf \u6700\u9002\u5408 vLLM PagedAttention\u3001\u8fde\u7eed\u6279\u5904\u7406\u3001\u9ad8\u541e\u5410\u91cf \u901a\u7528LLM\u670d\u52a1\uff0c\u6700\u9ad8\u541e\u5410\u91cf TensorRT-LLM NVIDIA\u4f18\u5316\u5185\u6838\u3001FP8\u3001\u98de\u884c\u4e2d\u6279\u5904\u7406 NVIDIA GPU\u4e0a\u7684\u6700\u5927\u6027\u80fd SGLang \u524d\u7f00\u7f13\u5b58\uff08RadixAttention\uff09\u3001\u5feb\u901f\u7ed3\u6784\u5316\u751f\u6210 \u5177\u6709\u5171\u4eab\u524d\u7f00\u7684\u5e94\u7528\uff0c\u53d7\u9650\u8f93\u51fa llama.cpp CPU/Metal/CUDA/Vulkan\u3001GGUF\u91cf\u5316\u3001\u53ef\u79fb\u690d \u6d88\u8d39\u7ea7\u786c\u4ef6\uff0c\u8bbe\u5907\u7aef\u63a8\u7406 TGI\uff08HuggingFace\uff09 \u7b80\u5355API\uff0c\u6613\u4e8e\u90e8\u7f72\uff0c\u6a21\u578b\u4e2d\u5fc3\u96c6\u6210 \u5feb\u901f\u90e8\u7f72\uff0cHuggingFace\u751f\u6001 Ollama \u4e00\u952e\u4e0b\u8f7d\u548c\u63d0\u4f9b\u670d\u52a1 \u4e2a\u4eba\u4f7f\u7528\uff0c\u672c\u5730\u5f00\u53d1 ExLlamaV2 \u6781\u81f4\u91cf\u5316\u4f18\u5316\uff08EXL2\u683c\u5f0f\uff09 \u5185\u5b58\u53d7\u9650\u7684GPU\u63a8\u7406 "},{"location":"chapter%2017%3A%20AI%20inference/05.%20scaling%20and%20deployment/#_9","title":"\u6210\u672c\u4f18\u5316","text":" \u914d\u7f6e \u6bcf100\u4e07token\u6210\u672c GPT-4o API $2.50 Claude 3.5 Sonnet API $3.00 Llama-70B on H100\uff08vLLM\uff0cFP16\uff09 $0.50 Llama-70B on H100\uff08TRT-LLM\uff0cINT8\uff09 $0.25 Llama-8B on A10G\uff08vLLM\uff0cINT4\uff09 $0.05 Llama-3B \u8bbe\u5907\u7aef\uff08llama.cpp\uff09 $0\uff08\u786c\u4ef6\u6210\u672c\u644a\u9500\uff09"},{"location":"chapter%2017%3A%20AI%20inference/05.%20scaling%20and%20deployment/#_10","title":"\u76d1\u63a7","text":""},{"location":"chapter%2017%3A%20AI%20inference/05.%20scaling%20and%20deployment/#colabnotebook","title":"\u7f16\u7a0b\u4efb\u52a1\uff08\u4f7f\u7528CoLab\u6216notebook\uff09","text":"
  1. \u6a21\u62df\u63a8\u6d4b\u6027\u89e3\u7801\u3002\u4f7f\u7528\u5feb\u901f\u7684\"\u8349\u7a3f\"\u51fd\u6570\u548c\u6162\u901f\u7684\"\u76ee\u6807\"\u51fd\u6570\uff0c\u6d4b\u91cf\u4e00\u6b21\u751f\u6210\u548c\u9a8c\u8bc1\u591a\u4e2atoken\u7684\u52a0\u901f\u6bd4\u3002

    import random\nimport time\n\ndef target_model(tokens):\n    \"\"\"\u6162\u4f46\u51c6\u786e\u7684\u6a21\u578b\u3002\u8fd4\u56de\u6bcf\u4e2a\u5019\u9009token\u7684\u6982\u7387\u3002\"\"\"\n    time.sleep(0.01)  # \u6a21\u62df\u6bcf\u6b21\u524d\u5411\u4f20\u64ad10ms\n    # \u7528\u4e8e\u6a21\u62df\uff1a\u63a5\u53d7\u5076\u6570token\n    return [0.9 if t % 2 == 0 else 0.1 for t in tokens]\n\ndef draft_model():\n    \"\"\"\u5feb\u4f46\u8fd1\u4f3c\u7684\u6a21\u578b\u3002\u751f\u6210\u4e00\u4e2a\u5019\u9009token\u3002\"\"\"\n    time.sleep(0.001)  # \u6a21\u62df\u6bcftoken 1ms\n    return random.randint(0, 9)\n\ndef standard_decoding(n_tokens):\n    \"\"\"\u4e00\u6b21\u751f\u6210\u4e00\u4e2atoken\uff0c\u4f7f\u7528\u76ee\u6807\u6a21\u578b\u3002\"\"\"\n    tokens = []\n    for _ in range(n_tokens):\n        time.sleep(0.01)  # \u76ee\u6807\u6a21\u578b\u751f\u62101\u4e2atoken\n        tokens.append(random.randint(0, 9))\n    return tokens\n\ndef speculative_decoding(n_tokens, k=4):\n    \"\"\"\u751f\u6210k\u4e2a\u8349\u7a3ftoken\uff0c\u7528\u76ee\u6807\u6a21\u578b\u9a8c\u8bc1\uff0c\u63a5\u53d7/\u62d2\u7edd\u3002\"\"\"\n    tokens = []\n    total_target_calls = 0\n\n    while len(tokens) < n_tokens:\n        # \u8349\u7a3f\uff1a\u5feb\u901f\u751f\u6210k\u4e2a\u5019\u9009\n        candidates = [draft_model() for _ in range(k)]\n\n        # \u9a8c\u8bc1\uff1a\u4e00\u6b21\u76ee\u6807\u6a21\u578b\u8c03\u7528\u9a8c\u8bc1\u6240\u6709k\u4e2a\u5019\u9009\n        probs = target_model(candidates)\n        total_target_calls += 1\n\n        # \u63a5\u53d7token\uff0c\u76f4\u5230\u4e00\u4e2a\u88ab\u62d2\u7edd\n        for i, (tok, prob) in enumerate(zip(candidates, probs)):\n            if random.random() < prob:\n                tokens.append(tok)\n                if len(tokens) >= n_tokens:\n                    break\n            else:\n                # \u4ece\u76ee\u6807\u5206\u5e03\u91cd\u65b0\u91c7\u6837\n                tokens.append(tok + 1)  # \u7b80\u5316\u91cd\u65b0\u91c7\u6837\n                break\n\n    return tokens, total_target_calls\n\nn = 50\n\nstart = time.time()\n_ = standard_decoding(n)\nstandard_time = time.time() - start\n\nstart = time.time()\n_, target_calls = speculative_decoding(n, k=5)\nspec_time = time.time() - start\n\nprint(f\"\u6807\u51c6:    {standard_time:.2f}s ({n} \u6b21\u76ee\u6807\u8c03\u7528)\")\nprint(f\"\u63a8\u6d4b\u6027: {spec_time:.2f}s ({target_calls} \u6b21\u76ee\u6807\u8c03\u7528)\")\nprint(f\"\u52a0\u901f\u6bd4:     {standard_time / spec_time:.1f}x\")\n

  2. \u4f30\u8ba1\u5e94\u7528\u4e8eLLM\u670d\u52a1\u90e8\u7f72\u7684\u4e0d\u540c\u4f18\u5316\u7b56\u7565\u7684\u6210\u672c\u8282\u7701\u3002

    def serving_cost_analysis(\n    model_name, params_B, precision_bits,\n    gpu_name, gpu_mem_gb, gpu_cost_per_hr,\n    target_throughput_tps,\n):\n    \"\"\"\u4f30\u8ba1LLM\u90e8\u7f72\u7684\u670d\u52a1\u6210\u672c\u3002\"\"\"\n    model_size_gb = params_B * 1e9 * precision_bits / 8 / 1e9\n    gpus_for_model = max(1, int((model_size_gb * 1.2) / gpu_mem_gb + 0.99))  # 1.2x\u7528\u4e8eKV\u7f13\u5b58\n\n    # \u7c97\u7565\u541e\u5410\u91cf\u4f30\u8ba1\uff08\u5185\u5b58\u5e26\u5bbd\u53d7\u9650\uff09\n    tokens_per_gpu = 500 / (params_B * precision_bits / 16)  # \u5f52\u4e00\u5316\u52307B FP16\u7684500 tok/s\n    total_throughput = tokens_per_gpu * gpus_for_model\n\n    replicas = max(1, int(target_throughput_tps / total_throughput + 0.99))\n    total_gpus = gpus_for_model * replicas\n    cost_per_hr = total_gpus * gpu_cost_per_hr\n    cost_per_1M_tokens = cost_per_hr / (total_throughput * replicas * 3600 / 1e6)\n\n    print(f\"{model_name} @ {precision_bits}-\u4f4d \u5728 {gpu_name} \u4e0a:\")\n    print(f\"  \u6a21\u578b\u5927\u5c0f: {model_size_gb:.0f} GB \u2192 {gpus_for_model} GPU(s)/\u526f\u672c\")\n    print(f\"  \u541e\u5410\u91cf: {total_throughput:.0f} tok/s/\u526f\u672c\")\n    print(f\"  \u9700\u8fbe\u5230{target_throughput_tps} tok/s\u7684\u526f\u672c\u6570: {replicas}\")\n    print(f\"  \u603bGPU\u6570: {total_gpus}\")\n    print(f\"  \u6210\u672c: ${cost_per_hr:.0f}/\u5c0f\u65f6, ${cost_per_1M_tokens:.2f}/100\u4e07token\")\n    print()\n\nprint(\"=== \u6210\u672c\u6bd4\u8f83 ===\\n\")\n\n# \u57fa\u7ebf\uff1aH100\u4e0a\u7684FP16\nserving_cost_analysis(\"Llama-70B\", 70, 16, \"H100\", 80, 8.0, 1000)\n\n# \u91cf\u5316\u540e\uff1aH100\u4e0a\u7684INT8\nserving_cost_analysis(\"Llama-70B\", 70, 8, \"H100\", 80, 8.0, 1000)\n\n# \u91cf\u5316\u540e\uff1aA100\u4e0a\u7684INT4\nserving_cost_analysis(\"Llama-70B\", 70, 4, \"A100\", 80, 4.0, 1000)\n\n# \u5c0f\u6a21\u578b\uff1aA10G\u4e0a\u76848B\nserving_cost_analysis(\"Llama-8B\", 8, 4, \"A10G\", 24, 1.0, 1000)\n

"},{"location":"chapter%2018%3A%20ML%20systems%20design/01.%20systems%20design%20fundamentals/","title":"\u7cfb\u7edf\u8bbe\u8ba1\u57fa\u7840","text":"

\u7cfb\u7edf\u8bbe\u8ba1\u662f\u6784\u5efa\u53ef\u5728\u5927\u89c4\u6a21\u4e0b\u53ef\u9760\u8fd0\u884c\u7684\u8f6f\u4ef6\u7684\u65b9\u6cd5\u3002\u672c\u6587\u4ef6\u6db5\u76d6\u5ba2\u6237\u7aef-\u670d\u52a1\u5668\u67b6\u6784\u3001\u7f51\u7edc\u534f\u8bae\u3001DNS\u3001\u4ee3\u7406\u3001\u8d1f\u8f7d\u5747\u8861\u3001\u7f13\u5b58\u3001\u6570\u636e\u5e93\u3001\u6d88\u606f\u961f\u5217\u3001\u4e00\u81f4\u6027\u6a21\u578b\u548c\u5f39\u6027\u6a21\u5f0f

"},{"location":"chapter%2018%3A%20ML%20systems%20design/01.%20systems%20design%20fundamentals/#-","title":"\u5ba2\u6237\u7aef-\u670d\u52a1\u5668\u67b6\u6784","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/01.%20systems%20design%20fundamentals/#_2","title":"\u7f51\u7edc\u534f\u8bae","text":"
message PredictRequest {\n    repeated float features = 1;\n    string model_version = 2;\n}\n\nmessage PredictResponse {\n    float prediction = 1;\n    float confidence = 2;\n}\n\nservice ModelService {\n    rpc Predict(PredictRequest) returns (PredictResponse);\n}\n
"},{"location":"chapter%2018%3A%20ML%20systems%20design/01.%20systems%20design%20fundamentals/#dns","title":"DNS","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/01.%20systems%20design%20fundamentals/#_3","title":"\u4ee3\u7406","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/01.%20systems%20design%20fundamentals/#_4","title":"\u8d1f\u8f7d\u5747\u8861","text":" "},{"location":"chapter%2018%3A%20ML%20systems%20design/01.%20systems%20design%20fundamentals/#_5","title":"\u7f13\u5b58","text":" "},{"location":"chapter%2018%3A%20ML%20systems%20design/01.%20systems%20design%20fundamentals/#_6","title":"\u6570\u636e\u5e93","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/01.%20systems%20design%20fundamentals/#sql","title":"SQL\uff08\u5173\u7cfb\u578b\uff09","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/01.%20systems%20design%20fundamentals/#nosql","title":"NoSQL","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/01.%20systems%20design%20fundamentals/#cap","title":"CAP\u5b9a\u7406","text":" "},{"location":"chapter%2018%3A%20ML%20systems%20design/01.%20systems%20design%20fundamentals/#_7","title":"\u5206\u7247","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/01.%20systems%20design%20fundamentals/#_8","title":"\u6570\u636e\u5e93\u7d22\u5f15","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/01.%20systems%20design%20fundamentals/#api","title":"API\u8bbe\u8ba1","text":"
{\n    \"error\": {\n        \"code\": \"INVALID_INPUT\",\n        \"message\": \"\u7279\u5f81'user_age'\u5fc5\u987b\u4e3a\u6b63\u6574\u6570\",\n        \"details\": {\"field\": \"user_age\", \"value\": -5}\n    }\n}\n
"},{"location":"chapter%2018%3A%20ML%20systems%20design/01.%20systems%20design%20fundamentals/#_9","title":"\u6d88\u606f\u961f\u5217","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/01.%20systems%20design%20fundamentals/#_10","title":"\u4e00\u81f4\u6027\u6a21\u578b","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/01.%20systems%20design%20fundamentals/#_11","title":"\u5f39\u6027\u6a21\u5f0f","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/02.%20cloud%20computing/","title":"\u4e91\u8ba1\u7b97","text":"

\u4e91\u8ba1\u7b97\u4e3aML\u5de5\u4f5c\u8d1f\u8f7d\u63d0\u4f9b\u6309\u9700\u57fa\u7840\u8bbe\u65bd\uff0c\u65e0\u9700\u62e5\u6709\u786c\u4ef6\u3002\u672c\u6587\u4ef6\u6db5\u76d6\u670d\u52a1\u6a21\u578b\u3001\u4e3b\u8981\u4e91\u670d\u52a1\u5546\u3001\u5bb9\u5668\u548cKubernetes\u3001\u5b58\u50a8\u3001\u4e91\u7f51\u7edc\u3001\u65e0\u670d\u52a1\u5668\u8ba1\u7b97\u3001\u6210\u672c\u7ba1\u7406\u548c\u57fa\u7840\u8bbe\u65bd\u5373\u4ee3\u7801

"},{"location":"chapter%2018%3A%20ML%20systems%20design/02.%20cloud%20computing/#_2","title":"\u4e91\u670d\u52a1\u6a21\u578b","text":" \u6a21\u578b \u4f60\u7ba1\u7406 \u63d0\u4f9b\u5546\u7ba1\u7406 \u793a\u4f8b IaaS\uff08\u57fa\u7840\u8bbe\u65bd\uff09 \u64cd\u4f5c\u7cfb\u7edf\u3001\u8fd0\u884c\u65f6\u3001\u5e94\u7528 \u786c\u4ef6\u3001\u865a\u62df\u5316\u3001\u7f51\u7edc AWS EC2\u3001GCP Compute Engine PaaS\uff08\u5e73\u53f0\uff09 \u5e94\u7528\u3001\u6570\u636e \u64cd\u4f5c\u7cfb\u7edf\u3001\u8fd0\u884c\u65f6\u3001\u6269\u5c55\u3001\u4fee\u8865 AWS SageMaker\u3001GCP Vertex AI SaaS\uff08\u8f6f\u4ef6\uff09 \u4ec0\u4e48\u90fd\u4e0d\u7528\u7ba1\uff08\u53ea\u7ba1\u7528\uff09 \u4e00\u5207 OpenAI API\u3001Weights & Biases FaaS\uff08\u51fd\u6570\uff09 \u5355\u4e2a\u51fd\u6570 \u5176\u4ed6\u6240\u6709 AWS Lambda\u3001GCP Cloud Functions "},{"location":"chapter%2018%3A%20ML%20systems%20design/02.%20cloud%20computing/#_3","title":"\u4e3b\u8981\u4e91\u670d\u52a1\u5546","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/02.%20cloud%20computing/#aws","title":"AWS\uff08\u4e9a\u9a6c\u900a\u4e91\u670d\u52a1\uff09","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/02.%20cloud%20computing/#gcp","title":"GCP\uff08\u8c37\u6b4c\u4e91\u5e73\u53f0\uff09","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/02.%20cloud%20computing/#azure","title":"Azure\uff08\u5fae\u8f6f\uff09","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/02.%20cloud%20computing/#kubernetes","title":"\u5bb9\u5668\u548cKubernetes","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/02.%20cloud%20computing/#kubernetesml","title":"Kubernetes\u7528\u4e8eML","text":"
resources:\n  limits:\n    nvidia.com/gpu: 2  # \u6b64Pod\u9700\u89812\u4e2aGPU\n
"},{"location":"chapter%2018%3A%20ML%20systems%20design/02.%20cloud%20computing/#_4","title":"\u81ea\u52a8\u7f29\u653e","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/02.%20cloud%20computing/#_5","title":"\u5b58\u50a8","text":"\u7c7b\u578b \u7279\u6027 \u7528\u9014 \u793a\u4f8b \u5757\u5b58\u50a8 \u4f4e\u5ef6\u8fdf\uff0c\u9644\u52a0\u5230\u5355\u53f0VM \u64cd\u4f5c\u7cfb\u7edf\u78c1\u76d8\u3001\u6570\u636e\u5e93 AWS EBS\u3001GCP Persistent Disk \u5bf9\u8c61\u5b58\u50a8 \u65e0\u9650\u5bb9\u91cf\uff0cHTTP\u8bbf\u95ee \u6570\u636e\u96c6\u3001\u6a21\u578b\u6743\u91cd\u3001\u65e5\u5fd7 AWS S3\u3001GCS\u3001Azure Blob \u6587\u4ef6\u5b58\u50a8 \u8de8VM\u5171\u4eab\uff0cPOSIX \u5171\u4eab\u8bad\u7ec3\u6570\u636e AWS EFS\u3001GCP Filestore\u3001NFS \u6570\u636e\u6e56 \u8bfb\u53d6\u65f6\u5b9a\u4e49\u6a21\u5f0f\uff0c\u539f\u59cb\u6570\u636e \u5206\u6790\u3001\u7279\u5f81\u5de5\u7a0b Delta Lake\u3001Iceberg\u3001Hudi "},{"location":"chapter%2018%3A%20ML%20systems%20design/02.%20cloud%20computing/#_6","title":"\u4e91\u7f51\u7edc","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/02.%20cloud%20computing/#_7","title":"\u65e0\u670d\u52a1\u5668\u8ba1\u7b97","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/02.%20cloud%20computing/#_8","title":"\u6210\u672c\u7ba1\u7406","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/02.%20cloud%20computing/#_9","title":"\u591a\u533a\u57df\u90e8\u7f72","text":" GPU AWS GCP Azure \u5178\u578b\u7528\u9014 A10G\uff0824 GB\uff09 $1.00/\u5c0f\u65f6\uff08g5\uff09 $0.90/\u5c0f\u65f6 $0.90/\u5c0f\u65f6 \u5c0f\u6a21\u578b\u63a8\u7406 A100\uff0880 GB\uff09 $4.10/\u5c0f\u65f6\uff08p4d\uff09 $3.70/\u5c0f\u65f6 $3.40/\u5c0f\u65f6 \u8bad\u7ec3\u3001\u5927\u578b\u63a8\u7406 H100\uff0880 GB\uff09 $8.00/\u5c0f\u65f6\uff08p5\uff09 $7.50/\u5c0f\u65f6 $7.00/\u5c0f\u65f6 \u524d\u6cbf\u8bad\u7ec3 TPU v5e \u65e0 $1.20/\u5c0f\u65f6 \u65e0 JAX\u5927\u89c4\u6a21\u8bad\u7ec3 "},{"location":"chapter%2018%3A%20ML%20systems%20design/02.%20cloud%20computing/#_10","title":"\u57fa\u7840\u8bbe\u65bd\u5373\u4ee3\u7801","text":"
# main.tf \u2014 \u521b\u5efa\u7528\u4e8e\u63a8\u7406\u7684GPU VM\nresource \"aws_instance\" \"model_server\" {\n  ami           = \"ami-0abcdef1234567890\"  # \u6df1\u5ea6\u5b66\u4e60AMI\n  instance_type = \"g5.xlarge\"               # A10G GPU\n\n  tags = {\n    Name = \"model-server-prod\"\n  }\n}\n\nresource \"aws_s3_bucket\" \"model_weights\" {\n  bucket = \"my-model-weights-prod\"\n\n  versioning {\n    enabled = true\n  }\n}\n
terraform init      # \u4e0b\u8f7d\u63d0\u4f9b\u5546\u63d2\u4ef6\nterraform plan      # \u663e\u793a\u5c06\u8981\u66f4\u6539\u7684\u5185\u5bb9\nterraform apply     # \u521b\u5efa\u57fa\u7840\u8bbe\u65bd\nterraform destroy   # \u5168\u90e8\u62c6\u9664\n
"},{"location":"chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/","title":"\u5927\u89c4\u6a21\u57fa\u7840\u8bbe\u65bd","text":"

\u6784\u5efa\u670d\u52a1\u6570\u767e\u4e07\u7528\u6237\u7684\u7cfb\u7edf\u9700\u8981\u7684\u4e0d\u53ea\u662f\u5355\u4e2a\u670d\u52a1\u5668\u3002\u672c\u6587\u4ef6\u6db5\u76d6\u53ef\u6269\u5c55\u6027\u6a21\u5f0f\u3001\u5206\u5e03\u5f0f\u7cfb\u7edf\u57fa\u7840\u3001\u5fae\u670d\u52a1\u3001\u6570\u636e\u6d41\u6c34\u7ebf\u3001\u6570\u636e\u5e93\u6269\u5c55\u3001\u641c\u7d22\u548c\u5411\u91cf\u7cfb\u7edf\u3001\u53ef\u89c2\u6d4b\u6027\u3001\u53ef\u9760\u6027\u5de5\u7a0b\u4ee5\u53caCI/CD

"},{"location":"chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/#_2","title":"\u53ef\u6269\u5c55\u6027","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/#_3","title":"\u5206\u5e03\u5f0f\u7cfb\u7edf","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/#_4","title":"\u5fae\u670d\u52a1","text":"
\u250c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2510  \u250c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2510  \u250c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2510\n\u2502 API\u7f51\u5173     \u2502\u2192 \u2502 \u7279\u5f81\u670d\u52a1     \u2502\u2192 \u2502 \u7279\u5f81\u6570\u636e\u5e93   \u2502\n\u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518  \u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518  \u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n       \u2502\n       \u251c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2192 \u250c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2510  \u250c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2510\n       \u2502          \u2502 \u6a21\u578b\u670d\u52a1     \u2502\u2192 \u2502 \u6a21\u578b\u5b58\u50a8     \u2502\n       \u2502          \u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518  \u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n       \u2502\n       \u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2192 \u250c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2510  \u250c\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2510\n                  \u2502 \u65e5\u5fd7\u670d\u52a1     \u2502\u2192 \u2502 \u65e5\u5fd7\u5b58\u50a8     \u2502\n                  \u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518  \u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n
"},{"location":"chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/#_5","title":"\u6570\u636e\u6d41\u6c34\u7ebf","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/#_6","title":"\u6279\u5904\u7406","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/#_7","title":"\u6d41\u5904\u7406","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/#lambda","title":"Lambda\u67b6\u6784","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/#ml","title":"ML\u8bad\u7ec3\u57fa\u7840\u8bbe\u65bd","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/#gpu","title":"GPU\u96c6\u7fa4","text":" "},{"location":"chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/#_8","title":"\u7f51\u7edc\u62d3\u6251","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/#_9","title":"\u8bad\u7ec3\u5b58\u50a8","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/#_10","title":"\u4f5c\u4e1a\u8c03\u5ea6","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/#_11","title":"\u5bb9\u9519","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/#_12","title":"\u6210\u672c\u548c\u6548\u7387","text":" \u7ec4\u4ef6 \u5360\u603b\u6210\u672c\u767e\u5206\u6bd4 GPU\u8ba1\u7b97 70-80% \u7f51\u7edc\uff08InfiniBand\uff09 10-15% \u5b58\u50a8 5-10% \u51b7\u5374\u548c\u7535\u6e90 5-10% "},{"location":"chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/#_13","title":"\u6570\u636e\u5e93\u6269\u5c55","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/#_14","title":"\u641c\u7d22\u548c\u5411\u91cf\u7cfb\u7edf","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/#_15","title":"\u6587\u672c\u641c\u7d22","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/#_16","title":"\u5411\u91cf\u641c\u7d22","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/#_17","title":"\u53ef\u89c2\u6d4b\u6027","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/#_18","title":"\u65e5\u5fd7","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/#_19","title":"\u6307\u6807","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/#_20","title":"\u8ffd\u8e2a","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/#_21","title":"\u53ef\u9760\u6027","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/03.%20large%20scale%20infrastructure/#cicd","title":"CI/CD","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/","title":"ML\u7cfb\u7edf\u8bbe\u8ba1","text":"

ML\u7cfb\u7edf\u8bbe\u8ba1\u5c06\u6587\u4ef601-03\u4e2d\u7684\u57fa\u7840\u8bbe\u65bd\u6a21\u5f0f\u5e94\u7528\u4e8e\u673a\u5668\u5b66\u4e60\u7684\u7279\u5b9a\u6311\u6218\u3002\u672c\u6587\u4ef6\u6db5\u76d6ML\u751f\u547d\u5468\u671f\u3001\u6570\u636e\u7ba1\u7406\u3001\u8bad\u7ec3\u57fa\u7840\u8bbe\u65bd\u3001\u6a21\u578b\u8bc4\u4f30\u3001\u670d\u52a1\u7b56\u7565\u3001\u7279\u5f81\u5de5\u7a0b\u3001ML\u6d41\u6c34\u7ebf\u548c\u76d1\u63a7

"},{"location":"chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/#ml_1","title":"ML\u7cfb\u7edf\u751f\u547d\u5468\u671f","text":"
\u95ee\u9898\u5b9a\u4e49 \u2192 \u6570\u636e \u2192 \u7279\u5f81 \u2192 \u8bad\u7ec3 \u2192 \u8bc4\u4f30 \u2192 \u90e8\u7f72 \u2192 \u76d1\u63a7 \u2192 \u8fed\u4ee3\n       \u2191                                                        \u2502\n       \u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n
"},{"location":"chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/#_1","title":"\u95ee\u9898\u5b9a\u4e49","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/#_2","title":"\u6570\u636e\u7ba1\u7406","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/#_3","title":"\u6570\u636e\u6536\u96c6\u548c\u6807\u6ce8","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/#_4","title":"\u6570\u636e\u8d28\u91cf","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/#_5","title":"\u7279\u5f81\u5b58\u50a8","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/#_6","title":"\u8bad\u7ec3\u57fa\u7840\u8bbe\u65bd","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/#_7","title":"\u6a21\u578b\u8bc4\u4f30","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/#_8","title":"\u79bb\u7ebf\u8bc4\u4f30","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/#_9","title":"\u5728\u7ebf\u8bc4\u4f30","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/#_10","title":"\u6a21\u578b\u670d\u52a1","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/#vs","title":"\u6279\u91cfvs\u5b9e\u65f6","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/#_11","title":"\u6a21\u578b\u7248\u672c\u7ba1\u7406\u548c\u6ce8\u518c\u8868","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/#_12","title":"\u7279\u5f81\u5de5\u7a0b","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/#vs_1","title":"\u5728\u7ebfvs\u79bb\u7ebf\u7279\u5f81","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/#_13","title":"\u5e38\u89c1\u7279\u5f81\u6a21\u5f0f","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/#ml_2","title":"ML\u6d41\u6c34\u7ebf","text":"
\u6570\u636e\u6444\u5165 \u2192 \u9a8c\u8bc1 \u2192 \u7279\u5f81\u5de5\u7a0b \u2192 \u8bad\u7ec3 \u2192 \u8bc4\u4f30 \u2192 \u6ce8\u518c \u2192 \u90e8\u7f72 \u2192 \u76d1\u63a7\n
"},{"location":"chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/#_14","title":"\u76d1\u63a7","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/#_15","title":"\u6570\u636e\u6f02\u79fb","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/#_16","title":"\u6982\u5ff5\u6f02\u79fb","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/#_17","title":"\u6a21\u578b\u9000\u5316","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/#_18","title":"\u53cd\u9988\u5faa\u73af","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/#_19","title":"\u5d4c\u5165\u8868\u7ba1\u7406","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/04.%20ML%20systems%20design/#_20","title":"\u516c\u5e73\u6027\u548c\u504f\u89c1","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/","title":"ML\u8bbe\u8ba1\u793a\u4f8b","text":"

\u5b66\u4e60ML\u7cfb\u7edf\u8bbe\u8ba1\u7684\u6700\u4f73\u65b9\u5f0f\u662f\u901a\u8fc7\u5b9e\u64cd\u793a\u4f8b\u3002\u672c\u6587\u4ef6\u8be6\u7ec6\u4ecb\u7ecd\u4e86\u4e03\u4e2a\u5b8c\u6574\u7684\u8bbe\u8ba1\uff1a\u63a8\u8350\u7cfb\u7edf\u3001\u641c\u7d22\u6392\u5e8f\u3001\u5e7f\u544a\u70b9\u51fb\u9884\u6d4b\u3001\u6b3a\u8bc8\u68c0\u6d4b\u3001\u5185\u5bb9\u5ba1\u6838\u3001\u5bf9\u8bdd\u5f0fAI\u548c\u5927\u89c4\u6a21\u56fe\u50cf\u641c\u7d22

"},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#1-youtubenetflixspotify","title":"1. \u63a8\u8350\u7cfb\u7edf\uff08\u4f8b\u5982YouTube\u3001Netflix\u3001Spotify\uff09","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_1","title":"\u95ee\u9898\u5b9a\u4e49","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_2","title":"\u67b6\u6784\uff1a\u4e24\u9636\u6bb5\u6d41\u6c34\u7ebf","text":"
1\u4ebf\u4e2a\u9879\u76ee \u2192 \u5019\u9009\u751f\u6210\uff08\u5feb\u901f\u3001\u7c97\u7565\uff09\u2192 1000\u4e2a\u5019\u9009\n          \u2192 \u6392\u5e8f\uff08\u7f13\u6162\u3001\u7cbe\u786e\uff09\u2192 100\u4e2a\u6392\u5e8f\u9879\u76ee\n          \u2192 \u91cd\u65b0\u6392\u5e8f\uff08\u4e1a\u52a1\u89c4\u5219\uff09\u2192 \u5c55\u793a\u7ed9\u7528\u6237\u768420\u4e2a\n
"},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_3","title":"\u5019\u9009\u751f\u6210","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_4","title":"\u6392\u5e8f","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_5","title":"\u91cd\u65b0\u6392\u5e8f","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_6","title":"\u7c97\u7565\u4f30\u7b97\u6570\u5b57","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_7","title":"\u51b7\u542f\u52a8","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_8","title":"\u8bc4\u4f30","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#2-googlebing","title":"2. \u641c\u7d22\u6392\u5e8f\uff08\u4f8b\u5982Google\u3001Bing\uff09","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_9","title":"\u95ee\u9898\u5b9a\u4e49","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_10","title":"\u67b6\u6784\uff1a\u67e5\u8be2\u7406\u89e3\u2192\u68c0\u7d22\u2192\u6392\u5e8f","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_11","title":"\u67e5\u8be2\u7406\u89e3","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_12","title":"\u68c0\u7d22","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_13","title":"\u6392\u5e8f","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_14","title":"\u7279\u5f81","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#3","title":"3. \u5e7f\u544a\u70b9\u51fb\u9884\u6d4b","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_15","title":"\u95ee\u9898\u5b9a\u4e49","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_16","title":"\u67b6\u6784","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_17","title":"\u5b9e\u65f6\u7ade\u4ef7","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#4","title":"4. \u6b3a\u8bc8\u68c0\u6d4b","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_18","title":"\u95ee\u9898\u5b9a\u4e49","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_19","title":"\u67b6\u6784","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_20","title":"\u7279\u5f81","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_21","title":"\u6a21\u578b","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_22","title":"\u4eba\u5728\u56de\u8def\u4e2d","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#5","title":"5. \u5185\u5bb9\u5ba1\u6838","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_23","title":"\u95ee\u9898\u5b9a\u4e49","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_24","title":"\u67b6\u6784","text":"
if text_model.hate_speech_score > 0.9:\n    action = \"remove\"\nelif text_model.hate_speech_score > 0.7:\n    action = \"human_review\"\nelse:\n    action = \"allow\"\n
"},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#vs","title":"\u4e3b\u52a8vs\u88ab\u52a8\u5ba1\u6838","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_25","title":"\u54c8\u5e0c\u5339\u914d","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_26","title":"\u7c97\u7565\u4f30\u7b97\u6570\u5b57","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_27","title":"\u5347\u7ea7\u5de5\u4f5c\u6d41","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#6-airag","title":"6. \u5bf9\u8bdd\u5f0fAI\uff08\u57fa\u4e8eRAG\u7684\u804a\u5929\u673a\u5668\u4eba\uff09","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_28","title":"\u95ee\u9898\u5b9a\u4e49","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#rag","title":"\u67b6\u6784\uff1a\u68c0\u7d22\u589e\u5f3a\u751f\u6210\uff08RAG\uff09","text":"
\u7528\u6237\u67e5\u8be2 \u2192 \u67e5\u8be2\u5d4c\u5165 \u2192 \u5411\u91cf\u641c\u7d22\uff08\u6587\u6863\uff09\u2192 Top-K\u5757\n                                                      \u2193\n\u7528\u6237\u67e5\u8be2 + \u68c0\u7d22\u5230\u7684\u5757 \u2192 LLM \u2192 \u54cd\u5e94\uff08\u542b\u5f15\u7528\uff09\n
"},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_29","title":"\u7ec4\u4ef6","text":"
\u7cfb\u7edf\uff1a\u4f60\u662f\u4e00\u4e2a\u6709\u7528\u7684\u52a9\u624b\u3002\u4ec5\u57fa\u4e8e\u63d0\u4f9b\u7684\u4e0a\u4e0b\u6587\u56de\u7b54\u3002\n\u5982\u679c\u7b54\u6848\u4e0d\u5728\u4e0a\u4e0b\u6587\u4e2d\uff0c\u8bf7\u8bf4\"\u6211\u4e0d\u77e5\u9053\u3002\"\n\n\u4e0a\u4e0b\u6587\uff1a\n[\u5757 1]\n[\u5757 2]\n...\n\n\u7528\u6237\uff1a{\u95ee\u9898}\n
"},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_30","title":"\u67e5\u8be2\u91cd\u5199","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_31","title":"\u7c97\u7565\u4f30\u7b97\u6570\u5b57","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_32","title":"\u8bc4\u4f30","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#7","title":"7. \u5927\u89c4\u6a21\u56fe\u50cf\u641c\u7d22","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_33","title":"\u95ee\u9898\u5b9a\u4e49","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_34","title":"\u67b6\u6784","text":"
\u67e5\u8be2\u56fe\u50cf \u2192 \u5d4c\u5165\u6a21\u578b\uff08ViT/CLIP\uff09\u2192 512\u7ef4\u5411\u91cf \u2192 ANN\u641c\u7d22 \u2192 Top-K\u7ed3\u679c\n
"},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_35","title":"\u5d4c\u5165\u63d0\u53d6","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_36","title":"\u7d22\u5f15","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_37","title":"\u670d\u52a1","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_38","title":"\u8bc4\u4f30","text":""},{"location":"chapter%2018%3A%20ML%20systems%20design/05.%20ML%20design%20examples/#_39","title":"\u9762\u8bd5\u6846\u67b6","text":""},{"location":"chapter%2019%3A%20applied%20AI/01.%20AI%20for%20finance/","title":"AI for Finance","text":""},{"location":"chapter%2019%3A%20applied%20AI/02.%20protein%20design/","title":"AI for Biology","text":""},{"location":"chapter%2020%3A%20bleeding%20edge%20AI/01.%20quantum%20machine%20learning/","title":"\u91cf\u5b50\u673a\u5668\u5b66\u4e60 (Quantum Machine Learning)","text":""},{"location":"chapter%2020%3A%20bleeding%20edge%20AI/02.%20neuromorphic%20computing/","title":"\u795e\u7ecf\u5f62\u6001\u8ba1\u7b97 (Neuromorphic Computing)","text":""}]}