Skip to content

Commit

Permalink
fix windows env (PaddlePaddle#312)
Browse files Browse the repository at this point in the history
* fix windows env

* fix windows envs
  • Loading branch information
xiteng1988 authored May 27, 2020
1 parent cf4d540 commit 3dfa717
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions demo/slimfacenet/train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,10 @@ def train(exe, train_program, train_out, test_program, test_out, args):


def build_program(program, startup, args, is_train=True):
num_trainers = len(os.getenv('CUDA_VISIBLE_DEVICES').split(','))
if args.use_gpu:
num_trainers = len(os.getenv('CUDA_VISIBLE_DEVICES').split(','))
else:
num_trainers = 1
places = fluid.cuda_places() if args.use_gpu else fluid.CPUPlace()

train_dataset = CASIA_Face(root=args.train_data_dir)
Expand All @@ -166,7 +169,7 @@ def build_program(program, startup, args, is_train=True):
image = fluid.data(
name='image', shape=[-1, 3, 112, 96], dtype='float32')
label = fluid.data(name='label', shape=[-1, 1], dtype='int64')
train_reader = paddle.fluid.io.batch(
train_reader = fluid.io.batch(
train_dataset.reader,
batch_size=args.train_batchsize // num_trainers,
drop_last=False)
Expand All @@ -187,7 +190,7 @@ def build_program(program, startup, args, is_train=True):
else:
nl, nr, flods, flags = parse_filelist(args.test_data_dir)
test_dataset = LFW(nl, nr)
test_reader = paddle.fluid.io.batch(
test_reader = fluid.io.batch(
test_dataset.reader,
batch_size=args.test_batchsize,
drop_last=False)
Expand Down Expand Up @@ -231,7 +234,7 @@ def build_program(program, startup, args, is_train=True):
def quant_val_reader_batch():
nl, nr, flods, flags = parse_filelist(args.test_data_dir)
test_dataset = LFW(nl, nr)
test_reader = paddle.fluid.io.batch(
test_reader = fluid.io.batch(
test_dataset.reader, batch_size=1, drop_last=False)
shuffle_reader = fluid.io.shuffle(test_reader, 3)

Expand Down Expand Up @@ -298,14 +301,16 @@ def main():
help='The path of the extract features save, must be .mat file')
args = parser.parse_args()

num_trainers = len(os.getenv('CUDA_VISIBLE_DEVICES').split(','))
if args.use_gpu:
num_trainers = len(os.getenv('CUDA_VISIBLE_DEVICES').split(','))
else:
num_trainers = 1
print(args)
print('num_trainers: {}'.format(num_trainers))
if args.save_ckpt == None:
args.save_ckpt = 'output'
if not os.path.exists(args.save_ckpt):
subprocess.call(['mkdir', '-p', args.save_ckpt])

if not os.path.isdir(args.save_ckpt):
os.makedirs(args.save_ckpt)
with open(os.path.join(args.save_ckpt, 'log.txt'), 'w+') as f:
f.writelines(str(args) + '\n')
f.writelines('num_trainers: {}'.format(num_trainers) + '\n')
Expand Down Expand Up @@ -346,7 +351,7 @@ def main():
executor=exe)
nl, nr, flods, flags = parse_filelist(args.test_data_dir)
test_dataset = LFW(nl, nr)
test_reader = paddle.fluid.io.batch(
test_reader = fluid.io.batch(
test_dataset.reader,
batch_size=args.test_batchsize,
drop_last=False)
Expand Down

0 comments on commit 3dfa717

Please sign in to comment.