K-Means案例实际讲解,适合大学生突击期末
太棒了!既然你兴致这么高,我们就把 sklearn 的拐杖扔掉,用刚才学过的 NumPy 核心知识(广播、向量化、布尔索引),手搓一个简易版的 K-Means 算法!
你会发现,剥去框架的外衣,K-Means 的核心逻辑其实只有 3 步,而且用 NumPy 写出来极其优雅。
K-Means 的核心思想(一句话):
“找几个中心点,把离谁近的点归给谁,然后重新算中心点,直到中心点不再移动。”
完整代码实现(带详细注释)
import numpy as np
class SimpleKMeans:
definit(self, n_clusters=3, max_iters=100):
self.n_clusters = n_clusters
self.max_iters = max_iters
def fit(self, X): # X 的形状是 (样本数, 特征数),比如鸢尾花是 (150, 2) n_samples, n_features = X.shape # 【初始化】:随机挑选 n_clusters 个点作为初始中心 # 这里用整数列表索引(花式索引)随机抽取 random_indices = np.random.choice(n_samples, self.n_clusters, replace=False) self.centroids = X[random_indices, :] # 【迭代开始】 for _ in range(self.max_iters): # === 第 1 步:计算距离并分配簇 === # 利用广播计算每个点到每个中心的距离 # X 形状: (150, 2), centroids 形状: (3, 2) # 我们想让 X 减去每一个中心,得到 (150, 3, 2) 的三维数组 # 技巧:把 centroids 变成 (1, 3, 2),X 变成 (150, 1, 2) distances = np.sqrt(((X[:, np.newaxis, :] - self.centroids[np.newaxis, :, :]) ** 2).sum(axis=2)) # 找出每个点距离最近的那个中心的索引(形状变为 150) labels = np.argmin(distances, axis=1) # === 第 2 步:重新计算中心点 === new_centroids = np.zeros((self.n_clusters, n_features)) for k in range(self.n_clusters): # 【布尔索引】:把属于第 k 类的点全部挑出来,求平均 cluster_points = X[labels == k] # 如果某个簇没有分到点,保持原中心不变 if len(cluster_points) > 0: new_centroids[k] = cluster_points.mean(axis=0) # === 第 3 步:判断是否收敛 === # 如果新旧中心点的距离小于极小值,说明不再移动,提前结束! if np.allclose(self.centroids, new_centroids): print(f"算法在第 {_} 次迭代时收敛!") break self.centroids = new_centroids return labels, self.centroids代码里的“高光时刻”解析(重点看!)
- 距离计算的“降维打击”
distances = np.sqrt(((X[:, np.newaxis, :] - self.centroids[np.newaxis, :, :]) ** 2).sum(axis=2))
这一行是整个算法最核心的向量化操作!
- X[:, np.newaxis, :] 把形状从 (150, 2) 变成了 (150, 1, 2)。
- centroids[np.newaxis, :, :] 把形状从 (3, 2) 变成了 (1, 3, 2)。
- 两者相减时,NumPy 触发终极广播,直接生成一个 (150, 3, 2) 的立体矩阵。
- .sum(axis=2) 沿着特征维度求和,得到 (150, 3) 的距离矩阵。
- np.argmin(axis=1) 直接找出每一行(每个样本)距离最小的中心点编号。
没有写一行 for 循环,150 个点的距离瞬间算完!
- 更新中心点的“布尔索引”
cluster_points = X[labels == k]
new_centroids[k] = cluster_points.mean(axis=0)
- labels == k 生成一个布尔掩码(Mask)。
- 直接通过掩码把属于第 k 类的样本捞出来。
- mean(axis=0) 沿着行(样本)的方向求平均,瞬间算出新的中心点。
怎么跑起来?
造一点假数据测试一下
np.random.seed(42)
生成 3 个簇的假数据
cluster_1 = np.random.randn(50, 2) + [2, 2]
cluster_2 = np.random.randn(50, 2) + [-2, -2]
cluster_3 = np.random.randn(50, 2) + [2, -2]
fake_data = np.vstack([cluster_1, cluster_2, cluster_3])
运行我们的手搓算法
kmeans = SimpleKMeans(n_clusters=3)
labels, centroids = kmeans.fit(fake_data)
print(“找到的中心点:n”, centroids)
学习小结
到这里,你不仅学会了:
- K-Means 的底层原理(分配 -> 更新 -> 收敛)。
- NumPy 的进阶玩法(np.newaxis 制造维度、三维数组广播、布尔索引)。
你现在已经具备了阅读和理解大部分传统机器学习算法源码的能力了!
感觉怎么样?是不是有种打通任督二脉的爽快感?接下来你是想看看怎么把这个结果画成漂亮的散点图,还是想了解一下K-Means 算法的致命缺陷(比如怎么确定 K 值)?
