| | from email import message |
| | from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request |
| | from fastapi.staticfiles import StaticFiles |
| | from fastapi.responses import HTMLResponse, FileResponse |
| | import uvicorn |
| | import json |
| | import asyncio |
| | import os |
| | from pathlib import Path |
| | from datetime import datetime |
| | from bw_utils import get_grandchild_folders, is_image, load_json_file |
| | from BookWorld import BookWorld |
| | os.chdir(os.path.dirname(os.path.abspath(__file__))) |
| |
|
| | app = FastAPI() |
| | default_icon_path = './frontend/assets/images/default-icon.jpg' |
| | config = load_json_file('config.json') |
| | experiment_name = config["preset_path"].split("/")[-1].split(".")[0] |
| | |
| | |
| |
|
| | for key in config: |
| | if "API_KEY" in key and config[key]: |
| | os.environ[key] = config[key] |
| |
|
| | static_file_abspath = os.path.abspath(os.path.join(os.path.dirname(__file__), 'frontend')) |
| | app.mount("/frontend", StaticFiles(directory=static_file_abspath), name="frontend") |
| |
|
| | |
| | PRESETS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'experiment_presets') |
| |
|
| | class ConnectionManager: |
| | def __init__(self): |
| | self.active_connections: dict[str, WebSocket] = {} |
| | self.story_tasks: dict[str, asyncio.Task] = {} |
| | if True: |
| | if "preset_path" in config and config["preset_path"]: |
| | if os.path.exists(config["preset_path"]): |
| | preset_path = config["preset_path"] |
| | else: |
| | raise ValueError(f"The preset path {config['preset_path']} does not exist.") |
| | elif "genre" in config and config["genre"]: |
| | genre = config["genre"] |
| | preset_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),f"./config/experiment_{genre}.json") |
| | else: |
| | raise ValueError("Please set the preset_path in `config.json`.") |
| | self.bw = BookWorld(preset_path = preset_path, |
| | world_llm_name = config["world_llm_name"], |
| | role_llm_name = config["role_llm_name"], |
| | embedding_name = config["embedding_model_name"]) |
| | self.bw.set_generator(rounds = config["rounds"], |
| | save_dir = config["save_dir"], |
| | if_save = config["if_save"], |
| | mode = config["mode"], |
| | scene_mode = config["scene_mode"],) |
| | else: |
| | from BookWorld_test import BookWorld_test |
| | self.bw = BookWorld_test() |
| | |
| | async def connect(self, websocket: WebSocket, client_id: str): |
| | await websocket.accept() |
| | self.active_connections[client_id] = websocket |
| | |
| | def disconnect(self, client_id: str): |
| | if client_id in self.active_connections: |
| | del self.active_connections[client_id] |
| | self.stop_story(client_id) |
| | |
| | def stop_story(self, client_id: str): |
| | if client_id in self.story_tasks: |
| | self.story_tasks[client_id].cancel() |
| | del self.story_tasks[client_id] |
| |
|
| | async def start_story(self, client_id: str): |
| | if client_id in self.story_tasks: |
| | |
| | self.stop_story(client_id) |
| | |
| | |
| | self.story_tasks[client_id] = asyncio.create_task( |
| | self.generate_story(client_id) |
| | ) |
| |
|
| | async def generate_story(self, client_id: str): |
| | """持续生成故事的协程""" |
| | try: |
| | while True: |
| | if client_id in self.active_connections: |
| | message,status = await self.get_next_message() |
| | await self.active_connections[client_id].send_json({ |
| | 'type': 'message', |
| | 'data': message |
| | }) |
| | await self.active_connections[client_id].send_json({ |
| | 'type': 'status_update', |
| | 'data': status |
| | }) |
| | |
| | await asyncio.sleep(2) |
| | else: |
| | break |
| | except asyncio.CancelledError: |
| | |
| | print(f"Story generation cancelled for client {client_id}") |
| | except Exception as e: |
| | print(f"Error in generate_story: {e}") |
| |
|
| | async def get_initial_data(self): |
| | """获取初始化数据""" |
| | data = { |
| | 'characters': self.bw.get_characters_info(), |
| | 'map': self.bw.get_map_info(), |
| | 'settings': self.bw.get_settings_info(), |
| | 'status': self.bw.get_current_status(), |
| | |
| | 'history_messages':[], |
| | } |
| | |
| | return data |
| | |
| | async def get_next_message(self): |
| | """从BookWorld获取下一条消息""" |
| | message = self.bw.generate_next_message() |
| | if not os.path.exists(message["icon"]) or not is_image(message["icon"]): |
| | message["icon"] = default_icon_path |
| | status = self.bw.get_current_status() |
| | |
| | return message,status |
| |
|
| | manager = ConnectionManager() |
| |
|
| | @app.get("/") |
| | async def get(): |
| | html_file = Path("index.html") |
| | return HTMLResponse(html_file.read_text(encoding="utf-8")) |
| |
|
| | @app.get("/data/{full_path:path}") |
| | async def get_file(full_path: str): |
| | |
| | base_paths = [ |
| | Path("/data/") |
| | ] |
| | |
| | for base_path in base_paths: |
| | file_path = base_path / full_path |
| | if file_path.exists() and file_path.is_file(): |
| | return FileResponse(file_path) |
| | else: |
| | return FileResponse(default_icon_path) |
| | |
| | raise HTTPException(status_code=404, detail="File not found") |
| |
|
| | @app.get("/api/list-presets") |
| | async def list_presets(): |
| | try: |
| | |
| | presets = [f for f in os.listdir(PRESETS_DIR) if f.endswith('.json')] |
| | return {"presets": presets} |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| | @app.post("/api/load-preset") |
| | async def load_preset(request: Request): |
| | try: |
| | data = await request.json() |
| | preset_name = data.get('preset') |
| | |
| | if not preset_name: |
| | raise HTTPException(status_code=400, detail="No preset specified") |
| | |
| | preset_path = os.path.join(PRESETS_DIR, preset_name) |
| | print(f"Loading preset from: {preset_path}") |
| | |
| | if not os.path.exists(preset_path): |
| | raise HTTPException(status_code=404, detail=f"Preset not found: {preset_path}") |
| | |
| | try: |
| | |
| | manager.bw = BookWorld( |
| | preset_path=preset_path, |
| | world_llm_name=config["world_llm_name"], |
| | role_llm_name=config["role_llm_name"], |
| | embedding_name=config["embedding_model_name"] |
| | ) |
| | config["preset_path"] = preset_path |
| | experiment_name = preset_path.split("/")[-1].split(".")[0] |
| | |
| | |
| | manager.bw.set_generator( |
| | rounds=config["rounds"], |
| | save_dir=config["save_dir"], |
| | if_save=config["if_save"], |
| | mode=config["mode"], |
| | scene_mode=config["scene_mode"] |
| | ) |
| | |
| | |
| | initial_data = await manager.get_initial_data() |
| | |
| | return { |
| | "success": True, |
| | "data": initial_data |
| | } |
| | except Exception as e: |
| | print(f"Error initializing BookWorld: {str(e)}") |
| | raise HTTPException(status_code=500, detail=f"Error initializing BookWorld: {str(e)}") |
| | |
| | except Exception as e: |
| | print(f"Error in load_preset: {str(e)}") |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| | @app.websocket("/ws/{client_id}") |
| | async def websocket_endpoint(websocket: WebSocket, client_id: str): |
| | await manager.connect(websocket, client_id) |
| | try: |
| | initial_data = await manager.get_initial_data() |
| | await websocket.send_json({ |
| | 'type': 'initial_data', |
| | 'data': initial_data |
| | }) |
| | |
| | while True: |
| | data = await websocket.receive_text() |
| | message = json.loads(data) |
| | |
| | if message['type'] == 'user_message': |
| | |
| | await websocket.send_json({ |
| | 'type': 'message', |
| | 'data': { |
| | 'username': 'User', |
| | 'timestamp': message['timestamp'], |
| | 'text': message['text'], |
| | 'icon': default_icon_path, |
| | } |
| | }) |
| | |
| | elif message['type'] == 'control': |
| | |
| | if message['action'] == 'start': |
| | await manager.start_story(client_id) |
| | elif message['action'] == 'pause': |
| | manager.stop_story(client_id) |
| | elif message['action'] == 'stop': |
| | manager.stop_story(client_id) |
| | |
| | |
| | elif message['type'] == 'edit_message': |
| | |
| | edit_data = message['data'] |
| | |
| | manager.bw.handle_message_edit( |
| | record_id=edit_data['uuid'], |
| | new_text=edit_data['text'] |
| | ) |
| | |
| | elif message['type'] == 'request_scene_characters': |
| | manager.bw.select_scene(message['scene']) |
| | scene_characters = manager.bw.get_characters_info() |
| | await websocket.send_json({ |
| | 'type': 'scene_characters', |
| | 'data': scene_characters |
| | }) |
| | |
| | elif message['type'] == 'generate_story': |
| | |
| | story_text = manager.bw.generate_story() |
| | |
| | await websocket.send_json({ |
| | 'type': 'message', |
| | 'data': { |
| | 'username': 'System', |
| | 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
| | 'text': story_text, |
| | 'icon': default_icon_path, |
| | 'type': 'story' |
| | } |
| | }) |
| | |
| | except Exception as e: |
| | print(f"WebSocket error: {e}") |
| | finally: |
| | manager.disconnect(client_id) |
| |
|
| | @app.post("/api/save-config") |
| | async def save_config(request: Request): |
| | global config |
| | global manager |
| | try: |
| | config_data = await request.json() |
| | |
| | if 'provider' not in config_data or 'model' not in config_data or 'apiKey' not in config_data: |
| | raise HTTPException(status_code=400, detail="缺少必要的字段") |
| |
|
| | llm_provider = config_data['provider'] |
| | model = config_data['model'] |
| | api_key = config_data['apiKey'] |
| | config['role_llm_name'] = model |
| | config['world_llm_name'] = model |
| | if 'openai' in llm_provider.lower(): |
| | os.environ['OPENAI_API_KEY'] = api_key |
| | elif 'anthropic' in llm_provider.lower(): |
| | os.environ['ANTHROPIC_API_KEY'] = api_key |
| | elif 'alibaba' in llm_provider.lower(): |
| | os.environ['DASHSCOPE_API_KEY'] = api_key |
| | elif 'openrouter' in llm_provider.lower(): |
| | os.environ['OPENROUTER_API_KEY'] = api_key |
| | |
| | manager.bw.server.reset_llm(model,model) |
| | return {"status": "success", "message": llm_provider + " 配置已保存"} |
| | |
| | except Exception as e: |
| | print(f"保存配置失败: {e}") |
| | raise HTTPException(status_code=500, detail="保存配置失败") |
| |
|
| | if __name__ == "__main__": |
| | uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True) |
| |
|