-
Notifications
You must be signed in to change notification settings - Fork 46
/
train.py
25 lines (22 loc) · 1019 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import datetime
import utils.csv_record as csv_record
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import main
import loan_train
import image_train
import config
import random
def train(helper, start_epoch, local_model, target_model, is_poison,agent_name_keys):
epochs_submit_update_dict={}
num_samples_dict={}
if helper.params['type'] == config.TYPE_LOAN:
epochs_submit_update_dict, num_samples_dict = loan_train.LoanTrain(helper, start_epoch, local_model, target_model, is_poison,agent_name_keys)
elif helper.params['type'] == config.TYPE_CIFAR \
or helper.params['type'] == config.TYPE_MNIST \
or helper.params['type']==config.TYPE_TINYIMAGENET:
epochs_submit_update_dict, num_samples_dict = image_train.ImageTrain(helper, start_epoch, local_model,
target_model, is_poison, agent_name_keys)
return epochs_submit_update_dict, num_samples_dict