# 代码库设计与模式 *良好的代码库设计是区分研究原型与生产级软件的关键。本文涵盖项目结构、整洁代码原则、与机器学习相关的设计模式、配置管理、日志、API 设计以及打包分发。* - 大多数机器学习代码始于 Jupyter notebook。Notebook 不断增长、被复制、修改、共享,最终变成由全局变量、死单元格和魔数组成的难以维护的混乱。**代码库设计**是一门组织代码的学科,使代码在项目增长过程中保持可理解和可修改。 - 这不是为了遵循规则而遵循规则。而是为了减少从"我想改变 X"到"X 已被修改并能正常工作"之间的时间。在精心设计的代码库中,这个时间是几分钟。在设计糟糕的代码库中,则需要几天的时间去考古、翻阅未记录的意大利面条式代码。 ## 项目结构 - 一致的项目布局让任何人(包括未来的你)都能立即浏览代码库。 ``` my_project/ ├── src/my_project/ # 源代码(可导入的包) │ ├── __init__.py │ ├── data/ # 数据加载和预处理 │ │ ├── __init__.py │ │ ├── dataset.py │ │ └── transforms.py │ ├── models/ # 模型架构 │ │ ├── __init__.py │ │ ├── transformer.py │ │ └── layers.py │ ├── training/ # 训练循环、优化器 │ │ ├── __init__.py │ │ ├── trainer.py │ │ └── losses.py │ └── utils/ # 共享工具 │ ├── __init__.py │ └── logging.py ├── configs/ # 配置文件 │ ├── base.yaml │ └── experiment_1.yaml ├── scripts/ # 入口点(训练、评估、推理) │ ├── train.py │ ├── evaluate.py │ └── serve.py ├── tests/ # 测试文件(镜像 src/ 结构) │ ├── test_dataset.py │ ├── test_model.py │ └── test_trainer.py ├── notebooks/ # 仅用于探索(非生产代码) ├── pyproject.toml # 项目元数据和依赖 ├── README.md ├── .gitignore └── Dockerfile ``` - **`src/` 布局**:将源代码放在 `src/my_project/` 下可以防止从当前目录意外导入(这会掩盖在生产环境中才会暴露的导入错误)。使用 `pip install -e .` 进行开发安装。 - **单仓库 vs 多仓库**:**单仓库**将所有相关项目放在一个仓库中(跨项目更改更容易、CI 共享)。**多仓库**给每个项目自己的仓库(边界更清晰、版本控制独立)。大多数机器学习团队从单仓库开始,必要时再拆分。 - **脚本 vs 库**:将入口点(`train.py`、`evaluate.py`)保留在 `scripts/` 中。将可复用的逻辑放在 `src/` 中。训练脚本应约为 50 行:解析配置、构建数据集、构建模型、构建训练器、训练。所有复杂性都在库中。 ## 整洁代码原则 - **命名**:你能做的唯一最有影响力的事情。名为 `x` 的变量需要你阅读周围的代码才能理解。名为 `learning_rate` 的变量是自解释的。 ```python # 糟糕 def proc(d, n, lr): for i in range(n): for k, v in d.items(): v -= lr * g[k] # 良好 def update_parameters(parameters, num_steps, learning_rate): for step in range(num_steps): for name, param in parameters.items(): param -= learning_rate * gradients[name] ``` - **单一职责原则**:每个函数/类只做一件事。名为 `load_data_and_train_model` 的函数在做两件事,应该拆分。这使每个部分都可以独立测试、复用和理解。 - **DRY(不要重复自己)**——但不要过早抽象。如果你复制粘贴代码三次,将其提取为一个函数。但不要为只使用过一次的代码创建抽象。过早的抽象比重复更糟糕:它增加了复杂性但没有经过验证的好处。 ```python # 过早抽象(一个用例,过度设计) class AbstractDataTransformPipelineFactory: ... # 恰到好处(直接、清晰、在三处使用) def normalise_image(image, mean, std): return (image - mean) / std ``` - **魔数**:永远不要使用未解释的字面值。 ```python # 糟糕 if len(batch) > 32: split_batch(batch, 32) # 良好 MAX_BATCH_SIZE = 32 if len(batch) > MAX_BATCH_SIZE: split_batch(batch, MAX_BATCH_SIZE) ``` - **函数应该简短**:如果一个函数不能在一屏内显示完整(约 30 行),那它可能做得太多了。将逻辑块提取为带有描述性名称的辅助函数。然后函数体读起来就像高级摘要。 ## 适用于机器学习的设范计式 - 设计模式是针对常见问题的可复用解决方案。以下是与机器学习代码库最相关的模式: - **工厂模式**:在不指定确切类的情况下创建对象。当你的配置说 `model: "transformer"` 并且你需要实例化正确的类时很有用: ```python MODEL_REGISTRY = { "transformer": TransformerModel, "cnn": CNNModel, "mlp": MLPModel, } def build_model(config): model_cls = MODEL_REGISTRY[config["model"]] return model_cls(**config["model_params"]) ``` - 这使训练脚本与特定的模型实现解耦。添加新模型意味着在注册表中添加一行,而不是修改训练循环。 - **策略模式**:在运行时交换算法。适用于损失函数、优化器、调度器: ```python LOSS_FUNCTIONS = { "mse": nn.MSELoss, "cross_entropy": nn.CrossEntropyLoss, "focal": FocalLoss, } loss_fn = LOSS_FUNCTIONS[config["loss"]]() ``` - **观察者模式**(回调/钩子):让模块响应事件而不紧密耦合。训练框架(PyTorch Lightning、Keras)广泛使用回调: ```python class EarlyStopping: def __init__(self, patience=5): self.patience = patience self.best_loss = float('inf') self.counter = 0 def on_epoch_end(self, epoch, val_loss): if val_loss < self.best_loss: self.best_loss = val_loss self.counter = 0 else: self.counter += 1 if self.counter >= self.patience: return "stop" ``` - **依赖注入**:将依赖项传入函数/类,而不是在内部创建。这使得测试变得容易(注入 mock)并且配置灵活: ```python # 糟糕:硬编码依赖 class Trainer: def __init__(self): self.logger = WandbLogger() # 没有 W&B 就无法测试 # 良好:注入依赖 class Trainer: def __init__(self, logger): self.logger = logger # 可以注入任何记录器,包括 mock ``` ## 配置管理 - 硬编码超参数、文件路径和模型设置使实验无法重现,修改也很痛苦。**将配置外部化**到文件中。 - **YAML** 是机器学习配置最常见的格式: ```yaml # configs/experiment_1.yaml model: name: transformer d_model: 512 n_heads: 8 n_layers: 6 training: batch_size: 64 learning_rate: 3e-4 max_epochs: 100 early_stopping_patience: 10 data: train_path: /data/train.parquet val_path: /data/val.parquet max_seq_length: 512 ``` - **Hydra**(Facebook)是一个支持组合(将基础配置与实验特定覆盖合并)、命令行覆盖(`python train.py training.lr=1e-3`)和多运行(超参数扫描)的配置框架。 - **argparse** 适用于参数较少的脚本: ```python import argparse parser = argparse.ArgumentParser() parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--config", type=str, default="configs/base.yaml") args = parser.parse_args() ``` - **最佳实践**:有一个包含所有默认值的基础配置,以及每个实验的配置,只覆盖更改的部分。追踪每个实验的配置及其结果。 ## 日志与可观测性 - `print` 语句用于调试。**日志**用于生产环境: ```python import logging logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) logger.debug("Batch loaded: %d samples", len(batch)) # 详细,用于调试 logger.info("Epoch %d: loss=%.4f, lr=%.6f", epoch, loss, lr) # 正常运行 logger.warning("GPU memory >90%%, consider reducing batch size") logger.error("Failed to load checkpoint: %s", path) # 可恢复的错误 logger.critical("CUDA out of memory, aborting") # 致命错误 ``` - **为什么不用 print**:日志支持级别(在生产环境中过滤调试消息)、格式化(时间戳、模块名)和处理程序(写入文件、发送到监控系统),而无需更改日志调用。 - **结构化日志**同时输出机器可解析的格式(JSON)和人类可读的消息。这使得可以搜索特定字段并设置告警: ```python logger.info("training_step", extra={ "epoch": 5, "step": 1200, "loss": 0.0342, "lr": 2.1e-4 }) ``` ## API 设计 - 如果你的模型将被其他服务使用(Web 应用、移动应用、另一个机器学习管道),它需要一个 **API**(应用程序编程接口)。 - **REST API** 使用 HTTP 方法:`GET` 用于读取,`POST` 用于创建/预测,`PUT` 用于更新,`DELETE` 用于删除。端点遵循基于资源的命名: ``` POST /api/v1/predict # 发送输入,获取预测结果 GET /api/v1/models # 列出可用模型 GET /api/v1/models/{id} # 获取模型详情 POST /api/v1/models/{id}/predict # 使用特定模型进行预测 ``` - **FastAPI** 是机器学习推理的首选 Python 框架: ```python from fastapi import FastAPI from pydantic import BaseModel app = FastAPI() class PredictRequest(BaseModel): text: str class PredictResponse(BaseModel): label: str confidence: float @app.post("/predict", response_model=PredictResponse) async def predict(request: PredictRequest): result = model.predict(request.text) return PredictResponse(label=result.label, confidence=result.score) ``` - FastAPI 自动生成 API 文档(在 `/docs` 的 Swagger UI),使用 Pydantic 模型验证输入/输出,并支持异步以实现高吞吐量。 - **gRPC** 在内部服务间通信方面比 REST 更快。它使用 Protocol Buffers(二进制序列化,比 JSON 更小更快)并支持流式传输。TensorFlow Serving、Triton Inference Server 和许多微服务架构都使用它。 ## 打包与分发 - 让你的代码可以作为包安装,使其他人(和你自己的脚本)可以干净地导入: ```toml # pyproject.toml [project] name = "my-ml-project" version = "0.1.0" requires-python = ">=3.10" dependencies = [ "torch>=2.0", "jax>=0.4", "pydantic>=2.0", ] [project.optional-dependencies] dev = ["pytest", "ruff", "mypy"] [build-system] requires = ["setuptools>=64"] build-backend = "setuptools.backends._legacy:_Backend" ``` ```bash pip install -e ".[dev]" # 以可编辑模式安装,包含开发依赖 ``` - **可编辑安装**(`-e`):对源代码的更改会立即生效,无需重新安装。开发期间必不可少。 - **锁定依赖**:使用确切版本的 `requirements.txt`(`torch==2.2.1`,而不是 `torch>=2.0`)确保可重现性。使用 `pip freeze > requirements.txt` 捕获你当前的环境。对于更复杂的依赖管理,使用 `uv`、`poetry` 或 `pip-tools`。 ## 使用 AI 编码助手 - AI 编码助手(Claude Code、GitHub Copilot、Cursor 等)现在已成为专业工程师工作流程的一部分。使用得当,它们能极大加速开发。使用不当,它们会引入微妙的错误、侵蚀你对代码库的理解,并制造虚假的生产力感。 - 正确的心智模型:**AI 助手是一个快速但缺乏经验的结对程序员**。它可以快速编写代码,熟悉语法和标准模式,并且阅读过的文档比你还多。但它不了解你的特定系统、业务约束、边界情况以及设计决策背后的*原因*。你是高级工程师;AI 助手是初级工程师。你来指导、审查并承担责任。 ### AI 助手擅长之处 - **样板代码和脚手架**:生成 Dockerfile、CI 配置、测试夹具、数据类定义、argparse 设置。这些遵循众所周知的模式,手动编写很繁琐。让 AI 生成它们,然后审查正确性。 - **编写测试**:描述函数的行为,AI 助手生成测试用例。它通常会捕捉到你可能会遗漏的边界情况(空输入、负值、Unicode)。始终阅读生成的测试——它们验证的是你的假设,而不仅仅是你的代码。 - **重构**:"将这个块提取成函数"、"将这个类改为使用 dataclasses"、"给这个模块添加类型提示"。机械性的转换,意图明确,引入细微错误的风险较低。 - **探索和原型开发**:"写一个快速脚本来 benchmark 推理延迟"或"展示如何使用 HuggingFace tokeniser API"。AI 助手能比阅读文档更快地给你一个可用的起点。 - **文档和 docstrings**:AI 助手可以根据你的代码结构生成文档。你需要审查准确性,但苦力活已经自动化了。 - **调试辅助**:粘贴错误回溯信息并请求诊断。AI 助手通常能识别根本原因并提出修复建议,尤其是对于常见问题(形状不匹配、导入错误、CUDA 内存不足)。 ### 何时不应依赖 AI 助手 - **新颖的架构决策**:如果你正在设计一个新的训练管道,AI 助手会给出一个通用的答案。它不了解你的数据约束、延迟要求或团队专业知识。使用 AI 助手来实现你已经深思熟虑的设计。 - **安全关键代码**:认证、加密、输入清理。AI 助手可能生成看起来正确但存在细微漏洞的代码(SQL 注入、不安全的默认值、时序攻击)。安全代码应由理解威胁模型的人编写,并由另一个人审查。 - **性能关键的内循环**:AI 助手会编写正确但天真的代码。对于 GPU 内核、内存关键的数据结构或延迟敏感的推理路径,你需要理解硬件约束(第 13 章、第 16 章)并有目的地进行优化。 - **你不理解的代码**:如果 AI 助手生成了 200 行代码,而你无法解释每一行的作用,那就不要提交。你现在正在维护你不理解的代码,当它出问题时(它会的),你无法调试。这是最常见也最危险的失败模式。 ### 审查纪律 - **在提交前始终逐行阅读**生成的代码。这不是可选的。AI 助手的代码是草稿,不是成品。就像对待同事的拉取请求一样:批判性地审查它。 - **检查什么**: - **正确性**:它是否真的做了你要求的事情?AI 助手经常解决与你意图略有不同的问题。 - **边界情况**:它是否处理了空输入、None 值、负数、非常大的输入?AI 助手经常省略边界情况处理。 - **幻想的 API**:AI 助手可能调用不存在函数或使用不存在的参数,尤其是对于较新或较少使用的库。验证每个 API 调用是否真实存在。 - **过度工程**:AI 助手倾向于产生比必要更多的代码。一个 50 行的解决方案解决一个 10 行的问题,增加了不必要的复杂性。无情地简化。 - **安全性**:硬编码的密钥、未经清理的用户输入、不安全的默认值。AI 助手不会以对抗性思维思考。 - **风格一致性**:生成的代码是否与项目的约定一致(命名、模式、错误处理)? ### 如何编写好的提示词 - AI 助手输出的质量直接与你的指令质量成正比。模糊的提示词得到模糊的代码。 - **糟糕**:"写一个数据加载器" - **好**:"为一个包含'text'和'label'列的 CSV 文件编写一个 PyTorch DataLoader。使用 HuggingFace tokeniser 'bert-base-uncased' 对文本进行分词,max_length=512。返回 input_ids、attention_mask 和 label 作为张量。处理 CSV 中标签列有缺失值的情况,跳过那些行。" - **提供上下文**:告诉 AI 助手你的项目结构、现有代码、约束和约定。上下文越多,输出越好。 - **指定约束**:"只使用标准库"、"必须兼容 Python 3.10"、"不要使用全局变量"、"遵循 `src/models/transformer.py` 中的现有模式"。 - **要求解释**:"实现 X 并解释关键的设计决策。"这会迫使 AI 助手阐述其推理,使你更容易发现错误假设。 ### 使用质量门控来捕捉 AI 助手的错误 - 你现有的质量基础设施(文件 04)捕捉 AI 助手的错误与捕捉人类的错误同样有效: - **类型检查(mypy)**:捕捉幻想的 API 签名和类型不匹配。 - **代码检查(ruff)**:捕捉未使用的导入、未定义的变量和风格违规。 - **测试(pytest)**:如果 AI 助手的代码通过了你的测试套件,它更可能是正确的。如果你还没有测试,在要求 AI 助手实现功能之前*先编写测试*(测试驱动开发与 AI 助手配合得特别好)。 - **CI 管道**:在每次提交时自动运行上述所有检查。 - **"AI 助手写代码" + "质量门控验证"** 的组合比单独使用任何一种都更高效。AI 助手快速但草率;门控工具彻底但不写代码。两者结合,你同时获得速度和正确性。 ### 生产力陷阱 - 使用编码助手的最大风险是**生产力的幻觉**。你可以在 10 分钟内生成 500 行代码。但如果你花 2 小时调试这些你并不理解的 500 行代码,那还不如自己花 30 分钟写 200 行代码来得快。 - 使用 AI 助手的真正生产力来自: 1. **保持控制**:你决定架构,AI 助手填入实现。 2. **理解生成的内容**:如果你无法解释它,就重写它或让 AI 助手简化它。 3. **投资质量门控**:测试、类型和代码检查的成本通过每次 AI 交互分摊。 4. **利用 AI 助手弥补你的弱点**:如果你擅长算法但编写测试很慢,让 AI 助手写测试。如果你对 UI 代码很快但不熟悉数据库查询,让 AI 助手草拟 SQL。发挥你的优势,委托你的短板。 - 从编码助手中获益最多的工程师是那些已经擅长编码的人。AI 助手放大你现有的技能;它不会取代你的技能。理解数据结构、算法、系统设计和软件工程(整章的内容)让你能够有效地指导 AI 助手并批判性地评估其输出。