一,代码:
异常处理
# 2. 捕获 FastAPI 参数校验异常 (422 Unprocessable Entity)@app.exception_handler(RequestValidationError)async def validation_exception_handler(request: Request, exc: RequestValidationError):# 获取当前请求路由的匹配结果route = request.scope.get("route")custom_errors = []str_errors = ''# 获取异常自带的错误详情errors = exc.errors()for error in errors:# loc 包含了出错字段的路径,例如 ("body", "username") 或 ("query", "age")print("error:", error )loc = error.get("loc", [])print("loc:", loc)field_name = loc[-1] if loc else Nonelocation_type = loc[0] if len(loc) >= 1 else Noneprint("field_name:", field_name)description = "未定义描述"print("location_type:", location_type)# 如果是请求体(Body)校验失败,尝试从路由函数签名中找 Pydantic Model 并提取 descriptionis_base = is_basemodel(route, field_name) # 判断字段是否是基于BaseModel的子类print("is_base:", is_base)if route and hasattr(route, "dependant") and field_name:print("开始查找数据:", field_name)# 遍历路由中所有的 Query 参数定义all_flat_params = (route.dependant.query_params +route.dependant.path_params +route.dependant.header_params +route.dependant.cookie_params# route.dependant.body_params)for param in all_flat_params:print("dir_param:", dir(param))# 匹配当前报错的字段名(比如 "q")print("param.name:", param.name)if param.name == field_name:# 关键点:FastAPI 把 Query(description="...") 存放在了 param.field_info 中if hasattr(param, "field_info") and getattr(param.field_info, "description", None):description = param.field_info.descriptionprint("description1:",description)error_msg = f"{description}"if str_errors == '':str_errors = error_msgelse:str_errors += ";" + error_msgbreakbody_flat_params = (route.dependant.body_params)for param in body_flat_params:print("dir_param:", dir(param))# 匹配当前报错的字段名(比如 "q")print("param.name:", param.name)# if param.name == field_name:# 关键点:FastAPI 把 Query(description="...") 存放在了 param.field_info 中if hasattr(param, "field_info") and getattr(param.field_info, "description", None):description = param.field_info.descriptionprint("description1:",description)error_msg = f"{description}"if str_errors == '':str_errors = error_msgelse:str_errors += ";" + error_msgbreakif (location_type == "body" or location_type == "query") and route and field_name:endpoint = route.endpoint# 遍历路由函数的参数,寻找继承自 BaseModel 的参数import inspectsig = inspect.signature(endpoint)print(sig.parameters.items())field_info = ''for param_name, param in sig.parameters.items():model_cls = param.annotationprint("model_cls:",model_cls)print(dir(model_cls))if hasattr(model_cls, 'model_fields'):field_info = model_cls.model_fields.get(field_name)print("field_info:", field_info)if field_info and field_info.description:error_msg = f"{field_info.description}"print("error_msg:", error_msg)if str_errors == '':str_errors = error_msgelse:str_errors += ";" + error_msgbreakelse:print("没有model_fields属性")if str_errors == '':one_str_error = field_name + ":" + error['msg']str_errors = one_str_errorcontent = BaseResponse(code=422, msg=str_errors, data=custom_errors).model_dump()return JSONResponse(status_code=422,content=content)
调用:使用Annotated参数
@router.post("/onepost")
def get_one_post(# 校验 Query (GET) 参数:长度必须在 3 到 10 之间# q: Annotated[str, Query(min_length=3, max_length=10)],# 校验 Body (POST) 参数:必须大于 0# product_id: Annotated[float, Body(gt=0.0)],product_id: Annotated[int, Body(gt=0,description="商品id,需要是大于0的整数")],
):print("测试参数检测:",)return {"id": product_id, "name": f"未知"}@router.get("/oneprod")
def get_one_product(# 校验 Query (GET) 参数:长度必须在 3 到 10 之间# q: Annotated[str, Query(min_length=3, max_length=10)],# 校验 Body (POST) 参数:必须大于 0# product_id: Annotated[float, Body(gt=0.0)],product_id: Annotated[int, Query(gt=0,description="商品id,需要是大于0的整数")],
):print("测试参数检测:",)return {"id": product_id, "name": f"未知"}
基于basemodel类:get/post
# 1. 定义请求体数据模型
class ProductPage(BaseModel):category: str = Field(..., min_length=3, max_length=20,description="分类名长度是3-20个字符")page: int = Field(..., gt=0, description="页数需要大于0")per_page: int = Field(..., gt=0, description="每页数量需要大于0")@router.get("/all")
async def get_all_products(prod_data: Annotated[ProductPage, Query()],db: AsyncSession = Depends(get_db)):page = prod_data.pageper_page = prod_data.per_page# 异步写法不能直接用 db.query(),必须用 selectskip = (page-1) * per_pagelimit = per_pagestmt = select(Product).offset(skip).limit(limit)result = await db.execute(stmt) # 必须 await 数据库 I/Oprods = result.scalars().all()# 异步查询总条数 (注意 await)count_stmt = select(func.count()).select_from(Product)total_items = await db.scalar(count_stmt)# 计算分页元数据total_pages = math.ceil(total_items / per_page) if total_items > 0 else 1has_prev = page > 1has_next = page < total_pagesres = {"list": prods,"meta": {"total_items": total_items,"total_pages": total_pages,"current_page": page,"per_page": per_page,"has_prev": has_prev,"has_next": has_next,"prev_page": page - 1 if has_prev else None,"next_page": page + 1 if has_next else None}}return success(data=res)# 1. 定义请求体数据模型
class ProductCreate(BaseModel):name: str = Field(..., min_length=3, max_length=20,description="商品名称,长度是3-20个字符")price: int = Field(..., gt=0, description="价格需要大于0")stock: int = Field(..., ge=0, description="库存不能小于0")# 如果请求方式是get,使用: params: Annotated[FilterParams, Query()]@router.post("/add")
async def create_user(prod_data: Annotated[ProductCreate, Form()],db: AsyncSession = Depends(get_db)):# 检查用户名是否已存在result = await db.execute(select(Product).where(Product.name == prod_data.name))existing_prod = result.scalar_one_or_none()if existing_prod:raise HTTPException(status_code=400, detail="产品名已被使用")# 创建新产品new_prod = Product(name=prod_data.name,stock=prod_data.stock,price = prod_data.price,)# 添加到数据库db.add(new_prod)await db.commit() # 提交事务await db.refresh(new_prod) # 刷新,获取数据库生成的字段(如id)return success(data=new_prod)
基于basemodel类:json
# 1. 定义请求体数据模型
class UserLogin(BaseModel):username: str = Field(..., min_length=3, max_length=20,description="用户名,长度是3-20个字符")age: int = Field(..., ge=18, le=100, description="年龄必须在18到100岁之间")password: str = Field(..., min_length=6, description="密码长度最少6个字符")@router.post("/login")
async def account_login(user: UserLogin):user_info = {"age": user.age, "name": f"用户{user.username}"}return success(data=user_info)
二,测试效果:
get

get

post

post:

post:

