Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Gluon data 2.0: c++ dataloader and built-in image/bbox transforms #17841

Merged
merged 17 commits into from
May 7, 2020
156 changes: 156 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ typedef void *ExecutorHandle;
typedef void *DataIterCreator;
/*! \brief handle to a DataIterator */
typedef void *DataIterHandle;
/*! \brief handle a dataset creator */
typedef void *DatasetCreator;
/*! \brief handle to a Dataset */
typedef void *DatasetHandle;
/*! \brief handle to a BatchifyFunction creator*/
typedef void *BatchifyFunctionCreator;
/*! \brief handle to a BatchifyFunction */
typedef void *BatchifyFunctionHandle;
/*! \brief handle to KVStore */
typedef void *KVStoreHandle;
/*! \brief handle to RecordIO */
Expand Down Expand Up @@ -2670,6 +2678,13 @@ MXNET_DLL int MXDataIterNext(DataIterHandle handle,
*/
MXNET_DLL int MXDataIterBeforeFirst(DataIterHandle handle);

/*!
* \brief Call iterator.GetLenHint. Note that some iterators don't provide length.
* \param handle the handle to iterator
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXDataIterGetLenHint(DataIterHandle handle,
int64_t *len);
/*!
* \brief Get the handle to the NDArray of underlying data
* \param handle the handle pointer to the data iterator
Expand Down Expand Up @@ -2705,6 +2720,147 @@ MXNET_DLL int MXDataIterGetPadNum(DataIterHandle handle,
*/
MXNET_DLL int MXDataIterGetLabel(DataIterHandle handle,
NDArrayHandle *out);
/*!
* \brief Get the handles to specified underlying ndarrays of index
* \param handle the handle pointer to the data iterator
* \param num_outputs the length of outputs
* \param out the handle to an array of NDArrays that stores pointers to handles
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXDataIterGetItems(DataIterHandle handle,
int* num_outputs,
NDArrayHandle **outputs);

/*!
* \brief List all the available dataset entries
* \param out_size the size of returned datasets
* \param out_array the output dataset entries
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXListDatasets(uint32_t *out_size,
DatasetCreator **out_array);
/*!
* \brief Init an dataset, init with parameters
* the array size of passed in arguments
* \param handle of the dataset creator
* \param num_param number of parameter
* \param keys parameter keys
* \param vals parameter values
* \param out resulting dataset
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXDatasetCreateDataset(DatasetCreator handle,
uint32_t num_param,
const char **keys,
const char **vals,
DatasetHandle *out);
/*!
* \brief Get the detailed information about dataset.
* \param creator the DatasetCreator.
* \param name The returned name of the creator.
* \param description The returned description of the symbol.
* \param num_args Number of arguments.
* \param arg_names Name of the arguments.
* \param arg_type_infos Type informations about the arguments.
* \param arg_descriptions Description information about the arguments.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXDatasetGetDatasetInfo(DatasetCreator creator,
const char **name,
const char **description,
uint32_t *num_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions);
/*!
* \brief Free the handle to the IO module
* \param handle the handle pointer to the dataset
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXDatasetFree(DatasetHandle handle);
/*!
* \brief Get dataset overal length(size)
* \param handle the handle to dataset
* \param out return value of GetLen
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXDatasetGetLen(DatasetHandle handle,
uint64_t *out);
/*!
* \brief Get Output NDArray given specified indices
* \param handle the handle to dataset
* \param index the index of the dataset item to be retrieved
* \param num_outputs the number of output ndarrays
* \param outputs the pointers to handles of ndarrays
* \param is_scalar if not zeros then output should be casted to scalars
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXDatasetGetItems(DatasetHandle handle,
uint64_t index,
int* num_outputs,
NDArrayHandle **outputs);

/*!
* \brief List all the available batchify function entries
* \param out_size the size of returned batchify functions
* \param out_array the output batchify function entries
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXListBatchifyFunctions(uint32_t *out_size,
BatchifyFunctionCreator **out_array);
/*!
* \brief Init an batchify function, init with parameters
* the array size of passed in arguments
* \param handle of the batchify function creator
* \param num_param number of parameter
* \param keys parameter keys
* \param vals parameter values
* \param out resulting batchify function
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXBatchifyFunctionCreateFunction(BatchifyFunctionCreator handle,
uint32_t num_param,
const char **keys,
const char **vals,
BatchifyFunctionHandle *out);
/*!
* \brief Get the detailed information about batchify function.
* \param creator the batchifyFunctionCreator.
* \param name The returned name of the creator.
* \param description The returned description of the symbol.
* \param num_args Number of arguments.
* \param arg_names Name of the arguments.
* \param arg_type_infos Type informations about the arguments.
* \param arg_descriptions Description information about the arguments.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXBatchifyFunctionGetFunctionInfo(BatchifyFunctionCreator creator,
const char **name,
const char **description,
uint32_t *num_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions);
/*!
* \brief Invoke the Batchify Function
* \param handle the handle pointer to the batchify function
* \param batch_size the batch size
* \param num_output the number of ndarrays for output
* \param inputs the pointers to input ndarrays
* \param ouptuts the pointers to output ndarrays
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXBatchifyFunctionInvoke(BatchifyFunctionHandle handle,
int batch_size,
int num_output,
NDArrayHandle *inputs,
NDArrayHandle **outputs);
/*!
* \brief Free the handle to the IO module
* \param handle the handle pointer to the batchify function
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXBatchifyFunctionFree(BatchifyFunctionHandle handle);
//--------------------------------------------
// Part 6: basic KVStore interface
//--------------------------------------------
Expand Down
98 changes: 97 additions & 1 deletion include/mxnet/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ class IIterator : public dmlc::DataIter<DType> {
inline void SetDataName(const std::string data_name) {
data_names.push_back(data_name);
}
/*! \brief request iterator length hint for current epoch.
* Note that the returned value can be < 0, indicating
* that the length of iterator is unknown unless you went through all data.
*/
virtual int64_t GetLenHint(void) const {
return -1;
}
}; // class IIterator

/*! \brief a single data instance */
Expand Down Expand Up @@ -104,7 +111,7 @@ struct DataIteratorReg
*
* \code
* // example of registering a mnist iterator
* REGISTER_IO_ITE(MNISTIter)
* REGISTER_IO_ITER(MNISTIter)
* .describe("Mnist data iterator")
* .set_body([]() {
* return new PrefetcherIter(new MNISTIter());
Expand All @@ -113,5 +120,94 @@ struct DataIteratorReg
*/
#define MXNET_REGISTER_IO_ITER(name) \
DMLC_REGISTRY_REGISTER(::mxnet::DataIteratorReg, DataIteratorReg, name)

/*!
* \brief A random accessable dataset which provides GetLen() and GetItem().
* Unlike DataIter, it's a static lookup storage which is friendly to random access.
* The dataset itself should NOT contain data processing, which should be applied during
* data augmentation or transformation processes.
*/
class Dataset {
public:
/*!
* \brief Get the size of the dataset
*/
virtual uint64_t GetLen(void) const = 0;
/*!
* \brief Get the ndarray items given index in dataset
* \param idx the integer index for required data
* \param ret the returned ndarray items
*/
virtual bool GetItem(uint64_t idx, std::vector<NDArray>* ret) = 0;
// virtual destructor
virtual ~Dataset(void) {}
}; // class Dataset

/*! \brief typedef the factory function of dataset */
typedef std::function<Dataset *(
const std::vector<std::pair<std::string, std::string> >&)> DatasetFactory;
/*!
* \brief Registry entry for Dataset factory functions.
*/
struct DatasetReg
: public dmlc::FunctionRegEntryBase<DatasetReg,
DatasetFactory> {
};
//--------------------------------------------------------------
// The following part are API Registration of Datasets
//--------------------------------------------------------------
/*!
* \brief Macro to register Datasets
*
* \code
* // example of registering an image sequence dataset
* REGISTER_IO_ITE(ImageSequenceDataset)
* .describe("image sequence dataset")
* .set_body([]() {
* return new ImageSequenceDataset();
* });
* \endcode
*/
#define MXNET_REGISTER_IO_DATASET(name) \
DMLC_REGISTRY_REGISTER(::mxnet::DatasetReg, DatasetReg, name)

class BatchifyFunction {
public:
/*! \brief Destructor */
virtual ~BatchifyFunction(void) {}
/*! \brief The batchify logic */
virtual bool Batchify(const std::vector<std::vector<NDArray> >& inputs,
std::vector<NDArray>* outputs) = 0;
}; // class BatchifyFunction

using BatchifyFunctionPtr = std::shared_ptr<BatchifyFunction>;

/*! \brief typedef the factory function of data sampler */
typedef std::function<BatchifyFunction *(
const std::vector<std::pair<std::string, std::string> >&)> BatchifyFunctionFactory;
/*!
* \brief Registry entry for DataSampler factory functions.
*/
struct BatchifyFunctionReg
: public dmlc::FunctionRegEntryBase<BatchifyFunctionReg,
BatchifyFunctionFactory> {
};
//--------------------------------------------------------------
// The following part are API Registration of Batchify Function
//--------------------------------------------------------------
/*!
* \brief Macro to register Batchify Functions
*
* \code
* // example of registering a Batchify Function
* MXNET_REGISTER_IO_BATCHIFY_FUNCTION(StackBatchify)
* .describe("Stack Batchify Function")
* .set_body([]() {
* return new StackBatchify();
* });
* \endcode
*/
#define MXNET_REGISTER_IO_BATCHIFY_FUNCTION(name) \
DMLC_REGISTRY_REGISTER(::mxnet::BatchifyFunctionReg, BatchifyFunctionReg, name)
} // namespace mxnet
#endif // MXNET_IO_H_
2 changes: 2 additions & 0 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,8 @@ def _load_lib():
ExecutorHandle = ctypes.c_void_p
DataIterCreatorHandle = ctypes.c_void_p
DataIterHandle = ctypes.c_void_p
DatasetHandle = ctypes.c_void_p
BatchifyFunctionhandle = ctypes.c_void_p
KVStoreHandle = ctypes.c_void_p
RecordIOHandle = ctypes.c_void_p
RtcHandle = ctypes.c_void_p
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/gluon/contrib/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
"""Contrib datasets."""

from . import text
from . import vision

from .sampler import *
22 changes: 22 additions & 0 deletions python/mxnet/gluon/contrib/data/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=wildcard-import
"""Contrib vision utilities."""
from .transforms import *
from .dataloader import *
Loading