Skip to content

Commit

Permalink
Improve import speed
Browse files Browse the repository at this point in the history
  • Loading branch information
salomartin committed May 29, 2024
1 parent e577d04 commit c430be1
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 25 deletions.
1 change: 1 addition & 0 deletions graphpyshop/extensions/shopify_async_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ async def try_create_bulk_query(
while time.time() - start_time < self.BULK_QUERY_TRY_START_TIMEOUT:
try:
query = self.inject_variables(gql_query, variables)
logging.debug(f"[{query_name}] Query: {query}")
return await bulk_operation_call(query)
except ShopifyGetDataError as e:
current_time = time.time()
Expand Down
82 changes: 57 additions & 25 deletions graphpyshop/extensions/shopify_bulk_queries_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ class ShopifyBulkQueriesPlugin(Plugin):

def __init__(self, schema: GraphQLSchema, config_dict: Dict[str, Any]) -> None:
super().__init__(schema=schema, config_dict=config_dict)
self.pending_imports: Dict[str, Set[str]] = {} # Initialize pending imports
self.pending_imports: Dict[str, Set[Tuple[str, bool]]] = (
{}
) # Initialize pending imports
logging.info("ShopifyBulkQueriesPlugin initialized with schema and config.")

def get_typename_to_class_map(self, module_name: str) -> Dict[str, str]:
Expand Down Expand Up @@ -79,7 +81,7 @@ def _add_necessary_imports(self, module: ast.Module):
)
for stmt in module.body
):
self._add_import("AsyncGenerator", "typing")
self._add_import("AsyncGenerator", "typing", False)

# Ensure BulkOperationStatus from .enums is always imported
if not any(
Expand All @@ -91,7 +93,8 @@ def _add_necessary_imports(self, module: ast.Module):
)
for stmt in module.body
):
self._add_import("BulkOperationStatus", ".enums")
self._add_import("BulkOperationStatus", ".enums", False)
self._add_import("BulkOperationNodeBulkOperation", ".bulk_operation", False)

def _enhance_class_with_bulk_methods(
self, class_def: ast.ClassDef, module: ast.Module
Expand Down Expand Up @@ -153,11 +156,16 @@ def _enhance_class_with_bulk_methods(
def _is_list_return_type(
self, return_type: ast.expr, module: ast.Module
) -> Optional[Tuple[ast.expr, str]]:
if isinstance(return_type, ast.Constant):
class_name = return_type.value
if isinstance(return_type, ast.Constant) or isinstance(return_type, ast.Name):
if isinstance(return_type, ast.Constant):
class_name = return_type.value
else:
class_name = return_type.id
result = self._get_class_ast(class_name, module)
if result:
class_ast, class_module_name = result
if isinstance(return_type, ast.Name):
class_module_name = "." + class_module_name
class_def = self._find_class_in_ast(class_name, class_ast)
if class_def:
for node in ast.walk(class_def):
Expand Down Expand Up @@ -185,7 +193,6 @@ def _is_list_return_type(
):
self._add_import_to_module(
class_node.annotation,
module,
class_module_name,
)
return (
Expand All @@ -200,7 +207,11 @@ def _get_class_ast(
for node in ast.walk(module):
if isinstance(node, ast.ImportFrom):
for alias in node.names:
if isinstance(alias, ast.alias) and alias.name == class_name and node.module:
if (
isinstance(alias, ast.alias)
and alias.name == class_name
and node.module
):
absolute_module = f"graphpyshop.client.{node.module}"
try:
# Get the file path of the module
Expand All @@ -223,9 +234,7 @@ def _find_class_in_ast(
return class_node
return None

def _add_import_to_module(
self, class_name: ast.expr, module: ast.Module, module_name: str
):
def _add_import_to_module(self, class_name: ast.expr, module_name: str):
if isinstance(class_name, ast.Subscript):
# Handle subscript case (e.g., List[str])
if isinstance(class_name.slice, ast.Constant):
Expand All @@ -242,22 +251,44 @@ def _add_import_to_module(
f"Unsupported type for class_name: {type(class_name).__name__}"
)

def _add_import(self, name: str, module_name: str) -> None:
def _add_import(
self, name: str, module_name: str, under_type_checking: bool = True
) -> None:
if module_name not in self.pending_imports:
self.pending_imports[module_name] = set()

self.pending_imports[module_name].add(name)
self.pending_imports[module_name].add((name, under_type_checking))

def _flush_pending_imports(self, module: ast.Module) -> None:
# Generate import statements for all collected names grouped by module
type_checking_node = None
for node in module.body:
if (
isinstance(node, ast.If)
and isinstance(node.test, ast.Name)
and node.test.id == "TYPE_CHECKING"
):
type_checking_node = node
break

for module_name, names in self.pending_imports.items():
import_stmt = ast.ImportFrom(
module=module_name,
names=[ast.alias(name=name, asname=None) for name in names],
level=0,
)
module.body.insert(0, import_stmt)
logging.info(f"Import for {', '.join(names)} from {module_name} added.")
for name, under_type_checking in names:
import_stmt = ast.ImportFrom(
module=module_name,
names=[ast.alias(name=name, asname=None)],
level=0,
)
if type_checking_node and under_type_checking:
type_checking_node.body.append(import_stmt)
logging.info(
f"Import for {name} from {module_name} added under TYPE_CHECKING."
)
else:
module.body.insert(0, import_stmt)
logging.info(
f"Import for {name} from {module_name} added outside TYPE_CHECKING."
)

self.pending_imports = {} # Reset after flushing

def _create_bulk_method(
Expand Down Expand Up @@ -288,25 +319,26 @@ def _create_bulk_method(
),
)
return bulk_method_def

def _create_import_statement(self, class_name: str, module_name: str):
return ast.ImportFrom(
module=module_name,
names=[ast.alias(name=class_name, asname=None)],
level=0
module=module_name, names=[ast.alias(name=class_name, asname=None)], level=0
)

def _generate_bulk_method_body(
self, method_def: ast.AsyncFunctionDef, gql_var_name: str, module_name: str
):
# Create import statements
import_stmt_bulk_operation = []
"""
import_stmt_bulk_operation = ast.ImportFrom(
module=".bulk_operation",
names=[
ast.alias(name="BulkOperationNodeBulkOperation", asname=None),
],
level=0
level=0,
)
"""

# Generate the variables assignment dynamically based on the method definition, excluding 'self'
variables_assignment = ast.AnnAssign(
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ plugins = [
"ariadne_codegen.contrib.extract_operations.ExtractOperationsPlugin",
"ariadne_codegen.contrib.shorter_results.ShorterResultsPlugin",
"ariadne_codegen.contrib.client_forward_refs.ClientForwardRefsPlugin",
"ariadne_codegen.contrib.no_reimports.NoReimportsPlugin",
"graphpyshop.extensions.ShopifyBulkQueriesPlugin",
]

Expand Down

0 comments on commit c430be1

Please sign in to comment.