import asyncio from typing import Set from fastapi import WebSocket from app.core.logging import logger class WebSocketManager: def __init__(self): self._clients: Set[WebSocket] = set() self._poll_tasks: dict[str, asyncio.Task] = {} async def connect(self, websocket: WebSocket): await websocket.accept() self._clients.add(websocket) logger.info(f"WebSocket client connected. Total: {len(self._clients)}") def disconnect(self, websocket: WebSocket): self._clients.discard(websocket) logger.info(f"WebSocket client disconnected. Total: {len(self._clients)}") async def broadcast(self, message: dict): if not self._clients: return disconnected = set() for client in self._clients: try: await client.send_json(message) except Exception as e: logger.error(f"Failed to send WebSocket message: {e}") disconnected.add(client) for client in disconnected: self._clients.discard(client) async def start_poll_task(self, task_id: str, session_id: str): if task_id in self._poll_tasks: return task = asyncio.create_task(self._poll_session_status(task_id, session_id)) self._poll_tasks[task_id] = task async def stop_poll_task(self, task_id: str): task = self._poll_tasks.pop(task_id, None) if task: task.cancel() async def _poll_session_status(self, task_id: str, session_id: str): from app.services.opencode_client import opencode_client from app.services.session_manager import session_manager from app.models.session import TaskStatus poll_interval = 3 max_attempts = 60 for _ in range(max_attempts): try: status = await opencode_client.get_session_status() sessions = status.get("sessions", {}) session_info = sessions.get(session_id) if session_info: current_status = session_info.get("status", "") if current_status == "finished": await session_manager.update_task_status( task_id, TaskStatus.COMPLETED ) await self.broadcast( { "type": "task_status", "task_id": task_id, "status": TaskStatus.COMPLETED.value, } ) break elif current_status == "error": error_msg = session_info.get("error", "Unknown error") await session_manager.update_task_status( task_id, TaskStatus.FAILED, error_msg ) await self.broadcast( { "type": "task_status", "task_id": task_id, "status": TaskStatus.FAILED.value, "error": error_msg, } ) break await self.broadcast( { "type": "task_status", "task_id": task_id, "status": TaskStatus.RUNNING.value, } ) except asyncio.CancelledError: break except Exception as e: logger.error(f"Error polling session status: {e}") await asyncio.sleep(poll_interval) self._poll_tasks.pop(task_id, None) ws_manager = WebSocketManager()