feat(security): add LAN IP allowlist and ingress guardrails

This commit is contained in:
2026-04-27 16:43:21 -07:00
parent 28aa40c42a
commit 76e4461b38
5 changed files with 360 additions and 26 deletions

240
app.py
View File

@@ -27,6 +27,8 @@ import hmac
import time
import uuid
import re
import ipaddress
from collections import defaultdict, deque
from threading import Lock
from datetime import datetime, timezone
from pathlib import Path
@@ -72,6 +74,35 @@ TRUSTED_ORIGINS = {
if origin.strip()
}
DEFAULT_ALLOWED_CIDRS = "127.0.0.0/8,::1/128,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16"
ALLOWED_CIDRS_RAW = os.getenv("JARVISCHAT_ALLOWED_CIDRS", DEFAULT_ALLOWED_CIDRS)
TRUST_X_FORWARDED_FOR = (
os.getenv("JARVISCHAT_TRUST_X_FORWARDED_FOR", "false").lower() == "true"
)
# --- Rate / Payload Guardrails (home-lab friendly defaults) ---
RATE_WINDOW_SECONDS = 60
RL_LOGIN_PER_WINDOW = 10
RL_CHAT_PER_WINDOW = 24
RL_SEARCH_PER_WINDOW = 16
RL_WRITE_PER_WINDOW = 30
RL_DEFAULT_PER_WINDOW = 240
RL_STATS_PER_WINDOW = 600
BODY_LIMIT_DEFAULT_BYTES = 64 * 1024
BODY_LIMIT_CHAT_BYTES = 128 * 1024
BODY_LIMIT_PROFILE_BYTES = 256 * 1024
MAX_CHAT_MESSAGE_CHARS = 8000
MAX_SEARCH_QUERY_CHARS = 500
MAX_PROFILE_CHARS = 32000
MAX_MEMORY_FACT_CHARS = 2000
MAX_PRESET_NAME_CHARS = 120
MAX_PRESET_PROMPT_CHARS = 12000
MAX_SETTINGS_KEYS = 16
MAX_SETTINGS_VALUE_CHARS = 8000
MAX_CONVERSATION_TITLE_CHARS = 200
# --- Templates and Static Files ---
templates = Jinja2Templates(directory=str(BASE_DIR / "templates"))
@@ -109,7 +140,9 @@ HEDGE_PATTERNS = [
SESSIONS: dict[str, dict] = {}
PIN_ATTEMPTS: dict[str, dict] = {}
RATE_EVENTS: dict[str, deque[float]] = defaultdict(deque)
SESSION_LOCK = Lock()
RATE_LOCK = Lock()
def hash_pin(pin: str, salt_hex: Optional[str] = None) -> tuple[str, str]:
@@ -143,6 +176,23 @@ def audit_event(
log.info(msg)
def parse_allowed_cidrs(raw: str) -> list[ipaddress._BaseNetwork]:
"""Parse comma-separated CIDR list into validated network objects."""
networks: list[ipaddress._BaseNetwork] = []
for entry in (raw or "").split(","):
value = entry.strip()
if not value:
continue
try:
networks.append(ipaddress.ip_network(value, strict=False))
except ValueError:
log.warning(f"Invalid CIDR ignored: {value}")
return networks
ALLOWED_NETWORKS = parse_allowed_cidrs(ALLOWED_CIDRS_RAW)
def clean_hedging(text: str) -> str:
"""Remove hedging sentences from model response."""
cleaned = text
@@ -175,6 +225,55 @@ def sanitize_outbound_url(url: str) -> str:
return ""
async def read_json_body(request: Request, max_bytes: int) -> dict:
"""Read and parse JSON body with a hard size cap for non-streaming requests."""
raw = await request.body()
if len(raw) > max_bytes:
raise HTTPException(status_code=413, detail="Request payload too large")
if not raw:
return {}
try:
return json.loads(raw.decode("utf-8"))
except Exception:
raise HTTPException(status_code=400, detail="Invalid JSON payload")
def request_body_limit(path: str) -> int:
if path in {"/api/chat", "/api/search"}:
return BODY_LIMIT_CHAT_BYTES
if path == "/api/profile":
return BODY_LIMIT_PROFILE_BYTES
return BODY_LIMIT_DEFAULT_BYTES
def rate_policy(path: str, method: str, ip: str, sid: str) -> tuple[str, int]:
identity = sid or ip
if path == "/api/auth/login":
return f"login:{ip}", RL_LOGIN_PER_WINDOW
if path == "/api/chat":
return f"chat:{identity}", RL_CHAT_PER_WINDOW
if path == "/api/search":
return f"search:{identity}", RL_SEARCH_PER_WINDOW
if path == "/api/stats":
return f"stats:{ip}", RL_STATS_PER_WINDOW
if method in {"POST", "PUT", "DELETE", "PATCH"}:
return f"write:{identity}", RL_WRITE_PER_WINDOW
return f"api:{identity}", RL_DEFAULT_PER_WINDOW
def check_rate_limit(key: str, limit: int, window_seconds: int) -> tuple[bool, int]:
now_ts = time.time()
with RATE_LOCK:
bucket = RATE_EVENTS[key]
while bucket and bucket[0] <= (now_ts - window_seconds):
bucket.popleft()
if len(bucket) >= limit:
retry_after = max(1, int(math.ceil(window_seconds - (now_ts - bucket[0]))))
return False, retry_after
bucket.append(now_ts)
return True, 0
# --- Default Profile ---
DEFAULT_PROFILE = """You are a coding companion running locally on a machine called "jarvis".
@@ -760,13 +859,29 @@ if static_dir.exists():
def get_client_ip(request: Request) -> str:
forwarded = request.headers.get("x-forwarded-for", "").strip()
if forwarded:
if TRUST_X_FORWARDED_FOR and forwarded:
return forwarded.split(",")[0].strip()
if request.client and request.client.host:
return request.client.host
return "unknown"
def is_ip_allowed(ip: str) -> bool:
"""Allow only loopback/private CIDRs by default; env can override CIDR list."""
normalized = ip.strip().lower()
if normalized in {"localhost", "testclient"}:
normalized = "127.0.0.1"
try:
ip_obj = ipaddress.ip_address(normalized)
except ValueError:
return False
for network in ALLOWED_NETWORKS:
if ip_obj in network:
return True
return False
def cleanup_sessions(now_ts: Optional[float] = None) -> None:
now_ts = now_ts or time.time()
with SESSION_LOCK:
@@ -927,9 +1042,53 @@ def origin_allowed(request: Request) -> bool:
async def session_auth_middleware(request: Request, call_next):
path = request.url.path
ip = get_client_ip(request)
sid = request.headers.get("x-session-id", "").strip()
request.state.session_role = "none"
request.state.client_ip = ip
if path.startswith("/api/"):
if not is_ip_allowed(ip):
audit_event(
"ip_allowlist",
"denied",
ip=ip,
role="none",
details=f"{request.method} {path}",
warning=True,
)
return JSONResponse(
status_code=403,
content={"detail": "Client IP not allowed"},
)
if path.startswith("/api/"):
rate_key, rate_limit = rate_policy(path, request.method, ip, sid)
allowed, retry_after = check_rate_limit(
rate_key, rate_limit, RATE_WINDOW_SECONDS
)
if not allowed:
audit_event(
"rate_limit",
"denied",
ip=ip,
role="none",
details=f"{request.method} {path} retry_after={retry_after}",
warning=True,
)
return JSONResponse(
status_code=429,
content={"detail": f"Rate limit exceeded. Retry in {retry_after}s."},
)
if request.method in {"POST", "PUT", "PATCH"}:
max_bytes = request_body_limit(path)
content_length = request.headers.get("content-length", "").strip()
if content_length.isdigit() and int(content_length) > max_bytes:
return JSONResponse(
status_code=413,
content={"detail": "Request payload too large"},
)
unauth_paths = {
"/api/auth/login",
"/api/auth/logout",
@@ -955,7 +1114,6 @@ async def session_auth_middleware(request: Request, call_next):
)
if path.startswith("/api/") and path not in unauth_paths:
sid = request.headers.get("x-session-id", "").strip()
session = get_session(sid, ip, touch=True)
if not session:
audit_event(
@@ -1016,7 +1174,7 @@ async def auth_guest(request: Request):
@app.post("/api/auth/login")
async def auth_login(request: Request):
body = await request.json()
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
pin = str(body.get("pin", ""))
ip = get_client_ip(request)
@@ -1081,7 +1239,7 @@ async def auth_logout(request: Request):
role = session.get("role", "none") if session else "none"
if not sid:
try:
body = await request.json()
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
sid = str(body.get("session_id", "")).strip()
except Exception:
try:
@@ -1125,7 +1283,7 @@ async def running_models():
@app.post("/api/show")
async def show_model(request: Request):
body = await request.json()
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
async with httpx.AsyncClient() as client:
try:
resp = await client.post(f"{OLLAMA_BASE}/api/show", json=body, timeout=10)
@@ -1175,9 +1333,14 @@ async def list_memories(topic: Optional[str] = None):
@app.post("/api/memories")
async def create_memory(request: Request):
body = await request.json()
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
fact = str(body.get("fact", "")).strip()
if not fact:
raise HTTPException(status_code=400, detail="Memory fact is required")
if len(fact) > MAX_MEMORY_FACT_CHARS:
raise HTTPException(status_code=413, detail="Memory fact is too long")
rowid = add_memory(
fact=body["fact"],
fact=fact,
topic=body.get("topic", "general"),
source=body.get("source", "manual"),
)
@@ -1193,8 +1356,13 @@ async def remove_memory(rowid: int):
@app.put("/api/memories/{rowid}")
async def edit_memory(rowid: int, request: Request):
body = await request.json()
if not update_memory(rowid, body["fact"]):
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
fact = str(body.get("fact", "")).strip()
if not fact:
raise HTTPException(status_code=400, detail="Memory fact is required")
if len(fact) > MAX_MEMORY_FACT_CHARS:
raise HTTPException(status_code=413, detail="Memory fact is too long")
if not update_memory(rowid, fact):
raise HTTPException(status_code=404, detail="Memory not found")
return {"status": "ok"}
@@ -1233,12 +1401,15 @@ async def get_profile():
@app.put("/api/profile")
async def update_profile(request: Request):
body = await request.json()
body = await read_json_body(request, BODY_LIMIT_PROFILE_BYTES)
content = str(body.get("content", ""))
if len(content) > MAX_PROFILE_CHARS:
raise HTTPException(status_code=413, detail="Profile content is too long")
now = datetime.now(timezone.utc).isoformat()
db = get_db()
db.execute(
"UPDATE profile SET content = ?, updated_at = ? WHERE id = 1",
(body["content"], now),
(content, now),
)
db.commit()
db.close()
@@ -1263,9 +1434,16 @@ async def get_settings():
@app.put("/api/settings")
async def update_settings(request: Request):
body = await request.json()
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
if not isinstance(body, dict):
raise HTTPException(status_code=400, detail="Settings payload must be an object")
if len(body) > MAX_SETTINGS_KEYS:
raise HTTPException(status_code=413, detail="Too many settings in one request")
db = get_db()
for key, value in body.items():
if len(str(key)) > 80 or len(str(value)) > MAX_SETTINGS_VALUE_CHARS:
db.close()
raise HTTPException(status_code=413, detail="Setting key/value too long")
db.execute(
"INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)",
(key, str(value)),
@@ -1290,26 +1468,38 @@ async def list_presets():
@app.post("/api/presets")
async def create_preset(request: Request):
body = await request.json()
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
name = str(body.get("name", "")).strip()
prompt = str(body.get("prompt", "")).strip()
if not name or not prompt:
raise HTTPException(status_code=400, detail="Preset name and prompt are required")
if len(name) > MAX_PRESET_NAME_CHARS or len(prompt) > MAX_PRESET_PROMPT_CHARS:
raise HTTPException(status_code=413, detail="Preset fields are too long")
preset_id = str(uuid.uuid4())
now = datetime.now(timezone.utc).isoformat()
db = get_db()
db.execute(
"INSERT INTO system_presets (id, name, prompt, is_default, created_at) VALUES (?, ?, ?, 0, ?)",
(preset_id, body["name"], body["prompt"], now),
(preset_id, name, prompt, now),
)
db.commit()
db.close()
return {"id": preset_id, "name": body["name"], "prompt": body["prompt"]}
return {"id": preset_id, "name": name, "prompt": prompt}
@app.put("/api/presets/{preset_id}")
async def update_preset(preset_id: str, request: Request):
body = await request.json()
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
name = str(body.get("name", "")).strip()
prompt = str(body.get("prompt", "")).strip()
if not name or not prompt:
raise HTTPException(status_code=400, detail="Preset name and prompt are required")
if len(name) > MAX_PRESET_NAME_CHARS or len(prompt) > MAX_PRESET_PROMPT_CHARS:
raise HTTPException(status_code=413, detail="Preset fields are too long")
db = get_db()
db.execute(
"UPDATE system_presets SET name = ?, prompt = ? WHERE id = ?",
(body["name"], body["prompt"], preset_id),
(name, prompt, preset_id),
)
db.commit()
db.close()
@@ -1340,11 +1530,11 @@ async def list_conversations():
@app.post("/api/conversations")
async def create_conversation(request: Request):
body = await request.json()
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
conv_id = str(uuid.uuid4())
now = datetime.now(timezone.utc).isoformat()
model = body.get("model", DEFAULT_MODEL)
title = body.get("title", "New Chat")
title = str(body.get("title", "New Chat"))[:MAX_CONVERSATION_TITLE_CHARS]
db = get_db()
db.execute(
"INSERT INTO conversations (id, title, model, created_at, updated_at) VALUES (?, ?, ?, ?, ?)",
@@ -1377,13 +1567,13 @@ async def get_conversation(conv_id: str):
@app.put("/api/conversations/{conv_id}")
async def update_conversation(conv_id: str, request: Request):
body = await request.json()
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
db = get_db()
now = datetime.now(timezone.utc).isoformat()
if "title" in body:
db.execute(
"UPDATE conversations SET title = ?, updated_at = ? WHERE id = ?",
(body["title"], now, conv_id),
(str(body["title"])[:MAX_CONVERSATION_TITLE_CHARS], now, conv_id),
)
if "model" in body:
db.execute(
@@ -1424,8 +1614,10 @@ async def delete_all_conversations():
@app.post("/api/search")
async def explicit_search(request: Request):
"""Explicit web search - bypasses model uncertainty, queries SearXNG directly."""
body = await request.json()
body = await read_json_body(request, BODY_LIMIT_CHAT_BYTES)
query = body.get("query", "").strip()
if len(query) > MAX_SEARCH_QUERY_CHARS:
raise HTTPException(status_code=413, detail="Search query is too long")
conv_id = body.get("conversation_id")
model = body.get("model", DEFAULT_MODEL)
@@ -1570,9 +1762,11 @@ def build_system_prompt(db, extra_prompt="", user_message=""):
@app.post("/api/chat")
async def chat(request: Request):
body = await request.json()
body = await read_json_body(request, BODY_LIMIT_CHAT_BYTES)
conv_id = body.get("conversation_id")
user_message = body.get("message", "").strip()
if len(user_message) > MAX_CHAT_MESSAGE_CHARS:
raise HTTPException(status_code=413, detail="Chat message is too long")
model = body.get("model", DEFAULT_MODEL)
preset_prompt = body.get("system_prompt", "")