@@ -277,50 +277,30 @@ def cloud_train(train_file_pattern,
277277 print (job_request )
278278
279279
280- def local_predict (model_dir , prediction_input_file ):
280+ def local_predict ():
281281 """Runs local prediction.
282282
283283 Runs local prediction in memory and prints the results to the screen. For
284284 running prediction on a large dataset or saving the results, run
285285 local_batch_prediction or batch_prediction.
286286
287287 Args:
288- model_dir: Path to folder that contains the model. This is usully OUT/model
289- where OUT is the value of output_dir when local_training was ran.
290- prediction_input_file: csv file that has the same schem as the input
291- files used during local_preprocess, except that the target column is
292- removed.
288+
293289 """
294- pass
295- #TODO(brandondutra): remove this hack once cloudml 1.8 is released.
296- # Check that the model folder has a metadata.yaml file. If not, copy it.
297- # if not os.path.isfile(os.path.join(model_dir, 'metadata.yaml')):
298- # shutil.copy2(os.path.join(model_dir, 'metadata.json'),
299- # os.path.join(model_dir, 'metadata.yaml'))
290+ # Save the instances to a file, call local batch prediction, and print it back
291+
300292
301- # cmd = ['gcloud beta ml local predict',
302- # '--model-dir=%s' % model_dir,
303- # '--text-instances=%s' % prediction_input_file]
304- # print('Local prediction, running command: %s' % ' '.join(cmd))
305- # _run_cmd(' '.join(cmd))
306- # print('Local prediction done.')
307293
308294
309- def cloud_predict (model_name , prediction_input_file , version_name = None ):
295+ def cloud_predict ():
310296 """Use Online prediction.
311297
312298 Runs online prediction in the cloud and prints the results to the screen. For
313299 running prediction on a large dataset or saving the results, run
314300 local_batch_prediction or batch_prediction.
315301
316302 Args:
317- model_dir: Path to folder that contains the model. This is usully OUT/model
318- where OUT is the value of output_dir when local_training was ran.
319- prediction_input_file: csv file that has the same schem as the input
320- files used during local_preprocess, except that the target column is
321- removed.
322- vsersion_name: Optional version of the model to use. If None, the default
323- version is used.
303+
324304
325305 Before using this, the model must be created. This can be done by running
326306 two gcloud commands:
@@ -334,91 +314,67 @@ def cloud_predict(model_name, prediction_input_file, version_name=None):
334314 Note that the model must be on GCS.
335315 """
336316 pass
337- # cmd = ['gcloud beta ml predict',
338- # '--model=%s' % model_name,
339- # '--text-instances=%s' % prediction_input_file]
340- # if version_name:
341- # cmd += ['--version=%s' % version_name]
342317
343- # print('CloudML online prediction, running command: %s' % ' '.join(cmd))
344- # _run_cmd(' '.join(cmd))
345- # print('CloudML online prediction done.')
346318
347319
348- def local_batch_predict (model_dir , prediction_input_file , output_dir ):
320+ def local_batch_predict (model_dir , prediction_input_file , output_dir ,
321+ batch_size = 1000 , shard_files = True ):
349322 """Local batch prediction.
350323
351324 Args:
352- model_dir: local path to trained model.
353- prediction_input_file: File path to input files. May contain a file pattern.
354- Only csv files are supported, and the scema must match what was used
355- in preprocessing except that the target column is removed.
356- output_dir: folder to save results to.
325+ model_dir: local file path to trained model. Usually, this is
326+ training_output_dir/model.
327+ prediction_input_file: csv file pattern to a local file.
328+ output_dir: local output location to save the results.
329+ batch_size: Int. How many instances to run in memory at once. Larger values
330+ mean better performace but more memeory consumed.
331+ shard_files: If false, the output files are not shardded.
357332 """
358- pass
359- #TODO(brandondutra): remove this hack once cloudml 1.8 is released.
360- # Check that the model folder has a metadata.yaml file. If not, copy it.
361- # if not os.path.isfile(os.path.join(model_dir, 'metadata.yaml')):
362- # shutil.copy2(os.path.join(model_dir, 'metadata.json'),
363- # os.path.join(model_dir, 'metadata.yaml'))
333+ cmd = ['predict.py' ,
334+ '--predict_data=%s' % prediction_input_file ,
335+ '--trained_model_dir=%s' % model_dir ,
336+ '--output_dir=%s' % output_dir ,
337+ '--output_format=csv' ,
338+ '--batch_size=%s' % str (batch_size )]
339+
340+ if shard_files :
341+ cmd .append ('--shard_files' )
342+ else :
343+ cmd .append ('--no-shard_files' )
364344
365- # cmd = ['python -m google.cloud.ml.dataflow.batch_prediction_main',
366- # '--input_file_format=text',
367- # '--input_file_patterns=%s' % prediction_input_file,
368- # '--output_location=%s' % output_dir,
369- # '--model_dir=%s' % model_dir]
345+ print ('Starting local batch prediction.' )
346+ predict .predict .main (args )
347+ print ('Local batch prediction done.' )
370348
371- # print('Local batch prediction, running command: %s' % ' '.join(cmd))
372- # _run_cmd(' '.join(cmd))
373- # print('Local batch prediction done.')
374349
375350
376- def cloud_batch_predict (model_name , prediction_input_file , output_dir , region ,
377- job_name = None , version_name = None ):
378- """Cloud batch prediction.
351+ def cloud_batch_predict (model_dir , prediction_input_file , output_dir ,
352+ batch_size = 1000 , shard_files = True ):
353+ """Cloud batch prediction. Submitts a Dataflow job.
379354
380355 Args:
381- model_name: name of the model. The model must already exist.
382- prediction_input_file: File path to input files. May contain a file pattern.
383- Only csv files are supported, and the scema must match what was used
384- in preprocessing except that the target column is removed. Files must
385- be on GCS
386- output_dir: GCS folder to safe results to.
387- region: GCP compute region to run the batch job. Try using your default
388- region first, as this cloud batch prediction is not avaliable in all
389- regions.
390- job_name: job name used for the cloud job.
391- version_name: model version to use. If node, the default version of the
392- model is used.
356+ model_dir: GSC file path to trained model. Usually, this is
357+ training_output_dir/model.
358+ prediction_input_file: csv file pattern to a GSC file.
359+ output_dir: Location to save the results on GCS.
360+ batch_size: Int. How many instances to run in memory at once. Larger values
361+ mean better performace but more memeory consumed.
362+ shard_files: If false, the output files are not shardded.
393363 """
394- pass
395- # job_name = job_name or ('structured_data_batch_predict_' +
396- # datetime.datetime.now().strftime('%Y%m%d%H%M%S'))
397-
398- # if (not prediction_input_file.startswith('gs://') or
399- # not output_dir.startswith('gs://')):
400- # print('ERROR: prediction_input_file and output_dir must point to a '
401- # 'location on GCS.')
402- # return
403-
404- # cmd = ['gcloud beta ml jobs submit prediction %s' % job_name,
405- # '--model=%s' % model_name,
406- # '--region=%s' % region,
407- # '--data-format=TEXT',
408- # '--input-paths=%s' % prediction_input_file,
409- # '--output-path=%s' % output_dir]
410- # if version_name:
411- # cmd += ['--version=%s' % version_name]
412-
413- # print('CloudML batch prediction, running command: %s' % ' '.join(cmd))
414- # _run_cmd(' '.join(cmd))
415- # print('CloudML batch prediction job submitted.')
416-
417- # if _is_in_IPython():
418- # import IPython
419-
420- # dataflow_url = ('https://console.developers.google.com/ml/jobs?project=%s'
421- # % _default_project())
422- # html = ('<p>Click <a href="%s" target="_blank">here</a> to track '
423- # 'the prediction job %s.</p><br/>' % (dataflow_url, job_name))
424- # IPython.display.display_html(html, raw=True)
364+ cmd = ['predict.py' ,
365+ '--cloud' ,
366+ '--project_id=%s' % _default_project (),
367+ '--predict_data=%s' % prediction_input_file ,
368+ '--trained_model_dir=%s' % model_dir ,
369+ '--output_dir=%s' % output_dir ,
370+ '--output_format=csv' ,
371+ '--batch_size=%s' % str (batch_size )]
372+
373+ if shard_files :
374+ cmd .append ('--shard_files' )
375+ else :
376+ cmd .append ('--no-shard_files' )
377+
378+ print ('Starting cloud batch prediction.' )
379+ predict .predict .main (args )
380+ print ('See above link for job status.' )
0 commit comments