diff --git a/tests/NRedisStack.Tests/NRedisStack.Tests.csproj b/tests/NRedisStack.Tests/NRedisStack.Tests.csproj index 7aceede5..cb522fcc 100644 --- a/tests/NRedisStack.Tests/NRedisStack.Tests.csproj +++ b/tests/NRedisStack.Tests/NRedisStack.Tests.csproj @@ -31,6 +31,7 @@ + diff --git a/tests/NRedisStack.Tests/RedisFixture.cs b/tests/NRedisStack.Tests/RedisFixture.cs index 0a145e85..9dfdf2eb 100644 --- a/tests/NRedisStack.Tests/RedisFixture.cs +++ b/tests/NRedisStack.Tests/RedisFixture.cs @@ -52,9 +52,12 @@ public class RedisFixture : IDisposable public bool isEnterprise = Environment.GetEnvironmentVariable("IS_ENTERPRISE") == "true"; public bool isOSSCluster; + private ConnectionMultiplexer redis; + private ConfigurationOptions defaultConfig; + public RedisFixture() { - ConfigurationOptions clusterConfig = new ConfigurationOptions + defaultConfig = new ConfigurationOptions { AsyncTimeout = 10000, SyncTimeout = 10000 @@ -93,8 +96,6 @@ public RedisFixture() isOSSCluster = true; } } - - Redis = GetConnectionById(clusterConfig, defaultEndpointId); } public void Dispose() @@ -102,7 +103,14 @@ public void Dispose() Redis.Close(); } - public ConnectionMultiplexer Redis { get; } + public ConnectionMultiplexer Redis + { + get + { + redis = redis ?? GetConnectionById(defaultConfig, defaultEndpointId); + return redis; + } + } public ConnectionMultiplexer GetConnectionById(ConfigurationOptions configurationOptions, string id) { diff --git a/tests/NRedisStack.Tests/SkipIfRedisAttribute.cs b/tests/NRedisStack.Tests/SkipIfRedisAttribute.cs index b62dc17a..eae76c4d 100644 --- a/tests/NRedisStack.Tests/SkipIfRedisAttribute.cs +++ b/tests/NRedisStack.Tests/SkipIfRedisAttribute.cs @@ -21,6 +21,8 @@ public class SkipIfRedisAttribute : FactAttribute private readonly Comparison _comparison; private readonly List _environments = new List(); + private static Version serverVersion = null; + public SkipIfRedisAttribute( Is environment, Comparison comparison = Comparison.LessThan, @@ -95,7 +97,7 @@ public override string? Skip } // Version check (if Is.Standalone/Is.OSSCluster is set then ) - var serverVersion = redisFixture.Redis.GetServer(redisFixture.Redis.GetEndPoints()[0]).Version; + serverVersion = serverVersion ?? redisFixture.Redis.GetServer(redisFixture.Redis.GetEndPoints()[0]).Version; var targetVersion = new Version(_targetVersion); int comparisonResult = serverVersion.CompareTo(targetVersion); diff --git a/tests/NRedisStack.Tests/TargetEnvironmentAttribute.cs b/tests/NRedisStack.Tests/TargetEnvironmentAttribute.cs new file mode 100644 index 00000000..4497aef0 --- /dev/null +++ b/tests/NRedisStack.Tests/TargetEnvironmentAttribute.cs @@ -0,0 +1,36 @@ +using Xunit; + +namespace NRedisStack.Tests; +[AttributeUsage(AttributeTargets.Method, AllowMultiple = false)] +public class TargetEnvironmentAttribute : SkipIfRedisAttribute +{ + private string targetEnv; + public TargetEnvironmentAttribute(string targetEnv) : base(Comparison.LessThan, "0.0.0") + { + this.targetEnv = targetEnv; + } + + public TargetEnvironmentAttribute(string targetEnv, Is environment, Comparison comparison = Comparison.LessThan, + string targetVersion = "0.0.0") : base(environment, comparison, targetVersion) + { + this.targetEnv = targetEnv; + } + + public TargetEnvironmentAttribute(string targetEnv, Is environment1, Is environment2, Comparison comparison = Comparison.LessThan, + string targetVersion = "0.0.0") : base(environment1, environment2, comparison, targetVersion) + { + this.targetEnv = targetEnv; + } + + public override string? Skip + { + get + { + if (!new RedisFixture().IsTargetConnectionExist(targetEnv)) + { + return "Test skipped, because: target environment not found."; + } + return base.Skip; + } + } +} \ No newline at end of file diff --git a/tests/NRedisStack.Tests/TokenBasedAuthentication/AuthenticationTests.cs b/tests/NRedisStack.Tests/TokenBasedAuthentication/AuthenticationTests.cs new file mode 100644 index 00000000..db8842c2 --- /dev/null +++ b/tests/NRedisStack.Tests/TokenBasedAuthentication/AuthenticationTests.cs @@ -0,0 +1,53 @@ +using Xunit; +using StackExchange.Redis; +using Azure.Identity; +using NRedisStack.RedisStackCommands; +using NRedisStack.Search; + +namespace NRedisStack.Tests.TokenBasedAuthentication +{ + public class AuthenticationTests : AbstractNRedisStackTest + { + static readonly string key = "myKey"; + static readonly string value = "myValue"; + static readonly string index = "myIndex"; + static readonly string field = "myField"; + static readonly string alias = "myAlias"; + public AuthenticationTests(RedisFixture redisFixture) : base(redisFixture) { } + + [TargetEnvironment("standalone-entraid-acl")] + public void TestTokenBasedAuthentication() + { + + var configurationOptions = new ConfigurationOptions().ConfigureForAzureWithTokenCredentialAsync(new DefaultAzureCredential()).Result!; + configurationOptions.Ssl = false; + configurationOptions.AbortOnConnectFail = true; // Fail fast for the purposes of this sample. In production code, this should remain false to retry connections on startup + + ConnectionMultiplexer? connectionMultiplexer = redisFixture.GetConnectionById(configurationOptions, "standalone-entraid-acl"); + + IDatabase db = connectionMultiplexer.GetDatabase(); + + db.KeyDelete(key); + try + { + db.FT().DropIndex(index); + } + catch { } + + db.StringSet(key, value); + string result = db.StringGet(key); + Assert.Equal(value, result); + + var ft = db.FT(); + Schema sc = new Schema().AddTextField(field); + Assert.True(ft.Create(index, FTCreateParams.CreateParams(), sc)); + + db.HashSet(index, new HashEntry[] { new HashEntry(field, value) }); + + Assert.True(ft.AliasAdd(alias, index)); + SearchResult res1 = ft.Search(alias, new Query("*").ReturnFields(field)); + Assert.Equal(1, res1.TotalResults); + Assert.Equal(value, res1.Documents[0][field]); + } + } +} \ No newline at end of file diff --git a/tests/NRedisStack.Tests/TokenBasedAuthentication/FaultInjectorClient.cs b/tests/NRedisStack.Tests/TokenBasedAuthentication/FaultInjectorClient.cs new file mode 100644 index 00000000..cf319cf1 --- /dev/null +++ b/tests/NRedisStack.Tests/TokenBasedAuthentication/FaultInjectorClient.cs @@ -0,0 +1,113 @@ +using System.Text; +using System.Text.Json; +using System.Net.Http; + +public class FaultInjectorClient +{ + private static readonly string BASE_URL; + + static FaultInjectorClient() + { + BASE_URL = Environment.GetEnvironmentVariable("FAULT_INJECTION_API_URL") ?? "http://127.0.0.1:20324"; + } + + public class TriggerActionResponse + { + public string ActionId { get; } + private DateTime? LastRequestTime { get; set; } + private DateTime? CompletedAt { get; set; } + private DateTime? FirstRequestAt { get; set; } + + public TriggerActionResponse(string actionId) + { + ActionId = actionId; + } + + public async Task IsCompletedAsync(TimeSpan checkInterval, TimeSpan delayAfter, TimeSpan timeout) + { + if (CompletedAt.HasValue) + { + return DateTime.UtcNow - CompletedAt.Value >= delayAfter; + } + + if (FirstRequestAt.HasValue && DateTime.UtcNow - FirstRequestAt.Value >= timeout) + { + throw new TimeoutException("Timeout"); + } + + if (!LastRequestTime.HasValue || DateTime.UtcNow - LastRequestTime.Value >= checkInterval) + { + LastRequestTime = DateTime.UtcNow; + + if (!FirstRequestAt.HasValue) + { + FirstRequestAt = LastRequestTime; + } + + using var httpClient = GetHttpClient(); + var request = new HttpRequestMessage(HttpMethod.Get, $"{BASE_URL}/action/{ActionId}"); + + try + { + var response = await httpClient.SendAsync(request); + var result = await response.Content.ReadAsStringAsync(); + + + if (result.Contains("success")) + { + CompletedAt = DateTime.UtcNow; + return DateTime.UtcNow - CompletedAt.Value >= delayAfter; + } + } + catch (HttpRequestException e) + { + throw new Exception("Fault injection proxy error", e); + } + } + return false; + } + } + + private static HttpClient GetHttpClient() + { + var httpClient = new HttpClient + { + Timeout = TimeSpan.FromMilliseconds(5000) + }; + return httpClient; + } + + public async Task TriggerActionAsync(string actionType, Dictionary parameters) + { + var payload = new Dictionary + { + { "type", actionType }, + { "parameters", parameters } + }; + + var jsonString = JsonSerializer.Serialize(payload, new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase + }); + + using var httpClient = GetHttpClient(); + var request = new HttpRequestMessage(HttpMethod.Post, $"{BASE_URL}/action") + { + Content = new StringContent(jsonString, Encoding.UTF8, "application/json") + }; + + try + { + var response = await httpClient.SendAsync(request); + var result = await response.Content.ReadAsStringAsync(); + return JsonSerializer.Deserialize(result, new JsonSerializerOptions + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase + }); + } + catch (HttpRequestException e) + { + throw; + } + } +}