Skip to content

Commit

Permalink
feat: optimize CachedRequestBuilder (#1716)
Browse files Browse the repository at this point in the history
Co-authored-by: Chris Pulman <chris.pulman@yahoo.com>
  • Loading branch information
TimothyMakkison and ChrisPulman authored Jun 25, 2024
1 parent 107d716 commit b320e4e
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 24 deletions.
148 changes: 148 additions & 0 deletions Refit.Tests/CachedRequestBuilder.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
using System.Net;
using System.Net.Http;
using System.Reflection;

using RichardSzalay.MockHttp;

using Xunit;

namespace Refit.Tests;

public interface IGeneralRequests
{
[Post("/foo")]
Task Empty();

[Post("/foo")]
Task SingleParameter(string id);

[Post("/foo")]
Task MultiParameter(string id, string name);

[Post("/foo")]
Task SingleGenericMultiParameter<TValue>(string id, string name, TValue generic);
}

public interface IDuplicateNames
{
[Post("/foo")]
Task SingleParameter(string id);

[Post("/foo")]
Task SingleParameter(int id);
}

public class CachedRequestBuilderTests
{
[Fact]
public async Task CacheHasCorrectNumberOfElementsTest()
{
var mockHttp = new MockHttpMessageHandler();
var settings = new RefitSettings { HttpMessageHandlerFactory = () => mockHttp };

var fixture = RestService.For<IGeneralRequests>("http://bar", settings);

// get internal dictionary to check count
var requestBuilderField = fixture.GetType().GetFields(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public).Single(x => x.Name == "requestBuilder");
var requestBuilder = requestBuilderField.GetValue(fixture) as CachedRequestBuilderImplementation;

mockHttp
.Expect(HttpMethod.Post, "http://bar/foo")
.Respond(HttpStatusCode.OK);
await fixture.Empty();
Assert.Single(requestBuilder.MethodDictionary);

mockHttp
.Expect(HttpMethod.Post, "http://bar/foo")
.WithQueryString("id", "id")
.Respond(HttpStatusCode.OK);
await fixture.SingleParameter("id");
Assert.Equal(2, requestBuilder.MethodDictionary.Count);

mockHttp
.Expect(HttpMethod.Post, "http://bar/foo")
.WithQueryString("id", "id")
.WithQueryString("name", "name")
.Respond(HttpStatusCode.OK);
await fixture.MultiParameter("id", "name");
Assert.Equal(3, requestBuilder.MethodDictionary.Count);

mockHttp
.Expect(HttpMethod.Post, "http://bar/foo")
.WithQueryString("id", "id")
.WithQueryString("name", "name")
.WithQueryString("generic", "generic")
.Respond(HttpStatusCode.OK);
await fixture.SingleGenericMultiParameter("id", "name", "generic");
Assert.Equal(4, requestBuilder.MethodDictionary.Count);

mockHttp.VerifyNoOutstandingExpectation();
}

[Fact]
public async Task NoDuplicateEntriesTest()
{
var mockHttp = new MockHttpMessageHandler();
var settings = new RefitSettings { HttpMessageHandlerFactory = () => mockHttp };

var fixture = RestService.For<IGeneralRequests>("http://bar", settings);

// get internal dictionary to check count
var requestBuilderField = fixture.GetType().GetFields(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public).Single(x => x.Name == "requestBuilder");
var requestBuilder = requestBuilderField.GetValue(fixture) as CachedRequestBuilderImplementation;

// send the same request repeatedly to ensure that multiple dictionary entries are not created
mockHttp
.Expect(HttpMethod.Post, "http://bar/foo")
.WithQueryString("id", "id")
.Respond(HttpStatusCode.OK);
await fixture.SingleParameter("id");
Assert.Single(requestBuilder.MethodDictionary);

mockHttp
.Expect(HttpMethod.Post, "http://bar/foo")
.WithQueryString("id", "id")
.Respond(HttpStatusCode.OK);
await fixture.SingleParameter("id");
Assert.Single(requestBuilder.MethodDictionary);

mockHttp
.Expect(HttpMethod.Post, "http://bar/foo")
.WithQueryString("id", "id")
.Respond(HttpStatusCode.OK);
await fixture.SingleParameter("id");
Assert.Single(requestBuilder.MethodDictionary);

mockHttp.VerifyNoOutstandingExpectation();
}

[Fact]
public async Task SameNameDuplicateEntriesTest()
{
var mockHttp = new MockHttpMessageHandler();
var settings = new RefitSettings { HttpMessageHandlerFactory = () => mockHttp };

var fixture = RestService.For<IDuplicateNames>("http://bar", settings);

// get internal dictionary to check count
var requestBuilderField = fixture.GetType().GetFields(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public).Single(x => x.Name == "requestBuilder");
var requestBuilder = requestBuilderField.GetValue(fixture) as CachedRequestBuilderImplementation;

// send the two different requests with the same name
mockHttp
.Expect(HttpMethod.Post, "http://bar/foo")
.WithQueryString("id", "id")
.Respond(HttpStatusCode.OK);
await fixture.SingleParameter("id");
Assert.Single(requestBuilder.MethodDictionary);

mockHttp
.Expect(HttpMethod.Post, "http://bar/foo")
.WithQueryString("id", "10")
.Respond(HttpStatusCode.OK);
await fixture.SingleParameter(10);
Assert.Equal(2, requestBuilder.MethodDictionary.Count);

mockHttp.VerifyNoOutstandingExpectation();
}
}
108 changes: 84 additions & 24 deletions Refit/CachedRequestBuilderImplementation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,33 @@ public CachedRequestBuilderImplementation(IRequestBuilder innerBuilder)
}

readonly IRequestBuilder innerBuilder;
readonly ConcurrentDictionary<
string,
internal readonly ConcurrentDictionary<
MethodTableKey,
Func<HttpClient, object[], object?>
> methodDictionary = new();
> MethodDictionary = new();

public Func<HttpClient, object[], object?> BuildRestResultFuncForMethod(
string methodName,
Type[]? parameterTypes = null,
Type[]? genericArgumentTypes = null
)
{
var cacheKey = GetCacheKey(
var cacheKey = new MethodTableKey(
methodName,
parameterTypes ?? Array.Empty<Type>(),
genericArgumentTypes ?? Array.Empty<Type>()
);
var func = methodDictionary.GetOrAdd(
cacheKey,

if (MethodDictionary.TryGetValue(cacheKey, out var methodFunc))
{
return methodFunc;
}

// use GetOrAdd with cloned array method table key. This prevents the array from being modified, breaking the dictionary.
var func = MethodDictionary.GetOrAdd(
new MethodTableKey(methodName,
parameterTypes?.ToArray() ?? Array.Empty<Type>(),
genericArgumentTypes?.ToArray() ?? Array.Empty<Type>()),
_ =>
innerBuilder.BuildRestResultFuncForMethod(
methodName,
Expand All @@ -48,37 +57,88 @@ readonly ConcurrentDictionary<

return func;
}
}

static string GetCacheKey(
string methodName,
Type[] parameterTypes,
Type[] genericArgumentTypes
)
/// <summary>
/// Represents a method composed of its name, generic arguments and parameters.
/// </summary>
internal readonly struct MethodTableKey : IEquatable<MethodTableKey>
{
/// <summary>
/// Constructs an instance of <see cref="MethodTableKey"/>.
/// </summary>
/// <param name="methodName">Represents the methods name.</param>
/// <param name="parameters">Array containing the methods parameters.</param>
/// <param name="genericArguments">Array containing the methods generic arguments.</param>
public MethodTableKey (string methodName, Type[] parameters, Type[] genericArguments)
{
var genericDefinition = GetGenericString(genericArgumentTypes);
var argumentString = GetArgumentString(parameterTypes);

return $"{methodName}{genericDefinition}({argumentString})";
MethodName = methodName;
Parameters = parameters;
GenericArguments = genericArguments;
}

static string GetArgumentString(Type[] parameterTypes)
/// <summary>
/// The methods name.
/// </summary>
string MethodName { get; }

/// <summary>
/// Array containing the methods parameters.
/// </summary>
Type[] Parameters { get; }

/// <summary>
/// Array containing the methods generic arguments.
/// </summary>
Type[] GenericArguments { get; }

public override int GetHashCode()
{
if (parameterTypes == null || parameterTypes.Length == 0)
unchecked
{
return "";
}
var hashCode = MethodName.GetHashCode();

foreach (var argument in Parameters)
{
hashCode = (hashCode * 397) ^ argument.GetHashCode();
}

return string.Join(", ", parameterTypes.Select(t => t.FullName));
foreach (var genericArgument in GenericArguments)
{
hashCode = (hashCode * 397) ^ genericArgument.GetHashCode();
}
return hashCode;
}
}

static string GetGenericString(Type[] genericArgumentTypes)
public bool Equals(MethodTableKey other)
{
if (genericArgumentTypes == null || genericArgumentTypes.Length == 0)
if (Parameters.Length != other.Parameters.Length
|| GenericArguments.Length != other.GenericArguments.Length
|| MethodName != other.MethodName)
{
return "";
return false;
}

return "<" + string.Join(", ", genericArgumentTypes.Select(t => t.FullName)) + ">";
for (var i = 0; i < Parameters.Length; i++)
{
if (Parameters[i] != other.Parameters[i])
{
return false;
}
}

for (var i = 0; i < GenericArguments.Length; i++)
{
if (GenericArguments[i] != other.GenericArguments[i])
{
return false;
}
}

return true;
}

public override bool Equals(object? obj) => obj is MethodTableKey other && Equals(other);
}
}

0 comments on commit b320e4e

Please sign in to comment.