from datetime import datetime from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel from sqlmodel import Session, col, func, select from app.auth import get_current_user from app.db import get_session from app.api.v1.endpoints.notifications import send_push_to_all from app.models.models import ( Category, Currency, Transaction, TransactionCreate, TransactionRead, TransactionSource, TransactionUpdate, ) router = APIRouter(prefix="/transactions", tags=["transactions"]) def get_cycle_range(year: int, month: int) -> tuple[datetime, datetime]: """Return (start, end) for billing cycle: month/18 to month+1/18.""" start = datetime(year, month, 18) if month == 12: end = datetime(year + 1, 1, 18) else: end = datetime(year, month + 1, 18) return start, end class BillingCycle(BaseModel): year: int month: int label: str count: int total: float def auto_categorize(merchant: str, session: Session) -> Optional[int]: categories = session.exec(select(Category)).all() merchant_lower = merchant.lower() for cat in categories: if cat.auto_match_patterns: patterns = [p.strip().lower() for p in cat.auto_match_patterns.split(",")] if any(p in merchant_lower for p in patterns if p): return cat.id return None @router.get("/", response_model=list[TransactionRead]) def list_transactions( source: Optional[TransactionSource] = None, search: Optional[str] = None, category_id: Optional[int] = None, cycle_year: Optional[int] = None, cycle_month: Optional[int] = None, start_date: Optional[str] = None, end_date: Optional[str] = None, limit: int = Query(default=50, le=500), offset: int = 0, session: Session = Depends(get_session), _user: str = Depends(get_current_user), ): query = select(Transaction) if source: query = query.where(Transaction.source == source) if category_id: query = query.where(Transaction.category_id == category_id) if search: query = query.where(col(Transaction.merchant).ilike(f"%{search}%")) if cycle_year and cycle_month: start, end = get_cycle_range(cycle_year, cycle_month) query = query.where(Transaction.date >= start, Transaction.date < end) elif start_date and end_date: query = query.where( Transaction.date >= datetime.fromisoformat(start_date), Transaction.date < datetime.fromisoformat(end_date), ) query = query.order_by(col(Transaction.date).desc()).offset(offset).limit(limit) return session.exec(query).all() @router.get("/cycles", response_model=list[BillingCycle]) def list_billing_cycles( session: Session = Depends(get_session), _user: str = Depends(get_current_user), ): """Return available billing cycles based on transaction dates.""" # Get date range of all transactions result = session.exec( select(func.min(Transaction.date), func.max(Transaction.date)) ).first() if not result or not result[0]: return [] min_date, max_date = result cycles = [] # Determine which cycle the min_date falls into if min_date.day < 18: # Falls in previous month's cycle if min_date.month == 1: y, m = min_date.year - 1, 12 else: y, m = min_date.year, min_date.month - 1 else: y, m = min_date.year, min_date.month while True: start, end = get_cycle_range(y, m) if start > max_date: break # Count transactions in this cycle count_result = session.exec( select(func.count(), func.coalesce(func.sum(Transaction.amount), 0)).where( Transaction.date >= start, Transaction.date < end ) ).first() count = count_result[0] if count_result else 0 total = float(count_result[1]) if count_result else 0.0 if count > 0: month_names = [ "", "Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec", ] end_month = m + 1 if m < 12 else 1 end_year = y if m < 12 else y + 1 label = f"{month_names[m]} 18 - {month_names[end_month]} 18, {end_year}" cycles.append(BillingCycle(year=y, month=m, label=label, count=count, total=total)) # Next month if m == 12: y, m = y + 1, 1 else: m += 1 return list(reversed(cycles)) @router.get("/recent", response_model=list[TransactionRead]) def recent_transactions( limit: int = Query(default=5, le=20), session: Session = Depends(get_session), _user: str = Depends(get_current_user), ): query = ( select(Transaction) .where(Transaction.source == TransactionSource.CREDIT_CARD) .order_by(col(Transaction.date).desc()) .limit(limit) ) return session.exec(query).all() @router.post("/", response_model=TransactionRead, status_code=201) def create_transaction( data: TransactionCreate, session: Session = Depends(get_session), _user: str = Depends(get_current_user), ): tx = Transaction.model_validate(data) # Duplicate detection by reference if tx.reference: existing = session.exec( select(Transaction).where(Transaction.reference == tx.reference) ).first() if existing: raise HTTPException( status_code=409, detail=f"Duplicate transaction: reference '{tx.reference}' already exists (id={existing.id})", ) if tx.category_id is None: tx.category_id = auto_categorize(tx.merchant, session) session.add(tx) session.commit() session.refresh(tx) # Send push notification symbol = "₡" if tx.currency == Currency.CRC else tx.currency.value amount_str = f"{symbol}{tx.amount:,.0f}" if tx.currency == Currency.CRC else f"{symbol}{tx.amount:,.2f}" send_push_to_all( session, title=f"💳 {tx.merchant}", body=f"{amount_str} — {tx.bank.value} {tx.transaction_type.value.lower()}", url=f"/budget", ) return tx @router.patch("/{transaction_id}", response_model=TransactionRead) def update_transaction( transaction_id: int, data: TransactionUpdate, session: Session = Depends(get_session), _user: str = Depends(get_current_user), ): tx = session.get(Transaction, transaction_id) if not tx: raise HTTPException(status_code=404, detail="Transaction not found") update_data = data.model_dump(exclude_unset=True) for key, value in update_data.items(): setattr(tx, key, value) session.add(tx) session.commit() session.refresh(tx) return tx @router.delete("/{transaction_id}", status_code=204) def delete_transaction( transaction_id: int, session: Session = Depends(get_session), _user: str = Depends(get_current_user), ): tx = session.get(Transaction, transaction_id) if not tx: raise HTTPException(status_code=404, detail="Transaction not found") session.delete(tx) session.commit()