Files
XCClaw/app/services/websocket_manager.py

107 lines
3.8 KiB
Python

import asyncio
from typing import Set
from fastapi import WebSocket
from app.core.logging import logger
class WebSocketManager:
def __init__(self):
self._clients: Set[WebSocket] = set()
self._poll_tasks: dict[str, asyncio.Task] = {}
async def connect(self, websocket: WebSocket):
await websocket.accept()
self._clients.add(websocket)
logger.info(f"WebSocket client connected. Total: {len(self._clients)}")
def disconnect(self, websocket: WebSocket):
self._clients.discard(websocket)
logger.info(f"WebSocket client disconnected. Total: {len(self._clients)}")
async def broadcast(self, message: dict):
if not self._clients:
return
disconnected = set()
for client in self._clients:
try:
await client.send_json(message)
except Exception as e:
logger.error(f"Failed to send WebSocket message: {e}")
disconnected.add(client)
for client in disconnected:
self._clients.discard(client)
async def start_poll_task(self, task_id: str, session_id: str):
if task_id in self._poll_tasks:
return
task = asyncio.create_task(self._poll_session_status(task_id, session_id))
self._poll_tasks[task_id] = task
async def stop_poll_task(self, task_id: str):
task = self._poll_tasks.pop(task_id, None)
if task:
task.cancel()
async def _poll_session_status(self, task_id: str, session_id: str):
from app.services.opencode_client import opencode_client
from app.services.session_manager import session_manager
from app.models.session import TaskStatus
poll_interval = 3
max_attempts = 60
for _ in range(max_attempts):
try:
status = await opencode_client.get_session_status()
sessions = status.get("sessions", {})
session_info = sessions.get(session_id)
if session_info:
current_status = session_info.get("status", "")
if current_status == "finished":
await session_manager.update_task_status(
task_id, TaskStatus.COMPLETED
)
await self.broadcast(
{
"type": "task_status",
"task_id": task_id,
"status": TaskStatus.COMPLETED.value,
}
)
break
elif current_status == "error":
error_msg = session_info.get("error", "Unknown error")
await session_manager.update_task_status(
task_id, TaskStatus.FAILED, error_msg
)
await self.broadcast(
{
"type": "task_status",
"task_id": task_id,
"status": TaskStatus.FAILED.value,
"error": error_msg,
}
)
break
await self.broadcast(
{
"type": "task_status",
"task_id": task_id,
"status": TaskStatus.RUNNING.value,
}
)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error polling session status: {e}")
await asyncio.sleep(poll_interval)
self._poll_tasks.pop(task_id, None)
ws_manager = WebSocketManager()