Skip to content

Commit

Permalink
Rearrangements.
Browse files Browse the repository at this point in the history
  • Loading branch information
sdxsd committed Dec 8, 2023
1 parent e8b6598 commit 8f20c6a
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions subtitlecorrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ def __init__(self, chosen_prompt=1):
self.token_usage_output = 0
self.client = AsyncOpenAI()

def handle_exception(self, exception):
print("EXCEPTION: Type: ", exception.type, " | Message: ", exception.message)
map(lambda query: query.cancel(), self.queries)

def validate_finish_reason(self, finish_reason):
match finish_reason:
case "stop":
Expand All @@ -96,25 +100,6 @@ def validate_finish_reason(self, finish_reason):
raise QueryException(finish_reason, "Query failed due to exceeding the token limit.")
case "content_filter":
raise QueryException(finish_reason, "Query failed due to violation of content policy")

# Queries ChatGPT with the stripped SRT data.
async def query_chatgpt(self, query_str, query_number):
start = time.time()
response = await self.client.chat.completions.create(
model=self.model,
messages=[
{'role': 'system', 'content': self.prompt_list[int(self.chosen_prompt)]},
{'role': 'user', 'content': query_str},
]
)
print("Query number: ", query_number, " | ", "Response received in: ", round((time.time() - start), 2), " seconds")
self.token_usage_input += response.usage.prompt_tokens
self.token_usage_output += response.usage.completion_tokens
self.validate_finish_reason(response.choices[0].finish_reason)
answer = response.choices[0].message.content
if (answer[-1] != os.linesep):
answer += os.linesep
return answer

# Counts the number of tokens in a given string.
def num_tokens(self, raw_text):
Expand All @@ -135,6 +120,25 @@ def calculate_cost(self):
def report_status(self, token_count):
print("Sending query with token count: ", (token_count + self.prompt_token_count), " | Query count: ", self.query_counter, "/", self.total_queries)

# Queries ChatGPT with the stripped SRT data.
async def query_chatgpt(self, query_str, query_number):
start = time.time()
response = await self.client.chat.completions.create(
model=self.model,
messages=[
{'role': 'system', 'content': self.prompt_list[int(self.chosen_prompt)]},
{'role': 'user', 'content': query_str},
]
)
print("Query number: ", query_number, " | ", "Response received in: ", round((time.time() - start), 2), " seconds")
self.token_usage_input += response.usage.prompt_tokens
self.token_usage_output += response.usage.completion_tokens
self.validate_finish_reason(response.choices[0].finish_reason)
answer = response.choices[0].message.content
if (answer[-1] != os.linesep):
answer += os.linesep
return answer

# Replaces the "content" variable of the original subtitle block list
# using the sum of the responses from GPT.
def replace_sub_content(self, rawlines, slist):
Expand Down Expand Up @@ -164,10 +168,6 @@ async def send_and_receive(self, query_str, token_count):
print("Inconsistent output, resending query: ", token_count, " | Query count: ", self.query_counter)
answer = await self.query_chatgpt(query_str, self.query_counter)
return answer

def handle_exception(self, exception):
print("EXCEPTION: Type: ", exception.type, " | Message: ", exception.message)
map(lambda query: query.cancel(), self.queries)

async def query_loop(self, subtitle_file):
slist = list(srt.parse(open(subtitle_file, "r", encoding=encoding)))
Expand Down

0 comments on commit 8f20c6a

Please sign in to comment.