Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Roslyn nullable annotations #537

Merged
merged 5 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions src/Parquet.Test/DataAnalysis/DataFrameReaderTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@ namespace Parquet.Test.DataAnalysis {
public class DataFrameReaderTest : TestBase {

[Theory]
[InlineData(typeof(short), (short)1, (short)2)]
[InlineData(typeof(short?), null, (short)2)]
[InlineData(typeof(int), 1, 2)]
[InlineData(typeof(int?), null, 2)]
[InlineData(typeof(bool), true, false)]
[InlineData(typeof(bool?), true, null)]
[InlineData(typeof(long), 1L, 2L)]
[InlineData(typeof(long?), 1L, 2L)]
[InlineData(typeof(ulong), 1UL, 2UL)]
[InlineData(typeof(ulong?), 1UL, 2UL)]
[InlineData(typeof(string), "1", "2")]
[InlineData(typeof(string), null, "2")]
public async Task Roundtrip_all_types(Type t, object? el1, object? el2) {
[InlineData(typeof(short), false, (short)1, (short)2)]
[InlineData(typeof(short?), false, null, (short)2)]
[InlineData(typeof(int), false, 1, 2)]
[InlineData(typeof(int?), false, null, 2)]
[InlineData(typeof(bool), false, true, false)]
[InlineData(typeof(bool?), false, true, null)]
[InlineData(typeof(long), false, 1L, 2L)]
[InlineData(typeof(long?), false, 1L, 2L)]
[InlineData(typeof(ulong), false, 1UL, 2UL)]
[InlineData(typeof(ulong?), false, 1UL, 2UL)]
[InlineData(typeof(string), false, "1", "2")]
[InlineData(typeof(string), true, null, "2")]
public async Task Roundtrip_all_types(Type t, bool makeNullable, object? el1, object? el2) {

// arrange
using var ms = new MemoryStream();
Expand All @@ -34,7 +34,7 @@ public async Task Roundtrip_all_types(Type t, object? el1, object? el2) {


// make schema
var schema = new ParquetSchema(new DataField(t.Name, t));
var schema = new ParquetSchema(new DataField(t.Name, t, isNullable: makeNullable?true:null, isCompiledWithNullable: true));

// make data
using(ParquetWriter writer = await ParquetWriter.CreateAsync(schema, ms)) {
Expand Down
2 changes: 1 addition & 1 deletion src/Parquet.Test/DictionaryEncodingTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public async Task DictionaryEncodingTest2() {
"zzz",
};

var dataField = new DataField<string>("string");
var dataField = new DataField<string>("string", true);
var parquetSchema = new ParquetSchema(dataField);

using var stream = new MemoryStream();
Expand Down
4 changes: 2 additions & 2 deletions src/Parquet.Test/Rows/RowsModelTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ public async Task Struct_write_read_with_null_entry() {
new ParquetSchema(
new DataField<string>("isbn"),
new StructField("author",
new DataField<string>("firstName"),
new DataField<string>("lastName"))));
new DataField<string>("firstName", true),
new DataField<string>("lastName", true))));
var ms = new MemoryStream();

table.Add("12345-6", new Row("Hazel", "Nut"));
Expand Down
21 changes: 21 additions & 0 deletions src/Parquet.Test/Serialisation/SchemaReflectorTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,27 @@ public void Strings_OptionalAndRequired() {
Assert.True(s.DataFields[2].IsNullable);
}


class PocoClassNotNullable {
[JsonPropertyName("id")]
public long Id { get; set; }

[JsonPropertyName("value")]
public string Value { get; set; } = null!;

[JsonPropertyName("frequency")]
public double Frequency { get; set; }
}

[Fact]
public void Strings_NotNullable() {
ParquetSchema s = typeof(PocoClassNotNullable).GetParquetSchema(true);
Assert.False(s.DataFields[0].IsNullable);
Assert.False(s.DataFields[1].IsNullable);
Assert.False(s.DataFields[2].IsNullable);
}


public interface IInterface {
int Id { get; set; }
}
Expand Down
6 changes: 4 additions & 2 deletions src/Parquet/Encodings/ParquetPlainEncoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1072,8 +1072,10 @@ public static void Encode(ReadOnlySpan<string> data, Stream destination) {
else
Array.Copy(BitConverter.GetBytes(len), 0, rb, rbOffset, sizeof(int));
rbOffset += sizeof(int);
E.GetBytes(s, 0, s.Length, rb, rbOffset);
rbOffset += len;
if(len > 0) {
E.GetBytes(s, 0, s.Length, rb, rbOffset);
rbOffset += len;
}
}

if(rbOffset > 0)
Expand Down
13 changes: 13 additions & 0 deletions src/Parquet/Extensions/TypeExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ public static bool IsNullable(this Type t) {
(ti.IsGenericType && ti.GetGenericTypeDefinition() == typeof(Nullable<>));
}

public static bool IsNullableStrict(this Type t) {
TypeInfo ti = t.GetTypeInfo();

return
(ti.IsGenericType && ti.GetGenericTypeDefinition() == typeof(Nullable<>));
}

public static bool IsSystemNullable(this Type t) {
TypeInfo ti = t.GetTypeInfo();

Expand All @@ -128,6 +135,12 @@ public static Type GetNonNullable(this Type t) {
return ti.GenericTypeArguments[0];
}

public static bool CanNullifyType(this Type t) {
TypeInfo ti = t.GetTypeInfo();

return !ti.IsClass;
}

public static Type GetNullable(this Type t) {
TypeInfo ti = t.GetTypeInfo();

Expand Down
4 changes: 2 additions & 2 deletions src/Parquet/Parquet.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@
<ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.0'">
<PackageReference Include="System.Reflection.Emit.Lightweight" Version="4.7.0" />
<PackageReference Include="System.Threading.Tasks.Extensions" Version="4.5.4" />
<PackageReference Include="System.Text.Json" Version="8.0.3" />
<PackageReference Include="System.Text.Json" Version="8.0.4" />
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.1'">
<PackageReference Include="System.Text.Json" Version="8.0.3" />
<PackageReference Include="System.Text.Json" Version="8.0.4" />
</ItemGroup>


Expand Down
2 changes: 1 addition & 1 deletion src/Parquet/ParquetExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ public static async Task<DataFrame> ReadParquetAsDataFrameAsync(
public static async Task WriteAsync(this DataFrame df, Stream outputStream, CancellationToken cancellationToken = default) {
// create schema
var schema = new ParquetSchema(
df.Columns.Select(col => new DataField(col.Name, col.DataType.GetNullable())));
df.Columns.Select(col => new DataField(col.Name, col.DataType.GetNullable(), isNullable: col.DataType.CanNullifyType() ? null : col.NullCount > 0)));

using ParquetWriter writer = await ParquetWriter.CreateAsync(schema, outputStream, cancellationToken: cancellationToken);
using ParquetRowGroupWriter rgw = writer.CreateRowGroup();
Expand Down
9 changes: 5 additions & 4 deletions src/Parquet/Schema/DataField.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,14 @@ public bool IsArray {
/// </summary>
/// <param name="name">Field name</param>
/// <param name="clrType">CLR type of this field. The type is internally discovered and expanded into appropriate Parquet flags.</param>
/// <param name="isCompiledWithNullable">Indicates if the source type was compiled with nullable enabled or not.</param>
/// <param name="isNullable">When set, will override <see cref="IsNullable"/> attribute regardless whether passed type was nullable or not.</param>
/// <param name="isArray">When set, will override <see cref="IsArray"/> attribute regardless whether passed type was an array or not.</param>
/// <param name="propertyName">When set, uses this property to get the field's data. When not set, uses the property that matches the name parameter.</param>
public DataField(string name, Type clrType, bool? isNullable = null, bool? isArray = null, string? propertyName = null)
public DataField(string name, Type clrType, bool? isNullable = null, bool? isArray = null, string? propertyName = null, bool? isCompiledWithNullable = null)
: base(name, SchemaType.Data) {

Discover(clrType, out Type baseType, out bool discIsArray, out bool discIsNullable);
Discover(clrType, isCompiledWithNullable ?? true, out Type baseType, out bool discIsArray, out bool discIsNullable);
ClrType = baseType;
if(!SchemaEncoder.IsSupported(ClrType)) {
if(baseType == typeof(DateTimeOffset)) {
Expand Down Expand Up @@ -167,7 +168,7 @@ public override bool Equals(object? obj) {

#region [ Type Resolution ]

private static void Discover(Type t, out Type baseType, out bool isArray, out bool isNullable) {
private static void Discover(Type t, bool isCompiledWithNullable, out Type baseType, out bool isArray, out bool isNullable) {
baseType = t;
isArray = false;
isNullable = false;
Expand All @@ -182,7 +183,7 @@ private static void Discover(Type t, out Type baseType, out bool isArray, out bo
isArray = true;
}

if(baseType.IsNullable()) {
if (baseType.IsNullable()) {
baseType = baseType.GetNonNullable();
isNullable = true;
}
Expand Down
103 changes: 83 additions & 20 deletions src/Parquet/Serialization/TypeExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using Parquet.Encodings;
using Parquet.Schema;
using Parquet.Serialization.Attributes;
using Parquet.Utils;

namespace Parquet.Serialization {

Expand Down Expand Up @@ -39,8 +40,8 @@ public string ColumnName {

public abstract Type MemberType { get; }

public int? Order {
get{
public int? Order {
get {
#if NETCOREAPP3_1
return null;
#else
Expand All @@ -66,6 +67,57 @@ public bool ShouldIgnore {

public ParquetDecimalAttribute? DecimalAttribute => _mi.GetCustomAttribute<ParquetDecimalAttribute>();

/// <summary>
/// https://github.com/dotnet/roslyn/blob/main/docs/features/nullable-metadata.md
/// This check if a class T (nullable by default) doesn't have the nullable mark.
/// Every class should be considered nullable unless the compiler has been instructed to make it non-nullable.
/// </summary>
/// <returns></returns>
public bool? IsNullable(Type finalType) {
if(finalType.IsClass == false)
return null;
bool isCompiledWithNullable = _mi.DeclaringType?.CustomAttributes
.Any(attr => attr.AttributeType.Name == "NullableAttribute") == true;
if(!isCompiledWithNullable) {
return null;
}

// Check if any properties have the NullableContextAttribute
CustomAttributeData? nullableAttribute = _mi.CustomAttributes
.FirstOrDefault(attr => attr.AttributeType.Name == "NullableAttribute");

byte? attributeFlag = null;
if(nullableAttribute != null) {
if(nullableAttribute.ConstructorArguments[0].Value is byte t) {
attributeFlag = t;
} else if(nullableAttribute.ConstructorArguments[0].Value is byte[] tArray) {
attributeFlag = tArray[0];
}
}
if(attributeFlag == 1) {
return false;
}
if(attributeFlag == 2) {
return true;
}

CustomAttributeData? nullableContextAttribute = _mi.DeclaringType?.CustomAttributes
.FirstOrDefault(attr => attr.AttributeType.Name == "NullableContextAttribute");
byte? classFlag = null;
if(nullableContextAttribute != null) {
classFlag = (byte)nullableContextAttribute.ConstructorArguments[0].Value!;
}
if(classFlag == 1) {
return false;
}
if(classFlag == 2) {
return true;
}

return null;
}


}

class ClassPropertyMember : ClassMember {
Expand All @@ -82,7 +134,7 @@ public ClassPropertyMember(PropertyInfo propertyInfo) : base(propertyInfo) {
class ClassFieldMember : ClassMember {
private readonly FieldInfo _fi;

public ClassFieldMember(FieldInfo fi) :base(fi) {
public ClassFieldMember(FieldInfo fi) : base(fi) {
_fi = fi;
}

Expand Down Expand Up @@ -129,7 +181,7 @@ private static List<ClassMember> FindMembers(Type t, bool forWriting) {
return members;
}

private static Field ConstructDataField(string name, string propertyName, Type t, ClassMember? member) {
private static Field ConstructDataField(string name, string propertyName, Type t, ClassMember? member, bool isCompiledWithNullable) {
Field r;
bool? isNullable = member == null
? null
Expand Down Expand Up @@ -169,45 +221,54 @@ private static Field ConstructDataField(string name, string propertyName, Type t
if(t.IsEnum) {
t = t.GetEnumUnderlyingType();
}

r = new DataField(name, t, isNullable, null, propertyName);
bool? isMemberNullable = null;
if (isCompiledWithNullable) {
isMemberNullable = member?.IsNullable(t);
}

if(isMemberNullable is not null) {
isNullable = isMemberNullable.Value;
}
r = new DataField(name, t, isNullable, null, propertyName, isCompiledWithNullable);
}

return r;
}

private static MapField ConstructMapField(string name, string propertyName,
Type tKey, Type tValue,
bool forWriting) {
bool forWriting,
bool isCompiledWithNullable) {

Type kvpType = typeof(KeyValuePair<,>).MakeGenericType(tKey, tValue);
PropertyInfo piKey = kvpType.GetProperty("Key")!;
PropertyInfo piValue = kvpType.GetProperty("Value")!;

Field keyField = MakeField(new ClassPropertyMember(piKey), forWriting)!;
Field keyField = MakeField(new ClassPropertyMember(piKey), forWriting, isCompiledWithNullable)!;
if(keyField is DataField keyDataField && keyDataField.IsNullable) {
keyField.IsNullable = false;
}
Field valueField = MakeField(new ClassPropertyMember(piValue), forWriting)!;
Field valueField = MakeField(new ClassPropertyMember(piValue), forWriting, isCompiledWithNullable)!;
var mf = new MapField(name, keyField, valueField);
mf.ClrPropName = propertyName;
return mf;
}

private static ListField ConstructListField(string name, string propertyName,
Type elementType,
bool forWriting) {
bool forWriting,
bool isCompiledWithNullable) {

ListField lf = new ListField(name, MakeField(elementType, ListField.ElementName, propertyName, null, forWriting)!);
ListField lf = new ListField(name, MakeField(elementType, ListField.ElementName, propertyName, null, forWriting, isCompiledWithNullable)!);
lf.ClrPropName = propertyName;
return lf;
}

private static Field? MakeField(ClassMember member, bool forWriting) {
private static Field? MakeField(ClassMember member, bool forWriting, bool isCompiledWithNullable) {
if(member.ShouldIgnore)
return null;

Field r = MakeField(member.MemberType, member.ColumnName, member.Name, member, forWriting);
Field r = MakeField(member.MemberType, member.ColumnName, member.Name, member, forWriting, isCompiledWithNullable);
r.Order = member.Order;
return r;
}
Expand All @@ -220,28 +281,30 @@ private static ListField ConstructListField(string name, string propertyName,
/// <param name="propertyName">Class property name</param>
/// <param name="member">Optional <see cref="PropertyInfo"/> that can be used to get attribute metadata.</param>
/// <param name="forWriting"></param>
/// <param name="isCompiledWithNullable">if nullable was enabled to compile the type</param>
/// <returns><see cref="DataField"/> or complex field (recursively scans class). Can return null if property is explicitly marked to be ignored.</returns>
/// <exception cref="NotImplementedException"></exception>
private static Field MakeField(Type t, string columnName, string propertyName,
ClassMember? member,
bool forWriting) {
bool forWriting,
bool isCompiledWithNullable) {

Type bt = t.IsNullable() ? t.GetNonNullable() : t;
if(member != null && member.IsLegacyRepeatable && !bt.IsGenericIDictionary() && bt.TryExtractIEnumerableType(out Type? bti)) {
bt = bti!;
}

if(SchemaEncoder.IsSupported(bt)) {
return ConstructDataField(columnName, propertyName, t, member);
return ConstructDataField(columnName, propertyName, t, member, isCompiledWithNullable && !(member?.IsLegacyRepeatable??false));
} else if(t.TryExtractDictionaryType(out Type? tKey, out Type? tValue)) {
return ConstructMapField(columnName, propertyName, tKey!, tValue!, forWriting);
return ConstructMapField(columnName, propertyName, tKey!, tValue!, forWriting, isCompiledWithNullable);
} else if(t.TryExtractIEnumerableType(out Type? elementType)) {
return ConstructListField(columnName, propertyName, elementType!, forWriting);
return ConstructListField(columnName, propertyName, elementType!, forWriting, isCompiledWithNullable);
} else if(t.IsClass || t.IsInterface || t.IsValueType) {
// must be a struct then (c# class or c# struct)!
List<ClassMember> props = FindMembers(t, forWriting);
Field[] fields = props
.Select(p => MakeField(p, forWriting))
.Select(p => MakeField(p, forWriting, isCompiledWithNullable))
.Where(f => f != null)
.Select(f => f!)
.OrderBy(f => f.Order)
Expand All @@ -261,10 +324,10 @@ private static Field MakeField(Type t, string columnName, string propertyName,
private static ParquetSchema CreateSchema(Type t, bool forWriting) {

// get it all, including base class properties (may be a hierarchy)

bool isCompiledWithNullable = NullableChecker.IsNullableEnabled(t);
List<ClassMember> props = FindMembers(t, forWriting);
List<Field> fields = props
.Select(p => MakeField(p, forWriting))
.Select(p => MakeField(p, forWriting, isCompiledWithNullable))
.Where(f => f != null)
.Select(f => f!)
.OrderBy(f => f.Order)
Expand Down
Loading