当前位置: 首页 > news >正文

这个 Python 泛型仓库让你少写 80% 重复代码(附代码)

本文约4000字,建议阅读5分钟本文介绍了用 Python 泛型和 SQLAlchemy 实现通用仓库,告别重复 CRUD。

你还在为每个实体手写CRUD?这个Python泛型仓库模式让你一次编写,随处复用


一个真实场景:刚接手一个FastAPI项目,打开代码库,UserRepository、ProductRepository、OrderRepository……每个文件都在重复同样的save、get、update、delete逻辑。复制粘贴了8次之后,我开始怀疑人生——我们真的需要为每个数据表写一遍相同的代码吗?

如果你也有同样的困惑,今天这篇文章会给你一个答案。我将带你用Python泛型和SQLAlchemy,实现一个类型安全、可扩展、可复用的通用仓库模式,让你从此告别重复的CRUD代码。

重复的代码,重复的痛苦


在大多数FastAPI或SQLAlchemy项目中,仓库层(Repository)长这样:

class UserRepository: def __init__(self, session: AsyncSession): self._session = session asyncdef save(self, user: User) -> User: model = UserModel(name=user.name, email=user.email) self._session.add(model) await self._session.flush() await self._session.refresh(model) return self._to_entity(model) asyncdef get(self, user_id: UUID) -> User | None: result = await self._session.scalar( select(UserModel).where(UserModel.id == user_id) ) return self._to_entity(result) if result elseNone # ... 更多方法

然后你创建ProductRepository——复制粘贴。


OrderRepository——再次复制粘贴。

每个仓库都包含:

  • 相同的CRUD逻辑

  • 相同的分页逻辑

  • 相同的错误处理

  • 相同的SQLAlchemy操作模式

唯一变化的只有三样东西:

  • 实体类型(如User)

  • ORM模型类型(如UserModel)

  • 实体与模型之间的映射

⚠️ 注意:这种重复代码是“复制粘贴综合症”的典型表现,90%的团队在这里踩坑——当业务逻辑需要修改时,你要在8个仓库里改8遍,漏改一个就是Bug。

解决方案:一个通用的抽象仓库


一个设计良好的通用仓库应该做到:

  • 实现所有常见CRUD操作

  • 支持分页、排序、存在性检查、计数

  • 通过Python泛型保证类型安全

  • 允许自定义实体与模型的映射

  • 允许每个仓库自定义过滤条件

  • 保持代码整洁、可扩展、易测试

下面是一份生产级的实现代码。

核心组件:实体基类


首先,需要一个所有领域实体共享的基类,保证统一的结构:

from dataclasses import dataclass, fieldfrom datetime import datetime, timezonefrom uuid import UUID@dataclass(kw_only=True)class EntityBase: id: UUID | None = None created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))


辅助工具:异常与排序


class DatabaseException(Exception): """数据库操作异常的统一包装""" passfrom enum import StrEnumclass Ordering(StrEnum): """排序方向,类型安全""" asc = "asc" desc = "desc"

通用仓库实现


这是整个模式的核心。我把它拆成两部分讲解,但你可以直接复制使用。

from abc import ABC, abstractmethodfrom typing import Any, Generic, List, TypeVarimport sqlalchemyfrom sqlalchemy import asc, desc, func, selectfrom sqlalchemy.exc import IntegrityError, SQLAlchemyErrorfrom sqlalchemy.ext.asyncio import AsyncSession# 假设你的Base类在这里定义from .... import Basefrom domain.value_objects.ordering import Orderingfrom domain.entities.base import EntityBasefrom domain.exceptions.common import DatabaseExceptionEntity = TypeVar("Entity", bound=EntityBase)SqlAlchemyModel = TypeVar("SqlAlchemyModel", bound=Base)class SqlAlchemyAbstractRepository(ABC, Generic[Entity, SqlAlchemyModel]): # 子类必须指定具体的ORM模型类 model: type[SqlAlchemyModel] def __init__(self, session: AsyncSession) -> None: self._session = session asyncdef save(self, entity: Entity) -> Entity: """保存实体,返回包含数据库生成字段(如ID)的完整实体""" model = self._entity_to_model(entity) self._session.add(model) await self._session.flush() await self._session.refresh(model) return self._model_to_entity(model) asyncdef update( self, fields_to_update: dict[str, Any], **filters, ) -> int: """根据过滤条件更新字段,返回受影响的行数""" try: filter_conditions = self._get_filters(**filters) query = ( sqlalchemy.update(self.model) .where(*filter_conditions) .values(fields_to_update) ) result = await self._session.execute(query) await self._session.flush() return result.rowcount # type: ignore[attr-defined] except IntegrityError as exception: await self._session.rollback() raise exception except SQLAlchemyError as exception: await self._session.rollback() raise DatabaseException from exception asyncdef list_all( self, page: int = 1, limit: int = 10, order_by: str = "created_at", ordering: Ordering = Ordering.asc, **filters, ) -> List[Entity]: """分页列表查询,支持排序和过滤""" query = select(self.model) filter_conditions = self._get_filters(**filters) query = query.where(*filter_conditions) # 排序 query = query.order_by( self._get_order_expression(order_by=order_by, ordering=ordering) ) # 分页 offset = (page - 1) * limit query = query.offset(offset).limit(limit) result = await self._session.execute(query) models = result.scalars().all() return [self._model_to_entity(model) for model in models] asyncdef get( self, **filters, ) -> Entity | None: """根据过滤条件获取单个实体""" query = select(self.model) filter_conditions = self._get_filters(**filters) query = query.where(*filter_conditions) model = await self._session.scalar(query) return self._model_to_entity(model) if model elseNone asyncdef exists( self, **filters, ) -> bool: """检查是否存在满足条件的记录""" query = select(self.model) filter_conditions = self._get_filters(**filters) query = query.where(*filter_conditions) result = await self._session.scalar(query) return result isnotNone asyncdef delete( self, **filters, ) -> int: """根据过滤条件删除记录,返回删除的行数""" try: query = sqlalchemy.delete(self.model) filter_conditions = self._get_filters(**filters) query = query.where(*filter_conditions) result = await self._session.execute(query) await self._session.flush() return result.rowcount # type: ignore[attr-defined] except SQLAlchemyError as e: await self._session.rollback() raise DatabaseException from e asyncdef count( self, **filters, ) -> int: """统计满足条件的记录数""" filter_conditions = self._get_filters(**filters) return ( await self._session.scalar( select(func.count()).select_from(self.model).where(*filter_conditions) ) or0 ) @staticmethod @abstractmethod def _model_to_entity(model: SqlAlchemyModel) -> Entity: """将ORM模型转换为领域实体——子类必须实现""" raise NotImplementedError("Subclasses must implement _model_to_entity") @staticmethod @abstractmethod def _entity_to_model(entity: Entity) -> SqlAlchemyModel: """将领域实体转换为ORM模型——子类必须实现""" raise NotImplementedError("Subclasses must implement _entity_to_model") @abstractmethod def _get_filters(self, **filters) -> List[Any]: """将业务层过滤条件转换为SQLAlchemy查询条件——子类可重写""" return [] @staticmethod def _get_order_expression( order_by: str, ordering: Ordering ) -> sqlalchemy.UnaryExpression[str]: """生成排序表达式""" if ordering == Ordering.asc: return asc(order_by) return desc(order_by)

泛型解析:用生活化类比理解


如果上面这段代码让你有点晕,我用一个类比帮你理清:

泛型就像订餐平台的模板:

  • Entity = TypeVar("Entity", bound=EntityBase) —— 这就像“我要一份饭”,但具体是盖浇饭还是炒饭,后面再定

  • Model = TypeVar("Model", bound=Base) —— 这就像“我要一个餐具”,具体是碗还是盘子,也后面再定

  • SqlAlchemyAbstractRepository[Entity, Model] —— 这个组合就像“我要一份(某种饭)搭配(某种餐具)的套餐”

当你创建具体仓库时:

class UserRepository(SqlAlchemyAbstractRepository[User, UserModel]): ...

就相当于说:“我要一份User饭装在UserModel餐具里。”

IDE现在就能准确知道:

  • save() 接收User,返回User

  • _model_to_entity() 必须把UserModel映射成User

  • 过滤条件只接受对User有效的字段

⚠️ 关键点:Python虽然是动态语言,但通过类型提示和泛型,你可以获得编译时类型检查的能力。这在多人协作时,能避免无数“不小心传错参数”的Bug。

实战:创建具体的UserRepository


现在创建一个用户仓库,你会发现只需要写三件事:

  1. 指定model类

  2. 实现映射逻辑

  3. 定义支持的过滤条件

class SqlAlchemyUserRepository( SqlAlchemyAbstractRepository[User, UserModel],): model = UserModel def _entity_to_model(self, entity: User) -> UserModel: model = UserModel( name=entity.name, email=entity.email, role=entity.role, ) # 如果实体已有ID(更新场景),保持ID if entity.id: model.id = entity.id return model def _model_to_entity(self, model: UserModel) -> User: return User( id=model.id, name=model.name, email=model.email, role=model.role, created_at=model.created_at, updated_at=model.updated_at, ) def _get_filters(self, **filters): """支持三种过滤条件:id、email、role""" conditions = [] if"id_filter"in filters: conditions.append(UserModel.id == filters["id_filter"]) if"email_filter"in filters: conditions.append(UserModel.email == filters["email_filter"]) if"role_filter"in filters: conditions.append(UserModel.role == filters["role_filter"]) return conditions

看到没? 整个仓库就这么点代码。

  • CRUD?已经处理好了

  • 分页?已经处理好了

  • 错误处理?已经处理好了

你的仓库只需要关注领域特有的逻辑。

为什么_get_filters这么重要?


它让你的查询API既干净又灵活:

# 查询管理员admins = await user_repo.list_all( role_filter="admin", page=1, limit=20)# 按邮箱查找单个用户user = await user_repo.get(email_filter="john@example.com")# 检查用户是否存在exists = await user_repo.exists(email_filter="john@example.com")不需要为每个查询写单独的SQL,所有过滤条件统一通过_get_filters转换为查询条件。

自定义错误处理:保留灵活扩展的空间


需要处理特定业务的数据库错误?只需覆盖方法:

class SqlAlchemyUserRepository(...): # ... 前面的代码 asyncdef save(self, entity: User) -> User: try: returnawait super().save(entity) except IntegrityError as e: await self._session.rollback() # 检查是否是邮箱重复 if"ix_users_email"in str(e): raise UserAlreadyExistsError(entity.email) raise

⚠️ 注意:这里的关键是await self._session.rollback()——忘记回滚会让session处于异常状态,后续操作都会失败。这是90%的人踩过的坑。

添加自定义方法:通用 ≠ 不能定制


通用仓库不代表不能添加特定查询:

class SqlAlchemyUserRepository(...): # ... 前面的代码 asyncdef get_by_email(self, email: str) -> User | None: """按邮箱获取用户(业务常用)""" returnawait self.get(email_filter=email) asyncdef get_active_admins(self) -> List[User]: """获取活跃管理员(业务特定)""" returnawait self.list_all( role_filter="admin", status_filter="active" )

通用 ≠ 限制,而是从强大的基础上开始。

真实项目效果对比


在重构一个中等规模的FastAPI项目后,数据是这样的:

维度

重构前

重构后

仓库数量

8个

8个

单个仓库代码量

250-400行

30-50行

CRUD重复代码

每个仓库重复

0(全部复用)

修改分页逻辑

改8个地方

改1个地方

类型安全

❌ 随意传参

✅ 编译时检查

核心洞察:这种模式不仅减少了代码量,更重要的是——逻辑集中在一处,修改一次生效全局,Bug率显著下降。

为什么这个模式值得你采用?


1. DRY原则落地
写一次,修一次,处处生效。

2. 一致性保障
所有仓库行为统一,新人上手零学习成本。

3. 类型安全
告别Any和随意传递的字典,IDE能给你准确的代码补全。

4. 可测试性
测试一次基类,所有仓库都得到测试覆盖。

5. 可维护性
想加软删除?在基类改一次,所有仓库自动支持。

6. 灵活性
需要特殊行为?覆盖方法即可,基类不限制你。

写在最后


从复制粘贴8个仓库,到用泛型基类一行行抽象出来,这个过程让我意识到一件事:

好的抽象不是炫技,而是当你需要修改代码时,发现只需要改一个地方。

通用仓库模式在Python生态中并不算新,但它结合async、SQLAlchemy和泛型后,能给你的代码质量带来质的飞跃。下次你再新建一个实体时,不用再写那300行CRUD,只需30行映射和过滤逻辑。

如果你正在维护一个数据访问层臃肿的项目,建议逐个仓库迁移,而不是一次性全量替换。先迁移一个非核心的仓库,验证无误后再逐步推进。

核心内容

  1. 原理:泛型+抽象基类,让CRUD逻辑一次性实现,类型安全有保障

  2. 实践:子类只需实现映射和过滤,所有操作自动获得

  3. 避坑:记得处理事务回滚,自定义过滤用_get_filters统一入口

编辑:于腾凯

校对:孙英杰

关于我们

数据派THU作为数据科学类公众号,背靠清华大学大数据研究中心,分享前沿数据科学与大数据技术创新研究动态、持续传播数据科学知识,努力建设数据人才聚集平台、打造中国大数据最强集团军。

新浪微博:@数据派THU

微信视频号:数据派THU

今日头条:数据派THU

http://www.jsqmd.com/news/775730/

相关文章:

  • 收藏 | 程序员小白必看:揭秘大模型 KVCache 的演进与优化秘籍
  • 亲身感受:我眼中的壹肆叁叁教育咨询(山东)有限公司 - 速递信息
  • 手把手教你为i.MX6ULL开发板驱动1.3寸ST7789 TFT屏(含完整设备树与驱动代码)
  • 在树莓派4B(ARM64)上源码编译PyQt5完整流程:从Python3.7到解决Qt::ItemDataRole编译错误
  • 程序员提效神器:Gemini3.1Pro自动生成代码注释与文档
  • 透明背景图片制作方法大全:从零基础到高效批量处理
  • 【AISMM+ESG融合实践手册】:全球仅12家通过奇点认证的企业都在用的6步嵌入法(附ISO/IEC 42001映射表)
  • 如何为每个Android应用独立设置虚拟位置?FakeLocation精准位置控制方案
  • Qdrant向量数据库MCP服务器:AI智能体标准化工具集成指南
  • CoPaw:开源个人AI工作站部署与实战指南
  • 百度网盘解析工具完整指南:告别限速下载的终极方案
  • ARM调试器在SoC开发中的核心价值与应用实践
  • 如何在Zotero中实现文献阅读进度可视化和智能管理?终极指南
  • 解锁碧蓝航线全自动游戏体验:你的智能航海助手
  • 科研图表数据提取终极指南:如何用WebPlotDigitizer高效获取隐藏数据?
  • SynthID-Image:不可见数字水印技术解析与实践
  • 多终端命令历史实时同步工具multicli的设计与部署指南
  • 为什么92%的AI厂商误读AISMM?奇点大会闭门报告泄露:市场定位错配导致ROI下降47%的实证数据
  • WarcraftHelper完整指南:魔兽争霸III游戏优化终极教程
  • 终极跨平台硬件调优指南:Universal x86 Tuning Utility如何释放你的Intel/AMD设备全部潜力
  • 多智能体协作平台AgentLayer:从架构设计到工程实践
  • Scroll Reverser终极指南:揭秘macOS滚动方向深度定制技术
  • PotPlayer字幕翻译终极指南:免费实现实时双语字幕的完整教程
  • GDScript代码质量工具链:从格式化到静态分析的工程实践
  • Windows全局钩子与透明窗口实现鼠标光标高亮器技术解析
  • 如何快速掌握Jasminum:面向中文研究者的Zotero终极解决方案
  • Sorbetto:为Ruby开发者打造的VS Code增强插件,提升Sorbet开发体验
  • XXMI启动器:一站式二次元游戏模组管理终极指南,告别繁琐手动配置
  • ClipTalk:基于Go的短视频去水印与语音转文字API服务实战
  • 开源工具token-usage-ui:可视化监控LLM API Token用量与成本