From b2605fad7a3491831b1728c70933fe048efe6048 Mon Sep 17 00:00:00 2001 From: Mazdak Farrokhzad Date: Wed, 11 Oct 2023 15:49:34 +0200 Subject: [PATCH] flatten SATS + convenience APIs --- .../client/module_bindings/Message.cs | 6 +- .../module_bindings/SendMessageReducer.cs | 2 +- .../client/module_bindings/SetNameReducer.cs | 2 +- .../quickstart/client/module_bindings/User.cs | 6 +- src/Identity.cs | 20 +- src/SATS/AlgebraicType.cs | 111 +-- src/SATS/AlgebraicValue.cs | 632 ++++++++---------- 7 files changed, 353 insertions(+), 426 deletions(-) diff --git a/examples/quickstart/client/module_bindings/Message.cs b/examples/quickstart/client/module_bindings/Message.cs index 223f57b4..f24904a9 100644 --- a/examples/quickstart/client/module_bindings/Message.cs +++ b/examples/quickstart/client/module_bindings/Message.cs @@ -33,10 +33,10 @@ public static SpacetimeDB.SATS.AlgebraicType GetAlgebraicType() { new SpacetimeDB.SATS.ProductTypeElement("sender", SpacetimeDB.SATS.AlgebraicType.CreateProductType(new SpacetimeDB.SATS.ProductTypeElement[] { - new SpacetimeDB.SATS.ProductTypeElement("__identity_bytes", SpacetimeDB.SATS.AlgebraicType.CreateArrayType(SpacetimeDB.SATS.AlgebraicType.CreatePrimitiveType(SpacetimeDB.SATS.BuiltinType.Type.U8))), + new SpacetimeDB.SATS.ProductTypeElement("__identity_bytes", SpacetimeDB.SATS.AlgebraicType.CreateBytesType()), })), - new SpacetimeDB.SATS.ProductTypeElement("sent", SpacetimeDB.SATS.AlgebraicType.CreatePrimitiveType(SpacetimeDB.SATS.BuiltinType.Type.U64)), - new SpacetimeDB.SATS.ProductTypeElement("text", SpacetimeDB.SATS.AlgebraicType.CreatePrimitiveType(SpacetimeDB.SATS.BuiltinType.Type.String)), + new SpacetimeDB.SATS.ProductTypeElement("sent", SpacetimeDB.SATS.AlgebraicType.CreateU64Type()), + new SpacetimeDB.SATS.ProductTypeElement("text", SpacetimeDB.SATS.AlgebraicType.CreateStringType()), }); } diff --git a/examples/quickstart/client/module_bindings/SendMessageReducer.cs b/examples/quickstart/client/module_bindings/SendMessageReducer.cs index 9b37c0f9..70ee537f 100644 --- a/examples/quickstart/client/module_bindings/SendMessageReducer.cs +++ b/examples/quickstart/client/module_bindings/SendMessageReducer.cs @@ -52,7 +52,7 @@ public static void SendMessageDeserializeEventArgs(ClientApi.Event dbEvent) bsatnBytes.CopyTo(ms.GetBuffer(), 0); ms.Position = 0; using var reader = new System.IO.BinaryReader(ms); - var args_0_value = SpacetimeDB.SATS.AlgebraicValue.Deserialize(SpacetimeDB.SATS.AlgebraicType.CreatePrimitiveType(SpacetimeDB.SATS.BuiltinType.Type.String), reader); + var args_0_value = SpacetimeDB.SATS.AlgebraicValue.Deserialize(SpacetimeDB.SATS.AlgebraicType.CreateStringType(), reader); args.Text = args_0_value.AsString(); dbEvent.FunctionCall.CallInfo = new ReducerEvent(ReducerType.SendMessage, "send_message", dbEvent.Timestamp, Identity.From(dbEvent.CallerIdentity.ToByteArray()), dbEvent.Message, dbEvent.Status, args); } diff --git a/examples/quickstart/client/module_bindings/SetNameReducer.cs b/examples/quickstart/client/module_bindings/SetNameReducer.cs index c4806057..e9cbc87c 100644 --- a/examples/quickstart/client/module_bindings/SetNameReducer.cs +++ b/examples/quickstart/client/module_bindings/SetNameReducer.cs @@ -52,7 +52,7 @@ public static void SetNameDeserializeEventArgs(ClientApi.Event dbEvent) bsatnBytes.CopyTo(ms.GetBuffer(), 0); ms.Position = 0; using var reader = new System.IO.BinaryReader(ms); - var args_0_value = SpacetimeDB.SATS.AlgebraicValue.Deserialize(SpacetimeDB.SATS.AlgebraicType.CreatePrimitiveType(SpacetimeDB.SATS.BuiltinType.Type.String), reader); + var args_0_value = SpacetimeDB.SATS.AlgebraicValue.Deserialize(SpacetimeDB.SATS.AlgebraicType.CreateStringType(), reader); args.Name = args_0_value.AsString(); dbEvent.FunctionCall.CallInfo = new ReducerEvent(ReducerType.SetName, "set_name", dbEvent.Timestamp, Identity.From(dbEvent.CallerIdentity.ToByteArray()), dbEvent.Message, dbEvent.Status, args); } diff --git a/examples/quickstart/client/module_bindings/User.cs b/examples/quickstart/client/module_bindings/User.cs index b50a4502..ee64ea3d 100644 --- a/examples/quickstart/client/module_bindings/User.cs +++ b/examples/quickstart/client/module_bindings/User.cs @@ -37,16 +37,16 @@ public static SpacetimeDB.SATS.AlgebraicType GetAlgebraicType() { new SpacetimeDB.SATS.ProductTypeElement("identity", SpacetimeDB.SATS.AlgebraicType.CreateProductType(new SpacetimeDB.SATS.ProductTypeElement[] { - new SpacetimeDB.SATS.ProductTypeElement("__identity_bytes", SpacetimeDB.SATS.AlgebraicType.CreateArrayType(SpacetimeDB.SATS.AlgebraicType.CreatePrimitiveType(SpacetimeDB.SATS.BuiltinType.Type.U8))), + new SpacetimeDB.SATS.ProductTypeElement("__identity_bytes", SpacetimeDB.SATS.AlgebraicType.CreateBytesType()), })), new SpacetimeDB.SATS.ProductTypeElement("name", SpacetimeDB.SATS.AlgebraicType.CreateSumType(new System.Collections.Generic.List { - new SpacetimeDB.SATS.SumTypeVariant("some", SpacetimeDB.SATS.AlgebraicType.CreatePrimitiveType(SpacetimeDB.SATS.BuiltinType.Type.String)), + new SpacetimeDB.SATS.SumTypeVariant("some", SpacetimeDB.SATS.AlgebraicType.CreateStringType()), new SpacetimeDB.SATS.SumTypeVariant("none", SpacetimeDB.SATS.AlgebraicType.CreateProductType(new SpacetimeDB.SATS.ProductTypeElement[] { })), })), - new SpacetimeDB.SATS.ProductTypeElement("online", SpacetimeDB.SATS.AlgebraicType.CreatePrimitiveType(SpacetimeDB.SATS.BuiltinType.Type.Bool)), + new SpacetimeDB.SATS.ProductTypeElement("online", SpacetimeDB.SATS.AlgebraicType.CreateBoolType()), }); } diff --git a/src/Identity.cs b/src/Identity.cs index 0251a3f9..3f579fef 100644 --- a/src/Identity.cs +++ b/src/Identity.cs @@ -14,25 +14,7 @@ public struct Identity : IEquatable public byte[] Bytes => bytes; - public static AlgebraicType GetAlgebraicType() - { - return new AlgebraicType - { - type = AlgebraicType.Type.Builtin, - builtin = new BuiltinType - { - type = BuiltinType.Type.Array, - arrayType = new AlgebraicType - { - type = AlgebraicType.Type.Builtin, - builtin = new BuiltinType - { - type = BuiltinType.Type.U8 - } - } - } - }; - } + public static AlgebraicType GetAlgebraicType() => AlgebraicType.CreateBytesType(); public static explicit operator Identity(AlgebraicValue v) => new Identity { diff --git a/src/SATS/AlgebraicType.cs b/src/SATS/AlgebraicType.cs index d499a11d..14d74c36 100644 --- a/src/SATS/AlgebraicType.cs +++ b/src/SATS/AlgebraicType.cs @@ -10,20 +10,13 @@ public class SumType { public List variants; - public SumType() - { - variants = new List(); - } + public SumType() => variants = new List(); // TODO(jdetter): Perhaps not needed? - public SumType NewUnnamed() + public SumType NewUnnamed() => new SumType { - var s = new SumType - { - variants = variants.Select(a => new SumTypeVariant(null, a.algebraicType)).ToList() - }; - return s; - } + variants = variants.Select(a => new SumTypeVariant(null, a.algebraicType)).ToList() + }; } public struct SumTypeVariant @@ -42,10 +35,7 @@ public class ProductType { public List elements; - public ProductType() - { - elements = new List(); - } + public ProductType() => elements = new List(); } public struct ProductTypeElement @@ -72,10 +62,13 @@ public struct MapType public AlgebraicType valueType; } - public class BuiltinType + public class AlgebraicType { public enum Type { + TypeRef, + Sum, + Product, Bool, I8, U8, @@ -91,23 +84,7 @@ public enum Type F64, String, Array, - Map - } - - public Type type; - - public AlgebraicType arrayType; - public MapType mapType; - } - - public class AlgebraicType - { - public enum Type - { - Sum, - Product, - Builtin, - TypeRef, + Map, None, } @@ -121,7 +98,7 @@ public SumType sum { type = value == null ? Type.None : Type.Sum; } } - + public ProductType product { get { return type == Type.Product ? (ProductType)type_ : null; } set { @@ -129,12 +106,20 @@ public ProductType product { type = value == null ? Type.None : Type.Product; } } - - public BuiltinType builtin { - get { return type == Type.Builtin ? (BuiltinType)type_ : null; } + + public AlgebraicType array { + get { return type == Type.Array ? (AlgebraicType)type_ : null; } set { type_ = value; - type = value == null ? Type.None : Type.Builtin; + type = value == null ? Type.None : Type.Array; + } + } + + public MapType map { + get { return type == Type.Map ? (MapType)type_ : null; } + set { + type_ = value; + type = value == null ? Type.None : Type.Map; } } @@ -151,19 +136,18 @@ public static AlgebraicType CreateProductType(IEnumerable el return new AlgebraicType { type = Type.Product, - product = new ProductType - { + type_ = new ProductType { elements = elements.ToList() } }; } - + public static AlgebraicType CreateSumType(IEnumerable variants) { return new AlgebraicType { type = Type.Sum, - sum = new SumType + type_ = new SumType { variants = variants.ToList(), } @@ -173,25 +157,42 @@ public static AlgebraicType CreateSumType(IEnumerable variants) public static AlgebraicType CreateArrayType(AlgebraicType elementType) { return new AlgebraicType { - type = Type.Builtin, - builtin = new BuiltinType - { - type = BuiltinType.Type.Array, - arrayType = elementType - } + type = Type.Array, + type_ = elementType }; } - public static AlgebraicType CreatePrimitiveType(BuiltinType.Type type) { + public static AlgebraicType CreateBytesType() => AlgebraicType.CreateArrayType(AlgebraicType.CreateU8Type()); + + public static AlgebraicType CreateMapType(MapType type) { return new AlgebraicType { - type = Type.Builtin, - builtin = new BuiltinType - { - type = type, - } + type = Type.Map, + type_ = type + }; + } + + public static AlgebraicType CreateTypeRef(int idx) { + return new AlgebraicType + { + type = Type.TypeRef, + type_ = idx }; } + public static AlgebraicType CreateBoolType() => new AlgebraicType { type = Type.Bool }; + public static AlgebraicType CreateI8Type() => new AlgebraicType { type = Type.I8 }; + public static AlgebraicType CreateU8Type() => new AlgebraicType { type = Type.U8 }; + public static AlgebraicType CreateI16Type() => new AlgebraicType { type = Type.I16 }; + public static AlgebraicType CreateU16Type() => new AlgebraicType { type = Type.U16 }; + public static AlgebraicType CreateI32Type() => new AlgebraicType { type = Type.I32 }; + public static AlgebraicType CreateU32Type() => new AlgebraicType { type = Type.U32 }; + public static AlgebraicType CreateI64Type() => new AlgebraicType { type = Type.I64 }; + public static AlgebraicType CreateU64Type() => new AlgebraicType { type = Type.U64 }; + public static AlgebraicType CreateI128Type() => new AlgebraicType { type = Type.I128 }; + public static AlgebraicType CreateU128Type() => new AlgebraicType { type = Type.U128 }; + public static AlgebraicType CreateF32Type() => new AlgebraicType { type = Type.F32 }; + public static AlgebraicType CreateF64Type() => new AlgebraicType { type = Type.F64 }; + public static AlgebraicType CreateStringType() => new AlgebraicType { type = Type.String }; } -} \ No newline at end of file +} diff --git a/src/SATS/AlgebraicValue.cs b/src/SATS/AlgebraicValue.cs index 9d78de35..fe82ca14 100644 --- a/src/SATS/AlgebraicValue.cs +++ b/src/SATS/AlgebraicValue.cs @@ -5,291 +5,6 @@ namespace SpacetimeDB.SATS { - public struct BuiltinValue - { - private object value; - - public bool AsBool() => (bool)value; - public sbyte AsI8() => (sbyte)value; - public byte AsU8() => (byte)value; - public short AsI16() => (short)value; - public ushort AsU16() => (ushort)value; - public int AsI32() => (int)value; - public uint AsU32() => (uint)value; - public long AsI64() => (long)value; - public ulong AsU64() => (ulong)value; - public byte[] AsI128() => (byte[])value; - public byte[] AsU128() => (byte[])value; - public float AsF32() => (float)value; - public double AsF64() => (double)value; - public byte[] AsBytes() => (byte[])value; - public string AsString() => (string)value; - public List AsArray() => (List)value; - public Dictionary AsMap() => (Dictionary)value; - - public static BuiltinValue CreateBool(bool value) => new BuiltinValue { value = value }; - public static BuiltinValue CreateI8(sbyte value) => new BuiltinValue { value = value }; - public static BuiltinValue CreateU8(byte value) => new BuiltinValue { value = value }; - public static BuiltinValue CreateI16(short value) => new BuiltinValue { value = value }; - public static BuiltinValue CreateU16(ushort value) => new BuiltinValue { value = value }; - public static BuiltinValue CreateI32(int value) => new BuiltinValue { value = value }; - public static BuiltinValue CreateU32(uint value) => new BuiltinValue { value = value }; - public static BuiltinValue CreateI64(long value) => new BuiltinValue { value = value }; - public static BuiltinValue CreateU64(ulong value) => new BuiltinValue { value = value }; - public static BuiltinValue CreateI128(byte[] value) => new BuiltinValue { value = value }; - public static BuiltinValue CreateU128(byte[] value) => new BuiltinValue { value = value }; - public static BuiltinValue CreateF32(float value) => new BuiltinValue { value = value }; - public static BuiltinValue CreateF64(double value) => new BuiltinValue { value = value }; - public static BuiltinValue CreateString(string value) => new BuiltinValue { value = value }; - public static BuiltinValue CreateBytes(byte[] value) => new BuiltinValue { value = value }; - public static BuiltinValue CreateArray(List value) => new BuiltinValue { value = value }; - public static BuiltinValue CreateMap(Dictionary value) => new BuiltinValue { value = value }; - - public void Serialize(BuiltinType type, BinaryWriter writer) - { - void WriteByteBuffer(byte[] buf) - { - if (buf.LongLength > uint.MaxValue) - { - throw new Exception("Serializing a buffer that is too large for SATS."); - } - - writer.Write((uint)buf.LongLength); - writer.Write(buf); - } - - switch (type.type) - { - case BuiltinType.Type.Bool: - writer.Write(AsBool()); - break; - case BuiltinType.Type.I8: - writer.Write(AsI8()); - break; - case BuiltinType.Type.U8: - writer.Write(AsU8()); - break; - case BuiltinType.Type.I16: - writer.Write(AsI16()); - break; - case BuiltinType.Type.U16: - writer.Write(AsU16()); - break; - case BuiltinType.Type.I32: - writer.Write(AsI32()); - break; - case BuiltinType.Type.U32: - writer.Write(AsU32()); - break; - case BuiltinType.Type.I64: - writer.Write(AsI64()); - break; - case BuiltinType.Type.U64: - writer.Write(AsU64()); - break; - case BuiltinType.Type.I128: - writer.Write(AsI128()); - break; - case BuiltinType.Type.U128: - writer.Write(AsU128()); - break; - case BuiltinType.Type.F32: - writer.Write(AsF32()); - break; - case BuiltinType.Type.F64: - writer.Write(AsF64()); - break; - case BuiltinType.Type.String: - WriteByteBuffer(System.Text.Encoding.UTF8.GetBytes((string)value)); - break; - case BuiltinType.Type.Array: - if (type.arrayType.type == AlgebraicType.Type.Builtin && - type.arrayType.builtin.type == BuiltinType.Type.U8) - { - WriteByteBuffer(AsBytes()); - break; - } - - var array = AsArray(); - writer.Write(array.Count); - foreach (var entry in array) - { - entry.Serialize(type.arrayType, writer); - } - break; - case BuiltinType.Type.Map: - throw new NotImplementedException(); - } - } - - public static BuiltinValue Deserialize(BuiltinType type, BinaryReader reader) - { - byte[] ReadByteArray() - { - var len = reader.ReadUInt32(); - if (len > int.MaxValue) - { - var arrays = new List(); - long read = 0; - while (read < len) - { - var remaining = len - read; - var readResult = reader.ReadBytes(remaining > int.MaxValue ? int.MaxValue : (int)remaining); - arrays.Add(readResult); - read += readResult.Length; - } - - var result = new byte[len]; - long pos = 0; - foreach (var array in arrays) - { - Array.Copy(array, 0, result, pos, array.LongLength); - pos += array.LongLength; - } - - return result; - } - - return reader.ReadBytes((int)len); - } - - switch (type.type) - { - case BuiltinType.Type.Bool: - return CreateBool(reader.ReadByte() != 0); - case BuiltinType.Type.I8: - return CreateI8(reader.ReadSByte()); - case BuiltinType.Type.U8: - return CreateU8(reader.ReadByte()); - case BuiltinType.Type.I16: - return CreateI16(reader.ReadInt16()); - case BuiltinType.Type.U16: - return CreateU16(reader.ReadUInt16()); - case BuiltinType.Type.I32: - return CreateI32(reader.ReadInt32()); - case BuiltinType.Type.U32: - return CreateU32(reader.ReadUInt32()); - case BuiltinType.Type.I64: - return CreateI64(reader.ReadInt64()); - case BuiltinType.Type.U64: - return CreateU64(reader.ReadUInt64()); - case BuiltinType.Type.I128: - return CreateI128(reader.ReadBytes(16)); - case BuiltinType.Type.U128: - return CreateU128(reader.ReadBytes(16)); - case BuiltinType.Type.F32: - return CreateF32(reader.ReadSingle()); - case BuiltinType.Type.F64: - return CreateF64(reader.ReadDouble()); - case BuiltinType.Type.String: - return CreateString(System.Text.Encoding.UTF8.GetString(ReadByteArray())); - case BuiltinType.Type.Array: - if (type.arrayType.type == AlgebraicType.Type.Builtin && - type.arrayType.builtin.type == BuiltinType.Type.U8) - { - return CreateBytes(ReadByteArray()); - } - - var length = reader.ReadInt32(); - var arrayResult = new List(); - for (var x = 0; x < length; x++) - { - arrayResult.Add(AlgebraicValue.Deserialize(type.arrayType, reader)); - } - - return CreateArray(arrayResult); - case BuiltinType.Type.Map: - { - var len = reader.ReadUInt32(); - var mapResult = new Dictionary(); - for (var x = 0; x < len; x++) - { - var key = AlgebraicValue.Deserialize(type.mapType.keyType, reader); - var value = AlgebraicValue.Deserialize(type.mapType.valueType, reader); - mapResult.Add(key, value); - } - - return CreateMap(mapResult); - } - default: - throw new NotImplementedException(); - } - } - - public static bool Compare(BuiltinType t, BuiltinValue v1, BuiltinValue v2) - { - switch (t.type) - { - case BuiltinType.Type.Bool: - return v1.AsBool() == v2.AsBool(); - case BuiltinType.Type.U8: - return v1.AsU8() == v2.AsU8(); - case BuiltinType.Type.I8: - return v1.AsI8() == v2.AsI8(); - case BuiltinType.Type.U16: - return v1.AsU16() == v2.AsU16(); - case BuiltinType.Type.I16: - return v1.AsI16() == v2.AsI16(); - case BuiltinType.Type.U32: - return v1.AsU32() == v2.AsU32(); - case BuiltinType.Type.I32: - return v1.AsI32() == v2.AsI32(); - case BuiltinType.Type.U64: - return v1.AsU64() == v2.AsU64(); - case BuiltinType.Type.I64: - return v1.AsI64() == v2.AsI64(); - case BuiltinType.Type.U128: - case BuiltinType.Type.I128: - case BuiltinType.Type.F32: - case BuiltinType.Type.F64: - case BuiltinType.Type.Map: - throw new NotImplementedException(); - case BuiltinType.Type.String: - return v1.AsString() == v2.AsString(); - case BuiltinType.Type.Array: - if (t.arrayType.type == AlgebraicType.Type.Builtin && - t.arrayType.builtin.type == BuiltinType.Type.U8) - { - var arr1 = v1.AsBytes(); - var arr2 = v2.AsBytes(); - - if (arr1.Length != arr2.Length) - { - return false; - } - - for (var i = 0; i < arr1.Length; i++) - { - if (arr1[i] != arr2[i]) - { - return false; - } - } - - return true; - } - - var list1 = v1.AsArray(); - var list2 = v2.AsArray(); - if (list1.Count != list2.Count) - { - return false; - } - - for (var i = 0; i < list1.Count; i++) - { - if (!AlgebraicValue.Compare(t.arrayType, list1[i], list2[i])) - { - return false; - } - } - return true; - default: - throw new NotImplementedException(); - } - } - } - public class SumValue { public byte tag; @@ -380,63 +95,135 @@ public static bool Compare(ProductType type, ProductValue v1, ProductValue v2) public class AlgebraicValue { - public SumValue sum; - public ProductValue product; - public BuiltinValue builtin; - - public bool AsBool() => builtin.AsBool(); - public sbyte AsI8() => builtin.AsI8(); - public byte AsU8() => builtin.AsU8(); - public short AsI16() => builtin.AsI16(); - public ushort AsU16() => builtin.AsU16(); - public int AsI32() => builtin.AsI32(); - public uint AsU32() => builtin.AsU32(); - public long AsI64() => builtin.AsI64(); - public ulong AsU64() => builtin.AsU64(); - public byte[] AsI128() => builtin.AsI128(); - public byte[] AsU128() => builtin.AsU128(); - public float AsF32() => builtin.AsF32(); - public double AsF64() => builtin.AsF64(); - public string AsString() => builtin.AsString(); - public byte[] AsBytes() => builtin.AsBytes(); - public List AsArray() => builtin.AsArray(); - public Dictionary AsMap() => builtin.AsMap(); - public static AlgebraicValue CreateBool(bool v) => new AlgebraicValue { builtin = BuiltinValue.CreateBool(v) }; - public static AlgebraicValue CreateI8(sbyte v) => new AlgebraicValue { builtin = BuiltinValue.CreateI8(v) }; - public static AlgebraicValue CreateU8(byte v) => new AlgebraicValue { builtin = BuiltinValue.CreateU8(v) }; - public static AlgebraicValue CreateI16(short v) => new AlgebraicValue { builtin = BuiltinValue.CreateI16(v) }; - public static AlgebraicValue CreateU16(ushort v) => new AlgebraicValue { builtin = BuiltinValue.CreateU16(v) }; - public static AlgebraicValue CreateI32(int v) => new AlgebraicValue { builtin = BuiltinValue.CreateI32(v) }; - public static AlgebraicValue CreateU32(uint v) => new AlgebraicValue { builtin = BuiltinValue.CreateU32(v) }; - public static AlgebraicValue CreateI64(long v) => new AlgebraicValue { builtin = BuiltinValue.CreateI64(v) }; - public static AlgebraicValue CreateU64(ulong v) => new AlgebraicValue { builtin = BuiltinValue.CreateU64(v) }; - public static AlgebraicValue CreateI128(byte[] v) => new AlgebraicValue { builtin = BuiltinValue.CreateI128(v) }; - public static AlgebraicValue CreateU128(byte[] v) => new AlgebraicValue { builtin = BuiltinValue.CreateU128(v) }; - public static AlgebraicValue CreateF32(float v) => new AlgebraicValue { builtin = BuiltinValue.CreateF32(v) }; - public static AlgebraicValue CreateF64(double v) => new AlgebraicValue { builtin = BuiltinValue.CreateF64(v) }; - public static AlgebraicValue CreateString(string v) => new AlgebraicValue { builtin = BuiltinValue.CreateString(v) }; - public static AlgebraicValue CreateBytes(byte[] v) => new AlgebraicValue { builtin = BuiltinValue.CreateBytes(v) }; - public static AlgebraicValue CreateArray(List v) => new AlgebraicValue { builtin = BuiltinValue.CreateArray(v) }; - public static AlgebraicValue CreateMap(Dictionary v) => new AlgebraicValue { builtin = BuiltinValue.CreateMap(v) }; - - public BuiltinValue AsBuiltInValue() => builtin; - public ProductValue AsProductValue() => product; - public SumValue AsSumValue() => sum; - - public static AlgebraicValue Create(BuiltinValue value) => new AlgebraicValue { builtin = value }; - public static AlgebraicValue Create(ProductValue value) => new AlgebraicValue { product = value }; - public static AlgebraicValue Create(SumValue value) => new AlgebraicValue { sum = value }; + private object value; + + public SumValue AsSumValue() => (SumValue)value; + public ProductValue AsProductValue() => (ProductValue)value; + public List AsArray() => (List)value; + public SortedDictionary AsMap() => (SortedDictionary)value; + public bool AsBool() => (bool)value; + public sbyte AsI8() => (sbyte)value; + public byte AsU8() => (byte)value; + public short AsI16() => (short)value; + public ushort AsU16() => (ushort)value; + public int AsI32() => (int)value; + public uint AsU32() => (uint)value; + public long AsI64() => (long)value; + public ulong AsU64() => (ulong)value; + public byte[] AsI128() => (byte[])value; + public byte[] AsU128() => (byte[])value; + public float AsF32() => (float)value; + public double AsF64() => (double)value; + public string AsString() => (string)value; + public byte[] AsBytes() => (byte[])value; + + public static AlgebraicValue CreateProduct(ProductValue value) => new AlgebraicValue { value = value }; + public static AlgebraicValue CreateSum(SumValue value) => new AlgebraicValue { value = value }; + public static AlgebraicValue CreateArray(List value) => new AlgebraicValue { value = value }; + public static AlgebraicValue CreateMap(SortedDictionary value) => new AlgebraicValue { value = value }; + public static AlgebraicValue CreateBool(bool value) => new AlgebraicValue { value = value }; + public static AlgebraicValue CreateI8(sbyte value) => new AlgebraicValue { value = value }; + public static AlgebraicValue CreateU8(byte value) => new AlgebraicValue { value = value }; + public static AlgebraicValue CreateI16(short value) => new AlgebraicValue { value = value }; + public static AlgebraicValue CreateU16(ushort value) => new AlgebraicValue { value = value }; + public static AlgebraicValue CreateI32(int value) => new AlgebraicValue { value = value }; + public static AlgebraicValue CreateU32(uint value) => new AlgebraicValue { value = value }; + public static AlgebraicValue CreateI64(long value) => new AlgebraicValue { value = value }; + public static AlgebraicValue CreateU64(ulong value) => new AlgebraicValue { value = value }; + public static AlgebraicValue CreateI128(byte[] value) => new AlgebraicValue { value = value }; + public static AlgebraicValue CreateU128(byte[] value) => new AlgebraicValue { value = value }; + public static AlgebraicValue CreateF32(float value) => new AlgebraicValue { value = value }; + public static AlgebraicValue CreateF64(double value) => new AlgebraicValue { value = value }; + public static AlgebraicValue CreateString(string value) => new AlgebraicValue { value = value }; + public static AlgebraicValue CreateBytes(byte[] value) => new AlgebraicValue { value = value }; + + private static bool compareBytes(byte[] arr1, byte[] arr2) + { + if (arr1.Length != arr2.Length) + { + return false; + } + + for (var i = 0; i < arr1.Length; i++) + { + if (arr1[i] != arr2[i]) + { + return false; + } + } + + return true; + } public static bool Compare(AlgebraicType t, AlgebraicValue v1, AlgebraicValue v2) { switch (t.type) { - case AlgebraicType.Type.Builtin: - return BuiltinValue.Compare(t.builtin, v1.builtin, v2.builtin); case AlgebraicType.Type.Sum: - return SumValue.Compare(t.sum, v1.sum, v2.sum); + return SumValue.Compare(t.sum, v1.AsSumValue(), v2.AsSumValue()); case AlgebraicType.Type.Product: - return ProductValue.Compare(t.product, v1.product, v2.product); + return ProductValue.Compare(t.product, v1.AsProductValue(), v2.AsProductValue()); + case AlgebraicType.Type.Array: + // Fast path for byte arrays. + if (t.arrayType.type == AlgebraicType.Type.U8) + { + return AlgebraicValue.compareBytes(v1.AsBytes(), v2.AsBytes()); + } + + var list1 = v1.AsArray(); + var list2 = v2.AsArray(); + if (list1.Count != list2.Count) + { + return false; + } + + for (var i = 0; i < list1.Count; i++) + { + if (!AlgebraicValue.Compare(t.arrayType, list1[i], list2[i])) + { + return false; + } + } + return true; + case AlgebraicType.Type.Map: + var dict1 = v1.AsMap(); + var dict2 = v2.AsMap(); + // First a fast length check and then ensure that + // for every key in the first dict there's a key in the second + // where their values match. + return dict1.Count == dict2.Count && dict1.All( + (dict1KV) => d2.TryGetValue(dict1KV.Key, out var dict2Value) && + AlgebraicValue.Compare(t.valueType, dict1KV.Value, dict2Value) + ); + case AlgebraicType.Type.Bool: + return v1.AsBool() == v2.AsBool(); + case AlgebraicType.Type.U8: + return v1.AsU8() == v2.AsU8(); + case AlgebraicType.Type.I8: + return v1.AsI8() == v2.AsI8(); + case AlgebraicType.Type.U16: + return v1.AsU16() == v2.AsU16(); + case AlgebraicType.Type.I16: + return v1.AsI16() == v2.AsI16(); + case AlgebraicType.Type.U32: + return v1.AsU32() == v2.AsU32(); + case AlgebraicType.Type.I32: + return v1.AsI32() == v2.AsI32(); + case AlgebraicType.Type.U64: + return v1.AsU64() == v2.AsU64(); + case AlgebraicType.Type.I64: + return v1.AsI64() == v2.AsI64(); + case AlgebraicType.Type.U128: + return AlgebraicValue.compareBytes(v1.AsU128(), v2.AsU128()); + case AlgebraicType.Type.I128: + return AlgebraicValue.compareBytes(v1.AsI128(), v2.AsI128()); + // For floats, match the semantics of Rust in not accounting for epsilon. + case AlgebraicType.Type.F32: + return v1.AsF32() == v2.AsF32(); + case AlgebraicType.Type.F64: + return v1.AsF64() == v2.AsF64(); + case AlgebraicType.Type.String: + return v1.AsString() == v2.AsString(); case AlgebraicType.Type.TypeRef: case AlgebraicType.Type.None: default: @@ -448,33 +235,190 @@ public static AlgebraicValue Deserialize(AlgebraicType type, BinaryReader reader { switch (type.type) { - case AlgebraicType.Type.Builtin: - return Create(BuiltinValue.Deserialize(type.builtin, reader)); - case AlgebraicType.Type.Product: - return Create(ProductValue.Deserialize(type.product, reader)); case AlgebraicType.Type.Sum: return Create(SumValue.Deserialize(type.sum, reader)); + case AlgebraicType.Type.Product: + return Create(ProductValue.Deserialize(type.product, reader)); + case AlgebraicType.Type.Array: + if (type.arrayType.type == AlgebraicType.Type.U8) + { + return CreateBytes(ReadByteArray()); + } + + var length = reader.ReadInt32(); + var arrayResult = new List(); + for (var x = 0; x < length; x++) + { + arrayResult.Add(AlgebraicValue.Deserialize(type.arrayType, reader)); + } + + return CreateArray(arrayResult); + case AlgebraicType.Type.Map: + { + var len = reader.ReadUInt32(); + var mapResult = new SortedDictionary(); + for (var x = 0; x < len; x++) + { + var key = AlgebraicValue.Deserialize(type.mapType.keyType, reader); + var value = AlgebraicValue.Deserialize(type.mapType.valueType, reader); + mapResult.Add(key, value); + } + + return CreateMap(mapResult); + } + case AlgebraicType.Type.Bool: + return CreateBool(reader.ReadByte() != 0); + case AlgebraicType.Type.I8: + return CreateI8(reader.ReadSByte()); + case AlgebraicType.Type.U8: + return CreateU8(reader.ReadByte()); + case AlgebraicType.Type.I16: + return CreateI16(reader.ReadInt16()); + case AlgebraicType.Type.U16: + return CreateU16(reader.ReadUInt16()); + case AlgebraicType.Type.I32: + return CreateI32(reader.ReadInt32()); + case AlgebraicType.Type.U32: + return CreateU32(reader.ReadUInt32()); + case AlgebraicType.Type.I64: + return CreateI64(reader.ReadInt64()); + case AlgebraicType.Type.U64: + return CreateU64(reader.ReadUInt64()); + case AlgebraicType.Type.I128: + return CreateI128(reader.ReadBytes(16)); + case AlgebraicType.Type.U128: + return CreateU128(reader.ReadBytes(16)); + case AlgebraicType.Type.F32: + return CreateF32(reader.ReadSingle()); + case AlgebraicType.Type.F64: + return CreateF64(reader.ReadDouble()); + case AlgebraicType.Type.String: + return CreateString(System.Text.Encoding.UTF8.GetString(ReadByteArray())); default: throw new NotImplementedException(); } + + byte[] ReadByteArray() + { + var len = reader.ReadUInt32(); + if (len > int.MaxValue) + { + var arrays = new List(); + long read = 0; + while (read < len) + { + var remaining = len - read; + var readResult = reader.ReadBytes(remaining > int.MaxValue ? int.MaxValue : (int)remaining); + arrays.Add(readResult); + read += readResult.Length; + } + + var result = new byte[len]; + long pos = 0; + foreach (var array in arrays) + { + Array.Copy(array, 0, result, pos, array.LongLength); + pos += array.LongLength; + } + + return result; + } + + return reader.ReadBytes((int)len); + } } public void Serialize(AlgebraicType type, BinaryWriter writer) { switch (type.type) { - case AlgebraicType.Type.Builtin: - builtin.Serialize(type.builtin, writer); + case AlgebraicType.Type.Sum: + AsSumValue().Serialize(type.sum, writer); break; case AlgebraicType.Type.Product: - product.Serialize(type.product, writer); + AsProductValue().Serialize(type.product, writer); break; - case AlgebraicType.Type.Sum: - sum.Serialize(type.sum, writer); + case AlgebraicType.Type.Array: + if (type.arrayType.type == AlgebraicType.Type.U8) + { + WriteByteBuffer(AsBytes()); + break; + } + + var array = AsArray(); + writer.Write(array.Count); + foreach (var entry in array) + { + entry.Serialize(type.arrayType, writer); + } + break; + case AlgebraicType.Type.Map: + // The map is sorted by key, just like `BTreeMap` in Rust + // so we can serialize deterministically. + var map = AsMap(); + writer.Write(map.Count); + foreach( KeyValuePair kv in map ) + { + kv.Key.Serialize(type.keyType, writer); + kv.Value.Serialize(type.valueType, writer); + } + break; + case AlgebraicType.Type.Bool: + writer.Write(AsBool()); + break; + case AlgebraicType.Type.I8: + writer.Write(AsI8()); + break; + case AlgebraicType.Type.U8: + writer.Write(AsU8()); + break; + case AlgebraicType.Type.I16: + writer.Write(AsI16()); + break; + case AlgebraicType.Type.U16: + writer.Write(AsU16()); + break; + case AlgebraicType.Type.I32: + writer.Write(AsI32()); + break; + case AlgebraicType.Type.U32: + writer.Write(AsU32()); + break; + case AlgebraicType.Type.I64: + writer.Write(AsI64()); + break; + case AlgebraicType.Type.U64: + writer.Write(AsU64()); + break; + case AlgebraicType.Type.I128: + writer.Write(AsI128()); + break; + case AlgebraicType.Type.U128: + writer.Write(AsU128()); + break; + case AlgebraicType.Type.F32: + writer.Write(AsF32()); + break; + case AlgebraicType.Type.F64: + writer.Write(AsF64()); + break; + case AlgebraicType.Type.String: + WriteByteBuffer(System.Text.Encoding.UTF8.GetBytes((string)value)); break; default: throw new NotImplementedException(); } + + void WriteByteBuffer(byte[] buf) + { + if (buf.LongLength > uint.MaxValue) + { + throw new Exception("Serializing a buffer that is too large for SATS."); + } + + writer.Write((uint)buf.LongLength); + writer.Write(buf); + } } public class AlgebraicValueComparer : IEqualityComparer