Git-RSCLIP遥感图像分类代码实例:Python调用API实现批量推理
Git-RSCLIP遥感图像分类代码实例:Python调用API实现批量推理
1. 项目概述与核心价值
Git-RSCLIP是北京航空航天大学团队基于SigLIP架构专门为遥感图像场景开发的先进模型。这个模型在Git-10M数据集上进行了大规模预训练,该数据集包含1000万对高质量的遥感图像和文本描述,使其在遥感图像理解方面表现出色。
为什么选择Git-RSCLIP?
- 零样本分类能力:无需额外训练,直接使用自定义标签进行分类
- 遥感场景优化:专门针对卫星图像、航拍图像等遥感数据优化
- 多语言支持:虽然英文效果最佳,但也支持中文描述
- 高效推理:支持GPU加速,处理速度快
在实际应用中,Git-RSCLIP可以帮我们快速实现:
- 批量自动标注遥感图像
- 根据文本描述检索相似图像
- 多类别图像分类排序
- 遥感场景智能分析
2. 环境准备与模型部署
2.1 基础环境要求
确保你的环境满足以下要求:
# Python版本要求 Python >= 3.8 # 主要依赖库 torch >= 1.9.0 transformers >= 4.20.0 Pillow >= 9.0.0 requests >= 2.28.0 numpy >= 1.21.02.2 快速安装依赖
pip install torch transformers Pillow requests numpy2.3 模型加载与初始化
import torch from transformers import AutoProcessor, AutoModel from PIL import Image import requests from io import BytesIO import numpy as np class GitRSCLIP: def __init__(self, device=None): """ 初始化Git-RSCLIP模型 """ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {self.device}") # 加载模型和处理器 self.model_name = "git_rsclip" self.processor = AutoProcessor.from_pretrained(self.model_name) self.model = AutoModel.from_pretrained(self.model_name).to(self.device) print("模型加载完成!") # 初始化模型 git_rsclip = GitRSCLIP()3. 核心功能实现
3.1 单张图像分类实现
def classify_single_image(image_path, candidate_labels): """ 对单张遥感图像进行分类 参数: image_path: 图像路径或URL candidate_labels: 候选标签列表 返回: 分类结果排序列表 """ # 加载图像 if image_path.startswith('http'): response = requests.get(image_path) image = Image.open(BytesIO(response.content)) else: image = Image.open(image_path) # 准备输入 inputs = processor( text=candidate_labels, images=image, return_tensors="pt", padding=True, truncation=True ).to(device) # 模型推理 with torch.no_grad(): outputs = model(**inputs) # 计算相似度 logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1) # 整理结果 results = [] for i, label in enumerate(candidate_labels): results.append({ "label": label, "score": probs[0][i].item(), "confidence": f"{probs[0][i].item()*100:.2f}%" }) # 按置信度排序 results.sort(key=lambda x: x["score"], reverse=True) return results # 使用示例 candidate_labels = [ "a remote sensing image of river", "a remote sensing image of buildings and roads", "a remote sensing image of forest", "a remote sensing image of farmland", "a remote sensing image of airport" ] image_path = "https://example.com/satellite_image.jpg" results = classify_single_image(image_path, candidate_labels) print("分类结果:") for i, result in enumerate(results, 1): print(f"{i}. {result['label']} - {result['confidence']}")3.2 批量图像处理实现
import os from concurrent.futures import ThreadPoolExecutor import pandas as pd def batch_process_images(image_folder, candidate_labels, output_file="results.csv"): """ 批量处理文件夹中的遥感图像 参数: image_folder: 图像文件夹路径 candidate_labels: 候选标签列表 output_file: 输出结果文件 """ # 获取所有图像文件 image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff'] image_files = [] for file in os.listdir(image_folder): if any(file.lower().endswith(ext) for ext in image_extensions): image_files.append(os.path.join(image_folder, file)) print(f"找到 {len(image_files)} 张图像") # 批量处理 all_results = [] def process_single_image(image_path): try: results = classify_single_image(image_path, candidate_labels) best_result = results[0] # 取置信度最高的结果 return { "image_name": os.path.basename(image_path), "predicted_label": best_result["label"], "confidence": best_result["confidence"], "all_scores": {r["label"]: r["score"] for r in results} } except Exception as e: print(f"处理图像 {image_path} 时出错: {e}") return None # 使用多线程加速处理 with ThreadPoolExecutor(max_workers=4) as executor: results = list(executor.map(process_single_image, image_files)) # 过滤掉失败的结果 all_results = [r for r in results if r is not None] # 保存结果到CSV df = pd.DataFrame(all_results) df.to_csv(output_file, index=False, encoding='utf-8-sig') print(f"处理完成!结果已保存到 {output_file}") return df # 使用示例 image_folder = "/path/to/your/satellite/images" batch_results = batch_process_images(image_folder, candidate_labels)3.3 图文相似度计算
def calculate_image_text_similarity(image_path, text_descriptions): """ 计算图像与多个文本描述的相似度 参数: image_path: 图像路径或URL text_descriptions: 文本描述列表 返回: 相似度得分列表 """ # 加载图像 if image_path.startswith('http'): response = requests.get(image_path) image = Image.open(BytesIO(response.content)) else: image = Image.open(image_path) # 准备输入 inputs = processor( text=text_descriptions, images=image, return_tensors="pt", padding=True, truncation=True ).to(device) # 模型推理 with torch.no_grad(): outputs = model(**inputs) # 获取相似度分数 similarity_scores = outputs.logits_per_image.cpu().numpy() results = [] for i, text in enumerate(text_descriptions): results.append({ "text_description": text, "similarity_score": float(similarity_scores[0][i]), "normalized_score": float(torch.sigmoid(torch.tensor(similarity_scores[0][i])).numpy()) }) # 按相似度排序 results.sort(key=lambda x: x["similarity_score"], reverse=True) return results # 使用示例 image_path = "https://example.com/urban_area.jpg" text_descriptions = [ "urban area with high buildings", "rural farmland with crops", "forest area with dense trees", "water body like river or lake" ] similarity_results = calculate_image_text_similarity(image_path, text_descriptions) print("图文相似度结果:") for result in similarity_results: print(f"{result['text_description']}: {result['normalized_score']:.4f}")4. 实战应用案例
4.1 遥感图像自动标注系统
class RemoteSensingAutoLabeler: def __init__(self, predefined_categories): self.predefined_categories = predefined_categories self.git_rsclip = GitRSCLIP() def auto_label_dataset(self, dataset_path, output_dir): """ 自动标注整个数据集 """ if not os.path.exists(output_dir): os.makedirs(output_dir) # 获取所有图像 image_files = [] for root, _, files in os.walk(dataset_path): for file in files: if file.lower().endswith(('.jpg', '.jpeg', '.png')): image_files.append(os.path.join(root, file)) print(f"开始自动标注 {len(image_files)} 张图像...") results = [] for i, image_path in enumerate(image_files): if i % 100 == 0: print(f"已处理 {i}/{len(image_files)} 张图像") try: classification = classify_single_image(image_path, self.predefined_categories) best_label = classification[0]['label'] # 创建标签对应的文件夹 label_dir = os.path.join(output_dir, best_label) if not os.path.exists(label_dir): os.makedirs(label_dir) # 复制图像到对应文件夹 import shutil shutil.copy2(image_path, os.path.join(label_dir, os.path.basename(image_path))) results.append({ 'image_path': image_path, 'predicted_label': best_label, 'confidence': classification[0]['confidence'] }) except Exception as e: print(f"处理 {image_path} 时出错: {e}") # 保存标注结果 df = pd.DataFrame(results) df.to_csv(os.path.join(output_dir, 'labeling_report.csv'), index=False) return df # 使用示例 predefined_categories = [ "urban residential area", "commercial district", "agricultural farmland", "forest vegetation", "water body", "industrial area", "transportation infrastructure" ] labeler = RemoteSensingAutoLabeler(predefined_categories) labeling_results = labeler.auto_label_dataset("/path/to/raw/images", "/path/to/labeled/dataset")4.2 批量图像检索系统
class ImageRetrievalSystem: def __init__(self): self.git_rsclip = GitRSCLIP() self.image_database = {} def build_database(self, image_folder): """ 构建图像特征数据库 """ image_files = [] for file in os.listdir(image_folder): if file.lower().endswith(('.jpg', '.jpeg', '.png')): image_files.append(os.path.join(image_folder, file)) print(f"正在提取 {len(image_files)} 张图像的特征...") for image_path in image_files: try: # 提取图像特征 image = Image.open(image_path) inputs = processor(images=image, return_tensors="pt").to(device) with torch.no_grad(): image_features = model.get_image_features(**inputs) self.image_database[image_path] = { 'features': image_features.cpu().numpy(), 'filename': os.path.basename(image_path) } except Exception as e: print(f"处理 {image_path} 时出错: {e}") def search_similar_images(self, query_image_path, top_k=5): """ 根据查询图像搜索相似图像 """ # 提取查询图像特征 query_image = Image.open(query_image_path) inputs = processor(images=query_image, return_tensors="pt").to(device) with torch.no_grad(): query_features = model.get_image_features(**inputs) # 计算相似度 similarities = [] for img_path, data in self.image_database.items(): similarity = torch.cosine_similarity( query_features.cpu(), torch.tensor(data['features']) ).item() similarities.append((img_path, similarity, data['filename'])) # 按相似度排序 similarities.sort(key=lambda x: x[1], reverse=True) return similarities[:top_k] # 使用示例 retrieval_system = ImageRetrievalSystem() retrieval_system.build_database("/path/to/image/database") # 搜索相似图像 query_image = "/path/to/query/image.jpg" similar_images = retrieval_system.search_similar_images(query_image, top_k=10) print("最相似的图像:") for i, (img_path, similarity, filename) in enumerate(similar_images, 1): print(f"{i}. {filename} - 相似度: {similarity:.4f}")5. 性能优化与最佳实践
5.1 批量处理优化技巧
def optimized_batch_processing(image_paths, candidate_labels, batch_size=8): """ 优化后的批量处理函数,支持批处理加速 """ results = [] # 分批处理 for i in range(0, len(image_paths), batch_size): batch_paths = image_paths[i:i+batch_size] batch_images = [] # 批量加载图像 for path in batch_paths: try: image = Image.open(path) batch_images.append(image) except: batch_images.append(None) # 批量处理 valid_indices = [idx for idx, img in enumerate(batch_images) if img is not None] valid_images = [img for img in batch_images if img is not None] if valid_images: # 准备批量输入 inputs = processor( text=candidate_labels, images=valid_images, return_tensors="pt", padding=True, truncation=True ).to(device) # 批量推理 with torch.no_grad(): outputs = model(**inputs) # 处理结果 logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1) for j, idx in enumerate(valid_indices): image_results = [] for k, label in enumerate(candidate_labels): image_results.append({ "label": label, "score": probs[j][k].item() }) image_results.sort(key=lambda x: x["score"], reverse=True) results.append({ "image_path": batch_paths[idx], "predictions": image_results }) print(f"已处理 {min(i+batch_size, len(image_paths))}/{len(image_paths)} 张图像") return results5.2 内存优化策略
class MemoryEfficientGitRSCLIP: def __init__(self, device=None): self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.processor = AutoProcessor.from_pretrained("git_rsclip") # 使用低精度推理节省内存 self.model = AutoModel.from_pretrained( "git_rsclip", torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ).to(self.device) self.model.eval() # 设置为评估模式 def process_with_memory_optimization(self, image_path, candidate_labels): """ 内存优化的处理函数 """ # 使用梯度关闭和推理模式 with torch.no_grad(), torch.inference_mode(): image = Image.open(image_path) # 调整图像大小减少内存占用 if max(image.size) > 512: image.thumbnail((512, 512)) inputs = processor( text=candidate_labels, images=image, return_tensors="pt", padding=True, truncation=True ).to(self.device) # 使用更小的批处理大小 outputs = model(**inputs) logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1) return [ {"label": label, "score": probs[0][i].item()} for i, label in enumerate(candidate_labels) ]6. 总结与建议
6.1 技术要点回顾
通过本文的代码实例,我们实现了Git-RSCLIP模型的多种应用方式:
- 单图像分类:快速对单张遥感图像进行分类
- 批量处理:高效处理大量图像数据
- 图文检索:实现图像与文本的相似度计算
- 自动标注:构建完整的自动标注流水线
- 图像检索:建立基于内容的图像检索系统
6.2 最佳实践建议
标签设计技巧:
- 使用英文描述,效果优于中文
- 描述要具体明确,如"a remote sensing image of dense urban area"比"city"更好
- 包含场景上下文信息,如时间、季节、地理特征等
性能优化建议:
- 使用批处理提高GPU利用率
- 对大量数据采用多线程处理
- 调整图像尺寸平衡精度和速度
- 使用混合精度训练节省内存
应用场景扩展:
- 结合GIS系统进行空间分析
- 与时间序列数据结合进行变化检测
- 集成到自动化工作流中实现端到端处理
Git-RSCLIP为遥感图像分析提供了强大的零样本分类能力,通过合理的代码实现和优化,可以在实际项目中发挥重要作用。建议根据具体需求选择合适的实现方式,并不断优化调整以获得最佳效果。
获取更多AI镜像
想探索更多AI镜像和应用场景?访问 CSDN星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。
