Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

documentation for scrape and database #294

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions src/sherpa_ai/database/user_usage_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@


class UsageTracker(Base):

"""SQLAlchemy base model for tracking LLM token usage on per-user basis"""

__tablename__ = "usage_tracker"
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(String)
Expand All @@ -26,18 +29,32 @@ class UsageTracker(Base):


class Whitelist(Base):

"""Represents a trusted list of users whose usage is not tracked"""

__tablename__ = "whitelist"

id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(String)


class UserUsageTracker:

"""Enables an app to track LLM token usage on per-user basis"""

def __init__(
self,
db_name=cfg.DB_NAME,
verbose_logger: BaseVerboseLogger = DummyVerboseLogger(),
):
"""
Initialize the UserUsageTracker.

Args:
db_name (str): Name of the database.
max_daily_token (int): Maximum daily token limit.
"""

self.engine = create_engine(db_name)
Session = sessionmaker(bind=self.engine)
self.session = Session()
Expand All @@ -48,19 +65,45 @@ def __init__(
self.usage_percentage_allowed = 75

def download_from_s3(self, bucket_name, s3_file_key, local_file_path):
"""
Download a file from Amazon S3.

Args:
bucket_name (str): Name of the S3 bucket.
s3_file_key (str): Key of the file in the S3 bucket.
local_file_path (str): Local path where the file will be downloaded.
"""
file_path = Path("./token_counter.db")
if not file_path.exists():
s3 = boto3.client("s3")
s3.download_file(bucket_name, s3_file_key, local_file_path)

def upload_to_s3(self, local_file_path, bucket_name, s3_file_key):
"""
Upload a file to Amazon S3.

Args:
local_file_path (str): Local path of the file to be uploaded.
bucket_name (str): Name of the S3 bucket.
s3_file_key (str): Key of the file in the S3 bucket.
"""

s3 = boto3.client("s3")
s3.upload_file(local_file_path, bucket_name, s3_file_key)

def create_table(self):
"""Create the necessary tables in the database."""

Base.metadata.create_all(self.engine)

def add_to_whitelist(self, user_id):
"""
Add a user to the whitelist table.

Args:
user_id (str): ID of the user to be added to the whitelist.
"""

user = Whitelist(user_id=user_id)
self.session.add(user)
self.session.commit()
Expand All @@ -70,19 +113,50 @@ def add_to_whitelist(self, user_id):
)

def get_all_whitelisted_ids(self):
"""Get a list of all user IDs in the whitelist."""

whitelisted_ids = [user.user_id for user in self.session.query(Whitelist).all()]
return whitelisted_ids

def get_whitelist_by_user_id(self, user_id):
"""
Get whitelist information for a specific user.

Args:
user_id (str): ID of the user.

Returns:
list: List of dictionaries containing whitelist information.
"""

data = self.session.query(Whitelist).filter_by(user_id=user_id).all()
return [{"id": item.id, "user_id": item.user_id} for item in data]

def is_in_whitelist(self, user_id):
"""
Check if a user is in the whitelist.

Args:
user_id (str): ID of the user.

Returns:
bool: True if the user is in the whitelist, False otherwise.
"""

return bool(self.get_whitelist_by_user_id(user_id))

def add_and_check_data(
self, combined_id, token, reset_timestamp=False, reminded_timestamp=False
):
"""
Add usage data for a user and check for reminders.

Args:
combined_id (str): Combined ID of the user.
token (int): Number of tokens used.
reset_timestamp (bool): Whether to reset the timestamp.
reminded_timestamp (bool): Set reminded_timestamp.
"""
self.add_data(
combined_id=combined_id,
token=token,
Expand All @@ -94,6 +168,16 @@ def add_and_check_data(
def add_data(
self, combined_id, token, reset_timestamp=False, reminded_timestamp=False
):
"""
Add usage data for a user.

Args:
combined_id (str): Combined ID of the user.
token (int): Number of tokens used.
reset_timestamp (bool): Whether to reset the timestamp.
reminded_timestamp (bool): Set reminded_timestamp.

"""
data = UsageTracker(
user_id=combined_id,
token=token,
Expand All @@ -105,12 +189,28 @@ def add_data(
self.session.commit()

def percentage_used(self, combined_id):
"""
Calculate the percentage of daily token quota used by a user.

Args:
combined_id (str): Combined ID of the user.

Returns:
float: Percentage of daily tokens used since last reset.
"""

total_token_since_last_reset = self.get_sum_of_tokens_since_last_reset(
user_id=combined_id
)
return (total_token_since_last_reset * 100) / self.max_daily_token

def remind_user_of_daily_token_limit(self, combined_id):
"""
Remind the user when their token usage exceeds a certain percentage.

Args:
combined_id (str): Combined ID of the user.
"""
split_parts = combined_id.split("_")
user_id = ""
if len(split_parts) > 0:
Expand All @@ -131,6 +231,16 @@ def remind_user_of_daily_token_limit(self, combined_id):
)

def get_data_since_last_reset(self, user_id):
"""
Get usage since the user's usage data was last reset.

Args:
user_id (str): ID of the user.

Returns:
list: List of dictionaries containing usage data.
"""

last_reset_info = self.get_last_reset_info(user_id)

if last_reset_info is None or last_reset_info["id"] is None:
Expand Down Expand Up @@ -176,6 +286,16 @@ def check_if_reminded(self, combined_id):
return is_reminded_true

def get_sum_of_tokens_since_last_reset(self, user_id):
"""
Calculate the sum of tokens used since the last reset for a user.

Args:
user_id (str): ID of the user.

Returns:
int: Sum of tokens used since the last reset.
"""

data_since_last_reset = self.get_data_since_last_reset(user_id)

if len(data_since_last_reset) == 1 and "user_id" in data_since_last_reset[0]:
Expand All @@ -185,11 +305,28 @@ def get_sum_of_tokens_since_last_reset(self, user_id):
return token_sum

def reset_usage(self, combined_id, token_amount):
"""
Reset the usage data for a user to zero.

Args:
combined_id (str): Combined ID of the user.
token_amount (int): Number of tokens to reset.
"""
self.add_and_check_data(
combined_id=combined_id, token=token_amount, reset_timestamp=True
)

def get_last_reset_info(self, combined_id):
"""
Get information about the most recent usage data reset for a user.

Args:
combined_id (str): Combined ID of the user.

Returns:
dict or None: Dictionary containing last reset information or None if not found.
"""

data = (
self.session.query(UsageTracker.id, UsageTracker.timestamp)
.filter(
Expand All @@ -205,6 +342,16 @@ def get_last_reset_info(self, combined_id):
return None

def seconds_to_hms(self, seconds):
"""
Convert seconds to hours, minutes, and seconds.

Args:
seconds (int): Number of seconds.

Returns:
str: Formatted string in the format "hours : minutes : seconds".
"""

remaining_seconds = int(float(cfg.LIMIT_TIME_SIZE_IN_HOURS) * 3600 - seconds)
hours = remaining_seconds // 3600
minutes = (remaining_seconds % 3600) // 60
Expand All @@ -213,6 +360,20 @@ def seconds_to_hms(self, seconds):
return f"{hours} hours : {minutes} min : {seconds} sec"

def check_usage(self, user_id, combined_id, token_amount):
"""
Check user usage and determine whether user is allowed to consume more tokens.

Args:
user_id (str): ID of the user.
combined_id (str): Combined ID of the user.
token_amount (int): Number of tokens to check.

Returns:
dict: Result containing information about tokens remaining,
whether more tokens can be consumed (can_excute),
any associated message, and the time left.
"""

user_is_whitelisted = self.is_in_whitelist(user_id)

if user_is_whitelisted:
Expand Down Expand Up @@ -278,4 +439,6 @@ def get_all_data(self):
]

def close_connection(self):
"""Close the database connection."""

self.session.close()
28 changes: 28 additions & 0 deletions src/sherpa_ai/scrape/extract_github_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,31 @@


def get_owner_and_repo(url):
"""
Extracts the owner and repository name from a GitHub repository URL.

Parameters:
- url (str): The GitHub repository URL.

Returns:
Tuple[str, str]: A tuple containing the owner and repository name.
"""

url_content_list = url.split("/")
return url_content_list[3], url_content_list[4].split("#")[0]


def extract_github_readme(repo_url):
"""
Extracts the content of the README file from a GitHub repository.

Parameters:
- repo_url (str): The GitHub repository URL.

Returns:
str or None: The content of the README file, or None if the file is not found.
"""

pattern = r"(?:https?://)?(?:www\.)?github\.com/.*"
match = re.match(pattern, repo_url)
if match:
Expand Down Expand Up @@ -64,6 +84,14 @@ def extract_github_readme(repo_url):


def save_to_pine_cone(content, metadatas):
"""
Saves the content and metadata to Pinecone vector store.

Parameters:
- content (str): The content to be saved.
- metadatas (list): List of metadata associated with the content.
"""

pinecone.init(api_key=cfg.PINECONE_API_KEY, environment=cfg.PINECONE_ENV)
index = pinecone.Index("langchain")
embeddings = OpenAIEmbeddings(openai_api_key=cfg.OPENAI_API_KEY)
Expand Down
Loading