当前位置: 首页 > news >正文

LLM·minimind-预训练

文章目录

  • 预训练
    • 初始化模型和分词器
      • 初始化配置文件 AutoConfig
      • 从配置文件初始化 AutoModel
      • 加载 AutoTokenizer
    • 预训练数据集
      • 加载数据集
      • DataDict
      • Dataset
      • 数据预处理
        • 数据预先处理函数
        • 1.数据集编码为tokens
        • 2.数据集分块,获得特定长度的`input_ids`和`labels`
    • 训练器
      • TrainingArguments
      • Trainer

所有的代码都基于transformers库。

预训练

初始化模型和分词器

初始化配置文件 AutoConfig

  • 类似transformers中的AutoModel一样,都需要先下载配置文件model.info,然后读取该文件夹获得配置信息
from transformersimportAutoConfig,AutoModelForCausalLM model_path="../model/Qwen2.5-1.5B"config=AutoConfig.from_pretrained(model_path)config
  • 配置信息包含模型的架构:
Qwen2Config{"architectures":["Qwen2ForCausalLM"],"attention_dropout":0.0,

从配置文件初始化 AutoModel

  • AutoModelForCausalLM.from_config:这一步构造初始化需要很长的时间
  • 可能涉及与远程仓库中的定义进行对齐,甚至下载远程仓库的代码,没有科学上网会卡死!
model=AutoModelForCausalLM.from_config(config,trust_remote_code=False)model.to("cuda")
  • model的具体架构:
modelQwen2ForCausalLM((model):Qwen2Model((embed_tokens):Embedding(151936,1536)(layers):ModuleList((0-27):28xQwen2DecoderLayer((self_attn):Qwen2SdpaAttention((q_proj):Linear(in_features=1536,out_features=1536,bias=True)(k_proj):Linear(in_features=1536,out_features=256,bias=True)(v_proj):Linear(in_features=1536,out_features=256,bias=True)(o_proj):Linear(in_features=1536,out_features=1536,bias=False)(rotary_emb):Qwen2RotaryEmbedding()

加载 AutoTokenizer

from transformersimportAutoTokenizertokenizer=AutoTokenizer.from_pretrained(model_path)tokenizer

预训练数据集

  • 数据集格式:必须是token化的序列
  • 最大长度必须一致
  • 构造出labels,该labels与input_ids一致,模型会处理移位

加载数据集

参考文献

  • path:表示数据集的名称monkey-gen,如果只有当前参数则会自动下载到缓存;数据集的格式,例如json,csv
  • data_dir:数据集所在的本地目录
  • data_file:数据集本身,例如xxx.jsonl.
dataset=load_dataset("csv",data_files="./ChnSentiCorp_htl_all.csv",split="train")dataset=load_dataset("json",data_files="./cmrc2018_trial.json",field="data")

DataDict

  • 类型为Dict[str,Dataset]
  • 把他当成一个字典来理解,用于获得train或者test字段的Dataset。
  • 不支持直接索引
DatasetDict({train:Dataset({features:['input_ids','attention_mask'],num_rows:100001})})

Dataset

  • 数组和字典的混合体
  • 可以理解为List[Dict]或者Dict[List]的形式,支持下标索引和键值对索引。
Dataset({features:['input_ids','attention_mask','labels'],num_rows:1370})

数据预处理

  • 我们期望的预训练格式如下:首先是将原始文本str转换为input_ids:List[int]
数据预先处理函数
  • 输入参数:batched=True时为examples:类型为Dict[str,List[Any]]
examples={"text":["今天天气不错。","我在学预训练语言模型。","DeepSpeed 加速训练。"]}
  • 返回参数:batched=True时返回类型为Dict[str,List[Any]]
  • 注意tokenizer处理batch时会返回字典Dict[str,List[Any]]
{"input_ids":[[...,...],[...,...],[...,...]],"attention_mask":[[...,...],[...,...],[...,...]],# 其他字段(如 token_type_ids 等)}
1.数据集编码为tokens
deftokenize_function(examples:Dict[str,List[Any]]# 列名 对应一个列表/):returntokenizer([textfortext in examples['text']])

examples的数据类型

  • examples的类型为:Dict[str,List]
  • 例如:‘text’:[1,2,3]
tokenized_ds=ds.map(tokenize_function,batched=True,# 打包为列名:值/列表'text':['文本1','文本2',...]num_proc=10,remove_columns=column_names,load_from_cache_file=True)
  • 输出结果,将删除当前列,并且返回input_idsattention_mask组成的字典。
DatasetDict({train:Dataset({features:['input_ids','attention_mask'],num_rows:5001})})
2.数据集分块,获得特定长度的input_idslabels
defgroup_texts(examples:Dict[str,List[str]]):# 拼接所有可迭代对象 concat_examples:Dict[str:List]={k:list(chain(*examples[k]))# iter->listfork in examples.keys()# List[tensor]}# 计算总长度 seq=mask total_length=len(concat_examples[list(examples.keys())[0]])num_block=total_length// block_sizeresult={#list->list[tensor]k:[concat_examples[k][i*block_size:(i+1)*block_size]fori inrange(num_block)]fork in concat_examples.keys()}result['labels']=result['input_ids'].copy()returnresult
lm_ds=tokenized_ds.map(group_texts,batched=True,num_proc=10,load_from_cache_file=True,batch_size=1000,)

chain:合并迭代器

  • 拼接两个迭代器,返回一个更长的迭代器
  • 可以通过list转换为数组。
from itertoolsimportchainblock_size=2048# 首位拼接 可迭代对象->返回长迭代器list(chain([1,2],[3,4]))#[1,2,3,4]list(chain(*[[1,2],[3,4]]))#[1,2,3,4]

训练器

  • 训练器包括优化器模型本身,分词器等等,数据集加粗样式

TrainingArguments

  • 规定了一些重要的超参数
  • 包括训练参数:epoch数,梯度累积更新数,评估参数等等
from transformersimportTrainingArgumentstraining_args=TrainingArguments(output_dir="output/",per_device_train_batch_size=1,gradient_accumulation_steps=4,logging_steps=4,num_train_epochs=1,save_steps=500,learning_rate=1e-4,save_on_each_node=True,gradient_checkpointing=True,)

Trainer

  • 数据集使用default_data_collator进行封装为batch
  • IterableWrapper(train_dataset):支持将训练集包裹为可迭代对象,可以直接传入Dataset类型
from transformersimportTrainer,default_data_collator from torchdata.datapipes.iterimportIterableWrapper# 训练器 trainer=Trainer(model=model,args=training_args,#Dataset传入也可以,本身就是mmap,不会节省太多内存train_dataset=IterableWrapper(train_dataset),# 将Dataset类型包裹为迭代器 eval_dataset=None,#tokenizer=tokenizer,# 默认为 MLM 的 collator,使用 CLM 的 collater#CLM:因果语言建模,输入和输出标签一致,不会随机掩码data_collator=default_data_collator,# MLM:掩码语言建模,完型填空,不会随机掩码;)
http://www.jsqmd.com/news/530447/

相关文章:

  • 洞见2026:玄奘之路戈壁徒步专业服务商全景解析与适配建议 - 2026年企业推荐榜
  • AcousticSense AI真实案例:民谣与乡村音乐在ViT-B/16特征空间中的聚类效果
  • 基于PHP、asp.net、java、Springboot、SSM、vue3的技术博客系统的设计与实现
  • Tinke终极指南:NDS游戏文件编辑与资源提取的完整解决方案
  • 基于脉振高频电压注入法的永磁同步电机PMSM矢量控制模型 在d轴注入旋转高频电压信号,在q轴进...
  • 代码遗产规划师:在技术断代潮收割焦虑税
  • 终极指南:如何用DiffSynth Studio实现视频到3D骨架的智能转换
  • Chord视频时空分析工具效果展示:动态目标跨帧跟踪可视化案例
  • FigmaCN 技术架构深度解析:现代浏览器扩展本地化方案的设计与实现
  • AI原生应用领域:文本生成的前沿技术揭秘
  • BLE调试工具大比拼:nRF Connect vs BLE调试助手 vs LightBlue,哪个更适合你的项目?
  • OpenClaw七大配置:从SOUL、USER、AGENTS到MEMORY
  • AI审核驱动的IACheck:适老化改造工程检测报告如何实现更细致与可靠的质量把控
  • YapDatabase并发性能优化:如何在多线程环境中实现零阻塞
  • 风速仿真模型中的Sumlink仿真:风机仿真、风电机组模型、变桨控制与最大功率追踪控制,包含四...
  • 打卡信奥刷题(3006)用C++实现信奥题 P6225 [eJOI 2019] 异或橙子
  • 激光雕刻机未来几年,年复合增长率(CAGR)高达12.9%
  • GME-Qwen2-VL-2B-Instruct实操手册:电商详情页首图与卖点文案语义一致性检测
  • AppleRa1n:iOS 15-16设备iCloud激活锁一键绕过工具,让解锁更简单
  • Icarus Verilog仿真器完整指南:从零开始的数字电路设计终极教程
  • 圣女司幼幽-造相Z-Turbo入门必读:从CSDN博客获取文档、镜像与问题支持全链路
  • 告别混乱代码!Arduino IDE多文件开发避坑指南(从ino到h/cpp的平滑迁移)
  • Onekey:Steam Depot清单自动化获取的一站式解决方案
  • Fish-Speech-1.5实时语音合成展示:对话系统的流畅交互体验
  • BM25S4021-1 TDS水质传感器嵌入式驱动开发指南
  • 2026年评价高的反光膜公司推荐:包装袋/反光膜/塑料膜/塑料袋/大棚膜/气泡膜/气泡袋/珍珠棉定位/缠绕膜/选择指南 - 优质品牌商家
  • Icalingua++插件开发终极指南:打造专属聊天功能
  • NVIDIA DIGITS终极指南:如何快速构建深度学习视觉训练系统 [特殊字符]
  • Axure RP界面异常深度修复指南:从问题诊断到系统化解法
  • 从点云到3D框:CenterPoint实战教程(附Waymo数据集测试结果)