别再死记硬背了!用Python手撸一个ID3决策树,从熵到分类器一次搞懂
别再死记硬背了!用Python手撸一个ID3决策树,从熵到分类器一次搞懂
决策树算法作为机器学习中最直观的模型之一,常被比作"机器学习界的Hello World"。但很多初学者在学习时容易陷入两个极端:要么被数学公式吓退,要么只会调用sklearn的DecisionTreeClassifier而不知其所以然。今天,我们将打破这种局面——用不到200行Python代码,从熵的计算开始,一步步构建完整的ID3决策树分类器。
1. 决策树与信息熵:从概念到代码
1.1 为什么需要信息熵
想象你在玩20个问题的游戏:每次提问都希望能最大程度地缩小答案范围。信息熵正是量化这种"不确定性"的数学工具。在决策树中,我们通过计算信息增益来决定每个节点的分裂特征,而这一切都建立在信息熵的基础上。
熵的计算公式看似简单:
H(X) = -Σ p(x)log₂p(x)但如何将其转化为可运行的Python代码?下面是一个直观的实现:
import numpy as np def calc_entropy(labels): """计算信息熵""" unique_labels, counts = np.unique(labels, return_counts=True) probabilities = counts / len(labels) return -np.sum(probabilities * np.log2(probabilities))这个函数的核心逻辑:
- 使用
np.unique统计每个类别出现的次数 - 计算每个类别的概率
- 应用熵公式求和
注意:在实际应用中,当概率为0时,log2(0)无定义,但NumPy会正确处理这种情况返回0
1.2 信息增益的计算
信息增益衡量的是特征对分类不确定性的减少程度。计算步骤分为:
- 计算原始数据集的信息熵(父节点熵)
- 按特征分割数据集后,计算加权平均熵(子节点熵)
- 两者相减得到信息增益
对应的Python实现:
def information_gain(data, labels, feature_idx): """计算指定特征的信息增益""" parent_entropy = calc_entropy(labels) # 获取该特征的所有唯一值 feature_values = np.unique(data[:, feature_idx]) # 计算加权子节点熵 child_entropy = 0 for value in feature_values: mask = data[:, feature_idx] == value subset_labels = labels[mask] weight = len(subset_labels) / len(labels) child_entropy += weight * calc_entropy(subset_labels) return parent_entropy - child_entropy2. 构建决策树的核心逻辑
2.1 递归构建树结构
决策树的构建本质上是一个递归过程,包含三个关键步骤:
- 选择最佳分裂特征:计算所有特征的信息增益,选择增益最大的
- 创建分支节点:根据选定特征的不同取值创建分支
- 递归处理子集:对每个分支对应的数据子集重复上述过程
实现这一逻辑的Python代码如下:
def build_tree(data, labels, feature_names): # 终止条件1:所有样本属于同一类别 if len(np.unique(labels)) == 1: return labels[0] # 终止条件2:没有更多特征可用于分裂 if data.shape[1] == 0: return np.bincount(labels).argmax() # 选择最佳分裂特征 best_feature = select_best_feature(data, labels) best_feature_name = feature_names[best_feature] # 创建树节点 tree = {best_feature_name: {}} # 获取该特征的所有唯一值 feature_values = np.unique(data[:, best_feature]) # 递归构建子树 for value in feature_values: mask = data[:, best_feature] == value subset_data = np.delete(data[mask], best_feature, axis=1) subset_labels = labels[mask] subset_feature_names = np.delete(feature_names, best_feature) tree[best_feature_name][value] = build_tree( subset_data, subset_labels, subset_feature_names) return tree2.2 处理边界情况
在实际编码中,我们需要处理几种特殊情况:
- 连续特征处理:ID3算法原本只处理离散特征,可通过二分法扩展
- 缺失值处理:可采用常见值填充或概率分配
- 过拟合预防:设置最大深度或最小样本数限制
以下是增强版的终止条件判断:
def should_stop(data, labels, max_depth, current_depth, min_samples): # 所有样本属于同一类别 if len(np.unique(labels)) == 1: return True # 达到最大深度限制 if current_depth >= max_depth: return True # 样本数小于最小限制 if len(labels) < min_samples: return True # 没有更多特征可用于分裂 if data.shape[1] == 0: return True return False3. 完整ID3决策树实现
3.1 决策树分类器类
将上述功能封装成一个完整的类,提高代码复用性:
class ID3DecisionTree: def __init__(self, max_depth=None, min_samples_split=2): self.max_depth = max_depth self.min_samples_split = min_samples_split self.tree = None self.feature_names = None def fit(self, data, labels, feature_names): self.feature_names = feature_names self.tree = self._build_tree( data, labels, feature_names, current_depth=0) def _build_tree(self, data, labels, feature_names, current_depth): # 终止条件判断 if self._should_stop(data, labels, current_depth): return self._make_leaf_node(labels) # 选择最佳分裂特征 best_idx = self._select_best_feature(data, labels) best_name = feature_names[best_idx] # 创建树节点 node = {best_name: {}} # 获取特征唯一值并递归构建子树 feature_values = np.unique(data[:, best_idx]) for value in feature_values: mask = data[:, best_idx] == value subset_data = np.delete(data[mask], best_idx, axis=1) subset_labels = labels[mask] subset_features = np.delete(feature_names, best_idx) node[best_name][value] = self._build_tree( subset_data, subset_labels, subset_features, current_depth+1) return node def predict(self, X): return np.array([self._traverse_tree(x, self.tree) for x in X]) def _traverse_tree(self, sample, node): if not isinstance(node, dict): return node feature_name = next(iter(node)) feature_idx = np.where(self.feature_names == feature_name)[0][0] feature_value = sample[feature_idx] if feature_value in node[feature_name]: return self._traverse_tree(sample, node[feature_name][feature_value]) else: # 处理未见过的特征值 return self._handle_unknown_value(node[feature_name])3.2 可视化决策树
理解决策树的最好方式就是可视化其结构。我们可以使用简单的文本缩进来展示:
def print_tree(node, indent=""): if not isinstance(node, dict): print(indent + "预测: " + str(node)) return feature_name = next(iter(node)) print(indent + feature_name) for value, subtree in node[feature_name].items(): print(indent + "├── " + str(value) + ":") print_tree(subtree, indent + "│ ")4. 实战:用自制决策树解决真实问题
4.1 准备示例数据集
让我们创建一个简单的贷款审批数据集:
# 特征:年龄(0:青年,1:中年,2:老年),有工作(0:否,1:是),有房子(0:否,1:是) data = np.array([ [0, 0, 0], [0, 0, 1], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 1], [2, 0, 0], [2, 0, 1], [2, 1, 1] ]) # 标签:0:拒绝贷款,1:批准贷款 labels = np.array([0, 1, 1, 0, 1, 1, 0, 1, 1]) feature_names = np.array(['年龄', '有工作', '有房子'])4.2 训练并测试模型
# 初始化并训练模型 tree = ID3DecisionTree(max_depth=3) tree.fit(data, labels, feature_names) # 打印树结构 print_tree(tree.tree) # 预测新样本 test_samples = np.array([ [0, 1, 0], # 青年,有工作,没房子 [1, 0, 1], # 中年,没工作,有房子 [2, 1, 0] # 老年,有工作,没房子 ]) predictions = tree.predict(test_samples) print("预测结果:", predictions)4.3 性能优化技巧
- 特征预排序:对连续特征提前排序,加速最优分割点查找
- 并行计算:在多核CPU上并行计算不同特征的信息增益
- 缓存中间结果:避免重复计算相同子集的信息熵
- 使用Cython加速:将计算密集型部分用Cython重写
优化后的信息增益计算示例:
from joblib import Parallel, delayed def parallel_information_gain(data, labels, feature_idx): # ...同前... def select_best_feature_parallel(data, labels, n_jobs=-1): n_features = data.shape[1] gains = Parallel(n_jobs=n_jobs)( delayed(parallel_information_gain)(data, labels, i) for i in range(n_features) ) return np.argmax(gains)5. 从ID3到C4.5:决策树算法的演进
虽然我们实现了基础的ID3算法,但现代决策树通常使用其改进版本C4.5。主要改进包括:
| 特性 | ID3算法 | C4.5算法 |
|---|---|---|
| 分裂标准 | 信息增益 | 信息增益比 |
| 处理连续特征 | 不支持 | 支持 |
| 缺失值处理 | 不支持 | 支持 |
| 剪枝策略 | 无 | 悲观剪枝 |
| 多叉树 | 是 | 是 |
实现信息增益比的Python代码:
def gain_ratio(data, labels, feature_idx): info_gain = information_gain(data, labels, feature_idx) # 计算分裂信息(Split Information) feature_values = data[:, feature_idx] _, counts = np.unique(feature_values, return_counts=True) probabilities = counts / len(feature_values) split_info = -np.sum(probabilities * np.log2(probabilities)) # 避免除以0 if split_info == 0: return 0 return info_gain / split_info决策树算法的魅力在于其直观性和可解释性。通过这次从零实现,我深刻体会到:在机器学习中,真正理解一个算法的最佳方式就是亲手实现它。当看到自己编写的树结构能够正确分类样本时,那种成就感远胜过调用现成库函数。
