当前位置: 首页 > news >正文

为什么显卡明明可以发下0.5B、1.5B甚至3B的大模型参数,但是训练的时候就会报显存不足的错误呢?

为什么显卡明明可以发下0.5B、1.5B甚至3B的大模型参数,但是训练的时候就会报显存不足的错误呢?

前几天跑了一个大语言模型的代码,自家的电脑显卡本身本身是可以放下模型的参数和优化器参数的,但是训练的时候就报错,当时没有多想,就直接在云服务器上租了一个A800的显卡,不过今天突然想到了这个问题了。


看到了这么一个项目:

https://github.com/zhongzhengli13/MobileNetV3-for-leaf


其中的模型定义代码:

import torch
import torch.nn as nn
from torchvision.models import mobilenet_v3_small
from torchsummary import summaryclass PlantDiseaseClassifier(nn.Module):def __init__(self, num_classes=3):super(PlantDiseaseClassifier, self).__init__()self.base_model = mobilenet_v3_small(pretrained=False)in_features = self.base_model.classifier[3].in_featuresprint(self.base_model.classifier)  # 打印原始分类器结构self.base_model.classifier[3] = nn.Linear(in_features, num_classes)def forward(self, x):return self.base_model(x)if __name__ == "__main__":device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = PlantDiseaseClassifier(num_classes=3).to(device)# ✅ 把 dummy_input 移动到相同 device 上dummy_input = torch.randn(4, 3, 224, 224).to(device)# 测试 forwardoutput = model(dummy_input)print("\n✅ 输出形状:", output.shape)  # 应该是 [4, 3]# 模型结构 summarysummary(model, (3, 4000, 2672), device=str(device))