Files
WealthySmart/backend/app/agent/tools.py
Carlos Escalante ec716e698f
All checks were successful
Deploy to VPS / deploy (push) Successful in 13s
Exclude SALARY and DEPOSITO from agent recent-transactions tool
The 'last N transactions' answer was including salary deposits, which the
user reads as expense activity. Filter income types out at the query level.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-29 22:55:53 -06:00

468 lines
15 KiB
Python

"""
Read-only tools exposed to the MAF ChatAgent. Each tool is a thin wrapper
around existing SQLModel queries / service helpers — they do NOT duplicate
business logic. The active DB session is resolved via a ContextVar so tool
signatures stay clean for the LLM.
"""
from __future__ import annotations
import contextvars
from datetime import datetime
from typing import Annotated, Optional
from pydantic import Field
from sqlalchemy import case
from sqlmodel import Session, col, func, select
from app.models.models import (
Account,
BalanceOverride,
Category,
MunicipalReceipt,
PensionSnapshot,
RecurringItem,
Transaction,
TransactionSource,
TransactionType,
WaterMeterReading,
)
from app.services.budget_projection import (
compute_monthly_projection,
compute_yearly_projection_with_cumulative,
get_cycle_range,
)
from app.services.exchange_rate import (
get_converted_amount_expr,
get_current_rate,
)
_session_ctx: contextvars.ContextVar[Session] = contextvars.ContextVar("agent_session")
def set_session(session: Session) -> contextvars.Token:
return _session_ctx.set(session)
def reset_session(token: contextvars.Token) -> None:
_session_ctx.reset(token)
def _s() -> Session:
return _session_ctx.get()
# ─── Tools ──────────────────────────────────────────────────────────────────
def get_accounts() -> list[dict]:
"""List every account with current balance, currency, bank and type
(BANK, PENSION, CRYPTO, SAVINGS, LIABILITY). Use this for net-worth and
balance questions."""
rows = _s().exec(select(Account).order_by(Account.account_type, Account.label)).all()
return [
{
"id": a.id,
"bank": a.bank.value,
"label": a.label,
"currency": a.currency.value,
"balance": a.balance,
"account_type": a.account_type.value,
"next_payment": a.next_payment,
}
for a in rows
]
def get_net_worth() -> dict:
"""Return total assets, liabilities and net worth in CRC (primary currency).
USD/EUR balances are converted at the latest exchange rate."""
accounts = _s().exec(select(Account)).all()
rate = get_current_rate(_s())
sell = rate.sell_rate if rate else 600.0
assets_crc = 0.0
liabilities_crc = 0.0
for a in accounts:
amt = a.balance
if a.currency.value == "USD":
amt = a.balance * sell
elif a.currency.value == "EUR":
amt = a.balance * sell * 1.08 # rough; real conversion is endpoint-side
if a.account_type.value == "LIABILITY":
liabilities_crc += amt
else:
assets_crc += amt
return {
"assets_crc": round(assets_crc, 2),
"liabilities_crc": round(liabilities_crc, 2),
"net_crc": round(assets_crc - liabilities_crc, 2),
}
def get_recent_transactions(
limit: Annotated[int, Field(ge=1, le=100, description="How many rows to return")] = 20,
source: Annotated[
Optional[str],
Field(description="Filter by source: CREDIT_CARD, CASH, or TRANSFER"),
] = None,
category_id: Annotated[Optional[int], Field(description="Filter by category id")] = None,
search: Annotated[
Optional[str], Field(description="Substring match against merchant name")
] = None,
start_date: Annotated[
Optional[str], Field(description="ISO date lower bound, inclusive")
] = None,
end_date: Annotated[
Optional[str], Field(description="ISO date upper bound, exclusive")
] = None,
) -> list[dict]:
"""Recent transactions, newest first. Use filters to narrow down. For
billing-cycle scoped totals prefer get_cycle_summary."""
q = select(Transaction).where(
Transaction.transaction_type.notin_(
[TransactionType.SALARY, TransactionType.DEPOSITO]
)
)
if source:
q = q.where(Transaction.source == TransactionSource(source))
if category_id is not None:
q = q.where(Transaction.category_id == category_id)
if search:
q = q.where(col(Transaction.merchant).ilike(f"%{search}%"))
if start_date:
q = q.where(Transaction.date >= datetime.fromisoformat(start_date))
if end_date:
q = q.where(Transaction.date < datetime.fromisoformat(end_date))
q = q.order_by(col(Transaction.date).desc()).limit(limit)
return [
{
"id": t.id,
"date": t.date.isoformat(),
"merchant": t.merchant,
"amount": t.amount,
"currency": t.currency.value,
"source": t.source.value,
"transaction_type": t.transaction_type.value,
"bank": t.bank.value,
"category_id": t.category_id,
}
for t in _s().exec(q).all()
]
def get_cycle_summary(
cycle_year: Annotated[int, Field(description="Billing cycle year, e.g. 2026")],
cycle_month: Annotated[
int,
Field(ge=1, le=12, description="Billing cycle month (cycle runs 18th→18th)"),
],
) -> dict:
"""Totals for a credit-card billing cycle (18th of month → 18th of next).
Returns spend by source, count, and spend by category."""
session = _s()
amount_crc = get_converted_amount_expr(session)
start, end = get_cycle_range(cycle_year, cycle_month)
totals = session.exec(
select(
Transaction.source,
func.count(),
func.coalesce(func.sum(amount_crc), 0),
)
.where(
Transaction.transaction_type == TransactionType.COMPRA,
Transaction.date >= start,
Transaction.date < end,
)
.group_by(Transaction.source)
).all()
by_category = session.exec(
select(
Category.name,
func.coalesce(func.sum(amount_crc), 0),
func.count(),
)
.join(Category, Category.id == Transaction.category_id, isouter=True)
.where(
Transaction.transaction_type == TransactionType.COMPRA,
Transaction.date >= start,
Transaction.date < end,
)
.group_by(Category.name)
.order_by(func.sum(amount_crc).desc())
).all()
return {
"cycle_year": cycle_year,
"cycle_month": cycle_month,
"range": [start.isoformat(), end.isoformat()],
"by_source": [
{"source": s.value, "count": c, "total_crc": float(t)}
for s, c, t in totals
],
"by_category": [
{"category": n or "Uncategorized", "total_crc": float(t), "count": c}
for n, t, c in by_category
],
}
def get_budget_projection(
year: Annotated[int, Field(description="Year to project")],
month: Annotated[
Optional[int],
Field(ge=1, le=12, description="If given, return only that month's detail"),
] = None,
) -> dict:
"""Budget projection. If month is omitted, returns the yearly rollup; if
given, returns the monthly detail with income items, expense items and
actuals by source."""
session = _s()
if month is None:
months_data = compute_yearly_projection_with_cumulative(session, year)
return {
"year": year,
"months": months_data,
"annual_income": sum(m["projected_income"] for m in months_data),
"annual_expenses": sum(m["gran_total_egresos"] for m in months_data),
"annual_net": sum(m["net_balance"] for m in months_data),
}
return compute_monthly_projection(session, year, month)
def list_recurring_items() -> list[dict]:
"""All recurring items (income and expense, SAVINGS excluded) used by the
budget projection. Useful to explain what's driving a month's projection."""
rows = _s().exec(
select(RecurringItem)
.where(RecurringItem.is_active == True) # noqa: E712
.order_by(RecurringItem.item_type, RecurringItem.name)
).all()
return [
{
"id": r.id,
"name": r.name,
"amount": r.amount,
"currency": r.currency.value,
"item_type": r.item_type.value,
"frequency": r.frequency.value,
"day_of_month": r.day_of_month,
"category_id": r.category_id,
}
for r in rows
]
def get_pension_snapshots(
fund: Annotated[
Optional[str],
Field(description="Filter by fund bank code (FCL, ROP, VOL, etc.)"),
] = None,
latest_only: Annotated[
bool,
Field(description="If true, return only the latest snapshot per fund"),
] = True,
) -> list[dict]:
"""Pension fund snapshots. Each snapshot covers a period with balances,
contributions, returns, fees and the ending balance (saldo_final)."""
q = select(PensionSnapshot).order_by(col(PensionSnapshot.period_end).desc())
if fund:
q = q.where(PensionSnapshot.fund == fund)
rows = _s().exec(q).all()
if latest_only:
seen: dict[str, PensionSnapshot] = {}
for r in rows:
if r.fund.value not in seen:
seen[r.fund.value] = r
rows = list(seen.values())
return [
{
"fund": r.fund.value,
"period_start": r.period_start.isoformat(),
"period_end": r.period_end.isoformat(),
"saldo_anterior": r.saldo_anterior,
"aportes": r.aportes,
"rendimientos": r.rendimientos,
"retiros": r.retiros,
"comision": r.comision,
"saldo_final": r.saldo_final,
}
for r in rows
]
def get_salary_summary() -> dict:
"""Summary of salary deposits (count, total in CRC, latest date)."""
session = _s()
amount_crc = get_converted_amount_expr(session)
row = session.exec(
select(
func.count(),
func.coalesce(func.sum(amount_crc), 0),
func.max(Transaction.date),
).where(Transaction.transaction_type == TransactionType.SALARY)
).first()
count = row[0] if row else 0
total = float(row[1]) if row else 0.0
latest = row[2].isoformat() if row and row[2] else None
return {"count": count, "total_crc": total, "latest_date": latest}
def get_municipal_receipts(
limit: Annotated[int, Field(ge=1, le=50)] = 12,
account: Annotated[
Optional[str], Field(description="Municipal account/contract id")
] = None,
) -> list[dict]:
"""Recent municipal receipts (water + related services) with totals and
water consumption in m³."""
q = select(MunicipalReceipt).order_by(col(MunicipalReceipt.receipt_date).desc())
if account:
q = q.where(MunicipalReceipt.account == account)
q = q.limit(limit)
rows = _s().exec(q).all()
out: list[dict] = []
for r in rows:
readings = _s().exec(
select(WaterMeterReading).where(WaterMeterReading.receipt_id == r.id)
).all()
out.append(
{
"id": r.id,
"receipt_date": r.receipt_date.isoformat(),
"period": r.period,
"account": r.account,
"finca": r.finca,
"subtotal": r.subtotal,
"interests": r.interests,
"iva": r.iva,
"total": r.total,
"water_consumption_m3": sum(w.consumption_m3 for w in readings),
}
)
return out
def get_analytics_by_category(
cycle_year: Annotated[Optional[int], Field(description="Scope to a billing cycle")] = None,
cycle_month: Annotated[Optional[int], Field(ge=1, le=12)] = None,
) -> list[dict]:
"""Spending breakdown by category in CRC (optionally scoped to a billing
cycle). Percentages sum to 100."""
session = _s()
amount_crc = get_converted_amount_expr(session)
q = (
select(
Transaction.category_id,
func.sum(amount_crc).label("total"),
func.count().label("count"),
)
.where(Transaction.transaction_type == TransactionType.COMPRA)
.group_by(Transaction.category_id)
)
if cycle_year and cycle_month:
start, end = get_cycle_range(cycle_year, cycle_month)
q = q.where(Transaction.date >= start, Transaction.date < end)
rows = session.exec(q).all()
grand = sum(float(r[1]) for r in rows) or 1.0
out = []
for cat_id, total, count in rows:
name = "Uncategorized"
if cat_id:
cat = session.get(Category, cat_id)
if cat:
name = cat.name
out.append(
{
"category_id": cat_id,
"category": name,
"total_crc": float(total),
"count": count,
"percentage": round(float(total) / grand * 100, 1),
}
)
out.sort(key=lambda x: x["total_crc"], reverse=True)
return out
def get_monthly_trend(
months: Annotated[int, Field(ge=1, le=24, description="How many months back")] = 6,
) -> list[dict]:
"""Spending trend by billing cycle for the last N months."""
session = _s()
amount_crc = get_converted_amount_expr(session)
now = datetime.now()
results: list[dict] = []
y, m = now.year, now.month
for _ in range(months):
start, end = get_cycle_range(y, m)
row = session.exec(
select(
func.count(),
func.coalesce(func.sum(amount_crc), 0),
func.coalesce(
func.sum(
case((Transaction.currency == "USD", Transaction.amount), else_=0)
),
0,
),
).where(
Transaction.transaction_type == TransactionType.COMPRA,
Transaction.date >= start,
Transaction.date < end,
)
).first()
results.append(
{
"year": y,
"month": m,
"total_crc": float(row[1]) if row else 0.0,
"total_usd_raw": float(row[2]) if row else 0.0,
"count": row[0] if row else 0,
}
)
if m == 1:
y, m = y - 1, 12
else:
m -= 1
return list(reversed(results))
def get_exchange_rate() -> dict:
"""Latest USD/CRC exchange rate (buy and sell). All multi-currency data
in the app is normalized to CRC using these rates."""
rate = get_current_rate(_s())
if not rate:
return {"buy_rate": None, "sell_rate": None, "date": None}
return {
"buy_rate": rate.buy_rate,
"sell_rate": rate.sell_rate,
"date": rate.date.isoformat(),
}
def list_categories() -> list[dict]:
"""All transaction categories (id, name, icon). Use when the user asks
about a category and you need the id to filter by."""
rows = _s().exec(select(Category).order_by(Category.name)).all()
return [{"id": c.id, "name": c.name, "icon": c.icon} for c in rows]
# Registered with the agent in agent.py
TOOLS = [
get_accounts,
get_net_worth,
get_recent_transactions,
get_cycle_summary,
get_budget_projection,
list_recurring_items,
get_pension_snapshots,
get_salary_summary,
get_municipal_receipts,
get_analytics_by_category,
get_monthly_trend,
get_exchange_rate,
list_categories,
]