当前位置: 首页 > news >正文

day37简单的神经网络@浙大疏锦行

day37简单的神经网络@浙大疏锦行

使用 sklearn 的 load_digits 数据集 (8x8 像素的手写数字) 进行 MLP 训练。

importtorchimporttorch.nnasnnimporttorch.optimasoptimfromsklearn.datasetsimportload_digitsfromsklearn.model_selectionimporttrain_test_splitfromsklearn.preprocessingimportMinMaxScalerimportnumpyasnpimportmatplotlib.pyplotasplt# 1. 加载数据digits=load_digits()X=digits.data y=digits.targetprint(f"数据形状:{X.shape}")print(f"标签形状:{y.shape}")# 查看一张图片plt.imshow(digits.images[0],cmap='gray')plt.title(f"Label:{y[0]}")plt.show()

数据形状: (1797, 64) 标签形状: (1797,)

# 2. 数据预处理# 划分训练集和测试集X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.2,random_state=42)# 归一化scaler=MinMaxScaler()X_train=scaler.fit_transform(X_train)X_test=scaler.transform(X_test)# 转换为 TensorX_train=torch.FloatTensor(X_train)y_train=torch.LongTensor(y_train)X_test=torch.FloatTensor(X_test)y_test=torch.LongTensor(y_test)print("训练集 Tensor 形状:",X_train.shape)print("测试集 Tensor 形状:",X_test.shape)

训练集 Tensor 形状: torch.Size([1437, 64])

测试集 Tensor 形状: torch.Size([360, 64])

# 3. 定义模型classMLP(nn.Module):def__init__(self):super(MLP,self).__init__()# 输入层 64 (8*8像素) -> 隐藏层 32 -> 输出层 10 (0-9数字)self.fc1=nn.Linear(64,32)self.relu=nn.ReLU()self.fc2=nn.Linear(32,10)defforward(self,x):out=self.fc1(x)out=self.relu(out)out=self.fc2(out)returnout model=MLP()print(model)

MLP(

(fc1): Linear(in_features=64, out_features=32, bias=True) (relu): ReLU()

(fc2): Linear(in_features=32, out_features=10, bias=True)

)

# 4. 定义损失函数和优化器criterion=nn.CrossEntropyLoss()optimizer=optim.SGD(model.parameters(),lr=0.1)# 学习率稍微调大一点,或者增加epoch
# 5. 训练模型num_epochs=2000losses=[]forepochinrange(num_epochs):# 前向传播outputs=model(X_train)loss=criterion(outputs,y_train)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()losses.append(loss.item())if(epoch+1)%100==0:print(f'Epoch [{epoch+1}/{num_epochs}], Loss:{loss.item():.4f}')

# 6. 可视化损失plt.plot(range(num_epochs),losses)plt.xlabel('Epoch')plt.ylabel('Loss')plt.title('Training Loss')plt.show()

# 7. 模型评估withtorch.no_grad():# 训练集准确率outputs_train=model(X_train)_,predicted_train=torch.max(outputs_train,1)accuracy_train=(predicted_train==y_train).sum().item()/y_train.size(0)# 测试集准确率outputs_test=model(X_test)_,predicted_test=torch.max(outputs_test,1)accuracy_test=(predicted_test==y_test).sum().item()/y_test.size(0)print(f'训练集准确率:{accuracy_train:.4f}')print(f'测试集准确率:{accuracy_test:.4f}')

@浙大疏锦行

http://www.jsqmd.com/news/84640/

相关文章:

  • wpf 类图
  • 基于springboot的运动服服饰销售购买商城系统
  • 基于文献的‘12-文献代码复现‘:非线性模型预测控制(NMPC)多无人船USV编队控制form...
  • 【保姆级教程】爆火开源项目 Next AI Draw.io 上手指南:一句话画流程图
  • 韩语教程资源合集
  • 如何用DSPy优化RAG prompt示例
  • 英语口语资源合集
  • 鸿蒙PC UI控件库 - TextInput 文本输入框详解
  • 链表中的回文判断
  • 鸿蒙PC UI控件库 - PasswordInput 密码输入框详解
  • 【大模型预训练】07-数据处理流程设计:从原始数据到模型输入的端到端处理链路
  • 基于VMD-CPA-KELM-IOWAl-CSA-LSSVM碳排放的混合预测模型研究附Matlab代码
  • 【机器人路径规划】基于6种算法(黑翅鸢优化算法BKA、SSA、MSA、RTH、TROA、COA)求解机器人路径规划研究附Matlab代码
  • 基于6种最新算法(小龙虾优化算法COA、MSA、RTH、NOA、BFO、SWO)求解机器人路径规划研究附Matlab代码
  • 【数据结构】排序
  • 【路径规划】基于RRT快速探索随机树算法在包含圆形障碍物的环境中寻找从起点到目标点的路径附matlab代码
  • 【太阳能学报EI复现】基于粒子群优化算法的风-水电联合优化运行分析附Matlab代码
  • go构建web服务
  • 夜莺监控设计思考(三)时序库、agent 的一些设计考量
  • 系统基础服务
  • 数据结构:二叉排序树,平衡二叉树,红黑树的介绍
  • 软件测试面试题集合
  • AI中的优化5-无约束非线性规划之凸性
  • Go Module构建
  • 【time-rs】Duration 结构体详解
  • 深圳|昆明|广州|东莞-奶茶原料批发供应商|奶茶原料供应商|奶茶原料批发市场|奶茶原料批发|奶茶原料推荐|奶茶原料公司——圣旺水吧 - 老百姓的口碑
  • Python基础知识的总结(2)
  • Go程序的执行顺序
  • Java 大视界 -- 基于 Java 的大数据分布式计算在地球物理勘探数据处理与地质结构建模中的应用
  • TDengine 新性能基准测试工具 taosgen