Skip to content

Commit a627d5b

Browse files
authored
Create API for extracting information about the nodes in a TensorFlow model (#862)
* Add a method that returns TensorFlow model outputs as an ISchema. * Update after merge with master * Address PR comments. * Add metadata with information about the operation type, and the inputs needed for it. * Add method that returns an enumerable of the information about graph nodes, and a console app that displays it * Add the DnnAnalyzer project files. * Address code review comments * Make needed changes after merge with master * Fix bug when there is a node with 1 dimension that is unknown
1 parent 86f4d93 commit a627d5b

File tree

8 files changed

+476
-79
lines changed

8 files changed

+476
-79
lines changed

Microsoft.ML.sln

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Analyzer", "sr
115115
EndProject
116116
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.StaticPipelineTesting", "test\Microsoft.ML.StaticPipelineTesting\Microsoft.ML.StaticPipelineTesting.csproj", "{8B38BF24-35F4-4787-A9C5-22D35987106E}"
117117
EndProject
118+
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.DnnAnalyzer", "src\Microsoft.ML.DnnAnalyzer\Microsoft.ML.DnnAnalyzer\Microsoft.ML.DnnAnalyzer.csproj", "{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}"
119+
EndProject
118120
Global
119121
GlobalSection(SolutionConfigurationPlatforms) = preSolution
120122
Debug|Any CPU = Debug|Any CPU
@@ -419,6 +421,14 @@ Global
419421
{8B38BF24-35F4-4787-A9C5-22D35987106E}.Release|Any CPU.Build.0 = Release|Any CPU
420422
{8B38BF24-35F4-4787-A9C5-22D35987106E}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
421423
{8B38BF24-35F4-4787-A9C5-22D35987106E}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
424+
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
425+
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Debug|Any CPU.Build.0 = Debug|Any CPU
426+
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
427+
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
428+
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Release|Any CPU.ActiveCfg = Release|Any CPU
429+
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Release|Any CPU.Build.0 = Release|Any CPU
430+
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
431+
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
422432
EndGlobalSection
423433
GlobalSection(SolutionProperties) = preSolution
424434
HideSolutionNode = FALSE
@@ -466,6 +476,7 @@ Global
466476
{570A0B8A-5463-44D2-8521-54C0CA4CACA9} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
467477
{6DEF0F40-3853-47B3-8165-5F24BA5E14DF} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
468478
{8B38BF24-35F4-4787-A9C5-22D35987106E} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
479+
{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
469480
EndGlobalSection
470481
GlobalSection(ExtensibilityGlobals) = postSolution
471482
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}

src/Microsoft.ML.Data/DataView/SimpleRow.cs

Lines changed: 82 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -64,97 +64,135 @@ public bool IsColumnActive(int col)
6464
/// An <see cref="ISchema"/> that takes all column names and types as constructor parameters.
6565
/// The columns do not have metadata.
6666
/// </summary>
67-
public sealed class SimpleSchema : ISchema
67+
public abstract class SimpleSchemaBase : ISchema
6868
{
69-
private readonly IExceptionContext _ectx;
69+
protected readonly IExceptionContext Ectx;
7070
private readonly string[] _names;
71-
private readonly ColumnType[] _types;
72-
private readonly Dictionary<string, int> _columnNameMap;
73-
private readonly MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>[] _keyValueGetters;
71+
protected readonly ColumnType[] Types;
72+
protected readonly Dictionary<string, int> ColumnNameMap;
7473

75-
public int ColumnCount => _types.Length;
74+
public int ColumnCount => Types.Length;
7675

77-
public SimpleSchema(IExceptionContext ectx, params KeyValuePair<string, ColumnType>[] columns)
76+
protected SimpleSchemaBase(IExceptionContext ectx, params KeyValuePair<string, ColumnType>[] columns)
7877
{
7978
Contracts.CheckValueOrNull(ectx);
80-
_ectx = ectx;
81-
_ectx.CheckValue(columns, nameof(columns));
79+
Ectx = ectx;
80+
Ectx.CheckValue(columns, nameof(columns));
8281

8382
_names = new string[columns.Length];
84-
_types = new ColumnType[columns.Length];
85-
_columnNameMap = new Dictionary<string, int>();
83+
Types = new ColumnType[columns.Length];
84+
ColumnNameMap = new Dictionary<string, int>();
8685
for (int i = 0; i < columns.Length; i++)
8786
{
8887
_names[i] = columns[i].Key;
89-
_types[i] = columns[i].Value;
90-
if (_columnNameMap.ContainsKey(columns[i].Key))
88+
Types[i] = columns[i].Value;
89+
if (ColumnNameMap.ContainsKey(columns[i].Key))
9190
throw ectx.ExceptParam(nameof(columns), $"Duplicate column name: '{columns[i].Key}'");
92-
_columnNameMap[columns[i].Key] = i;
93-
}
94-
_keyValueGetters = new MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>[ColumnCount];
95-
}
96-
97-
public SimpleSchema(IExceptionContext ectx, KeyValuePair<string, ColumnType>[] columns, Dictionary<string, MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>> keyValues)
98-
: this(ectx, columns)
99-
{
100-
foreach (var kvp in keyValues)
101-
{
102-
var name = kvp.Key;
103-
var getter = kvp.Value;
104-
if (!_columnNameMap.TryGetValue(name, out int col))
105-
throw _ectx.ExceptParam(nameof(keyValues), $"Output schema does not contain column '{name}'");
106-
if (!_types[col].ItemType.IsKey)
107-
throw _ectx.ExceptParam(nameof(keyValues), $"Column '{name}' is not a key column, so it cannot have key value metadata");
108-
_keyValueGetters[col] = getter;
91+
ColumnNameMap[columns[i].Key] = i;
10992
}
11093
}
11194

11295
public bool TryGetColumnIndex(string name, out int col)
11396
{
114-
return _columnNameMap.TryGetValue(name, out col);
97+
return ColumnNameMap.TryGetValue(name, out col);
11598
}
11699

117100
public string GetColumnName(int col)
118101
{
119-
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
102+
Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
120103
return _names[col];
121104
}
122105

123106
public ColumnType GetColumnType(int col)
124107
{
125-
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
126-
return _types[col];
108+
Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
109+
return Types[col];
127110
}
128111

129112
public IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypes(int col)
130113
{
131-
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
114+
Ectx.Assert(0 <= col && col < ColumnCount);
115+
return GetMetadataTypesCore(col);
116+
}
117+
118+
protected abstract IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypesCore(int col);
119+
120+
public ColumnType GetMetadataTypeOrNull(string kind, int col)
121+
{
122+
Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
123+
return GetMetadataTypeOrNullCore(kind, col);
124+
}
125+
126+
protected abstract ColumnType GetMetadataTypeOrNullCore(string kind, int col);
127+
128+
public void GetMetadata<TValue>(string kind, int col, ref TValue value)
129+
{
130+
Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
131+
GetMetadataCore(kind, col, ref value);
132+
}
133+
134+
protected abstract void GetMetadataCore<TValue>(string kind, int col, ref TValue value);
135+
}
136+
137+
/// <summary>
138+
/// An <see cref="ISchema"/> that takes all column names and types as constructor parameters.
139+
/// The columns can optionally have text <see cref="MetadataUtils.Kinds.KeyValues"/> metadata.
140+
/// </summary>
141+
public sealed class SimpleSchema : SimpleSchemaBase
142+
{
143+
private readonly MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>[] _keyValueGetters;
144+
145+
public SimpleSchema(IExceptionContext ectx, params KeyValuePair<string, ColumnType>[] columns)
146+
: base(ectx, columns)
147+
{
148+
_keyValueGetters = new MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>[ColumnCount];
149+
}
150+
151+
public SimpleSchema(IExceptionContext ectx, KeyValuePair<string, ColumnType>[] columns,
152+
Dictionary<string, MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>> keyValues)
153+
: this(ectx, columns)
154+
{
155+
foreach (var kvp in keyValues)
156+
{
157+
var name = kvp.Key;
158+
var getter = kvp.Value;
159+
if (!ColumnNameMap.TryGetValue(name, out int col))
160+
throw Ectx.ExceptParam(nameof(keyValues), $"Output schema does not contain column '{name}'");
161+
if (!Types[col].ItemType.IsKey)
162+
throw Ectx.ExceptParam(nameof(keyValues), $"Column '{name}' is not a key column, so it cannot have key value metadata");
163+
_keyValueGetters[col] = getter;
164+
}
165+
}
166+
167+
protected override IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypesCore(int col)
168+
{
169+
Ectx.Assert(0 <= col && col < ColumnCount);
132170
if (_keyValueGetters[col] != null)
133171
{
134-
_ectx.Assert(_types[col].ItemType.IsKey);
172+
Ectx.Assert(Types[col].ItemType.IsKey);
135173
yield return new KeyValuePair<string, ColumnType>(MetadataUtils.Kinds.KeyValues,
136-
new VectorType(TextType.Instance, _types[col].ItemType.KeyCount));
174+
new VectorType(TextType.Instance, Types[col].ItemType.KeyCount));
137175
}
138176
}
139177

140-
public ColumnType GetMetadataTypeOrNull(string kind, int col)
178+
protected override ColumnType GetMetadataTypeOrNullCore(string kind, int col)
141179
{
142-
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
180+
Ectx.Assert(0 <= col && col < ColumnCount);
143181
if (kind == MetadataUtils.Kinds.KeyValues && _keyValueGetters[col] != null)
144182
{
145-
_ectx.Assert(_types[col].ItemType.IsKey);
146-
return new VectorType(TextType.Instance, _types[col].ItemType.KeyCount);
183+
Ectx.Assert(Types[col].ItemType.IsKey);
184+
return new VectorType(TextType.Instance, Types[col].ItemType.KeyCount);
147185
}
148186
return null;
149187
}
150188

151-
public void GetMetadata<TValue>(string kind, int col, ref TValue value)
189+
protected override void GetMetadataCore<TValue>(string kind, int col, ref TValue value)
152190
{
153-
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
191+
Ectx.Assert(0 <= col && col < ColumnCount);
154192
if (kind == MetadataUtils.Kinds.KeyValues && _keyValueGetters[col] != null)
155193
_keyValueGetters[col].Marshal(col, ref value);
156194
else
157-
throw _ectx.ExceptGetMetadata();
195+
throw Ectx.ExceptGetMetadata();
158196
}
159197
}
160198
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.ML.Runtime;
6+
using Microsoft.ML.Runtime.Data;
7+
using Microsoft.ML.Runtime.Internal.Utilities;
8+
using Microsoft.ML.Transforms.TensorFlow;
9+
using System;
10+
using System.Linq;
11+
12+
namespace Microsoft.ML.DnnAnalyzer
13+
{
14+
public static class DnnAnalyzer
15+
{
16+
public static void Main(string[] args)
17+
{
18+
if (Utils.Size(args) != 1)
19+
{
20+
Console.Error.WriteLine("Usage: dotnet DnnAnalyzer.dll <model_location>");
21+
return;
22+
}
23+
24+
foreach (var (name, opType, type, inputs) in TensorFlowUtils.GetModelNodes(args[0]))
25+
{
26+
var inputsString = inputs.Length == 0 ? "" : $", input nodes: {string.Join(", ", inputs)}";
27+
Console.WriteLine($"Graph node: '{name}', operation type: '{opType}', output type: '{type}'{inputsString}");
28+
}
29+
}
30+
}
31+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
3+
<PropertyGroup>
4+
<OutputType>Exe</OutputType>
5+
<TargetFramework>netcoreapp2.1</TargetFramework>
6+
<AssemblyName>DnnAnalyzer</AssemblyName>
7+
<IncludeInPackage>Microsoft.ML.TensorFlow</IncludeInPackage>
8+
</PropertyGroup>
9+
10+
<ItemGroup>
11+
<ProjectReference Include="..\..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
12+
<ProjectReference Include="..\..\Microsoft.ML.TensorFlow\Microsoft.ML.TensorFlow.csproj" />
13+
</ItemGroup>
14+
15+
<ItemGroup>
16+
<NativeAssemblyReference Include="tensorflow" />
17+
</ItemGroup>
18+
19+
</Project>

src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs

Lines changed: 71 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
using size_t = System.UIntPtr;
2525
using System.Collections.Generic;
26+
using System.Collections;
2627

2728
#pragma warning disable MSML_GeneralName
2829
#pragma warning disable MSML_PrivateFieldName
@@ -492,7 +493,7 @@ public void SetConfig(IntPtr protoData, int length, TFStatus status = null)
492493
/// "hot", and add a "sub" operation there the result will be "demo/hot/sub".
493494
/// </para>
494495
/// </remarks>
495-
internal partial class TFGraph : TFDisposableThreadSafe
496+
internal partial class TFGraph : TFDisposableThreadSafe, IEnumerable<TFOperation>
496497
{
497498
// extern TF_Graph * TF_NewGraph ();
498499
[DllImport(NativeBinding.TensorFlowLibrary)]
@@ -696,6 +697,33 @@ public override string ToString()
696697
IntPtr len;
697698
return TF_GraphDebugString(Handle, out len);
698699
}
700+
701+
[DllImport(NativeBinding.TensorFlowLibrary)]
702+
private static unsafe extern TF_Operation TF_GraphNextOperation(TF_Graph graph, ref IntPtr pos);
703+
704+
/// <summary>
705+
/// Returns the enumerator that returns all the TFOperations in a graph.
706+
/// </summary>
707+
/// <returns>The enumerator.</returns>
708+
private IEnumerable<TFOperation> GetEnumerable()
709+
{
710+
if (handle == IntPtr.Zero)
711+
ObjectDisposedException();
712+
IntPtr token = IntPtr.Zero;
713+
IntPtr operll;
714+
while ((operll = TF_GraphNextOperation(handle, ref token)) != IntPtr.Zero)
715+
yield return new TFOperation(this, operll);
716+
}
717+
718+
public IEnumerator<TFOperation> GetEnumerator()
719+
{
720+
return GetEnumerable().GetEnumerator();
721+
}
722+
723+
IEnumerator IEnumerable.GetEnumerator()
724+
{
725+
return GetEnumerator();
726+
}
699727
}
700728

701729
/// <summary>
@@ -736,6 +764,48 @@ public TFOutput this[int idx]
736764
return new TFOutput(this, idx);
737765
}
738766
}
767+
768+
// extern TF_Output TF_OperationInput (TF_Input oper_in);
769+
[DllImport(NativeBinding.TensorFlowLibrary)]
770+
private static extern TFOutput TF_OperationInput(TFInput oper_in);
771+
772+
public TFOutput GetInput(int idx)
773+
{
774+
return TF_OperationInput(new TFInput() { Operation = handle, Index = idx });
775+
}
776+
777+
[DllImport(NativeBinding.TensorFlowLibrary)]
778+
private static extern IntPtr TF_OperationName(TF_Operation oper);
779+
780+
/// <summary>
781+
/// The name for this operation/
782+
/// </summary>
783+
/// <value>The name.</value>
784+
public string Name => handle == IntPtr.Zero ? "<ObjectDisposed>" : TF_OperationName(handle).GetStr();
785+
786+
[DllImport(NativeBinding.TensorFlowLibrary)]
787+
private static extern IntPtr TF_OperationOpType(TF_Operation oper);
788+
789+
public string OpType => handle == IntPtr.Zero ? "<ObjectDisposedException>" : TF_OperationOpType(handle).GetStr();
790+
791+
[DllImport(NativeBinding.TensorFlowLibrary)]
792+
private static extern int TF_OperationNumOutputs(TF_Operation oper);
793+
794+
/// <summary>
795+
/// Gets the number of outputs on this operation.
796+
/// </summary>
797+
/// <value>The number outputs.</value>
798+
public int NumOutputs => handle == IntPtr.Zero ? -1 : TF_OperationNumOutputs(handle);
799+
800+
[DllImport(NativeBinding.TensorFlowLibrary)]
801+
private static extern int TF_OperationNumInputs(TF_Operation oper);
802+
803+
/// <summary>
804+
/// Gets the number of inputs for this operation.
805+
/// Import a serialized graph into this graph, using the specified importing options.
806+
/// </summary>
807+
/// <value>The number inputs.</value>
808+
public int NumInputs => TF_OperationNumInputs(handle);
739809
}
740810

741811
/// <summary>
@@ -1768,15 +1838,6 @@ internal struct TFInput
17681838
/// </summary>
17691839
public int Index;
17701840

1771-
// extern TF_Output TF_OperationInput (TF_Input oper_in);
1772-
[DllImport(NativeBinding.TensorFlowLibrary)]
1773-
private static extern TFOutput TF_OperationInput(TFInput oper_in);
1774-
1775-
public TFOutput GetOutput(TFInput operIn)
1776-
{
1777-
return TF_OperationInput(operIn);
1778-
}
1779-
17801841
// extern TF_DataType TF_OperationInputType (TF_Input oper_in);
17811842
[DllImport(NativeBinding.TensorFlowLibrary)]
17821843
private static extern TFDataType TF_OperationInputType(TFInput oper_in);

0 commit comments

Comments
 (0)