Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug that whether an Orchestrator object can extract entities or n… #5749

Merged
merged 7 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ public class OrchestratorRecognizer : AdaptiveRecognizer
/// </summary>
public const string EntitiesProperty = "entityResult";
private const float UnknownIntentFilterScore = 0.4F;
private static ConcurrentDictionary<string, BotFramework.Orchestrator.Orchestrator> orchestratorMap = new ConcurrentDictionary<string, BotFramework.Orchestrator.Orchestrator>();
private string _modelFolder;
private string _snapshotFile;
private static ConcurrentDictionary<string, OrchestratorDictionaryEntry> orchestratorMap = new ConcurrentDictionary<string, OrchestratorDictionaryEntry>();
private OrchestratorDictionaryEntry _orchestrator = null;
private ILabelResolver _resolver = null;
private bool _isResolverMockup = false;

/// <summary>
/// Initializes a new instance of the <see cref="OrchestratorRecognizer"/> class.
Expand All @@ -63,30 +63,12 @@ public OrchestratorRecognizer([CallerFilePath] string callerPath = "", [CallerLi
/// </summary>
/// <param name="modelFolder">Specifies the base model folder.</param>
/// <param name="snapshotFile">Specifies full path to the snapshot file.</param>
/// <param name="resolver">Label resolver.</param>
public OrchestratorRecognizer(string modelFolder, string snapshotFile, ILabelResolver resolver = null)
/// <param name="resolverExternal">External label resolver object.</param>
public OrchestratorRecognizer(string modelFolder, string snapshotFile, ILabelResolver resolverExternal = null)
{
_resolver = resolver;
if (modelFolder == null)
{
throw new ArgumentNullException(nameof(modelFolder));
}

if (snapshotFile == null)
{
throw new ArgumentNullException(nameof(snapshotFile));
}

_modelFolder = modelFolder;
_snapshotFile = snapshotFile;
InitializeModel();
InitializeModel(modelFolder, snapshotFile, resolverExternal);
}

[JsonIgnore]
hcyang marked this conversation as resolved.
Show resolved Hide resolved
#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member
public bool ScoreEntities { get; set; } = false;
#pragma warning restore CS1591 // Missing XML comment for publicly visible type or member

/// <summary>
/// Gets or sets the folder path to Orchestrator base model to use.
/// </summary>
Expand Down Expand Up @@ -131,6 +113,15 @@ public OrchestratorRecognizer(string modelFolder, string snapshotFile, ILabelRes
[JsonProperty("detectAmbiguousIntents")]
public BoolExpression DetectAmbiguousIntents { get; set; } = false;

/// <summary>
/// Gets or sets a value indicating whether to enable or disable entity-extraction logic.
/// NOTE: SHOULD consider removing this flag in the next major SDK release (V5).
/// </summary>
/// <value>
/// The flag for enabling or disabling entity-extraction function.
/// </value>
public bool ScoreEntities { get; set; } = true;

/// <summary>
/// Return recognition results.
/// </summary>
Expand All @@ -142,14 +133,16 @@ public OrchestratorRecognizer(string modelFolder, string snapshotFile, ILabelRes
/// <returns>A <see cref="RecognizerResult"/> containing the QnA Maker result.</returns>
public override async Task<RecognizerResult> RecognizeAsync(DialogContext dc, Schema.Activity activity, CancellationToken cancellationToken, Dictionary<string, string> telemetryProperties = null, Dictionary<string, double> telemetryMetrics = null)
{
if (_resolver == null)
{
string modelFolder = ModelFolder.GetValue(dc.State);
string snapshotFile = SnapshotFile.GetValue(dc.State);
InitializeModel(modelFolder, snapshotFile, null);
}

var text = activity.Text ?? string.Empty;
var detectAmbiguity = DetectAmbiguousIntents.GetValue(dc.State);

_modelFolder = ModelFolder.GetValue(dc.State);
_snapshotFile = SnapshotFile.GetValue(dc.State);

InitializeModel();

var recognizerResult = new RecognizerResult()
{
Text = text,
Expand All @@ -165,11 +158,11 @@ public override async Task<RecognizerResult> RecognizeAsync(DialogContext dc, Sc
// Score with orchestrator
var results = _resolver.Score(text);

// Add full recognition result as a 'result' property
recognizerResult.Properties.Add(ResultProperty, results);

if (results.Any())
if ((results != null) && results.Any())
{
// Add full recognition result as a 'result' property
recognizerResult.Properties.Add(ResultProperty, results);

var topScore = results[0].Score;

// if top scoring intent is less than threshold, return None
Expand Down Expand Up @@ -230,7 +223,7 @@ public override async Task<RecognizerResult> RecognizeAsync(DialogContext dc, Sc
// Return 'None' if no intent matched.
recognizerResult.Intents.Add(NoneIntent, new IntentScore() { Score = 1.0 });
}

if (ExternalEntityRecognizer != null)
{
// Run external recognition
Expand All @@ -239,7 +232,7 @@ public override async Task<RecognizerResult> RecognizeAsync(DialogContext dc, Sc
}

TryScoreEntities(text, recognizerResult);

// Add full recognition result as a 'result' property
await dc.Context.TraceActivityAsync($"{nameof(OrchestratorRecognizer)}Result", JObject.FromObject(recognizerResult), nameof(OrchestratorRecognizer), "Orchestrator Recognition", cancellationToken).ConfigureAwait(false);
TrackRecognizerResult(dc, $"{nameof(OrchestratorRecognizer)}Result", FillRecognizerResultTelemetryProperties(recognizerResult, telemetryProperties, dc), telemetryMetrics);
Expand Down Expand Up @@ -317,107 +310,189 @@ private static JToken EntityResultToInstanceJObject(string text, Result result)

private void TryScoreEntities(string text, RecognizerResult recognizerResult)
{
if (!this.ScoreEntities)
// It's impossible to extract entities without a _resolver object.
if (_resolver == null)
{
return;
}

var results = _resolver.Score(text, LabelType.Entity);
recognizerResult.Properties.Add(EntitiesProperty, results);
// Entity extraction can be controlled by the ScoreEntities flag.
// NOTE: SHOULD consider removing this flag in the next major SDK release (V5).
if (!this.ScoreEntities)
{
return;
}

if (results.Any())
// The following check is necessary to ensure that the _resolver object
// is capable of entity exttraction. However, this check can also block
// a mock-up _resolver.
if (!_isResolverMockup)
{
if (recognizerResult.Entities == null)
if ((_orchestrator == null) || (!_orchestrator.IsEntityExtractionCapable))
{
recognizerResult.Entities = new JObject();
return;
}
}

var entitiesResult = recognizerResult.Entities;
foreach (var result in results)
{
// add value
JToken values;
if (!entitiesResult.TryGetValue(result.Label.Name, StringComparison.OrdinalIgnoreCase, out values))
{
values = new JArray();
entitiesResult[result.Label.Name] = values;
}
// As this method is TryScoreEntities, so it's best effort only, there should
// not be any exception thrown out of this method.
try
{
var results = _resolver.Score(text, LabelType.Entity);

((JArray)values).Add(EntityResultToJObject(text, result));
if ((results != null) && results.Any())
{
recognizerResult.Properties.Add(EntitiesProperty, results);

// get/create $instance
JToken instanceRoot;
if (!recognizerResult.Entities.TryGetValue("$instance", StringComparison.OrdinalIgnoreCase, out instanceRoot))
if (recognizerResult.Entities == null)
{
instanceRoot = new JObject();
recognizerResult.Entities["$instance"] = instanceRoot;
recognizerResult.Entities = new JObject();
}

// add instanceData
JToken instanceData;
if (!((JObject)instanceRoot).TryGetValue(result.Label.Name, StringComparison.OrdinalIgnoreCase, out instanceData))
var entitiesResult = recognizerResult.Entities;
foreach (var result in results)
{
instanceData = new JArray();
instanceRoot[result.Label.Name] = instanceData;
}
// add value
JToken values;
if (!entitiesResult.TryGetValue(result.Label.Name, StringComparison.OrdinalIgnoreCase, out values))
{
values = new JArray();
entitiesResult[result.Label.Name] = values;
}

((JArray)instanceData).Add(EntityResultToInstanceJObject(text, result));
// values came from an external entity recognizer, which may not make it a JArray.
if (values.Type != JTokenType.Array)
{
values = new JArray();
}

((JArray)values).Add(EntityResultToJObject(text, result));
hcyang marked this conversation as resolved.
Show resolved Hide resolved

// get/create $instance
JToken instanceRoot;
if (!recognizerResult.Entities.TryGetValue("$instance", StringComparison.OrdinalIgnoreCase, out instanceRoot))
hcyang marked this conversation as resolved.
Show resolved Hide resolved
{
instanceRoot = new JObject();
recognizerResult.Entities["$instance"] = instanceRoot;
}

// instanceRoot came from an external entity recognizer, which may not make it a JObject.
if (instanceRoot.Type != JTokenType.Object)
{
instanceRoot = new JObject();
}

// add instanceData
JToken instanceData;
if (!((JObject)instanceRoot).TryGetValue(result.Label.Name, StringComparison.OrdinalIgnoreCase, out instanceData))
{
instanceData = new JArray();
instanceRoot[result.Label.Name] = instanceData;
}

// instanceData came from an external entity recognizer, which may not make it a JArray.
if (instanceData.Type != JTokenType.Array)
{
instanceData = new JArray();
}

((JArray)instanceData).Add(EntityResultToInstanceJObject(text, result));
tsuwandy marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
catch (ApplicationException)
{
return; // ---- This is a "Try" function, i.e., best effort only, no exception.
}
}

private void InitializeModel()
[MethodImpl(MethodImplOptions.Synchronized)]
private void InitializeModel(string modelFolder, string snapshotFile, ILabelResolver resolverExternal = null)
{
if (_modelFolder == null)
if (resolverExternal != null)
{
#pragma warning disable CA2208 // Instantiate argument exceptions correctly
throw new ArgumentNullException("ModelFolder");
#pragma warning restore CA2208 // Instantiate argument exceptions correctly
_resolver = resolverExternal;
_isResolverMockup = true;
return;
}

if (_snapshotFile == null)
{
#pragma warning disable CA2208 // Instantiate argument exceptions correctly
throw new ArgumentNullException("SnapshotFile");
#pragma warning restore CA2208 // Instantiate argument exceptions correctly
}
if (string.IsNullOrWhiteSpace(modelFolder))
{
throw new ArgumentNullException(nameof(modelFolder));
}

if (_resolver != null)
{
return;
if (string.IsNullOrWhiteSpace(snapshotFile))
{
throw new ArgumentNullException(nameof(snapshotFile));
}
}

var fullModelFolder = Path.GetFullPath(PathUtils.NormalizePath(_modelFolder));
var fullModelFolder = Path.GetFullPath(PathUtils.NormalizePath(modelFolder));

var orchestrator = orchestratorMap.GetOrAdd(fullModelFolder, path =>
_orchestrator = orchestratorMap.GetOrAdd(fullModelFolder, path =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the latency of this method on moderate sized models?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latency for that is just seconds. Loading the models are quick even though they can be up to 1GB. The actual time consuming part is building the resolver consuming some big snapshot file. Perhaps that's the reason the InitializeModel() function is called once inside RecognizeAsync(), which is async.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the latency of the entire initializeModel?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Asking because 1) this will get run on the first turn and not when the app service starts. This means that the first turn of the first conversation will be really slow, i.e. as long as this method takes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we'd want to transfer this time to application load somehow in the future. We don't need to resolve this now, but it is importnat. I'll create an issue for this.

{
// Create Orchestrator
string entityModelFolder = null;
string entityModelFolder = null;
bool isEntityExtractionCapable = false;
try
{
entityModelFolder = Path.Combine(path, "entity");
ScoreEntities = Directory.Exists(entityModelFolder);
isEntityExtractionCapable = Directory.Exists(entityModelFolder);

return ScoreEntities ?
new BotFramework.Orchestrator.Orchestrator(path, entityModelFolder) :
new BotFramework.Orchestrator.Orchestrator(path);
return new OrchestratorDictionaryEntry()
{
Orchestrator = isEntityExtractionCapable ?
new BotFramework.Orchestrator.Orchestrator(path, entityModelFolder) :
new BotFramework.Orchestrator.Orchestrator(path),
IsEntityExtractionCapable = isEntityExtractionCapable
};
}
catch (Exception ex)
{
throw new InvalidOperationException(
ScoreEntities ? $"Failed to find or load Model with path {path}, entity model path {entityModelFolder}" : $"Failed to find or load Model with path {path}",
isEntityExtractionCapable ? $"Failed to find or load Model with path {path}, entity model path {entityModelFolder}" : $"Failed to find or load Model with path {path}",
ex);
}
});

var fullSnapShotFile = Path.GetFullPath(PathUtils.NormalizePath(_snapshotFile));
var fullSnapShotFile = Path.GetFullPath(PathUtils.NormalizePath(snapshotFile));

// Load the snapshot
string content = File.ReadAllText(fullSnapShotFile);
byte[] snapShotByteArray = Encoding.UTF8.GetBytes(content);
byte[] snapShotByteArray = File.ReadAllBytes(fullSnapShotFile);

// Create label resolver
_resolver = orchestrator.CreateLabelResolver(snapShotByteArray);
_resolver = this._orchestrator.Orchestrator.CreateLabelResolver(snapShotByteArray);
}

/// <summary>
/// OrchestratorDictionaryEntry is used for the static orchestratorMap object.
/// </summary>
private class OrchestratorDictionaryEntry
{
/// <summary>
/// Gets or sets the Orchestrator object.
/// </summary>
/// <value>
/// The Orchestrator object.
/// </value>
public BotFramework.Orchestrator.Orchestrator Orchestrator
{
get;
set;
}

/// <summary>
/// Gets or sets a value indicating whether the Orchestrator object is capable of entity extraction.
/// </summary>
/// <value>
/// The IsEntityExtractionCapable flag.
/// </value>
public bool IsEntityExtractionCapable
{
get;
set;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ public async Task TestEntityRecognize()
{
ModelFolder = new StringExpression("fakePath"),
SnapshotFile = new StringExpression("fakePath"),
ScoreEntities = true,
ExternalEntityRecognizer = new NumberEntityRecognizer()
};

Expand Down