086、FastAPI 生产实践:数据库会话管理、JWT 认证、Rate Limiting
从一次线上事故说起
上周三凌晨两点,我被报警电话吵醒。生产环境的FastAPI服务突然大面积返回500错误,用户登录全部失败。我睡眼惺忪地打开日志,发现数据库连接池耗尽,所有请求都在等待数据库会话释放。更诡异的是,JWT token验证偶尔成功偶尔失败,同一个token在不同请求中表现不一致。
排查到最后,问题出在三个地方:数据库会话没有正确关闭、JWT密钥在多个worker间不一致、某个爬虫在疯狂刷接口。这三个问题恰好对应了今天要聊的三个主题——数据库会话管理、JWT认证、Rate Limiting。如果你正在用FastAPI做生产项目,这篇文章能帮你少踩我踩过的坑。
数据库会话管理:别让连接池变成定时炸弹
会话工厂的正确姿势
很多新手喜欢在每个请求里手动创建数据库连接,这是灾难的开始。正确的做法是用sessionmaker创建会话工厂,然后通过依赖注入管理生命周期。
# 别这样写:每次请求都新建引擎
# engine = create_engine(DATABASE_URL)
# Session = sessionmaker(bind=engine)
# 正确做法:全局一个引擎,按需创建会话
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
from fastapi import Depends, FastAPI
engine = create_engine(
DATABASE_URL,
pool_size=20, # 这里踩过坑,太小了并发一高就死
max_overflow=10,
pool_pre_ping=True, # 这个参数救过我的命,自动检测连接是否有效
pool_recycle=3600 # 连接超过1小时自动回收,防止MySQL主动断开
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close() # 这里一定要close,不然连接池会泄漏
注意那个finally块,很多教程只写yield db,忘了关闭。生产环境跑几天就会发现连接数暴涨,直到数据库拒绝连接。
异步会话的坑
如果你用异步数据库驱动(比如asyncpg),事情会复杂一些。FastAPI的异步依赖注入和同步的不太一样:
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
async_engine = create_async_engine(
ASYNC_DATABASE_URL,
pool_size=20,
max_overflow=10
)
AsyncSessionLocal = async_sessionmaker(async_engine, expire_on_commit=False)
async def get_async_db():
async with AsyncSessionLocal() as session:
yield session
# 这里不用手动close,async with会自动处理
# 但注意:如果yield之后有异常,session可能不会正确回滚
这里有个隐藏问题:expire_on_commit=False。默认情况下,commit之后所有对象会过期,下次访问会重新查询。如果你在commit之后还想用对象,记得设置这个参数为False,否则会触发额外的数据库查询,性能下降不说,还可能因为session已关闭而报错。
事务管理:别让脏数据污染你的数据库
我见过最离谱的代码是在视图函数里手动调用db.commit()和db.rollback()。正确的做法是用依赖注入统一管理事务:
from contextlib import contextmanager
@contextmanager
def transaction(db: Session):
try:
yield
db.commit()
except Exception:
db.rollback()
raise
# 在路由中使用
@app.post("/users/")
def create_user(user: UserCreate, db: Session = Depends(get_db)):
with transaction(db):
db_user = User(**user.dict())
db.add(db_user)
# 这里如果抛出异常,事务会自动回滚
return db_user
这样写的好处是:所有数据库操作都在一个事务上下文中,要么全部成功,要么全部失败。别在多个函数里分散commit,否则一个函数成功另一个失败,数据就不一致了。
JWT认证:从入门到生产级
基础实现:别把密钥写死在代码里
JWT认证的核心是签名和验证。很多教程把SECRET_KEY直接写在代码里,这是生产环境的大忌。
from datetime import datetime, timedelta
from jose import JWTError, jwt
from passlib.context import CryptContext
# 别这样写:SECRET_KEY = "my-secret-key"
# 正确做法:从环境变量读取
import os
SECRET_KEY = os.getenv("JWT_SECRET_KEY", "fallback-dev-key") # 生产环境必须设置
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def create_access_token(data: dict, expires_delta: timedelta = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
这里有个细节:exp字段是Unix时间戳,别传datetime对象,jose库会自动处理。但如果你用其他库,可能需要手动转换。
依赖注入:让每个路由都能验证用户
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
async def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无法验证凭据",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get("sub")
if user_id is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user = db.query(User).filter(User.id == user_id).first()
if user is None:
raise credentials_exception
return user
注意:这里每次请求都会查数据库。如果用户量很大,可以考虑用Redis缓存用户信息,但要注意缓存失效问题。我踩过的坑是:用户被禁用后,缓存里的token还能用,直到过期。解决方案是在token里加入用户状态版本号。
生产级优化:Token黑名单和刷新机制
JWT一旦签发就无法撤销,这是它最大的痛点。生产环境需要实现token黑名单:
# 用Redis存储黑名单
import redis
r = redis.Redis(host='localhost', port=6379, db=0)
def revoke_token(token: str, expire_in: int = 3600):
# 将token加入黑名单,过期时间设为token的剩余有效期
r.setex(f"blacklist:{token}", expire_in, "revoked")
def is_token_revoked(token: str) -> bool:
return r.exists(f"blacklist:{token}")
# 在验证时检查
async def get_current_user(token: str = Depends(oauth2_scheme)):
if is_token_revoked(token):
raise HTTPException(status_code=401, detail="Token已失效")
# 继续验证...
刷新token的机制也很重要。别让用户频繁登录,但也不能让token永不过期。我一般用双token方案:access token有效期15分钟,refresh token有效期7天。
Rate Limiting:别让爬虫打垮你的服务
基础实现:内存中的简单限流
FastAPI没有内置的限流中间件,但实现起来不难。最简单的方案是用内存字典:
from collections import defaultdict
import time
class MemoryRateLimiter:
def __init__(self):
self.requests = defaultdict(list)
def is_allowed(self, key: str, max_requests: int, window_seconds: int) -> bool:
now = time.time()
window_start = now - window_seconds
# 清理过期记录
self.requests[key] = [t for t in self.requests[key] if t > window_start]
if len(self.requests[key]) >= max_requests:
return False
self.requests[key].append(now)
return True
limiter = MemoryRateLimiter()
@app.get("/api/data")
def get_data(user: User = Depends(get_current_user)):
if not limiter.is_allowed(f"user:{user.id}", max_requests=10, window_seconds=60):
raise HTTPException(status_code=429, detail="请求过于频繁")
return {"data": "some data"}
这个方案的问题很明显:重启服务后限流数据丢失,多worker环境下不共享。生产环境必须用Redis。
生产级方案:基于Redis的滑动窗口限流
import redis
import time
class RedisRateLimiter:
def __init__(self, redis_client: redis.Redis):
self.redis = redis_client
def is_allowed(self, key: str, max_requests: int, window_seconds: int) -> bool:
now = int(time.time())
window_start = now - window_seconds
# 使用有序集合存储时间戳
pipeline = self.redis.pipeline()
pipeline.zadd(key, {now: now})
pipeline.zremrangebyscore(key, 0, window_start)
pipeline.zcard(key)
pipeline.expire(key, window_seconds + 1)
results = pipeline.execute()
current_count = results[2] # zcard的结果
return current_count <= max_requests
这里用Redis的有序集合实现滑动窗口,比固定窗口更精确。固定窗口的问题在于:如果用户在窗口边界集中请求,可能瞬间打垮服务。滑动窗口能平滑流量。
按用户和接口分别限流
不同接口的限流策略应该不同。登录接口要严格限制(防止暴力破解),普通查询接口可以宽松一些:
def rate_limit(key_prefix: str, max_requests: int, window_seconds: int):
def decorator(func):
async def wrapper(*args, **kwargs):
# 从依赖注入中获取当前用户
request = kwargs.get('request')
user = kwargs.get('current_user')
if user:
key = f"{key_prefix}:user:{user.id}"
else:
# 未登录用户用IP限流
client_ip = request.client.host
key = f"{key_prefix}:ip:{client_ip}"
if not redis_limiter.is_allowed(key, max_requests, window_seconds):
raise HTTPException(status_code=429, detail="请求过于频繁")
return await func(*args, **kwargs)
return wrapper
return decorator
@app.post("/login")
@rate_limit("login", max_requests=5, window_seconds=60)
async def login(request: Request):
# 登录逻辑
pass
注意:登录接口的限流要基于IP,因为用户还没登录。但IP可能被代理隐藏,所以最好结合User-Agent等其他信息。
个人经验性建议
-
数据库会话管理:永远不要在视图函数里手动管理会话。用依赖注入统一管理,配合
contextmanager处理事务。如果遇到连接池耗尽,先检查是不是有会话没关闭,再考虑调整连接池大小。 -
JWT认证:密钥一定要从环境变量读取,别写死在代码里。生产环境用RS256代替HS256,这样密钥对可以分开管理。token里只放必要信息(用户ID、角色),别把密码等敏感信息放进去。
-
Rate Limiting:别用内存限流,除非你的服务只有一个worker。Redis是标准方案,但要注意Redis本身也可能成为瓶颈。如果流量特别大,考虑用Nginx层限流作为第一道防线。
-
调试技巧:遇到认证问题,先检查token是否过期,再检查密钥是否一致。多worker环境下,确保所有worker使用相同的密钥和Redis实例。数据库会话问题,先看连接池状态,再看是否有未关闭的会话。
-
监控告警:这三个组件都要加监控。数据库连接池使用率超过80%要报警,JWT验证失败率突然升高要报警,某个接口的请求量异常增长要报警。没有监控的生产环境,就像闭着眼睛开车。
最后说一句:这些实践不是银弹。每个项目都有自己的特殊性,但理解了原理,遇到问题就能快速定位。我踩过的坑,希望你能绕过去。
596

被折叠的 条评论
为什么被折叠?



