当前位置: 首页 > news >正文

LaSt-ViT:Vision Transformers Need More Than Registers(CVPR 2026)

前言

尽管 Vision Transformers (ViTs) 在图像分类等领域取得了巨大成功,但其内部机制仍存在诸多未解之谜。近年来的研究发现,在需要密集特征的下游任务中,ViTs 表现出多种令人困惑的伪影 (Artifacts),这些问题普遍存在于不同的训练范式中:

  • 全监督 (Fully-supervised):存在明显的“注意力缺陷”,生成的特征图无法有效聚焦于物体主体,导致局部特征提取能力受限。
  • 文本监督 (Text-supervised):稠密特征与文本提示的对齐精度不佳,难以在像素级别精准匹配语义信息。
  • 自监督 (Self-supervised):模型中出现“高范数令牌” (High-norm tokens),成为干扰项,严重影响对目标物体的精确定位。

这些问题背后,是否有一个共同的原因呢?今天我们要学习的是来自CVPR2026的前沿研究LaSt ViT(LazyStrike ViT)。

论文:Vision Transformers Need More Than Registers

代码:https://github.com/ChengShiest/LAST-ViT

惰性聚合

为了解释这些现象,这篇工作提出了“惰性聚合假说”。该假说认为,ViT的伪影源于一种“偷懒”行为:“在粗粒度语义监督和全局注意力机制的共同作用下,ViT倾向于利用语义上无关的背景补丁作为捷径 (Shortcut) 来编码全局语义,而非专注于前景目标。”

粗粒度语义监督(Coarse-grained Supervision)指模型仅获取图片级的类别标签,缺乏对各图像块(Patch)的精确监督。模型可以通过任何足以区分不同类别的特征完成任务,而无需依赖最具代表性的前景特征。而全局注意力机制允许信息在所有图像块之间自由、高效地流动。若某些背景特征与类别高度相关,模型会快速将前景语义“扩散”到背景上,形成训练捷径。

于是模型发现,与其费力地去学习前景特征,不如利用无处不在的背景作为捷径来完成分类任务。这就是ViT伪影的根源。

假说验证

这里定义了两个指标:

  • Patch Score 每个补丁特征与CLS令牌的余弦相似度。高分值意味着该补丁与图像整体语义高度相关,是衡量局部特征重要性的核心指标。
  • Point-in-Box (PiB) 最高分补丁落在前景标注框内的图像比例。PiB数值越高,代表模型越能准确地将“全局语义重心”与“视觉前景”对齐。

Patch Score衡量每个图像块与整体语义的相关性,PiB衡量最高分块是否落在前景。

令人惊讶的是背景块的Patch Score分数反而更高。更有趣的是,当我们把分数最高的一半块(大多是背景)遮住,模型的分类准确率竟然没怎么变。这直接证明了ViT确实是在“偷懒”,靠背景来完成任务。研究还给出了三个证据:

懒惰从训练初期就存在

在追踪训练全流程的过程中,发现了一个极具反差性的现象:ViT模型的分类准确率曲线表现完美,随着训练推进稳步攀升。然而,衡量其定位前景能力的指标——PiB分数却始终在约42%的低位徘徊,远低于基于卷积的ResNet模型。

这一结果有力地证明:模型利用背景信息走捷径的“伪影”,并非训练后期才产生的副作用,而是从训练初期就已形成并贯穿全过程。模型在学习分类任务的同时,从第一个epoch起就选择了更容易的背景特征,而非学习真正的前景特征。

粗粒度监督引导ViT偷懒

分类任务的标签仅告知模型图中“有什么”,未指明物体“在哪里”。由于背景 Patch 数量远多于前景,模型为了快速收敛,会利用背景线索而非学习物体特征。研究团队也做了实验,把patch从16x16改到了28x28,pib有所上升,但准确率却有所下降。

这表明了:增大 Patch 尺寸以减少背景干扰,结果显示定位能力(PiB)上升但分类准确率下降。这表明模型宁愿牺牲定位能力,也要依赖背景“捷径”。

全局注意力的影响

为验证全局注意力是否会通过允许前景语义传播到背景区域而加剧其惰性聚合行为。为逐步限制长程依赖关系,在不同层级将全局自注意力替换为基于窗口的注意力,详细实验如下表所示:

实验说明随着全局注意力的限制,PIB得分逐渐提高,在所有层均采用窗口注意力时达到最高分数59.8。然而,准确性相应下降到了63.9,这表明尽管全局上下文对分类有益,但它也会促进语义向背景区域扩散。

LaSt-ViT

为了从根本上解决“惰性聚合”问题,LaSt-ViT (LazyStrike ViT)采用了一种简单而有效的基于频率感知的选择性聚合方案。

彻底重构 CLS 令牌的聚合逻辑,不再像传统 ViT 那样对所有补丁特征进行“无差别”聚合。 转而让 CLS 令牌只选择性地聚合来自前景补丁的有效特征,将背景视为干扰并在聚合阶段予以过滤。

利用深层网络特征图的通道维度变化差异,通过频率分析实现前景与背景的分离:

  • 前景:语义均匀 → 特征变化小 →低频信号
  • 背景:语义杂乱 → 特征变化大 →高频信号

利用这个特性,我们就能通过频率分析来筛选出特征稳定的patch。

核心实现

LazyStrike ViT的核心操作主要有三步:

(1)计算稳定性分数 (Stability Score)

首先,对每个补丁的特征向量在通道维度进行一维傅里叶变换,应用高斯低通滤波后逆变换。滤波后与原始特征的比值即为稳定性分数,分数越高代表语义越稳定,越可能是前景。

(2)通道级Top-K聚合 (Channel-wise Top-K Pooling)

为每个通道选择稳定性分数最高的K个补丁,对其特征做平均池化,将结果整合到CLS令牌中,从而聚合各通道最具信息量的局部特征。

(3)投票计数筛选 (Vote Count)

统计每个补丁在所有通道中被选中的次数,次数越多代表补丁的重要性越高。通过这种投票机制,进一步强化了对图像中前景区域的特征表征能力。

通过这三步,我们就能让CLS令牌精准地锚定前景。

算法实现可以参考下面:

""" LAST-ViT: Vision Transformers Need More Than Registers Core implementation of the frequency-domain token selection mechanism. Reference: https://arxiv.org/abs/2602.22394 """ import torch import torch.nn as nn from torchvision.models.vision_transformer import VisionTransformer class LASTViT(VisionTransformer): """ LAST-ViT replaces the standard CLS token with a frequency-domain token selection mechanism. Key idea: 1. Apply FFT + Gaussian low-pass filter to patch tokens 2. Compute stability scores: diff = original / |filtered - original| 3. Select the most stable patch token (top-k) 4. Average selected tokens as the new CLS token """ def __init__(self, image_size=224, patch_size=16, num_layers=12, num_heads=12, hidden_dim=768, mlp_dim=3072, num_classes=1000, top_k=1, sigma=None, **kwargs): super().__init__( image_size=image_size, patch_size=patch_size, num_layers=num_layers, num_heads=num_heads, hidden_dim=hidden_dim, mlp_dim=mlp_dim, **kwargs ) self.top_k = top_k self.sigma = sigma if sigma is not None else hidden_dim ** 0.5 self.cached_kernel = None # Replace classification head (ViT_B_16_Weights has 1000 classes by default) self.heads = nn.Linear(hidden_dim, num_classes) if num_classes != 1000 else self.heads def gaussian_kernel_1d(self, kernel_size: int, sigma: float) -> torch.Tensor: """Create a 1D Gaussian kernel for frequency-domain filtering.""" positions = torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1).float() kernel = torch.exp(-0.5 * (positions / sigma) ** 2) kernel = kernel / kernel.max() return kernel def low_pass_filter(self, patch_tokens: torch.Tensor) -> torch.Tensor: """ Apply frequency-domain low-pass filter (Gaussian in frequency domain). Args: patch_tokens: [B, N, D] patch embeddings Returns: Filtered patch tokens [B, N, D] """ original_dtype = patch_tokens.dtype # Use float for FFT to avoid precision issues if patch_tokens.dtype in {torch.float16, torch.bfloat16}: patch_tokens = patch_tokens.float() # Lazy initialization of Gaussian kernel if self.cached_kernel is None or self.cached_kernel.shape[-1] != patch_tokens.shape[-1]: kernel = self.gaussian_kernel_1d(patch_tokens.shape[-1], self.sigma) self.cached_kernel = kernel.view(1, 1, -1).to(patch_tokens.device) # FFT-based filtering spectrum = torch.fft.fft(patch_tokens, dim=-1) spectrum = torch.fft.fftshift(spectrum, dim=-1) spectrum = spectrum * self.cached_kernel spectrum = torch.fft.ifftshift(spectrum, dim=-1) filtered = torch.fft.ifft(spectrum, dim=-1).real return filtered.to(dtype=original_dtype) def stability_score(self, original: torch.Tensor, filtered: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: """ Compute token stability scores. Higher score = more stable token (less affected by high-frequency removal) Formula from the paper: score = original / |filtered - original| """ diff = filtered - original return original / (torch.abs(diff) + eps) def forward_features(self, x: torch.Tensor): """ Forward pass with token selection. Args: x: Input images [B, 3, H, W] Returns: logits: Classification logits [B, num_classes] cls_token: The aggregated CLS token [B, hidden_dim] (optional) """ # Standard ViT preprocessing and encoding x = self._process_input(x) n = x.shape[0] # Add class token batch_class_token = self.class_token.expand(n, -1, -1) x = torch.cat([batch_class_token, x], dim=1) x = self.encoder(x) # Patch tokens only (drop CLS token) patch_tokens = x[:, 1:] # [B, N, D] # Apply low-pass filtering filtered_tokens = self.low_pass_filter(patch_tokens) # Compute stability scores scores = self.stability_score(patch_tokens, filtered_tokens) # Select top-k most stable tokens top_k = min(self.top_k, patch_tokens.shape[1]) _, indices = torch.topk(scores, k=top_k, dim=1, largest=True) # Gather selected tokens selected_tokens = torch.gather(patch_tokens, 1, indices) # [B, k, D] # Average to form new CLS token cls_token = torch.mean(selected_tokens, dim=1) # [B, D] return cls_token, patch_tokens def forward(self, x: torch.Tensor): cls_token, _ = self.forward_features(x) # Classification head logits = self.heads(cls_token) return logits, cls_token def create_last_vit( pretrained_path: str = None, top_k: int = 1, num_classes: int = 1000, device: str = 'cuda' if torch.cuda.is_available() else 'cpu' ) -> LASTViT: """ Create a LAST-ViT model with optional pretrained weights. Args: pretrained_path: Path to pretrained checkpoint (from GitHub releases) top_k: Number of tokens to select (k=1 for standard LAST-ViT) num_classes: Number of output classes device: Device to load model on Returns: LASTViT model """ model = LASTViT( image_size=224, patch_size=16, num_layers=12, num_heads=12, hidden_dim=768, mlp_dim=3072, num_classes=num_classes, top_k=top_k ) if pretrained_path: checkpoint = torch.load(pretrained_path, map_location='cpu') # Handle different checkpoint formats if isinstance(checkpoint, dict): state_dict = checkpoint.get('model', checkpoint) else: state_dict = checkpoint # Remove common prefixes new_state_dict = {} for key, value in state_dict.items(): new_key = key for prefix in ['module.', 'model.']: if new_key.startswith(prefix): new_key = new_key[len(prefix):] new_state_dict[new_key] = value # Load weights missing, unexpected = model.load_state_dict(new_state_dict, strict=False) print(f"Loaded pretrained weights from {pretrained_path}") if missing: print(f" Missing keys: {missing[:5]}..." if len(missing) > 5 else f" Missing keys: {missing}") if unexpected: print(f" Unexpected keys: {unexpected[:5]}..." if len( unexpected) > 5 else f" Unexpected keys: {unexpected}") model = model.to(device) return model if __name__ == "__main__": # Test the model print("LAST-ViT Model Test") device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Device: {device}") # Create model model = create_last_vit('ViT_190k.pth', top_k=1) model.eval() print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters") # Test forward pass dummy_input = torch.randn(2, 3, 224, 224).to(device) with torch.no_grad(): logits, cls_token = model(dummy_input) features = model.forward_features(dummy_input) print(f"Input shape: {dummy_input.shape}") print(f" Logits shape: {logits.shape}") print(f" CLS token shape: {cls_token.shape}") print(f"Expected: logits [2, 1000], cls_token [2, 768]") if isinstance(features, tuple): cls_feat, patch_feat = features print(f"\nFeature extraction:") print(f" CLS feature: {cls_feat.shape}") print(f" Patch feature: {patch_feat.shape}") print("\n✓ Model works correctly!")

这里将特征提取与分类头进行了分离,这部分说不定能作为特征提取。可下载标签监督的权重:

效果可视化

这张图直观地展示了LaSt-ViT的效果。可以看到,右边的可视化结果中(ConvNet、ViT、LaSt-ViT),高投票区域非常精准地覆盖了食物和动物这些前景物体。而标准 ViT 的 Patch Score 往往分散在复杂的背景区域中,容易受噪声干扰;而 LaSt-ViT 的投票区域则能准确聚焦于目标主体,显著提升了识别的鲁棒性。

优点总结

LazyStrike完全不需要额外标注,不改动ViT原有架构,只在预训练阶段生效,全监督、文本监督、自监督三种范式全部通用,而且推理阶段没有任何额外开销,真正做到了一次修改,所有任务都能稳定涨点。

实验结果:12个基准全提效

LaSt-ViT在三种主流训练范式(全监督、文本监督、自监督)、多种主流模型(DeiT, ViT, CLIP, DINO)以及12个不同的下游任务基准上进行了全面验证,取得了显著且一致的性能提升。

在全监督ViT上,粗分割任务VOC12数据集从22.3%暴涨到32.8%,相当于凭空多出了物体定位的能力;

在文本监督CLIP上,语义分割VOC20从49.0直接冲到了75.0,从普通能用的水平变成了SOTA级别;

在自监督DINO上,不用Register也能实现无监督物体发现,性能也能达到67.6。

普通的ViT与加入了LazyStrike之后的ViT的PCA 可视化对比,前则不能很好区分前景和背景,而后者不仅能区分前景和背景,还能 区分物体的各个部位。

空间模式可视化

""" Visualize which token positions are most frequently selected by the LAST-ViT frequency-domain token selection mechanism across different k values. Use your own image folder instead of ImageNet. """ import os from collections import defaultdict import torch import torch.nn as nn import numpy as np import matplotlib matplotlib.use("Agg") # 保存图片用,避免 TkAgg 报错 import matplotlib.pyplot as plt from PIL import Image from scipy import ndimage from torchvision.models.vision_transformer import VisionTransformer from torchvision.transforms import transforms as T from torch.utils.data import DataLoader, Dataset from tqdm import tqdm class CustomImageDataset(Dataset): """Dataset for user's own image folder.""" def __init__(self, root_dir, transform=None, max_samples=None): self.root_dir = root_dir self.transform = transform self.image_paths = [] valid_exts = (".jpg", ".jpeg", ".png", ".bmp", ".webp") for dirpath, _, filenames in os.walk(root_dir): for name in filenames: if name.lower().endswith(valid_exts): self.image_paths.append(os.path.join(dirpath, name)) self.image_paths = sorted(self.image_paths) if max_samples is not None: self.image_paths = self.image_paths[:max_samples] if len(self.image_paths) == 0: raise RuntimeError(f"No images found in: {root_dir}") print(f"Found {len(self.image_paths)} images in: {root_dir}") def __len__(self): return len(self.image_paths) def __getitem__(self, idx): path = self.image_paths[idx] img = Image.open(path).convert("RGB") if self.transform is not None: img = self.transform(img) label = 0 return img, label, path class DenseViTWithTracking(VisionTransformer): """ViT with token selection tracking for visualization.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.cached_kernel = None self.token_selection_counts = defaultdict(lambda: defaultdict(int)) self.enable_tracking = False self.k_values = [1, 5, 10, 20] def gaussian_kernel_1d(self, kernel_size, sigma): device = self.class_token.device x = torch.arange( -kernel_size // 2 + 1, kernel_size // 2 + 1, device=device, dtype=torch.float32, ) kernel = torch.exp(-0.5 * (x / sigma) ** 2) kernel = kernel / torch.max(kernel) return kernel def compute_token_scores(self, images): """ Return: diff: [B, N], token frequency-domain difference score x_detach: [B, N, D], patch token embeddings """ x = self._process_input(images) n = x.shape[0] batch_class_token = self.class_token.expand(n, -1, -1) x = torch.cat([batch_class_token, x], dim=1) x = self.encoder(x) x_detach = x[:, 1:] # [B, 196, 768] hidden_dim = x_detach.shape[-1] if self.cached_kernel is None or self.cached_kernel.shape[-1] != hidden_dim: self.cached_kernel = ( self.gaussian_kernel_1d(hidden_dim, hidden_dim ** 0.5) .to(x.device) .unsqueeze(0) .unsqueeze(0) ) x_fft = torch.fft.fft(x_detach, dim=-1) x_fft = torch.fft.fftshift(x_fft, dim=-1) x_fft = x_fft * self.cached_kernel.to(x.device) x_fft = torch.fft.ifftshift(x_fft, dim=-1) x_filtered = torch.fft.ifft(x_fft, dim=-1).real diff = torch.norm(x_detach - x_filtered, dim=-1) # [B, N] return diff, x_detach def forward(self, x: torch.Tensor): diff, x_detach = self.compute_token_scores(x) if self.enable_tracking: for k in self.k_values: if k <= diff.shape[1]: _, indices = torch.topk(diff, k=k, dim=1, largest=True) for b in range(indices.shape[0]): selected_tokens = indices[b].detach().cpu().numpy() for token_idx in selected_tokens: self.token_selection_counts[k][int(token_idx)] += 1 _, indices = torch.topk(diff, k=1, dim=1, largest=True) selected_tokens = torch.gather( x_detach, dim=1, index=indices.unsqueeze(-1).expand(-1, -1, x_detach.shape[-1]), ) cls_token = torch.mean(selected_tokens, dim=1) return cls_token, None def get_topk_mask_for_image(self, image_tensor, k, device): """ Get current image's selected token mask. image_tensor: [3, 224, 224] """ self.eval() with torch.no_grad(): if image_tensor.dim() == 3: image_tensor = image_tensor.unsqueeze(0) image_tensor = image_tensor.to(device) diff, _ = self.compute_token_scores(image_tensor) k = min(k, diff.shape[1]) _, indices = torch.topk(diff, k=k, dim=1, largest=True) mask = np.zeros(diff.shape[1], dtype=np.float32) for token_idx in indices[0].detach().cpu().numpy(): mask[int(token_idx)] = 1.0 return mask def load_model_and_data(data_root, num_samples=1000, batch_size=32, checkpoint_path=None): """Load model and user's custom image data.""" model = DenseViTWithTracking( image_size=224, patch_size=16, num_layers=12, num_heads=12, hidden_dim=768, mlp_dim=3072, ) model.eval() if checkpoint_path and os.path.exists(checkpoint_path): print(f"Loading pretrained weights: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location="cpu") state_dict = checkpoint if isinstance(checkpoint, dict): if "model" in checkpoint: state_dict = checkpoint["model"] elif "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] new_state_dict = {} for key, value in state_dict.items(): if key.startswith("model."): new_state_dict[key[6:]] = value elif key.startswith("module."): new_state_dict[key[7:]] = value else: new_state_dict[key] = value try: model.load_state_dict(new_state_dict, strict=True) print("Pretrained weights loaded successfully.") except Exception as e: print(f"Strict loading failed: {e}") print("Trying partial loading...") model_dict = model.state_dict() matched_dict = { k: v for k, v in new_state_dict.items() if k in model_dict and model_dict[k].shape == v.shape } model_dict.update(matched_dict) model.load_state_dict(model_dict) print(f"Partial load succeeded: {len(matched_dict)} / {len(model_dict)} parameters matched.") else: raise FileNotFoundError( f"Checkpoint not found: {checkpoint_path}\n" f"请确认 --checkpoint 路径正确。为了避免噪声,这里不再使用随机权重。" ) transform = T.Compose( [ T.Resize(256), T.CenterCrop(224), T.ToTensor(), T.Normalize( mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), ), ] ) dataset = CustomImageDataset( root_dir=data_root, transform=transform, max_samples=num_samples, ) dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, ) return model, dataloader def visualize_token_selection(token_counts_dict, num_tokens=196, save_path="token_selection_heatmap.png"): """Visualize global token selection patterns for different k values.""" k_values = sorted(token_counts_dict.keys()) num_k = len(k_values) if num_k == 0: print("No token selection counts found.") return cols = min(3, num_k) rows = (num_k + cols - 1) // cols fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows)) if num_k == 1: axes = [axes] else: axes = np.array(axes).reshape(-1) grid_size = int(np.sqrt(num_tokens)) for idx, k in enumerate(k_values): counts = token_counts_dict[k] token_freq = np.zeros(num_tokens, dtype=np.float32) for token_idx, count in counts.items(): if 0 <= token_idx < num_tokens: token_freq[token_idx] = count if token_freq.max() > 0: token_freq_normalized = token_freq / token_freq.max() else: token_freq_normalized = token_freq token_grid = token_freq_normalized.reshape(grid_size, grid_size) im = axes[idx].imshow( token_grid, cmap="hot", interpolation="nearest", vmin=0, vmax=1, ) axes[idx].set_title( f"k={k} Token Selection Frequency\nTotal selections: {int(token_freq.sum())}", fontsize=12, fontweight="bold", ) axes[idx].set_xlabel("Patch Column") axes[idx].set_ylabel("Patch Row") plt.colorbar(im, ax=axes[idx], label="Normalized count") for idx in range(num_k, len(axes)): axes[idx].axis("off") plt.tight_layout() plt.savefig(save_path, dpi=300, bbox_inches="tight") plt.close() print(f"Global heatmap saved to: {save_path}") stats_path = save_path.replace(".png", "_stats.png") fig, ax = plt.subplots(figsize=(14, 7)) top_n = 30 for k in k_values: counts = token_counts_dict[k] sorted_tokens = sorted(counts.items(), key=lambda x: x[1], reverse=True)[:top_n] if len(sorted_tokens) == 0: continue token_indices = [t[0] for t in sorted_tokens] token_counts_vals = [t[1] for t in sorted_tokens] ax.plot( token_indices, token_counts_vals, marker="o", label=f"k={k}", linewidth=2, markersize=5, ) ax.set_xlabel("Token Index") ax.set_ylabel("Selection Count") ax.set_title(f"Top {top_n} Most Selected Tokens") ax.legend() ax.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(stats_path, dpi=300, bbox_inches="tight") plt.close() print(f"Statistics plot saved to: {stats_path}") def denormalize_image(tensor, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): mean = torch.tensor(mean).view(3, 1, 1) std = torch.tensor(std).view(3, 1, 1) return tensor * std + mean def visualize_mask_on_image(image_tensor, mask, image_size=224): """Overlay selected patch mask on original image.""" img = denormalize_image(image_tensor.clone()) img = img.clamp(0, 1) img_np = img.permute(1, 2, 0).cpu().numpy() grid_size = int(np.sqrt(mask.size)) mask_2d = mask.reshape(grid_size, grid_size) mask_img = np.zeros((image_size, image_size), dtype=bool) patch_h = image_size // grid_size patch_w = image_size // grid_size for i in range(grid_size): for j in range(grid_size): if mask_2d[i, j] > 0: h_start = i * patch_h h_end = min((i + 1) * patch_h, image_size) w_start = j * patch_w w_end = min((j + 1) * patch_w, image_size) mask_img[h_start:h_end, w_start:w_end] = True result = img_np.copy() mask_3d = mask_img[:, :, np.newaxis] red = np.array([1.0, 0.25, 0.25]) fill_alpha = 0.45 result = result * (1 - fill_alpha * mask_3d) + red * fill_alpha * mask_3d edges = ndimage.sobel(mask_img.astype(float)) edge_mask = np.abs(edges) > 0.1 result[edge_mask, 0] = 1.0 result[edge_mask, 1] = 0.0 result[edge_mask, 2] = 0.0 return np.clip(result, 0, 1) def save_sample_visualizations(model, images, paths, output_dir, device, start_index, max_save=10): """Save per-image token masks using current image top-k tokens.""" saved = 0 num_tokens = (224 // 16) ** 2 for img_idx in range(images.shape[0]): if start_index + saved >= max_save: break sample_img = images[img_idx].cpu() image_path = paths[img_idx] fig, axes = plt.subplots( 1, len(model.k_values), figsize=(4 * len(model.k_values), 4), ) if len(model.k_values) == 1: axes = [axes] for idx, k in enumerate(model.k_values): mask = model.get_topk_mask_for_image(sample_img, k, device) assert mask.size == num_tokens img_with_mask = visualize_mask_on_image(sample_img, mask) axes[idx].imshow(img_with_mask) axes[idx].set_title(f"k={k}", fontsize=12, fontweight="bold") axes[idx].axis("off") basename = os.path.splitext(os.path.basename(image_path))[0] save_name = f"sample_{start_index + saved + 1:03d}_{basename}.png" save_path = os.path.join(output_dir, save_name) plt.tight_layout() plt.savefig(save_path, dpi=150, bbox_inches="tight") plt.close() print(f"Saved sample mask: {save_path}") saved += 1 return saved def main(): import argparse parser = argparse.ArgumentParser(description="Visualize LAST-ViT token selection patterns on custom images") parser.add_argument( "--data-root", type=str, default=r"E:\PythonProject\YoloProject\data\test_coco8\images", help="Path to your own image folder", ) parser.add_argument( "--num-samples", type=int, default=100, help="Number of images to use", ) parser.add_argument( "--batch-size", type=int, default=8, help="Batch size", ) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device: cuda or cpu", ) parser.add_argument( "--output-dir", type=str, default="./visualize", help="Output directory", ) parser.add_argument( "--checkpoint", type=str, default=r"E:\PythonProject\LAST_ViT\ViT_190k.pth", help="Path to pretrained weights", ) parser.add_argument( "--max-sample-vis", type=int, default=10, help="Number of individual images to visualize", ) args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) print("=" * 70) print("LAST-ViT Token Selection Visualization") print("=" * 70) print(f"Data root: {args.data_root}") print(f"Checkpoint: {args.checkpoint}") print(f"Device: {args.device}") print(f"Num samples: {args.num_samples}") print(f"Batch size: {args.batch_size}") print(f"Output dir: {args.output_dir}") print("=" * 70) print("\nLoading model and custom data...") model, dataloader = load_model_and_data( data_root=args.data_root, num_samples=args.num_samples, batch_size=args.batch_size, checkpoint_path=args.checkpoint, ) model = model.to(args.device) model.eval() model.enable_tracking = True num_tokens = (224 // 16) ** 2 print(f"\nModel loaded.") print(f"Patch tokens: {num_tokens}") print(f"k values: {model.k_values}") print("\nRunning inference and tracking token selections...") saved_sample_count = 0 with torch.no_grad(): for batch_idx, batch in enumerate(tqdm(dataloader, desc="Processing")): images, labels, paths = batch images = images.to(args.device, non_blocking=True) try: _ = model(images) if saved_sample_count < args.max_sample_vis: saved_now = save_sample_visualizations( model=model, images=images, paths=paths, output_dir=args.output_dir, device=args.device, start_index=saved_sample_count, max_save=args.max_sample_vis, ) saved_sample_count += saved_now except Exception as e: print(f"Error processing batch {batch_idx}: {e}") continue print("\nToken Selection Statistics:") print("-" * 70) for k in sorted(model.token_selection_counts.keys()): counts = model.token_selection_counts[k] total_selections = sum(counts.values()) unique_tokens = len(counts) print(f"k={k:2d}: total selections={total_selections:8d}, unique tokens={unique_tokens:3d}") top_5 = sorted(counts.items(), key=lambda x: x[1], reverse=True)[:5] print(f" top 5 tokens: {top_5}") print("\nGenerating global statistics visualization...") save_path = os.path.join(args.output_dir, "token_selection_heatmap.png") visualize_token_selection( model.token_selection_counts, num_tokens=num_tokens, save_path=save_path, ) print("\n" + "=" * 70) print("Done.") print("=" * 70) if __name__ == "__main__": main()

总结

这项研究通过系统性分析,深刻揭示了ViT因“惰性聚合”而依赖背景补丁作为语义捷径的核心成因,并提出了一种名为LaSt-ViT的创新方法。该方法通过频率感知的选择性聚合策略,从根源上消除了ViT的特征伪影,最终实现了在多种训练范式和下游任务上的一致性性能提升。

http://www.jsqmd.com/news/733532/

相关文章:

  • Firefox老版本爱好者的自救指南:手动修改prefs.js与channel-prefs.js锁定版本
  • 开源AI视频生成项目Vidya:从扩散模型原理到实战部署全解析
  • 如何利用NTU VIRAL数据集构建无人机多传感器融合算法:完整技术指南
  • AMD Ryzen处理器终极调试指南:SMUDebugTool免费开源工具完全教程
  • 避开这些坑!Pipelined-ADC设计实战:从理论指标到电路仿真的完整避坑指南
  • 微信读书笔记助手:免费高效的阅读管理终极指南
  • 2026年,405nm窄带滤光片定制有何独特之处?带你一探究竟!
  • 实时日志采集与统计分析平台
  • 三电平半桥LLC谐振变换器电路仿真研究:移相角度控制与DSP PWM生成方式探讨,输出电压优化...
  • Anthropic 推出 Claude Security,AI 漏洞扫描能否助力开发者高效修复漏洞?
  • SAA-C03备考别死记硬背!用这5个真实AWS场景串联核心服务(附避坑清单)
  • 杂谈勾股定理
  • 京东秒杀自动化工具:5步轻松实现热门商品抢购的终极指南
  • 如何快速掌握AMD Ryzen调试工具:面向初学者的完整指南
  • 2026年GEO优化公司TOP5推荐:国内主流服务商选型专业参考指南 - 商业小白条
  • 别再死记硬背Payload了!用DVWA靶场手把手教你理解SQL注入与XSS的底层原理
  • 2026年国内GEO优化服务商市场全景分析:综合实力领先的3家主流机构梳理 - 商业小白条
  • 别再瞎调间距了!手把手教你用TCAD仿真优化功率器件场限环(FLR)设计
  • VSCode 2026协作权限体系曝光:细粒度文件级/行级/语义级锁定策略(含RBAC+SCIM集成方案)
  • 基于大语言模型的游戏AI助手:ChatGPT-On-CS项目实战解析
  • Pandas数据分析避坑指南:describe()函数里藏着的5个细节,新手必看
  • 别再手动算闰年了!基于UNIX时间戳的STM32 RTC日期转换与显示实战(附完整代码)
  • 南京及周边防水补漏技术全解析 选服务商的核心逻辑 - 奔跑123
  • 优质小程序开发公司2026年权威推荐!深度测评靠谱小程序制作服务商选型指南 - 新闻快传
  • 高性能内存分配器xgmem:原理、集成与调优实战
  • SparkFun Datalogger IoT开发板:无代码传感器数据采集方案
  • 别急着把 autocast 全切成 bf16:RTX 3090 上把 GEMM、Conv2d 和 ResNet18 训练都跑完后,我的推荐顺序是这样
  • 终极LaTeX公式转换指南:3秒将网页公式完美粘贴到Word
  • 从元数据混乱到有序:用ExifToolGUI重构你的照片管理思维
  • 各行业营销推广方法速查总纲:覆盖30+行业的获客方案