当前位置: 首页 > news >正文

SAM2:使用mask作为提示输入,实现VOS视频分割

8k50o45u_seg

目录

1. 引言

2. 使用SAM2实现VOS任务

2.1 数据集

2.2 主要函数

2.3 主要代码

3. 结果展示


1. 引言

本文尝试使用SAM2模型来实现VOS任务。由于在官方的github代码中只找到了point或者box作为提示,但是论文中却说明是可以输入mask作为提示的,所以打算自己尝试一下。

官方github:https://github.com/facebookresearch/sam2

2. 使用SAM2实现VOS任务

2.1 数据集

本文是基于MOSE数据集来进行实验。MOSE 数据集是一个用于视频目标分割(Video Object Segmentation, VOS)的数据集,主要用于评估像 Segment Anything Model 2 这类模型在复杂视频场景中的分割与跟踪能力

MOSE数据集:https://mose.video/

数据集的格式如下,每一个视频都处理一个帧列表:

train/valid.tar.gz │ ├── Annotations │ ├── video_name_1 │ │ ├── 00000.png │ │ ├── 00001.png │ │ └── ... │ └── video_name_... │ └── ... │ └── JPEGImages ├── video_name_1 │ ├── 00000.png │ ├── 00001.png │ └── ... └── video_name_... └── ...
2.2 主要函数

在VOS任务中,通常会使用第一帧的mask图像作为提示输入,然而,MOSE数据中的标注是彩色的mask(每一种颜色代表一种类别,如下图所示),而SAM2中所需要的mask提示需要为 Binary Mask,所以需要对MOSE数据集中的mask进行处理,提取出每一个类别的掩码图。

将每种颜色拆分成为单独的 Binary mask图像,并保存到列表中。

def split_multiclass_mask(mask_path, ignore_background=True): """ 读取多类别 mask,并将每种颜色拆分成单独的二值 mask 参数: mask_path: mask图像路径 ignore_background: 是否忽略背景(默认忽略黑色) 返回: binary_masks: list,每个元素是一个 H×W 的二值 mask colors: 每个 mask 对应的颜色 """ mask = cv2.imread(mask_path) if mask is None: raise ValueError(f"无法读取 mask: {mask_path}") # 获取所有颜色 colors = np.unique(mask.reshape(-1, 3), axis=0) binary_masks = [] valid_colors = [] for color in colors: # 跳过背景 if ignore_background and np.all(color == [0, 0, 0]): continue # 找到该颜色的位置 m = np.all(mask == color, axis=-1) binary_mask = (m * 255).astype(np.uint8) binary_masks.append(binary_mask) valid_colors.append(color) return binary_masks, valid_colors

虽然输入需要binary mask,但是最终生成的掩码图应该是彩色的mask,所以使用generate_mask函数生成彩色mask,为每一个id生成一种颜色。

def generate_mask(obj_id, mask): mask = np.squeeze(mask) H, W = mask.shape # 生成颜色 cmap = plt.get_cmap("tab10") color = np.array(cmap(obj_id % 10)[:3]) * 255 color = color.astype(np.uint8) # 初始化黑图 mask_img = np.zeros((H, W, 3), dtype=np.uint8) mask_bool = mask.astype(bool) # 填充颜色 mask_img[mask_bool] = color return mask_img

本文中希望最终生成的结果是带有掩码的视频和彩色mask图像

def merge_masks_and_overlay(frame, all_masks, alpha=0.5): """ frame: 原图 (H,W,3), dtype=uint8 all_masks: list,每个元素是 (H,W,3) 彩色mask, dtype=uint8 alpha: 叠加透明度 return: merged_mask : 合并后的mask图 (H,W,3) overlay_img : 原图+mask叠加 (H,W,3) """ H, W = frame.shape[:2] # 1 初始化 mask merged_mask = np.zeros((H, W, 3), dtype=np.uint8) # 2 合并所有mask for m in all_masks: mask_bool = np.any(m > 0, axis=2) merged_mask[mask_bool] = m[mask_bool] # 3 叠加到原图 overlay = frame.copy() mask_area = np.any(merged_mask > 0, axis=2) # 使用 numpy 做叠加 overlay[mask_area] = ( alpha * merged_mask[mask_area].astype(np.float32) + (1 - alpha) * overlay[mask_area].astype(np.float32) ).astype(np.uint8) # cv2.imwrite('frame.png', overlay) return merged_mask, overlay
2.3 主要代码

配置模型与保存路径

if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") print(f"using device: {device}") if device.type == "cuda": # use bfloat16 for the entire notebook torch.autocast("cuda", dtype=torch.bfloat16).__enter__() # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) if torch.cuda.get_device_properties(0).major >= 8: torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True elif device.type == "mps": print( "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might " "give numerically different outputs and sometimes degraded performance on MPS. " "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion." ) if __name__ == "__main__": sam2_checkpoint = "sam2/checkpoints/sam2.1_hiera_large.pt" model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" video_path = "sam2/MOSE_val/JPEGImages" anno_path = "sam2/MOSE_val/Annotations" output_dir = "output/MOSE" obj_num = 1 os.makedirs(output_dir, exist_ok=True) shutil.rmtree(output_dir) file_count = 1 predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device) for file in os.listdir(video_path): if file_count > 1: predictor.reset_state(inference_state) torch.cuda.empty_cache() frames_path = os.path.join(video_path, file) mask_path = os.path.join(anno_path, file, os.listdir(os.path.join(anno_path, file))[0]) video_save_path = os.path.join(output_dir, 'video',file) mask_save_path = os.path.join(output_dir, 'masks',file) os.makedirs(video_save_path, exist_ok=True) os.makedirs(mask_save_path, exist_ok=True)

接下来对彩色mask图像进行处理,并且将视频帧按顺序排列(这一步非常重要,不然可能会导致输入的提示mask与输入的帧不一致,也是为了保证后续生成符合顺序的掩码视频)

# 处理mask图像 binary_masks, valid_colors = split_multiclass_mask(mask_path) # 对帧按顺序排列 frame_names = [ p for p in os.listdir(frames_path) if os.path.splitext(p)[-1].lower() in [".jpg", ".jpeg", ".png"] ] frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))

初始化predictor后,输入提示掩码,注意,每一种类别对应着一种obj_id,每一个类别对应着的binary mask都需要输入一次,不能一次性输入。

inference_state = predictor.init_state(video_path=frames_path) # take a look the first video frame frame_idx = 0 obj_id = 1 for mask in binary_masks: _, out_obj_ids, out_mask_logits = predictor.add_new_mask( inference_state, frame_idx, obj_id, mask, ) obj_id += 1

用 SAM2 的视频传播(propagation)功能,对视频每一帧生成目标的分割 mask,并把结果保存到video_segments字典中。

video_segments = {} # video_segments contains the per-frame segmentation results for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state): video_segments[out_frame_idx] = { out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids) }

最后生成结果

output_video = os.path.join(video_save_path, f"{file}_seg.mp4") # 读取第一帧获取尺寸 first_img = cv2.imread(os.path.join(frames_path, frame_names[0])) h, w, _ = first_img.shape # 创建 VideoWriter out = iio.get_writer( output_video, fps=5, # 保持 RGB macro_block_size=1 ) for frame_idx in range(len(frame_names)): img_path = os.path.join(frames_path, frame_names[frame_idx]) frame = cv2.imread(img_path) all_masks = [] if frame_idx in video_segments: for out_obj_id, out_mask in video_segments[frame_idx].items(): obj_mask_img = generate_mask(out_obj_id, out_mask) all_masks.append(obj_mask_img) merge_mask, overlay_frame = merge_masks_and_overlay(frame, all_masks) cv2.imwrite(os.path.join(mask_save_path, frame_names[frame_idx]), merge_mask) cv2.imwrite('overlay.png', overlay_frame) overlay_frame = cv2.cvtColor(overlay_frame, cv2.COLOR_BGR2RGB) out.append_data(overlay_frame) out.close() print("Saved video:", output_video) file_count += 1

3. 结果展示

抽取了一个视频中的前几帧mask结果展示

jxmcdk8k_seg

http://www.jsqmd.com/news/485282/

相关文章:

  • Meta甩出4款推理芯片,软硬协同两年算力暴涨25倍
  • 笨鸟先飞之python基础总结
  • AI大模型教程(2026最新)从零基础入门到精通,一篇收藏全掌握!
  • 测试文章发布
  • MATLAB R2018A环境下基于基尼相关性的频域地震盲反褶积方法
  • 小程序毕业设计-基于微信小程序的乡村治理数字化平台的设计与实现
  • 政府科技管理部门如何高效整合区域创新资源?
  • 面试官最爱问的设计题:动态支付系统设计(策略模式 + 工厂模式 + Spring自动注册)
  • Python每日一题:四道易错题深度解析(变量作用域、逻辑运算、lambda、Py2/3区别)
  • OpenClaw玩转有道云笔记
  • 超越 Transformer 的架构前瞻
  • 2026年手机摄像头测试方案厂商技术强的品牌推荐 - mypinpai
  • 网络安全向日葵漏洞
  • 学长亲荐 8个降AIGC软件:全行业通用测评,帮你高效降AI率
  • java从头开始-苍穹外卖-day11-数据统计与展示
  • Argo CD 的核心架构组件与作用
  • js 从入门到放弃 3/15
  • 语音算法面试复习系列2——语音信号处理基础(下)
  • Vue案例——面经
  • 图解C语言侵入式双向循环链表与 container_of 宏底层原理
  • 百度文心搜索4.0+C# RAG实战:打造支持实时问答与长文档总结的智能客服
  • 计算机毕业设计springboot基于Spark的用户行为数据挖掘与分析解决方案 SpringBoot框架下融合Spark的用户行为模式识别与智能分析平台 基于SpringBoot与Spark的用户行
  • lossless-claw vs mem0:别再把上下文管理和长期记忆混为一谈
  • JAVA面试题速记-分布式架构知识点-元一软件
  • 2.创建你的第一个FreeRTOS任务(动态与静态)
  • 项目实训开题
  • Three.js制作的3D魔方。
  • 0612-出租车(调价+昼夜)-系统设计(51+SEG+DS1302)
  • TimeLine如何自定义轨道
  • 035-spiderbuf第C12题