import asyncio import json import re import uuid from contextlib import asynccontextmanager from agent_framework_ag_ui import add_agent_framework_fastapi_endpoint from fastapi import FastAPI, HTTPException, Request, Response, status from fastapi.middleware.cors import CORSMiddleware from jose import JWTError, jwt from pydantic import BaseModel from app.agent.agent import build_agent from app.agent.tools import reset_session, set_session from app.api.v1.router import api_router from app.auth import ALGORITHM, create_access_token from app.config import settings from app.db import get_session, init_db, run_migrations from app.seed import seed_db from app.services.exchange_rate import refresh_rates_periodically AGENT_PATH = "/api/v1/agent/agui" def _pair_orphan_tool_calls(messages: list) -> list: """Inject synthetic tool responses for any assistant tool_calls that have no matching tool message. OpenAI rejects histories where a tool_calls entry is not immediately followed by the corresponding tool response.""" out: list = [] pending: list[str] = [] def flush(): for call_id in pending: out.append({"role": "tool", "tool_call_id": call_id, "content": ""}) pending.clear() for msg in messages: role = msg.get("role", "") if role == "tool": call_id = msg.get("tool_call_id") or msg.get("toolCallId") if call_id and call_id in pending: pending.remove(call_id) out.append(msg) continue if role == "assistant": flush() out.append(msg) for tc in msg.get("tool_calls") or msg.get("toolCalls") or []: tc_id = tc.get("id") if isinstance(tc, dict) else None if tc_id: pending.append(tc_id) continue flush() out.append(msg) flush() return out @asynccontextmanager async def lifespan(app: FastAPI): init_db() run_migrations() seed_db() rate_refresh_task = asyncio.create_task(refresh_rates_periodically()) try: yield finally: rate_refresh_task.cancel() try: await rate_refresh_task except asyncio.CancelledError: pass app = FastAPI(title="WealthySmart API", version="0.1.0", lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.middleware("http") async def agent_auth_and_session(request: Request, call_next): """For the AG-UI route, validate the JWT, repair message history, and bind a DB session to a ContextVar so agent tools can query without going through Depends.""" if not request.url.path.startswith(AGENT_PATH): return await call_next(request) if request.method == "OPTIONS": return await call_next(request) auth_header = request.headers.get("authorization", "") token: str | None = None if auth_header.lower().startswith("bearer "): token = auth_header.split(" ", 1)[1].strip() else: cookie_header = request.headers.get("cookie", "") m = re.search(r"(?:^|;\s*)ws_token=([^;]+)", cookie_header) if m: token = m.group(1) if not token: return Response(status_code=401, content="Missing auth") try: payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) if not payload.get("sub"): return Response(status_code=401, content="Invalid token") except JWTError: return Response(status_code=401, content="Invalid token") # Repair orphan tool_calls before the MAF agent sees the message history. if request.method == "POST" and "application/json" in request.headers.get("content-type", ""): raw = await request.body() try: body = json.loads(raw) if isinstance(body.get("messages"), list): body["messages"] = _pair_orphan_tool_calls(body["messages"]) raw = json.dumps(body).encode() except Exception: pass # Starlette caches the body; replace it so call_next sees the fixed bytes. request._body = raw # type: ignore[attr-defined] session_gen = get_session() session = next(session_gen) token_var = set_session(session) try: return await call_next(request) finally: reset_session(token_var) try: next(session_gen) except StopIteration: pass # Register app routes app.include_router(api_router) # Mount the AG-UI agent endpoint. add_agent_framework_fastapi_endpoint(app, build_agent(), AGENT_PATH) @app.get("/") def root(): return {"app": "WealthySmart", "version": "0.1.0"} # ── Cookie-based auth endpoints (used by the Vite SPA) ────────────────────── class LoginRequest(BaseModel): username: str password: str @app.post("/api/auth/login") def cookie_login(body: LoginRequest, response: Response): if ( body.username != settings.ADMIN_USERNAME or body.password != settings.ADMIN_PASSWORD ): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials") token = create_access_token(body.username) response.set_cookie( key="ws_token", value=token, httponly=True, samesite="lax", max_age=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, secure=False, # set True behind TLS in production via nginx ) return {"ok": True} @app.post("/api/auth/logout", status_code=204) def cookie_logout(response: Response): response.delete_cookie("ws_token") @app.get("/api/health") def health(): return {"ok": True}