Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add data format #872

Closed
wants to merge 20 commits into from
Closed
Empty file added python/paddle/data/__init__.py
Empty file.
103 changes: 103 additions & 0 deletions python/paddle/data/amazon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#/usr/bin/env python
# -*- coding:utf-8 -*-

# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add docstring

添加整体文件的注释。docstring

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


import shutil
import os
import sys
import zipfile
import collections
import numpy as np
from six.moves import urllib
import stat

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

添加 __all__

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不需要添加all吧

source_url='http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/reviews_Electronics_5.json.gz'
moses_url='https://github.com/moses-smt/mosesdecoder/archive/master.zip'
file_source = "mosesdecoder-master"
def fetch():
source_name = "amazon"
#file_source = "mosesdecoder-master"
#Set the download dir for cifar.
data_home = set_data_path(source_name)
#filepath = data_download(data_home,moses_url)
filepath = data_download(data_home, source_url)
filepath = data_download(data_home, moses_url)
"""
for i in range(1, num_batch + 1):
fpath = os.path.join(filepath, "data_batch_%d" % i)
"""

def _unpickle(file_path):
with open(file_path, mode='rb') as file:
if sys.version_info < (3,):
data = cPickle.load(file)
else:
data = cPickle.load(file, encoding='bytes')
return data

def set_data_path(source_name):
data_base = os.path.expanduser(os.path.join('~',' .paddle'))
if not os.access(data_base, os.W_OK):
data_base = os.path.join('/tmp', '.paddle')
datadir = os.path.join(data_base, source_name)
print datadir
if not os.path.exists(datadir):
os.makedirs(datadir)
return datadir

def data_download(download_dir, source_url):
src_file = source_url.strip().split('/')[-1]
file_path = os.path.join(download_dir, src_file)

if not os.path.exists(file_path):
temp_file_name,_ = download_with_urlretrieve(source_url)
temp_file_path = os.getcwd()
os.rename(temp_file_name, src_file)
move_files(src_file, download_dir)
print("Download finished, Extracting files.")

if 'zip' in src_file:
tar = zipfile.ZipFile(file_path,'r')
infos = tar.infolist()
for file in infos:
tar.extract(file, download_dir)
fpath = os.path.join(download_dir, file.filename)
os.chmod(fpath,stat.S_IRWXU|stat.S_IRGRP|stat.S_IROTH)
os.remove(file_path)
print("Unpacking done!")
else:
if 'zip' in src_file:
tar = zipfile.ZipFile(file_path,'r')
infos = tar.infolist()
for file in infos:
tar.extract(file, download_dir)
fpath = os.path.join(download_dir, file.filename)
os.chmod(fpath,stat.S_IRWXU|stat.S_IRGRP|stat.S_IROTH)
os.remove(file_path)
print("Data has been already downloaded and unpacked!")
return download_dir

def move_files(source_dire, target_dire):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数没必要有吧。直接调用shutils就好了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好,done

shutil.move(source_dire, target_dire)

def download_with_urlretrieve(url, filename=None):
return urllib.request.urlretrieve(url, filename)


if __name__ == '__main__':
path = fetch()
print path
100 changes: 100 additions & 0 deletions python/paddle/data/cifar_10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#/usr/bin/env python
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

似乎要修改下 python/CMakeLists.txt 类似
file(GLOB UTILS_PY_FILES . ./paddle/data/*.py)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个不需要,我问过廖纲,只需要修改setup.py.in 就可以。本地测试了,可以load上

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# -*- coding:utf-8 -*-

# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import shutil
import os
import sys
import tarfile
import zipfile
import collections
import numpy as np
from six.moves import urllib

source_url='https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
source_file = "cifar-10-batches-py"
label_map = {
0: "airplane",
1: "automobile",
2: "bird",
3: "cat",
4: "deer",
5: "dog",
6: "frog",
7: "horse",
8: "ship",
9: "truck"
}

def fetch():
num_images_train = 50000
num_batch = 5
source_name = "cifar"
file_source = "cifar-10-batches-py"
#Set the download dir for cifar.
data_home = set_data_path(source_name)
filepath = data_download(data_home, source_url)
"""
for i in range(1, num_batch + 1):
fpath = os.path.join(filepath, "data_batch_%d" % i)
"""

def _unpickle(file_path):
with open(file_path, mode='rb') as file:
if sys.version_info < (3,):
data = cPickle.load(file)
else:
data = cPickle.load(file, encoding='bytes')
return data

def set_data_path(source_name):
data_base = os.path.expanduser(os.path.join('~', '.paddle'))
print data_base
if not os.access(data_base, os.W_OK):
data_base = os.path.join('/tmp', '.paddle')
datadir = os.path.join(data_base, source_name)
print datadir
if not os.path.exists(datadir):
os.makedirs(datadir)
return datadir

def data_download(download_dir, source_url):
src_file = source_url.strip().split('/')[-1]
file_path = os.path.join(download_dir, src_file)
if not os.path.exists(file_path):
temp_file_name,_ = download_with_urlretrieve(source_url)
temp_file_path = os.getcwd()
os.rename(temp_file_name, src_file)
move_files(src_file, download_dir)
print("Download finished,Extracting files.")
tarfile.open(name=file_path, mode="r:gz").extractall(download_dir)
print("Unpacking done!")
else:
tarfile.open(name=file_path, mode="r:gz").extractall(download_dir)
print("Data has been already downloaded and unpacked!")
return download_dir
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

方便的话返回解压缩的路径吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def move_files(source_dire, target_dire):
shutil.move(source_dire, target_dire)

def download_with_urlretrieve(url, filename=None):
return urllib.request.urlretrieve(url, filename)


if __name__ == '__main__':
path = fetch()
print path
83 changes: 83 additions & 0 deletions python/paddle/data/mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#/usr/bin/env python
# -*- coding:utf-8 -*-

# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import shutil
import os
import sys
import collections
import numpy as np
from six.moves import urllib
import urlparse
import gzip

source_url = 'http://yann.lecun.com/exdb/mnist/'
filename = ['train-images-idx3-ubyte.gz','train-labels-idx1-ubyte.gz','t10k-images-idx3-ubyte.gz','t10k-labels-idx1-ubyte.gz']

def fetch():
source_name = "mnist"
file_source = "cifar-10-batches-py"
#Set the download dir for cifar.
data_home = set_data_path(source_name)
filepath = data_download(data_home, source_url)
"""
for i in range(1, num_batch + 1):
fpath = os.path.join(filepath, "data_batch_%d" % i)
"""

def set_data_path(source_name):
data_base = os.path.expanduser(os.path.join('~', '.paddle'))
if not os.access(data_base, os.W_OK):
data_base = os.path.join('/tmp', '.paddle')
datadir = os.path.join(data_base, source_name)
print datadir
if not os.path.exists(datadir):
os.makedirs(datadir)
return datadir

def data_download(download_dir, source_url):
for file in filename:
data_url = urlparse.urljoin(source_url, file)
file_path = os.path.join(download_dir, file)
untar_path = os.path.join(download_dir, file.replace(".gz", ""))
if not os.path.exists(file_path):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单纯的gzip文件不需要解压缩。。因为Python可以直接按照读取。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

temp_file_name,_ = download_with_urlretrieve(data_url)
temp_file_path = os.getcwd()
os.rename(temp_file_name, file)
move_files(file, download_dir)
print("Download finished,Extracting files.")
g_file = gzip.GzipFile(file_path)
open(untar_path,'w+').write(g_file.read())
g_file.close()
print("Unpacking done!")
else:
g_file = gzip.GzipFile(file_path)
open(untar_path, 'w+').write(g_file.read())
g_file.close()
print("Data has been already downloaded and unpacked!")
os.remove(file_path)
return download_dir
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

返回这几个文件的路径吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def move_files(source_dire, target_dire):
shutil.move(source_dire, target_dire)

def download_with_urlretrieve(url, filename=None):
return urllib.request.urlretrieve(url, filename)


if __name__ == '__main__':
path = fetch()
print path
1 change: 1 addition & 0 deletions python/setup.py.in
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from setuptools import setup

packages=['paddle',
'paddle.data',
'paddle.proto',
'paddle.trainer',
'paddle.trainer_config_helpers',
Expand Down