Skip to content

Commit

Permalink
Add JoinableTaskFactory integration
Browse files Browse the repository at this point in the history
This avoids deadlocks that RPC would otherwise introduce by carrying a `JoinableTask` token across RPC so that when/if the RPC call ever makes it back to the original AppDomain (or never leaves it) such that the callee needs the main thread, it will be able to reach it if the caller owns the main thread.
  • Loading branch information
AArnott committed Mar 1, 2023
1 parent fe02bff commit 59af916
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 3 deletions.
6 changes: 4 additions & 2 deletions Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
<CentralPackageTransitivePinningEnabled>true</CentralPackageTransitivePinningEnabled>

<MicroBuildVersion>2.0.107</MicroBuildVersion>
<VisualStudioThreadingVersion>17.6.35-alpha-gd646d3a5e1</VisualStudioThreadingVersion>
</PropertyGroup>
<ItemGroup>
<PackageVersion Include="BenchmarkDotNet.Diagnostics.Windows" Version="0.13.2" />
Expand All @@ -18,8 +19,8 @@
<PackageVersion Include="Microsoft.CodeCoverage" Version="17.5.0-release-20230131-04" />
<PackageVersion Include="Microsoft.NET.Test.Sdk" Version="17.4.1" />
<PackageVersion Include="Microsoft.VisualStudio.Internal.MicroBuild.NonShipping" Version="$(MicroBuildVersion)" />
<PackageVersion Include="Microsoft.VisualStudio.Threading.Analyzers" Version="17.4.27" />
<PackageVersion Include="Microsoft.VisualStudio.Threading" Version="17.1.46" />
<PackageVersion Include="Microsoft.VisualStudio.Threading.Analyzers" Version="$(VisualStudioThreadingVersion)" />
<PackageVersion Include="Microsoft.VisualStudio.Threading" Version="$(VisualStudioThreadingVersion)" />
<PackageVersion Include="Nerdbank.Streams" Version="2.9.112" />
<PackageVersion Include="Newtonsoft.Json" Version="13.0.1" />
<PackageVersion Include="System.Collections.Immutable" Version="7.0.0" />
Expand All @@ -34,6 +35,7 @@
<PackageVersion Include="xunit.runner.console" Version="2.4.2" />
<PackageVersion Include="xunit.runner.visualstudio" Version="2.4.5" />
<PackageVersion Include="xunit.skippablefact" Version="1.4.13" />
<PackageVersion Include="xunit.stafact" Version="1.1.11" />
<PackageVersion Include="xunit" Version="2.4.2" />
</ItemGroup>
<ItemGroup>
Expand Down
45 changes: 44 additions & 1 deletion src/StreamJsonRpc/JsonRpc.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ namespace StreamJsonRpc;
/// </summary>
public class JsonRpc : IDisposableObservable, IJsonRpcFormatterCallbacks, IJsonRpcTracingCallbacks
{
private const string JoinableTaskTokenHeaderName = "joinableTaskToken";

/// <summary>
/// The <see cref="System.Threading.SynchronizationContext"/> to use to schedule work on the threadpool.
/// </summary>
Expand Down Expand Up @@ -106,6 +108,12 @@ public class JsonRpc : IDisposableObservable, IJsonRpcFormatterCallbacks, IJsonR
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private SynchronizationContext? synchronizationContext;

/// <summary>
/// Backing field for the <see cref="JoinableTaskFactory"/> property.
/// </summary>
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private JoinableTaskFactory? joinableTaskFactory;

/// <summary>
/// Backing field for the <see cref="CancellationStrategy"/> property.
/// </summary>
Expand Down Expand Up @@ -429,6 +437,21 @@ public SynchronizationContext? SynchronizationContext
}
}

/// <summary>
/// Gets or sets the <see cref="JoinableTaskFactory"/> to participate in to mitigate deadlocks with the main thread.
/// </summary>
/// <value>Defaults to null.</value>
public JoinableTaskFactory? JoinableTaskFactory
{
get => this.joinableTaskFactory;

set
{
this.ThrowIfConfigurationLocked();
this.joinableTaskFactory = value;
}
}

/// <summary>
/// Gets a <see cref="Task"/> that completes when this instance is disposed or when listening has stopped
/// whether by error, disposal or the stream closing.
Expand Down Expand Up @@ -1894,6 +1917,11 @@ private JsonRpcError CreateCancellationResponse(JsonRpcRequest request)
JsonRpcEventSource.Instance.SendingRequest(request.RequestId.NumberIfPossibleForEvent, request.Method, JsonRpcEventSource.GetArgumentsString(request));
}

if (this.joinableTaskFactory?.Context.Capture() is string parentToken)
{
request.TrySetTopLevelProperty(JoinableTaskTokenHeaderName, parentToken);
}

// IMPORTANT: This should be the first await in this async code path.
// This is crucial to the guarantee that overrides of SendAsync can assume they are executed
// before the first await when a JsonRpc call is made.
Expand Down Expand Up @@ -2044,7 +2072,22 @@ private async ValueTask<JsonRpcMessage> DispatchIncomingRequestAsync(JsonRpcRequ
}
}

return await this.DispatchRequestAsync(request, targetMethod, cancellationToken).ConfigureAwait(false);
if (this.JoinableTaskFactory is null || request.TryGetTopLevelProperty<string>(JoinableTaskTokenHeaderName, out string? parentToken) is false || parentToken is null)
{
return await this.DispatchRequestAsync(request, targetMethod, cancellationToken).ConfigureAwait(false);
}
else
{
JsonRpcMessage? result = null;
await this.JoinableTaskFactory.RunAsync(
async delegate
{
result = await this.DispatchRequestAsync(request, targetMethod, cancellationToken).ConfigureAwait(false);
},
parentToken,
JoinableTaskCreationOptions.None);
return result!;
}
}
else
{
Expand Down
35 changes: 35 additions & 0 deletions test/StreamJsonRpc.Tests/JsonRpcTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2848,6 +2848,34 @@ public async Task IncomingActivityStopsAfterAsyncTargetMethodCompletes()
await stopped.Task.WithCancellation(this.TimeoutToken);
}

[Fact]
public void JoinableTaskFactory_ThrowsAfterRunning()
{
Assert.Throws<InvalidOperationException>(() => this.clientRpc.JoinableTaskFactory = null);
}

[UIFact]
public void JoinableTaskFactory_Integration()
{
// Set up a main thread and JoinableTaskContext.
JoinableTaskContext jtc = new();

// Configure the client and server to understand JTF.
this.clientRpc.AllowModificationWhileListening = true;
this.clientRpc.JoinableTaskFactory = jtc.Factory;
this.serverRpc.AllowModificationWhileListening = true;
this.serverRpc.JoinableTaskFactory = jtc.Factory;

// Tell the server to require the main thread to get something done.
this.server.JoinableTaskFactory = jtc.Factory;

jtc.Factory.Run(async delegate
{
string result = await this.clientRpc.InvokeWithCancellationAsync<string>(nameof(this.server.AsyncMethod), new object?[] { "hi" }, this.TimeoutToken).WithCancellation(this.TimeoutToken);
Assert.Equal("hi!", result);
});
}

protected static Exception CreateExceptionToBeThrownByDeserializer() => new Exception("This exception is meant to be thrown.");

protected override void Dispose(bool disposing)
Expand Down Expand Up @@ -3059,6 +3087,8 @@ public class Server : BaseClass, IServerDerived

internal TraceSource? TraceSource { get; set; }

internal JoinableTaskFactory? JoinableTaskFactory { get; set; }

public static string ServerMethod(string argument)
{
return argument + "!";
Expand Down Expand Up @@ -3300,6 +3330,11 @@ public string ExpectEncodedA(string arg)

public async Task<string> AsyncMethod(string arg)
{
if (this.JoinableTaskFactory is not null)
{
await this.JoinableTaskFactory.SwitchToMainThreadAsync();
}

await Task.Yield();
return arg + "!";
}
Expand Down
1 change: 1 addition & 0 deletions test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
<PackageReference Include="xunit.runner.console" />
<PackageReference Include="xunit.runner.visualstudio" />
<PackageReference Include="xunit.skippablefact" />
<PackageReference Include="xunit.stafact" />
<PackageReference Include="xunit" />
</ItemGroup>
<ItemGroup>
Expand Down

0 comments on commit 59af916

Please sign in to comment.