当前位置: 首页 > news >正文

图像分类__半监督

da201ff6879d588a207a552c160fdb08_720
不仅要验证集上的准确率达标,还要求无标签数据集上的概率大于threshold才能打上标签,加进半监督集
并且semiloader的读取也不用每一轮都尝试,这样太浪费时间了,可以每五轮或3轮尝试
semidataset的输入就是 无标签数据集,模型,置信度

no_label_loader = DataLoader(no_label_set, batch_size=16, shuffle=False) 他不能打乱,否则在semi_dataset中无法将索引对应的label正确加入进半监督训练集

class semi_dataset():def __init__(self,no_label_loader , model, device, thres=0.99):x, y = self.get_label(no_label_loader, model, device, thres)if x == []:flag = False #说明没一个无标签数据的可信度达标else:self.flag = Trueself.X = np.array(x)    #数据集之后Dataloader会完成对array矩阵处理成tensorself.Y = torch.LongTensor(y)    #而Y是tensor是因为模型输出也默认是tenosr,这里标签也设为tenosr,方便后面loss处理self.transform = train_transformdef get_label(self, no_label_loader, model, device,thres):model = model.to(device)soft = nn.Softmax()pred_prob = []labels = []x = []y = []with torch.no_grad():for bat_x, _ in no_label_loader:bat_x = bat_x.to(device)pred = model(bat_x)pred_soft = soft(pred)pred_max, pred_value =  pred_soft.max(1) #代表是横着的维度,因为每个pred是tensor(batshcsize,11)#分别承载概率最大值,和最大概率的索引pred_prob.extend(pred_max.cpu().numpy().tolist())labels.extend(pred_value.cpu().numpy().tolist()) #numy之能在cpu上,所以先到cpu上在转为numpy,在转成list,对list只能用extend,对值可以用appenfor index, prob in enumerate(pred_prob):if prob> thres:x.append(no_label_loader.dataset[index][1])       #index调用到dataset的getitem函数的第二个返回结果即没有扩增过的原始数据y.append(labels[index])return x,ydef __getitem__(self,item):return self.transform(self.x[item]), self.y[item]def __len__(self):return len(self.x)

对于semi_dataloader不应该直接定义,因为它不一定存在,其实它必须在模型训练过程中存在,比如模型的acc已经训练到了一个程度

def get_semi_loader(no_label_loader,model,device,thres):semi_set =  semi_dataset(no_label_loader,model,device,thres)if semi_set.flag == False:return Noneelse:semi_loader = DataLoader(semi_set, batch_size=16, shuffle=False)  #这里对于打不打乱随便return semi_loader
http://www.jsqmd.com/news/409384/

相关文章:

  • 从`vector`和`ArrayList`的区别联想到`ArrayList`线程安全问题
  • AI辅助的房地产投资分析
  • 告别反复登录:一文搞定 AWS CLI SSO 凭证自动刷新
  • C++游戏开发之旅 16
  • 大数据领域 Neo4j 与传统数据库的对比分析
  • ArgoCD部署与核心配置详解 - wanghongwei
  • 【Claude Code解惑】源码阅读利器:Claude Code 帮你梳理 Linux 内核模块逻辑
  • ArgoCD部署与核心配置详解及生产最佳实践 - wanghongwei
  • Hadoop与视频流分析:内容推荐系统
  • VsCode插件推荐---Todo Tree
  • OSPF 邻居无法建立的常见原因
  • 408真题解析-2010-41-数据结构-散列表
  • 【CTFshow-pwn系列】03_栈溢出【pwn 053】详解:逐字节爆破!手写 Canary 的终极破解
  • `static`局部变量与全局变量的区别,编译后映射文件是否包含此类变量的地址?
  • 基于SpringBoot的口腔诊所系统的设计与实现_e47798hi
  • Trae AI使用第三方中转API的配置及Anthropic Claude API、gpt - 4o、grok、gemini、deepseek等大模型及 BaseURL指南
  • 基于SpringBoot的房屋租赁系统设计与实现_10h5wcdp
  • 基于springboot的健身房管理系统 _sj44f863
  • Java 集合入门:Collection List 接口超详细讲解
  • 基于springboot的某学院兼职平台设计与实现_ie33fqxq
  • 【每日一题】LeetCode 1022. 从根到叶的二进制数之和
  • 并查集 - How Many Answers Are Wrong
  • 科研前沿篇---论文研究方向
  • 2026.2.24 雅礼总结
  • 基于springboot的城乡商城协作系统(编号:57734107)
  • 科研前沿篇---人工智能网络结构与研究方向
  • 科研前沿篇---网络结构与研究方向
  • MiniMax M2.5模型正式上线,是否真正实现“生产力SOTA ”与“低负担”,如何评价其表现?
  • 莫队学习总结
  • 大数据领域HBase的集群性能调优实战案例