from datetime import date, datetime from typing import Any, List, Optional from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel from sqlmodel import Session, select from app.api import deps from app.models.supplement import Supplement, SupplementLog router = APIRouter() class SupplementCreate(BaseModel): name: str dosage: float unit: str frequency: str = "daily" scheduled_times: List[str] = [] notes: Optional[str] = None class SupplementUpdate(BaseModel): name: Optional[str] = None dosage: Optional[float] = None unit: Optional[str] = None frequency: Optional[str] = None scheduled_times: Optional[List[str]] = None notes: Optional[str] = None is_active: Optional[bool] = None class SupplementLogCreate(BaseModel): dose_taken: Optional[float] = None notes: Optional[str] = None taken_at: Optional[datetime] = None class SupplementWithStatus(BaseModel): id: int name: str dosage: float unit: str frequency: str scheduled_times: List[str] notes: Optional[str] is_active: bool created_at: datetime taken_today: bool streak: int @router.get("/", response_model=List[Supplement]) def list_supplements( current_user: deps.CurrentUser, session: Session = Depends(deps.get_session), ) -> Any: statement = ( select(Supplement) .where(Supplement.user_id == current_user.id) .where(Supplement.is_active) .order_by(Supplement.name) ) return session.exec(statement).all() @router.post("/", response_model=Supplement) def create_supplement( *, current_user: deps.CurrentUser, session: Session = Depends(deps.get_session), data: SupplementCreate, ) -> Any: supplement = Supplement( user_id=current_user.id, name=data.name, dosage=data.dosage, unit=data.unit, frequency=data.frequency, scheduled_times=data.scheduled_times, notes=data.notes, ) session.add(supplement) session.commit() session.refresh(supplement) return supplement @router.put("/{supplement_id}", response_model=Supplement) def update_supplement( supplement_id: int, data: SupplementUpdate, current_user: deps.CurrentUser, session: Session = Depends(deps.get_session), ) -> Any: supplement = session.get(Supplement, supplement_id) if not supplement: raise HTTPException(status_code=404, detail="Supplement not found") if supplement.user_id != current_user.id: raise HTTPException(status_code=403, detail="Not authorized") update_data = data.model_dump(exclude_unset=True) for key, value in update_data.items(): setattr(supplement, key, value) session.add(supplement) session.commit() session.refresh(supplement) return supplement @router.delete("/{supplement_id}", status_code=204) def delete_supplement( supplement_id: int, current_user: deps.CurrentUser, session: Session = Depends(deps.get_session), ) -> None: supplement = session.get(Supplement, supplement_id) if not supplement: raise HTTPException(status_code=404, detail="Supplement not found") if supplement.user_id != current_user.id: raise HTTPException(status_code=403, detail="Not authorized") # Soft delete supplement.is_active = False session.add(supplement) session.commit() @router.post("/{supplement_id}/log", response_model=SupplementLog) def log_supplement( supplement_id: int, data: SupplementLogCreate, current_user: deps.CurrentUser, session: Session = Depends(deps.get_session), ) -> Any: supplement = session.get(Supplement, supplement_id) if not supplement: raise HTTPException(status_code=404, detail="Supplement not found") if supplement.user_id != current_user.id: raise HTTPException(status_code=403, detail="Not authorized") log = SupplementLog( user_id=current_user.id, supplement_id=supplement_id, taken_at=data.taken_at or datetime.utcnow(), dose_taken=data.dose_taken, notes=data.notes, ) session.add(log) session.commit() session.refresh(log) return log @router.get("/logs", response_model=List[SupplementLog]) def get_supplement_logs( current_user: deps.CurrentUser, session: Session = Depends(deps.get_session), supplement_id: Optional[int] = Query(default=None), start_date: Optional[str] = Query(default=None, description="YYYY-MM-DD"), end_date: Optional[str] = Query(default=None, description="YYYY-MM-DD"), ) -> Any: statement = select(SupplementLog).where(SupplementLog.user_id == current_user.id) if supplement_id: statement = statement.where(SupplementLog.supplement_id == supplement_id) if start_date: dt = datetime.strptime(start_date, "%Y-%m-%d") statement = statement.where(SupplementLog.taken_at >= dt) if end_date: dt = datetime.strptime(end_date, "%Y-%m-%d").replace(hour=23, minute=59, second=59) statement = statement.where(SupplementLog.taken_at <= dt) statement = statement.order_by(SupplementLog.taken_at.desc()) return session.exec(statement).all() @router.get("/today", response_model=List[SupplementWithStatus]) def get_today_supplements( current_user: deps.CurrentUser, session: Session = Depends(deps.get_session), ) -> Any: today = date.today() start = datetime(today.year, today.month, today.day, 0, 0, 0) end = datetime(today.year, today.month, today.day, 23, 59, 59) supplements = session.exec( select(Supplement) .where(Supplement.user_id == current_user.id) .where(Supplement.is_active) .order_by(Supplement.name) ).all() today_logs = session.exec( select(SupplementLog) .where(SupplementLog.user_id == current_user.id) .where(SupplementLog.taken_at >= start) .where(SupplementLog.taken_at <= end) ).all() taken_ids = {log.supplement_id for log in today_logs} result = [] for s in supplements: # Calculate streak: consecutive days taken streak = 0 check_date = today while True: d_start = datetime(check_date.year, check_date.month, check_date.day, 0, 0, 0) d_end = datetime(check_date.year, check_date.month, check_date.day, 23, 59, 59) taken = session.exec( select(SupplementLog) .where(SupplementLog.supplement_id == s.id) .where(SupplementLog.taken_at >= d_start) .where(SupplementLog.taken_at <= d_end) ).first() if taken: streak += 1 from datetime import timedelta check_date = check_date - timedelta(days=1) else: break if streak > 365: break result.append( SupplementWithStatus( id=s.id, name=s.name, dosage=s.dosage, unit=s.unit, frequency=s.frequency, scheduled_times=s.scheduled_times or [], notes=s.notes, is_active=s.is_active, created_at=s.created_at, taken_today=s.id in taken_ids, streak=streak, ) ) return result