feat: 添加任务历史记录、持久化会话、WebSocket支持和数据持久化功能
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
|
||||
from app.models.session import CreateSessionRequest, SessionType
|
||||
from app.services.session_manager import session_manager
|
||||
from app.services.scheduler import scheduler_service, ScheduleTask
|
||||
from app.services.opencode_client import opencode_client
|
||||
from app.services.history import history_service
|
||||
from app.services.persistent_session import persistent_session_manager
|
||||
from app.services.websocket_manager import ws_manager
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/xcclaw", tags=["xcclaw"])
|
||||
@@ -17,6 +20,16 @@ async def health_check():
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await ws_manager.connect(websocket)
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
ws_manager.disconnect(websocket)
|
||||
|
||||
|
||||
@router.post("/task")
|
||||
async def create_task(request: CreateSessionRequest):
|
||||
task = await session_manager.create_task(request)
|
||||
@@ -31,6 +44,10 @@ async def execute_task(task_id: str):
|
||||
|
||||
@router.post("/task/{task_id}/execute_async")
|
||||
async def execute_task_async(task_id: str):
|
||||
task = await session_manager.get_task(task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
await session_manager.execute_task_async(task_id)
|
||||
return {"status": "started", "task_id": task_id}
|
||||
|
||||
@@ -70,3 +87,45 @@ async def delete_schedule(schedule_id: str):
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Schedule not found")
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
async def get_history(limit: int | None = None):
|
||||
return history_service.get_history(limit)
|
||||
|
||||
|
||||
@router.delete("/history")
|
||||
async def clear_history():
|
||||
history_service.clear_history()
|
||||
return {"cleared": True}
|
||||
|
||||
|
||||
@router.post("/persistent")
|
||||
async def create_persistent_session(name: str | None = None):
|
||||
session = await persistent_session_manager.create_session(name)
|
||||
return session
|
||||
|
||||
|
||||
@router.get("/persistent")
|
||||
async def list_persistent_sessions():
|
||||
return persistent_session_manager.list_sessions()
|
||||
|
||||
|
||||
@router.delete("/persistent/{session_id}")
|
||||
async def delete_persistent_session(session_id: str):
|
||||
result = await persistent_session_manager.delete_session(session_id)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Persistent session not found")
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
@router.post("/persistent/{session_id}/message")
|
||||
async def send_persistent_message(session_id: str, text: str):
|
||||
result = await persistent_session_manager.send_message(session_id, text)
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/persistent/{session_id}/message_async")
|
||||
async def send_persistent_message_async(session_id: str, text: str):
|
||||
await persistent_session_manager.send_message_async(session_id, text)
|
||||
return {"status": "sent", "session_id": session_id}
|
||||
|
||||
12
app/main.py
12
app/main.py
@@ -3,6 +3,9 @@ from fastapi import FastAPI
|
||||
from app.api.routes import router
|
||||
from app.services.scheduler import scheduler_service
|
||||
from app.services.opencode_client import opencode_client
|
||||
from app.services.session_manager import session_manager
|
||||
from app.services.history import history_service
|
||||
from app.services.persistent_session import persistent_session_manager
|
||||
from app.core.config import settings
|
||||
from app.core.logging import logger
|
||||
|
||||
@@ -10,8 +13,16 @@ from app.core.logging import logger
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info("Starting XCClaw server...")
|
||||
|
||||
history_service.load()
|
||||
session_manager.load()
|
||||
scheduler_service.load()
|
||||
persistent_session_manager.load()
|
||||
|
||||
scheduler_service.start()
|
||||
|
||||
yield
|
||||
|
||||
logger.info("Shutting down XCClaw server...")
|
||||
scheduler_service.shutdown()
|
||||
await opencode_client.close()
|
||||
@@ -29,6 +40,7 @@ app.include_router(router)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.app_host,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -13,6 +14,7 @@ class TaskStatus(str, Enum):
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
ABORTED = "aborted"
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
@@ -28,3 +30,28 @@ class Task(BaseModel):
|
||||
status: TaskStatus = TaskStatus.PENDING
|
||||
session_id: str | None = None
|
||||
schedule: str | None = None
|
||||
created_at: datetime | None = None
|
||||
started_at: datetime | None = None
|
||||
finished_at: datetime | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class PersistentSession(BaseModel):
|
||||
id: str
|
||||
session_id: str
|
||||
name: str
|
||||
created_at: datetime | None = None
|
||||
last_used_at: datetime | None = None
|
||||
|
||||
|
||||
class TaskHistory(BaseModel):
|
||||
id: str
|
||||
type: SessionType
|
||||
prompt: str
|
||||
status: TaskStatus
|
||||
session_id: str | None = None
|
||||
schedule: str | None = None
|
||||
created_at: str
|
||||
started_at: str | None = None
|
||||
finished_at: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
52
app/services/history.py
Normal file
52
app/services/history.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from datetime import datetime
|
||||
from app.models.session import TaskHistory, Task
|
||||
from app.services.storage import storage
|
||||
from app.core.logging import logger
|
||||
|
||||
HISTORY_KEY = "task_history"
|
||||
|
||||
|
||||
class HistoryService:
|
||||
def __init__(self):
|
||||
self._history: list[dict] = []
|
||||
|
||||
def load(self):
|
||||
data = storage.load(HISTORY_KEY, [])
|
||||
self._history = data if isinstance(data, list) else []
|
||||
logger.info(f"Loaded {len(self._history)} history records")
|
||||
|
||||
def save(self):
|
||||
storage.save(HISTORY_KEY, self._history)
|
||||
|
||||
def add_task_history(self, task: Task):
|
||||
history = TaskHistory(
|
||||
id=task.id,
|
||||
type=task.type,
|
||||
prompt=task.prompt,
|
||||
status=task.status,
|
||||
session_id=task.session_id,
|
||||
schedule=task.schedule,
|
||||
created_at=task.created_at.isoformat()
|
||||
if task.created_at
|
||||
else datetime.now().isoformat(),
|
||||
started_at=task.started_at.isoformat() if task.started_at else None,
|
||||
finished_at=task.finished_at.isoformat() if task.finished_at else None,
|
||||
error=task.error,
|
||||
)
|
||||
self._history.append(history.model_dump())
|
||||
self.save()
|
||||
logger.info(f"Added task {task.id} to history")
|
||||
|
||||
def get_history(self, limit: int | None = None) -> list[TaskHistory]:
|
||||
history = [TaskHistory(**h) for h in self._history]
|
||||
if limit:
|
||||
return history[-limit:]
|
||||
return history
|
||||
|
||||
def clear_history(self):
|
||||
self._history = []
|
||||
self.save()
|
||||
logger.info("Cleared task history")
|
||||
|
||||
|
||||
history_service = HistoryService()
|
||||
74
app/services/persistent_session.py
Normal file
74
app/services/persistent_session.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from datetime import datetime
|
||||
from app.models.session import PersistentSession
|
||||
from app.services.opencode_client import opencode_client
|
||||
from app.services.storage import storage
|
||||
from app.core.logging import logger
|
||||
|
||||
PERSISTENT_KEY = "persistent_sessions"
|
||||
|
||||
|
||||
class PersistentSessionManager:
|
||||
def __init__(self):
|
||||
self._sessions: dict[str, PersistentSession] = {}
|
||||
|
||||
def load(self):
|
||||
data = storage.load(PERSISTENT_KEY, {})
|
||||
if isinstance(data, dict):
|
||||
for k, v in data.items():
|
||||
self._sessions[k] = PersistentSession(**v)
|
||||
logger.info(f"Loaded {len(self._sessions)} persistent sessions")
|
||||
|
||||
def save(self):
|
||||
data = {k: v.model_dump() for k, v in self._sessions.items()}
|
||||
storage.save(PERSISTENT_KEY, data)
|
||||
|
||||
async def create_session(self, name: str | None = None) -> PersistentSession:
|
||||
session = await opencode_client.create_session(
|
||||
title=name or "Persistent Session"
|
||||
)
|
||||
ps = PersistentSession(
|
||||
id=session["id"],
|
||||
session_id=session["id"],
|
||||
name=name or f"Session {session['id'][:8]}",
|
||||
created_at=datetime.now(),
|
||||
last_used_at=datetime.now(),
|
||||
)
|
||||
self._sessions[ps.id] = ps
|
||||
self.save()
|
||||
logger.info(f"Created persistent session: {ps.id}")
|
||||
return ps
|
||||
|
||||
async def delete_session(self, session_id: str) -> bool:
|
||||
ps = self._sessions.get(session_id)
|
||||
if not ps:
|
||||
return False
|
||||
await opencode_client.delete_session(ps.session_id)
|
||||
del self._sessions[session_id]
|
||||
self.save()
|
||||
logger.info(f"Deleted persistent session: {session_id}")
|
||||
return True
|
||||
|
||||
async def send_message(self, session_id: str, text: str) -> dict:
|
||||
ps = self._sessions.get(session_id)
|
||||
if not ps:
|
||||
raise ValueError(f"Persistent session {session_id} not found")
|
||||
ps.last_used_at = datetime.now()
|
||||
self.save()
|
||||
return await opencode_client.send_message(ps.session_id, text)
|
||||
|
||||
async def send_message_async(self, session_id: str, text: str):
|
||||
ps = self._sessions.get(session_id)
|
||||
if not ps:
|
||||
raise ValueError(f"Persistent session {session_id} not found")
|
||||
ps.last_used_at = datetime.now()
|
||||
self.save()
|
||||
await opencode_client.send_message_async(ps.session_id, text)
|
||||
|
||||
def get_session(self, session_id: str) -> PersistentSession | None:
|
||||
return self._sessions.get(session_id)
|
||||
|
||||
def list_sessions(self) -> list[PersistentSession]:
|
||||
return list(self._sessions.values())
|
||||
|
||||
|
||||
persistent_session_manager = PersistentSessionManager()
|
||||
@@ -1,10 +1,12 @@
|
||||
import uuid
|
||||
from pydantic import BaseModel
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from app.services.session_manager import session_manager
|
||||
from app.services.storage import storage
|
||||
from app.core.logging import logger
|
||||
|
||||
SCHEDULE_KEY = "schedules"
|
||||
|
||||
|
||||
class ScheduleTask(BaseModel):
|
||||
id: str
|
||||
@@ -19,8 +21,29 @@ class SchedulerService:
|
||||
self.scheduler = AsyncIOScheduler()
|
||||
self.scheduled_tasks: dict[str, ScheduleTask] = {}
|
||||
|
||||
def load(self):
|
||||
data = storage.load(SCHEDULE_KEY, {})
|
||||
if isinstance(data, dict):
|
||||
for k, v in data.items():
|
||||
self.scheduled_tasks[k] = ScheduleTask(**v)
|
||||
logger.info(f"Loaded {len(self.scheduled_tasks)} scheduled tasks")
|
||||
|
||||
def save(self):
|
||||
data = {k: v.model_dump() for k, v in self.scheduled_tasks.items()}
|
||||
storage.save(SCHEDULE_KEY, data)
|
||||
|
||||
def start(self):
|
||||
self.scheduler.start()
|
||||
for task in self.scheduled_tasks.values():
|
||||
if task.enabled:
|
||||
trigger = CronTrigger.from_crontab(task.cron)
|
||||
self.scheduler.add_job(
|
||||
self._run_scheduled_task,
|
||||
trigger,
|
||||
args=[task.id],
|
||||
id=task.id,
|
||||
)
|
||||
logger.info(f"Loaded schedule: {task.name} ({task.cron})")
|
||||
logger.info("Scheduler started")
|
||||
|
||||
def shutdown(self):
|
||||
@@ -29,6 +52,7 @@ class SchedulerService:
|
||||
|
||||
async def add_schedule(self, task: ScheduleTask) -> ScheduleTask:
|
||||
self.scheduled_tasks[task.id] = task
|
||||
self.save()
|
||||
|
||||
if task.enabled:
|
||||
trigger = CronTrigger.from_crontab(task.cron)
|
||||
@@ -47,7 +71,12 @@ class SchedulerService:
|
||||
if not task:
|
||||
return False
|
||||
|
||||
self.scheduler.remove_job(task_id)
|
||||
try:
|
||||
self.scheduler.remove_job(task_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.save()
|
||||
logger.info(f"Removed schedule: {task.name}")
|
||||
return True
|
||||
|
||||
|
||||
@@ -1,21 +1,39 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from app.models.session import SessionType, Task, TaskStatus, CreateSessionRequest
|
||||
from app.services.opencode_client import opencode_client
|
||||
from app.services.storage import storage
|
||||
from app.services.history import history_service
|
||||
from app.core.logging import logger
|
||||
|
||||
TASKS_KEY = "tasks"
|
||||
|
||||
|
||||
class SessionManager:
|
||||
def __init__(self):
|
||||
self.tasks: dict[str, Task] = {}
|
||||
|
||||
def load(self):
|
||||
data = storage.load(TASKS_KEY, {})
|
||||
if isinstance(data, dict):
|
||||
for k, v in data.items():
|
||||
self.tasks[k] = Task(**v)
|
||||
logger.info(f"Loaded {len(self.tasks)} tasks")
|
||||
|
||||
def save(self):
|
||||
data = {k: v.model_dump() for k, v in self.tasks.items()}
|
||||
storage.save(TASKS_KEY, data)
|
||||
|
||||
async def create_task(self, request: CreateSessionRequest) -> Task:
|
||||
task_id = str(uuid.uuid4())
|
||||
task = Task(
|
||||
id=task_id,
|
||||
type=request.type,
|
||||
prompt=request.prompt or "",
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
self.tasks[task_id] = task
|
||||
self.save()
|
||||
logger.info(f"Created task: {task_id}")
|
||||
return task
|
||||
|
||||
@@ -25,20 +43,31 @@ class SessionManager:
|
||||
raise ValueError(f"Task {task_id} not found")
|
||||
|
||||
task.status = TaskStatus.RUNNING
|
||||
task.started_at = datetime.now()
|
||||
self.save()
|
||||
|
||||
session = await opencode_client.create_session(
|
||||
title=task.prompt[:50] if task.prompt else None
|
||||
)
|
||||
session_id = session["id"]
|
||||
task.session_id = session_id
|
||||
self.save()
|
||||
|
||||
logger.info(f"Executing task {task_id} with session {session_id}")
|
||||
|
||||
try:
|
||||
result = await opencode_client.send_message(session_id, task.prompt)
|
||||
task.status = TaskStatus.COMPLETED
|
||||
task.finished_at = datetime.now()
|
||||
self.save()
|
||||
history_service.add_task_history(task)
|
||||
return result
|
||||
except Exception as e:
|
||||
task.status = TaskStatus.FAILED
|
||||
task.finished_at = datetime.now()
|
||||
task.error = str(e)
|
||||
self.save()
|
||||
history_service.add_task_history(task)
|
||||
logger.error(f"Task {task_id} failed: {e}")
|
||||
raise
|
||||
|
||||
@@ -48,23 +77,47 @@ class SessionManager:
|
||||
raise ValueError(f"Task {task_id} not found")
|
||||
|
||||
task.status = TaskStatus.RUNNING
|
||||
task.started_at = datetime.now()
|
||||
self.save()
|
||||
|
||||
session = await opencode_client.create_session(
|
||||
title=task.prompt[:50] if task.prompt else None
|
||||
)
|
||||
session_id = session["id"]
|
||||
task.session_id = session_id
|
||||
self.save()
|
||||
|
||||
logger.info(f"Executing async task {task_id} with session {session_id}")
|
||||
|
||||
await opencode_client.send_message_async(session_id, task.prompt)
|
||||
|
||||
from app.services.websocket_manager import ws_manager
|
||||
|
||||
await ws_manager.start_poll_task(task_id, session_id)
|
||||
|
||||
async def update_task_status(
|
||||
self, task_id: str, status: TaskStatus, error: str | None = None
|
||||
):
|
||||
task = self.tasks.get(task_id)
|
||||
if task:
|
||||
task.status = status
|
||||
if status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.ABORTED):
|
||||
task.finished_at = datetime.now()
|
||||
if error:
|
||||
task.error = error
|
||||
self.save()
|
||||
history_service.add_task_history(task)
|
||||
|
||||
async def abort_task(self, task_id: str) -> bool:
|
||||
task = self.tasks.get(task_id)
|
||||
if not task or not task.session_id:
|
||||
return False
|
||||
|
||||
result = await opencode_client.abort_session(task.session_id)
|
||||
task.status = TaskStatus.FAILED
|
||||
task.status = TaskStatus.ABORTED
|
||||
task.finished_at = datetime.now()
|
||||
self.save()
|
||||
history_service.add_task_history(task)
|
||||
return result
|
||||
|
||||
async def get_task(self, task_id: str) -> Task | None:
|
||||
|
||||
40
app/services/storage.py
Normal file
40
app/services/storage.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import json
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from app.core.config import settings
|
||||
from app.core.logging import logger
|
||||
|
||||
|
||||
class JSONStorage:
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
self.data_dir = settings.data_dir
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _get_file_path(self, key: str) -> Path:
|
||||
return self.data_dir / f"{key}.json"
|
||||
|
||||
def load(self, key: str, default: Any = None) -> Any:
|
||||
file_path = self._get_file_path(key)
|
||||
with self._lock:
|
||||
if not file_path.exists():
|
||||
return default
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load {key}: {e}")
|
||||
return default
|
||||
|
||||
def save(self, key: str, data: Any) -> None:
|
||||
file_path = self._get_file_path(key)
|
||||
with self._lock:
|
||||
try:
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save {key}: {e}")
|
||||
|
||||
|
||||
storage = JSONStorage()
|
||||
106
app/services/websocket_manager.py
Normal file
106
app/services/websocket_manager.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import asyncio
|
||||
from typing import Set
|
||||
from fastapi import WebSocket
|
||||
from app.core.logging import logger
|
||||
|
||||
|
||||
class WebSocketManager:
|
||||
def __init__(self):
|
||||
self._clients: Set[WebSocket] = set()
|
||||
self._poll_tasks: dict[str, asyncio.Task] = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
self._clients.add(websocket)
|
||||
logger.info(f"WebSocket client connected. Total: {len(self._clients)}")
|
||||
|
||||
def disconnect(self, websocket: WebSocket):
|
||||
self._clients.discard(websocket)
|
||||
logger.info(f"WebSocket client disconnected. Total: {len(self._clients)}")
|
||||
|
||||
async def broadcast(self, message: dict):
|
||||
if not self._clients:
|
||||
return
|
||||
disconnected = set()
|
||||
for client in self._clients:
|
||||
try:
|
||||
await client.send_json(message)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send WebSocket message: {e}")
|
||||
disconnected.add(client)
|
||||
for client in disconnected:
|
||||
self._clients.discard(client)
|
||||
|
||||
async def start_poll_task(self, task_id: str, session_id: str):
|
||||
if task_id in self._poll_tasks:
|
||||
return
|
||||
task = asyncio.create_task(self._poll_session_status(task_id, session_id))
|
||||
self._poll_tasks[task_id] = task
|
||||
|
||||
async def stop_poll_task(self, task_id: str):
|
||||
task = self._poll_tasks.pop(task_id, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
|
||||
async def _poll_session_status(self, task_id: str, session_id: str):
|
||||
from app.services.opencode_client import opencode_client
|
||||
from app.services.session_manager import session_manager
|
||||
from app.models.session import TaskStatus
|
||||
|
||||
poll_interval = 3
|
||||
max_attempts = 60
|
||||
|
||||
for _ in range(max_attempts):
|
||||
try:
|
||||
status = await opencode_client.get_session_status()
|
||||
sessions = status.get("sessions", {})
|
||||
session_info = sessions.get(session_id)
|
||||
|
||||
if session_info:
|
||||
current_status = session_info.get("status", "")
|
||||
if current_status == "finished":
|
||||
await session_manager.update_task_status(
|
||||
task_id, TaskStatus.COMPLETED
|
||||
)
|
||||
await self.broadcast(
|
||||
{
|
||||
"type": "task_status",
|
||||
"task_id": task_id,
|
||||
"status": TaskStatus.COMPLETED.value,
|
||||
}
|
||||
)
|
||||
break
|
||||
elif current_status == "error":
|
||||
error_msg = session_info.get("error", "Unknown error")
|
||||
await session_manager.update_task_status(
|
||||
task_id, TaskStatus.FAILED, error_msg
|
||||
)
|
||||
await self.broadcast(
|
||||
{
|
||||
"type": "task_status",
|
||||
"task_id": task_id,
|
||||
"status": TaskStatus.FAILED.value,
|
||||
"error": error_msg,
|
||||
}
|
||||
)
|
||||
break
|
||||
|
||||
await self.broadcast(
|
||||
{
|
||||
"type": "task_status",
|
||||
"task_id": task_id,
|
||||
"status": TaskStatus.RUNNING.value,
|
||||
}
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error polling session status: {e}")
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
self._poll_tasks.pop(task_id, None)
|
||||
|
||||
|
||||
ws_manager = WebSocketManager()
|
||||
Reference in New Issue
Block a user