diff --git a/nvflare/app_common/xgb/sec/dam.py b/nvflare/app_common/xgb/sec/dam.py new file mode 100644 index 0000000000..76730154a8 --- /dev/null +++ b/nvflare/app_common/xgb/sec/dam.py @@ -0,0 +1,124 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import struct +from io import BytesIO +from typing import List + +SIGNATURE = "NVDADAM1" # DAM (Direct Accessible Marshalling) V1 +PREFIX_LEN = 24 + +DATA_TYPE_INT = 1 +DATA_TYPE_FLOAT = 2 +DATA_TYPE_STRING = 3 +DATA_TYPE_INT_ARRAY = 257 +DATA_TYPE_FLOAT_ARRAY = 258 + + +class DamEncoder: + def __init__(self, data_set_id: int): + self.data_set_id = data_set_id + self.entries = [] + self.buffer = BytesIO() + + def add_int_array(self, value: List[int]): + self.entries.append((DATA_TYPE_INT_ARRAY, value)) + + def add_float_array(self, value: List[float]): + self.entries.append((DATA_TYPE_FLOAT_ARRAY, value)) + + def finish(self) -> bytes: + size = PREFIX_LEN + for entry in self.entries: + size += 16 + size += len(entry) * 8 + + self.write_str(SIGNATURE) + self.write_int64(size) + self.write_int64(self.data_set_id) + + for entry in self.entries: + data_type, value = entry + self.write_int64(data_type) + self.write_int64(len(value)) + + for x in value: + if data_type == DATA_TYPE_INT_ARRAY: + self.write_int64(x) + else: + self.write_float(x) + + return self.buffer.getvalue() + + def write_int64(self, value: int): + self.buffer.write(struct.pack("q", value)) + + def write_float(self, value: float): + self.buffer.write(struct.pack("d", value)) + + def write_str(self, value: str): + self.buffer.write(value.encode("utf-8")) + + +class DamDecoder: + def __init__(self, buffer: bytes): + self.buffer = buffer + self.pos = 0 + self.signature = self.read_string(8) + self.size = self.read_int64() + self.data_set_id = self.read_int64() + + def is_valid(self): + return self.signature == SIGNATURE + + def get_data_set_id(self): + return self.data_set_id + + def decode_int_array(self) -> List[int]: + data_type = self.read_int64() + if data_type != DATA_TYPE_INT_ARRAY: + raise RuntimeError("Invalid data type for int array") + + num = self.read_int64() + result = [0] * num + for i in range(num): + result[i] = self.read_int64() + + return result + + def decode_float_array(self): + data_type = self.read_int64() + if data_type != DATA_TYPE_FLOAT_ARRAY: + raise RuntimeError("Invalid data type for float array") + + num = self.read_int64() + result = [0.0] * num + for i in range(num): + result[i] = self.read_float() + + return result + + def read_string(self, length: int) -> str: + result = self.buffer[self.pos : self.pos + length].decode("utf-8") + self.pos += length + return result + + def read_int64(self) -> int: + (result,) = struct.unpack_from("q", self.buffer, self.pos) + self.pos += 8 + return result + + def read_float(self) -> float: + (result,) = struct.unpack_from("d", self.buffer, self.pos) + self.pos += 8 + return result diff --git a/nvflare/app_common/xgb/sec/data_converter.py b/nvflare/app_common/xgb/sec/data_converter.py new file mode 100644 index 0000000000..9a8416ce2b --- /dev/null +++ b/nvflare/app_common/xgb/sec/data_converter.py @@ -0,0 +1,78 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Tuple + +from nvflare.apis.fl_context import FLContext + + +class FeatureContext: + def __init__(self, feature_id, sample_bin_assignment, num_bins: int): + self.feature_id = feature_id + self.num_bins = num_bins # how many bins this feature has + self.sample_bin_assignment = sample_bin_assignment # sample/bin assignment; normalized to [0 .. num_bins-1] + + +class AggregationContext: + def __init__(self, features: List[FeatureContext], sample_groups: Dict[int, List[int]]): # group_id => sample Ids + self.features = features + self.sample_groups = sample_groups + + +class FeatureAggregationResult: + def __init__(self, feature_id: int, aggregated_hist: List[Tuple[int, int]]): + self.feature_id = feature_id + self.aggregated_hist = aggregated_hist # list of (G, H) values, one for each bin of the feature + + +class DataConverter: + def decode_gh_pairs(self, buffer: bytes, fl_ctx: FLContext) -> List[Tuple[int, int]]: + """Decode the buffer to extract (g, h) pairs. + + Args: + buffer: the buffer to be decoded + fl_ctx: FLContext info + + Returns: if the buffer contains (g, h) pairs, return a tuple of (g_numbers, h_numbers); + otherwise, return None + + """ + pass + + def decode_aggregation_context(self, buffer: bytes, fl_ctx: FLContext) -> AggregationContext: + """Decode the buffer to extract aggregation context info + + Args: + buffer: buffer to be decoded + fl_ctx: FLContext info + + Returns: if the buffer contains aggregation context, return an AggregationContext object; + otherwise, return None + + """ + pass + + def encode_aggregation_result( + self, aggr_results: Dict[int, List[FeatureAggregationResult]], fl_ctx: FLContext + ) -> bytes: + """Encode an individual rank's aggr result to a buffer based on XGB data structure + + Args: + aggr_results: aggregation result for all features and all groups from all clients + group_id => list of feature aggr results + fl_ctx: FLContext info + + Returns: a buffer of bytes + + """ + pass diff --git a/nvflare/app_common/xgb/sec/processor_data_converter.py b/nvflare/app_common/xgb/sec/processor_data_converter.py new file mode 100644 index 0000000000..152a650be0 --- /dev/null +++ b/nvflare/app_common/xgb/sec/processor_data_converter.py @@ -0,0 +1,137 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Tuple + +from nvflare.apis.fl_context import FLContext +from nvflare.app_common.xgb.sec.dam import DamDecoder, DamEncoder +from nvflare.app_common.xgb.sec.data_converter import ( + AggregationContext, + DataConverter, + FeatureAggregationResult, + FeatureContext, +) + +DATA_SET_GH_PAIRS = 1 +DATA_SET_AGGREGATION = 2 +DATA_SET_AGGREGATION_WITH_FEATURES = 3 +DATA_SET_AGGREGATION_RESULT = 4 + +SCALE_FACTOR = 1000000.0 # Preserve 6 decimal places + + +class ProcessorDataConverter(DataConverter): + def __init__(self): + super().__init__() + self.features = [] + self.feature_list = None + self.num_samples = 0 + + def decode_gh_pairs(self, buffer: bytes, fl_ctx: FLContext) -> List[Tuple[int, int]]: + decoder = DamDecoder(buffer) + if not decoder.is_valid(): + raise RuntimeError("GH Buffer is not properly encoded") + + if decoder.get_data_set_id() != DATA_SET_GH_PAIRS: + raise RuntimeError(f"Data is not for GH Pairs: {decoder.get_data_set_id()}") + + float_array = decoder.decode_float_array() + result = [] + self.num_samples = int(len(float_array) / 2) + + for i in range(self.num_samples): + result.append((self.float_to_int(float_array[2 * i]), self.float_to_int(float_array[2 * i + 1]))) + + return result + + def decode_aggregation_context(self, buffer: bytes, fl_ctx: FLContext) -> AggregationContext: + decoder = DamDecoder(buffer) + if not decoder.is_valid(): + raise RuntimeError("Aggregation Buffer is not properly encoded") + data_set_id = decoder.get_data_set_id() + cuts = decoder.decode_int_array() + + if data_set_id == DATA_SET_AGGREGATION_WITH_FEATURES: + self.feature_list = decoder.decode_int_array() + num = len(self.feature_list) + slots = decoder.decode_int_array() + for i in range(num): + bin_assignment = [] + for row_id in range(self.num_samples): + _, bin_num = self.slot_to_bin(cuts, slots[row_id * num + i]) + bin_assignment.append(bin_num) + + bin_size = self.get_bin_size(cuts, self.feature_list[i]) + feature_ctx = FeatureContext(self.feature_list[i], bin_assignment, bin_size) + self.features.append(feature_ctx) + elif data_set_id != DATA_SET_AGGREGATION: + raise RuntimeError(f"Invalid DataSet: {data_set_id}") + + node_list = decoder.decode_int_array() + sample_groups = {} + for node in node_list: + row_ids = decoder.decode_int_array() + sample_groups[node] = row_ids + + return AggregationContext(self.features, sample_groups) + + def encode_aggregation_result( + self, aggr_results: Dict[int, List[FeatureAggregationResult]], fl_ctx: FLContext + ) -> bytes: + encoder = DamEncoder(DATA_SET_AGGREGATION_RESULT) + node_list = sorted(aggr_results.keys()) + encoder.add_int_array(node_list) + + for node in node_list: + result_list = aggr_results.get(node) + for f in self.feature_list: + encoder.add_float_array(self.find_histo_for_feature(result_list, f)) + + return encoder.finish() + + @staticmethod + def get_bin_size(cuts: [int], feature_id: int) -> int: + return cuts[feature_id + 1] - cuts[feature_id] + + @staticmethod + def slot_to_bin(cuts: [int], slot: int) -> Tuple[int, int]: + if slot < 0 or slot >= cuts[-1]: + raise RuntimeError(f"Invalid slot {slot}, out of range [0-{cuts[-1]-1}]") + + for i in range(len(cuts) - 1): + if cuts[i] <= slot < cuts[i + 1]: + bin_num = slot - cuts[i] + return i, bin_num + + raise RuntimeError(f"Logic error. Slot {slot}, out of range [0-{cuts[-1] - 1}]") + + @staticmethod + def float_to_int(value: float) -> int: + return int(value * SCALE_FACTOR) + + @staticmethod + def int_to_float(value: int) -> float: + return value / SCALE_FACTOR + + @staticmethod + def find_histo_for_feature(result_list: List[FeatureAggregationResult], feature_id: int) -> List[float]: + for result in result_list: + if result.feature_id == feature_id: + float_array = [] + for (g, h) in result.aggregated_hist: + float_array.append(ProcessorDataConverter.int_to_float(g)) + float_array.append(ProcessorDataConverter.int_to_float(h)) + + return float_array + + raise RuntimeError(f"Logic error. Feature {feature_id} not found in the list") diff --git a/tests/unit_test/app_common/xgb/sec/dam_test.py b/tests/unit_test/app_common/xgb/sec/dam_test.py new file mode 100644 index 0000000000..2ba44e1468 --- /dev/null +++ b/tests/unit_test/app_common/xgb/sec/dam_test.py @@ -0,0 +1,36 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.app_common.xgb.sec.dam import DamDecoder, DamEncoder + +DATA_SET = 123456 +INT_ARRAY = [123, 456, 789] +FLOAT_ARRAY = [1.2, 2.3, 3.4, 4.5] + + +class TestDam: + def test_encode_decode(self): + encoder = DamEncoder(DATA_SET) + encoder.add_int_array(INT_ARRAY) + encoder.add_float_array(FLOAT_ARRAY) + buffer = encoder.finish() + + decoder = DamDecoder(buffer) + assert decoder.is_valid() + assert decoder.get_data_set_id() == DATA_SET + + int_array = decoder.decode_int_array() + assert int_array == INT_ARRAY + + float_array = decoder.decode_float_array() + assert float_array == FLOAT_ARRAY diff --git a/tests/unit_test/app_common/xgb/sec/data_converter_test.py b/tests/unit_test/app_common/xgb/sec/data_converter_test.py new file mode 100644 index 0000000000..ccf2165529 --- /dev/null +++ b/tests/unit_test/app_common/xgb/sec/data_converter_test.py @@ -0,0 +1,126 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List + +import pytest + +from nvflare.app_common.xgb.sec.dam import DamDecoder, DamEncoder +from nvflare.app_common.xgb.sec.data_converter import FeatureAggregationResult +from nvflare.app_common.xgb.sec.processor_data_converter import ( + DATA_SET_AGGREGATION_WITH_FEATURES, + DATA_SET_GH_PAIRS, + ProcessorDataConverter, +) + + +class TestDataConverter: + @pytest.fixture() + def data_converter(self): + yield ProcessorDataConverter() + + @pytest.fixture() + def gh_buffer(self): + + gh = [0.1, 0.2, 1.2, 1.2, 2.1, 2.2, 3.1, 3.2, 4.1, 4.2, 5.1, 5.2, 6.1, 6.2, 7.1, 7.2, 8.1, 8.2, 9.1, 9.2] + + encoder = DamEncoder(DATA_SET_GH_PAIRS) + encoder.add_float_array(gh) + return encoder.finish() + + @pytest.fixture() + def aggr_buffer(self): + + encoder = DamEncoder(DATA_SET_AGGREGATION_WITH_FEATURES) + + cuts = [0, 2, 5, 10] + encoder.add_int_array(cuts) + + features = [0, 2] + encoder.add_int_array(features) + + slots = [ + 0, + 5, + 1, + 9, + 1, + 6, + 0, + 7, + 0, + 9, + 0, + 8, + 1, + 5, + 0, + 6, + 0, + 8, + 1, + 5, + ] + encoder.add_int_array(slots) + + nodes_to_build = [0, 1] + encoder.add_int_array(nodes_to_build) + + row_id_1 = [0, 3, 6, 8] + row_id_2 = [1, 2, 4, 5, 7, 9] + encoder.add_int_array(row_id_1) + encoder.add_int_array(row_id_2) + + return encoder.finish() + + @pytest.fixture() + def aggr_results(self) -> Dict[int, List[FeatureAggregationResult]]: + feature0 = [(1100000, 1200000), (1200000, 1300000)] + feature2 = [(1100000, 1200000), (2100000, 2200000), (3100000, 3200000), (4100000, 4200000), (5100000, 5200000)] + + aggr_result0 = FeatureAggregationResult(0, feature0) + aggr_result2 = FeatureAggregationResult(2, feature2) + result_list = [aggr_result0, aggr_result2] + return {0: result_list, 1: result_list} + + def test_decode(self, data_converter, gh_buffer, aggr_buffer): + gh_pair = data_converter.decode_gh_pairs(gh_buffer, None) + assert len(gh_pair) == data_converter.num_samples + + context = data_converter.decode_aggregation_context(aggr_buffer, None) + assert len(context.features) == 2 + f1 = context.features[0] + assert f1.feature_id == 0 + assert f1.num_bins == 2 + assert f1.sample_bin_assignment == [0, 1, 1, 0, 0, 0, 1, 0, 0, 1] + + f2 = context.features[1] + assert f2.feature_id == 2 + assert f2.num_bins == 5 + assert f2.sample_bin_assignment == [0, 4, 1, 2, 4, 3, 0, 1, 3, 0] + + def test_encode(self, data_converter, aggr_results): + + # Simulate the state of converter after decode call + data_converter.feature_list = [0, 2] + buffer = data_converter.encode_aggregation_result(aggr_results, None) + + decoder = DamDecoder(buffer) + node_list = decoder.decode_int_array() + assert node_list == [0, 1] + + histo0 = decoder.decode_float_array() + assert histo0 == [1.1, 1.2, 1.2, 1.3] + + histo2 = decoder.decode_float_array() + assert histo2 == [1.1, 1.2, 2.1, 2.2, 3.1, 3.2, 4.1, 4.2, 5.1, 5.2]