# # Copyright 2019 The FATE Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse from fate_client.pipeline import FateFlowPipeline from fate_client.pipeline.components.fate import DataSplit, PSI, Reader, CoordinatedLR from fate_client.pipeline.utils import test_utils def main(config="../config.yaml", namespace=""): if isinstance(config, str): config = test_utils.load_job_config(config) parties = config.parties guest = parties.guest[0] host = parties.host[0] arbiter = parties.arbiter[0] pipeline = FateFlowPipeline().set_parties(guest=guest, host=host, arbiter=arbiter) if config.task_cores: pipeline.conf.set("task_cores", config.task_cores) if config.timeout: pipeline.conf.set("timeout", config.timeout) reader_0 = Reader("reader_0", runtime_parties=dict(guest=guest, host=host)) reader_0.guest.task_parameters( namespace=f"experiment{namespace}", name="breast_hetero_guest" ) reader_0.hosts[0].task_parameters( namespace=f"experiment{namespace}", name="breast_hetero_host" ) reader_1 = Reader("reader_1", runtime_parties=dict(guest=guest, host=host)) reader_1.guest.task_parameters( namespace=f"experiment{namespace}", name="breast_hetero_guest" ) reader_1.hosts[0].task_parameters( namespace=f"experiment{namespace}", name="breast_hetero_host" ) psi_0 = PSI("psi_0", input_data=reader_0.outputs["output_data"]) psi_1 = PSI("psi_1", input_data=reader_1.outputs["output_data"]) data_split_0 = DataSplit("data_split_0", train_size=0.6, validate_size=0.0, test_size=0.4, stratified=True, input_data=psi_0.outputs["output_data"], runtime_parties=dict(guest=guest, host=host)) data_split_1 = DataSplit("data_split_1", train_size=200, test_size=50, stratified=True, hetero_sync=True, input_data=psi_1.outputs["output_data"] ) lr_0 = CoordinatedLR("lr_0", epochs=10, batch_size=300, optimizer={"method": "SGD", "optimizer_params": {"lr": 0.1}, "penalty": "l2", "alpha": 0.001}, init_param={"fit_intercept": True, "method": "zeros"}, train_data=data_split_1.outputs["train_output_data"], validate_data=data_split_1.outputs["validate_output_data"], learning_rate_scheduler={"method": "linear", "scheduler_params": {"start_factor": 0.7, "total_iters": 100}}) pipeline.add_tasks([reader_0, reader_1, psi_0, psi_1, data_split_0, data_split_1, lr_0]) pipeline.compile() # print(pipeline.get_dag()) pipeline.fit() if __name__ == "__main__": parser = argparse.ArgumentParser("PIPELINE DEMO") parser.add_argument("--config", type=str, default="../config.yaml", help="config file") parser.add_argument("--namespace", type=str, default="", help="namespace for data stored in FATE") args = parser.parse_args() main(config=args.config, namespace=args.namespace)