Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic support for StringMarshalling in GeneratedComInterface #86404

Merged
merged 6 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
AttributeData? generatedComAttribute = null;
foreach (var attr in symbol.ContainingType.GetAttributes())
{
if (generatedComAttribute is not null && attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute)
if (generatedComAttribute is null
&& attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute)
{
generatedComAttribute = attr;
}
Expand All @@ -256,8 +257,23 @@ private static IncrementalMethodStubGenerationContext CalculateStubInformation(M
generatorDiagnostics.ReportConfigurationNotSupported(lcidConversionAttr, nameof(TypeNames.LCIDConversionAttribute));
}

var generatedComInterfaceAttributeData = new InteropAttributeCompilationData();
if (generatedComAttribute is not null)
{
var args = generatedComAttribute.NamedArguments.ToImmutableDictionary();
generatedComInterfaceAttributeData = generatedComInterfaceAttributeData.WithValuesFromNamedArguments(args);
}
// Create the stub.
var signatureContext = SignatureContext.Create(symbol, DefaultMarshallingInfoParser.Create(environment, generatorDiagnostics, symbol, new InteropAttributeCompilationData(), generatedComAttribute), environment, typeof(VtableIndexStubGenerator).Assembly);
var signatureContext = SignatureContext.Create(
symbol,
DefaultMarshallingInfoParser.Create(
environment,
generatorDiagnostics,
symbol,
generatedComInterfaceAttributeData,
generatedComAttribute),
environment,
typeof(VtableIndexStubGenerator).Assembly);

if (!symbol.MethodImplementationFlags.HasFlag(MethodImplAttributes.PreserveSig))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public interface IMarshallingInfoAttributeParser
}

/// <summary>
/// A provider of marshalling info based only on the managed type any any previously parsed use-site attribute information
/// A provider of marshalling info based only on the managed type and any previously parsed use-site attribute information
/// </summary>
public interface ITypeBasedMarshallingInfoProvider
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,8 @@ public GeneratedComClassAttribute() { }
public partial class GeneratedComInterfaceAttribute : System.Attribute
{
public GeneratedComInterfaceAttribute() { }
public StringMarshalling StringMarshalling { get { throw null; } set { } }
public Type? StringMarshallingCustomType { get { throw null; } set { } }
}
[System.CLSCompliantAttribute(false)]
public partial interface IComExposedClass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,26 @@ namespace System.Runtime.InteropServices.Marshalling
[AttributeUsage(AttributeTargets.Interface)]
public class GeneratedComInterfaceAttribute : Attribute
{
/// <summary>
/// Gets or sets how to marshal string arguments to all methods on the interface.
/// If the attributed interface inherits from another interface with <see cref="GeneratedComInterfaceAttribute"/>,
/// it must have the same values for <see cref="StringMarshalling"/> and <see cref="StringMarshallingCustomType"/>.
/// </summary>
/// <remarks>
/// If this field is set to a value other than <see cref="StringMarshalling.Custom" />,
/// <see cref="StringMarshallingCustomType" /> must not be specified.
/// </remarks>
public StringMarshalling StringMarshalling { get; set; }

/// <summary>
/// Gets or sets the <see cref="Type"/> used to control how string arguments are marshalled for all methods on the interface.
/// If the attributed interface inherits from another interface with <see cref="GeneratedComInterfaceAttribute"/>,
/// it must have the same values for <see cref="StringMarshalling"/> and <see cref="StringMarshallingCustomType"/>.
/// </summary>
/// <remarks>
/// If this field is specified, <see cref="StringMarshalling" /> must not be specified
/// or must be set to <see cref="StringMarshalling.Custom" />.
/// </remarks>
public Type? StringMarshallingCustomType { get; set; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using SharedTypes.ComInterfaces;
using Xunit;

namespace ComInterfaceGenerator.Tests
{
public unsafe partial class StringMarshallingTests
{
[LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "new_utf8_marshalling")]
public static partial void* NewIUtf8Marshalling();

[LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "new_utf16_marshalling")]
public static partial void* NewIUtf16Marshalling();

[GeneratedComClass]
internal partial class Utf8MarshalledClass : IUTF8Marshalling
{
string _data = "Hello, World!";

public string GetString() => _data;
public void SetString(string value) => _data = value;
}

[GeneratedComClass]
internal partial class Utf16MarshalledClass : IUTF16Marshalling
{
string _data = "Hello, World!";

public string GetString() => _data;
public void SetString(string value) => _data = value;
}

[GeneratedComClass]
internal partial class CustomUtf16MarshalledClass : ICustomStringMarshallingUtf16
{
string _data = "Hello, World!";

public string GetString() => _data;
public void SetString(string value) => _data = value;
}

[Fact]
public void ValidateStringMarshallingRCW()
{
var cw = new StrategyBasedComWrappers();
var utf8 = NewIUtf8Marshalling();
IUTF8Marshalling obj8 = (IUTF8Marshalling)cw.GetOrCreateObjectForComInstance((nint)utf8, CreateObjectFlags.None);
string value = obj8.GetString();
Assert.Equal("Hello, World!", value);
obj8.SetString("TestString");
value = obj8.GetString();
Assert.Equal("TestString", value);

var utf16 = NewIUtf16Marshalling();
IUTF16Marshalling obj16 = (IUTF16Marshalling)cw.GetOrCreateObjectForComInstance((nint)utf16, CreateObjectFlags.None);
Assert.Equal("Hello, World!", obj16.GetString());
obj16.SetString("TestString");
Assert.Equal("TestString", obj16.GetString());

var utf16custom = NewIUtf16Marshalling();
ICustomStringMarshallingUtf16 objCustom = (ICustomStringMarshallingUtf16)cw.GetOrCreateObjectForComInstance((nint)utf16custom, CreateObjectFlags.None);
Assert.Equal("Hello, World!", objCustom.GetString());
objCustom.SetString("TestString");
Assert.Equal("TestString", objCustom.GetString());
}

[Fact]
[ActiveIssue("https://github.com/dotnet/runtime/issues/85795", TargetFrameworkMonikers.Any)]
public void RcwToCcw()
{
var cw = new StrategyBasedComWrappers();

var utf8 = new Utf8MarshalledClass();
var utf8ComInstance = cw.GetOrCreateComInterfaceForObject(utf8, CreateComInterfaceFlags.None);
var utf8ComObject = (IUTF8Marshalling)cw.GetOrCreateObjectForComInstance(utf8ComInstance, CreateObjectFlags.None);
Assert.Equal(utf8.GetString(), utf8ComObject.GetString());
utf8.SetString("Set from CLR object");
Assert.Equal(utf8.GetString(), utf8ComObject.GetString());
utf8ComObject.SetString("Set from COM object");
Assert.Equal(utf8.GetString(), utf8ComObject.GetString());

var utf16 = new Utf16MarshalledClass();
var utf16ComInstance = cw.GetOrCreateComInterfaceForObject(utf16, CreateComInterfaceFlags.None);
var utf16ComObject = (IUTF16Marshalling)cw.GetOrCreateObjectForComInstance(utf16ComInstance, CreateObjectFlags.None);
Assert.Equal(utf16.GetString(), utf16ComObject.GetString());
utf16.SetString("Set from CLR object");
Assert.Equal(utf16.GetString(), utf16ComObject.GetString());
utf16ComObject.SetString("Set from COM object");
Assert.Equal(utf16.GetString(), utf16ComObject.GetString());

var customUtf16 = new CustomUtf16MarshalledClass();
var customUtf16ComInstance = cw.GetOrCreateComInterfaceForObject(customUtf16, CreateComInterfaceFlags.None);
var customUtf16ComObject = (ICustomStringMarshallingUtf16)cw.GetOrCreateObjectForComInstance(customUtf16ComInstance, CreateObjectFlags.None);
Assert.Equal(customUtf16.GetString(), customUtf16ComObject.GetString());
customUtf16.SetString("Set from CLR object");
Assert.Equal(customUtf16.GetString(), customUtf16ComObject.GetString());
customUtf16ComObject.SetString("Set from COM object");
Assert.Equal(customUtf16.GetString(), customUtf16ComObject.GetString());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.Marshalling;
using System.Text;
using System.Threading.Tasks;
using SharedTypes.ComInterfaces;
using static System.Runtime.InteropServices.ComWrappers;

namespace NativeExports.ComInterfaceGenerator
{
public unsafe class StringMarshalling
{
[UnmanagedCallersOnly(EntryPoint = "new_utf8_marshalling")]
public static void* CreateUtf8ComObject()
{
MyComWrapper cw = new();
var myObject = new Utf8Implementation();
nint ptr = cw.GetOrCreateComInterfaceForObject(myObject, CreateComInterfaceFlags.None);

return (void*)ptr;
}

[UnmanagedCallersOnly(EntryPoint = "new_utf16_marshalling")]
public static void* CreateUtf16ComObject()
{
MyComWrapper cw = new();
var myObject = new Utf16Implementation();
nint ptr = cw.GetOrCreateComInterfaceForObject(myObject, CreateComInterfaceFlags.None);

return (void*)ptr;
}

class MyComWrapper : ComWrappers
{
static void* _s_comInterface1VTable = null;
static void* _s_comInterface2VTable = null;
static void* S_Utf8VTable
{
get
{
if (_s_comInterface1VTable != null)
return _s_comInterface1VTable;
void** vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(GetAndSetInt), sizeof(void*) * 5);
GetIUnknownImpl(out var fpQueryInterface, out var fpAddReference, out var fpRelease);
vtable[0] = (void*)fpQueryInterface;
vtable[1] = (void*)fpAddReference;
vtable[2] = (void*)fpRelease;
vtable[3] = (delegate* unmanaged<void*, byte**, int>)&Utf8Implementation.ABI.GetStringUtf8;
vtable[4] = (delegate* unmanaged<void*, byte*, int>)&Utf8Implementation.ABI.SetStringUtf8;
_s_comInterface1VTable = vtable;
return _s_comInterface1VTable;
}
}
static void* S_Utf16VTable
{
get
{
if (_s_comInterface2VTable != null)
return _s_comInterface2VTable;
void** vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(GetAndSetInt), sizeof(void*) * 5);
GetIUnknownImpl(out var fpQueryInterface, out var fpAddReference, out var fpRelease);
vtable[0] = (void*)fpQueryInterface;
vtable[1] = (void*)fpAddReference;
vtable[2] = (void*)fpRelease;
vtable[3] = (delegate* unmanaged<void*, ushort**, int>)&Utf16Implementation.ABI.GetStringUtf16;
vtable[4] = (delegate* unmanaged<void*, ushort*, int>)&Utf16Implementation.ABI.SetStringUtf16;
_s_comInterface2VTable = vtable;
return _s_comInterface2VTable;
}
}

protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
{
if (obj is IUTF8Marshalling)
{
ComInterfaceEntry* comInterfaceEntry = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(Utf8Implementation), sizeof(ComInterfaceEntry));
comInterfaceEntry->IID = new Guid(IUTF8Marshalling._guid);
comInterfaceEntry->Vtable = (nint)S_Utf8VTable;
count = 1;
return comInterfaceEntry;
}
else if (obj is IUTF16Marshalling)
{
ComInterfaceEntry* comInterfaceEntry = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(Utf16Implementation), sizeof(ComInterfaceEntry));
comInterfaceEntry->IID = new Guid(IUTF16Marshalling._guid);
comInterfaceEntry->Vtable = (nint)S_Utf16VTable;
count = 1;
return comInterfaceEntry;
}
count = 0;
return null;
}

protected override object? CreateObject(nint externalComObject, CreateObjectFlags flags) => throw new NotImplementedException();
protected override void ReleaseObjects(IEnumerable objects) => throw new NotImplementedException();
}

class Utf8Implementation : IUTF8Marshalling
{
string _data = "Hello, World!";

string IUTF8Marshalling.GetString()
{
return _data;
}
void IUTF8Marshalling.SetString(string x)
{
_data = x;
}

// Provides function pointers in the COM format to use in COM VTables
public static class ABI
{
[UnmanagedCallersOnly]
public static int GetStringUtf8(void* @this, byte** value)
{
try
{
string currValue = ComInterfaceDispatch.GetInstance<IUTF8Marshalling>((ComInterfaceDispatch*)@this).GetString();
*value = Utf8StringMarshaller.ConvertToUnmanaged(currValue);
return 0;
}
catch (Exception e)
{
return e.HResult;
}
}

[UnmanagedCallersOnly]
public static int SetStringUtf8(void* @this, byte* newValue)
{
try
{
string value = Utf8StringMarshaller.ConvertToManaged(newValue);
ComInterfaceDispatch.GetInstance<IUTF8Marshalling>((ComInterfaceDispatch*)@this).SetString(value);
return 0;
}
catch (Exception e)
{
return e.HResult;
}
}
}
}

class Utf16Implementation : IUTF16Marshalling
{
string _data = "Hello, World!";

string IUTF16Marshalling.GetString()
{
return _data;
}
void IUTF16Marshalling.SetString(string x)
{
_data = x;
}

// Provides function pointers in the COM format to use in COM VTables
public static class ABI
{
[UnmanagedCallersOnly]
public static int GetStringUtf16(void* @this, ushort** value)
{
try
{
string currValue = ComInterfaceDispatch.GetInstance<IUTF16Marshalling>((ComInterfaceDispatch*)@this).GetString();
*value = Utf16StringMarshaller.ConvertToUnmanaged(currValue);
return 0;
}
catch (Exception e)
{
return e.HResult;
}
}

[UnmanagedCallersOnly]
public static int SetStringUtf16(void* @this, ushort* newValue)
{
try
{
string value = Utf16StringMarshaller.ConvertToManaged(newValue);
ComInterfaceDispatch.GetInstance<IUTF16Marshalling>((ComInterfaceDispatch*)@this).SetString(value);
return 0;
}
catch (Exception e)
{
return e.HResult;
}
}
}
}
}
}
Loading