Skip to content

Commit

Permalink
Support #allow-multiple! predicate to enable multiple content ranges (
Browse files Browse the repository at this point in the history
  • Loading branch information
pokey authored Jun 6, 2023
1 parent 0b7756e commit f759734
Show file tree
Hide file tree
Showing 9 changed files with 158 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ export interface QueryCapture {

/** The range of the capture. */
readonly range: Range;

/** Whether it is ok for the same capture to appear multiple times with the
* same domain. If set to `true`, then the scope handler should merge all
* captures with the same name and domain into a single scope with multiple
* content ranges. */
readonly allowMultiple: boolean;
}

/**
Expand All @@ -40,6 +46,7 @@ export interface MutableQueryCapture extends QueryCapture {
readonly node: SyntaxNode;

range: Range;
allowMultiple: boolean;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ export class TreeSitterQuery {
name,
node,
range: getNodeRange(node),
allowMultiple: false,
})),
}),
)
Expand Down Expand Up @@ -108,6 +109,7 @@ export class TreeSitterQuery {
range: captures
.map(({ range }) => range)
.reduce((accumulator, range) => range.union(accumulator)),
allowMultiple: captures.some((capture) => capture.allowMultiple),
};
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import assert = require("assert");

interface TestCase {
name: string;
captures: QueryCapture[];
captures: Omit<QueryCapture, "allowMultiple">[];
isValid: boolean;
expectedErrorMessageIds: string[];
}
Expand Down Expand Up @@ -188,7 +188,13 @@ suite("checkCaptureStartEnd", () => {
},
};

const result = checkCaptureStartEnd(testCase.captures, messages);
const result = checkCaptureStartEnd(
testCase.captures.map((capture) => ({
...capture,
allowMultiple: false,
})),
messages,
);
assert(result === testCase.isValid);
assert.deepStrictEqual(actualErrorIds, testCase.expectedErrorMessageIds);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,23 @@ class ChildRange extends QueryPredicateOperator<ChildRange> {
}
}

class AllowMultiple extends QueryPredicateOperator<AllowMultiple> {
name = "allow-multiple!" as const;
schema = z.tuple([q.node]);

run(nodeInfo: MutableQueryCapture) {
nodeInfo.allowMultiple = true;

return true;
}
}

export const queryPredicateOperators = [
new NotType(),
new NotParentType(),
new IsNthChild(),
new StartPosition(),
new EndPosition(),
new ChildRange(),
new AllowMultiple(),
];
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ export abstract class BaseTreeSitterScopeHandler extends BaseScopeHandler {
const scopes = this.query
.matches(document, start, end)
.map((match) => this.matchToScope(editor, match))
.filter((scope): scope is TargetScope => scope != null)
.filter((scope): scope is ExtendedTargetScope => scope != null)
.sort((a, b) => compareTargetScopes(direction, position, a, b));

// Merge scopes that have the same domain into a single scope with multiple
Expand All @@ -56,11 +56,23 @@ export abstract class BaseTreeSitterScopeHandler extends BaseScopeHandler {

return {
...equivalentScopes[0],

getTargets(isReversed: boolean) {
return uniqWith(
const targets = uniqWith(
equivalentScopes.flatMap((scope) => scope.getTargets(isReversed)),
(a, b) => a.isEqual(b),
);

if (
targets.length > 1 &&
!equivalentScopes.every((scope) => scope.allowMultiple)
) {
throw Error(
"Please use #allow-multiple! predicate in your query to allow multiple matches for this scope type",
);
}

return targets;
},
};
},
Expand All @@ -78,7 +90,11 @@ export abstract class BaseTreeSitterScopeHandler extends BaseScopeHandler {
protected abstract matchToScope(
editor: TextEditor,
match: QueryMatch,
): TargetScope | undefined;
): ExtendedTargetScope | undefined;
}

export interface ExtendedTargetScope extends TargetScope {
allowMultiple: boolean;
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import { ScopeType, SimpleScopeType, TextEditor } from "@cursorless/common";
import { TreeSitterQuery } from "../../../../languages/TreeSitterQuery";
import { PlainTarget } from "../../../targets";
import { TargetScope } from "../scope.types";
import { BaseTreeSitterScopeHandler } from "./BaseTreeSitterScopeHandler";
import { getCaptureRangeByName, getRelatedRange } from "./captureUtils";
import { QueryMatch } from "../../../../languages/TreeSitterQuery/QueryCapture";
import { PlainTarget } from "../../../targets";
import {
BaseTreeSitterScopeHandler,
ExtendedTargetScope,
} from "./BaseTreeSitterScopeHandler";
import { getRelatedCapture, getRelatedRange } from "./captureUtils";

/** Scope handler to be used for iteration scopes of tree-sitter scope types */
export class TreeSitterIterationScopeHandler extends BaseTreeSitterScopeHandler {
Expand All @@ -30,26 +32,26 @@ export class TreeSitterIterationScopeHandler extends BaseTreeSitterScopeHandler
protected matchToScope(
editor: TextEditor,
match: QueryMatch,
): TargetScope | undefined {
): ExtendedTargetScope | undefined {
const scopeTypeType = this.iterateeScopeType.type;

const contentRange = getRelatedRange(match, scopeTypeType, "iteration")!;
const capture = getRelatedCapture(match, scopeTypeType, "iteration", false);

if (contentRange == null) {
if (capture == null) {
// This capture was for some unrelated scope type
return undefined;
}

const { range: contentRange, allowMultiple } = capture;

const domain =
getCaptureRangeByName(
match,
`${scopeTypeType}.iteration.domain`,
`_.iteration.domain`,
) ?? contentRange;
getRelatedRange(match, scopeTypeType, "iteration.domain", false) ??
contentRange;

return {
editor,
domain,
allowMultiple,
getTargets: (isReversed) => [
new PlainTarget({
editor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ import { SimpleScopeType, TextEditor } from "@cursorless/common";
import { TreeSitterQuery } from "../../../../languages/TreeSitterQuery";
import { QueryMatch } from "../../../../languages/TreeSitterQuery/QueryCapture";
import ScopeTypeTarget from "../../../targets/ScopeTypeTarget";
import { TargetScope } from "../scope.types";
import { CustomScopeType } from "../scopeHandler.types";
import { BaseTreeSitterScopeHandler } from "./BaseTreeSitterScopeHandler";
import {
BaseTreeSitterScopeHandler,
ExtendedTargetScope,
} from "./BaseTreeSitterScopeHandler";
import { TreeSitterIterationScopeHandler } from "./TreeSitterIterationScopeHandler";
import { getCaptureRangeByName, getRelatedRange } from "./captureUtils";
import { findCaptureByName, getRelatedRange } from "./captureUtils";

/**
* Handles scopes that are implemented using tree-sitter.
Expand All @@ -33,38 +35,48 @@ export class TreeSitterScopeHandler extends BaseTreeSitterScopeHandler {
protected matchToScope(
editor: TextEditor,
match: QueryMatch,
): TargetScope | undefined {
): ExtendedTargetScope | undefined {
const scopeTypeType = this.scopeType.type;

const contentRange = getCaptureRangeByName(match, scopeTypeType);
const capture = findCaptureByName(match, scopeTypeType);

if (contentRange == null) {
if (capture == null) {
// This capture was for some unrelated scope type
return undefined;
}

const { range: contentRange, allowMultiple } = capture;

const domain =
getRelatedRange(match, scopeTypeType, "domain") ?? contentRange;
getRelatedRange(match, scopeTypeType, "domain", true) ?? contentRange;

const removalRange = getRelatedRange(match, scopeTypeType, "removal");
const removalRange = getRelatedRange(match, scopeTypeType, "removal", true);

const leadingDelimiterRange = getRelatedRange(
match,
scopeTypeType,
"leading",
true,
);

const trailingDelimiterRange = getRelatedRange(
match,
scopeTypeType,
"trailing",
true,
);

const interiorRange = getRelatedRange(match, scopeTypeType, "interior");
const interiorRange = getRelatedRange(
match,
scopeTypeType,
"interior",
true,
);

return {
editor,
domain,
allowMultiple,
getTargets: (isReversed) => [
new ScopeTypeTarget({
scopeTypeType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ import { TreeSitterQuery } from "../../../../languages/TreeSitterQuery";
import { QueryMatch } from "../../../../languages/TreeSitterQuery/QueryCapture";
import { TEXT_FRAGMENT_CAPTURE_NAME } from "../../../../languages/captureNames";
import { PlainTarget } from "../../../targets";
import { TargetScope } from "../scope.types";
import { BaseTreeSitterScopeHandler } from "./BaseTreeSitterScopeHandler";
import { getCaptureRangeByName } from "./captureUtils";
import {
BaseTreeSitterScopeHandler,
ExtendedTargetScope,
} from "./BaseTreeSitterScopeHandler";
import { findCaptureByName } from "./captureUtils";

/** Scope handler to be used for extracting text fragments from the perspective
* of surrounding pairs */
Expand All @@ -28,20 +30,26 @@ export class TreeSitterTextFragmentScopeHandler extends BaseTreeSitterScopeHandl
protected matchToScope(
editor: TextEditor,
match: QueryMatch,
): TargetScope | undefined {
const contentRange = getCaptureRangeByName(
match,
TEXT_FRAGMENT_CAPTURE_NAME,
);
): ExtendedTargetScope | undefined {
const capture = findCaptureByName(match, TEXT_FRAGMENT_CAPTURE_NAME);

if (contentRange == null) {
if (capture == null) {
// This capture was for some unrelated scope type
return undefined;
}

const { range: contentRange, allowMultiple } = capture;

if (allowMultiple) {
throw Error(
"The #allow-multiple! predicate is not supported for text fragments",
);
}

return {
editor,
domain: contentRange,
allowMultiple,
getTargets: (isReversed) => [
new PlainTarget({
editor,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,63 @@
import { QueryMatch } from "../../../../languages/TreeSitterQuery/QueryCapture";

/**
* Gets a capture that is related to the scope. For example, if the scope is
* "class name", the `domain` node would be the containing class.
*
* @param match The match to get the range from
* @param scopeTypeType The type of the scope
* @param relationship The relationship to get the range for, eg "domain", or
* "removal"
* @param matchHasScopeType Set to `true` if this match is known to have a
* capture for the given scope type
* @returns A capture or undefined if no capture was found
*/
export function getRelatedCapture(
match: QueryMatch,
scopeTypeType: string,
relationship: string,
matchHasScopeType: boolean,
) {
if (matchHasScopeType) {
return findCaptureByName(
match,
`${scopeTypeType}.${relationship}`,
`_.${relationship}`,
);
}

return (
findCaptureByName(match, `${scopeTypeType}.${relationship}`) ??
(findCaptureByName(match, scopeTypeType) != null
? findCaptureByName(match, `_.${relationship}`)
: undefined)
);
}

/**
* Gets the range of a node that is related to the scope. For example, if the
* scope is "class name", the `domain` node would be the containing class.
*
* @param match The match to get the range from
* @param scopeTypeType The type of the scope
* @param relationship The relationship to get the range for, eg "domain", or "removal"
* @param relationship The relationship to get the range for, eg "domain", or
* "removal"
* @param matchHasScopeType Set to `true` if this match is known to have a
* capture for the given scope type
* @returns A range or undefined if no range was found
*/

export function getRelatedRange(
match: QueryMatch,
scopeTypeType: string,
relationship: string,
matchHasScopeType: boolean,
) {
return getCaptureRangeByName(
return getRelatedCapture(
match,
`${scopeTypeType}.${relationship}`,
`_.${relationship}`,
);
scopeTypeType,
relationship,
matchHasScopeType,
)?.range;
}

/**
Expand All @@ -30,8 +68,20 @@ export function getRelatedRange(
* @param names The possible names of the capture to get the range for
* @returns A range or undefined if no matching capture was found
*/
export function getCaptureRangeByName(match: QueryMatch, ...names: string[]) {
export function findCaptureRangeByName(match: QueryMatch, ...names: string[]) {
return findCaptureByName(match, ...names)?.range;
}

/**
* Looks in the captures of a match for a capture with one of the given names, and
* returns that capture, or undefined if no matching capture was found
*
* @param match The match to get the range from
* @param names The possible names of the capture to get the range for
* @returns A range or undefined if no matching capture was found
*/
export function findCaptureByName(match: QueryMatch, ...names: string[]) {
return match.captures.find((capture) =>
names.some((name) => capture.name === name),
)?.range;
);
}

0 comments on commit f759734

Please sign in to comment.