Skip to content

Commit

Permalink
feat: also handle case where importlib.import_module is aliased
Browse files Browse the repository at this point in the history
  • Loading branch information
lmmx committed Jul 27, 2024
1 parent 025df97 commit 73c790a
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 16 deletions.
58 changes: 44 additions & 14 deletions src/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct ImportVisitor {
imports: HashMap<String, Vec<TextRange>>,
import_module_name: Option<String>,
}

impl ImportVisitor {
pub fn new() -> Self {
Self {
imports: HashMap::new(),
import_module_name: None,
}
}

Expand All @@ -25,20 +27,39 @@ impl<'a> Visitor<'a> for ImportVisitor {
match stmt {
Stmt::Import(import_stmt) => {
for alias in &import_stmt.names {
let top_level_module = get_top_level_module_name(&alias.name);
let top_level_module = get_top_level_module_name(alias.name.as_str());
self.imports
.entry(top_level_module)
.or_default()
.push(alias.range);

if alias.name.as_str() == "importlib" {
self.import_module_name = Some("import_module".to_string());
}
}
}
Stmt::ImportFrom(import_from_stmt) => {
if let Some(module) = &import_from_stmt.module {
if import_from_stmt.level == 0 {
let module_name = module.as_str();
self.imports
.entry(get_top_level_module_name(module.as_str()))
.entry(get_top_level_module_name(module_name))
.or_default()
.push(import_from_stmt.range);

if module_name == "importlib" {
for alias in &import_from_stmt.names {
if alias.name.as_str() == "import_module" {
self.import_module_name = Some(
alias.asname
.as_ref()
.map(|id| id.as_str().to_string())
.unwrap_or_else(|| "import_module".to_string())
);
break;
}
}
}
}
}
}
Expand All @@ -47,18 +68,27 @@ impl<'a> Visitor<'a> for ImportVisitor {
}
Stmt::Expr(expr_stmt) => {
if let Expr::Call(call_expr) = expr_stmt.value.as_ref() {
if let Expr::Attribute(attr_expr) = call_expr.func.as_ref() {
if let Expr::Name(name) = attr_expr.value.as_ref() {
if name.id.as_str() == "importlib" && attr_expr.attr.as_str() == "import_module" {
if let Some(arg) = call_expr.arguments.args.first() {
if let Expr::StringLiteral(string_literal) = arg {
let top_level_module = get_top_level_module_name(&string_literal.value.to_string());
self.imports
.entry(top_level_module)
.or_default()
.push(expr_stmt.range);
}
}
let is_import_module = match call_expr.func.as_ref() {
Expr::Attribute(attr_expr) => {
// Case: importlib.import_module(...)
matches!(attr_expr.value.as_ref(), Expr::Name(name) if name.id.as_str() == "importlib")
&& attr_expr.attr.as_str() == "import_module"
}
Expr::Name(name) => {
// Case: import_module(...) or aliased version
self.import_module_name.as_ref().map_or(false, |im_name| name.id.as_str() == im_name)
}
_ => false,
};

if is_import_module {
if let Some(arg) = call_expr.arguments.args.first() {
if let Expr::StringLiteral(string_literal) = arg {
let top_level_module = get_top_level_module_name(&string_literal.value.to_string());
self.imports
.entry(top_level_module)
.or_default()
.push(expr_stmt.range);
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions tests/data/some_dyn_imports.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from importlib import import_module
from importlib import import_module as im
import importlib

import_module("polars")
importlib.import_module("patito")
im("uvicorn")
13 changes: 11 additions & 2 deletions tests/unit/imports/test_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,17 @@ def test_dyn_import_parser_py() -> None:
some_dyn_imports_path = Path("tests/data/some_dyn_imports.py")

assert get_imported_modules_from_list_of_files([some_dyn_imports_path]) == {
"importlib": [Location(some_dyn_imports_path, line=1, column=1)],
"polars": [Location(some_dyn_imports_path, line=3, column=1)],
"importlib": [
Location(some_dyn_imports_path, line=1, column=1),
Location(some_dyn_imports_path, line=2, column=1),
Location(some_dyn_imports_path, line=3, column=8),
],
"patito": [
Location(some_dyn_imports_path, line=6, column=1),
],
"polars": [
Location(some_dyn_imports_path, line=5, column=1),
],
}


Expand Down

0 comments on commit 73c790a

Please sign in to comment.