Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add start_record interface #3128

Merged
merged 6 commits into from
Aug 1, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 9 additions & 18 deletions go/pserver/client/c/test/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,11 @@
import paddle.v2.master as master
import os
import cPickle as pickle
from paddle.v2.reader.creator import cloud_reader

etcd_ip = os.getenv("MASTER_IP", "127.0.0.1")
etcd_endpoint = "http://" + etcd_ip + ":2379"
print "connecting to master, etcd endpoints: ", etcd_endpoint
master_client = master.client(etcd_endpoint, 5, 64)


def cloud_reader():
global master_client
master_client.set_dataset(
["/pfs/dlnel/public/dataset/uci_housing/uci_housing-*"], passes=30)
while 1:
r, e = master_client.next_record()
if not r:
if e != -2: # other errors
print "get record error:", e
break
yield pickle.loads(r)
etcd_endpoints = "http://" + etcd_ip + ":2379"
print "etcd endpoints: ", etcd_endpoints


def main():
Expand Down Expand Up @@ -49,7 +36,7 @@ def main():
parameters=parameters,
update_equation=optimizer,
is_local=False,
pserver_spec=etcd_endpoint,
pserver_spec=etcd_endpoints,
use_etcd=True)

# event_handler to print training and testing info
Expand All @@ -75,7 +62,11 @@ def event_handler(event):
trainer.train(
reader=paddle.batch(
paddle.reader.shuffle(
cloud_reader, buf_size=500), batch_size=2),
cloud_reader(
["/pfs/dlnel/public/dataset/uci_housing/uci_housing*"],
etcd_endpoints),
buf_size=500),
batch_size=2),
feeding={'x': 0,
'y': 1},
event_handler=event_handler,
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/v2/master/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,6 @@ def next_record(self):
# Memory created from C should be freed.
get_c_lib().mem_free(ret.contents)
return record, 0

def paddle_start_get_records(self, pass_id):
get_c_lib().paddle_start_get_records(self.c, pass_id)
48 changes: 27 additions & 21 deletions python/paddle/v2/reader/creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
be used in user program.
"""

__all__ = ['np_array', 'text_file', "recordio"]
__all__ = ['np_array', 'text_file', "cloud_reader"]


def np_array(x):
Expand Down Expand Up @@ -81,35 +81,41 @@ def reader():
return dec.buffered(reader, buf_size)


def recordio(paths, buf_size=100):
pass_num = 0


def cloud_reader(paths, etcd_endpoints, timeout_sec=5, buf_size=64):
"""
Creates a data reader that outputs record one one by one
from given local or cloud recordio path.
Create a data reader that yield a record one by one from
the paths:
:path: path of recordio files.
:etcd_endpoints: the endpoints for etcd cluster
:returns: data reader of recordio files.

.. code-block:: python
from paddle.v2.reader.creator import cloud_reader
etcd_endpoints = "http://127.0.0.1:2379"
trainer.train.(
reader=cloud_reader(["/work/dataset/uci_housing/uci_housing*"], etcd_endpoints),
)
"""
import os
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need some demo code of how to use this reader.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

import paddle.v2.master.client as cloud

if "KUBERNETES_SERVICE_HOST" not in os.environ.keys():
return recordio_local(paths)

host_name = "MASTER_SERVICE_HOST"
if host_name not in os.environ.keys():
raise Exception('not find ' + host_name + ' in environment variable.')

addr = os.environ(host)
import cPickle as pickle
import paddle.v2.master as master
c = master.client(etcd_endpoints, timeout_sec, buf_size)
c.set_dataset(paths)

def reader():
c = cloud(addr, buf_size)
c.set_dataset(paths)
global pass_num
c.paddle_start_get_records(pass_num)
pass_num += 1

while True:
r, err = client.next_record()
if err < 0:
r, e = c.next_record()
if not r:
if e != -2:
print "get record error: ", e
break
yield r

c.release()
yield pickle.loads(r)

return reader
9 changes: 0 additions & 9 deletions python/paddle/v2/reader/tests/creator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,5 @@ def test_text_file(self):
self.assertEqual(e, str(idx * 2) + " " + str(idx * 2 + 1))


class TestRecordIO(unittest.TestCase):
def test_recordio(self):
path = os.path.join(
os.path.dirname(__file__), "test_recordio_creator.dat")
reader = paddle.v2.reader.creator.recordio([path])
for idx, r in enumerate(reader()):
self.assertSequenceEqual(r, str(idx))


if __name__ == '__main__':
unittest.main()