Skip to content
This repository has been archived by the owner on Aug 10, 2023. It is now read-only.

Commit

Permalink
added some type hints for function args (#894)
Browse files Browse the repository at this point in the history
  • Loading branch information
linhandev authored Feb 25, 2023
1 parent c19c40c commit 948aef9
Showing 1 changed file with 27 additions and 38 deletions.
65 changes: 27 additions & 38 deletions src/revChatGPT/V1.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""
Standard ChatGPT
"""
import asyncio
from __future__ import annotations

import json
import logging
import time
Expand All @@ -23,7 +24,7 @@
log = logging.getLogger(__name__)


def logger(is_timed):
def logger(is_timed: bool):
"""
Logger decorator
"""
Expand Down Expand Up @@ -82,16 +83,16 @@ class Chatbot:
@logger(is_timed=True)
def __init__(
self,
config,
conversation_id=None,
parent_id=None,
config: dict[str, str],
conversation_id: str | None = None,
parent_id: str | None = None,
session_client=None,
) -> None:
self.config = config
self.session = session_client() if session_client else requests.Session()

if "proxy" in config:
if isinstance(config["proxy"], str) is False:
if not isinstance(config["proxy"], str):
raise Exception("Proxy must be a string!")
proxies = {
"http": config["proxy"],
Expand All @@ -118,15 +119,15 @@ def __check_credentials(self):
elif "session_token" in self.config:
pass
else:
raise Exception("No login details provided!")
raise Exception("Insufficient login details provided!")
if "access_token" not in self.config:
try:
self.__login()
except AuthError as error:
raise error

@logger(is_timed=False)
def __refresh_headers(self, access_token):
def __refresh_headers(self, access_token: str):
self.session.headers.clear()
self.session.headers.update(
{
Expand All @@ -145,8 +146,8 @@ def __login(self):
if (
"email" not in self.config or "password" not in self.config
) and "session_token" not in self.config:
log.error("No login details provided!")
raise Exception("No login details provided!")
log.error("Insufficient login details provided!")
raise Exception("Insufficient login details provided!")
auth = Authenticator(
email_address=self.config.get("email"),
password=self.config.get("password"),
Expand All @@ -171,25 +172,21 @@ def __login(self):
@logger(is_timed=True)
def ask(
self,
prompt,
conversation_id=None,
parent_id=None,
timeout=360,
prompt: str,
conversation_id: str | None = None,
parent_id: str | None = None,
timeout: float = 360,
):
"""
Ask a question to the chatbot
:param prompt: String
:param conversation_id: UUID
:param parent_id: UUID
:param gen_title: Boolean
:param timeout: Float. Unit is second
"""
if parent_id is not None and conversation_id is None:
log.error("conversation_id must be set once parent_id is set")
error = Error()
error.source = "User"
error.message = "conversation_id must be set once parent_id is set"
error.code = -1
raise error
raise Error("User", "conversation_id must be set once parent_id is set", -1)

if conversation_id is not None and conversation_id != self.conversation_id:
log.debug("Updating to new conversation by setting parent_id to None")
Expand All @@ -203,7 +200,6 @@ def ask(

if conversation_id is not None and parent_id is None:
if conversation_id not in self.conversation_mapping:

log.debug(
"Conversation ID %s not found in conversation mapping, mapping conversations",
conversation_id,
Expand Down Expand Up @@ -304,17 +300,12 @@ def __check_fields(self, data: dict) -> bool:

@logger(is_timed=False)
def __check_response(self, response):

if response.status_code != 200:
print(response.text)
error = Error()
error.source = "OpenAI"
error.code = response.status_code
error.message = response.text
raise error
raise Error("OpenAI", response.status_code, response.text)

@logger(is_timed=True)
def get_conversations(self, offset=0, limit=20):
def get_conversations(self, offset: int = 0, limit: int = 20):
"""
Get conversations
:param offset: Integer
Expand All @@ -327,7 +318,7 @@ def get_conversations(self, offset=0, limit=20):
return data["items"]

@logger(is_timed=True)
def get_msg_history(self, convo_id, encoding=None):
def get_msg_history(self, convo_id: str, encoding: str | None = None):
"""
Get message history
:param id: UUID of conversation
Expand All @@ -342,21 +333,20 @@ def get_msg_history(self, convo_id, encoding=None):
return data

@logger(is_timed=True)
def gen_title(self, convo_id, message_id):
def gen_title(self, convo_id: str, message_id: str):
"""
Generate title for conversation
"""
url = BASE_URL + f"api/conversation/gen_title/{convo_id}"
response = self.session.post(
url,
BASE_URL + f"api/conversation/gen_title/{convo_id}",
data=json.dumps(
{"message_id": message_id, "model": "text-davinci-002-render"},
),
)
self.__check_response(response)

@logger(is_timed=True)
def change_title(self, convo_id, title):
def change_title(self, convo_id: str, title: str):
"""
Change title of conversation
:param id: UUID of conversation
Expand All @@ -367,7 +357,7 @@ def change_title(self, convo_id, title):
self.__check_response(response)

@logger(is_timed=True)
def delete_conversation(self, convo_id):
def delete_conversation(self, convo_id: str):
"""
Delete conversation
:param id: UUID of conversation
Expand Down Expand Up @@ -402,11 +392,11 @@ def reset_chat(self) -> None:
self.conversation_id = None
self.parent_id = str(uuid.uuid4())

@logger
def rollback_conversation(self, num=1) -> None:
@logger(is_timed=False)
def rollback_conversation(self, num: int = 1) -> None:
"""
Rollback the conversation.
:param num: The number of messages to rollback
:param num: Integer. The number of messages to rollback
:return: None
"""
for _ in range(num):
Expand Down Expand Up @@ -680,7 +670,6 @@ def handle_commands(command: str) -> bool:
elif command == "!config":
print(json.dumps(chatbot.config, indent=4))
elif command.startswith("!rollback"):

try:
rollback = int(command.split(" ")[1])
except IndexError:
Expand Down

0 comments on commit 948aef9

Please sign in to comment.