diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..16c4b85 --- /dev/null +++ b/.gitignore @@ -0,0 +1,49 @@ + +# Mac file +.DS_Store + +# Log file +*.log + +# Gradle files +.gradle + +# BlueJ files +*.ctxt + +# Mobile Tools for Java (J2ME) +.mtj.tmp/ + +# Package Files # +*.jar +*.war +*.nar +*.ear +*.zip +*.tar.gz +*.rar + +# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml +hs_err_pid* + +scratch* +*.flake8 +.vscode +.idea/ +*.iml + +# environment file +.env + +# log file +*.json + +# Lock file +*.lock + + +# Python compiled files and env +__pycache__/ +*.py[cod] +.python-version +.venv/ \ No newline at end of file diff --git a/README.md b/README.md index c79c677..e2729cc 100644 --- a/README.md +++ b/README.md @@ -269,4 +269,4 @@ if __name__ == "__main__": # (11) Print the instruction and LLM output print(f"Instruction:\n{instruction}") print(f"LLM Output:\n{llm_output}") -``` +``` \ No newline at end of file diff --git a/cldk/analysis/java/codeanalyzer/codeanalyzer.py b/cldk/analysis/java/codeanalyzer/codeanalyzer.py index b9bb4da..671f991 100644 --- a/cldk/analysis/java/codeanalyzer/codeanalyzer.py +++ b/cldk/analysis/java/codeanalyzer/codeanalyzer.py @@ -65,9 +65,13 @@ def __init__( self.use_graalvm_binary = use_graalvm_binary self.eager_analysis = eager_analysis self.analysis_level = analysis_level + self.application = self._init_codeanalyzer(analysis_level=1 if analysis_level == 'symbol_table' else 2) # Attributes related the Java code analysis... - self.call_graph: DiGraph | None = None - self.application = None + if analysis_level == 'symbol_table': + self.call_graph: DiGraph | None = None + else: + self.call_graph: DiGraph = self._generate_call_graph(using_symbol_table=False) + @staticmethod def _download_or_update_code_analyzer(filepath: Path) -> str: @@ -164,7 +168,7 @@ def _get_codeanalyzer_exec(self) -> List[str]: codeanalyzer_exec = shlex.split(codeanalyzer_bin_path.__str__()) else: if self.analysis_backend_path: - analysis_backend_path = Path(analysis_backend_path) + analysis_backend_path = Path(self.analysis_backend_path) logger.info(f"Using codeanalyzer.jar from {analysis_backend_path}") codeanalyzer_exec = shlex.split(f"java -jar {analysis_backend_path / 'codeanalyzer.jar'}") else: @@ -715,43 +719,19 @@ def get_class_call_graph(self, qualified_class_name: str, method_name: str | Non """ # If the method name is not provided, we'll get the call graph for the entire class. - # TODO: Implement class call graph generation @rahlk - - _class: JType = self.get_class(qualified_class_name) - - edge_list = [] - for method_signature, callable in _class.callable_declarations.items(): - for callsite in callable.callsites: - edge_list.append(((callable.signature, qualified_class_name),)) - - class_call_graph = nx.DiGraph() - - edge_list = [ - ( - (jge.source.method.signature, jge.source.klass), - (jge.target.method.signature, jge.target.klass), - { - "type": jge.type, - "weight": jge.weight, - "calling_lines": tsu.get_calling_lines(jge.source.method.code, jge.target.method.signature), - }, - ) - for jge in sdg - if jge.type == "CONTROL_DEP" or jge.type == "CALL_DEP" - ] - - for jge in sdg: - class_call_graph.add_node( - (jge.source.method.signature, jge.source.klass), - method_detail=jge.source, - ) - class_call_graph.add_node( - (jge.target.method.signature, jge.target.klass), - method_detail=jge.target, - ) - class_call_graph.add_edges_from(edge_list) - - NotImplementedError("Class call graph generation is not implemented yet.") + if method_name is None: + filter_criteria = {node for node in self.call_graph.nodes if node[1] == qualified_class_name} + else: + filter_criteria = {node for node in self.call_graph.nodes if + tuple(node) == (method_name, qualified_class_name)} + + graph_edges: List[Tuple[JMethodDetail, JMethodDetail]] = list() + for edge in self.call_graph.edges(nbunch=filter_criteria): + source: JMethodDetail = self.call_graph.nodes[edge[0]]["method_detail"] + target: JMethodDetail = self.call_graph.nodes[edge[1]]["method_detail"] + graph_edges.append((source, target)) + + return graph_edges def get_all_entry_point_methods(self) -> Dict[str, Dict[str, JCallable]]: """ diff --git a/cldk/analysis/java/java.py b/cldk/analysis/java/java.py index 9403705..2c7f4f9 100644 --- a/cldk/analysis/java/java.py +++ b/cldk/analysis/java/java.py @@ -59,12 +59,12 @@ def __init__( self.analysis_backend_path = analysis_backend_path self.eager_analysis = eager_analysis self.use_graalvm_binary = use_graalvm_binary - + self.analysis_backend = analysis_backend # Initialize the analysis analysis_backend if analysis_backend.lower() == "codeql": self.analysis_backend: JCodeQL = JCodeQL(self.project_dir, self.analysis_json_path) elif analysis_backend.lower() == "codeanalyzer": - self.analysis_backend: JCodeanalyzer = JCodeanalyzer( + self.backend: JCodeanalyzer = JCodeanalyzer( project_dir=self.project_dir, source_code=self.source_code, eager_analysis=self.eager_analysis, @@ -99,7 +99,7 @@ def get_application_view(self) -> JApplication: """ if self.source_code: raise NotImplementedError(f"Support for this functionality has not been implemented yet.") - return self.analysis_backend.get_application_view() + return self.backend.get_application_view() def get_symbol_table(self) -> Dict[str, JCompilationUnit]: """ @@ -110,7 +110,7 @@ def get_symbol_table(self) -> Dict[str, JCompilationUnit]: Dict[str, JCompilationUnit] The application view of the Java code. """ - return self.analysis_backend.get_symbol_table() + return self.backend.get_symbol_table() def get_compilation_units(self) -> List[JCompilationUnit]: """ @@ -123,7 +123,7 @@ def get_compilation_units(self) -> List[JCompilationUnit]: """ if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: raise NotImplementedError(f"Support for this functionality has not been implemented yet.") - return self.analysis_backend.get_compilation_units() + return self.backend.get_compilation_units() def get_class_hierarchy(self) -> DiGraph: """ @@ -134,7 +134,7 @@ def get_class_hierarchy(self) -> DiGraph: DiGraph The class hierarchy of the Java code. """ - if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: + if self.backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: raise NotImplementedError(f"Support for this functionality has not been implemented yet.") raise NotImplementedError("Class hierarchy is not implemented yet.") @@ -147,7 +147,7 @@ def get_call_graph(self) -> DiGraph: DiGraph The call graph of the Java code. """ - return self.analysis_backend.get_call_graph() + return self.backend.get_call_graph() def get_call_graph_json(self) -> str: """ @@ -155,7 +155,7 @@ def get_call_graph_json(self) -> str: """ if self.source_code: raise NotImplementedError("Producing a call graph over a single file is not implemented yet.") - return self.analysis_backend.get_call_graph_json() + return self.backend.get_call_graph_json() def get_callers(self, target_class_name: str, target_method_declaration: str): """ @@ -168,7 +168,7 @@ def get_callers(self, target_class_name: str, target_method_declaration: str): """ if self.source_code: raise NotImplementedError("Generating all callers over a single file is not implemented yet.") - return self.analysis_backend.get_all_callers(target_class_name, target_method_declaration) + return self.backend.get_all_callers(target_class_name, target_method_declaration) def get_callees(self, source_class_name: str, source_method_declaration: str): """ @@ -181,7 +181,7 @@ def get_callees(self, source_class_name: str, source_method_declaration: str): """ if self.source_code: raise NotImplementedError("Generating all callees over a single file is not implemented yet.") - return self.analysis_backend.get_all_callees(source_class_name, source_method_declaration) + return self.backend.get_all_callees(source_class_name, source_method_declaration) def get_methods(self) -> Dict[str, Dict[str, JCallable]]: """ @@ -196,7 +196,7 @@ def get_methods(self) -> Dict[str, Dict[str, JCallable]]: """ if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: raise NotImplementedError(f"Support for this functionality has not been implemented yet.") - return self.analysis_backend.get_all_methods_in_application() + return self.backend.get_all_methods_in_application() def get_classes(self) -> Dict[str, JType]: """ @@ -209,7 +209,7 @@ def get_classes(self) -> Dict[str, JType]: """ if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: raise NotImplementedError(f"Support for this functionality has not been implemented yet.") - return self.analysis_backend.get_all_classes() + return self.backend.get_all_classes() def get_classes_by_criteria(self, inclusions=None, exclusions=None) -> Dict[str, JType]: """ @@ -230,7 +230,7 @@ def get_classes_by_criteria(self, inclusions=None, exclusions=None) -> Dict[str, if inclusions is None: inclusions = [] class_dict: Dict[str, JType] = {} - all_classes = self.get_all_classes() + all_classes = self.backend.get_all_classes() for application_class in all_classes: is_selected = False for inclusion in inclusions: @@ -260,7 +260,7 @@ def get_class(self, qualified_class_name) -> JType: """ if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: raise NotImplementedError(f"Support for this functionality has not been implemented yet.") - return self.analysis_backend.get_class(qualified_class_name) + return self.backend.get_class(qualified_class_name) def get_method(self, qualified_class_name, qualified_method_name) -> JCallable: """ @@ -278,7 +278,7 @@ def get_method(self, qualified_class_name, qualified_method_name) -> JCallable: """ if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: raise NotImplementedError(f"Support for this functionality has not been implemented yet.") - return self.analysis_backend.get_method(qualified_class_name, qualified_method_name) + return self.backend.get_method(qualified_class_name, qualified_method_name) def get_java_file(self, qualified_class_name) -> str: """ @@ -296,7 +296,7 @@ def get_java_file(self, qualified_class_name) -> str: """ if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: raise NotImplementedError(f"Support for this functionality has not been implemented yet.") - return self.analysis_backend.get_java_file(qualified_class_name) + return self.backend.get_java_file(qualified_class_name) def get_java_compilation_unit(self, file_path: str) -> JCompilationUnit: """ @@ -314,7 +314,7 @@ def get_java_compilation_unit(self, file_path: str) -> JCompilationUnit: """ if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: raise NotImplementedError(f"Support for this functionality has not been implemented yet.") - return self.analysis_backend.get_java_compilation_unit(file_path) + return self.backend.get_java_compilation_unit(file_path) def get_methods_in_class(self, qualified_class_name) -> Dict[str, JCallable]: """ @@ -332,7 +332,7 @@ def get_methods_in_class(self, qualified_class_name) -> Dict[str, JCallable]: """ if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: raise NotImplementedError(f"Support for this functionality has not been implemented yet.") - return self.analysis_backend.get_all_methods_in_class(qualified_class_name) + return self.backend.get_all_methods_in_class(qualified_class_name) def get_constructors(self, qualified_class_name) -> Dict[str, JCallable]: """ @@ -350,7 +350,7 @@ def get_constructors(self, qualified_class_name) -> Dict[str, JCallable]: """ if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: raise NotImplementedError(f"Support for this functionality has not been implemented yet.") - return self.analysis_backend.get_all_constructors(qualified_class_name) + return self.backend.get_all_constructors(qualified_class_name) def get_fields(self, qualified_class_name) -> List[JField]: """ @@ -368,7 +368,7 @@ def get_fields(self, qualified_class_name) -> List[JField]: """ if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: raise NotImplementedError(f"Support for this functionality has not been implemented yet.") - return self.analysis_backend.get_all_fields(qualified_class_name) + return self.backend.get_all_fields(qualified_class_name) def get_nested_classes(self, qualified_class_name) -> List[JType]: """ @@ -386,7 +386,7 @@ def get_nested_classes(self, qualified_class_name) -> List[JType]: """ if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: raise NotImplementedError(f"Support for this functionality has not been implemented yet.") - return self.analysis_backend.get_all_nested_classes(qualified_class_name) + return self.backend.get_all_nested_classes(qualified_class_name) def get_sub_classes(self, qualified_class_name) -> Dict[str, JType]: """ @@ -399,7 +399,7 @@ def get_sub_classes(self, qualified_class_name) -> Dict[str, JType]: ------- Dict[str, JType]: A dictionary of all sub-classes of the given class, and class details """ - return self.analysis_backend.get_all_sub_classes(qualified_class_name=qualified_class_name) + return self.backend.get_all_sub_classes(qualified_class_name=qualified_class_name) def get_extended_classes(self, qualified_class_name) -> List[str]: """ @@ -417,7 +417,7 @@ def get_extended_classes(self, qualified_class_name) -> List[str]: """ if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: raise NotImplementedError(f"Support for this functionality has not been implemented yet.") - return self.analysis_backend.get_extended_classes(qualified_class_name) + return self.backend.get_extended_classes(qualified_class_name) def get_implemented_interfaces(self, qualified_class_name) -> List[str]: """ @@ -435,7 +435,7 @@ def get_implemented_interfaces(self, qualified_class_name) -> List[str]: """ if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: raise NotImplementedError(f"Support for this functionality has not been implemented yet.") - return self.analysis_backend.get_implemented_interfaces(qualified_class_name) + return self.backend.get_implemented_interfaces(qualified_class_name) def get_class_call_graph(self, qualified_class_name: str, method_name: str | None = None) -> (List)[Tuple[JMethodDetail, JMethodDetail]]: """ @@ -455,7 +455,7 @@ def get_class_call_graph(self, qualified_class_name: str, method_name: str | Non """ if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: raise NotImplementedError(f"Support for this functionality has not been implemented yet.") - return self.analysis_backend.get_class_call_graph(qualified_class_name, method_name) + return self.backend.get_class_call_graph(qualified_class_name, method_name) def get_entry_point_classes(self) -> Dict[str, JType]: """ @@ -468,7 +468,7 @@ def get_entry_point_classes(self) -> Dict[str, JType]: """ if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: raise NotImplementedError(f"Support for this functionality has not been implemented yet.") - return self.analysis_backend.get_all_entry_point_classes() + return self.backend.get_all_entry_point_classes() def get_entry_point_methods(self) -> Dict[str, Dict[str, JCallable]]: """ @@ -483,7 +483,7 @@ def get_entry_point_methods(self) -> Dict[str, Dict[str, JCallable]]: """ if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: raise NotImplementedError(f"Support for this functionality has not been implemented yet.") - return self.analysis_backend.get_all_entry_point_methods() + return self.backend.get_all_entry_point_methods() def remove_all_comments(self) -> str: """ @@ -503,7 +503,7 @@ def remove_all_comments(self) -> str: # Remove any prefix comments/content before the package declaration if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.CODEANALYZER]: raise NotImplementedError(f"Support for this functionality has not been implemented yet.") - return self.analysis_backend.remove_all_comments(self.source_code) + return self.backend.remove_all_comments(self.source_code) def get_methods_with_annotations(self, annotations: List[str]) -> Dict[str, List[Dict]]: """ @@ -523,7 +523,7 @@ def get_methods_with_annotations(self, annotations: List[str]) -> Dict[str, List """ if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.CODEANALYZER]: raise NotImplementedError(f"Support for this functionality has not been implemented yet.") - return self.analysis_backend.get_methods_with_annotations(self.source_code, annotations) + return self.backend.get_methods_with_annotations(self.source_code, annotations) def get_test_methods(self, source_class_code: str) -> Dict[str, str]: """ @@ -541,7 +541,7 @@ def get_test_methods(self, source_class_code: str) -> Dict[str, str]: """ if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.CODEANALYZER]: raise NotImplementedError(f"Support for this functionality has not been implemented yet.") - return self.analysis_backend.get_test_methods(self.source_code) + return self.backend.get_test_methods(self.source_code) def get_calling_lines(self, target_method_name: str) -> List[int]: """ @@ -562,7 +562,7 @@ def get_calling_lines(self, target_method_name: str) -> List[int]: """ if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: raise NotImplementedError(f"Support for this functionality has not been implemented yet.") - return self.analysis_backend.get_calling_lines(self.source_code, target_method_name) + return self.backend.get_calling_lines(self.source_code, target_method_name) def get_call_targets(self, declared_methods: dict) -> Set[str]: """Generate a list of call targets from the method body. @@ -584,4 +584,4 @@ def get_call_targets(self, declared_methods: dict) -> Set[str]: """ if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: raise NotImplementedError(f"Support for this functionality has not been implemented yet.") - return self.analysis_backend.get_call_targets(self.source_code, declared_methods) + return self.backend.get_call_targets(self.source_code, declared_methods) diff --git a/cldk/analysis/python/treesitter/python_sitter.py b/cldk/analysis/python/treesitter/python_sitter.py index 3460987..ec3ba65 100644 --- a/cldk/analysis/python/treesitter/python_sitter.py +++ b/cldk/analysis/python/treesitter/python_sitter.py @@ -3,12 +3,12 @@ from pathlib import Path from typing import List -from sphinx.domains.python import PyField from tree_sitter import Language, Parser, Query, Node import tree_sitter_python as tspython from cldk.models.python.models import PyMethod, PyClass, PyArg, PyImport, PyModule, PyCallSite from cldk.models.treesitter import Captures +from cldk.utils.treesitter.tree_sitter_utils import TreeSitterUtils class PythonSitter: @@ -19,6 +19,7 @@ class PythonSitter: def __init__(self) -> None: self.language: Language = Language(tspython.language()) self.parser: Parser = Parser(self.language) + self.utils: TreeSitterUtils = TreeSitterUtils() def get_all_methods(self, module: str) -> List[PyMethod]: """ @@ -104,9 +105,9 @@ def get_all_imports(self, module: str) -> List[str]: List[str]: List of imports """ import_list = [] - captures_from_import: Captures = self.__frame_query_and_capture_output("(((import_from_statement) @imports))", + captures_from_import: Captures = self.utils.frame_query_and_capture_output(self.parser, self.language, "(((import_from_statement) @imports))", module) - captures_import: Captures = self.__frame_query_and_capture_output("(((import_statement) @imports))", module) + captures_import: Captures = self.utils.frame_query_and_capture_output(self.parser, self.language, "(((import_statement) @imports))", module) for capture in captures_import: import_list.append(capture.node.text.decode()) for capture in captures_from_import: @@ -131,9 +132,9 @@ def get_all_imports_details(self, module: str) -> List[PyImport]: List[PyImport]: List of imports """ import_list = [] - captures_from_import: Captures = self.__frame_query_and_capture_output("(((import_from_statement) @imports))", + captures_from_import: Captures = self.utils.frame_query_and_capture_output(self.parser, self.language, "(((import_from_statement) @imports))", module) - captures_import: Captures = self.__frame_query_and_capture_output("(((import_statement) @imports))", module) + captures_import: Captures = self.utils.frame_query_and_capture_output(self.parser, self.language, "(((import_statement) @imports))", module) for capture in captures_import: imports = [] for import_name in capture.node.children: @@ -152,7 +153,7 @@ def get_all_imports_details(self, module: str) -> List[PyImport]: import_list.append(PyImport(from_statement=capture.node.children[1].text.decode(), imports=imports)) return import_list - def get_all_fields(self, module: str) -> List[PyField]: + def get_all_fields(self, module: str): pass def get_all_classes(self, module: str) -> List[PyClass]: @@ -167,7 +168,7 @@ def get_all_classes(self, module: str) -> List[PyClass]: List[PyClass]: returns details of all classes in it """ classes: List[PyClass] = [] - all_class_details: Captures = self.__frame_query_and_capture_output("(((class_definition) @class_name))", + all_class_details: Captures = self.utils.frame_query_and_capture_output(self.parser, self.language, "(((class_definition) @class_name))", module) for class_name in all_class_details: code_body = class_name.node.text.decode() @@ -280,7 +281,7 @@ def __get_function_details(self, node: Node, klass_name: str = "") -> PyMethod: is_constructor = False is_static = False call_sites: List[PyCallSite] = [] - call_nodes: Captures = self.__frame_query_and_capture_output("(((call) @call_name))", node.text.decode()) + call_nodes: Captures = self.utils.frame_query_and_capture_output(self.parser, self.language, "(((call) @call_name))", node.text.decode()) for call_node in call_nodes: call_sites.append(self.__get_call_site_details(call_node.node)) for function_detail in node.children: @@ -342,9 +343,9 @@ def __get_function_details(self, node: Node, klass_name: str = "") -> PyMethod: def __get_class_nodes(self, module: str) -> Captures: - captures: Captures = self.__frame_query_and_capture_output("(((class_definition) @class_name))", module) + captures: Captures = self.utils.frame_query_and_capture_output(self.parser, self.language, "(((class_definition) @class_name))", module) return captures def __get_method_nodes(self, module: str) -> Captures: - captures: Captures = self.__frame_query_and_capture_output("(((function_definition) @function_name))", module) + captures: Captures = self.utils.frame_query_and_capture_output(self.parser, self.language, "(((function_definition) @function_name))", module) return captures diff --git a/cldk/utils/treesitter/tree_sitter_utils.py b/cldk/utils/treesitter/tree_sitter_utils.py index 8621232..967e7f9 100644 --- a/cldk/utils/treesitter/tree_sitter_utils.py +++ b/cldk/utils/treesitter/tree_sitter_utils.py @@ -4,7 +4,7 @@ class TreeSitterUtils: - def __frame_query_and_capture_output(self, query: str, code_to_process: str) -> Captures: + def frame_query_and_capture_output(self, parser, language, query: str, code_to_process: str) -> Captures: """Frame a query for the tree-sitter parser. Parameters @@ -14,11 +14,11 @@ def __frame_query_and_capture_output(self, query: str, code_to_process: str) -> code_to_process : str The code to process. """ - framed_query: Query = self.language.query(query) - tree = self.parser.parse(bytes(code_to_process, "utf-8")) + framed_query: Query = language.query(query) + tree = parser.parse(bytes(code_to_process, "utf-8")) return Captures(framed_query.captures(tree.root_node)) - def __safe_ascend(self, node: Node, ascend_count: int) -> Node: + def safe_ascend(self, node: Node, ascend_count: int) -> Node: """Safely ascend the tree. If the node does not exist or if it has no parent, raise an error. Parameters @@ -45,4 +45,4 @@ def __safe_ascend(self, node: Node, ascend_count: int) -> Node: if ascend_count == 0: return node else: - return self.__safe_ascend(node.parent, ascend_count - 1) \ No newline at end of file + return self.safe_ascend(node.parent, ascend_count - 1) \ No newline at end of file diff --git a/docs/examples/java/code_summarization.py b/docs/examples/java/code_summarization.py new file mode 100644 index 0000000..bc00eda --- /dev/null +++ b/docs/examples/java/code_summarization.py @@ -0,0 +1,66 @@ +import os +from pathlib import Path +import ollama +from cldk import CLDK + + +def format_inst(code, focal_method, focal_class, language): + """ + Format the instruction for the given focal method and class. + """ + inst = f"Question: Can you write a brief summary for the method `{focal_method}` in the class `{focal_class}` below?\n" + + inst += "\n" + inst += f"```{language}\n" + inst += code + inst += "```" if code.endswith("\n") else "\n```" + inst += "\n" + return inst + + +def prompt_ollama(message: str, model_id: str = "granite-code:8b-instruct") -> str: + """Prompt local model on Ollama""" + response_object = ollama.generate(model=model_id, prompt=message) + return response_object["response"] + + +if __name__ == "__main__": + # (1) Create a new instance of the CLDK class + cldk = CLDK(language="java") + + # (2) Create an analysis object over the java application + analysis = cldk.analysis(project_path="JAVA_APP_PATH") + + # (3) Iterate over all the files in the project + for file_path, class_file in analysis.get_symbol_table().items(): + class_file_path = Path(file_path).absolute().resolve() + # (4) Iterate over all the classes in the file + for type_name, type_declaration in class_file.type_declarations.items(): + # (5) Iterate over all the methods in the class + for method in type_declaration.callable_declarations.values(): + # (6) Get code body of the method + code_body = class_file_path.read_text() + + # (7) Initialize the treesitter utils for the class file content + tree_sitter_utils = cldk.tree_sitter_utils(source_code=code_body) + + # (8) Sanitize the class for analysis + sanitized_class = tree_sitter_utils.sanitize_focal_class(method.declaration) + + # (9) Format the instruction for the given focal method and class + instruction = format_inst( + code=sanitized_class, + focal_method=method.declaration, + focal_class=type_name, + language="java" + ) + + # (10) Prompt the local model on Ollama + llm_output = prompt_ollama( + message=instruction, + model_id="granite-code:20b-instruct", + ) + + # (11) Print the instruction and LLM output + print(f"Instruction:\n{instruction}") + print(f"LLM Output:\n{llm_output}") \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 9bba721..4218f11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "cldk" version = "0.1.0-dev" description = "codellm-devkit: A python library for seamless interation with LLMs." -authors = ["Rahul Krishna ", "Rangeet Pan ", "Saurabh Sinhas ", +authors = ["Rahul Krishna ", "Rangeet Pan ", "Saurabh Sinhas ", "Raju Pavuluri "] license = "Apache 2.0" readme = "README.md" @@ -22,7 +22,7 @@ requests = "^2.31.0" tree-sitter-java = "^0.21.0" tree-sitter-c = "^0.21.0" tree-sitter-go = "^0.21.0" -tree-sitter-python = {git = "https://github.com/tree-sitter/tree-sitter-python", rev = "0f9047c"} # Points to 0.21.0 +tree-sitter-python = "^0.21.0" tree-sitter-javascript = "^0.21.0" # Test dependencies diff --git a/tests/analysis/java/test_java.py b/tests/analysis/java/test_java.py index 01fcf2f..4be9882 100644 --- a/tests/analysis/java/test_java.py +++ b/tests/analysis/java/test_java.py @@ -12,6 +12,7 @@ def test_get_class_call_graph(test_fixture): analysis_backend="codeanalyzer", analysis_json_path="/tmp", eager=True, + analysis_level='call-graph' ) class_call_graph: List[Tuple[JMethodDetail, JMethodDetail]] = analysis.get_class_call_graph( qualified_class_name="com.ibm.websphere.samples.daytrader.impl.direct.TradeDirectDBUtils" diff --git a/tests/tree_sitter/python/test_python_tree_sitter.py b/tests/tree_sitter/python/test_python_tree_sitter.py index 474f0df..830d06f 100644 --- a/tests/tree_sitter/python/test_python_tree_sitter.py +++ b/tests/tree_sitter/python/test_python_tree_sitter.py @@ -37,7 +37,7 @@ def __str__(self): self.assertFalse(all_methods[0].is_static) self.assertEquals(all_methods[0].class_signature, "Person") self.assertEquals(all_functions[0].class_signature, "") - self.assertTrue(all_functions[0].is_static) + self.assertFalse(all_functions[0].is_static) def test_get_all_imports(self): module_str = """