Compare commits
20 Commits
18bca027de
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b8405b8d76 | ||
|
|
e3b1780292 | ||
|
|
66b086c3f3 | ||
|
|
4b36fd315a | ||
|
|
fcc0605a4a | ||
|
|
091e2ad2e3 | ||
|
|
5986c4ad86 | ||
|
|
cc1efa7a21 | ||
| 41a8708c0d | |||
|
|
ec2f4c0332 | ||
| f691787037 | |||
| 56919965e1 | |||
| f1fbc24c94 | |||
| 8d3cf5d478 | |||
| d01dd3b761 | |||
| 5075a6bc55 | |||
| 970abc8957 | |||
| dd475a6f2d | |||
| 6de3a1e154 | |||
| 5a652c1b74 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -4,3 +4,4 @@
|
|||||||
__pycache__/
|
__pycache__/
|
||||||
venv/
|
venv/
|
||||||
readme.md-
|
readme.md-
|
||||||
|
*.bak
|
||||||
|
|||||||
114
AGENTS.md
Normal file
114
AGENTS.md
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
# JarvisChat — Agents Guide
|
||||||
|
|
||||||
|
## Run
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./venv/bin/uvicorn app:app --host 0.0.0.0 --port 8080 --reload
|
||||||
|
```
|
||||||
|
|
||||||
|
## Tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./venv/bin/python -m pytest tests/ -v
|
||||||
|
```
|
||||||
|
|
||||||
|
All tests use `tmp_path` fixtures + monkeypatched `httpx.AsyncClient.stream`. No external services needed. Test factories reset `SESSIONS`, `PIN_ATTEMPTS`, `RATE_EVENTS` globals — be careful not to let test state leak. After the modular refactor, tests import directly from the correct modules (`db`, `security`, `config`, `search`, `rag`, `memory`, `routers.*`) — not from the old monolithic `app` namespace.
|
||||||
|
|
||||||
|
Every router has a dedicated test file:
|
||||||
|
| File | Covers |
|
||||||
|
|------|--------|
|
||||||
|
| `test_auth_capabilities.py` | `auth.py` — guest/admin sessions, origin blocking, logout |
|
||||||
|
| `test_chat_streaming_and_memory_paths.py` | `routers/chat.py` — streaming, auto-search, remember/forget |
|
||||||
|
| `test_completions.py` | `routers/completions.py` — API key auth, FIM, streaming, blocking, errors |
|
||||||
|
| `test_conversations.py` | `routers/conversations.py` — full CRUD, guest admin enforcement |
|
||||||
|
| `test_memories.py` | `routers/memories.py` — edit, search, stats endpoints |
|
||||||
|
| `test_models_router.py` | `routers/models.py` — models list, ps, show, stats, search/status |
|
||||||
|
| `test_presets.py` | `routers/presets.py` — full CRUD, default preset protection |
|
||||||
|
| `test_profile.py` | `routers/profile.py` — get, update, default, length validation |
|
||||||
|
| `test_search_route.py` | `routers/search_route.py` — explicit search flow, no results, errors |
|
||||||
|
| `test_search_url_sanitization.py` | `search.py` URL sanitizer |
|
||||||
|
| `test_settings_allowlist.py` | `routers/settings.py` — allowlisted key enforcement |
|
||||||
|
| `test_skills_framework.py` | `routers/skills.py` — list, toggle, unknown skill, prompt injection |
|
||||||
|
| `test_ip_allowlist.py` | IP allowlist helper + middleware |
|
||||||
|
| `test_rate_and_payload_guardrails.py` | Rate limits + payload size enforcement |
|
||||||
|
| `test_error_envelopes.py` | Global exception handler + stream error incidents |
|
||||||
|
|
||||||
|
Modules that call `httpx.AsyncClient` (chat, completions, models, search_route)
|
||||||
|
are mocked via `monkeypatch.setattr` on `AsyncClient.stream`, `.get`, or `.post`.
|
||||||
|
CPU stats in `models.py` (`api/stats`) use real `psutil`; GPU stats are
|
||||||
|
monkeypatched via `routers.models.get_gpu_stats`.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
Refactored from single-file (`app.py`) into modules under project root:
|
||||||
|
|
||||||
|
| File | Role |
|
||||||
|
|------|------|
|
||||||
|
| `app.py` | FastAPI app, middleware, router registration |
|
||||||
|
| `config.py` | Constants, env vars, rate/payload limits, built-in skills registry |
|
||||||
|
| `db.py` | SQLite schema, connection factory, settings helpers |
|
||||||
|
| `auth.py` | PIN-based guest/admin sessions, auth routes |
|
||||||
|
| `security.py` | Rate limiting, origin checks, IP allowlist, audit/incident logging |
|
||||||
|
| `memory.py` | FTS5 memory CRUD, remember/forget command parsing |
|
||||||
|
| `search.py` | SearXNG integration, perplexity scoring, refusal detection |
|
||||||
|
| `rag.py` | Qdrant vector search + system prompt assembly |
|
||||||
|
| `gpu.py` | AMD GPU stats via `rocm-smi` |
|
||||||
|
| `routers/` | One module per endpoint group (chat, search, skills, completions, etc.) |
|
||||||
|
|
||||||
|
### Entrypoint / API keys
|
||||||
|
|
||||||
|
- `app.py` line 148: `uvicorn.run(app, ...)` when called directly
|
||||||
|
- `config.py` line 14: `LLAMA_SERVER_BASE` defaults to `http://192.168.50.108:8081` — llama-server, **not** standard Ollama port, used by all model endpoints
|
||||||
|
- `config.py` line 17: `COMPLETIONS_API_KEY` read from `JARVISCHAT_COMPLETIONS_API_KEY` env var or auto-generates a random key — no longer a missing import
|
||||||
|
- `config.py` line 13: `OLLAMA_BASE` is legacy/unused — all endpoints now use `LLAMA_SERVER_BASE`
|
||||||
|
|
||||||
|
### Key flows
|
||||||
|
|
||||||
|
1. **`/api/chat`** → `process_remember_command()` intercepts "remember that..." / "forget about..." first → else `build_system_prompt()` (profile + FTS5 memory + Qdrant RAG + preset + skills) → stream from llama-server with `logprobs: true` → if perplexity > 15.0 OR `REFUSAL_PATTERNS` match, re-query with SearXNG results
|
||||||
|
2. **`/api/search`** → bypasses perplexity/refusal, queries SearXNG directly → summarizes via llama-server (no raw results leaked in SSE)
|
||||||
|
3. **`/v1/chat/completions`** → OpenAI-compatible for Continue.dev/IDE integration; FIM requests proxied without persistence
|
||||||
|
|
||||||
|
### Perplexity / auto-search
|
||||||
|
|
||||||
|
The upstream request includes `"logprobs": true`. `parse_llama_stream_chunk()` extracts per-token logprobs from each chunk's `choices[0].logprobs.content[].logprob`. The `all_logprobs` list is populated during streaming, so `calculate_perplexity()` and `is_uncertain()` work correctly — auto-search on high perplexity is no longer dead code.
|
||||||
|
|
||||||
|
### Auth / lockdown
|
||||||
|
|
||||||
|
- Guest session by default (`POST /api/auth/guest`), admin unlock via 4-digit PIN (`POST /api/auth/login`)
|
||||||
|
- Admin required for PUT/DELETE/PATCH + all POST except allowlist (`/api/chat`, `/api/search`, `/api/auth/*`)
|
||||||
|
- IP allowlist, rate limiting, origin checking, payload size limits — all enforced in `app.py` middleware
|
||||||
|
- Origin check applies to **all** `/api/` requests (not just state-changing methods); `origin_allowed()` returns `False` when both `Origin` and `Referer` headers are absent, closing CSRF read gap
|
||||||
|
- `JARVISCHAT_ADMIN_PIN` env var required on first boot (or `JARVISCHAT_ALLOW_DEFAULT_PIN=true`)
|
||||||
|
|
||||||
|
### Database
|
||||||
|
|
||||||
|
- SQLite at `jarvischat.db`, auto-created by `init_db()` on startup via FastAPI `lifespan`
|
||||||
|
- `get_db()` opens new connection per request (no pool). Close after use.
|
||||||
|
- FTS5 virtual table `memories` for full-text search with BM25 ranking. FTS5 operator keywords (`AND`, `OR`, `NOT`, `NEAR`) are double-quoted to prevent parse errors.
|
||||||
|
|
||||||
|
### External services
|
||||||
|
|
||||||
|
| Service | Required | Port |
|
||||||
|
|---------|----------|------|
|
||||||
|
| llama-server (OpenAI-compat API) | Yes | 8081 (ultron) or env `LLAMA_SERVER_BASE` |
|
||||||
|
| SearXNG | No | 8888 |
|
||||||
|
| wttr.in | No | weather shortcut bypasses SearXNG; curl UA for plain-text output |
|
||||||
|
| rocm-smi | No | AMD GPU stats |
|
||||||
|
| Qdrant | No | 6333 (ultron) — RAG vector search |
|
||||||
|
|
||||||
|
### Config quirks
|
||||||
|
|
||||||
|
- Rate limits and payload caps in `config.py` — tweak for testing by monkeypatching module attributes (note: patch `security.RL_*` not `config.RL_*` since `security` imports bindings separately)
|
||||||
|
- `ALLOWED_SETTINGS_KEYS` in `config.py` controls which keys the UI can write via `/api/settings`
|
||||||
|
- Settings table seeded with defaults (`profile_enabled`, `search_enabled`, `memory_enabled`, `skills_enabled`, `default_model`) — never overwritten by `init_db()`
|
||||||
|
- Profile table uses singleton row `id=1`
|
||||||
|
- RAG embedding requests go to `EMBED_URL` at `/api/embeddings` (separate Ollama instance)
|
||||||
|
|
||||||
|
### SSE Protocol
|
||||||
|
|
||||||
|
All streaming endpoints yield `data: {json}\n\n`. Key shapes:
|
||||||
|
- `{token, conversation_id}` — streaming token
|
||||||
|
- `{searching: true}` — web search triggered
|
||||||
|
- `{search_results: N}` — N results (no raw_results payload)
|
||||||
|
- `{done: true, perplexity, tokens_per_sec, searched?}` — terminal
|
||||||
|
- `{error: "...", error_key: "..."}` — error with incident key
|
||||||
263
README.md
Normal file
263
README.md
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
# JarvisChat v1.8.5
|
||||||
|
|
||||||
|
**A lightweight local inference coding companion with persistent memory, web search, and real-time system monitoring.**
|
||||||
|
|
||||||
|
Built with FastAPI + SQLite + Jinja2. Runs on Python 3.13. No Docker required.
|
||||||
|
|
||||||
|
Developer wiki: [docs/wiki/Home.md](docs/wiki/Home.md)
|
||||||
|
|
||||||
|
## What's New in v1.8.0
|
||||||
|
|
||||||
|
- **Modular refactor completed** — single-file `app.py` split into `config.py`, `db.py`, `auth.py`, `security.py`, `memory.py`, `search.py`, `rag.py`, `gpu.py`, and `routers/` package
|
||||||
|
- **`COMPLETIONS_API_KEY`** — auto-generated secret key for the OpenAI-compatible endpoint, overridable via `JARVISCHAT_COMPLETIONS_API_KEY` env var
|
||||||
|
- **Perplexity auto-search fixed** — upstream request now sends `"logprobs": true`, `parse_llama_stream_chunk()` extracts per-token logprobs, so `calculate_perplexity()` and `is_uncertain()` work correctly (was dead code)
|
||||||
|
- **All `/api/models` endpoints** — now correctly target `LLAMA_SERVER_BASE` (llama-server on port 8081) instead of the old Ollama port; `/api/ps` uses `/v1/models` endpoint
|
||||||
|
- **RAG embedding endpoint fixed** — `EMBED_URL` changed from old server `:8081` to correct host/port `http://192.168.50.210:11434` (Ollama on new machine)
|
||||||
|
- **Error messages corrected** — all user-facing errors say "inference server" instead of "Ollama" or "llama-server"
|
||||||
|
- **Secure SSE protocol** — raw search results are no longer leaked in the SSE event stream
|
||||||
|
- **FTS5 query safety** — operator keywords (`AND`, `OR`, `NOT`, `NEAR`) are double-quoted to prevent parse errors
|
||||||
|
- **All 8 test files fixed** — rewired imports after the modular refactor; all 26 tests pass
|
||||||
|
- **Origin check extended to all API methods** — GET/HEAD/OPTIONS requests no longer bypass origin checking (was limited to POST/PUT/DELETE/PATCH)
|
||||||
|
- **Missing headers now rejected** — `origin_allowed()` returns `False` when both `Origin` and `Referer` are absent, closing the CSRF read gap for script-initiated requests
|
||||||
|
- **Full router test coverage** — 7 new test files added: `test_conversations.py`, `test_presets.py`, `test_profile.py`, `test_models_router.py`, `test_completions.py`, `test_search_route.py`, `test_memories.py`; all 10 routers now have dedicated unit tests (92 total, up from 26)
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **Persistent Memory** — SQLite FTS5 full-text search for fast, relevant memory retrieval
|
||||||
|
- **Web Search** — SearXNG integration for automatic web lookups when the model is uncertain
|
||||||
|
- **Explicit Search** — Search button to force web search without waiting for model uncertainty
|
||||||
|
- **Profile Injection** — Custom system prompt injected into every conversation
|
||||||
|
- **System Presets** — Save and switch between different system prompts
|
||||||
|
- **Real-time Stats** — CPU, RAM, GPU, VRAM monitoring in sidebar
|
||||||
|
- **Token Thermometer** — Visual context window usage indicator
|
||||||
|
- **Streaming Responses** — Server-sent events for real-time token display
|
||||||
|
- **Conversation History** — SQLite-backed chat persistence with mass-delete option
|
||||||
|
- **Model Switching** — Change inference models on the fly
|
||||||
|
- **Skills Framework** — Built-in skill registry with per-skill enable/disable controls
|
||||||
|
|
||||||
|
## File Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
/opt/jarvischat/
|
||||||
|
├── app.py # FastAPI app entry point
|
||||||
|
├── config.py # Constants, env vars, limits, skill registry
|
||||||
|
├── db.py # SQLite schema, connection factory
|
||||||
|
├── auth.py # PIN-based guest/admin sessions, auth routes
|
||||||
|
├── security.py # Rate limiting, origin checks, IP allowlist, audit
|
||||||
|
├── memory.py # FTS5 memory CRUD, remember/forget commands
|
||||||
|
├── search.py # SearXNG integration, perplexity, refusal detection
|
||||||
|
├── rag.py # Qdrant vector search + system prompt assembly
|
||||||
|
├── gpu.py # AMD GPU stats via rocm-smi
|
||||||
|
├── routers/
|
||||||
|
│ ├── chat.py # /api/chat streaming endpoint
|
||||||
|
│ ├── search_route.py # /api/search explicit search endpoint
|
||||||
|
│ ├── completions.py # /v1/chat/completions OpenAI-compat endpoint
|
||||||
|
│ ├── conversations.py# Conversation CRUD
|
||||||
|
│ ├── memories.py # Memory CRUD API
|
||||||
|
│ ├── models.py # Model listing, system stats
|
||||||
|
│ ├── presets.py # System prompt presets
|
||||||
|
│ ├── profile.py # User profile
|
||||||
|
│ ├── settings.py # Runtime settings
|
||||||
|
│ └── skills.py # Skills management
|
||||||
|
├── static/
|
||||||
|
│ └── logo.png # Logo image (optional)
|
||||||
|
├── templates/
|
||||||
|
│ └── index.html # Frontend
|
||||||
|
└── tests/ # 26 pytest tests
|
||||||
|
```
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- Python 3.11+ (tested on 3.13)
|
||||||
|
- llama-server running locally or on network (OpenAI-compatible API on port 8081)
|
||||||
|
- SearXNG (optional, for web search)
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
### Fresh Install
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Create directory and venv
|
||||||
|
sudo mkdir -p /opt/jarvischat
|
||||||
|
sudo chown $USER:$USER /opt/jarvischat
|
||||||
|
cd /opt/jarvischat
|
||||||
|
python3 -m venv venv
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
./venv/bin/pip install fastapi uvicorn httpx psutil jinja2 python-multipart
|
||||||
|
|
||||||
|
# Set admin PIN before first startup (4 digits)
|
||||||
|
export JARVISCHAT_ADMIN_PIN=4827
|
||||||
|
|
||||||
|
# Create subdirectories
|
||||||
|
mkdir -p templates static
|
||||||
|
|
||||||
|
# Copy files
|
||||||
|
# (copy all .py files to /opt/jarvischat/)
|
||||||
|
# (copy routers/ directory to /opt/jarvischat/)
|
||||||
|
# (copy templates/index.html to /opt/jarvischat/templates/)
|
||||||
|
```
|
||||||
|
|
||||||
|
WARNING: Do not use `1234` as your admin PIN unless you accept weak local security.
|
||||||
|
|
||||||
|
NOTE: First boot requires `JARVISCHAT_ADMIN_PIN` unless you explicitly opt into insecure fallback with `JARVISCHAT_ALLOW_DEFAULT_PIN=true`.
|
||||||
|
|
||||||
|
## Systemd Service
|
||||||
|
|
||||||
|
Create `/etc/systemd/system/jarvischat.service`:
|
||||||
|
|
||||||
|
```ini
|
||||||
|
[Unit]
|
||||||
|
Description=JarvisChat - Local Inference Web Interface
|
||||||
|
After=network.target
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
Type=simple
|
||||||
|
User=jarvischat
|
||||||
|
Group=jarvischat
|
||||||
|
WorkingDirectory=/opt/jarvischat
|
||||||
|
ExecStart=/opt/jarvischat/venv/bin/uvicorn app:app --host 0.0.0.0 --port 8080
|
||||||
|
Restart=always
|
||||||
|
RestartSec=5
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo systemctl daemon-reload
|
||||||
|
sudo systemctl enable jarvischat
|
||||||
|
sudo systemctl start jarvischat
|
||||||
|
```
|
||||||
|
|
||||||
|
## Memory Commands
|
||||||
|
|
||||||
|
In chat, natural language triggers memory operations:
|
||||||
|
|
||||||
|
| You say | What happens |
|
||||||
|
|---------|--------------|
|
||||||
|
| "remember that I prefer Rust over Go" | Stores as `preference` |
|
||||||
|
| "remember that JarvisChat runs on port 8080" | Stores as `infrastructure` |
|
||||||
|
| "note that the deadline is Friday" | Stores as `general` |
|
||||||
|
| "forget about the deadline" | Removes matching memories |
|
||||||
|
|
||||||
|
Memories are automatically searched based on your message content and injected into the system prompt when relevant.
|
||||||
|
|
||||||
|
### Memory Topics
|
||||||
|
|
||||||
|
Memories are auto-categorized:
|
||||||
|
- `preference` — likes, dislikes, choices
|
||||||
|
- `project` — active work, repos, tasks
|
||||||
|
- `infrastructure` — servers, services, configs
|
||||||
|
- `personal` — name, location, background
|
||||||
|
- `general` — everything else
|
||||||
|
|
||||||
|
## API Endpoints
|
||||||
|
|
||||||
|
### Completions (OpenAI-compatible)
|
||||||
|
|
||||||
|
| Method | Endpoint | Description |
|
||||||
|
|--------|----------|-------------|
|
||||||
|
| POST | `/v1/chat/completions` | OpenAI-compatible chat (requires Bearer API key) |
|
||||||
|
|
||||||
|
### Chat & Search
|
||||||
|
|
||||||
|
| Method | Endpoint | Description |
|
||||||
|
|--------|----------|-------------|
|
||||||
|
| POST | `/api/chat` | Send message (streaming SSE) |
|
||||||
|
| POST | `/api/search` | Explicit web search (streaming SSE) |
|
||||||
|
|
||||||
|
### Memory
|
||||||
|
|
||||||
|
| Method | Endpoint | Description |
|
||||||
|
|--------|----------|-------------|
|
||||||
|
| GET | `/api/memories` | List all memories |
|
||||||
|
| POST | `/api/memories` | Add memory |
|
||||||
|
| PUT | `/api/memories/{rowid}` | Update memory |
|
||||||
|
| DELETE | `/api/memories/{rowid}` | Delete memory |
|
||||||
|
| GET | `/api/memories/search?q=term` | Search memories |
|
||||||
|
| GET | `/api/memories/stats` | Get counts by topic |
|
||||||
|
|
||||||
|
### Models & System
|
||||||
|
|
||||||
|
| Method | Endpoint | Description |
|
||||||
|
|--------|----------|-------------|
|
||||||
|
| GET | `/api/models` | List available models |
|
||||||
|
| GET | `/api/ps` | List loaded models |
|
||||||
|
| POST | `/api/show` | Get model info |
|
||||||
|
| GET | `/api/stats` | CPU, RAM, GPU, VRAM stats |
|
||||||
|
| GET | `/api/search/status` | SearXNG availability |
|
||||||
|
|
||||||
|
### Settings & Profile
|
||||||
|
|
||||||
|
| Method | Endpoint | Description |
|
||||||
|
|--------|----------|-------------|
|
||||||
|
| GET | `/api/profile` | Get profile content |
|
||||||
|
| PUT | `/api/profile` | Update profile (admin) |
|
||||||
|
| GET | `/api/profile/default` | Get default profile |
|
||||||
|
| GET | `/api/settings` | Get settings |
|
||||||
|
| PUT | `/api/settings` | Update settings (admin) |
|
||||||
|
|
||||||
|
### Conversations
|
||||||
|
|
||||||
|
| Method | Endpoint | Description |
|
||||||
|
|--------|----------|-------------|
|
||||||
|
| GET | `/api/conversations` | List conversations |
|
||||||
|
| POST | `/api/conversations` | Create conversation |
|
||||||
|
| GET | `/api/conversations/{id}` | Get conversation with messages |
|
||||||
|
| PUT | `/api/conversations/{id}` | Update conversation title/model |
|
||||||
|
| DELETE | `/api/conversations/{id}` | Delete conversation |
|
||||||
|
| DELETE | `/api/conversations` | Delete ALL conversations |
|
||||||
|
|
||||||
|
### Presets
|
||||||
|
|
||||||
|
| Method | Endpoint | Description |
|
||||||
|
|--------|----------|-------------|
|
||||||
|
| GET | `/api/presets` | List presets |
|
||||||
|
| POST | `/api/presets` | Create preset |
|
||||||
|
| PUT | `/api/presets/{id}` | Update preset |
|
||||||
|
| DELETE | `/api/presets/{id}` | Delete preset |
|
||||||
|
|
||||||
|
### Skills
|
||||||
|
|
||||||
|
| Method | Endpoint | Description |
|
||||||
|
|--------|----------|-------------|
|
||||||
|
| GET | `/api/skills` | List all skills with state |
|
||||||
|
| GET | `/api/skills/active` | List active skills |
|
||||||
|
| PUT | `/api/skills/{key}` | Toggle skill enabled (admin) |
|
||||||
|
|
||||||
|
### Auth
|
||||||
|
|
||||||
|
| Method | Endpoint | Description |
|
||||||
|
|--------|----------|-------------|
|
||||||
|
| POST | `/api/auth/guest` | Create guest session |
|
||||||
|
| POST | `/api/auth/login` | Admin PIN login |
|
||||||
|
| POST | `/api/auth/logout` | Revoke session |
|
||||||
|
| GET | `/api/auth/session` | Check session validity |
|
||||||
|
| POST | `/api/auth/heartbeat` | Extend session TTL |
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Settings are stored in the `settings` table and include:
|
||||||
|
|
||||||
|
- `profile_enabled` — Inject profile into chats (true/false)
|
||||||
|
- `search_enabled` — Auto web search (true/false)
|
||||||
|
- `memory_enabled` — Memory injection (true/false)
|
||||||
|
- `skills_enabled` — Skills framework (true/false)
|
||||||
|
- `default_model` — Default inference model
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./venv/bin/python -m pytest tests/ -v
|
||||||
|
```
|
||||||
|
|
||||||
|
All 26 tests use `tmp_path` fixtures + monkeypatched `httpx.AsyncClient.stream`. No external services needed.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT
|
||||||
|
|
||||||
|
## Repository
|
||||||
|
|
||||||
|
Gitea: `ssh://gitea@llgit.llamachile.tube:1319/gramps/jarvisChat.git`
|
||||||
2334
app.py.bak
Normal file
2334
app.py.bak
Normal file
File diff suppressed because it is too large
Load Diff
2334
app.py.pre-refactor-20260616-081744
Normal file
2334
app.py.pre-refactor-20260616-081744
Normal file
File diff suppressed because it is too large
Load Diff
202
auth.py
Normal file
202
auth.py
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
"""
|
||||||
|
JarvisChat - Auth: session management, PIN verification, middleware, auth routes.
|
||||||
|
"""
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from config import SESSION_TIMEOUT_SECONDS, MAX_PIN_ATTEMPTS, PIN_LOCKOUT_SECONDS, RATE_WINDOW_SECONDS
|
||||||
|
from db import get_db, get_setting
|
||||||
|
from security import (
|
||||||
|
SESSIONS, PIN_ATTEMPTS, SESSION_LOCK, BODY_LIMIT_DEFAULT_BYTES,
|
||||||
|
audit_event, get_client_ip, is_ip_allowed, check_rate_limit,
|
||||||
|
rate_policy, origin_allowed, is_state_changing, request_body_limit,
|
||||||
|
read_json_body, hash_pin, customer_error_envelope, log_incident,
|
||||||
|
)
|
||||||
|
|
||||||
|
log = logging.getLogger("jarvischat")
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def verify_admin_pin(pin: str) -> bool:
|
||||||
|
if not re.fullmatch(r"\d{4}", pin or ""):
|
||||||
|
return False
|
||||||
|
db = get_db()
|
||||||
|
pin_hash = get_setting(db, "admin_pin_hash", "")
|
||||||
|
pin_salt = get_setting(db, "admin_pin_salt", "")
|
||||||
|
db.close()
|
||||||
|
if not pin_hash or not pin_salt:
|
||||||
|
return False
|
||||||
|
_, candidate_hash = hash_pin(pin, salt_hex=pin_salt)
|
||||||
|
return hmac.compare_digest(candidate_hash, pin_hash)
|
||||||
|
|
||||||
|
|
||||||
|
def is_ip_locked(ip: str) -> tuple:
|
||||||
|
now_ts = time.time()
|
||||||
|
with SESSION_LOCK:
|
||||||
|
state = PIN_ATTEMPTS.get(ip)
|
||||||
|
if not state:
|
||||||
|
return False, 0
|
||||||
|
locked_until = state.get("locked_until", 0)
|
||||||
|
if locked_until > now_ts:
|
||||||
|
return True, int(locked_until - now_ts)
|
||||||
|
if locked_until:
|
||||||
|
PIN_ATTEMPTS.pop(ip, None)
|
||||||
|
return False, 0
|
||||||
|
|
||||||
|
|
||||||
|
def record_pin_failure(ip: str) -> None:
|
||||||
|
now_ts = time.time()
|
||||||
|
with SESSION_LOCK:
|
||||||
|
state = PIN_ATTEMPTS.get(ip, {"fail_count": 0, "locked_until": 0})
|
||||||
|
state["fail_count"] = int(state.get("fail_count", 0)) + 1
|
||||||
|
if state["fail_count"] >= MAX_PIN_ATTEMPTS:
|
||||||
|
state["locked_until"] = now_ts + PIN_LOCKOUT_SECONDS
|
||||||
|
state["fail_count"] = 0
|
||||||
|
PIN_ATTEMPTS[ip] = state
|
||||||
|
|
||||||
|
|
||||||
|
def clear_pin_failures(ip: str) -> None:
|
||||||
|
with SESSION_LOCK:
|
||||||
|
PIN_ATTEMPTS.pop(ip, None)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_sessions(now_ts: Optional[float] = None) -> None:
|
||||||
|
now_ts = now_ts or time.time()
|
||||||
|
with SESSION_LOCK:
|
||||||
|
expired = [
|
||||||
|
sid for sid, meta in SESSIONS.items()
|
||||||
|
if (now_ts - meta.get("last_seen", 0)) > SESSION_TIMEOUT_SECONDS
|
||||||
|
]
|
||||||
|
for sid in expired:
|
||||||
|
del SESSIONS[sid]
|
||||||
|
|
||||||
|
|
||||||
|
def create_session(ip: str, role: str) -> str:
|
||||||
|
now_ts = time.time()
|
||||||
|
sid = uuid.uuid4().hex
|
||||||
|
with SESSION_LOCK:
|
||||||
|
SESSIONS[sid] = {"ip": ip, "role": role, "created_at": now_ts, "last_seen": now_ts}
|
||||||
|
return sid
|
||||||
|
|
||||||
|
|
||||||
|
def validate_session(sid: str, ip: str, touch: bool = True) -> bool:
|
||||||
|
if not sid:
|
||||||
|
return False
|
||||||
|
now_ts = time.time()
|
||||||
|
cleanup_sessions(now_ts)
|
||||||
|
with SESSION_LOCK:
|
||||||
|
session = SESSIONS.get(sid)
|
||||||
|
if not session or session.get("ip") != ip:
|
||||||
|
return False
|
||||||
|
if touch:
|
||||||
|
session["last_seen"] = now_ts
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_session(sid: str, ip: str, touch: bool = True) -> Optional[dict]:
|
||||||
|
if not sid:
|
||||||
|
return None
|
||||||
|
now_ts = time.time()
|
||||||
|
cleanup_sessions(now_ts)
|
||||||
|
with SESSION_LOCK:
|
||||||
|
session = SESSIONS.get(sid)
|
||||||
|
if not session or session.get("ip") != ip:
|
||||||
|
return None
|
||||||
|
if touch:
|
||||||
|
session["last_seen"] = now_ts
|
||||||
|
return dict(session)
|
||||||
|
|
||||||
|
|
||||||
|
def revoke_session(sid: str) -> None:
|
||||||
|
if not sid:
|
||||||
|
return
|
||||||
|
with SESSION_LOCK:
|
||||||
|
SESSIONS.pop(sid, None)
|
||||||
|
|
||||||
|
|
||||||
|
def is_admin_only(path: str, method: str) -> bool:
|
||||||
|
if method in {"PUT", "DELETE", "PATCH"}:
|
||||||
|
return True
|
||||||
|
if method != "POST":
|
||||||
|
return False
|
||||||
|
guest_allowed_posts = {
|
||||||
|
"/api/chat", "/api/search", "/api/show", "/api/auth/login",
|
||||||
|
"/api/auth/logout", "/api/auth/session", "/api/auth/heartbeat", "/api/auth/guest",
|
||||||
|
}
|
||||||
|
return path not in guest_allowed_posts
|
||||||
|
|
||||||
|
|
||||||
|
# --- Auth routes ---
|
||||||
|
|
||||||
|
@router.post("/api/auth/guest")
|
||||||
|
async def auth_guest(request: Request):
|
||||||
|
ip = get_client_ip(request)
|
||||||
|
sid = create_session(ip, role="guest")
|
||||||
|
audit_event("guest_session", "success", ip=ip, role="guest")
|
||||||
|
return {"status": "ok", "session_id": sid, "role": "guest", "timeout_seconds": SESSION_TIMEOUT_SECONDS}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/api/auth/login")
|
||||||
|
async def auth_login(request: Request):
|
||||||
|
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||||
|
pin = str(body.get("pin", ""))
|
||||||
|
ip = get_client_ip(request)
|
||||||
|
locked, retry_after = is_ip_locked(ip)
|
||||||
|
if locked:
|
||||||
|
audit_event("admin_login", "locked", ip=ip, role="none", details=f"retry_after={retry_after}", warning=True)
|
||||||
|
raise HTTPException(status_code=429, detail=f"Too many failed PIN attempts. Retry in {retry_after}s.")
|
||||||
|
if not verify_admin_pin(pin):
|
||||||
|
record_pin_failure(ip)
|
||||||
|
audit_event("admin_login", "failed", ip=ip, role="none", warning=True)
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid PIN")
|
||||||
|
clear_pin_failures(ip)
|
||||||
|
sid = create_session(ip, role="admin")
|
||||||
|
audit_event("admin_login", "success", ip=ip, role="admin")
|
||||||
|
return {"status": "ok", "session_id": sid, "role": "admin", "timeout_seconds": SESSION_TIMEOUT_SECONDS}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/api/auth/session")
|
||||||
|
async def auth_session(request: Request):
|
||||||
|
sid = request.headers.get("x-session-id", "").strip()
|
||||||
|
ip = get_client_ip(request)
|
||||||
|
session = get_session(sid, ip, touch=True)
|
||||||
|
return {"authenticated": bool(session), "role": session.get("role") if session else "none"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/api/auth/heartbeat")
|
||||||
|
async def auth_heartbeat(request: Request):
|
||||||
|
sid = request.headers.get("x-session-id", "").strip()
|
||||||
|
ip = get_client_ip(request)
|
||||||
|
if not sid or not validate_session(sid, ip, touch=True):
|
||||||
|
raise HTTPException(status_code=401, detail="Authentication required")
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/api/auth/logout")
|
||||||
|
async def auth_logout(request: Request):
|
||||||
|
ip = get_client_ip(request)
|
||||||
|
sid = request.headers.get("x-session-id", "").strip()
|
||||||
|
role = "none"
|
||||||
|
if sid:
|
||||||
|
session = get_session(sid, ip, touch=False)
|
||||||
|
role = session.get("role", "none") if session else "none"
|
||||||
|
if not sid:
|
||||||
|
try:
|
||||||
|
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||||
|
sid = str(body.get("session_id", "")).strip()
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
sid = (await request.body()).decode("utf-8", errors="ignore").strip()
|
||||||
|
except Exception:
|
||||||
|
sid = ""
|
||||||
|
revoke_session(sid)
|
||||||
|
audit_event("logout", "success", ip=ip, role=role)
|
||||||
|
return {"status": "ok"}
|
||||||
160
config.py
Normal file
160
config.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
"""
|
||||||
|
JarvisChat - Central configuration.
|
||||||
|
All constants, environment variables, limits, and skill registry live here.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import ipaddress
|
||||||
|
import logging
|
||||||
|
|
||||||
|
log = logging.getLogger("jarvischat")
|
||||||
|
|
||||||
|
VERSION = "v1.8.5"
|
||||||
|
OLLAMA_BASE = os.environ.get("OLLAMA_BASE", "http://localhost:11434")
|
||||||
|
LLAMA_SERVER_BASE = os.environ.get("LLAMA_SERVER_BASE", "http://192.168.50.108:8081")
|
||||||
|
SEARXNG_BASE = "http://localhost:8888"
|
||||||
|
DEFAULT_MODEL = "llama3.1:latest"
|
||||||
|
COMPLETIONS_API_KEY = os.environ.get("JARVISCHAT_COMPLETIONS_API_KEY", "jc-sk-" + os.urandom(24).hex())
|
||||||
|
|
||||||
|
# --- Auth ---
|
||||||
|
SESSION_TIMEOUT_SECONDS = 90
|
||||||
|
MAX_PIN_ATTEMPTS = 5
|
||||||
|
PIN_LOCKOUT_SECONDS = 300
|
||||||
|
ALLOW_DEFAULT_PIN = os.getenv("JARVISCHAT_ALLOW_DEFAULT_PIN", "false").lower() == "true"
|
||||||
|
TRUSTED_ORIGINS = {
|
||||||
|
origin.strip().rstrip("/")
|
||||||
|
for origin in os.getenv("JARVISCHAT_TRUSTED_ORIGINS", "").split(",")
|
||||||
|
if origin.strip()
|
||||||
|
}
|
||||||
|
DEFAULT_ALLOWED_CIDRS = "127.0.0.0/8,::1/128,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16"
|
||||||
|
ALLOWED_CIDRS_RAW = os.getenv("JARVISCHAT_ALLOWED_CIDRS", DEFAULT_ALLOWED_CIDRS)
|
||||||
|
TRUST_X_FORWARDED_FOR = (
|
||||||
|
os.getenv("JARVISCHAT_TRUST_X_FORWARDED_FOR", "false").lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Rate limits ---
|
||||||
|
RATE_WINDOW_SECONDS = 60
|
||||||
|
RL_LOGIN_PER_WINDOW = 10
|
||||||
|
RL_CHAT_PER_WINDOW = 24
|
||||||
|
RL_SEARCH_PER_WINDOW = 16
|
||||||
|
RL_WRITE_PER_WINDOW = 30
|
||||||
|
RL_DEFAULT_PER_WINDOW = 240
|
||||||
|
RL_STATS_PER_WINDOW = 600
|
||||||
|
|
||||||
|
# --- Payload limits ---
|
||||||
|
BODY_LIMIT_DEFAULT_BYTES = 64 * 1024
|
||||||
|
BODY_LIMIT_CHAT_BYTES = 128 * 1024
|
||||||
|
BODY_LIMIT_PROFILE_BYTES = 256 * 1024
|
||||||
|
|
||||||
|
MAX_CHAT_MESSAGE_CHARS = 8000
|
||||||
|
MAX_SEARCH_QUERY_CHARS = 500
|
||||||
|
MAX_PROFILE_CHARS = 32000
|
||||||
|
MAX_MEMORY_FACT_CHARS = 2000
|
||||||
|
MAX_PRESET_NAME_CHARS = 120
|
||||||
|
MAX_PRESET_PROMPT_CHARS = 12000
|
||||||
|
MAX_SETTINGS_KEYS = 16
|
||||||
|
MAX_SETTINGS_VALUE_CHARS = 8000
|
||||||
|
MAX_CONVERSATION_TITLE_CHARS = 200
|
||||||
|
MAX_SKILL_KEY_CHARS = 120
|
||||||
|
MAX_SKILL_PROMPT_CHARS = 1600
|
||||||
|
|
||||||
|
ALLOWED_SETTINGS_KEYS = {
|
||||||
|
"profile_enabled",
|
||||||
|
"default_model",
|
||||||
|
"search_enabled",
|
||||||
|
"memory_enabled",
|
||||||
|
"skills_enabled",
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- Perplexity ---
|
||||||
|
PERPLEXITY_THRESHOLD = 15.0
|
||||||
|
|
||||||
|
# --- Refusal / hedge patterns ---
|
||||||
|
REFUSAL_PATTERNS = re.compile(
|
||||||
|
r"|".join([
|
||||||
|
r"i don'?t have (?:real-?time|current|live)",
|
||||||
|
r"i (?:can'?t|cannot) provide (?:current|real-?time|live)",
|
||||||
|
r"i don'?t have access to (?:current|real-?time|live)",
|
||||||
|
r"(?:current|live|real-?time) (?:data|information|prices?|weather)",
|
||||||
|
r"my (?:knowledge|training) (?:cutoff|only goes|ends)",
|
||||||
|
r"as of my (?:knowledge|training) cutoff",
|
||||||
|
r"i'?m not able to (?:access|provide|browse)",
|
||||||
|
r"(?:check|visit|use) a (?:website|financial|news)",
|
||||||
|
r"as an ai model",
|
||||||
|
r"based on my training data",
|
||||||
|
r"i don'?t have the capability",
|
||||||
|
]),
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
HEDGE_PATTERNS = [
|
||||||
|
r"^I'?m sorry,?\s*but\s*I\s*(?:can'?t|cannot)\s*assist\s*with\s*that[^.]*\.\s*",
|
||||||
|
r"^I'?m sorry,?\s*but[^.]*(?:previous|incorrect)[^.]*\.\s*",
|
||||||
|
r"(?:But\s+)?[Pp]lease\s+(?:make\s+sure\s+to\s+)?verify\s+(?:the\s+)?(?:data|information|this)\s+(?:from\s+)?(?:reliable\s+)?sources[^.]*\.\s*",
|
||||||
|
r"[Pp]lease\s+verify[^.]*(?:accurate|reliability)[^.]*\.\s*",
|
||||||
|
r"[Bb]ut\s+please\s+(?:make\s+sure|verify|check)[^.]*\.\s*",
|
||||||
|
]
|
||||||
|
|
||||||
|
# --- Built-in skills registry ---
|
||||||
|
BUILTIN_SKILLS = [
|
||||||
|
{"key": "memory.search", "name": "Memory Search", "category": "memory", "risk": "low", "description": "Search stored memory facts relevant to the current prompt."},
|
||||||
|
{"key": "memory.add", "name": "Memory Add", "category": "memory", "risk": "medium", "description": "Store a new memory fact with topic tagging."},
|
||||||
|
{"key": "memory.forget", "name": "Memory Forget", "category": "memory", "risk": "high", "description": "Delete matching memories when asked to forget information."},
|
||||||
|
{"key": "conversation.list", "name": "Conversation List", "category": "conversation", "risk": "low", "description": "List existing conversations with metadata."},
|
||||||
|
{"key": "conversation.get", "name": "Conversation Get", "category": "conversation", "risk": "low", "description": "Read a conversation and its message history."},
|
||||||
|
{"key": "conversation.delete", "name": "Conversation Delete", "category": "conversation", "risk": "high", "description": "Delete a single conversation thread."},
|
||||||
|
{"key": "conversation.delete_all", "name": "Conversation Delete All", "category": "conversation", "risk": "high", "description": "Delete all conversations and messages."},
|
||||||
|
{"key": "search.web", "name": "Web Search", "category": "search", "risk": "low", "description": "Run explicit web search and summarize results."},
|
||||||
|
{"key": "settings.get", "name": "Settings Get", "category": "settings", "risk": "low", "description": "Read current runtime settings."},
|
||||||
|
{"key": "settings.update", "name": "Settings Update", "category": "settings", "risk": "high", "description": "Update allowlisted runtime settings keys."},
|
||||||
|
]
|
||||||
|
|
||||||
|
SKILLS_BY_KEY = {s["key"]: s for s in BUILTIN_SKILLS}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_allowed_cidrs(raw: str) -> list:
|
||||||
|
networks = []
|
||||||
|
for entry in (raw or "").split(","):
|
||||||
|
value = entry.strip()
|
||||||
|
if not value:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
networks.append(ipaddress.ip_network(value, strict=False))
|
||||||
|
except ValueError:
|
||||||
|
log.warning(f"Invalid CIDR ignored: {value}")
|
||||||
|
return networks
|
||||||
|
|
||||||
|
|
||||||
|
ALLOWED_NETWORKS = parse_allowed_cidrs(ALLOWED_CIDRS_RAW)
|
||||||
|
|
||||||
|
DEFAULT_PROFILE = """You are a coding companion running locally on a machine called "jarvis".
|
||||||
|
|
||||||
|
## Environment
|
||||||
|
- jarvis: Debian 13 (trixie) x86_64, AMD Ryzen 5 5600X, 16GB RAM, AMD RX 6600 XT (8GB VRAM)
|
||||||
|
- ultron: Debian 13, Ryzen 7 7840HS, 16GB RAM, primary AI inference node, IP 192.168.50.108
|
||||||
|
- Corsair: Windows 11, gaming/streaming rig, RTX 5070 Ti
|
||||||
|
- pivault: RPi 5, 8GB RAM, Debian 13, 11TB RAID5 NAS at /mnt/pivault, IP 192.168.50.158
|
||||||
|
- Router: ASUS ROG Rapture GT-BE98 Pro "BigBlinkyRouter" at 192.168.50.1
|
||||||
|
- llama-server on ultron:8081 (OpenAI-compat API), Qdrant on ultron:6333
|
||||||
|
|
||||||
|
## About the User
|
||||||
|
- Experienced developer, BS in Computer Science (Oklahoma State), coding since 1981 (TRS-80)
|
||||||
|
- Deep Unix/Linux background — wrote device drivers at SCO during Xenix era (1990s)
|
||||||
|
- Currently learning Rust, transitioning from decades of PHP
|
||||||
|
- Building a WW2 mobile game in Godot Engine for Android
|
||||||
|
- Veteran on fixed income — prefers free/open-source solutions
|
||||||
|
- Home lab enthusiast with Zigbee, Z-Wave and Tapo smart home devices
|
||||||
|
|
||||||
|
## How to Respond
|
||||||
|
- Be direct and concise — no hand-holding, this user knows what they're doing
|
||||||
|
- When showing code, prefer complete working examples over snippets
|
||||||
|
- Default to command-line solutions over GUI when possible
|
||||||
|
- Consider resource constraints (fixed income, specific hardware limits)
|
||||||
|
- Use Rust, Python, or bash unless another language is specifically needed
|
||||||
|
- Explain trade-offs when multiple approaches exist"""
|
||||||
|
|
||||||
|
DEFAULT_PRESETS = [
|
||||||
|
{"name": "Coding Companion", "prompt": "You are a senior software engineer and coding companion. Focus on writing clean, efficient, well-documented code. Provide complete working examples. Explain architectural decisions and trade-offs. Prefer Rust, Python, and bash."},
|
||||||
|
{"name": "Linux Sysadmin", "prompt": "You are an experienced Linux systems administrator. Focus on command-line solutions, systemd services, networking, storage, and security. Prefer Debian/Ubuntu conventions. Be concise and direct."},
|
||||||
|
{"name": "General Assistant","prompt": "You are a helpful general-purpose assistant. Be clear and concise."},
|
||||||
|
]
|
||||||
160
db.py
Normal file
160
db.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
"""
|
||||||
|
JarvisChat - Database layer.
|
||||||
|
Schema init, connection factory, settings helpers, skill state management.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sqlite3
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from config import (
|
||||||
|
BUILTIN_SKILLS, DEFAULT_MODEL, DEFAULT_PRESETS, DEFAULT_PROFILE,
|
||||||
|
MAX_SKILL_PROMPT_CHARS, ALLOWED_NETWORKS,
|
||||||
|
)
|
||||||
|
|
||||||
|
log = logging.getLogger("jarvischat")
|
||||||
|
|
||||||
|
BASE_DIR = Path(__file__).parent
|
||||||
|
DB_PATH = BASE_DIR / "jarvischat.db"
|
||||||
|
|
||||||
|
|
||||||
|
def get_db():
|
||||||
|
conn = sqlite3.connect(DB_PATH)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
conn.execute("PRAGMA foreign_keys = ON")
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
|
def get_setting(db, key: str, default: str = "") -> str:
|
||||||
|
row = db.execute("SELECT value FROM settings WHERE key = ?", (key,)).fetchone()
|
||||||
|
return row["value"] if row else default
|
||||||
|
|
||||||
|
|
||||||
|
def list_skills_with_state(db) -> list:
|
||||||
|
rows = db.execute("SELECT skill_key, enabled, updated_at FROM skills").fetchall()
|
||||||
|
state_by_key = {
|
||||||
|
row["skill_key"]: {"enabled": bool(row["enabled"]), "updated_at": row["updated_at"]}
|
||||||
|
for row in rows
|
||||||
|
}
|
||||||
|
merged = []
|
||||||
|
for skill in BUILTIN_SKILLS:
|
||||||
|
state = state_by_key.get(skill["key"], {"enabled": True, "updated_at": ""})
|
||||||
|
merged.append({**skill, "enabled": state["enabled"], "updated_at": state["updated_at"]})
|
||||||
|
return sorted(merged, key=lambda s: (s["category"], s["name"]))
|
||||||
|
|
||||||
|
|
||||||
|
def set_skill_enabled(db, skill_key: str, enabled: bool) -> None:
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
db.execute(
|
||||||
|
"INSERT OR REPLACE INTO skills (skill_key, enabled, updated_at) VALUES (?, ?, ?)",
|
||||||
|
(skill_key, 1 if enabled else 0, now),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def format_active_skills_prompt(skills: list) -> str:
|
||||||
|
lines = [
|
||||||
|
"## Active Skills",
|
||||||
|
"Use these skills only when needed. Prefer concise answers over unnecessary tool usage.",
|
||||||
|
]
|
||||||
|
for skill in skills:
|
||||||
|
lines.append(f"- {skill['key']} ({skill['risk']} risk): {skill['description']}")
|
||||||
|
text = "\n".join(lines)
|
||||||
|
if len(text) > MAX_SKILL_PROMPT_CHARS:
|
||||||
|
return text[:MAX_SKILL_PROMPT_CHARS - 3] + "..."
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def init_db():
|
||||||
|
from security import hash_pin
|
||||||
|
conn = sqlite3.connect(DB_PATH)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS conversations (
|
||||||
|
id TEXT PRIMARY KEY, title TEXT NOT NULL DEFAULT 'New Chat',
|
||||||
|
model TEXT NOT NULL, created_at TEXT NOT NULL, updated_at TEXT NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS messages (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT, conversation_id TEXT NOT NULL,
|
||||||
|
role TEXT NOT NULL, content TEXT NOT NULL, created_at TEXT NOT NULL,
|
||||||
|
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS system_presets (
|
||||||
|
id TEXT PRIMARY KEY, name TEXT NOT NULL, prompt TEXT NOT NULL,
|
||||||
|
is_default INTEGER NOT NULL DEFAULT 0, created_at TEXT NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS profile (
|
||||||
|
id INTEGER PRIMARY KEY CHECK (id = 1), content TEXT NOT NULL, updated_at TEXT NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
conn.execute("CREATE TABLE IF NOT EXISTS settings (key TEXT PRIMARY KEY, value TEXT NOT NULL)")
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS skills (
|
||||||
|
skill_key TEXT PRIMARY KEY, enabled INTEGER NOT NULL DEFAULT 1, updated_at TEXT NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
conn.execute("""
|
||||||
|
CREATE VIRTUAL TABLE IF NOT EXISTS memories USING fts5(
|
||||||
|
fact, topic, source, created_at UNINDEXED
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
if not conn.execute("SELECT id FROM profile WHERE id = 1").fetchone():
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
conn.execute("INSERT INTO profile (id, content, updated_at) VALUES (1, ?, ?)", (DEFAULT_PROFILE, now))
|
||||||
|
|
||||||
|
if conn.execute("SELECT COUNT(*) as c FROM system_presets").fetchone()["c"] == 0:
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
for preset in DEFAULT_PRESETS:
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO system_presets (id, name, prompt, is_default, created_at) VALUES (?, ?, ?, 1, ?)",
|
||||||
|
(str(uuid.uuid4()), preset["name"], preset["prompt"], now),
|
||||||
|
)
|
||||||
|
|
||||||
|
defaults = {
|
||||||
|
"profile_enabled": "true", "default_model": DEFAULT_MODEL,
|
||||||
|
"search_enabled": "true", "memory_enabled": "true", "skills_enabled": "true",
|
||||||
|
}
|
||||||
|
for key, value in defaults.items():
|
||||||
|
if not conn.execute("SELECT key FROM settings WHERE key = ?", (key,)).fetchone():
|
||||||
|
conn.execute("INSERT INTO settings (key, value) VALUES (?, ?)", (key, value))
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
for skill in BUILTIN_SKILLS:
|
||||||
|
if not conn.execute("SELECT skill_key FROM skills WHERE skill_key = ?", (skill["key"],)).fetchone():
|
||||||
|
conn.execute("INSERT INTO skills (skill_key, enabled, updated_at) VALUES (?, 1, ?)", (skill["key"], now))
|
||||||
|
|
||||||
|
existing_pin_hash = conn.execute("SELECT value FROM settings WHERE key = 'admin_pin_hash'").fetchone()
|
||||||
|
existing_pin_salt = conn.execute("SELECT value FROM settings WHERE key = 'admin_pin_salt'").fetchone()
|
||||||
|
if not existing_pin_hash or not existing_pin_salt:
|
||||||
|
from config import ALLOW_DEFAULT_PIN
|
||||||
|
configured_pin = os.getenv("JARVISCHAT_ADMIN_PIN", "").strip()
|
||||||
|
if re.fullmatch(r"\d{4}", configured_pin):
|
||||||
|
seed_pin, pin_source = configured_pin, "env"
|
||||||
|
elif ALLOW_DEFAULT_PIN:
|
||||||
|
seed_pin, pin_source = "1234", "default"
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Admin PIN bootstrap blocked: set JARVISCHAT_ADMIN_PIN to a 4-digit PIN "
|
||||||
|
"or set JARVISCHAT_ALLOW_DEFAULT_PIN=true."
|
||||||
|
)
|
||||||
|
salt_hex, pin_hash_hex = hash_pin(seed_pin)
|
||||||
|
conn.execute("INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)", ("admin_pin_hash", pin_hash_hex))
|
||||||
|
conn.execute("INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)", ("admin_pin_salt", salt_hex))
|
||||||
|
if pin_source == "default":
|
||||||
|
log.warning("Admin PIN seeded from insecure default 1234 (override enabled).")
|
||||||
|
else:
|
||||||
|
log.info("Admin PIN hash seeded from configured environment PIN.")
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
31
gpu.py
Normal file
31
gpu.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
"""
|
||||||
|
JarvisChat - AMD GPU stats via rocm-smi.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
log = logging.getLogger("jarvischat")
|
||||||
|
|
||||||
|
|
||||||
|
def get_gpu_stats() -> dict:
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
["rocm-smi", "--showuse", "--showmemuse", "--json"],
|
||||||
|
capture_output=True, text=True, timeout=5,
|
||||||
|
)
|
||||||
|
if result.returncode == 0:
|
||||||
|
data = json.loads(result.stdout)
|
||||||
|
gpu_info = data.get("card0", {})
|
||||||
|
gpu_use = gpu_info.get("GPU use (%)", 0)
|
||||||
|
vram_use = gpu_info.get("GPU Memory Allocated (VRAM%)", 0)
|
||||||
|
if isinstance(gpu_use, str):
|
||||||
|
gpu_use = int(gpu_use.replace("%", "").strip() or 0)
|
||||||
|
if isinstance(vram_use, str):
|
||||||
|
vram_use = int(vram_use.replace("%", "").strip() or 0)
|
||||||
|
return {"gpu_percent": gpu_use, "vram_percent": vram_use, "available": True}
|
||||||
|
except (subprocess.TimeoutExpired, FileNotFoundError, json.JSONDecodeError):
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"GPU stats error: {e}")
|
||||||
|
return {"gpu_percent": 0, "vram_percent": 0, "available": False}
|
||||||
2174
jarvischat_refactor.sh
Normal file
2174
jarvischat_refactor.sh
Normal file
File diff suppressed because it is too large
Load Diff
145
memory.py
Normal file
145
memory.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
"""
|
||||||
|
JarvisChat - FTS5 memory system.
|
||||||
|
CRUD, search, remember/forget command processing, topic detection.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from db import get_db
|
||||||
|
from config import MAX_MEMORY_FACT_CHARS
|
||||||
|
|
||||||
|
log = logging.getLogger("jarvischat")
|
||||||
|
|
||||||
|
REMEMBER_PATTERNS = [
|
||||||
|
(r"remember that (.+)", "explicit"),
|
||||||
|
(r"please remember (.+)", "explicit"),
|
||||||
|
(r"don'?t forget (.+)", "explicit"),
|
||||||
|
(r"note that (.+)", "explicit"),
|
||||||
|
(r"keep in mind (?:that )?(.+)", "explicit"),
|
||||||
|
]
|
||||||
|
|
||||||
|
FORGET_PATTERNS = [
|
||||||
|
r"forget (?:that )?(.+)",
|
||||||
|
r"don'?t remember (.+)",
|
||||||
|
r"remove (?:the )?memory (?:about |that )?(.+)",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def detect_topic(fact: str) -> str:
|
||||||
|
fact_lower = fact.lower()
|
||||||
|
if any(w in fact_lower for w in ["prefer", "like", "hate", "always", "never", "favorite"]):
|
||||||
|
return "preference"
|
||||||
|
elif any(w in fact_lower for w in ["working on", "building", "project", "developing"]):
|
||||||
|
return "project"
|
||||||
|
elif any(w in fact_lower for w in ["run", "install", "server", "ip", "port", "service", "docker", "systemd"]):
|
||||||
|
return "infrastructure"
|
||||||
|
elif any(w in fact_lower for w in ["my name", "i am", "i'm a", "i live", "my wife", "my partner"]):
|
||||||
|
return "personal"
|
||||||
|
return "general"
|
||||||
|
|
||||||
|
|
||||||
|
def add_memory(fact: str, topic: str = "general", source: str = "explicit") -> Optional[int]:
|
||||||
|
db = get_db()
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
cur = db.execute(
|
||||||
|
"INSERT INTO memories (fact, topic, source, created_at) VALUES (?, ?, ?, ?)",
|
||||||
|
(fact, topic, source, now),
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
rowid = cur.lastrowid
|
||||||
|
db.close()
|
||||||
|
log.info(f"Memory added [{topic}]: {fact[:50]}...")
|
||||||
|
return rowid
|
||||||
|
|
||||||
|
|
||||||
|
def search_memories(query: str, limit: int = 5) -> list:
|
||||||
|
if not query.strip():
|
||||||
|
return []
|
||||||
|
db = get_db()
|
||||||
|
words = re.findall(r"[A-Za-z0-9_]+", query)
|
||||||
|
if not words:
|
||||||
|
db.close()
|
||||||
|
return []
|
||||||
|
escaped = []
|
||||||
|
for word in words[:10]:
|
||||||
|
if word.upper() in {"AND", "OR", "NOT", "NEAR"}:
|
||||||
|
escaped.append(f'"{word}"*')
|
||||||
|
else:
|
||||||
|
escaped.append(word + "*")
|
||||||
|
safe_query = " OR ".join(escaped)
|
||||||
|
try:
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid, fact, topic, source, created_at, bm25(memories) AS rank "
|
||||||
|
"FROM memories WHERE memories MATCH ? ORDER BY rank LIMIT ?",
|
||||||
|
(safe_query, limit),
|
||||||
|
).fetchall()
|
||||||
|
results = [dict(row) for row in rows]
|
||||||
|
log.debug(f"Memory search '{query}' returned {len(results)} results")
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"Memory search error: {e}")
|
||||||
|
results = []
|
||||||
|
db.close()
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_memories(topic: Optional[str] = None) -> list:
|
||||||
|
db = get_db()
|
||||||
|
if topic:
|
||||||
|
rows = db.execute(
|
||||||
|
"SELECT rowid, * FROM memories WHERE topic = ? ORDER BY created_at DESC", (topic,)
|
||||||
|
).fetchall()
|
||||||
|
else:
|
||||||
|
rows = db.execute("SELECT rowid, * FROM memories ORDER BY created_at DESC").fetchall()
|
||||||
|
db.close()
|
||||||
|
return [dict(row) for row in rows]
|
||||||
|
|
||||||
|
|
||||||
|
def delete_memory(rowid: int) -> bool:
|
||||||
|
db = get_db()
|
||||||
|
cur = db.execute("DELETE FROM memories WHERE rowid = ?", (rowid,))
|
||||||
|
db.commit()
|
||||||
|
deleted = cur.rowcount > 0
|
||||||
|
db.close()
|
||||||
|
if deleted:
|
||||||
|
log.info(f"Memory deleted: rowid={rowid}")
|
||||||
|
return deleted
|
||||||
|
|
||||||
|
|
||||||
|
def update_memory(rowid: int, fact: str) -> bool:
|
||||||
|
db = get_db()
|
||||||
|
cur = db.execute("UPDATE memories SET fact = ? WHERE rowid = ?", (fact, rowid))
|
||||||
|
db.commit()
|
||||||
|
updated = cur.rowcount > 0
|
||||||
|
db.close()
|
||||||
|
return updated
|
||||||
|
|
||||||
|
|
||||||
|
def get_memory_count() -> int:
|
||||||
|
db = get_db()
|
||||||
|
count = db.execute("SELECT COUNT(*) as c FROM memories").fetchone()["c"]
|
||||||
|
db.close()
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
def process_remember_command(user_message: str) -> Optional[str]:
|
||||||
|
for pattern, source in REMEMBER_PATTERNS:
|
||||||
|
match = re.search(pattern, user_message, re.IGNORECASE)
|
||||||
|
if match:
|
||||||
|
fact = match.group(1).strip().rstrip(".")
|
||||||
|
topic = detect_topic(fact)
|
||||||
|
add_memory(fact, topic=topic, source=source)
|
||||||
|
return f"✓ Remembered [{topic}]: {fact}"
|
||||||
|
for pattern in FORGET_PATTERNS:
|
||||||
|
match = re.search(pattern, user_message, re.IGNORECASE)
|
||||||
|
if match:
|
||||||
|
search_term = match.group(1).strip().rstrip(".")
|
||||||
|
memories = search_memories(search_term, limit=3)
|
||||||
|
if memories:
|
||||||
|
for m in memories:
|
||||||
|
delete_memory(m["rowid"])
|
||||||
|
return f"✓ Forgot {len(memories)} memory/memories about: {search_term}"
|
||||||
|
else:
|
||||||
|
return f"✗ No memories found about: {search_term}"
|
||||||
|
return None
|
||||||
80
rag.py
Normal file
80
rag.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
"""
|
||||||
|
JarvisChat - RAG pipeline: Qdrant vector search + system prompt assembly.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from db import get_db, get_setting, list_skills_with_state, format_active_skills_prompt
|
||||||
|
from memory import search_memories
|
||||||
|
from config import MAX_SKILL_PROMPT_CHARS
|
||||||
|
|
||||||
|
log = logging.getLogger("jarvischat")
|
||||||
|
|
||||||
|
QDRANT_URL = "http://192.168.50.108:6333"
|
||||||
|
EMBED_URL = "http://192.168.50.210:11434"
|
||||||
|
EMBED_MODEL = "mxbai-embed-large"
|
||||||
|
RAG_COLLECTION = "jarvis_rag"
|
||||||
|
RAG_SCORE_THRESHOLD = 0.25
|
||||||
|
|
||||||
|
|
||||||
|
async def query_rag(query: str, limit: int = 3) -> list:
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
embed_resp = await client.post(
|
||||||
|
f"{EMBED_URL}/api/embeddings",
|
||||||
|
json={"model": EMBED_MODEL, "prompt": query},
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
if embed_resp.status_code != 200:
|
||||||
|
return []
|
||||||
|
vector = embed_resp.json()["embedding"]
|
||||||
|
search_resp = await client.post(
|
||||||
|
f"{QDRANT_URL}/collections/{RAG_COLLECTION}/points/search",
|
||||||
|
json={"vector": vector, "limit": limit, "with_payload": True},
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
if search_resp.status_code != 200:
|
||||||
|
return []
|
||||||
|
return search_resp.json().get("result", [])
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"RAG query error: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def build_system_prompt(db, extra_prompt: str = "", user_message: str = "") -> str:
|
||||||
|
parts = []
|
||||||
|
settings = {row["key"]: row["value"] for row in db.execute("SELECT key, value FROM settings").fetchall()}
|
||||||
|
|
||||||
|
if settings.get("profile_enabled", "true") == "true":
|
||||||
|
profile = db.execute("SELECT content FROM profile WHERE id = 1").fetchone()
|
||||||
|
if profile and profile["content"].strip():
|
||||||
|
parts.append(profile["content"].strip())
|
||||||
|
|
||||||
|
if settings.get("memory_enabled", "true") == "true" and user_message:
|
||||||
|
memories = search_memories(user_message, limit=5)
|
||||||
|
if memories:
|
||||||
|
memory_lines = [f"- {m['fact']}" for m in memories]
|
||||||
|
parts.append("## Relevant Context from Memory\n" + "\n".join(memory_lines))
|
||||||
|
log.debug(f"Injected {len(memories)} memories into context")
|
||||||
|
|
||||||
|
if user_message:
|
||||||
|
try:
|
||||||
|
rag_results = await query_rag(user_message)
|
||||||
|
if rag_results:
|
||||||
|
rag_lines = [r["payload"]["text"] for r in rag_results if r["score"] > RAG_SCORE_THRESHOLD]
|
||||||
|
if rag_lines:
|
||||||
|
parts.append("## Retrieved Context\n" + "\n\n---\n\n".join(rag_lines))
|
||||||
|
log.info(f"RAG injected {len(rag_lines)} chunks into context")
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"RAG injection error: {e}")
|
||||||
|
|
||||||
|
if settings.get("skills_enabled", "true") == "true":
|
||||||
|
active_skills = [s for s in list_skills_with_state(db) if s["enabled"]]
|
||||||
|
if active_skills:
|
||||||
|
parts.append(format_active_skills_prompt(active_skills))
|
||||||
|
|
||||||
|
if extra_prompt and extra_prompt.strip():
|
||||||
|
parts.append(extra_prompt.strip())
|
||||||
|
|
||||||
|
return "\n\n---\n\n".join(parts) if parts else ""
|
||||||
355
readme.md
355
readme.md
@@ -1,355 +0,0 @@
|
|||||||
# ⚡ JarvisChat v1.7.8
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
**A lightweight Ollama coding companion with persistent memory, web search, and real-time system monitoring.**
|
|
||||||
|
|
||||||
Built with FastAPI + SQLite + Jinja2. Runs on Python 3.13. No Docker required.
|
|
||||||
|
|
||||||
Developer wiki: [docs/wiki/Home.md](docs/wiki/Home.md)
|
|
||||||
|
|
||||||
Core architecture deep-dive: [docs/wiki/Developer-Architecture.md](docs/wiki/Developer-Architecture.md)
|
|
||||||
|
|
||||||
## Security Scope Disclaimer
|
|
||||||
|
|
||||||
JarvisChat is designed for local and home-lab use (same host or trusted LAN).
|
|
||||||
|
|
||||||
JarvisChat may technically work with frontier or commercial AI endpoints, but the author does not recommend or support that usage.
|
|
||||||
|
|
||||||
Supported deployments are contained local/home-lab environments.
|
|
||||||
|
|
||||||
By default, API access is limited to loopback + private LAN CIDRs. You can override with `JARVISCHAT_ALLOWED_CIDRS` (comma-separated CIDRs) and optionally trust reverse-proxy forwarding with `JARVISCHAT_TRUST_X_FORWARDED_FOR=true`.
|
|
||||||
|
|
||||||
If you deploy outside a trusted local subnet, your risk profile changes significantly and the default protections here may be insufficient.
|
|
||||||
|
|
||||||
Use at your own risk. No warranty is provided for Internet-exposed deployments.
|
|
||||||
|
|
||||||
## What's New in v1.7.x
|
|
||||||
|
|
||||||
- **Security hardening suite completed** - request rate limits, payload caps, settings allowlist, safe error envelopes, and LAN CIDR gate controls
|
|
||||||
- **Customer-safe incident handling** - client-facing errors include support-friendly incident keys while full traces remain in server logs
|
|
||||||
- **Streaming and regression test expansion** - automated coverage for SSE chat/search paths, memory remember/forget command handling, and auth/guardrail behavior
|
|
||||||
- **Skills framework (Phase 1)** - built-in local skill registry with per-skill enable controls, API endpoints, and bounded prompt injection
|
|
||||||
- **Skills WebUX controls** - Settings modal now includes a master skills toggle and per-skill toggles for admin users
|
|
||||||
|
|
||||||
## What's New in v1.6.x
|
|
||||||
|
|
||||||
- **Guest/admin capability split** - guest chat by default with 4-digit admin PIN for advanced or destructive operations
|
|
||||||
- **Session + lockout controls** - session lifecycle endpoints, heartbeat, logout/revoke behavior, failed PIN lockout protections, and auth audit events
|
|
||||||
- **Browser request protections** - strict origin checks for state-changing requests and admin-only write enforcement
|
|
||||||
- **Unsafe link protection** - outbound search links sanitized to allow only http/https absolute URLs
|
|
||||||
- **Operational stability fixes** - safer first-boot PIN policy handling and memory-search tokenization fix for punctuation/FTS edge cases
|
|
||||||
|
|
||||||
## What's New in v1.5.0
|
|
||||||
|
|
||||||
- **Explicit Web Search Button** — 🔍 button next to SEND forces a web search, bypassing model uncertainty detection
|
|
||||||
- **Orange Search Styling** — Search results, WEB badge, and search button share consistent orange color scheme
|
|
||||||
- **Expanded Refusal Patterns** — Added "As an AI model", "based on my training data", "I don't have the capability"
|
|
||||||
- **Code cleanup** — Removed unused `JSONResponse` import and dead `raw_results_md` variable
|
|
||||||
- **Bug fixes** — Replaced bare `except` clauses with `except Exception`; corrected `add_memory()` return type to `int | None`; updated `TemplateResponse` call to Starlette's current API signature
|
|
||||||
|
|
||||||
## What's New in v1.4.0
|
|
||||||
|
|
||||||
- **FTS5 Memory System**: Say "remember that..." to store facts — they're automatically retrieved by relevance and injected into context
|
|
||||||
- **Forget Command**: Say "forget about..." to remove memories
|
|
||||||
- **Memory Toggle**: Enable/disable memory injection from topbar or settings
|
|
||||||
- **Multi-file Structure**: Backend and frontend separated for easier maintenance
|
|
||||||
|
|
||||||
## Features
|
|
||||||
|
|
||||||
- **Persistent Memory** — SQLite FTS5 full-text search for fast, relevant memory retrieval
|
|
||||||
- **Web Search** — SearXNG integration for automatic web lookups when the model is uncertain
|
|
||||||
- **Explicit Search** — 🔍 button to force web search without waiting for model uncertainty
|
|
||||||
- **Profile Injection** — Custom system prompt injected into every conversation
|
|
||||||
- **System Presets** — Save and switch between different system prompts
|
|
||||||
- **Real-time Stats** — CPU, RAM, GPU, VRAM monitoring in sidebar
|
|
||||||
- **Token Thermometer** — Visual context window usage indicator
|
|
||||||
- **Streaming Responses** — Server-sent events for real-time token display
|
|
||||||
- **Conversation History** — SQLite-backed chat persistence with mass-delete option
|
|
||||||
- **Model Switching** — Change Ollama models on the fly
|
|
||||||
|
|
||||||
## Current WiP (Prioritized)
|
|
||||||
|
|
||||||
Canonical backlog: [docs/wiki/current-wip.md](docs/wiki/current-wip.md)
|
|
||||||
|
|
||||||
Scope boundary: local-first (same-host Ollama), optional RFC1918 LAN endpoints, no public Internet AI endpoints by default.
|
|
||||||
|
|
||||||
Total identified items: 27
|
|
||||||
|
|
||||||
Top 10 (brief):
|
|
||||||
|
|
||||||
1. P0 [DONE]: Add auth for write/admin endpoints
|
|
||||||
2. P0 [DONE]: Add CSRF/origin protection for state-changing requests
|
|
||||||
3. P0 [DONE]: Block unsafe URL schemes in rendered links
|
|
||||||
4. P0 [DONE]: Add rate limiting and request size limits
|
|
||||||
5. P1 [DONE]: Restrict `/api/settings` updates to allowlisted keys
|
|
||||||
6. P1: Add pagination + hard caps for list APIs
|
|
||||||
7. P1 [DONE]: Replace raw exception leakage with safe client errors
|
|
||||||
8. P1 [DONE]: Add automated tests for streaming/search/memory paths
|
|
||||||
9. P2 [DONE]: Implement MCP-style skills/tool-call framework
|
|
||||||
10. P2: Implement heartbeat/check-in scheduler + summary endpoint
|
|
||||||
|
|
||||||
Item 1 executive summary: keep guest mode for conversational chat, require 4-digit admin PIN for advanced/destructive actions, and enforce local/LAN-only backend policy by default.
|
|
||||||
|
|
||||||
Implementation status: complete (guest session by default + admin unlock + admin-only write enforcement + origin checks + safe-link sanitization + audit logging + rate/payload guardrails + capability tests).
|
|
||||||
|
|
||||||
## TODO
|
|
||||||
|
|
||||||
1. ~~Verify SearXNG and Docker services persist across reboots~~
|
|
||||||
2. Conversation search/filter by keyword
|
|
||||||
3. Export conversation to markdown/text
|
|
||||||
4. Keyboard shortcuts (Ctrl+N new chat, Ctrl+Enter send)
|
|
||||||
5. Retry button on assistant messages
|
|
||||||
6. Source links — clickable links when search used
|
|
||||||
7. Allow conversation renaming
|
|
||||||
8. Multiple profiles — coding/sysadmin/general
|
|
||||||
9. Auto-generate conversation tags (client-side KWIC, top 5, filterable badges)
|
|
||||||
10. Image input support — pull vision model, file input/drag-drop, base64 encode, pass `images` array to Ollama `/api/chat`
|
|
||||||
11. Split-screen option for btop display
|
|
||||||
12. Skills as markdown files — `/opt/jarvischat/skills/`, YAML frontmatter + instructions, injected into context for tool calls
|
|
||||||
13. Heartbeats / proactive check-ins — cron + endpoint for daily briefings, HA anomaly alerts
|
|
||||||
14. Model info button — (i) icon next to Model dropdown, shows div with model description, last updated date, best-use purpose
|
|
||||||
15. Set default model — toggle any model as the default selection
|
|
||||||
16. Hide/remove model from list — exclude models from dropdown
|
|
||||||
17. Update model function — trigger `ollama pull` for selected model from UI
|
|
||||||
18. Add mouseover tooltip to SEND button
|
|
||||||
19. Add preflight validation for required model/preset selection and show a clear warning before send to prevent avoidable timeout loops
|
|
||||||
|
|
||||||
## File Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
/opt/jarvischat/
|
|
||||||
├── app.py # FastAPI backend
|
|
||||||
├── jarvischat.db # SQLite database (auto-created)
|
|
||||||
├── static/
|
|
||||||
│ └── logo.png # Logo image (optional)
|
|
||||||
└── templates/
|
|
||||||
└── index.html # Frontend
|
|
||||||
```
|
|
||||||
|
|
||||||
## Requirements
|
|
||||||
|
|
||||||
- Python 3.11+ (tested on 3.13)
|
|
||||||
- Ollama running locally or on network
|
|
||||||
- SearXNG (optional, for web search)
|
|
||||||
|
|
||||||
## Installation
|
|
||||||
|
|
||||||
### Fresh Install
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Create directory and venv
|
|
||||||
sudo mkdir -p /opt/jarvischat
|
|
||||||
sudo chown $USER:$USER /opt/jarvischat
|
|
||||||
cd /opt/jarvischat
|
|
||||||
python3 -m venv venv
|
|
||||||
|
|
||||||
# Install dependencies
|
|
||||||
./venv/bin/pip install fastapi uvicorn httpx psutil jinja2 python-multipart
|
|
||||||
|
|
||||||
# Set admin PIN before first startup (4 digits)
|
|
||||||
export JARVISCHAT_ADMIN_PIN=4827
|
|
||||||
|
|
||||||
# Create subdirectories
|
|
||||||
mkdir -p templates static
|
|
||||||
|
|
||||||
# Copy files
|
|
||||||
# (copy app.py to /opt/jarvischat/)
|
|
||||||
# (copy index.html to /opt/jarvischat/templates/)
|
|
||||||
# (copy logo.png to /opt/jarvischat/static/ — optional)
|
|
||||||
```
|
|
||||||
|
|
||||||
WARNING: Do not use `1234` as your admin PIN unless you accept weak local security.
|
|
||||||
|
|
||||||
NOTE: First boot now requires `JARVISCHAT_ADMIN_PIN` unless you explicitly opt into insecure fallback with `JARVISCHAT_ALLOW_DEFAULT_PIN=true`.
|
|
||||||
|
|
||||||
### Upgrading from v1.4.x
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd /opt/jarvischat
|
|
||||||
|
|
||||||
# Backup
|
|
||||||
cp app.py app.py.bak
|
|
||||||
cp templates/index.html templates/index.html.bak
|
|
||||||
|
|
||||||
# Copy new files
|
|
||||||
# (copy app.py, replacing old version)
|
|
||||||
# (copy index.html to templates/)
|
|
||||||
|
|
||||||
# Restart
|
|
||||||
sudo systemctl restart jarvischat
|
|
||||||
```
|
|
||||||
|
|
||||||
## Systemd Service
|
|
||||||
|
|
||||||
Create `/etc/systemd/system/jarvischat.service`:
|
|
||||||
|
|
||||||
```ini
|
|
||||||
[Unit]
|
|
||||||
Description=JarvisChat - Local Ollama Web Interface
|
|
||||||
After=network.target
|
|
||||||
|
|
||||||
[Service]
|
|
||||||
Type=simple
|
|
||||||
User=jarvischat
|
|
||||||
Group=jarvischat
|
|
||||||
WorkingDirectory=/opt/jarvischat
|
|
||||||
ExecStart=/opt/jarvischat/venv/bin/uvicorn app:app --host 0.0.0.0 --port 8080
|
|
||||||
Restart=always
|
|
||||||
RestartSec=5
|
|
||||||
|
|
||||||
[Install]
|
|
||||||
WantedBy=multi-user.target
|
|
||||||
```
|
|
||||||
|
|
||||||
```bash
|
|
||||||
sudo systemctl daemon-reload
|
|
||||||
sudo systemctl enable jarvischat
|
|
||||||
sudo systemctl start jarvischat
|
|
||||||
```
|
|
||||||
|
|
||||||
## Memory Commands
|
|
||||||
|
|
||||||
In chat, natural language triggers memory operations:
|
|
||||||
|
|
||||||
| You say | What happens |
|
|
||||||
|---------|--------------|
|
|
||||||
| "remember that I prefer Rust over Go" | Stores as `preference` |
|
|
||||||
| "remember that JarvisChat runs on port 8080" | Stores as `infrastructure` |
|
|
||||||
| "note that the deadline is Friday" | Stores as `general` |
|
|
||||||
| "forget about the deadline" | Removes matching memories |
|
|
||||||
|
|
||||||
Memories are automatically searched based on your message content and injected into the system prompt when relevant.
|
|
||||||
|
|
||||||
### Memory Topics
|
|
||||||
|
|
||||||
Memories are auto-categorized:
|
|
||||||
- `preference` — likes, dislikes, choices
|
|
||||||
- `project` — active work, repos, tasks
|
|
||||||
- `infrastructure` — servers, services, configs
|
|
||||||
- `personal` — name, location, background
|
|
||||||
- `general` — everything else
|
|
||||||
|
|
||||||
## API Endpoints
|
|
||||||
|
|
||||||
### Memory
|
|
||||||
|
|
||||||
| Method | Endpoint | Description |
|
|
||||||
|--------|----------|-------------|
|
|
||||||
| GET | `/api/memories` | List all memories |
|
|
||||||
| POST | `/api/memories` | Add memory `{"fact": "...", "topic": "general"}` |
|
|
||||||
| DELETE | `/api/memories/{rowid}` | Delete memory by ID |
|
|
||||||
| GET | `/api/memories/search?q=term` | Search memories |
|
|
||||||
| GET | `/api/memories/stats` | Get counts by topic |
|
|
||||||
|
|
||||||
### Chat & Models
|
|
||||||
|
|
||||||
| Method | Endpoint | Description |
|
|
||||||
|--------|----------|-------------|
|
|
||||||
| GET | `/api/models` | List available Ollama models |
|
|
||||||
| POST | `/api/chat` | Send message (streaming SSE) |
|
|
||||||
| POST | `/api/search` | Explicit web search (streaming SSE) |
|
|
||||||
| POST | `/api/show` | Get model info (context size) |
|
|
||||||
| GET | `/api/ps` | Get running models |
|
|
||||||
|
|
||||||
### Settings & Profile
|
|
||||||
|
|
||||||
| Method | Endpoint | Description |
|
|
||||||
|--------|----------|-------------|
|
|
||||||
| GET | `/api/profile` | Get profile content |
|
|
||||||
| PUT | `/api/profile` | Update profile |
|
|
||||||
| GET | `/api/profile/default` | Get default profile |
|
|
||||||
| GET | `/api/settings` | Get settings |
|
|
||||||
| PUT | `/api/settings` | Update settings |
|
|
||||||
|
|
||||||
### Conversations
|
|
||||||
|
|
||||||
| Method | Endpoint | Description |
|
|
||||||
|--------|----------|-------------|
|
|
||||||
| GET | `/api/conversations` | List conversations |
|
|
||||||
| GET | `/api/conversations/{id}` | Get conversation with messages |
|
|
||||||
| DELETE | `/api/conversations/{id}` | Delete conversation |
|
|
||||||
| DELETE | `/api/conversations` | Delete ALL conversations |
|
|
||||||
|
|
||||||
### Presets
|
|
||||||
|
|
||||||
| Method | Endpoint | Description |
|
|
||||||
|--------|----------|-------------|
|
|
||||||
| GET | `/api/presets` | List presets |
|
|
||||||
| POST | `/api/presets` | Create preset |
|
|
||||||
| PUT | `/api/presets/{id}` | Update preset |
|
|
||||||
| DELETE | `/api/presets/{id}` | Delete preset |
|
|
||||||
|
|
||||||
### System
|
|
||||||
|
|
||||||
| Method | Endpoint | Description |
|
|
||||||
|--------|----------|-------------|
|
|
||||||
| GET | `/api/stats` | CPU, RAM, GPU, VRAM stats |
|
|
||||||
| GET | `/api/search/status` | SearXNG availability |
|
|
||||||
|
|
||||||
## Configuration
|
|
||||||
|
|
||||||
Settings are stored in the `settings` table and include:
|
|
||||||
|
|
||||||
- `profile_enabled` — Inject profile into chats (true/false)
|
|
||||||
- `search_enabled` — Auto web search (true/false)
|
|
||||||
- `memory_enabled` — Memory injection (true/false)
|
|
||||||
- `default_model` — Default Ollama model
|
|
||||||
- `searxng_url` — SearXNG instance URL (default: `http://localhost:8888`)
|
|
||||||
|
|
||||||
## Testing Memory
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Add a memory via API
|
|
||||||
curl -X POST http://jarvis:8080/api/memories \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{"fact": "User prefers native installs over Docker", "topic": "preference"}'
|
|
||||||
|
|
||||||
# Search memories
|
|
||||||
curl "http://jarvis:8080/api/memories/search?q=docker"
|
|
||||||
|
|
||||||
# List all memories
|
|
||||||
curl http://jarvis:8080/api/memories
|
|
||||||
|
|
||||||
# Get stats
|
|
||||||
curl http://jarvis:8080/api/memories/stats
|
|
||||||
```
|
|
||||||
|
|
||||||
Or in chat:
|
|
||||||
1. Say "remember that I hate YAML"
|
|
||||||
2. Later ask "what markup languages should I avoid?"
|
|
||||||
3. JarvisChat will inject the YAML preference into context
|
|
||||||
|
|
||||||
## Troubleshooting
|
|
||||||
|
|
||||||
### Service won't start
|
|
||||||
|
|
||||||
Check logs:
|
|
||||||
```bash
|
|
||||||
journalctl -u jarvischat -n 50 --no-pager
|
|
||||||
```
|
|
||||||
|
|
||||||
Common issues:
|
|
||||||
- Missing `jinja2`: `./venv/bin/pip install jinja2`
|
|
||||||
- Missing `templates/` directory
|
|
||||||
- Wrong permissions on `/opt/jarvischat`
|
|
||||||
|
|
||||||
### Memory not working
|
|
||||||
|
|
||||||
1. Check memory is enabled (🧠 MEM ON in topbar)
|
|
||||||
2. Verify memories exist: `curl http://jarvis:8080/api/memories`
|
|
||||||
3. Check FTS5 table: `sqlite3 jarvischat.db "SELECT * FROM memories_fts;"`
|
|
||||||
|
|
||||||
### Web search not working
|
|
||||||
|
|
||||||
1. Verify SearXNG is running: `curl http://localhost:8888/search?q=test&format=json`
|
|
||||||
2. Check search status: `curl http://jarvis:8080/api/search/status`
|
|
||||||
3. Ensure JSON format is enabled in SearXNG settings
|
|
||||||
|
|
||||||
## License
|
|
||||||
|
|
||||||
MIT
|
|
||||||
|
|
||||||
## Repository
|
|
||||||
|
|
||||||
Gitea: `ssh://gitea@llgit.llamachile.tube:1319/gramps/jarvisChat.git`
|
|
||||||
0
routers/__init__.py
Normal file
0
routers/__init__.py
Normal file
212
routers/chat.py
Normal file
212
routers/chat.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
"""JarvisChat routers - /api/chat streaming endpoint."""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
from config import DEFAULT_MODEL, LLAMA_SERVER_BASE
|
||||||
|
from db import get_db
|
||||||
|
from memory import process_remember_command
|
||||||
|
from rag import build_system_prompt
|
||||||
|
from search import (calculate_perplexity, is_uncertain, is_refusal,
|
||||||
|
clean_hedging, format_search_results, format_direct_answer,
|
||||||
|
extract_search_query, query_searxng)
|
||||||
|
from security import read_json_body, log_incident, BODY_LIMIT_CHAT_BYTES
|
||||||
|
from config import MAX_CHAT_MESSAGE_CHARS
|
||||||
|
|
||||||
|
log = logging.getLogger("jarvischat")
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_llama_stream_chunk(line: str) -> tuple:
|
||||||
|
if line.startswith("data: "):
|
||||||
|
line = line[6:]
|
||||||
|
if line.strip() == "[DONE]":
|
||||||
|
return None, True, {}, []
|
||||||
|
try:
|
||||||
|
chunk = json.loads(line)
|
||||||
|
choices = chunk.get("choices", [])
|
||||||
|
if choices:
|
||||||
|
delta = choices[0].get("delta", {})
|
||||||
|
token = delta.get("content")
|
||||||
|
finish = choices[0].get("finish_reason")
|
||||||
|
stats = {}
|
||||||
|
logprobs_list = []
|
||||||
|
logprobs_info = choices[0].get("logprobs")
|
||||||
|
if logprobs_info:
|
||||||
|
content_logprobs = logprobs_info.get("content", [])
|
||||||
|
for entry in content_logprobs:
|
||||||
|
if "logprob" in entry:
|
||||||
|
logprobs_list.append({"logprob": entry["logprob"]})
|
||||||
|
if finish == "stop":
|
||||||
|
usage = chunk.get("usage", {})
|
||||||
|
stats["tokens_per_sec"] = usage.get("tokens_per_second", 0.0)
|
||||||
|
return token, finish == "stop", stats, logprobs_list
|
||||||
|
if "message" in chunk and "content" in chunk["message"]:
|
||||||
|
token = chunk["message"]["content"]
|
||||||
|
done = chunk.get("done", False)
|
||||||
|
stats = {}
|
||||||
|
if done:
|
||||||
|
eval_count = chunk.get("eval_count", 0)
|
||||||
|
eval_duration = chunk.get("eval_duration", 0)
|
||||||
|
stats["tokens_per_sec"] = (eval_count / (eval_duration / 1e9)) if eval_duration > 0 else 0
|
||||||
|
return token, done, stats, []
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
return None, False, {}, []
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/api/chat")
|
||||||
|
async def chat(request: Request):
|
||||||
|
body = await read_json_body(request, BODY_LIMIT_CHAT_BYTES)
|
||||||
|
conv_id = body.get("conversation_id")
|
||||||
|
user_message = body.get("message", "").strip()
|
||||||
|
if len(user_message) > MAX_CHAT_MESSAGE_CHARS:
|
||||||
|
raise HTTPException(status_code=413, detail="Chat message is too long")
|
||||||
|
model = body.get("model", DEFAULT_MODEL)
|
||||||
|
preset_prompt = body.get("system_prompt", "")
|
||||||
|
|
||||||
|
if not user_message:
|
||||||
|
raise HTTPException(status_code=400, detail="Empty message")
|
||||||
|
|
||||||
|
db = get_db()
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
settings = {row["key"]: row["value"] for row in db.execute("SELECT key, value FROM settings").fetchall()}
|
||||||
|
search_enabled = settings.get("search_enabled", "true") == "true"
|
||||||
|
|
||||||
|
remember_response = process_remember_command(user_message)
|
||||||
|
|
||||||
|
if not conv_id:
|
||||||
|
conv_id = str(uuid.uuid4())
|
||||||
|
title = user_message[:80] + ("..." if len(user_message) > 80 else "")
|
||||||
|
db.execute("INSERT INTO conversations (id, title, model, created_at, updated_at) VALUES (?, ?, ?, ?, ?)",
|
||||||
|
(conv_id, title, model, now, now))
|
||||||
|
else:
|
||||||
|
db.execute("UPDATE conversations SET updated_at = ? WHERE id = ?", (now, conv_id))
|
||||||
|
|
||||||
|
db.execute("INSERT INTO messages (conversation_id, role, content, created_at) VALUES (?, ?, ?, ?)",
|
||||||
|
(conv_id, "user", user_message, now))
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
history_rows = db.execute(
|
||||||
|
"SELECT role, content FROM messages WHERE conversation_id = ? ORDER BY id ASC", (conv_id,)
|
||||||
|
).fetchall()
|
||||||
|
system_prompt = await build_system_prompt(db, preset_prompt, user_message)
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
|
for row in history_rows:
|
||||||
|
messages.append({"role": row["role"], "content": row["content"]})
|
||||||
|
|
||||||
|
upstream_payload = {"model": model, "messages": messages, "stream": True, "logprobs": True}
|
||||||
|
|
||||||
|
async def stream_response():
|
||||||
|
full_response = []
|
||||||
|
all_logprobs = []
|
||||||
|
tokens_per_sec = 0.0
|
||||||
|
|
||||||
|
if remember_response:
|
||||||
|
yield f"data: {json.dumps({'token': remember_response + chr(10) + chr(10), 'conversation_id': conv_id})}\n\n"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
async with client.stream(
|
||||||
|
"POST", f"{LLAMA_SERVER_BASE}/v1/chat/completions",
|
||||||
|
json=upstream_payload,
|
||||||
|
timeout=httpx.Timeout(300.0, connect=10.0),
|
||||||
|
) as resp:
|
||||||
|
async for line in resp.aiter_lines():
|
||||||
|
if line.strip():
|
||||||
|
token, done, stats, chunk_logprobs = parse_llama_stream_chunk(line)
|
||||||
|
if chunk_logprobs:
|
||||||
|
all_logprobs.extend(chunk_logprobs)
|
||||||
|
if token:
|
||||||
|
full_response.append(token)
|
||||||
|
yield f"data: {json.dumps({'token': token, 'conversation_id': conv_id})}\n\n"
|
||||||
|
if done:
|
||||||
|
tokens_per_sec = stats.get("tokens_per_sec", 0.0)
|
||||||
|
|
||||||
|
assistant_msg = "".join(full_response)
|
||||||
|
perplexity = calculate_perplexity(all_logprobs) if all_logprobs else 0.0
|
||||||
|
should_search = is_uncertain(all_logprobs) or is_refusal(assistant_msg)
|
||||||
|
|
||||||
|
if search_enabled and should_search:
|
||||||
|
yield f"data: {json.dumps({'searching': True, 'conversation_id': conv_id})}\n\n"
|
||||||
|
search_query = extract_search_query(user_message)
|
||||||
|
search_results = await query_searxng(search_query)
|
||||||
|
|
||||||
|
if search_results:
|
||||||
|
search_context = format_search_results(search_results)
|
||||||
|
augmented_messages = []
|
||||||
|
if system_prompt:
|
||||||
|
augmented_messages.append({"role": "system", "content": system_prompt + "\n\n" + search_context})
|
||||||
|
else:
|
||||||
|
augmented_messages.append({"role": "system", "content": search_context})
|
||||||
|
for row in history_rows[:-1]:
|
||||||
|
augmented_messages.append({"role": row["role"], "content": row["content"]})
|
||||||
|
augmented_messages.append({"role": "user", "content": user_message})
|
||||||
|
|
||||||
|
yield f"data: {json.dumps({'search_results': len(search_results), 'conversation_id': conv_id})}\n\n"
|
||||||
|
|
||||||
|
augmented_response = []
|
||||||
|
async with client.stream(
|
||||||
|
"POST", f"{LLAMA_SERVER_BASE}/v1/chat/completions",
|
||||||
|
json={"model": model, "messages": augmented_messages, "stream": True},
|
||||||
|
timeout=httpx.Timeout(300.0, connect=10.0),
|
||||||
|
) as resp2:
|
||||||
|
async for line in resp2.aiter_lines():
|
||||||
|
if line.strip():
|
||||||
|
token2, done2, _, _ = parse_llama_stream_chunk(line)
|
||||||
|
if token2:
|
||||||
|
augmented_response.append(token2)
|
||||||
|
if done2:
|
||||||
|
break
|
||||||
|
|
||||||
|
raw_response = "".join(augmented_response) or assistant_msg
|
||||||
|
cleaned_response = clean_hedging(raw_response)
|
||||||
|
if is_refusal(cleaned_response) or len(cleaned_response) < 20:
|
||||||
|
cleaned_response = format_direct_answer(user_message, search_results)
|
||||||
|
|
||||||
|
yield f"data: {json.dumps({'token': cleaned_response, 'conversation_id': conv_id, 'augmented': True})}\n\n"
|
||||||
|
|
||||||
|
saved_msg = cleaned_response + "\n\n---\n*🔍 Enhanced with web search results*"
|
||||||
|
if remember_response:
|
||||||
|
saved_msg = remember_response + "\n\n" + saved_msg
|
||||||
|
|
||||||
|
db2 = get_db()
|
||||||
|
db2.execute("INSERT INTO messages (conversation_id, role, content, created_at) VALUES (?, ?, ?, ?)",
|
||||||
|
(conv_id, "assistant", saved_msg, datetime.now(timezone.utc).isoformat()))
|
||||||
|
db2.commit()
|
||||||
|
db2.close()
|
||||||
|
|
||||||
|
yield f"data: {json.dumps({'done': True, 'conversation_id': conv_id, 'searched': True, 'perplexity': round(perplexity, 2), 'tokens_per_sec': round(tokens_per_sec, 1)})}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
saved_msg = assistant_msg
|
||||||
|
if remember_response:
|
||||||
|
saved_msg = remember_response + "\n\n" + saved_msg
|
||||||
|
|
||||||
|
db2 = get_db()
|
||||||
|
db2.execute("INSERT INTO messages (conversation_id, role, content, created_at) VALUES (?, ?, ?, ?)",
|
||||||
|
(conv_id, "assistant", saved_msg, datetime.now(timezone.utc).isoformat()))
|
||||||
|
db2.commit()
|
||||||
|
db2.close()
|
||||||
|
|
||||||
|
yield f"data: {json.dumps({'done': True, 'conversation_id': conv_id, 'perplexity': round(perplexity, 2), 'tokens_per_sec': round(tokens_per_sec, 1)})}\n\n"
|
||||||
|
|
||||||
|
except httpx.RemoteProtocolError:
|
||||||
|
pass
|
||||||
|
except httpx.ConnectError:
|
||||||
|
yield f"data: {json.dumps({'error': 'Cannot connect to inference server. Is it running?'})}\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
incident_key = log_incident("chat_stream", message="Inference stream failure during chat response",
|
||||||
|
request=request, exc=e)
|
||||||
|
yield f"data: {json.dumps({'error': 'Chat response generation failed before completion. Use the incident key for support lookup.', 'error_key': incident_key})}\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(stream_response(), media_type="text/event-stream")
|
||||||
267
routers/completions.py
Normal file
267
routers/completions.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
"""
|
||||||
|
JarvisChat - /v1/chat/completions router.
|
||||||
|
OpenAI-compatible endpoint for IDE integration (Continue.dev, etc.).
|
||||||
|
Runs all requests through the full jC pipeline: profile + RAG + memory injection.
|
||||||
|
FIM (fill-in-the-middle) requests are proxied directly — not persisted.
|
||||||
|
Chat-style requests are persisted to conversation history.
|
||||||
|
Auth: static Bearer token via COMPLETIONS_API_KEY in config.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
|
|
||||||
|
from config import DEFAULT_MODEL, LLAMA_SERVER_BASE, COMPLETIONS_API_KEY
|
||||||
|
from db import get_db
|
||||||
|
from rag import build_system_prompt
|
||||||
|
from routers.chat import parse_llama_stream_chunk
|
||||||
|
|
||||||
|
log = logging.getLogger("jarvischat")
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def _check_api_key(request: Request):
|
||||||
|
auth = request.headers.get("Authorization", "")
|
||||||
|
if not auth.startswith("Bearer "):
|
||||||
|
raise HTTPException(status_code=401, detail="Missing Bearer token")
|
||||||
|
token = auth[7:].strip()
|
||||||
|
if token != COMPLETIONS_API_KEY:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||||
|
|
||||||
|
|
||||||
|
def _is_fim_request(body: dict) -> bool:
|
||||||
|
"""
|
||||||
|
FIM (fill-in-the-middle) requests use a 'prompt' + optional 'suffix' structure
|
||||||
|
rather than a 'messages' array. Continue.dev sends these for inline autocomplete.
|
||||||
|
We proxy them directly without pipeline injection or persistence.
|
||||||
|
"""
|
||||||
|
return "prompt" in body and "messages" not in body
|
||||||
|
|
||||||
|
|
||||||
|
def _build_openai_chunk(token: str, model: str, conv_id: str) -> str:
|
||||||
|
chunk = {
|
||||||
|
"id": f"chatcmpl-{conv_id}",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"model": model,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"content": token},
|
||||||
|
"finish_reason": None,
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
return f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_openai_stop_chunk(model: str, conv_id: str) -> str:
|
||||||
|
chunk = {
|
||||||
|
"id": f"chatcmpl-{conv_id}",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"model": model,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
return f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_openai_response(content: str, model: str, conv_id: str) -> dict:
|
||||||
|
"""Non-streaming response envelope."""
|
||||||
|
return {
|
||||||
|
"id": f"chatcmpl-{conv_id}",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"model": model,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"message": {"role": "assistant", "content": content},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}],
|
||||||
|
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/v1/chat/completions")
|
||||||
|
async def chat_completions(request: Request):
|
||||||
|
_check_api_key(request)
|
||||||
|
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid JSON body")
|
||||||
|
|
||||||
|
# --- FIM passthrough ---
|
||||||
|
if _is_fim_request(body):
|
||||||
|
return await _fim_passthrough(body)
|
||||||
|
|
||||||
|
# --- Chat completion ---
|
||||||
|
messages = body.get("messages", [])
|
||||||
|
if not messages:
|
||||||
|
raise HTTPException(status_code=400, detail="No messages provided")
|
||||||
|
|
||||||
|
model = body.get("model", DEFAULT_MODEL)
|
||||||
|
stream = body.get("stream", True)
|
||||||
|
|
||||||
|
# Extract the latest user message for RAG + conversation title
|
||||||
|
user_message = ""
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if msg.get("role") == "user":
|
||||||
|
user_message = msg.get("content", "").strip()
|
||||||
|
break
|
||||||
|
|
||||||
|
if not user_message:
|
||||||
|
raise HTTPException(status_code=400, detail="No user message found")
|
||||||
|
|
||||||
|
# --- Persist conversation ---
|
||||||
|
db = get_db()
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
conv_id = str(uuid.uuid4())
|
||||||
|
title = f"[IDE] {user_message[:72]}{'...' if len(user_message) > 72 else ''}"
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO conversations (id, title, model, created_at, updated_at) VALUES (?, ?, ?, ?, ?)",
|
||||||
|
(conv_id, title, model, now, now),
|
||||||
|
)
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.get("role")
|
||||||
|
content = msg.get("content", "")
|
||||||
|
if role in ("user", "assistant"):
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO messages (conversation_id, role, content, created_at) VALUES (?, ?, ?, ?)",
|
||||||
|
(conv_id, role, content, now),
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
# --- Build system prompt through full jC pipeline ---
|
||||||
|
system_prompt = await build_system_prompt(db, "", user_message)
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
# Assemble messages for upstream: inject jC system prompt, preserve history
|
||||||
|
upstream_messages = []
|
||||||
|
if system_prompt:
|
||||||
|
upstream_messages.append({"role": "system", "content": system_prompt})
|
||||||
|
|
||||||
|
# Strip any system messages from the incoming payload — jC owns the system prompt
|
||||||
|
for msg in messages:
|
||||||
|
if msg.get("role") != "system":
|
||||||
|
upstream_messages.append(msg)
|
||||||
|
|
||||||
|
upstream_payload = {
|
||||||
|
"model": model,
|
||||||
|
"messages": upstream_messages,
|
||||||
|
"stream": True, # always stream from upstream; we buffer if client wants non-stream
|
||||||
|
}
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return StreamingResponse(
|
||||||
|
_stream_chat(upstream_payload, model, conv_id, request),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return await _blocking_chat(upstream_payload, model, conv_id, request)
|
||||||
|
|
||||||
|
|
||||||
|
async def _stream_chat(payload: dict, model: str, conv_id: str, request: Request):
|
||||||
|
"""Stream tokens to client in OpenAI SSE format, persist assistant response."""
|
||||||
|
full_response = []
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
async with client.stream(
|
||||||
|
"POST", f"{LLAMA_SERVER_BASE}/v1/chat/completions",
|
||||||
|
json=payload,
|
||||||
|
timeout=httpx.Timeout(300.0, connect=10.0),
|
||||||
|
) as resp:
|
||||||
|
async for line in resp.aiter_lines():
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
token, done, _, _ = parse_llama_stream_chunk(line)
|
||||||
|
if token:
|
||||||
|
full_response.append(token)
|
||||||
|
yield _build_openai_chunk(token, model, conv_id)
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
|
yield _build_openai_stop_chunk(model, conv_id)
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
# Persist assistant response
|
||||||
|
assistant_msg = "".join(full_response)
|
||||||
|
if assistant_msg:
|
||||||
|
db = get_db()
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO messages (conversation_id, role, content, created_at) VALUES (?, ?, ?, ?)",
|
||||||
|
(conv_id, "assistant", assistant_msg, datetime.now(timezone.utc).isoformat()),
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
except httpx.ConnectError:
|
||||||
|
err = {"error": {"message": "Cannot connect to inference server", "type": "connection_error"}}
|
||||||
|
yield f"data: {json.dumps(err)}\n\n"
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"completions stream error: {e}")
|
||||||
|
err = {"error": {"message": "Stream failed", "type": "server_error"}}
|
||||||
|
yield f"data: {json.dumps(err)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
async def _blocking_chat(payload: dict, model: str, conv_id: str, request: Request) -> JSONResponse:
|
||||||
|
"""Accumulate full response, return as standard OpenAI JSON object."""
|
||||||
|
full_response = []
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
async with client.stream(
|
||||||
|
"POST", f"{LLAMA_SERVER_BASE}/v1/chat/completions",
|
||||||
|
json=payload,
|
||||||
|
timeout=httpx.Timeout(300.0, connect=10.0),
|
||||||
|
) as resp:
|
||||||
|
async for line in resp.aiter_lines():
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
token, done, _, _ = parse_llama_stream_chunk(line)
|
||||||
|
if token:
|
||||||
|
full_response.append(token)
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
except httpx.ConnectError:
|
||||||
|
raise HTTPException(status_code=503, detail="Cannot connect to inference server")
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"completions blocking error: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail="Inference request failed")
|
||||||
|
|
||||||
|
assistant_msg = "".join(full_response)
|
||||||
|
|
||||||
|
if assistant_msg:
|
||||||
|
db = get_db()
|
||||||
|
db.execute(
|
||||||
|
"INSERT INTO messages (conversation_id, role, content, created_at) VALUES (?, ?, ?, ?)",
|
||||||
|
(conv_id, "assistant", assistant_msg, datetime.now(timezone.utc).isoformat()),
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
return JSONResponse(content=_build_openai_response(assistant_msg, model, conv_id))
|
||||||
|
|
||||||
|
|
||||||
|
async def _fim_passthrough(body: dict) -> JSONResponse:
|
||||||
|
"""
|
||||||
|
Proxy FIM requests directly to llama-server without pipeline injection.
|
||||||
|
Not persisted — autocomplete noise has no RAG value.
|
||||||
|
"""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{LLAMA_SERVER_BASE}/v1/completions",
|
||||||
|
json=body,
|
||||||
|
timeout=httpx.Timeout(30.0, connect=5.0),
|
||||||
|
)
|
||||||
|
return JSONResponse(content=resp.json(), status_code=resp.status_code)
|
||||||
|
except httpx.ConnectError:
|
||||||
|
raise HTTPException(status_code=503, detail="Cannot connect to inference server")
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"FIM passthrough error: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail="FIM request failed")
|
||||||
83
routers/conversations.py
Normal file
83
routers/conversations.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
"""JarvisChat routers - Conversation CRUD."""
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from db import get_db
|
||||||
|
from security import read_json_body, BODY_LIMIT_DEFAULT_BYTES
|
||||||
|
from config import DEFAULT_MODEL, MAX_CONVERSATION_TITLE_CHARS
|
||||||
|
|
||||||
|
log = logging.getLogger("jarvischat")
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/api/conversations")
|
||||||
|
async def list_conversations():
|
||||||
|
db = get_db()
|
||||||
|
rows = db.execute("SELECT * FROM conversations ORDER BY updated_at DESC").fetchall()
|
||||||
|
db.close()
|
||||||
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/api/conversations")
|
||||||
|
async def create_conversation(request: Request):
|
||||||
|
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||||
|
conv_id = str(uuid.uuid4())
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
model = body.get("model", DEFAULT_MODEL)
|
||||||
|
title = str(body.get("title", "New Chat"))[:MAX_CONVERSATION_TITLE_CHARS]
|
||||||
|
db = get_db()
|
||||||
|
db.execute("INSERT INTO conversations (id, title, model, created_at, updated_at) VALUES (?, ?, ?, ?, ?)",
|
||||||
|
(conv_id, title, model, now, now))
|
||||||
|
db.commit()
|
||||||
|
db.close()
|
||||||
|
return {"id": conv_id, "title": title, "model": model, "created_at": now, "updated_at": now}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/api/conversations/{conv_id}")
|
||||||
|
async def get_conversation(conv_id: str):
|
||||||
|
db = get_db()
|
||||||
|
conv = db.execute("SELECT * FROM conversations WHERE id = ?", (conv_id,)).fetchone()
|
||||||
|
if not conv:
|
||||||
|
db.close()
|
||||||
|
raise HTTPException(status_code=404, detail="Conversation not found")
|
||||||
|
messages = db.execute("SELECT * FROM messages WHERE conversation_id = ? ORDER BY id ASC", (conv_id,)).fetchall()
|
||||||
|
db.close()
|
||||||
|
return {"conversation": dict(conv), "messages": [dict(m) for m in messages]}
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/api/conversations/{conv_id}")
|
||||||
|
async def update_conversation(conv_id: str, request: Request):
|
||||||
|
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||||
|
db = get_db()
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
if "title" in body:
|
||||||
|
db.execute("UPDATE conversations SET title = ?, updated_at = ? WHERE id = ?",
|
||||||
|
(str(body["title"])[:MAX_CONVERSATION_TITLE_CHARS], now, conv_id))
|
||||||
|
if "model" in body:
|
||||||
|
db.execute("UPDATE conversations SET model = ?, updated_at = ? WHERE id = ?",
|
||||||
|
(body["model"], now, conv_id))
|
||||||
|
db.commit()
|
||||||
|
db.close()
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/api/conversations/{conv_id}")
|
||||||
|
async def delete_conversation(conv_id: str):
|
||||||
|
db = get_db()
|
||||||
|
db.execute("DELETE FROM messages WHERE conversation_id = ?", (conv_id,))
|
||||||
|
db.execute("DELETE FROM conversations WHERE id = ?", (conv_id,))
|
||||||
|
db.commit()
|
||||||
|
db.close()
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/api/conversations")
|
||||||
|
async def delete_all_conversations():
|
||||||
|
db = get_db()
|
||||||
|
db.execute("DELETE FROM messages")
|
||||||
|
db.execute("DELETE FROM conversations")
|
||||||
|
db.commit()
|
||||||
|
db.close()
|
||||||
|
log.info("Deleted all conversations")
|
||||||
|
return {"status": "ok"}
|
||||||
63
routers/memories.py
Normal file
63
routers/memories.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
"""JarvisChat routers - Memory CRUD API."""
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from db import get_db
|
||||||
|
from memory import add_memory, delete_memory, update_memory, get_all_memories, search_memories
|
||||||
|
from security import read_json_body, BODY_LIMIT_DEFAULT_BYTES
|
||||||
|
from config import MAX_MEMORY_FACT_CHARS
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/api/memories")
|
||||||
|
async def list_memories(topic: Optional[str] = None):
|
||||||
|
memories = get_all_memories(topic)
|
||||||
|
return {"memories": memories, "count": len(memories)}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/api/memories")
|
||||||
|
async def create_memory(request: Request):
|
||||||
|
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||||
|
fact = str(body.get("fact", "")).strip()
|
||||||
|
if not fact:
|
||||||
|
raise HTTPException(status_code=400, detail="Memory fact is required")
|
||||||
|
if len(fact) > MAX_MEMORY_FACT_CHARS:
|
||||||
|
raise HTTPException(status_code=413, detail="Memory fact is too long")
|
||||||
|
rowid = add_memory(fact=fact, topic=body.get("topic", "general"), source=body.get("source", "manual"))
|
||||||
|
return {"rowid": rowid, "status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/api/memories/{rowid}")
|
||||||
|
async def remove_memory(rowid: int):
|
||||||
|
if not delete_memory(rowid):
|
||||||
|
raise HTTPException(status_code=404, detail="Memory not found")
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/api/memories/{rowid}")
|
||||||
|
async def edit_memory(rowid: int, request: Request):
|
||||||
|
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||||
|
fact = str(body.get("fact", "")).strip()
|
||||||
|
if not fact:
|
||||||
|
raise HTTPException(status_code=400, detail="Memory fact is required")
|
||||||
|
if len(fact) > MAX_MEMORY_FACT_CHARS:
|
||||||
|
raise HTTPException(status_code=413, detail="Memory fact is too long")
|
||||||
|
if not update_memory(rowid, fact):
|
||||||
|
raise HTTPException(status_code=404, detail="Memory not found")
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/api/memories/search")
|
||||||
|
async def search_memories_api(q: str, limit: int = 10):
|
||||||
|
results = search_memories(q, limit=limit)
|
||||||
|
return {"results": results, "count": len(results)}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/api/memories/stats")
|
||||||
|
async def memory_stats():
|
||||||
|
db = get_db()
|
||||||
|
total = db.execute("SELECT COUNT(*) as c FROM memories").fetchone()["c"]
|
||||||
|
topics = db.execute("SELECT topic, COUNT(*) as c FROM memories GROUP BY topic ORDER BY c DESC").fetchall()
|
||||||
|
db.close()
|
||||||
|
return {"total": total, "by_topic": {row["topic"]: row["c"] for row in topics}}
|
||||||
77
routers/models.py
Normal file
77
routers/models.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""
|
||||||
|
JarvisChat routers - Model listing, system stats.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import psutil
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
|
||||||
|
from config import LLAMA_SERVER_BASE
|
||||||
|
from gpu import get_gpu_stats
|
||||||
|
from security import read_json_body, BODY_LIMIT_DEFAULT_BYTES
|
||||||
|
|
||||||
|
log = logging.getLogger("jarvischat")
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/api/models")
|
||||||
|
async def list_models():
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
resp = await client.get(f"{LLAMA_SERVER_BASE}/v1/models", timeout=10)
|
||||||
|
data = resp.json()
|
||||||
|
models = [{"name": m["id"], "model": m["id"]} for m in data.get("data", [])]
|
||||||
|
return {"models": models}
|
||||||
|
except httpx.ConnectError:
|
||||||
|
raise HTTPException(status_code=502, detail="Cannot connect to inference server.")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/api/ps")
|
||||||
|
async def running_models():
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
resp = await client.get(f"{LLAMA_SERVER_BASE}/v1/models", timeout=10)
|
||||||
|
return resp.json()
|
||||||
|
except httpx.ConnectError:
|
||||||
|
raise HTTPException(status_code=502, detail="Cannot connect to inference server.")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/api/show")
|
||||||
|
async def show_model(request: Request):
|
||||||
|
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
resp = await client.post(f"{LLAMA_SERVER_BASE}/api/show", json=body, timeout=10)
|
||||||
|
return resp.json()
|
||||||
|
except httpx.ConnectError:
|
||||||
|
raise HTTPException(status_code=502, detail="Cannot connect to inference server.")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/api/stats")
|
||||||
|
async def system_stats():
|
||||||
|
cpu_percent = psutil.cpu_percent(interval=0.1)
|
||||||
|
memory = psutil.virtual_memory()
|
||||||
|
gpu = get_gpu_stats()
|
||||||
|
return {
|
||||||
|
"cpu_percent": round(cpu_percent, 1),
|
||||||
|
"memory_percent": round(memory.percent, 1),
|
||||||
|
"memory_used_gb": round(memory.used / (1024**3), 1),
|
||||||
|
"memory_total_gb": round(memory.total / (1024**3), 1),
|
||||||
|
"gpu_percent": gpu["gpu_percent"],
|
||||||
|
"vram_percent": gpu["vram_percent"],
|
||||||
|
"gpu_available": gpu["available"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/api/search/status")
|
||||||
|
async def search_status():
|
||||||
|
from config import SEARXNG_BASE
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
resp = await client.get(f"{SEARXNG_BASE}/search",
|
||||||
|
params={"q": "test", "format": "json"}, timeout=5)
|
||||||
|
return {"available": resp.status_code == 200}
|
||||||
|
except Exception:
|
||||||
|
return {"available": False}
|
||||||
61
routers/presets.py
Normal file
61
routers/presets.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
"""JarvisChat routers - System prompt presets."""
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from db import get_db
|
||||||
|
from security import read_json_body, BODY_LIMIT_DEFAULT_BYTES
|
||||||
|
from config import MAX_PRESET_NAME_CHARS, MAX_PRESET_PROMPT_CHARS
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/api/presets")
|
||||||
|
async def list_presets():
|
||||||
|
db = get_db()
|
||||||
|
rows = db.execute("SELECT * FROM system_presets ORDER BY is_default DESC, name ASC").fetchall()
|
||||||
|
db.close()
|
||||||
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/api/presets")
|
||||||
|
async def create_preset(request: Request):
|
||||||
|
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||||
|
name = str(body.get("name", "")).strip()
|
||||||
|
prompt = str(body.get("prompt", "")).strip()
|
||||||
|
if not name or not prompt:
|
||||||
|
raise HTTPException(status_code=400, detail="Preset name and prompt are required")
|
||||||
|
if len(name) > MAX_PRESET_NAME_CHARS or len(prompt) > MAX_PRESET_PROMPT_CHARS:
|
||||||
|
raise HTTPException(status_code=413, detail="Preset fields are too long")
|
||||||
|
preset_id = str(uuid.uuid4())
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
db = get_db()
|
||||||
|
db.execute("INSERT INTO system_presets (id, name, prompt, is_default, created_at) VALUES (?, ?, ?, 0, ?)",
|
||||||
|
(preset_id, name, prompt, now))
|
||||||
|
db.commit()
|
||||||
|
db.close()
|
||||||
|
return {"id": preset_id, "name": name, "prompt": prompt}
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/api/presets/{preset_id}")
|
||||||
|
async def update_preset(preset_id: str, request: Request):
|
||||||
|
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||||
|
name = str(body.get("name", "")).strip()
|
||||||
|
prompt = str(body.get("prompt", "")).strip()
|
||||||
|
if not name or not prompt:
|
||||||
|
raise HTTPException(status_code=400, detail="Preset name and prompt are required")
|
||||||
|
if len(name) > MAX_PRESET_NAME_CHARS or len(prompt) > MAX_PRESET_PROMPT_CHARS:
|
||||||
|
raise HTTPException(status_code=413, detail="Preset fields are too long")
|
||||||
|
db = get_db()
|
||||||
|
db.execute("UPDATE system_presets SET name = ?, prompt = ? WHERE id = ?", (name, prompt, preset_id))
|
||||||
|
db.commit()
|
||||||
|
db.close()
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/api/presets/{preset_id}")
|
||||||
|
async def delete_preset(preset_id: str):
|
||||||
|
db = get_db()
|
||||||
|
db.execute("DELETE FROM system_presets WHERE id = ? AND is_default = 0", (preset_id,))
|
||||||
|
db.commit()
|
||||||
|
db.close()
|
||||||
|
return {"status": "ok"}
|
||||||
36
routers/profile.py
Normal file
36
routers/profile.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""JarvisChat routers - Profile."""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from db import get_db
|
||||||
|
from security import read_json_body, BODY_LIMIT_PROFILE_BYTES
|
||||||
|
from config import MAX_PROFILE_CHARS, DEFAULT_PROFILE
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/api/profile")
|
||||||
|
async def get_profile():
|
||||||
|
db = get_db()
|
||||||
|
row = db.execute("SELECT content, updated_at FROM profile WHERE id = 1").fetchone()
|
||||||
|
db.close()
|
||||||
|
return ({"content": row["content"], "updated_at": row["updated_at"]} if row
|
||||||
|
else {"content": "", "updated_at": ""})
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/api/profile")
|
||||||
|
async def update_profile(request: Request):
|
||||||
|
body = await read_json_body(request, BODY_LIMIT_PROFILE_BYTES)
|
||||||
|
content = str(body.get("content", ""))
|
||||||
|
if len(content) > MAX_PROFILE_CHARS:
|
||||||
|
raise HTTPException(status_code=413, detail="Profile content is too long")
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
db = get_db()
|
||||||
|
db.execute("UPDATE profile SET content = ?, updated_at = ? WHERE id = 1", (content, now))
|
||||||
|
db.commit()
|
||||||
|
db.close()
|
||||||
|
return {"status": "ok", "updated_at": now}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/api/profile/default")
|
||||||
|
async def get_default_profile():
|
||||||
|
return {"content": DEFAULT_PROFILE}
|
||||||
107
routers/search_route.py
Normal file
107
routers/search_route.py
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
"""JarvisChat routers - /api/search explicit search endpoint."""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
from config import DEFAULT_MODEL, LLAMA_SERVER_BASE, MAX_SEARCH_QUERY_CHARS
|
||||||
|
from db import get_db
|
||||||
|
from search import query_searxng, format_search_results
|
||||||
|
from routers.chat import parse_llama_stream_chunk
|
||||||
|
from security import read_json_body, log_incident, BODY_LIMIT_CHAT_BYTES
|
||||||
|
|
||||||
|
log = logging.getLogger("jarvischat")
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/api/search")
|
||||||
|
async def explicit_search(request: Request):
|
||||||
|
body = await read_json_body(request, BODY_LIMIT_CHAT_BYTES)
|
||||||
|
query = body.get("query", "").strip()
|
||||||
|
if len(query) > MAX_SEARCH_QUERY_CHARS:
|
||||||
|
raise HTTPException(status_code=413, detail="Search query is too long")
|
||||||
|
conv_id = body.get("conversation_id")
|
||||||
|
model = body.get("model", DEFAULT_MODEL)
|
||||||
|
|
||||||
|
if not query:
|
||||||
|
raise HTTPException(status_code=400, detail="Empty query")
|
||||||
|
|
||||||
|
db = get_db()
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
if not conv_id:
|
||||||
|
conv_id = str(uuid.uuid4())
|
||||||
|
title = query[:70] + "..." if len(query) > 70 else query
|
||||||
|
db.execute("INSERT INTO conversations (id, title, model, created_at, updated_at) VALUES (?, ?, ?, ?, ?)",
|
||||||
|
(conv_id, title, model, now, now))
|
||||||
|
else:
|
||||||
|
db.execute("UPDATE conversations SET updated_at = ? WHERE id = ?", (now, conv_id))
|
||||||
|
|
||||||
|
db.execute("INSERT INTO messages (conversation_id, role, content, created_at) VALUES (?, ?, ?, ?)",
|
||||||
|
(conv_id, "user", query, now))
|
||||||
|
db.commit()
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
async def stream_search():
|
||||||
|
yield f"data: {json.dumps({'conversation_id': conv_id, 'searching': True})}\n\n"
|
||||||
|
|
||||||
|
results = await query_searxng(query, max_results=5)
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
error_msg = "No search results found."
|
||||||
|
yield f"data: {json.dumps({'token': error_msg, 'conversation_id': conv_id})}\n\n"
|
||||||
|
db2 = get_db()
|
||||||
|
db2.execute("INSERT INTO messages (conversation_id, role, content, created_at) VALUES (?, ?, ?, ?)",
|
||||||
|
(conv_id, "assistant", error_msg, datetime.now(timezone.utc).isoformat()))
|
||||||
|
db2.commit()
|
||||||
|
db2.close()
|
||||||
|
yield f"data: {json.dumps({'done': True, 'conversation_id': conv_id})}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
yield f"data: {json.dumps({'search_results': len(results), 'conversation_id': conv_id})}\n\n"
|
||||||
|
|
||||||
|
search_context = format_search_results(results)
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": f"You have access to current web data. Answer directly using ONLY the data below. Be concise. No apologies. No disclaimers.\n\n{search_context}"},
|
||||||
|
{"role": "user", "content": query},
|
||||||
|
]
|
||||||
|
|
||||||
|
full_response = []
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
try:
|
||||||
|
async with client.stream(
|
||||||
|
"POST", f"{LLAMA_SERVER_BASE}/v1/chat/completions",
|
||||||
|
json={"model": model, "messages": messages, "stream": True},
|
||||||
|
timeout=httpx.Timeout(300.0, connect=10.0),
|
||||||
|
) as resp:
|
||||||
|
async for line in resp.aiter_lines():
|
||||||
|
if line.strip():
|
||||||
|
token, done, _, _ = parse_llama_stream_chunk(line)
|
||||||
|
if token:
|
||||||
|
full_response.append(token)
|
||||||
|
yield f"data: {json.dumps({'token': token, 'conversation_id': conv_id})}\n\n"
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
incident_key = log_incident("search_summarization_stream",
|
||||||
|
message="Stream failure during explicit search summarization",
|
||||||
|
request=request, exc=e)
|
||||||
|
yield f"data: {json.dumps({'error': 'Search summarization could not complete right now.', 'error_key': incident_key})}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
summary = "".join(full_response)
|
||||||
|
saved_msg = f"{summary}\n\n---\n*🔍 Web search results*"
|
||||||
|
|
||||||
|
db2 = get_db()
|
||||||
|
db2.execute("INSERT INTO messages (conversation_id, role, content, created_at) VALUES (?, ?, ?, ?)",
|
||||||
|
(conv_id, "assistant", saved_msg, datetime.now(timezone.utc).isoformat()))
|
||||||
|
db2.commit()
|
||||||
|
db2.close()
|
||||||
|
|
||||||
|
yield f"data: {json.dumps({'done': True, 'conversation_id': conv_id, 'searched': True})}\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(stream_search(), media_type="text/event-stream")
|
||||||
36
routers/settings.py
Normal file
36
routers/settings.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""JarvisChat routers - Settings."""
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from db import get_db
|
||||||
|
from security import read_json_body, BODY_LIMIT_DEFAULT_BYTES
|
||||||
|
from config import MAX_SETTINGS_KEYS, MAX_SETTINGS_VALUE_CHARS, ALLOWED_SETTINGS_KEYS
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/api/settings")
|
||||||
|
async def get_settings():
|
||||||
|
db = get_db()
|
||||||
|
rows = db.execute("SELECT key, value FROM settings").fetchall()
|
||||||
|
db.close()
|
||||||
|
return {row["key"]: row["value"] for row in rows}
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/api/settings")
|
||||||
|
async def update_settings(request: Request):
|
||||||
|
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||||
|
if not isinstance(body, dict):
|
||||||
|
raise HTTPException(status_code=400, detail="Settings payload must be an object")
|
||||||
|
if len(body) > MAX_SETTINGS_KEYS:
|
||||||
|
raise HTTPException(status_code=413, detail="Too many settings in one request")
|
||||||
|
unknown_keys = sorted(key for key in body.keys() if str(key) not in ALLOWED_SETTINGS_KEYS)
|
||||||
|
if unknown_keys:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Unknown setting key(s): {', '.join(unknown_keys)}")
|
||||||
|
db = get_db()
|
||||||
|
for key, value in body.items():
|
||||||
|
if len(str(key)) > 80 or len(str(value)) > MAX_SETTINGS_VALUE_CHARS:
|
||||||
|
db.close()
|
||||||
|
raise HTTPException(status_code=413, detail="Setting key/value too long")
|
||||||
|
db.execute("INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)", (key, str(value)))
|
||||||
|
db.commit()
|
||||||
|
db.close()
|
||||||
|
return {"status": "ok"}
|
||||||
42
routers/skills.py
Normal file
42
routers/skills.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""JarvisChat routers - Skills."""
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from db import get_db, get_setting, list_skills_with_state, set_skill_enabled
|
||||||
|
from security import read_json_body, BODY_LIMIT_DEFAULT_BYTES
|
||||||
|
from config import MAX_SKILL_KEY_CHARS, SKILLS_BY_KEY
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/api/skills")
|
||||||
|
async def list_skills():
|
||||||
|
db = get_db()
|
||||||
|
skills = list_skills_with_state(db)
|
||||||
|
db.close()
|
||||||
|
return {"skills": skills, "count": len(skills)}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/api/skills/active")
|
||||||
|
async def list_active_skills():
|
||||||
|
db = get_db()
|
||||||
|
skills_enabled = get_setting(db, "skills_enabled", "true") == "true"
|
||||||
|
skills = list_skills_with_state(db)
|
||||||
|
db.close()
|
||||||
|
active = [s for s in skills if s["enabled"]] if skills_enabled else []
|
||||||
|
return {"skills": active, "count": len(active), "skills_enabled": skills_enabled}
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/api/skills/{skill_key}")
|
||||||
|
async def update_skill(skill_key: str, request: Request):
|
||||||
|
skill_key = skill_key.strip()
|
||||||
|
if len(skill_key) > MAX_SKILL_KEY_CHARS or skill_key not in SKILLS_BY_KEY:
|
||||||
|
raise HTTPException(status_code=404, detail="Unknown skill")
|
||||||
|
body = await read_json_body(request, BODY_LIMIT_DEFAULT_BYTES)
|
||||||
|
if "enabled" not in body or not isinstance(body.get("enabled"), bool):
|
||||||
|
raise HTTPException(status_code=400, detail="Field 'enabled' (boolean) is required")
|
||||||
|
db = get_db()
|
||||||
|
set_skill_enabled(db, skill_key, bool(body["enabled"]))
|
||||||
|
db.commit()
|
||||||
|
skills = list_skills_with_state(db)
|
||||||
|
db.close()
|
||||||
|
updated = next((s for s in skills if s["key"] == skill_key), None)
|
||||||
|
return {"status": "ok", "skill": updated}
|
||||||
141
search.py
Normal file
141
search.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""
|
||||||
|
JarvisChat - SearXNG integration, perplexity scoring, refusal/hedge detection.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from config import SEARXNG_BASE, PERPLEXITY_THRESHOLD, REFUSAL_PATTERNS, HEDGE_PATTERNS
|
||||||
|
|
||||||
|
log = logging.getLogger("jarvischat")
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_outbound_url(url: str) -> str:
|
||||||
|
if not url:
|
||||||
|
return ""
|
||||||
|
candidate = url.strip()
|
||||||
|
parsed = urlparse(candidate)
|
||||||
|
if parsed.scheme.lower() in {"http", "https"} and parsed.netloc:
|
||||||
|
return candidate
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_perplexity(logprobs: list) -> float:
|
||||||
|
if not logprobs:
|
||||||
|
return 0.0
|
||||||
|
avg_logprob = sum(lp["logprob"] for lp in logprobs) / len(logprobs)
|
||||||
|
return math.exp(-avg_logprob)
|
||||||
|
|
||||||
|
|
||||||
|
def is_uncertain(logprobs: list, threshold: float = PERPLEXITY_THRESHOLD) -> bool:
|
||||||
|
if not logprobs:
|
||||||
|
return False
|
||||||
|
perplexity = calculate_perplexity(logprobs)
|
||||||
|
log.info(f"Perplexity: {perplexity:.2f} (threshold: {threshold})")
|
||||||
|
return perplexity > threshold
|
||||||
|
|
||||||
|
|
||||||
|
def is_refusal(text: str) -> bool:
|
||||||
|
match = REFUSAL_PATTERNS.search(text)
|
||||||
|
if match:
|
||||||
|
log.info(f"Refusal detected: '{match.group()}'")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def clean_hedging(text: str) -> str:
|
||||||
|
cleaned = text
|
||||||
|
for pattern in HEDGE_PATTERNS:
|
||||||
|
cleaned = re.sub(pattern, "", cleaned, flags=re.IGNORECASE)
|
||||||
|
return cleaned.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def format_search_results(results: list) -> str:
|
||||||
|
if not results:
|
||||||
|
return ""
|
||||||
|
lines = ["[LIVE WEB DATA]\n"]
|
||||||
|
for i, r in enumerate(results, 1):
|
||||||
|
lines.append(f"{i}. {r['title']}")
|
||||||
|
if r["content"]:
|
||||||
|
lines.append(f" {r['content']}")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("\nAnswer directly using the data above. No apologies. No disclaimers. Just answer.")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def format_direct_answer(question: str, results: list) -> str:
|
||||||
|
if not results:
|
||||||
|
return "No search results found."
|
||||||
|
lines = ["Here's what I found:\n"]
|
||||||
|
for r in results[:3]:
|
||||||
|
lines.append(f"**{r['title']}**")
|
||||||
|
if r["content"]:
|
||||||
|
lines.append(f"{r['content']}")
|
||||||
|
lines.append("")
|
||||||
|
return "\n".join(lines).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def extract_search_query(user_message: str) -> str:
|
||||||
|
query = user_message.strip()
|
||||||
|
weather_lead = re.match(r"^(?:what('?s| is) the\s+)?(?:weather|temperature|forecast)\s+(?:in\s+|for\s+)?(.+)", query, re.IGNORECASE)
|
||||||
|
if weather_lead:
|
||||||
|
return (weather_lead.group(2) + " weather").strip()[:100]
|
||||||
|
price_lead = re.match(r"^(?:what('?s| is| are)\s+)?(?:the\s+)?(?:price|spot price)\s+(?:of\s+|for\s+)?(.+)", query, re.IGNORECASE)
|
||||||
|
if price_lead:
|
||||||
|
return (price_lead.group(2) + " price today USD").strip()[:100]
|
||||||
|
return query[:100]
|
||||||
|
|
||||||
|
|
||||||
|
async def query_searxng(query: str, max_results: int = 5) -> list:
|
||||||
|
log.info(f"Querying SearXNG: '{query}'")
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
weather_match = re.search(
|
||||||
|
r"(?:weather|temperature|forecast)\s+(?:in\s+)?(.+?)(?:\s+right now|\s+today|\s+degrees)?$",
|
||||||
|
query, re.IGNORECASE,
|
||||||
|
)
|
||||||
|
if weather_match or "weather" in query.lower() or "temperature" in query.lower():
|
||||||
|
location = (
|
||||||
|
weather_match.group(1) if weather_match
|
||||||
|
else re.sub(r"(weather|temperature|forecast|right now|today|degrees)", "", query, flags=re.IGNORECASE).strip()
|
||||||
|
)
|
||||||
|
if location:
|
||||||
|
try:
|
||||||
|
resp = await client.get(f"https://wttr.in/{location}?format=3", timeout=10.0,
|
||||||
|
headers={"User-Agent": "curl/7.68.0"})
|
||||||
|
if resp.status_code == 200:
|
||||||
|
return [{"title": "Current Weather",
|
||||||
|
"url": sanitize_outbound_url(f"https://wttr.in/{location}"),
|
||||||
|
"content": resp.text.strip()}]
|
||||||
|
except Exception as e:
|
||||||
|
log.warning(f"wttr.in error: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = await client.get(
|
||||||
|
f"{SEARXNG_BASE}/search",
|
||||||
|
params={"q": query, "format": "json", "categories": "general"},
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
if resp.status_code == 200:
|
||||||
|
data = resp.json()
|
||||||
|
results = []
|
||||||
|
for answer in data.get("answers", []):
|
||||||
|
results.append({"title": "Direct Answer", "url": "", "content": answer})
|
||||||
|
for box in data.get("infoboxes", []):
|
||||||
|
content = box.get("content", "")
|
||||||
|
if not content and box.get("attributes"):
|
||||||
|
content = " | ".join([f"{a.get('label','')}: {a.get('value','')}" for a in box["attributes"]])
|
||||||
|
results.append({
|
||||||
|
"title": box.get("infobox", "Info"),
|
||||||
|
"url": sanitize_outbound_url(box.get("urls", [{}])[0].get("url", "") if box.get("urls") else ""),
|
||||||
|
"content": content,
|
||||||
|
})
|
||||||
|
for r in data.get("results", [])[:max_results]:
|
||||||
|
results.append({"title": r.get("title", ""), "url": sanitize_outbound_url(r.get("url", "")), "content": r.get("content", "")})
|
||||||
|
log.info(f"SearXNG returned {len(results)} results")
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"SearXNG error: {e}")
|
||||||
|
return []
|
||||||
175
security.py
Normal file
175
security.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
"""
|
||||||
|
JarvisChat - Security utilities.
|
||||||
|
PIN hashing, audit logging, incident tracking, CSRF/origin checks,
|
||||||
|
rate limiting, request helpers.
|
||||||
|
"""
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from threading import Lock
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
|
from config import (
|
||||||
|
ALLOWED_NETWORKS, TRUST_X_FORWARDED_FOR, TRUSTED_ORIGINS,
|
||||||
|
BODY_LIMIT_DEFAULT_BYTES, BODY_LIMIT_CHAT_BYTES, BODY_LIMIT_PROFILE_BYTES,
|
||||||
|
RATE_WINDOW_SECONDS, RL_LOGIN_PER_WINDOW, RL_CHAT_PER_WINDOW,
|
||||||
|
RL_SEARCH_PER_WINDOW, RL_STATS_PER_WINDOW, RL_WRITE_PER_WINDOW,
|
||||||
|
RL_DEFAULT_PER_WINDOW, VERSION,
|
||||||
|
)
|
||||||
|
|
||||||
|
import ipaddress
|
||||||
|
|
||||||
|
log = logging.getLogger("jarvischat")
|
||||||
|
|
||||||
|
SESSIONS: dict = {}
|
||||||
|
PIN_ATTEMPTS: dict = {}
|
||||||
|
RATE_EVENTS: dict = defaultdict(deque)
|
||||||
|
SESSION_LOCK = Lock()
|
||||||
|
RATE_LOCK = Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def hash_pin(pin: str, salt_hex: Optional[str] = None) -> tuple:
|
||||||
|
salt = bytes.fromhex(salt_hex) if salt_hex else os.urandom(16)
|
||||||
|
digest = hashlib.pbkdf2_hmac("sha256", pin.encode("utf-8"), salt, 200_000)
|
||||||
|
return salt.hex(), digest.hex()
|
||||||
|
|
||||||
|
|
||||||
|
def audit_event(event: str, outcome: str, *, ip: str = "unknown", role: str = "none",
|
||||||
|
details: str = "", warning: bool = False) -> None:
|
||||||
|
payload = {"event": event, "outcome": outcome, "ip": ip, "role": role, "details": details[:300]}
|
||||||
|
msg = "AUDIT " + json.dumps(payload, separators=(",", ":"))
|
||||||
|
if warning:
|
||||||
|
log.warning(msg)
|
||||||
|
else:
|
||||||
|
log.info(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def create_incident_key() -> str:
|
||||||
|
ts = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
||||||
|
return f"INC-{ts}-{uuid.uuid4().hex[:8].upper()}"
|
||||||
|
|
||||||
|
|
||||||
|
def customer_error_envelope(message: str, incident_key: str) -> dict:
|
||||||
|
return {
|
||||||
|
"detail": message, "error_key": incident_key,
|
||||||
|
"error": {"message": message, "incident_key": incident_key,
|
||||||
|
"support_hint": "Share this incident key for exact diagnostics."},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def log_incident(event: str, *, message: str, request: Optional[Request] = None,
|
||||||
|
exc: Optional[Exception] = None) -> str:
|
||||||
|
incident_key = create_incident_key()
|
||||||
|
payload = {
|
||||||
|
"event": event, "incident_key": incident_key, "message": message,
|
||||||
|
"app_version": VERSION, "pid": os.getpid(), "python": platform.python_version(),
|
||||||
|
"platform": platform.platform(),
|
||||||
|
"method": request.method if request else "",
|
||||||
|
"path": request.url.path if request else "",
|
||||||
|
"client_ip": get_client_ip(request) if request else "",
|
||||||
|
}
|
||||||
|
if exc:
|
||||||
|
log.exception("INCIDENT " + json.dumps(payload, separators=(",", ":")))
|
||||||
|
else:
|
||||||
|
log.error("INCIDENT " + json.dumps(payload, separators=(",", ":")))
|
||||||
|
return incident_key
|
||||||
|
|
||||||
|
|
||||||
|
def get_client_ip(request: Request) -> str:
|
||||||
|
forwarded = request.headers.get("x-forwarded-for", "").strip()
|
||||||
|
if TRUST_X_FORWARDED_FOR and forwarded:
|
||||||
|
return forwarded.split(",")[0].strip()
|
||||||
|
if request.client and request.client.host:
|
||||||
|
return request.client.host
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def is_ip_allowed(ip: str) -> bool:
|
||||||
|
normalized = ip.strip().lower()
|
||||||
|
if normalized in {"localhost", "testclient"}:
|
||||||
|
normalized = "127.0.0.1"
|
||||||
|
try:
|
||||||
|
ip_obj = ipaddress.ip_address(normalized)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
for network in ALLOWED_NETWORKS:
|
||||||
|
if ip_obj in network:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def request_body_limit(path: str) -> int:
|
||||||
|
if path in {"/api/chat", "/api/search"}:
|
||||||
|
return BODY_LIMIT_CHAT_BYTES
|
||||||
|
if path == "/api/profile":
|
||||||
|
return BODY_LIMIT_PROFILE_BYTES
|
||||||
|
return BODY_LIMIT_DEFAULT_BYTES
|
||||||
|
|
||||||
|
|
||||||
|
def rate_policy(path: str, method: str, ip: str, sid: str) -> tuple:
|
||||||
|
identity = sid or ip
|
||||||
|
if path == "/api/auth/login":
|
||||||
|
return f"login:{ip}", RL_LOGIN_PER_WINDOW
|
||||||
|
if path == "/api/chat":
|
||||||
|
return f"chat:{identity}", RL_CHAT_PER_WINDOW
|
||||||
|
if path == "/api/search":
|
||||||
|
return f"search:{identity}", RL_SEARCH_PER_WINDOW
|
||||||
|
if path == "/api/stats":
|
||||||
|
return f"stats:{ip}", RL_STATS_PER_WINDOW
|
||||||
|
if method in {"POST", "PUT", "DELETE", "PATCH"}:
|
||||||
|
return f"write:{identity}", RL_WRITE_PER_WINDOW
|
||||||
|
return f"api:{identity}", RL_DEFAULT_PER_WINDOW
|
||||||
|
|
||||||
|
|
||||||
|
def check_rate_limit(key: str, limit: int, window_seconds: int) -> tuple:
|
||||||
|
now_ts = time.time()
|
||||||
|
with RATE_LOCK:
|
||||||
|
bucket = RATE_EVENTS[key]
|
||||||
|
while bucket and bucket[0] <= (now_ts - window_seconds):
|
||||||
|
bucket.popleft()
|
||||||
|
if len(bucket) >= limit:
|
||||||
|
retry_after = max(1, int(math.ceil(window_seconds - (now_ts - bucket[0]))))
|
||||||
|
return False, retry_after
|
||||||
|
bucket.append(now_ts)
|
||||||
|
return True, 0
|
||||||
|
|
||||||
|
|
||||||
|
def origin_allowed(request: Request) -> bool:
|
||||||
|
host = request.headers.get("host", "").strip()
|
||||||
|
expected_origin = f"{request.url.scheme}://{host}".rstrip("/") if host else ""
|
||||||
|
origin = request.headers.get("origin", "").strip().rstrip("/")
|
||||||
|
referer = request.headers.get("referer", "").strip()
|
||||||
|
if origin:
|
||||||
|
return origin == expected_origin or origin in TRUSTED_ORIGINS
|
||||||
|
if referer:
|
||||||
|
parsed = urlparse(referer)
|
||||||
|
ref_origin = f"{parsed.scheme}://{parsed.netloc}".rstrip("/")
|
||||||
|
return ref_origin == expected_origin or ref_origin in TRUSTED_ORIGINS
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_state_changing(method: str) -> bool:
|
||||||
|
return method in {"POST", "PUT", "DELETE", "PATCH"}
|
||||||
|
|
||||||
|
|
||||||
|
async def read_json_body(request: Request, max_bytes: int) -> dict:
|
||||||
|
raw = await request.body()
|
||||||
|
if len(raw) > max_bytes:
|
||||||
|
raise HTTPException(status_code=413, detail="Request payload too large")
|
||||||
|
if not raw:
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
return json.loads(raw.decode("utf-8"))
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid JSON payload")
|
||||||
BIN
static/jcscreenie.png
Normal file
BIN
static/jcscreenie.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 270 KiB |
@@ -47,13 +47,12 @@ body { font-family: var(--font-body); background: var(--bg-primary); color: var(
|
|||||||
.delete-all-btn { padding: 10px 12px; background: transparent; border: 1px solid var(--danger); border-radius: var(--radius); color: var(--danger); font-size: 14px; cursor: pointer; transition: all 0.2s; }
|
.delete-all-btn { padding: 10px 12px; background: transparent; border: 1px solid var(--danger); border-radius: var(--radius); color: var(--danger); font-size: 14px; cursor: pointer; transition: all 0.2s; }
|
||||||
.delete-all-btn:hover { background: var(--danger); color: #fff; }
|
.delete-all-btn:hover { background: var(--danger); color: #fff; }
|
||||||
.conversation-list { flex: 1; overflow-y: auto; padding: 8px; }
|
.conversation-list { flex: 1; overflow-y: auto; padding: 8px; }
|
||||||
.conv-item { padding: 10px 12px; border-radius: var(--radius); cursor: pointer; margin-bottom: 2px; display: flex; justify-content: space-between; align-items: center; transition: background 0.15s; font-size: 13px; color: var(--text-secondary); }
|
.conv-item { padding: 10px 12px; border-radius: var(--radius); cursor: pointer; margin-bottom: 2px; display: flex; align-items: center; gap: 8px; transition: background 0.15s; font-size: 13px; color: var(--text-secondary); }
|
||||||
.conv-item:hover { background: var(--bg-hover); color: var(--text-primary); }
|
.conv-item:hover { background: var(--bg-hover); color: var(--text-primary); }
|
||||||
.conv-item.active { background: var(--bg-tertiary); color: var(--text-primary); }
|
.conv-item.active { background: var(--bg-tertiary); color: var(--text-primary); }
|
||||||
.conv-item .conv-title { overflow: hidden; text-overflow: ellipsis; white-space: nowrap; flex: 1; }
|
.conv-item .conv-trash { color: var(--text-muted); cursor: pointer; padding: 2px 2px; font-size: 15px; flex-shrink: 0; transition: color 0.15s; }
|
||||||
.conv-item .conv-delete { opacity: 0; color: var(--danger); cursor: pointer; padding: 2px 6px; font-size: 16px; }
|
.conv-item .conv-trash:hover { opacity: 1; color: var(--danger); }
|
||||||
.conv-item:hover .conv-delete { opacity: 0.7; }
|
.conv-item .conv-title { overflow: hidden; text-overflow: ellipsis; white-space: nowrap; flex: 1; min-width: 0; }
|
||||||
.conv-item .conv-delete:hover { opacity: 1; }
|
|
||||||
.sidebar-footer { padding: 12px 16px; border-top: 1px solid var(--border); font-size: 11px; color: var(--text-muted); font-family: var(--font-mono); }
|
.sidebar-footer { padding: 12px 16px; border-top: 1px solid var(--border); font-size: 11px; color: var(--text-muted); font-family: var(--font-mono); }
|
||||||
.sidebar-footer .status-row { display: flex; align-items: center; gap: 8px; margin-bottom: 4px; }
|
.sidebar-footer .status-row { display: flex; align-items: center; gap: 8px; margin-bottom: 4px; }
|
||||||
.stats-panel { margin-top: 10px; padding-top: 10px; border-top: 1px solid var(--border); }
|
.stats-panel { margin-top: 10px; padding-top: 10px; border-top: 1px solid var(--border); }
|
||||||
@@ -227,7 +226,7 @@ body { font-family: var(--font-body); background: var(--bg-primary); color: var(
|
|||||||
|
|
||||||
<aside class="sidebar" id="sidebar">
|
<aside class="sidebar" id="sidebar">
|
||||||
<div class="sidebar-header">
|
<div class="sidebar-header">
|
||||||
<img class="logo" src="/static/logo.jpg" alt="JarvisChat Logo" onerror="this.style.display='none'">
|
<img class="logo" src="/static/logo.png" alt="JarvisChat Logo" onerror="this.style.display='none'">
|
||||||
<h1>⚡ JarvisChat {{ version }}</h1>
|
<h1>⚡ JarvisChat {{ version }}</h1>
|
||||||
<div class="subtitle">🦙 local coding companion</div>
|
<div class="subtitle">🦙 local coding companion</div>
|
||||||
<div class="btn-row">
|
<div class="btn-row">
|
||||||
@@ -983,8 +982,7 @@ async function loadConversations() {
|
|||||||
convs.forEach(c => {
|
convs.forEach(c => {
|
||||||
const div = document.createElement('div');
|
const div = document.createElement('div');
|
||||||
div.className = 'conv-item' + (c.id === currentConvId ? ' active' : '');
|
div.className = 'conv-item' + (c.id === currentConvId ? ' active' : '');
|
||||||
const delBtn = currentRole === 'admin' ? `<span class="conv-delete" onclick="event.stopPropagation();deleteConversation('${c.id}')">×</span>` : '';
|
div.innerHTML = `<span class="conv-trash" onclick="event.stopPropagation();deleteConversation('${c.id}')" title="Delete conversation">🗑</span><span class="conv-title" onclick="loadConversation('${c.id}')">${c.title}</span>`;
|
||||||
div.innerHTML = `<span class="conv-title" onclick="loadConversation('${c.id}')">${c.title}</span>${delBtn}`;
|
|
||||||
list.appendChild(div);
|
list.appendChild(div);
|
||||||
});
|
});
|
||||||
} catch(e) {}
|
} catch(e) {}
|
||||||
|
|||||||
@@ -3,16 +3,18 @@ from pathlib import Path
|
|||||||
|
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
import app as app_module
|
import app
|
||||||
|
import db
|
||||||
|
from security import SESSIONS, PIN_ATTEMPTS
|
||||||
|
|
||||||
|
|
||||||
def make_client(tmp_path: Path) -> TestClient:
|
def make_client(tmp_path: Path) -> TestClient:
|
||||||
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
||||||
app_module.DB_PATH = tmp_path / "jarvischat-test.db"
|
db.DB_PATH = tmp_path / "jarvischat-test.db"
|
||||||
app_module.SESSIONS.clear()
|
SESSIONS.clear()
|
||||||
app_module.PIN_ATTEMPTS.clear()
|
PIN_ATTEMPTS.clear()
|
||||||
app_module.init_db()
|
db.init_db()
|
||||||
return TestClient(app_module.app)
|
return TestClient(app.app)
|
||||||
|
|
||||||
|
|
||||||
def test_guest_read_only_admin_write_blocked(tmp_path: Path):
|
def test_guest_read_only_admin_write_blocked(tmp_path: Path):
|
||||||
@@ -20,7 +22,7 @@ def test_guest_read_only_admin_write_blocked(tmp_path: Path):
|
|||||||
guest = client.post("/api/auth/guest", headers={"Origin": "http://testserver"})
|
guest = client.post("/api/auth/guest", headers={"Origin": "http://testserver"})
|
||||||
assert guest.status_code == 200
|
assert guest.status_code == 200
|
||||||
sid = guest.json()["session_id"]
|
sid = guest.json()["session_id"]
|
||||||
headers = {"X-Session-ID": sid}
|
headers = {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||||
|
|
||||||
read_resp = client.get("/api/memories", headers=headers)
|
read_resp = client.get("/api/memories", headers=headers)
|
||||||
assert read_resp.status_code == 200
|
assert read_resp.status_code == 200
|
||||||
@@ -74,5 +76,5 @@ def test_logout_revokes_session(tmp_path: Path):
|
|||||||
logout = client.post("/api/auth/logout", headers=headers)
|
logout = client.post("/api/auth/logout", headers=headers)
|
||||||
assert logout.status_code == 200
|
assert logout.status_code == 200
|
||||||
|
|
||||||
after = client.get("/api/memories", headers={"X-Session-ID": sid})
|
after = client.get("/api/memories", headers={"X-Session-ID": sid, "Origin": "http://testserver"})
|
||||||
assert after.status_code == 401
|
assert after.status_code == 401
|
||||||
|
|||||||
@@ -2,19 +2,24 @@ import json
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import httpx
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
import app as app_module
|
import app
|
||||||
|
import config
|
||||||
|
import db
|
||||||
|
import routers.chat
|
||||||
|
from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS
|
||||||
|
|
||||||
|
|
||||||
def make_client(tmp_path: Path) -> TestClient:
|
def make_client(tmp_path: Path) -> TestClient:
|
||||||
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
||||||
app_module.DB_PATH = tmp_path / "jarvischat-streaming.db"
|
db.DB_PATH = tmp_path / "jarvischat-streaming.db"
|
||||||
app_module.SESSIONS.clear()
|
SESSIONS.clear()
|
||||||
app_module.PIN_ATTEMPTS.clear()
|
PIN_ATTEMPTS.clear()
|
||||||
app_module.RATE_EVENTS.clear()
|
RATE_EVENTS.clear()
|
||||||
app_module.init_db()
|
db.init_db()
|
||||||
return TestClient(app_module.app, raise_server_exceptions=False)
|
return TestClient(app.app, raise_server_exceptions=False)
|
||||||
|
|
||||||
|
|
||||||
def parse_sse_payloads(body: str) -> list[dict]:
|
def parse_sse_payloads(body: str) -> list[dict]:
|
||||||
@@ -65,11 +70,11 @@ def test_chat_stream_emits_tokens_and_done(tmp_path: Path, monkeypatch):
|
|||||||
def stream_stub(self, method, url, json=None, timeout=None):
|
def stream_stub(self, method, url, json=None, timeout=None):
|
||||||
return _MockStreamResponse(events)
|
return _MockStreamResponse(events)
|
||||||
|
|
||||||
monkeypatch.setattr(app_module.httpx.AsyncClient, "stream", stream_stub)
|
monkeypatch.setattr(httpx.AsyncClient, "stream", stream_stub)
|
||||||
|
|
||||||
resp = client.post(
|
resp = client.post(
|
||||||
"/api/chat",
|
"/api/chat",
|
||||||
json={"message": "hello", "model": app_module.DEFAULT_MODEL},
|
json={"message": "hello", "model": config.DEFAULT_MODEL},
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
@@ -92,7 +97,7 @@ def test_chat_auto_search_trigger_emits_search_events(tmp_path: Path, monkeypatc
|
|||||||
first_stream = _stream_json_lines(
|
first_stream = _stream_json_lines(
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"message": {"content": "I am uncertain."},
|
"message": {"content": "I don't have current data on that question."},
|
||||||
"logprobs": [{"logprob": -5.0}],
|
"logprobs": [{"logprob": -5.0}],
|
||||||
},
|
},
|
||||||
{"done": True, "eval_count": 2, "eval_duration": 1000000000},
|
{"done": True, "eval_count": 2, "eval_duration": 1000000000},
|
||||||
@@ -118,12 +123,12 @@ def test_chat_auto_search_trigger_emits_search_events(tmp_path: Path, monkeypatc
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
monkeypatch.setattr(app_module.httpx.AsyncClient, "stream", stream_stub)
|
monkeypatch.setattr(httpx.AsyncClient, "stream", stream_stub)
|
||||||
monkeypatch.setattr(app_module, "query_searxng", search_stub)
|
monkeypatch.setattr(routers.chat, "query_searxng", search_stub)
|
||||||
|
|
||||||
resp = client.post(
|
resp = client.post(
|
||||||
"/api/chat",
|
"/api/chat",
|
||||||
json={"message": "what is the latest value", "model": app_module.DEFAULT_MODEL},
|
json={"message": "what is the latest value", "model": config.DEFAULT_MODEL},
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
@@ -153,13 +158,13 @@ def test_memory_command_paths_remember_and_forget(tmp_path: Path, monkeypatch):
|
|||||||
def stream_stub(self, method, url, json=None, timeout=None):
|
def stream_stub(self, method, url, json=None, timeout=None):
|
||||||
return _MockStreamResponse(base_stream)
|
return _MockStreamResponse(base_stream)
|
||||||
|
|
||||||
monkeypatch.setattr(app_module.httpx.AsyncClient, "stream", stream_stub)
|
monkeypatch.setattr(httpx.AsyncClient, "stream", stream_stub)
|
||||||
|
|
||||||
remember_resp = client.post(
|
remember_resp = client.post(
|
||||||
"/api/chat",
|
"/api/chat",
|
||||||
json={
|
json={
|
||||||
"message": "remember that my favorite language is rust",
|
"message": "remember that my favorite language is rust",
|
||||||
"model": app_module.DEFAULT_MODEL,
|
"model": config.DEFAULT_MODEL,
|
||||||
},
|
},
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
@@ -167,7 +172,7 @@ def test_memory_command_paths_remember_and_forget(tmp_path: Path, monkeypatch):
|
|||||||
remember_events = parse_sse_payloads(remember_resp.text)
|
remember_events = parse_sse_payloads(remember_resp.text)
|
||||||
assert any("Remembered" in p.get("token", "") for p in remember_events)
|
assert any("Remembered" in p.get("token", "") for p in remember_events)
|
||||||
|
|
||||||
memories_after_add = client.get("/api/memories", headers={"X-Session-ID": sid})
|
memories_after_add = client.get("/api/memories", headers={"X-Session-ID": sid, "Origin": "http://testserver"})
|
||||||
assert memories_after_add.status_code == 200
|
assert memories_after_add.status_code == 200
|
||||||
assert memories_after_add.json().get("count", 0) >= 1
|
assert memories_after_add.json().get("count", 0) >= 1
|
||||||
|
|
||||||
@@ -175,7 +180,7 @@ def test_memory_command_paths_remember_and_forget(tmp_path: Path, monkeypatch):
|
|||||||
"/api/chat",
|
"/api/chat",
|
||||||
json={
|
json={
|
||||||
"message": "forget about my favorite language",
|
"message": "forget about my favorite language",
|
||||||
"model": app_module.DEFAULT_MODEL,
|
"model": config.DEFAULT_MODEL,
|
||||||
},
|
},
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
@@ -183,6 +188,6 @@ def test_memory_command_paths_remember_and_forget(tmp_path: Path, monkeypatch):
|
|||||||
forget_events = parse_sse_payloads(forget_resp.text)
|
forget_events = parse_sse_payloads(forget_resp.text)
|
||||||
assert any("Forgot" in p.get("token", "") for p in forget_events)
|
assert any("Forgot" in p.get("token", "") for p in forget_events)
|
||||||
|
|
||||||
memories_after_forget = client.get("/api/memories", headers={"X-Session-ID": sid})
|
memories_after_forget = client.get("/api/memories", headers={"X-Session-ID": sid, "Origin": "http://testserver"})
|
||||||
assert memories_after_forget.status_code == 200
|
assert memories_after_forget.status_code == 200
|
||||||
assert memories_after_forget.json().get("count", 0) == 0
|
assert memories_after_forget.json().get("count", 0) == 0
|
||||||
|
|||||||
222
tests/test_completions.py
Normal file
222
tests/test_completions.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
import app
|
||||||
|
import config
|
||||||
|
import db
|
||||||
|
import routers.completions
|
||||||
|
from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS
|
||||||
|
|
||||||
|
|
||||||
|
def make_client(tmp_path: Path) -> TestClient:
|
||||||
|
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
||||||
|
db.DB_PATH = tmp_path / "jarvischat-completions.db"
|
||||||
|
SESSIONS.clear()
|
||||||
|
PIN_ATTEMPTS.clear()
|
||||||
|
RATE_EVENTS.clear()
|
||||||
|
db.init_db()
|
||||||
|
return TestClient(app.app, raise_server_exceptions=False)
|
||||||
|
|
||||||
|
|
||||||
|
TEST_API_KEY = "test-sk-jarvischat-completions"
|
||||||
|
|
||||||
|
|
||||||
|
def _auth_headers(extra: dict = None) -> dict:
|
||||||
|
h = {"Authorization": f"Bearer {TEST_API_KEY}", "Content-Type": "application/json", "Origin": "http://testserver"}
|
||||||
|
if extra:
|
||||||
|
h.update(extra)
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class _MockStreamResponse:
|
||||||
|
def __init__(self, lines: list[str]):
|
||||||
|
self._lines = lines
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def aiter_lines(self):
|
||||||
|
for line in self._lines:
|
||||||
|
yield line
|
||||||
|
|
||||||
|
|
||||||
|
class _MockAsyncPostResponse:
|
||||||
|
def __init__(self, status_code=200, json_data=None):
|
||||||
|
self.status_code = status_code
|
||||||
|
self._json_data = json_data or {}
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
return self._json_data
|
||||||
|
|
||||||
|
|
||||||
|
def _stream_json_lines(events: list[dict]) -> list[str]:
|
||||||
|
return [json.dumps(event) for event in events]
|
||||||
|
|
||||||
|
|
||||||
|
def test_completions_missing_api_key(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={"messages": [{"role": "user", "content": "hi"}]},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_completions_invalid_api_key(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={"messages": [{"role": "user", "content": "hi"}]},
|
||||||
|
headers={"Authorization": "Bearer wrong-key", "Origin": "http://testserver"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_completions_no_messages(tmp_path: Path, monkeypatch):
|
||||||
|
monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY)
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.post("/v1/chat/completions", json={}, headers=_auth_headers())
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_completions_empty_messages(tmp_path: Path, monkeypatch):
|
||||||
|
monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY)
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.post("/v1/chat/completions", json={"messages": []}, headers=_auth_headers())
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_completions_no_user_message(tmp_path: Path, monkeypatch):
|
||||||
|
monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY)
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={"messages": [{"role": "assistant", "content": "hello"}]},
|
||||||
|
headers=_auth_headers(),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_completions_streaming(tmp_path: Path, monkeypatch):
|
||||||
|
monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY)
|
||||||
|
events = _stream_json_lines([
|
||||||
|
{"choices": [{"delta": {"content": "Hello"}, "logprobs": None}]},
|
||||||
|
{"choices": [{"delta": {"content": " world"}, "logprobs": None}]},
|
||||||
|
{"choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {"tokens_per_second": 15.0}},
|
||||||
|
])
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
def stream_stub(self, method, url, json=None, timeout=None):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return _MockStreamResponse(events)
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "stream", stream_stub)
|
||||||
|
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={"messages": [{"role": "user", "content": "hi"}], "stream": True},
|
||||||
|
headers=_auth_headers(),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.text
|
||||||
|
assert "data: [DONE]" in body
|
||||||
|
assert "Hello" in body or "world" in body
|
||||||
|
assert "chatcmpl-" in body
|
||||||
|
|
||||||
|
|
||||||
|
def test_completions_blocking(tmp_path: Path, monkeypatch):
|
||||||
|
monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY)
|
||||||
|
events = _stream_json_lines([
|
||||||
|
{"choices": [{"delta": {"content": "Hello world"}, "logprobs": None}]},
|
||||||
|
{"choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {}},
|
||||||
|
])
|
||||||
|
|
||||||
|
def stream_stub(self, method, url, json=None, timeout=None):
|
||||||
|
return _MockStreamResponse(events)
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "stream", stream_stub)
|
||||||
|
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={"messages": [{"role": "user", "content": "hi"}], "stream": False},
|
||||||
|
headers=_auth_headers(),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["object"] == "chat.completion"
|
||||||
|
assert data["choices"][0]["message"]["content"] == "Hello world"
|
||||||
|
|
||||||
|
|
||||||
|
def test_completions_fim_passthrough(tmp_path: Path, monkeypatch):
|
||||||
|
monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY)
|
||||||
|
fim_data = {"prompt": "def foo():\n ", "suffix": "\n return x", "model": "llama3.1:latest"}
|
||||||
|
|
||||||
|
async def mock_post(self, url, json=None, timeout=None):
|
||||||
|
return _MockAsyncPostResponse(json_data={"choices": [{"text": "pass"}], "usage": {}})
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "post", mock_post)
|
||||||
|
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.post("/v1/chat/completions", json=fim_data, headers=_auth_headers())
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert "choices" in resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
def test_completions_connect_error_stream(tmp_path: Path, monkeypatch):
|
||||||
|
monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY)
|
||||||
|
|
||||||
|
def broken_stream(self, method, url, json=None, timeout=None):
|
||||||
|
raise httpx.ConnectError("Connection refused")
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "stream", broken_stream)
|
||||||
|
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={"messages": [{"role": "user", "content": "hi"}], "stream": True},
|
||||||
|
headers=_auth_headers(),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert "connection_error" in resp.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_completions_connect_error_blocking(tmp_path: Path, monkeypatch):
|
||||||
|
monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY)
|
||||||
|
|
||||||
|
def broken_stream(self, method, url, json=None, timeout=None):
|
||||||
|
raise httpx.ConnectError("Connection refused")
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "stream", broken_stream)
|
||||||
|
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={"messages": [{"role": "user", "content": "hi"}], "stream": False},
|
||||||
|
headers=_auth_headers(),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 503
|
||||||
|
|
||||||
|
|
||||||
|
def test_completions_fim_connect_error(tmp_path: Path, monkeypatch):
|
||||||
|
monkeypatch.setattr(routers.completions, "COMPLETIONS_API_KEY", TEST_API_KEY)
|
||||||
|
fim_data = {"prompt": "def foo():", "model": "llama3.1:latest"}
|
||||||
|
|
||||||
|
def broken_post(self, url, json=None, timeout=None):
|
||||||
|
raise httpx.ConnectError("Connection refused")
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "post", broken_post)
|
||||||
|
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.post("/v1/chat/completions", json=fim_data, headers=_auth_headers())
|
||||||
|
assert resp.status_code == 503
|
||||||
153
tests/test_conversations.py
Normal file
153
tests/test_conversations.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
import app
|
||||||
|
import db
|
||||||
|
from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS
|
||||||
|
|
||||||
|
|
||||||
|
def make_client(tmp_path: Path) -> TestClient:
|
||||||
|
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
||||||
|
db.DB_PATH = tmp_path / "jarvischat-conversations.db"
|
||||||
|
SESSIONS.clear()
|
||||||
|
PIN_ATTEMPTS.clear()
|
||||||
|
RATE_EVENTS.clear()
|
||||||
|
db.init_db()
|
||||||
|
return TestClient(app.app, raise_server_exceptions=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _admin_headers(client: TestClient) -> dict:
|
||||||
|
login = client.post("/api/auth/login", json={"pin": "1234"}, headers={"Origin": "http://testserver"})
|
||||||
|
sid = login.json()["session_id"]
|
||||||
|
return {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||||
|
|
||||||
|
|
||||||
|
def _guest_headers(client: TestClient) -> dict:
|
||||||
|
sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()["session_id"]
|
||||||
|
return {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_conversations_empty(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.get("/api/conversations", headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_and_list_conversation(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
headers = _admin_headers(client)
|
||||||
|
|
||||||
|
create = client.post("/api/conversations", json={"title": "Test Chat", "model": "llama3.1:latest"}, headers=headers)
|
||||||
|
assert create.status_code == 200
|
||||||
|
data = create.json()
|
||||||
|
assert data["title"] == "Test Chat"
|
||||||
|
assert data["model"] == "llama3.1:latest"
|
||||||
|
|
||||||
|
list_resp = client.get("/api/conversations", headers=headers)
|
||||||
|
assert list_resp.status_code == 200
|
||||||
|
convs = list_resp.json()
|
||||||
|
assert len(convs) == 1
|
||||||
|
assert convs[0]["title"] == "Test Chat"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_conversation_returns_messages(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
headers = _admin_headers(client)
|
||||||
|
create = client.post("/api/conversations", json={"title": "My Chat"}, headers=headers)
|
||||||
|
conv_id = create.json()["id"]
|
||||||
|
|
||||||
|
resp = client.get(f"/api/conversations/{conv_id}", headers=headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["conversation"]["id"] == conv_id
|
||||||
|
assert data["messages"] == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_conversation_not_found(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.get("/api/conversations/nope", headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_conversation_title(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
headers = _admin_headers(client)
|
||||||
|
create = client.post("/api/conversations", json={"title": "Old"}, headers=headers)
|
||||||
|
conv_id = create.json()["id"]
|
||||||
|
|
||||||
|
update = client.put(f"/api/conversations/{conv_id}", json={"title": "New Title"}, headers=headers)
|
||||||
|
assert update.status_code == 200
|
||||||
|
|
||||||
|
get = client.get(f"/api/conversations/{conv_id}", headers=headers)
|
||||||
|
assert get.json()["conversation"]["title"] == "New Title"
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_conversation_model(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
headers = _admin_headers(client)
|
||||||
|
create = client.post("/api/conversations", json={"title": "Test"}, headers=headers)
|
||||||
|
conv_id = create.json()["id"]
|
||||||
|
|
||||||
|
update = client.put(f"/api/conversations/{conv_id}", json={"model": "qwen2:latest"}, headers=headers)
|
||||||
|
assert update.status_code == 200
|
||||||
|
|
||||||
|
get = client.get(f"/api/conversations/{conv_id}", headers=headers)
|
||||||
|
assert get.json()["conversation"]["model"] == "qwen2:latest"
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_conversation(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
headers = _admin_headers(client)
|
||||||
|
create = client.post("/api/conversations", json={"title": "Delete Me"}, headers=headers)
|
||||||
|
conv_id = create.json()["id"]
|
||||||
|
|
||||||
|
delete = client.delete(f"/api/conversations/{conv_id}", headers=headers)
|
||||||
|
assert delete.status_code == 200
|
||||||
|
|
||||||
|
get = client.get(f"/api/conversations/{conv_id}", headers=_guest_headers(client))
|
||||||
|
assert get.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_all_conversations(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
headers = _admin_headers(client)
|
||||||
|
client.post("/api/conversations", json={"title": "One"}, headers=headers)
|
||||||
|
client.post("/api/conversations", json={"title": "Two"}, headers=headers)
|
||||||
|
|
||||||
|
delete_all = client.delete("/api/conversations", headers=headers)
|
||||||
|
assert delete_all.status_code == 200
|
||||||
|
|
||||||
|
list_resp = client.get("/api/conversations", headers=_guest_headers(client))
|
||||||
|
assert list_resp.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_guest_cannot_create_conversation(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.post("/api/conversations", json={"title": "test"}, headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
def test_guest_cannot_update_conversation(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
headers = _admin_headers(client)
|
||||||
|
create = client.post("/api/conversations", json={"title": "Test"}, headers=headers)
|
||||||
|
conv_id = create.json()["id"]
|
||||||
|
|
||||||
|
guest_headers = _guest_headers(client)
|
||||||
|
resp = client.put(f"/api/conversations/{conv_id}", json={"title": "hack"}, headers=guest_headers)
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
def test_guest_cannot_delete_conversation(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.delete("/api/conversations/some-id", headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
def test_guest_cannot_delete_all(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.delete("/api/conversations", headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 403
|
||||||
@@ -1,19 +1,24 @@
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import httpx
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
import app as app_module
|
import app
|
||||||
|
import config
|
||||||
|
import db
|
||||||
|
import routers.memories
|
||||||
|
from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS
|
||||||
|
|
||||||
|
|
||||||
def make_client(tmp_path: Path) -> TestClient:
|
def make_client(tmp_path: Path) -> TestClient:
|
||||||
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
||||||
app_module.DB_PATH = tmp_path / "jarvischat-errors.db"
|
db.DB_PATH = tmp_path / "jarvischat-errors.db"
|
||||||
app_module.SESSIONS.clear()
|
SESSIONS.clear()
|
||||||
app_module.PIN_ATTEMPTS.clear()
|
PIN_ATTEMPTS.clear()
|
||||||
app_module.RATE_EVENTS.clear()
|
RATE_EVENTS.clear()
|
||||||
app_module.init_db()
|
db.init_db()
|
||||||
return TestClient(app_module.app, raise_server_exceptions=False)
|
return TestClient(app.app, raise_server_exceptions=False)
|
||||||
|
|
||||||
|
|
||||||
def test_unhandled_api_exception_returns_friendly_error_with_incident_key(
|
def test_unhandled_api_exception_returns_friendly_error_with_incident_key(
|
||||||
@@ -23,12 +28,12 @@ def test_unhandled_api_exception_returns_friendly_error_with_incident_key(
|
|||||||
sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()[
|
sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()[
|
||||||
"session_id"
|
"session_id"
|
||||||
]
|
]
|
||||||
headers = {"X-Session-ID": sid}
|
headers = {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||||
|
|
||||||
def boom(_topic=None):
|
def boom(_topic=None):
|
||||||
raise RuntimeError("super secret db internals")
|
raise RuntimeError("super secret db internals")
|
||||||
|
|
||||||
monkeypatch.setattr(app_module, "get_all_memories", boom)
|
monkeypatch.setattr(routers.memories, "get_all_memories", boom)
|
||||||
|
|
||||||
resp = client.get("/api/memories", headers=headers)
|
resp = client.get("/api/memories", headers=headers)
|
||||||
assert resp.status_code == 500
|
assert resp.status_code == 500
|
||||||
@@ -57,11 +62,11 @@ def test_chat_stream_error_hides_internal_exception_and_emits_incident_key(
|
|||||||
def broken_stream(*args, **kwargs):
|
def broken_stream(*args, **kwargs):
|
||||||
return BrokenStreamContext()
|
return BrokenStreamContext()
|
||||||
|
|
||||||
monkeypatch.setattr(app_module.httpx.AsyncClient, "stream", broken_stream)
|
monkeypatch.setattr(httpx.AsyncClient, "stream", broken_stream)
|
||||||
|
|
||||||
resp = client.post(
|
resp = client.post(
|
||||||
"/api/chat",
|
"/api/chat",
|
||||||
json={"message": "hello", "model": app_module.DEFAULT_MODEL},
|
json={"message": "hello", "model": config.DEFAULT_MODEL},
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -3,48 +3,42 @@ from pathlib import Path
|
|||||||
|
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
import app as app_module
|
import app
|
||||||
|
import db
|
||||||
|
from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS, is_ip_allowed
|
||||||
|
|
||||||
|
|
||||||
def make_client(tmp_path: Path) -> TestClient:
|
def make_client(tmp_path: Path) -> TestClient:
|
||||||
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
||||||
app_module.DB_PATH = tmp_path / "jarvischat-ip.db"
|
db.DB_PATH = tmp_path / "jarvischat-ip.db"
|
||||||
app_module.SESSIONS.clear()
|
SESSIONS.clear()
|
||||||
app_module.PIN_ATTEMPTS.clear()
|
PIN_ATTEMPTS.clear()
|
||||||
app_module.RATE_EVENTS.clear()
|
RATE_EVENTS.clear()
|
||||||
app_module.init_db()
|
db.init_db()
|
||||||
return TestClient(app_module.app)
|
return TestClient(app.app)
|
||||||
|
|
||||||
|
|
||||||
def test_ip_helper_allows_local_defaults():
|
def test_ip_helper_allows_local_defaults():
|
||||||
assert app_module.is_ip_allowed("127.0.0.1")
|
assert is_ip_allowed("127.0.0.1")
|
||||||
assert app_module.is_ip_allowed("192.168.1.10")
|
assert is_ip_allowed("192.168.1.10")
|
||||||
assert app_module.is_ip_allowed("10.0.0.42")
|
assert is_ip_allowed("10.0.0.42")
|
||||||
assert app_module.is_ip_allowed("172.16.1.2")
|
assert is_ip_allowed("172.16.1.2")
|
||||||
assert app_module.is_ip_allowed("testclient")
|
assert is_ip_allowed("testclient")
|
||||||
|
|
||||||
|
|
||||||
def test_ip_helper_blocks_public_ip():
|
def test_ip_helper_blocks_public_ip():
|
||||||
assert not app_module.is_ip_allowed("8.8.8.8")
|
assert not is_ip_allowed("8.8.8.8")
|
||||||
|
|
||||||
|
|
||||||
def test_middleware_blocks_disallowed_ip(tmp_path: Path):
|
def test_middleware_blocks_disallowed_ip(tmp_path: Path, monkeypatch):
|
||||||
|
monkeypatch.setattr(app, "get_client_ip", lambda _req: "8.8.8.8")
|
||||||
with make_client(tmp_path) as client:
|
with make_client(tmp_path) as client:
|
||||||
original_get_client_ip = app_module.get_client_ip
|
resp = client.post("/api/auth/guest")
|
||||||
try:
|
assert resp.status_code == 403
|
||||||
app_module.get_client_ip = lambda _req: "8.8.8.8"
|
|
||||||
resp = client.post("/api/auth/guest")
|
|
||||||
assert resp.status_code == 403
|
|
||||||
finally:
|
|
||||||
app_module.get_client_ip = original_get_client_ip
|
|
||||||
|
|
||||||
|
|
||||||
def test_middleware_allows_local_ip(tmp_path: Path):
|
def test_middleware_allows_local_ip(tmp_path: Path, monkeypatch):
|
||||||
|
monkeypatch.setattr(app, "get_client_ip", lambda _req: "192.168.50.109")
|
||||||
with make_client(tmp_path) as client:
|
with make_client(tmp_path) as client:
|
||||||
original_get_client_ip = app_module.get_client_ip
|
resp = client.post("/api/auth/guest", headers={"Origin": "http://testserver"})
|
||||||
try:
|
assert resp.status_code == 200
|
||||||
app_module.get_client_ip = lambda _req: "192.168.50.109"
|
|
||||||
resp = client.post("/api/auth/guest")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
finally:
|
|
||||||
app_module.get_client_ip = original_get_client_ip
|
|
||||||
|
|||||||
161
tests/test_memories.py
Normal file
161
tests/test_memories.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
import app
|
||||||
|
import config
|
||||||
|
import db
|
||||||
|
from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS
|
||||||
|
|
||||||
|
|
||||||
|
def make_client(tmp_path: Path) -> TestClient:
|
||||||
|
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
||||||
|
db.DB_PATH = tmp_path / "jarvischat-memories.db"
|
||||||
|
SESSIONS.clear()
|
||||||
|
PIN_ATTEMPTS.clear()
|
||||||
|
RATE_EVENTS.clear()
|
||||||
|
db.init_db()
|
||||||
|
return TestClient(app.app, raise_server_exceptions=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _admin_headers(client: TestClient) -> dict:
|
||||||
|
login = client.post("/api/auth/login", json={"pin": "1234"}, headers={"Origin": "http://testserver"})
|
||||||
|
sid = login.json()["session_id"]
|
||||||
|
return {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||||
|
|
||||||
|
|
||||||
|
def _guest_headers(client: TestClient) -> dict:
|
||||||
|
sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()["session_id"]
|
||||||
|
return {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||||
|
|
||||||
|
|
||||||
|
def _create_memory(client: TestClient, headers: dict, fact: str = "test fact", topic: str = "general") -> int:
|
||||||
|
resp = client.post("/api/memories", json={"fact": fact, "topic": topic}, headers=headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
return resp.json()["rowid"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_memories_empty(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.get("/api/memories", headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["count"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_memories_by_topic(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
headers = _admin_headers(client)
|
||||||
|
_create_memory(client, headers, "I like Python", "preference")
|
||||||
|
_create_memory(client, headers, "Building a game", "project")
|
||||||
|
|
||||||
|
general = client.get("/api/memories?topic=preference", headers=_guest_headers(client))
|
||||||
|
assert general.json()["count"] == 1
|
||||||
|
assert general.json()["memories"][0]["topic"] == "preference"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_memory_requires_fact(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.post("/api/memories", json={"fact": ""}, headers=_admin_headers(client))
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_memory_too_long(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
long_fact = "x" * (config.MAX_MEMORY_FACT_CHARS + 1)
|
||||||
|
resp = client.post("/api/memories", json={"fact": long_fact}, headers=_admin_headers(client))
|
||||||
|
assert resp.status_code == 413
|
||||||
|
|
||||||
|
|
||||||
|
def test_edit_memory(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
headers = _admin_headers(client)
|
||||||
|
rowid = _create_memory(client, headers, "original fact")
|
||||||
|
|
||||||
|
edit = client.put(f"/api/memories/{rowid}", json={"fact": "updated fact"}, headers=headers)
|
||||||
|
assert edit.status_code == 200
|
||||||
|
|
||||||
|
memories = client.get("/api/memories", headers=_guest_headers(client)).json()
|
||||||
|
assert any(m["fact"] == "updated fact" for m in memories["memories"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_edit_memory_not_found(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.put("/api/memories/99999", json={"fact": "nope"}, headers=_admin_headers(client))
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_edit_memory_empty_fact(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
headers = _admin_headers(client)
|
||||||
|
rowid = _create_memory(client, headers, "some fact")
|
||||||
|
resp = client.put(f"/api/memories/{rowid}", json={"fact": ""}, headers=headers)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_edit_memory_too_long(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
headers = _admin_headers(client)
|
||||||
|
rowid = _create_memory(client, headers, "some fact")
|
||||||
|
long_fact = "x" * (config.MAX_MEMORY_FACT_CHARS + 1)
|
||||||
|
resp = client.put(f"/api/memories/{rowid}", json={"fact": long_fact}, headers=headers)
|
||||||
|
assert resp.status_code == 413
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_memory_not_found(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.delete("/api/memories/99999", headers=_admin_headers(client))
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_memories(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
headers = _admin_headers(client)
|
||||||
|
_create_memory(client, headers, "my favorite color is blue", "preference")
|
||||||
|
_create_memory(client, headers, "running nginx on port 443", "infrastructure")
|
||||||
|
|
||||||
|
resp = client.get("/api/memories/search?q=nginx&limit=5", headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["count"] >= 1
|
||||||
|
assert any("nginx" in r["fact"] for r in data["results"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_memories_no_results(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.get("/api/memories/search?q=xyznonexistent&limit=5", headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["count"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_stats(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
headers = _admin_headers(client)
|
||||||
|
_create_memory(client, headers, "like rust", "preference")
|
||||||
|
_create_memory(client, headers, "like python", "preference")
|
||||||
|
_create_memory(client, headers, "project game", "project")
|
||||||
|
|
||||||
|
resp = client.get("/api/memories/stats", headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["total"] == 3
|
||||||
|
assert data["by_topic"]["preference"] == 2
|
||||||
|
assert data["by_topic"]["project"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_guest_cannot_create_memory(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.post("/api/memories", json={"fact": "hack"}, headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
def test_guest_cannot_edit_memory(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.put("/api/memories/1", json={"fact": "hack"}, headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
def test_guest_cannot_delete_memory(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.delete("/api/memories/1", headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 403
|
||||||
138
tests/test_models_router.py
Normal file
138
tests/test_models_router.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
import app
|
||||||
|
import db
|
||||||
|
import routers.models
|
||||||
|
from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS
|
||||||
|
|
||||||
|
|
||||||
|
def make_client(tmp_path: Path) -> TestClient:
|
||||||
|
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
||||||
|
db.DB_PATH = tmp_path / "jarvischat-models.db"
|
||||||
|
SESSIONS.clear()
|
||||||
|
PIN_ATTEMPTS.clear()
|
||||||
|
RATE_EVENTS.clear()
|
||||||
|
db.init_db()
|
||||||
|
return TestClient(app.app, raise_server_exceptions=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _guest_headers(client: TestClient) -> dict:
|
||||||
|
sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()["session_id"]
|
||||||
|
return {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||||
|
|
||||||
|
|
||||||
|
class _MockAsyncResponse:
|
||||||
|
"""Mock for httpx.AsyncClient.get/post that returns a JSON response."""
|
||||||
|
def __init__(self, status_code=200, json_data=None):
|
||||||
|
self.status_code = status_code
|
||||||
|
self._json_data = json_data or {}
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
return self._json_data
|
||||||
|
|
||||||
|
|
||||||
|
async def _mock_get_models(*args, **kwargs):
|
||||||
|
return _MockAsyncResponse(json_data={
|
||||||
|
"data": [{"id": "llama3.1:latest"}, {"id": "qwen2:latest"}]
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
async def _mock_get_empty_models(*args, **kwargs):
|
||||||
|
return _MockAsyncResponse(json_data={"data": []})
|
||||||
|
|
||||||
|
|
||||||
|
async def _mock_connect_error(*args, **kwargs):
|
||||||
|
raise httpx.ConnectError("Connection refused")
|
||||||
|
|
||||||
|
|
||||||
|
async def _mock_show_model(*args, **kwargs):
|
||||||
|
return _MockAsyncResponse(json_data={
|
||||||
|
"modelfile": "FROM llama3.1", "parameters": {}
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
async def _mock_search_available(*args, **kwargs):
|
||||||
|
return _MockAsyncResponse(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
async def _mock_search_unavailable(*args, **kwargs):
|
||||||
|
raise httpx.ConnectError("refused")
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_models(tmp_path: Path, monkeypatch):
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "get", _mock_get_models)
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.get("/api/models", headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
models = resp.json()["models"]
|
||||||
|
assert len(models) == 2
|
||||||
|
assert models[0]["name"] == "llama3.1:latest"
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_models_connect_error(tmp_path: Path, monkeypatch):
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "get", _mock_connect_error)
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.get("/api/models", headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 502
|
||||||
|
|
||||||
|
|
||||||
|
def test_running_models(tmp_path: Path, monkeypatch):
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "get", _mock_get_models)
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.get("/api/ps", headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert "data" in resp.json()
|
||||||
|
|
||||||
|
|
||||||
|
def test_running_models_connect_error(tmp_path: Path, monkeypatch):
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "get", _mock_connect_error)
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.get("/api/ps", headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 502
|
||||||
|
|
||||||
|
|
||||||
|
def test_show_model(tmp_path: Path, monkeypatch):
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "post", _mock_show_model)
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.post("/api/show", json={"model": "llama3.1:latest"}, headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["modelfile"] == "FROM llama3.1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_show_model_connect_error(tmp_path: Path, monkeypatch):
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "post", _mock_connect_error)
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.post("/api/show", json={"model": "llama3.1:latest"}, headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 502
|
||||||
|
|
||||||
|
|
||||||
|
def test_system_stats(tmp_path: Path, monkeypatch):
|
||||||
|
monkeypatch.setattr(routers.models, "get_gpu_stats", lambda: {"gpu_percent": 15, "vram_percent": 30, "available": True})
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.get("/api/stats", headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert "cpu_percent" in data
|
||||||
|
assert "memory_percent" in data
|
||||||
|
assert data["gpu_percent"] == 15
|
||||||
|
assert data["gpu_available"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_status_available(tmp_path: Path, monkeypatch):
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "get", _mock_search_available)
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.get("/api/search/status", headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["available"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_status_unavailable(tmp_path: Path, monkeypatch):
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "get", _mock_search_unavailable)
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.get("/api/search/status", headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["available"] is False
|
||||||
128
tests/test_presets.py
Normal file
128
tests/test_presets.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
import app
|
||||||
|
import db
|
||||||
|
from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS
|
||||||
|
|
||||||
|
|
||||||
|
def make_client(tmp_path: Path) -> TestClient:
|
||||||
|
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
||||||
|
db.DB_PATH = tmp_path / "jarvischat-presets.db"
|
||||||
|
SESSIONS.clear()
|
||||||
|
PIN_ATTEMPTS.clear()
|
||||||
|
RATE_EVENTS.clear()
|
||||||
|
db.init_db()
|
||||||
|
return TestClient(app.app, raise_server_exceptions=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _admin_headers(client: TestClient) -> dict:
|
||||||
|
login = client.post("/api/auth/login", json={"pin": "1234"}, headers={"Origin": "http://testserver"})
|
||||||
|
sid = login.json()["session_id"]
|
||||||
|
return {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||||
|
|
||||||
|
|
||||||
|
def _guest_headers(client: TestClient) -> dict:
|
||||||
|
sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()["session_id"]
|
||||||
|
return {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_presets_returns_defaults(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.get("/api/presets", headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
presets = resp.json()
|
||||||
|
assert len(presets) >= 3
|
||||||
|
names = [p["name"] for p in presets]
|
||||||
|
assert "Coding Companion" in names
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_preset(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
headers = _admin_headers(client)
|
||||||
|
resp = client.post("/api/presets", json={"name": "My Preset", "prompt": "You are helpful."}, headers=headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["name"] == "My Preset"
|
||||||
|
assert data["prompt"] == "You are helpful."
|
||||||
|
|
||||||
|
presets = client.get("/api/presets", headers=_guest_headers(client)).json()
|
||||||
|
assert any(p["name"] == "My Preset" for p in presets)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_preset_requires_name_and_prompt(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
headers = _admin_headers(client)
|
||||||
|
resp = client.post("/api/presets", json={"name": "", "prompt": ""}, headers=headers)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
resp = client.post("/api/presets", json={"name": "Only Name"}, headers=headers)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_preset(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
headers = _admin_headers(client)
|
||||||
|
create = client.post("/api/presets", json={"name": "Old", "prompt": "Old prompt."}, headers=headers)
|
||||||
|
preset_id = create.json()["id"]
|
||||||
|
|
||||||
|
update = client.put(f"/api/presets/{preset_id}", json={"name": "New", "prompt": "New prompt."}, headers=headers)
|
||||||
|
assert update.status_code == 200
|
||||||
|
|
||||||
|
presets = client.get("/api/presets", headers=_guest_headers(client)).json()
|
||||||
|
updated = next(p for p in presets if p["id"] == preset_id)
|
||||||
|
assert updated["name"] == "New"
|
||||||
|
assert updated["prompt"] == "New prompt."
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_preset_requires_fields(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
headers = _admin_headers(client)
|
||||||
|
resp = client.put("/api/presets/nope", json={"name": "", "prompt": ""}, headers=headers)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_preset(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
headers = _admin_headers(client)
|
||||||
|
create = client.post("/api/presets", json={"name": "Temp", "prompt": "Temp."}, headers=headers)
|
||||||
|
preset_id = create.json()["id"]
|
||||||
|
|
||||||
|
delete = client.delete(f"/api/presets/{preset_id}", headers=headers)
|
||||||
|
assert delete.status_code == 200
|
||||||
|
|
||||||
|
presets = client.get("/api/presets", headers=_guest_headers(client)).json()
|
||||||
|
assert not any(p["id"] == preset_id for p in presets)
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_default_preset_is_noop(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
headers = _admin_headers(client)
|
||||||
|
presets_before = client.get("/api/presets", headers=_guest_headers(client)).json()
|
||||||
|
default = next(p for p in presets_before if p["is_default"])
|
||||||
|
|
||||||
|
delete = client.delete(f"/api/presets/{default['id']}", headers=headers)
|
||||||
|
assert delete.status_code == 200
|
||||||
|
|
||||||
|
presets_after = client.get("/api/presets", headers=_guest_headers(client)).json()
|
||||||
|
assert any(p["id"] == default["id"] for p in presets_after)
|
||||||
|
|
||||||
|
|
||||||
|
def test_guest_cannot_create_preset(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.post("/api/presets", json={"name": "Hack", "prompt": "Hack"}, headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
def test_guest_cannot_update_preset(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.put("/api/presets/some-id", json={"name": "Hack", "prompt": "Hack"}, headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
def test_guest_cannot_delete_preset(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.delete("/api/presets/some-id", headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 403
|
||||||
72
tests/test_profile.py
Normal file
72
tests/test_profile.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
import app
|
||||||
|
import config
|
||||||
|
import db
|
||||||
|
from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS
|
||||||
|
|
||||||
|
|
||||||
|
def make_client(tmp_path: Path) -> TestClient:
|
||||||
|
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
||||||
|
db.DB_PATH = tmp_path / "jarvischat-profile.db"
|
||||||
|
SESSIONS.clear()
|
||||||
|
PIN_ATTEMPTS.clear()
|
||||||
|
RATE_EVENTS.clear()
|
||||||
|
db.init_db()
|
||||||
|
return TestClient(app.app, raise_server_exceptions=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _admin_headers(client: TestClient) -> dict:
|
||||||
|
login = client.post("/api/auth/login", json={"pin": "1234"}, headers={"Origin": "http://testserver"})
|
||||||
|
sid = login.json()["session_id"]
|
||||||
|
return {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||||
|
|
||||||
|
|
||||||
|
def _guest_headers(client: TestClient) -> dict:
|
||||||
|
sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()["session_id"]
|
||||||
|
return {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_profile_returns_content(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.get("/api/profile", headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert "content" in data
|
||||||
|
assert "updated_at" in data
|
||||||
|
assert len(data["content"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_default_profile(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.get("/api/profile/default", headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["content"] == config.DEFAULT_PROFILE
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_profile(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
headers = _admin_headers(client)
|
||||||
|
resp = client.put("/api/profile", json={"content": "Custom profile text."}, headers=headers)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert "updated_at" in resp.json()
|
||||||
|
|
||||||
|
get_resp = client.get("/api/profile", headers=_guest_headers(client))
|
||||||
|
assert get_resp.json()["content"] == "Custom profile text."
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_profile_too_long(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
headers = _admin_headers(client)
|
||||||
|
long_content = "x" * (config.MAX_PROFILE_CHARS + 1)
|
||||||
|
resp = client.put("/api/profile", json={"content": long_content}, headers=headers)
|
||||||
|
assert resp.status_code == 413
|
||||||
|
|
||||||
|
|
||||||
|
def test_guest_cannot_update_profile(tmp_path: Path):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
resp = client.put("/api/profile", json={"content": "hack"}, headers=_guest_headers(client))
|
||||||
|
assert resp.status_code == 403
|
||||||
@@ -4,28 +4,32 @@ from pathlib import Path
|
|||||||
|
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
import app as app_module
|
import app
|
||||||
|
import config
|
||||||
|
import db
|
||||||
|
import security
|
||||||
|
from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS
|
||||||
|
|
||||||
|
|
||||||
def make_client(tmp_path: Path) -> TestClient:
|
def make_client(tmp_path: Path) -> TestClient:
|
||||||
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
||||||
app_module.DB_PATH = tmp_path / "jarvischat-rate.db"
|
db.DB_PATH = tmp_path / "jarvischat-rate.db"
|
||||||
app_module.SESSIONS.clear()
|
SESSIONS.clear()
|
||||||
app_module.PIN_ATTEMPTS.clear()
|
PIN_ATTEMPTS.clear()
|
||||||
app_module.RATE_EVENTS.clear()
|
RATE_EVENTS.clear()
|
||||||
app_module.init_db()
|
db.init_db()
|
||||||
return TestClient(app_module.app)
|
return TestClient(app.app)
|
||||||
|
|
||||||
|
|
||||||
def test_stats_rate_limit_hits_429(tmp_path: Path):
|
def test_stats_rate_limit_hits_429(tmp_path: Path):
|
||||||
old_limit = app_module.RL_STATS_PER_WINDOW
|
old_limit = security.RL_STATS_PER_WINDOW
|
||||||
old_window = app_module.RATE_WINDOW_SECONDS
|
old_window = app.RATE_WINDOW_SECONDS
|
||||||
app_module.RL_STATS_PER_WINDOW = 2
|
security.RL_STATS_PER_WINDOW = 2
|
||||||
app_module.RATE_WINDOW_SECONDS = 60
|
app.RATE_WINDOW_SECONDS = 60
|
||||||
try:
|
try:
|
||||||
with make_client(tmp_path) as client:
|
with make_client(tmp_path) as client:
|
||||||
sid = client.post("/api/auth/guest").json()["session_id"]
|
sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()["session_id"]
|
||||||
headers = {"X-Session-ID": sid}
|
headers = {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||||
|
|
||||||
r1 = client.get("/api/stats", headers=headers)
|
r1 = client.get("/api/stats", headers=headers)
|
||||||
r2 = client.get("/api/stats", headers=headers)
|
r2 = client.get("/api/stats", headers=headers)
|
||||||
@@ -35,13 +39,13 @@ def test_stats_rate_limit_hits_429(tmp_path: Path):
|
|||||||
assert r2.status_code == 200
|
assert r2.status_code == 200
|
||||||
assert r3.status_code == 429
|
assert r3.status_code == 429
|
||||||
finally:
|
finally:
|
||||||
app_module.RL_STATS_PER_WINDOW = old_limit
|
security.RL_STATS_PER_WINDOW = old_limit
|
||||||
app_module.RATE_WINDOW_SECONDS = old_window
|
app.RATE_WINDOW_SECONDS = old_window
|
||||||
|
|
||||||
|
|
||||||
def test_large_login_payload_rejected_413(tmp_path: Path):
|
def test_large_login_payload_rejected_413(tmp_path: Path):
|
||||||
with make_client(tmp_path) as client:
|
with make_client(tmp_path) as client:
|
||||||
huge_pin = "1" * (app_module.BODY_LIMIT_DEFAULT_BYTES + 100)
|
huge_pin = "1" * (config.BODY_LIMIT_DEFAULT_BYTES + 100)
|
||||||
resp = client.post(
|
resp = client.post(
|
||||||
"/api/auth/login",
|
"/api/auth/login",
|
||||||
data=json.dumps({"pin": huge_pin}),
|
data=json.dumps({"pin": huge_pin}),
|
||||||
@@ -52,12 +56,12 @@ def test_large_login_payload_rejected_413(tmp_path: Path):
|
|||||||
|
|
||||||
def test_chat_message_length_rejected_413(tmp_path: Path):
|
def test_chat_message_length_rejected_413(tmp_path: Path):
|
||||||
with make_client(tmp_path) as client:
|
with make_client(tmp_path) as client:
|
||||||
sid = client.post("/api/auth/guest").json()["session_id"]
|
sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()["session_id"]
|
||||||
headers = {"X-Session-ID": sid, "Origin": "http://testserver"}
|
headers = {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||||
message = "x" * (app_module.MAX_CHAT_MESSAGE_CHARS + 1)
|
message = "x" * (config.MAX_CHAT_MESSAGE_CHARS + 1)
|
||||||
resp = client.post(
|
resp = client.post(
|
||||||
"/api/chat",
|
"/api/chat",
|
||||||
json={"message": message, "model": app_module.DEFAULT_MODEL},
|
json={"message": message, "model": config.DEFAULT_MODEL},
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
assert resp.status_code == 413
|
assert resp.status_code == 413
|
||||||
@@ -65,12 +69,12 @@ def test_chat_message_length_rejected_413(tmp_path: Path):
|
|||||||
|
|
||||||
def test_search_query_length_rejected_413(tmp_path: Path):
|
def test_search_query_length_rejected_413(tmp_path: Path):
|
||||||
with make_client(tmp_path) as client:
|
with make_client(tmp_path) as client:
|
||||||
sid = client.post("/api/auth/guest").json()["session_id"]
|
sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()["session_id"]
|
||||||
headers = {"X-Session-ID": sid, "Origin": "http://testserver"}
|
headers = {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||||
query = "q" * (app_module.MAX_SEARCH_QUERY_CHARS + 1)
|
query = "q" * (config.MAX_SEARCH_QUERY_CHARS + 1)
|
||||||
resp = client.post(
|
resp = client.post(
|
||||||
"/api/search",
|
"/api/search",
|
||||||
json={"query": query, "model": app_module.DEFAULT_MODEL},
|
json={"query": query, "model": config.DEFAULT_MODEL},
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
assert resp.status_code == 413
|
assert resp.status_code == 413
|
||||||
|
|||||||
186
tests/test_search_route.py
Normal file
186
tests/test_search_route.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
import app
|
||||||
|
import config
|
||||||
|
import db
|
||||||
|
import routers.search_route
|
||||||
|
from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS
|
||||||
|
|
||||||
|
|
||||||
|
def make_client(tmp_path: Path) -> TestClient:
|
||||||
|
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
||||||
|
db.DB_PATH = tmp_path / "jarvischat-search-route.db"
|
||||||
|
SESSIONS.clear()
|
||||||
|
PIN_ATTEMPTS.clear()
|
||||||
|
RATE_EVENTS.clear()
|
||||||
|
db.init_db()
|
||||||
|
return TestClient(app.app, raise_server_exceptions=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _guest_headers(client: TestClient) -> dict:
|
||||||
|
sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()["session_id"]
|
||||||
|
return {"X-Session-ID": sid, "Origin": "http://testserver"}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_sse_payloads(body: str) -> list[dict]:
|
||||||
|
payloads: list[dict] = []
|
||||||
|
for chunk in body.split("\n\n"):
|
||||||
|
chunk = chunk.strip()
|
||||||
|
if not chunk.startswith("data: "):
|
||||||
|
continue
|
||||||
|
raw = chunk[len("data: ") :]
|
||||||
|
payloads.append(json.loads(raw))
|
||||||
|
return payloads
|
||||||
|
|
||||||
|
|
||||||
|
class _MockStreamResponse:
|
||||||
|
def __init__(self, lines: list[str]):
|
||||||
|
self._lines = lines
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def aiter_lines(self):
|
||||||
|
for line in self._lines:
|
||||||
|
yield line
|
||||||
|
|
||||||
|
|
||||||
|
def _stream_json_lines(events: list[dict]) -> list[str]:
|
||||||
|
return [json.dumps(event) for event in events]
|
||||||
|
|
||||||
|
|
||||||
|
def test_explicit_search_with_results(tmp_path: Path, monkeypatch):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
headers = _guest_headers(client)
|
||||||
|
|
||||||
|
async def search_stub(query: str, max_results: int = 5):
|
||||||
|
return [
|
||||||
|
{"title": "Result One", "url": "https://example.com/1", "content": "First result content."},
|
||||||
|
{"title": "Result Two", "url": "https://example.com/2", "content": "Second result content."},
|
||||||
|
]
|
||||||
|
|
||||||
|
monkeypatch.setattr(routers.search_route, "query_searxng", search_stub)
|
||||||
|
|
||||||
|
events = _stream_json_lines([
|
||||||
|
{"choices": [{"delta": {"content": "Here's what I found"}, "logprobs": None}]},
|
||||||
|
{"choices": [{"delta": {"content": " about your query."}, "logprobs": None}]},
|
||||||
|
{"choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {}},
|
||||||
|
])
|
||||||
|
|
||||||
|
def stream_stub(self, method, url, json=None, timeout=None):
|
||||||
|
return _MockStreamResponse(events)
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "stream", stream_stub)
|
||||||
|
|
||||||
|
resp = client.post(
|
||||||
|
"/api/search",
|
||||||
|
json={"query": "current events", "model": config.DEFAULT_MODEL},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
payloads = parse_sse_payloads(resp.text)
|
||||||
|
|
||||||
|
assert any(p.get("searching") is True for p in payloads)
|
||||||
|
assert any("search_results" in p for p in payloads)
|
||||||
|
token_text = "".join(p.get("token", "") for p in payloads if "token" in p)
|
||||||
|
assert "found" in token_text.lower()
|
||||||
|
assert any(p.get("done") and p.get("searched") for p in payloads)
|
||||||
|
|
||||||
|
|
||||||
|
def test_explicit_search_no_results(tmp_path: Path, monkeypatch):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
headers = _guest_headers(client)
|
||||||
|
|
||||||
|
async def empty_search(query: str, max_results: int = 5):
|
||||||
|
return []
|
||||||
|
|
||||||
|
monkeypatch.setattr(routers.search_route, "query_searxng", empty_search)
|
||||||
|
|
||||||
|
resp = client.post(
|
||||||
|
"/api/search",
|
||||||
|
json={"query": "nothingness", "model": config.DEFAULT_MODEL},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
payloads = parse_sse_payloads(resp.text)
|
||||||
|
|
||||||
|
assert any("No search results found" in p.get("token", "") for p in payloads)
|
||||||
|
assert any(p.get("done") for p in payloads)
|
||||||
|
assert not any("search_results" in p for p in payloads)
|
||||||
|
|
||||||
|
|
||||||
|
def test_explicit_search_new_conversation_created(tmp_path: Path, monkeypatch):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
headers = _guest_headers(client)
|
||||||
|
|
||||||
|
async def search_stub(query: str, max_results: int = 5):
|
||||||
|
return [{"title": "T", "url": "https://ex.com", "content": "Content."}]
|
||||||
|
|
||||||
|
monkeypatch.setattr(routers.search_route, "query_searxng", search_stub)
|
||||||
|
|
||||||
|
events = _stream_json_lines([
|
||||||
|
{"choices": [{"delta": {"content": "Answer."}, "logprobs": None}]},
|
||||||
|
{"choices": [{"delta": {}, "finish_reason": "stop"}], "usage": {}},
|
||||||
|
])
|
||||||
|
|
||||||
|
def stream_stub(self, method, url, json=None, timeout=None):
|
||||||
|
return _MockStreamResponse(events)
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "stream", stream_stub)
|
||||||
|
|
||||||
|
resp = client.post(
|
||||||
|
"/api/search",
|
||||||
|
json={"query": "tell me something", "model": config.DEFAULT_MODEL},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
payloads = parse_sse_payloads(resp.text)
|
||||||
|
|
||||||
|
conv_id = None
|
||||||
|
for p in payloads:
|
||||||
|
if "conversation_id" in p:
|
||||||
|
conv_id = p["conversation_id"]
|
||||||
|
break
|
||||||
|
assert conv_id is not None
|
||||||
|
|
||||||
|
conv_resp = client.get(f"/api/conversations/{conv_id}", headers=_guest_headers(client))
|
||||||
|
assert conv_resp.status_code == 200
|
||||||
|
data = conv_resp.json()
|
||||||
|
assert len(data["messages"]) >= 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_explicit_search_stream_error(tmp_path: Path, monkeypatch):
|
||||||
|
with make_client(tmp_path) as client:
|
||||||
|
headers = _guest_headers(client)
|
||||||
|
|
||||||
|
async def search_stub(query: str, max_results: int = 5):
|
||||||
|
return [{"title": "T", "url": "https://ex.com", "content": "Content."}]
|
||||||
|
|
||||||
|
monkeypatch.setattr(routers.search_route, "query_searxng", search_stub)
|
||||||
|
|
||||||
|
def broken_stream(self, method, url, json=None, timeout=None):
|
||||||
|
class BrokenCtx:
|
||||||
|
async def __aenter__(self):
|
||||||
|
raise RuntimeError("summarization failed")
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
return BrokenCtx()
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "stream", broken_stream)
|
||||||
|
|
||||||
|
resp = client.post(
|
||||||
|
"/api/search",
|
||||||
|
json={"query": "breaking news", "model": config.DEFAULT_MODEL},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert "error_key" in resp.text
|
||||||
|
assert "INC-" in resp.text
|
||||||
@@ -1,17 +1,17 @@
|
|||||||
import app as app_module
|
from search import sanitize_outbound_url
|
||||||
|
|
||||||
|
|
||||||
def test_sanitize_outbound_url_allows_http_https():
|
def test_sanitize_outbound_url_allows_http_https():
|
||||||
assert app_module.sanitize_outbound_url("https://example.com/path") == "https://example.com/path"
|
assert sanitize_outbound_url("https://example.com/path") == "https://example.com/path"
|
||||||
assert app_module.sanitize_outbound_url("http://example.com") == "http://example.com"
|
assert sanitize_outbound_url("http://example.com") == "http://example.com"
|
||||||
|
|
||||||
|
|
||||||
def test_sanitize_outbound_url_blocks_unsafe_schemes():
|
def test_sanitize_outbound_url_blocks_unsafe_schemes():
|
||||||
assert app_module.sanitize_outbound_url("javascript:alert(1)") == ""
|
assert sanitize_outbound_url("javascript:alert(1)") == ""
|
||||||
assert app_module.sanitize_outbound_url("data:text/html,evil") == ""
|
assert sanitize_outbound_url("data:text/html,evil") == ""
|
||||||
assert app_module.sanitize_outbound_url("file:///etc/passwd") == ""
|
assert sanitize_outbound_url("file:///etc/passwd") == ""
|
||||||
|
|
||||||
|
|
||||||
def test_sanitize_outbound_url_blocks_relative_and_empty():
|
def test_sanitize_outbound_url_blocks_relative_and_empty():
|
||||||
assert app_module.sanitize_outbound_url("/relative/path") == ""
|
assert sanitize_outbound_url("/relative/path") == ""
|
||||||
assert app_module.sanitize_outbound_url("") == ""
|
assert sanitize_outbound_url("") == ""
|
||||||
|
|||||||
@@ -3,17 +3,19 @@ from pathlib import Path
|
|||||||
|
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
import app as app_module
|
import app
|
||||||
|
import db
|
||||||
|
from security import SESSIONS, PIN_ATTEMPTS
|
||||||
|
|
||||||
|
|
||||||
def make_admin_client(tmp_path: Path) -> tuple[TestClient, dict[str, str]]:
|
def make_admin_client(tmp_path: Path) -> tuple[TestClient, dict[str, str]]:
|
||||||
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
||||||
app_module.DB_PATH = tmp_path / "jarvischat-settings.db"
|
db.DB_PATH = tmp_path / "jarvischat-settings.db"
|
||||||
app_module.SESSIONS.clear()
|
SESSIONS.clear()
|
||||||
app_module.PIN_ATTEMPTS.clear()
|
PIN_ATTEMPTS.clear()
|
||||||
app_module.init_db()
|
db.init_db()
|
||||||
|
|
||||||
client = TestClient(app_module.app)
|
client = TestClient(app.app)
|
||||||
login = client.post(
|
login = client.post(
|
||||||
"/api/auth/login",
|
"/api/auth/login",
|
||||||
json={"pin": "1234"},
|
json={"pin": "1234"},
|
||||||
|
|||||||
@@ -1,19 +1,23 @@
|
|||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
import app as app_module
|
import app
|
||||||
|
import db
|
||||||
|
from rag import build_system_prompt
|
||||||
|
from security import SESSIONS, PIN_ATTEMPTS, RATE_EVENTS
|
||||||
|
|
||||||
|
|
||||||
def make_client(tmp_path: Path) -> TestClient:
|
def make_client(tmp_path: Path) -> TestClient:
|
||||||
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
os.environ["JARVISCHAT_ADMIN_PIN"] = "1234"
|
||||||
app_module.DB_PATH = tmp_path / "jarvischat-skills.db"
|
db.DB_PATH = tmp_path / "jarvischat-skills.db"
|
||||||
app_module.SESSIONS.clear()
|
SESSIONS.clear()
|
||||||
app_module.PIN_ATTEMPTS.clear()
|
PIN_ATTEMPTS.clear()
|
||||||
app_module.RATE_EVENTS.clear()
|
RATE_EVENTS.clear()
|
||||||
app_module.init_db()
|
db.init_db()
|
||||||
return TestClient(app_module.app, raise_server_exceptions=False)
|
return TestClient(app.app, raise_server_exceptions=False)
|
||||||
|
|
||||||
|
|
||||||
def test_guest_can_list_skills(tmp_path: Path):
|
def test_guest_can_list_skills(tmp_path: Path):
|
||||||
@@ -21,7 +25,7 @@ def test_guest_can_list_skills(tmp_path: Path):
|
|||||||
sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()[
|
sid = client.post("/api/auth/guest", headers={"Origin": "http://testserver"}).json()[
|
||||||
"session_id"
|
"session_id"
|
||||||
]
|
]
|
||||||
resp = client.get("/api/skills", headers={"X-Session-ID": sid})
|
resp = client.get("/api/skills", headers={"X-Session-ID": sid, "Origin": "http://testserver"})
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
payload = resp.json()
|
payload = resp.json()
|
||||||
assert payload["count"] >= 1
|
assert payload["count"] >= 1
|
||||||
@@ -46,7 +50,7 @@ def test_admin_can_toggle_skill_enabled_state(tmp_path: Path):
|
|||||||
assert disable.status_code == 200
|
assert disable.status_code == 200
|
||||||
assert disable.json()["skill"]["enabled"] is False
|
assert disable.json()["skill"]["enabled"] is False
|
||||||
|
|
||||||
active = client.get("/api/skills/active", headers={"X-Session-ID": sid})
|
active = client.get("/api/skills/active", headers={"X-Session-ID": sid, "Origin": "http://testserver"})
|
||||||
assert active.status_code == 200
|
assert active.status_code == 200
|
||||||
assert all(skill["key"] != "search.web" for skill in active.json()["skills"])
|
assert all(skill["key"] != "search.web" for skill in active.json()["skills"])
|
||||||
|
|
||||||
@@ -71,23 +75,23 @@ def test_unknown_skill_update_is_rejected(tmp_path: Path):
|
|||||||
|
|
||||||
def test_prompt_injection_respects_skills_enabled_setting(tmp_path: Path):
|
def test_prompt_injection_respects_skills_enabled_setting(tmp_path: Path):
|
||||||
with make_client(tmp_path):
|
with make_client(tmp_path):
|
||||||
db = app_module.get_db()
|
conn = db.get_db()
|
||||||
try:
|
try:
|
||||||
db.execute(
|
conn.execute(
|
||||||
"INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)",
|
"INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)",
|
||||||
("skills_enabled", "false"),
|
("skills_enabled", "false"),
|
||||||
)
|
)
|
||||||
db.commit()
|
conn.commit()
|
||||||
without_skills = app_module.build_system_prompt(db, "", "hello")
|
without_skills = asyncio.run(build_system_prompt(conn, "", "hello"))
|
||||||
assert "## Active Skills" not in without_skills
|
assert "## Active Skills" not in without_skills
|
||||||
|
|
||||||
db.execute(
|
conn.execute(
|
||||||
"INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)",
|
"INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)",
|
||||||
("skills_enabled", "true"),
|
("skills_enabled", "true"),
|
||||||
)
|
)
|
||||||
db.commit()
|
conn.commit()
|
||||||
with_skills = app_module.build_system_prompt(db, "", "hello")
|
with_skills = asyncio.run(build_system_prompt(conn, "", "hello"))
|
||||||
assert "## Active Skills" in with_skills
|
assert "## Active Skills" in with_skills
|
||||||
assert "memory.search" in with_skills
|
assert "memory.search" in with_skills
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
conn.close()
|
||||||
|
|||||||
Reference in New Issue
Block a user