【实战指南】从零部署垃圾分类AI应用:TensorFlow 2.3模型训练与PyQt5界面开发全流程
1. 环境配置与数据准备
第一次接触垃圾分类AI项目时,我被8万张图片的数据集吓到了。但实际操作后发现,只要环境搭对了,后面的流程就像搭积木一样简单。建议使用Anaconda创建独立环境,避免版本冲突。我常用的配置是Python 3.7 + TensorFlow 2.3,这个组合在CUDA 10.1环境下最稳定。
数据集处理有个小技巧:用pathlib库批量检查图片完整性。曾经有次训练到一半报错,发现是损坏的JPEG文件导致的。后来我养成了预处理时先跑这个脚本的习惯:
from pathlib import Path from PIL import Image def verify_images(data_dir): broken_files = [] for img_path in Path(data_dir).rglob('*.jpg'): try: with Image.open(img_path) as img: img.verify() except (IOError, SyntaxError) as e: broken_files.append(str(img_path)) return broken_files数据集目录结构要保持一致,建议按这个格式组织:
trash_jpg/ ├─ 厨余垃圾_苹果 ├─ 可回收物_塑料瓶 ├─ 有害垃圾_电池 └─ 其他垃圾_卫生纸2. 模型选型实战对比
测试了LeNet和MobileNetV2后,发现准确率相差15%。但别急着选MobileNet,要考虑部署场景。在树莓派上测试时,LeNet的推理速度比MobileNet快3倍,虽然准确率低些,但对实时性要求高的场景反而更合适。
LeNet的魔改版我增加了Dropout层,防止过拟合:
def build_lenet(input_shape=(224,224,3), num_classes=245): model = Sequential([ Rescaling(1./255, input_shape=input_shape), Conv2D(32, (3,3), activation='relu', padding='same'), MaxPooling2D(), Dropout(0.3), # 新增的Dropout层 Conv2D(64, (3,3), activation='relu', padding='same'), MaxPooling2D(), Flatten(), Dense(128, activation='relu'), Dense(num_classes, activation='softmax') ]) return modelMobileNet的迁移学习有个坑要注意:默认输入是224x224,但如果你的图片长宽比异常,resize时会变形。后来我改成先等比例缩放再中心裁剪的方式:
def preprocess_image(image_path): img = tf.io.read_file(image_path) img = tf.image.decode_jpeg(img, channels=3) # 保持比例的resize img = tf.image.resize_with_pad(img, 224, 224) return img3. 训练调参技巧
学习率设置是门艺术。我的经验是先用LearningRateScheduler找合适范围:
def lr_schedule(epoch): if epoch < 5: return 1e-3 elif epoch < 15: return 1e-4 else: return 1e-5 callbacks = [ LearningRateScheduler(lr_schedule), EarlyStopping(patience=5) ]数据增强要适度,过度增强反而会降低准确率。我常用的组合是:
data_augmentation = Sequential([ RandomFlip("horizontal"), RandomRotation(0.1), RandomZoom(0.1), ])遇到类别不平衡问题时,可以用class_weight参数。计算权重的公式:
import numpy as np from sklearn.utils import class_weight class_weights = class_weight.compute_class_weight( 'balanced', classes=np.unique(train_labels), y=train_labels )4. PyQt5界面开发细节
Qt的信号槽机制初学容易懵,这里有个实用模板:
class MainWindow(QMainWindow): def __init__(self): super().__init__() self.model = load_model() # 提前加载模型 self.init_ui() def init_ui(self): self.btn_load = QPushButton('选择图片') self.btn_load.clicked.connect(self.load_image) # 信号槽连接 def load_image(self): fname = QFileDialog.getOpenFileName(self, '打开图片')[0] if fname: # 处理图片逻辑图片显示优化有个小技巧:用QLabel显示图片时,先转换成QPixmap并保持宽高比:
def display_image(self, img_path): pixmap = QPixmap(img_path) if not pixmap.isNull(): scaled = pixmap.scaled( self.label.width(), self.label.height(), Qt.KeepAspectRatio ) self.label.setPixmap(scaled)打包成exe时,记得把模型文件一起打包。用PyInstaller时要加这个参数:
pyinstaller --add-data "model.h5;." window.py5. 部署中的性能优化
模型量化能让体积缩小4倍。用TFLite转换时开启优化:
converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.optimizations = [tf.lite.Optimize.DEFAULT] tflite_model = converter.convert()在树莓派上部署时,启用多线程推理:
interpreter = tf.lite.Interpreter( model_path="model.tflite", num_threads=4 # 根据CPU核心数调整 )内存管理很重要,特别是处理大图时。我用这个上下文管理器避免内存泄漏:
from contextlib import contextmanager @contextmanager def open_image(path): try: img = Image.open(path) yield img finally: img.close()6. 常见问题解决方案
遇到"CUDA out of memory"错误时,可以尝试这三步:
- 减小batch_size(我一般从32开始试)
- 在代码开头设置GPU内存增长
gpus = tf.config.experimental.list_physical_devices('GPU') for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) - 使用混合精度训练
policy = tf.keras.mixed_precision.Policy('mixed_float16') tf.keras.mixed_precision.set_global_policy(policy)
标签错乱的问题可以通过检查class_indices文件解决。建议训练前先保存类别映射关系:
import json with open('class_indices.json', 'w') as f: json.dump(train_ds.class_indices, f)7. 扩展功能实现
添加摄像头实时检测功能时,用OpenCV的VideoCapture:
cap = cv2.VideoCapture(0) while True: ret, frame = cap.read() if not ret: break # 预处理帧 input_tensor = preprocess(frame) # 推理 predictions = model.predict(input_tensor) # 显示结果 cv2.imshow('Detection', frame) if cv2.waitKey(1) == ord('q'): break实现批量预测功能时,建议用多进程加速:
from multiprocessing import Pool def predict_single(img_path): # 单张图片预测逻辑 return result with Pool(processes=4) as pool: results = pool.map(predict_single, img_paths)日志记录功能对调试很有帮助。我习惯用logging模块:
import logging logging.basicConfig( filename='app.log', level=logging.INFO, format='%(asctime)s - %(message)s' )8. 项目优化方向
模型层面可以尝试EfficientNet,我在测试集上能达到89%准确率。但要注意计算资源消耗会增加约40%。
界面美化可以用Qt的样式表:
self.setStyleSheet(""" QMainWindow { background: #f0f0f0; } QPushButton { background: #4CAF50; color: white; border: none; padding: 8px; } """)加入语音提示功能会让体验更友好:
from gtts import gTTS import os def text_to_speech(text): tts = gTTS(text=text, lang='zh-cn') tts.save("output.mp3") os.system("start output.mp3")