Skip to content

Commit

Permalink
add dynamic evaluation of __all__ as a list comprehension
Browse files Browse the repository at this point in the history
  • Loading branch information
dyc3 committed May 18, 2021
1 parent 816c1c3 commit 4afb034
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 0 deletions.
94 changes: 94 additions & 0 deletions packages/pyright-internal/src/analyzer/binder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ import {
ImportFromNode,
IndexNode,
LambdaNode,
ListComprehensionForNode,
ListComprehensionIfNode,
ListComprehensionNode,
MatchNode,
MemberAccessNode,
Expand Down Expand Up @@ -699,6 +701,98 @@ export class Binder extends ParseTreeWalker {
listEntryNode.strings[0].nodeType === ParseNodeType.String
) {
this._dunderAllNames!.push(listEntryNode.strings[0].value);
} else if (listEntryNode.nodeType === ParseNodeType.ListComprehension) {
// dynamic evaluation of __all__
let forNode: ListComprehensionForNode | undefined = undefined, ifNode: ListComprehensionIfNode | undefined = undefined;
for (let comprehensionNode of listEntryNode.comprehensions) {
if (comprehensionNode.nodeType === ParseNodeType.ListComprehensionFor) {
forNode = comprehensionNode;
} else if (comprehensionNode.nodeType === ParseNodeType.ListComprehensionIf) {
ifNode = comprehensionNode;
} else {
emitDunderAllWarning = true;
}
}

if (!forNode) {
emitDunderAllWarning = true;
return
}

if (forNode.iterableExpression.nodeType === ParseNodeType.Call) {
let call = forNode.iterableExpression;
if (call.leftExpression.nodeType === ParseNodeType.Name) {
if (call.leftExpression.value === "dir") {
this._currentScope.symbolTable.forEach((value, key) => {
this._dunderAllNames!.push(key);
})
} else {
emitDunderAllWarning = true;
}
} else {
emitDunderAllWarning = true;
}
} else {
emitDunderAllWarning = true;
}

let targetName: string;
if (forNode.targetExpression.nodeType === ParseNodeType.Name) {
targetName = forNode.targetExpression.value;
} else {
emitDunderAllWarning = true;
return;
}

function createFilterFunction(ex: ExpressionNode): ((name: string) => boolean) {
if (ex.nodeType === ParseNodeType.UnaryOperation) {
if (ex.operator === OperatorType.Not) {
return (name: string) => !createFilterFunction(ex.expression);
} else {
emitDunderAllWarning = true;
}
} else if (ex.nodeType === ParseNodeType.Call) {
if (ex.leftExpression.nodeType === ParseNodeType.MemberAccess) {
let member = ex.leftExpression;
if (member.leftExpression.nodeType === ParseNodeType.Name && member.leftExpression.value === targetName) {
let firstArg = ex.arguments[0].valueExpression;
if (firstArg.nodeType === ParseNodeType.StringList) {
if (["startswith", "endswith"].includes(member.memberName.value)) {
let funcname = member.memberName.value;
let prefix = firstArg.strings[0].value;
return (name: string) => {
const pyToJsFuncMap = new Map<string, ((name: string) => boolean)>([
["startswith", name.startsWith],
["endswith", name.endsWith],
]);
return pyToJsFuncMap.get(funcname)!(prefix);
};
} else if (member.memberName.value === "startswith") {
let prefix = firstArg.strings[0].value
return (name: string) => name.startsWith(prefix);
} else {
emitDunderAllWarning = true;
}
} else {
emitDunderAllWarning = true;
}
} else {
emitDunderAllWarning = true;
}
} else {
emitDunderAllWarning = true;
}
} else {
emitDunderAllWarning = true;
}
return (name: string) => true;
}

if (ifNode) {
let filterFunc = createFilterFunction(ifNode.testExpression);
this._dunderAllNames = this._dunderAllNames?.filter(filterFunc);
}

} else {
emitDunderAllWarning = true;
}
Expand Down
2 changes: 2 additions & 0 deletions packages/pyright-internal/src/parser/parseNodes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ export type ParseNodeArray = (ParseNode | undefined)[];
export interface ModuleNode extends ParseNodeBase {
readonly nodeType: ParseNodeType.Module;
statements: StatementNode[];
dunderAllNames: string[];
}

export namespace ModuleNode {
Expand All @@ -162,6 +163,7 @@ export namespace ModuleNode {
nodeType: ParseNodeType.Module,
id: _nextNodeId++,
statements: [],
dunderAllNames: [],
};

return node;
Expand Down
1 change: 1 addition & 0 deletions packages/pyright-internal/src/tests/samples/dunderAll1.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
__all__ += ["bar"]
__all__ += mock.__all__
__all__.extend(mock.__all__)
__all__ = [x for x in dir() if not x.startswith("_")]


my_string = "foo"
Expand Down
12 changes: 12 additions & 0 deletions packages/pyright-internal/src/tests/samples/dunderAll2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# This sample tests dynamic __all__ assignments based on dir()

# pyright: reportMissingModuleSource=false

from typing import Any

__all__: Any

foo = 42
_bar = "asdf"

__all__ = [x for x in dir()]
12 changes: 12 additions & 0 deletions packages/pyright-internal/src/tests/samples/dunderAll3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# This sample tests dynamic __all__ assignments based on dir()

# pyright: reportMissingModuleSource=false

from typing import Any

__all__: Any

foo = 42
_bar = "asdf"

__all__ = [x for x in dir() if not x.startswith("_")]
11 changes: 11 additions & 0 deletions packages/pyright-internal/src/tests/typeEvaluator4.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,17 @@ test('DunderAll1', () => {
TestUtils.validateResults(analysisResults, 0, 0);
});

test('DunderAll2', () => {
let analysisResults = TestUtils.typeAnalyzeSampleFiles(['dunderAll2.py']);
expect(analysisResults[0].parseResults?.parseTree.dunderAllNames).toContain("foo");
expect(analysisResults[0].parseResults?.parseTree.dunderAllNames).toContain("_bar");
})

test('DunderAll3', () => {
let analysisResults = TestUtils.typeAnalyzeSampleFiles(['dunderAll3.py']);
expect(analysisResults[0].parseResults?.parseTree.dunderAllNames).not.toContain("_bar");
})

test('Overload1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['overload1.py']);
TestUtils.validateResults(analysisResults, 2);
Expand Down

0 comments on commit 4afb034

Please sign in to comment.