Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/libraries/System.Net.Http/ref/System.Net.Http.cs
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ public partial class MultipartContent : System.Net.Http.HttpContent, System.Coll
public MultipartContent() { }
public MultipartContent(string subtype) { }
public MultipartContent(string subtype, string boundary) { }
public System.Net.Http.HeaderEncodingSelector<System.Net.Http.HttpContent>? HeaderEncodingSelector { get { throw null; } set { } }
public virtual void Add(System.Net.Http.HttpContent content) { }
protected override System.IO.Stream CreateContentReadStream(System.Threading.CancellationToken cancellationToken) { throw null; }
protected override System.Threading.Tasks.Task<System.IO.Stream> CreateContentReadStreamAsync() { throw null; }
Expand Down
153 changes: 90 additions & 63 deletions src/libraries/System.Net.Http/src/System/Net/Http/MultipartContent.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
Expand All @@ -18,10 +19,10 @@ public class MultipartContent : HttpContent, IEnumerable<HttpContent>

private const string CrLf = "\r\n";

private static readonly int s_crlfLength = GetEncodedLength(CrLf);
private static readonly int s_dashDashLength = GetEncodedLength("--");
private static readonly int s_colonSpaceLength = GetEncodedLength(": ");
private static readonly int s_commaSpaceLength = GetEncodedLength(", ");
private const int CrLfLength = 2;
private const int DashDashLength = 2;
private const int ColonSpaceLength = 2;
private const int CommaSpaceLength = 2;

private readonly List<HttpContent> _nestedContent;
private readonly string _boundary;
Expand Down Expand Up @@ -157,6 +158,12 @@ Collections.IEnumerator Collections.IEnumerable.GetEnumerator()

#region Serialization

/// <summary>
/// Gets or sets a callback that returns the <see cref="Encoding"/> to decode the value for the specified response header name,
/// or <see langword="null"/> to use the default behavior.
/// </summary>
public HeaderEncodingSelector<HttpContent>? HeaderEncodingSelector { get; set; }

// for-each content
// write "--" + boundary
// for-each content header
Expand All @@ -171,20 +178,19 @@ protected override void SerializeToStream(Stream stream, TransportContext? conte
try
{
// Write start boundary.
EncodeStringToStream(stream, "--" + _boundary + CrLf);
WriteToStream(stream, "--" + _boundary + CrLf);

// Write each nested content.
var output = new StringBuilder();
for (int contentIndex = 0; contentIndex < _nestedContent.Count; contentIndex++)
{
// Write divider, headers, and content.
HttpContent content = _nestedContent[contentIndex];
EncodeStringToStream(stream, SerializeHeadersToString(output, contentIndex, content));
SerializeHeadersToStream(stream, content, writeDivider: contentIndex != 0);
content.CopyTo(stream, context, cancellationToken);
}

// Write footer boundary.
EncodeStringToStream(stream, CrLf + "--" + _boundary + "--" + CrLf);
WriteToStream(stream, CrLf + "--" + _boundary + "--" + CrLf);
}
catch (Exception ex)
{
Expand Down Expand Up @@ -219,12 +225,17 @@ private protected async Task SerializeToStreamAsyncCore(Stream stream, Transport
await EncodeStringToStreamAsync(stream, "--" + _boundary + CrLf, cancellationToken).ConfigureAwait(false);

// Write each nested content.
var output = new StringBuilder();
var output = new MemoryStream();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we write directly into stream to prevent copying?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can at the cost of doubling the SerializeHeaders logic

for (int contentIndex = 0; contentIndex < _nestedContent.Count; contentIndex++)
{
// Write divider, headers, and content.
HttpContent content = _nestedContent[contentIndex];
await EncodeStringToStreamAsync(stream, SerializeHeadersToString(output, contentIndex, content), cancellationToken).ConfigureAwait(false);

output.SetLength(0);
SerializeHeadersToStream(output, content, writeDivider: contentIndex != 0);
output.Position = 0;
await output.CopyToAsync(stream, cancellationToken).ConfigureAwait(false);

await content.CopyToAsync(stream, context, cancellationToken).ConfigureAwait(false);
}

Expand Down Expand Up @@ -259,7 +270,6 @@ private async ValueTask<Stream> CreateContentReadStreamAsyncCore(bool async, Can
try
{
var streams = new Stream[2 + (_nestedContent.Count * 2)];
var scratch = new StringBuilder();
int streamIndex = 0;

// Start boundary.
Expand All @@ -271,7 +281,7 @@ private async ValueTask<Stream> CreateContentReadStreamAsyncCore(bool async, Can
cancellationToken.ThrowIfCancellationRequested();

HttpContent nestedContent = _nestedContent[contentIndex];
streams[streamIndex++] = EncodeStringToNewStream(SerializeHeadersToString(scratch, contentIndex, nestedContent));
streams[streamIndex++] = EncodeHeadersToNewStream(nestedContent, writeDivider: contentIndex != 0);

Stream readStream;
if (async)
Expand Down Expand Up @@ -312,43 +322,35 @@ private async ValueTask<Stream> CreateContentReadStreamAsyncCore(bool async, Can
}
}

private string SerializeHeadersToString(StringBuilder scratch, int contentIndex, HttpContent content)
private void SerializeHeadersToStream(Stream stream, HttpContent content, bool writeDivider)
{
scratch.Clear();

// Add divider.
if (contentIndex != 0) // Write divider for all but the first content.
if (writeDivider) // Write divider for all but the first content.
{
scratch.Append(CrLf + "--"); // const strings
scratch.Append(_boundary);
scratch.Append(CrLf);
WriteToStream(stream, CrLf + "--"); // const strings
WriteToStream(stream, _boundary);
WriteToStream(stream, CrLf);
}

// Add headers.
foreach (KeyValuePair<string, IEnumerable<string>> headerPair in content.Headers)
{
scratch.Append(headerPair.Key);
scratch.Append(": ");
Encoding headerValueEncoding = HeaderEncodingSelector?.Invoke(headerPair.Key, content) ?? HttpRuleParser.DefaultHttpEncoding;

WriteToStream(stream, headerPair.Key);
WriteToStream(stream, ": ");
string delim = string.Empty;
foreach (string value in headerPair.Value)
{
scratch.Append(delim);
scratch.Append(value);
WriteToStream(stream, delim);
WriteToStream(stream, value, headerValueEncoding);
delim = ", ";
}
scratch.Append(CrLf);
WriteToStream(stream, CrLf);
}

// Extra CRLF to end headers (even if there are no headers).
scratch.Append(CrLf);

return scratch.ToString();
}

private static void EncodeStringToStream(Stream stream, string input)
{
byte[] buffer = HttpRuleParser.DefaultHttpEncoding.GetBytes(input);
stream.Write(buffer);
WriteToStream(stream, CrLf);
}

private static ValueTask EncodeStringToStreamAsync(Stream stream, string input, CancellationToken cancellationToken)
Expand All @@ -362,55 +364,55 @@ private static Stream EncodeStringToNewStream(string input)
return new MemoryStream(HttpRuleParser.DefaultHttpEncoding.GetBytes(input), writable: false);
}

private Stream EncodeHeadersToNewStream(HttpContent content, bool writeDivider)
{
var stream = new MemoryStream();
SerializeHeadersToStream(stream, content, writeDivider);
stream.Position = 0;
return stream;
}

internal override bool AllowDuplex => false;

protected internal override bool TryComputeLength(out long length)
{
int boundaryLength = GetEncodedLength(_boundary);

long currentLength = 0;
long internalBoundaryLength = s_crlfLength + s_dashDashLength + boundaryLength + s_crlfLength;

// Start Boundary.
currentLength += s_dashDashLength + boundaryLength + s_crlfLength;
long currentLength = DashDashLength + _boundary.Length + CrLfLength;

bool first = true;
foreach (HttpContent content in _nestedContent)
if (_nestedContent.Count > 1)
{
if (first)
{
first = false; // First boundary already written.
}
else
{
// Internal Boundary.
currentLength += internalBoundaryLength;
}
// Internal boundaries
currentLength += (_nestedContent.Count - 1) * (CrLfLength + DashDashLength + _boundary.Length + CrLfLength);
}

foreach (HttpContent content in _nestedContent)
{
// Headers.
foreach (KeyValuePair<string, IEnumerable<string>> headerPair in content.Headers)
{
currentLength += GetEncodedLength(headerPair.Key) + s_colonSpaceLength;
currentLength += headerPair.Key.Length + ColonSpaceLength;

Encoding headerValueEncoding = HeaderEncodingSelector?.Invoke(headerPair.Key, content) ?? HttpRuleParser.DefaultHttpEncoding;

int valueCount = 0;
foreach (string value in headerPair.Value)
{
currentLength += GetEncodedLength(value);
currentLength += headerValueEncoding.GetByteCount(value);
valueCount++;
}

if (valueCount > 1)
{
currentLength += (valueCount - 1) * s_commaSpaceLength;
currentLength += (valueCount - 1) * CommaSpaceLength;
}

currentLength += s_crlfLength;
currentLength += CrLfLength;
}

currentLength += s_crlfLength;
currentLength += CrLfLength;

// Content.
long tempContentLength = 0;
if (!content.TryComputeLength(out tempContentLength))
if (!content.TryComputeLength(out long tempContentLength))
{
length = 0;
return false;
Expand All @@ -419,17 +421,12 @@ protected internal override bool TryComputeLength(out long length)
}

// Terminating boundary.
currentLength += s_crlfLength + s_dashDashLength + boundaryLength + s_dashDashLength + s_crlfLength;
currentLength += CrLfLength + DashDashLength + _boundary.Length + DashDashLength + CrLfLength;

length = currentLength;
return true;
}

private static int GetEncodedLength(string input)
{
return HttpRuleParser.DefaultHttpEncoding.GetByteCount(input);
}

private sealed class ContentReadStream : Stream
{
private readonly Stream[] _streams;
Expand Down Expand Up @@ -671,6 +668,36 @@ public override void Flush() { }
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { throw new NotSupportedException(); }
public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default) { throw new NotSupportedException(); }
}


private static void WriteToStream(Stream stream, string content) =>
WriteToStream(stream, content, HttpRuleParser.DefaultHttpEncoding);

private static void WriteToStream(Stream stream, string content, Encoding encoding)
{
const int StackallocThreshold = 1024;

int maxLength = encoding.GetMaxByteCount(content.Length);

byte[]? rentedBuffer = null;
Span<byte> buffer = maxLength <= StackallocThreshold
? stackalloc byte[StackallocThreshold]
: (rentedBuffer = ArrayPool<byte>.Shared.Rent(maxLength));

try
{
int written = encoding.GetBytes(content, buffer);
stream.Write(buffer.Slice(0, written));
}
finally
{
if (rentedBuffer != null)
{
ArrayPool<byte>.Shared.Return(rentedBuffer);
}
}
}

#endregion Serialization
}
}
Loading