import datetime
from dataclasses import dataclass
from typing import Literal
from pydantic import BaseModel, Field
from rich.prompt import Prompt
from pydantic_ai import Agent, ModelRetry, RunContext
from pydantic_ai.messages import ModelMessage
from pydantic_ai.usage import Usage, UsageLimits
from portkey_ai import AsyncPortkey
# Set up Portkey clients with shared trace ID for connected tracing
portkey_client = AsyncPortkey(
api_key="YOUR_PORTKEY_API_KEY",
provider="@YOUR_OPENAI_PROVIDER",
trace_id="flight-booking-session",
metadata={"app_type": "flight_booking"}
)
# Define structured output types
class FlightDetails(BaseModel):
"""Details of the most suitable flight."""
flight_number: str
price: int
origin: str = Field(description='Three-letter airport code')
destination: str = Field(description='Three-letter airport code')
date: datetime.date
class NoFlightFound(BaseModel):
"""When no valid flight is found."""
class SeatPreference(BaseModel):
row: int = Field(ge=1, le=30)
seat: Literal['A', 'B', 'C', 'D', 'E', 'F']
class Failed(BaseModel):
"""Unable to extract a seat selection."""
# Dependencies for flight search
@dataclass
class Deps:
web_page_text: str
req_origin: str
req_destination: str
req_date: datetime.date
# This agent is responsible for controlling the flow of the conversation
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.providers.openai import OpenAIProvider
search_agent = Agent[Deps, FlightDetails | NoFlightFound](
model=OpenAIModel(
model_name="gpt-4o",
provider=OpenAIProvider(openai_client=portkey_client),
),
output_type=FlightDetails | NoFlightFound, # type: ignore
retries=4,
system_prompt=(
'Your job is to find the cheapest flight for the user on the given date. '
),
instrument=True, # Enable instrumentation for better tracing
)
# This agent is responsible for extracting flight details from web page text
extraction_agent = Agent(
model=OpenAIModel(
model_name="gpt-4o",
provider=OpenAIProvider(openai_client=portkey_client),
),
output_type=list[FlightDetails],
system_prompt='Extract all the flight details from the given text.',
)
# This agent is responsible for extracting the user's seat selection
seat_preference_agent = Agent[None, SeatPreference | Failed](
model=OpenAIModel(
model_name="gpt-4o",
provider=OpenAIProvider(openai_client=portkey_client),
),
output_type=SeatPreference | Failed, # type: ignore
system_prompt=(
"Extract the user's seat preference. "
'Seats A and F are window seats. '
'Row 1 is the front row and has extra leg room. '
'Rows 14, and 20 also have extra leg room. '
),
)
@search_agent.tool
async def extract_flights(ctx: RunContext[Deps]) -> list[FlightDetails]:
"""Get details of all flights."""
# Pass the usage to track nested agent calls
result = await extraction_agent.run(ctx.deps.web_page_text, usage=ctx.usage)
return result.output
@search_agent.output_validator
async def validate_output(
ctx: RunContext[Deps], output: FlightDetails | NoFlightFound
) -> FlightDetails | NoFlightFound:
"""Procedural validation that the flight meets the constraints."""
if isinstance(output, NoFlightFound):
return output
errors: list[str] = []
if output.origin != ctx.deps.req_origin:
errors.append(
f'Flight should have origin {ctx.deps.req_origin}, not {output.origin}'
)
if output.destination != ctx.deps.req_destination:
errors.append(
f'Flight should have destination {ctx.deps.req_destination}, not {output.destination}'
)
if output.date != ctx.deps.req_date:
errors.append(f'Flight should be on {ctx.deps.req_date}, not {output.date}')
if errors:
raise ModelRetry('\n'.join(errors))
else:
return output
# Sample flight data (in a real application, this would be from a web scraper)
flights_web_page = """
1. Flight SFO-AK123
- Price: $350
- Origin: San Francisco International Airport (SFO)
- Destination: Ted Stevens Anchorage International Airport (ANC)
- Date: January 10, 2025
2. Flight SFO-AK456
- Price: $370
- Origin: San Francisco International Airport (SFO)
- Destination: Fairbanks International Airport (FAI)
- Date: January 10, 2025
... more flights ...
"""
# Main application flow
async def main():
# Restrict how many requests this app can make to the LLM
usage_limits = UsageLimits(request_limit=15)
deps = Deps(
web_page_text=flights_web_page,
req_origin='SFO',
req_destination='ANC',
req_date=datetime.date(2025, 1, 10),
)
message_history: list[ModelMessage] | None = None
usage: Usage = Usage()
# Run the agent until a satisfactory flight is found
while True:
result = await search_agent.run(
f'Find me a flight from {deps.req_origin} to {deps.req_destination} on {deps.req_date}',
deps=deps,
usage=usage,
message_history=message_history,
usage_limits=usage_limits,
)
if isinstance(result.output, NoFlightFound):
print('No flight found')
break
else:
flight = result.output
print(f'Flight found: {flight}')
answer = Prompt.ask(
'Do you want to buy this flight, or keep searching? (buy/*search)',
choices=['buy', 'search', ''],
show_choices=False,
)
if answer == 'buy':
seat = await find_seat(usage, usage_limits)
await buy_tickets(flight, seat)
break
else:
message_history = result.all_messages(
output_tool_return_content='Please suggest another flight'
)
async def find_seat(usage: Usage, usage_limits: UsageLimits) -> SeatPreference:
message_history: list[ModelMessage] | None = None
while True:
answer = Prompt.ask('What seat would you like?')
result = await seat_preference_agent.run(
answer,
message_history=message_history,
usage=usage,
usage_limits=usage_limits,
)
if isinstance(result.output, SeatPreference):
return result.output
else:
print('Could not understand seat preference. Please try again.')
message_history = result.all_messages()
async def buy_tickets(flight_details: FlightDetails, seat: SeatPreference):
print(f'Purchasing flight {flight_details=!r} {seat=!r}...')