Skip to content

Commit

Permalink
Make StreamingRequestHandler methods virtual (#5828) (#5835)
Browse files Browse the repository at this point in the history
* Make methods virtual

* Move package reference
  • Loading branch information
mrivera-ms authored Aug 12, 2021
1 parent 50ddade commit 0286c05
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public StreamingRequestHandler(IBot bot, IStreamingActivityProcessor activityPro
_userAgent = GetUserAgent();
_server = new WebSocketServer(socket, this);
_serverIsConnected = true;
_server.Disconnected += Server_Disconnected;
_server.Disconnected += ServerDisconnected;
}

/// <summary>
Expand Down Expand Up @@ -138,7 +138,7 @@ public StreamingRequestHandler(IBot bot, IStreamingActivityProcessor activityPro
_userAgent = GetUserAgent();
_server = new NamedPipeServer(pipeName, this);
_serverIsConnected = true;
_server.Disconnected += Server_Disconnected;
_server.Disconnected += ServerDisconnected;
}

/// <summary>
Expand All @@ -163,7 +163,7 @@ public StreamingRequestHandler(IBot bot, IStreamingActivityProcessor activityPro
/// Begins listening for incoming requests over this StreamingRequestHandler's server.
/// </summary>
/// <returns>A task that completes once the server is no longer listening.</returns>
public async Task ListenAsync()
public virtual async Task ListenAsync()
{
await _server.StartAsync().ConfigureAwait(false);
_logger.LogInformation("Streaming request handler started listening");
Expand Down Expand Up @@ -380,7 +380,7 @@ public override async Task<StreamingResponse> ProcessRequestAsync(ReceiveRequest
/// <param name="activity">The activity to send.</param>
/// <param name="cancellationToken">A cancellation token.</param>
/// <returns>A task that resolves to a <see cref="ResourceResponse"/>.</returns>
public async Task<ResourceResponse> SendActivityAsync(Activity activity, CancellationToken cancellationToken = default)
public virtual async Task<ResourceResponse> SendActivityAsync(Activity activity, CancellationToken cancellationToken = default)
{
string requestPath;
if (!string.IsNullOrWhiteSpace(activity.ReplyToId) && activity.ReplyToId.Length >= 1)
Expand Down Expand Up @@ -439,6 +439,19 @@ public Task<ReceiveResponse> SendStreamingRequestAsync(StreamingRequest request,
return _server.SendAsync(request, cancellationToken);
}

/// <summary>
/// An event handler for server disconnected events.
/// </summary>
/// <param name="sender">The source of the disconnection event.</param>
/// <param name="e">The arguments specified by the disconnection event.</param>
protected virtual void ServerDisconnected(object sender, DisconnectedEventArgs e)
{
_serverIsConnected = false;

// remove ourselves from the global collection
_requestHandlers.TryRemove(_instanceId, out var _);
}

/// <summary>
/// Build and return versioning information used for telemetry, including:
/// The Schema version is 3.1, put into the Microsoft-BotFramework header,
Expand Down Expand Up @@ -485,14 +498,6 @@ private static IEnumerable<HttpContent> UpdateAttachmentStreams(Activity activit
return null;
}

private void Server_Disconnected(object sender, DisconnectedEventArgs e)
{
_serverIsConnected = false;

// remove ourselves from the global collection
_requestHandlers.TryRemove(_instanceId, out var _);
}

/// <summary>
/// Checks the validity of the request and attempts to map it the correct custom endpoint,
/// then generates and returns a response if appropriate.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

<ItemGroup Condition="'$(TargetFramework)' != 'netcoreapp3.1'">
<PackageReference Include="Microsoft.AspNetCore.App" />
<PackageReference Include="Microsoft.AspNetCore.Http" Version="2.1.1" />
</ItemGroup>

<ItemGroup>
Expand Down
100 changes: 99 additions & 1 deletion tests/Microsoft.Bot.Streaming.Tests/StreamingRequestHandlerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,17 @@
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.Bot.Builder.Integration;
using Microsoft.Bot.Builder.Integration.AspNet.Core;
using Microsoft.Bot.Schema;
using Microsoft.Bot.Streaming;
using Microsoft.Bot.Streaming.Payloads;
using Microsoft.Bot.Streaming.Transport;
using Microsoft.Bot.Streaming.UnitTests.Mocks;
using Microsoft.Extensions.Logging;
using Microsoft.Rest.Serialization;
using Moq;
using Newtonsoft.Json;
using Xunit;

Expand Down Expand Up @@ -324,7 +330,59 @@ public async void ItGetsUserAgentInfo()
// Assert
Assert.Matches(expectation, response.Streams[0].Content.ReadAsStringAsync().Result);
}


[Fact]
public async void CallStreamingRequestHandlerOverrides()
{
var activity = new Activity()
{
Type = "message",
Text = "received from bot",
ServiceUrl = "http://localhost",
ChannelId = "ChannelId",
From = new Schema.ChannelAccount()
{
Id = "bot",
Name = "bot",
},
Conversation = new Schema.ConversationAccount(null, null, Guid.NewGuid().ToString(), null, null, null, null),
};

// Arrange
var headerDictionaryMock = new Mock<IHeaderDictionary>();
headerDictionaryMock.Setup(h => h[It.Is<string>(v => v == "Authorization")]).Returns<string>(null);

var httpRequestMock = new Mock<HttpRequest>();
httpRequestMock.Setup(r => r.Body).Returns(CreateStream(activity));
httpRequestMock.Setup(r => r.Headers).Returns(headerDictionaryMock.Object);
httpRequestMock.Setup(r => r.Method).Returns(HttpMethods.Get);
httpRequestMock.Setup(r => r.HttpContext.WebSockets.IsWebSocketRequest).Returns(true);

var httpResponseMock = new Mock<HttpResponse>();

var botMock = new Mock<IBot>();
botMock.Setup(b => b.OnTurnAsync(It.IsAny<TurnContext>(), It.IsAny<CancellationToken>())).Returns(Task.CompletedTask);

// Act
var methodCalls = new List<string>();
var adapter = new BotFrameworkHttpAdapterSub(methodCalls);
await adapter.ProcessAsync(httpRequestMock.Object, httpResponseMock.Object, botMock.Object);

Assert.Contains("ListenAsync()", methodCalls);
Assert.Contains("ServerDisconnected()", methodCalls);
}

private static Stream CreateStream(Activity activity)
{
string json = SafeJsonConvert.SerializeObject(activity, MessageSerializerSettings.Create());
var stream = new MemoryStream();
var textWriter = new StreamWriter(stream);
textWriter.Write(json);
textWriter.Flush();
stream.Seek(0, SeekOrigin.Begin);
return stream;
}

public class FauxSock : WebSocket
{
public override WebSocketCloseStatus? CloseStatus => throw new NotImplementedException();
Expand Down Expand Up @@ -388,5 +446,45 @@ public FakeContentStream(Guid id, string contentType, Stream stream)

public Stream Stream { get; set; }
}

private class BotFrameworkHttpAdapterSub : BotFrameworkHttpAdapter
{
private List<string> _methodCalls;

public BotFrameworkHttpAdapterSub(List<string> methodCalls)
: base()
{
_methodCalls = methodCalls;
}

public override StreamingRequestHandler CreateStreamingRequestHandler(IBot bot, WebSocket socket, string audience)
{
var socketMock = new Mock<WebSocket>();
return new StreamingRequestHandlerSub(bot, this, socketMock.Object, audience, Logger, _methodCalls);
}
}

private class StreamingRequestHandlerSub : StreamingRequestHandler
{
private List<string> _methodCalls;

public StreamingRequestHandlerSub(IBot bot, IStreamingActivityProcessor activityProcessor, WebSocket socket, string audience, ILogger logger = null, List<string> methodCalls = null)
: base(bot, activityProcessor, socket, audience, logger)
{
_methodCalls = methodCalls;
}

public override async Task ListenAsync()
{
_methodCalls.Add("ListenAsync()");
await base.ListenAsync();
}

protected override void ServerDisconnected(object sender, DisconnectedEventArgs e)
{
_methodCalls.Add("ServerDisconnected()");
base.ServerDisconnected(sender, e);
}
}
}
}

0 comments on commit 0286c05

Please sign in to comment.