问答系统:从检索到生成式模型
1. 技术分析
1.1 问答系统类型
问答系统可分为多种类型:
问答系统分类 检索式: 从知识库中检索答案 抽取式: 从文本中抽取答案片段 生成式: 直接生成答案 多模态: 结合文本和视觉
1.2 问答系统架构对比
| 类型 | 架构 | 特点 | 代表模型 |
|---|
| 检索式 | TF-IDF/BM25 | 简单快速 | Elasticsearch |
| 抽取式 | BERT | 准确 | BERT-QA |
| 生成式 | T5/GPT | 灵活 | T5-QA |
| 多模态 | ViLT | 多模态 | ViLT-QA |
1.3 QA 任务类型
QA 任务分类 SQuAD: 抽取式问答 HotpotQA: 多跳问答 TriviaQA: 开放域问答 VQA: 视觉问答
2. 核心功能实现
2.1 检索式问答
import torch import torch.nn as nn import numpy as np from rank_bm25 import BM25Okapi class RetrievalQA: def __init__(self, documents): self.documents = documents self.tokenized_docs = [doc.lower().split() for doc in documents] self.bm25 = BM25Okapi(self.tokenized_docs) def retrieve(self, query, top_k=5): tokenized_query = query.lower().split() scores = self.bm25.get_scores(tokenized_query) top_indices = np.argsort(scores)[::-1][:top_k] return [(self.documents[i], scores[i]) for i in top_indices] def answer(self, query, top_k=1): results = self.retrieve(query, top_k) return results[0][0] if results else None class DenseRetrieval(nn.Module): def __init__(self, model_name='bert-base-uncased'): super().__init__() from transformers import BertModel, BertTokenizer self.model = BertModel.from_pretrained(model_name) self.tokenizer = BertTokenizer.from_pretrained(model_name) def encode(self, texts): inputs = self.tokenizer( texts, padding=True, truncation=True, max_length=512, return_tensors='pt' ) outputs = self.model(**inputs) embeddings = outputs.last_hidden_state[:, 0, :] return embeddings def retrieve(self, query, documents, top_k=5): query_embedding = self.encode([query]) doc_embeddings = self.encode(documents) scores = torch.matmul(query_embedding, doc_embeddings.T).squeeze(0) top_indices = torch.argsort(scores, descending=True)[:top_k] return [(documents[i], scores[i].item()) for i in top_indices]
2.2 抽取式问答
class ExtractiveQA(nn.Module): def __init__(self, model_name='bert-base-uncased'): super().__init__() from transformers import BertForQuestionAnswering self.model = BertForQuestionAnswering.from_pretrained(model_name) def forward(self, input_ids, attention_mask): outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) return outputs.start_logits, outputs.end_logits def predict(self, question, context): from transformers import BertTokenizer tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') inputs = tokenizer( question, context, padding=True, truncation=True, max_length=512, return_tensors='pt' ) with torch.no_grad(): start_logits, end_logits = self.forward(inputs['input_ids'], inputs['attention_mask']) start_idx = torch.argmax(start_logits) end_idx = torch.argmax(end_logits) tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) answer = tokenizer.convert_tokens_to_string(tokens[start_idx:end_idx+1]) return answer class QAWithRetrieval: def __init__(self, documents): self.retriever = DenseRetrieval() self.extractor = ExtractiveQA() self.documents = documents def answer(self, question): candidates = self.retriever.retrieve(question, self.documents, top_k=3) for doc, _ in candidates: answer = self.extractor.predict(question, doc) if answer.strip(): return answer return "No answer found"
2.3 生成式问答
class GenerativeQA(nn.Module): def __init__(self, model_name='t5-base'): super().__init__() from transformers import T5ForConditionalGeneration, T5Tokenizer self.model = T5ForConditionalGeneration.from_pretrained(model_name) self.tokenizer = T5Tokenizer.from_pretrained(model_name) def generate(self, question, context=None): if context: input_text = f"question: {question} context: {context}" else: input_text = f"question: {question}" inputs = self.tokenizer( input_text, padding=True, truncation=True, max_length=512, return_tensors='pt' ) with torch.no_grad(): outputs = self.model.generate( **inputs, max_length=100, num_beams=5, early_stopping=True ) answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True) return answer class OpenDomainQA: def __init__(self, retriever, generator): self.retriever = retriever self.generator = generator def answer(self, question, documents): candidates = self.retriever.retrieve(question, documents, top_k=3) context = "\n".join([doc for doc, _ in candidates]) return self.generator.generate(question, context)
3. 性能对比
3.1 问答系统类型对比
| 类型 | 准确率 | 灵活性 | 训练数据 | 推理速度 |
|---|
| 检索式 | 中 | 低 | 无 | 很快 |
| 抽取式 | 高 | 中 | 中 | 中 |
| 生成式 | 高 | 很高 | 高 | 慢 |
3.2 不同 QA 数据集表现
| 数据集 | 抽取式 | 生成式 | 检索+生成 |
|---|
| SQuAD v1 | 92% | 88% | 90% |
| SQuAD v2 | 83% | 79% | 81% |
| HotpotQA | 75% | 72% | 78% |
3.3 模型大小影响
| 模型 | 参数 | F1 | 推理时间(ms) |
|---|
| BERT-base | 110M | 89% | 50 |
| BERT-large | 340M | 93% | 150 |
| T5-base | 220M | 87% | 100 |
| T5-large | 770M | 91% | 300 |
4. 最佳实践
4.1 问答系统选择
def select_qa_system(task_type, data_size): if task_type == 'retrieval': return RetrievalQA([]) elif task_type == 'extractive': return ExtractiveQA() elif task_type == 'generative': return GenerativeQA() else: return QAWithRetrieval([]) class QASystemFactory: @staticmethod def create(config): if config['type'] == 'retrieval': return RetrievalQA(config['documents']) elif config['type'] == 'extractive': return ExtractiveQA(config['model_name']) elif config['type'] == 'generative': return GenerativeQA(config['model_name']) elif config['type'] == 'hybrid': return OpenDomainQA( DenseRetrieval(config['retriever_model']), GenerativeQA(config['generator_model']) )
4.2 QA 系统训练流程
class QATrainer: def __init__(self, model, optimizer, scheduler, loss_fn): self.model = model self.optimizer = optimizer self.scheduler = scheduler self.loss_fn = loss_fn def train_step(self, batch): self.optimizer.zero_grad() input_ids = batch['input_ids'] attention_mask = batch['attention_mask'] start_positions = batch['start_positions'] end_positions = batch['end_positions'] start_logits, end_logits = self.model(input_ids, attention_mask) loss = (self.loss_fn(start_logits, start_positions) + self.loss_fn(end_logits, end_positions)) / 2 loss.backward() self.optimizer.step() self.scheduler.step() return loss.item() def evaluate(self, dataloader): self.model.eval() total_f1 = 0 with torch.no_grad(): for batch in dataloader: input_ids = batch['input_ids'] attention_mask = batch['attention_mask'] start_positions = batch['start_positions'] end_positions = batch['end_positions'] start_logits, end_logits = self.model(input_ids, attention_mask) start_pred = torch.argmax(start_logits, dim=1) end_pred = torch.argmax(end_logits, dim=1) for i in range(len(start_pred)): tp = ((start_pred[i] >= start_positions[i]) & (end_pred[i] <= end_positions[i])).sum().item() fp = ((start_pred[i] < start_positions[i]) | (end_pred[i] > end_positions[i])).sum().item() fn = ((start_pred[i] > start_positions[i]) | (end_pred[i] < end_positions[i])).sum().item() precision = tp / (tp + fp) if (tp + fp) > 0 else 0 recall = tp / (tp + fn) if (tp + fn) > 0 else 0 f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 total_f1 += f1 return total_f1 / len(dataloader)
5. 总结
问答系统是 NLP 的重要应用:
- 检索式:简单快速,适合小规模知识库
- 抽取式:准确,适合有上下文的场景
- 生成式:灵活,可生成自然语言答案
- 混合式:结合检索和生成,效果最佳
对比数据如下:
- 生成式在开放域问答中表现更好
- 抽取式在限定域问答中更准确
- 推荐使用混合架构
- 预训练模型是 QA 系统的基础