|
9 | 9 | using Microsoft.ML.Internal.Utilities; |
10 | 10 | using Microsoft.ML.Runtime; |
11 | 11 |
|
12 | | -namespace Microsoft.ML.EntryPoints |
| 12 | +namespace Microsoft.ML.EntryPoints; |
| 13 | + |
| 14 | +[BestFriend] |
| 15 | +internal static class EntryPointUtils |
13 | 16 | { |
14 | | - [BestFriend] |
15 | | - internal static class EntryPointUtils |
| 17 | + private static readonly FuncStaticMethodInfo1<TlcModule.RangeAttribute, object, bool> _isValueWithinRangeMethodInfo |
| 18 | + = new FuncStaticMethodInfo1<TlcModule.RangeAttribute, object, bool>(IsValueWithinRange<int>); |
| 19 | + |
| 20 | + private static bool IsValueWithinRange<T>(TlcModule.RangeAttribute range, object obj) |
16 | 21 | { |
17 | | - private static readonly FuncStaticMethodInfo1<TlcModule.RangeAttribute, object, bool> _isValueWithinRangeMethodInfo |
18 | | - = new FuncStaticMethodInfo1<TlcModule.RangeAttribute, object, bool>(IsValueWithinRange<int>); |
| 22 | + T val; |
| 23 | + if (obj is Optional<T> asOptional) |
| 24 | + val = asOptional.Value; |
| 25 | + else |
| 26 | + val = (T)obj; |
19 | 27 |
|
20 | | - private static bool IsValueWithinRange<T>(TlcModule.RangeAttribute range, object obj) |
21 | | - { |
22 | | - T val; |
23 | | - if (obj is Optional<T> asOptional) |
24 | | - val = asOptional.Value; |
25 | | - else |
26 | | - val = (T)obj; |
27 | | - |
28 | | - return |
29 | | - (range.Min == null || ((IComparable)range.Min).CompareTo(val) <= 0) && |
30 | | - (range.Inf == null || ((IComparable)range.Inf).CompareTo(val) < 0) && |
31 | | - (range.Max == null || ((IComparable)range.Max).CompareTo(val) >= 0) && |
32 | | - (range.Sup == null || ((IComparable)range.Sup).CompareTo(val) > 0); |
33 | | - } |
| 28 | + return |
| 29 | + (range.Min == null || ((IComparable)range.Min).CompareTo(val) <= 0) && |
| 30 | + (range.Inf == null || ((IComparable)range.Inf).CompareTo(val) < 0) && |
| 31 | + (range.Max == null || ((IComparable)range.Max).CompareTo(val) >= 0) && |
| 32 | + (range.Sup == null || ((IComparable)range.Sup).CompareTo(val) > 0); |
| 33 | + } |
34 | 34 |
|
35 | | - public static bool IsValueWithinRange(this TlcModule.RangeAttribute range, object val) |
36 | | - { |
37 | | - Contracts.AssertValue(range); |
38 | | - Contracts.AssertValue(val); |
39 | | - // Avoid trying to cast double as float. If range |
40 | | - // was specified using floats, but value being checked |
41 | | - // is double, change range to be of type double |
42 | | - if (range.Type == typeof(float) && val is double) |
43 | | - range.CastToDouble(); |
44 | | - return Utils.MarshalInvoke(_isValueWithinRangeMethodInfo, range.Type, range, val); |
45 | | - } |
| 35 | + public static bool IsValueWithinRange(this TlcModule.RangeAttribute range, object val) |
| 36 | + { |
| 37 | + Contracts.AssertValue(range); |
| 38 | + Contracts.AssertValue(val); |
| 39 | + // Avoid trying to cast double as float. If range |
| 40 | + // was specified using floats, but value being checked |
| 41 | + // is double, change range to be of type double |
| 42 | + if (range.Type == typeof(float) && val is double) |
| 43 | + range.CastToDouble(); |
| 44 | + return Utils.MarshalInvoke(_isValueWithinRangeMethodInfo, range.Type, range, val); |
| 45 | + } |
46 | 46 |
|
47 | | - /// <summary> |
48 | | - /// Performs checks on an EntryPoint input class equivalent to the checks that are done |
49 | | - /// when parsing a JSON EntryPoint graph. |
50 | | - /// |
51 | | - /// Call this method from EntryPoint methods to ensure that range and required checks are performed |
52 | | - /// in a consistent manner when EntryPoints are created directly from code. |
53 | | - /// </summary> |
54 | | - public static void CheckInputArgs(IExceptionContext ectx, object args) |
| 47 | + /// <summary> |
| 48 | + /// Performs checks on an EntryPoint input class equivalent to the checks that are done |
| 49 | + /// when parsing a JSON EntryPoint graph. |
| 50 | + /// |
| 51 | + /// Call this method from EntryPoint methods to ensure that range and required checks are performed |
| 52 | + /// in a consistent manner when EntryPoints are created directly from code. |
| 53 | + /// </summary> |
| 54 | + public static void CheckInputArgs(IExceptionContext ectx, object args) |
| 55 | + { |
| 56 | + foreach (var fieldInfo in args.GetType().GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)) |
55 | 57 | { |
56 | | - foreach (var fieldInfo in args.GetType().GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)) |
57 | | - { |
58 | | - var attr = fieldInfo.GetCustomAttributes(typeof(ArgumentAttribute), false).FirstOrDefault() |
59 | | - as ArgumentAttribute; |
60 | | - if (attr == null || attr.Visibility == ArgumentAttribute.VisibilityType.CmdLineOnly) |
61 | | - continue; |
62 | | - |
63 | | - var fieldVal = fieldInfo.GetValue(args); |
64 | | - var fieldType = fieldInfo.FieldType; |
65 | | - |
66 | | - // Optionals are either left in their Implicit constructed state or |
67 | | - // a new Explicit optional is constructed. They should never be set |
68 | | - // to null. |
69 | | - if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>) && fieldVal == null) |
70 | | - throw ectx.Except("Field '{0}' is Optional<> and set to null instead of an explicit value.", fieldInfo.Name); |
71 | | - |
72 | | - if (attr.IsRequired) |
73 | | - { |
74 | | - bool equalToDefault; |
75 | | - if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>)) |
76 | | - equalToDefault = !((Optional)fieldVal).IsExplicit; |
77 | | - else |
78 | | - equalToDefault = fieldType.IsValueType ? Activator.CreateInstance(fieldType).Equals(fieldVal) : fieldVal == null; |
79 | | - |
80 | | - if (equalToDefault) |
81 | | - throw ectx.Except("Field '{0}' is required but is not set.", fieldInfo.Name); |
82 | | - } |
83 | | - |
84 | | - var rangeAttr = fieldInfo.GetCustomAttributes(typeof(TlcModule.RangeAttribute), false).FirstOrDefault() |
85 | | - as TlcModule.RangeAttribute; |
86 | | - if (rangeAttr != null && fieldVal != null && !rangeAttr.IsValueWithinRange(fieldVal)) |
87 | | - throw ectx.Except("Field '{0}' is set to a value that falls outside the range bounds.", fieldInfo.Name); |
88 | | - } |
89 | | - } |
| 58 | + var attr = fieldInfo.GetCustomAttributes(typeof(ArgumentAttribute), false).FirstOrDefault() |
| 59 | + as ArgumentAttribute; |
| 60 | + if (attr == null || attr.Visibility == ArgumentAttribute.VisibilityType.CmdLineOnly) |
| 61 | + continue; |
90 | 62 |
|
91 | | - public static IHost CheckArgsAndCreateHost(IHostEnvironment env, string hostName, object input) |
92 | | - { |
93 | | - Contracts.CheckValue(env, nameof(env)); |
94 | | - var host = env.Register(hostName); |
95 | | - host.CheckValue(input, nameof(input)); |
96 | | - CheckInputArgs(host, input); |
97 | | - return host; |
98 | | - } |
| 63 | + var fieldVal = fieldInfo.GetValue(args); |
| 64 | + var fieldType = fieldInfo.FieldType; |
99 | 65 |
|
100 | | - /// <summary> |
101 | | - /// Searches for the given column name in the schema. This method applies a |
102 | | - /// common policy that throws an exception if the column is not found |
103 | | - /// and the column name was explicitly specified. If the column is not found |
104 | | - /// and the column name was not explicitly specified, it returns null. |
105 | | - /// </summary> |
106 | | - public static string FindColumnOrNull(IExceptionContext ectx, DataViewSchema schema, Optional<string> value) |
107 | | - { |
108 | | - Contracts.CheckValueOrNull(ectx); |
109 | | - ectx.CheckValue(schema, nameof(schema)); |
110 | | - ectx.CheckValue(value, nameof(value)); |
| 66 | + // Optionals are either left in their Implicit constructed state or |
| 67 | + // a new Explicit optional is constructed. They should never be set |
| 68 | + // to null. |
| 69 | + if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>) && fieldVal == null) |
| 70 | + throw ectx.Except("Field '{0}' is Optional<> and set to null instead of an explicit value.", fieldInfo.Name); |
111 | 71 |
|
112 | | - if (value == "") |
113 | | - return null; |
114 | | - if (schema.GetColumnOrNull(value) == null) |
| 72 | + if (attr.IsRequired) |
115 | 73 | { |
116 | | - if (value.IsExplicit) |
117 | | - throw ectx.Except("Column '{0}' not found", value); |
118 | | - return null; |
| 74 | + bool equalToDefault; |
| 75 | + if (fieldType.IsGenericType && fieldType.GetGenericTypeDefinition() == typeof(Optional<>)) |
| 76 | + equalToDefault = !((Optional)fieldVal).IsExplicit; |
| 77 | + else |
| 78 | + equalToDefault = fieldType.IsValueType ? Activator.CreateInstance(fieldType).Equals(fieldVal) : fieldVal == null; |
| 79 | + |
| 80 | + if (equalToDefault) |
| 81 | + throw ectx.Except("Field '{0}' is required but is not set.", fieldInfo.Name); |
119 | 82 | } |
120 | | - return value; |
| 83 | + |
| 84 | + var rangeAttr = fieldInfo.GetCustomAttributes(typeof(TlcModule.RangeAttribute), false).FirstOrDefault() |
| 85 | + as TlcModule.RangeAttribute; |
| 86 | + if (rangeAttr != null && fieldVal != null && !rangeAttr.IsValueWithinRange(fieldVal)) |
| 87 | + throw ectx.Except("Field '{0}' is set to a value that falls outside the range bounds.", fieldInfo.Name); |
121 | 88 | } |
| 89 | + } |
122 | 90 |
|
123 | | - /// <summary> |
124 | | - /// Converts EntryPoint Optional{T} types into nullable types, with the |
125 | | - /// implicit value being converted to the null value. |
126 | | - /// </summary> |
127 | | - public static T? AsNullable<T>(this Optional<T> opt) where T : struct |
| 91 | + public static IHost CheckArgsAndCreateHost(IHostEnvironment env, string hostName, object input) |
| 92 | + { |
| 93 | + Contracts.CheckValue(env, nameof(env)); |
| 94 | + var host = env.Register(hostName); |
| 95 | + host.CheckValue(input, nameof(input)); |
| 96 | + CheckInputArgs(host, input); |
| 97 | + return host; |
| 98 | + } |
| 99 | + |
| 100 | + /// <summary> |
| 101 | + /// Searches for the given column name in the schema. This method applies a |
| 102 | + /// common policy that throws an exception if the column is not found |
| 103 | + /// and the column name was explicitly specified. If the column is not found |
| 104 | + /// and the column name was not explicitly specified, it returns null. |
| 105 | + /// </summary> |
| 106 | + public static string FindColumnOrNull(IExceptionContext ectx, DataViewSchema schema, Optional<string> value) |
| 107 | + { |
| 108 | + Contracts.CheckValueOrNull(ectx); |
| 109 | + ectx.CheckValue(schema, nameof(schema)); |
| 110 | + ectx.CheckValue(value, nameof(value)); |
| 111 | + |
| 112 | + if (value == "") |
| 113 | + return null; |
| 114 | + if (schema.GetColumnOrNull(value) == null) |
128 | 115 | { |
129 | | - if (opt.IsExplicit) |
130 | | - return opt.Value; |
131 | | - else |
132 | | - return null; |
| 116 | + if (value.IsExplicit) |
| 117 | + throw ectx.Except("Column '{0}' not found", value); |
| 118 | + return null; |
133 | 119 | } |
| 120 | + return value; |
| 121 | + } |
| 122 | + |
| 123 | + /// <summary> |
| 124 | + /// Converts EntryPoint Optional{T} types into nullable types, with the |
| 125 | + /// implicit value being converted to the null value. |
| 126 | + /// </summary> |
| 127 | + public static T? AsNullable<T>(this Optional<T> opt) where T : struct |
| 128 | + { |
| 129 | + if (opt.IsExplicit) |
| 130 | + return opt.Value; |
| 131 | + else |
| 132 | + return null; |
134 | 133 | } |
135 | 134 | } |
0 commit comments