fastAPI+pgvector搭建向量搜索
一、 fastApi嵌入bert-base-chinese模型,提供接口服务
1、下载bert-base-chinese模型
这里我在魔搭下载模型,它是一个基于Transformer架构的中文预训练模型,使用了大量的中文语料进行训练。它在多个中文自然语言处理任务上表现出色,如文本分类、命名实体识别和情感分析等。可以生成768维的向量。
https://www.modelscope.cn/models/tiansz/bert-base-chinese/files
2、python实现将字符串转为向量
需要以下类库支持
pip3 install numpy pip3 install transformers pip3 install pydantic pip3 install torchpython 代码实现
from pydantic import BaseModel from transformers import AutoModel, AutoTokenizer import torch import numpy as np from typing import List, Optional # 模型路径(自动下载/本地路径均可) MODEL_NAME = "bert-base-chinese" DEVICE = 'cpu' # 预加载模型和分词器(只加载一次,提升接口速度) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE).eval() # eval模式关闭训练层,提速 # ===================== 数据模型定义 ===================== class TextRequest(BaseModel): """单文本请求模型""" text: str # 输入文本(如FAQ问题) # ===================== 核心向量生成函数 ===================== def get_embedding(text: str) -> List[float]: """ 单文本生成768维向量(适配pgvector) :param text: 输入文本 :return: 768维向量列表(float类型) """ if not text or text.strip() == "": raise ValueError("输入文本不能为空") # 1. 文本分词(BERT标准处理) inputs = tokenizer( text.strip(), return_tensors="pt", padding=True, truncation=True, max_length=512 # BERT最大长度 ).to(DEVICE) # 2. 模型推理(无梯度计算,提速) with torch.no_grad(): outputs = model(**inputs) # 3. 提取CLS token的向量(BERT语义核心,768维) cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy() # 归一化(提升pgvector余弦相似度计算精度) cls_embedding = cls_embedding / np.linalg.norm(cls_embedding) # 转换为List[float],适配pgvector return cls_embedding.tolist() if __name__ == "__main__": result = get_embedding('健康检查') print(result)3、fastAPI配置接口服务
from fastapi import FastAPI, Request, WebSocket from fastapi.staticfiles import StaticFiles from pydantic import BaseModel import uvicorn import json from bert_embedding_service import get_embedding app = FastAPI() app.mount("/static", StaticFiles(directory="static"), name="static") class Item(BaseModel): inputs: object = None sql: str = None db: str = None sid: str = None passwd: str = None text: str = None @app.post("/embedding", summary="单文本生成向量") def single_embedding(item: Item): return get_embedding(text=item.text) if __name__ == "__main__": uvicorn.run(app) # app.run(host='0.0.0.0',port=8080)以下是启动fastAPI的命令:
nohup uvicorn main:app --reload --workers 4 --host 0.0.0.0 --port=8080 &二、 docker搭建postgrep+pgvector数据库,提供存储数据服务
1、docker compose脚本
version: '3.8' services: pgvector: image: ankane/pgvector:latest # 预装pgvector的PostgreSQL镜像 container_name: pgvector-db restart: always # 容器异常退出自动重启 ports: - "5432:5432" # 宿主机端口:容器端口 environment: POSTGRES_USER: postgres # 用户名 POSTGRES_PASSWORD: Abc1234% # 密码(生产环境建议改为复杂密码 POSTGRES_DB: vectordb # 初始数据库 PGDATA: /var/lib/postgresql/data/pgdata # 数据存储路径 volumes: - pgvector-data:/var/lib/postgresql/data # 数据卷(持久化) # 可选:挂载自定义postgresql.conf配置文件 # - ./postgresql.conf:/var/lib/postgresql/data/postgresql.conf networks: - pgvector-network volumes: pgvector-data: # 命名卷,数据持久化 networks: pgvector-network: # 自定义网络,便于多容器通信2、docker启/停命令
# 启动容器(后台运行) docker-compose up -d # 查看容器状态 docker-compose ps # 停止容器(保留数据) docker-compose down # 停止并删除数据卷(谨慎!) docker-compose down -v3、验证部署
# 进入容器连接数据库 docker exec -it pgvector-db psql -U postgres -d vectordb # 验证pgvector CREATE EXTENSION IF NOT EXISTS vector; -- 创建测试表(向量维度为3) CREATE TABLE test_vectors (id SERIAL PRIMARY KEY, vec vector(3)); -- 插入测试数据 INSERT INTO test_vectors (vec) VALUES ('[1,2,3]'), ('[4,5,6]'); -- 查询向量(验证功能) SELECT * FROM test_vectors WHERE vec <-> '[1,2,3]' < 1; # 向量距离查询三、 核心 SQL 语句(向量存储 + 相似度查询)
1. 基础操作:创建扩展 + 表sql
-- 1. 启用 pgvector 扩展(首次使用必须执行) CREATE EXTENSION IF NOT EXISTS vector; -- 2. 创建带向量字段的表(核心) -- 示例:存储文本嵌入向量(维度 1536,对应 OpenAI text-embedding-ada-002) CREATE TABLE document_vectors ( id SERIAL PRIMARY KEY, -- 主键 document_id INT NOT NULL, -- 关联业务表的ID content TEXT, -- 原始文本(可选,便于排查) embedding vector(1536), -- 向量字段,指定维度(必填) created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP -- 创建时间 ); -- 可选:添加业务索引(非向量) CREATE INDEX idx_document_vectors_document_id ON document_vectors(document_id);2. 向量插入 / 更新 / 删除sql
-- 1. 插入向量(示例:1536维向量,实际替换为你的嵌入向量) INSERT INTO document_vectors (document_id, content, embedding) VALUES ( 1001, 'pgvector 向量数据库使用教程', '[0.123, 0.456, ..., 0.789]' -- 实际为1536个数值的向量 ); -- 2. 批量插入(推荐,效率更高) INSERT INTO document_vectors (document_id, content, embedding) VALUES (1002, 'Docker 部署 PostgreSQL', '[0.234, 0.567, ..., 0.890]'), (1003, 'pgvector 相似度查询调优', '[0.345, 0.678, ..., 0.901]'); -- 3. 更新向量 UPDATE document_vectors SET embedding = '[0.111, 0.222, ..., 0.333]' WHERE id = 1; -- 4. 删除向量 DELETE FROM document_vectors WHERE document_id = 1001;3. 相似度查询(核心)sql
-- 场景1:查询与目标向量最相似的Top10文档(L2距离) -- 目标向量:替换为你的查询嵌入向量 SELECT id, document_id, content, -- 计算L2距离(越小越相似) embedding <-> '[0.123, 0.456, ..., 0.789]' AS distance FROM document_vectors -- 按距离升序排序(最相似的在前) ORDER BY distance LIMIT 10; -- 场景2:余弦相似度查询(需先归一化向量) -- 步骤1:归一化向量(插入前处理,或查询时归一化) UPDATE document_vectors SET embedding = embedding / sqrt(embedding <#> embedding); -- 步骤2:余弦相似度查询 SELECT id, document_id, content, -- 余弦相似度值 = 1 - 余弦距离(越接近1越相似) 1 - (embedding <=> '[0.123, 0.456, ..., 0.789]') AS cosine_similarity FROM document_vectors ORDER BY embedding <=> '[0.123, 0.456, ..., 0.789]' LIMIT 10; -- 场景3:过滤后再查相似度(提升效率) -- 示例:只查近7天的文档,再找相似Top5 SELECT id, document_id, content, embedding <-> '[0.123, 0.456, ..., 0.789]' AS distance FROM document_vectors WHERE created_at >= NOW() - INTERVAL '7 days' ORDER BY distance LIMIT 5; -- 场景4:距离阈值过滤(只返回相似度足够高的结果) SELECT id, document_id, content FROM document_vectors -- L2距离小于0.5(阈值根据业务调整) WHERE embedding <-> '[0.123, 0.456, ..., 0.789]' < 0.5 ORDER BY distance LIMIT 10;四、Java 操作 PostgreSQL + pgvector
解决方案(两种常用方案,按需选择)
前置准备
确保已引入 pgvector 的 Java 依赖(Maven/Gradle),这是操作 vector 类型的基础:
xml
<!-- Maven 依赖 --> <dependency> <groupId>com.github.pgvector</groupId> <artifactId>pgvector-java</artifactId> <version>0.10.0</version> </dependency> <!-- PostgreSQL JDBC 驱动 --> <dependency> <groupId>org.postgresql</groupId> <artifactId>postgresql</artifactId> <version>42.7.3</version> </dependency>方案 1:使用 pgvector-java 库(推荐,简洁)
pgvector-java 库封装了Vector类型,可直接与 PostgreSQL 的vector类型映射,避免手动处理类型转换:
java
运行
import com.github.pgvector.Vector; import org.postgresql.PGConnection; import java.sql.Connection; import java.sql.DriverManager; import java.sql.PreparedStatement; import java.util.ArrayList; import java.util.List; public class PgVectorExample { public static void main(String[] args) throws Exception { // 1. 数据库连接配置 String url = "jdbc:postgresql://localhost:5432/vectordb"; String user = "postgres"; String password = "mysecretpassword"; // 2. 示例向量数据(ArrayList 存储数值) List<Double> vectorList = new ArrayList<>(); vectorList.add(0.123); vectorList.add(0.456); vectorList.add(0.789); // 3. 转换为 pgvector-java 的 Vector 类型 double[] vectorArray = vectorList.stream().mapToDouble(Double::doubleValue).toArray(); Vector pgVector = new Vector(vectorArray); // 4. 连接数据库并插入向量(自动映射类型) try (Connection conn = DriverManager.getConnection(url, user, password)) { // 注册 vector 类型(仅首次连接时执行) PGConnection pgConn = conn.unwrap(PGConnection.class); pgConn.addDataType("vector", Vector.class); // 插入向量(使用 pgVector 对象,无需手动指定类型) String sql = "INSERT INTO test_vectors (vec) VALUES (?)"; try (PreparedStatement pstmt = conn.prepareStatement(sql)) { pstmt.setObject(1, pgVector); // 自动识别为 vector 类型 pstmt.executeUpdate(); System.out.println("向量插入成功!"); } } } }方案 2:手动指定类型 + 转换为字符串格式(无依赖,通用)
若不想引入 pgvector-java 库,可将 ArrayList 转换为 pgvector 识别的字符串格式(如[0.123,0.456,0.789]),并显式指定 SQL 类型:
java
运行
import org.postgresql.util.PGobject; import java.sql.Connection; import java.sql.DriverManager; import java.sql.PreparedStatement; import java.sql.Types; import java.util.ArrayList; import java.util.List; public class PgVectorRawExample { public static void main(String[] args) throws Exception { // 1. 数据库连接配置 String url = "jdbc:postgresql://localhost:5432/vectordb"; String user = "postgres"; String password = "mysecretpassword"; // 2. 示例向量数据 List<Double> vectorList = new ArrayList<>(); vectorList.add(0.123); vectorList.add(0.456); vectorList.add(0.789); // 3. 转换为 pgvector 字符串格式([数值1,数值2,...]) String vectorStr = "[" + String.join(",", vectorList.stream().map(String::valueOf).toArray(String[]::new)) + "]"; // 4. 手动封装为 PGobject 并指定类型 PGobject pgObject = new PGobject(); pgObject.setType("vector"); // 显式指定 PostgreSQL 类型为 vector pgObject.setValue(vectorStr); // 5. 插入向量(显式指定 Types.OTHER) try (Connection conn = DriverManager.getConnection(url, user, password)) { String sql = "INSERT INTO test_vectors (vec) VALUES (?)"; try (PreparedStatement pstmt = conn.prepareStatement(sql)) { // 关键:用 setObject 并指定 Types.OTHER,告诉 JDBC 这是自定义类型 pstmt.setObject(1, pgObject, Types.OTHER); pstmt.executeUpdate(); System.out.println("向量插入成功!"); } } } }方案 3:查询向量时的类型处理(补充)
若查询时也遇到类型推断问题,同样用 pgvector-java 的Vector接收结果:
java
运行
import java.sql.ResultSet; import java.sql.Statement; // 接方案1的连接代码 try (Statement stmt = conn.createStatement()) { String sql = "SELECT vec FROM test_vectors LIMIT 1"; try (ResultSet rs = stmt.executeQuery(sql)) { if (rs.next()) { // 直接获取为 Vector 类型 Vector resultVector = (Vector) rs.getObject("vec"); double[] resultArray = resultVector.getValues(); System.out.println("查询到的向量:" + resultVector); } } }关键注意事项
- 向量维度匹配:插入的向量维度必须与表中
vector(n)定义的维度一致(如vector(3)必须传 3 个数值),否则会报维度不匹配错误。 - 空值处理:若向量可能为空,需先判断
vectorList是否为空,避免生成空的[]字符串。 - 性能优化:批量插入时,建议使用
addBatch()+executeBatch(),减少数据库交互次数:java
运行
// 批量插入示例 String sql = "INSERT INTO test_vectors (vec) VALUES (?)"; try (PreparedStatement pstmt = conn.prepareStatement(sql)) { for (List<Double> vecList : batchVectorList) { double[] arr = vecList.stream().mapToDouble(Double::doubleValue).toArray(); pstmt.setObject(1, new Vector(arr)); pstmt.addBatch(); } pstmt.executeBatch(); }
