一、为什么需要推理缓存
1.1 重复请求的浪费
现实中的推理服务有一个被忽视的事实: 大量请求是重复的。 以智慧安防为例: - 同一个摄像头,24 小时监控同一个入口 - 每秒 25 帧,大部分帧的画面几乎一样 - 对同一张人脸的识别结果,每秒被计算 25 次 - 实际上,每秒只需要计算 1-2 次,其余 23 次是浪费 以电商搜索为例: - 用户搜索"手机壳",系统返回推荐结果 - 10 秒后,另一个用户也搜索"手机壳" - 两次搜索的输入完全相同,但第二次重新计算了整个模型 量化分析: 假设: 推理延迟: 10ms QPS: 1000 重复请求比例: 30% 不用缓存: 每秒计算量: 1000 次 总计算时间: 1000 × 10ms = 10 秒 (需要多个 NPU 并行) 用缓存: 每秒实际计算: 700 次 (300 次命中缓存) 缓存命中延迟: < 0.1ms (内存查找) 节省: 30% 的 NPU 算力
1.2 缓存的代价
缓存不是免费的,需要权衡: 收益: ✓ 减少 NPU 计算量 ✓ 降低延迟(缓存命中 < 0.1ms vs 推理 10ms) ✓ 提高吞吐量 ✓ 节省算力成本 代价: ✗ 内存占用(缓存需要存储结果) ✗ 一致性风险(模型更新后缓存可能过期) ✗ 实现复杂度(缓存键生成、失效策略、并发控制) ✗ 首次请求仍然慢(缓存未命中) 什么时候适合用缓存: ✓ 输入空间有限(如固定类别分类) ✓ 重复请求比例高 ✓ 对延迟敏感 ✓ 模型更新不频繁 什么时候不适合: ✗ 输入几乎不重复(如随机生成的内容) ✗ 内存非常紧张 ✗ 模型频繁更新(缓存频繁失效)
二、缓存策略
2.1 LRU 缓存
importhashlibimporttimeimportthreadingfromcollectionsimportOrderedDictclassLRUCache:"""LRU (Least Recently Used) 缓存 核心思想: 当缓存满了,淘汰最久没有被访问的条目。 为什么选 LRU? - 实现简单,性能好 - 符合"时间局部性"原理: 最近被访问的数据,很可能很快再次被访问 - 适用于大多数推理场景 数据结构: 使用 OrderedDict 维护访问顺序。 每次访问时,将条目移到末尾(最新)。 淘汰时,删除头部(最旧)的条目。 时间复杂度: - 查找: O(1) - 插入: O(1) - 淘汰: O(1) 为什么不用 list? list 的删除是 O(n),当缓存很大时会变慢。 OrderedDict 的删除是 O(1)。 """def__init__(self,max_size=1000,ttl_seconds=300):""" 参数: max_size: 缓存最大条目数 太小: 命中率低 太大: 内存占用高 经验值: 预估重复请求数的 2-5 倍 ttl_seconds: 缓存条目过期时间(秒) 为什么要过期? 1. 防止内存无限增长 2. 模型更新后旧结果自动失效 3. 控制缓存的新鲜度 TTL 的选择: - 模型不更新: 可以设很长(如 1 小时) - 模型频繁更新: 设很短(如 10 秒) - 一般场景: 5-10 分钟 """self.max_size=max_size self.ttl_seconds=ttl_seconds self.cache=OrderedDict()# {key: (value, timestamp)}self.lock=threading.RLock()# 线程安全锁self.hits=0self.misses=0defget(self,key):"""获取缓存 查找流程: 1. 加锁(多线程安全) 2. 检查 key 是否存在 3. 检查是否过期 4. 如果命中,将条目移到末尾(标记为"最近使用") 5. 返回结果或 None """withself.lock:ifkeyinself.cache:value,timestamp=self.cache[key]# 检查是否过期iftime.time()-timestamp<self.ttl_seconds:# 命中: 移到末尾(标记为最近使用)self.cache.move_to_end(key)self.hits+=1returnvalueelse:# 过期: 删除delself.cache[key]self.misses+=1returnNonedefput(self,key,value):"""放入缓存 插入流程: 1. 如果 key 已存在,更新值并移到末尾 2. 如果缓存已满,删除头部(最久未使用) 3. 插入新条目到末尾 """withself.lock:ifkeyinself.cache:# 更新已有条目self.cache.move_to_end(key)self.cache[key]=(value,time.time())else:# 检查是否需要淘汰iflen(self.cache)>=self.max_size:# 删除最久未使用的(头部)self.cache.popitem(last=False)# 插入新条目self.cache[key]=(value,time.time())defclear(self):"""清空缓存"""withself.lock:self.cache.clear()self.hits=0self.misses=0defstats(self):"""获取缓存统计"""total=self.hits+self.misses hit_rate=self.hits/totaliftotal>0else0return{'size':len(self.cache),'max_size':self.max_size,'hits':self.hits,'misses':self.misses,'hit_rate':hit_rate}# 使用示例cache=LRUCache(max_size=1000,ttl_seconds=300)# 模拟推理请求definfer_with_cache(model,input_tensor):"""带缓存的推理"""# 生成缓存键(基于输入内容的哈希)cache_key=hashlib.md5(input_tensor.numpy().tobytes()).hexdigest()# 尝试从缓存获取cached_result=cache.get(cache_key)ifcached_resultisnotNone:returncached_result,"cache_hit"# 缓存未命中,执行推理withtorch.no_grad():result=model(input_tensor.npu()).cpu()# 存入缓存cache.put(cache_key,result)returnresult,"cache_miss"# 模拟 1000 个请求(30% 重复)importrandom inputs=[torch.randn(1,3,224,224)for_inrange(700)]# 加入 300 个重复请求for_inrange(300):inputs.append(random.choice(inputs))random.shuffle(inputs)forinpininputs:result,status=infer_with_cache(model,inp)print(f"缓存统计:{cache.stats()}")# 期望命中率: ~30%
2.2 LFU 缓存
classLFUCache:"""LFU (Least Frequently Used) 缓存 核心思想: 当缓存满了,淘汰访问频率最低的条目。 与 LRU 的区别: - LRU: 淘汰"最久没访问的" - LFU: 淘汰"访问次数最少的" 适用场景: - 某些输入被反复请求(如热门商品的识别) - 希望保留高频条目,淘汰低频条目 实现: 使用两个数据结构: 1. freq_map: {频率: OrderedDict(该频率的所有key)} 2. key_map: {key: (value, 频率)} """def__init__(self,max_size=1000,ttl_seconds=300):self.max_size=max_size self.ttl_seconds=ttl_seconds self.key_map={}# {key: (value, freq, timestamp)}self.freq_map={}# {freq: OrderedDict({key: None})}self.min_freq=0self.lock=threading.RLock()self.hits=0self.misses=0defget(self,key):"""获取缓存并增加频率"""withself.lock:ifkeyinself.key_map:value,freq,timestamp=self.key_map[key]iftime.time()-timestamp<self.ttl_seconds:# 更新频率self._increase_freq(key,freq)self.hits+=1returnvalueelse:# 过期self._remove(key)self.misses+=1returnNonedefput(self,key,value):"""放入缓存"""withself.lock:ifkeyinself.key_map:_,freq,_=self.key_map[key]self.key_map[key]=(value,freq+1,time.time())self._increase_freq(key,freq)else:iflen(self.key_map)>=self.max_size:# 淘汰频率最低的self._evict()self.key_map[key]=(value,1,time.time())self.min_freq=1if1notinself.freq_map:self.freq_map[1]=OrderedDict()self.freq_map[1][key]=Nonedef_increase_freq(self,key,old_freq):"""增加 key 的频率"""# 从旧频率组移除delself.freq_map[old_freq][key]ifnotself.freq_map[old_freq]:delself.freq_map[old_freq]ifself.min_freq==old_freq:self.min_freq=old_freq+1# 加入新频率组new_freq=old_freq+1ifnew_freqnotinself.freq_map:self.freq_map[new_freq]=OrderedDict()self.freq_map[new_freq][key]=None# 更新 key_mapvalue,_,timestamp=self.key_map[key]self.key_map[key]=(value,new_freq,timestamp)def_evict(self):"""淘汰频率最低的条目"""ifself.min_freqinself.freq_mapandself.freq_map[self.min_freq]:# 删除该频率组中最旧的条目key,_=self.freq_map[self.min_freq].popitem(last=False)ifnotself.freq_map[self.min_freq]:delself.freq_map[self.min_freq]delself.key_map[key]def_remove(self,key):"""删除指定 key"""ifkeyinself.key_map:_,freq,_=self.key_map[key]delself.key_map[key]iffreqinself.freq_map:delself.freq_map[freq][key]ifnotself.freq_map[freq]:delself.freq_map[freq]defstats(self):"""获取统计"""total=self.hits+self.missesreturn{'size':len(self.key_map),'max_size':self.max_size,'hits':self.hits,'misses':self.misses,'hit_rate':self.hits/totaliftotal>0else0}
三、缓存键生成
3.1 内容哈希
classCacheKeyGenerator:"""缓存键生成器 缓存键的质量直接决定缓存的命中率。 好的缓存键: - 相同输入 → 相同键(确定性) - 不同输入 → 不同键(唯一性) - 生成速度快(不能成为瓶颈) 坏的缓存键: - 用时间戳作为键 → 永远不会命中 - 用随机数作为键 → 永远不会命中 - 用文件路径作为键 → 文件内容变了但键没变 """@staticmethoddeftensor_hash(tensor):"""张量内容哈希 直接对张量的字节内容做哈希。 最简单、最可靠的方式。 注意事项: - 确保 tensor 在 CPU 上(.cpu()) - 确保 tensor 是 contiguous 的 - 对于浮点数,微小的精度差异会导致不同的哈希 """returnhashlib.sha256(tensor.detach().cpu().contiguous().numpy().tobytes()).hexdigest()@staticmethoddefbatch_hash(tensors):"""批量张量哈希"""combined=b""fortintensors:combined+=t.detach().cpu().contiguous().numpy().tobytes()returnhashlib.sha256(combined).hexdigest()@staticmethoddeftext_hash(text):"""文本哈希"""returnhashlib.sha256(text.encode('utf-8')).hexdigest()@staticmethoddefparams_hash(**kwargs):"""参数组合哈希"""sorted_params=sorted(kwargs.items())param_str=str(sorted_params)returnhashlib.sha256(param_str.encode()).hexdigest()# 使用示例key_gen=CacheKeyGenerator()# 图像推理缓存键image=torch.randn(1,3,224,224)key=key_gen.tensor_hash(image)print(f"图像缓存键:{key[:16]}...")# 文本推理缓存键text="今天天气真好"key=key_gen.text_hash(text)print(f"文本缓存键:{key[:16]}...")# 带参数的缓存键key=key_gen.params_hash(model="resnet50",image_hash=key_gen.tensor_hash(image),threshold=0.5)print(f"带参数缓存键:{key[:16]}...")
四、完整推理缓存系统
classInferenceCacheSystem:"""完整的推理缓存系统 架构: ┌──────────────────────────────────────────┐ │ 推理请求 │ │ ↓ │ │ ┌─────────────────┐ │ │ │ 缓存键生成 │ │ │ └────────┬────────┘ │ │ ↓ │ │ ┌─────────────────┐ 命中 │ │ │ LRU 缓存 │ ──────→ 返回结果 │ │ └────────┬────────┘ │ │ ↓ 未命中 │ │ ┌─────────────────┐ │ │ │ NPU 推理 │ │ │ └────────┬────────┘ │ │ ↓ │ │ ┌─────────────────┐ │ │ │ 结果存入缓存 │ │ │ └─────────────────┘ │ └──────────────────────────────────────────┘ """def__init__(self,model,cache_size=1000,ttl_seconds=300):self.model=model self.cache=LRUCache(max_size=cache_size,ttl_seconds=ttl_seconds)self.key_gen=CacheKeyGenerator()self.stats_log=[]definfer(self,input_tensor,model_params=None):"""带缓存的推理 参数: input_tensor: 输入张量 model_params: 模型参数(用于区分不同版本的模型) """# 1. 生成缓存键ifmodel_params:cache_key=self.key_gen.params_hash(input_hash=self.key_gen.tensor_hash(input_tensor),model_hash=self.key_gen.text_hash(str(model_params)))else:cache_key=self.key_gen.tensor_hash(input_tensor)# 2. 查缓存cached=self.cache.get(cache_key)ifcachedisnotNone:returncached,"hit"# 3. 执行推理start_time=time.time()withtorch.no_grad():result=self.model(input_tensor.npu()).cpu()infer_time=(time.time()-start_time)*1000# 4. 存入缓存self.cache.put(cache_key,result)# 5. 记录统计self.stats_log.append({'cache_key':cache_key[:16],'status':'miss','infer_time_ms':infer_time,'timestamp':time.time()})returnresult,"miss"defbatch_infer(self,input_tensors):"""批量推理(自动缓存)"""results=[]hit_count=0miss_count=0forinpininput_tensors:result,status=self.infer(inp)results.append(result)ifstatus=="hit":hit_count+=1else:miss_count+=1returnresults,{'total':len(input_tensors),'hits':hit_count,'misses':miss_count,'hit_rate':hit_count/len(input_tensors)}definvalidate(self,pattern=None):"""使缓存失效 用途: - 模型更新后,清除旧缓存 - 数据变化后,清除相关缓存 """ifpattern:# 按模式清除(简化实现)keys_to_remove=[kforkinself.cache.cache.keys()ifpatternink]forkeyinkeys_to_remove:delself.cache.cache[key]else:self.cache.clear()defprint_stats(self):"""打印统计"""stats=self.cache.stats()print(f"\n缓存统计:")print(f" 大小:{stats['size']}/{stats['max_size']}")print(f" 命中:{stats['hits']}")print(f" 未命中:{stats['misses']}")print(f" 命中率:{stats['hit_rate']:.1%}")# 使用示例# cache_system = InferenceCacheSystem(model, cache_size=1000, ttl_seconds=300)# result, status = cache_system.infer(test_input)# print(f"推理状态: {status}")
五、常见问题
| 问题 | 原因 | 解决方案 |
|---|
| 命中率低 | 缓存键太精确或输入几乎不重复 | 放宽缓存键(如量化后再哈希) |
| 内存占用高 | 缓存太大或结果太大 | 减小缓存大小、压缩缓存结果 |
| 缓存不一致 | 模型更新后旧缓存未失效 | 设置 TTL、版本化缓存键 |
| 并发性能差 | 锁竞争激烈 | 分片缓存、无锁数据结构 |
| 缓存雪崩 | 大量缓存同时过期 | 添加随机 TTL 偏移 |
相关仓库
- cachetools- Python 缓存库 https://github.com/tkem/cachetools
- redis-py- Redis 缓存 https://github.com/redis/redis-py
- ascend-cl- 推理接口 https://gitee.com/ascend/ascend-cl