""" Read-only tools exposed to the MAF ChatAgent. Each tool is a thin wrapper around existing SQLModel queries / service helpers — they do NOT duplicate business logic. The active DB session is resolved via a ContextVar so tool signatures stay clean for the LLM. """ from __future__ import annotations import contextvars from datetime import datetime from typing import Annotated, Optional from pydantic import Field from sqlalchemy import case from sqlmodel import Session, col, func, select from app.models.models import ( Account, BalanceOverride, Category, MunicipalReceipt, PensionSnapshot, RecurringItem, Transaction, TransactionSource, TransactionType, WaterMeterReading, ) from app.services.budget_projection import ( compute_monthly_projection, compute_yearly_projection_with_cumulative, get_cycle_range, ) from app.services.exchange_rate import ( get_converted_amount_expr, get_current_rate, ) _session_ctx: contextvars.ContextVar[Session] = contextvars.ContextVar("agent_session") def set_session(session: Session) -> contextvars.Token: return _session_ctx.set(session) def reset_session(token: contextvars.Token) -> None: _session_ctx.reset(token) def _s() -> Session: return _session_ctx.get() # ─── Tools ────────────────────────────────────────────────────────────────── def get_accounts() -> list[dict]: """List every account with current balance, currency, bank and type (BANK, PENSION, CRYPTO, SAVINGS, LIABILITY). Use this for net-worth and balance questions.""" rows = _s().exec(select(Account).order_by(Account.account_type, Account.label)).all() return [ { "id": a.id, "bank": a.bank.value, "label": a.label, "currency": a.currency.value, "balance": a.balance, "account_type": a.account_type.value, "next_payment": a.next_payment, } for a in rows ] def get_net_worth() -> dict: """Return total assets, liabilities and net worth in CRC (primary currency). USD/EUR balances are converted at the latest exchange rate.""" accounts = _s().exec(select(Account)).all() rate = get_current_rate(_s()) sell = rate.sell_rate if rate else 600.0 assets_crc = 0.0 liabilities_crc = 0.0 for a in accounts: amt = a.balance if a.currency.value == "USD": amt = a.balance * sell elif a.currency.value == "EUR": amt = a.balance * sell * 1.08 # rough; real conversion is endpoint-side if a.account_type.value == "LIABILITY": liabilities_crc += amt else: assets_crc += amt return { "assets_crc": round(assets_crc, 2), "liabilities_crc": round(liabilities_crc, 2), "net_crc": round(assets_crc - liabilities_crc, 2), } def get_recent_transactions( limit: Annotated[int, Field(ge=1, le=100, description="How many rows to return")] = 20, source: Annotated[ Optional[str], Field(description="Filter by source: CREDIT_CARD, CASH, or TRANSFER"), ] = None, category_id: Annotated[Optional[int], Field(description="Filter by category id")] = None, search: Annotated[ Optional[str], Field(description="Substring match against merchant name") ] = None, start_date: Annotated[ Optional[str], Field(description="ISO date lower bound, inclusive") ] = None, end_date: Annotated[ Optional[str], Field(description="ISO date upper bound, exclusive") ] = None, ) -> list[dict]: """Recent transactions, newest first. Use filters to narrow down. For billing-cycle scoped totals prefer get_cycle_summary.""" q = select(Transaction) if source: q = q.where(Transaction.source == TransactionSource(source)) if category_id is not None: q = q.where(Transaction.category_id == category_id) if search: q = q.where(col(Transaction.merchant).ilike(f"%{search}%")) if start_date: q = q.where(Transaction.date >= datetime.fromisoformat(start_date)) if end_date: q = q.where(Transaction.date < datetime.fromisoformat(end_date)) q = q.order_by(col(Transaction.date).desc()).limit(limit) return [ { "id": t.id, "date": t.date.isoformat(), "merchant": t.merchant, "amount": t.amount, "currency": t.currency.value, "source": t.source.value, "transaction_type": t.transaction_type.value, "bank": t.bank.value, "category_id": t.category_id, } for t in _s().exec(q).all() ] def get_cycle_summary( cycle_year: Annotated[int, Field(description="Billing cycle year, e.g. 2026")], cycle_month: Annotated[ int, Field(ge=1, le=12, description="Billing cycle month (cycle runs 18th→18th)"), ], ) -> dict: """Totals for a credit-card billing cycle (18th of month → 18th of next). Returns spend by source, count, and spend by category.""" session = _s() amount_crc = get_converted_amount_expr(session) start, end = get_cycle_range(cycle_year, cycle_month) totals = session.exec( select( Transaction.source, func.count(), func.coalesce(func.sum(amount_crc), 0), ) .where( Transaction.transaction_type == TransactionType.COMPRA, Transaction.date >= start, Transaction.date < end, ) .group_by(Transaction.source) ).all() by_category = session.exec( select( Category.name, func.coalesce(func.sum(amount_crc), 0), func.count(), ) .join(Category, Category.id == Transaction.category_id, isouter=True) .where( Transaction.transaction_type == TransactionType.COMPRA, Transaction.date >= start, Transaction.date < end, ) .group_by(Category.name) .order_by(func.sum(amount_crc).desc()) ).all() return { "cycle_year": cycle_year, "cycle_month": cycle_month, "range": [start.isoformat(), end.isoformat()], "by_source": [ {"source": s.value, "count": c, "total_crc": float(t)} for s, c, t in totals ], "by_category": [ {"category": n or "Uncategorized", "total_crc": float(t), "count": c} for n, t, c in by_category ], } def get_budget_projection( year: Annotated[int, Field(description="Year to project")], month: Annotated[ Optional[int], Field(ge=1, le=12, description="If given, return only that month's detail"), ] = None, ) -> dict: """Budget projection. If month is omitted, returns the yearly rollup; if given, returns the monthly detail with income items, expense items and actuals by source.""" session = _s() if month is None: months_data = compute_yearly_projection_with_cumulative(session, year) return { "year": year, "months": months_data, "annual_income": sum(m["projected_income"] for m in months_data), "annual_expenses": sum(m["gran_total_egresos"] for m in months_data), "annual_net": sum(m["net_balance"] for m in months_data), } return compute_monthly_projection(session, year, month) def list_recurring_items() -> list[dict]: """All recurring items (income and expense, SAVINGS excluded) used by the budget projection. Useful to explain what's driving a month's projection.""" rows = _s().exec( select(RecurringItem) .where(RecurringItem.is_active == True) # noqa: E712 .order_by(RecurringItem.item_type, RecurringItem.name) ).all() return [ { "id": r.id, "name": r.name, "amount": r.amount, "currency": r.currency.value, "item_type": r.item_type.value, "frequency": r.frequency.value, "day_of_month": r.day_of_month, "category_id": r.category_id, } for r in rows ] def get_pension_snapshots( fund: Annotated[ Optional[str], Field(description="Filter by fund bank code (FCL, ROP, VOL, etc.)"), ] = None, latest_only: Annotated[ bool, Field(description="If true, return only the latest snapshot per fund"), ] = True, ) -> list[dict]: """Pension fund snapshots. Each snapshot covers a period with balances, contributions, returns, fees and the ending balance (saldo_final).""" q = select(PensionSnapshot).order_by(col(PensionSnapshot.period_end).desc()) if fund: q = q.where(PensionSnapshot.fund == fund) rows = _s().exec(q).all() if latest_only: seen: dict[str, PensionSnapshot] = {} for r in rows: if r.fund.value not in seen: seen[r.fund.value] = r rows = list(seen.values()) return [ { "fund": r.fund.value, "period_start": r.period_start.isoformat(), "period_end": r.period_end.isoformat(), "saldo_anterior": r.saldo_anterior, "aportes": r.aportes, "rendimientos": r.rendimientos, "retiros": r.retiros, "comision": r.comision, "saldo_final": r.saldo_final, } for r in rows ] def get_salary_summary() -> dict: """Summary of salary deposits (count, total in CRC, latest date).""" session = _s() amount_crc = get_converted_amount_expr(session) row = session.exec( select( func.count(), func.coalesce(func.sum(amount_crc), 0), func.max(Transaction.date), ).where(Transaction.transaction_type == TransactionType.SALARY) ).first() count = row[0] if row else 0 total = float(row[1]) if row else 0.0 latest = row[2].isoformat() if row and row[2] else None return {"count": count, "total_crc": total, "latest_date": latest} def get_municipal_receipts( limit: Annotated[int, Field(ge=1, le=50)] = 12, account: Annotated[ Optional[str], Field(description="Municipal account/contract id") ] = None, ) -> list[dict]: """Recent municipal receipts (water + related services) with totals and water consumption in m³.""" q = select(MunicipalReceipt).order_by(col(MunicipalReceipt.receipt_date).desc()) if account: q = q.where(MunicipalReceipt.account == account) q = q.limit(limit) rows = _s().exec(q).all() out: list[dict] = [] for r in rows: readings = _s().exec( select(WaterMeterReading).where(WaterMeterReading.receipt_id == r.id) ).all() out.append( { "id": r.id, "receipt_date": r.receipt_date.isoformat(), "period": r.period, "account": r.account, "finca": r.finca, "subtotal": r.subtotal, "interests": r.interests, "iva": r.iva, "total": r.total, "water_consumption_m3": sum(w.consumption_m3 for w in readings), } ) return out def get_analytics_by_category( cycle_year: Annotated[Optional[int], Field(description="Scope to a billing cycle")] = None, cycle_month: Annotated[Optional[int], Field(ge=1, le=12)] = None, ) -> list[dict]: """Spending breakdown by category in CRC (optionally scoped to a billing cycle). Percentages sum to 100.""" session = _s() amount_crc = get_converted_amount_expr(session) q = ( select( Transaction.category_id, func.sum(amount_crc).label("total"), func.count().label("count"), ) .where(Transaction.transaction_type == TransactionType.COMPRA) .group_by(Transaction.category_id) ) if cycle_year and cycle_month: start, end = get_cycle_range(cycle_year, cycle_month) q = q.where(Transaction.date >= start, Transaction.date < end) rows = session.exec(q).all() grand = sum(float(r[1]) for r in rows) or 1.0 out = [] for cat_id, total, count in rows: name = "Uncategorized" if cat_id: cat = session.get(Category, cat_id) if cat: name = cat.name out.append( { "category_id": cat_id, "category": name, "total_crc": float(total), "count": count, "percentage": round(float(total) / grand * 100, 1), } ) out.sort(key=lambda x: x["total_crc"], reverse=True) return out def get_monthly_trend( months: Annotated[int, Field(ge=1, le=24, description="How many months back")] = 6, ) -> list[dict]: """Spending trend by billing cycle for the last N months.""" session = _s() amount_crc = get_converted_amount_expr(session) now = datetime.now() results: list[dict] = [] y, m = now.year, now.month for _ in range(months): start, end = get_cycle_range(y, m) row = session.exec( select( func.count(), func.coalesce(func.sum(amount_crc), 0), func.coalesce( func.sum( case((Transaction.currency == "USD", Transaction.amount), else_=0) ), 0, ), ).where( Transaction.transaction_type == TransactionType.COMPRA, Transaction.date >= start, Transaction.date < end, ) ).first() results.append( { "year": y, "month": m, "total_crc": float(row[1]) if row else 0.0, "total_usd_raw": float(row[2]) if row else 0.0, "count": row[0] if row else 0, } ) if m == 1: y, m = y - 1, 12 else: m -= 1 return list(reversed(results)) def get_exchange_rate() -> dict: """Latest USD/CRC exchange rate (buy and sell). All multi-currency data in the app is normalized to CRC using these rates.""" rate = get_current_rate(_s()) if not rate: return {"buy_rate": None, "sell_rate": None, "date": None} return { "buy_rate": rate.buy_rate, "sell_rate": rate.sell_rate, "date": rate.date.isoformat(), } def list_categories() -> list[dict]: """All transaction categories (id, name, icon). Use when the user asks about a category and you need the id to filter by.""" rows = _s().exec(select(Category).order_by(Category.name)).all() return [{"id": c.id, "name": c.name, "icon": c.icon} for c in rows] # Registered with the agent in agent.py TOOLS = [ get_accounts, get_net_worth, get_recent_transactions, get_cycle_summary, get_budget_projection, list_recurring_items, get_pension_snapshots, get_salary_summary, get_municipal_receipts, get_analytics_by_category, get_monthly_trend, get_exchange_rate, list_categories, ]