跳转至

代码库设计与模式

良好的代码库设计是区分研究原型与生产级软件的关键。本文涵盖项目结构、整洁代码原则、与机器学习相关的设计模式、配置管理、日志、API 设计以及打包分发。

  • 大多数机器学习代码始于 Jupyter notebook。Notebook 不断增长、被复制、修改、共享,最终变成由全局变量、死单元格和魔数组成的难以维护的混乱。代码库设计是一门组织代码的学科,使代码在项目增长过程中保持可理解和可修改。

  • 这不是为了遵循规则而遵循规则。而是为了减少从"我想改变 X"到"X 已被修改并能正常工作"之间的时间。在精心设计的代码库中,这个时间是几分钟。在设计糟糕的代码库中,则需要几天的时间去考古、翻阅未记录的意大利面条式代码。

项目结构

  • 一致的项目布局让任何人(包括未来的你)都能立即浏览代码库。
my_project/
├── src/my_project/       # 源代码(可导入的包)
│   ├── __init__.py
│   ├── data/             # 数据加载和预处理
│   │   ├── __init__.py
│   │   ├── dataset.py
│   │   └── transforms.py
│   ├── models/           # 模型架构
│   │   ├── __init__.py
│   │   ├── transformer.py
│   │   └── layers.py
│   ├── training/         # 训练循环、优化器
│   │   ├── __init__.py
│   │   ├── trainer.py
│   │   └── losses.py
│   └── utils/            # 共享工具
│       ├── __init__.py
│       └── logging.py
├── configs/              # 配置文件
│   ├── base.yaml
│   └── experiment_1.yaml
├── scripts/              # 入口点(训练、评估、推理)
│   ├── train.py
│   ├── evaluate.py
│   └── serve.py
├── tests/                # 测试文件(镜像 src/ 结构)
│   ├── test_dataset.py
│   ├── test_model.py
│   └── test_trainer.py
├── notebooks/            # 仅用于探索(非生产代码)
├── pyproject.toml        # 项目元数据和依赖
├── README.md
├── .gitignore
└── Dockerfile
  • src/ 布局:将源代码放在 src/my_project/ 下可以防止从当前目录意外导入(这会掩盖在生产环境中才会暴露的导入错误)。使用 pip install -e . 进行开发安装。

  • 单仓库 vs 多仓库单仓库将所有相关项目放在一个仓库中(跨项目更改更容易、CI 共享)。多仓库给每个项目自己的仓库(边界更清晰、版本控制独立)。大多数机器学习团队从单仓库开始,必要时再拆分。

  • 脚本 vs 库:将入口点(train.pyevaluate.py)保留在 scripts/ 中。将可复用的逻辑放在 src/ 中。训练脚本应约为 50 行:解析配置、构建数据集、构建模型、构建训练器、训练。所有复杂性都在库中。

整洁代码原则

  • 命名:你能做的唯一最有影响力的事情。名为 x 的变量需要你阅读周围的代码才能理解。名为 learning_rate 的变量是自解释的。
# 糟糕
def proc(d, n, lr):
    for i in range(n):
        for k, v in d.items():
            v -= lr * g[k]

# 良好
def update_parameters(parameters, num_steps, learning_rate):
    for step in range(num_steps):
        for name, param in parameters.items():
            param -= learning_rate * gradients[name]
  • 单一职责原则:每个函数/类只做一件事。名为 load_data_and_train_model 的函数在做两件事,应该拆分。这使每个部分都可以独立测试、复用和理解。

  • DRY(不要重复自己)——但不要过早抽象。如果你复制粘贴代码三次,将其提取为一个函数。但不要为只使用过一次的代码创建抽象。过早的抽象比重复更糟糕:它增加了复杂性但没有经过验证的好处。

# 过早抽象(一个用例,过度设计)
class AbstractDataTransformPipelineFactory:
    ...

# 恰到好处(直接、清晰、在三处使用)
def normalise_image(image, mean, std):
    return (image - mean) / std
  • 魔数:永远不要使用未解释的字面值。
# 糟糕
if len(batch) > 32:
    split_batch(batch, 32)

# 良好
MAX_BATCH_SIZE = 32
if len(batch) > MAX_BATCH_SIZE:
    split_batch(batch, MAX_BATCH_SIZE)
  • 函数应该简短:如果一个函数不能在一屏内显示完整(约 30 行),那它可能做得太多了。将逻辑块提取为带有描述性名称的辅助函数。然后函数体读起来就像高级摘要。

适用于机器学习的设范计式

  • 设计模式是针对常见问题的可复用解决方案。以下是与机器学习代码库最相关的模式:

  • 工厂模式:在不指定确切类的情况下创建对象。当你的配置说 model: "transformer" 并且你需要实例化正确的类时很有用:

MODEL_REGISTRY = {
    "transformer": TransformerModel,
    "cnn": CNNModel,
    "mlp": MLPModel,
}

def build_model(config):
    model_cls = MODEL_REGISTRY[config["model"]]
    return model_cls(**config["model_params"])
  • 这使训练脚本与特定的模型实现解耦。添加新模型意味着在注册表中添加一行,而不是修改训练循环。

  • 策略模式:在运行时交换算法。适用于损失函数、优化器、调度器:

LOSS_FUNCTIONS = {
    "mse": nn.MSELoss,
    "cross_entropy": nn.CrossEntropyLoss,
    "focal": FocalLoss,
}

loss_fn = LOSS_FUNCTIONS[config["loss"]]()
  • 观察者模式(回调/钩子):让模块响应事件而不紧密耦合。训练框架(PyTorch Lightning、Keras)广泛使用回调:
class EarlyStopping:
    def __init__(self, patience=5):
        self.patience = patience
        self.best_loss = float('inf')
        self.counter = 0

    def on_epoch_end(self, epoch, val_loss):
        if val_loss < self.best_loss:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return "stop"
  • 依赖注入:将依赖项传入函数/类,而不是在内部创建。这使得测试变得容易(注入 mock)并且配置灵活:
# 糟糕:硬编码依赖
class Trainer:
    def __init__(self):
        self.logger = WandbLogger()  # 没有 W&B 就无法测试

# 良好:注入依赖
class Trainer:
    def __init__(self, logger):
        self.logger = logger  # 可以注入任何记录器,包括 mock

配置管理

  • 硬编码超参数、文件路径和模型设置使实验无法重现,修改也很痛苦。将配置外部化到文件中。

  • YAML 是机器学习配置最常见的格式:

# configs/experiment_1.yaml
model:
  name: transformer
  d_model: 512
  n_heads: 8
  n_layers: 6

training:
  batch_size: 64
  learning_rate: 3e-4
  max_epochs: 100
  early_stopping_patience: 10

data:
  train_path: /data/train.parquet
  val_path: /data/val.parquet
  max_seq_length: 512
  • Hydra(Facebook)是一个支持组合(将基础配置与实验特定覆盖合并)、命令行覆盖(python train.py training.lr=1e-3)和多运行(超参数扫描)的配置框架。

  • argparse 适用于参数较少的脚本:

import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--config", type=str, default="configs/base.yaml")
args = parser.parse_args()
  • 最佳实践:有一个包含所有默认值的基础配置,以及每个实验的配置,只覆盖更改的部分。追踪每个实验的配置及其结果。

日志与可观测性

  • print 语句用于调试。日志用于生产环境:
import logging

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

logger.debug("Batch loaded: %d samples", len(batch))     # 详细,用于调试
logger.info("Epoch %d: loss=%.4f, lr=%.6f", epoch, loss, lr)  # 正常运行
logger.warning("GPU memory >90%%, consider reducing batch size")
logger.error("Failed to load checkpoint: %s", path)       # 可恢复的错误
logger.critical("CUDA out of memory, aborting")            # 致命错误
  • 为什么不用 print:日志支持级别(在生产环境中过滤调试消息)、格式化(时间戳、模块名)和处理程序(写入文件、发送到监控系统),而无需更改日志调用。

  • 结构化日志同时输出机器可解析的格式(JSON)和人类可读的消息。这使得可以搜索特定字段并设置告警:

logger.info("training_step", extra={
    "epoch": 5, "step": 1200, "loss": 0.0342, "lr": 2.1e-4
})

API 设计

  • 如果你的模型将被其他服务使用(Web 应用、移动应用、另一个机器学习管道),它需要一个 API(应用程序编程接口)。

  • REST API 使用 HTTP 方法:GET 用于读取,POST 用于创建/预测,PUT 用于更新,DELETE 用于删除。端点遵循基于资源的命名:

POST /api/v1/predict          # 发送输入,获取预测结果
GET  /api/v1/models           # 列出可用模型
GET  /api/v1/models/{id}      # 获取模型详情
POST /api/v1/models/{id}/predict  # 使用特定模型进行预测
  • FastAPI 是机器学习推理的首选 Python 框架:
from fastapi import FastAPI
from pydantic import BaseModel

app = FastAPI()

class PredictRequest(BaseModel):
    text: str

class PredictResponse(BaseModel):
    label: str
    confidence: float

@app.post("/predict", response_model=PredictResponse)
async def predict(request: PredictRequest):
    result = model.predict(request.text)
    return PredictResponse(label=result.label, confidence=result.score)
  • FastAPI 自动生成 API 文档(在 /docs 的 Swagger UI),使用 Pydantic 模型验证输入/输出,并支持异步以实现高吞吐量。

  • gRPC 在内部服务间通信方面比 REST 更快。它使用 Protocol Buffers(二进制序列化,比 JSON 更小更快)并支持流式传输。TensorFlow Serving、Triton Inference Server 和许多微服务架构都使用它。

打包与分发

  • 让你的代码可以作为包安装,使其他人(和你自己的脚本)可以干净地导入:
# pyproject.toml
[project]
name = "my-ml-project"
version = "0.1.0"
requires-python = ">=3.10"
dependencies = [
    "torch>=2.0",
    "jax>=0.4",
    "pydantic>=2.0",
]

[project.optional-dependencies]
dev = ["pytest", "ruff", "mypy"]

[build-system]
requires = ["setuptools>=64"]
build-backend = "setuptools.backends._legacy:_Backend"
pip install -e ".[dev]"    # 以可编辑模式安装,包含开发依赖
  • 可编辑安装-e):对源代码的更改会立即生效,无需重新安装。开发期间必不可少。

  • 锁定依赖:使用确切版本的 requirements.txttorch==2.2.1,而不是 torch>=2.0)确保可重现性。使用 pip freeze > requirements.txt 捕获你当前的环境。对于更复杂的依赖管理,使用 uvpoetrypip-tools

使用 AI 编码助手

  • AI 编码助手(Claude Code、GitHub Copilot、Cursor 等)现在已成为专业工程师工作流程的一部分。使用得当,它们能极大加速开发。使用不当,它们会引入微妙的错误、侵蚀你对代码库的理解,并制造虚假的生产力感。

  • 正确的心智模型:AI 助手是一个快速但缺乏经验的结对程序员。它可以快速编写代码,熟悉语法和标准模式,并且阅读过的文档比你还多。但它不了解你的特定系统、业务约束、边界情况以及设计决策背后的原因。你是高级工程师;AI 助手是初级工程师。你来指导、审查并承担责任。

AI 助手擅长之处

  • 样板代码和脚手架:生成 Dockerfile、CI 配置、测试夹具、数据类定义、argparse 设置。这些遵循众所周知的模式,手动编写很繁琐。让 AI 生成它们,然后审查正确性。

  • 编写测试:描述函数的行为,AI 助手生成测试用例。它通常会捕捉到你可能会遗漏的边界情况(空输入、负值、Unicode)。始终阅读生成的测试——它们验证的是你的假设,而不仅仅是你的代码。

  • 重构:"将这个块提取成函数"、"将这个类改为使用 dataclasses"、"给这个模块添加类型提示"。机械性的转换,意图明确,引入细微错误的风险较低。

  • 探索和原型开发:"写一个快速脚本来 benchmark 推理延迟"或"展示如何使用 HuggingFace tokeniser API"。AI 助手能比阅读文档更快地给你一个可用的起点。

  • 文档和 docstrings:AI 助手可以根据你的代码结构生成文档。你需要审查准确性,但苦力活已经自动化了。

  • 调试辅助:粘贴错误回溯信息并请求诊断。AI 助手通常能识别根本原因并提出修复建议,尤其是对于常见问题(形状不匹配、导入错误、CUDA 内存不足)。

何时不应依赖 AI 助手

  • 新颖的架构决策:如果你正在设计一个新的训练管道,AI 助手会给出一个通用的答案。它不了解你的数据约束、延迟要求或团队专业知识。使用 AI 助手来实现你已经深思熟虑的设计。

  • 安全关键代码:认证、加密、输入清理。AI 助手可能生成看起来正确但存在细微漏洞的代码(SQL 注入、不安全的默认值、时序攻击)。安全代码应由理解威胁模型的人编写,并由另一个人审查。

  • 性能关键的内循环:AI 助手会编写正确但天真的代码。对于 GPU 内核、内存关键的数据结构或延迟敏感的推理路径,你需要理解硬件约束(第 13 章、第 16 章)并有目的地进行优化。

  • 你不理解的代码:如果 AI 助手生成了 200 行代码,而你无法解释每一行的作用,那就不要提交。你现在正在维护你不理解的代码,当它出问题时(它会的),你无法调试。这是最常见也最危险的失败模式。

审查纪律

  • 在提交前始终逐行阅读生成的代码。这不是可选的。AI 助手的代码是草稿,不是成品。就像对待同事的拉取请求一样:批判性地审查它。

  • 检查什么

    • 正确性:它是否真的做了你要求的事情?AI 助手经常解决与你意图略有不同的问题。
    • 边界情况:它是否处理了空输入、None 值、负数、非常大的输入?AI 助手经常省略边界情况处理。
    • 幻想的 API:AI 助手可能调用不存在函数或使用不存在的参数,尤其是对于较新或较少使用的库。验证每个 API 调用是否真实存在。
    • 过度工程:AI 助手倾向于产生比必要更多的代码。一个 50 行的解决方案解决一个 10 行的问题,增加了不必要的复杂性。无情地简化。
    • 安全性:硬编码的密钥、未经清理的用户输入、不安全的默认值。AI 助手不会以对抗性思维思考。
    • 风格一致性:生成的代码是否与项目的约定一致(命名、模式、错误处理)?

如何编写好的提示词

  • AI 助手输出的质量直接与你的指令质量成正比。模糊的提示词得到模糊的代码。

  • 糟糕:"写一个数据加载器"

  • :"为一个包含'text'和'label'列的 CSV 文件编写一个 PyTorch DataLoader。使用 HuggingFace tokeniser 'bert-base-uncased' 对文本进行分词,max_length=512。返回 input_ids、attention_mask 和 label 作为张量。处理 CSV 中标签列有缺失值的情况,跳过那些行。"

  • 提供上下文:告诉 AI 助手你的项目结构、现有代码、约束和约定。上下文越多,输出越好。

  • 指定约束:"只使用标准库"、"必须兼容 Python 3.10"、"不要使用全局变量"、"遵循 src/models/transformer.py 中的现有模式"。

  • 要求解释:"实现 X 并解释关键的设计决策。"这会迫使 AI 助手阐述其推理,使你更容易发现错误假设。

使用质量门控来捕捉 AI 助手的错误

  • 你现有的质量基础设施(文件 04)捕捉 AI 助手的错误与捕捉人类的错误同样有效:

    • 类型检查(mypy):捕捉幻想的 API 签名和类型不匹配。
    • 代码检查(ruff):捕捉未使用的导入、未定义的变量和风格违规。
    • 测试(pytest):如果 AI 助手的代码通过了你的测试套件,它更可能是正确的。如果你还没有测试,在要求 AI 助手实现功能之前先编写测试(测试驱动开发与 AI 助手配合得特别好)。
    • CI 管道:在每次提交时自动运行上述所有检查。
  • "AI 助手写代码" + "质量门控验证" 的组合比单独使用任何一种都更高效。AI 助手快速但草率;门控工具彻底但不写代码。两者结合,你同时获得速度和正确性。

生产力陷阱

  • 使用编码助手的最大风险是生产力的幻觉。你可以在 10 分钟内生成 500 行代码。但如果你花 2 小时调试这些你并不理解的 500 行代码,那还不如自己花 30 分钟写 200 行代码来得快。

  • 使用 AI 助手的真正生产力来自:

    1. 保持控制:你决定架构,AI 助手填入实现。
    2. 理解生成的内容:如果你无法解释它,就重写它或让 AI 助手简化它。
    3. 投资质量门控:测试、类型和代码检查的成本通过每次 AI 交互分摊。
    4. 利用 AI 助手弥补你的弱点:如果你擅长算法但编写测试很慢,让 AI 助手写测试。如果你对 UI 代码很快但不熟悉数据库查询,让 AI 助手草拟 SQL。发挥你的优势,委托你的短板。
  • 从编码助手中获益最多的工程师是那些已经擅长编码的人。AI 助手放大你现有的技能;它不会取代你的技能。理解数据结构、算法、系统设计和软件工程(整章的内容)让你能够有效地指导 AI 助手并批判性地评估其输出。