Stable-Diffusion-V1-5 实战:基于JavaScript的实时交互式Web画板
Stable-Diffusion-V1-5 实战:基于JavaScript的实时交互式Web画板
想象一下,你正在为一个创意项目构思草图,或者只是想随手涂鸦。传统的数字绘画工具能帮你记录想法,但如果能让你的草图瞬间变成一幅充满细节和艺术感的完整画作呢?这听起来像是未来科技,但今天,我们就能用Stable Diffusion和几行JavaScript代码把它变成现实。
很多开发者对AI绘画模型望而却步,觉得它复杂、笨重,只能在后端服务器上运行。但事实是,通过巧妙的架构设计,我们可以把Stable Diffusion这样的强大模型,变成一个轻快、实时、能与用户直接对话的Web应用。用户在前端画布上随意勾勒几笔,或者输入几个关键词,浏览器背后就能实时生成一幅令人惊艳的图像并立刻呈现出来。
这篇文章,我就带你一步步搭建这样一个“魔法画板”。我们将从前端交互开始,打通实时通信链路,最后接入AI模型,完成一个完整的、可交互的AI绘画Web应用。整个过程,你会看到JavaScript如何成为连接创意与智能的桥梁。
1. 场景与核心思路:让AI实时响应你的画笔
在开始写代码之前,我们先搞清楚要做一个什么东西,以及为什么这么做。
核心场景:用户打开一个网页,看到一个空白的画布。他可以用鼠标(或触控笔)在上面自由绘画,比如画一个简单的房子轮廓、一棵树的形状,或者一些抽象的线条。同时,旁边有一个输入框,可以输入一些描述性文字,比如“阳光下的童话小屋”、“赛博朋克风格的城市”。当用户点击“生成”按钮,或者甚至每画完一笔,应用就能近乎实时地将草图与文字结合,通过Stable Diffusion模型生成一张精美的图片,并显示在画布旁边。
技术挑战与思路:
- 实时性:传统的“上传草图->等待处理->下载结果”流程体验割裂。我们需要实现近乎实时的反馈。解决方案是使用WebSocket建立前后端的持久连接,让草图数据和生成进度可以双向、低延迟地流动。
- 轻量前端与重型后端:Stable Diffusion模型推理是计算密集型任务,不可能在用户浏览器中运行。因此,架构必然是“轻前端 + 重后端”。前端(HTML/JS)负责交互和展示,后端(Python等)负责调用模型。
- 数据流转:草图是像素数据,需要高效地从前端传送到后端。我们将使用Canvas的API获取图像数据,并将其转换为适合网络传输和模型处理的格式(如Base64编码)。
- 用户体验:生成过程需要时间(几秒到几十秒)。我们需要提供明确的等待状态(如进度条、加载动画),并在生成完成后无缝更新画面。
整个应用的数据流可以概括为:用户绘制/输入 -> 前端捕获并编码 -> WebSocket发送 -> 后端接收并调用SD模型 -> 模型生成图像 -> 后端编码图像 -> WebSocket回传 -> 前端解码并渲染。
2. 搭建交互式前端画板
我们的前端需要三块核心区域:一个用于绘制的画布,一个用于显示生成结果的画布,以及一些控制元素。
2.1 创建基础HTML结构
我们先从最简单的HTML骨架开始,引入必要的样式和脚本。
<!DOCTYPE html> <html lang="zh-CN"> <head> <meta charset="UTF-8"> <meta name="viewport" content="width=device-width, initial-scale=1.0"> <title>实时AI魔法画板</title> <style> body { font-family: sans-serif; margin: 20px; background-color: #f5f5f5; } .container { display: flex; flex-wrap: wrap; gap: 20px; max-width: 1200px; margin: 0 auto; } .canvas-container { border: 2px solid #ccc; border-radius: 8px; background-color: white; box-shadow: 0 4px 6px rgba(0,0,0,0.1); } canvas { display: block; /* 消除canvas底部的间隙 */ } #drawingCanvas { cursor: crosshair; background-color: #fff; } #outputCanvas { background-color: #f0f0f0; } .controls { display: flex; flex-direction: column; gap: 15px; min-width: 250px; } .control-group { background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05); } label { display: block; margin-bottom: 5px; font-weight: bold; color: #333; } input, textarea, button { width: 100%; padding: 10px; box-sizing: border-box; border: 1px solid #ddd; border-radius: 4px; font-size: 14px; } textarea { height: 80px; resize: vertical; font-family: monospace; } button { background-color: #4a6fa5; color: white; border: none; cursor: pointer; font-weight: bold; transition: background-color 0.2s; } button:hover { background-color: #385d8a; } button:disabled { background-color: #aaa; cursor: not-allowed; } #status { padding: 10px; border-radius: 4px; margin-top: 10px; min-height: 20px; } .status-waiting { background-color: #fff3cd; color: #856404; } .status-processing { background-color: #cce5ff; color: #004085; } .status-success { background-color: #d4edda; color: #155724; } .status-error { background-color: #f8d7da; color: #721c24; } </style> </head> <body> <div class="container"> <!-- 左侧:绘制区 --> <div class="canvas-container"> <h3>绘制你的草图</h3> <canvas id="drawingCanvas" width="512" height="512"></canvas> <div style="text-align: center; margin-top: 10px;"> <button id="clearBtn">清空画布</button> <button id="undoBtn">撤销一步</button> </div> </div> <!-- 中间:控制区 --> <div class="controls"> <div class="control-group"> <h3>AI 生成控制</h3> <label for="promptInput">文字描述 (Prompt):</label> <textarea id="promptInput" placeholder="例如:一座被星空环绕的灯塔,油画风格"></textarea> <label for="negativePromptInput">不想出现的内容 (Negative Prompt):</label> <input type="text" id="negativePromptInput" placeholder="例如:模糊,丑陋,多只手"> <label for="stepsInput">生成步数 (Steps):</label> <input type="number" id="stepsInput" value="20" min="1" max="50"> <label for="cfgScaleInput">文本遵循度 (CFG Scale):</label> <input type="number" id="cfgScaleInput" value="7.5" min="1" max="20" step="0.5"> <div style="margin-top: 15px;"> <button id="generateBtn">生成图像</button> <button id="autoGenerateBtn" style="margin-top: 5px;">开启实时生成模式</button> </div> </div> <div class="control-group"> <h3>画笔设置</h3> <label for="brushSize">画笔大小:</label> <input type="range" id="brushSize" min="1" max="50" value="5"> <label for="brushColor">画笔颜色:</label> <input type="color" id="brushColor" value="#000000"> </div> <div class="control-group"> <h3>系统状态</h3> <div id="status" class="status-waiting">准备就绪。请开始绘制或输入描述。</div> <div id="progressContainer" style="display:none; margin-top:10px;"> <label>生成进度:</label> <progress id="generationProgress" value="0" max="100"></progress> <span id="progressText">0%</span> </div> </div> </div> <!-- 右侧:结果展示区 --> <div class="canvas-container"> <h3>AI 生成结果</h3> <canvas id="outputCanvas" width="512" height="512"></canvas> <div style="text-align: center; margin-top: 10px;"> <button id="downloadBtn" disabled>下载图片</button> </div> </div> </div> <script src="app.js"></script> </body> </html>这个页面布局清晰,分为绘制区、控制区和结果区。我们预留了所有交互元素,并添加了一些基础样式。
2.2 实现画布绘制逻辑
接下来是核心的app.js,我们先实现画布的绘制功能。
// app.js document.addEventListener('DOMContentLoaded', function() { // 获取DOM元素 const drawingCanvas = document.getElementById('drawingCanvas'); const outputCanvas = document.getElementById('outputCanvas'); const ctx = drawingCanvas.getContext('2d'); const outputCtx = outputCanvas.getContext('2d'); const clearBtn = document.getElementById('clearBtn'); const undoBtn = document.getElementById('undoBtn'); const generateBtn = document.getElementById('generateBtn'); const autoGenerateBtn = document.getElementById('autoGenerateBtn'); const downloadBtn = document.getElementById('downloadBtn'); const promptInput = document.getElementById('promptInput'); const negativePromptInput = document.getElementById('negativePromptInput'); const stepsInput = document.getElementById('stepsInput'); const cfgScaleInput = document.getElementById('cfgScaleInput'); const brushSizeInput = document.getElementById('brushSize'); const brushColorInput = document.getElementById('brushColor'); const statusDiv = document.getElementById('status'); const progressContainer = document.getElementById('progressContainer'); const generationProgress = document.getElementById('generationProgress'); const progressText = document.getElementById('progressText'); // 绘图状态变量 let isDrawing = false; let lastX = 0; let lastY = 0; let drawingHistory = []; const HISTORY_LIMIT = 20; // 初始化画布 ctx.fillStyle = 'white'; ctx.fillRect(0, 0, drawingCanvas.width, drawingCanvas.height); outputCtx.fillStyle = '#f0f0f0'; outputCtx.fillRect(0, 0, outputCanvas.width, outputCanvas.height); outputCtx.font = '16px sans-serif'; outputCtx.fillStyle = '#999'; outputCtx.textAlign = 'center'; outputCtx.fillText('等待生成图像...', outputCanvas.width/2, outputCanvas.height/2); // 画笔设置 let currentBrushSize = parseInt(brushSizeInput.value); let currentBrushColor = brushColorInput.value; brushSizeInput.addEventListener('input', (e) => { currentBrushSize = parseInt(e.target.value); }); brushColorInput.addEventListener('input', (e) => { currentBrushColor = e.target.value; }); // 绘图函数 function draw(e) { if (!isDrawing) return; ctx.strokeStyle = currentBrushColor; ctx.lineWidth = currentBrushSize; ctx.lineCap = 'round'; ctx.lineJoin = 'round'; // 获取画布上的坐标(考虑滚动和偏移) const rect = drawingCanvas.getBoundingClientRect(); const scaleX = drawingCanvas.width / rect.width; const scaleY = drawingCanvas.height / rect.height; const x = (e.clientX - rect.left) * scaleX; const y = (e.clientY - rect.top) * scaleY; ctx.beginPath(); ctx.moveTo(lastX, lastY); ctx.lineTo(x, y); ctx.stroke(); [lastX, lastY] = [x, y]; } // 事件监听 drawingCanvas.addEventListener('mousedown', (e) => { isDrawing = true; const rect = drawingCanvas.getBoundingClientRect(); const scaleX = drawingCanvas.width / rect.width; const scaleY = drawingCanvas.height / rect.height; lastX = (e.clientX - rect.left) * scaleX; lastY = (e.clientY - rect.top) * scaleY; // 开始新的笔画时,保存当前状态到历史记录 saveDrawingState(); }); drawingCanvas.addEventListener('mousemove', draw); drawingCanvas.addEventListener('mouseup', () => { isDrawing = false; ctx.beginPath(); // 结束当前路径 }); drawingCanvas.addEventListener('mouseout', () => { isDrawing = false; }); // 触摸屏支持 drawingCanvas.addEventListener('touchstart', (e) => { e.preventDefault(); const touch = e.touches[0]; const mouseEvent = new MouseEvent('mousedown', { clientX: touch.clientX, clientY: touch.clientY }); drawingCanvas.dispatchEvent(mouseEvent); }); drawingCanvas.addEventListener('touchmove', (e) => { e.preventDefault(); const touch = e.touches[0]; const mouseEvent = new MouseEvent('mousemove', { clientX: touch.clientX, clientY: touch.clientY }); drawingCanvas.dispatchEvent(mouseEvent); }); drawingCanvas.addEventListener('touchend', (e) => { e.preventDefault(); const mouseEvent = new MouseEvent('mouseup', {}); drawingCanvas.dispatchEvent(mouseEvent); }); // 历史记录功能 function saveDrawingState() { const imageData = ctx.getImageData(0, 0, drawingCanvas.width, drawingCanvas.height); drawingHistory.push(imageData); if (drawingHistory.length > HISTORY_LIMIT) { drawingHistory.shift(); // 移除最旧的历史记录 } undoBtn.disabled = drawingHistory.length === 0; } clearBtn.addEventListener('click', () => { ctx.fillStyle = 'white'; ctx.fillRect(0, 0, drawingCanvas.width, drawingCanvas.height); drawingHistory = []; saveDrawingState(); // 清空后保存空白状态 updateStatus('画布已清空', 'status-success'); }); undoBtn.addEventListener('click', () => { if (drawingHistory.length > 0) { // 恢复到上一个状态 const prevState = drawingHistory.pop(); ctx.putImageData(prevState, 0, 0); undoBtn.disabled = drawingHistory.length === 0; updateStatus('已撤销一步', 'status-success'); } }); // 状态更新函数 function updateStatus(message, cssClass = 'status-waiting') { statusDiv.textContent = message; statusDiv.className = 'status-waiting'; // 重置基础类 statusDiv.classList.add(cssClass); } // 更新进度条 function updateProgress(percent, message = '') { if (percent === 0) { progressContainer.style.display = 'block'; } generationProgress.value = percent; progressText.textContent = `${percent}%`; if (message) { updateStatus(message, 'status-processing'); } if (percent >= 100) { setTimeout(() => { progressContainer.style.display = 'none'; generationProgress.value = 0; }, 1000); } } // 初始化 saveDrawingState(); // 保存初始空白状态 undoBtn.disabled = true; downloadBtn.disabled = true; });现在,我们已经有了一个功能完整的画板。用户可以调整画笔大小和颜色,进行绘制、清空和撤销操作。接下来,我们需要让它“活”起来,连接到后端的AI模型。
3. 建立实时通信与后端连接
为了让前端画板能与后端的Stable Diffusion模型对话,我们需要建立实时通信。这里我们选择WebSocket,因为它支持全双工、低延迟的通信,非常适合实时更新生成进度。
3.1 前端WebSocket客户端
我们在app.js中继续添加WebSocket连接和通信逻辑。
// app.js (续) // WebSocket 连接 let socket = null; let isAutoGenerateMode = false; let autoGenerateTimeout = null; // 初始化WebSocket连接 function initWebSocket() { // 注意:这里需要替换成你实际的后端WebSocket地址 const wsUrl = 'ws://localhost:8765'; // 示例地址 socket = new WebSocket(wsUrl); socket.onopen = function() { updateStatus('已连接到AI服务器', 'status-success'); console.log('WebSocket连接已建立'); }; socket.onmessage = function(event) { try { const data = JSON.parse(event.data); handleServerMessage(data); } catch (e) { console.error('解析服务器消息失败:', e); } }; socket.onerror = function(error) { console.error('WebSocket错误:', error); updateStatus('连接服务器时出错', 'status-error'); }; socket.onclose = function() { console.log('WebSocket连接已关闭'); updateStatus('与服务器连接断开', 'status-error'); // 尝试重连 setTimeout(() => { updateStatus('正在尝试重新连接...', 'status-processing'); initWebSocket(); }, 3000); }; } // 处理服务器消息 function handleServerMessage(data) { switch (data.type) { case 'progress': // 更新生成进度 const percent = Math.round((data.step / data.total_steps) * 100); updateProgress(percent, `AI正在绘制... (${data.step}/${data.total_steps})`); break; case 'result': // 收到生成的图像 updateProgress(100, '图像生成完成!'); displayGeneratedImage(data.image_data); updateStatus('图像生成成功!', 'status-success'); downloadBtn.disabled = false; break; case 'error': updateStatus(`生成失败: ${data.message}`, 'status-error'); updateProgress(0); break; default: console.log('收到未知消息类型:', data.type); } } // 将生成的Base64图像显示在输出画布上 function displayGeneratedImage(base64Data) { const img = new Image(); img.onload = function() { outputCtx.clearRect(0, 0, outputCanvas.width, outputCanvas.height); // 将图像绘制到输出画布,保持宽高比 const scale = Math.min(outputCanvas.width / img.width, outputCanvas.height / img.height); const x = (outputCanvas.width - img.width * scale) / 2; const y = (outputCanvas.height - img.height * scale) / 2; outputCtx.drawImage(img, x, y, img.width * scale, img.height * scale); }; img.src = `data:image/png;base64,${base64Data}`; } // 从画布获取图像数据并发送到服务器 function sendDrawingToServer() { if (!socket || socket.readyState !== WebSocket.OPEN) { updateStatus('未连接到服务器,请检查连接', 'status-error'); return; } // 1. 获取画布数据(Base64格式) const sketchDataUrl = drawingCanvas.toDataURL('image/png'); // 去掉 data URL 前缀 const base64Sketch = sketchDataUrl.replace(/^data:image\/\w+;base64,/, ''); // 2. 获取用户输入参数 const prompt = promptInput.value.trim(); if (!prompt) { updateStatus('请输入文字描述', 'status-error'); return; } const requestData = { type: 'generate', sketch_image: base64Sketch, prompt: prompt, negative_prompt: negativePromptInput.value.trim(), steps: parseInt(stepsInput.value), cfg_scale: parseFloat(cfgScaleInput.value), width: outputCanvas.width, height: outputCanvas.height }; // 3. 发送请求 socket.send(JSON.stringify(requestData)); // 4. 更新UI状态 updateStatus('已发送生成请求,等待AI处理...', 'status-processing'); updateProgress(0); downloadBtn.disabled = true; // 5. 在输出画布显示等待信息 outputCtx.clearRect(0, 0, outputCanvas.width, outputCanvas.height); outputCtx.fillStyle = '#f0f0f0'; outputCtx.fillRect(0, 0, outputCanvas.width, outputCanvas.height); outputCtx.fillStyle = '#666'; outputCtx.font = '18px sans-serif'; outputCtx.textAlign = 'center'; outputCtx.fillText('AI正在创作中...', outputCanvas.width/2, outputCanvas.height/2 - 20); outputCtx.font = '14px sans-serif'; outputCtx.fillText('这可能需要几秒到几十秒', outputCanvas.width/2, outputCanvas.height/2 + 10); } // 下载生成的图片 downloadBtn.addEventListener('click', () => { const dataUrl = outputCanvas.toDataURL('image/png'); const link = document.createElement('a'); link.download = `ai_generated_${Date.now()}.png`; link.href = dataUrl; link.click(); }); // 手动生成按钮事件 generateBtn.addEventListener('click', sendDrawingToServer); // 自动生成模式 autoGenerateBtn.addEventListener('click', function() { isAutoGenerateMode = !isAutoGenerateMode; if (isAutoGenerateMode) { this.textContent = '关闭实时生成模式'; this.style.backgroundColor = '#dc3545'; updateStatus('实时生成模式已开启。绘制时,松开鼠标后将自动生成。', 'status-processing'); // 开启模式时立即生成一次 if (promptInput.value.trim()) { sendDrawingToServer(); } } else { this.textContent = '开启实时生成模式'; this.style.backgroundColor = ''; updateStatus('实时生成模式已关闭', 'status-success'); if (autoGenerateTimeout) { clearTimeout(autoGenerateTimeout); } } }); // 在绘制结束时,如果是自动模式,则触发生成(防抖处理) drawingCanvas.addEventListener('mouseup', () => { if (isAutoGenerateMode && promptInput.value.trim()) { if (autoGenerateTimeout) { clearTimeout(autoGenerateTimeout); } // 延迟500毫秒发送,避免频繁请求 autoGenerateTimeout = setTimeout(() => { sendDrawingToServer(); }, 500); } }); // 初始化WebSocket连接 initWebSocket();前端部分的核心通信逻辑就完成了。现在,当用户点击“生成”或处于“自动生成模式”时,草图和数据会被打包成JSON,通过WebSocket发送给后端。
3.2 后端WebSocket服务器与模型调用
后端我们需要一个WebSocket服务器来接收前端的请求,调用Stable Diffusion模型,并返回结果。这里我们用Python的websockets库和diffusers库来搭建一个简单的示例。请注意,运行此后端需要具备Python环境和一定的GPU资源。
# server.py import asyncio import websockets import json import base64 from io import BytesIO from PIL import Image import torch from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline import logging # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 全局变量,用于缓存模型(避免每次请求都加载) sd_pipe = None img2img_pipe = None device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"使用设备: {device}") async def load_models(): """加载Stable Diffusion模型""" global sd_pipe, img2img_pipe model_id = "runwayml/stable-diffusion-v1-5" try: logger.info("正在加载文生图模型...") sd_pipe = StableDiffusionPipeline.from_pretrained( model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32, safety_checker=None, # 为简化示例,禁用安全检查器(生产环境请谨慎) requires_safety_checker=False ) sd_pipe = sd_pipe.to(device) if device == "cuda": sd_pipe.enable_attention_slicing() # 减少显存占用 logger.info("正在加载图生图模型...") img2img_pipe = StableDiffusionImg2ImgPipeline.from_pretrained( model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32, safety_checker=None, requires_safety_checker=False, vae=sd_pipe.vae, text_encoder=sd_pipe.text_encoder, tokenizer=sd_pipe.tokenizer, unet=sd_pipe.unet, scheduler=sd_pipe.scheduler, ) img2img_pipe = img2img_pipe.to(device) logger.info("模型加载完成!") return True except Exception as e: logger.error(f"加载模型失败: {e}") return False def pil_to_base64(image): """将PIL图像转换为Base64字符串""" buffered = BytesIO() image.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode('utf-8') def base64_to_pil(image_base64): """将Base64字符串转换为PIL图像""" image_data = base64.b64decode(image_base64) image = Image.open(BytesIO(image_data)) return image async def handle_generation_request(data, websocket): """处理生成请求""" try: sketch_image = data.get("sketch_image") prompt = data.get("prompt", "") negative_prompt = data.get("negative_prompt", "") steps = data.get("steps", 20) cfg_scale = data.get("cfg_scale", 7.5) width = data.get("width", 512) height = data.get("height", 512) # 验证必要参数 if not prompt: await websocket.send(json.dumps({ "type": "error", "message": "提示词(prompt)不能为空" })) return # 发送进度更新 async def callback(step, timestep, latents): progress_data = { "type": "progress", "step": step, "total_steps": steps } try: await websocket.send(json.dumps(progress_data)) except: pass # 如果连接断开,忽略发送错误 logger.info(f"开始生成: prompt='{prompt[:50]}...', steps={steps}") if sketch_image: # 图生图模式:使用草图作为初始图像 init_image = base64_to_pil(sketch_image) # 调整草图尺寸到目标尺寸 init_image = init_image.resize((width, height), Image.Resampling.LANCZOS) # 将草图转换为RGB模式(如果是RGBA,去除Alpha通道) if init_image.mode != "RGB": init_image = init_image.convert("RGB") # 使用图生图管道 result = img2img_pipe( prompt=prompt, negative_prompt=negative_prompt, image=init_image, strength=0.75, # 控制草图的影响程度,0.75是个不错的起点 num_inference_steps=steps, guidance_scale=cfg_scale, callback=callback, callback_steps=1 ) else: # 文生图模式:仅使用文字描述 result = sd_pipe( prompt=prompt, negative_prompt=negative_prompt, width=width, height=height, num_inference_steps=steps, guidance_scale=cfg_scale, callback=callback, callback_steps=1 ) # 获取生成的图像 generated_image = result.images[0] # 转换为Base64并发送 image_base64 = pil_to_base64(generated_image) await websocket.send(json.dumps({ "type": "result", "image_data": image_base64 })) logger.info("生成完成并已发送") except Exception as e: logger.error(f"生成过程中出错: {e}") await websocket.send(json.dumps({ "type": "error", "message": f"生成失败: {str(e)}" })) async def handle_client(websocket, path): """处理WebSocket客户端连接""" client_ip = websocket.remote_address[0] logger.info(f"客户端连接: {client_ip}") try: async for message in websocket: try: data = json.loads(message) msg_type = data.get("type") if msg_type == "generate": await handle_generation_request(data, websocket) else: await websocket.send(json.dumps({ "type": "error", "message": f"未知的请求类型: {msg_type}" })) except json.JSONDecodeError: await websocket.send(json.dumps({ "type": "error", "message": "无效的JSON格式" })) except Exception as e: logger.error(f"处理消息时出错: {e}") await websocket.send(json.dumps({ "type": "error", "message": f"服务器内部错误: {str(e)}" })) except websockets.exceptions.ConnectionClosed: logger.info(f"客户端断开连接: {client_ip}") except Exception as e: logger.error(f"连接处理异常: {e}") async def main(): # 先加载模型 if not await load_models(): logger.error("模型加载失败,服务器启动中止") return # 启动WebSocket服务器 server = await websockets.serve( handle_client, "0.0.0.0", # 监听所有网络接口 8765, # 端口号,需与前端的 ws://localhost:8765 对应 ping_interval=20, ping_timeout=40 ) logger.info("WebSocket服务器已启动,监听端口 8765") logger.info("前端地址: http://localhost:8000 (假设前端运行在8000端口)") await server.wait_closed() if __name__ == "__main__": asyncio.run(main())这个后端服务器做了以下几件事:
- 加载模型:启动时加载Stable Diffusion v1-5的文生图和图生图管道。
- 处理WebSocket连接:监听来自前端的连接。
- 处理生成请求:接收包含草图(Base64)、提示词等参数的JSON数据。
- 调用模型:根据是否有草图,选择使用图生图或文生图模式。
- 进度回调:在生成过程中,定期向前端发送进度更新。
- 返回结果:将生成的图像转换为Base64格式,通过WebSocket发送回前端。
运行后端:
- 安装依赖:
pip install websockets diffusers torch pillow - 确保有足够的GPU内存(至少4-6GB用于FP16推理)。
- 运行脚本:
python server.py
运行前端:由于前端使用了WebSocket,直接通过file://协议打开HTML文件可能会遇到跨域问题。建议使用一个简单的HTTP服务器来提供前端文件。在项目目录下运行:python -m http.server 8000,然后在浏览器中访问http://localhost:8000。
4. 优化、调试与扩展思路
一个基础版本已经完成了。但在实际使用中,你可能会遇到一些问题,或者想让它变得更好。这里分享一些优化和扩展的思路。
4.1 性能与体验优化
- 前端防抖与节流:在“自动生成模式”下,我们使用了简单的延时。可以引入更完善的防抖函数,避免在用户快速连续绘画时发送过多请求。
- 草图预处理:后端的
base64_to_pil函数可以加入更多的图像预处理,比如自动裁剪空白边缘、调整对比度,让草图对模型的引导更有效。 - 生成队列:如果有多用户同时使用,后端需要实现一个请求队列,避免GPU内存溢出。可以使用
asyncio.Queue来管理。 - 连接稳定性:前端WebSocket断线重连的逻辑可以更健壮,比如指数退避重试。
- 生成中断:允许用户在生成过程中取消任务。这需要后端支持中断扩散过程,并向前端发送一个特殊的控制消息。
4.2 功能扩展
- 更多控制参数:在前端添加更多的Stable Diffusion参数控制,如
seed(随机种子)、sampler(采样器选择)。 - 草图强度调节:添加一个滑块,让用户实时调整
strength参数,控制草图对最终成图的影响程度。 - 风格预设:提供一些预设的提示词模板或风格(如“卡通风格”、“水墨画风”、“科幻场景”),方便用户快速选择。
- 历史记录:在本地存储(
localStorage)或后端数据库保存用户的生成记录,方便回溯和对比。 - 批量生成:允许用户一次生成多张图,并选择最满意的一张。
- 高级编辑:在输出画布上提供简单的后期编辑功能,如裁剪、滤镜、添加文字等。
4.3 部署与安全考虑
- 跨域问题:生产环境中,前端和后端可能部署在不同域名下。需要在后端设置正确的CORS头部,或者使用反向代理(如Nginx)将前后端统一到一个域名下。
- 身份验证:为WebSocket连接添加简单的Token验证,防止服务被滥用。
- 输入验证与过滤:后端必须对前端传来的提示词进行严格的审核和过滤,防止生成不当内容。
- 资源隔离:考虑使用容器化技术(如Docker)来部署模型服务,实现资源隔离和弹性伸缩。
- 使用更高效的推理后端:对于生产环境,可以考虑使用更专业的推理服务器,如
Triton Inference Server,或者调用云服务商提供的AI模型API。
5. 总结
走完这一趟,你会发现,将Stable Diffusion这样的“庞然大物”集成到一个轻快的Web应用中,并没有想象中那么困难。核心在于清晰的架构设计:交互前端负责捕获用户意图和展示结果,实时通信层(WebSocket)负责高效的数据穿梭,而模型后端则专注于重型计算。
我们构建的这个“魔法画板”,只是一个起点。它的价值在于展示了一种可能性——AI能力可以变得如此触手可及和互动自然。你可以基于这个骨架,把它变得更强大、更智能。比如,结合ControlNet让草图控制力更强,或者集成LoRA模型来固定某种画风。
技术最终要服务于创造。这个项目最大的乐趣,或许不在于代码本身,而在于它打开了一扇门:让不擅长绘画的人也能通过简单的线条和文字,召唤出脑海中的画面。当你看到自己随手画的一个圆圈,在AI的加持下变成一颗精致的星球时,那种感觉是非常奇妙的。
如果你对其中某个部分特别感兴趣,比如如何优化WebSocket在大流量下的表现,或者如何用ControlNet实现更精准的草图控制,那又是一个可以深入探索的新世界了。希望这个实战项目能给你带来一些启发,动手试试,把你的想法也变成可交互的现实吧。
获取更多AI镜像
想探索更多AI镜像和应用场景?访问 CSDN星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。
