107 lines
3.8 KiB
Python
107 lines
3.8 KiB
Python
|
|
import asyncio
|
||
|
|
from typing import Set
|
||
|
|
from fastapi import WebSocket
|
||
|
|
from app.core.logging import logger
|
||
|
|
|
||
|
|
|
||
|
|
class WebSocketManager:
|
||
|
|
def __init__(self):
|
||
|
|
self._clients: Set[WebSocket] = set()
|
||
|
|
self._poll_tasks: dict[str, asyncio.Task] = {}
|
||
|
|
|
||
|
|
async def connect(self, websocket: WebSocket):
|
||
|
|
await websocket.accept()
|
||
|
|
self._clients.add(websocket)
|
||
|
|
logger.info(f"WebSocket client connected. Total: {len(self._clients)}")
|
||
|
|
|
||
|
|
def disconnect(self, websocket: WebSocket):
|
||
|
|
self._clients.discard(websocket)
|
||
|
|
logger.info(f"WebSocket client disconnected. Total: {len(self._clients)}")
|
||
|
|
|
||
|
|
async def broadcast(self, message: dict):
|
||
|
|
if not self._clients:
|
||
|
|
return
|
||
|
|
disconnected = set()
|
||
|
|
for client in self._clients:
|
||
|
|
try:
|
||
|
|
await client.send_json(message)
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Failed to send WebSocket message: {e}")
|
||
|
|
disconnected.add(client)
|
||
|
|
for client in disconnected:
|
||
|
|
self._clients.discard(client)
|
||
|
|
|
||
|
|
async def start_poll_task(self, task_id: str, session_id: str):
|
||
|
|
if task_id in self._poll_tasks:
|
||
|
|
return
|
||
|
|
task = asyncio.create_task(self._poll_session_status(task_id, session_id))
|
||
|
|
self._poll_tasks[task_id] = task
|
||
|
|
|
||
|
|
async def stop_poll_task(self, task_id: str):
|
||
|
|
task = self._poll_tasks.pop(task_id, None)
|
||
|
|
if task:
|
||
|
|
task.cancel()
|
||
|
|
|
||
|
|
async def _poll_session_status(self, task_id: str, session_id: str):
|
||
|
|
from app.services.opencode_client import opencode_client
|
||
|
|
from app.services.session_manager import session_manager
|
||
|
|
from app.models.session import TaskStatus
|
||
|
|
|
||
|
|
poll_interval = 3
|
||
|
|
max_attempts = 60
|
||
|
|
|
||
|
|
for _ in range(max_attempts):
|
||
|
|
try:
|
||
|
|
status = await opencode_client.get_session_status()
|
||
|
|
sessions = status.get("sessions", {})
|
||
|
|
session_info = sessions.get(session_id)
|
||
|
|
|
||
|
|
if session_info:
|
||
|
|
current_status = session_info.get("status", "")
|
||
|
|
if current_status == "finished":
|
||
|
|
await session_manager.update_task_status(
|
||
|
|
task_id, TaskStatus.COMPLETED
|
||
|
|
)
|
||
|
|
await self.broadcast(
|
||
|
|
{
|
||
|
|
"type": "task_status",
|
||
|
|
"task_id": task_id,
|
||
|
|
"status": TaskStatus.COMPLETED.value,
|
||
|
|
}
|
||
|
|
)
|
||
|
|
break
|
||
|
|
elif current_status == "error":
|
||
|
|
error_msg = session_info.get("error", "Unknown error")
|
||
|
|
await session_manager.update_task_status(
|
||
|
|
task_id, TaskStatus.FAILED, error_msg
|
||
|
|
)
|
||
|
|
await self.broadcast(
|
||
|
|
{
|
||
|
|
"type": "task_status",
|
||
|
|
"task_id": task_id,
|
||
|
|
"status": TaskStatus.FAILED.value,
|
||
|
|
"error": error_msg,
|
||
|
|
}
|
||
|
|
)
|
||
|
|
break
|
||
|
|
|
||
|
|
await self.broadcast(
|
||
|
|
{
|
||
|
|
"type": "task_status",
|
||
|
|
"task_id": task_id,
|
||
|
|
"status": TaskStatus.RUNNING.value,
|
||
|
|
}
|
||
|
|
)
|
||
|
|
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
break
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error polling session status: {e}")
|
||
|
|
|
||
|
|
await asyncio.sleep(poll_interval)
|
||
|
|
|
||
|
|
self._poll_tasks.pop(task_id, None)
|
||
|
|
|
||
|
|
|
||
|
|
ws_manager = WebSocketManager()
|