Skip to content

Commit

Permalink
Address comments about moving hash to a class function and not nested.
Browse files Browse the repository at this point in the history
  • Loading branch information
sini committed Aug 30, 2021
1 parent 433c064 commit c67dbdf
Showing 1 changed file with 16 additions and 28 deletions.
44 changes: 16 additions & 28 deletions ml-agents/mlagents/training_analytics_side_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,11 @@ def __init__(self) -> None:
super().__init__(uuid.UUID("b664a4a9-d86f-5a5f-95cb-e8353a7e8356"))
self.run_options: Optional[RunOptions] = None

@staticmethod
def __hash(key: str, data: str) -> str:
@classmethod
def __hash(cls, data: str) -> str:
res = hmac.new(
key.encode("utf-8"), data.encode("utf-8"), hashlib.sha256
cls.__vendorKey.encode("utf-8"), data.encode("utf-8"), hashlib.sha256
).hexdigest()
print(res)
return res

def on_message_received(self, msg: IncomingMessage) -> None:
Expand All @@ -47,36 +46,31 @@ def on_message_received(self, msg: IncomingMessage) -> None:
+ "this should not have happened."
)

@staticmethod
def __sanitize_run_options(config: RunOptions) -> Dict[str, Any]:
@classmethod
def __sanitize_run_options(cls, config: RunOptions) -> Dict[str, Any]:
res = copy.deepcopy(config.as_dict())

def hash(value: str) -> str:
return TrainingAnalyticsSideChannel.__hash(
TrainingAnalyticsSideChannel.__vendorKey, value
)

# Filter potentially PII behavior names
if "behaviors" in res and res["behaviors"]:
res["behaviors"] = {hash(k): v for (k, v) in res["behaviors"].items()}
res["behaviors"] = {cls.__hash(k): v for (k, v) in res["behaviors"].items()}
for (k, v) in res["behaviors"].items():
if "init_path" in v and v["init_path"] is not None:
hashed_path = hash(v["init_path"])
hashed_path = cls.__hash(v["init_path"])
res["behaviors"][k]["init_path"] = hashed_path

# Filter potentially PII curriculum and behavior names from Checkpoint Settings
if "environment_parameters" in res and res["environment_parameters"]:
res["environment_parameters"] = {
hash(k): v for (k, v) in res["environment_parameters"].items()
cls.__hash(k): v for (k, v) in res["environment_parameters"].items()
}
for (curriculumName, curriculum) in res["environment_parameters"].items():
updated_lessons = []
for lesson in curriculum["curriculum"]:
new_lesson = copy.deepcopy(lesson)
if lesson.has_keys("name"):
new_lesson["name"] = hash(lesson["name"])
new_lesson["name"] = cls.__hash(lesson["name"])
if lesson.has_keys("completion_criteria"):
new_lesson["completion_criteria"]["behavior"] = hash(
new_lesson["completion_criteria"]["behavior"] = cls.__hash(
new_lesson["completion_criteria"]["behavior"]
)
updated_lessons.append(new_lesson)
Expand All @@ -90,7 +84,7 @@ def hash(value: str) -> str:
"initialize_from" in res["checkpoint_settings"]
and res["checkpoint_settings"]["initialize_from"] is not None
):
res["checkpoint_settings"]["initialize_from"] = hash(
res["checkpoint_settings"]["initialize_from"] = cls.__hash(
res["checkpoint_settings"]["initialize_from"]
)
if (
Expand Down Expand Up @@ -123,31 +117,25 @@ def environment_initialized(self, run_options: RunOptions) -> None:
run_options=json.dumps(sanitized_run_options),
)

print(msg)

any_message = Any()
any_message.Pack(msg)

env_init_msg = OutgoingMessage()
env_init_msg.set_raw_bytes(any_message.SerializeToString())
super().queue_message_to_send(env_init_msg)

@staticmethod
def __sanitize_trainer_settings(config: TrainerSettings) -> Dict[str, Any]:
@classmethod
def __sanitize_trainer_settings(cls, config: TrainerSettings) -> Dict[str, Any]:
config_dict = copy.deepcopy(config.as_dict())
if "init_path" in config_dict and config_dict["init_path"] is not None:
hashed_path = TrainingAnalyticsSideChannel.__hash(
TrainingAnalyticsSideChannel.__vendorKey, config_dict["init_path"]
)
hashed_path = cls.__hash(config_dict["init_path"])
config_dict["init_path"] = hashed_path
return config_dict

def training_started(self, behavior_name: str, config: TrainerSettings) -> None:
raw_config = TrainingAnalyticsSideChannel.__sanitize_trainer_settings(config)
raw_config = self.__sanitize_trainer_settings(config)
msg = TrainingBehaviorInitialized(
behavior_name=TrainingAnalyticsSideChannel.__hash(
self.__vendorKey, behavior_name
),
behavior_name=self.__hash(behavior_name),
trainer_type=config.trainer_type.value,
extrinsic_reward_enabled=(
RewardSignalType.EXTRINSIC in config.reward_signals
Expand Down

0 comments on commit c67dbdf

Please sign in to comment.