Skip to content

Commit

Permalink
Merge pull request #61187 from sharwell/unnecessary-yield
Browse files Browse the repository at this point in the history
Avoid yield in ReadAsync/WriteAsync unless on main thread
  • Loading branch information
sharwell authored May 11, 2022
2 parents e6b5dd8 + 6dac65d commit d797dcc
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ public WorkspaceThreadingService(IThreadingContext threadingContext)
_threadingContext = threadingContext;
}

public bool IsOnMainThread => _threadingContext.JoinableTaskContext.IsOnMainThread;

public TResult Run<TResult>(Func<Task<TResult>> asyncMethod)
{
return _threadingContext.JoinableTaskFactory.Run(asyncMethod);
Expand Down
78 changes: 24 additions & 54 deletions src/EditorFeatures/Test/Workspaces/TextFactoryTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

#nullable disable

using System;
using System.IO;
using System.Linq;
using System.Text;
Expand All @@ -13,9 +10,6 @@
using Microsoft.CodeAnalysis.Host;
using Microsoft.CodeAnalysis.Test.Utilities;
using Microsoft.CodeAnalysis.Text;
using Microsoft.VisualStudio.Text;
using Microsoft.VisualStudio.Utilities;
using Moq;
using Roslyn.Test.Utilities;
using Xunit;

Expand All @@ -29,7 +23,11 @@ public class TextFactoryTests
[Fact, WorkItem(1038018, "http://vstfdevdiv:8080/DevDiv2/DevDiv/_workitems/edit/1038018"), WorkItem(1041792, "http://vstfdevdiv:8080/DevDiv2/DevDiv/_workitems/edit/1041792")]
public void TestCreateTextFallsBackToSystemDefaultEncoding()
{
using var workspace = new AdhocWorkspace(EditorTestCompositions.EditorFeatures.GetHostServices());
var textFactoryService = Assert.IsType<EditorTextFactoryService>(workspace.Services.GetRequiredService<ITextFactoryService>());

TestCreateTextInferredEncoding(
textFactoryService,
_nonUTF8StringBytes,
defaultEncoding: null,
expectedEncoding: Encoding.Default);
Expand All @@ -38,7 +36,11 @@ public void TestCreateTextFallsBackToSystemDefaultEncoding()
[Fact, WorkItem(1038018, "http://vstfdevdiv:8080/DevDiv2/DevDiv/_workitems/edit/1038018")]
public void TestCreateTextFallsBackToUTF8Encoding()
{
using var workspace = new AdhocWorkspace(EditorTestCompositions.EditorFeatures.GetHostServices());
var textFactoryService = Assert.IsType<EditorTextFactoryService>(workspace.Services.GetRequiredService<ITextFactoryService>());

TestCreateTextInferredEncoding(
textFactoryService,
new ASCIIEncoding().GetBytes("Test"),
defaultEncoding: null,
expectedEncoding: new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: true));
Expand All @@ -47,7 +49,11 @@ public void TestCreateTextFallsBackToUTF8Encoding()
[Fact, WorkItem(1038018, "http://vstfdevdiv:8080/DevDiv2/DevDiv/_workitems/edit/1038018")]
public void TestCreateTextFallsBackToProvidedDefaultEncoding()
{
using var workspace = new AdhocWorkspace(EditorTestCompositions.EditorFeatures.GetHostServices());
var textFactoryService = Assert.IsType<EditorTextFactoryService>(workspace.Services.GetRequiredService<ITextFactoryService>());

TestCreateTextInferredEncoding(
textFactoryService,
new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: true).GetBytes("Test"),
defaultEncoding: Encoding.GetEncoding(1254),
expectedEncoding: Encoding.GetEncoding(1254));
Expand All @@ -56,7 +62,11 @@ public void TestCreateTextFallsBackToProvidedDefaultEncoding()
[Fact, WorkItem(1038018, "http://vstfdevdiv:8080/DevDiv2/DevDiv/_workitems/edit/1038018")]
public void TestCreateTextUsesByteOrderMarkIfPresent()
{
using var workspace = new AdhocWorkspace(EditorTestCompositions.EditorFeatures.GetHostServices());
var textFactoryService = Assert.IsType<EditorTextFactoryService>(workspace.Services.GetRequiredService<ITextFactoryService>());

TestCreateTextInferredEncoding(
textFactoryService,
Encoding.UTF8.GetPreamble().Concat(new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: true).GetBytes("Test")).ToArray(),
defaultEncoding: Encoding.GetEncoding(1254),
expectedEncoding: Encoding.UTF8);
Expand All @@ -65,8 +75,9 @@ public void TestCreateTextUsesByteOrderMarkIfPresent()
[Fact]
public async Task TestCreateFromTemporaryStorage()
{
var textFactory = CreateMockTextFactoryService();
var temporaryStorageService = new TemporaryStorageServiceFactory.TemporaryStorageService(textFactory);
using var workspace = new AdhocWorkspace(EditorTestCompositions.EditorFeatures.GetHostServices());

var temporaryStorageService = Assert.IsType<TemporaryStorageServiceFactory.TemporaryStorageService>(workspace.Services.GetRequiredService<ITemporaryStorageService>());

var text = SourceText.From("Hello, World!");

Expand All @@ -86,8 +97,9 @@ public async Task TestCreateFromTemporaryStorage()
[Fact]
public async Task TestCreateFromTemporaryStorageWithEncoding()
{
var textFactory = CreateMockTextFactoryService();
var temporaryStorageService = new TemporaryStorageServiceFactory.TemporaryStorageService(textFactory);
using var workspace = new AdhocWorkspace(EditorTestCompositions.EditorFeatures.GetHostServices());

var temporaryStorageService = Assert.IsType<TemporaryStorageServiceFactory.TemporaryStorageService>(workspace.Services.GetRequiredService<ITemporaryStorageService>());

var text = SourceText.From("Hello, World!", Encoding.ASCII);

Expand All @@ -104,53 +116,11 @@ public async Task TestCreateFromTemporaryStorageWithEncoding()
Assert.Equal(text2.Encoding, Encoding.ASCII);
}

private static EditorTextFactoryService CreateMockTextFactoryService()
private static void TestCreateTextInferredEncoding(ITextFactoryService textFactoryService, byte[] bytes, Encoding? defaultEncoding, Encoding expectedEncoding)
{
var mockTextBufferFactoryService = new Mock<ITextBufferFactoryService>(MockBehavior.Strict);
mockTextBufferFactoryService
.Setup(t => t.CreateTextBuffer(It.IsAny<TextReader>(), It.IsAny<IContentType>()))
.Returns<TextReader, IContentType>((reader, contentType) =>
{
var text = reader.ReadToEnd();

var mockImage = new Mock<ITextImage>(MockBehavior.Strict);
mockImage.Setup(i => i.GetText(It.IsAny<Span>())).Returns(text);
mockImage.Setup(i => i.Length).Returns(text.Length);

var mockSnapshot = new Mock<ITextSnapshot2>(MockBehavior.Strict);
mockSnapshot.Setup(s => s.TextImage).Returns(mockImage.Object);
mockSnapshot.Setup(s => s.GetText()).Returns(text);

var mockTextBuffer = new Mock<ITextBuffer>(MockBehavior.Strict);
mockTextBuffer.Setup(b => b.CurrentSnapshot).Returns(mockSnapshot.Object);
return mockTextBuffer.Object;
});

var mockUnknownContentType = new Mock<IContentType>(MockBehavior.Strict);
var mockContentTypeRegistryService = new Mock<IContentTypeRegistryService>(MockBehavior.Strict);
mockContentTypeRegistryService.Setup(r => r.UnknownContentType).Returns(mockUnknownContentType.Object);

return new EditorTextFactoryService(new FakeTextBufferCloneService(), mockTextBufferFactoryService.Object, mockContentTypeRegistryService.Object);
}

private static void TestCreateTextInferredEncoding(byte[] bytes, Encoding defaultEncoding, Encoding expectedEncoding)
{
var factory = CreateMockTextFactoryService();
using var stream = new MemoryStream(bytes);
var text = factory.CreateText(stream, defaultEncoding);
var text = textFactoryService.CreateText(stream, defaultEncoding);
Assert.Equal(expectedEncoding, text.Encoding);
}

private class FakeTextBufferCloneService : ITextBufferCloneService
{
public ITextBuffer CloneWithUnknownContentType(SnapshotSpan span) => throw new NotImplementedException();

public ITextBuffer CloneWithUnknownContentType(ITextImage textImage) => throw new NotImplementedException();

public ITextBuffer CloneWithRoslynContentType(SourceText sourceText) => throw new NotImplementedException();

public ITextBuffer Clone(SourceText sourceText, IContentType contentType) => throw new NotImplementedException();

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ namespace Microsoft.CodeAnalysis.Shared.Utilities
/// </remarks>
internal interface IWorkspaceThreadingService
{
bool IsOnMainThread { get; }

TResult Run<TResult>(Func<Task<TResult>> asyncMethod);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,17 @@ namespace Microsoft.CodeAnalysis.Host
[ExportWorkspaceServiceFactory(typeof(ITemporaryStorageService), ServiceLayer.Default), Shared]
internal partial class TemporaryStorageServiceFactory : IWorkspaceServiceFactory
{
private readonly IWorkspaceThreadingService? _workspaceThreadingService;

[ImportingConstructor]
[Obsolete(MefConstruction.ImportingConstructorMessage, error: true)]
public TemporaryStorageServiceFactory()
public TemporaryStorageServiceFactory(
[Import(AllowDefault = true)] IWorkspaceThreadingService? workspaceThreadingService)
{
_workspaceThreadingService = workspaceThreadingService;
}

[Obsolete(MefConstruction.FactoryMethodMessage, error: true)]
public IWorkspaceService CreateService(HostWorkspaceServices workspaceServices)
{
var textFactory = workspaceServices.GetRequiredService<ITextFactoryService>();
Expand All @@ -37,7 +42,7 @@ public IWorkspaceService CreateService(HostWorkspaceServices workspaceServices)
// and .NET Core Windows. For non-Windows .NET Core scenarios, we can return the TrivialTemporaryStorageService
// until https://github.com/dotnet/runtime/issues/30878 is fixed.
return PlatformInformation.IsWindows || PlatformInformation.IsRunningOnMono
? new TemporaryStorageService(textFactory)
? new TemporaryStorageService(_workspaceThreadingService, textFactory)
: TrivialTemporaryStorageService.Instance;
}

Expand Down Expand Up @@ -67,6 +72,7 @@ internal class TemporaryStorageService : ITemporaryStorageService2
/// <seealso cref="_weakFileReference"/>
private const long MultiFileBlockSize = SingleFileThreshold * 32;

private readonly IWorkspaceThreadingService? _workspaceThreadingService;
private readonly ITextFactoryService _textFactory;

/// <summary>
Expand Down Expand Up @@ -112,8 +118,12 @@ internal class TemporaryStorageService : ITemporaryStorageService2
/// <seealso cref="_weakFileReference"/>
private long _offset;

public TemporaryStorageService(ITextFactoryService textFactory)
=> _textFactory = textFactory;
[Obsolete(MefConstruction.FactoryMethodMessage, error: true)]
public TemporaryStorageService(IWorkspaceThreadingService? workspaceThreadingService, ITextFactoryService textFactory)
{
_workspaceThreadingService = workspaceThreadingService;
_textFactory = textFactory;
}

public ITemporaryTextStorage CreateTemporaryTextStorage(CancellationToken cancellationToken)
=> new TemporaryTextStorage(this);
Expand Down Expand Up @@ -242,7 +252,7 @@ public SourceText ReadText(CancellationToken cancellationToken)
}
}

public Task<SourceText> ReadTextAsync(CancellationToken cancellationToken)
public async Task<SourceText> ReadTextAsync(CancellationToken cancellationToken)
{
// There is a reason for implementing it like this: proper async implementation
// that reads the underlying memory mapped file stream in an asynchronous fashion
Expand All @@ -254,7 +264,13 @@ public Task<SourceText> ReadTextAsync(CancellationToken cancellationToken)
// of a page fault. Therefore, if we're going to be blocking a thread, we should
// just block one thread and do the whole thing at once vs. a fake "async"
// implementation which will continue to requeue work back to the thread pool.
return Task.Factory.StartNew(() => ReadText(cancellationToken), cancellationToken, TaskCreationOptions.None, TaskScheduler.Default);
if (_service._workspaceThreadingService is { IsOnMainThread: true })
{
await Task.Yield().ConfigureAwait(false);
cancellationToken.ThrowIfCancellationRequested();
}

return ReadText(cancellationToken);
}

public void WriteText(SourceText text, CancellationToken cancellationToken)
Expand All @@ -281,10 +297,16 @@ public void WriteText(SourceText text, CancellationToken cancellationToken)
}
}

public Task WriteTextAsync(SourceText text, CancellationToken cancellationToken = default)
public async Task WriteTextAsync(SourceText text, CancellationToken cancellationToken)
{
// See commentary in ReadTextAsync for why this is implemented this way.
return Task.Factory.StartNew(() => WriteText(text, cancellationToken), cancellationToken, TaskCreationOptions.None, TaskScheduler.Default);
if (_service._workspaceThreadingService is { IsOnMainThread: true })
{
await Task.Yield().ConfigureAwait(false);
cancellationToken.ThrowIfCancellationRequested();
}

WriteText(text, cancellationToken);
}

private static unsafe TextReader CreateTextReaderFromTemporaryStorage(ISupportDirectMemoryAccess accessor, int streamLength)
Expand Down
Loading

0 comments on commit d797dcc

Please sign in to comment.