diff --git a/gemini/src/model/parser/response_parser.py b/gemini/src/model/parser/response_parser.py index 8010887..e37f90c 100644 --- a/gemini/src/model/parser/response_parser.py +++ b/gemini/src/model/parser/response_parser.py @@ -1,6 +1,5 @@ import json from typing import Dict -from gemini.src.misc.utils import extract_code from gemini.src.model.parser.base import BaesParser @@ -98,16 +97,12 @@ def __extract_strategy_1(self, response_text: str) -> Dict: body = json.loads(json.loads(response_text.split("\n")[3])[0][2]) if not body[4]: body = json.loads(json.loads(response_text.split("\n")[3])[4][2]) - else: - raise ValueError("Invalid response data received.") return body def __extract_strategy_2(self, response_text: str) -> Dict: body = json.loads(json.loads(response_text.split("\n")[2])[0][2]) if not body[4]: body = json.loads(json.loads(response_text.split("\n")[2])[4][2]) - else: - raise ValueError("Invalid response data received.") return body def __extract_strategy_3(self, response_text: str) -> Dict: @@ -209,7 +204,7 @@ def _parse_code(self, text: str) -> Dict: """ if not text: return {} - extracted_code = extract_code(text) + extracted_code = self.extract_code(text) code_dict = {} if isinstance(extracted_code, str) and extracted_code != text: @@ -221,3 +216,40 @@ def _parse_code(self, text: str) -> Dict: return {} return code_dict + + @staticmethod + def extract_code(text: str) -> str: + """ + Extracts code snippets from the given text. + If only one snippet is found, returns it directly instead of a list. + If no snippets are found, returns the original text. + + Args: + text (str): The text containing mixed code snippets. + + Returns: + str or list of str: A single code snippet string if only one is found, otherwise a list of all extracted code snippets. Returns the original text if no snippets are found. + """ + + snippets = [] + start_pattern = "```" + end_pattern = "```" + start_idx = text.find(start_pattern) + + while start_idx != -1: + end_idx = text.find(end_pattern, start_idx + len(start_pattern)) + if end_idx != -1: + snippet = text[start_idx : end_idx + len(end_pattern)].strip() + snippets.append(snippet) + start_idx = text.find(start_pattern, end_idx + len(end_pattern)) + else: + break + + # Return directly if only one snippet is found + if len(snippets) == 1: + return snippets[0] + elif len(snippets) > 1: + return snippets + else: + # Return the original text if no snippets are found + return text