-
Notifications
You must be signed in to change notification settings - Fork 760
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize LayerNorm Forward #6842
Conversation
Speed stats:
|
int thread_group_width, int rows_per_access, bool padding> | ||
__global__ void LayerNormWarpImpl(LOAD load, STORE store, const int64_t rows, const int64_t cols, | ||
const double epsilon, ComputeType* mean, | ||
ComputeType* inv_variance) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cols<=1024时的主要逻辑
__global__ void LayerNormBlockSMemImpl(LOAD load, STORE store, const int64_t rows, | ||
const int64_t cols, const double epsilon, ComputeType* mean, | ||
ComputeType* inv_variance) { | ||
extern __shared__ __align__(sizeof(double)) unsigned char shared_buf[]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cols>1024时,使用Shared Memory缓存输入
template<typename LOAD, typename STORE, typename ComputeType, int pack_size, int block_size> | ||
__global__ void LayerNormBlockUncachedImpl(LOAD load, STORE store, const int64_t rows, | ||
const int64_t cols, const double epsilon, | ||
ComputeType* mean, ComputeType* inv_variance) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
无法Launch Shared Memory Kernel,重复读取输入
template<typename LOAD, typename STORE, typename ComputeType> | ||
inline cudaError_t DispatchLayerNorm(cudaStream_t stream, LOAD load, STORE store, | ||
const int64_t rows, const int64_t cols, const double epsilon, | ||
ComputeType* mean, ComputeType* inv_variance) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
前向调用入口函数
return false; | ||
} | ||
} | ||
DST* normalized; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
normalized将在后面的pr优化掉。
if (ctx->has_input("gamma", 0)) { | ||
const user_op::Tensor* gamma = ctx->Tensor4ArgNameAndIndex("gamma", 0); | ||
gamma_ptr = gamma->dptr<T>(); | ||
CHECK_EQ(gamma->shape().elem_cnt(), norm_size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不再考虑begin_norm_axis和begin_param_axis不等的情况,后续pr合并这两个参数。
mean/invariance的长度一定是num_instances, gamma/beta的elem_cnt一定是norm_size
Speed stats:
|
@autotest(n=20, auto_backward=True, rtol=1e-3, atol=1e-3) | ||
def test_layernorm_with_random_data(test_case): | ||
device = random_device() | ||
channel = random(1, 6).to(int) | ||
height = random(1, 6).to(int) | ||
width = random(1, 6).to(int) | ||
height = random(1, 60).to(int) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉测试样例子可以多写几个,分别覆盖你上面那几个情况
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果要确定覆盖,是不是写成height = random(32, 33)这样,例如保证它生成的是32,然后多写几个这个test函数?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以写成 oneof(a, b, c, ...)
的形式,它会以均匀的概率返回函数参数之一,例如 oneof(32, 64)
|
||
template<typename SRC, typename DST> | ||
struct DirectStore { | ||
DirectStore(DST* dst, int64_t row_size) : dst(dst), row_size(row_size) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int32_t row_size
能不能减少寄存器需求
|
||
template<typename T, int thread_group_width = kWarpSize> | ||
__inline__ __device__ void WelfordWarpAllReduce(T thread_mean, T thread_m2, T thread_count, T* mean, | ||
T* m2, T* count) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为不能确认Welford算法a+b和b+a一定相等,WelfordWarpAllReduce不再使用__shfl_xor_sync
,而是改用Reduce+broadcast实现。
device = random_device() | ||
channel = random(1, 2).to(int) | ||
height = random(1, 2).to(int) | ||
width = random(1, 1024).to(int) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
测试Warp实现的情况
device = random_device() | ||
channel = random(1, 2).to(int) | ||
height = random(1, 2).to(int) | ||
width = random(1024, 8192).to(int) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
测试SharedMemory的情况
width = random(1, 6).to(int) | ||
channel = random(1, 2).to(int) | ||
height = random(1, 2).to(int) | ||
width = random(8192, 32768).to(int) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
测试Uncached的情况
Speed stats:
|
…nc/oneflow into dev_layer_norm_forward
Speed stats:
|
* Rename class OneflowVM to VirtualMachine (#6753) * Rename class OneflowVM to VirtualMachine * refine * refine * refine Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * upgrade cub to 1.11.0 for NVIDIA/cub#170 (#6795) Signed-off-by: daquexian <daquexian566@gmail.com> Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * lazy create cuda_stream (#6806) * lazy create cuda_stream in CudaCopyD2HDeviceCtx CudaStreamHandleDeviceCtx * refine * Remove KernelContext::stream_ctx() (#6805) * Remove KernelContext::stream_ctx() * fix GetCudaAlignedSize * refine * Remove StreamContextAdapter * refine include Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * add tensor method docstr (#6800) * add tensor method docstr * add tensor method docstr * add tensor method docstr * add tensor method docstr * add tensor method docstr * add tensor method docstr * add tensor method docstr * fix ci related bug * set common compiler flags in oneflow_add_library(...), enable it for CUDA (#6813) * apply treating warnings as errors in oneflow_add_library(...), enable it to CUDA Signed-off-by: daquexian <daquexian566@gmail.com> * support target_try_compile_options on clang cuda Signed-off-by: daquexian <daquexian566@gmail.com> * reorder oneflow_add_library Signed-off-by: daquexian <daquexian566@gmail.com> * add cuda-61-clang.cmake and cuda-75-clang.cmake Signed-off-by: daquexian <daquexian566@gmail.com> * move oneflow_add_xxx after set_compile_options_to_oneflow_target Signed-off-by: daquexian <daquexian566@gmail.com> * reformat Signed-off-by: daquexian <daquexian566@gmail.com> Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Use ep::Stream instead of DeviceCtx (#6825) * remove redundant code (#6807) * Prevent CI failure when cublas alloc fail (#6826) * Dev nms (#6817) * fix typo * dev nms * fix * fix * fix format * skip distribute test Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Refactor vm consuming (#6748) * refactor PhyInstrOperand::ForEachXXXMirroredObject * remove ForEachXXXMirroredObject4XXXPhyInstrOperand * reduce for-loops for InstructionList * reduce for-loops on InstructionMsgList * refactor MakeInstructions * refactor PhyInstrOperand::ForEachXXXMirroredObject * 1) refactor ConnectInstruction to TryConnectInstruction; 2) refactor BackInserter to SetInserter * create RwMutextObjectAccess/InstructionEdge from intrusive::ObjectPool * refactor profiler range name * fix barrier instruction comment typos * fix compiler complaints * Update oneflow/core/intrusive/object_pool.h Co-authored-by: daquexian <daquexian566@gmail.com> * fix static analysis complaints Co-authored-by: daquexian <daquexian566@gmail.com> Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Add init method docstr modify int to int32 (#6828) * Add nn.init method docstr, and modify np.int * Add nn.init method docstr, and modify np.int * Check whether the expand_shape parameter is legal (#6812) * check parameters * simplify logic * fix ci error Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * refactor local call opkernel instruction (#6733) * remove CheckOutputBlobObjectsMemCase * move calling of ChooseOpKernel from scheduler thread to main thread. * address pr comments Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * one_hot primitive interface (#6796) * one_hot primitive interface * refine * refine Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * revert DependenceVector to std::vector (#6835) * fix indexed slice for adam max_x (#6824) Co-authored-by: ZZK <42901638+MARD1NO@users.noreply.github.com> Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * set str option (#6832) * set str option * refine * refine * fix * refine * fix * fix * refine * refine * refine * fix Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Empty op support float16 (#6847) * support fp16 * add float16 test case * add graph cudnn conv alg config (#6799) Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Dev vm view instruction (#6815) * shallow copy * try reset blob data * refine * debug * raw implementation * refine * refine * to_contiguous op * reine * refine * refine * set_last_used_device * refine * raw implementation * debug * replace TryResetBlobData with SyncAccessBlobByCallback * tensor_view_instruction * refine * tensor_view_operand * remove tensor_view_phy_instr_operand * refine * refine * refine * restruct * refine * refine * refine * refine * Remove deafult l2 and use bias add in lazy mode (#6844) * remove_deafult_l2_and_use_bias_add_in_lazy_mode * minor fix * minor fix * undo bais add Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Add arccos op and docstr (#6841) * Add arccos op and docstr * fix docstr format Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * add some fused kernels (#6635) * fix errors, op with dropout successes, but op without dropout has error * fix errors, success * fix typo error * test dropout * add comments * fix typos * change format * reformat file * fix error * change format * remove useless head file * fix errors * fix errors * reformat * fix errors * reformat * fix errors Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> Co-authored-by: ZZK <42901638+MARD1NO@users.noreply.github.com> * Add CUDA arch 52 back and compile it in CI (#6802) * Add CUDA arch 53 back and compile it in CI * fix cuda * fix * don't build 52 by default * rm comment Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * [EP] Add ep::Device/ep::Event (#6851) * [EP] Add ep::Device/ep::Event * Refine ActiveDeviceContext * fix * refine include * fix tidy error * fix cudaEventRecord * fix test * refine * Fix FuseBN eval error (#6836) * fix arange bug * fix fuse bn * Remove redundant saved_tensor * fix bug * add more test case * add more random test case * add fuse functor when track_stats=false * fix backward errror when track_stats=false Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Remove KernelXxxContext::device_ctx()/device_ctx() (#6862) * pool code refine (#6853) * pool code refine * refine * format * fix static analysis error * fix max_pool_2d_grad name * prefix tf is used to pool functor name * fix Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Add cpu group conv impl (#6823) * add cpu group conv kernel, test success * add group conv cpu backward kernel * rename * update test case * fix comment * fix comments * fix comment * optimize again and fix ci eroor * fix error * fix ci error * fix ci_tidy error * fix ci error * revert code * fix bug * delete useless file * delete useless file * fix ci error Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Add nsys profile host thread name (#6865) Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Rename DeviceType::kGPU to DeviceType::kCUDA (#6863) * Rename DeviceType::kGPU to DeviceType::kCUDA * fix * fix typo Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Check modify op module (#6860) * Add arccos op and docstr * Check and modify Op module * delete register_tensor_op * Fix random ops (#6868) Co-authored-by: Bowen Chen <bob2420083992@gmail.com> Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * fix docstr problem (#6554) * fix docstr problem * fix * Update random.py Co-authored-by: Yao Chi <later@usopp.net> * fix retinanet (#6870) Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Optimize LayerNorm Forward (#6842) * layer_norm forward * test case * rm useless * int count to T count * fix * fix T mask to int mask, refine code * refine * refine * test case * format * fix Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Refactor last used device (#6852) * move last_used_device * refine * refine * fix pipeline delay ctrl edge between src subset tick and output (#6881) Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Support OpenVINO in xrt (#6709) * Support openvino in xrt * OpenVINO: add graph input and weight in op * OpenVINO: support more op * update follow review * update follow review * update follow review * Add doc for graph_config.py * update follow review * update follow review * modify after review * format * add xrt in check_src.py Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Dev optimize std vector (#6630) * use reserve * use emplace_back * refine * remove useless codes (#6859) * remove useless codes * fix index_select * fix expand error * fix expand error Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * add alpha parameter in add_op (#6867) * add alpha parameter in add_op * format * refine * refine * refine * fix bug about dtype caused by alpha Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * add fallback to cpu builder (#6582) * add fallback to cpu slice boxing * fix * fix * merge master * format * fix * modify graph.py (#6884) Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org> * fix eye op attr name error (#6873) * fix eye op attr name error * refine * refine * fix * delete useless attr Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * add inplace mul (#6861) * init commit for inplace mul * fix issue, format code * add tests and fix issues * format code * delete redundant code * Update oneflow/core/functional/impl/binary_functor.cpp Co-authored-by: Yinggang Wang <wyg19970408@gmail.com> * refine * fix unit test * fix bug * refine * fix unittests * add boardcast test * refine * refine * fix ci issue Co-authored-by: Yinggang Wang <wyg19970408@gmail.com> * Dev roialign (#6879) * dev roialign * testcase * fux * fix Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Pick Variant from Standalone Maybe (#6856) * refactor maybe: add variant * maybe: add optional and tests * maybe: add hash for optional & variant; support NullOpt for both optional & variant * maybe: more notes * maybe: binary search impl for Variant::Visit * maybe: add more relational operator to optional & variant * maybe: add nonstd::string_view * maybe: fix construct of optional & variant * maybe: support comparision for optional & variant * maybe: add monadic operations for optional * maybe: add error traits * maybe: add JUST and Maybe * maybe: remove useless comment * maybe: add more test * maybe: customizable JUST * maybe: add Map and Bind (AndThen) to Maybe * maybe: re-design JustConfig * maybe: rename xxxT to xxxS * maybe: fix method names * maybe: add maybe to cmake * maybe: fix error traits * maybe: rename fields & add aggregate type checking * maybe: move string_view to new file * maybe: rename fields for optional and error * maybe: new Value (no index checking, protected method) and Get (has check, public method) * maybe: remove DefaultArgument * Pick Variant from Standalone Maybe Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Limit CI run speed test on one machine (#6891) * Run speed test on one machine * fix * Add oneDNN (#6767) * add onednn cmake * add onednn stream engine * Successfully implement addn * add int64 double * optimization voctor * fix * fix merge master error * fix merge master * fix merge error * Add BUILD_ONEDNN cmake flags * fix format * fix onednn datatype * optmizer onedn type * modified for(n) => for(i) * modified ci * modified oneDNN.cmake * fix clang 10 error * rename BUILD_ONEDNN * Delete oneDNN installation path by mistake * fix ci error, c++: error: third_party_install/onednn/lib/libdnnl.a: No such file or directory * include(GNUInstallDirs) * print ci error * reformat * Only the first parameter can be operated inplace * format * fix inlcude onednn, add clang 11 support refernce Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Rework op import with new ods (#6883) * add naive impl * refine * refine * refine * add naive gen td * refine * refine * fix * refine * refine * refine * sort alphabetically * support optional * support Variadic * refine * refine * add conv * add input output order * add todo * add todo * refine * refine * refine * refine * naive bn order interface * fix includes * refine * refine * refine * refine * group ops * refine order * add math * refine * refine * refine * refine * refine * refine * add quantization ops * refine * refine * add detection * fix * refien * add new .td generated * refine * refine * refine * refine * refine * refine * refine * refine * Use generated ods in mlir (#6857) * refine * check in changes * refine * move pattern to another file * compile grouped op * refine * add todo * fix * add GetUserOpDef in wrapper * check in files * refine * refine * fix * refine * refine * refine * refine * refine * refine * refine tablegen * refine * fix * refine * refine * refine * refine * refine * refine * refine * refien * refine * refine * refine * refine * fix * refine * rm log * refine * refine * make ctrl edge type safe * refine * refine * refine * rm legacy code * refine * refien * refine * dirty trick addn2 without variadic deduction * fix jit op * refine * extract GetOutputLbn * refine * fix for single seg * refine * rm todo * update .mlir file * refine * add todo * refine * refine * refine * refine * refine * add log * refine * refine * make op_type_name type safe * refine * refine * refine * delete trainable * add IsOpConfCompatible * add IsImportCompatible * refine * refien * mv ir_pass.cpp out of core * refine * refine * refine * refine * refine * refine * gen new ods from master * refine * refine * update for tf pool ops * refine * refine * refine * refine APIs * refine order * rm * rm output_lbn_segment_keys * output_lbn_segment_sizes * rm output_lbns * refine * refine * refine * refine * fmt * use less cores to prevent OOM in CI * refine * refine Co-authored-by: BBuf <1182563586@qq.com> Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * add cudnn.h (#6886) Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> Co-authored-by: Shenghang Tsai <jackalcooper@gmail.com> * refine * refactor jit interpreter with updated ODS * refine Co-authored-by: Yu OuYang <xuanjiuye@gmail.com> Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> Co-authored-by: daquexian <daquexian566@gmail.com> Co-authored-by: guo ran <360112263@qq.com> Co-authored-by: Juncheng <liujuncheng1022@gmail.com> Co-authored-by: Li Xiang <54010254+lixiang007666@users.noreply.github.com> Co-authored-by: dssgsra <dssgsra@gmail.com> Co-authored-by: Shijie <821898965@qq.com> Co-authored-by: Li Xinqi <lixinqi2010@gmail.com> Co-authored-by: Liang Depeng <liangdepeng@gmail.com> Co-authored-by: ZZK <42901638+MARD1NO@users.noreply.github.com> Co-authored-by: liufengwei0103 <2472937968@qq.com> Co-authored-by: Luyang <flowingsun007@163.com> Co-authored-by: Xiaoyu Xu <xiaoyulink@gmail.com> Co-authored-by: binbinHan <han_binbin@163.com> Co-authored-by: DangKai <dangkai4u@outlook.com> Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Co-authored-by: cheng cheng <472491134@qq.com> Co-authored-by: Houjiang Chen <chenhoujiangcug@gmail.com> Co-authored-by: Bowen Chen <bob2420083992@gmail.com> Co-authored-by: Derek Zhang <85550485+HENGRui6@users.noreply.github.com> Co-authored-by: Yao Chi <later@usopp.net> Co-authored-by: tingkuanpei <50049308+tingkuanpei@users.noreply.github.com> Co-authored-by: grybd <52237830+grybd@users.noreply.github.com> Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org> Co-authored-by: Zhanghuihong Guan <31779698+Garfieldgzhh@users.noreply.github.com> Co-authored-by: Yinggang Wang <wyg19970408@gmail.com> Co-authored-by: Twice <i@twice.moe> Co-authored-by: luqiang guo <702572275@qq.com> Co-authored-by: BBuf <1182563586@qq.com>
优化LayerNorm分几个pr,本pr为第一个pr:
1、优化前向实现:采用welford算法计算方差,及根据cols大小采用分段函数实现以达到最优性能,替代现有的naive实现。welford算法参考自:https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
2、优化后向Grad实现:根据cols大小采用分段函数实现以达到最优性能,替代现有的调用cudnn的方法。
3、优化后向ParamGrad实现:借助[32, 33]大小的Shared Memory块存储中间结果,借助gemm做第二次reduce。
4、优化输入输出,删除掉前向的输出normalize、后向输入normalize_diff,合并参数begin_norm_axis和begin_param_axis。合并LayerNormGrad和LayerNormParamGrad op。
5、实现LayerNorm primitive,补充cpu实现。