Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix reply ordering per endpoint #224

Merged
merged 1 commit into from
Jan 28, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 109 additions & 73 deletions UsbIpServer/AttachedClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
// SPDX-License-Identifier: GPL-2.0-only

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Windows.Win32;
Expand Down Expand Up @@ -39,13 +42,43 @@ public AttachedClient(ILogger<AttachedClient> logger, ClientContext clientContex
readonly ILogger Logger;
readonly ClientContext ClientContext;
readonly NetworkStream Stream;
readonly Channel<byte[]> ReplyChannel = Channel.CreateUnbounded<byte[]>();
readonly DeviceFile Device;
readonly SemaphoreSlim WriteMutex = new(1);
readonly object PendingSubmitsLock = new();

/// <summary>
/// Mapping from USBIP seqnum to raw USB endpoint number.
/// Used for UNLINK.
/// </summary>
readonly ConcurrentDictionary<uint, byte> PendingSubmits = new();

/// <summary>
/// Mapping from endpoint to its channel for ordered replies.
/// </summary>
readonly SortedDictionary<uint, byte> PendingSubmits = new();
readonly ConcurrentDictionary<byte, ChannelWriter<Task<byte[]>>> EndpointChannels = new();

/// <summary>
/// Returns the channel writer for the given endpoint.
/// </summary>
ChannelWriter<Task<byte[]>> GetEndpointWriter(byte endpoint, CancellationToken cancellationToken)
{
return EndpointChannels.GetOrAdd(endpoint, (_) =>
{
var channel = Channel.CreateUnbounded<Task<byte[]>>();
Task.Run(async () =>
{
// This task ensures that all replies for this specific endpoint are
// returned in the same order as the requests.
while (!cancellationToken.IsCancellationRequested)
{
var nextEndpointTask = await channel.Reader.ReadAsync(cancellationToken);
var nextEndpointReply = await nextEndpointTask;
// This multiplexes the replies for this endpoint with the other endpoints.
await ReplyChannel.Writer.WriteAsync(nextEndpointReply);
}
}, cancellationToken);
return channel.Writer;
});
}

async Task HandleSubmitIsochronousAsync(UsbIpHeaderBasic basic, UsbIpHeaderCmdSubmit submit, CancellationToken cancellationToken)
{
Expand All @@ -70,16 +103,13 @@ async Task HandleSubmitIsochronousAsync(UsbIpHeaderBasic basic, UsbIpHeaderCmdSu

// Everything has been read and validated, now process...

lock (PendingSubmitsLock)
// To support UNLINK, we must be able to abort the pipe that is used for this URB.
// We need the raw USB endpoint number, i.e. including the high bit for input pipes.
if (!PendingSubmits.TryAdd(basic.seqnum, (byte)(basic.ep | (basic.direction == UsbIpDir.USBIP_DIR_IN ? 0x80u : 0x00u))))
{
// To support UNLINK, we must be able to abort the pipe that is used for this URB.
// We need the raw USB endpoint number, i.e. including the high bit for input pipes.
if (!PendingSubmits.TryAdd(basic.seqnum, (byte)(basic.ep | (basic.direction == UsbIpDir.USBIP_DIR_IN ? 0x80u : 0x00u))))
{
throw new ProtocolViolationException($"duplicate sequence number {basic.seqnum}");
}
Logger.Trace($"Scheduled seqnum={basic.seqnum}, pending count = {PendingSubmits.Count}");
throw new ProtocolViolationException($"duplicate sequence number {basic.seqnum}");
}
Logger.Trace($"Scheduled seqnum={basic.seqnum}, pending count = {PendingSubmits.Count}");

// VBoxUSB only excepts up to 8 iso packets per ioctl, so we may have to split
// the request into multiple ioctls.
Expand Down Expand Up @@ -139,22 +169,10 @@ async Task HandleSubmitIsochronousAsync(UsbIpHeaderBasic basic, UsbIpHeaderCmdSu
}

// Continue when all ioctls *and* their continuations have been completed.
_ = Task.WhenAll(ioctls).ContinueWith(async (task, state) =>
var replyTask = Task.WhenAll(ioctls).ContinueWith(byte[] (task, _) =>
{
using var writeLock = await Lock.CreateAsync(WriteMutex, cancellationToken);

// Now we are synchronous with the sender.

lock (PendingSubmitsLock)
{
// We are racing with possible UNLINK commands.
if (!PendingSubmits.Remove(basic.seqnum))
{
// Apparently, the client has already UNLINK-ed (canceled) the request; we're done.
Logger.Trace($"Completed seqnum={basic.seqnum} after UNLINK, pending count = {PendingSubmits.Count}");
return;
}
}
// The pending request is now completed; no need to support UNLINK any longer.
PendingSubmits.TryRemove(basic.seqnum, out var _);

var header = new UsbIpHeader
{
Expand Down Expand Up @@ -190,13 +208,25 @@ async Task HandleSubmitIsochronousAsync(UsbIpHeaderBasic basic, UsbIpHeaderCmdSu
}
}

await Stream.WriteAsync(header.ToBytes(), cancellationToken);
using var replyStream = new MemoryStream();
replyStream.Write(header.ToBytes());
if (basic.direction == UsbIpDir.USBIP_DIR_IN)
{
await Stream.WriteAsync(retBuf, cancellationToken);
replyStream.Write(retBuf);
}
await Stream.WriteAsync(packetDescriptors.ToBytes(), cancellationToken);
replyStream.Write(packetDescriptors.ToBytes());
return replyStream.ToArray();
}, cancellationToken, TaskScheduler.Default);

// Now we queue the task that creates the response, so that all replies for a single
// endpoint remain ordered.

var endpointWriter = GetEndpointWriter((byte)basic.ep, cancellationToken);
await endpointWriter.WriteAsync(replyTask, cancellationToken);

// We return to the caller, so that the next request can be handled. As a result, multiple requests
// can be outstanding (either for the same, or for multiple endpoints). Requests are completed
// asynchronously and in any order, but the replies for each endpoint are sent to the client in original order.
}
finally
{
Expand All @@ -209,8 +239,6 @@ async Task HandleSubmitIsochronousAsync(UsbIpHeaderBasic basic, UsbIpHeaderCmdSu

async Task HandleSubmitAsync(UsbIpHeaderBasic basic, UsbIpHeaderCmdSubmit submit, CancellationToken cancellationToken)
{
// We are synchronous with the receiver.

if (submit.number_of_packets != 0)
{
await HandleSubmitIsochronousAsync(basic, submit, cancellationToken);
Expand Down Expand Up @@ -257,7 +285,7 @@ async Task HandleSubmitAsync(UsbIpHeaderBasic basic, UsbIpHeaderCmdSubmit submit
// This means no further requests will be read until the special request has completed.
// - Otherwise, we will start a new task so that the receiver can continue.
// This means multiple URBs can be outstanding awaiting completion.
// The pending URBs can be completed out of order, but the replies must be sent atomically.
// The pending URBs can be completed out of order, but for each endpoint the replies must be sent in order.

Task ioctl;
var pending = false;
Expand Down Expand Up @@ -309,17 +337,15 @@ async Task HandleSubmitAsync(UsbIpHeaderBasic basic, UsbIpHeaderCmdSubmit submit
{
Logger.Trace($"{submit.setup.bmRequestType.B} {submit.setup.bRequest} {submit.setup.wValue.W} {submit.setup.wIndex.W} {submit.setup.wLength}");
}
lock (PendingSubmitsLock)

// To support UNLINK, we must be able to abort the pipe that is used for this URB.
// We need the raw USB endpoint number, i.e. including the high bit for input pipes.
if (!PendingSubmits.TryAdd(basic.seqnum, (byte)(basic.ep | (basic.direction == UsbIpDir.USBIP_DIR_IN ? 0x80u : 0x00u))))
{
// To support UNLINK, we must be able to abort the pipe that is used for this URB.
// We need the raw USB endpoint number, i.e. including the high bit for input pipes.
if (!PendingSubmits.TryAdd(basic.seqnum, (byte)(basic.ep | (basic.direction == UsbIpDir.USBIP_DIR_IN ? 0x80u : 0x00u))))
{
throw new ProtocolViolationException($"duplicate sequence number {basic.seqnum}");
}
Logger.Trace($"Scheduled seqnum={basic.seqnum}, pending count = {PendingSubmits.Count}");
throw new ProtocolViolationException($"duplicate sequence number {basic.seqnum}");
}
pending = true;

// Input or output, exceptions or not, this buffer must be locked until after the ioctl has completed.
var gcHandle = GCHandle.Alloc(buf, GCHandleType.Pinned);
try
Expand All @@ -340,27 +366,14 @@ async Task HandleSubmitAsync(UsbIpHeaderBasic basic, UsbIpHeaderCmdSubmit submit
}

// At this point we have initiated the ioctl (and possibly awaited it for special cases).
// Now we schedule a continuation to write the response once the ioctl completes.
// This is fire-and-forget; we'll return to the caller so it can already receive the next request.
// Now we schedule a continuation to create the response once the ioctl completes.

_ = ioctl.ContinueWith(async (task, state) =>
var replyTask = ioctl.ContinueWith(byte[] (task, _) =>
{
using var writeLock = await Lock.CreateAsync(WriteMutex, cancellationToken);

// Now we are synchronous with the sender.

if (pending)
{
lock (PendingSubmitsLock)
{
// We are racing with possible UNLINK commands.
if (!PendingSubmits.Remove(basic.seqnum))
{
// Apparently, the client has already UNLINK-ed (canceled) the request; we're done.
Logger.Trace($"Completed seqnum={basic.seqnum} after UNLINK, pending count = {PendingSubmits.Count}");
return;
}
}
// The pending request is now completed; no need to support UNLINK any longer.
PendingSubmits.Remove(basic.seqnum, out var _);
BytesToStruct(bytes, out urb);
}

Expand Down Expand Up @@ -413,26 +426,32 @@ async Task HandleSubmitAsync(UsbIpHeaderBasic basic, UsbIpHeaderCmdSubmit submit
}
Logger.Trace($"actual: {header.ret_submit.actual_length}, requested: {requestLength}");

await Stream.WriteAsync(header.ToBytes(), cancellationToken);
using var replyStream = new MemoryStream();
replyStream.Write(header.ToBytes());
if (basic.direction == UsbIpDir.USBIP_DIR_IN)
{
await Stream.WriteAsync(buf.AsMemory(payloadOffset, header.ret_submit.actual_length), cancellationToken);
replyStream.Write(buf.AsSpan(payloadOffset, header.ret_submit.actual_length));
}
return replyStream.ToArray();
}, cancellationToken, TaskScheduler.Default);

// Now we queue the task that creates the response, so that all replies for a single
// endpoint remain ordered.

var endpointWriter = GetEndpointWriter((byte)basic.ep, cancellationToken);
await endpointWriter.WriteAsync(replyTask, cancellationToken);

// We return to the caller, so that the next request can be handled. As a result, multiple requests
// can be outstanding (either for the same, or for multiple endpoints). Requests are completed
// asynchronously and in any order, but the replies for each endpoint are sent to the client in original order.
}

async Task HandleUnlinkAsync(UsbIpHeaderBasic basic, UsbIpHeaderCmdUnlink unlink, CancellationToken cancellationToken)
{
// We are synchronous with the receiver.

var pending = false;
byte endpoint;

lock (PendingSubmitsLock)
{
pending = PendingSubmits.Remove(unlink.seqnum, out endpoint);
Logger.Trace($"Unlinking {unlink.seqnum}, pending = {pending}, pending count = {PendingSubmits.Count}");
}
var pending = PendingSubmits.TryGetValue(unlink.seqnum, out var endpoint);
Logger.Trace($"Unlinking {unlink.seqnum}, pending = {pending}, pending count = {PendingSubmits.Count}");

if (pending)
{
Expand All @@ -448,10 +467,6 @@ async Task HandleUnlinkAsync(UsbIpHeaderBasic basic, UsbIpHeaderCmdUnlink unlink
await Device.IoControlAsync(SUPUSB_IOCTL.USB_ABORT_ENDPOINT, StructToBytes(clearEndpoint), null);
}

using var writeLock = await Lock.CreateAsync(WriteMutex, cancellationToken);

// Now we are synchronous with the sender.

var header = new UsbIpHeader
{
basic = new()
Expand All @@ -461,15 +476,36 @@ async Task HandleUnlinkAsync(UsbIpHeaderBasic basic, UsbIpHeaderCmdUnlink unlink
},
ret_submit = new()
{
status = -(int)(pending ? Errno.ECONNRESET : Errno.SUCCESS),
status = -(int)Errno.SUCCESS,
},
};

await Stream.WriteAsync(header.ToBytes(), cancellationToken);
if (pending)
{
// We need to queue this on the same endpoint that the UNLINK was for, such
// that the reply of the aborted request gets sent first.
var endpointWriter = GetEndpointWriter((byte)(endpoint & 0x7f), cancellationToken);
await endpointWriter.WriteAsync(Task.FromResult(header.ToBytes()), cancellationToken);
}
else
{
// We didn't actually need to abort anything, so we can write the reply immediately.
await ReplyChannel.Writer.WriteAsync(header.ToBytes(), cancellationToken);
}
}

public async Task RunAsync(CancellationToken cancellationToken)
{
_ = Task.Run(async () =>
{
// This task multiplexes all the replies.
while (!cancellationToken.IsCancellationRequested)
{
var nextReply = await ReplyChannel.Reader.ReadAsync(cancellationToken);
await Stream.WriteAsync(nextReply, cancellationToken);
}
}, cancellationToken);

while (true)
{
var header = await Stream.ReadUsbIpHeaderAsync(cancellationToken);
Expand Down