diff --git a/src/treeSitterColor.ts b/src/treeSitterColor.ts index a87a075f7..60f9ddadd 100644 --- a/src/treeSitterColor.ts +++ b/src/treeSitterColor.ts @@ -84,7 +84,7 @@ export function colorGo(root: Parser.SyntaxNode, visibleRanges: {start: number, } isRoot(): boolean { - return this.parent == null; + return this.parent === null; } } const colors: {[scope: string]: Parser.SyntaxNode[]} = { @@ -156,11 +156,19 @@ export function colorGo(root: Parser.SyntaxNode, visibleRanges: {start: number, } function scanExpr(x: Parser.SyntaxNode, scope: Scope) { switch (x.type) { + case 'ERROR': + return; case 'func_literal': case 'block': + case 'expression_case_clause': + case 'type_case_clause': + case 'for_statement': + case 'if_statement': + case 'type_switch_statement': scope = new Scope(scope); break; case 'parameter_declaration': + case 'variadic_parameter_declaration': case 'var_spec': case 'const_spec': for (const id of x.namedChildren) { @@ -177,6 +185,13 @@ export function colorGo(root: Parser.SyntaxNode, visibleRanges: {start: number, } } break; + case 'type_switch_guard': + if (x.firstChild!.type === 'expression_list') { + for (const id of x.firstChild!.namedChildren) { + scope.declareLocal(id.text); + } + } + break; case 'inc_statement': case 'dec_statement': scope.modifyLocal(x.firstChild!.text); diff --git a/test/unit/treeSitter.test.ts b/test/unit/treeSitter.test.ts index bc3b0efa9..930f661e3 100644 --- a/test/unit/treeSitter.test.ts +++ b/test/unit/treeSitter.test.ts @@ -1,8 +1,8 @@ import Parser = require('web-tree-sitter'); -import {colorGo} from '../../src/treeSitterColor'; +import { colorGo } from '../../src/treeSitterColor'; import * as assert from 'assert'; -type check = [string, string|{not: string}]; +type check = [string, string | { not: string }]; type TestCase = [string, ...check[]]; const testCases: TestCase[] = [ @@ -12,11 +12,11 @@ const testCases: TestCase[] = [ ], [ `type Foo struct { x int }`, - ['Foo', 'entity.name.type'], ['x', {not: 'variable'}] + ['Foo', 'entity.name.type'], ['x', { not: 'variable' }] ], [ `type Foo interface { GetX() int }`, - ['Foo', 'entity.name.type'], ['int', 'entity.name.type'], ['GetX', {not: 'variable'}] + ['Foo', 'entity.name.type'], ['int', 'entity.name.type'], ['GetX', { not: 'variable' }] ], [ `func f() { x := 1; x := 2 }`, @@ -24,7 +24,7 @@ const testCases: TestCase[] = [ ], [ `func f(foo T) { foo.Foo() }`, - ['Foo', {not: 'entity.name.function'}] + ['Foo', { not: 'entity.name.function' }] ], [ `func f() { Foo() }`, @@ -36,7 +36,7 @@ const testCases: TestCase[] = [ ], [ `import "foo"; func f(foo T) { foo.Foo() }`, - ['Foo', {not: 'entity.name.function'}] + ['Foo', { not: 'entity.name.function' }] ], [ `func f(x other.T) { }`, @@ -48,7 +48,92 @@ const testCases: TestCase[] = [ ], [ `import (foo "foobar"); var _ = foo.Bar()`, - ['foo', {not: 'variable'}], ['Bar', 'entity.name.function'], + ['foo', { not: 'variable' }], ['Bar', 'entity.name.function'], + ], + [ + `func f(a int) int { + switch a { + case 1: + x := 1 + return x + case 2: + x := 2 + return x + } + }`, + ['x', { not: 'markup.underline' }] + ], + [ + `func f(a interface{}) int { + switch a.(type) { + case *int: + x := 1 + return x + case *int: + x := 2 + return x + } + }`, + ['x', { not: 'markup.underline' }] + ], + [ + `func f(a interface{}) int { + for i := range 10 { + print(i) + } + for i := range 10 { + print(i) + } + }`, + ['i', { not: 'markup.underline' }] + ], + [ + `func f(a interface{}) int { + if i := 1; i < 10 { + print(i) + } + if i := 1; i < 10 { + print(i) + } + }`, + ['i', { not: 'markup.underline' }] + ], + [ + `func f(a interface{}) { + switch aa := a.(type) { + case *int: + print(aa) + } + }`, + ['aa', { not: 'variable' }] + ], + [ + `func f() { + switch aa.(type) { + case *int: + print(aa) + } + }`, + ['aa', 'variable'] + ], + [ + `func f(a interface{}) { + switch aa := a.(type) { + case *int: + print(aa) + } + switch aa := a.(type) { + case *int: + print(aa) + } + }`, + ['aa', { not: 'markup.underline' }] + ], + [ + `func f(a ...int) { + print(a) + }`, + ['a', { not: 'variable' }] ], ]; @@ -67,7 +152,7 @@ suite('Syntax coloring', () => { for (const [src, ...expect] of testCases) { test(src, async () => { const tree = (await parser).parse(src); - const found = colorGo(tree.rootNode, [{start: 0, end: tree.rootNode.endPosition.row}]); + const found = colorGo(tree.rootNode, [{ start: 0, end: tree.rootNode.endPosition.row }]); const foundMap = new Map>(); for (const scope of Object.keys(found)) { for (const node of found[scope]) {