当前位置: 首页 > news >正文

猫狗识别大模型——基于python语言

目录

1.猫狗识别

2.数据集介绍

3.猫狗识别核心原理

4.程序思路

4.1数据文件框架

4.2 训练模型

4.3 模型使用

4.4 识别结果

5.总结


1.猫狗识别

人可以直接分辨出图片里的动物是猫还是狗,但是电脑不可以,要想让电脑也分辨出图片里的动物是猫还是小狗,就要使用到深度学习,电脑学习提取图片特征,进而学习区分图片里的是猫还是狗。

2.数据集介绍

程序用到的训练数据集是猫狗图像数据集,数据格式jpg格式,猫狗数据集:

https://www.kaggle.com/datasets/shaunthesheep/microsoft-catsvsdogs-dataset

3.猫狗识别核心原理

猫狗识别大模型是一种深度学习架构,主要用于图像分类任务,用来区分猫和狗这两种常见的宠物动物。

该模型基于卷积神经网络(CNN),它们通过学习大量的猫和狗图像数据集中的特征来进行训练,使其能够识别出输入图片中动物的种类。

训练过程中,模型会对猫的特有纹理、颜色模式、耳朵形状等特征进行学习,并形成区分猫狗的关键特征模板。一旦模型经过充分训练并优化,它可以准确地判断新的未知图片是属于猫还是狗。

应用此类模型的方式通常是将其部署到移动设备或者云端服务器上,用户上传一张照片后,模型会返回一个预测结果,指示图像中动物的类别。

4.程序思路

基于tensorflow模型框架以及卷积神经网络还有其他各种模块,划分训练集,微调集和测试机,对猫狗图片文件进行训练。

4.1数据文件框架

4.2 训练模型

import tensorflow as tf from tensorflow.keras.preprocessing.image import ImageDataGenerator from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout from tensorflow.keras.optimizers import Adam import matplotlib.pyplot as plt import os # 获取所有的GPU设备 gpus = tf.config.list_physical_devices('GPU') # 检查是否有两个以上的GPU if gpus and len(gpus) > 1: try: # 假设GPU1是独立GPU,设置可见设备为GPU1 tf.config.set_visible_devices(gpus[1], 'GPU') tf.config.experimental.set_memory_growth(gpus[1], True) except RuntimeError as e: print(e) else: print("没有检测到多个GPU,或者系统只存在一个GPU。") # 定义数据目录 data_dir = './pythonProject/ai_modle_win/cats vs dogs/dataset' # 请替换为你的数据集路径 train_dir = os.path.join(data_dir, 'train') validation_dir = os.path.join(data_dir, 'validation') test_dir = os.path.join(data_dir, 'test') # 图像数据生成器 train_datagen = ImageDataGenerator( rescale=1./255, rotation_range=20, width_shift_range=0.2, height_shift_range=0.2, shear_range=0.2, zoom_range=0.2, horizontal_flip=True ) validation_datagen = ImageDataGenerator(rescale=1./255) test_datagen = ImageDataGenerator(rescale=1./255) # 计算样本数量 def count_files(directory): total_files = 0 for root, dirs, files in os.walk(directory): total_files += len(files) return total_files train_samples = count_files(train_dir) validation_samples = count_files(validation_dir) test_samples = count_files(test_dir) # 数据生成器 def create_generator(datagen, directory, target_size, batch_size, class_mode): generator = datagen.flow_from_directory( directory, target_size=target_size, batch_size=batch_size, class_mode=class_mode ) # 包装生成器以处理损坏的图像文件 while True: try: yield next(generator) except (OSError, StopIteration) as e: print(f"跳过无法读取的图像文件:{e}") continue train_generator = create_generator(train_datagen, train_dir, (150, 150), 32, 'binary') validation_generator = create_generator(validation_datagen, validation_dir, (150, 150), 32, 'binary') test_generator = create_generator(test_datagen, test_dir, (150, 150), 32, 'binary') # 定义模型 model = Sequential([ Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)), MaxPooling2D(2, 2), Conv2D(64, (3, 3), activation='relu'), MaxPooling2D(2, 2), Conv2D(128, (3, 3), activation='relu'), MaxPooling2D(2, 2), Conv2D(128, (3, 3), activation='relu'), MaxPooling2D(2, 2), Flatten(), Dropout(0.5), Dense(512, activation='relu'), Dense(1, activation='sigmoid') ]) model.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.001), metrics=['accuracy']) # 训练模型 history = model.fit( train_generator, steps_per_epoch=train_samples // 32, # 将结果转换为整数 validation_data=validation_generator, validation_steps=validation_samples // 32, # 将结果转换为整数 epochs=5 ) # 保存模型 model.save('./pythonProject/ai_modle_win/cats vs dogs/cat_dog.h5') # 评估模型 test_loss, test_acc = model.evaluate(test_generator, steps=test_samples // 32) print(f'Test accuracy: {test_acc:.2f}') # 可视化训练结果 acc = history.history['accuracy'] val_acc = history.history['val_accuracy'] loss = history.history['loss'] val_loss = history.history['val_loss'] epochs = range(len(acc)) plt.figure(figsize=(12, 9)) plt.subplot(1, 2, 1) plt.plot(epochs, acc, 'b', label='Training accuracy') plt.plot(epochs, val_acc, 'r', label='Validation accuracy') plt.title('Training and validation accuracy') plt.legend() plt.subplot(1, 2, 2) plt.plot(epochs, loss, 'b', label='Training loss') plt.plot(epochs, val_loss, 'r', label='Validation loss') plt.title('Training and validation loss') plt.legend() plt.show()

注意更改文件路径!!!

4.3 模型使用

import tensorflow as tf from tensorflow.keras.models import load_model from tensorflow.keras.preprocessing import image import numpy as np import os # 加载已保存的模型 model = load_model('./pythonProject/ai_modle_win/cats vs dogs/cat_dog.h5') # 预测函数 def predict_image(img_path): img = image.load_img(img_path, target_size=(150, 150)) img_array = image.img_to_array(img) img_array = np.expand_dims(img_array, axis=0) img_array /= 255.0 prediction = model.predict(img_array) if prediction[0] > 0.5: print(f"The image at {img_path} is a Dog") else: print(f"The image at {img_path} is a Cat") # 示例用法 test_image_path = './pythonProject/ai_modle_win/cats vs dogs/30.jpg' # 替换为你的测试图片路径 predict_image(test_image_path)

使用上述训练的模型进行图片识别,注意文件路径。

4.4 识别结果

5.总结

通过构造猫狗图片数据集,然后使用深度学习训练一个猫狗识别大模型,你也快来试一试吧。

http://www.jsqmd.com/news/475360/

相关文章:

  • iwebsec靶场多平台搭建对比:虚拟机 vs Docker,哪种更适合你?
  • 华为鸿蒙系统借助GBox沙箱生态,无缝畅享谷歌全家桶应用
  • Maven下载配置
  • linux设置常见开机自启动命令(一)
  • Python实战:用ncnn验证模型转换成功的3种方法(附完整代码)
  • 终极指南:Zelda64Recomp跨平台兼容性详解 - Windows与Linux系统的完美适配方案
  • 三明市商用车主的2026年展望:如何定义可靠的尿素后处理品牌 - 2026年企业推荐榜
  • 从NCDC到本地分析:一站式气象数据获取与Python自动化处理指南
  • 2025年中科院预警期刊全解析:科研小白如何避开论文工厂陷阱?
  • Zotero插件:Green Frog(绿青蛙)与easyScholar联动配置全攻略
  • AE函数讲解大全 附带下载链接
  • Traceroute结果解读指南:如何从毫秒数和星号中找出网络瓶颈
  • 五大主流Web GIS框架深度对比:从Leaflet到Cesium的实战选型指南
  • 分组密码设计实战:为什么AES选择SPN而DES用Feistel?从硬件到安全的深度解析
  • 红队工具实测:用Fenjing一键搞定Jinja2 SSTI漏洞(含自定义WAF绕过脚本编写)
  • 使用Marqo构建多语言法律数据库的技术实践
  • 基于TLS协议与多特征融合的恶意加密流量智能检测实战
  • 2023最新测评:5款网页版PostgreSQL管理工具横向对比(含TeamPostgreSQL实战)
  • Marqo语音搜索系统:解锁音频内容的信息价值
  • 2026年酱香果酒性价比之选:专业公司深度评测 - 2026年企业推荐榜
  • LiveCharts2 核心架构与工作原理深度解析
  • Depth Anything 3实战:如何用DINOv2 Transformer一键生成3D高斯点云?
  • 安卓逆向实战:从脱壳到签名算法还原——以某新闻App为例
  • 构建AI Agent驱动的自动化测试设计流水线
  • ImGui字体控制避坑指南:为什么SetWindowFontScale会影响其他窗口?
  • Java安全实战:手把手教你复现CC1链漏洞(附完整代码)
  • 国内开发者福音:5个无需魔法快速下载HuggingFace大模型的镜像站(附实测速度对比)
  • 从LAN8742A到YT8512H:手把手教你移植PHY驱动到STM32F407(含避坑指南)
  • GESP C++编程题实战:小杨购物问题解析与优化思路(附完整代码)
  • Windows 10/11网络设置全攻略:如何手动配置IPv4地址和子网掩码(附常见问题解决)