-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add start_record interface #3128
Changes from 4 commits
3322e8b
f15c62b
7327697
0379f03
59f43d5
e888456
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ | |
be used in user program. | ||
""" | ||
|
||
__all__ = ['np_array', 'text_file', "recordio"] | ||
__all__ = ['np_array', 'text_file', "cloud_reader"] | ||
|
||
|
||
def np_array(x): | ||
|
@@ -81,35 +81,41 @@ def reader(): | |
return dec.buffered(reader, buf_size) | ||
|
||
|
||
def recordio(paths, buf_size=100): | ||
pass_num = 0 | ||
|
||
|
||
def cloud_reader(paths, etcd_endpoints, timeout_sec=5, buf_size=64): | ||
""" | ||
Creates a data reader that outputs record one one by one | ||
from given local or cloud recordio path. | ||
Create a data reader that yield a record one bye one from | ||
the paths: | ||
:path: path of recordio files. | ||
:etcd_endpoints: the endpoints for etcd cluster | ||
:returns: data reader of recordio files. | ||
|
||
.. code-block:: python | ||
from paddle.v2.reader.creator import cloud_reader | ||
etcd_endpoints = "http://127.0.0.1:2379" | ||
trainer.train.( | ||
reader=cloud_reader(["/work/dataset/uci_housing/uci_housing*"], etcd_endpoints), | ||
) | ||
""" | ||
import os | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need some demo code of how to use this reader. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
import paddle.v2.master.client as cloud | ||
|
||
if "KUBERNETES_SERVICE_HOST" not in os.environ.keys(): | ||
return recordio_local(paths) | ||
|
||
host_name = "MASTER_SERVICE_HOST" | ||
if host_name not in os.environ.keys(): | ||
raise Exception('not find ' + host_name + ' in environment variable.') | ||
|
||
addr = os.environ(host) | ||
import cPickle as pickle | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this, not see anywhere it being used. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, it's my mistake, we need |
||
import paddle.v2.master as master | ||
c = master.client(etcd_endpoints, timeout_sec, buf_size) | ||
c.set_dataset(paths) | ||
|
||
def reader(): | ||
c = cloud(addr, buf_size) | ||
c.set_dataset(paths) | ||
global pass_num | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Put There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
c.paddle_start_get_records(pass_num) | ||
pass_num += 1 | ||
|
||
while True: | ||
r, err = client.next_record() | ||
if err < 0: | ||
r, e = c.next_record() | ||
if not r: | ||
if e != -2: | ||
print "get record error: ", e | ||
break | ||
yield r | ||
|
||
c.release() | ||
yield pickle.loads(r) | ||
|
||
return reader |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one bye one => one by one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks!