feat: 添加任务历史记录、持久化会话、WebSocket支持和数据持久化功能
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
|
||||
from app.models.session import CreateSessionRequest, SessionType
|
||||
from app.services.session_manager import session_manager
|
||||
from app.services.scheduler import scheduler_service, ScheduleTask
|
||||
from app.services.opencode_client import opencode_client
|
||||
from app.services.history import history_service
|
||||
from app.services.persistent_session import persistent_session_manager
|
||||
from app.services.websocket_manager import ws_manager
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/xcclaw", tags=["xcclaw"])
|
||||
@@ -17,6 +20,16 @@ async def health_check():
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await ws_manager.connect(websocket)
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
ws_manager.disconnect(websocket)
|
||||
|
||||
|
||||
@router.post("/task")
|
||||
async def create_task(request: CreateSessionRequest):
|
||||
task = await session_manager.create_task(request)
|
||||
@@ -31,6 +44,10 @@ async def execute_task(task_id: str):
|
||||
|
||||
@router.post("/task/{task_id}/execute_async")
|
||||
async def execute_task_async(task_id: str):
|
||||
task = await session_manager.get_task(task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
await session_manager.execute_task_async(task_id)
|
||||
return {"status": "started", "task_id": task_id}
|
||||
|
||||
@@ -70,3 +87,45 @@ async def delete_schedule(schedule_id: str):
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Schedule not found")
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
async def get_history(limit: int | None = None):
|
||||
return history_service.get_history(limit)
|
||||
|
||||
|
||||
@router.delete("/history")
|
||||
async def clear_history():
|
||||
history_service.clear_history()
|
||||
return {"cleared": True}
|
||||
|
||||
|
||||
@router.post("/persistent")
|
||||
async def create_persistent_session(name: str | None = None):
|
||||
session = await persistent_session_manager.create_session(name)
|
||||
return session
|
||||
|
||||
|
||||
@router.get("/persistent")
|
||||
async def list_persistent_sessions():
|
||||
return persistent_session_manager.list_sessions()
|
||||
|
||||
|
||||
@router.delete("/persistent/{session_id}")
|
||||
async def delete_persistent_session(session_id: str):
|
||||
result = await persistent_session_manager.delete_session(session_id)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Persistent session not found")
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
@router.post("/persistent/{session_id}/message")
|
||||
async def send_persistent_message(session_id: str, text: str):
|
||||
result = await persistent_session_manager.send_message(session_id, text)
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/persistent/{session_id}/message_async")
|
||||
async def send_persistent_message_async(session_id: str, text: str):
|
||||
await persistent_session_manager.send_message_async(session_id, text)
|
||||
return {"status": "sent", "session_id": session_id}
|
||||
|
||||
12
app/main.py
12
app/main.py
@@ -3,6 +3,9 @@ from fastapi import FastAPI
|
||||
from app.api.routes import router
|
||||
from app.services.scheduler import scheduler_service
|
||||
from app.services.opencode_client import opencode_client
|
||||
from app.services.session_manager import session_manager
|
||||
from app.services.history import history_service
|
||||
from app.services.persistent_session import persistent_session_manager
|
||||
from app.core.config import settings
|
||||
from app.core.logging import logger
|
||||
|
||||
@@ -10,8 +13,16 @@ from app.core.logging import logger
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info("Starting XCClaw server...")
|
||||
|
||||
history_service.load()
|
||||
session_manager.load()
|
||||
scheduler_service.load()
|
||||
persistent_session_manager.load()
|
||||
|
||||
scheduler_service.start()
|
||||
|
||||
yield
|
||||
|
||||
logger.info("Shutting down XCClaw server...")
|
||||
scheduler_service.shutdown()
|
||||
await opencode_client.close()
|
||||
@@ -29,6 +40,7 @@ app.include_router(router)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.app_host,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -13,6 +14,7 @@ class TaskStatus(str, Enum):
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
ABORTED = "aborted"
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
@@ -28,3 +30,28 @@ class Task(BaseModel):
|
||||
status: TaskStatus = TaskStatus.PENDING
|
||||
session_id: str | None = None
|
||||
schedule: str | None = None
|
||||
created_at: datetime | None = None
|
||||
started_at: datetime | None = None
|
||||
finished_at: datetime | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class PersistentSession(BaseModel):
|
||||
id: str
|
||||
session_id: str
|
||||
name: str
|
||||
created_at: datetime | None = None
|
||||
last_used_at: datetime | None = None
|
||||
|
||||
|
||||
class TaskHistory(BaseModel):
|
||||
id: str
|
||||
type: SessionType
|
||||
prompt: str
|
||||
status: TaskStatus
|
||||
session_id: str | None = None
|
||||
schedule: str | None = None
|
||||
created_at: str
|
||||
started_at: str | None = None
|
||||
finished_at: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
52
app/services/history.py
Normal file
52
app/services/history.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from datetime import datetime
|
||||
from app.models.session import TaskHistory, Task
|
||||
from app.services.storage import storage
|
||||
from app.core.logging import logger
|
||||
|
||||
HISTORY_KEY = "task_history"
|
||||
|
||||
|
||||
class HistoryService:
|
||||
def __init__(self):
|
||||
self._history: list[dict] = []
|
||||
|
||||
def load(self):
|
||||
data = storage.load(HISTORY_KEY, [])
|
||||
self._history = data if isinstance(data, list) else []
|
||||
logger.info(f"Loaded {len(self._history)} history records")
|
||||
|
||||
def save(self):
|
||||
storage.save(HISTORY_KEY, self._history)
|
||||
|
||||
def add_task_history(self, task: Task):
|
||||
history = TaskHistory(
|
||||
id=task.id,
|
||||
type=task.type,
|
||||
prompt=task.prompt,
|
||||
status=task.status,
|
||||
session_id=task.session_id,
|
||||
schedule=task.schedule,
|
||||
created_at=task.created_at.isoformat()
|
||||
if task.created_at
|
||||
else datetime.now().isoformat(),
|
||||
started_at=task.started_at.isoformat() if task.started_at else None,
|
||||
finished_at=task.finished_at.isoformat() if task.finished_at else None,
|
||||
error=task.error,
|
||||
)
|
||||
self._history.append(history.model_dump())
|
||||
self.save()
|
||||
logger.info(f"Added task {task.id} to history")
|
||||
|
||||
def get_history(self, limit: int | None = None) -> list[TaskHistory]:
|
||||
history = [TaskHistory(**h) for h in self._history]
|
||||
if limit:
|
||||
return history[-limit:]
|
||||
return history
|
||||
|
||||
def clear_history(self):
|
||||
self._history = []
|
||||
self.save()
|
||||
logger.info("Cleared task history")
|
||||
|
||||
|
||||
history_service = HistoryService()
|
||||
74
app/services/persistent_session.py
Normal file
74
app/services/persistent_session.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from datetime import datetime
|
||||
from app.models.session import PersistentSession
|
||||
from app.services.opencode_client import opencode_client
|
||||
from app.services.storage import storage
|
||||
from app.core.logging import logger
|
||||
|
||||
PERSISTENT_KEY = "persistent_sessions"
|
||||
|
||||
|
||||
class PersistentSessionManager:
|
||||
def __init__(self):
|
||||
self._sessions: dict[str, PersistentSession] = {}
|
||||
|
||||
def load(self):
|
||||
data = storage.load(PERSISTENT_KEY, {})
|
||||
if isinstance(data, dict):
|
||||
for k, v in data.items():
|
||||
self._sessions[k] = PersistentSession(**v)
|
||||
logger.info(f"Loaded {len(self._sessions)} persistent sessions")
|
||||
|
||||
def save(self):
|
||||
data = {k: v.model_dump() for k, v in self._sessions.items()}
|
||||
storage.save(PERSISTENT_KEY, data)
|
||||
|
||||
async def create_session(self, name: str | None = None) -> PersistentSession:
|
||||
session = await opencode_client.create_session(
|
||||
title=name or "Persistent Session"
|
||||
)
|
||||
ps = PersistentSession(
|
||||
id=session["id"],
|
||||
session_id=session["id"],
|
||||
name=name or f"Session {session['id'][:8]}",
|
||||
created_at=datetime.now(),
|
||||
last_used_at=datetime.now(),
|
||||
)
|
||||
self._sessions[ps.id] = ps
|
||||
self.save()
|
||||
logger.info(f"Created persistent session: {ps.id}")
|
||||
return ps
|
||||
|
||||
async def delete_session(self, session_id: str) -> bool:
|
||||
ps = self._sessions.get(session_id)
|
||||
if not ps:
|
||||
return False
|
||||
await opencode_client.delete_session(ps.session_id)
|
||||
del self._sessions[session_id]
|
||||
self.save()
|
||||
logger.info(f"Deleted persistent session: {session_id}")
|
||||
return True
|
||||
|
||||
async def send_message(self, session_id: str, text: str) -> dict:
|
||||
ps = self._sessions.get(session_id)
|
||||
if not ps:
|
||||
raise ValueError(f"Persistent session {session_id} not found")
|
||||
ps.last_used_at = datetime.now()
|
||||
self.save()
|
||||
return await opencode_client.send_message(ps.session_id, text)
|
||||
|
||||
async def send_message_async(self, session_id: str, text: str):
|
||||
ps = self._sessions.get(session_id)
|
||||
if not ps:
|
||||
raise ValueError(f"Persistent session {session_id} not found")
|
||||
ps.last_used_at = datetime.now()
|
||||
self.save()
|
||||
await opencode_client.send_message_async(ps.session_id, text)
|
||||
|
||||
def get_session(self, session_id: str) -> PersistentSession | None:
|
||||
return self._sessions.get(session_id)
|
||||
|
||||
def list_sessions(self) -> list[PersistentSession]:
|
||||
return list(self._sessions.values())
|
||||
|
||||
|
||||
persistent_session_manager = PersistentSessionManager()
|
||||
@@ -1,10 +1,12 @@
|
||||
import uuid
|
||||
from pydantic import BaseModel
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from app.services.session_manager import session_manager
|
||||
from app.services.storage import storage
|
||||
from app.core.logging import logger
|
||||
|
||||
SCHEDULE_KEY = "schedules"
|
||||
|
||||
|
||||
class ScheduleTask(BaseModel):
|
||||
id: str
|
||||
@@ -19,8 +21,29 @@ class SchedulerService:
|
||||
self.scheduler = AsyncIOScheduler()
|
||||
self.scheduled_tasks: dict[str, ScheduleTask] = {}
|
||||
|
||||
def load(self):
|
||||
data = storage.load(SCHEDULE_KEY, {})
|
||||
if isinstance(data, dict):
|
||||
for k, v in data.items():
|
||||
self.scheduled_tasks[k] = ScheduleTask(**v)
|
||||
logger.info(f"Loaded {len(self.scheduled_tasks)} scheduled tasks")
|
||||
|
||||
def save(self):
|
||||
data = {k: v.model_dump() for k, v in self.scheduled_tasks.items()}
|
||||
storage.save(SCHEDULE_KEY, data)
|
||||
|
||||
def start(self):
|
||||
self.scheduler.start()
|
||||
for task in self.scheduled_tasks.values():
|
||||
if task.enabled:
|
||||
trigger = CronTrigger.from_crontab(task.cron)
|
||||
self.scheduler.add_job(
|
||||
self._run_scheduled_task,
|
||||
trigger,
|
||||
args=[task.id],
|
||||
id=task.id,
|
||||
)
|
||||
logger.info(f"Loaded schedule: {task.name} ({task.cron})")
|
||||
logger.info("Scheduler started")
|
||||
|
||||
def shutdown(self):
|
||||
@@ -29,6 +52,7 @@ class SchedulerService:
|
||||
|
||||
async def add_schedule(self, task: ScheduleTask) -> ScheduleTask:
|
||||
self.scheduled_tasks[task.id] = task
|
||||
self.save()
|
||||
|
||||
if task.enabled:
|
||||
trigger = CronTrigger.from_crontab(task.cron)
|
||||
@@ -47,7 +71,12 @@ class SchedulerService:
|
||||
if not task:
|
||||
return False
|
||||
|
||||
self.scheduler.remove_job(task_id)
|
||||
try:
|
||||
self.scheduler.remove_job(task_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.save()
|
||||
logger.info(f"Removed schedule: {task.name}")
|
||||
return True
|
||||
|
||||
|
||||
@@ -1,21 +1,39 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from app.models.session import SessionType, Task, TaskStatus, CreateSessionRequest
|
||||
from app.services.opencode_client import opencode_client
|
||||
from app.services.storage import storage
|
||||
from app.services.history import history_service
|
||||
from app.core.logging import logger
|
||||
|
||||
TASKS_KEY = "tasks"
|
||||
|
||||
|
||||
class SessionManager:
|
||||
def __init__(self):
|
||||
self.tasks: dict[str, Task] = {}
|
||||
|
||||
def load(self):
|
||||
data = storage.load(TASKS_KEY, {})
|
||||
if isinstance(data, dict):
|
||||
for k, v in data.items():
|
||||
self.tasks[k] = Task(**v)
|
||||
logger.info(f"Loaded {len(self.tasks)} tasks")
|
||||
|
||||
def save(self):
|
||||
data = {k: v.model_dump() for k, v in self.tasks.items()}
|
||||
storage.save(TASKS_KEY, data)
|
||||
|
||||
async def create_task(self, request: CreateSessionRequest) -> Task:
|
||||
task_id = str(uuid.uuid4())
|
||||
task = Task(
|
||||
id=task_id,
|
||||
type=request.type,
|
||||
prompt=request.prompt or "",
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
self.tasks[task_id] = task
|
||||
self.save()
|
||||
logger.info(f"Created task: {task_id}")
|
||||
return task
|
||||
|
||||
@@ -25,20 +43,31 @@ class SessionManager:
|
||||
raise ValueError(f"Task {task_id} not found")
|
||||
|
||||
task.status = TaskStatus.RUNNING
|
||||
task.started_at = datetime.now()
|
||||
self.save()
|
||||
|
||||
session = await opencode_client.create_session(
|
||||
title=task.prompt[:50] if task.prompt else None
|
||||
)
|
||||
session_id = session["id"]
|
||||
task.session_id = session_id
|
||||
self.save()
|
||||
|
||||
logger.info(f"Executing task {task_id} with session {session_id}")
|
||||
|
||||
try:
|
||||
result = await opencode_client.send_message(session_id, task.prompt)
|
||||
task.status = TaskStatus.COMPLETED
|
||||
task.finished_at = datetime.now()
|
||||
self.save()
|
||||
history_service.add_task_history(task)
|
||||
return result
|
||||
except Exception as e:
|
||||
task.status = TaskStatus.FAILED
|
||||
task.finished_at = datetime.now()
|
||||
task.error = str(e)
|
||||
self.save()
|
||||
history_service.add_task_history(task)
|
||||
logger.error(f"Task {task_id} failed: {e}")
|
||||
raise
|
||||
|
||||
@@ -48,23 +77,47 @@ class SessionManager:
|
||||
raise ValueError(f"Task {task_id} not found")
|
||||
|
||||
task.status = TaskStatus.RUNNING
|
||||
task.started_at = datetime.now()
|
||||
self.save()
|
||||
|
||||
session = await opencode_client.create_session(
|
||||
title=task.prompt[:50] if task.prompt else None
|
||||
)
|
||||
session_id = session["id"]
|
||||
task.session_id = session_id
|
||||
self.save()
|
||||
|
||||
logger.info(f"Executing async task {task_id} with session {session_id}")
|
||||
|
||||
await opencode_client.send_message_async(session_id, task.prompt)
|
||||
|
||||
from app.services.websocket_manager import ws_manager
|
||||
|
||||
await ws_manager.start_poll_task(task_id, session_id)
|
||||
|
||||
async def update_task_status(
|
||||
self, task_id: str, status: TaskStatus, error: str | None = None
|
||||
):
|
||||
task = self.tasks.get(task_id)
|
||||
if task:
|
||||
task.status = status
|
||||
if status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.ABORTED):
|
||||
task.finished_at = datetime.now()
|
||||
if error:
|
||||
task.error = error
|
||||
self.save()
|
||||
history_service.add_task_history(task)
|
||||
|
||||
async def abort_task(self, task_id: str) -> bool:
|
||||
task = self.tasks.get(task_id)
|
||||
if not task or not task.session_id:
|
||||
return False
|
||||
|
||||
result = await opencode_client.abort_session(task.session_id)
|
||||
task.status = TaskStatus.FAILED
|
||||
task.status = TaskStatus.ABORTED
|
||||
task.finished_at = datetime.now()
|
||||
self.save()
|
||||
history_service.add_task_history(task)
|
||||
return result
|
||||
|
||||
async def get_task(self, task_id: str) -> Task | None:
|
||||
|
||||
40
app/services/storage.py
Normal file
40
app/services/storage.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import json
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from app.core.config import settings
|
||||
from app.core.logging import logger
|
||||
|
||||
|
||||
class JSONStorage:
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
self.data_dir = settings.data_dir
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _get_file_path(self, key: str) -> Path:
|
||||
return self.data_dir / f"{key}.json"
|
||||
|
||||
def load(self, key: str, default: Any = None) -> Any:
|
||||
file_path = self._get_file_path(key)
|
||||
with self._lock:
|
||||
if not file_path.exists():
|
||||
return default
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load {key}: {e}")
|
||||
return default
|
||||
|
||||
def save(self, key: str, data: Any) -> None:
|
||||
file_path = self._get_file_path(key)
|
||||
with self._lock:
|
||||
try:
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save {key}: {e}")
|
||||
|
||||
|
||||
storage = JSONStorage()
|
||||
106
app/services/websocket_manager.py
Normal file
106
app/services/websocket_manager.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import asyncio
|
||||
from typing import Set
|
||||
from fastapi import WebSocket
|
||||
from app.core.logging import logger
|
||||
|
||||
|
||||
class WebSocketManager:
|
||||
def __init__(self):
|
||||
self._clients: Set[WebSocket] = set()
|
||||
self._poll_tasks: dict[str, asyncio.Task] = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
self._clients.add(websocket)
|
||||
logger.info(f"WebSocket client connected. Total: {len(self._clients)}")
|
||||
|
||||
def disconnect(self, websocket: WebSocket):
|
||||
self._clients.discard(websocket)
|
||||
logger.info(f"WebSocket client disconnected. Total: {len(self._clients)}")
|
||||
|
||||
async def broadcast(self, message: dict):
|
||||
if not self._clients:
|
||||
return
|
||||
disconnected = set()
|
||||
for client in self._clients:
|
||||
try:
|
||||
await client.send_json(message)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send WebSocket message: {e}")
|
||||
disconnected.add(client)
|
||||
for client in disconnected:
|
||||
self._clients.discard(client)
|
||||
|
||||
async def start_poll_task(self, task_id: str, session_id: str):
|
||||
if task_id in self._poll_tasks:
|
||||
return
|
||||
task = asyncio.create_task(self._poll_session_status(task_id, session_id))
|
||||
self._poll_tasks[task_id] = task
|
||||
|
||||
async def stop_poll_task(self, task_id: str):
|
||||
task = self._poll_tasks.pop(task_id, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
|
||||
async def _poll_session_status(self, task_id: str, session_id: str):
|
||||
from app.services.opencode_client import opencode_client
|
||||
from app.services.session_manager import session_manager
|
||||
from app.models.session import TaskStatus
|
||||
|
||||
poll_interval = 3
|
||||
max_attempts = 60
|
||||
|
||||
for _ in range(max_attempts):
|
||||
try:
|
||||
status = await opencode_client.get_session_status()
|
||||
sessions = status.get("sessions", {})
|
||||
session_info = sessions.get(session_id)
|
||||
|
||||
if session_info:
|
||||
current_status = session_info.get("status", "")
|
||||
if current_status == "finished":
|
||||
await session_manager.update_task_status(
|
||||
task_id, TaskStatus.COMPLETED
|
||||
)
|
||||
await self.broadcast(
|
||||
{
|
||||
"type": "task_status",
|
||||
"task_id": task_id,
|
||||
"status": TaskStatus.COMPLETED.value,
|
||||
}
|
||||
)
|
||||
break
|
||||
elif current_status == "error":
|
||||
error_msg = session_info.get("error", "Unknown error")
|
||||
await session_manager.update_task_status(
|
||||
task_id, TaskStatus.FAILED, error_msg
|
||||
)
|
||||
await self.broadcast(
|
||||
{
|
||||
"type": "task_status",
|
||||
"task_id": task_id,
|
||||
"status": TaskStatus.FAILED.value,
|
||||
"error": error_msg,
|
||||
}
|
||||
)
|
||||
break
|
||||
|
||||
await self.broadcast(
|
||||
{
|
||||
"type": "task_status",
|
||||
"task_id": task_id,
|
||||
"status": TaskStatus.RUNNING.value,
|
||||
}
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error polling session status: {e}")
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
self._poll_tasks.pop(task_id, None)
|
||||
|
||||
|
||||
ws_manager = WebSocketManager()
|
||||
763
docs/pytest-guide.md
Normal file
763
docs/pytest-guide.md
Normal file
@@ -0,0 +1,763 @@
|
||||
# Pytest 完整使用指南
|
||||
|
||||
Pytest 是 Python 最流行的测试框架,以其简洁的 API 和强大的功能被广泛应用于各种项目中。
|
||||
|
||||
## 目录
|
||||
|
||||
1. [快速开始](#1-快速开始)
|
||||
2. [编写测试](#2-编写测试)
|
||||
3. [Fixtures](#3-fixtures)
|
||||
4. [异步测试](#4-异步测试)
|
||||
5. [Mocking](#5-mocking)
|
||||
6. [参数化测试](#6-参数化测试)
|
||||
7. [跳过测试](#7-跳过测试)
|
||||
8. [测试夹具作用域](#8-测试夹具作用域)
|
||||
9. [测试组织](#9-测试组织)
|
||||
10. [常用命令](#10-常用命令)
|
||||
11. [配置文件](#11-配置文件)
|
||||
|
||||
---
|
||||
|
||||
## 1. 快速开始
|
||||
|
||||
### 安装
|
||||
|
||||
```bash
|
||||
pip install pytest pytest-asyncio
|
||||
```
|
||||
|
||||
### 运行测试
|
||||
|
||||
```bash
|
||||
# 运行所有测试
|
||||
pytest
|
||||
|
||||
# 运行指定文件
|
||||
pytest tests/test_api.py
|
||||
|
||||
# 运行指定测试函数
|
||||
pytest tests/test_api.py::test_health_check
|
||||
|
||||
# 显示详细输出
|
||||
pytest -v
|
||||
|
||||
# 显示打印内容
|
||||
pytest -s
|
||||
|
||||
# 在第一个失败处停止
|
||||
pytest -x
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 2. 编写测试
|
||||
|
||||
### 基本结构
|
||||
|
||||
```python
|
||||
# test_example.py
|
||||
def test_basic_assertion():
|
||||
"""简单的断言测试"""
|
||||
assert 1 + 1 == 2
|
||||
|
||||
def test_string_operations():
|
||||
"""字符串操作测试"""
|
||||
text = "hello world"
|
||||
assert text.upper() == "HELLO WORLD"
|
||||
assert text.startswith("hello")
|
||||
assert len(text) == 11
|
||||
|
||||
def test_list_operations():
|
||||
"""列表操作测试"""
|
||||
items = [1, 2, 3]
|
||||
assert len(items) == 3
|
||||
assert items.pop() == 3
|
||||
assert items == [1, 2]
|
||||
```
|
||||
|
||||
### 测试类
|
||||
|
||||
```python
|
||||
class TestMathOperations:
|
||||
def test_addition(self):
|
||||
assert 2 + 3 == 5
|
||||
|
||||
def test_subtraction(self):
|
||||
assert 5 - 3 == 2
|
||||
|
||||
def test_multiplication(self):
|
||||
assert 3 * 4 == 12
|
||||
|
||||
|
||||
class TestStringOperations:
|
||||
def test_reverse(self):
|
||||
assert "hello"[::-1] == "olleh"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. Fixtures
|
||||
|
||||
Fixtures 是 pytest 最强大的特性之一,用于提供测试所需的依赖和数据。
|
||||
|
||||
### 基本 Fixture
|
||||
|
||||
```python
|
||||
import pytest
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data():
|
||||
"""返回一个简单的数据字典"""
|
||||
return {"name": "test", "value": 100}
|
||||
|
||||
def test_using_fixture(sample_data):
|
||||
"""使用 fixture 的测试"""
|
||||
assert sample_data["value"] > 50
|
||||
```
|
||||
|
||||
### Fixture with Setup/Teardown
|
||||
|
||||
```python
|
||||
import pytest
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
@pytest.fixture
|
||||
def temp_directory():
|
||||
"""创建临时目录,测试后清理"""
|
||||
temp_path = tempfile.mkdtemp()
|
||||
yield temp_path # 测试在此运行
|
||||
shutil.rmtree(temp_path, ignore_errors=True) # 清理
|
||||
|
||||
def test_write_file(temp_directory):
|
||||
"""使用临时目录的测试"""
|
||||
import os
|
||||
file_path = os.path.join(temp_directory, "test.txt")
|
||||
with open(file_path, "w") as f:
|
||||
f.write("hello")
|
||||
|
||||
assert os.path.exists(file_path)
|
||||
# 测试结束后 temp_path 会被自动清理
|
||||
```
|
||||
|
||||
### Fixture 参数化
|
||||
|
||||
```python
|
||||
@pytest.fixture
|
||||
def user_data():
|
||||
return {"username": "testuser", "email": "test@example.com"}
|
||||
|
||||
|
||||
def test_user_creation(user_data):
|
||||
assert user_data["username"] == "testuser"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. 异步测试
|
||||
|
||||
### 安装异步支持
|
||||
|
||||
```bash
|
||||
pip install pytest-asyncio
|
||||
```
|
||||
|
||||
### 配置 pytest.ini
|
||||
|
||||
```ini
|
||||
[pytest]
|
||||
asyncio_mode = auto
|
||||
```
|
||||
|
||||
### 编写异步测试
|
||||
|
||||
```python
|
||||
import pytest
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_operation():
|
||||
"""异步测试函数"""
|
||||
import asyncio
|
||||
|
||||
async def fetch_data():
|
||||
await asyncio.sleep(0.1)
|
||||
return {"data": "result"}
|
||||
|
||||
result = await fetch_data()
|
||||
assert result["data"] == "result"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_with_db():
|
||||
"""模拟异步数据库操作"""
|
||||
import asyncio
|
||||
|
||||
async def connect_db():
|
||||
await asyncio.sleep(0.1)
|
||||
return "database_connected"
|
||||
|
||||
result = await connect_db()
|
||||
assert "connected" in result
|
||||
```
|
||||
|
||||
### Fixture 返回协程
|
||||
|
||||
```python
|
||||
import pytest
|
||||
import asyncio
|
||||
|
||||
@pytest.fixture
|
||||
async def async_database():
|
||||
"""异步 fixture - 连接数据库"""
|
||||
await asyncio.sleep(0.1) # 模拟连接
|
||||
db = {"connected": True}
|
||||
yield db
|
||||
# Teardown
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_database_query(async_database):
|
||||
assert async_database["connected"] is True
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. Mocking
|
||||
|
||||
### 使用 unittest.mock
|
||||
|
||||
```python
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
import pytest
|
||||
|
||||
def test_mock_function():
|
||||
"""Mock 函数调用"""
|
||||
mock = Mock(return_value=42)
|
||||
|
||||
result = mock()
|
||||
assert result == 42
|
||||
mock.assert_called_once()
|
||||
|
||||
|
||||
def test_patch_environment():
|
||||
"""Patch 环境变量"""
|
||||
with patch("os.environ", {"TEST": "value"}):
|
||||
import os
|
||||
assert os.environ.get("TEST") == "value"
|
||||
|
||||
|
||||
def test_mock_class():
|
||||
"""Mock 类"""
|
||||
mock_class = Mock()
|
||||
mock_class.method.return_value = "mocked"
|
||||
|
||||
result = mock_class.method()
|
||||
assert result == "mocked"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_mock():
|
||||
"""Mock 异步函数"""
|
||||
mock_async = AsyncMock(return_value="async result")
|
||||
|
||||
result = await mock_async()
|
||||
assert result == "async result"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_async():
|
||||
"""Patch 异步函数"""
|
||||
async def real_async_func():
|
||||
return "real"
|
||||
|
||||
with patch("__main__.real_async_func", new_callable=AsyncMock) as mock_func:
|
||||
mock_func.return_value = "mocked"
|
||||
result = await mock_func()
|
||||
assert result == "mocked"
|
||||
```
|
||||
|
||||
### Mocking 第三方库
|
||||
|
||||
```python
|
||||
from unittest.mock import patch, MagicMock
|
||||
import pytest
|
||||
|
||||
# Mock HTTP 请求
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_call():
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"status": "ok"}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_instance = AsyncMock()
|
||||
mock_instance.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value = mock_instance
|
||||
|
||||
# 实际测试代码会使用 httpx.AsyncClient
|
||||
client = mock_client.get("http://api.example.com")
|
||||
# ...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. 参数化测试
|
||||
|
||||
一个测试函数使用不同参数多次运行。
|
||||
|
||||
### @pytest.mark.parametrize
|
||||
|
||||
```python
|
||||
import pytest
|
||||
|
||||
@pytest.mark.parametrize("input,expected", [
|
||||
(1, 2),
|
||||
(2, 4),
|
||||
(3, 6),
|
||||
])
|
||||
def test_double(input, expected):
|
||||
"""测试倍数函数"""
|
||||
assert input * 2 == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("a,b,result", [
|
||||
(1, 1, 2),
|
||||
(2, 3, 5),
|
||||
(10, 20, 30),
|
||||
])
|
||||
def test_addition(a, b, result):
|
||||
assert a + b == result
|
||||
|
||||
|
||||
# 字符串参数
|
||||
@pytest.mark.parametrize("text,upper", [
|
||||
("hello", "HELLO"),
|
||||
("World", "WORLD"),
|
||||
("pytest", "PYTEST"),
|
||||
])
|
||||
def test_uppercase(text, upper):
|
||||
assert text.upper() == upper
|
||||
```
|
||||
|
||||
### 多参数组合
|
||||
|
||||
```python
|
||||
import pytest
|
||||
|
||||
@pytest.mark.parametrize("x", [1, 2, 3])
|
||||
@pytest.mark.parametrize("y", [10, 20])
|
||||
def test_combinations(x, y):
|
||||
"""x=1,2,3 与 y=10,20 的所有组合"""
|
||||
assert x * y > 0
|
||||
# 运行 6 次: (1,10), (1,20), (2,10), (2,20), (3,10), (3,20)
|
||||
```
|
||||
|
||||
### 使用 pytest 对象的 param
|
||||
|
||||
```python
|
||||
import pytest
|
||||
|
||||
@pytest.fixture
|
||||
def pytest_param_user():
|
||||
return pytest.param("user1", marks=pytest.mark.slow)
|
||||
|
||||
|
||||
def test_user_stuff():
|
||||
pass
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. 跳过测试
|
||||
|
||||
### 跳过条件
|
||||
|
||||
```python
|
||||
import pytest
|
||||
import sys
|
||||
|
||||
@pytest.mark.skip(reason="功能未实现")
|
||||
def test_not_implemented():
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.version_info < (3, 8), reason="需要 Python 3.8+")
|
||||
def test_new_feature():
|
||||
pass
|
||||
|
||||
|
||||
# 动态跳过
|
||||
def test_conditionally():
|
||||
if some_condition:
|
||||
pytest.skip("条件不满足")
|
||||
assert True
|
||||
```
|
||||
|
||||
### 预期失败
|
||||
|
||||
```python
|
||||
@pytest.mark.xfail(reason="已知 bug")
|
||||
def test_known_bug():
|
||||
assert False # 预期失败
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8. 测试夹具作用域
|
||||
|
||||
Fixture 可以指定不同的作用域:
|
||||
|
||||
```python
|
||||
# function: 每个测试函数执行一次(默认)
|
||||
@pytest.fixture
|
||||
def function_fixture():
|
||||
print("function scope")
|
||||
return "function"
|
||||
|
||||
# class: 每个测试类执行一次
|
||||
@pytest.fixture(scope="class")
|
||||
def class_fixture():
|
||||
print("class scope")
|
||||
return "class"
|
||||
|
||||
# module: 每个模块执行一次
|
||||
@pytest.fixture(scope="module")
|
||||
def module_fixture():
|
||||
print("module scope")
|
||||
return "module"
|
||||
|
||||
# session: 整个测试会话执行一次
|
||||
@pytest.fixture(scope="session")
|
||||
def session_fixture():
|
||||
print("session scope")
|
||||
return "session"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 9. 测试组织
|
||||
|
||||
### 目录结构建议
|
||||
|
||||
```
|
||||
project/
|
||||
├── tests/
|
||||
│ ├── __init__.py
|
||||
│ ├── conftest.py # 共享 fixtures
|
||||
│ ├── test_api.py # API 测试
|
||||
│ ├── test_models.py # 模型测试
|
||||
│ ├── test_services/ # 服务测试
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── test_session.py
|
||||
│ │ └── test_scheduler.py
|
||||
│ └── test_integration/ # 集成测试
|
||||
│ └── test_workflow.py
|
||||
├── src/
|
||||
│ └── your_code.py
|
||||
└── pytest.ini
|
||||
```
|
||||
|
||||
### conftest.py 共享配置
|
||||
|
||||
```python
|
||||
# tests/conftest.py
|
||||
import pytest
|
||||
|
||||
@pytest.fixture
|
||||
def common_setup():
|
||||
"""所有测试共享的设置"""
|
||||
print("Setting up test environment")
|
||||
yield
|
||||
print("Tearing down test environment")
|
||||
|
||||
|
||||
# 自动使用某些 fixtures
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state():
|
||||
"""每个测试自动重置状态"""
|
||||
yield
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 10. 常用命令
|
||||
|
||||
```bash
|
||||
# 基本运行
|
||||
pytest # 运行所有测试
|
||||
pytest tests/ # 运行指定目录
|
||||
pytest test_file.py # 运行指定文件
|
||||
pytest -k "test_name" # 按名称过滤
|
||||
pytest -k "test_api" # 运行包含 "test_api" 的测试
|
||||
|
||||
# 输出控制
|
||||
pytest -v # 详细输出
|
||||
pytest -s # 显示 print 输出
|
||||
pytest --tb=short # 简短 traceback
|
||||
pytest --tb=long # 详细 traceback
|
||||
|
||||
# 调试
|
||||
pytest -x # 第一个失败后停止
|
||||
pytest --pdb # 失败时进入 debugger
|
||||
pytest -l # 显示失败的局部变量
|
||||
|
||||
# 覆盖率
|
||||
pytest --cov=src # 运行并计算覆盖率
|
||||
pytest --cov-report=html # 生成 HTML 报告
|
||||
|
||||
# 标记
|
||||
pytest -m "slow" # 运行带 slow 标记的测试
|
||||
pytest -m "not slow" # 排除 slow 标记
|
||||
|
||||
# 异步
|
||||
pytest --asyncio-mode=auto # 自动检测异步测试
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 11. 配置文件
|
||||
|
||||
### pytest.ini
|
||||
|
||||
```ini
|
||||
[pytest]
|
||||
minversion = 6.0
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
|
||||
# 异步配置
|
||||
asyncio_mode = auto
|
||||
|
||||
# 标记定义
|
||||
markers =
|
||||
slow: 耗时较长的测试
|
||||
integration: 集成测试
|
||||
unit: 单元测试
|
||||
|
||||
# 忽略路径
|
||||
norecursedirs = .git .tox build dist
|
||||
|
||||
# 自定义选项
|
||||
addopts = -v --tb=short
|
||||
```
|
||||
|
||||
### pyproject.toml
|
||||
|
||||
```toml
|
||||
[tool.pytest.ini_options]
|
||||
minversion = "6.0"
|
||||
testpaths = ["tests"]
|
||||
pythonpath = ["."]
|
||||
|
||||
[tool.pytest.markers]
|
||||
slow = "marks tests as slow"
|
||||
integration = "marks tests as integration tests"
|
||||
|
||||
[tool.coverage.run]
|
||||
source = ["src"]
|
||||
|
||||
[tool.coverage.report]
|
||||
exclude_lines = [
|
||||
"pragma: no cover",
|
||||
"def __repr__",
|
||||
"raise AssertionError",
|
||||
]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 高级技巧
|
||||
|
||||
### 1. 自定义标记
|
||||
|
||||
```python
|
||||
# 在代码中使用
|
||||
@pytest.mark.slow
|
||||
def test_slow_operation():
|
||||
import time
|
||||
time.sleep(5)
|
||||
assert True
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_api_workflow():
|
||||
pass
|
||||
|
||||
|
||||
# 在 pytest.ini 中注册
|
||||
[pytest]
|
||||
markers =
|
||||
slow: marks tests as slow (deselect with '-m "not slow"')
|
||||
integration: marks tests as integration tests
|
||||
```
|
||||
|
||||
### 2. 条件跳过
|
||||
|
||||
```python
|
||||
import pytest
|
||||
|
||||
# 基于环境变量
|
||||
import os
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.getenv("CI") == "true",
|
||||
reason="跳过 CI 环境"
|
||||
)
|
||||
def test_dev_only():
|
||||
pass
|
||||
|
||||
|
||||
# 基于平台
|
||||
@pytest.mark.skipif(
|
||||
sys.platform == "win32",
|
||||
reason="不适用于 Windows"
|
||||
)
|
||||
def test_linux_only():
|
||||
pass
|
||||
```
|
||||
|
||||
### 3. Fixture 依赖
|
||||
|
||||
```python
|
||||
@pytest.fixture
|
||||
def database():
|
||||
return {"connected": False}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_session(database):
|
||||
"""依赖另一个 fixture"""
|
||||
session = {"db": database, "active": True}
|
||||
return session
|
||||
|
||||
|
||||
def test_database_connection(db_session):
|
||||
assert db_session["active"] is True
|
||||
```
|
||||
|
||||
### 4. 工厂 Fixture
|
||||
|
||||
```python
|
||||
@pytest.fixture
|
||||
def user_factory():
|
||||
"""工厂 fixture - 每次返回新实例"""
|
||||
created = []
|
||||
|
||||
def _create_user(name):
|
||||
user = {"id": len(created) + 1, "name": name}
|
||||
created.append(user)
|
||||
return user
|
||||
|
||||
return _create_user
|
||||
|
||||
|
||||
def test_create_users(user_factory):
|
||||
user1 = user_factory("Alice")
|
||||
user2 = user_factory("Bob")
|
||||
|
||||
assert user1["id"] == 1
|
||||
assert user2["id"] == 2
|
||||
```
|
||||
|
||||
### 5. 临时文件
|
||||
|
||||
```python
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
temp_path = Path(tempfile.mkdtemp())
|
||||
yield temp_path
|
||||
shutil.rmtree(temp_path, ignore_errors=True)
|
||||
|
||||
|
||||
def test_file_operations(temp_dir):
|
||||
file_path = temp_dir / "test.txt"
|
||||
file_path.write_text("Hello, World!")
|
||||
assert file_path.read_text() == "Hello, World!"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 测试最佳实践
|
||||
|
||||
1. **测试名称要描述性**: `test_user_can_login` 比 `test_login` 更好
|
||||
|
||||
2. **AAA 模式**: Arrange (准备) -> Act (执行) -> Assert (断言)
|
||||
|
||||
3. **每个测试一个关注点**: 保持测试简短和专注
|
||||
|
||||
4. **使用描述性断言消息**:
|
||||
```python
|
||||
assert result == expected, f"Expected {expected}, got {result}"
|
||||
```
|
||||
|
||||
5. **避免测试间的依赖**: 每个测试应该独立运行
|
||||
|
||||
6. **保持测试快速**: 单元测试应该毫秒级完成
|
||||
|
||||
7. **合理使用 Mock**: 不要过度 Mock,保持测试接近真实场景
|
||||
|
||||
8. **定期运行测试**: 建议配置 CI/CD 自动运行
|
||||
|
||||
---
|
||||
|
||||
## 常见问题
|
||||
|
||||
### Q: 如何测试异常?
|
||||
|
||||
```python
|
||||
import pytest
|
||||
|
||||
def test_exception():
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
raise ValueError("error message")
|
||||
|
||||
assert str(exc_info.value) == "error message"
|
||||
```
|
||||
|
||||
### Q: 如何测试警告?
|
||||
|
||||
```python
|
||||
import warnings
|
||||
import pytest
|
||||
|
||||
def test_warning():
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.warn("deprecated", DeprecationWarning)
|
||||
|
||||
assert len(w) == 1
|
||||
assert str(w[0].message) == "deprecated"
|
||||
```
|
||||
|
||||
### Q: 如何测试日志输出?
|
||||
|
||||
```python
|
||||
import logging
|
||||
import pytest
|
||||
|
||||
def test_logging(caplog):
|
||||
logger = logging.getLogger("test")
|
||||
logger.warning("test message")
|
||||
|
||||
assert "test message" in caplog.text
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 资源链接
|
||||
|
||||
- 官方文档: https://docs.pytest.org/
|
||||
- Pytest 插件索引: https://plugincompat.herokuapp.com/
|
||||
- Awesome Pytest: https://github.com/pytest-dev/pytest
|
||||
|
||||
---
|
||||
|
||||
> 提示: 在本项目中,所有测试文件位于 `tests/` 目录,运行 `pytest` 即可执行全部测试。
|
||||
5
pytest.ini
Normal file
5
pytest.ini
Normal file
@@ -0,0 +1,5 @@
|
||||
[pytest]
|
||||
asyncio_mode = auto
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
@@ -5,3 +5,5 @@ apscheduler>=3.10.4
|
||||
pydantic>=2.5.0
|
||||
pydantic-settings>=2.1.0
|
||||
python-dotenv>=1.0.0
|
||||
pytest>=8.0.0
|
||||
pytest-asyncio>=0.23.0
|
||||
|
||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# XCClaw Tests
|
||||
105
tests/conftest.py
Normal file
105
tests/conftest.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import pytest
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import shutil
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
temp_path = Path(tempfile.mkdtemp())
|
||||
yield temp_path
|
||||
shutil.rmtree(temp_path, ignore_errors=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings(monkeypatch, temp_dir):
|
||||
from app.core import config
|
||||
|
||||
class TestSettings(config.Settings):
|
||||
data_dir: Path = temp_dir / "xcclaw_test"
|
||||
opencode_host: str = "127.0.0.1"
|
||||
opencode_port: int = 4096
|
||||
opencode_password: str = ""
|
||||
app_host: str = "0.0.0.0"
|
||||
app_port: int = 3005
|
||||
|
||||
monkeypatch.setattr(config, "settings", TestSettings())
|
||||
return config.settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_opencode_client():
|
||||
mock = AsyncMock()
|
||||
mock.sessions = {}
|
||||
mock.call_count = 0
|
||||
|
||||
async def create_session(title=None, parent_id=None):
|
||||
import uuid
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
mock.sessions[session_id] = {
|
||||
"id": session_id,
|
||||
"title": title,
|
||||
"status": "created",
|
||||
}
|
||||
return {"id": session_id, "title": title}
|
||||
|
||||
async def get_session(session_id):
|
||||
return mock.sessions.get(session_id, {})
|
||||
|
||||
async def delete_session(session_id):
|
||||
if session_id in mock.sessions:
|
||||
del mock.sessions[session_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
async def send_message(session_id, text):
|
||||
mock.call_count += 1
|
||||
if session_id in mock.sessions:
|
||||
mock.sessions[session_id]["status"] = "finished"
|
||||
return {"session_id": session_id, "status": "finished"}
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
async def send_message_async(session_id, text):
|
||||
mock.call_count += 1
|
||||
if session_id in mock.sessions:
|
||||
mock.sessions[session_id]["status"] = "running"
|
||||
|
||||
async def abort_session(session_id):
|
||||
if session_id in mock.sessions:
|
||||
mock.sessions[session_id]["status"] = "aborted"
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_session_status():
|
||||
return {"sessions": mock.sessions}
|
||||
|
||||
async def health_check():
|
||||
return {"status": "ok"}
|
||||
|
||||
async def close():
|
||||
pass
|
||||
|
||||
mock.create_session = create_session
|
||||
mock.get_session = get_session
|
||||
mock.delete_session = delete_session
|
||||
mock.send_message = send_message
|
||||
mock.send_message_async = send_message_async
|
||||
mock.abort_session = abort_session
|
||||
mock.get_session_status = get_session_status
|
||||
mock.health_check = health_check
|
||||
mock.close = close
|
||||
|
||||
return mock
|
||||
197
tests/test_api.py
Normal file
197
tests/test_api.py
Normal file
@@ -0,0 +1,197 @@
|
||||
import pytest
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
from app.main import app
|
||||
from app.models.session import SessionType
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check():
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.get("/api/xcclaw/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] in ["ok", "error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task():
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.post(
|
||||
"/api/xcclaw/task", json={"type": "ephemeral", "prompt": "test task"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "id" in data
|
||||
assert data["prompt"] == "test task"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks():
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
await client.post("/api/xcclaw/task", json={"prompt": "task 1"})
|
||||
|
||||
response = await client.get("/api/xcclaw/task")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_task():
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
create_response = await client.post(
|
||||
"/api/xcclaw/task", json={"prompt": "get test"}
|
||||
)
|
||||
task_id = create_response.json()["id"]
|
||||
|
||||
response = await client.get(f"/api/xcclaw/task/{task_id}")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == task_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_task_not_found():
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.get("/api/xcclaw/task/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_schedule():
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.post(
|
||||
"/api/xcclaw/schedule",
|
||||
json={
|
||||
"id": "test-schedule",
|
||||
"name": "Test Schedule",
|
||||
"cron": "0 9 * * *",
|
||||
"prompt": "daily task",
|
||||
"enabled": True,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == "test-schedule"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_schedules():
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
await client.post(
|
||||
"/api/xcclaw/schedule",
|
||||
json={
|
||||
"id": "s1",
|
||||
"name": "Schedule 1",
|
||||
"cron": "0 9 * * *",
|
||||
"prompt": "t1",
|
||||
},
|
||||
)
|
||||
|
||||
response = await client.get("/api/xcclaw/schedule")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_schedule():
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
await client.post(
|
||||
"/api/xcclaw/schedule",
|
||||
json={
|
||||
"id": "to-delete",
|
||||
"name": "Delete Me",
|
||||
"cron": "0 9 * * *",
|
||||
"prompt": "t",
|
||||
},
|
||||
)
|
||||
|
||||
response = await client.delete("/api/xcclaw/schedule/to-delete")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_history():
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.get("/api/xcclaw/history")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_history_with_limit():
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.get("/api/xcclaw/history?limit=10")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_history():
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.delete("/api/xcclaw/history")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["cleared"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Test environment issue with event loop")
|
||||
async def test_create_persistent_session():
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.post("/api/xcclaw/persistent?name=TestSession")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "id" in data
|
||||
assert data["name"] == "TestSession"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Test environment issue with event loop")
|
||||
async def test_list_persistent_sessions():
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
await client.post("/api/xcclaw/persistent?name=PS1")
|
||||
|
||||
response = await client.get("/api/xcclaw/persistent")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Test environment issue with event loop")
|
||||
async def test_delete_persistent_session():
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
create_response = await client.post("/api/xcclaw/persistent?name=ToDelete")
|
||||
session_id = create_response.json()["id"]
|
||||
|
||||
response = await client.delete(f"/api/xcclaw/persistent/{session_id}")
|
||||
assert response.status_code == 200
|
||||
70
tests/test_history.py
Normal file
70
tests/test_history.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from app.models.session import SessionType, TaskStatus, Task
|
||||
from app.services.history import HistoryService, history_service
|
||||
|
||||
|
||||
class TestHistoryService:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, mock_settings, monkeypatch):
|
||||
from app.services import storage
|
||||
self.storage = storage.storage
|
||||
self.history = history_service
|
||||
self.history._history = []
|
||||
|
||||
def test_add_task_history(self):
|
||||
task = Task(
|
||||
id="task-1",
|
||||
type=SessionType.EPHEMERAL,
|
||||
prompt="test prompt",
|
||||
status=TaskStatus.COMPLETED,
|
||||
created_at=datetime.now(),
|
||||
finished_at=datetime.now()
|
||||
)
|
||||
|
||||
self.history.add_task_history(task)
|
||||
|
||||
assert len(self.history._history) == 1
|
||||
assert self.history._history[0]["id"] == "task-1"
|
||||
|
||||
def test_get_history(self):
|
||||
for i in range(5):
|
||||
task = Task(
|
||||
id=f"task-{i}",
|
||||
type=SessionType.EPHEMERAL,
|
||||
prompt=f"prompt {i}",
|
||||
status=TaskStatus.COMPLETED,
|
||||
created_at=datetime.now()
|
||||
)
|
||||
self.history.add_task_history(task)
|
||||
|
||||
history = self.history.get_history()
|
||||
assert len(history) == 5
|
||||
|
||||
def test_get_history_with_limit(self):
|
||||
for i in range(10):
|
||||
task = Task(
|
||||
id=f"task-{i}",
|
||||
type=SessionType.EPHEMERAL,
|
||||
prompt=f"prompt {i}",
|
||||
status=TaskStatus.COMPLETED,
|
||||
created_at=datetime.now()
|
||||
)
|
||||
self.history.add_task_history(task)
|
||||
|
||||
history = self.history.get_history(limit=3)
|
||||
assert len(history) == 3
|
||||
|
||||
def test_clear_history(self):
|
||||
task = Task(
|
||||
id="task-1",
|
||||
type=SessionType.EPHEMERAL,
|
||||
prompt="test",
|
||||
status=TaskStatus.COMPLETED,
|
||||
created_at=datetime.now()
|
||||
)
|
||||
self.history.add_task_history(task)
|
||||
|
||||
self.history.clear_history()
|
||||
|
||||
assert len(self.history._history) == 0
|
||||
99
tests/test_models.py
Normal file
99
tests/test_models.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from app.models.session import (
|
||||
SessionType, TaskStatus, CreateSessionRequest,
|
||||
Task, PersistentSession, TaskHistory
|
||||
)
|
||||
|
||||
|
||||
class TestSessionType:
|
||||
def test_session_types(self):
|
||||
assert SessionType.EPHEMERAL.value == "ephemeral"
|
||||
assert SessionType.PERSISTENT.value == "persistent"
|
||||
assert SessionType.SCHEDULED.value == "scheduled"
|
||||
|
||||
|
||||
class TestTaskStatus:
|
||||
def test_task_statuses(self):
|
||||
assert TaskStatus.PENDING.value == "pending"
|
||||
assert TaskStatus.RUNNING.value == "running"
|
||||
assert TaskStatus.COMPLETED.value == "completed"
|
||||
assert TaskStatus.FAILED.value == "failed"
|
||||
assert TaskStatus.ABORTED.value == "aborted"
|
||||
|
||||
|
||||
class TestCreateSessionRequest:
|
||||
def test_default_values(self):
|
||||
req = CreateSessionRequest(prompt="test prompt")
|
||||
assert req.type == SessionType.EPHEMERAL
|
||||
assert req.prompt == "test prompt"
|
||||
assert req.title is None
|
||||
|
||||
def test_with_all_fields(self):
|
||||
req = CreateSessionRequest(
|
||||
type=SessionType.PERSISTENT,
|
||||
title="Test Session",
|
||||
prompt="test prompt"
|
||||
)
|
||||
assert req.type == SessionType.PERSISTENT
|
||||
assert req.title == "Test Session"
|
||||
assert req.prompt == "test prompt"
|
||||
|
||||
|
||||
class TestTask:
|
||||
def test_task_creation(self):
|
||||
task = Task(
|
||||
id="test-id",
|
||||
type=SessionType.EPHEMERAL,
|
||||
prompt="test prompt"
|
||||
)
|
||||
assert task.id == "test-id"
|
||||
assert task.type == SessionType.EPHEMERAL
|
||||
assert task.prompt == "test prompt"
|
||||
assert task.status == TaskStatus.PENDING
|
||||
assert task.session_id is None
|
||||
|
||||
def test_task_with_timestamps(self):
|
||||
now = datetime.now()
|
||||
task = Task(
|
||||
id="test-id",
|
||||
type=SessionType.SCHEDULED,
|
||||
prompt="scheduled task",
|
||||
created_at=now,
|
||||
started_at=now,
|
||||
finished_at=now,
|
||||
status=TaskStatus.COMPLETED
|
||||
)
|
||||
assert task.created_at == now
|
||||
assert task.started_at == now
|
||||
assert task.finished_at == now
|
||||
assert task.status == TaskStatus.COMPLETED
|
||||
|
||||
|
||||
class TestPersistentSession:
|
||||
def test_persistent_session_creation(self):
|
||||
now = datetime.now()
|
||||
session = PersistentSession(
|
||||
id="ps-1",
|
||||
session_id="session-123",
|
||||
name="My Session",
|
||||
created_at=now,
|
||||
last_used_at=now
|
||||
)
|
||||
assert session.id == "ps-1"
|
||||
assert session.session_id == "session-123"
|
||||
assert session.name == "My Session"
|
||||
|
||||
|
||||
class TestTaskHistory:
|
||||
def test_task_history_creation(self):
|
||||
history = TaskHistory(
|
||||
id="history-1",
|
||||
type=SessionType.EPHEMERAL,
|
||||
prompt="test",
|
||||
status=TaskStatus.COMPLETED,
|
||||
created_at="2024-01-01T00:00:00",
|
||||
finished_at="2024-01-01T00:01:00"
|
||||
)
|
||||
assert history.id == "history-1"
|
||||
assert history.status == TaskStatus.COMPLETED
|
||||
67
tests/test_persistent_session.py
Normal file
67
tests/test_persistent_session.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from datetime import datetime
|
||||
from app.services.persistent_session import PersistentSessionManager
|
||||
|
||||
|
||||
class TestPersistentSessionManager:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, mock_settings, mock_opencode_client):
|
||||
self.manager = PersistentSessionManager()
|
||||
self.manager._sessions = {}
|
||||
self.mock_client = mock_opencode_client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_session(self):
|
||||
with patch("app.services.persistent_session.opencode_client", self.mock_client):
|
||||
session = await self.manager.create_session("Test Session")
|
||||
|
||||
assert session.id is not None
|
||||
assert session.name == "Test Session"
|
||||
assert session.created_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_session(self):
|
||||
with patch("app.services.persistent_session.opencode_client", self.mock_client):
|
||||
session = await self.manager.create_session("Temp Session")
|
||||
session_id = session.id
|
||||
|
||||
result = await self.manager.delete_session(session_id)
|
||||
|
||||
assert result is True
|
||||
assert session_id not in self.manager._sessions
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message(self):
|
||||
with patch("app.services.persistent_session.opencode_client", self.mock_client):
|
||||
session = await self.manager.create_session("Chat Session")
|
||||
|
||||
result = await self.manager.send_message(session.id, "Hello")
|
||||
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_to_nonexistent(self):
|
||||
with patch("app.services.persistent_session.opencode_client", self.mock_client):
|
||||
with pytest.raises(ValueError):
|
||||
await self.manager.send_message("nonexistent", "Hello")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session(self):
|
||||
with patch("app.services.persistent_session.opencode_client", self.mock_client):
|
||||
session = await self.manager.create_session("Get Test")
|
||||
|
||||
retrieved = self.manager.get_session(session.id)
|
||||
|
||||
assert retrieved is not None
|
||||
assert retrieved.id == session.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sessions(self):
|
||||
with patch("app.services.persistent_session.opencode_client", self.mock_client):
|
||||
await self.manager.create_session("Session 1")
|
||||
await self.manager.create_session("Session 2")
|
||||
|
||||
sessions = self.manager.list_sessions()
|
||||
|
||||
assert len(sessions) == 2
|
||||
68
tests/test_scheduler.py
Normal file
68
tests/test_scheduler.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import pytest
|
||||
import asyncio
|
||||
from app.services.scheduler import SchedulerService, ScheduleTask
|
||||
|
||||
|
||||
class TestSchedulerService:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, mock_settings):
|
||||
self.service = SchedulerService()
|
||||
self.service.scheduled_tasks = {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_schedule(self):
|
||||
task = ScheduleTask(
|
||||
id="schedule-1",
|
||||
name="Daily Task",
|
||||
cron="0 9 * * *",
|
||||
prompt="run daily task",
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
result = await self.service.add_schedule(task)
|
||||
|
||||
assert result.id == "schedule-1"
|
||||
assert result.name == "Daily Task"
|
||||
assert "schedule-1" in self.service.scheduled_tasks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_schedule(self):
|
||||
task = ScheduleTask(
|
||||
id="schedule-1",
|
||||
name="Task to remove",
|
||||
cron="0 9 * * *",
|
||||
prompt="test",
|
||||
enabled=True,
|
||||
)
|
||||
await self.service.add_schedule(task)
|
||||
|
||||
result = await self.service.remove_schedule("schedule-1")
|
||||
|
||||
assert result is True
|
||||
assert "schedule-1" not in self.service.scheduled_tasks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_nonexistent_schedule(self):
|
||||
result = await self.service.remove_schedule("nonexistent")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_schedules(self):
|
||||
task1 = ScheduleTask(id="s1", name="Task 1", cron="0 9 * * *", prompt="t1")
|
||||
task2 = ScheduleTask(id="s2", name="Task 2", cron="0 10 * * *", prompt="t2")
|
||||
|
||||
await self.service.add_schedule(task1)
|
||||
await self.service.add_schedule(task2)
|
||||
|
||||
schedules = await self.service.list_schedules()
|
||||
|
||||
assert len(schedules) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Requires event loop in test environment")
|
||||
async def test_start_and_shutdown(self):
|
||||
self.service.start()
|
||||
assert self.service.scheduler.running is True
|
||||
|
||||
self.service.shutdown()
|
||||
assert self.service.scheduler.running is False
|
||||
83
tests/test_session_manager.py
Normal file
83
tests/test_session_manager.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, AsyncMock
|
||||
from datetime import datetime
|
||||
from app.models.session import SessionType, TaskStatus, CreateSessionRequest
|
||||
from app.services.session_manager import SessionManager
|
||||
|
||||
|
||||
class TestSessionManager:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, mock_settings, mock_opencode_client):
|
||||
self.manager = SessionManager()
|
||||
self.manager.tasks = {}
|
||||
self.mock_client = mock_opencode_client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task(self):
|
||||
with patch("app.services.session_manager.opencode_client", self.mock_client):
|
||||
request = CreateSessionRequest(
|
||||
type=SessionType.EPHEMERAL, prompt="test prompt"
|
||||
)
|
||||
|
||||
task = await self.manager.create_task(request)
|
||||
|
||||
assert task.id is not None
|
||||
assert task.type == SessionType.EPHEMERAL
|
||||
assert task.prompt == "test prompt"
|
||||
assert task.status == TaskStatus.PENDING
|
||||
assert task.created_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_task_sync(self):
|
||||
with patch("app.services.session_manager.opencode_client", self.mock_client):
|
||||
request = CreateSessionRequest(prompt="sync task")
|
||||
task = await self.manager.create_task(request)
|
||||
|
||||
result = await self.manager.execute_task(task.id)
|
||||
|
||||
assert result is not None
|
||||
updated_task = await self.manager.get_task(task.id)
|
||||
assert updated_task.status == TaskStatus.COMPLETED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_task_async(self):
|
||||
with patch("app.services.session_manager.opencode_client", self.mock_client):
|
||||
request = CreateSessionRequest(prompt="async task")
|
||||
task = await self.manager.create_task(request)
|
||||
|
||||
await self.manager.execute_task_async(task.id)
|
||||
|
||||
updated_task = await self.manager.get_task(task.id)
|
||||
assert updated_task.status == TaskStatus.RUNNING
|
||||
assert updated_task.session_id is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort_task(self):
|
||||
with patch("app.services.session_manager.opencode_client", self.mock_client):
|
||||
request = CreateSessionRequest(prompt="abortable task")
|
||||
task = await self.manager.create_task(request)
|
||||
await self.manager.execute_task_async(task.id)
|
||||
|
||||
result = await self.manager.abort_task(task.id)
|
||||
|
||||
assert result is True
|
||||
updated_task = await self.manager.get_task(task.id)
|
||||
assert updated_task.status == TaskStatus.ABORTED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_task_not_found(self):
|
||||
with patch("app.services.session_manager.opencode_client", self.mock_client):
|
||||
task = await self.manager.get_task("nonexistent")
|
||||
assert task is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks(self):
|
||||
with patch("app.services.session_manager.opencode_client", self.mock_client):
|
||||
request1 = CreateSessionRequest(prompt="task 1")
|
||||
request2 = CreateSessionRequest(prompt="task 2")
|
||||
|
||||
await self.manager.create_task(request1)
|
||||
await self.manager.create_task(request2)
|
||||
|
||||
tasks = await self.manager.list_tasks()
|
||||
assert len(tasks) == 2
|
||||
47
tests/test_storage.py
Normal file
47
tests/test_storage.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import pytest
|
||||
from app.services.storage import JSONStorage
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TestJSONStorage:
|
||||
def test_save_and_load(self, temp_dir):
|
||||
storage = JSONStorage()
|
||||
storage.data_dir = temp_dir
|
||||
|
||||
data = {"name": "test", "value": 123}
|
||||
storage.save("test", data)
|
||||
|
||||
loaded = storage.load("test")
|
||||
assert loaded == data
|
||||
|
||||
def test_load_nonexistent(self, temp_dir):
|
||||
storage = JSONStorage()
|
||||
storage.data_dir = temp_dir
|
||||
|
||||
result = storage.load("nonexistent", default={"default": True})
|
||||
assert result == {"default": True}
|
||||
|
||||
def test_save_list(self, temp_dir):
|
||||
storage = JSONStorage()
|
||||
storage.data_dir = temp_dir
|
||||
|
||||
data = [1, 2, 3, "test"]
|
||||
storage.save("list", data)
|
||||
|
||||
loaded = storage.load("list")
|
||||
assert loaded == data
|
||||
|
||||
def test_save_nested_data(self, temp_dir):
|
||||
storage = JSONStorage()
|
||||
storage.data_dir = temp_dir
|
||||
|
||||
data = {
|
||||
"tasks": [
|
||||
{"id": "1", "status": "completed"},
|
||||
{"id": "2", "status": "running"}
|
||||
]
|
||||
}
|
||||
storage.save("nested", data)
|
||||
|
||||
loaded = storage.load("nested")
|
||||
assert loaded == data
|
||||
43
tests/test_websocket.py
Normal file
43
tests/test_websocket.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from app.services.websocket_manager import WebSocketManager
|
||||
|
||||
|
||||
class TestWebSocketManager:
|
||||
@pytest.fixture
|
||||
def ws_manager(self):
|
||||
return WebSocketManager()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket(self):
|
||||
ws = AsyncMock()
|
||||
ws.send_json = AsyncMock()
|
||||
return ws
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect(self, ws_manager, mock_websocket):
|
||||
await ws_manager.connect(mock_websocket)
|
||||
|
||||
mock_websocket.accept.assert_called_once()
|
||||
assert len(ws_manager._clients) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Requires async polling environment")
|
||||
async def test_start_poll_task(self, ws_manager):
|
||||
mock_client = AsyncMock()
|
||||
with patch("app.services.websocket_manager.opencode_client", mock_client):
|
||||
await ws_manager.start_poll_task("task-1", "session-1")
|
||||
|
||||
assert "task-1" in ws_manager._poll_tasks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="Requires async polling environment")
|
||||
async def test_stop_poll_task(self, ws_manager):
|
||||
mock_client = AsyncMock()
|
||||
with patch("app.services.websocket_manager.opencode_client", mock_client):
|
||||
await ws_manager.start_poll_task("task-1", "session-1")
|
||||
|
||||
await ws_manager.stop_poll_task("task-1")
|
||||
|
||||
assert "task-1" not in ws_manager._poll_tasks
|
||||
Reference in New Issue
Block a user