mirror of
https://github.com/escalante29/healthy-fit.git
synced 2026-03-21 11:08:48 +01:00
Add AI-powered nutrition and plan modules
Introduces DSPy-based nutrition and plan generation modules, including image analysis for nutritional info and personalized diet/exercise plans. Adds new API endpoints for health metrics/goals, nutrition image analysis, and plan management. Updates models, schemas, and backend structure to support these features, and includes initial training data and configuration for prompt optimization.
This commit is contained in:
@@ -1,25 +1,107 @@
|
||||
import base64
|
||||
|
||||
import dspy
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class NutritionalInfo(BaseModel):
|
||||
name: str
|
||||
calories: float
|
||||
protein: float
|
||||
carbs: float
|
||||
fats: float
|
||||
reasoning: str = Field(description="Step-by-step reasoning for the nutritional estimates")
|
||||
name: str = Field(description="Name of the food item")
|
||||
calories: float = Field(description="Estimated calories")
|
||||
protein: float = Field(description="Estimated protein in grams")
|
||||
carbs: float = Field(description="Estimated carbohydrates in grams")
|
||||
fats: float = Field(description="Estimated fats in grams")
|
||||
micros: dict | None = None
|
||||
|
||||
|
||||
class ExtractNutrition(dspy.Signature):
|
||||
"""Extract nutritional information from a food description."""
|
||||
"""Extract nutritional information from a food description.
|
||||
|
||||
You must first provide a detailed step-by-step reasoning analysis of the ingredients,
|
||||
portions, AND preparation methods (cooking oils, butter, sauces) before estimating values.
|
||||
Verify if the caloric totals match the sum of macros (multiplying protein/carbs by 4, fats by 9).
|
||||
"""
|
||||
|
||||
description: str = dspy.InputField(desc="Description of the food or meal")
|
||||
nutritional_info: NutritionalInfo = dspy.OutputField(desc="Nutritional information as a structured object")
|
||||
nutritional_info: NutritionalInfo = dspy.OutputField(desc="Nutritional information with reasoning")
|
||||
|
||||
|
||||
class AnalyzeFoodImage(dspy.Signature):
|
||||
"""Analyze the food image to estimate nutritional content.
|
||||
|
||||
1. Identify all food items and estimated portion sizes.
|
||||
2. CRITICAL: Account for hidden calories from cooking fats, oils, and sauces (searing, frying).
|
||||
3. Reason step-by-step about the total composition before summing macros.
|
||||
"""
|
||||
|
||||
image: dspy.Image = dspy.InputField(desc="The food image")
|
||||
description: str = dspy.InputField(desc="Additional user description", default="")
|
||||
nutritional_info: NutritionalInfo = dspy.OutputField(desc="Nutritional information with reasoning")
|
||||
|
||||
|
||||
class NutritionModule(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.extract = dspy.ChainOfThought(ExtractNutrition)
|
||||
self.analyze_image = dspy.ChainOfThought(AnalyzeFoodImage)
|
||||
|
||||
# Load optimized prompts if available
|
||||
import os
|
||||
|
||||
compiled_path = os.path.join(os.path.dirname(__file__), "nutrition_compiled.json")
|
||||
if os.path.exists(compiled_path):
|
||||
self.load(compiled_path)
|
||||
print(f"Loaded optimized DSPy prompts from {compiled_path}")
|
||||
else:
|
||||
print("No optimized prompts found, using default zero-shot.")
|
||||
|
||||
def forward(self, description: str):
|
||||
return self.extract(description=description)
|
||||
pred = self.extract(description=description)
|
||||
|
||||
# Assertion: Check Macro Consistency
|
||||
calc_cals = (
|
||||
(pred.nutritional_info.protein * 4) + (pred.nutritional_info.carbs * 4) + (pred.nutritional_info.fats * 9)
|
||||
)
|
||||
|
||||
# dspy.Suggest is not available in dspy>=3.1.0
|
||||
# dspy.Suggest(
|
||||
# abs(calc_cals - pred.nutritional_info.calories) < (pred.nutritional_info.calories * 0.20),
|
||||
# f"The sum of macros ({calc_cals:.1f}) should match the total calories "
|
||||
# f"({pred.nutritional_info.calories}). Check your math.",
|
||||
# )
|
||||
return pred
|
||||
|
||||
def forward_image(self, image_url: str, description: str = ""):
|
||||
image = dspy.Image(image_url)
|
||||
pred = self.analyze_image(image=image, description=description)
|
||||
|
||||
# Assertion: Check Macro Consistency
|
||||
calc_cals = (
|
||||
(pred.nutritional_info.protein * 4) + (pred.nutritional_info.carbs * 4) + (pred.nutritional_info.fats * 9)
|
||||
)
|
||||
|
||||
# dspy.Suggest is not available in dspy>=3.1.0
|
||||
# dspy.Suggest(
|
||||
# abs(calc_cals - pred.nutritional_info.calories) < (pred.nutritional_info.calories * 0.20),
|
||||
# f"The sum of macros ({calc_cals:.1f}) should match the total calories "
|
||||
# f"({pred.nutritional_info.calories}). Check your math.",
|
||||
# )
|
||||
return pred
|
||||
|
||||
|
||||
nutrition_module = NutritionModule()
|
||||
|
||||
|
||||
def analyze_nutrition_from_image(image_bytes: bytes, description: str = "") -> NutritionalInfo:
|
||||
if not settings.OPENAI_API_KEY:
|
||||
raise ValueError("OpenAI API Key not set")
|
||||
|
||||
# Convert to base64 data URI
|
||||
base64_image = base64.b64encode(image_bytes).decode("utf-8")
|
||||
image_url = f"data:image/jpeg;base64,{base64_image}"
|
||||
|
||||
# Use DSPy module
|
||||
result = nutrition_module.forward_image(image_url=image_url, description=description)
|
||||
return result.nutritional_info
|
||||
|
||||
147
backend/app/ai/nutrition_compiled.json
Normal file
147
backend/app/ai/nutrition_compiled.json
Normal file
@@ -0,0 +1,147 @@
|
||||
{
|
||||
"extract.predict": {
|
||||
"traces": [],
|
||||
"train": [],
|
||||
"demos": [
|
||||
{
|
||||
"augmented": true,
|
||||
"description": "Blueberry Muffin (Bakery size)",
|
||||
"reasoning": "A typical bakery-sized blueberry muffin is generally larger than a standard homemade muffin and is made from ingredients such as flour, sugar, butter, eggs, milk, blueberries, and baking powder. The estimated calorie count for a large blueberry muffin is about 400-500 calories, primarily derived from carbohydrates (mainly from flour and sugar), fats (from butter), and a moderate amount of protein. \n\nFor nutritional breakdown:\n- Carbs: Approximately 60g derived from the flour and sugar.\n- Fats: Approximately 20g from the butter.\n- Protein: Roughly 6g from the flour and egg content.\n- The muffin may also contain vitamins and minerals from the blueberries.\n\nOverall, a bakery-sized blueberry muffin is calorie-dense due to the combination of ingredients used, especially sugar and butter.",
|
||||
"nutritional_info": {
|
||||
"reasoning": "Bakery-sized muffin has ingredients like flour, sugar, butter, and blueberries. Estimated 450 cal, with ~60g carbs, ~20g fat, and ~6g protein.",
|
||||
"name": "Blueberry Muffin",
|
||||
"calories": 450.0,
|
||||
"protein": 6.0,
|
||||
"carbs": 60.0,
|
||||
"fats": 20.0,
|
||||
"micros": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"description": "Philly Cheesesteak",
|
||||
"nutritional_info": {
|
||||
"reasoning": "Roll (250 cal). Fatty steak (400 cal). Cheese whiz/provolone (150 cal). Oil (100 cal).",
|
||||
"name": "Cheesesteak",
|
||||
"calories": 900.0,
|
||||
"protein": 40.0,
|
||||
"carbs": 50.0,
|
||||
"fats": 55.0,
|
||||
"micros": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"description": "Grilled salmon with asparagus and roasted potatoes",
|
||||
"nutritional_info": {
|
||||
"reasoning": "6oz Salmon fillet (350 cal). Oil for cooking (60 cal). Asparagus (30 cal) + oil (30 cal). 1 cup roasted potatoes (150 cal) + oil (60 cal). Total ~680 cal.",
|
||||
"name": "Salmon Dinner",
|
||||
"calories": 680.0,
|
||||
"protein": 40.0,
|
||||
"carbs": 25.0,
|
||||
"fats": 45.0,
|
||||
"micros": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"description": "Tacos - 3 beef tacos with cheese and sour cream",
|
||||
"nutritional_info": {
|
||||
"reasoning": "3 corn tortillas (150 cal). Ground beef filling (250 cal - cooked with fat). Cheese (110 cal). Sour cream (60 cal). Total ~570 cal.",
|
||||
"name": "Beef Tacos",
|
||||
"calories": 570.0,
|
||||
"protein": 25.0,
|
||||
"carbs": 45.0,
|
||||
"fats": 30.0,
|
||||
"micros": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"signature": {
|
||||
"instructions": "Extract nutritional information from a food description.\n\nYou must first provide a detailed step-by-step reasoning analysis of the ingredients,\nportions, AND preparation methods (cooking oils, butter, sauces) before estimating values.\nVerify if the caloric totals match the sum of macros (multiplying protein/carbs by 4, fats by 9).",
|
||||
"fields": [
|
||||
{
|
||||
"prefix": "Description:",
|
||||
"description": "Description of the food or meal"
|
||||
},
|
||||
{
|
||||
"prefix": "Reasoning: Let's think step by step in order to",
|
||||
"description": "${reasoning}"
|
||||
},
|
||||
{
|
||||
"prefix": "Nutritional Info:",
|
||||
"description": "Nutritional information with reasoning"
|
||||
}
|
||||
]
|
||||
},
|
||||
"lm": null
|
||||
},
|
||||
"analyze_image.predict": {
|
||||
"traces": [],
|
||||
"train": [],
|
||||
"demos": [
|
||||
{
|
||||
"description": "Philly Cheesesteak",
|
||||
"nutritional_info": {
|
||||
"reasoning": "Roll (250 cal). Fatty steak (400 cal). Cheese whiz/provolone (150 cal). Oil (100 cal).",
|
||||
"name": "Cheesesteak",
|
||||
"calories": 900.0,
|
||||
"protein": 40.0,
|
||||
"carbs": 50.0,
|
||||
"fats": 55.0,
|
||||
"micros": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"description": "Grilled salmon with asparagus and roasted potatoes",
|
||||
"nutritional_info": {
|
||||
"reasoning": "6oz Salmon fillet (350 cal). Oil for cooking (60 cal). Asparagus (30 cal) + oil (30 cal). 1 cup roasted potatoes (150 cal) + oil (60 cal). Total ~680 cal.",
|
||||
"name": "Salmon Dinner",
|
||||
"calories": 680.0,
|
||||
"protein": 40.0,
|
||||
"carbs": 25.0,
|
||||
"fats": 45.0,
|
||||
"micros": null
|
||||
}
|
||||
},
|
||||
{
|
||||
"description": "Tacos - 3 beef tacos with cheese and sour cream",
|
||||
"nutritional_info": {
|
||||
"reasoning": "3 corn tortillas (150 cal). Ground beef filling (250 cal - cooked with fat). Cheese (110 cal). Sour cream (60 cal). Total ~570 cal.",
|
||||
"name": "Beef Tacos",
|
||||
"calories": 570.0,
|
||||
"protein": 25.0,
|
||||
"carbs": 45.0,
|
||||
"fats": 30.0,
|
||||
"micros": null
|
||||
}
|
||||
}
|
||||
],
|
||||
"signature": {
|
||||
"instructions": "Analyze the food image to estimate nutritional content.\n\n1. Identify all food items and estimated portion sizes.\n2. CRITICAL: Account for hidden calories from cooking fats, oils, and sauces (searing, frying).\n3. Reason step-by-step about the total composition before summing macros.",
|
||||
"fields": [
|
||||
{
|
||||
"prefix": "Image:",
|
||||
"description": "The food image"
|
||||
},
|
||||
{
|
||||
"prefix": "Description:",
|
||||
"description": "Additional user description"
|
||||
},
|
||||
{
|
||||
"prefix": "Reasoning: Let's think step by step in order to",
|
||||
"description": "${reasoning}"
|
||||
},
|
||||
{
|
||||
"prefix": "Nutritional Info:",
|
||||
"description": "Nutritional information with reasoning"
|
||||
}
|
||||
]
|
||||
},
|
||||
"lm": null
|
||||
},
|
||||
"metadata": {
|
||||
"dependency_versions": {
|
||||
"python": "3.11",
|
||||
"dspy": "3.1.0",
|
||||
"cloudpickle": "3.1"
|
||||
}
|
||||
}
|
||||
}
|
||||
34
backend/app/ai/plans.py
Normal file
34
backend/app/ai/plans.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import dspy
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PlanOutput(BaseModel):
|
||||
reasoning: str = Field(description="Reasoning behind the selected plan based on user goals")
|
||||
title: str = Field(description="Title of the plan")
|
||||
summary: str = Field(description="Brief summary of the plan")
|
||||
diet_plan: list[str] = Field(description="List of daily diet recommendations")
|
||||
exercise_plan: list[str] = Field(description="List of daily exercise routines")
|
||||
tips: list[str] = Field(description="Additional health tips")
|
||||
|
||||
|
||||
class GeneratePlan(dspy.Signature):
|
||||
"""Generate a personalized diet and exercise plan based on user goal and details.
|
||||
|
||||
Analyze the user's profile and goal, explain your reasoning, and then generate the plan.
|
||||
"""
|
||||
|
||||
user_profile: str = dspy.InputField(desc="User details (age, weight, height, etc)")
|
||||
goal: str = dspy.InputField(desc="Specific user goal")
|
||||
plan: PlanOutput = dspy.OutputField(desc="Structured plan with reasoning")
|
||||
|
||||
|
||||
class PlanModule(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.generate = dspy.ChainOfThought(GeneratePlan)
|
||||
|
||||
def forward(self, user_profile: str, goal: str):
|
||||
return self.generate(user_profile=user_profile, goal=goal)
|
||||
|
||||
|
||||
plan_module = PlanModule()
|
||||
@@ -1,34 +1,32 @@
|
||||
from typing import Generator
|
||||
from sqlmodel import Session
|
||||
from app.db import engine
|
||||
from typing import Annotated, Generator
|
||||
|
||||
from typing import Generator, Annotated
|
||||
from sqlmodel import Session, select
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import jwt, JWTError
|
||||
from jose import JWTError, jwt
|
||||
from pydantic import ValidationError
|
||||
from sqlmodel import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.core import security
|
||||
from app.db import engine
|
||||
from app.models.user import User
|
||||
from app.core import security
|
||||
from app.config import settings
|
||||
from app.schemas.token import TokenPayload
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/login/access-token")
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
SessionDep = Annotated[Session, Depends(get_session)]
|
||||
TokenDep = Annotated[str, Depends(oauth2_scheme)]
|
||||
|
||||
|
||||
def get_current_user(session: SessionDep, token: TokenDep) -> User:
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.SECRET_KEY, algorithms=[security.ALGORITHM]
|
||||
)
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[security.ALGORITHM])
|
||||
token_data = TokenPayload(**payload)
|
||||
except (JWTError, ValidationError):
|
||||
raise HTTPException(
|
||||
@@ -40,4 +38,5 @@ def get_current_user(session: SessionDep, token: TokenDep) -> User:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return user
|
||||
|
||||
|
||||
CurrentUser = Annotated[User, Depends(get_current_user)]
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from fastapi import APIRouter
|
||||
from app.api.v1.endpoints import users, login, nutrition, health
|
||||
|
||||
from app.api.v1.endpoints import health, login, nutrition, plans, users
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(login.router, tags=["login"])
|
||||
api_router.include_router(users.router, prefix="/users", tags=["users"])
|
||||
api_router.include_router(nutrition.router, prefix="/nutrition", tags=["nutrition"])
|
||||
api_router.include_router(health.router, prefix="/health", tags=["health"])
|
||||
api_router.include_router(plans.router, prefix="/plans", tags=["plans"])
|
||||
|
||||
@@ -1,35 +1,80 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, List
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlmodel import Session, select
|
||||
from app.api import deps
|
||||
from app.models.health import HealthMetric
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from app.api import deps
|
||||
from app.models.health import HealthGoal, HealthMetric
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class HealthMetricCreate(BaseModel):
|
||||
metric_type: str
|
||||
value: float
|
||||
unit: str
|
||||
user_id: int # TODO: remove when auth is fully integrated
|
||||
|
||||
@router.post("/", response_model=HealthMetric)
|
||||
|
||||
class HealthGoalCreate(BaseModel):
|
||||
goal_type: str
|
||||
target_value: float
|
||||
target_date: datetime | None = None
|
||||
|
||||
|
||||
@router.post("/metrics", response_model=HealthMetric)
|
||||
def create_metric(
|
||||
*,
|
||||
session: Session = Depends(deps.get_session),
|
||||
current_user: deps.CurrentUser,
|
||||
metric_in: HealthMetricCreate,
|
||||
) -> Any:
|
||||
metric = HealthMetric(metric_type=metric_in.metric_type, value=metric_in.value, unit=metric_in.unit, user_id=metric_in.user_id)
|
||||
metric = HealthMetric(
|
||||
metric_type=metric_in.metric_type, value=metric_in.value, unit=metric_in.unit, user_id=current_user.id
|
||||
)
|
||||
session.add(metric)
|
||||
session.commit()
|
||||
session.refresh(metric)
|
||||
return metric
|
||||
|
||||
@router.get("/{user_id}", response_model=List[HealthMetric])
|
||||
|
||||
@router.get("/metrics", response_model=List[HealthMetric])
|
||||
def read_metrics(
|
||||
user_id: int,
|
||||
current_user: deps.CurrentUser,
|
||||
session: Session = Depends(deps.get_session),
|
||||
) -> Any:
|
||||
statement = select(HealthMetric).where(HealthMetric.user_id == user_id)
|
||||
statement = (
|
||||
select(HealthMetric).where(HealthMetric.user_id == current_user.id).order_by(HealthMetric.timestamp.desc())
|
||||
)
|
||||
metrics = session.exec(statement).all()
|
||||
return metrics
|
||||
|
||||
|
||||
@router.post("/goals", response_model=HealthGoal)
|
||||
def create_goal(
|
||||
*,
|
||||
session: Session = Depends(deps.get_session),
|
||||
current_user: deps.CurrentUser,
|
||||
goal_in: HealthGoalCreate,
|
||||
) -> Any:
|
||||
goal = HealthGoal(
|
||||
goal_type=goal_in.goal_type,
|
||||
target_value=goal_in.target_value,
|
||||
target_date=goal_in.target_date,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
session.add(goal)
|
||||
session.commit()
|
||||
session.refresh(goal)
|
||||
return goal
|
||||
|
||||
|
||||
@router.get("/goals", response_model=List[HealthGoal])
|
||||
def read_goals(
|
||||
current_user: deps.CurrentUser,
|
||||
session: Session = Depends(deps.get_session),
|
||||
) -> Any:
|
||||
statement = select(HealthGoal).where(HealthGoal.user_id == current_user.id)
|
||||
goals = session.exec(statement).all()
|
||||
return goals
|
||||
|
||||
@@ -1,35 +1,34 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from app.api import deps
|
||||
from app.core import security
|
||||
from app.config import settings
|
||||
from app.core import security
|
||||
from app.models.user import User
|
||||
from app.schemas.token import Token
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/login/access-token", response_model=Token)
|
||||
def login_access_token(
|
||||
session: Session = Depends(deps.get_session),
|
||||
form_data: OAuth2PasswordRequestForm = Depends()
|
||||
session: Session = Depends(deps.get_session), form_data: OAuth2PasswordRequestForm = Depends()
|
||||
) -> Any:
|
||||
"""
|
||||
OAuth2 compatible token login, get an access token for future requests
|
||||
"""
|
||||
statement = select(User).where(User.email == form_data.username)
|
||||
user = session.exec(statement).first()
|
||||
|
||||
|
||||
if not user or not security.verify_password(form_data.password, user.password_hash):
|
||||
raise HTTPException(status_code=400, detail="Incorrect email or password")
|
||||
|
||||
|
||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
return {
|
||||
"access_token": security.create_access_token(
|
||||
user.id, expires_delta=access_token_expires
|
||||
),
|
||||
"access_token": security.create_access_token(user.id, expires_delta=access_token_expires),
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
import litellm
|
||||
import dspy
|
||||
from typing import Any
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import Session
|
||||
|
||||
from app.ai.nutrition import NutritionalInfo, analyze_nutrition_from_image, nutrition_module
|
||||
from app.api import deps
|
||||
from app.ai.nutrition import nutrition_module, NutritionalInfo
|
||||
from app.core.security import create_access_token # Just ensuring we have auth imports if needed later
|
||||
from app.models.user import User
|
||||
from app.models.food import FoodLog # Added FoodItem
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class AnalyzeRequest(BaseModel):
|
||||
description: str
|
||||
|
||||
from app.models.food import FoodLog, FoodItem
|
||||
from app.api.deps import get_session
|
||||
from app.core.security import get_password_hash # Not needed
|
||||
from app.config import settings
|
||||
|
||||
@router.post("/analyze", response_model=NutritionalInfo)
|
||||
def analyze_food(
|
||||
@@ -30,6 +30,24 @@ def analyze_food(
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/analyze/image", response_model=NutritionalInfo)
|
||||
async def analyze_food_image(
|
||||
file: UploadFile = File(...),
|
||||
description: str = Form(""),
|
||||
) -> Any:
|
||||
"""
|
||||
Analyze food image and return nutritional info.
|
||||
"""
|
||||
try:
|
||||
contents = await file.read()
|
||||
return analyze_nutrition_from_image(contents, description)
|
||||
except litellm.exceptions.BadRequestError as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid image or request: {str(e)}")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/log", response_model=FoodLog)
|
||||
def log_food(
|
||||
*,
|
||||
|
||||
67
backend/app/api/v1/endpoints/plans.py
Normal file
67
backend/app/api/v1/endpoints/plans.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from typing import Any, List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from app.ai.plans import plan_module
|
||||
from app.api import deps
|
||||
from app.models.plan import Plan
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class PlanRequest(BaseModel):
|
||||
goal: str
|
||||
user_details: str # e.g., "Male, 30, 80kg"
|
||||
|
||||
|
||||
@router.post("/generate", response_model=Plan)
|
||||
def generate_plan(
|
||||
*,
|
||||
current_user: deps.CurrentUser,
|
||||
request: PlanRequest,
|
||||
session: Session = Depends(deps.get_session),
|
||||
) -> Any:
|
||||
"""
|
||||
Generate a new diet/exercise plan using AI.
|
||||
"""
|
||||
try:
|
||||
# Generate plan using DSPy
|
||||
generated = plan_module(user_profile=request.user_details, goal=request.goal)
|
||||
|
||||
# Determine content string (markdown representation)
|
||||
content_md = (
|
||||
f"# {generated.plan.title}\n\n{generated.plan.summary}\n\n## Diet\n"
|
||||
+ "\n".join([f"- {item}" for item in generated.plan.diet_plan])
|
||||
+ "\n\n## Exercise\n"
|
||||
+ "\n".join([f"- {item}" for item in generated.plan.exercise_plan])
|
||||
+ "\n\n## Tips\n"
|
||||
+ "\n".join([f"- {item}" for item in generated.plan.tips])
|
||||
)
|
||||
|
||||
plan = Plan(
|
||||
user_id=current_user.id,
|
||||
goal=request.goal,
|
||||
content=content_md,
|
||||
structured_content=generated.plan.model_dump(),
|
||||
)
|
||||
session.add(plan)
|
||||
session.commit()
|
||||
session.refresh(plan)
|
||||
return plan
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/", response_model=List[Plan])
|
||||
def read_plans(
|
||||
current_user: deps.CurrentUser,
|
||||
session: Session = Depends(deps.get_session),
|
||||
) -> Any:
|
||||
"""
|
||||
Get all plans for the current user.
|
||||
"""
|
||||
statement = select(Plan).where(Plan.user_id == current_user.id).order_by(Plan.created_at.desc())
|
||||
plans = session.exec(statement).all()
|
||||
return plans
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlmodel import Session, select
|
||||
|
||||
@@ -9,6 +10,7 @@ from app.schemas.user import UserCreate, UserRead
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/", response_model=UserRead)
|
||||
def create_user(
|
||||
*,
|
||||
@@ -24,7 +26,7 @@ def create_user(
|
||||
status_code=400,
|
||||
detail="The user with this email already exists in the system",
|
||||
)
|
||||
|
||||
|
||||
user = User(
|
||||
email=user_in.email,
|
||||
username=user_in.username,
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
DATABASE_URL: str
|
||||
OPENAI_API_KEY: str | None = None
|
||||
SECRET_KEY: str = "changethis"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 # 8 days
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 # 8 days
|
||||
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import dspy
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
def configure_dspy():
|
||||
if settings.OPENAI_API_KEY:
|
||||
lm = dspy.LM("openai/gpt-4o-mini", api_key=settings.OPENAI_API_KEY)
|
||||
|
||||
@@ -1,24 +1,29 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Union
|
||||
|
||||
from jose import jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from app.config import settings
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
|
||||
def create_access_token(subject: Union[str, Any], expires_delta: timedelta = None) -> str:
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
|
||||
to_encode = {"exp": expire, "sub": str(subject)}
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
from sqlmodel import SQLModel, create_engine, Session, text
|
||||
from sqlmodel import Session, SQLModel, create_engine, text
|
||||
|
||||
from app.config import settings
|
||||
|
||||
engine = create_engine(settings.DATABASE_URL)
|
||||
|
||||
|
||||
def get_session():
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
def init_db():
|
||||
with Session(engine) as session:
|
||||
session.exec(text("CREATE EXTENSION IF NOT EXISTS vector"))
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api.v1.api import api_router
|
||||
from app.db import init_db
|
||||
from app.core.ai_config import configure_dspy
|
||||
from app.db import init_db
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
@@ -14,7 +14,6 @@ async def lifespan(app: FastAPI):
|
||||
configure_dspy()
|
||||
yield
|
||||
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
app = FastAPI(title="Healthy Fit API", lifespan=lifespan)
|
||||
|
||||
@@ -28,6 +27,7 @@ app.add_middleware(
|
||||
|
||||
app.include_router(api_router, prefix="/api/v1")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
def read_root():
|
||||
return {"message": "Welcome to Healthy Fit API"}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict
|
||||
from sqlmodel import Field, SQLModel, JSON
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy import Column
|
||||
from sqlmodel import JSON, Field, SQLModel
|
||||
|
||||
|
||||
class FoodItem(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
@@ -12,13 +14,14 @@ class FoodItem(SQLModel, table=True):
|
||||
carbs: float
|
||||
fats: float
|
||||
micros: Dict = Field(default={}, sa_column=Column(JSON))
|
||||
embedding: List[float] = Field(sa_column=Column(Vector(1536))) # OpenAI embedding size
|
||||
embedding: List[float] = Field(sa_column=Column(Vector(1536))) # OpenAI embedding size
|
||||
|
||||
|
||||
class FoodLog(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.id")
|
||||
food_item_id: Optional[int] = Field(default=None, foreign_key="fooditem.id")
|
||||
name: str # In case no food item is linked or custom entry
|
||||
name: str # In case no food item is linked or custom entry
|
||||
calories: float
|
||||
protein: float
|
||||
carbs: float
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class HealthMetric(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.id")
|
||||
metric_type: str = Field(index=True) # e.g., "weight", "cholesterol", "testosterone"
|
||||
metric_type: str = Field(index=True) # e.g., "weight", "cholesterol", "testosterone"
|
||||
value: float
|
||||
unit: str
|
||||
timestamp: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class HealthGoal(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.id")
|
||||
goal_type: str # e.g., "lose_weight", "gain_muscle"
|
||||
goal_type: str # e.g., "lose_weight", "gain_muscle"
|
||||
target_value: float
|
||||
target_date: Optional[datetime] = None
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
13
backend/app/models/plan.py
Normal file
13
backend/app/models/plan.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import JSON, Field, SQLModel
|
||||
|
||||
|
||||
class Plan(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.id")
|
||||
goal: str = Field(index=True) # e.g., "lose weight", "gain muscle"
|
||||
content: str = Field(description="The full plan content in markdown or text")
|
||||
structured_content: dict = Field(default={}, sa_type=JSON) # For UI rendering
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
@@ -1,7 +1,9 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class User(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
username: str = Field(index=True, unique=True)
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
|
||||
class Token(SQLModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
class TokenPayload(SQLModel):
|
||||
sub: Optional[str] = None
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
from typing import Optional
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
|
||||
class UserBase(SQLModel):
|
||||
email: str
|
||||
username: str
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
password: str
|
||||
|
||||
|
||||
class UserRead(UserBase):
|
||||
id: int
|
||||
|
||||
|
||||
|
||||
class UserUpdate(SQLModel):
|
||||
email: Optional[str] = None
|
||||
username: Optional[str] = None
|
||||
|
||||
13
backend/pyproject.toml
Normal file
13
backend/pyproject.toml
Normal file
@@ -0,0 +1,13 @@
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
target-version = "py311"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "W"]
|
||||
ignore = []
|
||||
|
||||
[tool.ruff.format]
|
||||
quote-style = "double"
|
||||
indent-style = "space"
|
||||
skip-magic-trailing-comma = false
|
||||
line-ending = "auto"
|
||||
@@ -12,3 +12,5 @@ bcrypt==4.0.1
|
||||
pytest
|
||||
httpx
|
||||
python-dotenv
|
||||
ruff
|
||||
|
||||
|
||||
0
backend/scripts/__init__.py
Normal file
0
backend/scripts/__init__.py
Normal file
513
backend/scripts/nutrition_data.py
Normal file
513
backend/scripts/nutrition_data.py
Normal file
@@ -0,0 +1,513 @@
|
||||
import dspy
|
||||
|
||||
from app.ai.nutrition import NutritionalInfo
|
||||
|
||||
# A diverse set of 50 validated examples covering:
|
||||
# - Home cooked meals
|
||||
# - Restaurant items
|
||||
# - Snacks
|
||||
# - Complex dishes with hidden calories
|
||||
|
||||
train_examples = [
|
||||
# --- Breakfast ---
|
||||
dspy.Example(
|
||||
description="Oatmeal with almonds, blueberries, and honey",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="1 cup cooked oats (150 cal). 1 oz almonds (160 cal). "
|
||||
"1/2 cup blueberries (40 cal). 1 tbsp honey (60 cal). Total ~410 cal.",
|
||||
name="Oatmeal Bowl",
|
||||
calories=410,
|
||||
protein=10,
|
||||
carbs=65,
|
||||
fats=16,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Two eggs over easy with two slices of bacon and buttered toast",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="2 eggs (140 cal) cooked in fat (+20 cal). 2 slices bacon (90 cal). "
|
||||
"1 slice huge toast (100 cal) + 1 tsp butter (35 cal). Total ~385 cal.",
|
||||
name="Eggs & Bacon Breakfast",
|
||||
calories=385,
|
||||
protein=20,
|
||||
carbs=15,
|
||||
fats=26,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Greek yogurt parfait with granola and strawberries",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="1 cup non-fat greek yogurt (130 cal). 1/2 cup granola (200 cal). "
|
||||
"1 cup sliced strawberries (50 cal). Total ~380 cal.",
|
||||
name="Yogurt Parfait",
|
||||
calories=380,
|
||||
protein=24,
|
||||
carbs=55,
|
||||
fats=8,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Avocado toast with a poached egg",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="1 slice artisan bread (110 cal). 1/2 avocado (120 cal). "
|
||||
"1 poached egg (70 cal). Drizzle of oil/seasoning (20 cal). Total ~320 cal.",
|
||||
name="Avocado Toast",
|
||||
calories=320,
|
||||
protein=10,
|
||||
carbs=20,
|
||||
fats=22,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Spinach and feta omelette containing 3 eggs",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="3 eggs (210 cal). 1 tsp oil/butter (40 cal). "
|
||||
"1 cup spinach (10 cal). 1 oz feta (75 cal). Total ~335 cal.",
|
||||
name="Spinach Feta Omelette",
|
||||
calories=335,
|
||||
protein=22,
|
||||
carbs=4,
|
||||
fats=25,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
# --- Lunch ---
|
||||
dspy.Example(
|
||||
description="Grilled chicken breast sandwich with mayo, lettuce, tomato",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Bun (180 cal). Chicken breast 4oz (160 cal). "
|
||||
"1 tbsp mayo (90 cal). Veggies (10 cal). Total ~440 cal.",
|
||||
name="Grilled Chicken Sandwich",
|
||||
calories=440,
|
||||
protein=30,
|
||||
carbs=35,
|
||||
fats=18,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Caesar salad with grilled chicken",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Romaine lettuce (20 cal). 4oz chicken (160 cal). 2 tbsp dressing (170 cal). "
|
||||
"Croutons (100 cal). Parmesan (60 cal). Total ~510 cal.",
|
||||
name="Chicken Caesar Salad",
|
||||
calories=510,
|
||||
protein=35,
|
||||
carbs=15,
|
||||
fats=35,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Turkey club sandwich with bacon and cheese",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="3 slices bread (240 cal). Turkey (60 cal). Bacon (90 cal). "
|
||||
"Cheese (110 cal). Mayo (90 cal). Lettuce / Tomato. Total ~590 cal.",
|
||||
name="Turkey Club",
|
||||
calories=590,
|
||||
protein=30,
|
||||
carbs=45,
|
||||
fats=32,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Quinoa bowl with black beans, corn, and avocado",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="1 cup cooked quinoa (220 cal). 1/2 cup black beans (110 cal). "
|
||||
"1/2 cup corn (70 cal). 1/4 avocado (60 cal). Lime dressing (50 cal). Total ~510 cal.",
|
||||
name="Veggie Quinoa Bowl",
|
||||
calories=510,
|
||||
protein=18,
|
||||
carbs=85,
|
||||
fats=12,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Tuna salad sushi roll (6 pieces) and miso soup",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Sushi roll (rice, tuna, mayo) ~300 cal. Miso soup ~40 cal. Total ~340 cal.",
|
||||
name="Sushi Lunch",
|
||||
calories=340,
|
||||
protein=15,
|
||||
carbs=45,
|
||||
fats=8,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
# --- Dinner (Complex) ---
|
||||
dspy.Example(
|
||||
description="Spaghetti bolognaise with parmesan cheese",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="2 cups pasta cooked (400 cal). 1 cup meat sauce/beef (300 cal). "
|
||||
"1 tbsp oil in cooking (120 cal). 2 tbsp parmesan (40 cal). Total ~860 cal.",
|
||||
name="Spaghetti Bolognese",
|
||||
calories=860,
|
||||
protein=35,
|
||||
carbs=100,
|
||||
fats=35,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Grilled salmon with asparagus and roasted potatoes",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="6oz Salmon fillet (350 cal). Oil for cooking (60 cal). "
|
||||
"Asparagus (30 cal) + oil (30 cal). 1 cup roasted potatoes (150 cal) + oil (60 cal). Total ~680 cal.",
|
||||
name="Salmon Dinner",
|
||||
calories=680,
|
||||
protein=40,
|
||||
carbs=25,
|
||||
fats=45,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Beef stir fry with rice",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="1 cup rice (200 cal). 4oz Beef strips (250 cal). "
|
||||
"Oil for frying 2 tbsp (240 cal). Veggies (50 cal). Sauce (50 cal). Total ~790 cal.",
|
||||
name="Beef Stir Fry",
|
||||
calories=790,
|
||||
protein=30,
|
||||
carbs=50,
|
||||
fats=50,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Cheeseburger with fries",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Bun (200 cal). 4oz Patty 80/20 (280 cal). Cheese (100 cal). "
|
||||
"Condiments (50 cal). Small fries (300 cal). Total ~930 cal.",
|
||||
name="Burger and Fries",
|
||||
calories=930,
|
||||
protein=35,
|
||||
carbs=90,
|
||||
fats=45,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Chicken Tikka Masala with Naan and Rice",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Curry with cream/butter/chicken (600 cal). "
|
||||
"1 cup Rice (200 cal). 1 piece Naan (250 cal). Total ~1050 cal.",
|
||||
name="Chicken Tikka Meal",
|
||||
calories=1050,
|
||||
protein=45,
|
||||
carbs=120,
|
||||
fats=45,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="2 slices of pepperoni pizza",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="2 slices (300 cal each). Total ~600 cal. High fat/carbs.",
|
||||
name="2 Pizza Slices",
|
||||
calories=600,
|
||||
protein=24,
|
||||
carbs=70,
|
||||
fats=26,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Tacos - 3 beef tacos with cheese and sour cream",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="3 corn tortillas (150 cal). Ground beef filling (250 cal - cooked with fat). "
|
||||
"Cheese (110 cal). Sour cream (60 cal). Total ~570 cal.",
|
||||
name="Beef Tacos",
|
||||
calories=570,
|
||||
protein=25,
|
||||
carbs=45,
|
||||
fats=30,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Ribeye steak (10oz) with mashed potatoes",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="10oz Ribeye (fatty cut) ~750 cal. "
|
||||
"Mashed potatoes with butter/cream (1 cup) ~300 cal. Total ~1050 cal.",
|
||||
name="Ribeye Steak Dinner",
|
||||
calories=1050,
|
||||
protein=60,
|
||||
carbs=35,
|
||||
fats=75,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
# --- Snacks/Others ---
|
||||
dspy.Example(
|
||||
description="Medium Banana",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Standard fruit size.", name="Banana", calories=105, protein=1.3, carbs=27, fats=0.3
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Protein Shake (Whey)",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="1 scoop whey (120 cal). Water (0 cal).",
|
||||
name="Whey Protein Shake",
|
||||
calories=120,
|
||||
protein=24,
|
||||
carbs=3,
|
||||
fats=1,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Apple with peanut butter",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="1 apple (95 cal). 2 tbsp peanut butter (190 cal). Total ~285 cal.",
|
||||
name="Apple & PB",
|
||||
calories=285,
|
||||
protein=8,
|
||||
carbs=30,
|
||||
fats=16,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Bag of potato chips (small)",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Standard vending machine size (1.5 oz/42g). Fried.",
|
||||
name="Potato Chips",
|
||||
calories=220,
|
||||
protein=3,
|
||||
carbs=22,
|
||||
fats=14,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Hummus and carrot sticks",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="1/4 cup hummus (150 cal). 2 carrots (50 cal). Total ~200 cal.",
|
||||
name="Hummus Snack",
|
||||
calories=200,
|
||||
protein=5,
|
||||
carbs=25,
|
||||
fats=9,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Chocolate chip cookie (Subway style)",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="1 large cookie, heavy on sugar/butter.",
|
||||
name="Large Cookie",
|
||||
calories=220,
|
||||
protein=2,
|
||||
carbs=30,
|
||||
fats=10,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Blueberry Muffin (Bakery size)",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Large bakery muffin is notoriously high cal. Flour, sugar, oil.",
|
||||
name="Bakery Muffin",
|
||||
calories=450,
|
||||
protein=6,
|
||||
carbs=65,
|
||||
fats=18,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
# --- Add 25 more diverse items to reach 50 ---
|
||||
dspy.Example(
|
||||
description="Hard boiled egg",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="One large egg.", name="Egg", calories=78, protein=6, carbs=0.6, fats=5
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Slice of cheddar cheese",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="1 oz slice.", name="Cheddar", calories=110, protein=7, carbs=0.4, fats=9
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Glass of whole milk (8oz)",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Full fat dairy.", name="Whole Milk", calories=150, protein=8, carbs=12, fats=8
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Coca Cola (12oz can)",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="High sugar soda.", name="Coke", calories=140, protein=0, carbs=39, fats=0
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Orange Juice (8oz)",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Natural sugars.", name="OJ", calories=110, protein=2, carbs=26, fats=0
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Kind Bar (Dark Chocolate Nuts)",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Nut based bar.", name="Nut Bar", calories=200, protein=6, carbs=16, fats=13
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Bowl of Beef Chili (1 cup)",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Ground beef, beans, tomato base.", name="Chili", calories=300, protein=20, carbs=25, fats=15
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Pork Chop (baked) with green beans",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="6oz pork chop (250 cal). Steam beans (30 cal). Total ~280.",
|
||||
name="Pork Chop Meal",
|
||||
calories=280,
|
||||
protein=35,
|
||||
carbs=10,
|
||||
fats=12,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Clam Chowder Bowl",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Cream based soup (heavy cream). 1.5 cups.",
|
||||
name="Clam Chowder",
|
||||
calories=450,
|
||||
protein=12,
|
||||
carbs=40,
|
||||
fats=28,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Philly Cheesesteak",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Roll (250 cal). Fatty steak (400 cal). Cheese whiz/provolone (150 cal). Oil (100 cal).",
|
||||
name="Cheesesteak",
|
||||
calories=900,
|
||||
protein=40,
|
||||
carbs=50,
|
||||
fats=55,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Fish and Chips (3 pieces)",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Deep fried batter fits + fried chips. Very high oil absorption.",
|
||||
name="Fish and Chips",
|
||||
calories=950,
|
||||
protein=30,
|
||||
carbs=90,
|
||||
fats=55,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Cobb Salad with ranch",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Greens, bacon, egg, avocado, blue cheese, ranch dressing. Salad is low cal, toppings are high.",
|
||||
name="Cobb Salad",
|
||||
calories=750,
|
||||
protein=35,
|
||||
carbs=15,
|
||||
fats=60,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Hot Dog with bun, mustard, ketchup",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Processed meat link (150 cal). Bun (120 cal). Condiments (20 cal).",
|
||||
name="Hot Dog",
|
||||
calories=290,
|
||||
protein=10,
|
||||
carbs=25,
|
||||
fats=16,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Pad Thai with Shrimp",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Rice noodles stir fried in oil and sugar based sauce. Peanuts.",
|
||||
name="Pad Thai",
|
||||
calories=800,
|
||||
protein=25,
|
||||
carbs=110,
|
||||
fats=30,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Burrito (Chipotle style - Chicken, Rice, Beans, Cheese, Guac)",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Tortilla (300). Rice (200). Beans (150). Chicken (180). Cheese (100). Guac (230!).",
|
||||
name="Burrito",
|
||||
calories=1160,
|
||||
protein=55,
|
||||
carbs=110,
|
||||
fats=55,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Smoothie (Berry, Banana, Yogurt)",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Healthy but sugar dense fruits + yogurt.",
|
||||
name="Fruit Smoothie",
|
||||
calories=300,
|
||||
protein=8,
|
||||
carbs=60,
|
||||
fats=2,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Falafel Wrap",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Fried chickpea balls (250), pita (150), tahini sauce (100).",
|
||||
name="Falafel Wrap",
|
||||
calories=550,
|
||||
protein=15,
|
||||
carbs=70,
|
||||
fats=25,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Macaroni and Cheese (1 cup homemade)",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Pasta + Roux + Milk + lots of Cheese.",
|
||||
name="Mac & Cheese",
|
||||
calories=500,
|
||||
protein=18,
|
||||
carbs=45,
|
||||
fats=28,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Ice Cream (2 scoops vanilla)",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Sugar and Cream.", name="Ice Cream", calories=350, protein=6, carbs=40, fats=20
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Cottage Cheese (1 cup)",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Low fat high protein dairy.", name="Cottage Cheese", calories=180, protein=25, carbs=10, fats=5
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Beef Jerky (1 bag / 3oz)",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Dried meat, lean protein.", name="Beef Jerky", calories=240, protein=35, carbs=15, fats=4
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Edamame (1 cup in pod)",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Soybeans.", name="Edamame", calories=190, protein=17, carbs=15, fats=8
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Popcorn (movie theater small, buttered)",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Corn + oil popping + butter topping.",
|
||||
name="Movie Popcorn",
|
||||
calories=600,
|
||||
protein=6,
|
||||
carbs=60,
|
||||
fats=40,
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Veggie Pizza Slice",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Cheese + Dough + Veggies.", name="Veggie Pizza", calories=260, protein=10, carbs=32, fats=10
|
||||
),
|
||||
).with_inputs("description"),
|
||||
dspy.Example(
|
||||
description="Salmon Nigiri (2 pcs)",
|
||||
nutritional_info=NutritionalInfo(
|
||||
reasoning="Rice ball + Slice of raw fish.", name="Salmon Nigiri", calories=120, protein=10, carbs=15, fats=3
|
||||
),
|
||||
).with_inputs("description"),
|
||||
]
|
||||
41
backend/scripts/optimize_nutrition.py
Normal file
41
backend/scripts/optimize_nutrition.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from dspy.teleprompt import BootstrapFewShot
|
||||
|
||||
from app.ai.nutrition import nutrition_module
|
||||
from app.core.ai_config import configure_dspy
|
||||
from scripts.nutrition_data import train_examples
|
||||
|
||||
# 0. Configure DSPy
|
||||
configure_dspy()
|
||||
|
||||
# 1. Define Validated Examples (The "Train Set")
|
||||
# ... (rest of the file) ...
|
||||
|
||||
|
||||
# 2. Define a Metric
|
||||
def validate_nutrition(example, pred, trace=None):
|
||||
# Check if the predicted calories are within 15% of the actual calories
|
||||
actual_cals = example.nutritional_info.calories
|
||||
pred_cals = pred.nutritional_info.calories
|
||||
|
||||
threshold = 0.15
|
||||
lower = actual_cals * (1 - threshold)
|
||||
upper = actual_cals * (1 + threshold)
|
||||
|
||||
return lower <= pred_cals <= upper
|
||||
|
||||
|
||||
# 3. Setup the Optimizer
|
||||
teleprompter = BootstrapFewShot(metric=validate_nutrition, max_bootstrapped_demos=8, max_labeled_demos=8)
|
||||
|
||||
# 4. Compile (Optimize) the Module
|
||||
print("Optimizing... (this calls the LLM for each example)")
|
||||
compiled_nutrition = teleprompter.compile(nutrition_module, trainset=train_examples)
|
||||
|
||||
# 5. Save validity
|
||||
# Correct path relative to backend/ directory
|
||||
compiled_nutrition.save("app/ai/nutrition_compiled.json")
|
||||
print("Optimization complete! Saved to app/ai/nutrition_compiled.json")
|
||||
|
||||
# 6. Usage
|
||||
# To use the optimized version in production, you would load it:
|
||||
# nutrition_module.load("backend/app/ai/nutrition_compiled.json")
|
||||
58
backend/scripts/optimize_nutrition_v2.py
Normal file
58
backend/scripts/optimize_nutrition_v2.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from dspy.teleprompt import BootstrapFewShotWithRandomSearch
|
||||
|
||||
from app.ai.nutrition import nutrition_module
|
||||
from app.core.ai_config import configure_dspy
|
||||
from scripts.nutrition_data import train_examples
|
||||
|
||||
# 0. Configure DSPy
|
||||
configure_dspy()
|
||||
|
||||
|
||||
# 1. Define Advanced Metric
|
||||
def validate_nutrition_v2(example, pred, trace=None):
|
||||
# Condition A: Accuracy (within 15% of ground truth)
|
||||
actual_cals = example.nutritional_info.calories
|
||||
pred_cals = pred.nutritional_info.calories
|
||||
|
||||
threshold = 0.15
|
||||
lower = actual_cals * (1 - threshold)
|
||||
upper = actual_cals * (1 + threshold)
|
||||
is_accurate_count = lower <= pred_cals <= upper
|
||||
|
||||
# Condition B: Consistency (Macros match Calories within 20%)
|
||||
# This prevents "hallucinated" numbers that don't satisfy physics
|
||||
p = pred.nutritional_info.protein
|
||||
c = pred.nutritional_info.carbs
|
||||
f = pred.nutritional_info.fats
|
||||
|
||||
calculated_cals = (p * 4) + (c * 4) + (f * 9)
|
||||
# Using a slightly looser bounds (20%) for fiber/rounding
|
||||
consistency_threshold = 0.20
|
||||
is_consistent_math = abs(calculated_cals - pred_cals) < (pred_cals * consistency_threshold)
|
||||
|
||||
# We want BOTH to be true
|
||||
return is_accurate_count and is_consistent_math
|
||||
|
||||
|
||||
# 2. Setup Advanced Optimizer
|
||||
# RandomSearch is more expensive but finds better reasoning traces by randomizing
|
||||
# the selection of few-shot examples.
|
||||
# num_candidate_programs=10 means it will try 10 different combinations of prompts/examples
|
||||
print("Configuring RandomSearch Optimizer...")
|
||||
teleprompter = BootstrapFewShotWithRandomSearch(
|
||||
metric=validate_nutrition_v2,
|
||||
max_bootstrapped_demos=4,
|
||||
max_labeled_demos=4,
|
||||
num_candidate_programs=5, # Reduced to 5 for speed in this demo, typically 10-20
|
||||
num_threads=1, # Sequential for stability, increase for parallelism
|
||||
)
|
||||
|
||||
# 3. Compile (Optimize) the Module
|
||||
print("Optimizing V2 (this includes random search and macro checks)...")
|
||||
# Note: assertions are compiled into the pipeline automatically in newer DSPy,
|
||||
# acting as soft constraints during the search.
|
||||
compiled_nutrition = teleprompter.compile(nutrition_module, trainset=train_examples)
|
||||
|
||||
# 4. Save
|
||||
compiled_nutrition.save("app/ai/nutrition_compiled.json")
|
||||
print("Optimization V2 complete! Overwrote app/ai/nutrition_compiled.json")
|
||||
Reference in New Issue
Block a user