一文看明白PyTorch 模型设计训练保存加载预测
需求
代码样例
包含训练 → 保存 → 加载 → 预测,代码可以直接运行:
importtorchimporttorch.nnasnnimporttorch.optimasoptimfromtorch.utils.dataimportDataLoader,TensorDataset# -----------------------------# 1. 定义模型# -----------------------------classSimpleModel(nn.Module):def__init__(self):super(SimpleModel,self).__init__()self.fc1=nn.Linear(128,96)self.fc2=nn.Linear(96,64)self.fc3=nn.Linear(64,32)self.relu=nn.ReLU()self.dropout=nn.Dropout(0.2)defforward(self,x):x=self.relu(self.fc1(x))x=self.dropout(x)x=self.relu(self.fc2(x))x=self.dropout(x)out=self.fc3(x)returnout# -----------------------------# 2. 准备数据 (示例随机数据)# -----------------------------X=torch.randn(1000,128)y=torch.randn(1000,32)dataset=TensorDataset(X,y)batch_size=32dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=True)# -----------------------------# 3. 定义损失函数和优化器# MSELoss Mean Squared Error(均方误差)# -----------------------------model=SimpleModel()criterion=nn.MSELoss()optimizer=optim.Adam(model.parameters(),lr=0.001)# -----------------------------# 4. 训练循环# -----------------------------num_epochs=20forepochinrange(num_epochs):model.train()# 训练模式epoch_loss=0forbatch_X,batch_yindataloader:optimizer.zero_grad()outputs=model(batch_X)loss=criterion(outputs,batch_y)loss.backward()optimizer.step()epoch_loss+=loss.item()*batch_X.size(0)epoch_loss/=len(dataset)print(f"Epoch{epoch+1}/{num_epochs}, Loss:{epoch_loss:.4f}")# -----------------------------# 5. 保存训练好的模型参数# -----------------------------torch.save(model.state_dict(),"simple_model.pth")print("模型参数已保存到 simple_model.pth")# -----------------------------# 6. 加载模型进行预测# -----------------------------# 重新创建模型对象model_loaded=SimpleModel()# 加载保存的参数model_loaded.load_state_dict(torch.load("simple_model.pth"))# 切换到评估模式model_loaded.eval()# 假设有新样本 x_newx_new=torch.randn(5,128)withtorch.no_grad():# 推理时禁用梯度y_pred=model_loaded(x_new)print("加载模型预测结果形状:",y_pred.shape)# [5, 32]✅ 特点
- 训练完成后保存权重,
simple_model.pth可以随时加载。 - 加载模型时必须重新创建类,然后
load_state_dict。 - 推理时切换到
eval()模式,保证 Dropout 不随机失活。 - 使用
torch.no_grad()提升预测效率,减少显存占用。
