diff --git a/Makefile b/Makefile index b19555fe962..baaaa8cb739 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ PROJECT := caffe -CONFIG_FILE := Makefile.config +CONFIG_FILE ?= Makefile.config include $(CONFIG_FILE) BUILD_DIR_LINK := $(BUILD_DIR) @@ -269,7 +269,7 @@ endif # Debugging ifeq ($(DEBUG), 1) - COMMON_FLAGS += -DDEBUG -g -O0 + COMMON_FLAGS += -DDEBUG -g -O0 -DBOOST_NOINLINE='__attribute__ ((noinline))' NVCCFLAGS += -G else COMMON_FLAGS += -DNDEBUG -O2 @@ -291,6 +291,11 @@ ifeq ($(CPU_ONLY), 1) COMMON_FLAGS += -DCPU_ONLY endif +ifeq ($(RDMA), 1) + COMMON_FLAGS += -DRDMA + LIBRARIES += ibverbs ibumad +endif + # BLAS configuration (default = ATLAS) BLAS ?= atlas ifeq ($(BLAS), mkl) @@ -549,6 +554,13 @@ $(GTEST_OBJ): $(GTEST_SRC) | $(GTEST_BUILD_DIR) @ cat $@.$(WARNS_EXT) @ echo +$(OBJ_BUILD_DIR)/%.cuo: src/$(PROJECT)/%.cu $(HXX_SRCS) \ + | $(LAYER_BUILD_DIR) + $(CUDA_DIR)/bin/nvcc $(NVCCFLAGS) $(CUDA_ARCH) -c $< -o $@ 2> $@.$(WARNS_EXT) \ + || (cat $@.$(WARNS_EXT); exit 1) + @ cat $@.$(WARNS_EXT) + @ echo + $(LAYER_BUILD_DIR)/%.cuo: src/$(PROJECT)/layers/%.cu $(HXX_SRCS) \ | $(LAYER_BUILD_DIR) $(CUDA_DIR)/bin/nvcc $(NVCCFLAGS) $(CUDA_ARCH) -c $< -o $@ 2> $@.$(WARNS_EXT) \ diff --git a/Makefile.config.example b/Makefile.config.example index e11db51395d..9f0269921c3 100644 --- a/Makefile.config.example +++ b/Makefile.config.example @@ -7,6 +7,9 @@ # CPU-only switch (uncomment to build without GPU support). # CPU_ONLY := 1 +# Parallelization over InfiniBand or RoCE +# RDMA := 1 + # To customize your choice of compiler, uncomment and set the following. # N.B. the default for Linux is g++ and the default for OSX is clang++ # CUSTOM_CXX := g++ diff --git a/examples/parallel/base.hpp b/examples/parallel/base.hpp new file mode 100644 index 00000000000..d45365355b9 --- /dev/null +++ b/examples/parallel/base.hpp @@ -0,0 +1,214 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +using namespace caffe; + +// Shared code for parallel examples. Should be replaced by some kind of cluster +// deployment and visualization solution. + +// Context for a solver running in a thread. Both initialization and run +// of the solver are done on the thread, to point to the same instance of the +// thread-local Caffe singleton. +class SolverContext : public Threaded { + public: + // Main solver does testing, display, snapshots etc. + SolverContext(const Params& params, + const SolverParameter& solver_param, Solver* solver) + : params_(params), + solver_param_(solver_param), + worker_(solver == NULL), + solver_(solver) { + + if (worker_) { + solver_param_.clear_display(); + solver_param_.clear_snapshot(); + } + } + + virtual void create_solver() { + if (worker_) { + solver_ = new SGDSolver(solver_param_, true); + CHECK(!solver_->test_nets().size()); // Only training + } + } + + virtual void delete_solver() { + if (worker_) + delete solver_; + } + + inline Solver* solver() const { + return solver_; + } + + virtual void stats(ostream& s) const { + } + + protected: + const Params& params_; + SolverParameter solver_param_; + const bool worker_; + Solver* solver_; +}; + +// Runs a CPU solver on a thread +class CPUContext : public SolverContext { + public: + CPUContext(const Params& params, const SolverParameter& solver_param, + Solver* solver = NULL) + : SolverContext(params, solver_param, solver) { + } + + void run() { + create_solver(); + params_.configure(solver_); + solver_->Solve(); + // Wait until asked to stop before destroying, monitor might + // still be accessing fields + if (worker_) + while (!must_stop()) + sleep(1); + delete_solver(); + } +}; + +#ifndef CPU_ONLY + +// Runs a GPU solver on a thread +class GPUContext : public SolverContext { + public: + GPUContext(const Params& params, const SolverParameter& solver_param, + GPUParams* gpu_params, Solver* solver = NULL) + : SolverContext(params, solver_param, solver), + gpu_params_(gpu_params) { + } + + void run() { + create_solver(); + gpu_params_->configure(solver_); + solver_->Solve(); + // Wait until asked to stop before destroying, monitor might + // still be accessing fields + if (worker_) + while (!must_stop()) + sleep(1); + delete_solver(); + } + + protected: + GPUParams* gpu_params_; +}; + +// Runs a GPU solver on a thread with CPU sync +class CPUGPUContext : public SolverContext { + public: + CPUGPUContext(const Params& params, + const SolverParameter& solver_param, Solver* solver = + NULL) + : SolverContext(params, solver_param, solver), + gpu_params_(), + sync_() { + } + + void run() { + create_solver(); + gpu_params_ = new GPUParams(params_, solver_param_.device_id()); + sync_ = new CPUGPUSync(*gpu_params_); + gpu_params_->configure(solver_); + sync_->start(); + solver_->Solve(); + // Wait until asked to stop before destroying, monitor might + // still be accessing fields + if (worker_) + while (!must_stop()) + sleep(1); + delete sync_; + delete gpu_params_; + delete_solver(); + } + + virtual void stats(ostream& s) const { + s << "GPU " << solver_param_.device_id() << " "; + if (sync_) { + sync_->calls().show(s); + s << ", "; + sync_->cycles().show(s); + } else + s << "starting"; + s << ", "; + } + + protected: + GPUParams* gpu_params_; + CPUGPUSync* sync_; +}; + +#endif + +// Displays stats about a set of solvers. Also keeps track and updates +// the global count of iterations (needed to adjust hyperparams). +class Monitor : public Threaded { + public: + Monitor(Params& params, const vector& solvers) + : params_(params), + solvers_(solvers), + total_iters_("total") { + } + + virtual ~Monitor() { + } + + void step(ostream* s = NULL) { + *s << "Monitor - iters: "; + + int total = 0; + bool all = true; // TODO remove + for (int i = 0; i < solvers_.size(); ++i) { + SolverContext* ctx = solvers_[i]; + int n = ctx->solver() ? ctx->solver()->iter() : 0; + total += n; + if (s) + *s << n << ", "; + if (!n) + all = false; + } + if (all) { + //cudaProfilerStart(); + //LOG(INFO)<< "Started profiler\n"; + } + params_.iterations(total); + total_iters_.value(total); + if (s) { + total_iters_.show(*s); + *s << ", "; + for (int i = 0; i < solvers_.size(); ++i) + solvers_[i]->stats(*s); + } + } + + void run() { + int every_seconds = 10; + time_t start = time(0); + while (!must_stop()) { + sleep(every_seconds); + + ostringstream s; + step(&s); + s << "\n"; + LOG(INFO)<< s.str(); + LOG(INFO)<< "Training time: " << (time(0) - start); + } + } + +protected: + Params& params_; + const vector& solvers_; + Meter total_iters_; +}; diff --git a/examples/parallel/cifar.prototxt b/examples/parallel/cifar.prototxt new file mode 100644 index 00000000000..dcf1b7764cf --- /dev/null +++ b/examples/parallel/cifar.prototxt @@ -0,0 +1,203 @@ +name: "CIFAR10_full" +layers { + name: "cifar" + type: DATA + top: "data" + top: "label" + data_param { + source: "/scratch/cifar10_train" + backend: LMDB + rand_skip: 10000 + batch_size: 16 + } + transform_param { + mean_file: "/data/shared/cifar10_mean" + } + include: { phase: TRAIN } +} +layers { + name: "cifar" + type: DATA + top: "data" + top: "label" + data_param { + source: "/scratch/cifar10_val" + backend: LMDB + batch_size: 100 + } + transform_param { + mean_file: "/data/shared/cifar10_mean" + } + include: { phase: TEST } +} +layers { + name: "conv1" + type: CONVOLUTION + bottom: "data" + top: "conv1" + blobs_lr: 1 + blobs_lr: 2 + convolution_param { + num_output: 32 + pad: 2 + kernel_size: 5 + stride: 1 + weight_filler { + type: "gaussian" + std: 0.0001 + } + bias_filler { + type: "constant" + } + } +} +layers { + name: "pool1" + type: POOLING + bottom: "conv1" + top: "pool1" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + name: "relu1" + type: RELU + bottom: "pool1" + top: "pool1" +} +layers { + name: "norm1" + type: LRN + bottom: "pool1" + top: "norm1" + lrn_param { + norm_region: WITHIN_CHANNEL + local_size: 3 + alpha: 5e-05 + beta: 0.75 + } +} +layers { + name: "conv2" + type: CONVOLUTION + bottom: "norm1" + top: "conv2" + blobs_lr: 1 + blobs_lr: 2 + convolution_param { + num_output: 32 + pad: 2 + kernel_size: 5 + stride: 1 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + } + } +} +layers { + name: "relu2" + type: RELU + bottom: "conv2" + top: "conv2" +} +layers { + name: "pool2" + type: POOLING + bottom: "conv2" + top: "pool2" + pooling_param { + pool: AVE + kernel_size: 3 + stride: 2 + } +} +layers { + name: "norm2" + type: LRN + bottom: "pool2" + top: "norm2" + lrn_param { + norm_region: WITHIN_CHANNEL + local_size: 3 + alpha: 5e-05 + beta: 0.75 + } +} +layers { + name: "conv3" + type: CONVOLUTION + bottom: "norm2" + top: "conv3" + convolution_param { + num_output: 64 + pad: 2 + kernel_size: 5 + stride: 1 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + } + } +} +layers { + name: "relu3" + type: RELU + bottom: "conv3" + top: "conv3" +} +layers { + name: "pool3" + type: POOLING + bottom: "conv3" + top: "pool3" + pooling_param { + pool: AVE + kernel_size: 3 + stride: 2 + } +} +layers { + name: "ip1" + type: INNER_PRODUCT + bottom: "pool3" + top: "ip1" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 250 + weight_decay: 0 + inner_product_param { + num_output: 10 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + } + } +} +layers { + name: "accuracy" + type: ACCURACY + bottom: "ip1" + bottom: "label" + top: "accuracy" + include: { phase: TEST } +} +layers { + name: "loss" + type: SOFTMAX_LOSS + bottom: "ip1" + bottom: "label" + top: "loss" +} diff --git a/examples/parallel/cifar_solver.prototxt b/examples/parallel/cifar_solver.prototxt new file mode 100644 index 00000000000..459668326f4 --- /dev/null +++ b/examples/parallel/cifar_solver.prototxt @@ -0,0 +1,20 @@ +# The train/test net protocol buffer definition +net: "examples/parallel/cifar.prototxt" +# test_iter specifies how many forward passes the test should carry out. +# In the case of CIFAR10, we have test batch size 100 and 100 test iterations, +# covering the full 10,000 testing images. +test_iter: 100 +# Carry out testing every 1000 training iterations. +test_interval: 1000 +# The base learning rate, momentum and the weight decay of the network. +base_lr: 0.001 +momentum: 0.0 +weight_decay: 0.004 +# The learning rate policy +lr_policy: "fixed" +# Display every 200 iterations +display: 200 +# The maximum number of iterations +max_iter: 60000 +# snapshot intermediate results +snapshot: 0 diff --git a/examples/parallel/gpus.cpp b/examples/parallel/gpus.cpp new file mode 100644 index 00000000000..1e59841c572 --- /dev/null +++ b/examples/parallel/gpus.cpp @@ -0,0 +1,85 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "base.hpp" + +using namespace std; +using namespace caffe; + +#ifndef CPU_ONLY + +// Trains a net on multiple GPUs on one box. C.f. GPUSync in parallel.h. +// +// Example launch on GPU 0 and 1: +// make -j +// export LD_LIBRARY_PATH=/usr/local/lib:/usr/local/cuda/lib64 +// export GLOG_logtostderr=1 +// build/examples/parallel/gpus.bin examples/parallel/mnist_solver.prototxt 0:1 + +int main(int argc, char** argv) { + ::google::InitGoogleLogging(argv[0]); + ::google::InstallFailureSignalHandler(); + + if (argc != 3) { + printf("Usage: gpus.bin solver_proto_file gpu_id[:gpu_id][...]\n"); + return 1; + } + + SolverParameter solver_param; + ReadProtoFromTextFile(argv[1], &solver_param); + + vector gpus; + vector gpu_strings; + boost::split(gpu_strings, argv[2], boost::is_any_of(":")); + for (int i = 0; i < gpu_strings.size(); ++i) + gpus.push_back(atoi(gpu_strings[i].c_str())); + + solver_param.set_device_id(gpus[0]); + SGDSolver main(solver_param); + + // Shared network weights + Params params(main.net()->params()); + + // Create contexts + vector solvers(gpus.size()); + solvers[0] = new CPUGPUContext(params, solver_param, &main); + for (int i = 1; i < gpus.size(); ++i) { + solver_param.set_device_id(gpus[i]); + solvers[i] = new CPUGPUContext(params, solver_param); + solvers[i]->start(); + } + + // Start monitor + Monitor monitor(params, solvers); + monitor.start(); + + // Run main on current thread + solvers[0]->run(); + + monitor.stop(); + LOG(INFO)<< "Monitor stop\n"; + + for (int i = 1; i < solvers.size(); ++i) + solvers[i]->stop(); + + for (int i = 1; i < solvers.size(); ++i) + delete solvers[i]; +} + +#else +int main(int argc, char *argv[]) { +} +#endif + diff --git a/examples/parallel/hogwild.cpp b/examples/parallel/hogwild.cpp new file mode 100644 index 00000000000..77e03369880 --- /dev/null +++ b/examples/parallel/hogwild.cpp @@ -0,0 +1,83 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "caffe/filler.hpp" +#include "caffe/parallel.hpp" +#include "base.hpp" + +using namespace std; +using namespace caffe; + +// Trains a net in parallel on multiple CPU cores. C.f. CPUSync in parallel.h. +// +// Your BLAS library needs to let the application manage its threads, e.g. +// for OpenBLAS, compile with no threading (USE_THREAD = 0 in Makefile.rule). +// Performance is linear at first, but then plateaus on large nets as the number +// of cores is increased, probably as the CPU runs out of memory bandwidth. +// +// Example launch on 4 cores: +// make -j +// export LD_LIBRARY_PATH=:/usr/local/lib:/usr/local/cuda/lib64 +// export GLOG_logtostderr=1 +// build/examples/parallel/hogwild.bin examples/parallel/mnist_solver.prototxt 4 + +int main(int argc, char** argv) { + ::google::InitGoogleLogging(argv[0]); + ::google::InstallFailureSignalHandler(); + + if (argc < 2 || argc > 3) { + printf("Usage: hogwild.bin solver_proto_file [number_of_cores]\n"); + return 1; + } + + SolverParameter solver_param; + ReadProtoFromTextFile(argv[1], &solver_param); + + int cores = argc == 3 ? atoi(argv[2]) : sysconf(_SC_NPROCESSORS_ONLN); + + // Override in code so that proto file can be shared with other examples + solver_param.set_solver_mode(SolverParameter::CPU); + + // Main solver + SGDSolver main(solver_param); + + // Shared network weights + Params params(main.net()->params()); + + // Create contexts + vector solvers(cores); + solvers[0] = new CPUContext(params, solver_param, &main); + for (int i = 1; i < cores; ++i) { + solvers[i] = new CPUContext(params, solver_param); + solvers[i]->start(); + } + + // Start monitor + Monitor monitor(params, solvers); + monitor.start(); + + // Run main on current thread + solvers[0]->run(); + + monitor.stop(); + LOG(INFO)<< "Monitor stop\n"; + + for (int i = 1; i < solvers.size(); ++i) + solvers[i]->stop(); + + for (int i = 1; i < solvers.size(); ++i) + delete solvers[i]; +} diff --git a/examples/parallel/imagenet.prototxt b/examples/parallel/imagenet.prototxt new file mode 100644 index 00000000000..b3278bb6eca --- /dev/null +++ b/examples/parallel/imagenet.prototxt @@ -0,0 +1,348 @@ +name: "CaffeNet" +layers { + name: "data" + type: DATA + top: "data" + top: "label" + data_param { + source: "/scratch/ilsvrc12/imagenet_1000_train" + source: "/data/shared/ilsvrc12/imagenet_1000_train" + backend: LMDB + rand_skip: 10000 + batch_size: 256 + } + transform_param { + crop_size: 227 + mean_file: "data/ilsvrc12/imagenet_mean.binaryproto" + mirror: true + } + include: { phase: TRAIN } +} +layers { + name: "data" + type: DATA + top: "data" + top: "label" + data_param { + source: "/scratch/ilsvrc12/imagenet_1000_val" + backend: LMDB + batch_size: 50 + } + transform_param { + crop_size: 227 + mean_file: "data/ilsvrc12/imagenet_mean.binaryproto" + mirror: false + } + include: { phase: TEST } +} +layers { + name: "conv1" + type: CONVOLUTION + bottom: "data" + top: "conv1" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 96 + kernel_size: 11 + stride: 4 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + name: "relu1" + type: RELU + bottom: "conv1" + top: "conv1" +} +layers { + name: "pool1" + type: POOLING + bottom: "conv1" + top: "pool1" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + name: "norm1" + type: LRN + bottom: "pool1" + top: "norm1" + lrn_param { + local_size: 5 + alpha: 0.0001 + beta: 0.75 + } +} +layers { + name: "conv2" + type: CONVOLUTION + bottom: "norm1" + top: "conv2" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 256 + pad: 2 + kernel_size: 5 + group: 2 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu2" + type: RELU + bottom: "conv2" + top: "conv2" +} +layers { + name: "pool2" + type: POOLING + bottom: "conv2" + top: "pool2" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + name: "norm2" + type: LRN + bottom: "pool2" + top: "norm2" + lrn_param { + local_size: 5 + alpha: 0.0001 + beta: 0.75 + } +} +layers { + name: "conv3" + type: CONVOLUTION + bottom: "norm2" + top: "conv3" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 384 + pad: 1 + kernel_size: 3 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + name: "relu3" + type: RELU + bottom: "conv3" + top: "conv3" +} +layers { + name: "conv4" + type: CONVOLUTION + bottom: "conv3" + top: "conv4" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 384 + pad: 1 + kernel_size: 3 + group: 2 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu4" + type: RELU + bottom: "conv4" + top: "conv4" +} +layers { + name: "conv5" + type: CONVOLUTION + bottom: "conv4" + top: "conv5" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 256 + pad: 1 + kernel_size: 3 + group: 2 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu5" + type: RELU + bottom: "conv5" + top: "conv5" +} +layers { + name: "pool5" + type: POOLING + bottom: "conv5" + top: "pool5" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} +layers { + name: "fc6" + type: INNER_PRODUCT + bottom: "pool5" + top: "fc6" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 4096 + weight_filler { + type: "gaussian" + std: 0.005 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu6" + type: RELU + bottom: "fc6" + top: "fc6" +} +layers { + name: "drop6" + type: DROPOUT + bottom: "fc6" + top: "fc6" + dropout_param { + dropout_ratio: 0.5 + } +} +layers { + name: "fc7" + type: INNER_PRODUCT + bottom: "fc6" + top: "fc7" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 4096 + weight_filler { + type: "gaussian" + std: 0.005 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu7" + type: RELU + bottom: "fc7" + top: "fc7" +} +layers { + name: "drop7" + type: DROPOUT + bottom: "fc7" + top: "fc7" + dropout_param { + dropout_ratio: 0.5 + } +} +layers { + name: "fc8" + type: INNER_PRODUCT + bottom: "fc7" + top: "fc8" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 1000 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + name: "accuracy" + type: ACCURACY + bottom: "fc8" + bottom: "label" + top: "accuracy" + include: { phase: TEST } +} +layers { + name: "loss" + type: SOFTMAX_LOSS + bottom: "fc8" + bottom: "label" + top: "loss" +} diff --git a/examples/parallel/imagenet_solver.prototxt b/examples/parallel/imagenet_solver.prototxt new file mode 100644 index 00000000000..78c8203e806 --- /dev/null +++ b/examples/parallel/imagenet_solver.prototxt @@ -0,0 +1,13 @@ +net: "examples/parallel/imagenet.prototxt" +test_iter: 1000 +test_interval: 1000 +base_lr: 0.01 +lr_policy: "step" +gamma: 0.1 +stepsize: 100000 +display: 20 +max_iter: 450000 +momentum: 0.9 +weight_decay: 0.0005 +snapshot: 10000 +snapshot_prefix: "examples/parallel/caffe_imagenet_train" diff --git a/examples/parallel/mnist.prototxt b/examples/parallel/mnist.prototxt new file mode 100644 index 00000000000..6a1399a19f2 --- /dev/null +++ b/examples/parallel/mnist.prototxt @@ -0,0 +1,149 @@ +name: "LeNet" +layers { + name: "mnist" + type: DATA + top: "data" + top: "label" + data_param { + source: "/scratch/mnist/mnist-train-8m" + source: "/data/shared/mnist/mnist-train-8m" + backend: LMDB + batch_size: 64 + rand_skip: 10000 + } + transform_param { + scale: 0.00390625 + } + include: { phase: TRAIN } +} +layers { + name: "mnist" + type: DATA + top: "data" + top: "label" + data_param { + source: "/scratch/mnist/mnist-val" + backend: LMDB + batch_size: 100 + } + transform_param { + scale: 0.00390625 + } + include: { phase: TEST } +} + +layers { + name: "conv1" + type: CONVOLUTION + bottom: "data" + top: "conv1" + blobs_lr: 1 + blobs_lr: 2 + convolution_param { + num_output: 20 + kernel_size: 5 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + } + } +} +layers { + name: "pool1" + type: POOLING + bottom: "conv1" + top: "pool1" + pooling_param { + pool: MAX + kernel_size: 2 + stride: 2 + } +} +layers { + name: "conv2" + type: CONVOLUTION + bottom: "pool1" + top: "conv2" + blobs_lr: 1 + blobs_lr: 2 + convolution_param { + num_output: 50 + kernel_size: 5 + stride: 1 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + } + } +} +layers { + name: "pool2" + type: POOLING + bottom: "conv2" + top: "pool2" + pooling_param { + pool: MAX + kernel_size: 2 + stride: 2 + } +} +layers { + name: "ip1" + type: INNER_PRODUCT + bottom: "pool2" + top: "ip1" + blobs_lr: 1 + blobs_lr: 2 + inner_product_param { + num_output: 500 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + } + } +} +layers { + name: "relu1" + type: RELU + bottom: "ip1" + top: "ip1" +} +layers { + name: "ip2" + type: INNER_PRODUCT + bottom: "ip1" + top: "ip2" + blobs_lr: 1 + blobs_lr: 2 + inner_product_param { + num_output: 10 + weight_filler { + type: "xavier" + } + bias_filler { + type: "constant" + } + } +} +layers { + name: "accuracy" + type: ACCURACY + bottom: "ip2" + bottom: "label" + top: "accuracy" + include: { phase: TEST } +} +layers { + name: "loss" + type: SOFTMAX_LOSS + bottom: "ip2" + bottom: "label" + top: "loss" +} diff --git a/examples/parallel/mnist_solver.prototxt b/examples/parallel/mnist_solver.prototxt new file mode 100644 index 00000000000..2fd7bf8fcc8 --- /dev/null +++ b/examples/parallel/mnist_solver.prototxt @@ -0,0 +1,24 @@ +# The train/test net protocol buffer definition +net: "examples/parallel/mnist.prototxt" +# test_iter specifies how many forward passes the test should carry out. +# In the case of MNIST, we have test batch size 100 and 100 test iterations, +# covering the full 10,000 testing images. +test_iter: 100 +# Carry out testing every 500 training iterations. +test_interval: 500 +# The base learning rate, momentum and the weight decay of the network. +base_lr: 0.01 +momentum: 0.0 +weight_decay: 0.0005 +# The learning rate policy +lr_policy: "inv" +gamma: 0.0001 +power: 0.75 +# Display every 100 iterations +display: 100 +# The maximum number of iterations +max_iter: 10000 +# snapshot intermediate results +snapshot: 0 +snapshot_prefix: "examples/parallel/lenet" +# solver_mode: CPU \ No newline at end of file diff --git a/examples/parallel/raw.cpp b/examples/parallel/raw.cpp new file mode 100644 index 00000000000..3d0641c06f6 --- /dev/null +++ b/examples/parallel/raw.cpp @@ -0,0 +1,169 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "base.hpp" + +using namespace std; +using namespace caffe; + +// Trains a net over multiple boxes through high perf. networking. C.f RawSync in +// parallel.h. The application must be launched on each box with the local CPUs & GPUs +// to use, and the list of MAC addresses of all adapters in the cluster. The MAC +// list must be the same on all boxes. +// +// Example launch on GPU 0, GPU 1, 4 cores on two machines: +// make -j +// (single thread BLAS is only needed for CPU training, c.f. hogwild.cpp) +// export LD_LIBRARY_PATH=:/usr/local/lib:/usr/local/cuda/lib64 +// export GLOG_logtostderr=1 +// build/examples/parallel/raw.bin examples/parallel/mnist_solver.prototxt 0:1:4 002590ca9998:002590ca9956 + +#ifdef __linux__ + +// Monitors solvers and network +class RawMonitor : public Monitor { + public: + RawMonitor(Params& params, const vector& solvers, + RawSync& raw) + : Monitor(params, solvers), + raw_(raw) { + } + + void stats(const Ring& r, ostream& s) { + s << r.adapter() + " "; + r.sent().show(s); + s << ", "; + r.recv().show(s); + } + + void run() { + time_t start = time(0); + for (;;) { + sleep(10); + + ostringstream s; + step(&s); + + s << "raw: "; + stats(raw_.master(), s); + s << ", "; + stats(raw_.worker(), s); + s << ", "; + raw_.cycles().show(s); + s << "\n"; + LOG(INFO)<< s.str(); + LOG(INFO)<< "Training time: " << (time(0) - start); + } + } + + const RawSync& raw_; +}; + +int main(int argc, char** argv) { + ::google::InitGoogleLogging(argv[0]); + ::google::InstallFailureSignalHandler(); + + if (argc < 4 || argc > 5) { + printf("Usage: raw.bin solver_proto_file " // + "[gpu_id][:gpu_id][...]:cpu_cores " + "mac_address[:mac_address][:...] [secondary_mac][:secondary_mac][:...]\n"); + printf("Raw socket is a privileged operation, either run as root or " // + "set the capability on the executable: " + "sudo setcap cap_net_raw+ep raw.bin\n"); + return 1; + } + + SolverParameter solver_param; + ReadProtoFromTextFile(argv[1], &solver_param); + + vector procs; + boost::split(procs, argv[2], boost::is_any_of(":")); + vector gpus; + for (int i = 0; i < procs.size() - 1; ++i) + gpus.push_back(atoi(procs[i].c_str())); + int cores = atoi(procs[procs.size() - 1].c_str()); + + vector macs; + boost::split(macs, argv[3], boost::is_any_of(":")); + + vector secs; + if (argc == 5) + boost::split(secs, argv[4], boost::is_any_of(":")); + + // Set main solver to first GPU if available, or CPU + if (gpus.size()) + solver_param.set_device_id(gpus[0]); + else + solver_param.set_solver_mode(SolverParameter::CPU); + SGDSolver main(solver_param); + + // Shared network weights + Params params(main.net()->params(), "/dev/shm/test"); + + // Raw socket synchronization + RawSync raw(params, macs, secs); + raw.start(); + + LOG(INFO)<< "Waiting for other boxes\n"; + while (!raw.ready()) + sleep(1); + LOG(INFO)<< "Start training\n"; + + // Create contexts + vector contexts(gpus.size() + cores); + if (gpus.size()) { +#ifndef CPU_ONLY + contexts[0] = new CPUGPUContext(params, solver_param, &main); +#else + NO_GPU; +#endif + } else { + contexts[0] = new CPUContext(params, solver_param, &main); + } +#ifndef CPU_ONLY + // GPUs + for (int i = 1; i < gpus.size(); ++i) { + solver_param.set_device_id(gpus[i]); + contexts[i] = new CPUGPUContext(params, solver_param); + contexts[i]->start(); + } +#endif + // CPUs + solver_param.set_solver_mode(SolverParameter::CPU); + for (int i = max(1, (int) gpus.size()); i < gpus.size() + cores; ++i) { + contexts[i] = new CPUContext(params, solver_param); + contexts[i]->start(); + } + + // Start monitor + RawMonitor monitor(params, contexts, raw); + monitor.start(); + + // Run main on current thread + contexts[0]->run(); + + monitor.stop(); + LOG(INFO)<< "Monitor stop\n"; + + for (int i = 1; i < contexts.size(); ++i) + contexts[i]->stop(); + + for (int i = 1; i < contexts.size(); ++i) + delete contexts[i]; +} + +#endif diff --git a/examples/parallel/rdma.cpp b/examples/parallel/rdma.cpp new file mode 100644 index 00000000000..13088a3467c --- /dev/null +++ b/examples/parallel/rdma.cpp @@ -0,0 +1,364 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "caffe/caffe.hpp" +#include "caffe/filler.hpp" +#include "caffe/parallel.hpp" +#include "base.hpp" + +#ifdef RDMA +#include + +using namespace std; + +// Trains a net on multiple boxes over InfiniBand or RoCE. RDMA addresses +// are exchanged over a socket, first launch a server instance, then clients: + +// Server: rdma.bin port solver.prototxt :: +// Client: rdma.bin port solver.prototxt :: + +// e.g. for 4 machines with 4 GPUs each: +// rdma.bin 3 4444 examples/parallel/mnist_solver.prototxt 0:1:2:3:0 +// then 3 times: +// rdma.bin 4444 examples/parallel/mnist_solver.prototxt 0:1:2:3:0 + +// Monitors solvers and network +class IBMonitor : public Monitor { + public: + IBMonitor(Params& params, const vector& solvers, + const vector*> syncs, + vector*> gpu_params) + : Monitor(params, solvers), + syncs_(syncs), + gpu_params_(gpu_params) { + } + + void stats(const IBChannel& c, ostream& s) { + s << c.adapter() + " "; + c.sent().show(s); + s << ", "; + c.recv().show(s); + } + + void run() { + sleep(5); + time_t start = time(0); + void* d0; + void* d1; + size_t len = gpu_params_[0]->params().len_buff(); + size_t size = len * sizeof(float); + CaffeMallocHost(&d0, size); + CaffeMallocHost(&d1, size); + for (;;) { + sleep(2); + + ostringstream s; + step(&s); + + for (int i = 0; i < syncs_.size(); ++i) { + s << "RDMA " << i << ": ucast "; + stats(syncs_[i]->ucast(), s); + s << ", mcast "; + stats(syncs_[i]->mcast(), s); + s << ", "; + syncs_[i]->cycles().show(s); + s << "\n"; + } + LOG(INFO)<< s.str(); + LOG(INFO)<< "Training time: " << (time(0) - start); + } + } + + const vector*> syncs_; + const vector*> gpu_params_; +}; + +static void exch_server(const int clients, const char* port, + vector* ucast_addrs, + vector* mcast_addrs); +static int exch_client(const char* server, const char* port, + vector* ucast_addrs, + vector* mcast_addrs); + +int main(int argc, char** argv) { + ::google::InitGoogleLogging(argv[0]); + ::google::InstallFailureSignalHandler(); + + if (argc != 5) { + printf("Usage: ib.bin port solver_proto_file " // + "[gpu_id][:gpu_id][...]:cpu_cores\n"); + return 1; + } + + const int clients = atoi(argv[1]); + const bool server = clients != 0; + char* host = argv[1]; + char* port = argv[2]; + + SolverParameter solver_param; + ReadProtoFromTextFile(argv[3], &solver_param); + + vector procs; + boost::split(procs, argv[4], boost::is_any_of(":")); + vector gpus; + for (int i = 0; i < procs.size() - 1; ++i) + gpus.push_back(atoi(procs[i].c_str())); + int cores = atoi(procs[procs.size() - 1].c_str()); + + // Get IB device + + ibv_device** dev_list; + ibv_device* ib_dev; + dev_list = ibv_get_device_list(NULL); + CHECK(dev_list) << "No IB devices found"; + ib_dev = dev_list[0]; + CHECK(ib_dev) << "No IB devices found"; + + // Create IB channels for exchanging positions and gradients + + const int channels = gpus.size() + (cores > 0 ? 1 : 0); + vector ucast(channels); + vector mcast(channels); + vector ucast_addrs(channels); + vector mcast_addrs(channels); + for (int i = 0; i < channels; ++i) { + ucast[i] = new IBChannel(ib_dev); + mcast[i] = new IBChannel(ib_dev); + ucast_addrs[i] = ucast[i]->address(); + mcast_addrs[i] = mcast[i]->mcast_create(); + } + + // Exchange IB addresses + + int rank; + if (server) { + if (clients > 0) + exch_server(clients, port, &ucast_addrs, &mcast_addrs); + rank = 0; + } else { + rank = exch_client(host, port, &ucast_addrs, &mcast_addrs); + } + + // Create main solver (first GPU if available, or CPU) + + if (gpus.size()) + solver_param.set_device_id(gpus[0]); + else + solver_param.set_solver_mode(SolverParameter::CPU); + SGDSolver main(solver_param, rank != 0); + + Params params(main.net()->params()); //, "/dev/shm/test"); + + // Create syncs + + bool sync = true; + vector*> gpu_params; + vector*> syncs; + for (int i = 0; i < gpus.size(); ++i) { + gpu_params.push_back(new GPUParams(params, gpus[i])); + syncs.push_back(new GPUIBSync(*gpu_params.back(), rank + i, // + *ucast[i], // + *mcast[i], // + ucast_addrs, // + mcast_addrs)); + if (sync) + syncs.back()->start(); + } + if (cores > 0) { + syncs.push_back(new CPUIBSync(params, rank + gpus.size(), // + *ucast[gpus.size()], // + *mcast[gpus.size()], // + ucast_addrs, // + mcast_addrs)); + syncs.back()->start(); + } + + // Wait for weights to be in sync + + LOG(INFO)<< "Waiting for other boxes\n"; + bool ready = false; + while (sync && !ready) { + sleep(1); + ready = true; + for (int i = 0; i < syncs.size(); ++i) { + if (!syncs[i]->ready()) { + ready = false; + } + } + } + LOG(INFO)<< "Start training\n"; + + // Create contexts + vector contexts(gpus.size() + cores); + if (gpus.size()) { +#ifndef CPU_ONLY + contexts[0] = new GPUContext(params, solver_param, gpu_params[0], &main); +#else + NO_GPU; +#endif + } else { + contexts[0] = new CPUContext(params, solver_param, &main); + } +#ifndef CPU_ONLY + // GPUs + for (int i = 1; i < gpus.size(); ++i) { + solver_param.set_device_id(gpus[i]); + contexts[i] = new GPUContext(params, solver_param, gpu_params[i]); + contexts[i]->start(); + } +#endif + // CPUs + solver_param.set_solver_mode(SolverParameter::CPU); + for (int i = max(1, (int) gpus.size()); i < gpus.size() + cores; ++i) { + contexts[i] = new CPUContext(params, solver_param); + contexts[i]->start(); + } + + // Start monitor + IBMonitor monitor(params, contexts, syncs, gpu_params); + monitor.start(); + + // Run main on current thread + contexts[0]->run(); + + monitor.stop(); + LOG(INFO)<< "Monitor stop\n"; + + for (int i = 1; i < contexts.size(); ++i) + contexts[i]->stop(); + + for (int i = 1; i < contexts.size(); ++i) + delete contexts[i]; + + ibv_free_device_list(dev_list); +} + +// Exchange addresses through socket, c.f. IB perftest + +static void exch_server(const int clients, const char* port, + vector* ucast_addrs, + vector* mcast_addrs) { + struct addrinfo *res, *t; + struct addrinfo hints; + memset(&hints, 0, sizeof hints); + hints.ai_flags = AI_PASSIVE; + hints.ai_family = AF_INET; + hints.ai_socktype = SOCK_STREAM; + int n = getaddrinfo(NULL, port, &hints, &res); + if (n < 0) { + fprintf(stderr, "%s for port %s\n", gai_strerror(n), port); + return; + } + int s = -1; + for (t = res; t; t = t->ai_next) { + s = socket(t->ai_family, t->ai_socktype, t->ai_protocol); + if (s >= 0) { + int n = 1; + setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &n, sizeof n); + if (!bind(s, t->ai_addr, t->ai_addrlen)) + break; + close(s); + s = -1; + } + } + freeaddrinfo(res); + if (s < 0) { + fprintf(stderr, "Couldn't listen to port %s\n", port); + return; + } + + printf("Listening to port %s\n", port); + listen(s, 1); + vector connections(clients); + vector ranks(clients); + + for (int i = 0; i < connections.size(); ++i) { + connections[i] = accept(s, NULL, 0); + if (connections[i] < 0) { + fprintf(stderr, "accept() failed\n"); + return; + } + LOG(INFO)<< "Client " << i << " of " << connections.size() << " connected\n"; + int count; + CHECK(read(connections[i], &count, sizeof(int)) == sizeof(int)); + ranks[i] = ucast_addrs->size(); + ucast_addrs->resize(ranks[i] + count); + mcast_addrs->resize(ranks[i] + count); + int bytes = sizeof(ib_addr) * count; + CHECK(read(connections[i], &ucast_addrs->at(ranks[i]), bytes) == bytes); + CHECK(read(connections[i], &mcast_addrs->at(ranks[i]), bytes) == bytes); + } + + for (int i = 0; i < connections.size(); ++i) { + int count = ucast_addrs->size(); + CHECK(write(connections[i], &ranks[i], sizeof(int)) == sizeof(int)); + CHECK(write(connections[i], &count, sizeof(int)) == sizeof(int)); + int bytes = sizeof(ib_addr) * count; + CHECK(write(connections[i], &ucast_addrs->at(0), bytes) == bytes); + CHECK(write(connections[i], &mcast_addrs->at(0), bytes) == bytes); + close(connections[i]); + } + + close(s); +} + +static int exch_client(const char* server, const char* port, + vector* ucast_addrs, + vector* mcast_addrs) { + addrinfo *res; + addrinfo hints; + memset(&hints, 0, sizeof hints); + hints.ai_family = AF_INET; + hints.ai_socktype = SOCK_STREAM; + int n = getaddrinfo(server, port, &hints, &res); + if (n < 0) { + fprintf(stderr, "%s for %s:%s\n", gai_strerror(n), server, port); + return -1; + } + int s = -1; + for (addrinfo* t = res;; t = t->ai_next) { + s = socket(t->ai_family, t->ai_socktype, t->ai_protocol); + if (s >= 0) { + if (!connect(s, t->ai_addr, t->ai_addrlen)) + break; + close(s); + s = -1; + } + } + freeaddrinfo(res); + if (s < 0) { + fprintf(stderr, "Couldn't connect to %s:%s\n", server, port); + return -1; + } + LOG(INFO)<< "Connected to server\n"; + + int bytes, rank, count = ucast_addrs->size(); + CHECK(write(s, &count, sizeof(int)) == sizeof(int)); + bytes = sizeof(ib_addr) * count; + CHECK(write(s, &ucast_addrs->at(0), bytes) == bytes); + CHECK(write(s, &mcast_addrs->at(0), bytes) == bytes); + + CHECK(read(s, &rank, sizeof(int)) == sizeof(int)); + CHECK(read(s, &count, sizeof(int)) == sizeof(int)); + ucast_addrs->resize(count); + mcast_addrs->resize(count); + bytes = sizeof(ib_addr) * count; + CHECK(read(s, &ucast_addrs->at(0), bytes) == bytes); + CHECK(read(s, &mcast_addrs->at(0), bytes) == bytes); + return rank; +} + +#else +int main(int argc, char *argv[]) { +} +#endif + diff --git a/include/caffe/common.hpp b/include/caffe/common.hpp index 81b2e9ae101..d9d52112aa0 100644 --- a/include/caffe/common.hpp +++ b/include/caffe/common.hpp @@ -2,6 +2,7 @@ #define CAFFE_COMMON_HPP_ #include +#include #include #include @@ -69,6 +70,7 @@ namespace caffe { // We will use the boost shared_ptr instead of the new C++11 one mainly // because cuda does not work (at least now) well with C++11 features. using boost::shared_ptr; +using boost::thread_specific_ptr; // Common functions and classes from std that caffe often uses. using std::fstream; @@ -95,10 +97,10 @@ class Caffe { public: ~Caffe(); inline static Caffe& Get() { - if (!singleton_.get()) { - singleton_.reset(new Caffe()); + if (!thread_instance_.get()) { + thread_instance_.reset(new Caffe()); } - return *singleton_; + return *(thread_instance_.get()); } enum Brew { CPU, GPU }; enum Phase { TRAIN, TEST }; @@ -161,7 +163,9 @@ class Caffe { Brew mode_; Phase phase_; - static shared_ptr singleton_; + + // Make sure each thread can have different values. + static thread_specific_ptr thread_instance_; private: // The private constructor to avoid duplicate instantiation. diff --git a/include/caffe/data_layers.hpp b/include/caffe/data_layers.hpp index 34b9b30aa3e..4c2736de23a 100644 --- a/include/caffe/data_layers.hpp +++ b/include/caffe/data_layers.hpp @@ -5,7 +5,7 @@ #include #include -#include "boost/scoped_ptr.hpp" +#include "boost/weak_ptr.hpp" #include "hdf5.h" #include "caffe/blob.hpp" @@ -16,9 +16,12 @@ #include "caffe/internal_thread.hpp" #include "caffe/layer.hpp" #include "caffe/proto/caffe.pb.h" +#include "caffe/util/blocking_queue.hpp" namespace caffe { +using boost::weak_ptr; + /** * @brief Provides base for data layers that feed blobs to the Net. * @@ -52,13 +55,18 @@ class BaseDataLayer : public Layer { bool output_labels_; }; +template +class Batch { + public: + Blob data_, label_; +}; + template class BasePrefetchingDataLayer : public BaseDataLayer, public InternalThread { public: - explicit BasePrefetchingDataLayer(const LayerParameter& param) - : BaseDataLayer(param) {} - virtual ~BasePrefetchingDataLayer() {} + explicit BasePrefetchingDataLayer(const LayerParameter& param); + virtual ~BasePrefetchingDataLayer(); // LayerSetUp: implements common data layer setup functionality, and calls // DataLayerSetUp to do special data layer setup for individual layer types. // This method may not be overridden. @@ -70,23 +78,70 @@ class BasePrefetchingDataLayer : virtual void Forward_gpu(const vector*>& bottom, const vector*>& top); - virtual void CreatePrefetchThread(); - virtual void JoinPrefetchThread(); - // The thread's function - virtual void InternalThreadEntry() {} - protected: - Blob prefetch_data_; - Blob prefetch_label_; + virtual void InternalThreadEntry(); + virtual void load_batch(Batch* batch) = 0; + + // Prefetches batches (asynchronously if to GPU memory) + static const int PREFETCH_COUNT = 4; + Batch prefetch_[PREFETCH_COUNT]; + blocking_queue*> prefetch_free_; + blocking_queue*> prefetch_full_; + int device_; + Blob transformed_data_; }; +// Single database context per file, prefetches datums to host memory that +// can be read by multiple data layers. +class DataLoader { + public: + DataLoader(const DataParameter& param, int index, + blocking_queue* free = NULL, + blocking_queue* full = NULL); + ~DataLoader(); + + inline blocking_queue* free() { + return body_.get()->free_; + } + inline blocking_queue* full() { + return body_.get()->full_; + } + + protected: + class Body: public InternalThread { + public: + Body(const DataParameter& param, int index, + blocking_queue* free, + blocking_queue* full); + ~Body(); + + void InternalThreadEntry(); + + shared_ptr > dataset_; + Dataset::const_iterator iter_; + + blocking_queue* free_; + blocking_queue* full_; + bool own_free_full_; + + DISABLE_COPY_AND_ASSIGN(Body); + }; + + static map > instances_; + static boost::mutex instances_mutex_; + + const string source_; + shared_ptr body_; + + DISABLE_COPY_AND_ASSIGN(DataLoader); +}; + template -class DataLayer : public BasePrefetchingDataLayer { +class DataLayer: public BasePrefetchingDataLayer { public: - explicit DataLayer(const LayerParameter& param) - : BasePrefetchingDataLayer(param) {} - virtual ~DataLayer(); + explicit DataLayer(const LayerParameter& param); + virtual ~DataLayer() {} virtual void DataLayerSetUp(const vector*>& bottom, const vector*>& top); @@ -98,10 +153,11 @@ class DataLayer : public BasePrefetchingDataLayer { virtual inline int MaxTopBlobs() const { return 2; } protected: - virtual void InternalThreadEntry(); + blocking_queue* loaders_free_; + blocking_queue* loaders_full_; - shared_ptr > dataset_; - Dataset::const_iterator iter_; + virtual void load_batch(Batch* batch); + vector > loaders_; }; /** @@ -244,7 +300,7 @@ class ImageDataLayer : public BasePrefetchingDataLayer { protected: shared_ptr prefetch_rng_; virtual void ShuffleImages(); - virtual void InternalThreadEntry(); + virtual void load_batch(Batch* batch); vector > lines_; int lines_id_; @@ -317,7 +373,7 @@ class WindowDataLayer : public BasePrefetchingDataLayer { protected: virtual unsigned int PrefetchRand(); - virtual void InternalThreadEntry(); + virtual void load_batch(Batch* batch); shared_ptr prefetch_rng_; vector > > image_database_; diff --git a/include/caffe/internal_thread.hpp b/include/caffe/internal_thread.hpp index 6a106e6eefa..dc1ae09ea0b 100644 --- a/include/caffe/internal_thread.hpp +++ b/include/caffe/internal_thread.hpp @@ -26,23 +26,28 @@ class Thread { */ class InternalThread { public: - InternalThread() : thread_(NULL) {} + InternalThread() : thread_(NULL), must_stop_() {} virtual ~InternalThread(); /** Returns true if the thread was successfully started. **/ bool StartInternalThread(); /** Will not return until the internal thread has exited. */ - bool WaitForInternalThreadToExit(); + bool StopInternalThread(); bool is_started() const { return thread_ != NULL && thread_->joinable(); } + bool must_stop() { + return must_stop_; + } + protected: /* Implement this method in your subclass with the code you want your thread to run. */ virtual void InternalThreadEntry() {} caffe::Thread* thread_; + bool must_stop_; }; } // namespace caffe diff --git a/include/caffe/parallel.hpp b/include/caffe/parallel.hpp new file mode 100644 index 00000000000..7032b598821 --- /dev/null +++ b/include/caffe/parallel.hpp @@ -0,0 +1,557 @@ +#ifndef CAFFE_PARALLEL_H_ +#define CAFFE_PARALLEL_H_ + +#include +#include +#include +#include +#include + +#include "caffe/common.hpp" +#include "caffe/blob.hpp" +#include "caffe/layer.hpp" +#include "caffe/solver.hpp" +#include "caffe/syncedmem.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/internal_thread.hpp" + +using std::deque; +using boost::dynamic_bitset; +using boost::posix_time::ptime; +using boost::posix_time::microsec_clock; + +// The following classes enable parallel training, over multiple CPU cores, +// GPUs, and machines. Gradients are measured and propagated between solvers +// asynchronously from backprop, to independently max out both networking +// and compute resources. Only data-parallel training is supported. Models +// can be trained in parallel without modification. + +namespace caffe { + +// Helper to write components running in their own threads +class Threaded : public InternalThread { + public: + Threaded() + : InternalThread() { + } + + virtual void start() { + this->StartInternalThread(); + } + virtual void stop() { + this->StopInternalThread(); + } + + virtual void run() = 0; + + protected: + void InternalThreadEntry() { + run(); + } + +DISABLE_COPY_AND_ASSIGN(Threaded); +}; + +// Helper for perf metrics +class Meter { + public: + // If unit_size is specified, meter will display bandwidth as size * count/s + Meter(const string& name, uint64_t unit_size = 0) + : name_(name), + unit_size_(unit_size), // + value_(), + last_(), + time_(microsec_clock::local_time()) { + } + + inline uint64_t value() const { + return value_; + } + inline void value(uint64_t value) { + value_ = value; + } + inline void operator++(int) { + value_++; + } + + void show(std::ostream& s) const; + + protected: + const string name_; + const uint64_t unit_size_; + mutable uint64_t value_, last_; + mutable ptime time_; // TODO find a monotonic clock + +DISABLE_COPY_AND_ASSIGN(Meter); +}; + +// Represents a net parameters. Once a net is created, its parameter buffers can +// be replaced by ones from Params, to allow parallelization. E.g. Params ensures +// that all parameters are allocated in one consecutive array, that the buffers +// are sufficiently long for chuncking alignments, and potentially other future +// requirements. Also keep track of the total iterations on those weights, to get +// correct hyper-parameters schedules across multiple solvers. TODO keep track +// of total iterations also between machines. +template +class Params { + public: + // Allocate a buffer compatible with the given blobs, optionally mapped to a + // file (/dev/shm) for multi-process configurations or debugging. + Params(const vector > >& blobs, // + const string& file_map = ""); + virtual ~Params(); + + inline size_t len_used() const { + return len_used_; + } + inline size_t len_buff() const { + return len_buff_; + } + inline Dtype* cpu() const { + return cpu_; + } + inline int iterations() { + return iterations_; + } + inline void iterations(int value) { + iterations_ = value; + } + + // Replaces solvers parameters by the shared buffer. Solvers then run on + // the same weights without synchronization (Hogwild). See hogwild.cpp in + // /examples for details and BLAS requirements. + void configure(Solver* solver) const; + + protected: + const size_t len_used_; // Actually used + const size_t len_buff_; // Allocated aligned to potential chunks + Dtype* cpu_; + mutable int iterations_; // Total iterations across solvers + + template + friend class GPUParams; + +DISABLE_COPY_AND_ASSIGN(Params); +}; + +#ifndef CPU_ONLY + +// Params on a GPU +template +class GPUParams { + public: + GPUParams(const Params& params, int device); + virtual ~GPUParams(); + void configure(Solver* solver) const; + + inline const Params& params() const { + return params_; + } + inline int device() const { + return device_; + } + inline Dtype* gpu() const { + return gpu_; + } + + protected: + const Params& params_; + const int device_; + Dtype* gpu_; + +DISABLE_COPY_AND_ASSIGN(GPUParams); +}; + +template +class GPUStream { + public: + GPUStream(); + virtual ~GPUStream(); + + const cudaStream_t& stream() const { + return stream_; + } + + protected: + cudaStream_t stream_; + +DISABLE_COPY_AND_ASSIGN(GPUStream); +}; + +// Base class for GPU synchronization. +template +class GPUSync { + protected: + GPUSync(const GPUParams& params); + virtual ~GPUSync(); + + const GPUParams& params_; + Dtype* gpu_last_; + +DISABLE_COPY_AND_ASSIGN(GPUSync); +}; + +// Syncs params between CPU and GPU memory. +template +class CPUGPUSync : // + public GPUSync, // + public Threaded { + + public: + CPUGPUSync(const GPUParams& params); + + virtual ~CPUGPUSync(); + + void run(); + + inline const Meter& calls() const { + return calls_; + } + inline const Meter& cycles() { + return cycles_; + } + + static size_t chunks(const size_t len) { + return (len + CHUNK - 1) / CHUNK; + } + + // TODO bench, auto tune? + static const int CHUNK = 262144; + + protected: + void push(uint32_t chunk); + + const uint32_t chunks_; + + // Perf counters + Meter calls_, cycles_; +}; + +template +void sync_master_kernel(Dtype* gpu, Dtype** grds, size_t* offs, // + int batch_start, int batch_count, // + const cudaStream_t& stream, size_t chunk); + +template +void sync_worker_kernel(Dtype* gpu, Dtype* last, Dtype** pos, size_t* offs, + Dtype** grads, uint8_t* get_grads, // + int batch_start, int batch_count, // + const cudaStream_t& stream, size_t chunk); + +#endif + +// Base class for distributed sync +template +class DistSync { + public: + inline const Meter& cycles() const { + return cycles_; + } + bool ready() { + return remaining_ == 0; + } + + protected: + DistSync(uint32_t nodes, uint32_t chunks); + virtual ~DistSync() { + } + void dist_init(int local); + + // Master node for a given chunk + inline int chunk_master(uint32_t chunk); + + const uint32_t nodes_; + const uint32_t chunks_; + + uint32_t own_start_; // Start of range of chunks for which this node is master + uint32_t own_until_; // End of this range + uint32_t chunk_; // Current chunk sent by master + + // Startup book-keeping, we need to know when nodes are in sync + // TODO replace by transfer of initial weights? + dynamic_bitset<> received_; + uint32_t remaining_; + + // Perf counter + Meter cycles_; + +DISABLE_COPY_AND_ASSIGN(DistSync); +}; + +#ifdef RDMA +#include +#include +#include "caffe/util/multicast_resources.hpp" + +struct ib_addr { + ibv_gid gid; // Only used for multicast addresses + uint16_t lid; + uint32_t qpn; + uint32_t psn; + ibv_ah* ah; +}; + +template +class IBSync; + +class IBChannel { + public: + IBChannel(ibv_device* ib_dev); + ~IBChannel(); + inline const ib_addr& address() const { + return local_; + } + + ib_addr mcast_create() const; + void mcast_join(const ib_addr& addr) const; + void mcast_attach_qp(const ib_addr& addr) const; + + void start(uint8_t* buf_send, size_t buf_size, bool gpu = false) const; + + inline const string adapter() const { + return string(context_->device->dev_name); + } + + // Stats + inline const Meter& sent() const { + return sent_; + } + inline const Meter& recv() const { + return recv_; + } + + static const int MTU = 4096; // TODO get at runtime + static const int FRAMES = 1024; // TODO bench + + protected: + // TODO make port configurable + static const int PORT = 1; + + bool can_send() const; + int send_init(uint8_t*& buf) const; + void send(int id, const ib_addr& addr, uint8_t* buf, uint32_t imm_data) const; + + bool can_recv() const; + int recv(uint8_t*& buf, uint32_t& imm_data) const; + void recv_done(int id) const; + + void poll() const; + + static ibv_context* open_device(ibv_device* ib_dev); + static ibv_pd* alloc_pd(ibv_context*); + + // TODO align recv buffers to CACHE_LINE_SIZE (64) - GRH + static const int GRH = 40; // Global Routing Header + + ibv_context* context_; + ibv_pd* pd_; + ib_addr local_; + ibv_cq* cq_; + ibv_qp* qp_; + + mutable uint8_t* buf_send_; + mutable uint8_t* buf_recv_; + mutable ibv_mr* mr_send_; + mutable ibv_mr* mr_recv_; + + struct recv_msg { + uint32_t id_; + uint32_t imm_; + }; + + mutable ibv_wc wc_[IBChannel::FRAMES * 2]; + mutable deque send_queue_; + mutable deque recv_queue_; + + mutable Meter sent_, recv_; + + template + friend class IBSync; + template + friend class CPUIBSync; + template + friend class GPUIBSync; + +DISABLE_COPY_AND_ASSIGN(IBChannel); +}; + +// Synchronization over InfiniBand +template +class IBSync : public DistSync, public Threaded { + public: + inline const IBChannel& ucast() const { + return ucast_; + } + inline const IBChannel& mcast() const { + return mcast_; + } + + static size_t chunks(const size_t len) { + return (len + CHUNK - 1) / CHUNK; + } + + static const int CHUNK = IBChannel::MTU / sizeof(Dtype); + + protected: + IBSync(const Params& params, int rank, // + const IBChannel& ucast, // + const IBChannel& mcast, // + const vector& ucast_addrs, // + const vector& mcast_addrs); + ~IBSync(); + + const int rank_; + const IBChannel& ucast_; + const IBChannel& mcast_; + vector ucast_addrs_; + ib_addr mcast_addr_; +}; + +// InfiniBand to and from host memory +template +class CPUIBSync : public IBSync { + public: + CPUIBSync(const Params& params, int rank, + const IBChannel& ucast, + const IBChannel& mcast, // + const vector& ucast_addrs, + const vector& mcast_addrs); + ~CPUIBSync(); + void run(); + + protected: + Dtype* cpu_; + Dtype* cpu_last_; +}; + +#ifndef CPU_ONLY + +// InfiniBand to and from GPU memory +template +class GPUIBSync : public GPUSync, public IBSync { + public: + GPUIBSync(const GPUParams& params, int rank, + const IBChannel& ucast, + const IBChannel& mcast, // + const vector& ucast_addrs, + const vector& mcast_addrs); + ~GPUIBSync(); + void run(); + + static const int FRAMES = IBChannel::FRAMES; + + protected: + Dtype* gpu_; + Dtype* gpu_last_; +}; + +#endif // not CPU_ONLY + +#endif // RDMA + +#ifdef __linux__ + +#include +#include + +// User-space networking ring buffer. +class Ring { + public: + Ring(const string& adapter, int protocol_send, int protocol_recv); + + ~Ring(); + + inline bool can_send(int frame, struct tpacket_hdr*& hdr); + inline ethhdr* send_init(const struct tpacket_hdr* hdr); + inline void send(struct tpacket_hdr* hdr); + + inline bool can_recv(int frame, struct tpacket_hdr*& hdr); + inline ethhdr* recv(const struct tpacket_hdr* hdr); + inline void recv_done(struct tpacket_hdr* hdr); + + inline const string& adapter() const { + return adapter_; + } + inline int sock() { + return socket_; + } + + // Stats + inline const Meter& sent() const { + return sent_; + } + inline const Meter& recv() const { + return recv_; + } + void socket_stats(uint64_t& received, uint64_t& dropped); + + static const int FRAME_SIZE = 2048; // TODO bench + static const int FRAME_NR = 32; + static const int BLOCK_NR = 1; + + protected: + const string adapter_; + const int socket_; + const uint8_t* map_recv_; + const uint8_t* map_send_; + + Meter sent_, recv_; + +DISABLE_COPY_AND_ASSIGN(Ring); +}; + +// Synchronization using raw sockets and user-space networking. Can be a very +// efficient alternative to RDMA if not available, but cannot read and write +// directly to GPU memory. +// C.f. https://www.kernel.org/doc/Documentation/networking/packet_mmap.txt +template +class RawSync : public DistSync, public Threaded { + public: + RawSync(const Params& params, // + const vector& mac_addresses, // + const vector& secondary_macs); + + ~RawSync(); + + void run(); + + inline const Ring& master() const { + return master_; + } + inline const Ring& worker() const { + return worker_; + } + + static size_t chunks(const size_t len) { + return (len + CHUNK - 1) / CHUNK; + } + + // Offsets of parts of a message + static const int MSG_CHUNK = 0; + static const int MSG_DATA = sizeof(Dtype); + + static const int CHUNK = (ETH_DATA_LEN - MSG_DATA) / sizeof(Dtype); + + protected: + // Next chunk + inline void next(); + + // Currently all nodes are both masters and workers for some chunks, + // so the two vectors should be equal. If machines have two adapters, + // workers can point to the secondary adapters for better performance. + const vector masters_; + const vector workers_; + vector others_; // Workers without the local one + uint32_t other_; // Current node the chunk is sent to + + Ring master_; + Ring worker_; + + Dtype* cpu_; + Dtype* cpu_last_; +}; + +#endif // __linux__ +} + +#endif diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index b20912606f8..6055aa15134 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -17,9 +17,9 @@ namespace caffe { template class Solver { public: - explicit Solver(const SolverParameter& param); + explicit Solver(const SolverParameter& param, bool skip_test_nets); explicit Solver(const string& param_file); - void Init(const SolverParameter& param); + void Init(const SolverParameter& param, bool skip_test_nets); void InitTrainNet(); void InitTestNets(); // The main entry of the solver function. In default, iter will be zero. Pass @@ -27,11 +27,14 @@ class Solver { virtual void Solve(const char* resume_file = NULL); inline void Solve(const string resume_file) { Solve(resume_file.c_str()); } virtual ~Solver() {} + inline const SolverParameter& param() const { return param_; } inline shared_ptr > net() { return net_; } inline const vector > >& test_nets() { return test_nets_; } - int iter() { return iter_; } + inline int iter() { return iter_; } + inline int *iter_total() { return iter_total_; } + inline void iter_total(int *value) { iter_total_ = value; } protected: // PreSolve is run before any solving iteration starts, allowing one to @@ -57,6 +60,9 @@ class Solver { SolverParameter param_; int iter_; + // Points to iter_ by default, but can be overriden, e.g. to a global + // counter if multiple solvers contribute iterations to the same model. + int *iter_total_; int current_step_; shared_ptr > net_; vector > > test_nets_; @@ -72,8 +78,8 @@ class Solver { template class SGDSolver : public Solver { public: - explicit SGDSolver(const SolverParameter& param) - : Solver(param) {} + explicit SGDSolver(const SolverParameter& param, bool skip_test_nets = false) + : Solver(param, skip_test_nets) {} explicit SGDSolver(const string& param_file) : Solver(param_file) {} diff --git a/include/caffe/syncedmem.hpp b/include/caffe/syncedmem.hpp index db8d0e80e12..6e6514ffc80 100644 --- a/include/caffe/syncedmem.hpp +++ b/include/caffe/syncedmem.hpp @@ -8,29 +8,22 @@ namespace caffe { -// Theoretically, CaffeMallocHost and CaffeFreeHost should simply call the -// cudaMallocHost and cudaFree functions in order to create pinned memory. -// However, those codes rely on the existence of a cuda GPU (I don't know -// why that is a must since allocating memory should not be accessing the -// GPU resorce, but it just creates an error as of Cuda 5.0) and will cause -// problem when running on a machine without GPU. Thus, we simply define -// these two functions for safety and possible future change if the problem -// of calling cuda functions disappears in a future version. -// -// In practice, although we are creating unpinned memory here, as long as we -// are constantly accessing them the memory pages almost always stays in -// the physical memory (assuming we have large enough memory installed), and -// does not seem to create a memory bottleneck here. - inline void CaffeMallocHost(void** ptr, size_t size) { +#ifndef CPU_ONLY + cudaMallocHost(ptr, size); +#else *ptr = malloc(size); +#endif } inline void CaffeFreeHost(void* ptr) { +#ifndef CPU_ONLY + cudaFreeHost(ptr); +#else free(ptr); +#endif } - /** * @brief Manages memory allocation and synchronization between the host (CPU) * and device (GPU). @@ -41,20 +34,25 @@ class SyncedMemory { public: SyncedMemory() : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(0), head_(UNINITIALIZED), - own_cpu_data_(false) {} + own_cpu_data_(false), own_gpu_data_(false) {} explicit SyncedMemory(size_t size) : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(size), head_(UNINITIALIZED), - own_cpu_data_(false) {} + own_cpu_data_(false), own_gpu_data_(false) {} ~SyncedMemory(); const void* cpu_data(); void set_cpu_data(void* data); const void* gpu_data(); + void set_gpu_data(void* data); void* mutable_cpu_data(); void* mutable_gpu_data(); enum SyncedHead { UNINITIALIZED, HEAD_AT_CPU, HEAD_AT_GPU, SYNCED }; SyncedHead head() { return head_; } size_t size() { return size_; } +#ifndef CPU_ONLY + void async_gpu_push(const cudaStream_t& stream); +#endif + private: void to_cpu(); void to_gpu(); @@ -63,6 +61,7 @@ class SyncedMemory { size_t size_; SyncedHead head_; bool own_cpu_data_; + bool own_gpu_data_; DISABLE_COPY_AND_ASSIGN(SyncedMemory); }; // class SyncedMemory diff --git a/include/caffe/util/blocking_queue.hpp b/include/caffe/util/blocking_queue.hpp new file mode 100644 index 00000000000..a21b31a5395 --- /dev/null +++ b/include/caffe/util/blocking_queue.hpp @@ -0,0 +1,73 @@ +#ifndef CAFFE_UTIL_BLOCKING_QUEUE_H_ +#define CAFFE_UTIL_BLOCKING_QUEUE_H_ + +#include +#include "boost/thread.hpp" + +namespace caffe { + +template +class blocking_queue { +public: + void push(const T& t) { + boost::mutex::scoped_lock lock(mutex_); + queue_.push(t); + lock.unlock(); + condition_.notify_one(); + } + + bool empty() const { + boost::mutex::scoped_lock lock(mutex_); + return queue_.empty(); + } + + bool try_pop(T& t) { + boost::mutex::scoped_lock lock(mutex_); + + if (queue_.empty()) + return false; + + t = queue_.front(); + queue_.pop(); + return true; + } + + T pop(const string& log_on_wait = "") { + boost::mutex::scoped_lock lock(mutex_); + + while (queue_.empty()) { + if (!log_on_wait.empty()) { + time_t now = time(0); + if (now - last_wait_log_ > 5) { + last_wait_log_ = now; + LOG(INFO) << log_on_wait; + } + } + condition_.wait(lock); + } + + T t = queue_.front(); + queue_.pop(); + return t; + } + + // Return element without removing it + T peek() { + boost::mutex::scoped_lock lock(mutex_); + + while (queue_.empty()) + condition_.wait(lock); + + return queue_.front(); + } + +private: + std::queue queue_; + mutable boost::mutex mutex_; + boost::condition_variable condition_; + time_t last_wait_log_; +}; + +} // namespace caffe + +#endif diff --git a/include/caffe/util/multicast_resources.hpp b/include/caffe/util/multicast_resources.hpp new file mode 100644 index 00000000000..ae05b7a39dd --- /dev/null +++ b/include/caffe/util/multicast_resources.hpp @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2009 Mellanox Technologies Ltd. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * Author: Ido Shamay + */ + +#ifndef MULTICAST_RESOURCES_H +#define MULTICAST_RESOURCES_H + +// Imported from perftest +#ifdef RDMA + + /* Multicast Module for perftest. + * + * Description : + * + * This file contains the structures and methods for implementing a multiple + * multicast groups in user space enviroment. + * The module is in use in "send_bw" and "send_lat" ,but can be used on other + * applications and can generate more methods and serve more benchmarks. + * The Module uses only the structire defined here , enabling generic use of it. + * + * Defined Types : + * + * mcast_parameters - Contains all the parameters needed for this module. + * mcast_group - The multicast group entitiy itself. + * mcg_qp - Is a QP structure that is attahced to the group. + * + */ + + +/************************************************************************ + * Macros , Defines and Files included for work. * + ************************************************************************/ + +#include +#include +//#include "get_clock.h" + +#define QPNUM_MCAST 0xffffff +#define DEF_QKEY 0x11111111 +#define DEF_PKEY_IDX 0 +#define DEF_SLL 0 +#define MAX_POLL_ITERATION_TIMEOUT 1000000 +#define MCG_GID {255,1,0,0,0,2,201,133,0,0,0,0,0,0,0,0} + +// Definitions section for MADs +#define SUBN_ADM_ATTR_MC_MEMBER_RECORD 0x38 +#define MANAGMENT_CLASS_SUBN_ADM 0x03 /* Subnet Administration class */ +#define MCMEMBER_JOINSTATE_FULL_MEMBER 0x1 +#define MAD_SIZE 256 /* The size of a MAD is 256 bytes */ +#define QP1_WELL_KNOWN_Q_KEY 0x80010000 /* Q_Key value of QP1 */ +#define DEF_TRANS_ID 0x12345678 /* TransactionID */ +#define DEF_TCLASS 0 +#define DEF_FLOW_LABLE 0 + +// Macro for 64 bit variables to switch to from net +#ifndef ntohll +#define ntohll(x) (((uint64_t)(ntohl((int)((x << 32) >> 32))) << 32) | (unsigned int)ntohl(((int)(x >> 32)))) +#endif +#ifndef htonll +#define htonll(x) ntohll(x) +#endif + +// generate a bit mask S bits width +#define MASK32(S) ( ((uint32_t) ~0L) >> (32-(S)) ) + +// generate a bit mask with bits O+S..O set (assumes 32 bit integer). +#define BITS32(O,S) ( MASK32(S) << (O) ) + +// extract S bits from (u_int32_t)W with offset O and shifts them O places to the right +#define EXTRACT32(W,O,S) ( ((W)>>(O)) & MASK32(S) ) + +// insert S bits with offset O from field F into word W (u_int32_t) +#define INSERT32(W,F,O,S) (/*(W)=*/ ( ((W) & (~BITS32(O,S)) ) | (((F) & MASK32(S))<<(O)) )) + +#ifndef INSERTF + #define INSERTF(W,O1,F,O2,S) (INSERT32(W, EXTRACT32(F, O2, S), O1, S) ) +#endif + + +// according to Table 187 in the IB spec 1.2.1 +typedef enum { + SUBN_ADM_METHOD_SET = 0x2, + SUBN_ADM_METHOD_DELETE = 0x15 +} subn_adm_method; + +// Utilities for Umad Usage. +typedef enum { + SUBN_ADM_COMPMASK_MGID = (1ULL << 0), + SUBN_ADM_COMPMASK_PORT_GID = (1ULL << 1), + SUBN_ADM_COMPMASK_Q_KEY = (1ULL << 2), + SUBN_ADM_COMPMASK_P_KEY = (1ULL << 7), + SUBN_ADM_COMPMASK_TCLASS = (1ULL << 6), + SUBN_ADM_COMPMASK_SL = (1ULL << 12), + SUBN_ADM_COMPMASK_FLOW_LABEL = (1ULL << 13), + SUBN_ADM_COMPMASK_JOIN_STATE = (1ULL << 16), +} subn_adm_component_mask; + +typedef enum { + MCAST_IS_JOINED = 1, + MCAST_IS_ATTACHED = (1 << 1) +} mcast_state; + + +/************************************************************************ + * Multicast data structures. * + ************************************************************************/ + +// Needed parameters for creating a multiple multicast group entity. +struct mcast_parameters { + int num_qps_on_group; + int is_user_mgid; + int mcast_state; + int ib_port; + uint16_t mlid; + uint16_t base_mlid; + const char *user_mgid; + char *ib_devname; + uint16_t pkey; + uint16_t sm_lid; + uint8_t sm_sl; + union ibv_gid port_gid; + union ibv_gid mgid; + // In case it's a latency test. + union ibv_gid base_mgid; + int is_2nd_mgid_used; +}; + +// according to Table 195 in the IB spec 1.2.1 + +struct sa_mad_packet_t { + u_int8_t mad_header_buf[24]; + u_int8_t rmpp_header_buf[12]; + u_int64_t SM_Key; + u_int16_t AttributeOffset; + u_int16_t Reserved1; + u_int64_t ComponentMask; + u_int8_t SubnetAdminData[200]; +}__attribute__((packed)); + +/************************************************************************ + * Multicast resources methods. * + ************************************************************************/ + +/* set_multicast_gid . + * + * Description : + * + * Sets the Multicast GID , and stores it in the "mgid" value of + * mcast resourcs. If the user requested for a specific MGID, which + * is stored in params->user_mgid (in this case params->is_user_mgid should be 1) + * than it will be his MGID, if not the library choose a default one. + * + * Parameters : + * + * params - The parameters of the machine + * my_dest ,rem_dest - The 2 sides that ends the connection. + * + * Return Value : 0 upon success. -1 if it fails. + */ +void set_multicast_gid(struct mcast_parameters *params,uint32_t qp_num,int is_client); + + +/* ctx_close_connection . + * + * Description : + * + * Close the connection between the 2 machines. + * It performs an handshake to ensure the 2 sides are there. + * + * Parameters : + * + * params - The parameters of the machine + * my_dest ,rem_dest - The 2 sides that ends the connection. + * + * Return Value : 0 upon success. -1 if it fails. + */ +int join_multicast_group(subn_adm_method method,struct mcast_parameters *params); + +#endif /* RDMA */ + +#endif /* MULTICAST_RESOURCES_H */ diff --git a/src/caffe/common.cpp b/src/caffe/common.cpp index 834d5694aad..e732c189b64 100644 --- a/src/caffe/common.cpp +++ b/src/caffe/common.cpp @@ -7,7 +7,7 @@ namespace caffe { -shared_ptr Caffe::singleton_; +thread_specific_ptr Caffe::thread_instance_; // random seeding int64_t cluster_seedgen(void) { diff --git a/src/caffe/internal_thread.cpp b/src/caffe/internal_thread.cpp index d7b6ae206cf..5a32fd38e93 100644 --- a/src/caffe/internal_thread.cpp +++ b/src/caffe/internal_thread.cpp @@ -5,16 +5,17 @@ namespace caffe { InternalThread::~InternalThread() { - WaitForInternalThreadToExit(); + StopInternalThread(); if (thread_ != NULL) { delete thread_; } } bool InternalThread::StartInternalThread() { - if (!WaitForInternalThreadToExit()) { + if (!StopInternalThread()) { return false; } + must_stop_ = false; try { thread_ = new caffe::Thread (&InternalThread::InternalThreadEntry, this); @@ -25,7 +26,8 @@ bool InternalThread::StartInternalThread() { } /** Will not return until the internal thread has exited. */ -bool InternalThread::WaitForInternalThreadToExit() { +bool InternalThread::StopInternalThread() { + must_stop_ = true; if (is_started()) { try { thread_->join(); diff --git a/src/caffe/layers/base_data_layer.cpp b/src/caffe/layers/base_data_layer.cpp index eb0aaf82120..3bcb075889d 100644 --- a/src/caffe/layers/base_data_layer.cpp +++ b/src/caffe/layers/base_data_layer.cpp @@ -26,52 +26,95 @@ void BaseDataLayer::LayerSetUp(const vector*>& bottom, data_transformer_.InitRand(); } +template +BasePrefetchingDataLayer::BasePrefetchingDataLayer( + const LayerParameter& param) + : BaseDataLayer(param), + prefetch_free_(), prefetch_full_(), device_() { + for(int i = 0; i < PREFETCH_COUNT; ++i) + prefetch_free_.push(&prefetch_[i]); +} + template void BasePrefetchingDataLayer::LayerSetUp( const vector*>& bottom, const vector*>& top) { BaseDataLayer::LayerSetUp(bottom, top); - // Now, start the prefetch thread. Before calling prefetch, we make two - // cpu_data calls so that the prefetch thread does not accidentally make - // simultaneous cudaMalloc calls when the main thread is running. In some - // GPUs this seems to cause failures if we do not so. - this->prefetch_data_.mutable_cpu_data(); - if (this->output_labels_) { - this->prefetch_label_.mutable_cpu_data(); + + // Before starting the prefetch thread, we make cpu_data and gpu_data + // calls so that the prefetch thread does not accidentally make simultaneous + // cudaMalloc calls when the main thread is running. In some GPUs this + // seems to cause failures if we do not so. + for(int i = 0; i < PREFETCH_COUNT; ++i) { + prefetch_[i].data_.mutable_cpu_data(); + if (this->output_labels_) { + prefetch_[i].label_.mutable_cpu_data(); + } } + switch (Caffe::mode()) { + case Caffe::CPU: + device_ = -1; + break; + case Caffe::GPU: +#ifndef CPU_ONLY + for(int i = 0; i < PREFETCH_COUNT; ++i) { + prefetch_[i].data_.mutable_gpu_data(); + if (this->output_labels_) { + prefetch_[i].label_.mutable_gpu_data(); + } + } + CUDA_CHECK(cudaGetDevice(&device_)); +#endif + break; + } + DLOG(INFO) << "Initializing prefetch"; - this->CreatePrefetchThread(); + this->phase_ = Caffe::phase(); + this->data_transformer_.InitRand(); + CHECK(StartInternalThread()) << "Thread execution failed"; DLOG(INFO) << "Prefetch initialized."; } template -void BasePrefetchingDataLayer::CreatePrefetchThread() { - this->phase_ = Caffe::phase(); - this->data_transformer_.InitRand(); - CHECK(StartInternalThread()) << "Thread execution failed"; +BasePrefetchingDataLayer::~BasePrefetchingDataLayer() { + CHECK(StopInternalThread()) << "Stop thread failed"; } template -void BasePrefetchingDataLayer::JoinPrefetchThread() { - CHECK(WaitForInternalThreadToExit()) << "Thread joining failed"; +void BasePrefetchingDataLayer::InternalThreadEntry() { +#ifndef CPU_ONLY + cudaStream_t stream; + if(device_ >= 0) { + CUDA_CHECK(cudaSetDevice(device_)); + cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking); + } +#endif + + while(!must_stop()) { + Batch* batch = prefetch_free_.pop(); + load_batch(batch); +#ifndef CPU_ONLY + if(device_ >= 0) { + batch->data_.data().get()->async_gpu_push(stream); + cudaStreamSynchronize(stream); + } +#endif + prefetch_full_.push(batch); + } } template void BasePrefetchingDataLayer::Forward_cpu( const vector*>& bottom, const vector*>& top) { - // First, join the thread - JoinPrefetchThread(); - DLOG(INFO) << "Thread joined"; - // Copy the data - caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(), - top[0]->mutable_cpu_data()); - DLOG(INFO) << "Prefetch copied"; + Batch* batch = prefetch_full_.pop("Data layer prefetch queue empty"); + + caffe_copy(batch->data_.count(), batch->data_.cpu_data(), + top[0]->mutable_cpu_data()); if (this->output_labels_) { - caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(), - top[1]->mutable_cpu_data()); + caffe_copy(batch->label_.count(), batch->label_.cpu_data(), + top[1]->mutable_cpu_data()); } - // Start a new prefetch thread - DLOG(INFO) << "CreatePrefetchThread"; - CreatePrefetchThread(); + + prefetch_free_.push(batch); } #ifdef CPU_ONLY @@ -79,6 +122,7 @@ STUB_GPU_FORWARD(BasePrefetchingDataLayer, Forward); #endif INSTANTIATE_CLASS(BaseDataLayer); +INSTANTIATE_CLASS(Batch); INSTANTIATE_CLASS(BasePrefetchingDataLayer); } // namespace caffe diff --git a/src/caffe/layers/base_data_layer.cu b/src/caffe/layers/base_data_layer.cu index 204a16d260a..0e0f26a9791 100644 --- a/src/caffe/layers/base_data_layer.cu +++ b/src/caffe/layers/base_data_layer.cu @@ -7,17 +7,16 @@ namespace caffe { template void BasePrefetchingDataLayer::Forward_gpu( const vector*>& bottom, const vector*>& top) { - // First, join the thread - JoinPrefetchThread(); - // Copy the data - caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(), + Batch* batch = prefetch_full_.pop("Data layer prefetch queue empty"); + + caffe_copy(batch->data_.count(), batch->data_.gpu_data(), top[0]->mutable_gpu_data()); if (this->output_labels_) { - caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(), + caffe_copy(batch->label_.count(), batch->label_.gpu_data(), top[1]->mutable_gpu_data()); } - // Start a new prefetch thread - CreatePrefetchThread(); + + prefetch_free_.push(batch); } INSTANTIATE_LAYER_GPU_FORWARD(BasePrefetchingDataLayer); diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp index 5d6f05802ce..c2f502897cb 100644 --- a/src/caffe/layers/data_layer.cpp +++ b/src/caffe/layers/data_layer.cpp @@ -1,7 +1,9 @@ #include +#include #include #include +#include #include "caffe/common.hpp" #include "caffe/data_layers.hpp" @@ -15,28 +17,55 @@ namespace caffe { -template -DataLayer::~DataLayer() { - this->JoinPrefetchThread(); - // clean up the dataset resources - dataset_->close(); +map > DataLoader::instances_; +boost::mutex DataLoader::instances_mutex_; + +DataLoader::DataLoader(const DataParameter& param, int index, + blocking_queue* free, + blocking_queue* full): + source_(param.source(index)) { + boost::mutex::scoped_lock lock(instances_mutex_); + weak_ptr body = instances_[source_]; + body_ = body.lock(); + if (body_) { + CHECK(!free || free == body_.get()->free_); + CHECK(!full || full == body_.get()->full_); + } else { + body_.reset(new Body(param, index, free, full)); + instances_[source_] = weak_ptr(body_); + } } -template -void DataLayer::DataLayerSetUp(const vector*>& bottom, - const vector*>& top) { +DataLoader::~DataLoader() { + boost::mutex::scoped_lock lock(instances_mutex_); + body_.reset(); + if (instances_[source_].expired()) + instances_.erase(source_); +} + +DataLoader::Body::Body(const DataParameter& param, int index, + blocking_queue* free, + blocking_queue* full) : + free_(free), + full_(full), + own_free_full_() { + + // Initialize queues + if(!free_) { + free_ = new blocking_queue(); + full_ = new blocking_queue(); + own_free_full_ = true; + } + // Initialize DB - dataset_ = DatasetFactory( - this->layer_param_.data_param().backend()); - const string& source = this->layer_param_.data_param().source(); - LOG(INFO) << "Opening dataset " << source; - CHECK(dataset_->open(source, Dataset::ReadOnly)); + dataset_ = DatasetFactory(param.backend()); + LOG(INFO) << "Opening dataset " << param.source(index); + CHECK(dataset_->open(param.source(index), Dataset::ReadOnly)); iter_ = dataset_->begin(); - // Check if we would need to randomly skip a few data points - if (this->layer_param_.data_param().rand_skip()) { - unsigned int skip = caffe_rng_rand() % - this->layer_param_.data_param().rand_skip(); + // Check if we need to randomly skip a few data points + if (param.rand_skip()) { + unsigned int skip = caffe_rng_rand() % param.rand_skip(); LOG(INFO) << "Skipping first " << skip << " data points."; while (skip-- > 0) { if (++iter_ == dataset_->end()) { @@ -44,63 +73,128 @@ void DataLayer::DataLayerSetUp(const vector*>& bottom, } } } - // Read a data point, and use it to initialize the top blob. - CHECK(iter_ != dataset_->end()); - Datum datum = iter_->value; - if (DecodeDatum(&datum)) { - LOG(INFO) << "Decoding Datum"; + // Add prefetch datums to layer free queue + int prefetch = param.prefetch() * param.batch_size(); + for(int i = 0; i < prefetch; ++i) { + free_->push(new Datum()); + } + + CHECK(StartInternalThread()) << "DataLoader thread start failed"; +} + +DataLoader::Body::~Body() { + CHECK(StopInternalThread()) << "DataLoader thread stop failed"; + Datum* datum; + while(free_->try_pop(datum)) { + delete datum; + } + while(full_->try_pop(datum)) { + delete datum; + } + + // clean up the dataset resources + dataset_->close(); + + if(own_free_full_) { + delete free_; + delete full_; + } +} + +void DataLoader::Body::InternalThreadEntry() { + while(!must_stop()) { + CHECK(iter_ != dataset_->end()); + + Datum* datum = free_->pop(); + // TODO deserialize in-place instead of copy? + datum->CopyFrom(iter_->value); + full_->push(datum); + + ++iter_; + if (iter_ == dataset_->end()) { + iter_ = dataset_->begin(); + } } +} + +template +DataLayer::DataLayer(const LayerParameter& param) + : BasePrefetchingDataLayer(param) { + DataLoader* ld = new DataLoader(param.data_param(), 0); + loaders_.push_back(shared_ptr(ld)); + loaders_free_ = ld->free(); + loaders_full_ = ld->full(); + + // Loaders share queues in case of multiple sources + for(int i = 1; i < param.data_param().source().size(); ++i) { + ld = new DataLoader(param.data_param(), i, loaders_free_, loaders_full_); + loaders_.push_back(shared_ptr(ld)); + } +} + +template +void DataLayer::DataLayerSetUp(const vector*>& bottom, + const vector*>& top) { + + // Look at first data point to initialize the top blob. + Datum* datum = loaders_full_->peek(); + + if (DecodeDatum(datum)) + LOG(INFO) << "Decoding Datum"; + // image - int crop_size = this->layer_param_.transform_param().crop_size(); + const int crop_size = this->layer_param_.transform_param().crop_size(); + const int batch_size = this->layer_param_.data_param().batch_size(); if (crop_size > 0) { - top[0]->Reshape(this->layer_param_.data_param().batch_size(), - datum.channels(), crop_size, crop_size); - this->prefetch_data_.Reshape(this->layer_param_.data_param().batch_size(), - datum.channels(), crop_size, crop_size); - this->transformed_data_.Reshape(1, datum.channels(), crop_size, crop_size); + top[0]->Reshape(batch_size, datum->channels(), crop_size, crop_size); + for(int i = 0; i < this->PREFETCH_COUNT; ++i) { + this->prefetch_[i].data_.Reshape(batch_size, datum->channels(), + crop_size, crop_size); + } + this->transformed_data_.Reshape(1, datum->channels(), + crop_size, crop_size); } else { - top[0]->Reshape( - this->layer_param_.data_param().batch_size(), datum.channels(), - datum.height(), datum.width()); - this->prefetch_data_.Reshape(this->layer_param_.data_param().batch_size(), - datum.channels(), datum.height(), datum.width()); - this->transformed_data_.Reshape(1, datum.channels(), - datum.height(), datum.width()); + top[0]->Reshape(batch_size, datum->channels(), + datum->height(), datum->width()); + for(int i = 0; i < this->PREFETCH_COUNT; ++i) { + this->prefetch_[i].data_.Reshape(batch_size, datum->channels(), + datum->height(), datum->width()); + } + this->transformed_data_.Reshape(1, datum->channels(), + datum->height(), datum->width()); } LOG(INFO) << "output data size: " << top[0]->num() << "," << top[0]->channels() << "," << top[0]->height() << "," << top[0]->width(); // label if (this->output_labels_) { - top[1]->Reshape(this->layer_param_.data_param().batch_size(), 1, 1, 1); - this->prefetch_label_.Reshape(this->layer_param_.data_param().batch_size(), - 1, 1, 1); + top[1]->Reshape(batch_size, 1, 1, 1); + for(int i = 0; i < this->PREFETCH_COUNT; ++i) { + this->prefetch_[i].label_.Reshape(batch_size, 1, 1, 1); + } } } -// This function is used to create a thread that prefetches the data. +// This function is called on prefetch thread template -void DataLayer::InternalThreadEntry() { +void DataLayer::load_batch(Batch* batch) { CPUTimer batch_timer; batch_timer.Start(); double read_time = 0; double trans_time = 0; CPUTimer timer; - CHECK(this->prefetch_data_.count()); + CHECK(batch->data_.count()); CHECK(this->transformed_data_.count()); - Dtype* top_data = this->prefetch_data_.mutable_cpu_data(); + Dtype* top_data = batch->data_.mutable_cpu_data(); Dtype* top_label = NULL; // suppress warnings about uninitialized variables + if (this->output_labels_) + top_label = batch->label_.mutable_cpu_data(); - if (this->output_labels_) { - top_label = this->prefetch_label_.mutable_cpu_data(); - } const int batch_size = this->layer_param_.data_param().batch_size(); for (int item_id = 0; item_id < batch_size; ++item_id) { timer.Start(); - // get a blob - CHECK(iter_ != dataset_->end()); - const Datum& datum = iter_->value; + const Datum& datum = *(loaders_full_->pop("Waiting on data loader")); cv::Mat cv_img; if (datum.encoded()) { @@ -110,7 +204,7 @@ void DataLayer::InternalThreadEntry() { timer.Start(); // Apply data transformations (mirror, scale, crop...) - int offset = this->prefetch_data_.offset(item_id); + int offset = batch->data_.offset(item_id); this->transformed_data_.set_cpu_data(top_data + offset); if (datum.encoded()) { this->data_transformer_.Transform(cv_img, &(this->transformed_data_)); @@ -121,16 +215,13 @@ void DataLayer::InternalThreadEntry() { top_label[item_id] = datum.label(); } trans_time += timer.MicroSeconds(); - // go to the next iter - ++iter_; - if (iter_ == dataset_->end()) { - iter_ = dataset_->begin(); - } + + loaders_free_->push((Datum*) &datum); } batch_timer.Stop(); - DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms."; - DLOG(INFO) << " Read time: " << read_time / 1000 << " ms."; - DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms."; +// DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms."; +// DLOG(INFO) << " Read time: " << read_time / 1000 << " ms."; +// DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms."; } INSTANTIATE_CLASS(DataLayer); diff --git a/src/caffe/layers/image_data_layer.cpp b/src/caffe/layers/image_data_layer.cpp index 50997a23bf9..16a27bd5324 100644 --- a/src/caffe/layers/image_data_layer.cpp +++ b/src/caffe/layers/image_data_layer.cpp @@ -15,7 +15,7 @@ namespace caffe { template ImageDataLayer::~ImageDataLayer() { - this->JoinPrefetchThread(); + this->InternalThread::StopInternalThread(); } template @@ -68,11 +68,14 @@ void ImageDataLayer::DataLayerSetUp(const vector*>& bottom, const int batch_size = this->layer_param_.image_data_param().batch_size(); if (crop_size > 0) { top[0]->Reshape(batch_size, channels, crop_size, crop_size); - this->prefetch_data_.Reshape(batch_size, channels, crop_size, crop_size); + for(int i = 0; i < this->PREFETCH_COUNT; ++i) + this->prefetch_[i].data_.Reshape(batch_size, channels, + crop_size, crop_size); this->transformed_data_.Reshape(1, channels, crop_size, crop_size); } else { top[0]->Reshape(batch_size, channels, height, width); - this->prefetch_data_.Reshape(batch_size, channels, height, width); + for(int i = 0; i < this->PREFETCH_COUNT; ++i) + this->prefetch_[i].data_.Reshape(batch_size, channels, height, width); this->transformed_data_.Reshape(1, channels, height, width); } LOG(INFO) << "output data size: " << top[0]->num() << "," @@ -80,7 +83,8 @@ void ImageDataLayer::DataLayerSetUp(const vector*>& bottom, << top[0]->width(); // label top[1]->Reshape(batch_size, 1, 1, 1); - this->prefetch_label_.Reshape(batch_size, 1, 1, 1); + for(int i = 0; i < this->PREFETCH_COUNT; ++i) + this->prefetch_[i].label_.Reshape(batch_size, 1, 1, 1); } template @@ -90,18 +94,18 @@ void ImageDataLayer::ShuffleImages() { shuffle(lines_.begin(), lines_.end(), prefetch_rng); } -// This function is used to create a thread that prefetches the data. +// This function is called on prefetch thread template -void ImageDataLayer::InternalThreadEntry() { +void ImageDataLayer::load_batch(Batch* batch) { CPUTimer batch_timer; batch_timer.Start(); double read_time = 0; double trans_time = 0; CPUTimer timer; - CHECK(this->prefetch_data_.count()); + CHECK(batch->data_.count()); CHECK(this->transformed_data_.count()); - Dtype* top_data = this->prefetch_data_.mutable_cpu_data(); - Dtype* top_label = this->prefetch_label_.mutable_cpu_data(); + Dtype* top_data = batch->data_.mutable_cpu_data(); + Dtype* top_label = batch->label_.mutable_cpu_data(); ImageDataParameter image_data_param = this->layer_param_.image_data_param(); const int batch_size = image_data_param.batch_size(); const int new_height = image_data_param.new_height(); @@ -123,7 +127,7 @@ void ImageDataLayer::InternalThreadEntry() { read_time += timer.MicroSeconds(); timer.Start(); // Apply transformations (mirror, crop...) to the image - int offset = this->prefetch_data_.offset(item_id); + int offset = batch->data_.offset(item_id); this->transformed_data_.set_cpu_data(top_data + offset); this->data_transformer_.Transform(cv_img, &(this->transformed_data_)); trans_time += timer.MicroSeconds(); diff --git a/src/caffe/layers/window_data_layer.cpp b/src/caffe/layers/window_data_layer.cpp index 6287b385dc5..45ee2d404a8 100644 --- a/src/caffe/layers/window_data_layer.cpp +++ b/src/caffe/layers/window_data_layer.cpp @@ -30,7 +30,7 @@ namespace caffe { template WindowDataLayer::~WindowDataLayer() { - this->JoinPrefetchThread(); + this->InternalThread::StopInternalThread(); } template @@ -174,14 +174,16 @@ void WindowDataLayer::DataLayerSetUp(const vector*>& bottom, CHECK_GT(crop_size, 0); const int batch_size = this->layer_param_.window_data_param().batch_size(); top[0]->Reshape(batch_size, channels, crop_size, crop_size); - this->prefetch_data_.Reshape(batch_size, channels, crop_size, crop_size); + for(int i = 0; i < this->PREFETCH_COUNT; ++i) + this->prefetch_[i].data_.Reshape(batch_size, channels, crop_size, crop_size); LOG(INFO) << "output data size: " << top[0]->num() << "," << top[0]->channels() << "," << top[0]->height() << "," << top[0]->width(); // label top[1]->Reshape(batch_size, 1, 1, 1); - this->prefetch_label_.Reshape(batch_size, 1, 1, 1); + for(int i = 0; i < this->PREFETCH_COUNT; ++i) + this->prefetch_[i].label_.Reshape(batch_size, 1, 1, 1); // data mean has_mean_file_ = this->transform_param_.has_mean_file(); @@ -219,9 +221,9 @@ unsigned int WindowDataLayer::PrefetchRand() { return (*prefetch_rng)(); } -// Thread fetching the data +// This function is called on prefetch thread template -void WindowDataLayer::InternalThreadEntry() { +void WindowDataLayer::load_batch(Batch* batch) { // At each iteration, sample N windows where N*p are foreground (object) // windows and N*(1-p) are background (non-object) windows CPUTimer batch_timer; @@ -229,8 +231,8 @@ void WindowDataLayer::InternalThreadEntry() { double read_time = 0; double trans_time = 0; CPUTimer timer; - Dtype* top_data = this->prefetch_data_.mutable_cpu_data(); - Dtype* top_label = this->prefetch_label_.mutable_cpu_data(); + Dtype* top_data = batch->data_.mutable_cpu_data(); + Dtype* top_label = batch->label_.mutable_cpu_data(); const Dtype scale = this->layer_param_.window_data_param().scale(); const int batch_size = this->layer_param_.window_data_param().batch_size(); const int context_pad = this->layer_param_.window_data_param().context_pad(); @@ -254,7 +256,7 @@ void WindowDataLayer::InternalThreadEntry() { bool use_square = (crop_mode == "square") ? true : false; // zero out batch - caffe_set(this->prefetch_data_.count(), Dtype(0), top_data); + caffe_set(batch->data_.count(), Dtype(0), top_data); const int num_fg = static_cast(static_cast(batch_size) * fg_fraction); diff --git a/src/caffe/lmdb_dataset.cpp b/src/caffe/lmdb_dataset.cpp index 8f8e68e901e..f9a066813c4 100644 --- a/src/caffe/lmdb_dataset.cpp +++ b/src/caffe/lmdb_dataset.cpp @@ -57,10 +57,21 @@ bool LmdbDataset::open(const string& filename, int flag1 = 0; int flag2 = 0; if (mode == Base::ReadOnly) { - flag1 = MDB_RDONLY | MDB_NOTLS; + // No locking, assume db is not written to at the same time, otherwise + // LMDB tries to lock the file, which fails if it's read-only + flag1 = MDB_RDONLY | MDB_NOTLS | MDB_NOLOCK; flag2 = MDB_RDONLY; } + // Allow DB to be stand-alone file + { + struct stat st_buf; + stat (filename.c_str(), &st_buf); + if (S_ISREG (st_buf.st_mode)) { + flag1 |= MDB_NOSUBDIR; + } + } + retval = mdb_env_open(env_, filename.c_str(), flag1, 0664); if (MDB_SUCCESS != retval) { LOG(ERROR) << "mdb_env_open failed " << mdb_strerror(retval); diff --git a/src/caffe/parallel.cpp b/src/caffe/parallel.cpp new file mode 100644 index 00000000000..d9150752738 --- /dev/null +++ b/src/caffe/parallel.cpp @@ -0,0 +1,1277 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "caffe/filler.hpp" +#include "caffe/parallel.hpp" + +using namespace std; + +namespace caffe { + +void Meter::show(std::ostream& s) const { + ptime now = microsec_clock::local_time(); + uint64_t value = value_; + uint64_t delta = value - last_; + uint64_t u_sec = (now - time_).total_microseconds(); + double per_s = delta * 1e6 / (u_sec ? u_sec : 1); + last_ = value; + time_ = now; + s << name_ << " " << value << " ("; + if (unit_size_) + s << (int) (per_s * unit_size_ / (1024 * 1024)) << " mb"; + else + s << std::setprecision(2) << per_s; + s << "/s)"; +} + +// + +template +static size_t len(const vector > >& params) { + size_t len = 0; + for (int i = 0; i < params.size(); ++i) + len += params[i]->count(); + return len; +} + +// Align arrays to all potential chunk sizes to avoid boundary checks +template +static size_t align(const size_t len) { + size_t m = len; +#ifndef CPU_ONLY + m = max(m, CPUGPUSync::chunks(len) * CPUGPUSync::CHUNK); +#endif +#ifdef __linux__ + m = max(m, RawSync::chunks(len) * RawSync::CHUNK); +#endif +#ifdef INFINIBAND + m = max(m, IBSync::chunks(len) * IBSync::CHUNK); +#endif + return m; +} + +template +Params::Params(const vector > >& blobs, + const string& file_map) + : len_used_(len(blobs)), + len_buff_(align(len_used_)) { + + bool exists = false; + if (file_map.empty()) { + CaffeMallocHost((void**) &cpu_, len_buff_ * sizeof(Dtype)); + memset(cpu_, 0, len_buff_ * sizeof(Dtype)); + } else { + struct stat st_buf; + exists = stat(file_map.c_str(), &st_buf) == 0; + int fd = open(file_map.c_str(), O_RDWR | O_CREAT, // + S_IRWXU | S_IRWXG | S_IRWXO); + CHECK(!ftruncate(fd, len_buff_ * sizeof(Dtype))); + cpu_ = (Dtype*) mmap(NULL, // + len_buff_ * sizeof(Dtype), + PROT_READ | PROT_WRITE, + MAP_SHARED, fd, 0); + close(fd); + } + + Dtype* cpu = cpu_; + for (int i = 0; i < blobs.size(); ++i) { + int size = blobs[i]->data()->size(); + // Init to current values of blobs if file doesn't already exists + if (!exists) + memcpy(cpu, blobs[i]->data()->cpu_data(), size); + cpu += size / sizeof(Dtype); + CHECK(cpu <= cpu_ + len_used_); + } + size_t check = 0; + for (int i = 0; i < blobs.size(); ++i) + check += blobs[i]->count(); + Dtype* expect = cpu_ + check; + CHECK_EQ(expect, cpu); + + iterations_ = 0; +} + +template +Params::~Params() { + CaffeFreeHost((void*) cpu_); +} + +template +void Params::configure(Solver* solver) const { + // Replace weights + vector > > &blobs = solver->net()->params(); + Dtype* cpu = cpu_; + for (int i = 0; i < blobs.size(); ++i) { + blobs[i]->data()->set_cpu_data(cpu); + cpu += blobs[i]->data()->size() / sizeof(Dtype); + CHECK(cpu <= cpu_ + len_used_); + } + // Check sizes + size_t check = 0; + for (int i = 0; i < blobs.size(); ++i) + check += blobs[i]->count(); + Dtype* expect = cpu_ + check; + CHECK_EQ(expect, cpu); + + solver->iter_total(&iterations_); +} + +// + +#ifndef CPU_ONLY +#include + +template +GPUParams::GPUParams(const Params& params, int device) + : params_(params), + device_(device) { + + int current; + CUDA_CHECK(cudaGetDevice(¤t)); + CUDA_CHECK(cudaSetDevice(device)); + const size_t size = params.len_buff() * sizeof(Dtype); + CUDA_CHECK(cudaMalloc((void** ) &gpu_, size)); + CUDA_CHECK(cudaMemcpy(gpu_, params.cpu(), size, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaSetDevice(current)); +} + +template +GPUParams::~GPUParams() { + CUDA_CHECK(cudaFree((void* ) gpu_)); +} + +template +void GPUParams::configure(Solver* solver) const { + // Replace GPU weights + vector > > &blobs = solver->net()->params(); + Dtype* gpu = gpu_; + for (int i = 0; i < blobs.size(); ++i) { + blobs[i]->data()->set_gpu_data(gpu); + gpu += blobs[i]->data()->size() / sizeof(Dtype); + CHECK(gpu <= gpu_ + params_.len_used()); + } + size_t check = 0; + for (int i = 0; i < blobs.size(); ++i) + check += blobs[i]->count(); + Dtype* expect = gpu_ + check; + CHECK_EQ(expect, gpu); + + solver->iter_total(¶ms_.iterations_); +} + +// + +template +GPUStream::GPUStream() { + int least, greatest; + cudaDeviceGetStreamPriorityRange(&least, &greatest); + cudaStreamCreateWithPriority(&stream_, cudaStreamNonBlocking, least); +} + +template +GPUStream::~GPUStream() { + cudaStreamDestroy(stream_); +} + +// + +template +GPUSync::GPUSync(const GPUParams& params) + : params_(params) { + + size_t size = params.params().len_buff() * sizeof(Dtype); + Dtype* gpu = params.gpu(); + CUDA_CHECK(cudaMalloc((void** ) &gpu_last_, size)); + CUDA_CHECK(cudaMemcpy(gpu_last_, gpu, size, cudaMemcpyDeviceToDevice)); +} + +template +GPUSync::~GPUSync() { + CUDA_CHECK(cudaFree((void* ) gpu_last_)); +} + +// + +template +CPUGPUSync::CPUGPUSync(const GPUParams& params) + : GPUSync(params), + chunks_(chunks(params.params().len_used())), + calls_("calls", CHUNK * sizeof(Dtype)), + cycles_("cycles") { +} + +template +CPUGPUSync::~CPUGPUSync() { + stop(); +} + +template +void CPUGPUSync::run() { + CUDA_CHECK(cudaSetDevice(this->params_.device())); + GPUStream gpu_stream; + const cudaStream_t& stream = gpu_stream.stream(); + + // Current cpu values when invoking kernel, gradients on the way back + Dtype* buf; + Dtype* tmp; + CUDA_CHECK(cudaMalloc((void** ) &buf, CHUNK * sizeof(Dtype))); + CaffeMallocHost((void**) &tmp, CHUNK * sizeof(Dtype)); + + const size_t len = CHUNK * sizeof(Dtype); + // Explicit directions for readability + const cudaMemcpyKind put = cudaMemcpyHostToDevice; + const cudaMemcpyKind get = cudaMemcpyDeviceToHost; + uint32_t index = 0; + Dtype* cpu = this->params_.params().cpu(); + Dtype* gpu = this->params_.gpu(); + Dtype* last = this->gpu_last_; + uint8_t get_grads = true; + + while (!must_stop()) { + size_t off = index * CHUNK; + CUDA_CHECK(cudaMemcpyAsync(buf, &cpu[off], len, put, stream)); + // TODO simpler kernel + sync_worker_kernel(gpu, last, &buf, &off, &buf, &get_grads, // + 0, 1, stream, CHUNK); + CUDA_CHECK(cudaMemcpyAsync(tmp, buf, len, get, stream)); + cudaStreamSynchronize(stream); + for (size_t i = 0; i < CHUNK; ++i) + cpu[off + i] += tmp[i]; + if (++index == chunks_) { + index = 0; + cycles_++; + } + calls_++; + } + + CaffeFreeHost((void*) tmp); + CUDA_CHECK(cudaFree((void* ) buf)); +} + +#endif + +// + +template +DistSync::DistSync(uint32_t nodes, uint32_t chunks) + : nodes_(nodes), + chunks_(chunks), + received_(chunks), + remaining_(chunks), + cycles_("cycles") { + + own_start_ = own_until_ = chunk_ = 0; +} + +template +void DistSync::dist_init(int rank) { + own_start_ = (rank + 0) * chunks_ / nodes_; + own_until_ = (rank + 1) * chunks_ / nodes_; + LOG(INFO)<< "range: " << own_start_ << " " << own_until_; + chunk_ = own_start_; + + for (uint32_t chunk = own_start_; chunk < own_until_; ++chunk) { + received_[chunk] = true; + remaining_--; + } +} + +template +inline int DistSync::chunk_master(uint32_t chunk) { + // TODO find range without loop? + for (int i = nodes_ - 1; i >= 0; --i) { + uint32_t start = i * chunks_ / nodes_; + if (start <= chunk) + return i; + } + CHECK(false); + return -1; +} + +// + +INSTANTIATE_CLASS(Params); +#ifndef CPU_ONLY +INSTANTIATE_CLASS(GPUParams); +INSTANTIATE_CLASS(GPUSync); +INSTANTIATE_CLASS(CPUGPUSync); +#endif +INSTANTIATE_CLASS(DistSync); + +#ifdef RDMA + +ibv_context* IBChannel::open_device(ibv_device* ib_dev) { + ibv_context* context = ibv_open_device(ib_dev); + CHECK(context) << "Open context failed for " << ibv_get_device_name(ib_dev); + return context; +} + +ibv_pd* IBChannel::alloc_pd(ibv_context* context) { + ibv_pd* pd = ibv_alloc_pd(context); + CHECK(pd) << "Failed to allocate protection domain"; + return pd; +} + +IBChannel::IBChannel(ibv_device* ib_dev) + : context_(open_device(ib_dev)), + pd_(alloc_pd(context_)), + buf_send_(), + buf_recv_(), + mr_send_(), + mr_recv_(), + send_queue_(FRAMES), + recv_queue_(FRAMES), + sent_("sent", MTU), + recv_("recv", MTU) { + + cq_ = ibv_create_cq(context_, FRAMES * 2, NULL, NULL, 0); + CHECK(cq_) << "Failed to create completion queue"; + + // Create queue pair + { + ibv_qp_init_attr attr; + memset(&attr, 0, sizeof attr); + attr.send_cq = cq_; + attr.recv_cq = cq_; + attr.cap.max_send_wr = FRAMES; + attr.cap.max_recv_wr = FRAMES; + attr.cap.max_send_sge = 1; + attr.cap.max_recv_sge = 1; + attr.qp_type = IBV_QPT_UD, + + qp_ = ibv_create_qp(pd_, &attr); + CHECK(qp_) << "Failed to create queue pair"; + } + + // Init queue pair + { + ibv_qp_attr attr; + memset(&attr, 0, sizeof attr); + attr.qp_state = IBV_QPS_INIT; + attr.pkey_index = 0; + attr.port_num = PORT; + attr.qkey = 0x11111111; + + int mask = IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_QKEY; + CHECK(!ibv_modify_qp(qp_, &attr, mask)) << "Failed to set QP to INIT"; + } + + // Local address + { + memset(&local_, 0, sizeof(local_)); + ibv_port_attr attr; + CHECK(!ibv_query_port(context_, PORT, &attr)) << "Query port"; + CHECK(attr.active_mtu == IBV_MTU_4096); + local_.lid = attr.lid; + local_.qpn = qp_->qp_num; + local_.psn = caffe_rng_rand() & 0xffffff; + } + + // Queue pair to recv & send + { + struct ibv_qp_attr attr; + attr.qp_state = IBV_QPS_RTR; // Ready to receive + CHECK(!ibv_modify_qp(qp_, &attr, IBV_QP_STATE)) << "QP to RTR"; + attr.qp_state = IBV_QPS_RTS; // Ready to send + attr.sq_psn = local_.psn; + int mask = IBV_QP_STATE | IBV_QP_SQ_PSN; + CHECK(!ibv_modify_qp(qp_, &attr, mask)) << "QP to RTS"; + } + + for (int i = 0; i < 2 * FRAMES; ++i) + wc_[i].wr_id = i; +} + +static ib_addr mcast_init(ibv_context* context, int port, const ibv_gid* mgid) { + mcast_parameters params; + memset(¶ms, 0, sizeof(struct mcast_parameters)); + + string ib_devname(ibv_get_device_name(context->device)); + params.ib_devname = const_cast(ib_devname.c_str()); + CHECK(!ibv_query_gid(context, port, 0, ¶ms.port_gid)); + CHECK(!ibv_query_pkey(context, port, DEF_PKEY_IDX, ¶ms.pkey)); + + ibv_port_attr port_attr; + CHECK(!ibv_query_port(context, port, &port_attr)); + params.sm_lid = port_attr.sm_lid; + params.sm_sl = port_attr.sm_sl; + params.ib_port = port; + + if (mgid) + memcpy(¶ms.mgid.raw, &mgid->raw, 16); + + CHECK(!join_multicast_group(SUBN_ADM_METHOD_SET, ¶ms)) + << "Failed to create multicast group"; + + ib_addr addr; + memcpy(&addr.gid.raw, ¶ms.mgid.raw, 16); + addr.lid = params.mlid; + addr.qpn = QPNUM_MCAST; + return addr; +} + +ib_addr IBChannel::mcast_create() const { + ib_addr addr = mcast_init(context_, PORT, NULL); + addr.psn = caffe_rng_rand() & 0xffffff; + return addr; +} + +void IBChannel::mcast_join(const ib_addr& addr) const { + mcast_init(context_, PORT, &addr.gid); +} + +void IBChannel::mcast_attach_qp(const ib_addr& addr) const { + CHECK(!ibv_attach_mcast(qp_, &addr.gid, addr.lid)) + << "Failed to attach to the multicast group"; +} + +void IBChannel::start(uint8_t* buf_send, size_t buf_size, bool gpu) const { + size_t send_size = buf_send ? buf_size : FRAMES * MTU; + size_t recv_size = FRAMES * (GRH + MTU); + + if (gpu) { + if (buf_send) { + buf_send_ = buf_send; + } else { + CUDA_CHECK(cudaMalloc((void** ) &buf_send_, send_size)); + } + CUDA_CHECK(cudaMalloc((void** ) &buf_recv_, recv_size)); + } else { + buf_send_ = buf_send ? buf_send : (uint8_t*) malloc(send_size); + buf_recv_ = (uint8_t*) malloc(recv_size); + } + + LOG(INFO)<< "range: " << hex << (uint64_t) buf_send_ << " " << (uint64_t) send_size; + LOG(INFO)<< "range: " << hex << (uint64_t) buf_recv_ << " " << (uint64_t) recv_size; + + mr_send_ = ibv_reg_mr(pd_, buf_send_, send_size, IBV_ACCESS_LOCAL_WRITE); + mr_recv_ = ibv_reg_mr(pd_, buf_recv_, recv_size, IBV_ACCESS_LOCAL_WRITE); + CHECK(mr_send_ && mr_recv_) << "Failed to register memory regions"; + + // Create initial requests, start the recv ones + for (int i = 0; i < FRAMES; ++i) { + send_queue_[i] = i; + recv_done(i + FRAMES); + } + recv_queue_.clear(); +} + +IBChannel::~IBChannel() { + CHECK(!ibv_destroy_qp(qp_)) << "Failed to destroy QP"; + CHECK(!ibv_destroy_cq(cq_)) << "Failed to destroy CQ"; + CHECK(!ibv_dereg_mr(mr_send_)) << "Failed to deregister MR"; + CHECK(!ibv_dereg_mr(mr_recv_)) << "Failed to deregister MR"; + CHECK(!ibv_dealloc_pd(pd_)) << "Failed to deallocate PD"; + CHECK(!ibv_close_device(context_)) << "Failed to release context"; + free(buf_send_); + free(buf_recv_); +} + +bool IBChannel::can_send() const { + return !send_queue_.empty(); +} + +int IBChannel::send_init(uint8_t*& buf) const { + int id = send_queue_.front(); + send_queue_.pop_front(); + buf = buf_send_ + id * MTU; + return id; +} + +void IBChannel::send(int id, const ib_addr& addr, uint8_t* buf, + uint32_t imm_data) const { + struct ibv_sge list; + struct ibv_send_wr wr; + struct ibv_send_wr *bad_wr; + + list.addr = (uintptr_t) buf; + list.length = MTU; + list.lkey = mr_send_->lkey; + + wr.wr_id = id; + wr.next = NULL; + wr.sg_list = &list; + wr.num_sge = 1; + wr.opcode = IBV_WR_SEND_WITH_IMM; + wr.send_flags = IBV_SEND_SIGNALED; + wr.imm_data = imm_data; + wr.wr.ud.ah = addr.ah; + wr.wr.ud.remote_qpn = addr.qpn; + wr.wr.ud.remote_qkey = 0x11111111; + + CHECK(!ibv_post_send(qp_, &wr, &bad_wr)) << "Failed send"; +} + +bool IBChannel::can_recv() const { + return !recv_queue_.empty(); +} + +int IBChannel::recv(uint8_t*& buf, uint32_t& imm_data) const { + recv_msg& msg = recv_queue_.front(); + int id = msg.id_; + buf = buf_recv_ + (id - FRAMES) * (GRH + MTU) + GRH; + imm_data = msg.imm_; + recv_queue_.pop_front(); + return id; +} + +void IBChannel::recv_done(int id) const { + struct ibv_sge list; + struct ibv_recv_wr wr; + struct ibv_recv_wr* bad_wr; + + list.addr = (uintptr_t) (buf_recv_ + (id - FRAMES) * (GRH + MTU)); + list.length = GRH + MTU; + list.lkey = mr_recv_->lkey; + + wr.wr_id = id; + wr.next = NULL; + wr.sg_list = &list; + wr.num_sge = 1; + + CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed receive"; +} + +void IBChannel::poll() const { + int ne = ibv_poll_cq(cq_, FRAMES * 2, wc_); + CHECK(ne >= 0) << "Poll CQ failed"; + + for (int i = 0; i < ne; ++i) { + CHECK(wc_[i].status == IBV_WC_SUCCESS) << "Failed status \n" + << ibv_wc_status_str(wc_[i].status) + << " " << wc_[i].status << " " + << (int) wc_[i].wr_id << " " + << wc_[i].vendor_err; + + if (wc_[i].wr_id < IBChannel::FRAMES) { + sent_++; + send_queue_.push_back(wc_[i].wr_id); + } else { + recv_++; + CHECK(wc_[i].byte_len == GRH + MTU); + recv_msg msg; + msg.id_ = wc_[i].wr_id; + msg.imm_ = wc_[i].imm_data; + recv_queue_.push_back(msg); + } + } +} + +// + +template +IBSync::IBSync(const Params& params, int rank, + const IBChannel& ucast, const IBChannel& mcast, + const vector& ucast_addrs, + const vector& mcast_addrs) + : DistSync(ucast_addrs.size(), chunks(params.len_used())), + rank_(rank), + ucast_(ucast), + mcast_(mcast), + ucast_addrs_(ucast_addrs), + mcast_addr_(mcast_addrs[rank]) { + + for (int i = 0; i < ucast_addrs_.size(); ++i) { + CHECK(ucast_addrs_[i].ah == NULL); + if (i != rank) { + struct ibv_ah_attr ah_attr; + memset(&ah_attr, 0, sizeof ah_attr); + ah_attr.dlid = (uint16_t) ucast_addrs[i].lid; + ah_attr.sl = (uint8_t) 0; // Service level + ah_attr.src_path_bits = 0; + ah_attr.is_global = 0; + ah_attr.port_num = IBChannel::PORT; + ucast_addrs_[i].ah = ibv_create_ah(ucast.pd_, &ah_attr); + CHECK(ucast_addrs_[i].ah) << "Failed to create address handle"; + } + } + + struct ibv_ah_attr ah_attr; + memset(&ah_attr, 0, sizeof ah_attr); + ah_attr.grh.dgid = mcast_addr_.gid; + ah_attr.dlid = (uint16_t) mcast_addr_.lid; + ah_attr.sl = (uint8_t) 0; // Service level + ah_attr.src_path_bits = 0; + ah_attr.is_global = 1; + ah_attr.port_num = IBChannel::PORT; + mcast_addr_.ah = ibv_create_ah(mcast.pd_, &ah_attr); + CHECK(mcast_addr_.ah) << "Failed to create address handle"; + + for (int i = 0; i < mcast_addrs.size(); ++i) { + if (i != rank) { + mcast_.mcast_join(mcast_addrs[i]); + mcast_.mcast_attach_qp(mcast_addrs[i]); + } + } + + this->dist_init(rank); +} + +template +IBSync::~IBSync() { + for (int i = 0; i < this->ucast_addrs_.size(); ++i) { + if (i == rank_) { + CHECK(!ibv_destroy_ah(this->ucast_addrs_[i].ah)) + << "Failed to destroy ucast AH"; + } + } + CHECK(!ibv_destroy_ah(this->mcast_addr_.ah)) << "Failed to destroy mcast AH"; +} + +// + +template +CPUIBSync::CPUIBSync(const Params& params, int rank, + const IBChannel& ucast, const IBChannel& mcast, + const vector& ucast_addrs, + const vector& mcast_addrs) + : IBSync(params, rank, ucast, mcast, ucast_addrs, mcast_addrs) { + + cpu_ = params.cpu(); + CaffeMallocHost((void**) &cpu_last_, params.len_buff() * sizeof(Dtype)); + memcpy(cpu_last_, cpu_, params.len_used() * sizeof(Dtype)); +} + +template +CPUIBSync::~CPUIBSync() { + CaffeFreeHost((void*) cpu_last_); +} + +template +void CPUIBSync::run() { + // TODO +} + +// + +template +GPUIBSync::GPUIBSync(const GPUParams& params, int rank, + const IBChannel& ucast, const IBChannel& mcast, + const vector& ucast_addrs, + const vector& mcast_addrs) + : GPUSync(params), + IBSync(params.params(), rank, // + ucast, mcast, // + ucast_addrs, mcast_addrs) { + + gpu_ = params.gpu(); + int device; + CUDA_CHECK(cudaGetDevice(&device)); + CUDA_CHECK(cudaSetDevice(params.device())); + size_t size = params.params().len_buff() * sizeof(Dtype); + CUDA_CHECK(cudaMalloc((void** ) &gpu_last_, size)); + CUDA_CHECK(cudaMemcpy(gpu_last_, gpu_, size, cudaMemcpyDeviceToDevice)); + CUDA_CHECK(cudaSetDevice(device)); +} + +template +GPUIBSync::~GPUIBSync() { + CUDA_CHECK(cudaFree((void* ) gpu_last_)); +} + +class Queue { + public: + Queue() + : front_(), + back_(), + size_() { + } + void push() { + CHECK(size_ < IBChannel::FRAMES); + back_ = (back_ + 1) & (IBChannel::FRAMES - 1); + size_++; + } + void pop() { + CHECK(size_ > 0); + front_ = (front_ + 1) & (IBChannel::FRAMES - 1); + size_--; + } + + int front_; + int back_; + int size_; +}; + +class EventQueue : Queue { + public: + EventQueue(const cudaStream_t& stream) + : stream_(stream) { + + for (int i = 0; i < IBChannel::FRAMES; ++i) + cudaEventCreateWithFlags(&items_[i].event_, cudaEventDisableTiming); + } + + ~EventQueue() { + for (int i = 0; i < IBChannel::FRAMES; ++i) + cudaEventDestroy(items_[i].event_); + } + + void record(int tag) { + cudaEventRecord(items_[back_].event_, this->stream_); + items_[back_].tag_ = tag; + push(); + } + + bool query(int& tag) { + if (size_ && cudaEventQuery(items_[front_].event_) == cudaSuccess) { + tag = items_[front_].tag_; + pop(); + return true; + } + return false; + } + + protected: + const cudaStream_t& stream_; + struct item { + cudaEvent_t event_; + int tag_; + }; + item items_[IBChannel::FRAMES]; +}; + +template +void GPUIBSync::run() { + CUDA_CHECK(cudaSetDevice(this->params_.device())); + + const IBChannel& ucast = this->ucast_; + const IBChannel& mcast = this->mcast_; + ucast.start(NULL, 0, true); + mcast.start((uint8_t*) gpu_, (size_t) this->chunks_ * IBChannel::MTU, true); + + GPUStream master_stream; + Queue master_queue; + uint16_t master_ids_[FRAMES]; + EventQueue master_events(master_stream.stream()); + + GPUStream worker_stream; + Queue worker_queue; + struct worker_item { + int recv_id, send_id; + Dtype* grd; + uint32_t chunk; + }; + worker_item worker_items[FRAMES]; + EventQueue worker_events(worker_stream.stream()); + + const size_t real_size = FRAMES * sizeof(Dtype*); + const size_t size_size = FRAMES * sizeof(size_t); + const size_t bool_size = FRAMES * sizeof(size_t); + + Dtype** master_gpu_grds; + size_t* master_gpu_offs; + CUDA_CHECK(cudaMalloc((void** ) &master_gpu_grds, real_size)); + CUDA_CHECK(cudaMalloc((void** ) &master_gpu_offs, size_size)); + Dtype** master_cpu_grds; + size_t* master_cpu_offs; + CUDA_CHECK(cudaMallocHost((void** ) &master_cpu_grds, real_size)); + CUDA_CHECK(cudaMallocHost((void** ) &master_cpu_offs, size_size)); + + Dtype** worker_gpu_pos; + size_t* worker_gpu_offs; + Dtype** worker_gpu_grds; + uint8_t* worker_gpu_gets; + CUDA_CHECK(cudaMalloc((void** ) &worker_gpu_pos, real_size)); + CUDA_CHECK(cudaMalloc((void** ) &worker_gpu_offs, size_size)); + CUDA_CHECK(cudaMalloc((void** ) &worker_gpu_grds, real_size)); + CUDA_CHECK(cudaMalloc((void** ) &worker_gpu_gets, bool_size)); + Dtype** worker_cpu_pos; + size_t* worker_cpu_offs; + Dtype** worker_cpu_grds; + uint8_t* worker_cpu_gets; + CUDA_CHECK(cudaMallocHost((void** ) &worker_cpu_pos, real_size)); + CUDA_CHECK(cudaMallocHost((void** ) &worker_cpu_offs, size_size)); + CUDA_CHECK(cudaMallocHost((void** ) &worker_cpu_grds, real_size)); + CUDA_CHECK(cudaMallocHost((void** ) &worker_cpu_gets, bool_size)); + + int master_batch_start = 0; + int master_batch_count = 0; + int worker_batch_start = 0; + int worker_batch_count = 0; + const int batch = 128; // TODO bench + while (!this->must_stop()) { + ucast.poll(); + mcast.poll(); + + // Receive gradients for chunks for which we are master + { + while (ucast.can_recv()) { + uint8_t* buf; + uint32_t chunk; + int id = ucast.recv(buf, chunk); + Dtype* grd = (Dtype*) buf; + CHECK(this->chunk_master(chunk) == this->rank_); + size_t off = ((size_t) chunk) * IBSync::CHUNK; + + int index = master_queue.back_; + master_ids_[index] = id; + master_cpu_grds[index] = grd; + master_cpu_offs[index] = off; + master_queue.push(); + master_batch_count++; + } + // Add gradients to our weights + if (master_batch_count >= batch) { + CUDA_CHECK( + cudaMemcpyAsync(master_gpu_grds, master_cpu_grds, real_size, + cudaMemcpyHostToDevice, master_stream.stream())); + CUDA_CHECK( + cudaMemcpyAsync(master_gpu_offs, master_cpu_offs, size_size, + cudaMemcpyHostToDevice, master_stream.stream())); + sync_master_kernel(gpu_, master_gpu_grds, master_gpu_offs, + master_batch_start, master_batch_count, + master_stream.stream(), IBSync::CHUNK); + master_events.record(master_batch_count); + master_batch_start = master_queue.back_; + master_batch_count = 0; + } + } + // Start receiving again once kernels are done with buffers + for (;;) { + int batch; + if (!master_events.query(batch)) { + break; + } + for (int i = 0; i < batch; ++i) { + int index = master_queue.front_; + master_queue.pop(); + ucast.recv_done(master_ids_[index]); + } + } + // Send absolute positions for chunks for which we are master + while (mcast.can_send()) { + uint8_t* buf; + int id = mcast.send_init(buf); // buf ignored + size_t off = (size_t) this->chunk_ * IBSync::CHUNK; + buf = (uint8_t*) (gpu_ + off); + CHECK(id >= 0 && id < FRAMES); + mcast.send(id, this->mcast_addr_, buf, this->chunk_); + if (++this->chunk_ == this->own_until_) { + this->chunk_ = this->own_start_; + this->cycles_++; + } + } + + // Receive absolute positions for other chunks + { + while (mcast.can_recv()) { + Dtype* pos; + uint32_t chunk; + int recv_id, send_id; + size_t off; + { + uint8_t* buf; + recv_id = mcast.recv(buf, chunk); + pos = (Dtype*) buf; + off = ((size_t) chunk) * IBSync::CHUNK; + } + + // Send back the gradients if frame is available + Dtype* grd = NULL; + if (ucast.can_send()) { + uint8_t* buf; + send_id = ucast.send_init(buf); + grd = (Dtype*) buf; + } + + int index = worker_queue.back_; + worker_items[index].recv_id = recv_id; + worker_items[index].send_id = send_id; + worker_items[index].grd = grd; + worker_items[index].chunk = chunk; + worker_cpu_pos[index] = pos; + worker_cpu_offs[index] = off; + worker_cpu_grds[index] = grd; + worker_cpu_gets[index] = grd != NULL; + worker_queue.push(); + worker_batch_count++; + } + if (worker_batch_count >= batch) { + CUDA_CHECK( + cudaMemcpyAsync(worker_gpu_pos, worker_cpu_pos, real_size, + cudaMemcpyHostToDevice, worker_stream.stream())); + CUDA_CHECK( + cudaMemcpyAsync(worker_gpu_offs, worker_cpu_offs, size_size, + cudaMemcpyHostToDevice, worker_stream.stream())); + CUDA_CHECK( + cudaMemcpyAsync(worker_gpu_grds, worker_cpu_grds, real_size, + cudaMemcpyHostToDevice, worker_stream.stream())); + CUDA_CHECK( + cudaMemcpyAsync(worker_gpu_gets, worker_cpu_gets, bool_size, + cudaMemcpyHostToDevice, worker_stream.stream())); + sync_worker_kernel(gpu_, gpu_last_, worker_gpu_pos, + worker_gpu_offs, worker_gpu_grds, + worker_gpu_gets, worker_batch_start, + worker_batch_count, worker_stream.stream(), + IBSync::CHUNK); + worker_events.record(worker_batch_count); + worker_batch_start = worker_queue.back_; + worker_batch_count = 0; + } + } + for (;;) { + int batch; + if (!worker_events.query(batch)) { + break; + } + for (int i = 0; i < batch; ++i) { + int index = worker_queue.front_; + worker_queue.pop(); + int recv_id = worker_items[index].recv_id; + int send_id = worker_items[index].send_id; + Dtype* grd = worker_items[index].grd; + uint32_t chunk = worker_items[index].chunk; + + mcast.recv_done(recv_id); + if (grd) { + int master = this->chunk_master(chunk); + CHECK(master != this->rank_); + ib_addr& a = this->ucast_addrs_[master]; + ucast.send(send_id, a, (uint8_t*) grd, chunk); + } + if (this->remaining_ > 0 && !this->received_[chunk]) { + this->received_[chunk] = true; + this->remaining_--; + } + } + } + } +} + +INSTANTIATE_CLASS(IBSync); +INSTANTIATE_CLASS(CPUIBSync); +INSTANTIATE_CLASS(GPUIBSync); + +#endif + +#ifdef __linux__ + +// Parse MAC address to byte array +// TODO remove optional ':' chars +static uint8_t* parse_mac(const char* str) { + uint8_t* bytes = (uint8_t*) malloc(ETH_ALEN); + for (int i = 0; i < ETH_ALEN; ++i) { + int value; + sscanf(str + 2 * i, "%02x", &value); + bytes[i] = value; + } + return bytes; +} + +static vector parse_macs(const vector& macs) { + vector res; + for (int i = 0; i < macs.size(); ++i) + res.push_back(parse_mac(macs[i].c_str())); + return res; +} + +// Adapter name from MAC address +static string adapter(const uint8_t* mac) { + int s = socket(AF_INET, SOCK_DGRAM, IPPROTO_IP); + CHECK(s != -1); + + // Iterate over adapters + struct ifconf ifc; + char buf[1024]; + ifc.ifc_len = sizeof(buf); + ifc.ifc_buf = buf; + CHECK(ioctl(s, SIOCGIFCONF, &ifc) != -1); + struct ifreq* it = ifc.ifc_req; + const struct ifreq* const end = it + (ifc.ifc_len / sizeof(struct ifreq)); + + // Look for a MAC match + struct ifreq ifr; + for (; it != end; ++it) { + strcpy(ifr.ifr_name, it->ifr_name); + CHECK(!ioctl(s, SIOCGIFHWADDR, &ifr)); + if (!memcmp(mac, ifr.ifr_hwaddr.sa_data, ETH_ALEN)) + return string(it->ifr_name); + } + return ""; +} + +static int local(const vector& macs) { + for (int i = 0; i < macs.size(); ++i) { + string a = adapter(macs[i]); + if (!a.empty()) + return i; + } + CHECK(0) << "Local machine not part of given MAC addresses."; + return -1; +} + +// + +Ring::Ring(const string& adapter, int protocol_send, int protocol_recv) + : adapter_(adapter), // + socket_(socket(PF_PACKET, SOCK_RAW, htons(protocol_recv))), // + sent_("sent", ETH_FRAME_LEN), + recv_("recv", ETH_FRAME_LEN) { + + const int s = socket_; + CHECK(s != -1) << "Cannot open raw socket, make sure to run as root or to " + << "set the capability on the executable: " + << "sudo setcap cap_net_raw+ep " << endl; + + // TODO look at this + // s_ifr.ifr_mtu = c_mtu; + // /* update the mtu through ioctl */ + // ec = ioctl(fd_socket, SIOCSIFMTU, &s_ifr); + // if(ec == -1) + // { + // perror("iotcl"); + // return EXIT_FAILURE; + // } + + // Get adapter info + struct ifreq ifr; + strcpy(ifr.ifr_name, adapter.c_str()); + CHECK(ioctl(s, SIOCGIFINDEX, &ifr) != -1); + int index = ifr.ifr_ifindex; + CHECK(ioctl(s, SIOCGIFHWADDR, &ifr) != -1); + uint8_t* mac = (uint8_t*) ifr.ifr_hwaddr.sa_data; + + // Bind to interface + struct sockaddr_ll addr; + memset(&addr, 0, sizeof(struct sockaddr_ll)); + addr.sll_family = AF_PACKET; + addr.sll_protocol = htons(protocol_recv); + addr.sll_ifindex = index; + CHECK(bind(s, (struct sockaddr* ) &addr, sizeof(struct sockaddr_ll)) != -1); + + // Setup ring buffer + struct tpacket_req req; + req.tp_frame_size = FRAME_SIZE; + req.tp_frame_nr = FRAME_NR; + req.tp_block_size = FRAME_SIZE * FRAME_NR; + req.tp_block_nr = BLOCK_NR; + CHECK(setsockopt(s, SOL_PACKET, PACKET_RX_RING, &req, sizeof(req)) >= 0); + CHECK(setsockopt(s, SOL_PACKET, PACKET_TX_RING, &req, sizeof(req)) >= 0); + uint32_t size = req.tp_block_size * req.tp_block_nr; + int prot = PROT_READ | PROT_WRITE; + map_recv_ = (uint8_t*) mmap(0, 2 * size, prot, MAP_SHARED, s, 0); + map_send_ = map_recv_ + size; + CHECK(map_recv_ != (void* ) -1); + + // Pre-fill send frames with sender address and protocol + const __be16 protocol = htons(protocol_send); + for (int i = 0; i < FRAME_NR; i++) { + struct tpacket_hdr* hdr; + hdr = (struct tpacket_hdr*) (map_send_ + FRAME_SIZE * i); + hdr->tp_len = ETH_FRAME_LEN; + uint8_t* eth = ((uint8_t*) hdr) + TPACKET_ALIGN(sizeof(struct tpacket_hdr)); + memcpy(eth + ETH_ALEN, mac, ETH_ALEN); + memcpy(eth + ETH_ALEN * 2, &protocol, 2); + } +} + +Ring::~Ring() { + shutdown(socket_, 2); +} + +bool Ring::can_send(int frame, struct tpacket_hdr*& hdr) { + hdr = (struct tpacket_hdr*) (map_send_ + FRAME_SIZE * frame); + int status = (volatile uint32_t) hdr->tp_status; + CHECK(!(status & TP_STATUS_WRONG_FORMAT)); + return status == TP_STATUS_AVAILABLE; +} + +ethhdr* Ring::send_init(const struct tpacket_hdr* hdr) { + uint8_t* eth = ((uint8_t*) hdr) + TPACKET_ALIGN(sizeof(struct tpacket_hdr)); + return (struct ethhdr*) eth; +} + +void Ring::send(struct tpacket_hdr* hdr) { + hdr->tp_status = TP_STATUS_SEND_REQUEST; + sent_++; +} + +bool Ring::can_recv(int frame, struct tpacket_hdr*& hdr) { + hdr = (struct tpacket_hdr*) (map_recv_ + FRAME_SIZE * frame); + int status = (volatile uint32_t) hdr->tp_status; + CHECK(!(status & TP_STATUS_COPY)); + return status & TP_STATUS_USER; +} + +ethhdr* Ring::recv(const struct tpacket_hdr* hdr) { + return (struct ethhdr*) ((uint8_t*) hdr + hdr->tp_mac); +} + +void Ring::recv_done(struct tpacket_hdr* hdr) { + hdr->tp_status = TP_STATUS_KERNEL; + recv_++; +} + +void Ring::socket_stats(uint64_t& received, uint64_t& dropped) { + struct tpacket_stats st; + unsigned int len = sizeof(st); + int s = socket_; + CHECK(!getsockopt(s, SOL_PACKET, PACKET_STATISTICS, (char* ) &st, &len)); + received = st.tp_packets; + dropped = st.tp_drops; +} + +// + +template +RawSync::RawSync(const Params& params, + const vector& mac_addresses, + const vector& secondary_macs) + : DistSync(mac_addresses.size(), chunks(params.len_used())), + masters_(parse_macs(mac_addresses)), + workers_( + secondary_macs.size() ? + parse_macs(secondary_macs) : parse_macs(mac_addresses)), + others_(), + master_(adapter(this->masters_[local(this->masters_)]), 0x73A, 0x73B), + worker_(adapter(this->workers_[local(this->workers_)]), 0x73B, 0x73A) { + + int rank = local(this->masters_); + ostringstream s; + s << "Raw socket - node: " << rank << ", "; + if (secondary_macs.size()) { + CHECK(master_.adapter() != worker_.adapter()); + CHECK(rank == local(this->workers_)); + s << "adapters: " << master_.adapter() << ", " << worker_.adapter() << endl; + } else { + CHECK(master_.adapter() == worker_.adapter()); + s << "adapter: " << master_.adapter() << endl; + } + LOG(INFO)<< s.str(); + + cpu_ = params.cpu(); + CaffeMallocHost((void**) &cpu_last_, params.len_buff() * sizeof(Dtype)); + memcpy(cpu_last_, cpu_, params.len_used() * sizeof(Dtype)); + + for (int i = 0; i < workers_.size(); ++i) + if (i != rank) + others_.push_back(workers_[i]); + + this->dist_init(rank); +} + +template +RawSync::~RawSync() { + CaffeFreeHost((void*) cpu_last_); +} + +template +inline void RawSync::next() { + if (++other_ == others_.size()) { + other_ = 0; + if (++this->chunk_ == this->own_until_) { + this->chunk_ = this->own_start_; + this->cycles_++; + } + } +} + +template +void RawSync::run() { + struct tpacket_hdr* hdr; + struct tpacket_hdr* hdr_send; + // TODO split over two threads? compact wire format? + for (;;) { + // Receive and add gradients for chunks for which we are master + for (int f = 0; f < Ring::FRAME_NR; f++) { + if (master_.can_recv(f, hdr)) { + ethhdr* eth = master_.recv(hdr); + uint8_t* data = (uint8_t*) eth + ETH_HLEN; + uint32_t chunk = ((uint32_t*) &(data[MSG_CHUNK]))[0]; + size_t off = ((size_t) chunk) * CHUNK; + Dtype* grads = (Dtype*) &(data[MSG_DATA]); + for (size_t i = 0; i < CHUNK; ++i) + this->cpu_[off + i] += grads[i]; + master_.recv_done(hdr); + } + } + + // Send absolute positions for chunks for which we are master + // TODO allow broadcast addresses on private networks instead of + // iterating over workers + for (int f = 0; f < Ring::FRAME_NR; f++) { + if (master_.can_send(f, hdr)) { + uint32_t peer = this->other_; + uint32_t chnk = this->chunk_; + ethhdr* eth = master_.send_init(hdr); + memcpy(eth->h_dest, (void*) this->others_[peer], ETH_ALEN); + uint8_t* data = (uint8_t*) eth + ETH_HLEN; + ((uint32_t*) &(data[MSG_CHUNK]))[0] = chnk; + Dtype* pos = (Dtype*) &(data[MSG_DATA]); + size_t off = (size_t) chnk * CHUNK; + memcpy(pos, this->cpu_ + off, CHUNK * sizeof(Dtype)); + master_.send(hdr); + this->next(); + } + } + send(master_.sock(), NULL, 0, MSG_DONTWAIT); + + // Receive absolute positions for other chunks + for (int f = 0; f < Ring::FRAME_NR; f++) { + if (worker_.can_recv(f, hdr)) { + ethhdr* eth = worker_.recv(hdr); + uint8_t* data = (uint8_t*) eth + ETH_HLEN; + uint32_t chunk = ((uint32_t*) &(data[MSG_CHUNK]))[0]; + size_t off = ((size_t) chunk) * CHUNK; + Dtype* pos = (Dtype*) &(data[MSG_DATA]); + + // Send back the gradients if frame is available + Dtype* grads = NULL; + if (worker_.can_send(f, hdr_send)) { + ethhdr* eth_send = worker_.send_init(hdr_send); + uint8_t* m = this->masters_[this->chunk_master(chunk)]; + memcpy(eth_send->h_dest, (void*) m, ETH_ALEN); + uint8_t* data_send = (uint8_t*) eth_send + ETH_HLEN; + ((uint32_t*) &(data_send[MSG_CHUNK]))[0] = chunk; + grads = (Dtype*) &(data_send[MSG_DATA]); + } + + for (size_t i = 0; i < CHUNK; ++i) { + Dtype d = this->cpu_[off + i] - this->cpu_last_[off + i]; + // If gradient is sent, reset last_ to cpu_, otherwise keep them apart + if (grads) { + grads[i] = d; + this->cpu_last_[off + i] = pos[i] + d; + this->cpu_[off + i] = this->cpu_last_[off + i]; + } else { + this->cpu_last_[off + i] = pos[i]; + this->cpu_[off + i] = this->cpu_last_[off + i] + d; + } + } + + worker_.recv_done(hdr); + if (grads) + worker_.send(hdr_send); + + if (this->remaining_ > 0 && !this->received_[chunk]) { + this->received_[chunk] = true; + this->remaining_--; + } + } + } + send(worker_.sock(), NULL, 0, MSG_DONTWAIT); + } +} + +INSTANTIATE_CLASS(RawSync); + +#endif +} diff --git a/src/caffe/parallel.cu b/src/caffe/parallel.cu new file mode 100644 index 00000000000..a88a6769a83 --- /dev/null +++ b/src/caffe/parallel.cu @@ -0,0 +1,84 @@ +#include +#include +#include "caffe/parallel.hpp" + +namespace caffe { + +template +__global__ +void sync_master_kernel(Dtype* gpu, Dtype** grds, size_t* offs, // + int batch_start, int batch_count) { + int i = blockDim.x * blockIdx.x + threadIdx.x; + for (int b = 0; b < batch_count; ++b) { + // Index in queue + int q = (batch_start + b) & (IBChannel::FRAMES - 1); + gpu[offs[q] + i] += grds[q][i]; + } +} + +template +void sync_master_kernel(Dtype* gpu, Dtype** grds, size_t* offs, // + int batch_start, int batch_count, // + const cudaStream_t& stream, size_t chunk) { + int threadsPerBlock = 256; // TODO bench + int blocksPerGrid = chunk / threadsPerBlock; + sync_master_kernel<<>>( + gpu, grds, offs, batch_start, batch_count); + CUDA_POST_KERNEL_CHECK; +} + +template void sync_master_kernel(float* gpu, float** grds, size_t* offs, + int batch_start, int batch_count, // + const cudaStream_t& stream, size_t chunk); +template void sync_master_kernel(double* gpu, double** grds, + size_t* offs, // + int batch_start, int batch_count, // + const cudaStream_t& stream, size_t chunk); + +// + +template +__global__ +void sync_worker_kernel(Dtype* gpu, Dtype* last, Dtype** pos, size_t* offs, + Dtype** grads, uint8_t* get_grads, + int batch_start, int batch_count) { + int i = blockDim.x * blockIdx.x + threadIdx.x; + for (int b = 0; b < batch_count; ++b) { + // Index in queue + int q = (batch_start + b) & (IBChannel::FRAMES - 1); + Dtype d = gpu[offs[q] + i] - last[offs[q] + i]; + if(get_grads[q]) { + gpu[offs[q] + i] = last[offs[q] + i] = pos[q][i] + d; + grads[q][i] = d; // Warn: pos and grads can be same, keep assignment last + } else { + last[offs[q] + i] = pos[q][i]; + gpu[offs[q] + i] = pos[q][i] + d; + } + } +} + +template +void sync_worker_kernel(Dtype* gpu, Dtype* last, Dtype** pos, size_t* offs, + Dtype** grads, uint8_t* get_grads, + int batch_start, int batch_count, + const cudaStream_t& stream, size_t chunk) { + int threadsPerBlock = 64; // TODO bench + int blocksPerGrid = chunk / threadsPerBlock; + sync_worker_kernel<<>>( + gpu, last, pos, offs, grads, get_grads, batch_start, batch_count); + CUDA_POST_KERNEL_CHECK; +} + +template void sync_worker_kernel(float* gpu, float* last, float** pos, + size_t* offs, + float** grads, uint8_t* get_grads, + int batch_start, int batch_count, + const cudaStream_t& stream, size_t chunk); + +template void sync_worker_kernel(double* gpu, double* last, + double** pos, size_t* offs, + double** grads, uint8_t* get_grads, + int batch_start, int batch_count, + const cudaStream_t& stream, size_t chunk); + +} diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 8086ad66579..1dfbfc6c742 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -417,13 +417,14 @@ message ConvolutionParameter { } // Message that stores parameters used by DataLayer +// next available ID: 10 (last added: prefetch) message DataParameter { enum DB { LEVELDB = 0; LMDB = 1; } // Specify the data source. - optional string source = 1; + repeated string source = 1; // Specify the batch size. optional uint32 batch_size = 4; // The rand_skip variable is for the data layer to skip a few data points @@ -443,6 +444,8 @@ message DataParameter { // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror // data. optional bool mirror = 6 [default = false]; + // Prefetch queue (Number of batches to prefetch). + optional uint32 prefetch = 9 [default = 4]; } // Message that stores parameters used by DropoutLayer diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 7f5d3d20f35..b5d436b75a9 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -14,30 +14,36 @@ namespace caffe { template -Solver::Solver(const SolverParameter& param) - : net_() { - Init(param); +Solver::Solver(const SolverParameter& param, bool skip_test_nets) + : iter_(), iter_total_(&iter_), net_() { + Init(param, skip_test_nets); } template Solver::Solver(const string& param_file) - : net_() { + : iter_(), iter_total_(&iter_), net_() { SolverParameter param; ReadProtoFromTextFileOrDie(param_file, ¶m); - Init(param); + Init(param, false); } template -void Solver::Init(const SolverParameter& param) { +void Solver::Init(const SolverParameter& param, bool skip_test_nets) { LOG(INFO) << "Initializing solver from parameters: " << std::endl << param.DebugString(); param_ = param; + if (param_.solver_mode() == SolverParameter::GPU && + param_.has_device_id()) { + Caffe::SetDevice(param_.device_id()); + } + Caffe::set_mode(Caffe::Brew(param_.solver_mode())); if (param_.random_seed() >= 0) { Caffe::set_random_seed(param_.random_seed()); } // Scaffolding code InitTrainNet(); - InitTestNets(); + if(!skip_test_nets) + InitTestNets(); LOG(INFO) << "Solver scaffolding done."; } @@ -377,34 +383,35 @@ template Dtype SGDSolver::GetLearningRate() { Dtype rate; const string& lr_policy = this->param_.lr_policy(); + int iter = *(this->iter_total_); if (lr_policy == "fixed") { rate = this->param_.base_lr(); } else if (lr_policy == "step") { - this->current_step_ = this->iter_ / this->param_.stepsize(); + this->current_step_ = iter / this->param_.stepsize(); rate = this->param_.base_lr() * pow(this->param_.gamma(), this->current_step_); } else if (lr_policy == "exp") { - rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_); + rate = this->param_.base_lr() * pow(this->param_.gamma(), iter); } else if (lr_policy == "inv") { rate = this->param_.base_lr() * - pow(Dtype(1) + this->param_.gamma() * this->iter_, + pow(Dtype(1) + this->param_.gamma() * iter, - this->param_.power()); } else if (lr_policy == "multistep") { if (this->current_step_ < this->param_.stepvalue_size() && - this->iter_ >= this->param_.stepvalue(this->current_step_)) { + iter >= this->param_.stepvalue(this->current_step_)) { this->current_step_++; LOG(INFO) << "MultiStep Status: Iteration " << - this->iter_ << ", step = " << this->current_step_; + iter << ", step = " << this->current_step_; } rate = this->param_.base_lr() * pow(this->param_.gamma(), this->current_step_); } else if (lr_policy == "poly") { rate = this->param_.base_lr() * pow(Dtype(1.) - - (Dtype(this->iter_) / Dtype(this->param_.max_iter())), + (Dtype(iter) / Dtype(this->param_.max_iter())), this->param_.power()); } else if (lr_policy == "sigmoid") { rate = this->param_.base_lr() * (Dtype(1.) / - (Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) - + (Dtype(1.) + exp(-this->param_.gamma() * (Dtype(iter) - Dtype(this->param_.stepsize()))))); } else { LOG(FATAL) << "Unknown learning rate policy: " << lr_policy; diff --git a/src/caffe/syncedmem.cpp b/src/caffe/syncedmem.cpp index 7617ccfb27f..60263f225ee 100644 --- a/src/caffe/syncedmem.cpp +++ b/src/caffe/syncedmem.cpp @@ -12,7 +12,7 @@ SyncedMemory::~SyncedMemory() { } #ifndef CPU_ONLY - if (gpu_ptr_) { + if (gpu_ptr_ && own_gpu_data_) { CUDA_CHECK(cudaFree(gpu_ptr_)); } #endif // CPU_ONLY @@ -51,10 +51,12 @@ inline void SyncedMemory::to_gpu() { CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); caffe_gpu_memset(size_, 0, gpu_ptr_); head_ = HEAD_AT_GPU; + own_gpu_data_ = true; break; case HEAD_AT_CPU: if (gpu_ptr_ == NULL) { CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); + own_gpu_data_ = true; } caffe_gpu_memcpy(size_, cpu_ptr_, gpu_ptr_); head_ = SYNCED; @@ -92,6 +94,20 @@ const void* SyncedMemory::gpu_data() { #endif } +void SyncedMemory::set_gpu_data(void* data) { +#ifndef CPU_ONLY + CHECK(data); + if (own_gpu_data_) { + CUDA_CHECK(cudaFree(gpu_ptr_)); + } + gpu_ptr_ = data; + head_ = HEAD_AT_GPU; + own_gpu_data_ = false; +#else + NO_GPU; +#endif +} + void* SyncedMemory::mutable_cpu_data() { to_cpu(); head_ = HEAD_AT_CPU; @@ -108,6 +124,18 @@ void* SyncedMemory::mutable_gpu_data() { #endif } +#ifndef CPU_ONLY +void SyncedMemory::async_gpu_push(const cudaStream_t& stream) { + CHECK(head_ == HEAD_AT_CPU); + if (gpu_ptr_ == NULL) { + CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); + own_gpu_data_ = true; + } + CUDA_CHECK(cudaMemcpyAsync(gpu_ptr_, cpu_ptr_, size_, cudaMemcpyHostToDevice, stream)); + // Assume caller will synchronize on the stream before use + head_ = SYNCED; +} +#endif } // namespace caffe diff --git a/src/caffe/test/test_parallel.cpp b/src/caffe/test/test_parallel.cpp new file mode 100644 index 00000000000..8302b00a899 --- /dev/null +++ b/src/caffe/test/test_parallel.cpp @@ -0,0 +1,68 @@ +#include + +#include "gtest/gtest.h" + +#include "caffe/parallel.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template +class DistSyncTest: public ::testing::Test { +}; + +TYPED_TEST_CASE(DistSyncTest, TestDtypes); + +TYPED_TEST(DistSyncTest, TestNothing) { + // The first test case of a test suite takes the longest time + // due to the set up overhead. +} + +//TYPED_TEST(DistSyncTest, TestMasterIndex) { +// vector > > blobs(); +// blobs.push_back(new shared_ptr(new Blob(1, 1, 1, 1))); +// Params params(blobs); +// +// vector nodes(); +// nodes.push_back(0); +// nodes.push_back(1); +// nodes.push_back(2); +// nodes.push_back(3); +// +// uint32_t chunks = 1000; +// for(int index = 0; index < nodes.size(); ++index) { +// DistSync sync(params, nodes, nodes, chunks); +// sync.Init(index); +// +// for(uint32_t chunk = 0; chunk < chunks; ++chunk) +// EXPECT( +// (sync.master(chunk) == index) +// == +// (chunk >= sync.own_start_ && chunk < sync.own_until_)); +// } +//} + +// test buffers are the same +//bool ready = true; +//for (int i = 0; i < solvers.size(); ++i) +// if (!solvers[i]) +// ready = false; +//if (ready) { +// for (int i = 0; i < solvers.size(); ++i) { +// shared_ptr > n0 = solvers[0]->net(); +// shared_ptr > ni = solvers[i]->net(); +// vector > >& p0 = n0->params(); +// vector > >& pi = ni->params(); +// for (int j = 0; j < p0.size(); ++j) +// CHECK(pi[j]->cpu_data() == p0[j]->cpu_data()); +// } +// shared_ptr > n0 = solvers_[0]->net(); +// vector > >& p0 = n0->params(); +// for (int j = 0; j < p0.size(); ++j) +// for (int k = 0; k < p0[j]->count(); ++k) +// CHECK(!isnan(p0[j]->cpu_data()[k])) << " NAN"; +//} + + +} // namespace caffe diff --git a/src/caffe/util/multicast_resources.cpp b/src/caffe/util/multicast_resources.cpp new file mode 100644 index 00000000000..26496d6e1a5 --- /dev/null +++ b/src/caffe/util/multicast_resources.cpp @@ -0,0 +1,270 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace caffe { + +#ifdef RDMA +#include "caffe/util/multicast_resources.hpp" + +// This is when we get sig handler from the user before we remove the join request. +struct mcast_parameters *sighandler_params; + +/****************************************************************************** + * signalCatcher - cacth user signal in order to reregiser the mcast group + ******************************************************************************/ +static void signalCatcher (int sig) +{ + if (sig == SIGINT) { + + if (join_multicast_group(SUBN_ADM_METHOD_DELETE,sighandler_params)) + fprintf(stderr,"Couldn't Unregister the Mcast group on the SM\n"); + + if (sighandler_params->is_2nd_mgid_used) { + memcpy(sighandler_params->mgid.raw,sighandler_params->base_mgid.raw,16); + if (join_multicast_group(SUBN_ADM_METHOD_DELETE,sighandler_params)) + fprintf(stderr,"Couldn't Unregister the Base Mcast group on the SM\n"); + } + } + exit(1); +} + +/****************************************************************************** + * prepare_mcast_mad + ******************************************************************************/ +static void prepare_mcast_mad(uint8_t method, + struct mcast_parameters *params, + struct sa_mad_packet_t *samad_packet) { + + uint8_t *ptr; + uint64_t comp_mask; + + memset(samad_packet,0,sizeof(*samad_packet)); + + /* prepare the MAD header. according to Table 145 in IB spec 1.2.1 */ + ptr = samad_packet->mad_header_buf; + ptr[0] = 0x01; /* BaseVersion */ + ptr[1] = MANAGMENT_CLASS_SUBN_ADM; /* MgmtClass */ + ptr[2] = 0x02; /* ClassVersion */ + ptr[3] = INSERTF(ptr[3], 0, method, 0, 7); /* Method */ + (*(uint64_t *)(ptr + 8)) = htonll((uint64_t)DEF_TRANS_ID); /* TransactionID */ + (*(uint16_t *)(ptr + 16)) = htons(SUBN_ADM_ATTR_MC_MEMBER_RECORD); /* AttributeID */ + + ptr = samad_packet->SubnetAdminData; + + memcpy(&ptr[0],params->mgid.raw, 16); + memcpy(&ptr[16],params->port_gid.raw, 16); + + (*(uint32_t *)(ptr + 32)) = htonl(DEF_QKEY); + (*(uint16_t *)(ptr + 40)) = htons(params->pkey); + ptr[39] = DEF_TCLASS; + ptr[44] = INSERTF(ptr[44], 4, DEF_SLL, 0, 4); + ptr[44] = INSERTF(ptr[44], 0, DEF_FLOW_LABLE, 16, 4); + ptr[45] = INSERTF(ptr[45], 0, DEF_FLOW_LABLE, 8, 8); + ptr[46] = INSERTF(ptr[46], 0, DEF_FLOW_LABLE, 0, 8); + ptr[48] = INSERTF(ptr[48], 0, MCMEMBER_JOINSTATE_FULL_MEMBER, 0, 4); + + comp_mask = SUBN_ADM_COMPMASK_MGID | SUBN_ADM_COMPMASK_PORT_GID | SUBN_ADM_COMPMASK_Q_KEY | + SUBN_ADM_COMPMASK_P_KEY | SUBN_ADM_COMPMASK_TCLASS | SUBN_ADM_COMPMASK_SL | + SUBN_ADM_COMPMASK_FLOW_LABEL | SUBN_ADM_COMPMASK_JOIN_STATE; + + samad_packet->ComponentMask = htonll(comp_mask); +} + +/****************************************************************************** + * check_mad_status + ******************************************************************************/ +static int check_mad_status(struct sa_mad_packet_t *samad_packet) { + + uint8_t *ptr; + uint32_t user_trans_id; + uint16_t mad_header_status; + + ptr = samad_packet->mad_header_buf; + + // the upper 32 bits of TransactionID were set by the kernel + user_trans_id = ntohl(*(uint32_t *)(ptr + 12)); + + // check the TransactionID to make sure this is the response + // for the join/leave multicast group request we posted + if (user_trans_id != DEF_TRANS_ID) { + fprintf(stderr, "received a mad with TransactionID 0x%x, when expecting 0x%x\n", + (unsigned int)user_trans_id, (unsigned int)DEF_TRANS_ID);; + return 1; + } + + mad_header_status = 0x0; + mad_header_status = INSERTF(mad_header_status, 8, ptr[4], 0, 7); + mad_header_status = INSERTF(mad_header_status, 0, ptr[5], 0, 8); + + if (mad_header_status) { + fprintf(stderr,"received UMAD with an error: 0x%x\n", mad_header_status); + return 1; + } + + return 0; +} + + +/****************************************************************************** + * get_mlid_from_mad + ******************************************************************************/ +static void get_mlid_from_mad(struct sa_mad_packet_t *samad_packet,uint16_t *mlid) { + uint8_t *ptr; + ptr = samad_packet->SubnetAdminData; + *mlid = ntohs(*(uint16_t *)(ptr + 36)); +} + +/****************************************************************************** + * get_gid_from_mad + ******************************************************************************/ +static void get_gid_from_mad(struct sa_mad_packet_t *samad_packet,ibv_gid *gid) { + uint8_t *ptr; + ptr = samad_packet->SubnetAdminData; + memcpy(gid->raw, ptr, 16); +} + +/****************************************************************************** + * set_multicast_gid + ******************************************************************************/ +void set_multicast_gid(struct mcast_parameters *params,uint32_t qp_num,int is_client) { + + uint8_t mcg_gid[16] = MCG_GID; + char *pstr = const_cast(params->user_mgid); + char *term = NULL; + char tmp[20]; + int i; + + if (params->user_mgid) { + term = strpbrk(pstr, ":"); + memcpy(tmp, pstr, term - pstr+1); + tmp[term - pstr] = 0; + + mcg_gid[0] = (unsigned char)strtoll(tmp, NULL, 0); + + for (i = 1; i < 15; ++i) { + pstr += term - pstr + 1; + term = strpbrk(pstr, ":"); + memcpy(tmp, pstr, term - pstr+1); + tmp[term - pstr] = 0; + + mcg_gid[i] = (unsigned char)strtoll(tmp, NULL, 0); + } + pstr += term - pstr + 1; + + strcpy(tmp, pstr); + mcg_gid[15] = (unsigned char)strtoll(tmp, NULL, 0); + } + + memcpy(params->mgid.raw,mcg_gid,16); + if (is_client && params->user_mgid==NULL) + params->mgid.raw[15]++; +} + +/****************************************************************************** + * join_multicast_group + ******************************************************************************/ +int join_multicast_group(subn_adm_method method,struct mcast_parameters *params) { + + int portid = -1; + int agentid = -1; + void *umad_buff = NULL; + void *mad = NULL; + int length = MAD_SIZE; + int test_result = 0; + + // mlid will be assigned to the new LID after the join + if (umad_init() < 0) { + fprintf(stderr, "failed to init the UMAD library\n"); + goto cleanup; + } + /* use casting to loose the "const char0 *" */ + portid = umad_open_port((char*)params->ib_devname,params->ib_port); + if (portid < 0) { + fprintf(stderr,"failed to open UMAD port %d\n",params->ib_port); + goto cleanup; + } + + agentid = umad_register(portid,MANAGMENT_CLASS_SUBN_ADM, 2, 0, 0); + if (agentid < 0) { + fprintf(stderr,"failed to register UMAD agent for MADs\n"); + goto cleanup; + } + + umad_buff = umad_alloc(1, umad_size() + MAD_SIZE); + if (!umad_buff) { + fprintf(stderr, "failed to allocate MAD buffer\n"); + goto cleanup; + } + + mad = umad_get_mad(umad_buff); + prepare_mcast_mad(method,params,(struct sa_mad_packet_t *)mad); + + if (umad_set_addr(umad_buff,params->sm_lid,1,params->sm_sl,QP1_WELL_KNOWN_Q_KEY) < 0) { + fprintf(stderr, "failed to set the destination address of the SMP\n"); + goto cleanup; + } + + if (umad_send(portid,agentid,umad_buff,MAD_SIZE,100,5) < 0) { + fprintf(stderr, "failed to send MAD\n"); + goto cleanup; + } + + if (umad_recv(portid,umad_buff,&length,5000) < 0) { + fprintf(stderr, "failed to receive MAD response\n"); + goto cleanup; + } + + if (check_mad_status((struct sa_mad_packet_t*)mad)) { + fprintf(stderr, "failed to get mlid from MAD\n"); + goto cleanup; + } + + // "Join multicast group" message was sent + if (method == SUBN_ADM_METHOD_SET) { + get_gid_from_mad((struct sa_mad_packet_t*)mad,¶ms->mgid); + get_mlid_from_mad((struct sa_mad_packet_t*)mad,¶ms->mlid); + params->mcast_state |= MCAST_IS_JOINED; + if (params->is_2nd_mgid_used == 0) { + sighandler_params = params; + signal(SIGINT,signalCatcher); + } + } else { + params->mcast_state &= ~MCAST_IS_JOINED; + } + +cleanup: + if (umad_buff) + umad_free(umad_buff); + + if (portid >= 0) { + if (agentid >= 0) { + if (umad_unregister(portid, agentid)) { + fprintf(stderr, "failed to deregister UMAD agent for MADs\n"); + test_result = 1; + } + } + + if (umad_close_port(portid)) { + fprintf(stderr, "failed to close UMAD portid\n"); + test_result = 1; + } + } + + return test_result; +} + +#endif /* RDMA */ +} diff --git a/src/caffe/util/upgrade_proto.cpp b/src/caffe/util/upgrade_proto.cpp index cbd6003c948..20638fefc70 100644 --- a/src/caffe/util/upgrade_proto.cpp +++ b/src/caffe/util/upgrade_proto.cpp @@ -295,7 +295,7 @@ bool UpgradeLayerParameter(const LayerParameter& v0_layer_connection, } if (v0_layer_param.has_source()) { if (type == "data") { - layer_param->mutable_data_param()->set_source(v0_layer_param.source()); + layer_param->mutable_data_param()->add_source(v0_layer_param.source()); } else if (type == "hdf5_data") { layer_param->mutable_hdf5_data_param()->set_source( v0_layer_param.source());