边缘AI部署:在资源受限环境运行模型
边缘AI部署:在资源受限环境运行模型
前言
我们有一个用户场景:需要在没有网络的工厂环境中使用 AI。传统的云端 AI 方案完全不行,必须在边缘设备上运行模型。
经过几个月的探索,我们成功将模型部署到了树莓派和工业电脑上。今天,分享边缘 AI 部署的经验。
一、边缘AI的特点
1.1 边缘 vs 云端
| 维度 | 边缘部署 | 云端部署 |
|---|---|---|
| 延迟 | 极低 | 取决于网络 |
| 隐私 | 高(数据不离开设备) | 中(数据上传云端) |
| 成本 | 一次性硬件成本 | 按需付费 |
| 网络依赖 | 无 | 必须有网络 |
| 计算能力 | 有限 | 强大 |
| 模型大小 | 受限 | 无限制 |
1.2 边缘场景
EDGE_SCENARIOS = { "iot": {"device": "树莓派", "ram": "1-4GB", "suitable": "轻量模型"}, "industrial": {"device": "工业PC", "ram": "8-16GB", "suitable": "中量模型"}, "mobile": {"device": "手机", "ram": "4-8GB", "suitable": "量化模型"}, "embedded": {"device": "MCU", "ram": "512KB-2MB", "suitable": "Tiny模型"} }二、模型优化
2.1 模型剪枝
import torch.nn.utils.prune as prune class ModelPruner: def __init__(self, model): self.model = model def prune_weights(self, amount: float = 0.3): """权重剪枝""" for name, module in self.model.named_modules(): if isinstance(module, torch.nn.Linear): prune.l1_unstructured(module, name='weight', amount=amount) def remove_pruning(self): """移除剪枝,重新参数化""" for name, module in self.model.named_modules(): if isinstance(module, torch.nn.Linear): prune.remove(module, 'weight')2.2 模型量化
class ModelQuantizer: def __init__(self): self.quantization_config = { "compute_dtype": torch.float16, "weight_dtype": torch.qint8 } def quantize_dynamic(self, model): """动态量化""" return torch.quantization.quantize_dynamic( model, {torch.nn.Linear, torch.nn.LSTM}, dtype=torch.qint8 ) def quantize_static(self, model, calibration_data): """静态量化""" model.qconfig = torch.quantization.get_default_qconfig('fbgemm') torch.quantization.prepare(model, inplace=True) # 校准 with torch.no_grad(): for data in calibration_data: model(data) torch.quantization.convert(model, inplace=True) return model三、推理框架
3.1 ONNX Runtime
import onnxruntime as ort class ONNXInference: def __init__(self, model_path: str): self.session = ort.InferenceSession( model_path, providers=['CPUExecutionProvider'] ) def predict(self, input_data): """推理""" input_name = self.session.get_inputs()[0].name output_name = self.session.get_outputs()[0].name result = self.session.run( [output_name], {input_name: input_data} ) return result[0]3.2 TensorRT
class TensorRTInference: def __init__(self, engine_path: str): import tensorrt as trt logger = trt.Logger(trt.Logger.WARNING) runtime = trt.Runtime(logger) with open(engine_path, 'rb') as f: self.engine = runtime.deserialize_cuda_engine(f.read()) self.context = self.engine.create_execution_context() def predict(self, input_data, output_data): """推理""" import pycuda.driver as cuda cuda.init() context = cuda.Context() stream = cuda.Stream() # 内存分配和拷贝 d_input = cuda.mem_alloc(input_data.nbytes) d_output = cuda.mem_alloc(output_data.nbytes) cuda.memcpy_htod_async(d_input, input_data, stream) # 执行 self.context.execute_async_v2( bindings=[int(d_input), int(d_output)], stream_handle=stream.handle ) cuda.memcpy_dtoh_async(output_data, d_output, stream) stream.synchronize() return output_data四、设备适配
4.1 树莓派部署
# requirements.txt for Raspberry Pi # torch==2.0.0 # torchvision==0.15.0 # onnxruntime==1.15.0 class RaspberryPiDeployer: def optimize_for_pi(self, model): """为树莓派优化""" # 使用 PyTorch Mobile model.eval() # 量化 model_quantized = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8 ) return model_quantized def export_scripted(self, model, input_shape): """导出为 TorchScript""" traced = torch.jit.trace(model, torch.randn(input_shape)) return traced4.2 工业设备部署
class IndustrialDeployer: def deploy(self, model, device_type: str): """部署到工业设备""" if device_type == "jetson_nano": return self._deploy_jetson(model) elif device_type == "jetson_xavier": return self._deploy_jetson(model, use_tensorrt=True) elif device_type == "industrial_pc": return self._deploy_pc(model) def _deploy_jetson(self, model, use_tensorrt=True): """部署到 Jetson""" if use_tensorrt: # 转换为 TensorRT return self._convert_to_tensorrt(model) else: # 使用 PyTorch Native return model.cuda()五、性能优化
5.1 批处理优化
class BatchOptimizer: def __init__(self, max_batch_size: int = 8): self.max_batch_size = max_batch_size self.pending_requests = [] def add_request(self, data): """添加请求""" self.pending_requests.append(data) if len(self.pending_requests) >= self.max_batch_size: return self._process_batch() return None def force_process(self): """强制处理""" if self.pending_requests: return self._process_batch() return None def _process_batch(self): """批量处理""" batch = self.pending_requests[:self.max_batch_size] self.pending_requests = self.pending_requests[self.max_batch_size:] return batch5.2 缓存优化
class EdgeCache: def __init__(self, max_size_mb: int = 100): self.max_size = max_size_mb * 1024 * 1024 self.cache = {} self.access_times = {} def get(self, key): """获取缓存""" if key in self.cache: self.access_times[key] = datetime.now() return self.cache[key] return None def set(self, key, value): """设置缓存""" size = self._get_size(value) while self._get_total_size() + size > self.max_size: self._evict_lru() self.cache[key] = value self.access_times[key] = datetime.now()六、监控与维护
6.1 边缘监控
class EdgeMonitor: def __init__(self): self.metrics = { "cpu_usage": [], "memory_usage": [], "inference_count": 0, "errors": [] } def record(self, metric_type: str, value): """记录指标""" if metric_type in ["cpu_usage", "memory_usage"]: self.metrics[metric_type].append({ "value": value, "timestamp": datetime.now() }) else: self.metrics[metric_type] = value def get_health_report(self): """健康报告""" return { "cpu_avg": sum(m["value"] for m in self.metrics["cpu_usage"]) / len(self.metrics["cpu_usage"]) if self.metrics["cpu_usage"] else 0, "memory_avg": sum(m["value"] for m in self.metrics["memory_usage"]) / len(self.metrics["memory_usage"]) if self.metrics["memory_usage"] else 0, "total_inferences": self.metrics["inference_count"], "error_count": len(self.metrics["errors"]) }6.2 OTA 更新
class EdgeOTA: def __init__(self): self.update_server = "https://updates.example.com" def check_update(self, current_version: str) -> dict: """检查更新""" import requests response = requests.get( f"{self.update_server}/check", params={"version": current_version} ) return response.json() def download_update(self, model_id: str, progress_callback=None): """下载更新""" import requests response = requests.get( f"{self.update_server}/download/{model_id}", stream=True ) total_size = int(response.headers.get('content-length', 0)) downloaded = 0 with open('/tmp/model_update.onnx', 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) downloaded += len(chunk) if progress_callback: progress_callback(downloaded / total_size) return '/tmp/model_update.onnx'七、最佳实践
7.1 部署策略
- ✅渐进更新:先小范围测试,再全量
- ✅版本管理:保持多个版本,可回滚
- ✅监控告警:实时监控设备状态
- ✅自动恢复:异常时自动重启
7.2 性能优化
- ✅模型优化:剪枝、量化、蒸馏
- ✅批处理:提高 GPU 利用率
- ✅缓存:减少重复计算
- ✅异步:非阻塞推理
八、总结
边缘 AI 让 AI 能力延伸到每一个角落。关键在于:
- 模型优化:适配硬件限制
- 推理框架:选择合适的运行时
- 性能优化:榨干硬件性能
- 运维监控:确保稳定运行
记住:边缘不是将就,而是必然。
