From 76e4461b38e3c5b9f5860ea322293fdd7ed77e20 Mon Sep 17 00:00:00 2001 From: gramps Date: Mon, 27 Apr 2026 16:43:21 -0700 Subject: [PATCH] feat(security): add LAN IP allowlist and ingress guardrails --- app.py | 240 +++++++++++++++++++--- docs/wiki/current-wip.md | 2 +- readme.md | 18 +- tests/test_ip_allowlist.py | 50 +++++ tests/test_rate_and_payload_guardrails.py | 76 +++++++ 5 files changed, 360 insertions(+), 26 deletions(-) create mode 100644 tests/test_ip_allowlist.py create mode 100644 tests/test_rate_and_payload_guardrails.py diff --git a/app.py b/app.py index bb1da33..1b0009e 100644 --- a/app.py +++ b/app.py @@ -27,6 +27,8 @@ import hmac import time import uuid import re +import ipaddress +from collections import defaultdict, deque from threading import Lock from datetime import datetime, timezone from pathlib import Path @@ -72,6 +74,35 @@ TRUSTED_ORIGINS = { if origin.strip() } +DEFAULT_ALLOWED_CIDRS = "127.0.0.0/8,::1/128,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16" +ALLOWED_CIDRS_RAW = os.getenv("JARVISCHAT_ALLOWED_CIDRS", DEFAULT_ALLOWED_CIDRS) +TRUST_X_FORWARDED_FOR = ( + os.getenv("JARVISCHAT_TRUST_X_FORWARDED_FOR", "false").lower() == "true" +) + +# --- Rate / Payload Guardrails (home-lab friendly defaults) --- +RATE_WINDOW_SECONDS = 60 +RL_LOGIN_PER_WINDOW = 10 +RL_CHAT_PER_WINDOW = 24 +RL_SEARCH_PER_WINDOW = 16 +RL_WRITE_PER_WINDOW = 30 +RL_DEFAULT_PER_WINDOW = 240 +RL_STATS_PER_WINDOW = 600 + +BODY_LIMIT_DEFAULT_BYTES = 64 * 1024 +BODY_LIMIT_CHAT_BYTES = 128 * 1024 +BODY_LIMIT_PROFILE_BYTES = 256 * 1024 + +MAX_CHAT_MESSAGE_CHARS = 8000 +MAX_SEARCH_QUERY_CHARS = 500 +MAX_PROFILE_CHARS = 32000 +MAX_MEMORY_FACT_CHARS = 2000 +MAX_PRESET_NAME_CHARS = 120 +MAX_PRESET_PROMPT_CHARS = 12000 +MAX_SETTINGS_KEYS = 16 +MAX_SETTINGS_VALUE_CHARS = 8000 +MAX_CONVERSATION_TITLE_CHARS = 200 + # --- Templates and Static Files --- templates = Jinja2Templates(directory=str(BASE_DIR / "templates")) @@ -109,7 +140,9 @@ HEDGE_PATTERNS = [ SESSIONS: dict[str, dict] = {} PIN_ATTEMPTS: dict[str, dict] = {} +RATE_EVENTS: dict[str, deque[float]] = defaultdict(deque) SESSION_LOCK = Lock() +RATE_LOCK = Lock() def hash_pin(pin: str, salt_hex: Optional[str] = None) -> tuple[str, str]: @@ -143,6 +176,23 @@ def audit_event( log.info(msg) +def parse_allowed_cidrs(raw: str) -> list[ipaddress._BaseNetwork]: + """Parse comma-separated CIDR list into validated network objects.""" + networks: list[ipaddress._BaseNetwork] = [] + for entry in (raw or "").split(","): + value = entry.strip() + if not value: + continue + try: + networks.append(ipaddress.ip_network(value, strict=False)) + except ValueError: + log.warning(f"Invalid CIDR ignored: {value}") + return networks + + +ALLOWED_NETWORKS = parse_allowed_cidrs(ALLOWED_CIDRS_RAW) + + def clean_hedging(text: str) -> str: """Remove hedging sentences from model response.""" cleaned = text @@ -175,6 +225,55 @@ def sanitize_outbound_url(url: str) -> str: return "" +async def read_json_body(request: Request, max_bytes: int) -> dict: + """Read and parse JSON body with a hard size cap for non-streaming requests.""" + raw = await request.body() + if len(raw) > max_bytes: + raise HTTPException(status_code=413, detail="Request payload too large") + if not raw: + return {} + try: + return json.loads(raw.decode("utf-8")) + except Exception: + raise HTTPException(status_code=400, detail="Invalid JSON payload") + + +def request_body_limit(path: str) -> int: + if path in {"/api/chat", "/api/search"}: + return BODY_LIMIT_CHAT_BYTES + if path == "/api/profile": + return BODY_LIMIT_PROFILE_BYTES + return BODY_LIMIT_DEFAULT_BYTES + + +def rate_policy(path: str, method: str, ip: str, sid: str) -> tuple[str, int]: + identity = sid or ip + if path == "/api/auth/login": + return f"login:{ip}", RL_LOGIN_PER_WINDOW + if path == "/api/chat": + return f"chat:{identity}", RL_CHAT_PER_WINDOW + if path == "/api/search": + return f"search:{identity}", RL_SEARCH_PER_WINDOW + if path == "/api/stats": + return f"stats:{ip}", RL_STATS_PER_WINDOW + if method in {"POST", "PUT", "DELETE", "PATCH"}: + return f"write:{identity}", RL_WRITE_PER_WINDOW + return f"api:{identity}", RL_DEFAULT_PER_WINDOW + + +def check_rate_limit(key: str, limit: int, window_seconds: int) -> tuple[bool, int]: + now_ts = time.time() + with RATE_LOCK: + bucket = RATE_EVENTS[key] + while bucket and bucket[0] <= (now_ts - window_seconds): + bucket.popleft() + if len(bucket) >= limit: + retry_after = max(1, int(math.ceil(window_seconds - (now_ts - bucket[0])))) + return False, retry_after + bucket.append(now_ts) + return True, 0 + + # --- Default Profile --- DEFAULT_PROFILE = """You are a coding companion running locally on a machine called "jarvis". @@ -760,13 +859,29 @@ if static_dir.exists(): def get_client_ip(request: Request) -> str: forwarded = request.headers.get("x-forwarded-for", "").strip() - if forwarded: + if TRUST_X_FORWARDED_FOR and forwarded: return forwarded.split(",")[0].strip() if request.client and request.client.host: return request.client.host return "unknown" +def is_ip_allowed(ip: str) -> bool: + """Allow only loopback/private CIDRs by default; env can override CIDR list.""" + normalized = ip.strip().lower() + if normalized in {"localhost", "testclient"}: + normalized = "127.0.0.1" + try: + ip_obj = ipaddress.ip_address(normalized) + except ValueError: + return False + + for network in ALLOWED_NETWORKS: + if ip_obj in network: + return True + return False + + def cleanup_sessions(now_ts: Optional[float] = None) -> None: now_ts = now_ts or time.time() with SESSION_LOCK: @@ -927,9 +1042,53 @@ def origin_allowed(request: Request) -> bool: async def session_auth_middleware(request: Request, call_next): path = request.url.path ip = get_client_ip(request) + sid = request.headers.get("x-session-id", "").strip() request.state.session_role = "none" request.state.client_ip = ip + if path.startswith("/api/"): + if not is_ip_allowed(ip): + audit_event( + "ip_allowlist", + "denied", + ip=ip, + role="none", + details=f"{request.method} {path}", + warning=True, + ) + return JSONResponse( + status_code=403, + content={"detail": "Client IP not allowed"}, + ) + + if path.startswith("/api/"): + rate_key, rate_limit = rate_policy(path, request.method, ip, sid) + allowed, retry_after = check_rate_limit( + rate_key, rate_limit, RATE_WINDOW_SECONDS + ) + if not allowed: + audit_event( + "rate_limit", + "denied", + ip=ip, + role="none", + details=f"{request.method} {path} retry_after={retry_after}", + warning=True, + ) + return JSONResponse( + status_code=429, + content={"detail": f"Rate limit exceeded. Retry in {retry_after}s."}, + ) + + if request.method in {"POST", "PUT", "PATCH"}: + max_bytes = request_body_limit(path) + content_length = request.headers.get("content-length", "").strip() + if content_length.isdigit() and int(content_length) > max_bytes: + return JSONResponse( + status_code=413, + content={"detail": "Request payload too large"}, + ) + unauth_paths = { "/api/auth/login", "/api/auth/logout", @@ -955,7 +1114,6 @@ async def session_auth_middleware(request: Request, call_next): ) if path.startswith("/api/") and path not in unauth_paths: - sid = request.headers.get("x-session-id", "").strip() session = get_session(sid, ip, touch=True) if not session: audit_event( @@ -1016,7 +1174,7 @@ async def auth_guest(request: Request): @app.post("/api/auth/login") async def auth_login(request: Request): - body = await request.json() + body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES) pin = str(body.get("pin", "")) ip = get_client_ip(request) @@ -1081,7 +1239,7 @@ async def auth_logout(request: Request): role = session.get("role", "none") if session else "none" if not sid: try: - body = await request.json() + body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES) sid = str(body.get("session_id", "")).strip() except Exception: try: @@ -1125,7 +1283,7 @@ async def running_models(): @app.post("/api/show") async def show_model(request: Request): - body = await request.json() + body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES) async with httpx.AsyncClient() as client: try: resp = await client.post(f"{OLLAMA_BASE}/api/show", json=body, timeout=10) @@ -1175,9 +1333,14 @@ async def list_memories(topic: Optional[str] = None): @app.post("/api/memories") async def create_memory(request: Request): - body = await request.json() + body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES) + fact = str(body.get("fact", "")).strip() + if not fact: + raise HTTPException(status_code=400, detail="Memory fact is required") + if len(fact) > MAX_MEMORY_FACT_CHARS: + raise HTTPException(status_code=413, detail="Memory fact is too long") rowid = add_memory( - fact=body["fact"], + fact=fact, topic=body.get("topic", "general"), source=body.get("source", "manual"), ) @@ -1193,8 +1356,13 @@ async def remove_memory(rowid: int): @app.put("/api/memories/{rowid}") async def edit_memory(rowid: int, request: Request): - body = await request.json() - if not update_memory(rowid, body["fact"]): + body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES) + fact = str(body.get("fact", "")).strip() + if not fact: + raise HTTPException(status_code=400, detail="Memory fact is required") + if len(fact) > MAX_MEMORY_FACT_CHARS: + raise HTTPException(status_code=413, detail="Memory fact is too long") + if not update_memory(rowid, fact): raise HTTPException(status_code=404, detail="Memory not found") return {"status": "ok"} @@ -1233,12 +1401,15 @@ async def get_profile(): @app.put("/api/profile") async def update_profile(request: Request): - body = await request.json() + body = await read_json_body(request, BODY_LIMIT_PROFILE_BYTES) + content = str(body.get("content", "")) + if len(content) > MAX_PROFILE_CHARS: + raise HTTPException(status_code=413, detail="Profile content is too long") now = datetime.now(timezone.utc).isoformat() db = get_db() db.execute( "UPDATE profile SET content = ?, updated_at = ? WHERE id = 1", - (body["content"], now), + (content, now), ) db.commit() db.close() @@ -1263,9 +1434,16 @@ async def get_settings(): @app.put("/api/settings") async def update_settings(request: Request): - body = await request.json() + body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES) + if not isinstance(body, dict): + raise HTTPException(status_code=400, detail="Settings payload must be an object") + if len(body) > MAX_SETTINGS_KEYS: + raise HTTPException(status_code=413, detail="Too many settings in one request") db = get_db() for key, value in body.items(): + if len(str(key)) > 80 or len(str(value)) > MAX_SETTINGS_VALUE_CHARS: + db.close() + raise HTTPException(status_code=413, detail="Setting key/value too long") db.execute( "INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)", (key, str(value)), @@ -1290,26 +1468,38 @@ async def list_presets(): @app.post("/api/presets") async def create_preset(request: Request): - body = await request.json() + body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES) + name = str(body.get("name", "")).strip() + prompt = str(body.get("prompt", "")).strip() + if not name or not prompt: + raise HTTPException(status_code=400, detail="Preset name and prompt are required") + if len(name) > MAX_PRESET_NAME_CHARS or len(prompt) > MAX_PRESET_PROMPT_CHARS: + raise HTTPException(status_code=413, detail="Preset fields are too long") preset_id = str(uuid.uuid4()) now = datetime.now(timezone.utc).isoformat() db = get_db() db.execute( "INSERT INTO system_presets (id, name, prompt, is_default, created_at) VALUES (?, ?, ?, 0, ?)", - (preset_id, body["name"], body["prompt"], now), + (preset_id, name, prompt, now), ) db.commit() db.close() - return {"id": preset_id, "name": body["name"], "prompt": body["prompt"]} + return {"id": preset_id, "name": name, "prompt": prompt} @app.put("/api/presets/{preset_id}") async def update_preset(preset_id: str, request: Request): - body = await request.json() + body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES) + name = str(body.get("name", "")).strip() + prompt = str(body.get("prompt", "")).strip() + if not name or not prompt: + raise HTTPException(status_code=400, detail="Preset name and prompt are required") + if len(name) > MAX_PRESET_NAME_CHARS or len(prompt) > MAX_PRESET_PROMPT_CHARS: + raise HTTPException(status_code=413, detail="Preset fields are too long") db = get_db() db.execute( "UPDATE system_presets SET name = ?, prompt = ? WHERE id = ?", - (body["name"], body["prompt"], preset_id), + (name, prompt, preset_id), ) db.commit() db.close() @@ -1340,11 +1530,11 @@ async def list_conversations(): @app.post("/api/conversations") async def create_conversation(request: Request): - body = await request.json() + body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES) conv_id = str(uuid.uuid4()) now = datetime.now(timezone.utc).isoformat() model = body.get("model", DEFAULT_MODEL) - title = body.get("title", "New Chat") + title = str(body.get("title", "New Chat"))[:MAX_CONVERSATION_TITLE_CHARS] db = get_db() db.execute( "INSERT INTO conversations (id, title, model, created_at, updated_at) VALUES (?, ?, ?, ?, ?)", @@ -1377,13 +1567,13 @@ async def get_conversation(conv_id: str): @app.put("/api/conversations/{conv_id}") async def update_conversation(conv_id: str, request: Request): - body = await request.json() + body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES) db = get_db() now = datetime.now(timezone.utc).isoformat() if "title" in body: db.execute( "UPDATE conversations SET title = ?, updated_at = ? WHERE id = ?", - (body["title"], now, conv_id), + (str(body["title"])[:MAX_CONVERSATION_TITLE_CHARS], now, conv_id), ) if "model" in body: db.execute( @@ -1424,8 +1614,10 @@ async def delete_all_conversations(): @app.post("/api/search") async def explicit_search(request: Request): """Explicit web search - bypasses model uncertainty, queries SearXNG directly.""" - body = await request.json() + body = await read_json_body(request, BODY_LIMIT_CHAT_BYTES) query = body.get("query", "").strip() + if len(query) > MAX_SEARCH_QUERY_CHARS: + raise HTTPException(status_code=413, detail="Search query is too long") conv_id = body.get("conversation_id") model = body.get("model", DEFAULT_MODEL) @@ -1570,9 +1762,11 @@ def build_system_prompt(db, extra_prompt="", user_message=""): @app.post("/api/chat") async def chat(request: Request): - body = await request.json() + body = await read_json_body(request, BODY_LIMIT_CHAT_BYTES) conv_id = body.get("conversation_id") user_message = body.get("message", "").strip() + if len(user_message) > MAX_CHAT_MESSAGE_CHARS: + raise HTTPException(status_code=413, detail="Chat message is too long") model = body.get("model", DEFAULT_MODEL) preset_prompt = body.get("system_prompt", "") diff --git a/docs/wiki/current-wip.md b/docs/wiki/current-wip.md index bbbe6b7..276ce21 100644 --- a/docs/wiki/current-wip.md +++ b/docs/wiki/current-wip.md @@ -16,7 +16,7 @@ Total identified items: 26 1. [P0][DONE] Add authentication/authorization for all write and admin endpoints. 2. [P0][DONE] Add CSRF/origin protection for browser-initiated state-changing requests. 3. [P0][DONE] Block unsafe URL schemes in rendered search-result links (e.g., javascript:). -4. [P0] Add rate limiting and request body size limits for chat/search/profile APIs. +4. [P0][DONE] Add rate limiting and request body size limits for chat/search/profile APIs. 5. [P1] Restrict settings updates to an allowlist of valid keys. 6. [P1] Add pagination + hard caps on list endpoints (memories, conversations, message history). 7. [P1] Stop returning raw exception text to clients; use safe error envelopes. diff --git a/readme.md b/readme.md index eb8a15f..805c1f2 100644 --- a/readme.md +++ b/readme.md @@ -6,6 +6,20 @@ Built with FastAPI + SQLite + Jinja2. Runs on Python 3.13. No Docker required. +## Security Scope Disclaimer + +JarvisChat is designed for local and home-lab use (same host or trusted LAN). + +JarvisChat may technically work with frontier or commercial AI endpoints, but the author does not recommend or support that usage. + +Supported deployments are contained local/home-lab environments. + +By default, API access is limited to loopback + private LAN CIDRs. You can override with `JARVISCHAT_ALLOWED_CIDRS` (comma-separated CIDRs) and optionally trust reverse-proxy forwarding with `JARVISCHAT_TRUST_X_FORWARDED_FOR=true`. + +If you deploy outside a trusted local subnet, your risk profile changes significantly and the default protections here may be insufficient. + +Use at your own risk. No warranty is provided for Internet-exposed deployments. + ## What's New in v1.5.0 - **Explicit Web Search Button** — 🔍 button next to SEND forces a web search, bypassing model uncertainty detection @@ -47,7 +61,7 @@ Top 10 (brief): 1. P0 [DONE]: Add auth for write/admin endpoints 2. P0 [DONE]: Add CSRF/origin protection for state-changing requests 3. P0 [DONE]: Block unsafe URL schemes in rendered links -4. P0: Add rate limiting and request size limits +4. P0 [DONE]: Add rate limiting and request size limits 5. P1: Restrict `/api/settings` updates to allowlisted keys 6. P1: Add pagination + hard caps for list APIs 7. P1: Replace raw exception leakage with safe client errors @@ -57,7 +71,7 @@ Top 10 (brief): Item 1 executive summary: keep guest mode for conversational chat, require 4-digit admin PIN for advanced/destructive actions, and enforce local/LAN-only backend policy by default. -Implementation status: complete (guest session by default + admin unlock + admin-only write enforcement + origin checks + safe-link sanitization + audit logging + capability tests). +Implementation status: complete (guest session by default + admin unlock + admin-only write enforcement + origin checks + safe-link sanitization + audit logging + rate/payload guardrails + capability tests). ## TODO diff --git a/tests/test_ip_allowlist.py b/tests/test_ip_allowlist.py new file mode 100644 index 0000000..b2ec212 --- /dev/null +++ b/tests/test_ip_allowlist.py @@ -0,0 +1,50 @@ +import os +from pathlib import Path + +from fastapi.testclient import TestClient + +import app as app_module + + +def make_client(tmp_path: Path) -> TestClient: + os.environ["JARVISCHAT_ADMIN_PIN"] = "1234" + app_module.DB_PATH = tmp_path / "jarvischat-ip.db" + app_module.SESSIONS.clear() + app_module.PIN_ATTEMPTS.clear() + app_module.RATE_EVENTS.clear() + app_module.init_db() + return TestClient(app_module.app) + + +def test_ip_helper_allows_local_defaults(): + assert app_module.is_ip_allowed("127.0.0.1") + assert app_module.is_ip_allowed("192.168.1.10") + assert app_module.is_ip_allowed("10.0.0.42") + assert app_module.is_ip_allowed("172.16.1.2") + assert app_module.is_ip_allowed("testclient") + + +def test_ip_helper_blocks_public_ip(): + assert not app_module.is_ip_allowed("8.8.8.8") + + +def test_middleware_blocks_disallowed_ip(tmp_path: Path): + with make_client(tmp_path) as client: + original_get_client_ip = app_module.get_client_ip + try: + app_module.get_client_ip = lambda _req: "8.8.8.8" + resp = client.post("/api/auth/guest") + assert resp.status_code == 403 + finally: + app_module.get_client_ip = original_get_client_ip + + +def test_middleware_allows_local_ip(tmp_path: Path): + with make_client(tmp_path) as client: + original_get_client_ip = app_module.get_client_ip + try: + app_module.get_client_ip = lambda _req: "192.168.50.109" + resp = client.post("/api/auth/guest") + assert resp.status_code == 200 + finally: + app_module.get_client_ip = original_get_client_ip diff --git a/tests/test_rate_and_payload_guardrails.py b/tests/test_rate_and_payload_guardrails.py new file mode 100644 index 0000000..379f594 --- /dev/null +++ b/tests/test_rate_and_payload_guardrails.py @@ -0,0 +1,76 @@ +import json +import os +from pathlib import Path + +from fastapi.testclient import TestClient + +import app as app_module + + +def make_client(tmp_path: Path) -> TestClient: + os.environ["JARVISCHAT_ADMIN_PIN"] = "1234" + app_module.DB_PATH = tmp_path / "jarvischat-rate.db" + app_module.SESSIONS.clear() + app_module.PIN_ATTEMPTS.clear() + app_module.RATE_EVENTS.clear() + app_module.init_db() + return TestClient(app_module.app) + + +def test_stats_rate_limit_hits_429(tmp_path: Path): + old_limit = app_module.RL_STATS_PER_WINDOW + old_window = app_module.RATE_WINDOW_SECONDS + app_module.RL_STATS_PER_WINDOW = 2 + app_module.RATE_WINDOW_SECONDS = 60 + try: + with make_client(tmp_path) as client: + sid = client.post("/api/auth/guest").json()["session_id"] + headers = {"X-Session-ID": sid} + + r1 = client.get("/api/stats", headers=headers) + r2 = client.get("/api/stats", headers=headers) + r3 = client.get("/api/stats", headers=headers) + + assert r1.status_code == 200 + assert r2.status_code == 200 + assert r3.status_code == 429 + finally: + app_module.RL_STATS_PER_WINDOW = old_limit + app_module.RATE_WINDOW_SECONDS = old_window + + +def test_large_login_payload_rejected_413(tmp_path: Path): + with make_client(tmp_path) as client: + huge_pin = "1" * (app_module.BODY_LIMIT_DEFAULT_BYTES + 100) + resp = client.post( + "/api/auth/login", + data=json.dumps({"pin": huge_pin}), + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 413 + + +def test_chat_message_length_rejected_413(tmp_path: Path): + with make_client(tmp_path) as client: + sid = client.post("/api/auth/guest").json()["session_id"] + headers = {"X-Session-ID": sid, "Origin": "http://testserver"} + message = "x" * (app_module.MAX_CHAT_MESSAGE_CHARS + 1) + resp = client.post( + "/api/chat", + json={"message": message, "model": app_module.DEFAULT_MODEL}, + headers=headers, + ) + assert resp.status_code == 413 + + +def test_search_query_length_rejected_413(tmp_path: Path): + with make_client(tmp_path) as client: + sid = client.post("/api/auth/guest").json()["session_id"] + headers = {"X-Session-ID": sid, "Origin": "http://testserver"} + query = "q" * (app_module.MAX_SEARCH_QUERY_CHARS + 1) + resp = client.post( + "/api/search", + json={"query": query, "model": app_module.DEFAULT_MODEL}, + headers=headers, + ) + assert resp.status_code == 413