Python数据库编程与ORM
Python数据库编程与ORM
一、数据库连接基础
Python通过DB-API 2.0规范(PEP 249)统一了数据库接口。不同数据库使用不同的驱动,但API一致。
import sqlite3
# SQLite(内置,无需安装)
conn = sqlite3.connect('example.db')
cursor = conn.cursor()
# 创建表
cursor.execute('''
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
email TEXT UNIQUE,
age INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
# 插入数据(使用参数化查询防止SQL注入)
cursor.execute(
"INSERT INTO users (name, email, age) VALUES (?, ?, ?)",
("Alice", "alice@example.com", 30)
)
# 批量插入
users = [
("Bob", "bob@example.com", 25),
("Charlie", "charlie@example.com", 35),
("Diana", "diana@example.com", 28),
]
cursor.executemany(
"INSERT INTO users (name, email, age) VALUES (?, ?, ?)",
users
)
conn.commit()
# 查询
cursor.execute("SELECT * FROM users WHERE age > ?", (26,))
rows = cursor.fetchall()
for row in rows:
print(row)
conn.close()
二、上下文管理器与连接池
2.1 安全的数据库操作
from contextlib import contextmanager
@contextmanager
def get_db_connection(db_path):
conn = sqlite3.connect(db_path)
conn.row_factory = sqlite3.Row # 返回字典风格的行
try:
yield conn
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
@contextmanager
def get_cursor(conn):
cursor = conn.cursor()
try:
yield cursor
finally:
cursor.close()
# 使用
with get_db_connection('example.db') as conn:
with get_cursor(conn) as cursor:
cursor.execute("SELECT * FROM users")
for row in cursor.fetchall():
print(dict(row))
2.2 简单连接池
import queue
import threading
class ConnectionPool:
def __init__(self, db_path, max_connections=5):
self.db_path = db_path
self.pool = queue.Queue(maxsize=max_connections)
self.lock = threading.Lock()
self._size = 0
self._max = max_connections
def get_connection(self):
try:
return self.pool.get_nowait()
except queue.Empty:
with self.lock:
if self._size < self._max:
self._size += 1
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
return conn
# 等待可用连接
return self.pool.get(timeout=10)
def return_connection(self, conn):
self.pool.put(conn)
@contextmanager
def connection(self):
conn = self.get_connection()
try:
yield conn
conn.commit()
except Exception:
conn.rollback()
raise
finally:
self.return_connection(conn)
# 使用
pool = ConnectionPool('example.db', max_connections=10)
with pool.connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM users")
print(cursor.fetchone()[0])
三、MySQL与PostgreSQL
3.1 MySQL(使用pymysql)
import pymysql
config = {
'host': 'localhost',
'port': 3306,
'user': 'root',
'password': 'password',
'database': 'mydb',
'charset': 'utf8mb4',
'cursorclass': pymysql.cursors.DictCursor,
}
conn = pymysql.connect(**config)
try:
with conn.cursor() as cursor:
cursor.execute(
"SELECT * FROM users WHERE status = %s LIMIT %s",
('active', 10)
)
results = cursor.fetchall()
finally:
conn.close()
3.2 PostgreSQL(使用psycopg2)
import psycopg2
from psycopg2.extras import RealDictCursor
conn = psycopg2.connect(
host='localhost',
port=5432,
dbname='mydb',
user='postgres',
password='password'
)
with conn.cursor(cursor_factory=RealDictCursor) as cursor:
cursor.execute(
"SELECT * FROM users WHERE age BETWEEN %s AND %s",
(20, 30)
)
users = cursor.fetchall()
conn.commit()
conn.close()
四、SQLAlchemy Core
SQLAlchemy提供两层API:Core(SQL表达式)和ORM(对象关系映射)。
from sqlalchemy import create_engine, MetaData, Table, Column
from sqlalchemy import Integer, String, Float, DateTime, ForeignKey
from sqlalchemy import select, insert, update, delete, func, and_, or_
# 创建引擎
engine = create_engine(
'sqlite:///example.db',
echo=False, # True会打印SQL语句
pool_size=5,
max_overflow=10,
)
# 定义表结构
metadata = MetaData()
users = Table('users', metadata,
Column('id', Integer, primary_key=True),
Column('name', String(50), nullable=False),
Column('email', String(100), unique=True),
Column('age', Integer),
)
orders = Table('orders', metadata,
Column('id', Integer, primary_key=True),
Column('user_id', Integer, ForeignKey('users.id')),
Column('amount', Float),
Column('status', String(20)),
)
metadata.create_all(engine)
# 使用Core API
with engine.connect() as conn:
# 插入
conn.execute(insert(users).values(name='Eve', email='eve@example.com', age=22))
# 查询
stmt = select(users).where(users.c.age > 25).order_by(users.c.name)
result = conn.execute(stmt)
for row in result:
print(row.name, row.age)
# 聚合
stmt = select(func.count(), func.avg(users.c.age)).select_from(users)
count, avg_age = conn.execute(stmt).fetchone()
# 连接查询
stmt = (
select(users.c.name, orders.c.amount)
.join(orders, users.c.id == orders.c.user_id)
.where(orders.c.status == 'completed')
)
result = conn.execute(stmt)
conn.commit()
五、SQLAlchemy ORM
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from sqlalchemy.orm import Session, sessionmaker
from datetime import datetime
from typing import Optional
class Base(DeclarativeBase):
pass
class User(Base):
__tablename__ = 'users'
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(50))
email: Mapped[str] = mapped_column(String(100), unique=True)
age: Mapped[Optional[int]] = mapped_column(default=None)
created_at: Mapped[datetime] = mapped_column(default_factory=datetime.now)
# 关系
orders: Mapped[list["Order"]] = relationship(back_populates="user")
def __repr__(self):
return f"User(id={self.id}, name={self.name})"
class Order(Base):
__tablename__ = 'orders'
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(ForeignKey('users.id'))
amount: Mapped[float]
status: Mapped[str] = mapped_column(String(20), default='pending')
user: Mapped["User"] = relationship(back_populates="orders")
def __repr__(self):
return f"Order(id={self.id}, amount={self.amount})"
# 创建表
Base.metadata.create_all(engine)
# Session操作
SessionLocal = sessionmaker(bind=engine)
with SessionLocal() as session:
# 创建
user = User(name="Frank", email="frank@example.com", age=32)
session.add(user)
session.flush() # 获取自动生成的ID
order = Order(user_id=user.id, amount=99.99)
session.add(order)
# 查询
users = session.query(User).filter(User.age > 25).all()
# 新式查询(推荐)
from sqlalchemy import select
stmt = select(User).where(User.age > 25).order_by(User.name)
users = session.scalars(stmt).all()
# 关联查询
stmt = (
select(User)
.join(User.orders)
.where(Order.amount > 50)
.distinct()
)
users_with_big_orders = session.scalars(stmt).all()
# 更新
user = session.get(User, 1)
if user:
user.name = "Updated Name"
# 批量更新
session.execute(
update(User).where(User.age < 20).values(status='minor')
)
# 删除
session.execute(delete(Order).where(Order.status == 'cancelled'))
session.commit()
六、数据库迁移(Alembic)
# 安装: pip install alembic
# 初始化: alembic init migrations
# alembic.ini 配置
# sqlalchemy.url = sqlite:///example.db
# migrations/env.py 中设置 target_metadata
# target_metadata = Base.metadata
# 创建迁移脚本
# alembic revision --autogenerate -m "add phone column"
# 生成的迁移文件示例
"""
def upgrade():
op.add_column('users', sa.Column('phone', sa.String(20)))
def downgrade():
op.drop_column('users', 'phone')
"""
# 执行迁移
# alembic upgrade head
# alembic downgrade -1
七、Repository模式
from abc import ABC, abstractmethod
from typing import TypeVar, Generic, Optional
T = TypeVar('T')
class Repository(ABC, Generic[T]):
@abstractmethod
def get_by_id(self, id: int) -> Optional[T]:
pass
@abstractmethod
def get_all(self) -> list[T]:
pass
@abstractmethod
def add(self, entity: T) -> T:
pass
@abstractmethod
def update(self, entity: T) -> T:
pass
@abstractmethod
def delete(self, id: int) -> None:
pass
class UserRepository(Repository[User]):
def __init__(self, session: Session):
self.session = session
def get_by_id(self, id: int) -> Optional[User]:
return self.session.get(User, id)
def get_all(self) -> list[User]:
return self.session.scalars(select(User)).all()
def get_by_email(self, email: str) -> Optional[User]:
stmt = select(User).where(User.email == email)
return self.session.scalars(stmt).first()
def add(self, user: User) -> User:
self.session.add(user)
self.session.flush()
return user
def update(self, user: User) -> User:
self.session.merge(user)
return user
def delete(self, id: int) -> None:
user = self.get_by_id(id)
if user:
self.session.delete(user)
def search(self, name_query: str, min_age: int = 0) -> list[User]:
stmt = (
select(User)
.where(
and_(
User.name.ilike(f'%{name_query}%'),
User.age >= min_age
)
)
.order_by(User.name)
)
return self.session.scalars(stmt).all()
八、异步数据库操作
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.ext.asyncio import async_sessionmaker
# 异步引擎
async_engine = create_async_engine('sqlite+aiosqlite:///example.db')
AsyncSessionLocal = async_sessionmaker(async_engine, class_=AsyncSession)
async def get_users():
async with AsyncSessionLocal() as session:
stmt = select(User).where(User.age > 25)
result = await session.scalars(stmt)
return result.all()
async def create_user(name: str, email: str):
async with AsyncSessionLocal() as session:
user = User(name=name, email=email)
session.add(user)
await session.commit()
await session.refresh(user)
return user
九、查询优化
# 1. 预加载关联数据(避免N+1问题)
from sqlalchemy.orm import joinedload, selectinload
# 懒加载(默认)- 访问user.orders时才查询,导致N+1
users = session.scalars(select(User)).all()
# JOIN预加载 - 一次查询获取所有数据
stmt = select(User).options(joinedload(User.orders))
users = session.scalars(stmt).unique().all()
# 子查询预加载 - 两次查询
stmt = select(User).options(selectinload(User.orders))
users = session.scalars(stmt).all()
# 2. 只查询需要的列
stmt = select(User.name, User.email).where(User.age > 25)
# 3. 分页
stmt = select(User).offset(20).limit(10)
# 4. 使用索引
class User(Base):
__tablename__ = 'users'
__table_args__ = (
Index('idx_user_email', 'email'),
Index('idx_user_age_name', 'age', 'name'),
)
十、总结
数据库编程要点:
1. 始终使用参数化查询防止SQL注入
2. 使用连接池管理数据库连接
3. 合理使用事务保证数据一致性
4. ORM适合业务逻辑复杂的场景,Core适合性能敏感的批量操作
5. 注意N+1查询问题,合理使用预加载
6. 使用Alembic管理数据库迁移
