Files
jarvisChat/security.py
gramps d01dd3b761 refactor(arch): modular package structure — split monolithic app.py into config/db/auth/memory/search/rag/gpu + routers/
- config.py: all constants, env vars, limits, skill registry, profiles
- db.py: schema init, connection factory, skill state helpers
- security.py: PIN hashing, audit logging, rate limiting, CSRF, request helpers
- auth.py: session management, PIN verify, auth routes
- memory.py: FTS5 CRUD + remember/forget command processing
- search.py: SearXNG integration, perplexity scoring, refusal/hedge detection
- gpu.py: rocm-smi stats
- rag.py: Qdrant vector search + system prompt assembly
- routers/: conversations, memories, models, presets, profile, settings, skills, chat, search
- app.py: slim entry point, middleware, router registration only

Bumps to v1.9.0
2026-06-16 08:17:46 -07:00

176 lines
5.8 KiB
Python

"""
JarvisChat - Security utilities.
PIN hashing, audit logging, incident tracking, CSRF/origin checks,
rate limiting, request helpers.
"""
import hashlib
import hmac
import json
import logging
import math
import os
import platform
import time
import uuid
from collections import defaultdict, deque
from datetime import datetime, timezone
from threading import Lock
from typing import Optional
from urllib.parse import urlparse
from fastapi import HTTPException, Request
from config import (
ALLOWED_NETWORKS, TRUST_X_FORWARDED_FOR, TRUSTED_ORIGINS,
BODY_LIMIT_DEFAULT_BYTES, BODY_LIMIT_CHAT_BYTES, BODY_LIMIT_PROFILE_BYTES,
RATE_WINDOW_SECONDS, RL_LOGIN_PER_WINDOW, RL_CHAT_PER_WINDOW,
RL_SEARCH_PER_WINDOW, RL_STATS_PER_WINDOW, RL_WRITE_PER_WINDOW,
RL_DEFAULT_PER_WINDOW, VERSION,
)
import ipaddress
log = logging.getLogger("jarvischat")
SESSIONS: dict = {}
PIN_ATTEMPTS: dict = {}
RATE_EVENTS: dict = defaultdict(deque)
SESSION_LOCK = Lock()
RATE_LOCK = Lock()
def hash_pin(pin: str, salt_hex: Optional[str] = None) -> tuple:
salt = bytes.fromhex(salt_hex) if salt_hex else os.urandom(16)
digest = hashlib.pbkdf2_hmac("sha256", pin.encode("utf-8"), salt, 200_000)
return salt.hex(), digest.hex()
def audit_event(event: str, outcome: str, *, ip: str = "unknown", role: str = "none",
details: str = "", warning: bool = False) -> None:
payload = {"event": event, "outcome": outcome, "ip": ip, "role": role, "details": details[:300]}
msg = "AUDIT " + json.dumps(payload, separators=(",", ":"))
if warning:
log.warning(msg)
else:
log.info(msg)
def create_incident_key() -> str:
ts = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
return f"INC-{ts}-{uuid.uuid4().hex[:8].upper()}"
def customer_error_envelope(message: str, incident_key: str) -> dict:
return {
"detail": message, "error_key": incident_key,
"error": {"message": message, "incident_key": incident_key,
"support_hint": "Share this incident key for exact diagnostics."},
}
def log_incident(event: str, *, message: str, request: Optional[Request] = None,
exc: Optional[Exception] = None) -> str:
incident_key = create_incident_key()
payload = {
"event": event, "incident_key": incident_key, "message": message,
"app_version": VERSION, "pid": os.getpid(), "python": platform.python_version(),
"platform": platform.platform(),
"method": request.method if request else "",
"path": request.url.path if request else "",
"client_ip": get_client_ip(request) if request else "",
}
if exc:
log.exception("INCIDENT " + json.dumps(payload, separators=(",", ":")))
else:
log.error("INCIDENT " + json.dumps(payload, separators=(",", ":")))
return incident_key
def get_client_ip(request: Request) -> str:
forwarded = request.headers.get("x-forwarded-for", "").strip()
if TRUST_X_FORWARDED_FOR and forwarded:
return forwarded.split(",")[0].strip()
if request.client and request.client.host:
return request.client.host
return "unknown"
def is_ip_allowed(ip: str) -> bool:
normalized = ip.strip().lower()
if normalized in {"localhost", "testclient"}:
normalized = "127.0.0.1"
try:
ip_obj = ipaddress.ip_address(normalized)
except ValueError:
return False
for network in ALLOWED_NETWORKS:
if ip_obj in network:
return True
return False
def request_body_limit(path: str) -> int:
if path in {"/api/chat", "/api/search"}:
return BODY_LIMIT_CHAT_BYTES
if path == "/api/profile":
return BODY_LIMIT_PROFILE_BYTES
return BODY_LIMIT_DEFAULT_BYTES
def rate_policy(path: str, method: str, ip: str, sid: str) -> tuple:
identity = sid or ip
if path == "/api/auth/login":
return f"login:{ip}", RL_LOGIN_PER_WINDOW
if path == "/api/chat":
return f"chat:{identity}", RL_CHAT_PER_WINDOW
if path == "/api/search":
return f"search:{identity}", RL_SEARCH_PER_WINDOW
if path == "/api/stats":
return f"stats:{ip}", RL_STATS_PER_WINDOW
if method in {"POST", "PUT", "DELETE", "PATCH"}:
return f"write:{identity}", RL_WRITE_PER_WINDOW
return f"api:{identity}", RL_DEFAULT_PER_WINDOW
def check_rate_limit(key: str, limit: int, window_seconds: int) -> tuple:
now_ts = time.time()
with RATE_LOCK:
bucket = RATE_EVENTS[key]
while bucket and bucket[0] <= (now_ts - window_seconds):
bucket.popleft()
if len(bucket) >= limit:
retry_after = max(1, int(math.ceil(window_seconds - (now_ts - bucket[0]))))
return False, retry_after
bucket.append(now_ts)
return True, 0
def origin_allowed(request: Request) -> bool:
host = request.headers.get("host", "").strip()
expected_origin = f"{request.url.scheme}://{host}".rstrip("/") if host else ""
origin = request.headers.get("origin", "").strip().rstrip("/")
referer = request.headers.get("referer", "").strip()
if origin:
return origin == expected_origin or origin in TRUSTED_ORIGINS
if referer:
parsed = urlparse(referer)
ref_origin = f"{parsed.scheme}://{parsed.netloc}".rstrip("/")
return ref_origin == expected_origin or ref_origin in TRUSTED_ORIGINS
return True
def is_state_changing(method: str) -> bool:
return method in {"POST", "PUT", "DELETE", "PATCH"}
async def read_json_body(request: Request, max_bytes: int) -> dict:
raw = await request.body()
if len(raw) > max_bytes:
raise HTTPException(status_code=413, detail="Request payload too large")
if not raw:
return {}
try:
return json.loads(raw.decode("utf-8"))
except Exception:
raise HTTPException(status_code=400, detail="Invalid JSON payload")