Skip to content

Commit 6e498ce

Browse files
committed
feat(gazelle): Include types/stubs packages that have them
1 parent f88e083 commit 6e498ce

File tree

7 files changed

+70
-5
lines changed

7 files changed

+70
-5
lines changed

gazelle/modules_mapping/BUILD.bazel

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
load("@rules_python//python:defs.bzl", "py_binary")
1+
load("@rules_python//python:defs.bzl", "py_binary", "py_test")
22

33
# gazelle:exclude *.py
44

@@ -8,6 +8,15 @@ py_binary(
88
visibility = ["//visibility:public"],
99
)
1010

11+
py_test(
12+
name = "test_generator",
13+
srcs = ["test_generator.py"],
14+
data = glob(["testdata/**"]),
15+
imports = ["."],
16+
main = "test_generator.py",
17+
deps = [":generator"],
18+
)
19+
1120
filegroup(
1221
name = "distribution",
1322
srcs = glob(["**"]),

gazelle/modules_mapping/generator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ def __init__(self, stderr, output_file, excluded_patterns):
3535
# dig_wheel analyses the wheel .whl file determining the modules it provides
3636
# by looking at the directory structure.
3737
def dig_wheel(self, whl):
38+
# Skip stubs and types wheels.
39+
wheel_name = get_wheel_name(whl)
40+
if wheel_name.endswith(("_stubs", "_types")):
41+
self.mapping[wheel_name.lower()] = wheel_name.lower()
42+
return
3843
with zipfile.ZipFile(whl, "r") as zip_file:
3944
for path in zip_file.namelist():
4045
if is_metadata(path):
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import pathlib
2+
import unittest
3+
4+
from generator import Generator
5+
6+
7+
class GeneratorTest(unittest.TestCase):
8+
def test_generator(self):
9+
whl = pathlib.Path(
10+
pathlib.Path(__file__).parent, "testdata", "pytest-7.1.1-py3-none-any.whl"
11+
)
12+
gen = Generator(None, None, {})
13+
gen.dig_wheel(whl)
14+
self.assertLessEqual(
15+
{
16+
"_pytest": "pytest",
17+
"_pytest.__init__": "pytest",
18+
"_pytest._argcomplete": "pytest",
19+
"_pytest.config.argparsing": "pytest",
20+
}.items(),
21+
gen.mapping.items(),
22+
)
23+
24+
def test_stub_generator(self):
25+
whl = pathlib.Path(
26+
pathlib.Path(__file__).parent,
27+
"testdata",
28+
"django_types-0.15.0-py3-none-any.whl",
29+
)
30+
gen = Generator(None, None, {})
31+
gen.dig_wheel(whl)
32+
self.assertLessEqual(
33+
{
34+
"django_types": "django_types",
35+
}.items(),
36+
gen.mapping.items(),
37+
)
38+
39+
40+
if __name__ == "__main__":
41+
unittest.main()
Binary file not shown.
290 KB
Binary file not shown.

gazelle/python/resolve.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,18 @@ func (py *Resolver) Resolve(
189189
continue MODULES_LOOP
190190
}
191191
} else {
192-
if dep, ok := cfg.FindThirdPartyDependency(moduleName); ok {
192+
if dep, distributionName, ok := cfg.FindThirdPartyDependency(moduleName); ok {
193193
deps.Add(dep)
194+
// Add the type and stub dependencies if they exist.
195+
typeModule := fmt.Sprintf("%s_types", strings.ToLower(distributionName))
196+
if dep, _, ok := cfg.FindThirdPartyDependency(typeModule); ok {
197+
deps.Add(dep)
198+
199+
}
200+
stubModule := fmt.Sprintf("%s_stubs", strings.ToLower(distributionName))
201+
if dep, _, ok := cfg.FindThirdPartyDependency(stubModule); ok {
202+
deps.Add(dep)
203+
}
194204
if explainDependency == dep {
195205
log.Printf("Explaining dependency (%s): "+
196206
"in the target %q, the file %q imports %q at line %d, "+

gazelle/pythonconfig/pythonconfig.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ func (c *Config) SetGazelleManifest(gazelleManifest *manifest.Manifest) {
278278
// FindThirdPartyDependency scans the gazelle manifests for the current config
279279
// and the parent configs up to the root finding if it can resolve the module
280280
// name.
281-
func (c *Config) FindThirdPartyDependency(modName string) (string, bool) {
281+
func (c *Config) FindThirdPartyDependency(modName string) (string, string, bool) {
282282
for currentCfg := c; currentCfg != nil; currentCfg = currentCfg.parent {
283283
if currentCfg.gazelleManifest != nil {
284284
gazelleManifest := currentCfg.gazelleManifest
@@ -291,11 +291,11 @@ func (c *Config) FindThirdPartyDependency(modName string) (string, bool) {
291291
}
292292

293293
lbl := currentCfg.FormatThirdPartyDependency(distributionRepositoryName, distributionName)
294-
return lbl.String(), true
294+
return lbl.String(), distributionName, true
295295
}
296296
}
297297
}
298-
return "", false
298+
return "", "", false
299299
}
300300

301301
// AddIgnoreFile adds a file to the list of ignored files for a given package.

0 commit comments

Comments
 (0)