Skip to content

Commit

Permalink
Harden user PII protection logic and extend TrainingAnalytics to expo…
Browse files Browse the repository at this point in the history
…se detailed configuration parameters. (#5512)

* Hash128 is not a cryptographic hash, replace with HMAC-SHA256.

* Extend TrainingAnalytics side channel to expose configuration details

* Change member function scopes and hash demo_paths

* Extract tbiEvent hashing method and add test coverage
  • Loading branch information
sini authored and maryamhonari committed Nov 4, 2021
1 parent 64fa713 commit 3fe4e03
Show file tree
Hide file tree
Showing 11 changed files with 249 additions and 35 deletions.
36 changes: 31 additions & 5 deletions com.unity.ml-agents/Runtime/Analytics/AnalyticsUtils.cs
Original file line number Diff line number Diff line change
@@ -1,19 +1,45 @@
using System;
using System.Text;
using System.Security.Cryptography;
using UnityEngine;

namespace Unity.MLAgents.Analytics
{

internal static class AnalyticsUtils
{
/// <summary>
/// Conversion function from byte array to hex string
/// </summary>
/// <param name="array"></param>
/// <returns>A byte array to be hex encoded.</returns>
private static string ToHexString(byte[] array)
{
StringBuilder hex = new StringBuilder(array.Length * 2);
foreach (byte b in array)
{
hex.AppendFormat("{0:x2}", b);
}
return hex.ToString();
}

/// <summary>
/// Hash a string to remove PII or secret info before sending to analytics
/// </summary>
/// <param name="s"></param>
/// <returns>A string containing the Hash128 of the input string.</returns>
public static string Hash(string s)
/// <param name="key"></param>
/// <returns>A string containing the key to be used for HMAC encoding.</returns>
/// <param name="value"></param>
/// <returns>A string containing the value to be encoded.</returns>
public static string Hash(string key, string value)
{
var behaviorNameHash = Hash128.Compute(s);
return behaviorNameHash.ToString();
string hash;
UTF8Encoding encoder = new UTF8Encoding();
using (HMACSHA256 hmac = new HMACSHA256(encoder.GetBytes(key)))
{
Byte[] hmBytes = hmac.ComputeHash(encoder.GetBytes(value));
hash = ToHexString(hmBytes);
}
return hash;
}

internal static bool s_SendEditorAnalytics = true;
Expand Down
2 changes: 2 additions & 0 deletions com.unity.ml-agents/Runtime/Analytics/Events.cs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ internal struct TrainingEnvironmentInitializedEvent
public string TorchDeviceType;
public int NumEnvironments;
public int NumEnvironmentParameters;
public string RunOptions;
}

[Flags]
Expand Down Expand Up @@ -188,5 +189,6 @@ internal struct TrainingBehaviorInitializedEvent
public string VisualEncoder;
public int NumNetworkLayers;
public int NumNetworkHiddenUnits;
public string Config;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ IList<IActuator> actuators
var inferenceEvent = new InferenceEvent();

// Hash the behavior name so that there's no concern about PII or "secret" data being leaked.
inferenceEvent.BehaviorName = AnalyticsUtils.Hash(behaviorName);
inferenceEvent.BehaviorName = AnalyticsUtils.Hash(k_VendorKey, behaviorName);

inferenceEvent.BarracudaModelSource = barracudaModel.IrSource;
inferenceEvent.BarracudaModelVersion = barracudaModel.IrVersion;
Expand Down
20 changes: 16 additions & 4 deletions com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,21 @@ internal static string ParseBehaviorName(string fullyQualifiedBehaviorName)
return fullyQualifiedBehaviorName.Substring(0, lastQuestionIndex);
}

internal static TrainingBehaviorInitializedEvent SanitizeTrainingBehaviorInitializedEvent(TrainingBehaviorInitializedEvent tbiEvent)
{
// Hash the behavior name if the message version is from an older version of ml-agents that doesn't do trainer-side hashing.
// We'll also, for extra safety, verify that the BehaviorName is the size of the expected SHA256 hash.
// Context: The config field was added at the same time as trainer side hashing, so messages including it should already be hashed.
if (tbiEvent.Config.Length == 0 || tbiEvent.BehaviorName.Length != 64)
{
tbiEvent.BehaviorName = AnalyticsUtils.Hash(k_VendorKey, tbiEvent.BehaviorName);
}

return tbiEvent;
}

[Conditional("MLA_UNITY_ANALYTICS_MODULE")]
public static void TrainingBehaviorInitialized(TrainingBehaviorInitializedEvent tbiEvent)
public static void TrainingBehaviorInitialized(TrainingBehaviorInitializedEvent rawTbiEvent)
{
#if UNITY_EDITOR && MLA_UNITY_ANALYTICS_MODULE
if (!IsAnalyticsEnabled())
Expand All @@ -202,6 +215,7 @@ public static void TrainingBehaviorInitialized(TrainingBehaviorInitializedEvent
if (!EnableAnalytics())
return;

var tbiEvent = SanitizeTrainingBehaviorInitializedEvent(rawTbiEvent);
var behaviorName = tbiEvent.BehaviorName;
var added = s_SentTrainingBehaviorInitialized.Add(behaviorName);

Expand All @@ -211,9 +225,7 @@ public static void TrainingBehaviorInitialized(TrainingBehaviorInitializedEvent
return;
}

// Hash the behavior name so that there's no concern about PII or "secret" data being leaked.
tbiEvent.TrainingSessionGuid = s_TrainingSessionGuid.ToString();
tbiEvent.BehaviorName = AnalyticsUtils.Hash(tbiEvent.BehaviorName);

// Note - to debug, use JsonUtility.ToJson on the event.
// Debug.Log(
Expand All @@ -236,7 +248,7 @@ IList<IActuator> actuators
var remotePolicyEvent = new RemotePolicyInitializedEvent();

// Hash the behavior name so that there's no concern about PII or "secret" data being leaked.
remotePolicyEvent.BehaviorName = AnalyticsUtils.Hash(behaviorName);
remotePolicyEvent.BehaviorName = AnalyticsUtils.Hash(k_VendorKey, behaviorName);

remotePolicyEvent.TrainingSessionGuid = s_TrainingSessionGuid.ToString();
remotePolicyEvent.ActionSpec = EventActionSpec.FromActionSpec(actionSpec);
Expand Down
2 changes: 2 additions & 0 deletions com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ internal static TrainingEnvironmentInitializedEvent ToTrainingEnvironmentInitial
TorchDeviceType = inputProto.TorchDeviceType,
NumEnvironments = inputProto.NumEnvs,
NumEnvironmentParameters = inputProto.NumEnvironmentParameters,
RunOptions = inputProto.RunOptions,
};
}

Expand Down Expand Up @@ -530,6 +531,7 @@ internal static TrainingBehaviorInitializedEvent ToTrainingBehaviorInitializedEv
VisualEncoder = inputProto.VisualEncoder,
NumNetworkLayers = inputProto.NumNetworkLayers,
NumNetworkHiddenUnits = inputProto.NumNetworkHiddenUnits,
Config = inputProto.Config,
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,29 @@ static TrainingAnalyticsReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjttbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL3RyYWluaW5n",
"X2FuYWx5dGljcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMi2QEKHlRy",
"X2FuYWx5dGljcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMi7gEKHlRy",
"YWluaW5nRW52aXJvbm1lbnRJbml0aWFsaXplZBIYChBtbGFnZW50c192ZXJz",
"aW9uGAEgASgJEh0KFW1sYWdlbnRzX2VudnNfdmVyc2lvbhgCIAEoCRIWCg5w",
"eXRob25fdmVyc2lvbhgDIAEoCRIVCg10b3JjaF92ZXJzaW9uGAQgASgJEhkK",
"EXRvcmNoX2RldmljZV90eXBlGAUgASgJEhAKCG51bV9lbnZzGAYgASgFEiIK",
"Gm51bV9lbnZpcm9ubWVudF9wYXJhbWV0ZXJzGAcgASgFIq0DChtUcmFpbmlu",
"Z0JlaGF2aW9ySW5pdGlhbGl6ZWQSFQoNYmVoYXZpb3JfbmFtZRgBIAEoCRIU",
"Cgx0cmFpbmVyX3R5cGUYAiABKAkSIAoYZXh0cmluc2ljX3Jld2FyZF9lbmFi",
"bGVkGAMgASgIEhsKE2dhaWxfcmV3YXJkX2VuYWJsZWQYBCABKAgSIAoYY3Vy",
"aW9zaXR5X3Jld2FyZF9lbmFibGVkGAUgASgIEhoKEnJuZF9yZXdhcmRfZW5h",
"YmxlZBgGIAEoCBIiChpiZWhhdmlvcmFsX2Nsb25pbmdfZW5hYmxlZBgHIAEo",
"CBIZChFyZWN1cnJlbnRfZW5hYmxlZBgIIAEoCBIWCg52aXN1YWxfZW5jb2Rl",
"chgJIAEoCRIaChJudW1fbmV0d29ya19sYXllcnMYCiABKAUSIAoYbnVtX25l",
"dHdvcmtfaGlkZGVuX3VuaXRzGAsgASgFEhgKEHRyYWluZXJfdGhyZWFkZWQY",
"DCABKAgSGQoRc2VsZl9wbGF5X2VuYWJsZWQYDSABKAgSGgoSY3VycmljdWx1",
"bV9lbmFibGVkGA4gASgIQiWqAiJVbml0eS5NTEFnZW50cy5Db21tdW5pY2F0",
"b3JPYmplY3RzYgZwcm90bzM="));
"Gm51bV9lbnZpcm9ubWVudF9wYXJhbWV0ZXJzGAcgASgFEhMKC3J1bl9vcHRp",
"b25zGAggASgJIr0DChtUcmFpbmluZ0JlaGF2aW9ySW5pdGlhbGl6ZWQSFQoN",
"YmVoYXZpb3JfbmFtZRgBIAEoCRIUCgx0cmFpbmVyX3R5cGUYAiABKAkSIAoY",
"ZXh0cmluc2ljX3Jld2FyZF9lbmFibGVkGAMgASgIEhsKE2dhaWxfcmV3YXJk",
"X2VuYWJsZWQYBCABKAgSIAoYY3VyaW9zaXR5X3Jld2FyZF9lbmFibGVkGAUg",
"ASgIEhoKEnJuZF9yZXdhcmRfZW5hYmxlZBgGIAEoCBIiChpiZWhhdmlvcmFs",
"X2Nsb25pbmdfZW5hYmxlZBgHIAEoCBIZChFyZWN1cnJlbnRfZW5hYmxlZBgI",
"IAEoCBIWCg52aXN1YWxfZW5jb2RlchgJIAEoCRIaChJudW1fbmV0d29ya19s",
"YXllcnMYCiABKAUSIAoYbnVtX25ldHdvcmtfaGlkZGVuX3VuaXRzGAsgASgF",
"EhgKEHRyYWluZXJfdGhyZWFkZWQYDCABKAgSGQoRc2VsZl9wbGF5X2VuYWJs",
"ZWQYDSABKAgSGgoSY3VycmljdWx1bV9lbmFibGVkGA4gASgIEg4KBmNvbmZp",
"ZxgPIAEoCUIlqgIiVW5pdHkuTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0",
"c2IGcHJvdG8z"));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.TrainingEnvironmentInitialized), global::Unity.MLAgents.CommunicatorObjects.TrainingEnvironmentInitialized.Parser, new[]{ "MlagentsVersion", "MlagentsEnvsVersion", "PythonVersion", "TorchVersion", "TorchDeviceType", "NumEnvs", "NumEnvironmentParameters" }, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.TrainingBehaviorInitialized), global::Unity.MLAgents.CommunicatorObjects.TrainingBehaviorInitialized.Parser, new[]{ "BehaviorName", "TrainerType", "ExtrinsicRewardEnabled", "GailRewardEnabled", "CuriosityRewardEnabled", "RndRewardEnabled", "BehavioralCloningEnabled", "RecurrentEnabled", "VisualEncoder", "NumNetworkLayers", "NumNetworkHiddenUnits", "TrainerThreaded", "SelfPlayEnabled", "CurriculumEnabled" }, null, null, null)
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.TrainingEnvironmentInitialized), global::Unity.MLAgents.CommunicatorObjects.TrainingEnvironmentInitialized.Parser, new[]{ "MlagentsVersion", "MlagentsEnvsVersion", "PythonVersion", "TorchVersion", "TorchDeviceType", "NumEnvs", "NumEnvironmentParameters", "RunOptions" }, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.TrainingBehaviorInitialized), global::Unity.MLAgents.CommunicatorObjects.TrainingBehaviorInitialized.Parser, new[]{ "BehaviorName", "TrainerType", "ExtrinsicRewardEnabled", "GailRewardEnabled", "CuriosityRewardEnabled", "RndRewardEnabled", "BehavioralCloningEnabled", "RecurrentEnabled", "VisualEncoder", "NumNetworkLayers", "NumNetworkHiddenUnits", "TrainerThreaded", "SelfPlayEnabled", "CurriculumEnabled", "Config" }, null, null, null)
}));
}
#endregion
Expand Down Expand Up @@ -85,6 +86,7 @@ public TrainingEnvironmentInitialized(TrainingEnvironmentInitialized other) : th
torchDeviceType_ = other.torchDeviceType_;
numEnvs_ = other.numEnvs_;
numEnvironmentParameters_ = other.numEnvironmentParameters_;
runOptions_ = other.runOptions_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

Expand Down Expand Up @@ -170,6 +172,17 @@ public int NumEnvironmentParameters {
}
}

/// <summary>Field number for the "run_options" field.</summary>
public const int RunOptionsFieldNumber = 8;
private string runOptions_ = "";
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public string RunOptions {
get { return runOptions_; }
set {
runOptions_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as TrainingEnvironmentInitialized);
Expand All @@ -190,6 +203,7 @@ public bool Equals(TrainingEnvironmentInitialized other) {
if (TorchDeviceType != other.TorchDeviceType) return false;
if (NumEnvs != other.NumEnvs) return false;
if (NumEnvironmentParameters != other.NumEnvironmentParameters) return false;
if (RunOptions != other.RunOptions) return false;
return Equals(_unknownFields, other._unknownFields);
}

Expand All @@ -203,6 +217,7 @@ public override int GetHashCode() {
if (TorchDeviceType.Length != 0) hash ^= TorchDeviceType.GetHashCode();
if (NumEnvs != 0) hash ^= NumEnvs.GetHashCode();
if (NumEnvironmentParameters != 0) hash ^= NumEnvironmentParameters.GetHashCode();
if (RunOptions.Length != 0) hash ^= RunOptions.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
Expand Down Expand Up @@ -244,6 +259,10 @@ public void WriteTo(pb::CodedOutputStream output) {
output.WriteRawTag(56);
output.WriteInt32(NumEnvironmentParameters);
}
if (RunOptions.Length != 0) {
output.WriteRawTag(66);
output.WriteString(RunOptions);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
Expand Down Expand Up @@ -273,6 +292,9 @@ public int CalculateSize() {
if (NumEnvironmentParameters != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(NumEnvironmentParameters);
}
if (RunOptions.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(RunOptions);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
Expand Down Expand Up @@ -305,6 +327,9 @@ public void MergeFrom(TrainingEnvironmentInitialized other) {
if (other.NumEnvironmentParameters != 0) {
NumEnvironmentParameters = other.NumEnvironmentParameters;
}
if (other.RunOptions.Length != 0) {
RunOptions = other.RunOptions;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

Expand Down Expand Up @@ -344,6 +369,10 @@ public void MergeFrom(pb::CodedInputStream input) {
NumEnvironmentParameters = input.ReadInt32();
break;
}
case 66: {
RunOptions = input.ReadString();
break;
}
}
}
}
Expand Down Expand Up @@ -389,6 +418,7 @@ public TrainingBehaviorInitialized(TrainingBehaviorInitialized other) : this() {
trainerThreaded_ = other.trainerThreaded_;
selfPlayEnabled_ = other.selfPlayEnabled_;
curriculumEnabled_ = other.curriculumEnabled_;
config_ = other.config_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}

Expand Down Expand Up @@ -551,6 +581,17 @@ public bool CurriculumEnabled {
}
}

/// <summary>Field number for the "config" field.</summary>
public const int ConfigFieldNumber = 15;
private string config_ = "";
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public string Config {
get { return config_; }
set {
config_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}

[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as TrainingBehaviorInitialized);
Expand Down Expand Up @@ -578,6 +619,7 @@ public bool Equals(TrainingBehaviorInitialized other) {
if (TrainerThreaded != other.TrainerThreaded) return false;
if (SelfPlayEnabled != other.SelfPlayEnabled) return false;
if (CurriculumEnabled != other.CurriculumEnabled) return false;
if (Config != other.Config) return false;
return Equals(_unknownFields, other._unknownFields);
}

Expand All @@ -598,6 +640,7 @@ public override int GetHashCode() {
if (TrainerThreaded != false) hash ^= TrainerThreaded.GetHashCode();
if (SelfPlayEnabled != false) hash ^= SelfPlayEnabled.GetHashCode();
if (CurriculumEnabled != false) hash ^= CurriculumEnabled.GetHashCode();
if (Config.Length != 0) hash ^= Config.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
Expand Down Expand Up @@ -667,6 +710,10 @@ public void WriteTo(pb::CodedOutputStream output) {
output.WriteRawTag(112);
output.WriteBool(CurriculumEnabled);
}
if (Config.Length != 0) {
output.WriteRawTag(122);
output.WriteString(Config);
}
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
Expand Down Expand Up @@ -717,6 +764,9 @@ public int CalculateSize() {
if (CurriculumEnabled != false) {
size += 1 + 1;
}
if (Config.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(Config);
}
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
Expand Down Expand Up @@ -770,6 +820,9 @@ public void MergeFrom(TrainingBehaviorInitialized other) {
if (other.CurriculumEnabled != false) {
CurriculumEnabled = other.CurriculumEnabled;
}
if (other.Config.Length != 0) {
Config = other.Config;
}
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}

Expand Down Expand Up @@ -837,6 +890,10 @@ public void MergeFrom(pb::CodedInputStream input) {
CurriculumEnabled = input.ReadBool();
break;
}
case 122: {
Config = input.ReadString();
break;
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,19 @@ public void TestRemotePolicy()
Academy.Instance.Dispose();
}

[TestCase("a name we expect to hash", ExpectedResult = "d084a8b6da6a6a1c097cdc9ffea95e1546da4647352113ed77cbe7b4192e6d73")]
[TestCase("another_name", ExpectedResult = "0b74613c872e79aba11e06eda3538f2b646eb2b459e75087829ea500bd703d0b")]
[TestCase("0b74613c872e79aba11e06eda3538f2b646eb2b459e75087829ea500bd703d0b", ExpectedResult = "0b74613c872e79aba11e06eda3538f2b646eb2b459e75087829ea500bd703d0b")]
public string TestTrainingBehaviorInitialized(string stringToMaybeHash)
{
var tbiEvent = new TrainingBehaviorInitializedEvent();
tbiEvent.BehaviorName = stringToMaybeHash;
tbiEvent.Config = "{}";

var sanitizedEvent = TrainingAnalytics.SanitizeTrainingBehaviorInitializedEvent(tbiEvent);
return sanitizedEvent.BehaviorName;
}

[Test]
public void TestEnableAnalytics()
{
Expand Down
Loading

0 comments on commit 3fe4e03

Please sign in to comment.