-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathParams.py
30 lines (29 loc) · 2.3 KB
/
Params.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import argparse
def parse_args():
parser = argparse.ArgumentParser(description='Model Params')
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
parser.add_argument('--batch', default=2048, type=int, help='batch size')
parser.add_argument('--sBatch', default=1024, type=int, help='divider for batch size in social graph')
parser.add_argument('--test_batch', default=256, type=int, help='batch size in testset')
parser.add_argument('--tstEpoch', default=3, type=int, help='number of epoch to test while training')
parser.add_argument('--reg', default=1e-5, type=float, help='weight decay regularizer')
parser.add_argument('--epoch', default=100, type=int, help='number of epochs')
parser.add_argument('--decay', default=0.96, type=float, help='weight decay rate')
parser.add_argument('--save_name', default='tem', help='file name to save model and training record')
parser.add_argument('--latdim', default=16, type=int, help='embedding size')
parser.add_argument('--load_model', default=None, help='model name to load')
parser.add_argument('--data', default='CiaoDVD', type=str, help='name of dataset')
parser.add_argument('--gpu', default='0', type=str, help='gpu indices')
parser.add_argument('--patience', default='5', type=int, help='early stopping patience')
parser.add_argument('--seed', default='1024', type=int, help='random seed')
parser.add_argument('--gnn_layer', default=2, type=int, help='number of gnn layers')
parser.add_argument('--uugnn_layer', default=2, type=int, help='number of gnn layers for social graph')
parser.add_argument('--topk', default=10, type=int, help='K of top K')
parser.add_argument('--leaky', default=0.5, type=float, help='slope for leaky relu')
parser.add_argument('--uuPre_reg', default=1e-7, type=float, help='weights for social graph prediction regularization')
parser.add_argument('--sal_reg', default=1e-7, type=float, help='weights for SAL regularization')
parser.add_argument('--dropRate', default=0.5, type=float, help='rate for dropout layer')
parser.add_argument('--edge_drop', default=0.5, type=float, help='rate for dropout edge')
parser.add_argument('--msg_drop', default=0.5, type=float, help='rate for dropout message')
return parser.parse_args()
args = parse_args()