用Java手写kNN和朴素贝叶斯:从鸢尾花数据集到电影推荐,一次搞定两个经典算法
Java实战:从kNN到朴素贝叶斯,双算法实现分类与推荐系统
在机器学习领域,k近邻(kNN)和朴素贝叶斯(NB)是两个经典且实用的算法。本文将带你用Java实现这两个算法,并应用于鸢尾花分类和电影推荐两个不同场景。通过对比实现,你会发现虽然两者都基于"相似性"概念,但在思想和代码实现上有着显著差异。
1. 环境准备与数据加载
在开始编码前,我们需要准备好开发环境和数据集。这里使用Weka库来处理ARFF格式的数据文件,它提供了方便的API来操作机器学习数据集。
首先创建Maven项目并添加Weka依赖:
<dependency> <groupId>nz.ac.waikato.cms.weka</groupId> <artifactId>weka-stable</artifactId> <version>3.8.6</version> </dependency>对于鸢尾花数据集,我们可以直接从网络获取:
public KnnClassification(String paraFilename) { try { FileReader fileReader = new FileReader(paraFilename); dataset = new Instances(fileReader); dataset.setClassIndex(dataset.numAttributes() - 1); fileReader.close(); } catch (Exception ee) { System.out.println("Error reading file: " + ee); System.exit(0); } }电影评分数据则需要特殊处理,因为它是用户-物品-评分的三元组形式:
public MBR(String paraFilename, int paraNumUsers, int paraNumItems) throws Exception { // 初始化数组 BufferedReader tempBufReader = new BufferedReader(new FileReader(tempFile)); String tempString; while ((tempString = tempBufReader.readLine()) != null) { String[] tempStrArray = tempString.split(","); // 处理每行数据 } }2. kNN算法实现与优化
kNN算法的核心思想是"物以类聚"——通过计算待分类样本与训练集中各样本的距离,找出最近的k个邻居,然后根据这些邻居的类别进行投票决定待分类样本的类别。
距离度量是kNN的关键,常见的有:
- 曼哈顿距离:各维度绝对差之和
- 欧氏距离:平方差之和开方
public double distance(int paraI, int paraJ) { double resultDistance = 0; switch (distanceMeasure) { case MANHATTAN: // 曼哈顿距离计算 break; case EUCLIDEAN: // 欧氏距离计算 break; } return resultDistance; }邻居查找的原始实现需要多次扫描训练集,效率较低。我们可以优化为单次扫描:
public int[] computeNearests(int paraCurrent) { int[] resultNearests = new int[numNeighbors]; double[] tempDistances = new double[trainingSet.length]; // 一次计算所有距离 for (int i = 0; i < trainingSet.length; i++) { tempDistances[i] = distance(paraCurrent, trainingSet[i]); } // 然后找出最小的k个 }kNN在鸢尾花分类上的典型准确率能达到95%左右,但计算复杂度随数据量线性增长,这是其主要缺点。
3. 基于M-distance的推荐系统
推荐系统与分类问题看似不同,但核心思想也是"找相似"。这里我们实现基于物品的协同过滤推荐,使用M-distance衡量物品相似度。
M-distance认为两个物品的相似度取决于它们的平均评分是否接近:
if (Math.abs(itemAvgRating1 - itemAvgRating2) < radius) { // 视为邻居物品 }预测用户对某物品的评分时,取所有相似物品评分的平均值:
public double predictRating(int userId, int itemId) { double sum = 0; int count = 0; for (Item neighbor : findSimilarItems(itemId)) { if (userRatedItem(userId, neighbor.id)) { sum += getUserRating(userId, neighbor.id); count++; } } return count > 0 ? sum / count : DEFAULT_RATING; }这种方法的MAE(平均绝对误差)通常在0.7-0.8之间,调整radius参数可以平衡推荐覆盖率与准确率。
4. 朴素贝叶斯分类器实现
朴素贝叶斯基于贝叶斯定理,假设特征之间相互独立。我们先实现处理符号型数据的版本:
public void calculateConditionalProbabilities() { // 初始化三维数组 conditionalProbabilitiesLaplacian = new double[numClasses][numConditions][]; // 统计每个特征值在每个类别下的出现次数 for (Instance instance : dataset) { int cls = (int)instance.classValue(); for (int attr = 0; attr < numConditions; attr++) { int val = (int)instance.value(attr); conditionalCounts[cls][attr][val]++; } } // 计算拉普拉斯平滑后的概率 for (int cls = 0; cls < numClasses; cls++) { for (int attr = 0; attr < numConditions; attr++) { int numValues = dataset.attribute(attr).numValues(); for (int val = 0; val < numValues; val++) { conditionalProbabilitiesLaplacian[cls][attr][val] = (conditionalCounts[cls][attr][val] + 1) / (classCounts[cls] + numValues); } } } }分类时计算后验概率的对数,避免浮点数下溢:
public int classifyNominal(Instance instance) { double maxLogProbability = Double.NEGATIVE_INFINITY; int bestClass = -1; for (int cls = 0; cls < numClasses; cls++) { double logProb = Math.log(classDistributionLaplacian[cls]); for (int attr = 0; attr < numConditions; attr++) { int val = (int)instance.value(attr); logProb += Math.log(conditionalProbabilitiesLaplacian[cls][attr][val]); } if (logProb > maxLogProbability) { maxLogProbability = logProb; bestClass = cls; } } return bestClass; }5. 数值型数据的朴素贝叶斯
对于数值型特征,我们通常假设其服从高斯分布。需要为每个特征在每个类别下计算均值和标准差:
public void calculateGaussianParameters() { gaussianParameters = new GaussianParamters[numClasses][numConditions]; for (int cls = 0; cls < numClasses; cls++) { for (int attr = 0; attr < numConditions; attr++) { double sum = 0, sumSquares = 0; int count = 0; for (Instance instance : dataset) { if ((int)instance.classValue() == cls) { double val = instance.value(attr); sum += val; sumSquares += val * val; count++; } } double mean = sum / count; double stddev = Math.sqrt((sumSquares - sum*sum/count) / count); gaussianParameters[cls][attr] = new GaussianParamters(mean, stddev); } } }分类时使用高斯概率密度函数:
public int classifyNumerical(Instance instance) { double maxLogProbability = Double.NEGATIVE_INFINITY; int bestClass = -1; for (int cls = 0; cls < numClasses; cls++) { double logProb = Math.log(classDistributionLaplacian[cls]); for (int attr = 0; attr < numConditions; attr++) { double x = instance.value(attr); double mu = gaussianParameters[cls][attr].mu; double sigma = gaussianParameters[cls][attr].sigma; // 高斯分布的概率密度对数 logProb += -Math.log(sigma) - (x-mu)*(x-mu)/(2*sigma*sigma); } if (logProb > maxLogProbability) { maxLogProbability = logProb; bestClass = cls; } } return bestClass; }6. 算法对比与应用场景
虽然kNN和朴素贝叶斯都可用于分类,但它们的适用场景有所不同:
| 特性 | kNN | 朴素贝叶斯 |
|---|---|---|
| 训练速度 | 快(惰性学习) | 快 |
| 预测速度 | 慢 | 快 |
| 内存需求 | 高(存储全部数据) | 低(存储参数) |
| 特征相关性 | 可处理相关特征 | 假设特征独立 |
| 数据规模 | 适合中小规模 | 适合大规模 |
| 特征类型 | 数值型表现好 | 数值型、类别型均可 |
在推荐系统场景中,基于用户的协同过滤(UserCF)类似于kNN,而基于物品的协同过滤(ItemCF)更接近朴素贝叶斯的思想。选择哪种算法取决于具体需求和数据特性。
7. 工程实践建议
在实际项目中应用这些算法时,有几个关键点需要注意:
数据预处理:
- 归一化:kNN对特征尺度敏感,需做归一化
- 缺失值处理:朴素贝叶斯需要处理缺失值
// 归一化示例 public void normalize() { for (int attr = 0; attr < numConditions; attr++) { double min = Double.MAX_VALUE; double max = Double.MIN_VALUE; // 找出最小最大值 // 归一化每个特征值 } }参数调优:
- kNN中的k值
- M-distance中的radius阈值
- 朴素贝叶斯的平滑参数
性能优化:
- kNN可以使用KD树等数据结构加速搜索
- 对于大规模数据,可以考虑近似算法
评估指标:
- 分类问题:准确率、精确率、召回率、F1值
- 推荐系统:MAE、RMSE、覆盖率、多样性
// 评估指标计算示例 public void evaluate() { int[][] confusionMatrix = new int[numClasses][numClasses]; for (int i = 0; i < numInstances; i++) { int actual = (int)dataset.instance(i).classValue(); int predicted = predicts[i]; confusionMatrix[actual][predicted]++; } // 计算各项指标... }8. 扩展与进阶
掌握了基础实现后,可以考虑以下扩展方向:
加权kNN:给更近的邻居更高权重
// 加权投票 public int weightedVoting(int[] neighbors) { double[] weights = new double[numClasses]; for (int i = 0; i < neighbors.length; i++) { double distance = distances[i]; double weight = 1.0 / (distance + 1e-5); // 避免除零 int cls = (int)dataset.instance(neighbors[i]).classValue(); weights[cls] += weight; } // 返回权重最大的类别 }核密度估计:改进朴素贝叶斯对数值型特征的建模
混合型数据:同时处理数值型和类别型特征
增量学习:支持在线更新模型
分布式实现:使用Spark等框架处理大数据
通过本项目的实践,你不仅学会了两种重要算法的实现,还掌握了如何将机器学习应用于不同场景的关键技能。建议尝试将这些算法应用到自己的项目中,或者参加Kaggle等数据科学竞赛来进一步磨练技能。
