Files
jarvisChat/tests/test_chat_streaming_and_memory_paths.py
gramps 5986c4ad86 fix: close two CSRF origin-check security gaps
- Extend origin check to all /api/ requests (not just state-changing methods),
  closing the GET/HEAD/OPTIONS bypass that allowed cross-origin reads
- origin_allowed() now returns False when both Origin and Referer headers
  are absent, preventing script-initiated requests from bypassing the check
- Update AGENTS.md and README.md to document the changes
2026-06-27 15:20:02 -07:00

194 lines
6.6 KiB
Python

import json
import os
from pathlib import Path
import httpx
from fastapi.testclient import TestClient
import app
import config
import db
import routers.chat
from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS
def make_client(tmp_path: Path) -> TestClient:
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
db.DB_PATH = tmp_path / "jarvischat-streaming.db"
SESSIONS.clear()
PIN_ATTEMPTS.clear()
RATE_EVENTS.clear()
db.init_db()
return TestClient(app.app, raise_server_exceptions=False)
def parse_sse_payloads(body: str) -> list[dict]:
payloads: list[dict] = []
for chunk in body.split("\n\n"):
chunk = chunk.strip()
if not chunk.startswith("data: "):
continue
raw = chunk[len("data: ") :]
payloads.append(json.loads(raw))
return payloads
class _MockStreamResponse:
def __init__(self, lines: list[str]):
self._lines = lines
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
async def aiter_lines(self):
for line in self._lines:
yield line
def _stream_json_lines(events: list[dict]) -> list[str]:
return [json.dumps(event) for event in events]
def test_chat_stream_emits_tokens_and_done(tmp_path: Path, monkeypatch):
with make_client(tmp_path) as client:
sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()[
"session_id"
]
headers = {"X-Session-ID": sid, "Origin": "http://testserver"}
events = _stream_json_lines(
[
{"message": {"content": "Hel"}, "logprobs": [{"logprob": -0.01}]},
{"message": {"content": "lo"}, "logprobs": [{"logprob": -0.01}]},
{"done": True, "eval_count": 2, "eval_duration": 1000000000},
]
)
def stream_stub(self, method, url, json=None, timeout=None):
return _MockStreamResponse(events)
monkeypatch.setattr(httpx.AsyncClient, "stream", stream_stub)
resp = client.post(
"/api/chat",
json={"message": "hello", "model": config.DEFAULT_MODEL},
headers=headers,
)
assert resp.status_code == 200
payloads = parse_sse_payloads(resp.text)
token_text = "".join(p.get("token", "") for p in payloads if "token" in p)
assert token_text == "Hello"
done_events = [p for p in payloads if p.get("done")]
assert done_events
assert "searched" not in done_events[-1]
def test_chat_auto_search_trigger_emits_search_events(tmp_path: Path, monkeypatch):
with make_client(tmp_path) as client:
sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()[
"session_id"
]
headers = {"X-Session-ID": sid, "Origin": "http://testserver"}
first_stream = _stream_json_lines(
[
{
"message": {"content": "I don't have current data on that question."},
"logprobs": [{"logprob": -5.0}],
},
{"done": True, "eval_count": 2, "eval_duration": 1000000000},
]
)
second_stream = _stream_json_lines(
[
{"message": {"content": "Based on current data: 42."}},
{"done": True},
]
)
stream_batches = [first_stream, second_stream]
def stream_stub(self, method, url, json=None, timeout=None):
return _MockStreamResponse(stream_batches.pop(0))
async def search_stub(query: str, max_results: int = 5):
return [
{
"title": "Answer",
"url": "https://example.com",
"content": "The value is 42.",
}
]
monkeypatch.setattr(httpx.AsyncClient, "stream", stream_stub)
monkeypatch.setattr(routers.chat, "query_searxng", search_stub)
resp = client.post(
"/api/chat",
json={"message": "what is the latest value", "model": config.DEFAULT_MODEL},
headers=headers,
)
assert resp.status_code == 200
payloads = parse_sse_payloads(resp.text)
assert any(p.get("searching") is True for p in payloads)
assert any("search_results" in p for p in payloads)
assert any(p.get("augmented") is True for p in payloads)
done_events = [p for p in payloads if p.get("done")]
assert done_events and done_events[-1].get("searched") is True
def test_memory_command_paths_remember_and_forget(tmp_path: Path, monkeypatch):
with make_client(tmp_path) as client:
sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()[
"session_id"
]
headers = {"X-Session-ID": sid, "Origin": "http://testserver"}
base_stream = _stream_json_lines(
[
{"message": {"content": "ok"}, "logprobs": [{"logprob": -0.01}]},
{"done": True, "eval_count": 1, "eval_duration": 1000000000},
]
)
def stream_stub(self, method, url, json=None, timeout=None):
return _MockStreamResponse(base_stream)
monkeypatch.setattr(httpx.AsyncClient, "stream", stream_stub)
remember_resp = client.post(
"/api/chat",
json={
"message": "remember that my favorite language is rust",
"model": config.DEFAULT_MODEL,
},
headers=headers,
)
assert remember_resp.status_code == 200
remember_events = parse_sse_payloads(remember_resp.text)
assert any("Remembered" in p.get("token", "") for p in remember_events)
memories_after_add = client.get("/api/memories", headers={"X-Session-ID": sid, "Origin": "http://testserver"})
assert memories_after_add.status_code == 200
assert memories_after_add.json().get("count", 0) >= 1
forget_resp = client.post(
"/api/chat",
json={
"message": "forget about my favorite language",
"model": config.DEFAULT_MODEL,
},
headers=headers,
)
assert forget_resp.status_code == 200
forget_events = parse_sse_payloads(forget_resp.text)
assert any("Forgot" in p.get("token", "") for p in forget_events)
memories_after_forget = client.get("/api/memories", headers={"X-Session-ID": sid, "Origin": "http://testserver"})
assert memories_after_forget.status_code == 200
assert memories_after_forget.json().get("count", 0) == 0