Skip to content

Commit

Permalink
Add JoinableTaskFactory integration
Browse files Browse the repository at this point in the history
This avoids deadlocks that RPC would otherwise introduce by carrying a `JoinableTask` token across RPC so that when/if the RPC call ever makes it back to the original AppDomain (or never leaves it) such that the callee needs the main thread, it will be able to reach it if the caller owns the main thread.
  • Loading branch information
AArnott committed Mar 1, 2023
1 parent fe02bff commit 3116407
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 3 deletions.
6 changes: 4 additions & 2 deletions Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
<CentralPackageTransitivePinningEnabled>true</CentralPackageTransitivePinningEnabled>

<MicroBuildVersion>2.0.107</MicroBuildVersion>
<VisualStudioThreadingVersion>17.6.29-alpha</VisualStudioThreadingVersion>
</PropertyGroup>
<ItemGroup>
<PackageVersion Include="BenchmarkDotNet.Diagnostics.Windows" Version="0.13.2" />
Expand All @@ -18,8 +19,8 @@
<PackageVersion Include="Microsoft.CodeCoverage" Version="17.5.0-release-20230131-04" />
<PackageVersion Include="Microsoft.NET.Test.Sdk" Version="17.4.1" />
<PackageVersion Include="Microsoft.VisualStudio.Internal.MicroBuild.NonShipping" Version="$(MicroBuildVersion)" />
<PackageVersion Include="Microsoft.VisualStudio.Threading.Analyzers" Version="17.4.27" />
<PackageVersion Include="Microsoft.VisualStudio.Threading" Version="17.1.46" />
<PackageVersion Include="Microsoft.VisualStudio.Threading.Analyzers" Version="$(VisualStudioThreadingVersion)" />
<PackageVersion Include="Microsoft.VisualStudio.Threading" Version="$(VisualStudioThreadingVersion)" />
<PackageVersion Include="Nerdbank.Streams" Version="2.9.112" />
<PackageVersion Include="Newtonsoft.Json" Version="13.0.1" />
<PackageVersion Include="System.Collections.Immutable" Version="7.0.0" />
Expand All @@ -34,6 +35,7 @@
<PackageVersion Include="xunit.runner.console" Version="2.4.2" />
<PackageVersion Include="xunit.runner.visualstudio" Version="2.4.5" />
<PackageVersion Include="xunit.skippablefact" Version="1.4.13" />
<PackageVersion Include="xunit.stafact" Version="1.1.11" />
<PackageVersion Include="xunit" Version="2.4.2" />
</ItemGroup>
<ItemGroup>
Expand Down
55 changes: 54 additions & 1 deletion src/StreamJsonRpc/JsonRpc.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ public class JsonRpc : IDisposableObservable, IJsonRpcFormatterCallbacks, IJsonR
/// </summary>
internal static readonly SynchronizationContext DefaultSynchronizationContext = new SynchronizationContext();

/// <summary>
/// The name of the top-level field that we add to JSON-RPC messages to track JoinableTask context to mitigate deadlocks.
/// </summary>
private const string JoinableTaskTokenHeaderName = "joinableTaskToken";

private static readonly MethodInfo MarshalWithControlledLifetimeOpenGenericMethodInfo = typeof(JsonRpc).GetMethods(BindingFlags.Static | BindingFlags.NonPublic).Single(m => m.Name == nameof(MarshalWithControlledLifetime) && m.IsGenericMethod);

[DebuggerBrowsable(DebuggerBrowsableState.Never)]
Expand Down Expand Up @@ -78,6 +83,11 @@ public class JsonRpc : IDisposableObservable, IJsonRpcFormatterCallbacks, IJsonR
/// </summary>
private readonly RpcTargetInfo rpcTargetInfo;

/// <summary>
/// Carries the value from a <see cref="JoinableTaskTokenHeaderName"/> when <see cref="JoinableTaskFactory"/> has not been set.
/// </summary>
private readonly System.Threading.AsyncLocal<string?> joinableTaskTokenWithoutJtf = new();

/// <summary>
/// List of remote RPC targets to call if connection should be relayed.
/// </summary>
Expand Down Expand Up @@ -106,6 +116,12 @@ public class JsonRpc : IDisposableObservable, IJsonRpcFormatterCallbacks, IJsonR
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private SynchronizationContext? synchronizationContext;

/// <summary>
/// Backing field for the <see cref="JoinableTaskFactory"/> property.
/// </summary>
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private JoinableTaskFactory? joinableTaskFactory;

/// <summary>
/// Backing field for the <see cref="CancellationStrategy"/> property.
/// </summary>
Expand Down Expand Up @@ -429,6 +445,21 @@ public SynchronizationContext? SynchronizationContext
}
}

/// <summary>
/// Gets or sets the <see cref="JoinableTaskFactory"/> to participate in to mitigate deadlocks with the main thread.
/// </summary>
/// <value>Defaults to null.</value>
public JoinableTaskFactory? JoinableTaskFactory
{
get => this.joinableTaskFactory;

set
{
this.ThrowIfConfigurationLocked();
this.joinableTaskFactory = value;
}
}

/// <summary>
/// Gets a <see cref="Task"/> that completes when this instance is disposed or when listening has stopped
/// whether by error, disposal or the stream closing.
Expand Down Expand Up @@ -1894,6 +1925,12 @@ private JsonRpcError CreateCancellationResponse(JsonRpcRequest request)
JsonRpcEventSource.Instance.SendingRequest(request.RequestId.NumberIfPossibleForEvent, request.Method, JsonRpcEventSource.GetArgumentsString(request));
}

string? parentToken = this.JoinableTaskFactory is not null ? this.JoinableTaskFactory.Context.Capture() : this.joinableTaskTokenWithoutJtf.Value;
if (parentToken is not null)
{
request.TrySetTopLevelProperty(JoinableTaskTokenHeaderName, parentToken);
}

// IMPORTANT: This should be the first await in this async code path.
// This is crucial to the guarantee that overrides of SendAsync can assume they are executed
// before the first await when a JsonRpc call is made.
Expand Down Expand Up @@ -2044,7 +2081,23 @@ private async ValueTask<JsonRpcMessage> DispatchIncomingRequestAsync(JsonRpcRequ
}
}

return await this.DispatchRequestAsync(request, targetMethod, cancellationToken).ConfigureAwait(false);
request.TryGetTopLevelProperty<string>(JoinableTaskTokenHeaderName, out string? parentToken);
if (this.JoinableTaskFactory is null)
{
this.joinableTaskTokenWithoutJtf.Value = parentToken;
}

if (this.JoinableTaskFactory is null || parentToken is null)
{
return await this.DispatchRequestAsync(request, targetMethod, cancellationToken).ConfigureAwait(false);
}
else
{
return await this.JoinableTaskFactory.RunAsync(
async () => await this.DispatchRequestAsync(request, targetMethod, cancellationToken).ConfigureAwait(false),
parentToken,
JoinableTaskCreationOptions.None);
}
}
else
{
Expand Down
2 changes: 2 additions & 0 deletions src/StreamJsonRpc/netstandard2.0/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ StreamJsonRpc.ExceptionSettings.RecursionLimit.get -> int
StreamJsonRpc.ExceptionSettings.RecursionLimit.init -> void
StreamJsonRpc.JsonRpc.ExceptionOptions.get -> StreamJsonRpc.ExceptionSettings!
StreamJsonRpc.JsonRpc.ExceptionOptions.set -> void
StreamJsonRpc.JsonRpc.JoinableTaskFactory.get -> Microsoft.VisualStudio.Threading.JoinableTaskFactory?
StreamJsonRpc.JsonRpc.JoinableTaskFactory.set -> void
StreamJsonRpc.JsonRpcIgnoreAttribute
StreamJsonRpc.JsonRpcIgnoreAttribute.JsonRpcIgnoreAttribute() -> void
StreamJsonRpc.JsonRpcMethodAttribute.ClientRequiresNamedArguments.get -> bool
Expand Down
2 changes: 2 additions & 0 deletions src/StreamJsonRpc/netstandard2.1/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ StreamJsonRpc.ExceptionSettings.RecursionLimit.get -> int
StreamJsonRpc.ExceptionSettings.RecursionLimit.init -> void
StreamJsonRpc.JsonRpc.ExceptionOptions.get -> StreamJsonRpc.ExceptionSettings!
StreamJsonRpc.JsonRpc.ExceptionOptions.set -> void
StreamJsonRpc.JsonRpc.JoinableTaskFactory.get -> Microsoft.VisualStudio.Threading.JoinableTaskFactory?
StreamJsonRpc.JsonRpc.JoinableTaskFactory.set -> void
StreamJsonRpc.JsonRpcIgnoreAttribute
StreamJsonRpc.JsonRpcIgnoreAttribute.JsonRpcIgnoreAttribute() -> void
StreamJsonRpc.JsonRpcMethodAttribute.ClientRequiresNamedArguments.get -> bool
Expand Down
66 changes: 66 additions & 0 deletions test/StreamJsonRpc.Tests/JsonRpcTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2848,6 +2848,65 @@ public async Task IncomingActivityStopsAfterAsyncTargetMethodCompletes()
await stopped.Task.WithCancellation(this.TimeoutToken);
}

[Fact]
public void JoinableTaskFactory_ThrowsAfterRunning()
{
Assert.Throws<InvalidOperationException>(() => this.clientRpc.JoinableTaskFactory = null);
}

/// <summary>
/// Asserts that when both client and server are JTF-aware, that no deadlock occurs when the client blocks the main thread that the server needs.
/// </summary>
[UIFact]
public void JoinableTaskFactory_IntegrationBothSides()
{
// Set up a main thread and JoinableTaskContext.
JoinableTaskContext jtc = new();

// Configure the client and server to understand JTF.
this.clientRpc.AllowModificationWhileListening = true;
this.clientRpc.JoinableTaskFactory = jtc.Factory;
this.serverRpc.AllowModificationWhileListening = true;
this.serverRpc.JoinableTaskFactory = jtc.Factory;

// Tell the server to require the main thread to get something done.
this.server.JoinableTaskFactory = jtc.Factory;

jtc.Factory.Run(async delegate
{
string result = await this.clientRpc.InvokeWithCancellationAsync<string>(nameof(this.server.AsyncMethod), new object?[] { "hi" }, this.TimeoutToken).WithCancellation(this.TimeoutToken);
Assert.Equal("hi!", result);
});
}

/// <summary>
/// Asserts that when only the client is JTF-aware, that no deadlock occurs when the client blocks the main thread
/// and the server calls back to the client for something that needs the main thread as part of processing the client's request.
/// </summary>
[UIFact]
public void JoinableTaskFactory_IntegrationClientSideOnly()
{
// Set up a main thread and JoinableTaskContext.
JoinableTaskContext jtc = new();

// Configure the client and server to understand JTF.
this.clientRpc.AllowModificationWhileListening = true;
this.clientRpc.JoinableTaskFactory = jtc.Factory;

const string CallbackMethodName = "ClientNeedsMainThread";
this.clientRpc.AddLocalRpcMethod(CallbackMethodName, new Func<Task>(async delegate
{
await jtc.Factory.SwitchToMainThreadAsync();
}));

this.server.Tests = this;

jtc.Factory.Run(async delegate
{
await this.clientRpc.InvokeWithCancellationAsync(nameof(this.server.Callback), new object?[] { CallbackMethodName }, this.TimeoutToken).WithCancellation(this.TimeoutToken);
});
}

protected static Exception CreateExceptionToBeThrownByDeserializer() => new Exception("This exception is meant to be thrown.");

protected override void Dispose(bool disposing)
Expand Down Expand Up @@ -3059,6 +3118,8 @@ public class Server : BaseClass, IServerDerived

internal TraceSource? TraceSource { get; set; }

internal JoinableTaskFactory? JoinableTaskFactory { get; set; }

public static string ServerMethod(string argument)
{
return argument + "!";
Expand Down Expand Up @@ -3300,6 +3361,11 @@ public string ExpectEncodedA(string arg)

public async Task<string> AsyncMethod(string arg)
{
if (this.JoinableTaskFactory is not null)
{
await this.JoinableTaskFactory.SwitchToMainThreadAsync();
}

await Task.Yield();
return arg + "!";
}
Expand Down
1 change: 1 addition & 0 deletions test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
<PackageReference Include="xunit.runner.console" />
<PackageReference Include="xunit.runner.visualstudio" />
<PackageReference Include="xunit.skippablefact" />
<PackageReference Include="xunit.stafact" />
<PackageReference Include="xunit" />
</ItemGroup>
<ItemGroup>
Expand Down

0 comments on commit 3116407

Please sign in to comment.