feat(security): add LAN IP allowlist and ingress guardrails
This commit is contained in:
240
app.py
240
app.py
@@ -27,6 +27,8 @@ import hmac
|
||||
import time
|
||||
import uuid
|
||||
import re
|
||||
import ipaddress
|
||||
from collections import defaultdict, deque
|
||||
from threading import Lock
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
@@ -72,6 +74,35 @@ TRUSTED_ORIGINS = {
|
||||
if origin.strip()
|
||||
}
|
||||
|
||||
DEFAULT_ALLOWED_CIDRS = "127.0.0.0/8,::1/128,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16"
|
||||
ALLOWED_CIDRS_RAW = os.getenv("JARVISCHAT_ALLOWED_CIDRS", DEFAULT_ALLOWED_CIDRS)
|
||||
TRUST_X_FORWARDED_FOR = (
|
||||
os.getenv("JARVISCHAT_TRUST_X_FORWARDED_FOR", "false").lower() == "true"
|
||||
)
|
||||
|
||||
# --- Rate / Payload Guardrails (home-lab friendly defaults) ---
|
||||
RATE_WINDOW_SECONDS = 60
|
||||
RL_LOGIN_PER_WINDOW = 10
|
||||
RL_CHAT_PER_WINDOW = 24
|
||||
RL_SEARCH_PER_WINDOW = 16
|
||||
RL_WRITE_PER_WINDOW = 30
|
||||
RL_DEFAULT_PER_WINDOW = 240
|
||||
RL_STATS_PER_WINDOW = 600
|
||||
|
||||
BODY_LIMIT_DEFAULT_BYTES = 64 * 1024
|
||||
BODY_LIMIT_CHAT_BYTES = 128 * 1024
|
||||
BODY_LIMIT_PROFILE_BYTES = 256 * 1024
|
||||
|
||||
MAX_CHAT_MESSAGE_CHARS = 8000
|
||||
MAX_SEARCH_QUERY_CHARS = 500
|
||||
MAX_PROFILE_CHARS = 32000
|
||||
MAX_MEMORY_FACT_CHARS = 2000
|
||||
MAX_PRESET_NAME_CHARS = 120
|
||||
MAX_PRESET_PROMPT_CHARS = 12000
|
||||
MAX_SETTINGS_KEYS = 16
|
||||
MAX_SETTINGS_VALUE_CHARS = 8000
|
||||
MAX_CONVERSATION_TITLE_CHARS = 200
|
||||
|
||||
# --- Templates and Static Files ---
|
||||
templates = Jinja2Templates(directory=str(BASE_DIR / "templates"))
|
||||
|
||||
@@ -109,7 +140,9 @@ HEDGE_PATTERNS = [
|
||||
|
||||
SESSIONS: dict[str, dict] = {}
|
||||
PIN_ATTEMPTS: dict[str, dict] = {}
|
||||
RATE_EVENTS: dict[str, deque[float]] = defaultdict(deque)
|
||||
SESSION_LOCK = Lock()
|
||||
RATE_LOCK = Lock()
|
||||
|
||||
|
||||
def hash_pin(pin: str, salt_hex: Optional[str] = None) -> tuple[str, str]:
|
||||
@@ -143,6 +176,23 @@ def audit_event(
|
||||
log.info(msg)
|
||||
|
||||
|
||||
def parse_allowed_cidrs(raw: str) -> list[ipaddress._BaseNetwork]:
|
||||
"""Parse comma-separated CIDR list into validated network objects."""
|
||||
networks: list[ipaddress._BaseNetwork] = []
|
||||
for entry in (raw or "").split(","):
|
||||
value = entry.strip()
|
||||
if not value:
|
||||
continue
|
||||
try:
|
||||
networks.append(ipaddress.ip_network(value, strict=False))
|
||||
except ValueError:
|
||||
log.warning(f"Invalid CIDR ignored: {value}")
|
||||
return networks
|
||||
|
||||
|
||||
ALLOWED_NETWORKS = parse_allowed_cidrs(ALLOWED_CIDRS_RAW)
|
||||
|
||||
|
||||
def clean_hedging(text: str) -> str:
|
||||
"""Remove hedging sentences from model response."""
|
||||
cleaned = text
|
||||
@@ -175,6 +225,55 @@ def sanitize_outbound_url(url: str) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
async def read_json_body(request: Request, max_bytes: int) -> dict:
|
||||
"""Read and parse JSON body with a hard size cap for non-streaming requests."""
|
||||
raw = await request.body()
|
||||
if len(raw) > max_bytes:
|
||||
raise HTTPException(status_code=413, detail="Request payload too large")
|
||||
if not raw:
|
||||
return {}
|
||||
try:
|
||||
return json.loads(raw.decode("utf-8"))
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON payload")
|
||||
|
||||
|
||||
def request_body_limit(path: str) -> int:
|
||||
if path in {"/api/chat", "/api/search"}:
|
||||
return BODY_LIMIT_CHAT_BYTES
|
||||
if path == "/api/profile":
|
||||
return BODY_LIMIT_PROFILE_BYTES
|
||||
return BODY_LIMIT_DEFAULT_BYTES
|
||||
|
||||
|
||||
def rate_policy(path: str, method: str, ip: str, sid: str) -> tuple[str, int]:
|
||||
identity = sid or ip
|
||||
if path == "/api/auth/login":
|
||||
return f"login:{ip}", RL_LOGIN_PER_WINDOW
|
||||
if path == "/api/chat":
|
||||
return f"chat:{identity}", RL_CHAT_PER_WINDOW
|
||||
if path == "/api/search":
|
||||
return f"search:{identity}", RL_SEARCH_PER_WINDOW
|
||||
if path == "/api/stats":
|
||||
return f"stats:{ip}", RL_STATS_PER_WINDOW
|
||||
if method in {"POST", "PUT", "DELETE", "PATCH"}:
|
||||
return f"write:{identity}", RL_WRITE_PER_WINDOW
|
||||
return f"api:{identity}", RL_DEFAULT_PER_WINDOW
|
||||
|
||||
|
||||
def check_rate_limit(key: str, limit: int, window_seconds: int) -> tuple[bool, int]:
|
||||
now_ts = time.time()
|
||||
with RATE_LOCK:
|
||||
bucket = RATE_EVENTS[key]
|
||||
while bucket and bucket[0] <= (now_ts - window_seconds):
|
||||
bucket.popleft()
|
||||
if len(bucket) >= limit:
|
||||
retry_after = max(1, int(math.ceil(window_seconds - (now_ts - bucket[0]))))
|
||||
return False, retry_after
|
||||
bucket.append(now_ts)
|
||||
return True, 0
|
||||
|
||||
|
||||
# --- Default Profile ---
|
||||
DEFAULT_PROFILE = """You are a coding companion running locally on a machine called "jarvis".
|
||||
|
||||
@@ -760,13 +859,29 @@ if static_dir.exists():
|
||||
|
||||
def get_client_ip(request: Request) -> str:
|
||||
forwarded = request.headers.get("x-forwarded-for", "").strip()
|
||||
if forwarded:
|
||||
if TRUST_X_FORWARDED_FOR and forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
if request.client and request.client.host:
|
||||
return request.client.host
|
||||
return "unknown"
|
||||
|
||||
|
||||
def is_ip_allowed(ip: str) -> bool:
|
||||
"""Allow only loopback/private CIDRs by default; env can override CIDR list."""
|
||||
normalized = ip.strip().lower()
|
||||
if normalized in {"localhost", "testclient"}:
|
||||
normalized = "127.0.0.1"
|
||||
try:
|
||||
ip_obj = ipaddress.ip_address(normalized)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
for network in ALLOWED_NETWORKS:
|
||||
if ip_obj in network:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def cleanup_sessions(now_ts: Optional[float] = None) -> None:
|
||||
now_ts = now_ts or time.time()
|
||||
with SESSION_LOCK:
|
||||
@@ -927,9 +1042,53 @@ def origin_allowed(request: Request) -> bool:
|
||||
async def session_auth_middleware(request: Request, call_next):
|
||||
path = request.url.path
|
||||
ip = get_client_ip(request)
|
||||
sid = request.headers.get("x-session-id", "").strip()
|
||||
request.state.session_role = "none"
|
||||
request.state.client_ip = ip
|
||||
|
||||
if path.startswith("/api/"):
|
||||
if not is_ip_allowed(ip):
|
||||
audit_event(
|
||||
"ip_allowlist",
|
||||
"denied",
|
||||
ip=ip,
|
||||
role="none",
|
||||
details=f"{request.method} {path}",
|
||||
warning=True,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "Client IP not allowed"},
|
||||
)
|
||||
|
||||
if path.startswith("/api/"):
|
||||
rate_key, rate_limit = rate_policy(path, request.method, ip, sid)
|
||||
allowed, retry_after = check_rate_limit(
|
||||
rate_key, rate_limit, RATE_WINDOW_SECONDS
|
||||
)
|
||||
if not allowed:
|
||||
audit_event(
|
||||
"rate_limit",
|
||||
"denied",
|
||||
ip=ip,
|
||||
role="none",
|
||||
details=f"{request.method} {path} retry_after={retry_after}",
|
||||
warning=True,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": f"Rate limit exceeded. Retry in {retry_after}s."},
|
||||
)
|
||||
|
||||
if request.method in {"POST", "PUT", "PATCH"}:
|
||||
max_bytes = request_body_limit(path)
|
||||
content_length = request.headers.get("content-length", "").strip()
|
||||
if content_length.isdigit() and int(content_length) > max_bytes:
|
||||
return JSONResponse(
|
||||
status_code=413,
|
||||
content={"detail": "Request payload too large"},
|
||||
)
|
||||
|
||||
unauth_paths = {
|
||||
"/api/auth/login",
|
||||
"/api/auth/logout",
|
||||
@@ -955,7 +1114,6 @@ async def session_auth_middleware(request: Request, call_next):
|
||||
)
|
||||
|
||||
if path.startswith("/api/") and path not in unauth_paths:
|
||||
sid = request.headers.get("x-session-id", "").strip()
|
||||
session = get_session(sid, ip, touch=True)
|
||||
if not session:
|
||||
audit_event(
|
||||
@@ -1016,7 +1174,7 @@ async def auth_guest(request: Request):
|
||||
|
||||
@app.post("/api/auth/login")
|
||||
async def auth_login(request: Request):
|
||||
body = await request.json()
|
||||
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||
pin = str(body.get("pin", ""))
|
||||
ip = get_client_ip(request)
|
||||
|
||||
@@ -1081,7 +1239,7 @@ async def auth_logout(request: Request):
|
||||
role = session.get("role", "none") if session else "none"
|
||||
if not sid:
|
||||
try:
|
||||
body = await request.json()
|
||||
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||
sid = str(body.get("session_id", "")).strip()
|
||||
except Exception:
|
||||
try:
|
||||
@@ -1125,7 +1283,7 @@ async def running_models():
|
||||
|
||||
@app.post("/api/show")
|
||||
async def show_model(request: Request):
|
||||
body = await request.json()
|
||||
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
resp = await client.post(f"{OLLAMA_BASE}/api/show", json=body, timeout=10)
|
||||
@@ -1175,9 +1333,14 @@ async def list_memories(topic: Optional[str] = None):
|
||||
|
||||
@app.post("/api/memories")
|
||||
async def create_memory(request: Request):
|
||||
body = await request.json()
|
||||
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||
fact = str(body.get("fact", "")).strip()
|
||||
if not fact:
|
||||
raise HTTPException(status_code=400, detail="Memory fact is required")
|
||||
if len(fact) > MAX_MEMORY_FACT_CHARS:
|
||||
raise HTTPException(status_code=413, detail="Memory fact is too long")
|
||||
rowid = add_memory(
|
||||
fact=body["fact"],
|
||||
fact=fact,
|
||||
topic=body.get("topic", "general"),
|
||||
source=body.get("source", "manual"),
|
||||
)
|
||||
@@ -1193,8 +1356,13 @@ async def remove_memory(rowid: int):
|
||||
|
||||
@app.put("/api/memories/{rowid}")
|
||||
async def edit_memory(rowid: int, request: Request):
|
||||
body = await request.json()
|
||||
if not update_memory(rowid, body["fact"]):
|
||||
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||
fact = str(body.get("fact", "")).strip()
|
||||
if not fact:
|
||||
raise HTTPException(status_code=400, detail="Memory fact is required")
|
||||
if len(fact) > MAX_MEMORY_FACT_CHARS:
|
||||
raise HTTPException(status_code=413, detail="Memory fact is too long")
|
||||
if not update_memory(rowid, fact):
|
||||
raise HTTPException(status_code=404, detail="Memory not found")
|
||||
return {"status": "ok"}
|
||||
|
||||
@@ -1233,12 +1401,15 @@ async def get_profile():
|
||||
|
||||
@app.put("/api/profile")
|
||||
async def update_profile(request: Request):
|
||||
body = await request.json()
|
||||
body = await read_json_body(request, BODY_LIMIT_PROFILE_BYTES)
|
||||
content = str(body.get("content", ""))
|
||||
if len(content) > MAX_PROFILE_CHARS:
|
||||
raise HTTPException(status_code=413, detail="Profile content is too long")
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
db = get_db()
|
||||
db.execute(
|
||||
"UPDATE profile SET content = ?, updated_at = ? WHERE id = 1",
|
||||
(body["content"], now),
|
||||
(content, now),
|
||||
)
|
||||
db.commit()
|
||||
db.close()
|
||||
@@ -1263,9 +1434,16 @@ async def get_settings():
|
||||
|
||||
@app.put("/api/settings")
|
||||
async def update_settings(request: Request):
|
||||
body = await request.json()
|
||||
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||
if not isinstance(body, dict):
|
||||
raise HTTPException(status_code=400, detail="Settings payload must be an object")
|
||||
if len(body) > MAX_SETTINGS_KEYS:
|
||||
raise HTTPException(status_code=413, detail="Too many settings in one request")
|
||||
db = get_db()
|
||||
for key, value in body.items():
|
||||
if len(str(key)) > 80 or len(str(value)) > MAX_SETTINGS_VALUE_CHARS:
|
||||
db.close()
|
||||
raise HTTPException(status_code=413, detail="Setting key/value too long")
|
||||
db.execute(
|
||||
"INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)",
|
||||
(key, str(value)),
|
||||
@@ -1290,26 +1468,38 @@ async def list_presets():
|
||||
|
||||
@app.post("/api/presets")
|
||||
async def create_preset(request: Request):
|
||||
body = await request.json()
|
||||
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||
name = str(body.get("name", "")).strip()
|
||||
prompt = str(body.get("prompt", "")).strip()
|
||||
if not name or not prompt:
|
||||
raise HTTPException(status_code=400, detail="Preset name and prompt are required")
|
||||
if len(name) > MAX_PRESET_NAME_CHARS or len(prompt) > MAX_PRESET_PROMPT_CHARS:
|
||||
raise HTTPException(status_code=413, detail="Preset fields are too long")
|
||||
preset_id = str(uuid.uuid4())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
db = get_db()
|
||||
db.execute(
|
||||
"INSERT INTO system_presets (id, name, prompt, is_default, created_at) VALUES (?, ?, ?, 0, ?)",
|
||||
(preset_id, body["name"], body["prompt"], now),
|
||||
(preset_id, name, prompt, now),
|
||||
)
|
||||
db.commit()
|
||||
db.close()
|
||||
return {"id": preset_id, "name": body["name"], "prompt": body["prompt"]}
|
||||
return {"id": preset_id, "name": name, "prompt": prompt}
|
||||
|
||||
|
||||
@app.put("/api/presets/{preset_id}")
|
||||
async def update_preset(preset_id: str, request: Request):
|
||||
body = await request.json()
|
||||
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||
name = str(body.get("name", "")).strip()
|
||||
prompt = str(body.get("prompt", "")).strip()
|
||||
if not name or not prompt:
|
||||
raise HTTPException(status_code=400, detail="Preset name and prompt are required")
|
||||
if len(name) > MAX_PRESET_NAME_CHARS or len(prompt) > MAX_PRESET_PROMPT_CHARS:
|
||||
raise HTTPException(status_code=413, detail="Preset fields are too long")
|
||||
db = get_db()
|
||||
db.execute(
|
||||
"UPDATE system_presets SET name = ?, prompt = ? WHERE id = ?",
|
||||
(body["name"], body["prompt"], preset_id),
|
||||
(name, prompt, preset_id),
|
||||
)
|
||||
db.commit()
|
||||
db.close()
|
||||
@@ -1340,11 +1530,11 @@ async def list_conversations():
|
||||
|
||||
@app.post("/api/conversations")
|
||||
async def create_conversation(request: Request):
|
||||
body = await request.json()
|
||||
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||
conv_id = str(uuid.uuid4())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
model = body.get("model", DEFAULT_MODEL)
|
||||
title = body.get("title", "New Chat")
|
||||
title = str(body.get("title", "New Chat"))[:MAX_CONVERSATION_TITLE_CHARS]
|
||||
db = get_db()
|
||||
db.execute(
|
||||
"INSERT INTO conversations (id, title, model, created_at, updated_at) VALUES (?, ?, ?, ?, ?)",
|
||||
@@ -1377,13 +1567,13 @@ async def get_conversation(conv_id: str):
|
||||
|
||||
@app.put("/api/conversations/{conv_id}")
|
||||
async def update_conversation(conv_id: str, request: Request):
|
||||
body = await request.json()
|
||||
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||
db = get_db()
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
if "title" in body:
|
||||
db.execute(
|
||||
"UPDATE conversations SET title = ?, updated_at = ? WHERE id = ?",
|
||||
(body["title"], now, conv_id),
|
||||
(str(body["title"])[:MAX_CONVERSATION_TITLE_CHARS], now, conv_id),
|
||||
)
|
||||
if "model" in body:
|
||||
db.execute(
|
||||
@@ -1424,8 +1614,10 @@ async def delete_all_conversations():
|
||||
@app.post("/api/search")
|
||||
async def explicit_search(request: Request):
|
||||
"""Explicit web search - bypasses model uncertainty, queries SearXNG directly."""
|
||||
body = await request.json()
|
||||
body = await read_json_body(request, BODY_LIMIT_CHAT_BYTES)
|
||||
query = body.get("query", "").strip()
|
||||
if len(query) > MAX_SEARCH_QUERY_CHARS:
|
||||
raise HTTPException(status_code=413, detail="Search query is too long")
|
||||
conv_id = body.get("conversation_id")
|
||||
model = body.get("model", DEFAULT_MODEL)
|
||||
|
||||
@@ -1570,9 +1762,11 @@ def build_system_prompt(db, extra_prompt="", user_message=""):
|
||||
|
||||
@app.post("/api/chat")
|
||||
async def chat(request: Request):
|
||||
body = await request.json()
|
||||
body = await read_json_body(request, BODY_LIMIT_CHAT_BYTES)
|
||||
conv_id = body.get("conversation_id")
|
||||
user_message = body.get("message", "").strip()
|
||||
if len(user_message) > MAX_CHAT_MESSAGE_CHARS:
|
||||
raise HTTPException(status_code=413, detail="Chat message is too long")
|
||||
model = body.get("model", DEFAULT_MODEL)
|
||||
preset_prompt = body.get("system_prompt", "")
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ Total identified items: 26
|
||||
1. [P0][DONE] Add authentication/authorization for all write and admin endpoints.
|
||||
2. [P0][DONE] Add CSRF/origin protection for browser-initiated state-changing requests.
|
||||
3. [P0][DONE] Block unsafe URL schemes in rendered search-result links (e.g., javascript:).
|
||||
4. [P0] Add rate limiting and request body size limits for chat/search/profile APIs.
|
||||
4. [P0][DONE] Add rate limiting and request body size limits for chat/search/profile APIs.
|
||||
5. [P1] Restrict settings updates to an allowlist of valid keys.
|
||||
6. [P1] Add pagination + hard caps on list endpoints (memories, conversations, message history).
|
||||
7. [P1] Stop returning raw exception text to clients; use safe error envelopes.
|
||||
|
||||
18
readme.md
18
readme.md
@@ -6,6 +6,20 @@
|
||||
|
||||
Built with FastAPI + SQLite + Jinja2. Runs on Python 3.13. No Docker required.
|
||||
|
||||
## Security Scope Disclaimer
|
||||
|
||||
JarvisChat is designed for local and home-lab use (same host or trusted LAN).
|
||||
|
||||
JarvisChat may technically work with frontier or commercial AI endpoints, but the author does not recommend or support that usage.
|
||||
|
||||
Supported deployments are contained local/home-lab environments.
|
||||
|
||||
By default, API access is limited to loopback + private LAN CIDRs. You can override with `JARVISCHAT_ALLOWED_CIDRS` (comma-separated CIDRs) and optionally trust reverse-proxy forwarding with `JARVISCHAT_TRUST_X_FORWARDED_FOR=true`.
|
||||
|
||||
If you deploy outside a trusted local subnet, your risk profile changes significantly and the default protections here may be insufficient.
|
||||
|
||||
Use at your own risk. No warranty is provided for Internet-exposed deployments.
|
||||
|
||||
## What's New in v1.5.0
|
||||
|
||||
- **Explicit Web Search Button** — 🔍 button next to SEND forces a web search, bypassing model uncertainty detection
|
||||
@@ -47,7 +61,7 @@ Top 10 (brief):
|
||||
1. P0 [DONE]: Add auth for write/admin endpoints
|
||||
2. P0 [DONE]: Add CSRF/origin protection for state-changing requests
|
||||
3. P0 [DONE]: Block unsafe URL schemes in rendered links
|
||||
4. P0: Add rate limiting and request size limits
|
||||
4. P0 [DONE]: Add rate limiting and request size limits
|
||||
5. P1: Restrict `/api/settings` updates to allowlisted keys
|
||||
6. P1: Add pagination + hard caps for list APIs
|
||||
7. P1: Replace raw exception leakage with safe client errors
|
||||
@@ -57,7 +71,7 @@ Top 10 (brief):
|
||||
|
||||
Item 1 executive summary: keep guest mode for conversational chat, require 4-digit admin PIN for advanced/destructive actions, and enforce local/LAN-only backend policy by default.
|
||||
|
||||
Implementation status: complete (guest session by default + admin unlock + admin-only write enforcement + origin checks + safe-link sanitization + audit logging + capability tests).
|
||||
Implementation status: complete (guest session by default + admin unlock + admin-only write enforcement + origin checks + safe-link sanitization + audit logging + rate/payload guardrails + capability tests).
|
||||
|
||||
## TODO
|
||||
|
||||
|
||||
50
tests/test_ip_allowlist.py
Normal file
50
tests/test_ip_allowlist.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import app as app_module
|
||||
|
||||
|
||||
def make_client(tmp_path: Path) -> TestClient:
|
||||
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
||||
app_module.DB_PATH = tmp_path / "jarvischat-ip.db"
|
||||
app_module.SESSIONS.clear()
|
||||
app_module.PIN_ATTEMPTS.clear()
|
||||
app_module.RATE_EVENTS.clear()
|
||||
app_module.init_db()
|
||||
return TestClient(app_module.app)
|
||||
|
||||
|
||||
def test_ip_helper_allows_local_defaults():
|
||||
assert app_module.is_ip_allowed("127.0.0.1")
|
||||
assert app_module.is_ip_allowed("192.168.1.10")
|
||||
assert app_module.is_ip_allowed("10.0.0.42")
|
||||
assert app_module.is_ip_allowed("172.16.1.2")
|
||||
assert app_module.is_ip_allowed("testclient")
|
||||
|
||||
|
||||
def test_ip_helper_blocks_public_ip():
|
||||
assert not app_module.is_ip_allowed("8.8.8.8")
|
||||
|
||||
|
||||
def test_middleware_blocks_disallowed_ip(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
original_get_client_ip = app_module.get_client_ip
|
||||
try:
|
||||
app_module.get_client_ip = lambda _req: "8.8.8.8"
|
||||
resp = client.post("/api/auth/guest")
|
||||
assert resp.status_code == 403
|
||||
finally:
|
||||
app_module.get_client_ip = original_get_client_ip
|
||||
|
||||
|
||||
def test_middleware_allows_local_ip(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
original_get_client_ip = app_module.get_client_ip
|
||||
try:
|
||||
app_module.get_client_ip = lambda _req: "192.168.50.109"
|
||||
resp = client.post("/api/auth/guest")
|
||||
assert resp.status_code == 200
|
||||
finally:
|
||||
app_module.get_client_ip = original_get_client_ip
|
||||
76
tests/test_rate_and_payload_guardrails.py
Normal file
76
tests/test_rate_and_payload_guardrails.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import app as app_module
|
||||
|
||||
|
||||
def make_client(tmp_path: Path) -> TestClient:
|
||||
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
||||
app_module.DB_PATH = tmp_path / "jarvischat-rate.db"
|
||||
app_module.SESSIONS.clear()
|
||||
app_module.PIN_ATTEMPTS.clear()
|
||||
app_module.RATE_EVENTS.clear()
|
||||
app_module.init_db()
|
||||
return TestClient(app_module.app)
|
||||
|
||||
|
||||
def test_stats_rate_limit_hits_429(tmp_path: Path):
|
||||
old_limit = app_module.RL_STATS_PER_WINDOW
|
||||
old_window = app_module.RATE_WINDOW_SECONDS
|
||||
app_module.RL_STATS_PER_WINDOW = 2
|
||||
app_module.RATE_WINDOW_SECONDS = 60
|
||||
try:
|
||||
with make_client(tmp_path) as client:
|
||||
sid = client.post("/api/auth/guest").json()["session_id"]
|
||||
headers = {"X-Session-ID": sid}
|
||||
|
||||
r1 = client.get("/api/stats", headers=headers)
|
||||
r2 = client.get("/api/stats", headers=headers)
|
||||
r3 = client.get("/api/stats", headers=headers)
|
||||
|
||||
assert r1.status_code == 200
|
||||
assert r2.status_code == 200
|
||||
assert r3.status_code == 429
|
||||
finally:
|
||||
app_module.RL_STATS_PER_WINDOW = old_limit
|
||||
app_module.RATE_WINDOW_SECONDS = old_window
|
||||
|
||||
|
||||
def test_large_login_payload_rejected_413(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
huge_pin = "1" * (app_module.BODY_LIMIT_DEFAULT_BYTES + 100)
|
||||
resp = client.post(
|
||||
"/api/auth/login",
|
||||
data=json.dumps({"pin": huge_pin}),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert resp.status_code == 413
|
||||
|
||||
|
||||
def test_chat_message_length_rejected_413(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
sid = client.post("/api/auth/guest").json()["session_id"]
|
||||
headers = {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||
message = "x" * (app_module.MAX_CHAT_MESSAGE_CHARS + 1)
|
||||
resp = client.post(
|
||||
"/api/chat",
|
||||
json={"message": message, "model": app_module.DEFAULT_MODEL},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 413
|
||||
|
||||
|
||||
def test_search_query_length_rejected_413(tmp_path: Path):
|
||||
with make_client(tmp_path) as client:
|
||||
sid = client.post("/api/auth/guest").json()["session_id"]
|
||||
headers = {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||
query = "q" * (app_module.MAX_SEARCH_QUERY_CHARS + 1)
|
||||
resp = client.post(
|
||||
"/api/search",
|
||||
json={"query": query, "model": app_module.DEFAULT_MODEL},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 413
|
||||
Reference in New Issue
Block a user