Skip to content

Commit

Permalink
Merge pull request #6 from Lucs1590/llm-analysis
Browse files Browse the repository at this point in the history
LLM analysis
  • Loading branch information
Lucs1590 authored Sep 27, 2024
2 parents 51b0743 + dda676c commit e9f647d
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 24 deletions.
7 changes: 7 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
defusedxml==0.7.1
langchain_core==0.3.6
langchain_openai==0.2.1
numpy==2.1.1
pandas==2.2.3
python-dotenv==1.0.1
questionary==2.0.1
scipy==1.14.1
tcxreader==0.4.10
tqdm==4.66.5
162 changes: 145 additions & 17 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
import re
import os
import time
import logging
import webbrowser
import time
from defusedxml.minidom import parseString

from typing import Tuple

import questionary
import numpy as np
import pandas as pd

from tqdm import tqdm
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langchain_core.prompts.prompt import PromptTemplate
from defusedxml.minidom import parseString
from scipy.spatial.distance import squareform, pdist
from tcxreader.tcxreader import TCXReader


load_dotenv()
logger = logging.getLogger()

if not logger.handlers:
Expand Down Expand Up @@ -39,15 +50,22 @@ def main():
else:
file_path = ask_file_path(file_location)

if sport in ["Swim", "Other"]:
logger.info("Formatting the TCX file to be imported to TrainingPeaks")
format_to_swim(file_path)
elif sport in ["Bike", "Run"]:
logger.info("Validating the TCX file")
validate_tcx_file(file_path)
else:
logger.error("Invalid sport selected")
raise ValueError("Invalid sport selected")
if file_path:
if sport in ["Swim", "Other"]:
logger.info(
"Formatting the TCX file to be imported to TrainingPeaks"
)
format_to_swim(file_path)
elif sport in ["Bike", "Run"]:
logger.info("Validating the TCX file")
_, tcx_data = validate_tcx_file(file_path)
if ask_llm_analysis():
plan = ask_training_plan()
logger.info("Performing LLM analysis")
perform_llm_analysis(tcx_data, sport, plan)
else:
logger.error("Invalid sport selected")
raise ValueError("Invalid sport selected")

indent_xml_file(file_path)
logger.info("Process completed successfully!")
Expand Down Expand Up @@ -76,7 +94,8 @@ def ask_activity_id() -> str:

def download_tcx_file(activity_id: str, sport: str) -> None:
if sport in ["Swim", "Other"]:
url = f"https://www.strava.com/activities/{activity_id}/export_original"
url = f"https://www.strava.com/activities/{
activity_id}/export_original"
else:
url = f"https://www.strava.com/activities/{activity_id}/export_tcx"
try:
Expand Down Expand Up @@ -104,14 +123,23 @@ def get_latest_download() -> str:
return latest_file


def ask_file_path(file_location) -> str:
question = "Enter the path to the TCX file:" if file_location == "Provide path" else "Check if the TCX file was downloaded and then enter the path to the file:"
def ask_file_path(file_location: str) -> str:
if file_location == "Provide path":
question = "Enter the path to the TCX file:"
else:
question = "Check if the TCX was downloaded and validate the file:"

return questionary.path(
question,
validate=os.path.isfile
validate=validation,
only_directories=False
).ask()


def validation(path: str) -> bool:
return os.path.isfile(path)


def format_to_swim(file_path: str) -> None:
xml_str = read_xml_file(file_path)
xml_str = modify_xml_header(xml_str)
Expand All @@ -138,7 +166,7 @@ def write_xml_file(file_path: str, xml_str: str) -> None:
xml_file.write(xml_str)


def validate_tcx_file(file_path: str) -> bool:
def validate_tcx_file(file_path: str) -> Tuple[bool, TCXReader]:
xml_str = read_xml_file(file_path)
if not xml_str:
logger.error("The TCX file is empty.")
Expand All @@ -151,12 +179,112 @@ def validate_tcx_file(file_path: str) -> bool:
"The TCX file is valid. You covered a significant distance in this activity, with %d meters.",
data.distance
)
return True
return True, data
except Exception as err:
logger.error("Invalid TCX file.")
raise ValueError(f"Error reading the TCX file: {err}") from err


def ask_llm_analysis() -> str:
return questionary.confirm(
"Do you want to perform AI analysis?",
default=False
).ask()


def ask_training_plan() -> str:
return questionary.text(
"Was there anything planned for this training?"
).ask()


def perform_llm_analysis(data: TCXReader, sport: str, plan: str) -> str:
dataframe = preprocess_trackpoints_data(data)

prompt = """SYSTEM: You are an AI Assistant that helps athletes to improve their performance.
Based on the following csv data that is related to a {sport} training session, carry out an analysis highlighting positive points, where the athlete did well and where he did poorly and what he can do to improve in the next {sport}.
<csv_data>
{data}
</csv_data>
"""
prompt += "plan: {plan}" if plan else ""
prompt = PromptTemplate.from_template(prompt)
prompt = prompt.format(
sport=sport,
data=dataframe.to_csv(index=False),
plan=plan
)

openai_llm = ChatOpenAI(
openai_api_key=os.getenv("OPENAI_API_KEY"),
model_name="gpt-4o",
max_tokens=1500,
temperature=0.6,
max_retries=5
)
response = openai_llm.invoke(prompt)
logger.info("AI analysis completed successfully.")
logger.info("\nAI response:\n %s \n", response.content)
return response.content


def preprocess_trackpoints_data(data):
dataframe = pd.DataFrame(data.trackpoints_to_dict())
dataframe.rename(
columns={
"distance": "Distance_Km",
"time": "Time",
"Speed": "Speed_Kmh"
}, inplace=True
)
dataframe["Time"] = dataframe["Time"].apply(lambda x: x.value / 10**9)
dataframe["Distance_Km"] = round(dataframe["Distance_Km"] / 1000, 2)
dataframe["Speed_Kmh"] = dataframe["Speed_Kmh"] * 3.6
dataframe["Pace"] = round(
dataframe["Speed_Kmh"].apply(lambda x: 60 / x if x > 0 else 0),
2
)
if dataframe["cadence"].isnull().sum() >= len(dataframe) / 2:
dataframe.drop(columns=["cadence"], inplace=True)

dataframe = dataframe.drop_duplicates()
dataframe = dataframe.reset_index(drop=True)
dataframe = dataframe.dropna()

if dataframe.shape[0] > 4000:
dataframe = run_euclidean_dist_deletion(dataframe, 0.55)
elif dataframe.shape[0] > 1000:
dataframe = run_euclidean_dist_deletion(dataframe, 0.35)
else:
dataframe = run_euclidean_dist_deletion(dataframe, 0.10)

dataframe["Time"] = pd.to_datetime(
dataframe["Time"],
unit='s'
).dt.strftime('%H:%M:%S')

return dataframe


def run_euclidean_dist_deletion(dataframe: pd.DataFrame, percentage: float) -> pd.DataFrame:
dists = pdist(dataframe, metric='euclidean')
dists = squareform(dists)
np.fill_diagonal(dists, np.inf)

total_rows = int(percentage * len(dataframe))
with tqdm(total=total_rows, desc="Removing similar points") as pbar:
for _ in range(total_rows):
min_idx = np.argmin(dists)
row, col = np.unravel_index(min_idx, dists.shape)
dists[row, :] = np.inf
dists[:, col] = np.inf
dataframe = dataframe.drop(row)
pbar.update(1)

dataframe = dataframe.reset_index(drop=True)
return dataframe


def indent_xml_file(file_path: str) -> None:
try:
with open(file_path, "r", encoding='utf-8') as xml_file:
Expand Down
92 changes: 85 additions & 7 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
#import sys
# import sys
import unittest

from unittest.mock import patch
from pandas import DataFrame
from tcxreader.tcxreader import TCXReader

# sys.path.append(os.path.abspath(''))

Expand All @@ -18,13 +21,21 @@
ask_file_location,
ask_activity_id,
ask_file_path,
get_latest_download
get_latest_download,
validation,
ask_training_plan,
ask_llm_analysis,
perform_llm_analysis,
preprocess_trackpoints_data,
run_euclidean_dist_deletion
)


class TestMain(unittest.TestCase):
def setUp(self) -> None:
pass
tcx_reader = TCXReader()
self.running_example_data = tcx_reader.read("assets/run.tcx")
self.biking_example_data = tcx_reader.read("assets/bike.tcx")

@patch('src.main.webbrowser.open')
def test_download_tcx_file(self, mock_open):
Expand Down Expand Up @@ -96,6 +107,7 @@ def test_validate_tcx_file(self):
file_path = "assets/bike.tcx"
result = validate_tcx_file(file_path)
self.assertTrue(result)
self.assertEqual(len(result), 2)

def test_validate_tcx_file_error(self):
file_path = "assets/swim.tcx"
Expand Down Expand Up @@ -185,6 +197,9 @@ def test_main_invalid_sport(self, mock_indent, mock_validate, mock_format, mock_
mock_validate.assert_not_called()
mock_indent.assert_not_called()

@patch('src.main.ask_training_plan')
@patch('src.main.perform_llm_analysis')
@patch('src.main.ask_llm_analysis')
@patch('src.main.ask_sport')
@patch('src.main.ask_file_location')
@patch('src.main.ask_activity_id')
Expand All @@ -194,10 +209,15 @@ def test_main_invalid_sport(self, mock_indent, mock_validate, mock_format, mock_
@patch('src.main.validate_tcx_file')
@patch('src.main.indent_xml_file')
def test_main_bike_sport(self, mock_indent, mock_validate, mock_format, mock_ask_path, mock_download,
mock_ask_id, mock_ask_location, mock_ask_sport):
mock_ask_id, mock_ask_location, mock_ask_sport, mock_llm_analysis, mock_perform_llm,
mock_training_plan):
mock_ask_sport.return_value = "Bike"
mock_ask_location.return_value = "Local"
mock_ask_path.return_value = "assets/bike.tcx"
mock_llm_analysis.return_value = True
mock_validate.return_value = True, "TCX Data"
mock_perform_llm.return_value = "Training Plan"
mock_training_plan.return_value = ""

main()

Expand All @@ -207,6 +227,8 @@ def test_main_bike_sport(self, mock_indent, mock_validate, mock_format, mock_ask
mock_ask_path.assert_called_once()
mock_download.assert_not_called()
mock_format.assert_not_called()
mock_llm_analysis.assert_called_once()
mock_perform_llm.assert_called_once()
mock_validate.assert_called_once_with("assets/bike.tcx")
mock_indent.assert_called_once_with("assets/bike.tcx")

Expand Down Expand Up @@ -245,7 +267,8 @@ def test_ask_file_path(self):
result = ask_file_path("Provide path")
mock_path.assert_called_once_with(
"Enter the path to the TCX file:",
validate=os.path.isfile
validate=validation,
only_directories=False
)
self.assertEqual(result, "assets/test.tcx")

Expand All @@ -254,8 +277,9 @@ def test_ask_file_path(self):
mock_path.return_value.ask.return_value = "assets/downloaded.tcx"
result = ask_file_path("Download")
mock_path.assert_called_once_with(
"Check if the TCX file was downloaded and then enter the path to the file:",
validate=os.path.isfile
"Check if the TCX was downloaded and validate the file:",
validate=validation,
only_directories=False
)
self.assertEqual(result, "assets/downloaded.tcx")

Expand All @@ -275,6 +299,60 @@ def test_get_latest_downloads_with_ask(self, mock_ask_path):

self.assertEqual(result, "assets/bike.tcx")

def test_validation(self):
file_path = "assets/bike.tcx"
result = validation(file_path)

self.assertTrue(result)

def test_ask_training_plan(self):
with patch('src.main.questionary.text') as mock_text:
mock_text.return_value.ask.return_value = ""
result = ask_training_plan()
mock_text.assert_called_once_with(
"Was there anything planned for this training?"
)
self.assertEqual(result, "")

def test_ask_llm_analysis(self):
with patch('src.main.questionary.confirm') as mock_confirm:
mock_confirm.return_value.ask.return_value = True
result = ask_llm_analysis()
mock_confirm.assert_called_once_with(
"Do you want to perform AI analysis?",
default=False
)
self.assertTrue(result)

@patch('src.main.ChatOpenAI')
def test_perform_llm_analysis(self, mock_chat):
mock_invoke = mock_chat.return_value.invoke.return_value
mock_invoke.content = "Training Plan"
tcx_data = self.running_example_data
sport = "Run"
plan = "Training Plan"

result = perform_llm_analysis(tcx_data, sport, plan)
self.assertEqual(result, "Training Plan")

def test_preprocess_running_trackpoints_data(self):
tcx_data = self.running_example_data
result = preprocess_trackpoints_data(tcx_data)
self.assertEqual(len(result), 1646)

def test_preprocess_biking_trackpoints_data(self):
tcx_data = self.biking_example_data
result = preprocess_trackpoints_data(tcx_data)
self.assertEqual(len(result), 2028)

def test_run_euclidean_distance(self):
dataframe = DataFrame({
'latitude': [1, 2, 3, 3.5, 4, 5, 6, 6.5, 7, 8, 9],
'longitude': [1, 2, 3, 3.5, 4, 5, 6, 6.5, 7, 8, 9]
})
result = run_euclidean_dist_deletion(dataframe, 0.1)
self.assertEqual(len(result), 10)


if __name__ == '__main__':
unittest.main()

0 comments on commit e9f647d

Please sign in to comment.