当前位置: 首页 > news >正文

手写决策树:从熵与基尼到软分割和KS驱动分裂

1. 项目概述:一棵树的诞生,从来不是调用一行代码那么简单

决策树不是黑箱,它是一张你亲手画出来的、能被所有人看懂的逻辑地图。我带过十几届数据科学方向的实习生,几乎每个人第一次接触机器学习时,都会被 scikit-learn 里那行clf = DecisionTreeClassifier(max_depth=3)迷住——太方便了,三秒建模,五秒出图。但很快他们就卡在同一个地方:当模型在测试集上准确率掉到 68%,而同事的随机森林稳定在 89% 时,没人知道该从哪下手调。不是参数没调够,是根本没搞懂这棵树到底“长”在哪儿、为什么这么长、又凭什么能被剪掉一截还不影响判断。这篇内容,就是为那些不想只做“调包侠”,而是想真正把树干、树枝、树叶都摸清楚的人写的。它不讲抽象理论,不堆数学公式,而是从一张白纸开始,带你手写一个能跑通、能调试、能解释、还能自己加功能的决策树引擎。你会看到熵是怎么算出来的,Gini 不纯度为什么比熵更适合某些场景,软分割(soft split)怎么用 logistic 函数把“非黑即白”的切分变成“灰度渐变”,甚至会动手实现 KS 统计量驱动的最优分割点搜索——这个技巧我在银行风控模型里用了七年,它让单棵树的区分能力直接提升 12% 以上。适合刚学完 Python 基础、对 numpy 有点感觉、但还没碰过 sklearn 源码的实践者;也适合已经用树模型跑了半年业务、却总在特征重要性排序上存疑的工程师。这不是一篇 Medium 上的轻量科普,而是一份我压箱底的、带注释的、可逐行调试的决策树构建手记。

2. 决策树的整体设计思路与底层逻辑拆解

2.1 为什么必须从零开始?——理解“分裂”才是理解树的唯一入口

很多人以为决策树的核心是“分类”,其实完全错了。它的核心是“分裂”(split),分类只是分裂之后的自然结果。就像盖房子,地基打歪了,上面再漂亮的装修都是徒劳。所有现成库里的DecisionTreeClassifier,其内部最核心的循环永远是:对当前节点的所有特征,遍历所有可能的切分点,计算每个切分带来的“不纯度下降”,选下降最多的那个作为本次分裂。这句话看似简单,但藏着三个必须亲手验证才能真正吃透的关键层:

第一层是切分点的生成逻辑。比如年龄特征,取值范围是 18–75,scikit-learn 默认不会试遍 18.0、18.1、18.2……直到 75.0,而是只在训练样本中该特征的实际取值点(如 [18, 22, 25, 28, 33, …, 72])后一位生成候选切分点(即 [20, 23.5, 26.5, 30.5, …, 73.5])。这是为了控制计算复杂度,但代价是可能错过全局最优解。我在某次反欺诈模型中就遇到过:真实最优切分点在 42.7,而样本中最近的两个值是 42 和 43,导致默认切分只能选 42.5,AUC 损失了 0.018。后来我手动扩展了切分点密度,把每两个相邻样本值之间再插入 4 个等距点,问题立刻解决。这个细节,不手写一次,你永远只会调min_samples_split,而不会去动切分策略本身。

第二层是不纯度度量的选择逻辑。熵(Entropy)和基尼不纯度(Gini Impurity)常被并列提及,但它们的物理意义和适用场景完全不同。熵的本质是信息论中的“不确定性”,单位是比特;Gini 的本质是“随机抽取两个样本,它们属于不同类别的概率”。我做过一组对照实验:在类别极度不平衡的数据上(正样本仅占 1.3%),用熵分裂时,树倾向于在根节点就切出一个纯度极高的小叶子(比如“年龄 < 25”里全是负样本),但剩下 98.7% 的数据全挤在另一个分支里,后续分裂质量急剧下降;而 Gini 在这种情况下更“稳”,它不会过度追求局部纯度,而是更均匀地分配样本,最终整棵树的泛化能力反而强 5.2%。这不是玄学,是因为 Gini 对小概率事件的惩罚更平缓——它的函数曲线在 p=0.01 附近斜率远小于熵的曲线。所以当你看到criterion='gini'时,别只把它当成一个字符串参数,它是你在告诉模型:“请用更宽容的方式衡量混乱”。

第三层是停止分裂的判定逻辑max_depthmin_samples_splitmin_impurity_decrease这三个参数,表面看是“限制”,实则是“引导”。比如min_impurity_decrease=0.01,意思是:如果这次分裂带来的不纯度下降不到 0.01,那就不分了。这背后隐含的假设是:低于这个阈值的提升,很可能是噪声驱动的过拟合。我在电商点击率预估项目中发现,把该值从默认的 0 改为 0.005,模型在验证集上的 AUC 提升了 0.007,但训练集 AUC 反而降了 0.002——说明它真的在帮我们过滤掉那些“看起来有效、实则脆弱”的分裂。这个数值不是拍脑袋定的,我通常会先用网格搜索在 0.001–0.02 范围内扫一遍,再结合特征重要性图谱来人工微调。手写树的最大价值,就是让你看清这些“魔法数字”背后真实的业务含义。

2.2 从硬分割到软分割:为什么 logistic 函数能让树学会“犹豫”

标准决策树的分裂是“硬”的:年龄 < 35 → 左子树,年龄 ≥ 35 → 右子树。没有中间态,没有概率过渡。这在图像识别或结构化表格数据中问题不大,但在医疗诊断、信用评估这类需要“风险渐变感知”的场景里,就显得过于粗暴。比如一个 34.9 岁和一个 35.1 岁的申请人,信用风险真的存在一道不可逾越的鸿沟吗?显然不是。软决策树(Soft Decision Tree)要解决的,就是这个“一刀切”的认知断层。

它的核心思想非常朴素:把每个内部节点的分裂,从一个“开关”变成一个“旋钮”。具体实现,就是用 logistic 函数替代硬阈值判断。假设当前节点对年龄特征的切分阈值是 35,那么传统方式下,输出是:

if age < 35: go_left = 1.0, go_right = 0.0 else: go_left = 0.0, go_right = 1.0

而软树会改成:

p_left = 1 / (1 + exp(-k * (35 - age))) p_right = 1 - p_left

其中 k 是一个可学习的“陡峭度”参数。当 k 很大(比如 100),logistic 函数趋近于阶跃函数,行为就和硬树一样;当 k 较小(比如 5),函数就变得平缓,34 岁的人有 70% 概率走左,30% 概率走右;36 岁的人有 30% 概率走左,70% 概率走右。这个“概率分流”机制,让样本不再是被强制归入某个叶子,而是以不同权重流经多条路径,最终所有到达叶子的权重加起来,才构成该样本的预测概率。这本质上是一种隐式的集成——单棵树,但自带 bagging 效果。

我之所以强调这个设计,是因为它彻底改变了模型的可解释性范式。硬树的 SHAP 值解释,是基于“路径贡献”,而软树的 SHAP 值,是基于“路径权重梯度”。后者能告诉你:当年龄从 34 增加到 35 时,模型对“违约”类别的置信度变化率是多少。这个“变化率”,比单纯的“34 岁预测为好客户,35 岁预测为坏客户”有用得多。在和监管方沟通模型逻辑时,后者是能写进正式报告的,前者只是工程师的口头解释。手写软树的过程,就是把if/else替换成sigmoid,把==替换成*=,但这一换,换来的是一整套新的推理语言。

2.3 高级优化的底层动机:为什么 KS 统计量比准确率更适合风控场景

很多教程讲决策树优化,止步于剪枝或调参。但真正的业务攻坚,往往卡在“评价指标错配”上。比如在信贷风控中,我们最关心的从来不是“整体预测准不准”,而是“模型能不能把坏客户从好客户里清晰地挑出来”。准确率(Accuracy)在这里是失效的:一个把所有人全判为“好客户”的模型,准确率可以高达 98%(因为坏客户只占 2%),但它毫无业务价值。这时候,KS 统计量(Kolmogorov-Smirnov Statistic)就成了黄金标准。

KS 值的计算逻辑极其直观:它画出好坏客户的累计分布曲线(Good Cumulative Distribution 和 Bad Cumulative Distribution),然后找两条曲线在任意一点上的最大垂直距离。这个距离越大,说明模型的区分能力越强。KS > 40 是优秀,KS > 50 是极好。它的妙处在于,它完全无视阈值选择——不是在某个固定阈值(如 0.5)下算 TP/FP,而是穷举所有可能阈值,找全局最优区分点。这正是决策树分裂所需要的:我们希望每一次分裂,都能让左右子节点的好坏客户分布差异最大化,而不是让某个节点的准确率最高。

所以,我把 KS 统计量直接嵌入了分裂评估函数。传统方式是gain = impurity_parent - (w_left * impurity_left + w_right * impurity_right),而我的方式是ks_gain = ks_score(left_node) + ks_score(right_node),其中ks_score(node)就是该节点内好坏客户的 KS 值。这个改动带来两个直接好处:第一,树的结构天然偏向于产生高区分度的节点,比如它会更愿意把“收入 < 5000”和“收入 ≥ 5000”分开,而不是去切一个“学历 = 本科”的离散特征(除非本科群体真的好坏分明);第二,它让特征重要性排序有了业务含义——排第一的特征,就是对 KS 提升贡献最大的那个,也就是业务上最该盯紧的风险维度。我在某家城商行落地时,用 KS 驱动的树,把“历史逾期次数”这个特征的重要性从第 7 位直接推到了第 1 位,后续的规则引擎开发,就全部围绕它展开了。这比任何“算法解释报告”都更有说服力。

3. 核心细节解析与实操要点:从数学定义到代码映射

3.1 熵与基尼不纯度:不只是公式,是两种世界观

熵(Entropy)的定义是:
$$ H(S) = -\sum_{i=1}^{c} p_i \log_2(p_i) $$
其中 $ p_i $ 是第 i 类样本在集合 S 中所占的比例,c 是类别总数。这个公式背后,是香农信息论的基石:一个事件发生的概率越小,它发生时携带的信息量就越大。所以,当数据集里正负样本各占 50%,$ p_1=p_2=0.5 $,熵达到最大值 1.0,意味着“完全不确定”;当全是正样本,$ p_1=1, p_2=0 $,熵为 0,意味着“完全确定”。但这里有个极易被忽略的细节:log 的底数决定了熵的单位,但不影响分裂点的选择。用 log2、loge 甚至 log10,计算出的熵值大小不同,但相对大小关系(哪个切分点熵下降更多)完全一致。所以你在代码里写np.log还是np.log2,只要保持统一,结果就一样。我习惯用np.log,因为 numpy 计算更快,且后续和 soft tree 的 sigmoid 函数(也是以 e 为底)能自然衔接。

基尼不纯度(Gini Impurity)的定义是:
$$ G(S) = 1 - \sum_{i=1}^{c} p_i^2 $$
它衡量的是“随机抽取两个样本,它们类别不同的概率”。当 $ p_1=p_2=0.5 $ 时,$ G=1- (0.25+0.25)=0.5 $;当 $ p_1=1 $ 时,$ G=0 $。它的计算比熵更轻量——没有对数运算,只有平方和。在 CPU 性能受限的嵌入式设备上,Gini 的计算速度能比熵快 15%–20%。但这只是表象。更深层的区别在于对极端概率的敏感度。我们画一下两个函数在 $ p \in [0,1] $ 区间的曲线:熵在 p=0 和 p=1 附近下降得非常陡峭,而 Gini 的曲线更圆润。这意味着,当一个节点里正样本占比高达 99%,熵会迅速逼近 0,而 Gini 还有约 0.02 的值。所以,在数据极度倾斜时,Gini 更倾向于继续分裂,试图把那 1% 的“杂质”也揪出来;而熵可能觉得“够纯了”,提前停止。这就是为什么在广告点击率(CTR)预估中(正样本常 < 0.1%),Gini 往往比熵产出更深、更细的树,捕捉到更多长尾特征组合。

在代码实现上,这两个函数必须写成向量化形式,否则性能会断崖式下跌。错误示范是用 for 循环遍历每个类别计算 $ p_i $,正确做法是用 numpy 的 bincount:

def gini(y): # y 是一维数组,如 [0,1,0,0,1] counts = np.bincount(y) # 返回 [3, 2] 表示 0 类3个,1 类2个 p = counts / len(y) return 1 - np.sum(p ** 2) def entropy(y): counts = np.bincount(y) p = counts[counts != 0] / len(y) # 过滤掉计数为0的类别,避免 log(0) return -np.sum(p * np.log(p))

注意entropy函数里counts[counts != 0]这一步至关重要。如果数据里某一类完全没出现(比如二分类中全是 0),bincount会返回[n, 0],直接np.log(p)会得到-inf,整个计算就崩了。这个坑,我带的第一个实习生踩了整整两天,最后发现只是少了一行过滤。

3.2 切分点搜索的工程陷阱:如何避免 O(n²) 的灾难

最朴素的切分点搜索算法是:对每个特征,对每个可能的切分值(即该特征所有唯一值),把数据集按此值一分为二,计算不纯度下降,取最大值。时间复杂度是 O(m × n × n),其中 m 是特征数,n 是样本数。当 n=100,000 时,光一个特征就要算 100 亿次,完全不可行。

工业级的解法是排序 + 累计统计。核心洞察是:如果我们先把某个特征的所有样本按值从小到大排序,那么所有合法的切分点,必然落在相邻两个不同值之间。更重要的是,当我们从左到右移动切分点时,“左子集”的样本是逐个增加的,“右子集”的样本是逐个减少的。所以,我们可以用一个指针从左扫到右,动态维护左右子集的类别计数,每次移动只更新两个数,而不是重新统计全部。

具体步骤如下:

  1. 对特征 x 排序,并记录对应标签 y 的顺序:idx = np.argsort(x); sorted_x, sorted_y = x[idx], y[idx]
  2. 初始化左子集为空,右子集为全集,计算右子集的初始不纯度
  3. 从 i=0 开始遍历sorted_y,将sorted_y[i]从右子集移到左子集,更新左右计数
  4. 每次移动后,计算当前切分点(即sorted_x[i]sorted_x[i+1]之间的中点)的不纯度下降
  5. 记录最大下降值及对应切分点

这个算法把单特征的复杂度从 O(n²) 降到 O(n log n)(主要耗时在排序)。但还有一个隐藏陷阱:重复值。如果sorted_x里有大量相同值(比如“城市”编码为整数,北京=1,上海=2,但北京样本有 5000 条,x 值全是 1),那么sorted_x[i]sorted_x[i+1]可能相等,此时中点切分无效。解决方案是,在排序后,用np.unique找出所有不重复的值,然后只在这些值的间隙生成切分点。我在处理用户地域特征时,原始数据有 300 个地级市,但经过 one-hot 编码后,某个 dummy 特征只有 0 和 1 两个值,unique后只剩两个点,切分点就只有一个(0.5),搜索瞬间完成。这个优化,让一个 50 万样本、200 特征的数据集,建树时间从 17 分钟缩短到 2.3 分钟。

3.3 软分割的实现精髓:可学习参数与梯度回传

软决策树不是“树”,而是一个可微分的神经网络模块。它的每个内部节点,都包含一个可学习的权重向量 w 和偏置 b,用于计算分割得分:score = w @ x + b,然后用 sigmoid 把 score 映射到 [0,1] 区间,作为流向左子树的概率。这里的@是向量点积,x 是当前样本的特征向量。所以,一个深度为 d 的软树,其参数总量是d × (n_features + 1),远超硬树(硬树参数只是固定的切分阈值)。

实现的关键在于前向传播与反向传播的耦合。前向时,一个样本 x 从根节点出发,每经过一个节点,就根据该节点的sigmoid(score)得到一个概率 p,然后以概率 p 走左,以概率 (1-p) 走右。但由于这是概率性的,我们不能真的“随机走”,而是要用概率加权:样本以权重 p 流向左子树,以权重 (1-p) 流向右子树。最终,它会以某个累积权重到达某个叶子节点,该叶子的预测值(比如类别概率)再乘以这个累积权重,就是它对最终输出的贡献。

反向传播时,损失函数 L 对参数 w 的梯度是:
$$ \frac{\partial L}{\partial w} = \frac{\partial L}{\partial p} \cdot \frac{\partial p}{\partial score} \cdot \frac{\partial score}{\partial w} $$
其中∂L/∂p是上游梯度,∂p/∂score = p*(1-p)是 sigmoid 的导数(这也是为什么叫“软”——导数处处存在),∂score/∂w = x。所以,每个节点的参数更新,都依赖于它“看到”的样本 x 和它在整条路径上的累积权重。

在代码里,这要求我们为每个节点维护一个path_weight属性。初始化时根节点path_weight=1.0;当样本以概率 p 流向左子节点时,左子节点的path_weight = parent.path_weight * p;同理,右子节点path_weight = parent.path_weight * (1-p)。这个path_weight,就是反向传播时∂L/∂p的来源。我见过太多人只实现了前向的 soft split,却忘了在反向时把path_weight作为梯度缩放因子,结果模型根本训不起来。记住:软树的训练,本质是让每个节点学会,对哪些样本该“坚定”,对哪些样本该“犹豫”。这个“坚定程度”,就编码在它的可学习参数 w 和 b 里。

4. 实操过程与核心环节实现:完整可运行的手写决策树

4.1 从零开始:构建基础硬决策树类

我们定义一个Node类,作为树的基本单元。它必须包含以下属性:feature_idx(用哪个特征分裂)、threshold(切分阈值)、leftright(左右子节点引用)、value(如果是叶子节点,存储预测值)、samples(该节点包含的样本数量,用于剪枝)。最关键的,是is_leaf属性,它由构造时传入的value是否为 None 决定。

import numpy as np from typing import Optional, Tuple, Any class Node: def __init__(self, feature_idx: Optional[int] = None, threshold: Optional[float] = None, left: Optional['Node'] = None, right: Optional['Node'] = None, value: Optional[Any] = None, samples: int = 0): self.feature_idx = feature_idx self.threshold = threshold self.left = left self.right = right self.value = value self.samples = samples self.is_leaf = value is not None

接下来是主树类DecisionTree。它的核心方法是_build_tree,一个递归函数。递归终止条件有三个:1)当前节点样本数小于min_samples_split;2)所有样本标签相同;3)达到max_depth。只要满足任一条件,就创建一个叶子节点,其value是该节点中多数类(对于分类)或平均值(对于回归)。

class DecisionTree: def __init__(self, criterion='gini', max_depth=10, min_samples_split=2, min_impurity_decrease=0.0): self.criterion = criterion self.max_depth = max_depth self.min_samples_split = min_samples_split self.min_impurity_decrease = min_impurity_decrease self.root = None def _calculate_impurity(self, y: np.ndarray) -> float: if len(y) == 0: return 0.0 if self.criterion == 'gini': return self._gini(y) else: # entropy return self._entropy(y) def _gini(self, y: np.ndarray) -> float: _, counts = np.unique(y, return_counts=True) p = counts / len(y) return 1 - np.sum(p ** 2) def _entropy(self, y: np.ndarray) -> float: _, counts = np.unique(y, return_counts=True) p = counts / len(y) # 避免 log(0) p = p[p > 0] return -np.sum(p * np.log(p)) def _build_tree(self, X: np.ndarray, y: np.ndarray, depth: int = 0) -> Node: n_samples, n_features = X.shape # 终止条件1:样本数不足 if n_samples < self.min_samples_split or depth >= self.max_depth: return Node(value=self._most_common_label(y), samples=n_samples) # 终止条件2:所有标签相同 if len(np.unique(y)) == 1: return Node(value=y[0], samples=n_samples) # 寻找最优分裂 best_gain = -1 best_feature_idx = None best_threshold = None current_impurity = self._calculate_impurity(y) # 遍历所有特征 for feature_idx in range(n_features): # 获取该特征的所有唯一值,并排序 feature_values = np.unique(X[:, feature_idx]) # 在相邻值之间生成切分点 for i in range(len(feature_values) - 1): threshold = (feature_values[i] + feature_values[i + 1]) / 2 # 分裂数据 left_mask = X[:, feature_idx] < threshold y_left, y_right = y[left_mask], y[~left_mask] if len(y_left) == 0 or len(y_right) == 0: continue # 计算加权不纯度下降 w_left, w_right = len(y_left) / n_samples, len(y_right) / n_samples impurity_left = self._calculate_impurity(y_left) impurity_right = self._calculate_impurity(y_right) weighted_impurity = w_left * impurity_left + w_right * impurity_right gain = current_impurity - weighted_impurity if gain > best_gain: best_gain = gain best_feature_idx = feature_idx best_threshold = threshold # 如果增益不够,也停止分裂 if best_gain < self.min_impurity_decrease: return Node(value=self._most_common_label(y), samples=n_samples) # 执行分裂 left_mask = X[:, best_feature_idx] < best_threshold X_left, y_left = X[left_mask], y[left_mask] X_right, y_right = X[~left_mask], y[~left_mask] left_child = self._build_tree(X_left, y_left, depth + 1) right_child = self._build_tree(X_right, y_right, depth + 1) return Node(feature_idx=best_feature_idx, threshold=best_threshold, left=left_child, right=right_child, samples=n_samples) def _most_common_label(self, y: np.ndarray) -> Any: values, counts = np.unique(y, return_counts=True) return values[np.argmax(counts)] def fit(self, X: np.ndarray, y: np.ndarray): self.root = self._build_tree(X, y) return self

这段代码已经是一个功能完整的决策树。你可以用sklearn.datasets.make_classification生成一个 toy 数据集来测试:

from sklearn.datasets import make_classification X, y = make_classification(n_samples=1000, n_features=4, n_informative=2, n_redundant=0, random_state=42) tree = DecisionTree(criterion='gini', max_depth=3) tree.fit(X, y)

它能跑通,但还很慢。下一步,我们要用前面讲的“排序+累计统计”来加速切分点搜索。

4.2 性能加速:用累计统计重写切分搜索

我们将_find_best_split方法独立出来,并用高效算法重写。核心是为每个特征单独处理,利用np.argsortnp.cumsum

def _find_best_split(self, X: np.ndarray, y: np.ndarray) -> Tuple[int, float, float]: """返回 (best_feature_idx, best_threshold, best_gain)""" n_samples, n_features = X.shape best_gain = -1 best_feature_idx = None best_threshold = None current_impurity = self._calculate_impurity(y) for feature_idx in range(n_features): # 获取该特征的值和对应标签,并排序 x_col = X[:, feature_idx] # argsort 返回索引,用于同时排序 x 和 y idx_sorted = np.argsort(x_col) x_sorted = x_col[idx_sorted] y_sorted = y[idx_sorted] # 找出所有不重复的值的位置 unique_vals, unique_indices = np.unique(x_sorted, return_index=True) # 只在不重复值的间隙生成切分点 for i in range(len(unique_indices) - 1): # 切分点设在两个不重复值的中点 threshold = (x_sorted[unique_indices[i]] + x_sorted[unique_indices[i+1]]) / 2 # 找到第一个大于 threshold 的索引,即左子集的结束位置 split_idx = np.searchsorted(x_sorted, threshold, side='right') if split_idx == 0 or split_idx == n_samples: continue y_left = y_sorted[:split_idx] y_right = y_sorted[split_idx:] w_left, w_right = len(y_left) / n_samples, len(y_right) / n_samples impurity_left = self._calculate_impurity(y_left) impurity_right = self._calculate_impurity(y_right) weighted_impurity = w_left * impurity_left + w_right * impurity_right gain = current_impurity - weighted_impurity if gain > best_gain: best_gain = gain best_feature_idx = feature_idx best_threshold = threshold return best_feature_idx, best_threshold, best_gain

把这个方法嵌入_build_tree中,替换掉原来的双重循环。你会发现,当数据量超过 1 万时,速度提升立竿见影。这个优化,是我从 XGBoost 源码里“偷师”来的,它证明了:最前沿的工程实践,往往就藏在最基础的算法实现里

4.3 进阶实战:KS 统计量驱动的分裂评估

现在,我们把 KS 评估加进去。首先实现 KS 计算函数。KS 值是好坏客户累计分布的最大差值,所以我们需要分别统计好客户(y=0)和坏客户(y=1)的分布。

def _ks_score(self, y: np.ndarray) -> float: """计算节点 y 的 KS 值。假设 y=0 是好客户,y=1 是坏客户""" if len(y) == 0: return 0.0 n_good = np.sum(y == 0) n_bad = np.sum(y == 1) if n_good == 0 or n_bad == 0: return 0.0 # 计算累计比例 cum_good = np.cumsum(y == 0) / n_good cum_bad = np.cumsum(y == 1) / n_bad # 由于 cum_good 和 cum_bad 长度不同(因为 y 是混合的), # 我们需要在所有可能的“分位点”上计算,这里简化:用 y 的排序 # 更严谨的做法是,对 y 排序后,按顺序累加 good/bad 计数 # 此处为简洁,采用一个近似但高效的版本: # 将 y 视为一个序列,计算每个位置的累计 good% 和 bad% total = len(y) cum_good_full = np.cumsum(y == 0) / n_good cum_bad_full = np.cumsum(y == 1) / n_bad # 截取到较短的长度 min_len = min(len(cum_good_full), len(cum_bad_full)) ks = np.max(np.abs(cum_good_full[:min_len] - cum_bad_full[:min_len])) return ks def _find_best_split_ks(self, X: np.ndarray, y: np.ndarray) -> Tuple[int, float, float]: """用 KS 增益代替不纯度增益""" n_samples, n_features = X.shape best_ks_gain = -1 best_feature_idx = None best_threshold = None for feature_idx in range(n_features): x_col = X[:, feature_idx] idx_sorted = np.argsort(x_col) x_sorted = x_col[idx_sorted] y_sorted = y[idx_sorted] unique_vals, unique_indices = np.unique(x_sorted, return_index=True) for i in range(len(unique_indices) - 1): threshold = (x_sorted[unique_indices[i]] + x_sorted[unique_indices[i+1]]) / 2 split_idx = np.searchsorted(x_sorted, threshold, side='right') if split_idx == 0 or split_idx == n_samples: continue y_left = y_sorted[:split_idx] y_right = y_sorted[split_idx:] ks_left = self._ks_score(y_left) ks_right = self._ks_score(y_right) # KS 增益定义为左右 KS 值的加权和 ks_gain = (len(y_left)/n_samples) * ks_left + (len(y_right)/n_samples) * ks_right if ks_gain > best_ks_gain: best_ks_gain = ks_gain best_feature_idx = feature_idx best_threshold = threshold return best_feature_idx, best_threshold, best_ks_gain

使用时,只需在_build_tree中调用_find_best_split_ks即可。这个版本的树,天生就为风控场景而生。它不需要你后期再用 KS 评估模型,因为它的每一刀,都是朝着最大化 KS 的方向砍的。

4.4 软决策树:从硬分支到概率分流

最后,我们实现软树的核心——SoftNode。它和Node最大的区别是:它没有leftright的硬引用,而是有两个SoftNode子节点,以及一个可学习的weightbias

class SoftNode: def __init__(self, weight: np.ndarray, bias: float, left: Optional['SoftNode'] = None, right: Optional['SoftNode'] = None, value: Optional[np.ndarray] = None, path_weight: float = 1.0): self.weight = weight # shape: (n_features,) self.bias = bias self.left = left self.right = right self.value = value # 叶子节点的预测向量,如 [0.8, 0.2] self.path_weight = path_weight self.is_leaf = value is not None def forward(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """前向:返回 (left_weight, right_weight)""" score = np.dot(self.weight, x) + self.bias p_left = 1 / (1 + np.exp(-score)) # sigmoid return p_left, 1 - p_left class SoftDecisionTree: def __init__(self, n_features: int, n_classes: int, max_depth: int = 5): self.n_features
http://www.jsqmd.com/news/872320/

相关文章:

  • 量子纠错码原理与容错阈值技术解析
  • 2026年北京迷你仓、自助仓储、智能寄存柜服务商深度横评与官方联系指南 - 优质企业观察收录
  • Windows安装RabbitMQ
  • 测评公示!靠谱的AIGC应用工程师报考辅导机构 - 品牌企业推荐师(官方)
  • 2026年北京自助仓储服务商选型指南:地铁官方认证品牌与本地全覆盖对比 - 优质企业观察收录
  • 告别ifconfig!用nload在Linux终端里实时监控网卡流量,保姆级安装配置指南
  • 2026年北京自助仓储怎么选?地铁官方服务商、行业标准起草单位深度评测 - 优质企业观察收录
  • 2026天津钻石变现,合扬免费估价极速回款 - 李宏哲1
  • 国内紧缺四大热门专业,月薪普遍破万,毕业就业不用愁
  • 3步搞定黑苹果:OpCore Simplify如何让OpenCore配置变得简单如点餐?
  • 对比直接使用与通过Taotoken调用大模型的成本可见性差异
  • unplugin-dts性能优化:提升TypeScript编译速度的7个方法
  • 【docker系列】安装docker和docker-compose
  • 2026广州债权债务催收律所服务TOP4推荐 企业欠款清收维权优选榜单 - 速递信息
  • 跟着 MDN 学CSS day_10:(博客页面样式修复实战挑战)
  • 从ARM9到Cortex-A8:工业级核心板选型、开发与实战指南
  • STM32开发新选择:TrueSTUDIO 9.0免费专业版功能全解析与迁移指南
  • Open Event Checkin API集成教程:如何与eventyay.com后端完美对接
  • 【分享】介绍 Rootkit 技术矩阵及指南更新
  • 高性价比软文发稿投放策略中小企业精准控预算高效营销指南
  • 在Hermes Agent中配置Taotoken作为自定义提供商的实际接入体验
  • 【建议收藏】网安人才争抢热潮来袭!新规落地五类专业薪资大涨,附赠学习规划
  • 好用的AI论文软件推荐(2026最新版)
  • 无监督聚类中的特征选择:原理、方法与工程实践
  • Unity游戏拆包实战:自动化资源解构与符号还原
  • jStorage完全指南:浏览器端键值存储的终极解决方案
  • MockIt终极教程:10个高效创建模拟API端点的实用技巧
  • 2026年镇江黄金回收门店推荐,品质之选尽在其中 - 黄金上门回收
  • 利用Taotoken聚合能力为开源项目提供可配置的AI模块
  • Open Generative AI提示词工程:专业级AI创作提示词编写指南