feat(security): add LAN IP allowlist and ingress guardrails
This commit is contained in:
240
app.py
240
app.py
@@ -27,6 +27,8 @@ import hmac
|
||||
import time
|
||||
import uuid
|
||||
import re
|
||||
import ipaddress
|
||||
from collections import defaultdict, deque
|
||||
from threading import Lock
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
@@ -72,6 +74,35 @@ TRUSTED_ORIGINS = {
|
||||
if origin.strip()
|
||||
}
|
||||
|
||||
DEFAULT_ALLOWED_CIDRS = "127.0.0.0/8,::1/128,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16"
|
||||
ALLOWED_CIDRS_RAW = os.getenv("JARVISCHAT_ALLOWED_CIDRS", DEFAULT_ALLOWED_CIDRS)
|
||||
TRUST_X_FORWARDED_FOR = (
|
||||
os.getenv("JARVISCHAT_TRUST_X_FORWARDED_FOR", "false").lower() == "true"
|
||||
)
|
||||
|
||||
# --- Rate / Payload Guardrails (home-lab friendly defaults) ---
|
||||
RATE_WINDOW_SECONDS = 60
|
||||
RL_LOGIN_PER_WINDOW = 10
|
||||
RL_CHAT_PER_WINDOW = 24
|
||||
RL_SEARCH_PER_WINDOW = 16
|
||||
RL_WRITE_PER_WINDOW = 30
|
||||
RL_DEFAULT_PER_WINDOW = 240
|
||||
RL_STATS_PER_WINDOW = 600
|
||||
|
||||
BODY_LIMIT_DEFAULT_BYTES = 64 * 1024
|
||||
BODY_LIMIT_CHAT_BYTES = 128 * 1024
|
||||
BODY_LIMIT_PROFILE_BYTES = 256 * 1024
|
||||
|
||||
MAX_CHAT_MESSAGE_CHARS = 8000
|
||||
MAX_SEARCH_QUERY_CHARS = 500
|
||||
MAX_PROFILE_CHARS = 32000
|
||||
MAX_MEMORY_FACT_CHARS = 2000
|
||||
MAX_PRESET_NAME_CHARS = 120
|
||||
MAX_PRESET_PROMPT_CHARS = 12000
|
||||
MAX_SETTINGS_KEYS = 16
|
||||
MAX_SETTINGS_VALUE_CHARS = 8000
|
||||
MAX_CONVERSATION_TITLE_CHARS = 200
|
||||
|
||||
# --- Templates and Static Files ---
|
||||
templates = Jinja2Templates(directory=str(BASE_DIR / "templates"))
|
||||
|
||||
@@ -109,7 +140,9 @@ HEDGE_PATTERNS = [
|
||||
|
||||
SESSIONS: dict[str, dict] = {}
|
||||
PIN_ATTEMPTS: dict[str, dict] = {}
|
||||
RATE_EVENTS: dict[str, deque[float]] = defaultdict(deque)
|
||||
SESSION_LOCK = Lock()
|
||||
RATE_LOCK = Lock()
|
||||
|
||||
|
||||
def hash_pin(pin: str, salt_hex: Optional[str] = None) -> tuple[str, str]:
|
||||
@@ -143,6 +176,23 @@ def audit_event(
|
||||
log.info(msg)
|
||||
|
||||
|
||||
def parse_allowed_cidrs(raw: str) -> list[ipaddress._BaseNetwork]:
|
||||
"""Parse comma-separated CIDR list into validated network objects."""
|
||||
networks: list[ipaddress._BaseNetwork] = []
|
||||
for entry in (raw or "").split(","):
|
||||
value = entry.strip()
|
||||
if not value:
|
||||
continue
|
||||
try:
|
||||
networks.append(ipaddress.ip_network(value, strict=False))
|
||||
except ValueError:
|
||||
log.warning(f"Invalid CIDR ignored: {value}")
|
||||
return networks
|
||||
|
||||
|
||||
ALLOWED_NETWORKS = parse_allowed_cidrs(ALLOWED_CIDRS_RAW)
|
||||
|
||||
|
||||
def clean_hedging(text: str) -> str:
|
||||
"""Remove hedging sentences from model response."""
|
||||
cleaned = text
|
||||
@@ -175,6 +225,55 @@ def sanitize_outbound_url(url: str) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
async def read_json_body(request: Request, max_bytes: int) -> dict:
|
||||
"""Read and parse JSON body with a hard size cap for non-streaming requests."""
|
||||
raw = await request.body()
|
||||
if len(raw) > max_bytes:
|
||||
raise HTTPException(status_code=413, detail="Request payload too large")
|
||||
if not raw:
|
||||
return {}
|
||||
try:
|
||||
return json.loads(raw.decode("utf-8"))
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON payload")
|
||||
|
||||
|
||||
def request_body_limit(path: str) -> int:
|
||||
if path in {"/api/chat", "/api/search"}:
|
||||
return BODY_LIMIT_CHAT_BYTES
|
||||
if path == "/api/profile":
|
||||
return BODY_LIMIT_PROFILE_BYTES
|
||||
return BODY_LIMIT_DEFAULT_BYTES
|
||||
|
||||
|
||||
def rate_policy(path: str, method: str, ip: str, sid: str) -> tuple[str, int]:
|
||||
identity = sid or ip
|
||||
if path == "/api/auth/login":
|
||||
return f"login:{ip}", RL_LOGIN_PER_WINDOW
|
||||
if path == "/api/chat":
|
||||
return f"chat:{identity}", RL_CHAT_PER_WINDOW
|
||||
if path == "/api/search":
|
||||
return f"search:{identity}", RL_SEARCH_PER_WINDOW
|
||||
if path == "/api/stats":
|
||||
return f"stats:{ip}", RL_STATS_PER_WINDOW
|
||||
if method in {"POST", "PUT", "DELETE", "PATCH"}:
|
||||
return f"write:{identity}", RL_WRITE_PER_WINDOW
|
||||
return f"api:{identity}", RL_DEFAULT_PER_WINDOW
|
||||
|
||||
|
||||
def check_rate_limit(key: str, limit: int, window_seconds: int) -> tuple[bool, int]:
|
||||
now_ts = time.time()
|
||||
with RATE_LOCK:
|
||||
bucket = RATE_EVENTS[key]
|
||||
while bucket and bucket[0] <= (now_ts - window_seconds):
|
||||
bucket.popleft()
|
||||
if len(bucket) >= limit:
|
||||
retry_after = max(1, int(math.ceil(window_seconds - (now_ts - bucket[0]))))
|
||||
return False, retry_after
|
||||
bucket.append(now_ts)
|
||||
return True, 0
|
||||
|
||||
|
||||
# --- Default Profile ---
|
||||
DEFAULT_PROFILE = """You are a coding companion running locally on a machine called "jarvis".
|
||||
|
||||
@@ -760,13 +859,29 @@ if static_dir.exists():
|
||||
|
||||
def get_client_ip(request: Request) -> str:
|
||||
forwarded = request.headers.get("x-forwarded-for", "").strip()
|
||||
if forwarded:
|
||||
if TRUST_X_FORWARDED_FOR and forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
if request.client and request.client.host:
|
||||
return request.client.host
|
||||
return "unknown"
|
||||
|
||||
|
||||
def is_ip_allowed(ip: str) -> bool:
|
||||
"""Allow only loopback/private CIDRs by default; env can override CIDR list."""
|
||||
normalized = ip.strip().lower()
|
||||
if normalized in {"localhost", "testclient"}:
|
||||
normalized = "127.0.0.1"
|
||||
try:
|
||||
ip_obj = ipaddress.ip_address(normalized)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
for network in ALLOWED_NETWORKS:
|
||||
if ip_obj in network:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def cleanup_sessions(now_ts: Optional[float] = None) -> None:
|
||||
now_ts = now_ts or time.time()
|
||||
with SESSION_LOCK:
|
||||
@@ -927,9 +1042,53 @@ def origin_allowed(request: Request) -> bool:
|
||||
async def session_auth_middleware(request: Request, call_next):
|
||||
path = request.url.path
|
||||
ip = get_client_ip(request)
|
||||
sid = request.headers.get("x-session-id", "").strip()
|
||||
request.state.session_role = "none"
|
||||
request.state.client_ip = ip
|
||||
|
||||
if path.startswith("/api/"):
|
||||
if not is_ip_allowed(ip):
|
||||
audit_event(
|
||||
"ip_allowlist",
|
||||
"denied",
|
||||
ip=ip,
|
||||
role="none",
|
||||
details=f"{request.method} {path}",
|
||||
warning=True,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "Client IP not allowed"},
|
||||
)
|
||||
|
||||
if path.startswith("/api/"):
|
||||
rate_key, rate_limit = rate_policy(path, request.method, ip, sid)
|
||||
allowed, retry_after = check_rate_limit(
|
||||
rate_key, rate_limit, RATE_WINDOW_SECONDS
|
||||
)
|
||||
if not allowed:
|
||||
audit_event(
|
||||
"rate_limit",
|
||||
"denied",
|
||||
ip=ip,
|
||||
role="none",
|
||||
details=f"{request.method} {path} retry_after={retry_after}",
|
||||
warning=True,
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": f"Rate limit exceeded. Retry in {retry_after}s."},
|
||||
)
|
||||
|
||||
if request.method in {"POST", "PUT", "PATCH"}:
|
||||
max_bytes = request_body_limit(path)
|
||||
content_length = request.headers.get("content-length", "").strip()
|
||||
if content_length.isdigit() and int(content_length) > max_bytes:
|
||||
return JSONResponse(
|
||||
status_code=413,
|
||||
content={"detail": "Request payload too large"},
|
||||
)
|
||||
|
||||
unauth_paths = {
|
||||
"/api/auth/login",
|
||||
"/api/auth/logout",
|
||||
@@ -955,7 +1114,6 @@ async def session_auth_middleware(request: Request, call_next):
|
||||
)
|
||||
|
||||
if path.startswith("/api/") and path not in unauth_paths:
|
||||
sid = request.headers.get("x-session-id", "").strip()
|
||||
session = get_session(sid, ip, touch=True)
|
||||
if not session:
|
||||
audit_event(
|
||||
@@ -1016,7 +1174,7 @@ async def auth_guest(request: Request):
|
||||
|
||||
@app.post("/api/auth/login")
|
||||
async def auth_login(request: Request):
|
||||
body = await request.json()
|
||||
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||
pin = str(body.get("pin", ""))
|
||||
ip = get_client_ip(request)
|
||||
|
||||
@@ -1081,7 +1239,7 @@ async def auth_logout(request: Request):
|
||||
role = session.get("role", "none") if session else "none"
|
||||
if not sid:
|
||||
try:
|
||||
body = await request.json()
|
||||
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||
sid = str(body.get("session_id", "")).strip()
|
||||
except Exception:
|
||||
try:
|
||||
@@ -1125,7 +1283,7 @@ async def running_models():
|
||||
|
||||
@app.post("/api/show")
|
||||
async def show_model(request: Request):
|
||||
body = await request.json()
|
||||
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
resp = await client.post(f"{OLLAMA_BASE}/api/show", json=body, timeout=10)
|
||||
@@ -1175,9 +1333,14 @@ async def list_memories(topic: Optional[str] = None):
|
||||
|
||||
@app.post("/api/memories")
|
||||
async def create_memory(request: Request):
|
||||
body = await request.json()
|
||||
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||
fact = str(body.get("fact", "")).strip()
|
||||
if not fact:
|
||||
raise HTTPException(status_code=400, detail="Memory fact is required")
|
||||
if len(fact) > MAX_MEMORY_FACT_CHARS:
|
||||
raise HTTPException(status_code=413, detail="Memory fact is too long")
|
||||
rowid = add_memory(
|
||||
fact=body["fact"],
|
||||
fact=fact,
|
||||
topic=body.get("topic", "general"),
|
||||
source=body.get("source", "manual"),
|
||||
)
|
||||
@@ -1193,8 +1356,13 @@ async def remove_memory(rowid: int):
|
||||
|
||||
@app.put("/api/memories/{rowid}")
|
||||
async def edit_memory(rowid: int, request: Request):
|
||||
body = await request.json()
|
||||
if not update_memory(rowid, body["fact"]):
|
||||
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||
fact = str(body.get("fact", "")).strip()
|
||||
if not fact:
|
||||
raise HTTPException(status_code=400, detail="Memory fact is required")
|
||||
if len(fact) > MAX_MEMORY_FACT_CHARS:
|
||||
raise HTTPException(status_code=413, detail="Memory fact is too long")
|
||||
if not update_memory(rowid, fact):
|
||||
raise HTTPException(status_code=404, detail="Memory not found")
|
||||
return {"status": "ok"}
|
||||
|
||||
@@ -1233,12 +1401,15 @@ async def get_profile():
|
||||
|
||||
@app.put("/api/profile")
|
||||
async def update_profile(request: Request):
|
||||
body = await request.json()
|
||||
body = await read_json_body(request, BODY_LIMIT_PROFILE_BYTES)
|
||||
content = str(body.get("content", ""))
|
||||
if len(content) > MAX_PROFILE_CHARS:
|
||||
raise HTTPException(status_code=413, detail="Profile content is too long")
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
db = get_db()
|
||||
db.execute(
|
||||
"UPDATE profile SET content = ?, updated_at = ? WHERE id = 1",
|
||||
(body["content"], now),
|
||||
(content, now),
|
||||
)
|
||||
db.commit()
|
||||
db.close()
|
||||
@@ -1263,9 +1434,16 @@ async def get_settings():
|
||||
|
||||
@app.put("/api/settings")
|
||||
async def update_settings(request: Request):
|
||||
body = await request.json()
|
||||
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||
if not isinstance(body, dict):
|
||||
raise HTTPException(status_code=400, detail="Settings payload must be an object")
|
||||
if len(body) > MAX_SETTINGS_KEYS:
|
||||
raise HTTPException(status_code=413, detail="Too many settings in one request")
|
||||
db = get_db()
|
||||
for key, value in body.items():
|
||||
if len(str(key)) > 80 or len(str(value)) > MAX_SETTINGS_VALUE_CHARS:
|
||||
db.close()
|
||||
raise HTTPException(status_code=413, detail="Setting key/value too long")
|
||||
db.execute(
|
||||
"INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)",
|
||||
(key, str(value)),
|
||||
@@ -1290,26 +1468,38 @@ async def list_presets():
|
||||
|
||||
@app.post("/api/presets")
|
||||
async def create_preset(request: Request):
|
||||
body = await request.json()
|
||||
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||
name = str(body.get("name", "")).strip()
|
||||
prompt = str(body.get("prompt", "")).strip()
|
||||
if not name or not prompt:
|
||||
raise HTTPException(status_code=400, detail="Preset name and prompt are required")
|
||||
if len(name) > MAX_PRESET_NAME_CHARS or len(prompt) > MAX_PRESET_PROMPT_CHARS:
|
||||
raise HTTPException(status_code=413, detail="Preset fields are too long")
|
||||
preset_id = str(uuid.uuid4())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
db = get_db()
|
||||
db.execute(
|
||||
"INSERT INTO system_presets (id, name, prompt, is_default, created_at) VALUES (?, ?, ?, 0, ?)",
|
||||
(preset_id, body["name"], body["prompt"], now),
|
||||
(preset_id, name, prompt, now),
|
||||
)
|
||||
db.commit()
|
||||
db.close()
|
||||
return {"id": preset_id, "name": body["name"], "prompt": body["prompt"]}
|
||||
return {"id": preset_id, "name": name, "prompt": prompt}
|
||||
|
||||
|
||||
@app.put("/api/presets/{preset_id}")
|
||||
async def update_preset(preset_id: str, request: Request):
|
||||
body = await request.json()
|
||||
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||
name = str(body.get("name", "")).strip()
|
||||
prompt = str(body.get("prompt", "")).strip()
|
||||
if not name or not prompt:
|
||||
raise HTTPException(status_code=400, detail="Preset name and prompt are required")
|
||||
if len(name) > MAX_PRESET_NAME_CHARS or len(prompt) > MAX_PRESET_PROMPT_CHARS:
|
||||
raise HTTPException(status_code=413, detail="Preset fields are too long")
|
||||
db = get_db()
|
||||
db.execute(
|
||||
"UPDATE system_presets SET name = ?, prompt = ? WHERE id = ?",
|
||||
(body["name"], body["prompt"], preset_id),
|
||||
(name, prompt, preset_id),
|
||||
)
|
||||
db.commit()
|
||||
db.close()
|
||||
@@ -1340,11 +1530,11 @@ async def list_conversations():
|
||||
|
||||
@app.post("/api/conversations")
|
||||
async def create_conversation(request: Request):
|
||||
body = await request.json()
|
||||
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||
conv_id = str(uuid.uuid4())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
model = body.get("model", DEFAULT_MODEL)
|
||||
title = body.get("title", "New Chat")
|
||||
title = str(body.get("title", "New Chat"))[:MAX_CONVERSATION_TITLE_CHARS]
|
||||
db = get_db()
|
||||
db.execute(
|
||||
"INSERT INTO conversations (id, title, model, created_at, updated_at) VALUES (?, ?, ?, ?, ?)",
|
||||
@@ -1377,13 +1567,13 @@ async def get_conversation(conv_id: str):
|
||||
|
||||
@app.put("/api/conversations/{conv_id}")
|
||||
async def update_conversation(conv_id: str, request: Request):
|
||||
body = await request.json()
|
||||
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||
db = get_db()
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
if "title" in body:
|
||||
db.execute(
|
||||
"UPDATE conversations SET title = ?, updated_at = ? WHERE id = ?",
|
||||
(body["title"], now, conv_id),
|
||||
(str(body["title"])[:MAX_CONVERSATION_TITLE_CHARS], now, conv_id),
|
||||
)
|
||||
if "model" in body:
|
||||
db.execute(
|
||||
@@ -1424,8 +1614,10 @@ async def delete_all_conversations():
|
||||
@app.post("/api/search")
|
||||
async def explicit_search(request: Request):
|
||||
"""Explicit web search - bypasses model uncertainty, queries SearXNG directly."""
|
||||
body = await request.json()
|
||||
body = await read_json_body(request, BODY_LIMIT_CHAT_BYTES)
|
||||
query = body.get("query", "").strip()
|
||||
if len(query) > MAX_SEARCH_QUERY_CHARS:
|
||||
raise HTTPException(status_code=413, detail="Search query is too long")
|
||||
conv_id = body.get("conversation_id")
|
||||
model = body.get("model", DEFAULT_MODEL)
|
||||
|
||||
@@ -1570,9 +1762,11 @@ def build_system_prompt(db, extra_prompt="", user_message=""):
|
||||
|
||||
@app.post("/api/chat")
|
||||
async def chat(request: Request):
|
||||
body = await request.json()
|
||||
body = await read_json_body(request, BODY_LIMIT_CHAT_BYTES)
|
||||
conv_id = body.get("conversation_id")
|
||||
user_message = body.get("message", "").strip()
|
||||
if len(user_message) > MAX_CHAT_MESSAGE_CHARS:
|
||||
raise HTTPException(status_code=413, detail="Chat message is too long")
|
||||
model = body.get("model", DEFAULT_MODEL)
|
||||
preset_prompt = body.get("system_prompt", "")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user