Skip to content

Commit

Permalink
feat: request cancellation
Browse files Browse the repository at this point in the history
  • Loading branch information
mhuisi committed Feb 12, 2025
1 parent 07b0e5b commit 0694047
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 68 deletions.
1 change: 1 addition & 0 deletions src/Lean/Server/CodeActions/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def handleCodeAction (params : CodeActionParams) : RequestM (RequestTask (Array
let caps ← names.mapM evalCodeActionProvider
return (← builtinCodeActionProviders.get).toList.toArray ++ Array.zip names caps
caps.flatMapM fun (providerName, cap) => do
RequestM.checkCancelled
let cas ← cap params snap
cas.mapIdxM fun i lca => do
if lca.lazy?.isNone then return lca.eager
Expand Down
4 changes: 3 additions & 1 deletion src/Lean/Server/Completion.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Authors: Leonardo de Moura, Marc Huisinga
-/
prelude
import Lean.Server.Completion.CompletionCollectors
import Lean.Server.RequestCancellation
import Std.Data.HashMap

namespace Lean.Server.Completion
Expand Down Expand Up @@ -61,11 +62,12 @@ partial def find?
(cmdStx : Syntax)
(infoTree : InfoTree)
(caps : ClientCapabilities)
: IO CompletionList := do
: CancellableM CompletionList := do
let prioritizedPartitions := findPrioritizedCompletionPartitionsAt fileMap hoverPos cmdStx infoTree
let mut allCompletions := #[]
for partition in prioritizedPartitions do
for (i, completionInfoPos) in partition do
CancellableM.checkCancelled
let completions : Array ScoredCompletionItem ←
match i.info with
| .id stx id danglingDot lctx .. =>
Expand Down
39 changes: 27 additions & 12 deletions src/Lean/Server/Completion/CompletionCollectors.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import Lean.Data.FuzzyMatching
import Lean.Elab.Tactic.Doc
import Lean.Server.Completion.CompletionResolution
import Lean.Server.Completion.EligibleHeaderDecls
import Lean.Server.RequestCancellation

namespace Lean.Server.Completion
open Elab
Expand Down Expand Up @@ -36,7 +37,7 @@ section Infrastructure
Monad used for completion computation that allows modifying a completion `State` and reading
`CompletionParams`.
-/
private abbrev M := ReaderT Context $ StateRefT State MetaM
private abbrev M := ReaderT Context $ StateRefT State $ CancellableT MetaM

/-- Adds a new completion item to the state in `M`. -/
private def addItem
Expand Down Expand Up @@ -114,10 +115,13 @@ section Infrastructure
(ctx : ContextInfo)
(lctx : LocalContext)
(x : M Unit)
: IO (Array ScoredCompletionItem) :=
ctx.runMetaM lctx do
let (_, s) ← x.run ⟨params, completionInfoPos⟩ |>.run {}
return s.items
: CancellableM (Array ScoredCompletionItem) := do
let tk ← read
let r ← ctx.runMetaM lctx do
x.run ⟨params, completionInfoPos⟩ |>.run {} |>.run tk
match r with
| .error _ => throw .requestCancelled
| .ok (_, s) => return s.items

end Infrastructure

Expand Down Expand Up @@ -161,6 +165,16 @@ section Utils
return fuzzyMatchScoreWithThreshold? s₁ s₂ |>.map (declName, · / (p₂.getNumParts + 1).toFloat)
return none

private def forEligibleDeclsWithCancellationM [Monad m] [MonadEnv m]
[MonadLiftT (ST IO.RealWorld) m] [MonadCancellable m] [MonadLiftT IO m]
(f : Name → ConstantInfo → m PUnit) : m PUnit := do
let _ ← StateT.run (s := 0) <| forEligibleDeclsM fun decl ci => do
modify (· + 1)
if (← get) >= 10000 then
RequestCancellation.check
set <| 0
f decl ci

end Utils

section IdCompletionUtils
Expand Down Expand Up @@ -349,7 +363,7 @@ private def idCompletionCore
addUnresolvedCompletionItem localDecl.userName (.fvar localDecl.fvarId) (kind := CompletionItemKind.variable) score
-- search for matches in the environment
let env ← getEnv
forEligibleDeclsM fun declName c => do
forEligibleDeclsWithCancellationM fun declName c => do
let bestMatch? ← (·.2) <$> StateT.run (s := none) do
let matchUsingNamespace (ns : Name) : StateT (Option (Name × Float)) M Unit := do
let some (label, score) ← matchDecl? ns id danglingDot declName
Expand Down Expand Up @@ -380,6 +394,7 @@ private def idCompletionCore
matchUsingNamespace Name.anonymous
if let some (bestLabel, bestScore) := bestMatch? then
addUnresolvedCompletionItem bestLabel (.const declName) (← getCompletionKindForDecl c) bestScore
RequestCancellation.check
let matchAlias (ns : Name) (alias : Name) : Option Float :=
-- Recall that aliases may not be atomic and include the namespace where they were created.
if ns.isPrefixOf alias then
Expand Down Expand Up @@ -434,7 +449,7 @@ def idCompletion
(id : Name)
(hoverInfo : HoverInfo)
(danglingDot : Bool)
: IO (Array ScoredCompletionItem) :=
: CancellableM (Array ScoredCompletionItem) :=
runM params completionInfoPos ctx lctx do
idCompletionCore ctx stx id hoverInfo danglingDot

Expand All @@ -443,7 +458,7 @@ def dotCompletion
(completionInfoPos : Nat)
(ctx : ContextInfo)
(info : TermInfo)
: IO (Array ScoredCompletionItem) :=
: CancellableM (Array ScoredCompletionItem) :=
runM params completionInfoPos ctx info.lctx do
let nameSet ← try
getDotCompletionTypeNames (← instantiateMVars (← inferType info.expr))
Expand All @@ -452,7 +467,7 @@ def dotCompletion
if nameSet.isEmpty then
return

forEligibleDeclsM fun declName c => do
forEligibleDeclsWithCancellationM fun declName c => do
let unnormedTypeName := declName.getPrefix
if ! nameSet.contains unnormedTypeName then
return
Expand All @@ -471,7 +486,7 @@ def dotIdCompletion
(lctx : LocalContext)
(id : Name)
(expectedType? : Option Expr)
: IO (Array ScoredCompletionItem) :=
: CancellableM (Array ScoredCompletionItem) :=
runM params completionInfoPos ctx lctx do
let some expectedType := expectedType?
| return ()
Expand All @@ -485,7 +500,7 @@ def dotIdCompletion
catch _ =>
pure RBTree.empty

forEligibleDeclsM fun declName c => do
forEligibleDeclsWithCancellationM fun declName c => do
let unnormedTypeName := declName.getPrefix
if ! nameSet.contains unnormedTypeName then
return
Expand Down Expand Up @@ -513,7 +528,7 @@ def fieldIdCompletion
(lctx : LocalContext)
(id : Option Name)
(structName : Name)
: IO (Array ScoredCompletionItem) :=
: CancellableM (Array ScoredCompletionItem) :=
runM params completionInfoPos ctx lctx do
let idStr := id.map (·.toString) |>.getD ""
let fieldNames := getStructureFieldsFlattened (← getEnv) structName (includeSubobjectFields := false)
Expand Down
10 changes: 8 additions & 2 deletions src/Lean/Server/FileWorker.lean
Original file line number Diff line number Diff line change
Expand Up @@ -543,14 +543,14 @@ section NotificationHandling
let newDocText := foldDocumentChanges changes oldDoc.meta.text
updateDocument ⟨docId.uri, newVersion, newDocText, oldDoc.meta.dependencyBuildMode⟩
for (_, r) in st.pendingRequests do
r.cancelTk.cancel .edit
r.cancelTk.cancelByEdit


def handleCancelRequest (p : CancelParams) : WorkerM Unit := do
let st ← get
let some r := st.pendingRequests.find? p.id
| return
r.cancelTk.cancel .cancelRequest
r.cancelTk.cancelByCancelRequest
set <| { st with pendingRequests := st.pendingRequests.erase p.id }

/--
Expand Down Expand Up @@ -741,6 +741,12 @@ section MessageHandling
pure <| Task.pure <| .ok ()
| Except.ok t => (IO.mapTask · t) fun
| Except.ok r => do
if ← cancelTk.wasCancelledByCancelRequest then
-- Try not to emit a partial response if this request was cancelled.
-- Clients usually discard responses for requests that they cancelled anyways,
-- but it's still good to send less over the wire in this case.
emitResponse ctx (isComplete := false) <| RequestError.requestCancelled.toLspResponseError id
return
emitResponse ctx (isComplete := r.isComplete) <| .response id (toJson r.response)
| Except.error e =>
emitResponse ctx (isComplete := false) <| e.toLspResponseError id
Expand Down
3 changes: 1 addition & 2 deletions src/Lean/Server/FileWorker/InlayHints.lean
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def handleInlayHints (_ : InlayHintParams) (s : InlayHintState) :
| some lastEditTimestamp =>
let timeSinceLastEditMs := timestamp - lastEditTimestamp
inlayHintEditDelayMs - timeSinceLastEditMs
let (snaps, _, isComplete) ← ctx.doc.cmdSnaps.getFinishedPrefixWithConsistentLatency editDelayMs.toUInt32 (cancelTk? := ctx.cancelTk.truncatedTask)
let (snaps, _, isComplete) ← ctx.doc.cmdSnaps.getFinishedPrefixWithConsistentLatency editDelayMs.toUInt32 (cancelTk? := ctx.cancelTk.cancellationTask)
let finishedRange? : Option String.Range := do
return ⟨⟨0⟩, ← List.max? <| snaps.map (fun s => s.endPos)⟩
let oldInlayHints :=
Expand All @@ -143,7 +143,6 @@ def handleInlayHints (_ : InlayHintParams) (s : InlayHintState) :
let lspInlayHints ← inlayHints.mapM (·.toLspInlayHint srcSearchPath ctx.doc.meta.text)
let r := { response := lspInlayHints, isComplete }
let s := { s with oldInlayHints := inlayHints }
RequestM.checkCanceled
return (r, s)

def handleInlayHintsDidChange (p : DidChangeTextDocumentParams)
Expand Down
17 changes: 10 additions & 7 deletions src/Lean/Server/FileWorker/RequestHandling.lean
Original file line number Diff line number Diff line change
Expand Up @@ -382,13 +382,14 @@ partial def handleDocumentSymbol (_ : DocumentSymbolParams)
let t := doc.cmdSnaps.waitAll
mapTask t fun (snaps, _) => do
let mut stxs := snaps.map (·.stx)
return { syms := toDocumentSymbols doc.meta.text stxs #[] [] }
return { syms := toDocumentSymbols doc.meta.text stxs #[] [] }
where
toDocumentSymbols (text : FileMap) (stxs : List Syntax)
(syms : Array DocumentSymbol) (stack : List NamespaceEntry) :
Array DocumentSymbol :=
RequestM (Array DocumentSymbol) := do
RequestM.checkCancelled
match stxs with
| [] => stack.foldl (fun syms entry => entry.finish text syms none) syms
| [] => return stack.foldl (fun syms entry => entry.finish text syms none) syms
| stx::stxs => match stx with
| `(namespace $id) =>
let entry := { name := id.getId.componentsRev, stx, selection := id, prevSiblings := syms }
Expand All @@ -411,9 +412,9 @@ where
let syms := entry.finish text syms stx
popStack (n - entry.name.length) syms stack
popStack (id.map (·.getId.getNumParts) |>.getD 1) syms stack
| _ => Id.run do
| _ => do
unless stx.isOfKind ``Lean.Parser.Command.declaration do
return toDocumentSymbols text stxs syms stack
return toDocumentSymbols text stxs syms stack
if let some stxRange := stx.getRange? then
let (name, selection) := match stx with
| `($_:declModifiers $_:attrKind instance $[$np:namedPrio]? $[$id$[.{$ls,*}]?]? $sig:declSig $_) =>
Expand All @@ -431,7 +432,7 @@ where
range := stxRange.toLspRange text
selectionRange := selRange.toLspRange text
}
return toDocumentSymbols text stxs (syms.push sym) stack
return toDocumentSymbols text stxs (syms.push sym) stack
toDocumentSymbols text stxs syms stack

partial def handleFoldingRange (_ : FoldingRangeParams)
Expand All @@ -450,7 +451,9 @@ partial def handleFoldingRange (_ : FoldingRangeParams)
if let (_, start)::rest := sections then
addRange text FoldingRangeKind.region start text.source.endPos
addRanges text rest []
| stx::stxs => match stx with
| stx::stxs => do
RequestM.checkCancelled
match stx with
| `(namespace $id) =>
addRanges text ((id.getId.getNumParts, stx.getPos?)::sections) stxs
| `(section $(id)?) =>
Expand Down
6 changes: 4 additions & 2 deletions src/Lean/Server/FileWorker/SemanticHighlighting.lean
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,12 @@ def handleSemanticTokens (beginPos : String.Pos) (endPos? : Option String.Pos)
-- for the full file before sending a response. This means that the response will be incomplete,
-- which we mitigate by regularly sending `workspace/semanticTokens/refresh` requests in the
-- `FileWorker` to tell the client to re-compute the semantic tokens.
let (snaps, _, isComplete) ← doc.cmdSnaps.getFinishedPrefixWithTimeout 3000 (cancelTk? := ctx.cancelTk.truncatedTask)
let (snaps, _, isComplete) ← doc.cmdSnaps.getFinishedPrefixWithTimeout 3000 (cancelTk? := ctx.cancelTk.cancellationTask)
asTask <| do
return { response := ← run doc snaps, isComplete }
| some endPos =>
let t := doc.cmdSnaps.waitUntil (·.endPos >= endPos)
mapTask t fun (snaps, _) => do
RequestM.checkCanceled
return { response := ← run doc snaps, isComplete := true }
where
run doc snaps : RequestM SemanticTokens := do
Expand All @@ -164,8 +163,11 @@ where
let syntaxBasedSemanticTokens := collectSyntaxBasedSemanticTokens s.stx
let infoBasedSemanticTokens := collectInfoBasedSemanticTokens s.infoTree
leanSemanticTokens := leanSemanticTokens ++ syntaxBasedSemanticTokens ++ infoBasedSemanticTokens
RequestM.checkCancelled
let absoluteLspSemanticTokens := computeAbsoluteLspSemanticTokens doc.meta.text beginPos endPos? leanSemanticTokens
RequestM.checkCancelled
let absoluteLspSemanticTokens := filterDuplicateSemanticTokens absoluteLspSemanticTokens
RequestM.checkCancelled
let semanticTokens := computeDeltaLspSemanticTokens absoluteLspSemanticTokens
return semanticTokens

Expand Down
77 changes: 77 additions & 0 deletions src/Lean/Server/RequestCancellation.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Marc Huisinga
-/
prelude
import Init.System.Promise

namespace Lean.Server

structure RequestCancellationToken where
cancelledByCancelRequest : IO.Ref Bool
cancelledByEdit : IO.Ref Bool
cancellationPromise : IO.Promise Unit

namespace RequestCancellationToken

def new : IO RequestCancellationToken := do
return {
cancelledByCancelRequest := ← IO.mkRef false
cancelledByEdit := ← IO.mkRef false
cancellationPromise := ← IO.Promise.new
}

def cancelByCancelRequest (tk : RequestCancellationToken) : IO Unit := do
tk.cancelledByCancelRequest.set true
tk.cancellationPromise.resolve ()

def cancelByEdit (tk : RequestCancellationToken) : IO Unit := do
tk.cancelledByEdit.set true
tk.cancellationPromise.resolve ()

def cancellationTask (tk : RequestCancellationToken) : Task Unit :=
tk.cancellationPromise.result!

def wasCancelledByCancelRequest (tk : RequestCancellationToken) : IO Bool :=
tk.cancelledByCancelRequest.get

def wasCancelledByEdit (tk : RequestCancellationToken) : IO Bool := do
tk.cancelledByEdit.get

end RequestCancellationToken

structure RequestCancellation where

def RequestCancellation.requestCancelled : RequestCancellation := {}

abbrev CancellableT m := ReaderT RequestCancellationToken (ExceptT RequestCancellation m)
abbrev CancellableM := CancellableT IO

def CancellableT.run (tk : RequestCancellationToken) (x : CancellableT m α) :
m (Except RequestCancellation α) :=
x tk

def CancellableM.run (tk : RequestCancellationToken) (x : CancellableM α) :
IO (Except RequestCancellation α) :=
CancellableT.run tk x

def CancellableT.checkCancelled [Monad m] [MonadLiftT IO m] : CancellableT m Unit := do
let tk ← read
if ← tk.wasCancelledByCancelRequest then
throw .requestCancelled

def CancellableM.checkCancelled : CancellableM Unit :=
CancellableT.checkCancelled

class MonadCancellable (m : TypeType v) where
checkCancelled : m PUnit

instance (m n) [MonadLift m n] [MonadCancellable m] : MonadCancellable n where
checkCancelled := liftM (MonadCancellable.checkCancelled : m PUnit)

instance [Monad m] [MonadLiftT IO m] : MonadCancellable (CancellableT m) where
checkCancelled := CancellableT.checkCancelled

def RequestCancellation.check [MonadCancellable m] : m Unit :=
MonadCancellable.checkCancelled
Loading

0 comments on commit 0694047

Please sign in to comment.