From 5aa22e61a1d0ec5aa26eeb4df5032b8bfa913814 Mon Sep 17 00:00:00 2001 From: Adriano Carlos Verona Date: Wed, 16 Aug 2023 19:29:36 -0400 Subject: [PATCH] fixes plain value type being assigned to nullable variables in some scenarios (#251) --- .../Tests/Unit/NullableTests.cs | 57 ++++++- .../AST/ExpressionVisitor.Conversions.cs | 149 +++++++++++++++++ Cecilifier.Core/AST/ExpressionVisitor.cs | 155 ------------------ .../TypeSystem/RoslynTypeSystem.cs | 2 + .../TypeSystem/SystemTypeSystem.cs | 2 + 5 files changed, 201 insertions(+), 164 deletions(-) create mode 100644 Cecilifier.Core/AST/ExpressionVisitor.Conversions.cs diff --git a/Cecilifier.Core.Tests/Tests/Unit/NullableTests.cs b/Cecilifier.Core.Tests/Tests/Unit/NullableTests.cs index fe579395..ebc12459 100644 --- a/Cecilifier.Core.Tests/Tests/Unit/NullableTests.cs +++ b/Cecilifier.Core.Tests/Tests/Unit/NullableTests.cs @@ -1,6 +1,5 @@ using System.Text.RegularExpressions; using NUnit.Framework; -using NUnit.Framework.Internal; namespace Cecilifier.Core.Tests.Tests.Unit; @@ -54,13 +53,53 @@ class Foo : IFoo actual); } - (\s+il_test_\d+\.Emit\(OpCodes\.)Ldarg_0\); - \1Ldarg_1\); - \1NEWOBJ Nullable - \1Call, m_bar_1\); - \1NEWOBJ Nullable - \1Ret\); - """), - ""); + [TestCase( + """ + class Foo + { + int Bar(int? i) => i.Value; + int? Test(int i1) { return Bar(i1); } // i1 should be converted to Nullable and Bar() return also. + } + """, + + """ + //return Bar\(i1\); + (\s+il_test_\d+\.Emit\(OpCodes\.)Ldarg_0\); + \1Ldarg_1\); + (?\1Newobj, .+ImportReference\(typeof\(System.Nullable<>\).MakeGenericType\(typeof\(System.Int32\)\).GetConstructors\(\).+;) + \1Call, m_bar_1\); + \k + \1Ret\); + """, + TestName = "Method parameter and return value" + )] + + [TestCase( + """ + class Foo + { + void Bar(int? p) + { + p = 41; + + int ?lp; + lp = 42; + } + } + """, + + """ + //p = 41; + (\s+il_bar_\d+\.Emit\(OpCodes\.)Ldc_I4, 41\); + \1Newobj, .+ImportReference\(typeof\(System.Nullable<>\).MakeGenericType\(typeof\(System.Int32\)\).GetConstructors\(\).+; + \1Starg_S, p_p_3\); + """, + TestName = "Variable assignment" + )] + public void ImplicitNullableConversions_AreApplied(string code, string expectedSnippet) + { + //https://github.com/adrianoc/cecilifier/issues/251 + var result = RunCecilifier(code); + Assert.That(result.GeneratedCode.ReadToEnd(), Does.Match(expectedSnippet)); } } diff --git a/Cecilifier.Core/AST/ExpressionVisitor.Conversions.cs b/Cecilifier.Core/AST/ExpressionVisitor.Conversions.cs new file mode 100644 index 00000000..2c2a001d --- /dev/null +++ b/Cecilifier.Core/AST/ExpressionVisitor.Conversions.cs @@ -0,0 +1,149 @@ +using System; +using System.Diagnostics; +using System.Linq; +using Cecilifier.Core.Extensions; +using Cecilifier.Core.Misc; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Mono.Cecil.Cil; + +namespace Cecilifier.Core.AST; + +partial class ExpressionVisitor +{ + private void InjectRequiredConversions(ExpressionSyntax expression, Action loadArrayIntoStack = null) + { + var typeInfo = ModelExtensions.GetTypeInfo(Context.SemanticModel, expression); + if (typeInfo.Type == null) return; + var conversion = Context.SemanticModel.GetConversion(expression); + if (conversion.IsImplicit) + { + if (conversion.IsNullable) + { + Context.EmitCilInstruction( + ilVar, + OpCodes.Newobj, + $"assembly.MainModule.ImportReference(typeof(System.Nullable<>).MakeGenericType(typeof({typeInfo.Type.FullyQualifiedName()})).GetConstructors().Single(ctor => ctor.GetParameters().Length == 1))"); + return; + } + + if (conversion.IsNumeric) + { + Debug.Assert(typeInfo.ConvertedType != null); + switch (typeInfo.ConvertedType.SpecialType) + { + case SpecialType.System_Single: + Context.EmitCilInstruction(ilVar, OpCodes.Conv_R4); + return; + case SpecialType.System_Double: + Context.EmitCilInstruction(ilVar, OpCodes.Conv_R8); + return; + case SpecialType.System_Byte: + Context.EmitCilInstruction(ilVar, OpCodes.Conv_I1); + return; + case SpecialType.System_Int16: + Context.EmitCilInstruction(ilVar, OpCodes.Conv_I2); + return; + case SpecialType.System_Int32: + // byte/char are pushed as Int32 by the runtime + if (typeInfo.Type.SpecialType != SpecialType.System_SByte && typeInfo.Type.SpecialType != SpecialType.System_Byte && typeInfo.Type.SpecialType != SpecialType.System_Char) + Context.EmitCilInstruction(ilVar, OpCodes.Conv_I4); + return; + case SpecialType.System_Int64: + var convOpCode = typeInfo.Type.SpecialType == SpecialType.System_Char || typeInfo.Type.SpecialType == SpecialType.System_Byte ? OpCodes.Conv_U8 : OpCodes.Conv_I8; + Context.EmitCilInstruction(ilVar, convOpCode); + return; + case SpecialType.System_Decimal: + var operand = typeInfo.ConvertedType.GetMembers().OfType() + .Single(m => m.MethodKind == MethodKind.Constructor && m.Parameters.Length == 1 && m.Parameters[0].Type.SpecialType == typeInfo.Type.SpecialType); + Context.EmitCilInstruction(ilVar, OpCodes.Newobj, operand.MethodResolverExpression(Context)); + return; + default: + throw new Exception($"Conversion from {typeInfo.Type} to {typeInfo.ConvertedType} not implemented."); + } + } + + if (conversion.MethodSymbol != null) + { + AddMethodCall(ilVar, conversion.MethodSymbol, false); + } + } + + if (conversion.IsImplicit && (conversion.IsBoxing || NeedsBoxing(Context, expression, typeInfo.Type))) + { + AddCilInstruction(ilVar, OpCodes.Box, typeInfo.Type); + } + else if (conversion.IsIdentity && typeInfo.Type.Name == "Index" && !expression.IsKind(SyntaxKind.IndexExpression) && loadArrayIntoStack != null) + { + // We are indexing an array/indexer (this[]) using a System.Index variable; In this case + // we need to convert from System.Index to *int* which is done through + // the method System.Index::GetOffset(int32) + loadArrayIntoStack(); + var indexed = ModelExtensions.GetTypeInfo(Context.SemanticModel, expression.Ancestors().OfType().Single().Expression); + Utils.EnsureNotNull(indexed.Type, "Cannot be null."); + if (indexed.Type.Name == "Span") + AddMethodCall(ilVar, ((IPropertySymbol) indexed.Type.GetMembers("Length").Single()).GetMethod); + else + Context.EmitCilInstruction(ilVar, OpCodes.Ldlen); + Context.EmitCilInstruction(ilVar, OpCodes.Conv_I4); + AddMethodCall(ilVar, (IMethodSymbol) typeInfo.Type.GetMembers().Single(m => m.Name == "GetOffset")); + } + + // Empirically (verified in generated IL), expressions of type parameter used as: + // 1. Target of a call, unless the type parameter + // - is unconstrained (i.e, method being invoked comes from System.Object) or + // - is constrained to an interface, but not to a reference type or + // - is constrained to 'struct' + // 2. Source of assignment (or variable initialization) to a reference type + // 3. Argument for a reference type parameter + // requires boxing, but for some reason, the conversion returned by GetConversion() does not reflects that. + static bool NeedsBoxing(IVisitorContext context, ExpressionSyntax expression, ITypeSymbol type) + { + var needsBoxing = type.TypeKind == TypeKind.TypeParameter && (NeedsBoxingUsedAsTargetOfReference(context, expression) || AssignmentExpressionNeedsBoxing(context, expression, type) || + TypeIsReferenceType(context, expression, type) || expression.Parent.IsArgumentPassedToReferenceTypeParameter(context, type) || + expression.Parent is BinaryExpressionSyntax binaryExpressionSyntax && binaryExpressionSyntax.OperatorToken.IsKind(SyntaxKind.IsKeyword)); + return needsBoxing; + + bool TypeIsReferenceType(IVisitorContext context, ExpressionSyntax expression, ITypeSymbol rightType) + { + if (expression.Parent is not EqualsValueClauseSyntax equalsValueClauseSyntax) return false; + Debug.Assert(expression.Parent.Parent.IsKind(SyntaxKind.VariableDeclarator)); + + // get the type of the declared variable. For instance, in `int x = 10;`, expression='10', + // expression.Parent.Parent = 'x=10' (VariableDeclaratorSyntax) + var leftType = context.SemanticModel.GetDeclaredSymbol(expression.Parent.Parent).GetMemberType(); + return !SymbolEqualityComparer.Default.Equals(leftType, rightType) && leftType.IsReferenceType; + } + + static bool AssignmentExpressionNeedsBoxing(IVisitorContext context, ExpressionSyntax expression, ITypeSymbol rightType) + { + if (expression.Parent is not AssignmentExpressionSyntax assignment) return false; + var leftType = ModelExtensions.GetTypeInfo(context.SemanticModel, assignment.Left).Type; + return !SymbolEqualityComparer.Default.Equals(leftType, rightType) && leftType.IsReferenceType; + } + + static bool NeedsBoxingUsedAsTargetOfReference(IVisitorContext context, ExpressionSyntax expression) + { + if (((CSharpSyntaxNode) expression.Parent).Accept(UsageVisitor.GetInstance(context)) != UsageKind.CallTarget) return false; + var symbol = ModelExtensions.GetSymbolInfo(context.SemanticModel, expression).Symbol; + // only triggers when expression `T` used in T.Method() (i.e, abstract static methods from an interface) + if (symbol is { Kind: SymbolKind.TypeParameter }) return false; + ITypeParameterSymbol typeParameter = null; + if (symbol == null) + { + typeParameter = context.GetTypeInfo(expression).Type as ITypeParameterSymbol; + } + else + { + // 'expression' represents a local variable, parameters, etc.. so we get its element type + typeParameter = symbol.GetMemberType() as ITypeParameterSymbol; + } + + if (typeParameter == null) return false; + if (typeParameter.HasValueTypeConstraint) return false; + return typeParameter.HasReferenceTypeConstraint || (typeParameter.ConstraintTypes.Length > 0 && typeParameter.ConstraintTypes.Any(candidate => candidate.TypeKind != TypeKind.Interface)); + } + } + } +} diff --git a/Cecilifier.Core/AST/ExpressionVisitor.cs b/Cecilifier.Core/AST/ExpressionVisitor.cs index 75d92f13..48d261bf 100644 --- a/Cecilifier.Core/AST/ExpressionVisitor.cs +++ b/Cecilifier.Core/AST/ExpressionVisitor.cs @@ -1335,161 +1335,6 @@ TypeParameterSyntax[] TypeParameterSyntaxFor(IMethodSymbol method) return declarationNode.TypeParameterList.Parameters.ToArray(); } - - private void InjectRequiredConversions(ExpressionSyntax expression, Action loadArrayIntoStack = null) - { - var typeInfo = Context.SemanticModel.GetTypeInfo(expression); - if (typeInfo.Type == null) - return; - - var conversion = Context.SemanticModel.GetConversion(expression); - if (conversion.IsImplicit) - { - if (conversion.IsNumeric) - { - Debug.Assert(typeInfo.ConvertedType != null); - switch (typeInfo.ConvertedType.SpecialType) - { - case SpecialType.System_Single: - Context.EmitCilInstruction(ilVar, OpCodes.Conv_R4); - return; - - case SpecialType.System_Double: - Context.EmitCilInstruction(ilVar, OpCodes.Conv_R8); - return; - - case SpecialType.System_Byte: - Context.EmitCilInstruction(ilVar, OpCodes.Conv_I1); - return; - - case SpecialType.System_Int16: - Context.EmitCilInstruction(ilVar, OpCodes.Conv_I2); - return; - - case SpecialType.System_Int32: - // byte/char are pushed as Int32 by the runtime - if (typeInfo.Type.SpecialType != SpecialType.System_SByte && typeInfo.Type.SpecialType != SpecialType.System_Byte && typeInfo.Type.SpecialType != SpecialType.System_Char) - Context.EmitCilInstruction(ilVar, OpCodes.Conv_I4); - return; - - case SpecialType.System_Int64: - var convOpCode = typeInfo.Type.SpecialType == SpecialType.System_Char || typeInfo.Type.SpecialType == SpecialType.System_Byte ? OpCodes.Conv_U8 : OpCodes.Conv_I8; - Context.EmitCilInstruction(ilVar, convOpCode); - return; - - case SpecialType.System_Decimal: - var operand = typeInfo.ConvertedType.GetMembers().OfType().Single(m => m.MethodKind == MethodKind.Constructor && m.Parameters.Length == 1 && m.Parameters[0].Type.SpecialType == typeInfo.Type.SpecialType); - Context.EmitCilInstruction(ilVar, OpCodes.Newobj, operand.MethodResolverExpression(Context)); - return; - - default: - throw new Exception($"Conversion from {typeInfo.Type} to {typeInfo.ConvertedType} not implemented."); - } - } - - if (conversion.MethodSymbol != null) - { - AddMethodCall(ilVar, conversion.MethodSymbol, false); - } - } - - if (conversion.IsImplicit && (conversion.IsBoxing || NeedsBoxing(Context, expression, typeInfo.Type))) - { - AddCilInstruction(ilVar, OpCodes.Box, typeInfo.Type); - } - else if (conversion.IsIdentity && typeInfo.Type.Name == "Index" && !expression.IsKind(SyntaxKind.IndexExpression) && loadArrayIntoStack != null) - { - // We are indexing an array/indexer (this[]) using a System.Index variable; In this case - // we need to convert from System.Index to *int* which is done through - // the method System.Index::GetOffset(int32) - loadArrayIntoStack(); - - var indexed = Context.SemanticModel.GetTypeInfo(expression.Ancestors().OfType().Single().Expression); - Utils.EnsureNotNull(indexed.Type, "Cannot be null."); - if (indexed.Type.Name == "Span") - AddMethodCall(ilVar, ((IPropertySymbol) indexed.Type.GetMembers("Length").Single()).GetMethod); - else - Context.EmitCilInstruction(ilVar, OpCodes.Ldlen); - - Context.EmitCilInstruction(ilVar, OpCodes.Conv_I4); - AddMethodCall(ilVar, (IMethodSymbol) typeInfo.Type.GetMembers().Single(m => m.Name == "GetOffset")); - } - - // Empirically (verified in generated IL), expressions of type parameter used as: - // 1. Target of a call, unless the type parameter - // - is unconstrained (i.e, method being invoked comes from System.Object) or - // - is constrained to an interface, but not to a reference type or - // - is constrained to 'struct' - // 2. Source of assignment (or variable initialization) to a reference type - // 3. Argument for a reference type parameter - // requires boxing, but for some reason, the conversion returned by GetConversion() does not reflects that. - static bool NeedsBoxing(IVisitorContext context, ExpressionSyntax expression, ITypeSymbol type) - { - var needsBoxing = type.TypeKind == TypeKind.TypeParameter && - (NeedsBoxingUsedAsTargetOfReference(context, expression) - || AssignmentExpressionNeedsBoxing(context, expression, type) - || TypeIsReferenceType(context, expression, type) - || expression.Parent.IsArgumentPassedToReferenceTypeParameter(context, type) - || expression.Parent is BinaryExpressionSyntax binaryExpressionSyntax && binaryExpressionSyntax.OperatorToken.IsKind(SyntaxKind.IsKeyword) - ); - - return needsBoxing; - - bool TypeIsReferenceType(IVisitorContext context, ExpressionSyntax expression, ITypeSymbol rightType) - { - if (expression.Parent is not EqualsValueClauseSyntax equalsValueClauseSyntax) - return false; - - Debug.Assert(expression.Parent.Parent.IsKind(SyntaxKind.VariableDeclarator)); - - // get the type of the declared variable. For instance, in `int x = 10;`, expression='10', - // expression.Parent.Parent = 'x=10' (VariableDeclaratorSyntax) - var leftType = context.SemanticModel.GetDeclaredSymbol(expression.Parent.Parent).GetMemberType(); - return !SymbolEqualityComparer.Default.Equals(leftType, rightType) && leftType.IsReferenceType; - } - - static bool AssignmentExpressionNeedsBoxing(IVisitorContext context, ExpressionSyntax expression, ITypeSymbol rightType) - { - if (expression.Parent is not AssignmentExpressionSyntax assignment) - return false; - - var leftType = context.SemanticModel.GetTypeInfo(assignment.Left).Type; - - return !SymbolEqualityComparer.Default.Equals(leftType, rightType) && leftType.IsReferenceType; - } - - static bool NeedsBoxingUsedAsTargetOfReference(IVisitorContext context, ExpressionSyntax expression) - { - if (((CSharpSyntaxNode) expression.Parent).Accept(UsageVisitor.GetInstance(context)) != UsageKind.CallTarget) - return false; - - var symbol = context.SemanticModel.GetSymbolInfo(expression).Symbol; - // only triggers when expression `T` used in T.Method() (i.e, abstract static methods from an interface) - if (symbol is { Kind: SymbolKind.TypeParameter }) - return false; - - ITypeParameterSymbol typeParameter = null; - if (symbol == null) - { - typeParameter = context.GetTypeInfo(expression).Type as ITypeParameterSymbol; - } - else - { - // 'expression' represents a local variable, parameters, etc.. so we get its element type - typeParameter = symbol.GetMemberType() as ITypeParameterSymbol; - } - - if (typeParameter == null) - return false; - - if (typeParameter.HasValueTypeConstraint) - return false; - - return typeParameter.HasReferenceTypeConstraint - || (typeParameter.ConstraintTypes.Length > 0 && typeParameter.ConstraintTypes.Any(candidate => candidate.TypeKind != TypeKind.Interface)); - } - } - } private BinaryOperatorHandler OperatorHandlerFor(SyntaxToken operatorToken) { diff --git a/Cecilifier.Core/TypeSystem/RoslynTypeSystem.cs b/Cecilifier.Core/TypeSystem/RoslynTypeSystem.cs index e0623675..2db825f0 100644 --- a/Cecilifier.Core/TypeSystem/RoslynTypeSystem.cs +++ b/Cecilifier.Core/TypeSystem/RoslynTypeSystem.cs @@ -41,6 +41,7 @@ public RoslynTypeSystem(IVisitorContext ctx) SystemCollectionsGenericIEnumeratorOfT = ctx.SemanticModel.Compilation.GetSpecialType(SpecialType.System_Collections_Generic_IEnumerator_T); SystemCollectionsIEnumerable = ctx.SemanticModel.Compilation.GetSpecialType(SpecialType.System_Collections_IEnumerable); SystemCollectionsGenericIEnumerableOfT = ctx.SemanticModel.Compilation.GetSpecialType(SpecialType.System_Collections_Generic_IEnumerable_T); + SystemNullableOfT = ctx.SemanticModel.Compilation.GetSpecialType(SpecialType.System_Nullable_T); } public ITypeSymbol SystemIndex { get; } @@ -68,4 +69,5 @@ public RoslynTypeSystem(IVisitorContext ctx) public ITypeSymbol SystemValueType { get; } public ITypeSymbol SystemRuntimeCompilerServicesRuntimeHelpers { get; } + public ITypeSymbol SystemNullableOfT { get; } } diff --git a/Cecilifier.Core/TypeSystem/SystemTypeSystem.cs b/Cecilifier.Core/TypeSystem/SystemTypeSystem.cs index 67cabe91..ff8837c8 100644 --- a/Cecilifier.Core/TypeSystem/SystemTypeSystem.cs +++ b/Cecilifier.Core/TypeSystem/SystemTypeSystem.cs @@ -21,6 +21,7 @@ public SystemTypeSystem(ITypeResolver typeResolver, IVisitorContext context) [SpecialType.System_MulticastDelegate] = typeResolver.Resolve("System.MulticastDelegate"), [SpecialType.System_AsyncCallback] = typeResolver.Resolve("System.AsyncCallback"), [SpecialType.System_IAsyncResult] = typeResolver.Resolve("System.IAsyncResult"), + [SpecialType.System_Nullable_T] = typeResolver.Resolve("System.Nullable<>"), }; } @@ -35,6 +36,7 @@ public SystemTypeSystem(ITypeResolver typeResolver, IVisitorContext context) public string MulticastDelegate => _resolvedTypes[SpecialType.System_MulticastDelegate]; public string AsyncCallback => _resolvedTypes[SpecialType.System_AsyncCallback]; public string IAsyncResult => _resolvedTypes[SpecialType.System_IAsyncResult]; + public string NullableOfT => _resolvedTypes[SpecialType.System_Nullable_T]; private readonly IReadOnlyDictionary _resolvedTypes; }