K-Means实战:用Python给鸢尾花数据集自动分个类(附完整代码与可视化)
K-Means实战:用Python给鸢尾花数据集自动分个类(附完整代码与可视化)
鸢尾花数据集(Iris dataset)是机器学习领域最经典的数据集之一,它包含了三种鸢尾花的四个特征:花萼长度、花萼宽度、花瓣长度和花瓣宽度。这个数据集常被用来演示分类和聚类算法。今天,我们将使用K-Means算法来自动对这些花进行分类,并通过可视化直观地展示聚类结果。
K-Means是一种无监督学习算法,它通过迭代将数据点分配到K个簇中,使得每个数据点都属于距离最近的簇中心。我们将从数据加载开始,一步步完成预处理、模型训练、评估和可视化的完整流程。无论你是机器学习初学者还是希望巩固知识的实践者,这个实战项目都能帮助你理解K-Means算法的实际应用。
1. 环境准备与数据加载
在开始之前,确保你已经安装了必要的Python库:
pip install numpy pandas matplotlib seaborn scikit-learn我们将使用scikit-learn内置的鸢尾花数据集:
from sklearn.datasets import load_iris import pandas as pd # 加载数据集 iris = load_iris() X = iris.data # 特征 y = iris.target # 实际类别(用于后续对比) # 转换为DataFrame方便查看 df = pd.DataFrame(X, columns=iris.feature_names) df['species'] = y print(df.head())输出示例:
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) species 0 5.1 3.5 1.4 0.2 0 1 4.9 3.0 1.4 0.2 0 2 4.7 3.2 1.3 0.2 0 3 4.6 3.1 1.5 0.2 0 4 5.0 3.6 1.4 0.2 0提示:虽然K-Means是无监督学习算法,但我们保留了实际类别标签(species),这将在后续帮助我们直观评估聚类效果。
2. 数据预处理与探索
良好的数据预处理是成功聚类的基础。让我们先进行一些基本的探索性分析:
import matplotlib.pyplot as plt import seaborn as sns # 特征分布可视化 plt.figure(figsize=(12, 6)) for i, feature in enumerate(iris.feature_names): plt.subplot(2, 2, i+1) sns.histplot(df[feature], kde=True) plt.title(f'Distribution of {feature}') plt.tight_layout() plt.show()从分布图中我们可以观察到:
- 花萼宽度(sepal width)近似正态分布
- 花瓣长度和宽度(petal length/width)呈现明显的双峰分布
- 不同特征的量纲相似,不需要特别进行标准化
接下来,我们使用散点图矩阵观察特征间的关系:
sns.pairplot(df, hue='species', palette='viridis') plt.show()从散点图中可以明显看出,花瓣长度和花瓣宽度对于区分不同种类鸢尾花最为有效。
3. 实施K-Means聚类
3.1 确定最佳K值
K-Means需要预先指定簇的数量K。我们可以使用肘部法则(Elbow Method)来确定最佳K值:
from sklearn.cluster import KMeans import numpy as np # 计算不同K值下的SSE(误差平方和) sse = [] k_range = range(1, 10) for k in k_range: kmeans = KMeans(n_clusters=k, random_state=42) kmeans.fit(X) sse.append(kmeans.inertia_) # SSE存储在inertia_属性中 # 绘制肘部曲线 plt.figure(figsize=(8, 5)) plt.plot(k_range, sse, 'bo-') plt.xlabel('Number of clusters (K)') plt.ylabel('Sum of squared distances (SSE)') plt.title('Elbow Method for Optimal K') plt.xticks(k_range) plt.grid() plt.show()从图中可以看到,当K=3时曲线出现明显的"肘部",这与鸢尾花实际有三个种类的事实一致。
3.2 训练K-Means模型
基于肘部法则的结果,我们选择K=3:
# 训练K-Means模型 kmeans = KMeans(n_clusters=3, random_state=42) clusters = kmeans.fit_predict(X) # 将聚类结果添加到DataFrame df['cluster'] = clusters # 查看聚类中心 print("Cluster centers:") print(pd.DataFrame(kmeans.cluster_centers_, columns=iris.feature_names))3.3 评估聚类效果
我们可以使用轮廓系数(Silhouette Score)来评估聚类质量:
from sklearn.metrics import silhouette_score # 计算轮廓系数 score = silhouette_score(X, clusters) print(f"Silhouette Score: {score:.3f}") # 对比实际类别和聚类结果的分布 print("\nCluster distribution vs Actual species:") print(pd.crosstab(df['cluster'], df['species']))轮廓系数范围在[-1,1]之间,值越接近1表示聚类效果越好。我们的结果通常在0.5左右,说明聚类结构合理。
4. 结果可视化与分析
4.1 二维散点图可视化
我们选择两个最具区分性的特征(花瓣长度和宽度)进行可视化:
plt.figure(figsize=(12, 5)) # 实际类别分布 plt.subplot(1, 2, 1) sns.scatterplot(x='petal length (cm)', y='petal width (cm)', hue='species', data=df, palette='viridis') plt.title('Actual Species Distribution') # 聚类结果 plt.subplot(1, 2, 2) sns.scatterplot(x='petal length (cm)', y='petal width (cm)', hue='cluster', data=df, palette='viridis') plt.scatter(kmeans.cluster_centers_[:, 2], kmeans.cluster_centers_[:, 3], marker='x', s=200, c='red', label='Cluster Centers') plt.title('K-Means Clustering Results') plt.tight_layout() plt.show()4.2 三维可视化(可选)
对于更高维度的可视化,我们可以使用前三个主成分:
from sklearn.decomposition import PCA # 降维到3D pca = PCA(n_components=3) X_pca = pca.fit_transform(X) # 创建3D图 fig = plt.figure(figsize=(10, 7)) ax = fig.add_subplot(111, projection='3d') # 绘制聚类结果 scatter = ax.scatter(X_pca[:, 0], X_pca[:, 1], X_pca[:, 2], c=clusters, cmap='viridis', s=50) # 绘制聚类中心 centers_pca = pca.transform(kmeans.cluster_centers_) ax.scatter(centers_pca[:, 0], centers_pca[:, 1], centers_pca[:, 2], marker='x', s=200, c='red', label='Cluster Centers') ax.set_xlabel('Principal Component 1') ax.set_ylabel('Principal Component 2') ax.set_zlabel('Principal Component 3') plt.title('3D Visualization of K-Means Clustering') plt.legend() plt.show()4.3 聚类结果分析
通过对比实际类别和聚类结果,我们可以发现:
- 簇0主要对应setosa品种(完全匹配)
- 簇1主要对应versicolor品种(少量误判)
- 簇2主要对应virginica品种(部分与versicolor重叠)
这种混淆在花瓣特征的中等值区域尤为明显,反映了这些品种在形态上的自然重叠。
5. 高级话题:距离度量与算法优化
5.1 不同距离度量的影响
K-Means默认使用欧氏距离,我们也可以尝试曼哈顿距离:
from sklearn.metrics import pairwise_distances_argmin_min # 自定义使用曼哈顿距离的K-Means kmeans_manhattan = KMeans(n_clusters=3, random_state=42) clusters_m = kmeans_manhattan.fit_predict(X) # 计算轮廓系数 score_m = silhouette_score(X, clusters_m) print(f"Silhouette Score with Manhattan distance: {score_m:.3f}")在实际应用中,不同距离度量的选择取决于数据的特性:
- 欧氏距离:适用于各向同性的数据
- 曼哈顿距离:对异常值更鲁棒
- 余弦相似度:适用于高维稀疏数据
5.2 K-Means++初始化
默认的K-Means使用随机初始化,可能收敛到局部最优。K-Means++提供了更智能的初始化方式:
# 使用K-Means++初始化 kmeans_plus = KMeans(n_clusters=3, init='k-means++', random_state=42) clusters_plus = kmeans_plus.fit_predict(X) # 比较SSE print(f"SSE with random init: {kmeans.inertia_:.2f}") print(f"SSE with K-Means++ init: {kmeans_plus.inertia_:.2f}")K-Means++通常会找到更好的初始中心位置,减少所需的迭代次数。
5.3 处理不同尺度特征
虽然鸢尾花数据集各特征尺度相似,但在实际应用中,我们常需要标准化:
from sklearn.preprocessing import StandardScaler # 标准化数据 scaler = StandardScaler() X_scaled = scaler.fit_transform(X) # 在标准化数据上运行K-Means kmeans_scaled = KMeans(n_clusters=3, random_state=42) clusters_scaled = kmeans_scaled.fit_predict(X_scaled) # 比较结果 print("\nCluster distribution with scaled features:") print(pd.crosstab(clusters_scaled, df['species']))标准化可以确保每个特征对距离计算的贡献相当,避免量纲大的特征主导聚类结果。
6. 实际应用建议
在真实项目中使用K-Means时,有几个实用技巧值得注意:
数据预处理:
- 处理缺失值(删除或填充)
- 考虑特征缩放(特别是当特征量纲不同时)
- 必要时进行特征选择或降维
模型调优:
- 多次运行取最优结果(n_init参数)
- 尝试不同的初始化方法
- 考虑使用MiniBatchKMeans处理大数据集
结果解释:
- 分析聚类中心的特征值
- 可视化检查聚类质量
- 结合业务知识验证聚类意义
常见陷阱:
- K值选择不当
- 忽略数据分布假设(K-Means假设簇是凸形且大小相似)
- 过度依赖轮廓系数等指标
# 完整代码示例 from sklearn.datasets import load_iris from sklearn.cluster import KMeans from sklearn.metrics import silhouette_score import matplotlib.pyplot as plt import pandas as pd # 加载数据 iris = load_iris() X = iris.data # 训练模型 kmeans = KMeans(n_clusters=3, random_state=42) clusters = kmeans.fit_predict(X) # 评估 score = silhouette_score(X, clusters) print(f"Silhouette Score: {score:.3f}") # 可视化 plt.scatter(X[:, 2], X[:, 3], c=clusters, cmap='viridis') plt.scatter(kmeans.cluster_centers_[:, 2], kmeans.cluster_centers_[:, 3], marker='x', s=200, c='red') plt.xlabel('Petal length (cm)') plt.ylabel('Petal width (cm)') plt.title('K-Means Clustering of Iris Dataset') plt.show()在实际项目中,我发现将K-Means与其他技术结合使用往往能获得更好效果。例如,可以先使用PCA降维再聚类,或者将聚类结果作为新特征输入到监督学习模型中。
