-
Notifications
You must be signed in to change notification settings - Fork 181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Data Converter #2487
Data Converter #2487
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import struct | ||
from io import BytesIO | ||
from typing import List | ||
|
||
SIGNATURE = "NVDADAM1" # DAM (Direct Accessible Marshalling) V1 | ||
PREFIX_LEN = 24 | ||
|
||
DATA_TYPE_INT = 1 | ||
DATA_TYPE_FLOAT = 2 | ||
DATA_TYPE_STRING = 3 | ||
DATA_TYPE_INT_ARRAY = 257 | ||
DATA_TYPE_FLOAT_ARRAY = 258 | ||
|
||
|
||
class DamEncoder: | ||
def __init__(self, data_set_id: int): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add docstrings for all methods. |
||
self.data_set_id = data_set_id | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Data set" has specific meaning in AI. Maybe rename to something like "category"? |
||
self.entries = [] | ||
self.buffer = BytesIO() | ||
|
||
def add_int_array(self, value: List[int]): | ||
self.entries.append((DATA_TYPE_INT_ARRAY, value)) | ||
|
||
def add_float_array(self, value: List[float]): | ||
self.entries.append((DATA_TYPE_FLOAT_ARRAY, value)) | ||
|
||
def finish(self) -> bytes: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should the method name be changed to "encode"? |
||
size = PREFIX_LEN | ||
for entry in self.entries: | ||
size += 16 | ||
size += len(entry) * 8 | ||
|
||
self.write_str(SIGNATURE) | ||
self.write_int64(size) | ||
self.write_int64(self.data_set_id) | ||
|
||
for entry in self.entries: | ||
data_type, value = entry | ||
self.write_int64(data_type) | ||
self.write_int64(len(value)) | ||
|
||
for x in value: | ||
if data_type == DATA_TYPE_INT_ARRAY: | ||
self.write_int64(x) | ||
else: | ||
self.write_float(x) | ||
|
||
return self.buffer.getvalue() | ||
|
||
def write_int64(self, value: int): | ||
self.buffer.write(struct.pack("q", value)) | ||
|
||
def write_float(self, value: float): | ||
self.buffer.write(struct.pack("d", value)) | ||
|
||
def write_str(self, value: str): | ||
self.buffer.write(value.encode("utf-8")) | ||
|
||
|
||
class DamDecoder: | ||
def __init__(self, buffer: bytes): | ||
self.buffer = buffer | ||
self.pos = 0 | ||
self.signature = self.read_string(8) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These read calls could fail if buffer doesn't have enough bytes. |
||
self.size = self.read_int64() | ||
self.data_set_id = self.read_int64() | ||
|
||
def is_valid(self): | ||
return self.signature == SIGNATURE | ||
|
||
def get_data_set_id(self): | ||
return self.data_set_id | ||
|
||
def decode_int_array(self) -> List[int]: | ||
data_type = self.read_int64() | ||
if data_type != DATA_TYPE_INT_ARRAY: | ||
raise RuntimeError("Invalid data type for int array") | ||
|
||
num = self.read_int64() | ||
result = [0] * num | ||
for i in range(num): | ||
result[i] = self.read_int64() | ||
|
||
return result | ||
|
||
def decode_float_array(self): | ||
data_type = self.read_int64() | ||
if data_type != DATA_TYPE_FLOAT_ARRAY: | ||
raise RuntimeError("Invalid data type for float array") | ||
|
||
num = self.read_int64() | ||
result = [0.0] * num | ||
for i in range(num): | ||
result[i] = self.read_float() | ||
|
||
return result | ||
|
||
def read_string(self, length: int) -> str: | ||
result = self.buffer[self.pos : self.pos + length].decode("utf-8") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if exception occurs here? |
||
self.pos += length | ||
return result | ||
|
||
def read_int64(self) -> int: | ||
(result,) = struct.unpack_from("q", self.buffer, self.pos) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need exception handling |
||
self.pos += 8 | ||
return result | ||
|
||
def read_float(self) -> float: | ||
(result,) = struct.unpack_from("d", self.buffer, self.pos) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need exception handling |
||
self.pos += 8 | ||
return result |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Dict, List, Tuple | ||
|
||
from nvflare.apis.fl_context import FLContext | ||
|
||
|
||
class FeatureContext: | ||
def __init__(self, feature_id, sample_bin_assignment, num_bins: int): | ||
self.feature_id = feature_id | ||
self.num_bins = num_bins # how many bins this feature has | ||
self.sample_bin_assignment = sample_bin_assignment # sample/bin assignment; normalized to [0 .. num_bins-1] | ||
|
||
|
||
class AggregationContext: | ||
def __init__(self, features: List[FeatureContext], sample_groups: Dict[int, List[int]]): # group_id => sample Ids | ||
self.features = features | ||
self.sample_groups = sample_groups | ||
|
||
|
||
class FeatureAggregationResult: | ||
def __init__(self, feature_id: int, aggregated_hist: List[Tuple[int, int]]): | ||
self.feature_id = feature_id | ||
self.aggregated_hist = aggregated_hist # list of (G, H) values, one for each bin of the feature | ||
|
||
|
||
class DataConverter: | ||
def decode_gh_pairs(self, buffer: bytes, fl_ctx: FLContext) -> List[Tuple[int, int]]: | ||
"""Decode the buffer to extract (g, h) pairs. | ||
|
||
Args: | ||
buffer: the buffer to be decoded | ||
fl_ctx: FLContext info | ||
|
||
Returns: if the buffer contains (g, h) pairs, return a tuple of (g_numbers, h_numbers); | ||
otherwise, return None | ||
|
||
""" | ||
pass | ||
|
||
def decode_aggregation_context(self, buffer: bytes, fl_ctx: FLContext) -> AggregationContext: | ||
"""Decode the buffer to extract aggregation context info | ||
|
||
Args: | ||
buffer: buffer to be decoded | ||
fl_ctx: FLContext info | ||
|
||
Returns: if the buffer contains aggregation context, return an AggregationContext object; | ||
otherwise, return None | ||
|
||
""" | ||
pass | ||
|
||
def encode_aggregation_result( | ||
self, aggr_results: Dict[int, List[FeatureAggregationResult]], fl_ctx: FLContext | ||
) -> bytes: | ||
"""Encode an individual rank's aggr result to a buffer based on XGB data structure | ||
|
||
Args: | ||
aggr_results: aggregation result for all features and all groups from all clients | ||
group_id => list of feature aggr results | ||
fl_ctx: FLContext info | ||
|
||
Returns: a buffer of bytes | ||
|
||
""" | ||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Dict, List, Tuple | ||
|
||
from nvflare.apis.fl_context import FLContext | ||
from nvflare.app_common.xgb.sec.dam import DamDecoder, DamEncoder | ||
from nvflare.app_common.xgb.sec.data_converter import ( | ||
AggregationContext, | ||
DataConverter, | ||
FeatureAggregationResult, | ||
FeatureContext, | ||
) | ||
|
||
DATA_SET_GH_PAIRS = 1 | ||
DATA_SET_AGGREGATION = 2 | ||
DATA_SET_AGGREGATION_WITH_FEATURES = 3 | ||
DATA_SET_AGGREGATION_RESULT = 4 | ||
|
||
SCALE_FACTOR = 1000000.0 # Preserve 6 decimal places | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be configurable using ConfgService. |
||
|
||
|
||
class ProcessorDataConverter(DataConverter): | ||
def __init__(self): | ||
super().__init__() | ||
self.features = [] | ||
self.feature_list = None | ||
self.num_samples = 0 | ||
|
||
def decode_gh_pairs(self, buffer: bytes, fl_ctx: FLContext) -> List[Tuple[int, int]]: | ||
decoder = DamDecoder(buffer) | ||
if not decoder.is_valid(): | ||
raise RuntimeError("GH Buffer is not properly encoded") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we raise exception here? It should be normal that a non-gh-pair buffer is received. I think just return None here. |
||
|
||
if decoder.get_data_set_id() != DATA_SET_GH_PAIRS: | ||
raise RuntimeError(f"Data is not for GH Pairs: {decoder.get_data_set_id()}") | ||
|
||
float_array = decoder.decode_float_array() | ||
result = [] | ||
self.num_samples = int(len(float_array) / 2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. len(float_array) should be divisible by 2; otherwise it's an error. Should check it. |
||
|
||
for i in range(self.num_samples): | ||
result.append((self.float_to_int(float_array[2 * i]), self.float_to_int(float_array[2 * i + 1]))) | ||
|
||
return result | ||
|
||
def decode_aggregation_context(self, buffer: bytes, fl_ctx: FLContext) -> AggregationContext: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please document the encoding format of aggr ctx |
||
decoder = DamDecoder(buffer) | ||
if not decoder.is_valid(): | ||
raise RuntimeError("Aggregation Buffer is not properly encoded") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should return None here. |
||
data_set_id = decoder.get_data_set_id() | ||
cuts = decoder.decode_int_array() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's cuts? |
||
|
||
if data_set_id == DATA_SET_AGGREGATION_WITH_FEATURES: | ||
self.feature_list = decoder.decode_int_array() | ||
num = len(self.feature_list) | ||
slots = decoder.decode_int_array() | ||
for i in range(num): | ||
bin_assignment = [] | ||
for row_id in range(self.num_samples): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if num_samples is not set? Need to handle this exception. |
||
_, bin_num = self.slot_to_bin(cuts, slots[row_id * num + i]) | ||
bin_assignment.append(bin_num) | ||
|
||
bin_size = self.get_bin_size(cuts, self.feature_list[i]) | ||
feature_ctx = FeatureContext(self.feature_list[i], bin_assignment, bin_size) | ||
self.features.append(feature_ctx) | ||
elif data_set_id != DATA_SET_AGGREGATION: | ||
raise RuntimeError(f"Invalid DataSet: {data_set_id}") | ||
|
||
node_list = decoder.decode_int_array() | ||
sample_groups = {} | ||
for node in node_list: | ||
row_ids = decoder.decode_int_array() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's row_id? Isn't the same as sample_id? If so, let's use the same term. |
||
sample_groups[node] = row_ids | ||
|
||
return AggregationContext(self.features, sample_groups) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The whole method should be in try/except block, since any decode method call could fail. |
||
|
||
def encode_aggregation_result( | ||
self, aggr_results: Dict[int, List[FeatureAggregationResult]], fl_ctx: FLContext | ||
) -> bytes: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please document encoding format. |
||
encoder = DamEncoder(DATA_SET_AGGREGATION_RESULT) | ||
node_list = sorted(aggr_results.keys()) | ||
encoder.add_int_array(node_list) | ||
|
||
for node in node_list: | ||
result_list = aggr_results.get(node) | ||
for f in self.feature_list: | ||
encoder.add_float_array(self.find_histo_for_feature(result_list, f)) | ||
|
||
return encoder.finish() | ||
|
||
@staticmethod | ||
def get_bin_size(cuts: [int], feature_id: int) -> int: | ||
return cuts[feature_id + 1] - cuts[feature_id] | ||
|
||
@staticmethod | ||
def slot_to_bin(cuts: [int], slot: int) -> Tuple[int, int]: | ||
if slot < 0 or slot >= cuts[-1]: | ||
raise RuntimeError(f"Invalid slot {slot}, out of range [0-{cuts[-1]-1}]") | ||
|
||
for i in range(len(cuts) - 1): | ||
if cuts[i] <= slot < cuts[i + 1]: | ||
bin_num = slot - cuts[i] | ||
return i, bin_num | ||
|
||
raise RuntimeError(f"Logic error. Slot {slot}, out of range [0-{cuts[-1] - 1}]") | ||
|
||
@staticmethod | ||
def float_to_int(value: float) -> int: | ||
return int(value * SCALE_FACTOR) | ||
|
||
@staticmethod | ||
def int_to_float(value: int) -> float: | ||
return value / SCALE_FACTOR | ||
|
||
@staticmethod | ||
def find_histo_for_feature(result_list: List[FeatureAggregationResult], feature_id: int) -> List[float]: | ||
for result in result_list: | ||
if result.feature_id == feature_id: | ||
float_array = [] | ||
for (g, h) in result.aggregated_hist: | ||
float_array.append(ProcessorDataConverter.int_to_float(g)) | ||
float_array.append(ProcessorDataConverter.int_to_float(h)) | ||
|
||
return float_array | ||
|
||
raise RuntimeError(f"Logic error. Feature {feature_id} not found in the list") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from nvflare.app_common.xgb.sec.dam import DamDecoder, DamEncoder | ||
|
||
DATA_SET = 123456 | ||
INT_ARRAY = [123, 456, 789] | ||
FLOAT_ARRAY = [1.2, 2.3, 3.4, 4.5] | ||
|
||
|
||
class TestDam: | ||
def test_encode_decode(self): | ||
encoder = DamEncoder(DATA_SET) | ||
encoder.add_int_array(INT_ARRAY) | ||
encoder.add_float_array(FLOAT_ARRAY) | ||
buffer = encoder.finish() | ||
|
||
decoder = DamDecoder(buffer) | ||
assert decoder.is_valid() | ||
assert decoder.get_data_set_id() == DATA_SET | ||
|
||
int_array = decoder.decode_int_array() | ||
assert int_array == INT_ARRAY | ||
|
||
float_array = decoder.decode_float_array() | ||
assert float_array == FLOAT_ARRAY |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please document the encoding format