Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cpp parallel executor #9080

Merged
merged 163 commits into from
Mar 30, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
163 commits
Select commit Hold shift + click to select a range
0621c32
init commit
Mar 13, 2018
e67325c
update readme
Mar 14, 2018
8f061e4
delete param name
Mar 14, 2018
baef112
ParallelExecutor And dependency engine
reyoung Mar 14, 2018
692a0f7
Better name
reyoung Mar 14, 2018
ae88fde
Use thread pool
reyoung Mar 15, 2018
22bb262
Remove out of date design
reyoung Mar 15, 2018
35744e7
Polish code
reyoung Mar 15, 2018
193c0a7
Handle var hazard
reyoung Mar 15, 2018
d84ddcf
Stash
reyoung Mar 15, 2018
6f0dfd8
Single GPU ParallelExecutor complete
reyoung Mar 16, 2018
8c9cd36
Polish code style
reyoung Mar 16, 2018
8b397d1
Make recordio file reader thread-safe by default
reyoung Mar 16, 2018
5e87cd7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
reyoung Mar 16, 2018
0ef9edf
Stash
reyoung Mar 16, 2018
9fc0b59
Test more
reyoung Mar 16, 2018
d470763
Stash
reyoung Mar 16, 2018
c15d2c9
Update
reyoung Mar 16, 2018
8f0590e
Add ncclAllReduce
reyoung Mar 16, 2018
e8a7e5d
Update
reyoung Mar 16, 2018
b2c7a9b
Wait by stream
reyoung Mar 16, 2018
254d7ff
Refactor local_scopes
reyoung Mar 16, 2018
9cb8f50
Complete fetch op
reyoung Mar 19, 2018
e18a269
Add debug code
reyoung Mar 19, 2018
389ea18
Debug code
reyoung Mar 19, 2018
f8141d9
Debug
reyoung Mar 19, 2018
09935ab
Debug
reyoung Mar 19, 2018
0023c3b
Use atomic bool
reyoung Mar 19, 2018
f52714d
Debug
reyoung Mar 19, 2018
5957f28
Debug
reyoung Mar 19, 2018
36e0415
Single Thread
reyoung Mar 19, 2018
f3e983e
Memory order
reyoung Mar 19, 2018
b57b880
Debug
reyoung Mar 19, 2018
b1cb8bb
Debug
reyoung Mar 19, 2018
1f063d0
Memorder
reyoung Mar 19, 2018
515e516
Add more log
reyoung Mar 19, 2018
ea11a0a
Use volitie
reyoung Mar 19, 2018
a87ce91
Use mtx
reyoung Mar 19, 2018
a5ba704
Counter
reyoung Mar 19, 2018
d3e55fd
Guard devctx
reyoung Mar 19, 2018
866f6f1
Debug
reyoung Mar 19, 2018
7bff02b
Change to pending op
reyoung Mar 19, 2018
5fa535b
Wait all thread done
reyoung Mar 19, 2018
c7beac1
Add dummy var
reyoung Mar 19, 2018
1f53193
Use atomic code
reyoung Mar 19, 2018
3aa7051
Remove DevCtx lock
reyoung Mar 19, 2018
d7badb3
Use event to sync stream
reyoung Mar 19, 2018
29cc9f3
SetDev for nccl
reyoung Mar 19, 2018
8af5770
Only wait same device
reyoung Mar 19, 2018
071043c
Add paddle enforce
reyoung Mar 19, 2018
9824e8f
Scale loss op use event
reyoung Mar 19, 2018
4a33009
Add log
reyoung Mar 19, 2018
bade579
Wait code
reyoung Mar 19, 2018
7fd0d24
Add lgo
reyoung Mar 19, 2018
dad7bda
Add setDev
reyoung Mar 19, 2018
932364a
Sync dev
reyoung Mar 19, 2018
d55a03d
Scale loss on place
reyoung Mar 19, 2018
d26f093
Log
reyoung Mar 19, 2018
99f85a9
Set dev
reyoung Mar 19, 2018
b94ffac
SetDev
reyoung Mar 19, 2018
ee697b8
Larger model
reyoung Mar 19, 2018
48619bc
Shrink model
reyoung Mar 19, 2018
c372ce2
Add event for computational op
reyoung Mar 19, 2018
c18c2f6
Sync all computation streams at the end of run
reyoung Mar 20, 2018
d3c82c3
Wait multiple stream
reyoung Mar 20, 2018
3da4159
Add run iter
reyoung Mar 20, 2018
4137bb4
Add wait
reyoung Mar 20, 2018
d2cb379
Wait all evernts
reyoung Mar 20, 2018
8a9de67
Remove wait
reyoung Mar 20, 2018
3238ce0
Add wait
reyoung Mar 20, 2018
e025e28
Exchange wait op
reyoung Mar 20, 2018
260cfe3
Stop Wait NCCL Stream
reyoung Mar 20, 2018
feb569f
Add log
reyoung Mar 20, 2018
9b1f4d5
After nccl add event
reyoung Mar 20, 2018
631aa3d
Wait all inputs ready
reyoung Mar 20, 2018
4185dd4
Disable multi-thread
reyoung Mar 20, 2018
1dd216d
Wait bcast param
reyoung Mar 20, 2018
f251a58
Use base class manage events
reyoung Mar 20, 2018
ca4b3d2
Use 12 threads
reyoung Mar 20, 2018
7643c2c
Add flag for use event
reyoung Mar 20, 2018
fbbcedd
Fix bug
reyoung Mar 20, 2018
f8f1a96
Add debug code
reyoung Mar 20, 2018
3c9cea5
Add more log
reyoung Mar 20, 2018
a8bd7b9
Add log
reyoung Mar 20, 2018
e53b6ab
Use no thread
reyoung Mar 20, 2018
dbed123
Debug
reyoung Mar 20, 2018
4e43b71
Add wait log
reyoung Mar 20, 2018
a0494f8
Mutex lock wait
reyoung Mar 20, 2018
1c2b610
Add
reyoung Mar 20, 2018
798e690
Change mem order
reyoung Mar 20, 2018
95a0d7c
Illegal memory access
reyoung Mar 20, 2018
ed7727e
Fix bug in system allocator
reyoung Mar 20, 2018
176277b
Add log
reyoung Mar 20, 2018
1533bf1
Use event and single thread
reyoung Mar 20, 2018
ba227df
Expose num_threads
reyoung Mar 20, 2018
d42117e
Set NumThreads
reyoung Mar 20, 2018
65bc7d1
Add mtx to ncclAllReduce
reyoung Mar 20, 2018
eb0a580
Add enforce
reyoung Mar 20, 2018
82693e7
Wait nccl all reduce
reyoung Mar 20, 2018
e335f01
Add more logs
reyoung Mar 20, 2018
43e5407
Debug code
reyoung Mar 20, 2018
599f7a8
Refine code
reyoung Mar 20, 2018
7ac969b
Debug
reyoung Mar 21, 2018
90f9801
Do not wait computation stream
reyoung Mar 21, 2018
99fe83a
Move nccl helper
reyoung Mar 21, 2018
41ad632
Add NCCL Group Guard
reyoung Mar 21, 2018
f2685be
Clean code
reyoung Mar 21, 2018
a478a11
NCCL Guard for bcast
reyoung Mar 21, 2018
6ebc6bf
ReorganizeCode
reyoung Mar 21, 2018
fe7ed28
Extract NCCLCtxMap
reyoung Mar 21, 2018
5368e50
Reorganize code
reyoung Mar 21, 2018
15f5f10
AddInput/AddOutput for OpHandle
reyoung Mar 21, 2018
5c333e4
Add dctor for dev_ctx
reyoung Mar 21, 2018
f28ae6e
Reorganize Code
reyoung Mar 21, 2018
3181501
Rerange code
reyoung Mar 21, 2018
8dec4ad
Use int not Place for vars
reyoung Mar 21, 2018
64d7a30
Extract SSAGraph
reyoung Mar 21, 2018
79989c9
Add SSA builder
reyoung Mar 21, 2018
dd73d18
Extract SSAGraph
reyoung Mar 22, 2018
b123e43
extract multi devices graph builder
reyoung Mar 24, 2018
4c3361c
Extract GraphExecutor
reyoung Mar 24, 2018
c70b60d
Make executor steal graph inside
reyoung Mar 24, 2018
e314439
Extract Executors to indie modules
reyoung Mar 24, 2018
a7b0d5b
Clean code
reyoung Mar 24, 2018
edfd741
Add simple python wrapper for ParallelExecutor
reyoung Mar 24, 2018
5c7a523
Add Graphviz output
reyoung Mar 26, 2018
50e7e25
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
reyoung Mar 26, 2018
54bd17f
Complete Flowers
reyoung Mar 26, 2018
02aaecc
Fix CPU compile
reyoung Mar 26, 2018
3aa2a8f
Follow comments
reyoung Mar 26, 2018
ee97687
Fix compile
reyoung Mar 26, 2018
cb40c33
Update unittest
reyoung Mar 26, 2018
9dd64d8
WMT Model
reyoung Mar 26, 2018
aba46f0
Disable P2P
reyoung Mar 27, 2018
833e522
Enhance drop kids
reyoung Mar 27, 2018
f385228
Add Paddle Enforce
reyoung Mar 27, 2018
5a02739
Throw error
reyoung Mar 27, 2018
55e2cc3
FetchOp Force sync
reyoung Mar 27, 2018
b6ca371
Get error
reyoung Mar 27, 2018
76570c2
Wait fetch op
reyoung Mar 27, 2018
2227632
Change fetch op
reyoung Mar 27, 2018
9af8708
Use heap variables
reyoung Mar 27, 2018
dfb8680
Early drop fetch op
reyoung Mar 27, 2018
52dd8ff
Force sync dev
reyoung Mar 27, 2018
5b92dd4
Remove dev sync
reyoung Mar 27, 2018
c42c4a6
Add performance tests
reyoung Mar 27, 2018
3f88fad
Fix merge op
reyoung Mar 27, 2018
c0c2e15
NCCL AllReduce
reyoung Mar 27, 2018
7dcb217
Refine allreduce op
reyoung Mar 27, 2018
50f71f5
Using blocking queue
reyoung Mar 27, 2018
dcf7bd2
Add initP2P
reyoung Mar 27, 2018
201f79d
Use Extend method
reyoung Mar 27, 2018
5408854
Disable model evaluation in unittests
reyoung Mar 28, 2018
9f4a98f
Add design doc
reyoung Mar 28, 2018
084cdd1
Rename code
reyoung Mar 28, 2018
f2d29be
Disable transformer
reyoung Mar 28, 2018
f707a83
Add link
reyoung Mar 28, 2018
b077558
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
reyoung Mar 28, 2018
7da1ea0
Use PopAll
reyoung Mar 28, 2018
ce2f096
Merge branch 'cpp_parallel_executor' of github.com:reyoung/Paddle int…
reyoung Mar 28, 2018
38b53b3
Remove Pop method
reyoung Mar 28, 2018
e868950
Add comments
reyoung Mar 29, 2018
af230d9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
reyoung Mar 29, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ include(external/cares)
include(external/grpc)
include(external/snappy) # download snappy
include(external/snappystream)
include(external/threadpool)

include(cudnn) # set cudnn libraries, must before configure
include(cupti)
Expand Down
30 changes: 30 additions & 0 deletions cmake/external/threadpool.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
INCLUDE(ExternalProject)

SET(THREADPOOL_SOURCE_DIR ${THIRD_PARTY_PATH}/threadpool)
SET(THREADPOOL_INCLUDE_DIR ${THREADPOOL_SOURCE_DIR}/src/extern_threadpool)
INCLUDE_DIRECTORIES(${THREADPOOL_INCLUDE_DIR})

ExternalProject_Add(
extern_threadpool
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/progschj/ThreadPool.git"
GIT_TAG 9a42ec1329f259a5f4881a291db1dcb8f2ad9040
PREFIX ${THREADPOOL_SOURCE_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)

if (${CMAKE_VERSION} VERSION_LESS "3.3.0")
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/threadpool_dummy.c)
file(WRITE ${dummyfile} "const char *dummy_threadpool = \"${dummyfile}\";")
add_library(simple_threadpool STATIC ${dummyfile})
else()
add_library(simple_threadpool INTERFACE)
endif()

add_dependencies(simple_threadpool extern_threadpool)

LIST(APPEND external_project_dependencies simple_threadpool)
83 changes: 83 additions & 0 deletions doc/design/images/parallel_executor_overview.dot
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
digraph G {
subgraph cluster_init {
label="Initialization"
startup_program [label="startup", shape=box]
node_w_g0 [label="W\nGPU0"]
startup_program -> node_w_g0 [label="Initialize"]
node_w_g1 [label="W\nGPU1"]
node_w_g0 -> node_w_g1 [label="broadcast"]
}

subgraph cluster_train {
label="forward_backward"

subgraph cluster_gpu0 {
label="GPU0"
fc_0 [label="fc\nGPU0", shape=box]
hidden_0 [label="hidden\nGPU0"]
node_w_g0 -> fc_0
fc_0 -> hidden_0
loss0 [label="loss\nGPU0"]
hidden_0 -> loss0 [label="many ops omitted"]
scale_loss_0 [label="scale_loss_gradient\nGPU0", shape=box]
loss_g0 [label="loss_grad\nGPU0"]
scale_loss_0->loss_g0

fc_g_0 [label="w_grad\nGPU0", shape=box]
loss0 -> fc_g_0
loss_g0 -> fc_g_0
hidden_0 -> fc_g_0
}

subgraph cluster_gpu1 {
label="GPU1"
fc_1 [label="fc\nGPU1", shape=box]
hidden_1 [label="hidden\nGPU1"]
node_w_g1 -> fc_1
fc_1 -> hidden_1
loss1 [label="loss\nGPU1"]
hidden_1 -> loss1 [label="many ops omitted"]
scale_loss_1 [label="scale_loss_gradient\nGPU1", shape=box]
loss_g1 [label="loss_grad\nGPU1"]
scale_loss_1->loss_g1

fc_g_1 [label="w_grad\nGPU1", shape=box]
loss1 -> fc_g_1
loss_g1 -> fc_g_1
hidden_1 -> fc_g_1
}
}

all_reduce_w [label="Merge Gradients(AllReduce)", shape=box]
fc_g_0 -> all_reduce_w
fc_g_1 -> all_reduce_w

fc_g_0_merged [label="w_grad\nMerged\nGPU0"]
fc_g_1_merged [label="w_grad\nMerged\nGPU1"]
all_reduce_w -> fc_g_0_merged
all_reduce_w -> fc_g_1_merged

subgraph cluster_optimization {
label="Optimization"
subgraph cluster_opt_gpu0 {
label="GPU0"
sgd_0 [label="SGD Op\nGPU0", shape=box]

fc_g_0_merged -> sgd_0
node_w_g0 -> sgd_0
optimized_w_0 [label="Optimized W\nGPU0"]
sgd_0 -> optimized_w_0
}
subgraph cluster_opt_gpu1 {
label="GPU1"
sgd_1 [label="SGD Op\nGPU1", shape=box]

fc_g_1_merged -> sgd_1
node_w_g1 -> sgd_1
optimized_w_1 [label="Optimized W\nGPU0"]
sgd_1 -> optimized_w_1
}
}


}
Binary file added doc/design/images/parallel_executor_overview.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
104 changes: 104 additions & 0 deletions doc/design/parallel_executor.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# ParallelExecutor

## Background

Neural network models are defined as a `ProgramDesc` in Fluid. The `ProgramDesc` can be executed by an interpreter(i.e. the `executor` concept in Fluid). The instructions or operators in a `Program` will be executed, and the results will be fetched in Python side.

The executor is a very naive interpreter. It runs operators one by one. We can use `Parallel.Do` to support data parallelism, however, lacking device information in `ProgramDesc`; it is not possible to optimize the performance of `Parallel.Do`.

We want a `ProgramDesc` can be run on different nodes. It is better not to contain device information in `ProgramDesc`. However, we can write a high-performance interpreter, which can hold an alternative intermediate representation of `ProgramDesc`, to take full usage of Multi-GPUs.

ParallelExecutor is an interpreter of `ProgramDesc` which will [out-of-order execute](https://en.wikipedia.org/wiki/Out-of-order_execution) `Program` in data parallelism mode and maximise the utility of Multi-GPUs.


## Overview of MultiGPUs logic

The ParallelExecutor takes the startup program and main program as inputs. The parameters will be initialised on `GPU0` by startup program and will broadcast to multi-GPUs. The main program will be duplicated into multi-GPUs. The gradient will be merged during each iteration, and each device will optimize parameters independently. Since the gradients on each device will be merged before parameter optimization, the parameters will be the same on each device and it does not need to be broadcast the parameters.

![alt](images/parallel_executor_overview.png)

There are several optimizations for this logic.

1. We use an alternate representation in ParallelExecutor. It because the device information is critical for performance optimization.
2. The execution is out-of-order, i.e., an operator will be executed whenever the inputs of the operator are ready.
* GPU is a high-performance device; only one CPU thread cannot fulfil one GPU. So there is a thread pool to execute operators.
* Out-of-order also helps transpilers to generate `ProgramDesc`. It is no need to concern about the best order of performance when implementing a transpiler.
3. The streams of computation, merge gradients and fetch data are different.

The performance of `ResNeXt152` on `TitanX` which `batch_size=12` is shown below.

| Number of GPUs | 1 | 2 | 3 | 4|
| --- | --- | --- | --- | --- |
| Image/Sec | 17.9906 | 25.771 | 36.911 | 48.8428 |
| Speed Up | N/A | 1.43247029 | 2.05168255 | 2.71490667 |


## Static single assignment Graph

[Static single assignment form](https://en.wikipedia.org/wiki/Static_single_assignment_form)(`SSA` for short) is a common form for compiler optimization. To implement concurrent execution, we uses an `SSA` graph as an intermedia representation of `ProgramDesc`.

The `Program` is a directed acyclic graph, since a variable can be assigned multiple times. We enforce a variable will be assigned once, by adding version number to varaibles. We parsing the `Program` into a `SSA` graph. Also, ProgramExecutor duplicate `Program` into multi-devices. We also add a device number to varaibles and insert `NCCLAllReduce` into Graph.

The data structure of `SSA` graph is:

```c++
struct VarHandleBase {
OpHandleBase* generated_op_;
vector<OpHandleBase*> pending_ops_;

string name;
Place place;
size_t version;
};

struct OpHandleBase {
vector<OpHandleBase*> inputs_;
vector<OpHnadleBase*> outputs_;
};

struct SSAGraph {
// vars on each devices.
// * the vars in each map in vector is on different device.
// * the map is mapping a variable name to variable handles
// with different versions
vector<std::unordered_map<string, vector<VarHandleBase>>> vars_;

// All ops
vector<OpHandleBase> ops_;
};
```
The variable handles are the wrapper of `Variables`. The operator handles are the wrapper of `OperatorBase`. Some `OpHandle` is not an `OperatorBase`, such as `NCCLAllReduceOpHandle`, because `AllReduceOpHandle` will use new device contexts.

When the `ProgramDesc` converted into an `SSA` Graph, the [data hazard](https://en.wikipedia.org/wiki/Hazard_(computer_architecture)) problem is also need to be taken care. The dummy variables, which represent the dependency between operators, will be manually inserted into SSA graph to resolve the [data hazard](https://en.wikipedia.org/wiki/Hazard_(computer_architecture)) problem.

## Execute SSA Graph

The SSA graph can be out-of-order executed by an approximate [topological sorting](https://en.wikipedia.org/wiki/Topological_sorting) algorithm. The algorithm is

1. Maintaining a map of an operator and its needed input number.
2. If a variable is not generated by an operator, i.e., `var.generated_op == nullptr`, decrease the needed input number of its pending operators.
3. If there is an operator which needed input number is decreased to zero, just run this operator.
4. After run this operator, just mark the variables are generated and repeat step 2 until all variables are generated.

Running an operator can be asynchronized. There is a thread pool to execute an `SSA` graph.

## Synchronize GPU Kernels

The GPU is a non-blocking device. The different streams need be synchronized when switing streams. In current implementation, the synchronization based on the following algorithm:

1. `OpHandle` will record `DeviceContext` that it is used.
2. In `OpHandle::Run`, if the `DeviceContext` of current operator is different from `DeviceContext` of any input variable, just wait the generate operator of this input variable.

The `wait` are implemented by two strategies:

1. Invoke `DeviceContext->Wait()`, It will wait all operators on this device contexts complete.
2. Uses `cudaStreamWaitEvent` to sending a event to the stream. It is a non-blocking call. The wait operators will be executed in GPU.

Generally, the `cudaStreamWaitEvent` will have a better perforamnce. However, `DeviceContext->Wait()` strategy is easier to debug. The strategy can be changed in runtime.

## What's next?

* Merging gradient of dense parameters has been done. However, the merging of sparse parameters has not been done.
* The CPU version of Parallel Executor has not been implemented. The out-of-order logic will make CPU compuatation faster, too.
* A better strategy to merge gradients can be introduced. We can shrink the gradients from `float32` to `int8` or `int4` while merging. It will significantly speed up multi-GPUs training without much loss of precision.
* Combine multi-Nodes implementation. By the benifit of out-of-order, sending and recving operator can be an blocking operator, and the transpiler does not need to concern about the best position of operator.
4 changes: 4 additions & 0 deletions paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
add_subdirectory(details)
# ddim lib
proto_library(framework_proto SRCS framework.proto)

Expand Down Expand Up @@ -87,6 +88,9 @@ cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glo
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table feed_fetch_method)


cc_library(parallel_executor SRCS parallel_executor.cc DEPS multi_devices_graph_builder threaded_ssa_graph_executor)

cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
Expand Down
21 changes: 21 additions & 0 deletions paddle/fluid/framework/details/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
cc_library(var_handle SRCS var_handle.cc DEPS place)
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context)
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)

cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)

if(WITH_GPU)
set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle)
else()
set(multi_devices_graph_builder_deps)
endif()
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle ${multi_devices_graph_builder_deps})
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context)
42 changes: 42 additions & 0 deletions paddle/fluid/framework/details/computation_op_handle.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/details/computation_op_handle.h"

namespace paddle {
namespace framework {
namespace details {
ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
platform::Place place)
: op_(framework::OpRegistry::CreateOp(op_desc)),
scope_(scope),
place_(place) {}

void ComputationOpHandle::RunImpl() {
auto *cur_ctx = dev_ctxes_[place_];
for (auto *in : inputs_) {
bool need_wait =
in->generated_op_ && in->generated_op_->dev_ctxes_[place_] != cur_ctx;
if (need_wait) {
in->generated_op_->Wait(cur_ctx);
}
}

op_->Run(*scope_->FindVar("@TMP_SCOPE@")->Get<Scope *>(), place_);
}

std::string ComputationOpHandle::Name() const { return op_->Type(); }
} // namespace details
} // namespace framework
} // namespace paddle
41 changes: 41 additions & 0 deletions paddle/fluid/framework/details/computation_op_handle.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"

namespace paddle {
namespace framework {
namespace details {
struct ComputationOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_;
Scope *scope_;
platform::Place place_;

ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
platform::Place place);

std::string Name() const override;

protected:
void RunImpl() override;
};
} // namespace details
} // namespace framework
} // namespace paddle
Loading