refactor(arch): modular package structure — split monolithic app.py into config/db/auth/memory/search/rag/gpu + routers/
- config.py: all constants, env vars, limits, skill registry, profiles - db.py: schema init, connection factory, skill state helpers - security.py: PIN hashing, audit logging, rate limiting, CSRF, request helpers - auth.py: session management, PIN verify, auth routes - memory.py: FTS5 CRUD + remember/forget command processing - search.py: SearXNG integration, perplexity scoring, refusal/hedge detection - gpu.py: rocm-smi stats - rag.py: Qdrant vector search + system prompt assembly - routers/: conversations, memories, models, presets, profile, settings, skills, chat, search - app.py: slim entry point, middleware, router registration only Bumps to v1.9.0
This commit is contained in:
175
security.py
Normal file
175
security.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
JarvisChat - Security utilities.
|
||||
PIN hashing, audit logging, incident tracking, CSRF/origin checks,
|
||||
rate limiting, request helpers.
|
||||
"""
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict, deque
|
||||
from datetime import datetime, timezone
|
||||
from threading import Lock
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from config import (
|
||||
ALLOWED_NETWORKS, TRUST_X_FORWARDED_FOR, TRUSTED_ORIGINS,
|
||||
BODY_LIMIT_DEFAULT_BYTES, BODY_LIMIT_CHAT_BYTES, BODY_LIMIT_PROFILE_BYTES,
|
||||
RATE_WINDOW_SECONDS, RL_LOGIN_PER_WINDOW, RL_CHAT_PER_WINDOW,
|
||||
RL_SEARCH_PER_WINDOW, RL_STATS_PER_WINDOW, RL_WRITE_PER_WINDOW,
|
||||
RL_DEFAULT_PER_WINDOW, VERSION,
|
||||
)
|
||||
|
||||
import ipaddress
|
||||
|
||||
log = logging.getLogger("jarvischat")
|
||||
|
||||
SESSIONS: dict = {}
|
||||
PIN_ATTEMPTS: dict = {}
|
||||
RATE_EVENTS: dict = defaultdict(deque)
|
||||
SESSION_LOCK = Lock()
|
||||
RATE_LOCK = Lock()
|
||||
|
||||
|
||||
def hash_pin(pin: str, salt_hex: Optional[str] = None) -> tuple:
|
||||
salt = bytes.fromhex(salt_hex) if salt_hex else os.urandom(16)
|
||||
digest = hashlib.pbkdf2_hmac("sha256", pin.encode("utf-8"), salt, 200_000)
|
||||
return salt.hex(), digest.hex()
|
||||
|
||||
|
||||
def audit_event(event: str, outcome: str, *, ip: str = "unknown", role: str = "none",
|
||||
details: str = "", warning: bool = False) -> None:
|
||||
payload = {"event": event, "outcome": outcome, "ip": ip, "role": role, "details": details[:300]}
|
||||
msg = "AUDIT " + json.dumps(payload, separators=(",", ":"))
|
||||
if warning:
|
||||
log.warning(msg)
|
||||
else:
|
||||
log.info(msg)
|
||||
|
||||
|
||||
def create_incident_key() -> str:
|
||||
ts = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
||||
return f"INC-{ts}-{uuid.uuid4().hex[:8].upper()}"
|
||||
|
||||
|
||||
def customer_error_envelope(message: str, incident_key: str) -> dict:
|
||||
return {
|
||||
"detail": message, "error_key": incident_key,
|
||||
"error": {"message": message, "incident_key": incident_key,
|
||||
"support_hint": "Share this incident key for exact diagnostics."},
|
||||
}
|
||||
|
||||
|
||||
def log_incident(event: str, *, message: str, request: Optional[Request] = None,
|
||||
exc: Optional[Exception] = None) -> str:
|
||||
incident_key = create_incident_key()
|
||||
payload = {
|
||||
"event": event, "incident_key": incident_key, "message": message,
|
||||
"app_version": VERSION, "pid": os.getpid(), "python": platform.python_version(),
|
||||
"platform": platform.platform(),
|
||||
"method": request.method if request else "",
|
||||
"path": request.url.path if request else "",
|
||||
"client_ip": get_client_ip(request) if request else "",
|
||||
}
|
||||
if exc:
|
||||
log.exception("INCIDENT " + json.dumps(payload, separators=(",", ":")))
|
||||
else:
|
||||
log.error("INCIDENT " + json.dumps(payload, separators=(",", ":")))
|
||||
return incident_key
|
||||
|
||||
|
||||
def get_client_ip(request: Request) -> str:
|
||||
forwarded = request.headers.get("x-forwarded-for", "").strip()
|
||||
if TRUST_X_FORWARDED_FOR and forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
if request.client and request.client.host:
|
||||
return request.client.host
|
||||
return "unknown"
|
||||
|
||||
|
||||
def is_ip_allowed(ip: str) -> bool:
|
||||
normalized = ip.strip().lower()
|
||||
if normalized in {"localhost", "testclient"}:
|
||||
normalized = "127.0.0.1"
|
||||
try:
|
||||
ip_obj = ipaddress.ip_address(normalized)
|
||||
except ValueError:
|
||||
return False
|
||||
for network in ALLOWED_NETWORKS:
|
||||
if ip_obj in network:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def request_body_limit(path: str) -> int:
|
||||
if path in {"/api/chat", "/api/search"}:
|
||||
return BODY_LIMIT_CHAT_BYTES
|
||||
if path == "/api/profile":
|
||||
return BODY_LIMIT_PROFILE_BYTES
|
||||
return BODY_LIMIT_DEFAULT_BYTES
|
||||
|
||||
|
||||
def rate_policy(path: str, method: str, ip: str, sid: str) -> tuple:
|
||||
identity = sid or ip
|
||||
if path == "/api/auth/login":
|
||||
return f"login:{ip}", RL_LOGIN_PER_WINDOW
|
||||
if path == "/api/chat":
|
||||
return f"chat:{identity}", RL_CHAT_PER_WINDOW
|
||||
if path == "/api/search":
|
||||
return f"search:{identity}", RL_SEARCH_PER_WINDOW
|
||||
if path == "/api/stats":
|
||||
return f"stats:{ip}", RL_STATS_PER_WINDOW
|
||||
if method in {"POST", "PUT", "DELETE", "PATCH"}:
|
||||
return f"write:{identity}", RL_WRITE_PER_WINDOW
|
||||
return f"api:{identity}", RL_DEFAULT_PER_WINDOW
|
||||
|
||||
|
||||
def check_rate_limit(key: str, limit: int, window_seconds: int) -> tuple:
|
||||
now_ts = time.time()
|
||||
with RATE_LOCK:
|
||||
bucket = RATE_EVENTS[key]
|
||||
while bucket and bucket[0] <= (now_ts - window_seconds):
|
||||
bucket.popleft()
|
||||
if len(bucket) >= limit:
|
||||
retry_after = max(1, int(math.ceil(window_seconds - (now_ts - bucket[0]))))
|
||||
return False, retry_after
|
||||
bucket.append(now_ts)
|
||||
return True, 0
|
||||
|
||||
|
||||
def origin_allowed(request: Request) -> bool:
|
||||
host = request.headers.get("host", "").strip()
|
||||
expected_origin = f"{request.url.scheme}://{host}".rstrip("/") if host else ""
|
||||
origin = request.headers.get("origin", "").strip().rstrip("/")
|
||||
referer = request.headers.get("referer", "").strip()
|
||||
if origin:
|
||||
return origin == expected_origin or origin in TRUSTED_ORIGINS
|
||||
if referer:
|
||||
parsed = urlparse(referer)
|
||||
ref_origin = f"{parsed.scheme}://{parsed.netloc}".rstrip("/")
|
||||
return ref_origin == expected_origin or ref_origin in TRUSTED_ORIGINS
|
||||
return True
|
||||
|
||||
|
||||
def is_state_changing(method: str) -> bool:
|
||||
return method in {"POST", "PUT", "DELETE", "PATCH"}
|
||||
|
||||
|
||||
async def read_json_body(request: Request, max_bytes: int) -> dict:
|
||||
raw = await request.body()
|
||||
if len(raw) > max_bytes:
|
||||
raise HTTPException(status_code=413, detail="Request payload too large")
|
||||
if not raw:
|
||||
return {}
|
||||
try:
|
||||
return json.loads(raw.decode("utf-8"))
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="Invalid JSON payload")
|
||||
Reference in New Issue
Block a user