Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix NotEqual position issue #325

Merged
merged 1 commit into from
Jun 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions libcst/_nodes/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ def _get_token(self) -> str:

@add_slots
@dataclass(frozen=True)
class NotEqual(BaseCompOp):
class NotEqual(BaseCompOp, _BaseOneTokenOp):
"""
A comparison operator that can be used in a :class:`Comparison` expression.

Expand All @@ -691,7 +691,7 @@ def _validate(self) -> None:
if self.value not in ["!=", "<>"]:
raise CSTValidationError("Invalid value for NotEqual node.")

def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "BaseCompOp":
def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "NotEqual":
return self.__class__(
whitespace_before=visit_required(
self, "whitespace_before", self.whitespace_before, visitor
Expand All @@ -702,10 +702,8 @@ def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "BaseCompOp":
),
)

def _codegen_impl(self, state: CodegenState) -> None:
self.whitespace_before._codegen(state)
state.add_token(self.value)
self.whitespace_after._codegen(state)
def _get_token(self) -> str:
return self.value


@add_slots
Expand Down
24 changes: 22 additions & 2 deletions libcst/metadata/tests/test_position_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import libcst as cst
from libcst import parse_module
from libcst._batched_visitor import BatchableCSTVisitor
from libcst._visitors import CSTTransformer
from libcst._visitors import CSTVisitor
from libcst.metadata import (
CodeRange,
MetadataWrapper,
Expand Down Expand Up @@ -38,7 +38,7 @@ def test_visitor_provider(self) -> None:
"""
test = self

class DependentVisitor(CSTTransformer):
class DependentVisitor(CSTVisitor):
METADATA_DEPENDENCIES = (PositionProvider,)

def visit_Pass(self, node: cst.Pass) -> None:
Expand All @@ -49,6 +49,26 @@ def visit_Pass(self, node: cst.Pass) -> None:
wrapper = MetadataWrapper(parse_module("pass"))
wrapper.visit(DependentVisitor())

def test_equal_range(self) -> None:
test = self
expected_range = CodeRange((1, 4), (1, 6))

class EqualPositionVisitor(CSTVisitor):
METADATA_DEPENDENCIES = (PositionProvider,)

def visit_Equal(self, node: cst.Equal) -> None:
test.assertEqual(
self.get_metadata(PositionProvider, node), expected_range
)

def visit_NotEqual(self, node: cst.NotEqual) -> None:
test.assertEqual(
self.get_metadata(PositionProvider, node), expected_range
)

MetadataWrapper(parse_module("var == 1")).visit(EqualPositionVisitor())
MetadataWrapper(parse_module("var != 1")).visit(EqualPositionVisitor())

def test_batchable_provider(self) -> None:
test = self

Expand Down