17
17
import argparse
18
18
19
19
from fate_client .pipeline import FateFlowPipeline
20
- from fate_client .pipeline .components .fate import CoordinatedLR , Intersection
20
+ from fate_client .pipeline .components .fate import CoordinatedLR , PSI
21
21
from fate_client .pipeline .components .fate import Evaluation
22
22
from fate_client .pipeline .interface import DataWarehouseChannel
23
23
from fate_client .pipeline .utils import test_utils
@@ -48,11 +48,11 @@ def main(config="../../config.yaml", param="./lr_config.yaml", namespace=""):
48
48
if config .timeout :
49
49
pipeline .conf .set ("timeout" , config .timeout )
50
50
51
- intersect_0 = Intersection ( "intersect_0" , method = "raw " )
52
- intersect_0 .guest .component_setting (input_data = DataWarehouseChannel (name = guest_train_data ["name" ],
53
- namespace = guest_train_data ["namespace" ]))
54
- intersect_0 .hosts [0 ].component_setting (input_data = DataWarehouseChannel (name = host_train_data ["name" ],
55
- namespace = host_train_data ["namespace" ]))
51
+ psi_0 = PSI ( "psi_0 " )
52
+ psi_0 .guest .component_setting (input_data = DataWarehouseChannel (name = guest_train_data ["name" ],
53
+ namespace = guest_train_data ["namespace" ]))
54
+ psi_0 .hosts [0 ].component_setting (input_data = DataWarehouseChannel (name = host_train_data ["name" ],
55
+ namespace = host_train_data ["namespace" ]))
56
56
57
57
lr_param = {
58
58
}
@@ -68,10 +68,10 @@ def main(config="../../config.yaml", param="./lr_config.yaml", namespace=""):
68
68
}
69
69
lr_param .update (config_param )
70
70
lr_0 = CoordinatedLR ("lr_0" ,
71
- train_data = intersect_0 .outputs ["output_data" ],
71
+ train_data = psi_0 .outputs ["output_data" ],
72
72
** lr_param )
73
73
lr_1 = CoordinatedLR ("lr_1" ,
74
- test_data = intersect_0 .outputs ["output_data" ],
74
+ test_data = psi_0 .outputs ["output_data" ],
75
75
input_model = lr_0 .outputs ["output_model" ])
76
76
77
77
evaluation_0 = Evaluation ("evaluation_0" ,
@@ -80,7 +80,7 @@ def main(config="../../config.yaml", param="./lr_config.yaml", namespace=""):
80
80
metrics = ["auc" , "binary_precision" , "binary_accuracy" , "binary_recall" ],
81
81
input_data = lr_0 .outputs ["train_output_data" ])
82
82
83
- pipeline .add_task (intersect_0 )
83
+ pipeline .add_task (psi_0 )
84
84
pipeline .add_task (lr_0 )
85
85
pipeline .add_task (lr_1 )
86
86
pipeline .add_task (evaluation_0 )
0 commit comments