当前位置: 首页 > news >正文

用Java手写kNN和朴素贝叶斯:从鸢尾花数据集到电影推荐,一次搞定两个经典算法

Java实战:从kNN到朴素贝叶斯,双算法实现分类与推荐系统

在机器学习领域,k近邻(kNN)和朴素贝叶斯(NB)是两个经典且实用的算法。本文将带你用Java实现这两个算法,并应用于鸢尾花分类和电影推荐两个不同场景。通过对比实现,你会发现虽然两者都基于"相似性"概念,但在思想和代码实现上有着显著差异。

1. 环境准备与数据加载

在开始编码前,我们需要准备好开发环境和数据集。这里使用Weka库来处理ARFF格式的数据文件,它提供了方便的API来操作机器学习数据集。

首先创建Maven项目并添加Weka依赖:

<dependency> <groupId>nz.ac.waikato.cms.weka</groupId> <artifactId>weka-stable</artifactId> <version>3.8.6</version> </dependency>

对于鸢尾花数据集,我们可以直接从网络获取:

public KnnClassification(String paraFilename) { try { FileReader fileReader = new FileReader(paraFilename); dataset = new Instances(fileReader); dataset.setClassIndex(dataset.numAttributes() - 1); fileReader.close(); } catch (Exception ee) { System.out.println("Error reading file: " + ee); System.exit(0); } }

电影评分数据则需要特殊处理,因为它是用户-物品-评分的三元组形式:

public MBR(String paraFilename, int paraNumUsers, int paraNumItems) throws Exception { // 初始化数组 BufferedReader tempBufReader = new BufferedReader(new FileReader(tempFile)); String tempString; while ((tempString = tempBufReader.readLine()) != null) { String[] tempStrArray = tempString.split(","); // 处理每行数据 } }

2. kNN算法实现与优化

kNN算法的核心思想是"物以类聚"——通过计算待分类样本与训练集中各样本的距离,找出最近的k个邻居,然后根据这些邻居的类别进行投票决定待分类样本的类别。

距离度量是kNN的关键,常见的有:

  • 曼哈顿距离:各维度绝对差之和
  • 欧氏距离:平方差之和开方
public double distance(int paraI, int paraJ) { double resultDistance = 0; switch (distanceMeasure) { case MANHATTAN: // 曼哈顿距离计算 break; case EUCLIDEAN: // 欧氏距离计算 break; } return resultDistance; }

邻居查找的原始实现需要多次扫描训练集,效率较低。我们可以优化为单次扫描:

public int[] computeNearests(int paraCurrent) { int[] resultNearests = new int[numNeighbors]; double[] tempDistances = new double[trainingSet.length]; // 一次计算所有距离 for (int i = 0; i < trainingSet.length; i++) { tempDistances[i] = distance(paraCurrent, trainingSet[i]); } // 然后找出最小的k个 }

kNN在鸢尾花分类上的典型准确率能达到95%左右,但计算复杂度随数据量线性增长,这是其主要缺点。

3. 基于M-distance的推荐系统

推荐系统与分类问题看似不同,但核心思想也是"找相似"。这里我们实现基于物品的协同过滤推荐,使用M-distance衡量物品相似度。

M-distance认为两个物品的相似度取决于它们的平均评分是否接近:

if (Math.abs(itemAvgRating1 - itemAvgRating2) < radius) { // 视为邻居物品 }

预测用户对某物品的评分时,取所有相似物品评分的平均值:

public double predictRating(int userId, int itemId) { double sum = 0; int count = 0; for (Item neighbor : findSimilarItems(itemId)) { if (userRatedItem(userId, neighbor.id)) { sum += getUserRating(userId, neighbor.id); count++; } } return count > 0 ? sum / count : DEFAULT_RATING; }

这种方法的MAE(平均绝对误差)通常在0.7-0.8之间,调整radius参数可以平衡推荐覆盖率与准确率。

4. 朴素贝叶斯分类器实现

朴素贝叶斯基于贝叶斯定理,假设特征之间相互独立。我们先实现处理符号型数据的版本:

public void calculateConditionalProbabilities() { // 初始化三维数组 conditionalProbabilitiesLaplacian = new double[numClasses][numConditions][]; // 统计每个特征值在每个类别下的出现次数 for (Instance instance : dataset) { int cls = (int)instance.classValue(); for (int attr = 0; attr < numConditions; attr++) { int val = (int)instance.value(attr); conditionalCounts[cls][attr][val]++; } } // 计算拉普拉斯平滑后的概率 for (int cls = 0; cls < numClasses; cls++) { for (int attr = 0; attr < numConditions; attr++) { int numValues = dataset.attribute(attr).numValues(); for (int val = 0; val < numValues; val++) { conditionalProbabilitiesLaplacian[cls][attr][val] = (conditionalCounts[cls][attr][val] + 1) / (classCounts[cls] + numValues); } } } }

分类时计算后验概率的对数,避免浮点数下溢:

public int classifyNominal(Instance instance) { double maxLogProbability = Double.NEGATIVE_INFINITY; int bestClass = -1; for (int cls = 0; cls < numClasses; cls++) { double logProb = Math.log(classDistributionLaplacian[cls]); for (int attr = 0; attr < numConditions; attr++) { int val = (int)instance.value(attr); logProb += Math.log(conditionalProbabilitiesLaplacian[cls][attr][val]); } if (logProb > maxLogProbability) { maxLogProbability = logProb; bestClass = cls; } } return bestClass; }

5. 数值型数据的朴素贝叶斯

对于数值型特征,我们通常假设其服从高斯分布。需要为每个特征在每个类别下计算均值和标准差:

public void calculateGaussianParameters() { gaussianParameters = new GaussianParamters[numClasses][numConditions]; for (int cls = 0; cls < numClasses; cls++) { for (int attr = 0; attr < numConditions; attr++) { double sum = 0, sumSquares = 0; int count = 0; for (Instance instance : dataset) { if ((int)instance.classValue() == cls) { double val = instance.value(attr); sum += val; sumSquares += val * val; count++; } } double mean = sum / count; double stddev = Math.sqrt((sumSquares - sum*sum/count) / count); gaussianParameters[cls][attr] = new GaussianParamters(mean, stddev); } } }

分类时使用高斯概率密度函数:

public int classifyNumerical(Instance instance) { double maxLogProbability = Double.NEGATIVE_INFINITY; int bestClass = -1; for (int cls = 0; cls < numClasses; cls++) { double logProb = Math.log(classDistributionLaplacian[cls]); for (int attr = 0; attr < numConditions; attr++) { double x = instance.value(attr); double mu = gaussianParameters[cls][attr].mu; double sigma = gaussianParameters[cls][attr].sigma; // 高斯分布的概率密度对数 logProb += -Math.log(sigma) - (x-mu)*(x-mu)/(2*sigma*sigma); } if (logProb > maxLogProbability) { maxLogProbability = logProb; bestClass = cls; } } return bestClass; }

6. 算法对比与应用场景

虽然kNN和朴素贝叶斯都可用于分类,但它们的适用场景有所不同:

特性kNN朴素贝叶斯
训练速度快(惰性学习)
预测速度
内存需求高(存储全部数据)低(存储参数)
特征相关性可处理相关特征假设特征独立
数据规模适合中小规模适合大规模
特征类型数值型表现好数值型、类别型均可

在推荐系统场景中,基于用户的协同过滤(UserCF)类似于kNN,而基于物品的协同过滤(ItemCF)更接近朴素贝叶斯的思想。选择哪种算法取决于具体需求和数据特性。

7. 工程实践建议

在实际项目中应用这些算法时,有几个关键点需要注意:

  1. 数据预处理

    • 归一化:kNN对特征尺度敏感,需做归一化
    • 缺失值处理:朴素贝叶斯需要处理缺失值
    // 归一化示例 public void normalize() { for (int attr = 0; attr < numConditions; attr++) { double min = Double.MAX_VALUE; double max = Double.MIN_VALUE; // 找出最小最大值 // 归一化每个特征值 } }
  2. 参数调优

    • kNN中的k值
    • M-distance中的radius阈值
    • 朴素贝叶斯的平滑参数
  3. 性能优化

    • kNN可以使用KD树等数据结构加速搜索
    • 对于大规模数据,可以考虑近似算法
  4. 评估指标

    • 分类问题:准确率、精确率、召回率、F1值
    • 推荐系统:MAE、RMSE、覆盖率、多样性
// 评估指标计算示例 public void evaluate() { int[][] confusionMatrix = new int[numClasses][numClasses]; for (int i = 0; i < numInstances; i++) { int actual = (int)dataset.instance(i).classValue(); int predicted = predicts[i]; confusionMatrix[actual][predicted]++; } // 计算各项指标... }

8. 扩展与进阶

掌握了基础实现后,可以考虑以下扩展方向:

  1. 加权kNN:给更近的邻居更高权重

    // 加权投票 public int weightedVoting(int[] neighbors) { double[] weights = new double[numClasses]; for (int i = 0; i < neighbors.length; i++) { double distance = distances[i]; double weight = 1.0 / (distance + 1e-5); // 避免除零 int cls = (int)dataset.instance(neighbors[i]).classValue(); weights[cls] += weight; } // 返回权重最大的类别 }
  2. 核密度估计:改进朴素贝叶斯对数值型特征的建模

  3. 混合型数据:同时处理数值型和类别型特征

  4. 增量学习:支持在线更新模型

  5. 分布式实现:使用Spark等框架处理大数据

通过本项目的实践,你不仅学会了两种重要算法的实现,还掌握了如何将机器学习应用于不同场景的关键技能。建议尝试将这些算法应用到自己的项目中,或者参加Kaggle等数据科学竞赛来进一步磨练技能。

http://www.jsqmd.com/news/664363/

相关文章:

  • RWKV7-1.5B-G1A开源协作:在GitHub Actions中集成模型自动化代码审查
  • LFM2.5-1.2B-Thinking-GGUF零基础部署:5分钟在CSDN星图一键启动轻量文本生成模型
  • 别再死记硬背了!用PyTorch和TensorFlow动手搭建你的第一个自编码器(附完整代码)
  • 大模型---exploit and explore
  • 嘎嘎降AI和去AIGC哪个更适合理工科论文:2026年最新对比
  • Graphormer镜像免配置亮点:内置SMILES示例库与一键测试功能快速验证
  • internlm2-chat-1.8b效果惊艳:中文古籍标点自动添加+白话翻译对比展示
  • Phi-4-mini-reasoning推理模型企业级部署实录:Docker Compose+Nginx,稳定运行128K长文本
  • Fish Speech 1.5教育场景应用:制作多语言教学音频教程
  • 如何快速配置 Ultimate ASI Loader:游戏插件加载完整指南
  • 智能代码生成≠自动交付(重构才是最后一道防火墙):金融级系统落地的6项重构准入标准
  • jQuery 选择器
  • Qwen3-14B低代码开发应用:基于Dify快速构建AI智能体(Agent)
  • 别再死记硬背了!用这个“资本家模型”5分钟搞懂三极管饱和与截止
  • HeyGem数字人系统批量处理教程:高效制作企业宣传视频
  • 创维E900V22E刷机后必做的6项优化:从三网通吃到存储空间清理(S905L3固件实测)
  • Calibre中文路径保护插件:终极解决方案告别拼音路径困扰
  • WAN2.2+SDXL_Prompt风格效果展示:‘未来科技发布会’提示词生成专业级视频
  • GESP2023年12月认证C++三级( 第三部分编程题(1、小猫分鱼))
  • 工业路由器能用多久
  • Phi-3 Forest Lab部署教程:Kubernetes集群中水平扩展Phi-3服务
  • 从混合信号中精准剥离生命体征:基于HHT与自适应滤波的心率呼吸率分离实践
  • 网络协议分析助手:Phi-4-mini-reasoning解读抓包数据与故障诊断
  • 次元画室Python入门实践:用10行代码实现你的第一张AI绘画
  • KICS(Kucius Inverse Capability Score)完整体系:从元推理量化到去中心化共识治理
  • 如何在5分钟内免费部署本地AI写作助手:KoboldAI完全指南
  • LeetCode 3783. 整数的镜像距离 技术解析
  • 【计算机网络 实验报告4】虚拟局域网与ARP协议
  • 用ESP32+Arduino搞定VESC双轮毂电机同步控制(附完整代码)
  • 告别死板界面!Nanbeige 4.1-3B Streamlit WebUI极简版,一键搭建二次元对话助手