diff --git a/cli.py b/cli.py index 2e86134..959f440 100644 --- a/cli.py +++ b/cli.py @@ -43,7 +43,7 @@ def process_repository( for folder_name in folders.keys(): folder_path = os.path.join(root_path, folder_name) - folder_tale = document_folder(folder_path, output_path) + folder_tale = process_folder(folder_path, output_path) if folder_tale is not None: folder_tales.append( {"folder_name": folder_name, "folder_summary": folder_tale} @@ -58,7 +58,7 @@ def process_repository( file.write(root_index) -def document_folder( +def process_folder( folder_path: str, output_path: str, model_name: str = DEFAULT_MODEL_NAME, @@ -74,7 +74,7 @@ def document_folder( and os.path.splitext(filename)[1] in ALLOWED_EXTENSIONS ): logger.info(f"processing {file_path}") - file_tale = document_file(file_path, save_path) + file_tale = process_file(file_path, save_path) tales.append( {"file_name": filename, "file_summary": file_tale["file_docstring"]} @@ -93,7 +93,7 @@ def document_folder( return None -def document_file( +def process_file( file_path: str, output_path: str = DEFAULT_OUTPUT_PATH, model_name: str = DEFAULT_MODEL_NAME, @@ -120,7 +120,7 @@ def document_file( logger.info(f"tale section {str(idx+1)}/{len(docs)} done.") logger.info("write dev tale") - file_tales = fuse_tales(tales_list) + file_tales = fuse_tales(tales_list, code) logger.info("add dev tale summary") final_tale = get_tale_summary(file_tales) @@ -143,11 +143,18 @@ def document_file( @click.command() @click.option( - "-r", - "--repository-path", - "repository_path", + "-m", + "--mode", + type=click.Choice(["-r", "-d", "-f"]), required=True, - help="The path to the repository", + help="Select the mode: -r for repository, -d for folder, -f for file", +) +@click.option( + "-p", + "--path", + "path", + required=True, + help="The path to the repository, folder, or file", ) @click.option( "-o", @@ -155,7 +162,7 @@ def document_file( "output_path", required=False, default=DEFAULT_OUTPUT_PATH, - help="The destination folder where you want to save the document file", + help="The destination folder where you want to save the documentation outputs", ) @click.option( "-n", @@ -166,13 +173,20 @@ def document_file( help="The OpenAI model name you want to use. \ https://platform.openai.com/docs/models", ) -def main(repository_path: str, output_path: str, model_name: str): +def main(mode: str, path: str, output_path: str, model_name: str): if not os.environ.get("OPENAI_API_KEY"): os.environ["OPENAI_API_KEY"] = getpass.getpass( prompt="Enter your OpenAI API key: " ) - process_repository(repository_path, output_path, model_name) + if mode == "-r": + process_repository(path, output_path, model_name) + elif mode == "-d": + process_folder(path, output_path, model_name) + elif mode == "-f": + process_file(path, output_path, model_name) + else: + raise "Invalid mode. Please select -r (repository), -d (folder), or -f (file)." if __name__ == "__main__": diff --git a/devtale/schema.py b/devtale/schema.py index 1d22b1e..bd027e5 100644 --- a/devtale/schema.py +++ b/devtale/schema.py @@ -7,8 +7,9 @@ class ClassEntities(BaseModel): class_name: str = Field(default=None, description="Name of the class definition.") class_docstring: str = Field( default=None, - description="Google Style Docstring text that provides an explanation of the \ - purpose of the class and its class args. All inside the same str.", + description="The Google Style Docstring text that provides an explanation \ + of the purpose of the class, including its arguments if any. All inside \ + the same str.", ) @@ -18,20 +19,22 @@ class MethodEntities(BaseModel): ) method_docstring: str = Field( default=None, - description="Google Style Docstring text that provides an explanation of the \ - purpose of the method/function, method args, method returns, and method \ - raises. All inside the same str.", + description="The Google Style Docstring text that provides an explanation \ + of the purpose of the method/function, including its arguments, returns, and \ + raises if any. All inside the same str.", ) class FileDocumentation(BaseModel): classes: List[ClassEntities] = Field( default=None, - description="Entities containing class definitions along with their respective \ - docstrings.", + description="List of entities containing class definitions along with their \ + respective docstrings. This list must not include imported classes, utility \ + classes, or class instances.", ) methods: List[MethodEntities] = Field( default=None, - description="Entities containing method/function definitions along with their \ - respective docstrings.", + description="List of entities containing method/function definitions along \ + with their respective docstrings. This list must not include imported or \ + method/function instances.", ) diff --git a/devtale/templates.py b/devtale/templates.py index f7c6e23..290ed80 100644 --- a/devtale/templates.py +++ b/devtale/templates.py @@ -1,18 +1,14 @@ CODE_LEVEL_TEMPLATE = """ -Given the provided code, please perform the following actions: +Given the provided code text input enclosed within the <<< >>> delimiters, your \ +task is to create well-structured documentation for the classes, methods, and \ +functions explicitly defined within the code. +You are not allowed to generate new classes, methods or functions. +Skip class instances, imported classes, imported methods, method instances. +Output your answer as a JSON which matches the following output format. -1. Split the code into class definitions and method definitions. -2. For each class definition, generate a Google Style Docstring text that provides an \ -explanation of the purpose of the class, args and returns. -3. For each method definition, generate a Google Style Docstring text that provides an \ -explanation of the purpose of the method, args, returns, and raises. +Ouput format: {format_instructions} -{format_instructions} - -Here is the code: --------- -{code} --------- +Input: <<< {code} >>> """ FILE_LEVEL_TEMPLATE = """ diff --git a/devtale/utils.py b/devtale/utils.py index 3e750eb..03d52b6 100644 --- a/devtale/utils.py +++ b/devtale/utils.py @@ -1,4 +1,6 @@ import json +import re +from json import JSONDecodeError from langchain import LLMChain, PromptTemplate from langchain.chat_models import ChatOpenAI @@ -21,7 +23,7 @@ def split(code, language, chunk_size=1000, chunk_overlap=0): return docs -def get_tale_index(tales, model_name="gpt-3.5-turbo", verbose=True): +def get_tale_index(tales, model_name="gpt-3.5-turbo", verbose=False): prompt = PromptTemplate(template=FOLDER_LEVEL_TEMPLATE, input_variables=["tales"]) llm = ChatOpenAI(model_name=model_name) indixer = LLMChain(llm=llm, prompt=prompt, verbose=verbose) @@ -51,18 +53,39 @@ def get_unit_tale(doc, model_name="gpt-3.5-turbo", verbose=False): result_string = teller_of_tales({"code": doc.page_content}) try: result_json = json.loads(result_string["text"]) - except Exception as e: - print( - f"Error getting the JSON with the docstrings. \ - Error: {e} \n Result {result_string}" - ) - print("Returning empty JSON instead") - empty = {"classes": [], "methods": []} - return empty + except JSONDecodeError: + try: + text = result_string["text"].replace("\\n", "\n") + start_index = text.find("{") + end_index = text.rfind("}") + + if start_index != -1 and end_index != -1 and start_index < end_index: + json_text = text[start_index : end_index + 1] + result_json = json.loads(json_text) + else: + print(f"Ivalid JSON {text}") + print("Returning empty JSON instead") + empty = {"classes": [], "methods": []} + return empty + except Exception as e: + print( + f"Error getting the JSON with the docstrings. \ + Error: {e} \n Result {json_text}" + ) + print("Returning empty JSON instead") + empty = {"classes": [], "methods": []} + return empty return result_json -def fuse_tales(tales_list): +def is_hallucination(code_definition, code): + # Check if the code_definition exists within the code + if re.search(r"\b" + re.escape(code_definition) + r"\b", code): + return False + return True + + +def fuse_tales(tales_list, code): fused_tale = {"classes": [], "methods": []} unique_methods = set() unique_classes = set() @@ -71,14 +94,18 @@ def fuse_tales(tales_list): if "classes" in tale: for class_info in tale["classes"]: class_name = class_info["class_name"] - if class_name not in unique_classes: + if class_name not in unique_classes and not is_hallucination( + class_name, code + ): unique_classes.add(class_name) fused_tale["classes"].append(class_info) if "methods" in tale: for method in tale["methods"]: method_name = method["method_name"] - if method_name not in unique_methods: + if method_name not in unique_methods and not is_hallucination( + method_name, code + ): unique_methods.add(method_name) fused_tale["methods"].append(method)