当前位置: 首页 > news >正文

python实现skip-gram(跳词)示例

文章目录

      • 示例

什么是跳词?
一句话,就是用中心词,去预测它周围的词。它是 Word2Vec 里最常用的一种训练方式。

示例

1、安装依赖

pip install matplotlib# 其他torch等依赖早就安装了

2、创建python文件skip_gram_demo.py,代码:

importtorchimporttorch.nnasnnimporttorch.optimasoptimimportmatplotlib.pyplotaspltfromcollectionsimportCounter# ==========================================# 1. 数据准备与预处理# ==========================================# 一个简单的微型语料库corpus=""" deep learning is powerful machine learning is a subset of artificial intelligence deep learning models are inspired by the brain natural language processing uses deep learning """# 文本清洗与分词words=corpus.lower().split()# 构建词汇表 (Word -> Index)vocab=list(set(words))word_to_idx={w:ifori,winenumerate(vocab)}idx_to_word={i:wfori,winenumerate(vocab)}vocab_size=len(vocab)print(f"词汇表大小:{vocab_size}")print(f"词汇表:{vocab}")# 生成训练数据 (Skip-gram: 输入中心词 -> 输出上下文词)defcreate_dataloader(words,word_to_idx,window_size=2):inputs=[]targets=[]foriinrange(1,len(words)-1):center_word=words[i]center_idx=word_to_idx[center_word]# 获取上下文窗口# 比如 window_size=2,则取前后各2个词forjinrange(i-window_size,i+window_size+1):ifj!=iand0<=j<len(words):context_word=words[j]context_idx=word_to_idx[context_word]inputs.append(center_idx)targets.append(context_idx)returntorch.tensor(inputs,dtype=torch.long),torch.tensor(targets,dtype=torch.long)inputs,targets=create_dataloader(words,word_to_idx,window_size=2)# ==========================================# 2. 定义 Skip-gram 模型# ==========================================classSkipGramModel(nn.Module):def__init__(self,vocab_size,embedding_dim):super(SkipGramModel,self).__init__()# 中心词嵌入层 (W)self.w_in=nn.Embedding(vocab_size,embedding_dim)# 上下文词嵌入层 (W')self.w_out=nn.Embedding(vocab_size,embedding_dim)# 初始化权重nn.init.xavier_uniform_(self.w_in.weight)nn.init.xavier_uniform_(self.w_out.weight)defforward(self,x):# x: (batch_size,)# 获取中心词的向量embeds=self.w_in(x)# (batch_size, embedding_dim)returnembedsdefloss(self,x,y):# x: 中心词索引, y: 上下文词索引# 1. 获取中心词向量v_center=self.w_in(x)# (batch_size, dim)# 2. 获取上下文词向量v_context=self.w_out(y)# (batch_size, dim)# 3. 计算点积 (相似度)# 这里的逻辑是:点积越大,概率越大score=torch.sum(torch.mul(v_center,v_context),dim=1)# (batch_size,)# 4. 使用负对数似然损失 (简化版,未包含负采样)# 实际大规模训练中通常配合 Negative Sampling 使用# 这里为了演示简单,直接最大化目标词的概率loss=-torch.mean(score)returnloss# ==========================================# 3. 训练模型# ==========================================embedding_dim=10# 词向量维度learning_rate=0.01epochs=1000model=SkipGramModel(vocab_size,embedding_dim)optimizer=optim.SGD(model.parameters(),lr=learning_rate)print("\n开始训练...")forepochinrange(epochs):optimizer.zero_grad()# 前向传播loss=model.loss(inputs,targets)# 反向传播loss.backward()optimizer.step()if(epoch+1)%200==0:print(f"Epoch{epoch+1}, Loss:{loss.item():.4f}")# ==========================================# 4. 结果可视化与测试# ==========================================print("\n训练完成!查看词向量相似度...")# 获取嵌入权重embeddings=model.w_in.weight.data.numpy()# 简单的余弦相似度计算defcosine_similarity(w1,w2):returnnp.dot(w1,w2)/(np.linalg.norm(w1)*np.linalg.norm(w2))# 测试几个词test_words=["learning","deep","artificial","brain"]importnumpyasnpforw1intest_words:ifw1inword_to_idx:vec1=embeddings[word_to_idx[w1]]print(f"\n与 '{w1}' 最相似的词:")similarities=[]forw2invocab:ifw1!=w2:vec2=embeddings[word_to_idx[w2]]sim=cosine_similarity(vec1,vec2)similarities.append((w2,sim))# 排序并打印前3个similarities.sort(key=lambdax:x[1],reverse=True)forword,scoreinsimilarities[:3]:print(f"{word}:{score:.4f}")# 2D 可视化 (PCA 降维)fromsklearn.decompositionimportPCA pca=PCA(n_components=2)reduced_embeds=pca.fit_transform(embeddings)plt.figure(figsize=(10,8))fori,wordinenumerate(vocab):plt.scatter(reduced_embeds[i,0],reduced_embeds[i,1])plt.annotate(word,(reduced_embeds[i,0],reduced_embeds[i,1]))plt.title("Word Embeddings Visualization (PCA)")plt.xlabel("PC1")plt.ylabel("PC2")plt.grid(True)plt.show()

输出结果:

词汇表大小:20词汇表:['artificial','inspired','brain','natural','is','are','learning','by','machine','powerful','processing','language','a','intelligence','uses','subset','deep','models','the','of']开始训练...Epoch200,Loss:-0.0312Epoch400,Loss:-0.0661Epoch600,Loss:-0.1041Epoch800,Loss:-0.1467Epoch1000,Loss:-0.1957训练完成!查看词向量相似度...'learning'最相似的词:inspired:0.6657are:0.4793is:0.4745'deep'最相似的词:machine:0.6026intelligence:0.5229processing:0.4629'artificial'最相似的词:is:0.5218by:0.5195the:0.5013'brain'最相似的词:subset:0.2076powerful:0.1457language:0.0755

解读:
给了一堆杂乱的文字,它居然将这些词分出了远近关系。
成功了。

http://www.jsqmd.com/news/577496/

相关文章:

  • Agent的LLM+RPA模式有什么优势?——深度拆解2026年企业智能自动化新范式
  • 无线网络实战:从零配置AP与SSID,打通设备互联
  • 【龙虾系列】OpenClaw究竟为什么火?用最简单的话讲清楚
  • UVM sequence机制实战:从入门到精通(附6种仲裁算法详解)
  • 从参考到专题:14类地图的现代应用与数据叙事
  • SEO_为什么你的网站需要持续进行SEO优化?
  • YimMenu:GTA V体验增强工具的全方位应用指南
  • MATLAB图像锐化避坑指南:为什么你的拉普拉斯算子效果总是不对?
  • 终极免费音源解决方案:LXMusic如何实现高效音乐资源获取
  • 大模型压测全攻略:从指标解读到工具选型(含EvalScope实战)
  • 新手入门:借助快马AI生成lostlife交互示例学习前端开发
  • 【STM32】STM32F103C8T6结合编码器实现电机速度闭环控制的两种方法对比
  • 如何免费获取NVIDIA的1000次DeepSeek API调用权限
  • OpenCV图像锐化实战:用Laplacian算子让模糊照片瞬间变清晰的3种方法(附Python代码)
  • 运维系列【仅供参考】:【Docker】容器生命周期管理:从优雅停止到高效清理的实战技巧
  • SEO优化如何优化网站页面
  • 城市内涝预警新思路:如何用YOLO实例分割模型+监控视频流实时监测路面积水?
  • 电力负荷预测实战:用HuggingFace上的Timer模型,15分钟搞定一个地区的未来24小时预测
  • 5个高效步骤:直链技术让网盘用户实现下载速度跃升
  • 告别重复造轮子,用快马ai一键生成rabbitmq多模式高效代码模板
  • ArduRemoteID:开源无人机远程识别技术的合规解决方案
  • 【WGC开发】Windows.Graphics.Capture API在Windows10下的窗体捕获实战:开发环境与模板配置详解
  • 5个核心技术模块构建现代化智能Agent系统:fast-agent框架深度解析
  • Vue3+TS+Vite项目实战:5分钟搞定Mock数据接入(附完整代码)
  • 实战指南:用快马平台生成基于openclaw的mac数据清洗工具
  • 基于Python的个性化电影推荐系统毕业设计
  • Your build is currently configured to use incompatible Java 26 and Gradle 8.13. Cannot sync the proj
  • 破局双系统文件壁垒:WinBtrfs驱动终极应用指南
  • 2026年 江苏厂房装修设计公司推荐榜:专业工厂/办公楼/写字楼装修,打造高效办公与生产空间 - 品牌企业推荐师(官方)
  • 新手福音:在快马平台交互式学习openclaw更新命令语法与参数