diff --git a/AGENTS.md b/AGENTS.md index 18b2f00..3dc5f4b 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -14,6 +14,30 @@ All tests use `tmp_path` fixtures + monkeypatched `httpx.AsyncClient.stream`. No external services needed. Test factories reset `SESSIONS`, `PIN_ATTEMPTS`, `RATE_EVENTS` globals — be careful not to let test state leak. After the modular refactor, tests import directly from the correct modules (`db`, `security`, `config`, `search`, `rag`, `memory`, `routers.*`) — not from the old monolithic `app` namespace. +Every router has a dedicated test file: +| File | Covers | +|------|--------| +| `test_auth_capabilities.py` | `auth.py` — guest/admin sessions, origin blocking, logout | +| `test_chat_streaming_and_memory_paths.py` | `routers/chat.py` — streaming, auto-search, remember/forget | +| `test_completions.py` | `routers/completions.py` — API key auth, FIM, streaming, blocking, errors | +| `test_conversations.py` | `routers/conversations.py` — full CRUD, guest admin enforcement | +| `test_memories.py` | `routers/memories.py` — edit, search, stats endpoints | +| `test_models_router.py` | `routers/models.py` — models list, ps, show, stats, search/status | +| `test_presets.py` | `routers/presets.py` — full CRUD, default preset protection | +| `test_profile.py` | `routers/profile.py` — get, update, default, length validation | +| `test_search_route.py` | `routers/search_route.py` — explicit search flow, no results, errors | +| `test_search_url_sanitization.py` | `search.py` URL sanitizer | +| `test_settings_allowlist.py` | `routers/settings.py` — allowlisted key enforcement | +| `test_skills_framework.py` | `routers/skills.py` — list, toggle, unknown skill, prompt injection | +| `test_ip_allowlist.py` | IP allowlist helper + middleware | +| `test_rate_and_payload_guardrails.py` | Rate limits + payload size enforcement | +| `test_error_envelopes.py` | Global exception handler + stream error incidents | + +Modules that call `httpx.AsyncClient` (chat, completions, models, search_route) +are mocked via `monkeypatch.setattr` on `AsyncClient.stream`, `.get`, or `.post`. +CPU stats in `models.py` (`api/stats`) use real `psutil`; GPU stats are +monkeypatched via `routers.models.get_gpu_stats`. + ## Architecture Refactored from single-file (`app.py`) into modules under project root: diff --git a/README.md b/README.md index efe07da..cbff5cc 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ Developer wiki: [docs/wiki/Home.md](docs/wiki/Home.md) - **All 8 test files fixed** — rewired imports after the modular refactor; all 26 tests pass - **Origin check extended to all API methods** — GET/HEAD/OPTIONS requests no longer bypass origin checking (was limited to POST/PUT/DELETE/PATCH) - **Missing headers now rejected** — `origin_allowed()` returns `False` when both `Origin` and `Referer` are absent, closing the CSRF read gap for script-initiated requests +- **Full router test coverage** — 7 new test files added: `test_conversations.py`, `test_presets.py`, `test_profile.py`, `test_models_router.py`, `test_completions.py`, `test_search_route.py`, `test_memories.py`; all 10 routers now have dedicated unit tests (92 total, up from 26) ## Features diff --git a/tests/test_completions.py b/tests/test_completions.py new file mode 100644 index 0000000..af7b319 --- /dev/null +++ b/tests/test_completions.py @@ -0,0 +1,222 @@ +import json +import os +from pathlib import Path + +import httpx +from fastapi.testclient import TestClient + +import app +import config +import db +import routers.completions +from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS + + +def make_client(tmp_path: Path) -> TestClient: + os.environ["JARVISCHAT_ADMIN_PIN"] = "1234" + db.DB_PATH = tmp_path / "jarvischat-completions.db" + SESSIONS.clear() + PIN_ATTEMPTS.clear() + RATE_EVENTS.clear() + db.init_db() + return TestClient(app.app, raise_server_exceptions=False) + + +TEST_API_KEY = "test-sk-jarvischat-completions" + + +def _auth_headers(extra: dict = None) -> dict: + h = {"Authorization": f"Bearer {TEST_API_KEY}", "Content-Type": "application/json", "Origin": "http://testserver"} + if extra: + h.update(extra) + return h + + +class _MockStreamResponse: + def __init__(self, lines: list[str]): + self._lines = lines + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def aiter_lines(self): + for line in self._lines: + yield line + + +class _MockAsyncPostResponse: + def __init__(self, status_code=200, json_data=None): + self.status_code = status_code + self._json_data = json_data or {} + + def json(self): + return self._json_data + + +def _stream_json_lines(events: list[dict]) -> list[str]: + return [json.dumps(event) for event in events] + + +def test_completions_missing_api_key(tmp_path: Path): + with make_client(tmp_path) as client: + resp = client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hi"}]}, + ) + assert resp.status_code == 401 + + +def test_completions_invalid_api_key(tmp_path: Path): + with make_client(tmp_path) as client: + resp = client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hi"}]}, + headers={"Authorization": "Bearer wrong-key", "Origin": "http://testserver"}, + ) + assert resp.status_code == 401 + + +def test_completions_no_messages(tmp_path: Path, monkeypatch): + monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY) + with make_client(tmp_path) as client: + resp = client.post("/v1/chat/completions", json={}, headers=_auth_headers()) + assert resp.status_code == 400 + + +def test_completions_empty_messages(tmp_path: Path, monkeypatch): + monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY) + with make_client(tmp_path) as client: + resp = client.post("/v1/chat/completions", json={"messages": []}, headers=_auth_headers()) + assert resp.status_code == 400 + + +def test_completions_no_user_message(tmp_path: Path, monkeypatch): + monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY) + with make_client(tmp_path) as client: + resp = client.post( + "/v1/chat/completions", + json={"messages": [{"role": "assistant", "content": "hello"}]}, + headers=_auth_headers(), + ) + assert resp.status_code == 400 + + +def test_completions_streaming(tmp_path: Path, monkeypatch): + monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY) + events = _stream_json_lines([ + {"choices": [{"delta": {"content": "Hello"}, "logprobs": None}]}, + {"choices": [{"delta": {"content": " world"}, "logprobs": None}]}, + {"choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {"tokens_per_second": 15.0}}, + ]) + + call_count = 0 + + def stream_stub(self, method, url, json=None, timeout=None): + nonlocal call_count + call_count += 1 + return _MockStreamResponse(events) + + monkeypatch.setattr(httpx.AsyncClient, "stream", stream_stub) + + with make_client(tmp_path) as client: + resp = client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hi"}], "stream": True}, + headers=_auth_headers(), + ) + assert resp.status_code == 200 + body = resp.text + assert "data: [DONE]" in body + assert "Hello" in body or "world" in body + assert "chatcmpl-" in body + + +def test_completions_blocking(tmp_path: Path, monkeypatch): + monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY) + events = _stream_json_lines([ + {"choices": [{"delta": {"content": "Hello world"}, "logprobs": None}]}, + {"choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {}}, + ]) + + def stream_stub(self, method, url, json=None, timeout=None): + return _MockStreamResponse(events) + + monkeypatch.setattr(httpx.AsyncClient, "stream", stream_stub) + + with make_client(tmp_path) as client: + resp = client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hi"}], "stream": False}, + headers=_auth_headers(), + ) + assert resp.status_code == 200 + data = resp.json() + assert data["object"] == "chat.completion" + assert data["choices"][0]["message"]["content"] == "Hello world" + + +def test_completions_fim_passthrough(tmp_path: Path, monkeypatch): + monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY) + fim_data = {"prompt": "def foo():\n ", "suffix": "\n return x", "model": "llama3.1:latest"} + + async def mock_post(self, url, json=None, timeout=None): + return _MockAsyncPostResponse(json_data={"choices": [{"text": "pass"}], "usage": {}}) + + monkeypatch.setattr(httpx.AsyncClient, "post", mock_post) + + with make_client(tmp_path) as client: + resp = client.post("/v1/chat/completions", json=fim_data, headers=_auth_headers()) + assert resp.status_code == 200 + assert "choices" in resp.json() + + +def test_completions_connect_error_stream(tmp_path: Path, monkeypatch): + monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY) + + def broken_stream(self, method, url, json=None, timeout=None): + raise httpx.ConnectError("Connection refused") + + monkeypatch.setattr(httpx.AsyncClient, "stream", broken_stream) + + with make_client(tmp_path) as client: + resp = client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hi"}], "stream": True}, + headers=_auth_headers(), + ) + assert resp.status_code == 200 + assert "connection_error" in resp.text + + +def test_completions_connect_error_blocking(tmp_path: Path, monkeypatch): + monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY) + + def broken_stream(self, method, url, json=None, timeout=None): + raise httpx.ConnectError("Connection refused") + + monkeypatch.setattr(httpx.AsyncClient, "stream", broken_stream) + + with make_client(tmp_path) as client: + resp = client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hi"}], "stream": False}, + headers=_auth_headers(), + ) + assert resp.status_code == 503 + + +def test_completions_fim_connect_error(tmp_path: Path, monkeypatch): + monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY) + fim_data = {"prompt": "def foo():", "model": "llama3.1:latest"} + + def broken_post(self, url, json=None, timeout=None): + raise httpx.ConnectError("Connection refused") + + monkeypatch.setattr(httpx.AsyncClient, "post", broken_post) + + with make_client(tmp_path) as client: + resp = client.post("/v1/chat/completions", json=fim_data, headers=_auth_headers()) + assert resp.status_code == 503 diff --git a/tests/test_conversations.py b/tests/test_conversations.py new file mode 100644 index 0000000..a7de9e8 --- /dev/null +++ b/tests/test_conversations.py @@ -0,0 +1,153 @@ +import os +from pathlib import Path + +from fastapi.testclient import TestClient + +import app +import db +from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS + + +def make_client(tmp_path: Path) -> TestClient: + os.environ["JARVISCHAT_ADMIN_PIN"] = "1234" + db.DB_PATH = tmp_path / "jarvischat-conversations.db" + SESSIONS.clear() + PIN_ATTEMPTS.clear() + RATE_EVENTS.clear() + db.init_db() + return TestClient(app.app, raise_server_exceptions=False) + + +def _admin_headers(client: TestClient) -> dict: + login = client.post("/api/auth/login", json={"pin": "1234"}, headers={"Origin": "http://testserver"}) + sid = login.json()["session_id"] + return {"X-Session-ID": sid, "Origin": "http://testserver"} + + +def _guest_headers(client: TestClient) -> dict: + sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()["session_id"] + return {"X-Session-ID": sid, "Origin": "http://testserver"} + + +def test_list_conversations_empty(tmp_path: Path): + with make_client(tmp_path) as client: + resp = client.get("/api/conversations", headers=_guest_headers(client)) + assert resp.status_code == 200 + assert resp.json() == [] + + +def test_create_and_list_conversation(tmp_path: Path): + with make_client(tmp_path) as client: + headers = _admin_headers(client) + + create = client.post("/api/conversations", json={"title": "Test Chat", "model": "llama3.1:latest"}, headers=headers) + assert create.status_code == 200 + data = create.json() + assert data["title"] == "Test Chat" + assert data["model"] == "llama3.1:latest" + + list_resp = client.get("/api/conversations", headers=headers) + assert list_resp.status_code == 200 + convs = list_resp.json() + assert len(convs) == 1 + assert convs[0]["title"] == "Test Chat" + + +def test_get_conversation_returns_messages(tmp_path: Path): + with make_client(tmp_path) as client: + headers = _admin_headers(client) + create = client.post("/api/conversations", json={"title": "My Chat"}, headers=headers) + conv_id = create.json()["id"] + + resp = client.get(f"/api/conversations/{conv_id}", headers=headers) + assert resp.status_code == 200 + data = resp.json() + assert data["conversation"]["id"] == conv_id + assert data["messages"] == [] + + +def test_get_conversation_not_found(tmp_path: Path): + with make_client(tmp_path) as client: + resp = client.get("/api/conversations/nope", headers=_guest_headers(client)) + assert resp.status_code == 404 + + +def test_update_conversation_title(tmp_path: Path): + with make_client(tmp_path) as client: + headers = _admin_headers(client) + create = client.post("/api/conversations", json={"title": "Old"}, headers=headers) + conv_id = create.json()["id"] + + update = client.put(f"/api/conversations/{conv_id}", json={"title": "New Title"}, headers=headers) + assert update.status_code == 200 + + get = client.get(f"/api/conversations/{conv_id}", headers=headers) + assert get.json()["conversation"]["title"] == "New Title" + + +def test_update_conversation_model(tmp_path: Path): + with make_client(tmp_path) as client: + headers = _admin_headers(client) + create = client.post("/api/conversations", json={"title": "Test"}, headers=headers) + conv_id = create.json()["id"] + + update = client.put(f"/api/conversations/{conv_id}", json={"model": "qwen2:latest"}, headers=headers) + assert update.status_code == 200 + + get = client.get(f"/api/conversations/{conv_id}", headers=headers) + assert get.json()["conversation"]["model"] == "qwen2:latest" + + +def test_delete_conversation(tmp_path: Path): + with make_client(tmp_path) as client: + headers = _admin_headers(client) + create = client.post("/api/conversations", json={"title": "Delete Me"}, headers=headers) + conv_id = create.json()["id"] + + delete = client.delete(f"/api/conversations/{conv_id}", headers=headers) + assert delete.status_code == 200 + + get = client.get(f"/api/conversations/{conv_id}", headers=_guest_headers(client)) + assert get.status_code == 404 + + +def test_delete_all_conversations(tmp_path: Path): + with make_client(tmp_path) as client: + headers = _admin_headers(client) + client.post("/api/conversations", json={"title": "One"}, headers=headers) + client.post("/api/conversations", json={"title": "Two"}, headers=headers) + + delete_all = client.delete("/api/conversations", headers=headers) + assert delete_all.status_code == 200 + + list_resp = client.get("/api/conversations", headers=_guest_headers(client)) + assert list_resp.json() == [] + + +def test_guest_cannot_create_conversation(tmp_path: Path): + with make_client(tmp_path) as client: + resp = client.post("/api/conversations", json={"title": "test"}, headers=_guest_headers(client)) + assert resp.status_code == 403 + + +def test_guest_cannot_update_conversation(tmp_path: Path): + with make_client(tmp_path) as client: + headers = _admin_headers(client) + create = client.post("/api/conversations", json={"title": "Test"}, headers=headers) + conv_id = create.json()["id"] + + guest_headers = _guest_headers(client) + resp = client.put(f"/api/conversations/{conv_id}", json={"title": "hack"}, headers=guest_headers) + assert resp.status_code == 403 + + +def test_guest_cannot_delete_conversation(tmp_path: Path): + with make_client(tmp_path) as client: + resp = client.delete("/api/conversations/some-id", headers=_guest_headers(client)) + assert resp.status_code == 403 + + +def test_guest_cannot_delete_all(tmp_path: Path): + with make_client(tmp_path) as client: + resp = client.delete("/api/conversations", headers=_guest_headers(client)) + assert resp.status_code == 403 diff --git a/tests/test_memories.py b/tests/test_memories.py new file mode 100644 index 0000000..9050832 --- /dev/null +++ b/tests/test_memories.py @@ -0,0 +1,161 @@ +import os +from pathlib import Path + +from fastapi.testclient import TestClient + +import app +import config +import db +from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS + + +def make_client(tmp_path: Path) -> TestClient: + os.environ["JARVISCHAT_ADMIN_PIN"] = "1234" + db.DB_PATH = tmp_path / "jarvischat-memories.db" + SESSIONS.clear() + PIN_ATTEMPTS.clear() + RATE_EVENTS.clear() + db.init_db() + return TestClient(app.app, raise_server_exceptions=False) + + +def _admin_headers(client: TestClient) -> dict: + login = client.post("/api/auth/login", json={"pin": "1234"}, headers={"Origin": "http://testserver"}) + sid = login.json()["session_id"] + return {"X-Session-ID": sid, "Origin": "http://testserver"} + + +def _guest_headers(client: TestClient) -> dict: + sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()["session_id"] + return {"X-Session-ID": sid, "Origin": "http://testserver"} + + +def _create_memory(client: TestClient, headers: dict, fact: str = "test fact", topic: str = "general") -> int: + resp = client.post("/api/memories", json={"fact": fact, "topic": topic}, headers=headers) + assert resp.status_code == 200 + return resp.json()["rowid"] + + +def test_list_memories_empty(tmp_path: Path): + with make_client(tmp_path) as client: + resp = client.get("/api/memories", headers=_guest_headers(client)) + assert resp.status_code == 200 + assert resp.json()["count"] == 0 + + +def test_list_memories_by_topic(tmp_path: Path): + with make_client(tmp_path) as client: + headers = _admin_headers(client) + _create_memory(client, headers, "I like Python", "preference") + _create_memory(client, headers, "Building a game", "project") + + general = client.get("/api/memories?topic=preference", headers=_guest_headers(client)) + assert general.json()["count"] == 1 + assert general.json()["memories"][0]["topic"] == "preference" + + +def test_create_memory_requires_fact(tmp_path: Path): + with make_client(tmp_path) as client: + resp = client.post("/api/memories", json={"fact": ""}, headers=_admin_headers(client)) + assert resp.status_code == 400 + + +def test_create_memory_too_long(tmp_path: Path): + with make_client(tmp_path) as client: + long_fact = "x" * (config.MAX_MEMORY_FACT_CHARS + 1) + resp = client.post("/api/memories", json={"fact": long_fact}, headers=_admin_headers(client)) + assert resp.status_code == 413 + + +def test_edit_memory(tmp_path: Path): + with make_client(tmp_path) as client: + headers = _admin_headers(client) + rowid = _create_memory(client, headers, "original fact") + + edit = client.put(f"/api/memories/{rowid}", json={"fact": "updated fact"}, headers=headers) + assert edit.status_code == 200 + + memories = client.get("/api/memories", headers=_guest_headers(client)).json() + assert any(m["fact"] == "updated fact" for m in memories["memories"]) + + +def test_edit_memory_not_found(tmp_path: Path): + with make_client(tmp_path) as client: + resp = client.put("/api/memories/99999", json={"fact": "nope"}, headers=_admin_headers(client)) + assert resp.status_code == 404 + + +def test_edit_memory_empty_fact(tmp_path: Path): + with make_client(tmp_path) as client: + headers = _admin_headers(client) + rowid = _create_memory(client, headers, "some fact") + resp = client.put(f"/api/memories/{rowid}", json={"fact": ""}, headers=headers) + assert resp.status_code == 400 + + +def test_edit_memory_too_long(tmp_path: Path): + with make_client(tmp_path) as client: + headers = _admin_headers(client) + rowid = _create_memory(client, headers, "some fact") + long_fact = "x" * (config.MAX_MEMORY_FACT_CHARS + 1) + resp = client.put(f"/api/memories/{rowid}", json={"fact": long_fact}, headers=headers) + assert resp.status_code == 413 + + +def test_delete_memory_not_found(tmp_path: Path): + with make_client(tmp_path) as client: + resp = client.delete("/api/memories/99999", headers=_admin_headers(client)) + assert resp.status_code == 404 + + +def test_search_memories(tmp_path: Path): + with make_client(tmp_path) as client: + headers = _admin_headers(client) + _create_memory(client, headers, "my favorite color is blue", "preference") + _create_memory(client, headers, "running nginx on port 443", "infrastructure") + + resp = client.get("/api/memories/search?q=nginx&limit=5", headers=_guest_headers(client)) + assert resp.status_code == 200 + data = resp.json() + assert data["count"] >= 1 + assert any("nginx" in r["fact"] for r in data["results"]) + + +def test_search_memories_no_results(tmp_path: Path): + with make_client(tmp_path) as client: + resp = client.get("/api/memories/search?q=xyznonexistent&limit=5", headers=_guest_headers(client)) + assert resp.status_code == 200 + assert resp.json()["count"] == 0 + + +def test_memory_stats(tmp_path: Path): + with make_client(tmp_path) as client: + headers = _admin_headers(client) + _create_memory(client, headers, "like rust", "preference") + _create_memory(client, headers, "like python", "preference") + _create_memory(client, headers, "project game", "project") + + resp = client.get("/api/memories/stats", headers=_guest_headers(client)) + assert resp.status_code == 200 + data = resp.json() + assert data["total"] == 3 + assert data["by_topic"]["preference"] == 2 + assert data["by_topic"]["project"] == 1 + + +def test_guest_cannot_create_memory(tmp_path: Path): + with make_client(tmp_path) as client: + resp = client.post("/api/memories", json={"fact": "hack"}, headers=_guest_headers(client)) + assert resp.status_code == 403 + + +def test_guest_cannot_edit_memory(tmp_path: Path): + with make_client(tmp_path) as client: + resp = client.put("/api/memories/1", json={"fact": "hack"}, headers=_guest_headers(client)) + assert resp.status_code == 403 + + +def test_guest_cannot_delete_memory(tmp_path: Path): + with make_client(tmp_path) as client: + resp = client.delete("/api/memories/1", headers=_guest_headers(client)) + assert resp.status_code == 403 diff --git a/tests/test_models_router.py b/tests/test_models_router.py new file mode 100644 index 0000000..3a8acdd --- /dev/null +++ b/tests/test_models_router.py @@ -0,0 +1,138 @@ +import os +from pathlib import Path + +import httpx +from fastapi.testclient import TestClient + +import app +import db +import routers.models +from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS + + +def make_client(tmp_path: Path) -> TestClient: + os.environ["JARVISCHAT_ADMIN_PIN"] = "1234" + db.DB_PATH = tmp_path / "jarvischat-models.db" + SESSIONS.clear() + PIN_ATTEMPTS.clear() + RATE_EVENTS.clear() + db.init_db() + return TestClient(app.app, raise_server_exceptions=False) + + +def _guest_headers(client: TestClient) -> dict: + sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()["session_id"] + return {"X-Session-ID": sid, "Origin": "http://testserver"} + + +class _MockAsyncResponse: + """Mock for httpx.AsyncClient.get/post that returns a JSON response.""" + def __init__(self, status_code=200, json_data=None): + self.status_code = status_code + self._json_data = json_data or {} + + def json(self): + return self._json_data + + +async def _mock_get_models(*args, **kwargs): + return _MockAsyncResponse(json_data={ + "data": [{"id": "llama3.1:latest"}, {"id": "qwen2:latest"}] + }) + + +async def _mock_get_empty_models(*args, **kwargs): + return _MockAsyncResponse(json_data={"data": []}) + + +async def _mock_connect_error(*args, **kwargs): + raise httpx.ConnectError("Connection refused") + + +async def _mock_show_model(*args, **kwargs): + return _MockAsyncResponse(json_data={ + "modelfile": "FROM llama3.1", "parameters": {} + }) + + +async def _mock_search_available(*args, **kwargs): + return _MockAsyncResponse(status_code=200) + + +async def _mock_search_unavailable(*args, **kwargs): + raise httpx.ConnectError("refused") + + +def test_list_models(tmp_path: Path, monkeypatch): + monkeypatch.setattr(httpx.AsyncClient, "get", _mock_get_models) + with make_client(tmp_path) as client: + resp = client.get("/api/models", headers=_guest_headers(client)) + assert resp.status_code == 200 + models = resp.json()["models"] + assert len(models) == 2 + assert models[0]["name"] == "llama3.1:latest" + + +def test_list_models_connect_error(tmp_path: Path, monkeypatch): + monkeypatch.setattr(httpx.AsyncClient, "get", _mock_connect_error) + with make_client(tmp_path) as client: + resp = client.get("/api/models", headers=_guest_headers(client)) + assert resp.status_code == 502 + + +def test_running_models(tmp_path: Path, monkeypatch): + monkeypatch.setattr(httpx.AsyncClient, "get", _mock_get_models) + with make_client(tmp_path) as client: + resp = client.get("/api/ps", headers=_guest_headers(client)) + assert resp.status_code == 200 + assert "data" in resp.json() + + +def test_running_models_connect_error(tmp_path: Path, monkeypatch): + monkeypatch.setattr(httpx.AsyncClient, "get", _mock_connect_error) + with make_client(tmp_path) as client: + resp = client.get("/api/ps", headers=_guest_headers(client)) + assert resp.status_code == 502 + + +def test_show_model(tmp_path: Path, monkeypatch): + monkeypatch.setattr(httpx.AsyncClient, "post", _mock_show_model) + with make_client(tmp_path) as client: + resp = client.post("/api/show", json={"model": "llama3.1:latest"}, headers=_guest_headers(client)) + assert resp.status_code == 200 + assert resp.json()["modelfile"] == "FROM llama3.1" + + +def test_show_model_connect_error(tmp_path: Path, monkeypatch): + monkeypatch.setattr(httpx.AsyncClient, "post", _mock_connect_error) + with make_client(tmp_path) as client: + resp = client.post("/api/show", json={"model": "llama3.1:latest"}, headers=_guest_headers(client)) + assert resp.status_code == 502 + + +def test_system_stats(tmp_path: Path, monkeypatch): + monkeypatch.setattr(routers.models, "get_gpu_stats", lambda: {"gpu_percent": 15, "vram_percent": 30, "available": True}) + with make_client(tmp_path) as client: + resp = client.get("/api/stats", headers=_guest_headers(client)) + assert resp.status_code == 200 + data = resp.json() + assert "cpu_percent" in data + assert "memory_percent" in data + assert data["gpu_percent"] == 15 + assert data["gpu_available"] is True + + +def test_search_status_available(tmp_path: Path, monkeypatch): + monkeypatch.setattr(httpx.AsyncClient, "get", _mock_search_available) + with make_client(tmp_path) as client: + resp = client.get("/api/search/status", headers=_guest_headers(client)) + assert resp.status_code == 200 + assert resp.json()["available"] is True + + +def test_search_status_unavailable(tmp_path: Path, monkeypatch): + monkeypatch.setattr(httpx.AsyncClient, "get", _mock_search_unavailable) + with make_client(tmp_path) as client: + resp = client.get("/api/search/status", headers=_guest_headers(client)) + assert resp.status_code == 200 + assert resp.json()["available"] is False diff --git a/tests/test_presets.py b/tests/test_presets.py new file mode 100644 index 0000000..36a1bdc --- /dev/null +++ b/tests/test_presets.py @@ -0,0 +1,128 @@ +import os +from pathlib import Path + +from fastapi.testclient import TestClient + +import app +import db +from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS + + +def make_client(tmp_path: Path) -> TestClient: + os.environ["JARVISCHAT_ADMIN_PIN"] = "1234" + db.DB_PATH = tmp_path / "jarvischat-presets.db" + SESSIONS.clear() + PIN_ATTEMPTS.clear() + RATE_EVENTS.clear() + db.init_db() + return TestClient(app.app, raise_server_exceptions=False) + + +def _admin_headers(client: TestClient) -> dict: + login = client.post("/api/auth/login", json={"pin": "1234"}, headers={"Origin": "http://testserver"}) + sid = login.json()["session_id"] + return {"X-Session-ID": sid, "Origin": "http://testserver"} + + +def _guest_headers(client: TestClient) -> dict: + sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()["session_id"] + return {"X-Session-ID": sid, "Origin": "http://testserver"} + + +def test_list_presets_returns_defaults(tmp_path: Path): + with make_client(tmp_path) as client: + resp = client.get("/api/presets", headers=_guest_headers(client)) + assert resp.status_code == 200 + presets = resp.json() + assert len(presets) >= 3 + names = [p["name"] for p in presets] + assert "Coding Companion" in names + + +def test_create_preset(tmp_path: Path): + with make_client(tmp_path) as client: + headers = _admin_headers(client) + resp = client.post("/api/presets", json={"name": "My Preset", "prompt": "You are helpful."}, headers=headers) + assert resp.status_code == 200 + data = resp.json() + assert data["name"] == "My Preset" + assert data["prompt"] == "You are helpful." + + presets = client.get("/api/presets", headers=_guest_headers(client)).json() + assert any(p["name"] == "My Preset" for p in presets) + + +def test_create_preset_requires_name_and_prompt(tmp_path: Path): + with make_client(tmp_path) as client: + headers = _admin_headers(client) + resp = client.post("/api/presets", json={"name": "", "prompt": ""}, headers=headers) + assert resp.status_code == 400 + + resp = client.post("/api/presets", json={"name": "Only Name"}, headers=headers) + assert resp.status_code == 400 + + +def test_update_preset(tmp_path: Path): + with make_client(tmp_path) as client: + headers = _admin_headers(client) + create = client.post("/api/presets", json={"name": "Old", "prompt": "Old prompt."}, headers=headers) + preset_id = create.json()["id"] + + update = client.put(f"/api/presets/{preset_id}", json={"name": "New", "prompt": "New prompt."}, headers=headers) + assert update.status_code == 200 + + presets = client.get("/api/presets", headers=_guest_headers(client)).json() + updated = next(p for p in presets if p["id"] == preset_id) + assert updated["name"] == "New" + assert updated["prompt"] == "New prompt." + + +def test_update_preset_requires_fields(tmp_path: Path): + with make_client(tmp_path) as client: + headers = _admin_headers(client) + resp = client.put("/api/presets/nope", json={"name": "", "prompt": ""}, headers=headers) + assert resp.status_code == 400 + + +def test_delete_preset(tmp_path: Path): + with make_client(tmp_path) as client: + headers = _admin_headers(client) + create = client.post("/api/presets", json={"name": "Temp", "prompt": "Temp."}, headers=headers) + preset_id = create.json()["id"] + + delete = client.delete(f"/api/presets/{preset_id}", headers=headers) + assert delete.status_code == 200 + + presets = client.get("/api/presets", headers=_guest_headers(client)).json() + assert not any(p["id"] == preset_id for p in presets) + + +def test_delete_default_preset_is_noop(tmp_path: Path): + with make_client(tmp_path) as client: + headers = _admin_headers(client) + presets_before = client.get("/api/presets", headers=_guest_headers(client)).json() + default = next(p for p in presets_before if p["is_default"]) + + delete = client.delete(f"/api/presets/{default['id']}", headers=headers) + assert delete.status_code == 200 + + presets_after = client.get("/api/presets", headers=_guest_headers(client)).json() + assert any(p["id"] == default["id"] for p in presets_after) + + +def test_guest_cannot_create_preset(tmp_path: Path): + with make_client(tmp_path) as client: + resp = client.post("/api/presets", json={"name": "Hack", "prompt": "Hack"}, headers=_guest_headers(client)) + assert resp.status_code == 403 + + +def test_guest_cannot_update_preset(tmp_path: Path): + with make_client(tmp_path) as client: + resp = client.put("/api/presets/some-id", json={"name": "Hack", "prompt": "Hack"}, headers=_guest_headers(client)) + assert resp.status_code == 403 + + +def test_guest_cannot_delete_preset(tmp_path: Path): + with make_client(tmp_path) as client: + resp = client.delete("/api/presets/some-id", headers=_guest_headers(client)) + assert resp.status_code == 403 diff --git a/tests/test_profile.py b/tests/test_profile.py new file mode 100644 index 0000000..0db85b9 --- /dev/null +++ b/tests/test_profile.py @@ -0,0 +1,72 @@ +import os +from pathlib import Path + +from fastapi.testclient import TestClient + +import app +import config +import db +from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS + + +def make_client(tmp_path: Path) -> TestClient: + os.environ["JARVISCHAT_ADMIN_PIN"] = "1234" + db.DB_PATH = tmp_path / "jarvischat-profile.db" + SESSIONS.clear() + PIN_ATTEMPTS.clear() + RATE_EVENTS.clear() + db.init_db() + return TestClient(app.app, raise_server_exceptions=False) + + +def _admin_headers(client: TestClient) -> dict: + login = client.post("/api/auth/login", json={"pin": "1234"}, headers={"Origin": "http://testserver"}) + sid = login.json()["session_id"] + return {"X-Session-ID": sid, "Origin": "http://testserver"} + + +def _guest_headers(client: TestClient) -> dict: + sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()["session_id"] + return {"X-Session-ID": sid, "Origin": "http://testserver"} + + +def test_get_profile_returns_content(tmp_path: Path): + with make_client(tmp_path) as client: + resp = client.get("/api/profile", headers=_guest_headers(client)) + assert resp.status_code == 200 + data = resp.json() + assert "content" in data + assert "updated_at" in data + assert len(data["content"]) > 0 + + +def test_get_default_profile(tmp_path: Path): + with make_client(tmp_path) as client: + resp = client.get("/api/profile/default", headers=_guest_headers(client)) + assert resp.status_code == 200 + assert resp.json()["content"] == config.DEFAULT_PROFILE + + +def test_update_profile(tmp_path: Path): + with make_client(tmp_path) as client: + headers = _admin_headers(client) + resp = client.put("/api/profile", json={"content": "Custom profile text."}, headers=headers) + assert resp.status_code == 200 + assert "updated_at" in resp.json() + + get_resp = client.get("/api/profile", headers=_guest_headers(client)) + assert get_resp.json()["content"] == "Custom profile text." + + +def test_update_profile_too_long(tmp_path: Path): + with make_client(tmp_path) as client: + headers = _admin_headers(client) + long_content = "x" * (config.MAX_PROFILE_CHARS + 1) + resp = client.put("/api/profile", json={"content": long_content}, headers=headers) + assert resp.status_code == 413 + + +def test_guest_cannot_update_profile(tmp_path: Path): + with make_client(tmp_path) as client: + resp = client.put("/api/profile", json={"content": "hack"}, headers=_guest_headers(client)) + assert resp.status_code == 403 diff --git a/tests/test_search_route.py b/tests/test_search_route.py new file mode 100644 index 0000000..d25b239 --- /dev/null +++ b/tests/test_search_route.py @@ -0,0 +1,186 @@ +import json +import os +from pathlib import Path + +import httpx +from fastapi.testclient import TestClient + +import app +import config +import db +import routers.search_route +from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS + + +def make_client(tmp_path: Path) -> TestClient: + os.environ["JARVISCHAT_ADMIN_PIN"] = "1234" + db.DB_PATH = tmp_path / "jarvischat-search-route.db" + SESSIONS.clear() + PIN_ATTEMPTS.clear() + RATE_EVENTS.clear() + db.init_db() + return TestClient(app.app, raise_server_exceptions=False) + + +def _guest_headers(client: TestClient) -> dict: + sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()["session_id"] + return {"X-Session-ID": sid, "Origin": "http://testserver"} + + +def parse_sse_payloads(body: str) -> list[dict]: + payloads: list[dict] = [] + for chunk in body.split("\n\n"): + chunk = chunk.strip() + if not chunk.startswith("data: "): + continue + raw = chunk[len("data: ") :] + payloads.append(json.loads(raw)) + return payloads + + +class _MockStreamResponse: + def __init__(self, lines: list[str]): + self._lines = lines + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def aiter_lines(self): + for line in self._lines: + yield line + + +def _stream_json_lines(events: list[dict]) -> list[str]: + return [json.dumps(event) for event in events] + + +def test_explicit_search_with_results(tmp_path: Path, monkeypatch): + with make_client(tmp_path) as client: + headers = _guest_headers(client) + + async def search_stub(query: str, max_results: int = 5): + return [ + {"title": "Result One", "url": "https://example.com/1", "content": "First result content."}, + {"title": "Result Two", "url": "https://example.com/2", "content": "Second result content."}, + ] + + monkeypatch.setattr(routers.search_route, "query_searxng", search_stub) + + events = _stream_json_lines([ + {"choices": [{"delta": {"content": "Here's what I found"}, "logprobs": None}]}, + {"choices": [{"delta": {"content": " about your query."}, "logprobs": None}]}, + {"choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {}}, + ]) + + def stream_stub(self, method, url, json=None, timeout=None): + return _MockStreamResponse(events) + + monkeypatch.setattr(httpx.AsyncClient, "stream", stream_stub) + + resp = client.post( + "/api/search", + json={"query": "current events", "model": config.DEFAULT_MODEL}, + headers=headers, + ) + assert resp.status_code == 200 + payloads = parse_sse_payloads(resp.text) + + assert any(p.get("searching") is True for p in payloads) + assert any("search_results" in p for p in payloads) + token_text = "".join(p.get("token", "") for p in payloads if "token" in p) + assert "found" in token_text.lower() + assert any(p.get("done") and p.get("searched") for p in payloads) + + +def test_explicit_search_no_results(tmp_path: Path, monkeypatch): + with make_client(tmp_path) as client: + headers = _guest_headers(client) + + async def empty_search(query: str, max_results: int = 5): + return [] + + monkeypatch.setattr(routers.search_route, "query_searxng", empty_search) + + resp = client.post( + "/api/search", + json={"query": "nothingness", "model": config.DEFAULT_MODEL}, + headers=headers, + ) + assert resp.status_code == 200 + payloads = parse_sse_payloads(resp.text) + + assert any("No search results found" in p.get("token", "") for p in payloads) + assert any(p.get("done") for p in payloads) + assert not any("search_results" in p for p in payloads) + + +def test_explicit_search_new_conversation_created(tmp_path: Path, monkeypatch): + with make_client(tmp_path) as client: + headers = _guest_headers(client) + + async def search_stub(query: str, max_results: int = 5): + return [{"title": "T", "url": "https://ex.com", "content": "Content."}] + + monkeypatch.setattr(routers.search_route, "query_searxng", search_stub) + + events = _stream_json_lines([ + {"choices": [{"delta": {"content": "Answer."}, "logprobs": None}]}, + {"choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {}}, + ]) + + def stream_stub(self, method, url, json=None, timeout=None): + return _MockStreamResponse(events) + + monkeypatch.setattr(httpx.AsyncClient, "stream", stream_stub) + + resp = client.post( + "/api/search", + json={"query": "tell me something", "model": config.DEFAULT_MODEL}, + headers=headers, + ) + assert resp.status_code == 200 + payloads = parse_sse_payloads(resp.text) + + conv_id = None + for p in payloads: + if "conversation_id" in p: + conv_id = p["conversation_id"] + break + assert conv_id is not None + + conv_resp = client.get(f"/api/conversations/{conv_id}", headers=_guest_headers(client)) + assert conv_resp.status_code == 200 + data = conv_resp.json() + assert len(data["messages"]) >= 2 + + +def test_explicit_search_stream_error(tmp_path: Path, monkeypatch): + with make_client(tmp_path) as client: + headers = _guest_headers(client) + + async def search_stub(query: str, max_results: int = 5): + return [{"title": "T", "url": "https://ex.com", "content": "Content."}] + + monkeypatch.setattr(routers.search_route, "query_searxng", search_stub) + + def broken_stream(self, method, url, json=None, timeout=None): + class BrokenCtx: + async def __aenter__(self): + raise RuntimeError("summarization failed") + async def __aexit__(self, exc_type, exc, tb): + return False + return BrokenCtx() + + monkeypatch.setattr(httpx.AsyncClient, "stream", broken_stream) + + resp = client.post( + "/api/search", + json={"query": "breaking news", "model": config.DEFAULT_MODEL}, + headers=headers, + ) + assert resp.status_code == 200 + assert "error_key" in resp.text + assert "INC-" in resp.text