Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in yield dictionary in DataProvider. #197

Merged
merged 5 commits into from
Oct 17, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 2.8)
project(paddle CXX C)
set(PADDLE_MAJOR_VERSION 0)
set(PADDLE_MINOR_VERSION 8)
set(PADDLE_PATCH_VERSION 0b1)
set(PADDLE_PATCH_VERSION 0b2)
set(PADDLE_VERSION ${PADDLE_MAJOR_VERSION}.${PADDLE_MINOR_VERSION}.${PADDLE_PATCH_VERSION})

set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake")
Expand Down
17 changes: 17 additions & 0 deletions cmake/util.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,20 @@ macro(add_paddle_culib TARGET_NAME)
cuda_add_library(${TARGET_NAME} STATIC ${ARGN})
set(CUDA_NVCC_FLAGS ${NVCC_FLAG})
endmacro()


# Creates C resources file from files in given resource file
function(create_resources res_file output)
# Create empty output file
file(WRITE ${output} "")
# Get short filename
string(REGEX MATCH "([^/]+)$" filename ${res_file})
# Replace filename spaces & extension separator for C compatibility
string(REGEX REPLACE "\\.| |-" "_" filename ${filename})
# Read hex data from file
file(READ ${res_file} filedata HEX)
# Convert hex data for C compatibility
string(REGEX REPLACE "([0-9a-f][0-9a-f])" "0x\\1," filedata ${filedata})
# Append data to output file
file(APPEND ${output} "const unsigned char ${filename}[] = {${filedata}};\nconst unsigned ${filename}_size = sizeof(${filename});\n")
endfunction()
Empty file modified demo/mnist/data/get_mnist_data.sh
100644 → 100755
Empty file.
19 changes: 9 additions & 10 deletions demo/mnist/mnist_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@


# Define a py data provider
@provider(input_types=[
dense_vector(28 * 28),
integer_value(10)
])
@provider(input_types={
'pixel': dense_vector(28 * 28),
'label': integer_value(10)
})
def process(settings, filename): # settings is not used currently.
imgf = filename + "-images-idx3-ubyte"
labelf = filename + "-labels-idx1-ubyte"
Expand All @@ -14,20 +14,19 @@ def process(settings, filename): # settings is not used currently.

f.read(16)
l.read(8)

# Define number of samples for train/test
if "train" in filename:
n = 60000
else:
n = 10000

for i in range(n):
label = ord(l.read(1))
pixels = []
for j in range(28*28):
for j in range(28 * 28):
pixels.append(float(ord(f.read(1))) / 255.0)
yield { "pixel": pixels, 'label': label }
yield {"pixel": pixels, 'label': label}

f.close()
l.close()

1 change: 1 addition & 0 deletions demo/mnist/vgg_16_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

if not is_predict:
lbl = data_layer(name="label", size=label_size)
inputs(img, lbl)
outputs(classification_cost(input=predict, label=lbl))
else:
outputs(predict)
10 changes: 5 additions & 5 deletions doc_cn/ui/data_provider/mnist_provider.dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@


# Define a py data provider
@provider(input_types=[
dense_vector(28 * 28),
integer_value(10)
])
@provider(input_types={
'pixel': dense_vector(28 * 28),
'label': integer_value(10)
})
def process(settings, filename): # settings is not used currently.
f = open(filename, 'r') # open one of training file

Expand All @@ -20,6 +20,6 @@ def process(settings, filename): # settings is not used currently.
pixels_float.append(float(each_pixel_str))

# give data to paddle.
yield { "pixel": pixels_float, 'label': int(label) }
yield {"pixel": pixels_float, 'label': int(label)}

f.close() # close file
2 changes: 0 additions & 2 deletions doc_cn/ui/data_provider/pydataprovider2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,6 @@ DataProvider创建的时候执行。这个初始化函数具有如下参数:
是一个batch size,但是有时为了计算均衡性,可以将一条数据设置成多个batch size
* cache 是数据缓存的策略,参考 `cache`_
* init_hook 是初始化时调用的函数,参考 `init_hook`_
* use_dynamic_order 如果是true的话,可以返回一个dict,key是data_layer的名字,value是特征值。同时,也可以
返回一个list或者tuple。如果是false的话,只能够返回list或者tuple
* check 设置成true的话,会根据input_types检查数据的合法性。
* check_fail_continue 如果设置成true的话,即使在check中数据不合法,也会扔到这条数据,继续训练。 如果
check是false的话,没有作用。
Expand Down
3 changes: 1 addition & 2 deletions paddle/gserver/dataproviders/PyDataProvider2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,7 @@ class PyDataProvider2 : public DataProvider {
PyObjectPtr && kwargs) {
LOG(INFO) << "loading dataprovider " << model <<"::" << className;

PyObjectPtr module(PyImport_ImportModule(model.c_str()));
CHECK_PY(module) << "Cannot imort module " << model.c_str();
PyObjectPtr module = py::import(model);
PyObjectPtr moduleDict(PyModule_GetDict(module.get()));
CHECK_PY(moduleDict) << "Invoke module.__dict__ error";
PyObjectPtr cls(PyDict_GetItemString(moduleDict.get(),
Expand Down
2 changes: 1 addition & 1 deletion paddle/gserver/tests/test_PyDataProvider2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ TEST(PyDataProvider2, index_no_seq) {
}

TEST(PyDataProvider2, init_hook) {
paddle::PyObjectPtr pickle(PyImport_ImportModule("pickle"));
paddle::PyObjectPtr pickle = paddle::py::import("pickle");
paddle::PyObjectPtr globals(
PyModule_GetDict(PyImport_AddModule("__main__")));
PyDict_SetItemString(globals.get(), "pickle", pickle.get());
Expand Down
2 changes: 1 addition & 1 deletion paddle/gserver/tests/test_PyDataProvider2.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_can_over_batch_size(setting, filename):
yield [random.randint(0, 100 - 1) for _ in xrange(seq_len)]


@provider(input_types=[index_slot(10), index_slot(10)])
@provider(input_types={'input1':index_slot(10), 'input2': index_slot(10)})
def test_input_order(setting, filename):
for _ in xrange(1000):
yield {
Expand Down
1 change: 1 addition & 0 deletions paddle/utils/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
enable_virtualenv.c
6 changes: 5 additions & 1 deletion paddle/utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@

file(GLOB UTIL_HEADERS . *.h)
file(GLOB UTIL_SOURCES . *.cpp)
create_resources(enable_virtualenv.py enable_virtualenv.c)
set(UTIL_RES enable_virtualenv.c)

if(APPLE)
file(GLOB UTIL_ARCH_SOURCES . arch/osx/*.cpp)
else()
file(GLOB UTIL_ARCH_SOURCES . arch/linux/*.cpp)
endif()
add_library(paddle_utils STATIC
${UTIL_SOURCES}
${UTIL_ARCH_SOURCES})
${UTIL_ARCH_SOURCES}
${UTIL_RES})
add_style_check_target(paddle_utils ${UTIL_HEADERS})
add_style_check_target(paddle_utils ${UTIL_SOURCES}
${UTIL_ARCH_SOURCES})
Expand Down
31 changes: 23 additions & 8 deletions paddle/utils/PythonUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,18 @@ static std::recursive_mutex g_pyMutex;
PyGuard::PyGuard() : guard_(g_pyMutex) {}


static void printPyErrorStack(std::ostream& os, bool withEndl = false) {
static void printPyErrorStack(std::ostream& os, bool withEndl = false,
bool withPyPath = true) {
PyObject * ptype, *pvalue, *ptraceback;
PyErr_Fetch(&ptype, &pvalue, &ptraceback);
PyErr_NormalizeException(&ptype, &pvalue, &ptraceback);
PyErr_Clear();
if (withPyPath) {
os << "Current PYTHONPATH: " << py::repr(PySys_GetObject(strdup("path")));
if (withEndl) {
os << std::endl;
}
}
PyTracebackObject* obj = (PyTracebackObject*)ptraceback;

os << "Python Error: " << PyString_AsString(PyObject_Str(ptype))
Expand Down Expand Up @@ -114,10 +121,7 @@ PyObjectPtr callPythonFuncRetPyObj(const std::string& moduleName,
const std::string& funcName,
const std::vector<std::string>& args) {
PyGuard guard;
PyObjectPtr pyModuleName(PyString_FromString(moduleName.c_str()));
CHECK_PY(pyModuleName) << "Import PyModule failed" << moduleName;
PyObjectPtr pyModule(PyImport_Import(pyModuleName.get()));
CHECK_PY(pyModule) << "Import Python Module"<< moduleName << " failed.";
PyObjectPtr pyModule = py::import(moduleName);
PyObjectPtr pyFunc(PyObject_GetAttrString(pyModule.get(), funcName.c_str()));
CHECK_PY(pyFunc) << "GetAttrString failed.";
PyObjectPtr pyArgs(PyTuple_New(args.size()));
Expand All @@ -143,7 +147,7 @@ PyObjectPtr createPythonClass(
const std::vector<std::string>& args,
const std::map<std::string, std::string>& kwargs) {
PyGuard guard;
PyObjectPtr pyModule(PyImport_ImportModule(moduleName.c_str()));
PyObjectPtr pyModule = py::import(moduleName);
LOG(INFO) << "createPythonClass moduleName.c_str:" << moduleName.c_str();
CHECK_PY(pyModule) << "Import module " << moduleName << " failed.";
PyObjectPtr pyDict(PyModule_GetDict(pyModule.get()));
Expand Down Expand Up @@ -181,18 +185,29 @@ std::string getPyCallStack() {
printPyErrorStack(os, true);
return os.str();
}

PyObjectPtr import(const std::string &moduleName) {
auto module = PyImport_ImportModule(moduleName.c_str());
CHECK_PY(module) << "Import " << moduleName << "Error";
return PyObjectPtr(module);
}

} // namespace py

#endif

extern "C" {
extern const char enable_virtualenv_py[];
}
void initPython(int argc, char** argv) {
#ifndef PADDLE_NO_PYTHON
Py_SetProgramName(argv[0]);
Py_Initialize();
PySys_SetArgv(argc, argv);

// python blocks SIGINT. Need to enable it.
signal(SIGINT, SIG_DFL);

// Manually activate virtualenv when user is using virtualenv
PyRun_SimpleString(enable_virtualenv_py);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to use virtualenv here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

virtualenv is not need by us, but may be used by user.

These codes are used to support running Paddle under a virtualenv. And because Paddle process itself is a python process, so it should activate virtualenv when user have activated a virtualenv before.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change the code comment here. It makes people confused.

#endif
}

Expand Down
2 changes: 2 additions & 0 deletions paddle/utils/PythonUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ PyObjectPtr createPythonClass(const std::string& moduleName,
CHECK((x) != nullptr) << ::paddle::py::getPyCallStack()

namespace py {
PyObjectPtr import(const std::string& moduleName);

/**
* Cast a PyLong or PyInt to int type T.
* @tparam T return type.
Expand Down
10 changes: 10 additions & 0 deletions paddle/utils/enable_virtualenv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import os

def __activate_virtual_env__():
__path__ = os.getenv('VIRTUAL_ENV')
if __path__ is None:
return
__script__ = os.path.join(__path__, 'bin', 'activate_this.py')
execfile(__script__, {'__file__': __script__})

__activate_virtual_env__()
24 changes: 14 additions & 10 deletions python/paddle/trainer/PyDataProvider2.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ def provider(input_types=None, should_shuffle=None, pool_size=-1,
calc_batch_size=None,
cache=CacheType.NO_CACHE,
check=False, check_fail_continue=False,
use_dynamic_order=True,
init_hook=None, **kwargs):
"""
Provider decorator. Use it to make a function into PyDataProvider2 object.
Expand All @@ -228,9 +227,15 @@ def process(settings, file_name):
The configuration of data provider should be setup by\:

:param input_types: Specify the input types, can also be set in init_hook.
It is a list of InputType object. For example, input_types= \
[dense_vector(9), integer_value(2)].
:type input_types: list|tuple
It could be a list of InputType object. For example,
input_types=[dense_vector(9), integer_value(2)]. Or user
can set a dict of InputType object, which key is
data_layer's name. For example, input_types=\
{'img': img_features, 'label': label}. when using dict of
InputType, user could yield a dict of feature values, which
key is also data_layer's name.

:type input_types: list|tuple|dict

:param should_shuffle: True if data should shuffle. Pass None means shuffle
when is training and not to shuffle when is testing.
Expand Down Expand Up @@ -281,12 +286,6 @@ def process(settings, file_name):
drop the wrong format data when it is True. Has
no effect when check set to False.
:type check_fail_continue: bool

:param use_dynamic_order: Allow provider to yield a dictionary object, whose
key is a input data layer name, and value is the
feature value. The tuples are still allowed when
use_dynmaic_order is True.
:type use_dynamic_order: bool
"""

def __wrapper__(generator):
Expand Down Expand Up @@ -340,6 +339,11 @@ def __init__(self, file_list, **kwargs):
assert self.slots is not None
assert self.generator is not None

use_dynamic_order = False
if isinstance(self.slots, dict): # reorder input_types
self.slots = [self.slots[ipt] for ipt in self.input_order]
use_dynamic_order = True

if len(self.slots) == 1:
self.generator = SingleSlotWrapper(self.generator)

Expand Down
4 changes: 4 additions & 0 deletions python/paddle/trainer/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ def Inputs(*args):
if g_current_submodel is g_root_submodel:
g_config.model_config.input_layer_names.append(name)

@config_func
def HasInputsSet():
return len(g_config.model_config.input_layer_names) != 0


# Define the name of the output layers of the NeuralNetwork.
# Usually the output is simply the cost layer.
Expand Down
Loading