Skip to content

Commit

Permalink
Remove LayerWeight, simplify the use of computational graphs (#372)
Browse files Browse the repository at this point in the history
* remove layer_weight

* fix test_ls_layers_new

* remove useless log

* fix

* pre-commit fix format

* format
  • Loading branch information
hexisyztem authored Sep 1, 2022
1 parent b4dbab3 commit e234497
Show file tree
Hide file tree
Showing 23 changed files with 375 additions and 525 deletions.
30 changes: 15 additions & 15 deletions examples/inference/cpp/bert_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,25 @@ Example of how to run Bert inference using our implementation.

int main(int argc, char* argv[]) {
std::string model_weights_path = argv[1];
std::vector<int> example_input = {2859, 2758, 2051, 2157,
2005, 6629, 7566, 1012};
int eg_seq_len = example_input.size();
int max_batch_size = 128;
int batch_size = 1;
int batch_seq_len = eg_seq_len;
std::vector<int> example_input{};

int max_batch_size = 1;
int batch_seq_len = 32;
int rand_seed = 772002;

if (argc == 4) {
batch_size = atoi(argv[2]);
max_batch_size = atoi(argv[2]);
batch_seq_len = atoi(argv[3]);
}
if (batch_size > max_batch_size) {
throw std::runtime_error("batch_size exceeds the maximum (128)!");
} else if (argc == 5) {
max_batch_size = atoi(argv[2]);
batch_seq_len = atoi(argv[3]);
rand_seed = atoi(argv[4]);
}

std::vector<int> host_input;
for (int i = 0; i < batch_size; ++i) {
for (int i = 0; i < max_batch_size; ++i) {
for (int j = 0; j < batch_seq_len; ++j) {
host_input.push_back(example_input[j % eg_seq_len]);
host_input.push_back(rand() % 9000 + 1000);
}
}

Expand All @@ -35,13 +35,13 @@ int main(int argc, char* argv[]) {

void* d_input;
lightseq::cuda::CHECK_GPU_ERROR(
cudaMalloc(&d_input, sizeof(int) * batch_size * batch_seq_len));
cudaMalloc(&d_input, sizeof(int) * max_batch_size * batch_seq_len));
lightseq::cuda::CHECK_GPU_ERROR(cudaMemcpy(
d_input, host_input.data(), sizeof(int) * batch_size * batch_seq_len,
d_input, host_input.data(), sizeof(int) * max_batch_size * batch_seq_len,
cudaMemcpyHostToDevice));

model->set_input_ptr(0, d_input);
model->set_input_shape(0, {batch_size, batch_seq_len});
model->set_input_shape(0, {max_batch_size, batch_seq_len});

for (int i = 0; i < model->get_output_size(); i++) {
void* d_output;
Expand Down
13 changes: 5 additions & 8 deletions lightseq/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
cmake_minimum_required(VERSION 3.10)
project(LightseqProtoType LANGUAGES C CXX CUDA)

set(CMAKE_CUDA_ARCHITECTURES
60
61
70
75
80
86
87)
set(CMAKE_CUDA_ARCHITECTURES 70 75 80 86 87)
find_package(CUDA 11.6 REQUIRED)

set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR})
Expand All @@ -31,6 +24,10 @@ else()
set(HDF5_USE_STATIC_LIBRARIES ON)
endif()

if(DEBUG_MODE)
add_definitions(-DDEBUG)
endif()

set(Protobuf_USE_STATIC_LIBS ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)
Expand Down
2 changes: 1 addition & 1 deletion lightseq/csrc/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ if [ ! -d 'build' ]; then
mkdir build
fi

cd build && cmake -DDEBUG=ON .. && make -j${nproc}
cd build && cmake -DDEBUG_MODE=OFF .. && make -j${nproc}
33 changes: 16 additions & 17 deletions lightseq/csrc/example/bert_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,25 @@ Example of how to run Bert inference using our implementation.

int main(int argc, char* argv[]) {
std::string model_weights_path = argv[1];
std::vector<int> example_input = {2859, 2758, 2051, 2157,
2005, 6629, 7566, 1012};
std::vector<int> example_input{};

int eg_seq_len = example_input.size();
int max_batch_size = 128;
int batch_size = 1;
int batch_seq_len = eg_seq_len;
int max_batch_size = 1;
int batch_seq_len = 32;
int rand_seed = 772002;

if (argc == 4) {
batch_size = atoi(argv[2]);
max_batch_size = atoi(argv[2]);
batch_seq_len = atoi(argv[3]);
}
if (batch_size > max_batch_size) {
throw std::runtime_error("batch_size exceeds the maximum (128)!");
} else if (argc == 5) {
max_batch_size = atoi(argv[2]);
batch_seq_len = atoi(argv[3]);
rand_seed = atoi(argv[4]);
}

std::vector<int> host_input;
for (int i = 0; i < batch_size; ++i) {
for (int i = 0; i < max_batch_size; ++i) {
for (int j = 0; j < batch_seq_len; ++j) {
host_input.push_back(example_input[j % eg_seq_len]);
host_input.push_back(rand() % 9000 + 1000);
}
}

Expand All @@ -36,13 +35,13 @@ int main(int argc, char* argv[]) {

void* d_input;
CHECK_GPU_ERROR(
cudaMalloc(&d_input, sizeof(int) * batch_size * batch_seq_len));
cudaMalloc(&d_input, sizeof(int) * max_batch_size * batch_seq_len));
CHECK_GPU_ERROR(cudaMemcpy(d_input, host_input.data(),
sizeof(int) * batch_size * batch_seq_len,
sizeof(int) * max_batch_size * batch_seq_len,
cudaMemcpyHostToDevice));

model->set_input_ptr(0, d_input);
model->set_input_shape(0, {batch_size, batch_seq_len});
model->set_input_shape(0, {max_batch_size, batch_seq_len});

for (int i = 0; i < model->get_output_size(); i++) {
void* d_output;
Expand All @@ -58,10 +57,10 @@ int main(int argc, char* argv[]) {
std::cout << "infer preprocessing finished" << std::endl;

/* ---step5. infer and log--- */
for (int i = 0; i < 1; i++) {
for (int i = 0; i < 10; i++) {
auto start = std::chrono::high_resolution_clock::now();
model->Infer();
// lightseq::cuda::print_time_duration(start, "one infer time", 0);
lightseq::print_time_duration(start, "one infer time", 0);
}

for (int i = 0; i < model->get_output_size(); i++) {
Expand Down
139 changes: 62 additions & 77 deletions lightseq/csrc/layers_new/feed_forward_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,68 +2,12 @@

namespace lightseq {

template <class T1, class T2>
int FeedForwardLayerWeight::load_para_and_grad(const T1* para_ptr,
T2* grad_ptr) { // for training
int offset = 0;
_inter_w_ptr = (char*)(para_ptr + offset);
_grad_inter_w_ptr = (char*)(grad_ptr + offset);
offset += _hidden_size * _intermediate_size;

_inter_b_ptr = (char*)(para_ptr + offset);
_grad_inter_b_ptr = (char*)(grad_ptr + offset);
offset += _intermediate_size;

_output_w_ptr = (char*)(para_ptr + offset);
_grad_output_w_ptr = (char*)(grad_ptr + offset);
offset += _hidden_size * _intermediate_size;

_output_b_ptr = (char*)(para_ptr + offset);
_grad_output_b_ptr = (char*)(grad_ptr + offset);
offset += _hidden_size;

_ffn_nw_ptr = (char*)(para_ptr + offset);
_grad_ffn_nw_ptr = (char*)(grad_ptr + offset);
offset += _hidden_size;

_ffn_nb_ptr = (char*)(para_ptr + offset);
_grad_ffn_nb_ptr = (char*)(grad_ptr + offset);
offset += _hidden_size;

return offset;
}

template int FeedForwardLayerWeight::load_para_and_grad(const float* para_ptr,
float* grad_ptr);
template int FeedForwardLayerWeight::load_para_and_grad(const __half* para_ptr,
__half* grad_ptr);

template <typename T>
void FeedForwardLayerWeight::load_params(const std::vector<const T*>& para_vec,
int& offset) { // for inference
_ffn_nw_ptr = (char*)para_vec[offset++];
_ffn_nb_ptr = (char*)para_vec[offset++];

_inter_w_ptr = (char*)para_vec[offset++];
_inter_b_ptr = (char*)para_vec[offset++];

_output_w_ptr = (char*)para_vec[offset++];
_output_b_ptr = (char*)para_vec[offset++];

return;
}

template void FeedForwardLayerWeight::load_params<float>(
const std::vector<const float*>& para_vec, int& offset);
template void FeedForwardLayerWeight::load_params<__half>(
const std::vector<const __half*>& para_vec, int& offset);

template <typename T1, typename T2>
FeedForwardLayer<T1, T2>::FeedForwardLayer(
FeedForwardLayerWeightPtr ffn_wt, int layer_id, int max_batch_tokens,
int max_seq_len, int hidden_size, int num_heads, int intermediate_size,
float activation_dropout_ratio, float hidden_output_dropout_ratio,
bool pre_or_postLayerNorm, std::string activation_fn, bool is_post_ln)
int layer_id, int max_batch_tokens, int max_seq_len, int hidden_size,
int num_heads, int intermediate_size, float activation_dropout_ratio,
float hidden_output_dropout_ratio, bool pre_or_postLayerNorm,
std::string activation_fn, bool is_post_ln)
: Layer("FeedForwardLayer"),
_layer_id(layer_id),
_max_batch_tokens(max_batch_tokens),
Expand All @@ -87,27 +31,21 @@ FeedForwardLayer<T1, T2>::FeedForwardLayer(
_ffn_dropout(new BiasDropoutResOp<T1, T2>(
hidden_output_dropout_ratio, max_batch_tokens * hidden_size)) {
// parameters node
_inter_w = new Variable(this->_name + "_inter_w", ffn_wt->_inter_w_ptr,
ffn_wt->_grad_inter_w_ptr);
_inter_b = new Variable(this->_name + "_inter_b", ffn_wt->_inter_b_ptr,
ffn_wt->_grad_inter_b_ptr);
_inter_w = new Variable(name() + "_inter_w");
_inter_b = new Variable(this->_name + "_inter_b");

_output_w = new Variable(this->_name + "_output_w", ffn_wt->_output_w_ptr,
ffn_wt->_grad_output_w_ptr);
_output_b = new Variable(this->_name + "_output_b", ffn_wt->_output_b_ptr,
ffn_wt->_grad_output_b_ptr);
_output_w = new Variable(this->_name + "_output_w");
_output_b = new Variable(this->_name + "_output_b");

_ffn_nw = new Variable(this->_name + "_ffn_nw", ffn_wt->_ffn_nw_ptr,
ffn_wt->_grad_ffn_nw_ptr);
_ffn_nb = new Variable(this->_name + "_ffn_nb", ffn_wt->_ffn_nb_ptr,
ffn_wt->_grad_ffn_nb_ptr);
_ffn_nw = new Variable(this->_name + "_ffn_nw");
_ffn_nb = new Variable(this->_name + "_ffn_nb");

this->_context_ptr->exit_layer(); // necessary
}

template <typename T1, typename T2>
Variable* FeedForwardLayer<T1, T2>::operator()(Variable* inp) {
this->set_inputs({inp});
LAYER_PRE_INPUTS({inp});
Variable* ff1_out = nullptr;
Variable* ffn_ln_out = nullptr;
if (_pre_or_postLayerNorm) {
Expand All @@ -130,10 +68,10 @@ Variable* FeedForwardLayer<T1, T2>::operator()(Variable* inp) {

if (!_pre_or_postLayerNorm) {
Variable* ffn_ln_out = (*_ffn_ln)(ffn_dropout_residual, _ffn_nw, _ffn_nb);
this->set_outputs({ffn_ln_out});
LAYER_POST_OUTPUTS({ffn_ln_out});
return ffn_ln_out;
} else {
this->set_outputs({ffn_dropout_residual});
LAYER_POST_OUTPUTS({ffn_dropout_residual});
return ffn_dropout_residual;
}
}
Expand All @@ -156,7 +94,54 @@ void FeedForwardLayer<T1, T2>::before_forward(int batch_size, int seq_len) {
template <typename T1, typename T2>
void FeedForwardLayer<T1, T2>::before_backward() {}

// template class FeedForwardLayer<float, float>;
// template class FeedForwardLayer<__half, __half>;
template <typename T1, typename T2>
int FeedForwardLayer<T1, T2>::load_para_and_grad(
const T1* para_ptr,
T2* grad_ptr) { // for training
int offset = 0;

_inter_w->set_value((char*)(para_ptr + offset));
_inter_w->set_grad((char*)(grad_ptr + offset));
offset += _hidden_size * _intermediate_size;

_inter_b->set_value((char*)(para_ptr + offset));
_inter_b->set_grad((char*)(grad_ptr + offset));
offset += _intermediate_size;

_output_w->set_value((char*)(para_ptr + offset));
_output_w->set_grad((char*)(grad_ptr + offset));
offset += _hidden_size * _intermediate_size;

_output_b->set_value((char*)(para_ptr + offset));
_output_b->set_grad((char*)(grad_ptr + offset));
offset += _hidden_size;

_ffn_nw->set_value((char*)(para_ptr + offset));
_ffn_nw->set_grad((char*)(grad_ptr + offset));
offset += _hidden_size;

_ffn_nb->set_value((char*)(para_ptr + offset));
_ffn_nb->set_grad((char*)(grad_ptr + offset));
offset += _hidden_size;

return offset;
}

template <typename T1, typename T2>
int FeedForwardLayer<T1, T2>::load_params(
const std::vector<const T1*>& para_vec,
int offset) { // for inference
int size = 0;
_ffn_nw->set_value((char*)para_vec[offset + size]), size++;
_ffn_nb->set_value((char*)para_vec[offset + size]), size++;

_inter_w->set_value((char*)para_vec[offset + size]), size++;
_inter_b->set_value((char*)para_vec[offset + size]), size++;

_output_w->set_value((char*)para_vec[offset + size]), size++;
_output_b->set_value((char*)para_vec[offset + size]), size++;

return size;
}

} // namespace lightseq
39 changes: 6 additions & 33 deletions lightseq/csrc/layers_new/includes/feed_forward_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,6 @@

namespace lightseq {

class FeedForwardLayerWeight {
public:
FeedForwardLayerWeight(int hidden_size, int intermediate_size)
: _hidden_size(hidden_size), _intermediate_size(intermediate_size) {}
char* _inter_w_ptr;
char* _inter_b_ptr;
char* _output_w_ptr;
char* _output_b_ptr;
char* _ffn_nw_ptr;
char* _ffn_nb_ptr;

char* _grad_inter_w_ptr;
char* _grad_inter_b_ptr;
char* _grad_output_w_ptr;
char* _grad_output_b_ptr;
char* _grad_ffn_nw_ptr;
char* _grad_ffn_nb_ptr;

int _hidden_size;
int _intermediate_size;

template <class T1, class T2>
int load_para_and_grad(const T1* para_ptr, T2* grad_ptr);

template <typename T>
void load_params(const std::vector<const T*>& para_vec, int& offset);
};

using FeedForwardLayerWeightPtr = std::shared_ptr<FeedForwardLayerWeight>;

template <class T1, class T2>
class FeedForwardLayer : public Layer {
private:
Expand Down Expand Up @@ -71,9 +41,8 @@ class FeedForwardLayer : public Layer {
bool _is_post_ln;

public:
FeedForwardLayer(FeedForwardLayerWeightPtr ffn_wt, int layer_id,
int max_batch_tokens, int max_seq_len, int hidden_size,
int num_heads, int intermediate_size,
FeedForwardLayer(int layer_id, int max_batch_tokens, int max_seq_len,
int hidden_size, int num_heads, int intermediate_size,
float activation_dropout_ratio,
float hidden_output_dropout_ratio, bool pre_or_postLayerNorm,
std::string activation_fn, bool is_post_ln = false);
Expand All @@ -85,6 +54,10 @@ class FeedForwardLayer : public Layer {
void before_forward(int batch_size, int seq_len);

void before_backward();

int load_para_and_grad(const T1* para_ptr, T2* grad_ptr);

int load_params(const std::vector<const T1*>& para_vec, int offset);
};

template class FeedForwardLayer<__half, __half>;
Expand Down
Loading

0 comments on commit e234497

Please sign in to comment.