Skip to content

Commit

Permalink
[red-knot] Encapsulate module resolution logic in module.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexWaygood committed Jun 5, 2024
1 parent 2567e14 commit d514307
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 55 deletions.
12 changes: 6 additions & 6 deletions crates/red_knot/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use tracing_subscriber::{Layer, Registry};
use tracing_tree::time::Uptime;

use red_knot::db::{HasJar, ParallelDatabase, QueryError, SourceDb, SourceJar};
use red_knot::module::{set_module_search_paths, ModuleSearchPath, ModuleSearchPathKind};
use red_knot::module::{set_module_search_paths, ResolvedSearchPathOrder};
use red_knot::program::check::ExecutionMode;
use red_knot::program::{FileWatcherChange, Program};
use red_knot::watch::FileWatcher;
Expand Down Expand Up @@ -44,12 +44,12 @@ fn main() -> anyhow::Result<()> {
let workspace_folder = entry_point.parent().unwrap();
let workspace = Workspace::new(workspace_folder.to_path_buf());

let workspace_search_path = ModuleSearchPath::new(
workspace.root().to_path_buf(),
ModuleSearchPathKind::FirstParty,
);
let workspace_search_path = workspace.root().to_path_buf();
let resolved_search_paths =
ResolvedSearchPathOrder::new(vec![], workspace_search_path, None, None);

let mut program = Program::new(workspace);
set_module_search_paths(&mut program, vec![workspace_search_path]);
set_module_search_paths(&mut program, resolved_search_paths);

let entry_id = program.file_id(entry_point);
program.workspace_mut().open_file(entry_id);
Expand Down
146 changes: 104 additions & 42 deletions crates/red_knot/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,14 +227,21 @@ struct ModuleSearchPathInner {

#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub enum ModuleSearchPathKind {
// Project dependency
FirstParty,
/// "Extra" paths provided by the user in a config file, env var or CLI flag.
/// E.g. mypy's `MYPYPATH` env var, or pyright's `stubPath` configuration setting
Extra,

// e.g. site packages
ThirdParty,
/// Project dependency
FirstParty,

// e.g. built-in modules, typeshed
/// e.g. built-in modules, typeshed
StandardLibrary,

/// Stubs or runtime modules installed in site-packages
SitePackagesThirdParty,

/// Vendored third-party stubs from typeshed
VendoredThirdParty,
}

impl ModuleSearchPathKind {
Expand Down Expand Up @@ -388,12 +395,64 @@ pub fn file_to_module(db: &dyn SemanticDb, file: FileId) -> QueryResult<Option<M
//////////////////////////////////////////////////////

/// Changes the module search paths to `search_paths`.
pub fn set_module_search_paths(db: &mut dyn SemanticDb, search_paths: Vec<ModuleSearchPath>) {
pub fn set_module_search_paths(db: &mut dyn SemanticDb, search_paths: ResolvedSearchPathOrder) {
let jar: &mut SemanticJar = db.jar_mut();

jar.module_resolver = ModuleResolver::new(search_paths);
}

/// A resolved module resolution order, implementing PEP 561
/// (with some small, deliberate differences)
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct ResolvedSearchPathOrder(Vec<ModuleSearchPath>);

impl Deref for ResolvedSearchPathOrder {
type Target = [ModuleSearchPath];

fn deref(&self) -> &Self::Target {
&self.0
}
}

/// Implementation of the module resolution order from PEP-561.
///
/// - `extra_paths` is a list of user-provided paths
/// that should take first priority in the module resolution.
/// Examples in other type checkers are mypy's MYPYPATH environment variable,
/// or pyright's stubPath configuration setting.
/// - `workspace_root` is the root of the workspace,
/// used for finding first-party modules
/// - `custom_typeshed`
impl ResolvedSearchPathOrder {
pub fn new(
extra_paths: Vec<PathBuf>,
workspace_root: PathBuf,
site_packages: Option<PathBuf>,
custom_typeshed: Option<PathBuf>,
) -> Self {
Self(
extra_paths
.into_iter()
.map(|path| ModuleSearchPath::new(path, ModuleSearchPathKind::Extra))
.chain(std::iter::once(ModuleSearchPath::new(
workspace_root,
ModuleSearchPathKind::FirstParty,
)))
// TODO fallback to vendored typeshed stubs if no custom typeshed directory is provided by the user
.chain(
custom_typeshed.into_iter().map(|path| {
ModuleSearchPath::new(path, ModuleSearchPathKind::StandardLibrary)
}),
)
.chain(site_packages.into_iter().map(|path| {
ModuleSearchPath::new(path, ModuleSearchPathKind::SitePackagesThirdParty)
}))
// TODO vendor typeshed's third-party stubs as well as the stdlib and fallback to them as a final step
.collect(),
)
}
}

/// Adds a module located at `path` to the resolver.
///
/// Returns `None` if the path doesn't resolve to a module.
Expand Down Expand Up @@ -460,7 +519,7 @@ pub fn add_module(db: &mut dyn SemanticDb, path: &Path) -> Option<(Module, Vec<A
#[derive(Default)]
pub struct ModuleResolver {
/// The search paths where modules are located (and searched). Corresponds to `sys.path` at runtime.
search_paths: Vec<ModuleSearchPath>,
search_paths: ResolvedSearchPathOrder,

// Locking: Locking is done by acquiring a (write) lock on `by_name`. This is because `by_name` is the primary
// lookup method. Acquiring locks in any other ordering can result in deadlocks.
Expand All @@ -477,7 +536,7 @@ pub struct ModuleResolver {
}

impl ModuleResolver {
pub fn new(search_paths: Vec<ModuleSearchPath>) -> Self {
pub fn new(search_paths: ResolvedSearchPathOrder) -> Self {
Self {
search_paths,
modules: FxDashMap::default(),
Expand Down Expand Up @@ -690,21 +749,22 @@ impl PackageKind {
#[cfg(test)]
mod tests {
use std::num::NonZeroU32;
use std::path::PathBuf;

use crate::db::tests::TestDb;
use crate::db::SourceDb;
use crate::module::{
path_to_module, resolve_module, set_module_search_paths, ModuleKind, ModuleName,
ModuleSearchPath, ModuleSearchPathKind,
ResolvedSearchPathOrder,
};
use crate::symbols::Dependency;

struct TestCase {
temp_dir: tempfile::TempDir,
db: TestDb,

src: ModuleSearchPath,
site_packages: ModuleSearchPath,
src: PathBuf,
site_packages: PathBuf,
}

fn create_resolver() -> std::io::Result<TestCase> {
Expand All @@ -716,16 +776,18 @@ mod tests {
std::fs::create_dir(&src)?;
std::fs::create_dir(&site_packages)?;

let src = ModuleSearchPath::new(src.canonicalize()?, ModuleSearchPathKind::FirstParty);
let site_packages = ModuleSearchPath::new(
site_packages.canonicalize()?,
ModuleSearchPathKind::ThirdParty,
);
let src = src.canonicalize()?;
let site_packages = site_packages.canonicalize()?;

let roots = vec![src.clone(), site_packages.clone()];
let resolved_search_paths = ResolvedSearchPathOrder::new(
vec![],
src.clone(),
Some(site_packages.clone()),
None,
);

let mut db = TestDb::default();
set_module_search_paths(&mut db, roots);
set_module_search_paths(&mut db, resolved_search_paths);

Ok(TestCase {
temp_dir,
Expand All @@ -744,7 +806,7 @@ mod tests {
..
} = create_resolver()?;

let foo_path = src.path().join("foo.py");
let foo_path = src.join("foo.py");
std::fs::write(&foo_path, "print('Hello, world!')")?;

let foo_module = resolve_module(&db, ModuleName::new("foo"))?.unwrap();
Expand All @@ -755,7 +817,7 @@ mod tests {
);

assert_eq!(ModuleName::new("foo"), foo_module.name(&db)?);
assert_eq!(&src, foo_module.path(&db)?.root());
assert_eq!(&src, foo_module.path(&db)?.root().path());
assert_eq!(ModuleKind::Module, foo_module.kind(&db)?);
assert_eq!(&foo_path, &*db.file_path(foo_module.path(&db)?.file()));

Expand All @@ -773,15 +835,15 @@ mod tests {
..
} = create_resolver()?;

let foo_dir = src.path().join("foo");
let foo_dir = src.join("foo");
let foo_path = foo_dir.join("__init__.py");
std::fs::create_dir(&foo_dir)?;
std::fs::write(&foo_path, "print('Hello, world!')")?;

let foo_module = resolve_module(&db, ModuleName::new("foo"))?.unwrap();

assert_eq!(ModuleName::new("foo"), foo_module.name(&db)?);
assert_eq!(&src, foo_module.path(&db)?.root());
assert_eq!(&src, foo_module.path(&db)?.root().path());
assert_eq!(&foo_path, &*db.file_path(foo_module.path(&db)?.file()));

assert_eq!(Some(foo_module), path_to_module(&db, &foo_path)?);
Expand All @@ -801,17 +863,17 @@ mod tests {
..
} = create_resolver()?;

let foo_dir = src.path().join("foo");
let foo_dir = src.join("foo");
let foo_init = foo_dir.join("__init__.py");
std::fs::create_dir(&foo_dir)?;
std::fs::write(&foo_init, "print('Hello, world!')")?;

let foo_py = src.path().join("foo.py");
let foo_py = src.join("foo.py");
std::fs::write(&foo_py, "print('Hello, world!')")?;

let foo_module = resolve_module(&db, ModuleName::new("foo"))?.unwrap();

assert_eq!(&src, foo_module.path(&db)?.root());
assert_eq!(&src, foo_module.path(&db)?.root().path());
assert_eq!(&foo_init, &*db.file_path(foo_module.path(&db)?.file()));
assert_eq!(ModuleKind::Package, foo_module.kind(&db)?);

Expand All @@ -830,14 +892,14 @@ mod tests {
..
} = create_resolver()?;

let foo_stub = src.path().join("foo.pyi");
let foo_py = src.path().join("foo.py");
let foo_stub = src.join("foo.pyi");
let foo_py = src.join("foo.py");
std::fs::write(&foo_stub, "x: int")?;
std::fs::write(&foo_py, "print('Hello, world!')")?;

let foo = resolve_module(&db, ModuleName::new("foo"))?.unwrap();

assert_eq!(&src, foo.path(&db)?.root());
assert_eq!(&src, foo.path(&db)?.root().path());
assert_eq!(&foo_stub, &*db.file_path(foo.path(&db)?.file()));

assert_eq!(Some(foo), path_to_module(&db, &foo_stub)?);
Expand All @@ -855,7 +917,7 @@ mod tests {
..
} = create_resolver()?;

let foo = src.path().join("foo");
let foo = src.join("foo");
let bar = foo.join("bar");
let baz = bar.join("baz.py");

Expand All @@ -866,7 +928,7 @@ mod tests {

let baz_module = resolve_module(&db, ModuleName::new("foo.bar.baz"))?.unwrap();

assert_eq!(&src, baz_module.path(&db)?.root());
assert_eq!(&src, baz_module.path(&db)?.root().path());
assert_eq!(&baz, &*db.file_path(baz_module.path(&db)?.file()));

assert_eq!(Some(baz_module), path_to_module(&db, &baz)?);
Expand Down Expand Up @@ -896,14 +958,14 @@ mod tests {
// two.py
// ```

let parent1 = src.path().join("parent");
let parent1 = src.join("parent");
let child1 = parent1.join("child");
let one = child1.join("one.py");

std::fs::create_dir_all(child1)?;
std::fs::write(&one, "print('Hello, world!')")?;

let parent2 = site_packages.path().join("parent");
let parent2 = site_packages.join("parent");
let child2 = parent2.join("child");
let two = child2.join("two.py");

Expand Down Expand Up @@ -942,15 +1004,15 @@ mod tests {
// two.py
// ```

let parent1 = src.path().join("parent");
let parent1 = src.join("parent");
let child1 = parent1.join("child");
let one = child1.join("one.py");

std::fs::create_dir_all(&child1)?;
std::fs::write(child1.join("__init__.py"), "print('Hello, world!')")?;
std::fs::write(&one, "print('Hello, world!')")?;

let parent2 = site_packages.path().join("parent");
let parent2 = site_packages.join("parent");
let child2 = parent2.join("child");
let two = child2.join("two.py");

Expand All @@ -977,15 +1039,15 @@ mod tests {
temp_dir: _temp_dir,
} = create_resolver()?;

let foo_src = src.path().join("foo.py");
let foo_site_packages = site_packages.path().join("foo.py");
let foo_src = src.join("foo.py");
let foo_site_packages = site_packages.join("foo.py");

std::fs::write(&foo_src, "")?;
std::fs::write(&foo_site_packages, "")?;

let foo_module = resolve_module(&db, ModuleName::new("foo"))?.unwrap();

assert_eq!(&src, foo_module.path(&db)?.root());
assert_eq!(&src, foo_module.path(&db)?.root().path());
assert_eq!(&foo_src, &*db.file_path(foo_module.path(&db)?.file()));

assert_eq!(Some(foo_module), path_to_module(&db, &foo_src)?);
Expand All @@ -1004,8 +1066,8 @@ mod tests {
..
} = create_resolver()?;

let foo = src.path().join("foo.py");
let bar = src.path().join("bar.py");
let foo = src.join("foo.py");
let bar = src.join("bar.py");

std::fs::write(&foo, "")?;
std::os::unix::fs::symlink(&foo, &bar)?;
Expand All @@ -1015,12 +1077,12 @@ mod tests {

assert_ne!(foo_module, bar_module);

assert_eq!(&src, foo_module.path(&db)?.root());
assert_eq!(&src, foo_module.path(&db)?.root().path());
assert_eq!(&foo, &*db.file_path(foo_module.path(&db)?.file()));

// Bar has a different name but it should point to the same file.

assert_eq!(&src, bar_module.path(&db)?.root());
assert_eq!(&src, bar_module.path(&db)?.root().path());
assert_eq!(foo_module.path(&db)?.file(), bar_module.path(&db)?.file());
assert_eq!(&foo, &*db.file_path(bar_module.path(&db)?.file()));

Expand All @@ -1039,7 +1101,7 @@ mod tests {
..
} = create_resolver()?;

let foo_dir = src.path().join("foo");
let foo_dir = src.join("foo");
let foo_path = foo_dir.join("__init__.py");
let bar_path = foo_dir.join("bar.py");

Expand Down
Loading

0 comments on commit d514307

Please sign in to comment.