test: add unit tests for all 10 routers (92 total)
New test files: - test_conversations.py — list/create/get/update/delete/delete-all, admin enforcement - test_presets.py — list/create/update/delete, default preset protection - test_profile.py — get/update/default, length validation - test_models_router.py — list/ps/show/stats/search-status, connect errors - test_completions.py — API key auth, FIM passthrough, streaming/blocking, errors - test_search_route.py — explicit search flow, no results, stream errors - test_memories.py — edit/search/stats endpoints, validation, admin enforcement Update AGENTS.md with full test file coverage table and README.md
This commit is contained in:
24
AGENTS.md
24
AGENTS.md
@@ -14,6 +14,30 @@
|
||||
|
||||
All tests use `tmp_path` fixtures + monkeypatched `httpx.AsyncClient.stream`. No external services needed. Test factories reset `SESSIONS`, `PIN_ATTEMPTS`, `RATE_EVENTS` globals — be careful not to let test state leak. After the modular refactor, tests import directly from the correct modules (`db`, `security`, `config`, `search`, `rag`, `memory`, `routers.*`) — not from the old monolithic `app` namespace.
|
||||
|
||||
Every router has a dedicated test file:
|
||||
| File | Covers |
|
||||
|------|--------|
|
||||
| `test_auth_capabilities.py` | `auth.py` — guest/admin sessions, origin blocking, logout |
|
||||
| `test_chat_streaming_and_memory_paths.py` | `routers/chat.py` — streaming, auto-search, remember/forget |
|
||||
| `test_completions.py` | `routers/completions.py` — API key auth, FIM, streaming, blocking, errors |
|
||||
| `test_conversations.py` | `routers/conversations.py` — full CRUD, guest admin enforcement |
|
||||
| `test_memories.py` | `routers/memories.py` — edit, search, stats endpoints |
|
||||
| `test_models_router.py` | `routers/models.py` — models list, ps, show, stats, search/status |
|
||||
| `test_presets.py` | `routers/presets.py` — full CRUD, default preset protection |
|
||||
| `test_profile.py` | `routers/profile.py` — get, update, default, length validation |
|
||||
| `test_search_route.py` | `routers/search_route.py` — explicit search flow, no results, errors |
|
||||
| `test_search_url_sanitization.py` | `search.py` URL sanitizer |
|
||||
| `test_settings_allowlist.py` | `routers/settings.py` — allowlisted key enforcement |
|
||||
| `test_skills_framework.py` | `routers/skills.py` — list, toggle, unknown skill, prompt injection |
|
||||
| `test_ip_allowlist.py` | IP allowlist helper + middleware |
|
||||
| `test_rate_and_payload_guardrails.py` | Rate limits + payload size enforcement |
|
||||
| `test_error_envelopes.py` | Global exception handler + stream error incidents |
|
||||
|
||||
Modules that call `httpx.AsyncClient` (chat, completions, models, search_route)
|
||||
are mocked via `monkeypatch.setattr` on `AsyncClient.stream`, `.get`, or `.post`.
|
||||
CPU stats in `models.py` (`api/stats`) use real `psutil`; GPU stats are
|
||||
monkeypatched via `routers.models.get_gpu_stats`.
|
||||
|
||||
## Architecture
|
||||
|
||||
Refactored from single-file (`app.py`) into modules under project root:
|
||||
|
||||
@@ -19,6 +19,7 @@ Developer wiki: [docs/wiki/Home.md](docs/wiki/Home.md)
|
||||
- **All 8 test files fixed** — rewired imports after the modular refactor; all 26 tests pass
|
||||
- **Origin check extended to all API methods** — GET/HEAD/OPTIONS requests no longer bypass origin checking (was limited to POST/PUT/DELETE/PATCH)
|
||||
- **Missing headers now rejected** — `origin_allowed()` returns `False` when both `Origin` and `Referer` are absent, closing the CSRF read gap for script-initiated requests
|
||||
- **Full router test coverage** — 7 new test files added: `test_conversations.py`, `test_presets.py`, `test_profile.py`, `test_models_router.py`, `test_completions.py`, `test_search_route.py`, `test_memories.py`; all 10 routers now have dedicated unit tests (92 total, up from 26)
|
||||
|
||||
## Features
|
||||
|
||||
|
||||
222
tests/test_completions.py
Normal file
222
tests/test_completions.py
Normal file
@@ -0,0 +1,222 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import app
|
||||
import config
|
||||
import db
|
||||
import routers.completions
|
||||
from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS
|
||||
|
||||
|
||||
def make_client(tmp_path: Path) -> TestClient:
|
||||
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
||||
db.DB_PATH = tmp_path / "jarvischat-completions.db"
|
||||
SESSIONS.clear()
|
||||
PIN_ATTEMPTS.clear()
|
||||
RATE_EVENTS.clear()
|
||||
db.init_db()
|
||||
return TestClient(app.app, raise_server_exceptions=False)
|
||||
|
||||
|
||||
TEST_API_KEY = "test-sk-jarvischat-completions"
|
||||
|
||||
|
||||
def _auth_headers(extra: dict = None) -> dict:
|
||||
h = {"Authorization": f"Bearer {TEST_API_KEY}", "Content-Type": "application/json", "Origin": "http://testserver"}
|
||||
if extra:
|
||||
h.update(extra)
|
||||
return h
|
||||
|
||||
|
||||
class _MockStreamResponse:
|
||||
def __init__(self, lines: list[str]):
|
||||
self._lines = lines
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def aiter_lines(self):
|
||||
for line in self._lines:
|
||||
yield line
|
||||
|
||||
|
||||
class _MockAsyncPostResponse:
|
||||
def __init__(self, status_code=200, json_data=None):
|
||||
self.status_code = status_code
|
||||
self._json_data = json_data or {}
|
||||
|
||||
def json(self):
|
||||
return self._json_data
|
||||
|
||||
|
||||
def _stream_json_lines(events: list[dict]) -> list[str]:
|
||||
return [json.dumps(event) for event in events]
|
||||
|
||||
|
||||
def test_completions_missing_api_key(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hi"}]},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
def test_completions_invalid_api_key(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hi"}]},
|
||||
headers={"Authorization": "Bearer wrong-key", "Origin": "http://testserver"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
def test_completions_no_messages(tmp_path: Path, monkeypatch):
|
||||
monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY)
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.post("/v1/chat/completions", json={}, headers=_auth_headers())
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_completions_empty_messages(tmp_path: Path, monkeypatch):
|
||||
monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY)
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.post("/v1/chat/completions", json={"messages": []}, headers=_auth_headers())
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_completions_no_user_message(tmp_path: Path, monkeypatch):
|
||||
monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY)
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "assistant", "content": "hello"}]},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_completions_streaming(tmp_path: Path, monkeypatch):
|
||||
monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY)
|
||||
events = _stream_json_lines([
|
||||
{"choices": [{"delta": {"content": "Hello"}, "logprobs": None}]},
|
||||
{"choices": [{"delta": {"content": " world"}, "logprobs": None}]},
|
||||
{"choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {"tokens_per_second": 15.0}},
|
||||
])
|
||||
|
||||
call_count = 0
|
||||
|
||||
def stream_stub(self, method, url, json=None, timeout=None):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return _MockStreamResponse(events)
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "stream", stream_stub)
|
||||
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hi"}], "stream": True},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.text
|
||||
assert "data: [DONE]" in body
|
||||
assert "Hello" in body or "world" in body
|
||||
assert "chatcmpl-" in body
|
||||
|
||||
|
||||
def test_completions_blocking(tmp_path: Path, monkeypatch):
|
||||
monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY)
|
||||
events = _stream_json_lines([
|
||||
{"choices": [{"delta": {"content": "Hello world"}, "logprobs": None}]},
|
||||
{"choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {}},
|
||||
])
|
||||
|
||||
def stream_stub(self, method, url, json=None, timeout=None):
|
||||
return _MockStreamResponse(events)
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "stream", stream_stub)
|
||||
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hi"}], "stream": False},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["object"] == "chat.completion"
|
||||
assert data["choices"][0]["message"]["content"] == "Hello world"
|
||||
|
||||
|
||||
def test_completions_fim_passthrough(tmp_path: Path, monkeypatch):
|
||||
monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY)
|
||||
fim_data = {"prompt": "def foo():\n ", "suffix": "\n return x", "model": "llama3.1:latest"}
|
||||
|
||||
async def mock_post(self, url, json=None, timeout=None):
|
||||
return _MockAsyncPostResponse(json_data={"choices": [{"text": "pass"}], "usage": {}})
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
|
||||
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.post("/v1/chat/completions", json=fim_data, headers=_auth_headers())
|
||||
assert resp.status_code == 200
|
||||
assert "choices" in resp.json()
|
||||
|
||||
|
||||
def test_completions_connect_error_stream(tmp_path: Path, monkeypatch):
|
||||
monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY)
|
||||
|
||||
def broken_stream(self, method, url, json=None, timeout=None):
|
||||
raise httpx.ConnectError("Connection refused")
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "stream", broken_stream)
|
||||
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hi"}], "stream": True},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert "connection_error" in resp.text
|
||||
|
||||
|
||||
def test_completions_connect_error_blocking(tmp_path: Path, monkeypatch):
|
||||
monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY)
|
||||
|
||||
def broken_stream(self, method, url, json=None, timeout=None):
|
||||
raise httpx.ConnectError("Connection refused")
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "stream", broken_stream)
|
||||
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hi"}], "stream": False},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 503
|
||||
|
||||
|
||||
def test_completions_fim_connect_error(tmp_path: Path, monkeypatch):
|
||||
monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY)
|
||||
fim_data = {"prompt": "def foo():", "model": "llama3.1:latest"}
|
||||
|
||||
def broken_post(self, url, json=None, timeout=None):
|
||||
raise httpx.ConnectError("Connection refused")
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "post", broken_post)
|
||||
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.post("/v1/chat/completions", json=fim_data, headers=_auth_headers())
|
||||
assert resp.status_code == 503
|
||||
153
tests/test_conversations.py
Normal file
153
tests/test_conversations.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import app
|
||||
import db
|
||||
from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS
|
||||
|
||||
|
||||
def make_client(tmp_path: Path) -> TestClient:
|
||||
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
||||
db.DB_PATH = tmp_path / "jarvischat-conversations.db"
|
||||
SESSIONS.clear()
|
||||
PIN_ATTEMPTS.clear()
|
||||
RATE_EVENTS.clear()
|
||||
db.init_db()
|
||||
return TestClient(app.app, raise_server_exceptions=False)
|
||||
|
||||
|
||||
def _admin_headers(client: TestClient) -> dict:
|
||||
login = client.post("/api/auth/login", json={"pin": "1234"}, headers={"Origin": "http://testserver"})
|
||||
sid = login.json()["session_id"]
|
||||
return {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||
|
||||
|
||||
def _guest_headers(client: TestClient) -> dict:
|
||||
sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()["session_id"]
|
||||
return {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||
|
||||
|
||||
def test_list_conversations_empty(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.get("/api/conversations", headers=_guest_headers(client))
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
def test_create_and_list_conversation(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
headers = _admin_headers(client)
|
||||
|
||||
create = client.post("/api/conversations", json={"title": "Test Chat", "model": "llama3.1:latest"}, headers=headers)
|
||||
assert create.status_code == 200
|
||||
data = create.json()
|
||||
assert data["title"] == "Test Chat"
|
||||
assert data["model"] == "llama3.1:latest"
|
||||
|
||||
list_resp = client.get("/api/conversations", headers=headers)
|
||||
assert list_resp.status_code == 200
|
||||
convs = list_resp.json()
|
||||
assert len(convs) == 1
|
||||
assert convs[0]["title"] == "Test Chat"
|
||||
|
||||
|
||||
def test_get_conversation_returns_messages(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
headers = _admin_headers(client)
|
||||
create = client.post("/api/conversations", json={"title": "My Chat"}, headers=headers)
|
||||
conv_id = create.json()["id"]
|
||||
|
||||
resp = client.get(f"/api/conversations/{conv_id}", headers=headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["conversation"]["id"] == conv_id
|
||||
assert data["messages"] == []
|
||||
|
||||
|
||||
def test_get_conversation_not_found(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.get("/api/conversations/nope", headers=_guest_headers(client))
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_update_conversation_title(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
headers = _admin_headers(client)
|
||||
create = client.post("/api/conversations", json={"title": "Old"}, headers=headers)
|
||||
conv_id = create.json()["id"]
|
||||
|
||||
update = client.put(f"/api/conversations/{conv_id}", json={"title": "New Title"}, headers=headers)
|
||||
assert update.status_code == 200
|
||||
|
||||
get = client.get(f"/api/conversations/{conv_id}", headers=headers)
|
||||
assert get.json()["conversation"]["title"] == "New Title"
|
||||
|
||||
|
||||
def test_update_conversation_model(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
headers = _admin_headers(client)
|
||||
create = client.post("/api/conversations", json={"title": "Test"}, headers=headers)
|
||||
conv_id = create.json()["id"]
|
||||
|
||||
update = client.put(f"/api/conversations/{conv_id}", json={"model": "qwen2:latest"}, headers=headers)
|
||||
assert update.status_code == 200
|
||||
|
||||
get = client.get(f"/api/conversations/{conv_id}", headers=headers)
|
||||
assert get.json()["conversation"]["model"] == "qwen2:latest"
|
||||
|
||||
|
||||
def test_delete_conversation(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
headers = _admin_headers(client)
|
||||
create = client.post("/api/conversations", json={"title": "Delete Me"}, headers=headers)
|
||||
conv_id = create.json()["id"]
|
||||
|
||||
delete = client.delete(f"/api/conversations/{conv_id}", headers=headers)
|
||||
assert delete.status_code == 200
|
||||
|
||||
get = client.get(f"/api/conversations/{conv_id}", headers=_guest_headers(client))
|
||||
assert get.status_code == 404
|
||||
|
||||
|
||||
def test_delete_all_conversations(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
headers = _admin_headers(client)
|
||||
client.post("/api/conversations", json={"title": "One"}, headers=headers)
|
||||
client.post("/api/conversations", json={"title": "Two"}, headers=headers)
|
||||
|
||||
delete_all = client.delete("/api/conversations", headers=headers)
|
||||
assert delete_all.status_code == 200
|
||||
|
||||
list_resp = client.get("/api/conversations", headers=_guest_headers(client))
|
||||
assert list_resp.json() == []
|
||||
|
||||
|
||||
def test_guest_cannot_create_conversation(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.post("/api/conversations", json={"title": "test"}, headers=_guest_headers(client))
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
def test_guest_cannot_update_conversation(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
headers = _admin_headers(client)
|
||||
create = client.post("/api/conversations", json={"title": "Test"}, headers=headers)
|
||||
conv_id = create.json()["id"]
|
||||
|
||||
guest_headers = _guest_headers(client)
|
||||
resp = client.put(f"/api/conversations/{conv_id}", json={"title": "hack"}, headers=guest_headers)
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
def test_guest_cannot_delete_conversation(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.delete("/api/conversations/some-id", headers=_guest_headers(client))
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
def test_guest_cannot_delete_all(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.delete("/api/conversations", headers=_guest_headers(client))
|
||||
assert resp.status_code == 403
|
||||
161
tests/test_memories.py
Normal file
161
tests/test_memories.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import app
|
||||
import config
|
||||
import db
|
||||
from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS
|
||||
|
||||
|
||||
def make_client(tmp_path: Path) -> TestClient:
|
||||
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
||||
db.DB_PATH = tmp_path / "jarvischat-memories.db"
|
||||
SESSIONS.clear()
|
||||
PIN_ATTEMPTS.clear()
|
||||
RATE_EVENTS.clear()
|
||||
db.init_db()
|
||||
return TestClient(app.app, raise_server_exceptions=False)
|
||||
|
||||
|
||||
def _admin_headers(client: TestClient) -> dict:
|
||||
login = client.post("/api/auth/login", json={"pin": "1234"}, headers={"Origin": "http://testserver"})
|
||||
sid = login.json()["session_id"]
|
||||
return {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||
|
||||
|
||||
def _guest_headers(client: TestClient) -> dict:
|
||||
sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()["session_id"]
|
||||
return {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||
|
||||
|
||||
def _create_memory(client: TestClient, headers: dict, fact: str = "test fact", topic: str = "general") -> int:
|
||||
resp = client.post("/api/memories", json={"fact": fact, "topic": topic}, headers=headers)
|
||||
assert resp.status_code == 200
|
||||
return resp.json()["rowid"]
|
||||
|
||||
|
||||
def test_list_memories_empty(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.get("/api/memories", headers=_guest_headers(client))
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["count"] == 0
|
||||
|
||||
|
||||
def test_list_memories_by_topic(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
headers = _admin_headers(client)
|
||||
_create_memory(client, headers, "I like Python", "preference")
|
||||
_create_memory(client, headers, "Building a game", "project")
|
||||
|
||||
general = client.get("/api/memories?topic=preference", headers=_guest_headers(client))
|
||||
assert general.json()["count"] == 1
|
||||
assert general.json()["memories"][0]["topic"] == "preference"
|
||||
|
||||
|
||||
def test_create_memory_requires_fact(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.post("/api/memories", json={"fact": ""}, headers=_admin_headers(client))
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_create_memory_too_long(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
long_fact = "x" * (config.MAX_MEMORY_FACT_CHARS + 1)
|
||||
resp = client.post("/api/memories", json={"fact": long_fact}, headers=_admin_headers(client))
|
||||
assert resp.status_code == 413
|
||||
|
||||
|
||||
def test_edit_memory(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
headers = _admin_headers(client)
|
||||
rowid = _create_memory(client, headers, "original fact")
|
||||
|
||||
edit = client.put(f"/api/memories/{rowid}", json={"fact": "updated fact"}, headers=headers)
|
||||
assert edit.status_code == 200
|
||||
|
||||
memories = client.get("/api/memories", headers=_guest_headers(client)).json()
|
||||
assert any(m["fact"] == "updated fact" for m in memories["memories"])
|
||||
|
||||
|
||||
def test_edit_memory_not_found(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.put("/api/memories/99999", json={"fact": "nope"}, headers=_admin_headers(client))
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_edit_memory_empty_fact(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
headers = _admin_headers(client)
|
||||
rowid = _create_memory(client, headers, "some fact")
|
||||
resp = client.put(f"/api/memories/{rowid}", json={"fact": ""}, headers=headers)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_edit_memory_too_long(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
headers = _admin_headers(client)
|
||||
rowid = _create_memory(client, headers, "some fact")
|
||||
long_fact = "x" * (config.MAX_MEMORY_FACT_CHARS + 1)
|
||||
resp = client.put(f"/api/memories/{rowid}", json={"fact": long_fact}, headers=headers)
|
||||
assert resp.status_code == 413
|
||||
|
||||
|
||||
def test_delete_memory_not_found(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.delete("/api/memories/99999", headers=_admin_headers(client))
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_search_memories(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
headers = _admin_headers(client)
|
||||
_create_memory(client, headers, "my favorite color is blue", "preference")
|
||||
_create_memory(client, headers, "running nginx on port 443", "infrastructure")
|
||||
|
||||
resp = client.get("/api/memories/search?q=nginx&limit=5", headers=_guest_headers(client))
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["count"] >= 1
|
||||
assert any("nginx" in r["fact"] for r in data["results"])
|
||||
|
||||
|
||||
def test_search_memories_no_results(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.get("/api/memories/search?q=xyznonexistent&limit=5", headers=_guest_headers(client))
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["count"] == 0
|
||||
|
||||
|
||||
def test_memory_stats(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
headers = _admin_headers(client)
|
||||
_create_memory(client, headers, "like rust", "preference")
|
||||
_create_memory(client, headers, "like python", "preference")
|
||||
_create_memory(client, headers, "project game", "project")
|
||||
|
||||
resp = client.get("/api/memories/stats", headers=_guest_headers(client))
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total"] == 3
|
||||
assert data["by_topic"]["preference"] == 2
|
||||
assert data["by_topic"]["project"] == 1
|
||||
|
||||
|
||||
def test_guest_cannot_create_memory(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.post("/api/memories", json={"fact": "hack"}, headers=_guest_headers(client))
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
def test_guest_cannot_edit_memory(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.put("/api/memories/1", json={"fact": "hack"}, headers=_guest_headers(client))
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
def test_guest_cannot_delete_memory(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.delete("/api/memories/1", headers=_guest_headers(client))
|
||||
assert resp.status_code == 403
|
||||
138
tests/test_models_router.py
Normal file
138
tests/test_models_router.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import app
|
||||
import db
|
||||
import routers.models
|
||||
from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS
|
||||
|
||||
|
||||
def make_client(tmp_path: Path) -> TestClient:
|
||||
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
||||
db.DB_PATH = tmp_path / "jarvischat-models.db"
|
||||
SESSIONS.clear()
|
||||
PIN_ATTEMPTS.clear()
|
||||
RATE_EVENTS.clear()
|
||||
db.init_db()
|
||||
return TestClient(app.app, raise_server_exceptions=False)
|
||||
|
||||
|
||||
def _guest_headers(client: TestClient) -> dict:
|
||||
sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()["session_id"]
|
||||
return {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||
|
||||
|
||||
class _MockAsyncResponse:
|
||||
"""Mock for httpx.AsyncClient.get/post that returns a JSON response."""
|
||||
def __init__(self, status_code=200, json_data=None):
|
||||
self.status_code = status_code
|
||||
self._json_data = json_data or {}
|
||||
|
||||
def json(self):
|
||||
return self._json_data
|
||||
|
||||
|
||||
async def _mock_get_models(*args, **kwargs):
|
||||
return _MockAsyncResponse(json_data={
|
||||
"data": [{"id": "llama3.1:latest"}, {"id": "qwen2:latest"}]
|
||||
})
|
||||
|
||||
|
||||
async def _mock_get_empty_models(*args, **kwargs):
|
||||
return _MockAsyncResponse(json_data={"data": []})
|
||||
|
||||
|
||||
async def _mock_connect_error(*args, **kwargs):
|
||||
raise httpx.ConnectError("Connection refused")
|
||||
|
||||
|
||||
async def _mock_show_model(*args, **kwargs):
|
||||
return _MockAsyncResponse(json_data={
|
||||
"modelfile": "FROM llama3.1", "parameters": {}
|
||||
})
|
||||
|
||||
|
||||
async def _mock_search_available(*args, **kwargs):
|
||||
return _MockAsyncResponse(status_code=200)
|
||||
|
||||
|
||||
async def _mock_search_unavailable(*args, **kwargs):
|
||||
raise httpx.ConnectError("refused")
|
||||
|
||||
|
||||
def test_list_models(tmp_path: Path, monkeypatch):
|
||||
monkeypatch.setattr(httpx.AsyncClient, "get", _mock_get_models)
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.get("/api/models", headers=_guest_headers(client))
|
||||
assert resp.status_code == 200
|
||||
models = resp.json()["models"]
|
||||
assert len(models) == 2
|
||||
assert models[0]["name"] == "llama3.1:latest"
|
||||
|
||||
|
||||
def test_list_models_connect_error(tmp_path: Path, monkeypatch):
|
||||
monkeypatch.setattr(httpx.AsyncClient, "get", _mock_connect_error)
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.get("/api/models", headers=_guest_headers(client))
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
def test_running_models(tmp_path: Path, monkeypatch):
|
||||
monkeypatch.setattr(httpx.AsyncClient, "get", _mock_get_models)
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.get("/api/ps", headers=_guest_headers(client))
|
||||
assert resp.status_code == 200
|
||||
assert "data" in resp.json()
|
||||
|
||||
|
||||
def test_running_models_connect_error(tmp_path: Path, monkeypatch):
|
||||
monkeypatch.setattr(httpx.AsyncClient, "get", _mock_connect_error)
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.get("/api/ps", headers=_guest_headers(client))
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
def test_show_model(tmp_path: Path, monkeypatch):
|
||||
monkeypatch.setattr(httpx.AsyncClient, "post", _mock_show_model)
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.post("/api/show", json={"model": "llama3.1:latest"}, headers=_guest_headers(client))
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["modelfile"] == "FROM llama3.1"
|
||||
|
||||
|
||||
def test_show_model_connect_error(tmp_path: Path, monkeypatch):
|
||||
monkeypatch.setattr(httpx.AsyncClient, "post", _mock_connect_error)
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.post("/api/show", json={"model": "llama3.1:latest"}, headers=_guest_headers(client))
|
||||
assert resp.status_code == 502
|
||||
|
||||
|
||||
def test_system_stats(tmp_path: Path, monkeypatch):
|
||||
monkeypatch.setattr(routers.models, "get_gpu_stats", lambda: {"gpu_percent": 15, "vram_percent": 30, "available": True})
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.get("/api/stats", headers=_guest_headers(client))
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "cpu_percent" in data
|
||||
assert "memory_percent" in data
|
||||
assert data["gpu_percent"] == 15
|
||||
assert data["gpu_available"] is True
|
||||
|
||||
|
||||
def test_search_status_available(tmp_path: Path, monkeypatch):
|
||||
monkeypatch.setattr(httpx.AsyncClient, "get", _mock_search_available)
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.get("/api/search/status", headers=_guest_headers(client))
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["available"] is True
|
||||
|
||||
|
||||
def test_search_status_unavailable(tmp_path: Path, monkeypatch):
|
||||
monkeypatch.setattr(httpx.AsyncClient, "get", _mock_search_unavailable)
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.get("/api/search/status", headers=_guest_headers(client))
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["available"] is False
|
||||
128
tests/test_presets.py
Normal file
128
tests/test_presets.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import app
|
||||
import db
|
||||
from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS
|
||||
|
||||
|
||||
def make_client(tmp_path: Path) -> TestClient:
|
||||
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
||||
db.DB_PATH = tmp_path / "jarvischat-presets.db"
|
||||
SESSIONS.clear()
|
||||
PIN_ATTEMPTS.clear()
|
||||
RATE_EVENTS.clear()
|
||||
db.init_db()
|
||||
return TestClient(app.app, raise_server_exceptions=False)
|
||||
|
||||
|
||||
def _admin_headers(client: TestClient) -> dict:
|
||||
login = client.post("/api/auth/login", json={"pin": "1234"}, headers={"Origin": "http://testserver"})
|
||||
sid = login.json()["session_id"]
|
||||
return {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||
|
||||
|
||||
def _guest_headers(client: TestClient) -> dict:
|
||||
sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()["session_id"]
|
||||
return {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||
|
||||
|
||||
def test_list_presets_returns_defaults(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.get("/api/presets", headers=_guest_headers(client))
|
||||
assert resp.status_code == 200
|
||||
presets = resp.json()
|
||||
assert len(presets) >= 3
|
||||
names = [p["name"] for p in presets]
|
||||
assert "Coding Companion" in names
|
||||
|
||||
|
||||
def test_create_preset(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
headers = _admin_headers(client)
|
||||
resp = client.post("/api/presets", json={"name": "My Preset", "prompt": "You are helpful."}, headers=headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["name"] == "My Preset"
|
||||
assert data["prompt"] == "You are helpful."
|
||||
|
||||
presets = client.get("/api/presets", headers=_guest_headers(client)).json()
|
||||
assert any(p["name"] == "My Preset" for p in presets)
|
||||
|
||||
|
||||
def test_create_preset_requires_name_and_prompt(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
headers = _admin_headers(client)
|
||||
resp = client.post("/api/presets", json={"name": "", "prompt": ""}, headers=headers)
|
||||
assert resp.status_code == 400
|
||||
|
||||
resp = client.post("/api/presets", json={"name": "Only Name"}, headers=headers)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_update_preset(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
headers = _admin_headers(client)
|
||||
create = client.post("/api/presets", json={"name": "Old", "prompt": "Old prompt."}, headers=headers)
|
||||
preset_id = create.json()["id"]
|
||||
|
||||
update = client.put(f"/api/presets/{preset_id}", json={"name": "New", "prompt": "New prompt."}, headers=headers)
|
||||
assert update.status_code == 200
|
||||
|
||||
presets = client.get("/api/presets", headers=_guest_headers(client)).json()
|
||||
updated = next(p for p in presets if p["id"] == preset_id)
|
||||
assert updated["name"] == "New"
|
||||
assert updated["prompt"] == "New prompt."
|
||||
|
||||
|
||||
def test_update_preset_requires_fields(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
headers = _admin_headers(client)
|
||||
resp = client.put("/api/presets/nope", json={"name": "", "prompt": ""}, headers=headers)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_delete_preset(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
headers = _admin_headers(client)
|
||||
create = client.post("/api/presets", json={"name": "Temp", "prompt": "Temp."}, headers=headers)
|
||||
preset_id = create.json()["id"]
|
||||
|
||||
delete = client.delete(f"/api/presets/{preset_id}", headers=headers)
|
||||
assert delete.status_code == 200
|
||||
|
||||
presets = client.get("/api/presets", headers=_guest_headers(client)).json()
|
||||
assert not any(p["id"] == preset_id for p in presets)
|
||||
|
||||
|
||||
def test_delete_default_preset_is_noop(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
headers = _admin_headers(client)
|
||||
presets_before = client.get("/api/presets", headers=_guest_headers(client)).json()
|
||||
default = next(p for p in presets_before if p["is_default"])
|
||||
|
||||
delete = client.delete(f"/api/presets/{default['id']}", headers=headers)
|
||||
assert delete.status_code == 200
|
||||
|
||||
presets_after = client.get("/api/presets", headers=_guest_headers(client)).json()
|
||||
assert any(p["id"] == default["id"] for p in presets_after)
|
||||
|
||||
|
||||
def test_guest_cannot_create_preset(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.post("/api/presets", json={"name": "Hack", "prompt": "Hack"}, headers=_guest_headers(client))
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
def test_guest_cannot_update_preset(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.put("/api/presets/some-id", json={"name": "Hack", "prompt": "Hack"}, headers=_guest_headers(client))
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
def test_guest_cannot_delete_preset(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.delete("/api/presets/some-id", headers=_guest_headers(client))
|
||||
assert resp.status_code == 403
|
||||
72
tests/test_profile.py
Normal file
72
tests/test_profile.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import app
|
||||
import config
|
||||
import db
|
||||
from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS
|
||||
|
||||
|
||||
def make_client(tmp_path: Path) -> TestClient:
|
||||
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
||||
db.DB_PATH = tmp_path / "jarvischat-profile.db"
|
||||
SESSIONS.clear()
|
||||
PIN_ATTEMPTS.clear()
|
||||
RATE_EVENTS.clear()
|
||||
db.init_db()
|
||||
return TestClient(app.app, raise_server_exceptions=False)
|
||||
|
||||
|
||||
def _admin_headers(client: TestClient) -> dict:
|
||||
login = client.post("/api/auth/login", json={"pin": "1234"}, headers={"Origin": "http://testserver"})
|
||||
sid = login.json()["session_id"]
|
||||
return {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||
|
||||
|
||||
def _guest_headers(client: TestClient) -> dict:
|
||||
sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()["session_id"]
|
||||
return {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||
|
||||
|
||||
def test_get_profile_returns_content(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.get("/api/profile", headers=_guest_headers(client))
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "content" in data
|
||||
assert "updated_at" in data
|
||||
assert len(data["content"]) > 0
|
||||
|
||||
|
||||
def test_get_default_profile(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.get("/api/profile/default", headers=_guest_headers(client))
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["content"] == config.DEFAULT_PROFILE
|
||||
|
||||
|
||||
def test_update_profile(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
headers = _admin_headers(client)
|
||||
resp = client.put("/api/profile", json={"content": "Custom profile text."}, headers=headers)
|
||||
assert resp.status_code == 200
|
||||
assert "updated_at" in resp.json()
|
||||
|
||||
get_resp = client.get("/api/profile", headers=_guest_headers(client))
|
||||
assert get_resp.json()["content"] == "Custom profile text."
|
||||
|
||||
|
||||
def test_update_profile_too_long(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
headers = _admin_headers(client)
|
||||
long_content = "x" * (config.MAX_PROFILE_CHARS + 1)
|
||||
resp = client.put("/api/profile", json={"content": long_content}, headers=headers)
|
||||
assert resp.status_code == 413
|
||||
|
||||
|
||||
def test_guest_cannot_update_profile(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
resp = client.put("/api/profile", json={"content": "hack"}, headers=_guest_headers(client))
|
||||
assert resp.status_code == 403
|
||||
186
tests/test_search_route.py
Normal file
186
tests/test_search_route.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import app
|
||||
import config
|
||||
import db
|
||||
import routers.search_route
|
||||
from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS
|
||||
|
||||
|
||||
def make_client(tmp_path: Path) -> TestClient:
|
||||
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
||||
db.DB_PATH = tmp_path / "jarvischat-search-route.db"
|
||||
SESSIONS.clear()
|
||||
PIN_ATTEMPTS.clear()
|
||||
RATE_EVENTS.clear()
|
||||
db.init_db()
|
||||
return TestClient(app.app, raise_server_exceptions=False)
|
||||
|
||||
|
||||
def _guest_headers(client: TestClient) -> dict:
|
||||
sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()["session_id"]
|
||||
return {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||
|
||||
|
||||
def parse_sse_payloads(body: str) -> list[dict]:
|
||||
payloads: list[dict] = []
|
||||
for chunk in body.split("\n\n"):
|
||||
chunk = chunk.strip()
|
||||
if not chunk.startswith("data: "):
|
||||
continue
|
||||
raw = chunk[len("data: ") :]
|
||||
payloads.append(json.loads(raw))
|
||||
return payloads
|
||||
|
||||
|
||||
class _MockStreamResponse:
|
||||
def __init__(self, lines: list[str]):
|
||||
self._lines = lines
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def aiter_lines(self):
|
||||
for line in self._lines:
|
||||
yield line
|
||||
|
||||
|
||||
def _stream_json_lines(events: list[dict]) -> list[str]:
|
||||
return [json.dumps(event) for event in events]
|
||||
|
||||
|
||||
def test_explicit_search_with_results(tmp_path: Path, monkeypatch):
|
||||
with make_client(tmp_path) as client:
|
||||
headers = _guest_headers(client)
|
||||
|
||||
async def search_stub(query: str, max_results: int = 5):
|
||||
return [
|
||||
{"title": "Result One", "url": "https://example.com/1", "content": "First result content."},
|
||||
{"title": "Result Two", "url": "https://example.com/2", "content": "Second result content."},
|
||||
]
|
||||
|
||||
monkeypatch.setattr(routers.search_route, "query_searxng", search_stub)
|
||||
|
||||
events = _stream_json_lines([
|
||||
{"choices": [{"delta": {"content": "Here's what I found"}, "logprobs": None}]},
|
||||
{"choices": [{"delta": {"content": " about your query."}, "logprobs": None}]},
|
||||
{"choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {}},
|
||||
])
|
||||
|
||||
def stream_stub(self, method, url, json=None, timeout=None):
|
||||
return _MockStreamResponse(events)
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "stream", stream_stub)
|
||||
|
||||
resp = client.post(
|
||||
"/api/search",
|
||||
json={"query": "current events", "model": config.DEFAULT_MODEL},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
payloads = parse_sse_payloads(resp.text)
|
||||
|
||||
assert any(p.get("searching") is True for p in payloads)
|
||||
assert any("search_results" in p for p in payloads)
|
||||
token_text = "".join(p.get("token", "") for p in payloads if "token" in p)
|
||||
assert "found" in token_text.lower()
|
||||
assert any(p.get("done") and p.get("searched") for p in payloads)
|
||||
|
||||
|
||||
def test_explicit_search_no_results(tmp_path: Path, monkeypatch):
|
||||
with make_client(tmp_path) as client:
|
||||
headers = _guest_headers(client)
|
||||
|
||||
async def empty_search(query: str, max_results: int = 5):
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(routers.search_route, "query_searxng", empty_search)
|
||||
|
||||
resp = client.post(
|
||||
"/api/search",
|
||||
json={"query": "nothingness", "model": config.DEFAULT_MODEL},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
payloads = parse_sse_payloads(resp.text)
|
||||
|
||||
assert any("No search results found" in p.get("token", "") for p in payloads)
|
||||
assert any(p.get("done") for p in payloads)
|
||||
assert not any("search_results" in p for p in payloads)
|
||||
|
||||
|
||||
def test_explicit_search_new_conversation_created(tmp_path: Path, monkeypatch):
|
||||
with make_client(tmp_path) as client:
|
||||
headers = _guest_headers(client)
|
||||
|
||||
async def search_stub(query: str, max_results: int = 5):
|
||||
return [{"title": "T", "url": "https://ex.com", "content": "Content."}]
|
||||
|
||||
monkeypatch.setattr(routers.search_route, "query_searxng", search_stub)
|
||||
|
||||
events = _stream_json_lines([
|
||||
{"choices": [{"delta": {"content": "Answer."}, "logprobs": None}]},
|
||||
{"choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {}},
|
||||
])
|
||||
|
||||
def stream_stub(self, method, url, json=None, timeout=None):
|
||||
return _MockStreamResponse(events)
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "stream", stream_stub)
|
||||
|
||||
resp = client.post(
|
||||
"/api/search",
|
||||
json={"query": "tell me something", "model": config.DEFAULT_MODEL},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
payloads = parse_sse_payloads(resp.text)
|
||||
|
||||
conv_id = None
|
||||
for p in payloads:
|
||||
if "conversation_id" in p:
|
||||
conv_id = p["conversation_id"]
|
||||
break
|
||||
assert conv_id is not None
|
||||
|
||||
conv_resp = client.get(f"/api/conversations/{conv_id}", headers=_guest_headers(client))
|
||||
assert conv_resp.status_code == 200
|
||||
data = conv_resp.json()
|
||||
assert len(data["messages"]) >= 2
|
||||
|
||||
|
||||
def test_explicit_search_stream_error(tmp_path: Path, monkeypatch):
|
||||
with make_client(tmp_path) as client:
|
||||
headers = _guest_headers(client)
|
||||
|
||||
async def search_stub(query: str, max_results: int = 5):
|
||||
return [{"title": "T", "url": "https://ex.com", "content": "Content."}]
|
||||
|
||||
monkeypatch.setattr(routers.search_route, "query_searxng", search_stub)
|
||||
|
||||
def broken_stream(self, method, url, json=None, timeout=None):
|
||||
class BrokenCtx:
|
||||
async def __aenter__(self):
|
||||
raise RuntimeError("summarization failed")
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
return BrokenCtx()
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "stream", broken_stream)
|
||||
|
||||
resp = client.post(
|
||||
"/api/search",
|
||||
json={"query": "breaking news", "model": config.DEFAULT_MODEL},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert "error_key" in resp.text
|
||||
assert "INC-" in resp.text
|
||||
Reference in New Issue
Block a user