Skip to content

Commit

Permalink
refactor: replaced mysql-connector-python methods with ORM: #74
Browse files Browse the repository at this point in the history
  • Loading branch information
hwakabh committed Jan 23, 2024
1 parent 6a9ca9a commit 3a92f1f
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 91 deletions.
150 changes: 69 additions & 81 deletions app/api/v1/cruds.py
Original file line number Diff line number Diff line change
@@ -1,139 +1,127 @@
import json
import random

import mysql.connector as mydb
import numpy as np
from sqlalchemy.orm import Session

# from app.api.v1 import models
from app.api.v1 import models
from app.api.v1 import schemas
from app.api.v1.helpers import dist_on_sphere
from app.config import app_settings


# TODO: should be used with ORM
conn = mydb.connect(
host=app_settings.MYSQL_HOST,
user=app_settings.MYSQL_USER,
password=app_settings.MYSQL_PASSWORD,
database=app_settings.MYSQL_DATABASE,
charset="utf8"
)


# def get_airports_from_db(db: Session) -> schemas.Airport:
# return db.query(models.Airport).limit(5).all()


# Bind cruds functions for return results to router
def get_destination(req: schemas.SearchRequestBody):
def get_destination(db: Session, req: schemas.SearchRequestBody):
#--- get ajax POST data
print(f'User conditions: {req}')

#--- search and get near airport from MySQL (airport table)
near_airport_IATA = get_near_airport(
current_lat=req.current_lat,
current_lng=req.current_lng
db=db,
lat=req.current_lat,
lng=req.current_lng
)
print("near_airport_IATA: " + near_airport_IATA)

#--- search and get reachable location (airport and country) from skyscanner api
#--- exclude if time and travel expenses exceed the user input parameter
#--- select a country at random
destination = get_destination_from_skyscanner_by_random(near_airport_IATA)
destination = get_destination_from_skyscanner_by_random(
db=db,
iata=near_airport_IATA
)
print('Destination: ')
print(destination)

return schemas.SearchResultResponseBody(**destination)


# --- search and get near airport from MySQL (airport table)
def get_near_airport(current_lat,current_lng):

conn.ping(reconnect=True)
cur = conn.cursor()

current = float(current_lat), float(current_lng)
def get_near_airport(db: Session, lat: float, lng: float) -> str:
target = []
dist_result = []
search_key = []
count = 0

cur.execute('select id,IATA,Name,Country,City,Latitude,Longitude from airport where not IATA="NULL"')

for sql_result in cur.fetchall():
target = sql_result[5], sql_result[6]
dist = dist_on_sphere(current, target)
dist_result.append([count,sql_result[0],sql_result[1],sql_result[2],sql_result[3],sql_result[4],dist])
airports = db.query(
models.Airport.id,
models.Airport.IATA,
models.Airport.name,
models.Airport.country,
models.Airport.city,
models.Airport.latitude,
models.Airport.longitude
).filter(
models.Airport.IATA != "NULL"
).all()

for airport in airports:
target = airport[5], airport[6]
dist = dist_on_sphere(
pos0=(lat, lng),
pos1=target
)
dist_result.append([count,airport[0],airport[1],airport[2],airport[3],airport[4],dist])
search_key.append(dist)
count = count + 1

cur.close()
conn.close()

#--- return near airport IATA
return dist_result[np.argmin(search_key)][2]


#--- search and get reachable location (airport and country) from skyscanner api
#--- exclude if time and travel expenses exceed the user input parameter
#--- select a country at random
def get_destination_from_skyscanner_by_random(near_airport_IATA) -> dict:
def get_destination_from_skyscanner_by_random(db: Session, iata: str) -> dict:
# --- search and get reachable location (airport and country) from skyscanner api
# --- exclude if time and travel expenses exceed the user input parameter

##########################################################################################
##################################### Update required #####################################
###########################################################################################

#reachable_airport_IATA = ["TXL","YTD","CQS","NYR","QFG","NZE","IWK"]

conn.ping(reconnect=True)
cur = conn.cursor()

cur.execute('select IATA from airport where not IATA="NULL"')
reachable_airport_IATA = []
for sql_result in cur.fetchall():
reachable_airport_IATA.append(sql_result[0])

cur.close()
conn.close()

##########################################################################################
##########################################################################################
##########################################################################################
airport_codes = db.query(models.Airport.IATA).filter(models.Airport.IATA != "NULL").all()

reachable_airport_IATA = [airport_code[0] for airport_code in airport_codes]
#--- select a country at random
random_airport_IATA = random.choice(reachable_airport_IATA)

#--- get lat/lng of near and selected airport from MySQL (airport table)
conn.ping(reconnect=True)
cur = conn.cursor()
transit_airports = db.query(
models.Airport.country,
models.Airport.city,
models.Airport.IATA,
models.Airport.name,
models.Airport.latitude,
models.Airport.longitude,
).filter(
models.Airport.IATA == iata
).all()

cur.execute('select Country,City,IATA,Name,Latitude,Longitude from airport where IATA="' + near_airport_IATA + '"')
transit = []
for sql_result in cur.fetchall():
transit.append([sql_result[0],sql_result[1],sql_result[2],sql_result[3],sql_result[4],sql_result[5]])
for airport in transit_airports:
transit.append([airport[0],airport[1],airport[2],airport[3],airport[4],airport[5]])

destination_airports = db.query(
models.Airport.country,
models.Airport.city,
models.Airport.IATA,
models.Airport.name,
models.Airport.latitude,
models.Airport.longitude,
).filter(
models.Airport.IATA == random_airport_IATA
).all()

cur.execute('select Country,City,IATA,Name,Latitude,Longitude from airport where IATA="' + random_airport_IATA + '"')
destination = []
for sql_result in cur.fetchall():
destination.append([sql_result[0],sql_result[1],sql_result[2],sql_result[3],sql_result[4],sql_result[5]])

cur.close()
conn.close()
for airport in destination_airports:
destination.append([airport[0],airport[1],airport[2],airport[3],airport[4],airport[5]])

return {
"tran_country":transit[0][0],
"tran_city":transit[0][1],
"tran_iata":transit[0][2],
"tran_airport":transit[0][3],
"tran_lat":transit[0][4],
"tran_lng":transit[0][5],
"dest_country":destination[0][0],
"dest_city":destination[0][1],
"dest_iata":destination[0][2],
"dest_airport":destination[0][3],
"dest_lat":destination[0][4],
"dest_lng":destination[0][5]
"tran_country": transit[0][0],
"tran_city": transit[0][1],
"tran_iata": transit[0][2],
"tran_airport": transit[0][3],
"tran_lat": transit[0][4],
"tran_lng": transit[0][5],
"dest_country": destination[0][0],
"dest_city": destination[0][1],
"dest_iata": destination[0][2],
"dest_airport": destination[0][3],
"dest_lat": destination[0][4],
"dest_lng": destination[0][5]
}
15 changes: 14 additions & 1 deletion app/api/v1/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sqlalchemy import Column, String, Integer
from sqlalchemy import Column, String, Integer, Float

from app.database import Base

Expand All @@ -7,3 +7,16 @@ class Airport(Base):
__tablename__ = 'airport'

id: int = Column(Integer, primary_key=True)
name: str = Column(String(74))
city: str = Column(String(35))
country: str = Column(String(34))
IATA: str = Column(String(3))
ICAO: str = Column(String(4))
latitude: float = Column(Float)
longitude: float = Column(Float)
altitude: int = Column(Integer)
tz_offset: float = Column(Float)
DST: str = Column(String(1))
tz_dbtime: str = Column(String(32))
types: str = Column(String(13))
datasource: str = Column(String(13))
20 changes: 11 additions & 9 deletions app/api/v1/routers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from fastapi import APIRouter, Depends
from fastapi.requests import Request
from fastapi.responses import Response, JSONResponse
# from sqlalchemy.orm import Session
from sqlalchemy.orm import Session

from app.api.v1 import services
from app.api.v1 import cruds
from app.api.v1 import schemas
from app.database import get_db
from app.api.v1 import models
from app.database import get_db, engine

router = APIRouter()
# Create table if not exists
models.Base.metadata.create_all(bind=engine)


@router.get('/')
Expand All @@ -19,22 +22,21 @@ def index(req: Request) -> schemas.RootResponse:
})


# @router.get('/airports')
# def get_airports(db: Session = Depends(get_db)) -> list[schemas.Airport]:
# return cruds.get_airports_from_db(db=db)


@router.get('/fetch')
def fetch() -> Response:
return services.load_google_map()


@router.post('/shuffle')
def get_random_country(payload: schemas.SearchRequestBody) -> schemas.SearchResultResponseBody:
def get_random_country(
payload: schemas.SearchRequestBody,
db: Session = Depends(get_db)
) -> schemas.SearchResultResponseBody:

country = services.get_random_country()
print(f'Randomly selected country: {country}')

return cruds.get_destination(req=payload)
return cruds.get_destination(db=db, req=payload)


@router.post('/translate')
Expand Down

0 comments on commit 3a92f1f

Please sign in to comment.