Skip to content

Commit

Permalink
adds <Clone> method in record classes (#291)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianoc committed Jun 17, 2024
1 parent c28313c commit 44029fa
Showing 1 changed file with 47 additions and 9 deletions.
56 changes: 47 additions & 9 deletions Cecilifier.Core/CodeGeneration/Record.Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ internal void AddSyntheticMembers()
AddPropertiesFrom();
PrimaryConstructorGenerator.AddPrimaryConstructor(context, recordTypeDefinitionVariable, record);
AddCopyConstructor();
AddCloneMethod();
AddIEquatableEquals();
AddToStringAndRelatedMethods();
AddGetHashCodeMethod();
Expand All @@ -79,6 +80,43 @@ internal void AddSyntheticMembers()
AddDeconstructMethod();
}

private void AddCloneMethod()
{
if (_recordSymbol.IsValueType || record.ParameterList?.Parameters.Count == 0)
return;

const string CloneMethodName = "<Clone>$";
context.WriteComment($"{_recordSymbol.Name} {CloneMethodName} method");

var cloneMethodVar = context.Naming.SyntheticVariable("clone", ElementKind.Method);
var cloneMethodExps = CecilDefinitionsFactory.Method(context, cloneMethodVar, CloneMethodName, Constants.Cecil.HideBySigNewSlotVirtual.AppendModifier("MethodAttributes.Public"), _recordSymbol, false, []);
context.WriteCecilExpressions(
[
..cloneMethodExps,
$"{recordTypeDefinitionVariable}.Methods.Add({cloneMethodVar});"
]);

var copyCtorVarToFind = _recordSymbol.GetMembers(".ctor").OfType<IMethodSymbol>().Single(c => c.Parameters.Length == 1 && SymbolEqualityComparer.Default.Equals(c.Parameters[0].Type, c.ContainingType)).AsMethodDefinitionVariable();
var copyCtorVar = context.DefinitionVariables.GetMethodVariable(copyCtorVarToFind);
if (!copyCtorVar.IsValid)
{
throw new Exception($"Copy constructor definition variable for record {_recordSymbol.Name} could not be found.");
}

InstructionRepresentation[] instructions =
[
OpCodes.Ldarg_0,
OpCodes.Newobj.WithOperand(ClosedGenericMethodForMethodVariable(copyCtorVar.VariableName, recordTypeDefinitionVariable, context.TypeResolver.Resolve(_recordSymbol))),
OpCodes.Ret
];

context.WriteCecilExpressions(
CecilDefinitionsFactory.MethodBody(context.Naming, "clone", cloneMethodVar, context.Naming.ILProcessor("clone"), [], instructions)
);

AddCompilerGeneratedAttributeTo(context, cloneMethodVar);
}

private void AddCopyConstructor()
{
if (_recordSymbol.IsValueType || record.ParameterList?.Parameters.Count == 0)
Expand All @@ -92,7 +130,7 @@ private void AddCopyConstructor()
if (!found.IsValid)
{
copyCtorVar = context.Naming.Constructor(record, false);
var copyCtor = CecilDefinitionsFactory.Constructor(context, copyCtorVar, _recordSymbol.Name, false, Constants.Cecil.CtorAttributes.AppendModifier("MethodAttributes.Family | MethodAttributes.HideBySig"), [_recordSymbol.Name]);
var copyCtor = CecilDefinitionsFactory.Constructor(context, copyCtorVar, _recordSymbol.Name, false, Constants.Cecil.CtorAttributes.AppendModifier("MethodAttributes.Family | MethodAttributes.HideBySig"), [_recordSymbol.ToDisplayString()]);
context.WriteCecilExpressions(
[
copyCtor,
Expand Down Expand Up @@ -354,7 +392,7 @@ void AddRecordClassHashCodeSpecificCode()
[
OpCodes.Call.WithOperand(equalityComparerMembersForSystemType.GetDefaultMethodVar),
OpCodes.Ldarg_0, // Load this
OpCodes.Call.WithOperand(ClosedGenericMethodForMethodVariable(context, getEqualityContractMethodVar.VariableName, recordTypeDefinitionVariable)), // load EqualityContract
OpCodes.Call.WithOperand(ClosedGenericMethodForMethodVariable(getEqualityContractMethodVar.VariableName, recordTypeDefinitionVariable)), // load EqualityContract
OpCodes.Callvirt.WithOperand(equalityComparerMembersForSystemType.GetHashCodeMethodVar)
]);
}
Expand Down Expand Up @@ -436,7 +474,7 @@ void AddPrintMembersMethod()
OpCodes.Pop,
OpCodes.Ldarg_1,
OpCodes.Ldarg_0,
OpCodes.Call.WithOperand(ClosedGenericMethodFor(context, $"get_{property.Name}", recordTypeDefinitionVariable)),
OpCodes.Call.WithOperand(ClosedGenericMethodFor($"get_{property.Name}", recordTypeDefinitionVariable)),
OpCodes.Box.WithOperand(context.TypeResolver.Resolve(property.Type)).IgnoreIf(property.Type.TypeKind != TypeKind.TypeParameter),
OpCodes.Callvirt.WithOperand(stringBuilderAppendMethod.MethodResolverExpression(context)),
OpCodes.Pop
Expand Down Expand Up @@ -547,18 +585,19 @@ static IMethodSymbol StringBuilderAppendMethodFor(IVisitorContext context, IType
}
}

private string ClosedGenericMethodFor(IVisitorContext context, string memberName, string recordVar)
private string ClosedGenericMethodFor(string memberName, string recordVar)
{
var methodVar = context.DefinitionVariables.GetVariable(memberName, VariableMemberKind.Method, _recordSymbol.Name);
return ClosedGenericMethodForMethodVariable(context, methodVar.VariableName, recordVar);
return ClosedGenericMethodForMethodVariable(methodVar.VariableName, recordVar);
}

private string ClosedGenericMethodForMethodVariable(IVisitorContext context, string methodVar, string recordVar)
private string ClosedGenericMethodForMethodVariable(string methodVar, string recordVar, params string[] parameterReferences)
{
if (_recordSymbol is INamedTypeSymbol { IsGenericType: true } genericRecord)
{
var parameters = parameterReferences.Length > 0 ? $$""", Parameters = { {{string.Join(',', parameterReferences.Select(p => $"new ParameterDefinition({p})"))}} }""" : "";
var typeArguments = string.Join(',', genericRecord.TypeParameters.Select(tp => context.TypeResolver.Resolve(tp)));
return $"new MethodReference({methodVar}.Name, {methodVar}.ReturnType) {{ DeclaringType = {recordVar}.MakeGenericInstanceType([{typeArguments}]), HasThis = {methodVar}.HasThis, ExplicitThis = {methodVar}.ExplicitThis, CallingConvention = {methodVar}.CallingConvention }}";
return $"new MethodReference({methodVar}.Name, {methodVar}.ReturnType) {{ DeclaringType = {recordVar}.MakeGenericInstanceType([{typeArguments}]), HasThis = {methodVar}.HasThis, ExplicitThis = {methodVar}.ExplicitThis, CallingConvention = {methodVar}.CallingConvention{parameters} }}";
}

return methodVar;
Expand Down Expand Up @@ -669,7 +708,7 @@ void CheckForReferenceEqualityIfApplicable(List<InstructionRepresentation> instr
OpCodes.Brfalse_S.WithBranchOperand("NotEquals"),
]);

var equalityContractGetter = ClosedGenericMethodForMethodVariable(context, _equalityContractGetMethodVar, recordTypeDefinitionVariable);
var equalityContractGetter = ClosedGenericMethodForMethodVariable(_equalityContractGetMethodVar, recordTypeDefinitionVariable);
instructions.AddRange(
[
OpCodes.Ldarg_0,
Expand All @@ -689,7 +728,6 @@ void CheckForReferenceEqualityIfApplicable(List<InstructionRepresentation> instr
foreach (var targetType in targetTypes)
{
var openEqualityComparerType = context.TypeResolver.Resolve(context.SemanticModel.Compilation.GetTypeByMetadataName(typeof(EqualityComparer<>).FullName!));
//var parameterType = context.SemanticModel.GetTypeInfo(targetType.Type!).Type.EnsureNotNull();
if (equalityComparerDataByType.ContainsKey(targetType.Name))
continue;

Expand Down

0 comments on commit 44029fa

Please sign in to comment.