diff --git a/tests/EdgeDB.Tests.Unit/SharedClientTests.cs b/tests/EdgeDB.Tests.Unit/SharedClientTests.cs new file mode 100644 index 00000000..18d8f383 --- /dev/null +++ b/tests/EdgeDB.Tests.Unit/SharedClientTests.cs @@ -0,0 +1,713 @@ +using EdgeDB.Abstractions; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.IO; +using System.Linq; +using System.Runtime.InteropServices; +using System.Security.Cryptography; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace EdgeDB.Tests.Unit; + +[TestClass] +public class SharedClientTests +{ + [TestMethod] + public void TestConnectParams() + { + StreamReader reader = new("shared-client-testcases/connection_testcases.json"); + List<TestCase>? testcases = JsonSerializer.Deserialize<List<TestCase>>(reader.ReadToEnd()); + if (testcases is null) + { + throw new JsonException("Failed to read 'connection_testcases.json.\n" + + "Is the 'shared-client-testcases' submodule initialised? " + + "Try running 'git submodule update --init'."); + } + + foreach ((int textIndex, TestCase testCase) in testcases.Select((x, i) => (i, x))) + { + Console.WriteLine(testCase.Name); + + if (testCase.FileSystem is not null + && ( + !(testCase.Platform is null && RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) || + !(testCase.Platform == "windows" && !RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) || + !(testCase.Platform == "macos" && RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + )) + { + // skipping unsupported platform test + continue; + } + + if ((testCase.Result is null) == (testCase.Error is null)) + { + throw new Exception("invalid test case: either \"result\" or \"error\" key has to be specified"); + } + + EdgeDBConnection? expectedConnection = GetExpectedConnection(testCase); + Exception? expectedException = testCase.Error?.GetExpectedException(); + Result result = ParseConnection(testCase); + + if (expectedConnection is not null) + { + Expect(result, expectedConnection); + } + else if (expectedException is not null) + { + ExpectException(result, expectedException); + } + } + + Assert.IsTrue(false); + } + + private static void Expect(Result result, EdgeDBConnection expected) + { + if (result.Exception is not null) + { + Console.WriteLine($" {result.Exception.Message}"); + } + Assert.IsNotNull(result.Connection); + var actual = result.Connection; + + Assert.AreEqual(expected.Username, actual.Username); + Assert.AreEqual(expected.Password, actual.Password); + static string ResolveHostname(string hostname) + { + return hostname == "localhost" ? "127.0.0.1" : hostname; + } + Assert.AreEqual(ResolveHostname(expected.Hostname), ResolveHostname(actual.Hostname)); + Assert.AreEqual(expected.Port, actual.Port); + Assert.AreEqual(expected.Database, actual.Database); + Assert.AreEqual(expected.TLSCertificateAuthority, actual.TLSCertificateAuthority); + Assert.AreEqual(expected.TLSServerName, actual.TLSServerName); + Assert.AreEqual(expected.TLSSecurity, actual.TLSSecurity); + CollectionAssert.AreEqual(expected.ServerSettings, actual.ServerSettings); + } + + private static void ExpectException(Result result, Exception expected) + { + Assert.IsNotNull(result.Exception); + Assert.IsInstanceOfType(result.Exception, expected.GetType(), + $"Exception type {expected.GetType()} expected but got {result.Exception.GetType()}"); + Assert.AreEqual(expected.Message, result.Exception.Message); + } + + private static Result ParseConnection(TestCase testCase) + { + try + { + MockSystemProvider mockSystem = new(testCase); + + EdgeDBConnection connection = EdgeDBConnection._Parse( + instance: null, + dsn: testCase?.Options?.Instance ?? testCase?.Options?.Dsn, + configure: x => { + TestCase.Credentials? credentials = + JsonSerializer.Deserialize<TestCase.Credentials>( + testCase?.Options?.CredentialsFile is not null ? + mockSystem.FileReadAllText(testCase.Options.CredentialsFile) : + testCase?.Options?.Credentials is not null ? + testCase.Options.Credentials : + "{}" + ); + + if (testCase?.Options?.Host is not null) x.Hostname = testCase.Options.Host; + if (testCase?.Options?.Port is not null) x.Port = int.Parse(testCase.Options.Port); + if (testCase?.Options?.Database is not null) x.Database = testCase.Options.Database; + if (testCase?.Options?.Branch is not null) x.Branch = testCase.Options.Branch; + if (testCase?.Options?.User is not null) x.Username = testCase.Options.User; + if (testCase?.Options?.Password is not null) x.Password = testCase.Options.Password; + if (testCase?.Options?.SecretKey is not null) x.SecretKey = testCase.Options.SecretKey; + if (testCase?.Options?.TlsCA is not null) x.TLSCertificateAuthority = testCase.Options.TlsCA; + if (testCase?.Options?.TlsCAFile is not null) + { + x.TLSCertificateAuthority = mockSystem.FileReadAllText(testCase.Options.TlsCAFile); + } + if (testCase?.Options?.TlsSecurity is not null) x.TLSSecurity = testCase.Options.TlsSecurity.Value; + if (testCase?.Options?.TlsServerName is not null) x.TLSServerName = testCase.Options.TlsServerName; + + if (credentials?.Host is not null) x.Hostname = credentials.Host; + if (credentials?.Port is not null) x.Port = int.Parse(credentials.Port); + if (credentials?.Database is not null) x.Database = credentials.Database; + if (credentials?.Branch is not null) x.Branch = credentials.Branch; + if (credentials?.User is not null) x.Username = credentials.User; + if (credentials?.Password is not null) x.Password = credentials.Password; + if (credentials?.TlsCA is not null) x.TLSCertificateAuthority = credentials.TlsCA; + if (credentials?.TlsSecurity is not null) x.TLSSecurity = credentials.TlsSecurity.Value; + }, + autoResolve: false, + platform: mockSystem); + + return connection; + } + catch (Exception x) + { + return x; + } + } + + private static EdgeDBConnection? GetExpectedConnection(TestCase testCase) + { + if (testCase.Result is null) + { + return null; + } + else + { + EdgeDBConnection connection = new(); + + if (testCase.Result.User is not null) connection.Username = testCase.Result.User; + if (testCase.Result.Password is not null) connection.Password = testCase.Result.Password; + if (testCase.Result.Address is not null) connection.Hostname = testCase.Result.Address[0]; + if (testCase.Result.Address is not null) connection.Port = int.Parse(testCase.Result.Address[1]); + if (testCase.Result.Database is not null) connection.Database = testCase.Result.Database; + if (testCase.Result.TlsCAData is not null) connection.TLSCertificateAuthority = testCase.Result.TlsCAData; + if (testCase.Result.TlsServerName is not null) connection.TLSServerName = testCase.Result.TlsServerName; + if (testCase.Result.TlsSecurity is not null) connection.TLSSecurity = testCase.Result.TlsSecurity.Value; + if (testCase.Result.WaitUntilAvailable is not null) + { + connection.Timeout = EdgeDBConnection.ParseWaitUntilAvailable(testCase.Result.WaitUntilAvailable); + } + if (testCase.Result.ServerSettings is not null) connection.ServerSettings = testCase.Result.ServerSettings; + + return connection; + } + } + + private class Result + { + public EdgeDBConnection? Connection { get; init; } + public Exception? Exception { get; init; } + + public static implicit operator Result(EdgeDBConnection c) => new() {Connection = c}; + public static implicit operator Result(Exception x) => new() {Exception = x}; + } + + class MockSystemProvider : BaseDefaultSystemProvider + { + private readonly string? _homeDir; + private readonly string? _currentDir; + private readonly Dictionary<string, string> _envVars; + private Dictionary<string, string> _files; + + public List<string> Warnings { get; } = new(); + + public MockSystemProvider(TestCase testCase) + { + _homeDir = testCase.FileSystem?.HomeDir; + _currentDir = testCase.FileSystem?.CurrentDir; + _envVars = testCase.EnvVars ?? new(); + _files = CacheFiles(testCase?.FileSystem?.Files); + } + + private Dictionary<string, string> CacheFiles(Dictionary<string, TestCase.File>? files) + { + return files?.SelectMany( + x => { + string path = x.Key; + TestCase.File file = x.Value; + + if (file.Contents is not null) + { + return new List<(string, string)>(){(path, file.Contents)}; + } + else + { + if (file.Fields is null) + { + throw new Exception("File must be either string or json object of fields"); + } + if (!file.Fields.ContainsKey("project-path")) + { + throw new Exception("File as object must have \"project-path\" field"); + } + + List<(string,string)> subfiles = new(); + + string dir = path.Replace("${HASH}", ProjectPathHash(file.Fields["project-path"])); + + foreach (KeyValuePair<string, string> field in file.Fields) + { + subfiles.Add((CombinePaths(new string[]{ dir, field.Key }), field.Value)); + } + + return subfiles; + } + }) + .ToDictionary(x => x.Item1, x => x.Item2) + ?? new(); + } + + private string ProjectPathHash(string path) + { + if (IsOSPlatform(OSPlatform.Windows) && !path.StartsWith("\\\\")) + { + path = "\\\\?\\" + path; + } + + return Convert.ToHexString(SHA1.HashData(Encoding.UTF8.GetBytes(path))); + } + + public override string GetHomeDir() => _homeDir ?? base.GetHomeDir(); + + public override string GetCurrentDirectory() => _currentDir ?? base.GetCurrentDirectory(); + + public override string? GetEnvVariable(string name) + => _envVars.TryGetValue(name, out var val) + ? val + : null; + + public override bool FileExists(string path) => _files.ContainsKey(path); + + public override string FileReadAllText(string path) => _files[path]; + + public override void WriteWarning(string message) + => Warnings.Add(message); + } + + class TestCase + { + [JsonPropertyName("name")] + public string Name { get; init; } = string.Empty; + + [JsonPropertyName("opts")] + public OptionsData? Options { get; init; } + + [JsonPropertyName("env")] + public Dictionary<string, string>? EnvVars { get; init; } + + [JsonPropertyName("platform")] + public string? Platform { get; init; } + + [JsonPropertyName("fs")] + public FileSystemData? FileSystem { get; init; } + + [JsonPropertyName("warnings")] + public List<string>? Warnings { get; init; } + + [JsonPropertyName("result")] + public ResultData? Result { get; init; } + + [JsonPropertyName("error")] + public ErrorData? Error { get; init; } + + public class OptionsData + { + [JsonPropertyName("instance")] + public string? Instance { get; init; } + + [JsonPropertyName("database")] + public string? Database { get; init; } + + [JsonPropertyName("branch")] + public string? Branch { get; init; } + + [JsonPropertyName("host")] + public string? Host { get; init; } + + [JsonPropertyName("dsn")] + public string? Dsn { get; init; } + + [JsonPropertyName("port")] + [JsonConverter(typeof(AsStringConverter))] + public string? Port { get; init; } + + [JsonPropertyName("user")] + public string? User { get; init; } + + [JsonPropertyName("password")] + public string? Password { get; init; } + + [JsonPropertyName("secretKey")] + public string? SecretKey { get; init; } + + [JsonPropertyName("credentials")] + public string? Credentials { get; init; } + + [JsonPropertyName("credentialsFile")] + public string? CredentialsFile { get; init; } + + [JsonPropertyName("tlsCA")] + public string? TlsCA { get; init; } + + [JsonPropertyName("tlsCAFile")] + public string? TlsCAFile { get; init; } + + [JsonPropertyName("tlsSecurity")] + [JsonConverter(typeof(TLSSecurityModeConverter))] + public TLSSecurityMode? TlsSecurity { get; init; } + + [JsonPropertyName("tlsServerName")] + public string? TlsServerName { get; init; } + + [JsonPropertyName("waitUntilAvailable")] + public string? WaitUntilAvailable { get; init; } + + [JsonPropertyName("serverSettings")] + public Dictionary<string, string>? ServerSettings { get; init; } + } + + public class Credentials + { + [JsonPropertyName("host")] + public string? Host { get; init; } + + [JsonPropertyName("port")] + [JsonConverter(typeof(AsStringConverter))] + public string? Port { get; init; } + + [JsonPropertyName("database")] + public string? Database { get; init; } + + [JsonPropertyName("branch")] + public string? Branch { get; init; } + + [JsonPropertyName("user")] + public string? User { get; init; } + + [JsonPropertyName("password")] + public string? Password { get; init; } + + [JsonPropertyName("tls_ca")] + public string? TlsCA { get; init; } + + [JsonPropertyName("tls_security")] + [JsonConverter(typeof(TLSSecurityModeConverter))] + public TLSSecurityMode? TlsSecurity { get; init; } + } + + public class FileSystemData + { + [JsonPropertyName("cwd")] + public string? CurrentDir { get; init; } + + [JsonPropertyName("homedir")] + public string? HomeDir { get; init; } + + [JsonPropertyName("files")] + public Dictionary<string, File>? Files { get; init; } + } + + [JsonConverter(typeof(FileJsonConverter))] + public class File + { + // Has either string contents or has explicitly defined instance information + + // string contents + public string? Contents { get; init; } + + // instance information + public Dictionary<string, string>? Fields { get; init; } + } + + public class ResultData + { + [JsonPropertyName("address")] + [JsonConverter(typeof(AsListStringConverter))] + public List<string> Address { get; init; } = new(); + + [JsonPropertyName("database")] + public string Database { get; init; } = string.Empty; + + [JsonPropertyName("branch")] + public string Branch { get; init; } = string.Empty; + + [JsonPropertyName("user")] + public string User { get; init; } = string.Empty; + + [JsonPropertyName("password")] + public string? Password { get; init; } + + [JsonPropertyName("secretKey")] + public string? SecretKey { get; init; } + + [JsonPropertyName("tlsCAData")] + public string? TlsCAData { get; init; } + + [JsonPropertyName("tlsServerName")] + public string? TlsServerName { get; init; } + + [JsonPropertyName("tlsSecurity")] + [JsonConverter(typeof(TLSSecurityModeConverter))] + public TLSSecurityMode? TlsSecurity { get; init; } + + [JsonPropertyName("waitUntilAvailable")] + public string? WaitUntilAvailable { get; init; } + + [JsonPropertyName("serverSettings")] + public Dictionary<string, string>? ServerSettings { get; init; } + } + + public class ErrorData + { + [JsonPropertyName("type")] + public string Type { get; init; } = string.Empty; + + public Exception? GetExpectedException() + { + (Type, string) error = _errorMapping[Type]; + + return (Exception?)Activator.CreateInstance(error.Item1, new object[]{error.Item2}); + } + + private static readonly Dictionary<string, (Type, string)> _errorMapping = new() + { + {"credentials_file_not_found", ( + typeof(ConfigurationException), + "cannot read credentials")}, + {"project_not_initialised", ( + typeof(ConfigurationException), + "Found `\\w+.toml` but the project is not initialized")}, + {"no_options_or_toml", ( + typeof(ConfigurationException), + "no `gel.toml` found and no connection options specified")}, + {"invalid_credentials_file", ( + typeof(ConfigurationException), + "cannot read credentials")}, + {"invalid_dsn_or_instance_name", (typeof(ConfigurationException), "invalid DSN or instance name")}, + {"invalid_instance_name", (typeof(ConfigurationException), "invalid instance name")}, + {"invalid_dsn", (typeof(ConfigurationException), "invalid DSN")}, + {"unix_socket_unsupported", (typeof(ConfigurationException), "unix socket paths not supported")}, + {"invalid_host", (typeof(ConfigurationException), "invalid host")}, + {"invalid_port", (typeof(ConfigurationException), "invalid port")}, + {"invalid_user", (typeof(ConfigurationException), "invalid user")}, + {"invalid_database", (typeof(ConfigurationException), "invalid database")}, + {"multiple_compound_env", (typeof(ConfigurationException), "Cannot have more than one of the following connection environment variables")}, + {"multiple_compound_opts", ( + typeof(ConfigurationException), + "Cannot have more than one of the following connection options")}, + {"exclusive_options", ( + typeof(ConfigurationException), + "are mutually exclusive")}, + {"env_not_found", ( + typeof(ConfigurationException), + "environment variable \".*\" doesn\"t exist")}, + {"file_not_found", ( + typeof(ConfigurationException), + "No such file or directory")}, + {"invalid_tls_security", ( + typeof(ConfigurationException), + "tls_security can only be one of `insecure`, |tls_security must be set to strict")}, + {"invalid_secret_key", ( + typeof(ConfigurationException), + "Invalid secret key")}, + {"secret_key_not_found", ( + typeof(ConfigurationException), + "Cannot connect to cloud instances without secret key")}, + {"docker_tcp_port", ( + typeof(ConfigurationException), + "EDGEDB_PORT in \"tcp://host:port\" format, so will be ignored")}, + {"gel_and_edgedb", ( + typeof(ConfigurationException), + "Both GEL_\\w+ and EDGEDB_\\w+ are set; EDGEDB_\\w+ will be ignored")} + }; + } + + private class AsStringConverter : JsonConverter<string> + { + // Always reads numbers as strings + + public override string? Read( + ref Utf8JsonReader reader, + Type typeToConvert, + JsonSerializerOptions options) + { + if (reader.TokenType == JsonTokenType.Number) + { + if (reader.TryGetInt32(out int asInt)) + { + return asInt.ToString(); + } + else + { + return reader.GetDouble().ToString(); + } + } + else if (reader.TokenType == JsonTokenType.String) + { + return reader.GetString(); + } + else + { + throw new JsonException("Expected Number or String."); + } + } + + public override void Write( + Utf8JsonWriter writer, + string value, + JsonSerializerOptions options) + { + throw new NotImplementedException(); + } + } + + private class AsListStringConverter : JsonConverter<List<string>> + { + // Always reads numbers as strings + + public override List<string>? Read( + ref Utf8JsonReader reader, + Type typeToConvert, + JsonSerializerOptions options) + { + List<string> result = new(); + + while (reader.Read()) + { + if (reader.TokenType == JsonTokenType.EndArray) + { + break; + } + + if (reader.TokenType == JsonTokenType.Number) + { + if (reader.TryGetInt32(out int asInt)) + { + result.Add(asInt.ToString()); + } + else + { + result.Add(reader.GetDouble().ToString()); + } + } + else if (reader.TokenType == JsonTokenType.String) + { + string? text = reader.GetString(); + if (text is null) + { + throw new JsonException(); + } + result.Add(text); + } + else + { + throw new JsonException("Expected Number or String."); + } + } + + return result; + } + + public override void Write( + Utf8JsonWriter writer, + List<string> value, + JsonSerializerOptions options) + { + throw new NotImplementedException(); + } + } + + private class TLSSecurityModeConverter : JsonConverter<TLSSecurityMode?> + { + // Always reads numbers as strings + + public override TLSSecurityMode? Read( + ref Utf8JsonReader reader, + Type typeToConvert, + JsonSerializerOptions options) + { + string? text = reader.GetString(); + if (text is not null) + { + if (!TLSSecurityModeParser.TryParse(text, true, out TLSSecurityMode? tlsSecurity)) + { + throw new FormatException($"\"{text}\" must be a value of TLSSecurityMode"); + } + return tlsSecurity ?? TLSSecurityMode.Default; + } + else + { + throw new JsonException("Expected String."); + } + } + + public override void Write( + Utf8JsonWriter writer, + TLSSecurityMode? value, + JsonSerializerOptions options) + { + throw new NotImplementedException(); + } + } + + private class FileJsonConverter : JsonConverter<File> + { + public override File? Read( + ref Utf8JsonReader reader, + Type typeToConvert, + JsonSerializerOptions options) + { + if (reader.TokenType == JsonTokenType.String) + { + // string contents + return new() + { + Contents = reader.GetString(), + }; + } + else if (reader.TokenType == JsonTokenType.StartObject) + { + // instance information + Dictionary<string, string> fields = new(); + while (reader.Read()) + { + if (reader.TokenType == JsonTokenType.EndObject) + { + break; + } + + if (reader.TokenType != JsonTokenType.PropertyName) + { + throw new JsonException(); + } + + string? propertyName = reader.GetString(); + if (propertyName is not null) + { + reader.Read(); + if (reader.TokenType != JsonTokenType.String) + { + throw new JsonException($"Expected string for property \"{propertyName}\""); + } + + string? value = reader.GetString(); + if (value == null) + { + throw new JsonException(); + } + + fields[propertyName] = value; + } + else + { + throw new JsonException(); + } + } + + return new() + { + Fields = fields, + }; + } + else + { + throw new JsonException("Could not read File object."); + } + } + + public override void Write( + Utf8JsonWriter writer, + File value, + JsonSerializerOptions options) + { + throw new NotImplementedException(); + } + } + } +}