Skip to content

Commit 5f9c11b

Browse files
committed
revamp GraphNode locking
1 parent 78202b8 commit 5f9c11b

File tree

2 files changed

+68
-82
lines changed

2 files changed

+68
-82
lines changed

src/Compiler/Facilities/BuildGraph.fs

Lines changed: 53 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -37,86 +37,66 @@ module GraphNode =
3737
| None -> ()
3838

3939
[<Sealed>]
40-
type GraphNode<'T> private (computation: Async<'T>, cachedResult: ValueOption<'T>, cachedResultNode: Async<'T>) =
40+
type GraphNode<'T> private (compute: unit -> unit, tcs: TaskCompletionSource<'T>, cts: CancellationTokenSource) =
4141

42-
let mutable computation = computation
4342
let mutable requestCount = 0
43+
let mutable started = false
44+
45+
// Any locking we do is for very short synchronous state updates.
46+
let gate = obj
47+
48+
new(computation) =
49+
// Apparently a trick to force GC of the original computation:
50+
let mutable computation = computation
51+
52+
let tcs = TaskCompletionSource<'T>()
53+
let cts = new CancellationTokenSource()
54+
55+
let compute () =
56+
Async.StartWithContinuations(
57+
async {
58+
Thread.CurrentThread.CurrentUICulture <- GraphNode.culture
59+
return! computation
60+
},
61+
(fun result ->
62+
tcs.SetResult result
63+
// Allow GC of the original computation.
64+
computation <- Unchecked.defaultof<_>),
65+
(tcs.SetException),
66+
(ignore >> tcs.SetCanceled),
67+
// This is not a requestor's CancellationToken.
68+
cts.Token)
69+
70+
GraphNode(compute, tcs, cts)
4471

45-
let mutable cachedResult = cachedResult
46-
let mutable cachedResultNode: Async<'T> = cachedResultNode
72+
member _.GetOrComputeValue() =
4773

48-
let isCachedResultNodeNotNull () =
49-
not (obj.ReferenceEquals(cachedResultNode, null))
74+
// Lock for the sake of `started` flag.
75+
let startNew = lock gate <| fun () ->
76+
Interlocked.Increment &requestCount = 1 && not started
77+
78+
// The cancellation of the computation is not governed by the requestor's CancellationToken.
79+
// It will continue to run as long as there are requests.
80+
if startNew then started <- true; compute()
5081

51-
let semaphore = new SemaphoreSlim(1, 1)
82+
async {
83+
try
84+
return! tcs.Task |> Async.AwaitTask
85+
finally
86+
if Interlocked.Decrement &requestCount = 0 then
87+
// All requestors either finished or cancelled, so it is safe to cancel either way.
88+
cts.Cancel()
89+
}
5290

53-
member _.GetOrComputeValue() =
54-
// fast path
55-
if isCachedResultNodeNotNull () then
56-
cachedResultNode
57-
else
58-
async {
59-
Interlocked.Increment(&requestCount) |> ignore
60-
61-
try
62-
let! ct = Async.CancellationToken
63-
64-
// We must set 'taken' before any implicit cancellation checks
65-
// occur, making sure we are under the protection of the 'try'.
66-
// For example, NodeCode's 'try/finally' (TryFinally) uses async.TryFinally which does
67-
// implicit cancellation checks even before the try is entered, as do the
68-
// de-sugaring of 'do!' and other NodeCode constructs.
69-
let mutable taken = false
70-
71-
try
72-
do!
73-
semaphore
74-
.WaitAsync(ct)
75-
.ContinueWith(
76-
(fun _ -> taken <- true),
77-
(TaskContinuationOptions.NotOnCanceled
78-
||| TaskContinuationOptions.NotOnFaulted
79-
||| TaskContinuationOptions.ExecuteSynchronously)
80-
)
81-
|> Async.AwaitTask
82-
83-
match cachedResult with
84-
| ValueSome value -> return value
85-
| _ ->
86-
let tcs = TaskCompletionSource<'T>()
87-
let p = computation
88-
89-
Async.StartWithContinuations(
90-
async {
91-
Thread.CurrentThread.CurrentUICulture <- GraphNode.culture
92-
return! p
93-
},
94-
(fun res ->
95-
cachedResult <- ValueSome res
96-
cachedResultNode <- async.Return res
97-
computation <- Unchecked.defaultof<_>
98-
tcs.SetResult(res)),
99-
(fun ex -> tcs.SetException(ex)),
100-
(fun _ -> tcs.SetCanceled()),
101-
ct
102-
)
103-
104-
return! tcs.Task |> Async.AwaitTask
105-
finally
106-
if taken then
107-
semaphore.Release() |> ignore
108-
finally
109-
Interlocked.Decrement(&requestCount) |> ignore
110-
}
111-
112-
member _.TryPeekValue() = cachedResult
113-
114-
member _.HasValue = cachedResult.IsSome
91+
92+
member _.TryPeekValue() = if tcs.Task.IsCompleted then ValueSome tcs.Task.Result else ValueNone
93+
94+
member _.HasValue = tcs.Task.IsCompleted
11595

11696
member _.IsComputing = requestCount > 0
11797

11898
static member FromResult(result: 'T) =
119-
let nodeResult = async.Return result
120-
GraphNode(nodeResult, ValueSome result, nodeResult)
121-
122-
new(computation) = GraphNode(computation, ValueNone, Unchecked.defaultof<_>)
99+
let tcs = TaskCompletionSource()
100+
tcs.SetResult result
101+
GraphNode(ignore, tcs, new CancellationTokenSource())
102+

tests/FSharp.Compiler.UnitTests/BuildGraphTests.fs

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,16 +113,22 @@ module BuildGraphTests =
113113
let ``Many requests to get a value asynchronously should have its computation cleaned up by the GC``() =
114114
let requests = 10000
115115

116-
let graphNode, weak = createNode ()
116+
let weak =
117117

118-
GC.Collect(2, GCCollectionMode.Forced, true)
118+
let graphNode, weak = createNode ()
119+
120+
GC.Collect(2, GCCollectionMode.Forced, true)
119121

120-
Assert.shouldBeTrue weak.IsAlive
122+
Assert.shouldBeTrue weak.IsAlive
121123

122-
Async.RunImmediate(Async.Parallel(Array.init requests (fun _ -> graphNode.GetOrComputeValue() )))
123-
|> ignore
124+
Async.RunImmediate(Async.Parallel(Array.init requests (fun _ -> graphNode.GetOrComputeValue() )))
125+
|> ignore
124126

125-
GC.Collect(2, GCCollectionMode.Forced, true)
127+
weak
128+
129+
GC.Collect()
130+
131+
//GC.Collect(2, GCCollectionMode.Forced, true)
126132

127133
Assert.shouldBeFalse weak.IsAlive
128134

@@ -170,7 +176,7 @@ module BuildGraphTests =
170176
if task.Wait(1000) |> not then raise (TimeoutException())
171177

172178
[<Fact>]
173-
let ``Many requests to get a value asynchronously might evaluate the computation more than once even when some requests get canceled``() =
179+
let ``Many requests to get a value asynchronously will never evaluate the value more than once``() =
174180
let requests = 10000
175181
let resetEvent = new ManualResetEvent(false)
176182
let mutable computationCountBeforeSleep = 0
@@ -208,8 +214,8 @@ module BuildGraphTests =
208214
|> ignore
209215

210216
Assert.shouldBeTrue cts.IsCancellationRequested
211-
Assert.shouldBeTrue(computationCountBeforeSleep > 0)
212-
Assert.shouldBeTrue(computationCount >= 0)
217+
Assert.shouldBeTrue(computationCountBeforeSleep = 1)
218+
Assert.shouldBeTrue(computationCount = 1)
213219

214220
tasks
215221
|> Seq.iter (fun x ->

0 commit comments

Comments
 (0)