Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JoinableTaskFactory integration #886

Merged
merged 1 commit into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
<CentralPackageTransitivePinningEnabled>true</CentralPackageTransitivePinningEnabled>

<MicroBuildVersion>2.0.107</MicroBuildVersion>
<VisualStudioThreadingVersion>17.6.29-alpha</VisualStudioThreadingVersion>
</PropertyGroup>
<ItemGroup>
<PackageVersion Include="BenchmarkDotNet.Diagnostics.Windows" Version="0.13.2" />
Expand All @@ -18,8 +19,8 @@
<PackageVersion Include="Microsoft.CodeCoverage" Version="17.5.0-release-20230131-04" />
<PackageVersion Include="Microsoft.NET.Test.Sdk" Version="17.4.1" />
<PackageVersion Include="Microsoft.VisualStudio.Internal.MicroBuild.NonShipping" Version="$(MicroBuildVersion)" />
<PackageVersion Include="Microsoft.VisualStudio.Threading.Analyzers" Version="17.4.27" />
<PackageVersion Include="Microsoft.VisualStudio.Threading" Version="17.1.46" />
<PackageVersion Include="Microsoft.VisualStudio.Threading.Analyzers" Version="$(VisualStudioThreadingVersion)" />
<PackageVersion Include="Microsoft.VisualStudio.Threading" Version="$(VisualStudioThreadingVersion)" />
<PackageVersion Include="Nerdbank.Streams" Version="2.9.112" />
<PackageVersion Include="Newtonsoft.Json" Version="13.0.1" />
<PackageVersion Include="System.Collections.Immutable" Version="7.0.0" />
Expand All @@ -34,6 +35,7 @@
<PackageVersion Include="xunit.runner.console" Version="2.4.2" />
<PackageVersion Include="xunit.runner.visualstudio" Version="2.4.5" />
<PackageVersion Include="xunit.skippablefact" Version="1.4.13" />
<PackageVersion Include="xunit.stafact" Version="1.1.11" />
<PackageVersion Include="xunit" Version="2.4.2" />
</ItemGroup>
<ItemGroup>
Expand Down
55 changes: 54 additions & 1 deletion src/StreamJsonRpc/JsonRpc.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ public class JsonRpc : IDisposableObservable, IJsonRpcFormatterCallbacks, IJsonR
/// </summary>
internal static readonly SynchronizationContext DefaultSynchronizationContext = new SynchronizationContext();

/// <summary>
/// The name of the top-level field that we add to JSON-RPC messages to track JoinableTask context to mitigate deadlocks.
/// </summary>
private const string JoinableTaskTokenHeaderName = "joinableTaskToken";

private static readonly MethodInfo MarshalWithControlledLifetimeOpenGenericMethodInfo = typeof(JsonRpc).GetMethods(BindingFlags.Static | BindingFlags.NonPublic).Single(m => m.Name == nameof(MarshalWithControlledLifetime) && m.IsGenericMethod);

[DebuggerBrowsable(DebuggerBrowsableState.Never)]
Expand Down Expand Up @@ -78,6 +83,11 @@ public class JsonRpc : IDisposableObservable, IJsonRpcFormatterCallbacks, IJsonR
/// </summary>
private readonly RpcTargetInfo rpcTargetInfo;

/// <summary>
/// Carries the value from a <see cref="JoinableTaskTokenHeaderName"/> when <see cref="JoinableTaskFactory"/> has not been set.
/// </summary>
private readonly System.Threading.AsyncLocal<string?> joinableTaskTokenWithoutJtf = new();

/// <summary>
/// List of remote RPC targets to call if connection should be relayed.
/// </summary>
Expand Down Expand Up @@ -106,6 +116,12 @@ public class JsonRpc : IDisposableObservable, IJsonRpcFormatterCallbacks, IJsonR
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private SynchronizationContext? synchronizationContext;

/// <summary>
/// Backing field for the <see cref="JoinableTaskFactory"/> property.
/// </summary>
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private JoinableTaskFactory? joinableTaskFactory;

/// <summary>
/// Backing field for the <see cref="CancellationStrategy"/> property.
/// </summary>
Expand Down Expand Up @@ -429,6 +445,21 @@ public SynchronizationContext? SynchronizationContext
}
}

/// <summary>
/// Gets or sets the <see cref="JoinableTaskFactory"/> to participate in to mitigate deadlocks with the main thread.
/// </summary>
/// <value>Defaults to null.</value>
public JoinableTaskFactory? JoinableTaskFactory
{
get => this.joinableTaskFactory;

set
{
this.ThrowIfConfigurationLocked();
this.joinableTaskFactory = value;
}
}

/// <summary>
/// Gets a <see cref="Task"/> that completes when this instance is disposed or when listening has stopped
/// whether by error, disposal or the stream closing.
Expand Down Expand Up @@ -1894,6 +1925,12 @@ private JsonRpcError CreateCancellationResponse(JsonRpcRequest request)
JsonRpcEventSource.Instance.SendingRequest(request.RequestId.NumberIfPossibleForEvent, request.Method, JsonRpcEventSource.GetArgumentsString(request));
}

string? parentToken = this.JoinableTaskFactory is not null ? this.JoinableTaskFactory.Context.Capture() : this.joinableTaskTokenWithoutJtf.Value;
if (parentToken is not null)
{
request.TrySetTopLevelProperty(JoinableTaskTokenHeaderName, parentToken);
}

// IMPORTANT: This should be the first await in this async code path.
// This is crucial to the guarantee that overrides of SendAsync can assume they are executed
// before the first await when a JsonRpc call is made.
Expand Down Expand Up @@ -2044,7 +2081,23 @@ private async ValueTask<JsonRpcMessage> DispatchIncomingRequestAsync(JsonRpcRequ
}
}

return await this.DispatchRequestAsync(request, targetMethod, cancellationToken).ConfigureAwait(false);
request.TryGetTopLevelProperty<string>(JoinableTaskTokenHeaderName, out string? parentToken);
if (this.JoinableTaskFactory is null)
{
this.joinableTaskTokenWithoutJtf.Value = parentToken;
}

if (this.JoinableTaskFactory is null || parentToken is null)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is a JTF instance but no token was provided in the header, should we still use JTF to dispatch for the benefit of the current process? (although I am not sure what benefit there would be)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lifeng and I couldn't think of any benefit in doing so. But there is a non-zero overhead from doing so. Thus, we only do it when we feel it's justified by skipping when there is no parent token.

{
return await this.DispatchRequestAsync(request, targetMethod, cancellationToken).ConfigureAwait(false);
}
else
{
return await this.JoinableTaskFactory.RunAsync(
async () => await this.DispatchRequestAsync(request, targetMethod, cancellationToken).ConfigureAwait(false),
parentToken,
JoinableTaskCreationOptions.None);
}
}
else
{
Expand Down
2 changes: 2 additions & 0 deletions src/StreamJsonRpc/netstandard2.0/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ StreamJsonRpc.ExceptionSettings.RecursionLimit.get -> int
StreamJsonRpc.ExceptionSettings.RecursionLimit.init -> void
StreamJsonRpc.JsonRpc.ExceptionOptions.get -> StreamJsonRpc.ExceptionSettings!
StreamJsonRpc.JsonRpc.ExceptionOptions.set -> void
StreamJsonRpc.JsonRpc.JoinableTaskFactory.get -> Microsoft.VisualStudio.Threading.JoinableTaskFactory?
StreamJsonRpc.JsonRpc.JoinableTaskFactory.set -> void
StreamJsonRpc.JsonRpcIgnoreAttribute
StreamJsonRpc.JsonRpcIgnoreAttribute.JsonRpcIgnoreAttribute() -> void
StreamJsonRpc.JsonRpcMethodAttribute.ClientRequiresNamedArguments.get -> bool
Expand Down
2 changes: 2 additions & 0 deletions src/StreamJsonRpc/netstandard2.1/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ StreamJsonRpc.ExceptionSettings.RecursionLimit.get -> int
StreamJsonRpc.ExceptionSettings.RecursionLimit.init -> void
StreamJsonRpc.JsonRpc.ExceptionOptions.get -> StreamJsonRpc.ExceptionSettings!
StreamJsonRpc.JsonRpc.ExceptionOptions.set -> void
StreamJsonRpc.JsonRpc.JoinableTaskFactory.get -> Microsoft.VisualStudio.Threading.JoinableTaskFactory?
StreamJsonRpc.JsonRpc.JoinableTaskFactory.set -> void
StreamJsonRpc.JsonRpcIgnoreAttribute
StreamJsonRpc.JsonRpcIgnoreAttribute.JsonRpcIgnoreAttribute() -> void
StreamJsonRpc.JsonRpcMethodAttribute.ClientRequiresNamedArguments.get -> bool
Expand Down
66 changes: 66 additions & 0 deletions test/StreamJsonRpc.Tests/JsonRpcTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2848,6 +2848,65 @@ public async Task IncomingActivityStopsAfterAsyncTargetMethodCompletes()
await stopped.Task.WithCancellation(this.TimeoutToken);
}

[Fact]
public void JoinableTaskFactory_ThrowsAfterRunning()
{
Assert.Throws<InvalidOperationException>(() => this.clientRpc.JoinableTaskFactory = null);
}

/// <summary>
/// Asserts that when both client and server are JTF-aware, that no deadlock occurs when the client blocks the main thread that the server needs.
/// </summary>
[UIFact]
public void JoinableTaskFactory_IntegrationBothSides()
{
// Set up a main thread and JoinableTaskContext.
JoinableTaskContext jtc = new();

// Configure the client and server to understand JTF.
this.clientRpc.AllowModificationWhileListening = true;
this.clientRpc.JoinableTaskFactory = jtc.Factory;
this.serverRpc.AllowModificationWhileListening = true;
this.serverRpc.JoinableTaskFactory = jtc.Factory;

// Tell the server to require the main thread to get something done.
this.server.JoinableTaskFactory = jtc.Factory;

jtc.Factory.Run(async delegate
{
string result = await this.clientRpc.InvokeWithCancellationAsync<string>(nameof(this.server.AsyncMethod), new object?[] { "hi" }, this.TimeoutToken).WithCancellation(this.TimeoutToken);
Assert.Equal("hi!", result);
});
}

/// <summary>
/// Asserts that when only the client is JTF-aware, that no deadlock occurs when the client blocks the main thread
/// and the server calls back to the client for something that needs the main thread as part of processing the client's request.
/// </summary>
[UIFact]
public void JoinableTaskFactory_IntegrationClientSideOnly()
{
// Set up a main thread and JoinableTaskContext.
JoinableTaskContext jtc = new();

// Configure the client and server to understand JTF.
this.clientRpc.AllowModificationWhileListening = true;
this.clientRpc.JoinableTaskFactory = jtc.Factory;

const string CallbackMethodName = "ClientNeedsMainThread";
this.clientRpc.AddLocalRpcMethod(CallbackMethodName, new Func<Task>(async delegate
{
await jtc.Factory.SwitchToMainThreadAsync();
}));

this.server.Tests = this;

jtc.Factory.Run(async delegate
{
await this.clientRpc.InvokeWithCancellationAsync(nameof(this.server.Callback), new object?[] { CallbackMethodName }, this.TimeoutToken).WithCancellation(this.TimeoutToken);
});
}

protected static Exception CreateExceptionToBeThrownByDeserializer() => new Exception("This exception is meant to be thrown.");

protected override void Dispose(bool disposing)
Expand Down Expand Up @@ -3059,6 +3118,8 @@ public class Server : BaseClass, IServerDerived

internal TraceSource? TraceSource { get; set; }

internal JoinableTaskFactory? JoinableTaskFactory { get; set; }

public static string ServerMethod(string argument)
{
return argument + "!";
Expand Down Expand Up @@ -3300,6 +3361,11 @@ public string ExpectEncodedA(string arg)

public async Task<string> AsyncMethod(string arg)
{
if (this.JoinableTaskFactory is not null)
{
await this.JoinableTaskFactory.SwitchToMainThreadAsync();
}

await Task.Yield();
return arg + "!";
}
Expand Down
1 change: 1 addition & 0 deletions test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
<PackageReference Include="xunit.runner.console" />
<PackageReference Include="xunit.runner.visualstudio" />
<PackageReference Include="xunit.skippablefact" />
<PackageReference Include="xunit.stafact" />
<PackageReference Include="xunit" />
</ItemGroup>
<ItemGroup>
Expand Down