From d7a98b70f2bd2e60c849cdae8aae189c4d52aabe Mon Sep 17 00:00:00 2001 From: Lehonti Ramos Date: Thu, 10 Aug 2023 22:23:36 +0200 Subject: [PATCH] File-scoped namespaces in files under `ComponentModel` (`Microsoft.ML.Core`) --- .../ComponentModel/AssemblyLoadingUtils.cs | 451 ++-- .../ComponentModel/ComponentCatalog.cs | 1875 ++++++++--------- .../ComponentModel/ComponentFactory.cs | 225 +- .../ComponentModel/ExtensionBaseAttribute.cs | 25 +- .../ComponentModel/LoadableClassAttribute.cs | 391 ++-- 5 files changed, 1481 insertions(+), 1486 deletions(-) diff --git a/src/Microsoft.ML.Core/ComponentModel/AssemblyLoadingUtils.cs b/src/Microsoft.ML.Core/ComponentModel/AssemblyLoadingUtils.cs index 3c0d2609fe..79bc103980 100644 --- a/src/Microsoft.ML.Core/ComponentModel/AssemblyLoadingUtils.cs +++ b/src/Microsoft.ML.Core/ComponentModel/AssemblyLoadingUtils.cs @@ -8,286 +8,285 @@ using System.Reflection; using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML.Runtime; + +[Obsolete("The usage for this is intended for the internal command line utilities and is not intended for anything related to the API. " + + "Please consider another way of doing whatever it is you're attempting to accomplish.")] +[BestFriend] +internal static class AssemblyLoadingUtils { - [Obsolete("The usage for this is intended for the internal command line utilities and is not intended for anything related to the API. " + - "Please consider another way of doing whatever it is you're attempting to accomplish.")] - [BestFriend] - internal static class AssemblyLoadingUtils + /// + /// Make sure the given assemblies are loaded and that their loadable classes have been catalogued. + /// + public static void LoadAndRegister(IHostEnvironment env, string[] assemblies) { - /// - /// Make sure the given assemblies are loaded and that their loadable classes have been catalogued. - /// - public static void LoadAndRegister(IHostEnvironment env, string[] assemblies) - { - Contracts.AssertValue(env); + Contracts.AssertValue(env); - if (Utils.Size(assemblies) > 0) + if (Utils.Size(assemblies) > 0) + { + foreach (string path in assemblies) { - foreach (string path in assemblies) + Exception ex = null; + try { - Exception ex = null; - try - { - // REVIEW: Will LoadFrom ever return null? - Contracts.CheckNonEmpty(path, nameof(path)); - if (!File.Exists(path)) - { - throw Contracts.ExceptParam(nameof(path), "File does not exist at path: {0}", path); - } - var assem = LoadAssembly(env, path); - if (assem != null) - continue; - } - catch (Exception e) - { - ex = e; - } - - // If it is a zip file, load it that way. - ZipArchive zip; - try + // REVIEW: Will LoadFrom ever return null? + Contracts.CheckNonEmpty(path, nameof(path)); + if (!File.Exists(path)) { - zip = ZipFile.OpenRead(path); + throw Contracts.ExceptParam(nameof(path), "File does not exist at path: {0}", path); } - catch (Exception e) - { - // Couldn't load as an assembly and not a zip, so warn the user. - ex = ex ?? e; - Console.Error.WriteLine("Warning: Could not load '{0}': {1}", path, ex.Message); + var assem = LoadAssembly(env, path); + if (assem != null) continue; - } + } + catch (Exception e) + { + ex = e; + } - string dir; - try - { - dir = CreateTempDirectory(); - } - catch (Exception e) - { - throw Contracts.ExceptIO(e, "Creating temp directory for extra assembly zip extraction failed: '{0}'", path); - } + // If it is a zip file, load it that way. + ZipArchive zip; + try + { + zip = ZipFile.OpenRead(path); + } + catch (Exception e) + { + // Couldn't load as an assembly and not a zip, so warn the user. + ex = ex ?? e; + Console.Error.WriteLine("Warning: Could not load '{0}': {1}", path, ex.Message); + continue; + } - try - { - zip.ExtractToDirectory(dir); - } - catch (Exception e) - { - throw Contracts.ExceptIO(e, "Extracting extra assembly zip failed: '{0}'", path); - } + string dir; + try + { + dir = CreateTempDirectory(); + } + catch (Exception e) + { + throw Contracts.ExceptIO(e, "Creating temp directory for extra assembly zip extraction failed: '{0}'", path); + } - LoadAssembliesInDir(env, dir, false); + try + { + zip.ExtractToDirectory(dir); } + catch (Exception e) + { + throw Contracts.ExceptIO(e, "Extracting extra assembly zip failed: '{0}'", path); + } + + LoadAssembliesInDir(env, dir, false); } } + } - public static IDisposable CreateAssemblyRegistrar(IHostEnvironment env, string loadAssembliesPath = null) - { - Contracts.CheckValue(env, nameof(env)); - env.CheckValueOrNull(loadAssembliesPath); + public static IDisposable CreateAssemblyRegistrar(IHostEnvironment env, string loadAssembliesPath = null) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValueOrNull(loadAssembliesPath); - return new AssemblyRegistrar(env, loadAssembliesPath); - } + return new AssemblyRegistrar(env, loadAssembliesPath); + } - public static void RegisterCurrentLoadedAssemblies(IHostEnvironment env) - { - Contracts.CheckValue(env, nameof(env)); + public static void RegisterCurrentLoadedAssemblies(IHostEnvironment env) + { + Contracts.CheckValue(env, nameof(env)); - foreach (Assembly a in AppDomain.CurrentDomain.GetAssemblies()) - { - TryRegisterAssembly(env.ComponentCatalog, a); - } + foreach (Assembly a in AppDomain.CurrentDomain.GetAssemblies()) + { + TryRegisterAssembly(env.ComponentCatalog, a); } + } - private static string CreateTempDirectory() + private static string CreateTempDirectory() + { + string dir = GetTempPath(); + Directory.CreateDirectory(dir); + return dir; + } + + private static string GetTempPath() + { + Guid guid = Guid.NewGuid(); + return Path.GetFullPath(Path.Combine(Path.GetTempPath(), "MLNET_" + guid.ToString())); + } + + private static readonly string[] _filePrefixesToAvoid = new string[] { + "api-ms-win", + "clr", + "coreclr", + "dbgshim", + "ext-ms-win", + "microsoft.bond.", + "microsoft.cosmos.", + "microsoft.csharp", + "microsoft.data.", + "microsoft.hpc.", + "microsoft.live.", + "microsoft.platformbuilder.", + "microsoft.visualbasic", + "microsoft.visualstudio.", + "microsoft.win32", + "microsoft.windowsapicodepack.", + "microsoft.windowsazure.", + "mscor", + "msvc", + "petzold.", + "roslyn.", + "sho", + "sni", + "sqm", + "system.", + "zlib", + }; + + private static bool ShouldSkipPath(string path) + { + string name = Path.GetFileName(path).ToLowerInvariant(); + switch (name) { - string dir = GetTempPath(); - Directory.CreateDirectory(dir); - return dir; + case "cpumathnative.dll": + case "cqo.dll": + case "fasttreenative.dll": + case "libiomp5md.dll": + case "ldanative.dll": + case "libvw.dll": + case "matrixinterf.dll": + case "microsoft.ml.neuralnetworks.gpucuda.dll": + case "mklimports.dll": + case "microsoft.research.controls.decisiontrees.dll": + case "microsoft.ml.neuralnetworks.sse.dll": + case "mklproxynative.dll": + case "neuraltreeevaluator.dll": + case "optimizationbuilderdotnet.dll": + case "parallelcommunicator.dll": + case "Microsoft.ML.runtests.dll": + case "scopecompiler.dll": + case "symsgdnative.dll": + case "tbb.dll": + case "internallearnscope.dll": + case "unmanagedlib.dll": + case "vcclient.dll": + case "libxgboost.dll": + case "zedgraph.dll": + case "__scopecodegen__.dll": + case "cosmosClientApi.dll": + return true; } - private static string GetTempPath() + foreach (var s in _filePrefixesToAvoid) { - Guid guid = Guid.NewGuid(); - return Path.GetFullPath(Path.Combine(Path.GetTempPath(), "MLNET_" + guid.ToString())); + if (name.StartsWith(s, StringComparison.OrdinalIgnoreCase)) + return true; } - private static readonly string[] _filePrefixesToAvoid = new string[] { - "api-ms-win", - "clr", - "coreclr", - "dbgshim", - "ext-ms-win", - "microsoft.bond.", - "microsoft.cosmos.", - "microsoft.csharp", - "microsoft.data.", - "microsoft.hpc.", - "microsoft.live.", - "microsoft.platformbuilder.", - "microsoft.visualbasic", - "microsoft.visualstudio.", - "microsoft.win32", - "microsoft.windowsapicodepack.", - "microsoft.windowsazure.", - "mscor", - "msvc", - "petzold.", - "roslyn.", - "sho", - "sni", - "sqm", - "system.", - "zlib", - }; - - private static bool ShouldSkipPath(string path) - { - string name = Path.GetFileName(path).ToLowerInvariant(); - switch (name) - { - case "cpumathnative.dll": - case "cqo.dll": - case "fasttreenative.dll": - case "libiomp5md.dll": - case "ldanative.dll": - case "libvw.dll": - case "matrixinterf.dll": - case "microsoft.ml.neuralnetworks.gpucuda.dll": - case "mklimports.dll": - case "microsoft.research.controls.decisiontrees.dll": - case "microsoft.ml.neuralnetworks.sse.dll": - case "mklproxynative.dll": - case "neuraltreeevaluator.dll": - case "optimizationbuilderdotnet.dll": - case "parallelcommunicator.dll": - case "Microsoft.ML.runtests.dll": - case "scopecompiler.dll": - case "symsgdnative.dll": - case "tbb.dll": - case "internallearnscope.dll": - case "unmanagedlib.dll": - case "vcclient.dll": - case "libxgboost.dll": - case "zedgraph.dll": - case "__scopecodegen__.dll": - case "cosmosClientApi.dll": - return true; - } + return false; + } - foreach (var s in _filePrefixesToAvoid) + private static void LoadAssembliesInDir(IHostEnvironment env, string dir, bool filter) + { + if (!Directory.Exists(dir)) + return; + + // Load all dlls in the given directory. + var paths = Directory.EnumerateFiles(dir, "*.dll"); + foreach (string path in paths) + { + if (filter && ShouldSkipPath(path)) { - if (name.StartsWith(s, StringComparison.OrdinalIgnoreCase)) - return true; + continue; } - return false; + LoadAssembly(env, path); } + } - private static void LoadAssembliesInDir(IHostEnvironment env, string dir, bool filter) + /// + /// Given an assembly path, load the assembly and register it with the ComponentCatalog. + /// + private static Assembly LoadAssembly(IHostEnvironment env, string path) + { + Assembly assembly = null; + try { - if (!Directory.Exists(dir)) - return; - - // Load all dlls in the given directory. - var paths = Directory.EnumerateFiles(dir, "*.dll"); - foreach (string path in paths) - { - if (filter && ShouldSkipPath(path)) - { - continue; - } - - LoadAssembly(env, path); - } + assembly = Assembly.LoadFrom(path); + } + catch (Exception) + { + return null; } - /// - /// Given an assembly path, load the assembly and register it with the ComponentCatalog. - /// - private static Assembly LoadAssembly(IHostEnvironment env, string path) + if (assembly != null) { - Assembly assembly = null; - try - { - assembly = Assembly.LoadFrom(path); - } - catch (Exception) - { - return null; - } + TryRegisterAssembly(env.ComponentCatalog, assembly); + } - if (assembly != null) - { - TryRegisterAssembly(env.ComponentCatalog, assembly); - } + return assembly; + } - return assembly; - } + /// + /// Checks whether references the assembly containing LoadableClassAttributeBase, + /// and therefore can contain components. + /// + private static bool CanContainComponents(Assembly assembly) + { + var targetFullName = typeof(LoadableClassAttributeBase).Assembly.GetName().FullName; - /// - /// Checks whether references the assembly containing LoadableClassAttributeBase, - /// and therefore can contain components. - /// - private static bool CanContainComponents(Assembly assembly) + bool found = false; + foreach (var name in assembly.GetReferencedAssemblies()) { - var targetFullName = typeof(LoadableClassAttributeBase).Assembly.GetName().FullName; - - bool found = false; - foreach (var name in assembly.GetReferencedAssemblies()) + if (name.FullName == targetFullName) { - if (name.FullName == targetFullName) - { - found = true; - break; - } + found = true; + break; } - - return found; } - private static void TryRegisterAssembly(ComponentCatalog catalog, Assembly assembly) - { - // Don't try to index dynamic generated assembly - if (assembly.IsDynamic) - return; - - if (!CanContainComponents(assembly)) - return; + return found; + } - catalog.RegisterAssembly(assembly); - } + private static void TryRegisterAssembly(ComponentCatalog catalog, Assembly assembly) + { + // Don't try to index dynamic generated assembly + if (assembly.IsDynamic) + return; - private sealed class AssemblyRegistrar : IDisposable - { - private readonly IHostEnvironment _env; + if (!CanContainComponents(assembly)) + return; - public AssemblyRegistrar(IHostEnvironment env, string path) - { - _env = env; + catalog.RegisterAssembly(assembly); + } - RegisterCurrentLoadedAssemblies(_env); + private sealed class AssemblyRegistrar : IDisposable + { + private readonly IHostEnvironment _env; - if (!string.IsNullOrEmpty(path)) - { - LoadAssembliesInDir(_env, path, true); - path = Path.Combine(path, "AutoLoad"); - LoadAssembliesInDir(_env, path, true); - } + public AssemblyRegistrar(IHostEnvironment env, string path) + { + _env = env; - AppDomain.CurrentDomain.AssemblyLoad += CurrentDomainAssemblyLoad; - } + RegisterCurrentLoadedAssemblies(_env); - public void Dispose() + if (!string.IsNullOrEmpty(path)) { - AppDomain.CurrentDomain.AssemblyLoad -= CurrentDomainAssemblyLoad; + LoadAssembliesInDir(_env, path, true); + path = Path.Combine(path, "AutoLoad"); + LoadAssembliesInDir(_env, path, true); } - private void CurrentDomainAssemblyLoad(object sender, AssemblyLoadEventArgs args) - { - TryRegisterAssembly(_env.ComponentCatalog, args.LoadedAssembly); - } + AppDomain.CurrentDomain.AssemblyLoad += CurrentDomainAssemblyLoad; + } + + public void Dispose() + { + AppDomain.CurrentDomain.AssemblyLoad -= CurrentDomainAssemblyLoad; + } + + private void CurrentDomainAssemblyLoad(object sender, AssemblyLoadEventArgs args) + { + TryRegisterAssembly(_env.ComponentCatalog, args.LoadedAssembly); } } } diff --git a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs index 77604fb439..1c718db0d5 100644 --- a/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs +++ b/src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs @@ -11,1135 +11,1134 @@ using Microsoft.ML.EntryPoints; using Microsoft.ML.Internal.Utilities; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML.Runtime; + + +internal static class Extension { + internal static AccessModifier Accessmodifier(this MethodInfo methodInfo) + { + if (methodInfo.IsFamilyAndAssembly) + return AccessModifier.PrivateProtected; + if (methodInfo.IsPrivate) + return AccessModifier.Private; + if (methodInfo.IsFamily) + return AccessModifier.Protected; + if (methodInfo.IsFamilyOrAssembly) + return AccessModifier.ProtectedInternal; + if (methodInfo.IsAssembly) + return AccessModifier.Internal; + if (methodInfo.IsPublic) + return AccessModifier.Public; + throw new ArgumentException("Did not find access modifier", nameof(methodInfo)); + } - internal static class Extension + internal static AccessModifier Accessmodifier(this ConstructorInfo constructorInfo) { - internal static AccessModifier Accessmodifier(this MethodInfo methodInfo) - { - if (methodInfo.IsFamilyAndAssembly) - return AccessModifier.PrivateProtected; - if (methodInfo.IsPrivate) - return AccessModifier.Private; - if (methodInfo.IsFamily) - return AccessModifier.Protected; - if (methodInfo.IsFamilyOrAssembly) - return AccessModifier.ProtectedInternal; - if (methodInfo.IsAssembly) - return AccessModifier.Internal; - if (methodInfo.IsPublic) - return AccessModifier.Public; - throw new ArgumentException("Did not find access modifier", nameof(methodInfo)); - } + if (constructorInfo.IsFamilyAndAssembly) + return AccessModifier.PrivateProtected; + if (constructorInfo.IsPrivate) + return AccessModifier.Private; + if (constructorInfo.IsFamily) + return AccessModifier.Protected; + if (constructorInfo.IsFamilyOrAssembly) + return AccessModifier.ProtectedInternal; + if (constructorInfo.IsAssembly) + return AccessModifier.Internal; + if (constructorInfo.IsPublic) + return AccessModifier.Public; + throw new ArgumentException("Did not find access modifier", nameof(constructorInfo)); + } - internal static AccessModifier Accessmodifier(this ConstructorInfo constructorInfo) - { - if (constructorInfo.IsFamilyAndAssembly) - return AccessModifier.PrivateProtected; - if (constructorInfo.IsPrivate) - return AccessModifier.Private; - if (constructorInfo.IsFamily) - return AccessModifier.Protected; - if (constructorInfo.IsFamilyOrAssembly) - return AccessModifier.ProtectedInternal; - if (constructorInfo.IsAssembly) - return AccessModifier.Internal; - if (constructorInfo.IsPublic) - return AccessModifier.Public; - throw new ArgumentException("Did not find access modifier", nameof(constructorInfo)); - } + internal enum AccessModifier + { + PrivateProtected, + Private, + Protected, + ProtectedInternal, + Internal, + Public + } +} - internal enum AccessModifier - { - PrivateProtected, - Private, - Protected, - ProtectedInternal, - Internal, - Public - } +/// +/// This catalogs instantiable components (aka, loadable classes). Components are registered via +/// a descendant of , identifying the names and signature types under which the component +/// type should be registered. Signatures are delegate types that return void and specify that parameter +/// types for component instantiation. Each component may also specify an "arguments object" that should +/// be provided at instantiation time. +/// +public sealed class ComponentCatalog +{ + internal ComponentCatalog() + { + _lock = new object(); + _cachedAssemblies = new HashSet(); + _classesByKey = new Dictionary(); + _classes = new List(); + _signatures = new Dictionary(); + + _entryPoints = new List(); + _entryPointMap = new Dictionary(); + _componentMap = new Dictionary(); + _components = new List(); + + _extensionsMap = new Dictionary<(Type AttributeType, string ContractName), Type>(); } /// - /// This catalogs instantiable components (aka, loadable classes). Components are registered via - /// a descendant of , identifying the names and signature types under which the component - /// type should be registered. Signatures are delegate types that return void and specify that parameter - /// types for component instantiation. Each component may also specify an "arguments object" that should - /// be provided at instantiation time. + /// Provides information on an instantiable component, aka, loadable class. /// - public sealed class ComponentCatalog + [BestFriend] + internal sealed class LoadableClassInfo { - internal ComponentCatalog() - { - _lock = new object(); - _cachedAssemblies = new HashSet(); - _classesByKey = new Dictionary(); - _classes = new List(); - _signatures = new Dictionary(); - - _entryPoints = new List(); - _entryPointMap = new Dictionary(); - _componentMap = new Dictionary(); - _components = new List(); - - _extensionsMap = new Dictionary<(Type AttributeType, string ContractName), Type>(); - } - /// - /// Provides information on an instantiable component, aka, loadable class. + /// Used for dictionary lookup based on signature and name. /// - [BestFriend] - internal sealed class LoadableClassInfo + internal readonly struct Key : IEquatable { - /// - /// Used for dictionary lookup based on signature and name. - /// - internal readonly struct Key : IEquatable - { - public readonly string Name; - public readonly Type Signature; - - public Key(string name, Type sig) - { - Name = name; - Signature = sig; - } - - public override int GetHashCode() - { - return Hashing.CombinedHash(Name.GetHashCode(), Signature.GetHashCode()); - } - - public override bool Equals(object obj) - { - return obj is Key && Equals((Key)obj); - } - - public bool Equals(Key other) - { - return other.Name == Name && other.Signature == Signature; - } - } + public readonly string Name; + public readonly Type Signature; - /// - /// Count of component construction arguments, NOT including the arguments object (if there is one). - /// This matches the number of arguments for the signature type delegate(s). - /// - internal int ExtraArgCount => ArgType == null ? CtorTypes.Length : CtorTypes.Length - 1; - - public Type Type { get; } - - /// - /// The type that contains the construction method, whether static Instance property, - /// static Create method, or constructor. - /// - public Type LoaderType { get; } - - public IReadOnlyList SignatureTypes { get; } - - /// - /// Summary of the component. - /// - public string Summary { get; } - - /// - /// UserName may be null or empty, indicating that it should be hidden in UI. - /// - public string UserName { get; } - - /// - /// Whether this is a "hidden" component, that generally shouldn't be displayed - /// to users. - /// - public bool IsHidden => string.IsNullOrWhiteSpace(UserName); - - /// - /// All load names. The first is the default. - /// - public IReadOnlyList LoadNames { get; } - - /// - /// The static property that returns an instance of this loadable class. - /// This creation method does not support an arguments class. - /// Only one of Ctor, Create and InstanceGetter can be non-null. - /// - public MethodInfo InstanceGetter { get; } - - /// - /// The constructor to create an instance of this loadable class. - /// This creation method supports an arguments class. - /// Only one of Ctor, Create and InstanceGetter can be non-null. - /// - public ConstructorInfo Constructor { get; } - - /// - /// The static method that creates an instance of this loadable class. - /// This creation method supports an arguments class. - /// Only one of Ctor, Create and InstanceGetter can be non-null. - /// - public MethodInfo CreateMethod { get; } - - public bool RequireEnvironment { get; } - - /// - /// A name of an embedded resource containing documentation for this - /// loadable class. This is non-null only in the event that we have - /// verified the assembly of actually contains - /// this resource. - /// - public string DocName { get; } - - /// - /// The type that contains the arguments to the component. - /// - public Type ArgType { get; } - - private Type[] CtorTypes { get; } - - internal LoadableClassInfo(LoadableClassAttributeBase attr, MethodInfo getter, ConstructorInfo ctor, MethodInfo create, bool requireEnvironment) + public Key(string name, Type sig) { - Contracts.AssertValue(attr); - Contracts.AssertValue(attr.InstanceType); - Contracts.AssertValue(attr.LoaderType); - Contracts.AssertValueOrNull(attr.Summary); - Contracts.AssertValueOrNull(attr.DocName); - Contracts.AssertValueOrNull(attr.UserName); - Contracts.AssertNonEmpty(attr.LoadNames); - Contracts.Assert(getter == null || Utils.Size(attr.CtorTypes) == 0); - - Type = attr.InstanceType; - LoaderType = attr.LoaderType; - Summary = attr.Summary; - UserName = attr.UserName; - LoadNames = attr.LoadNames.AsReadOnly(); - - if (getter != null) - InstanceGetter = getter; - else if (ctor != null) - Constructor = ctor; - else if (create != null) - CreateMethod = create; - ArgType = attr.ArgType; - SignatureTypes = attr.SigTypes.AsReadOnly(); - CtorTypes = attr.CtorTypes ?? Type.EmptyTypes; - RequireEnvironment = requireEnvironment; - - if (!string.IsNullOrWhiteSpace(attr.DocName)) - DocName = attr.DocName; - - Contracts.Assert(ArgType == null || CtorTypes.Length > 0 && CtorTypes[0] == ArgType); + Name = name; + Signature = sig; } - internal object CreateInstanceCore(object[] ctorArgs) + public override int GetHashCode() { - Contracts.Assert(Utils.Size(ctorArgs) == CtorTypes.Length + ((RequireEnvironment) ? 1 : 0)); - try - { - if (InstanceGetter != null) - { - Contracts.Assert(Utils.Size(ctorArgs) == 0); - return InstanceGetter.Invoke(null, null); - } - if (Constructor != null) - return Constructor.Invoke(ctorArgs); - if (CreateMethod != null) - return CreateMethod.Invoke(null, ctorArgs); - } - catch (TargetInvocationException ex) - { - if (ex.InnerException != null && ex.InnerException.IsMarked()) - throw Contracts.Except(ex, "Error during class instantiation"); - else - throw; - } - throw Contracts.Except("Can't instantiate class '{0}'", Type.Name); + return Hashing.CombinedHash(Name.GetHashCode(), Signature.GetHashCode()); } - /// - /// Create an instance, given the arguments object and arguments to the signature delegate. - /// The args should be non-null iff ArgType is non-null. The length of the extra array should - /// match the number of parameters for the signature delegate. When that number is zero, extra - /// may be null. - /// - public object CreateInstance(IHostEnvironment env, object args, object[] extra) + public override bool Equals(object obj) { - Contracts.CheckValue(env, nameof(env)); - env.Check((ArgType != null) == (args != null)); - env.Check(Utils.Size(extra) == ExtraArgCount); - - List prefix = new List(); - if (RequireEnvironment) - prefix.Add(env); - if (ArgType != null) - prefix.Add(args); - var values = Utils.Concat(prefix.ToArray(), extra); - return CreateInstanceCore(values); + return obj is Key && Equals((Key)obj); } - /// - /// Create an instance, given the arguments object and arguments to the signature delegate. - /// The args should be non-null iff ArgType is non-null. The length of the extra array should - /// match the number of parameters for the signature delegate. When that number is zero, extra - /// may be null. - /// - public TRes CreateInstance(IHostEnvironment env, object args, object[] extra) + public bool Equals(Key other) { - if (!typeof(TRes).IsAssignableFrom(Type)) - throw Contracts.Except("Loadable class '{0}' does not derive from '{1}'", LoadNames[0], typeof(TRes).FullName); - return (TRes)CreateInstance(env, args, extra); + return other.Name == Name && other.Signature == Signature; } + } - /// - /// Create an instance with default arguments. - /// - public TRes CreateInstance(IHostEnvironment env) - { - if (!typeof(TRes).IsAssignableFrom(Type)) - throw Contracts.Except("Loadable class '{0}' does not derive from '{1}'", LoadNames[0], typeof(TRes).FullName); - return (TRes)CreateInstance(env, CreateArguments(), null); - } + /// + /// Count of component construction arguments, NOT including the arguments object (if there is one). + /// This matches the number of arguments for the signature type delegate(s). + /// + internal int ExtraArgCount => ArgType == null ? CtorTypes.Length : CtorTypes.Length - 1; - /// - /// If is not null, returns a new default instance of . - /// Otherwise, returns null. - /// - public object CreateArguments() - { - if (ArgType == null) - return null; + public Type Type { get; } - var ctor = ArgType.GetConstructor(Type.EmptyTypes); - if (ctor == null) - { - throw Contracts.Except("Loadable class '{0}' has ArgType '{1}', which has no suitable constructor", - UserName, ArgType); - } + /// + /// The type that contains the construction method, whether static Instance property, + /// static Create method, or constructor. + /// + public Type LoaderType { get; } - return ctor.Invoke(null); - } - } + public IReadOnlyList SignatureTypes { get; } + + /// + /// Summary of the component. + /// + public string Summary { get; } + + /// + /// UserName may be null or empty, indicating that it should be hidden in UI. + /// + public string UserName { get; } + + /// + /// Whether this is a "hidden" component, that generally shouldn't be displayed + /// to users. + /// + public bool IsHidden => string.IsNullOrWhiteSpace(UserName); + + /// + /// All load names. The first is the default. + /// + public IReadOnlyList LoadNames { get; } + + /// + /// The static property that returns an instance of this loadable class. + /// This creation method does not support an arguments class. + /// Only one of Ctor, Create and InstanceGetter can be non-null. + /// + public MethodInfo InstanceGetter { get; } + + /// + /// The constructor to create an instance of this loadable class. + /// This creation method supports an arguments class. + /// Only one of Ctor, Create and InstanceGetter can be non-null. + /// + public ConstructorInfo Constructor { get; } + + /// + /// The static method that creates an instance of this loadable class. + /// This creation method supports an arguments class. + /// Only one of Ctor, Create and InstanceGetter can be non-null. + /// + public MethodInfo CreateMethod { get; } + + public bool RequireEnvironment { get; } /// - /// A description of a single entry point. + /// A name of an embedded resource containing documentation for this + /// loadable class. This is non-null only in the event that we have + /// verified the assembly of actually contains + /// this resource. /// - [BestFriend] - internal sealed class EntryPointInfo + public string DocName { get; } + + /// + /// The type that contains the arguments to the component. + /// + public Type ArgType { get; } + + private Type[] CtorTypes { get; } + + internal LoadableClassInfo(LoadableClassAttributeBase attr, MethodInfo getter, ConstructorInfo ctor, MethodInfo create, bool requireEnvironment) { - public readonly string Name; - public readonly string Description; - public readonly string ShortName; - public readonly string FriendlyName; - public readonly MethodInfo Method; - public readonly Type InputType; - public readonly Type OutputType; - public readonly Type[] InputKinds; - public readonly Type[] OutputKinds; - public readonly ObsoleteAttribute ObsoleteAttribute; - - internal EntryPointInfo(MethodInfo method, - TlcModule.EntryPointAttribute attribute, ObsoleteAttribute obsoleteAttribute) + Contracts.AssertValue(attr); + Contracts.AssertValue(attr.InstanceType); + Contracts.AssertValue(attr.LoaderType); + Contracts.AssertValueOrNull(attr.Summary); + Contracts.AssertValueOrNull(attr.DocName); + Contracts.AssertValueOrNull(attr.UserName); + Contracts.AssertNonEmpty(attr.LoadNames); + Contracts.Assert(getter == null || Utils.Size(attr.CtorTypes) == 0); + + Type = attr.InstanceType; + LoaderType = attr.LoaderType; + Summary = attr.Summary; + UserName = attr.UserName; + LoadNames = attr.LoadNames.AsReadOnly(); + + if (getter != null) + InstanceGetter = getter; + else if (ctor != null) + Constructor = ctor; + else if (create != null) + CreateMethod = create; + ArgType = attr.ArgType; + SignatureTypes = attr.SigTypes.AsReadOnly(); + CtorTypes = attr.CtorTypes ?? Type.EmptyTypes; + RequireEnvironment = requireEnvironment; + + if (!string.IsNullOrWhiteSpace(attr.DocName)) + DocName = attr.DocName; + + Contracts.Assert(ArgType == null || CtorTypes.Length > 0 && CtorTypes[0] == ArgType); + } + + internal object CreateInstanceCore(object[] ctorArgs) + { + Contracts.Assert(Utils.Size(ctorArgs) == CtorTypes.Length + ((RequireEnvironment) ? 1 : 0)); + try { - Contracts.AssertValue(method); - Contracts.AssertValue(attribute); - - Name = attribute.Name ?? string.Join(".", method.DeclaringType.Name, method.Name); - Description = attribute.Desc; - Method = method; - ShortName = attribute.ShortName; - FriendlyName = attribute.UserName; - ObsoleteAttribute = obsoleteAttribute; - - // There are supposed to be 2 parameters, env and input for non-macro nodes. - // Macro nodes have a 3rd parameter, the entry point node. - var parameters = method.GetParameters(); - if (parameters.Length != 2 && parameters.Length != 3) - throw Contracts.Except("Method '{0}' has {1} parameters, but must have 2 or 3", method.Name, parameters.Length); - if (parameters[0].ParameterType != typeof(IHostEnvironment)) - throw Contracts.Except("Method '{0}', 1st parameter is {1}, but must be IHostEnvironment", method.Name, parameters[0].ParameterType); - InputType = parameters[1].ParameterType; - var outputType = method.ReturnType; - if (!outputType.IsClass) - throw Contracts.Except("Method '{0}' returns {1}, but must return a class", method.Name, outputType); - OutputType = outputType; - - InputKinds = FindEntryPointKinds(InputType); - OutputKinds = FindEntryPointKinds(OutputType); + if (InstanceGetter != null) + { + Contracts.Assert(Utils.Size(ctorArgs) == 0); + return InstanceGetter.Invoke(null, null); + } + if (Constructor != null) + return Constructor.Invoke(ctorArgs); + if (CreateMethod != null) + return CreateMethod.Invoke(null, ctorArgs); } - - private Type[] FindEntryPointKinds(Type type) + catch (TargetInvocationException ex) { - var kindAttr = type.GetTypeInfo().GetCustomAttributes(typeof(TlcModule.EntryPointKindAttribute), false).FirstOrDefault() - as TlcModule.EntryPointKindAttribute; - var baseType = type.BaseType; - - if (baseType == null) - return kindAttr?.Kinds; - var baseKinds = FindEntryPointKinds(baseType); - if (kindAttr == null) - return baseKinds; - if (baseKinds == null) - return kindAttr.Kinds; - return kindAttr.Kinds.Concat(baseKinds).ToArray(); + if (ex.InnerException != null && ex.InnerException.IsMarked()) + throw Contracts.Except(ex, "Error during class instantiation"); + else + throw; } + throw Contracts.Except("Can't instantiate class '{0}'", Type.Name); + } - public override string ToString() => $"{Name}: {Description}"; + /// + /// Create an instance, given the arguments object and arguments to the signature delegate. + /// The args should be non-null iff ArgType is non-null. The length of the extra array should + /// match the number of parameters for the signature delegate. When that number is zero, extra + /// may be null. + /// + public object CreateInstance(IHostEnvironment env, object args, object[] extra) + { + Contracts.CheckValue(env, nameof(env)); + env.Check((ArgType != null) == (args != null)); + env.Check(Utils.Size(extra) == ExtraArgCount); + + List prefix = new List(); + if (RequireEnvironment) + prefix.Add(env); + if (ArgType != null) + prefix.Add(args); + var values = Utils.Concat(prefix.ToArray(), extra); + return CreateInstanceCore(values); } /// - /// A description of a single component. - /// The 'component' is a non-standalone building block that is used to parametrize entry points or other ML.NET components. - /// For example, 'Loss function', or 'similarity calculator' could be components. + /// Create an instance, given the arguments object and arguments to the signature delegate. + /// The args should be non-null iff ArgType is non-null. The length of the extra array should + /// match the number of parameters for the signature delegate. When that number is zero, extra + /// may be null. /// - [BestFriend] - internal sealed class ComponentInfo + public TRes CreateInstance(IHostEnvironment env, object args, object[] extra) { - public readonly string Name; - public readonly string Description; - public readonly string FriendlyName; - public readonly string Kind; - public readonly Type ArgumentType; - public readonly Type InterfaceType; - public readonly string[] Aliases; - - internal ComponentInfo(Type interfaceType, string kind, Type argumentType, TlcModule.ComponentAttribute attribute) - { - Contracts.AssertValue(interfaceType); - Contracts.AssertNonEmpty(kind); - Contracts.AssertValue(argumentType); - Contracts.AssertValue(attribute); - - Name = attribute.Name; - Description = attribute.Desc; - if (string.IsNullOrWhiteSpace(attribute.FriendlyName)) - FriendlyName = Name; - else - FriendlyName = attribute.FriendlyName; + if (!typeof(TRes).IsAssignableFrom(Type)) + throw Contracts.Except("Loadable class '{0}' does not derive from '{1}'", LoadNames[0], typeof(TRes).FullName); + return (TRes)CreateInstance(env, args, extra); + } - Kind = kind; - if (!IsValidName(Kind)) - throw Contracts.Except("Invalid component kind: '{0}'", Kind); + /// + /// Create an instance with default arguments. + /// + public TRes CreateInstance(IHostEnvironment env) + { + if (!typeof(TRes).IsAssignableFrom(Type)) + throw Contracts.Except("Loadable class '{0}' does not derive from '{1}'", LoadNames[0], typeof(TRes).FullName); + return (TRes)CreateInstance(env, CreateArguments(), null); + } - Aliases = attribute.Aliases; - if (!IsValidName(Name)) - throw Contracts.Except("Component name '{0}' is not valid.", Name); + /// + /// If is not null, returns a new default instance of . + /// Otherwise, returns null. + /// + public object CreateArguments() + { + if (ArgType == null) + return null; - if (Aliases != null && Aliases.Any(x => !IsValidName(x))) - throw Contracts.Except("Component '{0}' has an invalid alias '{1}'", Name, Aliases.First(x => !IsValidName(x))); + var ctor = ArgType.GetConstructor(Type.EmptyTypes); + if (ctor == null) + { + throw Contracts.Except("Loadable class '{0}' has ArgType '{1}', which has no suitable constructor", + UserName, ArgType); + } - if (!typeof(IComponentFactory).IsAssignableFrom(argumentType)) - throw Contracts.Except("Component '{0}' must inherit from IComponentFactory", argumentType); + return ctor.Invoke(null); + } + } - ArgumentType = argumentType; - InterfaceType = interfaceType; - } + /// + /// A description of a single entry point. + /// + [BestFriend] + internal sealed class EntryPointInfo + { + public readonly string Name; + public readonly string Description; + public readonly string ShortName; + public readonly string FriendlyName; + public readonly MethodInfo Method; + public readonly Type InputType; + public readonly Type OutputType; + public readonly Type[] InputKinds; + public readonly Type[] OutputKinds; + public readonly ObsoleteAttribute ObsoleteAttribute; + + internal EntryPointInfo(MethodInfo method, + TlcModule.EntryPointAttribute attribute, ObsoleteAttribute obsoleteAttribute) + { + Contracts.AssertValue(method); + Contracts.AssertValue(attribute); + + Name = attribute.Name ?? string.Join(".", method.DeclaringType.Name, method.Name); + Description = attribute.Desc; + Method = method; + ShortName = attribute.ShortName; + FriendlyName = attribute.UserName; + ObsoleteAttribute = obsoleteAttribute; + + // There are supposed to be 2 parameters, env and input for non-macro nodes. + // Macro nodes have a 3rd parameter, the entry point node. + var parameters = method.GetParameters(); + if (parameters.Length != 2 && parameters.Length != 3) + throw Contracts.Except("Method '{0}' has {1} parameters, but must have 2 or 3", method.Name, parameters.Length); + if (parameters[0].ParameterType != typeof(IHostEnvironment)) + throw Contracts.Except("Method '{0}', 1st parameter is {1}, but must be IHostEnvironment", method.Name, parameters[0].ParameterType); + InputType = parameters[1].ParameterType; + var outputType = method.ReturnType; + if (!outputType.IsClass) + throw Contracts.Except("Method '{0}' returns {1}, but must return a class", method.Name, outputType); + OutputType = outputType; + + InputKinds = FindEntryPointKinds(InputType); + OutputKinds = FindEntryPointKinds(OutputType); } - // This lock protects adding to the below collections. - private readonly object _lock; - private readonly HashSet _cachedAssemblies; + private Type[] FindEntryPointKinds(Type type) + { + var kindAttr = type.GetTypeInfo().GetCustomAttributes(typeof(TlcModule.EntryPointKindAttribute), false).FirstOrDefault() + as TlcModule.EntryPointKindAttribute; + var baseType = type.BaseType; + + if (baseType == null) + return kindAttr?.Kinds; + var baseKinds = FindEntryPointKinds(baseType); + if (kindAttr == null) + return baseKinds; + if (baseKinds == null) + return kindAttr.Kinds; + return kindAttr.Kinds.Concat(baseKinds).ToArray(); + } + + public override string ToString() => $"{Name}: {Description}"; + } + + /// + /// A description of a single component. + /// The 'component' is a non-standalone building block that is used to parametrize entry points or other ML.NET components. + /// For example, 'Loss function', or 'similarity calculator' could be components. + /// + [BestFriend] + internal sealed class ComponentInfo + { + public readonly string Name; + public readonly string Description; + public readonly string FriendlyName; + public readonly string Kind; + public readonly Type ArgumentType; + public readonly Type InterfaceType; + public readonly string[] Aliases; + + internal ComponentInfo(Type interfaceType, string kind, Type argumentType, TlcModule.ComponentAttribute attribute) + { + Contracts.AssertValue(interfaceType); + Contracts.AssertNonEmpty(kind); + Contracts.AssertValue(argumentType); + Contracts.AssertValue(attribute); + + Name = attribute.Name; + Description = attribute.Desc; + if (string.IsNullOrWhiteSpace(attribute.FriendlyName)) + FriendlyName = Name; + else + FriendlyName = attribute.FriendlyName; + + Kind = kind; + if (!IsValidName(Kind)) + throw Contracts.Except("Invalid component kind: '{0}'", Kind); + + Aliases = attribute.Aliases; + if (!IsValidName(Name)) + throw Contracts.Except("Component name '{0}' is not valid.", Name); + + if (Aliases != null && Aliases.Any(x => !IsValidName(x))) + throw Contracts.Except("Component '{0}' has an invalid alias '{1}'", Name, Aliases.First(x => !IsValidName(x))); + + if (!typeof(IComponentFactory).IsAssignableFrom(argumentType)) + throw Contracts.Except("Component '{0}' must inherit from IComponentFactory", argumentType); + + ArgumentType = argumentType; + InterfaceType = interfaceType; + } + } - // Map from key/name to loadable class. Note that the same ClassInfo may appear - // multiple times. For the set of unique infos, use _classes. - private readonly Dictionary _classesByKey; + // This lock protects adding to the below collections. + private readonly object _lock; + private readonly HashSet _cachedAssemblies; - // The unique ClassInfos and Signatures. - private readonly List _classes; - private readonly Dictionary _signatures; + // Map from key/name to loadable class. Note that the same ClassInfo may appear + // multiple times. For the set of unique infos, use _classes. + private readonly Dictionary _classesByKey; - private readonly List _entryPoints; - private readonly Dictionary _entryPointMap; + // The unique ClassInfos and Signatures. + private readonly List _classes; + private readonly Dictionary _signatures; - private readonly List _components; - private readonly Dictionary _componentMap; + private readonly List _entryPoints; + private readonly Dictionary _entryPointMap; - private readonly Dictionary<(Type AttributeType, string ContractName), Type> _extensionsMap; + private readonly List _components; + private readonly Dictionary _componentMap; - private static bool TryGetIniters(Type instType, Type loaderType, Type[] parmTypes, - out MethodInfo getter, out ConstructorInfo ctor, out MethodInfo create, out bool requireEnvironment) - { - getter = null; - ctor = null; - create = null; - requireEnvironment = false; - bool requireEnvironmentCtor = false; - bool requireEnvironmentCreate = false; - var parmTypesWithEnv = Utils.Concat(new Type[1] { typeof(IHostEnvironment) }, parmTypes); + private readonly Dictionary<(Type AttributeType, string ContractName), Type> _extensionsMap; - if (Utils.Size(parmTypes) == 0 && (getter = FindInstanceGetter(instType, loaderType)) != null) - return true; + private static bool TryGetIniters(Type instType, Type loaderType, Type[] parmTypes, + out MethodInfo getter, out ConstructorInfo ctor, out MethodInfo create, out bool requireEnvironment) + { + getter = null; + ctor = null; + create = null; + requireEnvironment = false; + bool requireEnvironmentCtor = false; + bool requireEnvironmentCreate = false; + var parmTypesWithEnv = Utils.Concat(new Type[1] { typeof(IHostEnvironment) }, parmTypes); + + if (Utils.Size(parmTypes) == 0 && (getter = FindInstanceGetter(instType, loaderType)) != null) + return true; - // Find both 'ctor' and 'create' methods if available - if (instType.IsAssignableFrom(loaderType)) + // Find both 'ctor' and 'create' methods if available + if (instType.IsAssignableFrom(loaderType)) + { + if ((ctor = loaderType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, parmTypes ?? Type.EmptyTypes, null)) == null) { - if ((ctor = loaderType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, parmTypes ?? Type.EmptyTypes, null)) == null) - { - if ((ctor = loaderType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, parmTypesWithEnv ?? Type.EmptyTypes, null)) != null) - requireEnvironmentCtor = true; - } + if ((ctor = loaderType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, parmTypesWithEnv ?? Type.EmptyTypes, null)) != null) + requireEnvironmentCtor = true; } + } - if ((create = FindCreateMethod(instType, loaderType, parmTypes ?? Type.EmptyTypes)) == null) - { - if ((create = FindCreateMethod(instType, loaderType, parmTypesWithEnv ?? Type.EmptyTypes)) != null) - requireEnvironmentCreate = true; - } + if ((create = FindCreateMethod(instType, loaderType, parmTypes ?? Type.EmptyTypes)) == null) + { + if ((create = FindCreateMethod(instType, loaderType, parmTypesWithEnv ?? Type.EmptyTypes)) != null) + requireEnvironmentCreate = true; + } - if (ctor != null && create != null) - { - // If both 'ctor' and 'create' methods were found - // Choose the one that is 'more' public - // If they have the same visibility, then throw an exception, since this shouldn't happen. + if (ctor != null && create != null) + { + // If both 'ctor' and 'create' methods were found + // Choose the one that is 'more' public + // If they have the same visibility, then throw an exception, since this shouldn't happen. - if (ctor.Accessmodifier() == create.Accessmodifier()) - { - throw Contracts.Except($"Can't load type {instType}, because it has both create and constructor methods with the same visibility. Please indicate which one should be used by changing either the signature or the visibility of one of them."); - } - if (ctor.Accessmodifier() > create.Accessmodifier()) - { - create = null; - requireEnvironment = requireEnvironmentCtor; - return true; - } - ctor = null; - requireEnvironment = requireEnvironmentCreate; - return true; + if (ctor.Accessmodifier() == create.Accessmodifier()) + { + throw Contracts.Except($"Can't load type {instType}, because it has both create and constructor methods with the same visibility. Please indicate which one should be used by changing either the signature or the visibility of one of them."); } - - if (ctor != null && create == null) + if (ctor.Accessmodifier() > create.Accessmodifier()) { + create = null; requireEnvironment = requireEnvironmentCtor; return true; } + ctor = null; + requireEnvironment = requireEnvironmentCreate; + return true; + } - if (ctor == null && create != null) - { - requireEnvironment = requireEnvironmentCreate; - return true; - } + if (ctor != null && create == null) + { + requireEnvironment = requireEnvironmentCtor; + return true; + } - return false; + if (ctor == null && create != null) + { + requireEnvironment = requireEnvironmentCreate; + return true; } - private void AddClass(LoadableClassInfo info, string[] loadNames, bool throwOnError) + return false; + } + + private void AddClass(LoadableClassInfo info, string[] loadNames, bool throwOnError) + { + _classes.Add(info); + bool isEntryPoint = false; + foreach (var sigType in info.SignatureTypes) { - _classes.Add(info); - bool isEntryPoint = false; - foreach (var sigType in info.SignatureTypes) + _signatures[sigType] = true; + + foreach (var name in loadNames) { - _signatures[sigType] = true; + string nameCi = name.ToLowerInvariant(); - foreach (var name in loadNames) + var key = new LoadableClassInfo.Key(nameCi, sigType); + if (_classesByKey.TryGetValue(key, out var infoCur)) { - string nameCi = name.ToLowerInvariant(); - - var key = new LoadableClassInfo.Key(nameCi, sigType); - if (_classesByKey.TryGetValue(key, out var infoCur)) + if (throwOnError) { - if (throwOnError) - { - throw Contracts.Except($"ComponentCatalog cannot map name '{name}' and SignatureType '{sigType}' to {info.Type.Name}, already mapped to {infoCur.Type.Name}."); - } - } - else - { - _classesByKey.Add(key, info); + throw Contracts.Except($"ComponentCatalog cannot map name '{name}' and SignatureType '{sigType}' to {info.Type.Name}, already mapped to {infoCur.Type.Name}."); } } - - if (sigType == typeof(SignatureEntryPointModule)) + else { - isEntryPoint = true; + _classesByKey.Add(key, info); } } - if (isEntryPoint) + if (sigType == typeof(SignatureEntryPointModule)) { - ScanForEntryPoints(info); + isEntryPoint = true; } } - private void ScanForEntryPoints(LoadableClassInfo info) + if (isEntryPoint) { - var type = info.LoaderType; + ScanForEntryPoints(info); + } + } - // Scan for entry points. - foreach (var methodInfo in type.GetMethods(BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic)) - { - var attr = methodInfo.GetCustomAttributes(typeof(TlcModule.EntryPointAttribute), false).FirstOrDefault() as TlcModule.EntryPointAttribute; - if (attr == null) - continue; + private void ScanForEntryPoints(LoadableClassInfo info) + { + var type = info.LoaderType; - var entryPointInfo = new EntryPointInfo(methodInfo, attr, - methodInfo.GetCustomAttributes(typeof(ObsoleteAttribute), false).FirstOrDefault() as ObsoleteAttribute); + // Scan for entry points. + foreach (var methodInfo in type.GetMethods(BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic)) + { + var attr = methodInfo.GetCustomAttributes(typeof(TlcModule.EntryPointAttribute), false).FirstOrDefault() as TlcModule.EntryPointAttribute; + if (attr == null) + continue; - _entryPoints.Add(entryPointInfo); - if (_entryPointMap.ContainsKey(entryPointInfo.Name)) - { - // Duplicate entry point name. We need to show a warning here. - // REVIEW: we will be able to do this once catalog becomes a part of env. - continue; - } + var entryPointInfo = new EntryPointInfo(methodInfo, attr, + methodInfo.GetCustomAttributes(typeof(ObsoleteAttribute), false).FirstOrDefault() as ObsoleteAttribute); - _entryPointMap[entryPointInfo.Name] = entryPointInfo; + _entryPoints.Add(entryPointInfo); + if (_entryPointMap.ContainsKey(entryPointInfo.Name)) + { + // Duplicate entry point name. We need to show a warning here. + // REVIEW: we will be able to do this once catalog becomes a part of env. + continue; } - // Scan for components. - // First scan ourself, and then all nested types, for component info. - ScanForComponents(type); - foreach (var nestedType in type.GetTypeInfo().GetNestedTypes()) - ScanForComponents(nestedType); + _entryPointMap[entryPointInfo.Name] = entryPointInfo; } - private bool ScanForComponents(Type nestedType) - { - var attr = nestedType.GetTypeInfo().GetCustomAttributes(typeof(TlcModule.ComponentAttribute), true).FirstOrDefault() - as TlcModule.ComponentAttribute; - if (attr == null) - return false; + // Scan for components. + // First scan ourself, and then all nested types, for component info. + ScanForComponents(type); + foreach (var nestedType in type.GetTypeInfo().GetNestedTypes()) + ScanForComponents(nestedType); + } - bool found = false; - foreach (var faceType in nestedType.GetInterfaces()) - { - var faceAttr = faceType.GetTypeInfo().GetCustomAttributes(typeof(TlcModule.ComponentKindAttribute), false).FirstOrDefault() - as TlcModule.ComponentKindAttribute; - if (faceAttr == null) - continue; + private bool ScanForComponents(Type nestedType) + { + var attr = nestedType.GetTypeInfo().GetCustomAttributes(typeof(TlcModule.ComponentAttribute), true).FirstOrDefault() + as TlcModule.ComponentAttribute; + if (attr == null) + return false; + + bool found = false; + foreach (var faceType in nestedType.GetInterfaces()) + { + var faceAttr = faceType.GetTypeInfo().GetCustomAttributes(typeof(TlcModule.ComponentKindAttribute), false).FirstOrDefault() + as TlcModule.ComponentKindAttribute; + if (faceAttr == null) + continue; - if (!typeof(IComponentFactory).IsAssignableFrom(faceType)) - throw Contracts.Except("Component signature '{0}' doesn't inherit from '{1}'", faceType, typeof(IComponentFactory)); + if (!typeof(IComponentFactory).IsAssignableFrom(faceType)) + throw Contracts.Except("Component signature '{0}' doesn't inherit from '{1}'", faceType, typeof(IComponentFactory)); - try - { - // In order to populate from JSON, we need to invoke the parameterless ctor. Testing that this is possible. - Activator.CreateInstance(nestedType); - } - catch (MissingMemberException ex) - { - throw Contracts.Except(ex, "Component type '{0}' doesn't have a default constructor", faceType); - } + try + { + // In order to populate from JSON, we need to invoke the parameterless ctor. Testing that this is possible. + Activator.CreateInstance(nestedType); + } + catch (MissingMemberException ex) + { + throw Contracts.Except(ex, "Component type '{0}' doesn't have a default constructor", faceType); + } - var info = new ComponentInfo(faceType, faceAttr.Kind, nestedType, attr); - var names = (info.Aliases ?? new string[0]).Concat(new[] { info.Name }).Distinct(); - _components.Add(info); + var info = new ComponentInfo(faceType, faceAttr.Kind, nestedType, attr); + var names = (info.Aliases ?? new string[0]).Concat(new[] { info.Name }).Distinct(); + _components.Add(info); - foreach (var alias in names) + foreach (var alias in names) + { + var tag = $"{info.Kind}:{alias}"; + if (_componentMap.ContainsKey(tag)) { - var tag = $"{info.Kind}:{alias}"; - if (_componentMap.ContainsKey(tag)) - { - // Duplicate component name. We need to show a warning here. - // REVIEW: we will be able to do this once catalog becomes a part of env. - continue; - } - _componentMap[tag] = info; + // Duplicate component name. We need to show a warning here. + // REVIEW: we will be able to do this once catalog becomes a part of env. + continue; } + _componentMap[tag] = info; } - - return found; } - private static MethodInfo FindInstanceGetter(Type instType, Type loaderType) - { - // Look for a public static property named Instance of the correct type. - var prop = loaderType.GetProperty("Instance", instType); - if (prop == null) - return null; - if (prop.DeclaringType != loaderType) - return null; - var meth = prop.GetGetMethod(false); - if (meth == null) - return null; - if (meth.ReturnType != instType) - return null; - if (!meth.IsPublic || !meth.IsStatic) - return null; - return meth; - } + return found; + } - private static MethodInfo FindCreateMethod(Type instType, Type loaderType, Type[] parmTypes) - { - var meth = loaderType.GetMethod("Create", BindingFlags.Public | BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.FlattenHierarchy, null, parmTypes ?? Type.EmptyTypes, null); - if (meth == null) - return null; - if (meth.DeclaringType != loaderType) - return null; - if (meth.ReturnType != instType) - return null; - if (!meth.IsStatic) - return null; - return meth; - } + private static MethodInfo FindInstanceGetter(Type instType, Type loaderType) + { + // Look for a public static property named Instance of the correct type. + var prop = loaderType.GetProperty("Instance", instType); + if (prop == null) + return null; + if (prop.DeclaringType != loaderType) + return null; + var meth = prop.GetGetMethod(false); + if (meth == null) + return null; + if (meth.ReturnType != instType) + return null; + if (!meth.IsPublic || !meth.IsStatic) + return null; + return meth; + } - /// - /// Registers all the components in the specified assembly by looking for loadable classes - /// and adding them to the catalog. - /// - /// - /// The assembly to register. - /// - /// - /// true to throw an exception if there are errors with registering the components; - /// false to skip any errors. - /// - public void RegisterAssembly(Assembly assembly, bool throwOnError = true) + private static MethodInfo FindCreateMethod(Type instType, Type loaderType, Type[] parmTypes) + { + var meth = loaderType.GetMethod("Create", BindingFlags.Public | BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.FlattenHierarchy, null, parmTypes ?? Type.EmptyTypes, null); + if (meth == null) + return null; + if (meth.DeclaringType != loaderType) + return null; + if (meth.ReturnType != instType) + return null; + if (!meth.IsStatic) + return null; + return meth; + } + + /// + /// Registers all the components in the specified assembly by looking for loadable classes + /// and adding them to the catalog. + /// + /// + /// The assembly to register. + /// + /// + /// true to throw an exception if there are errors with registering the components; + /// false to skip any errors. + /// + public void RegisterAssembly(Assembly assembly, bool throwOnError = true) + { + lock (_lock) { - lock (_lock) + if (_cachedAssemblies.Add(assembly.FullName)) { - if (_cachedAssemblies.Add(assembly.FullName)) + foreach (LoadableClassAttributeBase attr in assembly.GetCustomAttributes(typeof(LoadableClassAttributeBase))) { - foreach (LoadableClassAttributeBase attr in assembly.GetCustomAttributes(typeof(LoadableClassAttributeBase))) + MethodInfo getter = null; + ConstructorInfo ctor = null; + MethodInfo create = null; + bool requireEnvironment = false; + if (attr.InstanceType != typeof(void) && !TryGetIniters(attr.InstanceType, attr.LoaderType, attr.CtorTypes, out getter, out ctor, out create, out requireEnvironment)) { - MethodInfo getter = null; - ConstructorInfo ctor = null; - MethodInfo create = null; - bool requireEnvironment = false; - if (attr.InstanceType != typeof(void) && !TryGetIniters(attr.InstanceType, attr.LoaderType, attr.CtorTypes, out getter, out ctor, out create, out requireEnvironment)) + if (throwOnError) { - if (throwOnError) - { - throw Contracts.Except( - $"Can't instantiate loadable class '{attr.InstanceType.Name}' with name '{attr.LoadNames[0]}'"); - } - Contracts.Assert(getter == null && ctor == null && create == null); + throw Contracts.Except( + $"Can't instantiate loadable class '{attr.InstanceType.Name}' with name '{attr.LoadNames[0]}'"); } - var info = new LoadableClassInfo(attr, getter, ctor, create, requireEnvironment); - - AddClass(info, attr.LoadNames, throwOnError); + Contracts.Assert(getter == null && ctor == null && create == null); } + var info = new LoadableClassInfo(attr, getter, ctor, create, requireEnvironment); - LoadExtensions(assembly, throwOnError); + AddClass(info, attr.LoadNames, throwOnError); } + + LoadExtensions(assembly, throwOnError); } } + } - /// - /// Return an array containing information for all instantiable components. - /// If provided, the given set of assemblies is loaded first. - /// - [BestFriend] - internal LoadableClassInfo[] GetAllClasses() - { - return _classes.ToArray(); - } + /// + /// Return an array containing information for all instantiable components. + /// If provided, the given set of assemblies is loaded first. + /// + [BestFriend] + internal LoadableClassInfo[] GetAllClasses() + { + return _classes.ToArray(); + } - /// - /// Return an array containing information for instantiable components with the given - /// signature and base type. If provided, the given set of assemblies is loaded first. - /// - [BestFriend] - internal LoadableClassInfo[] GetAllDerivedClasses(Type typeBase, Type typeSig) - { - Contracts.CheckValue(typeBase, nameof(typeBase)); - Contracts.CheckValueOrNull(typeSig); + /// + /// Return an array containing information for instantiable components with the given + /// signature and base type. If provided, the given set of assemblies is loaded first. + /// + [BestFriend] + internal LoadableClassInfo[] GetAllDerivedClasses(Type typeBase, Type typeSig) + { + Contracts.CheckValue(typeBase, nameof(typeBase)); + Contracts.CheckValueOrNull(typeSig); - // Apply the default. - if (typeSig == null) - typeSig = typeof(SignatureDefault); + // Apply the default. + if (typeSig == null) + typeSig = typeof(SignatureDefault); - return _classes - .Where(info => info.SignatureTypes.Contains(typeSig) && typeBase.IsAssignableFrom(info.Type)) - .ToArray(); - } + return _classes + .Where(info => info.SignatureTypes.Contains(typeSig) && typeBase.IsAssignableFrom(info.Type)) + .ToArray(); + } - /// - /// Return an array containing all the known signature types. If provided, the given set of assemblies - /// is loaded first. - /// - [BestFriend] - internal Type[] GetAllSignatureTypes() - { - return _signatures.Select(kvp => kvp.Key).ToArray(); - } + /// + /// Return an array containing all the known signature types. If provided, the given set of assemblies + /// is loaded first. + /// + [BestFriend] + internal Type[] GetAllSignatureTypes() + { + return _signatures.Select(kvp => kvp.Key).ToArray(); + } - /// - /// Returns a string name for a given signature type. - /// - [BestFriend] - internal static string SignatureToString(Type sig) - { - Contracts.CheckValue(sig, nameof(sig)); - Contracts.CheckParam(sig.BaseType == typeof(MulticastDelegate), nameof(sig), "Must be a delegate type"); - string kind = sig.Name; - if (kind.Length > "Signature".Length && kind.StartsWith("Signature")) - kind = kind.Substring("Signature".Length); - return kind; - } + /// + /// Returns a string name for a given signature type. + /// + [BestFriend] + internal static string SignatureToString(Type sig) + { + Contracts.CheckValue(sig, nameof(sig)); + Contracts.CheckParam(sig.BaseType == typeof(MulticastDelegate), nameof(sig), "Must be a delegate type"); + string kind = sig.Name; + if (kind.Length > "Signature".Length && kind.StartsWith("Signature")) + kind = kind.Substring("Signature".Length); + return kind; + } - private LoadableClassInfo FindClassCore(LoadableClassInfo.Key key) - { - LoadableClassInfo info; - if (_classesByKey.TryGetValue(key, out info)) - return info; + private LoadableClassInfo FindClassCore(LoadableClassInfo.Key key) + { + LoadableClassInfo info; + if (_classesByKey.TryGetValue(key, out info)) + return info; - return null; - } + return null; + } - [BestFriend] - internal LoadableClassInfo[] FindLoadableClasses(string name) - { - name = name.ToLowerInvariant().Trim(); + [BestFriend] + internal LoadableClassInfo[] FindLoadableClasses(string name) + { + name = name.ToLowerInvariant().Trim(); - var res = _classes - .Where(ci => ci.LoadNames.Select(n => n.ToLowerInvariant().Trim()).Contains(name)) - .ToArray(); - return res; - } + var res = _classes + .Where(ci => ci.LoadNames.Select(n => n.ToLowerInvariant().Trim()).Contains(name)) + .ToArray(); + return res; + } - [BestFriend] - internal LoadableClassInfo[] FindLoadableClasses() - { - return _classes - .Where(ci => ci.SignatureTypes.Contains(typeof(TSig))) - .ToArray(); - } + [BestFriend] + internal LoadableClassInfo[] FindLoadableClasses() + { + return _classes + .Where(ci => ci.SignatureTypes.Contains(typeof(TSig))) + .ToArray(); + } - [BestFriend] - internal LoadableClassInfo[] FindLoadableClasses() - { - // REVIEW: this and above methods perform a linear search over all the loadable classes. - // On 6/15/2015, TLC release build contained 431 of them, so adding extra lookups looks unnecessary at this time. - return _classes - .Where(ci => ci.ArgType == typeof(TArgs) && ci.SignatureTypes.Contains(typeof(TSig))) - .ToArray(); - } + [BestFriend] + internal LoadableClassInfo[] FindLoadableClasses() + { + // REVIEW: this and above methods perform a linear search over all the loadable classes. + // On 6/15/2015, TLC release build contained 431 of them, so adding extra lookups looks unnecessary at this time. + return _classes + .Where(ci => ci.ArgType == typeof(TArgs) && ci.SignatureTypes.Contains(typeof(TSig))) + .ToArray(); + } - [BestFriend] - internal LoadableClassInfo GetLoadableClassInfo(string loadName) - { - return GetLoadableClassInfo(loadName, typeof(TSig)); - } + [BestFriend] + internal LoadableClassInfo GetLoadableClassInfo(string loadName) + { + return GetLoadableClassInfo(loadName, typeof(TSig)); + } - [BestFriend] - internal LoadableClassInfo GetLoadableClassInfo(string loadName, Type signatureType) - { - Contracts.CheckParam(signatureType.BaseType == typeof(MulticastDelegate), nameof(signatureType), "signatureType must be a delegate type"); - Contracts.CheckValueOrNull(loadName); - loadName = (loadName ?? "").ToLowerInvariant().Trim(); - return FindClassCore(new LoadableClassInfo.Key(loadName, signatureType)); - } + [BestFriend] + internal LoadableClassInfo GetLoadableClassInfo(string loadName, Type signatureType) + { + Contracts.CheckParam(signatureType.BaseType == typeof(MulticastDelegate), nameof(signatureType), "signatureType must be a delegate type"); + Contracts.CheckValueOrNull(loadName); + loadName = (loadName ?? "").ToLowerInvariant().Trim(); + return FindClassCore(new LoadableClassInfo.Key(loadName, signatureType)); + } - /// - /// Get all registered entry points. - /// - [BestFriend] - internal IEnumerable AllEntryPoints() - { - return _entryPoints; - } + /// + /// Get all registered entry points. + /// + [BestFriend] + internal IEnumerable AllEntryPoints() + { + return _entryPoints; + } - [BestFriend] - internal bool TryFindEntryPoint(string name, out EntryPointInfo entryPoint) - { - Contracts.CheckNonEmpty(name, nameof(name)); - return _entryPointMap.TryGetValue(name, out entryPoint); - } + [BestFriend] + internal bool TryFindEntryPoint(string name, out EntryPointInfo entryPoint) + { + Contracts.CheckNonEmpty(name, nameof(name)); + return _entryPointMap.TryGetValue(name, out entryPoint); + } - [BestFriend] - internal bool TryFindComponent(string kind, string alias, out ComponentInfo component) - { - Contracts.CheckNonEmpty(kind, nameof(kind)); - Contracts.CheckNonEmpty(alias, nameof(alias)); + [BestFriend] + internal bool TryFindComponent(string kind, string alias, out ComponentInfo component) + { + Contracts.CheckNonEmpty(kind, nameof(kind)); + Contracts.CheckNonEmpty(alias, nameof(alias)); - // Note that, if kind or alias contain the colon character, the kind:name 'tag' will contain more than one colon. - // Since colon may not appear in any valid name, the dictionary lookup is guaranteed to fail. - return _componentMap.TryGetValue($"{kind}:{alias}", out component); - } + // Note that, if kind or alias contain the colon character, the kind:name 'tag' will contain more than one colon. + // Since colon may not appear in any valid name, the dictionary lookup is guaranteed to fail. + return _componentMap.TryGetValue($"{kind}:{alias}", out component); + } - [BestFriend] - internal bool TryFindComponent(Type argumentType, out ComponentInfo component) - { - Contracts.CheckValue(argumentType, nameof(argumentType)); + [BestFriend] + internal bool TryFindComponent(Type argumentType, out ComponentInfo component) + { + Contracts.CheckValue(argumentType, nameof(argumentType)); - component = _components.FirstOrDefault(x => x.ArgumentType == argumentType); - return component != null; - } + component = _components.FirstOrDefault(x => x.ArgumentType == argumentType); + return component != null; + } - [BestFriend] - internal bool TryFindComponent(Type interfaceType, Type argumentType, out ComponentInfo component) - { - Contracts.CheckValue(interfaceType, nameof(interfaceType)); - Contracts.CheckParam(interfaceType.IsInterface, nameof(interfaceType), "Must be interface"); - Contracts.CheckValue(argumentType, nameof(argumentType)); + [BestFriend] + internal bool TryFindComponent(Type interfaceType, Type argumentType, out ComponentInfo component) + { + Contracts.CheckValue(interfaceType, nameof(interfaceType)); + Contracts.CheckParam(interfaceType.IsInterface, nameof(interfaceType), "Must be interface"); + Contracts.CheckValue(argumentType, nameof(argumentType)); - component = _components.FirstOrDefault(x => x.InterfaceType == interfaceType && x.ArgumentType == argumentType); - return component != null; - } + component = _components.FirstOrDefault(x => x.InterfaceType == interfaceType && x.ArgumentType == argumentType); + return component != null; + } - [BestFriend] - internal bool TryFindComponent(Type interfaceType, string alias, out ComponentInfo component) - { - Contracts.CheckValue(interfaceType, nameof(interfaceType)); - Contracts.CheckParam(interfaceType.IsInterface, nameof(interfaceType), "Must be interface"); - Contracts.CheckNonEmpty(alias, nameof(alias)); - component = _components.FirstOrDefault(x => x.InterfaceType == interfaceType && (x.Name == alias || (x.Aliases != null && x.Aliases.Contains(alias)))); - return component != null; - } + [BestFriend] + internal bool TryFindComponent(Type interfaceType, string alias, out ComponentInfo component) + { + Contracts.CheckValue(interfaceType, nameof(interfaceType)); + Contracts.CheckParam(interfaceType.IsInterface, nameof(interfaceType), "Must be interface"); + Contracts.CheckNonEmpty(alias, nameof(alias)); + component = _components.FirstOrDefault(x => x.InterfaceType == interfaceType && (x.Name == alias || (x.Aliases != null && x.Aliases.Contains(alias)))); + return component != null; + } - /// - /// Akin to , except if the regular (case sensitive) comparison fails, it will - /// attempt to back off to a case-insensitive comparison. - /// - [BestFriend] - internal bool TryFindComponentCaseInsensitive(Type interfaceType, string alias, out ComponentInfo component) - { - Contracts.CheckValue(interfaceType, nameof(interfaceType)); - Contracts.CheckParam(interfaceType.IsInterface, nameof(interfaceType), "Must be interface"); - Contracts.CheckNonEmpty(alias, nameof(alias)); - if (TryFindComponent(interfaceType, alias, out component)) - return true; - alias = alias.ToLowerInvariant(); - component = _components.FirstOrDefault(x => x.InterfaceType == interfaceType && (x.Name.ToLowerInvariant() == alias || AnyMatch(alias, x.Aliases))); - return component != null; - } + /// + /// Akin to , except if the regular (case sensitive) comparison fails, it will + /// attempt to back off to a case-insensitive comparison. + /// + [BestFriend] + internal bool TryFindComponentCaseInsensitive(Type interfaceType, string alias, out ComponentInfo component) + { + Contracts.CheckValue(interfaceType, nameof(interfaceType)); + Contracts.CheckParam(interfaceType.IsInterface, nameof(interfaceType), "Must be interface"); + Contracts.CheckNonEmpty(alias, nameof(alias)); + if (TryFindComponent(interfaceType, alias, out component)) + return true; + alias = alias.ToLowerInvariant(); + component = _components.FirstOrDefault(x => x.InterfaceType == interfaceType && (x.Name.ToLowerInvariant() == alias || AnyMatch(alias, x.Aliases))); + return component != null; + } - private static bool AnyMatch(string name, string[] aliases) - { - if (aliases == null) - return false; - return aliases.Any(a => string.Equals(name, a, StringComparison.OrdinalIgnoreCase)); - } + private static bool AnyMatch(string name, string[] aliases) + { + if (aliases == null) + return false; + return aliases.Any(a => string.Equals(name, a, StringComparison.OrdinalIgnoreCase)); + } - /// - /// Returns all valid component kinds. - /// - [BestFriend] - internal IEnumerable GetAllComponentKinds() - { - return _components.Select(x => x.Kind).Distinct().OrderBy(x => x); - } + /// + /// Returns all valid component kinds. + /// + [BestFriend] + internal IEnumerable GetAllComponentKinds() + { + return _components.Select(x => x.Kind).Distinct().OrderBy(x => x); + } - /// - /// Returns all components of the specified kind. - /// - [BestFriend] - internal IEnumerable GetAllComponents(string kind) - { - Contracts.CheckNonEmpty(kind, nameof(kind)); - Contracts.CheckParam(IsValidName(kind), nameof(kind), "Invalid component kind"); - return _components.Where(x => x.Kind == kind).OrderBy(x => x.Name); - } + /// + /// Returns all components of the specified kind. + /// + [BestFriend] + internal IEnumerable GetAllComponents(string kind) + { + Contracts.CheckNonEmpty(kind, nameof(kind)); + Contracts.CheckParam(IsValidName(kind), nameof(kind), "Invalid component kind"); + return _components.Where(x => x.Kind == kind).OrderBy(x => x.Name); + } - /// - /// Returns all components that implement the specified interface. - /// - [BestFriend] - internal IEnumerable GetAllComponents(Type interfaceType) + /// + /// Returns all components that implement the specified interface. + /// + [BestFriend] + internal IEnumerable GetAllComponents(Type interfaceType) + { + Contracts.CheckValue(interfaceType, nameof(interfaceType)); + return _components.Where(x => x.InterfaceType == interfaceType).OrderBy(x => x.Name); + } + + [BestFriend] + internal bool TryGetComponentKind(Type signatureType, out string kind) + { + Contracts.CheckValue(signatureType, nameof(signatureType)); + // REVIEW: replace with a dictionary lookup. + + var faceAttr = signatureType.GetTypeInfo().GetCustomAttributes(typeof(TlcModule.ComponentKindAttribute), false).FirstOrDefault() + as TlcModule.ComponentKindAttribute; + kind = faceAttr == null ? null : faceAttr.Kind; + return faceAttr != null; + } + + [BestFriend] + internal bool TryGetComponentShortName(Type type, out string name) + { + ComponentInfo component; + if (!TryFindComponent(type, out component)) { - Contracts.CheckValue(interfaceType, nameof(interfaceType)); - return _components.Where(x => x.InterfaceType == interfaceType).OrderBy(x => x.Name); + name = null; + return false; } - [BestFriend] - internal bool TryGetComponentKind(Type signatureType, out string kind) - { - Contracts.CheckValue(signatureType, nameof(signatureType)); - // REVIEW: replace with a dictionary lookup. + name = component.Aliases != null && component.Aliases.Length > 0 ? component.Aliases[0] : component.Name; + return true; + } - var faceAttr = signatureType.GetTypeInfo().GetCustomAttributes(typeof(TlcModule.ComponentKindAttribute), false).FirstOrDefault() - as TlcModule.ComponentKindAttribute; - kind = faceAttr == null ? null : faceAttr.Kind; - return faceAttr != null; - } + /// + /// The valid names for the components and entry points must consist of letters, digits, underscores and dots, + /// and begin with a letter or digit. + /// + private static readonly Regex _nameRegex = new Regex(@"^\w[_\.\w]*$", RegexOptions.Compiled); + private static bool IsValidName(string name) + { + Contracts.AssertValueOrNull(name); + if (string.IsNullOrWhiteSpace(name)) + return false; + return _nameRegex.IsMatch(name); + } - [BestFriend] - internal bool TryGetComponentShortName(Type type, out string name) - { - ComponentInfo component; - if (!TryFindComponent(type, out component)) - { - name = null; - return false; - } + /// + /// Create an instance of the indicated component with the given extra parameters. + /// + [BestFriend] + internal static TRes CreateInstance(IHostEnvironment env, Type signatureType, string name, string options, params object[] extra) + where TRes : class + { + TRes result; + if (TryCreateInstance(env, signatureType, out result, name, options, extra)) + return result; + throw Contracts.Except("Unknown loadable class: {0}", name).MarkSensitive(MessageSensitivity.None); + } - name = component.Aliases != null && component.Aliases.Length > 0 ? component.Aliases[0] : component.Name; - return true; - } + /// + /// Try to create an instance of the indicated component and settings with the given extra parameters. + /// If there is no such component in the catalog, returns false. Any other error results in an exception. + /// + [BestFriend] + internal static bool TryCreateInstance(IHostEnvironment env, out TRes result, string name, string options, params object[] extra) + where TRes : class + { + return TryCreateInstance(env, typeof(TSig), out result, name, options, extra); + } - /// - /// The valid names for the components and entry points must consist of letters, digits, underscores and dots, - /// and begin with a letter or digit. - /// - private static readonly Regex _nameRegex = new Regex(@"^\w[_\.\w]*$", RegexOptions.Compiled); - private static bool IsValidName(string name) - { - Contracts.AssertValueOrNull(name); - if (string.IsNullOrWhiteSpace(name)) - return false; - return _nameRegex.IsMatch(name); - } + [BestFriend] + internal static bool TryCreateInstance(IHostEnvironment env, Type signatureType, out TRes result, string name, string options, params object[] extra) + where TRes : class + { + Contracts.CheckValue(env, nameof(env)); + env.Check(signatureType.BaseType == typeof(MulticastDelegate)); + env.CheckValueOrNull(name); - /// - /// Create an instance of the indicated component with the given extra parameters. - /// - [BestFriend] - internal static TRes CreateInstance(IHostEnvironment env, Type signatureType, string name, string options, params object[] extra) - where TRes : class + string nameLower = (name ?? "").ToLowerInvariant().Trim(); + LoadableClassInfo info = env.ComponentCatalog.FindClassCore(new LoadableClassInfo.Key(nameLower, signatureType)); + if (info == null) { - TRes result; - if (TryCreateInstance(env, signatureType, out result, name, options, extra)) - return result; - throw Contracts.Except("Unknown loadable class: {0}", name).MarkSensitive(MessageSensitivity.None); + result = null; + return false; } - /// - /// Try to create an instance of the indicated component and settings with the given extra parameters. - /// If there is no such component in the catalog, returns false. Any other error results in an exception. - /// - [BestFriend] - internal static bool TryCreateInstance(IHostEnvironment env, out TRes result, string name, string options, params object[] extra) - where TRes : class + if (!typeof(TRes).IsAssignableFrom(info.Type)) + throw env.Except("Loadable class '{0}' does not derive from '{1}'", name, typeof(TRes).FullName); + + int carg = Utils.Size(extra); + + if (info.ExtraArgCount != carg) { - return TryCreateInstance(env, typeof(TSig), out result, name, options, extra); + throw env.Except( + "Wrong number of extra parameters for loadable class '{0}', need '{1}', given '{2}'", + name, info.ExtraArgCount, carg); } - [BestFriend] - internal static bool TryCreateInstance(IHostEnvironment env, Type signatureType, out TRes result, string name, string options, params object[] extra) - where TRes : class + if (info.ArgType == null) { - Contracts.CheckValue(env, nameof(env)); - env.Check(signatureType.BaseType == typeof(MulticastDelegate)); - env.CheckValueOrNull(name); - - string nameLower = (name ?? "").ToLowerInvariant().Trim(); - LoadableClassInfo info = env.ComponentCatalog.FindClassCore(new LoadableClassInfo.Key(nameLower, signatureType)); - if (info == null) - { - result = null; - return false; - } - - if (!typeof(TRes).IsAssignableFrom(info.Type)) - throw env.Except("Loadable class '{0}' does not derive from '{1}'", name, typeof(TRes).FullName); + if (!string.IsNullOrEmpty(options)) + throw env.Except("Loadable class '{0}' doesn't support settings", name); + result = (TRes)info.CreateInstance(env, null, extra); + return true; + } - int carg = Utils.Size(extra); + object args = info.CreateArguments(); + if (args == null) + throw Contracts.Except("Can't instantiate arguments object '{0}' for '{1}'", info.ArgType.Name, name); - if (info.ExtraArgCount != carg) - { - throw env.Except( - "Wrong number of extra parameters for loadable class '{0}', need '{1}', given '{2}'", - name, info.ExtraArgCount, carg); - } + ParseArguments(env, args, options, name); + result = (TRes)info.CreateInstance(env, args, extra); + return true; + } - if (info.ArgType == null) - { - if (!string.IsNullOrEmpty(options)) - throw env.Except("Loadable class '{0}' doesn't support settings", name); - result = (TRes)info.CreateInstance(env, null, extra); - return true; - } + /// + /// Parses arguments using CmdParser. If there's a problem, it throws an InvalidOperationException, + /// with a message giving usage. + /// + /// The host environment + /// The argument object + /// The settings string + /// The name is used for error reporting only + private static void ParseArguments(IHostEnvironment env, object args, string settings, string name) + { + Contracts.AssertValue(args); + Contracts.AssertNonEmpty(name); - object args = info.CreateArguments(); - if (args == null) - throw Contracts.Except("Can't instantiate arguments object '{0}' for '{1}'", info.ArgType.Name, name); + if (string.IsNullOrWhiteSpace(settings)) + return; - ParseArguments(env, args, options, name); - result = (TRes)info.CreateInstance(env, args, extra); - return true; + string errorMsg = null; + try + { + string err = null; + string helpText; + if (!CmdParser.ParseArguments(env, settings, args, e => { err = err ?? e; }, out helpText)) + errorMsg = err + Environment.NewLine + "Usage For '" + name + "':" + Environment.NewLine + helpText; } - - /// - /// Parses arguments using CmdParser. If there's a problem, it throws an InvalidOperationException, - /// with a message giving usage. - /// - /// The host environment - /// The argument object - /// The settings string - /// The name is used for error reporting only - private static void ParseArguments(IHostEnvironment env, object args, string settings, string name) + catch (Exception e) { - Contracts.AssertValue(args); - Contracts.AssertNonEmpty(name); - - if (string.IsNullOrWhiteSpace(settings)) - return; - - string errorMsg = null; - try - { - string err = null; - string helpText; - if (!CmdParser.ParseArguments(env, settings, args, e => { err = err ?? e; }, out helpText)) - errorMsg = err + Environment.NewLine + "Usage For '" + name + "':" + Environment.NewLine + helpText; - } - catch (Exception e) - { - Contracts.Assert(false); - throw Contracts.Except(e, "Unexpected exception thrown while parsing: " + e.Message); - } - - if (errorMsg != null) - throw Contracts.Except(errorMsg); + Contracts.Assert(false); + throw Contracts.Except(e, "Unexpected exception thrown while parsing: " + e.Message); } - private void LoadExtensions(Assembly assembly, bool throwOnError) + if (errorMsg != null) + throw Contracts.Except(errorMsg); + } + + private void LoadExtensions(Assembly assembly, bool throwOnError) + { + // don't waste time looking through all the types of an assembly + // that can't contain extensions + if (CanContainExtensions(assembly)) { - // don't waste time looking through all the types of an assembly - // that can't contain extensions - if (CanContainExtensions(assembly)) + foreach (Type type in assembly.GetTypes()) { - foreach (Type type in assembly.GetTypes()) + if (type.IsClass) { - if (type.IsClass) + foreach (ExtensionBaseAttribute attribute in type.GetCustomAttributes(typeof(ExtensionBaseAttribute))) { - foreach (ExtensionBaseAttribute attribute in type.GetCustomAttributes(typeof(ExtensionBaseAttribute))) + var key = (AttributeType: attribute.GetType(), attribute.ContractName); + if (_extensionsMap.TryGetValue(key, out var existingType)) { - var key = (AttributeType: attribute.GetType(), attribute.ContractName); - if (_extensionsMap.TryGetValue(key, out var existingType)) - { - if (throwOnError) - { - throw Contracts.Except($"An extension for '{key.AttributeType.Name}' with contract '{key.ContractName}' has already been registered in the ComponentCatalog."); - } - } - else + if (throwOnError) { - _extensionsMap.Add(key, type); + throw Contracts.Except($"An extension for '{key.AttributeType.Name}' with contract '{key.ContractName}' has already been registered in the ComponentCatalog."); } } + else + { + _extensionsMap.Add(key, type); + } } } } } + } - /// - /// Gets a value indicating whether can contain extensions. - /// - /// - /// All ML.NET product assemblies won't contain extensions. - /// - private static bool CanContainExtensions(Assembly assembly) + /// + /// Gets a value indicating whether can contain extensions. + /// + /// + /// All ML.NET product assemblies won't contain extensions. + /// + private static bool CanContainExtensions(Assembly assembly) + { + if (assembly.FullName.StartsWith("Microsoft.ML.", StringComparison.Ordinal) + && HasMLNetPublicKey(assembly)) { - if (assembly.FullName.StartsWith("Microsoft.ML.", StringComparison.Ordinal) - && HasMLNetPublicKey(assembly)) - { - return false; - } - - return true; + return false; } - private static bool HasMLNetPublicKey(Assembly assembly) + return true; + } + + private static bool HasMLNetPublicKey(Assembly assembly) + { + return assembly.GetName().GetPublicKey().SequenceEqual( + typeof(ComponentCatalog).Assembly.GetName().GetPublicKey()); + } + + [BestFriend] + internal object GetExtensionValue(IHostEnvironment env, Type attributeType, string contractName) + { + object exportedValue = null; + if (_extensionsMap.TryGetValue((attributeType, contractName), out Type extensionType)) { - return assembly.GetName().GetPublicKey().SequenceEqual( - typeof(ComponentCatalog).Assembly.GetName().GetPublicKey()); + exportedValue = Activator.CreateInstance(extensionType); } - [BestFriend] - internal object GetExtensionValue(IHostEnvironment env, Type attributeType, string contractName) + if (exportedValue == null) { - object exportedValue = null; - if (_extensionsMap.TryGetValue((attributeType, contractName), out Type extensionType)) - { - exportedValue = Activator.CreateInstance(extensionType); - } - - if (exportedValue == null) - { - throw env.Except($"Unable to locate an extension for the contract '{contractName}'. Ensure you have called {nameof(ComponentCatalog)}.{nameof(ComponentCatalog.RegisterAssembly)} with the Assembly that contains a class decorated with a '{attributeType.FullName}'."); - } - - return exportedValue; + throw env.Except($"Unable to locate an extension for the contract '{contractName}'. Ensure you have called {nameof(ComponentCatalog)}.{nameof(ComponentCatalog.RegisterAssembly)} with the Assembly that contains a class decorated with a '{attributeType.FullName}'."); } + + return exportedValue; } } diff --git a/src/Microsoft.ML.Core/ComponentModel/ComponentFactory.cs b/src/Microsoft.ML.Core/ComponentModel/ComponentFactory.cs index d04df9a80a..79f1a337bd 100644 --- a/src/Microsoft.ML.Core/ComponentModel/ComponentFactory.cs +++ b/src/Microsoft.ML.Core/ComponentModel/ComponentFactory.cs @@ -4,166 +4,165 @@ using System; -namespace Microsoft.ML.Runtime +namespace Microsoft.ML.Runtime; + +/// +/// This is a token interface that all component factories must implement. +/// +public interface IComponentFactory { - /// - /// This is a token interface that all component factories must implement. - /// - public interface IComponentFactory - { - } +} + +/// +/// An interface for creating a component with no extra parameters (other than an ). +/// +public interface IComponentFactory : IComponentFactory +{ + TComponent CreateComponent(IHostEnvironment env); +} + +/// +/// An interface for creating a component when we take one extra parameter (and an ). +/// +public interface IComponentFactory : IComponentFactory +{ + TComponent CreateComponent(IHostEnvironment env, TArg1 argument1); +} + +/// +/// An interface for creating a component when we take two extra parameters (and an ). +/// +public interface IComponentFactory : IComponentFactory +{ + TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2); +} +/// +/// An interface for creating a component when we take three extra parameters (and an ). +/// +public interface IComponentFactory : IComponentFactory +{ + TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2, TArg3 argument3); +} + +/// +/// A utility class for creating instances. +/// +[BestFriend] +internal static class ComponentFactoryUtils +{ /// - /// An interface for creating a component with no extra parameters (other than an ). + /// Creates a component factory with no extra parameters (other than an ) + /// that simply wraps a delegate which creates the component. /// - public interface IComponentFactory : IComponentFactory + public static IComponentFactory CreateFromFunction(Func factory) { - TComponent CreateComponent(IHostEnvironment env); + return new SimpleComponentFactory(factory); } /// - /// An interface for creating a component when we take one extra parameter (and an ). + /// Creates a component factory when we take one extra parameter (and an + /// ) that simply wraps a delegate which creates the component. /// - public interface IComponentFactory : IComponentFactory + public static IComponentFactory CreateFromFunction(Func factory) { - TComponent CreateComponent(IHostEnvironment env, TArg1 argument1); + return new SimpleComponentFactory(factory); } /// - /// An interface for creating a component when we take two extra parameters (and an ). + /// Creates a component factory when we take two extra parameters (and an + /// ) that simply wraps a delegate which creates the component. /// - public interface IComponentFactory : IComponentFactory + public static IComponentFactory CreateFromFunction(Func factory) { - TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2); + return new SimpleComponentFactory(factory); } /// - /// An interface for creating a component when we take three extra parameters (and an ). + /// Creates a component factory when we take three extra parameters (and an + /// ) that simply wraps a delegate which creates the component. /// - public interface IComponentFactory : IComponentFactory + public static IComponentFactory CreateFromFunction(Func factory) { - TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2, TArg3 argument3); + return new SimpleComponentFactory(factory); } /// - /// A utility class for creating instances. + /// A class for creating a component with no extra parameters (other than an ) + /// that simply wraps a delegate which creates the component. /// - [BestFriend] - internal static class ComponentFactoryUtils + private sealed class SimpleComponentFactory : IComponentFactory { - /// - /// Creates a component factory with no extra parameters (other than an ) - /// that simply wraps a delegate which creates the component. - /// - public static IComponentFactory CreateFromFunction(Func factory) - { - return new SimpleComponentFactory(factory); - } + private readonly Func _factory; - /// - /// Creates a component factory when we take one extra parameter (and an - /// ) that simply wraps a delegate which creates the component. - /// - public static IComponentFactory CreateFromFunction(Func factory) + public SimpleComponentFactory(Func factory) { - return new SimpleComponentFactory(factory); + _factory = factory; } - /// - /// Creates a component factory when we take two extra parameters (and an - /// ) that simply wraps a delegate which creates the component. - /// - public static IComponentFactory CreateFromFunction(Func factory) + public TComponent CreateComponent(IHostEnvironment env) { - return new SimpleComponentFactory(factory); + return _factory(env); } + } - /// - /// Creates a component factory when we take three extra parameters (and an - /// ) that simply wraps a delegate which creates the component. - /// - public static IComponentFactory CreateFromFunction(Func factory) - { - return new SimpleComponentFactory(factory); - } + /// + /// A class for creating a component when we take one extra parameter + /// (and an ) that simply wraps a delegate which + /// creates the component. + /// + private sealed class SimpleComponentFactory : IComponentFactory + { + private readonly Func _factory; - /// - /// A class for creating a component with no extra parameters (other than an ) - /// that simply wraps a delegate which creates the component. - /// - private sealed class SimpleComponentFactory : IComponentFactory + public SimpleComponentFactory(Func factory) { - private readonly Func _factory; - - public SimpleComponentFactory(Func factory) - { - _factory = factory; - } - - public TComponent CreateComponent(IHostEnvironment env) - { - return _factory(env); - } + _factory = factory; } - /// - /// A class for creating a component when we take one extra parameter - /// (and an ) that simply wraps a delegate which - /// creates the component. - /// - private sealed class SimpleComponentFactory : IComponentFactory + public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1) { - private readonly Func _factory; + return _factory(env, argument1); + } + } - public SimpleComponentFactory(Func factory) - { - _factory = factory; - } + /// + /// A class for creating a component when we take one extra parameter + /// (and an ) that simply wraps a delegate which + /// creates the component. + /// + private sealed class SimpleComponentFactory : IComponentFactory + { + private readonly Func _factory; - public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1) - { - return _factory(env, argument1); - } + public SimpleComponentFactory(Func factory) + { + _factory = factory; } - /// - /// A class for creating a component when we take one extra parameter - /// (and an ) that simply wraps a delegate which - /// creates the component. - /// - private sealed class SimpleComponentFactory : IComponentFactory + public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2) { - private readonly Func _factory; + return _factory(env, argument1, argument2); + } + } - public SimpleComponentFactory(Func factory) - { - _factory = factory; - } + /// + /// A class for creating a component when we take three extra parameters + /// (and an ) that simply wraps a delegate which + /// creates the component. + /// + private sealed class SimpleComponentFactory : IComponentFactory + { + private readonly Func _factory; - public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2) - { - return _factory(env, argument1, argument2); - } + public SimpleComponentFactory(Func factory) + { + _factory = factory; } - /// - /// A class for creating a component when we take three extra parameters - /// (and an ) that simply wraps a delegate which - /// creates the component. - /// - private sealed class SimpleComponentFactory : IComponentFactory + public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2, TArg3 argument3) { - private readonly Func _factory; - - public SimpleComponentFactory(Func factory) - { - _factory = factory; - } - - public TComponent CreateComponent(IHostEnvironment env, TArg1 argument1, TArg2 argument2, TArg3 argument3) - { - return _factory(env, argument1, argument2, argument3); - } + return _factory(env, argument1, argument2, argument3); } } } diff --git a/src/Microsoft.ML.Core/ComponentModel/ExtensionBaseAttribute.cs b/src/Microsoft.ML.Core/ComponentModel/ExtensionBaseAttribute.cs index 66e0860310..dcad79ba01 100644 --- a/src/Microsoft.ML.Core/ComponentModel/ExtensionBaseAttribute.cs +++ b/src/Microsoft.ML.Core/ComponentModel/ExtensionBaseAttribute.cs @@ -4,20 +4,19 @@ using System; -namespace Microsoft.ML +namespace Microsoft.ML; + +/// +/// The base attribute type for all attributes used for extensibility purposes. +/// +[AttributeUsage(AttributeTargets.Class)] +public abstract class ExtensionBaseAttribute : Attribute { - /// - /// The base attribute type for all attributes used for extensibility purposes. - /// - [AttributeUsage(AttributeTargets.Class)] - public abstract class ExtensionBaseAttribute : Attribute - { - public string ContractName { get; } + public string ContractName { get; } - [BestFriend] - private protected ExtensionBaseAttribute(string contractName) - { - ContractName = contractName; - } + [BestFriend] + private protected ExtensionBaseAttribute(string contractName) + { + ContractName = contractName; } } diff --git a/src/Microsoft.ML.Core/ComponentModel/LoadableClassAttribute.cs b/src/Microsoft.ML.Core/ComponentModel/LoadableClassAttribute.cs index a61ddff719..d7fdc371af 100644 --- a/src/Microsoft.ML.Core/ComponentModel/LoadableClassAttribute.cs +++ b/src/Microsoft.ML.Core/ComponentModel/LoadableClassAttribute.cs @@ -8,229 +8,228 @@ using Microsoft.ML.Internal.Utilities; using Microsoft.ML.Runtime; -namespace Microsoft.ML +namespace Microsoft.ML; + +/// +/// Common signature type with no extra parameters. +/// +[BestFriend] +internal delegate void SignatureDefault(); + +[AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)] +[BestFriend] +internal sealed class LoadableClassAttribute : LoadableClassAttributeBase { /// - /// Common signature type with no extra parameters. + /// Assembly attribute used to specify that a class is loadable by a machine learning + /// host environment, such as TLC /// - [BestFriend] - internal delegate void SignatureDefault(); + /// The class type that is loadable + /// The argument type that the constructor takes (may be null) + /// The signature of the constructor of this class (in addition to the arguments parameter) + /// The name to use when presenting a list to users + /// The names that can be used to load the class, for example, from a command line + public LoadableClassAttribute(Type instType, Type argType, Type sigType, string userName, params string[] loadNames) + : base(null, instType, instType, argType, new[] { sigType }, userName, loadNames) + { + } - [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)] - [BestFriend] - internal sealed class LoadableClassAttribute : LoadableClassAttributeBase + /// + /// Assembly attribute used to specify that a class is loadable by a machine learning + /// host environment, such as TLC + /// + /// The class type that is loadable + /// The class type that contains the construction method + /// The argument type that the constructor takes (may be null) + /// The signature of the constructor of this class (in addition to the arguments parameter) + /// The name to use when presenting a list to users + /// The names that can be used to load the class, for example, from a command line + public LoadableClassAttribute(Type instType, Type loaderType, Type argType, Type sigType, string userName, params string[] loadNames) + : base(null, instType, loaderType, argType, new[] { sigType }, userName, loadNames) { - /// - /// Assembly attribute used to specify that a class is loadable by a machine learning - /// host environment, such as TLC - /// - /// The class type that is loadable - /// The argument type that the constructor takes (may be null) - /// The signature of the constructor of this class (in addition to the arguments parameter) - /// The name to use when presenting a list to users - /// The names that can be used to load the class, for example, from a command line - public LoadableClassAttribute(Type instType, Type argType, Type sigType, string userName, params string[] loadNames) - : base(null, instType, instType, argType, new[] { sigType }, userName, loadNames) - { - } + } - /// - /// Assembly attribute used to specify that a class is loadable by a machine learning - /// host environment, such as TLC - /// - /// The class type that is loadable - /// The class type that contains the construction method - /// The argument type that the constructor takes (may be null) - /// The signature of the constructor of this class (in addition to the arguments parameter) - /// The name to use when presenting a list to users - /// The names that can be used to load the class, for example, from a command line - public LoadableClassAttribute(Type instType, Type loaderType, Type argType, Type sigType, string userName, params string[] loadNames) - : base(null, instType, loaderType, argType, new[] { sigType }, userName, loadNames) - { - } + public LoadableClassAttribute(Type instType, Type argType, Type[] sigTypes, string userName, params string[] loadNames) + : base(null, instType, instType, argType, sigTypes, userName, loadNames) + { + } - public LoadableClassAttribute(Type instType, Type argType, Type[] sigTypes, string userName, params string[] loadNames) - : base(null, instType, instType, argType, sigTypes, userName, loadNames) - { - } + public LoadableClassAttribute(Type instType, Type loaderType, Type argType, Type[] sigTypes, string userName, params string[] loadNames) + : base(null, instType, loaderType, argType, sigTypes, userName, loadNames) + { + } - public LoadableClassAttribute(Type instType, Type loaderType, Type argType, Type[] sigTypes, string userName, params string[] loadNames) - : base(null, instType, loaderType, argType, sigTypes, userName, loadNames) - { - } + /// + /// Assembly attribute used to specify that a class is loadable by a machine learning + /// host environment, such as TLC + /// + /// The description summary of the class type + /// The class type that is loadable + /// The argument type that the constructor takes (may be null) + /// The signature of the constructor of this class (in addition to the arguments parameter) + /// The name to use when presenting a list to users + /// The names that can be used to load the class, for example, from a command line + public LoadableClassAttribute(string summary, Type instType, Type argType, Type sigType, string userName, params string[] loadNames) + : base(summary, instType, instType, argType, new[] { sigType }, userName, loadNames) + { + } - /// - /// Assembly attribute used to specify that a class is loadable by a machine learning - /// host environment, such as TLC - /// - /// The description summary of the class type - /// The class type that is loadable - /// The argument type that the constructor takes (may be null) - /// The signature of the constructor of this class (in addition to the arguments parameter) - /// The name to use when presenting a list to users - /// The names that can be used to load the class, for example, from a command line - public LoadableClassAttribute(string summary, Type instType, Type argType, Type sigType, string userName, params string[] loadNames) - : base(summary, instType, instType, argType, new[] { sigType }, userName, loadNames) - { - } + /// + /// Assembly attribute used to specify that a class is loadable by a machine learning + /// host environment, such as TLC + /// + /// The description summary of the class type + /// The class type that is loadable + /// The class type that contains the construction method + /// The argument type that the constructor takes (may be null) + /// The signature of the constructor of this class (in addition to the arguments parameter) + /// The name to use when presenting a list to users + /// The names that can be used to load the class, for example, from a command line + public LoadableClassAttribute(string summary, Type instType, Type loaderType, Type argType, Type sigType, string userName, params string[] loadNames) + : base(summary, instType, loaderType, argType, new[] { sigType }, userName, loadNames) + { + } - /// - /// Assembly attribute used to specify that a class is loadable by a machine learning - /// host environment, such as TLC - /// - /// The description summary of the class type - /// The class type that is loadable - /// The class type that contains the construction method - /// The argument type that the constructor takes (may be null) - /// The signature of the constructor of this class (in addition to the arguments parameter) - /// The name to use when presenting a list to users - /// The names that can be used to load the class, for example, from a command line - public LoadableClassAttribute(string summary, Type instType, Type loaderType, Type argType, Type sigType, string userName, params string[] loadNames) - : base(summary, instType, loaderType, argType, new[] { sigType }, userName, loadNames) - { - } + public LoadableClassAttribute(string summary, Type instType, Type argType, Type[] sigTypes, string userName, params string[] loadNames) + : base(summary, instType, instType, argType, sigTypes, userName, loadNames) + { + } - public LoadableClassAttribute(string summary, Type instType, Type argType, Type[] sigTypes, string userName, params string[] loadNames) - : base(summary, instType, instType, argType, sigTypes, userName, loadNames) + public LoadableClassAttribute(string summary, Type instType, Type loaderType, Type argType, Type[] sigTypes, string userName, params string[] loadNames) + : base(summary, instType, loaderType, argType, sigTypes, userName, loadNames) + { + } +} + +internal abstract class LoadableClassAttributeBase : Attribute +{ + // Note: these properties have private setters to make attribute parsing easier - the values + // are all guaranteed to be in the ConstructorArguments of the CustomAttributeData + // (no named arguments). + + /// + /// The type that is created/loaded. + /// + public Type InstanceType { get; private set; } + + /// + /// The type that contains the construction method, whether static Instance property, + /// static Create method, or constructor. Of course, a constructor is only permissible if + /// this type derives from InstanceType. This defaults to the same as InstanceType. + /// + public Type LoaderType { get; private set; } + + /// + /// The command line arguments object type. This should be null if there isn't one. + /// + public Type ArgType { get; private set; } + + /// + /// This indicates the extra parameter types. It must be a delegate type. The return type should be void. + /// The parameter types of the SigType delegate should NOT include the ArgType. + /// + public Type[] SigTypes { get; private set; } + + /// + /// Note that CtorTypes includes the ArgType (if there is one), and the parameter types of the SigType. + /// + public Type[] CtorTypes { get; private set; } + + /// + /// The description summary of the class type. + /// + public string Summary { get; private set; } + + /// + /// UserName may be null or empty indicating that it should be hidden in UI. + /// + public string UserName { get; private set; } + public string[] LoadNames { get; private set; } + + // REVIEW: This is out of step with the remainder of the class. However, my opinion is that the + // LoadableClassAttribute class's design is worth reconsideration: having so many Type and string arguments + // be defined *without names* in a constructor has led to enormous confusion. + + // REVIEW: Presumably it would be beneficial to have multiple documents. + + /// + /// This should indicate a path within the doc/public directory next to the TLC + /// solution, where the documentation lies. This value will be used as part of a URL, so, + /// the path separator should be phrased as '/' forward slashes rather than backslashes. + public string DocName { get; set; } + + protected LoadableClassAttributeBase(string summary, Type instType, Type loaderType, Type argType, Type[] sigTypes, string userName, params string[] loadNames) + { + Contracts.CheckValueOrNull(summary); + Contracts.CheckValue(instType, nameof(instType)); + Contracts.CheckValue(loaderType, nameof(loaderType)); + Contracts.CheckNonEmpty(sigTypes, nameof(sigTypes)); + + if (Utils.Size(loadNames) == 0) + loadNames = new string[] { userName }; + + if (loadNames.Any(s => string.IsNullOrWhiteSpace(s))) + throw Contracts.ExceptEmpty(nameof(loadNames), "LoadableClass loadName parameter can't be empty"); + + var sigType = sigTypes[0]; + Contracts.CheckValue(sigType, nameof(sigTypes)); + Type[] types; + Contracts.CheckParam(sigType.BaseType == typeof(System.MulticastDelegate), nameof(sigTypes), "LoadableClass signature type must be a delegate type"); + + var meth = sigType.GetMethod("Invoke"); + Contracts.CheckParam(meth != null, nameof(sigTypes), "LoadableClass signature type must be a delegate type"); + Contracts.CheckParam(meth.ReturnType == typeof(void), nameof(sigTypes), "LoadableClass signature type must be a delegate type with void return"); + + var parms = meth.GetParameters(); + int itypeBase = 0; + + if (argType != null) { + types = new Type[1 + parms.Length]; + types[itypeBase++] = argType; } + else if (parms.Length > 0) + types = new Type[parms.Length]; + else + types = Type.EmptyTypes; - public LoadableClassAttribute(string summary, Type instType, Type loaderType, Type argType, Type[] sigTypes, string userName, params string[] loadNames) - : base(summary, instType, loaderType, argType, sigTypes, userName, loadNames) + for (int itype = 0; itype < parms.Length; itype++) { + var parm = parms[itype]; + if ((parm.Attributes & (ParameterAttributes.Out | ParameterAttributes.Retval)) != 0) + throw Contracts.Except("Invalid signature parameter attributes"); + types[itypeBase + itype] = parm.ParameterType; } - } - internal abstract class LoadableClassAttributeBase : Attribute - { - // Note: these properties have private setters to make attribute parsing easier - the values - // are all guaranteed to be in the ConstructorArguments of the CustomAttributeData - // (no named arguments). - - /// - /// The type that is created/loaded. - /// - public Type InstanceType { get; private set; } - - /// - /// The type that contains the construction method, whether static Instance property, - /// static Create method, or constructor. Of course, a constructor is only permissible if - /// this type derives from InstanceType. This defaults to the same as InstanceType. - /// - public Type LoaderType { get; private set; } - - /// - /// The command line arguments object type. This should be null if there isn't one. - /// - public Type ArgType { get; private set; } - - /// - /// This indicates the extra parameter types. It must be a delegate type. The return type should be void. - /// The parameter types of the SigType delegate should NOT include the ArgType. - /// - public Type[] SigTypes { get; private set; } - - /// - /// Note that CtorTypes includes the ArgType (if there is one), and the parameter types of the SigType. - /// - public Type[] CtorTypes { get; private set; } - - /// - /// The description summary of the class type. - /// - public string Summary { get; private set; } - - /// - /// UserName may be null or empty indicating that it should be hidden in UI. - /// - public string UserName { get; private set; } - public string[] LoadNames { get; private set; } - - // REVIEW: This is out of step with the remainder of the class. However, my opinion is that the - // LoadableClassAttribute class's design is worth reconsideration: having so many Type and string arguments - // be defined *without names* in a constructor has led to enormous confusion. - - // REVIEW: Presumably it would be beneficial to have multiple documents. - - /// - /// This should indicate a path within the doc/public directory next to the TLC - /// solution, where the documentation lies. This value will be used as part of a URL, so, - /// the path separator should be phrased as '/' forward slashes rather than backslashes. - public string DocName { get; set; } - - protected LoadableClassAttributeBase(string summary, Type instType, Type loaderType, Type argType, Type[] sigTypes, string userName, params string[] loadNames) + for (int i = 1; i < sigTypes.Length; i++) { - Contracts.CheckValueOrNull(summary); - Contracts.CheckValue(instType, nameof(instType)); - Contracts.CheckValue(loaderType, nameof(loaderType)); - Contracts.CheckNonEmpty(sigTypes, nameof(sigTypes)); - - if (Utils.Size(loadNames) == 0) - loadNames = new string[] { userName }; - - if (loadNames.Any(s => string.IsNullOrWhiteSpace(s))) - throw Contracts.ExceptEmpty(nameof(loadNames), "LoadableClass loadName parameter can't be empty"); - - var sigType = sigTypes[0]; + sigType = sigTypes[i]; Contracts.CheckValue(sigType, nameof(sigTypes)); - Type[] types; - Contracts.CheckParam(sigType.BaseType == typeof(System.MulticastDelegate), nameof(sigTypes), "LoadableClass signature type must be a delegate type"); - var meth = sigType.GetMethod("Invoke"); + Contracts.Check(sigType.BaseType == typeof(System.MulticastDelegate), "LoadableClass signature type must be a delegate type"); + + meth = sigType.GetMethod("Invoke"); Contracts.CheckParam(meth != null, nameof(sigTypes), "LoadableClass signature type must be a delegate type"); Contracts.CheckParam(meth.ReturnType == typeof(void), nameof(sigTypes), "LoadableClass signature type must be a delegate type with void return"); - - var parms = meth.GetParameters(); - int itypeBase = 0; - - if (argType != null) - { - types = new Type[1 + parms.Length]; - types[itypeBase++] = argType; - } - else if (parms.Length > 0) - types = new Type[parms.Length]; - else - types = Type.EmptyTypes; - + parms = meth.GetParameters(); + Contracts.CheckParam(parms.Length + itypeBase == types.Length, nameof(sigTypes), "LoadableClass signatures must have the same number of parameters"); for (int itype = 0; itype < parms.Length; itype++) { var parm = parms[itype]; if ((parm.Attributes & (ParameterAttributes.Out | ParameterAttributes.Retval)) != 0) - throw Contracts.Except("Invalid signature parameter attributes"); - types[itypeBase + itype] = parm.ParameterType; + throw Contracts.ExceptParam(nameof(sigTypes), "Invalid signature parameter attributes"); + Contracts.CheckParam(types[itypeBase + itype] == parm.ParameterType, nameof(sigTypes), + "LoadableClass signatures must have the same set of parameters"); } - - for (int i = 1; i < sigTypes.Length; i++) - { - sigType = sigTypes[i]; - Contracts.CheckValue(sigType, nameof(sigTypes)); - - Contracts.Check(sigType.BaseType == typeof(System.MulticastDelegate), "LoadableClass signature type must be a delegate type"); - - meth = sigType.GetMethod("Invoke"); - Contracts.CheckParam(meth != null, nameof(sigTypes), "LoadableClass signature type must be a delegate type"); - Contracts.CheckParam(meth.ReturnType == typeof(void), nameof(sigTypes), "LoadableClass signature type must be a delegate type with void return"); - parms = meth.GetParameters(); - Contracts.CheckParam(parms.Length + itypeBase == types.Length, nameof(sigTypes), "LoadableClass signatures must have the same number of parameters"); - for (int itype = 0; itype < parms.Length; itype++) - { - var parm = parms[itype]; - if ((parm.Attributes & (ParameterAttributes.Out | ParameterAttributes.Retval)) != 0) - throw Contracts.ExceptParam(nameof(sigTypes), "Invalid signature parameter attributes"); - Contracts.CheckParam(types[itypeBase + itype] == parm.ParameterType, nameof(sigTypes), - "LoadableClass signatures must have the same set of parameters"); - } - } - - InstanceType = instType; - LoaderType = loaderType; - ArgType = argType; - SigTypes = sigTypes; - CtorTypes = types; - Summary = summary; - UserName = userName; - LoadNames = loadNames; } + + InstanceType = instType; + LoaderType = loaderType; + ArgType = argType; + SigTypes = sigTypes; + CtorTypes = types; + Summary = summary; + UserName = userName; + LoadNames = loadNames; } }