- config.py: all constants, env vars, limits, skill registry, profiles - db.py: schema init, connection factory, skill state helpers - security.py: PIN hashing, audit logging, rate limiting, CSRF, request helpers - auth.py: session management, PIN verify, auth routes - memory.py: FTS5 CRUD + remember/forget command processing - search.py: SearXNG integration, perplexity scoring, refusal/hedge detection - gpu.py: rocm-smi stats - rag.py: Qdrant vector search + system prompt assembly - routers/: conversations, memories, models, presets, profile, settings, skills, chat, search - app.py: slim entry point, middleware, router registration only Bumps to v1.9.0
176 lines
5.8 KiB
Python
176 lines
5.8 KiB
Python
"""
|
|
JarvisChat - Security utilities.
|
|
PIN hashing, audit logging, incident tracking, CSRF/origin checks,
|
|
rate limiting, request helpers.
|
|
"""
|
|
import hashlib
|
|
import hmac
|
|
import json
|
|
import logging
|
|
import math
|
|
import os
|
|
import platform
|
|
import time
|
|
import uuid
|
|
from collections import defaultdict, deque
|
|
from datetime import datetime, timezone
|
|
from threading import Lock
|
|
from typing import Optional
|
|
from urllib.parse import urlparse
|
|
|
|
from fastapi import HTTPException, Request
|
|
|
|
from config import (
|
|
ALLOWED_NETWORKS, TRUST_X_FORWARDED_FOR, TRUSTED_ORIGINS,
|
|
BODY_LIMIT_DEFAULT_BYTES, BODY_LIMIT_CHAT_BYTES, BODY_LIMIT_PROFILE_BYTES,
|
|
RATE_WINDOW_SECONDS, RL_LOGIN_PER_WINDOW, RL_CHAT_PER_WINDOW,
|
|
RL_SEARCH_PER_WINDOW, RL_STATS_PER_WINDOW, RL_WRITE_PER_WINDOW,
|
|
RL_DEFAULT_PER_WINDOW, VERSION,
|
|
)
|
|
|
|
import ipaddress
|
|
|
|
log = logging.getLogger("jarvischat")
|
|
|
|
SESSIONS: dict = {}
|
|
PIN_ATTEMPTS: dict = {}
|
|
RATE_EVENTS: dict = defaultdict(deque)
|
|
SESSION_LOCK = Lock()
|
|
RATE_LOCK = Lock()
|
|
|
|
|
|
def hash_pin(pin: str, salt_hex: Optional[str] = None) -> tuple:
|
|
salt = bytes.fromhex(salt_hex) if salt_hex else os.urandom(16)
|
|
digest = hashlib.pbkdf2_hmac("sha256", pin.encode("utf-8"), salt, 200_000)
|
|
return salt.hex(), digest.hex()
|
|
|
|
|
|
def audit_event(event: str, outcome: str, *, ip: str = "unknown", role: str = "none",
|
|
details: str = "", warning: bool = False) -> None:
|
|
payload = {"event": event, "outcome": outcome, "ip": ip, "role": role, "details": details[:300]}
|
|
msg = "AUDIT " + json.dumps(payload, separators=(",", ":"))
|
|
if warning:
|
|
log.warning(msg)
|
|
else:
|
|
log.info(msg)
|
|
|
|
|
|
def create_incident_key() -> str:
|
|
ts = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
|
return f"INC-{ts}-{uuid.uuid4().hex[:8].upper()}"
|
|
|
|
|
|
def customer_error_envelope(message: str, incident_key: str) -> dict:
|
|
return {
|
|
"detail": message, "error_key": incident_key,
|
|
"error": {"message": message, "incident_key": incident_key,
|
|
"support_hint": "Share this incident key for exact diagnostics."},
|
|
}
|
|
|
|
|
|
def log_incident(event: str, *, message: str, request: Optional[Request] = None,
|
|
exc: Optional[Exception] = None) -> str:
|
|
incident_key = create_incident_key()
|
|
payload = {
|
|
"event": event, "incident_key": incident_key, "message": message,
|
|
"app_version": VERSION, "pid": os.getpid(), "python": platform.python_version(),
|
|
"platform": platform.platform(),
|
|
"method": request.method if request else "",
|
|
"path": request.url.path if request else "",
|
|
"client_ip": get_client_ip(request) if request else "",
|
|
}
|
|
if exc:
|
|
log.exception("INCIDENT " + json.dumps(payload, separators=(",", ":")))
|
|
else:
|
|
log.error("INCIDENT " + json.dumps(payload, separators=(",", ":")))
|
|
return incident_key
|
|
|
|
|
|
def get_client_ip(request: Request) -> str:
|
|
forwarded = request.headers.get("x-forwarded-for", "").strip()
|
|
if TRUST_X_FORWARDED_FOR and forwarded:
|
|
return forwarded.split(",")[0].strip()
|
|
if request.client and request.client.host:
|
|
return request.client.host
|
|
return "unknown"
|
|
|
|
|
|
def is_ip_allowed(ip: str) -> bool:
|
|
normalized = ip.strip().lower()
|
|
if normalized in {"localhost", "testclient"}:
|
|
normalized = "127.0.0.1"
|
|
try:
|
|
ip_obj = ipaddress.ip_address(normalized)
|
|
except ValueError:
|
|
return False
|
|
for network in ALLOWED_NETWORKS:
|
|
if ip_obj in network:
|
|
return True
|
|
return False
|
|
|
|
|
|
def request_body_limit(path: str) -> int:
|
|
if path in {"/api/chat", "/api/search"}:
|
|
return BODY_LIMIT_CHAT_BYTES
|
|
if path == "/api/profile":
|
|
return BODY_LIMIT_PROFILE_BYTES
|
|
return BODY_LIMIT_DEFAULT_BYTES
|
|
|
|
|
|
def rate_policy(path: str, method: str, ip: str, sid: str) -> tuple:
|
|
identity = sid or ip
|
|
if path == "/api/auth/login":
|
|
return f"login:{ip}", RL_LOGIN_PER_WINDOW
|
|
if path == "/api/chat":
|
|
return f"chat:{identity}", RL_CHAT_PER_WINDOW
|
|
if path == "/api/search":
|
|
return f"search:{identity}", RL_SEARCH_PER_WINDOW
|
|
if path == "/api/stats":
|
|
return f"stats:{ip}", RL_STATS_PER_WINDOW
|
|
if method in {"POST", "PUT", "DELETE", "PATCH"}:
|
|
return f"write:{identity}", RL_WRITE_PER_WINDOW
|
|
return f"api:{identity}", RL_DEFAULT_PER_WINDOW
|
|
|
|
|
|
def check_rate_limit(key: str, limit: int, window_seconds: int) -> tuple:
|
|
now_ts = time.time()
|
|
with RATE_LOCK:
|
|
bucket = RATE_EVENTS[key]
|
|
while bucket and bucket[0] <= (now_ts - window_seconds):
|
|
bucket.popleft()
|
|
if len(bucket) >= limit:
|
|
retry_after = max(1, int(math.ceil(window_seconds - (now_ts - bucket[0]))))
|
|
return False, retry_after
|
|
bucket.append(now_ts)
|
|
return True, 0
|
|
|
|
|
|
def origin_allowed(request: Request) -> bool:
|
|
host = request.headers.get("host", "").strip()
|
|
expected_origin = f"{request.url.scheme}://{host}".rstrip("/") if host else ""
|
|
origin = request.headers.get("origin", "").strip().rstrip("/")
|
|
referer = request.headers.get("referer", "").strip()
|
|
if origin:
|
|
return origin == expected_origin or origin in TRUSTED_ORIGINS
|
|
if referer:
|
|
parsed = urlparse(referer)
|
|
ref_origin = f"{parsed.scheme}://{parsed.netloc}".rstrip("/")
|
|
return ref_origin == expected_origin or ref_origin in TRUSTED_ORIGINS
|
|
return True
|
|
|
|
|
|
def is_state_changing(method: str) -> bool:
|
|
return method in {"POST", "PUT", "DELETE", "PATCH"}
|
|
|
|
|
|
async def read_json_body(request: Request, max_bytes: int) -> dict:
|
|
raw = await request.body()
|
|
if len(raw) > max_bytes:
|
|
raise HTTPException(status_code=413, detail="Request payload too large")
|
|
if not raw:
|
|
return {}
|
|
try:
|
|
return json.loads(raw.decode("utf-8"))
|
|
except Exception:
|
|
raise HTTPException(status_code=400, detail="Invalid JSON payload")
|