当前位置: 首页 > news >正文

`train_test_split` 是什么?

一、函数基础:train_test_split 是什么?

train_test_split 是 sklearn 提供的数据集划分工具,核心功能是:

  • 随机打乱原始数据(避免数据有序性导致的偏差);
  • 按指定比例拆分数据为「两部分」(默认是训练集和测试集);
  • 保证拆分后,xy 的对应关系不混乱(即某个样本的特征和标签不会被分到不同集合)。

语法格式(简化版):

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

返回值是 4个变量(固定顺序):

  1. X_train:拆分后的「训练集特征」;
  2. X_test:拆分后的「测试集特征」;
  3. y_train:拆分后的「训练集标签」(与 X_train 一一对应);
  4. y_test:拆分后的「测试集标签」(与 X_test 一一对应)。

二、逐参数解释(结合你的代码)

train_test_split(x, y, test_size=0.2, random_state=42)

每个参数的作用和细节如下:

参数名 代码中的值 核心作用 关键细节
第一个位置参数 x 输入的「特征数据」(如之前的100个 x 值,形状 (100,1)) 必须是2D数组(sklearn 要求),每行是一个样本,每列是一个特征
第二个位置参数 y 输入的「标签数据」(如之前的带噪声 y 值,形状 (100,1)) x 的样本数量必须一致(100个样本对应100个标签),1D/2D数组都可
test_size 0.2 测试集占「原始数据总比例」(默认0.25,即25%) 这里 0.2 表示:测试集占20%,剩余80%是「训练集+验证集」(后续再拆分)
random_state 42 控制数据划分的「随机种子」(和之前 np.random.seed(42) 作用一致) 固定后,每次运行划分结果完全相同(保证可复现);不设则每次划分都不同

三、关键细节:划分逻辑(为什么这么拆?)

结合之前的多项式回归流程,这行代码的划分逻辑是「两步拆分法」的第一步:

1. 第一步:先拆「总数据」为「80%训练+验证集」和「20%测试集」

  • 原始数据:x(100个样本)、y(100个样本);
  • test_size=0.2:从100个样本中随机选20个作为「测试集」(x_test, y_test);
  • 剩余80个样本作为「训练集+验证集」(x_train_val, y_train_val);
  • random_state=42:固定随机选择的结果(比如第3、15、27个样本是测试集,每次运行都一样)。

2. 第二步:再拆「80%训练+验证集」为「60%训练集」和「20%验证集」

这就是之前代码中紧接着的第二行划分:

x_train, x_val, y_train, y_val = train_test_split(x_train_val, y_train_val, test_size=0.25, random_state=42
)
  • 这里 test_size=0.25 是「相对于 x_train_val(80个样本)的比例」;
  • 80 * 0.25 = 20 个样本作为「验证集」,剩余 60 个作为「训练集」;
  • 最终总比例:训练集60%、验证集20%、测试集20%(符合标准划分)。

四、参数拓展:其他常用设置(可选)

除了代码中的参数,train_test_split 还有两个实用参数,适合不同场景:

1. train_size:直接指定训练集比例(与 test_size 互斥)

比如想让训练集占70%、测试集30%,可以写:

x_train_val, x_test, y_train_val, y_test = train_test_split(x, y, train_size=0.7, random_state=42  # 无需再写 test_size
)

2. stratify:分层抽样(分类任务必用,回归任务不用)

如果是 分类问题(如预测「是否患病」),为了保证训练集和测试集的「类别比例一致」(比如原始数据中患病占30%,测试集也应占30%),需要设置 stratify=y

# 分类任务示例(回归任务如多项式回归不需要这参数)
x_train_val, x_test, y_train_val, y_test = train_test_split(x, y, test_size=0.2, random_state=42, stratify=y
)
  • 回归任务(预测连续值,如之前的 y)不需要 stratify,因为标签是连续的,没有「类别比例」可言。

五、和之前流程的核心关联

这行代码是「多项式次数选择」的基础,没有合理的数据划分,后续的模型选择就会失真:

  1. 测试集的独立性test_size=0.2 拆分出的测试集,全程不参与「模型训练」和「次数选择」,仅用于最终评估最优模型的泛化能力(避免测试集“泄露”导致的结果造假);
  2. 随机种子的一致性random_state=42 与之前的 np.random.seed(42) 保持一致,确保「数据划分」和「噪声生成」的随机性都被固定,结果可复现;
  3. 后续步骤的依赖:拆分出的 x_train 用于训练不同次数的多项式模型,x_val 用于筛选最优次数,x_test 用于最终评估——三者各司其职,是避免过拟合、准确评估模型的关键。

六、常见误区提醒

  1. 不要重复划分测试集:测试集只能划分一次,不能在模型选择后重新划分(否则会用测试集的信息调整模型,导致泛化能力评估失真);
  2. random_state 不是必须的,但强烈推荐:不设置 random_state 会导致每次运行的训练/测试集不同,无法复现结果,不利于调试和对比;
  3. test_size 的取值范围:必须是 (0,1) 之间的小数(比例)或正整数(样本数量),比如 test_size=20 表示直接取20个样本作为测试集。

总结

train_test_split(x, y, test_size=0.2, random_state=42) 的核心作用是:
按2:8比例(测试集:训练+验证集)随机拆分数据,固定随机种子保证可复现,同时维持特征和标签的对应关系——这是后续训练模型、用验证集选最优次数、用测试集评估泛化能力的前提,是机器学习流程的“基础操作”。

http://www.jsqmd.com/news/54389/

相关文章:

  • 解决LVGL与FATFS编码格式冲突及外挂字库方案
  • 我是如何用浏览器插件轻松抓取抖音评论并实现精准搜索分析的
  • 重练算法(代码随想录版) day24 - 回溯part3
  • useEffect详解
  • 详解np.random.normal(0, 3, size=x.shape)
  • 代码随想录Day23_回溯_组合.md
  • 详细介绍:【JUnit实战3_21】第十二章:JUnit 5 与主流 IDE 的集成 + 第十三章:用 JUnit 5 做持续集成(上):在本地安装 Jenkins
  • 代码随想录Day24_回溯_复原IP.md
  • 何以为生
  • GraphRAG进阶:基于Neo4j与LlamaIndex的DRIFT搜索实现详解
  • Gemini3疯了!0.09接入Nano Banana Pro 4k画质API(附实战教程)
  • 11/28
  • noip板子
  • Webstorm常用配置
  • 东方博宜OJ 1119:求各位数字之和 ← 循环结构
  • 2025.11.28
  • 10个免费查重降重工具分享,降AIGC率工具
  • Linux_Socket_浅谈UDP - 教程
  • Jetlinks 物联网平台 开源版学习源码分析
  • Java 线程池深度解析:原理、策略与生产环境调优指南
  • Tita CRM一体化平台:破解销售管理五大痛点,实现业绩可持续增长
  • NOIP 算法合集
  • 会赢吗
  • 直接通过electron创建项目
  • 东方博宜OJ 1246:请输出n行的9*9乘法表 ← 嵌套循环
  • 使用cnpm(中国镜像源的npm客户端)来安装electron
  • 2025年11月电动叉车销售企业避坑指南:市场主流品牌横向对比
  • 2025年11月中国电动叉车销售公司推荐榜单:主流品牌综合对比分析
  • 详细介绍:Qt样式深度解析
  • 文档抽取科技:利用自然语言处理技术自动识别和提取合同、判决书等法律文书中的关键信息,并将其转化为结构化数据