From 322a36b814ec05bf555bcfa75d108c3f4b8cf27e Mon Sep 17 00:00:00 2001 From: Houjiang Chen Date: Tue, 4 Jan 2022 21:57:29 +0800 Subject: [PATCH] Fix python apis and xla implementation (#7183) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Support save/load for lr_scheduler (#6948) * feat(LrScheduler): support save/load for lr_scheduler * refine document * auto format by CI * Refine test * auto format by CI Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> Co-authored-by: oneflow-ci-bot * Fix eye_op attr (#6973) * fix * add graph test * Update python/oneflow/test/graph/test_graph_eye.py Co-authored-by: daquexian * refine * Update python/oneflow/test/graph/test_graph_eye.py Co-authored-by: daquexian * auto format by CI Co-authored-by: daquexian Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> Co-authored-by: oneflow-ci-bot * softmax double use uncached impl to accelerate compile (#6992) Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Add [[nodiscard]] for cpp api (#6997) * add [[nodiscard]] * refine * reformat Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Support Arange delta to decide dtype (#6998) * support delta dtype to decide output dtype * add more unittest Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Add clang as CUDA FE compiler in CI (#6954) * update action use * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine * fix * add 80 and 86 * refine * refine * add CUDA_NVCC_THREADS_NUMBER * refine * address review * set CUDA_NVCC_THREADS_NUMBER 8 * fix * fix clang in init cmake * add script * refine * refine * refine * refine * refine * refien * refine * add flags to skip zlib * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine * Migrate chunk python layer to functor (#6983) * Migrate chunk Python layer logic to functor * fix runtime * Fix splits bug and CI * Modify push to emplace Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Reduce memory usage when compiling oneflow dialect ops (#7000) * CudaAllocator device reset before OOM (#6976) * CudaAllocator device reset before OOM * Add NOTE Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Refactor vm stream desc (#6989) * remove StreamDesc::num_machines * Prepare one thread for one stream_type Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Add Diagonal Op (#6016) * format complete * python to cpp * py2cpp error * rm * auto format by CI * revise * auto format by CI * license * docstring * docstring * tensor * tensor attribute * auto format by CI * docstring * revise * test * revise * revise * rename * half * docs * doc,test * test times * revise * format Co-authored-by: oneflow-ci-bot Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * add all to all op (#6283) * add all to all op * add barrier * format * add import * fix test * delete barrier * delete barrier * Revert "delete barrier" This reverts commit aa397ea5ba815fe6df883b263b82735f126345c8. * Revert "delete barrier" This reverts commit 7ddf79afaa7ac072813e84ce9224440939a3f95c. * check tensor meta between ranks * add more assert * all_reduce operate in place * all_reduce operate in place * fix bug * assert tensor.is_local * fix bug in scatter * add more assert * delete meta check * add pytorch comparison test * add pytorch comparison test * refine * add ONEFLOW_TEST_CPU_ONLY * fix bug from torch gloo Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Dev ivalue for cpp api (#6890) * add api tensor * refine * add nn.relu * refine * clean shape & refine relu test * support void* for from_blob * add multithreading relu test * refine test * refine * refine * add comment for __internal_tensor() * convert to copy_util * reformat * refine * add ivalue * refine directory structure * refine cpp api test * refine test * add ivalue * refine ivalue * refine ivalue * refine * refine * refine * refine Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * default use cpu generator (#7001) * optimize reshape/slice/transpose functor (#6956) * optimize reshape/slice/transpose functor * update code according to reviewer's suggestion * judge negative dimension number besides -1 * judge negative shape value in view::Reshape * remove is_full_slice logic in SliceFunctor * update code according to yinggang's advice * move ordered permute judge to TransposeKernel * remove print sentence * abstract IsOrderedPermute func * support negative permute value in TransposeFunctor * delete tranpose_kernel optimization * Revert "delete tranpose_kernel optimization" This reverts commit e026434dc7c1ebad948c76bde475540e3bf4477a. * not return original tensor when reshape do nothing * simplify code * correct spell error Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * fix IsContinuosSubspace error (#6968) * fix IsContinuosSubspace error * recover original IsContinuosSubspace code * add test case * auto format by CI Co-authored-by: oneflow-ci-bot Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * add cpu group deconve impl (#6980) * add cpu group deconv impl * remove useless lines * remove useless lines * add deconv2d import * add groups test * remove check_allclose=False * add tf_prelu * add cpu group deconv impl * remove useless lines * remove useless lines * add deconv2d * add groups test * remove check_allclose=False * add tf_prelu * auto format by CI * add deconv2d impl * add deconv2d impl * remove useless lines * add deconv2d in functional api * auto format by CI * auto format by CI * Add variable initial * Add variable initial * auto format by CI * add conv2d impl * add conv2d impl * auto format by CI * remove useless lines Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> Co-authored-by: oneflow-ci-bot * Migrate the python layer logic of broadcastlike to functor (#7007) * Migrate the python layer logic of broadcastlike to functor * add var name Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Temporarily skip comm test cases (#7015) * Temporarily skip comm test cases * auto format by CI Co-authored-by: oneflow-ci-bot Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Fix nd_sbp attribute type and set nd_sbp in random functors (#7017) * fix * fix compile Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Save Job to IR and load Job from IR (#6885) * save to ir * test * fix bugs * impl load and test * rm useless code * fix conflict * fix issues * JobOp * fix issues * fix test_fuse_tril_scale * fix test jit-outline-func * fix test_mlir_opt.py * save * fix ods gen for max and avg pool * rename oneflow to oneflow_foundation * fix files checks * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine * auto format by CI * check in changes * refine * Update oneflow/ir/test/OneFlow/test_mlir_opt.py * Update oneflow/ir/include/OneFlow/OneFlowOps.td * refine includes * printer & parser & verifier * code tidy * tidy include * address review * rm duplicated GetDataTypeType * TensorSource trait Co-authored-by: jackalcooper Co-authored-by: oneflow-ci-bot * Fix Simple CI linkage (#6986) * fix-simple-ci-linkage * refine * refine * fix * refine * refine * refine * refine * refien * refine * revert * refine * auto format by CI * refine * revert * refine Co-authored-by: oneflow-ci-bot * fix sbp when weight is optional (#6984) Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Feat from numpy (#7013) * feat(Tensor): support share memory with ndarray * test(FromNumpy): add test * enhancement test and add document * Fix merge error * fix bug in numpy c api * Fix(doctest): fix doctest error Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Add custom ShapeAttr in ODS (#7023) * add ShapeAttr * refine * fix doc * refine * fix (#7028) * Add linspace op (#7006) * add linspace op * refine doc * refine * fix comments * fix comment * auto format by CI * fix ci doc error Co-authored-by: oneflow-ci-bot Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * fix fasterrcnn infer (#7014) * fix fasterrcnn infer * roi_align 0shape * refine * refine Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * separate kernel state and cache (#6655) * support eager state except lazy dynamic Signed-off-by: daquexian * modularize kernel contexts Signed-off-by: daquexian * fix warning Signed-off-by: daquexian * reformat Signed-off-by: daquexian * remove duplicated license Signed-off-by: daquexian * fix static check error Signed-off-by: daquexian * make test gpu only Signed-off-by: daquexian * temp Signed-off-by: daquexian * revert opkernel context changes, align with master Signed-off-by: daquexian * reformat Signed-off-by: daquexian * refine cachecontext Signed-off-by: daquexian * add separate cache context inferface, remove out-dated files Signed-off-by: daquexian * add init and cache context aliases Signed-off-by: daquexian * update eager kernel Signed-off-by: daquexian * fix wrong AttrMayChanged value Signed-off-by: daquexian * rename and add comment Signed-off-by: daquexian * auto format by CI * fix combined_margin_loss_kernel.cpp Signed-off-by: daquexian * rename op_kernel_state_wrapper.h to op_kernel_wrapper.h Signed-off-by: daquexian * rename more classes, fix old cache in stateful op kernel Signed-off-by: daquexian * rename more classes Signed-off-by: daquexian * may changed -> not changed Signed-off-by: daquexian * optimize away genrepeatedbn Signed-off-by: daquexian * reformat Signed-off-by: daquexian * refine Signed-off-by: daquexian * update stateful local opkernel, use Cache** if possible Signed-off-by: daquexian * remove TensorDesc4ArgNameAndIndex base method Signed-off-by: daquexian * auto format by CI * fix clang-tidy error Signed-off-by: daquexian * auto format by CI * fix conv kernel bug Signed-off-by: daquexian * auto format by CI * fix group conv bug and fix warning Signed-off-by: daquexian * fix avgpool error Signed-off-by: daquexian * fix maxpool error Signed-off-by: daquexian * auto format by CI * respect flag in deconv cpu kernel, rename cache to cache_ptr Signed-off-by: daquexian * fix compile error Signed-off-by: daquexian * auto format by CI * fix deconv cache bug Signed-off-by: daquexian Co-authored-by: oneflow-ci-bot Co-authored-by: Li Xinqi Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Add fully support for all datatype (#7025) * add fully support for all datatype * Use max array size * add clang-format off to maintain the matrix * fix format * remove redundant numpy dtype Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Migrate split python layer to functor (#7030) * Migrate split python layer to functor * modify dim Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Add add_sparse_optimizer for Graph (#6988) * add_sparse_optimizer * format * fix bug * refine new interface by discuss * auto format by CI * address review * correct syntax * correct error message * rm debug print * auto format by CI * fix cpu-only test Co-authored-by: XIE Xuan Co-authored-by: oneflow-ci-bot Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Refine RUN_CUDA_KERNEL (#7003) * Refine RUN_CUDA_KERNEL * Added LaunchConfig Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Support llvm in tree build (#6995) * refine * refine * refine * refine * add61 * refien * refine * refine * refine * refine * refien * refine * refine * refine * refine * refine * refine * refine * rm * revert * refine * refine * refine * refine * return_self_in_to_consistent_if_necessary (#7004) * return_self_in_to_consistent_if_necessary * fix error and add test case * skip cpu test * fix error Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Decouple ep and global (#7027) * Decouple ep and global * NOLINT * fix * fix import Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * arange doc fix (#7035) Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * add_consistency_check_in_consistent_tensor_set_data (#7002) * add_consistency_check_in_consistent_tensor_set_data * auto format by CI * minor fix * add just wrap Co-authored-by: oneflow-ci-bot Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * [cmake] add liboneflow_cpp target (#7005) * add cmake changes for liboneflow_cpp.so Signed-off-by: daquexian * add separate target for cpp api test Signed-off-by: daquexian * add cpp api test in ci Signed-off-by: daquexian * reverse the order of cudnn and cuda library Signed-off-by: daquexian * update logic of BUILD_MONOLITHIC_LIBONEFLOW Signed-off-by: daquexian * rename BUILD_MONOLITHIC_LIBONEFLOW to BUILD_MONOLITHIC_LIBONEFLOW_CPP_SO Signed-off-by: daquexian * share lib directory in test container Signed-off-by: daquexian * add github actions debug Signed-off-by: daquexian * Revert "add github actions debug" This reverts commit 7d9aef684a479285c690f38d25525c9b97865e45. * add upterm debug after exe test Signed-off-by: daquexian * sleep after fail Signed-off-by: daquexian * set LD_LIBRARY_PATH in yml for cpp api test exe Signed-off-by: daquexian * sleep Signed-off-by: daquexian * upload liboneflow_cpp.so Signed-off-by: daquexian * modify cmake to trigger compilation Signed-off-by: daquexian * remove sleep Signed-off-by: daquexian * build cpp api in cpu mode Signed-off-by: daquexian Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Fix CUDA 52 and add it to CI (#7031) * refine * refine * refine * refine * revert * fix * refine * refine * refine Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Add check of placement constructor (#6991) * add_check_of_placement_constructor * move CheckDeviceIdsIsValid to runtime * handle comment * fix error * fix error Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Fix(FromNumpy): fix bug in stride (#7042) Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * add non virtual destructor back (#6999) Signed-off-by: daquexian Co-authored-by: Houjiang Chen * move python code to cpp: eye (#7036) * 80% Sbp signature left to finish * refine functional_api.yaml * 90% docstr left to update * refine * add sbp check * refine docs * auto format by CI * refine * refine docstr * auto format by CI Co-authored-by: oneflow-ci-bot Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Fix l2norm block_size (#7044) Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * fix undefined symbol: cudaGetDeviceCount (#7052) * fix_worker_orphan_process (#7048) * fix_worker_orphan_process * use SIGTERM instead * broadcast elemwise binary (#6871) * add * broadcast elementwise binary * fix * refine * fix * refine * refine * for compile * refine * refine * refine * refine * refine * revert kernels * revert kernel * refine * refine * refine * refine * nvcc thread to 4 Co-authored-by: ZZK <42901638+MARD1NO@users.noreply.github.com> Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Source op per critical section (#6472) * backup code * EventRecord * auto format by CI * backup code * remove deprecated binary test cases * refactor valatile to atomic * add StreamType::InitInstructionStatusIf/StreamType::DeleteInstructionStatusIf * merge from branch profiling_nn_graph * address comments * EventRecordProvider * more comments for XXXStatusQuerier::SetLaunched * more comments for SharedEventRecord::Init * wait source op per critical section * rename a task_node.cpp * minor fix * backup code * fix compiler complaints * 1) remove AddCtrlEdgeBetweenSrcDstTickAndInputOutputInSameRank; 2) create CriticalSectionInstance buffers * fix compiler complaints * more profiler code * refactor vm preschedule * TryMoveFromWaitingToReady * revert flying_instruction_cnt * revert to single position to call DispatchInstruction * revert several code * reset instruction watermark * remove is_xxx_hook_empty * build with profiler * merge master * insert device ticks before and after critical sections * refactor register_num of cs_wait/cs_callback from 2 to 128 * fix static analysis complaints * fix complier complaints about JobBuilder::ParallelConf4OpName * Update oneflow/core/operator/critical_section_wait_tick_op.cpp Co-authored-by: daquexian * address pr comments * add job example for InstructionsBuilder::LaunchLazyJob * address pr comments Co-authored-by: oneflow-ci-bot Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> Co-authored-by: ouyangyu Co-authored-by: daquexian * More details of error of getting op matched sbp signature (#7077) * more details of error msg * minor change * address review comment * avoid namesake iterator * Module apply only once (#7055) * add once apply of param * apply once on buffer * test reuse var on module to * test resue var * rm useless test * finish test * refine test Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * distributed test bugfix (#7057) * change spawn_shell to spawn_shell_and_check, sleep in script Signed-off-by: daquexian * fix distributed test master addr Signed-off-by: daquexian * remove sleep Signed-off-by: daquexian * spawn_shell -> spawn_shell_ignoring_failure Signed-off-by: daquexian * auto format by CI * fix bug Signed-off-by: daquexian * auto format by CI * fix the reversed logic Signed-off-by: daquexian * improve error msg Signed-off-by: daquexian * resolve name conflict of MASTER_ADDR Signed-off-by: daquexian Co-authored-by: oneflow-ci-bot Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * fix promote_type matrix (#7066) Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * fix chunk op dim=-1 bug (#7073) * fix chunk op dim=-1 bug * Update oneflow/core/functional/impl/array_functor.cpp Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> * Update oneflow/core/functional/impl/array_functor.cpp Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Fix resource desc dump cudnn conf bug (#7038) * fix Resource::DumpCudnnConf * fix typo and error msg Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * fix concat bug (#7075) * fix * support concat single input * Clean TensorNameScope after graph build (#7076) * Clear tensor name scope after graph build * Add test case of 2 graph caught same free eager tensor * auto format by CI Co-authored-by: oneflow-ci-bot Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * fix_abnormal_printing (#7099) * Fix bias add dropout fuse (#7081) * fix bias_add dropout fuse when p=0.0 * remove redundant op Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Support 1d to 2d eager boxing (#7083) * fix Resource::DumpCudnnConf * support_1d_to_2d_eager_boxing * rename stack to unflatten * add test case * of format * refine test case * Revert "fix Resource::DumpCudnnConf" This reverts commit f07278d71e3f344f435fc8f116a12cbd1c099b54. * support nd to 1d * add 2d to 1d test case Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Implement all User Ops with Op Schema (#7032) * add oneflow-tblgen: generate op schema (OpInterpCtx) from ods * cmake: add inja * tblgen: add oneflow_datatype * tblgen: use option cat * tblgen: fix error * tblgen: put impl in .cpp * tblgen: fix null attrs * tblgen: fix null ops * refine * refine * reifne * Refine op schema template and compilation * add base OpInterpCtx to finish compilation * fix * refine * fix * add custom infer code * generate op registrants automatically * refine * fix * update user op ods and fix shape attr * refine * refine * add custom code in op base * refine comments * add same_output_regst_num and infer * support declare hasxx * update op schema emitter * refine * emit output regist num * refine * refine * migrate acc op * migrate onerec_reader, ones_like, send, pack and padding ops * add has_sbp_signature_infer_fn * refine * migrate pad, parallel_cast, partial_fc and pooling ops * rm redundant has_device_infer_fn * migrate prelu, quantization, randperm, reduce and repeat ops * migrate reshape, reshape_like, roi_align, same_pad, selu and scalar related ops * back port * backport * migrate ops * refine * refine * refine * refine * add new op * fix llvm not found * fix mlir headers * fix mlir headers * fix llvm not found * irefine * mark override * fix merge * fix * fix * set op schema as obj lib to speed up * rewrite ops * add addn * add grdi * refien * add more def (#7051) * affine grid * refien * refine * refine * refine * fix * refien * refine * refine * refine * refine * refine * refien * refine * refine * refein * refine * refine * refine * refine * refien * refine * refine * refine * refien * refien * refien * refine * refine * refien * refine * refine * refine * refein * refine * refine * refine * refine * refine * refien * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine * refein * refine * refine * refine * move more ops * fix math_binary_broadcast/elementwise_ops * fix hardtanh * add norm * rename file and add CpuOnly no_grad * fix ir & fix norm op * fix oneflow-tblgen * fix math_unary_elementwise_op * fix norm * fix bn * fix op schema * refine * fix * refine physical_tensor_desc_infer_fn * refine * add ScalarLogicalNotEqualOp & RecvOp * refine * auto format by CI * fix fmt * add cuda only trait * delete unused inja * del inja_copy_headers_to_destination * delete unused inja * del inja_copy_headers_to_destination * add cuda only to tblgen * fix json inja url and md5 not used * fix json inja url and md5 not used * refine * revert * add with cuda * refine * delete GenUserOpODS * remove cuda only * revert cuda only after meeting * fix Co-authored-by: PragmaTwice Co-authored-by: hjchen2 Co-authored-by: oneflow-ci-bot * Feat/debug pass (#7054) * add pass debug * debug pass * refine comment of fuse add pass * auto format by CI Co-authored-by: oneflow-ci-bot Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Fix error message (#6930) * fix error message * fix dot doc * fix dot elem cnt * auto format by CI Co-authored-by: oneflow-ci-bot Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * fix simple ci: add of_op_schema target to tidy check (#7105) Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Rename AnyType in .td (#7109) * AnyType => Tensor * refine * refine Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Feat graph reuse var (#7080) * add once apply of param * apply once on buffer * test reuse var on module to * test resue var * rm useless test * finish test * refine test * Clear tensor name scope after graph build * Add test case of 2 graph caught same free eager tensor * auto format by CI * refactor var build draft * add full func; add check * done * add test of call parameter ousite its moudule * fix break test Co-authored-by: chengtbf <472491134@qq.com> Co-authored-by: oneflow-ci-bot Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Fix l2_normalize & add nn.functional.normalize (#6940) * fix l2_normalize * add normalize * add test for normalize * refine * clean l2_normalize and refine normalize * simplify normalize test * Fix l2norm block_size * refine Co-authored-by: Juncheng Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Align api in swin transformer (#7058) * add linspace op * fix align error in swintransformer * add @ magic method * fix conflict * support tensor list * fix meshgrid bug * revert Co-authored-by: hjchen2 * set CMAKE_LINK_DEPENDS_NO_SHARED to ON (#7063) Signed-off-by: daquexian Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Add other api graph autotest (#7091) * Clear tensor name scope after graph build * Add test case of 2 graph caught same free eager tensor * auto format by CI * add other api graph autotest * add more samples * fix comments * refine * refine * refine * refine * refine * fix error * fix test error * fix bug * fix flip bug * fix bug * fix bug * fix ci bug * fix ci error * fix bug * fix ci error Co-authored-by: chengtbf <472491134@qq.com> Co-authored-by: oneflow-ci-bot Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> Co-authored-by: Li Xiang <54010254+lixiang007666@users.noreply.github.com> * [serving] dev graph run (#7008) * add cmake changes for liboneflow_cpp.so Signed-off-by: daquexian * add separate target for cpp api test Signed-off-by: daquexian * add cpp api test in ci Signed-off-by: daquexian * graph run * reverse the order of cudnn and cuda library Signed-off-by: daquexian * update logic of BUILD_MONOLITHIC_LIBONEFLOW Signed-off-by: daquexian * rename BUILD_MONOLITHIC_LIBONEFLOW to BUILD_MONOLITHIC_LIBONEFLOW_CPP_SO Signed-off-by: daquexian * refine * [draft] implement graph parameter load and save (#7010) * implement parameter save (python) and load (c++) Signed-off-by: daquexian * revert accident changes Signed-off-by: daquexian * fix circular reference Signed-off-by: daquexian * pimpl * batching * share lib directory in test container Signed-off-by: daquexian * fix typo; * add github actions debug Signed-off-by: daquexian * Revert "add github actions debug" This reverts commit 7d9aef684a479285c690f38d25525c9b97865e45. * add upterm debug after exe test Signed-off-by: daquexian * sleep after fail Signed-off-by: daquexian * set LD_LIBRARY_PATH in yml for cpp api test exe Signed-off-by: daquexian * refine * add test file && input order * sleep Signed-off-by: daquexian * upload liboneflow_cpp.so Signed-off-by: daquexian * modify cmake to trigger compilation Signed-off-by: daquexian * load job from ir && clean && add mlir model * [remove useless python code]save to .pb * add target of_common_obj to remove duplicate REGISTER_PASS && run of_format * remove openvino * remove openvino test * refine * IValue * Update oneflow/api/cpp/framework/graph.h Co-authored-by: daquexian * refine * refine * refine * refine * refine * refine * rename in oneflow.cmake * refine oneflow.cmake * make of_api_common object library * move device util function in api to core * remove device check in New and ThreadLocalGetOrNew * refine * fix device test * refine graph test * refine GetExeDir() * refine GetExeDir() again * fix * refine * fix Co-authored-by: daquexian Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> Co-authored-by: mosout * disable autograd in lazy mode (#7070) * disable autograd in lazy mode * refine * Fix/rand source op in graph (#7092) * add test * fix rand consistent * add test * Fix powf (#7106) * quick fix power * add int scalar test case Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Dispatch stateful ops using functional api (#7046) * Dispatch functional stateful ops * fix * fix cmake * fix * disable attr check since it may not given when creating op expr. * fix * fix * fix * fix * fix * fix * fix * fix * refine Co-authored-by: VertexC * Fix HWLoc memory affinity (#7115) Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * add_env_api_docs (#7100) * add_env_api_docs * minor fix * fix grammatical errors Co-authored-by: Yao Chi Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * tmp skip s0 print because of slice (#7065) * tmp skip s0 print because of slice * tmp skip s0 print in test case * fix Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * indexing first version (#7012) * indexing first version * complete * test * out loop * test skip * revise * revise * shape * docs * formatted * confict1 * confict2 * confict2 * confict * revise * auto format by CI Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> Co-authored-by: oneflow-ci-bot * fix maybe: add Maybe(T&&) to allow constructing from rvalue T (#7125) Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * autotest_add_graph_log (#7126) * Meta info consistency check (#7085) * meta_info_consistency_check * refine check function * Update consistent_cast.cpp * move check to opinterpreter * refine * add note * refactor MetaInfoConsistencyCheck * of_format * refine * NonRecursiveMetaInfoConsistencyCheck * fix func name * add IsMetaInfoConsistencyCheckDisable() * mino fix * refine * minor fix * format * minor fix Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * cmake: use interface target instead of include_directories in pybind11 (#7128) Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Import cmake dependence json and inja using FetchContent (#7124) * import cmake dependence json and inja using FetchContent * install-llvm: fix url hash * fix inja config * add cache var * fix ninja build * fix ninja build Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Add environment variable to set GRPC_ARG_MAX_MESSAGE_LENGTH (#7130) * env ONEFLOW_GRPC_MAX_MESSAGE_BYTE_SIZE * set default to -1 * Fea/nhwc (#6811) * legacy maxpool2d module * add legacy avgpool2d * add graph cudnn conv alg config * add conv2d nhwc * lazy create cuda_stream in CudaCopyD2HDeviceCtx CudaStreamHandleDeviceCtx * refine * conv bn pool nhwc for resnet perf * one hot with float * use BiasAddRowGpu * rm l2 with 0 * reformat * add nhwc env var * legacy pool merged into new * refine * fix style * fix and refine * address review * fix and refine * fix doc test Co-authored-by: luyang Co-authored-by: guo-ran <360112263@qq.com> Co-authored-by: lixinqi Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * reduce memory usage caused by slice grad (#7144) * cmake: fix THIRD_PARTY build (#7146) Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * fix fold op (#7156) Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Support inplace for lazy consistent (#7112) * Support inplace for lazy consistent * fix single client sbp hint * refine Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * fix prelu bug (#7118) * support dtype and device in prelu * optimize PreluFunctor * fix prelu 1-dim error * update * update * auto format by CI Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> Co-authored-by: oneflow-ci-bot * use ibn2nd_sbp to get nd_sbp (#7155) Co-authored-by: Houjiang Chen * fix copy bug (#7159) * fix copy bug * add to test case * refine * fix test case Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Fix laynorm backward bug (#7164) * fix layernorm backward index bug * add layernorm test case * auto format by CI Co-authored-by: oneflow-ci-bot Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * [Fix] graph support 0-Size tensor (#6957) * Add nn.functional.glu graph test * add filter to motify functional autotest * motify code * add test example * add test else * add test judging condition for test_masked_fill.py,test_constant.py,test_tile.py态test_repeat.py,test_expand.py * add test ok example * Clear tensor name scope after graph build * Add test case of 2 graph caught same free eager tensor * auto format by CI * Dev cc clean tensor name scope (#7082) * Clear tensor name scope after graph build * Add test case of 2 graph caught same free eager tensor * auto format by CI Co-authored-by: chengtbf <472491134@qq.com> Co-authored-by: oneflow-ci-bot * submit test success example * test success example * submit test code * fix a bug about relu module with 0 shape data * fixed a bug about relu module with 0 shape data * fix a bug about relu module with 0 shape data * fix a bug about relu module with 0 shape data * 0shape and 0d autotest * fix a bug about relu module with 0 shape data * 0shape changed to 0_size * modify test_var.py * modify test_eye.py * modify test_reshape.py * modify test_.py * modify ReshapeFunctor * modify some file * Fixed graph autotest bug with reshape op test * Fixed graph autotest bug with reshape op test * fixed test_sub.py * modify test_sub.py * modify tensor_methods.cpp * modify array_functor.cpp * graph support 0-Size tensor * rename 0shape to 0 size * modified check_graph=True * fix and refine Co-authored-by: Zhenhua Co-authored-by: tangnana925 <85614052+tangnana925@users.noreply.github.com> Co-authored-by: tangnana Co-authored-by: Zhenhua <1209435+hengzi@users.noreply.github.com> Co-authored-by: chengtbf <472491134@qq.com> Co-authored-by: oneflow-ci-bot Co-authored-by: Xiaoyu Xu Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Cumsum op implementation (#7050) * add cumsum op's forward definition * add cumsum forward test case * cumsum ver3 * remove calculating time * add cumsum forward gpu implementation * fix gpu forward error * change var name * remove annotation * add cumsum cpu forward multi-thread support * add multi-thread annotation * add cumsum grad definition * update * add cumsum cpu backward * add cumsum cpu backward functor * add cumsum autograd * update * remove user interface * use random method to test cumsum forward * add cumsum gpu backward * add cumsum gpu test * fix gpu backward bug * add a 3d cuda kernel try * Revert "add cumsum gpu test" This reverts commit 05c31556ba28ecb827b25e54c2f5fa38984e8096. * Revert "Revert "add cumsum gpu test"" This reverts commit 918ee1569863b008c1d419c3528257416cffd840. * change nele to ele_cnt * add test_cumsum.py in oneflow/test/modules * change original test_cumsum to autotest version * optimize cumsum for special up_space and down_space * add two special cu func * add cumsum doc * update doc * update doc * update code according to bbuf's review * ditto * change pin/pout to in_ptr/out_ptr * remove multi-thread func * update doc * use tensor processor * update by review * update by review * update * update * auto format by CI * auto format by CI * update doc * update Co-authored-by: oneflow-ci-bot * Logical slice in tenosr str (#7116) * using logical slice in tensor str * add tensor str util file * refine * refine * refine * refine * add logical slice docs * fix bug * fix comment * auto format by CI * fix doc test bug * delete TODO Co-authored-by: oneflow-ci-bot Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Add install for oneflow py (#7107) * Add install for oneflow py * refine * refine * refine * refine * refine * refine * refine * refine * refien * refine * refine * refine * refine * refine * refine * refine * refine * refine * refine Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * fix bug: output key not exists when SavaJobToIR (#7139) * fix bug: output key not exists when SavaJobToIR * [test] makedirs when path not exists * remove useless comment Co-authored-by: Peihong Liu Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * Add linalg 2d norm op for clip_grad (#7160) * add linalg_2d_norm op for clip_grad * code format * revert sqrt * fix comment * refine * fix comment * fix ci error * fix ci error * fix docs bug * fix ci error * fix ci error Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> * refine nn.graph autotest (#7111) * add linspace op * refine graph autotest * revert * add graph error trace * fix bug * fix autotest bug * auto format by CI * fix set_printoptions error * auto format by CI * CI test bug * auto format by CI * For CI * auto format by CI * For CI test * fix ci error * revert for ci * fix bug * fix ci error * fix bug * fix bug Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> Co-authored-by: oneflow-ci-bot Co-authored-by: Li Xiang <54010254+lixiang007666@users.noreply.github.com> Co-authored-by: lixiang <88304454@qq.com> * add oneflow/pytorch cudnn.deterministic (#7172) * add cudnn.deterministic * fix bug * auto format by CI * fix bug * fix generate fake program input bug * auto format by CI Co-authored-by: oneflow-ci-bot * fix linalg vector norm scalar tensor print bug (#7178) * fix linalg vector norm scalar tensor print bug * auto format by CI Co-authored-by: oneflow-ci-bot * format * refine * format Co-authored-by: Yinggang Wang Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> Co-authored-by: oneflow-ci-bot Co-authored-by: liufengwei0103 <2472937968@qq.com> Co-authored-by: daquexian Co-authored-by: guo ran <360112263@qq.com> Co-authored-by: Peihong Liu Co-authored-by: ZZK <42901638+MARD1NO@users.noreply.github.com> Co-authored-by: Shenghang Tsai Co-authored-by: Li Xiang <54010254+lixiang007666@users.noreply.github.com> Co-authored-by: cheng cheng <472491134@qq.com> Co-authored-by: Li Xinqi Co-authored-by: lichunyou <33850693+lcylcy@users.noreply.github.com> Co-authored-by: Luyang Co-authored-by: wyushun Co-authored-by: zhu wang <33675639+olojuwin@users.noreply.github.com> Co-authored-by: leaves-zwx Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Co-authored-by: Shijie <821898965@qq.com> Co-authored-by: XIE Xuan Co-authored-by: Juncheng Co-authored-by: binbinHan Co-authored-by: CHI LIU <42956025+thinksoso@users.noreply.github.com> Co-authored-by: Yao Chi Co-authored-by: ouyangyu Co-authored-by: Xiaoyu Xu Co-authored-by: PragmaTwice Co-authored-by: luqiang guo <702572275@qq.com> Co-authored-by: ZeKai Zhou <30856589+zzk0@users.noreply.github.com> Co-authored-by: VertexC Co-authored-by: lixinqi Co-authored-by: fengdaozhuo <52237830+grybd@users.noreply.github.com> Co-authored-by: Zhenhua Co-authored-by: tangnana925 <85614052+tangnana925@users.noreply.github.com> Co-authored-by: tangnana Co-authored-by: Zhenhua <1209435+hengzi@users.noreply.github.com> Co-authored-by: lixiang <88304454@qq.com> --- .github/workflows/canary.yml | 4 +- .github/workflows/simple.yml | 6 +- .github/workflows/test.yml | 83 +- CMakeLists.txt | 42 +- ci/CMakeLists.txt | 1 + ci/clang/build-llvm.sh | 28 + ci/test/1node_op_test.sh | 2 - ci/test/2node_op_test_multi_client.sh | 4 +- ci/test/CMakeLists.txt | 25 + ci/test/distributed_run.py | 93 +- ci/test/resource-spec/1x-gtx-1080.json | 16 + ci/test/resource-spec/2x-rtx-2080.json | 20 + ci/test/resource-spec/4x-rtx-2080ti.json | 28 + cmake/caches/ci/canary/cuda.cmake | 1 + cmake/caches/ci/cpu.cmake | 2 + cmake/caches/ci/cuda.cmake | 1 + cmake/caches/ci/llvm/cuda-75-clang.cmake | 22 + cmake/caches/cn/fast/cpu-clang.cmake | 1 + cmake/caches/cn/fast/cpu.cmake | 1 + cmake/caches/cn/fast/cuda-61-clang.cmake | 3 +- cmake/caches/cn/fast/cuda-61.cmake | 1 + cmake/caches/cn/fast/cuda-75-clang.cmake | 4 +- cmake/caches/cn/fast/cuda-75.cmake | 1 + cmake/caches/cn/fast/mlir-cuda-61.cmake | 22 + cmake/caches/cn/fast/mlir-cuda-75.cmake | 1 + cmake/cfg.cmake | 7 +- cmake/functional.cmake | 32 + cmake/oneflow-config.cmake | 2 +- cmake/oneflow.cmake | 252 +- cmake/op_schema.cmake | 90 + cmake/pybind11.cmake | 15 +- cmake/third_party.cmake | 50 +- cmake/third_party/absl.cmake | 15 +- cmake/third_party/eigen.cmake | 1 + cmake/third_party/flatbuffers.cmake | 1 + cmake/third_party/gflags.cmake | 1 + cmake/third_party/glog.cmake | 1 + cmake/third_party/googletest.cmake | 1 + cmake/third_party/grpc.cmake | 1 + cmake/third_party/json.cmake | 46 +- cmake/third_party/oneDNN.cmake | 1 + cmake/third_party/opencv.cmake | 7 +- cmake/third_party/protobuf.cmake | 6 +- cmake/third_party/re2.cmake | 1 + cmake/third_party/zlib.cmake | 1 + cmake/util.cmake | 52 +- docs/source/comm.rst | 2 + docs/source/env.rst | 11 + docs/source/functional.rst | 2 + docs/source/index.rst | 1 + docs/source/oneflow.rst | 8 +- docs/source/tensor.rst | 1 + oneflow/api/common/device.cpp | 52 - oneflow/api/common/{device.h => ir_pass.cpp} | 16 +- oneflow/api/common/job_build_and_infer_ctx.h | 36 + oneflow/api/common/sbp.h | 58 + oneflow/api/common/scope.h | 54 + oneflow/api/cpp/env.cpp | 12 +- oneflow/api/cpp/framework.h | 4 +- oneflow/api/cpp/framework/device.cpp | 8 +- oneflow/api/cpp/framework/device.h | 9 +- oneflow/api/cpp/framework/dtype.h | 4 +- oneflow/api/cpp/framework/graph.cpp | 395 + oneflow/api/cpp/framework/graph.h | 56 + oneflow/api/cpp/framework/ivalue.cpp | 53 + oneflow/api/cpp/framework/ivalue.h | 149 + oneflow/api/cpp/framework/shape.cpp | 5 + oneflow/api/cpp/framework/shape.h | 18 +- oneflow/api/cpp/framework/tensor.cpp | 25 +- oneflow/api/cpp/framework/tensor.h | 27 +- oneflow/api/cpp/nn.h | 2 +- oneflow/api/cpp/nn/functional/activation.h | 2 +- oneflow/api/cpp/tests/api_test.cpp | 35 +- oneflow/api/cpp/tests/api_test.h | 2 + oneflow/api/cpp/tests/graph_test.cpp | 197 + .../affine_no_parameter/model.mlir | 11 + .../affine_with_parameter/model.a/meta | 5 + .../affine_with_parameter/model.a/out | Bin 0 -> 48 bytes .../affine_with_parameter/model.b/meta | 4 + .../affine_with_parameter/model.b/out | Bin 0 -> 16 bytes .../affine_with_parameter/model.mlir | 11 + oneflow/api/cpp/tests/ivalue_test.cpp | 132 + oneflow/api/cpp/tests/tensor_test.cpp | 8 +- oneflow/api/python/framework/device.cpp | 5 +- oneflow/api/python/framework/dtype.cpp | 12 +- oneflow/api/python/framework/nn_graph.cpp | 15 + oneflow/api/python/framework/op_expr.cpp | 59 +- oneflow/api/python/functional/common.cpp | 31 + oneflow/api/python/functional/common.h | 13 + .../functional/dispatch_stateful_ops.cpp | 459 + .../functional/dispatch_stateful_ops.yaml | 137 + oneflow/api/python/functional/py_function.cpp | 9 +- oneflow/api/python/functional/python_arg.cpp | 28 +- oneflow/api/python/functional/tensor_api.cpp | 72 +- oneflow/api/python/functional/tensor_api.yaml | 35 +- oneflow/api/python/functional/value_types.cpp | 5 +- oneflow/api/python/functional/value_types.h | 13 +- oneflow/api/python/ir.cpp | 3 - .../python/job_build/job_build_and_infer.cpp | 2 +- .../python/job_build/job_build_and_infer.h | 6 +- .../job_build/job_build_and_infer_api.h | 5 +- .../api/python/symbol/placement_symbol.cpp | 20 + oneflow/api/python/symbol/sbp_symbol.cpp | 3 +- oneflow/api/python/utils/tensor_utils.cpp | 2 +- .../core/autograd/gradient_funcs/concat.cpp | 12 +- .../gradient_funcs/consistent_cast.cpp | 24 +- .../core/autograd/gradient_funcs/cumsum.cpp | 64 + .../core/autograd/gradient_funcs/deconv.cpp | 22 +- .../core/autograd/gradient_funcs/diagonal.cpp | 72 + oneflow/core/autograd/gradient_funcs/dot.cpp | 2 - .../core/autograd/gradient_funcs/slice.cpp | 9 +- oneflow/core/boxing/boxing_dividor_util.cpp | 34 + oneflow/core/boxing/boxing_dividor_util.h | 2 + .../boxing/eager_boxing_interpreter_mgr.cpp | 35 +- oneflow/core/boxing/unflatten_hierarchy.cpp | 62 + oneflow/core/common/buffer_manager.h | 20 + oneflow/core/common/data_type.proto | 10 +- oneflow/core/common/high_order_bool.h | 13 +- oneflow/core/common/maybe.h | 1 + oneflow/core/common/maybe_test.cpp | 3 + oneflow/core/common/nd_index_offset_helper.h | 9 + oneflow/core/common/shape.cpp | 1 - oneflow/core/control/ctrl_service.cpp | 6 +- oneflow/core/cuda/layer_norm.cuh | 4 +- oneflow/core/cuda/softmax.cuh | 55 +- oneflow/core/device/cuda_util.cpp | 29 +- oneflow/core/device/cuda_util.h | 25 +- oneflow/core/device/ep_based_event_record.h | 54 + .../critical_section_instruction_type.cpp | 62 +- .../critical_section_phy_instr_operand.cpp | 59 +- .../critical_section_phy_instr_operand.h | 91 +- .../eager/critical_section_stream_type.cpp | 1 - .../core/eager/lazy_job_instruction_type.cpp | 118 +- .../core/eager/lazy_job_phy_instr_operand.cpp | 18 - .../core/eager/lazy_job_phy_instr_operand.h | 38 +- oneflow/core/eager/lazy_job_stream_type.cpp | 1 - .../core/eager/opkernel_instruction_type.cpp | 26 +- .../eager/opkernel_instruction_type_test.cpp | 26 +- .../core/ep/common/active_device_guard.cpp | 5 +- .../ep/common/device_manager_registry.cpp | 9 +- .../core/ep/common/primitive/binary_functor.h | 118 + .../primitive/broadcast_elementwise_binary.h | 79 + .../broadcast_simplify_dims_test.cpp | 91 + oneflow/core/ep/common/primitive/util.h | 129 + oneflow/core/ep/cpu/cpu_device.cpp | 10 +- oneflow/core/ep/cpu/cpu_device.h | 8 +- oneflow/core/ep/cpu/cpu_device_manager.cpp | 10 +- oneflow/core/ep/cpu/cpu_device_manager.h | 6 +- .../ep/cpu/cpu_device_manager_factory.cpp | 4 +- .../core/ep/cpu/primitive/binary_functor.h | 39 + .../broadcast_elementwise_binary.cpp | 190 + oneflow/core/ep/cuda/cuda_device.cpp | 4 +- oneflow/core/ep/cuda/cuda_device.h | 6 +- oneflow/core/ep/cuda/cuda_device_manager.cpp | 9 +- oneflow/core/ep/cuda/cuda_device_manager.h | 6 +- .../ep/cuda/cuda_device_manager_factory.cpp | 4 +- oneflow/core/ep/cuda/cuda_stream.cpp | 10 +- oneflow/core/ep/cuda/cuda_stream.h | 46 + .../core/ep/cuda/primitive/binary_functor.cuh | 50 + .../primitive/broadcast_elementwise_binary.cu | 86 + .../broadcast_elementwise_binary.cuh | 377 + ...roadcast_elementwise_binary_comparision.cu | 37 + .../broadcast_elementwise_binary_logical.cu | 37 + .../broadcast_elementwise_binary_math.cu | 35 + oneflow/core/ep/include/device.h | 3 + oneflow/core/ep/include/device_manager.h | 6 +- .../core/ep/include/device_manager_factory.h | 4 +- oneflow/core/ep/include/primitive/binary_op.h | 9 + .../primitive/broadcast_elementwise_binary.h | 2 +- oneflow/core/framework/attr_value.cpp | 18 +- oneflow/core/framework/attr_value.h | 40 +- .../core/framework/attr_value_accessor.cpp | 6 +- oneflow/core/framework/consistency_check.cpp | 255 + oneflow/core/framework/consistency_check.h | 52 + .../core/framework/data_consistency_check.cpp | 54 - oneflow/core/framework/device.cpp | 21 + oneflow/core/framework/device.h | 5 +- oneflow/core/framework/dtype.cpp | 110 +- oneflow/core/framework/dtype.h | 13 +- oneflow/core/framework/infer_util.cpp | 16 +- oneflow/core/framework/infer_util.h | 1 - .../core/framework/instructions_builder.cpp | 119 +- oneflow/core/framework/instructions_builder.h | 6 +- .../multi_client_session_context.cpp | 3 + oneflow/core/framework/nd_sbp.cpp | 41 +- oneflow/core/framework/nd_sbp.h | 6 +- oneflow/core/framework/nn_graph.cpp | 79 +- oneflow/core/framework/nn_graph.h | 8 +- oneflow/core/framework/op_attrs.cpp | 58 + oneflow/core/framework/op_attrs.h | 102 + oneflow/core/framework/op_base.h | 55 + oneflow/core/framework/op_expr.cpp | 16 +- oneflow/core/framework/op_interp_ctx.cpp | 66 + oneflow/core/framework/op_interp_ctx.h | 73 + .../eager_consistent_op_interpreter.cpp | 9 +- .../op_interpreter/lazy_op_interpreter.cpp | 37 +- .../op_interpreter/op_interpreter.cpp | 4 +- oneflow/core/framework/op_kernel.h | 83 +- oneflow/core/framework/placement_sbp_util.cpp | 4 +- .../core/framework/random_generator_impl.cpp | 12 +- oneflow/core/framework/system_ops.cpp | 115 + oneflow/core/framework/system_ops.h | 89 + oneflow/core/framework/tensor.cpp | 18 +- oneflow/core/framework/tensor.h | 12 +- oneflow/core/framework/tensor_methods.cpp | 6 +- oneflow/core/framework/tensor_name_scope.cpp | 5 + oneflow/core/framework/tensor_name_scope.h | 2 + oneflow/core/framework/tensor_rpc_util.cpp | 2 +- oneflow/core/framework/tensor_rpc_util.h | 4 +- oneflow/core/framework/tensor_util.cpp | 36 + oneflow/core/framework/tensor_util.h | 29 + oneflow/core/framework/user_op_conf.cpp | 35 +- oneflow/core/framework/user_op_def.cpp | 6 - oneflow/core/framework/user_op_def.h | 1 - oneflow/core/framework/user_op_def.proto | 2 - oneflow/core/framework/user_op_registry.cpp | 25 +- oneflow/core/framework/user_op_registry.h | 8 +- oneflow/core/framework/user_op_tensor.h | 8 +- oneflow/core/functional/function_library.h | 31 +- oneflow/core/functional/functional_api.yaml | 118 +- .../functional/impl/activation_functor.cpp | 18 +- .../core/functional/impl/array_functor.cpp | 304 +- .../core/functional/impl/consistent_cast.cpp | 11 + .../core/functional/impl/dataset_functor.cpp | 1 + oneflow/core/functional/impl/eye_functor.cpp | 137 + oneflow/core/functional/impl/math_functor.cpp | 209 +- oneflow/core/functional/impl/nn_functor.cpp | 128 +- .../core/functional/impl/random_functor.cpp | 34 +- .../core/functional/impl/unary_functor.cpp | 2 + oneflow/core/functional/tensor_index.cpp | 53 +- oneflow/core/graph/boxing/boxing_logger.cpp | 12 +- oneflow/core/graph/task_graph.cpp | 47 - oneflow/core/graph/task_graph.h | 1 - .../core/graph/task_stream_index_manager.cpp | 2 +- ...ritical_section_wait_compute_task_node.cpp | 67 + .../core/hardware/cuda_device_descriptor.cpp | 2 +- .../hardware/cuda_device_descriptor_class.cpp | 2 +- .../hardware/net_ib_device_descriptor.cpp | 2 +- .../net_ib_device_descriptor_class.cpp | 2 +- .../hardware/net_socket_device_descriptor.cpp | 2 +- .../net_socket_device_descriptor_class.cpp | 2 +- .../core/hardware/node_device_descriptor.cpp | 6 +- oneflow/core/job/compiler.cpp | 4 - oneflow/core/job/critical_section_instance.h | 39 + oneflow/core/job/env_global_objects_scope.cpp | 6 - .../core/job/inter_job_mem_sharing_util.cpp | 6 +- oneflow/core/job/job_build_and_infer_ctx.cpp | 100 +- oneflow/core/job/job_build_and_infer_ctx.h | 12 +- .../core/job/job_build_and_infer_ctx_mgr.cpp | 2 +- oneflow/core/job/job_builder.cpp | 6 +- oneflow/core/job/job_builder.h | 4 +- oneflow/core/job/job_instance.h | 6 - oneflow/core/job/job_ir.cpp | 32 + .../data_consistency_check.h => job/job_ir.h} | 13 +- oneflow/core/job/oneflow.cpp | 2 +- oneflow/core/job/parallel_desc.cpp | 51 + oneflow/core/job/parallel_desc.h | 2 + oneflow/core/job/plan_util.cpp | 2 +- oneflow/core/job/resource_desc.cpp | 1 - oneflow/core/job/sbp_parallel.cpp | 55 +- oneflow/core/job/sbp_parallel.h | 4 + oneflow/core/job/task.proto | 1 + oneflow/core/job_rewriter/adam_optm.cpp | 4 +- .../job_rewriter/add_lbi_diff_watcher.cpp | 2 +- .../auto_mixed_precision_lists.cpp | 3 +- oneflow/core/job_rewriter/autograd.cpp | 6 +- oneflow/core/job_rewriter/autotick.cpp | 201 +- oneflow/core/job_rewriter/autotick.h | 1 + oneflow/core/job_rewriter/clone_grad.cpp | 14 +- oneflow/core/job_rewriter/clone_grad.h | 9 +- .../job_rewriter/fuse_add_to_output_pass.cpp | 5 + .../insert_nccl_logical_op_pass.cpp | 12 +- oneflow/core/job_rewriter/job_completer.cpp | 1 + .../critical_section_callback_tick_kernel.cpp | 51 + .../critical_section_wait_tick_kernel.cpp | 50 + oneflow/core/kernel/eager_kernel.h | 1 + oneflow/core/kernel/input_kernel.cpp | 10 +- oneflow/core/kernel/kernel_util.cuh | 15 +- oneflow/core/kernel/output_kernel.cpp | 10 +- oneflow/core/kernel/return_kernel.cpp | 10 +- oneflow/core/kernel/user_kernel.cpp | 73 +- oneflow/core/kernel/user_kernel.h | 7 + oneflow/core/lazy/actor/naive_actor.cpp | 1 + oneflow/core/lazy/actor/pack_actor.cpp | 2 +- oneflow/core/lazy/actor/unpack_actor.cpp | 2 +- oneflow/core/ndarray/binary_func.h | 7 +- oneflow/core/ndarray/ndarray_assign_core.cu | 5 +- oneflow/core/ndarray/xpu_ndarray_assign.cu | 5 +- .../critical_section_callback_tick_op.cpp | 81 + .../critical_section_wait_tick_op.cpp | 81 + oneflow/core/operator/interface_op_util.cpp | 2 +- oneflow/core/operator/op_conf.proto | 16 +- oneflow/core/operator/operator.cpp | 26 +- oneflow/core/operator/user_op.cpp | 6 +- oneflow/core/vm/async_cuda_stream_type.cpp | 3 +- oneflow/core/vm/control_stream_type.cpp | 1 - oneflow/core/vm/cpu_stream_type.cpp | 3 +- oneflow/core/vm/cuda_allocator.cpp | 11 +- oneflow/core/vm/cuda_copy_d2h_stream_type.cpp | 3 +- oneflow/core/vm/cuda_copy_h2d_stream_type.cpp | 3 +- oneflow/core/vm/cuda_stream_type.cpp | 3 +- oneflow/core/vm/device_helper_stream_type.cpp | 3 +- oneflow/core/vm/host_stream_type.cpp | 1 - oneflow/core/vm/stream_desc.cpp | 8 +- oneflow/core/vm/stream_desc.h | 10 +- oneflow/core/vm/test_util.cpp | 2 +- oneflow/core/vm/transport_stream_type.cpp | 16 +- oneflow/ir/CMakeLists.txt | 115 +- oneflow/ir/include/OneFlow/CMakeLists.txt | 17 +- oneflow/ir/include/OneFlow/OneFlowBase.td | 172 +- oneflow/ir/include/OneFlow/OneFlowDialect.h | 1 + oneflow/ir/include/OneFlow/OneFlowDialect.td | 3 + .../ir/include/OneFlow/OneFlowInterfaces.td | 11 + oneflow/ir/include/OneFlow/OneFlowOpTraits.h | 131 + oneflow/ir/include/OneFlow/OneFlowOps.h | 97 +- oneflow/ir/include/OneFlow/OneFlowOps.td | 184 +- oneflow/ir/include/OneFlow/OneFlowPatterns.td | 2 +- oneflow/ir/include/OneFlow/OneFlowUserOps.td | 8713 +++++++++++++++++ oneflow/ir/install-llvm.cmake | 86 + oneflow/ir/lib/OneFlow/CMakeLists.txt | 1 - oneflow/ir/lib/OneFlow/OneFlowOpGetGen.cpp.in | 9 +- oneflow/ir/lib/OneFlow/OneFlowOps.cpp | 203 +- oneflow/ir/lib/OneFlow/Passes.cpp | 36 +- oneflow/ir/llvm-in-tree.cmake | 57 + oneflow/ir/oneflow-extension/CMakeLists.txt | 35 +- oneflow/ir/oneflow-extension/extension.cpp | 4 +- oneflow/ir/oneflow-extension/ir_pass.cpp | 20 +- .../ir/oneflow-gen-ods/oneflow-gen-ods.cpp | 724 -- oneflow/ir/oneflow-opt/CMakeLists.txt | 17 +- oneflow/ir/oneflow-runtime/CMakeLists.txt | 1 + oneflow/ir/oneflow-runtime/lib/CMakeLists.txt | 8 + oneflow/ir/oneflow-runtime/lib/Runtime.cpp | 17 + oneflow/ir/oneflow-translate/CMakeLists.txt | 1 + .../include/OneFlow/MLIROneFlowTranslation.h | 29 +- .../lib/OneFlow/CMakeLists.txt | 8 +- .../lib/OneFlow/Importer.cpp | 293 +- .../lib/OneFlow/MLIROneFlowTranslation.cpp | 607 +- oneflow/ir/test/CMakeLists.txt | 7 +- oneflow/ir/test/OneFlow/jit-outline-func.mlir | 6 +- oneflow/ir/test/OneFlow/networks/__init__.py | 0 oneflow/ir/test/OneFlow/networks/resnet50.py | 293 + .../ir/test/OneFlow/test_fuse_cast_scale.py | 2 +- .../ir/test/OneFlow/test_fuse_tril_scale.py | 16 +- .../test/OneFlow/test_graph_save_and_load.py | 96 + oneflow/ir/test/OneFlow/test_mlir_opt.mlir | 30 - ...test_mlir_opt.mlir.py => test_mlir_opt.py} | 13 + oneflow/ir/test/lit.cfg.py | 6 +- oneflow/ir/test/lit.site.cfg.py.in | 1 + oneflow/user/data/coco_data_reader.h | 2 +- oneflow/user/kernels/add_n_kernel.cpp | 4 +- oneflow/user/kernels/arange_kernel.cpp | 1 + oneflow/user/kernels/avg_pooling_kernel.cpp | 88 +- oneflow/user/kernels/bernoulli_kernel.cpp | 5 +- oneflow/user/kernels/cast_kernel.cpp | 4 +- oneflow/user/kernels/coco_reader_kernel.cpp | 3 +- .../kernels/combined_margin_loss_kernel.cpp | 56 +- .../kernels/combined_margin_loss_kernel.cu | 58 +- oneflow/user/kernels/conv_cudnn_kernels.cpp | 24 +- oneflow/user/kernels/conv_kernels.cpp | 180 +- oneflow/user/kernels/cumsum_kernel.cpp | 129 + oneflow/user/kernels/cumsum_kernel.cu | 207 + oneflow/user/kernels/deconv_cpu_kernel.cpp | 99 +- oneflow/user/kernels/diagonal_kernel.cpp | 138 + oneflow/user/kernels/diagonal_kernel.cu | 160 + .../kernels/distributions/normal_kernel.h | 3 +- .../distributions/uniform_int_kernel.h | 3 +- .../kernels/distributions/uniform_kernel.h | 3 +- oneflow/user/kernels/dot_kernel.cpp | 1 - oneflow/user/kernels/dropout_kernel.cpp | 5 +- oneflow/user/kernels/dropout_kernel.cu | 5 +- oneflow/user/kernels/eager_b_to_s_kernel.cpp | 27 +- oneflow/user/kernels/eager_nccl_kernels.cpp | 110 +- oneflow/user/kernels/eager_nccl_kernels.cu | 128 +- oneflow/user/kernels/eager_p_to_b_kernel.cpp | 23 +- oneflow/user/kernels/eager_p_to_s_kernel.cpp | 27 +- oneflow/user/kernels/eager_s_to_b_kernel.cpp | 27 +- oneflow/user/kernels/eager_s_to_s_kernel.cpp | 27 +- .../kernels/eager_symmetric_s_to_p_kernel.cpp | 25 +- oneflow/user/kernels/fold_kernel.cpp | 4 +- oneflow/user/kernels/gather_kernel.cpp | 27 +- ...andom_batch_permutation_indices_kernel.cpp | 5 +- ...random_batch_permutation_indices_kernel.cu | 5 +- .../user/kernels/gpt_data_loader_kernel.cpp | 3 +- oneflow/user/kernels/group_conv_kernel.cpp | 143 +- oneflow/user/kernels/group_deconv_kernel.cpp | 429 + .../kernels/heap_selection_top_k_kernel.cu | 68 +- .../user/kernels/image_preprocess_kernels.cpp | 12 +- .../user/kernels/image_preprocess_kernels.cu | 3 +- oneflow/user/kernels/l2_normalize_kernel.cu | 4 +- .../kernels/math_binary_elementwise_func.h | 1 + .../kernels/math_unary_elementwise_func.h | 28 + .../user/kernels/min_max_observer_kernel.cu | 35 +- oneflow/user/kernels/model_update_kernels.cpp | 67 +- .../moving_average_min_max_observer_kernel.cu | 31 +- .../kernels/nccl_logical_2d_sbp_kernels.cpp | 15 +- oneflow/user/kernels/nccl_logical_kernels.cpp | 15 +- oneflow/user/kernels/nvtx_range_kernel.cu | 6 +- .../user/kernels/ofrecord_decoder_kernels.cpp | 5 +- ...ord_image_classification_reader_kernel.cpp | 3 +- .../user/kernels/ofrecord_reader_kernel.cpp | 3 +- oneflow/user/kernels/onerec_reader_kernel.cpp | 3 +- ...el_state_wrapper.h => op_kernel_wrapper.h} | 15 + oneflow/user/kernels/pack_kernel.cpp | 5 +- .../user/kernels/partial_fc_sample_kernel.cu | 6 +- oneflow/user/kernels/pool_cpu_kernel.cpp | 158 +- oneflow/user/kernels/pool_gpu_kernel.cpp | 162 +- oneflow/user/kernels/pooling_kernel.cpp | 86 +- oneflow/user/kernels/prelu_kernel.cpp | 4 +- oneflow/user/kernels/prelu_kernel.cu | 4 +- .../user/kernels/radix_sort_top_k_kernel.cu | 60 +- .../user/kernels/random_mask_like_kernel.h | 3 +- oneflow/user/kernels/randperm_kernel.cpp | 5 +- oneflow/user/kernels/randperm_kernel.cu | 5 +- oneflow/user/kernels/relu_kernel.cpp | 8 +- oneflow/user/kernels/roi_align_kernel.cu | 1 + oneflow/user/kernels/slice_kernel.cpp | 48 +- .../kernels/sparse_cross_entropy_kernel.cpp | 48 +- .../sparse_softmax_cross_entropy_kernel.cpp | 27 +- .../user/kernels/sqrt_square_sum_kernel.cpp | 69 + .../kernels/sqrt_square_sum_kernel_util.cpp | 34 + .../kernels/sqrt_square_sum_kernel_util.cu | 82 + .../kernels/sqrt_square_sum_kernel_util.h | 30 + oneflow/user/kernels/square_sum_kernel.cpp | 2 +- .../user/kernels/stateful_local_opkernel.cpp | 39 +- .../user/kernels/stateful_local_opkernel.h | 15 +- oneflow/user/kernels/test_kernels.cpp | 11 +- oneflow/user/kernels/top_k_kernel.cpp | 37 +- oneflow/user/kernels/transpose_kernel.cpp | 23 +- oneflow/user/kernels/unpack_kernel.cpp | 5 +- .../kernels/unsorted_segment_sum_kernel.cpp | 52 +- oneflow/user/ops/acc_op.cpp | 89 +- oneflow/user/ops/adaptive_pool_op.cpp | 102 +- oneflow/user/ops/add_n_op.cpp | 82 +- oneflow/user/ops/affine_grid_op.cpp | 179 +- oneflow/user/ops/amp_white_identity_op.cpp | 54 +- oneflow/user/ops/arange_op.cpp | 103 +- oneflow/user/ops/arg_sort_op.cpp | 58 +- oneflow/user/ops/arg_where_op.cpp | 36 +- oneflow/user/ops/argmax_op.cpp | 44 +- oneflow/user/ops/assign_op.cpp | 51 +- oneflow/user/ops/batch_gather_op.cpp | 134 +- oneflow/user/ops/bernoulli_op.cpp | 49 +- oneflow/user/ops/bias_add_op.cpp | 80 +- oneflow/user/ops/binary_cross_entropy_op.cpp | 73 +- .../binary_cross_entropy_with_logits_op.cpp | 90 +- oneflow/user/ops/broadcast_div_grad_op.cpp | 94 +- oneflow/user/ops/broadcast_like_op.cpp | 43 +- oneflow/user/ops/broadcast_pow_grad_op.cpp | 194 +- oneflow/user/ops/buffer_op.cpp | 58 +- oneflow/user/ops/cast_like_op.cpp | 100 +- oneflow/user/ops/cast_op.cpp | 21 +- oneflow/user/ops/cast_to_static_shape_op.cpp | 61 +- oneflow/user/ops/cast_to_tick_op.cpp | 73 +- .../ops/categorical_ordinal_encode_op.cpp | 114 +- oneflow/user/ops/celu_op.cpp | 105 +- oneflow/user/ops/clip_by_value_op.cpp | 100 +- oneflow/user/ops/coco_reader_op.cpp | 241 +- oneflow/user/ops/combined_margin_loss_op.cpp | 170 +- oneflow/user/ops/concat_op.cpp | 78 +- oneflow/user/ops/constant_op.cpp | 41 +- oneflow/user/ops/conv_op.cpp | 467 +- oneflow/user/ops/copy_op.cpp | 76 +- oneflow/user/ops/count_not_finite_op.cpp | 115 +- oneflow/user/ops/ctc_loss_op.cpp | 241 +- oneflow/user/ops/cumsum_op.cpp | 87 + oneflow/user/ops/deconv_op.cpp | 132 +- oneflow/user/ops/diag_op.cpp | 124 +- oneflow/user/ops/diagonal_op.cpp | 100 + oneflow/user/ops/dim_gather_op.cpp | 129 +- oneflow/user/ops/dim_scatter_ops.cpp | 176 +- oneflow/user/ops/distributions/normal_op.cpp | 57 +- .../user/ops/distributions/uniform_int_op.cpp | 82 +- oneflow/user/ops/distributions/uniform_op.cpp | 65 +- oneflow/user/ops/dot_op.cpp | 59 +- oneflow/user/ops/dropout_op.cpp | 205 +- .../ops/dynamic_loss_scale_schedule_op.cpp | 44 +- oneflow/user/ops/eager_b_to_s_op.cpp | 46 +- oneflow/user/ops/eager_nccl_ops.cpp | 429 +- oneflow/user/ops/eager_p_to_b_op.cpp | 48 +- oneflow/user/ops/eager_p_to_s_op.cpp | 45 +- oneflow/user/ops/eager_s_to_b_op.cpp | 49 +- oneflow/user/ops/eager_s_to_s_op.cpp | 46 +- .../user/ops/eager_symmetric_s_to_p_op.cpp | 95 +- .../ops/elementwise_maximum_minimum_ops.cpp | 74 +- oneflow/user/ops/elu_op.cpp | 105 +- oneflow/user/ops/empty_op.cpp | 76 +- oneflow/user/ops/expand_dims_op.cpp | 73 +- oneflow/user/ops/expand_op.cpp | 207 +- oneflow/user/ops/eye_op.cpp | 38 +- oneflow/user/ops/fake_quantization_op.cpp | 185 +- oneflow/user/ops/flatten_op.cpp | 75 +- oneflow/user/ops/flip_op.cpp | 80 +- oneflow/user/ops/fused_bias_add_op.cpp | 347 +- oneflow/user/ops/fused_cast_scale_op.cpp | 22 +- .../fused_scale_mask_softmax_dropout_op.cpp | 190 +- .../user/ops/fused_scale_mask_softmax_op.cpp | 165 +- ...fused_scale_tril_softmax_mask_scale_op.cpp | 165 +- ..._attention_query_mul_key_and_value_ops.cpp | 197 +- oneflow/user/ops/gather_op.cpp | 139 +- oneflow/user/ops/gelu_op.cpp | 109 +- ...te_random_batch_permutation_indices_op.cpp | 50 +- oneflow/user/ops/gpt_data_loader_op.cpp | 77 +- oneflow/user/ops/grid_sample_op.cpp | 204 +- oneflow/user/ops/hardsigmoid_op.cpp | 107 +- oneflow/user/ops/hardswish_op.cpp | 103 +- oneflow/user/ops/hardtanh_op.cpp | 123 +- .../ops/hierarchical_parallel_cast_op.cpp | 124 +- oneflow/user/ops/identity_op.cpp | 55 +- oneflow/user/ops/image_batch_align_op.cpp | 132 +- oneflow/user/ops/image_decode_op.cpp | 87 +- .../user/ops/image_object_preprocess_ops.cpp | 420 +- oneflow/user/ops/image_preprocess_ops.cpp | 436 +- oneflow/user/ops/image_resize_ops.cpp | 246 +- oneflow/user/ops/image_target_resize_op.cpp | 103 +- oneflow/user/ops/in_top_k_op.cpp | 62 +- .../user/ops/indexed_slices_reduce_sum_op.cpp | 73 +- oneflow/user/ops/kl_div_op.cpp | 97 +- .../user/ops/l1_l2_regularize_gradient_op.cpp | 37 +- oneflow/user/ops/l2_normalize_op.cpp | 176 +- oneflow/user/ops/layer_norm_op.cpp | 428 +- oneflow/user/ops/leaky_relu_op.cpp | 114 +- oneflow/user/ops/log_softmax_op.cpp | 105 +- oneflow/user/ops/logical_not_op.cpp | 22 +- oneflow/user/ops/loss_op_util.cpp | 18 +- oneflow/user/ops/loss_op_util.h | 10 +- oneflow/user/ops/masked_fill_op.cpp | 33 +- .../user/ops/math_binary_broadcast_ops.cpp | 51 +- oneflow/user/ops/math_binary_broadcast_seq.h | 22 + .../user/ops/math_binary_elementwise_ops.cpp | 43 +- .../user/ops/math_binary_elementwise_seq.h | 7 + .../user/ops/math_unary_elementwise_op.cpp | 65 +- oneflow/user/ops/math_unary_elementwise_seq.h | 37 + oneflow/user/ops/matmul_op.cpp | 516 +- oneflow/user/ops/min_max_observer_op.cpp | 112 +- oneflow/user/ops/mish_op.cpp | 103 +- oneflow/user/ops/model_update_ops.cpp | 711 +- .../moving_average_min_max_observer_op.cpp | 160 +- oneflow/user/ops/multiply_op.cpp | 81 +- oneflow/user/ops/narrow_op.cpp | 226 +- oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp | 428 +- oneflow/user/ops/nccl_logical_ops.cpp | 413 +- oneflow/user/ops/nd_index_slice_ops.cpp | 363 +- oneflow/user/ops/nll_op.cpp | 81 +- oneflow/user/ops/nms_op.cpp | 24 +- oneflow/user/ops/normalization_op.cpp | 444 +- oneflow/user/ops/nvtx_range_op.cpp | 147 +- oneflow/user/ops/ofrecord_decoder_ops.cpp | 280 +- ...frecord_image_classification_reader_op.cpp | 107 +- oneflow/user/ops/ofrecord_reader_op.cpp | 96 +- oneflow/user/ops/one_hot_op.cpp | 93 +- oneflow/user/ops/onerec_decoder_op.cpp | 112 +- oneflow/user/ops/onerec_reader_op.cpp | 61 +- oneflow/user/ops/ones_like_op.cpp | 53 +- oneflow/user/ops/p2p_comm_op.cpp | 65 +- oneflow/user/ops/pack_op.cpp | 100 +- oneflow/user/ops/pad_op.cpp | 139 +- oneflow/user/ops/padding_ops.cpp | 758 +- oneflow/user/ops/parallel_cast_op.cpp | 81 +- oneflow/user/ops/partial_fc_sample_op.cpp | 226 +- oneflow/user/ops/pool_op.cpp | 129 +- oneflow/user/ops/pooling_op.cpp | 181 +- oneflow/user/ops/prelu_op.cpp | 186 +- oneflow/user/ops/quantization_op.cpp | 172 +- oneflow/user/ops/randperm_op.cpp | 54 +- oneflow/user/ops/reduce_like_ops.cpp | 139 +- oneflow/user/ops/reduce_ops.cpp | 48 +- oneflow/user/ops/relu_op.cpp | 109 +- oneflow/user/ops/repeat_op.cpp | 67 +- oneflow/user/ops/reshape_like_op.cpp | 87 +- oneflow/user/ops/reshape_op.cpp | 23 +- oneflow/user/ops/roi_align_op.cpp | 108 +- oneflow/user/ops/roll_op.cpp | 69 +- oneflow/user/ops/same_padding_op.cpp | 155 +- oneflow/user/ops/scalar_by_tensor_op.cpp | 141 +- oneflow/user/ops/scalar_logical_op.cpp | 65 +- oneflow/user/ops/scalar_math_op.cpp | 87 +- oneflow/user/ops/selu_op.cpp | 97 +- oneflow/user/ops/sigmoid_cross_entropy_op.cpp | 148 +- oneflow/user/ops/sigmoid_op.cpp | 101 +- oneflow/user/ops/silu_op.cpp | 97 +- oneflow/user/ops/slice_op.cpp | 373 +- oneflow/user/ops/smooth_l1_loss_op.cpp | 89 +- oneflow/user/ops/softmax_cross_entropy_op.cpp | 186 +- oneflow/user/ops/softmax_op.cpp | 100 +- oneflow/user/ops/softsign_op.cpp | 97 +- oneflow/user/ops/sort_op.cpp | 54 +- oneflow/user/ops/sparse_cross_entropy_op.cpp | 156 +- .../ops/sparse_softmax_cross_entropy_op.cpp | 77 +- oneflow/user/ops/split_like_op.cpp | 116 +- oneflow/user/ops/sqrt_square_sum_op.cpp | 41 + oneflow/user/ops/square_sum_op.cpp | 107 +- oneflow/user/ops/squeeze_op.cpp | 46 +- oneflow/user/ops/ssp_variable_proxy_op.cpp | 71 +- oneflow/user/ops/summary_ops.cpp | 130 +- oneflow/user/ops/tanh_op.cpp | 40 +- oneflow/user/ops/tensor_buffer_ops.cpp | 370 +- oneflow/user/ops/test_ops.cpp | 569 +- oneflow/user/ops/tf_prelu_op.cpp | 189 +- oneflow/user/ops/top_k_op.cpp | 52 +- oneflow/user/ops/transpose_ops.cpp | 75 +- oneflow/user/ops/tril_op.cpp | 144 +- oneflow/user/ops/triu_op.cpp | 58 +- oneflow/user/ops/tuple_identity_op.cpp | 94 +- oneflow/user/ops/two_stage_reduce_ops.cpp | 169 +- oneflow/user/ops/unfold_fold_op.cpp | 53 +- oneflow/user/ops/unfold_tensor_op.cpp | 152 +- oneflow/user/ops/unique_with_counts_op.cpp | 75 +- oneflow/user/ops/unpack_op.cpp | 89 +- .../ops/unsorted_batch_segment_sum_op.cpp | 115 +- oneflow/user/ops/unsorted_segment_sum_op.cpp | 307 +- oneflow/user/ops/upsample_op.cpp | 782 +- oneflow/user/ops/where_op.cpp | 195 +- oneflow/user/ops/zero_like_op.cpp | 53 +- oneflow/user/summary/summary_converter.h | 2 +- oneflow/user/utils/pool_util.h | 2 +- oneflow/xrt/passes/rebuild_job_pass.cpp | 3 +- oneflow/xrt/xla/ops/layer_norm_op.cpp | 25 +- python/oneflow/__init__.py | 24 +- python/oneflow/comm/__init__.py | 2 + python/oneflow/comm/comm_ops.py | 81 +- .../compatible/single_client/__init__.py | 11 +- .../single_client/framework/config_util.py | 3 + .../single_client/framework/dtype.py | 17 + .../single_client/framework/op_expr_util.py | 38 - .../framework/register_class_method_util.py | 2 - .../single_client/nn/optimizer/adam.py | 10 +- .../single_client/nn/optimizer/adamw.py | 10 +- .../single_client/nn/optimizer/rmsprop.py | 23 +- .../single_client/nn/optimizer/sgd.py | 29 +- .../compatible/single_client/ops/layers.py | 41 +- .../compatible/single_client/ops/math_ops.py | 4 +- .../ops/{builtin_ops.py => stateful_ops.py} | 27 +- .../single_client/test/ops/test_ccrelu.py | 4 +- .../test/ops/test_multi_global_function.py | 4 +- .../test/ops/test_multi_process.py | 111 - .../test/xrt/test_layer_norm_param_grad.py | 61 +- python/oneflow/distributed/launch.py | 18 +- python/oneflow/env.py | 20 + python/oneflow/framework/check_point_v2.py | 33 +- python/oneflow/framework/config_util.py | 3 + python/oneflow/framework/docstr/__init__.py | 4 + python/oneflow/framework/docstr/arange.py | 52 + python/oneflow/framework/docstr/array_ops.py | 30 + .../docstr}/broadcast_like.py | 38 +- .../{nn/modules => framework/docstr}/chunk.py | 27 +- python/oneflow/framework/docstr/conv.py | 3 +- python/oneflow/framework/docstr/math_ops.py | 212 +- python/oneflow/framework/docstr/meshgrid.py | 10 +- python/oneflow/framework/docstr/norm.py | 78 + .../{nn/modules => framework/docstr}/split.py | 22 +- python/oneflow/framework/docstr/tensor.py | 44 +- python/oneflow/framework/env_util.py | 8 +- python/oneflow/framework/graph_build_util.py | 20 +- python/oneflow/framework/op_expr_util.py | 36 - .../framework/register_class_method_util.py | 4 +- python/oneflow/framework/tensor.py | 27 +- python/oneflow/framework/tensor_str.py | 70 +- python/oneflow/framework/tensor_str_util.py | 57 + python/oneflow/nn/functional/__init__.py | 4 +- python/oneflow/nn/graph/block.py | 46 +- python/oneflow/nn/graph/graph.py | 138 +- python/oneflow/nn/graph/graph_config.py | 4 +- python/oneflow/nn/graph/optimizer.py | 60 +- python/oneflow/nn/init.py | 4 + python/oneflow/nn/module.py | 57 +- python/oneflow/nn/modules/activation.py | 11 +- python/oneflow/nn/modules/all_reduce.py | 9 +- python/oneflow/nn/modules/arange.py | 31 +- python/oneflow/nn/modules/batchnorm.py | 9 +- python/oneflow/nn/modules/batchnorm_fused.py | 9 +- python/oneflow/nn/modules/conv.py | 120 +- python/oneflow/nn/modules/dataset.py | 630 +- python/oneflow/nn/modules/in_top_k.py | 25 +- .../nn/modules/{eye.py => linspace.py} | 180 +- python/oneflow/nn/modules/math_ops.py | 99 +- python/oneflow/nn/modules/meshgrid.py | 4 +- python/oneflow/nn/modules/norm.py | 60 - python/oneflow/nn/modules/pooling.py | 196 +- python/oneflow/nn/modules/slice.py | 27 + python/oneflow/nn/modules/tensor_buffer.py | 60 +- python/oneflow/nn/optimizer/adagrad.py | 10 +- python/oneflow/nn/optimizer/adam.py | 20 +- python/oneflow/nn/optimizer/adamw.py | 20 +- python/oneflow/nn/optimizer/lr_scheduler.py | 32 +- python/oneflow/nn/optimizer/optimizer.py | 43 + python/oneflow/nn/optimizer/rmsprop.py | 23 +- python/oneflow/nn/optimizer/sgd.py | 28 +- .../oneflow/nn/optimizer/sparse_optimizer.py | 41 + python/oneflow/nn/parallel/ddp.py | 1 - python/oneflow/ops/initializer_util.py | 11 + .../ops/{builtin_ops.py => stateful_ops.py} | 28 +- python/oneflow/optim/__init__.py | 1 + python/oneflow/optim/utils.py | 16 + python/oneflow/test/graph/test_graph_eye.py | 38 + .../graph/test_graph_free_eager_tensor.py | 55 +- .../test/graph/test_graph_inplace_add.py | 74 + .../test/graph/test_graph_optimizer.py | 5 +- .../test/graph/test_graph_reuse_var.py | 98 + .../test/graph/test_graph_sparse_optimizer.py | 74 + .../oneflow/test/graph/test_input_op_expr.py | 3 +- .../oneflow/test/graph/test_output_op_expr.py | 6 +- .../oneflow/test/graph/test_user_op_expr.py | 16 +- python/oneflow/test/graph/test_util.py | 11 + .../test/graph/test_variable_op_expr.py | 4 +- python/oneflow/test/modules/resnet50_model.py | 3 +- python/oneflow/test/modules/test_abs.py | 4 +- .../oneflow/test/modules/test_activation.py | 28 +- .../test/modules/test_adaptive_pool.py | 6 +- python/oneflow/test/modules/test_add.py | 6 +- python/oneflow/test/modules/test_addmm.py | 4 +- .../oneflow/test/modules/test_affine_grid.py | 4 +- python/oneflow/test/modules/test_arange.py | 12 +- python/oneflow/test/modules/test_argmax.py | 2 +- python/oneflow/test/modules/test_autograd.py | 2 +- python/oneflow/test/modules/test_cast.py | 2 +- python/oneflow/test/modules/test_ceil.py | 4 +- python/oneflow/test/modules/test_chunk.py | 15 + python/oneflow/test/modules/test_clamp.py | 4 +- python/oneflow/test/modules/test_comm_ops.py | 226 +- python/oneflow/test/modules/test_concat.py | 14 +- python/oneflow/test/modules/test_constant.py | 6 +- .../oneflow/test/modules/test_constantpad.py | 2 +- .../test/modules/test_convtranspose.py | 48 + python/oneflow/test/modules/test_cumsum.py | 37 + python/oneflow/test/modules/test_deconv2d.py | 25 + python/oneflow/test/modules/test_diag.py | 4 +- python/oneflow/test/modules/test_diagonal.py | 44 + python/oneflow/test/modules/test_div.py | 4 +- python/oneflow/test/modules/test_dot.py | 2 +- .../oneflow/test/modules/test_eager_boxing.py | 61 + python/oneflow/test/modules/test_eq.py | 4 +- python/oneflow/test/modules/test_erf.py | 2 +- python/oneflow/test/modules/test_erfc.py | 2 +- python/oneflow/test/modules/test_expand.py | 2 +- python/oneflow/test/modules/test_expm1.py | 4 +- python/oneflow/test/modules/test_flatten.py | 2 +- python/oneflow/test/modules/test_flip.py | 2 +- python/oneflow/test/modules/test_fmod.py | 4 +- .../oneflow/test/modules/test_from_numpy.py | 62 + .../test/modules/test_functional_docstr.py | 7 +- .../modules/test_fused_bias_add_dropout.py | 10 +- python/oneflow/test/modules/test_gather.py | 2 +- python/oneflow/test/modules/test_glu.py | 4 +- python/oneflow/test/modules/test_greater.py | 4 +- python/oneflow/test/modules/test_groupnorm.py | 2 +- python/oneflow/test/modules/test_layernorm.py | 18 + python/oneflow/test/modules/test_linear.py | 2 +- python/oneflow/test/modules/test_linspace.py | 59 + python/oneflow/test/modules/test_log1p.py | 2 +- .../oneflow/test/modules/test_logical_and.py | 2 +- .../oneflow/test/modules/test_logical_not.py | 2 +- .../oneflow/test/modules/test_logical_or.py | 2 +- .../oneflow/test/modules/test_logical_xor.py | 2 +- .../oneflow/test/modules/test_lr_scheduler.py | 46 +- .../oneflow/test/modules/test_masked_fill.py | 6 +- python/oneflow/test/modules/test_math_ops.py | 8 +- python/oneflow/test/modules/test_matmul.py | 6 +- python/oneflow/test/modules/test_maxpool.py | 6 +- python/oneflow/test/modules/test_mean.py | 2 +- python/oneflow/test/modules/test_meshgrid.py | 41 +- python/oneflow/test/modules/test_module_to.py | 24 + .../test/modules/test_module_to_consistent.py | 64 + python/oneflow/test/modules/test_ne.py | 4 +- python/oneflow/test/modules/test_negative.py | 4 +- python/oneflow/test/modules/test_norm.py | 10 +- ...test_l2_normalize.py => test_normalize.py} | 51 +- python/oneflow/test/modules/test_permute.py | 6 +- python/oneflow/test/modules/test_prod.py | 4 +- python/oneflow/test/modules/test_randint.py | 41 + python/oneflow/test/modules/test_randperm.py | 2 +- .../oneflow/test/modules/test_reciprocal.py | 2 +- python/oneflow/test/modules/test_repeat.py | 2 +- python/oneflow/test/modules/test_reshape.py | 4 +- python/oneflow/test/modules/test_round.py | 2 +- python/oneflow/test/modules/test_sign.py | 6 +- python/oneflow/test/modules/test_split.py | 12 +- .../test/modules/test_sqrt_square_sum.py | 58 + python/oneflow/test/modules/test_squeeze.py | 6 +- python/oneflow/test/modules/test_stack.py | 2 +- .../test_stateful_kernel_with_cache.py | 48 + .../modules/test_stateful_local_opkernel.py | 39 +- python/oneflow/test/modules/test_std.py | 4 +- python/oneflow/test/modules/test_sub.py | 2 +- python/oneflow/test/modules/test_sum.py | 4 +- .../oneflow/test/modules/test_tensor_str.py | 10 +- python/oneflow/test/modules/test_tensor_to.py | 9 + python/oneflow/test/modules/test_tile.py | 4 +- python/oneflow/test/modules/test_transpose.py | 4 +- python/oneflow/test/modules/test_tril.py | 4 +- python/oneflow/test/modules/test_triu.py | 4 +- .../test/modules/test_unfold_tensor.py | 2 +- python/oneflow/test/modules/test_unsqueeze.py | 6 +- python/oneflow/test/modules/test_var.py | 5 +- python/oneflow/test/modules/test_where.py | 20 +- python/oneflow/test/tensor/test_parameter.py | 4 +- python/oneflow/test/tensor/test_tensor.py | 50 +- .../test/tensor/test_tensor_indexing.py | 19 +- .../automated_test_util/generators.py | 18 + .../torch_flow_dual_object.py | 125 +- python/oneflow/utils/data/_utils/__init__.py | 3 +- python/oneflow/utils/data/_utils/worker.py | 3 +- python/oneflow/utils/data/dataloader.py | 5 +- python/oneflow/utils/data/distributed.py | 2 +- python/oneflow/utils/data/sampler.py | 2 +- .../generate_dispatch_stateful_ops.py | 185 + tools/functional/generator.py | 18 +- tools/oneflow-tblgen/CMakeLists.txt | 46 + tools/oneflow-tblgen/backends.h | 39 + tools/oneflow-tblgen/example/constant.td | 17 + tools/oneflow-tblgen/op_schema_emitter.cpp | 242 + tools/oneflow-tblgen/op_schema_header.inc | 100 + tools/oneflow-tblgen/op_schema_source.inc | 106 + tools/oneflow-tblgen/op_schema_types.inc | 14 + tools/oneflow-tblgen/tablegen.cpp | 104 + tools/package_mirror.py | 1 + 815 files changed, 38707 insertions(+), 18043 deletions(-) create mode 100644 ci/CMakeLists.txt create mode 100644 ci/clang/build-llvm.sh create mode 100644 ci/test/CMakeLists.txt create mode 100644 ci/test/resource-spec/1x-gtx-1080.json create mode 100644 ci/test/resource-spec/2x-rtx-2080.json create mode 100644 ci/test/resource-spec/4x-rtx-2080ti.json create mode 100644 cmake/caches/ci/llvm/cuda-75-clang.cmake create mode 100644 cmake/caches/cn/fast/mlir-cuda-61.cmake create mode 100644 cmake/op_schema.cmake create mode 100644 docs/source/env.rst delete mode 100644 oneflow/api/common/device.cpp rename oneflow/api/common/{device.h => ir_pass.cpp} (63%) create mode 100644 oneflow/api/common/job_build_and_infer_ctx.h create mode 100644 oneflow/api/common/sbp.h create mode 100644 oneflow/api/common/scope.h create mode 100644 oneflow/api/cpp/framework/graph.cpp create mode 100644 oneflow/api/cpp/framework/graph.h create mode 100644 oneflow/api/cpp/framework/ivalue.cpp create mode 100644 oneflow/api/cpp/framework/ivalue.h create mode 100644 oneflow/api/cpp/tests/graph_test.cpp create mode 100644 oneflow/api/cpp/tests/graph_test_model/affine_no_parameter/model.mlir create mode 100644 oneflow/api/cpp/tests/graph_test_model/affine_with_parameter/model.a/meta create mode 100644 oneflow/api/cpp/tests/graph_test_model/affine_with_parameter/model.a/out create mode 100644 oneflow/api/cpp/tests/graph_test_model/affine_with_parameter/model.b/meta create mode 100644 oneflow/api/cpp/tests/graph_test_model/affine_with_parameter/model.b/out create mode 100644 oneflow/api/cpp/tests/graph_test_model/affine_with_parameter/model.mlir create mode 100644 oneflow/api/cpp/tests/ivalue_test.cpp create mode 100644 oneflow/api/python/functional/dispatch_stateful_ops.cpp create mode 100644 oneflow/api/python/functional/dispatch_stateful_ops.yaml create mode 100644 oneflow/core/autograd/gradient_funcs/cumsum.cpp create mode 100644 oneflow/core/autograd/gradient_funcs/diagonal.cpp create mode 100644 oneflow/core/boxing/unflatten_hierarchy.cpp create mode 100644 oneflow/core/device/ep_based_event_record.h create mode 100644 oneflow/core/ep/common/primitive/binary_functor.h create mode 100644 oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h create mode 100644 oneflow/core/ep/common/primitive/broadcast_simplify_dims_test.cpp create mode 100644 oneflow/core/ep/common/primitive/util.h create mode 100644 oneflow/core/ep/cpu/primitive/binary_functor.h create mode 100644 oneflow/core/ep/cpu/primitive/broadcast_elementwise_binary.cpp create mode 100644 oneflow/core/ep/cuda/primitive/binary_functor.cuh create mode 100644 oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cu create mode 100644 oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh create mode 100644 oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_comparision.cu create mode 100644 oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_logical.cu create mode 100644 oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math.cu create mode 100644 oneflow/core/framework/consistency_check.cpp create mode 100644 oneflow/core/framework/consistency_check.h delete mode 100644 oneflow/core/framework/data_consistency_check.cpp create mode 100644 oneflow/core/framework/op_attrs.cpp create mode 100644 oneflow/core/framework/op_attrs.h create mode 100644 oneflow/core/framework/op_base.h create mode 100644 oneflow/core/framework/op_interp_ctx.cpp create mode 100644 oneflow/core/framework/op_interp_ctx.h create mode 100644 oneflow/core/framework/system_ops.cpp create mode 100644 oneflow/core/framework/system_ops.h create mode 100644 oneflow/core/framework/tensor_util.cpp create mode 100644 oneflow/core/framework/tensor_util.h create mode 100644 oneflow/core/functional/impl/eye_functor.cpp create mode 100644 oneflow/core/graph_impl/critical_section_wait_compute_task_node.cpp create mode 100644 oneflow/core/job/critical_section_instance.h create mode 100644 oneflow/core/job/job_ir.cpp rename oneflow/core/{framework/data_consistency_check.h => job/job_ir.h} (62%) create mode 100644 oneflow/core/kernel/critical_section_callback_tick_kernel.cpp create mode 100644 oneflow/core/kernel/critical_section_wait_tick_kernel.cpp create mode 100644 oneflow/core/operator/critical_section_callback_tick_op.cpp create mode 100644 oneflow/core/operator/critical_section_wait_tick_op.cpp create mode 100644 oneflow/ir/include/OneFlow/OneFlowOpTraits.h create mode 100644 oneflow/ir/include/OneFlow/OneFlowUserOps.td create mode 100644 oneflow/ir/install-llvm.cmake create mode 100644 oneflow/ir/llvm-in-tree.cmake delete mode 100644 oneflow/ir/oneflow-gen-ods/oneflow-gen-ods.cpp create mode 100644 oneflow/ir/oneflow-runtime/CMakeLists.txt create mode 100644 oneflow/ir/oneflow-runtime/lib/CMakeLists.txt create mode 100644 oneflow/ir/oneflow-runtime/lib/Runtime.cpp create mode 100644 oneflow/ir/test/OneFlow/networks/__init__.py create mode 100644 oneflow/ir/test/OneFlow/networks/resnet50.py create mode 100644 oneflow/ir/test/OneFlow/test_graph_save_and_load.py delete mode 100644 oneflow/ir/test/OneFlow/test_mlir_opt.mlir rename oneflow/ir/test/OneFlow/{test_mlir_opt.mlir.py => test_mlir_opt.py} (85%) create mode 100644 oneflow/user/kernels/cumsum_kernel.cpp create mode 100644 oneflow/user/kernels/cumsum_kernel.cu create mode 100644 oneflow/user/kernels/diagonal_kernel.cpp create mode 100644 oneflow/user/kernels/diagonal_kernel.cu create mode 100644 oneflow/user/kernels/group_deconv_kernel.cpp rename oneflow/user/kernels/{op_kernel_state_wrapper.h => op_kernel_wrapper.h} (77%) create mode 100644 oneflow/user/kernels/sqrt_square_sum_kernel.cpp create mode 100644 oneflow/user/kernels/sqrt_square_sum_kernel_util.cpp create mode 100644 oneflow/user/kernels/sqrt_square_sum_kernel_util.cu create mode 100644 oneflow/user/kernels/sqrt_square_sum_kernel_util.h create mode 100644 oneflow/user/ops/cumsum_op.cpp create mode 100644 oneflow/user/ops/diagonal_op.cpp create mode 100644 oneflow/user/ops/sqrt_square_sum_op.cpp delete mode 100644 python/oneflow/compatible/single_client/framework/op_expr_util.py rename python/oneflow/compatible/single_client/ops/{builtin_ops.py => stateful_ops.py} (69%) delete mode 100644 python/oneflow/compatible/single_client/test/ops/test_multi_process.py create mode 100644 python/oneflow/framework/docstr/arange.py rename python/oneflow/{nn/modules => framework/docstr}/broadcast_like.py (51%) rename python/oneflow/{nn/modules => framework/docstr}/chunk.py (71%) rename python/oneflow/{nn/modules => framework/docstr}/split.py (80%) delete mode 100644 python/oneflow/framework/op_expr_util.py create mode 100644 python/oneflow/framework/tensor_str_util.py rename python/oneflow/nn/modules/{eye.py => linspace.py} (54%) delete mode 100644 python/oneflow/nn/modules/norm.py create mode 100644 python/oneflow/nn/optimizer/sparse_optimizer.py rename python/oneflow/ops/{builtin_ops.py => stateful_ops.py} (66%) create mode 100644 python/oneflow/optim/utils.py create mode 100644 python/oneflow/test/graph/test_graph_eye.py create mode 100644 python/oneflow/test/graph/test_graph_inplace_add.py create mode 100644 python/oneflow/test/graph/test_graph_reuse_var.py create mode 100644 python/oneflow/test/graph/test_graph_sparse_optimizer.py create mode 100644 python/oneflow/test/modules/test_cumsum.py create mode 100644 python/oneflow/test/modules/test_diagonal.py create mode 100644 python/oneflow/test/modules/test_from_numpy.py create mode 100644 python/oneflow/test/modules/test_linspace.py create mode 100644 python/oneflow/test/modules/test_module_to_consistent.py rename python/oneflow/test/modules/{test_l2_normalize.py => test_normalize.py} (73%) create mode 100644 python/oneflow/test/modules/test_sqrt_square_sum.py create mode 100644 python/oneflow/test/modules/test_stateful_kernel_with_cache.py create mode 100644 tools/functional/generate_dispatch_stateful_ops.py create mode 100644 tools/oneflow-tblgen/CMakeLists.txt create mode 100644 tools/oneflow-tblgen/backends.h create mode 100644 tools/oneflow-tblgen/example/constant.td create mode 100644 tools/oneflow-tblgen/op_schema_emitter.cpp create mode 100644 tools/oneflow-tblgen/op_schema_header.inc create mode 100644 tools/oneflow-tblgen/op_schema_source.inc create mode 100644 tools/oneflow-tblgen/op_schema_types.inc create mode 100644 tools/oneflow-tblgen/tablegen.cpp diff --git a/.github/workflows/canary.yml b/.github/workflows/canary.yml index 79799e69d00..eb999863bf9 100644 --- a/.github/workflows/canary.yml +++ b/.github/workflows/canary.yml @@ -4,7 +4,7 @@ on: push: branches: - master - - add-canary-release + - add-support-clang-12 workflow_dispatch: inputs: oneflow-ref: @@ -43,7 +43,7 @@ jobs: - name: Checkout Oneflow-Inc/oneflow if: ${{ github.event.inputs.oneflow-ref == '' }} uses: actions/checkout@v2 - - uses: Oneflow-Inc/get-oneflow@canary-release + - uses: Oneflow-Inc/get-oneflow@support-clang-12 name: Build manylinux id: build-cuda with: diff --git a/.github/workflows/simple.yml b/.github/workflows/simple.yml index 499d552e76f..35b5d06e154 100644 --- a/.github/workflows/simple.yml +++ b/.github/workflows/simple.yml @@ -50,7 +50,7 @@ jobs: cmake .. -C ../cmake/caches/international/cpu.cmake \ -DCMAKE_BUILD_TYPE=Release \ -DBUILD_TESTING=ON - cmake --build . -j$(nproc) --target oneflow_deps of_cfgobj of_protoobj of_functional_obj of_functional_tensor_obj + cmake --build . -j$(nproc) --target oneflow_deps of_cfgobj of_protoobj of_functional_obj of_functional_tensor_obj of_op_schema - name: Run clang-tidy for all translation units # use clang as compiler for correct compiler flags run: | @@ -247,7 +247,7 @@ jobs: repository: Oneflow-Inc/conda-env ref: 30a7f00eb48ee9009d85a848e720823e5054c66b path: conda-env - - uses: Oneflow-Inc/get-oneflow@canary-release + - uses: Oneflow-Inc/get-oneflow@support-clang-12 name: Build with gcc7 if: ${{ matrix.build-type == 'gcc7'}} with: @@ -256,7 +256,7 @@ jobs: oneflow-build-env: conda conda-env-file: conda-env/dev/gcc7/environment-v2.yml conda-env-name: oneflow-dev-gcc7-v2 - - uses: Oneflow-Inc/get-oneflow@canary-release + - uses: Oneflow-Inc/get-oneflow@support-clang-12 name: Build with clang10 if: ${{ matrix.build-type == 'clang10'}} with: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 03715abcc0e..b0085586049 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -182,7 +182,7 @@ jobs: with: ref: ${{ github.event.pull_request.head.sha }} repository: ${{github.event.pull_request.head.repo.full_name}} - - uses: Oneflow-Inc/get-oneflow/cache-complete/matrix/build@canary-release + - uses: Oneflow-Inc/get-oneflow/cache-complete/matrix/build@support-clang-12 name: find cache id: find-cache timeout-minutes: 5 @@ -228,7 +228,7 @@ jobs: with: ref: ${{ github.event.pull_request.head.sha }} repository: ${{github.event.pull_request.head.repo.full_name}} - - uses: Oneflow-Inc/get-oneflow/cache-complete@canary-release + - uses: Oneflow-Inc/get-oneflow/cache-complete@support-clang-12 name: Save cache if successful id: save-cache timeout-minutes: 5 @@ -242,7 +242,7 @@ jobs: run: | echo "::error file=test.yml,line=204,col=10::steps.save-cache.outputs.cache-hit != matrix.cache-hit" exit 1 - - uses: Oneflow-Inc/get-oneflow@canary-release + - uses: Oneflow-Inc/get-oneflow@support-clang-12 name: Build manylinux cpu only id: build-cpu if: ${{ matrix.entry =='cpu' && !matrix.cache-hit }} @@ -263,7 +263,7 @@ jobs: python-versions: | 3.6 3.7 - - uses: Oneflow-Inc/get-oneflow@canary-release + - uses: Oneflow-Inc/get-oneflow@support-clang-12 name: Build manylinux cu102 id: build-cuda if: ${{ matrix.entry =='cu102' && !matrix.cache-hit }} @@ -284,7 +284,7 @@ jobs: python-versions: | 3.6 3.7 - - uses: Oneflow-Inc/get-oneflow@canary-release + - uses: Oneflow-Inc/get-oneflow@support-clang-12 name: Build manylinux cu101_xla id: build-xla if: ${{ matrix.entry =='cu101_xla' && !matrix.cache-hit && needs.changed_files.outputs.should_run_single_client_tests == '1' }} @@ -306,7 +306,7 @@ jobs: 3.6 - name: Upload bin if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') && (steps.build-cpu.outcome == 'success' || steps.build-cuda.outcome == 'success' || steps.build-xla.outcome == 'success') }} - uses: Oneflow-Inc/get-oneflow/digest/upload@canary-release + uses: Oneflow-Inc/get-oneflow/digest/upload@support-clang-12 timeout-minutes: 10 with: digest: ${{ steps.save-cache.outputs.build-digest }} @@ -315,9 +315,20 @@ jobs: ssh-tank-path: ${{ env.SSH_TANK_PATH }} src-dir: ${{ env.MANYLINUX_CACHE_DIR }}/build/bin dst-dir: bin + - name: Upload liboneflow_cpp library + if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') && (steps.build-cpu.outcome == 'success' || steps.build-cuda.outcome == 'success') }} + uses: Oneflow-Inc/get-oneflow/digest/upload@support-clang-12 + timeout-minutes: 10 + with: + digest: ${{ steps.save-cache.outputs.build-digest }} + entry: ${{ matrix.entry }} + ssh-tank-host: ${{ env.SSH_TANK_HOST }} + ssh-tank-path: ${{ env.SSH_TANK_PATH }} + src-dir: ${{ env.MANYLINUX_CACHE_DIR }}/build/liboneflow_cpp/lib + dst-dir: liboneflow_cpp/lib - name: Upload whl if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') && (steps.build-cpu.outcome == 'success' || steps.build-cuda.outcome == 'success' || steps.build-xla.outcome == 'success') }} - uses: Oneflow-Inc/get-oneflow/digest/upload@canary-release + uses: Oneflow-Inc/get-oneflow/digest/upload@support-clang-12 timeout-minutes: 10 with: digest: ${{ steps.save-cache.outputs.build-digest }} @@ -331,6 +342,11 @@ jobs: name: Build with clang if: github.event.pull_request.draft == false && github.base_ref == 'master' && contains(github.event.pull_request.requested_reviewers.*.login, 'oneflow-ci-bot') runs-on: [self-hosted, linux, build] + env: + ONEFLOW_SRC: . + MANYLINUX_CACHE_DIR: ~/manylinux-cache-dir/clang13 + CUDA_VERSION: "10.1" + WHEELHOUSE_DIR: ./wheelhouse steps: - name: Fix permissions run: | @@ -338,7 +354,7 @@ jobs: docker run --rm -v $PWD:$PWD -w $PWD busybox rm -rf * - name: Checkout Oneflow-Inc/oneflow uses: actions/checkout@v2 - - uses: Oneflow-Inc/get-oneflow/cache-complete@canary-release + - uses: Oneflow-Inc/get-oneflow/cache-complete@support-clang-12 name: Save cache if successful id: save-cache timeout-minutes: 5 @@ -347,25 +363,26 @@ jobs: entry: build-with-clang digest-type: build mark-as-completed: ${{ github.event.pull_request.head.repo.full_name == github.repository }} - - name: Checkout Oneflow-Inc/conda-env - if: ${{ !fromJSON(steps.save-cache.outputs.cache-hit) }} - uses: actions/checkout@v2 - with: - repository: Oneflow-Inc/conda-env - ref: 30a7f00eb48ee9009d85a848e720823e5054c66b - path: conda-env - - uses: Oneflow-Inc/get-oneflow@canary-release + - name: Build with Clang + uses: Oneflow-Inc/get-oneflow@support-clang-12 if: ${{ !fromJSON(steps.save-cache.outputs.cache-hit) }} - name: Build with clang10 with: - cmake-init-cache: cmake/caches/ci/gh-hosted/cpu-clang.cmake - oneflow-src: . - oneflow-build-env: conda - conda-env-file: conda-env/dev/clang10/environment-v2.yml - conda-env-name: oneflow-dev-clang10-v2 - conda-installer-url: https://oneflow-static.oss-cn-beijing.aliyuncs.com/downloads/conda-installers/Miniconda3-py39_4.10.3-Linux-x86_64.sh - conda-prefix: ~/miniconda3-prefixes/py39_4.10.3 + cmake-init-cache: ${{ env.ONEFLOW_SRC }}/cmake/caches/ci/llvm/cuda-75-clang.cmake + build-script: ${{ env.ONEFLOW_SRC }}/ci/clang/build-llvm.sh + oneflow-src: ${{ env.ONEFLOW_SRC }} + oneflow-build-env: llvm + wheelhouse-dir: ${{ env.WHEELHOUSE_DIR }} + clear-wheelhouse-dir: true self-hosted: true + cuda-version: ${{ env.CUDA_VERSION }} + manylinux-cache-dir: ${{ env.MANYLINUX_CACHE_DIR }} + docker-run-use-system-http-proxy: false + docker-run-use-lld: false + retry-failed-build: true + clean-ccache: ${{ contains(github.event.pull_request.labels.*.name, 'need-clean-ccache') }} + wheel-audit: false + python-versions: | + 3.8 find-test-cache: name: "Find test cache" @@ -382,7 +399,7 @@ jobs: with: ref: ${{ github.event.pull_request.head.sha }} repository: ${{github.event.pull_request.head.repo.full_name}} - - uses: Oneflow-Inc/get-oneflow/cache-complete/matrix/test@canary-release + - uses: Oneflow-Inc/get-oneflow/cache-complete/matrix/test@support-clang-12 name: find cache id: find-cache timeout-minutes: 5 @@ -424,7 +441,7 @@ jobs: if: ${{ contains(matrix.runs-on, 'self-hosted') }} run: | docker rm -f ${{ env.TEST_CONTAINER_NAME }} || true - - uses: Oneflow-Inc/get-oneflow/cache-complete@canary-release + - uses: Oneflow-Inc/get-oneflow/cache-complete@support-clang-12 name: Save cache if successful id: save-cache timeout-minutes: 5 @@ -438,9 +455,9 @@ jobs: run: | echo "::error file=test.yml,line=204,col=10::steps.save-cache.outputs.cache-hit != matrix.cache-hit" exit 1 - - name: Download wheel and binary + - name: Download wheel, binary and liboneflow_cpp lib if: ${{ !fromJson(matrix.cache-hit) && contains(matrix.runs-on, 'self-hosted') && (!fromJson(matrix.is-xla) || (fromJson(matrix.is-xla) && needs.changed_files.outputs.should_run_single_client_tests == '1')) }} - uses: Oneflow-Inc/get-oneflow/digest/download@canary-release + uses: Oneflow-Inc/get-oneflow/digest/download@support-clang-12 id: download-digest timeout-minutes: 10 with: @@ -492,6 +509,7 @@ jobs: working-directory: ${{ env.ONEFLOW_SRC }} env: ONEFLOW_BIN_PATH: ${{ steps.download-digest.outputs.entry-dir }}/bin + ONEFLOW_CPP_API_LIB_PATH: ${{ steps.download-digest.outputs.entry-dir }}/liboneflow_cpp/lib run: | docker run -d --rm --privileged --shm-size=8g \ --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ @@ -499,6 +517,7 @@ jobs: -v /dataset:/dataset:ro -v /model_zoo:/model_zoo:ro \ -v ${ONEFLOW_WHEEL_PATH}:${ONEFLOW_WHEEL_PATH}:ro \ -v ${ONEFLOW_BIN_PATH}:${ONEFLOW_BIN_PATH}:ro \ + -v ${ONEFLOW_CPP_API_LIB_PATH}:${ONEFLOW_CPP_API_LIB_PATH}:ro \ -v $HOME/test-container-cache/dot-local:/root/.local \ -v $HOME/test-container-cache/dot-cache:/root/.cache \ -e ONEFLOW_WHEEL_PATH=${ONEFLOW_WHEEL_PATH} \ @@ -527,11 +546,13 @@ jobs: run: | docker exec ${{ env.TEST_CONTAINER_NAME }} python3 -m oneflow --doctor - name: Exe test - if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' }} + if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' && matrix.device == 'cpu' }} timeout-minutes: 10 run: | chmod +x ${{ steps.download-digest.outputs.entry-dir }}/bin/oneflow_testexe docker exec ${{ env.TEST_CONTAINER_NAME }} ${{ steps.download-digest.outputs.entry-dir }}/bin/oneflow_testexe + chmod +x ${{ steps.download-digest.outputs.entry-dir }}/bin/oneflow_cpp_api_testexe + docker exec -e LD_LIBRARY_PATH=${{ steps.download-digest.outputs.entry-dir }}/liboneflow_cpp/lib ${{ env.TEST_CONTAINER_NAME }} ${{ steps.download-digest.outputs.entry-dir }}/bin/oneflow_cpp_api_testexe - name: Build documentation timeout-minutes: 10 if: ${{ !fromJson(matrix.cache-hit) && matrix.test-type == 'misc' && matrix.device == 'cpu' }} @@ -744,7 +765,7 @@ jobs: ref: ${{ github.event.pull_request.head.sha }} repository: ${{github.event.pull_request.head.repo.full_name}} fetch-depth: 0 - - uses: Oneflow-Inc/get-oneflow/cache-complete@canary-release + - uses: Oneflow-Inc/get-oneflow/cache-complete@support-clang-12 name: Save cache if successful id: save-cache timeout-minutes: 5 @@ -785,7 +806,7 @@ jobs: -DBUILD_TESTING=ON \ -DCMAKE_C_COMPILER_LAUNCHER=ccache \ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache - cmake --build . -j$(nproc) --target oneflow_deps of_cfgobj of_protoobj of_functional_obj of_functional_tensor_obj + cmake --build . -j$(nproc) --target oneflow_deps of_cfgobj of_protoobj of_functional_obj of_functional_tensor_obj of_op_schema - name: Fetch upstream if: ${{ !fromJSON(steps.save-cache.outputs.cache-hit) && github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name }} run: | diff --git a/CMakeLists.txt b/CMakeLists.txt index 3c963b59e5e..4ce9d765882 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,7 @@ # Minimum CMake required cmake_minimum_required(VERSION 3.18.0) +set(CMAKE_INSTALL_MESSAGE LAZY CACHE STRING "") if (NOT CMAKE_BUILD_TYPE) message(STATUS "No build type selected, default to Release") set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Build type (default Release)" FORCE) @@ -23,7 +24,8 @@ endif() option(USE_CLANG_FORMAT "" OFF) option(USE_CLANG_TIDY "" OFF) option(BUILD_PYTHON "" ON) -option(BUILD_MONOLITHIC_LIBONEFLOW "" ON) +option(BUILD_CPP_API "Option to build OneFlow C++ API (beta)" OFF) +option(BUILD_MONOLITHIC_LIBONEFLOW_CPP_SO "Option to build a monolithic liboneflow_cpp.so (only meaningful when BUILD_CPP_API is ON)" ON) option(BUILD_RDMA "" OFF) option(BUILD_CUDA "" ON) option(WITH_ONEDNN "" OFF) @@ -33,7 +35,12 @@ option(WITH_TENSORRT "Option to build with TensorRT" OFF) option(WITH_OPENVINO "Option to build with OpenVINO" OFF) option(WITH_MLIR "" OFF) option(WITH_MLIR_CUDA_CODEGEN "" OFF) +set(LLVM_PROVIDER "in-tree" CACHE STRING "in-tree, install") +if (NOT WITH_MLIR) + set(LLVM_PROVIDER "install" CACHE STRING "in-tree will build LLVM's ALL, not what we want when not building MLIR" FORCE) +endif(NOT WITH_MLIR) option(WITH_COCOAPI "Option to build with COCO API" ON) +option(WITH_ZLIB "" ON) option(BUILD_GIT_VERSION "" ON) option(BUILD_PROFILER "" OFF) option(OF_SOFTMAX_USE_FAST_MATH "" ON) @@ -201,28 +208,22 @@ endif() if(BUILD_PYTHON) set(ONEFLOW_INCLUDE_DIR "${ONEFLOW_PYTHON_DIR}/oneflow/include") -else() # build_python - set(ONEFLOW_INCLUDE_DIR "${PROJECT_BINARY_DIR}/liboneflow/include/oneflow") - set(ONEFLOW_LIBRARY_DIR "${PROJECT_BINARY_DIR}/liboneflow/lib") - set(ONEFLOW_SHARE_DIR "${PROJECT_BINARY_DIR}/liboneflow/share") - make_directory(${ONEFLOW_INCLUDE_DIR}) - make_directory(${ONEFLOW_LIBRARY_DIR}) - make_directory(${ONEFLOW_SHARE_DIR}) +endif(BUILD_PYTHON) + +if(BUILD_CPP_API) + set(LIBONEFLOW_LIBRARY_DIR "${PROJECT_BINARY_DIR}/liboneflow_cpp/lib") + set(LIBONEFLOW_SHARE_DIR "${PROJECT_BINARY_DIR}/liboneflow_cpp/share") + make_directory(${LIBONEFLOW_LIBRARY_DIR}) + make_directory(${LIBONEFLOW_SHARE_DIR}) if(BUILD_SHARED_LIBS) - if(BUILD_MONOLITHIC_LIBONEFLOW) - set(BUILD_SHARED_LIBS OFF) + if(BUILD_MONOLITHIC_LIBONEFLOW_CPP_SO) + message(FATAL_ERROR "BUILD_MONOLITHIC_LIBONEFLOW_CPP_SO is incompatible with BUILD_SHARED_LIBS. Please set either of them to OFF.") else() - set(LIBRARY_OUTPUT_PATH ${ONEFLOW_LIBRARY_DIR}) - endif(BUILD_MONOLITHIC_LIBONEFLOW) - set(BUILD_SHARED_LIBONEFLOW ON) - else() - if(BUILD_MONOLITHIC_LIBONEFLOW) - message(WARNING "BUILD_MONOLITHIC_LIBONEFLOW=ON is meaningless when BUILD_SHARED_LIBS=OFF") - endif() - set(BUILD_SHARED_LIBONEFLOW OFF) + set(LIBRARY_OUTPUT_PATH ${LIBONEFLOW_LIBRARY_DIR}) + endif(BUILD_MONOLITHIC_LIBONEFLOW_CPP_SO) endif(BUILD_SHARED_LIBS) -endif(BUILD_PYTHON) +endif(BUILD_CPP_API) include(third_party) @@ -261,7 +262,7 @@ if (BUILD_CUDA) if ("${CMAKE_CUDA_COMPILER_ID}" STREQUAL "NVIDIA") if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "11.2") - set(CUDA_NVCC_THREADS_NUMBER "1" CACHE STRING "") + set(CUDA_NVCC_THREADS_NUMBER "4" CACHE STRING "") list(APPEND CUDA_NVCC_FLAGS -t ${CUDA_NVCC_THREADS_NUMBER}) endif() message(STATUS "CUDA_NVCC_FLAGS: " ${CUDA_NVCC_FLAGS}) @@ -276,3 +277,4 @@ add_custom_target(oneflow_deps ALL DEPENDS prepare_oneflow_third_party) if (ONEFLOW) include(oneflow) endif() +add_subdirectory(ci) diff --git a/ci/CMakeLists.txt b/ci/CMakeLists.txt new file mode 100644 index 00000000000..552439ebc59 --- /dev/null +++ b/ci/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(test) diff --git a/ci/clang/build-llvm.sh b/ci/clang/build-llvm.sh new file mode 100644 index 00000000000..1200d26f079 --- /dev/null +++ b/ci/clang/build-llvm.sh @@ -0,0 +1,28 @@ +set -ex +export PATH=/usr/lib/llvm-12/bin:/usr/lib/llvm-13/bin:/usr/lib64/ccache:/root/.local/bin:$PATH + +# clean python dir +cd ${ONEFLOW_CI_SRC_DIR} +${ONEFLOW_CI_PYTHON_EXE} -m pip install -i https://mirrors.aliyun.com/pypi/simple --user -r ci/fixed-dev-requirements.txt +cd python +git clean -nXd -e \!dist -e \!dist/** +git clean -fXd -e \!dist -e \!dist/** + +# cmake config +mkdir -p ${ONEFLOW_CI_BUILD_DIR} +cd ${ONEFLOW_CI_BUILD_DIR} +find ${ONEFLOW_CI_BUILD_DIR} -name CMakeCache.txt +find ${ONEFLOW_CI_BUILD_DIR} -name CMakeCache.txt -delete +if [ ! -f "$ONEFLOW_CI_CMAKE_INIT_CACHE" ]; then + echo "$ONEFLOW_CI_CMAKE_INIT_CACHE does not exist." + exit 1 +fi +cmake -S ${ONEFLOW_CI_SRC_DIR} -C ${ONEFLOW_CI_CMAKE_INIT_CACHE} -DPython3_EXECUTABLE=${ONEFLOW_CI_PYTHON_EXE} +# cmake build +cd ${ONEFLOW_CI_BUILD_DIR} +cmake --build . -j $(nproc) + +# build pip +cd ${ONEFLOW_CI_SRC_DIR} +cd python +${ONEFLOW_CI_PYTHON_EXE} setup.py bdist_wheel diff --git a/ci/test/1node_op_test.sh b/ci/test/1node_op_test.sh index 0c3057d8365..2b0214fa403 100644 --- a/ci/test/1node_op_test.sh +++ b/ci/test/1node_op_test.sh @@ -37,5 +37,3 @@ then else echo "deadlock unsolved, skipping multi-card eager" fi - -ONEFLOW_TEST_MULTI_PROCESS=1 python3 test/ops/test_multi_process.py --failfast --verbose diff --git a/ci/test/2node_op_test_multi_client.sh b/ci/test/2node_op_test_multi_client.sh index 33efe6c8ce3..148ae2fe374 100755 --- a/ci/test/2node_op_test_multi_client.sh +++ b/ci/test/2node_op_test_multi_client.sh @@ -17,7 +17,7 @@ cd ${test_tmp_dir}/$(basename $test_dir) for device_num in 1 2 4 do - ONEFLOW_TEST_NODE_NUM=2 ONEFLOW_TEST_DEVICE_NUM=$device_num python3 -m oneflow.distributed.launch --nproc_per_node $device_num --nnodes=2 --node_rank=$NODE_RANK --master_addr 192.168.1.12 -m unittest discover ${PWD} --failfast --verbose + ONEFLOW_TEST_NODE_NUM=2 ONEFLOW_TEST_DEVICE_NUM=$device_num python3 -m oneflow.distributed.launch --nproc_per_node $device_num --nnodes=2 --node_rank=$NODE_RANK --master_addr $_MASTER_ADDR -m unittest discover ${PWD} --failfast --verbose # use a invalid ibverbs lib to test if falling back to epoll works - ONEFLOW_TEST_NODE_NUM=2 ONEFLOW_TEST_DEVICE_NUM=$device_num ONEFLOW_LIBIBVERBS_PATH=invalid_lib python3 -m oneflow.distributed.launch --nproc_per_node $device_num --nnodes=2 --node_rank=$NODE_RANK --master_addr 192.168.1.12 -m unittest discover ${PWD} --failfast --verbose + ONEFLOW_TEST_NODE_NUM=2 ONEFLOW_TEST_DEVICE_NUM=$device_num ONEFLOW_LIBIBVERBS_PATH=invalid_lib python3 -m oneflow.distributed.launch --nproc_per_node $device_num --nnodes=2 --node_rank=$NODE_RANK --master_addr $_MASTER_ADDR -m unittest discover ${PWD} --failfast --verbose done diff --git a/ci/test/CMakeLists.txt b/ci/test/CMakeLists.txt new file mode 100644 index 00000000000..1f7871e80a7 --- /dev/null +++ b/ci/test/CMakeLists.txt @@ -0,0 +1,25 @@ +set(PYTHON_EXECUTABLE python3 CACHE STRING "python3 exe to run test, usually is the python3 installation oneflow is linked to") +set(ONEFLOW_SRC_DIR ${CMAKE_SOURCE_DIR} CACHE STRING "source dir of oneflow") +set(IS_DEV ON CACHE BOOL "") +set(CTEST_RESOURCE_SPEC_FILE "${CMAKE_CURRENT_SOURCE_DIR}/resource-spec/2x-rtx-2080.json" CACHE STRING "") + +# CTEST_OUTPUT_ON_FAILURE=1 CTEST_PARALLEL_LEVEL=20 ninja test + +file(GLOB_RECURSE PYTHON_TEST_FILES LIST_DIRECTORIES false RELATIVE ${ONEFLOW_SRC_DIR} "${ONEFLOW_SRC_DIR}/python/oneflow/test_*.py") +foreach(PYTHON_TEST_FILE ${PYTHON_TEST_FILES}) + set(TEST_NAME ${PYTHON_TEST_FILE}) + add_test(NAME ${TEST_NAME} + COMMAND ${PYTHON_EXECUTABLE} ${ONEFLOW_SRC_DIR}/${PYTHON_TEST_FILE} --failfast --verbose + ) + set_tests_properties(${TEST_NAME} + PROPERTIES + ENVIRONMENT "$<$>:ONEFLOW_TEST_CPU_ONLY=1>;$<$:PYTHONPATH=${ONEFLOW_SRC_DIR}/python:$ENV{PYTHONPATH}>" + RESOURCE_GROUPS + "vram:2000" + ) +endforeach() +set_tests_properties(python/oneflow/test/modules/test_rnn.py + PROPERTIES + RESOURCE_GROUPS + "vram:4000" +) diff --git a/ci/test/distributed_run.py b/ci/test/distributed_run.py index 4002befd658..55e97d9edde 100644 --- a/ci/test/distributed_run.py +++ b/ci/test/distributed_run.py @@ -60,13 +60,13 @@ def find_free_port(): return s.getsockname()[1] -async def spawn_shell_and_check(cmd: str = None): +async def spawn_shell(cmd: str = None): p = await asyncio.create_subprocess_shell(cmd,) await p.wait() assert p.returncode == 0, cmd -async def spawn_shell(cmd: str = None): +async def spawn_shell_ignoring_failure(cmd: str = None): p = await asyncio.create_subprocess_shell(cmd,) await p.wait() @@ -74,34 +74,32 @@ async def spawn_shell(cmd: str = None): async def build_docker_img(remote_host=None, workspace_dir=None): if remote_host: assert workspace_dir - await spawn_shell_and_check("rm -f > oneflow-src.zip") - await spawn_shell_and_check("git archive --format zip HEAD > oneflow-src.zip") - await spawn_shell_and_check( + await spawn_shell("rm -f > oneflow-src.zip") + await spawn_shell("git archive --format zip HEAD > oneflow-src.zip") + await spawn_shell( f"scp oneflow-src.zip {remote_host}:{workspace_dir}/oneflow-src.zip", ) - await spawn_shell_and_check( + await spawn_shell( f"ssh {remote_host} unzip {workspace_dir}/oneflow-src.zip -d {workspace_dir}/oneflow-src", ) - await spawn_shell_and_check( + await spawn_shell( f"ssh {remote_host} bash {workspace_dir}/oneflow-src/docker/ci/test/build.sh", ) else: - await spawn_shell_and_check(f"bash docker/ci/test/build.sh") + await spawn_shell(f"bash docker/ci/test/build.sh") async def create_remote_workspace_dir( remote_host=None, workspace_dir=None, copy_files=None ): - await spawn_shell_and_check(f"ssh {remote_host} mkdir -p {workspace_dir}") + await spawn_shell(f"ssh {remote_host} mkdir -p {workspace_dir}") if copy_files is not None: for path in copy_files: # Reference: https://stackoverflow.com/a/31278462 if os.path.isdir(path) and path[-1] != "/": path += "/" - await spawn_shell_and_check( - f"ssh {remote_host} mkdir -p {workspace_dir}/{path}" - ) - await spawn_shell_and_check( + await spawn_shell(f"ssh {remote_host} mkdir -p {workspace_dir}/{path}") + await spawn_shell( f"rsync -azPq --omit-dir-times --no-perms --no-group --copy-links --exclude='__pycache__' {path} {remote_host}:{workspace_dir}/{path}" ) print("create_remote_workspace_dir done") @@ -126,9 +124,17 @@ async def launch_remote_container( oneflow_python_path=None, cmd=None, node_rank=None, + master_addr=None, ): print("launching remote container at", remote_host) assert img_tag + multi_client_args = [node_rank, master_addr] + multi_client_arg_has_value = [x is not None for x in multi_client_args] + if any(multi_client_arg_has_value): + assert all(multi_client_arg_has_value) + is_multi_client = True + else: + is_multi_client = False pythonpath_args = None if oneflow_wheel_path: pythonpath_args = "" @@ -138,25 +144,28 @@ async def launch_remote_container( raise ValueError("must have oneflow_wheel_path or oneflow_python_path") docker_cmd = f"""docker run --privileged -d --network host --shm-size=8g --rm {get_docker_cache_args()} -v {workspace_dir}:{workspace_dir} -w {workspace_dir} -v /dataset:/dataset -v /model_zoo:/model_zoo --name {container_name} {pythonpath_args} {img_tag} sleep {survival_time} """ - await spawn_shell_and_check(f"ssh {remote_host} {docker_cmd}") + await spawn_shell(f"ssh {remote_host} {docker_cmd}") if oneflow_wheel_path: whl_basename = os.path.basename(oneflow_wheel_path) - await spawn_shell_and_check( + await spawn_shell( f"ssh {remote_host} docker exec {container_name} python3 -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple" ) - await spawn_shell_and_check( + await spawn_shell( f"ssh {remote_host} docker exec {container_name} python3 -m pip install {workspace_dir}/{whl_basename}" ) await spawn_shell( f"ssh {remote_host} docker exec {container_name} python3 -m oneflow --doctor" ) if cmd: - if node_rank is not None: - node_rank_args = f"--env NODE_RANK={node_rank}" + if is_multi_client: + multi_client_docker_args = ( + # Use _MASTER_ADDR to avoid name conflict with OneFlow's built-in MASTER_ADDR + f"--env NODE_RANK={node_rank} --env _MASTER_ADDR={master_addr}" + ) else: - node_rank_args = "" + multi_client_docker_args = "" await spawn_shell( - f"ssh {remote_host} docker exec {node_rank_args} {container_name} {cmd}" + f"ssh {remote_host} docker exec {multi_client_docker_args} {container_name} {cmd}" ) @@ -318,7 +327,7 @@ async def fix_and_sync_libs(oneflow_internal_path=None, remote_hosts=None): tmp_dir = tempfile.TemporaryDirectory() tmp_lib_dir = os.path.join(tmp_dir.name, "libs") os.mkdir(tmp_lib_dir) - await spawn_shell_and_check( + await spawn_shell( """ldd file | grep "=> /" | awk '{print $3}' | xargs -I '{}' cp -v '{}' destination""".replace( "file", oneflow_internal_path ).replace( @@ -331,17 +340,15 @@ async def fix_and_sync_libs(oneflow_internal_path=None, remote_hosts=None): pathlib.Path(__file__).parent.absolute(), "excludelist" ) excludelist = open(excludelist_path).read().split("\n") - await spawn_shell_and_check(f"cp {oneflow_internal_path} {tmp_dir.name}") + await spawn_shell(f"cp {oneflow_internal_path} {tmp_dir.name}") def handle_lib(lib): if lib in excludelist or "libpython" in lib: print("excluding", lib) - return spawn_shell_and_check(f"rm {tmp_lib_dir}/{lib}") + return spawn_shell(f"rm {tmp_lib_dir}/{lib}") else: print("keeping", lib) - return spawn_shell_and_check( - f"patchelf --set-rpath '$ORIGIN' {tmp_lib_dir}/{lib}" - ) + return spawn_shell(f"patchelf --set-rpath '$ORIGIN' {tmp_lib_dir}/{lib}") await asyncio.gather(*(handle_lib(lib) for lib in libs)) @@ -349,15 +356,15 @@ def handle_lib(lib): tmp_dir.name, pathlib.Path(oneflow_internal_path).name ) print("before fixing .so") - await spawn_shell_and_check(f"ldd {tmp_oneflow_internal_path}") + await spawn_shell(f"ldd {tmp_oneflow_internal_path}") print("fixing .so") - await spawn_shell_and_check( + await spawn_shell( f"patchelf --set-rpath '$ORIGIN/libs' {tmp_oneflow_internal_path}" ) await asyncio.gather( *[ - spawn_shell_and_check( + spawn_shell( f"ssh {remote_host} 'mkdir -p {workspace_dir}/python/oneflow/libs'", ) for remote_host in remote_hosts @@ -366,7 +373,7 @@ def handle_lib(lib): async def copy_file(path=None, remote_host=None): relpath = os.path.relpath(path, tmp_dir.name) - await spawn_shell_and_check( + await spawn_shell( f"scp {path} {remote_host}:{workspace_dir}/python/oneflow/{relpath}", ) @@ -382,7 +389,7 @@ async def copy_file(path=None, remote_host=None): for remote_host in remote_hosts for f in files ], - spawn_shell_and_check(f"ldd {tmp_oneflow_internal_path}"), + spawn_shell(f"ldd {tmp_oneflow_internal_path}"), ) @@ -391,8 +398,11 @@ async def remove_containers_by_name(remote_hosts=None, container_name=None): assert container_name assert remote_hosts await asyncio.gather( - *[spawn_shell(f"ssh {remote_host} {rm_cmd}") for remote_host in remote_hosts], - spawn_shell(rm_cmd), + *[ + spawn_shell_ignoring_failure(f"ssh {remote_host} {rm_cmd}") + for remote_host in remote_hosts + ], + spawn_shell_ignoring_failure(rm_cmd), ) @@ -504,9 +514,7 @@ def get_remote_hosts(args): loop.run_until_complete( asyncio.gather( *[ - spawn_shell_and_check( - f"ssh -o StrictHostKeyChecking=no {remote_host} true" - ) + spawn_shell(f"ssh -o StrictHostKeyChecking=no {remote_host} true") for remote_host in remote_hosts ], ), @@ -545,7 +553,7 @@ def get_remote_hosts(args): loop.run_until_complete( asyncio.gather( *[ - spawn_shell_and_check( + spawn_shell( f"rsync -azPq --omit-dir-times --no-perms --no-group --copy-links --include='*.py' --exclude='*.so' --exclude='__pycache__' --exclude='oneflow/include' --include='*/' --exclude='*' {args.oneflow_python_path} {remote_host}:{workspace_dir}" ) for remote_host in remote_hosts @@ -564,7 +572,7 @@ def get_remote_hosts(args): loop.run_until_complete( asyncio.gather( *[ - spawn_shell_and_check( + spawn_shell( f"rsync -azPq --omit-dir-times --no-perms --no-group {oneflow_wheel_path} {remote_host}:{workspace_dir}" ) for remote_host in remote_hosts @@ -611,7 +619,7 @@ def exit_handler(): loop.run_until_complete( asyncio.gather( *[ - spawn_shell( + spawn_shell_ignoring_failure( f"ssh {remote_host} docker run --rm -v {workspace_dir}:/p -w /p busybox chmod -R 777 .", ) for remote_host in remote_hosts @@ -625,7 +633,7 @@ def exit_handler(): loop.run_until_complete( asyncio.gather( *[ - spawn_shell( + spawn_shell_ignoring_failure( f"rsync -azPq --omit-dir-times --no-perms --no-group --exclude='*.whl' --exclude='python' {extra_exclude_args} {remote_host}:{workspace_dir}/ {args.oneflow_test_tmp_dir}/{remote_host}" ) for remote_host in remote_hosts @@ -638,7 +646,9 @@ def exit_handler(): loop.run_until_complete( asyncio.gather( *[ - spawn_shell(f"ssh {remote_host} rm -rf {workspace_dir}",) + spawn_shell_ignoring_failure( + f"ssh {remote_host} rm -rf {workspace_dir}", + ) for remote_host in remote_hosts ], ) @@ -667,6 +677,7 @@ def exit_handler(): img_tag=img_tag, cmd=args.cmd, node_rank=node_rank, + master_addr=this_host, ) for node_rank, remote_host in enumerate(remote_hosts) ], diff --git a/ci/test/resource-spec/1x-gtx-1080.json b/ci/test/resource-spec/1x-gtx-1080.json new file mode 100644 index 00000000000..81f888431bf --- /dev/null +++ b/ci/test/resource-spec/1x-gtx-1080.json @@ -0,0 +1,16 @@ +{ + "version": { + "major": 1, + "minor": 0 + }, + "local": [ + { + "vram": [ + { + "id": "0", + "slots": 8117 + } + ] + } + ] +} diff --git a/ci/test/resource-spec/2x-rtx-2080.json b/ci/test/resource-spec/2x-rtx-2080.json new file mode 100644 index 00000000000..a1e44586957 --- /dev/null +++ b/ci/test/resource-spec/2x-rtx-2080.json @@ -0,0 +1,20 @@ +{ + "version": { + "major": 1, + "minor": 0 + }, + "local": [ + { + "vram": [ + { + "id": "0", + "slots": 7982 + }, + { + "id": "1", + "slots": 7982 + } + ] + } + ] +} diff --git a/ci/test/resource-spec/4x-rtx-2080ti.json b/ci/test/resource-spec/4x-rtx-2080ti.json new file mode 100644 index 00000000000..aa401817598 --- /dev/null +++ b/ci/test/resource-spec/4x-rtx-2080ti.json @@ -0,0 +1,28 @@ +{ + "version": { + "major": 1, + "minor": 0 + }, + "local": [ + { + "vram": [ + { + "id": "0", + "slots": 11019 + }, + { + "id": "1", + "slots": 11019 + }, + { + "id": "2", + "slots": 11019 + }, + { + "id": "3", + "slots": 11019 + } + ] + } + ] +} diff --git a/cmake/caches/ci/canary/cuda.cmake b/cmake/caches/ci/canary/cuda.cmake index 1c8116d32ee..8a9e8b61342 100644 --- a/cmake/caches/ci/canary/cuda.cmake +++ b/cmake/caches/ci/canary/cuda.cmake @@ -14,4 +14,5 @@ set(CMAKE_CUDA_HOST_COMPILER "/usr/lib64/ccache/g++" CACHE STRING "") set(CMAKE_CUDA_ARCHITECTURES "61-real;70-real;75-real;80-real;86-real" CACHE STRING "") set(CUDNN_STATIC OFF CACHE BOOL "") set(WITH_MLIR ON CACHE BOOL "") +set(BUILD_CPP_API ON CACHE BOOL "") set(CUDA_NVCC_THREADS_NUMBER 8 CACHE STRING "") diff --git a/cmake/caches/ci/cpu.cmake b/cmake/caches/ci/cpu.cmake index cb8547b061d..fc416e58016 100644 --- a/cmake/caches/ci/cpu.cmake +++ b/cmake/caches/ci/cpu.cmake @@ -8,3 +8,5 @@ set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE Release CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL "") +set(BUILD_CPP_API ON CACHE BOOL "") +set(WITH_MLIR ON CACHE BOOL "") diff --git a/cmake/caches/ci/cuda.cmake b/cmake/caches/ci/cuda.cmake index ecd4128569e..4b15fc0c1be 100644 --- a/cmake/caches/ci/cuda.cmake +++ b/cmake/caches/ci/cuda.cmake @@ -15,4 +15,5 @@ set(CMAKE_CUDA_HOST_COMPILER "/usr/lib64/ccache/g++" CACHE STRING "") set(CMAKE_CUDA_ARCHITECTURES "61;75" CACHE STRING "") set(CUDNN_STATIC ON CACHE BOOL "") set(WITH_MLIR ON CACHE BOOL "") +set(BUILD_CPP_API ON CACHE BOOL "") set(CUDA_NVCC_THREADS_NUMBER 8 CACHE STRING "") diff --git a/cmake/caches/ci/llvm/cuda-75-clang.cmake b/cmake/caches/ci/llvm/cuda-75-clang.cmake new file mode 100644 index 00000000000..68b4ffa8672 --- /dev/null +++ b/cmake/caches/ci/llvm/cuda-75-clang.cmake @@ -0,0 +1,22 @@ +set(CMAKE_C_COMPILER "clang" CACHE STRING "") +set(CMAKE_CXX_COMPILER "clang++" CACHE STRING "") +set(CMAKE_CUDA_COMPILER "clang++" CACHE STRING "") +set(CMAKE_EXE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") +set(CMAKE_MODULE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") +set(CMAKE_SHARED_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") +set(BUILD_SHARED_LIBS YES CACHE BOOL "") +set(BUILD_CUDA YES CACHE BOOL "") +set(CMAKE_CUDA_ARCHITECTURES "75;52-real" CACHE STRING "") +set(BUILD_TESTING YES CACHE BOOL "") +set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") +set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") +set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "") +set(CMAKE_GENERATOR Ninja CACHE STRING "") +set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") +set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") +set(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING "") +set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL "") +set(CUDAToolkit_ROOT /usr/local/cuda CACHE STRING "") +set(CUDNN_ROOT_DIR /usr/local/cudnn CACHE STRING "") +set(RPC_BACKEND "LOCAL" CACHE STRING "") +set(BUILD_HWLOC NO CACHE BOOL "") diff --git a/cmake/caches/cn/fast/cpu-clang.cmake b/cmake/caches/cn/fast/cpu-clang.cmake index 77bcad931c0..a4ca72f6207 100644 --- a/cmake/caches/cn/fast/cpu-clang.cmake +++ b/cmake/caches/cn/fast/cpu-clang.cmake @@ -4,6 +4,7 @@ set(CMAKE_EXE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_MODULE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_SHARED_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(BUILD_SHARED_LIBS YES CACHE BOOL "") +set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL "") set(BUILD_CUDA NO CACHE BOOL "") set(BUILD_TESTING YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") diff --git a/cmake/caches/cn/fast/cpu.cmake b/cmake/caches/cn/fast/cpu.cmake index 627063a2048..a0cfa584343 100644 --- a/cmake/caches/cn/fast/cpu.cmake +++ b/cmake/caches/cn/fast/cpu.cmake @@ -1,4 +1,5 @@ set(BUILD_SHARED_LIBS YES CACHE BOOL "") +set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL "") set(BUILD_CUDA NO CACHE BOOL "") set(BUILD_TESTING YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") diff --git a/cmake/caches/cn/fast/cuda-61-clang.cmake b/cmake/caches/cn/fast/cuda-61-clang.cmake index 9a8bc2c81af..7a21d1b22ca 100644 --- a/cmake/caches/cn/fast/cuda-61-clang.cmake +++ b/cmake/caches/cn/fast/cuda-61-clang.cmake @@ -4,6 +4,7 @@ set(CMAKE_EXE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_MODULE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_SHARED_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(BUILD_SHARED_LIBS YES CACHE BOOL "") +set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL "") set(BUILD_CUDA YES CACHE BOOL "") set(CMAKE_CUDA_ARCHITECTURES "61" CACHE STRING "") set(BUILD_TESTING YES CACHE BOOL "") @@ -13,5 +14,5 @@ set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") +set(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL "") - diff --git a/cmake/caches/cn/fast/cuda-61.cmake b/cmake/caches/cn/fast/cuda-61.cmake index 9da3d096cd7..fd06516c71f 100644 --- a/cmake/caches/cn/fast/cuda-61.cmake +++ b/cmake/caches/cn/fast/cuda-61.cmake @@ -1,6 +1,7 @@ set(BUILD_CUDA YES CACHE BOOL "") set(BUILD_TESTING YES CACHE BOOL "") set(BUILD_SHARED_LIBS YES CACHE BOOL "") +set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "") diff --git a/cmake/caches/cn/fast/cuda-75-clang.cmake b/cmake/caches/cn/fast/cuda-75-clang.cmake index d935bb3b158..29282395d03 100644 --- a/cmake/caches/cn/fast/cuda-75-clang.cmake +++ b/cmake/caches/cn/fast/cuda-75-clang.cmake @@ -4,6 +4,7 @@ set(CMAKE_EXE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_MODULE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(CMAKE_SHARED_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") set(BUILD_SHARED_LIBS YES CACHE BOOL "") +set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL "") set(BUILD_CUDA YES CACHE BOOL "") set(CMAKE_CUDA_ARCHITECTURES "75" CACHE STRING "") set(BUILD_TESTING YES CACHE BOOL "") @@ -13,6 +14,5 @@ set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "") set(CMAKE_GENERATOR Ninja CACHE STRING "") set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") +set(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING "") set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL "") - - diff --git a/cmake/caches/cn/fast/cuda-75.cmake b/cmake/caches/cn/fast/cuda-75.cmake index f07e1abc4d0..7d573745e6d 100644 --- a/cmake/caches/cn/fast/cuda-75.cmake +++ b/cmake/caches/cn/fast/cuda-75.cmake @@ -1,6 +1,7 @@ set(BUILD_CUDA YES CACHE BOOL "") set(BUILD_TESTING YES CACHE BOOL "") set(BUILD_SHARED_LIBS YES CACHE BOOL "") +set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL "") set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "") diff --git a/cmake/caches/cn/fast/mlir-cuda-61.cmake b/cmake/caches/cn/fast/mlir-cuda-61.cmake new file mode 100644 index 00000000000..106b66f2f8c --- /dev/null +++ b/cmake/caches/cn/fast/mlir-cuda-61.cmake @@ -0,0 +1,22 @@ +set(BUILD_SHARED_LIBS YES CACHE BOOL "") +set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL "") +set(BUILD_CUDA YES CACHE BOOL "") +set(BUILD_GIT_VERSION NO CACHE BOOL "") +set(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL "") +set(BUILD_HWLOC NO CACHE BOOL "") +set(BUILD_TESTING OFF CACHE BOOL "") +set(WITH_MLIR YES CACHE BOOL "") +set(THIRD_PARTY_MIRROR aliyun CACHE STRING "") +set(PIP_INDEX_MIRROR "https://pypi.tuna.tsinghua.edu.cn/simple" CACHE STRING "") +set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "") +set(CMAKE_GENERATOR Ninja CACHE STRING "") +set(CMAKE_CUDA_ARCHITECTURES "61-real" CACHE STRING "") +set(CUDA_TOOLKIT_ROOT_DIR /usr/local/cuda CACHE STRING "") +set(CUDNN_ROOT_DIR /usr/local/cudnn CACHE STRING "") +set(CMAKE_C_COMPILER_LAUNCHER ccache CACHE STRING "") +set(CMAKE_CXX_COMPILER_LAUNCHER ccache CACHE STRING "") +set(CMAKE_CUDA_COMPILER_LAUNCHER ccache CACHE STRING "") +set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF CACHE BOOL "") +set(CMAKE_EXE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") +set(CMAKE_MODULE_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") +set(CMAKE_SHARED_LINKER_FLAGS_INIT "-fuse-ld=lld" CACHE STRING "") diff --git a/cmake/caches/cn/fast/mlir-cuda-75.cmake b/cmake/caches/cn/fast/mlir-cuda-75.cmake index 0e3d453b91c..def90f95c74 100644 --- a/cmake/caches/cn/fast/mlir-cuda-75.cmake +++ b/cmake/caches/cn/fast/mlir-cuda-75.cmake @@ -1,4 +1,5 @@ set(BUILD_SHARED_LIBS YES CACHE BOOL "") +set(CMAKE_LINK_DEPENDS_NO_SHARED YES CACHE BOOL "") set(BUILD_CUDA YES CACHE BOOL "") set(BUILD_GIT_VERSION NO CACHE BOOL "") set(TREAT_WARNINGS_AS_ERRORS YES CACHE BOOL "") diff --git a/cmake/cfg.cmake b/cmake/cfg.cmake index b480f5dcc3b..3c565af4896 100644 --- a/cmake/cfg.cmake +++ b/cmake/cfg.cmake @@ -1,8 +1,4 @@ -execute_process( - COMMAND ${CODEGEN_PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/tools/cfg/generate_cfg_head_dir_and_convert_src.py - --get_message_type=cfg_include_dir - OUTPUT_VARIABLE CFG_INCLUDE_DIR) - +set(CFG_INCLUDE_DIR tools/cfg/include) execute_process( COMMAND ${CODEGEN_PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/tools/cfg/generate_cfg_head_dir_and_convert_src.py --get_message_type=template_convert_python_script @@ -25,7 +21,6 @@ execute_process( include_directories(${CFG_INCLUDE_DIR}) - function(GENERATE_CFG_AND_PYBIND11_CPP SRCS HDRS PYBIND_SRCS ROOT_DIR) set(of_cfg_proto_python_dir "${PROJECT_BINARY_DIR}/of_cfg_proto_python") diff --git a/cmake/functional.cmake b/cmake/functional.cmake index 42b79ecbdd6..bc4f6860606 100644 --- a/cmake/functional.cmake +++ b/cmake/functional.cmake @@ -87,3 +87,35 @@ function(GENERATE_FUNCTIONAL_TENSOR_API_AND_PYBIND11_CPP SRCS HDRS PYBIND_SRCS R set(${PYBIND_SRCS} ${${PYBIND_SRCS}} PARENT_SCOPE) endfunction() + +function(GENERATE_FUNCTIONAL_DISPATCH_STATEFUL_OPS_AND_PYBIND11_CPP SRCS HDRS PYBIND_SRCS ROOT_DIR) + set(YAML_FILE ${PROJECT_SOURCE_DIR}/oneflow/api/python/functional/dispatch_stateful_ops.yaml) + set(GENERATED_API_DIR oneflow/api/python/functional) + set(GENERATED_PYBIND_DIR oneflow/api/python/functional) + + list(APPEND SRCS ${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/dispatch_stateful_ops.yaml.cpp) + list(APPEND HDRS ${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/dispatch_stateful_ops.yaml.h) + list(APPEND PYBIND_SRCS ${PROJECT_BINARY_DIR}/${GENERATED_PYBIND_DIR}/dispatch_stateful_ops.yaml.pybind.cpp) + + add_custom_command( + OUTPUT "${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/dispatch_stateful_ops.yaml.cpp" + "${PROJECT_BINARY_DIR}/${GENERATED_API_DIR}/dispatch_stateful_ops.yaml.h" + "${PROJECT_BINARY_DIR}/${GENERATED_PYBIND_DIR}/dispatch_stateful_ops.yaml.pybind.cpp" + COMMAND ${CMAKE_COMMAND} + ARGS -E make_directory ${GENERATED_API_DIR} + COMMAND ${CMAKE_COMMAND} + ARGS -E make_directory ${GENERATED_PYBIND_DIR} + COMMAND ${CODEGEN_PYTHON_EXECUTABLE} + ARGS ${PROJECT_SOURCE_DIR}/tools/functional/generate_dispatch_stateful_ops.py + --project_source_dir ${PROJECT_SOURCE_DIR} + DEPENDS ${CODEGEN_PYTHON_EXECUTABLE} + ${PROJECT_SOURCE_DIR}/tools/functional/generate_dispatch_stateful_ops.py + ${PROJECT_SOURCE_DIR}/tools/functional/generator.py ${YAML_FILE} + VERBATIM) + + set_source_files_properties(${${SRCS}} ${${HDRS}} ${${PYBIND_SRCS}} PROPERTIES GENERATED TRUE) + set(${SRCS} ${${SRCS}} PARENT_SCOPE) + set(${HDRS} ${${HDRS}} PARENT_SCOPE) + set(${PYBIND_SRCS} ${${PYBIND_SRCS}} PARENT_SCOPE) + +endfunction() diff --git a/cmake/oneflow-config.cmake b/cmake/oneflow-config.cmake index 99edffe81c8..fddb71ea003 100644 --- a/cmake/oneflow-config.cmake +++ b/cmake/oneflow-config.cmake @@ -7,7 +7,7 @@ endif() set(ONEFLOW_INCLUDE_DIRS ${ONEFLOW_INSTALL_PREFIX}/include) -find_library(ONEFLOW_LIBRARY NAMES oneflow PATHS ${ONEFLOW_INSTALL_PREFIX}/lib REQUIRED) +find_library(ONEFLOW_LIBRARY NAMES oneflow_cpp PATHS ${ONEFLOW_INSTALL_PREFIX}/lib REQUIRED) if(NOT TARGET OneFlow::liboneflow) add_library(OneFlow::liboneflow INTERFACE IMPORTED) diff --git a/cmake/oneflow.cmake b/cmake/oneflow.cmake index 66dc39ea786..8ca518b9add 100644 --- a/cmake/oneflow.cmake +++ b/cmake/oneflow.cmake @@ -103,11 +103,6 @@ foreach(oneflow_single_file ${oneflow_all_src}) set(group_this ON) endif() - if("${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/api/common/.*\\.(h|cpp)$") - list(APPEND of_all_obj_cc ${oneflow_single_file}) - set(group_this ON) - endif() - if(BUILD_PYTHON) if("${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/api/python/.*\\.(h|cpp)$") @@ -119,18 +114,6 @@ foreach(oneflow_single_file ${oneflow_all_src}) list(APPEND of_pyext_obj_cc ${oneflow_single_file}) set(group_this ON) endif() - - else() # build_python - - if("${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/api/cpp/.*\\.(h|cpp)$") - if("${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/api/cpp/.*_test\\.cpp$") - list(APPEND of_all_test_cc ${oneflow_single_file}) - else() - list(APPEND of_all_obj_cc ${oneflow_single_file}) - endif() - set(group_this ON) - endif() - endif(BUILD_PYTHON) if("${oneflow_single_file}" MATCHES "^${PROJECT_SOURCE_DIR}/oneflow/(core|user|xrt|maybe)/.*\\.cpp$") @@ -162,6 +145,7 @@ add_custom_target(of_format COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/ci/check/run_license_format.py -i ${ONEFLOW_PYTHON_DIR} --fix --exclude="oneflow/include" --exclude="oneflow/core" COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/ci/check/run_clang_format.py --source_dir ${CMAKE_CURRENT_SOURCE_DIR}/oneflow --fix --quiet COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/ci/check/run_py_format.py --source_dir ${CMAKE_CURRENT_SOURCE_DIR} --fix + COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/ci/check/run_clang_format.py --source_dir ${CMAKE_CURRENT_SOURCE_DIR}/tools/oneflow-tblgen --fix --quiet ) # clang tidy add_custom_target(of_tidy @@ -234,8 +218,14 @@ if(BUILD_PYTHON) GENERATE_FUNCTIONAL_TENSOR_API_AND_PYBIND11_CPP( FUNCTIONAL_TENSOR_GENERATED_SRCS FUNCTIONAL_TENSOR_GENERATED_HRCS FUNCTIONAL_TENSOR_PYBIND11_SRCS ${PROJECT_SOURCE_DIR}) + + GENERATE_FUNCTIONAL_DISPATCH_STATEFUL_OPS_AND_PYBIND11_CPP( + FUNCTIONAL_OPS_GENERATED_SRCS FUNCTIONAL_OPS_GENERATED_HRCS + FUNCTIONAL_OPS_PYBIND11_SRCS ${PROJECT_SOURCE_DIR}) + oneflow_add_library(of_functional_tensor_obj STATIC - ${FUNCTIONAL_TENSOR_GENERATED_SRCS} ${FUNCTIONAL_TENSOR_GENERATED_HRCS}) + ${FUNCTIONAL_TENSOR_GENERATED_SRCS} ${FUNCTIONAL_TENSOR_GENERATED_HRCS} + ${FUNCTIONAL_OPS_GENERATED_SRCS} ${FUNCTIONAL_OPS_GENERATED_HRCS}) add_dependencies(of_functional_tensor_obj of_cfgobj) add_dependencies(of_functional_tensor_obj prepare_oneflow_third_party) target_include_directories(of_functional_tensor_obj PRIVATE ${Python_INCLUDE_DIRS} ${Python_NumPy_INCLUDE_DIRS}) @@ -243,7 +233,8 @@ if(BUILD_PYTHON) set(PYBIND11_SRCS ${CFG_PYBIND11_SRCS} ${FUNCTIONAL_PYBIND11_SRCS} - ${FUNCTIONAL_TENSOR_PYBIND11_SRCS}) + ${FUNCTIONAL_TENSOR_PYBIND11_SRCS} + ${FUNCTIONAL_OPS_PYBIND11_SRCS}) endif(BUILD_PYTHON) @@ -251,21 +242,13 @@ include_directories(${PROJECT_SOURCE_DIR}) # TO FIND: third_party/eigen3/.. include_directories(${PROJECT_BINARY_DIR}) # cc obj lib -if(BUILD_PYTHON) - oneflow_add_library(oneflow ${of_all_obj_cc}) -else() # build_python - if(BUILD_SHARED_LIBONEFLOW) - oneflow_add_library(oneflow SHARED ${of_all_obj_cc}) - else() - oneflow_add_library(oneflow ${of_all_obj_cc}) - endif() -endif(BUILD_PYTHON) +oneflow_add_library(oneflow ${of_all_obj_cc}) add_dependencies(oneflow of_protoobj) add_dependencies(oneflow of_cfgobj) add_dependencies(oneflow of_functional_obj) +add_dependencies(oneflow of_op_schema) add_dependencies(oneflow of_git_version) -set_target_properties(oneflow PROPERTIES ARCHIVE_OUTPUT_DIRECTORY "${ONEFLOW_LIBRARY_DIR}" LIBRARY_OUTPUT_DIRECTORY "${ONEFLOW_LIBRARY_DIR}") if (USE_CLANG_FORMAT) add_dependencies(oneflow of_format) @@ -276,35 +259,43 @@ endif() target_compile_definitions(oneflow PRIVATE GOOGLE_LOGGING) -oneflow_add_executable(oneflow-gen-ods ${PROJECT_SOURCE_DIR}/oneflow/ir/oneflow-gen-ods/oneflow-gen-ods.cpp) -set_target_properties(oneflow-gen-ods PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${PROJECT_BINARY_DIR}/bin") +set(ONEFLOW_TOOLS_DIR "${PROJECT_BINARY_DIR}/tools" CACHE STRING "dir to put binary for debugging and development") +set(LLVM_MONO_REPO_URL "https://github.com/llvm/llvm-project/archive/649d95371680cbf7f740c990c0357372c2bd4058.zip" CACHE STRING "") +use_mirror(VARIABLE LLVM_MONO_REPO_URL URL ${LLVM_MONO_REPO_URL}) +set(LLVM_MONO_REPO_MD5 "9bda804e5cc61899085fb0f0dce1089f" CACHE STRING "") +set(ONEFLOW_BUILD_ROOT_DIR "${PROJECT_BINARY_DIR}") +add_subdirectory(${PROJECT_SOURCE_DIR}/oneflow/ir) if (WITH_MLIR) - set(LLVM_MONO_REPO_URL "https://github.com/llvm/llvm-project/archive/649d95371680cbf7f740c990c0357372c2bd4058.zip" CACHE STRING "" FORCE) - use_mirror(VARIABLE LLVM_MONO_REPO_URL URL ${LLVM_MONO_REPO_URL}) - set(LLVM_MONO_REPO_MD5 "9bda804e5cc61899085fb0f0dce1089f" CACHE STRING "" FORCE) - add_subdirectory(${PROJECT_SOURCE_DIR}/oneflow/ir) set(ONEFLOW_MLIR_LIBS -Wl,--no-as-needed MLIROneFlowExtension -Wl,--as-needed) - include_directories(${LLVM_INCLUDE_DIRS}) - include_directories(${MLIR_INCLUDE_DIRS}) - include_directories(${ONEFLOW_MLIR_SOURCE_INCLUDE_DIRS}) - include_directories(${ONEFLOW_MLIR_BINARY_INCLUDE_DIRS}) endif() +include(op_schema) + if(APPLE) - set(of_libs -Wl,-force_load oneflow of_protoobj of_cfgobj of_functional_obj) + set(of_libs -Wl,-force_load oneflow of_protoobj of_cfgobj of_functional_obj of_op_schema) target_link_libraries(oneflow of_protoobj of_cfgobj of_functional_obj glog_imported gflags_imported ${oneflow_third_party_libs}) elseif(UNIX) - set(of_libs -Wl,--whole-archive oneflow of_protoobj of_cfgobj of_functional_obj -Wl,--no-whole-archive -ldl -lrt) + set(of_libs -Wl,--whole-archive oneflow of_protoobj of_cfgobj of_functional_obj of_op_schema -Wl,--no-whole-archive -ldl -lrt) target_link_libraries(oneflow of_protoobj of_cfgobj of_functional_obj glog_imported gflags_imported ${oneflow_third_party_libs} -Wl,--no-whole-archive -ldl -lrt) + if(BUILD_CUDA) + target_link_libraries(oneflow CUDA::cudart_static) + endif() elseif(WIN32) - set(of_libs oneflow of_protoobj of_cfgobj of_functional_obj) + set(of_libs oneflow of_protoobj of_cfgobj of_functional_obj of_op_schema) set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /WHOLEARCHIVE:oneflow") endif() -target_link_libraries(oneflow-gen-ods ${of_libs} ${oneflow_third_party_libs} ${oneflow_exe_third_party_libs}) -if (BUILD_CUDA) - target_link_libraries(oneflow-gen-ods CUDA::cudart_static) +# oneflow api common +if (BUILD_PYTHON OR BUILD_CPP_API) + file(GLOB_RECURSE of_api_common_files + ${PROJECT_SOURCE_DIR}/oneflow/api/common/*.h + ${PROJECT_SOURCE_DIR}/oneflow/api/common/*.cpp) + oneflow_add_library(of_api_common OBJECT ${of_api_common_files}) + target_link_libraries(of_api_common oneflow) + if (WITH_MLIR) + target_link_libraries(of_api_common ${ONEFLOW_MLIR_LIBS}) + endif() endif() if(BUILD_PYTHON) @@ -312,7 +303,7 @@ if(BUILD_PYTHON) # py ext lib oneflow_add_library(of_pyext_obj ${of_pyext_obj_cc}) target_include_directories(of_pyext_obj PRIVATE ${Python_INCLUDE_DIRS} ${Python_NumPy_INCLUDE_DIRS}) - target_link_libraries(of_pyext_obj oneflow) + target_link_libraries(of_pyext_obj oneflow pybind11::headers) if(BUILD_SHARED_LIBS AND APPLE) target_link_libraries(of_pyext_obj ${Python3_LIBRARIES}) endif() @@ -321,19 +312,22 @@ if(BUILD_PYTHON) pybind11_add_module(oneflow_internal ${PYBIND11_SRCS} ${of_pybind_obj_cc} ${PYBIND_REGISTRY_CC}) set_compile_options_to_oneflow_target(oneflow_internal) set_property(TARGET oneflow_internal PROPERTY CXX_VISIBILITY_PRESET "default") - add_dependencies(oneflow_internal of_cfgobj of_functional_obj of_functional_tensor_obj) + add_dependencies(oneflow_internal of_cfgobj of_functional_obj of_functional_tensor_obj of_op_schema) set_target_properties(oneflow_internal PROPERTIES PREFIX "_") set_target_properties(oneflow_internal PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${ONEFLOW_PYTHON_DIR}/oneflow") target_link_libraries(oneflow_internal PRIVATE ${of_libs} of_functional_tensor_obj - ${ONEFLOW_MLIR_LIBS} + of_api_common ${oneflow_third_party_libs} of_pyext_obj ${oneflow_exe_third_party_libs}) target_include_directories(oneflow_internal PRIVATE ${Python_INCLUDE_DIRS} ${Python_NumPy_INCLUDE_DIRS}) target_compile_definitions(oneflow_internal PRIVATE ONEFLOW_CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}) + if(WITH_MLIR) + add_dependencies(check-oneflow oneflow_internal) + endif(WITH_MLIR) set(gen_pip_args "") if (BUILD_CUDA) @@ -356,65 +350,139 @@ if(BUILD_PYTHON) endif(BUILD_PYTHON) +if (BUILD_CPP_API) + file(GLOB_RECURSE of_cpp_api_files + ${PROJECT_SOURCE_DIR}/oneflow/api/cpp/*.cpp + ${PROJECT_SOURCE_DIR}/oneflow/api/cpp/*.h) + if(BUILD_MONOLITHIC_LIBONEFLOW_CPP_SO) + oneflow_add_library(oneflow_cpp SHARED ${of_cpp_api_files}) + else() + oneflow_add_library(oneflow_cpp ${of_cpp_api_files}) + endif() + set_target_properties(oneflow_cpp PROPERTIES ARCHIVE_OUTPUT_DIRECTORY "${LIBONEFLOW_LIBRARY_DIR}" LIBRARY_OUTPUT_DIRECTORY "${LIBONEFLOW_LIBRARY_DIR}") + target_link_libraries(oneflow_cpp PRIVATE ${of_libs} of_api_common ${oneflow_third_party_libs}) +endif() + file(RELATIVE_PATH PROJECT_BINARY_DIR_RELATIVE ${PROJECT_SOURCE_DIR} ${PROJECT_BINARY_DIR}) +function(oneflow_add_test target_name) + cmake_parse_arguments(arg "" "TEST_NAME;WORKING_DIRECTORY" "SRCS" ${ARGN}) + oneflow_add_executable(${target_name} ${arg_SRCS}) + if (BUILD_CUDA) + target_link_libraries(${target_name} CUDA::cudart_static) + endif() + set_target_properties(${target_name} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${PROJECT_BINARY_DIR}/bin") + add_test(NAME ${arg_TEST_NAME} COMMAND ${target_name} WORKING_DIRECTORY ${arg_WORKING_DIRECTORY}) + set_tests_properties( + ${arg_TEST_NAME} + PROPERTIES + ENVIRONMENT "HTTP_PROXY='';HTTPS_PROXY='';http_proxy='';https_proxy='';" + ) +endfunction() + # build test if(BUILD_TESTING) if (of_all_test_cc) - oneflow_add_executable(oneflow_testexe ${of_all_test_cc}) - target_link_libraries(oneflow_testexe ${of_libs} ${oneflow_third_party_libs} ${oneflow_exe_third_party_libs}) - if (BUILD_CUDA) - target_link_libraries(oneflow_testexe CUDA::cudart_static) - endif() - set_target_properties(oneflow_testexe PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${PROJECT_BINARY_DIR}/bin") - add_test(NAME oneflow_test COMMAND oneflow_testexe) + oneflow_add_test(oneflow_testexe SRCS ${of_all_test_cc} TEST_NAME oneflow_test) + target_link_libraries(oneflow_testexe ${of_libs} ${oneflow_third_party_libs} ${oneflow_exe_third_party_libs} ${oneflow_test_libs}) + endif() + + if (BUILD_CPP_API) + file(GLOB_RECURSE cpp_api_test_files ${PROJECT_SOURCE_DIR}/oneflow/api/cpp/tests/*.cpp) + oneflow_add_test(oneflow_cpp_api_testexe SRCS ${cpp_api_test_files} TEST_NAME oneflow_cpp_api_test WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}) + target_link_libraries(oneflow_cpp_api_testexe oneflow_cpp ${oneflow_test_libs}) endif() endif() + # build include add_custom_target(of_include_copy ALL) if(BUILD_PYTHON) add_dependencies(of_include_copy oneflow_internal of_pyscript_copy) - - foreach(of_include_src_dir ${CFG_INCLUDE_DIR}) - copy_all_files_in_dir("${of_include_src_dir}" "${ONEFLOW_INCLUDE_DIR}" of_include_copy) - endforeach() - - copy_files("${PROTO_HDRS}" "${PROJECT_BINARY_DIR}" "${ONEFLOW_INCLUDE_DIR}" of_include_copy) - copy_files("${CFG_HRCS}" "${PROJECT_BINARY_DIR}" "${ONEFLOW_INCLUDE_DIR}" of_include_copy) - - set(OF_CORE_HDRS) - list(APPEND of_core_dir_name_list "common" "device" "framework" "kernel/util" "persistence" "ep/include") - foreach(of_core_dir_name ${of_core_dir_name_list}) - file(GLOB_RECURSE h_files "${PROJECT_SOURCE_DIR}/oneflow/core/${of_core_dir_name}/*.h") - list(APPEND OF_CORE_HDRS ${h_files}) - file(GLOB_RECURSE hpp_files "${PROJECT_SOURCE_DIR}/oneflow/core/${of_core_dir_name}/*.hpp") - list(APPEND OF_CORE_HDRS ${hpp_files}) - endforeach() - list(APPEND OF_CORE_HDRS "${PROJECT_SOURCE_DIR}/oneflow/core/kernel/new_kernel_util.h") - list(APPEND OF_CORE_HDRS "${PROJECT_SOURCE_DIR}/oneflow/core/kernel/kernel_context.h") - list(APPEND OF_CORE_HDRS "${PROJECT_SOURCE_DIR}/oneflow/core/kernel/kernel_observer.h") - list(APPEND OF_CORE_HDRS "${PROJECT_SOURCE_DIR}/oneflow/core/kernel/kernel_util.cuh") - list(APPEND OF_CORE_HDRS "${PROJECT_SOURCE_DIR}/oneflow/core/job/sbp_signature_builder.h") - list(APPEND OF_CORE_HDRS "${PROJECT_SOURCE_DIR}/oneflow/core/common/symbol.h") - list(APPEND OF_CORE_HDRS "${PROJECT_SOURCE_DIR}/oneflow/core/job/parallel_desc.h") - list(APPEND OF_CORE_HDRS "${PROJECT_SOURCE_DIR}/oneflow/core/autograd/autograd_meta.h") - copy_files("${OF_CORE_HDRS}" "${PROJECT_SOURCE_DIR}" "${ONEFLOW_INCLUDE_DIR}" of_include_copy) - + install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/oneflow/core DESTINATION ${ONEFLOW_INCLUDE_DIR}/oneflow + COMPONENT oneflow_py_include + EXCLUDE_FROM_ALL + FILES_MATCHING + PATTERN *.h + PATTERN *.hpp + ) + install(DIRECTORY ${CFG_INCLUDE_DIR}/oneflow DESTINATION ${ONEFLOW_INCLUDE_DIR} + COMPONENT oneflow_py_include + EXCLUDE_FROM_ALL + ) + install(DIRECTORY ${CMAKE_SOURCE_DIR}/oneflow DESTINATION ${ONEFLOW_INCLUDE_DIR} + COMPONENT oneflow_py_include + EXCLUDE_FROM_ALL + FILES_MATCHING + REGEX "oneflow/core/common/.+(h|hpp)$" + REGEX "oneflow/core/device/.+(h|hpp)$" + REGEX "oneflow/core/framework/.+(h|hpp)$" + REGEX "oneflow/core/kernel/util/.+(h|hpp)$" + REGEX "oneflow/core/persistence/.+(h|hpp)$" + REGEX "oneflow/core/ep/include/.+(h|hpp)$" + PATTERN "oneflow/core/kernel/new_kernel_util.h" + PATTERN "oneflow/core/kernel/kernel_context.h" + PATTERN "oneflow/core/kernel/kernel_observer.h" + PATTERN "oneflow/core/kernel/kernel_util.cuh" + PATTERN "oneflow/core/job/sbp_signature_builder.h" + PATTERN "oneflow/core/common/symbol.h" + PATTERN "oneflow/core/job/parallel_desc.h" + PATTERN "oneflow/core/autograd/autograd_meta.h" + PATTERN "oneflow/api" EXCLUDE + PATTERN "oneflow/xrt" EXCLUDE + PATTERN "oneflow/user" EXCLUDE + PATTERN "oneflow/extension" EXCLUDE + PATTERN "oneflow/maybe" EXCLUDE + PATTERN "oneflow/core/lazy" EXCLUDE + PATTERN "oneflow/core/graph_impl" EXCLUDE + PATTERN "oneflow/core/job_rewriter" EXCLUDE + PATTERN "oneflow/core/hardware" EXCLUDE + PATTERN "oneflow/core/intrusive" EXCLUDE + PATTERN "oneflow/core/stream" EXCLUDE + PATTERN "oneflow/core/functional" EXCLUDE + PATTERN "oneflow/core/platform" EXCLUDE + PATTERN "oneflow/core/boxing" EXCLUDE + PATTERN "oneflow/core/rpc" EXCLUDE + PATTERN "oneflow/core/profiler" EXCLUDE + PATTERN "oneflow/core/transport" EXCLUDE + PATTERN "oneflow/core/comm_network" EXCLUDE + PATTERN "oneflow/ir" EXCLUDE + ) + add_custom_target(install_oneflow_py_include + COMMAND + "${CMAKE_COMMAND}" -DCMAKE_INSTALL_COMPONENT=oneflow_py_include + -P "${CMAKE_BINARY_DIR}/cmake_install.cmake" + DEPENDS oneflow_internal + ) add_custom_target(oneflow_py ALL) - add_dependencies(oneflow_py of_include_copy) - -else() # build_python - - add_dependencies(of_include_copy oneflow) + add_dependencies(oneflow_py of_include_copy install_oneflow_py_include) - set(OF_API_DIRS) - file(GLOB_RECURSE api_h_files "${PROJECT_SOURCE_DIR}/oneflow/api/cpp/*.h") - list(APPEND OF_API_DIRS ${api_h_files}) +endif(BUILD_PYTHON) - copy_files("${OF_API_DIRS}" "${PROJECT_SOURCE_DIR}/oneflow/api/cpp" "${ONEFLOW_INCLUDE_DIR}" of_include_copy) - copy_files("${PROJECT_SOURCE_DIR}/cmake/oneflow-config.cmake" "${PROJECT_SOURCE_DIR}/cmake" "${ONEFLOW_SHARE_DIR}" of_include_copy) -endif(BUILD_PYTHON) +set(LIBONEFLOW_INCLUDE_DIR "${PROJECT_BINARY_DIR}/liboneflow_cpp/include/oneflow/api") +install(DIRECTORY oneflow/api/cpp DESTINATION ${LIBONEFLOW_INCLUDE_DIR} + COMPONENT oneflow_cpp_include + EXCLUDE_FROM_ALL + FILES_MATCHING + PATTERN "*.h" +) + +add_custom_target(install_oneflow_cpp_include + COMMAND + "${CMAKE_COMMAND}" -DCMAKE_INSTALL_COMPONENT=oneflow_cpp_include + -P "${CMAKE_BINARY_DIR}/cmake_install.cmake" + DEPENDS oneflow_internal +) +if (BUILD_CPP_API) + add_dependencies(of_include_copy oneflow_cpp) + add_dependencies(of_include_copy install_oneflow_cpp_include) + copy_files("${PROJECT_SOURCE_DIR}/cmake/oneflow-config.cmake" "${PROJECT_SOURCE_DIR}/cmake" "${LIBONEFLOW_SHARE_DIR}" of_include_copy) + + if(WITH_MLIR) + file(GLOB mlir_shared_libs "${PROJECT_BINARY_DIR}/oneflow/ir/llvm_monorepo-build/lib/*.14git") + copy_files("${mlir_shared_libs}" "${PROJECT_BINARY_DIR}/oneflow/ir/llvm_monorepo-build/lib" "${LIBONEFLOW_LIBRARY_DIR}" of_include_copy) + endif(WITH_MLIR) +endif(BUILD_CPP_API) diff --git a/cmake/op_schema.cmake b/cmake/op_schema.cmake new file mode 100644 index 00000000000..970910f94c9 --- /dev/null +++ b/cmake/op_schema.cmake @@ -0,0 +1,90 @@ +get_property(LLVM_INSTALL_DIR GLOBAL PROPERTY LLVM_INSTALL_DIR) +set(LLVM_INSTALL_DIR ${THIRD_PARTY_DIR}/llvm) +set(LLVM_DIR ${LLVM_INSTALL_DIR}/lib/cmake/llvm) +set(ONEFLOW_OP_GROUPS + "ASSIGN" + "BINARY" + "BROADCAST" + "CONV" + "CROSS_ENTROPY" + "CUDA" + "DATASET" + "DETECTION" + "EAGER" + "FUSED" + "IDEMPOTENT" + "IDENTITY" + "IMAGE" + "INDICES" + "INVOLUTION" + "LOSS" + "MATH" + "MATMUL" + "MISC" + "NCCL" + "NORMALIZATION" + "OPTIMIZER" + "PADDING" + "PARALLEL_CAST" + "POOL" + "QUANTIZATION" + "REDUCE" + "RESHAPE" + "SCALAR" + "SOFTMAX" + "SUMMARY" + "TENSOR_BUFFER" + "TEST" + "TRIGONOMETRIC" + "UNARY" + "UPSAMPLE" +) +foreach (OP_GROUP_NAME IN LISTS ONEFLOW_OP_GROUPS) + list(APPEND ONEFLOW_SCHEMA_TABLEGEN_FLAGS "-DGET_ONEFLOW_${OP_GROUP_NAME}_OP_DEFINITIONS") +endforeach() +list(APPEND ONEFLOW_SCHEMA_TABLEGEN_FLAGS "-DREMOVE_ONEFLOW_MLIR_ONLY_OP_DEFINITIONS") + +set(GENERATED_OP_SCHEMA_DIR oneflow/core/framework) +set(GENERATED_IR_INCLUDE_DIR oneflow/ir/include) +set(SOURCE_IR_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/oneflow/ir/include) +set(ONEFLOW_ODS ${SOURCE_IR_INCLUDE_DIR}/OneFlow/OneFlowOps.td) + +list(APPEND ONEFLOW_SCHEMA_TABLEGEN_FLAGS "-I${GENERATED_IR_INCLUDE_DIR}") +list(APPEND ONEFLOW_SCHEMA_TABLEGEN_FLAGS "-I${SOURCE_IR_INCLUDE_DIR}") +list(APPEND ONEFLOW_SCHEMA_TABLEGEN_FLAGS "-I${LLVM_INSTALL_DIR}/include") + +set(GENERATED_OP_SCHEMA_H "${GENERATED_OP_SCHEMA_DIR}/op_generated.h") +set(GENERATED_OP_SCHEMA_CPP "${GENERATED_OP_SCHEMA_DIR}/op_generated.cpp") + + +set(ONEFLOW_TABLE_GEN_EXE ${LLVM_INSTALL_DIR}/bin/oneflow_tblgen) +if(LLVM_PROVIDER STREQUAL "in-tree") + set(ONEFLOW_TABLE_GEN_TARGET oneflow_tblgen install-oneflow-tblgen install-mlir-headers) +elseif(LLVM_PROVIDER STREQUAL "install") + set(ONEFLOW_TABLE_GEN_TARGET ${ONEFLOW_TABLE_GEN_EXE}) +endif() + +file(GLOB_RECURSE ODS_FILES LIST_DIRECTORIES false "${SOURCE_IR_INCLUDE_DIR}/*.td") +if(NOT ODS_FILES) + message(FATAL_ERROR "ODS_FILES not found: ${ODS_FILES}") +endif() +add_custom_command( + OUTPUT ${GENERATED_OP_SCHEMA_H} ${GENERATED_OP_SCHEMA_CPP} + COMMAND ${CMAKE_COMMAND} + ARGS -E make_directory ${GENERATED_OP_SCHEMA_DIR} + COMMAND ${ONEFLOW_TABLE_GEN_EXE} + ARGS --gen-op-schema-h ${ONEFLOW_ODS} ${ONEFLOW_SCHEMA_TABLEGEN_FLAGS} -o ${GENERATED_OP_SCHEMA_H} + COMMAND ${ONEFLOW_TABLE_GEN_EXE} + ARGS --gen-op-schema-cpp ${ONEFLOW_ODS} ${ONEFLOW_SCHEMA_TABLEGEN_FLAGS} + --op-include ${GENERATED_OP_SCHEMA_H} -o ${GENERATED_OP_SCHEMA_CPP} + DEPENDS ${ONEFLOW_TABLE_GEN_TARGET} + ${ODS_FILES} + VERBATIM +) +set_source_files_properties( + ${GENERATED_OP_SCHEMA_H} ${GENERATED_OP_SCHEMA_CPP} PROPERTIES GENERATED TRUE +) + +oneflow_add_library(of_op_schema OBJECT ${GENERATED_OP_SCHEMA_H} ${GENERATED_OP_SCHEMA_CPP}) +add_dependencies(of_op_schema of_cfgobj) +add_dependencies(of_op_schema prepare_oneflow_third_party) diff --git a/cmake/pybind11.cmake b/cmake/pybind11.cmake index 2ac08dfdb0a..6729244f199 100644 --- a/cmake/pybind11.cmake +++ b/cmake/pybind11.cmake @@ -1,16 +1,13 @@ include(FetchContent) -set(PYBIND11_TAR_URL https://github.com/pybind/pybind11/archive/v2.7.0.zip) -use_mirror(VARIABLE PYBIND11_TAR_URL URL ${PYBIND11_TAR_URL}) + +set(PYBIND11_URL https://github.com/pybind/pybind11/archive/v2.7.0.zip) +use_mirror(VARIABLE PYBIND11_URL URL ${PYBIND11_URL}) set(PYBIND11_URL_HASH 267807f790ef598ef912a79aceefdc10) FetchContent_Declare( pybind11 - URL ${PYBIND11_TAR_URL} + URL ${PYBIND11_URL} URL_HASH MD5=${PYBIND11_URL_HASH} ) -FetchContent_GetProperties(pybind11) -if(NOT pybind11_POPULATED) - FetchContent_Populate(pybind11) - add_subdirectory(${pybind11_SOURCE_DIR} ${pybind11_BINARY_DIR}) - include_directories("${pybind11_SOURCE_DIR}/include") -endif() + +FetchContent_MakeAvailable(pybind11) diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index f5c137fefb9..099c151c22b 100644 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -3,7 +3,9 @@ if (NOT WIN32) find_package(Threads) endif() -include(zlib) +if (WITH_ZLIB) + include(zlib) +endif() include(protobuf) include(googletest) include(gflags) @@ -43,6 +45,10 @@ if (WITH_ONEDNN) include(oneDNN) endif() +set_mirror_url_with_hash(INJA_URL + https://github.com/pantor/inja/archive/refs/tags/v3.3.0.zip + 611e6b7206d0fb89728a3879f78b4775 +) option(CUDA_STATIC "" ON) @@ -130,6 +136,12 @@ set(oneflow_exe_third_party_libs gflags_imported ) +set(oneflow_test_libs + ${GOOGLETEST_STATIC_LIBRARIES} + ${GOOGLEMOCK_STATIC_LIBRARIES} +) + + set(oneflow_third_party_libs ${GOOGLETEST_STATIC_LIBRARIES} ${GOOGLEMOCK_STATIC_LIBRARIES} @@ -140,12 +152,12 @@ set(oneflow_third_party_libs ${OPENCV_STATIC_LIBRARIES} ${COCOAPI_STATIC_LIBRARIES} ${LIBJPEG_STATIC_LIBRARIES} - zlib_imported ${ABSL_STATIC_LIBRARIES} ${OPENSSL_STATIC_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT} ${FLATBUFFERS_STATIC_LIBRARIES} ${LZ4_STATIC_LIBRARIES} + nlohmann_json::nlohmann_json ) if (WITH_ONEDNN) set(oneflow_third_party_libs ${oneflow_third_party_libs} ${ONEDNN_STATIC_LIBRARIES}) @@ -155,6 +167,10 @@ if (NOT WITH_XLA) list(APPEND oneflow_third_party_libs ${RE2_LIBRARIES}) endif() +if (WITH_ZLIB) + list(APPEND oneflow_third_party_libs zlib_imported) +endif() + if(WIN32) # static gflags lib requires "PathMatchSpecA" defined in "ShLwApi.Lib" list(APPEND oneflow_third_party_libs "ShLwApi.Lib") @@ -162,7 +178,6 @@ if(WIN32) endif() set(oneflow_third_party_dependencies - zlib protobuf gflags glog @@ -173,7 +188,6 @@ set(oneflow_third_party_dependencies eigen half_copy_headers_to_destination re2 - json_copy_headers_to_destination flatbuffers lz4_copy_libs_to_destination lz4_copy_headers_to_destination @@ -181,7 +195,9 @@ set(oneflow_third_party_dependencies if (WITH_ONEDNN) list(APPEND oneflow_third_party_dependencies onednn) endif() - +if (WITH_ZLIB) + list(APPEND oneflow_third_party_dependencies zlib) +endif() if (WITH_COCOAPI) list(APPEND oneflow_third_party_dependencies cocoapi_copy_headers_to_destination) @@ -206,7 +222,6 @@ list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${EIGEN_INCLUDE_DIR} ${COCOAPI_INCLUDE_DIR} ${HALF_INCLUDE_DIR} - ${JSON_INCLUDE_DIR} ${ABSL_INCLUDE_DIR} ${OPENSSL_INCLUDE_DIR} ${FLATBUFFERS_INCLUDE_DIR} @@ -236,9 +251,9 @@ if (BUILD_CUDA) endif() include(nccl) - list(APPEND oneflow_third_party_libs ${VENDOR_CUDA_LIBRARIES}) - list(APPEND oneflow_third_party_libs ${CUDNN_LIBRARIES}) list(APPEND oneflow_third_party_libs ${NCCL_LIBRARIES}) + list(APPEND oneflow_third_party_libs ${CUDNN_LIBRARIES}) + list(APPEND oneflow_third_party_libs ${VENDOR_CUDA_LIBRARIES}) list(APPEND oneflow_third_party_dependencies nccl) @@ -308,11 +323,20 @@ add_definitions(-DHALF_ENABLE_CPP11_USER_LITERALS=0) if (THIRD_PARTY) add_custom_target(prepare_oneflow_third_party ALL DEPENDS ${oneflow_third_party_dependencies}) - if(BUILD_PYTHON) - foreach(of_include_src_dir ${ONEFLOW_THIRD_PARTY_INCLUDE_DIRS}) - copy_all_files_in_dir("${of_include_src_dir}" "${ONEFLOW_INCLUDE_DIR}" prepare_oneflow_third_party) - endforeach() - endif(BUILD_PYTHON) + if(NOT ONEFLOW_INCLUDE_DIR MATCHES "/include$") + message(FATAL_ERROR "ONEFLOW_INCLUDE_DIR must end with '/include', current value: ${ONEFLOW_INCLUDE_DIR}") + endif() + get_filename_component(ONEFLOW_INCLUDE_DIR_PARENT "${ONEFLOW_INCLUDE_DIR}" DIRECTORY) + foreach(of_include_src_dir ${ONEFLOW_THIRD_PARTY_INCLUDE_DIRS}) + set(ONEFLOW_INCLUDE_DIR_DST ${ONEFLOW_INCLUDE_DIR}) + if(of_include_src_dir MATCHES "/include$") + set(ONEFLOW_INCLUDE_DIR_DST ${ONEFLOW_INCLUDE_DIR_PARENT}) + endif() + install(DIRECTORY ${of_include_src_dir} DESTINATION ${ONEFLOW_INCLUDE_DIR_DST} + COMPONENT oneflow_py_include + EXCLUDE_FROM_ALL + ) + endforeach() else() add_custom_target(prepare_oneflow_third_party ALL) endif() diff --git a/cmake/third_party/absl.cmake b/cmake/third_party/absl.cmake index 6bd3a19664f..f15d13adf5f 100644 --- a/cmake/third_party/absl.cmake +++ b/cmake/third_party/absl.cmake @@ -12,14 +12,14 @@ SET(ABSL_LIBRARY_DIR ${THIRD_PARTY_DIR}/absl/${CMAKE_INSTALL_LIBDIR} CACHE PATH if(WIN32) set(ABSL_BUILD_LIBRARY_DIR ${ABSL_INSTALL}/${CMAKE_INSTALL_LIBDIR}) - set(ABSL_LIBRARY_NAMES absl_base.lib absl_spinlock_wait.lib absl_dynamic_annotations.lib + set(ABSL_LIBRARY_NAMES absl_spinlock_wait.lib absl_dynamic_annotations.lib absl_malloc_internal.lib absl_throw_delegate.lib absl_int128.lib absl_strings.lib absl_str_format_internal.lib - absl_time.lib absl_bad_optional_access.lib) + absl_time.lib absl_bad_optional_access.lib absl_base.lib) else() set(ABSL_BUILD_LIBRARY_DIR ${ABSL_INSTALL}/${CMAKE_INSTALL_LIBDIR}) - set(ABSL_LIBRARY_NAMES libabsl_base.a libabsl_spinlock_wait.a libabsl_dynamic_annotations.a + set(ABSL_LIBRARY_NAMES libabsl_spinlock_wait.a libabsl_dynamic_annotations.a libabsl_malloc_internal.a libabsl_throw_delegate.a libabsl_int128.a libabsl_strings.a libabsl_str_format_internal.a - libabsl_time.a libabsl_bad_optional_access.a) + libabsl_time.a libabsl_bad_optional_access.a libabsl_base.a) endif() foreach(LIBRARY_NAME ${ABSL_LIBRARY_NAMES}) @@ -34,17 +34,12 @@ if(THIRD_PARTY) URL_MD5 20126998c9b17e5f7a93711972f03f79 UPDATE_COMMAND "" BUILD_BYPRODUCTS ${ABSL_STATIC_LIBRARIES} - CMAKE_ARGS - -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} - -DBUILD_SHARED_LIBS:BOOL=OFF - -DCMAKE_CXX_FLAGS:STRING=${CMAKE_CXX_FLAGS} - -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG} - -DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE} CMAKE_CACHE_ARGS -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER} -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER} -DCMAKE_INSTALL_PREFIX:PATH=${ABSL_INSTALL} -DCMAKE_INSTALL_LIBDIR:PATH=${ABSL_INSTALL}/${CMAKE_INSTALL_LIBDIR} + -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE} -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} ) diff --git a/cmake/third_party/eigen.cmake b/cmake/third_party/eigen.cmake index 3dc32e933e4..d7fbaec2fd4 100644 --- a/cmake/third_party/eigen.cmake +++ b/cmake/third_party/eigen.cmake @@ -36,6 +36,7 @@ ExternalProject_Add(eigen -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_INSTALL_PREFIX:STRING=${EIGEN_INSTALL_DIR} + -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE} -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG} -DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE} -DBUILD_TESTING:BOOL=OFF diff --git a/cmake/third_party/flatbuffers.cmake b/cmake/third_party/flatbuffers.cmake index 398de64f442..15782de6153 100644 --- a/cmake/third_party/flatbuffers.cmake +++ b/cmake/third_party/flatbuffers.cmake @@ -36,6 +36,7 @@ if (THIRD_PARTY) -DCMAKE_INSTALL_INCLUDEDIR=${FLATBUFFERS_INSTALL_INCLUDEDIR} -DCMAKE_INSTALL_LIBDIR=${FLATBUFFERS_INSTALL_LIBDIR} -DCMAKE_INSTALL_BINDIR=${FLATBUFFERS_INSTALL_BINDIR} + -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE} -DFLATBUFFERS_BUILD_TESTS=OFF ) endif (THIRD_PARTY) diff --git a/cmake/third_party/gflags.cmake b/cmake/third_party/gflags.cmake index f32cca67301..c8fe24b0009 100644 --- a/cmake/third_party/gflags.cmake +++ b/cmake/third_party/gflags.cmake @@ -47,6 +47,7 @@ ExternalProject_Add(gflags -DCMAKE_CXX_FLAGS:STRING=${CMAKE_CXX_FLAGS} -DGFLAGS_NAMESPACE:STRING=gflags -DCMAKE_INSTALL_PREFIX:STRING=${GFLAGS_INSTALL_DIR} + -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE} ) endif(THIRD_PARTY) diff --git a/cmake/third_party/glog.cmake b/cmake/third_party/glog.cmake index 6361f03d867..c19f9e3e552 100644 --- a/cmake/third_party/glog.cmake +++ b/cmake/third_party/glog.cmake @@ -60,6 +60,7 @@ ExternalProject_Add(glog -DWITH_GFLAGS:BOOL=ON -Dgflags_ROOT:STRING=${GFLAGS_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX:STRING=${GLOG_INSTALL_DIR} + -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE} ) endif(THIRD_PARTY) diff --git a/cmake/third_party/googletest.cmake b/cmake/third_party/googletest.cmake index 9ff4dfa489e..b1eccd954e1 100644 --- a/cmake/third_party/googletest.cmake +++ b/cmake/third_party/googletest.cmake @@ -62,6 +62,7 @@ ExternalProject_Add(googletest -DCMAKE_INSTALL_INCLUDEDIR:STRING=${GTEST_INSTALL_INCLUDEDIR} -DCMAKE_INSTALL_LIBDIR:STRING=${GTEST_INSTALL_LIBDIR} -DCMAKE_INSTALL_BINDIR:STRING=${GTEST_INSTALL_BINDIR} + -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE} #-Dgtest_force_shared_crt:BOOL=ON #default value is OFF ) diff --git a/cmake/third_party/grpc.cmake b/cmake/third_party/grpc.cmake index 11942d6ec0f..0a6c3528020 100644 --- a/cmake/third_party/grpc.cmake +++ b/cmake/third_party/grpc.cmake @@ -76,5 +76,6 @@ ExternalProject_Add(grpc -DgRPC_SSL_PROVIDER:STRING=package -DOpenSSL_ROOT:PATH=${OPENSSL_INSTALL} -DCMAKE_INSTALL_PREFIX:STRING=${GRPC_INSTALL_DIR} + -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE} ) endif(THIRD_PARTY) diff --git a/cmake/third_party/json.cmake b/cmake/third_party/json.cmake index ee497029860..24118ee437c 100644 --- a/cmake/third_party/json.cmake +++ b/cmake/third_party/json.cmake @@ -1,36 +1,16 @@ -include(ExternalProject) +include(FetchContent) -SET(JSON_URL https://github.com/nlohmann/json/releases/download/v3.7.3/include.zip) -use_mirror(VARIABLE JSON_URL URL ${JSON_URL}) -SET(JSON_BASE_DIR ${CMAKE_CURRENT_BINARY_DIR}/json/src/json) -SET(JSON_INSTALL_DIR ${THIRD_PARTY_DIR}/json) -SET(JSON_INCLUDE_DIR ${JSON_INSTALL_DIR}/include CACHE PATH "" FORCE) -SET(JSON_URL_HASH fb96f95cdf609143e998db401ca4f324) -SET(JSON_HEADERS - "${JSON_BASE_DIR}/single_include/nlohmann/json.hpp" +set_mirror_url_with_hash(JSON_URL + https://github.com/nlohmann/json/archive/refs/tags/v3.10.4.zip + 59c2a25e17b94d612fdb32a1a37378cf ) +set(JSON_Install ON CACHE STRING "" FORCE) -if(THIRD_PARTY) - ExternalProject_Add(json - PREFIX json - URL ${JSON_URL} - URL_HASH MD5=${JSON_URL_HASH} - UPDATE_COMMAND "" - CONFIGURE_COMMAND "" - BUILD_COMMAND "" - BUILD_IN_SOURCE 1 - INSTALL_COMMAND "" - ) - add_custom_target(json_create_header_dir - COMMAND ${CMAKE_COMMAND} -E make_directory ${JSON_INCLUDE_DIR} - DEPENDS json - ) - add_custom_target(json_copy_headers_to_destination - DEPENDS json_create_header_dir - ) - foreach(header_file ${JSON_HEADERS}) - add_custom_command(TARGET json_copy_headers_to_destination PRE_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${JSON_INCLUDE_DIR} - ) - endforeach() -endif(THIRD_PARTY) +FetchContent_Declare( + json + URL ${JSON_URL} + URL_HASH MD5=${JSON_URL_HASH} +) + + +FetchContent_MakeAvailable(json) diff --git a/cmake/third_party/oneDNN.cmake b/cmake/third_party/oneDNN.cmake index e56ed55bbf4..28c61045260 100644 --- a/cmake/third_party/oneDNN.cmake +++ b/cmake/third_party/oneDNN.cmake @@ -42,6 +42,7 @@ ExternalProject_Add(onednn BUILD_BYPRODUCTS ${ONEDNN_STATIC_LIBRARIES} CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:STRING=${ONEDNN_INSTALL_DIR} + -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE} -DCMAKE_C_COMPILER_LAUNCHER:STRING=${CMAKE_C_COMPILER_LAUNCHER} -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER} -DCMAKE_POLICY_DEFAULT_CMP0074:STRING=NEW diff --git a/cmake/third_party/opencv.cmake b/cmake/third_party/opencv.cmake index f4eff142e36..bbac4f5b539 100644 --- a/cmake/third_party/opencv.cmake +++ b/cmake/third_party/opencv.cmake @@ -46,7 +46,7 @@ else() endif() ExternalProject_Add(opencv - DEPENDS zlib libjpeg_copy_headers_to_destination libjpeg_copy_libs_to_destination + DEPENDS libjpeg_copy_headers_to_destination libjpeg_copy_libs_to_destination PREFIX opencv URL ${OPENCV_URL} URL_MD5 59870e55385f5202c1aa178fe37ed2de @@ -62,6 +62,7 @@ ExternalProject_Add(opencv -DCMAKE_POLICY_DEFAULT_CMP0074:STRING=NEW -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE} -DCMAKE_INSTALL_PREFIX:STRING=${OPENCV_INSTALL_DIR} + -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE} -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG} @@ -143,6 +144,10 @@ ExternalProject_Add(opencv # -DLIB_SUFFIX:STRING=64 ) +if (WITH_ZLIB) + add_dependencies(opencv zlib) +endif() + # put opencv includes in the 'THIRD_PARTY_DIR' add_copy_headers_target(NAME opencv SRC ${OPENCV_BUILD_INCLUDE_DIR} DST ${OPENCV_INCLUDE_DIR} DEPS opencv INDEX_FILE "${oneflow_cmake_dir}/third_party/header_index/opencv_headers.txt") diff --git a/cmake/third_party/protobuf.cmake b/cmake/third_party/protobuf.cmake index 2278a8422a3..0cbbca7d8d9 100644 --- a/cmake/third_party/protobuf.cmake +++ b/cmake/third_party/protobuf.cmake @@ -45,7 +45,6 @@ if (THIRD_PARTY) ExternalProject_Add(protobuf PREFIX protobuf - DEPENDS zlib URL ${PROTOBUF_URL} URL_MD5 ${PROTOBUF_MD5} UPDATE_COMMAND "" @@ -61,6 +60,7 @@ ExternalProject_Add(protobuf -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DZLIB_ROOT:PATH=${ZLIB_INSTALL} + -Dprotobuf_WITH_ZLIB:BOOL=${WITH_ZLIB} -DCMAKE_CXX_FLAGS_DEBUG:STRING=${CMAKE_CXX_FLAGS_DEBUG} -DBUILD_SHARED_LIBS:BOOL=${PROTOBUF_BUILD_SHARED_LIBS} -Dprotobuf_BUILD_SHARED_LIBS:BOOL=${PROTOBUF_BUILD_SHARED_LIBS} @@ -69,9 +69,13 @@ ExternalProject_Add(protobuf -DCMAKE_INSTALL_INCLUDEDIR:STRING=${PROTOBUF_INSTALL_INCLUDEDIR} -DCMAKE_INSTALL_LIBDIR:STRING=${PROTOBUF_INSTALL_LIBDIR} -DCMAKE_INSTALL_BINDIR:STRING=${PROTOBUF_INSTALL_BINDIR} + -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE} -Dprotobuf_DEBUG_POSTFIX:STRING= ${PROTOBUF_ADDITIONAL_CMAKE_OPTIONS} ) +if (WITH_ZLIB) + add_dependencies(protobuf zlib) +endif() else() add_custom_target(protobuf) endif(THIRD_PARTY) diff --git a/cmake/third_party/re2.cmake b/cmake/third_party/re2.cmake index bb18730507b..11fd14f4048 100644 --- a/cmake/third_party/re2.cmake +++ b/cmake/third_party/re2.cmake @@ -28,6 +28,7 @@ if (THIRD_PARTY) -DCMAKE_CXX_COMPILER_LAUNCHER:STRING=${CMAKE_CXX_COMPILER_LAUNCHER} -DCMAKE_INSTALL_PREFIX:PATH=${RE2_INSTALL_DIR} -DCMAKE_INSTALL_LIBDIR:PATH=${RE2_LIBRARY_DIR} + -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE} -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DRE2_BUILD_TESTING:BOOL=OFF -DCMAKE_BUILD_TYPE:STRING=${CMAKE_BUILD_TYPE}) diff --git a/cmake/third_party/zlib.cmake b/cmake/third_party/zlib.cmake index 596c36a4ca9..e666f10c00c 100644 --- a/cmake/third_party/zlib.cmake +++ b/cmake/third_party/zlib.cmake @@ -50,6 +50,7 @@ ExternalProject_Add(zlib -DCMAKE_CXX_FLAGS_RELEASE:STRING=${CMAKE_CXX_FLAGS_RELEASE} -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_INSTALL_PREFIX:STRING=${ZLIB_INSTALL} + -DCMAKE_INSTALL_MESSAGE:STRING=${CMAKE_INSTALL_MESSAGE} ) diff --git a/cmake/util.cmake b/cmake/util.cmake index 4c6feb2c374..bb0192fd7f8 100644 --- a/cmake/util.cmake +++ b/cmake/util.cmake @@ -102,41 +102,41 @@ function(add_copy_headers_target) endfunction() function(use_mirror) - cmake_parse_arguments( - PARSED_ARGS - "" - "VARIABLE;URL" - "" - ${ARGN} + set(ALIYUN_URL_PREFIX "https://oneflow-static.oss-cn-beijing.aliyuncs.com/third_party_mirror/https/" + CACHE STRING "URL prefix of Aliyun OSS mirror" + ) + cmake_parse_arguments(PARSED_ARGS + "" "VARIABLE;URL" "" ${ARGN} ) - if(NOT PARSED_ARGS_VARIABLE) - message(FATAL_ERROR "VARIABLE required") - endif(NOT PARSED_ARGS_VARIABLE) - if(NOT PARSED_ARGS_URL) - message(FATAL_ERROR "url required") - endif(NOT PARSED_ARGS_URL) - set(UTIL_PYTHON_EXECUTABLE "python3" CACHE STRING "Python executable to run util") - if(Python3_EXECUTABLE) - set(UTIL_PYTHON_EXECUTABLE ${Python3_EXECUTABLE}) - endif(Python3_EXECUTABLE) + + if((NOT PARSED_ARGS_VARIABLE) OR (NOT PARSED_ARGS_URL)) + message(FATAL_ERROR "VARIABLE or URL required") + endif() + if(DEFINED THIRD_PARTY_MIRROR) if(THIRD_PARTY_MIRROR STREQUAL "aliyun") - execute_process( - COMMAND ${UTIL_PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/tools/package_mirror.py -u ${PARSED_ARGS_URL} - OUTPUT_VARIABLE temp_url - ERROR_VARIABLE err - RESULT_VARIABLE ret_code) - if (NOT (ret_code EQUAL "0")) - message(FATAL_ERROR "Fail to convert mirror url ${CMAKE_CURRENT_SOURCE_DIR}/tools/package_mirror.py. URL: ${PARSED_ARGS_URL}. Error: ${err}. Output: ${temp_url}") - else() - set(${PARSED_ARGS_VARIABLE} ${temp_url} PARENT_SCOPE) + if(NOT PARSED_ARGS_URL MATCHES "^https://") + message(FATAL_ERROR "URL should start with 'https://'") endif() + string(REPLACE "https://" ${ALIYUN_URL_PREFIX} MIRRORED_URL ${PARSED_ARGS_URL}) + set(${PARSED_ARGS_VARIABLE} ${MIRRORED_URL} PARENT_SCOPE) + message(NOTICE "-- fetch ${PARSED_ARGS_VARIABLE} using aliyun mirror ${MIRRORED_URL}") elseif(NOT THIRD_PARTY_MIRROR STREQUAL "") - message(FATAL_ERROR "Invalid key for third party mirror.") + message(FATAL_ERROR "invalid key for third party mirror") endif() endif() endfunction() +macro(set_mirror_url variable url) + set(${variable} ${url} ${ARGN}) + use_mirror(VARIABLE ${variable} URL ${url}) +endmacro() + +macro(set_mirror_url_with_hash variable url hash) + set_mirror_url(${variable} ${url} ${ARGN}) + set(${variable}_HASH ${hash} ${ARGN}) +endmacro() + function(check_cxx11_abi OUTPUT_VAR) execute_process( COMMAND ${CMAKE_COMMAND} -E echo "#include \n void test(std::string){}\n int main(){}" OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/temp.cpp) diff --git a/docs/source/comm.rst b/docs/source/comm.rst index 4870e1a7c1a..cdfc227c929 100644 --- a/docs/source/comm.rst +++ b/docs/source/comm.rst @@ -8,8 +8,10 @@ oneflow communication function all_gather, broadcast, scatter, + all_to_all, reduce, gather, reduce_scatter, send, recv, + barrier, diff --git a/docs/source/env.rst b/docs/source/env.rst new file mode 100644 index 00000000000..3425a329a20 --- /dev/null +++ b/docs/source/env.rst @@ -0,0 +1,11 @@ +oneflow.env +=================================== +Environment +---------------------------------- +.. currentmodule:: oneflow + +.. autofunction:: oneflow.env.get_world_size +.. autofunction:: oneflow.env.get_rank +.. autofunction:: oneflow.env.get_local_rank +.. autofunction:: oneflow.env.get_node_size +.. autofunction:: oneflow.env.is_multi_client diff --git a/docs/source/functional.rst b/docs/source/functional.rst index d165621f55f..23b21728d8b 100644 --- a/docs/source/functional.rst +++ b/docs/source/functional.rst @@ -13,6 +13,7 @@ Functional operations for neural networks .. autofunction:: hardsigmoid .. autofunction:: hardswish .. autofunction:: hardtanh +.. autofunction:: normalize .. autofunction:: l2_normalize .. autofunction:: leaky_relu .. autofunction:: elu @@ -22,6 +23,7 @@ Functional operations for neural networks .. autofunction:: pad .. autofunction:: prelu .. autofunction:: logsigmoid +.. autofunction:: log_softmax .. autofunction:: gelu .. autofunction:: glu .. autofunction:: softsign diff --git a/docs/source/index.rst b/docs/source/index.rst index 78a2febf6f1..13f8917bc5c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -23,6 +23,7 @@ OneFlow API Reference optim utils cuda + env distributed comm placement diff --git a/docs/source/oneflow.rst b/docs/source/oneflow.rst index e378708d9b8..a1bbb9dd0a5 100644 --- a/docs/source/oneflow.rst +++ b/docs/source/oneflow.rst @@ -42,6 +42,7 @@ oneflow cos, cosh, diag, + diagonal, movedim, div, dot, @@ -64,6 +65,7 @@ oneflow gt, in_top_k, index_select, + linspace, logical_and, logical_or, logical_not, @@ -80,7 +82,7 @@ oneflow mean, mish, min, - meshgrid, + meshgrid, mul, neg, negative, @@ -112,6 +114,7 @@ oneflow selu, silu, slice, + logical_slice, slice_update, softsign, sort, @@ -149,6 +152,7 @@ oneflow set_printoptions, decode_onerec, read_onerec, + from_numpy, + cumsum, .. autofunction:: oneflow.relu -.. autofunction:: oneflow.env.get_rank diff --git a/docs/source/tensor.rst b/docs/source/tensor.rst index f494ce3ecfc..9fd114adee5 100644 --- a/docs/source/tensor.rst +++ b/docs/source/tensor.rst @@ -42,6 +42,7 @@ OneFlow Tensor Class detach, device, diag, + diagonal, dim, div, double, diff --git a/oneflow/api/common/device.cpp b/oneflow/api/common/device.cpp deleted file mode 100644 index 566f9231066..00000000000 --- a/oneflow/api/common/device.cpp +++ /dev/null @@ -1,52 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -#include "oneflow/api/common/device.h" - -namespace oneflow { - -namespace { - -void CheckDeviceType(const std::string& type) { - if (Device::type_supported.find(type) == Device::type_supported.end()) { - std::string error_msg = - "Expected one of cpu, cuda device type at start of device string " + type; - throw std::runtime_error(error_msg); - } -} - -} // namespace - -/* static */ Maybe> DeviceExportUtil::ParseAndNew( - const std::string& type_or_type_with_device_id) { - std::string type; - int device_id = -1; - ParsingDeviceTag(type_or_type_with_device_id, &type, &device_id).GetOrThrow(); - CheckDeviceType(type); - if (device_id == -1) { - return Device::New(type); - } else { - return Device::New(type, device_id); - } -} - -/* static */ Maybe> DeviceExportUtil::New(const std::string& type, - int64_t device_id) { - CheckDeviceType(type); - return Device::New(type, device_id); -} - -} // namespace oneflow diff --git a/oneflow/api/common/device.h b/oneflow/api/common/ir_pass.cpp similarity index 63% rename from oneflow/api/common/device.h rename to oneflow/api/common/ir_pass.cpp index 8703666af7f..ca4111a9da0 100644 --- a/oneflow/api/common/device.h +++ b/oneflow/api/common/ir_pass.cpp @@ -14,17 +14,17 @@ See the License for the specific language governing permissions and limitations under the License. */ -#ifndef ONEFLOW_API_COMMON_DEVICE_H_ -#define ONEFLOW_API_COMMON_DEVICE_H_ +#ifdef WITH_MLIR -#include "oneflow/core/framework/device.h" +#include "oneflow/ir/include/OneFlow/Extension.h" +#include "oneflow/ir/oneflow-extension/include/OneFlow/OneFlowRoundTrip.h" +#include namespace oneflow { -struct DeviceExportUtil final { - static Maybe> ParseAndNew(const std::string& type_or_type_with_device_id); - static Maybe> New(const std::string& type, int64_t device_id); -}; +REGISTER_JOB_PASS("IRRoundTripBeforeAD", IRRoundTrip); +REGISTER_JOB_PASS("IRRoundTrip", IRRoundTrip); + } // namespace oneflow -#endif // !ONEFLOW_API_COMMON_DEVICE_H_ +#endif // WITH_MLIR diff --git a/oneflow/api/common/job_build_and_infer_ctx.h b/oneflow/api/common/job_build_and_infer_ctx.h new file mode 100644 index 00000000000..8b475f8a2db --- /dev/null +++ b/oneflow/api/common/job_build_and_infer_ctx.h @@ -0,0 +1,36 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef ONEFLOW_API_COMMON_JOB_BUILD_AND_INFER_CTX_H_ +#define ONEFLOW_API_COMMON_JOB_BUILD_AND_INFER_CTX_H_ + +#include "oneflow/core/job/job.pb.h" +#include "oneflow/core/job/job_build_and_infer_ctx_mgr.h" + +namespace oneflow { + +inline Maybe GetCurrentJob() { + auto* job_ctx_mgr = Global::Get(); + CHECK_NOTNULL_OR_RETURN(job_ctx_mgr); + auto* job_ctx = + JUST(job_ctx_mgr->FindJobBuildAndInferCtx(*JUST(job_ctx_mgr->GetCurrentJobName()))); + CHECK_NOTNULL_OR_RETURN(job_ctx); + return job_ctx->job(); +} + +} // namespace oneflow + +#endif // ONEFLOW_API_COMMON_JOB_BUILD_AND_INFER_CTX_H_ diff --git a/oneflow/api/common/sbp.h b/oneflow/api/common/sbp.h new file mode 100644 index 00000000000..013f24924e1 --- /dev/null +++ b/oneflow/api/common/sbp.h @@ -0,0 +1,58 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef ONEFLOW_API_COMMON_SBP_H_ +#define ONEFLOW_API_COMMON_SBP_H_ + +#include "oneflow/core/job/sbp_parallel.pb.h" +#include "oneflow/core/job/sbp_parallel.cfg.h" +#include "oneflow/core/common/symbol.h" +#include "oneflow/core/common/maybe.h" + +namespace oneflow { + +namespace api { + +inline Maybe SbpToString(Symbol sbp_sym) { + std::string sbp_str = "oneflow.sbp."; + if (sbp_sym->has_broadcast_parallel()) { + sbp_str += "broadcast"; + } else if (sbp_sym->has_partial_sum_parallel()) { + sbp_str += "partial_sum"; + } else if (sbp_sym->has_split_parallel()) { + sbp_str += "split(axis=" + std::to_string(sbp_sym->split_parallel().axis()) + ")"; + } else { + UNIMPLEMENTED_THEN_RETURN(); + } + return sbp_str; +} + +inline Maybe NdSbpToString(Symbol nd_sbp) { + std::string str = "("; + for (int i = 0; i < nd_sbp->sbp_parallel_size(); ++i) { + if (i > 0) { str += ", "; } + str += *JUST(SbpToString(SymbolOf(nd_sbp->sbp_parallel(i)))); + } + if (nd_sbp->sbp_parallel_size() == 1) { str += ","; } + str += ")"; + return str; +} + +} // namespace api + +} // namespace oneflow + +#endif // !ONEFLOW_API_COMMON_SBP_H_ diff --git a/oneflow/api/common/scope.h b/oneflow/api/common/scope.h new file mode 100644 index 00000000000..f0626e3ada1 --- /dev/null +++ b/oneflow/api/common/scope.h @@ -0,0 +1,54 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef ONEFLOW_API_COMMON_SCOPE_H_ +#define ONEFLOW_API_COMMON_SCOPE_H_ + +#include +#include +#include "oneflow/core/common/just.h" +#include "oneflow/core/framework/device.h" +#include "oneflow/core/framework/instructions_builder.h" +#include "oneflow/core/framework/session_util.h" +#include "oneflow/core/job/job_conf.cfg.h" +#include "oneflow/core/job/job_conf.pb.h" +#include "oneflow/core/job/scope.h" + +namespace oneflow { + +inline Maybe MakeScope(const JobConfigProto& config_proto, const Device& device) { + std::shared_ptr scope; + std::shared_ptr cfg_config_proto = + std::make_shared(config_proto); + JUST(LogicalRun([&](InstructionsBuilder* builder) -> Maybe { + int64_t session_id = 0; + std::string device_tag = "cpu"; + std::string machine_ids = "0"; + std::string device_ids = "0"; + if (device.type() == "cuda") { + device_tag = "gpu"; + device_ids = std::to_string(device.device_id()); + } + scope = JUST(builder->BuildInitialScope(session_id, cfg_config_proto, device_tag, + {machine_ids + ":" + device_ids}, nullptr, false)); + return Maybe::Ok(); + })); + return scope; +} + +} // namespace oneflow + +#endif // ONEFLOW_API_COMMON_SCOPE_H_ diff --git a/oneflow/api/cpp/env.cpp b/oneflow/api/cpp/env.cpp index 6919948af02..ea31af15a69 100644 --- a/oneflow/api/cpp/env.cpp +++ b/oneflow/api/cpp/env.cpp @@ -24,14 +24,17 @@ limitations under the License. #include #include #include "oneflow/api/cpp/env.h" +#include "oneflow/core/common/global.h" #include "oneflow/core/common/just.h" #include "oneflow/core/common/multi_client.h" #include "oneflow/core/common/optional.h" +#include "oneflow/core/framework/multi_client_session_context.h" #include "oneflow/core/framework/shut_down_util.h" #include "oneflow/core/job/cluster_instruction.h" #include "oneflow/core/job/env.pb.h" #include "oneflow/core/job/env_global_objects_scope.h" #include "oneflow/core/control/ctrl_bootstrap.h" +#include "oneflow/core/job/session.h" #include "oneflow/core/rpc/include/base.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/thread/thread_consistent_id.h" @@ -114,6 +117,12 @@ of::Maybe initEnv() { CompleteEnvProto(env_proto); of::Global::SetAllocated(new of::EnvGlobalObjectsScope()); JUST(of::Global::Get()->Init(env_proto)); + + of::ConfigProto config_proto; + config_proto.mutable_resource()->set_cpu_device_num(1); // useless, will be set in TryInit + config_proto.set_session_id(of::NewSessionId()); + of::Global::New(); + of::Global::Get()->TryInit(config_proto).GetOrThrow(); return of::Maybe::Ok(); } @@ -129,6 +138,8 @@ void release() { if (IsEnvInited()) { // sync multi_client of::vm::ClusterSync().GetOrThrow(); + of::Global::Get()->TryClose().GetOrThrow(); + of::Global::Delete(); // destory env if (of::IsMultiClient().GetOrThrow()) { OF_ENV_BARRIER(); @@ -137,7 +148,6 @@ void release() { } of::Global::Delete(); } - // TODO close session of::SetShuttingDown(); of::ResetThisThreadUniqueConsistentId().GetOrThrow(); } diff --git a/oneflow/api/cpp/framework.h b/oneflow/api/cpp/framework.h index efe1ac38b9a..5d05fb65442 100644 --- a/oneflow/api/cpp/framework.h +++ b/oneflow/api/cpp/framework.h @@ -21,5 +21,7 @@ limitations under the License. #include "framework/shape.h" #include "framework/dtype.h" #include "framework/tensor.h" +#include "framework/ivalue.h" +#include "framework/graph.h" -#endif // !ONEFLOW_API_CPP_FRAMEWORK_H_ +#endif // ONEFLOW_API_CPP_FRAMEWORK_H_ diff --git a/oneflow/api/cpp/framework/device.cpp b/oneflow/api/cpp/framework/device.cpp index dd291de0353..3be4b3f5a46 100644 --- a/oneflow/api/cpp/framework/device.cpp +++ b/oneflow/api/cpp/framework/device.cpp @@ -15,9 +15,9 @@ limitations under the License. */ #include "oneflow/api/cpp/framework/device.h" -#include "oneflow/api/common/device.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/symbol.h" +#include "oneflow/core/framework/device.h" namespace oneflow_api { @@ -25,11 +25,11 @@ namespace of = oneflow; Device::Device(const std::string& type_or_type_with_device_id) : device_(std::make_shared>( - of::DeviceExportUtil::ParseAndNew(type_or_type_with_device_id).GetOrThrow())) {} + of::Device::ParseAndNew(type_or_type_with_device_id).GetOrThrow())) {} Device::Device(const std::string& type, int64_t device_id) - : device_(std::make_shared>( - of::DeviceExportUtil::New(type, device_id).GetOrThrow())) {} + : device_( + std::make_shared>(of::Device::New(type, device_id).GetOrThrow())) {} const std::string& Device::type() const { return (*device_)->type(); } diff --git a/oneflow/api/cpp/framework/device.h b/oneflow/api/cpp/framework/device.h index ab9d3d9175e..2a7e79b2a23 100644 --- a/oneflow/api/cpp/framework/device.h +++ b/oneflow/api/cpp/framework/device.h @@ -32,15 +32,16 @@ namespace oneflow_api { class Device final { friend class Tensor; + friend class Graph; public: explicit Device(const std::string& type_or_type_with_device_id); explicit Device(const std::string& type, int64_t device_id); - const std::string& type() const; - int64_t device_id() const; + [[nodiscard]] const std::string& type() const; + [[nodiscard]] int64_t device_id() const; - bool operator==(const Device& rhs) const; - bool operator!=(const Device& rhs) const; + [[nodiscard]] bool operator==(const Device& rhs) const; + [[nodiscard]] bool operator!=(const Device& rhs) const; private: std::shared_ptr> device_ = nullptr; diff --git a/oneflow/api/cpp/framework/dtype.h b/oneflow/api/cpp/framework/dtype.h index f528bd82b16..217a70d928e 100644 --- a/oneflow/api/cpp/framework/dtype.h +++ b/oneflow/api/cpp/framework/dtype.h @@ -36,8 +36,8 @@ enum class DType { kMaxDataType = 12 }; -int32_t GetDTypeSize(DType dtype); +[[nodiscard]] int32_t GetDTypeSize(DType dtype); } // namespace oneflow_api -#endif // !ONEFLOW_API_CPP_FRAMEWORK_DTYPE_H_ +#endif // ONEFLOW_API_CPP_FRAMEWORK_DTYPE_H_ diff --git a/oneflow/api/cpp/framework/graph.cpp b/oneflow/api/cpp/framework/graph.cpp new file mode 100644 index 00000000000..1a3f14d55d6 --- /dev/null +++ b/oneflow/api/cpp/framework/graph.cpp @@ -0,0 +1,395 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/api/common/ofblob.h" +#include "oneflow/api/common/scope.h" +#include "oneflow/api/cpp/framework/device.h" +#include "oneflow/api/cpp/framework/graph.h" +#include "oneflow/api/cpp/framework/ivalue.h" +#include "oneflow/api/cpp/framework/shape.h" +#include "oneflow/api/cpp/framework/tensor.h" +#include "oneflow/api/common/job_build_and_infer_ctx.h" +#include "oneflow/api/python/job_build/job_build_and_infer.h" +#include "oneflow/core/common/data_type.pb.h" +#include "oneflow/core/common/global.h" +#include "oneflow/core/common/hash_container.h" +#include "oneflow/core/common/just.h" +#include "oneflow/core/common/shape.h" +#include "oneflow/core/common/symbol.h" +#include "oneflow/core/common/util.h" +#include "oneflow/core/framework/device.h" +#include "oneflow/core/framework/dtype.h" +#include "oneflow/core/framework/multi_client_session_context.h" +#include "oneflow/core/framework/nn_graph.h" +#include "oneflow/core/framework/scope_util.h" +#include "oneflow/core/framework/tensor.h" +#include "oneflow/core/framework/tensor_tuple.h" +#include "oneflow/core/framework/tensor_util.h" +#include "oneflow/core/functional/functional_api.yaml.h" +#include "oneflow/core/graph/op_graph.h" +#include "oneflow/core/job/job.pb.h" +#include "oneflow/core/job/job_build_and_infer_ctx.h" +#include "oneflow/core/job/job_build_and_infer_ctx_mgr.h" +#include "oneflow/core/job/job_conf.cfg.h" +#include "oneflow/core/job/job_conf.pb.h" +#include "oneflow/core/job/job_ir.h" +#include "oneflow/core/job/job_set.pb.h" +#include "oneflow/core/job/lazy_mode.h" +#include "oneflow/core/job/parallel_desc.h" +#include "oneflow/core/job/scope.h" +#include "oneflow/core/job/session.h" +#include "oneflow/core/operator/interface_blob_conf.pb.h" +#include "oneflow/core/operator/op_conf.pb.h" +#include "oneflow/core/register/logical_blob_id.pb.h" + +namespace oneflow_api { + +namespace of = oneflow; + +enum class XrtKind : int { kNone = 0, kTensorRT = 1 }; + +namespace { + +class CompileScope { + public: + CompileScope(const of::JobConfigProto& job_config, const of::Device& device, XrtKind kind) { + const std::shared_ptr scope = CHECK_JUST(of::MakeScope(job_config, device)); + CHECK_JUST(of::ThreadLocalScopeStackPush(scope)); + + of::cfg::JobConfigProto job_config_cfg(job_config); + ConfigXrt(job_config_cfg, kind); + CHECK_JUST(of::JobBuildAndInferCtx_Open(job_config.job_name())); + CHECK_JUST(of::CurJobBuildAndInferCtx_SetJobConf(job_config_cfg)); + } + + ~CompileScope() { + CHECK_JUST(of::JobBuildAndInferCtx_Close()); + CHECK_JUST(of::ThreadLocalScopeStackPop()); + } + + private: + of::LazyMode::Guard lazy_mode_enabled_guard{true}; + + void ConfigXrt(of::cfg::JobConfigProto& job_config_cfg, XrtKind kind) { + if (kind == XrtKind::kTensorRT) { +#ifdef WITH_TENSORRT + *(job_config_cfg.mutable_xrt_config()->mutable_use_tensorrt()) = true; +#else + LOG(WARNING) << "XRT TensorRT is unavailable while tensorrt is enabled"; +#endif + } + } +}; + +std::shared_ptr ConvertToTensorTuple( + const std::vector>& tensors) { + auto tensor_tuple = std::make_shared(); + for (const auto& tensor : tensors) { tensor_tuple->emplace_back(tensor); } + return tensor_tuple; +} + +std::string GetDeviceTag(const Device& device) { + if (device.type() == "cuda") { + return "gpu"; + } else { + return "cpu"; + } +} + +template +const std::pair, std::vector> Unzip(const of::HashMap& hash_map) { + std::vector vec1; + std::vector vec2; + for (const auto& entry : hash_map) { + vec1.emplace_back(entry.first); + vec2.emplace_back(entry.second); + } + return std::make_pair(vec1, vec2); +} + +} // namespace + +class Graph::GraphImpl final { + public: + explicit GraphImpl(const std::string& model_path, const Device& device = Device("cpu")); + + GraphImpl(const GraphImpl& graph) = delete; + GraphImpl(GraphImpl&& graph) noexcept; + + ~GraphImpl() = default; + + GraphImpl& operator=(const GraphImpl& graph) = delete; + GraphImpl& operator=(GraphImpl&& graph) noexcept; + + std::vector Forward(const std::vector& inputs); + void set_batch_size(int batch_size) { batch_size_ = batch_size; } + void enable_tensorrt() { xrt_kind_ = XrtKind::kTensorRT; } + + private: + oneflow::Maybe Compile(const std::vector& inputs); + oneflow::Maybe> Run(const std::vector& inputs) const; + oneflow::Maybe AddOp(oneflow::OperatorConf op_conf); + oneflow::Maybe BuildGraph(const std::vector& inputs); + oneflow::Maybe LoadCheckpoint(); + oneflow::Maybe RegisterTensors(const std::vector& inputs); + + std::shared_ptr graph_ = nullptr; + std::string model_path_; + bool is_compiled_ = false; + int batch_size_ = 0; + XrtKind xrt_kind_ = XrtKind::kNone; + Device device_; + oneflow::Job job_; + + oneflow::HashMap input_name_to_order_; + oneflow::HashMap> output_name_to_tensor_; + oneflow::HashMap> variable_op_name_to_tensor_; + std::shared_ptr output_tensor_tuple_; + std::shared_ptr parameter_tensor_tuple_; +}; + +Graph::Graph(const std::string& model_path, const Device& device) + : graph_(std::make_unique(model_path, device)) {} + +Graph::~Graph() = default; + +Graph::Graph(Graph&& graph) noexcept : graph_(std::move(graph.graph_)) {} + +Graph& Graph::operator=(Graph&& graph) noexcept { + if (&graph == this) { return *this; } + graph_ = std::move(graph.graph_); + return *this; +} + +IValue Graph::Forward(const IValue& inputs) { + std::vector input_tensors; + if (inputs.IsNone()) { + // do nothing + } else if (inputs.IsTensor()) { + input_tensors.emplace_back(inputs.ToTensor()); + } else if (inputs.IsTensorVector()) { + input_tensors = inputs.ToTensorVector(); + } else { + LOG(WARNING) << "Graph currently only support types: Tensor/vector(Tensor)/None"; + } + + std::vector output_tensors = graph_->Forward(input_tensors); + if (output_tensors.empty()) { + return IValue{}; + } else if (output_tensors.size() == 1) { + return IValue(output_tensors.at(0)); + } else { + return IValue(output_tensors); + } +} + +void Graph::set_batch_size(int batch_size) { graph_->set_batch_size(batch_size); } + +void Graph::enable_tensorrt() { graph_->enable_tensorrt(); } + +Graph Graph::Load(const std::string& model_path, const Device& device) { + Graph graph(model_path, device); + return graph; +} + +Graph::GraphImpl::GraphImpl(const std::string& model_path, const Device& device) + : model_path_(model_path), device_(device) { + CHECK_JUST(of::LoadJobFromIR(&job_, model_path + "/model.mlir")); + job_.mutable_job_conf()->mutable_predict_conf(); + job_.mutable_job_conf()->set_job_name(job_.mutable_job_conf()->job_name() + of::NewUniqueId()); + graph_ = std::make_shared(job_.job_conf().job_name()); + of::Global::Get()->AddCGraph(graph_).GetOrThrow(); +} + +Graph::GraphImpl::GraphImpl(GraphImpl&& graph) noexcept + : graph_(std::move(graph.graph_)), + model_path_(std::move(graph.model_path_)), + is_compiled_(graph.is_compiled_), + batch_size_(graph.batch_size_), + xrt_kind_(graph.xrt_kind_), + device_(std::move(graph.device_)), + job_(std::move(graph.job_)), + input_name_to_order_(std::move(graph.input_name_to_order_)), + output_name_to_tensor_(std::move(graph.output_name_to_tensor_)), + variable_op_name_to_tensor_(std::move(graph.variable_op_name_to_tensor_)), + output_tensor_tuple_(std::move(graph.output_tensor_tuple_)), + parameter_tensor_tuple_(std::move(graph.parameter_tensor_tuple_)) {} + +Graph::GraphImpl& Graph::GraphImpl::operator=(Graph::GraphImpl&& graph) noexcept { + if (&graph == this) { return *this; } + graph_ = std::move(graph.graph_); + model_path_ = std::move(graph.model_path_); + is_compiled_ = graph.is_compiled_; + batch_size_ = graph.batch_size_; + xrt_kind_ = graph.xrt_kind_; + device_ = std::move(graph.device_); + job_ = std::move(graph.job_); + input_name_to_order_ = std::move(graph.input_name_to_order_); + output_name_to_tensor_ = std::move(graph.output_name_to_tensor_); + variable_op_name_to_tensor_ = std::move(graph.variable_op_name_to_tensor_); + output_tensor_tuple_ = std::move(graph.output_tensor_tuple_); + parameter_tensor_tuple_ = std::move(graph.parameter_tensor_tuple_); + return *this; +} + +std::vector Graph::GraphImpl::Forward(const std::vector& inputs) { + if (!is_compiled_) { + static std::mutex mtx; + std::lock_guard lock(mtx); + Compile(inputs).GetOrThrow(); + is_compiled_ = true; + } + return Run(inputs).GetOrThrow(); +} + +of::Maybe Graph::GraphImpl::Compile(const std::vector& inputs) { + JUST(BuildGraph(inputs)); + JUST(LoadCheckpoint()); + JUST(RegisterTensors(inputs)); + JUST(graph_->CompileAndInitRuntime()); + return of::Maybe::Ok(); +} + +of::Maybe> Graph::GraphImpl::Run(const std::vector& inputs) const { + const auto input_tensor_tuple = std::make_shared(); + for (const auto& tensor : inputs) { input_tensor_tuple->emplace_back(tensor.tensor_); } + + JUST(of::RunLazyNNGraph(*input_tensor_tuple, *output_tensor_tuple_, *parameter_tensor_tuple_, + graph_)); + JUST(of::SoftSyncNNGraphBuffers(*output_tensor_tuple_, graph_)); + + std::vector outputs; + for (const auto& tensor : *output_tensor_tuple_) { outputs.emplace_back(Tensor(tensor)); } + return outputs; +} + +of::Maybe Graph::GraphImpl::AddOp(of::OperatorConf op_conf) { + { + const std::shared_ptr scope = JUST(of::GetCurrentScope()); + op_conf.set_scope_symbol_id(scope->symbol_id().value_or(0)); + } + op_conf.set_device_tag(GetDeviceTag(device_)); + if (batch_size_ > 0 && op_conf.has_input_conf()) { + op_conf.mutable_input_conf()->mutable_blob_conf()->mutable_shape()->mutable_dim()->Set( + 0, batch_size_); + } + auto* ctx = JUST(of::GetCurInferCtx()); + JUST(ctx->AddAndInferConsistentOp(op_conf)); + return of::Maybe::Ok(); +} + +of::Maybe Graph::GraphImpl::BuildGraph(const std::vector& inputs) { + CompileScope build_graph_scope(job_.job_conf(), *device_.device_->shared_from_symbol(), + xrt_kind_); + { + int input_tensor_order = 0; + const of::OpGraph op_graph(job_); + op_graph.TopoForEachNode([&](const of::OpNode* node) -> of::Maybe { + const of::OperatorConf& op_conf = node->op().op_conf(); + JUST(AddOp(op_conf)); + if (op_conf.has_input_conf()) { + input_name_to_order_[op_conf.name()] = input_tensor_order; + input_tensor_order += 1; + } else if (op_conf.has_variable_conf()) { + const of::LazyMode::Guard lazy_mode_disabled_guard{false}; + const of::VariableOpConf& variable_conf = op_conf.variable_conf(); + variable_op_name_to_tensor_[op_conf.name()] = JUST(of::one::functional::Empty( + of::Shape(variable_conf.shape()), + JUST(of::DType::Get(static_cast(variable_conf.data_type()))), + *device_.device_)); + } + return of::Maybe::Ok(); + }); + } + JUST(of::CurJobBuildAndInferCtx_Complete()); + { + const std::shared_ptr complete_job = JUST(of::GetCurrentJob()); + const of::OpGraph complete_graph(*complete_job); + complete_graph.TopoForEachNode([&](const of::OpNode* node) -> of::Maybe { + const of::LazyMode::Guard lazy_mode_disabled_guard{false}; + const of::OperatorConf& op_conf = node->op().op_conf(); + if (op_conf.has_output_conf()) { + of::InterfaceBlobConf blob_conf = op_conf.output_conf().blob_conf(); + if (batch_size_ > 0) { + const std::string input_lbi_str = op_conf.output_conf().in(); + const of::LogicalBlobId input_lbi = of::GenLogicalBlobId(input_lbi_str); + int64_t batch_size = node->LogicalBlobDesc4Lbi(input_lbi).shape().At(0); + blob_conf.mutable_shape()->set_dim(0, batch_size); + } + output_name_to_tensor_[op_conf.name()] = JUST(of::one::functional::Empty( + of::Shape(blob_conf.shape()), + JUST(of::DType::Get(static_cast(blob_conf.data_type()))), + *device_.device_)); + } + return of::Maybe::Ok(); + }); + } + return of::Maybe::Ok(); +} + +of::Maybe Graph::GraphImpl::LoadCheckpoint() { + for (const auto& variable_op_name_and_tensor : variable_op_name_to_tensor_) { + const auto& variable_op_name = variable_op_name_and_tensor.first; + const auto& variable_tensor = variable_op_name_and_tensor.second; + const std::string variable_filename = model_path_ + "/" + variable_op_name + "/out"; + const std::string buffer = [&]() { + std::ifstream variable_file(variable_filename, std::ios::binary); + CHECK(variable_file.is_open()); + std::stringstream ss; + ss << variable_file.rdbuf(); + return ss.str(); + }(); + const auto& callback = + std::make_shared>([&](uint64_t of_blob_ptr) { + CHECK_JUST(of::BlobBufferCopyUtil::From( + of_blob_ptr, buffer.data(), + variable_tensor->shape()->elem_cnt() + * of::GetSizeOfDataType(variable_tensor->dtype()->data_type()))); + }); + JUST(of::one::SyncAccessTensorWithTimeOut(variable_tensor, callback, "mut")); + } + + return of::Maybe::Ok(); +} + +of::Maybe Graph::GraphImpl::RegisterTensors(const std::vector& inputs) { + { + std::vector input_op_names(inputs.size()); + std::vector> input_tensors(inputs.size()); + for (const auto& name_order : input_name_to_order_) { + input_op_names[name_order.second] = name_order.first; + input_tensors[name_order.second] = inputs.at(name_order.second).tensor_; + } + JUST(graph_->RegisterInputOpNamesAndTensors(input_op_names, input_tensors)); + } + { + const auto& pair = Unzip(output_name_to_tensor_); + const std::vector& output_op_names = pair.first; + const std::vector>& output_tensors = pair.second; + JUST(graph_->RegisterOutputOpNamesAndTensors(output_op_names, output_tensors)); + output_tensor_tuple_ = ConvertToTensorTuple(output_tensors); + } + { + const auto& pair = Unzip(variable_op_name_to_tensor_); + const std::vector& variable_op_names = pair.first; + const std::vector>& variable_tensors = pair.second; + JUST(graph_->RegisterVariableOpNamesAndTensors(variable_op_names, variable_tensors)); + parameter_tensor_tuple_ = ConvertToTensorTuple(variable_tensors); + } + return of::Maybe::Ok(); +} + +} // namespace oneflow_api diff --git a/oneflow/api/cpp/framework/graph.h b/oneflow/api/cpp/framework/graph.h new file mode 100644 index 00000000000..c2f690b642d --- /dev/null +++ b/oneflow/api/cpp/framework/graph.h @@ -0,0 +1,56 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef ONEFLOW_API_CPP_GRAPH_H_ +#define ONEFLOW_API_CPP_GRAPH_H_ + +#include "device.h" +#include "ivalue.h" +#include "tensor.h" + +namespace oneflow { + +class NNGraph; + +} // namespace oneflow + +namespace oneflow_api { + +class Graph { + public: + explicit Graph(const std::string& model_path, const Device& device = Device("cpu")); + ~Graph(); + + Graph(const Graph& graph) = delete; + Graph(Graph&& graph) noexcept; + + Graph& operator=(const Graph& graph) = delete; + Graph& operator=(Graph&& graph) noexcept; + + IValue Forward(const IValue& inputs); + void set_batch_size(int batch_size); + void enable_tensorrt(); + + static Graph Load(const std::string& model_path, const Device& device = Device("cpu")); + + private: + class GraphImpl; + std::unique_ptr graph_; +}; + +} // namespace oneflow_api + +#endif // ONEFLOW_API_CPP_GRAPH_H_ diff --git a/oneflow/api/cpp/framework/ivalue.cpp b/oneflow/api/cpp/framework/ivalue.cpp new file mode 100644 index 00000000000..638bee4f124 --- /dev/null +++ b/oneflow/api/cpp/framework/ivalue.cpp @@ -0,0 +1,53 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/api/cpp/framework/ivalue.h" +#include + +namespace oneflow_api { + +namespace of = oneflow; + +std::ostream& operator<<(std::ostream& os, const IValue::Tag& tag) { + os << static_cast(tag); + return os; +} + +int64_t IValue::ToInt() const { + CHECK_EQ(tag_, Tag::kInt) << "Current value is not an int."; + return payload_.i.v_int; +} + +double IValue::ToDouble() const { + CHECK_EQ(tag_, Tag::kDouble) << "Current value is not a double."; + return payload_.i.v_double; +} + +bool IValue::ToBool() const { + CHECK_EQ(tag_, Tag::kBool) << "Current value is not a bool."; + return payload_.i.v_bool; +} + +const Tensor& IValue::ToTensor() const { + CHECK_EQ(tag_, Tag::kTensor) << "Current value is not a tensor."; + return payload_.v_tensor; +} + +const std::vector& IValue::ToTensorVector() const { + CHECK_EQ(tag_, Tag::kTensorVector) << "Current value is not a vector of tensor."; + return payload_.v_tensor_vector; +} + +} // namespace oneflow_api diff --git a/oneflow/api/cpp/framework/ivalue.h b/oneflow/api/cpp/framework/ivalue.h new file mode 100644 index 00000000000..fad26be6aff --- /dev/null +++ b/oneflow/api/cpp/framework/ivalue.h @@ -0,0 +1,149 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_API_CPP_FRAMEWORK_IVALUE_H_ +#define ONEFLOW_API_CPP_FRAMEWORK_IVALUE_H_ + +#include +#include +#include +#include "tensor.h" + +namespace oneflow_api { + +class IValue { + public: + IValue() : tag_(IValue::Tag::kNone) {} + explicit IValue(int value) : tag_(IValue::Tag::kInt) { payload_.i.v_int = value; } + + explicit IValue(int64_t value) : tag_(IValue::Tag::kInt) { payload_.i.v_int = value; } + + explicit IValue(double value) : tag_(IValue::Tag::kDouble) { payload_.i.v_double = value; } + + explicit IValue(bool value) : tag_(IValue::Tag::kBool) { payload_.i.v_bool = value; } + + IValue(const Tensor& value) : tag_(IValue::Tag::kTensor) { // NOLINT + new (&payload_.v_tensor) Tensor(value); + } + + IValue(Tensor&& value) : tag_(IValue::Tag::kTensor) { // NOLINT + new (&payload_.v_tensor) Tensor(std::move(value)); + } + + IValue(const std::vector& value) : tag_(IValue::Tag::kTensorVector) { // NOLINT + new (&payload_.v_tensor_vector) std::vector(value); + } + + IValue(std::vector&& value) : tag_(IValue::Tag::kTensorVector) { // NOLINT + new (&payload_.v_tensor_vector) std::vector(std::move(value)); + } + + IValue(const IValue& value) : tag_(value.tag_) { + if (IsTensor()) { + new (&payload_.v_tensor) Tensor(value.payload_.v_tensor); + } else if (IsTensorVector()) { + new (&payload_.v_tensor_vector) std::vector(value.payload_.v_tensor_vector); + } else { + payload_.i = value.payload_.i; + } + } + + IValue(IValue&& value) noexcept : tag_(value.tag_) { MoveFrom(std::move(value)); } + + IValue& operator=(const IValue& value) { + if (&value == this) { return *this; } + this->tag_ = value.tag_; + *this = IValue(value); + return *this; + } + + IValue& operator=(IValue&& value) noexcept { + if (&value == this) { return *this; } + Destory(); + this->tag_ = value.tag_; + MoveFrom(std::move(value)); + return *this; + } + + ~IValue() { Destory(); } + + bool IsNone() const { return tag_ == Tag::kNone; } + + bool IsInt() const { return tag_ == Tag::kInt; } + + bool IsDouble() const { return tag_ == Tag::kDouble; } + + bool IsBool() const { return tag_ == Tag::kBool; } + + bool IsTensor() const { return tag_ == Tag::kTensor; } + + bool IsTensorVector() const { return tag_ == Tag::kTensorVector; } + + int64_t ToInt() const; + double ToDouble() const; + bool ToBool() const; + const Tensor& ToTensor() const; + const std::vector& ToTensorVector() const; + + private: + enum class Tag { kNone = 0, kInt = 1, kDouble = 2, kBool = 3, kTensor = 4, kTensorVector = 5 }; + friend std::ostream& operator<<(std::ostream&, const Tag&); + + union Payload { // NOLINT + union InternalPayload { + InternalPayload() : v_int(0) {} + + int64_t v_int; + double v_double; + bool v_bool; + } i; + + Tensor v_tensor; + std::vector v_tensor_vector; + + Payload() : i() {} + ~Payload() {} + }; + + Payload payload_; + Tag tag_; + + inline void Destory() { + if (IsTensor()) { payload_.v_tensor.~Tensor(); } + if (IsTensorVector()) { payload_.v_tensor_vector.~vector(); } + } + + inline void MoveFrom(IValue&& value) { + if (IsTensor()) { + new (&payload_.v_tensor) Tensor(std::move(value.payload_.v_tensor)); + } else if (IsTensorVector()) { + new (&payload_.v_tensor_vector) + std::vector(std::move(value.payload_.v_tensor_vector)); + } else { + payload_.i = value.payload_.i; + } + value.ClearToNone(); + } + + inline void ClearToNone() { + Destory(); + payload_.i.v_int = 0; + tag_ = Tag::kNone; + } +}; + +} // namespace oneflow_api + +#endif // ONEFLOW_API_CPP_FRAMEWORK_IVALUE_H_ diff --git a/oneflow/api/cpp/framework/shape.cpp b/oneflow/api/cpp/framework/shape.cpp index da713cf7362..e5365c9ce4a 100644 --- a/oneflow/api/cpp/framework/shape.cpp +++ b/oneflow/api/cpp/framework/shape.cpp @@ -60,4 +60,9 @@ int64_t Shape::Count(int64_t begin_axis, int64_t end_axis) const { int64_t Shape::Count(int64_t begin_axis) const { return shape_->Count(begin_axis); } +std::ostream& operator<<(std::ostream& os, const Shape& shape) { + os << shape.shape_->DebugStr(); + return os; +} + } // namespace oneflow_api diff --git a/oneflow/api/cpp/framework/shape.h b/oneflow/api/cpp/framework/shape.h index 1110cf0f707..7465444ded2 100644 --- a/oneflow/api/cpp/framework/shape.h +++ b/oneflow/api/cpp/framework/shape.h @@ -37,20 +37,22 @@ class Shape final { ~Shape() = default; Shape& operator=(const Shape& shape); - bool operator==(const Shape& rhs) const; - bool operator!=(const Shape& rhs) const; + [[nodiscard]] bool operator==(const Shape& rhs) const; + [[nodiscard]] bool operator!=(const Shape& rhs) const; - int64_t elem_cnt() const; - int64_t At(int64_t index) const; void Set(int64_t index, int64_t val); - int64_t NumAxes() const; - int64_t Count(int64_t begin_axis, int64_t end_axis) const; - int64_t Count(int64_t begin_axis) const; + [[nodiscard]] int64_t elem_cnt() const; + [[nodiscard]] int64_t At(int64_t index) const; + [[nodiscard]] int64_t NumAxes() const; + [[nodiscard]] int64_t Count(int64_t begin_axis, int64_t end_axis) const; + [[nodiscard]] int64_t Count(int64_t begin_axis) const; private: std::shared_ptr shape_ = nullptr; + + friend std::ostream& operator<<(std::ostream&, const Shape&); }; } // namespace oneflow_api -#endif // !ONEFLOW_API_CPP_FRAMEWORK_SHAPE_H_ +#endif // ONEFLOW_API_CPP_FRAMEWORK_SHAPE_H_ diff --git a/oneflow/api/cpp/framework/tensor.cpp b/oneflow/api/cpp/framework/tensor.cpp index fe62b2dfc61..27b95f3ddbd 100644 --- a/oneflow/api/cpp/framework/tensor.cpp +++ b/oneflow/api/cpp/framework/tensor.cpp @@ -14,7 +14,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/api/cpp/framework/tensor.h" -#include #include "oneflow/api/cpp/framework/device.h" #include "oneflow/api/cpp/framework/dtype.h" #include "oneflow/api/cpp/framework/shape.h" @@ -42,17 +41,31 @@ Tensor::Tensor(const Shape& shape, const Device& device, const DType& dtype) { } Tensor::Tensor(const std::shared_ptr& tensor) : tensor_(tensor) {} -const Shape Tensor::shape() const { +Tensor::Tensor(const Tensor& tensor) : tensor_(tensor.tensor_) {} +Tensor::Tensor(Tensor&& tensor) noexcept : tensor_(std::move(tensor.tensor_)) {} + +Tensor& Tensor::operator=(const Tensor& tensor) { + if (&tensor == this) { return *this; } + tensor_ = tensor.tensor_; + return *this; +} +Tensor& Tensor::operator=(Tensor&& tensor) noexcept { + if (&tensor == this) { return *this; } + tensor_ = std::move(tensor.tensor_); + return *this; +} + +Shape Tensor::shape() const { const auto shape_ = tensor_->shape(); return Shape(std::vector(shape_->dim_vec().begin(), shape_->dim_vec().end())); } -const Device Tensor::device() const { +Device Tensor::device() const { const auto device_ = tensor_->device().GetOrThrow(); return Device(device_->type(), device_->device_id()); } -const DType Tensor::dtype() const { return static_cast(tensor_->dtype()->data_type()); } +DType Tensor::dtype() const { return static_cast(tensor_->dtype()->data_type()); } void Tensor::zeros_() { std::shared_ptr local_tensor = @@ -87,7 +100,7 @@ Tensor Tensor::from_buffer(const void* buffer, const Shape& shape, const Device& } template -void Tensor::copy_to(T* buffer) { +void Tensor::copy_to(T* buffer) const { std::shared_ptr local_tensor = tensor_->AsMirroredTensor().GetPtrOrThrow(); const auto shape = this->shape(); @@ -117,7 +130,7 @@ void Tensor::copy_to(T* buffer) { const std::shared_ptr& Tensor::__internal_tensor() const { return tensor_; } #define REGISTER_TENSOR_COPY_TO(cpp_dtype) \ - template void Tensor::copy_to(cpp_dtype * buffer); + template void Tensor::copy_to(cpp_dtype * buffer) const; REGISTER_TENSOR_COPY_TO(float) REGISTER_TENSOR_COPY_TO(double) diff --git a/oneflow/api/cpp/framework/tensor.h b/oneflow/api/cpp/framework/tensor.h index 4058a3db1c3..08ac8daf488 100644 --- a/oneflow/api/cpp/framework/tensor.h +++ b/oneflow/api/cpp/framework/tensor.h @@ -33,24 +33,35 @@ class Tensor; namespace oneflow_api { class Tensor final { + friend class Graph; + public: explicit Tensor(const Shape& shape = Shape(), const Device& device = Device("cpu"), const DType& dtype = DType::kFloat); explicit Tensor(const std::shared_ptr& tensor); - const Shape shape() const; - const Device device() const; - const DType dtype() const; + + Tensor(const Tensor& tensor); + Tensor(Tensor&& tensor) noexcept; + + ~Tensor() = default; + + Tensor& operator=(const Tensor& tensor); + Tensor& operator=(Tensor&& tensor) noexcept; + + [[nodiscard]] Shape shape() const; + [[nodiscard]] Device device() const; + [[nodiscard]] DType dtype() const; void zeros_(); // You should never call __internal_tensor() directly. - const std::shared_ptr& __internal_tensor() const; + [[nodiscard]] const std::shared_ptr& __internal_tensor() const; template - void copy_to(T* buffer); + void copy_to(T* buffer) const; - static Tensor from_buffer(const void* buffer, const Shape& shape, const Device& device, - const DType& dtype); + [[nodiscard]] static Tensor from_buffer(const void* buffer, const Shape& shape, + const Device& device, const DType& dtype); private: std::shared_ptr tensor_ = nullptr; @@ -58,4 +69,4 @@ class Tensor final { } // namespace oneflow_api -#endif // !ONEFLOW_API_CPP_FRAMEWORK_TENSOR_H_ +#endif // ONEFLOW_API_CPP_FRAMEWORK_TENSOR_H_ diff --git a/oneflow/api/cpp/nn.h b/oneflow/api/cpp/nn.h index ebff7bd7c5a..4dbfbef6d9d 100644 --- a/oneflow/api/cpp/nn.h +++ b/oneflow/api/cpp/nn.h @@ -19,4 +19,4 @@ limitations under the License. #include "nn/functional/activation.h" -#endif // !ONEFLOW_API_CPP_NN_H_ +#endif // ONEFLOW_API_CPP_NN_H_ diff --git a/oneflow/api/cpp/nn/functional/activation.h b/oneflow/api/cpp/nn/functional/activation.h index f22cde74645..dc334eb034a 100644 --- a/oneflow/api/cpp/nn/functional/activation.h +++ b/oneflow/api/cpp/nn/functional/activation.h @@ -27,4 +27,4 @@ Tensor relu(const Tensor& tensor); } // namespace oneflow_api -#endif // !ONEFLOW_API_CPP_NN_FUNCTIONAL_ACTIVATION_H_ +#endif // ONEFLOW_API_CPP_NN_FUNCTIONAL_ACTIVATION_H_ diff --git a/oneflow/api/cpp/tests/api_test.cpp b/oneflow/api/cpp/tests/api_test.cpp index 6888b95487d..1fbc790bcc2 100644 --- a/oneflow/api/cpp/tests/api_test.cpp +++ b/oneflow/api/cpp/tests/api_test.cpp @@ -15,8 +15,18 @@ limitations under the License. */ #include "oneflow/api/cpp/tests/api_test.h" -#include +#include #include +#include +#ifdef __linux__ + +#include // readlink + +#elif defined(__APPLE__) + +#include // _NSGetExecutablePath + +#endif namespace oneflow_api { @@ -48,4 +58,27 @@ REGISTER_RANDOM_DATA(int8_t) REGISTER_RANDOM_DATA(int32_t) REGISTER_RANDOM_DATA(int64_t) +std::string GetExeDir() { + const size_t path_max_size = 4096; // PATH_MAX = 4096 on linux + char result[path_max_size]; + + const auto get_dir_from_path = [](char result[], size_t count) -> std::string { + std::string exe_path(result, (count > 0) ? count : 0); + + // string(path).rfind('/') will never be string::npos on linux or macos. + return exe_path.substr(0, exe_path.rfind('/')); + }; + +#ifdef __linux__ + ssize_t count = readlink("/proc/self/exe", result, path_max_size); + return get_dir_from_path(result, count); +#elif defined(__APPLE__) + uint32_t count = path_max_size; + CHECK_EQ(_NSGetExecutablePath(result, &count), 0) << "Fail to get executable file path."; + return get_dir_from_path(result, count); +#else +#error oneflow_api::GetExeDir() has not been supported on windows. +#endif +} + } // namespace oneflow_api diff --git a/oneflow/api/cpp/tests/api_test.h b/oneflow/api/cpp/tests/api_test.h index cec50969e69..c196bc90662 100644 --- a/oneflow/api/cpp/tests/api_test.h +++ b/oneflow/api/cpp/tests/api_test.h @@ -32,6 +32,8 @@ Shape RandomShape(); template std::vector RandomData(size_t size); +std::string GetExeDir(); + } // namespace oneflow_api #endif // !ONEFLOW_API_CPP_TESTS_API_TEST_H_ diff --git a/oneflow/api/cpp/tests/graph_test.cpp b/oneflow/api/cpp/tests/graph_test.cpp new file mode 100644 index 00000000000..497da6b1bbb --- /dev/null +++ b/oneflow/api/cpp/tests/graph_test.cpp @@ -0,0 +1,197 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "oneflow/api/cpp/framework.h" +#include "oneflow/api/cpp/tests/api_test.h" + +namespace oneflow_api { + +namespace { + +inline Graph LoadGraph(const Device& device) { + Graph graph = + Graph::Load("./oneflow/api/cpp/tests/graph_test_model/affine_with_parameter", device); + return graph; +} + +inline void Forward(Graph& graph, const Device& device, int expected_batch_dim = 1) { + std::vector data(expected_batch_dim * 3); + std::fill(data.begin(), data.end(), 1); + std::vector inputs; + inputs.emplace_back( + Tensor::from_buffer(data.data(), Shape({expected_batch_dim, 3}), device, DType::kFloat)); + const auto& value = graph.Forward(inputs); + ASSERT_TRUE(value.IsTensor()); + Tensor output = value.ToTensor(); + Shape shape = output.shape(); + ASSERT_EQ(shape.At(0), expected_batch_dim); + ASSERT_EQ(shape.At(1), 4); + std::vector buf(expected_batch_dim * 4); + output.copy_to(buf.data()); + for (const float& element : buf) { ASSERT_EQ(element, 4); } +} + +} // namespace + +TEST(Api, graph_cpu_test) { + EnvScope scope; + Device device("cpu"); + Graph graph = LoadGraph(device); + Forward(graph, device, 1); +} + +#ifdef WITH_CUDA +TEST(Api, graph_gpu_test) { + EnvScope scope; + Device device("cuda", 0); + Graph graph = LoadGraph(device); + Forward(graph, device); +} + +TEST(Api, graph_multi_gpu_test) { + EnvScope scope; + Device device("cuda", 0); + Graph graph = LoadGraph(device); + Forward(graph, device); + + Device device1("cuda", 1); + Graph graph1 = LoadGraph(device1); + Forward(graph1, device1); +} + +TEST(Api, graph_trt_test) { + EnvScope scope; + Device device("cuda:0"); + Graph graph = LoadGraph(device); + graph.enable_tensorrt(); + Forward(graph, device); +} +#endif + +TEST(Api, graph_cpu_batching_test) { + EnvScope scope; + Device device("cpu"); + Graph graph = LoadGraph(device); + graph.set_batch_size(10); + Forward(graph, device, 10); +} + +#ifdef WITH_CUDA +TEST(Api, graph_gpu_batching_test) { + EnvScope scope; + Device device("cuda", 0); + Graph graph = LoadGraph(device); + graph.set_batch_size(10); + Forward(graph, device, 10); +} + +TEST(Api, graph_multi_device_test) { + EnvScope scope; + Device device("cuda", 0); + Graph graph = LoadGraph(device); + Forward(graph, device, 1); + + Device device1("cuda", 1); + Graph graph1 = LoadGraph(device1); + Forward(graph1, device1, 1); + + Device device2("cpu"); + Graph graph2 = LoadGraph(device2); + Forward(graph2, device2, 1); +} + +TEST(Api, graph_unload_test) { + { + EnvScope scope; + + Device device("cuda", 0); + Graph graph = LoadGraph(device); + Forward(graph, device, 1); + + { + Device device1("cuda", 1); + Graph graph1 = LoadGraph(device1); + Forward(graph1, device1, 1); + } + + Device device2("cpu"); + Graph graph2 = LoadGraph(device2); + Forward(graph2, device2, 1); + } + + { + EnvScope scope; + + Device device("cpu"); + Graph graph = LoadGraph(device); + Forward(graph, device, 1); + } +} +#endif + +TEST(Api, graph_thread_test) { + EnvScope scope; + + Device device("cpu"); + std::vector graphs; + for (int i = 0; i < 10; i++) { graphs.emplace_back(LoadGraph(device)); } + + std::vector threads; + for (Graph& graph : graphs) { + threads.emplace_back(std::thread(std::bind(Forward, std::move(graph), device, 1))); + } + for (auto& thread : threads) { thread.join(); } +} + +TEST(Api, graph_input_order_test) { + EnvScope scope; + + Device device("cpu"); + Graph graph = Graph::Load("./oneflow/api/cpp/tests/graph_test_model/affine_no_parameter", device); + + std::vector inputs; + std::vector x(3); + std::fill(x.begin(), x.end(), 1); + inputs.emplace_back(Tensor::from_buffer(x.data(), Shape({1, 3}), device, DType::kFloat)); + std::vector a(3 * 2); + std::fill(a.begin(), a.end(), 1); + inputs.emplace_back(Tensor::from_buffer(a.data(), Shape({3, 2}), device, DType::kFloat)); + std::vector b(2); + std::fill(b.begin(), b.end(), 1); + inputs.emplace_back(Tensor::from_buffer(b.data(), Shape({2}), device, DType::kFloat)); + + const auto& value = graph.Forward(inputs); + ASSERT_TRUE(value.IsTensor()); + Tensor output = value.ToTensor(); + Shape shape = output.shape(); + ASSERT_EQ(shape.At(0), 1); + ASSERT_EQ(shape.At(1), 2); + std::array buf{}; + output.copy_to(buf.data()); + ASSERT_EQ(buf[0], 4); + ASSERT_EQ(buf[1], 4); +} + +} // namespace oneflow_api diff --git a/oneflow/api/cpp/tests/graph_test_model/affine_no_parameter/model.mlir b/oneflow/api/cpp/tests/graph_test_model/affine_no_parameter/model.mlir new file mode 100644 index 00000000000..30c09f7c841 --- /dev/null +++ b/oneflow/api/cpp/tests/graph_test_model/affine_no_parameter/model.mlir @@ -0,0 +1,11 @@ +module { + oneflow.job @MyGraph_1(%arg0: tensor<1x3xf32>, %arg1: tensor<3x2xf32>, %arg2: tensor<2xf32>) -> tensor<1x2xf32> { + %output = "oneflow.input"(%arg0) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_MyGraph_1-input_0", output_lbns = ["_MyGraph_1-input_0/out"], scope_symbol_id = 4611686018427527167 : i64, shape = [1 : si64, 3 : si64]} : (tensor<1x3xf32>) -> tensor<1x3xf32> + %output_0 = "oneflow.input"(%arg1) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_MyGraph_1-input_1", output_lbns = ["_MyGraph_1-input_1/out"], scope_symbol_id = 4611686018427527167 : i64, shape = [3 : si64, 2 : si64]} : (tensor<3x2xf32>) -> tensor<3x2xf32> + %output_1 = "oneflow.input"(%arg2) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_MyGraph_1-input_2", output_lbns = ["_MyGraph_1-input_2/out"], scope_symbol_id = 4611686018427527167 : i64, shape = [2 : si64]} : (tensor<2xf32>) -> tensor<2xf32> + %0 = "oneflow.matmul"(%output, %output_0) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "model-matmul_0", output_lbns = ["model-matmul_0/out_0"], scope_symbol_id = 4611686018427535359 : i64, transpose_a = false, transpose_b = false} : (tensor<1x3xf32>, tensor<3x2xf32>) -> tensor<1x2xf32> + %1 = "oneflow.broadcast_add"(%0, %output_1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "model-broadcast_add_1", output_lbns = ["model-broadcast_add_1/z_0"], scope_symbol_id = 4611686018427535359 : i64} : (tensor<1x2xf32>, tensor<2xf32>) -> tensor<1x2xf32> + %output_2 = "oneflow.output"(%1) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_MyGraph_1-output_0", output_lbns = ["_MyGraph_1-output_0/out"], scope_symbol_id = 4611686018427527167 : i64, shape = [1 : si64, 2 : si64]} : (tensor<1x2xf32>) -> tensor<1x2xf32> + oneflow.return %output_2 : tensor<1x2xf32> + } +} diff --git a/oneflow/api/cpp/tests/graph_test_model/affine_with_parameter/model.a/meta b/oneflow/api/cpp/tests/graph_test_model/affine_with_parameter/model.a/meta new file mode 100644 index 00000000000..421341fc956 --- /dev/null +++ b/oneflow/api/cpp/tests/graph_test_model/affine_with_parameter/model.a/meta @@ -0,0 +1,5 @@ +shape { + dim: 3 + dim: 4 +} +data_type: kFloat diff --git a/oneflow/api/cpp/tests/graph_test_model/affine_with_parameter/model.a/out b/oneflow/api/cpp/tests/graph_test_model/affine_with_parameter/model.a/out new file mode 100644 index 0000000000000000000000000000000000000000..be22e342567fcc4be86263602fe799ac97b1e8e1 GIT binary patch literal 48 OcmZQzXs~A>0RsTk90>IQ literal 0 HcmV?d00001 diff --git a/oneflow/api/cpp/tests/graph_test_model/affine_with_parameter/model.b/meta b/oneflow/api/cpp/tests/graph_test_model/affine_with_parameter/model.b/meta new file mode 100644 index 00000000000..166375025be --- /dev/null +++ b/oneflow/api/cpp/tests/graph_test_model/affine_with_parameter/model.b/meta @@ -0,0 +1,4 @@ +shape { + dim: 4 +} +data_type: kFloat diff --git a/oneflow/api/cpp/tests/graph_test_model/affine_with_parameter/model.b/out b/oneflow/api/cpp/tests/graph_test_model/affine_with_parameter/model.b/out new file mode 100644 index 0000000000000000000000000000000000000000..dcce8bfb97e5327dd298643776af46107f980856 GIT binary patch literal 16 NcmZQzXs~BM!T=WZ0{s90 literal 0 HcmV?d00001 diff --git a/oneflow/api/cpp/tests/graph_test_model/affine_with_parameter/model.mlir b/oneflow/api/cpp/tests/graph_test_model/affine_with_parameter/model.mlir new file mode 100644 index 00000000000..15a53af1f48 --- /dev/null +++ b/oneflow/api/cpp/tests/graph_test_model/affine_with_parameter/model.mlir @@ -0,0 +1,11 @@ +module { + oneflow.job @MyGraph_0(%arg0: tensor<1x3xf32>) -> tensor<1x4xf32> { + %output = "oneflow.input"(%arg0) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_MyGraph_0-input_0", output_lbns = ["_MyGraph_0-input_0/out"], scope_symbol_id = 4611686018427469823 : i64, shape = [1 : si64, 3 : si64]} : (tensor<1x3xf32>) -> tensor<1x3xf32> + %output_0 = "oneflow.variable"() {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], nd_sbp = ["B"], op_name = "model.a", output_lbns = ["model.a/out"], scope_symbol_id = 4611686018427482111 : i64, shape = [3 : si64, 4 : si64]} : () -> tensor<3x4xf32> + %output_1 = "oneflow.variable"() {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], nd_sbp = ["B"], op_name = "model.b", output_lbns = ["model.b/out"], scope_symbol_id = 4611686018427494399 : i64, shape = [4 : si64]} : () -> tensor<4xf32> + %0 = "oneflow.matmul"(%output, %output_0) {alpha = 1.000000e+00 : f64, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "model-matmul_0", output_lbns = ["model-matmul_0/out_0"], scope_symbol_id = 4611686018427486207 : i64, transpose_a = false, transpose_b = false} : (tensor<1x3xf32>, tensor<3x4xf32>) -> tensor<1x4xf32> + %1 = "oneflow.broadcast_add"(%0, %output_1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "model-broadcast_add_1", output_lbns = ["model-broadcast_add_1/z_0"], scope_symbol_id = 4611686018427486207 : i64} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + %output_2 = "oneflow.output"(%1) {data_type = 2 : i32, device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], is_dynamic = false, nd_sbp = ["B"], op_name = "_MyGraph_0-output_0", output_lbns = ["_MyGraph_0-output_0/out"], scope_symbol_id = 4611686018427469823 : i64, shape = [1 : si64, 4 : si64]} : (tensor<1x4xf32>) -> tensor<1x4xf32> + oneflow.return %output_2 : tensor<1x4xf32> + } +} diff --git a/oneflow/api/cpp/tests/ivalue_test.cpp b/oneflow/api/cpp/tests/ivalue_test.cpp new file mode 100644 index 00000000000..ff4d054061f --- /dev/null +++ b/oneflow/api/cpp/tests/ivalue_test.cpp @@ -0,0 +1,132 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include "oneflow/api/cpp/framework/dtype.h" +#include "oneflow/api/cpp/framework/ivalue.h" +#include "oneflow/api/cpp/tests/api_test.h" + +namespace oneflow_api { + +namespace { + +std::mt19937 rng(std::random_device{}()); + +} + +TEST(Api, ivalue) { + std::uniform_real_distribution<> dist(-100, 100); + std::uniform_int_distribution<> dist_bool(0, 1); + + const auto v_int = static_cast(dist(rng)); + ASSERT_EQ(IValue(v_int).ToInt(), v_int); + + const auto v_int64 = static_cast(dist(rng)); + ASSERT_EQ(IValue(v_int64).ToInt(), v_int64); + + const auto v_float = static_cast(dist(rng)); + ASSERT_EQ(IValue(v_float).ToDouble(), v_float); + + const auto v_double = static_cast(dist(rng)); + ASSERT_EQ(IValue(v_double).ToDouble(), v_double); + + const auto v_bool = static_cast(dist_bool(rng)); + ASSERT_EQ(IValue(v_bool).ToBool(), v_bool); +} + +TEST(Api, ivalue_tensor) { + EnvScope scope; + + const auto device = Device("cpu"); + const auto shape = RandomShape(); + const auto dtype = DType::kDouble; + + const IValue i_tensor(Tensor(shape, device, dtype)); + const auto& tensor = i_tensor.ToTensor(); + + ASSERT_EQ(tensor.shape(), shape); + ASSERT_EQ(tensor.device(), device); + ASSERT_EQ(tensor.dtype(), dtype); +} + +TEST(Api, ivalue_tensor_vector) { + EnvScope scope; + + const auto device = Device("cpu"); + + const std::vector v_tensor_vector{Tensor(RandomShape(), device, DType::kDouble), + Tensor(RandomShape(), device, DType::kFloat)}; + const auto i_tensor = IValue(v_tensor_vector); + const auto& tensor_vector = i_tensor.ToTensorVector(); + + ASSERT_EQ(v_tensor_vector.size(), tensor_vector.size()); + + for (size_t i = 0; i < tensor_vector.size(); ++i) { + ASSERT_EQ(v_tensor_vector[i].device(), tensor_vector[i].device()); + ASSERT_EQ(v_tensor_vector[i].shape(), tensor_vector[i].shape()); + ASSERT_EQ(v_tensor_vector[i].dtype(), tensor_vector[i].dtype()); + } +} + +TEST(Api, ivalue_copy) { + EnvScope scope; + + const auto device = Device("cpu"); + const auto shape = RandomShape(); + const auto dtype = DType::kDouble; + + const IValue i_tensor(Tensor(shape, device, dtype)); + const auto i_tensor_a = i_tensor; // NOLINT + + ASSERT_EQ(i_tensor_a.ToTensor().shape(), shape); + ASSERT_EQ(i_tensor_a.ToTensor().device(), device); + ASSERT_EQ(i_tensor_a.ToTensor().dtype(), dtype); + + IValue i_tensor_b; + i_tensor_b = i_tensor; + + ASSERT_EQ(i_tensor_b.ToTensor().shape(), shape); + ASSERT_EQ(i_tensor_b.ToTensor().device(), device); + ASSERT_EQ(i_tensor_b.ToTensor().dtype(), dtype); +} + +TEST(Api, ivalue_move) { + EnvScope scope; + + const auto device = Device("cpu"); + const auto shape = RandomShape(); + const auto dtype = DType::kDouble; + + IValue i_tensor_a = IValue(Tensor(shape, device, dtype)); + IValue i_tensor_b = IValue(Tensor(shape, device, dtype)); + + IValue i_tensor_c = std::move(i_tensor_a); + ASSERT_EQ(i_tensor_c.ToTensor().shape(), shape); + ASSERT_EQ(i_tensor_c.ToTensor().device(), device); + ASSERT_EQ(i_tensor_c.ToTensor().dtype(), dtype); + + IValue i_tensor_d; + i_tensor_d = std::move(i_tensor_b); + ASSERT_EQ(i_tensor_d.ToTensor().shape(), shape); + ASSERT_EQ(i_tensor_d.ToTensor().device(), device); + ASSERT_EQ(i_tensor_d.ToTensor().dtype(), dtype); + + ASSERT_EQ(i_tensor_a.IsNone(), true); + ASSERT_EQ(i_tensor_b.IsNone(), true); +} + +} // namespace oneflow_api diff --git a/oneflow/api/cpp/tests/tensor_test.cpp b/oneflow/api/cpp/tests/tensor_test.cpp index 3241a220263..5960f961675 100644 --- a/oneflow/api/cpp/tests/tensor_test.cpp +++ b/oneflow/api/cpp/tests/tensor_test.cpp @@ -26,13 +26,13 @@ TEST(Api, device) { ASSERT_EQ(device.type(), "cpu"); #ifdef WITH_CUDA - device = Device("cuda", 1); + device = Device("cuda:0"); ASSERT_EQ(device.type(), "cuda"); - ASSERT_EQ(device.device_id(), 1); + ASSERT_EQ(device.device_id(), 0); - device = Device("cuda:2"); + device = Device("cuda", 1); ASSERT_EQ(device.type(), "cuda"); - ASSERT_EQ(device.device_id(), 2); + ASSERT_EQ(device.device_id(), 1); #endif } diff --git a/oneflow/api/python/framework/device.cpp b/oneflow/api/python/framework/device.cpp index f843d8ccb74..06d5a733bfb 100644 --- a/oneflow/api/python/framework/device.cpp +++ b/oneflow/api/python/framework/device.cpp @@ -16,7 +16,6 @@ limitations under the License. #include #include #include "oneflow/core/control/global_process_ctx.h" -#include "oneflow/api/common/device.h" #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/common/str_util.h" @@ -29,10 +28,10 @@ namespace oneflow { ONEFLOW_API_PYBIND11_MODULE("", m) { py::class_, std::shared_ptr>>(m, "device") .def(py::init([](const std::string& type_or_type_with_device_id) { - return DeviceExportUtil::ParseAndNew(type_or_type_with_device_id).GetOrThrow(); + return Device::ParseAndNew(type_or_type_with_device_id).GetOrThrow(); })) .def(py::init([](const std::string& type, int64_t device_id) { - return DeviceExportUtil::New(type, device_id).GetOrThrow(); + return Device::New(type, device_id).GetOrThrow(); })) .def_property_readonly("type", [](const Symbol& d) { return d->type(); }) .def_property_readonly("index", [](const Symbol& d) { return d->device_id(); }) diff --git a/oneflow/api/python/framework/dtype.cpp b/oneflow/api/python/framework/dtype.cpp index 4fa6f8d9d08..9696988cf10 100644 --- a/oneflow/api/python/framework/dtype.cpp +++ b/oneflow/api/python/framework/dtype.cpp @@ -45,19 +45,25 @@ ONEFLOW_API_PYBIND11_MODULE("", m) { m.attr("char") = &CHECK_JUST(DType::Get(DataType::kChar)); m.attr("float16") = &CHECK_JUST(DType::Get(DataType::kFloat16)); m.attr("float") = &CHECK_JUST(DType::Get(DataType::kFloat)); - m.attr("float32") = &CHECK_JUST(DType::Get(DataType::kFloat)); m.attr("double") = &CHECK_JUST(DType::Get(DataType::kDouble)); m.attr("float64") = &CHECK_JUST(DType::Get(DataType::kDouble)); - m.attr("int8") = &CHECK_JUST(DType::Get(DataType::kInt8)); m.attr("int32") = &CHECK_JUST(DType::Get(DataType::kInt32)); m.attr("int64") = &CHECK_JUST(DType::Get(DataType::kInt64)); - m.attr("uint8") = &CHECK_JUST(DType::Get(DataType::kUInt8)); m.attr("record") = &CHECK_JUST(DType::Get(DataType::kOFRecord)); m.attr("tensor_buffer") = &CHECK_JUST(DType::Get(DataType::kTensorBuffer)); m.attr("bfloat16") = &CHECK_JUST(DType::Get(DataType::kBFloat16)); + m.attr("uint16") = &CHECK_JUST(DType::Get(DataType::kUInt16)); + m.attr("uint32") = &CHECK_JUST(DType::Get(DataType::kUInt32)); + m.attr("uint64") = &CHECK_JUST(DType::Get(DataType::kUInt64)); + m.attr("uint128") = &CHECK_JUST(DType::Get(DataType::kUInt128)); + m.attr("int16") = &CHECK_JUST(DType::Get(DataType::kInt16)); + m.attr("int128") = &CHECK_JUST(DType::Get(DataType::kInt128)); + m.attr("complex32") = &CHECK_JUST(DType::Get(DataType::kComplex32)); + m.attr("complex64") = &CHECK_JUST(DType::Get(DataType::kComplex64)); + m.attr("complex128") = &CHECK_JUST(DType::Get(DataType::kComplex128)); } } // namespace oneflow diff --git a/oneflow/api/python/framework/nn_graph.cpp b/oneflow/api/python/framework/nn_graph.cpp index 4ee816ea6a3..090d128573d 100644 --- a/oneflow/api/python/framework/nn_graph.cpp +++ b/oneflow/api/python/framework/nn_graph.cpp @@ -22,10 +22,13 @@ limitations under the License. #include "oneflow/core/framework/nn_graph.h" #include "oneflow/core/job/runtime.h" #include "oneflow/core/register/blob.h" +#include "oneflow/core/job/job.pb.h" +#include "oneflow/core/job/job_ir.h" namespace py = pybind11; namespace oneflow { + ONEFLOW_API_PYBIND11_MODULE("nn.graph.", m) { using namespace oneflow; py::class_>(m, "CNNGraph") @@ -63,5 +66,17 @@ ONEFLOW_API_PYBIND11_MODULE("nn.graph.", m) { }); m.def("AddTensorAsGraphLoss", [](const std::shared_ptr& t) { return AddTensorAsGraphLoss(t).GetOrThrow(); }); + m.def("SaveJobToIR", [](const std::string& serialized_job, const std::string& path) { + Job job; + CHECK(TxtString2PbMessage(serialized_job, &job)); + return SaveJobToIR(&job, path).GetOrThrow(); + ; + }); + m.def("LoadSerializedJobFromIR", [](const std::string& path) { + Job job; + LoadJobFromIR(&job, path).GetOrThrow(); + return py::bytes(job.SerializeAsString()); + }); } + } // namespace oneflow diff --git a/oneflow/api/python/framework/op_expr.cpp b/oneflow/api/python/framework/op_expr.cpp index 9a32ef35617..53a2c7a7fc1 100644 --- a/oneflow/api/python/framework/op_expr.cpp +++ b/oneflow/api/python/framework/op_expr.cpp @@ -17,7 +17,6 @@ limitations under the License. #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/common/protobuf.h" -#include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter.h" @@ -32,42 +31,6 @@ namespace oneflow { namespace { -Maybe Interpret(const one::OpExpr& op, const one::TensorTuple& inputs, - const AttrMap& attrs) { - CHECK_EQ_OR_RETURN(op.input_size(), inputs.size()) - << "The operation requires " << op.input_size() << " inputs, but " << inputs.size() - << " is given."; - return JUST(one::OpInterpUtil::Dispatch(op, inputs, attrs)); -} - -Maybe Interpret(const one::OpExpr& op, - const std::vector>& inputs, - const AttrMap& attrs) { - one::TensorTuple input_list(inputs.size()); - for (int i = 0; i < inputs.size(); ++i) { input_list[i] = inputs[i]; } - return JUST(Interpret(op, input_list, attrs)); -} - -Maybe Interpret(const one::OpExpr& op, const Symbol& placement, - const std::vector>& sbp_tuple, - const AttrMap& attrs) { - CHECK_EQ_OR_RETURN(op.input_size(), 0) - << " the op : " << op.op_type_name() - << " is NOT source op with input_size = " << op.input_size(); - const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple)); - return JUST(one::OpInterpUtil::Dispatch( - op, {}, one::OpExprInterpContext(attrs, placement, nd_sbp))); -} - -Maybe Interpret(const one::OpExpr& op, const Symbol& device, - const AttrMap& attrs) { - CHECK_EQ_OR_RETURN(op.input_size(), 0) - << " the op : " << op.op_type_name() - << " is NOT source op with input_size = " << op.input_size(); - return JUST(one::OpInterpUtil::Dispatch( - op, {}, one::OpExprInterpContext(attrs, device))); -} - template::value>::type* = nullptr> py::class_> PybindExportOpExpr( @@ -92,27 +55,7 @@ ONEFLOW_API_PYBIND11_MODULE("one", m) { py::class_>(m, "OpExpr") .def_property_readonly("op_type_name", &one::OpExpr::op_type_name) .def_property_readonly("input_size", &one::OpExpr::input_size) - .def_property_readonly("output_size", &one::OpExpr::output_size) - .def("apply", - [](const one::OpExpr& op_expr, const std::vector>& inputs, - const MutableCfgAttrMap& attrs) { - return Interpret(op_expr, inputs, attrs).GetPtrOrThrow(); - }) - .def("apply", - [](const one::OpExpr& op_expr, const one::TensorTuple& inputs, - const MutableCfgAttrMap& attrs) { - return Interpret(op_expr, inputs, attrs).GetPtrOrThrow(); - }) - .def("apply", - [](const one::OpExpr& op_expr, const Symbol& placement, - const std::vector>& sbp_tuple, - const MutableCfgAttrMap& attrs) { - return Interpret(op_expr, placement, sbp_tuple, attrs).GetPtrOrThrow(); - }) - .def("apply", [](const one::OpExpr& op_expr, const Symbol& device, - const MutableCfgAttrMap& attrs) { - return Interpret(op_expr, device, attrs).GetPtrOrThrow(); - }); + .def_property_readonly("output_size", &one::OpExpr::output_size); py::class_>(m, "BuiltinOpExpr") diff --git a/oneflow/api/python/functional/common.cpp b/oneflow/api/python/functional/common.cpp index 143f29cb3c9..ed7992b64c5 100644 --- a/oneflow/api/python/functional/common.cpp +++ b/oneflow/api/python/functional/common.cpp @@ -20,6 +20,7 @@ limitations under the License. #include "oneflow/core/common/scalar.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/device.h" +#include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/random_generator.h" @@ -109,6 +110,25 @@ Maybe> PyUnpackDType(PyObject* obj) { return *py::cast*>(handle); } +// DType list +bool PyDTypeSequenceCheck(PyObject* obj) { + return PySequenceCheck(obj, [](PyObject* item) { return PyDTypeCheck(item); }); +} +Maybe>> PyUnpackDTypeSequence(PyObject* obj) { + return PyUnpackSequence>(obj, [](PyObject* item) { return PyUnpackDType(item); }); +} + +// Shape list +bool PyShapeSequenceCheck(PyObject* obj) { + return PySequenceCheck(obj, [](PyObject* item) { return PyLongSequenceCheck(item); }); +} +Maybe> PyUnpackShapeSequence(PyObject* obj) { + return PyUnpackSequence(obj, [](PyObject* item) -> Maybe { + const auto& shape = JUST(PyUnpackLongSequence(item)); + return std::make_shared(DimVector(shape->begin(), shape->end())); + }); +} + // Generator bool PyGeneratorCheck(PyObject* obj) { auto handle = py::reinterpret_borrow(obj); @@ -250,6 +270,17 @@ Maybe PyUnpackTensorIndex(PyObject* obj) { return tensor_index; } +// OpExpr +bool PyOpExprCheck(PyObject* obj) { + auto handle = py::reinterpret_borrow(obj); + return py::isinstance(handle); +} + +Maybe PyUnpackOpExpr(PyObject* obj) { + auto handle = py::reinterpret_borrow(obj); + return py::cast>(handle); +} + } // namespace functional } // namespace one } // namespace oneflow diff --git a/oneflow/api/python/functional/common.h b/oneflow/api/python/functional/common.h index 9530b9b9b4a..e53c48a1422 100644 --- a/oneflow/api/python/functional/common.h +++ b/oneflow/api/python/functional/common.h @@ -26,6 +26,7 @@ limitations under the License. #include "oneflow/core/common/scalar.h" #include "oneflow/core/framework/dtype.h" #include "oneflow/core/framework/device.h" +#include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/random_generator.h" @@ -127,6 +128,14 @@ Maybe PyUnpackTensorTuple(PyObject* obj); bool PyDTypeCheck(PyObject* obj); Maybe> PyUnpackDType(PyObject* obj); +// DType list +bool PyDTypeSequenceCheck(PyObject* obj); +Maybe>> PyUnpackDTypeSequence(PyObject* obj); + +// Shape list +bool PyShapeSequenceCheck(PyObject* obj); +Maybe> PyUnpackShapeSequence(PyObject* obj); + // Generator bool PyGeneratorCheck(PyObject* obj); Maybe PyUnpackGenerator(PyObject* obj); @@ -151,6 +160,10 @@ Maybe>> PyUnpackSbpParallelSequence(PyObjec bool PyTensorIndexCheck(PyObject* obj); Maybe PyUnpackTensorIndex(PyObject* obj); +// OpExpr +bool PyOpExprCheck(PyObject* obj); +Maybe PyUnpackOpExpr(PyObject* obj); + } // namespace functional } // namespace one } // namespace oneflow diff --git a/oneflow/api/python/functional/dispatch_stateful_ops.cpp b/oneflow/api/python/functional/dispatch_stateful_ops.cpp new file mode 100644 index 00000000000..782c95db4ae --- /dev/null +++ b/oneflow/api/python/functional/dispatch_stateful_ops.cpp @@ -0,0 +1,459 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/common/scalar.h" +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/nd_sbp.h" +#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" +#include "oneflow/core/framework/tensor.h" +#include "oneflow/core/framework/tensor_tuple.h" +#include "oneflow/core/functional/functional.h" +#include "oneflow/core/functional/function_library.h" + +namespace oneflow { +namespace one { +namespace functional { + +namespace impl { + +ONEFLOW_FUNCTION_LIBRARY(m) { + m.add_functor( + "DispatchFeedInput", + [](const std::shared_ptr& op, const std::shared_ptr& input) -> Maybe { + return OpInterpUtil::Dispatch(*op, {input}); + }); + m.add_functor( + "DispatchFetchOutput", + [](const std::shared_ptr& op, const std::shared_ptr& input) -> Maybe { + return OpInterpUtil::Dispatch(*op, {input}); + }); + m.add_functor("DispatchFeedVariable", + [](const std::shared_ptr& op, const std::shared_ptr& input, + const Scalar& l2) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("l2", JUST(l2.As()))); + return OpInterpUtil::Dispatch(*op, {input}, attrs); + }); + m.add_functor( + "DispatchOfrecordReader", + [](const std::shared_ptr& op, const std::string& data_dir, int32_t data_part_num, + const std::string& part_name_prefix, int32_t part_name_suffix_length, int32_t batch_size, + int32_t shuffle_buffer_size, bool random_shuffle, bool shuffle_after_epoch, int64_t seed, + const Optional>& device) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("data_dir", data_dir)); + JUST(attrs.SetAttr("data_part_num", data_part_num)); + JUST(attrs.SetAttr("part_name_prefix", part_name_prefix)); + JUST(attrs.SetAttr("part_name_suffix_length", part_name_suffix_length)); + JUST(attrs.SetAttr("batch_size", batch_size)); + JUST(attrs.SetAttr("shuffle_buffer_size", shuffle_buffer_size)); + JUST(attrs.SetAttr("random_shuffle", random_shuffle)); + JUST(attrs.SetAttr("shuffle_after_epoch", shuffle_after_epoch)); + JUST(attrs.SetAttr("seed", seed)); + return OpInterpUtil::Dispatch(*op, {}, OpExprInterpContext(attrs, JUST(device))); + }); + m.add_functor( + "DispatchOfrecordReader", + [](const std::shared_ptr& op, const std::string& data_dir, int32_t data_part_num, + const std::string& part_name_prefix, int32_t part_name_suffix_length, int32_t batch_size, + int32_t shuffle_buffer_size, bool random_shuffle, bool shuffle_after_epoch, int64_t seed, + const Symbol& placement, + const std::vector>& sbp_tuple) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("data_dir", data_dir)); + JUST(attrs.SetAttr("data_part_num", data_part_num)); + JUST(attrs.SetAttr("part_name_prefix", part_name_prefix)); + JUST(attrs.SetAttr("part_name_suffix_length", part_name_suffix_length)); + JUST(attrs.SetAttr("batch_size", batch_size)); + JUST(attrs.SetAttr("shuffle_buffer_size", shuffle_buffer_size)); + JUST(attrs.SetAttr("random_shuffle", random_shuffle)); + JUST(attrs.SetAttr("shuffle_after_epoch", shuffle_after_epoch)); + JUST(attrs.SetAttr("seed", seed)); + JUST(attrs.SetAttr("nd_sbp", *JUST(GetNdSbpStrList(sbp_tuple)))); + auto nd_sbp = JUST(GetNdSbp(sbp_tuple)); + return OpInterpUtil::Dispatch(*op, {}, + OpExprInterpContext(attrs, placement, nd_sbp)); + }); + m.add_functor("DispatchOfrecordRawDecoder", + [](const std::shared_ptr& op, const std::shared_ptr& input, + const std::string& name, const Shape& shape, const Symbol& data_type, + bool dim1_varying_length, bool truncate) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("name", name)); + JUST(attrs.SetAttr("shape", shape)); + JUST(attrs.SetAttr("data_type", data_type->data_type())); + JUST(attrs.SetAttr("dim1_varying_length", dim1_varying_length)); + JUST(attrs.SetAttr("truncate", truncate)); + return OpInterpUtil::Dispatch(*op, {input}, attrs); + }); + m.add_functor( + "DispatchCoinFlip", + [](const std::shared_ptr& op, int64_t batch_size, Scalar probability, int64_t seed, + bool has_seed, const Optional>& device) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("probability", JUST(probability.As()))); + JUST(attrs.SetAttr("batch_size", batch_size)); + JUST(attrs.SetAttr("seed", seed)); + JUST(attrs.SetAttr("has_seed", has_seed)); + return OpInterpUtil::Dispatch(*op, {}, OpExprInterpContext(attrs, JUST(device))); + }); + m.add_functor("DispatchCoinFlip", + [](const std::shared_ptr& op, int64_t batch_size, Scalar probability, + int64_t seed, bool has_seed, const Symbol& placement, + const std::vector>& sbp_tuple) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("probability", JUST(probability.As()))); + JUST(attrs.SetAttr("batch_size", batch_size)); + JUST(attrs.SetAttr("seed", seed)); + JUST(attrs.SetAttr("has_seed", has_seed)); + JUST(attrs.SetAttr("nd_sbp", *JUST(GetNdSbpStrList(sbp_tuple)))); + auto nd_sbp = JUST(GetNdSbp(sbp_tuple)); + return OpInterpUtil::Dispatch( + *op, {}, OpExprInterpContext(attrs, placement, nd_sbp)); + }); + m.add_functor( + "DispatchCropMirrorNormalizeFromUint8", + [](const std::shared_ptr& op, const TensorTuple& input, int64_t crop_h, + int64_t crop_w, float crop_pos_x, float crop_pos_y, const std::vector& mean, + const std::vector& std, const Symbol& output_dtype, + const std::string& output_layout, const std::string& color_space) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("color_space", color_space)); + JUST(attrs.SetAttr("output_layout", output_layout)); + JUST(attrs.SetAttr("mean", mean)); + JUST(attrs.SetAttr("std", std)); + JUST(attrs.SetAttr("crop_h", crop_h)); + JUST(attrs.SetAttr("crop_w", crop_w)); + JUST(attrs.SetAttr("crop_pos_x", crop_pos_x)); + JUST(attrs.SetAttr("crop_pos_y", crop_pos_y)); + JUST(attrs.SetAttr("output_dtype", output_dtype->data_type())); + return OpInterpUtil::Dispatch(*op, input, attrs); + }); + m.add_functor( + "DispatchCropMirrorNormalizeFromTensorBuffer", + [](const std::shared_ptr& op, const TensorTuple& input, int64_t crop_h, + int64_t crop_w, float crop_pos_x, float crop_pos_y, const std::vector& mean, + const std::vector& std, const Symbol& output_dtype, + const std::string& output_layout, const std::string& color_space) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("color_space", color_space)); + JUST(attrs.SetAttr("output_layout", output_layout)); + JUST(attrs.SetAttr("mean", mean)); + JUST(attrs.SetAttr("std", std)); + JUST(attrs.SetAttr("crop_h", crop_h)); + JUST(attrs.SetAttr("crop_w", crop_w)); + JUST(attrs.SetAttr("crop_pos_x", crop_pos_x)); + JUST(attrs.SetAttr("crop_pos_y", crop_pos_y)); + JUST(attrs.SetAttr("output_dtype", output_dtype->data_type())); + return OpInterpUtil::Dispatch(*op, {input}, attrs); + }); + m.add_functor( + "DispatchOfrecordImageDecoderRandomCrop", + [](const std::shared_ptr& op, const std::shared_ptr& input, + const std::string& name, const std::string& color_space, + const std::vector& random_area, const std::vector& random_aspect_ratio, + int32_t num_attempts, int64_t seed, bool has_seed) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("name", name)); + JUST(attrs.SetAttr("color_space", color_space)); + JUST(attrs.SetAttr("num_attempts", num_attempts)); + JUST(attrs.SetAttr("seed", seed)); + JUST(attrs.SetAttr("has_seed", has_seed)); + JUST(attrs.SetAttr("random_area", random_area)); + JUST(attrs.SetAttr("random_aspect_ratio", random_aspect_ratio)); + return OpInterpUtil::Dispatch(*op, {input}, attrs); + }); + m.add_functor("DispatchOfrecordImageDecoder", + [](const std::shared_ptr& op, const std::shared_ptr& input, + const std::string& name, const std::string& color_space) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("name", name)); + JUST(attrs.SetAttr("color_space", color_space)); + return OpInterpUtil::Dispatch(*op, {input}, attrs); + }); + m.add_functor("DispatchImageDecoderRandomCropResize", + [](const std::shared_ptr& op, const std::shared_ptr& input, + int64_t target_width, int64_t target_height, int64_t seed, int64_t num_workers, + int64_t max_num_pixels, float random_area_min, float random_area_max, + float random_aspect_ratio_min, float random_aspect_ratio_max, + int64_t warmup_size, int64_t num_attempts) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("target_width", target_width)); + JUST(attrs.SetAttr("target_height", target_height)); + JUST(attrs.SetAttr("seed", seed)); + JUST(attrs.SetAttr("num_workers", num_workers)); + JUST(attrs.SetAttr("max_num_pixels", max_num_pixels)); + JUST(attrs.SetAttr("random_area_min", random_area_min)); + JUST(attrs.SetAttr("random_area_max", random_area_max)); + JUST(attrs.SetAttr("random_aspect_ratio_min", random_aspect_ratio_min)); + JUST(attrs.SetAttr("random_aspect_ratio_max", random_aspect_ratio_max)); + JUST(attrs.SetAttr("warmup_size", warmup_size)); + JUST(attrs.SetAttr("num_attempts", num_attempts)); + return OpInterpUtil::Dispatch(*op, {input}, attrs); + }); + m.add_functor( + "DispatchTensorBufferToListOfTensorsV2", + [](const std::shared_ptr& op, const std::shared_ptr& input, + const std::vector& out_shapes, const std::vector>& out_dtypes, + bool dynamic_out) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("out_shapes", out_shapes)); + JUST(attrs.SetAttr("dynamic_out", dynamic_out)); + auto out_data_types = std::vector(); + for (auto it = out_dtypes.begin(); it != out_dtypes.end(); it++) { + out_data_types.emplace_back((*it)->data_type()); + } + JUST(attrs.SetAttr("out_dtypes", out_data_types)); + return OpInterpUtil::Dispatch(*op, {input}, attrs); + }); + m.add_functor("DispatchImageResizeKeepAspectRatio", + [](const std::shared_ptr& op, const std::shared_ptr& input, + int32_t target_size, int32_t min_size, int32_t max_size, bool resize_longer, + const std::string& interpolation_type) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("target_size", target_size)); + JUST(attrs.SetAttr("min_size", min_size)); + JUST(attrs.SetAttr("max_size", max_size)); + JUST(attrs.SetAttr("resize_longer", resize_longer)); + JUST(attrs.SetAttr("interpolation_type", interpolation_type)); + return OpInterpUtil::Dispatch(*op, {input}, attrs); + }); + m.add_functor("DispatchImageResizeToFixed", + [](const std::shared_ptr& op, const std::shared_ptr& input, + int64_t target_width, int64_t target_height, int64_t channels, + const Symbol& data_type, + const std::string& interpolation_type) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("target_width", target_width)); + JUST(attrs.SetAttr("target_height", target_height)); + JUST(attrs.SetAttr("channels", channels)); + JUST(attrs.SetAttr("data_type", data_type->data_type())); + JUST(attrs.SetAttr("interpolation_type", interpolation_type)); + return OpInterpUtil::Dispatch(*op, {input}, attrs); + }); + m.add_functor( + "DispatchImageDecode", + [](const std::shared_ptr& op, const std::shared_ptr& input, + const std::string& color_space, const Symbol& data_type) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("color_space", color_space)); + JUST(attrs.SetAttr("data_type", data_type->data_type())); + return OpInterpUtil::Dispatch(*op, {input}, attrs); + }); + m.add_functor("DispatchImageNormalize", + [](const std::shared_ptr& op, const std::shared_ptr& input, + const std::vector& mean, const std::vector& std) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("std", std)); + JUST(attrs.SetAttr("mean", mean)); + return OpInterpUtil::Dispatch(*op, {input}, attrs); + }); + m.add_functor( + "DispatchCOCOReader", + [](const std::shared_ptr& op, const std::string& image_dir, + const std::string& annotation_file, int64_t batch_size, bool shuffle_after_epoch, + int64_t random_seed, bool group_by_ratio, bool remove_images_without_annotations, + bool stride_partition, int64_t session_id, + const Optional>& device) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("session_id", session_id)); + JUST(attrs.SetAttr("annotation_file", annotation_file)); + JUST(attrs.SetAttr("image_dir", image_dir)); + JUST(attrs.SetAttr("batch_size", batch_size)); + JUST(attrs.SetAttr("shuffle_after_epoch", shuffle_after_epoch)); + JUST(attrs.SetAttr("random_seed", random_seed)); + JUST(attrs.SetAttr("group_by_ratio", group_by_ratio)); + JUST(attrs.SetAttr("remove_images_without_annotations", remove_images_without_annotations)); + JUST(attrs.SetAttr("stride_partition", stride_partition)); + return OpInterpUtil::Dispatch(*op, {}, + OpExprInterpContext(attrs, JUST(device))); + }); + m.add_functor( + "DispatchCOCOReader", + [](const std::shared_ptr& op, const std::string& image_dir, + const std::string& annotation_file, int64_t batch_size, bool shuffle_after_epoch, + int64_t random_seed, bool group_by_ratio, bool remove_images_without_annotations, + bool stride_partition, int64_t session_id, const Symbol& placement, + const std::vector>& sbp_tuple) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("session_id", session_id)); + JUST(attrs.SetAttr("annotation_file", annotation_file)); + JUST(attrs.SetAttr("image_dir", image_dir)); + JUST(attrs.SetAttr("batch_size", batch_size)); + JUST(attrs.SetAttr("shuffle_after_epoch", shuffle_after_epoch)); + JUST(attrs.SetAttr("random_seed", random_seed)); + JUST(attrs.SetAttr("group_by_ratio", group_by_ratio)); + JUST(attrs.SetAttr("remove_images_without_annotations", remove_images_without_annotations)); + JUST(attrs.SetAttr("stride_partition", stride_partition)); + JUST(attrs.SetAttr("nd_sbp", *JUST(GetNdSbpStrList(sbp_tuple)))); + auto nd_sbp = JUST(GetNdSbp(sbp_tuple)); + return OpInterpUtil::Dispatch(*op, {}, + OpExprInterpContext(attrs, placement, nd_sbp)); + }); + m.add_functor( + "DispatchImageBatchAlign", + [](const std::shared_ptr& op, const std::shared_ptr& input, int32_t alignment, + const Shape& shape, const Symbol& data_type, bool dynamic_out) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("shape", shape)); + JUST(attrs.SetAttr("data_type", data_type->data_type())); + JUST(attrs.SetAttr("alignment", alignment)); + JUST(attrs.SetAttr("dynamic_out", dynamic_out)); + return OpInterpUtil::Dispatch(*op, {input}, attrs); + }); + m.add_functor("DispatchOfrecordBytesDecoder", + [](const std::shared_ptr& op, const std::shared_ptr& input, + const std::string& name) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("name", name)); + return OpInterpUtil::Dispatch(*op, {input}, attrs); + }); + m.add_functor( + "DispatchMegatronGptMmapDataLoader", + [](const std::shared_ptr& op, const std::string& data_file_prefix, int64_t seq_length, + int64_t label_length, int64_t num_samples, int64_t batch_size, const Symbol& dtype, + const std::vector& split_sizes, int64_t split_index, bool shuffle, + int64_t random_seed, const Optional>& device) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("data_file_prefix", data_file_prefix)); + JUST(attrs.SetAttr("seq_length", seq_length)); + JUST(attrs.SetAttr("label_length", label_length)); + JUST(attrs.SetAttr("num_samples", num_samples)); + JUST(attrs.SetAttr("batch_size", batch_size)); + JUST(attrs.SetAttr("dtype", dtype->data_type())); + JUST(attrs.SetAttr("split_sizes", split_sizes)); + JUST(attrs.SetAttr("split_index", split_index)); + JUST(attrs.SetAttr("shuffle", shuffle)); + JUST(attrs.SetAttr("random_seed", random_seed)); + return OpInterpUtil::Dispatch(*op, {}, OpExprInterpContext(attrs, JUST(device))); + }); + m.add_functor( + "DispatchMegatronGptMmapDataLoader", + [](const std::shared_ptr& op, const std::string& data_file_prefix, int64_t seq_length, + int64_t label_length, int64_t num_samples, int64_t batch_size, const Symbol& dtype, + const std::vector& split_sizes, int64_t split_index, bool shuffle, + int64_t random_seed, const Symbol& placement, + const std::vector>& sbp_tuple) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("data_file_prefix", data_file_prefix)); + JUST(attrs.SetAttr("seq_length", seq_length)); + JUST(attrs.SetAttr("label_length", label_length)); + JUST(attrs.SetAttr("num_samples", num_samples)); + JUST(attrs.SetAttr("batch_size", batch_size)); + JUST(attrs.SetAttr("dtype", dtype->data_type())); + JUST(attrs.SetAttr("split_sizes", split_sizes)); + JUST(attrs.SetAttr("split_index", split_index)); + JUST(attrs.SetAttr("shuffle", shuffle)); + JUST(attrs.SetAttr("random_seed", random_seed)); + auto nd_sbp = JUST(GetNdSbp(sbp_tuple)); + return OpInterpUtil::Dispatch(*op, {}, + OpExprInterpContext(attrs, placement, nd_sbp)); + }); + m.add_functor("DispatchRmspropUpdate", + [](const std::shared_ptr& op, const TensorTuple& inputs, + float learning_rate, double scale, float l1, float l2, bool centered, + float epsilon, float decay_rate, float weight_decay) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("learning_rate_val", learning_rate)); + JUST(attrs.SetAttr("scale", scale)); + JUST(attrs.SetAttr("l1", l1)); + JUST(attrs.SetAttr("l2", l2)); + JUST(attrs.SetAttr("centered", centered)); + JUST(attrs.SetAttr("epsilon", epsilon)); + JUST(attrs.SetAttr("decay_rate", decay_rate)); + JUST(attrs.SetAttr("weight_decay", weight_decay)); + JUST(OpInterpUtil::Dispatch(*op, inputs, attrs)); + return Maybe::Ok(); + }); + m.add_functor("DispatchAdamUpdate", + [](const std::shared_ptr& op, const TensorTuple& inputs, + float learning_rate, float bias_correction1, float bias_correction2, + double scale, float l1, float l2, float beta1, float beta2, float epsilon, + float weight_decay, bool amsgrad, bool do_bias_correction) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("learning_rate_val", learning_rate)); + JUST(attrs.SetAttr("bias_correction1_val", bias_correction1)); + JUST(attrs.SetAttr("bias_correction2_val", bias_correction2)); + JUST(attrs.SetAttr("scale", scale)); + JUST(attrs.SetAttr("l1", l1)); + JUST(attrs.SetAttr("l2", l2)); + JUST(attrs.SetAttr("beta1", beta1)); + JUST(attrs.SetAttr("beta2", beta2)); + JUST(attrs.SetAttr("epsilon", epsilon)); + JUST(attrs.SetAttr("weight_decay", weight_decay)); + JUST(attrs.SetAttr("amsgrad", amsgrad)); + JUST(attrs.SetAttr("do_bias_correction", do_bias_correction)); + JUST(OpInterpUtil::Dispatch(*op, inputs, attrs)); + return Maybe::Ok(); + }); + m.add_functor("DispatchAdagradUpdate", + [](const std::shared_ptr& op, const TensorTuple& inputs, + float learning_rate, double scale, float l1, float l2, float lr_decay, + float weight_decay, float epsilon, int32_t train_step) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("learning_rate_val", learning_rate)); + JUST(attrs.SetAttr("scale", scale)); + JUST(attrs.SetAttr("l1", l1)); + JUST(attrs.SetAttr("l2", l2)); + JUST(attrs.SetAttr("lr_decay", lr_decay)); + JUST(attrs.SetAttr("weight_decay", weight_decay)); + JUST(attrs.SetAttr("epsilon", epsilon)); + JUST(attrs.SetAttr("train_step_val", train_step)); + JUST(OpInterpUtil::Dispatch(*op, inputs, attrs)); + return Maybe::Ok(); + }); + m.add_functor( + "DispatchMomentumUpdate", + [](const std::shared_ptr& op, const TensorTuple& inputs, float learning_rate, + double scale, float l1, float l2, float beta, float weight_decay) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("learning_rate_val", learning_rate)); + JUST(attrs.SetAttr("scale", scale)); + JUST(attrs.SetAttr("l1", l1)); + JUST(attrs.SetAttr("l2", l2)); + JUST(attrs.SetAttr("beta", beta)); + JUST(attrs.SetAttr("weight_decay", weight_decay)); + JUST(OpInterpUtil::Dispatch(*op, inputs, attrs)); + return Maybe::Ok(); + }); + m.add_functor( + "DispatchSgdUpdate", + [](const std::shared_ptr& op, const TensorTuple& inputs, float learning_rate, + double scale, float l1, float l2, float weight_decay) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("learning_rate_val", learning_rate)); + JUST(attrs.SetAttr("scale", scale)); + JUST(attrs.SetAttr("l1", l1)); + JUST(attrs.SetAttr("l2", l2)); + JUST(attrs.SetAttr("weight_decay", weight_decay)); + JUST(OpInterpUtil::Dispatch(*op, inputs, attrs)); + return Maybe::Ok(); + }); + m.add_functor("DispatchEagerNcclAllReduce", + [](const std::shared_ptr& op, const std::shared_ptr& input, + const std::string& parallel_conf, bool async_launch) -> Maybe { + MutableAttrMap attrs; + JUST(attrs.SetAttr("parallel_conf", parallel_conf)); + JUST(attrs.SetAttr("async_launch", async_launch)); + return OpInterpUtil::Dispatch(*op, {input}, attrs); + }); +} + +} // namespace impl + +} // namespace functional +} // namespace one +} // namespace oneflow diff --git a/oneflow/api/python/functional/dispatch_stateful_ops.yaml b/oneflow/api/python/functional/dispatch_stateful_ops.yaml new file mode 100644 index 00000000000..2b8f49263b1 --- /dev/null +++ b/oneflow/api/python/functional/dispatch_stateful_ops.yaml @@ -0,0 +1,137 @@ +# Copyright 2020 The OneFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The following data types are allowed, +# { +# "Tensor", "TensorTuple", "Scalar", "Int", "Int32", "Int64", "Float", "Double", "String", "Bool", +# "ScalarList", "IntList", "Int32List", "Int64List", "FloatList", "DoubleList", "StringList", +# "BoolList", "DataType", "Shape", "Generator", "TensorIndex", "Device", "Placement", +# "Sbp", "SbpList" +# } + +- name: "dispatch_feed_input" + signature: "Tensor (OpExpr op, Tensor input) => DispatchFeedInput" + bind_python: True + +- name: "dispatch_feed_variable" + signature: "Tensor (OpExpr op, Tensor input, Scalar l2) => DispatchFeedVariable" + bind_python: True + +- name: "dispatch_fetch_output" + signature: "Tensor (OpExpr op, Tensor input) => DispatchFetchOutput" + bind_python: True + +- name: "dispatch_ofrecord_reader" + signature: [ + "Tensor (OpExpr op, String data_dir, Int32 data_part_num, String part_name_prefix=\"part-\", Int32 part_name_suffix_length=-1, Int32 batch_size, Int32 shuffle_buffer_size=1024, Bool random_shuffle=False, Bool shuffle_after_epoch=False, Int64 seed=-1, Device device=None) => DispatchOfrecordReader", + "Tensor (OpExpr op, String data_dir, Int32 data_part_num, String part_name_prefix=\"part-\", Int32 part_name_suffix_length=-1, Int32 batch_size, Int32 shuffle_buffer_size=1024, Bool random_shuffle=False, Bool shuffle_after_epoch=False, Int64 seed=-1, Placement placement, SbpList sbp) => DispatchOfrecordReader", + ] + bind_python: True + +- name: "dispatch_ofrecord_raw_decoder" + signature: "Tensor (OpExpr op, Tensor input, String name, Shape shape, DataType data_type, Bool dim1_varying_length=False, Bool truncate=False) => DispatchOfrecordRawDecoder" + bind_python: True + +- name: "dispatch_coin_flip" + signature: [ + "Tensor (OpExpr op, Int64 batch_size, Scalar probability=0.5, Int64 seed=-1, Bool has_seed=False, Device device=None) => DispatchCoinFlip", + "Tensor (OpExpr op, Int64 batch_size, Scalar probability=0.5, Int64 seed=-1, Bool has_seed=False, Placement placement, SbpList sbp) => DispatchCoinFlip", + ] + bind_python: True + +- name: "dispatch_crop_mirror_normalize_from_uint8" + signature: "Tensor (OpExpr op, TensorTuple input, Int64 crop_h=0, Int64 crop_w=0, Float crop_pos_x=0.5, Float crop_pos_y=0.5, FloatList mean, FloatList std, DataType output_dtype=kFloat, String output_layout=\"NCHW\", String color_space=\"BGR\") => DispatchCropMirrorNormalizeFromUint8" + bind_python: True + +- name: "dispatch_crop_mirror_normalize_from_tensorbuffer" + signature: "Tensor (OpExpr op, TensorTuple input, Int64 crop_h=0, Int64 crop_w=0, Float crop_pos_x=0.5, Float crop_pos_y=0.5, FloatList mean, FloatList std, DataType output_dtype=kFloat, String output_layout=\"NCHW\", String color_space=\"BGR\") => DispatchCropMirrorNormalizeFromTensorBuffer" + bind_python: True + +- name: "dispatch_ofrecord_image_decoder_random_crop" + signature: "Tensor (OpExpr op, Tensor input, String name, String color_space=\"BGR\", FloatList random_area, FloatList random_aspect_ratio, Int32 num_attempts=10, Int64 seed=-1, Bool has_seed=False) => DispatchOfrecordImageDecoderRandomCrop" + bind_python: True + +- name: "dispatch_ofrecord_image_decoder" + signature: "Tensor (OpExpr op, Tensor input, String name, String color_space=\"BGR\") => DispatchOfrecordImageDecoder" + bind_python: True + +- name: "dispatch_image_decoder_random_crop_resize" + signature: "Tensor (OpExpr op, Tensor input, Int64 target_width, Int64 target_height, Int64 seed, Int64 num_workers=3, Int64 max_num_pixels=67108864, Float random_area_min=0.08f, Float random_area_max=1.0f, Float random_aspect_ratio_min=0.75f, Float random_aspect_ratio_max=1.333333f, Int64 warmup_size=6400, Int64 num_attempts=10) => DispatchImageDecoderRandomCropResize" + bind_python: True + +- name: "dispatch_tensor_buffer_to_list_of_tensors_v2" + signature: "TensorTuple (OpExpr op, Tensor input, ShapeList out_shapes, DataTypeList out_dtypes, Bool dynamic_out) => DispatchTensorBufferToListOfTensorsV2" + bind_python: True + +- name: "dispatch_image_resize_keep_aspect_ratio" + signature: "TensorTuple (OpExpr op, Tensor input, Int32 target_size, Int32 min_size=0, Int32 max_size=0, Bool resize_longer=False, String interpolation_type=\"bilinear\") => DispatchImageResizeKeepAspectRatio" + bind_python: True + +- name: "dispatch_image_resize_to_fixed" + signature: "TensorTuple (OpExpr op, Tensor input, Int64 target_width=0, Int64 target_height=0, Int64 channels=3, DataType data_type=kUInt8, String interpolation_type=\"bilinear\") => DispatchImageResizeToFixed" + bind_python: True + +- name: "dispatch_image_decode" + signature: "Tensor (OpExpr op, Tensor input, String color_space=\"BGR\", DataType data_type=kUInt8) => DispatchImageDecode" + bind_python: True + +- name: "dispatch_image_normalize" + signature: "Tensor (OpExpr op, Tensor input, FloatList mean, FloatList std) => DispatchImageNormalize" + bind_python: True + +- name: "dispatch_coco_reader" + signature: [ + "TensorTuple (OpExpr op, String image_dir, String annotation_file, Int64 batch_size, Bool shuffle_after_epoch=False, Int64 random_seed=-1, Bool group_by_ratio=True, Bool remove_images_without_annotations=True, Bool stride_partition=False, Int64 session_id, Device device=None) => DispatchCOCOReader", + "TensorTuple (OpExpr op, String image_dir, String annotation_file, Int64 batch_size, Bool shuffle_after_epoch=False, Int64 random_seed=-1, Bool group_by_ratio=True, Bool remove_images_without_annotations=True, Bool stride_partition=False, Int64 session_id, Placement placement, SbpList sbp) => DispatchCOCOReader", + ] + bind_python: True + +- name: "dispatch_image_batch_align" + signature: "Tensor (OpExpr op, Tensor input, Int32 alignment, Shape shape, DataType data_type, Bool dynamic_out) => DispatchImageBatchAlign" + bind_python: True + +- name: "dispatch_ofrecord_bytes_decoder" + signature: "Tensor (OpExpr op, Tensor input, String name) => DispatchOfrecordBytesDecoder" + bind_python: True + +- name: "dispatch_megatron_gpt_mmap_data_loader" + signature: [ + "Tensor (OpExpr op, String data_file_prefix, Int64 seq_length, Int64 label_length=1, Int64 num_samples, Int64 batch_size, DataType dtype, Int64List split_sizes, Int64 split_index, Bool shuffle, Int64 random_seed, Device device=None) => DispatchMegatronGptMmapDataLoader", + "Tensor (OpExpr op, String data_file_prefix, Int64 seq_length, Int64 label_length=1, Int64 num_samples, Int64 batch_size, DataType dtype, Int64List split_sizes, Int64 split_index, Bool shuffle, Int64 random_seed, Placement placement, SbpList sbp) => DispatchMegatronGptMmapDataLoader", + ] + bind_python: True + +- name: "dispatch_rmsprop_update" + signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Bool centered=False, Float epsilon=1e-8, Float decay_rate=0.99, Float weight_decay=0.0) => DispatchRmspropUpdate" + bind_python: True + +- name: "dispatch_adam_update" + signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Float bias_correction1=1.0, Float bias_correction2=1.0, Double scale=1.0, Float l1=0, Float l2=0, Float beta1=0.9, Float beta2=0.999, Float epsilon=1e-8, Float weight_decay=0, Bool amsgrad=False, Bool do_bias_correction=True) => DispatchAdamUpdate" + bind_python: True + +- name: "dispatch_adagrad_update" + signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float lr_decay=0, Float weight_decay=0, Float epsilon=1e-10, Int32 train_step_val=0) => DispatchAdagradUpdate" + bind_python: True + +- name: "dispatch_momentum_update" + signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float beta=0.9, Float weight_decay=0) => DispatchMomentumUpdate" + bind_python: True + +- name: "dispatch_sgd_update" + signature: "Void (OpExpr op, TensorTuple inputs, Float learning_rate=0, Double scale=1.0, Float l1=0, Float l2=0, Float weight_decay=0) => DispatchSgdUpdate" + bind_python: True + +- name: "dispatch_eager_nccl_all_reduce" + signature: "Tensor (OpExpr op, Tensor input, String parallel_conf, Bool async_launch=False) => DispatchEagerNcclAllReduce" + bind_python: True diff --git a/oneflow/api/python/functional/py_function.cpp b/oneflow/api/python/functional/py_function.cpp index df591d5b13c..7df18262e8d 100644 --- a/oneflow/api/python/functional/py_function.cpp +++ b/oneflow/api/python/functional/py_function.cpp @@ -47,15 +47,15 @@ void ReportKwargsError(const py::kwargs& kwargs, const FunctionDef& function, si // The argument parsing refers to the implementation of Pytorch. bool ParseArgs(const py::args& args, const py::kwargs& kwargs, std::vector* parsed_args, const FunctionDef& function, size_t max_pos_args, bool raise_exception) { - bool treat_args_as_intlist = false; + bool treat_args_as_list = false; size_t nargs = args.size(); size_t remaining_kwargs = kwargs.size(); if (max_pos_args == 1) { const auto& type = function.argument_def.at(0).type; - treat_args_as_intlist = IsIntegralListType(type) || type == kSHAPE; + treat_args_as_list = IsIntegralListType(type) || type == kSHAPE || type == kTENSOR_TUPLE; } - if (nargs > max_pos_args && !treat_args_as_intlist) { + if (nargs > max_pos_args && !treat_args_as_list) { if (raise_exception) { THROW(TypeError) << function.name << "(): takes " << max_pos_args << " positional arguments but " << nargs << " were given."; @@ -83,7 +83,8 @@ bool ParseArgs(const py::args& args, const py::kwargs& kwargs, std::vector> PythonArg::ObjectAs>() const { return PyUnpackDType(object_); } +template<> +Maybe>> PythonArg::ObjectAs>>() const { + return PyUnpackDTypeSequence(object_); +} + template<> Maybe PythonArg::ObjectAs() const { const auto& shape = JUST(PyUnpackLongSequence(object_)); return std::make_shared(DimVector(shape->begin(), shape->end())); } +template<> +Maybe> PythonArg::ObjectAs>() const { + return PyUnpackShapeSequence(object_); +} + template<> Maybe> PythonArg::ObjectAs>() const { @@ -126,7 +135,7 @@ template<> Maybe> PythonArg::ObjectAs>() const { if (PyStringCheck(object_)) { const char* device_str = JUST(PyStringAsString(object_)); - return DeviceExportUtil::ParseAndNew(device_str); + return Device::ParseAndNew(device_str); } return PyUnpackDevice(object_); } @@ -156,6 +165,16 @@ Maybe PythonArg::ObjectAs() const { return PyUnpackTensorIndex(object_); } +template<> +Maybe> PythonArg::ObjectAs>() const { + return JUST(PyUnpackOpExpr(object_)); +} + +template<> +Maybe PythonArg::ObjectAs() const { + return PyUnpackOpExpr(object_); +} + template<> Maybe PythonArg::ObjectAs() const { return object_; @@ -208,7 +227,10 @@ Maybe PythonArg::TypeCheck(ValueType type) const { case kSBP_PARALLEL: return PySbpParallelCheck(object_); case kSBP_PARALLEL_LIST: return PySbpParallelSequenceCheck(object_) || PySbpParallelCheck(object_); + case kOPEXPR_REF: return PyOpExprCheck(object_); case kPY_OBJECT: return nullptr != object_; + case kDTYPE_LIST: return PyDTypeSequenceCheck(object_); + case kSHAPE_LIST: return PyShapeSequenceCheck(object_); default: { OF_UNIMPLEMENTED() << "Can not check type " << JUST(ValueTypeName(type)); } diff --git a/oneflow/api/python/functional/tensor_api.cpp b/oneflow/api/python/functional/tensor_api.cpp index 4c95aa79323..d355bae8eb1 100644 --- a/oneflow/api/python/functional/tensor_api.cpp +++ b/oneflow/api/python/functional/tensor_api.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include #include "oneflow/api/python/utils/tensor_utils.h" #include "oneflow/api/python/functional/common.h" @@ -59,9 +60,10 @@ class TensorWithDataFunctor { const auto& other = JUST(PyUnpackTensor(data)); return MakeTensorFromOtherTensor(other, dtype, device, requires_grad); + } else { + // Make tensor from python sequence or numpy array. + return MakeLocalTensorFromData(data, dtype, device, requires_grad); } - // Make tensor from python sequence or numpy array. - return MakeLocalTensorFromData(data, dtype, device, requires_grad); } }; @@ -73,6 +75,7 @@ class ConsistentTensorWithDataFunctor { const bool& requires_grad) const { // NOTE(chengcheng): flow.Tensor or flow.tensor ONLY created by EagerTensor now. LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false); + JUST(CheckDeviceIdsIsValid(placement)); if (PyTensorCheck(data)) { // Throw warnings like pytorch. @@ -105,6 +108,7 @@ class ConsistentTensorEmptyCtorFunctor { Maybe operator()(const Symbol& placement, const std::vector>& sbp_tuple) const { Shape shape(DimVector{0}); + JUST(CheckDeviceIdsIsValid(placement)); return ConsistentTensorWithShapeCtor(shape, placement, sbp_tuple); } }; @@ -146,6 +150,7 @@ class ConsistentTensorWithDataCtorFunctor { public: Maybe operator()(PyObject* data, const Symbol& placement, const std::vector>& sbp_tuple) const { + JUST(CheckDeviceIdsIsValid(placement)); // Treat the single long as shape. if (PyLong_Check(data)) { int64_t size = PyLong_AsLongLong(data); @@ -188,6 +193,7 @@ class ConsistentTensorWithShapeCtorFunctor { const std::vector>& sbp_tuple) const { // NOTE(chengcheng): flow.Tensor or flow.tensor ONLY created by EagerTensor now. LazyMode::Guard lazy_mode_disabled_guard(/*is_enabled*/ false); + JUST(CheckDeviceIdsIsValid(placement)); return functional::ConsistentEmpty(shape, DType::Float(), placement, sbp_tuple); } }; @@ -209,6 +215,67 @@ class AssignLocalTensorFunctor { std::shared_ptr op_; }; +class LocalTensorSharedNumpyDataFunctor { + public: + LocalTensorSharedNumpyDataFunctor() {} + Maybe operator()(PyObject* obj) const { + if (!PyArray_Check(obj)) { + return Error::TypeError() << "expected np.ndarray, but got " << Py_TYPE(obj)->tp_name; + } + auto* array = reinterpret_cast(obj); + + // Build TensorMeta + int32_t dim = PyArray_NDIM(array); + const npy_intp* dims_ptr = PyArray_SHAPE(array); + const auto shape = std::make_shared(DimVector(dims_ptr, dims_ptr + dim)); + DataType data_type = JUST(numpy::GetOFDataTypeFromNpArray(array)); + Symbol device = JUST(Device::New("cpu")); + const npy_intp* stride_ptr = PyArray_STRIDES(array); + // stride + auto strides_vec = DimVector(stride_ptr, stride_ptr + dim); + auto element_size_in_bytes = PyArray_ITEMSIZE(array); + // NumPy strides use bytes. OneFlow strides use element counts. + for (auto& stride : strides_vec) { + if (stride % element_size_in_bytes != 0) { + return Error::RuntimeError() << "given numpy array strides not a multiple of the element " + "byte size. Copy the numpy array to reallocate the memory."; + } + stride /= element_size_in_bytes; + } + const auto strides = std::make_shared(strides_vec); + auto tensor_meta = std::make_shared(shape, data_type, device, strides, 0); + + // Build TensorBuffer + const auto& Free = [obj](char* dptr) { + py::gil_scoped_acquire gil; + Py_DECREF(obj); + }; + Py_INCREF(obj); // make TensorBuffer hold ndarray + void* data_ptr = PyArray_DATA(array); + auto array_size_in_bytes = PyArray_NBYTES(array); + auto tensor_data = std::make_shared(); + tensor_data->set_blob_dptr( + std::unique_ptr>(static_cast(data_ptr), Free), + array_size_in_bytes); + + // Build TensorStorage: decrease ndarray reference count before releasing + auto tensor_storage = std::make_shared(tensor_data); + + // Build Tensor + auto tensor_impl = std::make_shared(tensor_meta, tensor_storage, + /*requires_grad=*/false, + /*ls_leaf=*/true); + + // Init blob + JUST(tensor_impl->InitEagerBlobObject(JUST(GetLocalDepObject4Device(*device)))); + JUST(tensor_impl->eager_blob_object())->set_last_used_device(device); + JUST(JUST(tensor_impl->eager_blob_object())->TryInitBlob()); + JUST(tensor_impl->eager_blob_object())->mut_blob()->reset_dptr(static_cast(data_ptr)); + std::shared_ptr out(new MirroredTensor(tensor_impl)); + return out; + } +}; + } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { @@ -222,6 +289,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("TensorWithShapeCtor"); m.add_functor("ConsistentTensorWithShapeCtor"); m.add_functor("AssignLocalTensorFunctor"); + m.add_functor("LocalTensorSharedNumpyData"); } } // namespace functional diff --git a/oneflow/api/python/functional/tensor_api.yaml b/oneflow/api/python/functional/tensor_api.yaml index cbe4a95e65a..fd2f2d681e2 100644 --- a/oneflow/api/python/functional/tensor_api.yaml +++ b/oneflow/api/python/functional/tensor_api.yaml @@ -14,25 +14,30 @@ - name: "tensor" signature: [ - "Tensor (PyObject* data, *, DataType dtype=None, Device device=None, - Bool requires_grad=False) => TensorWithData", - "Tensor (PyObject* data, *, DataType dtype=None, Placement placement, - SbpList sbp, Bool requires_grad=False) => ConsistentTensorWithData", - ] + "Tensor (PyObject* data, *, DataType dtype=None, Device device=None, + Bool requires_grad=False) => TensorWithData", + "Tensor (PyObject* data, *, DataType dtype=None, Placement placement, + SbpList sbp, Bool requires_grad=False) => ConsistentTensorWithData", + ] bind_python: True - name: "_legacy_tensor_ctor" - signature: [ - "Tensor (*, Device device=None) => TensorEmptyCtor", - "Tensor (*, Placement placement, SbpList sbp) => ConsistentTensorEmptyCtor", - "Tensor (Tensor other) => TensorWithOtherCtor", - "Tensor (PyObject* data, *, Device device=None) => TensorWithDataCtor", - "Tensor (PyObject* data, *, Placement placement, SbpList sbp) => ConsistentTensorWithDataCtor", - "Tensor (Shape size, *, Device device=None) => TensorWithShapeCtor", - "Tensor (Shape size, *, Placement placement, SbpList sbp) => ConsistentTensorWithShapeCtor", - ] + signature: + [ + "Tensor (*, Device device=None) => TensorEmptyCtor", + "Tensor (*, Placement placement, SbpList sbp) => ConsistentTensorEmptyCtor", + "Tensor (Tensor other) => TensorWithOtherCtor", + "Tensor (PyObject* data, *, Device device=None) => TensorWithDataCtor", + "Tensor (PyObject* data, *, Placement placement, SbpList sbp) => ConsistentTensorWithDataCtor", + "Tensor (Shape size, *, Device device=None) => TensorWithShapeCtor", + "Tensor (Shape size, *, Placement placement, SbpList sbp) => ConsistentTensorWithShapeCtor", + ] bind_python: True - name: "assign_local_tensor" - signature: "Void (Tensor ref, Tensor value)=> AssignLocalTensorFunctor" + signature: "Void (Tensor ref, Tensor value) => AssignLocalTensorFunctor" + bind_python: True + +- name: "from_numpy" + signature: "Tensor (PyObject* obj) => LocalTensorSharedNumpyData" bind_python: True diff --git a/oneflow/api/python/functional/value_types.cpp b/oneflow/api/python/functional/value_types.cpp index fe2631f1e9a..61319bd77eb 100644 --- a/oneflow/api/python/functional/value_types.cpp +++ b/oneflow/api/python/functional/value_types.cpp @@ -51,9 +51,10 @@ HashMap* GetValueTypeNameMap() { {kTENSOR_TUPLE_MAYBE, "maybe tensor tuple"}, {kATTR, "attr"}, {kATTR_REF, "attr"}, - {kATTR_MAP, "attr map"}, {kDTYPE, "data type"}, + {kDTYPE_LIST, "data type list"}, {kSHAPE, "shape"}, + {kSHAPE_LIST, "shape list"}, {kGENERATOR, "generator"}, {kGENERATOR_REF, "generator"}, {kGENERATOR_MAYBE, "maybe generator"}, @@ -62,6 +63,8 @@ HashMap* GetValueTypeNameMap() { {kPARALLEL_DESC, "placement"}, {kSBP_PARALLEL, "sbp"}, {kSBP_PARALLEL_LIST, "sbp list"}, + {kOPEXPR, "opexpr"}, + {kOPEXPR_REF, "opexpr"}, {kPY_OBJECT, "python object"}, }; return &value_type_name_map; diff --git a/oneflow/api/python/functional/value_types.h b/oneflow/api/python/functional/value_types.h index bab62f2b178..e76e49830f8 100644 --- a/oneflow/api/python/functional/value_types.h +++ b/oneflow/api/python/functional/value_types.h @@ -28,7 +28,6 @@ limitations under the License. namespace oneflow { class Scalar; class Shape; -class AttrMap; template class Symbol; @@ -45,6 +44,7 @@ namespace one { class Tensor; class TensorTuple; class Generator; +class OpExpr; namespace functional { class TensorIndex; @@ -96,7 +96,6 @@ enum ValueType : int { kTENSOR_TUPLE_MAYBE, kATTR, kATTR_REF, - kATTR_MAP, kDTYPE, kSHAPE, kGENERATOR, @@ -107,7 +106,11 @@ enum ValueType : int { kPARALLEL_DESC, kSBP_PARALLEL, kSBP_PARALLEL_LIST, + kSHAPE_LIST, + kDTYPE_LIST, + kOPEXPR = 390, + kOPEXPR_REF, kPY_OBJECT = 400, }; @@ -152,9 +155,10 @@ VALUE_TYPE_OF_IMPL(std::shared_ptr, kTENSOR_TUPLE_REF); VALUE_TYPE_OF_IMPL(Maybe, kTENSOR_TUPLE_MAYBE); VALUE_TYPE_OF_IMPL(cfg::AttrValue, kATTR); VALUE_TYPE_OF_IMPL(std::shared_ptr, kATTR_REF); -VALUE_TYPE_OF_IMPL(AttrMap, kATTR_MAP); VALUE_TYPE_OF_IMPL(Symbol, kDTYPE); +VALUE_TYPE_OF_IMPL(std::vector>, kDTYPE_LIST); VALUE_TYPE_OF_IMPL(Shape, kSHAPE); +VALUE_TYPE_OF_IMPL(std::vector, kSHAPE_LIST); VALUE_TYPE_OF_IMPL(one::Generator, kGENERATOR); VALUE_TYPE_OF_IMPL(std::shared_ptr, kGENERATOR_REF); VALUE_TYPE_OF_IMPL(Maybe, kGENERATOR_MAYBE); @@ -164,6 +168,9 @@ VALUE_TYPE_OF_IMPL(Symbol, kPARALLEL_DESC); VALUE_TYPE_OF_IMPL(Symbol, kSBP_PARALLEL); VALUE_TYPE_OF_IMPL(std::vector>, kSBP_PARALLEL_LIST); +VALUE_TYPE_OF_IMPL(one::OpExpr, kOPEXPR); +VALUE_TYPE_OF_IMPL(std::shared_ptr, kOPEXPR_REF); + VALUE_TYPE_OF_IMPL(PyObject*, kPY_OBJECT); VALUE_TYPE_OF_IMPL(const PyObject*, kPY_OBJECT); diff --git a/oneflow/api/python/ir.cpp b/oneflow/api/python/ir.cpp index 5840cb9d716..422242d37c4 100644 --- a/oneflow/api/python/ir.cpp +++ b/oneflow/api/python/ir.cpp @@ -28,9 +28,6 @@ ONEFLOW_API_PYBIND11_MODULE("ir", m) { [](const std::string& lib_path) { MutSharedLibPaths()->insert(lib_path); }); } -REGISTER_JOB_PASS("IRRoundTripBeforeAD", IRRoundTrip); -REGISTER_JOB_PASS("IRRoundTrip", IRRoundTrip); - } // namespace oneflow #endif // WITH_MLIR diff --git a/oneflow/api/python/job_build/job_build_and_infer.cpp b/oneflow/api/python/job_build/job_build_and_infer.cpp index fd25d18f1b3..39d38199b9d 100644 --- a/oneflow/api/python/job_build/job_build_and_infer.cpp +++ b/oneflow/api/python/job_build/job_build_and_infer.cpp @@ -47,7 +47,7 @@ ONEFLOW_API_PYBIND11_MODULE("", m) { m.def("JobBuildAndInferCtx_GetDataType", &JobBuildAndInferCtx_GetDataType); m.def("JobBuildAndInferCtx_IsDynamic", &JobBuildAndInferCtx_IsDynamic); - m.def("JobBuildAndInferCtx_DisableBoxing", &JobBuildAndInferCtx_DisableBoxing); + m.def("JobBuildAndInferCtx_IsDisableBoxing", &JobBuildAndInferCtx_IsDisableBoxing); m.def("JobBuildAndInferCtx_GetSplitAxisFromProducerView", &JobBuildAndInferCtx_GetSplitAxisFromProducerView); diff --git a/oneflow/api/python/job_build/job_build_and_infer.h b/oneflow/api/python/job_build/job_build_and_infer.h index 3d88c27d109..c05d3c33a57 100644 --- a/oneflow/api/python/job_build/job_build_and_infer.h +++ b/oneflow/api/python/job_build/job_build_and_infer.h @@ -114,10 +114,10 @@ inline Maybe JobBuildAndInferCtx_IsDynamic(const std::string& job_name, return ctx->IsDynamic(lbn); } -inline Maybe JobBuildAndInferCtx_DisableBoxing(const std::string& job_name, - const std::string& lbn) { +inline Maybe JobBuildAndInferCtx_IsDisableBoxing(const std::string& job_name, + const std::string& lbn) { auto* ctx = JUST(GetJobBuildAndInferCtx(job_name)); - return ctx->DisableBoxing(lbn); + return ctx->IsDisableBoxing(lbn); } inline Maybe JobBuildAndInferCtx_GetSplitAxisFromProducerView( diff --git a/oneflow/api/python/job_build/job_build_and_infer_api.h b/oneflow/api/python/job_build/job_build_and_infer_api.h index ac9399a18a0..f092ea01d54 100644 --- a/oneflow/api/python/job_build/job_build_and_infer_api.h +++ b/oneflow/api/python/job_build/job_build_and_infer_api.h @@ -90,8 +90,9 @@ inline bool JobBuildAndInferCtx_IsDynamic(const std::string& job_name, const std return oneflow::JobBuildAndInferCtx_IsDynamic(job_name, lbn).GetOrThrow(); } -inline bool JobBuildAndInferCtx_DisableBoxing(const std::string& job_name, const std::string& lbn) { - return oneflow::JobBuildAndInferCtx_DisableBoxing(job_name, lbn).GetOrThrow(); +inline bool JobBuildAndInferCtx_IsDisableBoxing(const std::string& job_name, + const std::string& lbn) { + return oneflow::JobBuildAndInferCtx_IsDisableBoxing(job_name, lbn).GetOrThrow(); } inline std::string JobBuildAndInferCtx_GetSplitAxisFromProducerView(const std::string& job_name, diff --git a/oneflow/api/python/symbol/placement_symbol.cpp b/oneflow/api/python/symbol/placement_symbol.cpp index 3547aeb91db..926147d13f7 100644 --- a/oneflow/api/python/symbol/placement_symbol.cpp +++ b/oneflow/api/python/symbol/placement_symbol.cpp @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include #include @@ -30,6 +31,9 @@ limitations under the License. #include "oneflow/core/job/placement.cfg.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/resource_desc.h" +#ifdef WITH_CUDA +#include +#endif // WITH_CUDA namespace py = pybind11; @@ -37,6 +41,16 @@ namespace oneflow { namespace { +int64_t GetGpuDeviceNum() { +#ifndef WITH_CUDA + return 0; +#else + int device_count = 0; + cudaGetDeviceCount(&device_count); + return device_count; +#endif +} + Maybe MakeShape(const py::tuple& py_shape) { DimVector shape_dims{}; for (const auto& dim : py_shape) { shape_dims.emplace_back(dim.cast()); } @@ -150,6 +164,12 @@ struct PlacementSymbolExportUtil { if (iter == device_tag2placement.end()) { int64_t node_size = GlobalProcessCtx::NodeSize(); int64_t device_num = GlobalProcessCtx::NumOfProcessPerNode(); + if (device_tag == "gpu") { + const int64_t gpu_device_num = GetGpuDeviceNum(); + CHECK_NE(gpu_device_num, 0) + << "Can\'t construct placment with \"cuda\" type because there is no CUDA device!"; + device_num = std::min(device_num, gpu_device_num); + } std::vector machine_device_ids; for (int64_t node_id = 0; node_id < node_size; ++node_id) { std::string device_name = std::to_string(node_id) + ":0-" + std::to_string(device_num - 1); diff --git a/oneflow/api/python/symbol/sbp_symbol.cpp b/oneflow/api/python/symbol/sbp_symbol.cpp index a21eaa0325e..406e590a222 100644 --- a/oneflow/api/python/symbol/sbp_symbol.cpp +++ b/oneflow/api/python/symbol/sbp_symbol.cpp @@ -16,6 +16,7 @@ limitations under the License. #include #include #include "oneflow/api/python/of_api_registry.h" +#include "oneflow/api/common/sbp.h" #include "oneflow/core/common/util.h" #include "oneflow/core/common/constant.h" #include "oneflow/core/common/maybe.h" @@ -32,7 +33,7 @@ namespace oneflow { namespace { std::string SbpParallelSymbolToString(const Symbol& sbp_sym) { - return *SbpToString(sbp_sym).GetPtrOrThrow(); + return *api::SbpToString(sbp_sym).GetPtrOrThrow(); } Maybe>> MakeSplitSbpParallelList(int max_split_axis) { diff --git a/oneflow/api/python/utils/tensor_utils.cpp b/oneflow/api/python/utils/tensor_utils.cpp index 5f580eacf89..f4c35b31f7b 100644 --- a/oneflow/api/python/utils/tensor_utils.cpp +++ b/oneflow/api/python/utils/tensor_utils.cpp @@ -24,7 +24,7 @@ limitations under the License. #include "oneflow/core/functional/functional.h" #include "oneflow/extension/python/numpy.h" #include "oneflow/core/common/decorator.h" -#include "oneflow/core/framework/data_consistency_check.h" +#include "oneflow/core/framework/consistency_check.h" namespace py = pybind11; diff --git a/oneflow/core/autograd/gradient_funcs/concat.cpp b/oneflow/core/autograd/gradient_funcs/concat.cpp index d27975d9dd1..86adc4545ff 100644 --- a/oneflow/core/autograd/gradient_funcs/concat.cpp +++ b/oneflow/core/autograd/gradient_funcs/concat.cpp @@ -65,11 +65,15 @@ Maybe Concat::Apply(const ConcatCaptureState* ctx, const TensorTuple& out_ in_grads->resize(ctx->input_num); TensorTuple like(ctx->input_num); for (int i = 0; i < ctx->input_num; ++i) { like[i] = ctx->SavedTensors().at(i); } - const auto& results = JUST(functional::SplitLike(out_grads.at(0), like, ctx->axis)); - CHECK_EQ_OR_RETURN(results->size(), ctx->input_num); + if (ctx->input_num == 1) { + in_grads->at(0) = out_grads.at(0); + } else { + const auto& results = JUST(functional::SplitLike(out_grads.at(0), like, ctx->axis)); + CHECK_EQ_OR_RETURN(results->size(), ctx->input_num); - for (int i = 0; i < ctx->input_num; ++i) - if (ctx->requires_grad.at(i)) { in_grads->at(i) = results->at(i); } + for (int i = 0; i < ctx->input_num; ++i) + if (ctx->requires_grad.at(i)) { in_grads->at(i) = results->at(i); } + } return Maybe::Ok(); } diff --git a/oneflow/core/autograd/gradient_funcs/consistent_cast.cpp b/oneflow/core/autograd/gradient_funcs/consistent_cast.cpp index 5534614a0f3..92216bc47a0 100644 --- a/oneflow/core/autograd/gradient_funcs/consistent_cast.cpp +++ b/oneflow/core/autograd/gradient_funcs/consistent_cast.cpp @@ -19,30 +19,11 @@ limitations under the License. #include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h" #include "oneflow/core/framework/tensor_rpc_util.h" #include "oneflow/core/common/decorator.h" +#include "oneflow/core/functional/functional.h" namespace oneflow { namespace one { -namespace { - -Maybe CalcBoxingOutput(const std::shared_ptr& input, Symbol out_nd_sbp, - Symbol out_parallel_desc) { - const auto* mgr = Global::Get(); - // Eager boxing - const auto& in_nd_sbp = JUST(input->nd_sbp()); - const auto& in_parallel_desc = JUST(input->parallel_desc()); - const auto& boxing_interpreter = JUST( - mgr->GetEagerBoxingInterpreter(in_nd_sbp, out_nd_sbp, in_parallel_desc, out_parallel_desc)); - const auto& output = JUST(boxing_interpreter->Interpret(input, in_nd_sbp, out_nd_sbp, - in_parallel_desc, out_parallel_desc)); - return output; -} - -static constexpr auto* RecursiveGetBoxingOutput = - DECORATE(&CalcBoxingOutput, CheckConsistentTensorMeta); - -} // namespace - struct CastConsistentCaptureState : public AutoGradCaptureState { Symbol parallel_desc; Symbol nd_sbp; @@ -77,7 +58,8 @@ class CastToConsistent : public OpExprGradFunction { Symbol nd_sbp_constraint = ctx->nd_sbp; Symbol parallel_desc_constraint = ctx->parallel_desc; out_grad = - JUST(RecursiveGetBoxingOutput(out_grad, nd_sbp_constraint, parallel_desc_constraint)); + JUST(functional::ToConsistent(out_grad, parallel_desc_constraint, + *JUST(GetSbpList(nd_sbp_constraint)), GetNoneSbpList())); } in_grads->at(0) = JUST(OpInterpUtil::Dispatch(*grad_op_, {out_grad})); return Maybe::Ok(); diff --git a/oneflow/core/autograd/gradient_funcs/cumsum.cpp b/oneflow/core/autograd/gradient_funcs/cumsum.cpp new file mode 100644 index 00000000000..652e28a4849 --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/cumsum.cpp @@ -0,0 +1,64 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/functional/functional.h" + +namespace oneflow { +namespace one { + +struct CumsumCaptureState : public AutoGradCaptureState { + bool requires_grad = false; + int64_t dim = 0; +}; + +class CumsumGrad : public OpExprGradFunction { + public: + Maybe Init(const OpExpr& op) override { + const auto* fw_op_expr = dynamic_cast(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + return Maybe::Ok(); + } + + Maybe Capture(CumsumCaptureState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override { + CHECK_EQ_OR_RETURN(inputs.size(), 1); + ctx->requires_grad = inputs.at(0)->requires_grad(); + if (!ctx->requires_grad) { return Maybe::Ok(); } + + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->dim = JUST(composed_attrs.GetAttr("dim")); + return Maybe::Ok(); + } + + Maybe Apply(const CumsumCaptureState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override { + CHECK_EQ_OR_RETURN(out_grads.size(), 1); + in_grads->resize(1); + if (ctx->requires_grad) { + in_grads->at(0) = JUST(functional::CumsumGrad(out_grads.at(0), ctx->dim)); + } + return Maybe::Ok(); + } + + private: + AttrMap base_attrs_; +}; + +REGISTER_OP_EXPR_GRAD_FUNCTION("cumsum", CumsumGrad); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/autograd/gradient_funcs/deconv.cpp b/oneflow/core/autograd/gradient_funcs/deconv.cpp index 1ed4f16463b..84248c0831c 100644 --- a/oneflow/core/autograd/gradient_funcs/deconv.cpp +++ b/oneflow/core/autograd/gradient_funcs/deconv.cpp @@ -31,6 +31,7 @@ struct DeConvolutionNdCaptureState : public AutoGradCaptureState { std::vector kernel_size; std::vector strides; std::vector dilation_rate; + int32_t groups; }; class DeConvolutionNd : public OpExprGradFunction { @@ -69,6 +70,7 @@ Maybe DeConvolutionNd::Capture(DeConvolutionNdCaptureState* ctx, const Ten ctx->kernel_size = JUST(composed_attrs.GetAttr>("kernel_size")); ctx->strides = JUST(composed_attrs.GetAttr>("strides")); ctx->dilation_rate = JUST(composed_attrs.GetAttr>("dilation_rate")); + ctx->groups = JUST(composed_attrs.GetAttr("groups")); ctx->ndims = ctx->kernel_size.size(); return Maybe::Ok(); } @@ -86,21 +88,21 @@ Maybe DeConvolutionNd::Apply(const DeConvolutionNdCaptureState* ctx, } const auto& weight = ctx->SavedTensors().at(0); if (ctx->ndims == 1) { - std::shared_ptr result = - JUST(functional::Conv1d(out_grads.at(0), weight, Optional(), ctx->strides, - ctx->padding_before, ctx->dilation_rate, /*groups=*/1)); + std::shared_ptr result = JUST(functional::Conv1d( + out_grads.at(0), weight, Optional(), ctx->strides, ctx->padding_before, + ctx->dilation_rate, ctx->groups, ctx->data_format)); result = JUST(functional::Slice(result, start, stop, step)); in_grads->at(0) = result; } else if (ctx->ndims == 2) { - std::shared_ptr result = - JUST(functional::Conv2d(out_grads.at(0), weight, Optional(), ctx->strides, - ctx->padding_before, ctx->dilation_rate, /*groups=*/1)); + std::shared_ptr result = JUST(functional::Conv2d( + out_grads.at(0), weight, Optional(), ctx->strides, ctx->padding_before, + ctx->dilation_rate, ctx->groups, ctx->data_format)); result = JUST(functional::Slice(result, start, stop, step)); in_grads->at(0) = result; } else if (ctx->ndims == 3) { - std::shared_ptr result = - JUST(functional::Conv3d(out_grads.at(0), weight, Optional(), ctx->strides, - ctx->padding_before, ctx->dilation_rate, /*groups=*/1)); + std::shared_ptr result = JUST(functional::Conv3d( + out_grads.at(0), weight, Optional(), ctx->strides, ctx->padding_before, + ctx->dilation_rate, ctx->groups, ctx->data_format)); result = JUST(functional::Slice(result, start, stop, step)); in_grads->at(0) = result; } else { @@ -112,7 +114,7 @@ Maybe DeConvolutionNd::Apply(const DeConvolutionNdCaptureState* ctx, const auto& x = ctx->SavedTensors().at(idx); in_grads->at(1) = JUST(functional::ConvFilterGrad( x, out_grads.at(0), ctx->ndims, ctx->kernel_size, ctx->strides, ctx->padding_before, - ctx->dilation_rate, /*groups=*/1, ctx->data_format)); + ctx->dilation_rate, ctx->groups, ctx->data_format)); } return Maybe::Ok(); } diff --git a/oneflow/core/autograd/gradient_funcs/diagonal.cpp b/oneflow/core/autograd/gradient_funcs/diagonal.cpp new file mode 100644 index 00000000000..a79d241e176 --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/diagonal.cpp @@ -0,0 +1,72 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/functional/functional.h" + +namespace oneflow { +namespace one { + +struct DiagonalInterpState : public AutoGradCaptureState { + bool requires_grad = false; + int32_t offset = 0; +}; + +class Diagonal : public OpExprGradFunction { + public: + Maybe Init(const OpExpr& op) override; + Maybe Capture(DiagonalInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const override; + Maybe Apply(const DiagonalInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override; + + private: + AttrMap base_attrs_; +}; + +Maybe Diagonal::Init(const OpExpr& op) { + const UserOpExpr* fw_op_expr = dynamic_cast(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + return Maybe::Ok(); +} + +Maybe Diagonal::Capture(DiagonalInterpState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const { + CHECK_EQ_OR_RETURN(outputs.size(), 1); + ctx->requires_grad = inputs.at(0)->requires_grad(); + if (!ctx->requires_grad) { return Maybe::Ok(); } + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->offset = JUST(composed_attrs.GetAttr("offset")); + ctx->SaveTensorForBackward(inputs.at(0)); + return Maybe::Ok(); +} + +Maybe Diagonal::Apply(const DiagonalInterpState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const { + CHECK_EQ_OR_RETURN(out_grads.size(), 1); + in_grads->resize(2); + if (ctx->requires_grad) { + const auto& x = ctx->SavedTensors().at(0); + in_grads->at(0) = JUST(functional::DiagonalGrad(out_grads.at(0), x, ctx->offset)); + } + return Maybe::Ok(); +} + +REGISTER_OP_EXPR_GRAD_FUNCTION("diagonal", Diagonal); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/autograd/gradient_funcs/dot.cpp b/oneflow/core/autograd/gradient_funcs/dot.cpp index 3412c36f426..0db6084eda4 100644 --- a/oneflow/core/autograd/gradient_funcs/dot.cpp +++ b/oneflow/core/autograd/gradient_funcs/dot.cpp @@ -21,9 +21,7 @@ namespace one { struct DotCaptureState : public AutoGradCaptureState { bool x_requires_grad = false; - ; bool y_requires_grad = false; - ; size_t x_offset = 0; size_t y_offset = 0; }; diff --git a/oneflow/core/autograd/gradient_funcs/slice.cpp b/oneflow/core/autograd/gradient_funcs/slice.cpp index f4455a41e0d..1e8fe4f83f4 100644 --- a/oneflow/core/autograd/gradient_funcs/slice.cpp +++ b/oneflow/core/autograd/gradient_funcs/slice.cpp @@ -24,6 +24,7 @@ namespace one { struct SliceCaptureState : public AutoGradCaptureState { bool requires_grad; + Shape like_shape; std::vector start; std::vector stop; std::vector step; @@ -49,17 +50,15 @@ class Slice : public OpExprGradFunction { ctx->start = JUST(composed_attrs.GetAttr>("start")); ctx->stop = JUST(composed_attrs.GetAttr>("stop")); ctx->step = JUST(composed_attrs.GetAttr>("step")); - ctx->SaveTensorForBackward(inputs.at(0)); + ctx->like_shape = *(inputs.at(0)->shape()); return Maybe::Ok(); } Maybe Apply(const SliceCaptureState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override { - const auto& like = ctx->SavedTensors().at(0); - in_grads->resize(1); - in_grads->at(0) = - JUST(functional::SliceGrad(out_grads.at(0), like, ctx->start, ctx->stop, ctx->step)); + in_grads->at(0) = JUST( + functional::SliceGrad(out_grads.at(0), ctx->like_shape, ctx->start, ctx->stop, ctx->step)); return Maybe::Ok(); } diff --git a/oneflow/core/boxing/boxing_dividor_util.cpp b/oneflow/core/boxing/boxing_dividor_util.cpp index 437a678ea1b..fffa308ace1 100644 --- a/oneflow/core/boxing/boxing_dividor_util.cpp +++ b/oneflow/core/boxing/boxing_dividor_util.cpp @@ -74,9 +74,43 @@ Maybe RawFlattenInHierarchy() { }); } +Maybe> RawUnflattenHierarchy(Symbol in_placed_nd_sbp, + Symbol out_placed_nd_sbp) { + CHECK_GE_OR_RETURN(in_placed_nd_sbp->nd_sbp()->sbp_parallel_size(), 0); + CHECK_GE_OR_RETURN(out_placed_nd_sbp->nd_sbp()->sbp_parallel_size(), 0); + const auto& in_sbp_parallel = in_placed_nd_sbp->nd_sbp()->sbp_parallel(0); + cfg::NdSbp unflattened_nd_sbp; + for (int64_t i = 0; i < out_placed_nd_sbp->nd_sbp()->sbp_parallel_size(); ++i) { + unflattened_nd_sbp.mutable_sbp_parallel()->Add()->CopyFrom(in_sbp_parallel); + } + return JUST(PlacedNdSbp::New(SymbolOf(unflattened_nd_sbp), out_placed_nd_sbp->placement())); +} + +static constexpr auto* UnflattenHierarchy = DECORATE(&RawUnflattenHierarchy, ThreadLocal); + +Maybe RawUnflattenInHierarchy() { + return std::make_shared( + "UnflattenInHierarchy", + [](Symbol in, Symbol out) -> Maybe> { + return UnflattenHierarchy(in, out); + }); +} + +Maybe RawUnflattenOutHierarchy() { + return std::make_shared( + "UnflattenOutHierarchy", + [](Symbol in, Symbol out) -> Maybe> { + return UnflattenHierarchy(out, in); + }); +} + } // namespace decltype(FlattenInHierarchy) FlattenInHierarchy = DECORATE(&RawFlattenInHierarchy, ThreadLocal); +decltype(UnflattenInHierarchy) UnflattenInHierarchy = + DECORATE(&RawUnflattenInHierarchy, ThreadLocal); +decltype(UnflattenOutHierarchy) UnflattenOutHierarchy = + DECORATE(&RawUnflattenOutHierarchy, ThreadLocal); namespace { diff --git a/oneflow/core/boxing/boxing_dividor_util.h b/oneflow/core/boxing/boxing_dividor_util.h index 48089846b98..18a279c284c 100644 --- a/oneflow/core/boxing/boxing_dividor_util.h +++ b/oneflow/core/boxing/boxing_dividor_util.h @@ -24,6 +24,8 @@ namespace oneflow { extern Maybe (*ReplaceInDeviceType)(DeviceType device_type); extern Maybe (*ReplaceOutDeviceType)(DeviceType device_type); extern Maybe (*FlattenInHierarchy)(); +extern Maybe (*UnflattenInHierarchy)(); +extern Maybe (*UnflattenOutHierarchy)(); extern Maybe (*OutPlacementAndPartialSum)(); extern Maybe (*InPlacementAndBroadcast)(); extern Maybe (*OutPlacementAndBroadcast)(); diff --git a/oneflow/core/boxing/eager_boxing_interpreter_mgr.cpp b/oneflow/core/boxing/eager_boxing_interpreter_mgr.cpp index 86c3531ceb7..38074f1da61 100644 --- a/oneflow/core/boxing/eager_boxing_interpreter_mgr.cpp +++ b/oneflow/core/boxing/eager_boxing_interpreter_mgr.cpp @@ -61,6 +61,19 @@ Maybe OneToNBoxingExpr() { | JUST(BoxingExpr("identity")))); } +Maybe SymmetricOnedToNdBoxingExpr() { + return JUST( + BoxingExpr(JUST(UnflattenInHierarchy()), JUST(BoxingExpr("unflatten-hierarchy")), + JUST(BoxingExpr("symmetric-nd-sbp-to-nd-sbp")) | JUST(BoxingExpr("identity")))); +} + +Maybe SymmetricNdToOnedBoxingExpr() { + return JUST( + BoxingExpr(JUST(UnflattenOutHierarchy()), + JUST(BoxingExpr("symmetric-nd-sbp-to-nd-sbp")) | JUST(BoxingExpr("identity")), + JUST(BoxingExpr("flatten-hierarchy")))); +} + Maybe GenericBoxingExpr() { // in_placement contain out_placement or out_placement contain in_placement const auto& boxing_expr_with_inclusive_placement = @@ -79,20 +92,20 @@ Maybe GenericBoxingExpr() { } Maybe RawMainBoxingExpr() { - const auto& core = JUST(BoxingExpr("identity")) | JUST(BoxingExpr("flatten-hierarchy")) - | JUST(BoxingExpr("cuda-copy-h2d")) | JUST(BoxingExpr("cuda-copy-d2h")) - | JUST(BoxingExpr("nccl-p-to-b")) | JUST(BoxingExpr("ccl-p-to-b")) - | JUST(BoxingExpr("nccl-s-to-b")) | JUST(BoxingExpr("ccl-s-to-b")) - | JUST(BoxingExpr("nccl-s-to-s")) | JUST(BoxingExpr("ccl-s-to-s")) - | JUST(BoxingExpr("nccl-p-to-s")) | JUST(BoxingExpr("ccl-p-to-s")) - | JUST(BoxingExpr("symmetric-b-to-p")) | JUST(BoxingExpr("symmetric-b-to-s")) - | JUST(BoxingExpr("symmetric-s-to-p")) + const auto& core = JUST(BoxingExpr("identity")) | JUST(BoxingExpr("cuda-copy-h2d")) + | JUST(BoxingExpr("cuda-copy-d2h")) | JUST(BoxingExpr("nccl-p-to-b")) + | JUST(BoxingExpr("ccl-p-to-b")) | JUST(BoxingExpr("nccl-s-to-b")) + | JUST(BoxingExpr("ccl-s-to-b")) | JUST(BoxingExpr("nccl-s-to-s")) + | JUST(BoxingExpr("ccl-s-to-s")) | JUST(BoxingExpr("nccl-p-to-s")) + | JUST(BoxingExpr("ccl-p-to-s")) | JUST(BoxingExpr("symmetric-b-to-p")) + | JUST(BoxingExpr("symmetric-b-to-s")) | JUST(BoxingExpr("symmetric-s-to-p")) | JUST(BoxingExpr("symmetric-nd-sbp-to-nd-sbp")) | JUST(BoxingExpr("asymmetric-x-to-b")) | JUST(BoxingExpr("naive-s-to-s")) | JUST(BoxingExpr("naive-1-to-1")) | JUST(BoxingExpr("naive-s-to-b")) | JUST(BoxingExpr("naive-b-to-s")) | JUST(BoxingExpr("naive-p-to-b")) | JUST(BoxingExpr("naive-p-to-s")) | JUST(OneToNBoxingExpr()) - | JUST(NToOneBoxingExpr()) | JUST(GenericBoxingExpr()); + | JUST(NToOneBoxingExpr()) | JUST(GenericBoxingExpr()) + | JUST(SymmetricOnedToNdBoxingExpr()) | JUST(SymmetricNdToOnedBoxingExpr()); return core | JUST(OptionalCudaCopy(core)); } @@ -114,8 +127,8 @@ Maybe GetBoxingInterpreter(Symbol in_nd_sbp, UNIMPLEMENTED_THEN_RETURN() << Error::BoxingNotSupportedError() << "consistent-to-consistent not supported" - << ". from_nd_sbp: " << *JUST(NdSbpToString(in_nd_sbp)) - << ", to_nd_sbp: " << *JUST(NdSbpToString(out_nd_sbp)) + << ". from_nd_sbp: " << NdSbpToString(in_nd_sbp) + << ", to_nd_sbp: " << NdSbpToString(out_nd_sbp) << ", from_placement: " << *JUST(PlacementToString(in_parallel_desc)) << ", to_placement: " << *JUST(PlacementToString(out_parallel_desc)); } diff --git a/oneflow/core/boxing/unflatten_hierarchy.cpp b/oneflow/core/boxing/unflatten_hierarchy.cpp new file mode 100644 index 00000000000..d87ddd9225c --- /dev/null +++ b/oneflow/core/boxing/unflatten_hierarchy.cpp @@ -0,0 +1,62 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/nd_sbp.h" +#include "oneflow/core/boxing/eager_boxing_interpreter.h" +#include "oneflow/core/functional/functional.h" +#include "oneflow/core/common/decorator.h" + +namespace oneflow { + +namespace { + +Maybe RawCheckUnflattenHierarchy(Symbol in, Symbol out) { + CHECK_EQ_OR_RETURN(in->nd_sbp()->sbp_parallel_size(), 1); + CHECK_GT_OR_RETURN(out->nd_sbp()->sbp_parallel_size(), 1); + for (int i = 0; i < out->nd_sbp()->sbp_parallel_size(); ++i) { + const auto& sbp_parallel = out->nd_sbp()->sbp_parallel(i); + CHECK_OR_RETURN(sbp_parallel == out->nd_sbp()->sbp_parallel(0)) << "nd_sbp axis: " << i; + } + CHECK_EQ_OR_RETURN(in->placement()->device_type(), out->placement()->device_type()); + CHECK_EQ_OR_RETURN(in->placement()->parallel_num(), out->placement()->parallel_num()); + ParallelConf unflattened_parallel_conf(in->placement()->parallel_conf()); + unflattened_parallel_conf.mutable_hierarchy()->CopyFrom( + out->placement()->parallel_conf().hierarchy()); + const auto& unflatten_placement = SymbolOf(ParallelDesc(unflattened_parallel_conf)); + CHECK_OR_RETURN(unflatten_placement == out->placement()) + << "The output placement is not a hierarch-unflattened version of the input placement"; + return Maybe::Ok(); +} + +} // namespace + +static constexpr auto* CheckUnflattenHierarchy = DECORATE(&RawCheckUnflattenHierarchy, ThreadLocal); + +Maybe UnflattenHierarchy(const std::shared_ptr& tensor, + Symbol in, Symbol out) { + const auto& tensor_nd_sbp = JUST(tensor->nd_sbp()); + CHECK_OR_RETURN(tensor_nd_sbp == in->nd_sbp()); + const auto& tensor_placement = JUST(tensor->parallel_desc()); + CHECK_OR_RETURN(tensor_placement == in->placement()); + const auto& local_tensor = JUST(tensor->cur_rank_phy_tensor()); + const auto& sbp_list = JUST(GetSbpList(out->nd_sbp())); + return JUST(one::functional::LocalToConsistent(local_tensor, out->placement(), *sbp_list, + *tensor->shape(), tensor->dtype())); +} + +COMMAND(RegisterBoxingFunction("unflatten-hierarchy", CheckUnflattenHierarchy, + &UnflattenHierarchy)); + +} // namespace oneflow diff --git a/oneflow/core/common/buffer_manager.h b/oneflow/core/common/buffer_manager.h index d386b714513..fe7a5ada9d9 100644 --- a/oneflow/core/common/buffer_manager.h +++ b/oneflow/core/common/buffer_manager.h @@ -50,6 +50,26 @@ inline std::string GetCallbackNotifierBufferName(const std::string& job_name) { return prefix + job_name; } +inline std::string GetInputCriticalSectionWaitBufferName(const std::string& job_name) { + static const std::string prefix = "InputCriticalSectionWait-"; + return prefix + job_name; +} + +inline std::string GetInputCriticalSectionCallbackBufferName(const std::string& job_name) { + static const std::string prefix = "InputCriticalSectionCallback-"; + return prefix + job_name; +} + +inline std::string GetOutputCriticalSectionWaitBufferName(const std::string& job_name) { + static const std::string prefix = "OutputCriticalSectionWait-"; + return prefix + job_name; +} + +inline std::string GetOutputCriticalSectionCallbackBufferName(const std::string& job_name) { + static const std::string prefix = "OutputCriticalSectionCallback-"; + return prefix + job_name; +} + inline std::string GetForeignInputBufferName(const std::string& job_name) { static const std::string prefix = "ForeignInput-"; return prefix + job_name; diff --git a/oneflow/core/common/data_type.proto b/oneflow/core/common/data_type.proto index 9a2b35bbe45..4fdb0760fbe 100644 --- a/oneflow/core/common/data_type.proto +++ b/oneflow/core/common/data_type.proto @@ -15,7 +15,15 @@ enum DataType { kTensorBuffer = 10; kBFloat16 = 11; kBool = 12; - kMaxDataType = 13; + kUInt16 = 13; + kUInt32 = 14; + kUInt64 = 15; + kUInt128 = 16; + kInt16 = 17; + kInt128 = 18; + kComplex32 = 19; + kComplex64 = 20; + kComplex128 = 21; } message OptInt64 { diff --git a/oneflow/core/common/high_order_bool.h b/oneflow/core/common/high_order_bool.h index 57c35ee78e4..86b9b560f4b 100644 --- a/oneflow/core/common/high_order_bool.h +++ b/oneflow/core/common/high_order_bool.h @@ -32,7 +32,13 @@ namespace hob { template struct BaseExpr { - virtual ~BaseExpr() = default; +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wnon-virtual-dtor" + // NOTE: Performance will be degraded if the destructor is virtual. + // So please do NOT implement custom destructor in any child classes of BaseExpr, + // and every fields of child classes should be of POD type. + ~BaseExpr() = default; +#pragma GCC diagnostic pop ALWAYS_INLINE virtual scalar_or_const_ref_t get(const Context&) const = 0; virtual std::string DebugStr(const Context&, bool display_result = true) const = 0; // NOLINT operator bool() = delete; @@ -40,7 +46,10 @@ struct BaseExpr { template struct Expr : public BaseExpr { - virtual ~Expr() = default; +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wnon-virtual-dtor" + ~Expr() = default; +#pragma GCC diagnostic pop }; template diff --git a/oneflow/core/common/maybe.h b/oneflow/core/common/maybe.h index fd949ef8b9d..bff480f6a96 100644 --- a/oneflow/core/common/maybe.h +++ b/oneflow/core/common/maybe.h @@ -43,6 +43,7 @@ class Maybe::value || IsScala final { public: Maybe(const T& data) : data_or_error_(std::make_shared(data)) {} + Maybe(T&& data) : data_or_error_(std::make_shared(std::move(data))) {} Maybe(const Error& error) : data_or_error_(error.error_proto()) {} Maybe(const std::shared_ptr& data) : data_or_error_(data) {} Maybe(std::shared_ptr&& data) : data_or_error_(std::move(data)) {} diff --git a/oneflow/core/common/maybe_test.cpp b/oneflow/core/common/maybe_test.cpp index edacf082335..45aeedf98f0 100644 --- a/oneflow/core/common/maybe_test.cpp +++ b/oneflow/core/common/maybe_test.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/common/maybe.h" #include +#include #include "oneflow/core/common/util.h" namespace oneflow { @@ -67,5 +68,7 @@ TEST(Maybe, CHECK_OK) { ASSERT_EXIT(CHECK_OK(g(11)), testing::KilledBySignal(SIGABRT), R"(g\(11\) is not OK)"); } +TEST(Maybe, Noncopyable) { Maybe> a{std::make_unique(1)}; } + } // namespace test } // namespace oneflow diff --git a/oneflow/core/common/nd_index_offset_helper.h b/oneflow/core/common/nd_index_offset_helper.h index eda499f3250..89bd0a90763 100644 --- a/oneflow/core/common/nd_index_offset_helper.h +++ b/oneflow/core/common/nd_index_offset_helper.h @@ -44,6 +44,15 @@ class NdIndexOffsetHelper { OF_DEVICE_FUNC explicit NdIndexOffsetHelper(const T* dims, int n) { InitStrides(dims, n); } + template + OF_DEVICE_FUNC explicit NdIndexOffsetHelper(const U* dims, int n) { + T dims_arr[N]; + for (int i = 0; i < N; ++i) { + if (i < n) { dims_arr[i] = dims[i]; } + } + InitStrides(dims_arr, n); + } + ~NdIndexOffsetHelper() = default; OF_DEVICE_FUNC T NdIndexToOffset(const T* index) const { diff --git a/oneflow/core/common/shape.cpp b/oneflow/core/common/shape.cpp index a09462d0693..12fb8ef08d5 100644 --- a/oneflow/core/common/shape.cpp +++ b/oneflow/core/common/shape.cpp @@ -122,7 +122,6 @@ void Shape::Set(int64_t index, int64_t val) { CHECK_GE(index, 0); CHECK_LT(index, this->NumAxes()) << " Shape: " << DebugStr() << " visit index: " << index << " > num_axes: " << this->NumAxes(); - CHECK_GE(val, 0); dim_vec_.at(index) = val; UpdateElemCnt(); } diff --git a/oneflow/core/control/ctrl_service.cpp b/oneflow/core/control/ctrl_service.cpp index c33f65b78a2..6a3c5ca062b 100644 --- a/oneflow/core/control/ctrl_service.cpp +++ b/oneflow/core/control/ctrl_service.cpp @@ -31,6 +31,8 @@ std::array BuildRpcMethods( return {BuildOneRpcMethod(channel)...}; } +constexpr int64_t kDefaultGrpcMaxMessageByteSize = -1; + } // namespace CtrlService::Stub::Stub(std::shared_ptr channel) @@ -39,7 +41,9 @@ CtrlService::Stub::Stub(std::shared_ptr channel) std::unique_ptr CtrlService::NewStub(const std::string& addr) { grpc::ChannelArguments ch_args; - ch_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, 64 * 1024 * 1024); + int64_t max_msg_byte_size = + ParseIntegerFromEnv("ONEFLOW_GRPC_MAX_MESSAGE_BYTE_SIZE", kDefaultGrpcMaxMessageByteSize); + ch_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, max_msg_byte_size); return std::make_unique( grpc::CreateCustomChannel(addr, grpc::InsecureChannelCredentials(), ch_args)); } diff --git a/oneflow/core/cuda/layer_norm.cuh b/oneflow/core/cuda/layer_norm.cuh index 21f96dcd74c..19da5e8d024 100644 --- a/oneflow/core/cuda/layer_norm.cuh +++ b/oneflow/core/cuda/layer_norm.cuh @@ -39,7 +39,7 @@ struct MaxOp { __device__ __forceinline__ T operator()(const T& a, const T& b) const { return max(a, b); } }; -template typename ReductionOp, typename T, int thread_group_width = kWarpSize> +template class ReductionOp, typename T, int thread_group_width = kWarpSize> __inline__ __device__ T WarpAllReduce(T val) { for (int mask = thread_group_width / 2; mask > 0; mask /= 2) { val = ReductionOp()(val, __shfl_xor_sync(0xffffffff, val, mask)); @@ -47,7 +47,7 @@ __inline__ __device__ T WarpAllReduce(T val) { return val; } -template typename ReductionOp, typename T, int block_size> +template class ReductionOp, typename T, int block_size> __inline__ __device__ T BlockAllReduce(T val) { typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; diff --git a/oneflow/core/cuda/softmax.cuh b/oneflow/core/cuda/softmax.cuh index db84e3db350..940cf45e19c 100644 --- a/oneflow/core/cuda/softmax.cuh +++ b/oneflow/core/cuda/softmax.cuh @@ -709,8 +709,9 @@ inline cudaError_t DispatchSoftmaxBlockUncachedImpl(cudaStream_t stream, LOAD lo } template -inline cudaError_t DispatchSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, - const int64_t cols) { +inline typename std::enable_if::value, cudaError_t>::type +DispatchSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, + const int64_t cols) { if (cols <= 1024) { return DispatchSoftmaxWarpImpl( stream, load, store, rows, cols); @@ -731,8 +732,17 @@ inline cudaError_t DispatchSoftmax(cudaStream_t stream, LOAD load, STORE store, } template -inline cudaError_t DispatchLogSoftmax(cudaStream_t stream, LOAD load, STORE store, - const int64_t rows, const int64_t cols) { +inline typename std::enable_if::value, cudaError_t>::type +DispatchSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, + const int64_t cols) { + return DispatchSoftmaxBlockUncachedImpl( + stream, load, store, rows, cols); +} + +template +inline typename std::enable_if::value, cudaError_t>::type +DispatchLogSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, + const int64_t cols) { if (cols <= 1024) { return DispatchSoftmaxWarpImpl( stream, load, store, rows, cols); @@ -752,6 +762,14 @@ inline cudaError_t DispatchLogSoftmax(cudaStream_t stream, LOAD load, STORE stor } } +template +inline typename std::enable_if::value, cudaError_t>::type +DispatchLogSoftmax(cudaStream_t stream, LOAD load, STORE store, const int64_t rows, + const int64_t cols) { + return DispatchSoftmaxBlockUncachedImpl( + stream, load, store, rows, cols); +} + template @@ -1267,8 +1285,9 @@ inline cudaError_t DispatchSoftmaxGradBlockUncachedImpl(cudaStream_t stream, LOA } template -inline cudaError_t DispatchSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, - STORE store, const int64_t rows, const int64_t cols) { +inline typename std::enable_if::value, cudaError_t>::type +DispatchSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, + const int64_t rows, const int64_t cols) { if (cols <= 1024) { return DispatchSoftmaxGradWarpImpl( stream, load_y, load_dy, store, rows, cols); @@ -1288,9 +1307,20 @@ inline cudaError_t DispatchSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_ return cudaSuccess; } } + +template +inline typename std::enable_if::value, cudaError_t>::type +DispatchSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, + const int64_t rows, const int64_t cols) { + return DispatchSoftmaxGradBlockUncachedImpl(stream, load_y, load_dy, store, + rows, cols); +} + template -inline cudaError_t DispatchLogSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, - STORE store, const int64_t rows, const int64_t cols) { +inline typename std::enable_if::value, cudaError_t>::type +DispatchLogSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, + const int64_t rows, const int64_t cols) { if (cols <= 1024) { return DispatchSoftmaxGradWarpImpl( stream, load_y, load_dy, store, rows, cols); @@ -1311,6 +1341,15 @@ inline cudaError_t DispatchLogSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LO } } +template +inline typename std::enable_if::value, cudaError_t>::type +DispatchLogSoftmaxGrad(cudaStream_t stream, LOAD_Y load_y, LOAD_DY load_dy, STORE store, + const int64_t rows, const int64_t cols) { + return DispatchSoftmaxGradBlockUncachedImpl(stream, load_y, load_dy, + store, rows, cols); +} + } // namespace softmax } // namespace cuda diff --git a/oneflow/core/device/cuda_util.cpp b/oneflow/core/device/cuda_util.cpp index da304377a0a..81a1d44f11e 100644 --- a/oneflow/core/device/cuda_util.cpp +++ b/oneflow/core/device/cuda_util.cpp @@ -96,31 +96,8 @@ const char* NvjpegGetErrorString(nvjpegStatus_t error) { #endif -void InitGlobalCudaDeviceProp() { - CHECK(Global::Get() == nullptr) << "initialized Global twice"; - Global::New(); - cudaGetDeviceProperties(Global::Get(), 0); - if (IsCuda9OnTuringDevice()) { - LOG(WARNING) - << "CUDA 9 running on Turing device has known issues, consider upgrading to CUDA 10"; - } -} - -int32_t GetSMCudaMaxBlocksNum() { - const auto& global_device_prop = *Global::Get(); - int32_t n = - global_device_prop.multiProcessorCount * global_device_prop.maxThreadsPerMultiProcessor; - return (n + kCudaThreadsNumPerBlock - 1) / kCudaThreadsNumPerBlock; -} - -bool IsCuda9OnTuringDevice() { - const auto& global_device_prop = *Global::Get(); - return CUDA_VERSION >= 9000 && CUDA_VERSION < 9020 && global_device_prop.major == 7 - && global_device_prop.minor == 5; -} - size_t GetAvailableGpuMemSize(int dev_id) { - cudaDeviceProp prop; + cudaDeviceProp prop{}; cudaGetDeviceProperties(&prop, dev_id); return prop.totalGlobalMem; } @@ -150,10 +127,6 @@ std::function GetCudaMallocHostFn(int32_t dev) { } // namespace -cudaStream_t RunCudaKernelGetStream(ep::Stream* stream) { - return stream->As()->cuda_stream(); -} - cudaError_t NumaAwareCudaMallocHost(int32_t dev, void** ptr, size_t size) { auto fn = GetCudaMallocHostFn(dev); return fn(ptr, size); diff --git a/oneflow/core/device/cuda_util.h b/oneflow/core/device/cuda_util.h index 278f3e7776a..1b268bb7b20 100644 --- a/oneflow/core/device/cuda_util.h +++ b/oneflow/core/device/cuda_util.h @@ -28,6 +28,7 @@ limitations under the License. #include #include #include "oneflow/core/device/cuda_pseudo_half.h" +#include "oneflow/core/ep/cuda/cuda_stream.h" #if CUDA_VERSION >= 10020 @@ -109,33 +110,13 @@ const int32_t kCudaWarpSize = 32; // TODO: limit of shared memory should be different for different arch const int32_t kCudaMaxSharedMemoryByteSize = 48 << 10; -int32_t GetSMCudaMaxBlocksNum(); -void InitGlobalCudaDeviceProp(); -bool IsCuda9OnTuringDevice(); - inline int32_t BlocksNum4ThreadsNum(const int32_t n) { CHECK_GT(n, 0); return std::min((n + kCudaThreadsNumPerBlock - 1) / kCudaThreadsNumPerBlock, kCudaMaxBlocksNum); } -inline int32_t SMBlocksNum4ThreadsNum(const int32_t n) { - CHECK_GT(n, 0); - return std::min((n + kCudaThreadsNumPerBlock - 1) / kCudaThreadsNumPerBlock, - GetSMCudaMaxBlocksNum()); -} - -namespace ep { - -class Stream; -class CudaStream; - -} // namespace ep - -cudaStream_t RunCudaKernelGetStream(ep::Stream* stream); - -#define RUN_CUDA_KERNEL(func, device_ctx_ptr, thread_num, ...) \ - func<<>>(__VA_ARGS__) +#define RUN_CUDA_KERNEL(func, stream, elem_cnt, ...) \ + stream->As()->LaunchKernel(func, elem_cnt, 1, __VA_ARGS__) size_t GetAvailableGpuMemSize(int dev_id); diff --git a/oneflow/core/device/ep_based_event_record.h b/oneflow/core/device/ep_based_event_record.h new file mode 100644 index 00000000000..5783c35f955 --- /dev/null +++ b/oneflow/core/device/ep_based_event_record.h @@ -0,0 +1,54 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_DEVICE_EP_BASED_EVENT_RECORD_H_ +#define ONEFLOW_CORE_DEVICE_EP_BASED_EVENT_RECORD_H_ + +#include "oneflow/core/device/event_record.h" +#include "oneflow/core/ep/include/active_device_guard.h" + +namespace oneflow { + +class EpBasedEventRecord : public EventRecord { + public: + OF_DISALLOW_COPY_AND_MOVE(EpBasedEventRecord); + EpBasedEventRecord(ep::Event* event, ep::Device* device) : event_(event), device_(device) {} + ~EpBasedEventRecord() { + ep::ActiveDeviceGuard guard(device_); + device_->DestroyEvent(event_); + }; + + static std::shared_ptr MakeEventRecord(ep::Stream* stream) { + ep::Device* device = stream->device(); + ep::ActiveDeviceGuard guard(device); + ep::Event* event = device->CreateEvent(); + stream->RecordEvent(event); + return std::make_shared(event, device); + } + + bool QueryDone() const override { + ep::ActiveDeviceGuard guard(device_); + bool done = CHECK_JUST(event_->QueryDone()); + return done; + } + + private: + ep::Event* event_; + ep::Device* device_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_DEVICE_EP_BASED_EVENT_RECORD_H_ diff --git a/oneflow/core/eager/critical_section_instruction_type.cpp b/oneflow/core/eager/critical_section_instruction_type.cpp index 21cc9313961..f4c65631263 100644 --- a/oneflow/core/eager/critical_section_instruction_type.cpp +++ b/oneflow/core/eager/critical_section_instruction_type.cpp @@ -17,11 +17,11 @@ limitations under the License. #include "oneflow/core/eager/critical_section_stream_type.h" #include "oneflow/core/eager/critical_section_status_querier.h" #include "oneflow/core/eager/critical_section_phy_instr_operand.h" +#include "oneflow/core/job/critical_section_instance.h" #include "oneflow/core/framework/nn_graph_if.h" #include "oneflow/core/common/container_util.h" #include "oneflow/core/vm/instruction.h" #include "oneflow/core/vm/instruction_type.h" -#include "oneflow/core/job/job_instance.h" #include "oneflow/core/common/buffer_manager.h" #include "oneflow/core/common/global.h" #include "oneflow/core/vm/stream.h" @@ -49,9 +49,63 @@ class CriticalSectionBeginInstructionType final : public InstructionType { void Infer(vm::Instruction* instruction) const override { UNIMPLEMENTED(); } void Compute(vm::Instruction* instruction) const override { - auto* status_buffer_data = instruction->mut_status_buffer()->mut_buffer()->mut_data(); - auto* status_querier = CriticalSectionStatusQuerier::MutCast(status_buffer_data); - status_querier->SetLaunched(std::make_shared()); + OF_PROFILER_RANGE_GUARD("CriticalSectionBegin"); + { + auto ptr = instruction->instr_msg().phy_instr_operand(); + auto phy_instr_operand = std::dynamic_pointer_cast(ptr); + CHECK_NOTNULL(phy_instr_operand); + const auto& critical_section_instance = MakeCriticalSectionInstance(phy_instr_operand); + const auto& job_name = critical_section_instance->job_name(); + auto* buffer_mgr = Global>>::Get(); + for (int i = 0; i < phy_instr_operand->interfaces_op_names().size(); ++i) { + if (phy_instr_operand->interfaces_valid().at(i)) { + const std::string& interface_op_name = phy_instr_operand->interfaces_op_names().at(i); + const auto& buffer_name = + phy_instr_operand->GetInterfaceBufferName(job_name, interface_op_name); + buffer_mgr->Get(buffer_name)->Push(critical_section_instance); + } + } + const auto& callback_buffer_name = + phy_instr_operand->GetInterfaceCriticalSectionCallbackBufferName(job_name); + buffer_mgr->Get(callback_buffer_name)->Push(critical_section_instance); + const auto& wait_buffer_name = + phy_instr_operand->GetInterfaceCriticalSectionWaitBufferName(job_name); + buffer_mgr->Get(wait_buffer_name)->Push(critical_section_instance); + } + { + auto* status_buffer_data = instruction->mut_status_buffer()->mut_buffer()->mut_data(); + auto* status_querier = CriticalSectionStatusQuerier::MutCast(status_buffer_data); + status_querier->SetLaunched(std::make_shared()); + } + } + + private: + class NaiveCriticalSectionInstance final : public CriticalSectionInstance { + public: + NaiveCriticalSectionInstance( + const std::shared_ptr& phy_instr_operand, + const std::string& job_name) + : CriticalSectionInstance(), phy_instr_operand_(phy_instr_operand), job_name_(job_name) {} + + ~NaiveCriticalSectionInstance() override = default; + + const std::string& job_name() const override { return job_name_; } + + void AccessBlobByOpName(uint64_t ofblob_ptr, const std::string& op_name) const override { + phy_instr_operand_->AccessBlobByOpName(ofblob_ptr, op_name); + } + void Finish() const override { phy_instr_operand_->Finish(); } + + private: + std::shared_ptr phy_instr_operand_; + std::string job_name_; + }; + + std::shared_ptr MakeCriticalSectionInstance( + const std::shared_ptr& phy_instr_operand) const { + phy_instr_operand->FinishInvalidInterfaceEventRecords(); + const auto& job_name = phy_instr_operand->nn_graph()->job_name(); + return std::make_shared(phy_instr_operand, job_name); } }; diff --git a/oneflow/core/eager/critical_section_phy_instr_operand.cpp b/oneflow/core/eager/critical_section_phy_instr_operand.cpp index 4fe8c4b95e3..ffd2d5cfa37 100644 --- a/oneflow/core/eager/critical_section_phy_instr_operand.cpp +++ b/oneflow/core/eager/critical_section_phy_instr_operand.cpp @@ -15,7 +15,12 @@ limitations under the License. */ #include "oneflow/core/eager/critical_section_phy_instr_operand.h" #include "oneflow/core/framework/device.h" +#include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/common/decorator.h" +#include "oneflow/core/device/device_context.h" +#include "oneflow/core/device/ep_based_event_record.h" +#include "oneflow/core/register/ofblob.h" +#include "oneflow/core/common/container_util.h" namespace oneflow { namespace vm { @@ -46,7 +51,59 @@ constexpr auto* CriticalSectionLocalDepObject = void CriticalSectionBeginPhyInstrOperand::ForEachMutMirroredObject( const std::function& DoEach) const { DoEach(CHECK_JUST(CriticalSectionLocalDepObject())->mut_mirrored_object()); - DoEach(local_dep_object_->mut_mirrored_object()); +} + +void CriticalSectionBeginPhyInstrOperand::FinishInvalidInterfaceEventRecords() { + for (const auto& op_name : interfaces_op_names()) { + size_t index = CHECK_JUST(MapAt(op_name2interface_index_, op_name)); + if (!interfaces_valid().at(index)) { + const auto& iter = op_name2end_event_record_->find(op_name); + CHECK(iter != op_name2end_event_record_->end()); + iter->second->Init(std::make_shared()); + } + } +} + +void CriticalSectionBeginPhyInstrOperand::Finish() { + for (const auto& pair : *op_name2end_event_record_) { + pair.second->TryInit(std::make_shared()); + } +} + +void InputCriticalSectionBeginPhyInstrOperand::AccessBlobByOpName(uint64_t of_blob_ptr, + const std::string& op_name) { + int64_t i = CHECK_JUST(MapAt(op_name2interface_index_, op_name)); + CHECK(interfaces_valid().at(i)); + OfBlob* of_blob = reinterpret_cast(of_blob_ptr); + const auto& eager_blob_object = eager_blob_objects_->at(i); + const Blob* blob = &eager_blob_object->blob(); + CHECK_NOTNULL(blob); + of_blob->mut_blob()->CopyHeaderFrom(blob); + const auto& end_event_record = op_name2end_event_record_->at(op_name); + if (blob->dptr() == nullptr) { + end_event_record->Init(std::make_shared()); + } else { + AutoMemcpy(of_blob->stream(), of_blob->mut_blob(), blob); + end_event_record->Init(EpBasedEventRecord::MakeEventRecord(of_blob->stream())); + } +} + +void OutputCriticalSectionBeginPhyInstrOperand::AccessBlobByOpName(uint64_t of_blob_ptr, + const std::string& op_name) { + int64_t i = CHECK_JUST(MapAt(op_name2interface_index_, op_name)); + CHECK(interfaces_valid().at(i)); + OfBlob* of_blob = reinterpret_cast(of_blob_ptr); + const auto& eager_blob_object = eager_blob_objects_->at(i); + Blob* mut_blob = eager_blob_object->mut_blob(); + CHECK_NOTNULL(mut_blob); + mut_blob->CopyHeaderFrom(&of_blob->blob()); + const auto& end_event_record = op_name2end_event_record_->at(op_name); + if (mut_blob->dptr() == nullptr) { + end_event_record->Init(std::make_shared()); + } else { + AutoMemcpy(of_blob->stream(), mut_blob, &of_blob->blob()); + end_event_record->Init(EpBasedEventRecord::MakeEventRecord(of_blob->stream())); + } } void CriticalSectionEndPhyInstrOperand::ForEachMutMirroredObject( diff --git a/oneflow/core/eager/critical_section_phy_instr_operand.h b/oneflow/core/eager/critical_section_phy_instr_operand.h index 76904de3394..514f47d0818 100644 --- a/oneflow/core/eager/critical_section_phy_instr_operand.h +++ b/oneflow/core/eager/critical_section_phy_instr_operand.h @@ -20,6 +20,7 @@ limitations under the License. #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/device/event_record.h" #include "oneflow/core/framework/nn_graph_if.h" +#include "oneflow/core/common/buffer_manager.h" namespace oneflow { @@ -42,32 +43,60 @@ class CriticalSectionBeginPhyInstrOperand : public PhyInstrOperand { virtual ~CriticalSectionBeginPhyInstrOperand() = default; explicit CriticalSectionBeginPhyInstrOperand( + const std::shared_ptr& nn_graph, const one::EagerBlobObjectListPtr& eager_blob_objects, - intrusive::shared_ptr local_dep_object) - : eager_blob_objects_(eager_blob_objects), local_dep_object_(local_dep_object) {} + const std::shared_ptr>>& + op_name2end_event_record) + : nn_graph_(nn_graph), + eager_blob_objects_(eager_blob_objects), + op_name2end_event_record_(op_name2end_event_record) {} + + const std::shared_ptr& nn_graph() const { return nn_graph_; } + const one::EagerBlobObjectListPtr& eager_blob_objects() const { return eager_blob_objects_; } void ForEachMirroredObject(const std::function&) const; void ForEachMutMirroredObject(const std::function&) const; - intrusive::shared_ptr local_dep_object() const { return local_dep_object_; } + virtual const std::vector& interfaces_op_names() const = 0; + virtual const std::vector& interfaces_valid() const = 0; + virtual std::string GetInterfaceBufferName(const std::string& job_name, + const std::string& op_name) const = 0; + virtual std::string GetInterfaceCriticalSectionCallbackBufferName( + const std::string& job_name) const = 0; + virtual std::string GetInterfaceCriticalSectionWaitBufferName( + const std::string& job_name) const = 0; + virtual void AccessBlobByOpName(uint64_t of_blob_ptr, const std::string& op_name) = 0; + + void FinishInvalidInterfaceEventRecords(); + void Finish(); protected: + std::shared_ptr nn_graph_; one::EagerBlobObjectListPtr eager_blob_objects_; - mutable intrusive::shared_ptr local_dep_object_; + std::shared_ptr>> + op_name2end_event_record_; + HashMap op_name2interface_index_; }; class InputCriticalSectionBeginPhyInstrOperand final : public CriticalSectionBeginPhyInstrOperand { public: - explicit InputCriticalSectionBeginPhyInstrOperand( + InputCriticalSectionBeginPhyInstrOperand( + const std::shared_ptr& nn_graph, const one::EagerBlobObjectListPtr& eager_blob_objects, - intrusive::shared_ptr local_dep_object) - : CriticalSectionBeginPhyInstrOperand(eager_blob_objects, local_dep_object), + const std::shared_ptr>>& + op_name2end_event_record) + : CriticalSectionBeginPhyInstrOperand(nn_graph, eager_blob_objects, op_name2end_event_record), input_dependences_(), output_dependences_() { ForEachConstMirroredObject(SetInserter(&input_dependences_)); ForEachMutMirroredObject(SetInserter(&output_dependences_)); ForEachMut2MirroredObject(SetInserter(&output_dependences_)); + CHECK_EQ(nn_graph->inputs_op_names().size(), eager_blob_objects->size()); + CHECK_EQ(nn_graph->inputs_op_names().size(), nn_graph->inputs_valid().size()); + for (int i = 0; i < nn_graph->inputs_op_names().size(); ++i) { + CHECK(op_name2interface_index_.emplace(nn_graph->inputs_op_names().at(i), i).second); + } } ~InputCriticalSectionBeginPhyInstrOperand() override = default; @@ -82,6 +111,23 @@ class InputCriticalSectionBeginPhyInstrOperand final : public CriticalSectionBeg } // for outputs + const std::vector& interfaces_op_names() const override { + return nn_graph_->inputs_op_names(); + } + const std::vector& interfaces_valid() const override { return nn_graph_->inputs_valid(); } + std::string GetInterfaceBufferName(const std::string& job_name, + const std::string& op_name) const override { + return GetInputBufferName(job_name, op_name); + } + std::string GetInterfaceCriticalSectionCallbackBufferName( + const std::string& job_name) const override { + return GetInputCriticalSectionCallbackBufferName(job_name); + } + std::string GetInterfaceCriticalSectionWaitBufferName( + const std::string& job_name) const override { + return GetInputCriticalSectionWaitBufferName(job_name); + } + void AccessBlobByOpName(uint64_t of_blob_ptr, const std::string& op_name) override; void ForEachMut2MirroredObject(const std::function&) const {} private: @@ -91,15 +137,22 @@ class InputCriticalSectionBeginPhyInstrOperand final : public CriticalSectionBeg class OutputCriticalSectionBeginPhyInstrOperand final : public CriticalSectionBeginPhyInstrOperand { public: - explicit OutputCriticalSectionBeginPhyInstrOperand( + OutputCriticalSectionBeginPhyInstrOperand( + const std::shared_ptr& nn_graph, const one::EagerBlobObjectListPtr& eager_blob_objects, - intrusive::shared_ptr local_dep_object) - : CriticalSectionBeginPhyInstrOperand(eager_blob_objects, local_dep_object), + const std::shared_ptr>>& + op_name2end_event_record) + : CriticalSectionBeginPhyInstrOperand(nn_graph, eager_blob_objects, op_name2end_event_record), input_dependences_(), output_dependences_() { ForEachConstMirroredObject(SetInserter(&input_dependences_)); ForEachMutMirroredObject(SetInserter(&output_dependences_)); ForEachMut2MirroredObject(SetInserter(&output_dependences_)); + CHECK_EQ(nn_graph->outputs_op_names().size(), eager_blob_objects->size()); + CHECK_EQ(nn_graph->outputs_op_names().size(), nn_graph->outputs_valid().size()); + for (int i = 0; i < nn_graph->outputs_op_names().size(); ++i) { + CHECK(op_name2interface_index_.emplace(nn_graph->outputs_op_names().at(i), i).second); + } } ~OutputCriticalSectionBeginPhyInstrOperand() override = default; @@ -116,6 +169,24 @@ class OutputCriticalSectionBeginPhyInstrOperand final : public CriticalSectionBe ForEachMirroredObject(DoEach); } + const std::vector& interfaces_op_names() const override { + return nn_graph_->outputs_op_names(); + } + const std::vector& interfaces_valid() const override { return nn_graph_->outputs_valid(); } + std::string GetInterfaceBufferName(const std::string& job_name, + const std::string& op_name) const override { + return GetOutputBufferName(job_name, op_name); + } + std::string GetInterfaceCriticalSectionCallbackBufferName( + const std::string& job_name) const override { + return GetOutputCriticalSectionCallbackBufferName(job_name); + } + std::string GetInterfaceCriticalSectionWaitBufferName( + const std::string& job_name) const override { + return GetOutputCriticalSectionWaitBufferName(job_name); + } + void AccessBlobByOpName(uint64_t of_blob_ptr, const std::string& op_name) override; + private: DependenceVector input_dependences_; DependenceVector output_dependences_; diff --git a/oneflow/core/eager/critical_section_stream_type.cpp b/oneflow/core/eager/critical_section_stream_type.cpp index 42101671331..e9e6884e9c1 100644 --- a/oneflow/core/eager/critical_section_stream_type.cpp +++ b/oneflow/core/eager/critical_section_stream_type.cpp @@ -58,7 +58,6 @@ intrusive::shared_ptr CriticalSectionStreamType::MakeStreamDesc( const Resource& resource, int64_t this_machine_id) const { auto ret = intrusive::make_shared(); ret->mut_stream_type_id()->__Init__(LookupStreamType4TypeIndex()); - ret->set_num_machines(1); ret->set_num_streams_per_machine(1); ret->set_num_streams_per_thread(1); return ret; diff --git a/oneflow/core/eager/lazy_job_instruction_type.cpp b/oneflow/core/eager/lazy_job_instruction_type.cpp index 8f4126b8c39..a6d25f00f53 100644 --- a/oneflow/core/eager/lazy_job_instruction_type.cpp +++ b/oneflow/core/eager/lazy_job_instruction_type.cpp @@ -30,60 +30,20 @@ limitations under the License. #include "oneflow/core/vm/naive_instruction_status_querier.h" #include "oneflow/core/profiler/profiler.h" #include "oneflow/core/kernel/kernel_util.h" -#include "oneflow/core/ep/include/active_device_guard.h" namespace oneflow { namespace { -class EpBasedEventRecord : public EventRecord { - public: - OF_DISALLOW_COPY(EpBasedEventRecord); - EpBasedEventRecord(ep::Event* event, ep::Device* device) : event_(event), device_(device) {} - ~EpBasedEventRecord() { - ep::ActiveDeviceGuard guard(device_); - device_->DestroyEvent(event_); - }; - - bool QueryDone() const override { - ep::ActiveDeviceGuard guard(device_); - bool done = CHECK_JUST(event_->QueryDone()); - return done; - } - - private: - ep::Event* event_; - ep::Device* device_; -}; - -std::shared_ptr MakeEventRecord(ep::Stream* stream) { - ep::Device* device = stream->device(); - ep::ActiveDeviceGuard guard(device); - ep::Event* event = device->CreateEvent(); - stream->RecordEvent(event); - return std::make_shared(event, device); -} - class LazyJobInstance final : public JobInstance { public: LazyJobInstance(const LazyJobInstance&) = delete; LazyJobInstance(LazyJobInstance&&) = delete; ~LazyJobInstance() override = default; - LazyJobInstance(const std::string& job_name, - const HashMap>& push_cbs, - const HashMap>& pull_cbs, - const std::function finish_cb) - : job_name_(job_name), push_cbs_(push_cbs), pull_cbs_(pull_cbs), finish_cb_(finish_cb) {} + LazyJobInstance(const std::string& job_name, const std::function& finish_cb) + : job_name_(job_name), finish_cb_(finish_cb) {} std::string job_name() const override { return job_name_; } - void PushBlobByOpName(uint64_t ofblob_ptr, const std::string& op_name) const override { - const auto& push_cb = CHECK_JUST(MapAt(push_cbs_, op_name)); - return push_cb(ofblob_ptr); - } - void PullBlobByOpName(uint64_t ofblob_ptr, const std::string& op_name) const override { - const auto& pull_cb = CHECK_JUST(MapAt(pull_cbs_, op_name)); - return pull_cb(ofblob_ptr); - } void Finish() const override { finish_cb_(); } std::string sole_input_op_name_in_user_job() const override { @@ -99,8 +59,6 @@ class LazyJobInstance final : public JobInstance { private: const std::string job_name_; - const HashMap> push_cbs_; - const HashMap> pull_cbs_; const std::function finish_cb_; }; @@ -117,8 +75,6 @@ class LaunchLazyJobInstructionType final : public InstructionType { // NOLINT using stream_type = LazyJobStreamType; void Infer(vm::Instruction* instruction) const override { UNIMPLEMENTED(); } void Compute(vm::Instruction* instruction) const override { - const auto* ptr = instruction->instr_msg().phy_instr_operand().get(); - const auto* phy_instr_operand = dynamic_cast(ptr); const auto& cur_nn_graph = GetCurNNGraph(instruction); auto* device_ctx = GetLazyJobDeviceCtx(instruction); @@ -133,18 +89,6 @@ class LaunchLazyJobInstructionType final : public InstructionType { // NOLINT OF_PROFILER_RANGE_PUSH("Send all buffers to BufferMgr"); const auto& job_name = job_instance->job_name(); auto* buffer_mgr = Global>>::Get(); - for (int i = 0; i < cur_nn_graph->inputs_op_names().size(); ++i) { - if (cur_nn_graph->inputs_valid().at(i)) { - const std::string& input_op_name = cur_nn_graph->inputs_op_names().at(i); - buffer_mgr->Get(GetInputBufferName(job_name, input_op_name))->Push(job_instance); - } - } - for (int i = 0; i < cur_nn_graph->outputs_op_names().size(); ++i) { - if (cur_nn_graph->outputs_valid().at(i)) { - const std::string& output_op_name = cur_nn_graph->outputs_op_names().at(i); - buffer_mgr->Get(GetOutputBufferName(job_name, output_op_name))->Push(job_instance); - } - } buffer_mgr->Get(GetCallbackNotifierBufferName(job_name))->Push(job_instance); buffer_mgr->Get(GetSourceTickBufferName(job_name))->Push(job_instance); OF_PROFILER_RANGE_POP(); // BufferMgr @@ -174,67 +118,13 @@ class LaunchLazyJobInstructionType final : public InstructionType { // NOLINT const auto* phy_instr_operand = dynamic_cast(ptr); CHECK_NOTNULL(phy_instr_operand); const auto& nn_graph = phy_instr_operand->nn_graph(); - HashMap> push_cbs; - for (int i = 0; i < nn_graph->inputs_op_names().size(); ++i) { - const auto& input_op_name = nn_graph->inputs_op_names().at(i); - const auto& end_event_record = - CHECK_JUST(phy_instr_operand->EndEventRecord4OpName(input_op_name)); - if (nn_graph->inputs_valid().at(i)) { - const auto& input_blob_object = phy_instr_operand->input_blob_objects()->at(i); - const auto& PushCb = [input_op_name, end_event_record, - input_blob_object](int64_t of_blob_ptr) { - OfBlob* of_blob = reinterpret_cast(of_blob_ptr); - const Blob* blob = &input_blob_object->blob(); - CHECK_NOTNULL(blob); - of_blob->mut_blob()->CopyHeaderFrom(blob); - if (blob->dptr() == nullptr) { - end_event_record->Init(std::make_shared()); - } else { - AutoMemcpy(of_blob->stream(), of_blob->mut_blob(), blob); - end_event_record->Init(MakeEventRecord(of_blob->stream())); - } - }; - CHECK(push_cbs.emplace(input_op_name, PushCb).second); - } else { - end_event_record->Init(std::make_shared()); - } - } - HashMap> pull_cbs; - for (int i = 0; i < nn_graph->outputs_op_names().size(); ++i) { - const auto& output_op_name = nn_graph->outputs_op_names().at(i); - const auto& end_event_record = - CHECK_JUST(phy_instr_operand->EndEventRecord4OpName(output_op_name)); - if (nn_graph->outputs_valid().at(i)) { - const auto& output_blob_object = phy_instr_operand->output_blob_objects()->at(i); - const auto& PullCb = [output_op_name, end_event_record, - output_blob_object](int64_t of_blob_ptr) { - OfBlob* of_blob = reinterpret_cast(of_blob_ptr); - Blob* mut_blob = output_blob_object->mut_blob(); - CHECK_NOTNULL(mut_blob); - mut_blob->CopyHeaderFrom(&of_blob->blob()); - if (mut_blob->dptr() == nullptr) { - end_event_record->Init(std::make_shared()); - } else { - AutoMemcpy(of_blob->stream(), mut_blob, &of_blob->blob()); - end_event_record->Init(MakeEventRecord(of_blob->stream())); - } - }; - CHECK(pull_cbs.emplace(output_op_name, PullCb).second); - } else { - end_event_record->Init(std::make_shared()); - } - } - const auto& op_name2end_event_record = phy_instr_operand->op_name2end_event_record(); - const auto& FinishCb = [this, instruction, op_name2end_event_record]() { - for (const auto& pair : *op_name2end_event_record) { - pair.second->TryInit(std::make_shared()); - } + const auto& FinishCb = [this, instruction]() { auto* device_ctx = GetLazyJobDeviceCtx(instruction); device_ctx->DequeueNNGraph(); auto* status_buffer = instruction->mut_status_buffer(); NaiveInstrStatusQuerier::MutCast(status_buffer->mut_buffer()->mut_data())->set_done(); }; - return std::make_shared(nn_graph->job_name(), push_cbs, pull_cbs, FinishCb); + return std::make_shared(nn_graph->job_name(), FinishCb); } }; diff --git a/oneflow/core/eager/lazy_job_phy_instr_operand.cpp b/oneflow/core/eager/lazy_job_phy_instr_operand.cpp index f09dea27d91..5d8c0b434e3 100644 --- a/oneflow/core/eager/lazy_job_phy_instr_operand.cpp +++ b/oneflow/core/eager/lazy_job_phy_instr_operand.cpp @@ -43,9 +43,6 @@ static constexpr auto* GetEagerNcclLocalDepObject = void LaunchLazyJobPhyInstrOperand::ForEachMutMirroredObject( const std::function& DoEach) const { - DoEach(inputs_local_dep_object_->mut_mirrored_object()); - DoEach(outputs_local_dep_object_->mut_mirrored_object()); - for (const auto& eager_blob_object : *param_blob_objects_) { DoEach(CHECK_JUST(eager_blob_object->compute_local_dep_object())->mut_mirrored_object()); } @@ -58,20 +55,5 @@ void LaunchLazyJobPhyInstrOperand::ForEachMutMirroredObject( #endif // WITH_CUDA } -void LaunchLazyJobPhyInstrOperand::ForEachConstMirroredObject( - const std::function& DoEach) const { - DoEach(inputs_local_dep_object_->mut_mirrored_object()); -} - -void LaunchLazyJobPhyInstrOperand::ForEachMut2MirroredObject( - const std::function& DoEach) const { - DoEach(outputs_local_dep_object_->mut_mirrored_object()); -} - -Maybe LaunchLazyJobPhyInstrOperand::EndEventRecord4OpName( - const std::string& op_name) const { - return JUST(MapAt(*op_name2end_event_record_, op_name)); -} - } // namespace vm } // namespace oneflow diff --git a/oneflow/core/eager/lazy_job_phy_instr_operand.h b/oneflow/core/eager/lazy_job_phy_instr_operand.h index 128d2324ee4..03b64c655cd 100644 --- a/oneflow/core/eager/lazy_job_phy_instr_operand.h +++ b/oneflow/core/eager/lazy_job_phy_instr_operand.h @@ -41,22 +41,10 @@ class LaunchLazyJobPhyInstrOperand final : public PhyInstrOperand { LaunchLazyJobPhyInstrOperand(LaunchLazyJobPhyInstrOperand&&) = delete; ~LaunchLazyJobPhyInstrOperand() override = default; - LaunchLazyJobPhyInstrOperand( - const intrusive::shared_ptr& inputs_local_dep_object, - const intrusive::shared_ptr& outputs_local_dep_object, - const std::shared_ptr>>& - op_name2end_event_record, - const one::EagerBlobObjectListPtr& input_blob_objects, - const one::EagerBlobObjectListPtr& output_blob_objects, - const one::EagerBlobObjectListPtr& param_blob_objects, - const std::shared_ptr& nn_graph) - : inputs_local_dep_object_(inputs_local_dep_object), - outputs_local_dep_object_(outputs_local_dep_object), - op_name2end_event_record_(op_name2end_event_record), - input_blob_objects_(input_blob_objects), - output_blob_objects_(output_blob_objects), + LaunchLazyJobPhyInstrOperand(const std::shared_ptr& nn_graph, + const one::EagerBlobObjectListPtr& param_blob_objects) + : nn_graph_(nn_graph), param_blob_objects_(param_blob_objects), - nn_graph_(nn_graph), input_dependences_(), output_dependences_() { ForEachConstMirroredObject(SetInserter(&input_dependences_)); @@ -64,34 +52,20 @@ class LaunchLazyJobPhyInstrOperand final : public PhyInstrOperand { ForEachMut2MirroredObject(SetInserter(&output_dependences_)); } - const one::EagerBlobObjectListPtr& input_blob_objects() const { return input_blob_objects_; } - const one::EagerBlobObjectListPtr& output_blob_objects() const { return output_blob_objects_; } const std::shared_ptr& nn_graph() const { return nn_graph_; } - Maybe EndEventRecord4OpName(const std::string& op_name) const; - const std::shared_ptr>>& - op_name2end_event_record() const { - return op_name2end_event_record_; - } - const DependenceVector& input_dependences() const override { return input_dependences_; } const DependenceVector& output_dependences() const override { return output_dependences_; } - void ForEachConstMirroredObject(const std::function&) const; + void ForEachConstMirroredObject(const std::function&) const {} void ForEachMutMirroredObject(const std::function&) const; - void ForEachMut2MirroredObject(const std::function&) const; + void ForEachMut2MirroredObject(const std::function&) const {} private: - mutable intrusive::shared_ptr inputs_local_dep_object_; - mutable intrusive::shared_ptr outputs_local_dep_object_; - std::shared_ptr>> - op_name2end_event_record_; - one::EagerBlobObjectListPtr input_blob_objects_; - one::EagerBlobObjectListPtr output_blob_objects_; - one::EagerBlobObjectListPtr param_blob_objects_; std::shared_ptr nn_graph_; + one::EagerBlobObjectListPtr param_blob_objects_; DependenceVector input_dependences_; DependenceVector output_dependences_; }; diff --git a/oneflow/core/eager/lazy_job_stream_type.cpp b/oneflow/core/eager/lazy_job_stream_type.cpp index da2d3bf7f45..2b03e5286a6 100644 --- a/oneflow/core/eager/lazy_job_stream_type.cpp +++ b/oneflow/core/eager/lazy_job_stream_type.cpp @@ -59,7 +59,6 @@ intrusive::shared_ptr LazyJobStreamType::MakeStreamDesc(const Resour int64_t this_machine_id) const { auto ret = intrusive::make_shared(); ret->mut_stream_type_id()->__Init__(LookupStreamType4TypeIndex()); - ret->set_num_machines(1); ret->set_num_streams_per_machine(1); ret->set_num_streams_per_thread(1); return ret; diff --git a/oneflow/core/eager/opkernel_instruction_type.cpp b/oneflow/core/eager/opkernel_instruction_type.cpp index 13e6a6ffebc..25789a7a32c 100644 --- a/oneflow/core/eager/opkernel_instruction_type.cpp +++ b/oneflow/core/eager/opkernel_instruction_type.cpp @@ -454,9 +454,10 @@ struct LocalCallOpKernelUtil final { JUST(ResetTempStorageBlob(operand)); JUST(TryAllocateTempStorageBlobMemory(operand, device_ctx)); } - user_op::OpKernelState* state; - TryInitOpKernelState(operand, device_ctx, &state); - OpKernelCompute(operand, device_ctx, state); + user_op::OpKernelState* state = nullptr; + user_op::OpKernelCache* cache = nullptr; + TryInitOpKernelStateAndCache(operand, device_ctx, &state, &cache); + OpKernelCompute(operand, device_ctx, state, cache); if (unlikely(operand->need_temp_storage())) { JUST(DeallocateTempStorageBlobMemory(operand, device_ctx)); } @@ -488,15 +489,19 @@ struct LocalCallOpKernelUtil final { return operand->mut_opkernel()->mut_temp_blob_object()->InitBlob(); } - static inline void TryInitOpKernelState(LocalCallOpKernelPhyInstrOperand* operand, - DeviceCtx* device_ctx, user_op::OpKernelState** state) { + static inline void TryInitOpKernelStateAndCache(LocalCallOpKernelPhyInstrOperand* operand, + DeviceCtx* device_ctx, + user_op::OpKernelState** state, + user_op::OpKernelCache** cache) { if (likely(operand->op_interp_ctx().state)) { *state = operand->op_interp_ctx().state.get(); - return; + // set state to nullptr so that state initialization in TryInitOpKernelStateAndCache will be + // skipped. + state = nullptr; } - operand->mut_opkernel()->TryInitOpKernelState( + operand->mut_opkernel()->TryInitOpKernelStateAndCache( operand->user_opkernel(), device_ctx, operand->inputs().get(), operand->outputs().get(), - operand->consistent_tensor_infer_result().get(), state); + operand->consistent_tensor_infer_result().get(), state, cache); } static inline Maybe AllocateOutputBlobsMemory(LocalCallOpKernelPhyInstrOperand* operand, @@ -515,12 +520,13 @@ struct LocalCallOpKernelUtil final { } static inline void OpKernelCompute(LocalCallOpKernelPhyInstrOperand* operand, - DeviceCtx* device_ctx, user_op::OpKernelState* state) { + DeviceCtx* device_ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache* cache) { auto* opkernel = operand->mut_opkernel(); auto* compute_ctx = opkernel->UpdateComputeContext(operand->inputs().get(), operand->outputs().get(), operand->consistent_tensor_infer_result().get(), device_ctx); - operand->user_opkernel()->Compute(compute_ctx, state); + operand->user_opkernel()->Compute(compute_ctx, state, cache); // tensor tuples are not allowed to be hold by StatefulLocalOpKernel opkernel->UpdateComputeContext(nullptr, nullptr, nullptr, nullptr); } diff --git a/oneflow/core/eager/opkernel_instruction_type_test.cpp b/oneflow/core/eager/opkernel_instruction_type_test.cpp index ba85b1358cc..20447f684ce 100644 --- a/oneflow/core/eager/opkernel_instruction_type_test.cpp +++ b/oneflow/core/eager/opkernel_instruction_type_test.cpp @@ -210,7 +210,6 @@ TEST(OpkernelInstructionType, call_opkernel) { TEST(OpkernelInstructionType, consecutive_opkernel_calls) { vm::TestResourceDescScope resource_scope(1, 1); InstructionMsgList list; - int64_t in_id = vm::TestUtil::NewStringSymbol(&list, "in_0"); int64_t out_id = vm::TestUtil::NewStringSymbol(&list, "out_0"); int64_t tmp_buffer_id = vm::TestUtil::NewStringSymbol(&list, "tmp_buffer_0"); int64_t test_source_id = 0; @@ -245,28 +244,30 @@ TEST(OpkernelInstructionType, consecutive_opkernel_calls) { op_conf->set_name("ccrelu_op_name"); auto* user_conf = op_conf->mutable_user_conf(); user_conf->set_op_type_name("ccrelu"); - (*user_conf->mutable_input())["in"].add_s("ccrelu_op_name/in_0"); - (*user_conf->mutable_output())["out"].add_s("ccrelu_op_name/out_0"); + (*user_conf->mutable_input())["x"].add_s("ccrelu_op_name/x_0"); + (*user_conf->mutable_output())["y"].add_s("ccrelu_op_name/y_0"); ccrelu_id = InitOpKernelObject(&list, std::make_shared(), op_conf, "gpu"); } int64_t y = 0; + int64_t x_id = vm::TestUtil::NewStringSymbol(&list, "x_0"); + int64_t y_id = vm::TestUtil::NewStringSymbol(&list, "y_0"); { int64_t y_parallel_desc_id = 0; y = vm::TestUtil::NewObject(&list, "gpu", "0:0", &y_parallel_desc_id); int64_t tmp_buffer = vm::TestUtil::NewObject(&list, "gpu", "0:0", &y_parallel_desc_id); int64_t op_node_signature_id = - NewOpNodeSignature(&list, {"in_0"}, {x_parallel_desc_id}, {"out_0", "tmp_buffer_0"}, + NewOpNodeSignature(&list, {"x_0"}, {x_parallel_desc_id}, {"y_0", "tmp_buffer_0"}, {y_parallel_desc_id, y_parallel_desc_id}); list.EmplaceBack(vm::NewInstruction("gpu.CallOpKernel") ->add_parallel_desc(y_parallel_desc_id) ->add_mut_operand(ccrelu_id) ->add_symbol_operand(op_node_signature_id) ->add_separator() - ->add_symbol_operand(in_id) + ->add_symbol_operand(x_id) ->add_const_operand(x) ->add_separator() ->add_separator() - ->add_symbol_operand(out_id) + ->add_symbol_operand(y_id) ->add_symbol_operand(tmp_buffer_id) ->add_mut_operand(y) ->add_mut_operand(tmp_buffer) @@ -365,22 +366,23 @@ TEST(OpkernelInstructionType, consecutive_stateless_call_opkernel) { ->add_symbol_operand(out_id) ->add_mut_operand(x) ->add_separator()); - int64_t in_id = vm::TestUtil::NewStringSymbol(&list, "in_0"); + int64_t x_id = vm::TestUtil::NewStringSymbol(&list, "x_0"); + int64_t y_id = vm::TestUtil::NewStringSymbol(&list, "y_0"); int64_t ccrelu_id = 0; { auto op_conf = std::make_shared(); op_conf->set_name("ccrelu_op_name"); auto* user_conf = op_conf->mutable_user_conf(); user_conf->set_op_type_name("ccrelu"); - (*user_conf->mutable_input())["in"].add_s("ccrelu_op_name/in_0"); - (*user_conf->mutable_output())["out"].add_s("ccrelu_op_name/out_0"); + (*user_conf->mutable_input())["x"].add_s("ccrelu_op_name/x_0"); + (*user_conf->mutable_output())["y"].add_s("ccrelu_op_name/y_0"); ccrelu_id = NewOpConfSymbol(&list, op_conf); } int64_t y_parallel_desc_id = 0; int64_t y = vm::TestUtil::NewObject(&list, "gpu", "0:0", &y_parallel_desc_id); int64_t tmp_buffer = vm::TestUtil::NewObject(&list, "gpu", "0:0", &y_parallel_desc_id); op_node_signature_id = - NewOpNodeSignature(&list, {"in_0"}, {parallel_desc_id}, {"out_0", "tmp_buffer_0"}, + NewOpNodeSignature(&list, {"x_0"}, {parallel_desc_id}, {"y_0", "tmp_buffer_0"}, {y_parallel_desc_id, y_parallel_desc_id}); list.EmplaceBack(vm::NewInstruction("gpu.compute.UserStatelessCallOpKernel") ->add_parallel_desc(y_parallel_desc_id) @@ -389,11 +391,11 @@ TEST(OpkernelInstructionType, consecutive_stateless_call_opkernel) { ->add_symbol_operand(op_node_signature_id) ->add_mut_operand(opkernel_id) ->add_separator() - ->add_symbol_operand(in_id) + ->add_symbol_operand(x_id) ->add_const_operand(x) ->add_separator() ->add_separator() - ->add_symbol_operand(out_id) + ->add_symbol_operand(y_id) ->add_symbol_operand(tmp_buffer_id) ->add_mut_operand(y) ->add_mut_operand(tmp_buffer) diff --git a/oneflow/core/ep/common/active_device_guard.cpp b/oneflow/core/ep/common/active_device_guard.cpp index b9088059314..640d345e3a7 100644 --- a/oneflow/core/ep/common/active_device_guard.cpp +++ b/oneflow/core/ep/common/active_device_guard.cpp @@ -20,10 +20,7 @@ namespace oneflow { namespace ep { -ActiveDeviceGuard::ActiveDeviceGuard(Device* device) { - device_manager_ = // NOLINT - Global::Get()->GetDeviceManager(device->device_type()); // NOLINT - CHECK_NOTNULL(device_manager_); +ActiveDeviceGuard::ActiveDeviceGuard(Device* device) : device_manager_(device->device_manager()) { saved_active_device_ = device_manager_->GetActiveDeviceIndex(); device->SetAsActiveDevice(); } diff --git a/oneflow/core/ep/common/device_manager_registry.cpp b/oneflow/core/ep/common/device_manager_registry.cpp index 4f70e33bfc1..4445e907e56 100644 --- a/oneflow/core/ep/common/device_manager_registry.cpp +++ b/oneflow/core/ep/common/device_manager_registry.cpp @@ -24,7 +24,9 @@ namespace ep { class DeviceManagerRegistry::Impl { public: OF_DISALLOW_COPY_AND_MOVE(Impl); - Impl() { managers_.resize(DeviceType_ARRAYSIZE); } + explicit Impl(DeviceManagerRegistry* registry) : registry_(registry) { + managers_.resize(DeviceType_ARRAYSIZE); + } ~Impl() = default; DeviceManager* GetDeviceManager(DeviceType device_type) { @@ -33,7 +35,7 @@ class DeviceManagerRegistry::Impl { std::lock_guard factories_lock(factories_mutex_); auto& factory = factories_.at(device_type); CHECK(factory); - managers_.at(device_type) = factory->NewDeviceManager(); + managers_.at(device_type) = factory->NewDeviceManager(registry_); } return managers_.at(device_type).get(); } @@ -88,13 +90,14 @@ class DeviceManagerRegistry::Impl { static std::vector> factories_; static HashMap device_type_name2device_type_; static std::mutex factories_mutex_; + DeviceManagerRegistry* registry_; }; std::vector> DeviceManagerRegistry::Impl::factories_; HashMap DeviceManagerRegistry::Impl::device_type_name2device_type_; std::mutex DeviceManagerRegistry::Impl::factories_mutex_; -DeviceManagerRegistry::DeviceManagerRegistry() { impl_.reset(new Impl()); } +DeviceManagerRegistry::DeviceManagerRegistry() { impl_.reset(new Impl(this)); } DeviceManagerRegistry::~DeviceManagerRegistry() = default; diff --git a/oneflow/core/ep/common/primitive/binary_functor.h b/oneflow/core/ep/common/primitive/binary_functor.h new file mode 100644 index 00000000000..700239ac25c --- /dev/null +++ b/oneflow/core/ep/common/primitive/binary_functor.h @@ -0,0 +1,118 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_PRIMITIVE_COMMON_BINARY_FUNCTOR_H_ +#define ONEFLOW_CORE_PRIMITIVE_COMMON_BINARY_FUNCTOR_H_ + +#include "oneflow/core/ep/include/primitive/binary_op.h" +#include "oneflow/core/common/data_type.h" + +namespace oneflow { + +namespace ep { +namespace primitive { +namespace broadcast_elementwise_binary { + +template +struct BinaryFunctor; + +template +struct BinaryFunctor { + OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 + src1); } +}; + +template +struct BinaryFunctor { + OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 - src1); } +}; + +template +struct BinaryFunctor { + OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 * src1); } +}; + +template +struct BinaryFunctor { + OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 / src1); } +}; + +template +struct BinaryFunctor { + OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { + return static_cast(src0 > src1 ? src0 : src1); + } +}; + +template +struct BinaryFunctor { + OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { + return static_cast(src0 < src1 ? src0 : src1); + } +}; + +template +struct BinaryFunctor { + OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 == src1); } +}; + +template +struct BinaryFunctor { + OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 != src1); } +}; + +template +struct BinaryFunctor { + OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 < src1); } +}; + +template +struct BinaryFunctor { + OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 <= src1); } +}; + +template +struct BinaryFunctor { + OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 > src1); } +}; + +template +struct BinaryFunctor { + OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 >= src1); } +}; + +template +struct BinaryFunctor { + OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 && src1); } +}; + +template +struct BinaryFunctor { + OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return static_cast(src0 || src1); } +}; + +template +struct BinaryFunctor { + OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { + return static_cast(src0) != static_cast(src1); + } +}; + +} // namespace broadcast_elementwise_binary +} // namespace primitive +} // namespace ep + +} // namespace oneflow + +#endif // ONEFLOW_CORE_PRIMITIVE_COMMON_BINARY_FUNCTOR_H_ diff --git a/oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h b/oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h new file mode 100644 index 00000000000..9182b675cdf --- /dev/null +++ b/oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h @@ -0,0 +1,79 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_PRIMITIVE_COMMON_BROADCAST_ELEMENTWISE_BINARY +#define ONEFLOW_CORE_PRIMITIVE_COMMON_BROADCAST_ELEMENTWISE_BINARY + +#include "oneflow/core/ep/include/primitive/primitive.h" +#include "oneflow/core/ep/include/primitive/binary_op.h" +#include "oneflow/core/common/nd_index_offset_helper.h" +#include "oneflow/core/ep/common/primitive/util.h" + +namespace oneflow { + +namespace ep { +namespace primitive { + +namespace broadcast_elementwise_binary { + +constexpr size_t kMaxNumDims = 8; + +inline void CheckInplace(size_t num_dims, const int64_t* src0_dims, const void* src0, + const int64_t* src1_dims, const void* src1, const int64_t* dst_dims, + const void* dst) { + for (int64_t i = 0; i < num_dims; ++i) { + if (src0 == dst) { CHECK_EQ(src0_dims[i], dst_dims[i]); } + if (src1 == dst) { CHECK_EQ(src1_dims[i], dst_dims[i]); } + } +} + +inline bool IsDimsEquals(size_t num_src0_dims, const int64_t* src0_dims, size_t num_src1_dims, + const int64_t* src1_dims) { + if (num_src0_dims != num_src1_dims) { return false; } + for (size_t i = 0; i < num_src1_dims; ++i) { + if (src0_dims[i] != src1_dims[i]) { return false; } + } + return true; +} + +#define BINARY_MATH_OP_SEQ \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAdd) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSub) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMul) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kDiv) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMax) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMin) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kPow) + +#define BINARY_COMPARISION_OP_SEQ \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kEqual) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kNotEqual) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessThan) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessEqual) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterThan) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterEqual) + +#define BINARY_LOGICAL_OP_SEQ \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalAnd) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalOr) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalXor) + +} // namespace broadcast_elementwise_binary +} // namespace primitive +} // namespace ep + +} // namespace oneflow + +#endif // ONEFLOW_CORE_PRIMITIVE_COMMON_BROADCAST_ELEMENTWISE_BINARY diff --git a/oneflow/core/ep/common/primitive/broadcast_simplify_dims_test.cpp b/oneflow/core/ep/common/primitive/broadcast_simplify_dims_test.cpp new file mode 100644 index 00000000000..f25802163d3 --- /dev/null +++ b/oneflow/core/ep/common/primitive/broadcast_simplify_dims_test.cpp @@ -0,0 +1,91 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/ep/common/primitive/util.h" +#include + +namespace oneflow { + +namespace ep { +namespace primitive { + +namespace { + +template +void TestSimplifyBroadcastDims(size_t num_src0_dims, const int64_t* src0_dims, size_t num_src1_dims, + const int64_t* src1_dims, size_t expected_num_dims, + const int64_t* expected_src0_dims, const int64_t* expected_src1_dims, + const int64_t* expected_dst_dims) { + size_t simplified_num_dims = 0; + int64_t simplified_src0_dims[max_num_dims]{}; + int64_t simplified_src1_dims[max_num_dims]{}; + int64_t simplified_dst_dims[max_num_dims]{}; + SimplifyBroadcastDims(num_src0_dims, src0_dims, num_src1_dims, src1_dims, + &simplified_num_dims, simplified_src0_dims, + simplified_src1_dims, simplified_dst_dims); + ASSERT_EQ(simplified_num_dims, expected_num_dims); + for (size_t i = 0; i < simplified_num_dims; ++i) { + ASSERT_EQ(simplified_src0_dims[i], expected_src0_dims[i]); + ASSERT_EQ(simplified_src1_dims[i], expected_src1_dims[i]); + ASSERT_EQ(simplified_dst_dims[i], expected_dst_dims[i]); + } +} + +TEST(Broadcast, SimplifyBroadcastDims) { + constexpr size_t max_num_dims = 8; + + const size_t num_src0_dims_1 = 4; + const size_t num_src1_dims_1 = 5; + int64_t src0_dims_1[max_num_dims]{2, 5, 10, 5}; + int64_t src1_dims_1[max_num_dims]{5, 1, 5, 10, 1}; + const size_t simplified_num_dims_1 = 4; + int64_t simplified_src0_dims_1[max_num_dims]{1, 2, 50, 5}; + int64_t simplified_src1_dims_1[max_num_dims]{5, 1, 50, 1}; + int64_t simplified_dst_dims_1[max_num_dims]{5, 2, 50, 5}; + TestSimplifyBroadcastDims( + num_src0_dims_1, src0_dims_1, num_src1_dims_1, src1_dims_1, simplified_num_dims_1, + simplified_src0_dims_1, simplified_src1_dims_1, simplified_dst_dims_1); + + const size_t num_src0_dims_2 = 4; + const size_t num_src1_dims_2 = 1; + int64_t src0_dims_2[max_num_dims]{10, 5, 1, 5}; + int64_t src1_dims_2[max_num_dims]{5}; + const size_t simplified_num_dims_2 = 2; + int64_t simplified_src0_dims_2[max_num_dims]{50, 5}; + int64_t simplified_src1_dims_2[max_num_dims]{1, 5}; + int64_t simplified_dst_dims_2[max_num_dims]{50, 5}; + TestSimplifyBroadcastDims( + num_src0_dims_2, src0_dims_2, num_src1_dims_2, src1_dims_2, simplified_num_dims_2, + simplified_src0_dims_2, simplified_src1_dims_2, simplified_dst_dims_2); + + const size_t num_src0_dims_3 = 4; + const size_t num_src1_dims_3 = 1; + int64_t src0_dims_3[max_num_dims]{2, 5, 10, 5}; + int64_t src1_dims_3[max_num_dims]{1}; + const size_t simplified_num_dims_3 = 1; + int64_t simplified_src0_dims_3[max_num_dims]{500}; + int64_t simplified_src1_dims_3[max_num_dims]{1}; + int64_t simplified_dst_dims_3[max_num_dims]{500}; + TestSimplifyBroadcastDims( + num_src0_dims_3, src0_dims_3, num_src1_dims_3, src1_dims_3, simplified_num_dims_3, + simplified_src0_dims_3, simplified_src1_dims_3, simplified_dst_dims_3); +} + +} // namespace + +} // namespace primitive +} // namespace ep + +} // namespace oneflow diff --git a/oneflow/core/ep/common/primitive/util.h b/oneflow/core/ep/common/primitive/util.h new file mode 100644 index 00000000000..515f002750d --- /dev/null +++ b/oneflow/core/ep/common/primitive/util.h @@ -0,0 +1,129 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_EP_COMMON_PRIMITIVE_UTIL_H_ +#define ONEFLOW_CORE_EP_COMMON_PRIMITIVE_UTIL_H_ + +#include "oneflow/core/common/util.h" + +namespace oneflow { + +namespace ep { +namespace primitive { + +inline size_t GetElementCount(size_t num_dims, const int64_t* dims) { + size_t count = 1; + for (size_t i = 0; i < num_dims; ++i) { count *= dims[i]; } + return count; +} + +template +bool IsPackSizeSupported(const size_t pack_size, size_t num_dims, const int64_t* dims, + const void* ptr) { + return (dims[num_dims - 1] % pack_size == 0) + && (reinterpret_cast(ptr) % (pack_size * sizeof(T)) == 0); +} + +inline void SimplifyBroadcastDims(size_t num_a_dims, const int64_t* a_dims, size_t num_b_dims, + const int64_t* b_dims, size_t num_c_dims, const int64_t* c_dims, + size_t* simplified_num_dims, int64_t* simplified_broadcast_dims, + int64_t* simplified_a_dims, int64_t* simplified_b_dims, + int64_t* simplified_c_dims) { + const size_t num_max_dims = std::max(num_a_dims, num_b_dims); + auto MakeGetDim = [num_max_dims](size_t num_dims, const int64_t* dims) { + const int64_t num_padding_dims = num_max_dims - num_dims; + return [num_padding_dims, dims](size_t index) { + return index < num_padding_dims ? 1 : dims[index - num_padding_dims]; + }; + }; + auto GetADim = MakeGetDim(num_a_dims, a_dims); + auto GetBDim = MakeGetDim(num_b_dims, b_dims); + auto GetCDim = MakeGetDim(num_c_dims, c_dims); + *simplified_num_dims = 0; + bool prev_broadcast_a = false; + bool prev_broadcast_b = false; + bool prev_broadcast_c = false; + for (int64_t i = 0; i < num_max_dims; ++i) { + const int64_t a_dim = GetADim(i); + const int64_t b_dim = GetBDim(i); + const int64_t c_dim = GetCDim(i); + const int64_t broadcast_dim = std::max(std::max(a_dim, b_dim), c_dim); + CHECK_GT(broadcast_dim, 0); + const bool broadcast_a = (a_dim == 1); + const bool broadcast_b = (b_dim == 1); + const bool broadcast_c = (c_dim == 1); + CHECK((a_dim == broadcast_dim) || broadcast_a); + CHECK((b_dim == broadcast_dim) || broadcast_b); + CHECK((c_dim == broadcast_dim) || broadcast_c); + if (broadcast_dim == 1) { + continue; + } else if (*simplified_num_dims != 0 + && (prev_broadcast_a == broadcast_a && prev_broadcast_b == broadcast_b + && prev_broadcast_c == broadcast_c)) { + simplified_a_dims[*simplified_num_dims - 1] *= a_dim; + simplified_b_dims[*simplified_num_dims - 1] *= b_dim; + simplified_c_dims[*simplified_num_dims - 1] *= c_dim; + simplified_broadcast_dims[*simplified_num_dims - 1] *= broadcast_dim; + } else { + simplified_a_dims[*simplified_num_dims] = a_dim; + simplified_b_dims[*simplified_num_dims] = b_dim; + simplified_c_dims[*simplified_num_dims] = c_dim; + simplified_broadcast_dims[*simplified_num_dims] = broadcast_dim; + *simplified_num_dims += 1; + prev_broadcast_a = broadcast_a; + prev_broadcast_b = broadcast_b; + prev_broadcast_c = broadcast_c; + } + } +} + +template +inline void SimplifyBroadcastDims(size_t num_src0_dims, const int64_t* src0_dims, + size_t num_src1_dims, const int64_t* src1_dims, + size_t* simplified_num_dims, int64_t* simplified_src0_dims, + int64_t* simplified_src1_dims, int64_t* simplified_dst_dims) { + size_t src0_count = GetElementCount(num_src0_dims, src0_dims); + size_t src1_count = GetElementCount(num_src1_dims, src1_dims); + if (src0_count == 1 || src1_count == 1) { + *simplified_num_dims = 1; + simplified_src0_dims[0] = src0_count; + simplified_src1_dims[0] = src1_count; + simplified_dst_dims[0] = std::max(src0_count, src1_count); + return; + } + int64_t dst_dims[max_num_dims]; + int64_t broadcast_dims[max_num_dims]; + const size_t num_dst_dims = std::max(num_src0_dims, num_src1_dims); + for (int64_t i = 0; i < num_dst_dims; ++i) { + const int64_t num_src0_padding_dims = num_dst_dims - num_src0_dims; + const int64_t num_src1_padding_dims = num_dst_dims - num_src1_dims; + size_t src0_dim = i < num_src0_padding_dims ? 1 : src0_dims[i - num_src0_padding_dims]; + size_t src1_dim = i < num_src1_padding_dims ? 1 : src1_dims[i - num_src1_padding_dims]; + dst_dims[i] = std::max(src0_dim, src1_dim); + } + SimplifyBroadcastDims(num_src0_dims, src0_dims, num_src1_dims, src1_dims, num_dst_dims, dst_dims, + simplified_num_dims, broadcast_dims, simplified_src0_dims, + simplified_src1_dims, simplified_dst_dims); + for (int64_t i = 0; i < *simplified_num_dims; ++i) { + CHECK_EQ(broadcast_dims[i], simplified_dst_dims[i]); + } +} + +} // namespace primitive +} // namespace ep + +} // namespace oneflow + +#endif // ONEFLOW_CORE_EP_COMMON_PRIMITIVE_UTIL_H_ diff --git a/oneflow/core/ep/cpu/cpu_device.cpp b/oneflow/core/ep/cpu/cpu_device.cpp index 83f36ba90e7..5a3659417ee 100644 --- a/oneflow/core/ep/cpu/cpu_device.cpp +++ b/oneflow/core/ep/cpu/cpu_device.cpp @@ -38,8 +38,9 @@ void CpuDevice::DestroyEvents(Event** events, size_t count) { Maybe CpuDevice::Alloc(const AllocationOptions& options, void** ptr, size_t size) { if (options.HasPinnedDevice()) { - auto device = Global::Get()->GetDevice( - options.GetPinnedDeviceType(), options.GetPinnedDeviceIndex()); + auto device = + this->device_manager()->registry()->GetDevice(options.GetPinnedDeviceType(), // NOLINT + options.GetPinnedDeviceIndex()); // NOLINT CHECK_OR_RETURN(device); return device->AllocPinned(options, ptr, size); } else { @@ -54,8 +55,9 @@ Maybe CpuDevice::Alloc(const AllocationOptions& options, void** ptr, size_ void CpuDevice::Free(const AllocationOptions& options, void* ptr) { if (options.HasPinnedDevice()) { - auto device = Global::Get()->GetDevice( - options.GetPinnedDeviceType(), options.GetPinnedDeviceIndex()); + auto device = + this->device_manager()->registry()->GetDevice(options.GetPinnedDeviceType(), // NOLINT + options.GetPinnedDeviceIndex()); // NOLINT CHECK(device); return device->FreePinned(options, ptr); } else { diff --git a/oneflow/core/ep/cpu/cpu_device.h b/oneflow/core/ep/cpu/cpu_device.h index f1f3f7eb0cd..093889e592e 100644 --- a/oneflow/core/ep/cpu/cpu_device.h +++ b/oneflow/core/ep/cpu/cpu_device.h @@ -25,13 +25,14 @@ namespace ep { class CpuDevice : public Device { public: OF_DISALLOW_COPY_AND_MOVE(CpuDevice); - CpuDevice() = default; - virtual ~CpuDevice() = default; + explicit CpuDevice(DeviceManager* device_manager) : device_manager_(device_manager) {} + ~CpuDevice() override = default; void SetAsActiveDevice() override; DeviceType device_type() const override { return DeviceType::kCPU; } size_t device_index() const override { return 0; } + DeviceManager* device_manager() const override { return device_manager_; } Stream* CreateStream() override; void DestroyStream(Stream* stream) override; @@ -43,6 +44,9 @@ class CpuDevice : public Device { void Free(const AllocationOptions& options, void* ptr) override; Maybe AllocPinned(const AllocationOptions& options, void** ptr, size_t size) override; void FreePinned(const AllocationOptions& options, void* ptr) override; + + private: + DeviceManager* device_manager_; }; } // namespace ep diff --git a/oneflow/core/ep/cpu/cpu_device_manager.cpp b/oneflow/core/ep/cpu/cpu_device_manager.cpp index b6d4d01a40c..3a03ca1c533 100644 --- a/oneflow/core/ep/cpu/cpu_device_manager.cpp +++ b/oneflow/core/ep/cpu/cpu_device_manager.cpp @@ -20,9 +20,15 @@ namespace oneflow { namespace ep { +CpuDeviceManager::CpuDeviceManager(DeviceManagerRegistry* registry) : registry_(registry) {} + +CpuDeviceManager::~CpuDeviceManager() = default; + +DeviceManagerRegistry* CpuDeviceManager::registry() const { return registry_; } + std::shared_ptr CpuDeviceManager::GetDevice(size_t device_index) { std::lock_guard lock(device_mutex_); - if (!device_) { device_.reset(new CpuDevice()); } + if (!device_) { device_.reset(new CpuDevice(this)); } return device_; } @@ -34,8 +40,6 @@ size_t CpuDeviceManager::GetActiveDeviceIndex() { return 0; } void CpuDeviceManager::SetActiveDeviceByIndex(size_t device_index) {} -REGISTER_EP_DEVICE_MANAGER(DeviceType::kCPU, CpuDeviceManager); - } // namespace ep } // namespace oneflow diff --git a/oneflow/core/ep/cpu/cpu_device_manager.h b/oneflow/core/ep/cpu/cpu_device_manager.h index 172eb18d3f0..dd1e614f2c9 100644 --- a/oneflow/core/ep/cpu/cpu_device_manager.h +++ b/oneflow/core/ep/cpu/cpu_device_manager.h @@ -25,9 +25,10 @@ namespace ep { class CpuDeviceManager : public DeviceManager { public: OF_DISALLOW_COPY_AND_MOVE(CpuDeviceManager); - CpuDeviceManager() = default; - virtual ~CpuDeviceManager() = default; + CpuDeviceManager(DeviceManagerRegistry* registry); + ~CpuDeviceManager() override; + DeviceManagerRegistry* registry() const override; std::shared_ptr GetDevice(size_t device_index) override; size_t GetDeviceCount(size_t primary_device_index) override; size_t GetDeviceCount() override; @@ -37,6 +38,7 @@ class CpuDeviceManager : public DeviceManager { private: std::mutex device_mutex_; std::shared_ptr device_; + DeviceManagerRegistry* registry_; }; } // namespace ep diff --git a/oneflow/core/ep/cpu/cpu_device_manager_factory.cpp b/oneflow/core/ep/cpu/cpu_device_manager_factory.cpp index 335ee97f126..e1dd053ec62 100644 --- a/oneflow/core/ep/cpu/cpu_device_manager_factory.cpp +++ b/oneflow/core/ep/cpu/cpu_device_manager_factory.cpp @@ -29,8 +29,8 @@ class CpuDeviceManagerFactory : public DeviceManagerFactory { CpuDeviceManagerFactory() = default; ~CpuDeviceManagerFactory() override = default; - std::unique_ptr NewDeviceManager() override { - return std::make_unique(); + std::unique_ptr NewDeviceManager(DeviceManagerRegistry* registry) override { + return std::make_unique(registry); } DeviceType device_type() const override { return DeviceType::kCPU; } diff --git a/oneflow/core/ep/cpu/primitive/binary_functor.h b/oneflow/core/ep/cpu/primitive/binary_functor.h new file mode 100644 index 00000000000..4addb4044bc --- /dev/null +++ b/oneflow/core/ep/cpu/primitive/binary_functor.h @@ -0,0 +1,39 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/ep/common/primitive/binary_functor.h" + +namespace oneflow { + +namespace ep { +namespace primitive { +namespace broadcast_elementwise_binary { + +template +struct BinaryFunctor { + OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return std::pow(src0, src1); } +}; + +template<> +struct BinaryFunctor { + OF_DEVICE_FUNC float16 operator()(float16 src0, float16 src1) const { + return static_cast(std::pow(static_cast(src0), static_cast(src1))); + } +}; + +} // namespace broadcast_elementwise_binary +} // namespace primitive +} // namespace ep +} // namespace oneflow diff --git a/oneflow/core/ep/cpu/primitive/broadcast_elementwise_binary.cpp b/oneflow/core/ep/cpu/primitive/broadcast_elementwise_binary.cpp new file mode 100644 index 00000000000..cf1c55dd5b3 --- /dev/null +++ b/oneflow/core/ep/cpu/primitive/broadcast_elementwise_binary.cpp @@ -0,0 +1,190 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h" +#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h" +#include "oneflow/core/ep/cpu/primitive/binary_functor.h" +#include "oneflow/core/ep/cpu/primitive/type_seq.h" +#include "oneflow/core/ndarray/ndarray_util.h" +#include "oneflow/core/ndarray/xpu_var_ndarray.h" + +namespace oneflow { + +namespace ep { +namespace primitive { +namespace broadcast_elementwise_binary { + +namespace { + +template +T GetValue(Scalar value) { + return value.Value(); +} + +template<> +float16 GetValue(Scalar value) { + return static_cast(GetValue(value)); +} + +template& z, + const XpuVarNdarray& x, const XpuVarNdarray& y)> +class BroadcastElementwiseBinaryImpl : public BroadcastElementwiseBinary { + public: + OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinaryImpl); + BroadcastElementwiseBinaryImpl() = default; + ~BroadcastElementwiseBinaryImpl() override = default; + + void Launch(Stream* stream, Scalar src0, size_t num_src1_dims, const int64_t* src1_dims, + const void* src1, void* dst) override { + int64_t elem_cnt = GetElementCount(num_src1_dims, src1_dims); + Src src0_val = GetValue(src0); + binary_func(stream, XpuVarNdarray(Shape({elem_cnt}), reinterpret_cast(dst), 1), + XpuVarNdarray(Shape({1}), &src0_val, 1), + XpuVarNdarray(Shape({elem_cnt}), reinterpret_cast(src1), 1)); + } + void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0, + Scalar src1, void* dst) override { + int64_t elem_cnt = GetElementCount(num_src0_dims, src0_dims); + Src src1_val = GetValue(src1); + binary_func(stream, XpuVarNdarray(Shape({elem_cnt}), reinterpret_cast(dst), 1), + XpuVarNdarray(Shape({elem_cnt}), reinterpret_cast(src0), 1), + XpuVarNdarray(Shape({1}), &src1_val, 1)); + } + void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0, + size_t num_src1_dims, const int64_t* src1_dims, const void* src1, + void* dst) override { + DimVector src0_dim_vec; + DimVector src1_dim_vec; + DimVector dst_dim_vec; + size_t num_dims = 0; + int64_t simplified_src0_dims[kMaxNumDims]; + int64_t simplified_src1_dims[kMaxNumDims]; + int64_t simplified_dst_dims[kMaxNumDims]; + SimplifyBroadcastDims(num_src0_dims, src0_dims, num_src1_dims, src1_dims, + &num_dims, simplified_src0_dims, simplified_src1_dims, + simplified_dst_dims); + CheckInplace(num_dims, simplified_src0_dims, src0, simplified_src1_dims, src1, + simplified_dst_dims, dst); + for (int64_t i = 0; i < num_dims; ++i) { + src0_dim_vec.push_back(simplified_src0_dims[i]); + src1_dim_vec.push_back(simplified_src1_dims[i]); + dst_dim_vec.push_back(simplified_dst_dims[i]); + } + binary_func( + stream, XpuVarNdarray(Shape(dst_dim_vec), reinterpret_cast(dst), num_dims), + XpuVarNdarray(Shape(src0_dim_vec), reinterpret_cast(src0), num_dims), + XpuVarNdarray(Shape(src1_dim_vec), reinterpret_cast(src1), + num_dims)); + } +}; + +template& z, + const XpuVarNdarray& x, const XpuVarNdarray& y)> +std::unique_ptr NewBroadcastElementwiseBinary() { + return std::unique_ptr( + new BroadcastElementwiseBinaryImpl()); +} + +#define BINARY_MATH_OP_NDARRAY_PAIR \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kAdd, Add) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSub, Sub) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMul, Mul) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kDiv, Div) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMax, Max) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kMin, Min) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kPow, Pow) + +#define NDARRAY_BINARY_TYPE_SEQ \ + CPU_PRIMITIVE_INT8_TYPE_SEQ \ + CPU_PRIMITIVE_UINT8_TYPE_SEQ \ + CPU_PRIMITIVE_INT32_TYPE_SEQ \ + CPU_PRIMITIVE_INT64_TYPE_SEQ \ + CPU_PRIMITIVE_FLOAT_TYPE_SEQ \ + CPU_PRIMITIVE_DOUBLE_TYPE_SEQ \ + CPU_PRIMITIVE_FLOAT16_TYPE_SEQ + +#define BINARY_LOGICAL_COMPARISION_OP_NDARRAY_PAIR \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kEqual, EQ) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kNotEqual, NE) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessThan, LT) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLessEqual, LE) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterThan, GT) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kGreaterEqual, GE) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalAnd, AND) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalOr, OR) \ + OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kLogicalXor, XOR) + +class BroadcastElementwiseBinaryFactoryImpl : public BroadcastElementwiseBinaryFactory { + public: + OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinaryFactoryImpl); + BroadcastElementwiseBinaryFactoryImpl() = default; + ~BroadcastElementwiseBinaryFactoryImpl() override = default; + + std::unique_ptr New(BinaryOp binary_op, DataType src_type, + DataType dst_type, size_t max_num_dims) override { + if (max_num_dims > kMaxNumDims) { return nullptr; } +#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op_pair, data_type_pair) \ + {std::make_tuple(OF_PP_PAIR_FIRST(binary_op_pair), OF_PP_PAIR_SECOND(data_type_pair), \ + OF_PP_PAIR_SECOND(data_type_pair)), \ + NewBroadcastElementwiseBinary< \ + OF_PP_PAIR_FIRST(binary_op_pair), OF_PP_PAIR_FIRST(data_type_pair), \ + OF_PP_PAIR_FIRST(data_type_pair), \ + &NdarrayUtil::OF_PP_CAT( \ + Broadcast, OF_PP_PAIR_SECOND(binary_op_pair))>}, + +#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY( \ + binary_op_pair, src_data_type_pair, dst_data_type_pair) \ + {std::make_tuple(OF_PP_PAIR_FIRST(binary_op_pair), OF_PP_PAIR_SECOND(src_data_type_pair), \ + OF_PP_PAIR_SECOND(dst_data_type_pair)), \ + NewBroadcastElementwiseBinary< \ + OF_PP_PAIR_FIRST(binary_op_pair), OF_PP_PAIR_FIRST(src_data_type_pair), \ + OF_PP_PAIR_FIRST(dst_data_type_pair), \ + &NdarrayUtil::OF_PP_CAT( \ + Broadcast, OF_PP_PAIR_SECOND(binary_op_pair))>}, + + static const std::map, + std::function()>> + new_broadcast_elementwise_binary_handle{ + OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, + BINARY_MATH_OP_NDARRAY_PAIR, NDARRAY_BINARY_TYPE_SEQ) + OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( + MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY, + BINARY_LOGICAL_COMPARISION_OP_NDARRAY_PAIR, NDARRAY_BINARY_TYPE_SEQ, + CPU_PRIMITIVE_INT8_TYPE_SEQ)}; + +#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY +#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY + const auto it = new_broadcast_elementwise_binary_handle.find( + std::make_tuple(binary_op, src_type, dst_type)); + if (it != new_broadcast_elementwise_binary_handle.end()) { + return it->second(); + } else { + return nullptr; + } + } +}; + +REGISTER_PRIMITIVE_FACTORY(DeviceType::kCPU, BroadcastElementwiseBinaryFactory, + BroadcastElementwiseBinaryFactoryImpl); + +} // namespace +} // namespace broadcast_elementwise_binary +} // namespace primitive +} // namespace ep + +} // namespace oneflow diff --git a/oneflow/core/ep/cuda/cuda_device.cpp b/oneflow/core/ep/cuda/cuda_device.cpp index aff5f9d450a..dbd3716da0a 100644 --- a/oneflow/core/ep/cuda/cuda_device.cpp +++ b/oneflow/core/ep/cuda/cuda_device.cpp @@ -23,8 +23,8 @@ namespace oneflow { namespace ep { -CudaDevice::CudaDevice(int device_index) - : device_index_(device_index), event_flags_{}, properties_{} { +CudaDevice::CudaDevice(int device_index, DeviceManager* device_manager) + : device_index_(device_index), event_flags_{}, properties_{}, device_manager_(device_manager) { CudaCurrentDeviceGuard guard(device_index_); OF_CUDA_CHECK(cudaGetDeviceProperties(&properties_, device_index_)); event_flags_ = cudaEventDisableTiming; diff --git a/oneflow/core/ep/cuda/cuda_device.h b/oneflow/core/ep/cuda/cuda_device.h index 320db9548b8..2885911fa93 100644 --- a/oneflow/core/ep/cuda/cuda_device.h +++ b/oneflow/core/ep/cuda/cuda_device.h @@ -29,13 +29,14 @@ namespace ep { class CudaDevice : public Device { public: OF_DISALLOW_COPY_AND_MOVE(CudaDevice); - explicit CudaDevice(int device_index); - virtual ~CudaDevice(); + explicit CudaDevice(int device_index, DeviceManager* device_manager); + ~CudaDevice() override; void SetAsActiveDevice() override; DeviceType device_type() const override { return DeviceType::kCUDA; } size_t device_index() const override { return device_index_; } + DeviceManager* device_manager() const override { return device_manager_; } Stream* CreateStream() override; void DestroyStream(Stream* stream) override; @@ -56,6 +57,7 @@ class CudaDevice : public Device { std::vector events_; unsigned int event_flags_; cudaDeviceProp properties_; + DeviceManager* device_manager_; }; } // namespace ep diff --git a/oneflow/core/ep/cuda/cuda_device_manager.cpp b/oneflow/core/ep/cuda/cuda_device_manager.cpp index b19bfd3b035..cf221e329e9 100644 --- a/oneflow/core/ep/cuda/cuda_device_manager.cpp +++ b/oneflow/core/ep/cuda/cuda_device_manager.cpp @@ -23,12 +23,17 @@ namespace oneflow { namespace ep { +CudaDeviceManager::CudaDeviceManager(DeviceManagerRegistry* registry) : registry_(registry) {} +CudaDeviceManager::~CudaDeviceManager() = default; + +DeviceManagerRegistry* CudaDeviceManager::registry() const { return registry_; } + std::shared_ptr CudaDeviceManager::GetDevice(size_t device_index) { std::lock_guard lock(devices_mutex_); if (device_index < devices_.size() && devices_.at(device_index)) { return devices_.at(device_index); } - auto device = std::make_shared(device_index); + auto device = std::make_shared(device_index, this); if (device_index >= devices_.size()) { devices_.resize(device_index + 1); } devices_.at(device_index) = device; return device; @@ -55,8 +60,6 @@ void CudaDeviceManager::SetActiveDeviceByIndex(size_t device_index) { OF_CUDA_CHECK(cudaSetDevice(static_cast(device_index))); } -REGISTER_EP_DEVICE_MANAGER(DeviceType::kCUDA, CudaDeviceManager); - } // namespace ep } // namespace oneflow diff --git a/oneflow/core/ep/cuda/cuda_device_manager.h b/oneflow/core/ep/cuda/cuda_device_manager.h index 76fe434fba0..88ae1f6b86d 100644 --- a/oneflow/core/ep/cuda/cuda_device_manager.h +++ b/oneflow/core/ep/cuda/cuda_device_manager.h @@ -27,9 +27,10 @@ namespace ep { class CudaDeviceManager : public DeviceManager { public: OF_DISALLOW_COPY_AND_MOVE(CudaDeviceManager); - CudaDeviceManager() = default; - virtual ~CudaDeviceManager() = default; + CudaDeviceManager(DeviceManagerRegistry* registry); + ~CudaDeviceManager() override; + DeviceManagerRegistry* registry() const override; std::shared_ptr GetDevice(size_t device_index) override; size_t GetDeviceCount(size_t primary_device_index) override; size_t GetDeviceCount() override; @@ -39,6 +40,7 @@ class CudaDeviceManager : public DeviceManager { private: std::mutex devices_mutex_; std::vector> devices_; + DeviceManagerRegistry* registry_; }; } // namespace ep diff --git a/oneflow/core/ep/cuda/cuda_device_manager_factory.cpp b/oneflow/core/ep/cuda/cuda_device_manager_factory.cpp index 0ae8def4031..b925c381fff 100644 --- a/oneflow/core/ep/cuda/cuda_device_manager_factory.cpp +++ b/oneflow/core/ep/cuda/cuda_device_manager_factory.cpp @@ -96,8 +96,8 @@ class CudaDeviceManagerFactory : public DeviceManagerFactory { CudaDeviceManagerFactory() = default; ~CudaDeviceManagerFactory() override = default; - std::unique_ptr NewDeviceManager() override { - return std::make_unique(); + std::unique_ptr NewDeviceManager(DeviceManagerRegistry* registry) override { + return std::make_unique(registry); } DeviceType device_type() const override { return DeviceType::kCUDA; } diff --git a/oneflow/core/ep/cuda/cuda_stream.cpp b/oneflow/core/ep/cuda/cuda_stream.cpp index 1b0a99bcc92..a35ef90121c 100644 --- a/oneflow/core/ep/cuda/cuda_stream.cpp +++ b/oneflow/core/ep/cuda/cuda_stream.cpp @@ -42,6 +42,10 @@ void SetAffinityByDevice(int dev_id) { node_device_desc->Topology()->SetMemoryAffinityByPCIBusID(cuda_device->PCIBusID()); } +bool IsCuda9OnTuringDevice(const cudaDeviceProp& prop) { + return CUDA_VERSION >= 9000 && CUDA_VERSION < 9020 && prop.major == 7 && prop.minor == 5; +} + } // namespace #ifdef WITH_CUDA_GRAPHS @@ -89,7 +93,7 @@ CudaStream::CudaStream(CudaDevice* device) OF_CUBLAS_CHECK(cublasCreate(&cublas_handle_)); OF_CUBLAS_CHECK(cublasSetStream(cublas_handle_, cuda_stream_)); #if CUBLAS_VERSION >= 11000 - if (Global::Get()->enable_tensor_float_32_compute()) { + if (ParseBooleanFromEnv("ONEFLOW_EP_CUDA_ENABLE_TF32_EXECUTION", true)) { OF_CUBLAS_CHECK(cublasSetMathMode(cublas_handle_, CUBLAS_TF32_TENSOR_OP_MATH)); } #endif // CUBLAS_VERSION >= 11000 @@ -99,12 +103,12 @@ CudaStream::CudaStream(CudaDevice* device) OF_CUBLAS_CHECK(cublasSetWorkspace(cublas_handle_, workspace_, workspace_size_)); #endif // CUBLAS_VERSION >= 11200 // cudnn_handle - if (IsCuda9OnTuringDevice()) { + if (IsCuda9OnTuringDevice(device_properties())) { OF_CUDA_CHECK(cudaDeviceSynchronize()); OF_CUDA_CHECK(cudaGetLastError()); } OF_CUDNN_CHECK(cudnnCreate(&cudnn_handle_)); - if (IsCuda9OnTuringDevice()) { + if (IsCuda9OnTuringDevice(device_properties())) { OF_CUDA_CHECK(cudaDeviceSynchronize()); cudaGetLastError(); } diff --git a/oneflow/core/ep/cuda/cuda_stream.h b/oneflow/core/ep/cuda/cuda_stream.h index f6575fd4a27..6e204f264f5 100644 --- a/oneflow/core/ep/cuda/cuda_stream.h +++ b/oneflow/core/ep/cuda/cuda_stream.h @@ -56,12 +56,24 @@ class CudaGraphExecutable { #endif // WITH_CUDA_GRAPHS +struct CudaLaunchConfig { + dim3 grid_dim; + dim3 block_dim; + size_t shared_mem_size; + CudaLaunchConfig() : grid_dim{}, block_dim{}, shared_mem_size(0) {} + + CudaLaunchConfig(unsigned int grid_size, unsigned int block_size, size_t shared_mem_size) + : grid_dim(grid_size), block_dim(block_size), shared_mem_size(shared_mem_size) {} +}; + class CudaStream : public Stream { public: OF_DISALLOW_COPY_AND_MOVE(CudaStream); explicit CudaStream(CudaDevice* device); ~CudaStream() override; + static constexpr uint32_t kDefaultBlockSize = 256; + DeviceType device_type() const override; Device* device() const override; Maybe Sync() override; @@ -75,6 +87,40 @@ class CudaStream : public Stream { cudnnHandle_t cudnn_handle() const; const cudaDeviceProp& device_properties() const; + void InitLaunchConfigWithWaves(CudaLaunchConfig* config, size_t elem_cnt, size_t block_size, + size_t max_waves) const { + const uint32_t max_grid_size = max_waves * device_properties().multiProcessorCount + * (device_properties().maxThreadsPerMultiProcessor / block_size); + const uint32_t grid_size = + std::min(max_grid_size, (elem_cnt + block_size - 1) / block_size); + config->grid_dim = dim3(grid_size); + config->block_dim = dim3(block_size); + config->shared_mem_size = 0; + } + +#ifdef __CUDACC__ + template + void LaunchKernel(void (*kernel)(Params...), const CudaLaunchConfig& launch_config, + Args... args) { + kernel<<>>(args...); + } + + template + void LaunchKernel(void (*kernel)(Params...), size_t elem_cnt, size_t max_waves, Args... args) { + constexpr uint32_t block_size = kDefaultBlockSize; + CudaLaunchConfig config{}; + InitLaunchConfigWithWaves(&config, elem_cnt, block_size, max_waves); + LaunchKernel(kernel, config, args...); + } + + template + void LaunchKernelDefaultWaves(void (*kernel)(Params...), size_t elem_cnt, Args... args) { + const size_t default_waves = 32; + LaunchKernel(kernel, elem_cnt, default_waves, args...); + } +#endif // __CUDACC__ + #ifdef WITH_CUDA_GRAPHS void BeginGraphCapture(); void EndGraphCapture(CudaGraphExecutable* executable); diff --git a/oneflow/core/ep/cuda/primitive/binary_functor.cuh b/oneflow/core/ep/cuda/primitive/binary_functor.cuh new file mode 100644 index 00000000000..98ce1a481e1 --- /dev/null +++ b/oneflow/core/ep/cuda/primitive/binary_functor.cuh @@ -0,0 +1,50 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/ep/common/primitive/binary_functor.h" + +namespace oneflow { +namespace ep { +namespace primitive { +namespace broadcast_elementwise_binary { + +template +struct BinaryFunctor { + OF_DEVICE_FUNC Dst operator()(Src src0, Src src1) const { return pow(src0, src1); } +}; + +template<> +struct BinaryFunctor { + OF_DEVICE_FUNC half operator()(half src0, half src1) const { + return static_cast(pow(static_cast(src0), static_cast(src1))); + } +}; + +#if CUDA_VERSION >= 11000 + +template<> +struct BinaryFunctor { + OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src0, nv_bfloat16 src1) const { + return static_cast(pow(static_cast(src0), static_cast(src1))); + } +}; + +#endif // CUDA_VERSION >= 11000 + +} // namespace broadcast_elementwise_binary +} // namespace primitive +} // namespace ep +} // namespace oneflow diff --git a/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cu b/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cu new file mode 100644 index 00000000000..a45a6e293cc --- /dev/null +++ b/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cu @@ -0,0 +1,86 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/ep/include/primitive//broadcast_elementwise_binary.h" +#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h" +#include "oneflow/core/ep/cuda/primitive/type_seq.h" +#include "oneflow/core/ep/cuda/cuda_stream.h" +#include "oneflow/core/cuda/elementwise.cuh" +#include "oneflow/core/ep/cuda/primitive/binary_functor.cuh" + +namespace oneflow { + +namespace ep { +namespace primitive { +namespace broadcast_elementwise_binary { + +template +std::unique_ptr NewBroadcastElementwiseBinary(); + +namespace { + +class BroadcastElementwiseBinaryFactoryImpl : public BroadcastElementwiseBinaryFactory { + public: + OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinaryFactoryImpl); + BroadcastElementwiseBinaryFactoryImpl() = default; + ~BroadcastElementwiseBinaryFactoryImpl() override = default; + + std::unique_ptr New(BinaryOp binary_op, DataType src_type, + DataType dst_type, size_t max_num_dims) override { + if (max_num_dims > kMaxNumDims) { return nullptr; } +#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \ + {std::make_tuple(binary_op, OF_PP_PAIR_SECOND(data_type_pair), \ + OF_PP_PAIR_SECOND(data_type_pair)), \ + NewBroadcastElementwiseBinary}, + +#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY( \ + binary_op, src_data_type_pair, dst_data_type_pair) \ + {std::make_tuple(binary_op, OF_PP_PAIR_SECOND(src_data_type_pair), \ + OF_PP_PAIR_SECOND(dst_data_type_pair)), \ + NewBroadcastElementwiseBinary}, + + static const std::map, + std::function()>> + new_broadcast_elementwise_binary_handle{ + OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, + BINARY_MATH_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ) + OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( + MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY, + BINARY_COMPARISION_OP_SEQ BINARY_LOGICAL_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ, + CUDA_PRIMITIVE_INT8_TYPE_SEQ)}; + +#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY +#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY + + const auto it = new_broadcast_elementwise_binary_handle.find( + std::make_tuple(binary_op, src_type, dst_type)); + if (it != new_broadcast_elementwise_binary_handle.end()) { + return it->second(); + } else { + return nullptr; + } + } +}; + +REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, BroadcastElementwiseBinaryFactory, + BroadcastElementwiseBinaryFactoryImpl); +} // namespace +} // namespace broadcast_elementwise_binary +} // namespace primitive +} // namespace ep + +} // namespace oneflow diff --git a/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh b/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh new file mode 100644 index 00000000000..2e4df78afa5 --- /dev/null +++ b/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh @@ -0,0 +1,377 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/ep/include/primitive//broadcast_elementwise_binary.h" +#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h" +#include "oneflow/core/ep/cuda/primitive/type_seq.h" +#include "oneflow/core/ep/cuda/cuda_stream.h" +#include "oneflow/core/cuda/elementwise.cuh" +#include "oneflow/core/ep/cuda/primitive/binary_functor.cuh" + +namespace oneflow { + +namespace ep { +namespace primitive { +namespace broadcast_elementwise_binary { + +namespace { + +template +struct GetPackType { + using type = typename std::aligned_storage::type; +}; + +template +using PackType = typename GetPackType::type; + +template +union Pack { + static_assert(sizeof(PackType) == sizeof(T) * N, ""); + OF_DEVICE_FUNC Pack() { + // do nothing + } + PackType storage; + T elem[N]; +}; + +template +struct BroadcastElementwiseBinaryParams { + NdIndexOffsetHelper src0_index_helper; + NdIndexOffsetHelper src1_index_helper; + NdIndexOffsetHelper dst_index_helper; + size_t num_dims; + IndexType src0_index_mask[max_dims]; + IndexType src1_index_mask[max_dims]; + IndexType count{}; + const void* src0{}; + const void* src1{}; + void* dst{}; +}; + +template +__global__ void BroadcastElementwiseBinaryGpu( + BroadcastElementwiseBinaryParams params) { + constexpr size_t dst_pack_size = + src0_pack_size > src1_pack_size ? src0_pack_size : src1_pack_size; + static_assert(src0_pack_size == dst_pack_size || src0_pack_size == 1, ""); + static_assert(src1_pack_size == dst_pack_size || src1_pack_size == 1, ""); + + const PackType* src0 = + reinterpret_cast*>(params.src0); + const PackType* src1 = + reinterpret_cast*>(params.src1); + PackType* dst = reinterpret_cast*>(params.dst); + + IndexType src0_index[max_dims]; + IndexType src1_index[max_dims]; + IndexType dst_index[max_dims]; + size_t num_dims = params.num_dims; + CUDA_1D_KERNEL_LOOP_T(IndexType, offset, params.count) { + params.dst_index_helper.OffsetToNdIndex(offset, dst_index, num_dims); +#pragma unroll + for (int i = 0; i < max_dims; ++i) { + if (i < num_dims) { + src0_index[i] = params.src0_index_mask[i] * dst_index[i]; + src1_index[i] = params.src1_index_mask[i] * dst_index[i]; + } + } + const IndexType src0_offset = params.src0_index_helper.NdIndexToOffset(src0_index, num_dims); + const IndexType src1_offset = params.src1_index_helper.NdIndexToOffset(src1_index, num_dims); + Pack src0_pack; + src0_pack.storage = src0[src0_offset]; + Pack src1_pack; + src1_pack.storage = src1[src1_offset]; + Pack dst_pack; +#pragma unroll + for (int j = 0; j < dst_pack_size; ++j) { + const Src src0_val = + (src0_pack_size == dst_pack_size) ? src0_pack.elem[j] : src0_pack.elem[0]; + const Src src1_val = + (src1_pack_size == dst_pack_size) ? src1_pack.elem[j] : src1_pack.elem[0]; + dst_pack.elem[j] = + BinaryFunctor()(src0_val, src1_val); + } + dst[offset] = dst_pack.storage; + } +} + +template +void LaunchKernel(Stream* stream, int num_dims, const int64_t* src0_dims, const void* src0, + const int64_t* src1_dims, const void* src1, const int64_t* dst_dims, void* dst, + size_t count) { + BroadcastElementwiseBinaryParams params; + for (size_t i = 0; i < num_dims; ++i) { + params.src0_index_mask[i] = (src0_dims[i] == 1) ? 0 : 1; + params.src1_index_mask[i] = (src1_dims[i] == 1) ? 0 : 1; + } + params.src0_index_helper = NdIndexOffsetHelper(src0_dims, num_dims); + params.src1_index_helper = NdIndexOffsetHelper(src1_dims, num_dims); + params.dst_index_helper = NdIndexOffsetHelper(dst_dims, num_dims); + params.num_dims = num_dims; + params.src0 = src0; + params.src1 = src1; + params.dst = dst; + params.count = static_cast(count); + auto* cuda_stream = stream->As(); + BroadcastElementwiseBinaryGpu + <<cuda_stream()>>>(params); +} + +template +void DispatchIndexType(Stream* stream, size_t num_dims, const int64_t* src0_dims, const void* src0, + const int64_t* src1_dims, const void* src1, const int64_t* dst_dims, + void* dst) { + size_t count = GetElementCount(num_dims, dst_dims); + if (count < GetMaxVal()) { + LaunchKernel( + stream, num_dims, src0_dims, src0, src1_dims, src1, dst_dims, dst, count); + } else { + LaunchKernel( + stream, num_dims, src0_dims, src0, src1_dims, src1, dst_dims, dst, count); + } +} + +template +void DispatchPackSize(Stream* stream, size_t src0_pack_size, size_t src1_pack_size, size_t num_dims, + const int64_t* src0_dims, const void* src0, const int64_t* src1_dims, + const void* src1, const int64_t* dst_dims, void* dst) { + void (*func)(Stream* /*stream*/, size_t /*num_dims*/, const int64_t* /*src0_dims*/, + const void* /*src0*/, const int64_t* /*src1_dims*/, const void* /*src1*/, + const int64_t* /*dst_dims*/, void* /*dst*/) = nullptr; + if (src0_pack_size == 1 && src1_pack_size == 1) { + func = DispatchIndexType; + } else if (src0_pack_size == 4 && src1_pack_size == 4) { + func = DispatchIndexType; + } else if (src0_pack_size == 1 && src1_pack_size == 4) { + func = DispatchIndexType; + } else if (src0_pack_size == 4 && src1_pack_size == 1) { + func = DispatchIndexType; + } else { + UNIMPLEMENTED(); + } + func(stream, num_dims, src0_dims, src0, src1_dims, src1, dst_dims, dst); +} + +template +void DispatchNumDims(Stream* stream, size_t src0_pack_size, size_t src1_pack_size, size_t num_dims, + const int64_t* src0_dims, const void* src0, const int64_t* src1_dims, + const void* src1, const int64_t* dst_dims, void* dst) { + void (*func)(Stream* /*stream*/, size_t /*src0_pack_size*/, size_t /*src1_pack_size*/, + size_t /*num_dims*/, const int64_t* /*src0_dims*/, const void* /*src0*/, + const int64_t* /*src1_dims*/, const void* /*src1*/, const int64_t* /*dst_dims*/, + void* /*dst*/) = nullptr; + CHECK_NE(num_dims, 1); + if (num_dims == 2) { + func = DispatchPackSize; + } else if (num_dims == 3) { + func = DispatchPackSize; + } else if (num_dims == 4) { + func = DispatchPackSize; + } else if (num_dims <= 8) { + func = DispatchPackSize; + } else { + UNIMPLEMENTED(); + } + func(stream, src0_pack_size, src1_pack_size, num_dims, src0_dims, src0, src1_dims, src1, dst_dims, + dst); +} + +template +size_t GetPackSize(size_t num_src_dims, const int64_t* src0_dims, const void* src0, + const int64_t* src1_dims, const void* src1, void* dst) { + static_assert(max_pack_size > 0 && (max_pack_size & (max_pack_size - 1)) == 0, ""); + CHECK(src0_dims[num_src_dims - 1] != 1 || src1_dims[num_src_dims - 1] != 1); + auto dst_ptr = reinterpret_cast(dst); + for (size_t pack_size = max_pack_size; pack_size > 2; pack_size /= 2) { + bool is_src0_supported = (src0_dims[num_src_dims - 1] == 1) + || IsPackSizeSupported(pack_size, num_src_dims, src0_dims, src0); + bool is_src1_supported = (src1_dims[num_src_dims - 1] == 1) + || IsPackSizeSupported(pack_size, num_src_dims, src1_dims, src1); + if (is_src0_supported && is_src1_supported && (dst_ptr % (pack_size * sizeof(R))) == 0) { + return pack_size; + } + } + return 1; +} + +constexpr size_t kMaxPackSize = 4; + +template +void LaunchWithSimplified(Stream* stream, size_t simplified_num_dims, int64_t* simplified_src0_dims, + const void* src0, int64_t* simplified_src1_dims, const void* src1, + int64_t* simplified_dst_dims, void* dst) { + CHECK_LE(simplified_num_dims, kMaxNumDims); + size_t pack_size = GetPackSize(simplified_num_dims, simplified_src0_dims, + src0, simplified_src1_dims, src1, dst); + size_t src0_pack_size = 1; + size_t src1_pack_size = 1; + if (simplified_src0_dims[simplified_num_dims - 1] != 1) { + simplified_src0_dims[simplified_num_dims - 1] /= pack_size; + src0_pack_size = pack_size; + } + if (simplified_src1_dims[simplified_num_dims - 1] != 1) { + simplified_src1_dims[simplified_num_dims - 1] /= pack_size; + src1_pack_size = pack_size; + } + simplified_dst_dims[simplified_num_dims - 1] /= pack_size; + DispatchNumDims(stream, src0_pack_size, src1_pack_size, simplified_num_dims, + simplified_src0_dims, src0, simplified_src1_dims, src1, + simplified_dst_dims, dst); +} + +template +struct BinaryLhsScalarFunctor { + __host__ __device__ explicit BinaryLhsScalarFunctor(Src scalar) : scalar(scalar) {} + __device__ Dst operator()(Src src) const { + return BinaryFunctor()(scalar, src); + } + const Src scalar; +}; + +template +struct BinaryRhsScalarFunctor { + __host__ __device__ explicit BinaryRhsScalarFunctor(Src scalar) : scalar(scalar) {} + __device__ Dst operator()(Src src) const { + return BinaryFunctor()(src, scalar); + } + const Src scalar; +}; + +template +struct BinaryLhsScalarPtrFunctorFactory { + __host__ __device__ explicit BinaryLhsScalarPtrFunctorFactory(const Src* scalar_ptr) + : scalar_ptr(scalar_ptr) {} + __device__ BinaryLhsScalarFunctor operator()() const { + return BinaryLhsScalarFunctor(*scalar_ptr); + } + const Src* scalar_ptr; +}; + +template +struct BinaryRhsScalarPtrFunctorFactory { + __host__ __device__ explicit BinaryRhsScalarPtrFunctorFactory(const Src* scalar_ptr) + : scalar_ptr(scalar_ptr) {} + __device__ BinaryRhsScalarFunctor operator()() const { + return BinaryRhsScalarFunctor(*scalar_ptr); + } + const Src* scalar_ptr; +}; + +template +void DispatchLaunch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const Src* src0, + size_t num_src1_dims, const int64_t* src1_dims, const Src* src1, Dst* dst) { + auto* cuda_stream = stream->As(); + size_t simplified_num_dims = 0; + int64_t simplified_src0_dims[kMaxNumDims]; + int64_t simplified_src1_dims[kMaxNumDims]; + int64_t simplified_dst_dims[kMaxNumDims]; + SimplifyBroadcastDims(num_src0_dims, src0_dims, num_src1_dims, src1_dims, + &simplified_num_dims, simplified_src0_dims, + simplified_src1_dims, simplified_dst_dims); + CheckInplace(simplified_num_dims, simplified_src0_dims, src0, simplified_src1_dims, src1, + simplified_dst_dims, dst); + if (IsDimsEquals(simplified_num_dims, simplified_src0_dims, simplified_num_dims, + simplified_src1_dims)) { + const int64_t elem_cnt = GetElementCount(simplified_num_dims, simplified_src0_dims); + OF_CUDA_CHECK( + (cuda::elementwise::Binary(BinaryFunctor(), + elem_cnt, dst, src0, src1, cuda_stream->cuda_stream()))); + } else { + if (simplified_num_dims == 1 && simplified_src0_dims[0] == 1) { + OF_CUDA_CHECK((cuda::elementwise::UnaryWithFactory( + BinaryLhsScalarPtrFunctorFactory(src0), simplified_src1_dims[0], dst, + src1, cuda_stream->cuda_stream()))); + } else if (simplified_num_dims == 1 && simplified_src1_dims[0] == 1) { + OF_CUDA_CHECK((cuda::elementwise::UnaryWithFactory( + BinaryRhsScalarPtrFunctorFactory(src1), simplified_src0_dims[0], dst, + src0, cuda_stream->cuda_stream()))); + } else { + LaunchWithSimplified(stream, simplified_num_dims, simplified_src0_dims, + src0, simplified_src1_dims, src1, + simplified_dst_dims, dst); + } + } +} + +template +T GetValue(Scalar value) { + return value.Value(); +} + +template<> +half GetValue(Scalar value) { + return static_cast(GetValue(value)); +} + +#if CUDA_VERSION >= 11000 + +template<> +nv_bfloat16 GetValue(Scalar value) { + return static_cast(GetValue(value)); +} + +#endif // CUDA_VERSION >= 11000 + +template +class BroadcastElementwiseBinaryImpl : public BroadcastElementwiseBinary { + public: + OF_DISALLOW_COPY_AND_MOVE(BroadcastElementwiseBinaryImpl); + BroadcastElementwiseBinaryImpl() = default; + ~BroadcastElementwiseBinaryImpl() override = default; + + void Launch(Stream* stream, Scalar src0, size_t num_src1_dims, const int64_t* src1_dims, + const void* src1, void* dst) override { + auto* cuda_stream = stream->As(); + const size_t elem_cnt = GetElementCount(num_src1_dims, src1_dims); + OF_CUDA_CHECK( + (cuda::elementwise::Unary(BinaryLhsScalarFunctor(GetValue(src0)), + elem_cnt, reinterpret_cast(dst), + reinterpret_cast(src1), cuda_stream->cuda_stream()))); + } + void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0, + Scalar src1, void* dst) override { + auto* cuda_stream = stream->As(); + const size_t elem_cnt = GetElementCount(num_src0_dims, src0_dims); + OF_CUDA_CHECK( + (cuda::elementwise::Unary(BinaryRhsScalarFunctor(GetValue(src1)), + elem_cnt, reinterpret_cast(dst), + reinterpret_cast(src0), cuda_stream->cuda_stream()))); + } + void Launch(Stream* stream, size_t num_src0_dims, const int64_t* src0_dims, const void* src0, + size_t num_src1_dims, const int64_t* src1_dims, const void* src1, + void* dst) override { + DispatchLaunch( + stream, num_src0_dims, src0_dims, reinterpret_cast(src0), num_src1_dims, + src1_dims, reinterpret_cast(src1), reinterpret_cast(dst)); + } +}; + +} // namespace + +template +std::unique_ptr NewBroadcastElementwiseBinary() { + return std::unique_ptr( + new BroadcastElementwiseBinaryImpl()); +} + +} // namespace broadcast_elementwise_binary +} // namespace primitive +} // namespace ep + +} // namespace oneflow diff --git a/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_comparision.cu b/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_comparision.cu new file mode 100644 index 00000000000..37d21427b3e --- /dev/null +++ b/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_comparision.cu @@ -0,0 +1,37 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh" + +namespace oneflow { + +namespace ep { +namespace primitive { +namespace broadcast_elementwise_binary { + +#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY( \ + binary_op, src_data_type_pair, dst_data_type_pair) \ + template std::unique_ptr NewBroadcastElementwiseBinary< \ + binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>(); + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_ENTRY, + BINARY_COMPARISION_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ, + CUDA_PRIMITIVE_INT8_TYPE_SEQ); + +} // namespace broadcast_elementwise_binary +} // namespace primitive +} // namespace ep + +} // namespace oneflow diff --git a/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_logical.cu b/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_logical.cu new file mode 100644 index 00000000000..b2cacb10a11 --- /dev/null +++ b/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_logical.cu @@ -0,0 +1,37 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh" + +namespace oneflow { + +namespace ep { +namespace primitive { +namespace broadcast_elementwise_binary { + +#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_LOGICAL_ENTRY(binary_op, src_data_type_pair, \ + dst_data_type_pair) \ + template std::unique_ptr NewBroadcastElementwiseBinary< \ + binary_op, OF_PP_PAIR_FIRST(src_data_type_pair), OF_PP_PAIR_FIRST(dst_data_type_pair)>(); + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_LOGICAL_ENTRY, + BINARY_COMPARISION_OP_SEQ BINARY_LOGICAL_OP_SEQ, + CUDA_PRIMITIVE_ALL_TYPE_SEQ, CUDA_PRIMITIVE_INT8_TYPE_SEQ); + +} // namespace broadcast_elementwise_binary +} // namespace primitive +} // namespace ep + +} // namespace oneflow diff --git a/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math.cu b/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math.cu new file mode 100644 index 00000000000..0024c3f6798 --- /dev/null +++ b/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math.cu @@ -0,0 +1,35 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh" + +namespace oneflow { + +namespace ep { +namespace primitive { +namespace broadcast_elementwise_binary { + +#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \ + template std::unique_ptr NewBroadcastElementwiseBinary< \ + binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>(); + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY, + BINARY_MATH_OP_SEQ, CUDA_PRIMITIVE_ALL_TYPE_SEQ); + +} // namespace broadcast_elementwise_binary +} // namespace primitive +} // namespace ep + +} // namespace oneflow diff --git a/oneflow/core/ep/include/device.h b/oneflow/core/ep/include/device.h index b7cb214f5c1..e7b62e18ad0 100644 --- a/oneflow/core/ep/include/device.h +++ b/oneflow/core/ep/include/device.h @@ -28,6 +28,8 @@ namespace ep { constexpr size_t kMaxAlignmentRequirement = 512; +class DeviceManager; + class Device { public: OF_DISALLOW_COPY_AND_MOVE(Device); @@ -38,6 +40,7 @@ class Device { virtual DeviceType device_type() const = 0; virtual size_t device_index() const = 0; + virtual DeviceManager* device_manager() const = 0; virtual Stream* CreateStream() = 0; virtual void DestroyStream(Stream* stream) = 0; diff --git a/oneflow/core/ep/include/device_manager.h b/oneflow/core/ep/include/device_manager.h index 53f1c51e0c5..9a9050decd2 100644 --- a/oneflow/core/ep/include/device_manager.h +++ b/oneflow/core/ep/include/device_manager.h @@ -26,12 +26,15 @@ namespace oneflow { namespace ep { +class DeviceManagerRegistry; + class DeviceManager { public: OF_DISALLOW_COPY_AND_MOVE(DeviceManager); DeviceManager() = default; virtual ~DeviceManager() = default; + virtual DeviceManagerRegistry* registry() const = 0; virtual std::shared_ptr GetDevice(size_t device_index) = 0; virtual size_t GetDeviceCount(size_t primary_device_index) = 0; virtual size_t GetDeviceCount() = 0; @@ -39,9 +42,6 @@ class DeviceManager { virtual void SetActiveDeviceByIndex(size_t device_index) = 0; }; -#define REGISTER_EP_DEVICE_MANAGER(device_type, ManagerType) \ - REGISTER_CLASS(int32_t, device_type, DeviceManager, ManagerType) - } // namespace ep } // namespace oneflow diff --git a/oneflow/core/ep/include/device_manager_factory.h b/oneflow/core/ep/include/device_manager_factory.h index 980efb48e35..d8013c0cfdf 100644 --- a/oneflow/core/ep/include/device_manager_factory.h +++ b/oneflow/core/ep/include/device_manager_factory.h @@ -24,13 +24,15 @@ namespace oneflow { namespace ep { +class DeviceManagerRegistry; + class DeviceManagerFactory { public: OF_DISALLOW_COPY_AND_MOVE(DeviceManagerFactory); DeviceManagerFactory() = default; virtual ~DeviceManagerFactory() = default; - virtual std::unique_ptr NewDeviceManager() = 0; + virtual std::unique_ptr NewDeviceManager(DeviceManagerRegistry* registry) = 0; virtual DeviceType device_type() const = 0; virtual std::string device_type_name() const = 0; virtual void DumpVersionInfo() const {} diff --git a/oneflow/core/ep/include/primitive/binary_op.h b/oneflow/core/ep/include/primitive/binary_op.h index 3730808904b..f621d3e2b4e 100644 --- a/oneflow/core/ep/include/primitive/binary_op.h +++ b/oneflow/core/ep/include/primitive/binary_op.h @@ -24,21 +24,30 @@ namespace ep { namespace primitive { enum class BinaryOp { + // Math kAdd, kSub, kMul, kDiv, kMax, kMin, + kPow, + // Comparision kEqual, kNotEqual, kLessThan, kLessEqual, kGreaterThan, kGreaterEqual, + // Logical kLogicalAnd, kLogicalOr, kLogicalXor, + // Unary Backward + kReluBackwardWithDyY, + kSigmoidBackwardWithDyY, + kGeluBackwardWithDyX, + }; } diff --git a/oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h b/oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h index 82165119978..904762be74d 100644 --- a/oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h +++ b/oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h @@ -47,7 +47,7 @@ class BroadcastElementwiseBinaryFactory : public Factory New(BinaryOp op, DataType src_type, - DeviceType dst_type, + DataType dst_type, size_t max_num_dims) = 0; }; diff --git a/oneflow/core/framework/attr_value.cpp b/oneflow/core/framework/attr_value.cpp index a30e107ffa3..f24d8077e15 100644 --- a/oneflow/core/framework/attr_value.cpp +++ b/oneflow/core/framework/attr_value.cpp @@ -19,17 +19,29 @@ namespace oneflow { template const T& AttrValueCast(const user_op::AttrVal& attr_val) { - const auto* typed_attr = dynamic_cast*>(&attr_val); + const auto* typed_attr = dynamic_cast*>(&attr_val); return CHECK_NOTNULL(typed_attr)->val(); } +template +std::shared_ptr CastAttrValue(const T& attr_val) { + return std::make_shared>(attr_val); +} + +template +std::shared_ptr CastAttrValue(const T* attr_val) { + return std::make_shared>(attr_val); +} + template size_t HashTypedAttrVal(const T& val) { return std::hash()(val); } -#define INITIALIZE_ATTR_VALUE_CAST(field, T, attr_type) \ - template const T& AttrValueCast(const user_op::AttrVal& attr_val); \ +#define INITIALIZE_ATTR_VALUE_CAST(field, T, attr_type) \ + template const T& AttrValueCast(const user_op::AttrVal& attr_val); \ + template std::shared_ptr CastAttrValue(const T& attr_val); \ + template std::shared_ptr CastAttrValue(const T* attr_val); \ template size_t HashTypedAttrVal(const T& attr_val); OF_PP_FOR_EACH_TUPLE(INITIALIZE_ATTR_VALUE_CAST, ATTR_SEQ) diff --git a/oneflow/core/framework/attr_value.h b/oneflow/core/framework/attr_value.h index b02a379cecc..d7b67757cea 100644 --- a/oneflow/core/framework/attr_value.h +++ b/oneflow/core/framework/attr_value.h @@ -97,19 +97,25 @@ class AttrVal { }; template -class TypedAttrVal final : public AttrVal { +class TypedAttrValIf : public AttrVal { public: - TypedAttrVal(T v) : val_(v) {} - ~TypedAttrVal() = default; + virtual const T& val() const = 0; + size_t hash_value() const override { return std::hash()(val()); } - size_t hash_value() const override { return std::hash()(val_); } bool operator==(const AttrVal& other) const override { - auto* that = dynamic_cast*>(&other); + auto* that = dynamic_cast*>(&other); if (that == nullptr) { return false; } - return this->val_ == that->val_; + return this->val() == that->val(); } +}; + +template +class TypedAttrVal final : public TypedAttrValIf { + public: + TypedAttrVal(T v) : val_(v) {} + ~TypedAttrVal() = default; - const T& val() const { return val_; } + const T& val() const override { return val_; } private: OF_DISALLOW_COPY_AND_MOVE(TypedAttrVal); @@ -117,11 +123,31 @@ class TypedAttrVal final : public AttrVal { T val_; }; +template +class TypedAttrValRef final : public TypedAttrValIf { + public: + TypedAttrValRef(const T* v) : val_(v) {} + ~TypedAttrValRef() = default; + + const T& val() const override { return *val_; } + + private: + OF_DISALLOW_COPY_AND_MOVE(TypedAttrValRef); + + const T* val_; +}; + } // namespace user_op template const T& AttrValueCast(const user_op::AttrVal& val); +template +std::shared_ptr CastAttrValue(const T& attr_val); + +template +std::shared_ptr CastAttrValue(const T* attr_val); + } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_ATTR_VALUE_H_ diff --git a/oneflow/core/framework/attr_value_accessor.cpp b/oneflow/core/framework/attr_value_accessor.cpp index 79a63238840..2cc897030c7 100644 --- a/oneflow/core/framework/attr_value_accessor.cpp +++ b/oneflow/core/framework/attr_value_accessor.cpp @@ -130,7 +130,7 @@ std::vector AttrValueAccessor>::Attr(const AttrValue& template<> void AttrValueAccessor>::Attr(const std::vector& cpp_val, AttrValue* attr_val) { - if (attr_val->at_list_shape().val_size() > 0) { attr_val->mutable_at_list_shape()->clear_val(); } + attr_val->mutable_at_list_shape()->clear_val(); FOR_RANGE(int32_t, i, 0, cpp_val.size()) { cpp_val.at(i).ToProto(attr_val->mutable_at_list_shape()->add_val()); } @@ -176,8 +176,8 @@ Maybe MakeCppAttrValueFromProtoOrCfgAttrValue(const ProtoT& cfg_attr_va // clang-format off #define MAKE_ENTRY(field, cpp_type, attr_type) \ } \ - else if (dynamic_cast*>(&cpp_attr_value) != nullptr) { \ - const auto* ptr = dynamic_cast*>(&cpp_attr_value); \ + else if (dynamic_cast*>(&cpp_attr_value) != nullptr) { \ + const auto* ptr = dynamic_cast*>(&cpp_attr_value); \ AttrValueAccessor::Attr(ptr->val(), attr_value); OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, ATTR_SEQ); #undef MAKE_ENTRY diff --git a/oneflow/core/framework/consistency_check.cpp b/oneflow/core/framework/consistency_check.cpp new file mode 100644 index 00000000000..ee99f35fd43 --- /dev/null +++ b/oneflow/core/framework/consistency_check.cpp @@ -0,0 +1,255 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include +#include "oneflow/core/framework/consistency_check.h" +#include "oneflow/core/intrusive/flat_msg.h" +#include "oneflow/core/job/rank_group.h" +#include "oneflow/core/framework/transport_util.h" +#include "oneflow/core/job/rank_group_scope.h" +#include "oneflow/core/framework/synced_symbol_map.h" +#include "oneflow/core/framework/sync_symbol_nd_sbp.h" +#include "oneflow/core/framework/sync_symbol_parallel_desc.h" + +namespace oneflow { + +namespace { + +class FlatMetaInfoConsistency; + +class CheckMetaInfoConsistencyAsyncTransportCtx : public AsyncTransportCtx { + public: + CheckMetaInfoConsistencyAsyncTransportCtx(const TransportToken& transport_token, + const Symbol& placement, + const Optional>& nd_sbp, + const Optional>& grad_nd_sbp) + : AsyncTransportCtx(transport_token), + placement_(placement), + nd_sbp_(nd_sbp), + grad_nd_sbp_(grad_nd_sbp) {} + + ~CheckMetaInfoConsistencyAsyncTransportCtx() override = default; + + Maybe PrepareSendBufferAndCallback(int64_t rank, void** buffer, std::size_t* size, + std::function* Callback) override; + + Maybe PrepareRecvBufferAndCallback(int64_t rank, void** buffer, std::size_t* size, + std::function* Callback) override; + + Maybe Check() const; + + private: + Symbol placement_; + Optional> nd_sbp_; + Optional> grad_nd_sbp_; + std::shared_ptr flat_meta_info_consistency_; +}; + +// clang-format off +FLAT_MSG_BEGIN(FlatMetaInfoConsistency); + public: + static Maybe New() { + const auto& consistency = std::make_shared(); + consistency->clear(); + return consistency; + } + static Maybe New(const Symbol& placement, + const Optional>& nd_sbp, const Optional>& grad_nd_sbp) { + const auto& consistency = std::make_shared(); + consistency->clear(); + JUST(consistency->Init(placement, nd_sbp, grad_nd_sbp)); + return consistency; + } + + Maybe Check(const Symbol& placement, + const Optional>& nd_sbp, const Optional>& grad_nd_sbp) { + + const auto& this_placement = + JUST(SyncedSymbolMap::Symbol4SyncedSymbolId( + this->placement_symbol_id())); + CHECK_OR_RETURN(this_placement == placement) << "Each rank must have same input placement"; + CHECK_EQ_OR_RETURN(nd_sbp.has_value(), this->has_nd_sbp_symbol_id()); + if (this->has_nd_sbp_symbol_id()) { + const auto& that_nd_sbp = + JUST(SyncedSymbolMap::Symbol4SyncedSymbolId( + this->nd_sbp_symbol_id())); + const auto& this_nd_sbp = JUST(nd_sbp); + CHECK_OR_RETURN(this_nd_sbp == that_nd_sbp) << "Each rank must have same input sbp"; + } + CHECK_EQ_OR_RETURN(grad_nd_sbp.has_value(), this->has_grad_nd_sbp_symbol_id()); + if (this->has_grad_nd_sbp_symbol_id()) { + const auto& that_grad_nd_sbp = + JUST(SyncedSymbolMap::Symbol4SyncedSymbolId( + this->grad_nd_sbp_symbol_id())); + const auto& this_grad_nd_sbp = JUST(grad_nd_sbp); + CHECK_OR_RETURN(this_grad_nd_sbp == that_grad_nd_sbp) << "Each rank must have same input grad sbp"; + } + return Maybe::Ok(); + } + private: + Maybe Init(const Symbol& placement, const Optional>& nd_sbp, + const Optional>& grad_nd_sbp) { + this->set_placement_symbol_id( + JUST(SyncedSymbolMap::FindOrSync(placement, &SyncSymbolParallelDesc))); + if (nd_sbp.has_value()) { + this->set_nd_sbp_symbol_id( + JUST(SyncedSymbolMap::FindOrSync(JUST(nd_sbp), &SyncSymbolNdSbp))); + } + if (grad_nd_sbp.has_value()) { + this->set_grad_nd_sbp_symbol_id( + JUST(SyncedSymbolMap::FindOrSync(JUST(grad_nd_sbp), &SyncSymbolNdSbp))); + } + return Maybe::Ok(); + } + FLAT_MSG_DEFINE_OPTIONAL(uint64_t, placement_symbol_id); + FLAT_MSG_DEFINE_OPTIONAL(uint64_t, nd_sbp_symbol_id); + FLAT_MSG_DEFINE_OPTIONAL(uint64_t, grad_nd_sbp_symbol_id); +FLAT_MSG_END(FlatMetaInfoConsistency); +// clang-format off + +Maybe CheckMetaInfoConsistencyAsyncTransportCtx::PrepareSendBufferAndCallback( + int64_t rank, void** buffer, std::size_t* size, std::function* Callback) { + const auto& meta_info_consistency = + JUST(FlatMetaInfoConsistency::New(placement_, nd_sbp_, grad_nd_sbp_)); + *buffer = meta_info_consistency.get(); + *size = sizeof(FlatMetaInfoConsistency); + *Callback = [meta_info_consistency] {}; + return Maybe::Ok(); +} + +Maybe CheckMetaInfoConsistencyAsyncTransportCtx::PrepareRecvBufferAndCallback( + int64_t rank, void** buffer, std::size_t* size, std::function* Callback) { + const auto& flat_meta_info_consistency = JUST(FlatMetaInfoConsistency::New()); + *buffer = flat_meta_info_consistency.get(); + *size = sizeof(FlatMetaInfoConsistency); + *Callback = [flat_meta_info_consistency]() {}; + flat_meta_info_consistency_ = flat_meta_info_consistency; + return Maybe::Ok(); +} + +Maybe CheckMetaInfoConsistencyAsyncTransportCtx::Check() const { + if (!flat_meta_info_consistency_) { return Maybe::Ok(); } + JUST(flat_meta_info_consistency_->Check(placement_, nd_sbp_, grad_nd_sbp_)); + return Maybe::Ok(); +} + +} // namespace + +Maybe DataConsistencyCheck(const void* buffer_ptr, size_t buffer_size, + Symbol placement) { + const auto& rank_group = JUST(RankGroup::New(placement)); + + std::vector recv_buffer(buffer_size); + char* recv_ptr = recv_buffer.data(); + + TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData)); + NaiveAsyncTransportCtx ctx( + transport_token, + [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { + *buffer = const_cast(buffer_ptr); + *size = buffer_size; + *Cb = [] {}; + return Maybe::Ok(); + }, + [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { + *buffer = recv_ptr; + *size = buffer_size; + *Cb = [] {}; + return Maybe::Ok(); + }); + JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx)); + JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx)); + JUST(TransportUtil::WaitUntilDoneOrTimeout(ctx, TransportUtil::TimeoutSeconds())); + CHECK_OR_RETURN(std::memcmp(buffer_ptr, reinterpret_cast(recv_ptr), buffer_size) + == 0) + << "Each rank must have same input sequence or numpy array"; + return Maybe::Ok(); +} + +namespace { + +Maybe MetaInfoConsistencyCheckUtil(const Symbol& placement, + const Optional>& nd_sbp, const Optional>& grad_nd_sbp) { + const auto& rank_group = JUST(RankGroupScope::CurrentRankGroup()); + const auto& transport_token = + JUST(TransportToken::NewTransportToken(kTransportTokenTypeCheckRankGroupConsistency)); + const auto& ctx = std::make_shared( + transport_token, placement, nd_sbp, grad_nd_sbp); + JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, ctx.get())); + JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, ctx.get())); + JUST(TransportUtil::WaitUntilDoneOrTimeout(*ctx, TransportUtil::TimeoutSeconds())); + JUST(ctx->Check()); + return Maybe::Ok(); +} + +int64_t* MutThreadLocalMetaInfoConsistencyCheckDepth() { + static thread_local int64_t recursive_depth = 0; + return &recursive_depth; +} + +inline bool IsMetaInfoConsistencyCheckDisable() { + return *MutThreadLocalMetaInfoConsistencyCheckDepth() > 1; +} + +} // namespace + +NonRecursiveMetaInfoConsistencyCheckScope::NonRecursiveMetaInfoConsistencyCheckScope() { + auto* recursive_depth = MutThreadLocalMetaInfoConsistencyCheckDepth(); + ++*recursive_depth; +} + +NonRecursiveMetaInfoConsistencyCheckScope::~NonRecursiveMetaInfoConsistencyCheckScope() { + auto* recursive_depth = MutThreadLocalMetaInfoConsistencyCheckDepth(); + --*recursive_depth; +} + +Maybe MetaInfoConsistencyCheck(const Symbol& placement, + const Optional>& nd_sbp, + const Optional>& grad_nd_sbp) { + if (!IsMetaInfoConsistencyCheckDisable()) { + JUST(MetaInfoConsistencyCheckUtil(placement, nd_sbp, grad_nd_sbp)); + } + return Maybe::Ok(); +} + +Maybe MetaInfoConsistencyCheck(const Symbol& placement, + const Optional>& nd_sbp) { + if (!IsMetaInfoConsistencyCheckDisable()) { + JUST(MetaInfoConsistencyCheckUtil(placement, nd_sbp, Optional>())); + } + return Maybe::Ok(); +} + +Maybe MetaInfoConsistencyCheck(const Symbol& placement, + const std::vector>& sbp_tuple, + const std::vector>& grad_sbp_tuple) { + Optional> nd_sbp; + Optional> grad_nd_sbp; + if (!sbp_tuple.empty()) { grad_nd_sbp = JUST(GetNdSbp(sbp_tuple)); } + if (!grad_sbp_tuple.empty()) { grad_nd_sbp = JUST(GetNdSbp(grad_sbp_tuple)); } + JUST(MetaInfoConsistencyCheck(placement, nd_sbp, grad_nd_sbp)); + return Maybe::Ok(); +} + +Maybe MetaInfoConsistencyCheck(const Symbol& placement, + const std::vector>& sbp_tuple) { + Optional> nd_sbp; + Optional> grad_nd_sbp; + if (!sbp_tuple.empty()) { grad_nd_sbp = JUST(GetNdSbp(sbp_tuple)); } + JUST(MetaInfoConsistencyCheck(placement, nd_sbp, grad_nd_sbp)); + return Maybe::Ok(); +} + +} // namespace oneflow diff --git a/oneflow/core/framework/consistency_check.h b/oneflow/core/framework/consistency_check.h new file mode 100644 index 00000000000..aba5ef46f48 --- /dev/null +++ b/oneflow/core/framework/consistency_check.h @@ -0,0 +1,52 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_FRAMEWORK_DATA_CONSISTENCY_CHECK_H_ +#define ONEFLOW_CORE_FRAMEWORK_DATA_CONSISTENCY_CHECK_H_ + +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/symbol.h" +#include "oneflow/core/job/parallel_desc.h" +#include "oneflow/core/framework/nd_sbp.h" + +namespace oneflow { + +class NonRecursiveMetaInfoConsistencyCheckScope final { + public: + OF_DISALLOW_COPY_AND_MOVE(NonRecursiveMetaInfoConsistencyCheckScope); + NonRecursiveMetaInfoConsistencyCheckScope(); + ~NonRecursiveMetaInfoConsistencyCheckScope(); +}; + +Maybe DataConsistencyCheck(const void* buffer_ptr, size_t buffer_size, + Symbol placement); + +Maybe MetaInfoConsistencyCheck(const Symbol& placement, + const Optional>& nd_sbp, + const Optional>& grad_nd_sbp); + +Maybe MetaInfoConsistencyCheck(const Symbol& placement, + const Optional>& nd_sbp); + +Maybe MetaInfoConsistencyCheck(const Symbol& placement, + const std::vector>& sbp_tuple, + const std::vector>& grad_sbp_tuple); + +Maybe MetaInfoConsistencyCheck(const Symbol& placement, + const std::vector>& sbp_tuple); + +} // namespace oneflow + +#endif // ONEFLOW_CORE_FRAMEWORK_DATA_CONSISTENCY_CHECK_H_ diff --git a/oneflow/core/framework/data_consistency_check.cpp b/oneflow/core/framework/data_consistency_check.cpp deleted file mode 100644 index 1cf224d8cd3..00000000000 --- a/oneflow/core/framework/data_consistency_check.cpp +++ /dev/null @@ -1,54 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include -#include "oneflow/core/framework/data_consistency_check.h" -#include "oneflow/core/job/rank_group.h" -#include "oneflow/core/framework/transport_util.h" - -namespace oneflow { - -Maybe DataConsistencyCheck(const void* buffer_ptr, size_t buffer_size, - Symbol placement) { - const auto& rank_group = JUST(RankGroup::New(placement)); - - std::vector recv_buffer(buffer_size); - char* recv_ptr = recv_buffer.data(); - - TransportToken transport_token = JUST(TransportToken::NewTransportToken(kTransportTokenTypeData)); - NaiveAsyncTransportCtx ctx( - transport_token, - [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { - *buffer = const_cast(buffer_ptr); - *size = buffer_size; - *Cb = [] {}; - return Maybe::Ok(); - }, - [&](void** buffer, std::size_t* size, std::function* Cb) -> Maybe { - *buffer = recv_ptr; - *size = buffer_size; - *Cb = [] {}; - return Maybe::Ok(); - }); - JUST(TransportUtil::SendToNextRankInRing(rank_group, transport_token, &ctx)); - JUST(TransportUtil::ReceiveFromPrevRankInRing(rank_group, transport_token, &ctx)); - JUST(TransportUtil::WaitUntilDoneOrTimeout(ctx, TransportUtil::TimeoutSeconds())); - CHECK_OR_RETURN(std::memcmp(buffer_ptr, reinterpret_cast(recv_ptr), buffer_size) - == 0) - << "Each rank must have same input sequence or numpy array"; - return Maybe::Ok(); -} - -} // namespace oneflow diff --git a/oneflow/core/framework/device.cpp b/oneflow/core/framework/device.cpp index 52bb9a8e5d2..2dd191ec638 100644 --- a/oneflow/core/framework/device.cpp +++ b/oneflow/core/framework/device.cpp @@ -36,6 +36,14 @@ inline size_t HashDevice(const std::string& type, int64_t device_id) { return std::hash()(type) ^ std::hash()(device_id); } +void CheckDeviceType(const std::string& type) { + if (Device::type_supported.find(type) == Device::type_supported.end()) { + std::string error_msg = + "Expected one of cpu, cuda device type at start of device string " + type; + throw std::runtime_error(error_msg); + } +} + } // namespace Device::Device(const std::string& type, int64_t device_id) @@ -83,6 +91,19 @@ Maybe Device::Init() { return New(type, GlobalProcessCtx::LocalRank()); } +/* static */ Maybe> Device::ParseAndNew( + const std::string& type_or_type_with_device_id) { + std::string type; + int device_id = -1; + JUST(ParsingDeviceTag(type_or_type_with_device_id, &type, &device_id)); + CheckDeviceType(type); + if (device_id == -1) { + return Device::New(type); + } else { + return Device::New(type, device_id); + } +} + Maybe Device::of_type() const { static const HashMap type2device_tag{ {"cpu", "cpu"}, diff --git a/oneflow/core/framework/device.h b/oneflow/core/framework/device.h index 6ff3c50250f..399dcb6bd3b 100644 --- a/oneflow/core/framework/device.h +++ b/oneflow/core/framework/device.h @@ -30,8 +30,8 @@ class ParallelDesc; class MemoryCase; class LocalDepObject; -inline size_t GetInstructionHighWaterMark() { return 500; } -inline size_t GetInstructionLowWaterMark() { return 200; } +inline size_t GetInstructionHighWaterMark() { return 3000; } +inline size_t GetInstructionLowWaterMark() { return 1000; } class Device final { public: @@ -57,6 +57,7 @@ class Device final { static Maybe> ThreadLocalGetOrNew(const std::string& type, int64_t device_id); static Maybe> New(const std::string& type, int64_t device_id); static Maybe> New(const std::string& type); + static Maybe> ParseAndNew(const std::string& type_or_type_with_device_id); static Maybe> MakeDeviceByParallelDesc(const ParallelDesc& parallel_desc); static const std::unordered_set type_supported; diff --git a/oneflow/core/framework/dtype.cpp b/oneflow/core/framework/dtype.cpp index f01e11b77e1..598033ad2e3 100644 --- a/oneflow/core/framework/dtype.cpp +++ b/oneflow/core/framework/dtype.cpp @@ -104,24 +104,33 @@ bool DType::is_complex() const { return CHECK_JUST(DTypeMeta4DataType(data_type_ /* The order of datatype is: - 0 1 2 3 4 5 6 7 8 9 10 11 - iv c1 f4 f8 i1 i4 i8 u1 re f2 bu bf - The priority order of datatype is: - 0 1 2 3 4 5 6 7 8 9 10 11 - iv < u1 < c1 < i1 < i4 < i8 < f2 < f4 < f8 < bf < re < bu. + 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 + 20 iv c1 f4 f8 i1 i4 i8 u1 re f2 bu bf b1 u4 u8 u16 i2 i16 cp4 + cp8 cp16 The priority order of datatype is: 0 1 2 3 4 5 6 7 8 9 10 + 11 12 13 14 15 16 17 18 19 20 iv < b1 < u1 < c1 < i1 < i2 < u4 < i4 < u8 < + i8 < u16 < i16 < f2 < f4 < f8 < cp4 < cp8 < cp16 < bf < re < bu. */ -const int DType::priority_order[DataType::kMaxDataType] = {0, /*kInvalid*/ - 2, /*kChar*/ - 7, /*kFloat32*/ - 8, /*kDouble*/ - 3, /*kInt8*/ - 4, /*kInt32*/ - 5, /*kInt64*/ - 1, /*kUInt8*/ - 10, /*kOFRecord*/ - 6, /*kFloat16*/ - 11, /*kTensorBuffer*/ - 9 /*kBFloat16*/}; +const int DType::priority_order[DataType_ARRAYSIZE] = {0, /*kInvalid*/ + 3, /*kChar*/ + 13, /*kFloat32*/ + 14, /*kDouble*/ + 4, /*kInt8*/ + 7, /*kInt32*/ + 9, /*kInt64*/ + 2, /*kUInt8*/ + 19, /*kOFRecord*/ + 12, /*kFloat16*/ + 20, /*kTensorBuffer*/ + 18, /*kBFloat16*/ + 1, /*kBool*/ + 6, /*kUint32*/ + 8, /*kUint64*/ + 10, /*kUint128*/ + 5, /*kInt16*/ + 11, /*kInt128*/ + 15, /*kComplex32*/ + 16, /*kComplex64*/ + 17 /*kComplex128*/}; bool DType::is_floating_point() const { return CHECK_JUST(DTypeMeta4DataType(data_type_)).is_floating_point(); @@ -150,6 +159,16 @@ Symbol promoteTypes(const Symbol a, const Symbol b) { const Symbol f2 = CHECK_JUST(DType::Get(DataType::kFloat16)); const Symbol bu = CHECK_JUST(DType::Get(DataType::kTensorBuffer)); const Symbol bf = CHECK_JUST(DType::Get(DataType::kBFloat16)); + const Symbol b1 = CHECK_JUST(DType::Get(DataType::kBool)); + const Symbol u2 = CHECK_JUST(DType::Get(DataType::kUInt16)); + const Symbol u4 = CHECK_JUST(DType::Get(DataType::kUInt32)); + const Symbol u8 = CHECK_JUST(DType::Get(DataType::kUInt64)); + const Symbol u16 = CHECK_JUST(DType::Get(DataType::kUInt128)); + const Symbol i2 = CHECK_JUST(DType::Get(DataType::kInt16)); + const Symbol i16 = CHECK_JUST(DType::Get(DataType::kInt128)); + const Symbol cp4 = CHECK_JUST(DType::Get(DataType::kComplex32)); + const Symbol cp8 = CHECK_JUST(DType::Get(DataType::kComplex64)); + const Symbol cp16 = CHECK_JUST(DType::Get(DataType::kComplex128)); /* It is consistent with data_type.proto(except kInvalidDataType, kOFRecord and kTensorBuffer) kInvalidDataType = 0; @@ -164,29 +183,54 @@ Symbol promoteTypes(const Symbol a, const Symbol b) { kFloat16 = 9; kTensorBuffer = 10; kBFloat16 = 11; + kBool = 12; + kUInt16 = 13; + kUInt32 = 14; + kUInt64 = 15; + kUInt128 = 16; + kInt16 = 17; + kInt128 = 18; + kComplex32 = 19; + kComplex64 = 20; + kComplex128 = 21; The priority order of datatype is: - iv < u1 < c1 < i1 < i4 < i8 < f2 < f4 < f8 < bf < re < bu. + iv < b1 < u1 < c1 < i1 < u2 < i2 < u4 < i4 < u8 < i8 < u16 < i16 < f2 < f4 < f8 < cp4 < cp8 < + cp16 < bf < re < bu. + + When int8 + uint8, it need to promote to int16, etc. + But in int8 + uint128, we should promote to int256, but it is not exist, so we set as Invalid. The new DataType should be add in the end of proto, and the Loopup table should be maintained as right priority (author:zhengzekang). */ - static const Symbol _promoteTypesLookup[DataType::kMaxDataType][DataType::kMaxDataType] = { - /* iv c1 f4 f8 i1 i4 i8 u1 re f2 bu bf */ - /* iv */ {iv, c1, f4, f8, i1, i4, i8, u1, re, f2, bu, bf}, - /* c1 */ {c1, c1, f4, f8, i1, i4, i8, c1, re, f2, bu, bf}, - /* f4 */ {f4, f4, f4, f8, f4, f4, f4, f4, re, f4, bu, bf}, - /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, re, f8, bu, bf}, - /* i1 */ {i1, i1, f4, f8, i1, i4, i8, i1, re, f2, bu, bf}, - /* i4 */ {i4, i4, f4, f8, i4, i4, i8, i4, re, f2, bu, bf}, - /* i8 */ {i8, i8, f4, f8, i8, i8, i8, i8, re, f2, bu, bf}, - /* u1 */ {u1, c1, f4, f8, i1, i4, i8, u1, re, f2, bu, bf}, - /* re */ {re, re, re, re, re, re, re, re, re, re, bu, re}, - /* f2 */ {f2, f2, f4, f8, f2, f2, f2, f2, re, f2, bu, bf}, - /* bu */ {bu, bu, bu, bu, bu, bu, bu, bu, bu, bu, bu, bu}, - /* bf */ {bf, bf, bf, bf, bf, bf, bf, bf, re, bf, bu, bf}, - }; + // clang-format off + static const Symbol _promoteTypesLookup[DataType_ARRAYSIZE][DataType_ARRAYSIZE] = { + /* iv c1 f4 f8 i1 i4 i8 u1 re f2 bu bf b1 u2 u4 u8 u16 i2 i16 cp4 cp8 cp16 */ + /* iv */ {iv, c1, f4, f8, i1, i4, i8, u1, re, f2, bu, bf, b1, u2, u4, u8, u16, i2, i16, cp4, cp8, cp16}, + /* c1 */ {c1, c1, f4, f8, i1, i4, i8, c1, iv, f2, iv, bf, c1, u2, u4, u8, u16, i2, i16, iv, cp4, cp16}, + /* f4 */ {f4, f4, f4, f8, f4, f4, f4, f4, iv, f4, iv, bf, f4, f4, f4, f4, f4, f4, f4, iv, cp4, cp16}, + /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, iv, f8, iv, bf, f8, f8, f8, f8, f8, f8, f8, iv, cp4, cp16}, + /* i1 */ {i1, i1, f4, f8, i1, i4, i8, i2, iv, f2, iv, bf, i1, i4, i8, i16, iv, i2, i16, iv, cp4, cp16}, + /* i4 */ {i4, i4, f4, f8, i4, i4, i8, i4, iv, f2, iv, bf, i4, i4, i8, i16, iv, i4, i16, iv, cp4, cp16}, + /* i8 */ {i8, i8, f4, f8, i8, i8, i8, i8, iv, f2, iv, bf, i8, i8, i8, i16, iv, i8, i16, iv, cp4, cp16}, + /* u1 */ {u1, c1, f4, f8, i2, i4, i8, u1, iv, f2, iv, bf, u1, u2, u4, u8, u16, i2, i16, iv, cp4, cp16}, + /* re */ {iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv}, + /* f2 */ {f2, f2, f4, f8, f2, f2, f2, f2, iv, f2, iv, bf, f2, f2, f2, f2, iv, f2, f2, iv, cp4, cp16}, + /* bu */ {iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, bu, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv}, + /* bf */ {bf, bf, bf, bf, bf, bf, bf, bf, iv, bf, iv, bf, bf, bf, bf, bf, iv, bf, bf, iv, cp4, cp16}, + /* b1 */ {c1, c1, f4, f8, i1, i4, i8, u1, iv, f2, iv, bf, b1, u2, u4, u8, u16, i2, i16, iv, cp4, cp16}, + /* u2 */ {u2, u2, f4, f8, i4, i4, i8, u2, iv, f2, iv, bf, u2, u2, u4, u8, u16, i4, i16, iv, cp4, cp16}, + /* u4 */ {u4, u4, f4, f8, i8, i8, i8, u4, iv, f2, iv, bf, u4, u4, u4, u8, u16, i8, i16, iv, cp4, cp16}, + /* u8 */ {u8, u8, f4, f8, i16, i16, i16, u8, iv, f2, iv, bf, u8, u8, u8, u8, u16, i16, i16, iv, cp4, cp16}, + /* u16 */ {u16, u16, f4, f8, iv, iv, iv, u16, iv, f2, iv, bf, u16, u16, u16, u16, u16, iv, iv, iv, cp4, cp16}, + /* i2 */ {i2, i2, f4, f8, i2, i4, i8, i2, iv, f2, iv, bf, i2, i4, i8, i16, iv, i2, i16, iv, cp4, cp16}, + /* i16 */ {i16, i16, f4, f8, i16, i16, i16, i16, iv, f2, iv, bf, i16, i16, i16, i16, iv, i16, i16, iv, cp4, cp16}, + /* cp4 */ {iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, cp4, cp8, cp16}, + /* cp8 */ {cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp8, iv, cp8, iv, cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp16}, + /* cp16 */ {cp16,cp16,cp16,cp16,cp16,cp16,cp16,cp16,iv, cp16,iv, cp16,cp16,cp16,cp16,cp16,cp16, cp16,cp16, cp16, cp16, cp16}}; + // clang-format on return _promoteTypesLookup[static_cast(a->data_type())][static_cast(b->data_type())]; } diff --git a/oneflow/core/framework/dtype.h b/oneflow/core/framework/dtype.h index 6cd9a6437cd..39934075bbf 100644 --- a/oneflow/core/framework/dtype.h +++ b/oneflow/core/framework/dtype.h @@ -36,7 +36,16 @@ namespace oneflow { OF_PP_MAKE_TUPLE_SEQ(UInt8) \ OF_PP_MAKE_TUPLE_SEQ(OFRecord) \ OF_PP_MAKE_TUPLE_SEQ(TensorBuffer) \ - OF_PP_MAKE_TUPLE_SEQ(BFloat16) + OF_PP_MAKE_TUPLE_SEQ(BFloat16) \ + OF_PP_MAKE_TUPLE_SEQ(UInt16) \ + OF_PP_MAKE_TUPLE_SEQ(UInt32) \ + OF_PP_MAKE_TUPLE_SEQ(UInt64) \ + OF_PP_MAKE_TUPLE_SEQ(UInt128) \ + OF_PP_MAKE_TUPLE_SEQ(Int16) \ + OF_PP_MAKE_TUPLE_SEQ(Int128) \ + OF_PP_MAKE_TUPLE_SEQ(Complex32) \ + OF_PP_MAKE_TUPLE_SEQ(Complex64) \ + OF_PP_MAKE_TUPLE_SEQ(Complex128) class DType final { public: @@ -55,7 +64,7 @@ class DType final { Maybe bytes() const; static Maybe&> Get(DataType); - static const int priority_order[DataType::kMaxDataType]; + static const int priority_order[DataType_ARRAYSIZE]; #define DECLARE_GET_DATA_TYPE_FUNCTION(data_type) static const Symbol& data_type(); OF_PP_FOR_EACH_TUPLE(DECLARE_GET_DATA_TYPE_FUNCTION, DTYPE_SEQ) diff --git a/oneflow/core/framework/infer_util.cpp b/oneflow/core/framework/infer_util.cpp index 9c7226060d2..599f6a9070d 100644 --- a/oneflow/core/framework/infer_util.cpp +++ b/oneflow/core/framework/infer_util.cpp @@ -31,11 +31,10 @@ Maybe TensorDescInferFnUtil::Unchanged(InferContext* ctx) { for (size_t i = 0; i < ctx->inputs().size(); ++i) { const std::pair& input_arg = ctx->inputs().at(i); if (first_tensor_desc) { - const TensorDesc* tensor_desc = - ctx->TensorDesc4ArgNameAndIndex(input_arg.first, input_arg.second); - CHECK_EQ_OR_RETURN(tensor_desc->shape(), first_tensor_desc->shape()); + const TensorDesc& tensor_desc = ctx->InputTensorDesc(input_arg.first, input_arg.second); + CHECK_EQ_OR_RETURN(tensor_desc.shape(), first_tensor_desc->shape()); } else { - first_tensor_desc = ctx->TensorDesc4ArgNameAndIndex(input_arg.first, input_arg.second); + first_tensor_desc = &ctx->InputTensorDesc(input_arg.first, input_arg.second); } } for (size_t i = 0; i < ctx->outputs().size(); ++i) { @@ -51,11 +50,10 @@ Maybe TensorDescInferFnUtil::UnchangedDataType(InferContext* ctx) { for (size_t i = 0; i < ctx->inputs().size(); ++i) { const std::pair& input_arg = ctx->inputs().at(i); if (first_tensor_desc) { - const TensorDesc* tensor_desc = - ctx->TensorDesc4ArgNameAndIndex(input_arg.first, input_arg.second); - CHECK_EQ_OR_RETURN(tensor_desc->data_type(), first_tensor_desc->data_type()); + const TensorDesc& tensor_desc = ctx->InputTensorDesc(input_arg.first, input_arg.second); + CHECK_EQ_OR_RETURN(tensor_desc.data_type(), first_tensor_desc->data_type()); } else { - first_tensor_desc = ctx->TensorDesc4ArgNameAndIndex(input_arg.first, input_arg.second); + first_tensor_desc = &ctx->InputTensorDesc(input_arg.first, input_arg.second); } } for (size_t i = 0; i < ctx->outputs().size(); ++i) { @@ -71,7 +69,7 @@ Maybe TensorDescInferFnUtil::InOutCorrespond(InferContext* ctx) { const auto& input_arg = ctx->inputs().at(i); const auto& output_arg = ctx->outputs().at(i); *ctx->OutputTensorDesc(output_arg.first, output_arg.second) = - *ctx->TensorDesc4ArgNameAndIndex(input_arg.first, input_arg.second); + ctx->InputTensorDesc(input_arg.first, input_arg.second); } return Maybe::Ok(); } diff --git a/oneflow/core/framework/infer_util.h b/oneflow/core/framework/infer_util.h index cf763fbdb1d..da24849cafe 100644 --- a/oneflow/core/framework/infer_util.h +++ b/oneflow/core/framework/infer_util.h @@ -40,7 +40,6 @@ class InferContext { virtual const TensorDesc& InputTensorDesc(const std::string&, int32_t) const = 0; virtual TensorDesc* OutputTensorDesc(const std::string&, int32_t) = 0; - virtual TensorDesc* TensorDesc4ArgNameAndIndex(const std::string&, int32_t) = 0; virtual const TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string&, int32_t) const = 0; virtual const Shape& InputShape(const std::string&, int32_t) const = 0; diff --git a/oneflow/core/framework/instructions_builder.cpp b/oneflow/core/framework/instructions_builder.cpp index b5ba1df875c..b6cc6d751c5 100644 --- a/oneflow/core/framework/instructions_builder.cpp +++ b/oneflow/core/framework/instructions_builder.cpp @@ -44,6 +44,8 @@ limitations under the License. #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/instruction_replay.h" +#include "oneflow/core/job/env_desc.h" +#include "oneflow/core/profiler/profiler.h" #include "oneflow/core/vm/tensor_view_operand.h" namespace oneflow { @@ -254,31 +256,67 @@ static constexpr auto* GetCriticalSectionDevice = } // namespace -template -Maybe> InstructionsBuilder::MakeCriticalSectionBegin( - const one::EagerBlobObjectListPtr& eager_blob_objects) { - const auto& device = JUST(GetCriticalSectionDevice()); - const auto local_dep_object = JUST(LocalDepObject::New(*device)); - const auto& phy_instr_operand = std::make_shared(eager_blob_objects, *local_dep_object); +template +Maybe InstructionsBuilder::MakeCriticalSectionBegin( + const std::shared_ptr& phy_instr_operand) { auto instruction = intrusive::make_shared( Global::Get()->mut_vm(), "CriticalSectionBegin", std::shared_ptr(), phy_instr_operand); instruction_list_->EmplaceBack(std::move(instruction)); - return local_dep_object; + return Maybe::Ok(); } template Maybe InstructionsBuilder::MakeCriticalSectionEnd( - const std::shared_ptr& eager_blob_object, - const std::shared_ptr& event_record) { - const auto& operand = std::make_shared(eager_blob_object, event_record); + const std::shared_ptr& phy_instr_operand) { auto instruction = intrusive::make_shared( Global::Get()->mut_vm(), "CriticalSectionEnd", - std::shared_ptr(), operand); + std::shared_ptr(), phy_instr_operand); instruction_list_->EmplaceBack(std::move(instruction)); return Maybe::Ok(); } +// clang-format off +// Job e.g.: +// [wait_and_send_ids] +// | +// V +// | +// +-------------------+ +// | | +// V [cpu_decoder] +// | | +// [critcial_section_wait] V +// | | +// V [forward_ops...] +// | | +// | V +// +-------------------+ +// | +// [copy_loss] +// | +// +-----------------------+ +// | | +// V V +// | | +// [backward_ops...] | +// | | +// V [critical_section_callback] +// | | +// [optimizer_ops...] V +// | | +// V | +// | | +// +-----------------------+ +// | +// [callback_notifier] +// +// +// clang-format on +// critcial_section_wait is a blocking opkernel which waits tick signal from instruction +// CriticalSectionBegin. +// critical_section_callback is a non-blocking opkernel which notifies instruction +// CriticalSectionEnd done. Maybe InstructionsBuilder::LaunchLazyJob(const one::EagerBlobObjectListPtr& inputs, const one::EagerBlobObjectListPtr& outputs, const one::EagerBlobObjectListPtr& parameters, @@ -287,25 +325,36 @@ Maybe InstructionsBuilder::LaunchLazyJob(const one::EagerBlobObjectListPtr JUST(SoftSyncNNGraphBuffers(outputs, nn_graph)); JUST(SoftSyncNNGraphBuffers(parameters, nn_graph)); { - // instruction list: [CriticalSectionBegin] -> LaunchLazyJob -> [CriticalSectionEnd] - const auto& in_local_dep_object = - JUST(MakeCriticalSectionBegin(inputs)); - const auto& out_local_dep_object = - JUST(MakeCriticalSectionBegin(outputs)); - const auto& op_name2end_event_record = + // instruction chain: [CriticalSectionBegin] -> [CriticalSectionEnd] + // instructions LaunchLazyJob are launched independent from instruction chains + // [CriticalSectionBegin] -> [CriticalSectionEnd] + const auto& input_op_name2end_event_record = std::make_shared>>(); - for (const auto& op_name : nn_graph->inputs_op_names()) { - const auto& event_record = std::make_shared(); - CHECK_OR_RETURN(op_name2end_event_record->emplace(op_name, event_record).second); + { + for (const auto& op_name : nn_graph->inputs_op_names()) { + const auto& event_record = std::make_shared(); + CHECK_OR_RETURN(input_op_name2end_event_record->emplace(op_name, event_record).second); + } + const auto& phy_instr_operand = + std::make_shared( + nn_graph, inputs, input_op_name2end_event_record); + JUST(MakeCriticalSectionBegin(phy_instr_operand)); } - for (const auto& op_name : nn_graph->outputs_op_names()) { - const auto& event_record = std::make_shared(); - CHECK_OR_RETURN(op_name2end_event_record->emplace(op_name, event_record).second); + const auto& output_op_name2end_event_record = + std::make_shared>>(); + { + for (const auto& op_name : nn_graph->outputs_op_names()) { + const auto& event_record = std::make_shared(); + CHECK_OR_RETURN(output_op_name2end_event_record->emplace(op_name, event_record).second); + } + const auto& phy_instr_operand = + std::make_shared( + nn_graph, outputs, output_op_name2end_event_record); + JUST(MakeCriticalSectionBegin(phy_instr_operand)); } { - const auto& phy_instr_operand = std::make_shared( - *in_local_dep_object, *out_local_dep_object, op_name2end_event_record, inputs, outputs, - parameters, nn_graph); + const auto& phy_instr_operand = + std::make_shared(nn_graph, parameters); auto instruction = intrusive::make_shared( Global::Get()->mut_vm(), "LaunchLazyJob", std::shared_ptr(), phy_instr_operand); @@ -314,16 +363,18 @@ Maybe InstructionsBuilder::LaunchLazyJob(const one::EagerBlobObjectListPtr for (int i = 0; i < nn_graph->inputs_op_names().size(); ++i) { const auto& eager_blob_object = inputs->at(i); const auto& op_name = nn_graph->inputs_op_names().at(i); - const auto& event_record = JUST(MapAt(*op_name2end_event_record, op_name)); - JUST(MakeCriticalSectionEnd(eager_blob_object, - event_record)); + const auto& event_record = JUST(MapAt(*input_op_name2end_event_record, op_name)); + const auto& phy_instr_operand = std::make_shared( + eager_blob_object, event_record); + JUST(MakeCriticalSectionEnd(phy_instr_operand)); } for (int i = 0; i < nn_graph->outputs_op_names().size(); ++i) { const auto& eager_blob_object = outputs->at(i); const auto& op_name = nn_graph->outputs_op_names().at(i); - const auto& event_record = JUST(MapAt(*op_name2end_event_record, op_name)); - JUST(MakeCriticalSectionEnd(eager_blob_object, - event_record)); + const auto& event_record = JUST(MapAt(*output_op_name2end_event_record, op_name)); + const auto& phy_instr_operand = std::make_shared( + eager_blob_object, event_record); + JUST(MakeCriticalSectionEnd(phy_instr_operand)); } } return Maybe::Ok(); @@ -1037,9 +1088,8 @@ Maybe InstructionsBuilder::SoftSyncStream(LocalDepObject* compute_local_de const std::string& modifier, Symbol op_device) { if (!JUST(op_device->need_soft_sync_stream())) { return Maybe::Ok(); } - + OF_PROFILER_RANGE_PUSH("SoftStream"); const auto& parallel_desc = JUST(Placement4Device(op_device)).shared_from_symbol(); - { const auto& phy_instr_operand = std::make_shared( compute_local_dep_object, modifier); @@ -1055,6 +1105,7 @@ Maybe InstructionsBuilder::SoftSyncStream(LocalDepObject* compute_local_de Global::Get()->mut_vm(), "Touch", parallel_desc, phy_instr_operand); instruction_list_->EmplaceBack(std::move(instruction)); } + OF_PROFILER_RANGE_POP(); return Maybe::Ok(); } diff --git a/oneflow/core/framework/instructions_builder.h b/oneflow/core/framework/instructions_builder.h index fccb6516224..72b45e6855e 100644 --- a/oneflow/core/framework/instructions_builder.h +++ b/oneflow/core/framework/instructions_builder.h @@ -469,12 +469,10 @@ class InstructionsBuilder : public std::enable_shared_from_this - Maybe> MakeCriticalSectionBegin( - const one::EagerBlobObjectListPtr& eager_blob_objects); + Maybe MakeCriticalSectionBegin(const std::shared_ptr& phy_instr_operand); template - Maybe MakeCriticalSectionEnd(const std::shared_ptr& eager_blob_object, - const std::shared_ptr& event_record); + Maybe MakeCriticalSectionEnd(const std::shared_ptr& phy_instr_operand); std::shared_ptr id_generator_; vm::InstructionMsgList* instruction_list_; diff --git a/oneflow/core/framework/multi_client_session_context.cpp b/oneflow/core/framework/multi_client_session_context.cpp index 75b6a11a211..c183f7e90b3 100644 --- a/oneflow/core/framework/multi_client_session_context.cpp +++ b/oneflow/core/framework/multi_client_session_context.cpp @@ -22,6 +22,7 @@ limitations under the License. #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/id_manager.h" #include "oneflow/core/job/job_instance.h" +#include "oneflow/core/job/critical_section_instance.h" #include "oneflow/core/job/job_build_and_infer_ctx_mgr.h" #include "oneflow/core/job/runtime_context.h" #include "oneflow/core/job/runtime_job_descs.h" @@ -99,6 +100,7 @@ Maybe MultiClientSessionContext::TryInit(const ConfigProto& config_proto) { // NOTE(chengcheng): init runtime global objects Global>>::New(); + Global>>::New(); Global::New(); Global::New(); Global::New(); @@ -143,6 +145,7 @@ Maybe MultiClientSessionContext::TryClose() { Global::Delete(); Global::Delete(); Global::Delete(); + Global>>::Delete(); Global>>::Delete(); } diff --git a/oneflow/core/framework/nd_sbp.cpp b/oneflow/core/framework/nd_sbp.cpp index cf8dd59b67c..bf254453fda 100644 --- a/oneflow/core/framework/nd_sbp.cpp +++ b/oneflow/core/framework/nd_sbp.cpp @@ -13,7 +13,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include #include "oneflow/core/common/util.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/job/sbp_parallel.h" @@ -105,29 +104,33 @@ const std::vector>& GetNoneSbpList() { return none; } -Maybe SbpToString(Symbol sbp_sym) { - std::string sbp_str = "oneflow.sbp."; - if (sbp_sym->has_broadcast_parallel()) { - sbp_str += "broadcast"; - } else if (sbp_sym->has_partial_sum_parallel()) { - sbp_str += "partial_sum"; - } else if (sbp_sym->has_split_parallel()) { - sbp_str += "split(axis=" + std::to_string(sbp_sym->split_parallel().axis()) + ")"; +std::string SbpToString(Symbol sbp_sym) { return SbpToString(*sbp_sym); } + +std::string NdSbpToString(Symbol nd_sbp_sym) { return NdSbpToString(*nd_sbp_sym); } + +std::string SbpToString(const cfg::SbpParallel& sbp) { + std::ostringstream ss; + if (sbp.has_broadcast_parallel()) { + ss << "B"; + } else if (sbp.has_partial_sum_parallel()) { + ss << "P"; + } else if (sbp.has_split_parallel()) { + ss << "S(" << std::to_string(sbp.split_parallel().axis()) << ")"; } else { - UNIMPLEMENTED_THEN_RETURN(); + UNIMPLEMENTED(); } - return sbp_str; + return ss.str(); } -Maybe NdSbpToString(Symbol nd_sbp) { - std::string str = "("; - for (int i = 0; i < nd_sbp->sbp_parallel_size(); ++i) { - if (i > 0) { str += ", "; } - str += *JUST(SbpToString(SymbolOf(nd_sbp->sbp_parallel(i)))); +std::string NdSbpToString(const cfg::NdSbp& nd_sbp) { + std::ostringstream ss; + ss << "("; + for (size_t i = 0; i < nd_sbp.sbp_parallel_size(); ++i) { + if (i > 0) { ss << ", "; } + ss << SbpToString(nd_sbp.sbp_parallel(i)); } - if (nd_sbp->sbp_parallel_size() == 1) { str += ","; } - str += ")"; - return str; + ss << ")"; + return ss.str(); } } // namespace oneflow diff --git a/oneflow/core/framework/nd_sbp.h b/oneflow/core/framework/nd_sbp.h index 445caeafdee..5d1851f8f3d 100644 --- a/oneflow/core/framework/nd_sbp.h +++ b/oneflow/core/framework/nd_sbp.h @@ -49,8 +49,10 @@ static constexpr auto* GetNdSbp = DECORATE(&private_details::RawGetNdSbp, Thread static constexpr auto* GetSbpList = DECORATE(&private_details::RawGetSbpList, ThreadLocal); const std::vector>& GetNoneSbpList(); -Maybe SbpToString(Symbol sbp_sym); -Maybe NdSbpToString(Symbol nd_sbp); +std::string SbpToString(Symbol sbp_sym); +std::string NdSbpToString(Symbol nd_sbp_sym); +std::string SbpToString(const cfg::SbpParallel& sbp); +std::string NdSbpToString(const cfg::NdSbp& nd_sbp); } // namespace oneflow diff --git a/oneflow/core/framework/nn_graph.cpp b/oneflow/core/framework/nn_graph.cpp index e27d2098ce4..940d4fd1c6c 100644 --- a/oneflow/core/framework/nn_graph.cpp +++ b/oneflow/core/framework/nn_graph.cpp @@ -22,12 +22,14 @@ limitations under the License. #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/multi_client_session_context.h" #include "oneflow/core/framework/nd_sbp.h" +#include "oneflow/core/framework/tensor_name_scope.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/graph/op_graph.h" #include "oneflow/core/job/compiler.h" #include "oneflow/core/job/job_build_and_infer_ctx_mgr.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/job_instance.h" +#include "oneflow/core/job/critical_section_instance.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/core/job/plan_util.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" @@ -55,7 +57,7 @@ Maybe GetTensorMetaString(const std::shared_ptr& tenso std::string ret = "shape=" + tensor->shape()->ToString() + ", dtype=" + tensor->dtype()->name(); if (tensor->is_consistent()) { ret += ", placement=" + *JUST(PlacementToString(JUST(tensor->parallel_desc()))); - ret += ", nd_sbp=" + *JUST(NdSbpToString(JUST(tensor->nd_sbp()))); + ret += ", nd_sbp=" + NdSbpToString(JUST(tensor->nd_sbp())); } else { ret += ", device=" + JUST(tensor->device())->ToString(); } @@ -81,9 +83,9 @@ Maybe NNGraph::Close() { return Maybe::Ok(); } -const std::vector& NNGraph::inputs_op_names() const { return input_op_names_; } +const std::vector& NNGraph::inputs_op_names() const { return inputs_op_names_; } -const std::vector& NNGraph::outputs_op_names() const { return output_op_names_; } +const std::vector& NNGraph::outputs_op_names() const { return outputs_op_names_; } const std::vector& NNGraph::inputs_valid() const { return input_tensors_valid_; } @@ -100,14 +102,14 @@ const std::vector& NNGraph::outputs_tensor_meta_str() const { int64_t NNGraph::variable_op_size() const { return variable_op_name2eager_blob_.size(); } Maybe NNGraph::RegisterInputOpNamesAndTensors( - const std::vector& input_op_names, + const std::vector& inputs_op_names, const std::vector>& input_tensors) { - CHECK_EQ_OR_RETURN(input_op_names.size(), input_tensors.size()); - CHECK_OR_RETURN(input_op_names_.empty()) + CHECK_EQ_OR_RETURN(inputs_op_names.size(), input_tensors.size()); + CHECK_OR_RETURN(inputs_op_names_.empty()) << " The input tensors of nn.Graph " << name_ << " are register repeatedly."; CHECK_OR_RETURN(input_tensors_valid_.empty()); CHECK_OR_RETURN(inputs_tensor_meta_str_.empty()); - input_op_names_.assign(input_op_names.begin(), input_op_names.end()); + inputs_op_names_.assign(inputs_op_names.begin(), inputs_op_names.end()); input_tensors_valid_.reserve(input_tensors.size()); inputs_tensor_meta_str_.reserve(input_tensors.size()); for (const auto& input_tensor : input_tensors) { @@ -119,14 +121,14 @@ Maybe NNGraph::RegisterInputOpNamesAndTensors( } Maybe NNGraph::RegisterOutputOpNamesAndTensors( - const std::vector& output_op_names, + const std::vector& outputs_op_names, const std::vector>& output_tensors) { - CHECK_EQ_OR_RETURN(output_op_names.size(), output_tensors.size()); - CHECK_OR_RETURN(output_op_names_.empty()) + CHECK_EQ_OR_RETURN(outputs_op_names.size(), output_tensors.size()); + CHECK_OR_RETURN(outputs_op_names_.empty()) << " The output tensors of nn.Graph " << name_ << " are register repeatedly."; CHECK_OR_RETURN(output_tensors_valid_.empty()); CHECK_OR_RETURN(outputs_tensor_meta_str_.empty()); - output_op_names_.assign(output_op_names.begin(), output_op_names.end()); + outputs_op_names_.assign(outputs_op_names.begin(), outputs_op_names.end()); output_tensors_valid_.reserve(output_tensors.size()); outputs_tensor_meta_str_.reserve(output_tensors.size()); for (const auto& output_tensor : output_tensors) { @@ -253,6 +255,9 @@ Maybe NNGraph::CompileAndInitRuntime() { // TODO(chengcheng): CHECK job valid for each rank. JUST(CreateAndRegisterNewVariableOpInJobPass()); + // NOTE(chengcheng): TensorNameScope need to be cleared after current graph is built. + one::TensorNameScope::Global()->Clear(); + // NOTE(chengcheng): Global need be clear before GlobalJobDescScope construct. if (Global::Get() != nullptr) { Global::Delete(); } @@ -298,34 +303,50 @@ Maybe NNGraph::CompileAndInitRuntime() { } void NNGraph::NewRuntimeBuffers() { - auto* buffer_mgr = Global>>::Get(); // NOTE(chengcheng): - // The BufferSize for each Buffer: - // 1. SourceTick and CallbackNotifier is job_conf.concurrency_width by user (default = 128) - // in Pipeline Parallelism, this value need greater than pipeline stage num for pipelining. - // 2. Input/Output Buffer is 2 because this is the minimum size of pipeline async launch job. + // 1. The BufferSize comes from job_conf.concurrency_width configured by user (default = 128) + // 2. In Pipeline Parallelism, this value need greater than pipeline stage num for pipelining. size_t concurrency_width = job_.job_conf().concurrency_width(); - buffer_mgr->NewBuffer(GetSourceTickBufferName(name_), concurrency_width); - buffer_mgr->NewBuffer(GetCallbackNotifierBufferName(name_), concurrency_width); - for (const std::string& input_op_name : input_op_names_) { - buffer_mgr->NewBuffer(GetInputBufferName(name_, input_op_name), concurrency_width); + { + auto* buffer_mgr = Global>>::Get(); + buffer_mgr->NewBuffer(GetSourceTickBufferName(name_), concurrency_width); + buffer_mgr->NewBuffer(GetCallbackNotifierBufferName(name_), concurrency_width); } - for (const std::string& output_op_name : output_op_names_) { - buffer_mgr->NewBuffer(GetOutputBufferName(name_, output_op_name), concurrency_width); + { + auto* buffer_mgr = Global>>::Get(); + buffer_mgr->NewBuffer(GetInputCriticalSectionWaitBufferName(name_), concurrency_width); + buffer_mgr->NewBuffer(GetInputCriticalSectionCallbackBufferName(name_), concurrency_width); + buffer_mgr->NewBuffer(GetOutputCriticalSectionWaitBufferName(name_), concurrency_width); + buffer_mgr->NewBuffer(GetOutputCriticalSectionCallbackBufferName(name_), concurrency_width); + for (const std::string& input_op_name : inputs_op_names_) { + buffer_mgr->NewBuffer(GetInputBufferName(name_, input_op_name), concurrency_width); + } + for (const std::string& output_op_name : outputs_op_names_) { + buffer_mgr->NewBuffer(GetOutputBufferName(name_, output_op_name), concurrency_width); + } } } void NNGraph::CloseRuntimeBuffers() { if (runtime_inited_) { - auto* buffer_mgr = Global>>::Get(); - for (const std::string& output_op_name : output_op_names_) { - buffer_mgr->Get(GetOutputBufferName(name_, output_op_name))->Close(); + { + auto* buffer_mgr = Global>>::Get(); + for (const std::string& output_op_name : outputs_op_names_) { + buffer_mgr->Get(GetOutputBufferName(name_, output_op_name))->Close(); + } + for (const std::string& input_op_name : inputs_op_names_) { + buffer_mgr->Get(GetInputBufferName(name_, input_op_name))->Close(); + } + buffer_mgr->Get(GetOutputCriticalSectionCallbackBufferName(name_))->Close(); + buffer_mgr->Get(GetOutputCriticalSectionWaitBufferName(name_))->Close(); + buffer_mgr->Get(GetInputCriticalSectionCallbackBufferName(name_))->Close(); + buffer_mgr->Get(GetInputCriticalSectionWaitBufferName(name_))->Close(); } - for (const std::string& input_op_name : input_op_names_) { - buffer_mgr->Get(GetInputBufferName(name_, input_op_name))->Close(); + { + auto* buffer_mgr = Global>>::Get(); + buffer_mgr->Get(GetCallbackNotifierBufferName(name_))->Close(); + buffer_mgr->Get(GetSourceTickBufferName(name_))->Close(); } - buffer_mgr->Get(GetCallbackNotifierBufferName(name_))->Close(); - buffer_mgr->Get(GetSourceTickBufferName(name_))->Close(); } } diff --git a/oneflow/core/framework/nn_graph.h b/oneflow/core/framework/nn_graph.h index 6e3ac26ce3d..4a13ab7847c 100644 --- a/oneflow/core/framework/nn_graph.h +++ b/oneflow/core/framework/nn_graph.h @@ -43,10 +43,10 @@ class NNGraph final : public NNGraphIf { int64_t variable_op_size() const; Maybe RegisterInputOpNamesAndTensors( - const std::vector& input_op_names, + const std::vector& inputs_op_names, const std::vector>& input_tensors); Maybe RegisterOutputOpNamesAndTensors( - const std::vector& output_op_names, + const std::vector& outputs_op_names, const std::vector>& output_tensors); Maybe RegisterVariableOpNamesAndTensors( const std::vector& variable_op_names, @@ -62,8 +62,8 @@ class NNGraph final : public NNGraphIf { void CloseRuntimeBuffers(); std::string name_; - std::vector input_op_names_; - std::vector output_op_names_; + std::vector inputs_op_names_; + std::vector outputs_op_names_; std::vector input_tensors_valid_; std::vector output_tensors_valid_; std::vector inputs_tensor_meta_str_; diff --git a/oneflow/core/framework/op_attrs.cpp b/oneflow/core/framework/op_attrs.cpp new file mode 100644 index 00000000000..7ae08477890 --- /dev/null +++ b/oneflow/core/framework/op_attrs.cpp @@ -0,0 +1,58 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/op_attrs.h" +#include "oneflow/core/common/util.h" +#include "oneflow/core/framework/op_interp_ctx.h" + +namespace oneflow { + +size_t OpAttrs::count(const std::string& attr_name) const { + return ctx_->AttrNames().count(attr_name); +} + +Maybe OpAttrs::at(const std::string& attr_name) const { return ctx_->GetAttr(attr_name); } +Maybe OpAttrs::operator[](const std::string& attr_name) const { + return ctx_->GetAttr(attr_name); +} + +OpAttrs::const_iterator OpAttrs::begin() const { + const auto& attrs = ctx_->AttrNames(); + return const_iterator(attrs.cbegin(), attrs.cend(), this); +} +OpAttrs::const_iterator OpAttrs::end() const { + const auto& attrs = ctx_->AttrNames(); + return const_iterator(attrs.cend(), attrs.cend(), this); +} + +bool OpAttrs::operator==(const OpAttrs& other) const { + // TODO(hjchen2): Compare each attribute + return ctx_ == other.ctx_; +} + +} // namespace oneflow + +namespace std { + +size_t hash::operator()(const oneflow::OpAttrs& attrs) const { + size_t hash_val = 0; + for (const auto& it : attrs) { + oneflow::AddHash(&hash_val, it.first); + oneflow::HashCombine(&hash_val, it.second->hash_value()); + } + return hash_val; +} + +} // namespace std diff --git a/oneflow/core/framework/op_attrs.h b/oneflow/core/framework/op_attrs.h new file mode 100644 index 00000000000..46f6df71f18 --- /dev/null +++ b/oneflow/core/framework/op_attrs.h @@ -0,0 +1,102 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_FRAMEWORK_OP_ATTRS_H_ +#define ONEFLOW_CORE_FRAMEWORK_OP_ATTRS_H_ + +#include +#include + +#include "oneflow/core/common/hash_container.h" +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/framework/attr_value.h" + +namespace oneflow { + +using user_op::AttrVal; + +class OpInterpCtx; + +class OpAttrs { + public: + explicit OpAttrs(const OpInterpCtx* ctx) : ctx_(ctx) {} + + size_t count(const std::string& attr_name) const; + + template + Maybe at(const std::string& attr_name) { + return AttrValueCast(*JUST(this->at(attr_name))); + } + Maybe at(const std::string& attr_name) const; + Maybe operator[](const std::string& attr_name) const; + + class const_iterator { + public: + using bucket_iter = HashSet::const_iterator; + using reference = const std::pair>&; + using pointer = const std::pair>*; + + const_iterator() = default; + const_iterator(bucket_iter pos, bucket_iter limit, const OpAttrs* self) + : pos_(pos), limit_(limit), self_(self) { + CHECK_JUST(UpdateKV()); + } + reference operator*() const { return kv_; } + pointer operator->() const { return &kv_; } + + const_iterator& operator++() { + pos_++; + CHECK_JUST(UpdateKV()); + return *this; + } + bool operator==(const const_iterator& x) const { return pos_ == x.pos_ && self_ == x.self_; } + bool operator!=(const const_iterator& x) const { return !(*this == x); } + + private: + Maybe UpdateKV() { + if (pos_ != limit_) { + kv_.first = *pos_; + kv_.second = JUST(self_->at(*pos_)); + } + return Maybe::Ok(); + } + + bucket_iter pos_; + bucket_iter limit_; + const OpAttrs* self_; + std::pair> kv_; + }; + + const_iterator begin() const; + const_iterator end() const; + + bool operator==(const OpAttrs& other) const; + + private: + const OpInterpCtx* ctx_; +}; + +} // namespace oneflow + +namespace std { + +template<> +struct hash { + size_t operator()(const oneflow::OpAttrs& attrs) const; +}; + +} // namespace std + +#endif // ONEFLOW_CORE_FRAMEWORK_OP_ATTRS_H_ diff --git a/oneflow/core/framework/op_base.h b/oneflow/core/framework/op_base.h new file mode 100644 index 00000000000..0f74faf5153 --- /dev/null +++ b/oneflow/core/framework/op_base.h @@ -0,0 +1,55 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_FRAMEWORK_OP_BASE_H_ +#define ONEFLOW_CORE_FRAMEWORK_OP_BASE_H_ + +#include + +#include "oneflow/core/common/hash_container.h" +#include "oneflow/core/common/maybe.h" + +namespace oneflow { + +namespace user_op { +class AttrVal; +} // namespace user_op +using AttrVal = user_op::AttrVal; + +class OpBase { + public: + virtual ~OpBase() = default; + + virtual Maybe GetAttr(const std::string& attr_name) const = 0; + + virtual const HashSet& AttrNames() const { + static const HashSet attr_names; + return attr_names; + } + + protected: + OpBase() = default; +}; + +class FakeOp : public OpBase { + public: + Maybe GetAttr(const std::string& attr_name) const override { + return Error::RuntimeError() << "`FakeOp` has no attribute."; + } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_FRAMEWORK_OP_BASE_H_ diff --git a/oneflow/core/framework/op_expr.cpp b/oneflow/core/framework/op_expr.cpp index 0fa12aad3f8..ceed83e1a4c 100644 --- a/oneflow/core/framework/op_expr.cpp +++ b/oneflow/core/framework/op_expr.cpp @@ -175,7 +175,7 @@ class UserOpExprInferContext : public user_op::InferContext { return TensorDesc4ArgNameAndIndex(name, index); } - user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& name, int32_t index) override { + user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& name, int32_t index) { { const auto& arg_tuple = *user_op_expr_->output_arg_tuple(); int32_t tuple_index = arg_tuple.TensorTupleIndex4ArgNameAndIndex(name, index); @@ -500,7 +500,6 @@ Maybe BuiltinOpExprImpl::GetOrCreateOpGradCl template<> Maybe BuiltinOpExprImpl::BuildOpConf(OperatorConf* op_conf, const AttrMap& attrs) const { - CHECK_EQ_OR_RETURN(attrs.size(), 0); *(op_conf->mutable_name()) = op_name_; *(op_conf->mutable_feed_variable_conf()) = op_proto_; return Maybe::Ok(); @@ -528,9 +527,20 @@ Maybe BuiltinOpExprImpl::GetOrCreateOpGrad template<> Maybe BuiltinOpExprImpl::BuildOpConf( OperatorConf* op_conf, const AttrMap& attrs) const { - CHECK_EQ_OR_RETURN(attrs.size(), 0); *(op_conf->mutable_name()) = op_name_; *(op_conf->mutable_image_decoder_random_crop_resize_conf()) = op_proto_; + auto* proto = op_conf->mutable_image_decoder_random_crop_resize_conf(); + proto->set_target_width(JUST(attrs.GetAttr("target_width"))); + proto->set_target_height(JUST(attrs.GetAttr("target_height"))); + proto->set_num_workers(JUST(attrs.GetAttr("num_workers"))); + proto->set_max_num_pixels(JUST(attrs.GetAttr("max_num_pixels"))); + proto->set_warmup_size(JUST(attrs.GetAttr("warmup_size"))); + proto->set_seed(JUST(attrs.GetAttr("seed"))); + proto->set_num_attempts(JUST(attrs.GetAttr("num_attempts"))); + proto->set_random_area_min(JUST(attrs.GetAttr("random_area_min"))); + proto->set_random_area_max(JUST(attrs.GetAttr("random_area_max"))); + proto->set_random_aspect_ratio_min(JUST(attrs.GetAttr("random_aspect_ratio_min"))); + proto->set_random_aspect_ratio_max(JUST(attrs.GetAttr("random_aspect_ratio_max"))); return Maybe::Ok(); } diff --git a/oneflow/core/framework/op_interp_ctx.cpp b/oneflow/core/framework/op_interp_ctx.cpp new file mode 100644 index 00000000000..8e950fbf78f --- /dev/null +++ b/oneflow/core/framework/op_interp_ctx.cpp @@ -0,0 +1,66 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/op_interp_ctx.h" +#include "oneflow/core/framework/attr_value.h" + +namespace oneflow { + +Maybe OpInterpCtx::GetAttr(const std::string& attr_name) const { + return op_->GetAttr(attr_name); +} + +template +Maybe OpInterpCtx::GetAttr(const std::string& attr_name) const { + const auto& attr_val = JUST(this->GetAttr(attr_name)); + if (const auto* ptr = dynamic_cast*>(attr_val.get())) { + return ptr->val(); + } + return Error::RuntimeError() << "Invalid type for attribute " << attr_name; +} + +OpAttrs OpInterpCtx::GetAttrs() const { return OpAttrs(this); } + +template +Maybe OpInterpCtx::SetAttr(const std::string& attr_name, const T& attr_val) { + *const_cast(&JUST(this->GetAttr(attr_name))) = attr_val; + return Maybe::Ok(); +} + +#define INSTANCE_ATTR_GETTER_AND_SETTER(field, T, attr_type) \ + template Maybe OpInterpCtx::GetAttr(const std::string& attr_name) const; \ + template Maybe OpInterpCtx::SetAttr(const std::string& attr_name, const T& attr_val); + +OF_PP_FOR_EACH_TUPLE(INSTANCE_ATTR_GETTER_AND_SETTER, ATTR_SEQ) +#undef INSTANCE_ATTR_GETTER_AND_SETTER + +Maybe OpInterpCtx::SetAttr(const std::string& attr_name, const AttrVal& attr_val) { +#define MAKE_ENTRY(field, cpp_type, attr_type) \ + if (const auto* ptr = dynamic_cast*>(&attr_val)) { \ + return this->SetAttr(attr_name, ptr->val()); \ + } + + OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, ATTR_SEQ); +#undef MAKE_ENTRY + return Error::RuntimeError() << "Invalid type for attribute " << attr_name; +} + +bool OpInterpCtx::HasAttr(const std::string& attr_name) const { + return AttrNames().count(attr_name) > 0; +} + +const HashSet& OpInterpCtx::AttrNames() const { return op_->AttrNames(); } + +} // namespace oneflow diff --git a/oneflow/core/framework/op_interp_ctx.h b/oneflow/core/framework/op_interp_ctx.h new file mode 100644 index 00000000000..771766f6c8e --- /dev/null +++ b/oneflow/core/framework/op_interp_ctx.h @@ -0,0 +1,73 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_FRAMEWORK_OP_INTERP_CTX_H_ +#define ONEFLOW_CORE_FRAMEWORK_OP_INTERP_CTX_H_ + +#include + +#include "oneflow/core/common/hash_container.h" +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/symbol.h" +#include "oneflow/core/framework/attr_value.h" +#include "oneflow/core/framework/nd_sbp.h" +#include "oneflow/core/framework/op_attrs.h" +#include "oneflow/core/framework/op_base.h" +#include "oneflow/core/job/parallel_desc.h" +#include "oneflow/core/job/sbp_parallel.cfg.h" + +namespace oneflow { + +using user_op::AttrVal; +template +using TypedAttrValRef = user_op::TypedAttrValRef; + +namespace user_op { +class OpKernelState; +} // namespace user_op + +class OpInterpCtx { + public: + explicit OpInterpCtx(const std::shared_ptr& op) : op_(op) {} + virtual ~OpInterpCtx() = default; + + template + Maybe GetAttr(const std::string& attr_name) const; + + Maybe GetAttr(const std::string& attr_name) const; + + OpAttrs GetAttrs() const; + + template + Maybe SetAttr(const std::string& attr_name, const T& attr_val); + + Maybe SetAttr(const std::string& attr_name, const AttrVal& attr_val); + + bool HasAttr(const std::string& attr_name) const; + + const HashSet& AttrNames() const; + + public: + std::shared_ptr op_; + + Optional> device; // for local op + Optional> parallel_desc; // for consistent op + Optional> sbp; // for consistent op + Optional state; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_FRAMEWORK_OP_INTERP_CTX_H_ diff --git a/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp index ba4f32b7e14..61837d26bae 100644 --- a/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp @@ -32,6 +32,7 @@ limitations under the License. #include "oneflow/core/autograd/autograd_mode.h" #include "oneflow/core/boxing/eager_boxing_interpreter_mgr.h" #include "oneflow/user/kernels/stateful_local_opkernel.h" +#include "oneflow/core/framework/consistency_check.h" #include "oneflow/core/framework/tensor_rpc_util.h" #include "oneflow/core/framework/tensor_consistent_id.h" #include "oneflow/core/framework/nd_sbp.h" @@ -88,7 +89,11 @@ Maybe Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs, CHECK_EQ_OR_RETURN(outputs->size(), user_op_expr.output_size()); const auto& parallel_desc = JUST(GetParallelDesc(inputs, ctx)); std::shared_ptr result; + NonRecursiveMetaInfoConsistencyCheckScope scope; if (inputs.empty()) { + // check consistency placment and nd_sbp, do not check in non-src op because it is assumed that + // InferSbp in op is a deterministic algorithm + JUST(MetaInfoConsistencyCheck(parallel_desc, ctx.nd_sbp)); const auto& infer_args = JUST(SrcOpConsistentTensorMetaInferArgs::New(ctx.attrs, parallel_desc, JUST(ctx.nd_sbp))); result = JUST(user_op_expr.mut_consistent_tensor_infer_cache()->GetOrInfer(*infer_args)); @@ -184,8 +189,8 @@ Maybe RawConsistentToConsistent(const ConsistentToConsistentOpExpr& op_exp if (out_parallel_id->has_value()) { const auto& nd_sbp = JUST(tensor->nd_sbp()); const auto& parallel_desc = JUST(tensor->parallel_desc()); - CHECK_OR_RETURN(nd_sbp == out_nd_sbp) << ". nd_sbp: " << *JUST(NdSbpToString(nd_sbp)) - << ", out_nd_sbp" << *JUST(NdSbpToString(out_nd_sbp)); + CHECK_OR_RETURN(nd_sbp == out_nd_sbp) + << ". nd_sbp: " << NdSbpToString(nd_sbp) << ", out_nd_sbp" << NdSbpToString(out_nd_sbp); CHECK_OR_RETURN(parallel_desc == out_parallel_desc); outputs->at(0) = tensor; } else { diff --git a/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp index 14799a83f6c..57bc58e2a0c 100644 --- a/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/lazy_op_interpreter.cpp @@ -13,8 +13,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/common/cpp_attribute.h" #include "oneflow/core/common/maybe.h" #include "oneflow/core/common/cpp_attribute.h" +#include "oneflow/core/framework/consistency_check.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/multi_client_session_context.h" @@ -73,12 +75,9 @@ Maybe BuildTensor(const OpAttribute& op_attribute, const std::string& bn Maybe CheckTensorMatchAttr(const std::shared_ptr& tensor, const OpAttribute& op_attribute, const std::string& bn_in_op, const std::shared_ptr& parallel_desc, - const bool is_lazy, const bool is_local, const bool requires_grad, - const bool is_leaf) { + const bool is_lazy, const bool is_local) { CHECK_EQ_OR_RETURN(tensor->is_lazy(), is_lazy); CHECK_EQ_OR_RETURN(tensor->is_local(), is_local); - CHECK_EQ_OR_RETURN(tensor->requires_grad(), requires_grad); - CHECK_EQ_OR_RETURN(tensor->is_leaf(), is_leaf); CHECK_OR_RETURN(op_attribute.has_logical_blob_desc_signature()); const auto& blob_desc_sign_map = op_attribute.logical_blob_desc_signature().bn_in_op2blob_desc(); @@ -100,7 +99,8 @@ Maybe CheckTensorMatchAttr(const std::shared_ptr& tensor, CHECK_OR_RETURN(nd_sbp_it != nd_sbp_sign_map.end()) << "nd_sbp of " << bn_in_op << " not found in op " << op_attribute.op_conf().name(); cfg::NdSbp nd_sbp(nd_sbp_it->second); - CHECK_OR_RETURN(JUST(tensor->nd_sbp()) == SymbolOf(nd_sbp)); + CHECK_OR_RETURN(JUST(tensor->nd_sbp()) == SymbolOf(nd_sbp)) + << "The input sbp is not valid for an inplace operation, please try to use non-inplace."; CHECK_OR_RETURN(JUST(tensor->parallel_desc()) == SymbolOf(*parallel_desc)); } return Maybe::Ok(); @@ -418,12 +418,15 @@ namespace { Maybe LazyInterpreterApplyImplForSourceUserOpExpr(const UserOpExpr& op_expr, TensorTuple* outputs, const OpExprInterpContext& ctx) { + NonRecursiveMetaInfoConsistencyCheckScope non_scope; bool is_local; std::shared_ptr parallel_desc; if (ctx.parallel_desc.has_value()) { // NOTE(chengcheng): consistent CHECK_OR_RETURN(!ctx.device.has_value()); - parallel_desc = JUST(ctx.parallel_desc).shared_from_symbol(); + const auto& parallel_desc_sym = JUST(ctx.parallel_desc); + parallel_desc = parallel_desc_sym.shared_from_symbol(); + JUST(MetaInfoConsistencyCheck(parallel_desc_sym, ctx.nd_sbp)); is_local = false; } else { // NOTE(chengcheng): local @@ -650,6 +653,21 @@ Maybe LazyInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTu } } + // Check outputs num and setup output tensor properties. + CHECK_EQ_OR_RETURN(outputs->size(), op_expr.output_size()); + + // Disable boxing if the computation is inplace. + for (int i = 0; i < op_expr.output_size(); ++i) { + const auto& output = outputs->at(i); + if (output) { + const std::string& lbn = TensorNameScope::Global()->Lookup(output); + CHECK_OR_RETURN(!lbn.empty()) << "The output which index is " << i + << " has no tensor name, please check whether the inplaced " + "output is also an input of the operation " + << new_op_name; + JUST(infer_ctx->DisableBoxing(lbn)); + } + } VLOG(2) << "Lazy nn.Graph name " << graph_name << " try to add op: \n" << op_conf->DebugString() << std::endl; OpAttribute op_attr = *JUST(infer_ctx->AddAndInferConsistentOp(*op_conf)); @@ -660,9 +678,6 @@ Maybe LazyInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTu int64_t parallel_desc_sym_id = JUST(scope->GetParallelDescSymbolId(*op_conf)); auto blob_parallel_desc = JUST(GetSymbol(parallel_desc_sym_id)); - - // Check outputs num and setup output tensor properties. - CHECK_EQ_OR_RETURN(outputs->size(), op_expr.output_size()); for (int i = 0; i < op_expr.output_size(); ++i) { const std::string& obn = op_expr.indexed_obns().at(i); if (!(*outputs)[i]) { @@ -671,9 +686,7 @@ Maybe LazyInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTu } else { const std::shared_ptr& inplace_out = (*outputs)[i]; JUST(CheckTensorMatchAttr(inplace_out, op_attr, obn, blob_parallel_desc, /* is_lazy= */ true, - is_local, - /* requires_grad */ false, - /* is_leaf */ true)); + is_local)); } TensorNameScope::Global()->Record((*outputs)[i], GenLogicalBlobName(new_op_name, obn)); } diff --git a/oneflow/core/framework/op_interpreter/op_interpreter.cpp b/oneflow/core/framework/op_interpreter/op_interpreter.cpp index 88de66d6b44..1af45137f31 100644 --- a/oneflow/core/framework/op_interpreter/op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/op_interpreter.cpp @@ -23,6 +23,7 @@ limitations under the License. #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" +#include "oneflow/core/job/lazy_mode.h" namespace oneflow { namespace one { @@ -94,7 +95,8 @@ Maybe AutogradInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& autograd::AutoGradMode mode(false); JUST(internal_->Apply(op_expr, inputs, outputs, ctx)); } - if (requires_grad) { + // Lazy mode will construct backward compute graph in passes, so disable autograd if lazy mode. + if (requires_grad && !LazyMode::is_enabled()) { const auto& grad_closure = JUST(op_expr.GetOrCreateOpGradClosure()); JUST(grad_closure->Capture(inputs, *outputs, ctx)); diff --git a/oneflow/core/framework/op_kernel.h b/oneflow/core/framework/op_kernel.h index 0b5a4f94d9d..bd831d9f886 100644 --- a/oneflow/core/framework/op_kernel.h +++ b/oneflow/core/framework/op_kernel.h @@ -95,6 +95,64 @@ class KernelInitContext { virtual const std::shared_ptr& Attr4Name(const std::string& attr_name) const = 0; }; +class KernelCacheContext { + public: + OF_DISALLOW_COPY_AND_MOVE(KernelCacheContext); + virtual ~KernelCacheContext() = default; + + virtual ep::Stream* stream() = 0; + + virtual DeviceType device_type() const = 0; + virtual const ParallelContext& parallel_ctx() const = 0; + virtual const TensorDesc* TensorDesc4ArgNameAndIndex(const std::string&, int32_t) const = 0; + virtual const cfg::SbpParallel& SbpParallel4ArgNameAndIndex(const std::string&, + int32_t) const = 0; + virtual const TensorDesc* LogicalTensorDesc4ArgNameAndIndex(const std::string&, + int32_t) const = 0; + virtual const ParallelDesc& parallel_desc() const = 0; + virtual const cfg::NdSbp& NdSbp4ArgNameAndIndex(const std::string&, int32_t) const = 0; + + virtual const std::vector>& inputs() const = 0; + virtual const std::vector>& outputs() const = 0; + + const std::string& input(const std::string& arg_name, int32_t index) const { + return user_op_conf().input(arg_name, index); + } + const std::string& output(const std::string& arg_name, int32_t index) const { + return user_op_conf().output(arg_name, index); + } + bool has_input(const std::string& arg_name, int32_t index) const { + return user_op_conf().has_input(arg_name, index); + } + bool has_output(const std::string& arg_name, int32_t index) const { + return user_op_conf().has_output(arg_name, index); + } + int32_t input_size(const std::string& arg_name) const { + return user_op_conf().input_size(arg_name); + } + int32_t output_size(const std::string& arg_name) const { + return user_op_conf().output_size(arg_name); + } + const std::string& op_name() const { return user_op_conf().op_name(); } + const std::string& op_type_name() const { return user_op_conf().op_type_name(); } + const std::string& device_tag() const { return user_op_conf().op_conf().device_tag(); } + const OperatorConf& op_conf() const { return user_op_conf().op_conf(); } + + template + const T& Attr(const std::string& attr_name) const { + return AttrValueCast(*Attr4Name(attr_name)); + } + + template + const T& attr(const std::string& attr_name) const; + + protected: + KernelCacheContext() = default; + + virtual const UserOpConfWrapper& user_op_conf() const = 0; + virtual const std::shared_ptr& Attr4Name(const std::string& attr_name) const = 0; +}; + class KernelInferContext { public: OF_DISALLOW_COPY_AND_MOVE(KernelInferContext); @@ -217,6 +275,18 @@ class OpKernelState { OpKernelState() = default; }; +class OpKernelCache { + public: + virtual ~OpKernelCache() = default; + + static const int32_t kAllMayChanged = 0; + static const int32_t kShapeNotChanged = 1 << 0; + static const int32_t kAttrNotChanged = 1 << 1; + + protected: + OpKernelCache() = default; +}; + class OpKernel; template @@ -231,7 +301,18 @@ class OpKernel { return std::shared_ptr(); } - virtual void Compute(KernelComputeContext* ctx, OpKernelState*) const { Compute(ctx); } + virtual std::shared_ptr InitOpKernelCache(KernelCacheContext* ctx) const { + return std::shared_ptr(); + } + + virtual void InitOpKernelCache(KernelCacheContext* ctx, int8_t flag, + std::shared_ptr* cache_ptr) const { + *cache_ptr = InitOpKernelCache(ctx); + } + + virtual void Compute(KernelComputeContext* ctx, OpKernelState*, const OpKernelCache*) const { + Compute(ctx); + } virtual void Compute(KernelComputeContext*) const { LOG(INFO) << "UNIMPLEMENTED"; } virtual void InferShape(KernelInferContext* ctx) const; virtual bool AlwaysComputeWhenAllOutputsEmpty() const = 0; diff --git a/oneflow/core/framework/placement_sbp_util.cpp b/oneflow/core/framework/placement_sbp_util.cpp index d64c1bfafb4..ccbc1011d63 100644 --- a/oneflow/core/framework/placement_sbp_util.cpp +++ b/oneflow/core/framework/placement_sbp_util.cpp @@ -446,8 +446,8 @@ std::string GetCyclicBoxingDebugString( CHECK_EQ(src_nd_sbp->sbp_parallel_size(), dst_nd_sbp->sbp_parallel_size()); std::stringstream ss; ss << "cyclic split axis boxing are not supported. " - << "src_nd_sbp: " << CHECK_JUST(NdSbpToString(src_nd_sbp)) - << ", dst_nd_sbp: " << CHECK_JUST(NdSbpToString(dst_nd_sbp)) << ". " + << "src_nd_sbp: " << NdSbpToString(src_nd_sbp) << ", dst_nd_sbp: " << NdSbpToString(dst_nd_sbp) + << ". " << "dst_nd_sbp axis to exclusive src_nd_sbp axis: "; ss << "["; for (int i = 0; i < src_nd_sbp->sbp_parallel_size(); ++i) { diff --git a/oneflow/core/framework/random_generator_impl.cpp b/oneflow/core/framework/random_generator_impl.cpp index a99b93bf686..888cccaf5d6 100644 --- a/oneflow/core/framework/random_generator_impl.cpp +++ b/oneflow/core/framework/random_generator_impl.cpp @@ -18,6 +18,7 @@ limitations under the License. #include "oneflow/core/common/util.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/instructions_builder.h" +#include "oneflow/core/framework/tensor_util.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/job/env_global_objects_scope.h" #include "oneflow/core/register/ofblob.h" @@ -33,17 +34,6 @@ namespace one { namespace { -Maybe SyncAccessTensorWithTimeOut( - const std::shared_ptr& tensor, - const std::shared_ptr>& callback, const std::string& modifier) { - return SpinCounter::SpinWait(1, [&](const std::shared_ptr& sc) -> Maybe { - return PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { - return builder->SyncAccessBlobByCallback(JUST(tensor->AsMirroredTensor()), sc, callback, - modifier); - }); - }); -} - Maybe CPUSynchronize() { if (Global::Get() != nullptr) { return vm::CurrentRankSync(); } return Maybe::Ok(); diff --git a/oneflow/core/framework/system_ops.cpp b/oneflow/core/framework/system_ops.cpp new file mode 100644 index 00000000000..44b449fe5ec --- /dev/null +++ b/oneflow/core/framework/system_ops.cpp @@ -0,0 +1,115 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/system_ops.h" +#include "oneflow/core/framework/attr_value.h" + +namespace oneflow { +namespace schema { + +Maybe CastToConsistentOp::GetAttr(const std::string& attr_name) const { + if (attr_name == "shape") { + return CastAttrValue(&shape); + } else if (attr_name == "dtype") { + return CastAttrValue(&dtype); + } else { + return Error::RuntimeError() << "CastToConsistent op has no attribute named " << attr_name; + } +} + +const HashSet& CastToConsistentOp::AttrNames() const { + static HashSet attr_names{"shape", "dtype"}; + return attr_names; +} + +Maybe SelectTopNOp::GetAttr(const std::string& attr_name) const { + if (attr_name == "top_n") { + return CastAttrValue(&top_n); + } else { + return Error::RuntimeError() << "SelectTopN op has no attribute named " << attr_name; + } +} + +const HashSet& SelectTopNOp::AttrNames() const { + static HashSet attr_names{"top_n"}; + return attr_names; +} + +Maybe FeedInputOp::GetAttr(const std::string& attr_name) const { + return Error::RuntimeError() << "FeedInput op has no attribute named " << attr_name; +} + +Maybe FetchOutputOp::GetAttr(const std::string& attr_name) const { + return Error::RuntimeError() << "FetchOutput op has no attribute named " << attr_name; +} + +Maybe FeedVariableOp::GetAttr(const std::string& attr_name) const { + if (attr_name == "_l2") { + return CastAttrValue(&_l2); + } else { + return Error::RuntimeError() << "FeedVariable op has no attribute named " << attr_name; + } +} + +const HashSet& FeedVariableOp::AttrNames() const { + static HashSet attr_names{"_l2"}; + return attr_names; +} + +Maybe ImageDecoderRandomCropResizeOp::GetAttr(const std::string& attr_name) const { + if (attr_name == "target_width") { + return CastAttrValue(&target_width); + } else if (attr_name == "target_height") { + return CastAttrValue(&target_height); + } else if (attr_name == "num_workers") { + return CastAttrValue(&num_workers); + } else if (attr_name == "max_num_pixels") { + return CastAttrValue(&max_num_pixels); + } else if (attr_name == "warmup_size") { + return CastAttrValue(&warmup_size); + } else if (attr_name == "seed") { + return CastAttrValue(&seed); + } else if (attr_name == "num_attempts") { + return CastAttrValue(&num_attempts); + } else if (attr_name == "random_area_min") { + return CastAttrValue(&random_area_min); + } else if (attr_name == "random_area_max") { + return CastAttrValue(&random_area_max); + } else if (attr_name == "random_aspect_ratio_min") { + return CastAttrValue(&random_aspect_ratio_min); + } else if (attr_name == "random_aspect_ratio_max") { + return CastAttrValue(&random_aspect_ratio_max); + } else { + return Error::RuntimeError() << "FeedVariable op has no attribute named " << attr_name; + } +} + +const HashSet& ImageDecoderRandomCropResizeOp::AttrNames() const { + static HashSet attr_names{"target_width", + "target_height", + "num_workers", + "max_num_pixels", + "warmup_size", + "seed", + "num_attempts", + "random_area_min", + "random_area_max", + "random_aspect_ratio_min", + "random_aspect_ratio_max"}; + return attr_names; +} + +} // namespace schema +} // namespace oneflow diff --git a/oneflow/core/framework/system_ops.h b/oneflow/core/framework/system_ops.h new file mode 100644 index 00000000000..69b1fad6858 --- /dev/null +++ b/oneflow/core/framework/system_ops.h @@ -0,0 +1,89 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_FRAMEWORK_SYSTEM_OPS_H_ +#define ONEFLOW_CORE_FRAMEWORK_SYSTEM_OPS_H_ + +#include "oneflow/core/framework/op_base.h" + +#include "oneflow/core/common/data_type.pb.h" +#include "oneflow/core/common/hash_container.h" +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/shape.h" + +namespace oneflow { +namespace schema { + +class CastToConsistentOp : public OpBase { + public: + Maybe GetAttr(const std::string& attr_name) const override; + const HashSet& AttrNames() const override; + + public: + Shape shape; + DataType dtype; +}; + +class SelectTopNOp : public OpBase { + public: + Maybe GetAttr(const std::string& attr_name) const override; + const HashSet& AttrNames() const override; + + public: + int32_t top_n; +}; + +class FeedInputOp : public OpBase { + public: + Maybe GetAttr(const std::string& attr_name) const override; +}; + +class FetchOutputOp : public OpBase { + public: + Maybe GetAttr(const std::string& attr_name) const override; +}; + +class FeedVariableOp : public OpBase { + public: + Maybe GetAttr(const std::string& attr_name) const override; + const HashSet& AttrNames() const override; + + public: + double _l2; +}; + +class ImageDecoderRandomCropResizeOp : public OpBase { + public: + Maybe GetAttr(const std::string& attr_name) const override; + const HashSet& AttrNames() const override; + + public: + int64_t target_width; + int64_t target_height; + int64_t num_workers; + int64_t max_num_pixels; + int64_t warmup_size; + int64_t seed; + int64_t num_attempts; + float random_area_min; + float random_area_max; + float random_aspect_ratio_min; + float random_aspect_ratio_max; +}; + +} // namespace schema +} // namespace oneflow + +#endif // ONEFLOW_CORE_FRAMEWORK_SYSTEM_OPS_H_ diff --git a/oneflow/core/framework/tensor.cpp b/oneflow/core/framework/tensor.cpp index 3a0d14d2b0b..7eab130f59d 100644 --- a/oneflow/core/framework/tensor.cpp +++ b/oneflow/core/framework/tensor.cpp @@ -25,7 +25,6 @@ limitations under the License. #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/autograd/autograd_engine.h" #include "oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.h" -#include "oneflow/core/framework/tensor_rpc_util.h" #include "oneflow/core/functional/functional.h" namespace oneflow { @@ -104,6 +103,23 @@ Maybe ConsistentTensor::detach() const { return t; } +Maybe ConsistentTensor::set_data(const std::shared_ptr& other) { + CHECK_OR_RETURN(this->is_leaf()) + << "Only leaf tensor's data can be set, because non-leaf tensor's data has been captured in " + "the backward graph in autograd."; + const auto& consistent_tensor = + std::dynamic_pointer_cast(JUST(other->detach())); + CHECK_NOTNULL_OR_RETURN(consistent_tensor); + JUST(WithConsistencyChecked(consistent_tensor, + [&]() -> Maybe { return Maybe::Ok(); })); + + bool old_requires_grad = requires_grad(); + impl_ = consistent_tensor->impl_; + JUST(set_requires_grad(old_requires_grad)); + grad_fn_node_ = nullptr; + return Maybe::Ok(); +} + } // namespace one } // namespace oneflow diff --git a/oneflow/core/framework/tensor.h b/oneflow/core/framework/tensor.h index 7a7520aa3c6..2d7e2fce348 100644 --- a/oneflow/core/framework/tensor.h +++ b/oneflow/core/framework/tensor.h @@ -587,17 +587,7 @@ class ConsistentTensor final : public TensorIf { } user_op::TensorDesc* mut_tensor_meta() override { return impl_->mut_tensor_meta(); } - Maybe set_data(const std::shared_ptr& other) override { - CHECK_OR_RETURN(this->is_leaf()) << "Can only set leaf tensor's data."; - const auto& consistent_tensor = - std::dynamic_pointer_cast(JUST(other->detach())); - CHECK_NOTNULL_OR_RETURN(consistent_tensor); - bool old_requires_grad = requires_grad(); - impl_ = consistent_tensor->impl_; - set_requires_grad(old_requires_grad); - grad_fn_node_ = nullptr; - return Maybe::Ok(); - } + Maybe set_data(const std::shared_ptr& other) override; Maybe AsMirroredTensor() override { RETURN_ERROR_WITH_BUG_PROMPT(); } Maybe AsConsistentTensor() override { diff --git a/oneflow/core/framework/tensor_methods.cpp b/oneflow/core/framework/tensor_methods.cpp index d05d0f97797..b79cdba4067 100644 --- a/oneflow/core/framework/tensor_methods.cpp +++ b/oneflow/core/framework/tensor_methods.cpp @@ -65,7 +65,7 @@ Maybe BasicView(const std::shared_ptr& input, const Shape& targe std::make_shared(target_shape), input->dtype()->data_type(), device, std::make_shared(target_strides), storage_offset); - JUST(input->has_eager_blob_object()); + CHECK_OR_RETURN(JUST(input->has_eager_blob_object())); // new output tensor const auto& blob_object = JUST(input->eager_blob_object()); auto tensor_impl = std::make_shared( @@ -88,7 +88,9 @@ Maybe Reshape(const std::shared_ptr& input, const Shape& shape) int need_infer_axis = -1; size_t count = 1; for (int i = 0; i < shape.NumAxes(); ++i) { - if (shape.At(i) == -1) { + if (shape.At(i) < -1) { + return Error::RuntimeError() << "Invalid shape dimension " << shape.At(i); + } else if (shape.At(i) == -1) { CHECK_EQ_OR_RETURN(need_infer_axis, -1) << "Shape " << shape.ToString() << " has more than 1 axis that needs to be infered."; need_infer_axis = i; diff --git a/oneflow/core/framework/tensor_name_scope.cpp b/oneflow/core/framework/tensor_name_scope.cpp index ff0c484cb81..1a841cdd76a 100644 --- a/oneflow/core/framework/tensor_name_scope.cpp +++ b/oneflow/core/framework/tensor_name_scope.cpp @@ -42,5 +42,10 @@ void TensorNameScope::Record(const std::shared_ptr& tensor, const std::s tensor_names_[key] = name; } +void TensorNameScope::Clear() { + std::lock_guard lock(mutex_); + tensor_names_.clear(); +} + } // namespace one } // namespace oneflow diff --git a/oneflow/core/framework/tensor_name_scope.h b/oneflow/core/framework/tensor_name_scope.h index b9bfad9b8eb..2ead6c19ba2 100644 --- a/oneflow/core/framework/tensor_name_scope.h +++ b/oneflow/core/framework/tensor_name_scope.h @@ -31,6 +31,8 @@ class TensorNameScope { void Record(const std::shared_ptr& tensor, const std::string& name); + void Clear(); + private: TensorNameScope() : default_tensor_name_("") {} virtual ~TensorNameScope() = default; diff --git a/oneflow/core/framework/tensor_rpc_util.cpp b/oneflow/core/framework/tensor_rpc_util.cpp index 33879ddd5de..b262e7597f8 100644 --- a/oneflow/core/framework/tensor_rpc_util.cpp +++ b/oneflow/core/framework/tensor_rpc_util.cpp @@ -169,7 +169,7 @@ Maybe LaunchTensorMetaConsistencyCheck(const return ctx; } -Maybe BuzyWaitAndCheck(std::shared_ptr& ctx) { +Maybe BusyWaitAndCheck(std::shared_ptr& ctx) { JUST(TransportUtil::WaitUntilDoneOrTimeout(*ctx, TransportUtil::TimeoutSeconds())); JUST(ctx->Check()); return Maybe::Ok(); diff --git a/oneflow/core/framework/tensor_rpc_util.h b/oneflow/core/framework/tensor_rpc_util.h index 113ca9deadf..486a0c08a47 100644 --- a/oneflow/core/framework/tensor_rpc_util.h +++ b/oneflow/core/framework/tensor_rpc_util.h @@ -33,7 +33,7 @@ int64_t* MutThreadLocalTensorMetaCheckDepth(); Maybe LaunchTensorMetaConsistencyCheck( const one::Tensor& tensor); -Maybe BuzyWaitAndCheck(std::shared_ptr& ctx); +Maybe BusyWaitAndCheck(std::shared_ptr& ctx); Maybe RunCallback(const std::shared_ptr& tensor, const std::function()>& Callback); @@ -59,7 +59,7 @@ struct CheckConsistentTensorMeta&, Args RetT ret = func(tensor, args...); --*depth; // Always synchronize consistent tensor meta even if `func` failed. - if (*depth == 0) { JUST(private_details::BuzyWaitAndCheck(ctx)); } + if (*depth == 0) { JUST(private_details::BusyWaitAndCheck(ctx)); } return ret; } }; diff --git a/oneflow/core/framework/tensor_util.cpp b/oneflow/core/framework/tensor_util.cpp new file mode 100644 index 00000000000..6a615e25173 --- /dev/null +++ b/oneflow/core/framework/tensor_util.cpp @@ -0,0 +1,36 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/tensor_util.h" + +#include "oneflow/core/common/spin_counter.h" +#include "oneflow/core/framework/instructions_builder.h" + +namespace oneflow { +namespace one { + +Maybe SyncAccessTensorWithTimeOut( + const std::shared_ptr& tensor, + const std::shared_ptr>& callback, const std::string& modifier) { + return SpinCounter::SpinWait(1, [&](const std::shared_ptr& sc) -> Maybe { + return PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { + return builder->SyncAccessBlobByCallback(JUST(tensor->AsMirroredTensor()), sc, callback, + modifier); + }); + }); +} + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/framework/tensor_util.h b/oneflow/core/framework/tensor_util.h new file mode 100644 index 00000000000..028ddf05e2e --- /dev/null +++ b/oneflow/core/framework/tensor_util.h @@ -0,0 +1,29 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include + +#include "oneflow/core/common/maybe.h" + +namespace oneflow { +namespace one { + +class Tensor; + +Maybe SyncAccessTensorWithTimeOut( + const std::shared_ptr& tensor, + const std::shared_ptr>& callback, const std::string& modifier); +} // namespace one +} // namespace oneflow \ No newline at end of file diff --git a/oneflow/core/framework/user_op_conf.cpp b/oneflow/core/framework/user_op_conf.cpp index d34a6700977..0b4ac26e617 100644 --- a/oneflow/core/framework/user_op_conf.cpp +++ b/oneflow/core/framework/user_op_conf.cpp @@ -278,17 +278,11 @@ Maybe CheckArgDefIsValidInUserOpConf( if (arg_name2lbns.find(arg.name()) != arg_name2lbns.end()) { arg_blob_num = arg_name2lbns.at(arg.name()).s_size(); } - if (arg_blob_num != arg.num()) { - if (arg_blob_num == 0) { - CHECK_OR_RETURN(arg.is_optional()) - << " op_name: " << op_name << " op_type_name: " << op_type_name - << " arg name: " << arg.name() << " in OpDef must have blob in op_conf"; - } else { - CHECK_OR_RETURN(arg_blob_num > arg.num() && arg.num_as_min()) - << " op_name: " << op_name << " op_type_name: " << op_type_name - << " arg name: " << arg.name() << " has blob num: " << arg_blob_num - << " in op_conf does not meet its constraints in OpDef"; - } + if (arg_blob_num == 0) { + CHECK_OR_RETURN(arg.is_optional()) + << " op_name: " << op_name << " op_type_name: " << op_type_name + << " arg name: " << arg.name() << " in OpDef must have blob in op_conf: \n" + << op_conf.DebugString(); } op_def_arg_names.insert(arg.name()); } @@ -358,24 +352,6 @@ Maybe AddAttrDefaultValueAndCheckValid(const UserOpDef& op_def, OperatorCo return AddAttrDefaultValueAndCheckValid(op_def, user_conf, error_msg_prefix); } -Maybe AddUserOpConfOutputDefaultArg(const UserOpDef& op_def, OperatorConf* op_conf) { - UserOpConf* user_conf = op_conf->mutable_user_conf(); - // add default output arg and lbn - for (const auto& output_arg : op_def.output()) { - if (user_conf->output().find(output_arg.name()) == user_conf->output().end() - && (!output_arg.is_optional()) && (!output_arg.num_as_min())) { - for (int32_t i = 0; i < output_arg.num(); ++i) { - std::string lbn = GenLogicalBlobName(op_conf->name(), GenRepeatedBn(output_arg.name(), i)); - (*(user_conf->mutable_output()))[output_arg.name()].add_s(lbn); - CHECK_EQ_OR_RETURN(i + 1, user_conf->output().at(output_arg.name()).s_size()); - } - user_conf->add_output_order(output_arg.name()); - CHECK_EQ_OR_RETURN(user_conf->output().size(), user_conf->output_order().size()); - } - } - return Maybe::Ok(); -} - Maybe GetAttrTypeImpl(const std::string& op_type_name, const std::string& attr_name) { const user_op::OpRegistryResult* val = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_type_name); @@ -397,7 +373,6 @@ Maybe CheckAndCompleteUserOpConfImpl(const OperatorConf& op_conf) const UserOpDef& op_def = val->op_def; JUST(AddAttrDefaultValueAndCheckValid(op_def, &ret)); - JUST(AddUserOpConfOutputDefaultArg(op_def, &ret)); // check input and output valid JUST(CheckArgDefIsValidInUserOpConf(op_conf, user_conf->input(), op_def.input())); JUST(CheckArgDefIsValidInUserOpConf(op_conf, user_conf->output(), op_def.output())); diff --git a/oneflow/core/framework/user_op_def.cpp b/oneflow/core/framework/user_op_def.cpp index 8aee596d2c5..def5da20d92 100644 --- a/oneflow/core/framework/user_op_def.cpp +++ b/oneflow/core/framework/user_op_def.cpp @@ -51,12 +51,6 @@ bool UserOpDefWrapper::IsArgOptional(const std::string& name) const { return arg_def->is_optional(); } -std::pair UserOpDefWrapper::ArgNumAndIsMin(const std::string& name) const { - const UserOpDef::ArgDef* arg_def = GetArgPointer(name); - CHECK_NOTNULL(arg_def); - return std::make_pair(arg_def->num(), arg_def->num_as_min()); -} - const UserOpDef::ArgDef* UserOpDefWrapper::GetArgPointer(const std::string& name) const { auto it = inputs_.find(name); if (it != inputs_.end()) { return it->second; } diff --git a/oneflow/core/framework/user_op_def.h b/oneflow/core/framework/user_op_def.h index 8d39a4333cf..f3c2aab548e 100644 --- a/oneflow/core/framework/user_op_def.h +++ b/oneflow/core/framework/user_op_def.h @@ -37,7 +37,6 @@ class UserOpDefWrapper final { bool IsAttrName(const std::string&) const; bool IsArgOptional(const std::string&) const; - std::pair ArgNumAndIsMin(const std::string&) const; AttrType GetAttrType(const std::string&) const; bool AttrHasDefaultVal(const std::string&) const; diff --git a/oneflow/core/framework/user_op_def.proto b/oneflow/core/framework/user_op_def.proto index d22c8c808b6..e4ea72157e3 100644 --- a/oneflow/core/framework/user_op_def.proto +++ b/oneflow/core/framework/user_op_def.proto @@ -9,8 +9,6 @@ message UserOpDef { message ArgDef { required string name = 1; optional bool is_optional = 2 [default = false]; - required int32 num = 3; - required bool num_as_min = 4; } repeated ArgDef input = 2; repeated ArgDef output = 3; diff --git a/oneflow/core/framework/user_op_registry.cpp b/oneflow/core/framework/user_op_registry.cpp index e7a15b2c869..837addea155 100644 --- a/oneflow/core/framework/user_op_registry.cpp +++ b/oneflow/core/framework/user_op_registry.cpp @@ -42,15 +42,13 @@ OpRegistry& OpRegistry::Name(const std::string& op_type_name) { return *this; } -OpRegistry& OpRegistry::ArgImpl(bool is_input, const std::string& name, bool is_optional, - int32_t num, bool num_as_min) { - CHECK(InsertIfNotExists(name, &unique_names_)); +OpRegistry& OpRegistry::ArgImpl(bool is_input, const std::string& name, bool is_optional) { + CHECK(InsertIfNotExists(name, &unique_names_)) + << "op arg registered, name: " << name << ", op: " << result_.op_type_name; UserOpDef::ArgDef arg_def; { arg_def.set_name(name); arg_def.set_is_optional(is_optional); - arg_def.set_num(num); - arg_def.set_num_as_min(num_as_min); } if (is_input) { *(result_.op_def.mutable_input()->Add()) = arg_def; @@ -60,15 +58,9 @@ OpRegistry& OpRegistry::ArgImpl(bool is_input, const std::string& name, bool is_ return *this; } -#define OP_REG_ARG_MEMBER_FUNC(name_prefix, is_input, is_optional) \ - OpRegistry& OpRegistry::name_prefix(const std::string& name) { \ - return ArgImpl(is_input, name, is_optional, 1, false); \ - } \ - OpRegistry& OpRegistry::name_prefix(const std::string& name, int32_t num) { \ - return ArgImpl(is_input, name, is_optional, num, false); \ - } \ - OpRegistry& OpRegistry::name_prefix##WithMinimum(const std::string& name, int32_t min_num) { \ - return ArgImpl(is_input, name, is_optional, min_num, true); \ +#define OP_REG_ARG_MEMBER_FUNC(name_prefix, is_input, is_optional) \ + OpRegistry& OpRegistry::name_prefix(const std::string& name) { \ + return ArgImpl(is_input, name, is_optional); \ } OP_REG_ARG_MEMBER_FUNC(Input, true, false) @@ -172,6 +164,7 @@ OpRegistry& OpRegistry::SetGetSbpFn(GetSbpFn get_sbp_fn) { result_.get_sbp_fn = std::move(get_sbp_fn); return *this; } + OpRegistry& OpRegistry::SetSbpSignatureInferFn(SbpSignatureInferFn sbp_signature_infer_fn) { result_.sbp_signature_infer_fn = std::move(sbp_signature_infer_fn); return *this; @@ -222,10 +215,10 @@ Maybe OpRegistry::Finish() { const auto& nd_sbp = ctx->NdSbp4ArgNameAndIndex(pair.first, pair.second); const TensorDesc* in_logical = ctx->LogicalTensorDesc4ArgNameAndIndex(pair.first, pair.second); - const TensorDesc* in_physical = ctx->TensorDesc4ArgNameAndIndex(pair.first, pair.second); + const TensorDesc& in_physical = ctx->InputTensorDesc(pair.first, pair.second); CHECK_OR_RETURN(*JUST(GetPhysicalShape(in_logical->shape(), nd_sbp, ctx->parallel_desc(), ctx->parallel_ctx())) - == in_physical->shape()); + == in_physical.shape()); } for (const auto& pair : ctx->outputs()) { TensorDesc* desc = ctx->OutputTensorDesc(pair.first, pair.second); diff --git a/oneflow/core/framework/user_op_registry.h b/oneflow/core/framework/user_op_registry.h index fb3d69ecb56..036aa792acd 100644 --- a/oneflow/core/framework/user_op_registry.h +++ b/oneflow/core/framework/user_op_registry.h @@ -48,12 +48,13 @@ using SbpSignatureInferFn = std::function(InferSbpSignatureFnContext using InputArgModifier = InputBlobModifier; using GetInputArgModifier = std::function; -using InputArgModifyFn = std::function(GetInputArgModifier, const UserOpConfWrapper&)>; +using InputArgModifyFn = + std::function(const GetInputArgModifier&, const UserOpConfWrapper&)>; using OutputArgModifier = OutputBlobModifier; using GetOutputArgModifier = std::function; using OutputArgModifyFn = - std::function(GetOutputArgModifier, const UserOpConfWrapper&)>; + std::function(const GetOutputArgModifier&, const UserOpConfWrapper&)>; using OutputBlobTimeShapeInferFn = std::function(InferOutputBlobTimeShapeFnContext*)>; using NdSbpInferFn = std::function(InferNdSbpFnContext*)>; @@ -129,8 +130,7 @@ class OpRegistry final { OpRegistryResult GetResult() { return result_; } private: - OpRegistry& ArgImpl(bool is_input, const std::string& name, bool is_optional, int32_t num, - bool num_as_min); + OpRegistry& ArgImpl(bool is_input, const std::string& name, bool is_optional); OpRegistry& DefaultedAttr(const std::string& name, AttrType type, const std::function& SetDefault); diff --git a/oneflow/core/framework/user_op_tensor.h b/oneflow/core/framework/user_op_tensor.h index a7dca20412a..2853ad56f15 100644 --- a/oneflow/core/framework/user_op_tensor.h +++ b/oneflow/core/framework/user_op_tensor.h @@ -30,7 +30,13 @@ namespace user_op { class Tensor { public: - virtual ~Tensor() = default; +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wnon-virtual-dtor" + // NOTE: Performance will be degraded if the destructor is virtual. + // So please do NOT implement custom destructor in any child classes of user_op::Tensor, + // and every fields of child classes should be of POD type. + ~Tensor() = default; +#pragma GCC diagnostic pop virtual const ShapeView& shape() const = 0; virtual MutShapeView* mut_shape() = 0; diff --git a/oneflow/core/functional/function_library.h b/oneflow/core/functional/function_library.h index 6dd8148ea4e..addaacc01e4 100644 --- a/oneflow/core/functional/function_library.h +++ b/oneflow/core/functional/function_library.h @@ -30,10 +30,10 @@ class FunctionLibrary { virtual ~FunctionLibrary() = default; template - struct PackedFuncMap; + struct PackedFuncCreatorMap; template - struct PackedFuncMap { + struct PackedFuncCreatorMap { using FunctorCreator = typename std::function()>; static HashMap* Get() { @@ -42,14 +42,18 @@ class FunctionLibrary { } }; + template + void add_functor(const std::string& func_name, const Func& func) { + using func_type = typename function_traits::func_type; + add_functor_creator( + func_name, [=]() { return PackedFunctorMaker::make(func_name, func); }); + } + template void add_one_functor(const std::string& func_name) { using func_type = typename function_traits::func_type; - using FType = typename PackedFunctorMaker::FType; - auto* functors = PackedFuncMap::Get(); - CHECK_EQ(functors->count(func_name), 0) - << "The functor with name " << func_name << " has been registered more than once."; - functors->emplace(func_name, [func_name]() -> PackedFunctor { + add_functor_creator(func_name, [=]() { + // Lazily construct functor since ops maybe have not been registered. Func func; return PackedFunctorMaker::make(func_name, func); }); @@ -58,15 +62,13 @@ class FunctionLibrary { template void add_functor(const std::string& func_name) { static_assert(sizeof...(Fs) > 0, "at least one functor is expected"); - __attribute__((__unused__)) int dummy[] = {(add_one_functor(func_name), 0)...}; } template auto find(const std::string& func_name) -> Maybe::FType>> { - using FType = typename PackedFunctorMaker::FType; - auto* functors = PackedFuncMap::Get(); + auto* functors = PackedFuncCreatorMap::FType>::Get(); const auto& it = functors->find(func_name); CHECK_OR_RETURN(it != functors->end()) << "Functor was not found for \"" << func_name @@ -81,6 +83,15 @@ class FunctionLibrary { private: FunctionLibrary() = default; + + template + void add_functor_creator(const std::string& func_name, Creator creator) { + using func_type = typename function_traits::func_type; + auto* functors = PackedFuncCreatorMap::FType>::Get(); + CHECK_EQ(functors->count(func_name), 0) + << "The functor with name " << func_name << " has been registered more than once."; + functors->emplace(func_name, creator); + } }; #define ONEFLOW_FUNCTION_LIBRARY(m) ONEFLOW_FUNCTION_LIBRARY_IMPL(m, __COUNTER__) diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 533cb067619..8018ef75309 100755 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -23,10 +23,10 @@ - name: "add" signature: [ - "Tensor (TensorTuple inputs, *, Bool inplace=False) => Add", "Tensor (Tensor input, Tensor other, *, Scalar alpha=1, Bool inplace=False) => Add", "Tensor (Tensor input, Scalar other, *, Scalar alpha=1, Bool inplace=False) => ScalarAdd", "Tensor (Scalar input, Tensor other, *, Scalar alpha=1) => ScalarAdd", + "Tensor (TensorTuple inputs, *, Bool inplace=False) => Add", ] bind_python: true @@ -357,6 +357,14 @@ signature: "Tensor (Tensor x, Tensor dy) => LogGrad" bind_python: False +- name: "log2" + signature: "Tensor (Tensor x) => Log2" + bind_python: True + +- name: "log2_grad" + signature: "Tensor (Tensor x, Tensor dy) => Log2Grad" + bind_python: False + - name: "sqrt" signature: "Tensor (Tensor x) => Sqrt" bind_python: True @@ -381,6 +389,10 @@ signature: "Tensor (Tensor x, Tensor dy) => SquareGrad" bind_python: False +- name: "sqrt_square_sum" + signature: "Tensor (Tensor x) => SqrtSquareSum" + bind_python: True + - name: "std" signature: "Tensor (Tensor x, Int32List[1] dim=None, Bool unbiased=None, Bool keepdim=None) => StandardDeviation" bind_python: True @@ -521,32 +533,27 @@ bind_python: False - name: "eye" - signature: - [ - "Tensor (Scalar n, Scalar m=None, *, DataType dtype=None, Device device=None) => Eye", - ] - bind_python: True - -- name: "consistent_eye" - signature: - [ - "Tensor (Scalar n, Scalar m=None, *, DataType dtype=None, Placement placement, SbpList sbp) => ConsistentEye", + signature: [ + "Tensor (Scalar n, Scalar m=None, *, DataType dtype=kFloat, Device device=None, Bool requires_grad=False) => Eye", + "Tensor (Scalar n, Scalar m=None, *, DataType dtype=kFloat, String device, Bool requires_grad=False) => Eye", + "Tensor (Scalar n, Scalar m=None, *, DataType dtype=kFloat, Bool requires_grad=False, Placement placement, SbpList sbp) => Eye", + "Tensor (Scalar n, Scalar m=None, *, DataType dtype=kFloat, Bool requires_grad=False, Placement placement, Sbp sbp) => Eye", ] bind_python: True - name: "arange" signature: [ - "Tensor (Scalar start, Scalar end, Scalar step=1, *, DataType dtype=kInt64, + "Tensor (Scalar start, Scalar end, Scalar step=1, *, DataType dtype=None, Device device=None) => Arange", - "Tensor (Scalar end, *, DataType dtype=kInt64, Device device=None) => Arange", + "Tensor (Scalar end, *, DataType dtype=None, Device device=None) => Arange", ] bind_python: True - name: "consistent_arange" signature: [ - "Tensor (Scalar start, Scalar end, Scalar step=1, *, DataType dtype=kInt64, + "Tensor (Scalar start, Scalar end, Scalar step=1, *, DataType dtype=None, Placement placement, SbpList sbp) => ConsistentArange", - "Tensor (Scalar end, *, DataType dtype=kInt64, Placement placement, SbpList sbp) => ConsistentArange", + "Tensor (Scalar end, *, DataType dtype=None, Placement placement, SbpList sbp) => ConsistentArange", ] bind_python: True @@ -626,13 +633,22 @@ - name: "conv1d" signature: "Tensor (Tensor x, Tensor weight, Tensor bias=None, Int32List stride, - Int32List padding, Int32List dilation, Int32 groups=1) => Conv1d" + Int32List padding, Int32List dilation, Int32 groups=1, + String channel_pos) => Conv1d" bind_python: True - name: "conv2d" signature: "Tensor (Tensor x, Tensor weight, Tensor bias=None, Int32List stride, - Int32List padding, Int32List dilation, Int32 groups=1) => Conv2d" + Int32List padding, Int32List dilation, Int32 groups=1, + String channel_pos) => Conv2d" + bind_python: True + +- name: "conv3d" + signature: + "Tensor (Tensor x, Tensor weight, Tensor bias=None, Int32List stride, + Int32List padding, Int32List dilation, Int32 groups=1, + String channel_pos) => Conv3d" bind_python: True - name: "fake_quantization" @@ -660,12 +676,6 @@ Int32 quantization_bit, String quantization_scheme, Float momentum) => MovingAverageMinMaxObserver" bind_python: True -- name: "conv3d" - signature: - "Tensor (Tensor x, Tensor weight, Tensor bias=None, Int32List stride, - Int32List padding, Int32List dilation, Int32 groups=1) => Conv3d" - bind_python: True - - name: "conv_data_grad" signature: 'Tensor (Tensor dy, Tensor weight, Tensor x, Int32 num_spatial_dims, @@ -693,6 +703,13 @@ Int32List output_padding, Int32List strides, Int32List dilation, Int32 groups=1) => Deconv1d" bind_python: True +- name: "deconv2d" + signature: + "Tensor (Tensor x, Tensor weight, Tensor bias=None, Int32 filters, + Int32List padding, String data_format, Int32List kernel_size, + Int32List output_padding, Int32List strides, Int32List dilation, Int32 groups=1) => Deconv2d" + bind_python: True + - name: "deconv3d" signature: "Tensor (Tensor x, Tensor weight, Tensor bias=None, Int32 filters, @@ -944,14 +961,14 @@ signature: "TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Int64 begin_params_axis, Double epsilon) => LayerNormParamGrad" bind_python: False -- name: "avg_pool_2d" +- name: "avg_pool2d_nhwc" signature: 'Tensor (Tensor x, Int32List kernel_size, Int32List stride, String padding, Int32List padding_before, Int32List padding_after, String data_format="channels_first", Bool ceil_mode=False) => AvgPool2D' bind_python: True -- name: "max_pool_2d" +- name: "max_pool2d_nhwc" signature: 'Tensor (Tensor x, Int32List kernel_size, Int32List stride, String padding, Int32List padding_before, Int32List padding_after, @@ -1034,7 +1051,7 @@ bind_python: True - name: "slice_grad" - signature: "Tensor (Tensor dy, Tensor like, Int64List start, Int64List stop, Int64List step) => SliceGrad" + signature: "Tensor (Tensor dy, Shape like, Int64List start, Int64List stop, Int64List step) => SliceGrad" bind_python: False - name: "narrow" @@ -1456,6 +1473,14 @@ signature: "Tensor (Tensor dy, Tensor in, Int32 diagonal=0) => DiagGrad" bind_python: False +- name: "diagonal" + signature: "Tensor (Tensor x, Int32 offset=0, Int32 dim1=0, Int32 dim2=1) => Diagonal" + bind_python: True + +- name: "diagonal_grad" + signature: "Tensor (Tensor dy, Tensor in, Int32 offset=0) => DiagonalGrad" + bind_python: False + - name: "tensor_getitem" signature: "Tensor (Tensor x, TensorIndex index) => TensorGetItem" bind_python: True @@ -1688,12 +1713,22 @@ ] bind_python: True +- name: "chunk" + signature: [ + "TensorTuple (Tensor x, Int64 chunks, Int64 dim=0) => Chunk", + ] + bind_python: True + - name: "split_like" signature: "TensorTuple (Tensor x, TensorTuple like, Int64 axis) => SplitLike" bind_python: True +- name: "normalize" + signature: "Tensor (Tensor input, Float p=2.0, Int32 dim=1, Float eps=1e-12) => Normalize" + bind_python: True + - name: "l2_normalize" - signature: "TensorTuple (Tensor input, Int32 axis, Float epsilon) => L2Normalize" + signature: "Tensor (Tensor input, Int32 axis=0, Float epsilon=1e-12) => L2Normalize" bind_python: True - name: "l2_normalize_grad" @@ -1799,7 +1834,7 @@ bind_python: False - name: "meshgrid" - signature: "TensorTuple (TensorTuple tensors) => Meshgrid" + signature: 'TensorTuple (TensorTuple tensors, String indexing="ij") => Meshgrid' bind_python: True - name: "decode_onerec" @@ -1814,3 +1849,30 @@ signature: "Tensor (Tensor input, Tensor other) => Dot" bind_python: True +- name: "tensor_buffer_to_tensor" + signature: "Tensor (Tensor input, Shape instance_shape, DataType dtype) => TensorBufferToTensor" + bind_python: True + +- name: "tensor_to_tensor_buffer" + signature: "Tensor (Tensor input, Int32 instance_dims) => TensorToTensorBuffer" + bind_python: True + +- name: "gen_tensor_buffer" + signature: "Tensor (Shape shape, ShapeList shape_list, FloatList value_list, DataType data_type, Bool dynamic_out) => GenTensorBuffer" + bind_python: True + +- name: "top_k" + signature: "Tensor (Tensor input, Int32 k, Bool sorted=True) => TopK" + bind_python: True + +- name: "in_top_k" + signature: "Tensor (Tensor targets, Tensor predictions, Int32 k) => InTopK" + bind_python: True + +- name: "cumsum" + signature: "Tensor (Tensor input, Int64 dim) => Cumsum" + bind_python: True + +- name: "cumsum_grad" + signature: "Tensor (Tensor input, Int64 dim) => CumsumGrad" + bind_python: False diff --git a/oneflow/core/functional/impl/activation_functor.cpp b/oneflow/core/functional/impl/activation_functor.cpp index 122f0f98d81..f00f9777471 100644 --- a/oneflow/core/functional/impl/activation_functor.cpp +++ b/oneflow/core/functional/impl/activation_functor.cpp @@ -36,9 +36,7 @@ namespace impl { class ReluFunctor { public: - ReluFunctor() { - op_ = CHECK_JUST(one::OpBuilder("relu").Input("in", 1).Output("out", 1).Build()); - } + ReluFunctor() { op_ = CHECK_JUST(one::OpBuilder("relu").Input("x", 1).Output("y", 1).Build()); } Maybe operator()(const std::shared_ptr& x, bool inplace) const { if (inplace) { JUST(CheckInplaceValid(x)); @@ -62,16 +60,6 @@ class ReluGradFunctor : public BinaryFunctor { } }; -namespace { -Maybe CheckPReLUParametersValid(const std::shared_ptr& x, - const std::shared_ptr& alpha) { - int num_params = alpha->dim(0); - CHECK_OR_RETURN(((num_params == 1) || (num_params == x->shape()->At(1)))) - << "num_parameters in prelu must be 1 or " << x->shape()->At(1); - return Maybe::Ok(); -} -} // namespace - class PReluFunctor { public: PReluFunctor() { @@ -80,7 +68,9 @@ class PReluFunctor { Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& alpha) const { - JUST(CheckPReLUParametersValid(x, alpha)); + int num_params = alpha->dim(0); + CHECK_OR_RETURN(((num_params == 1) || (num_params == x->shape()->At(1)))) + << "num_parameters in prelu must be 1 or " << x->shape()->At(1); return OpInterpUtil::Dispatch(*op_, {x, alpha}); } diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp index d7f8504f77a..fc7c47bdc3d 100644 --- a/oneflow/core/functional/impl/array_functor.cpp +++ b/oneflow/core/functional/impl/array_functor.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/autograd/autograd_mode.h" +#include "oneflow/core/common/maybe.h" #include "oneflow/core/common/scalar.h" #include "oneflow/core/common/global.h" #include "oneflow/core/common/optional.h" @@ -130,6 +131,7 @@ class ConsistentConstantFunctor { Maybe operator()(const Shape& shape, const Scalar& value, const Symbol& dtype, const Symbol& placement, const std::vector>& sbp_tuple) const { + JUST(CheckDeviceIdsIsValid(placement)); MutableAttrMap attrs; JUST(attrs.SetAttr("shape", shape)); JUST(attrs.SetAttr("dtype", dtype->data_type())); @@ -210,6 +212,7 @@ class ConsistentEmptyFunctor { Maybe operator()(const Shape& shape, const Symbol& dtype, const Symbol& placement, const std::vector>& sbp_tuple) const { + JUST(CheckDeviceIdsIsValid(placement)); MutableAttrMap attrs; JUST(attrs.SetAttr("shape", shape)); JUST(attrs.SetAttr("dtype", dtype->data_type())); @@ -399,6 +402,23 @@ class BroadcastLikeFunctor { const std::shared_ptr& like, const std::vector& broadcast_axes) const { MutableAttrMap attrs; + if (broadcast_axes.empty()) { + int64_t like_ndim = like->shape()->NumAxes(); + int64_t x_ndim = x->shape()->NumAxes(); + int64_t num_prepend = like_ndim - x_ndim; + std::vector prepend_shape(num_prepend, 1); + std::vector broadcast_axes; + for (int i = 0; i < x_ndim; ++i) { prepend_shape.emplace_back(x->shape()->At(i)); } + for (int i = 0; i < num_prepend; ++i) { broadcast_axes.emplace_back(i); } + for (int i = num_prepend; i < prepend_shape.size(); ++i) { + if (prepend_shape[i] != like->shape()->At(i)) { + if (prepend_shape[i] == 1) { broadcast_axes.emplace_back(i); } + CHECK_GE_OR_RETURN(prepend_shape[i], 1) + << "output with shape " << x->shape()->ToString() + << " doesn't match the broadcast shape " << like->shape()->ToString(); + } + } + } JUST(attrs.SetAttr>("broadcast_axes", broadcast_axes)); return OpInterpUtil::Dispatch(*op_, {x, like}, attrs); } @@ -411,16 +431,16 @@ class ConcatFunctor { public: ConcatFunctor() { ops_.resize(kMaxInputCount); - for (int n = 1; n < ops_.size(); ++n) { + for (int n = 0; n < ops_.size(); ++n) { ops_[n] = CHECK_JUST(one::OpBuilder("concat").Input("in", n + 1).Output("out").Build()); } } Maybe operator()(const TensorTuple& inputs, const int64_t& dim) const { - if (inputs.size() == 1) { return inputs.at(0); } + const int64_t ninput = inputs.size(); int64_t axis = dim; int64_t ndim = inputs[0]->ndim(); int64_t max_dim_size = 0; - CHECK_GE_OR_RETURN(inputs.size(), 2); + CHECK_GE_OR_RETURN(ninput, 1); CHECK_OR_RETURN((-(ndim) <= dim) && (dim <= (ndim - 1))) << " IndexError: Dimension out of range, expected to be in range of [" << -ndim << ", " << ndim - 1 << "], but got " << dim; @@ -445,13 +465,14 @@ class ConcatFunctor { JUST(attrs.SetAttr("axis", axis)); JUST(attrs.SetAttr("max_dim_size", max_dim_size)); TensorTuple outputs; - for (int i = 0; i < inputs.size(); i += kMaxInputCount) { - size_t size = (i + kMaxInputCount) < inputs.size() ? kMaxInputCount : inputs.size() - i; + for (int i = 0; i < ninput; i += kMaxInputCount) { + size_t size = (i + kMaxInputCount) < ninput ? kMaxInputCount : ninput - i; TensorTuple partial_inputs(size); for (int j = 0; j < size; ++j) { partial_inputs[j] = inputs[i + j]; } outputs.emplace_back( JUST(OpInterpUtil::Dispatch(*ops_.at(size - 1), partial_inputs, attrs))); } + if (outputs.size() == 1) { return outputs.at(0); } return this->operator()(outputs, axis); } @@ -890,11 +911,13 @@ class ReshapeFunctor { } Maybe operator()(const std::shared_ptr& x, const Shape& shape) const { // if input tensor is eager local, than return tensor's view - if (x->is_eager() && x->is_local()) { return view::Reshape(x, shape); } + if (x->is_local() && !(LazyMode::is_enabled())) { return view::Reshape(x, shape); } int need_infer_axis = -1; size_t count = 1; for (int i = 0; i < shape.NumAxes(); ++i) { - if (shape.At(i) == -1) { + if (shape.At(i) < -1) { + return Error::RuntimeError() << "Invalid shape dimension " << shape.At(i); + } else if (shape.At(i) == -1) { CHECK_EQ_OR_RETURN(need_infer_axis, -1) << "Shape " << shape.ToString() << " has more than 1 axis that needs to be infered."; need_infer_axis = i; @@ -946,15 +969,15 @@ class SliceGradBaseFunctor { public: SliceGradBaseFunctor() = default; virtual ~SliceGradBaseFunctor() = default; - Maybe operator()(const std::shared_ptr& dy, - const std::shared_ptr& like, + Maybe operator()(const std::shared_ptr& dy, const Shape& like, const std::vector& start, const std::vector& stop, const std::vector& step) const { MutableAttrMap attrs; + JUST(attrs.SetAttr("like_shape", like)); JUST(attrs.SetAttr>("start", start)); JUST(attrs.SetAttr>("stop", stop)); JUST(attrs.SetAttr>("step", step)); - return OpInterpUtil::Dispatch(*op_, {dy, like}, attrs); + return OpInterpUtil::Dispatch(*op_, {dy}, attrs); } protected: @@ -969,7 +992,7 @@ class SliceFunctor : public SliceBaseFunctor { class SliceGradFunctor : public SliceGradBaseFunctor { public: SliceGradFunctor() { - op_ = CHECK_JUST(one::OpBuilder("slice_grad").Input("dy").Input("like").Output("dx").Build()); + op_ = CHECK_JUST(one::OpBuilder("slice_grad").Input("dy").Output("dx").Build()); } }; @@ -1541,6 +1564,8 @@ class TrilFunctor { Maybe operator()(const std::shared_ptr& x, const int64_t& diagonal) const { MutableAttrMap attrs; JUST(attrs.SetAttr("diagonal", diagonal)); + JUST(attrs.SetAttr("is_floating_fill_value", false)); + JUST(attrs.SetAttr("integer_fill_value", 0)); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } @@ -1590,6 +1615,63 @@ class DiagGradFunctor { std::shared_ptr op_; }; +class DiagonalFunctor { + public: + DiagonalFunctor() { + op_ = CHECK_JUST(one::OpBuilder("diagonal").Input("in").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& x, const int32_t& offset, + const int32_t& dim1, const int32_t& dim2) const { + int64_t ndims = x->shape()->NumAxes(); + + CHECK_GE_OR_RETURN(dim1, -ndims) + << ", Dimension out of range (expected to be in range of [" << -ndims << ", " << ndims - 1 + << "], but got " << dim1 << ");"; + CHECK_LT_OR_RETURN(dim1, ndims) << ", Dimension out of range (expected to be in range of [" + << -ndims << ", " << ndims - 1 << "], but got " << dim1 << ");"; + CHECK_GE_OR_RETURN(dim2, -ndims) + << ", Dimension out of range (expected to be in range of [" << -ndims << ", " << ndims - 1 + << "], but got " << dim2 << ");"; + CHECK_LT_OR_RETURN(dim2, ndims) << ", Dimension out of range (expected to be in range of [" + << -ndims << ", " << ndims - 1 << "], but got " << dim2 << ");"; + + int32_t p_dim1 = dim1 >= 0 ? dim1 : dim1 + ndims; + int32_t p_dim2 = dim2 >= 0 ? dim2 : dim2 + ndims; + CHECK_NE_OR_RETURN(p_dim1, p_dim2) + << ", diagonal dimensions cannot be identical " << dim1 << ", " << dim2; + + std::vector input_index{p_dim1, p_dim2}; + for (int32_t i = 0; i < ndims; i++) { + if (i != p_dim1 && i != p_dim2) { input_index.push_back(i); } + } + + std::shared_ptr d_x = JUST(Transpose(x, input_index)); + + MutableAttrMap attrs; + JUST(attrs.SetAttr("offset", offset)); + return OpInterpUtil::Dispatch(*op_, {d_x}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class DiagonalGradFunctor { + public: + DiagonalGradFunctor() { + op_ = CHECK_JUST(one::OpBuilder("diagonal_grad").Input("dy").Input("in").Output("dx").Build()); + } + Maybe operator()(const std::shared_ptr& dy, + const std::shared_ptr& x, const int32_t& offset) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("offset", offset)); + return OpInterpUtil::Dispatch(*op_, {dy, x}, attrs); + } + + private: + std::shared_ptr op_; +}; + class TensorGetItemFunctor { public: TensorGetItemFunctor() {} @@ -1938,6 +2020,52 @@ class SplitFunctor { } }; +class ChunkFunctor { + public: + ChunkFunctor() {} + Maybe operator()(const std::shared_ptr& x, const int64_t& chunks, + const int64_t& dim) const { + int64_t axis = dim; + if (axis < 0) { axis += x->ndim(); } + int64_t split_size = x->shape()->At(axis) / chunks; + CHECK_OR_RETURN(axis >= 0 && axis < x->ndim()) + << "Dimension out of range (expected to be in range of [" << -(x->ndim()) << ", " + << x->ndim() - 1 << "], but got " << dim; + int64_t dim_size = x->shape()->At(axis); + if ((split_size * chunks) != dim_size) { + std::vector sections; + for (int i = 0; i < chunks - 1; ++i) { sections.emplace_back(split_size); } + sections.emplace_back(dim_size - split_size * (chunks - 1)); + int64_t num_splits = sections.size(); + TensorTuple splits(num_splits); + int64_t start_idx = 0; + for (int i = 0; i < num_splits; ++i) { + int64_t length = sections[i]; + CHECK_GE_OR_RETURN(length, 0) << "split_with_sizes expects split_sizes have only " + "non-negative entries, but split_sizes[" + << i << "] = " << length; + splits[i] = JUST(Narrow(x, axis, start_idx, length)); + start_idx += length; + } + CHECK_EQ_OR_RETURN(start_idx, dim_size) + << "split_with_sizes expects split_sizes to sum exactly to " << dim_size + << " (input tensor's size at dimension " << axis << "), " + << "but got sum(split_sizes)=" << start_idx; + return splits; + } + CHECK_GE_OR_RETURN(split_size, 0) + << "split expects split_size be non-negative, but got split_size=" << split_size; + int64_t num_splits = std::max((dim_size + split_size - 1) / split_size, 1); + TensorTuple splits(num_splits); + int64_t last_split_size = split_size - (split_size * num_splits - dim_size); + for (int i = 0; i < num_splits; ++i) { + int64_t length = i < num_splits - 1 ? split_size : last_split_size; + splits[i] = JUST(Narrow(x, axis, i * split_size, length)); + } + return splits; + } +}; + class SplitLikeFunctor { public: SplitLikeFunctor() { @@ -2073,21 +2201,9 @@ class MaskedFillFunctor { class MeshgridFunctor { public: - Maybe operator()(const TensorTuple& tensors) const { + Maybe operator()(const TensorTuple& tensors, const std::string& indexing) const { int size = tensors.size(); CHECK_GT_OR_RETURN(size, 0) << "meshgrid expects a non-empty TensorList"; - DimVector shape_vec(size); - for (int i = 0; i < size; ++i) { - CHECK_LE_OR_RETURN(tensors[i]->shape()->NumAxes(), 1) - << "Expected scalar or 1D tensor in the tensor list but got: " - << tensors[i]->shape()->NumAxes(); - if (tensors[i]->shape()->NumAxes() == 0) { - shape_vec[i] = 1; - } else { - shape_vec[i] = tensors[i]->shape()->At(0); - } - } - Shape shape(shape_vec); for (int i = 0; i < size - 1; ++i) { CHECK_OR_RETURN( @@ -2095,16 +2211,46 @@ class MeshgridFunctor { && (JUST(tensors[i]->device())->type() == JUST(tensors[i + 1]->device())->type())) << "meshgrid expects all tensors to have the same dtype and device"; } - TensorTuple outputs(size); + + std::vector> tensor_consts(tensors.begin(), tensors.end()); + + bool swap_first_and_second_tensors = false; + if (indexing == "xy") { + swap_first_and_second_tensors = (size >= 2); + if (swap_first_and_second_tensors) { std::swap(tensor_consts[0], tensor_consts[1]); } + } else { + CHECK_EQ_OR_RETURN(indexing, "ij") + << "flow.meshgrid: indexing must be one of \"xy\" or \"ij\", " + "but received: ," + << indexing; + } + + TensorTuple grids(size); + DimVector grids_vec(size); + for (int i = 0; i < size; ++i) { + CHECK_LE_OR_RETURN(tensor_consts[i]->shape()->NumAxes(), 1) + << "Expected scalar or 1D tensor in the tensor list but got: " + << tensor_consts[i]->shape()->NumAxes(); + if (tensor_consts[i]->shape()->NumAxes() == 0) { + grids_vec[i] = 1; + } else { + grids_vec[i] = tensor_consts[i]->shape()->At(0); + } + } + Shape grids_shape(grids_vec); + + DimVector view_shape_vec(size, 1); + Shape view_shape(view_shape_vec); for (int i = 0; i < size; ++i) { - DimVector view_shape_vec(size, 1); - view_shape_vec[i] = -1; - Shape view_shape(view_shape_vec); - std::shared_ptr reshaped = JUST(Reshape(tensors.at(i), view_shape)); - outputs[i] = JUST(Expand(reshaped, shape)); + view_shape.Set(i, -1); + std::shared_ptr reshaped = JUST(Reshape(tensor_consts.at(i), view_shape)); + grids[i] = JUST(Expand(reshaped, grids_shape)); + view_shape.Set(i, 1); } - return outputs; + if (swap_first_and_second_tensors) { std::swap(grids[0], grids[1]); } + + return grids; } }; @@ -2239,6 +2385,94 @@ class To4Functor { } }; +class TopKFunctor { + public: + TopKFunctor() { op_ = CHECK_JUST(one::OpBuilder("top_k").Input("in").Output("out").Build()); } + Maybe operator()(const std::shared_ptr& input, int32_t k, bool sorted) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("k", k)); + JUST(attrs.SetAttr("sorted", sorted)); + return OpInterpUtil::Dispatch(*op_, {input}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class InTopKFunctor { + public: + InTopKFunctor() { + op_ = CHECK_JUST( + one::OpBuilder("in_top_k").Input("targets").Input("predictions").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& targets, + const std::shared_ptr& predictions, int32_t k) const { + CHECK_EQ_OR_RETURN(targets->shape()->At(0), predictions->shape()->At(0)) + << "The num of targets must equal the num of predictions"; + CHECK_EQ_OR_RETURN(targets->ndim(), 1) << "The dimension of targets must be 1"; + CHECK_EQ_OR_RETURN(predictions->ndim(), 2) << "The dimension of predictions must be 2"; + MutableAttrMap attrs; + JUST(attrs.SetAttr("k", k)); + return OpInterpUtil::Dispatch(*op_, {targets, predictions}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class TensorBufferToTensorFunctor { + public: + TensorBufferToTensorFunctor() { + op_ = CHECK_JUST(one::OpBuilder("tensor_buffer_to_tensor").Input("in").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& input, const Shape& instance_shape, + const Symbol& dtype) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("instance_shape", instance_shape)); + JUST(attrs.SetAttr("dtype", dtype->data_type())); + return OpInterpUtil::Dispatch(*op_, {input}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class TensorToTensorBufferFunctor { + public: + TensorToTensorBufferFunctor() { + op_ = CHECK_JUST(one::OpBuilder("tensor_to_tensor_buffer").Input("in").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& input, int32_t instance_dims) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("instance_dims", instance_dims)); + return OpInterpUtil::Dispatch(*op_, {input}, attrs); + } + + private: + std::shared_ptr op_; +}; + +class GenTensorBufferFunctor { + public: + GenTensorBufferFunctor() { + op_ = CHECK_JUST(one::OpBuilder("gen_tensor_buffer").Output("out").Build()); + } + Maybe operator()(const Shape& shape, const std::vector& shape_list, + const std::vector& value_list, const Symbol& dtype, + bool dynamic_out) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("shape", shape)); + JUST(attrs.SetAttr>("shape_list", shape_list)); + JUST(attrs.SetAttr>("value_list", value_list)); + JUST(attrs.SetAttr("data_type", dtype->data_type())); + JUST(attrs.SetAttr("dynamic_out", dynamic_out)); + return OpInterpUtil::Dispatch(*op_, {}, attrs); + } + + private: + std::shared_ptr op_; +}; + } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { @@ -2306,6 +2540,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("Triu"); m.add_functor("Diag"); m.add_functor("DiagGrad"); + m.add_functor("Diagonal"); + m.add_functor("DiagonalGrad"); m.add_functor("TensorGetItem"); m.add_functor("DimScatter"); m.add_functor("DimScatterAdd"); @@ -2326,6 +2562,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("ReduceSumLike"); m.add_functor("BroadcastReduceSumLike"); m.add_functor("Split"); + m.add_functor("Chunk"); m.add_functor("SplitLike"); m.add_functor("SplitWithSize"); m.add_functor("BatchGather"); @@ -2333,6 +2570,11 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("MaskedFill"); m.add_functor("Meshgrid"); m.add_functor("To"); + m.add_functor("TopK"); + m.add_functor("InTopK"); + m.add_functor("TensorToTensorBuffer"); + m.add_functor("TensorBufferToTensor"); + m.add_functor("GenTensorBuffer"); }; } // namespace functional diff --git a/oneflow/core/functional/impl/consistent_cast.cpp b/oneflow/core/functional/impl/consistent_cast.cpp index c331ebdc318..09316e50598 100644 --- a/oneflow/core/functional/impl/consistent_cast.cpp +++ b/oneflow/core/functional/impl/consistent_cast.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/framework/consistency_check.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/core/framework/id_util.h" #include "oneflow/core/framework/tensor.h" @@ -227,6 +228,10 @@ Maybe ConsistentToConsistent( CHECK_NOTNULL_OR_RETURN(consistent_tensor) << "consistent tensors supported only"; const auto& op = JUST(GetConsistentToConsistentOpExpr(grad_sbp_parallels)); const auto& nd_sbp = JUST(GetNdSbp(sbp_parallels)); + if (!LazyMode::is_enabled() && JUST(x->nd_sbp()) == nd_sbp + && JUST(x->parallel_desc()) == parallel_desc && grad_sbp_parallels.size() == 0) { + return x; + } const auto& tensor = JUST(OpInterpUtil::Dispatch( *op, {consistent_tensor}, OpExprInterpContext(AttrMap{}, parallel_desc, nd_sbp))); if (!LazyMode::is_enabled() && tensor != x && !IsConsistentTensorMetaCheckDisabled()) { @@ -290,6 +295,9 @@ class LocalToConsistentFunctor { Symbol parallel_desc, const std::vector>& sbp_parallels, const Shape& shape, const Symbol& dtype) const { + JUST(CheckDeviceIdsIsValid(parallel_desc)); + NonRecursiveMetaInfoConsistencyCheckScope no_recursive_meta_info_conisitency_check_scope; + JUST(MetaInfoConsistencyCheck(parallel_desc, sbp_parallels)); CHECK_OR_RETURN(x->is_local()); std::shared_ptr input = x; // copy to right device first if input's device type is wrong @@ -332,6 +340,9 @@ class ToConsistentFunctor { Symbol parallel_desc, const std::vector>& sbp_parallels, const std::vector>& grad_sbp_parallels) const { + JUST(CheckDeviceIdsIsValid(parallel_desc)); + NonRecursiveMetaInfoConsistencyCheckScope scope; + JUST(MetaInfoConsistencyCheck(parallel_desc, sbp_parallels, grad_sbp_parallels)); std::shared_ptr tensor; if (x->is_consistent()) { tensor = JUST(ConsistentToConsistent(x, parallel_desc, sbp_parallels, grad_sbp_parallels)); diff --git a/oneflow/core/functional/impl/dataset_functor.cpp b/oneflow/core/functional/impl/dataset_functor.cpp index 096f58d566d..5a347c958ef 100644 --- a/oneflow/core/functional/impl/dataset_functor.cpp +++ b/oneflow/core/functional/impl/dataset_functor.cpp @@ -106,6 +106,7 @@ class ReadOneRecFunctor { JUST(attrs.SetAttr("verify_example", verify_example)); if (placement.has_value()) { + JUST(CheckDeviceIdsIsValid(JUST(placement))); CHECK_OR_RETURN(sbp.has_value()) << "placement is not None, but sbp is None. It's not allowed."; AttrMap attrmap(attrs); diff --git a/oneflow/core/functional/impl/eye_functor.cpp b/oneflow/core/functional/impl/eye_functor.cpp new file mode 100644 index 00000000000..4de53d1be8f --- /dev/null +++ b/oneflow/core/functional/impl/eye_functor.cpp @@ -0,0 +1,137 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/common/just.h" +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/scalar.h" +#include "oneflow/core/common/throw.h" +#include "oneflow/core/common/util.h" +#include "oneflow/core/framework/device.h" +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/nd_sbp.h" +#include "oneflow/core/framework/op_builder.h" +#include "oneflow/core/framework/op_expr.h" +#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" +#include "oneflow/core/framework/tensor.h" +#include "oneflow/core/framework/tensor_tuple.h" +#include "oneflow/core/functional/functional.h" +#include "oneflow/core/functional/function_library.h" +#include "oneflow/core/functional/functional_api.yaml.h" +#include "oneflow/core/functional/impl/common.h" +#include "oneflow/core/job/lazy_mode.h" +#include "oneflow/core/job/sbp_parallel.h" + +namespace oneflow { +namespace one { +namespace functional { + +namespace impl { + +class EyeDevcieFunctor { + public: + EyeDevcieFunctor() { op_ = CHECK_JUST(one::OpBuilder("eye").Output("out").Build()); } + Maybe operator()(const Scalar& rows, const Optional& cols, + const Symbol& dtype, const Optional>& device, + const bool& requires_grad) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("rows", JUST(rows.As()))); + JUST(attrs.SetAttr("cols", JUST(cols.value_or(rows).As()))); + JUST(attrs.SetAttr("dtype", dtype->data_type())); + OpExprInterpContext ctx(attrs); + ctx.device = device; + auto res = JUST(OpInterpUtil::Dispatch(*op_, {}, ctx)); + JUST(res->set_requires_grad(requires_grad)); + return res; + } + + private: + std::shared_ptr op_; +}; + +class EyeDeviceStrFunctor { + public: + Maybe operator()(const Scalar& rows, const Optional& cols, + const Symbol& dtype, const std::string& device, + const bool& requires_grad) const { + const Symbol& dev = JUST(Device::ParseAndNew(device)); + return JUST(functional::Eye(rows, cols, dtype, dev, requires_grad)); + } +}; + +class ConsistentEyeSbpListFunctor { + public: + ConsistentEyeSbpListFunctor() { op_ = CHECK_JUST(one::OpBuilder("eye").Output("out").Build()); } + Maybe operator()(const Scalar& rows, const Optional& cols, + const Symbol& dtype, const bool& requires_grad, + const Symbol& placement, + const std::vector>& sbp_tuple) const { + MutableAttrMap attrs; + CHECK_EQ_OR_RETURN(sbp_tuple.size(), placement->hierarchy()->NumAxes()) + << "len(sbp) == len(placement.hierarchy) required, but " + << "len(sbp)==" << sbp_tuple.size() << ", " + << "len(placement.hierarchy)==" << placement->hierarchy()->NumAxes(); + + FOR_RANGE(int32_t, i, 0, sbp_tuple.size()) { + CHECK_OR_RETURN(sbp_tuple.at(i)->has_broadcast_parallel()) + << "sbp of eye should be broadcast only"; + } + + JUST(attrs.SetAttr("rows", JUST(rows.As()))); + JUST(attrs.SetAttr("cols", JUST(cols.value_or(rows).As()))); + JUST(attrs.SetAttr("dtype", dtype->data_type())); + if (LazyMode::is_enabled()) { + std::vector nd_sbp(sbp_tuple.size()); + { + for (int i = 0; i < sbp_tuple.size(); ++i) { + nd_sbp.at(i) = SbpParallelToString(*sbp_tuple.at(i)); + } + } + JUST(attrs.SetAttr>("nd_sbp", nd_sbp)); + } + const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple)); + auto res = JUST( + OpInterpUtil::Dispatch(*op_, {}, OpExprInterpContext(attrs, placement, nd_sbp))); + JUST(res->set_requires_grad(requires_grad)); + return res; + } + + private: + std::shared_ptr op_; +}; + +class ConsistentEyeSbpFunctor { + public: + Maybe operator()(const Scalar& rows, const Optional& cols, + const Symbol& dtype, const bool& requires_grad, + const Symbol& placement, + const Symbol& sbp) const { + std::vector> sbp_tuple{sbp}; + return JUST(functional::Eye(rows, cols, dtype, requires_grad, placement, sbp_tuple)); + } +}; + +} // namespace impl + +using namespace impl; + +ONEFLOW_FUNCTION_LIBRARY(m) { + m.add_functor("Eye"); +}; + +} // namespace functional +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/functional/impl/math_functor.cpp b/oneflow/core/functional/impl/math_functor.cpp index a21b10e6955..40c6dcd2503 100644 --- a/oneflow/core/functional/impl/math_functor.cpp +++ b/oneflow/core/functional/impl/math_functor.cpp @@ -575,67 +575,21 @@ class TransposeFunctor { Maybe operator()(const std::shared_ptr& input, const std::vector& permute) const { MutableAttrMap attrs; - CHECK_EQ_OR_RETURN(input->ndim(), permute.size()) << "number of dims don't match in permute"; - JUST(attrs.SetAttr>("perm", permute)); - int32_t ndims = input->shape()->NumAxes(); - for (int i = 0; i < permute.size(); i++) { - int32_t dim = permute.at(i); - if (dim < 0) { dim += ndims; } - CHECK_GE_OR_RETURN(dim, 0) - << "IndexError: Dimension out of range (expected to be in range of [" << -ndims << "," - << ndims << " ] but got " << ndims; - CHECK_LT_OR_RETURN(dim, ndims) - << "IndexError: Dimension out of range (expected to be in range of [" << -ndims << "," - << ndims << " ] but got " << ndims; + auto ndim = input->ndim(); + CHECK_EQ_OR_RETURN(ndim, permute.size()) << "number of dims don't match in permute"; + + // handle negative permute value here, because of permute is const, + // so copy it to local var and do modification. + auto positive_perm = permute; + for (auto i = 0; i < positive_perm.size(); i++) { + if (positive_perm[i] < 0) { positive_perm[i] += ndim; } + CHECK_OR_RETURN(positive_perm[i] >= 0 && positive_perm[i] < ndim) + << "IndexError: Dimension out of range (expected to be in range of [" << -ndim << "," + << ndim << " ) but got " << positive_perm[i]; } - return OpInterpUtil::Dispatch(*op_, {input}, attrs); - } - - private: - std::shared_ptr op_; -}; - -class EyeFunctor { - public: - EyeFunctor() { op_ = CHECK_JUST(one::OpBuilder("eye").Output("out").Build()); } - Maybe operator()(const Scalar& rows, const Optional& cols, - const Optional>& dtype, - const Optional>& device) const { - MutableAttrMap attrs; - JUST(attrs.SetAttr("rows", JUST(rows.As()))); - JUST(attrs.SetAttr("cols", JUST(cols.value_or(rows).As()))); - JUST(attrs.SetAttr("dtype", dtype ? JUST(dtype)->data_type() : DataType::kFloat)); - OpExprInterpContext ctx(attrs); - ctx.device = device; - return OpInterpUtil::Dispatch(*op_, {}, ctx); - } - - private: - std::shared_ptr op_; -}; -class ConsistentEyeFunctor { - public: - ConsistentEyeFunctor() { op_ = CHECK_JUST(one::OpBuilder("eye").Output("out").Build()); } - Maybe operator()(const Scalar& rows, const Optional& cols, - const Optional>& dtype, - const Symbol& placement, - const std::vector>& sbp_tuple) const { - MutableAttrMap attrs; - JUST(attrs.SetAttr("rows", JUST(rows.As()))); - JUST(attrs.SetAttr("cols", JUST(cols.value_or(rows).As()))); - JUST(attrs.SetAttr("dtype", dtype ? JUST(dtype)->data_type() : DataType::kFloat)); - if (LazyMode::is_enabled()) { - std::vector nd_sbp(sbp_tuple.size()); - { - for (int i = 0; i < sbp_tuple.size(); ++i) { - nd_sbp.at(i) = SbpParallelToString(*sbp_tuple.at(i)); - } - } - JUST(attrs.SetAttr>("nd_sbp", nd_sbp)); - } - const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple)); - return OpInterpUtil::Dispatch(*op_, {}, OpExprInterpContext(attrs, placement, nd_sbp)); + JUST(attrs.SetAttr>("perm", positive_perm)); + return OpInterpUtil::Dispatch(*op_, {input}, attrs); } private: @@ -678,19 +632,34 @@ class ArangeFunctor { public: ArangeFunctor() { op_ = CHECK_JUST(one::OpBuilder("arange").Output("out").Build()); } Maybe operator()(const Scalar& start, const Scalar& limit, const Scalar& delta, - const Symbol& dtype, + const Optional>& dtype, const Optional>& device) const { MutableAttrMap attrs; - const DataType range_dtype = dtype->data_type(); - JUST(attrs.SetAttr("dtype", range_dtype)); - if (IsIntegralDataType(range_dtype)) { - JUST(attrs.SetAttr("integer_start", JUST(start.As()))); - JUST(attrs.SetAttr("integer_limit", JUST(limit.As()))); - JUST(attrs.SetAttr("integer_delta", JUST(delta.As()))); + if (dtype.has_value()) { + const DataType range_dtype = JUST(dtype)->data_type(); + if (IsIntegralDataType(range_dtype)) { + JUST(attrs.SetAttr("integer_start", JUST(start.As()))); + JUST(attrs.SetAttr("integer_limit", JUST(limit.As()))); + JUST(attrs.SetAttr("integer_delta", JUST(delta.As()))); + JUST(attrs.SetAttr("dtype", range_dtype)); + } else { + JUST(attrs.SetAttr("float_start", JUST(start.As()))); + JUST(attrs.SetAttr("float_limit", JUST(limit.As()))); + JUST(attrs.SetAttr("float_delta", JUST(delta.As()))); + JUST(attrs.SetAttr("dtype", range_dtype)); + } } else { - JUST(attrs.SetAttr("float_start", JUST(start.As()))); - JUST(attrs.SetAttr("float_limit", JUST(limit.As()))); - JUST(attrs.SetAttr("float_delta", JUST(delta.As()))); + if (delta.IsIntegral()) { + JUST(attrs.SetAttr("integer_start", JUST(start.As()))); + JUST(attrs.SetAttr("integer_limit", JUST(limit.As()))); + JUST(attrs.SetAttr("integer_delta", JUST(delta.As()))); + JUST(attrs.SetAttr("dtype", DType::Int64()->data_type())); + } else { + JUST(attrs.SetAttr("float_start", JUST(start.As()))); + JUST(attrs.SetAttr("float_limit", JUST(limit.As()))); + JUST(attrs.SetAttr("float_delta", JUST(delta.As()))); + JUST(attrs.SetAttr("dtype", DType::Float()->data_type())); + } } OpExprInterpContext ctx(attrs); ctx.device = device; @@ -703,7 +672,7 @@ class ArangeFunctor { class Arange2Functor { public: - Maybe operator()(const Scalar& limit, const Symbol& dtype, + Maybe operator()(const Scalar& limit, const Optional>& dtype, const Optional>& device) const { return Arange(Scalar(0), limit, Scalar(1), dtype, device); } @@ -713,21 +682,37 @@ class ConsistentArangeFunctor { public: ConsistentArangeFunctor() { op_ = CHECK_JUST(one::OpBuilder("arange").Output("out").Build()); } Maybe operator()(const Scalar& start, const Scalar& limit, const Scalar& delta, - const Symbol& dtype, const Symbol& placement, + const Optional>& dtype, + const Symbol& placement, const std::vector>& sbp_tuple) const { + JUST(CheckDeviceIdsIsValid(placement)); MutableAttrMap attrs; - const DataType range_dtype = dtype->data_type(); - JUST(attrs.SetAttr("dtype", range_dtype)); - if (IsIntegralDataType(range_dtype)) { - JUST(attrs.SetAttr("integer_start", JUST(start.As()))); - JUST(attrs.SetAttr("integer_limit", JUST(limit.As()))); - JUST(attrs.SetAttr("integer_delta", JUST(delta.As()))); + if (dtype.has_value()) { + const DataType range_dtype = JUST(dtype)->data_type(); + if (IsIntegralDataType(range_dtype)) { + JUST(attrs.SetAttr("integer_start", JUST(start.As()))); + JUST(attrs.SetAttr("integer_limit", JUST(limit.As()))); + JUST(attrs.SetAttr("integer_delta", JUST(delta.As()))); + JUST(attrs.SetAttr("dtype", range_dtype)); + } else { + JUST(attrs.SetAttr("float_start", JUST(start.As()))); + JUST(attrs.SetAttr("float_limit", JUST(limit.As()))); + JUST(attrs.SetAttr("float_delta", JUST(delta.As()))); + JUST(attrs.SetAttr("dtype", range_dtype)); + } } else { - JUST(attrs.SetAttr("float_start", JUST(start.As()))); - JUST(attrs.SetAttr("float_limit", JUST(limit.As()))); - JUST(attrs.SetAttr("float_delta", JUST(delta.As()))); + if (delta.IsIntegral()) { + JUST(attrs.SetAttr("integer_start", JUST(start.As()))); + JUST(attrs.SetAttr("integer_limit", JUST(limit.As()))); + JUST(attrs.SetAttr("integer_delta", JUST(delta.As()))); + JUST(attrs.SetAttr("dtype", DType::Int64()->data_type())); + } else { + JUST(attrs.SetAttr("float_start", JUST(start.As()))); + JUST(attrs.SetAttr("float_limit", JUST(limit.As()))); + JUST(attrs.SetAttr("float_delta", JUST(delta.As()))); + JUST(attrs.SetAttr("dtype", DType::Float()->data_type())); + } } - if (LazyMode::is_enabled()) { std::vector nd_sbp(sbp_tuple.size()); { @@ -750,6 +735,7 @@ class ConsistentArange2Functor { Maybe operator()(const Scalar& limit, const Symbol& dtype, const Symbol& placement, const std::vector>& sbp_tuple) const { + JUST(CheckDeviceIdsIsValid(placement)); return ConsistentArange(Scalar(0), limit, Scalar(1), dtype, placement, sbp_tuple); } }; @@ -824,6 +810,19 @@ class ClampFunctor { std::shared_ptr clip_max_op_; }; +class SqrtSquareSumFunctor { + public: + SqrtSquareSumFunctor() { + op_ = CHECK_JUST(one::OpBuilder("sqrt_square_sum").Input("x").Output("y").Build()); + } + Maybe operator()(const std::shared_ptr& x) const { + return OpInterpUtil::Dispatch(*op_, {x}, {}); + } + + private: + std::shared_ptr op_; +}; + class VectorNormFunctor { public: VectorNormFunctor() {} @@ -848,6 +847,7 @@ class VectorNormFunctor { } dtype_val = x->dtype(); } + bool full_dim_flag = true; std::vector dim; if (!input_dim.has_value()) { std::vector reduce_axis(x->shape()->NumAxes()); @@ -862,7 +862,9 @@ class VectorNormFunctor { } else { dim.emplace_back(dim_check[i] + x->shape()->NumAxes()); } + if (dim[i] != i) { full_dim_flag = false; } } + if ((int)dim.size() < x->shape()->NumAxes()) { full_dim_flag = false; } } if (ord.IsIntegral() || ord.IsFloatingPoint()) { double ord_val = JUST(ord.As()); @@ -873,6 +875,9 @@ class VectorNormFunctor { res = JUST(ReduceMax(JUST(Abs(x)), dim, keepdim)); } else if (ord_val == -INFINITY) { res = JUST(ReduceMin(JUST(Abs(x)), dim, keepdim)); + } else if (ord_val == 2.0 && keepdim == false && full_dim_flag + && x->requires_grad() == false) { + res = JUST(SqrtSquareSum(x)); } else { res = JUST(ScalarPow(JUST(ReduceSum(JUST(ScalarPow(JUST(Abs(x)), ord, false)), dim, keepdim)), @@ -1680,6 +1685,43 @@ class MovedimIntFunctor { } }; +class CumsumFunctor { + public: + CumsumFunctor() { op_ = CHECK_JUST(one::OpBuilder("cumsum").Input("in").Output("out").Build()); } + Maybe operator()(const std::shared_ptr& input, int64_t dim) const { + auto ndim = input->ndim(); + if (dim < 0) { dim += ndim; } + CHECK_OR_RETURN(dim >= 0 && dim < ndim) + << "IndexError: Dimension out of range (expected to be in range of [" << -ndim << "," + << ndim << " ) but got " << dim; + + MutableAttrMap attrs; + JUST(attrs.SetAttr("dim", dim)); + TensorProcessor tensor_processor; + JUST(tensor_processor.AddInputs({input}, DType::Int64()).Apply()); + TensorTuple input_tuple = JUST(tensor_processor.GetInputs()); + return OpInterpUtil::Dispatch(*op_, input_tuple, attrs); + } + + private: + std::shared_ptr op_; +}; + +class CumsumGradFunctor { + public: + CumsumGradFunctor() { + op_ = CHECK_JUST(one::OpBuilder("cumsum_grad").Input("dy").Output("dx").Build()); + } + Maybe operator()(const std::shared_ptr& input, int64_t dim) const { + // No need to check dim validation here, while CumsumFunctor handled already + MutableAttrMap attrs; + JUST(attrs.SetAttr("dim", dim)); + return OpInterpUtil::Dispatch(*op_, {input}, attrs); + } + + private: + std::shared_ptr op_; +}; } // namespace impl using namespace impl; @@ -1710,13 +1752,12 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("ReduceMaxGlobalStageGrad"); m.add_functor("Transpose"); m.add_functor("Permute"); - m.add_functor("Eye"); - m.add_functor("ConsistentEye"); m.add_functor("Transpose2dim"); m.add_functor("Arange"); m.add_functor("ConsistentArange"); m.add_functor("Cast"); m.add_functor("Clamp"); + m.add_functor("SqrtSquareSum"); m.add_functor("VectorNorm"); m.add_functor("MatrixNorm"); m.add_functor("Norm"); @@ -1744,6 +1785,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("Dot"); m.add_functor("MovedimVec"); m.add_functor("MovedimInt"); + m.add_functor("Cumsum"); + m.add_functor("CumsumGrad"); }; } // namespace functional diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index c74f4a6e950..d34f121d046 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/common/data_type.pb.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/common/scalar.h" #include "oneflow/core/framework/attr_map.h" @@ -22,6 +23,7 @@ limitations under the License. #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" +#include "oneflow/core/framework/tensor_util.h" #include "oneflow/core/framework/op_interpreter.h" #include "oneflow/core/framework/random_generator.h" #include "oneflow/core/functional/functional.h" @@ -47,7 +49,12 @@ class BiasAddFunctor { Maybe operator()(const std::shared_ptr& x, const std::shared_ptr& bias, const int32_t& axis) const { MutableAttrMap attrs; - JUST(attrs.SetAttr("axis", axis)); + int32_t axis_val = axis; + if (axis_val < 0) { + const int64_t num_axes = x->shape()->NumAxes(); + axis_val += num_axes; + } + JUST(attrs.SetAttr("axis", axis_val)); return OpInterpUtil::Dispatch(*op_, {x, bias}, attrs); } @@ -65,11 +72,15 @@ class ConvBaseFunctor { const std::shared_ptr& weight, const Optional& bias, const std::vector& stride, const std::vector& padding, - const std::vector& dilation, const int32_t& groups) const { + const std::vector& dilation, const int32_t& groups, + const std::string& channel_pos) const { MutableAttrMap conv_attrs; std::vector kernel_size_vec(num_spatial_dims_); + int32_t kernel_idx_offset = 2; + if (channel_pos == "channels_last") { kernel_idx_offset = 1; } + for (int i = 0; i < num_spatial_dims_; i++) { - kernel_size_vec.at(i) = ((weight->shape())->At(i + 2)); + kernel_size_vec.at(i) = ((weight->shape())->At(i + kernel_idx_offset)); } JUST(conv_attrs.SetAttr("filters", (weight->shape())->At(0))); JUST(conv_attrs.SetAttr>("padding_before", padding)); @@ -77,7 +88,7 @@ class ConvBaseFunctor { JUST(conv_attrs.SetAttr>("strides", stride)); JUST(conv_attrs.SetAttr>("dilation_rate", dilation)); JUST(conv_attrs.SetAttr("groups", groups)); - JUST(conv_attrs.SetAttr("data_format", std::string("channels_first"))); + JUST(conv_attrs.SetAttr("data_format", channel_pos)); const std::shared_ptr& conv_out = JUST(OpInterpUtil::Dispatch(*conv_op_, {x, weight}, conv_attrs)); if (bias) { @@ -140,23 +151,10 @@ class DeConvBaseFunctor { JUST(deconv_attrs.SetAttr>("output_padding", output_padding)); JUST(deconv_attrs.SetAttr>("strides", strides)); JUST(deconv_attrs.SetAttr>("dilation_rate", dilation)); + JUST(deconv_attrs.SetAttr("groups", groups)); JUST(deconv_attrs.SetAttr("data_format", data_format)); std::shared_ptr deconv_out = nullptr; - if (groups == 1) { - deconv_out = JUST(OpInterpUtil::Dispatch(*deconv_op_, {x, weight}, deconv_attrs)); - } else { - auto nc = x->dim(1) / groups; - auto split_x = JUST(functional::Split(x, nc, 1)); - auto split_weight = JUST(functional::Split(weight, nc, 0)); - one::TensorTuple split_out; - for (int i = 0; i < groups; i++) { - const std::shared_ptr& deconv_i = JUST(OpInterpUtil::Dispatch( - *deconv_op_, {split_x->at(i), split_weight->at(i)}, deconv_attrs)); - split_out.emplace_back(deconv_i); - } - deconv_out = JUST(functional::Concat(split_out, 1)); - } - + deconv_out = JUST(OpInterpUtil::Dispatch(*deconv_op_, {x, weight}, deconv_attrs)); if (bias) { MutableAttrMap bias_attrs; JUST(bias_attrs.SetAttr("axis", 1)); @@ -179,6 +177,14 @@ class DeConv1dFunctor : public DeConvBaseFunctor { } }; +class DeConv2dFunctor : public DeConvBaseFunctor { + public: + DeConv2dFunctor() { + deconv_op_ = + CHECK_JUST(one::OpBuilder("deconv2d").Input("in").Input("weight").Output("out").Build()); + } +}; + class DeConv3dFunctor : public DeConvBaseFunctor { public: DeConv3dFunctor() { @@ -1264,7 +1270,7 @@ class NormalizationAddReluFunctor { .Output("y") .Attr("training", false) .Build()); - relu_op_ = CHECK_JUST(one::OpBuilder("relu").Input("in").Output("out").Build()); + relu_op_ = CHECK_JUST(one::OpBuilder("relu").Input("x").Output("y").Build()); add_op_ = CHECK_JUST(one::OpBuilder("add_n").Input("in", 2).Output("out").Build()); fused_norm_training_stats_op_ = CHECK_JUST(one::OpBuilder("normalization_add_relu") .Input("x") @@ -1596,17 +1602,6 @@ class FoldFunctor { std::shared_ptr fold_op_; }; -Maybe SyncAccessTensorWithTimeOut( - const std::shared_ptr& tensor, - const std::shared_ptr>& callback, const std::string& modifier) { - return SpinCounter::SpinWait(1, [&](const std::shared_ptr& sc) -> Maybe { - return PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { - return builder->SyncAccessBlobByCallback(JUST(tensor->AsMirroredTensor()), sc, callback, - modifier); - }); - }); -} - class OneHotFunctor { public: OneHotFunctor() { @@ -1636,10 +1631,11 @@ class OneHotFunctor { } else { JUST(attrs.SetAttr("depth", num_classes)); } + // Refer to: https://github.com/Oneflow-Inc/oneflow/pull/5315/files#r755823506 bool is_on_value_double = on_value.IsFloatingPoint(); bool is_off_value_double = off_value.IsFloatingPoint(); if (is_on_value_double || is_off_value_double) { - JUST(attrs.SetAttr("dtype", kDouble)); + JUST(attrs.SetAttr("dtype", kFloat)); JUST(attrs.SetAttr("floating_on_value", JUST(on_value.As()))); JUST(attrs.SetAttr("floating_off_value", JUST(off_value.As()))); JUST(attrs.SetAttr("integer_on_value", 0)); @@ -1664,18 +1660,44 @@ class L2NormalizeFunctor { op_ = CHECK_JUST( one::OpBuilder("l2_normalize").Input("x").Output("y").Output("square_x_sum").Build()); } - Maybe operator()(const std::shared_ptr& input, const int32_t& axis, - const float& epsilon) const { + Maybe operator()(const std::shared_ptr& input, const int32_t& axis, + const float& epsilon) const { MutableAttrMap attrs; - JUST(attrs.SetAttr("axis", axis)); + JUST(attrs.SetAttr("axis", 0)); JUST(attrs.SetAttr("epsilon", epsilon)); - return OpInterpUtil::Dispatch(*op_, {input}, attrs); + + if (axis != 0) { + std::vector input_perm(input->shape()->dim_vec().size(), 0); + for (size_t i = 0; i < input_perm.size(); ++i) { input_perm[i] = static_cast(i); } + std::swap(input_perm[0], input_perm[static_cast(axis)]); + + const auto result = JUST(OpInterpUtil::Dispatch( + *op_, {JUST(functional::Transpose(input, input_perm))}, attrs)); + return functional::Transpose(result->at(0), input_perm); + } + + return OpInterpUtil::Dispatch(*op_, {input}, attrs); } private: std::shared_ptr op_; }; +class NormalizeFunctor { + public: + Maybe operator()(const std::shared_ptr& input, const float& p, + const int32_t& dim, const float& eps) const { + return SequenceFunction(const std::shared_ptr&, const float&, + const int32_t&)>( + [](const auto& x, const float& p, const int32_t& dim) -> Maybe { + return functional::ScalarNorm(x, p, dim, true, NullOpt); + }) + .then([&](const auto& x) { return functional::Clamp(x, eps, NullOpt); }) + .then([&](const auto& x) { return functional::Div(input, x); }) + .call(input, p, dim); + } +}; + class FusedSelfAttentionFunctor { public: FusedSelfAttentionFunctor() { @@ -1849,18 +1871,26 @@ class FusedBiasAddDropoutFunctor { if (p != 1.0) { scale = 1.0 / (1.0 - p); } MutableAttrMap fused_bias_add_mask_attrs; JUST(fused_bias_add_mask_attrs.SetAttr("scale", scale)); - JUST(fused_bias_add_mask_attrs.SetAttr("axis", axis)); - - return SequenceFunction()>([&]() -> Maybe { - return OpInterpUtil::Dispatch( - *random_mask_like_op_, {a}, - OpExprInterpContext(random_mask_like_attrs, random_mask_like_state)); - }) - .then([&](const std::shared_ptr& x) { - return OpInterpUtil::Dispatch(*fused_bias_add_mask_scale_op_, {a, b, x}, - fused_bias_add_mask_attrs); - }) - .call(); + int32_t axis_val = axis; + if (axis_val < 0) { + const int64_t num_axes = a->shape()->NumAxes(); + axis_val += num_axes; + } + JUST(fused_bias_add_mask_attrs.SetAttr("axis", axis_val)); + if (p >= 0.0) { + return SequenceFunction()>([&]() -> Maybe { + return OpInterpUtil::Dispatch( + *random_mask_like_op_, {a}, + OpExprInterpContext(random_mask_like_attrs, random_mask_like_state)); + }) + .then([&](const std::shared_ptr& x) { + return OpInterpUtil::Dispatch(*fused_bias_add_mask_scale_op_, {a, b, x}, + fused_bias_add_mask_attrs); + }) + .call(); + } else { + return functional::BiasAdd(a, b, axis_val); + } } private: @@ -2113,6 +2143,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("Conv2d"); m.add_functor("Conv3d"); m.add_functor("Deconv1d"); + m.add_functor("Deconv2d"); m.add_functor("Deconv3d"); m.add_functor("MatMul"); m.add_functor("BatchMatMul"); @@ -2158,6 +2189,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("OneHot"); m.add_functor("FusedSelfAttention"); m.add_functor("FusedSelfAttentionGrad"); + m.add_functor("Normalize"); m.add_functor("L2Normalize"); m.add_functor("L2NormalizeGrad"); m.add_functor("FusedBiasAddGelu"); diff --git a/oneflow/core/functional/impl/random_functor.cpp b/oneflow/core/functional/impl/random_functor.cpp index 3ab56adf8eb..d85ddc29704 100644 --- a/oneflow/core/functional/impl/random_functor.cpp +++ b/oneflow/core/functional/impl/random_functor.cpp @@ -108,6 +108,7 @@ class ConsistentRandFunctor { const Optional>& dtype, const Optional& generator, const bool& requires_grad) const { + JUST(CheckDeviceIdsIsValid(placement)); DataType dtype_val = DataType::kFloat; if (dtype.has_value()) { dtype_val = JUST(dtype)->data_type(); @@ -127,8 +128,8 @@ class ConsistentRandFunctor { const auto& distribution_state = std::make_shared(gen); const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple)); - if (!JUST(IsMultiClient())) { - JUST(attrs.SetAttr("nd_sbp", nd_sbp->DebugString())); + if (LazyMode::is_enabled()) { + JUST(attrs.SetAttr>("nd_sbp", *JUST(GetNdSbpStrList(nd_sbp)))); } auto result = JUST(OpInterpUtil::Dispatch( *op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, distribution_state))); @@ -182,6 +183,7 @@ class ConsistentRandNFunctor { const Optional>& dtype, const Optional& generator, const bool& requires_grad) const { + JUST(CheckDeviceIdsIsValid(placement)); DataType dtype_val = DataType::kFloat; if (dtype) { dtype_val = JUST(dtype)->data_type(); } if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) { @@ -199,8 +201,8 @@ class ConsistentRandNFunctor { const auto& distribution_state = std::make_shared(gen); const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple)); - if (!JUST(IsMultiClient())) { - JUST(attrs.SetAttr("nd_sbp", nd_sbp->DebugString())); + if (LazyMode::is_enabled()) { + JUST(attrs.SetAttr>("nd_sbp", *JUST(GetNdSbpStrList(nd_sbp)))); } auto result = JUST(OpInterpUtil::Dispatch( *op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, distribution_state))); @@ -269,6 +271,7 @@ class ConsistentRandIntFunctor { const Optional>& dtype, const Optional& generator, const bool& requires_grad) const { + JUST(CheckDeviceIdsIsValid(placement)); DataType dtype_val = DataType::kInt64; if (dtype) { dtype_val = JUST(dtype)->data_type(); } @@ -282,17 +285,10 @@ class ConsistentRandIntFunctor { const auto& distribution_state = std::make_shared(gen); + const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple)); if (LazyMode::is_enabled()) { - std::vector nd_sbp(sbp_tuple.size()); - { - for (int i = 0; i < sbp_tuple.size(); ++i) { - nd_sbp.at(i) = SbpParallelToString(*sbp_tuple.at(i)); - } - } - JUST(attrs.SetAttr>("nd_sbp", nd_sbp)); + JUST(attrs.SetAttr>("nd_sbp", *JUST(GetNdSbpStrList(nd_sbp)))); } - const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple)); - auto result = JUST(OpInterpUtil::Dispatch( *op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, distribution_state))); @@ -312,6 +308,7 @@ class ConsistentRandInt2Functor { const Optional>& dtype, const Optional& generator, const bool& requires_grad) const { + JUST(CheckDeviceIdsIsValid(placement)); return ConsistentRandInt(/*low*/ 0, high, shape, placement, sbp_tuple, dtype, generator, requires_grad); } @@ -351,6 +348,7 @@ class ConsistentRandPermFunctor { const std::vector>& sbp_tuple, const Optional& generator, const Symbol& dtype, const bool& requires_grad) const { + JUST(CheckDeviceIdsIsValid(placement)); const auto gen = generator.value_or(JUST(one::DefaultAutoGenerator())); MutableAttrMap attrs; JUST(attrs.SetAttr("n", n)); @@ -358,16 +356,10 @@ class ConsistentRandPermFunctor { const auto& distribution_state = std::make_shared(gen); + const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple)); if (LazyMode::is_enabled()) { - std::vector nd_sbp(sbp_tuple.size()); - { - for (int i = 0; i < sbp_tuple.size(); ++i) { - nd_sbp.at(i) = SbpParallelToString(*sbp_tuple.at(i)); - } - } - JUST(attrs.SetAttr>("nd_sbp", nd_sbp)); + JUST(attrs.SetAttr>("nd_sbp", *JUST(GetNdSbpStrList(nd_sbp)))); } - const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple)); auto result = JUST(OpInterpUtil::Dispatch( *randperm_op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, distribution_state))); diff --git a/oneflow/core/functional/impl/unary_functor.cpp b/oneflow/core/functional/impl/unary_functor.cpp index 3954c9a6281..50efed0b3e4 100644 --- a/oneflow/core/functional/impl/unary_functor.cpp +++ b/oneflow/core/functional/impl/unary_functor.cpp @@ -55,6 +55,7 @@ namespace impl { OF_PP_MAKE_TUPLE_SEQ("exp", Exp) \ OF_PP_MAKE_TUPLE_SEQ("expm1", Expm1) \ OF_PP_MAKE_TUPLE_SEQ("log", Log) \ + OF_PP_MAKE_TUPLE_SEQ("log2", Log2) \ OF_PP_MAKE_TUPLE_SEQ("log1p", Log1p) \ OF_PP_MAKE_TUPLE_SEQ("negative", Negative) \ OF_PP_MAKE_TUPLE_SEQ("reciprocal", Reciprocal) \ @@ -132,6 +133,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) { ADD_UNARY_FUNCTOR(Floor, "Floor"); ADD_UNARY_FUNCTOR(Lgamma, "Lgamma"); ADD_UNARY_FUNCTOR(Log, "Log"); + ADD_UNARY_FUNCTOR(Log2, "Log2"); ADD_UNARY_FUNCTOR(Log1p, "Log1p"); ADD_UNARY_FUNCTOR(LogSigmoid, "LogSigmoid"); ADD_UNARY_FUNCTOR(Negative, "Negative"); diff --git a/oneflow/core/functional/tensor_index.cpp b/oneflow/core/functional/tensor_index.cpp index be99333e61b..dbc4d159998 100644 --- a/oneflow/core/functional/tensor_index.cpp +++ b/oneflow/core/functional/tensor_index.cpp @@ -19,6 +19,7 @@ limitations under the License. #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/tensor_tuple.h" +#include "oneflow/core/framework/tensor_util.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/functional/functional.h" #include "oneflow/core/job/sbp_parallel.h" @@ -30,17 +31,6 @@ namespace functional { namespace { -Maybe SyncAccessTensorWithTimeOut( - const std::shared_ptr& tensor, - const std::shared_ptr>& callback, const std::string& modifier) { - return SpinCounter::SpinWait(1, [&](const std::shared_ptr& sc) -> Maybe { - return PhysicalRun([&](InstructionsBuilder* builder) -> Maybe { - return builder->SyncAccessBlobByCallback(JUST(tensor->AsMirroredTensor()), sc, callback, - modifier); - }); - }); -} - int64_t CountSpecifiedDims(const TensorIndex& index) { int64_t specified_ndims = 0; for (int i = 0; i < index.size(); ++i) { @@ -50,7 +40,7 @@ int64_t CountSpecifiedDims(const TensorIndex& index) { } else if (index_item.IsTensor()) { const auto& tensor = index_item.tensor(); if (tensor->dtype() == DType::Int8() || tensor->dtype() == DType::UInt8()) { - specified_ndims += tensor->shape()->NumAxes(); + specified_ndims += tensor->ndim(); } else { specified_ndims++; } @@ -81,7 +71,7 @@ Maybe ExpandMaskIndex(const std::shared_ptr& index) { }); JUST(SyncAccessTensorWithTimeOut(size_tensor, callback, "const")); - for (int i = 0; i < index->shape()->NumAxes(); ++i) { + for (int i = 0; i < index->ndim(); ++i) { auto item = JUST(functional::Slice(res->at(0), {0, i}, {size, i + 1}, {1, 1})); item = JUST(functional::Reshape(item, {size})); indices->emplace_back(item); @@ -139,7 +129,7 @@ Maybe IsContinuousSubspace(const TensorTuple& indices) { token = 1; } else if (indices.at(i) && token) { if (token != 1) { return false; } - } else if (!token) { + } else if (token) { token += 1; } } @@ -149,14 +139,14 @@ Maybe IsContinuousSubspace(const TensorTuple& indices) { Maybe TransposeFront(const std::shared_ptr& input, const TensorTuple& indices, std::shared_ptr* output, TensorTuple* valid_indices) { std::vector permute; - permute.reserve(input->shape()->NumAxes()); - for (int i = 0; i < input->shape()->NumAxes(); ++i) { + permute.reserve(input->ndim()); + for (int i = 0; i < input->ndim(); ++i) { if (i < indices.size() && indices.at(i)) { permute.emplace_back(i); valid_indices->emplace_back(indices.at(i)); } } - for (int i = 0; i < input->shape()->NumAxes(); ++i) { + for (int i = 0; i < input->ndim(); ++i) { if (i >= indices.size() || !indices.at(i)) { permute.emplace_back(i); } } bool need_transpose = [&]() { @@ -183,7 +173,7 @@ Maybe AdjustSubspace(const std::shared_ptr& input, const TensorT } } if (index_subspace_pos <= 0) { return input; } - int ndim = input->shape()->NumAxes(); + int ndim = input->ndim(); CHECK_LE_OR_RETURN(index_subspace_pos + index_ndim, ndim) << "Failed to adjust subspace since the index is out of bounds for tensor dimension " << ndim; std::vector permute; @@ -271,7 +261,7 @@ Maybe PrepareSliceIndices(const TensorIndex& index, const Shape& shape, const auto& tensor = index_item.tensor(); auto indices = std::make_shared(); if (tensor->dtype() == DType::Int8() || tensor->dtype() == DType::UInt8()) { - for (int j = 0; j < tensor->shape()->NumAxes(); ++j) { + for (int j = 0; j < tensor->ndim(); ++j) { if (tensor->shape()->At(j) != shape.At(dim + j)) { return Error::IndexError() << "The shape of the mask " << tensor->shape()->ToString() << " at index " << j @@ -318,10 +308,10 @@ Maybe> RemoveExpandDimSlice( Maybe ApplyAdvancedIndexing(const std::shared_ptr& input, const TensorTuple& indices) { - CHECK_GE_OR_RETURN(input->shape()->NumAxes(), indices.size()) - << "Too many indices for tensor of dimension " << input->shape()->NumAxes(); + CHECK_GE_OR_RETURN(input->ndim(), indices.size()) + << "Too many indices for tensor of dimension " << input->ndim(); const auto& expanded_indices = JUST(ExpandIndices(indices)); - bool is_continuos_subspace = JUST(IsContinuousSubspace(indices)); + bool is_continuous_subspace = JUST(IsContinuousSubspace(indices)); // Since the start dimension cannot be specified for `gather_nd`, so we should // transpose the input as long as the first indice is null. @@ -329,14 +319,9 @@ Maybe ApplyAdvancedIndexing(const std::shared_ptr& input, TensorTuple valid_indices; JUST(TransposeFront(input, *expanded_indices, &transposed_input, &valid_indices)); if (valid_indices.empty()) { return input; } - int index_ndim = valid_indices.at(0)->shape()->NumAxes(); - std::shared_ptr packed_indices; - if (valid_indices.size() > 1) { - packed_indices = JUST(Stack(valid_indices, 0)); - } else { - packed_indices = JUST(ExpandDims(valid_indices.at(0), 0)); - } - int packed_ndim = packed_indices->shape()->NumAxes(); + int index_ndim = valid_indices.at(0)->ndim(); + auto packed_indices = JUST(Stack(valid_indices, 0)); + int packed_ndim = packed_indices->ndim(); CHECK_GT_OR_RETURN(packed_ndim, 0) << "Index array dimension should be greater than 0."; std::vector permute(packed_ndim); permute[packed_ndim - 1] = 0; @@ -356,11 +341,11 @@ Maybe ApplyAdvancedIndexing(const std::shared_ptr& input, } auto result = JUST(GatherNd(transposed_input, packed_indices)); - int required_ndim = input->shape()->NumAxes() - valid_indices.size() + index_ndim; - CHECK_EQ_OR_RETURN(result->shape()->NumAxes(), required_ndim) - << "The indexing result dimension is " << result->shape()->NumAxes() << ", but shoule be " + int required_ndim = input->ndim() - valid_indices.size() + index_ndim; + CHECK_EQ_OR_RETURN(result->ndim(), required_ndim) + << "The indexing result dimension is " << result->ndim() << ", but shoule be " << required_ndim; - if (is_continuos_subspace) { result = JUST(AdjustSubspace(result, indices, index_ndim)); } + if (is_continuous_subspace) { result = JUST(AdjustSubspace(result, indices, index_ndim)); } return result; } diff --git a/oneflow/core/graph/boxing/boxing_logger.cpp b/oneflow/core/graph/boxing/boxing_logger.cpp index dadd43e4b79..b08138d36f2 100644 --- a/oneflow/core/graph/boxing/boxing_logger.cpp +++ b/oneflow/core/graph/boxing/boxing_logger.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/graph/boxing/boxing_logger.h" #include "oneflow/core/job/sbp_parallel.h" +#include "oneflow/core/framework/nd_sbp.h" namespace oneflow { @@ -56,17 +57,6 @@ std::string ParallelDescToString(const ParallelDesc& parallel_desc) { return serialized_parallel_desc; } -std::string NdSbpToString(const cfg::NdSbp& nd_sbp) { - std::string serialized_nd_sbp; - const int64_t num_axes = nd_sbp.sbp_parallel_size(); - serialized_nd_sbp += "["; - for (int64_t i = 0; i < num_axes - 1; ++i) { - serialized_nd_sbp += SbpParallelToString(nd_sbp.sbp_parallel(i)) + " "; - } - serialized_nd_sbp += SbpParallelToString(nd_sbp.sbp_parallel(num_axes - 1)) + "]"; - return serialized_nd_sbp; -} - std::string MakeBoxingLoggerCsvRow(const SubTskGphBuilderStatus& status, const std::string& src_op_name, const std::string& dst_op_name, const ParallelDesc& src_parallel_desc, diff --git a/oneflow/core/graph/task_graph.cpp b/oneflow/core/graph/task_graph.cpp index 5ff91ee3dc1..40ee2918658 100644 --- a/oneflow/core/graph/task_graph.cpp +++ b/oneflow/core/graph/task_graph.cpp @@ -541,53 +541,6 @@ void TaskGraph::ConnectCtrlEdges(const std::vector& src_task_node } } -void TaskGraph::AddCtrlEdgeBetweenSrcDstTickAndInputOutputInSameRank() { - if (!CHECK_JUST(IsMultiClient())) { return; } - HashMap rank_id2src_tick; - HashMap rank_id2dst_tick; - HashMap> rank_id2input_output_nodes; - - ForEachNode([&](TaskNode* node) { - if (node->GetTaskType() == TaskType::kSrcSubsetTick) { - CHECK(rank_id2src_tick.emplace(node->machine_id(), node).second); - } else if (node->GetTaskType() == TaskType::kDstSubsetTick) { - CHECK(rank_id2dst_tick.emplace(node->machine_id(), node).second); - } else if (node->GetTaskType() == TaskType::kNormalForward) { - auto* forward_node = reinterpret_cast(node); - CHECK(forward_node); - if (forward_node->op()->op_conf().has_input_conf() - || forward_node->op()->op_conf().has_output_conf()) { - CHECK(rank_id2input_output_nodes[node->machine_id()].insert(node).second); - } - } - }); - - auto AddCtrlEdge = [&](TaskNode* src, TaskNode* dst) { - std::string ctrl_regst_name; - RegstDesc* ctrl_regst = src->BuildCtrlRegstDesc(dst, &ctrl_regst_name); - // NOTE(chengcheng): - // ctrl edge between src subset tick to output is just for restrict order in multi-client - // but this ctrl edge will block src subset tick to delay pipeline, so this ctrl edge must - // at least 2. - ctrl_regst->UpdtMinRegstNumIfNeed(2); - TaskEdge* edge = NewEdge(); - Connect(src, edge, dst); - src->BindEdgeWithProducedRegst(edge, ctrl_regst_name); - }; - - for (auto& pair : rank_id2src_tick) { - int64_t rank_id = pair.first; - TaskNode* src = pair.second; - for (TaskNode* io_task : rank_id2input_output_nodes[rank_id]) { AddCtrlEdge(src, io_task); } - } - - for (auto& pair : rank_id2dst_tick) { - int64_t rank_id = pair.first; - TaskNode* dst = pair.second; - for (TaskNode* io_task : rank_id2input_output_nodes[rank_id]) { AddCtrlEdge(io_task, dst); } - } -} - void TaskGraph::RemoveEmptyRegsts() { ForEachNode([&](TaskNode* node) { node->EraseUninitializedShapeProducedBlob(); }); ForEachNode([&](TaskNode* node) { node->EraseZeroSizeConsumedRegst(); }); diff --git a/oneflow/core/graph/task_graph.h b/oneflow/core/graph/task_graph.h index b7fcc0099ca..71593a834f1 100644 --- a/oneflow/core/graph/task_graph.h +++ b/oneflow/core/graph/task_graph.h @@ -47,7 +47,6 @@ class TaskGraph final : public Graph { const char* TypeName() const override { return "TaskGraph"; } void RemoveEmptyRegsts(); - void AddCtrlEdgeBetweenSrcDstTickAndInputOutputInSameRank(); void MergeChainAndAddOrderingCtrlEdgeInSameChain(); void EnableInplaceMemSharing(const std::function& diff --git a/oneflow/core/graph/task_stream_index_manager.cpp b/oneflow/core/graph/task_stream_index_manager.cpp index dc46ba1e12e..8e98590a13c 100644 --- a/oneflow/core/graph/task_stream_index_manager.cpp +++ b/oneflow/core/graph/task_stream_index_manager.cpp @@ -62,7 +62,7 @@ Maybe TaskStreamIndexGetterRegistry::Dispatch( auto key = std::make_pair(device_type, task_type); auto it = stream_index_getter_map_.find(key); CHECK_OR_RETURN(it != stream_index_getter_map_.end()) - << "TaskType: " << key.first << ", DeviceType: " << key.second << " has not been registered"; + << "TaskType: " << key.second << ", DeviceType: " << key.first << " has not been registered"; return it->second(generator); } diff --git a/oneflow/core/graph_impl/critical_section_wait_compute_task_node.cpp b/oneflow/core/graph_impl/critical_section_wait_compute_task_node.cpp new file mode 100644 index 00000000000..51647e00bdf --- /dev/null +++ b/oneflow/core/graph_impl/critical_section_wait_compute_task_node.cpp @@ -0,0 +1,67 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/graph/compute_task_node.h" +#include "oneflow/core/graph/task_stream_index_manager.h" + +namespace oneflow { + +class CriticalSectionWaitTickCompTaskNode final : public CompTaskNode { + public: + OF_DISALLOW_COPY_AND_MOVE(CriticalSectionWaitTickCompTaskNode); + CriticalSectionWaitTickCompTaskNode() = default; + ~CriticalSectionWaitTickCompTaskNode() = default; + + bool IsMeaningLess() override { return false; } + TaskType GetTaskType() const override { return TaskType::kCriticalSectionWaitTick; } + + private: + void ProduceAllRegstsAndBindEdges() override; + void ConsumeAllRegsts() override; + void BuildExecGphAndRegst() override; +}; + +void CriticalSectionWaitTickCompTaskNode::ProduceAllRegstsAndBindEdges() { + ProduceRegst("out", false, 128, 128); + ForEachOutDataEdge([&](TaskEdge* edge) { BindEdgeWithProducedRegst(edge, "out"); }); +} + +void CriticalSectionWaitTickCompTaskNode::ConsumeAllRegsts() { + ConsumeRegst("in"); + ForEachInDataEdge([&](TaskEdge* edge) { ConsumeRegst("in", edge->GetSoleRegst()); }); +} + +void CriticalSectionWaitTickCompTaskNode::BuildExecGphAndRegst() { + ExecNode* node = mut_exec_gph().NewNode(); + node->mut_op() = op(); + const std::list>& in_regsts = GetConsumedRegst("in"); + for (const std::string& ibn : node->op()->input_bns()) { + node->BindBnWithOneOfTheRegsts(ibn, in_regsts); + } + std::shared_ptr out_regst = GetProducedRegst("out"); + for (const std::string& obn : node->op()->output_bns()) { + const LogicalBlobId& lbi = node->op()->BnInOp2Lbi(obn); + out_regst->AddLbi(lbi); + node->BindBnWithRegst(obn, out_regst); + } + node->InferBlobDescs(parallel_ctx()); +} + +REGISTER_INDEPENDENT_TASK_STREAM_INDEX_GETTER(TaskType::kCriticalSectionWaitTick); + +REGISTER_SYSTEM_OP_COMP_TASK_NODE_TYPE(OperatorConf::kCriticalSectionWaitTickConf, + CriticalSectionWaitTickCompTaskNode); + +} // namespace oneflow diff --git a/oneflow/core/hardware/cuda_device_descriptor.cpp b/oneflow/core/hardware/cuda_device_descriptor.cpp index 81e4182edad..ae6cdc463c1 100644 --- a/oneflow/core/hardware/cuda_device_descriptor.cpp +++ b/oneflow/core/hardware/cuda_device_descriptor.cpp @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include +#include "nlohmann/json.hpp" namespace oneflow { diff --git a/oneflow/core/hardware/cuda_device_descriptor_class.cpp b/oneflow/core/hardware/cuda_device_descriptor_class.cpp index 877f5f582b6..64177f7d703 100644 --- a/oneflow/core/hardware/cuda_device_descriptor_class.cpp +++ b/oneflow/core/hardware/cuda_device_descriptor_class.cpp @@ -20,7 +20,7 @@ limitations under the License. #include "oneflow/core/common/util.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/core/common/str_util.h" -#include +#include "nlohmann/json.hpp" #ifdef WITH_CUDA diff --git a/oneflow/core/hardware/net_ib_device_descriptor.cpp b/oneflow/core/hardware/net_ib_device_descriptor.cpp index ec92f04fb2e..3c2e1d76547 100644 --- a/oneflow/core/hardware/net_ib_device_descriptor.cpp +++ b/oneflow/core/hardware/net_ib_device_descriptor.cpp @@ -17,7 +17,7 @@ limitations under the License. #ifdef WITH_RDMA -#include +#include "nlohmann/json.hpp" namespace oneflow { diff --git a/oneflow/core/hardware/net_ib_device_descriptor_class.cpp b/oneflow/core/hardware/net_ib_device_descriptor_class.cpp index 9e2d525f562..bfa05859b85 100644 --- a/oneflow/core/hardware/net_ib_device_descriptor_class.cpp +++ b/oneflow/core/hardware/net_ib_device_descriptor_class.cpp @@ -20,7 +20,7 @@ limitations under the License. #include "oneflow/core/common/util.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/core/common/str_util.h" -#include +#include "nlohmann/json.hpp" #ifdef WITH_RDMA diff --git a/oneflow/core/hardware/net_socket_device_descriptor.cpp b/oneflow/core/hardware/net_socket_device_descriptor.cpp index f91e1ecfc9a..b041b4b022e 100644 --- a/oneflow/core/hardware/net_socket_device_descriptor.cpp +++ b/oneflow/core/hardware/net_socket_device_descriptor.cpp @@ -17,7 +17,7 @@ limitations under the License. #ifdef __linux__ #include "oneflow/core/hardware/net_socket_device_descriptor.h" -#include +#include "nlohmann/json.hpp" namespace oneflow { diff --git a/oneflow/core/hardware/net_socket_device_descriptor_class.cpp b/oneflow/core/hardware/net_socket_device_descriptor_class.cpp index ead140117e6..567d169fca0 100644 --- a/oneflow/core/hardware/net_socket_device_descriptor_class.cpp +++ b/oneflow/core/hardware/net_socket_device_descriptor_class.cpp @@ -22,7 +22,7 @@ limitations under the License. #include "oneflow/core/common/util.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" #include "oneflow/core/common/str_util.h" -#include +#include "nlohmann/json.hpp" #include #include #include diff --git a/oneflow/core/hardware/node_device_descriptor.cpp b/oneflow/core/hardware/node_device_descriptor.cpp index 1d3c21bb135..b24b9ceda19 100644 --- a/oneflow/core/hardware/node_device_descriptor.cpp +++ b/oneflow/core/hardware/node_device_descriptor.cpp @@ -17,7 +17,7 @@ limitations under the License. #include "oneflow/core/hardware/device_descriptor_class.h" #include "oneflow/core/common/str_util.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" -#include +#include "nlohmann/json.hpp" #ifdef WITH_HWLOC #include #endif // WITH_HWLOC @@ -139,9 +139,9 @@ class HWLocTopologyDescriptor : public TopologyDescriptor { if (bus_id.empty()) { return nullptr; } hwloc_obj_t non_io_ancestor = GetNonIOAncestorByPCIBusID(bus_id); if (non_io_ancestor == nullptr) { return nullptr; } - if (non_io_ancestor->nodeset == nullptr) { return nullptr; } + if (non_io_ancestor->cpuset == nullptr) { return nullptr; } return std::make_shared( - hwloc_bitmap_dup(non_io_ancestor->nodeset), HWLOC_MEMBIND_BIND); + hwloc_bitmap_dup(non_io_ancestor->cpuset), HWLOC_MEMBIND_BIND); } void SetCPUAffinity( diff --git a/oneflow/core/job/compiler.cpp b/oneflow/core/job/compiler.cpp index 0cc6dd6a3c7..7a529f53180 100644 --- a/oneflow/core/job/compiler.cpp +++ b/oneflow/core/job/compiler.cpp @@ -68,10 +68,6 @@ void Compiler::Compile(Job* job, Plan* plan, bool need_job_complete) const { task_gph->ForEachNode(std::bind(&TaskNode::PinConsumedRegst, _1)); task_gph->TopoForEachNode(&TaskNode::Build); task_gph->RemoveEmptyRegsts(); - // NOTE(chengcheng): - // In Multi-Client, each rank has its own src_tick/dst_tick and input/output with callback, - // which need to be forced sequenced. - task_gph->AddCtrlEdgeBetweenSrcDstTickAndInputOutputInSameRank(); task_gph->MergeChainAndAddOrderingCtrlEdgeInSameChain(); auto IsReachable = Global::Get()->MakePredicatorIsOpNameDataOrCtrlReachable(); if (job_desc.enable_inplace()) { task_gph->EnableInplaceMemSharing(IsReachable); } diff --git a/oneflow/core/job/critical_section_instance.h b/oneflow/core/job/critical_section_instance.h new file mode 100644 index 00000000000..765056d8864 --- /dev/null +++ b/oneflow/core/job/critical_section_instance.h @@ -0,0 +1,39 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_JOB_CRITICAL_SECTION_INSTANCE_H_ +#define ONEFLOW_CORE_JOB_CRITICAL_SECTION_INSTANCE_H_ + +#include "oneflow/core/register/ofblob.h" + +namespace oneflow { + +class CriticalSectionInstance { + public: + CriticalSectionInstance() = default; + + virtual const std::string& job_name() const = 0; + + virtual ~CriticalSectionInstance() = default; + + virtual void AccessBlobByOpName(uint64_t ofblob_ptr, const std::string& op_name) const { + UNIMPLEMENTED(); + } + virtual void Finish() const { UNIMPLEMENTED(); } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_JOB_CRITICAL_SECTION_INSTANCE_H_ diff --git a/oneflow/core/job/env_global_objects_scope.cpp b/oneflow/core/job/env_global_objects_scope.cpp index 05bb5893446..3919b448826 100644 --- a/oneflow/core/job/env_global_objects_scope.cpp +++ b/oneflow/core/job/env_global_objects_scope.cpp @@ -132,9 +132,6 @@ bool CommNetIBEnabled() { Maybe EnvGlobalObjectsScope::Init(const EnvProto& env_proto) { InitLogging(env_proto.cpp_logging_conf()); -#ifdef WITH_CUDA - InitGlobalCudaDeviceProp(); -#endif Global::New(env_proto); Global::New(); // Avoid dead lock by using CHECK_JUST instead of JUST. because it maybe be blocked in @@ -255,9 +252,6 @@ EnvGlobalObjectsScope::~EnvGlobalObjectsScope() { Global::Delete(); Global::Delete(); Global::Delete(); -#ifdef WITH_CUDA - Global::Delete(); -#endif ClearAllSymbolAndIdCache(); google::ShutdownGoogleLogging(); } diff --git a/oneflow/core/job/inter_job_mem_sharing_util.cpp b/oneflow/core/job/inter_job_mem_sharing_util.cpp index 0f63eaf935a..fbba694281e 100644 --- a/oneflow/core/job/inter_job_mem_sharing_util.cpp +++ b/oneflow/core/job/inter_job_mem_sharing_util.cpp @@ -290,10 +290,10 @@ void MergeSharedInterfaceMemBlock(const std::vector>& jobs, HashMap* mem_block_id2mem_block) { HashMap> interface_op_name2job_ids = GetInterfaceOpName2JobIds(jobs); - HashSet interface_op_names; - for (const auto& pair : interface_op_name2job_ids) { interface_op_names.insert(pair.first); } + HashSet interfaces_op_names; + for (const auto& pair : interface_op_name2job_ids) { interfaces_op_names.insert(pair.first); } HashMap>> op_name2job_id2task_protos; - GetOpName2JobId2TaskProtos(plan, interface_op_names, &op_name2job_id2task_protos); + GetOpName2JobId2TaskProtos(plan, interfaces_op_names, &op_name2job_id2task_protos); for (const auto& op_job_pair : interface_op_name2job_ids) { if (op_job_pair.second.size() <= 1) { continue; } diff --git a/oneflow/core/job/job_build_and_infer_ctx.cpp b/oneflow/core/job/job_build_and_infer_ctx.cpp index b9a8bdf0b4b..1aa14821ca9 100644 --- a/oneflow/core/job/job_build_and_infer_ctx.cpp +++ b/oneflow/core/job/job_build_and_infer_ctx.cpp @@ -27,7 +27,7 @@ limitations under the License. #include "oneflow/user/summary/summary_converter.h" #include -#include +#include "nlohmann/json.hpp" namespace oneflow { @@ -153,16 +153,13 @@ Maybe JobBuildAndInferCtx::AddLbiParallelConf2BlobPlacement( } Maybe JobBuildAndInferCtx::DecodeLbiHintAndReturnNewOpConf( - const Operator& op, cfg::SbpSignature* sbp_sig_conf, - HashMap* ibn2disable_boxing) const { + const Operator& op, cfg::SbpSignature* sbp_sig_conf) const { auto op_conf_without_split_hint = std::make_shared(op.op_conf()); for (const std::string& ibn : op.input_bns()) { std::string lbn_may_with_hint = GetInputLbnInOpCustomizedConf(op.op_conf(), ibn); cfg::SbpParallel sbp_parallel; bool has_sbp_hint = JUST(GetSbpParallelInLbnOrNothing(lbn_may_with_hint, &sbp_parallel)); - bool has_disable_boxing_hint = - JUST(ParseDisableBoxingFlag(lbn_may_with_hint, &(*ibn2disable_boxing)[ibn])); - if (has_sbp_hint || has_disable_boxing_hint) { + if (has_sbp_hint) { (*(sbp_sig_conf->mutable_bn_in_op2sbp_parallel()))[ibn] = sbp_parallel; const LogicalBlobId& lbi = op.BnInOp2Lbi(ibn); std::string lbn = GenLogicalBlobName(lbi); @@ -372,18 +369,24 @@ void JobBuildAndInferCtx::InitIbn2DisableBoxing(const Operator& op, } } -void JobBuildAndInferCtx::UpdateLbi2DisableBoxing( - const Operator& op, const HashMap& ibn2disable_boxing) { - bool disable_boxing = false; - for (const auto& ibn : op.input_bns()) { - if (ibn2disable_boxing.at(ibn)) { - disable_boxing = true; - break; +Maybe JobBuildAndInferCtx::InitConstraitNdSbpSignature( + const Operator& op, const HashMap& ibn2disable_boxing) const { + auto nd_sbp_sig = std::make_shared(); + for (const auto& it : ibn2disable_boxing) { + if (it.second) { + const auto& ibn = it.first; + const LogicalBlobId& lbi = op.BnInOp2Lbi(ibn); + const auto& nd_sbp_iter = lbi2nd_sbp_from_producer_view_.find(lbi); + if (nd_sbp_iter == lbi2nd_sbp_from_producer_view_.end()) { + return Error::RuntimeError() + << "The nd_sbp of input " << ibn << " (tensor name is " << GenLogicalBlobName(lbi) + << ") is not found for operation " << op.op_name() + << ". It maybe caused by an invalid inplace operation."; + } + (*(nd_sbp_sig->mutable_bn_in_op2nd_sbp()))[ibn] = lbi2nd_sbp_from_producer_view_.at(lbi); } } - for (const auto& obn : op.output_bns()) { - lbi2disable_boxing_[op.BnInOp2Lbi(obn)] = disable_boxing; - } + return nd_sbp_sig; } bool JobBuildAndInferCtx::HasAnyMirroredBlobInput(const Operator& op) const { @@ -572,12 +575,11 @@ Maybe JobBuildAndInferCtx::AddAndInferOp(const OperatorConf& op_con cfg::SbpSignature sbp_sig_conf; HashMap ibn2disable_boxing; InitIbn2DisableBoxing(*op, &ibn2disable_boxing); - auto new_op_conf = JUST(DecodeLbiHintAndReturnNewOpConf(*op, &sbp_sig_conf, &ibn2disable_boxing)); + auto new_op_conf = JUST(DecodeLbiHintAndReturnNewOpConf(*op, &sbp_sig_conf)); auto parallel_conf = JUST(InferOpParallelConf(*op, origin_parallel_conf, ibn2disable_boxing)); ParallelDesc parallel_desc(*parallel_conf); JUST(op->FillOpParallelDesc(parallel_desc)); JUST(AddOpNameParallelConf2Placement(op_name, *parallel_conf)); - UpdateLbi2DisableBoxing(*op, ibn2disable_boxing); auto GetBlobDesc4BnInOp = [&](const std::string& bn) -> BlobDesc* { const LogicalBlobId& lbi = op->BnInOp2Lbi(bn); @@ -592,8 +594,11 @@ Maybe JobBuildAndInferCtx::AddAndInferOp(const OperatorConf& op_con JUST(InferMirroredSignature(op, is_mirrored_parallel_view, parallel_desc)); // infer nd_sbp signature - cfg::NdSbpSignature nd_sbp_sig_conf; - SbpSignatureToNdSbpSignature(sbp_sig_conf, &nd_sbp_sig_conf); + cfg::NdSbpSignature nd_sbp_sig_conf = *JUST(InitConstraitNdSbpSignature(*op, ibn2disable_boxing)); + // Override constrait nd_sbp if sbp hint is given + if (!sbp_sig_conf.bn_in_op2sbp_parallel().empty()) { + SbpSignatureToNdSbpSignature(sbp_sig_conf, &nd_sbp_sig_conf); + } AddOpAndUpdateJobParallelViewConf(*new_op_conf, parallel_desc, nd_sbp_sig_conf, is_mirrored_parallel_view); JUST(InferOpOutNdSbp(op, nd_sbp_sig_conf, parallel_desc)); @@ -663,7 +668,7 @@ Maybe JobBuildAndInferCtx::IsDynamic(const std::string& lbn) const { return lbi2logical_blob_desc_.at(GenLogicalBlobId(lbn))->is_dynamic(); } -Maybe JobBuildAndInferCtx::DisableBoxing(const std::string& lbn) const { +Maybe JobBuildAndInferCtx::IsDisableBoxing(const std::string& lbn) const { JUST(CheckLbnValidAndExist(lbn)); LogicalBlobId lbi(GenLogicalBlobId(lbn)); const auto& iter = lbi2disable_boxing_.find(lbi); @@ -671,6 +676,13 @@ Maybe JobBuildAndInferCtx::DisableBoxing(const std::string& lbn) const { return iter->second; } +Maybe JobBuildAndInferCtx::DisableBoxing(const std::string& lbn) { + JUST(CheckLbnValidAndExist(lbn)); + LogicalBlobId lbi(GenLogicalBlobId(lbn)); + lbi2disable_boxing_[lbi] = true; + return Maybe::Ok(); +} + Maybe JobBuildAndInferCtx::Op4OpName(const std::string& op_name) const { const auto& op_iter = op_name2op_.find(op_name); CHECK_OR_RETURN(op_iter != op_name2op_.end()); @@ -965,8 +977,36 @@ Maybe LazyJobBuildAndInferCtx::Complete() { Global::Delete(); auto scope = std::make_unique(mut_job()->job_conf(), job_id()); JobPassCtx job_pass_ctx(GlobalJobDesc()); - auto DoPass = [&](const std::string& pass_name) -> Maybe { - return JobPass4Name(pass_name)(mut_job(), &job_pass_ctx); + const auto& job_name = job().job_conf().job_name(); + auto LogJob = [&](const std::string& name_suffix) -> void { + std::string full_log_name = + job_name + "-job_id_" + std::to_string(job_id()) + "-" + name_suffix; + TeePersistentLogStream::Create(full_log_name)->Write(job()); + Global::New(job()); + Global::Get()->ToDotWithFilePath(full_log_name + ".dot"); + Global::Delete(); + }; + std::string debug_pass_name = GetStringFromEnv("ONEFLOW_DEBUG_PASS", ""); + auto NeedLogJob = [&](const std::string& pass_name) -> bool { + if ("ALL" == debug_pass_name) { + return true; + } else if (pass_name == debug_pass_name) { + return true; + } else { + return false; + } + }; + auto DoPass = [&](const std::string& pass_name, int32_t cnt = 0) -> Maybe { + if (unlikely(NeedLogJob(pass_name))) { + std::string cnt_str = cnt > 0 ? std::to_string(cnt) : ""; + LogJob(pass_name + cnt_str + "-before"); + } + JUST(JobPass4Name(pass_name)(mut_job(), &job_pass_ctx)); + if (unlikely(NeedLogJob(pass_name))) { + std::string cnt_str = cnt > 0 ? std::to_string(cnt) : ""; + LogJob(pass_name + cnt_str + "-after"); + } + return Maybe::Ok(); }; if (Global::Get()->enable_debug_mode() @@ -1007,7 +1047,7 @@ Maybe LazyJobBuildAndInferCtx::Complete() { JUST(DoPass("FuseAddToOutputPass")); // run this pass again to fuse ops created in the first run. // TODO(guoran): loop multiple times inside the pass - JUST(DoPass("FuseAddToOutputPass")); + JUST(DoPass("FuseAddToOutputPass", 1)); JUST(DoPass("IndexedSlicesOptimizerRewritePass")); JUST(DoPass("SplitSparseSoftmaxCrossEntropyOpPass")); JUST(DoPass("DoParallelCastBeforeWideningTypeCast")); @@ -1197,8 +1237,8 @@ void FormateVariableConf(nlohmann::json& json_conf) { std::string oneflow::JobBuildAndInferCtx::GetJobStructureGraphJson( const std::string& job_name) const { - HashSet input_op_names; - HashSet output_op_names; + HashSet inputs_op_names; + HashSet outputs_op_names; std::vector layers_vec; layers_vec.reserve(op_name2op_.size()); for (const auto& pair : op_name2op_) { @@ -1215,10 +1255,10 @@ std::string oneflow::JobBuildAndInferCtx::GetJobStructureGraphJson( } if (op->op_conf().has_input_conf() && op->op_conf().has_return_conf()) { - input_op_names.insert(op_name); + inputs_op_names.insert(op_name); } if (op->op_conf().has_output_conf() && op->op_conf().has_return_conf()) { - output_op_names.insert(op_name); + outputs_op_names.insert(op_name); } json_layers_pair["name"] = op_name; @@ -1242,8 +1282,8 @@ std::string oneflow::JobBuildAndInferCtx::GetJobStructureGraphJson( nlohmann::json json_pair; json_pair["name"] = job_name; json_pair["layers"] = layers_vec; - json_pair["input_layers"] = input_op_names; - json_pair["output_layers"] = output_op_names; + json_pair["input_layers"] = inputs_op_names; + json_pair["output_layers"] = outputs_op_names; return json_pair.dump(); } diff --git a/oneflow/core/job/job_build_and_infer_ctx.h b/oneflow/core/job/job_build_and_infer_ctx.h index 48a3416b56b..3466a31c63a 100644 --- a/oneflow/core/job/job_build_and_infer_ctx.h +++ b/oneflow/core/job/job_build_and_infer_ctx.h @@ -44,7 +44,8 @@ class JobBuildAndInferCtx { Maybe GetStaticShape(const std::string& lbn) const; Maybe GetDataType(const std::string& lbn) const; Maybe IsDynamic(const std::string& lbn) const; - Maybe DisableBoxing(const std::string& lbn) const; + Maybe IsDisableBoxing(const std::string& lbn) const; + Maybe DisableBoxing(const std::string& lbn); Maybe GetSplitAxisFromProducerView(const std::string& lbn) const; Maybe GetParallelDescFromProducerView(const std::string& lbn) const; @@ -107,13 +108,12 @@ class JobBuildAndInferCtx { Maybe AddOpNameParallelConf2Placement(const std::string& op_name, const ParallelConf& parallel_conf); void InitIbn2DisableBoxing(const Operator& op, HashMap* ibn2disable_boxing); - void UpdateLbi2DisableBoxing(const Operator& op, - const HashMap& ibn2disable_boxing); + Maybe InitConstraitNdSbpSignature( + const Operator& op, const HashMap& ibn2disable_boxing) const; + Maybe DecodeLbiHintAndReturnNewOpConf(const Operator& op, + cfg::SbpSignature* sbp_sig_conf) const; Maybe AddLbiParallelConf2BlobPlacement( const Operator* op, std::function ParallelDesc4Obn); - Maybe DecodeLbiHintAndReturnNewOpConf( - const Operator& op, cfg::SbpSignature* sbp_sig_conf, - HashMap* ibn2disable_boxing) const; void AddOpAndUpdateJobParallelViewConf(const OperatorConf& operator_conf, const ParallelDesc& parallel_desc, const cfg::NdSbpSignature& nd_sbp_signature, diff --git a/oneflow/core/job/job_build_and_infer_ctx_mgr.cpp b/oneflow/core/job/job_build_and_infer_ctx_mgr.cpp index 3e443c1fd74..26abf3402a3 100644 --- a/oneflow/core/job/job_build_and_infer_ctx_mgr.cpp +++ b/oneflow/core/job/job_build_and_infer_ctx_mgr.cpp @@ -19,7 +19,7 @@ limitations under the License. #include "oneflow/core/common/util.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/lazy_mode.h" -#include +#include "nlohmann/json.hpp" namespace oneflow { diff --git a/oneflow/core/job/job_builder.cpp b/oneflow/core/job/job_builder.cpp index d8569605bde..fc195b15b1f 100644 --- a/oneflow/core/job/job_builder.cpp +++ b/oneflow/core/job/job_builder.cpp @@ -157,7 +157,7 @@ Maybe JobBuilder::OpConf4OpName(const std::string& op_name) return *JUST(MapAt(op_name2op_conf_, op_name)); } -const ParallelConf& JobBuilder::ParallelConf4Lbi(const LogicalBlobId& lbi) const { +Maybe JobBuilder::ParallelConf4Lbi(const LogicalBlobId& lbi) const { const auto& iter = lbi2blob_parallel_conf_.find(lbi); if (iter != lbi2blob_parallel_conf_.end()) { return *iter->second; } return ParallelConf4OpName(lbi.op_name()); @@ -316,9 +316,9 @@ Maybe JobBuilder::ForEachOperator( return Maybe::Ok(); } -const ParallelConf& JobBuilder::ParallelConf4OpName(const std::string& op_name) const { +Maybe JobBuilder::ParallelConf4OpName(const std::string& op_name) const { const auto& iter = op_name2parallel_conf_.find(op_name); - CHECK(iter != op_name2parallel_conf_.end()); + CHECK_OR_RETURN(iter != op_name2parallel_conf_.end()); return *iter->second; } diff --git a/oneflow/core/job/job_builder.h b/oneflow/core/job/job_builder.h index 5238194131c..ec03fa278c6 100644 --- a/oneflow/core/job/job_builder.h +++ b/oneflow/core/job/job_builder.h @@ -71,8 +71,8 @@ class JobBuilder final { void SetNdSbp4Oba(const OpBlobArg& oba, const cfg::NdSbp& nd_sbp); Maybe ForEachOperator(const std::function(const Operator&)>& Handler) const; - const ParallelConf& ParallelConf4Lbi(const LogicalBlobId& lbi) const; - const ParallelConf& ParallelConf4OpName(const std::string& op_name) const; + Maybe ParallelConf4Lbi(const LogicalBlobId& lbi) const; + Maybe ParallelConf4OpName(const std::string& op_name) const; const cfg::SbpSignature SbpSignature4OpName(const std::string& op_name) const; void AddSbpSignature4OpName(const std::string& op_name, const cfg::SbpSignature& sbp_signature); diff --git a/oneflow/core/job/job_instance.h b/oneflow/core/job/job_instance.h index dc9ed7c03ba..5be11bd8a26 100644 --- a/oneflow/core/job/job_instance.h +++ b/oneflow/core/job/job_instance.h @@ -29,12 +29,6 @@ class JobInstance { virtual std::string job_name() const { UNIMPLEMENTED(); } virtual std::string sole_input_op_name_in_user_job() const { UNIMPLEMENTED(); } virtual std::string sole_output_op_name_in_user_job() const { UNIMPLEMENTED(); } - virtual void PushBlobByOpName(uint64_t ofblob_ptr, const std::string& op_name) const { - UNIMPLEMENTED(); - } - virtual void PullBlobByOpName(uint64_t ofblob_ptr, const std::string& op_name) const { - UNIMPLEMENTED(); - } virtual void PushBlob(uint64_t ofblob_ptr) const { UNIMPLEMENTED(); } virtual void PullBlob(uint64_t ofblob_ptr) const { UNIMPLEMENTED(); } virtual void Finish() const { UNIMPLEMENTED(); } diff --git a/oneflow/core/job/job_ir.cpp b/oneflow/core/job/job_ir.cpp new file mode 100644 index 00000000000..792735a0354 --- /dev/null +++ b/oneflow/core/job/job_ir.cpp @@ -0,0 +1,32 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/job/job_ir.h" + +namespace oneflow { + +#ifndef WITH_MLIR + +Maybe SaveJobToIR(Job* job, const std::string& path) { + UNIMPLEMENTED_THEN_RETURN() << "SaveJobToIR is only supported WITH_MLIR"; +} + +Maybe LoadJobFromIR(Job* job, const std::string& path) { + UNIMPLEMENTED_THEN_RETURN() << "LoadJobFromIR is only supported WITH_MLIR"; +} + +#endif + +} // namespace oneflow diff --git a/oneflow/core/framework/data_consistency_check.h b/oneflow/core/job/job_ir.h similarity index 62% rename from oneflow/core/framework/data_consistency_check.h rename to oneflow/core/job/job_ir.h index a6f6072c04f..c57d0eebeb8 100644 --- a/oneflow/core/framework/data_consistency_check.h +++ b/oneflow/core/job/job_ir.h @@ -13,18 +13,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifndef ONEFLOW_CORE_FRAMEWORK_DATA_CONSISTENCY_CHECK_H_ -#define ONEFLOW_CORE_FRAMEWORK_DATA_CONSISTENCY_CHECK_H_ +#ifndef ONEFLOW_CORE_JOB_JOB_IR_H_ +#define ONEFLOW_CORE_JOB_JOB_IR_H_ #include "oneflow/core/common/maybe.h" -#include "oneflow/core/common/symbol.h" -#include "oneflow/core/job/parallel_desc.h" +#include "oneflow/core/job/job.pb.h" namespace oneflow { -Maybe DataConsistencyCheck(const void* buffer_ptr, size_t buffer_size, - Symbol placement); +Maybe SaveJobToIR(Job* job, const std::string& path); +Maybe LoadJobFromIR(Job* job, const std::string& path); } // namespace oneflow -#endif // ONEFLOW_CORE_FRAMEWORK_DATA_CONSISTENCY_CHECK_H_ +#endif // ONEFLOW_CORE_JOB_JOB_IR_H_ diff --git a/oneflow/core/job/oneflow.cpp b/oneflow/core/job/oneflow.cpp index 6dd0cc30d60..1e23e5d1511 100644 --- a/oneflow/core/job/oneflow.cpp +++ b/oneflow/core/job/oneflow.cpp @@ -388,7 +388,7 @@ void GetMemSharingOpBlobInfo(const JobBuilder& job_builder, const std::string& o } const auto& job = job_builder.job(); ParallelBlobConf ret; - *blob_conf->mutable_parallel_conf() = job_builder.ParallelConf4OpName(op_name); + *blob_conf->mutable_parallel_conf() = CHECK_JUST(job_builder.ParallelConf4OpName(op_name)); *blob_conf->mutable_logical_blob_desc_conf() = job.helper().lbn2logical_blob_desc().at(lbn); *blob_conf->mutable_nd_sbp() = job.job_parallel_view_conf().op_name2nd_sbp_signature_conf().at(op_name).bn_in_op2nd_sbp().at( diff --git a/oneflow/core/job/parallel_desc.cpp b/oneflow/core/job/parallel_desc.cpp index c375c98ad96..37fe3ccd905 100644 --- a/oneflow/core/job/parallel_desc.cpp +++ b/oneflow/core/job/parallel_desc.cpp @@ -13,10 +13,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/placement.cfg.h" #include "oneflow/core/common/decorator.h" #include "oneflow/core/common/util.h" +#include "oneflow/core/common/multi_client.h" +#include "oneflow/core/common/cpp_attribute.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/id_manager.h" #include "oneflow/core/control/global_process_ctx.h" @@ -24,11 +27,24 @@ limitations under the License. #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/vm/vm_util.h" +#ifdef WITH_CUDA +#include +#endif // WITH_CUDA namespace oneflow { namespace { +int64_t GetGpuDeviceNum() { +#ifndef WITH_CUDA + return 0; +#else + int device_count = 0; + cudaGetDeviceCount(&device_count); + return device_count; +#endif +} + using MachineId2DeviceIdList = std::shared_ptr>>>; @@ -302,6 +318,34 @@ Maybe ParallelDesc::CheckWithResourceDesc(const ResourceDesc& resource_des return Maybe::Ok(); } +Maybe ParallelDesc::CheckDeviceIdsIsValid() const { + if (likely(JUST(IsMultiClient()))) { + const auto& sorted_dev_phy_ids_iter = + machine_id2sorted_dev_phy_ids_->find(GlobalProcessCtx::Rank()); + if (sorted_dev_phy_ids_iter != machine_id2sorted_dev_phy_ids_->end()) { + for (int64_t dev_phy_id : *sorted_dev_phy_ids_iter->second) { + if (device_type_ == DeviceType::kCUDA) { + const int64_t gpu_device_num = GetGpuDeviceNum(); + CHECK_NE_OR_RETURN(gpu_device_num, 0) + << "Placment with \"cuda\" type is invalid because there is no CUDA device!"; + int64_t device_num = std::min(GlobalProcessCtx::NumOfProcessPerNode(), gpu_device_num); + CHECK_LT_OR_RETURN(dev_phy_id, device_num) + << "Placment is invalid because device id must be less than " + << (gpu_device_num < GlobalProcessCtx::NumOfProcessPerNode() + ? "num of CUDA devices on node" + : "num of process per node"); + } else if (device_type_ == DeviceType::kCPU) { + CHECK_LT_OR_RETURN(dev_phy_id, GlobalProcessCtx::NumOfProcessPerNode()) + << "Placment is invalid because device id must be less than num of process per node"; + } else { + OF_UNIMPLEMENTED(); + } + } + } + } + return Maybe::Ok(); +} + ParallelConf ParallelDesc::GetParallelIdOnlyParallelConf(int64_t parallel_id) const { ParallelConf parallel_conf; std::string rank = std::to_string(CHECK_JUST(MachineId4ParallelId(parallel_id))); @@ -456,6 +500,11 @@ Maybe> RawTxtStringToPlacement(const std::string& parallel_ return SymbolOf(ParallelDesc(parallel_conf)); } +Maybe RawCheckDeviceIdsIsValid(Symbol placement) { + JUST(placement->CheckDeviceIdsIsValid()); + return Maybe::Ok(); +} + } // namespace decltype(GetParallelId4CurrentProcessCtx) GetParallelId4CurrentProcessCtx = @@ -467,5 +516,7 @@ decltype(PlacementToString) PlacementToString = DECORATE(&RawPlacementToString, decltype(GetTensorDevice) GetTensorDevice = DECORATE(&RawGetTensorDevice, ThreadLocal); decltype(TxtStringToPlacement) TxtStringToPlacement = DECORATE(&RawTxtStringToPlacement, ThreadLocalCopiable); +decltype(CheckDeviceIdsIsValid) CheckDeviceIdsIsValid = + DECORATE(&RawCheckDeviceIdsIsValid, ThreadLocal); } // namespace oneflow diff --git a/oneflow/core/job/parallel_desc.h b/oneflow/core/job/parallel_desc.h index 4b43fba0d85..9848334950a 100644 --- a/oneflow/core/job/parallel_desc.h +++ b/oneflow/core/job/parallel_desc.h @@ -107,6 +107,7 @@ class ParallelDesc final { std::shared_ptr cfg_parallel_conf() const { return cfg_parallel_conf_; } bool TryGetParallelId(int64_t machine_id, int64_t device_id, int64_t* parallel_id) const; + Maybe CheckDeviceIdsIsValid() const; private: friend Maybe ParseMachineAndDeviceIdList(const ParallelConf& parallel_conf); @@ -149,6 +150,7 @@ extern Maybe> (*ReplaceDeviceType)(Symbol, De extern Maybe (*PlacementToString)(Symbol placement); extern Maybe> (*GetTensorDevice)(Symbol parallel_desc); extern Maybe> (*TxtStringToPlacement)(const std::string& parallel_conf_str); +extern Maybe (*CheckDeviceIdsIsValid)(Symbol placement); inline bool operator==(const ParallelConf& lhs, const ParallelConf& rhs) { return ParallelDesc(lhs) == ParallelDesc(rhs); diff --git a/oneflow/core/job/plan_util.cpp b/oneflow/core/job/plan_util.cpp index 4a8e687d4e2..a404999eb36 100644 --- a/oneflow/core/job/plan_util.cpp +++ b/oneflow/core/job/plan_util.cpp @@ -163,7 +163,7 @@ void GenChunkForMultiNNGraphMemoryReuseInMultiClient( CHECK_LE(current_chunk_offset + mem_block->mem_size(), chunk->mem_size()); CHECK_GE(current_chunk_offset, 0); // CHECK_GT(mem_block->mem_size(), 0); NOTE(chengcheng): has mem block mem size = 0 - CHECK_GT(chunk->mem_size(), 0); + CHECK_GE(chunk->mem_size(), 0); mem_block->set_chunk_id(chunk->chunk_id()); mem_block->set_chunk_offset(current_chunk_offset); current_chunk_offset += mem_block->mem_size(); diff --git a/oneflow/core/job/resource_desc.cpp b/oneflow/core/job/resource_desc.cpp index 81392e757f0..7ba811cce38 100644 --- a/oneflow/core/job/resource_desc.cpp +++ b/oneflow/core/job/resource_desc.cpp @@ -80,7 +80,6 @@ bool ResourceDesc::nccl_use_compute_stream() const { } void ResourceDesc::DumpCudnnConf(const JobConfigProto& job_conf) { - resource_.clear_cudnn_conf(); auto* cudnn_conf = resource_.mutable_cudnn_conf(); if (job_conf.has_enable_cudnn()) { cudnn_conf->set_enable_cudnn(job_conf.enable_cudnn()); } if (job_conf.has_cudnn_buf_limit_mbyte()) { diff --git a/oneflow/core/job/sbp_parallel.cpp b/oneflow/core/job/sbp_parallel.cpp index 40b48310137..221cd739463 100644 --- a/oneflow/core/job/sbp_parallel.cpp +++ b/oneflow/core/job/sbp_parallel.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/common/protobuf.h" +#include "oneflow/core/framework/nd_sbp.h" namespace oneflow { @@ -179,17 +180,7 @@ bool ParseSbpParallelFromString(const std::string& sbp_str, cfg::SbpParallel* sb } std::string SbpParallelToString(const cfg::SbpParallel& sbp_parallel) { - std::string sbp_str = ""; - if (sbp_parallel.has_broadcast_parallel()) { - sbp_str = "B"; - } else if (sbp_parallel.has_partial_sum_parallel()) { - sbp_str = "P"; - } else if (sbp_parallel.has_split_parallel()) { - sbp_str = "S(" + std::to_string(sbp_parallel.split_parallel().axis()) + ")"; - } else { - UNIMPLEMENTED(); - } - return sbp_str; + return SbpToString(sbp_parallel); } void SbpSignatureToNdSbpSignature(const cfg::SbpSignature& sbp_signature, @@ -227,4 +218,46 @@ void CheckSbpSignatureAndNdSbpEquals(const cfg::SbpSignature& sbp_sig, } } +Maybe SbpSignatureListAsString(const cfg::SbpSignatureList& sbp_signatures, + const PbRpf& inputs, + const PbRpf& outputs) { + std::ostringstream ss; + if (sbp_signatures.sbp_signature_size() == 0) { return ss.str(); } + + auto WalkIO = + [&](const std::function(const std::string&)>& bn_handler) -> Maybe { + ss << "("; + for (size_t i = 0; i < inputs.size(); ++i) { + ss << *JUST(bn_handler(inputs[i])); + if (i != inputs.size() - 1) { ss << ", "; } + } + ss << ") -> ("; + for (size_t i = 0; i < outputs.size(); ++i) { + ss << *JUST(bn_handler(outputs[i])); + if (i != outputs.size() - 1) { ss << ", "; } + } + ss << ")"; + return Maybe::Ok(); + }; + + JUST(WalkIO([](const std::string& bn) -> Maybe { return bn; })); + ss << ": "; + + ss << "[\n"; + for (const auto& sbp_signature : sbp_signatures.sbp_signature()) { + ss << "\t"; + JUST(WalkIO([&](const std::string& bn) -> Maybe { + auto it = sbp_signature.bn_in_op2sbp_parallel().find(bn); + if (it == sbp_signature.bn_in_op2sbp_parallel().end()) { + return Error::RuntimeError() + << "can't find " << bn << "in SbpSignature: " << sbp_signature.DebugString(); + } + return SbpParallelToString(it->second); + })); + ss << ",\n"; + } + ss << "]"; + return ss.str(); +} + } // namespace oneflow diff --git a/oneflow/core/job/sbp_parallel.h b/oneflow/core/job/sbp_parallel.h index ac8a909cc3c..91c7382f52e 100644 --- a/oneflow/core/job/sbp_parallel.h +++ b/oneflow/core/job/sbp_parallel.h @@ -62,6 +62,10 @@ void NdSbpSignatureToSbpSignature(const NdSbpSignatureT& nd_sbp_signature, void CheckSbpSignatureAndNdSbpEquals(const cfg::SbpSignature& sbp_sig, const cfg::NdSbpSignature& nd_sbp_sig); +Maybe SbpSignatureListAsString(const cfg::SbpSignatureList& sbp_signatures, + const PbRpf& inputs, + const PbRpf& outputs); + } // namespace oneflow namespace std { diff --git a/oneflow/core/job/task.proto b/oneflow/core/job/task.proto index 83e6f6ff3bf..e4df1c4a0db 100644 --- a/oneflow/core/job/task.proto +++ b/oneflow/core/job/task.proto @@ -37,6 +37,7 @@ enum TaskType { kCollectiveBoxingUnpack = 62; kSspVariableProxy = 63; kBoxingZeros = 64; + kCriticalSectionWaitTick = 65; }; message RegstDescIdSet { diff --git a/oneflow/core/job_rewriter/adam_optm.cpp b/oneflow/core/job_rewriter/adam_optm.cpp index 6d9f81c4210..b325a992317 100644 --- a/oneflow/core/job_rewriter/adam_optm.cpp +++ b/oneflow/core/job_rewriter/adam_optm.cpp @@ -136,9 +136,9 @@ void GenerateOptimizerOpConf(JobPassCtx* ctx, const OpNode& var_op_node, auto* state = CHECK_JUST(ctx->MutableState(job_pass_state_key)); ParallelConf bias_correction_parallel_conf; const auto& lr_parallel_conf = - job_builder->ParallelConf4Lbi(GenLogicalBlobId(learning_rate_lbn)); + CHECK_JUST(job_builder->ParallelConf4Lbi(GenLogicalBlobId(learning_rate_lbn))); const auto& train_step_parallel_conf = - job_builder->ParallelConf4Lbi(GenLogicalBlobId(train_step_lbn)); + CHECK_JUST(job_builder->ParallelConf4Lbi(GenLogicalBlobId(train_step_lbn))); if (lr_parallel_conf == train_step_parallel_conf) { bias_correction_parallel_conf = lr_parallel_conf; } else { diff --git a/oneflow/core/job_rewriter/add_lbi_diff_watcher.cpp b/oneflow/core/job_rewriter/add_lbi_diff_watcher.cpp index 16987f3d96b..067381a0661 100644 --- a/oneflow/core/job_rewriter/add_lbi_diff_watcher.cpp +++ b/oneflow/core/job_rewriter/add_lbi_diff_watcher.cpp @@ -55,7 +55,7 @@ Maybe AddLbiDiffWatcherOpConfs::Apply(Job* job) const { auto* foreign_watcher_conf = foreign_watcher_op.mutable_foreign_watch_conf(); foreign_watcher_conf->set_in(GenLogicalBlobName(diff_lbi)); foreign_watcher_conf->set_handler_uuid(pair.watcher_uuid()); - job_builder.AddOps(job_builder.ParallelConf4Lbi(pair.lbi()), {foreign_watcher_op}); + job_builder.AddOps(JUST(job_builder.ParallelConf4Lbi(pair.lbi())), {foreign_watcher_op}); } return Maybe::Ok(); } diff --git a/oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp b/oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp index a81392646b1..2fcf657980d 100644 --- a/oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp +++ b/oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp @@ -24,7 +24,8 @@ const AMPList& AutoMixedPrecisionLists::WhiteList() { "amp_white_identity", "broadcast_matmul", "fused_self_attention_query_mul_key_and_value", - "prelu"}; + "prelu", + "tf_prelu"}; return white_list; } diff --git a/oneflow/core/job_rewriter/autograd.cpp b/oneflow/core/job_rewriter/autograd.cpp index 22b1b52849e..5e9ddacc6c4 100644 --- a/oneflow/core/job_rewriter/autograd.cpp +++ b/oneflow/core/job_rewriter/autograd.cpp @@ -255,7 +255,7 @@ Maybe TryMirroredCastTotalLossInstanceNum( cast_from_mirrored->set_in(GenLogicalBlobName(*total_loss_instance_num_lbi)); cast_from_mirrored->set_out("out"); cast_from_mirrored->mutable_sbp_parallel()->mutable_partial_sum_parallel(); - const auto& parallel_conf = job_builder->ParallelConf4Lbi(*total_loss_instance_num_lbi); + const auto& parallel_conf = JUST(job_builder->ParallelConf4Lbi(*total_loss_instance_num_lbi)); int64_t scope_symbol_id = 0; { const std::shared_ptr& cfg_job_conf = @@ -821,8 +821,8 @@ Maybe AutoGrad(JobPassCtx* ctx, const OpGraph& op_graph, JobBuilder* job_b auto LogicalBlobDesc4BnInOp = [&](const std::string& bn) -> const BlobDesc& { return op_graph.GetLogicalBlobDesc(op_node->op().BnInOp2Lbi(bn)); }; - GenerateCloneGradOpIfNeed(*op_node, job_builder, in_oba2in_diff_lbi, &out_oba2out_diff_lbi, - &out_oba2clone_bw_add_out_lbi); + JUST(GenerateCloneGradOpIfNeed(*op_node, job_builder, in_oba2in_diff_lbi, &out_oba2out_diff_lbi, + &out_oba2clone_bw_add_out_lbi)); std::vector ops; JUST(GenerateBackwardOpConfIf(op_node->op(), &ops, DiffLbi4BnInOp, LogicalBlobDesc4BnInOp)); int64_t scope_symbol_id = op_node->op().op_conf().scope_symbol_id(); diff --git a/oneflow/core/job_rewriter/autotick.cpp b/oneflow/core/job_rewriter/autotick.cpp index 7ca9856093f..60809b085f2 100644 --- a/oneflow/core/job_rewriter/autotick.cpp +++ b/oneflow/core/job_rewriter/autotick.cpp @@ -19,7 +19,11 @@ limitations under the License. #include "oneflow/core/job_rewriter/autotick.h" #include "oneflow/core/job/job_builder.h" #include "oneflow/core/job/critical_section_desc.h" +#include "oneflow/core/common/protobuf.h" +#include "oneflow/core/common/container_util.h" +#include "oneflow/core/common/buffer_manager.h" #include "oneflow/core/job/global_for.h" +#include "oneflow/core/common/multi_client.h" namespace oneflow { @@ -57,7 +61,7 @@ void PrependTickByParallelDesc(const OpGraph& op_graph, JobBuilder* job_builder) } } -Maybe FindSrcSubsetTickOpConf(const Job& job) { +Maybe FindJobSoleSrcSubsetTickOpConf(const Job& job) { const OperatorConf* src_subset_tick_op_conf = nullptr; for (const auto& op_conf : job.net().op()) { if (!op_conf.has_src_subset_tick_conf()) { continue; } @@ -561,7 +565,7 @@ Maybe AutoSourceAndSinkTick( CHECK_OR_RETURN(tick_lbis.emplace(op_node->op().BnInOp2Lbi(op_node->op().SoleObn())).second); return Maybe::Ok(); })); - OperatorConf src_subset_tick = JUST(FindSrcSubsetTickOpConf(job_builder->job())); + OperatorConf src_subset_tick = JUST(FindJobSoleSrcSubsetTickOpConf(job_builder->job())); JUST(CreateSourceTicksAndSrcSubsetTick(&src_subset_tick, job_builder, DoEachSrc)); JUST(CreateDstSubsetTickAndSinkTicks(src_subset_tick, tick_lbis, job_builder, DoEachSink)); return Maybe::Ok(); @@ -614,6 +618,199 @@ Maybe MultiClientAutoSourceAndSinkTick(const OpGraph& op_graph, Job* job) return Maybe::Ok(); } +namespace { + +Maybe InsertCriticalSectionSrcAndDstTicks( + const std::vector& interface_op_nodes, JobBuilder* job_builder, + std::vector* interface_src_tick_op_names, + std::vector* interface_dst_tick_lbns) { + HashMap> parallel_desc2interface_op_nodes; + for (const auto* op_node : interface_op_nodes) { + parallel_desc2interface_op_nodes[op_node->parallel_desc()].push_back(op_node); + } + for (const auto& pair : parallel_desc2interface_op_nodes) { + const auto& parallel_conf = pair.first.parallel_conf(); + for (const auto* op_node : pair.second) { + OperatorConf interface_op(op_node->op().op_conf()); + { + OperatorConf device_tick_op; + device_tick_op.set_name("System-EagerCriticalSection-Interface-Begin-Tick-" + + NewUniqueId()); + auto* device_tick_op_conf = device_tick_op.mutable_device_tick_conf(); + device_tick_op_conf->set_out("out"); + interface_src_tick_op_names->push_back(device_tick_op.name()); + JUST(job_builder->AddOp(parallel_conf, device_tick_op)); + interface_op.add_ctrl_in_op_name(device_tick_op.name()); + JUST(job_builder->MutOpOnlyOnce(interface_op)); + } + { + OperatorConf device_tick_op; + device_tick_op.set_name("System-EagerCriticalSection-Interface-End-Tick-" + NewUniqueId()); + device_tick_op.add_ctrl_in_op_name(interface_op.name()); + auto* device_tick_op_conf = device_tick_op.mutable_device_tick_conf(); + device_tick_op_conf->set_out("out"); + interface_dst_tick_lbns->push_back(device_tick_op.name() + "/out"); + JUST(job_builder->AddOp(parallel_conf, device_tick_op)); + } + } + } + return Maybe::Ok(); +} + +Maybe InsertSrcSubsetTickAndDstSubsetTick( + const std::vector& interface_src_tick_op_names, + const std::vector& interface_dst_tick_lbns, JobBuilder* job_builder, + std::string* src_subset_tick_op_name, LogicalBlobId* dst_subset_tick_lbi) { + { + OperatorConf src_subset_tick; + JUST(BuildSrcSubsetTickOpAndParallelConf(&src_subset_tick, job_builder)); + *src_subset_tick_op_name = src_subset_tick.name(); + } + for (const auto& op_name : interface_src_tick_op_names) { + OperatorConf op_conf(JUST(job_builder->OpConf4OpName(op_name))); + CHECK_OR_RETURN(op_conf.has_device_tick_conf()); + op_conf.mutable_device_tick_conf()->add_tick(*src_subset_tick_op_name + "/out"); + JUST(job_builder->MutOpOnlyOnce(op_conf)); + } + HashSet dst_subset_tick_input_lbis; + dst_subset_tick_input_lbis.insert(GenLogicalBlobId(*src_subset_tick_op_name + "/out")); + for (const auto& lbn : interface_dst_tick_lbns) { + const auto& lbi = GenLogicalBlobId(lbn); + CHECK_OR_RETURN(dst_subset_tick_input_lbis.insert(lbi).second); + } + { + OperatorConf dst_subset_tick_op; + JUST(BuildDstSubsetTickOpAndParallelConf(dst_subset_tick_input_lbis, &dst_subset_tick_op, + job_builder)); + dst_subset_tick_lbi->set_op_name(dst_subset_tick_op.name()); + CHECK_OR_RETURN(dst_subset_tick_op.has_dst_subset_tick_conf()); + dst_subset_tick_lbi->set_blob_name(dst_subset_tick_op.dst_subset_tick_conf().out()); + } + return Maybe::Ok(); +} + +Maybe InsertCriticalSectionWaitTicks(const OpGraph& op_graph, JobBuilder* job_builder, + const std::string& src_subset_tick_op_name, + const std::string& wait_buffer_name) { + std::vector wait_and_send_id_op_nodes; + op_graph.ForEachNode([&](OpNode* op_node) { + if (!op_node->op().op_conf().has_wait_and_send_ids_conf()) { return; } + wait_and_send_id_op_nodes.push_back(op_node); + }); + CHECK_GT_OR_RETURN(wait_and_send_id_op_nodes.size(), 0); + OperatorConf src_subset_tick_op(JUST(job_builder->OpConf4OpName(src_subset_tick_op_name))); + CHECK_OR_RETURN(src_subset_tick_op.has_src_subset_tick_conf()); + for (const OpNode* wait_and_send_id_op_node : wait_and_send_id_op_nodes) { + LogicalBlobId lbi; + lbi.set_op_name(wait_and_send_id_op_node->op().op_name()); + lbi.set_blob_name(wait_and_send_id_op_node->op().op_conf().wait_and_send_ids_conf().out()); + OperatorConf critical_section_wait_op; + { + critical_section_wait_op.set_name("System-EagerCriticalSection-Wait-" + NewUniqueId()); + auto* conf = critical_section_wait_op.mutable_critical_section_wait_tick_conf(); + conf->add_tick(GenLogicalBlobName(lbi)); + conf->set_out("out"); + conf->set_buffer_name(wait_buffer_name); + } + const auto& parallel_conf = wait_and_send_id_op_node->parallel_desc().parallel_conf(); + JUST(job_builder->AddOp(parallel_conf, critical_section_wait_op)); + src_subset_tick_op.mutable_src_subset_tick_conf()->add_in(critical_section_wait_op.name() + + "/out"); + } + JUST(job_builder->MutOpOnlyOnce(src_subset_tick_op)); + return Maybe::Ok(); +} + +Maybe InsertCriticalSectionCallbackTicks(const OpGraph& op_graph, + JobBuilder* job_builder, + const LogicalBlobId& dst_subset_tick_lbi, + const std::string& callback_buffer_name) { + OperatorConf critical_section_callback_op; + critical_section_callback_op.set_name("System-EagerCriticalSection-Callback-" + NewUniqueId()); + auto* conf = critical_section_callback_op.mutable_critical_section_callback_tick_conf(); + conf->add_tick(GenLogicalBlobName(dst_subset_tick_lbi)); + conf->set_out("out"); + conf->set_buffer_name(callback_buffer_name); + const auto& op_name = dst_subset_tick_lbi.op_name(); + const auto& parallel_conf = JUST(job_builder->ParallelConf4OpName(op_name)); + JUST(job_builder->AddOp(parallel_conf, critical_section_callback_op)); + LogicalBlobId critical_section_callback_lbi; + critical_section_callback_lbi.set_op_name(critical_section_callback_op.name()); + critical_section_callback_lbi.set_blob_name("out"); + return critical_section_callback_lbi; +} + +Maybe MultiClientAutoCriticalSectionTick( + const OpGraph& op_graph, JobBuilder* job_builder, + const std::vector& interface_op_nodes, const std::string& wait_buffer_name, + const std::string& callback_buffer_name) { + std::vector interface_src_tick_op_names; + std::vector interface_dst_tick_lbns; + JUST(InsertCriticalSectionSrcAndDstTicks(interface_op_nodes, job_builder, + &interface_src_tick_op_names, &interface_dst_tick_lbns)); + std::string src_subset_tick_op_name; + LogicalBlobId dst_subset_tick_lbi; + JUST(InsertSrcSubsetTickAndDstSubsetTick(interface_src_tick_op_names, interface_dst_tick_lbns, + job_builder, &src_subset_tick_op_name, + &dst_subset_tick_lbi)); + JUST(InsertCriticalSectionWaitTicks(op_graph, job_builder, src_subset_tick_op_name, + wait_buffer_name)); + const auto& lbi = JUST(InsertCriticalSectionCallbackTicks( + op_graph, job_builder, dst_subset_tick_lbi, callback_buffer_name)); + return lbi; +} + +Maybe ConnectCriticalSectionCallbackToJobSoleDstSubsetTick( + const OpGraph& op_graph, JobBuilder* job_builder, + const std::vector>& critical_section_callback_lbis) { + const OpNode* dst_subset_tick_op_node = nullptr; + JUST(op_graph.MaybeForEachNode([&](OpNode* op_node) -> Maybe { + if (!op_node->op().op_conf().has_dst_subset_tick_conf()) { return Maybe::Ok(); } + CHECK_OR_RETURN(dst_subset_tick_op_node == nullptr); + dst_subset_tick_op_node = op_node; + return Maybe::Ok(); + })); + CHECK_NOTNULL_OR_RETURN(dst_subset_tick_op_node); + OperatorConf dst_subset_tick_op(dst_subset_tick_op_node->op().op_conf()); + auto* conf = dst_subset_tick_op.mutable_dst_subset_tick_conf(); + for (const auto& lbi : critical_section_callback_lbis) { conf->add_in(GenLogicalBlobName(*lbi)); } + JUST(job_builder->MutOpOnlyOnce(dst_subset_tick_op)); + return Maybe::Ok(); +} + +} // namespace + +Maybe MultiClientAutoInterfaceCriticalSectionTick(const OpGraph& op_graph, Job* job) { + if (!JUST(IsMultiClient())) { return Maybe::Ok(); } + JobBuilder job_builder(job); + std::vector> critical_section_callback_lbis; + { + std::vector interface_op_nodes; + op_graph.ForEachNode([&](OpNode* node) { + if (node->op().op_conf().has_input_conf()) { interface_op_nodes.push_back(node); } + }); + const auto& lbi = JUST(MultiClientAutoCriticalSectionTick( + op_graph, &job_builder, interface_op_nodes, + GetInputCriticalSectionWaitBufferName(job->job_conf().job_name()), + GetInputCriticalSectionCallbackBufferName(job->job_conf().job_name()))); + critical_section_callback_lbis.push_back(lbi); + } + { + std::vector interface_op_nodes; + op_graph.ForEachNode([&](OpNode* node) { + if (node->op().op_conf().has_output_conf()) { interface_op_nodes.push_back(node); } + }); + const auto& lbi = JUST(MultiClientAutoCriticalSectionTick( + op_graph, &job_builder, interface_op_nodes, + GetOutputCriticalSectionWaitBufferName(job->job_conf().job_name()), + GetOutputCriticalSectionCallbackBufferName(job->job_conf().job_name()))); + critical_section_callback_lbis.push_back(lbi); + } + JUST(ConnectCriticalSectionCallbackToJobSoleDstSubsetTick(op_graph, &job_builder, + critical_section_callback_lbis)); + return Maybe::Ok(); +} + Maybe SingleClientAddGlobalInputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder) { if (JUST(IsMultiClient())) { return Maybe::Ok(); } diff --git a/oneflow/core/job_rewriter/autotick.h b/oneflow/core/job_rewriter/autotick.h index 4edd5bfcc5e..53a81b90408 100644 --- a/oneflow/core/job_rewriter/autotick.h +++ b/oneflow/core/job_rewriter/autotick.h @@ -30,6 +30,7 @@ Maybe SingleClientAddGlobalInputCriticalSections(const OpGraph& op_graph, Maybe SingleClientAddGlobalOutputCriticalSections(const OpGraph& op_graph, JobBuilder* job_builder); Maybe MultiClientAutoSourceAndSinkTick(const OpGraph& op_graph, Job* job); +Maybe MultiClientAutoInterfaceCriticalSectionTick(const OpGraph& op_graph, Job* job); class MutOpConTickInputHelper { public: diff --git a/oneflow/core/job_rewriter/clone_grad.cpp b/oneflow/core/job_rewriter/clone_grad.cpp index 0d4bbf2dee9..d5917ceefde 100644 --- a/oneflow/core/job_rewriter/clone_grad.cpp +++ b/oneflow/core/job_rewriter/clone_grad.cpp @@ -18,10 +18,11 @@ limitations under the License. namespace oneflow { -void GenerateCloneGradOpIfNeed(const OpNode& op_node, JobBuilder* job_builder, - const HashMap& in_oba2in_diff_lbi, - HashMap* out_oba2out_diff_lbi, - HashMap* out_oba2clone_bw_add_out_lbi) { +Maybe GenerateCloneGradOpIfNeed( + const OpNode& op_node, JobBuilder* job_builder, + const HashMap& in_oba2in_diff_lbi, + HashMap* out_oba2out_diff_lbi, + HashMap* out_oba2clone_bw_add_out_lbi) { HashMap out_lbi2out_oba; for (const auto& obn : op_node.op().output_bns()) { out_lbi2out_oba[op_node.op().BnInOp2Lbi(obn)] = GenOpBlobArg(op_node.op().op_name(), obn); @@ -52,14 +53,15 @@ void GenerateCloneGradOpIfNeed(const OpNode& op_node, JobBuilder* job_builder, for (const LogicalBlobId& lbi_to_add : lbis_to_add) { add_op_builder.Input("in", GenLogicalBlobName(lbi_to_add)); } - const auto& op_conf = CHECK_JUST(job_builder->OpConf4OpName(lbi.op_name())); + const auto& op_conf = JUST(job_builder->OpConf4OpName(lbi.op_name())); const auto add_op = add_op_builder.Output("out").ScopeSymbolId(op_conf.scope_symbol_id()).Build(); - job_builder->AddOps(job_builder->ParallelConf4Lbi(lbi), {add_op.op_conf()}); + job_builder->AddOps(JUST(job_builder->ParallelConf4Lbi(lbi)), {add_op.op_conf()}); CHECK(out_oba2clone_bw_add_out_lbi->emplace(oba, lbis_to_add.front()).second); out_oba2out_diff_lbi->emplace(oba, GenLogicalBlobId(add_op.output("out", 0))); } } + return Maybe::Ok(); } } // namespace oneflow diff --git a/oneflow/core/job_rewriter/clone_grad.h b/oneflow/core/job_rewriter/clone_grad.h index a96d9561a83..adafc79b0ab 100644 --- a/oneflow/core/job_rewriter/clone_grad.h +++ b/oneflow/core/job_rewriter/clone_grad.h @@ -20,10 +20,11 @@ limitations under the License. namespace oneflow { -void GenerateCloneGradOpIfNeed(const OpNode& op_node, JobBuilder* job_builder, - const HashMap& in_oba2in_diff_lbi, - HashMap* out_oba2out_diff_lbi, - HashMap* out_oba2clone_bw_add_out_lbi); +Maybe GenerateCloneGradOpIfNeed( + const OpNode& op_node, JobBuilder* job_builder, + const HashMap& in_oba2in_diff_lbi, + HashMap* out_oba2out_diff_lbi, + HashMap* out_oba2clone_bw_add_out_lbi); } #endif // ONEFLOW_CORE_JOB_REWRITER_CLONE_GRAD_H_ diff --git a/oneflow/core/job_rewriter/fuse_add_to_output_pass.cpp b/oneflow/core/job_rewriter/fuse_add_to_output_pass.cpp index f1cf8a443c9..13f08773177 100644 --- a/oneflow/core/job_rewriter/fuse_add_to_output_pass.cpp +++ b/oneflow/core/job_rewriter/fuse_add_to_output_pass.cpp @@ -73,6 +73,8 @@ Maybe FuseAddToOutputPass::Apply(const OpGraph& op_graph, JobBuilder* job_ if (user_op_conf.has_input("_add_to_output", 0)) { return false; } return true; }; + + // Save all op's ctrl in op name in a set. HashSet ctrl_in_op_names; op_graph.ForEachNode([&](const OpNode* op_node) { for (const std::string& ctrl_in_op_name : op_node->op().op_conf().ctrl_in_op_name()) { @@ -113,6 +115,7 @@ Maybe FuseAddToOutputPass::Apply(const OpGraph& op_graph, JobBuilder* job_ } else { return; } + // Make a new_add_to_op to fuse add_n into this op. OperatorConf new_add_to_op_conf = add_to_node->op().op_conf(); *(*(new_add_to_op_conf.mutable_user_conf()->mutable_input()))["_add_to_output"] .mutable_s() @@ -124,6 +127,7 @@ Maybe FuseAddToOutputPass::Apply(const OpGraph& op_graph, JobBuilder* job_ if (op_name2op_conf.find(consumer_op_name) == op_name2op_conf.end()) { op_name2op_conf[consumer_op_name] = consumer->op().op_conf(); } + // Make add_n op's consumer to consume the new_add_to_op for (const std::string& ibn : consumer->op().input_bns()) { if (consumer->op().BnInOp2Lbi(ibn) == out) { OperatorConf& consumer_op_conf = op_name2op_conf.at(consumer_op_name); @@ -133,6 +137,7 @@ Maybe FuseAddToOutputPass::Apply(const OpGraph& op_graph, JobBuilder* job_ } } } + // Add the add_n op to removing list delete_ops.emplace_back(op_conf); }); job_builder->DelOps(delete_ops); diff --git a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp index b651f12c10b..e8db5308ca4 100644 --- a/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp +++ b/oneflow/core/job_rewriter/insert_nccl_logical_op_pass.cpp @@ -22,6 +22,7 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/graph/boxing/hierarchical_sub_task_graph_builder_impl.h" +#include "oneflow/core/framework/nd_sbp.h" #ifdef WITH_CUDA @@ -52,17 +53,6 @@ class InsertNcclLogicalOpPass final : public JobPass { const std::string kNcclLogicalOpNamePrefix = "System-NCCL-Logical"; -std::string NdSbpToString(const cfg::NdSbp& nd_sbp) { - std::string serialized_nd_sbp; - const int64_t num_axes = nd_sbp.sbp_parallel_size(); - serialized_nd_sbp += "["; - for (int64_t i = 0; i < num_axes - 1; ++i) { - serialized_nd_sbp += SbpParallelToString(nd_sbp.sbp_parallel(i)) + " "; - } - serialized_nd_sbp += SbpParallelToString(nd_sbp.sbp_parallel(num_axes - 1)) + "]"; - return serialized_nd_sbp; -} - bool IsBreakpointOpNode(const OpNode* node) { // NOTE(chengcheng): breakpoint op is special which CANNOT through subgraph such as: // variable, tick, repeat/acc/pack/unpack change timeshape diff --git a/oneflow/core/job_rewriter/job_completer.cpp b/oneflow/core/job_rewriter/job_completer.cpp index 75390201c13..9998413e0ba 100644 --- a/oneflow/core/job_rewriter/job_completer.cpp +++ b/oneflow/core/job_rewriter/job_completer.cpp @@ -137,6 +137,7 @@ Maybe JobCompleter::Complete(Job* job) const { JUST(WithOpGraphAndMutJobBuilder(job, &SingleClientAddGlobalInputCriticalSections)); JUST(WithOpGraphAndMutJobBuilder(job, &SingleClientAddGlobalOutputCriticalSections)); JUST(WithOpGraphAndMutJob(job, &MultiClientAutoSourceAndSinkTick)); + JUST(WithOpGraphAndMutJob(job, &MultiClientAutoInterfaceCriticalSectionTick)); JUST(JobPass4Name("SystemOpFillJobNamePass")(job, &job_pass_ctx)); JUST(JobPass4Name("DumpBlobParallelConfPass")(job, &job_pass_ctx)); if (XrtCompilationEnabled(GlobalJobDesc())) { diff --git a/oneflow/core/kernel/critical_section_callback_tick_kernel.cpp b/oneflow/core/kernel/critical_section_callback_tick_kernel.cpp new file mode 100644 index 00000000000..27650328777 --- /dev/null +++ b/oneflow/core/kernel/critical_section_callback_tick_kernel.cpp @@ -0,0 +1,51 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/kernel/kernel.h" +#include "oneflow/core/common/buffer_manager.h" +#include "oneflow/core/job/critical_section_instance.h" +#include "oneflow/core/job/global_for.h" +#include "oneflow/core/common/buffer_manager.h" +#include "oneflow/core/common/multi_client.h" + +namespace oneflow { + +class CriticalSectionCallbackTickKernel final : public Kernel { + public: + OF_DISALLOW_COPY_AND_MOVE(CriticalSectionCallbackTickKernel); + CriticalSectionCallbackTickKernel() = default; + ~CriticalSectionCallbackTickKernel() = default; + + private: + bool IsStateless() const override { return false; } + void ForwardDataContent(KernelContext* ctx) const override; +}; + +void CriticalSectionCallbackTickKernel::ForwardDataContent(KernelContext* ctx) const { + auto* buffer_mgr = Global>>::Get(); + bool is_multi_client = CHECK_JUST(IsMultiClient()); + CHECK(is_multi_client); + CHECK(op_conf().has_critical_section_callback_tick_conf()); + const std::string& buffer_name = op_conf().critical_section_callback_tick_conf().buffer_name(); + std::shared_ptr foreign_critical_section_instance; + BufferStatus buffer_status = + buffer_mgr->Get(buffer_name)->TryReceive(&foreign_critical_section_instance); + CHECK_EQ(buffer_status, kBufferStatusSuccess); + foreign_critical_section_instance->Finish(); +} + +REGISTER_KERNEL(OperatorConf::kCriticalSectionCallbackTickConf, CriticalSectionCallbackTickKernel); + +} // namespace oneflow diff --git a/oneflow/core/kernel/critical_section_wait_tick_kernel.cpp b/oneflow/core/kernel/critical_section_wait_tick_kernel.cpp new file mode 100644 index 00000000000..89198e542fb --- /dev/null +++ b/oneflow/core/kernel/critical_section_wait_tick_kernel.cpp @@ -0,0 +1,50 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/kernel/kernel.h" +#include "oneflow/core/common/buffer_manager.h" +#include "oneflow/core/job/critical_section_instance.h" +#include "oneflow/core/job/global_for.h" +#include "oneflow/core/common/buffer_manager.h" +#include "oneflow/core/common/multi_client.h" + +namespace oneflow { + +class CriticalSectionWaitTickKernel final : public Kernel { + public: + OF_DISALLOW_COPY_AND_MOVE(CriticalSectionWaitTickKernel); + CriticalSectionWaitTickKernel() = default; + ~CriticalSectionWaitTickKernel() = default; + + private: + bool IsStateless() const override { return false; } + void ForwardDataContent(KernelContext* ctx) const override; +}; + +void CriticalSectionWaitTickKernel::ForwardDataContent(KernelContext* ctx) const { + auto* buffer_mgr = Global>>::Get(); + bool is_multi_client = CHECK_JUST(IsMultiClient()); + CHECK(is_multi_client); + CHECK(this->op_conf().has_critical_section_wait_tick_conf()); + const std::string& buffer_name = this->op_conf().critical_section_wait_tick_conf().buffer_name(); + std::shared_ptr foreign_critical_section_instance; + BufferStatus buffer_status = + buffer_mgr->Get(buffer_name)->Pull(&foreign_critical_section_instance); + CHECK_EQ(buffer_status, kBufferStatusSuccess); +} + +REGISTER_KERNEL(OperatorConf::kCriticalSectionWaitTickConf, CriticalSectionWaitTickKernel); + +} // namespace oneflow diff --git a/oneflow/core/kernel/eager_kernel.h b/oneflow/core/kernel/eager_kernel.h index 8ad530413bc..19c3f4a0268 100644 --- a/oneflow/core/kernel/eager_kernel.h +++ b/oneflow/core/kernel/eager_kernel.h @@ -37,6 +37,7 @@ class EagerKernel final : public Kernel { void InitOpKernel(const KernelConf& kernel_conf); void ForwardDataContent(KernelContext* kernel_ctx) const override { UNIMPLEMENTED(); } std::unique_ptr kernel_; + mutable std::shared_ptr cache_; }; } // namespace oneflow diff --git a/oneflow/core/kernel/input_kernel.cpp b/oneflow/core/kernel/input_kernel.cpp index 8821be6f074..0c948ac50e7 100644 --- a/oneflow/core/kernel/input_kernel.cpp +++ b/oneflow/core/kernel/input_kernel.cpp @@ -16,8 +16,8 @@ limitations under the License. #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/common/buffer_manager.h" +#include "oneflow/core/job/critical_section_instance.h" #include "oneflow/core/common/multi_client.h" -#include "oneflow/core/job/job_instance.h" #include "oneflow/core/job/global_for.h" namespace oneflow { @@ -36,14 +36,14 @@ class InputKernel final : public Kernel { CHECK(this->op_conf().input_conf().has_job_name()); const auto& job_name = this->op_conf().input_conf().job_name(); const auto& op_name = this->op_conf().name(); - auto* buffer_mgr = Global>>::Get(); + auto* buffer_mgr = Global>>::Get(); auto* buffer = buffer_mgr->Get(GetInputBufferName(job_name, op_name)); - std::shared_ptr job_instance; - BufferStatus buffer_status = buffer->TryReceive(&job_instance); + std::shared_ptr critical_section_instance; + BufferStatus buffer_status = buffer->TryReceive(&critical_section_instance); CHECK_NE(buffer_status, kBufferStatusEmpty); if (buffer_status == kBufferStatusSuccess) { OfBlob ofblob(ctx->stream(), ctx->BnInOp2Blob("out")); - job_instance->PushBlobByOpName(reinterpret_cast(&ofblob), op_name); + critical_section_instance->AccessBlobByOpName(reinterpret_cast(&ofblob), op_name); } } } diff --git a/oneflow/core/kernel/kernel_util.cuh b/oneflow/core/kernel/kernel_util.cuh index 5d2469b16c8..18743b97ca1 100644 --- a/oneflow/core/kernel/kernel_util.cuh +++ b/oneflow/core/kernel/kernel_util.cuh @@ -15,6 +15,7 @@ limitations under the License. */ #ifndef ONEFLOW_CORE_KERNEL_KERNEL_UTIL_CUH_ #define ONEFLOW_CORE_KERNEL_KERNEL_UTIL_CUH_ +#include "oneflow/core/device/cuda_pseudo_half.h" namespace oneflow { @@ -31,14 +32,9 @@ OF_DEVICE_FUNC T MaxWithLogThreshold(T x) { #if defined(__CUDACC__) __device__ __forceinline__ half MaxWithLogThreshold(half x) { -#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) half threshold = hexp2(__float2half(-14.0)); if (__hgt(x, threshold)) { return x; } return threshold; -#else - printf("use half need nvcc arch >= 530"); - assert(false); -#endif /* __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)*/ } #endif @@ -48,14 +44,7 @@ OF_DEVICE_FUNC T SafeLog(T x) { } #if defined(__CUDACC__) -__device__ __forceinline__ half SafeLog(half x) { -#if __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__) - return hlog(MaxWithLogThreshold(x)); -#else - printf("use half need nvcc arch >= 530"); - assert(false); -#endif /* __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)*/ -} +__device__ __forceinline__ half SafeLog(half x) { return hlog(MaxWithLogThreshold(x)); } #endif } // namespace oneflow diff --git a/oneflow/core/kernel/output_kernel.cpp b/oneflow/core/kernel/output_kernel.cpp index e1cd46ce112..fcd82e6893c 100644 --- a/oneflow/core/kernel/output_kernel.cpp +++ b/oneflow/core/kernel/output_kernel.cpp @@ -15,8 +15,8 @@ limitations under the License. */ #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/common/buffer_manager.h" +#include "oneflow/core/job/critical_section_instance.h" #include "oneflow/core/common/multi_client.h" -#include "oneflow/core/job/job_instance.h" #include "oneflow/core/job/global_for.h" namespace oneflow { @@ -37,14 +37,14 @@ void OutputKernel::ForwardDataContent(KernelContext* ctx) const { CHECK(this->op_conf().output_conf().has_job_name()); const auto& job_name = this->op_conf().output_conf().job_name(); const auto& op_name = this->op_conf().name(); - auto* buffer_mgr = Global>>::Get(); + auto* buffer_mgr = Global>>::Get(); auto* buffer = buffer_mgr->Get(GetOutputBufferName(job_name, op_name)); - std::shared_ptr job_instance; - BufferStatus buffer_status = buffer->TryReceive(&job_instance); + std::shared_ptr critical_section_instance; + BufferStatus buffer_status = buffer->TryReceive(&critical_section_instance); CHECK_NE(buffer_status, kBufferStatusEmpty); if (buffer_status == kBufferStatusSuccess) { OfBlob ofblob(ctx->stream(), ctx->BnInOp2Blob("in")); - job_instance->PullBlobByOpName(reinterpret_cast(&ofblob), op_name); + critical_section_instance->AccessBlobByOpName(reinterpret_cast(&ofblob), op_name); } } else { AutoMemcpy(ctx->stream(), ctx->BnInOp2Blob("out"), ctx->BnInOp2Blob("in")); diff --git a/oneflow/core/kernel/return_kernel.cpp b/oneflow/core/kernel/return_kernel.cpp index 2d17664535b..1cc7d30ad50 100644 --- a/oneflow/core/kernel/return_kernel.cpp +++ b/oneflow/core/kernel/return_kernel.cpp @@ -15,8 +15,8 @@ limitations under the License. */ #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/common/buffer_manager.h" +#include "oneflow/core/job/critical_section_instance.h" #include "oneflow/core/common/multi_client.h" -#include "oneflow/core/job/job_instance.h" #include "oneflow/core/job/global_for.h" namespace oneflow { @@ -37,14 +37,14 @@ void ReturnKernel::ForwardDataContent(KernelContext* ctx) const { CHECK(this->op_conf().return_conf().has_job_name()); const auto& job_name = this->op_conf().return_conf().job_name(); const auto& op_name = this->op_conf().name(); - auto* buffer_mgr = Global>>::Get(); + auto* buffer_mgr = Global>>::Get(); auto* buffer = buffer_mgr->Get(GetOutputBufferName(job_name, op_name)); - std::shared_ptr job_instance; - BufferStatus buffer_status = buffer->TryReceive(&job_instance); + std::shared_ptr critical_section_instance; + BufferStatus buffer_status = buffer->TryReceive(&critical_section_instance); CHECK_NE(buffer_status, kBufferStatusEmpty); if (buffer_status == kBufferStatusSuccess) { OfBlob ofblob(ctx->stream(), ctx->BnInOp2Blob("in")); - job_instance->PullBlobByOpName(reinterpret_cast(&ofblob), op_name); + critical_section_instance->AccessBlobByOpName(reinterpret_cast(&ofblob), op_name); } } else { AutoMemcpy(ctx->stream(), ctx->BnInOp2Blob("out"), ctx->BnInOp2Blob("in")); diff --git a/oneflow/core/kernel/user_kernel.cpp b/oneflow/core/kernel/user_kernel.cpp index aeaae600dc6..43abb1f1688 100644 --- a/oneflow/core/kernel/user_kernel.cpp +++ b/oneflow/core/kernel/user_kernel.cpp @@ -79,7 +79,9 @@ class UserKernelBaseContext { device_type_ = CHECK_JUST(DeviceType4DeviceTag(device_tag_)); parallel_ctx_ = kernel_conf.parallel_ctx(); for (const auto& pair : kernel_conf.user_conf().bn_in_op2blob_desc()) { - arg2tensor_desc_.emplace(GenUnRepeatedBn(pair.first), user_op::NaiveTensorDesc(pair.second)); + arg2bn_and_tensor_desc_.emplace( + GenUnRepeatedBn(pair.first), + std::make_pair(pair.first, user_op::NaiveTensorDesc(pair.second))); } } ~UserKernelBaseContext() = default; @@ -89,26 +91,29 @@ class UserKernelBaseContext { const ParallelContext& parallel_ctx() const { return parallel_ctx_; } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const { - auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); - if (it == arg2tensor_desc_.end()) { return nullptr; } - return &(it->second); + auto it = arg2bn_and_tensor_desc_.find(std::make_pair(arg_name, index)); + if (it == arg2bn_and_tensor_desc_.end()) { return nullptr; } + return &(it->second.second); } const ArgVec& inputs() const { return inputs_; } const ArgVec& outputs() const { return outputs_; } private: + friend class UserKernelInitAndCacheContext; + HashMap, std::pair> + arg2bn_and_tensor_desc_; ArgVec inputs_; ArgVec outputs_; DeviceType device_type_; std::string device_tag_; ParallelContext parallel_ctx_; - HashMap, user_op::NaiveTensorDesc> arg2tensor_desc_; }; -class UserKernelInitContext final : public user_op::KernelInitContext { +class UserKernelInitAndCacheContext final : public user_op::KernelInitContext, + public user_op::KernelCacheContext { public: - explicit UserKernelInitContext(ep::Stream* stream, const KernelConf& kernel_conf) + explicit UserKernelInitAndCacheContext(ep::Stream* stream, const KernelConf& kernel_conf) : user_op_conf_(kernel_conf.op_attribute().op_conf()), stream_(stream), base_ctx_(UserKernelBaseContext(kernel_conf)), @@ -117,16 +122,35 @@ class UserKernelInitContext final : public user_op::KernelInitContext { if (kernel_conf.op_attribute().has_sbp_signature()) { sbp_signature_ = cfg::SbpSignature(kernel_conf.op_attribute().sbp_signature()); } - for (const auto& pair : - kernel_conf.op_attribute().logical_blob_desc_signature().bn_in_op2blob_desc()) { - arg2logical_tensor_desc_.emplace(GenUnRepeatedBn(pair.first), - user_op::NaiveTensorDesc(pair.second)); + bool is_dynamic = false; + for (const auto& pair : kernel_conf.user_conf().bn_in_op2blob_desc()) { + if (pair.second.is_dynamic()) { + is_dynamic = true; + break; + } + } + if (!is_dynamic || parallel_ctx().parallel_num() == 1) { + for (const auto& pair : + kernel_conf.op_attribute().logical_blob_desc_signature().bn_in_op2blob_desc()) { + arg2logical_tensor_desc_.emplace(GenUnRepeatedBn(pair.first), + user_op::NaiveTensorDesc(pair.second)); + } } } - ~UserKernelInitContext() override = default; + ~UserKernelInitAndCacheContext() override = default; ep::Stream* stream() override { return stream_; } + void UpdateTensorWithCorrBlob(const std::function& BnInOp2Blob) { + for (auto& pair : base_ctx_.arg2bn_and_tensor_desc_) { + const std::string& bn = pair.second.first; + auto& tensor_desc = pair.second.second; + Blob* blob = BnInOp2Blob(bn); + CHECK(blob != nullptr) << "Blob " << bn << " is not found in cache context."; + if (blob->blob_desc().is_dynamic()) { blob->shape().ToShape(tensor_desc.mut_shape()); } + } + } + DeviceType device_type() const override { return base_ctx_.device_type(); } const ParallelContext& parallel_ctx() const override { return base_ctx_.parallel_ctx(); } const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, @@ -182,6 +206,9 @@ class UserKernelInitContext final : public user_op::KernelInitContext { cfg::NdSbpSignature nd_sbp_signature_; }; +using UserKernelInitContext = UserKernelInitAndCacheContext; +using UserKernelCacheContext = UserKernelInitAndCacheContext; + class UserKernelOpInferContext : public user_op::InferContext { public: explicit UserKernelOpInferContext(const KernelConf& kernel_conf) @@ -229,8 +256,7 @@ class UserKernelOpInferContext : public user_op::InferContext { user_op::TensorDesc* OutputTensorDesc(const std::string& arg_name, int32_t index) override { return TensorDesc4ArgNameAndIndex(arg_name, index); } - user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, - int32_t index) override { + user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return nullptr; } return it->second.get(); @@ -571,6 +597,7 @@ UserKernel::~UserKernel() = default; void UserKernel::InitUserKernel(ep::Stream* stream) { ctx_.reset(new UserKernelComputeContext(stream, kernel_conf())); infer_ctx_.reset(new UserKernelInferContext(stream, kernel_conf())); + cache_ctx_.reset(new UserKernelCacheContext(stream, kernel_conf())); infer_cache_.reset(new user_op::OpKernelInferCache(kernel_conf(), this)); { const std::string& op_type_name = @@ -596,6 +623,13 @@ void UserKernel::ForwardUserKernel(const std::functionUpdateTensorWithCorrBlob(BnInOp2Blob); + if (updated) { + cache_ctx_->UpdateTensorWithCorrBlob(BnInOp2Blob); + kernel_->InitOpKernelCache(cache_ctx_.get(), user_op::OpKernelCache::kAttrNotChanged, + &opkernel_cache_); + } else { + // do nothing + } #ifdef WITH_CUDA_GRAPHS bool current_scope_capturing = false; if (cuda_graph_exec_) { @@ -611,7 +645,7 @@ void UserKernel::ForwardUserKernel(const std::functionCompute(ctx_.get(), opkernel_state); + kernel_->Compute(ctx_.get(), opkernel_state, opkernel_cache_.get()); #ifdef WITH_CUDA_GRAPHS if (cuda_graph_exec_ && current_scope_capturing) { @@ -634,6 +668,8 @@ void UserKernel::VirtualKernelInit(KernelContext* ctx) { InitUserKernel(ctx->stream()); CHECK(opkernel_state_.get() == nullptr); opkernel_state_ = CreateOpKernelState(ctx); + kernel_->InitOpKernelCache(cache_ctx_.get(), user_op::OpKernelCache::kAllMayChanged, + &opkernel_cache_); #ifdef WITH_CUDA_GRAPHS if (ParseBooleanFromEnv("ONEFLOW_KERNEL_ENABLE_CUDA_GRAPH", false)) { UserKernelInitContext init_ctx(ctx->stream(), kernel_conf()); @@ -724,12 +760,13 @@ std::shared_ptr EagerKernel::EagerForward( std::function BnInOp2Blob) const { std::shared_ptr new_opkernel_state; CHECK_NOTNULL(device_ctx); + UserKernelInitAndCacheContext init_and_cache_ctx(device_ctx->stream(), kernel_conf()); if (old_opkernel_state) { new_opkernel_state = old_opkernel_state; } else { - UserKernelInitContext init_ctx(device_ctx->stream(), kernel_conf()); - new_opkernel_state = kernel_->CreateOpKernelState(&init_ctx); + new_opkernel_state = kernel_->CreateOpKernelState(&init_and_cache_ctx); } + kernel_->InitOpKernelCache(&init_and_cache_ctx, user_op::OpKernelCache::kAllMayChanged, &cache_); if (IsAllBlobEmpty(op_attribute().output_bns(), BnInOp2Blob) && !kernel_->AlwaysComputeWhenAllOutputsEmpty()) { @@ -739,7 +776,7 @@ std::shared_ptr EagerKernel::EagerForward( // TODO(lixinqi): refactor to a lightweight KernelComputeContext UserKernelComputeContext compute_ctx(device_ctx->stream(), kernel_conf()); compute_ctx.UpdateTensorWithCorrBlob(BnInOp2Blob); - kernel_->Compute(&compute_ctx, new_opkernel_state.get()); + kernel_->Compute(&compute_ctx, new_opkernel_state.get(), cache_.get()); return new_opkernel_state; } diff --git a/oneflow/core/kernel/user_kernel.h b/oneflow/core/kernel/user_kernel.h index 4baf2aec8f3..9347f6f8232 100644 --- a/oneflow/core/kernel/user_kernel.h +++ b/oneflow/core/kernel/user_kernel.h @@ -33,6 +33,11 @@ namespace oneflow { class UserKernelComputeContext; class UserKernelInferContext; +class UserKernelInitAndCacheContext; + +namespace user_op { +class OpKernelCache; +} class UserKernel final : public Kernel { public: @@ -55,9 +60,11 @@ class UserKernel final : public Kernel { bool IsStateless() const override; + mutable std::shared_ptr opkernel_cache_; std::shared_ptr opkernel_state_; std::unique_ptr kernel_; std::unique_ptr ctx_; + std::unique_ptr cache_ctx_; std::unique_ptr infer_ctx_; std::unique_ptr infer_cache_; #ifdef WITH_CUDA_GRAPHS diff --git a/oneflow/core/lazy/actor/naive_actor.cpp b/oneflow/core/lazy/actor/naive_actor.cpp index 99ceaf936bb..ac557618b74 100644 --- a/oneflow/core/lazy/actor/naive_actor.cpp +++ b/oneflow/core/lazy/actor/naive_actor.cpp @@ -35,6 +35,7 @@ REGISTER_ACTOR(TaskType::kBoxingIdentity, NaiveActor); REGISTER_ACTOR(TaskType::kCollectiveBoxingPack, NaiveActor); REGISTER_ACTOR(TaskType::kCollectiveBoxingUnpack, NaiveActor); REGISTER_ACTOR(TaskType::kDecodeH2D, NaiveActor); +REGISTER_ACTOR(TaskType::kCriticalSectionWaitTick, NaiveActor); #ifdef WITH_CUDA REGISTER_ACTOR(TaskType::kCopyHd, NaiveActor); #endif diff --git a/oneflow/core/lazy/actor/pack_actor.cpp b/oneflow/core/lazy/actor/pack_actor.cpp index f5f203bd89c..6c3de79fd79 100644 --- a/oneflow/core/lazy/actor/pack_actor.cpp +++ b/oneflow/core/lazy/actor/pack_actor.cpp @@ -15,7 +15,7 @@ limitations under the License. */ #include "oneflow/core/lazy/actor/actor.h" #include "oneflow/core/kernel/user_kernel.h" -#include "oneflow/user/kernels/op_kernel_state_wrapper.h" +#include "oneflow/user/kernels/op_kernel_wrapper.h" namespace oneflow { diff --git a/oneflow/core/lazy/actor/unpack_actor.cpp b/oneflow/core/lazy/actor/unpack_actor.cpp index 7bb954b3201..9a18ee09f13 100644 --- a/oneflow/core/lazy/actor/unpack_actor.cpp +++ b/oneflow/core/lazy/actor/unpack_actor.cpp @@ -15,7 +15,7 @@ limitations under the License. */ #include "oneflow/core/lazy/actor/actor.h" #include "oneflow/core/kernel/user_kernel.h" -#include "oneflow/user/kernels/op_kernel_state_wrapper.h" +#include "oneflow/user/kernels/op_kernel_wrapper.h" namespace oneflow { diff --git a/oneflow/core/ndarray/binary_func.h b/oneflow/core/ndarray/binary_func.h index c1088cf19b7..7dc5d380234 100644 --- a/oneflow/core/ndarray/binary_func.h +++ b/oneflow/core/ndarray/binary_func.h @@ -143,12 +143,13 @@ template struct BinaryFuncPow final { static OF_DEVICE_FUNC const T Invoke(const T x, const T y) { #if defined(__CUDACC__) - return pow(x, y); + return powf(x, y); #else return std::pow(x, y); #endif } }; + SPECIALIZE_CONST_TYPE_BINARY_FUNC(BinaryFuncPow); template<> @@ -167,7 +168,7 @@ struct BinaryFuncPow final { template<> struct BinaryFuncPow final { static __device__ __forceinline__ float Invoke(const float x, const float y) { - return __powf(x, y); + return powf(x, y); } }; @@ -175,7 +176,7 @@ template<> struct BinaryFuncPow final { static __device__ __forceinline__ half Invoke(const half x, const half y) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 - return __float2half(__powf(__half2float(x), __half2float(y))); + return __float2half(powf(__half2float(x), __half2float(y))); #else NO_HALF_UTIL_FOUND; #endif diff --git a/oneflow/core/ndarray/ndarray_assign_core.cu b/oneflow/core/ndarray/ndarray_assign_core.cu index 491c6923a8b..73fcfd61bc5 100644 --- a/oneflow/core/ndarray/ndarray_assign_core.cu +++ b/oneflow/core/ndarray/ndarray_assign_core.cu @@ -22,7 +22,8 @@ namespace oneflow { namespace { template -__global__ void NdarrayAssignGpu(XpuVarNdarray y, const XpuReducedNdarray reduced) { +__global__ void NdarrayAssignReducedGpu(XpuVarNdarray y, + const XpuReducedNdarray reduced) { NdarrayAssignCore::Assign(y, reduced); } @@ -39,7 +40,7 @@ struct NdarrayAssignCoreWrapper final { const XpuReducedNdarray& reduced) { size_t n = y.host_shape().HostElemNum(); if (n == 0) { return; } - RUN_CUDA_KERNEL((NdarrayAssignGpu), ctx, n, y, reduced); + RUN_CUDA_KERNEL((NdarrayAssignReducedGpu), ctx, n, y, reduced); } static void Assign(ep::Stream* ctx, const XpuVarNdarray& y, const XpuVarNdarray& x) { size_t n = y.host_shape().HostElemNum(); diff --git a/oneflow/core/ndarray/xpu_ndarray_assign.cu b/oneflow/core/ndarray/xpu_ndarray_assign.cu index dab8760dd9f..809fdc3c091 100644 --- a/oneflow/core/ndarray/xpu_ndarray_assign.cu +++ b/oneflow/core/ndarray/xpu_ndarray_assign.cu @@ -22,7 +22,8 @@ namespace oneflow { namespace { template -__global__ void NdarrayAssignGpu(XpuVarNdarray y, const XpuReducedNdarray reduced) { +__global__ void NdarrayAssignReducedGpu(XpuVarNdarray y, + const XpuReducedNdarray reduced) { NdarrayAssignCore::Assign(y, reduced); } @@ -38,7 +39,7 @@ struct NdarrayAssignCoreWrapper final { static void Assign(ep::Stream* stream, XpuVarNdarray* y, const XpuReducedNdarray& reduced) { size_t n = y->host_shape().HostElemNum(); - RUN_CUDA_KERNEL((NdarrayAssignGpu), stream, n, *y, reduced); + RUN_CUDA_KERNEL((NdarrayAssignReducedGpu), stream, n, *y, reduced); } static void Assign(ep::Stream* ctx, const XpuVarNdarray& y, const XpuVarNdarray& x) { size_t n = y.host_shape().HostElemNum(); diff --git a/oneflow/core/operator/critical_section_callback_tick_op.cpp b/oneflow/core/operator/critical_section_callback_tick_op.cpp new file mode 100644 index 00000000000..fc8acc44574 --- /dev/null +++ b/oneflow/core/operator/critical_section_callback_tick_op.cpp @@ -0,0 +1,81 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/operator/operator.h" +#include "oneflow/core/job/sbp_signature_builder.h" + +namespace oneflow { + +namespace { + +Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { + BlobDesc* blob_desc = BlobDesc4BnInOp("out"); + blob_desc->mut_shape() = Shape({1}); + blob_desc->set_data_type(DataType::kInt8); + return Maybe::Ok(); +} + +} // namespace + +class CriticalSectionCallbackTickOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(CriticalSectionCallbackTickOp); + CriticalSectionCallbackTickOp() = default; + ~CriticalSectionCallbackTickOp() = default; + + Maybe InitFromOpConf() override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; + Maybe InferOutBlobDescs( + const std::function& GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + + private: + Maybe GetSbpSignatures( + const std::function(const std::string&)>& LogicalBlobDesc4Ibn, + cfg::SbpSignatureList* sbp_sig_list) const override; +}; + +Maybe CriticalSectionCallbackTickOp::InitFromOpConf() { + CHECK(op_conf().has_critical_section_callback_tick_conf()); + EnrollRepeatedInputBn("tick", false); + EnrollOutputBn("out", false); + return Maybe::Ok(); +} + +Maybe CriticalSectionCallbackTickOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + return InferBlobDescs(BlobDesc4BnInOp); +} + +Maybe CriticalSectionCallbackTickOp::InferOutBlobDescs( + const std::function& GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + return InferBlobDescs(GetBlobDesc4BnInOp); +} + +Maybe CriticalSectionCallbackTickOp::GetSbpSignatures( + const std::function(const std::string&)>& LogicalBlobDesc4Ibn, + cfg::SbpSignatureList* sbp_sig_list) const { + return Maybe::Ok(); +} + +REGISTER_OP_SAME_OUTPUT_BLOB_REGST_NUM(OperatorConf::kCriticalSectionCallbackTickConf, 128); +REGISTER_OP(OperatorConf::kCriticalSectionCallbackTickConf, CriticalSectionCallbackTickOp); +REGISTER_TICK_TOCK_OP(OperatorConf::kCriticalSectionCallbackTickConf); + +} // namespace oneflow diff --git a/oneflow/core/operator/critical_section_wait_tick_op.cpp b/oneflow/core/operator/critical_section_wait_tick_op.cpp new file mode 100644 index 00000000000..0b7b4ee6dd8 --- /dev/null +++ b/oneflow/core/operator/critical_section_wait_tick_op.cpp @@ -0,0 +1,81 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/operator/operator.h" +#include "oneflow/core/job/sbp_signature_builder.h" + +namespace oneflow { + +namespace { + +Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { + BlobDesc* blob_desc = BlobDesc4BnInOp("out"); + blob_desc->mut_shape() = Shape({1}); + blob_desc->set_data_type(DataType::kInt8); + return Maybe::Ok(); +} + +} // namespace + +class CriticalSectionWaitTickOp final : public Operator { + public: + OF_DISALLOW_COPY_AND_MOVE(CriticalSectionWaitTickOp); + CriticalSectionWaitTickOp() = default; + ~CriticalSectionWaitTickOp() = default; + + Maybe InitFromOpConf() override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; + Maybe InferOutBlobDescs( + const std::function& GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; + + private: + Maybe GetSbpSignatures( + const std::function(const std::string&)>& LogicalBlobDesc4Ibn, + cfg::SbpSignatureList* sbp_sig_list) const override; +}; + +Maybe CriticalSectionWaitTickOp::InitFromOpConf() { + CHECK_OR_RETURN(op_conf().has_critical_section_wait_tick_conf()); + EnrollRepeatedInputBn("tick", false); + EnrollOutputBn("out", false); + return Maybe::Ok(); +} + +Maybe CriticalSectionWaitTickOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + return InferBlobDescs(BlobDesc4BnInOp); +} + +Maybe CriticalSectionWaitTickOp::InferOutBlobDescs( + const std::function& GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + return InferBlobDescs(GetBlobDesc4BnInOp); +} + +Maybe CriticalSectionWaitTickOp::GetSbpSignatures( + const std::function(const std::string&)>& LogicalBlobDesc4Ibn, + cfg::SbpSignatureList* sbp_sig_list) const { + return Maybe::Ok(); +} + +REGISTER_OP_SAME_OUTPUT_BLOB_REGST_NUM(OperatorConf::kCriticalSectionWaitTickConf, 2); +REGISTER_OP(OperatorConf::kCriticalSectionWaitTickConf, CriticalSectionWaitTickOp); +REGISTER_TICK_TOCK_OP(OperatorConf::kCriticalSectionWaitTickConf); + +} // namespace oneflow diff --git a/oneflow/core/operator/interface_op_util.cpp b/oneflow/core/operator/interface_op_util.cpp index 77b15626122..aec94ee596e 100644 --- a/oneflow/core/operator/interface_op_util.cpp +++ b/oneflow/core/operator/interface_op_util.cpp @@ -21,7 +21,7 @@ namespace oneflow { namespace { void CheckShape(const Shape& shape) { - FOR_RANGE(int, i, 1, shape.NumAxes()) { CHECK_GT(shape.At(i), 0); } + FOR_RANGE(int, i, 1, shape.NumAxes()) { CHECK_GE(shape.At(i), 0); } } Maybe GetSbpSignature(const InterfaceBlobConf& blob_conf, const PbRpf& input_bns, diff --git a/oneflow/core/operator/op_conf.proto b/oneflow/core/operator/op_conf.proto index 31c5880bb8f..919eaa90d02 100644 --- a/oneflow/core/operator/op_conf.proto +++ b/oneflow/core/operator/op_conf.proto @@ -187,6 +187,18 @@ message TickOpConf { required string out = 2; } +message CriticalSectionWaitTickOpConf { + repeated string tick = 1; + required string out = 2; + required string buffer_name = 3; +} + +message CriticalSectionCallbackTickOpConf { + repeated string tick = 1; + required string out = 2; + required string buffer_name = 3; +} + message DeviceTickOpConf { repeated string tick = 1; required string out = 2; @@ -480,7 +492,9 @@ message OperatorConf { BoxingOpConf boxing_conf = 108; VariableOpConf variable_conf = 122; TickOpConf tick_conf = 124; - TotalLossInstanceNumOpConf total_loss_instance_num_conf = 126; + CriticalSectionWaitTickOpConf critical_section_wait_tick_conf = 125; + CriticalSectionCallbackTickOpConf critical_section_callback_tick_conf = 126; + TotalLossInstanceNumOpConf total_loss_instance_num_conf = 131; ShapeElemCntOpConf shape_elem_cnt_conf = 132; SrcSubsetTickOpConf src_subset_tick_conf = 133; DstSubsetTickOpConf dst_subset_tick_conf = 134; diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index 7b947a28549..3a3d1fdf2a9 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -26,6 +26,7 @@ limitations under the License. #include "oneflow/core/operator/op_node_signature.pb.h" #include "oneflow/core/job/nd_sbp_infer_hint.h" #include "oneflow/core/job/foreign_callback.h" +#include "oneflow/core/framework/nd_sbp.h" namespace oneflow { @@ -738,7 +739,30 @@ Maybe Operator::InferNdSbpSignature( break; } } - CHECK_OR_RETURN(matched_sbp_signature != nullptr) << " op_name " << op_name(); + if (!matched_sbp_signature) { + std::ostringstream err; + err << "op: " << op_name() + << " can't find available sbp signature.\nSupported SBP signatures are: "; + err << *JUST(SbpSignatureListAsString(list, input_bns(), output_bns())); + + std::ostringstream got_input_sbp_ss; + std::ostringstream all_input_sbp_ss; + for (size_t j = 0; j < input_bns().size(); ++j) { + // NOTE: i is hierarchy dim and j is input_order + const auto& ibn = input_bns()[j]; + const cfg::NdSbp& nd_sbp = ibn2nd_sbp.at(ibn); + if (j > 0) { + got_input_sbp_ss << ", "; + all_input_sbp_ss << ", "; + } + got_input_sbp_ss << SbpToString(nd_sbp.sbp_parallel(i)); + all_input_sbp_ss << ibn << ": " << NdSbpToString(nd_sbp); + } + err << ", but got (" << got_input_sbp_ss.str(); + err << ") -> ? at hierarchy dim " << i + << ", since the SBP of inputs are: " << all_input_sbp_ss.str(); + return Error::RuntimeError() << err.str(); + } for (const auto& bn : input_bns()) { *((*nd_sbp_signature->mutable_bn_in_op2nd_sbp())[bn].add_sbp_parallel()) = matched_sbp_signature->bn_in_op2sbp_parallel().at(bn); diff --git a/oneflow/core/operator/user_op.cpp b/oneflow/core/operator/user_op.cpp index 471334c7ff0..8a85910ae2b 100644 --- a/oneflow/core/operator/user_op.cpp +++ b/oneflow/core/operator/user_op.cpp @@ -135,8 +135,9 @@ class UserOpInferContext final : public user_op::InferContext { auto InitTensorDesc = [&](const ArgVec& arg_vec, const PbRpf& bns) { CHECK_EQ(arg_vec.size(), bns.size()); for (int32_t i = 0; i < arg_vec.size(); ++i) { + const auto& bn_i = bns.Get(i); BlobDesc* blob = GetBlobDesc4BnInOp(bns.Get(i)); - CHECK_NOTNULL(blob); + CHECK(blob != nullptr) << bn_i; arg2tensor_desc_.emplace(arg_vec.at(i), GenTensorDescFromBlobDesc(blob)); } }; @@ -152,8 +153,7 @@ class UserOpInferContext final : public user_op::InferContext { user_op::TensorDesc* OutputTensorDesc(const std::string& arg_name, int32_t index) override { return TensorDesc4ArgNameAndIndex(arg_name, index); } - user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, - int32_t index) override { + user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) { auto it = arg2tensor_desc_.find(std::make_pair(arg_name, index)); if (it == arg2tensor_desc_.end()) { return nullptr; }; return &(it->second); diff --git a/oneflow/core/vm/async_cuda_stream_type.cpp b/oneflow/core/vm/async_cuda_stream_type.cpp index 0361e603ddb..c0d8519f882 100644 --- a/oneflow/core/vm/async_cuda_stream_type.cpp +++ b/oneflow/core/vm/async_cuda_stream_type.cpp @@ -74,9 +74,8 @@ intrusive::shared_ptr AsyncCudaStreamType::MakeStreamDesc( std::size_t device_num = resource.gpu_device_num(); auto ret = intrusive::make_shared(); ret->mut_stream_type_id()->__Init__(LookupStreamType4TypeIndex()); - ret->set_num_machines(1); ret->set_num_streams_per_machine(device_num); - ret->set_num_streams_per_thread(1); + ret->set_num_streams_per_thread(device_num); return ret; } diff --git a/oneflow/core/vm/control_stream_type.cpp b/oneflow/core/vm/control_stream_type.cpp index d89971dcd15..d040f1fda23 100644 --- a/oneflow/core/vm/control_stream_type.cpp +++ b/oneflow/core/vm/control_stream_type.cpp @@ -120,7 +120,6 @@ intrusive::shared_ptr ControlStreamType::MakeStreamDesc(const Resour int64_t this_machine_id) const { auto ret = intrusive::make_shared(); ret->mut_stream_type_id()->__Init__(LookupStreamType4TypeIndex()); - ret->set_num_machines(1); ret->set_num_streams_per_machine(1); ret->set_num_streams_per_thread(1); return ret; diff --git a/oneflow/core/vm/cpu_stream_type.cpp b/oneflow/core/vm/cpu_stream_type.cpp index 3ad184088ea..7629d678947 100644 --- a/oneflow/core/vm/cpu_stream_type.cpp +++ b/oneflow/core/vm/cpu_stream_type.cpp @@ -67,9 +67,8 @@ intrusive::shared_ptr CpuStreamType::MakeStreamDesc(const Resource& std::size_t device_num = resource.cpu_device_num(); auto ret = intrusive::make_shared(); ret->mut_stream_type_id()->__Init__(LookupStreamType4TypeIndex()); - ret->set_num_machines(1); ret->set_num_streams_per_machine(device_num); - ret->set_num_streams_per_thread(1); + ret->set_num_streams_per_thread(device_num); return ret; } diff --git a/oneflow/core/vm/cuda_allocator.cpp b/oneflow/core/vm/cuda_allocator.cpp index 74bd5ae4ae6..ec994a42d9d 100644 --- a/oneflow/core/vm/cuda_allocator.cpp +++ b/oneflow/core/vm/cuda_allocator.cpp @@ -276,7 +276,16 @@ void CudaAllocator::Allocate(char** mem_ptr, std::size_t size) { } } - CHECK(piece != nullptr) << "Error! : Out of memory when allocate size : " << size; + if (piece == nullptr) { + // NOTE(chengcheng): In some corner case on ubuntu, cuda memory not released even if OOM. + // So there need release all cuda memory allocated by this process before core dump. + LOG(INFO) << " OOM error is detected, process will exit. And it will start to reset CUDA " + << "device for release device memory."; + OF_CUDA_CHECK(cudaDeviceReset()); + LOG(FATAL) << "Error! : Out of memory when allocate size : " << size + << ".\n The total_memory_bytes allocated by this CudaAllocator is : " + << total_memory_bytes_; + } CHECK_NOTNULL(piece->ptr); CHECK(ptr2piece_.find(piece->ptr) != ptr2piece_.end()); *mem_ptr = piece->ptr; diff --git a/oneflow/core/vm/cuda_copy_d2h_stream_type.cpp b/oneflow/core/vm/cuda_copy_d2h_stream_type.cpp index c7a14b9df6c..472f2a74995 100644 --- a/oneflow/core/vm/cuda_copy_d2h_stream_type.cpp +++ b/oneflow/core/vm/cuda_copy_d2h_stream_type.cpp @@ -72,9 +72,8 @@ intrusive::shared_ptr CudaCopyD2HStreamType::MakeStreamDesc( std::size_t device_num = resource.gpu_device_num(); auto ret = intrusive::make_shared(); ret->mut_stream_type_id()->__Init__(LookupStreamType4TypeIndex()); - ret->set_num_machines(1); ret->set_num_streams_per_machine(device_num); - ret->set_num_streams_per_thread(1); + ret->set_num_streams_per_thread(device_num); return ret; } diff --git a/oneflow/core/vm/cuda_copy_h2d_stream_type.cpp b/oneflow/core/vm/cuda_copy_h2d_stream_type.cpp index 8cfd355a4b7..7332eed2345 100644 --- a/oneflow/core/vm/cuda_copy_h2d_stream_type.cpp +++ b/oneflow/core/vm/cuda_copy_h2d_stream_type.cpp @@ -65,9 +65,8 @@ intrusive::shared_ptr CudaCopyH2DStreamType::MakeStreamDesc( std::size_t device_num = resource.gpu_device_num(); auto ret = intrusive::make_shared(); ret->mut_stream_type_id()->__Init__(LookupStreamType4TypeIndex()); - ret->set_num_machines(1); ret->set_num_streams_per_machine(device_num); - ret->set_num_streams_per_thread(1); + ret->set_num_streams_per_thread(device_num); return ret; } diff --git a/oneflow/core/vm/cuda_stream_type.cpp b/oneflow/core/vm/cuda_stream_type.cpp index ad53e87f943..cef73d7f36f 100644 --- a/oneflow/core/vm/cuda_stream_type.cpp +++ b/oneflow/core/vm/cuda_stream_type.cpp @@ -74,9 +74,8 @@ intrusive::shared_ptr CudaStreamType::MakeStreamDesc(const Resource& std::size_t device_num = resource.gpu_device_num(); auto ret = intrusive::make_shared(); ret->mut_stream_type_id()->__Init__(LookupStreamType4TypeIndex()); - ret->set_num_machines(1); ret->set_num_streams_per_machine(device_num); - ret->set_num_streams_per_thread(1); + ret->set_num_streams_per_thread(device_num); return ret; } diff --git a/oneflow/core/vm/device_helper_stream_type.cpp b/oneflow/core/vm/device_helper_stream_type.cpp index 0c603a0a226..3a30518a4fc 100644 --- a/oneflow/core/vm/device_helper_stream_type.cpp +++ b/oneflow/core/vm/device_helper_stream_type.cpp @@ -65,9 +65,8 @@ intrusive::shared_ptr DeviceHelperStreamType::MakeStreamDesc( CHECK_GT(device_num, 0); auto ret = intrusive::make_shared(); ret->mut_stream_type_id()->__Init__(LookupStreamType4TypeIndex()); - ret->set_num_machines(1); ret->set_num_streams_per_machine(device_num); - ret->set_num_streams_per_thread(1); + ret->set_num_streams_per_thread(device_num); return ret; } diff --git a/oneflow/core/vm/host_stream_type.cpp b/oneflow/core/vm/host_stream_type.cpp index b5c024f843b..596d6e734ff 100644 --- a/oneflow/core/vm/host_stream_type.cpp +++ b/oneflow/core/vm/host_stream_type.cpp @@ -57,7 +57,6 @@ intrusive::shared_ptr HostStreamType::MakeStreamDesc(const Resource& int64_t this_machine_id) const { auto ret = intrusive::make_shared(); ret->mut_stream_type_id()->__Init__(LookupStreamType4TypeIndex()); - ret->set_num_machines(1); ret->set_num_streams_per_machine(1); ret->set_num_streams_per_thread(1); return ret; diff --git a/oneflow/core/vm/stream_desc.cpp b/oneflow/core/vm/stream_desc.cpp index fc2d7e0960d..6a39ba9b42a 100644 --- a/oneflow/core/vm/stream_desc.cpp +++ b/oneflow/core/vm/stream_desc.cpp @@ -18,16 +18,16 @@ limitations under the License. namespace oneflow { namespace vm { -void StreamDesc::__Init__(const StreamTypeId& stream_type_id, int32_t num_machines, - int32_t num_streams_per_machine, int32_t num_streams_per_thread) { +void StreamDesc::__Init__(const StreamTypeId& stream_type_id, int32_t num_streams_per_machine, + int32_t num_streams_per_thread) { mut_stream_type_id()->CopyFrom(stream_type_id); - set_num_machines(num_machines); set_num_streams_per_machine(num_streams_per_machine); set_num_streams_per_thread(num_streams_per_thread); } int32_t StreamDesc::num_threads() const { - int32_t num_devices = num_machines() * num_streams_per_machine(); + int32_t num_devices = num_streams_per_machine(); + if (num_devices == 0) { return 0; } CHECK_EQ(num_devices % num_streams_per_thread(), 0); return num_devices / num_streams_per_thread(); } diff --git a/oneflow/core/vm/stream_desc.h b/oneflow/core/vm/stream_desc.h index 948e2df27ff..e254bba576a 100644 --- a/oneflow/core/vm/stream_desc.h +++ b/oneflow/core/vm/stream_desc.h @@ -59,22 +59,20 @@ class StreamId final { class StreamDesc final : public intrusive::Base { public: // Getters - int32_t num_machines() const { return num_machines_; } int32_t num_streams_per_machine() const { return num_streams_per_machine_; } int32_t num_streams_per_thread() const { return num_streams_per_thread_; } const StreamTypeId& stream_type_id() const { return stream_type_id_.key().Get(); } // Setters - void set_num_machines(int32_t val) { num_machines_ = val; } void set_num_streams_per_machine(int32_t val) { num_streams_per_machine_ = val; } void set_num_streams_per_thread(int32_t val) { num_streams_per_thread_ = val; } StreamTypeId* mut_stream_type_id() { return stream_type_id_.mut_key()->Mutable(); } // methods void __Init__() {} - void __Init__(const StreamTypeId& stream_type_id, int32_t num_machines, - int32_t num_streams_per_machine, int32_t num_streams_per_thread); + void __Init__(const StreamTypeId& stream_type_id, int32_t num_streams_per_machine, + int32_t num_streams_per_thread); int32_t num_threads() const; - int32_t parallel_num() const { return num_machines() * num_streams_per_machine(); } + int32_t parallel_num() const { return num_streams_per_machine(); } private: friend class intrusive::Ref; @@ -82,13 +80,11 @@ class StreamDesc final : public intrusive::Base { StreamDesc() : intrusive_ref_(), - num_machines_(), num_streams_per_machine_(), num_streams_per_thread_(), stream_type_id_() {} intrusive::Ref intrusive_ref_; // fields - int32_t num_machines_; int32_t num_streams_per_machine_; int32_t num_streams_per_thread_; diff --git a/oneflow/core/vm/test_util.cpp b/oneflow/core/vm/test_util.cpp index eb115990bcc..1e399f39d1c 100644 --- a/oneflow/core/vm/test_util.cpp +++ b/oneflow/core/vm/test_util.cpp @@ -114,7 +114,7 @@ void TestUtil::AddStreamDescByInstrNames(VmDesc* vm_desc, int64_t parallel_num, const std::vector& instr_names) { auto Insert = [&](const std::string& instr_name) { const auto& stream_type_id = LookupInstrTypeId(instr_name).stream_type_id(); - auto stream_desc = intrusive::make_shared(stream_type_id, 1, parallel_num, 1); + auto stream_desc = intrusive::make_shared(stream_type_id, parallel_num, 1); vm_desc->mut_stream_type_id2desc()->Insert(stream_desc.Mutable()); }; for (const auto& instr_name : instr_names) { diff --git a/oneflow/core/vm/transport_stream_type.cpp b/oneflow/core/vm/transport_stream_type.cpp index cd64e9832f7..33d806eab8f 100644 --- a/oneflow/core/vm/transport_stream_type.cpp +++ b/oneflow/core/vm/transport_stream_type.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/vm/transport_stream_type.h" +#include "oneflow/core/common/multi_client.h" namespace oneflow { namespace vm { @@ -51,16 +52,19 @@ template intrusive::shared_ptr TransportStreamType::MakeTransportStreamDesc( const Resource& resource, int64_t this_machine_id) const { std::size_t device_num = 0; - if (resource.has_cpu_device_num()) { - device_num = std::max(device_num, resource.cpu_device_num()); - } - if (resource.has_gpu_device_num()) { - device_num = std::max(device_num, resource.gpu_device_num()); + if (!CHECK_JUST(IsMultiClient())) { + if (resource.has_cpu_device_num()) { + device_num = std::max(device_num, resource.cpu_device_num()); + } + if (resource.has_gpu_device_num()) { + device_num = std::max(device_num, resource.gpu_device_num()); + } + } else { + // Keep device_num = 0. TransportStreamType is not used in multi-client mode. } auto ret = intrusive::make_shared(); ret->mut_stream_type_id()->__Init__(LookupStreamType4TypeIndex()); // TODO(lixinqi): remove this ugly field - ret->set_num_machines(1); ret->set_num_streams_per_machine(device_num); // TODO(lixinqi): refactor to a num_threads_per_machine field ret->set_num_streams_per_thread(1); diff --git a/oneflow/ir/CMakeLists.txt b/oneflow/ir/CMakeLists.txt index 22bab214674..b0a17da4797 100644 --- a/oneflow/ir/CMakeLists.txt +++ b/oneflow/ir/CMakeLists.txt @@ -22,89 +22,16 @@ endif() project(oneflow-dialect LANGUAGES CXX C) -set(CMAKE_CXX_STANDARD 14 CACHE STRING "C++ standard to conform to") - -message("-- LLVM_MONO_REPO_URL: " ${LLVM_MONO_REPO_URL}) -message("-- LLVM_MONO_REPO_MD5: " ${LLVM_MONO_REPO_MD5}) -FetchContent_Declare( - llvm_monorepo -) -FetchContent_GetProperties(llvm_monorepo) - -if(NOT llvm_monorepo_POPULATED) - FetchContent_Populate(llvm_monorepo - URL ${LLVM_MONO_REPO_URL} - URL_HASH MD5=${LLVM_MONO_REPO_MD5} - ) - set(LLVM_INSTALL_DIR ${THIRD_PARTY_DIR}/llvm) - - execute_process(COMMAND "${CMAKE_COMMAND}" ${llvm_monorepo_SOURCE_DIR}/llvm - -DCMAKE_C_COMPILER_LAUNCHER=${CMAKE_C_COMPILER_LAUNCHER} - -DCMAKE_CXX_COMPILER_LAUNCHER=${CMAKE_CXX_COMPILER_LAUNCHER} - -DCMAKE_CUDA_COMPILER_LAUNCHER=${CMAKE_CUDA_COMPILER_LAUNCHER} - -DCMAKE_EXE_LINKER_FLAGS_INIT=${CMAKE_EXE_LINKER_FLAGS_INIT} - -DCMAKE_MODULE_LINKER_FLAGS_INIT=${CMAKE_MODULE_LINKER_FLAGS_INIT} - -DCMAKE_SHARED_LINKER_FLAGS_INIT=${CMAKE_SHARED_LINKER_FLAGS_INIT} - -DCMAKE_INSTALL_PREFIX=${LLVM_INSTALL_DIR} - -DLLVM_ENABLE_RTTI=ON # turn this on to make it compatible with protobuf - -DLLVM_ENABLE_EH=ON # turn this on to make it compatible with half (the library) - -DLLVM_BUILD_EXAMPLES=OFF - -DLLVM_BUILD_TOOLS=OFF - -DLLVM_INCLUDE_EXAMPLES=OFF - -DLLVM_INCLUDE_TESTS=OFF - -DLLVM_INCLUDE_BENCHMARKS=OFF - -DLLVM_TARGETS_TO_BUILD=host\;NVPTX - -DLLVM_ENABLE_ASSERTIONS=ON - -DLLVM_ENABLE_PROJECTS=mlir - -DLLVM_APPEND_VC_REV=OFF - -DLLVM_ENABLE_ZLIB=OFF - -DLLVM_INSTALL_UTILS=ON - -DBUILD_SHARED_LIBS=${BUILD_SHARED_LIBS} - -DLLVM_ENABLE_OCAMLDOC=OFF - -DLLVM_ENABLE_BINDINGS=OFF - -DMLIR_ENABLE_CUDA_RUNNER=${WITH_MLIR_CUDA_CODEGEN} - -DCMAKE_CUDA_COMPILER=${CMAKE_CUDA_COMPILER} - -G ${CMAKE_GENERATOR} - WORKING_DIRECTORY ${llvm_monorepo_BINARY_DIR} - RESULT_VARIABLE ret) - if(ret EQUAL "1") - message( FATAL_ERROR "Bad exit status") - endif() - include(ProcessorCount) - ProcessorCount(PROC_NUM) - execute_process(COMMAND "${CMAKE_COMMAND}" --build . -j${PROC_NUM} - WORKING_DIRECTORY ${llvm_monorepo_BINARY_DIR} - RESULT_VARIABLE ret - ) - if(ret EQUAL "1") - message( FATAL_ERROR "Bad exit status") - endif() - execute_process(COMMAND "${CMAKE_COMMAND}" --build . -j${PROC_NUM} --target install - WORKING_DIRECTORY ${llvm_monorepo_BINARY_DIR} - RESULT_VARIABLE ret - ) - if(ret EQUAL "1") - message( FATAL_ERROR "Bad exit status") - endif() - set(LLVM_DIR ${LLVM_INSTALL_DIR}/lib/cmake/llvm) - set(MLIR_DIR ${LLVM_INSTALL_DIR}/lib/cmake/mlir) +if(LLVM_PROVIDER STREQUAL "in-tree") + include(llvm-in-tree.cmake) +elseif(LLVM_PROVIDER STREQUAL "install") + include(install-llvm.cmake) +else() + message(FATAL_ERROR "LLVM_PROVIDER should be in-tree or install, but got: ${LLVM_PROVIDER}") endif() -find_package(MLIR REQUIRED CONFIG) - -message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}") -message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") - -set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin) -set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib) -set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) - -list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") -list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") -include(TableGen) -include(AddLLVM) -include(AddMLIR) -include(HandleLLVMOptions) +set_property(GLOBAL PROPERTY LLVM_INSTALL_DIR ${LLVM_INSTALL_DIR}) +set(MLIR_TABLEGEN_EXE mlir-tblgen) include_directories(${LLVM_INCLUDE_DIRS}) include_directories(${MLIR_INCLUDE_DIRS}) @@ -117,25 +44,47 @@ include_directories(${PROJECT_BINARY_DIR}/include) link_directories(${LLVM_BUILD_LIBRARY_DIR}) add_definitions(${LLVM_DEFINITIONS}) +if(LLVM_PROVIDER STREQUAL "in-tree") + add_subdirectory(${CMAKE_SOURCE_DIR}/tools/oneflow-tblgen ${PROJECT_BINARY_DIR}/oneflow-tblgen) +endif() + set_property(GLOBAL PROPERTY ALL_ONEFLOW_LIBS -Wl,--no-as-needed oneflow -Wl,--as-needed -Wl,--no-as-needed ${oneflow_exe_third_party_libs} -Wl,--as-needed -Wl,--no-as-needed ${oneflow_third_party_libs} -Wl,--as-needed ) -function(oneflow_add_llvm_library) - add_llvm_library(${ARGV}) +function(oneflow_add_mlir_library) + add_mlir_library(${ARGV}) set_compile_options_to_oneflow_target(${ARGV0}) + set_property(TARGET ${ARGV0} APPEND PROPERTY + BUILD_RPATH "${LLVM_LIBRARY_DIR}") + set_property(TARGET ${ARGV0} APPEND PROPERTY + INSTALL_RPATH "${LLVM_LIBRARY_DIR}") + set_property(TARGET ${ARGV0} APPEND PROPERTY + INSTALL_RPATH "${ONEFLOW_BUILD_ROOT_DIR}") endfunction() function(oneflow_add_mlir_dialect_library) add_mlir_dialect_library(${ARGV}) set_compile_options_to_oneflow_target(${ARGV0}) + set_property(TARGET ${ARGV0} APPEND PROPERTY + INSTALL_RPATH "${LLVM_LIBRARY_DIR}") + set_property(TARGET ${ARGV0} APPEND PROPERTY + INSTALL_RPATH "${ONEFLOW_BUILD_ROOT_DIR}") endfunction() +find_package(Threads REQUIRED) +set(LLVM_PTHREAD_LIB ${CMAKE_THREAD_LIBS_INIT}) + +set(LLVM_RUNTIME_OUTPUT_INTDIR ${PROJECT_BINARY_DIR}/bin) +set(LLVM_LIBRARY_OUTPUT_INTDIR ${PROJECT_BINARY_DIR}/lib) +if(WITH_MLIR) add_subdirectory(include) add_subdirectory(lib) add_subdirectory(test) add_subdirectory(oneflow-opt) add_subdirectory(oneflow-translate) +add_subdirectory(oneflow-runtime) add_subdirectory(oneflow-extension) +endif(WITH_MLIR) diff --git a/oneflow/ir/include/OneFlow/CMakeLists.txt b/oneflow/ir/include/OneFlow/CMakeLists.txt index df95367ceb4..8dd61fe2eb1 100644 --- a/oneflow/ir/include/OneFlow/CMakeLists.txt +++ b/oneflow/ir/include/OneFlow/CMakeLists.txt @@ -1,11 +1,5 @@ -set(ONEFLOW_USER_OP_GEN_TD_PATH "${PROJECT_BINARY_DIR}/include/OneFlow") -message(STATUS "Generating user op ODS ${ONEFLOW_USER_OP_GEN_TD_PATH}/OneFlowUserOpGen.td") -add_custom_target(GenUserOpODS - DEPENDS oneflow-gen-ods - COMMAND "$" - BYPRODUCTS OneFlowUserOpGen.td - WORKING_DIRECTORY "${ONEFLOW_USER_OP_GEN_TD_PATH}" -) +# set(ONEFLOW_USER_OP_GEN_TD_PATH "${PROJECT_BINARY_DIR}/include/OneFlow") +set(ONEFLOW_USER_OP_GEN_TD_PATH "${PROJECT_SOURCE_DIR}/include/OneFlow") set(LLVM_TARGET_DEFINITIONS OneFlowEnums.td) mlir_tablegen(OneFlowEnums.h.inc -gen-enum-decls) @@ -19,7 +13,6 @@ foreach (OP_GROUP_NAME IN LISTS ONEFLOW_OP_GROUPS_USED_IN_PATTERNS) endforeach() mlir_tablegen(OneFlowPatterns.cpp.inc -gen-rewriters) add_public_tablegen_target(MLIROneFlowPatternsIncGen) -add_dependencies(MLIROneFlowPatternsIncGen GenUserOpODS) # NOTE: seperate conversion and opt with --name set(LLVM_TARGET_DEFINITIONS OneFlowOps.td) @@ -39,15 +32,15 @@ foreach (OP_GROUP_NAME IN LISTS ONEFLOW_OP_GROUPS) set(LLVM_TABLEGEN_FLAGS "${ONE_LLVM_TABLEGEN_FLAGS}") string(TOLOWER "${OP_GROUP_NAME}" OP_GROUP_NAME_LOWER) set(CPP_INC_FILE "OneFlow.${OP_GROUP_NAME_LOWER}_ops.cpp.inc") + set(HEADER_INC_FILE "OneFlow.${OP_GROUP_NAME_LOWER}_ops.h.inc") mlir_tablegen(${CPP_INC_FILE} -gen-op-defs) + mlir_tablegen(${HEADER_INC_FILE} -gen-op-decls) endforeach() add_public_tablegen_target(MLIROneFlowOpGroupDefsIncGen) -add_dependencies(MLIROneFlowOpGroupDefsIncGen GenUserOpODS) set(LLVM_TABLEGEN_FLAGS "${FULL_LLVM_TABLEGEN_FLAGS}") -mlir_tablegen(OneFlow.Ops.h.inc -gen-op-decls) +mlir_tablegen(OneFlow.gen_ops.h.inc -gen-op-decls) add_public_tablegen_target(MLIROneFlowOpGroupDeclsIncGen) -add_dependencies(MLIROneFlowOpGroupDeclsIncGen GenUserOpODS) set(LLVM_TABLEGEN_FLAGS "") add_mlir_dialect( diff --git a/oneflow/ir/include/OneFlow/OneFlowBase.td b/oneflow/ir/include/OneFlow/OneFlowBase.td index 029ed170986..33553cc13d5 100644 --- a/oneflow/ir/include/OneFlow/OneFlowBase.td +++ b/oneflow/ir/include/OneFlow/OneFlowBase.td @@ -3,30 +3,24 @@ include "OneFlow/OneFlowDialect.td" include "OneFlow/OneFlowInterfaces.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" -def SI32ArrayAttr : TypedArrayAttrBase { - let constBuilderCall = "$_builder.getArrayAttr(llvm::to_vector<8>(llvm::map_range(values, [this](int32_t v) -> Attribute { return builder_.getSI32IntegerAttr($0); })))"; -} +def OneFlow_Tensor : TensorOf<[AnyType]>; +def SI32ArrayAttr : TypedArrayAttrBase {} -def SI64ArrayAttr : TypedArrayAttrBase { - let constBuilderCall = "$_builder.getArrayAttr(llvm::to_vector<8>(llvm::map_range(values, [this](int64_t v) -> Attribute { return builder_.getSI64IntegerAttr($0); })))"; -} +def SI64ArrayAttr : TypedArrayAttrBase {} -def DTArrayAttr : TypedArrayAttrBase { - let constBuilderCall = "$_builder.getArrayAttr(llvm::to_vector<8>(llvm::map_range(values, [this](auto v) -> Attribute { return DataTypeAttr::get($0); })))"; -} +def ShapeAttr : TypedArrayAttrBase {} -def ShapeArrayAttr : TypedArrayAttrBase { - let constBuilderCall = "$_builder.getArrayAttr(llvm::to_vector<8>(llvm::map_range(values, [this](auto v) -> Attribute { return DenseIntElementsAttr::get($0); })))"; -} +def DTArrayAttr : TypedArrayAttrBase {} + +def ShapeArrayAttr : TypedArrayAttrBase {} def OneFlow_IsOpConfCompatible : NativeOpTrait<"IsOpConfCompatible">; def OneFlow_IsImportCompatible : NativeOpTrait<"IsImportCompatible">; def OneFlow_AlternativeOp : NativeOpTrait<"IsAlternative">; +def OneFlow_TensorSource : NativeOpTrait<"TensorSource">; class OneFlow_BaseOp traits = []> : Op { @@ -40,8 +34,8 @@ class OneFlow_BaseOp traits = []> : dag attrs = (ins); dag trait_attrs = (ins); dag user_op_attrs = (ins); - dag input = (ins Variadic:$data_input); - dag output = (outs Variadic:$data_output); + dag input = (ins); + dag output = (outs); dag ctrl_input = (ins); dag ctrl_output = (outs); let arguments = !con( @@ -56,6 +50,19 @@ class OneFlow_BaseOp traits = []> : output, ctrl_output ); + int same_output_regst_num = -1; + + bit has_check_fn = 0; + bit has_logical_tensor_desc_infer_fn = 0; + bit has_physical_tensor_desc_infer_fn = 0; + bit has_get_sbp_fn = 0; + bit has_sbp_signature_infer_fn = 0; + bit has_data_type_infer_fn = 0; + bit has_device_infer_fn = 0; + bit has_input_arg_modify_fn = 0; + bit has_output_arg_modify_fn = 0; + bit has_output_blob_time_shape_infer_fn = 0; + bit has_nd_sbp_infer_fn = 0; } class OneFlow_Op traits = []> : @@ -109,17 +116,22 @@ class OneFlow_ConvolutionBaseOp traits = []> : ); let output = (outs AnyType:$out); let attrs = (ins - SI32Attr:$filters, + DefaultValuedAttr:$filters, SI32ArrayAttr:$padding_before, StrAttr:$data_format, SI32ArrayAttr:$kernel_size, SI32ArrayAttr:$strides, SI32ArrayAttr:$dilation_rate, - DefaultValuedAttr:$group + DefaultValuedAttr:$groups ); let trait_attrs = (ins I32ElementsAttr:$operand_segment_sizes ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; } class OneFlow_TFPoolBaseOp traits = []> : @@ -136,6 +148,10 @@ class OneFlow_TFPoolBaseOp traits = []> : SI32ArrayAttr:$strides, BoolAttr:$ceil_mode ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; } class OneFlow_TFPoolGradBaseOp traits = []> : @@ -156,43 +172,113 @@ class OneFlow_TFPoolGradBaseOp traits = []> : SI32ArrayAttr:$strides, BoolAttr:$ceil_mode ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; } -class OneFlow_PoolBaseOp traits = []> : +class OneFlow_MaxPoolBaseOp traits = []> : OneFlow_BaseOp])> { - let summary = "OneFlow pooling operation"; - let input = (ins AnyType:$x); - let output = (outs AnyType:$y); + let summary = "OneFlow Max Pooling operation"; + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y, + AnyType:$indice + ); let attrs = (ins SI32ArrayAttr:$padding, StrAttr:$data_format, SI32ArrayAttr:$kernel_size, SI32ArrayAttr:$stride, - BoolAttr:$ceil_mode, - BoolAttr:$count_include_pad, - SI64Attr:$divisor_override + SI32ArrayAttr:$dilation, + DefaultValuedAttr:$return_indices, + DefaultValuedAttr:$ceil_mode ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; } -class OneFlow_PoolGradBaseOp traits = []> : +class OneFlow_AvgPoolBaseOp traits = []> : OneFlow_BaseOp])> { - let summary = "OneFlow pooling grad operation"; + let summary = "OneFlow Average Pooling operation"; + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + SI32ArrayAttr:$padding, + StrAttr:$data_format, + SI32ArrayAttr:$kernel_size, + SI32ArrayAttr:$stride, + DefaultValuedAttr:$ceil_mode, + DefaultValuedAttr:$count_include_pad, + DefaultValuedAttr:$divisor_override + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +class OneFlow_MaxPoolGradBaseOp traits = []> : + OneFlow_BaseOp])> { + let summary = "OneFlow Max Pooling Grad operation"; let input = (ins AnyType:$x, AnyType:$y, + AnyType:$indice, AnyType:$dy ); - let output = (outs AnyType:$dx); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + SI32ArrayAttr:$padding, + StrAttr:$data_format, + SI32ArrayAttr:$kernel_size, + SI32ArrayAttr:$stride, + SI32ArrayAttr:$dilation, + DefaultValuedAttr:$return_indices, + DefaultValuedAttr:$ceil_mode + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +class OneFlow_AvgPoolGradBaseOp traits = []> : + OneFlow_BaseOp])> { + let summary = "OneFlow Average Pooling Grad operation"; + let input = (ins + AnyType:$x, + AnyType:$y, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); let attrs = (ins SI32ArrayAttr:$padding, StrAttr:$data_format, SI32ArrayAttr:$kernel_size, SI32ArrayAttr:$stride, - BoolAttr:$ceil_mode, - BoolAttr:$count_include_pad, - SI64Attr:$divisor_override + DefaultValuedAttr:$ceil_mode, + DefaultValuedAttr:$count_include_pad, + DefaultValuedAttr:$divisor_override ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; } class OneFlow_AdaptivePoolBaseOp traits = []> : @@ -205,6 +291,10 @@ class OneFlow_AdaptivePoolBaseOp traits = []> : let attrs = (ins SI64ArrayAttr:$output_size ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; } class OneFlow_AdaptivePoolGradBaseOp traits = []> : @@ -218,6 +308,10 @@ class OneFlow_AdaptivePoolGradBaseOp traits = []> let attrs = (ins SI64ArrayAttr:$output_size ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; } class OneFlow_UnaryBaseOp traits = []> : @@ -225,21 +319,23 @@ class OneFlow_UnaryBaseOp traits = []> : let summary = ""; let input = (ins AnyType:$x); let output = (outs AnyType:$y); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; } def OneFlow_Idempotent : NativeOpTrait<"IsIdempotentOfIdenticalPlacement">; class OneFlow_IdempotentBaseOp traits = []> : - OneFlow_UnaryBaseOp { -} + OneFlow_UnaryBaseOp {} def OneFlow_Involution : NativeOpTrait<"IsInvolutionOfIdenticalPlacement">; class OneFlow_InvolutionBaseOp traits = []> : - OneFlow_UnaryBaseOp { -} + OneFlow_UnaryBaseOp {} #define GET_ONEFLOW_BASE_OP_DEFINITIONS -include "OneFlow/OneFlowUserOpGen.td" +include "OneFlow/OneFlowUserOps.td" #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWBASE_H_ diff --git a/oneflow/ir/include/OneFlow/OneFlowDialect.h b/oneflow/ir/include/OneFlow/OneFlowDialect.h index 0c2c7e35bf8..34f05494168 100644 --- a/oneflow/ir/include/OneFlow/OneFlowDialect.h +++ b/oneflow/ir/include/OneFlow/OneFlowDialect.h @@ -17,6 +17,7 @@ limitations under the License. #define ONEFLOW_ONEFLOWDIALECT_H #include "mlir/IR/Dialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "OneFlow/OneFlowOpsDialect.h.inc" diff --git a/oneflow/ir/include/OneFlow/OneFlowDialect.td b/oneflow/ir/include/OneFlow/OneFlowDialect.td index 8e621ef2aaf..86833eaf1d3 100644 --- a/oneflow/ir/include/OneFlow/OneFlowDialect.td +++ b/oneflow/ir/include/OneFlow/OneFlowDialect.td @@ -10,6 +10,9 @@ def OneFlow_Dialect : Dialect { This dialect is the IR of OneFlow. }]; let cppNamespace = "::mlir::oneflow"; + let dependentDialects = [ + "StandardOpsDialect" + ]; } #endif // ONEFLOW_DIALECT diff --git a/oneflow/ir/include/OneFlow/OneFlowInterfaces.td b/oneflow/ir/include/OneFlow/OneFlowInterfaces.td index 85e8a6f6e6e..1aa344884bd 100644 --- a/oneflow/ir/include/OneFlow/OneFlowInterfaces.td +++ b/oneflow/ir/include/OneFlow/OneFlowInterfaces.td @@ -73,4 +73,15 @@ def ControlEdgeCompatibleInterface : OpInterface<"ControlEdgeCompatible"> { ]; } +def NoGrad : OpInterface<"NoGrad"> { + let description = [{ + }]; +} + +def CpuOnly : OpInterface<"CpuOnly"> { + let description = [{ + }]; +} + + #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWINTERFACES_H_ diff --git a/oneflow/ir/include/OneFlow/OneFlowOpTraits.h b/oneflow/ir/include/OneFlow/OneFlowOpTraits.h new file mode 100644 index 00000000000..715bc785cc5 --- /dev/null +++ b/oneflow/ir/include/OneFlow/OneFlowOpTraits.h @@ -0,0 +1,131 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWOPTRAITS_H_ +#define ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWOPTRAITS_H_ + +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" + +namespace mlir { + +namespace OpTrait { + +namespace impl { + +OpFoldResult foldIdempotentOfIdenticalPlacement(Operation* op); +OpFoldResult foldInvolutionOfIdenticalPlacement(Operation* op); +LogicalResult VerifyIsOpConfCompatible(Operation* op); +LogicalResult VerifyIsImportCompatible(Operation* op); + +} // namespace impl + +template +class IsOpConfCompatible : public TraitBase { + public: + static StringRef getOpNameAttr() { return "op_name"; } + static StringRef getDeviceTagAttr() { return "device_tag"; } + static StringRef getDeviceNameAttr() { return "device_name"; } + static StringRef getScopeSymbolIDAttr() { return "scope_symbol_id"; } + static StringRef getHierarchyAttr() { return "hierarchy"; } + static LogicalResult verifyTrait(Operation* op) { return impl::VerifyIsOpConfCompatible(op); } +}; + +template +class IsImportCompatible : public TraitBase { + public: + static StringRef getOutputLBNsAttr() { return "output_lbns"; } + static LogicalResult verifyTrait(Operation* op) { return impl::VerifyIsImportCompatible(op); } +}; + +template +class IsIdempotentOfIdenticalPlacement + : public TraitBase { + public: + static LogicalResult verifyTrait(Operation* op) { + static_assert(ConcreteType::template hasTrait(), + "expected operation to produce one result"); + static_assert(ConcreteType::template hasTrait(), + "expected operation to take one operand"); + static_assert(ConcreteType::template hasTrait(), + "expected operation to preserve type"); + static_assert(ConcreteType::template hasTrait(), + "expected operation to be op conf compatible"); + return impl::verifyIsIdempotent(op); + } + + static OpFoldResult foldTrait(Operation* op, ArrayRef operands) { + return impl::foldIdempotentOfIdenticalPlacement(op); + } +}; + +template +class IsInvolutionOfIdenticalPlacement + : public TraitBase { + public: + static LogicalResult verifyTrait(Operation* op) { + static_assert(ConcreteType::template hasTrait(), + "expected operation to produce one result"); + static_assert(ConcreteType::template hasTrait(), + "expected operation to take one operand"); + static_assert(ConcreteType::template hasTrait(), + "expected operation to preserve type"); + static_assert(ConcreteType::template hasTrait(), + "expected operation to be op conf compatible"); + return impl::verifyIsInvolution(op); + } + + static OpFoldResult foldTrait(Operation* op, ArrayRef operands) { + return impl::foldInvolutionOfIdenticalPlacement(op); + } +}; + +template +class IsAlternative : public TraitBase { + public: + static StringRef getOpTypeNameAttr() { return "op_type_name"; } + static LogicalResult verifyTrait(Operation* op) { + if (op->hasAttrOfType(getOpTypeNameAttr())) { + return success(); + } else { + return op->emitError("expected operation to have attribute: " + getOpTypeNameAttr()); + } + } +}; + +template +class TensorSource : public TraitBase { + public: + static StringRef getShapeAttrName() { return "shape"; } + static StringRef getDataTypeAttrName() { return "data_type"; } + static StringRef getIsDynamicAttrName() { return "is_dynamic"; } + static StringRef getNdSbpAttrName() { return "nd_sbp"; } + + static LogicalResult verifyTrait(Operation* op) { + if (!op->hasAttrOfType(getShapeAttrName())) { + return op->emitError("expected operation to have attribute: " + getShapeAttrName()); + } + if (!op->hasAttrOfType(getDataTypeAttrName())) { + return op->emitError("expected operation to have attribute: " + getDataTypeAttrName()); + } + return success(); + } +}; + +} // namespace OpTrait + +} // namespace mlir + +#endif // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWOPTRAITS_H_ diff --git a/oneflow/ir/include/OneFlow/OneFlowOps.h b/oneflow/ir/include/OneFlow/OneFlowOps.h index 958a10519bb..292efa18ad3 100644 --- a/oneflow/ir/include/OneFlow/OneFlowOps.h +++ b/oneflow/ir/include/OneFlow/OneFlowOps.h @@ -18,106 +18,23 @@ limitations under the License. #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/FunctionSupport.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" - -#include "mlir/IR/Builders.h" -#include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "OneFlow/OneFlowSupport.h" #include "OneFlow/OneFlowInterfaces.h.inc" +#include "OneFlow/OneFlowOpTraits.h" #include "OneFlow/OneFlowEnums.h.inc" +#include "OneFlow/OneFlowOpTraits.h" namespace mlir { class FuncOp; -namespace OpTrait { - -namespace impl { - -OpFoldResult foldIdempotentOfIdenticalPlacement(Operation* op); -OpFoldResult foldInvolutionOfIdenticalPlacement(Operation* op); -LogicalResult VerifyIsOpConfCompatible(Operation* op); -LogicalResult VerifyIsImportCompatible(Operation* op); - -} // namespace impl - -template -class IsOpConfCompatible : public TraitBase { - public: - static StringRef getOpNameAttr() { return "op_name"; } - static StringRef getDeviceTagAttr() { return "device_tag"; } - static StringRef getDeviceNameAttr() { return "device_name"; } - static StringRef getScopeSymbolIDAttr() { return "scope_symbol_id"; } - static StringRef getHierarchyAttr() { return "hierarchy"; } - static LogicalResult verifyTrait(Operation* op) { return impl::VerifyIsOpConfCompatible(op); } -}; - -template -class IsImportCompatible : public TraitBase { - public: - static StringRef getOutputLBNsAttr() { return "output_lbns"; } - static LogicalResult verifyTrait(Operation* op) { return impl::VerifyIsImportCompatible(op); } -}; - -template -class IsIdempotentOfIdenticalPlacement - : public TraitBase { - public: - static LogicalResult verifyTrait(Operation* op) { - static_assert(ConcreteType::template hasTrait(), - "expected operation to produce one result"); - static_assert(ConcreteType::template hasTrait(), - "expected operation to take one operand"); - static_assert(ConcreteType::template hasTrait(), - "expected operation to preserve type"); - static_assert(ConcreteType::template hasTrait(), - "expected operation to be op conf compatible"); - return impl::verifyIsIdempotent(op); - } - - static OpFoldResult foldTrait(Operation* op, ArrayRef operands) { - return impl::foldIdempotentOfIdenticalPlacement(op); - } -}; - -template -class IsInvolutionOfIdenticalPlacement - : public TraitBase { - public: - static LogicalResult verifyTrait(Operation* op) { - static_assert(ConcreteType::template hasTrait(), - "expected operation to produce one result"); - static_assert(ConcreteType::template hasTrait(), - "expected operation to take one operand"); - static_assert(ConcreteType::template hasTrait(), - "expected operation to preserve type"); - static_assert(ConcreteType::template hasTrait(), - "expected operation to be op conf compatible"); - return impl::verifyIsInvolution(op); - } - - static OpFoldResult foldTrait(Operation* op, ArrayRef operands) { - return impl::foldInvolutionOfIdenticalPlacement(op); - } -}; - -template -class IsAlternative : public TraitBase { - public: - static StringRef getOpTypeNameAttr() { return "op_type_name"; } - static LogicalResult verifyTrait(Operation* op) { - if (op->hasAttrOfType(getOpTypeNameAttr())) { - return success(); - } else { - return op->emitError("expected operation to have attribute: " + getOpTypeNameAttr()); - } - } -}; - -} // namespace OpTrait - template inline std::string GetOpTypeName(T op) { std::string op_type_name = op->getName().stripDialect().str(); @@ -137,6 +54,6 @@ inline std::string GetOpTypeName(T op) { #define GET_OP_CLASSES #include "OneFlow/OneFlowOps.h.inc" #define GET_OP_CLASSES -#include "OneFlow/OneFlow.Ops.h.inc" +#include "OneFlow/OneFlow.gen_ops.h.inc" #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWOPS_H_ diff --git a/oneflow/ir/include/OneFlow/OneFlowOps.td b/oneflow/ir/include/OneFlow/OneFlowOps.td index ddf87bb0b19..be1433b77db 100644 --- a/oneflow/ir/include/OneFlow/OneFlowOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowOps.td @@ -4,27 +4,12 @@ include "OneFlow/OneFlowDialect.td" include "OneFlow/OneFlowEnums.td" include "OneFlow/OneFlowInterfaces.td" -include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/Pass/PassBase.td" -include "mlir/Interfaces/CallInterfaces.td" include "OneFlow/OneFlowBase.td" -def OneFlow_UserOp : OneFlow_UserBaseWithCtrlOp<"user", [OneFlow_IsImportCompatible]> { - let summary = ""; - let attrs = (ins - StrArrayAttr:$output_lbns - ); - let hasCanonicalizer = 1; -} - -def OneFlow_SystemOp : OneFlow_Op<"system", [OneFlow_IsImportCompatible]> { - let summary = ""; - let attrs = (ins - StrArrayAttr:$input_bns, - StrArrayAttr:$output_lbns - ); - let hasCanonicalizer = 1; -} +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/CallInterfaces.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Pass/PassBase.td" def OneFlow_NormalizationAddReluOp : OneFlow_NormalizationAddReluBaseOp { let builders = [ @@ -50,6 +35,29 @@ def OneFlow_NormalizationAddReluOp : OneFlow_NormalizationAddReluBaseOp { ]; } +#ifndef REMOVE_ONEFLOW_MLIR_ONLY_OP_DEFINITIONS + +def OneFlow_UserOp : OneFlow_UserBaseWithCtrlOp<"user", [OneFlow_IsImportCompatible]> { + let summary = ""; + let input = (ins Variadic:$data_input); + let output = (outs Variadic:$data_output); + let attrs = (ins + StrArrayAttr:$output_lbns + ); + let hasCanonicalizer = 1; +} + +def OneFlow_SystemOp : OneFlow_Op<"system", [OneFlow_IsImportCompatible]> { + let summary = ""; + let input = (ins Variadic:$data_input); + let output = (outs Variadic:$data_output); + let attrs = (ins + StrArrayAttr:$input_bns, + StrArrayAttr:$output_lbns + ); + let hasCanonicalizer = 1; +} + def OneFlow_Add2Op : OneFlow_BaseOp<"add_n2", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = ""; let input = (ins @@ -62,6 +70,8 @@ def OneFlow_Add2Op : OneFlow_BaseOp<"add_n2", [NoSideEffect, DeclareOpInterfaceM // JIT ops def OneFlow_MlirJitOp : OneFlow_BaseOp<"mlir_jit", [ CallOpInterface, DeclareOpInterfaceMethods ] > { + let input = (ins Variadic:$data_input); + let output = (outs Variadic:$data_output); let attrs = (ins FlatSymbolRefAttr:$callee, StrAttr:$mlir_assembly @@ -92,6 +102,138 @@ def OneFlow_MlirJitOp : OneFlow_BaseOp<"mlir_jit", [ CallOpInterface, DeclareOpI }]; } +class OneFlow_ConcreteSystemOp traits = []> : + OneFlow_BaseOp])> { + let input = (ins); + let output = (ins); + let ctrl_input = (ins Variadic:$ctrl_inputs); + let ctrl_output = (outs Optional:$ctrl_output); + dag required_attrs = (ins StrArrayAttr:$output_lbns); + dag custom_attrs = (ins); + let attrs = !con( + required_attrs, + custom_attrs + ); + let hasCanonicalizer = 1; +} + +def OneFlow_VariableOp : OneFlow_ConcreteSystemOp<"variable", [OneFlow_TensorSource]> { + let summary = ""; + let input = (ins); + let output = (outs AnyType:$output); + let custom_attrs = (ins + ShapeAttr:$shape, + OptionalAttr:$data_type, + DefaultValuedAttr:$model_name, + DefaultValuedAttr:$l1_regularization, + DefaultValuedAttr:$l2_regularization, + DefaultValuedAttr:$trainable, + StrArrayAttr:$nd_sbp + ); +} + +def OneFlow_InputOp : OneFlow_ConcreteSystemOp<"input", [OneFlow_TensorSource]> { + let summary = ""; + let input = (ins AnyType:$input); + let output = (outs AnyType:$output); + let custom_attrs = (ins + OptionalAttr:$shape, + OptionalAttr:$data_type, + OptionalAttr:$is_dynamic, + OptionalAttr:$nd_sbp, + OptionalAttr:$job_name + ); + let builders = [ + OpBuilder<(ins + "::oneflow::OperatorConf":$op_conf + )> + ]; +} + +def OneFlow_OutputOp : OneFlow_ConcreteSystemOp<"output", [OneFlow_TensorSource]> { + let summary = ""; + let input = (ins AnyType:$input); + let output = (outs AnyType:$output); + let custom_attrs = (ins + OptionalAttr:$shape, + OptionalAttr:$data_type, + OptionalAttr:$is_dynamic, + OptionalAttr:$nd_sbp, + OptionalAttr:$job_name + ); +} + +def OneFlow_Job : Op { + let regions = (region AnyRegion:$body); + + let arguments = (ins + SymbolNameAttr:$sym_name, + TypeAttr:$type, + OptionalAttr:$sym_visibility + ); + + let builders = [OpBuilder<(ins + "StringRef":$name, "FunctionType":$type) + >]; + + let extraClassDeclaration = [{ + bool isDeclaration() { return isExternal(); } + + private: + friend class OpTrait::FunctionLike; + + unsigned getNumFuncArguments() { return getType().getInputs().size(); } + + unsigned getNumFuncResults() { return getType().getResults().size(); } + + LogicalResult verifyType() { + auto type = getTypeAttr().getValue(); + if (!type.isa()) + return emitOpError("requires '" + getTypeAttrName() + + "' attribute of function type"); + return success(); + } + }]; + + let parser = [{ return ::mlir::oneflow::parseJob(parser, result); }]; + let printer = [{ return ::mlir::oneflow::print(*this, p); }]; + let verifier = [{ return ::mlir::oneflow::verify(*this); }]; +} + +def OneFlow_ReturnOp : Op, + MemRefsNormalizable, ReturnLike, Terminator]> { + let summary = "return operation"; + let description = [{ + The "return" operation represents a return operation within a Job. + The operation takes an optional tensor operand and produces no results. + The operand type must match the signature of the job function that contains + the operation. For example: + + ```mlir + job @foo() -> tensor<2xf64> { + ... + oneflow.return %0 : tensor<2xf64> + } + ``` + }]; + + let arguments = (ins Variadic:$operands); + + let builders = [ + OpBuilder<(ins), + [{ build($_builder, $_state, llvm::None); }]>]; + + let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; + + let printer = [{ return ::mlir::oneflow::print(p, *this); }]; + let verifier = [{ return ::mlir::oneflow::verify(*this); }]; + let parser = [{ return ::mlir::oneflow::parse$cppClass(parser, result); }]; +} + +#endif // REMOVE_ONEFLOW_MLIR_ONLY_OP_DEFINITIONS + #endif // ONEFLOW_OPS #ifndef ONEFLOW_PASSES @@ -103,7 +245,7 @@ def LowerOneFlowToTosaPass : Pass<"lower-oneflow-to-tosa", "ModuleOp"> { let dependentDialects = ["tosa::TosaDialect", "memref::MemRefDialect", "StandardOpsDialect"]; } -def MapSCFToGPUPass : Pass<"gpu-greedy-parallel-loop-mapping", "FuncOp"> { +def MapSCFToGPUPass : Pass<"gpu-greedy-parallel-loop-mapping", "ModuleOp"> { let summary = "Greedily maps all parallel loops to gpu hardware ids"; let constructor = "mlir::oneflow::createMapSCFToGPUPass()"; let dependentDialects = ["scf::SCFDialect"]; @@ -120,7 +262,7 @@ def OutlineJitFunctionPass : Pass<"outline-jit-function", "ModuleOp"> { let constructor = "mlir::oneflow::createOutlineJitFunctionPass()"; } -def FuseIntoExistingOpPass : Pass<"fuse-into-existing-op", "FuncOp"> { +def FuseIntoExistingOpPass : Pass<"fuse-into-existing-op", "ModuleOp"> { let summary = ""; let constructor = "mlir::oneflow::createFuseIntoExistingOpPass()"; } diff --git a/oneflow/ir/include/OneFlow/OneFlowPatterns.td b/oneflow/ir/include/OneFlow/OneFlowPatterns.td index c007684c1a8..a27f1a9f07e 100644 --- a/oneflow/ir/include/OneFlow/OneFlowPatterns.td +++ b/oneflow/ir/include/OneFlow/OneFlowPatterns.td @@ -4,7 +4,7 @@ include "OneFlow/OneFlowOps.td" -def IsNotNestedInJit: ConstraintgetParentOfType<::mlir::FuncOp>()->hasAttr(\"llvm.emit_c_interface\"))">, "">; +def IsNotNestedInJit: ConstraintgetParentOfType<::mlir::oneflow::Job>())">, "">; def OutlineMulCast : NativeCodeCall<"::mlir::oneflow::OutlineMulCast($_builder, $0, $1)">; // TODO: remove attr binding if possible def MulCastPattern : Pat< diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td new file mode 100644 index 00000000000..59ef7348519 --- /dev/null +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -0,0 +1,8713 @@ +// ASSIGN;BASE;BINARY;BROADCAST;CONV;CROSS_ENTROPY;CUDA;DATASET;DETECTION;EAGER;FUSED;IDEMPOTENT;IDENTITY;IMAGE;INDICES;INVOLUTION;LOSS;MATH;MATMUL;MISC;NCCL;NORMALIZATION;OPTIMIZER;PADDING;PARALLEL_CAST;POOL;QUANTIZATION;REDUCE;RESHAPE;SCALAR;SOFTMAX;SUMMARY;TENSOR_BUFFER;TEST;TRIGONOMETRIC;UNARY;UPSAMPLE + +/* +#define GET_OP_LIST +#include "OneFlow/OneFlow.assign_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.binary_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.broadcast_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.conv_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.cross_entropy_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.cuda_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.dataset_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.detection_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.eager_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.fused_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.idempotent_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.identity_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.image_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.indices_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.involution_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.loss_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.math_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.matmul_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.misc_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.nccl_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.normalization_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.optimizer_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.padding_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.parallel_cast_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.pool_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.quantization_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.reduce_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.reshape_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.scalar_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.softmax_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.summary_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.tensor_buffer_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.trigonometric_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.unary_ops.cpp.inc" +, +#define GET_OP_LIST +#include "OneFlow/OneFlow.upsample_ops.cpp.inc" +*/ + +// Group: ASSIGN +// assign, assign_if, assign_if_not, logical_slice_assign +// Total: 4 + +#ifdef GET_ONEFLOW_ASSIGN_OP_DEFINITIONS + +def OneFlow_AssignUserOp : OneFlow_BaseOp<"assign", [NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$ref, + OneFlow_Tensor:$value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_AssignIfOp : OneFlow_BaseOp<"assign_if", [NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$ref, + OneFlow_Tensor:$value, + OneFlow_Tensor:$condition + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_AssignIfNotOp : OneFlow_BaseOp<"assign_if_not", [NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$ref, + OneFlow_Tensor:$value, + OneFlow_Tensor:$condition + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_LogicalSliceAssignOp : OneFlow_BaseOp<"logical_slice_assign", [DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$ref, + OneFlow_Tensor:$value + ); + let attrs = (ins + SI64ArrayAttr:$start, + SI64ArrayAttr:$stop, + SI64ArrayAttr:$step + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +#endif // GET_ONEFLOW_ASSIGN_OP_DEFINITIONS + +// Group: BASE +// normalization_add_relu +// Total: 1 + +#ifdef GET_ONEFLOW_BASE_OP_DEFINITIONS + +class OneFlow_NormalizationAddReluBaseOp : OneFlow_BaseOp<"normalization_add_relu", [NoSideEffect, AttrSizedOperandSegments, AttrSizedResultSegments, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + Optional:$addend, + Optional:$moving_mean, + Optional:$moving_variance, + OneFlow_Tensor:$gamma, + OneFlow_Tensor:$beta + ); + let output = (outs + OneFlow_Tensor:$y, + OneFlow_Tensor:$reserve_space, + Optional:$mean, + Optional:$inv_variance + ); + let attrs = (ins + DefaultValuedAttr:$axis, + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$training, + DefaultValuedAttr:$momentum + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes, + I32ElementsAttr:$result_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +#endif // GET_ONEFLOW_BASE_OP_DEFINITIONS + +// Group: BINARY +// bias_add, cast_like, celu_grad, diag_grad, diagonal_grad, dot, dropout_grad, elementwise_maximum, elementwise_minimum, elu_grad, floordiv, gelu_grad, grid_sample, hardsigmoid_grad, hardswish_grad, l1_l2_regularize_gradient, leaky_relu_grad, masked_fill, mish_grad, multiply, narrow_grad, pow, prelu, relu_grad, selu_grad, sigmoid_grad, silu_grad, tf_prelu, unfold_tensor_grad, xdivy, xlogy +// Total: 31 + +#ifdef GET_ONEFLOW_BINARY_OP_DEFINITIONS + +def OneFlow_BiasAddOp : OneFlow_BaseOp<"bias_add", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$a, + OneFlow_Tensor:$b + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CastLikeOp : OneFlow_BaseOp<"cast_like", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in, + OneFlow_Tensor:$dtype_like + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_CeluGradOp : OneFlow_BaseOp<"celu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$alpha + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_DiagGradOp : OneFlow_BaseOp<"diag_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$diagonal + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_DiagonalGradOp : OneFlow_BaseOp<"diagonal_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$offset + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_DotOp : OneFlow_BaseOp<"dot", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_DropoutGradOp : OneFlow_BaseOp<"dropout_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$mask + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$scale + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ElementwiseMaximumOp : OneFlow_BaseOp<"elementwise_maximum", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ElementwiseMinimumOp : OneFlow_BaseOp<"elementwise_minimum", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_EluGradOp : OneFlow_BaseOp<"elu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$alpha + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FloordivOp : OneFlow_BaseOp<"floordiv", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_GeluGradOp : OneFlow_BaseOp<"gelu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_GridSampleOp : OneFlow_BaseOp<"grid_sample", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input, + OneFlow_Tensor:$grid + ); + let output = (outs + OneFlow_Tensor:$output + ); + let attrs = (ins + StrAttr:$interpolation_mode, + StrAttr:$padding_mode, + DefaultValuedAttr:$align_corners + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_HardsigmoidGradOp : OneFlow_BaseOp<"hardsigmoid_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_HardswishGradOp : OneFlow_BaseOp<"hardswish_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_L1L2RegularizeGradientOp : OneFlow_BaseOp<"l1_l2_regularize_gradient", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$model, + OneFlow_Tensor:$model_diff + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$l1, + DefaultValuedAttr:$l2 + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LeakyReluGradOp : OneFlow_BaseOp<"leaky_relu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$alpha + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_MaskedFillOp : OneFlow_BaseOp<"masked_fill", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$mask + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_MishGradOp : OneFlow_BaseOp<"mish_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_MultiplyOp : OneFlow_BaseOp<"multiply", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_NarrowGradOp : OneFlow_BaseOp<"narrow_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$like + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$dim, + DefaultValuedAttr:$start, + DefaultValuedAttr:$length + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_PowOp : OneFlow_BaseOp<"pow", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_PreluOp : OneFlow_BaseOp<"prelu", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$alpha + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReluGradOp : OneFlow_BaseOp<"relu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$y, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SeluGradOp : OneFlow_BaseOp<"selu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SigmoidGradOp : OneFlow_BaseOp<"sigmoid_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$y, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SiluGradOp : OneFlow_BaseOp<"silu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TfPreluOp : OneFlow_BaseOp<"tf_prelu", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$alpha + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UnfoldTensorGradOp : OneFlow_BaseOp<"unfold_tensor_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$dimension, + DefaultValuedAttr:$size, + DefaultValuedAttr:$step + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_XdivyOp : OneFlow_BaseOp<"xdivy", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_XlogyOp : OneFlow_BaseOp<"xlogy", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_BINARY_OP_DEFINITIONS + +// Group: BROADCAST +// broadcast_add, broadcast_div, broadcast_div_grad, broadcast_equal, broadcast_floor_mod, broadcast_fmod, broadcast_greater, broadcast_greater_equal, broadcast_less, broadcast_less_equal, broadcast_like, broadcast_logical_and, broadcast_logical_or, broadcast_logical_xor, broadcast_maximum, broadcast_minimum, broadcast_mul, broadcast_not_equal, broadcast_pow, broadcast_pow_x_grad, broadcast_pow_y_grad, broadcast_sub +// Total: 22 + +#ifdef GET_ONEFLOW_BROADCAST_OP_DEFINITIONS + +def OneFlow_BroadcastAddOp : OneFlow_BaseOp<"broadcast_add", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastDivOp : OneFlow_BaseOp<"broadcast_div", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastDivGradOp : OneFlow_BaseOp<"broadcast_div_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$y, + OneFlow_Tensor:$z, + OneFlow_Tensor:$dz + ); + let output = (outs + OneFlow_Tensor:$dy + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastEqualOp : OneFlow_BaseOp<"broadcast_equal", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastFloorModOp : OneFlow_BaseOp<"broadcast_floor_mod", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastFmodOp : OneFlow_BaseOp<"broadcast_fmod", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastGreaterOp : OneFlow_BaseOp<"broadcast_greater", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastGreaterEqualOp : OneFlow_BaseOp<"broadcast_greater_equal", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastLessOp : OneFlow_BaseOp<"broadcast_less", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastLessEqualOp : OneFlow_BaseOp<"broadcast_less_equal", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastLikeOp : OneFlow_BaseOp<"broadcast_like", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$like + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + SI32ArrayAttr:$broadcast_axes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_BroadcastLogicalAndOp : OneFlow_BaseOp<"broadcast_logical_and", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastLogicalOrOp : OneFlow_BaseOp<"broadcast_logical_or", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastLogicalXorOp : OneFlow_BaseOp<"broadcast_logical_xor", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastMaximumOp : OneFlow_BaseOp<"broadcast_maximum", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastMinimumOp : OneFlow_BaseOp<"broadcast_minimum", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastMulOp : OneFlow_BaseOp<"broadcast_mul", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastNotEqualOp : OneFlow_BaseOp<"broadcast_not_equal", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastPowOp : OneFlow_BaseOp<"broadcast_pow", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastPowXGradOp : OneFlow_BaseOp<"broadcast_pow_x_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y, + OneFlow_Tensor:$z, + OneFlow_Tensor:$dz + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastPowYGradOp : OneFlow_BaseOp<"broadcast_pow_y_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y, + OneFlow_Tensor:$z, + OneFlow_Tensor:$dz + ); + let output = (outs + OneFlow_Tensor:$dy + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastSubOp : OneFlow_BaseOp<"broadcast_sub", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_BROADCAST_OP_DEFINITIONS + +// Group: CONV +// conv1d, conv2d, conv3d, conv_bias_grad, conv_data_grad, conv_filter_grad, deconv1d, deconv2d, deconv3d +// Total: 9 + +#ifdef GET_ONEFLOW_CONV_OP_DEFINITIONS + +def OneFlow_Conv1DOp : OneFlow_ConvolutionBaseOp<"conv1d", [NoSideEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> {} + +def OneFlow_Conv2DOp : OneFlow_ConvolutionBaseOp<"conv2d", [NoSideEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> {} + +def OneFlow_Conv3DOp : OneFlow_ConvolutionBaseOp<"conv3d", [NoSideEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> {} + +def OneFlow_ConvBiasGradOp : OneFlow_BaseOp<"conv_bias_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$bias_diff + ); + let attrs = (ins + StrAttr:$data_format, + DefaultValuedAttr:$num_spatial_dims + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ConvDataGradOp : OneFlow_BaseOp<"conv_data_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$filter, + OneFlow_Tensor:$x_like, + Optional:$_add_to_output + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$num_spatial_dims, + SI32ArrayAttr:$padding_before, + StrAttr:$data_format, + SI32ArrayAttr:$kernel_size, + SI32ArrayAttr:$strides, + SI32ArrayAttr:$dilation_rate, + DefaultValuedAttr:$groups + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ConvFilterGradOp : OneFlow_BaseOp<"conv_filter_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$filter_diff + ); + let attrs = (ins + DefaultValuedAttr:$num_spatial_dims, + SI32ArrayAttr:$padding_before, + StrAttr:$data_format, + SI32ArrayAttr:$kernel_size, + SI32ArrayAttr:$strides, + SI32ArrayAttr:$dilation_rate, + DefaultValuedAttr:$groups + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_Deconv1DOp : OneFlow_BaseOp<"deconv1d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in, + OneFlow_Tensor:$weight + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$filters, + SI32ArrayAttr:$padding_before, + StrAttr:$data_format, + SI32ArrayAttr:$kernel_size, + SI32ArrayAttr:$output_padding, + SI32ArrayAttr:$strides, + SI32ArrayAttr:$dilation_rate, + DefaultValuedAttr:$groups + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_Deconv2DOp : OneFlow_BaseOp<"deconv2d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in, + OneFlow_Tensor:$weight + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$filters, + SI32ArrayAttr:$padding_before, + StrAttr:$data_format, + SI32ArrayAttr:$kernel_size, + SI32ArrayAttr:$output_padding, + SI32ArrayAttr:$strides, + SI32ArrayAttr:$dilation_rate, + DefaultValuedAttr:$groups + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_Deconv3DOp : OneFlow_BaseOp<"deconv3d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in, + OneFlow_Tensor:$weight + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$filters, + SI32ArrayAttr:$padding_before, + StrAttr:$data_format, + SI32ArrayAttr:$kernel_size, + SI32ArrayAttr:$output_padding, + SI32ArrayAttr:$strides, + SI32ArrayAttr:$dilation_rate, + DefaultValuedAttr:$groups + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_CONV_OP_DEFINITIONS + +// Group: CROSS_ENTROPY +// binary_cross_entropy, binary_cross_entropy_grad, binary_cross_entropy_with_logits, binary_cross_entropy_with_logits_grad, sigmoid_cross_entropy, sigmoid_cross_entropy_grad, sparse_cross_entropy, sparse_cross_entropy_grad, sparse_cross_entropy_ms, sparse_cross_entropy_ms_grad +// Total: 10 + +#ifdef GET_ONEFLOW_CROSS_ENTROPY_OP_DEFINITIONS + +def OneFlow_BinaryCrossEntropyOp : OneFlow_BaseOp<"binary_cross_entropy", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input, + OneFlow_Tensor:$target, + Optional:$weight + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_BinaryCrossEntropyGradOp : OneFlow_BaseOp<"binary_cross_entropy_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input, + OneFlow_Tensor:$target, + Optional:$weight, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BinaryCrossEntropyWithLogitsOp : OneFlow_BaseOp<"binary_cross_entropy_with_logits", [NoSideEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input, + OneFlow_Tensor:$target, + Optional:$weight, + Optional:$pos_weight + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_pos_weight + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_BinaryCrossEntropyWithLogitsGradOp : OneFlow_BaseOp<"binary_cross_entropy_with_logits_grad", [NoSideEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input, + OneFlow_Tensor:$target, + Optional:$weight, + Optional:$pos_weight, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$has_pos_weight + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SigmoidCrossEntropyOp : OneFlow_BaseOp<"sigmoid_cross_entropy", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$prediction, + OneFlow_Tensor:$label + ); + let output = (outs + OneFlow_Tensor:$loss + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_SigmoidCrossEntropyGradOp : OneFlow_BaseOp<"sigmoid_cross_entropy_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$prediction, + OneFlow_Tensor:$loss_diff, + OneFlow_Tensor:$label + ); + let output = (outs + OneFlow_Tensor:$prediction_diff + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_SparseCrossEntropyOp : OneFlow_BaseOp<"sparse_cross_entropy", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$prediction, + OneFlow_Tensor:$label + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$depth + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_SparseCrossEntropyGradOp : OneFlow_BaseOp<"sparse_cross_entropy_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$prediction, + OneFlow_Tensor:$label, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$prediction_diff + ); + let attrs = (ins + DefaultValuedAttr:$depth + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SparseCrossEntropyMsOp : OneFlow_BaseOp<"sparse_cross_entropy_ms", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$prediction, + OneFlow_Tensor:$label + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$depth + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_SparseCrossEntropyMsGradOp : OneFlow_BaseOp<"sparse_cross_entropy_ms_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$prediction, + OneFlow_Tensor:$label, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$prediction_diff + ); + let attrs = (ins + DefaultValuedAttr:$depth + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_CROSS_ENTROPY_OP_DEFINITIONS + +// Group: CUDA +// nvtx_end, nvtx_start +// Total: 2 + +#ifdef GET_ONEFLOW_CUDA_OP_DEFINITIONS + +def OneFlow_NvtxEndOp : OneFlow_BaseOp<"nvtx_end", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + StrAttr:$mark_prefix + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_NvtxStartOp : OneFlow_BaseOp<"nvtx_start", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + StrAttr:$mark_prefix + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_CUDA_OP_DEFINITIONS + +// Group: DATASET +// COCOReader, OFRecordReader, OneRecReader, ctc_greedy_decoder, megatron_gpt_mmap_data_loader, ofrecord_bytes_decoder, ofrecord_image_classification_reader, ofrecord_image_decoder, ofrecord_image_decoder_random_crop, ofrecord_raw_decoder, onerec_decoder +// Total: 11 + +#ifdef GET_ONEFLOW_DATASET_OP_DEFINITIONS + +def OneFlow_COCOReaderOp : OneFlow_BaseOp<"COCOReader", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let output = (outs + OneFlow_Tensor:$image, + OneFlow_Tensor:$image_id, + OneFlow_Tensor:$image_size, + OneFlow_Tensor:$gt_bbox, + OneFlow_Tensor:$gt_label, + OneFlow_Tensor:$gt_segm, + OneFlow_Tensor:$gt_segm_index + ); + let attrs = (ins + DefaultValuedAttr:$session_id, + StrAttr:$annotation_file, + StrAttr:$image_dir, + DefaultValuedAttr:$batch_size, + DefaultValuedAttr:$shuffle_after_epoch, + DefaultValuedAttr:$random_seed, + DefaultValuedAttr:$group_by_ratio, + DefaultValuedAttr:$remove_images_without_annotations, + DefaultValuedAttr:$stride_partition, + StrArrayAttr:$nd_sbp + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_output_arg_modify_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_OFRecordReaderOp : OneFlow_BaseOp<"OFRecordReader", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + StrAttr:$data_dir, + DefaultValuedAttr:$data_part_num, + DefaultValuedAttr:$batch_size, + DefaultValuedAttr:$part_name_prefix, + DefaultValuedAttr:$part_name_suffix_length, + DefaultValuedAttr:$random_shuffle, + DefaultValuedAttr:$seed, + DefaultValuedAttr:$shuffle_buffer_size, + DefaultValuedAttr:$shuffle_after_epoch, + StrArrayAttr:$nd_sbp + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_output_arg_modify_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_OneRecReaderOp : OneFlow_BaseOp<"OneRecReader", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + StrArrayAttr:$files, + DefaultValuedAttr:$batch_size, + DefaultValuedAttr:$random_shuffle, + DefaultValuedAttr:$shuffle_mode, + DefaultValuedAttr:$seed, + DefaultValuedAttr:$shuffle_buffer_size, + DefaultValuedAttr:$shuffle_after_epoch, + DefaultValuedAttr:$verify_example + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CtcGreedyDecoderOp : OneFlow_BaseOp<"ctc_greedy_decoder", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$log_probs, + OneFlow_Tensor:$input_lengths + ); + let output = (outs + OneFlow_Tensor:$decoded, + OneFlow_Tensor:$neg_sum_logits + ); + let attrs = (ins + DefaultValuedAttr:$merge_repeated + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_MegatronGptMmapDataLoaderOp : OneFlow_BaseOp<"megatron_gpt_mmap_data_loader", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + Optional:$iteration + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + StrAttr:$data_file_prefix, + DefaultValuedAttr:$seq_length, + DefaultValuedAttr:$label_length, + DefaultValuedAttr:$num_samples, + DefaultValuedAttr:$batch_size, + OneFlow_DataType:$dtype, + SI64ArrayAttr:$split_sizes, + DefaultValuedAttr:$split_index, + DefaultValuedAttr:$shuffle, + DefaultValuedAttr:$random_seed, + StrArrayAttr:$nd_sbp + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_OfrecordBytesDecoderOp : OneFlow_BaseOp<"ofrecord_bytes_decoder", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + StrAttr:$name + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_OfrecordImageClassificationReaderOp : OneFlow_BaseOp<"ofrecord_image_classification_reader", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let output = (outs + OneFlow_Tensor:$image, + OneFlow_Tensor:$label + ); + let attrs = (ins + StrAttr:$data_dir, + DefaultValuedAttr:$data_part_num, + DefaultValuedAttr:$batch_size, + DefaultValuedAttr:$part_name_prefix, + DefaultValuedAttr:$part_name_suffix_length, + DefaultValuedAttr:$random_shuffle, + DefaultValuedAttr:$seed, + DefaultValuedAttr:$shuffle_buffer_size, + DefaultValuedAttr:$shuffle_after_epoch, + DefaultValuedAttr:$color_space, + DefaultValuedAttr:$image_feature_name, + DefaultValuedAttr:$label_feature_name, + DefaultValuedAttr:$decode_buffer_size_per_thread, + DefaultValuedAttr:$num_decode_threads_per_machine + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_output_arg_modify_fn = 1; +} + +def OneFlow_OfrecordImageDecoderOp : OneFlow_BaseOp<"ofrecord_image_decoder", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + StrAttr:$name, + DefaultValuedAttr:$color_space + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_OfrecordImageDecoderRandomCropOp : OneFlow_BaseOp<"ofrecord_image_decoder_random_crop", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + StrAttr:$name, + DefaultValuedAttr:$color_space, + DefaultValuedAttr:$num_attempts, + DefaultValuedAttr:$seed, + DefaultValuedAttr:$has_seed, + F32ArrayAttr:$random_area, + F32ArrayAttr:$random_aspect_ratio + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_OfrecordRawDecoderOp : OneFlow_BaseOp<"ofrecord_raw_decoder", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + StrAttr:$name, + ShapeAttr:$shape, + OneFlow_DataType:$data_type, + DefaultValuedAttr:$dim1_varying_length, + DefaultValuedAttr:$truncate + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_OnerecDecoderOp : OneFlow_BaseOp<"onerec_decoder", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + StrAttr:$key, + OneFlow_DataType:$data_type, + ShapeAttr:$static_shape, + DefaultValuedAttr:$is_dynamic, + DefaultValuedAttr:$has_reshape, + ShapeAttr:$reshape, + DefaultValuedAttr:$has_batch_padding, + ShapeAttr:$batch_padding + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; + let has_output_arg_modify_fn = 1; +} + +#endif // GET_ONEFLOW_DATASET_OP_DEFINITIONS + +// Group: DETECTION +// in_top_k, nms, object_bbox_flip, object_bbox_scale, object_segmentation_polygon_flip, object_segmentation_polygon_scale, object_segmentation_polygon_to_mask, roi_align, roi_align_grad, top_k +// Total: 10 + +#ifdef GET_ONEFLOW_DETECTION_OP_DEFINITIONS + +def OneFlow_InTopKOp : OneFlow_BaseOp<"in_top_k", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$targets, + OneFlow_Tensor:$predictions + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$k + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_NmsOp : OneFlow_BaseOp<"nms", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$iou_threshold, + DefaultValuedAttr:$keep_n + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ObjectBboxFlipOp : OneFlow_BaseOp<"object_bbox_flip", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$bbox, + OneFlow_Tensor:$image_size, + OneFlow_Tensor:$flip_code + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ObjectBboxScaleOp : OneFlow_BaseOp<"object_bbox_scale", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$bbox, + OneFlow_Tensor:$scale + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ObjectSegmentationPolygonFlipOp : OneFlow_BaseOp<"object_segmentation_polygon_flip", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$poly, + OneFlow_Tensor:$image_size, + OneFlow_Tensor:$flip_code + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ObjectSegmentationPolygonScaleOp : OneFlow_BaseOp<"object_segmentation_polygon_scale", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$poly, + OneFlow_Tensor:$scale + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ObjectSegmentationPolygonToMaskOp : OneFlow_BaseOp<"object_segmentation_polygon_to_mask", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$poly, + OneFlow_Tensor:$poly_index, + OneFlow_Tensor:$image_size + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_RoiAlignOp : OneFlow_BaseOp<"roi_align", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$rois + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + DefaultValuedAttr:$pooled_h, + DefaultValuedAttr:$pooled_w, + DefaultValuedAttr:$spatial_scale, + DefaultValuedAttr:$sampling_ratio, + DefaultValuedAttr:$aligned + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_RoiAlignGradOp : OneFlow_BaseOp<"roi_align_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$x_like, + OneFlow_Tensor:$rois + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$pooled_h, + DefaultValuedAttr:$pooled_w, + DefaultValuedAttr:$spatial_scale, + DefaultValuedAttr:$sampling_ratio, + DefaultValuedAttr:$aligned + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TopKOp : OneFlow_BaseOp<"top_k", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$k, + DefaultValuedAttr:$sorted + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_DETECTION_OP_DEFINITIONS + +// Group: EAGER +// eager_b_to_s, eager_naive_s_to_s, eager_nccl_all_gather, eager_nccl_all_reduce, eager_nccl_broadcast, eager_nccl_reduce, eager_nccl_reduce_scatter, eager_nccl_s2s, eager_p_to_b, eager_p_to_s, eager_s_to_b, eager_symmetric_s_to_p +// Total: 12 + +#ifdef GET_ONEFLOW_EAGER_OP_DEFINITIONS + +def OneFlow_EagerBToSOp : OneFlow_BaseOp<"eager_b_to_s", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$out_split_axis, + StrAttr:$in_parallel_conf, + StrAttr:$out_parallel_conf, + ShapeAttr:$shape + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_EagerNaiveSToSOp : OneFlow_BaseOp<"eager_naive_s_to_s", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$in_split_axis, + DefaultValuedAttr:$out_split_axis, + StrAttr:$in_parallel_conf, + StrAttr:$out_parallel_conf, + ShapeAttr:$shape + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_EagerNcclAllGatherOp : OneFlow_BaseOp<"eager_nccl_all_gather", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + StrAttr:$parallel_conf + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_EagerNcclAllReduceOp : OneFlow_BaseOp<"eager_nccl_all_reduce", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + StrAttr:$parallel_conf, + DefaultValuedAttr:$async_launch + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; +} + +def OneFlow_EagerNcclBroadcastOp : OneFlow_BaseOp<"eager_nccl_broadcast", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + StrAttr:$parallel_conf, + DefaultValuedAttr:$root + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; +} + +def OneFlow_EagerNcclReduceOp : OneFlow_BaseOp<"eager_nccl_reduce", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + StrAttr:$parallel_conf, + DefaultValuedAttr:$root + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; +} + +def OneFlow_EagerNcclReduceScatterOp : OneFlow_BaseOp<"eager_nccl_reduce_scatter", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + StrAttr:$parallel_conf, + DefaultValuedAttr:$op_type + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_EagerNcclS2sOp : OneFlow_BaseOp<"eager_nccl_s2s", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$in_split_axis, + DefaultValuedAttr:$out_split_axis, + StrAttr:$parallel_conf + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_EagerPToBOp : OneFlow_BaseOp<"eager_p_to_b", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + StrAttr:$in_parallel_conf, + StrAttr:$out_parallel_conf, + ShapeAttr:$shape + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_EagerPToSOp : OneFlow_BaseOp<"eager_p_to_s", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$out_split_axis, + StrAttr:$in_parallel_conf, + StrAttr:$out_parallel_conf, + ShapeAttr:$shape + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_EagerSToBOp : OneFlow_BaseOp<"eager_s_to_b", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$in_split_axis, + StrAttr:$in_parallel_conf, + StrAttr:$out_parallel_conf, + ShapeAttr:$shape + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_EagerSymmetricSToPOp : OneFlow_BaseOp<"eager_symmetric_s_to_p", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$in_split_axis, + StrAttr:$parallel_conf + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +#endif // GET_ONEFLOW_EAGER_OP_DEFINITIONS + +// Group: FUSED +// cudnn_fused_normalization_add_relu, cudnn_fused_normalization_add_relu_grad, fused_bias_add_gelu, fused_bias_add_gelu_grad, fused_bias_add_mask_scale, fused_cast_scale, fused_scale_mask_softmax, fused_scale_mask_softmax_dropout, fused_scale_mask_softmax_dropout_grad, fused_scale_mask_softmax_grad, fused_scale_tril, fused_self_attention_query_mul_key_and_value, fused_self_attention_query_mul_key_and_value_grad, fused_tril_scale_softmax_mask_scale, fused_tril_scale_softmax_mask_scale_grad, normalization_add_relu_grad +// Total: 16 + +#ifdef GET_ONEFLOW_FUSED_OP_DEFINITIONS + +def OneFlow_CudnnFusedNormalizationAddReluOp : OneFlow_BaseOp<"cudnn_fused_normalization_add_relu", [NoSideEffect, AttrSizedOperandSegments, AttrSizedResultSegments, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + Optional:$addend, + Optional:$moving_mean, + Optional:$moving_variance, + OneFlow_Tensor:$gamma, + OneFlow_Tensor:$beta + ); + let output = (outs + OneFlow_Tensor:$y, + OneFlow_Tensor:$reserve_space, + Optional:$mean, + Optional:$inv_variance + ); + let attrs = (ins + DefaultValuedAttr:$axis, + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$momentum + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes, + I32ElementsAttr:$result_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_CudnnFusedNormalizationAddReluGradOp : OneFlow_BaseOp<"cudnn_fused_normalization_add_relu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy, + OneFlow_Tensor:$mean, + OneFlow_Tensor:$inv_variance, + OneFlow_Tensor:$gamma, + OneFlow_Tensor:$beta, + OneFlow_Tensor:$reserve_space, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$gamma_diff, + OneFlow_Tensor:$beta_diff, + OneFlow_Tensor:$dx, + Optional:$addend_diff + ); + let attrs = (ins + DefaultValuedAttr:$axis, + DefaultValuedAttr:$epsilon + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FusedBiasAddGeluOp : OneFlow_BaseOp<"fused_bias_add_gelu", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$a, + OneFlow_Tensor:$b + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FusedBiasAddGeluGradOp : OneFlow_BaseOp<"fused_bias_add_gelu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$a, + OneFlow_Tensor:$b, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FusedBiasAddMaskScaleOp : OneFlow_BaseOp<"fused_bias_add_mask_scale", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$a, + OneFlow_Tensor:$b, + OneFlow_Tensor:$mask, + Optional:$_add_to_output + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$axis, + DefaultValuedAttr:$scale + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_FusedCastScaleOp : OneFlow_BaseOp<"fused_cast_scale", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$scale_by_tensor + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + DefaultValuedAttr:$scale + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FusedScaleMaskSoftmaxOp : OneFlow_BaseOp<"fused_scale_mask_softmax", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$mask + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + DefaultValuedAttr:$scale_value, + DefaultValuedAttr:$mask_fill_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_FusedScaleMaskSoftmaxDropoutOp : OneFlow_BaseOp<"fused_scale_mask_softmax_dropout", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$mask, + OneFlow_Tensor:$dropout_mask + ); + let output = (outs + OneFlow_Tensor:$y, + OneFlow_Tensor:$softmax_y + ); + let attrs = (ins + DefaultValuedAttr:$scale_value, + DefaultValuedAttr:$mask_fill_value, + DefaultValuedAttr:$dropout_scale_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_FusedScaleMaskSoftmaxDropoutGradOp : OneFlow_BaseOp<"fused_scale_mask_softmax_dropout_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$softmax_y, + OneFlow_Tensor:$dy, + OneFlow_Tensor:$mask, + OneFlow_Tensor:$dropout_mask + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$scale_value, + DefaultValuedAttr:$dropout_scale_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FusedScaleMaskSoftmaxGradOp : OneFlow_BaseOp<"fused_scale_mask_softmax_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$y, + OneFlow_Tensor:$dy, + OneFlow_Tensor:$mask + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$scale_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FusedScaleTrilOp : OneFlow_BaseOp<"fused_scale_tril", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$diagonal, + DefaultValuedAttr:$floating_fill_value, + DefaultValuedAttr:$integer_fill_value, + DefaultValuedAttr:$is_floating_fill_value, + DefaultValuedAttr:$floating_scale_value, + DefaultValuedAttr:$integer_scale_value, + DefaultValuedAttr:$is_floating_scale_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FusedSelfAttentionQueryMulKeyAndValueOp : OneFlow_BaseOp<"fused_self_attention_query_mul_key_and_value", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$hidden_states + ); + let output = (outs + OneFlow_Tensor:$query_mul_key, + OneFlow_Tensor:$value + ); + let attrs = (ins + DefaultValuedAttr:$head_size, + DefaultValuedAttr:$alpha + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FusedSelfAttentionQueryMulKeyAndValueGradOp : OneFlow_BaseOp<"fused_self_attention_query_mul_key_and_value_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$query_mul_key_grad, + OneFlow_Tensor:$value_grad, + OneFlow_Tensor:$hidden_states + ); + let output = (outs + OneFlow_Tensor:$hidden_states_grad + ); + let attrs = (ins + DefaultValuedAttr:$alpha + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FusedTrilScaleSoftmaxMaskScaleOp : OneFlow_BaseOp<"fused_tril_scale_softmax_mask_scale", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$mask + ); + let output = (outs + OneFlow_Tensor:$y, + OneFlow_Tensor:$softmax_y + ); + let attrs = (ins + DefaultValuedAttr:$diagonal, + DefaultValuedAttr:$tril_fill_value, + DefaultValuedAttr:$tril_scale_value, + DefaultValuedAttr:$mask_scale_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_FusedTrilScaleSoftmaxMaskScaleGradOp : OneFlow_BaseOp<"fused_tril_scale_softmax_mask_scale_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$softmax_y, + OneFlow_Tensor:$dy, + OneFlow_Tensor:$mask + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$diagonal, + DefaultValuedAttr:$tril_scale_value, + DefaultValuedAttr:$mask_scale_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_NormalizationAddReluGradOp : OneFlow_BaseOp<"normalization_add_relu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy, + OneFlow_Tensor:$mean, + OneFlow_Tensor:$inv_variance, + OneFlow_Tensor:$gamma, + OneFlow_Tensor:$beta, + OneFlow_Tensor:$reserve_space, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$gamma_diff, + OneFlow_Tensor:$beta_diff, + OneFlow_Tensor:$dx, + Optional:$addend_diff + ); + let attrs = (ins + DefaultValuedAttr:$axis, + DefaultValuedAttr:$epsilon + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_FUSED_OP_DEFINITIONS + +// Group: IDEMPOTENT +// abs, ceil, floor, ones_like, relu, rint, round, sign +// Total: 8 + +#ifdef GET_ONEFLOW_IDEMPOTENT_OP_DEFINITIONS + +def OneFlow_AbsOp : OneFlow_IdempotentBaseOp<"abs", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_CeilOp : OneFlow_IdempotentBaseOp<"ceil", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_FloorOp : OneFlow_IdempotentBaseOp<"floor", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_OnesLikeOp : OneFlow_IdempotentBaseOp<"ones_like", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let same_output_regst_num = 1; +} + +def OneFlow_ReluOp : OneFlow_IdempotentBaseOp<"relu", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_RintOp : OneFlow_IdempotentBaseOp<"rint", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_RoundOp : OneFlow_IdempotentBaseOp<"round", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_SignOp : OneFlow_IdempotentBaseOp<"sign", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +#endif // GET_ONEFLOW_IDEMPOTENT_OP_DEFINITIONS + +// Group: IDENTITY +// amp_white_identity, identity, identity_buffer, tuple_identity +// Total: 4 + +#ifdef GET_ONEFLOW_IDENTITY_OP_DEFINITIONS + +def OneFlow_AmpWhiteIdentityOp : OneFlow_BaseOp<"amp_white_identity", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_IdentityOp : OneFlow_BaseOp<"identity", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_IdentityBufferOp : OneFlow_BaseOp<"identity_buffer", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$buffer_size + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TupleIdentityOp : OneFlow_BaseOp<"tuple_identity", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + Variadic:$in + ); + let output = (outs + Variadic:$out + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_sbp_signature_infer_fn = 1; +} + +#endif // GET_ONEFLOW_IDENTITY_OP_DEFINITIONS + +// Group: IMAGE +// image_batch_align, image_decode, image_flip, image_random_crop, image_resize_keep_aspect_ratio, image_resize_to_fixed +// Total: 6 + +#ifdef GET_ONEFLOW_IMAGE_OP_DEFINITIONS + +def OneFlow_ImageBatchAlignOp : OneFlow_BaseOp<"image_batch_align", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + ShapeAttr:$shape, + OneFlow_DataType:$data_type, + DefaultValuedAttr:$alignment, + DefaultValuedAttr:$dynamic_out + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_output_arg_modify_fn = 1; +} + +def OneFlow_ImageDecodeOp : OneFlow_BaseOp<"image_decode", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$color_space, + OneFlow_DataType:$data_type + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ImageFlipOp : OneFlow_BaseOp<"image_flip", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in, + OneFlow_Tensor:$flip_code + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ImageRandomCropOp : OneFlow_BaseOp<"image_random_crop", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$num_attempts, + DefaultValuedAttr:$seed, + DefaultValuedAttr:$has_seed, + F32ArrayAttr:$random_area, + F32ArrayAttr:$random_aspect_ratio + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_ImageResizeKeepAspectRatioOp : OneFlow_BaseOp<"image_resize_keep_aspect_ratio", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out, + OneFlow_Tensor:$size, + OneFlow_Tensor:$scale + ); + let attrs = (ins + DefaultValuedAttr:$target_size, + DefaultValuedAttr:$min_size, + DefaultValuedAttr:$max_size, + DefaultValuedAttr:$resize_longer, + DefaultValuedAttr:$interpolation_type + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ImageResizeToFixedOp : OneFlow_BaseOp<"image_resize_to_fixed", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out, + OneFlow_Tensor:$scale + ); + let attrs = (ins + DefaultValuedAttr:$target_width, + DefaultValuedAttr:$target_height, + DefaultValuedAttr:$channels, + OneFlow_DataType:$data_type, + DefaultValuedAttr:$interpolation_type + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_IMAGE_OP_DEFINITIONS + +// Group: INDICES +// arg_sort, argmax, argwhere, batch_gather, dim_gather, dim_scatter_add, dim_scatter_add_like, dim_scatter_add_scalar, dim_scatter_mul, dim_scatter_mul_scalar, dim_scatter_update, dim_scatter_update_scalar, gather, gather_nd, generate_random_batch_permutation_indices, image_target_resize, logical_slice, scatter_nd, scatter_nd_like, slice, slice_grad, tensor_scatter_nd_add, tensor_scatter_nd_update, unsorted_batch_segment_sum, unsorted_segment_sum, unsorted_segment_sum_like, where, where_scalar_x, where_scalar_xy, where_scalar_y +// Total: 30 + +#ifdef GET_ONEFLOW_INDICES_OP_DEFINITIONS + +def OneFlow_ArgSortOp : OneFlow_BaseOp<"arg_sort", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + StrAttr:$direction + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ArgmaxOp : OneFlow_BaseOp<"argmax", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ArgwhereOp : OneFlow_BaseOp<"argwhere", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input + ); + let output = (outs + OneFlow_Tensor:$output, + OneFlow_Tensor:$output_size + ); + let attrs = (ins + OneFlow_DataType:$dtype + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BatchGatherOp : OneFlow_BaseOp<"batch_gather", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in, + OneFlow_Tensor:$indices + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_DimGatherOp : OneFlow_BaseOp<"dim_gather", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input, + OneFlow_Tensor:$index + ); + let output = (outs + OneFlow_Tensor:$output + ); + let attrs = (ins + DefaultValuedAttr:$dim + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_DimScatterAddOp : OneFlow_BaseOp<"dim_scatter_add", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input, + OneFlow_Tensor:$index, + OneFlow_Tensor:$src + ); + let output = (outs + OneFlow_Tensor:$output + ); + let attrs = (ins + DefaultValuedAttr:$dim + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_DimScatterAddLikeOp : OneFlow_BaseOp<"dim_scatter_add_like", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$like, + OneFlow_Tensor:$index, + OneFlow_Tensor:$src + ); + let output = (outs + OneFlow_Tensor:$output + ); + let attrs = (ins + DefaultValuedAttr:$dim + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_DimScatterAddScalarOp : OneFlow_BaseOp<"dim_scatter_add_scalar", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input, + OneFlow_Tensor:$index + ); + let output = (outs + OneFlow_Tensor:$output + ); + let attrs = (ins + DefaultValuedAttr:$src_scalar, + DefaultValuedAttr:$dim + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_DimScatterMulOp : OneFlow_BaseOp<"dim_scatter_mul", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input, + OneFlow_Tensor:$index, + OneFlow_Tensor:$src + ); + let output = (outs + OneFlow_Tensor:$output + ); + let attrs = (ins + DefaultValuedAttr:$dim + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_DimScatterMulScalarOp : OneFlow_BaseOp<"dim_scatter_mul_scalar", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input, + OneFlow_Tensor:$index + ); + let output = (outs + OneFlow_Tensor:$output + ); + let attrs = (ins + DefaultValuedAttr:$src_scalar, + DefaultValuedAttr:$dim + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_DimScatterUpdateOp : OneFlow_BaseOp<"dim_scatter_update", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input, + OneFlow_Tensor:$index, + OneFlow_Tensor:$src + ); + let output = (outs + OneFlow_Tensor:$output + ); + let attrs = (ins + DefaultValuedAttr:$dim + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_DimScatterUpdateScalarOp : OneFlow_BaseOp<"dim_scatter_update_scalar", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input, + OneFlow_Tensor:$index + ); + let output = (outs + OneFlow_Tensor:$output + ); + let attrs = (ins + DefaultValuedAttr:$src_scalar, + DefaultValuedAttr:$dim + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_GatherOp : OneFlow_BaseOp<"gather", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in, + OneFlow_Tensor:$indices + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_GatherNdOp : OneFlow_BaseOp<"gather_nd", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$params, + OneFlow_Tensor:$indices + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_GenerateRandomBatchPermutationIndicesOp : OneFlow_BaseOp<"generate_random_batch_permutation_indices", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + DefaultValuedAttr:$seed + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ImageTargetResizeOp : OneFlow_BaseOp<"image_target_resize", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out, + OneFlow_Tensor:$size, + OneFlow_Tensor:$scale + ); + let attrs = (ins + DefaultValuedAttr:$target_size, + DefaultValuedAttr:$max_size + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LogicalSliceOp : OneFlow_BaseOp<"logical_slice", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + SI64ArrayAttr:$start, + SI64ArrayAttr:$stop, + SI64ArrayAttr:$step + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScatterNdOp : OneFlow_BaseOp<"scatter_nd", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$indices, + OneFlow_Tensor:$updates + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + ShapeAttr:$shape + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_ScatterNdLikeOp : OneFlow_BaseOp<"scatter_nd_like", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$like, + OneFlow_Tensor:$indices, + OneFlow_Tensor:$updates + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SliceOp : OneFlow_BaseOp<"slice", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + SI64ArrayAttr:$start, + SI64ArrayAttr:$stop, + SI64ArrayAttr:$step + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SliceGradOp : OneFlow_BaseOp<"slice_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + ShapeAttr:$like_shape, + SI64ArrayAttr:$start, + SI64ArrayAttr:$stop, + SI64ArrayAttr:$step + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_TensorScatterNdAddOp : OneFlow_BaseOp<"tensor_scatter_nd_add", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$params, + OneFlow_Tensor:$updates, + OneFlow_Tensor:$indices + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_TensorScatterNdUpdateOp : OneFlow_BaseOp<"tensor_scatter_nd_update", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$params, + OneFlow_Tensor:$updates, + OneFlow_Tensor:$indices + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_UnsortedBatchSegmentSumOp : OneFlow_BaseOp<"unsorted_batch_segment_sum", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$data, + OneFlow_Tensor:$segment_ids + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$num_segments + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_UnsortedSegmentSumOp : OneFlow_BaseOp<"unsorted_segment_sum", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$data, + OneFlow_Tensor:$segment_ids + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$axis, + DefaultValuedAttr:$num_segments + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_UnsortedSegmentSumLikeOp : OneFlow_BaseOp<"unsorted_segment_sum_like", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$data, + OneFlow_Tensor:$segment_ids, + OneFlow_Tensor:$like + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_WhereOp : OneFlow_BaseOp<"where", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$condition, + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_WhereScalarXOp : OneFlow_BaseOp<"where_scalar_x", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$condition, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_WhereScalarXyOp : OneFlow_BaseOp<"where_scalar_xy", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$condition + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_x_int_operand, + DefaultValuedAttr:$has_x_float_operand, + DefaultValuedAttr:$has_y_int_operand, + DefaultValuedAttr:$has_y_float_operand, + DefaultValuedAttr:$x_int_operand, + DefaultValuedAttr:$x_float_operand, + DefaultValuedAttr:$y_int_operand, + DefaultValuedAttr:$y_float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_WhereScalarYOp : OneFlow_BaseOp<"where_scalar_y", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$condition, + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +#endif // GET_ONEFLOW_INDICES_OP_DEFINITIONS + +// Group: INVOLUTION +// negative, reciprocal +// Total: 2 + +#ifdef GET_ONEFLOW_INVOLUTION_OP_DEFINITIONS + +def OneFlow_NegativeOp : OneFlow_InvolutionBaseOp<"negative", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_ReciprocalOp : OneFlow_InvolutionBaseOp<"reciprocal", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +#endif // GET_ONEFLOW_INVOLUTION_OP_DEFINITIONS + +// Group: LOSS +// combined_margin_loss, combined_margin_loss_grad, ctc_loss, ctc_loss_grad, dynamic_loss_scale_schedule, kl_div_loss, kl_div_loss_grad, smooth_l1_loss, smooth_l1_loss_grad +// Total: 9 + +#ifdef GET_ONEFLOW_LOSS_OP_DEFINITIONS + +def OneFlow_CombinedMarginLossOp : OneFlow_BaseOp<"combined_margin_loss", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$label + ); + let output = (outs + OneFlow_Tensor:$y, + OneFlow_Tensor:$theta + ); + let attrs = (ins + DefaultValuedAttr:$m1, + DefaultValuedAttr:$m2, + DefaultValuedAttr:$m3, + DefaultValuedAttr:$depth + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_CombinedMarginLossGradOp : OneFlow_BaseOp<"combined_margin_loss_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$label, + OneFlow_Tensor:$theta + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$m1, + DefaultValuedAttr:$m2, + DefaultValuedAttr:$m3, + DefaultValuedAttr:$depth + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CtcLossOp : OneFlow_BaseOp<"ctc_loss", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$log_probs, + OneFlow_Tensor:$targets, + OneFlow_Tensor:$input_lengths, + OneFlow_Tensor:$target_lengths + ); + let output = (outs + OneFlow_Tensor:$loss, + OneFlow_Tensor:$alpha + ); + let attrs = (ins + DefaultValuedAttr:$max_target_length, + DefaultValuedAttr:$blank, + DefaultValuedAttr:$zero_infinity + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CtcLossGradOp : OneFlow_BaseOp<"ctc_loss_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$grad_out, + OneFlow_Tensor:$log_probs, + OneFlow_Tensor:$targets, + OneFlow_Tensor:$input_lengths, + OneFlow_Tensor:$target_lengths, + OneFlow_Tensor:$loss, + OneFlow_Tensor:$alpha + ); + let output = (outs + OneFlow_Tensor:$grad + ); + let attrs = (ins + DefaultValuedAttr:$max_target_length, + DefaultValuedAttr:$blank, + DefaultValuedAttr:$zero_infinity + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_DynamicLossScaleScheduleOp : OneFlow_BaseOp<"dynamic_loss_scale_schedule", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$count_not_finite, + OneFlow_Tensor:$loss_scale, + OneFlow_Tensor:$good_step_counter + ); + let attrs = (ins + DefaultValuedAttr:$increment_period, + DefaultValuedAttr:$multiplier + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_KlDivLossOp : OneFlow_BaseOp<"kl_div_loss", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input, + OneFlow_Tensor:$target + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$log_target + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_KlDivLossGradOp : OneFlow_BaseOp<"kl_div_loss_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input, + OneFlow_Tensor:$target, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$log_target + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SmoothL1LossOp : OneFlow_BaseOp<"smooth_l1_loss", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input, + OneFlow_Tensor:$target + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$beta + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_SmoothL1LossGradOp : OneFlow_BaseOp<"smooth_l1_loss_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input, + OneFlow_Tensor:$target, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$beta + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_LOSS_OP_DEFINITIONS + +// Group: MATH +// abs_grad, ceil_grad, erf, erf_grad, exp, exp_grad, expand_grad, expm1, expm1_grad, floor_grad, floordiv_x_grad, floordiv_y_grad, lgamma, lgamma_grad, log, log1p, log1p_grad, log2_grad, log_grad, log_sigmoid, log_sigmoid_grad, negative_grad, reciprocal_grad, reciprocal_no_nan, reciprocal_no_nan_grad, rint_grad, round_grad, rsqrt, rsqrt_grad, sigmoid_v2, sigmoid_v2_grad, sign_grad, softplus, softplus_grad, softsign_grad, sqrt, sqrt_grad, square, square_grad, xlogy_x_grad, xlogy_y_grad +// Total: 41 + +#ifdef GET_ONEFLOW_MATH_OP_DEFINITIONS + +def OneFlow_AbsGradOp : OneFlow_BaseOp<"abs_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CeilGradOp : OneFlow_BaseOp<"ceil_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ErfOp : OneFlow_BaseOp<"erf", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ErfGradOp : OneFlow_BaseOp<"erf_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ExpOp : OneFlow_BaseOp<"exp", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ExpGradOp : OneFlow_BaseOp<"exp_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ExpandGradOp : OneFlow_BaseOp<"expand_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + SI32ArrayAttr:$logical_out_shape, + SI32ArrayAttr:$logical_expand_shape + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_Expm1Op : OneFlow_BaseOp<"expm1", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_Expm1GradOp : OneFlow_BaseOp<"expm1_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FloorGradOp : OneFlow_BaseOp<"floor_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FloordivXGradOp : OneFlow_BaseOp<"floordiv_x_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y, + OneFlow_Tensor:$dz + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FloordivYGradOp : OneFlow_BaseOp<"floordiv_y_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y, + OneFlow_Tensor:$dz + ); + let output = (outs + OneFlow_Tensor:$dy + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LgammaOp : OneFlow_BaseOp<"lgamma", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LgammaGradOp : OneFlow_BaseOp<"lgamma_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LogOp : OneFlow_BaseOp<"log", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_Log1pOp : OneFlow_BaseOp<"log1p", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_Log1pGradOp : OneFlow_BaseOp<"log1p_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_Log2GradOp : OneFlow_BaseOp<"log2_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LogGradOp : OneFlow_BaseOp<"log_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LogSigmoidOp : OneFlow_BaseOp<"log_sigmoid", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LogSigmoidGradOp : OneFlow_BaseOp<"log_sigmoid_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_NegativeGradOp : OneFlow_BaseOp<"negative_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReciprocalGradOp : OneFlow_BaseOp<"reciprocal_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReciprocalNoNanOp : OneFlow_BaseOp<"reciprocal_no_nan", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReciprocalNoNanGradOp : OneFlow_BaseOp<"reciprocal_no_nan_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_RintGradOp : OneFlow_BaseOp<"rint_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_RoundGradOp : OneFlow_BaseOp<"round_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_RsqrtOp : OneFlow_BaseOp<"rsqrt", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_RsqrtGradOp : OneFlow_BaseOp<"rsqrt_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SigmoidV2Op : OneFlow_BaseOp<"sigmoid_v2", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SigmoidV2GradOp : OneFlow_BaseOp<"sigmoid_v2_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SignGradOp : OneFlow_BaseOp<"sign_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SoftplusOp : OneFlow_BaseOp<"softplus", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SoftplusGradOp : OneFlow_BaseOp<"softplus_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SoftsignGradOp : OneFlow_BaseOp<"softsign_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SqrtOp : OneFlow_BaseOp<"sqrt", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SqrtGradOp : OneFlow_BaseOp<"sqrt_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SquareOp : OneFlow_BaseOp<"square", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SquareGradOp : OneFlow_BaseOp<"square_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_XlogyXGradOp : OneFlow_BaseOp<"xlogy_x_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y, + OneFlow_Tensor:$dz + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_XlogyYGradOp : OneFlow_BaseOp<"xlogy_y_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y, + OneFlow_Tensor:$dz + ); + let output = (outs + OneFlow_Tensor:$dy + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_MATH_OP_DEFINITIONS + +// Group: MATMUL +// batch_matmul, broadcast_matmul, broadcast_matmul_grad_b, distributed_partial_fc_sample, distributed_partial_fc_sample_disable_boxing, erfc, erfc_grad, matmul +// Total: 8 + +#ifdef GET_ONEFLOW_MATMUL_OP_DEFINITIONS + +def OneFlow_BatchMatmulOp : OneFlow_BaseOp<"batch_matmul", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$a, + OneFlow_Tensor:$b, + Optional:$_add_to_output + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$transpose_a, + DefaultValuedAttr:$transpose_b, + DefaultValuedAttr:$alpha + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastMatmulOp : OneFlow_BaseOp<"broadcast_matmul", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$a, + OneFlow_Tensor:$b, + Optional:$_add_to_output + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$transpose_a, + DefaultValuedAttr:$transpose_b, + DefaultValuedAttr:$alpha + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastMatmulGradBOp : OneFlow_BaseOp<"broadcast_matmul_grad_b", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$a, + OneFlow_Tensor:$b, + Optional:$_add_to_output + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$alpha + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_DistributedPartialFcSampleOp : OneFlow_BaseOp<"distributed_partial_fc_sample", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$weight, + OneFlow_Tensor:$label + ); + let output = (outs + OneFlow_Tensor:$mapped_label, + OneFlow_Tensor:$sampled_label, + OneFlow_Tensor:$sampled_weight + ); + let attrs = (ins + DefaultValuedAttr:$num_sample, + DefaultValuedAttr:$seed + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_DistributedPartialFcSampleDisableBoxingOp : OneFlow_BaseOp<"distributed_partial_fc_sample_disable_boxing", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$sampled_weight_diff, + OneFlow_Tensor:$sampled_label + ); + let output = (outs + OneFlow_Tensor:$boxing_disabled_sampled_weight_diff, + OneFlow_Tensor:$boxing_disabled_sampled_label + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ErfcOp : OneFlow_BaseOp<"erfc", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ErfcGradOp : OneFlow_BaseOp<"erfc_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_MatmulOp : OneFlow_BaseOp<"matmul", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$a, + OneFlow_Tensor:$b, + Optional:$_add_to_output + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$transpose_a, + DefaultValuedAttr:$transpose_b, + DefaultValuedAttr:$alpha + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_MATMUL_OP_DEFINITIONS + +// Group: MISC +// CategoricalOrdinalEncode, add_n, arange, coin_flip, concat, constant, dropout, elementwise_maximum_backward, elementwise_minimum_backward, empty, eye, grid_sample_grad, multi_count_not_finite, multi_square_sum, nll, nll_grad, pow_x_grad, pow_y_grad, prelu_grad, randperm, recv, send, split_like, ssp_variable_proxy, tf_prelu_grad, uniform, uniform_int, unique_with_counts, xdivy_x_grad, xdivy_y_grad +// Total: 30 + +#ifdef GET_ONEFLOW_MISC_OP_DEFINITIONS + +def OneFlow_CategoricalOrdinalEncodeOp : OneFlow_BaseOp<"CategoricalOrdinalEncode", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$table, + OneFlow_Tensor:$size, + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$hash_precomputed + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_AddNOp : OneFlow_BaseOp<"add_n", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + Variadic:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let hasCanonicalizer = 1; + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ArangeOp : OneFlow_BaseOp<"arange", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$integer_start, + DefaultValuedAttr:$integer_delta, + DefaultValuedAttr:$integer_limit, + DefaultValuedAttr:$float_start, + DefaultValuedAttr:$float_delta, + DefaultValuedAttr:$float_limit, + OneFlow_DataType:$dtype, + StrArrayAttr:$nd_sbp + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_CoinFlipOp : OneFlow_BaseOp<"coin_flip", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$probability, + DefaultValuedAttr:$batch_size, + DefaultValuedAttr:$seed, + DefaultValuedAttr:$has_seed, + StrArrayAttr:$nd_sbp + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_ConcatOp : OneFlow_BaseOp<"concat", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + Variadic:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$axis, + DefaultValuedAttr:$max_dim_size + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ConstantOp : OneFlow_BaseOp<"constant", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$floating_value, + DefaultValuedAttr:$integer_value, + DefaultValuedAttr:$is_floating_value, + OneFlow_DataType:$dtype, + ShapeAttr:$shape, + StrArrayAttr:$nd_sbp + ); + let same_output_regst_num = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_DropoutOp : OneFlow_BaseOp<"dropout", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in, + Optional:$_add_to_output + ); + let output = (outs + OneFlow_Tensor:$out, + OneFlow_Tensor:$mask + ); + let attrs = (ins + DefaultValuedAttr:$rate + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ElementwiseMaximumBackwardOp : OneFlow_BaseOp<"elementwise_maximum_backward", [NoSideEffect, AttrSizedResultSegments, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dz, + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + Optional:$dx, + Optional:$dy + ); + let trait_attrs = (ins + I32ElementsAttr:$result_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ElementwiseMinimumBackwardOp : OneFlow_BaseOp<"elementwise_minimum_backward", [NoSideEffect, AttrSizedResultSegments, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dz, + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + Optional:$dx, + Optional:$dy + ); + let trait_attrs = (ins + I32ElementsAttr:$result_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_EmptyOp : OneFlow_BaseOp<"empty", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + OneFlow_DataType:$dtype, + ShapeAttr:$shape, + StrArrayAttr:$nd_sbp + ); + let same_output_regst_num = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_EyeOp : OneFlow_BaseOp<"eye", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$rows, + DefaultValuedAttr:$cols, + OneFlow_DataType:$dtype, + StrArrayAttr:$nd_sbp + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_GridSampleGradOp : OneFlow_BaseOp<"grid_sample_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$doutput, + OneFlow_Tensor:$input, + OneFlow_Tensor:$grid + ); + let output = (outs + OneFlow_Tensor:$dinput, + OneFlow_Tensor:$dgrid + ); + let attrs = (ins + StrAttr:$interpolation_mode, + StrAttr:$padding_mode, + DefaultValuedAttr:$align_corners + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_MultiCountNotFiniteOp : OneFlow_BaseOp<"multi_count_not_finite", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + Variadic:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_MultiSquareSumOp : OneFlow_BaseOp<"multi_square_sum", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + Variadic:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_NllOp : OneFlow_BaseOp<"nll", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input, + OneFlow_Tensor:$target, + Optional:$weight + ); + let output = (outs + OneFlow_Tensor:$out, + OneFlow_Tensor:$total_weight + ); + let attrs = (ins + DefaultValuedAttr:$ignore_index + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_NllGradOp : OneFlow_BaseOp<"nll_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input, + OneFlow_Tensor:$target, + OneFlow_Tensor:$total_weight, + Optional:$weight, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$ignore_index + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_PowXGradOp : OneFlow_BaseOp<"pow_x_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y, + OneFlow_Tensor:$dz + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_PowYGradOp : OneFlow_BaseOp<"pow_y_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y, + OneFlow_Tensor:$dz + ); + let output = (outs + OneFlow_Tensor:$dy + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_PreluGradOp : OneFlow_BaseOp<"prelu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$x, + OneFlow_Tensor:$alpha + ); + let output = (outs + OneFlow_Tensor:$dx, + OneFlow_Tensor:$alpha_diff + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_RandpermOp : OneFlow_BaseOp<"randperm", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$n, + DefaultValuedAttr:$seed, + StrArrayAttr:$nd_sbp + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_RecvOp : OneFlow_BaseOp<"recv", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$src_process_id, + OneFlow_DataType:$dtype, + ShapeAttr:$shape, + StrAttr:$device_type, + DefaultValuedAttr:$device_id + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; +} + +def OneFlow_SendOp : OneFlow_BaseOp<"send", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let attrs = (ins + DefaultValuedAttr:$dst_process_id + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; +} + +def OneFlow_SplitLikeOp : OneFlow_BaseOp<"split_like", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in, + Variadic:$like + ); + let output = (outs + Variadic:$out + ); + let attrs = (ins + DefaultValuedAttr:$axis + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_SspVariableProxyOp : OneFlow_BaseOp<"ssp_variable_proxy", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$var + ); + let output = (outs + OneFlow_Tensor:$ref, + OneFlow_Tensor:$value + ); + let attrs = (ins + DefaultValuedAttr:$buffer_size + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_output_arg_modify_fn = 1; +} + +def OneFlow_TfPreluGradOp : OneFlow_BaseOp<"tf_prelu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$x, + OneFlow_Tensor:$alpha + ); + let output = (outs + OneFlow_Tensor:$dx, + OneFlow_Tensor:$alpha_diff + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UniformOp : OneFlow_BaseOp<"uniform", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$from, + DefaultValuedAttr:$to, + DefaultValuedAttr:$seed, + OneFlow_DataType:$dtype, + ShapeAttr:$shape, + StrArrayAttr:$nd_sbp + ); + let same_output_regst_num = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_UniformIntOp : OneFlow_BaseOp<"uniform_int", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$from, + DefaultValuedAttr:$to, + DefaultValuedAttr:$seed, + OneFlow_DataType:$dtype, + ShapeAttr:$shape, + StrArrayAttr:$nd_sbp + ); + let same_output_regst_num = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_UniqueWithCountsOp : OneFlow_BaseOp<"unique_with_counts", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y, + OneFlow_Tensor:$idx, + OneFlow_Tensor:$count, + OneFlow_Tensor:$num_unique + ); + let attrs = (ins + OneFlow_DataType:$out_idx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_XdivyXGradOp : OneFlow_BaseOp<"xdivy_x_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y, + OneFlow_Tensor:$dz + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_XdivyYGradOp : OneFlow_BaseOp<"xdivy_y_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y, + OneFlow_Tensor:$dz + ); + let output = (outs + OneFlow_Tensor:$dy + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_MISC_OP_DEFINITIONS + +// Group: NCCL +// _nccl_logical_2D_same_dim0_all2all, _nccl_logical_2D_same_dim0_all_gather, _nccl_logical_2D_same_dim0_all_gather_noncontinuous, _nccl_logical_2D_same_dim0_all_reduce, _nccl_logical_2D_same_dim1_all_reduce, _nccl_logical_all_gather, _nccl_logical_all_gather_noncontinuous, _nccl_logical_all_reduce, _nccl_logical_reduce_scatter, _nccl_logical_s2s +// Total: 10 + +#ifdef GET_ONEFLOW_NCCL_OP_DEFINITIONS + +def OneFlow__ncclLogical_2DSameDim0All2allOp : OneFlow_BaseOp<"_nccl_logical_2D_same_dim0_all2all", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$in_dim1_split_axis, + DefaultValuedAttr:$out_dim1_split_axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow__ncclLogical_2DSameDim0AllGatherOp : OneFlow_BaseOp<"_nccl_logical_2D_same_dim0_all_gather", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow__ncclLogical_2DSameDim0AllGatherNoncontinuousOp : OneFlow_BaseOp<"_nccl_logical_2D_same_dim0_all_gather_noncontinuous", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$in_dim1_split_axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow__ncclLogical_2DSameDim0AllReduceOp : OneFlow_BaseOp<"_nccl_logical_2D_same_dim0_all_reduce", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow__ncclLogical_2DSameDim1AllReduceOp : OneFlow_BaseOp<"_nccl_logical_2D_same_dim1_all_reduce", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow__ncclLogicalAllGatherOp : OneFlow_BaseOp<"_nccl_logical_all_gather", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow__ncclLogicalAllGatherNoncontinuousOp : OneFlow_BaseOp<"_nccl_logical_all_gather_noncontinuous", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$in_split_axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow__ncclLogicalAllReduceOp : OneFlow_BaseOp<"_nccl_logical_all_reduce", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow__ncclLogicalReduceScatterOp : OneFlow_BaseOp<"_nccl_logical_reduce_scatter", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow__ncclLogicalS2sOp : OneFlow_BaseOp<"_nccl_logical_s2s", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$in_split_axis, + DefaultValuedAttr:$out_split_axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +#endif // GET_ONEFLOW_NCCL_OP_DEFINITIONS + +// Group: NORMALIZATION +// crop_mirror_normalize_from_tensorbuffer, crop_mirror_normalize_from_uint8, image_normalize, l2_normalize, l2_normalize_grad, layer_norm, layer_norm_grad, layer_norm_param_grad, normal, normalization, normalization_grad +// Total: 11 + +#ifdef GET_ONEFLOW_NORMALIZATION_OP_DEFINITIONS + +def OneFlow_CropMirrorNormalizeFromTensorbufferOp : OneFlow_BaseOp<"crop_mirror_normalize_from_tensorbuffer", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in, + Optional:$mirror + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$color_space, + DefaultValuedAttr:$output_layout, + F32ArrayAttr:$mean, + F32ArrayAttr:$std, + DefaultValuedAttr:$crop_h, + DefaultValuedAttr:$crop_w, + DefaultValuedAttr:$crop_pos_x, + DefaultValuedAttr:$crop_pos_y, + OneFlow_DataType:$output_dtype + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CropMirrorNormalizeFromUint8Op : OneFlow_BaseOp<"crop_mirror_normalize_from_uint8", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in, + Optional:$mirror + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$color_space, + DefaultValuedAttr:$output_layout, + F32ArrayAttr:$mean, + F32ArrayAttr:$std, + DefaultValuedAttr:$crop_h, + DefaultValuedAttr:$crop_w, + DefaultValuedAttr:$crop_pos_x, + DefaultValuedAttr:$crop_pos_y, + OneFlow_DataType:$output_dtype + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ImageNormalizeOp : OneFlow_BaseOp<"image_normalize", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + F32ArrayAttr:$std, + F32ArrayAttr:$mean + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_L2NormalizeOp : OneFlow_BaseOp<"l2_normalize", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y, + OneFlow_Tensor:$square_x_sum + ); + let attrs = (ins + DefaultValuedAttr:$axis, + DefaultValuedAttr:$epsilon + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_L2NormalizeGradOp : OneFlow_BaseOp<"l2_normalize_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$y, + OneFlow_Tensor:$square_x_sum + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$axis, + DefaultValuedAttr:$epsilon + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LayerNormOp : OneFlow_BaseOp<"layer_norm", [NoSideEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + Optional:$beta, + Optional:$gamma + ); + let output = (outs + OneFlow_Tensor:$y, + OneFlow_Tensor:$mean, + OneFlow_Tensor:$inv_variance, + Optional:$normalized + ); + let attrs = (ins + DefaultValuedAttr:$center, + DefaultValuedAttr:$scale, + DefaultValuedAttr:$begin_norm_axis, + DefaultValuedAttr:$begin_params_axis, + DefaultValuedAttr:$epsilon + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LayerNormGradOp : OneFlow_BaseOp<"layer_norm_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$x, + OneFlow_Tensor:$mean, + OneFlow_Tensor:$inv_variance, + Optional:$gamma, + Optional:$_add_to_output + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$begin_norm_axis, + DefaultValuedAttr:$epsilon + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LayerNormParamGradOp : OneFlow_BaseOp<"layer_norm_param_grad", [NoSideEffect, AttrSizedOperandSegments, AttrSizedResultSegments, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$x, + OneFlow_Tensor:$mean, + OneFlow_Tensor:$inv_variance + ); + let output = (outs + Optional:$beta_diff, + Optional:$gamma_diff + ); + let attrs = (ins + DefaultValuedAttr:$begin_params_axis + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes, + I32ElementsAttr:$result_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_NormalOp : OneFlow_BaseOp<"normal", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$mean, + DefaultValuedAttr:$std, + DefaultValuedAttr:$seed, + OneFlow_DataType:$dtype, + ShapeAttr:$shape, + StrArrayAttr:$nd_sbp + ); + let same_output_regst_num = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_NormalizationOp : OneFlow_BaseOp<"normalization", [NoSideEffect, AttrSizedOperandSegments, AttrSizedResultSegments, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + Optional:$moving_mean, + Optional:$moving_variance, + OneFlow_Tensor:$gamma, + OneFlow_Tensor:$beta, + Optional:$_add_to_output + ); + let output = (outs + OneFlow_Tensor:$y, + Optional:$mean, + Optional:$inv_variance + ); + let attrs = (ins + DefaultValuedAttr:$axis, + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$training, + DefaultValuedAttr:$momentum + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes, + I32ElementsAttr:$result_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_NormalizationGradOp : OneFlow_BaseOp<"normalization_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy, + OneFlow_Tensor:$mean, + OneFlow_Tensor:$inv_variance, + OneFlow_Tensor:$gamma + ); + let output = (outs + OneFlow_Tensor:$gamma_diff, + OneFlow_Tensor:$beta_diff, + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$axis, + DefaultValuedAttr:$epsilon + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_NORMALIZATION_OP_DEFINITIONS + +// Group: OPTIMIZER +// adagrad_update, adam_bias_correction_factor, adam_update, indexed_slices_adam_update, indexed_slices_momentum_update, indexed_slices_sgd_update, lamb_update, lars_update, momentum_update, rmsprop_update, sgd_update, slice_update +// Total: 12 + +#ifdef GET_ONEFLOW_OPTIMIZER_OP_DEFINITIONS + +def OneFlow_AdagradUpdateOp : OneFlow_BaseOp<"adagrad_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$model, + OneFlow_Tensor:$model_diff, + Optional:$learning_rate, + Optional:$scale_by_tensor, + Optional:$skip_if, + Optional:$train_step, + OneFlow_Tensor:$sum + ); + let attrs = (ins + DefaultValuedAttr:$train_step_val, + DefaultValuedAttr:$learning_rate_val, + DefaultValuedAttr:$scale, + DefaultValuedAttr:$l1, + DefaultValuedAttr:$l2, + DefaultValuedAttr:$lr_decay, + DefaultValuedAttr:$weight_decay, + DefaultValuedAttr:$epsilon + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_AdamBiasCorrectionFactorOp : OneFlow_BaseOp<"adam_bias_correction_factor", [NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$train_step + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$beta + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AdamUpdateOp : OneFlow_BaseOp<"adam_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$model, + OneFlow_Tensor:$model_diff, + Optional:$learning_rate, + Optional:$scale_by_tensor, + Optional:$skip_if, + Optional:$bias_correction1, + Optional:$bias_correction2, + OneFlow_Tensor:$m, + OneFlow_Tensor:$v, + OneFlow_Tensor:$max_v + ); + let attrs = (ins + DefaultValuedAttr:$learning_rate_val, + DefaultValuedAttr:$bias_correction1_val, + DefaultValuedAttr:$bias_correction2_val, + DefaultValuedAttr:$scale, + DefaultValuedAttr:$l1, + DefaultValuedAttr:$l2, + DefaultValuedAttr:$beta1, + DefaultValuedAttr:$beta2, + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$weight_decay, + DefaultValuedAttr:$amsgrad, + DefaultValuedAttr:$do_bias_correction + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_IndexedSlicesAdamUpdateOp : OneFlow_BaseOp<"indexed_slices_adam_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$model, + OneFlow_Tensor:$model_diff_indices, + OneFlow_Tensor:$model_diff_values, + OneFlow_Tensor:$learning_rate, + Optional:$bias_correction1, + Optional:$bias_correction2, + OneFlow_Tensor:$m, + OneFlow_Tensor:$v, + OneFlow_Tensor:$max_v + ); + let attrs = (ins + DefaultValuedAttr:$learning_rate_val, + DefaultValuedAttr:$beta1, + DefaultValuedAttr:$beta2, + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$weight_decay, + DefaultValuedAttr:$amsgrad, + DefaultValuedAttr:$do_bias_correction + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_IndexedSlicesMomentumUpdateOp : OneFlow_BaseOp<"indexed_slices_momentum_update", [NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$model, + OneFlow_Tensor:$model_diff_indices, + OneFlow_Tensor:$model_diff_values, + OneFlow_Tensor:$learning_rate, + OneFlow_Tensor:$momentum + ); + let attrs = (ins + DefaultValuedAttr:$beta, + DefaultValuedAttr:$weight_decay + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_IndexedSlicesSgdUpdateOp : OneFlow_BaseOp<"indexed_slices_sgd_update", [NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$model, + OneFlow_Tensor:$model_diff_indices, + OneFlow_Tensor:$model_diff_values, + OneFlow_Tensor:$learning_rate + ); + let attrs = (ins + DefaultValuedAttr:$weight_decay + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_LambUpdateOp : OneFlow_BaseOp<"lamb_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$m, + OneFlow_Tensor:$v, + OneFlow_Tensor:$beta1_t, + OneFlow_Tensor:$beta2_t, + OneFlow_Tensor:$model, + OneFlow_Tensor:$model_diff, + OneFlow_Tensor:$learning_rate, + Optional:$scale_by_tensor, + Optional:$skip_if + ); + let attrs = (ins + DefaultValuedAttr:$beta1, + DefaultValuedAttr:$beta2, + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$scale, + DefaultValuedAttr:$l1, + DefaultValuedAttr:$l2, + DefaultValuedAttr:$weight_decay + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_LarsUpdateOp : OneFlow_BaseOp<"lars_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$model, + OneFlow_Tensor:$model_diff, + OneFlow_Tensor:$learning_rate, + OneFlow_Tensor:$momentum, + Optional:$scale_by_tensor, + Optional:$skip_if + ); + let attrs = (ins + DefaultValuedAttr:$scale, + DefaultValuedAttr:$l1, + DefaultValuedAttr:$l2, + DefaultValuedAttr:$momentum_beta, + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$lars_coefficient, + DefaultValuedAttr:$weight_decay + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_MomentumUpdateOp : OneFlow_BaseOp<"momentum_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$model, + OneFlow_Tensor:$model_diff, + OneFlow_Tensor:$momentum, + Optional:$learning_rate, + Optional:$scale_by_tensor, + Optional:$skip_if + ); + let attrs = (ins + DefaultValuedAttr:$learning_rate_val, + DefaultValuedAttr:$scale, + DefaultValuedAttr:$l1, + DefaultValuedAttr:$l2, + DefaultValuedAttr:$beta, + DefaultValuedAttr:$weight_decay + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_RmspropUpdateOp : OneFlow_BaseOp<"rmsprop_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$model, + OneFlow_Tensor:$model_diff, + Optional:$learning_rate, + Optional:$scale_by_tensor, + Optional:$skip_if, + OneFlow_Tensor:$mean_square, + Optional:$mean_gradient + ); + let attrs = (ins + DefaultValuedAttr:$learning_rate_val, + DefaultValuedAttr:$scale, + DefaultValuedAttr:$l1, + DefaultValuedAttr:$l2, + DefaultValuedAttr:$centered, + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$decay_rate, + DefaultValuedAttr:$weight_decay + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_SgdUpdateOp : OneFlow_BaseOp<"sgd_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$model, + OneFlow_Tensor:$model_diff, + Optional:$learning_rate, + Optional:$scale_by_tensor, + Optional:$skip_if + ); + let attrs = (ins + DefaultValuedAttr:$learning_rate_val, + DefaultValuedAttr:$scale, + DefaultValuedAttr:$l1, + DefaultValuedAttr:$l2, + DefaultValuedAttr:$weight_decay + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_SliceUpdateOp : OneFlow_BaseOp<"slice_update", [DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$update + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + SI64ArrayAttr:$start, + SI64ArrayAttr:$stop, + SI64ArrayAttr:$step + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_OPTIMIZER_OP_DEFINITIONS + +// Group: PADDING +// constant_pad1d, constant_pad1d_grad, constant_pad2d, constant_pad2d_grad, constant_pad3d, constant_pad3d_grad, pad, pad_grad, reflection_pad2d, reflection_pad2d_grad, replication_pad2d, replication_pad2d_grad, same_padding, same_padding_grad +// Total: 14 + +#ifdef GET_ONEFLOW_PADDING_OP_DEFINITIONS + +def OneFlow_ConstantPad1DOp : OneFlow_BaseOp<"constant_pad1d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + SI64ArrayAttr:$padding, + DefaultValuedAttr:$floating_value, + DefaultValuedAttr:$integral_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_ConstantPad1DGradOp : OneFlow_BaseOp<"constant_pad1d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + SI64ArrayAttr:$padding, + DefaultValuedAttr:$floating_value, + DefaultValuedAttr:$integral_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ConstantPad2DOp : OneFlow_BaseOp<"constant_pad2d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + SI64ArrayAttr:$padding, + DefaultValuedAttr:$floating_value, + DefaultValuedAttr:$integral_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_ConstantPad2DGradOp : OneFlow_BaseOp<"constant_pad2d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + SI64ArrayAttr:$padding, + DefaultValuedAttr:$floating_value, + DefaultValuedAttr:$integral_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ConstantPad3DOp : OneFlow_BaseOp<"constant_pad3d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + SI64ArrayAttr:$padding, + DefaultValuedAttr:$floating_value, + DefaultValuedAttr:$integral_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_ConstantPad3DGradOp : OneFlow_BaseOp<"constant_pad3d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + SI64ArrayAttr:$padding, + DefaultValuedAttr:$floating_value, + DefaultValuedAttr:$integral_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_PadOp : OneFlow_BaseOp<"pad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + SI64ArrayAttr:$padding_before, + SI64ArrayAttr:$padding_after, + DefaultValuedAttr:$floating_constant_value, + DefaultValuedAttr:$integral_constant_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_PadGradOp : OneFlow_BaseOp<"pad_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + SI64ArrayAttr:$padding_before, + SI64ArrayAttr:$padding_after, + DefaultValuedAttr:$floating_constant_value, + DefaultValuedAttr:$integral_constant_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReflectionPad2DOp : OneFlow_BaseOp<"reflection_pad2d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + SI64ArrayAttr:$padding + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_ReflectionPad2DGradOp : OneFlow_BaseOp<"reflection_pad2d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + SI64ArrayAttr:$padding + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReplicationPad2DOp : OneFlow_BaseOp<"replication_pad2d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + SI64ArrayAttr:$padding + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_ReplicationPad2DGradOp : OneFlow_BaseOp<"replication_pad2d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + SI64ArrayAttr:$padding + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SamePaddingOp : OneFlow_BaseOp<"same_padding", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + StrAttr:$padding, + StrAttr:$data_format, + SI32ArrayAttr:$kernel_size, + SI32ArrayAttr:$strides, + SI32ArrayAttr:$dilation_rate + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SamePaddingGradOp : OneFlow_BaseOp<"same_padding_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x_like, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + StrAttr:$padding, + StrAttr:$data_format, + SI32ArrayAttr:$kernel_size, + SI32ArrayAttr:$strides, + SI32ArrayAttr:$dilation_rate + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_PADDING_OP_DEFINITIONS + +// Group: PARALLEL_CAST +// hierarchical_parallel_cast, hierarchical_parallel_cast_like, parallel_cast +// Total: 3 + +#ifdef GET_ONEFLOW_PARALLEL_CAST_OP_DEFINITIONS + +def OneFlow_HierarchicalParallelCastOp : OneFlow_BaseOp<"hierarchical_parallel_cast", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + StrArrayAttr:$nd_sbp, + StrAttr:$grad_mode, + StrArrayAttr:$grad_nd_sbp + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_HierarchicalParallelCastLikeOp : OneFlow_BaseOp<"hierarchical_parallel_cast_like", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in, + OneFlow_Tensor:$like + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_ParallelCastOp : OneFlow_BaseOp<"parallel_cast", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + StrAttr:$sbp_parallel, + StrAttr:$grad_sbp_parallel + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_sbp_signature_infer_fn = 1; +} + +#endif // GET_ONEFLOW_PARALLEL_CAST_OP_DEFINITIONS + +// Group: POOL +// adaptive_avg_pool1d, adaptive_avg_pool1d_grad, adaptive_avg_pool2d, adaptive_avg_pool2d_grad, adaptive_avg_pool3d, adaptive_avg_pool3d_grad, avgpool_1d, avgpool_1d_grad, avgpool_2d, avgpool_2d_grad, avgpool_3d, avgpool_3d_grad, maxpool_1d, maxpool_1d_grad, maxpool_2d, maxpool_2d_grad, maxpool_3d, maxpool_3d_grad, tf_avg_pool_1d, tf_avg_pool_1d_grad, tf_avg_pool_2d, tf_avg_pool_2d_grad, tf_avg_pool_3d, tf_avg_pool_3d_grad, tf_max_pool_1d, tf_max_pool_1d_grad, tf_max_pool_2d, tf_max_pool_2d_grad, tf_max_pool_3d, tf_max_pool_3d_grad +// Total: 30 + +#ifdef GET_ONEFLOW_POOL_OP_DEFINITIONS + +def OneFlow_AdaptiveAvgPool1DOp : OneFlow_AdaptivePoolBaseOp<"adaptive_avg_pool1d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_AdaptiveAvgPool1DGradOp : OneFlow_AdaptivePoolGradBaseOp<"adaptive_avg_pool1d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_AdaptiveAvgPool2DOp : OneFlow_AdaptivePoolBaseOp<"adaptive_avg_pool2d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_AdaptiveAvgPool2DGradOp : OneFlow_AdaptivePoolGradBaseOp<"adaptive_avg_pool2d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_AdaptiveAvgPool3DOp : OneFlow_AdaptivePoolBaseOp<"adaptive_avg_pool3d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_AdaptiveAvgPool3DGradOp : OneFlow_AdaptivePoolGradBaseOp<"adaptive_avg_pool3d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_AvgPool1DOp : OneFlow_AvgPoolBaseOp<"avgpool_1d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_AvgPool1DGradOp : OneFlow_AvgPoolGradBaseOp<"avgpool_1d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_AvgPool2DOp : OneFlow_AvgPoolBaseOp<"avgpool_2d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_AvgPool2DGradOp : OneFlow_AvgPoolGradBaseOp<"avgpool_2d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_AvgPool3DOp : OneFlow_AvgPoolBaseOp<"avgpool_3d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_AvgPool3DGradOp : OneFlow_AvgPoolGradBaseOp<"avgpool_3d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_MaxPool1DOp : OneFlow_MaxPoolBaseOp<"maxpool_1d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_MaxPool1DGradOp : OneFlow_MaxPoolGradBaseOp<"maxpool_1d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_MaxPool2DOp : OneFlow_MaxPoolBaseOp<"maxpool_2d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_MaxPool2DGradOp : OneFlow_MaxPoolGradBaseOp<"maxpool_2d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_MaxPool3DOp : OneFlow_MaxPoolBaseOp<"maxpool_3d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_MaxPool3DGradOp : OneFlow_MaxPoolGradBaseOp<"maxpool_3d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_TfAvgPool1DOp : OneFlow_TFPoolBaseOp<"tf_avg_pool_1d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_TfAvgPool1DGradOp : OneFlow_TFPoolGradBaseOp<"tf_avg_pool_1d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_TfAvgPool2DOp : OneFlow_TFPoolBaseOp<"tf_avg_pool_2d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_TfAvgPool2DGradOp : OneFlow_TFPoolGradBaseOp<"tf_avg_pool_2d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_TfAvgPool3DOp : OneFlow_TFPoolBaseOp<"tf_avg_pool_3d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_TfAvgPool3DGradOp : OneFlow_TFPoolGradBaseOp<"tf_avg_pool_3d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_TfMaxPool1DOp : OneFlow_TFPoolBaseOp<"tf_max_pool_1d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_TfMaxPool1DGradOp : OneFlow_TFPoolGradBaseOp<"tf_max_pool_1d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_TfMaxPool2DOp : OneFlow_TFPoolBaseOp<"tf_max_pool_2d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_TfMaxPool2DGradOp : OneFlow_TFPoolGradBaseOp<"tf_max_pool_2d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_TfMaxPool3DOp : OneFlow_TFPoolBaseOp<"tf_max_pool_3d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_TfMaxPool3DGradOp : OneFlow_TFPoolGradBaseOp<"tf_max_pool_3d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +#endif // GET_ONEFLOW_POOL_OP_DEFINITIONS + +// Group: QUANTIZATION +// fake_quantization, min_max_observer, moving_average_min_max_observer, quantization +// Total: 4 + +#ifdef GET_ONEFLOW_QUANTIZATION_OP_DEFINITIONS + +def OneFlow_FakeQuantizationOp : OneFlow_BaseOp<"fake_quantization", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in, + OneFlow_Tensor:$scale, + OneFlow_Tensor:$zero_point + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$quantization_formula, + DefaultValuedAttr:$quantization_bit, + DefaultValuedAttr:$quantization_scheme + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_MinMaxObserverOp : OneFlow_BaseOp<"min_max_observer", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$scale, + OneFlow_Tensor:$zero_point + ); + let attrs = (ins + DefaultValuedAttr:$quantization_formula, + DefaultValuedAttr:$quantization_bit, + DefaultValuedAttr:$quantization_scheme, + DefaultValuedAttr:$per_layer_quantization + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_MovingAverageMinMaxObserverOp : OneFlow_BaseOp<"moving_average_min_max_observer", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in, + OneFlow_Tensor:$current_train_step, + OneFlow_Tensor:$moving_max, + OneFlow_Tensor:$moving_min + ); + let output = (outs + OneFlow_Tensor:$scale, + OneFlow_Tensor:$zero_point + ); + let attrs = (ins + DefaultValuedAttr:$training, + DefaultValuedAttr:$quantization_formula, + DefaultValuedAttr:$stop_update_after_iters, + DefaultValuedAttr:$quantization_bit, + DefaultValuedAttr:$quantization_scheme, + DefaultValuedAttr:$momentum + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_QuantizationOp : OneFlow_BaseOp<"quantization", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in, + OneFlow_Tensor:$scale, + OneFlow_Tensor:$zero_point + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$quantization_formula, + DefaultValuedAttr:$quantization_bit, + DefaultValuedAttr:$quantization_scheme + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +#endif // GET_ONEFLOW_QUANTIZATION_OP_DEFINITIONS + +// Group: REDUCE +// indexed_slices_reduce_sum, reduce_all, reduce_any, reduce_max, reduce_max_device_stage, reduce_max_device_stage_grad, reduce_max_global_stage, reduce_max_global_stage_grad, reduce_min, reduce_min_device_stage, reduce_min_device_stage_grad, reduce_min_global_stage, reduce_min_global_stage_grad, reduce_prod, reduce_sum, reduce_sum_like +// Total: 16 + +#ifdef GET_ONEFLOW_REDUCE_OP_DEFINITIONS + +def OneFlow_IndexedSlicesReduceSumOp : OneFlow_BaseOp<"indexed_slices_reduce_sum", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x_indices, + OneFlow_Tensor:$x_values + ); + let output = (outs + OneFlow_Tensor:$y_indices, + OneFlow_Tensor:$y_values, + OneFlow_Tensor:$num_unique + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceAllOp : OneFlow_BaseOp<"reduce_all", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input_tensor + ); + let output = (outs + OneFlow_Tensor:$output_tensor + ); + let attrs = (ins + SI32ArrayAttr:$axis, + DefaultValuedAttr:$keepdims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceAnyOp : OneFlow_BaseOp<"reduce_any", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input_tensor + ); + let output = (outs + OneFlow_Tensor:$output_tensor + ); + let attrs = (ins + SI32ArrayAttr:$axis, + DefaultValuedAttr:$keepdims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceMaxOp : OneFlow_BaseOp<"reduce_max", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input_tensor + ); + let output = (outs + OneFlow_Tensor:$output_tensor + ); + let attrs = (ins + SI32ArrayAttr:$axis, + DefaultValuedAttr:$keepdims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceMaxDeviceStageOp : OneFlow_BaseOp<"reduce_max_device_stage", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out, + OneFlow_Tensor:$mask, + OneFlow_Tensor:$count + ); + let attrs = (ins + SI32ArrayAttr:$axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceMaxDeviceStageGradOp : OneFlow_BaseOp<"reduce_max_device_stage_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$out_diff, + OneFlow_Tensor:$mask, + OneFlow_Tensor:$count + ); + let output = (outs + OneFlow_Tensor:$in_diff + ); + let attrs = (ins + SI32ArrayAttr:$axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceMaxGlobalStageOp : OneFlow_BaseOp<"reduce_max_global_stage", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in, + OneFlow_Tensor:$device_count + ); + let output = (outs + OneFlow_Tensor:$out, + OneFlow_Tensor:$mask + ); + let attrs = (ins + SI32ArrayAttr:$axis, + DefaultValuedAttr:$keepdims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_ReduceMaxGlobalStageGradOp : OneFlow_BaseOp<"reduce_max_global_stage_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$out_diff, + OneFlow_Tensor:$mask, + OneFlow_Tensor:$device_count + ); + let output = (outs + OneFlow_Tensor:$in_diff + ); + let attrs = (ins + SI32ArrayAttr:$axis, + DefaultValuedAttr:$keepdims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceMinOp : OneFlow_BaseOp<"reduce_min", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input_tensor + ); + let output = (outs + OneFlow_Tensor:$output_tensor + ); + let attrs = (ins + SI32ArrayAttr:$axis, + DefaultValuedAttr:$keepdims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceMinDeviceStageOp : OneFlow_BaseOp<"reduce_min_device_stage", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out, + OneFlow_Tensor:$mask, + OneFlow_Tensor:$count + ); + let attrs = (ins + SI32ArrayAttr:$axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceMinDeviceStageGradOp : OneFlow_BaseOp<"reduce_min_device_stage_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$out_diff, + OneFlow_Tensor:$mask, + OneFlow_Tensor:$count + ); + let output = (outs + OneFlow_Tensor:$in_diff + ); + let attrs = (ins + SI32ArrayAttr:$axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceMinGlobalStageOp : OneFlow_BaseOp<"reduce_min_global_stage", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in, + OneFlow_Tensor:$device_count + ); + let output = (outs + OneFlow_Tensor:$out, + OneFlow_Tensor:$mask + ); + let attrs = (ins + SI32ArrayAttr:$axis, + DefaultValuedAttr:$keepdims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_ReduceMinGlobalStageGradOp : OneFlow_BaseOp<"reduce_min_global_stage_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$out_diff, + OneFlow_Tensor:$mask, + OneFlow_Tensor:$device_count + ); + let output = (outs + OneFlow_Tensor:$in_diff + ); + let attrs = (ins + SI32ArrayAttr:$axis, + DefaultValuedAttr:$keepdims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceProdOp : OneFlow_BaseOp<"reduce_prod", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input_tensor + ); + let output = (outs + OneFlow_Tensor:$output_tensor + ); + let attrs = (ins + SI32ArrayAttr:$axis, + DefaultValuedAttr:$keepdims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceSumOp : OneFlow_BaseOp<"reduce_sum", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input_tensor + ); + let output = (outs + OneFlow_Tensor:$output_tensor + ); + let attrs = (ins + SI32ArrayAttr:$axis, + DefaultValuedAttr:$keepdims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceSumLikeOp : OneFlow_BaseOp<"reduce_sum_like", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$like + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + SI32ArrayAttr:$axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +#endif // GET_ONEFLOW_REDUCE_OP_DEFINITIONS + +// Group: RESHAPE +// reshape, reshape_like +// Total: 2 + +#ifdef GET_ONEFLOW_RESHAPE_OP_DEFINITIONS + +def OneFlow_ReshapeOp : OneFlow_BaseOp<"reshape", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + ShapeAttr:$shape + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_ReshapeLikeOp : OneFlow_BaseOp<"reshape_like", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in, + OneFlow_Tensor:$like + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +#endif // GET_ONEFLOW_RESHAPE_OP_DEFINITIONS + +// Group: SCALAR +// clip_by_scalar, clip_by_scalar_grad, clip_by_scalar_max, clip_by_scalar_max_grad, clip_by_scalar_min, clip_by_scalar_min_grad, scalar_add, scalar_add_by_tensor, scalar_div_by_tensor, scalar_floordiv, scalar_fmod, scalar_logical_and, scalar_logical_equal, scalar_logical_greater, scalar_logical_greater_equal, scalar_logical_less, scalar_logical_less_equal, scalar_logical_not_equal, scalar_logical_or, scalar_logical_xor, scalar_mul, scalar_mul_by_tensor, scalar_pow, scalar_pow_grad, scalar_sub_by_tensor +// Total: 25 + +#ifdef GET_ONEFLOW_SCALAR_OP_DEFINITIONS + +def OneFlow_ClipByScalarOp : OneFlow_BaseOp<"clip_by_scalar", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + DefaultValuedAttr:$floating_min, + DefaultValuedAttr:$integral_min, + DefaultValuedAttr:$floating_max, + DefaultValuedAttr:$integral_max + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ClipByScalarGradOp : OneFlow_BaseOp<"clip_by_scalar_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$floating_min, + DefaultValuedAttr:$integral_min, + DefaultValuedAttr:$floating_max, + DefaultValuedAttr:$integral_max + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ClipByScalarMaxOp : OneFlow_BaseOp<"clip_by_scalar_max", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + DefaultValuedAttr:$floating_max, + DefaultValuedAttr:$integral_max + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ClipByScalarMaxGradOp : OneFlow_BaseOp<"clip_by_scalar_max_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$floating_max, + DefaultValuedAttr:$integral_max + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ClipByScalarMinOp : OneFlow_BaseOp<"clip_by_scalar_min", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + DefaultValuedAttr:$floating_min, + DefaultValuedAttr:$integral_min + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ClipByScalarMinGradOp : OneFlow_BaseOp<"clip_by_scalar_min_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$floating_min, + DefaultValuedAttr:$integral_min + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarAddOp : OneFlow_BaseOp<"scalar_add", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarAddByTensorOp : OneFlow_BaseOp<"scalar_add_by_tensor", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$scalar + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarDivByTensorOp : OneFlow_BaseOp<"scalar_div_by_tensor", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$scalar + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarFloordivOp : OneFlow_BaseOp<"scalar_floordiv", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarFmodOp : OneFlow_BaseOp<"scalar_fmod", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarLogicalAndOp : OneFlow_BaseOp<"scalar_logical_and", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarLogicalEqualOp : OneFlow_BaseOp<"scalar_logical_equal", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarLogicalGreaterOp : OneFlow_BaseOp<"scalar_logical_greater", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarLogicalGreaterEqualOp : OneFlow_BaseOp<"scalar_logical_greater_equal", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarLogicalLessOp : OneFlow_BaseOp<"scalar_logical_less", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarLogicalLessEqualOp : OneFlow_BaseOp<"scalar_logical_less_equal", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarLogicalNotEqualOp : OneFlow_BaseOp<"scalar_logical_not_equal", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarLogicalOrOp : OneFlow_BaseOp<"scalar_logical_or", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarLogicalXorOp : OneFlow_BaseOp<"scalar_logical_xor", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarMulOp : OneFlow_BaseOp<"scalar_mul", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarMulByTensorOp : OneFlow_BaseOp<"scalar_mul_by_tensor", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$scalar + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarPowOp : OneFlow_BaseOp<"scalar_pow", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarPowGradOp : OneFlow_BaseOp<"scalar_pow_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarSubByTensorOp : OneFlow_BaseOp<"scalar_sub_by_tensor", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$scalar + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_SCALAR_OP_DEFINITIONS + +// Group: SOFTMAX +// log_softmax, log_softmax_grad, softmax, softmax_cross_entropy, softmax_cross_entropy_grad, softmax_grad, sparse_softmax_cross_entropy, sparse_softmax_cross_entropy_grad, sparse_softmax_cross_entropy_ms, sparse_softmax_cross_entropy_ms_grad +// Total: 10 + +#ifdef GET_ONEFLOW_SOFTMAX_OP_DEFINITIONS + +def OneFlow_LogSoftmaxOp : OneFlow_BaseOp<"log_softmax", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$prob + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LogSoftmaxGradOp : OneFlow_BaseOp<"log_softmax_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$prob, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SoftmaxOp : OneFlow_BaseOp<"softmax", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SoftmaxCrossEntropyOp : OneFlow_BaseOp<"softmax_cross_entropy", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$prediction, + OneFlow_Tensor:$label + ); + let output = (outs + OneFlow_Tensor:$prob, + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_SoftmaxCrossEntropyGradOp : OneFlow_BaseOp<"softmax_cross_entropy_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$label, + OneFlow_Tensor:$prob + ); + let output = (outs + OneFlow_Tensor:$prediction_diff + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SoftmaxGradOp : OneFlow_BaseOp<"softmax_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$y, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SparseSoftmaxCrossEntropyOp : OneFlow_BaseOp<"sparse_softmax_cross_entropy", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$prediction, + OneFlow_Tensor:$label + ); + let output = (outs + OneFlow_Tensor:$prob, + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$depth + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_SparseSoftmaxCrossEntropyGradOp : OneFlow_BaseOp<"sparse_softmax_cross_entropy_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$label, + OneFlow_Tensor:$dy, + OneFlow_Tensor:$prob + ); + let output = (outs + OneFlow_Tensor:$prediction_diff + ); + let attrs = (ins + DefaultValuedAttr:$depth + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SparseSoftmaxCrossEntropyMsOp : OneFlow_BaseOp<"sparse_softmax_cross_entropy_ms", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$prediction, + OneFlow_Tensor:$label + ); + let output = (outs + OneFlow_Tensor:$prob, + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$depth + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_SparseSoftmaxCrossEntropyMsGradOp : OneFlow_BaseOp<"sparse_softmax_cross_entropy_ms_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$label, + OneFlow_Tensor:$dy, + OneFlow_Tensor:$prob + ); + let output = (outs + OneFlow_Tensor:$prediction_diff + ); + let attrs = (ins + DefaultValuedAttr:$depth + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_SOFTMAX_OP_DEFINITIONS + +// Group: SUMMARY +// create_summary_writer, flush_summary_writer, summary_write_histogram, summary_write_image, summary_write_pb, summary_write_scalar +// Total: 6 + +#ifdef GET_ONEFLOW_SUMMARY_OP_DEFINITIONS + +def OneFlow_CreateSummaryWriterOp : OneFlow_BaseOp<"create_summary_writer", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let attrs = (ins + StrAttr:$logdir + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FlushSummaryWriterOp : OneFlow_BaseOp<"flush_summary_writer", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SummaryWriteHistogramOp : OneFlow_BaseOp<"summary_write_histogram", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in, + OneFlow_Tensor:$step, + OneFlow_Tensor:$tag + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SummaryWriteImageOp : OneFlow_BaseOp<"summary_write_image", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in, + OneFlow_Tensor:$step, + OneFlow_Tensor:$tag + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SummaryWritePbOp : OneFlow_BaseOp<"summary_write_pb", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in, + OneFlow_Tensor:$step + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SummaryWriteScalarOp : OneFlow_BaseOp<"summary_write_scalar", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in, + OneFlow_Tensor:$step, + OneFlow_Tensor:$tag + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_SUMMARY_OP_DEFINITIONS + +// Group: TENSOR_BUFFER +// gen_tensor_buffer, tensor_buffer_to_list_of_tensors, tensor_buffer_to_list_of_tensors_v2, tensor_buffer_to_tensor, tensor_to_tensor_buffer +// Total: 5 + +#ifdef GET_ONEFLOW_TENSOR_BUFFER_OP_DEFINITIONS + +def OneFlow_GenTensorBufferOp : OneFlow_BaseOp<"gen_tensor_buffer", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + ShapeAttr:$shape, + ShapeArrayAttr:$shape_list, + F32ArrayAttr:$value_list, + OneFlow_DataType:$data_type, + DefaultValuedAttr:$dynamic_out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TensorBufferToListOfTensorsOp : OneFlow_BaseOp<"tensor_buffer_to_list_of_tensors", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + Variadic:$out + ); + let attrs = (ins + ShapeAttr:$out_shape, + OneFlow_DataType:$out_dtype, + DefaultValuedAttr:$dynamic_out + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_output_arg_modify_fn = 1; +} + +def OneFlow_TensorBufferToListOfTensorsV2Op : OneFlow_BaseOp<"tensor_buffer_to_list_of_tensors_v2", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + Variadic:$out + ); + let attrs = (ins + ShapeArrayAttr:$out_shapes, + DTArrayAttr:$out_dtypes, + DefaultValuedAttr:$dynamic_out + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_output_arg_modify_fn = 1; +} + +def OneFlow_TensorBufferToTensorOp : OneFlow_BaseOp<"tensor_buffer_to_tensor", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + ShapeAttr:$instance_shape, + OneFlow_DataType:$dtype + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TensorToTensorBufferOp : OneFlow_BaseOp<"tensor_to_tensor_buffer", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$instance_dims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_TENSOR_BUFFER_OP_DEFINITIONS + +// Group: TEST +// TestDataTypeAttr, TestDynamicSource, TestListDataTypeAndListShapeAndListStringAttr, TestMultiInput, TestMultiInputGrad, TestMultiOutputOrder, TestRandomSource, TestReshape, TestSource, TestSourceMultiGpuFixedOutNum, ccrelu, ccrelu_grad, cpu_only_relu_test, test_user_op_attr_auto_type +// Total: 14 + +#ifdef GET_ONEFLOW_TEST_OP_DEFINITIONS + +def OneFlow_TestDataTypeAttrOp : OneFlow_BaseOp<"TestDataTypeAttr", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + OneFlow_DataType:$output_type + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TestDynamicSourceOp : OneFlow_BaseOp<"TestDynamicSource", [NoSideEffect, DeclareOpInterfaceMethods]> { + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_output_arg_modify_fn = 1; +} + +def OneFlow_TestListDataTypeAndListShapeAndListStringAttrOp : OneFlow_BaseOp<"TestListDataTypeAndListShapeAndListStringAttr", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + ShapeArrayAttr:$out_shapes, + DTArrayAttr:$out_types, + StrArrayAttr:$string_list + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TestMultiInputOp : OneFlow_BaseOp<"TestMultiInput", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x1, + OneFlow_Tensor:$x2 + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TestMultiInputGradOp : OneFlow_BaseOp<"TestMultiInputGrad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x1, + OneFlow_Tensor:$x2, + OneFlow_Tensor:$y_diff + ); + let output = (outs + OneFlow_Tensor:$x1_diff, + OneFlow_Tensor:$x2_diff + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TestMultiOutputOrderOp : OneFlow_BaseOp<"TestMultiOutputOrder", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out1, + OneFlow_Tensor:$out2 + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TestRandomSourceOp : OneFlow_BaseOp<"TestRandomSource", [NoSideEffect, DeclareOpInterfaceMethods]> { + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$seed + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TestReshapeOp : OneFlow_BaseOp<"TestReshape", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + ShapeAttr:$shape + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TestSourceOp : OneFlow_BaseOp<"TestSource", [NoSideEffect, DeclareOpInterfaceMethods]> { + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TestSourceMultiGpuFixedOutNumOp : OneFlow_BaseOp<"TestSourceMultiGpuFixedOutNum", [NoSideEffect, DeclareOpInterfaceMethods]> { + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$out_num + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CcreluOp : OneFlow_BaseOp<"ccrelu", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CcreluGradOp : OneFlow_BaseOp<"ccrelu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$y, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CpuOnlyReluTestOp : OneFlow_BaseOp<"cpu_only_relu_test", [NoSideEffect, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TestUserOpAttrAutoTypeOp : OneFlow_BaseOp<"test_user_op_attr_auto_type", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$int1, + DefaultValuedAttr:$int2 + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_TEST_OP_DEFINITIONS + +// Group: TRIGONOMETRIC +// acos, acos_grad, acosh, acosh_grad, asin, asin_grad, asinh, asinh_grad, atan, atan2, atan2_x_grad, atan2_y_grad, atan_grad, atanh, atanh_grad, cos, cos_grad, cosh, cosh_grad, hardtanh, hardtanh_grad, sin, sin_grad, sinh, sinh_grad, tan, tan_grad, tanh, tanh_grad +// Total: 29 + +#ifdef GET_ONEFLOW_TRIGONOMETRIC_OP_DEFINITIONS + +def OneFlow_AcosOp : OneFlow_BaseOp<"acos", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AcosGradOp : OneFlow_BaseOp<"acos_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AcoshOp : OneFlow_BaseOp<"acosh", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AcoshGradOp : OneFlow_BaseOp<"acosh_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AsinOp : OneFlow_BaseOp<"asin", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AsinGradOp : OneFlow_BaseOp<"asin_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AsinhOp : OneFlow_BaseOp<"asinh", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AsinhGradOp : OneFlow_BaseOp<"asinh_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AtanOp : OneFlow_BaseOp<"atan", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_Atan2Op : OneFlow_BaseOp<"atan2", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y + ); + let output = (outs + OneFlow_Tensor:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_Atan2XGradOp : OneFlow_BaseOp<"atan2_x_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y, + OneFlow_Tensor:$dz + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_Atan2YGradOp : OneFlow_BaseOp<"atan2_y_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$y, + OneFlow_Tensor:$dz + ); + let output = (outs + OneFlow_Tensor:$dy + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AtanGradOp : OneFlow_BaseOp<"atan_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AtanhOp : OneFlow_BaseOp<"atanh", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AtanhGradOp : OneFlow_BaseOp<"atanh_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CosOp : OneFlow_BaseOp<"cos", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CosGradOp : OneFlow_BaseOp<"cos_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CoshOp : OneFlow_BaseOp<"cosh", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CoshGradOp : OneFlow_BaseOp<"cosh_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_HardtanhOp : OneFlow_BaseOp<"hardtanh", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$min_val, + DefaultValuedAttr:$max_val + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_HardtanhGradOp : OneFlow_BaseOp<"hardtanh_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$y, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$min_val, + DefaultValuedAttr:$max_val + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SinOp : OneFlow_BaseOp<"sin", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SinGradOp : OneFlow_BaseOp<"sin_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SinhOp : OneFlow_BaseOp<"sinh", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SinhGradOp : OneFlow_BaseOp<"sinh_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TanOp : OneFlow_BaseOp<"tan", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TanGradOp : OneFlow_BaseOp<"tan_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TanhOp : OneFlow_BaseOp<"tanh", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TanhGradOp : OneFlow_BaseOp<"tanh_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_TRIGONOMETRIC_OP_DEFINITIONS + +// Group: UNARY +// acc, affine_grid, affine_grid_grad, bernoulli, cast, cast_to_static_shape, cast_to_tick, celu, copy, count_not_finite, diag, diagonal, elu, expand, expand_dims, flatten, flip, flip_grad, fold, gelu, hardsigmoid, hardswish, leaky_relu, log2, logical_not, mish, narrow, one_hot, pack, random_mask_like, repeat, roll, selu, sigmoid, silu, softsign, sort, square_sum, squeeze, transpose, tril, triu, unfold, unfold_tensor, unpack, zero_like +// Total: 46 + +#ifdef GET_ONEFLOW_UNARY_OP_DEFINITIONS + +def OneFlow_AccOp : OneFlow_BaseOp<"acc", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$max_acc_num + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_output_blob_time_shape_infer_fn = 1; +} + +def OneFlow_AffineGridOp : OneFlow_BaseOp<"affine_grid", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$theta + ); + let output = (outs + OneFlow_Tensor:$grid + ); + let attrs = (ins + ShapeAttr:$size, + DefaultValuedAttr:$align_corners + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AffineGridGradOp : OneFlow_BaseOp<"affine_grid_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dgrid + ); + let output = (outs + OneFlow_Tensor:$dtheta + ); + let attrs = (ins + ShapeAttr:$size, + DefaultValuedAttr:$align_corners + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BernoulliOp : OneFlow_BaseOp<"bernoulli", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$seed, + DefaultValuedAttr:$has_seed, + OneFlow_DataType:$dtype + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CastOp : OneFlow_BaseOp<"cast", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + OneFlow_DataType:$dtype + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CastToStaticShapeOp : OneFlow_BaseOp<"cast_to_static_shape", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input + ); + let output = (outs + OneFlow_Tensor:$output + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CastToTickOp : OneFlow_BaseOp<"cast_to_tick", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_CeluOp : OneFlow_BaseOp<"celu", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$alpha + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CopyOp : OneFlow_BaseOp<"copy", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + StrAttr:$device_type, + DefaultValuedAttr:$device_id + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; +} + +def OneFlow_CountNotFiniteOp : OneFlow_BaseOp<"count_not_finite", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_DiagOp : OneFlow_BaseOp<"diag", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$diagonal + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_DiagonalOp : OneFlow_BaseOp<"diagonal", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$offset + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_EluOp : OneFlow_BaseOp<"elu", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$alpha + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ExpandOp : OneFlow_BaseOp<"expand", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + SI32ArrayAttr:$logical_in_shape, + SI32ArrayAttr:$logical_expand_shape + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ExpandDimsOp : OneFlow_BaseOp<"expand_dims", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FlattenOp : OneFlow_BaseOp<"flatten", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$start_dim, + DefaultValuedAttr:$end_dim + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FlipOp : OneFlow_BaseOp<"flip", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + SI32ArrayAttr:$dims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FlipGradOp : OneFlow_BaseOp<"flip_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + SI32ArrayAttr:$dims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FoldOp : OneFlow_BaseOp<"fold", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + StrAttr:$data_format, + SI32ArrayAttr:$output_size, + SI32ArrayAttr:$kernel_size, + SI32ArrayAttr:$strides, + SI32ArrayAttr:$padding, + SI32ArrayAttr:$dilation_rate + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_GeluOp : OneFlow_BaseOp<"gelu", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_HardsigmoidOp : OneFlow_BaseOp<"hardsigmoid", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_HardswishOp : OneFlow_BaseOp<"hardswish", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LeakyReluOp : OneFlow_BaseOp<"leaky_relu", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + DefaultValuedAttr:$alpha + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_Log2Op : OneFlow_BaseOp<"log2", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LogicalNotOp : OneFlow_BaseOp<"logical_not", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_MishOp : OneFlow_BaseOp<"mish", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_NarrowOp : OneFlow_BaseOp<"narrow", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$dim, + DefaultValuedAttr:$start, + DefaultValuedAttr:$length + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_OneHotOp : OneFlow_BaseOp<"one_hot", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$indices + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$depth, + DefaultValuedAttr:$floating_on_value, + DefaultValuedAttr:$integer_on_value, + DefaultValuedAttr:$floating_off_value, + DefaultValuedAttr:$integer_off_value, + OneFlow_DataType:$dtype + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_PackOp : OneFlow_BaseOp<"pack", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$pack_num + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_output_blob_time_shape_infer_fn = 1; +} + +def OneFlow_RandomMaskLikeOp : OneFlow_BaseOp<"random_mask_like", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$like + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$rate, + DefaultValuedAttr:$seed + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_RepeatOp : OneFlow_BaseOp<"repeat", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$repeat_num + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_output_blob_time_shape_infer_fn = 1; +} + +def OneFlow_RollOp : OneFlow_BaseOp<"roll", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + SI32ArrayAttr:$shifts, + SI32ArrayAttr:$dims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SeluOp : OneFlow_BaseOp<"selu", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SigmoidOp : OneFlow_BaseOp<"sigmoid", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SiluOp : OneFlow_BaseOp<"silu", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SoftsignOp : OneFlow_BaseOp<"softsign", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SortOp : OneFlow_BaseOp<"sort", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + StrAttr:$direction + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SquareSumOp : OneFlow_BaseOp<"square_sum", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SqrtSquareSumOp : OneFlow_BaseOp<"sqrt_square_sum", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SqueezeOp : OneFlow_BaseOp<"squeeze", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + SI32ArrayAttr:$axes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TransposeOp : OneFlow_BaseOp<"transpose", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$input + ); + let output = (outs + OneFlow_Tensor:$output + ); + let attrs = (ins + SI32ArrayAttr:$perm + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TrilOp : OneFlow_BaseOp<"tril", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$diagonal, + DefaultValuedAttr:$floating_fill_value, + DefaultValuedAttr:$integer_fill_value, + DefaultValuedAttr:$is_floating_fill_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TriuOp : OneFlow_BaseOp<"triu", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$diagonal + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UnfoldOp : OneFlow_BaseOp<"unfold", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + StrAttr:$data_format, + SI32ArrayAttr:$kernel_size, + SI32ArrayAttr:$padding, + SI32ArrayAttr:$strides, + SI32ArrayAttr:$dilation_rate + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UnfoldTensorOp : OneFlow_BaseOp<"unfold_tensor", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + DefaultValuedAttr:$dimension, + DefaultValuedAttr:$size, + DefaultValuedAttr:$step + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UnpackOp : OneFlow_BaseOp<"unpack", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$in + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$unpack_num + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_output_blob_time_shape_infer_fn = 1; +} + +def OneFlow_ZeroLikeOp : OneFlow_BaseOp<"zero_like", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$like + ); + let output = (outs + OneFlow_Tensor:$out + ); + let same_output_regst_num = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_UNARY_OP_DEFINITIONS + +// Group: UPSAMPLE +// upsample, upsample_bicubic_2d, upsample_bicubic_2d_grad, upsample_bilinear_2d, upsample_bilinear_2d_grad, upsample_grad, upsample_linear_1d, upsample_linear_1d_grad, upsample_nearest_1d, upsample_nearest_1d_grad, upsample_nearest_2d, upsample_nearest_2d_grad, upsample_nearest_3d, upsample_nearest_3d_grad, upsample_trilinear_3d, upsample_trilinear_3d_grad +// Total: 16 + +#ifdef GET_ONEFLOW_UPSAMPLE_OP_DEFINITIONS + +def OneFlow_UpsampleOp : OneFlow_BaseOp<"upsample", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$align_corners, + StrAttr:$data_format, + StrAttr:$interpolation + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleBicubic2DOp : OneFlow_BaseOp<"upsample_bicubic_2d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$align_corners, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleBicubic2DGradOp : OneFlow_BaseOp<"upsample_bicubic_2d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$align_corners, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleBilinear2DOp : OneFlow_BaseOp<"upsample_bilinear_2d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$align_corners, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleBilinear2DGradOp : OneFlow_BaseOp<"upsample_bilinear_2d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$align_corners, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleGradOp : OneFlow_BaseOp<"upsample_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$align_corners, + StrAttr:$data_format, + StrAttr:$interpolation + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleLinear1DOp : OneFlow_BaseOp<"upsample_linear_1d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + DefaultValuedAttr:$scale_factor, + DefaultValuedAttr:$align_corners, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleLinear1DGradOp : OneFlow_BaseOp<"upsample_linear_1d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$scale_factor, + DefaultValuedAttr:$align_corners, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleNearest1DOp : OneFlow_BaseOp<"upsample_nearest_1d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + DefaultValuedAttr:$scale_factor, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleNearest1DGradOp : OneFlow_BaseOp<"upsample_nearest_1d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$scale_factor, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleNearest2DOp : OneFlow_BaseOp<"upsample_nearest_2d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleNearest2DGradOp : OneFlow_BaseOp<"upsample_nearest_2d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleNearest3DOp : OneFlow_BaseOp<"upsample_nearest_3d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + DefaultValuedAttr:$depth_scale, + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleNearest3DGradOp : OneFlow_BaseOp<"upsample_nearest_3d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$depth_scale, + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleTrilinear3DOp : OneFlow_BaseOp<"upsample_trilinear_3d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$y + ); + let attrs = (ins + DefaultValuedAttr:$depth_scale, + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$align_corners, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleTrilinear3DGradOp : OneFlow_BaseOp<"upsample_trilinear_3d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$dy, + OneFlow_Tensor:$x + ); + let output = (outs + OneFlow_Tensor:$dx + ); + let attrs = (ins + DefaultValuedAttr:$depth_scale, + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$align_corners, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_UPSAMPLE_OP_DEFINITIONS diff --git a/oneflow/ir/install-llvm.cmake b/oneflow/ir/install-llvm.cmake new file mode 100644 index 00000000000..6e0953ae83d --- /dev/null +++ b/oneflow/ir/install-llvm.cmake @@ -0,0 +1,86 @@ +message("-- LLVM_MONO_REPO_URL: " ${LLVM_MONO_REPO_URL}) +message("-- LLVM_MONO_REPO_MD5: " ${LLVM_MONO_REPO_MD5}) +FetchContent_Declare( + llvm_monorepo +) +FetchContent_GetProperties(llvm_monorepo) + +if(NOT llvm_monorepo_POPULATED) + FetchContent_Populate(llvm_monorepo + URL ${LLVM_MONO_REPO_URL} + URL_HASH MD5=${LLVM_MONO_REPO_MD5} + ) + set(LLVM_INSTALL_DIR ${THIRD_PARTY_DIR}/llvm) + + execute_process(COMMAND "${CMAKE_COMMAND}" ${llvm_monorepo_SOURCE_DIR}/llvm + -DCMAKE_C_COMPILER_LAUNCHER=${CMAKE_C_COMPILER_LAUNCHER} + -DCMAKE_CXX_COMPILER_LAUNCHER=${CMAKE_CXX_COMPILER_LAUNCHER} + -DCMAKE_CUDA_COMPILER_LAUNCHER=${CMAKE_CUDA_COMPILER_LAUNCHER} + -DCMAKE_EXE_LINKER_FLAGS_INIT=${CMAKE_EXE_LINKER_FLAGS_INIT} + -DCMAKE_MODULE_LINKER_FLAGS_INIT=${CMAKE_MODULE_LINKER_FLAGS_INIT} + -DCMAKE_SHARED_LINKER_FLAGS_INIT=${CMAKE_SHARED_LINKER_FLAGS_INIT} + -DCMAKE_INSTALL_PREFIX=${LLVM_INSTALL_DIR} + -DCMAKE_INSTALL_MESSAGE=${CMAKE_INSTALL_MESSAGE} + -DLLVM_ENABLE_RTTI=ON # turn this on to make it compatible with protobuf + -DLLVM_ENABLE_EH=ON # turn this on to make it compatible with half (the library) + -DLLVM_BUILD_EXAMPLES=OFF + -DLLVM_BUILD_TOOLS=OFF + -DLLVM_INCLUDE_EXAMPLES=OFF + -DLLVM_INCLUDE_TESTS=OFF + -DLLVM_INCLUDE_BENCHMARKS=OFF + -DLLVM_TARGETS_TO_BUILD=host\;NVPTX + -DLLVM_ENABLE_ASSERTIONS=ON + -DLLVM_ENABLE_PROJECTS=mlir + -DLLVM_APPEND_VC_REV=OFF + -DLLVM_ENABLE_ZLIB=OFF + -DLLVM_INSTALL_UTILS=ON + -DBUILD_SHARED_LIBS=${BUILD_SHARED_LIBS} + -DLLVM_ENABLE_OCAMLDOC=OFF + -DLLVM_ENABLE_BINDINGS=OFF + -DMLIR_ENABLE_CUDA_RUNNER=${WITH_MLIR_CUDA_CODEGEN} + -DCMAKE_CUDA_COMPILER=${CMAKE_CUDA_COMPILER} + -DINJA_URL=${INJA_URL} + -DINJA_URL_HASH=${INJA_URL_HASH} + -DJSON_URL=${JSON_URL} + -DJSON_URL_HASH=${JSON_URL_HASH} + -DCMAKE_CUDA_COMPILER=${CMAKE_CUDA_COMPILER} + -DLLVM_EXTERNAL_PROJECTS=OneFlowTableGen + -DLLVM_EXTERNAL_ONEFLOWTABLEGEN_SOURCE_DIR=${CMAKE_SOURCE_DIR}/tools/oneflow-tblgen + -G ${CMAKE_GENERATOR} + WORKING_DIRECTORY ${llvm_monorepo_BINARY_DIR} + RESULT_VARIABLE ret) + if(ret EQUAL "1") + message( FATAL_ERROR "Bad exit status") + endif() + include(ProcessorCount) + ProcessorCount(PROC_NUM) + if(WITH_MLIR) + set(INSTALL_ALL "install") + endif() + execute_process(COMMAND "${CMAKE_COMMAND}" --build . -j${PROC_NUM} --target ${INSTALL_ALL} install-oneflow-tblgen install-mlir-headers + WORKING_DIRECTORY ${llvm_monorepo_BINARY_DIR} + RESULT_VARIABLE ret + ) + if(ret EQUAL "1") + message( FATAL_ERROR "Bad exit status") + endif() +endif() + +if (WITH_MLIR) +set(LLVM_DIR ${LLVM_INSTALL_DIR}/lib/cmake/llvm) +set(MLIR_DIR ${LLVM_INSTALL_DIR}/lib/cmake/mlir) +find_package(MLIR REQUIRED CONFIG) + +message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}") +message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") + +set(MLIR_BINARY_DIR ${llvm_monorepo_BINARY_DIR}) + +list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") +list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") +include(TableGen) +include(AddLLVM) +include(AddMLIR) +include(HandleLLVMOptions) +set(LLVM_EXTERNAL_LIT "${llvm_monorepo_BINARY_DIR}/bin/llvm-lit" CACHE STRING "" FORCE) +endif() diff --git a/oneflow/ir/lib/OneFlow/CMakeLists.txt b/oneflow/ir/lib/OneFlow/CMakeLists.txt index e090b742a3c..b9f5a0e0a59 100644 --- a/oneflow/ir/lib/OneFlow/CMakeLists.txt +++ b/oneflow/ir/lib/OneFlow/CMakeLists.txt @@ -33,7 +33,6 @@ oneflow_add_mlir_dialect_library(MLIROneFlow DEPENDS MLIROneFlowOpsIncGen prepare_oneflow_third_party - oneflow-gen-ods LINK_LIBS PUBLIC ${dialect_libs} diff --git a/oneflow/ir/lib/OneFlow/OneFlowOpGetGen.cpp.in b/oneflow/ir/lib/OneFlow/OneFlowOpGetGen.cpp.in index 79bdaba8ecb..0a98a691285 100644 --- a/oneflow/ir/lib/OneFlow/OneFlowOpGetGen.cpp.in +++ b/oneflow/ir/lib/OneFlow/OneFlowOpGetGen.cpp.in @@ -1,7 +1,5 @@ -#include "OneFlow/OneFlowOps.h" #include #include -#include "OneFlow/OneFlowDialect.h" #include "OneFlow/Passes.h" #include "llvm/ADT/STLExtras.h" #include "mlir/IR/BuiltinAttributes.h" @@ -10,6 +8,13 @@ #include "llvm/ADT/StringSet.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "OneFlow/OneFlowDialect.h" +#include "OneFlow/OneFlowOpTraits.h" +#include "OneFlow/OneFlowEnums.h.inc" +#include "OneFlow/OneFlowSupport.h" +#include "OneFlow/OneFlowInterfaces.h.inc" +#define GET_OP_CLASSES +#include "OneFlow/OneFlow.@OP_GROUP_NAME_LOWER@_ops.h.inc" #define GET_OP_CLASSES #include "OneFlow/OneFlow.@OP_GROUP_NAME_LOWER@_ops.cpp.inc" diff --git a/oneflow/ir/lib/OneFlow/OneFlowOps.cpp b/oneflow/ir/lib/OneFlow/OneFlowOps.cpp index 69781d3febb..da3bdcb4da2 100644 --- a/oneflow/ir/lib/OneFlow/OneFlowOps.cpp +++ b/oneflow/ir/lib/OneFlow/OneFlowOps.cpp @@ -14,33 +14,53 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OneFlowOps.h" -#include -#include #include "OneFlow/OneFlowDialect.h" +#include "OneFlow/OneFlowSupport.h" #include "OneFlow/Passes.h" + #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringSet.h" + #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" -#include "llvm/ADT/StringSet.h" +#include "mlir/IR/FunctionImplementation.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "oneflow/ir/include/OneFlow/OneFlowSupport.h" + +#include +#include namespace mlir { namespace oneflow { -::mlir::OperandRange UserOp::dataInputOperands() { return data_input(); } -::mlir::OperandRange UserOp::ctrlInputOperands() { return ctrl_inputs(); } -::mlir::ResultRange UserOp::dataOutputResults() { return data_output(); } -::mlir::Value UserOp::ctrlOutputResult() { return ctrl_output(); } -::mlir::OperandRange SystemOp::dataInputOperands() { return data_input(); } -::mlir::OperandRange SystemOp::ctrlInputOperands() { return ctrl_inputs(); } -::mlir::ResultRange SystemOp::dataOutputResults() { return data_output(); } -::mlir::Value SystemOp::ctrlOutputResult() { return ctrl_output(); } +OperandRange UserOp::dataInputOperands() { return data_input(); } +OperandRange UserOp::ctrlInputOperands() { return ctrl_inputs(); } +ResultRange UserOp::dataOutputResults() { return data_output(); } +Value UserOp::ctrlOutputResult() { return ctrl_output(); } + +OperandRange SystemOp::dataInputOperands() { return data_input(); } +OperandRange SystemOp::ctrlInputOperands() { return ctrl_inputs(); } +ResultRange SystemOp::dataOutputResults() { return data_output(); } +Value SystemOp::ctrlOutputResult() { return ctrl_output(); } + +OperandRange VariableOp::dataInputOperands() { return {operand_begin(), operand_begin()}; } +OperandRange VariableOp::ctrlInputOperands() { return ctrl_inputs(); } +ResultRange VariableOp::dataOutputResults() { return output().dyn_cast(); } +Value VariableOp::ctrlOutputResult() { return ctrl_output(); } + +OperandRange InputOp::dataInputOperands() { return getODSOperands(0); } +OperandRange InputOp::ctrlInputOperands() { return ctrl_inputs(); } +ResultRange InputOp::dataOutputResults() { return output().dyn_cast(); } +Value InputOp::ctrlOutputResult() { return ctrl_output(); } -static mlir::ParseResult parseConstantOp(mlir::OpAsmParser& parser, mlir::OperationState& result) { +OperandRange OutputOp::dataInputOperands() { return getODSOperands(0); } +OperandRange OutputOp::ctrlInputOperands() { return ctrl_inputs(); } +ResultRange OutputOp::dataOutputResults() { return output().dyn_cast(); } +Value OutputOp::ctrlOutputResult() { return ctrl_output(); } + +static ParseResult parseConstantOp(OpAsmParser& parser, OperationState& result) { mlir::DenseElementsAttr value; if (parser.parseOptionalAttrDict(result.attributes) || parser.parseAttribute(value, "value", result.attributes)) { @@ -50,8 +70,6 @@ static mlir::ParseResult parseConstantOp(mlir::OpAsmParser& parser, mlir::Operat return success(); } -static mlir::LogicalResult verify(oneflow::ConstantOp op) { return mlir::success(); } - namespace { template @@ -60,8 +78,8 @@ LogicalResult TrimRedundantCtrl(OpType& op, PatternRewriter& rewriter) { const int32_t num_data_outputs = *(op.result_segment_sizes().template getValues()).begin(); NamedAttrList attributes(op->getAttrDictionary()); - attributes.erase(mlir::OpTrait::AttrSizedResultSegments::getResultSegmentSizeAttr()); - attributes.append(mlir::OpTrait::AttrSizedResultSegments::getResultSegmentSizeAttr(), + attributes.erase(OpTrait::AttrSizedResultSegments::getResultSegmentSizeAttr()); + attributes.append(OpTrait::AttrSizedResultSegments::getResultSegmentSizeAttr(), rewriter.getI32VectorAttr({num_data_outputs, 0})); if (auto created = rewriter.create(op->getLoc(), op.getODSResults(0 /* data out */).getTypes(), @@ -76,11 +94,11 @@ LogicalResult TrimRedundantCtrl(OpType& op, PatternRewriter& rewriter) { return failure(); } -bool IsCtrlOutTrimmed(oneflow::UserOp& op) { return !op.ctrl_output(); } +bool IsCtrlOutTrimmed(UserOp& op) { return !op.ctrl_output(); } -bool IsCtrlInAbsent(oneflow::UserOp& op) { - if (!op->hasAttrOfType<::mlir::DenseIntElementsAttr>( - mlir::OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr())) +bool IsCtrlInAbsent(UserOp& op) { + if (!op->hasAttrOfType( + OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr())) op.dump(); return op.ctrl_inputs().empty(); } @@ -94,11 +112,10 @@ static void getValuesFromIntArrayAttribute(ArrayAttr attr, SmallVector& array } } -struct ConcreteUserOps : public mlir::OpRewritePattern { - explicit ConcreteUserOps(mlir::MLIRContext* context) - : OpRewritePattern(context, /*benefit=*/1) {} - mlir::LogicalResult matchAndRewrite(oneflow::UserOp op, - mlir::PatternRewriter& rewriter) const override { +struct ConcreteUserOps : public OpRewritePattern { + explicit ConcreteUserOps(MLIRContext* context) + : OpRewritePattern(context, /*benefit=*/1) {} + LogicalResult matchAndRewrite(UserOp op, PatternRewriter& rewriter) const override { if (succeeded(TrimRedundantCtrl(op, rewriter))) { return success(); } // In principle, a concrete user op has no ctrl input/output. Some benefits: // 1. simplify things @@ -108,19 +125,19 @@ struct ConcreteUserOps : public mlir::OpRewritePattern { NamedAttrList attributes(op->getAttrDictionary()); attributes.erase(op.input_sizesAttrName()); attributes.erase(op.output_sizesAttrName()); - attributes.erase(mlir::OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr()); - attributes.erase(mlir::OpTrait::AttrSizedResultSegments::getResultSegmentSizeAttr()); + attributes.erase(OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr()); + attributes.erase(OpTrait::AttrSizedResultSegments::getResultSegmentSizeAttr()); llvm::SmallVector input_sizes, output_sizes; getValuesFromIntArrayAttribute(op.input_sizes(), input_sizes); getValuesFromIntArrayAttribute(op.output_sizes(), output_sizes); if (!input_sizes.empty()) { attributes.push_back(rewriter.getNamedAttr( - mlir::OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr(), + OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr(), rewriter.getI32VectorAttr(input_sizes))); } if (!output_sizes.empty()) { attributes.push_back(rewriter.getNamedAttr( - mlir::OpTrait::AttrSizedResultSegments::getResultSegmentSizeAttr(), + OpTrait::AttrSizedResultSegments::getResultSegmentSizeAttr(), rewriter.getI32VectorAttr(output_sizes))); } OperationState state(op->getLoc(), OneFlowDialect::getDialectNamespace().str() + "." @@ -129,13 +146,11 @@ struct ConcreteUserOps : public mlir::OpRewritePattern { state.addOperands(op.getODSOperands(0) /* data in */); state.addTypes(op.getODSResults(0 /* data out */).getTypes()); if (auto created = rewriter.createOperation(state)) { - if (created->hasTrait() == false) { - created->removeAttr( - mlir::OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr()); + if (created->hasTrait() == false) { + created->removeAttr(OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr()); } - if (created->hasTrait() == false) { - created->removeAttr( - mlir::OpTrait::AttrSizedResultSegments::getResultSegmentSizeAttr()); + if (created->hasTrait() == false) { + created->removeAttr(OpTrait::AttrSizedResultSegments::getResultSegmentSizeAttr()); } if (created->hasTrait() == false) { created->removeAttr(OpTrait::IsAlternative::getOpTypeNameAttr()); @@ -152,30 +167,26 @@ struct ConcreteUserOps : public mlir::OpRewritePattern { } }; -void UserOp::getCanonicalizationPatterns(::mlir::RewritePatternSet& results, - ::mlir::MLIRContext* context) { +void UserOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { results.insert(context); } -struct ConcreteSystemOps : public mlir::OpRewritePattern { - explicit ConcreteSystemOps(mlir::MLIRContext* context) - : OpRewritePattern(context, /*benefit=*/1) {} - mlir::LogicalResult matchAndRewrite(oneflow::SystemOp op, - mlir::PatternRewriter& rewriter) const override { +struct ConcreteSystemOps : public OpRewritePattern { + explicit ConcreteSystemOps(MLIRContext* context) + : OpRewritePattern(context, /*benefit=*/1) {} + LogicalResult matchAndRewrite(oneflow::SystemOp op, PatternRewriter& rewriter) const override { return TrimRedundantCtrl(op, rewriter); } }; -void SystemOp::getCanonicalizationPatterns(::mlir::RewritePatternSet& results, - ::mlir::MLIRContext* context) { +void SystemOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { results.insert(context); } -struct ConvertAddOpWithArity : public mlir::OpRewritePattern { - explicit ConvertAddOpWithArity(mlir::MLIRContext* context) - : OpRewritePattern(context, /*benefit=*/1) {} - mlir::LogicalResult matchAndRewrite(oneflow::AddNOp op, - mlir::PatternRewriter& rewriter) const override { +struct ConvertAddOpWithArity : public OpRewritePattern { + explicit ConvertAddOpWithArity(MLIRContext* context) + : OpRewritePattern(context, /*benefit=*/1) {} + LogicalResult matchAndRewrite(AddNOp op, PatternRewriter& rewriter) const override { const auto arity = op.in().size(); if (arity == 2) { NamedAttrList attributes = op->getAttrs(); @@ -194,11 +205,41 @@ struct ConvertAddOpWithArity : public mlir::OpRewritePattern { } }; -void AddNOp::getCanonicalizationPatterns(::mlir::RewritePatternSet& results, - ::mlir::MLIRContext* context) { +void AddNOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { results.insert(context); } +template +struct ConcreteSystemOpPattern : public OpRewritePattern { + explicit ConcreteSystemOpPattern(MLIRContext* context) + : OpRewritePattern(context, /*benefit=*/1) {} + LogicalResult matchAndRewrite(OpType op, PatternRewriter& rewriter) const override { + if (op.ctrl_output() && op.ctrl_output().use_empty()) { + NamedAttrList attributes(op->getAttrDictionary()); + if (auto created = rewriter.create(op->getLoc(), op.output().getType(), + op->getOperands(), attributes)) { + op.output().replaceAllUsesWith( + created->getResult(op.output().template cast().getResultNumber())); + op->erase(); + return success(); + } + } + return failure(); + } +}; + +void VariableOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { + results.insert>(context); +} + +void InputOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { + results.insert>(context); +} + +void OutputOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { + results.insert>(context); +} + void NormalizationAddReluOp::build(::mlir::OpBuilder& odsBuilder, ::mlir::OperationState& odsState, Value x, Value addend, Value moving_mean, Value moving_variance, Value gamma, Value beta, StringRef op_name, StringRef device_tag, @@ -240,6 +281,64 @@ void NormalizationAddReluOp::build(::mlir::OpBuilder& odsBuilder, ::mlir::Operat std::string Add2Op::getOriginalOpTypeName() { return "add_n"; } +void Job::build(OpBuilder& builder, OperationState& state, StringRef name, FunctionType type) { + state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); + state.addAttribute(getTypeAttrName(), TypeAttr::get(type)); + state.addRegion(); +} + +static ParseResult parseJob(OpAsmParser& parser, OperationState& result) { + auto buildFuncType = [](Builder& builder, ArrayRef argTypes, ArrayRef results, + function_like_impl::VariadicFlag, + std::string&) { return builder.getFunctionType(argTypes, results); }; + + return function_like_impl::parseFunctionLikeOp(parser, result, /*allowVariadic=*/false, + buildFuncType); +} + +static void print(Job op, OpAsmPrinter& p) { + FunctionType fnType = op.getType(); + function_like_impl::printFunctionLikeOp(p, op, fnType.getInputs(), /*isVariadic=*/false, + fnType.getResults()); +} + +static LogicalResult verify(Job op) { + // If this function is external there is nothing to do. + if (op.isExternal()) return success(); + + // Verify that the argument list of the function and the arg list of the entry + // block line up. The trait already verified that the number of arguments is + // the same between the signature and the block. + auto fnInputTypes = op.getType().getInputs(); + Block& entryBlock = op.front(); + for (unsigned i = 0, e = entryBlock.getNumArguments(); i != e; ++i) + if (fnInputTypes[i] != entryBlock.getArgument(i).getType()) + return op.emitOpError("type of entry block argument #") + << i << '(' << entryBlock.getArgument(i).getType() + << ") must match the type of the corresponding argument in " + << "function signature(" << fnInputTypes[i] << ')'; + + return success(); +} + +static LogicalResult verify(mlir::oneflow::ReturnOp op) { + auto job = cast(op->getParentOp()); + + // The operand number and types must match the function signature. + const auto& results = job.getType().getResults(); + if (op.getNumOperands() != results.size()) + return op.emitOpError("has ") << op.getNumOperands() << " operands, but enclosing function (@" + << job.getName() << ") returns " << results.size(); + + for (unsigned i = 0, e = results.size(); i != e; ++i) + if (op.getOperand(i).getType() != results[i]) + return op.emitError() << "type of return operand " << i << " (" << op.getOperand(i).getType() + << ") doesn't match function result type (" << results[i] << ")" + << " in function @" << job.getName(); + + return success(); +} + } // namespace oneflow } // namespace mlir diff --git a/oneflow/ir/lib/OneFlow/Passes.cpp b/oneflow/ir/lib/OneFlow/Passes.cpp index 94a4ba5db60..6f2e43f1c98 100644 --- a/oneflow/ir/lib/OneFlow/Passes.cpp +++ b/oneflow/ir/lib/OneFlow/Passes.cpp @@ -14,11 +14,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "OneFlow/OneFlowOps.h" -#include -#include #include "OneFlow/OneFlowDialect.h" #include "OneFlow/Passes.h" -#include "llvm/ADT/STLExtras.h" + #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" @@ -51,14 +49,21 @@ limitations under the License. #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" #endif // WITH_MLIR_CUDA_CODEGEN +#include "llvm/ADT/STLExtras.h" + +#include +#include + namespace mlir { namespace oneflow { LogicalResult DumpAssembly(::mlir::PatternRewriter& rewriter, MlirJitOp op) { // TODO: now we only need one JIT engine - auto parent_func_op = op->getParentOfType(); + auto parent_func_op = op->getParentOfType(); + if (!parent_func_op) { return failure(); } auto parent_module_op = parent_func_op->getParentOfType(); + if (!parent_module_op) { return failure(); } SymbolTable symbol_table(parent_module_op); std::string mlir; llvm::raw_string_ostream os_mlir(mlir); @@ -79,13 +84,24 @@ FuncOp GetOrInsertFuncOp(::mlir::PatternRewriter& rewriter, mlir::Location loc, for (auto result : results) { result_types.push_back(result.getType()); } auto func_type = rewriter.getFunctionType(argument_types, result_types); auto first_op = *ops.begin(); - auto parent_func_op = first_op->getParentOfType(); + auto parent_func_op = first_op->getParentOfType(); + if (!parent_func_op) { + emitError(loc) << "null parent oneflow::Job " << *first_op; + return nullptr; + } auto parent_module_op = parent_func_op->getParentOfType(); + if (!parent_module_op) { + emitError(loc) << "null ModuleOp " << *first_op; + return nullptr; + } SymbolTable symbol_table(parent_module_op); OpBuilder::InsertionGuard guard(rewriter); Block::iterator insertPt(parent_func_op->getNextNode()); rewriter.setInsertionPointToStart(&parent_module_op.body().getBlocks().back()); - assert(!parent_func_op->hasAttr("llvm.emit_c_interface")); + if (parent_func_op->hasAttr("llvm.emit_c_interface")) { + emitError(loc) << "parent should not has attr of llvm.emit_c_interface " << *parent_func_op; + return nullptr; + } auto function = rewriter.create(loc, func_name, func_type); function->setAttr("llvm.emit_c_interface", mlir::UnitAttr::get(rewriter.getContext())); function.body().emplaceBlock(); @@ -98,8 +114,11 @@ FuncOp GetOrInsertFuncOp(::mlir::PatternRewriter& rewriter, mlir::Location loc, for (auto op : ops) { nb.clone(*op, mapping); } SmallVector<::mlir::Value, 4> mapped_results; for (auto result : results) { mapped_results.push_back(mapping.lookup(result)); } - rewriter.create(loc, mapped_results); - assert(!symbol_table.lookup(func_name)); + rewriter.create(loc, mapped_results); + if (symbol_table.lookup(func_name)) { + emitError(loc) << func_name << " should not be at symbol table of ModuleOp"; + return nullptr; + } return function; } @@ -142,6 +161,7 @@ ::llvm::SmallVector<::mlir::Value, 4> OutlineMulCast(::mlir::PatternRewriter& re SmallVector ops = {cast_op, mul_op}; auto function = GetOrInsertFuncOp(rewriter, mul_op->getLoc(), op_name, operands, results, ops); + assert(function); auto created = rewriter.create(mul_op.getLoc(), function, attributes, operands); assert(DumpAssembly(rewriter, created).succeeded()); cast_op->dropAllUses(); diff --git a/oneflow/ir/llvm-in-tree.cmake b/oneflow/ir/llvm-in-tree.cmake new file mode 100644 index 00000000000..5e8d2afc3f0 --- /dev/null +++ b/oneflow/ir/llvm-in-tree.cmake @@ -0,0 +1,57 @@ +include(FetchContent) +message("-- LLVM_MONO_REPO_URL: " ${LLVM_MONO_REPO_URL}) +message("-- LLVM_MONO_REPO_MD5: " ${LLVM_MONO_REPO_MD5}) +FetchContent_Declare( + llvm_monorepo +) +FetchContent_GetProperties(llvm_monorepo) + +set(LLVM_INSTALL_DIR ${THIRD_PARTY_DIR}/llvm) + +if(NOT llvm_monorepo_POPULATED) + FetchContent_Populate(llvm_monorepo + URL ${LLVM_MONO_REPO_URL} + URL_HASH MD5=${LLVM_MONO_REPO_MD5} + ) +endif() +set(CMAKE_CXX_FLAGS "" CACHE STRING "" FORCE) +set(CMAKE_CXX_FLAGS_DEBUG "" CACHE STRING "" FORCE) +set(CMAKE_CXX_FLAGS_RELEASE "" CACHE STRING "" FORCE) +set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "" CACHE STRING "" FORCE) + +set(CMAKE_INSTALL_PREFIX ${LLVM_INSTALL_DIR} CACHE STRING "" FORCE) +set(LLVM_ENABLE_RTTI ON CACHE BOOL "turn this on to make it compatible with protobuf") +set(LLVM_ENABLE_EH ON CACHE BOOL "turn this on to make it compatible with half (the library)") +set(LLVM_BUILD_EXAMPLES OFF CACHE BOOL "") +set(LLVM_BUILD_TOOLS OFF CACHE BOOL "") +set(LLVM_INCLUDE_EXAMPLES OFF CACHE BOOL "") +set(LLVM_INCLUDE_TESTS OFF CACHE BOOL "" FORCE) +set(MLIR_INCLUDE_TESTS OFF CACHE BOOL "" FORCE) +set(LLVM_INCLUDE_BENCHMARKS OFF CACHE BOOL "") +set(LLVM_TARGETS_TO_BUILD host;NVPTX CACHE STRING "") +set(LLVM_ENABLE_ASSERTIONS ON CACHE BOOL "") +set(LLVM_ENABLE_PROJECTS mlir CACHE STRING "") +set(LLVM_APPEND_VC_REV OFF CACHE BOOL "") +set(LLVM_ENABLE_ZLIB OFF CACHE BOOL "") +set(LLVM_INSTALL_UTILS ON CACHE BOOL "") +set(LLVM_ENABLE_OCAMLDOC OFF CACHE BOOL "") +set(LLVM_ENABLE_BINDINGS OFF CACHE BOOL "") +set(LLVM_OPTIMIZED_TABLEGEN ON CACHE BOOL "" FORCE) +set(MLIR_ENABLE_CUDA_RUNNER ${WITH_MLIR_CUDA_CODEGEN} CACHE STRING "") +set(LLVM_MAIN_SRC_DIR ${llvm_monorepo_SOURCE_DIR}/llvm) +set(LLVM_BINARY_DIR ${llvm_monorepo_BINARY_DIR}) +set(LLVM_TOOLS_BINARY_DIR ${llvm_monorepo_BINARY_DIR}/bin CACHE STRING "" FORCE) +set(MLIR_MAIN_SRC_DIR ${LLVM_MAIN_SRC_DIR}/../mlir) +set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include) +set(MLIR_GENERATED_INCLUDE_DIR ${LLVM_BINARY_DIR}/tools/mlir/include) +set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}") + + +set(llvm_monorepo_BINARY_DIR ${llvm_monorepo_BINARY_DIR}) +install(TARGETS oneflow of_protoobj of_cfgobj of_functional_obj EXPORT oneflow DESTINATION lib) +install(EXPORT oneflow DESTINATION lib/oneflow) +add_subdirectory(${llvm_monorepo_SOURCE_DIR}/llvm ${llvm_monorepo_BINARY_DIR}) +set(LLVM_INCLUDE_DIRS ${LLVM_MAIN_SRC_DIR}/include;${llvm_monorepo_BINARY_DIR}/include) +set(LLVM_EXTERNAL_LIT "${llvm_monorepo_BINARY_DIR}/bin/llvm-lit" CACHE STRING "" FORCE) +set(LTDL_SHLIB_EXT ${CMAKE_SHARED_LIBRARY_SUFFIX}) +set(LLVM_LIBRARY_DIR "${llvm_monorepo_BINARY_DIR}/lib") diff --git a/oneflow/ir/oneflow-extension/CMakeLists.txt b/oneflow/ir/oneflow-extension/CMakeLists.txt index 48094cf3e5d..9640904c461 100644 --- a/oneflow/ir/oneflow-extension/CMakeLists.txt +++ b/oneflow/ir/oneflow-extension/CMakeLists.txt @@ -1,25 +1,9 @@ -set(LLVM_LINK_COMPONENTS - Support - ) - -set(LLVM_ENABLE_RTTI ON) # turn this on to make it compatible with protobuf -oneflow_add_llvm_library(MLIROneFlowExtension +oneflow_add_mlir_library(MLIROneFlowExtension extension.cpp ir_pass.cpp DEPENDS LINK_LIBS PUBLIC MLIRIR - BUILDTREE_ONLY -) - -llvm_update_compile_flags(MLIROneFlowExtension) - -if(WITH_MLIR_CUDA_CODEGEN) - set(MLIR_RUNTIME_GPU_LIBS -Wl,--no-as-needed mlir_cuda_runtime -Wl,--as-needed) -endif(WITH_MLIR_CUDA_CODEGEN) -set(MLIR_RUNTIME_LIBS -Wl,--no-as-needed mlir_c_runner_utils -Wl,--as-needed) -target_link_libraries(MLIROneFlowExtension - PRIVATE ${dialect_libs} ${translation_libs} MLIRIR @@ -29,21 +13,10 @@ target_link_libraries(MLIROneFlowExtension MLIRTranslation MLIRSupport MLIROneFlow + oneflow MLIRExecutionEngine MLIROneFlowTranslation - oneflow - PUBLIC - ${MLIR_RUNTIME_LIBS} - ${MLIR_RUNTIME_GPU_LIBS} + MLIROneFlowRuntime ) - -if (BUILD_SHARED_LIBS) - get_filename_component(ONEFLOW_BUILD_ROOT_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../../.. ABSOLUTE) - get_property(TRANSLATE_INSTALL_RPATH TARGET MLIROneFlowExtension PROPERTY INSTALL_RPATH) - list(APPEND TRANSLATE_INSTALL_RPATH ${PROTOBUF_LIBRARY_DIR}) - list(APPEND TRANSLATE_INSTALL_RPATH ${ONEFLOW_BUILD_ROOT_DIR}) - set_target_properties(MLIROneFlowExtension PROPERTIES INSTALL_RPATH "${TRANSLATE_INSTALL_RPATH}") -endif() - -mlir_check_link_libraries(MLIROneFlowExtension) +mlir_check_all_link_libraries(MLIROneFlowExtension) add_custom_target(mex DEPENDS MLIROneFlowExtension) diff --git a/oneflow/ir/oneflow-extension/extension.cpp b/oneflow/ir/oneflow-extension/extension.cpp index 466f1567fcf..56ae0ac2afd 100644 --- a/oneflow/ir/oneflow-extension/extension.cpp +++ b/oneflow/ir/oneflow-extension/extension.cpp @@ -40,8 +40,8 @@ namespace { REGISTER_USER_OP("mlir_jit") .Attr("mlir_assembly") - .InputWithMinimum("in", 0) - .OutputWithMinimum("out", 0) + .Input("in") + .Output("out") .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { // TODO: infer shape by extracting Ops from mlir_assembly CHECK_EQ(ctx->inputs().size(), 2); diff --git a/oneflow/ir/oneflow-extension/ir_pass.cpp b/oneflow/ir/oneflow-extension/ir_pass.cpp index 4b8e4c0d001..7c40aa7e28c 100644 --- a/oneflow/ir/oneflow-extension/ir_pass.cpp +++ b/oneflow/ir/oneflow-extension/ir_pass.cpp @@ -21,6 +21,7 @@ limitations under the License. #include "oneflow/core/framework/user_op_def.h" #include "oneflow/core/framework/user_op_registry.h" #include "oneflow/core/framework/user_op_registry_manager.h" +#include "oneflow/core/job/job_ir.h" namespace oneflow { @@ -72,7 +73,7 @@ class RoundTripOneFlowJobWrapper : public mlir::oneflow::RoundTripOneFlowJobWrap } const ::oneflow::ParallelConf& ParallelConf4OpName(const std::string& op_name) const override { - return job_builder_.ParallelConf4OpName(op_name); + return job_builder_.ParallelConf4OpName(op_name).GetOrThrow(); } const ::oneflow::OperatorConf& OpConf4OpName(const std::string& op_name) const override { return job_builder_.OpConf4OpName(op_name).GetOrThrow(); @@ -174,4 +175,21 @@ Maybe IRRoundTrip::Apply(Job* job, JobPassCtx* ctx) const { template class IRRoundTrip; template class IRRoundTrip; +Maybe SaveJobToIR(Job* job, const std::string& path) { + // TODO: check path is valid dir + if (std::getenv("ONEFLOW_DEBUG_MODE") != nullptr) { + TeePersistentLogStream::Create("saved_job")->Write(*job); + } + RoundTripOneFlowJobWrapper job_wrapper(job); + ::mlir::oneflow::SaveJobToIR(job_wrapper, path); + return Maybe::Ok(); +} + +Maybe LoadJobFromIR(Job* job, const std::string& path) { + job->Clear(); + RoundTripOneFlowJobWrapper job_wrapper(job); + ::mlir::oneflow::LoadJobFromIR(job_wrapper, path); + return Maybe::Ok(); +} + } // namespace oneflow diff --git a/oneflow/ir/oneflow-gen-ods/oneflow-gen-ods.cpp b/oneflow/ir/oneflow-gen-ods/oneflow-gen-ods.cpp deleted file mode 100644 index 79b147e98ce..00000000000 --- a/oneflow/ir/oneflow-gen-ods/oneflow-gen-ods.cpp +++ /dev/null @@ -1,724 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -#include -#include "oneflow/core/framework/user_op_def.h" -#include "oneflow/core/framework/user_op_registry.h" -#include "oneflow/core/framework/user_op_registry_manager.h" -#include - -namespace { - -using K = std::string; -using V = ::oneflow::user_op::OpRegistryResult; -using ::oneflow::AttrType; -using ::oneflow::UserOpDef_ArgDef; - -// from llvm -std::string convertToCamelFromSnakeCase(const std::string& input, bool capitalizeFirst) { - if (input.empty()) return ""; - - std::string output; - output.reserve(input.size()); - - // Push the first character, capatilizing if necessary. - if (capitalizeFirst && std::islower(input.front())) - output.push_back(toupper(input.front())); - else - output.push_back(input.front()); - - // Walk the input converting any `*_[a-z]` snake case into `*[A-Z]` camelCase. - for (size_t pos = 1, e = input.size(); pos < e; ++pos) { - if (input[pos] == '_' && pos != (e - 1) && std::islower(input[pos + 1])) - output.push_back(toupper(input[++pos])); - else - output.push_back(input[pos]); - } - return output; -} - -std::string GetMLIRAttrTypeName(const AttrType& attr_type) { - if (attr_type == ::oneflow::kAtInt32) { - return "SI32Attr"; - } else if (attr_type == ::oneflow::kAtInt64) { - return "SI64Attr"; - } else if (attr_type == ::oneflow::kAtBool) { - return "BoolAttr"; - } else if (attr_type == ::oneflow::kAtFloat) { - return "F32Attr"; - } else if (attr_type == ::oneflow::kAtDouble) { - return "F64Attr"; - } else if (attr_type == ::oneflow::kAtString) { - return "StrAttr"; - } else if (attr_type == ::oneflow::kAtShape) { - return "AnyI64ElementsAttr"; - } else if (attr_type == ::oneflow::kAtDataType) { - return "OneFlow_DataType"; - } else if (attr_type == ::oneflow::kAtListInt32) { - return "SI32ArrayAttr"; - } else if (attr_type == ::oneflow::kAtListInt64) { - return "SI64ArrayAttr"; - } else if (attr_type == ::oneflow::kAtListFloat) { - return "F32ArrayAttr"; - } else if (attr_type == ::oneflow::kAtListDataType) { - return "DTArrayAttr"; - } else if (attr_type == ::oneflow::kAtListShape) { - return "ShapeArrayAttr"; - } else if (attr_type == ::oneflow::kAtListString) { - return "StrArrayAttr"; - } else { - LOG(FATAL) << "fail to convert: " << attr_type; - return "failure"; - } -} - -template -std::string ToZeroNoTrailing(T f) { - std::string str = std::to_string(f); - str.erase(str.find_last_not_of('0') + 1, std::string::npos); - return str; -} - -std::string GetDefaultValue(const ::oneflow::AttrValue& attr_val) { - if (attr_val.has_at_string()) { - return "\\\"" + attr_val.at_string() + "\\\""; - } else if (attr_val.has_at_int32()) { - return std::to_string(attr_val.at_int32()); - } else if (attr_val.has_at_int64()) { - return std::to_string(attr_val.at_int64()); - } else if (attr_val.has_at_float()) { - return ToZeroNoTrailing(attr_val.at_float()); - } else if (attr_val.has_at_double()) { - return ToZeroNoTrailing(attr_val.at_double()); - } else if (attr_val.has_at_bool()) { - return attr_val.at_bool() ? "true" : "false"; - } else if (attr_val.has_at_list_int32()) { - std::string ret = "{"; - const auto& list = attr_val.at_list_int32().val(); - for (auto it = list.begin(); it != list.end(); ++it) { - ret += std::to_string(*it) + (std::next(it) == list.end() ? "" : ", "); - } - ret += "}"; - return ret; - } else if (attr_val.has_at_list_int64()) { - std::string ret = "{"; - const auto& list = attr_val.at_list_int64().val(); - for (auto it = list.begin(); it != list.end(); ++it) { - ret += std::to_string(*it) + (std::next(it) == list.end() ? "" : ", "); - } - ret += "}"; - return ret; - } else if (attr_val.has_at_list_float()) { - std::string ret = "{"; - const auto& list = attr_val.at_list_float().val(); - for (auto it = list.begin(); it != list.end(); ++it) { - ret += std::to_string(*it) + (std::next(it) == list.end() ? "" : ", "); - } - ret += "}"; - return ret; - } else if (attr_val.has_at_list_string()) { - std::string ret = "{"; - const auto& list = attr_val.at_list_string().val(); - for (auto it = list.begin(); it != list.end(); ++it) { - ret += "\"" + *it + "\"" + (std::next(it) == list.end() ? "" : ", "); - } - ret += "}"; - return ret; - } else if (attr_val.has_at_data_type()) { - return std::to_string(attr_val.at_data_type()); - } - LOG(FATAL) << "fail to convert value_case: " << attr_val.value_case() << "\n" - << attr_val.DebugString(); -} - -std::string GetMLIRAttrType(const ::oneflow::UserOpDef_AttrDef& attr_def) { - const AttrType& attr_type = attr_def.type(); - std::string name = GetMLIRAttrTypeName(attr_type); - auto is_default_supported = - attr_def.default_val().has_at_bool() || attr_def.default_val().has_at_int32() - || attr_def.default_val().has_at_int64() || attr_def.default_val().has_at_float() - || attr_def.default_val().has_at_double() - || (attr_def.default_val().has_at_string() && attr_def.default_val().at_string().size() > 0); - if (attr_def.has_default_val() && is_default_supported) { - name = - "DefaultValuedAttr<" + name + ", " + "\"" + GetDefaultValue(attr_def.default_val()) + "\">"; - } - return name; -} - -const std::set& GetIdempotentOps() { - static std::set ret{"abs", "ceil", "floor", "ones_like", "relu", "relu_grad", - "relu6", "rint", "round", "sign", "zeros_like"}; - return ret; -} -const std::set& GetInvolutionOps() { - static std::set ret{"reciprocal", "negative"}; - return ret; -} - -bool IsGradOp(const std::string& op_name) { return op_name.find("grad") != std::string::npos; } -const std::set& GetQuantizationOps() { - static std::set ret{"min_max_observer", "moving_average_min_max_observer", - "fake_quantization", "quantization"}; - return ret; -} - -const std::set& GetMathOps() { - static std::set ret{"abs", "acos", - "acosh", "asin", - "asinh", "atan", - "atanh", "ceil", - "cos", "cosh", - "erf", "erfc", - "exp", "expm1", - "floor", "lgamma", - "log", "log1p", - "log_sigmoid", "negative", - "reciprocal", "reciprocal_no_nan", - "rint", "round", - "rsqrt", "sigmoid_v2", - "sign", "sin", - "sinh", "softplus", - "sqrt", "square", - "tan", "tanh"}; - return ret; -} - -const std::set& GetOpsUsedInPatterns() { - static std::set ret{"scalar_mul_by_tensor", "cast", "tril", "scalar_mul", - "fused_scale_tril", "dropout", "bias_add"}; - return ret; -} -bool IsMathOp(const std::string& op_name) { - bool is_grad = false; - for (const auto& name : GetMathOps()) { - if (op_name.find(name) != std::string::npos && IsGradOp(op_name)) { is_grad = true; } - } - return GetMathOps().find(op_name) != GetMathOps().end() || is_grad; -} -bool IsUsedInPatterns(const std::string& op_name) { - return GetOpsUsedInPatterns().find(op_name) != GetOpsUsedInPatterns().end(); -} -bool IsInvolutionOp(const std::string& op_name) { - return GetInvolutionOps().find(op_name) != GetInvolutionOps().end() && !IsGradOp(op_name); -} -bool IsQuantizationOp(const std::string& op_name) { - return GetQuantizationOps().find(op_name) != GetQuantizationOps().end(); -} -bool IsIdempotentOp(const std::string& op_name) { - return GetIdempotentOps().find(op_name) != GetIdempotentOps().end() && !IsGradOp(op_name); -} - -bool IsPoolOp(const std::string& op_name) { - return ((op_name.rfind("avg", 0) == 0 || op_name.rfind("max", 0) == 0) - || ((op_name.find("avg") != std::string::npos || op_name.find("max") != std::string::npos) - && op_name.rfind("tf", 0) == 0)) - && op_name.find("pool") != std::string::npos; -} -bool IsEagerOp(const std::string& op_name) { return (op_name.rfind("eager", 0) == 0); } -bool IsTensorBufferOp(const std::string& op_name) { - return op_name.find("tensor_buffer") != std::string::npos; -} -bool IsSummaryOp(const std::string& op_name) { - return op_name.find("summary") != std::string::npos; -} -bool IsAnyPoolOp(const std::string& op_name) { return op_name.find("pool") != std::string::npos; } -bool IsAnyConvOp(const std::string& op_name) { return op_name.find("conv") != std::string::npos; } -bool IsConvOp(const std::string& op_name) { - return op_name.rfind("conv", 0) == 0 && op_name.find("grad") == std::string::npos; -} - -bool IsLazyPoolOp(const std::string& op_name) { - return op_name.find("_pool") != std::string::npos && op_name.find("tf_") != std::string::npos; -} -bool IsAdaptivePoolOp(const std::string& op_name) { - return op_name.find("_pool") != std::string::npos - && op_name.find("adaptive_") != std::string::npos; -} -bool IsNCCLOp(const std::string& op_name) { return op_name.find("nccl") != std::string::npos; } -bool IsOptimizerOp(const std::string& op_name) { - return (op_name.find("update") != std::string::npos || op_name.find("adam") != std::string::npos) - && op_name.find("scatter") == std::string::npos; -} -bool IsTrigonometric(const std::string& op_name) { - return (op_name.find("sin") != std::string::npos || op_name.find("cos") != std::string::npos - || op_name.find("tan") != std::string::npos) - && op_name.find("constant") == std::string::npos; -} -bool IsTestOp(const std::string& op_name) { - return (op_name.find("test") != std::string::npos || op_name.find("Test") != std::string::npos - || op_name.find("ccrelu") != std::string::npos); -} -bool IsPaddingOp(const std::string& op_name) { return (op_name.find("pad") != std::string::npos); } -bool IsAssignOp(const std::string& op_name) { - return (op_name.find("assign") != std::string::npos); -} -bool IsCrossEntropyOp(const std::string& op_name) { - return (op_name.find("cross_entropy") != std::string::npos); -} -bool IsCUDAOp(const std::string& op_name) { return (op_name.find("nvtx") != std::string::npos); } -bool IsMatmulOp(const std::string& op_name) { - return (op_name.find("matmul") != std::string::npos || op_name.find("fc") != std::string::npos); -} - -bool IsDatasetOp(const std::string& op_name) { - return (op_name.find("reader") != std::string::npos || op_name.find("Reader") != std::string::npos - || op_name.find("loader") != std::string::npos - || op_name.find("decoder") != std::string::npos); -} -bool IsUpsampleOp(const std::string& op_name) { - return (op_name.find("upsample") != std::string::npos); -} -bool IsBroadcastOp(const std::string& op_name) { - return (op_name.find("broadcast") != std::string::npos); -} -bool IsIdentityOp(const std::string& op_name) { - return (op_name.find("identity") != std::string::npos); -} -bool IsScalarOp(const std::string& op_name) { - return (op_name.rfind("scalar_", 0) == 0 || op_name.find("by_scalar") != std::string::npos); -} -bool IsImageOp(const std::string& op_name) { return (op_name.find("image") != std::string::npos); } -bool IsSoftmaxOp(const std::string& op_name) { - return (op_name.find("softmax") != std::string::npos); -} -bool IsFusedOp(const std::string& op_name) { - return (op_name.find("fused") != std::string::npos - || op_name.find("add_relu") != std::string::npos); -} -bool IsReduceOp(const std::string& op_name) { - return (op_name.find("reduce") != std::string::npos); -} -bool IsReshapeOp(const std::string& op_name) { - return (op_name.find("reshape") != std::string::npos); -} -bool IsLossOp(const std::string& op_name) { return (op_name.find("loss") != std::string::npos); } -bool IsDetectionOp(const std::string& op_name) { - return (op_name.find("top_k") != std::string::npos || op_name.find("bbox") != std::string::npos - || op_name.find("segmentation") != std::string::npos - || op_name.find("roi") != std::string::npos || op_name.find("poly") != std::string::npos - || op_name.find("nms") != std::string::npos - || op_name.find("object") != std::string::npos); -} -bool IsIndicesOp(const std::string& op_name) { - return (op_name.find("arg") != std::string::npos || op_name.find("where") != std::string::npos - || op_name.find("gather") != std::string::npos - || op_name.find("slice") != std::string::npos - || op_name.find("indices") != std::string::npos - || op_name.find("segment_sum") != std::string::npos - || op_name.find("scatter") != std::string::npos); -} -bool IsNormalizationOp(const std::string& op_name) { - return (op_name.find("norm") != std::string::npos); -} -bool IsParallelCastOp(const std::string& op_name) { - return (op_name.find("parallel_cast") != std::string::npos); -} - -std::string PostProcessClassName(const std::string& op_name) { - std::string ret = op_name; - ret = std::regex_replace(ret, std::regex("pool"), "Pool"); - ret = std::regex_replace(ret, std::regex("_1d"), "1D"); - ret = std::regex_replace(ret, std::regex("_2d"), "2D"); - ret = std::regex_replace(ret, std::regex("_3d"), "3D"); - ret = std::regex_replace(ret, std::regex("1d"), "1D"); - ret = std::regex_replace(ret, std::regex("2d"), "2D"); - ret = std::regex_replace(ret, std::regex("3d"), "3D"); - return ret; -} - -std::string GetConvOpClassName(const std::string& op_name) { - std::string ret(convertToCamelFromSnakeCase(op_name, true)); - // NOTE: should change form conv => Convolution ? - return ret; -} - -std::string GetBaseOp(const std::string& op_name) { - if (IsInvolutionOp(op_name)) { - return "OneFlow_InvolutionBaseOp"; - } else if (IsIdempotentOp(op_name)) { - return "OneFlow_IdempotentBaseOp"; - } else if (IsConvOp(op_name)) { - return "OneFlow_ConvolutionBaseOp"; - } else if (IsPoolOp(op_name)) { - return "OneFlow_" + std::string(IsLazyPoolOp(op_name) ? "TF" : "") + "Pool" - + std::string(IsGradOp(op_name) ? "Grad" : "") + "BaseOp"; - } else if (IsAdaptivePoolOp(op_name)) { - return "OneFlow_AdaptivePool" + std::string(IsGradOp(op_name) ? "Grad" : "") + "BaseOp"; - } else { - return "OneFlow_BaseOp"; - } -} - -bool ShouldSkipOperandAndResultsAndAttrs(const std::string& op_name) { - return IsInvolutionOp(op_name) || IsIdempotentOp(op_name); -} - -bool ShouldGenEmptyBody(const std::string& op_name) { - return IsPoolOp(op_name) || IsAdaptivePoolOp(op_name) || IsConvOp(op_name); -} - -void PrintArgDef(const UserOpDef_ArgDef& arg_def) { - std::cout << " "; - if (arg_def.is_optional()) { std::cout << "Optional<"; } - if (arg_def.num_as_min()) { std::cout << "Variadic<"; } - std::cout << "AnyType"; - if (arg_def.is_optional() || arg_def.num_as_min()) { std::cout << ">"; } - CHECK(!(arg_def.is_optional() && arg_def.num_as_min())) << arg_def.DebugString(); - std::cout << ":$" << arg_def.name(); - if (arg_def.num_as_min()) { - // TODO: add verifier - } -} - -uint32_t NumMultipleVariadic( - const ::google::protobuf::RepeatedPtrField<::oneflow::UserOpDef_ArgDef>& arg_defs) { - uint32_t num_variadic_op = 0; - for (const auto& arg_def : arg_defs) { - if (arg_def.is_optional()) { num_variadic_op += 1; } - if (arg_def.num_as_min()) { num_variadic_op += 1; } - } - return num_variadic_op; -} - -bool HasAtLeastTwoVariadic( - const ::google::protobuf::RepeatedPtrField<::oneflow::UserOpDef_ArgDef>& arg_defs) { - return NumMultipleVariadic(arg_defs) > 1; -} - -bool HasVariadic( - const ::google::protobuf::RepeatedPtrField<::oneflow::UserOpDef_ArgDef>& arg_defs) { - return NumMultipleVariadic(arg_defs) > 0; -} - -std::string GetOperandKeys( - const ::google::protobuf::RepeatedPtrField<::oneflow::UserOpDef_ArgDef>& arg_defs) { - std::string ret = "{"; - for (auto it = arg_defs.begin(); it != arg_defs.end(); ++it) { - ret += ("\"" + it->name() + "\""); - if (std::next(it) != arg_defs.end()) { ret += ", "; } - } - ret += "}"; - return ret; -} - -std::string GetOperandMinimums( - const ::google::protobuf::RepeatedPtrField<::oneflow::UserOpDef_ArgDef>& arg_defs) { - std::string ret = "{"; - for (auto it = arg_defs.begin(); it != arg_defs.end(); ++it) { - uint32_t min = 0; - if (it->is_optional()) { - min = 0; - } else if (it->has_num_as_min()) { - min = it->num(); - } else { - min = 1; - } - ret += std::to_string(min); - if (std::next(it) != arg_defs.end()) { ret += ", "; } - } - ret += "}"; - return ret; -} - -// TODO: use MLIR Interfaces it implement this -void PrintReturnStaticVal(const std::string& type, const std::string& func_name, - const std::string& val) { - std::cout << " static const " + type + "* " + func_name + "() { static " + type + " val(" + val - + "); return &val; }\n"; -} -void PrintExtraClassDeclaration(const ::oneflow::UserOpDef& op_def) { - return; - std::cout << " let extraClassDeclaration = [{" - << "\n"; - PrintReturnStaticVal("std::vector", "inputKeys", GetOperandKeys(op_def.input())); - PrintReturnStaticVal("std::vector", "inputMinimums", - GetOperandMinimums(op_def.input())); - PrintReturnStaticVal("std::vector", "outputKeys", GetOperandKeys(op_def.output())); - PrintReturnStaticVal("std::vector", "outputMinimums", - GetOperandMinimums(op_def.input())); - std::cout << " }];" - << "\n"; -} - -void PrintHasCanonicalizer(const std::string& op_name) { - if (op_name == "add_n") { - std::cout << " let hasCanonicalizer = 1;" - << "\n"; - } -} - -void PrintTraitAttrs(const ::oneflow::UserOpDef& op_def) { - const bool need_operand_segment_sizes = HasAtLeastTwoVariadic(op_def.input()); - const bool need_result_segment_sizes = HasAtLeastTwoVariadic(op_def.output()); - if (need_operand_segment_sizes || need_result_segment_sizes) { - std::cout << " let trait_attrs = (ins" - << "\n"; - if (need_operand_segment_sizes) { - std::cout << " I32ElementsAttr:$operand_segment_sizes" - << (need_result_segment_sizes ? ",\n" : "\n"); - } - if (need_result_segment_sizes) { std::cout << " I32ElementsAttr:$result_segment_sizes\n"; } - std::cout << " );" - << "\n"; - } -} - -bool IsUnaryOp(const ::oneflow::user_op::OpRegistryResult& r) { - return NumMultipleVariadic(r.op_def.input()) == 0 && NumMultipleVariadic(r.op_def.output()) == 0 - && r.op_def.input().size() == 1 && r.op_def.output().size() == 1; -} - -bool IsBinaryOp(const ::oneflow::user_op::OpRegistryResult& r) { - return NumMultipleVariadic(r.op_def.input()) == 0 && NumMultipleVariadic(r.op_def.output()) == 0 - && r.op_def.input().size() == 2 && r.op_def.output().size() == 1; -} - -void PrintBody(const ::oneflow::user_op::OpRegistryResult& r) { - const ::oneflow::UserOpDef& op_def = r.op_def; - // TODO: handle in out size/optional - // TODO: handle "," in last element - std::cout << "{" - << "\n"; - // inputs - const bool should_skip_operand_and_results_and_attrs = - ShouldSkipOperandAndResultsAndAttrs(r.op_type_name); - const bool should_skip_operand = should_skip_operand_and_results_and_attrs; - const bool should_skip_result = should_skip_operand_and_results_and_attrs; - const bool should_skip_attrs = should_skip_operand_and_results_and_attrs; - if (op_def.input().size() && !should_skip_operand) { - std::cout << " let input = (ins" - << "\n"; - for (auto it = op_def.input().begin(); it != op_def.input().end(); ++it) { - PrintArgDef(*it); - std::cout << (std::next(it) == op_def.input().end() ? "" : ",") << "\n"; - } - std::cout << " );" - << "\n"; - } - // outputs - if (op_def.output().size() && !should_skip_result) { - std::cout << " let output = (outs" - << "\n"; - for (auto it = op_def.output().begin(); it != op_def.output().end(); ++it) { - PrintArgDef(*it); - std::cout << (std::next(it) == op_def.output().end() ? "" : ",") << "\n"; - } - std::cout << " );" - << "\n"; - } - // attrs - if (op_def.attr().size() && !should_skip_attrs) { - std::cout << " let attrs = (ins" - << "\n"; - for (auto it = op_def.attr().begin(); it != op_def.attr().end(); ++it) { - std::cout << " " << GetMLIRAttrType(*it) << ":$" << it->name() - << (std::next(it) == op_def.attr().end() ? "" : ",") << "\n"; - } - std::cout << " );" - << "\n"; - } - // trait attrs - PrintTraitAttrs(op_def); - PrintExtraClassDeclaration(op_def); - PrintHasCanonicalizer(r.op_type_name); - std::cout << "}" - << "\n"; -} - -bool ShouldGenBaseClass(const std::string& op_name) { return op_name == "normalization_add_relu"; } - -bool HasSideEffect(const std::string& op_name) { - return IsAssignOp(op_name) || IsOptimizerOp(op_name); -} - -std::string GetOpClassName(const std::string& op_name) { - std::string ret = ""; - if (IsConvOp(op_name)) { - ret = GetConvOpClassName(op_name); - } else { - ret = convertToCamelFromSnakeCase(op_name, true); - } - if (ShouldGenBaseClass(op_name)) { ret += "Base"; } - return PostProcessClassName(ret); -} - -std::string GetTraits(const ::oneflow::user_op::OpRegistryResult& r) { - const ::oneflow::UserOpDef& op_def = r.op_def; - std::string ret{}; - if (HasSideEffect(r.op_type_name) == false) { ret += "NoSideEffect"; } - const bool need_operand_segment_sizes = HasAtLeastTwoVariadic(op_def.input()); - const bool need_result_segment_sizes = HasAtLeastTwoVariadic(op_def.output()); - if (need_operand_segment_sizes) { - if (ret != "") ret += ", "; - ret += "AttrSizedOperandSegments"; - } - - if (need_result_segment_sizes) { - if (ret != "") ret += ", "; - ret += "AttrSizedResultSegments"; - } - if (ret != "") ret += ", "; - ret += "DeclareOpInterfaceMethods"; - return ret; -} - -bool IsReferencedByOtherDefinitions(const std::string& op_name) { - return ShouldGenBaseClass(op_name); -} - -bool ShoudSkipOp(const std::string& op_name) { return op_name == "mlir_jit"; } - -void PrintODSFromOpRegistryResults(const std::map& results) { - for (const auto& kv : results) { - if (ShoudSkipOp(kv.first)) continue; - const ::oneflow::user_op::OpRegistryResult& r = kv.second; - auto op_class_name = GetOpClassName(kv.first); - std::cout << (ShouldGenBaseClass(r.op_type_name) ? "class" : "def") << " OneFlow_" - << op_class_name << "Op : " << GetBaseOp(r.op_type_name) << "<\"" << kv.first - << "\", [" + GetTraits(r) + "]> "; // TODO: add traits - if (ShouldGenEmptyBody(r.op_type_name)) { - std::cout << "{}\n"; - } else { - PrintBody(r); - } - std::cout << "\n"; - } -} - -void PrintNamesInResults(const std::map& results) { - std::cout << "// "; - for (auto it = results.begin(); it != results.end(); ++it) { - std::cout << it->first; - if (std::next(it) != results.end()) { std::cout << ", "; } - } - std::cout << "\n"; -} - -void PrintGroupNames(std::map>& groups) { - std::cout << "// "; - for (auto it = groups.begin(); it != groups.end(); ++it) { - if (ShoudSkipOp(it->first)) continue; - std::cout << it->first; - if (std::next(it) != groups.end()) { std::cout << ";"; } - } - std::cout << "\n\n"; -} - -void PrintIncludes(std::map>& groups) { - std::cout << "/*\n"; - for (auto it = groups.begin(); it != groups.end(); ++it) { - auto group_name = it->first; - if (group_name == "BASE") continue; - if (group_name == "TEST") continue; - std::transform(group_name.begin(), group_name.end(), group_name.begin(), ::tolower); - group_name += "_ops"; - std::cout << "#define GET_OP_LIST\n"; - std::cout << "#include \"OneFlow/OneFlow." << group_name << ".cpp.inc\"\n"; - if (std::next(it) != groups.end()) { std::cout << ",\n"; } - } - std::cout << "*/\n\n"; -} - -void GroupOpRegistryResults(const std::map& results, - std::map>& groups) { - for (const auto& kv : results) { - std::string group_name = "MISC"; - const ::oneflow::user_op::OpRegistryResult& r = kv.second; - if (IsUnaryOp(r)) { group_name = "Unary"; } - if (IsBinaryOp(r)) { group_name = "Binary"; } - if (IsImageOp(r.op_type_name)) { group_name = "Image"; } - if (IsMathOp(r.op_type_name)) { group_name = "math"; } - if (IsPaddingOp(r.op_type_name)) { group_name = "PADDING"; } - if (IsIndicesOp(r.op_type_name)) { group_name = "Indices"; } - if (IsBroadcastOp(r.op_type_name)) { group_name = "Broadcast"; } - if (IsScalarOp(r.op_type_name)) { group_name = "Scalar"; } - if (IsReduceOp(r.op_type_name)) { group_name = "reduce"; } - if (IsReshapeOp(r.op_type_name)) { group_name = "reshape"; } - if (IsLossOp(r.op_type_name)) { group_name = "loss"; } - if (IsNormalizationOp(r.op_type_name)) { group_name = "Normalization"; } - if (IsCrossEntropyOp(r.op_type_name)) { group_name = "Cross_Entropy"; } - if (IsSoftmaxOp(r.op_type_name)) { group_name = "Softmax"; } - if (IsNCCLOp(r.op_type_name)) { group_name = "NCCL"; } - if (IsAnyConvOp(r.op_type_name)) { group_name = "CONV"; } - if (IsAnyPoolOp(r.op_type_name)) { group_name = "POOL"; } - if (IsUpsampleOp(r.op_type_name)) { group_name = "UPSAMPLE"; } - if (IsAssignOp(r.op_type_name)) { group_name = "assign"; } - if (IsOptimizerOp(r.op_type_name)) { group_name = "OPTIMIZER"; } - if (IsTrigonometric(r.op_type_name)) { group_name = "TRIGONOMETRIC"; } - if (IsIdempotentOp(r.op_type_name)) { group_name = "IDEMPOTENT"; } - if (IsInvolutionOp(r.op_type_name)) { group_name = "INVOLUTION"; } - if (IsIdentityOp(r.op_type_name)) { group_name = "Identity"; } - if (IsFusedOp(r.op_type_name)) { group_name = "Fused"; } - if (IsEagerOp(r.op_type_name)) { group_name = "eager"; } - if (IsQuantizationOp(r.op_type_name)) { group_name = "QUANTIZATION"; } - if (IsDatasetOp(r.op_type_name)) { group_name = "DATASET"; } - if (IsMatmulOp(r.op_type_name)) { group_name = "matmul"; } - if (IsTensorBufferOp(r.op_type_name)) { group_name = "tensor_buffer"; } - if (IsTestOp(r.op_type_name)) { group_name = "TEST"; } - if (IsDetectionOp(r.op_type_name)) { group_name = "Detection"; } - if (IsSummaryOp(r.op_type_name)) { group_name = "summary"; } - if (IsCUDAOp(r.op_type_name)) { group_name = "cuda"; } - if (IsParallelCastOp(r.op_type_name)) { group_name = "parallel_cast"; } - if (ShouldGenBaseClass(r.op_type_name)) { group_name = "BASE"; } - // if (IsUsedInPatterns(r.op_type_name)) { group_name = "used_in_patterns"; } - std::transform(group_name.begin(), group_name.end(), group_name.begin(), ::toupper); - groups[group_name].insert({kv.first, kv.second}); - } -} - -} // namespace - -int main(int argc, char* argv[]) { - std::streambuf* coutBuf = std::cout.rdbuf(); - std::ofstream of("OneFlowUserOpGen.td"); - std::streambuf* fileBuf = of.rdbuf(); - std::cout.rdbuf(fileBuf); - - std::map sorted{}; - auto unordered = oneflow::user_op::UserOpRegistryMgr::Get().GetAllOpRegistryResults(); - std::transform(unordered.begin(), unordered.end(), std::inserter(sorted, sorted.end()), - [](const std::pair& p) { return p; }); - std::map> groups; - GroupOpRegistryResults(sorted, groups); - PrintGroupNames(groups); - PrintIncludes(groups); - // std::cout << "#ifndef ONEFLOW_USER_OP_GEN\n"; - // std::cout << "#define ONEFLOW_USER_OP_GEN\n\n"; - - for (const auto& kv : groups) { - auto group_name = kv.first; - auto results = kv.second; - std::cout << "// Group: " << group_name << "\n"; - PrintNamesInResults(results); - std::cout << "// " - << "Total: " << kv.second.size() << "\n\n"; - CHECK(kv.second.size()) << group_name; - auto get_group_by_name = "GET_ONEFLOW_" + group_name + "_OP_DEFINITIONS"; - auto group_def_name = "ONEFLOW_" + group_name + "_OPS"; - std::cout << "#ifdef " << get_group_by_name << "\n\n"; - // std::cout << "#ifndef " << group_def_name << "\n\n"; - // std::cout << "#define " << group_def_name << "\n\n"; - PrintODSFromOpRegistryResults(results); - // std::cout << "#endif // " << group_def_name << "\n\n"; - std::cout << "#endif // " << get_group_by_name << "\n\n"; - } - of.flush(); - of.close(); - - std::cout.rdbuf(coutBuf); - return 0; -} diff --git a/oneflow/ir/oneflow-opt/CMakeLists.txt b/oneflow/ir/oneflow-opt/CMakeLists.txt index 266a0c66cf8..2f3678b052f 100644 --- a/oneflow/ir/oneflow-opt/CMakeLists.txt +++ b/oneflow/ir/oneflow-opt/CMakeLists.txt @@ -1,13 +1,6 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) get_property(oneflow_libs GLOBAL PROPERTY ALL_ONEFLOW_LIBS) -set(LIBS - ${dialect_libs} - ${conversion_libs} - MLIROptLib - MLIROneFlow - ${oneflow_libs} -) add_llvm_executable(oneflow-opt oneflow-opt.cpp) set(_origin_prefix "\$ORIGIN") @@ -20,6 +13,14 @@ set_target_properties(oneflow-opt PROPERTIES INSTALL_RPATH "${_origin_prefix}" ) llvm_update_compile_flags(oneflow-opt) -target_link_libraries(oneflow-opt PRIVATE ${LIBS}) +target_link_libraries(oneflow-opt + PRIVATE + ${dialect_libs} + ${conversion_libs} + MLIROptLib + PUBLIC + MLIROneFlow + ${oneflow_libs} +) mlir_check_all_link_libraries(oneflow-opt) diff --git a/oneflow/ir/oneflow-runtime/CMakeLists.txt b/oneflow/ir/oneflow-runtime/CMakeLists.txt new file mode 100644 index 00000000000..3ea7a4199b2 --- /dev/null +++ b/oneflow/ir/oneflow-runtime/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(lib) diff --git a/oneflow/ir/oneflow-runtime/lib/CMakeLists.txt b/oneflow/ir/oneflow-runtime/lib/CMakeLists.txt new file mode 100644 index 00000000000..379388d975a --- /dev/null +++ b/oneflow/ir/oneflow-runtime/lib/CMakeLists.txt @@ -0,0 +1,8 @@ +if(WITH_MLIR_CUDA_CODEGEN) + set(MLIR_RUNTIME_GPU_LIBS -Wl,--no-as-needed mlir_cuda_runtime -Wl,--as-needed) +endif(WITH_MLIR_CUDA_CODEGEN) +set(MLIR_RUNTIME_LIBS -Wl,--no-as-needed mlir_c_runner_utils -Wl,--as-needed) +oneflow_add_mlir_library(MLIROneFlowRuntime + Runtime.cpp +) +target_link_libraries(MLIROneFlowRuntime PUBLIC ${MLIR_RUNTIME_LIBS} ${MLIR_RUNTIME_GPU_LIBS}) diff --git a/oneflow/ir/oneflow-runtime/lib/Runtime.cpp b/oneflow/ir/oneflow-runtime/lib/Runtime.cpp new file mode 100644 index 00000000000..8cf98acaa7f --- /dev/null +++ b/oneflow/ir/oneflow-runtime/lib/Runtime.cpp @@ -0,0 +1,17 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// This file is added to avoid cmake error diff --git a/oneflow/ir/oneflow-translate/CMakeLists.txt b/oneflow/ir/oneflow-translate/CMakeLists.txt index c173ad09d64..fdd3486e9b3 100644 --- a/oneflow/ir/oneflow-translate/CMakeLists.txt +++ b/oneflow/ir/oneflow-translate/CMakeLists.txt @@ -37,6 +37,7 @@ target_link_libraries(oneflow-translate PRIVATE ${dialect_libs} ${translation_libs} + PUBLIC MLIRTranslation MLIROneFlowTranslation ${oneflow_libs} diff --git a/oneflow/ir/oneflow-translate/include/OneFlow/MLIROneFlowTranslation.h b/oneflow/ir/oneflow-translate/include/OneFlow/MLIROneFlowTranslation.h index 4c77f79dda3..8accbe8bf82 100644 --- a/oneflow/ir/oneflow-translate/include/OneFlow/MLIROneFlowTranslation.h +++ b/oneflow/ir/oneflow-translate/include/OneFlow/MLIROneFlowTranslation.h @@ -16,13 +16,17 @@ limitations under the License. #ifndef ONEFLOW_IR_ONEFLOW_TRANSLATE_INCLUDE_ONEFLOW_MLIRONEFLOWTRANSLATION_H_ #define ONEFLOW_IR_ONEFLOW_TRANSLATE_INCLUDE_ONEFLOW_MLIRONEFLOWTRANSLATION_H_ -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/MLIRContext.h" -#include "OneFlow/OneFlowOps.h" #include "oneflow/core/framework/user_op_def.pb.h" #include "oneflow/core/job/job.pb.h" +#include "oneflow/core/job/sbp_parallel.pb.h" #include "oneflow/core/operator/op_conf.pb.h" + +#include "OneFlow/OneFlowOps.h" + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" + #include #include @@ -44,6 +48,17 @@ LogicalResult ConvertCtrlInputs(Operation* op, ::oneflow::OperatorConf& op_conf) ResultRange GetDataOutputResults(Operation* op); llvm::Optional GetCtrlOutputResult(Operation* op); llvm::Optional GetOutputLbn(OpResult result); +llvm::Optional GetDataTypeAttr(MLIRContext* context, + ::oneflow::DataType oneflow_value); +LogicalResult ConvertVariableOpConf(Operation* op, oneflow::VariableOpAdaptor& adaptor, + ::oneflow::OperatorConf* op_conf); +LogicalResult ConvertInputOpConf(Operation* op, oneflow::InputOpAdaptor& adaptor, + ::oneflow::OperatorConf* op_conf); +LogicalResult ConvertOutputOpConf(Operation* op, oneflow::OutputOpAdaptor& adaptor, + ::oneflow::OperatorConf* op_conf); + +LogicalResult ParseNdSbpFromAttr(ArrayAttr nd_sbp_attr, ::oneflow::NdSbp* nd_sbp); +Attribute ConvertNdSbpToAttr(Builder& builder, const ::oneflow::NdSbp& nd_sbp); class Importer { public: @@ -95,7 +110,7 @@ class Importer { return GetBuilder().getArrayAttr(attrs); } - DenseIntElementsAttr DenseIntElementsAttrFromShape(const ::oneflow::ShapeProto& shape); + ArrayAttr GetAttrFromShape(const ::oneflow::ShapeProto& shape); llvm::Optional GetTypeFromOneFlowDataType(::oneflow::DataType dt); OpBuilder& GetBuilder() { return builder_; } MLIRContext* GetMLIRContext() { return context_; } @@ -138,8 +153,12 @@ class RoundTripOneFlowJobWrapperInterface { void RoundTripOneFlowJob( RoundTripOneFlowJobWrapperInterface& job_wrapper, const std::function& is_legit_job); + void registerFromOneFlowJobTranslation(); +void SaveJobToIR(RoundTripOneFlowJobWrapperInterface& job_wrapper, const std::string& path); +void LoadJobFromIR(RoundTripOneFlowJobWrapperInterface& job_wrapper, const std::string& path); + } // namespace oneflow } // namespace mlir diff --git a/oneflow/ir/oneflow-translate/lib/OneFlow/CMakeLists.txt b/oneflow/ir/oneflow-translate/lib/OneFlow/CMakeLists.txt index 8d48f02356f..ab9be4f73e3 100644 --- a/oneflow/ir/oneflow-translate/lib/OneFlow/CMakeLists.txt +++ b/oneflow/ir/oneflow-translate/lib/OneFlow/CMakeLists.txt @@ -1,4 +1,4 @@ -oneflow_add_llvm_library(MLIROneFlowTranslation +oneflow_add_mlir_library(MLIROneFlowTranslation MLIROneFlowTranslation.cpp Importer.cpp ADDITIONAL_HEADER_DIRS @@ -7,12 +7,6 @@ oneflow_add_llvm_library(MLIROneFlowTranslation oneflow_deps LINK_LIBS PUBLIC MLIRIR - BUILDTREE_ONLY -) -llvm_update_compile_flags(MLIROneFlowTranslation) - -target_link_libraries(MLIROneFlowTranslation - PRIVATE ${dialect_libs} ${translation_libs} MLIRIR diff --git a/oneflow/ir/oneflow-translate/lib/OneFlow/Importer.cpp b/oneflow/ir/oneflow-translate/lib/OneFlow/Importer.cpp index 9962d7e4f58..d4a8cf98603 100644 --- a/oneflow/ir/oneflow-translate/lib/OneFlow/Importer.cpp +++ b/oneflow/ir/oneflow-translate/lib/OneFlow/Importer.cpp @@ -13,6 +13,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/common/data_type.pb.h" +#include "oneflow/core/framework/user_op_conf.pb.h" +#include "oneflow/core/job/job.pb.h" +#include "oneflow/core/operator/op_conf.pb.h" +#include "oneflow/core/framework/user_op_def.h" +#include "oneflow/core/framework/user_op_registry_manager.h" + +#include "OneFlow/OneFlowDialect.h" +#include "OneFlow/OneFlowOps.h" +#include "OneFlow/OneFlowSupport.h" +#include "OneFlow/Passes.h" +#include "OneFlow/MLIROneFlowTranslation.h" + #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -41,28 +54,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" -#include "OneFlow/OneFlowDialect.h" -#include "OneFlow/OneFlowOps.h" -#include "OneFlow/MLIROneFlowTranslation.h" -#include "OneFlow/Passes.h" -#include "OneFlow/OneFlowSupport.h" - -#include "oneflow/core/common/data_type.pb.h" -#include "oneflow/core/framework/user_op_conf.pb.h" -#include "oneflow/core/job/job.pb.h" -#include "oneflow/core/operator/op_conf.pb.h" -#include "oneflow/core/framework/user_op_def.h" -#include "oneflow/core/framework/user_op_registry_manager.h" -#include -#include #include -#include -#include -#include -#include -#include -#include -#include namespace mlir { @@ -107,7 +99,9 @@ std::vector GetOutputLbns(const ::oneflow::OperatorConf& op, UserOp for (const auto& arg : op.user_conf().output()) { names_appeared.insert(arg.first); } for (const auto& arg_def : arg_defs) { const auto& key = arg_def.name(); - auto result_size = op.user_conf().output().at(key).s_size(); + const auto& it = op.user_conf().output().find(key); + if (it == op.user_conf().output().end()) { continue; } + auto result_size = it->second.s_size(); if (result_size == 0) { continue; } for (int32_t i = 0; i < result_size; i++) { const auto output_lbn = op_name + "/" + key + "_" + std::to_string(i); @@ -124,8 +118,7 @@ LogicalResult Importer::AddUserOpInputOutputSegments(const ::oneflow::OperatorCo if (op.has_user_conf() == false) return failure(); const auto& user_conf = op.user_conf(); const ::oneflow::UserOpDef& op_def = GetUserOpDef(op.user_conf().op_type_name()); - const auto UserOpOperationName = - OperationName(oneflow::UserOp::getOperationName(), GetMLIRContext()); + const auto UserOpOperationName = OperationName(UserOp::getOperationName(), GetMLIRContext()); attr_vec.push_back(GetBuilder().getNamedAttr( oneflow::UserOp::input_sizesAttrName(UserOpOperationName), GetBuilder().getI32ArrayAttr(GetSizesFromArgs(user_conf.input(), op_def.input())))); @@ -199,17 +192,14 @@ llvm::Optional GetDataTypeAttr(MLIRContext* context } } -DenseIntElementsAttr Importer::DenseIntElementsAttrFromShape(const ::oneflow::ShapeProto& shape) { - ArrayRef values = {shape.dim().begin(), shape.dim().end()}; - RankedTensorType tt = RankedTensorType::get({static_cast(values.size())}, - GetBuilder().getIntegerType(64, true)); - ; - return DenseIntElementsAttr::get(tt, values); +ArrayAttr Importer::GetAttrFromShape(const ::oneflow::ShapeProto& shape) { + return GetBuilder().getArrayAttr(llvm::to_vector<8>(llvm::map_range( + shape.dim(), [this](int64_t v) -> Attribute { return getSI64IntegerAttr(v); }))); } -void WriteDenseIntElementsToShape(mlir::Attribute& attr, ::oneflow::ShapeProto* shape) { - for (auto int_v : attr.dyn_cast().getValues()) { - shape->add_dim(int_v); +void WriteAttrToShape(mlir::Attribute& attr, ::oneflow::ShapeProto* shape) { + for (auto v : attr.dyn_cast().getValue()) { + shape->add_dim(v.dyn_cast().getSInt()); } } @@ -244,8 +234,7 @@ LogicalResult Importer::namedAttributesFromUserOp(const ::oneflow::OperatorConf& DEFINE_ONE_ELIF(at_string, getStringAttr) #undef DEFINE_ONE_ELIF else if (value.has_at_shape()) { - attr_vec.emplace_back( - GetBuilder().getNamedAttr(name, DenseIntElementsAttrFromShape(value.at_shape()))); + attr_vec.emplace_back(GetBuilder().getNamedAttr(name, GetAttrFromShape(value.at_shape()))); } #define DEFINE_ONE_ELIF(at_key, get_attr, field) \ else if (value.has_##at_key()) { \ @@ -285,9 +274,9 @@ LogicalResult Importer::namedAttributesFromUserOp(const ::oneflow::OperatorConf& name, GetBuilder().getArrayAttr(llvm::to_vector<8>(dt_attr_list)))); } else if (value.has_at_list_shape()) { - auto dense_attr_list = llvm::map_range( - value.at_list_shape().val(), - [&](const ::oneflow::ShapeProto& s) { return DenseIntElementsAttrFromShape(s); }); + auto dense_attr_list = + llvm::map_range(value.at_list_shape().val(), + [&](const ::oneflow::ShapeProto& s) { return GetAttrFromShape(s); }); std::vector dense_attr_vector{dense_attr_list.begin(), dense_attr_list.end()}; attr_vec.emplace_back( @@ -361,6 +350,49 @@ llvm::Optional Importer::GetTypeFromOneFlowDataType(::oneflow::DataType dt } } +LogicalResult ParseNdSbpFromAttr(ArrayAttr nd_sbp_attr, ::oneflow::NdSbp* nd_sbp) { + for (const auto& sbp_attr : nd_sbp_attr) { + auto sbp_str_attr = sbp_attr.dyn_cast(); + if (!sbp_str_attr) { + llvm::errs() << "nd_sbp attr is not a StrArrayAttr"; + return failure(); + } + auto sbp_strref = sbp_str_attr.getValue(); + if (sbp_strref.startswith("S")) { + if (!(sbp_strref.substr(1, 1) == "(" && sbp_strref.endswith(")"))) { + llvm::errs() << "invalid sbp S(x) string value: " << sbp_strref; + return failure(); + } + auto split_axis = std::stoi(sbp_strref.substr(2, 1).str()); + nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(split_axis); + } else if (sbp_strref == "B") { + nd_sbp->add_sbp_parallel()->mutable_broadcast_parallel(); + } else if (sbp_strref == "P") { + nd_sbp->add_sbp_parallel()->mutable_partial_sum_parallel(); + } else { + llvm::errs() << "unspported nd_sbp string value: " << sbp_strref; + return failure(); + } + } + return success(); +} + +Attribute ConvertNdSbpToAttr(Builder& builder, const ::oneflow::NdSbp& nd_sbp) { + llvm::SmallVector sbp_strrefs; + for (const auto& sbp : nd_sbp.sbp_parallel()) { + if (sbp.has_split_parallel()) { + sbp_strrefs.emplace_back("S(" + std::to_string(sbp.split_parallel().axis()) + ")"); + } else if (sbp.has_broadcast_parallel()) { + sbp_strrefs.emplace_back("B"); + } else if (sbp.has_partial_sum_parallel()) { + sbp_strrefs.emplace_back("P"); + } else { + llvm::errs() << "unsupported sbp"; + } + } + return builder.getStrArrayAttr(makeArrayRef(sbp_strrefs)); +} + LogicalResult Importer::ProcessUserOp(const ::oneflow::OperatorConf& op) { if (op.has_user_conf() == false) { GetModule().emitError("Not a user op. op name: " + op.name()); @@ -402,7 +434,7 @@ LogicalResult Importer::ProcessUserOp(const ::oneflow::OperatorConf& op) { if (failed(AppendCtrlOutType(out_types))) { return failure(); } OperationState state(FileLineColLoc::get(GetMLIRContext(), op.name(), 0, 0), - oneflow::UserOp::getOperationName()); + UserOp::getOperationName()); uint32_t data_input_size = 0; uint32_t data_output_size = 0; for (const auto& input : op.user_conf().input()) { data_input_size += input.second.s().size(); } @@ -662,20 +694,6 @@ LogicalResult Importer::ConvertUserOpAttributes(Operation* op, auto id = id_attr.first; // mlir only attrs // TODO: find a way to skip attrs like callee in a declarative way - { - std::vector keys{}; - std::vector sizes{}; - assert(GetFilteredSegmentKeyAndSizes(op, keys, sizes) - .succeeded()); - for (const auto& s : keys) { op_conf.mutable_user_conf()->add_input_order(s); } - } - { - std::vector keys{}; - std::vector sizes{}; - assert(GetFilteredSegmentKeyAndSizes(op, keys, sizes) - .succeeded()); - for (const auto& s : keys) { op_conf.mutable_user_conf()->add_output_order(s); } - } if (id.strref().equals("callee") || id.strref().equals(OpTrait::IsOpConfCompatible::getDeviceNameAttr()) || id.strref().equals(OpTrait::IsOpConfCompatible::getHierarchyAttr()) @@ -718,7 +736,7 @@ LogicalResult Importer::ConvertUserOpAttributes(Operation* op, } else if (attr_type == ::oneflow::kAtString) { user_attr.set_at_string(attr.dyn_cast().getValue().str()); } else if (attr_type == ::oneflow::kAtShape) { - WriteDenseIntElementsToShape(attr, user_attr.mutable_at_shape()); + WriteAttrToShape(attr, user_attr.mutable_at_shape()); } else if (attr_type == ::oneflow::kAtDataType) { ::oneflow::DataType dt = ::oneflow::kInvalidDataType; if (succeeded(ConvertDT(attr, dt))) { @@ -757,11 +775,9 @@ LogicalResult Importer::ConvertUserOpAttributes(Operation* op, } } } else if (attr_type == ::oneflow::kAtListShape) { - for (auto s : attr.dyn_cast().getValue()) { + for (auto shape_attr : attr.dyn_cast().getValue()) { ::oneflow::ShapeProto* shape_ptr = user_attr.mutable_at_list_shape()->add_val(); - for (auto int_v : s.dyn_cast().getValues()) { - shape_ptr->mutable_dim()->Add(int_v); - } + WriteAttrToShape(shape_attr, shape_ptr); } } else if (attr_type == ::oneflow::kAtListString) { // attr like nd_sbp requires the existence of list even it is empty @@ -776,6 +792,167 @@ LogicalResult Importer::ConvertUserOpAttributes(Operation* op, (*user_conf->mutable_attr())[id.str()] = user_attr; } } + { + std::vector keys{}; + std::vector sizes{}; + assert(GetFilteredSegmentKeyAndSizes(op, keys, sizes) + .succeeded()); + for (const auto& s : keys) { op_conf.mutable_user_conf()->add_input_order(s); } + } + { + std::vector keys{}; + std::vector sizes{}; + assert(GetFilteredSegmentKeyAndSizes(op, keys, sizes) + .succeeded()); + for (const auto& s : keys) { op_conf.mutable_user_conf()->add_output_order(s); } + } + return success(); +} + +LogicalResult ConvertVariableOpConf(Operation* op, oneflow::VariableOpAdaptor& adaptor, + ::oneflow::OperatorConf* op_conf) { + op_conf->set_name(adaptor.op_name().getValue().str()); + op_conf->set_device_tag(adaptor.device_tag().getValue().str()); + op_conf->set_scope_symbol_id(adaptor.scope_symbol_id().getInt()); + // TODO: process stream_name_hint + + auto* var_op_conf = op_conf->mutable_variable_conf(); + var_op_conf->set_out("out"); + + if (auto shape_attr = + op->getAttrOfType(OpTrait::TensorSource::getShapeAttrName())) { + WriteAttrToShape(shape_attr, var_op_conf->mutable_shape()); + } + + if (op->hasAttr(OpTrait::TensorSource::getDataTypeAttrName())) { + ::oneflow::DataType dt = ::oneflow::DataType::kInvalidDataType; + if (failed(ConvertDT(adaptor.data_type(), dt))) { return failure(); } + var_op_conf->set_data_type(dt); + } + + if (op->hasAttr("model_name")) { + var_op_conf->set_model_name(adaptor.model_name().getValue().str()); + } + + if (op->hasAttr("l1_regularization")) { + var_op_conf->mutable_regularizer()->mutable_l1_l2_conf()->set_l1( + adaptor.l1_regularization().getValue().convertToFloat()); + } + + if (op->hasAttr("l2_regularization")) { + var_op_conf->mutable_regularizer()->mutable_l1_l2_conf()->set_l2( + adaptor.l2_regularization().getValue().convertToFloat()); + } + + if (op->hasAttr("trainable")) { var_op_conf->set_trainable(adaptor.trainable().getValue()); } + + for (const auto& sbp : adaptor.nd_sbp()) { + var_op_conf->add_nd_sbp(sbp.cast().getValue().str()); + } + + // all operands are ctrl_inputs + for (const auto& operand : op->getOperands()) { + op_conf->add_ctrl_in_op_name( + operand.getDefiningOp()->getAttrOfType("op_name").getValue().str()); + } + + // empty initializer + var_op_conf->mutable_initializer()->mutable_empty_conf(); + + return success(); +} + +LogicalResult ConvertInputOpConf(Operation* op, oneflow::InputOpAdaptor& adaptor, + ::oneflow::OperatorConf* op_conf) { + op_conf->set_name(adaptor.op_name().getValue().str()); + op_conf->set_device_tag(adaptor.device_tag().getValue().str()); + op_conf->set_scope_symbol_id(adaptor.scope_symbol_id().getInt()); + // TODO: process stream_name_hint + + auto* input_op_conf = op_conf->mutable_input_conf(); + input_op_conf->set_out("out"); + + if (auto shape_attr = + op->getAttrOfType(OpTrait::TensorSource::getShapeAttrName())) { + WriteAttrToShape(shape_attr, input_op_conf->mutable_blob_conf()->mutable_shape()); + } + + if (op->hasAttr(OpTrait::TensorSource::getDataTypeAttrName())) { + ::oneflow::DataType dt = ::oneflow::DataType::kInvalidDataType; + if (failed(ConvertDT(adaptor.data_type(), dt))) { return failure(); } + input_op_conf->mutable_blob_conf()->set_data_type(dt); + } + + if (op->hasAttr(OpTrait::TensorSource::getIsDynamicAttrName())) { + input_op_conf->mutable_blob_conf()->set_is_dynamic(adaptor.is_dynamic().getValue()); + } + + if (op->hasAttr(OpTrait::TensorSource::getNdSbpAttrName())) { + if (failed(ParseNdSbpFromAttr(adaptor.nd_sbp(), + input_op_conf->mutable_blob_conf()->mutable_nd_sbp()))) { + return failure(); + } + } + + if (op->hasAttr("job_name")) { input_op_conf->set_job_name(adaptor.job_name().getValue().str()); } + + // operand 0 is block argument, others are ctrl_inputs + for (size_t i = 1; i < op->getNumOperands(); ++i) { + op_conf->add_ctrl_in_op_name( + op->getOperand(i).getDefiningOp()->getAttrOfType("op_name").getValue().str()); + } + + return success(); +} + +LogicalResult ConvertOutputOpConf(Operation* op, oneflow::OutputOpAdaptor& adaptor, + ::oneflow::OperatorConf* op_conf) { + op_conf->set_name(adaptor.op_name().getValue().str()); + op_conf->set_device_tag(adaptor.device_tag().getValue().str()); + op_conf->set_scope_symbol_id(adaptor.scope_symbol_id().getInt()); + // TODO: process stream_name_hint + + auto* output_op_conf = op_conf->mutable_output_conf(); + output_op_conf->set_out("out"); + + if (auto shape_attr = + op->getAttrOfType(OpTrait::TensorSource::getShapeAttrName())) { + WriteAttrToShape(shape_attr, output_op_conf->mutable_blob_conf()->mutable_shape()); + } + + if (op->hasAttr(OpTrait::TensorSource::getDataTypeAttrName())) { + ::oneflow::DataType dt = ::oneflow::DataType::kInvalidDataType; + if (failed(ConvertDT(adaptor.data_type(), dt))) { return failure(); } + output_op_conf->mutable_blob_conf()->set_data_type(dt); + } + + if (op->hasAttr(OpTrait::TensorSource::getIsDynamicAttrName())) { + output_op_conf->mutable_blob_conf()->set_is_dynamic(adaptor.is_dynamic().getValue()); + } + + if (op->hasAttr(OpTrait::TensorSource::getNdSbpAttrName())) { + if (failed(ParseNdSbpFromAttr(adaptor.nd_sbp(), + output_op_conf->mutable_blob_conf()->mutable_nd_sbp()))) { + return failure(); + } + } + + if (op->hasAttr("job_name")) { + output_op_conf->set_job_name(adaptor.job_name().getValue().str()); + } + + if (op->getNumOperands() == 0) { + op->emitError("output op has at least one input."); + return failure(); + } + auto result = op->getOperand(0).dyn_cast(); + auto* producer_op = result.getDefiningOp(); + auto output_lbn = producer_op->getAttrOfType("output_lbns")[result.getResultNumber()]; + output_op_conf->set_in(output_lbn.dyn_cast().getValue().str()); + for (size_t i = 1; i < op->getNumOperands(); ++i) { + op_conf->add_ctrl_in_op_name( + op->getOperand(i).getDefiningOp()->getAttrOfType("op_name").getValue().str()); + } return success(); } diff --git a/oneflow/ir/oneflow-translate/lib/OneFlow/MLIROneFlowTranslation.cpp b/oneflow/ir/oneflow-translate/lib/OneFlow/MLIROneFlowTranslation.cpp index f5c0854a381..e110f493f44 100644 --- a/oneflow/ir/oneflow-translate/lib/OneFlow/MLIROneFlowTranslation.cpp +++ b/oneflow/ir/oneflow-translate/lib/OneFlow/MLIROneFlowTranslation.cpp @@ -13,6 +13,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + +#include "oneflow/core/common/util.h" +#include "oneflow/core/common/data_type.pb.h" +#include "oneflow/core/framework/user_op_conf.pb.h" +#include "oneflow/core/job/job.pb.h" +#include "oneflow/core/operator/op_conf.pb.h" +#include "oneflow/core/operator/interface_blob_conf.pb.h" + +#include "OneFlow/OneFlowDialect.h" +#include "OneFlow/OneFlowOps.h" +#include "OneFlow/OneFlowOpTraits.h" +#include "OneFlow/Passes.h" +#include "OneFlow/MLIROneFlowTranslation.h" + #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -24,12 +38,15 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" #include "mlir/IR/UseDefLists.h" #include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/Passes.h" #include "mlir/Translation.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Parser.h" + #include "llvm-c/Core.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/None.h" @@ -39,27 +56,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" -#include "OneFlow/OneFlowDialect.h" -#include "OneFlow/OneFlowOps.h" -#include "OneFlow/MLIROneFlowTranslation.h" -#include "OneFlow/Passes.h" - -#include "oneflow/core/common/data_type.pb.h" -#include "oneflow/core/framework/user_op_conf.pb.h" -#include "oneflow/core/job/job.pb.h" -#include "oneflow/core/operator/op_conf.pb.h" -#include "oneflow/core/common/util.h" - -#include -#include #include -#include -#include -#include -#include -#include -#include -#include namespace mlir { @@ -80,10 +77,23 @@ class JobImporter : Importer { LogicalResult AddDeviceName(const ::oneflow::OperatorConf& op, std::vector& attr_vec) override; LogicalResult InsertOpResults(const ::oneflow::OperatorConf& op, Operation*) override; - LogicalResult ProcessSystemOp(const ::oneflow::OperatorConf& op) override; + LogicalResult ProcessJob(); + LogicalResult ProcessSystemOp(const ::oneflow::OperatorConf& op) override; + LogicalResult ProcessVariableOp(const ::oneflow::OperatorConf& op); + LogicalResult ProcessInputOp(const ::oneflow::OperatorConf& op_conf, Block* entry_block, + size_t& input_count); + LogicalResult ProcessOutputOp(const ::oneflow::OperatorConf& op_conf); + LogicalResult TryToUpdateJob(); + LogicalResult ConvertUserOp(Operation* op, ::oneflow::Job& job); + LogicalResult ConvertSystemOp(Operation* op, ::oneflow::Job& job); + LogicalResult ConvertVariableOp(Operation* op, ::oneflow::Job& job); + LogicalResult ConvertInputOp(Operation* op, ::oneflow::Job& job); + LogicalResult ConvertOutputOp(Operation* op, ::oneflow::Job& job); + Type GetTensorTypeOfLbn(const std::string& lbn) override; + Type GetInterfaceBlobConfType(const ::oneflow::InterfaceBlobConf& blob_conf); private: std::unordered_map lbn2result_; @@ -179,6 +189,8 @@ LogicalResult JobImporter::ProcessSystemOp(const ::oneflow::OperatorConf& op) { GetModule().emitError("Not a sys op. op name: " + op.name()); return failure(); } + if (op.has_variable_conf()) { return ProcessVariableOp(op); } + auto input_bns_lbns = job_wrapper_.InputBns4OpName(op.name()); auto input_bns = input_bns_lbns.first; auto input_lbns = input_bns_lbns.second; @@ -195,7 +207,7 @@ LogicalResult JobImporter::ProcessSystemOp(const ::oneflow::OperatorConf& op) { GetBuilder().getStrArrayAttr( std::vector({output_lbns.begin(), output_lbns.end()})))); OperationState state(FileLineColLoc::get(GetMLIRContext(), op.name(), 0, 0), - oneflow::SystemOp::getOperationName()); + SystemOp::getOperationName()); attr_vec.push_back( GetBuilder().getNamedAttr("op_type_case", GetBuilder().getI32IntegerAttr(op.op_type_case()))); if (failed(AddOperandSegmentSizes(static_cast(input_lbns.size()), op.ctrl_in_op_name_size(), @@ -225,28 +237,334 @@ LogicalResult JobImporter::ProcessSystemOp(const ::oneflow::OperatorConf& op) { return success(); } -LogicalResult JobImporter::ProcessJob() { - auto func_type = GetBuilder().getFunctionType(llvm::None, llvm::None); - auto function = mlir::FuncOp::create(GetRootLocation(), job_->job_conf().job_name(), func_type); - auto& entryBlock = *function.addEntryBlock(); - GetBuilder().setInsertionPointToStart(&entryBlock); +LogicalResult JobImporter::ProcessVariableOp(const ::oneflow::OperatorConf& op_conf) { + if (!op_conf.has_variable_conf()) { + GetModule().emitError("Not a variable op. op name: " + op_conf.name()); + return failure(); + } + + if (op_conf.variable_conf().has_tick()) { + GetModule().emitError("variable op has tick input. op name: " + op_conf.name()); + return failure(); + } + + OperationState state(FileLineColLoc::get(GetMLIRContext(), op_conf.name(), 0, 0), + "oneflow.variable"); + // attrs + std::vector attr_vec; + if (failed(AddOpConf(op_conf, attr_vec))) { return failure(); } + if (failed(AddDeviceName(op_conf, attr_vec))) { return failure(); } + // attr output_lbns + auto output_lbns_attr = GetBuilder().getStrArrayAttr({op_conf.name() + "/out"}); + attr_vec.emplace_back(GetBuilder().getNamedAttr( + OpTrait::IsImportCompatible::getOutputLBNsAttr(), output_lbns_attr)); + // attr shape + auto shape_attr = GetAttrFromShape(op_conf.variable_conf().shape()); + auto shape_named_attr = + GetBuilder().getNamedAttr(OpTrait::TensorSource::getShapeAttrName(), shape_attr); + attr_vec.emplace_back(shape_named_attr); + // attr data_type + if (op_conf.variable_conf().has_data_type()) { + attr_vec.emplace_back(GetBuilder().getNamedAttr( + OpTrait::TensorSource::getDataTypeAttrName(), + GetDataTypeAttr(GetMLIRContext(), op_conf.variable_conf().data_type()).getValue())); + } + // attr model_name + if (op_conf.variable_conf().has_model_name()) { + const std::string& model_name = op_conf.variable_conf().model_name(); + attr_vec.emplace_back( + GetBuilder().getNamedAttr("model_name", GetBuilder().getStringAttr(model_name))); + } + // attr l1 l2 regularization + if (op_conf.variable_conf().has_regularizer() + && op_conf.variable_conf().regularizer().has_l1_l2_conf()) { + if (op_conf.variable_conf().regularizer().l1_l2_conf().has_l1()) { + float l1_regularization = op_conf.variable_conf().regularizer().l1_l2_conf().l1(); + attr_vec.emplace_back(GetBuilder().getNamedAttr( + "l1_regularization", GetBuilder().getF32FloatAttr(l1_regularization))); + } + if (op_conf.variable_conf().regularizer().l1_l2_conf().has_l2()) { + float l2_regularization = op_conf.variable_conf().regularizer().l1_l2_conf().l2(); + attr_vec.emplace_back(GetBuilder().getNamedAttr( + "l2_regularization", GetBuilder().getF32FloatAttr(l2_regularization))); + } + } + // attr trainable + if (op_conf.variable_conf().has_trainable()) { + bool trainable = op_conf.variable_conf().trainable(); + attr_vec.emplace_back( + GetBuilder().getNamedAttr("trainable", GetBuilder().getBoolAttr(trainable))); + } + // attr nd_sbp + const std::vector nd_sbp_str_vec{op_conf.variable_conf().nd_sbp().begin(), + op_conf.variable_conf().nd_sbp().end()}; + auto nd_sbp_attr = GetBuilder().getStrArrayAttr(makeArrayRef(nd_sbp_str_vec)); + attr_vec.emplace_back( + GetBuilder().getNamedAttr(OpTrait::TensorSource::getNdSbpAttrName(), nd_sbp_attr)); + // add attrs + state.addAttributes(attr_vec); + // operands + std::vector<::mlir::Value> operand_vec; + if (failed(AppendCtrlInOperand(op_conf, operand_vec))) { return failure(); } + state.addOperands(operand_vec); + // result types + llvm::SmallVector out_types; + auto output_lbn = op_conf.name() + "/out"; + out_types.push_back(GetTensorTypeOfLbn(output_lbn)); + if (failed(AppendCtrlOutType(out_types))) { return failure(); } + state.addTypes(out_types); + // create op + auto op = GetBuilder().createOperation(state); + if (!op) { + GetModule()->emitError("fail to create op, name: " + op_conf.name()); + return failure(); + } + // record result + if (op->getNumResults() != 2) { + op->emitError("variable op should has two results (out and ctrl_output), but got " + + std::to_string(op->getNumResults()) + "\n"); + return failure(); + } + if (!lbn2result_.emplace(output_lbn, op->getResult(0)).second) { + op->emitError("lbn already exists, lbn: ") << output_lbn; + return failure(); + } + if (!op_name2ctrl_result_.emplace(op_conf.name(), op->getResult(1)).second) { + op->emitError("ctrl output already exists, op_name: ") << op_conf.name(); + return failure(); + } + return success(); +} + +LogicalResult JobImporter::ProcessInputOp(const ::oneflow::OperatorConf& op_conf, + Block* entry_block, size_t& input_count) { + if (!op_conf.has_input_conf()) { + GetModule().emitError("Not a input op. op name: " + op_conf.name()); + return failure(); + } + + if (op_conf.input_conf().has_tick()) { + GetModule().emitError("input op has tick input. op name: " + op_conf.name()); + return failure(); + } + + OperationState state(FileLineColLoc::get(GetMLIRContext(), op_conf.name(), 0, 0), + "oneflow.input"); + // attrs + std::vector attr_vec; + if (failed(AddOpConf(op_conf, attr_vec))) { return failure(); } + if (failed(AddDeviceName(op_conf, attr_vec))) { return failure(); } + // attr output_lbns + auto output_lbns_attr = GetBuilder().getStrArrayAttr({op_conf.name() + "/out"}); + attr_vec.emplace_back(GetBuilder().getNamedAttr( + OpTrait::IsImportCompatible::getOutputLBNsAttr(), output_lbns_attr)); + // attr shape + if (op_conf.input_conf().blob_conf().has_shape()) { + auto shape_attr = GetAttrFromShape(op_conf.input_conf().blob_conf().shape()); + attr_vec.emplace_back( + GetBuilder().getNamedAttr(OpTrait::TensorSource::getShapeAttrName(), shape_attr)); + } + // attr data_type + if (op_conf.input_conf().blob_conf().has_data_type()) { + attr_vec.emplace_back(GetBuilder().getNamedAttr( + OpTrait::TensorSource::getDataTypeAttrName(), + GetDataTypeAttr(GetMLIRContext(), op_conf.input_conf().blob_conf().data_type()) + .getValue())); + } + // attr is_dynamic + if (op_conf.input_conf().blob_conf().has_is_dynamic()) { + bool is_dynamic = op_conf.input_conf().blob_conf().is_dynamic(); + attr_vec.emplace_back(GetBuilder().getNamedAttr( + OpTrait::TensorSource::getIsDynamicAttrName(), GetBuilder().getBoolAttr(is_dynamic))); + } + // attr nd_sbp + if (op_conf.input_conf().blob_conf().has_nd_sbp()) { + auto nd_sbp_attr = ConvertNdSbpToAttr(GetBuilder(), op_conf.input_conf().blob_conf().nd_sbp()); + attr_vec.emplace_back( + GetBuilder().getNamedAttr(OpTrait::TensorSource::getNdSbpAttrName(), nd_sbp_attr)); + } + // attr job_name + if (op_conf.input_conf().has_job_name()) { + const std::string& job_name = op_conf.input_conf().job_name(); + attr_vec.emplace_back( + GetBuilder().getNamedAttr("job_name", GetBuilder().getStringAttr(job_name))); + } + // add attrs + state.addAttributes(attr_vec); + // operands + std::vector<::mlir::Value> operand_vec; + operand_vec.emplace_back(entry_block->getArgument(input_count++)); + if (failed(AppendCtrlInOperand(op_conf, operand_vec))) { return failure(); } + state.addOperands(operand_vec); + // result types + llvm::SmallVector out_types; + auto output_lbn = op_conf.name() + "/out"; + out_types.push_back(GetTensorTypeOfLbn(output_lbn)); + if (failed(AppendCtrlOutType(out_types))) { return failure(); } + state.addTypes(out_types); + // create op + auto op = GetBuilder().createOperation(state); + if (!op) { + GetModule()->emitError("fail to create op, name: " + op_conf.name()); + return failure(); + } + // record result + if (op->getNumResults() != 2) { + op->emitError("input op should has two results (out and ctrl_output), but got " + + std::to_string(op->getNumResults()) + "\n"); + return failure(); + } + if (!lbn2result_.emplace(output_lbn, op->getResult(0)).second) { + op->emitError("lbn already exists, lbn: ") << output_lbn; + return failure(); + } + if (!op_name2ctrl_result_.emplace(op_conf.name(), op->getResult(1)).second) { + op->emitError("ctrl output already exists, op_name: ") << op_conf.name(); + return failure(); + } + return success(); +} + +LogicalResult JobImporter::ProcessOutputOp(const ::oneflow::OperatorConf& op_conf) { + if (!op_conf.has_output_conf()) { + GetModule().emitError("Not a output op. op name: " + op_conf.name()); + return failure(); + } + + OperationState state(FileLineColLoc::get(GetMLIRContext(), op_conf.name(), 0, 0), + "oneflow.output"); + // attrs + std::vector attr_vec; + if (failed(AddOpConf(op_conf, attr_vec))) { return failure(); } + if (failed(AddDeviceName(op_conf, attr_vec))) { return failure(); } + // attr output_lbns + auto output_lbns_attr = GetBuilder().getStrArrayAttr({op_conf.name() + "/out"}); + attr_vec.emplace_back(GetBuilder().getNamedAttr( + OpTrait::IsImportCompatible::getOutputLBNsAttr(), output_lbns_attr)); + // attr shape + if (op_conf.output_conf().blob_conf().has_shape()) { + auto shape_attr = GetAttrFromShape(op_conf.output_conf().blob_conf().shape()); + attr_vec.emplace_back( + GetBuilder().getNamedAttr(OpTrait::TensorSource::getShapeAttrName(), shape_attr)); + } + // attr data_type + if (op_conf.output_conf().blob_conf().has_data_type()) { + attr_vec.emplace_back(GetBuilder().getNamedAttr( + OpTrait::TensorSource::getDataTypeAttrName(), + GetDataTypeAttr(GetMLIRContext(), op_conf.output_conf().blob_conf().data_type()) + .getValue())); + } + // attr is_dynamic + if (op_conf.output_conf().blob_conf().has_is_dynamic()) { + bool is_dynamic = op_conf.output_conf().blob_conf().is_dynamic(); + attr_vec.emplace_back(GetBuilder().getNamedAttr( + OpTrait::TensorSource::getIsDynamicAttrName(), GetBuilder().getBoolAttr(is_dynamic))); + } + // attr nd_sbp + if (op_conf.output_conf().blob_conf().has_nd_sbp()) { + auto nd_sbp_attr = ConvertNdSbpToAttr(GetBuilder(), op_conf.output_conf().blob_conf().nd_sbp()); + attr_vec.emplace_back( + GetBuilder().getNamedAttr(OpTrait::TensorSource::getNdSbpAttrName(), nd_sbp_attr)); + } + // attr job_name + if (op_conf.output_conf().has_job_name()) { + const std::string& job_name = op_conf.output_conf().job_name(); + attr_vec.emplace_back( + GetBuilder().getNamedAttr("job_name", GetBuilder().getStringAttr(job_name))); + } + // add attrs + state.addAttributes(attr_vec); + // operands + std::vector<::mlir::Value> operand_vec; + auto input_bns_lbns = job_wrapper_.InputBns4OpName(op_conf.name()); + if (input_bns_lbns.second.size() != 1) { + GetModule()->emitError("output op should has only one input, op_name: " + op_conf.name()); + return failure(); + } + if (failed(AppendDataInOperand(input_bns_lbns.second[0], operand_vec))) { return failure(); } + if (failed(AppendCtrlInOperand(op_conf, operand_vec))) { return failure(); } + state.addOperands(operand_vec); + // result types + llvm::SmallVector out_types; + auto output_lbn = op_conf.name() + "/out"; + out_types.push_back(GetTensorTypeOfLbn(output_lbn)); + if (failed(AppendCtrlOutType(out_types))) { return failure(); } + state.addTypes(out_types); + // create op + auto op = GetBuilder().createOperation(state); + if (!op) { + GetModule()->emitError("fail to create op, name: " + op_conf.name()); + return failure(); + } + // record result + if (op->getNumResults() != 2) { + op->emitError("output_conf op should has two results (out and ctrl_output), but got " + + std::to_string(op->getNumResults()) + "\n"); + return failure(); + } + if (!lbn2result_.emplace(output_lbn, op->getResult(0)).second) { + op->emitError("lbn already exists, lbn: ") << output_lbn; + return failure(); + } + if (!op_name2ctrl_result_.emplace(op_conf.name(), op->getResult(1)).second) { + op->emitError("ctrl output already exists, op_name: ") << op_conf.name(); + return failure(); + } + return success(); +} +LogicalResult JobImporter::ProcessJob() { + llvm::SmallVector input_types; + llvm::SmallVector result_types; + llvm::SmallVector results; bool is_succeeded = true; + + job_wrapper_.TopoForEachOpConf([&](const ::oneflow::OperatorConf* op_conf) { + if (op_conf->has_input_conf()) { + auto type = GetInterfaceBlobConfType(op_conf->input_conf().blob_conf()); + if (type) { + input_types.emplace_back(type); + } else { + is_succeeded = false; + } + } + }); + if (!is_succeeded) { return failure(); } + + auto func_type = GetBuilder().getFunctionType(input_types, llvm::None); + auto job_op = + GetBuilder().create(GetRootLocation(), job_->job_conf().job_name(), func_type); + auto* entryBlock = job_op.addEntryBlock(); + GetBuilder().setInsertionPointToStart(entryBlock); + + is_succeeded = true; + size_t input_count = 0; job_wrapper_.TopoForEachOpConf([&](const ::oneflow::OperatorConf* op_conf) { - const auto op = *op_conf; if (is_succeeded == false) { return; } - if (op.has_user_conf()) { - is_succeeded = succeeded(ProcessUserOp(op)); + if (op_conf->has_user_conf()) { + is_succeeded = succeeded(ProcessUserOp(*op_conf)); + } else if (op_conf->has_input_conf()) { + is_succeeded = succeeded(ProcessInputOp(*op_conf, entryBlock, input_count)); + } else if (op_conf->has_output_conf()) { + is_succeeded = succeeded(ProcessOutputOp(*op_conf)); + if (is_succeeded) { + auto result = entryBlock->back().getResult(0); + results.emplace_back(result); + result_types.emplace_back(result.getType()); + } } else { - is_succeeded = succeeded(ProcessSystemOp(op)); + is_succeeded = succeeded(ProcessSystemOp(*op_conf)); } }); if (is_succeeded == false) { return failure(); } - ReturnOp returnOp; - if (!entryBlock.empty()) { returnOp = dyn_cast(entryBlock.back()); } - if (!returnOp) { GetBuilder().create(GetRootLocation()); } - GetModule().push_back(function); + mlir::oneflow::ReturnOp return_op; + if (!entryBlock->empty()) { return_op = dyn_cast(entryBlock->back()); } + if (!return_op) { GetBuilder().create(GetRootLocation(), results); } + + func_type = GetBuilder().getFunctionType(input_types, result_types); + job_op.getOperation()->setAttr(oneflow::Job::getTypeAttrName(), TypeAttr::get(func_type)); + GetModule().push_back(job_op); return success(); } @@ -272,64 +590,151 @@ LogicalResult JobImporter::TryToUpdateJob() { new_job.CopyFrom(*job_); new_job.clear_net(); new_job.mutable_placement()->clear_placement_group(); - auto convertOps = [&](Operation* op) { - if (llvm::dyn_cast(op)) { - oneflow::SystemOpAdaptor system_op_adaptor(op->getOperands(), op->getAttrDictionary()); - UpdatePlacement(op, system_op_adaptor, new_job); - auto op_name = system_op_adaptor.op_name().getValue().str(); - ::oneflow::OperatorConf op_conf = job_wrapper_.OpConf4OpName(op_name); - for (const auto& ibn : llvm::enumerate(op->getAttrOfType("input_bns"))) { - auto result = GetDataInputOperands(op)[ibn.index()].dyn_cast(); - std::string new_val = GetOutputLbn(result).getValue(); - job_wrapper_.ReplaceInputLbnInOpCustomizedConf( - &op_conf, ibn.value().dyn_cast().getValue().str(), new_val); - } - if (succeeded(ConvertCtrlInputs(op, op_conf))) { - *(new_job.mutable_net()->add_op()) = op_conf; - } else { + + Operation* job_op = nullptr; + llvm::SmallVector outputs; + + auto find_first_job = [&](oneflow::Job job) -> WalkResult { + job_op = job.getOperation(); + new_job.mutable_job_conf()->set_job_name(job.sym_name().str()); + return WalkResult::interrupt(); + }; + + GetModule().getOperation()->walk(find_first_job); + if (!job_op) { + GetModule()->emitError("job not found. module op: ") << *GetModule(); + return failure(); + } + + auto ConvertOp = [&](Operation* op) -> WalkResult { + if (op->hasTrait()) { + if (llvm::dyn_cast(op)) { + op->emitError("excepted concrete UserOp, but got generic UserOp: ") << *op; return WalkResult::interrupt(); - } - } else if (llvm::dyn_cast(op) || llvm::dyn_cast(op) - || llvm::dyn_cast(op)) { - return WalkResult::advance(); - } else { - oneflow::UserOpAdaptor user_op_adaptor(op->getOperands(), op->getAttrDictionary()); - UpdatePlacement(op, user_op_adaptor, new_job); - ::oneflow::OperatorConf op_conf; - const std::string op_name = user_op_adaptor.op_name().getValue().str(); - auto user_conf = op_conf.mutable_user_conf(); - if (succeeded(ConvertUserOpInputs(op, user_op_adaptor, user_conf)) - && succeeded(ConvertUserOpOutputs(op, user_op_adaptor, user_conf)) - && succeeded(ConvertUserOpAttributes(op, user_op_adaptor, op_conf)) - && succeeded(ConvertCtrlInputs(op, op_conf))) { - *(new_job.mutable_net()->add_op()) = op_conf; + } else if (llvm::dyn_cast(op)) { + if (failed(ConvertSystemOp(op, new_job))) { + op->emitError("failed to process SystemOp: ") << *op; + return WalkResult::interrupt(); + } + } else if (llvm::dyn_cast(op)) { + if (failed(ConvertVariableOp(op, new_job))) { + op->emitError("failed to process VariableOp: ") << *op; + return WalkResult::interrupt(); + } + } else if (llvm::dyn_cast(op) || llvm::dyn_cast(op)) { + // do nothing and advance } else { - return WalkResult::interrupt(); + if (!dyn_cast(op)) { + op->emitError("op is not UserOpCompatible ") << *op; + return WalkResult::interrupt(); + } + if (failed(ConvertUserOp(op, new_job))) { + op->emitError("failed to process UserOp: ") << *op; + return WalkResult::interrupt(); + } } - } /* convert op conf */ + } else if (llvm::dyn_cast(op)) { + // do nothing and advance + } else if (auto return_op = llvm::dyn_cast(op)) { + for (auto operand : return_op->getOperands()) { outputs.emplace_back(operand); } + } else { + op->emitError("unexcepted op: ") << *op; + return WalkResult::interrupt(); + } return WalkResult::advance(); }; - SymbolTable symbol_table(GetModule()); - if (symbol_table.lookup(job_wrapper_.job()->job_conf().job_name()) - ->walk(convertOps) - .wasInterrupted()) { - return failure(); - } else { - job_wrapper_.UpdateJob(&new_job); + if (job_op->walk(ConvertOp).wasInterrupted()) { return failure(); } + + // add input op + auto arguments = llvm::dyn_cast(job_op).body().front().getArguments(); + for (BlockArgument argument : arguments) { + for (auto& use : argument.getUses()) { + Operation* owner = use.getOwner(); + if (!dyn_cast(owner)) { return failure(); } + if (failed(ConvertInputOp(owner, new_job))) { return failure(); } + } + } + // add output op + for (auto output : outputs) { + Operation* owner = output.getDefiningOp(); + if (!dyn_cast(owner)) { return failure(); } + if (failed(ConvertOutputOp(owner, new_job))) { return failure(); } + } + + job_wrapper_.UpdateJob(&new_job); + return success(); +} + +LogicalResult JobImporter::ConvertUserOp(Operation* op, ::oneflow::Job& job) { + // TODO: concrete user op should not use generic UserOpAdaptor + oneflow::UserOpAdaptor user_op_adaptor(op->getOperands(), op->getAttrDictionary()); + UpdatePlacement(op, user_op_adaptor, job); + + auto* op_conf = job.mutable_net()->add_op(); + auto* user_conf = op_conf->mutable_user_conf(); + if (!succeeded(ConvertUserOpInputs(op, user_op_adaptor, user_conf))) { return failure(); } + if (!succeeded(ConvertUserOpOutputs(op, user_op_adaptor, user_conf))) { return failure(); } + if (!succeeded(ConvertUserOpAttributes(op, user_op_adaptor, *op_conf))) { return failure(); } + if (!succeeded(ConvertCtrlInputs(op, *op_conf))) { return failure(); } + return success(); +} + +LogicalResult JobImporter::ConvertSystemOp(Operation* op, ::oneflow::Job& job) { + oneflow::SystemOpAdaptor system_op_adaptor(op->getOperands(), op->getAttrDictionary()); + UpdatePlacement(op, system_op_adaptor, job); + auto op_name = system_op_adaptor.op_name().getValue().str(); + ::oneflow::OperatorConf op_conf = job_wrapper_.OpConf4OpName(op_name); + for (const auto& ibn : llvm::enumerate(op->getAttrOfType("input_bns"))) { + auto result = GetDataInputOperands(op)[ibn.index()].dyn_cast(); + std::string new_val = GetOutputLbn(result).getValue(); + job_wrapper_.ReplaceInputLbnInOpCustomizedConf( + &op_conf, ibn.value().dyn_cast().getValue().str(), new_val); } + if (failed(ConvertCtrlInputs(op, op_conf))) { return failure(); } + *(job.mutable_net()->add_op()) = op_conf; return success(); } +LogicalResult JobImporter::ConvertVariableOp(Operation* op, ::oneflow::Job& job) { + oneflow::VariableOpAdaptor op_adaptor(op->getOperands(), op->getAttrDictionary()); + UpdatePlacement(op, op_adaptor, job); + auto* op_conf = job.mutable_net()->add_op(); + return ConvertVariableOpConf(op, op_adaptor, op_conf); +} + +LogicalResult JobImporter::ConvertInputOp(Operation* op, ::oneflow::Job& job) { + oneflow::InputOpAdaptor op_adaptor(op->getOperands(), op->getAttrDictionary()); + UpdatePlacement(op, op_adaptor, job); + auto* op_conf = job.mutable_net()->add_op(); + return ConvertInputOpConf(op, op_adaptor, op_conf); +} + +LogicalResult JobImporter::ConvertOutputOp(Operation* op, ::oneflow::Job& job) { + oneflow::OutputOpAdaptor op_adaptor(op->getOperands(), op->getAttrDictionary()); + UpdatePlacement(op, op_adaptor, job); + auto* op_conf = job.mutable_net()->add_op(); + return ConvertOutputOpConf(op, op_adaptor, op_conf); +} + +Type JobImporter::GetInterfaceBlobConfType(const ::oneflow::InterfaceBlobConf& blob_conf) { + if (!blob_conf.has_data_type()) { return Type{}; } + if (!blob_conf.has_shape()) { return Type{}; } + auto data_type = GetTypeFromOneFlowDataType(blob_conf.data_type()); + if (!data_type.hasValue()) { return Type{}; } + return RankedTensorType::get({blob_conf.shape().dim().begin(), blob_conf.shape().dim().end()}, + *data_type); +} + LogicalResult ApplyRoundTripPatterns(RoundTripOneFlowJobWrapperInterface& job_wrapper, MLIRContext* context, OwningModuleRef& module) { mlir::PassManager pm(context); - pm.addNestedPass(::mlir::createCanonicalizerPass()); + pm.addPass(createCanonicalizerPass()); std::string graphviz; if (job_wrapper.IsLastIRPass() && std::getenv("ONEFLOW_MLIR_ENABLE_CODEGEN_FUSERS") != nullptr) { pm.addPass(oneflow::createOutlineJitFunctionPass()); } - pm.addNestedPass(oneflow::createFuseIntoExistingOpPass()); - pm.addNestedPass(::mlir::createCanonicalizerPass()); + pm.addPass(oneflow::createFuseIntoExistingOpPass()); + pm.addPass(createCanonicalizerPass()); llvm::raw_string_ostream os_graphviz(graphviz); pm.addPass(createPrintOpGraphPass(os_graphviz)); if (mlir::failed(pm.run(*module))) { @@ -378,13 +783,59 @@ void RoundTripOneFlowJob( << job->job_conf().job_name() << "\n"; exit(EXIT_FAILURE); } - } else { llvm::errs() << "fail to convert job to IR, job_name: " << job->job_conf().job_name() << "\n"; exit(EXIT_FAILURE); } } +void SaveJobToIR(RoundTripOneFlowJobWrapperInterface& job_wrapper, const std::string& path) { + const ::oneflow::Job* job = job_wrapper.job(); + mlir::MLIRContext context; + context.getOrLoadDialect(); + context.loadDialect(); + + OwningModuleRef module( + ModuleOp::create(FileLineColLoc::get(&context, "", /*line=*/0, /*column=*/0))); + JobImporter imp(job_wrapper, &context, module.get()); + if (succeeded(imp.ProcessJob())) { + mlir::PassManager pm(&context); + pm.addPass(createCanonicalizerPass()); + if (mlir::failed(pm.run(*module))) { + module->emitError("Failed to run canonicalizer pass"); + exit(EXIT_FAILURE); + } + + std::string mlir; + llvm::raw_string_ostream os_mlir(mlir); + module->print(os_mlir); + std::string filename = path + "/model.mlir"; + std::ofstream fs(filename, std::ios::trunc); + if (!fs.is_open()) { + llvm::errs() << "fail to open file " << filename; + exit(EXIT_FAILURE); + } + fs << mlir; + fs.close(); + } else { + const auto& job_name = job->job_conf().job_name(); + llvm::errs() << "fail to convert job to IR, job_name: " << job_name << "\n"; + exit(EXIT_FAILURE); + } +} + +void LoadJobFromIR(RoundTripOneFlowJobWrapperInterface& job_wrapper, const std::string& path) { + MLIRContext context; + context.getOrLoadDialect(); + context.loadDialect(); + OwningModuleRef module = parseSourceFile(path, &context); + JobImporter imp(job_wrapper, &context, module.get()); + if (failed(imp.TryToUpdateJob())) { + llvm::errs() << "fail to load job from IR"; + exit(EXIT_FAILURE); + } +} + void registerFromOneFlowJobTranslation() { TranslateToMLIRRegistration fromOneFlowJob("import-oneflow-job", [](llvm::StringRef str, MLIRContext* context) { diff --git a/oneflow/ir/test/CMakeLists.txt b/oneflow/ir/test/CMakeLists.txt index 4a9fc5f1418..6270519bfd7 100644 --- a/oneflow/ir/test/CMakeLists.txt +++ b/oneflow/ir/test/CMakeLists.txt @@ -1,3 +1,5 @@ +message(STATUS "LLVM_TOOLS_BINARY_DIR (used as LLVM_TOOLS_DIR): ${LLVM_TOOLS_BINARY_DIR}") +message(STATUS "LLVM_EXTERNAL_LIT: ${LLVM_EXTERNAL_LIT}") configure_lit_site_cfg( ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in ${CMAKE_CURRENT_BINARY_DIR}/lit.site.cfg.py @@ -5,16 +7,13 @@ configure_lit_site_cfg( ${CMAKE_CURRENT_SOURCE_DIR}/lit.cfg.py ) -set(LLVM_EXTERNAL_LIT "${llvm_monorepo_BINARY_DIR}/bin/llvm-lit" CACHE STRING "") -message(STATUS "LLVM_EXTERNAL_LIT: ${LLVM_EXTERNAL_LIT}") - set(ONEFLOW_TEST_DEPENDS FileCheck count not oneflow-opt oneflow-translate ) -add_lit_testsuite(check-oneflow "Running the OneFlow MLIR regression tests" +add_lit_testsuite(check-oneflow "Running the OneFlow MLIR regression tests from: ${CMAKE_CURRENT_SOURCE_DIR}" ${CMAKE_CURRENT_BINARY_DIR} DEPENDS ${ONEFLOW_TEST_DEPENDS} ) diff --git a/oneflow/ir/test/OneFlow/jit-outline-func.mlir b/oneflow/ir/test/OneFlow/jit-outline-func.mlir index 4ee0819a6f5..95afd283a26 100644 --- a/oneflow/ir/test/OneFlow/jit-outline-func.mlir +++ b/oneflow/ir/test/OneFlow/jit-outline-func.mlir @@ -1,13 +1,13 @@ // RUN: oneflow-opt -outline-jit-function %s | FileCheck %s builtin.module { - builtin.func @FuseCastScaleJob() { + "oneflow.job" () ({ %data_output = "oneflow.system"() {device_name = ["@0:0"], device_tag = "gpu", hierarchy = [1], input_bns = [], op_name = "Input_0", op_type_case = 137 : i32, operand_segment_sizes = dense<0> : vector<2xi32>, output_lbns = ["Input_0/out"], result_segment_sizes = dense<[1, 0]> : vector<2xi32>, scope_symbol_id = 4611686018427432958 : i64} : () -> tensor<96x96xi64> %data_output_0 = "oneflow.system"() {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], input_bns = [], op_name = "scale", op_type_case = 122 : i32, operand_segment_sizes = dense<0> : vector<2xi32>, output_lbns = ["scale/out"], result_segment_sizes = dense<[1, 0]> : vector<2xi32>, scope_symbol_id = 4611686018427437054 : i64} : () -> tensor<1xf32> %0 = "oneflow.cast"(%data_output) {device_name = ["@0:0"], device_tag = "cpu", dtype = 2 : i32, hierarchy = [1], op_name = "Cast_1", scope_symbol_id = 4611686018427437054 : i64} : (tensor<96x96xi64>) -> tensor<96x96xf32> "oneflow.system"(%data_output_0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], input_bns = ["in"], op_name = "Return_4", op_type_case = 146 : i32, operand_segment_sizes = dense<[1, 0]> : vector<2xi32>, output_lbns = [], result_segment_sizes = dense<0> : vector<2xi32>, scope_symbol_id = 4611686018427445246 : i64} : (tensor<1xf32>) -> () %1 = "oneflow.scalar_mul_by_tensor"(%0, %data_output_0) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], op_name = "ScalarMulByTensor_2", scope_symbol_id = 4611686018427437054 : i64} : (tensor<96x96xf32>, tensor<1xf32>) -> tensor<96x96xf32> "oneflow.system"(%1) {device_name = ["@0:0"], device_tag = "cpu", hierarchy = [1], input_bns = ["in"], op_name = "Return_3", op_type_case = 146 : i32, operand_segment_sizes = dense<[1, 0]> : vector<2xi32>, output_lbns = [], result_segment_sizes = dense<0> : vector<2xi32>, scope_symbol_id = 4611686018427445246 : i64} : (tensor<96x96xf32>) -> () - return - } + oneflow.return + }) {sym_name = "FuseCastScaleJob", type = () -> ()} : () -> () } // CHECK: %0 = oneflow.mlir_jit diff --git a/oneflow/ir/test/OneFlow/networks/__init__.py b/oneflow/ir/test/OneFlow/networks/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/oneflow/ir/test/OneFlow/networks/resnet50.py b/oneflow/ir/test/OneFlow/networks/resnet50.py new file mode 100644 index 00000000000..90fc3ef110c --- /dev/null +++ b/oneflow/ir/test/OneFlow/networks/resnet50.py @@ -0,0 +1,293 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import oneflow as flow +import oneflow.nn as nn +from oneflow import Tensor +from typing import Type, Any, Callable, Union, List, Optional + + +def conv3x3( + in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1 +) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError("BasicBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU() + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.0)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU() + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError( + "replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation) + ) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d( + 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False + ) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU() + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer( + block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] + ) + self.layer3 = self._make_layer( + block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] + ) + self.layer4 = self._make_layer( + block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] + ) + self.avgpool = nn.AvgPool2d((7, 7)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer( + self, + block: Type[Union[BasicBlock, Bottleneck]], + planes: int, + blocks: int, + stride: int = 1, + dilate: bool = False, + ) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + norm_layer, + ) + ) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + ) + + return nn.Sequential(*layers) + + def _forward_impl(self, x: Tensor) -> Tensor: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = flow.flatten(x, 1) + x = self.fc(x) + + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def _resnet( + arch: str, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + **kwargs: Any +) -> ResNet: + model = ResNet(block, layers, **kwargs) + return model + + +def resnet50(**kwargs: Any) -> ResNet: + r"""ResNet-5 + `"Deep Residual Learning for Image Recognition" `_. + """ + return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], **kwargs) diff --git a/oneflow/ir/test/OneFlow/test_fuse_cast_scale.py b/oneflow/ir/test/OneFlow/test_fuse_cast_scale.py index 71d3c8a49ad..895407ef068 100644 --- a/oneflow/ir/test/OneFlow/test_fuse_cast_scale.py +++ b/oneflow/ir/test/OneFlow/test_fuse_cast_scale.py @@ -80,7 +80,7 @@ def FuseCastScaleJob( test_case.assertTrue(np.allclose(loss, x * scale)) -# CHECK: %0 = oneflow.mlir_jit +# CHECK: oneflow.mlir_jit if __name__ == "__main__": unittest.main() diff --git a/oneflow/ir/test/OneFlow/test_fuse_tril_scale.py b/oneflow/ir/test/OneFlow/test_fuse_tril_scale.py index e327ada9e19..5c507547edb 100644 --- a/oneflow/ir/test/OneFlow/test_fuse_tril_scale.py +++ b/oneflow/ir/test/OneFlow/test_fuse_tril_scale.py @@ -62,23 +62,23 @@ def FuseTrilScaleJob( # cpu -# CHECK-LABEL: @FuseTrilScaleJob -# CHECK-LABEL: @FuseTrilScaleJob -# CHECK-LABEL: @FuseTrilScaleJob -# CHECK-LABEL: @FuseTrilScaleJob +# CHECK-LABEL: oneflow.job +# CHECK-LABEL: oneflow.job +# CHECK-LABEL: oneflow.job +# CHECK-LABEL: oneflow.job # gpu -# CHECK-LABEL: @FuseTrilScaleJob +# CHECK-LABEL: oneflow.job # CHECK: %0 = "oneflow.fused_scale_tril" # CHECK: %1 = "oneflow.fused_scale_tril" -# CHECK-LABEL: FuseTrilScaleJob +# CHECK-LABEL: oneflow.job # CHECK: %0 = "oneflow.fused_scale_tril" # CHECK: %1 = "oneflow.fused_scale_tril" -# CHECK-LABEL: FuseTrilScaleJob +# CHECK-LABEL: oneflow.job # CHECK: %0 = "oneflow.fused_scale_tril" # CHECK: %1 = "oneflow.fused_scale_tril" -# CHECK-LABEL: FuseTrilScaleJob +# CHECK-LABEL: oneflow.job # CHECK: %0 = "oneflow.fused_scale_tril" # CHECK: %1 = "oneflow.fused_scale_tril" diff --git a/oneflow/ir/test/OneFlow/test_graph_save_and_load.py b/oneflow/ir/test/OneFlow/test_graph_save_and_load.py new file mode 100644 index 00000000000..b19d5cc25c4 --- /dev/null +++ b/oneflow/ir/test/OneFlow/test_graph_save_and_load.py @@ -0,0 +1,96 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +# RUN: python3 %s + +import os +import sys + +sys.path.append(os.path.abspath(os.path.dirname(__file__))) + +import unittest +import oneflow as flow +import oneflow.unittest +from oneflow.core.job import job_pb2 as job_pb + +from networks.resnet50 import resnet50 + + +class InferGraph(flow.nn.Graph): + def __init__(self, placement_arg=None): + super().__init__() + model = resnet50() + if placement_arg is not None: + if "placement" in placement_arg: + model.to_consistent(**placement_arg) + else: + model.to(**placement_arg) + self.model = model + + def build(self, image): + logits = self.model(image.to("cuda")) + pred = logits.softmax() + return pred + + +@unittest.skipIf(not flow.sysconfig.with_mlir(), "only test with mlir") +@flow.unittest.skip_unless_1n1d() +class GraphSaveTestCase(flow.unittest.TestCase): + def test_save_and_load(self): + placement_arg = { + "placement": flow.placement("cuda", {0: [0]}), + "sbp": flow.sbp.broadcast, + } + graph = InferGraph(placement_arg) + image_placeholder = flow.empty( + (1, 3, 224, 224), + dtype=flow.float32, + placement=flow.placement("cpu", {0: [0]}), + sbp=flow.sbp.broadcast, + ) + graph._compile(image_placeholder) + saved_path = os.path.join("saved_model", graph.name) + if not os.path.exists(saved_path): + os.makedirs(saved_path) + flow.save(graph, saved_path) + + saved_ir_path = os.path.join(saved_path, "model.mlir") + serialized_job = oneflow._oneflow_internal.nn.graph.LoadSerializedJobFromIR( + saved_ir_path + ) + job = job_pb.Job() + job.ParseFromString(serialized_job) + + op_list = [] + op_list_ = [] + + for op in job.net.op: + op_list.append(op) + + for op in graph._forward_job_proto.net.op: + op_list_.append(op) + + def sort_by_op_name(op): + return op.name + + op_list.sort(key=sort_by_op_name) + op_list_.sort(key=sort_by_op_name) + + for (op, op_) in zip(op_list, op_list_): + self.assertTrue(op == op_) + + +if __name__ == "__main__": + unittest.main() diff --git a/oneflow/ir/test/OneFlow/test_mlir_opt.mlir b/oneflow/ir/test/OneFlow/test_mlir_opt.mlir deleted file mode 100644 index 0c81b194267..00000000000 --- a/oneflow/ir/test/OneFlow/test_mlir_opt.mlir +++ /dev/null @@ -1,30 +0,0 @@ -// RUN: python3 %s.py | FileCheck %s -module { - func @IdempotentJob() { - %data_output = "oneflow.system"() {device_name = ["0:0-0"], device_tag = "gpu", hierarchy = [1], input_bns = [], op_name = "Input_0", op_type_case = 137 : i32, operand_segment_sizes = dense<0> : vector<2xi32>, output_lbns = ["Input_0/out"], result_segment_sizes = dense<[1, 0]> : vector<2xi32>, scope_symbol_id = 4611686018427420670 : i64} : () -> tensor<96x96xf32> - // CHECK: %data_output = "oneflow.system"() - %0 = "oneflow.relu"(%data_output) {device_name = ["0:0-0"], device_tag = "gpu", hierarchy = [1], op_name = "Relu_2", op_type_name = "relu", scope_symbol_id = 4611686018427420670 : i64} : (tensor<96x96xf32>) -> tensor<96x96xf32> - // CHECK: %0 = "oneflow.relu"(%data_output) - %1 = "oneflow.relu"(%data_output) {device_name = ["0:0-0"], device_tag = "gpu", hierarchy = [1], op_name = "Relu_1", op_type_name = "relu", scope_symbol_id = 4611686018427420670 : i64} : (tensor<96x96xf32>) -> tensor<96x96xf32> - // CHECK: %1 = "oneflow.relu"(%data_output) - "oneflow.system"(%1) {device_name = ["0:0-0"], device_tag = "cpu", hierarchy = [1], input_bns = ["in"], op_name = "Return_10", op_type_case = 146 : i32, operand_segment_sizes = dense<[1, 0]> : vector<2xi32>, output_lbns = [], result_segment_sizes = dense<0> : vector<2xi32>, scope_symbol_id = 4611686018427432958 : i64} : (tensor<96x96xf32>) -> () - // CHECK: "oneflow.system"(%1) - "oneflow.system"(%0) {device_name = ["0:0-0"], device_tag = "cpu", hierarchy = [1], input_bns = ["in"], op_name = "Return_11", op_type_case = 146 : i32, operand_segment_sizes = dense<[1, 0]> : vector<2xi32>, output_lbns = [], result_segment_sizes = dense<0> : vector<2xi32>, scope_symbol_id = 4611686018427432958 : i64} : (tensor<96x96xf32>) -> () - // CHECK: "oneflow.system"(%0) - return - // CHECK: return - } -} - -module { - func @InvolutionJob() { - %data_output = "oneflow.system"() {device_name = ["0:0-0"], device_tag = "gpu", hierarchy = [1], input_bns = [], op_name = "Input_140", op_type_case = 137 : i32, operand_segment_sizes = dense<0> : vector<2xi32>, output_lbns = ["Input_140/out"], result_segment_sizes = dense<[1, 0]> : vector<2xi32>, scope_symbol_id = 4611686018427441150 : i64} : () -> tensor<96x96xf32> - // CHECK: %data_output = "oneflow.system"() - "oneflow.system"(%data_output) {device_name = ["0:0-0"], device_tag = "cpu", hierarchy = [1], input_bns = ["in"], op_name = "Return_153", op_type_case = 146 : i32, operand_segment_sizes = dense<[1, 0]> : vector<2xi32>, output_lbns = [], result_segment_sizes = dense<0> : vector<2xi32>, scope_symbol_id = 4611686018427445246 : i64} : (tensor<96x96xf32>) -> () - // CHECK: "oneflow.system"(%data_output) - "oneflow.system"(%data_output) {device_name = ["0:0-0"], device_tag = "cpu", hierarchy = [1], input_bns = ["in"], op_name = "Return_154", op_type_case = 146 : i32, operand_segment_sizes = dense<[1, 0]> : vector<2xi32>, output_lbns = [], result_segment_sizes = dense<0> : vector<2xi32>, scope_symbol_id = 4611686018427445246 : i64} : (tensor<96x96xf32>) -> () - // CHECK: "oneflow.system"(%data_output) - return - // CHECK: return - } -} diff --git a/oneflow/ir/test/OneFlow/test_mlir_opt.mlir.py b/oneflow/ir/test/OneFlow/test_mlir_opt.py similarity index 85% rename from oneflow/ir/test/OneFlow/test_mlir_opt.mlir.py rename to oneflow/ir/test/OneFlow/test_mlir_opt.py index 6becbab67d7..007dd19b9f4 100644 --- a/oneflow/ir/test/OneFlow/test_mlir_opt.mlir.py +++ b/oneflow/ir/test/OneFlow/test_mlir_opt.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +# RUN: python3 %s | FileCheck %s import unittest import numpy as np import oneflow.compatible.single_client as flow @@ -74,3 +75,15 @@ def InvolutionJob( if __name__ == "__main__": unittest.main() + +# CHECK: [[RESULT_1:%.*]] = "oneflow.input"(%arg0) +# CHECK: %0 = "oneflow.relu"([[RESULT_1:%.*]]) +# CHECK: %1 = "oneflow.relu"([[RESULT_1:%.*]]) +# CHECK: "oneflow.system"(%1) +# CHECK: "oneflow.system"(%0) +# CHECK: oneflow.return + +# CHECK: [[RESULT_1:%.*]] = "oneflow.input"(%arg0) +# CHECK: "oneflow.system"([[RESULT_1:%.*]]) +# CHECK: "oneflow.system"([[RESULT_1:%.*]]) +# CHECK: oneflow.return diff --git a/oneflow/ir/test/lit.cfg.py b/oneflow/ir/test/lit.cfg.py index 6d4c8c36a84..defb82d5d9a 100644 --- a/oneflow/ir/test/lit.cfg.py +++ b/oneflow/ir/test/lit.cfg.py @@ -72,7 +72,7 @@ # test_exec_root: The root path where tests should be run. config.test_exec_root = os.path.join(config.oneflow_obj_root, "test") -config.oneflow_tools_dir = os.path.join(config.oneflow_obj_root, "bin") +config.oneflow_tools_dir = os.path.join(config.oneflow_ir_obj_root, "bin") # Tweak the PATH to include the tools dir. llvm_config.with_environment("PATH", config.llvm_tools_dir, append_path=True) @@ -80,9 +80,7 @@ llvm_config.with_environment("ONEFLOW_MLIR_ENABLE_CODEGEN_FUSERS", "1") llvm_config.with_environment("ONEFLOW_MLIR_ENABLE_ROUND_TRIP", "1") llvm_config.with_environment( - "PYTHONPATH", - os.path.join(config.test_source_root, "../../../python"), - append_path=True, + "PYTHONPATH", os.path.join(config.oneflow_src_root, "python"), append_path=True, ) tool_dirs = [config.oneflow_tools_dir, config.llvm_tools_dir] diff --git a/oneflow/ir/test/lit.site.cfg.py.in b/oneflow/ir/test/lit.site.cfg.py.in index c5bc5f35eaa..fd83149c3d4 100644 --- a/oneflow/ir/test/lit.site.cfg.py.in +++ b/oneflow/ir/test/lit.site.cfg.py.in @@ -30,6 +30,7 @@ config.llvm_host_triple = '@LLVM_HOST_TRIPLE@' config.host_arch = "@HOST_ARCH@" config.oneflow_src_root = "@CMAKE_SOURCE_DIR@" config.oneflow_obj_root = "@CMAKE_BINARY_DIR@" +config.oneflow_ir_obj_root = "@PROJECT_BINARY_DIR@" config.WITH_MLIR_CUDA_CODEGEN = "@WITH_MLIR_CUDA_CODEGEN@" # Support substitution of the tools_dir with user parameters. This is diff --git a/oneflow/user/data/coco_data_reader.h b/oneflow/user/data/coco_data_reader.h index 659fa315460..98d9e361a03 100644 --- a/oneflow/user/data/coco_data_reader.h +++ b/oneflow/user/data/coco_data_reader.h @@ -19,7 +19,7 @@ limitations under the License. #include "oneflow/user/data/data_reader.h" #include "oneflow/user/data/coco_parser.h" #include "oneflow/core/common/str_util.h" -#include +#include "nlohmann/json.hpp" namespace oneflow { namespace data { diff --git a/oneflow/user/kernels/add_n_kernel.cpp b/oneflow/user/kernels/add_n_kernel.cpp index 6b1defe8b95..1e6063e1bf7 100644 --- a/oneflow/user/kernels/add_n_kernel.cpp +++ b/oneflow/user/kernels/add_n_kernel.cpp @@ -16,7 +16,7 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/core/ep/include/primitive/add.h" #include "oneflow/core/kernel/cuda_graph_support.h" -#include "oneflow/user/kernels/op_kernel_state_wrapper.h" +#include "oneflow/user/kernels/op_kernel_wrapper.h" namespace oneflow { @@ -39,7 +39,7 @@ class AddNKernel : public OpKernel, public CudaGraphSupport { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: - void Compute(KernelComputeContext* ctx, OpKernelState* state) const override { + void Compute(KernelComputeContext* ctx) const override { auto primitive = NewAddPrimitive(ctx); CHECK(primitive); Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); diff --git a/oneflow/user/kernels/arange_kernel.cpp b/oneflow/user/kernels/arange_kernel.cpp index 999152fb01f..60944e4dd72 100644 --- a/oneflow/user/kernels/arange_kernel.cpp +++ b/oneflow/user/kernels/arange_kernel.cpp @@ -50,6 +50,7 @@ class ArangeKernel final : public OpKernel { delta = static_cast(float_delta); limit = static_cast(float_limit); } + if (arange_elem_cnt == 0) { return; } ArangeFunctor()(ctx->stream(), start, delta, arange_elem_cnt, output); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } diff --git a/oneflow/user/kernels/avg_pooling_kernel.cpp b/oneflow/user/kernels/avg_pooling_kernel.cpp index fee02882f7a..3ae5896bcce 100644 --- a/oneflow/user/kernels/avg_pooling_kernel.cpp +++ b/oneflow/user/kernels/avg_pooling_kernel.cpp @@ -17,14 +17,14 @@ limitations under the License. namespace oneflow { -struct AvgPoolingOpKernelState final : public user_op::OpKernelState { +struct AvgPoolingOpKernelCache final : public user_op::OpKernelCache { AvgPoolingParams3D params_3d; - AvgPoolingOpKernelState(AvgPoolingParams3D params_3d) : params_3d(params_3d) {} - const AvgPoolingParams3D& GetParams3D() { return params_3d; } + explicit AvgPoolingOpKernelCache(const AvgPoolingParams3D& params_3d) : params_3d(params_3d) {} + const AvgPoolingParams3D& GetParams3D() const { return params_3d; } }; -std::shared_ptr DoCreateAvgOpKernelState( - user_op::KernelComputeContext* ctx, const int32_t& dim) { +std::shared_ptr CreateAvgOpKernelCache(user_op::KernelCacheContext* ctx, + const int32_t& dim) { const Shape& x_shape = ctx->TensorDesc4ArgNameAndIndex("x", 0)->shape(); const std::string& data_format = ctx->Attr("data_format"); const std::vector& padding = ctx->Attr>("padding"); @@ -37,8 +37,8 @@ std::shared_ptr DoCreateAvgOpKernelState( AvgPoolingParams3D params_3d = AvgPoolingParams3D(dim, x_shape, data_format, padding, kernel_size, stride, ceil_mode, count_include_pad, divisor_override); - std::shared_ptr state(new AvgPoolingOpKernelState(params_3d)); - return state; + std::shared_ptr cache(new AvgPoolingOpKernelCache(params_3d)); + return cache; } template @@ -126,12 +126,18 @@ class AvgPool1dKernel final : public user_op::OpKernel { private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateAvgOpKernelCache(ctx, 1); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); - const auto& pooling_state = DoCreateAvgOpKernelState(ctx, 1); - const AvgPoolingParams3D& params_3d = pooling_state->GetParams3D(); + const auto* pooling_cache = dynamic_cast(cache); + const AvgPoolingParams3D& params_3d = pooling_cache->GetParams3D(); const int64_t elem_num = y->shape().elem_cnt(); const T* src = x->dptr(); @@ -153,12 +159,18 @@ class AvgPool1dGradKernel final : public user_op::OpKernel { private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateAvgOpKernelCache(ctx, 1); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); - const auto& pooling_state = DoCreateAvgOpKernelState(ctx, 1); - const AvgPoolingParams3D& params_3d = pooling_state->GetParams3D(); + const auto* pooling_cache = dynamic_cast(cache); + const AvgPoolingParams3D& params_3d = pooling_cache->GetParams3D(); const int64_t elem_num = dy->shape().elem_cnt(); const T* src = dy->dptr(); @@ -182,12 +194,18 @@ class AvgPool2dKernel final : public user_op::OpKernel { private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateAvgOpKernelCache(ctx, 2); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); - const auto& pooling_state = DoCreateAvgOpKernelState(ctx, 2); - const AvgPoolingParams3D& params_3d = pooling_state->GetParams3D(); + const auto* pooling_cache = dynamic_cast(cache); + const AvgPoolingParams3D& params_3d = pooling_cache->GetParams3D(); const int64_t elem_num = y->shape().elem_cnt(); const T* src = x->dptr(); @@ -209,12 +227,18 @@ class AvgPool2dGradKernel final : public user_op::OpKernel { private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateAvgOpKernelCache(ctx, 2); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); - const auto& pooling_state = DoCreateAvgOpKernelState(ctx, 2); - const AvgPoolingParams3D& params_3d = pooling_state->GetParams3D(); + const auto* pooling_cache = dynamic_cast(cache); + const AvgPoolingParams3D& params_3d = pooling_cache->GetParams3D(); const int64_t elem_num = dy->shape().elem_cnt(); const T* src = dy->dptr(); @@ -238,12 +262,18 @@ class AvgPool3dKernel final : public user_op::OpKernel { private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateAvgOpKernelCache(ctx, 3); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); - const auto& pooling_state = DoCreateAvgOpKernelState(ctx, 3); - const AvgPoolingParams3D& params_3d = pooling_state->GetParams3D(); + const auto* pooling_cache = dynamic_cast(cache); + const AvgPoolingParams3D& params_3d = pooling_cache->GetParams3D(); const int64_t elem_num = y->shape().elem_cnt(); const T* src = x->dptr(); @@ -265,12 +295,18 @@ class AvgPool3dGradKernel final : public user_op::OpKernel { private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateAvgOpKernelCache(ctx, 3); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); - const auto& pooling_state = DoCreateAvgOpKernelState(ctx, 3); - const AvgPoolingParams3D& params_3d = pooling_state->GetParams3D(); + const auto* pooling_cache = dynamic_cast(cache); + const AvgPoolingParams3D& params_3d = pooling_cache->GetParams3D(); const int64_t elem_num = dy->shape().elem_cnt(); const T* src = dy->dptr(); @@ -326,4 +362,4 @@ REGISTER_AVG_POOLING_WITH_DEVICE(DeviceType::kCUDA) OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_AVG_POOLING_KERNEL_UTIL, (DeviceType::kCPU), AVG_POOLING_DATA_TYPE_CPU_SEQ); -} // namespace oneflow \ No newline at end of file +} // namespace oneflow diff --git a/oneflow/user/kernels/bernoulli_kernel.cpp b/oneflow/user/kernels/bernoulli_kernel.cpp index b9453050f12..3fa324dc958 100644 --- a/oneflow/user/kernels/bernoulli_kernel.cpp +++ b/oneflow/user/kernels/bernoulli_kernel.cpp @@ -15,7 +15,7 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/distributions/common.h" -#include "oneflow/user/kernels/op_kernel_state_wrapper.h" +#include "oneflow/user/kernels/op_kernel_wrapper.h" #include "oneflow/user/kernels/random_seed_util.h" #include "oneflow/user/kernels/random_mask_generator.h" @@ -35,7 +35,8 @@ class BernoulliKerenl final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); const T* in_dptr = in_blob->dptr(); diff --git a/oneflow/user/kernels/cast_kernel.cpp b/oneflow/user/kernels/cast_kernel.cpp index 7b74b617b79..986751993f5 100644 --- a/oneflow/user/kernels/cast_kernel.cpp +++ b/oneflow/user/kernels/cast_kernel.cpp @@ -16,7 +16,7 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/cast.h" -#include "oneflow/user/kernels/op_kernel_state_wrapper.h" +#include "oneflow/user/kernels/op_kernel_wrapper.h" namespace oneflow { @@ -38,7 +38,7 @@ class CastKernel final : public OpKernel, public user_op::CudaGraphSupport { ~CastKernel() = default; private: - void Compute(KernelComputeContext* ctx, OpKernelState* state) const override { + void Compute(KernelComputeContext* ctx) const override { const Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex("in", 0); Tensor* output_tenor = ctx->Tensor4ArgNameAndIndex("out", 0); const int64_t elem_cnt = input_tensor->shape().elem_cnt(); diff --git a/oneflow/user/kernels/coco_reader_kernel.cpp b/oneflow/user/kernels/coco_reader_kernel.cpp index f4a22777c50..99ec0bd71e0 100644 --- a/oneflow/user/kernels/coco_reader_kernel.cpp +++ b/oneflow/user/kernels/coco_reader_kernel.cpp @@ -43,7 +43,8 @@ class COCOReaderKernel final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { auto* reader = dynamic_cast(state); reader->Read(ctx); } diff --git a/oneflow/user/kernels/combined_margin_loss_kernel.cpp b/oneflow/user/kernels/combined_margin_loss_kernel.cpp index 4648ff4a3e2..1fe1882b15e 100644 --- a/oneflow/user/kernels/combined_margin_loss_kernel.cpp +++ b/oneflow/user/kernels/combined_margin_loss_kernel.cpp @@ -22,10 +22,10 @@ namespace oneflow { namespace { -class CombinedMarginLossOpKernelState final : public user_op::OpKernelState { +class CombinedMarginLossOpKernelCache final : public user_op::OpKernelCache { public: - CombinedMarginLossOpKernelState(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {} - ~CombinedMarginLossOpKernelState() override = default; + CombinedMarginLossOpKernelCache(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {} + ~CombinedMarginLossOpKernelCache() override = default; int64_t lower() const { return lower_; } int64_t upper() const { return upper_; } @@ -35,11 +35,9 @@ class CombinedMarginLossOpKernelState final : public user_op::OpKernelState { const int64_t upper_; }; -std::shared_ptr CreateCombinedMarginLossOpKernelState( - user_op::KernelInitContext* ctx, const std::string& in_arg_name) { - if (ctx->parallel_ctx().parallel_num() == 1) { - return std::shared_ptr(nullptr); - } +std::shared_ptr CreateCombinedMarginLossOpKernelCache( + user_op::KernelCacheContext* ctx, const std::string& in_arg_name) { + if (ctx->parallel_ctx().parallel_num() == 1) { return nullptr; } const cfg::SbpParallel& in_sbp = ctx->SbpParallel4ArgNameAndIndex(in_arg_name, 0); if (in_sbp.has_split_parallel() && in_sbp.split_parallel().axis() == 1 @@ -50,11 +48,11 @@ std::shared_ptr CreateCombinedMarginLossOpKernelState( const auto depth = ctx->Attr("depth"); CHECK_EQ(depth, in_logical_desc->shape().At(1)); BalancedSplitter bs(depth, ctx->parallel_ctx().parallel_num()); - return std::make_shared( + return std::make_shared( bs.At(ctx->parallel_ctx().parallel_id()).begin(), bs.At(ctx->parallel_ctx().parallel_id()).end()); } else { - return std::shared_ptr(nullptr); + return nullptr; } } @@ -66,13 +64,14 @@ class CombinedMarginLossCpuKernel final : public user_op::OpKernel { CombinedMarginLossCpuKernel() = default; ~CombinedMarginLossCpuKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return CreateCombinedMarginLossOpKernelState(ctx, "x"); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateCombinedMarginLossOpKernelCache(ctx, "x"); } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const T* x_ptr = x->dptr(); const K* label_ptr = ctx->Tensor4ArgNameAndIndex("label", 0)->dptr(); @@ -82,11 +81,11 @@ class CombinedMarginLossCpuKernel final : public user_op::OpKernel { const float m2 = ctx->Attr("m2"); const float m3 = ctx->Attr("m3"); int64_t lower_bound = 0; - if (state != nullptr) { - auto* kernel_state = dynamic_cast(state); - CHECK_NOTNULL(kernel_state); - CHECK_EQ(x->shape().Count(1), kernel_state->upper() - kernel_state->lower()); - lower_bound = kernel_state->lower(); + if (cache != nullptr) { + auto* kernel_cache = dynamic_cast(cache); + CHECK_NOTNULL(kernel_cache); + CHECK_EQ(x->shape().Count(1), kernel_cache->upper() - kernel_cache->lower()); + lower_bound = kernel_cache->lower(); } const int64_t num_classes = x->shape().Count(1); FOR_RANGE(int32_t, i, 0, x->shape().elem_cnt()) { @@ -126,13 +125,14 @@ class CombinedMarginLossGradCpuKernel final : public user_op::OpKernel { CombinedMarginLossGradCpuKernel() = default; ~CombinedMarginLossGradCpuKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return CreateCombinedMarginLossOpKernelState(ctx, "dy"); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateCombinedMarginLossOpKernelCache(ctx, "dy"); } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const T* dy_ptr = dy->dptr(); const K* label_ptr = ctx->Tensor4ArgNameAndIndex("label", 0)->dptr(); @@ -141,11 +141,11 @@ class CombinedMarginLossGradCpuKernel final : public user_op::OpKernel { const float m1 = ctx->Attr("m1"); const float m2 = ctx->Attr("m2"); int64_t lower_bound = 0; - if (state != nullptr) { - auto* kernel_state = dynamic_cast(state); - CHECK_NOTNULL(kernel_state); - CHECK_EQ(dy->shape().Count(1), kernel_state->upper() - kernel_state->lower()); - lower_bound = kernel_state->lower(); + if (cache != nullptr) { + auto* kernel_cache = dynamic_cast(cache); + CHECK_NOTNULL(kernel_cache); + CHECK_EQ(dy->shape().Count(1), kernel_cache->upper() - kernel_cache->lower()); + lower_bound = kernel_cache->lower(); } const int64_t num_classes = dy->shape().Count(1); diff --git a/oneflow/user/kernels/combined_margin_loss_kernel.cu b/oneflow/user/kernels/combined_margin_loss_kernel.cu index e8983f23572..425ed51ef8d 100644 --- a/oneflow/user/kernels/combined_margin_loss_kernel.cu +++ b/oneflow/user/kernels/combined_margin_loss_kernel.cu @@ -67,10 +67,10 @@ __global__ void GpuBackward(const int64_t n, const int64_t num_classes, const in } } -class CombinedMarginLossOpKernelState final : public user_op::OpKernelState { +class CombinedMarginLossOpKernelCache final : public user_op::OpKernelCache { public: - CombinedMarginLossOpKernelState(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {} - ~CombinedMarginLossOpKernelState() override = default; + CombinedMarginLossOpKernelCache(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {} + ~CombinedMarginLossOpKernelCache() override = default; int64_t lower() const { return lower_; } int64_t upper() const { return upper_; } @@ -80,11 +80,9 @@ class CombinedMarginLossOpKernelState final : public user_op::OpKernelState { const int64_t upper_; }; -std::shared_ptr CreateCombinedMarginLossOpKernelState( - user_op::KernelInitContext* ctx, const std::string& in_arg_name) { - if (ctx->parallel_ctx().parallel_num() == 1) { - return std::shared_ptr(nullptr); - } +std::shared_ptr CreateCombinedMarginLossOpKernelCache( + user_op::KernelCacheContext* ctx, const std::string& in_arg_name) { + if (ctx->parallel_ctx().parallel_num() == 1) { return nullptr; } const cfg::SbpParallel& in_sbp = ctx->SbpParallel4ArgNameAndIndex(in_arg_name, 0); if (in_sbp.has_split_parallel() && in_sbp.split_parallel().axis() == 1 @@ -95,11 +93,11 @@ std::shared_ptr CreateCombinedMarginLossOpKernelState( const auto depth = ctx->Attr("depth"); CHECK_EQ(depth, in_logical_desc->shape().At(1)); BalancedSplitter bs(depth, ctx->parallel_ctx().parallel_num()); - return std::make_shared( + return std::make_shared( bs.At(ctx->parallel_ctx().parallel_id()).begin(), bs.At(ctx->parallel_ctx().parallel_id()).end()); } else { - return std::shared_ptr(nullptr); + return nullptr; } } @@ -111,14 +109,16 @@ class CombinedMarginLossGpuKernel final : public user_op::OpKernel { CombinedMarginLossGpuKernel() = default; ~CombinedMarginLossGpuKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return CreateCombinedMarginLossOpKernelState(ctx, "x"); + using user_op::OpKernel::InitOpKernelCache; + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateCombinedMarginLossOpKernelCache(ctx, "x"); } private: using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex("label", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); @@ -127,11 +127,11 @@ class CombinedMarginLossGpuKernel final : public user_op::OpKernel { const float m2 = ctx->Attr("m2"); const float m3 = ctx->Attr("m3"); int64_t lower_bound = 0; - if (state != nullptr) { - auto* kernel_state = dynamic_cast(state); - CHECK_NOTNULL(kernel_state); - CHECK_EQ(x->shape().Count(1), kernel_state->upper() - kernel_state->lower()); - lower_bound = kernel_state->lower(); + if (cache != nullptr) { + auto* kernel_cache = dynamic_cast(cache); + CHECK_NOTNULL(kernel_cache); + CHECK_EQ(x->shape().Count(1), kernel_cache->upper() - kernel_cache->lower()); + lower_bound = kernel_cache->lower(); } if (m1 == 1.0 && m2 == 0.0) { GpuForward<<shape().elem_cnt()), kCudaThreadsNumPerBlock, @@ -168,14 +168,16 @@ class CombinedMarginLossGradGpuKernel final : public user_op::OpKernel { CombinedMarginLossGradGpuKernel() = default; ~CombinedMarginLossGradGpuKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return CreateCombinedMarginLossOpKernelState(ctx, "dy"); + using user_op::OpKernel::InitOpKernelCache; + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateCombinedMarginLossOpKernelCache(ctx, "dy"); } private: using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex("label", 0); const user_op::Tensor* theta = ctx->Tensor4ArgNameAndIndex("theta", 0); @@ -184,11 +186,11 @@ class CombinedMarginLossGradGpuKernel final : public user_op::OpKernel { const float m2 = ctx->Attr("m2"); const float m3 = ctx->Attr("m3"); int64_t lower_bound = 0; - if (state != nullptr) { - auto* kernel_state = dynamic_cast(state); - CHECK_NOTNULL(kernel_state); - CHECK_EQ(dy->shape().Count(1), kernel_state->upper() - kernel_state->lower()); - lower_bound = kernel_state->lower(); + if (cache != nullptr) { + auto* kernel_cache = dynamic_cast(cache); + CHECK_NOTNULL(kernel_cache); + CHECK_EQ(dy->shape().Count(1), kernel_cache->upper() - kernel_cache->lower()); + lower_bound = kernel_cache->lower(); } if (m1 == 1.0 && m2 == 0.0) { GpuBackward diff --git a/oneflow/user/kernels/conv_cudnn_kernels.cpp b/oneflow/user/kernels/conv_cudnn_kernels.cpp index 8e5cace1108..14bfdda3656 100644 --- a/oneflow/user/kernels/conv_cudnn_kernels.cpp +++ b/oneflow/user/kernels/conv_cudnn_kernels.cpp @@ -138,7 +138,7 @@ CudnnTensorDesc* GetBiasCudnnTensorDesc<3>(const std::string& data_format, int32 return new CudnnTensorDesc(data_type, NDims, bias_dim.data(), stride_of_bias_tensor.data()); } -struct ConvCudnnOpKernelState final : public user_op::OpKernelState { +struct ConvCudnnOpKernelCache final : public user_op::OpKernelCache { std::unique_ptr bias_desc; }; @@ -150,12 +150,12 @@ class ConvGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphS bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - std::shared_ptr CreateConvCudnnOpKernelState( - user_op::KernelComputeContext* ctx) const { + std::shared_ptr CreateConvCudnnOpKernelCache( + user_op::KernelCacheContext* ctx) const { const auto& data_format = ctx->Attr("data_format"); int32_t filters = ctx->Attr("filters"); - std::shared_ptr state(new ConvCudnnOpKernelState()); + std::shared_ptr state(new ConvCudnnOpKernelCache()); const user_op::TensorDesc* bias = ctx->TensorDesc4ArgNameAndIndex("bias", 0); if (bias != nullptr) { @@ -167,7 +167,13 @@ class ConvGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphS } private: - void Compute(user_op::KernelComputeContext* ctx) const override { + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateConvCudnnOpKernelCache(ctx); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); user_op::Tensor* buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); @@ -187,10 +193,10 @@ class ConvGpuKernel final : public user_op::OpKernel, public user_op::CudaGraphS const user_op::Tensor* bias = ctx->Tensor4ArgNameAndIndex("bias", 0); if (bias != nullptr) { - const auto& conv_state = CreateConvCudnnOpKernelState(ctx); - CHECK_NOTNULL(conv_state.get()); + const auto* conv_cache = dynamic_cast(cache); + CHECK_NOTNULL(conv_cache); OF_CUDNN_CHECK(cudnnAddTensor(ctx->stream()->As()->cudnn_handle(), - CudnnSPOnePtr(), conv_state->bias_desc->Get(), + CudnnSPOnePtr(), conv_cache->bias_desc->Get(), bias->dptr(), CudnnSPOnePtr(), args.ydesc.Get(), out->mut_dptr())); } @@ -290,7 +296,7 @@ class ConvDataGradGpuKernel final : public user_op::OpKernel, public user_op::Cu .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ const auto& dy = ctx->InputTensorDesc("dy", 0); \ const auto& filter = ctx->InputTensorDesc("filter", 0); \ - const auto* dx = ctx->TensorDesc4ArgNameAndIndex("dx", 0); \ + const auto* dx = ctx->OutputTensorDesc("dx", 0); \ const auto& cudnn_conf = Global::Get()->resource().cudnn_conf(); \ return InferTmpSizeWithCudnn( \ dx, &filter, &dy, *ctx, cudnn_conf.has_cudnn_conv_force_bwd_data_algo(), \ diff --git a/oneflow/user/kernels/conv_kernels.cpp b/oneflow/user/kernels/conv_kernels.cpp index ac792deeaf4..80cd64000b0 100644 --- a/oneflow/user/kernels/conv_kernels.cpp +++ b/oneflow/user/kernels/conv_kernels.cpp @@ -294,10 +294,10 @@ struct ConvKernelUtil final { }; template -struct ConvOpKernelState final : public user_op::OpKernelState { - Im2ColFunc im2col_func_; - Col2ImFunc col2im_func_; - GemmFunc forward_func_; +struct ConvOpKernelCache final : public user_op::OpKernelCache { + Im2ColFunc im2col_func_ = nullptr; + Col2ImFunc col2im_func_ = nullptr; + GemmFunc forward_func_ = nullptr; Shape in_5d_shape_; Shape out_5d_shape_; @@ -307,46 +307,31 @@ struct ConvOpKernelState final : public user_op::OpKernelState { std::vector dilation_rate_3d_; std::vector padding_before_3d_; - enum CBLAS_TRANSPOSE is_out_diff_need_trans_; - int32_t idx_offset_; - bool is_dynamic_; - - void Update(const ShapeView& x_shape, const ShapeView& out_shape) { - auto Gen5DShape = [](const ShapeView& shape, int32_t idx_offset) -> Shape { - DimVector ret_vec; - shape.ToDimVector(&ret_vec); - int32_t ndims = ret_vec.size() - 2; - ret_vec.insert(ret_vec.begin() + idx_offset, 3 - ndims, 1); - return Shape(ret_vec); - }; - if (is_dynamic_) { - Shape in_shape; - in_5d_shape_ = Gen5DShape(x_shape, idx_offset_); - out_5d_shape_ = Gen5DShape(out_shape, idx_offset_); - } - } + enum CBLAS_TRANSPOSE is_out_diff_need_trans_ = CblasNoTrans; + int32_t idx_offset_{}; + bool is_dynamic_{}; }; template -std::shared_ptr> CreateConvOpKernelState(user_op::KernelComputeContext* ctx, +std::shared_ptr> CreateConvOpKernelCache(user_op::KernelCacheContext* ctx, const std::string& in_name, const std::string& out_name, const std::string& weight_name) { const auto& data_format = ctx->Attr("data_format"); - std::shared_ptr> state(new ConvOpKernelState()); + std::shared_ptr> cache(new ConvOpKernelCache()); if (data_format == "channels_first") { - state->im2col_func_ = ConvKernelUtil::NCDHWIm2Col; - state->col2im_func_ = ConvKernelUtil::NCDHWCol2Im; - state->forward_func_ = Gemm4ChannelFirst; - state->is_out_diff_need_trans_ = CblasNoTrans; - state->idx_offset_ = 2; + cache->im2col_func_ = ConvKernelUtil::NCDHWIm2Col; + cache->col2im_func_ = ConvKernelUtil::NCDHWCol2Im; + cache->forward_func_ = Gemm4ChannelFirst; + cache->is_out_diff_need_trans_ = CblasNoTrans; + cache->idx_offset_ = 2; } else { - state->im2col_func_ = ConvKernelUtil::NDHWCIm2Col; - state->col2im_func_ = ConvKernelUtil::NDHWCCol2Im; - state->forward_func_ = Gemm4ChannelLast; - state->is_out_diff_need_trans_ = CblasTrans; - state->idx_offset_ = 1; + cache->im2col_func_ = ConvKernelUtil::NDHWCIm2Col; + cache->col2im_func_ = ConvKernelUtil::NDHWCCol2Im; + cache->forward_func_ = Gemm4ChannelLast; + cache->is_out_diff_need_trans_ = CblasTrans; + cache->idx_offset_ = 1; } auto Gen5DShape = [](const Shape& shape, int32_t idx_offset) -> Shape { @@ -355,32 +340,33 @@ std::shared_ptr> CreateConvOpKernelState(user_op::KernelCom ret_vec.insert(ret_vec.begin() + idx_offset, 3 - ndims, 1); return Shape(ret_vec); }; - state->in_5d_shape_ = - Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(in_name, 0)->shape(), state->idx_offset_); - state->out_5d_shape_ = - Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(out_name, 0)->shape(), state->idx_offset_); - state->weight_5d_shape_ = - Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(weight_name, 0)->shape(), state->idx_offset_); + const auto* in_tensor = ctx->TensorDesc4ArgNameAndIndex(in_name, 0); + const auto& in_shape = in_tensor->shape(); + cache->in_5d_shape_ = Gen5DShape(in_shape, cache->idx_offset_); + cache->out_5d_shape_ = + Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(out_name, 0)->shape(), cache->idx_offset_); + cache->weight_5d_shape_ = + Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(weight_name, 0)->shape(), cache->idx_offset_); auto Gen3DVec = [](const std::vector& origin_vec) -> std::vector { std::vector ret_vec = origin_vec; ret_vec.insert(ret_vec.begin(), 3 - ret_vec.size(), 1); return ret_vec; }; - state->strides_3d_ = Gen3DVec(ctx->Attr>("strides")); - state->dilation_rate_3d_ = Gen3DVec(ctx->Attr>("dilation_rate")); - state->is_dynamic_ = ctx->TensorDesc4ArgNameAndIndex(in_name, 0)->is_dynamic(); + cache->strides_3d_ = Gen3DVec(ctx->Attr>("strides")); + cache->dilation_rate_3d_ = Gen3DVec(ctx->Attr>("dilation_rate")); + cache->is_dynamic_ = ctx->TensorDesc4ArgNameAndIndex(in_name, 0)->is_dynamic(); const auto& padding_before = ctx->Attr>("padding_before"); FOR_RANGE(uint8_t, dim, 0, 3) { int64_t index = static_cast(dim) - (3 - padding_before.size()); if (index < 0) { - state->padding_before_3d_.emplace_back(0); + cache->padding_before_3d_.emplace_back(0); } else { - state->padding_before_3d_.emplace_back(padding_before.at(index)); + cache->padding_before_3d_.emplace_back(padding_before.at(index)); } } - return state; + return cache; } template @@ -397,9 +383,15 @@ class ConvCpuKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& conv_state = CreateConvOpKernelState(ctx, "in", "out", "weight"); - CHECK_NOTNULL(conv_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateConvOpKernelCache(ctx, "in", "out", "weight"); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + const auto* conv_cache = dynamic_cast*>(cache); + CHECK_NOTNULL(conv_cache); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); @@ -410,20 +402,20 @@ class ConvCpuKernel final : public user_op::OpKernel { bool is_bias_mul_inited = false; for (int64_t i = 0; i < in->shape().At(0); ++i) { - conv_state->im2col_func_(GetImgDptr(in, i), ShapeView(conv_state->in_5d_shape_), - ShapeView(conv_state->weight_5d_shape_), - ShapeView(conv_state->out_5d_shape_), conv_state->strides_3d_.data(), - conv_state->dilation_rate_3d_.data(), - conv_state->padding_before_3d_.data(), col_buf_dptr); + conv_cache->im2col_func_(GetImgDptr(in, i), ShapeView(conv_cache->in_5d_shape_), + ShapeView(conv_cache->weight_5d_shape_), + ShapeView(conv_cache->out_5d_shape_), conv_cache->strides_3d_.data(), + conv_cache->dilation_rate_3d_.data(), + conv_cache->padding_before_3d_.data(), col_buf_dptr); // channels first: out = weight * col_buf // channels last: out = (weight * col_buf)(T) - int32_t idx_offset = conv_state->idx_offset_; - conv_state->forward_func_( + int32_t idx_offset = conv_cache->idx_offset_; + conv_cache->forward_func_( ctx->stream(), CblasNoTrans, CblasNoTrans, - conv_state->weight_5d_shape_.At(0), // filter - conv_state->out_5d_shape_.Count(idx_offset, idx_offset + 3), // od * oh * ow - conv_state->weight_5d_shape_.Count(1), // ci * kd * kh * kw + conv_cache->weight_5d_shape_.At(0), // filter + conv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3), // od * oh * ow + conv_cache->weight_5d_shape_.Count(1), // ci * kd * kh * kw static_cast(1), weight->dptr(), col_buf_dptr, static_cast(0), GetImgMutDptr(out, i)); @@ -441,10 +433,10 @@ class ConvCpuKernel final : public user_op::OpKernel { // channels first: out += bias * bias_mul // channels last: out += (bias * bias_mul)(T) - conv_state->forward_func_( + conv_cache->forward_func_( ctx->stream(), CblasNoTrans, CblasNoTrans, - conv_state->weight_5d_shape_.At(0), // filter - conv_state->out_5d_shape_.Count(idx_offset, idx_offset + 3), // od * oh * ow + conv_cache->weight_5d_shape_.At(0), // filter + conv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3), // od * oh * ow 1, // 1 static_cast(1), bias->dptr(), bias_mul_dptr, static_cast(1), GetImgMutDptr(out, i)); @@ -493,9 +485,15 @@ class ConvDataGradCpuKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& conv_state = CreateConvOpKernelState(ctx, "dx", "dy", "filter"); - CHECK_NOTNULL(conv_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateConvOpKernelCache(ctx, "dx", "dy", "filter"); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + const auto* conv_cache = dynamic_cast*>(cache); + CHECK_NOTNULL(conv_cache); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* filter = ctx->Tensor4ArgNameAndIndex("filter", 0); @@ -505,24 +503,24 @@ class ConvDataGradCpuKernel final : public user_op::OpKernel { Memset(ctx->stream(), dx->mut_dptr(), 0, dx->shape().elem_cnt() * sizeof(T)); - int32_t idx_offset = conv_state->idx_offset_; + int32_t idx_offset = conv_cache->idx_offset_; FOR_RANGE(int64_t, i, 0, dy->shape().At(0)) { // channels first: col_buf' = weight(T) * out[i]' // channels last : col_buf' = weight(T) * out[i]'(T) NewKernelUtil::OFGemm( - ctx->stream(), CblasTrans, conv_state->is_out_diff_need_trans_, - conv_state->weight_5d_shape_.Count(1), // ci * kd * kh * kw - conv_state->out_5d_shape_.Count(idx_offset, idx_offset + 3), // od * oh * ow - conv_state->weight_5d_shape_.At(0), // filter + ctx->stream(), CblasTrans, conv_cache->is_out_diff_need_trans_, + conv_cache->weight_5d_shape_.Count(1), // ci * kd * kh * kw + conv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3), // od * oh * ow + conv_cache->weight_5d_shape_.At(0), // filter static_cast(1), filter->dptr(), GetImgDptr(dy, i), static_cast(0), col_buf->mut_dptr()); // in' = col2im(col_buf') - conv_state->col2im_func_(col_buf->dptr(), ShapeView(conv_state->in_5d_shape_), - ShapeView(conv_state->weight_5d_shape_), - ShapeView(conv_state->out_5d_shape_), conv_state->strides_3d_.data(), - conv_state->dilation_rate_3d_.data(), - conv_state->padding_before_3d_.data(), GetImgMutDptr(dx, i)); + conv_cache->col2im_func_(col_buf->dptr(), ShapeView(conv_cache->in_5d_shape_), + ShapeView(conv_cache->weight_5d_shape_), + ShapeView(conv_cache->out_5d_shape_), conv_cache->strides_3d_.data(), + conv_cache->dilation_rate_3d_.data(), + conv_cache->padding_before_3d_.data(), GetImgMutDptr(dx, i)); } if (ctx->has_input("_add_to_output", 0)) { const user_op::Tensor* add_to_output = ctx->Tensor4ArgNameAndIndex("_add_to_output", 0); @@ -568,9 +566,15 @@ class ConvFilterGradCpuKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& conv_state = CreateConvOpKernelState(ctx, "x", "dy", "filter_diff"); - CHECK_NOTNULL(conv_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateConvOpKernelCache(ctx, "x", "dy", "filter_diff"); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + const auto* conv_cache = dynamic_cast*>(cache); + CHECK_NOTNULL(conv_cache); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); @@ -579,21 +583,21 @@ class ConvFilterGradCpuKernel final : public user_op::OpKernel { Memset(ctx->stream(), filter_diff->mut_dptr(), 0, filter_diff->shape().elem_cnt() * sizeof(T)); - int32_t idx_offset = conv_state->idx_offset_; + int32_t idx_offset = conv_cache->idx_offset_; FOR_RANGE(int64_t, i, 0, dy->shape().At(0)) { - conv_state->im2col_func_(GetImgDptr(x, i), ShapeView(conv_state->in_5d_shape_), - ShapeView(conv_state->weight_5d_shape_), - ShapeView(conv_state->out_5d_shape_), conv_state->strides_3d_.data(), - conv_state->dilation_rate_3d_.data(), - conv_state->padding_before_3d_.data(), col_buf->mut_dptr()); + conv_cache->im2col_func_(GetImgDptr(x, i), ShapeView(conv_cache->in_5d_shape_), + ShapeView(conv_cache->weight_5d_shape_), + ShapeView(conv_cache->out_5d_shape_), conv_cache->strides_3d_.data(), + conv_cache->dilation_rate_3d_.data(), + conv_cache->padding_before_3d_.data(), col_buf->mut_dptr()); // channels first: weight' += out[i]' * col_buf(T) // channels last : weight' += out[i]'(T) * col_buf(T) NewKernelUtil::OFGemm( - ctx->stream(), conv_state->is_out_diff_need_trans_, CblasTrans, - conv_state->weight_5d_shape_.At(0), // filter - conv_state->weight_5d_shape_.Count(1), // ci * kd * kh * kw - conv_state->out_5d_shape_.Count(idx_offset, idx_offset + 3), // od * oh * ow + ctx->stream(), conv_cache->is_out_diff_need_trans_, CblasTrans, + conv_cache->weight_5d_shape_.At(0), // filter + conv_cache->weight_5d_shape_.Count(1), // ci * kd * kh * kw + conv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3), // od * oh * ow static_cast(1), GetImgDptr(dy, i), col_buf->dptr(), static_cast(1), filter_diff->mut_dptr()); } diff --git a/oneflow/user/kernels/cumsum_kernel.cpp b/oneflow/user/kernels/cumsum_kernel.cpp new file mode 100644 index 00000000000..bf923de2309 --- /dev/null +++ b/oneflow/user/kernels/cumsum_kernel.cpp @@ -0,0 +1,129 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" + +namespace oneflow { + +namespace { +template +void cumsum_forward(const T* in_ptr, T* out_ptr, int64_t cs_up_space, int64_t cs_space, + int64_t cs_down_space, int64_t elem_cnt) { + std::copy_n(in_ptr, elem_cnt, out_ptr); + auto* tmp_out_ptr_base = out_ptr; + auto step = cs_space * cs_down_space; + for (auto i = 0; i < cs_up_space; i++) { + for (auto j = 1; j < cs_space; j++) { + auto* tmp_out_ptr = tmp_out_ptr_base + j * cs_down_space; + auto* last_tmp_out_ptr = tmp_out_ptr - cs_down_space; + for (auto k = 0; k < cs_down_space; k++) { tmp_out_ptr[k] += last_tmp_out_ptr[k]; } + } + tmp_out_ptr_base += step; + } +} + +template +void cumsum_backward(const T* in_ptr, T* out_ptr, int64_t cs_up_space, int64_t cs_space, + int64_t cs_down_space, int64_t elem_cnt) { + auto* tmp_in_ptr_base = in_ptr; + auto* tmp_out_ptr_base = out_ptr; + auto step = cs_space * cs_down_space; + for (auto i = 0; i < cs_up_space; i++) { + for (auto j = 0; j < cs_space; j++) { + auto* tmp_in_ptr = tmp_in_ptr_base + j * cs_down_space; + auto* tmp_out_ptr = tmp_out_ptr_base + j * cs_down_space; + std::fill_n(tmp_out_ptr, cs_down_space, cs_space - j); + for (auto k = 0; k < cs_down_space; k++) { tmp_out_ptr[k] *= tmp_in_ptr[k]; } + } + tmp_in_ptr_base += step; + tmp_out_ptr_base += step; + } +} +} // namespace + +template +class CpuCumsumKernel final : public user_op::OpKernel { + public: + CpuCumsumKernel() = default; + ~CpuCumsumKernel() = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + const auto* in = ctx->Tensor4ArgNameAndIndex("in", 0); + auto elem_cnt = in->shape().elem_cnt(); + // judge whether tensor has 0 size dimension first + if (!elem_cnt) { return; } + + auto* out = ctx->Tensor4ArgNameAndIndex("out", 0); + auto dim = ctx->Attr("dim"); + const auto* in_ptr = in->dptr(); + auto* out_ptr = out->mut_dptr(); + + // take cumsum's abbreviation as `cs` + // data partition: cs_up_space|cs_space|cs_down_space + auto cs_up_space = elem_cnt / in->shape().Count(dim); + auto cs_space = in->shape().At(dim); + auto cs_down_space = in->shape().Count(dim + 1); + + cumsum_forward(in_ptr, out_ptr, cs_up_space, cs_space, cs_down_space, elem_cnt); + } + + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_CUMSUM_KERNEL(dtype) \ + REGISTER_USER_KERNEL("cumsum").SetCreateFn>().SetIsMatchedHob( \ + (user_op::HobDeviceType() == DeviceType::kCPU) \ + && (user_op::HobDataType("out", 0) == GetDataType::value)); + +REGISTER_CUMSUM_KERNEL(int64_t) +REGISTER_CUMSUM_KERNEL(float) +REGISTER_CUMSUM_KERNEL(double) + +template +class CpuCumsumGradKernel final : public user_op::OpKernel { + public: + CpuCumsumGradKernel() = default; + ~CpuCumsumGradKernel() = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + const auto* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); + auto* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); + auto elem_cnt = dy->shape().elem_cnt(); + auto dim = ctx->Attr("dim"); + const auto* dy_ptr = dy->dptr(); + auto* dx_ptr = dx->mut_dptr(); + + // data partition: cs_up_space|cs_space|cs_down_space + auto cs_up_space = elem_cnt / dx->shape().Count(dim); + auto cs_space = dx->shape().At(dim); + auto cs_down_space = dx->shape().Count(dim + 1); + + cumsum_backward(dy_ptr, dx_ptr, cs_up_space, cs_space, cs_down_space, elem_cnt); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_CPU_CUMSUM_GRAD_KERNEL(dtype) \ + REGISTER_USER_KERNEL("cumsum_grad") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ + && (user_op::HobDataType("dx", 0) == GetDataType::value)); + +REGISTER_CPU_CUMSUM_GRAD_KERNEL(float) +REGISTER_CPU_CUMSUM_GRAD_KERNEL(double) + +} // namespace oneflow diff --git a/oneflow/user/kernels/cumsum_kernel.cu b/oneflow/user/kernels/cumsum_kernel.cu new file mode 100644 index 00000000000..2efae7027bb --- /dev/null +++ b/oneflow/user/kernels/cumsum_kernel.cu @@ -0,0 +1,207 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/device/cuda_util.h" + +namespace oneflow { + +namespace { + +// total thread number: cs_up_space * cs_down_space +// in cs_down_space part, use cs_down_space threads +// to calculate as follows(m=cs_down_space-1, n=cs_space-1, '|' stands for dependency): +// dm0, ..., d10, d00 +// | | | +// dm1, ..., d11, d01 +// | | | +// dm2, ..., d12, d02 +// | | | +// ... ... ... +// | | | +// dmn, ..., d1n, d0n +template +__global__ void CumsumForwardGpu(const T* in_ptr, T* out_ptr, int64_t cs_up_space, int64_t cs_space, + int64_t cs_down_space) { + CUDA_1D_KERNEL_LOOP(i, cs_up_space * cs_down_space) { + auto cs_up_space_id = i / cs_down_space; + auto cs_down_space_id = i - (i / cs_down_space) * cs_down_space; + + auto* in_ptr_base = in_ptr + cs_up_space_id * cs_space * cs_down_space + cs_down_space_id; + auto* out_ptr_base = out_ptr + cs_up_space_id * cs_space * cs_down_space + cs_down_space_id; + + // calculate cs_space data in one thread + for (auto j = 0; j < cs_space; j++) { + auto idx = j * cs_down_space; + out_ptr_base[idx] = in_ptr_base[idx]; + if (j != 0) { out_ptr_base[idx] += out_ptr_base[idx - cs_down_space]; } + } + } +} +template +__global__ void CumsumForwardGpuUpSpaceIs1(const T* in_ptr, T* out_ptr, int64_t cs_space, + int64_t cs_down_space) { + CUDA_1D_KERNEL_LOOP(i, cs_down_space) { + auto* in_ptr_base = in_ptr + i; + auto* out_ptr_base = out_ptr + i; + + // calculate cs_space data in one thread + for (auto j = 0; j < cs_space; j++) { + auto idx = j * cs_down_space; + out_ptr_base[idx] = in_ptr_base[idx]; + if (j != 0) { out_ptr_base[idx] += out_ptr_base[idx - cs_down_space]; } + } + } +} +template +__global__ void CumsumForwardGpuDownSpaceIs1(const T* in_ptr, T* out_ptr, int64_t cs_up_space, + int64_t cs_space) { + CUDA_1D_KERNEL_LOOP(i, cs_up_space) { + auto* in_ptr_base = in_ptr + i * cs_space; + auto* out_ptr_base = out_ptr + i * cs_space; + + // calculate cs_space data in one thread + for (auto j = 0; j < cs_space; j++) { + out_ptr_base[j] = in_ptr_base[j]; + if (j != 0) { out_ptr_base[j] += out_ptr_base[j - 1]; } + } + } +} + +// total thread number: cs_up_space * cs_down_space +// in cs_down_space part, use cs_down_space threads +// to calculate as follows(m=cs_down_space-1, n=cs_space-1, there is no dependency in backward): +// dm0, ..., d10, d00 +// dm1, ..., d11, d01 +// dm2, ..., d12, d02 +// ... ... ... +// dmn, ..., d1n, d0n +template +__global__ void CumsumBackwardGpu(const T* in_ptr, T* out_ptr, int64_t cs_space, + int64_t cs_down_space, int64_t elem_cnt) { + for (auto i = blockIdx.x * blockDim.x + threadIdx.x, step = blockDim.x * gridDim.x; i < elem_cnt; + i += step) { + auto tmp = cs_space * cs_down_space; + auto cs_space_id = (i - (i / tmp) * tmp) / cs_down_space; + out_ptr[i] = (cs_space - cs_space_id) * in_ptr[i]; + } +} +template +__global__ void CumsumBackwardGpu_DownSpaceIs1(const T* in_ptr, T* out_ptr, int64_t cs_up_space, + int64_t cs_space, int64_t elem_cnt) { + for (auto i = blockIdx.x * blockDim.x + threadIdx.x, step = blockDim.x * gridDim.x; i < elem_cnt; + i += step) { + auto cs_space_id = i - (i / cs_space) * cs_space; + out_ptr[i] = (cs_space - cs_space_id) * in_ptr[i]; + } +} + +} // namespace + +template +class GpuCumsumKernel final : public user_op::OpKernel { + public: + GpuCumsumKernel() = default; + ~GpuCumsumKernel() = default; + + private: + using user_op::OpKernel::Compute; + void Compute(user_op::KernelComputeContext* ctx) const override { + // judge whether tensor has 0 size dimension first + const auto* in = ctx->Tensor4ArgNameAndIndex("in", 0); + auto elem_cnt = in->shape().elem_cnt(); + if (!elem_cnt) { return; } + + auto* out = ctx->Tensor4ArgNameAndIndex("out", 0); + auto dim = ctx->Attr("dim"); + const auto* in_ptr = in->dptr(); + auto* out_ptr = out->mut_dptr(); + + // take cumsum's abbreviation as `cs` + // data partition: cs_up_space|cs_space|cs_down_space + auto cs_up_space = elem_cnt / in->shape().Count(dim); + auto cs_space = in->shape().At(dim); + auto cs_down_space = in->shape().Count(dim + 1); + auto thread_num = cs_up_space * cs_down_space; + + if (cs_up_space == 1) { + RUN_CUDA_KERNEL((CumsumForwardGpuUpSpaceIs1), ctx->stream(), thread_num, in_ptr, out_ptr, + cs_space, cs_down_space); + } else if (cs_down_space == 1) { + RUN_CUDA_KERNEL((CumsumForwardGpuDownSpaceIs1), ctx->stream(), thread_num, in_ptr, out_ptr, + cs_up_space, cs_space); + } else { + RUN_CUDA_KERNEL((CumsumForwardGpu), ctx->stream(), thread_num, in_ptr, out_ptr, + cs_up_space, cs_space, cs_down_space); + } + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_CUDA_CUMSUM_KERNEL(dtype) \ + REGISTER_USER_KERNEL("cumsum").SetCreateFn>().SetIsMatchedHob( \ + (user_op::HobDeviceType() == DeviceType::kCUDA) \ + && (user_op::HobDataType("out", 0) == GetDataType::value)); + +REGISTER_CUDA_CUMSUM_KERNEL(int64_t) +REGISTER_CUDA_CUMSUM_KERNEL(float) +REGISTER_CUDA_CUMSUM_KERNEL(double) + +template +class GpuCumsumGradKernel final : public user_op::OpKernel { + public: + GpuCumsumGradKernel() = default; + ~GpuCumsumGradKernel() = default; + + private: + using user_op::OpKernel::Compute; + void Compute(user_op::KernelComputeContext* ctx) const override { + // judge whether tensor has 0 size dimension first + const auto* in = ctx->Tensor4ArgNameAndIndex("dy", 0); + auto elem_cnt = in->shape().elem_cnt(); + if (!elem_cnt) { return; } + auto* out = ctx->Tensor4ArgNameAndIndex("dx", 0); + auto dim = ctx->Attr("dim"); + const auto* in_ptr = in->dptr(); + auto* out_ptr = out->mut_dptr(); + + // take cumsum's abbreviation as `cs` + // data partition: cs_up_space|cs_space|cs_down_space + auto cs_up_space = elem_cnt / in->shape().Count(dim); + auto cs_space = in->shape().At(dim); + auto cs_down_space = in->shape().Count(dim + 1); + auto thread_num = elem_cnt; + + if (cs_down_space == 1) { + RUN_CUDA_KERNEL((CumsumBackwardGpu_DownSpaceIs1), ctx->stream(), thread_num, in_ptr, + out_ptr, cs_up_space, cs_space, elem_cnt); + } else { + RUN_CUDA_KERNEL((CumsumBackwardGpu), ctx->stream(), thread_num, in_ptr, out_ptr, cs_space, + cs_down_space, elem_cnt); + } + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_CUDA_CUMSUM_GRAD_KERNEL(dtype) \ + REGISTER_USER_KERNEL("cumsum_grad") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ + && (user_op::HobDataType("dx", 0) == GetDataType::value)); + +REGISTER_CUDA_CUMSUM_GRAD_KERNEL(float) +REGISTER_CUDA_CUMSUM_GRAD_KERNEL(double) + +} // namespace oneflow diff --git a/oneflow/user/kernels/deconv_cpu_kernel.cpp b/oneflow/user/kernels/deconv_cpu_kernel.cpp index f931c3d507b..9ab26c4442a 100644 --- a/oneflow/user/kernels/deconv_cpu_kernel.cpp +++ b/oneflow/user/kernels/deconv_cpu_kernel.cpp @@ -182,7 +182,7 @@ class ColBufUtil final { }; template -struct ConvKernelUtil final { +struct DeconvKernelUtil final { public: static void NCDHWCol2Im(const T* col_buf_ptr, const ShapeView& in_shape, const ShapeView& weight_shape, const ShapeView& out_shape, @@ -234,8 +234,8 @@ struct ConvKernelUtil final { }; template -struct ConvOpKernelState final : public user_op::OpKernelState { - Col2ImFunc col2im_func_; +struct DeconvOpKernelCache final : public user_op::OpKernelCache { + Col2ImFunc col2im_func_ = nullptr; Shape in_5d_shape_; Shape out_5d_shape_; @@ -245,9 +245,9 @@ struct ConvOpKernelState final : public user_op::OpKernelState { std::vector dilation_rate_3d_; std::vector padding_before_3d_; - enum CBLAS_TRANSPOSE is_out_diff_need_trans_; - int32_t idx_offset_; - bool is_dynamic_; + enum CBLAS_TRANSPOSE is_out_diff_need_trans_ = CblasNoTrans; + int32_t idx_offset_ = 0; + bool is_dynamic_ = false; void Update(const ShapeView& x_shape, const ShapeView& out_shape) { auto Gen5DShape = [](const ShapeView& shape, int32_t idx_offset) -> Shape { @@ -266,21 +266,21 @@ struct ConvOpKernelState final : public user_op::OpKernelState { }; template -std::shared_ptr> CreateConvOpKernelState(user_op::KernelComputeContext* ctx, - const std::string& in_name, - const std::string& out_name, - const std::string& weight_name) { +std::shared_ptr> CreateDeconvOpKernelCache(user_op::KernelCacheContext* ctx, + const std::string& in_name, + const std::string& out_name, + const std::string& weight_name) { const auto& data_format = ctx->Attr("data_format"); - std::shared_ptr> state(new ConvOpKernelState()); + std::shared_ptr> cache(new DeconvOpKernelCache()); if (data_format == "channels_first") { - state->col2im_func_ = ConvKernelUtil::NCDHWCol2Im; - state->is_out_diff_need_trans_ = CblasNoTrans; - state->idx_offset_ = 2; + cache->col2im_func_ = DeconvKernelUtil::NCDHWCol2Im; + cache->is_out_diff_need_trans_ = CblasNoTrans; + cache->idx_offset_ = 2; } else { - state->col2im_func_ = ConvKernelUtil::NDHWCCol2Im; - state->is_out_diff_need_trans_ = CblasTrans; - state->idx_offset_ = 1; + cache->col2im_func_ = DeconvKernelUtil::NDHWCCol2Im; + cache->is_out_diff_need_trans_ = CblasTrans; + cache->idx_offset_ = 1; } auto Gen5DShape = [](const Shape& shape, int32_t idx_offset) -> Shape { @@ -289,32 +289,32 @@ std::shared_ptr> CreateConvOpKernelState(user_op::KernelCom ret_vec.insert(ret_vec.begin() + idx_offset, 3 - ndims, 1); return Shape(ret_vec); }; - state->in_5d_shape_ = - Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(in_name, 0)->shape(), state->idx_offset_); - state->out_5d_shape_ = - Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(out_name, 0)->shape(), state->idx_offset_); - state->weight_5d_shape_ = - Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(weight_name, 0)->shape(), state->idx_offset_); + cache->in_5d_shape_ = + Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(in_name, 0)->shape(), cache->idx_offset_); + cache->out_5d_shape_ = + Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(out_name, 0)->shape(), cache->idx_offset_); + cache->weight_5d_shape_ = + Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(weight_name, 0)->shape(), cache->idx_offset_); auto Gen3DVec = [](const std::vector& origin_vec) -> std::vector { std::vector ret_vec = origin_vec; ret_vec.insert(ret_vec.begin(), 3 - ret_vec.size(), 1); return ret_vec; }; - state->strides_3d_ = Gen3DVec(ctx->Attr>("strides")); - state->dilation_rate_3d_ = Gen3DVec(ctx->Attr>("dilation_rate")); - state->is_dynamic_ = ctx->TensorDesc4ArgNameAndIndex(in_name, 0)->is_dynamic(); + cache->strides_3d_ = Gen3DVec(ctx->Attr>("strides")); + cache->dilation_rate_3d_ = Gen3DVec(ctx->Attr>("dilation_rate")); + cache->is_dynamic_ = ctx->TensorDesc4ArgNameAndIndex(in_name, 0)->is_dynamic(); const auto& padding_before = ctx->Attr>("padding_before"); FOR_RANGE(uint8_t, dim, 0, 3) { int64_t index = static_cast(dim) - (3 - padding_before.size()); if (index < 0) { - state->padding_before_3d_.emplace_back(0); + cache->padding_before_3d_.push_back(0); } else { - state->padding_before_3d_.emplace_back(padding_before.at(index)); + cache->padding_before_3d_.push_back(padding_before.at(index)); } } - return state; + return cache; } template @@ -326,21 +326,28 @@ class DeconvCpuKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - std::shared_ptr> DoCreateOpKernelState( - user_op::KernelComputeContext* ctx) const { - return CreateConvOpKernelState(ctx, "out", "in", "weight"); + using user_op::OpKernel::InitOpKernelCache; + void InitOpKernelCache(user_op::KernelCacheContext* ctx, int8_t flag, + std::shared_ptr* cache_ptr) const override { + if (*cache_ptr != nullptr && (flag & user_op::OpKernelCache::kAttrNotChanged)) { + auto deconv_cache = std::dynamic_pointer_cast>(*cache_ptr); + deconv_cache->Update(ctx->TensorDesc4ArgNameAndIndex("in", 0)->shape(), + ctx->TensorDesc4ArgNameAndIndex("out", 0)->shape()); + return; + } + *cache_ptr = CreateDeconvOpKernelCache(ctx, "out", "in", "weight"); } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - auto conv_state = DoCreateOpKernelState(ctx); - CHECK_NOTNULL(conv_state); + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + auto deconv_cache = dynamic_cast*>(cache); + CHECK_NOTNULL(deconv_cache); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* col_buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); - conv_state->Update(in->shape(), out->shape()); Memset(ctx->stream(), out->mut_dptr(), 0, out->shape().elem_cnt() * sizeof(T)); @@ -348,20 +355,20 @@ class DeconvCpuKernel final : public user_op::OpKernel { // channels first: col_buf' = weight(T) * in[i]' // channels last : col_buf' = weight(T) * in[i]'(T) // m, n, k - int32_t idx_offset = conv_state->idx_offset_; + int32_t idx_offset = deconv_cache->idx_offset_; NewKernelUtil::OFGemm( - ctx->stream(), CblasTrans, conv_state->is_out_diff_need_trans_, - conv_state->weight_5d_shape_.Count(1), - conv_state->out_5d_shape_.Count(idx_offset, idx_offset + 3), - conv_state->weight_5d_shape_.At(0), static_cast(1), weight->dptr(), + ctx->stream(), CblasTrans, deconv_cache->is_out_diff_need_trans_, + deconv_cache->weight_5d_shape_.Count(1), + deconv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3), + deconv_cache->weight_5d_shape_.At(0), static_cast(1), weight->dptr(), GetImgDptr(in, i), static_cast(0), col_buf->mut_dptr()); // out = col2im(col_buf') - conv_state->col2im_func_(col_buf->dptr(), ShapeView(conv_state->in_5d_shape_), - ShapeView(conv_state->weight_5d_shape_), - ShapeView(conv_state->out_5d_shape_), conv_state->strides_3d_.data(), - conv_state->dilation_rate_3d_.data(), - conv_state->padding_before_3d_.data(), GetImgMutDptr(out, i)); + deconv_cache->col2im_func_( + col_buf->dptr(), ShapeView(deconv_cache->in_5d_shape_), + ShapeView(deconv_cache->weight_5d_shape_), ShapeView(deconv_cache->out_5d_shape_), + deconv_cache->strides_3d_.data(), deconv_cache->dilation_rate_3d_.data(), + deconv_cache->padding_before_3d_.data(), GetImgMutDptr(out, i)); } } }; diff --git a/oneflow/user/kernels/diagonal_kernel.cpp b/oneflow/user/kernels/diagonal_kernel.cpp new file mode 100644 index 00000000000..bb01b77d2a5 --- /dev/null +++ b/oneflow/user/kernels/diagonal_kernel.cpp @@ -0,0 +1,138 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/common/util.h" +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/kernel/new_kernel_util.h" +#include "oneflow/core/kernel/kernel_util.h" +#include "oneflow/core/ep/cuda/cuda_stream.h" + +namespace oneflow { +namespace { + +template +struct DiagonalFunctor final { + void operator()(ep::Stream* stream, T* out_buf, const T* in_buf, int32_t size, int32_t dim1, + int32_t dim2) { + int32_t offset_index = (dim1 + 1) * dim2; + FOR_RANGE(int32_t, index, 0, size * dim2) { + int32_t i = index / dim2; + int32_t j = index - i * dim2; + out_buf[j * size + i] = in_buf[i * offset_index + j]; + } + } +}; + +template +struct DiagonalGradFunctor final { + void operator()(ep::Stream* stream, T* dx_buf, const T* dy_buf, int32_t size, int32_t dim1, + int32_t dim2) { + int32_t offset_index = (dim1 + 1) * dim2; + FOR_RANGE(int32_t, index, 0, size * dim2) { + int32_t i = index / dim2; + int32_t j = index - i * dim2; + dx_buf[i * offset_index + j] = dy_buf[j * size + i]; + } + } +}; + +} // namespace + +template +class CpuDiagonalKernel final : public user_op::OpKernel { + public: + CpuDiagonalKernel() = default; + ~CpuDiagonalKernel() = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + const int32_t offset = ctx->Attr("offset"); + const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); + user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); + const ShapeView& out_shape = out->shape(); + const ShapeView& in_shape = in->shape(); + const T* in_buf = in->dptr(); + T* out_buf = out->mut_dptr(); + + int32_t size = out_shape.At(out_shape.NumAxes() - 1); + int32_t dim1 = in_shape.At(1); + int32_t dim2 = 0; + if (in_shape.NumAxes() <= 2) { + dim2 = 1; + } else { + dim2 = in_shape.Count(2, in_shape.NumAxes()); + } + + int32_t offset_in_bufer = (offset >= 0 ? offset * dim2 : -offset * dim1 * dim2); + in_buf += offset_in_bufer; + DiagonalFunctor()(ctx->stream(), out_buf, in_buf, size, dim1, dim2); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +template +class CpuDiagonalBackwardKernel final : public user_op::OpKernel { + public: + CpuDiagonalBackwardKernel() = default; + ~CpuDiagonalBackwardKernel() = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); + user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); + int32_t offset = ctx->Attr("offset"); + const ShapeView& dx_shape = dx->shape(); + const ShapeView& dy_shape = dy->shape(); + T* dx_buf = dx->mut_dptr(); + const T* dy_buf = dy->dptr(); + + Memset(ctx->stream(), dx->mut_dptr(), 0, dx_shape.elem_cnt() * sizeof(T)); + + int32_t dim1 = dx_shape.At(1); + int32_t dim2 = 0; + if (dx_shape.NumAxes() <= 2) { + dim2 = 1; + } else { + dim2 = dx_shape.Count(2, dx_shape.NumAxes()); + } + int32_t size = dy_shape.At(dy_shape.NumAxes() - 1); + int32_t offset_in_bufer = (offset >= 0 ? offset * dim2 : -offset * dim1 * dim2); + dx_buf += offset_in_bufer; + + DiagonalGradFunctor()(ctx->stream(), dx_buf, dy_buf, size, dim1, dim2); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_DIAGONAL_KERNELS(dtype) \ + REGISTER_USER_KERNEL("diagonal") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ + && (user_op::HobDataType("in", 0) == GetDataType::value)); \ + REGISTER_USER_KERNEL("diagonal_grad") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ + && (user_op::HobDataType("in", 0) == GetDataType::value)); + +REGISTER_DIAGONAL_KERNELS(float); +REGISTER_DIAGONAL_KERNELS(double); +REGISTER_DIAGONAL_KERNELS(int8_t); +REGISTER_DIAGONAL_KERNELS(int32_t); +REGISTER_DIAGONAL_KERNELS(int64_t); + +#undef REGISTER_DIAGONAL_KERNELS + +} // namespace oneflow diff --git a/oneflow/user/kernels/diagonal_kernel.cu b/oneflow/user/kernels/diagonal_kernel.cu new file mode 100644 index 00000000000..fc65ab55deb --- /dev/null +++ b/oneflow/user/kernels/diagonal_kernel.cu @@ -0,0 +1,160 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/common/util.h" +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/kernel/new_kernel_util.h" +#include "oneflow/core/kernel/kernel_util.h" +#include "oneflow/core/ep/cuda/cuda_stream.h" + +namespace oneflow { +namespace { + +template +__global__ void forward_diagonal_kernel(T* out_buf, const T* in_buf, int32_t size, int32_t dim1, + int32_t dim2) { + int32_t offset_index = (dim1 + 1) * dim2; + CUDA_1D_KERNEL_LOOP(index, size * dim2) { + int32_t i = index / dim2; + int32_t j = index - i * dim2; + out_buf[j * size + i] = in_buf[i * offset_index + j]; + } +} + +template +__global__ void backward_diagonal_kernel(T* dx_buf, const T* dy_buf, int32_t size, int32_t dim1, + int32_t dim2) { + int32_t offset_index = (dim1 + 1) * dim2; + CUDA_1D_KERNEL_LOOP(index, size * dim2) { + int32_t i = index / dim2; + int32_t j = index - i * dim2; + dx_buf[i * offset_index + j] = dy_buf[j * size + i]; + } +} + +template +struct DiagonalFunctor final { + void operator()(ep::Stream* stream, T* out_buf, const T* in_buf, int32_t size, int32_t dim1, + int32_t dim2) { + if (size * dim2 > 0) { + forward_diagonal_kernel + <<As()->cuda_stream()>>>(out_buf, in_buf, size, dim1, dim2); + } + } +}; + +template +struct DiagonalGradFunctor final { + void operator()(ep::Stream* stream, T* dx_buf, const T* dy_buf, int32_t size, int32_t dim1, + int32_t dim2) { + if (size * dim2 > 0) { + backward_diagonal_kernel + <<As()->cuda_stream()>>>(dx_buf, dy_buf, size, dim1, dim2); + } + } +}; + +} // namespace + +template +class GpuDiagonalKernel final : public user_op::OpKernel { + public: + GpuDiagonalKernel() = default; + ~GpuDiagonalKernel() = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + const int32_t offset = ctx->Attr("offset"); + const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); + user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); + const ShapeView& out_shape = out->shape(); + const ShapeView& in_shape = in->shape(); + const T* in_buf = in->dptr(); + T* out_buf = out->mut_dptr(); + + int32_t size = out_shape.At(out_shape.NumAxes() - 1); + int32_t dim1 = in_shape.At(1); + int32_t dim2 = 0; + if (in_shape.NumAxes() <= 2) { + dim2 = 1; + } else { + dim2 = in_shape.Count(2, in_shape.NumAxes()); + } + + int32_t offset_in_bufer = (offset >= 0 ? offset * dim2 : -offset * dim1 * dim2); + in_buf += offset_in_bufer; + + DiagonalFunctor()(ctx->stream(), out_buf, in_buf, size, dim1, dim2); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +template +class GpuDiagonalBackwardKernel final : public user_op::OpKernel { + public: + GpuDiagonalBackwardKernel() = default; + ~GpuDiagonalBackwardKernel() = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); + user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); + int32_t offset = ctx->Attr("offset"); + const ShapeView& dx_shape = dx->shape(); + const ShapeView& dy_shape = dy->shape(); + T* dx_buf = dx->mut_dptr(); + const T* dy_buf = dy->dptr(); + + Memset(ctx->stream(), dx->mut_dptr(), 0, dx_shape.elem_cnt() * sizeof(T)); + + int32_t dim1 = dx_shape.At(1); + int32_t dim2 = 0; + if (dx_shape.NumAxes() <= 2) { + dim2 = 1; + } else { + dim2 = dx_shape.Count(2, dx_shape.NumAxes()); + } + int32_t size = dy_shape.At(dy_shape.NumAxes() - 1); + int32_t offset_in_bufer = (offset >= 0 ? offset * dim2 : -offset * dim1 * dim2); + dx_buf += offset_in_bufer; + + DiagonalGradFunctor()(ctx->stream(), dx_buf, dy_buf, size, dim1, dim2); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_DIAGONAL_KERNELS(dtype) \ + REGISTER_USER_KERNEL("diagonal") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ + && (user_op::HobDataType("in", 0) == GetDataType::value)); \ + REGISTER_USER_KERNEL("diagonal_grad") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ + && (user_op::HobDataType("in", 0) == GetDataType::value)); + +REGISTER_DIAGONAL_KERNELS(half); +REGISTER_DIAGONAL_KERNELS(float); +REGISTER_DIAGONAL_KERNELS(double); +REGISTER_DIAGONAL_KERNELS(int8_t); +REGISTER_DIAGONAL_KERNELS(int32_t); +REGISTER_DIAGONAL_KERNELS(int64_t); + +#undef REGISTER_DIAGONAL_KERNELS + +} // namespace oneflow diff --git a/oneflow/user/kernels/distributions/normal_kernel.h b/oneflow/user/kernels/distributions/normal_kernel.h index 2a65c15c81e..a04e88cd80f 100644 --- a/oneflow/user/kernels/distributions/normal_kernel.h +++ b/oneflow/user/kernels/distributions/normal_kernel.h @@ -39,7 +39,8 @@ class NormalKernel final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const double mean = ctx->Attr("mean"); const double std = ctx->Attr("std"); diff --git a/oneflow/user/kernels/distributions/uniform_int_kernel.h b/oneflow/user/kernels/distributions/uniform_int_kernel.h index b069b6e6339..60c8a7e3c61 100644 --- a/oneflow/user/kernels/distributions/uniform_int_kernel.h +++ b/oneflow/user/kernels/distributions/uniform_int_kernel.h @@ -80,7 +80,8 @@ class UniformIntKernel final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); int64_t from = ctx->Attr("from"); int64_t to = ctx->Attr("to"); diff --git a/oneflow/user/kernels/distributions/uniform_kernel.h b/oneflow/user/kernels/distributions/uniform_kernel.h index d8c67dc9319..078f5bee5c6 100644 --- a/oneflow/user/kernels/distributions/uniform_kernel.h +++ b/oneflow/user/kernels/distributions/uniform_kernel.h @@ -38,7 +38,8 @@ class UniformKernel final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const double from = ctx->Attr("from"); const double to = ctx->Attr("to"); diff --git a/oneflow/user/kernels/dot_kernel.cpp b/oneflow/user/kernels/dot_kernel.cpp index 2be15520637..4e055ceefeb 100644 --- a/oneflow/user/kernels/dot_kernel.cpp +++ b/oneflow/user/kernels/dot_kernel.cpp @@ -48,7 +48,6 @@ class DotKernel final : public user_op::OpKernel { const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); int64_t n = x->shape().elem_cnt(); - CHECK(n <= INT_MAX); auto primitive = NewMatmulPrimitive(ctx); primitive->Launch(ctx->stream(), 1, 1, n, 1, x->dptr(), y->dptr(), 0, out->mut_dptr()); diff --git a/oneflow/user/kernels/dropout_kernel.cpp b/oneflow/user/kernels/dropout_kernel.cpp index 82bf55aa520..ab06a8df68c 100644 --- a/oneflow/user/kernels/dropout_kernel.cpp +++ b/oneflow/user/kernels/dropout_kernel.cpp @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" -#include "oneflow/user/kernels/op_kernel_state_wrapper.h" +#include "oneflow/user/kernels/op_kernel_wrapper.h" #include "oneflow/core/kernel/kernel_util.h" #include "oneflow/user/kernels/dropout_kernel.h" #include "oneflow/core/ep/include/primitive/add.h" @@ -58,7 +58,8 @@ class DropoutKernelCPU final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex("mask", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); diff --git a/oneflow/user/kernels/dropout_kernel.cu b/oneflow/user/kernels/dropout_kernel.cu index b1f84f319db..5c2471109ec 100644 --- a/oneflow/user/kernels/dropout_kernel.cu +++ b/oneflow/user/kernels/dropout_kernel.cu @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "oneflow/user/kernels/op_kernel_state_wrapper.h" +#include "oneflow/user/kernels/op_kernel_wrapper.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/cuda/elementwise.cuh" #include "oneflow/core/cuda/atomic.cuh" @@ -412,7 +412,8 @@ class DropoutKernelGPU final : public user_op::OpKernel, public user_op::CudaGra private: using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* mask = ctx->Tensor4ArgNameAndIndex("mask", 0); diff --git a/oneflow/user/kernels/eager_b_to_s_kernel.cpp b/oneflow/user/kernels/eager_b_to_s_kernel.cpp index 910cc748f68..a7b16dd3163 100644 --- a/oneflow/user/kernels/eager_b_to_s_kernel.cpp +++ b/oneflow/user/kernels/eager_b_to_s_kernel.cpp @@ -48,10 +48,10 @@ Maybe> GetAllBroadcastNdSbp(int64_t ndim) { auto* CachedGetAllBroadcastNdSbp = DECORATE(&GetAllBroadcastNdSbp, ThreadLocal); -class EagerBToSOpKernelState final : public user_op::OpKernelState { +class EagerBToSOpKernelCache final : public user_op::OpKernelCache { public: - explicit EagerBToSOpKernelState(user_op::KernelInitContext* ctx) { Init(ctx); } - ~EagerBToSOpKernelState() override = default; + explicit EagerBToSOpKernelCache(user_op::KernelCacheContext* ctx) { Init(ctx); } + ~EagerBToSOpKernelCache() override = default; const std::vector>>& sorted_elem_cnt2in_tensor_slice_copier_pair() const { @@ -68,7 +68,7 @@ class EagerBToSOpKernelState final : public user_op::OpKernelState { } private: - void Init(user_op::KernelInitContext* ctx) { + void Init(user_op::KernelCacheContext* ctx) { const std::string& in_parallel_conf_txt = ctx->Attr("in_parallel_conf"); const std::string& out_parallel_conf_txt = ctx->Attr("out_parallel_conf"); const int64_t out_split_axis = ctx->Attr("out_split_axis"); @@ -153,15 +153,16 @@ class EagerBToSKernel final : public user_op::OpKernel { EagerBToSKernel() = default; ~EagerBToSKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return std::make_shared(ctx); + void InitOpKernelCache(user_op::KernelCacheContext* ctx, int8_t flag, + std::shared_ptr* cache_ptr) const override { + if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared(ctx); } } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - auto* kernel_state = dynamic_cast(state); - CHECK(kernel_state != nullptr); + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + auto* kernel_cache = dynamic_cast(cache); + CHECK(kernel_cache != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); @@ -170,10 +171,10 @@ class EagerBToSKernel final : public user_op::OpKernel { void* tmp_buffer_ptr = tmp_buffer->mut_dptr(); const auto& sorted_elem_cnt2in_tensor_slice_copier_pair = - kernel_state->sorted_elem_cnt2in_tensor_slice_copier_pair(); + kernel_cache->sorted_elem_cnt2in_tensor_slice_copier_pair(); const auto& sorted_elem_cnt2out_tensor_slice_copier_pair = - kernel_state->sorted_elem_cnt2out_tensor_slice_copier_pair(); - const auto& sorted_p2p_pair = kernel_state->sorted_p2p_pair(); + kernel_cache->sorted_elem_cnt2out_tensor_slice_copier_pair(); + const auto& sorted_p2p_pair = kernel_cache->sorted_p2p_pair(); CHECK_EQ(sorted_elem_cnt2in_tensor_slice_copier_pair.size(), sorted_p2p_pair.size()); CHECK_EQ(sorted_elem_cnt2out_tensor_slice_copier_pair.size(), sorted_p2p_pair.size()); diff --git a/oneflow/user/kernels/eager_nccl_kernels.cpp b/oneflow/user/kernels/eager_nccl_kernels.cpp index 9930519e68c..debe18a8feb 100644 --- a/oneflow/user/kernels/eager_nccl_kernels.cpp +++ b/oneflow/user/kernels/eager_nccl_kernels.cpp @@ -27,15 +27,15 @@ namespace oneflow { namespace { -class EagerCclOpKernelState final : public user_op::OpKernelState { +class EagerCclOpKernelCache final : public user_op::OpKernelCache { public: - EagerCclOpKernelState(user_op::KernelInitContext* ctx) { Init(ctx); } - ~EagerCclOpKernelState() override = default; + explicit EagerCclOpKernelCache(user_op::KernelCacheContext* ctx) { Init(ctx); } + ~EagerCclOpKernelCache() override = default; Symbol parallel_desc() const { return parallel_desc_; } private: - void Init(user_op::KernelInitContext* ctx) { + void Init(user_op::KernelCacheContext* ctx) { const std::string& parallel_conf_txt = ctx->Attr("parallel_conf"); ParallelConf parallel_conf; CHECK(TxtString2PbMessage(parallel_conf_txt, ¶llel_conf)); @@ -67,6 +67,12 @@ Maybe>> RawGroupP2PPair( static constexpr auto* GroupP2PPair = DECORATE(&RawGroupP2PPair, ThreadLocal); +void InitEagerCclOpKernelCache(user_op::KernelCacheContext* ctx, + std::shared_ptr* cache_ptr) { + // NOTE(jianhao): the cache only depends on parallel_conf, and the kernel is singleton + // once parallel_conf is determined, so only init the cache at the first time. + if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared(ctx); } +} } // namespace class EagerCclBroadcastKernel final : public user_op::OpKernel { @@ -74,15 +80,16 @@ class EagerCclBroadcastKernel final : public user_op::OpKernel { EagerCclBroadcastKernel() = default; ~EagerCclBroadcastKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return std::make_shared(ctx); + void InitOpKernelCache(user_op::KernelCacheContext* ctx, int8_t flag, + std::shared_ptr* cache_ptr) const override { + InitEagerCclOpKernelCache(ctx, cache_ptr); } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - auto* kernel_state = dynamic_cast(state); - CHECK(kernel_state != nullptr); + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + auto* kernel_cache = dynamic_cast(cache); + CHECK(kernel_cache != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); int64_t root = ctx->Attr("root"); @@ -94,7 +101,7 @@ class EagerCclBroadcastKernel final : public user_op::OpKernel { } CHECK_JUST(ccl::Broadcast(in_ptr, out->mut_dptr(), out->shape().elem_cnt(), out->data_type(), root, - kernel_state->parallel_desc(), ctx->stream())); + kernel_cache->parallel_desc(), ctx->stream())); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; @@ -108,16 +115,17 @@ class EagerCclReduceKernel final : public user_op::OpKernel { EagerCclReduceKernel() = default; ~EagerCclReduceKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return std::make_shared(ctx); + void InitOpKernelCache(user_op::KernelCacheContext* ctx, int8_t flag, + std::shared_ptr* cache_ptr) const override { + InitEagerCclOpKernelCache(ctx, cache_ptr); } private: using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - auto* kernel_state = dynamic_cast(state); - CHECK(kernel_state != nullptr); + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + auto* kernel_cache = dynamic_cast(cache); + CHECK(kernel_cache != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); int64_t root = ctx->Attr("root"); @@ -129,7 +137,7 @@ class EagerCclReduceKernel final : public user_op::OpKernel { } CHECK_JUST(ccl::Reduce(in->dptr(), out_ptr, in->shape().elem_cnt(), in->data_type(), ccl::kSum, root, - kernel_state->parallel_desc(), ctx->stream())); + kernel_cache->parallel_desc(), ctx->stream())); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; @@ -143,15 +151,16 @@ class EagerCclAllReduceKernel final : public user_op::OpKernel { EagerCclAllReduceKernel() = default; ~EagerCclAllReduceKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return std::make_shared(ctx); + void InitOpKernelCache(user_op::KernelCacheContext* ctx, int8_t flag, + std::shared_ptr* cache_ptr) const override { + InitEagerCclOpKernelCache(ctx, cache_ptr); } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - auto* kernel_state = dynamic_cast(state); - CHECK(kernel_state != nullptr); + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + auto* kernel_cache = dynamic_cast(cache); + CHECK(kernel_cache != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->shape(), out->shape()); @@ -159,7 +168,7 @@ class EagerCclAllReduceKernel final : public user_op::OpKernel { CHECK_JUST(ccl::AllReduce( in->dptr(), out->mut_dptr(), out->shape().elem_cnt(), out->data_type(), ccl::kSum, - kernel_state->parallel_desc(), ctx->stream())); + kernel_cache->parallel_desc(), ctx->stream())); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; @@ -173,16 +182,17 @@ class EagerCclReduceScatterKernel final : public user_op::OpKernel { EagerCclReduceScatterKernel() = default; ~EagerCclReduceScatterKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return std::make_shared(ctx); + void InitOpKernelCache(user_op::KernelCacheContext* ctx, int8_t flag, + std::shared_ptr* cache_ptr) const override { + InitEagerCclOpKernelCache(ctx, cache_ptr); } private: using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - auto* kernel_state = dynamic_cast(state); - CHECK(kernel_state != nullptr); + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + auto* kernel_cache = dynamic_cast(cache); + CHECK(kernel_cache != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->data_type(), out->data_type()); @@ -190,7 +200,7 @@ class EagerCclReduceScatterKernel final : public user_op::OpKernel { CHECK_EQ(op_type, "sum"); CHECK_JUST(ccl::ReduceScatter( in->dptr(), out->mut_dptr(), out->shape().elem_cnt(), out->data_type(), ccl::kSum, - kernel_state->parallel_desc(), ctx->stream())); + kernel_cache->parallel_desc(), ctx->stream())); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; @@ -204,21 +214,22 @@ class EagerCclAllGatherKernel final : public user_op::OpKernel { EagerCclAllGatherKernel() = default; ~EagerCclAllGatherKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return std::make_shared(ctx); + void InitOpKernelCache(user_op::KernelCacheContext* ctx, int8_t flag, + std::shared_ptr* cache_ptr) const override { + InitEagerCclOpKernelCache(ctx, cache_ptr); } private: using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - auto* kernel_state = dynamic_cast(state); - CHECK(kernel_state != nullptr); + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + auto* kernel_cache = dynamic_cast(cache); + CHECK(kernel_cache != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->data_type(), out->data_type()); CHECK_JUST(ccl::AllGather(in->dptr(), out->mut_dptr(), in->shape().elem_cnt(), - out->data_type(), kernel_state->parallel_desc(), + out->data_type(), kernel_cache->parallel_desc(), ctx->stream())); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } @@ -234,16 +245,17 @@ class EagerCclS2SKernel final : public user_op::OpKernel { EagerCclS2SKernel() = default; ~EagerCclS2SKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return std::make_shared(ctx); + void InitOpKernelCache(user_op::KernelCacheContext* ctx, int8_t flag, + std::shared_ptr* cache_ptr) const override { + InitEagerCclOpKernelCache(ctx, cache_ptr); } private: using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - auto* kernel_state = dynamic_cast(state); - CHECK(kernel_state != nullptr); + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + auto* kernel_cache = dynamic_cast(cache); + CHECK(kernel_cache != nullptr); // NOTE(hanbinbin): Compute logic copy from _nccl_logical_s2s const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); @@ -257,7 +269,7 @@ class EagerCclS2SKernel final : public user_op::OpKernel { CHECK_EQ(tmp_size, data_size * 2); CHECK_EQ(in->data_type(), out->data_type()); - const int64_t num_ranks = kernel_state->parallel_desc()->parallel_num(); + const int64_t num_ranks = kernel_cache->parallel_desc()->parallel_num(); CHECK_EQ(in->shape().elem_cnt(), out->shape().elem_cnt()) << in->shape().ToString() << " vs " << out->shape().ToString(); const int64_t elem_cnt = in->shape().elem_cnt(); @@ -301,13 +313,13 @@ class EagerCclS2SKernel final : public user_op::OpKernel { // NOTE: Do S2S const int64_t elem_per_chunk = elem_cnt / num_ranks; const int64_t chunk_size = elem_per_chunk * dtype_size; - const auto& p2p_pairs = CHECK_JUST(GroupP2PPair(kernel_state->parallel_desc())); + const auto& p2p_pairs = CHECK_JUST(GroupP2PPair(kernel_cache->parallel_desc())); for (const auto& pair : *p2p_pairs) { int64_t src = pair.first; int64_t dst = pair.second; if (GlobalProcessCtx::Rank() == src) { - Symbol parallel_desc = kernel_state->parallel_desc(); + Symbol parallel_desc = kernel_cache->parallel_desc(); int64_t device_id = GlobalProcessCtx::LocalRank(dst); int64_t parallel_id = CHECK_JUST(parallel_desc->ParallelId4MachineDeviceId(dst, device_id)); @@ -318,7 +330,7 @@ class EagerCclS2SKernel final : public user_op::OpKernel { elem_per_chunk, in->data_type(), dst, ctx->stream())); } if (GlobalProcessCtx::Rank() == dst) { - Symbol parallel_desc = kernel_state->parallel_desc(); + Symbol parallel_desc = kernel_cache->parallel_desc(); int64_t device_id = GlobalProcessCtx::LocalRank(src); int64_t parallel_id = CHECK_JUST(parallel_desc->ParallelId4MachineDeviceId(src, device_id)); diff --git a/oneflow/user/kernels/eager_nccl_kernels.cu b/oneflow/user/kernels/eager_nccl_kernels.cu index f8e4f898baf..fe0fad69526 100644 --- a/oneflow/user/kernels/eager_nccl_kernels.cu +++ b/oneflow/user/kernels/eager_nccl_kernels.cu @@ -28,16 +28,16 @@ namespace oneflow { namespace { -class EagerNcclOpKernelState final : public user_op::OpKernelState { +class EagerNcclOpKernelCache final : public user_op::OpKernelCache { public: - explicit EagerNcclOpKernelState(user_op::KernelInitContext* ctx) { Init(ctx); } - ~EagerNcclOpKernelState() override = default; + explicit EagerNcclOpKernelCache(user_op::KernelCacheContext* ctx) { Init(ctx); } + ~EagerNcclOpKernelCache() override = default; Symbol parallel_desc() const { return parallel_desc_; } ncclComm_t comm() const { return comm_; } private: - void Init(user_op::KernelInitContext* ctx) { + void Init(user_op::KernelCacheContext* ctx) { const std::string& parallel_conf_txt = ctx->Attr("parallel_conf"); ParallelConf parallel_conf; std::set> device_set; @@ -64,6 +64,12 @@ size_t InferEagerNcclS2SKernelTmpBufferSize(user_op::InferContext* ctx) { return tensor_byte_size * 2; } +void InitEagerNcclOpKernelCache(user_op::KernelCacheContext* ctx, + std::shared_ptr* cache_ptr) { + // NOTE(jianhao): the cache only depends on parallel_conf, and the kernel is singleton + // once parallel_conf is determined, so only init the cache at the first time. + if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared(ctx); } +} } // namespace class EagerNcclAllReduceKernel final : public user_op::OpKernel { @@ -71,22 +77,24 @@ class EagerNcclAllReduceKernel final : public user_op::OpKernel { EagerNcclAllReduceKernel() = default; ~EagerNcclAllReduceKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return std::make_shared(ctx); + private: + using user_op::OpKernel::InitOpKernelCache; + void InitOpKernelCache(user_op::KernelCacheContext* ctx, int8_t flag, + std::shared_ptr* cache_ptr) const override { + InitEagerNcclOpKernelCache(ctx, cache_ptr); } - private: using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - auto* kernel_state = dynamic_cast(state); - CHECK(kernel_state != nullptr); + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + auto* kernel_cache = dynamic_cast(cache); + CHECK(kernel_cache != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->shape(), out->shape()); CHECK_EQ(in->data_type(), out->data_type()); OF_NCCL_CHECK(ncclAllReduce(in->dptr(), out->mut_dptr(), in->shape().elem_cnt(), - GetNcclDataType(in->data_type()), ncclSum, kernel_state->comm(), + GetNcclDataType(in->data_type()), ncclSum, kernel_cache->comm(), ctx->stream()->As()->cuda_stream())); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } @@ -101,22 +109,24 @@ class EagerNcclBroadcastKernel final : public user_op::OpKernel { EagerNcclBroadcastKernel() = default; ~EagerNcclBroadcastKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return std::make_shared(ctx); + private: + using user_op::OpKernel::InitOpKernelCache; + void InitOpKernelCache(user_op::KernelCacheContext* ctx, int8_t flag, + std::shared_ptr* cache_ptr) const override { + InitEagerNcclOpKernelCache(ctx, cache_ptr); } - private: using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - auto* kernel_state = dynamic_cast(state); - CHECK(kernel_state != nullptr); + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + auto* kernel_cache = dynamic_cast(cache); + CHECK(kernel_cache != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); int64_t root = ctx->Attr("root"); int64_t dev_id = GlobalProcessCtx::LocalRank(root); int64_t nccl_root = - CHECK_JUST(kernel_state->parallel_desc()->ParallelId4MachineDeviceId(root, dev_id)); + CHECK_JUST(kernel_cache->parallel_desc()->ParallelId4MachineDeviceId(root, dev_id)); const void* in_ptr = nullptr; if (GlobalProcessCtx::Rank() == root) { CHECK_EQ(in->shape(), out->shape()); @@ -124,7 +134,7 @@ class EagerNcclBroadcastKernel final : public user_op::OpKernel { in_ptr = in->dptr(); } OF_NCCL_CHECK(ncclBroadcast(in_ptr, out->mut_dptr(), out->shape().elem_cnt(), - GetNcclDataType(out->data_type()), nccl_root, kernel_state->comm(), + GetNcclDataType(out->data_type()), nccl_root, kernel_cache->comm(), ctx->stream()->As()->cuda_stream())); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } @@ -139,16 +149,18 @@ class EagerNcclReduceKernel final : public user_op::OpKernel { EagerNcclReduceKernel() = default; ~EagerNcclReduceKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return std::make_shared(ctx); + private: + using user_op::OpKernel::InitOpKernelCache; + void InitOpKernelCache(user_op::KernelCacheContext* ctx, int8_t flag, + std::shared_ptr* cache_ptr) const override { + InitEagerNcclOpKernelCache(ctx, cache_ptr); } - private: using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - auto* kernel_state = dynamic_cast(state); - CHECK(kernel_state != nullptr); + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + auto* kernel_cache = dynamic_cast(cache); + CHECK(kernel_cache != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); int64_t root = ctx->Attr("root"); @@ -159,7 +171,7 @@ class EagerNcclReduceKernel final : public user_op::OpKernel { out_ptr = out->mut_dptr(); } OF_NCCL_CHECK(ncclReduce(in->dptr(), out_ptr, in->shape().elem_cnt(), - GetNcclDataType(in->data_type()), ncclSum, root, kernel_state->comm(), + GetNcclDataType(in->data_type()), ncclSum, root, kernel_cache->comm(), ctx->stream()->As()->cuda_stream())); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } @@ -174,23 +186,25 @@ class EagerNcclReduceScatterKernel final : public user_op::OpKernel { EagerNcclReduceScatterKernel() = default; ~EagerNcclReduceScatterKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return std::make_shared(ctx); + private: + using user_op::OpKernel::InitOpKernelCache; + void InitOpKernelCache(user_op::KernelCacheContext* ctx, int8_t flag, + std::shared_ptr* cache_ptr) const override { + InitEagerNcclOpKernelCache(ctx, cache_ptr); } - private: using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - auto* kernel_state = dynamic_cast(state); - CHECK(kernel_state != nullptr); + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + auto* kernel_cache = dynamic_cast(cache); + CHECK(kernel_cache != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->data_type(), out->data_type()); const auto& op_type = ctx->Attr("op_type"); OF_NCCL_CHECK(ncclReduceScatter( in->dptr(), out->mut_dptr(), out->shape().elem_cnt(), GetNcclDataType(in->data_type()), - CHECK_JUST(MapAt(op_type2ncclRedOp_t, op_type)), kernel_state->comm(), + CHECK_JUST(MapAt(op_type2ncclRedOp_t, op_type)), kernel_cache->comm(), ctx->stream()->As()->cuda_stream())); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } @@ -210,21 +224,23 @@ class EagerNcclAllGatherKernel final : public user_op::OpKernel { EagerNcclAllGatherKernel() = default; ~EagerNcclAllGatherKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return std::make_shared(ctx); + private: + using user_op::OpKernel::InitOpKernelCache; + void InitOpKernelCache(user_op::KernelCacheContext* ctx, int8_t flag, + std::shared_ptr* cache_ptr) const override { + InitEagerNcclOpKernelCache(ctx, cache_ptr); } - private: using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - auto* kernel_state = dynamic_cast(state); - CHECK(kernel_state != nullptr); + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + auto* kernel_cache = dynamic_cast(cache); + CHECK(kernel_cache != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->data_type(), out->data_type()); OF_NCCL_CHECK(ncclAllGather(in->dptr(), out->mut_dptr(), in->shape().elem_cnt(), - GetNcclDataType(in->data_type()), kernel_state->comm(), + GetNcclDataType(in->data_type()), kernel_cache->comm(), ctx->stream()->As()->cuda_stream())); }; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } @@ -240,16 +256,18 @@ class EagerNcclS2SKernel final : public user_op::OpKernel { EagerNcclS2SKernel() = default; ~EagerNcclS2SKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return std::make_shared(ctx); + private: + using user_op::OpKernel::InitOpKernelCache; + void InitOpKernelCache(user_op::KernelCacheContext* ctx, int8_t flag, + std::shared_ptr* cache_ptr) const override { + InitEagerNcclOpKernelCache(ctx, cache_ptr); } - private: using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - auto* kernel_state = dynamic_cast(state); - CHECK(kernel_state != nullptr); + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + auto* kernel_cache = dynamic_cast(cache); + CHECK(kernel_cache != nullptr); // NOTE(hanbinbin): Compute logic copy from _nccl_logical_s2s const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); @@ -264,7 +282,7 @@ class EagerNcclS2SKernel final : public user_op::OpKernel { CHECK(tmp_size == 0 || tmp_size == data_size || tmp_size == data_size * 2); CHECK_EQ(in->data_type(), out->data_type()); - const int64_t num_ranks = kernel_state->parallel_desc()->parallel_num(); + const int64_t num_ranks = kernel_cache->parallel_desc()->parallel_num(); CHECK_EQ(in->shape().elem_cnt(), out->shape().elem_cnt()) << in->shape().ToString() << " vs " << out->shape().ToString(); const int64_t elem_cnt = in->shape().elem_cnt(); @@ -313,11 +331,11 @@ class EagerNcclS2SKernel final : public user_op::OpKernel { OF_NCCL_CHECK(ncclSend(reinterpret_cast( reinterpret_cast(pack_to_ptr) + j * chunk_size), elem_per_chunk, GetNcclDataType(in->data_type()), j, - kernel_state->comm(), + kernel_cache->comm(), ctx->stream()->As()->cuda_stream())); OF_NCCL_CHECK(ncclRecv( reinterpret_cast(reinterpret_cast(unpack_from_ptr) + j * chunk_size), - elem_per_chunk, GetNcclDataType(in->data_type()), j, kernel_state->comm(), + elem_per_chunk, GetNcclDataType(in->data_type()), j, kernel_cache->comm(), ctx->stream()->As()->cuda_stream())); } OF_NCCL_CHECK(ncclGroupEnd()); diff --git a/oneflow/user/kernels/eager_p_to_b_kernel.cpp b/oneflow/user/kernels/eager_p_to_b_kernel.cpp index fdf7c0450d2..9d4af2e0177 100644 --- a/oneflow/user/kernels/eager_p_to_b_kernel.cpp +++ b/oneflow/user/kernels/eager_p_to_b_kernel.cpp @@ -27,15 +27,15 @@ namespace oneflow { namespace { -class EagerPToBOpKernelState final : public user_op::OpKernelState { +class EagerPToBOpKernelCache final : public user_op::OpKernelCache { public: - explicit EagerPToBOpKernelState(user_op::KernelInitContext* ctx) { Init(ctx); } - ~EagerPToBOpKernelState() override = default; + explicit EagerPToBOpKernelCache(user_op::KernelCacheContext* ctx) { Init(ctx); } + ~EagerPToBOpKernelCache() override = default; const std::vector>& p2p_pair() const { return p2p_pair_; } private: - void Init(user_op::KernelInitContext* ctx) { + void Init(user_op::KernelCacheContext* ctx) { const std::string& in_parallel_conf_txt = ctx->Attr("in_parallel_conf"); const std::string& out_parallel_conf_txt = ctx->Attr("out_parallel_conf"); Symbol in_parallel_desc = CHECK_JUST(TxtStringToPlacement(in_parallel_conf_txt)); @@ -71,15 +71,16 @@ class EagerPToBKernel final : public user_op::OpKernel { EagerPToBKernel() = default; ~EagerPToBKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return std::make_shared(ctx); + void InitOpKernelCache(user_op::KernelCacheContext* ctx, int8_t flag, + std::shared_ptr* cache_ptr) const override { + if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared(ctx); } } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - auto* kernel_state = dynamic_cast(state); - CHECK(kernel_state != nullptr); + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + auto* kernel_cache = dynamic_cast(cache); + CHECK(kernel_cache != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); @@ -87,7 +88,7 @@ class EagerPToBKernel final : public user_op::OpKernel { void* tmp_buffer_ptr = tmp_buffer->mut_dptr(); const int64_t total_elem_cnt = ctx->Attr("shape").elem_cnt(); - const auto& p2p_pair = kernel_state->p2p_pair(); + const auto& p2p_pair = kernel_cache->p2p_pair(); Memset(ctx->stream(), out->mut_dptr(), 0, total_elem_cnt * GetSizeOfDataType(out->data_type())); diff --git a/oneflow/user/kernels/eager_p_to_s_kernel.cpp b/oneflow/user/kernels/eager_p_to_s_kernel.cpp index 82d55d8cd3f..aabfe42453e 100644 --- a/oneflow/user/kernels/eager_p_to_s_kernel.cpp +++ b/oneflow/user/kernels/eager_p_to_s_kernel.cpp @@ -49,12 +49,12 @@ Maybe> GetAllPartialSumNdSbp(int64_t ndim) { auto* CachedGetAllPartialSumNdSbp = DECORATE(&GetAllPartialSumNdSbp, ThreadLocal); -class EagerPToSOpKernelState final : public user_op::OpKernelState { +class EagerPToSOpKernelCache final : public user_op::OpKernelCache { public: - explicit EagerPToSOpKernelState(user_op::KernelInitContext* ctx) : elem_cnt_per_chunk_(0) { + explicit EagerPToSOpKernelCache(user_op::KernelCacheContext* ctx) : elem_cnt_per_chunk_(0) { Init(ctx); } - ~EagerPToSOpKernelState() override = default; + ~EagerPToSOpKernelCache() override = default; int64_t elem_cnt_per_chunk() const { return elem_cnt_per_chunk_; } @@ -67,7 +67,7 @@ class EagerPToSOpKernelState final : public user_op::OpKernelState { } private: - void Init(user_op::KernelInitContext* ctx) { + void Init(user_op::KernelCacheContext* ctx) { const std::string& in_parallel_conf_txt = ctx->Attr("in_parallel_conf"); const std::string& out_parallel_conf_txt = ctx->Attr("out_parallel_conf"); const int64_t out_split_axis = ctx->Attr("out_split_axis"); @@ -128,24 +128,25 @@ class EagerPToSKernel final : public user_op::OpKernel { EagerPToSKernel() = default; ~EagerPToSKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return std::make_shared(ctx); + void InitOpKernelCache(user_op::KernelCacheContext* ctx, int8_t flag, + std::shared_ptr* cache_ptr) const override { + if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared(ctx); } } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - auto* kernel_state = dynamic_cast(state); - CHECK(kernel_state != nullptr); + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + auto* kernel_cache = dynamic_cast(cache); + CHECK(kernel_cache != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); const void* in_ptr = in->dptr(); void* tmp_buffer_ptr = tmp_buffer->mut_dptr(); - int64_t elem_cnt_per_chunk = kernel_state->elem_cnt_per_chunk(); - const auto& sorted_in_tensor_slice_copier = kernel_state->sorted_in_tensor_slice_copier(); - const auto& sorted_p2p_pair = kernel_state->sorted_p2p_pair(); + int64_t elem_cnt_per_chunk = kernel_cache->elem_cnt_per_chunk(); + const auto& sorted_in_tensor_slice_copier = kernel_cache->sorted_in_tensor_slice_copier(); + const auto& sorted_p2p_pair = kernel_cache->sorted_p2p_pair(); CHECK_EQ(sorted_in_tensor_slice_copier.size(), sorted_p2p_pair.size()); Memset(ctx->stream(), out->mut_dptr(), 0, diff --git a/oneflow/user/kernels/eager_s_to_b_kernel.cpp b/oneflow/user/kernels/eager_s_to_b_kernel.cpp index b85e7cc32f0..be3dbae1960 100644 --- a/oneflow/user/kernels/eager_s_to_b_kernel.cpp +++ b/oneflow/user/kernels/eager_s_to_b_kernel.cpp @@ -48,10 +48,10 @@ Maybe> GetAllBroadcastNdSbp(int64_t ndim) { auto* CachedGetAllBroadcastNdSbp = DECORATE(&GetAllBroadcastNdSbp, ThreadLocal); -class EagerSToBOpKernelState final : public user_op::OpKernelState { +class EagerSToBOpKernelCache final : public user_op::OpKernelCache { public: - explicit EagerSToBOpKernelState(user_op::KernelInitContext* ctx) { Init(ctx); } - ~EagerSToBOpKernelState() override = default; + explicit EagerSToBOpKernelCache(user_op::KernelCacheContext* ctx) { Init(ctx); } + ~EagerSToBOpKernelCache() override = default; const std::vector>>& sorted_elem_cnt2in_tensor_slice_copier_pair() const { @@ -68,7 +68,7 @@ class EagerSToBOpKernelState final : public user_op::OpKernelState { } private: - void Init(user_op::KernelInitContext* ctx) { + void Init(user_op::KernelCacheContext* ctx) { const std::string& in_parallel_conf_txt = ctx->Attr("in_parallel_conf"); const std::string& out_parallel_conf_txt = ctx->Attr("out_parallel_conf"); const int64_t in_split_axis = ctx->Attr("in_split_axis"); @@ -135,15 +135,16 @@ class EagerSToBKernel final : public user_op::OpKernel { EagerSToBKernel() = default; ~EagerSToBKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return std::make_shared(ctx); + void InitOpKernelCache(user_op::KernelCacheContext* ctx, int8_t flag, + std::shared_ptr* cache_ptr) const override { + if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared(ctx); } } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - auto* kernel_state = dynamic_cast(state); - CHECK(kernel_state != nullptr); + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + auto* kernel_cache = dynamic_cast(cache); + CHECK(kernel_cache != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); @@ -152,10 +153,10 @@ class EagerSToBKernel final : public user_op::OpKernel { void* tmp_buffer_ptr = tmp_buffer->mut_dptr(); const auto& sorted_elem_cnt2in_tensor_slice_copier_pair = - kernel_state->sorted_elem_cnt2in_tensor_slice_copier_pair(); + kernel_cache->sorted_elem_cnt2in_tensor_slice_copier_pair(); const auto& sorted_elem_cnt2out_tensor_slice_copier_pair = - kernel_state->sorted_elem_cnt2out_tensor_slice_copier_pair(); - const auto& sorted_p2p_pair = kernel_state->sorted_p2p_pair(); + kernel_cache->sorted_elem_cnt2out_tensor_slice_copier_pair(); + const auto& sorted_p2p_pair = kernel_cache->sorted_p2p_pair(); CHECK_EQ(sorted_elem_cnt2in_tensor_slice_copier_pair.size(), sorted_p2p_pair.size()); CHECK_EQ(sorted_elem_cnt2out_tensor_slice_copier_pair.size(), sorted_p2p_pair.size()); diff --git a/oneflow/user/kernels/eager_s_to_s_kernel.cpp b/oneflow/user/kernels/eager_s_to_s_kernel.cpp index 526584b7329..8bcacf482dc 100644 --- a/oneflow/user/kernels/eager_s_to_s_kernel.cpp +++ b/oneflow/user/kernels/eager_s_to_s_kernel.cpp @@ -43,10 +43,10 @@ Maybe> GetAllSplitNdSbp(int64_t axis, int64_t ndim) { auto* CachedGetAllSplitNdSbp = DECORATE(&GetAllSplitNdSbp, ThreadLocal); -class EagerNaiveSToSOpKernelState final : public user_op::OpKernelState { +class EagerNaiveSToSOpKernelCache final : public user_op::OpKernelCache { public: - explicit EagerNaiveSToSOpKernelState(user_op::KernelInitContext* ctx) { Init(ctx); } - ~EagerNaiveSToSOpKernelState() override = default; + explicit EagerNaiveSToSOpKernelCache(user_op::KernelCacheContext* ctx) { Init(ctx); } + ~EagerNaiveSToSOpKernelCache() override = default; const std::vector>>& sorted_elem_cnt2in_tensor_slice_copier_pair() const { @@ -63,7 +63,7 @@ class EagerNaiveSToSOpKernelState final : public user_op::OpKernelState { } private: - void Init(user_op::KernelInitContext* ctx) { + void Init(user_op::KernelCacheContext* ctx) { const std::string& in_parallel_conf_txt = ctx->Attr("in_parallel_conf"); const std::string& out_parallel_conf_txt = ctx->Attr("out_parallel_conf"); const int64_t in_split_axis = ctx->Attr("in_split_axis"); @@ -140,15 +140,16 @@ class EagerNaiveSToSKernel final : public user_op::OpKernel { EagerNaiveSToSKernel() = default; ~EagerNaiveSToSKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return std::make_shared(ctx); + void InitOpKernelCache(user_op::KernelCacheContext* ctx, int8_t flag, + std::shared_ptr* cache_ptr) const override { + if (*cache_ptr == nullptr) { *cache_ptr = std::make_shared(ctx); } } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - auto* kernel_state = dynamic_cast(state); - CHECK(kernel_state != nullptr); + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + auto* kernel_cache = dynamic_cast(cache); + CHECK(kernel_cache != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); @@ -157,10 +158,10 @@ class EagerNaiveSToSKernel final : public user_op::OpKernel { void* tmp_buffer_ptr = tmp_buffer->mut_dptr(); const auto& sorted_elem_cnt2in_tensor_slice_copier_pair = - kernel_state->sorted_elem_cnt2in_tensor_slice_copier_pair(); + kernel_cache->sorted_elem_cnt2in_tensor_slice_copier_pair(); const auto& sorted_elem_cnt2out_tensor_slice_copier_pair = - kernel_state->sorted_elem_cnt2out_tensor_slice_copier_pair(); - const auto& sorted_p2p_pair = kernel_state->sorted_p2p_pair(); + kernel_cache->sorted_elem_cnt2out_tensor_slice_copier_pair(); + const auto& sorted_p2p_pair = kernel_cache->sorted_p2p_pair(); CHECK_EQ(sorted_elem_cnt2in_tensor_slice_copier_pair.size(), sorted_p2p_pair.size()); CHECK_EQ(sorted_elem_cnt2out_tensor_slice_copier_pair.size(), sorted_p2p_pair.size()); diff --git a/oneflow/user/kernels/eager_symmetric_s_to_p_kernel.cpp b/oneflow/user/kernels/eager_symmetric_s_to_p_kernel.cpp index 53e496b9a61..75382cf0383 100644 --- a/oneflow/user/kernels/eager_symmetric_s_to_p_kernel.cpp +++ b/oneflow/user/kernels/eager_symmetric_s_to_p_kernel.cpp @@ -45,17 +45,17 @@ Maybe> GetAllPartialSumNdSbp(int64_t ndim) { auto* CachedGetAllPartialSumNdSbp = DECORATE(&GetAllPartialSumNdSbp, ThreadLocal); -class EagerSymmetricSToPOpKernelState final : public user_op::OpKernelState { +class EagerSymmetricSToPOpKernelCache final : public user_op::OpKernelCache { public: - explicit EagerSymmetricSToPOpKernelState(user_op::KernelInitContext* ctx) { Init(ctx); } - ~EagerSymmetricSToPOpKernelState() override = default; + explicit EagerSymmetricSToPOpKernelCache(user_op::KernelCacheContext* ctx) { Init(ctx); } + ~EagerSymmetricSToPOpKernelCache() override = default; const std::shared_ptr& tensor_slice_copier() const { return tensor_slice_copier_; } private: - void Init(user_op::KernelInitContext* ctx) { + void Init(user_op::KernelCacheContext* ctx) { const std::string& parallel_conf_txt = ctx->Attr("parallel_conf"); const int64_t in_split_axis = ctx->Attr("in_split_axis"); const user_op::TensorDesc* in_logical_desc = ctx->LogicalTensorDesc4ArgNameAndIndex("in", 0); @@ -93,15 +93,18 @@ class EagerSymmetricSToPKernel final : public user_op::OpKernel { EagerSymmetricSToPKernel() = default; ~EagerSymmetricSToPKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return std::make_shared(ctx); + void InitOpKernelCache(user_op::KernelCacheContext* ctx, int8_t flag, + std::shared_ptr* cache_ptr) const override { + if (*cache_ptr == nullptr) { + *cache_ptr = std::make_shared(ctx); + } } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - auto* kernel_state = dynamic_cast(state); - CHECK(kernel_state != nullptr); + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + auto* kernel_cache = dynamic_cast(cache); + CHECK(kernel_cache != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const auto& out_shape_view = out->shape(); @@ -112,7 +115,7 @@ class EagerSymmetricSToPKernel final : public user_op::OpKernel { Memset(ctx->stream(), out->mut_dptr(), 0, out_shape_view.elem_cnt() * GetSizeOfDataType(out->data_type())); - const auto& tensor_slice_copier = kernel_state->tensor_slice_copier(); + const auto& tensor_slice_copier = kernel_cache->tensor_slice_copier(); tensor_slice_copier->Copy(ctx->stream(), out_ptr, in_ptr); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } diff --git a/oneflow/user/kernels/fold_kernel.cpp b/oneflow/user/kernels/fold_kernel.cpp index 9e45dee3f77..1a48f75cec9 100644 --- a/oneflow/user/kernels/fold_kernel.cpp +++ b/oneflow/user/kernels/fold_kernel.cpp @@ -60,7 +60,7 @@ class FoldKernel final : public OpKernel { ~FoldKernel() = default; private: - void Compute(KernelComputeContext* ctx, OpKernelState* state) const override { + void Compute(KernelComputeContext* ctx) const override { const Tensor* input = ctx->Tensor4ArgNameAndIndex("x", 0); Tensor* output = ctx->Tensor4ArgNameAndIndex("y", 0); @@ -101,4 +101,4 @@ REGISTER_FOLD_KERNEL(DeviceType::kCUDA, double) } // namespace user_op -} // namespace oneflow \ No newline at end of file +} // namespace oneflow diff --git a/oneflow/user/kernels/gather_kernel.cpp b/oneflow/user/kernels/gather_kernel.cpp index 1354f3e8313..6a4a6f105e6 100644 --- a/oneflow/user/kernels/gather_kernel.cpp +++ b/oneflow/user/kernels/gather_kernel.cpp @@ -28,10 +28,10 @@ Shape GetFlatShape(const ShapeView& shape, int64_t axis) { return Shape({shape.Count(0, axis), shape.At(axis), shape.Count(axis + 1)}); } -class GatherOpKernelState final : public user_op::OpKernelState { +class GatherOpKernelCache final : public user_op::OpKernelCache { public: - GatherOpKernelState(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {} - ~GatherOpKernelState() override = default; + GatherOpKernelCache(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {} + ~GatherOpKernelCache() override = default; int64_t lower() const { return lower_; } int64_t upper() const { return upper_; } @@ -64,8 +64,8 @@ class GatherKernel final : public user_op::OpKernel, public user_op::CudaGraphSu GatherKernel() = default; ~GatherKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { if (ctx->parallel_ctx().parallel_num() > 1) { const auto axis = ctx->Attr("axis"); const cfg::NdSbp& in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); @@ -75,14 +75,15 @@ class GatherKernel final : public user_op::OpKernel, public user_op::CudaGraphSu const TensorDesc* in_logical_desc = ctx->LogicalTensorDesc4ArgNameAndIndex("in", 0); TensorSliceView view = GetTensorSliceView4ParallelId( hierarchy, in_nd_sbp, in_logical_desc->shape(), ctx->parallel_ctx().parallel_id()); - return std::make_shared(view.At(axis).begin(), view.At(axis).end()); + return std::make_shared(view.At(axis).begin(), view.At(axis).end()); } else { - return std::shared_ptr(nullptr); + return nullptr; } } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* indices = ctx->Tensor4ArgNameAndIndex("indices", 0); const int64_t axis = ctx->Attr("axis"); @@ -91,11 +92,11 @@ class GatherKernel final : public user_op::OpKernel, public user_op::CudaGraphSu if (out->shape().elem_cnt() == 0) { return; } int64_t offset = 0; - if (state != nullptr) { - auto* gather_state = dynamic_cast(state); - CHECK_NOTNULL(gather_state); - CHECK_EQ(in->shape().At(axis), gather_state->upper() - gather_state->lower()); - offset = gather_state->lower(); + if (cache != nullptr) { + auto* gather_cache = dynamic_cast(cache); + CHECK_NOTNULL(gather_cache); + CHECK_EQ(in->shape().At(axis), gather_cache->upper() - gather_cache->lower()); + offset = gather_cache->lower(); } GatherKernelUtilImpl::Forward(ctx->stream(), indices->dptr(), num_indices, diff --git a/oneflow/user/kernels/generate_random_batch_permutation_indices_kernel.cpp b/oneflow/user/kernels/generate_random_batch_permutation_indices_kernel.cpp index bba320f106a..a804d9076dd 100644 --- a/oneflow/user/kernels/generate_random_batch_permutation_indices_kernel.cpp +++ b/oneflow/user/kernels/generate_random_batch_permutation_indices_kernel.cpp @@ -16,7 +16,7 @@ limitations under the License. #include #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" -#include "oneflow/user/kernels/op_kernel_state_wrapper.h" +#include "oneflow/user/kernels/op_kernel_wrapper.h" namespace oneflow { @@ -32,7 +32,8 @@ class GenerateRandomBatchPermutationIndicesCPUKernel final : public user_op::OpK } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { auto* random_generator = dynamic_cast*>(state); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); std::iota(y->mut_dptr(), y->mut_dptr() + y->shape().elem_cnt(), 0); diff --git a/oneflow/user/kernels/generate_random_batch_permutation_indices_kernel.cu b/oneflow/user/kernels/generate_random_batch_permutation_indices_kernel.cu index eac6e759ade..baa2ae9586f 100644 --- a/oneflow/user/kernels/generate_random_batch_permutation_indices_kernel.cu +++ b/oneflow/user/kernels/generate_random_batch_permutation_indices_kernel.cu @@ -17,7 +17,7 @@ limitations under the License. #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/kernel/random_generator.h" #include "oneflow/user/kernels/radix_sort.cuh" -#include "oneflow/user/kernels/op_kernel_state_wrapper.h" +#include "oneflow/user/kernels/op_kernel_wrapper.h" #include "oneflow/core/ep/cuda/cuda_stream.h" namespace oneflow { @@ -91,7 +91,8 @@ class GenerateRandomBatchPermutationIndicesGPUKernel final : public user_op::OpK private: using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { auto* random_generator = dynamic_cast>*>(state); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); diff --git a/oneflow/user/kernels/gpt_data_loader_kernel.cpp b/oneflow/user/kernels/gpt_data_loader_kernel.cpp index 155091a595a..c04544c647f 100644 --- a/oneflow/user/kernels/gpt_data_loader_kernel.cpp +++ b/oneflow/user/kernels/gpt_data_loader_kernel.cpp @@ -134,7 +134,8 @@ class GPTDataLoaderKernel final : public OpKernel { } private: - void Compute(KernelComputeContext* ctx, OpKernelState* state) const override { + void Compute(KernelComputeContext* ctx, OpKernelState* state, + const OpKernelCache*) const override { auto* loader = dynamic_cast(state); user_op::Tensor* iteration_tensor = ctx->Tensor4ArgNameAndIndex("iteration", 0); user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); diff --git a/oneflow/user/kernels/group_conv_kernel.cpp b/oneflow/user/kernels/group_conv_kernel.cpp index ab60efe4708..ce759b44aee 100644 --- a/oneflow/user/kernels/group_conv_kernel.cpp +++ b/oneflow/user/kernels/group_conv_kernel.cpp @@ -306,7 +306,7 @@ struct ConvKernelUtil final { }; template -struct ConvOpKernelState final : public user_op::OpKernelState { +struct ConvOpKernelCache final : public user_op::OpKernelCache { Im2ColFunc im2col_func_ = ConvKernelUtil::NCDHWIm2Col; Col2ImFunc col2im_func_ = ConvKernelUtil::NCDHWCol2Im; GemmFunc forward_func_ = Gemm4ChannelLast; @@ -323,31 +323,16 @@ struct ConvOpKernelState final : public user_op::OpKernelState { int32_t idx_offset_ = 0; bool is_dynamic_ = false; int32_t groups = 1; - - void Update(const ShapeView& x_shape, const ShapeView& out_shape) { - auto Gen5DShape = [](const ShapeView& shape, int32_t idx_offset) -> Shape { - DimVector ret_vec; - shape.ToDimVector(&ret_vec); - int32_t ndims = ret_vec.size() - 2; - ret_vec.insert(ret_vec.begin() + idx_offset, 3 - ndims, 1); - return Shape(ret_vec); - }; - if (is_dynamic_) { - Shape in_shape; - in_5d_shape_ = Gen5DShape(x_shape, idx_offset_); - out_5d_shape_ = Gen5DShape(out_shape, idx_offset_); - } - } }; template -std::shared_ptr> CreateConvOpKernelState(user_op::KernelComputeContext* ctx, +std::shared_ptr> CreateConvOpKernelCache(user_op::KernelCacheContext* ctx, const std::string& in_name, const std::string& out_name, const std::string& weight_name) { const auto& data_format = ctx->Attr("data_format"); - std::shared_ptr> state(new ConvOpKernelState()); + std::shared_ptr> state(new ConvOpKernelCache()); if (data_format == "channels_first") { state->im2col_func_ = ConvKernelUtil::NCDHWIm2Col; state->col2im_func_ = ConvKernelUtil::NCDHWCol2Im; @@ -410,9 +395,15 @@ class ConvCpuKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& conv_state = CreateConvOpKernelState(ctx, "in", "out", "weight"); - CHECK_NOTNULL(conv_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateConvOpKernelCache(ctx, "in", "out", "weight"); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + const auto* conv_cache = dynamic_cast*>(cache); + CHECK_NOTNULL(conv_cache); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); @@ -420,32 +411,32 @@ class ConvCpuKernel final : public user_op::OpKernel { user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); T* col_buf_dptr = tmp_buffer->mut_dptr(); - int32_t idx_offset = conv_state->idx_offset_; - const int32_t input_group_interval = in->shape().At(1) / conv_state->groups; - const int32_t weight_group_interval = weight->shape().At(0) / conv_state->groups; - const int32_t output_group_interval = out->shape().At(1) / conv_state->groups; + int32_t idx_offset = conv_cache->idx_offset_; + const int32_t input_group_interval = in->shape().At(1) / conv_cache->groups; + const int32_t weight_group_interval = weight->shape().At(0) / conv_cache->groups; + const int32_t output_group_interval = out->shape().At(1) / conv_cache->groups; const int32_t input_step = input_group_interval * in->shape().Count(2); const int32_t weight_step = weight_group_interval * weight->shape().Count(1); const int32_t output_step = output_group_interval * out->shape().Count(2); - const int32_t m = conv_state->weight_5d_shape_.At(0) / conv_state->groups; - const int32_t n = conv_state->out_5d_shape_.Count(idx_offset, idx_offset + 3); - const int32_t k = conv_state->weight_5d_shape_.Count(1); + const int32_t m = conv_cache->weight_5d_shape_.At(0) / conv_cache->groups; + const int32_t n = conv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3); + const int32_t k = conv_cache->weight_5d_shape_.Count(1); bool is_bias_mul_inited = false; for (int64_t i = 0; i < in->shape().At(0); ++i) { const T* input_ptr = GetImgDptr(in, i); const T* weight_ptr = weight->dptr(); T* output_ptr = GetImgMutDptr(out, i); - for (int64_t g = 0; g < conv_state->groups; g++) { - conv_state->im2col_func_( - input_ptr, ShapeView(conv_state->in_5d_shape_), ShapeView(conv_state->weight_5d_shape_), - ShapeView(conv_state->out_5d_shape_), conv_state->strides_3d_.data(), - conv_state->dilation_rate_3d_.data(), conv_state->padding_before_3d_.data(), + for (int64_t g = 0; g < conv_cache->groups; g++) { + conv_cache->im2col_func_( + input_ptr, ShapeView(conv_cache->in_5d_shape_), ShapeView(conv_cache->weight_5d_shape_), + ShapeView(conv_cache->out_5d_shape_), conv_cache->strides_3d_.data(), + conv_cache->dilation_rate_3d_.data(), conv_cache->padding_before_3d_.data(), col_buf_dptr); // channels first: out = weight * col_buf // channels last: out = (weight * col_buf)(T) - conv_state->forward_func_(CblasNoTrans, CblasNoTrans, + conv_cache->forward_func_(CblasNoTrans, CblasNoTrans, m, // filter / groups n, // od * oh * ow k, // ci * kd * kh * kw / groups @@ -470,10 +461,10 @@ class ConvCpuKernel final : public user_op::OpKernel { // channels first: out += bias * bias_mul // channels last: out += (bias * bias_mul)(T) - conv_state->forward_func_( + conv_cache->forward_func_( CblasNoTrans, CblasNoTrans, - conv_state->weight_5d_shape_.At(0), // filter - conv_state->out_5d_shape_.Count(idx_offset, idx_offset + 3), // od * oh * ow + conv_cache->weight_5d_shape_.At(0), // filter + conv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3), // od * oh * ow 1, // 1 static_cast(1), bias->dptr(), bias_mul_dptr, static_cast(1), GetImgMutDptr(out, i)); @@ -522,25 +513,31 @@ class ConvDataGradCpuKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& conv_state = CreateConvOpKernelState(ctx, "dx", "dy", "filter"); - CHECK_NOTNULL(conv_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateConvOpKernelCache(ctx, "dx", "dy", "filter"); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + const auto* conv_cache = dynamic_cast*>(cache); + CHECK_NOTNULL(conv_cache); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* filter = ctx->Tensor4ArgNameAndIndex("filter", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); user_op::Tensor* col_buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); - int32_t idx_offset = conv_state->idx_offset_; - const int32_t dy_group_interval = dy->shape().At(1) / conv_state->groups; - const int32_t filter_group_interval = filter->shape().At(0) / conv_state->groups; - const int32_t dx_group_interval = dx->shape().At(1) / conv_state->groups; + int32_t idx_offset = conv_cache->idx_offset_; + const int32_t dy_group_interval = dy->shape().At(1) / conv_cache->groups; + const int32_t filter_group_interval = filter->shape().At(0) / conv_cache->groups; + const int32_t dx_group_interval = dx->shape().At(1) / conv_cache->groups; const int32_t dx_step = dx_group_interval * dx->shape().Count(2); const int32_t filter_step = filter_group_interval * filter->shape().Count(1); const int32_t dy_step = dy_group_interval * dy->shape().Count(2); - const int32_t m = conv_state->weight_5d_shape_.Count(1); - const int32_t n = conv_state->out_5d_shape_.Count(idx_offset, idx_offset + 3); - const int32_t k = conv_state->weight_5d_shape_.At(0) / conv_state->groups; + const int32_t m = conv_cache->weight_5d_shape_.Count(1); + const int32_t n = conv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3); + const int32_t k = conv_cache->weight_5d_shape_.At(0) / conv_cache->groups; Memset(ctx->stream(), dx->mut_dptr(), 0, dx->shape().elem_cnt() * sizeof(T)); @@ -549,22 +546,22 @@ class ConvDataGradCpuKernel final : public user_op::OpKernel { const T* filter_ptr = filter->dptr(); const T* dy_ptr = GetImgDptr(dy, i); T* dx_ptr = GetImgMutDptr(dx, i); - FOR_RANGE(int64_t, g, 0, conv_state->groups) { + FOR_RANGE(int64_t, g, 0, conv_cache->groups) { // channels first: col_buf' = weight(T) * out[i]' // channels last : col_buf' = weight(T) * out[i]'(T) NewKernelUtil::OFGemm( - nullptr, CblasTrans, conv_state->is_out_diff_need_trans_, + nullptr, CblasTrans, conv_cache->is_out_diff_need_trans_, m, // ci * kd * kh * kw / groups n, // od * oh * ow k, // filter / groups static_cast(1), filter_ptr, dy_ptr, static_cast(0), col_buf->mut_dptr()); // in' = col2im(col_buf') - conv_state->col2im_func_( - col_buf->dptr(), ShapeView(conv_state->in_5d_shape_), - ShapeView(conv_state->weight_5d_shape_), ShapeView(conv_state->out_5d_shape_), - conv_state->strides_3d_.data(), conv_state->dilation_rate_3d_.data(), - conv_state->padding_before_3d_.data(), dx_ptr); + conv_cache->col2im_func_( + col_buf->dptr(), ShapeView(conv_cache->in_5d_shape_), + ShapeView(conv_cache->weight_5d_shape_), ShapeView(conv_cache->out_5d_shape_), + conv_cache->strides_3d_.data(), conv_cache->dilation_rate_3d_.data(), + conv_cache->padding_before_3d_.data(), dx_ptr); filter_ptr += filter_step; dy_ptr += dy_step; dx_ptr += dx_step; @@ -614,24 +611,30 @@ class ConvFilterGradCpuKernel final : public user_op::OpKernel { bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } private: - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& conv_state = CreateConvOpKernelState(ctx, "x", "dy", "filter_diff"); - CHECK_NOTNULL(conv_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateConvOpKernelCache(ctx, "x", "dy", "filter_diff"); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + const auto* conv_cache = dynamic_cast*>(cache); + CHECK_NOTNULL(conv_cache); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* filter_diff = ctx->Tensor4ArgNameAndIndex("filter_diff", 0); user_op::Tensor* col_buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); - int32_t idx_offset = conv_state->idx_offset_; - const int32_t dy_group_interval = dy->shape().At(1) / conv_state->groups; - const int32_t filter_diff_group_interval = filter_diff->shape().At(0) / conv_state->groups; - const int32_t x_group_interval = x->shape().At(1) / conv_state->groups; + int32_t idx_offset = conv_cache->idx_offset_; + const int32_t dy_group_interval = dy->shape().At(1) / conv_cache->groups; + const int32_t filter_diff_group_interval = filter_diff->shape().At(0) / conv_cache->groups; + const int32_t x_group_interval = x->shape().At(1) / conv_cache->groups; const int32_t x_step = x_group_interval * x->shape().Count(2); const int32_t dy_step = dy_group_interval * dy->shape().Count(2); const int32_t filter_diff_step = filter_diff_group_interval * filter_diff->shape().Count(1); - const int32_t m = conv_state->weight_5d_shape_.At(0) / conv_state->groups; - const int32_t n = conv_state->weight_5d_shape_.Count(1); - const int32_t k = conv_state->out_5d_shape_.Count(idx_offset, idx_offset + 3); + const int32_t m = conv_cache->weight_5d_shape_.At(0) / conv_cache->groups; + const int32_t n = conv_cache->weight_5d_shape_.Count(1); + const int32_t k = conv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3); Memset(ctx->stream(), filter_diff->mut_dptr(), 0, filter_diff->shape().elem_cnt() * sizeof(T)); @@ -639,17 +642,17 @@ class ConvFilterGradCpuKernel final : public user_op::OpKernel { const T* x_ptr = GetImgDptr(x, i); const T* dy_ptr = GetImgDptr(dy, i); T* filter_diff_ptr = filter_diff->mut_dptr(); - FOR_RANGE(int64_t, g, 0, conv_state->groups) { - conv_state->im2col_func_( - x_ptr, ShapeView(conv_state->in_5d_shape_), ShapeView(conv_state->weight_5d_shape_), - ShapeView(conv_state->out_5d_shape_), conv_state->strides_3d_.data(), - conv_state->dilation_rate_3d_.data(), conv_state->padding_before_3d_.data(), + FOR_RANGE(int64_t, g, 0, conv_cache->groups) { + conv_cache->im2col_func_( + x_ptr, ShapeView(conv_cache->in_5d_shape_), ShapeView(conv_cache->weight_5d_shape_), + ShapeView(conv_cache->out_5d_shape_), conv_cache->strides_3d_.data(), + conv_cache->dilation_rate_3d_.data(), conv_cache->padding_before_3d_.data(), col_buf->mut_dptr()); // channels first: weight' += out[i]' * col_buf(T) // channels last : weight' += out[i]'(T) * col_buf(T) NewKernelUtil::OFGemm( - nullptr, conv_state->is_out_diff_need_trans_, CblasTrans, + nullptr, conv_cache->is_out_diff_need_trans_, CblasTrans, m, // filter / groups n, // ci * kd * kh * kw k, // od * oh * ow / groups diff --git a/oneflow/user/kernels/group_deconv_kernel.cpp b/oneflow/user/kernels/group_deconv_kernel.cpp new file mode 100644 index 00000000000..80d0d361bd7 --- /dev/null +++ b/oneflow/user/kernels/group_deconv_kernel.cpp @@ -0,0 +1,429 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/job/lazy_mode.h" +#include "oneflow/user/ops/nn_util.h" +#include "oneflow/core/kernel/new_kernel_util.h" +#include "oneflow/core/kernel/kernel_util.h" + +namespace oneflow { + +namespace { + +template +using Col2ImFunc = void (*)(const T* col_buf, const ShapeView& in_shape, + const ShapeView& weight_shape, const ShapeView& out_shape, + const int32_t* strides, const int32_t* dilation_rate, + const int32_t* padding_before, T* in_diff_ptr); + +template +void Gemm4ChannelFirst(enum CBLAS_TRANSPOSE trans_a, enum CBLAS_TRANSPOSE trans_b, const int m, + const int n, const int k, const T alpha, const T* a, const T* b, + const T beta, T* c) { + NewKernelUtil::OFGemm(nullptr, trans_a, trans_b, m, n, k, alpha, a, b, beta, c); +} + +template +void Gemm4ChannelLast(enum CBLAS_TRANSPOSE trans_a, enum CBLAS_TRANSPOSE trans_b, const int m, + const int n, const int k, const T alpha, const T* a, const T* b, const T beta, + T* c) { + trans_a = (trans_a == CblasNoTrans) ? CblasTrans : CblasNoTrans; + trans_b = (trans_b == CblasNoTrans) ? CblasTrans : CblasNoTrans; + NewKernelUtil::OFGemm(nullptr, trans_b, trans_a, n, m, k, alpha, b, a, beta, c); +} + +template +T* GetImgMutDptr(user_op::Tensor* tensor, int64_t idx) { + return tensor->mut_dptr() + tensor->shape().Count(1) * idx; +} + +template +const T* GetImgDptr(const user_op::Tensor* tensor, int64_t idx) { + return tensor->dptr() + tensor->shape().Count(1) * idx; +} + +size_t CalcElemNumOfColBuf(const ShapeView& out_shape, const ShapeView& weight_shape, + const int32_t idx_offset) { + int64_t col_buf_elem_cnt = 1; + int64_t ndims = out_shape.NumAxes() - 2; + for (size_t i = 0; i != ndims + 1; ++i) { col_buf_elem_cnt *= weight_shape.At(i + 1); } + for (size_t i = 0; i != ndims; ++i) { col_buf_elem_cnt *= out_shape.At(idx_offset + i); } + return col_buf_elem_cnt; +} + +template +class ColBufWriter { + public: + ColBufWriter(const T* src_ptr, T* dst_ptr, int64_t c_size, int64_t id_size, int64_t ih_size, + int64_t iw_size, int64_t od_size, int64_t oh_size, int64_t ow_size) + : src_ptr_(src_ptr), + dst_ptr_(dst_ptr), + c_size_(c_size), + id_size_(id_size), + ih_size_(ih_size), + iw_size_(iw_size), + od_size_(od_size), + oh_size_(oh_size), + ow_size_(ow_size) {} + virtual ~ColBufWriter() = default; + virtual void DHWCWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) = 0; + virtual void CDHWWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) = 0; + virtual void InvalidDFunc() = 0; + virtual void InvalidHFunc() = 0; + virtual void InvalidWFunc() = 0; + virtual void NextImCSize() = 0; + + protected: + const T* src_ptr_; + T* dst_ptr_; + int64_t c_size_ = 0; + int64_t id_size_ = 0; + int64_t ih_size_ = 0; + int64_t iw_size_ = 0; + int64_t od_size_ = 0; + int64_t oh_size_ = 0; + int64_t ow_size_ = 0; +}; + +template +class Col2ImWriter final : public ColBufWriter { + public: + Col2ImWriter(const T* src_ptr, T* dst_ptr, int64_t c_size, int64_t id_size, int64_t ih_size, + int64_t iw_size, int64_t od_size, int64_t oh_size, int64_t ow_size) + : ColBufWriter::ColBufWriter(src_ptr, dst_ptr, c_size, id_size, ih_size, iw_size, od_size, + oh_size, ow_size) {} + ~Col2ImWriter() = default; + void DHWCWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override { + this->dst_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw * this->iw_size_ + c] += + *(this->src_ptr_++); + } + void CDHWWrite(int64_t c, int64_t id, int64_t ih, int64_t iw) override { + this->dst_ptr_[id * this->id_size_ + ih * this->ih_size_ + iw] += *(this->src_ptr_++); + } + void InvalidDFunc() override { this->src_ptr_ += this->od_size_; } + void InvalidHFunc() override { this->src_ptr_ += this->oh_size_; } + void InvalidWFunc() override { this->src_ptr_ += this->ow_size_; } + void NextImCSize() override { this->dst_ptr_ += this->c_size_; } +}; + +template +using DHWValidFunc = void (ColBufWriter::*)(int64_t c, int64_t kd, int64_t kh, int64_t kw); + +template +class ColBufUtil final { + public: + ColBufUtil(const ShapeView& in_shape, const ShapeView& out_shape, int32_t dhw_offset, + const int32_t* strides, const int32_t* dilation_rate, const int32_t* padding_before, + const int32_t id_num, const int32_t ih_num, const int32_t iw_num, const int32_t od_num, + const int32_t oh_num, const int32_t ow_num) + : strides_(strides), + dilation_rate_(dilation_rate), + padding_before_(padding_before), + id_num_(id_num), + ih_num_(ih_num), + iw_num_(iw_num), + od_num_(od_num), + oh_num_(oh_num), + ow_num_(ow_num) { + if (dhw_offset == 2) { + dhw_valid_func_ = &ColBufWriter::CDHWWrite; + } else { + dhw_valid_func_ = &ColBufWriter::DHWCWrite; + } + } + void operator()(ColBufWriter* col_buf_writer, int64_t c, int64_t kd, int64_t kh, int64_t kw) { + int64_t id = kd * dilation_rate_[0] - padding_before_[0]; + FOR_RANGE(int64_t, od, 0, od_num_) { + if (id < 0 || id >= id_num_) { + col_buf_writer->InvalidDFunc(); + } else { + int64_t ih = kh * dilation_rate_[1] - padding_before_[1]; + FOR_RANGE(int64_t, oh, 0, oh_num_) { + if (ih < 0 || ih >= ih_num_) { + col_buf_writer->InvalidHFunc(); + } else { + int64_t iw = kw * dilation_rate_[2] - padding_before_[2]; + FOR_RANGE(int64_t, ow, 0, ow_num_) { + if (iw < 0 || iw >= iw_num_) { + col_buf_writer->InvalidWFunc(); + } else { + (col_buf_writer->*dhw_valid_func_)(c, id, ih, iw); + } + iw += strides_[2]; + } + } + ih += strides_[1]; + } + } + id += strides_[0]; + } + } + + private: + const int32_t* strides_; + const int32_t* dilation_rate_; + const int32_t* padding_before_; + DHWValidFunc dhw_valid_func_; + int64_t id_num_ = 0; + int64_t ih_num_ = 0; + int64_t iw_num_ = 0; + int64_t od_num_ = 0; + int64_t oh_num_ = 0; + int64_t ow_num_ = 0; +}; + +template +struct DeconvKernelUtil final { + public: + static void NCDHWCol2Im(const T* col_buf_ptr, const ShapeView& in_shape, + const ShapeView& weight_shape, const ShapeView& out_shape, + const int32_t* strides, const int32_t* dilation_rate, + const int32_t* padding_before, T* in_diff_ptr) { + ColBufUtil col_buf_util(in_shape, out_shape, 2, strides, dilation_rate, padding_before, + in_shape.At(2), in_shape.At(3), in_shape.At(4), out_shape.At(2), + out_shape.At(3), out_shape.At(4)); + Col2ImWriter col_buf_writer(col_buf_ptr, in_diff_ptr, in_shape.Count(2), in_shape.Count(3), + in_shape.Count(4), 1, out_shape.Count(3), out_shape.Count(4), 1); + DoNCDWHFunc(weight_shape, col_buf_util, &col_buf_writer); + } + + static void NDHWCCol2Im(const T* col_buf_ptr, const ShapeView& in_shape, + const ShapeView& weight_shape, const ShapeView& out_shape, + const int32_t* strides, const int32_t* dilation_rate, + const int32_t* padding_before, T* in_diff_ptr) { + ColBufUtil col_buf_util(in_shape, out_shape, 2, strides, dilation_rate, padding_before, + in_shape.At(2), in_shape.At(3), in_shape.At(4), out_shape.At(2), + out_shape.At(3), out_shape.At(4)); + Col2ImWriter col_buf_writer(col_buf_ptr, in_diff_ptr, in_shape.Count(2), in_shape.Count(2), + in_shape.Count(3), in_shape.Count(4), out_shape.Count(2, 4), + out_shape.Count(3, 4), 1); + DoNDWHCFunc(weight_shape, col_buf_util, &col_buf_writer); + } + + private: + static void DoNCDWHFunc(const ShapeView& weight_shape, ColBufUtil& col_buf_util, + ColBufWriter* col_buf_writer) { + for (int64_t c = 0; c != weight_shape.At(1); col_buf_writer->NextImCSize(), ++c) { + for (int64_t kd = 0; kd != weight_shape.At(2); ++kd) { + for (int64_t kh = 0; kh != weight_shape.At(3); ++kh) { + for (int64_t kw = 0; kw != weight_shape.At(4); ++kw) { + col_buf_util(col_buf_writer, c, kd, kh, kw); + } + } + } + } + } + + static void DoNDWHCFunc(const ShapeView& weight_shape, ColBufUtil& col_buf_util, + ColBufWriter* col_buf_writer) { + for (int64_t kd = 0; kd != weight_shape.At(1); ++kd) { + for (int64_t kh = 0; kh != weight_shape.At(2); ++kh) { + for (int64_t kw = 0; kw != weight_shape.At(3); ++kw) { + for (int64_t c = 0; c != weight_shape.At(4); ++c) { + col_buf_util(col_buf_writer, c, kd, kh, kw); + } + } + } + } + } +}; + +template +struct DeconvOpKernelCache final : public user_op::OpKernelCache { + Col2ImFunc col2im_func_ = DeconvKernelUtil::NCDHWCol2Im; + ; + + Shape in_5d_shape_; + Shape out_5d_shape_; + Shape weight_5d_shape_; + + std::vector strides_3d_; + std::vector dilation_rate_3d_; + std::vector padding_before_3d_; + + enum CBLAS_TRANSPOSE is_out_diff_need_trans_ = CblasNoTrans; + int32_t idx_offset_ = 0; + bool is_dynamic_ = false; + int32_t groups = 1; + + void Update(const ShapeView& x_shape, const ShapeView& out_shape) { + auto Gen5DShape = [](const ShapeView& shape, int32_t idx_offset) -> Shape { + DimVector ret_vec; + shape.ToDimVector(&ret_vec); + int32_t ndims = ret_vec.size() - 2; + ret_vec.insert(ret_vec.begin() + idx_offset, 3 - ndims, 1); + return Shape(ret_vec); + }; + if (is_dynamic_) { + Shape in_shape; + in_5d_shape_ = Gen5DShape(x_shape, idx_offset_); + out_5d_shape_ = Gen5DShape(out_shape, idx_offset_); + } + } +}; + +template +std::shared_ptr> CreateDeconvOpKernelCache(user_op::KernelCacheContext* ctx, + const std::string& in_name, + const std::string& out_name, + const std::string& weight_name) { + const auto& data_format = ctx->Attr("data_format"); + + std::shared_ptr> state(new DeconvOpKernelCache()); + if (data_format == "channels_first") { + state->col2im_func_ = DeconvKernelUtil::NCDHWCol2Im; + state->is_out_diff_need_trans_ = CblasNoTrans; + state->idx_offset_ = 2; + } else { + state->col2im_func_ = DeconvKernelUtil::NDHWCCol2Im; + state->is_out_diff_need_trans_ = CblasTrans; + state->idx_offset_ = 1; + } + + auto Gen5DShape = [](const Shape& shape, int32_t idx_offset) -> Shape { + DimVector ret_vec(shape.dim_vec()); + int32_t ndims = ret_vec.size() - 2; + ret_vec.insert(ret_vec.begin() + idx_offset, 3 - ndims, 1); + return Shape(ret_vec); + }; + state->groups = ctx->Attr("groups"); + + state->in_5d_shape_ = + Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(in_name, 0)->shape(), state->idx_offset_); + state->out_5d_shape_ = + Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(out_name, 0)->shape(), state->idx_offset_); + state->weight_5d_shape_ = + Gen5DShape(ctx->TensorDesc4ArgNameAndIndex(weight_name, 0)->shape(), state->idx_offset_); + + auto Gen3DVec = [](const std::vector& origin_vec) -> std::vector { + std::vector ret_vec = origin_vec; + ret_vec.insert(ret_vec.begin(), 3 - ret_vec.size(), 1); + return ret_vec; + }; + state->strides_3d_ = Gen3DVec(ctx->Attr>("strides")); + state->dilation_rate_3d_ = Gen3DVec(ctx->Attr>("dilation_rate")); + state->is_dynamic_ = ctx->TensorDesc4ArgNameAndIndex(in_name, 0)->is_dynamic(); + const auto& padding_before = ctx->Attr>("padding_before"); + FOR_RANGE(uint8_t, dim, 0, 3) { + int64_t index = static_cast(dim) - (3 - padding_before.size()); + if (index < 0) { + state->padding_before_3d_.emplace_back(0); + } else { + state->padding_before_3d_.emplace_back(padding_before.at(index)); + } + } + + return state; +} + +template +class DeconvCpuKernel final : public user_op::OpKernel { + public: + OF_DISALLOW_COPY_AND_MOVE(DeconvCpuKernel); + DeconvCpuKernel() = default; + ~DeconvCpuKernel() = default; + + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } + + using user_op::OpKernel::InitOpKernelCache; + void InitOpKernelCache(user_op::KernelCacheContext* ctx, int8_t flag, + std::shared_ptr* cache_ptr) const override { + if (*cache_ptr != nullptr && (flag & user_op::OpKernelCache::kAttrNotChanged)) { + auto deconv_cache = std::dynamic_pointer_cast>(*cache_ptr); + deconv_cache->Update(ctx->TensorDesc4ArgNameAndIndex("in", 0)->shape(), + ctx->TensorDesc4ArgNameAndIndex("out", 0)->shape()); + return; + } + *cache_ptr = CreateDeconvOpKernelCache(ctx, "out", "in", "weight"); + } + + private: + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + auto deconv_cache = dynamic_cast*>(cache); + CHECK_NOTNULL(deconv_cache); + const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); + const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); + user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); + user_op::Tensor* col_buf = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); + + int32_t idx_offset = deconv_cache->idx_offset_; + const int32_t input_group_interval = in->shape().At(1) / deconv_cache->groups; + const int32_t weight_group_interval = weight->shape().At(0) / deconv_cache->groups; + const int32_t output_group_interval = out->shape().At(1) / deconv_cache->groups; + const int32_t input_step = input_group_interval * in->shape().Count(2); + const int32_t weight_step = weight_group_interval * weight->shape().Count(1); + const int32_t output_step = output_group_interval * out->shape().Count(2); + const int32_t m = deconv_cache->weight_5d_shape_.Count(1); + const int32_t n = deconv_cache->out_5d_shape_.Count(idx_offset, idx_offset + 3); + const int32_t k = deconv_cache->weight_5d_shape_.At(0) / deconv_cache->groups; + + Memset(ctx->stream(), out->mut_dptr(), 0, + out->shape().elem_cnt() * sizeof(T)); + FOR_RANGE(int64_t, i, 0, in->shape().At(0)) { + const T* input_ptr = GetImgDptr(in, i); + const T* weight_ptr = weight->dptr(); + T* output_ptr = GetImgMutDptr(out, i); + + FOR_RANGE(int64_t, g, 0, deconv_cache->groups) { + NewKernelUtil::OFGemm( + ctx->stream(), CblasTrans, deconv_cache->is_out_diff_need_trans_, + + m, // (co / groups) * kd * kh * kw + n, // od * oh * ow + k, // filter / groups + static_cast(1), weight_ptr, input_ptr, static_cast(0), col_buf->mut_dptr()); + // out = col2im(col_buf') + deconv_cache->col2im_func_( + col_buf->mut_dptr(), ShapeView(deconv_cache->in_5d_shape_), + ShapeView(deconv_cache->weight_5d_shape_), ShapeView(deconv_cache->out_5d_shape_), + deconv_cache->strides_3d_.data(), deconv_cache->dilation_rate_3d_.data(), + deconv_cache->padding_before_3d_.data(), output_ptr); + input_ptr += input_step; + weight_ptr += weight_step; + output_ptr += output_step; + } + } + } +}; + +#define REGISTER_DECONV_DATA_KERNEL(op_name, dtype) \ + REGISTER_USER_KERNEL(#op_name) \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ + && (user_op::HobAttr("groups") > 1) \ + && (user_op::HobDataType("out", 0) == GetDataType::value)) \ + .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ + size_t tmp_buffer_size = 0; \ + const auto& in_shape = ctx->InputTensorDesc("in", 0).shape(); \ + const auto& weight_shape = ctx->InputTensorDesc("weight", 0).shape(); \ + \ + int64_t idx_offset = IdxOffset(ctx->Attr("data_format")); \ + tmp_buffer_size += \ + CalcElemNumOfColBuf(in_shape, weight_shape, idx_offset) * sizeof(dtype); \ + return tmp_buffer_size; \ + }) + +REGISTER_DECONV_DATA_KERNEL(deconv1d, float); +REGISTER_DECONV_DATA_KERNEL(deconv1d, double); +REGISTER_DECONV_DATA_KERNEL(deconv2d, float); +REGISTER_DECONV_DATA_KERNEL(deconv2d, double); +REGISTER_DECONV_DATA_KERNEL(deconv3d, float); +REGISTER_DECONV_DATA_KERNEL(deconv3d, double); + +} // namespace + +} // namespace oneflow diff --git a/oneflow/user/kernels/heap_selection_top_k_kernel.cu b/oneflow/user/kernels/heap_selection_top_k_kernel.cu index a0e5582df56..aa4c32c7829 100644 --- a/oneflow/user/kernels/heap_selection_top_k_kernel.cu +++ b/oneflow/user/kernels/heap_selection_top_k_kernel.cu @@ -21,7 +21,7 @@ namespace oneflow { namespace { template -T PowOf2Floor(T val, int32_t max_power) { +T PowOf2Floor(T val, int64_t max_power) { CHECK_GT(val, GetZeroVal()); T max_floor = static_cast(std::pow(2, max_power)); val = std::min(val, max_floor); @@ -33,7 +33,7 @@ T PowOf2Floor(T val, int32_t max_power) { } template -T PowOf2Ceil(T val, int32_t max_power) { +T PowOf2Ceil(T val, int64_t max_power) { CHECK_GT(val, GetZeroVal()); T max_ceil = static_cast(std::pow(2, max_power)); val = std::min(val, max_ceil); @@ -45,7 +45,7 @@ T PowOf2Ceil(T val, int32_t max_power) { } template -__device__ void BitonicSwap(T* data, const int32_t i, const int32_t j, const bool dir, +__device__ void BitonicSwap(T* data, const int64_t i, const int64_t j, const bool dir, const Compare& comp) { if (comp(data[i], data[j]) == dir) { T tmp = data[i]; @@ -56,15 +56,15 @@ __device__ void BitonicSwap(T* data, const int32_t i, const int32_t j, const boo // https://en.wikipedia.org/wiki/Bitonic_sorter template -__device__ void BitonicSort(T* data, const int32_t elem_cnt, const Compare& comp) { +__device__ void BitonicSort(T* data, const int64_t elem_cnt, const Compare& comp) { // The element count of instance should be pow-of-2 assert(elem_cnt > 0 && !(elem_cnt & (elem_cnt - 1))); // Generate a bitonic sequence from input - for (int32_t size = 2; size <= elem_cnt / 2; size *= 2) { + for (int64_t size = 2; size <= elem_cnt / 2; size *= 2) { // Merge 2 bitonic sequences of length 'size' into a bitonic sequence of length '2 * size' - for (int32_t stride = size / 2; stride > 0; stride /= 2) { - for (int32_t swap_id = threadIdx.x; swap_id < elem_cnt / 2; swap_id += blockDim.x) { + for (int64_t stride = size / 2; stride > 0; stride /= 2) { + for (int64_t swap_id = threadIdx.x; swap_id < elem_cnt / 2; swap_id += blockDim.x) { // Change dir at intervals of 'size / 2' swaps const bool dir = swap_id & (size / 2); // Locate the pair {pos, pos + stride} which is going te be swaped if needed @@ -78,8 +78,8 @@ __device__ void BitonicSort(T* data, const int32_t elem_cnt, const Compare& comp } // Sort the bitonic sequence - for (int32_t stride = elem_cnt / 2; stride > 0; stride /= 2) { - for (int32_t swap_id = threadIdx.x; swap_id < elem_cnt / 2; swap_id += blockDim.x) { + for (int64_t stride = elem_cnt / 2; stride > 0; stride /= 2) { + for (int64_t swap_id = threadIdx.x; swap_id < elem_cnt / 2; swap_id += blockDim.x) { // Locate the pair {pos, pos + stride} which is going te be swaped if needed const int pos = 2 * swap_id - (swap_id & (stride - 1)); @@ -93,11 +93,11 @@ __device__ void BitonicSort(T* data, const int32_t elem_cnt, const Compare& comp template class Entry final { public: - __device__ __forceinline__ Entry(int32_t index, T value) : index_(index), value_(value) {} + __device__ __forceinline__ Entry(int64_t index, T value) : index_(index), value_(value) {} - __device__ __forceinline__ int32_t GetIndex() const { return index_; } + __device__ __forceinline__ int64_t GetIndex() const { return index_; } __device__ __forceinline__ T GetValue() const { return value_; } - __device__ __forceinline__ void SetIndex(int32_t index) { index_ = index; } + __device__ __forceinline__ void SetIndex(int64_t index) { index_ = index; } __device__ __forceinline__ void SetValue(T value) { value_ = value; } __device__ __forceinline__ bool operator<(const Entry& entry) const { @@ -108,32 +108,32 @@ class Entry final { } private: - int32_t index_; + int64_t index_; T value_; }; template class MinHeap final { public: - __device__ __forceinline__ MinHeap(Entry* data, const int32_t heap_size, - const int32_t init_index, const T init_value) + __device__ __forceinline__ MinHeap(Entry* data, const int64_t heap_size, + const int64_t init_index, const T init_value) : data_(data), heap_size_(heap_size) { - for (int32_t i = 0; i < heap_size; ++i) { + for (int64_t i = 0; i < heap_size; ++i) { data_[i].SetIndex(init_index); data_[i].SetValue(init_value); } } __device__ __forceinline__ Entry& Top() { return data_[0]; } - __device__ __forceinline__ void Swap(const int32_t i, const int32_t j) { + __device__ __forceinline__ void Swap(const int64_t i, const int64_t j) { auto tmp = data_[j]; data_[j] = data_[i]; data_[i] = tmp; } - __device__ __forceinline__ void MinHeapify(int32_t index) { + __device__ __forceinline__ void MinHeapify(int64_t index) { while (true) { - const int32_t left = 2 * index + 1; - const int32_t right = 2 * index + 2; - int32_t min = index; + const int64_t left = 2 * index + 1; + const int64_t right = 2 * index + 2; + int64_t min = index; if (left < heap_size_ && data_[left] < data_[min]) { min = left; } if (right < heap_size_ && data_[right] < data_[min]) { min = right; } if (min == index) { return; } @@ -144,14 +144,14 @@ class MinHeap final { private: Entry* data_; - int32_t heap_size_; + int64_t heap_size_; }; template -__global__ void HeapTopKKernel(const T* in_ptr, const int32_t instance_num, - const int32_t instance_size, const int32_t k, - const int32_t heap_size, const int32_t init_index, - const T init_value, int32_t* out_ptr) { +__global__ void HeapTopKKernel(const T* in_ptr, const int64_t instance_num, + const int64_t instance_size, const int64_t k, + const int64_t heap_size, const int64_t init_index, + const T init_value, int64_t* out_ptr) { extern __shared__ char smem[]; auto* shared_entries = reinterpret_cast*>(smem); @@ -161,7 +161,7 @@ __global__ void HeapTopKKernel(const T* in_ptr, const int32_t instance_num, const T* input = in_ptr + blockIdx.x * instance_size; auto heap = MinHeap(shared_entries + threadIdx.x * heap_size, heap_size, init_index, init_value); - for (int32_t i = threadIdx.x; i < instance_size; i += blockDim.x) { + for (int64_t i = threadIdx.x; i < instance_size; i += blockDim.x) { auto entry = Entry(i, input[i]); if (entry > heap.Top()) { heap.Top() = entry; @@ -176,7 +176,7 @@ __global__ void HeapTopKKernel(const T* in_ptr, const int32_t instance_num, [](const Entry& x, const Entry& y) { return x > y; }); // Write top_k elements in sorted array to output - for (int32_t i = threadIdx.x; i < k; i += blockDim.x) { + for (int64_t i = threadIdx.x; i < k; i += blockDim.x) { (out_ptr + blockIdx.x * k)[i] = shared_entries[i].GetIndex(); } } @@ -196,14 +196,14 @@ class GpuHeapSelectionTopKKernel final : public user_op::OpKernel { if (in->shape().elem_cnt() == 0) { return; } user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); - const int32_t instance_size = in->shape().At(in->shape().NumAxes() - 1); - const int32_t instance_num = in->shape().elem_cnt() / instance_size; - const int32_t k = std::min(ctx->Attr("k"), instance_size); + const int64_t instance_size = in->shape().At(in->shape().NumAxes() - 1); + const int64_t instance_num = in->shape().elem_cnt() / instance_size; + const int64_t k = std::min(static_cast(ctx->Attr("k")), instance_size); // Use as many heaps as possible (# of heaps == # of threads used in thread block). // Limitation 1: size of shared memory // We also need heap_size * num_heap to be pow-of-2 which is necessary for bitonic sort - const int32_t heap_size = PowOf2Ceil(k, 16); + const int64_t heap_size = PowOf2Ceil(k, 16); int32_t num_heap = PowOf2Floor(kCudaMaxSharedMemoryByteSize / (heap_size * sizeof(Entry)), 16); // Limitation 2: # of threads in thread block @@ -211,8 +211,8 @@ class GpuHeapSelectionTopKKernel final : public user_op::OpKernel { HeapTopKKernel<<), ctx->stream()->As()->cuda_stream()>>>( - in->dptr(), instance_num, instance_size, k, heap_size, GetMaxVal(), - GetMinVal(), out->mut_dptr()); + in->dptr(), instance_num, instance_size, k, heap_size, GetMaxVal(), + GetMinVal(), out->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; diff --git a/oneflow/user/kernels/image_preprocess_kernels.cpp b/oneflow/user/kernels/image_preprocess_kernels.cpp index da243bf517f..dba10579831 100644 --- a/oneflow/user/kernels/image_preprocess_kernels.cpp +++ b/oneflow/user/kernels/image_preprocess_kernels.cpp @@ -132,7 +132,8 @@ class CropMirrorNormalizeFromStaticShapeToFloatKernel final : public user_op::Op } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { auto* cmn_attr = dynamic_cast(state); const std::vector& mean_vec = cmn_attr->mean_vec(); const std::vector& inv_std_vec = cmn_attr->inv_std_vec(); @@ -213,7 +214,8 @@ class CropMirrorNormalizeFromTensorBufferToFloatKernel final : public user_op::O } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { auto* cmn_attr = dynamic_cast(state); const std::vector& mean_vec = cmn_attr->mean_vec(); const std::vector& inv_std_vec = cmn_attr->inv_std_vec(); @@ -322,7 +324,8 @@ class CoinFlipKernel final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { auto* rand_bool_gen = dynamic_cast(state); user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); int8_t* dptr = out_blob->mut_dptr(); @@ -381,7 +384,8 @@ class ImageRandomCropKernel final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { auto* crop_window_generators = dynamic_cast(state); CHECK_NOTNULL(crop_window_generators); user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); diff --git a/oneflow/user/kernels/image_preprocess_kernels.cu b/oneflow/user/kernels/image_preprocess_kernels.cu index 6d8b139eff1..30fda3bd96d 100644 --- a/oneflow/user/kernels/image_preprocess_kernels.cu +++ b/oneflow/user/kernels/image_preprocess_kernels.cu @@ -141,7 +141,8 @@ class CropMirrorNormalizeGpuKernel final : public user_op::OpKernel { private: using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { auto* normalize_attr = dynamic_cast(state); const NormalizeVal& mean = normalize_attr->mean(); const NormalizeVal& inv_std = normalize_attr->inv_std(); diff --git a/oneflow/user/kernels/l2_normalize_kernel.cu b/oneflow/user/kernels/l2_normalize_kernel.cu index a7148517900..141a70e9899 100644 --- a/oneflow/user/kernels/l2_normalize_kernel.cu +++ b/oneflow/user/kernels/l2_normalize_kernel.cu @@ -24,7 +24,7 @@ namespace { template __global__ void L2NormalizeForward(const int32_t n, const int32_t c, const int32_t d, const T epsilon, const T* in, T* square_x_sum, T* out) { - using BlockReduce = cub::BlockReduce; + using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; for (int32_t i = blockIdx.x; i < n; i += gridDim.x) { @@ -54,7 +54,7 @@ __global__ void L2NormalizeBackward(const int32_t n, const int32_t c, const int3 const T inv_norm = rsqrt(fmaxf(square_x_sum[i], epsilon)); const int32_t offset = (i / d) * d * c + (i % d); if (square_x_sum[i] >= epsilon) { - using BlockReduce = cub::BlockReduce; + using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage_prod_sum; T y_dy_prod_sum = GetZeroVal(); diff --git a/oneflow/user/kernels/math_binary_elementwise_func.h b/oneflow/user/kernels/math_binary_elementwise_func.h index 6f8bb2aa2c2..03474b0c562 100644 --- a/oneflow/user/kernels/math_binary_elementwise_func.h +++ b/oneflow/user/kernels/math_binary_elementwise_func.h @@ -19,6 +19,7 @@ limitations under the License. #include "oneflow/core/common/util.h" #include "oneflow/core/common/data_type.h" #include "oneflow/user/ops/math_binary_elementwise_seq.h" +#include "oneflow/core/device/cuda_pseudo_half.h" #if defined(__CUDACC__) diff --git a/oneflow/user/kernels/math_unary_elementwise_func.h b/oneflow/user/kernels/math_unary_elementwise_func.h index 7ee240c407d..c5ebf38d5b6 100644 --- a/oneflow/user/kernels/math_unary_elementwise_func.h +++ b/oneflow/user/kernels/math_unary_elementwise_func.h @@ -19,6 +19,7 @@ limitations under the License. #include "oneflow/core/common/util.h" #include "oneflow/core/common/data_type.h" #include "oneflow/user/ops/math_unary_elementwise_seq.h" +#include "oneflow/core/device/cuda_pseudo_half.h" #if defined(__CUDACC__) @@ -238,6 +239,15 @@ struct LogFunctor { static OF_DEVICE_FUNC float Backward(const float x, const float dy) { return dy * (1.0f / x); } }; +template<> +struct Log2Functor { + static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(log2, x); } + + static OF_DEVICE_FUNC float Backward(const float x, const float dy) { + return dy * (1.0f / (x * MATH_FUNC_F(log, 2.0f))); + } +}; + template<> struct Log1pFunctor { static OF_DEVICE_FUNC float Forward(const float x) { return MATH_FUNC_F(log1p, x); } @@ -509,6 +519,15 @@ struct LogFunctor { static OF_DEVICE_FUNC double Backward(const double x, const double dy) { return dy * (1.0 / x); } }; +template<> +struct Log2Functor { + static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(log2, x); } + + static OF_DEVICE_FUNC double Backward(const double x, const double dy) { + return dy * (1.0 / (x * MATH_FUNC_D(log, 2.0))); + } +}; + template<> struct Log1pFunctor { static OF_DEVICE_FUNC double Forward(const double x) { return MATH_FUNC_D(log1p, x); } @@ -795,6 +814,15 @@ struct LogFunctor { static OF_HALF_FUNC half Backward(const half x, const half dy) { return __hmul(dy, hrcp(x)); } }; +template<> +struct Log2Functor { + static OF_HALF_FUNC half Forward(const half x) { return hlog2(x); } + + static OF_HALF_FUNC half Backward(const half x, const half dy) { + return __hmul(dy, hrcp(__hmul(x, hlog(HALF_VAL_TWO)))); + } +}; + template<> struct Log1pFunctor { static OF_HALF_FUNC half Forward(const half x) { return MATH_FUNC_H(log1p, x); } diff --git a/oneflow/user/kernels/min_max_observer_kernel.cu b/oneflow/user/kernels/min_max_observer_kernel.cu index 291140a63e2..fcd9a66e109 100644 --- a/oneflow/user/kernels/min_max_observer_kernel.cu +++ b/oneflow/user/kernels/min_max_observer_kernel.cu @@ -162,11 +162,18 @@ __global__ void CalScaleZeroPointCambricon(const T* max_ptr, const T* min_ptr, } } +ep::CudaLaunchConfig GetLaunchConfig(ep::CudaStream* stream, size_t thread_num, + size_t shared_mem_size) { + ep::CudaLaunchConfig config; + stream->InitLaunchConfigWithWaves(&config, thread_num, kCudaThreadsNumPerBlock, 1); + config.shared_mem_size = shared_mem_size; + return config; +} + } // namespace -#define LAUNCH_CUDA_KERNEL(func, stream_ptr, thread_num, shared_mem_size, ...) \ - func<<As()->cuda_stream()>>>(__VA_ARGS__) +#define LAUNCH_CUDA_KERNEL(func, stream, thread_num, shared_mem_size, ...) \ + (stream)->LaunchKernel(func, GetLaunchConfig((stream), thread_num, shared_mem_size), __VA_ARGS__); template class GpuMinMaxObserverKernel final : public user_op::OpKernel { @@ -192,38 +199,38 @@ class GpuMinMaxObserverKernel final : public user_op::OpKernel { const int64_t panel_size = elements / channel; T* max_ptr = tmp_buffer->mut_dptr(); T* min_ptr = max_ptr + channel; - - LAUNCH_CUDA_KERNEL((InitMaxMin), ctx->stream(), channel, 0, channel, max_ptr, min_ptr); + auto* cuda_stream = ctx->stream()->As(); + LAUNCH_CUDA_KERNEL((InitMaxMin), cuda_stream, channel, 0, channel, max_ptr, min_ptr); if (per_layer_quantization) { - LAUNCH_CUDA_KERNEL((ReduceMaxMinPerLayer), ctx->stream(), elements, + LAUNCH_CUDA_KERNEL((ReduceMaxMinPerLayer), cuda_stream, elements, kCudaThreadsNumPerBlock * 2 * sizeof(T), in->dptr(), elements, max_ptr, min_ptr); } else { // per-channel quantization // NOTE(Liang Depeng): each block of threads will be responsible for // computing the max and min values of the whole channel. - LAUNCH_CUDA_KERNEL((ReduceMaxMinPerChannel), ctx->stream(), + LAUNCH_CUDA_KERNEL((ReduceMaxMinPerChannel), cuda_stream, channel * kCudaThreadsNumPerBlock, kCudaThreadsNumPerBlock * 2 * sizeof(T), in->dptr(), elements, channel, panel_size, max_ptr, min_ptr); } if (quantization_formula == "google") { if (quantization_scheme == "symmetric") { - LAUNCH_CUDA_KERNEL((CalScaleZeroPointSymmetric), ctx->stream(), channel, 0, max_ptr, + LAUNCH_CUDA_KERNEL((CalScaleZeroPointSymmetric), cuda_stream, channel, 0, max_ptr, min_ptr, channel, static_cast(quantization_bit), scale->mut_dptr(), zero_point->mut_dptr()); } else { // quantization_scheme == "affine" - LAUNCH_CUDA_KERNEL((CalScaleZeroPointAffine), ctx->stream(), channel, 0, max_ptr, - min_ptr, channel, static_cast(quantization_bit), - scale->mut_dptr(), zero_point->mut_dptr()); + LAUNCH_CUDA_KERNEL((CalScaleZeroPointAffine), cuda_stream, channel, 0, max_ptr, min_ptr, + channel, static_cast(quantization_bit), scale->mut_dptr(), + zero_point->mut_dptr()); } } else if (quantization_formula == "cambricon") { if (!per_layer_quantization) { UNIMPLEMENTED() << " per-channel mode is not supported in cambricon scheme"; } - LAUNCH_CUDA_KERNEL((CalScaleZeroPointCambricon), ctx->stream(), channel, 0, max_ptr, - min_ptr, channel, static_cast(quantization_bit), - scale->mut_dptr(), zero_point->mut_dptr()); + LAUNCH_CUDA_KERNEL((CalScaleZeroPointCambricon), cuda_stream, channel, 0, max_ptr, min_ptr, + channel, static_cast(quantization_bit), scale->mut_dptr(), + zero_point->mut_dptr()); } else { UNIMPLEMENTED(); } diff --git a/oneflow/user/kernels/model_update_kernels.cpp b/oneflow/user/kernels/model_update_kernels.cpp index 07cf8796c52..b233b7d42f5 100644 --- a/oneflow/user/kernels/model_update_kernels.cpp +++ b/oneflow/user/kernels/model_update_kernels.cpp @@ -78,10 +78,10 @@ class TmpBufferManager final { void* ptr_; }; -class IndexedSlicesUpdateOpKernelState final : public user_op::OpKernelState { +class IndexedSlicesUpdateOpKernelCache final : public user_op::OpKernelCache { public: - IndexedSlicesUpdateOpKernelState(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {} - ~IndexedSlicesUpdateOpKernelState() override = default; + IndexedSlicesUpdateOpKernelCache(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {} + ~IndexedSlicesUpdateOpKernelCache() override = default; int64_t lower() const { return lower_; } int64_t upper() const { return upper_; } @@ -91,8 +91,8 @@ class IndexedSlicesUpdateOpKernelState final : public user_op::OpKernelState { const int64_t upper_; }; -std::shared_ptr CreateIndexedSlicesUpdateOpKernelState( - user_op::KernelInitContext* ctx) { +std::shared_ptr CreateIndexedSlicesUpdateOpKernelCache( + user_op::KernelCacheContext* ctx) { const cfg::SbpParallel& model_sbp = ctx->SbpParallel4ArgNameAndIndex("model", 0); const user_op::TensorDesc* model_logical_desc = ctx->LogicalTensorDesc4ArgNameAndIndex("model", 0); @@ -102,11 +102,11 @@ std::shared_ptr CreateIndexedSlicesUpdateOpKernelState( CHECK(ctx->SbpParallel4ArgNameAndIndex("model_diff_indices", 0).has_broadcast_parallel()); CHECK(ctx->SbpParallel4ArgNameAndIndex("model_diff_values", 0).has_broadcast_parallel()); BalancedSplitter bs(num_model_instances, ctx->parallel_ctx().parallel_num()); - return std::make_shared( + return std::make_shared( bs.At(ctx->parallel_ctx().parallel_id()).begin(), bs.At(ctx->parallel_ctx().parallel_id()).end()); } else { - return std::make_shared(0, num_model_instances); + return std::make_shared(0, num_model_instances); } } @@ -184,15 +184,16 @@ class IndexedSlicesSGDUpdateKernel final : public user_op::OpKernel { IndexedSlicesSGDUpdateKernel() = default; ~IndexedSlicesSGDUpdateKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return CreateIndexedSlicesUpdateOpKernelState(ctx); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateIndexedSlicesUpdateOpKernelCache(ctx); } private: using ReduceSumUtilT = IndexedSlicesReduceSumKernelUtil; using MdUpdateUtilT = IndexedSlicesSGDUpdateKernelUtil; - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); const user_op::Tensor* model_diff_indices = ctx->Tensor4ArgNameAndIndex("model_diff_indices", 0); @@ -208,9 +209,9 @@ class IndexedSlicesSGDUpdateKernel final : public user_op::OpKernel { CHECK_NE(num_values, 0); CHECK_EQ(num_values % num_indices, 0); const int64_t feature_size = num_values / num_indices; - auto* kernel_state = dynamic_cast(state); - CHECK_NOTNULL(kernel_state); - CHECK_EQ(model->shape().At(0), kernel_state->upper() - kernel_state->lower()); + auto* kernel_cache = dynamic_cast(cache); + CHECK_NOTNULL(kernel_cache); + CHECK_EQ(model->shape().At(0), kernel_cache->upper() - kernel_cache->lower()); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); TmpBufferManager buffer_manager(tmp_buffer->mut_dptr(), num_indices, num_values); @@ -221,7 +222,7 @@ class IndexedSlicesSGDUpdateKernel final : public user_op::OpKernel { buffer_manager.UniqueDiffIndicesPtr(), buffer_manager.UniqueDiffValuesPtr(), buffer_manager.UniqueWorkspacePtr(), buffer_manager.UniqueWorkspaceBytes()); MdUpdateUtilT::Update(ctx->stream(), weight_decay, num_indices, feature_size, - kernel_state->lower(), kernel_state->upper(), + kernel_cache->lower(), kernel_cache->upper(), buffer_manager.NumUniqueDiffIndicesPtr(), learning_rate->dptr(), buffer_manager.UniqueDiffIndicesPtr(), buffer_manager.UniqueDiffValuesPtr(), model->mut_dptr()); @@ -311,15 +312,16 @@ class IndexedSlicesMomentumUpdateKernel final : public user_op::OpKernel { IndexedSlicesMomentumUpdateKernel() = default; ~IndexedSlicesMomentumUpdateKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return CreateIndexedSlicesUpdateOpKernelState(ctx); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateIndexedSlicesUpdateOpKernelCache(ctx); } private: using ReduceSumUtilT = IndexedSlicesReduceSumKernelUtil; using MdUpdateUtilT = IndexedSlicesMomentumMdUpdateKernelUtil; - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* learning_rate = ctx->Tensor4ArgNameAndIndex("learning_rate", 0); const user_op::Tensor* model_diff_indices = ctx->Tensor4ArgNameAndIndex("model_diff_indices", 0); @@ -338,9 +340,9 @@ class IndexedSlicesMomentumUpdateKernel final : public user_op::OpKernel { CHECK_EQ(num_values % num_indices, 0); const int64_t feature_size = num_values / num_indices; CHECK_EQ(feature_size, model_diff_values->shape().Count(model_diff_indices->shape().NumAxes())); - auto* kernel_state = dynamic_cast(state); - CHECK_NOTNULL(kernel_state); - CHECK_EQ(model->shape().At(0), kernel_state->upper() - kernel_state->lower()); + auto* kernel_cache = dynamic_cast(cache); + CHECK_NOTNULL(kernel_cache); + CHECK_EQ(model->shape().At(0), kernel_cache->upper() - kernel_cache->lower()); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); TmpBufferManager buffer_manager(tmp_buffer->mut_dptr(), num_indices, num_values); @@ -351,8 +353,8 @@ class IndexedSlicesMomentumUpdateKernel final : public user_op::OpKernel { buffer_manager.UniqueDiffIndicesPtr(), buffer_manager.UniqueDiffValuesPtr(), buffer_manager.UniqueWorkspacePtr(), buffer_manager.UniqueWorkspaceBytes()); MdUpdateUtilT::Update( - ctx->stream(), beta, weight_decay, num_indices, feature_size, kernel_state->lower(), - kernel_state->upper(), buffer_manager.NumUniqueDiffIndicesPtr(), + ctx->stream(), beta, weight_decay, num_indices, feature_size, kernel_cache->lower(), + kernel_cache->upper(), buffer_manager.NumUniqueDiffIndicesPtr(), learning_rate->dptr(), buffer_manager.UniqueDiffIndicesPtr(), buffer_manager.UniqueDiffValuesPtr(), model->mut_dptr(), momentum->mut_dptr()); } @@ -534,15 +536,16 @@ class IndexedSlicesAdamUpdateKernel final : public user_op::OpKernel { public: IndexedSlicesAdamUpdateKernel() = default; ~IndexedSlicesAdamUpdateKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return CreateIndexedSlicesUpdateOpKernelState(ctx); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateIndexedSlicesUpdateOpKernelCache(ctx); } private: using ReduceSumUtilT = IndexedSlicesReduceSumKernelUtil; using MdUpdateUtilT = IndexedSlicesAdamMdUpdateKernelUtil; - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const float learning_rate_val = ctx->Attr("learning_rate_val"); const float* learning_rate_ptr = nullptr; if (ctx->has_input("learning_rate", 0)) { @@ -579,9 +582,9 @@ class IndexedSlicesAdamUpdateKernel final : public user_op::OpKernel { const bool amsgrad = ctx->Attr("amsgrad"); const bool do_bias_correction = ctx->Attr("do_bias_correction"); - auto* kernel_state = dynamic_cast(state); - CHECK_NOTNULL(kernel_state); - CHECK_EQ(model->shape().At(0), kernel_state->upper() - kernel_state->lower()); + auto* kernel_cache = dynamic_cast(cache); + CHECK_NOTNULL(kernel_cache); + CHECK_EQ(model->shape().At(0), kernel_cache->upper() - kernel_cache->lower()); const int64_t num_indices = model_diff_indices->shape().elem_cnt(); const int64_t num_values = model_diff_values->shape().elem_cnt(); if (num_indices == 0) { @@ -605,7 +608,7 @@ class IndexedSlicesAdamUpdateKernel final : public user_op::OpKernel { MdUpdateUtilT::Update( ctx->stream(), beta1, beta2, epsilon, weight_decay, amsgrad, do_bias_correction, - learning_rate_val, num_indices, feature_size, kernel_state->lower(), kernel_state->upper(), + learning_rate_val, num_indices, feature_size, kernel_cache->lower(), kernel_cache->upper(), buffer_manager.NumUniqueDiffIndicesPtr(), learning_rate_ptr, bias_correction1_ptr, bias_correction2_ptr, buffer_manager.UniqueDiffIndicesPtr(), buffer_manager.UniqueDiffValuesPtr(), model->mut_dptr(), m->mut_dptr(), diff --git a/oneflow/user/kernels/moving_average_min_max_observer_kernel.cu b/oneflow/user/kernels/moving_average_min_max_observer_kernel.cu index dc0e5c46ca0..d0398fa97c1 100644 --- a/oneflow/user/kernels/moving_average_min_max_observer_kernel.cu +++ b/oneflow/user/kernels/moving_average_min_max_observer_kernel.cu @@ -203,11 +203,18 @@ __global__ void CalFreezeScaleZeroPointCambricon(const int64_t elements, } } +ep::CudaLaunchConfig GetLaunchConfig(ep::CudaStream* stream, size_t thread_num, + size_t shared_mem_size) { + ep::CudaLaunchConfig config; + stream->InitLaunchConfigWithWaves(&config, thread_num, kCudaThreadsNumPerBlock, 1); + config.shared_mem_size = shared_mem_size; + return config; +} + } // namespace -#define LAUNCH_CUDA_KERNEL(func, stream_ptr, thread_num, shared_mem_size, ...) \ - func<<As()->cuda_stream()>>>(__VA_ARGS__) +#define LAUNCH_CUDA_KERNEL(func, stream, thread_num, shared_mem_size, ...) \ + (stream)->LaunchKernel(func, GetLaunchConfig((stream), thread_num, shared_mem_size), __VA_ARGS__); template class GpuMovingAverageMinMaxObserverKernel final : public user_op::OpKernel { @@ -242,10 +249,10 @@ class GpuMovingAverageMinMaxObserverKernel final : public user_op::OpKernel { OF_CUDA_CHECK(cudaMemcpy(host_current_train_step_ptr, current_train_step->dptr(), current_train_step->shape().elem_cnt() * sizeof(int64_t), cudaMemcpyDefault)); - + auto* cuda_stream = ctx->stream()->As(); if (*host_current_train_step_ptr <= stop_update_after_iters && is_training) { - LAUNCH_CUDA_KERNEL((InitMaxMin), ctx->stream(), 1, 0, 1, max_ptr, min_ptr); - LAUNCH_CUDA_KERNEL((ReduceMaxMinPerLayer), ctx->stream(), elements, + LAUNCH_CUDA_KERNEL((InitMaxMin), cuda_stream, 1, 0, 1, max_ptr, min_ptr); + LAUNCH_CUDA_KERNEL((ReduceMaxMinPerLayer), cuda_stream, elements, kCudaThreadsNumPerBlock * 2 * sizeof(T), in->dptr(), elements, max_ptr, min_ptr); } @@ -253,23 +260,23 @@ class GpuMovingAverageMinMaxObserverKernel final : public user_op::OpKernel { if (quantization_formula == "google") { if (quantization_scheme == "symmetric") { if (moving) { - LAUNCH_CUDA_KERNEL((CalScaleZeroPointSymmetric), ctx->stream(), 1, 0, 1, + LAUNCH_CUDA_KERNEL((CalScaleZeroPointSymmetric), cuda_stream, 1, 0, 1, static_cast(quantization_bit), momentum, max_ptr, min_ptr, moving_max->mut_dptr(), moving_min->mut_dptr(), scale->mut_dptr(), zero_point->mut_dptr()); } else { - LAUNCH_CUDA_KERNEL((CalFreezeScaleZeroPointSymmetric), ctx->stream(), 1, 0, 1, + LAUNCH_CUDA_KERNEL((CalFreezeScaleZeroPointSymmetric), cuda_stream, 1, 0, 1, static_cast(quantization_bit), momentum, moving_max->dptr(), scale->mut_dptr(), zero_point->mut_dptr()); } } else { // quantization_scheme == "affine" if (moving) { - LAUNCH_CUDA_KERNEL((CalScaleZeroPointAffine), ctx->stream(), 1, 0, 1, + LAUNCH_CUDA_KERNEL((CalScaleZeroPointAffine), cuda_stream, 1, 0, 1, static_cast(quantization_bit), momentum, max_ptr, min_ptr, moving_max->mut_dptr(), moving_min->mut_dptr(), scale->mut_dptr(), zero_point->mut_dptr()); } else { - LAUNCH_CUDA_KERNEL((CalFreezeScaleZeroPointAffine), ctx->stream(), 1, 0, 1, + LAUNCH_CUDA_KERNEL((CalFreezeScaleZeroPointAffine), cuda_stream, 1, 0, 1, static_cast(quantization_bit), momentum, moving_max->dptr(), moving_min->dptr(), scale->mut_dptr(), zero_point->mut_dptr()); @@ -277,12 +284,12 @@ class GpuMovingAverageMinMaxObserverKernel final : public user_op::OpKernel { } } else if (quantization_formula == "cambricon") { if (moving) { - LAUNCH_CUDA_KERNEL((CalScaleZeroPointCambricon), ctx->stream(), 1, 0, 1, + LAUNCH_CUDA_KERNEL((CalScaleZeroPointCambricon), cuda_stream, 1, 0, 1, static_cast(quantization_bit), momentum, max_ptr, min_ptr, moving_max->mut_dptr(), moving_min->mut_dptr(), scale->mut_dptr(), zero_point->mut_dptr()); } else { - LAUNCH_CUDA_KERNEL((CalFreezeScaleZeroPointCambricon), ctx->stream(), 1, 0, 1, + LAUNCH_CUDA_KERNEL((CalFreezeScaleZeroPointCambricon), cuda_stream, 1, 0, 1, static_cast(quantization_bit), momentum, moving_max->dptr(), scale->mut_dptr(), zero_point->mut_dptr()); } diff --git a/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp b/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp index 775a7557141..5c16e2452e4 100644 --- a/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp +++ b/oneflow/user/kernels/nccl_logical_2d_sbp_kernels.cpp @@ -97,7 +97,8 @@ class NcclLogical2DSameDim0AllReduce final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { auto* nccl_comm = dynamic_cast(state); CHECK(nccl_comm != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); @@ -123,7 +124,8 @@ class NcclLogical2DSameDim0AllGather final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { auto* nccl_comm = dynamic_cast(state); CHECK(nccl_comm != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); @@ -150,7 +152,8 @@ class NcclLogical2DSameDim0AllGatherNoncontinuous final : public user_op::OpKern } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { auto* nccl_comm = dynamic_cast(state); CHECK(nccl_comm != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); @@ -212,7 +215,8 @@ class NcclLogical2DSameDim0All2All final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { auto* nccl_comm = dynamic_cast(state); CHECK(nccl_comm != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); @@ -370,7 +374,8 @@ class NcclLogical2DSameDim1AllReduce final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { auto* nccl_comm = dynamic_cast(state); CHECK(nccl_comm != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); diff --git a/oneflow/user/kernels/nccl_logical_kernels.cpp b/oneflow/user/kernels/nccl_logical_kernels.cpp index db8500ba73b..dc2c4396a57 100644 --- a/oneflow/user/kernels/nccl_logical_kernels.cpp +++ b/oneflow/user/kernels/nccl_logical_kernels.cpp @@ -76,7 +76,8 @@ class NcclLogicalAllReduceKernel final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { auto* nccl_comm = dynamic_cast(state); CHECK(nccl_comm != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); @@ -102,7 +103,8 @@ class NcclLogicalReduceScatterKernel final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { auto* nccl_comm = dynamic_cast(state); CHECK(nccl_comm != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); @@ -129,7 +131,8 @@ class NcclLogicalAllGatherKernel final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { auto* nccl_comm = dynamic_cast(state); CHECK(nccl_comm != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); @@ -156,7 +159,8 @@ class NcclLogicalAllGatherNoncontinuous final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { auto* nccl_comm = dynamic_cast(state); CHECK(nccl_comm != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); @@ -217,7 +221,8 @@ class NcclLogicalS2SKernel final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { auto* nccl_comm = dynamic_cast(state); CHECK(nccl_comm != nullptr); const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); diff --git a/oneflow/user/kernels/nvtx_range_kernel.cu b/oneflow/user/kernels/nvtx_range_kernel.cu index 2a5e52ebcd3..95bcdced4a2 100644 --- a/oneflow/user/kernels/nvtx_range_kernel.cu +++ b/oneflow/user/kernels/nvtx_range_kernel.cu @@ -58,7 +58,8 @@ class NvtxStartKernel final : public user_op::OpKernel { private: using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const ShapeView& in_shape = in->shape(); @@ -100,7 +101,8 @@ class NvtxEndKernel final : public user_op::OpKernel { private: using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); const ShapeView& in_shape = in->shape(); diff --git a/oneflow/user/kernels/ofrecord_decoder_kernels.cpp b/oneflow/user/kernels/ofrecord_decoder_kernels.cpp index 785cb725504..98a9341e039 100644 --- a/oneflow/user/kernels/ofrecord_decoder_kernels.cpp +++ b/oneflow/user/kernels/ofrecord_decoder_kernels.cpp @@ -23,7 +23,7 @@ limitations under the License. #include "oneflow/user/image/random_crop_generator.h" #include "oneflow/user/image/image_util.h" #include "oneflow/user/kernels/random_crop_kernel_state.h" -#include "oneflow/user/kernels/op_kernel_state_wrapper.h" +#include "oneflow/user/kernels/op_kernel_wrapper.h" #include "oneflow/user/kernels/random_seed_util.h" #include @@ -229,7 +229,8 @@ class OFRecordImageDecoderRandomCropKernel final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { auto* crop_window_generators = dynamic_cast(state); CHECK_NOTNULL(crop_window_generators); user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); diff --git a/oneflow/user/kernels/ofrecord_image_classification_reader_kernel.cpp b/oneflow/user/kernels/ofrecord_image_classification_reader_kernel.cpp index ded3931b5d2..adc3f4dfde2 100644 --- a/oneflow/user/kernels/ofrecord_image_classification_reader_kernel.cpp +++ b/oneflow/user/kernels/ofrecord_image_classification_reader_kernel.cpp @@ -45,7 +45,8 @@ class OFRecordImageClassificationReaderKernel final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { auto* reader = dynamic_cast(state); CHECK_NOTNULL(reader); reader->Read(ctx); diff --git a/oneflow/user/kernels/ofrecord_reader_kernel.cpp b/oneflow/user/kernels/ofrecord_reader_kernel.cpp index a421366e385..84458aac900 100644 --- a/oneflow/user/kernels/ofrecord_reader_kernel.cpp +++ b/oneflow/user/kernels/ofrecord_reader_kernel.cpp @@ -45,7 +45,8 @@ class OFRecordReaderKernel final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { auto* reader = dynamic_cast(state); reader->Read(ctx); } diff --git a/oneflow/user/kernels/onerec_reader_kernel.cpp b/oneflow/user/kernels/onerec_reader_kernel.cpp index dd79b597ad6..e265ff3064f 100644 --- a/oneflow/user/kernels/onerec_reader_kernel.cpp +++ b/oneflow/user/kernels/onerec_reader_kernel.cpp @@ -45,7 +45,8 @@ class OneRecReaderKernel final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { auto* reader = dynamic_cast(state); reader->Read(ctx); } diff --git a/oneflow/user/kernels/op_kernel_state_wrapper.h b/oneflow/user/kernels/op_kernel_wrapper.h similarity index 77% rename from oneflow/user/kernels/op_kernel_state_wrapper.h rename to oneflow/user/kernels/op_kernel_wrapper.h index e22e5913653..552f1896aae 100644 --- a/oneflow/user/kernels/op_kernel_state_wrapper.h +++ b/oneflow/user/kernels/op_kernel_wrapper.h @@ -35,6 +35,21 @@ class OpKernelStateWrapper final : public user_op::OpKernelState { T data_; }; +template +class OpKernelCacheWrapper final : public user_op::OpKernelCache { + public: + template + explicit OpKernelCacheWrapper(Args&&... args) : data_(std::forward(args)...) {} + + ~OpKernelCacheWrapper() = default; + + const T& Get() const { return data_; } + T* Mutable() { return &data_; } + + private: + T data_; +}; + } // namespace oneflow #endif // ONEFLOW_USER_KERNELS_OP_KERNEL_STATE_WRAPPER_H_ diff --git a/oneflow/user/kernels/pack_kernel.cpp b/oneflow/user/kernels/pack_kernel.cpp index 34503510937..72df505f6e2 100644 --- a/oneflow/user/kernels/pack_kernel.cpp +++ b/oneflow/user/kernels/pack_kernel.cpp @@ -15,7 +15,7 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" -#include "oneflow/user/kernels/op_kernel_state_wrapper.h" +#include "oneflow/user/kernels/op_kernel_wrapper.h" namespace oneflow { @@ -34,7 +34,8 @@ class PackKernel final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); CHECK_EQ(in->data_type(), out->data_type()); diff --git a/oneflow/user/kernels/partial_fc_sample_kernel.cu b/oneflow/user/kernels/partial_fc_sample_kernel.cu index 80d172da250..6a6591408ca 100644 --- a/oneflow/user/kernels/partial_fc_sample_kernel.cu +++ b/oneflow/user/kernels/partial_fc_sample_kernel.cu @@ -313,7 +313,8 @@ class DistributedPartialFcSampleGpuKernel final : public user_op::OpKernel { private: using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { const user_op::Tensor* weight = ctx->Tensor4ArgNameAndIndex("weight", 0); const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex("label", 0); user_op::Tensor* mapped_label = ctx->Tensor4ArgNameAndIndex("mapped_label", 0); @@ -394,7 +395,8 @@ class DistributedPartialFcSampleDisableBoxingGpuKernel final : public user_op::O private: using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { const user_op::Tensor* sampled_weight_diff = ctx->Tensor4ArgNameAndIndex("sampled_weight_diff", 0); const user_op::Tensor* sampled_label = ctx->Tensor4ArgNameAndIndex("sampled_label", 0); diff --git a/oneflow/user/kernels/pool_cpu_kernel.cpp b/oneflow/user/kernels/pool_cpu_kernel.cpp index 82c9f896884..ce21276b017 100644 --- a/oneflow/user/kernels/pool_cpu_kernel.cpp +++ b/oneflow/user/kernels/pool_cpu_kernel.cpp @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" -#include "oneflow/user/kernels/op_kernel_state_wrapper.h" +#include "oneflow/user/kernels/op_kernel_wrapper.h" #include "oneflow/user/utils/pool_util.h" #include "oneflow/core/common/eigen_util.h" @@ -22,14 +22,14 @@ namespace oneflow { namespace { -struct PoolOpKernelState final : public user_op::OpKernelState { +struct PoolOpKernelCache final : public user_op::OpKernelCache { Params3D params_3d; - PoolOpKernelState(Params3D params_3d) : params_3d(params_3d) {} - const Params3D& GetParams3D() { return params_3d; } + explicit PoolOpKernelCache(const Params3D& params_3d) : params_3d(params_3d) {} + const Params3D& GetParams3D() const { return params_3d; } }; -std::shared_ptr DoCreatePoolOpKernelState(user_op::KernelComputeContext* ctx, - const int32_t& dim) { +std::shared_ptr InitPoolOpKernelCache(user_op::KernelCacheContext* ctx, + const int32_t& dim) { const Shape& x_shape = ctx->TensorDesc4ArgNameAndIndex("x", 0)->shape(); const std::string& data_format = ctx->Attr("data_format"); const std::string& padding = ctx->Attr("padding"); @@ -40,7 +40,7 @@ std::shared_ptr DoCreatePoolOpKernelState(user_op::KernelComp const bool ceil_mode = ctx->Attr("ceil_mode"); Params3D params_3d = Params3D(dim, x_shape, data_format, padding, padding_before, padding_after, pool_size, strides, ceil_mode); - std::shared_ptr state(new PoolOpKernelState(params_3d)); + std::shared_ptr state(new PoolOpKernelCache(params_3d)); return state; } @@ -249,7 +249,8 @@ struct PoolCpuKernelUtil { } } - static void AvgFWCompute(user_op::KernelComputeContext* ctx, PoolOpKernelState* pool_state) { + static void AvgFWCompute(user_op::KernelComputeContext* ctx, + const PoolOpKernelCache* pool_state) { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); CHECK_NOTNULL(pool_state); @@ -271,7 +272,8 @@ struct PoolCpuKernelUtil { } } - static void AvgBWCompute(user_op::KernelComputeContext* ctx, PoolOpKernelState* pool_state) { + static void AvgBWCompute(user_op::KernelComputeContext* ctx, + const PoolOpKernelCache* pool_state) { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); @@ -294,7 +296,8 @@ struct PoolCpuKernelUtil { } } - static void MaxFWCompute(user_op::KernelComputeContext* ctx, PoolOpKernelState* pool_state) { + static void MaxFWCompute(user_op::KernelComputeContext* ctx, + const PoolOpKernelCache* pool_state) { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); CHECK_NOTNULL(pool_state); @@ -319,7 +322,8 @@ struct PoolCpuKernelUtil { } } - static void MaxBWCompute(user_op::KernelComputeContext* ctx, PoolOpKernelState* pool_state) { + static void MaxBWCompute(user_op::KernelComputeContext* ctx, + const PoolOpKernelCache* pool_state) { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); @@ -358,9 +362,14 @@ class AvgPool1DCpuKernel final : public user_op::OpKernel { private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& pool_state = DoCreatePoolOpKernelState(ctx, 1); - PoolCpuKernelUtil::AvgFWCompute(ctx, pool_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return InitPoolOpKernelCache(ctx, 1); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + PoolCpuKernelUtil::AvgFWCompute(ctx, dynamic_cast(cache)); }; }; @@ -372,9 +381,14 @@ class AvgPool1DGradCpuKernel final : public user_op::OpKernel { private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& pool_state = DoCreatePoolOpKernelState(ctx, 1); - PoolCpuKernelUtil::AvgBWCompute(ctx, pool_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return InitPoolOpKernelCache(ctx, 1); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + PoolCpuKernelUtil::AvgBWCompute(ctx, dynamic_cast(cache)); }; }; @@ -386,9 +400,14 @@ class AvgPool2DCpuKernel final : public user_op::OpKernel { private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& pool_state = DoCreatePoolOpKernelState(ctx, 2); - PoolCpuKernelUtil::AvgFWCompute(ctx, pool_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return InitPoolOpKernelCache(ctx, 2); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + PoolCpuKernelUtil::AvgFWCompute(ctx, dynamic_cast(cache)); }; }; @@ -400,9 +419,14 @@ class AvgPool2DGradCpuKernel final : public user_op::OpKernel { private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& pool_state = DoCreatePoolOpKernelState(ctx, 2); - PoolCpuKernelUtil::AvgBWCompute(ctx, pool_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return InitPoolOpKernelCache(ctx, 2); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + PoolCpuKernelUtil::AvgBWCompute(ctx, dynamic_cast(cache)); }; }; @@ -414,9 +438,14 @@ class AvgPool3DCpuKernel final : public user_op::OpKernel { private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& pool_state = DoCreatePoolOpKernelState(ctx, 3); - PoolCpuKernelUtil::AvgFWCompute(ctx, pool_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return InitPoolOpKernelCache(ctx, 3); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + PoolCpuKernelUtil::AvgFWCompute(ctx, dynamic_cast(cache)); }; }; @@ -428,9 +457,14 @@ class AvgPool3DGradCpuKernel final : public user_op::OpKernel { private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& pool_state = DoCreatePoolOpKernelState(ctx, 3); - PoolCpuKernelUtil::AvgBWCompute(ctx, pool_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return InitPoolOpKernelCache(ctx, 3); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + PoolCpuKernelUtil::AvgBWCompute(ctx, dynamic_cast(cache)); }; }; @@ -442,9 +476,14 @@ class MaxPool1DCpuKernel final : public user_op::OpKernel { private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& pool_state = DoCreatePoolOpKernelState(ctx, 1); - PoolCpuKernelUtil::MaxFWCompute(ctx, pool_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return InitPoolOpKernelCache(ctx, 1); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + PoolCpuKernelUtil::MaxFWCompute(ctx, dynamic_cast(cache)); }; }; @@ -456,9 +495,14 @@ class MaxPool1DGradCpuKernel final : public user_op::OpKernel { private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& pool_state = DoCreatePoolOpKernelState(ctx, 1); - PoolCpuKernelUtil::MaxBWCompute(ctx, pool_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return InitPoolOpKernelCache(ctx, 1); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + PoolCpuKernelUtil::MaxBWCompute(ctx, dynamic_cast(cache)); }; }; @@ -470,9 +514,14 @@ class MaxPool2DCpuKernel final : public user_op::OpKernel { private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& pool_state = DoCreatePoolOpKernelState(ctx, 2); - PoolCpuKernelUtil::MaxFWCompute(ctx, pool_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return InitPoolOpKernelCache(ctx, 2); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + PoolCpuKernelUtil::MaxFWCompute(ctx, dynamic_cast(cache)); }; }; @@ -484,9 +533,14 @@ class MaxPool2DGradCpuKernel final : public user_op::OpKernel { private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& pool_state = DoCreatePoolOpKernelState(ctx, 2); - PoolCpuKernelUtil::MaxBWCompute(ctx, pool_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return InitPoolOpKernelCache(ctx, 2); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + PoolCpuKernelUtil::MaxBWCompute(ctx, dynamic_cast(cache)); }; }; @@ -498,9 +552,14 @@ class MaxPool3DCpuKernel final : public user_op::OpKernel { private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& pool_state = DoCreatePoolOpKernelState(ctx, 3); - PoolCpuKernelUtil::MaxFWCompute(ctx, pool_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return InitPoolOpKernelCache(ctx, 3); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + PoolCpuKernelUtil::MaxFWCompute(ctx, dynamic_cast(cache)); }; }; @@ -512,9 +571,14 @@ class MaxPool3DGradCpuKernel final : public user_op::OpKernel { private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& pool_state = DoCreatePoolOpKernelState(ctx, 3); - PoolCpuKernelUtil::MaxBWCompute(ctx, pool_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return InitPoolOpKernelCache(ctx, 3); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + PoolCpuKernelUtil::MaxBWCompute(ctx, dynamic_cast(cache)); }; }; diff --git a/oneflow/user/kernels/pool_gpu_kernel.cpp b/oneflow/user/kernels/pool_gpu_kernel.cpp index 623069c1f95..f6f049aaf0a 100644 --- a/oneflow/user/kernels/pool_gpu_kernel.cpp +++ b/oneflow/user/kernels/pool_gpu_kernel.cpp @@ -43,15 +43,15 @@ class CudnnPoolDesc final { cudnnPoolingDescriptor_t val_; }; -class GPUPoolOpKernelState final : public user_op::OpKernelState { +class GPUPoolOpKernelCache final : public user_op::OpKernelCache { public: - GPUPoolOpKernelState(const int32_t dim, const std::string& pooling_type, const ShapeView& x_shape, + GPUPoolOpKernelCache(const int32_t dim, const std::string& pooling_type, const ShapeView& x_shape, const ShapeView& y_shape, const std::string& data_format, const DataType& dtype, const Params3D& params_3d) : dim_(dim), pooling_type_(pooling_type) { Reset(dim, pooling_type, x_shape, y_shape, data_format, dtype, params_3d); } - ~GPUPoolOpKernelState() = default; + ~GPUPoolOpKernelCache() = default; void Reset(const int32_t dim, const std::string& pooling_type, const ShapeView& x_shape, const ShapeView& y_shape, const std::string& data_format, const DataType& dtype, @@ -80,10 +80,10 @@ class GPUPoolOpKernelState final : public user_op::OpKernelState { new CudnnPoolDesc(pooling_mode, dim, pool_size.data(), padding.data(), strides.data())); } - static std::shared_ptr FromKernelComputeContext( - const int32_t& dim, const std::string& pooling_type, user_op::KernelComputeContext* ctx) { + static std::shared_ptr FromKernelComputeContext( + const int32_t& dim, const std::string& pooling_type, user_op::KernelCacheContext* ctx) { if (pooling_type != "MAX" && pooling_type != "AVG") { UNIMPLEMENTED(); } - const ShapeView& x_shape = ctx->Tensor4ArgNameAndIndex("x", 0)->shape(); + const ShapeView& x_shape = ctx->TensorDesc4ArgNameAndIndex("x", 0)->shape(); const std::string& data_format = ctx->Attr("data_format"); const std::string& padding = ctx->Attr("padding"); const auto& padding_before = ctx->Attr>("padding_before"); @@ -93,9 +93,9 @@ class GPUPoolOpKernelState final : public user_op::OpKernelState { const bool ceil_mode = ctx->Attr("ceil_mode"); const Params3D params_3d(dim, x_shape, data_format, padding, padding_before, padding_after, pool_size, strides, ceil_mode); - const ShapeView& y_shape = ctx->Tensor4ArgNameAndIndex("y", 0)->shape(); - const DataType dtype = ctx->Tensor4ArgNameAndIndex("x", 0)->data_type(); - return std::make_shared(dim, pooling_type, x_shape, y_shape, data_format, + const ShapeView& y_shape = ctx->TensorDesc4ArgNameAndIndex("y", 0)->shape(); + const DataType dtype = ctx->TensorDesc4ArgNameAndIndex("x", 0)->data_type(); + return std::make_shared(dim, pooling_type, x_shape, y_shape, data_format, dtype, params_3d); } @@ -114,31 +114,31 @@ class GPUPoolOpKernelState final : public user_op::OpKernelState { template struct PoolGpuKernelUtil { static void FWCompute(user_op::KernelComputeContext* ctx, - GPUPoolOpKernelState* gpu_pool_op_kernel_state) { + const GPUPoolOpKernelCache* gpu_pool_op_kernel_cache) { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); - CHECK(gpu_pool_op_kernel_state != nullptr); + CHECK(gpu_pool_op_kernel_cache != nullptr); OF_CUDNN_CHECK(cudnnPoolingForward( ctx->stream()->As()->cudnn_handle(), - gpu_pool_op_kernel_state->cudnn_pooling_desc(), CudnnSPOnePtr(), - gpu_pool_op_kernel_state->cudnn_x_tensor_desc(), x->dptr(), CudnnSPZeroPtr(), - gpu_pool_op_kernel_state->cudnn_y_tensor_desc(), y->mut_dptr())); + gpu_pool_op_kernel_cache->cudnn_pooling_desc(), CudnnSPOnePtr(), + gpu_pool_op_kernel_cache->cudnn_x_tensor_desc(), x->dptr(), CudnnSPZeroPtr(), + gpu_pool_op_kernel_cache->cudnn_y_tensor_desc(), y->mut_dptr())); } static void BWCompute(user_op::KernelComputeContext* ctx, - GPUPoolOpKernelState* gpu_pool_op_kernel_state) { + const GPUPoolOpKernelCache* gpu_pool_op_kernel_cache) { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); - CHECK(gpu_pool_op_kernel_state != nullptr); + CHECK(gpu_pool_op_kernel_cache != nullptr); OF_CUDNN_CHECK(cudnnPoolingBackward( ctx->stream()->As()->cudnn_handle(), - gpu_pool_op_kernel_state->cudnn_pooling_desc(), CudnnSPOnePtr(), - gpu_pool_op_kernel_state->cudnn_y_tensor_desc(), y->dptr(), - gpu_pool_op_kernel_state->cudnn_y_tensor_desc(), dy->dptr(), - gpu_pool_op_kernel_state->cudnn_x_tensor_desc(), x->dptr(), CudnnSPZeroPtr(), - gpu_pool_op_kernel_state->cudnn_x_tensor_desc(), dx->mut_dptr())); + gpu_pool_op_kernel_cache->cudnn_pooling_desc(), CudnnSPOnePtr(), + gpu_pool_op_kernel_cache->cudnn_y_tensor_desc(), y->dptr(), + gpu_pool_op_kernel_cache->cudnn_y_tensor_desc(), dy->dptr(), + gpu_pool_op_kernel_cache->cudnn_x_tensor_desc(), x->dptr(), CudnnSPZeroPtr(), + gpu_pool_op_kernel_cache->cudnn_x_tensor_desc(), dx->mut_dptr())); } }; @@ -152,9 +152,13 @@ class AvgPool1DGpuKernel final : public user_op::OpKernel, public user_op::CudaG private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& pool_state = GPUPoolOpKernelState::FromKernelComputeContext(1, "AVG", ctx); - PoolGpuKernelUtil::FWCompute(ctx, pool_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return GPUPoolOpKernelCache::FromKernelComputeContext(1, "AVG", ctx); + } + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + PoolGpuKernelUtil::FWCompute(ctx, dynamic_cast(cache)); }; }; @@ -165,9 +169,13 @@ class AvgPool1DGradGpuKernel final : public user_op::OpKernel, public user_op::C ~AvgPool1DGradGpuKernel() = default; bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& pool_state = GPUPoolOpKernelState::FromKernelComputeContext(1, "AVG", ctx); - PoolGpuKernelUtil::BWCompute(ctx, pool_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return GPUPoolOpKernelCache::FromKernelComputeContext(1, "AVG", ctx); + } + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + PoolGpuKernelUtil::BWCompute(ctx, dynamic_cast(cache)); }; }; @@ -179,9 +187,13 @@ class AvgPool2DGpuKernel final : public user_op::OpKernel, public user_op::CudaG private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& pool_state = GPUPoolOpKernelState::FromKernelComputeContext(2, "AVG", ctx); - PoolGpuKernelUtil::FWCompute(ctx, pool_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return GPUPoolOpKernelCache::FromKernelComputeContext(2, "AVG", ctx); + } + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + PoolGpuKernelUtil::FWCompute(ctx, dynamic_cast(cache)); }; }; @@ -193,9 +205,13 @@ class AvgPool2DGradGpuKernel final : public user_op::OpKernel, public user_op::C private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& pool_state = GPUPoolOpKernelState::FromKernelComputeContext(2, "AVG", ctx); - PoolGpuKernelUtil::BWCompute(ctx, pool_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return GPUPoolOpKernelCache::FromKernelComputeContext(2, "AVG", ctx); + } + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + PoolGpuKernelUtil::BWCompute(ctx, dynamic_cast(cache)); }; }; @@ -207,9 +223,13 @@ class AvgPool3DGpuKernel final : public user_op::OpKernel, public user_op::CudaG private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& pool_state = GPUPoolOpKernelState::FromKernelComputeContext(3, "AVG", ctx); - PoolGpuKernelUtil::FWCompute(ctx, pool_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return GPUPoolOpKernelCache::FromKernelComputeContext(3, "AVG", ctx); + } + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + PoolGpuKernelUtil::FWCompute(ctx, dynamic_cast(cache)); }; }; @@ -221,9 +241,13 @@ class AvgPool3DGradGpuKernel final : public user_op::OpKernel, public user_op::C private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& pool_state = GPUPoolOpKernelState::FromKernelComputeContext(3, "AVG", ctx); - PoolGpuKernelUtil::BWCompute(ctx, pool_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return GPUPoolOpKernelCache::FromKernelComputeContext(3, "AVG", ctx); + } + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + PoolGpuKernelUtil::BWCompute(ctx, dynamic_cast(cache)); }; }; @@ -235,9 +259,13 @@ class MaxPool1DGpuKernel final : public user_op::OpKernel, public user_op::CudaG private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& pool_state = GPUPoolOpKernelState::FromKernelComputeContext(1, "MAX", ctx); - PoolGpuKernelUtil::FWCompute(ctx, pool_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return GPUPoolOpKernelCache::FromKernelComputeContext(1, "MAX", ctx); + } + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + PoolGpuKernelUtil::FWCompute(ctx, dynamic_cast(cache)); }; }; @@ -249,9 +277,13 @@ class MaxPool1DGradGpuKernel final : public user_op::OpKernel, public user_op::C private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& pool_state = GPUPoolOpKernelState::FromKernelComputeContext(1, "MAX", ctx); - PoolGpuKernelUtil::BWCompute(ctx, pool_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return GPUPoolOpKernelCache::FromKernelComputeContext(1, "MAX", ctx); + } + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + PoolGpuKernelUtil::BWCompute(ctx, dynamic_cast(cache)); }; }; @@ -263,9 +295,13 @@ class MaxPool2DGpuKernel final : public user_op::OpKernel, public user_op::CudaG private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& pool_state = GPUPoolOpKernelState::FromKernelComputeContext(2, "MAX", ctx); - PoolGpuKernelUtil::FWCompute(ctx, pool_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return GPUPoolOpKernelCache::FromKernelComputeContext(2, "MAX", ctx); + } + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + PoolGpuKernelUtil::FWCompute(ctx, dynamic_cast(cache)); }; }; @@ -277,9 +313,13 @@ class MaxPool2DGradGpuKernel final : public user_op::OpKernel, public user_op::C private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& pool_state = GPUPoolOpKernelState::FromKernelComputeContext(2, "MAX", ctx); - PoolGpuKernelUtil::BWCompute(ctx, pool_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return GPUPoolOpKernelCache::FromKernelComputeContext(2, "MAX", ctx); + } + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + PoolGpuKernelUtil::BWCompute(ctx, dynamic_cast(cache)); }; }; @@ -291,9 +331,13 @@ class MaxPool3DGpuKernel final : public user_op::OpKernel, public user_op::CudaG private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& pool_state = GPUPoolOpKernelState::FromKernelComputeContext(3, "MAX", ctx); - PoolGpuKernelUtil::FWCompute(ctx, pool_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return GPUPoolOpKernelCache::FromKernelComputeContext(3, "MAX", ctx); + } + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + PoolGpuKernelUtil::FWCompute(ctx, dynamic_cast(cache)); }; }; @@ -305,9 +349,13 @@ class MaxPool3DGradGpuKernel final : public user_op::OpKernel, public user_op::C private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { - const auto& pool_state = GPUPoolOpKernelState::FromKernelComputeContext(3, "MAX", ctx); - PoolGpuKernelUtil::BWCompute(ctx, pool_state.get()); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return GPUPoolOpKernelCache::FromKernelComputeContext(3, "MAX", ctx); + } + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { + PoolGpuKernelUtil::BWCompute(ctx, dynamic_cast(cache)); }; }; diff --git a/oneflow/user/kernels/pooling_kernel.cpp b/oneflow/user/kernels/pooling_kernel.cpp index 1888bf5465d..5d7651bc04b 100644 --- a/oneflow/user/kernels/pooling_kernel.cpp +++ b/oneflow/user/kernels/pooling_kernel.cpp @@ -17,14 +17,14 @@ limitations under the License. namespace oneflow { -struct PoolingOpKernelState final : public user_op::OpKernelState { +struct PoolingOpKernelCache final : public user_op::OpKernelCache { MaxPoolingParams3D params_3d; - explicit PoolingOpKernelState(const MaxPoolingParams3D& params_3d) : params_3d(params_3d) {} - const MaxPoolingParams3D& GetParams3D() { return params_3d; } + explicit PoolingOpKernelCache(const MaxPoolingParams3D& params_3d) : params_3d(params_3d) {} + const MaxPoolingParams3D& GetParams3D() const { return params_3d; } }; -std::shared_ptr DoCreateOpKernelState(user_op::KernelComputeContext* ctx, - const int32_t& dim) { +std::shared_ptr CreateOpKernelCache(user_op::KernelCacheContext* ctx, + const int32_t& dim) { const Shape& x_shape = ctx->TensorDesc4ArgNameAndIndex("x", 0)->shape(); const std::string& data_format = ctx->Attr("data_format"); const std::vector& padding = ctx->Attr>("padding"); @@ -36,8 +36,8 @@ std::shared_ptr DoCreateOpKernelState(user_op::KernelCompu MaxPoolingParams3D params_3d = MaxPoolingParams3D(dim, x_shape, data_format, padding, kernel_size, stride, dilation, return_indices, ceil_mode); - std::shared_ptr state(new PoolingOpKernelState(params_3d)); - return state; + std::shared_ptr cache(new PoolingOpKernelCache(params_3d)); + return cache; } template @@ -120,13 +120,19 @@ class MaxPool1dKernel final : public user_op::OpKernel { private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateOpKernelCache(ctx, 1); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0); - const auto& pooling_state = DoCreateOpKernelState(ctx, 1); - const MaxPoolingParams3D& params_3d = pooling_state->GetParams3D(); + const auto* pooling_cache = dynamic_cast(cache); + const MaxPoolingParams3D& params_3d = pooling_cache->GetParams3D(); const int64_t elem_num = y->shape().elem_cnt(); const T* src = x->dptr(); @@ -150,13 +156,19 @@ class MaxPool1dGradKernel final : public user_op::OpKernel { private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateOpKernelCache(ctx, 1); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); - const auto& pooling_state = DoCreateOpKernelState(ctx, 1); - const MaxPoolingParams3D& params_3d = pooling_state->GetParams3D(); + const auto* pooling_cache = dynamic_cast(cache); + const MaxPoolingParams3D& params_3d = pooling_cache->GetParams3D(); const int64_t elem_num = dy->shape().elem_cnt(); const T* src = dy->dptr(); @@ -182,13 +194,19 @@ class MaxPool2dKernel final : public user_op::OpKernel { private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateOpKernelCache(ctx, 2); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0); - const auto& pooling_state = DoCreateOpKernelState(ctx, 2); - const MaxPoolingParams3D& params_3d = pooling_state->GetParams3D(); + const auto* pooling_cache = dynamic_cast(cache); + const MaxPoolingParams3D& params_3d = pooling_cache->GetParams3D(); const int64_t elem_num = y->shape().elem_cnt(); const T* src = x->dptr(); @@ -212,13 +230,19 @@ class MaxPool2dGradKernel final : public user_op::OpKernel { private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateOpKernelCache(ctx, 2); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); - const auto& pooling_state = DoCreateOpKernelState(ctx, 2); - const MaxPoolingParams3D& params_3d = pooling_state->GetParams3D(); + const auto* pooling_cache = dynamic_cast(cache); + const MaxPoolingParams3D& params_3d = pooling_cache->GetParams3D(); const int64_t elem_num = dy->shape().elem_cnt(); const T* src = dy->dptr(); @@ -244,13 +268,19 @@ class MaxPool3dKernel final : public user_op::OpKernel { private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateOpKernelCache(ctx, 3); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0); - const auto& pooling_state = DoCreateOpKernelState(ctx, 3); - const MaxPoolingParams3D& params_3d = pooling_state->GetParams3D(); + const auto* pooling_cache = dynamic_cast(cache); + const MaxPoolingParams3D& params_3d = pooling_cache->GetParams3D(); const int64_t elem_num = y->shape().elem_cnt(); const T* src = x->dptr(); @@ -274,13 +304,19 @@ class MaxPool3dGradKernel final : public user_op::OpKernel { private: bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - void Compute(user_op::KernelComputeContext* ctx) const override { + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateOpKernelCache(ctx, 3); + } + + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* indice = ctx->Tensor4ArgNameAndIndex("indice", 0); user_op::Tensor* dx = ctx->Tensor4ArgNameAndIndex("dx", 0); - const auto& pooling_state = DoCreateOpKernelState(ctx, 3); - const MaxPoolingParams3D& params_3d = pooling_state->GetParams3D(); + const auto* pooling_cache = dynamic_cast(cache); + const MaxPoolingParams3D& params_3d = pooling_cache->GetParams3D(); const int64_t elem_num = dy->shape().elem_cnt(); const T* src = dy->dptr(); diff --git a/oneflow/user/kernels/prelu_kernel.cpp b/oneflow/user/kernels/prelu_kernel.cpp index 43eded62907..538cafabcca 100644 --- a/oneflow/user/kernels/prelu_kernel.cpp +++ b/oneflow/user/kernels/prelu_kernel.cpp @@ -35,7 +35,7 @@ class CpuPReluKernel final : public user_op::OpKernel { const int32_t elem_cnt = x->shape().elem_cnt(); const int32_t alpha_size = alpha->shape().elem_cnt(); const int batch = x->shape().At(0); - const int channels = x->shape().At(1); + const int channels = (x->shape().NumAxes() == 1) ? 1 : x->shape().At(1); const int32_t inner_size = elem_cnt / batch / channels; FOR_RANGE(int32_t, i, 0, elem_cnt) { y_ptr[i] = x_ptr[i] > 0 ? x_ptr[i] : x_ptr[i] * alpha_ptr[(i / inner_size) % alpha_size]; @@ -74,7 +74,7 @@ class CpuPReluGradKernel final : public user_op::OpKernel { const int32_t elem_cnt = x->shape().elem_cnt(); const int32_t alpha_size = alpha->shape().elem_cnt(); const int batch = x->shape().At(0); - const int channels = x->shape().At(1); + const int channels = (x->shape().NumAxes() == 1) ? 1 : x->shape().At(1); const int32_t inner_size = elem_cnt / batch / channels; Memset(ctx->stream(), alpha_diff->mut_dptr(), 0, diff --git a/oneflow/user/kernels/prelu_kernel.cu b/oneflow/user/kernels/prelu_kernel.cu index e79fd9ea1a8..ea241deaf35 100644 --- a/oneflow/user/kernels/prelu_kernel.cu +++ b/oneflow/user/kernels/prelu_kernel.cu @@ -338,7 +338,7 @@ class GpuPReluKernel final : public user_op::OpKernel { user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const int32_t elem_cnt = x->shape().elem_cnt(); const int32_t batch = x->shape().At(0); - const int32_t channels = x->shape().At(1); + const int32_t channels = (x->shape().NumAxes() == 1) ? 1 : x->shape().At(1); const int32_t alpha_size = alpha->shape().elem_cnt(); const int32_t inner_size = elem_cnt / batch / channels; @@ -388,7 +388,7 @@ class GpuPReluGradKernel final : public user_op::OpKernel { const Shape& left_extended_shape = CreatePreluLeftExtendedShape(ShapeView(x->shape())); const int32_t batch = x->shape().At(0); - const int32_t channels = x->shape().At(1); + const int32_t channels = (x->shape().NumAxes() == 1) ? 1 : x->shape().At(1); const int32_t alpha_size = alpha->shape().elem_cnt(); const int32_t inner_size = elem_cnt / batch / channels; if (alpha_size == 1) { diff --git a/oneflow/user/kernels/radix_sort_top_k_kernel.cu b/oneflow/user/kernels/radix_sort_top_k_kernel.cu index a903d82ce3c..29e69749ca3 100644 --- a/oneflow/user/kernels/radix_sort_top_k_kernel.cu +++ b/oneflow/user/kernels/radix_sort_top_k_kernel.cu @@ -25,19 +25,19 @@ template class TmpBufferManager final { public: OF_DISALLOW_COPY_AND_MOVE(TmpBufferManager); - TmpBufferManager(int32_t capacity, void* ptr, const ShapeView& in_shape) + TmpBufferManager(int64_t capacity, void* ptr, const ShapeView& in_shape) : capacity_{capacity}, sorted_in_elem_cnt_{in_shape.elem_cnt()}, indices_elem_cnt_{sorted_in_elem_cnt_}, sorted_indices_elem_cnt_{sorted_in_elem_cnt_} { - const int32_t sorted_in_aligned_bytes = GetCudaAlignedSize(sorted_in_elem_cnt_ * sizeof(T)); - const int32_t indices_aligned_bytes = GetCudaAlignedSize(indices_elem_cnt_ * sizeof(int32_t)); - const int32_t sorted_indices_aligned_bytes = indices_aligned_bytes; + const int64_t sorted_in_aligned_bytes = GetCudaAlignedSize(sorted_in_elem_cnt_ * sizeof(T)); + const int64_t indices_aligned_bytes = GetCudaAlignedSize(indices_elem_cnt_ * sizeof(int64_t)); + const int64_t sorted_indices_aligned_bytes = indices_aligned_bytes; sorted_in_ptr_ = reinterpret_cast(ptr); - indices_ptr_ = reinterpret_cast(reinterpret_cast(sorted_in_ptr_) + indices_ptr_ = reinterpret_cast(reinterpret_cast(sorted_in_ptr_) + sorted_in_aligned_bytes); sorted_indices_ptr_ = - reinterpret_cast(reinterpret_cast(indices_ptr_) + indices_aligned_bytes); + reinterpret_cast(reinterpret_cast(indices_ptr_) + indices_aligned_bytes); temp_storage_ptr_ = reinterpret_cast(reinterpret_cast(sorted_indices_ptr_) + sorted_indices_aligned_bytes); temp_storage_bytes_ = @@ -47,27 +47,27 @@ class TmpBufferManager final { ~TmpBufferManager() = default; T* SortedInPtr() const { return sorted_in_ptr_; } - int32_t* IndicesPtr() const { return indices_ptr_; } - int32_t* SortedIndicesPtr() const { return sorted_indices_ptr_; } + int64_t* IndicesPtr() const { return indices_ptr_; } + int64_t* SortedIndicesPtr() const { return sorted_indices_ptr_; } void* TempStoragePtr() const { return temp_storage_ptr_; } - int32_t TempStorageBytes() const { return temp_storage_bytes_; } + int64_t TempStorageBytes() const { return temp_storage_bytes_; } private: - int32_t capacity_; + int64_t capacity_; T* sorted_in_ptr_; - int32_t* indices_ptr_; - int32_t* sorted_indices_ptr_; + int64_t* indices_ptr_; + int64_t* sorted_indices_ptr_; void* temp_storage_ptr_; int64_t sorted_in_elem_cnt_; int64_t indices_elem_cnt_; int64_t sorted_indices_elem_cnt_; - int32_t temp_storage_bytes_; + int64_t temp_storage_bytes_; }; -__global__ void InitializeIndices(int32_t elem_cnt, int32_t* indices_ptr, int32_t instance_size) { +__global__ void InitializeIndices(int64_t elem_cnt, int64_t* indices_ptr, int64_t instance_size) { CUDA_1D_KERNEL_LOOP(i, elem_cnt) { indices_ptr[i] = i % instance_size; }; } @@ -86,13 +86,13 @@ class GpuRadixSortTopKKernel final : public user_op::OpKernel { if (in->shape().elem_cnt() == 0) { return; } user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); - TmpBufferManager buf_manager(static_cast(tmp_buffer->shape().elem_cnt()), + TmpBufferManager buf_manager(static_cast(tmp_buffer->shape().elem_cnt()), tmp_buffer->mut_dptr(), in->shape()); - const int32_t elem_cnt = in->shape().elem_cnt(); - const int32_t instance_size = in->shape().At(in->shape().NumAxes() - 1); - const int32_t instance_num = elem_cnt / instance_size; - const int32_t k = std::min(ctx->Attr("k"), instance_size); + const int64_t elem_cnt = in->shape().elem_cnt(); + const int64_t instance_size = in->shape().At(in->shape().NumAxes() - 1); + const int64_t instance_num = elem_cnt / instance_size; + const int64_t k = std::min(static_cast(ctx->Attr("k")), instance_size); InitializeIndices<<stream()->As()->cuda_stream()>>>( elem_cnt, buf_manager.IndicesPtr(), instance_size); @@ -100,9 +100,9 @@ class GpuRadixSortTopKKernel final : public user_op::OpKernel { buf_manager.TempStoragePtr(), buf_manager.TempStorageBytes(), buf_manager.SortedInPtr(), buf_manager.SortedIndicesPtr(), ctx->stream()->As()->cuda_stream()); - OF_CUDA_CHECK(cudaMemcpy2DAsync(out->mut_dptr(), k * sizeof(int32_t), - buf_manager.SortedIndicesPtr(), instance_size * sizeof(int32_t), - k * sizeof(int32_t), instance_num, cudaMemcpyDefault, + OF_CUDA_CHECK(cudaMemcpy2DAsync(out->mut_dptr(), k * sizeof(int64_t), + buf_manager.SortedIndicesPtr(), instance_size * sizeof(int64_t), + k * sizeof(int64_t), instance_num, cudaMemcpyDefault, ctx->stream()->As()->cuda_stream())); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } @@ -116,19 +116,19 @@ class GpuRadixSortTopKKernel final : public user_op::OpKernel { && (user_op::HobDataType("in", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& in_shape = ctx->InputShape("in", 0); \ - const int32_t elem_cnt = in_shape.elem_cnt(); \ - const int32_t instance_size = in_shape.dim_vec().back(); \ - const int32_t instance_num = elem_cnt / instance_size; \ + const int64_t elem_cnt = in_shape.elem_cnt(); \ + const int64_t instance_size = in_shape.dim_vec().back(); \ + const int64_t instance_num = elem_cnt / instance_size; \ \ /* Sorted In*/ \ - const int32_t sorted_in_aligned_bytes = GetCudaAlignedSize(elem_cnt * sizeof(dtype)); \ + const int64_t sorted_in_aligned_bytes = GetCudaAlignedSize(elem_cnt * sizeof(dtype)); \ /* Indices */ \ - const int32_t indices_aligned_bytes = GetCudaAlignedSize(elem_cnt * sizeof(int32_t)); \ + const int64_t indices_aligned_bytes = GetCudaAlignedSize(elem_cnt * sizeof(int64_t)); \ /* Sorted Indices */ \ - const int32_t sorted_indices_aligned_bytes = indices_aligned_bytes; \ + const int64_t sorted_indices_aligned_bytes = indices_aligned_bytes; \ /* CUB Temp Storage */ \ - int32_t temp_storage_bytes = \ - InferTempStorageForSortPairsDescending(instance_num, instance_size); \ + int64_t temp_storage_bytes = \ + InferTempStorageForSortPairsDescending(instance_num, instance_size); \ \ return sorted_in_aligned_bytes + indices_aligned_bytes + sorted_indices_aligned_bytes \ + temp_storage_bytes; \ diff --git a/oneflow/user/kernels/random_mask_like_kernel.h b/oneflow/user/kernels/random_mask_like_kernel.h index 45827b0bfb2..65289c84695 100644 --- a/oneflow/user/kernels/random_mask_like_kernel.h +++ b/oneflow/user/kernels/random_mask_like_kernel.h @@ -49,7 +49,8 @@ class RandomMaskLikeKernel final : public user_op::OpKernel, public user_op::Cud } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { const user_op::Tensor* like = ctx->Tensor4ArgNameAndIndex("like", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); int64_t elem_cnt = like->shape().elem_cnt(); diff --git a/oneflow/user/kernels/randperm_kernel.cpp b/oneflow/user/kernels/randperm_kernel.cpp index 356ac61b3a7..e0e8bea82e5 100644 --- a/oneflow/user/kernels/randperm_kernel.cpp +++ b/oneflow/user/kernels/randperm_kernel.cpp @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" -#include "oneflow/user/kernels/op_kernel_state_wrapper.h" +#include "oneflow/user/kernels/op_kernel_wrapper.h" #include "oneflow/core/common/data_type.h" #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/framework/random_generator.h" @@ -34,7 +34,8 @@ class CpuRandPermKernel final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); int32_t* output = out->mut_dptr(); const int32_t n = ctx->Attr("n"); diff --git a/oneflow/user/kernels/randperm_kernel.cu b/oneflow/user/kernels/randperm_kernel.cu index 7229eeaf8b9..cdd0d0abccd 100644 --- a/oneflow/user/kernels/randperm_kernel.cu +++ b/oneflow/user/kernels/randperm_kernel.cu @@ -20,7 +20,7 @@ limitations under the License. #include "oneflow/core/ep/include/stream.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/random_generator.h" -#include "oneflow/user/kernels/op_kernel_state_wrapper.h" +#include "oneflow/user/kernels/op_kernel_wrapper.h" #include "oneflow/user/kernels/arange_kernel_util.h" #include "oneflow/user/kernels/radix_sort.cuh" #include "oneflow/user/kernels/distributions/common.h" @@ -48,7 +48,8 @@ class GpuRandPermKernel final : public user_op::OpKernel { private: using user_op::OpKernel::Compute; - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); int32_t* output = out->mut_dptr(); const int32_t n = ctx->Attr("n"); diff --git a/oneflow/user/kernels/relu_kernel.cpp b/oneflow/user/kernels/relu_kernel.cpp index bf9124fc50b..bbc03a052a4 100644 --- a/oneflow/user/kernels/relu_kernel.cpp +++ b/oneflow/user/kernels/relu_kernel.cpp @@ -21,8 +21,8 @@ namespace oneflow { template std::unique_ptr NewReluPrimitive(Context* ctx) { - const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("in", 0); - const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("out", 0); + const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("x", 0); + const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("y", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kRelu, src->data_type(), dst->data_type()); } @@ -38,8 +38,8 @@ class ReluKernel final : public user_op::OpKernel, public user_op::CudaGraphSupp auto primitive = NewReluPrimitive(ctx); CHECK(primitive); - const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("in", 0); - user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("out", 0); + const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); + user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const int64_t elem_cnt = x->shape().elem_cnt(); if (elem_cnt != 0) { diff --git a/oneflow/user/kernels/roi_align_kernel.cu b/oneflow/user/kernels/roi_align_kernel.cu index 55e2d089adc..45b5ea6f031 100644 --- a/oneflow/user/kernels/roi_align_kernel.cu +++ b/oneflow/user/kernels/roi_align_kernel.cu @@ -239,6 +239,7 @@ class RoIAlignKernel final : public user_op::OpKernel { void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x_blob = ctx->Tensor4ArgNameAndIndex("x", 0); const user_op::Tensor* rois_blob = ctx->Tensor4ArgNameAndIndex("rois", 0); + if (rois_blob->shape().elem_cnt() == 0) { return; } user_op::Tensor* y_blob = ctx->Tensor4ArgNameAndIndex("y", 0); const int32_t pooled_h = ctx->Attr("pooled_h"); const int32_t pooled_w = ctx->Attr("pooled_w"); diff --git a/oneflow/user/kernels/slice_kernel.cpp b/oneflow/user/kernels/slice_kernel.cpp index 913a8158e94..1df0078920b 100644 --- a/oneflow/user/kernels/slice_kernel.cpp +++ b/oneflow/user/kernels/slice_kernel.cpp @@ -18,14 +18,14 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/slice_util.h" #include "oneflow/core/kernel/kernel_util.h" -#include "oneflow/user/kernels/op_kernel_state_wrapper.h" +#include "oneflow/user/kernels/op_kernel_wrapper.h" #include "oneflow/core/kernel/cuda_graph_support.h" namespace oneflow { namespace { -const int SPLIT_AXIS_FOR_BROADCAST = -1; +const int SPLIT_AXIS_FOR_NON_SPLIT = -1; // [start, end) int64_t GetSizeInSlice(const int64_t start, const int64_t end, const int64_t step) { @@ -218,7 +218,7 @@ void WriteSlice(user_op::KernelComputeContext* ctx, const user_op::Tensor* src, const bool from_large_to_small) { const user_op::Tensor* large = from_large_to_small ? src : dst; const user_op::Tensor* small = from_large_to_small ? dst : src; - if (slice_ctx.split_axis != SPLIT_AXIS_FOR_BROADCAST) { + if (slice_ctx.split_axis != SPLIT_AXIS_FOR_NON_SPLIT) { CHECK_EQ(large->shape().At(slice_ctx.split_axis), slice_ctx.upper - slice_ctx.lower); } @@ -278,11 +278,11 @@ DEFINE_STATIC_SWITCH_FUNC( )); #undef MAKE_WRITE_SLICE_SWITCH_ENTRY -std::shared_ptr CreateSliceState(user_op::KernelInitContext* ctx, +std::shared_ptr CreateSliceCache(user_op::KernelCacheContext* ctx, const std::string& large_tensor_name) { if (ctx->parallel_ctx().parallel_num() == 1) { - // split_axis == SPLIT_AXIS_FOR_BROADCAST means the sbp attribute is broadcast instead of split - return std::make_shared>(SPLIT_AXIS_FOR_BROADCAST, 0, 0, 0); + // split_axis == SPLIT_AXIS_FOR_NON_SPLIT means the sbp attribute is not 'split' + return std::make_shared>(SPLIT_AXIS_FOR_NON_SPLIT, 0, 0, 0); } const cfg::SbpParallel& in_sbp = ctx->SbpParallel4ArgNameAndIndex(large_tensor_name, 0); if (in_sbp.has_split_parallel()) { @@ -292,12 +292,11 @@ std::shared_ptr CreateSliceState(user_op::KernelInitCont const int64_t split_dim_size = in_logical_desc->shape().At(split_axis); const int64_t parallel_id = ctx->parallel_ctx().parallel_id(); BalancedSplitter bs(split_dim_size, ctx->parallel_ctx().parallel_num()); - return std::make_shared>( + return std::make_shared>( split_axis, bs.At(parallel_id).begin(), bs.At(parallel_id).end(), split_dim_size); - } else if (in_sbp.has_broadcast_parallel()) { - return std::make_shared>(SPLIT_AXIS_FOR_BROADCAST, 0, 0, 0); + } else if (in_sbp.has_broadcast_parallel() || in_sbp.has_partial_sum_parallel()) { + return std::make_shared>(SPLIT_AXIS_FOR_NON_SPLIT, 0, 0, 0); } else { - // TODO(jianhao): support partialsum UNIMPLEMENTED(); } } @@ -308,25 +307,30 @@ class LogicalSliceKernel final : public user_op::OpKernel { LogicalSliceKernel() = default; ~LogicalSliceKernel() = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { const cfg::SbpParallel& x_sbp = ctx->SbpParallel4ArgNameAndIndex("x", 0); const cfg::SbpParallel& y_sbp = ctx->SbpParallel4ArgNameAndIndex("y", 0); if (ctx->parallel_ctx().parallel_num() > 1) { if (x_sbp.has_split_parallel()) { CHECK(y_sbp.has_partial_sum_parallel()); - } else { + } else if (x_sbp.has_broadcast_parallel()) { CHECK(y_sbp.has_broadcast_parallel()); + } else { + CHECK(x_sbp.has_partial_sum_parallel()); + CHECK(y_sbp.has_partial_sum_parallel()); } } - return CreateSliceState(ctx, "x"); + return CreateSliceCache(ctx, "x"); } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { user_op::Tensor* y_tensor = ctx->Tensor4ArgNameAndIndex("y", 0); const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); - const SliceContext& slice_ctx = dynamic_cast*>(state)->Get(); + const SliceContext& slice_ctx = + dynamic_cast*>(cache)->Get(); if (y_tensor->mem_case().has_host_mem()) { memset(y_tensor->mut_dptr(), 0, y_tensor->shape().elem_cnt() * GetSizeOfDataType(y_tensor->data_type())); @@ -352,20 +356,22 @@ class LogicalSliceAssignKernel final : public user_op::OpKernel { LogicalSliceAssignKernel() = default; ~LogicalSliceAssignKernel() = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { if (ctx->parallel_ctx().parallel_num() > 1) { const cfg::SbpParallel& value_sbp = ctx->SbpParallel4ArgNameAndIndex("value", 0); CHECK(value_sbp.has_broadcast_parallel()); } - return CreateSliceState(ctx, "ref"); + return CreateSliceCache(ctx, "ref"); } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* value_tensor = ctx->Tensor4ArgNameAndIndex("value", 0); user_op::Tensor* ref_tensor = ctx->Tensor4ArgNameAndIndex("ref", 0); - const SliceContext& slice_ctx = dynamic_cast*>(state)->Get(); + const SliceContext& slice_ctx = + dynamic_cast*>(cache)->Get(); SwitchWriteSlice(SwitchCase(value_tensor->shape().NumAxes(), value_tensor->data_type()), ctx, value_tensor, ref_tensor, slice_ctx, false); } diff --git a/oneflow/user/kernels/sparse_cross_entropy_kernel.cpp b/oneflow/user/kernels/sparse_cross_entropy_kernel.cpp index 602bff43aad..78ed4c65513 100644 --- a/oneflow/user/kernels/sparse_cross_entropy_kernel.cpp +++ b/oneflow/user/kernels/sparse_cross_entropy_kernel.cpp @@ -23,10 +23,10 @@ namespace user_op { namespace { -class SparseCrossEntropyOpKernelState final : public user_op::OpKernelState { +class SparseCrossEntropyOpKernelCache final : public user_op::OpKernelCache { public: - SparseCrossEntropyOpKernelState(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {} - ~SparseCrossEntropyOpKernelState() override = default; + SparseCrossEntropyOpKernelCache(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {} + ~SparseCrossEntropyOpKernelCache() override = default; int64_t lower() const { return lower_; } int64_t upper() const { return upper_; } @@ -67,8 +67,8 @@ class SparseCrossEntropyMsKernel final : public user_op::OpKernel { SparseCrossEntropyMsKernel() = default; ~SparseCrossEntropyMsKernel() = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { if (ctx->parallel_ctx().parallel_num() > 1) { const cfg::NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("prediction", 0); const Shape& hierarchy = *ctx->parallel_desc().hierarchy(); @@ -77,15 +77,16 @@ class SparseCrossEntropyMsKernel final : public user_op::OpKernel { const int64_t class_axis = prediction_logical_desc->shape().NumAxes() - 1; TensorSliceView view = GetTensorSliceView4ParallelId( hierarchy, nd_sbp, prediction_logical_desc->shape(), ctx->parallel_ctx().parallel_id()); - return std::make_shared(view.At(class_axis).begin(), + return std::make_shared(view.At(class_axis).begin(), view.At(class_axis).end()); } else { - return std::shared_ptr(nullptr); + return nullptr; } } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* prediction = ctx->Tensor4ArgNameAndIndex("prediction", 0); const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex("label", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); @@ -94,11 +95,11 @@ class SparseCrossEntropyMsKernel final : public user_op::OpKernel { const int64_t num_classes = prediction->shape().elem_cnt() / num_instances; const int64_t depth = ctx->Attr("depth"); int64_t lower_bound = 0; - if (state != nullptr) { - auto* kernel_state = dynamic_cast(state); - CHECK_NOTNULL(kernel_state); - CHECK_EQ(num_classes, kernel_state->upper() - kernel_state->lower()); - lower_bound = kernel_state->lower(); + if (cache != nullptr) { + auto* kernel_cache = dynamic_cast(cache); + CHECK_NOTNULL(kernel_cache); + CHECK_EQ(num_classes, kernel_cache->upper() - kernel_cache->lower()); + lower_bound = kernel_cache->lower(); } Memset(ctx->stream(), out->mut_dptr(), 0, out->shape().elem_cnt() * GetSizeOfDataType(out->data_type())); @@ -171,8 +172,8 @@ class SparseCrossEntropyMsGradKernel final : public user_op::OpKernel { SparseCrossEntropyMsGradKernel() = default; ~SparseCrossEntropyMsGradKernel() = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { if (ctx->parallel_ctx().parallel_num() > 1) { const cfg::NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("prediction", 0); const Shape& hierarchy = *ctx->parallel_desc().hierarchy(); @@ -181,15 +182,16 @@ class SparseCrossEntropyMsGradKernel final : public user_op::OpKernel { const int64_t class_axis = prediction_logical_desc->shape().NumAxes() - 1; TensorSliceView view = GetTensorSliceView4ParallelId( hierarchy, nd_sbp, prediction_logical_desc->shape(), ctx->parallel_ctx().parallel_id()); - return std::make_shared(view.At(class_axis).begin(), + return std::make_shared(view.At(class_axis).begin(), view.At(class_axis).end()); } else { - return std::shared_ptr(nullptr); + return nullptr; } } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* prediction = ctx->Tensor4ArgNameAndIndex("prediction", 0); const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex("label", 0); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); @@ -199,11 +201,11 @@ class SparseCrossEntropyMsGradKernel final : public user_op::OpKernel { const int64_t num_classes = prediction->shape().elem_cnt() / num_instances; const int64_t depth = ctx->Attr("depth"); int64_t lower_bound = 0; - if (state != nullptr) { - auto* kernel_state = dynamic_cast(state); - CHECK_NOTNULL(kernel_state); - CHECK_EQ(num_classes, kernel_state->upper() - kernel_state->lower()); - lower_bound = kernel_state->lower(); + if (cache != nullptr) { + auto* kernel_cache = dynamic_cast(cache); + CHECK_NOTNULL(kernel_cache); + CHECK_EQ(num_classes, kernel_cache->upper() - kernel_cache->lower()); + lower_bound = kernel_cache->lower(); } size_t prediction_diff_bytes_size = prediction_diff->shape().elem_cnt() * GetSizeOfDataType(prediction_diff->data_type()); diff --git a/oneflow/user/kernels/sparse_softmax_cross_entropy_kernel.cpp b/oneflow/user/kernels/sparse_softmax_cross_entropy_kernel.cpp index 3283cf8de83..036390170a3 100644 --- a/oneflow/user/kernels/sparse_softmax_cross_entropy_kernel.cpp +++ b/oneflow/user/kernels/sparse_softmax_cross_entropy_kernel.cpp @@ -39,11 +39,11 @@ auto LogSoftmaxPrimitiveExists() { }); } -class SparseSoftmaxCrossEntropyOpKernelState final : public user_op::OpKernelState { +class SparseSoftmaxCrossEntropyOpKernelCache final : public user_op::OpKernelCache { public: - SparseSoftmaxCrossEntropyOpKernelState(int64_t lower, int64_t upper) + SparseSoftmaxCrossEntropyOpKernelCache(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {} - ~SparseSoftmaxCrossEntropyOpKernelState() override = default; + ~SparseSoftmaxCrossEntropyOpKernelCache() override = default; int64_t lower() const { return lower_; } int64_t upper() const { return upper_; } @@ -164,8 +164,8 @@ class SparseSoftmaxCrossEntropyMsGradKernel final : public user_op::OpKernel { public: SparseSoftmaxCrossEntropyMsGradKernel() = default; ~SparseSoftmaxCrossEntropyMsGradKernel() = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { if (ctx->parallel_ctx().parallel_num() > 1) { const cfg::NdSbp& nd_sbp = ctx->NdSbp4ArgNameAndIndex("prob", 0); const Shape& hierarchy = *ctx->parallel_desc().hierarchy(); @@ -173,15 +173,16 @@ class SparseSoftmaxCrossEntropyMsGradKernel final : public user_op::OpKernel { const int64_t class_axis = prob_logical_desc->shape().NumAxes() - 1; TensorSliceView view = GetTensorSliceView4ParallelId( hierarchy, nd_sbp, prob_logical_desc->shape(), ctx->parallel_ctx().parallel_id()); - return std::make_shared(view.At(class_axis).begin(), + return std::make_shared(view.At(class_axis).begin(), view.At(class_axis).end()); } else { - return std::shared_ptr(nullptr); + return nullptr; } } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* label = ctx->Tensor4ArgNameAndIndex("label", 0); const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); const user_op::Tensor* prob = ctx->Tensor4ArgNameAndIndex("prob", 0); @@ -191,11 +192,11 @@ class SparseSoftmaxCrossEntropyMsGradKernel final : public user_op::OpKernel { const int64_t num_classes = prob->shape().elem_cnt() / num_instances; const int64_t depth = ctx->Attr("depth"); int64_t lower_bound = 0; - if (state != nullptr) { - auto* kernel_state = dynamic_cast(state); - CHECK_NOTNULL(kernel_state); - CHECK_EQ(num_classes, kernel_state->upper() - kernel_state->lower()); - lower_bound = kernel_state->lower(); + if (cache != nullptr) { + auto* kernel_cache = dynamic_cast(cache); + CHECK_NOTNULL(kernel_cache); + CHECK_EQ(num_classes, kernel_cache->upper() - kernel_cache->lower()); + lower_bound = kernel_cache->lower(); } SparseCrossEntropyKernelUtil::ComputeDiffWithSoftmax( ctx->stream(), prediction_diff->shape().elem_cnt(), num_classes, depth, lower_bound, diff --git a/oneflow/user/kernels/sqrt_square_sum_kernel.cpp b/oneflow/user/kernels/sqrt_square_sum_kernel.cpp new file mode 100644 index 00000000000..4c741594e4b --- /dev/null +++ b/oneflow/user/kernels/sqrt_square_sum_kernel.cpp @@ -0,0 +1,69 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include +#include "oneflow/core/framework/framework.h" +#include "oneflow/user/kernels/sqrt_square_sum_kernel_util.h" +#include "oneflow/core/common/balanced_splitter.h" +#include "oneflow/core/kernel/cuda_graph_support.h" + +namespace oneflow { + +namespace user_op { + +int64_t getThreadNumBlocks(int64_t n) { + int64_t num_blocks = 1; +#ifdef WITH_CUDA + num_blocks = BlocksNum4ThreadsNum(n); +#endif + return num_blocks; +} + +template +class SqrtSquareSumKernel final : public user_op::OpKernel, public user_op::CudaGraphSupport { + public: + SqrtSquareSumKernel() = default; + ~SqrtSquareSumKernel() override = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); + user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); + user_op::Tensor* tmp = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); + + SqrtSquareSumKernelUtil::SqrtSquareSum( + ctx->stream(), x->shape().elem_cnt(), x->dptr(), y->mut_dptr(), tmp->mut_dptr()); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_SQUARE_SUM_KERNEL(device, dtype) \ + REGISTER_USER_KERNEL("sqrt_square_sum") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == device) \ + && (user_op::HobDataType("y", 0) == OF_PP_PAIR_SECOND(dtype))) \ + .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ + const auto& x_shape = ctx->InputTensorDesc("x", 0).shape(); \ + const int32_t num_blocks = getThreadNumBlocks(x_shape.Count(0)); \ + int64_t tmp_buffer_size = num_blocks; \ + return tmp_buffer_size * sizeof(OF_PP_PAIR_FIRST(dtype)); \ + }); + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_SQUARE_SUM_KERNEL, DEVICE_TYPE_SEQ, + FLOATING_DATA_TYPE_SEQ) + +} // namespace user_op + +} // namespace oneflow diff --git a/oneflow/user/kernels/sqrt_square_sum_kernel_util.cpp b/oneflow/user/kernels/sqrt_square_sum_kernel_util.cpp new file mode 100644 index 00000000000..abafd31697a --- /dev/null +++ b/oneflow/user/kernels/sqrt_square_sum_kernel_util.cpp @@ -0,0 +1,34 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/user/kernels/sqrt_square_sum_kernel_util.h" + +namespace oneflow { + +template +struct SqrtSquareSumKernelUtil { + static void SqrtSquareSum(ep::Stream* stream, int64_t n, const T* x, T* y, T* tmp) { + T sum = 0; + FOR_RANGE(int64_t, i, 0, n) { sum += x[i] * x[i]; } + *y = std::sqrt(sum); + } +}; + +#define INSTANTIATE_SQUARE_SUM_KERNEL_UTIL_CPU(type_cpp, type_proto) \ + template struct SqrtSquareSumKernelUtil; +OF_PP_FOR_EACH_TUPLE(INSTANTIATE_SQUARE_SUM_KERNEL_UTIL_CPU, FLOATING_DATA_TYPE_SEQ); +#undef INSTANTIATE_SQUARE_SUM_KERNEL_UTIL_CPU + +} // namespace oneflow diff --git a/oneflow/user/kernels/sqrt_square_sum_kernel_util.cu b/oneflow/user/kernels/sqrt_square_sum_kernel_util.cu new file mode 100644 index 00000000000..a93f48e0ee2 --- /dev/null +++ b/oneflow/user/kernels/sqrt_square_sum_kernel_util.cu @@ -0,0 +1,82 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/user/kernels/sqrt_square_sum_kernel_util.h" +#include "oneflow/core/cuda/atomic.cuh" +#include "oneflow/core/ep/cuda/cuda_stream.h" +#include + +namespace oneflow { + +namespace { + +template +__global__ void SqrtSquareSumForOneThreadBlock(int64_t n, const T* x, T* y) { + T t_sum = 0; + CUDA_1D_KERNEL_LOOP(i, n) { t_sum += x[i] * x[i]; } + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + T b_sum = BlockReduce(temp_storage).Sum(t_sum); + if (threadIdx.x == 0) { *y = sqrt(b_sum); } +} + +template +__global__ void SqrtSumForMultiThreadBlock(int64_t n, const T* x, T* y) { + T t_sum = 0; + CUDA_1D_KERNEL_LOOP(i, n) { t_sum += x[i]; } + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + T b_sum = BlockReduce(temp_storage).Sum(t_sum); + if (threadIdx.x == 0) { *y = sqrt(b_sum); } +} + +template +__global__ void SquareSumForMultiThreadBlock(int64_t n, const T* x, T* tmp) { + T t_sum = 0; + CUDA_1D_KERNEL_LOOP(i, n) { t_sum += x[i] * x[i]; } + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + T b_sum = BlockReduce(temp_storage).Sum(t_sum); + if (threadIdx.x == 0) { tmp[blockIdx.x] = b_sum; } +} + +} // namespace + +template +struct SqrtSquareSumKernelUtil { + static void SqrtSquareSum(ep::Stream* stream, int64_t n, const T* x, T* y, T* tmp) { + const int32_t num_blocks = BlocksNum4ThreadsNum(n); + CHECK_GE(num_blocks, 0); + if (num_blocks == 1) { + SqrtSquareSumForOneThreadBlock + <<<1, kCudaThreadsNumPerBlock, 0, stream->As()->cuda_stream()>>>(n, x, y); + } else { + Memset(stream, y, 0, sizeof(T)); + SquareSumForMultiThreadBlock + <<As()->cuda_stream()>>>( + n, x, tmp); + SqrtSumForMultiThreadBlock + <<<1, kCudaThreadsNumPerBlock, 0, stream->As()->cuda_stream()>>>( + num_blocks, tmp, y); + } + } +}; + +#define INSTANTIATE_SQRT_SQUARE_SUM_KERNEL_UTIL_CUDA(type_cpp, type_proto) \ + template struct SqrtSquareSumKernelUtil; +OF_PP_FOR_EACH_TUPLE(INSTANTIATE_SQRT_SQUARE_SUM_KERNEL_UTIL_CUDA, FLOATING_DATA_TYPE_SEQ); +#undef INSTANTIATE_SQRT_SQUARE_SUM_KERNEL_UTIL_CUDA + +} // namespace oneflow diff --git a/oneflow/user/kernels/sqrt_square_sum_kernel_util.h b/oneflow/user/kernels/sqrt_square_sum_kernel_util.h new file mode 100644 index 00000000000..fb905cb6124 --- /dev/null +++ b/oneflow/user/kernels/sqrt_square_sum_kernel_util.h @@ -0,0 +1,30 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_USER_KERNELS_SQUARE_SUM_KERNEL_UTIL_H_ +#define ONEFLOW_USER_KERNELS_SQUARE_SUM_KERNEL_UTIL_H_ + +#include "oneflow/core/kernel/kernel_util.h" + +namespace oneflow { + +template +struct SqrtSquareSumKernelUtil { + static void SqrtSquareSum(ep::Stream* stream, int64_t n, const T* x, T* y, T* tmp); +}; + +} // namespace oneflow + +#endif // ONEFLOW_USER_KERNELS_SQUARE_SUM_KERNEL_UTIL_H_ diff --git a/oneflow/user/kernels/square_sum_kernel.cpp b/oneflow/user/kernels/square_sum_kernel.cpp index 7b3da73b897..96fe61da092 100644 --- a/oneflow/user/kernels/square_sum_kernel.cpp +++ b/oneflow/user/kernels/square_sum_kernel.cpp @@ -29,7 +29,7 @@ class SquareSumKernel final : public user_op::OpKernel, public user_op::CudaGrap ~SquareSumKernel() override = default; private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx) const override { const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); diff --git a/oneflow/user/kernels/stateful_local_opkernel.cpp b/oneflow/user/kernels/stateful_local_opkernel.cpp index cce8b992bc1..4fcf4c866ad 100644 --- a/oneflow/user/kernels/stateful_local_opkernel.cpp +++ b/oneflow/user/kernels/stateful_local_opkernel.cpp @@ -202,9 +202,10 @@ class LocalUserKernelRegContext final : public user_op::KernelRegContext { } }; -class LocalUserKernelInitContext final : public user_op::KernelInitContext { +class LocalUserKernelInitAndCacheContext final : public user_op::KernelInitContext, + public user_op::KernelCacheContext { public: - explicit LocalUserKernelInitContext( + explicit LocalUserKernelInitAndCacheContext( DeviceCtx* device_ctx, const std::string& device_tag, const user_op::UserOpConfWrapper* user_op_conf, const std::shared_ptr& input_arg_tuple, @@ -218,7 +219,7 @@ class LocalUserKernelInitContext final : public user_op::KernelInitContext { composed_attrs_(composed_attrs) { base_ctx_.Update(inputs, outputs, consistent_tensor_infer_result); } - ~LocalUserKernelInitContext() override = default; + ~LocalUserKernelInitAndCacheContext() override = default; ep::Stream* stream() override { CHECK(device_ctx_); @@ -463,23 +464,31 @@ Maybe StatefulLocalOpKernel::ChooseOpKernel( return Maybe::Ok(); } -void StatefulLocalOpKernel::TryInitOpKernelState( +void StatefulLocalOpKernel::TryInitOpKernelStateAndCache( const user_op::OpKernel* op_kernel, DeviceCtx* device_ctx, EagerBlobObjectListRawPtr inputs, EagerBlobObjectListRawPtr outputs, ConsistentTensorInferResultRawPtr consistent_tensor_infer_result, - user_op::OpKernelState** state) { - auto it = op_kernel_state_map_.find(op_kernel); - if (it != op_kernel_state_map_.end()) { - *state = it->second.get(); - return; - } - - LocalUserKernelInitContext init_ctx( + user_op::OpKernelState** state, user_op::OpKernelCache** cache) { + LocalUserKernelInitAndCacheContext init_and_cache_ctx( device_ctx, op_conf_->device_tag(), user_op_conf_.get(), input_arg_tuple_, output_arg_tuple_, inputs, outputs, consistent_tensor_infer_result, composed_attrs_for_scheduler_thread()); - auto created_state = op_kernel->CreateOpKernelState(&init_ctx); - op_kernel_state_map_.emplace(op_kernel, created_state); - *state = created_state.get(); + if (state != nullptr) { + auto it = op_kernel_state_map_.find(op_kernel); + if (it != op_kernel_state_map_.end()) { + *state = it->second.get(); + } else { + auto created_state = op_kernel->CreateOpKernelState(&init_and_cache_ctx); + op_kernel_state_map_.emplace(op_kernel, created_state); + *state = created_state.get(); + } + } + + { + auto& cache_in_map = op_kernel_cache_map_[op_kernel]; + op_kernel->InitOpKernelCache(&init_and_cache_ctx, user_op::OpKernelCache::kAllMayChanged, + &cache_in_map); + *cache = cache_in_map.get(); + } } const user_op::InferTmpSizeFn& StatefulLocalOpKernel::GetInferTmpSizeFn( diff --git a/oneflow/user/kernels/stateful_local_opkernel.h b/oneflow/user/kernels/stateful_local_opkernel.h index 7a20c42a142..1e6e46ca6c0 100644 --- a/oneflow/user/kernels/stateful_local_opkernel.h +++ b/oneflow/user/kernels/stateful_local_opkernel.h @@ -37,7 +37,7 @@ namespace one { class LocalUserKernelBaseContext; class LocalUserKernelRegContext; -class LocalUserKernelInitContext; +class LocalUserKernelInitAndCacheContext; class LocalUserOpInferContext; class ConsistentTensorInferResult; @@ -229,8 +229,7 @@ class LocalUserOpInferContext : public user_op::InferContext { user_op::TensorDesc* OutputTensorDesc(const std::string& arg_name, int32_t index) override { return TensorDesc4ArgNameAndIndex(arg_name, index); } - user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, - int32_t index) override; + user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index); const Shape& InputShape(const std::string& arg_name, int32_t index) const override { return *const_cast(this)->Shape4ArgNameAndIndex(arg_name, index); } @@ -428,10 +427,11 @@ class StatefulLocalOpKernel final { user_op::TensorDescInferFn TensorDescInferFn() const; user_op::DataTypeInferFn DataTypeInferFn() const; - void TryInitOpKernelState(const user_op::OpKernel* op_kernel, DeviceCtx* device_ctx, - EagerBlobObjectListRawPtr inputs, EagerBlobObjectListRawPtr outputs, - ConsistentTensorInferResultRawPtr consistent_tensor_infer_result, - user_op::OpKernelState** state); + void TryInitOpKernelStateAndCache( + const user_op::OpKernel* op_kernel, DeviceCtx* device_ctx, EagerBlobObjectListRawPtr inputs, + EagerBlobObjectListRawPtr outputs, + ConsistentTensorInferResultRawPtr consistent_tensor_infer_result, + user_op::OpKernelState** state, user_op::OpKernelCache** cache); vm::EagerBlobObject* mut_temp_blob_object(); @@ -463,6 +463,7 @@ class StatefulLocalOpKernel final { DataType_MAX> dtype2cached_kernels_; HashMap> op_kernel_state_map_; + HashMap> op_kernel_cache_map_; HashMap infer_tmp_size_fn_map_; std::unique_ptr tmp_blob_object_; std::vector input_tuple_indexes4const_ibns_; diff --git a/oneflow/user/kernels/test_kernels.cpp b/oneflow/user/kernels/test_kernels.cpp index 19483b65d5c..580d729bc2e 100644 --- a/oneflow/user/kernels/test_kernels.cpp +++ b/oneflow/user/kernels/test_kernels.cpp @@ -17,7 +17,7 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" #include "oneflow/core/kernel/random_generator.h" -#include "oneflow/user/kernels/op_kernel_state_wrapper.h" +#include "oneflow/user/kernels/op_kernel_wrapper.h" #include "oneflow/core/ep/include/primitive/fill.h" namespace oneflow { @@ -31,8 +31,8 @@ class ReluKernel final : public user_op::OpKernel { private: void Compute(user_op::KernelComputeContext* ctx) const override { - const user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex("in", 0); - user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); + const user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex("x", 0); + user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* tmp = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); CHECK_NOTNULL(tmp); NewKernelUtil::Relu(ctx->stream(), in_blob->shape().elem_cnt(), @@ -64,7 +64,7 @@ REGISTER_USER_KERNEL("ccrelu") .SetInferTmpSizeFn([](user_op::InferContext*) { return 10; }) .SetInplaceProposalFn([](const user_op::InferContext&, user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { - OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "in", 0, true)); + OF_RETURN_IF_ERROR(AddInplaceArgPairFn("y", 0, "x", 0, true)); return Maybe::Ok(); }); @@ -287,7 +287,8 @@ class TestRandomSourceKernel final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { auto* random_generator = dynamic_cast>*>(state); user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); diff --git a/oneflow/user/kernels/top_k_kernel.cpp b/oneflow/user/kernels/top_k_kernel.cpp index 1ef0370828c..3bb825ff077 100644 --- a/oneflow/user/kernels/top_k_kernel.cpp +++ b/oneflow/user/kernels/top_k_kernel.cpp @@ -22,22 +22,22 @@ namespace oneflow { namespace { template -void ComputeTopOne(const T* in_ptr, const Range& range, int32_t instance_size, int32_t* out_ptr) { - FOR_RANGE(int32_t, i, range.begin(), range.end()) { +void ComputeTopOne(const T* in_ptr, const Range& range, int64_t instance_size, int64_t* out_ptr) { + FOR_RANGE(int64_t, i, range.begin(), range.end()) { const T* in_ptr_i = in_ptr + i * instance_size; out_ptr[i] = std::distance(in_ptr_i, std::max_element(in_ptr_i, in_ptr_i + instance_size)); } } template -void ComputeTopK(const T* in_ptr, int32_t* indices_ptr, const Range& range, int32_t instance_size, - int32_t k, bool sorted, int32_t* out_ptr) { - FOR_RANGE(int32_t, i, range.begin(), range.end()) { - const int32_t offset = i * instance_size; +void ComputeTopK(const T* in_ptr, int64_t* indices_ptr, const Range& range, int64_t instance_size, + int64_t k, bool sorted, int64_t* out_ptr) { + FOR_RANGE(int64_t, i, range.begin(), range.end()) { + const int64_t offset = i * instance_size; const T* in_ptr_i = in_ptr + offset; - int32_t* indices_ptr_i = indices_ptr + offset; + int64_t* indices_ptr_i = indices_ptr + offset; std::iota(indices_ptr_i, indices_ptr_i + instance_size, 0); - auto comp = [&](const int32_t lhs, const int32_t rhs) { + auto comp = [&](const int64_t lhs, const int64_t rhs) { const T l = in_ptr_i[lhs]; const T r = in_ptr_i[rhs]; if (l == r) { @@ -53,12 +53,13 @@ void ComputeTopK(const T* in_ptr, int32_t* indices_ptr, const Range& range, int3 } template -void CpuTopK(ep::Stream* /*stream*/, const T* in_ptr, int32_t* indices_ptr, int32_t instance_num, - int32_t instance_size, int32_t k, bool sorted, int32_t* out_ptr) { - const int32_t num_thread = std::min(instance_num, Global::Get()->thread_num()); +void CpuTopK(ep::Stream* /*stream*/, const T* in_ptr, int64_t* indices_ptr, int64_t instance_num, + int64_t instance_size, int64_t k, bool sorted, int64_t* out_ptr) { + const int64_t num_thread = + std::min(instance_num, static_cast(Global::Get()->thread_num())); const BalancedSplitter bs(instance_num, num_thread); BlockingCounter bc(num_thread); - FOR_RANGE(int32_t, thread_id, 0, num_thread) { + FOR_RANGE(int64_t, thread_id, 0, num_thread) { const Range range = bs.At(thread_id); Global::Get()->AddWork([=, &bc]() { if (k == 1) { @@ -87,12 +88,12 @@ class TopKCpuKernel final : public user_op::OpKernel { user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); - const int32_t instance_size = in->shape().At(in->shape().NumAxes() - 1); - const int32_t instance_num = in->shape().elem_cnt() / instance_size; - const int32_t k = std::min(ctx->Attr("k"), instance_size); - int32_t* indices_ptr = tmp_buffer ? tmp_buffer->mut_dptr() : nullptr; + const int64_t instance_size = in->shape().At(in->shape().NumAxes() - 1); + const int64_t instance_num = in->shape().elem_cnt() / instance_size; + const int64_t k = std::min(static_cast(ctx->Attr("k")), instance_size); + int64_t* indices_ptr = tmp_buffer ? tmp_buffer->mut_dptr() : nullptr; CpuTopK(ctx->stream(), in->dptr(), indices_ptr, instance_num, instance_size, k, - ctx->Attr("sorted"), out->mut_dptr()); + ctx->Attr("sorted"), out->mut_dptr()); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; @@ -104,7 +105,7 @@ class TopKCpuKernel final : public user_op::OpKernel { && (user_op::HobDataType("in", 0) == GetDataType::value)) \ .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ const Shape& in_shape = ctx->InputShape("in", 0); \ - return ctx->Attr("k") > 1 ? in_shape.elem_cnt() * sizeof(int32_t) : 0; \ + return ctx->Attr("k") > 1 ? in_shape.elem_cnt() * sizeof(int64_t) : 0; \ }); REGISTER_CPU_TOP_K_KERNEL(float) diff --git a/oneflow/user/kernels/transpose_kernel.cpp b/oneflow/user/kernels/transpose_kernel.cpp index bf1628119a7..f8438fbc102 100644 --- a/oneflow/user/kernels/transpose_kernel.cpp +++ b/oneflow/user/kernels/transpose_kernel.cpp @@ -14,12 +14,22 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/kernel/kernel_util.h" #include "oneflow/core/kernel/cuda_graph_support.h" #include "oneflow/core/ep/include/primitive/permute.h" namespace oneflow { namespace user_op { +namespace { +bool IsIdentity(const std::vector& perm) { + for (auto i = 0; i < perm.size(); i++) { + if (perm[i] != i) { return false; } + } + return true; +} +} // namespace + template std::unique_ptr NewPermutePrimitive(Context* ctx) { const int64_t num_dims = ctx->TensorDesc4ArgNameAndIndex("output", 0)->shape().NumAxes(); @@ -46,9 +56,18 @@ class TransposeKernel final : public OpKernel, public user_op::CudaGraphSupport const int64_t* src_dims = in_shape.ptr(); int64_t elem_cnt = tensor_out->shape().elem_cnt(); + if (elem_cnt != 0) { - primitive->Launch(ctx->stream(), dtype, num_dims, src_dims, tensor_in->dptr(), perm.data(), - tensor_out->mut_dptr()); + if (IsIdentity(perm)) { + // if permute vector is 0,1,...,n, do data copy directly + AutoMemcpy(ctx->stream(), tensor_out->mut_dptr(), tensor_in->dptr(), + elem_cnt * GetSizeOfDataType(dtype), tensor_out->mem_case(), + tensor_in->mem_case()); + } else { + primitive->Launch(ctx->stream(), dtype, num_dims, src_dims, tensor_in->dptr(), perm.data(), + tensor_out->mut_dptr()); + } + } else { // For 0-d Tensor return; diff --git a/oneflow/user/kernels/unpack_kernel.cpp b/oneflow/user/kernels/unpack_kernel.cpp index 8b21a85515d..82b85f4acf3 100644 --- a/oneflow/user/kernels/unpack_kernel.cpp +++ b/oneflow/user/kernels/unpack_kernel.cpp @@ -15,7 +15,7 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/kernel/new_kernel_util.h" -#include "oneflow/user/kernels/op_kernel_state_wrapper.h" +#include "oneflow/user/kernels/op_kernel_wrapper.h" namespace oneflow { @@ -34,7 +34,8 @@ class UnpackKernel final : public user_op::OpKernel { } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state, + const user_op::OpKernelCache*) const override { const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); CHECK_GT(in->shape().NumAxes(), 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); diff --git a/oneflow/user/kernels/unsorted_segment_sum_kernel.cpp b/oneflow/user/kernels/unsorted_segment_sum_kernel.cpp index e48541173a3..af0b2ba2d1f 100644 --- a/oneflow/user/kernels/unsorted_segment_sum_kernel.cpp +++ b/oneflow/user/kernels/unsorted_segment_sum_kernel.cpp @@ -40,10 +40,10 @@ void CheckNdSbp(const Shape& hierarchy, int64_t sum_axis, const cfg::NdSbp& segm } } -class UnsortedSegmentSumOpKernelState final : public user_op::OpKernelState { +class UnsortedSegmentSumOpKernelCache final : public user_op::OpKernelCache { public: - UnsortedSegmentSumOpKernelState(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {} - ~UnsortedSegmentSumOpKernelState() override = default; + UnsortedSegmentSumOpKernelCache(int64_t lower, int64_t upper) : lower_(lower), upper_(upper) {} + ~UnsortedSegmentSumOpKernelCache() override = default; int64_t lower() const { return lower_; } int64_t upper() const { return upper_; } @@ -53,8 +53,8 @@ class UnsortedSegmentSumOpKernelState final : public user_op::OpKernelState { const int64_t upper_; }; -std::shared_ptr CreateUnsortedSegmentSumOpKernelState( - user_op::KernelInitContext* ctx) { +std::shared_ptr CreateUnsortedSegmentSumOpKernelCache( + user_op::KernelCacheContext* ctx) { if (ctx->parallel_ctx().parallel_num() > 1) { const auto axis = ctx->Attr("axis"); const cfg::NdSbp& out_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); @@ -64,10 +64,10 @@ std::shared_ptr CreateUnsortedSegmentSumOpKernelState( const TensorDesc* out_logical_desc = ctx->LogicalTensorDesc4ArgNameAndIndex("out", 0); TensorSliceView view = GetTensorSliceView4ParallelId( hierarchy, out_nd_sbp, out_logical_desc->shape(), ctx->parallel_ctx().parallel_id()); - return std::make_shared(view.At(axis).begin(), + return std::make_shared(view.At(axis).begin(), view.At(axis).end()); } else { - return std::shared_ptr(nullptr); + return nullptr; } } @@ -79,13 +79,14 @@ class UnsortedSegmentSumKernel final : public user_op::OpKernel, public user_op: UnsortedSegmentSumKernel() = default; ~UnsortedSegmentSumKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return CreateUnsortedSegmentSumOpKernelState(ctx); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateUnsortedSegmentSumOpKernelCache(ctx); } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* data = ctx->Tensor4ArgNameAndIndex("data", 0); const user_op::Tensor* segment_ids = ctx->Tensor4ArgNameAndIndex("segment_ids", 0); int64_t axis = ctx->Attr("axis"); @@ -97,11 +98,11 @@ class UnsortedSegmentSumKernel final : public user_op::OpKernel, public user_op: Memset(ctx->stream(), out->mut_dptr(), 0, out->shape().elem_cnt() * sizeof(T)); int64_t offset = 0; - if (state != nullptr) { - auto* sum_state = dynamic_cast(state); - CHECK_NOTNULL(sum_state); - CHECK_EQ(out->shape().At(axis), sum_state->upper() - sum_state->lower()); - offset = sum_state->lower(); + if (cache != nullptr) { + auto* sum_cache = dynamic_cast(cache); + CHECK_NOTNULL(sum_cache); + CHECK_EQ(out->shape().At(axis), sum_cache->upper() - sum_cache->lower()); + offset = sum_cache->lower(); } if (num_segment_ids != 0) { @@ -143,13 +144,14 @@ class UnsortedSegmentSumHalfKernel final : public user_op::OpKernel { UnsortedSegmentSumHalfKernel() = default; ~UnsortedSegmentSumHalfKernel() override = default; - std::shared_ptr CreateOpKernelState( - user_op::KernelInitContext* ctx) const override { - return CreateUnsortedSegmentSumOpKernelState(ctx); + std::shared_ptr InitOpKernelCache( + user_op::KernelCacheContext* ctx) const override { + return CreateUnsortedSegmentSumOpKernelCache(ctx); } private: - void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState*, + const user_op::OpKernelCache* cache) const override { const user_op::Tensor* data = ctx->Tensor4ArgNameAndIndex("data", 0); const user_op::Tensor* segment_ids = ctx->Tensor4ArgNameAndIndex("segment_ids", 0); int64_t axis = ctx->Attr("axis"); @@ -162,11 +164,11 @@ class UnsortedSegmentSumHalfKernel final : public user_op::OpKernel { Memset(ctx->stream(), tmp_buf->mut_dptr(), 0, out->shape().elem_cnt() * sizeof(float)); int64_t offset = 0; - if (state != nullptr) { - auto* sum_state = dynamic_cast(state); - CHECK_NOTNULL(sum_state); - CHECK_EQ(out->shape().At(axis), sum_state->upper() - sum_state->lower()); - offset = sum_state->lower(); + if (cache != nullptr) { + auto* sum_cache = dynamic_cast(cache); + CHECK_NOTNULL(sum_cache); + CHECK_EQ(out->shape().At(axis), sum_cache->upper() - sum_cache->lower()); + offset = sum_cache->lower(); } UnsortedSegmentSumKernelUtil::UnsortedSegmentSum( diff --git a/oneflow/user/ops/acc_op.cpp b/oneflow/user/ops/acc_op.cpp index 4c3188bb5f2..92df9df8f8e 100644 --- a/oneflow/user/ops/acc_op.cpp +++ b/oneflow/user/ops/acc_op.cpp @@ -14,56 +14,53 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ Maybe AccOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe AccOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe AccOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return AccOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe AccOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe AccOp::InferOutputBlobTimeShape( + user_op::InferOutputBlobTimeShapeFnContext* ctx) { + const int32_t max_acc_num = ctx->user_op_conf().attr("max_acc_num"); + const Shape& in_time_shape = ctx->TimeShape4InputArgNameAndIndex("in", 0); + DimVector time_shape_dim_vec = in_time_shape.dim_vec(); + CHECK_OR_RETURN(!time_shape_dim_vec.empty()); + if (time_shape_dim_vec.back() == max_acc_num) { + time_shape_dim_vec.pop_back(); + } else if (time_shape_dim_vec.back() % max_acc_num == 0) { + time_shape_dim_vec.back() /= max_acc_num; + } else { + const int64_t elem_cnt = in_time_shape.elem_cnt(); + time_shape_dim_vec.resize(1); + time_shape_dim_vec.back() = elem_cnt / max_acc_num; + } + *ctx->mut_output_blob_time_shape() = Shape(time_shape_dim_vec); + return Maybe::Ok(); +} -REGISTER_USER_OP("acc") - .Input("in") - .Output("out") - .Attr("max_acc_num") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetOutputBlobTimeShapeInferFn( - [](user_op::InferOutputBlobTimeShapeFnContext* ctx) -> Maybe { - const int32_t max_acc_num = ctx->user_op_conf().attr("max_acc_num"); - const Shape& in_time_shape = ctx->TimeShape4InputArgNameAndIndex("in", 0); - DimVector time_shape_dim_vec = in_time_shape.dim_vec(); - CHECK_OR_RETURN(!time_shape_dim_vec.empty()); - if (time_shape_dim_vec.back() == max_acc_num) { - time_shape_dim_vec.pop_back(); - } else if (time_shape_dim_vec.back() % max_acc_num == 0) { - time_shape_dim_vec.back() /= max_acc_num; - } else { - const int64_t elem_cnt = in_time_shape.elem_cnt(); - time_shape_dim_vec.resize(1); - time_shape_dim_vec.back() = elem_cnt / max_acc_num; - } - *ctx->mut_output_blob_time_shape() = Shape(time_shape_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +namespace { REGISTER_USER_OP_GRAD("acc").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { diff --git a/oneflow/user/ops/adaptive_pool_op.cpp b/oneflow/user/ops/adaptive_pool_op.cpp index 106c535b532..ab2b083b6b9 100644 --- a/oneflow/user/ops/adaptive_pool_op.cpp +++ b/oneflow/user/ops/adaptive_pool_op.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/nn_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -63,31 +64,60 @@ Maybe BwGetSbpFn(user_op::SbpContext* ctx) { } Maybe InferFWDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = *ctx->Dtype4ArgNameAndIndex("x", 0); + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } Maybe InferBWDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("dx", 0) = *ctx->Dtype4ArgNameAndIndex("x", 0); + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } -REGISTER_USER_OP("adaptive_avg_pool1d") - .Input("x") - .Attr>("output_size") - .Output("y") - .SetTensorDescInferFn(InferFWTensorDesc) - .SetGetSbpFn(FwGetSbpFn) - .SetDataTypeInferFn(InferFWDataType); - -REGISTER_USER_OP("adaptive_avg_pool1d_grad") - .Input("x") - .Input("dy") - .Attr>("output_size") - .Output("dx") - .SetTensorDescInferFn(InferBWTensorDesc) - .SetGetSbpFn(BwGetSbpFn) - .SetDataTypeInferFn(InferBWDataType); +} // namespace + +#define DEF_ADAPTIVE_AVG_POOL_OP(op_class_name_prefix) \ + /* static */ Maybe op_class_name_prefix##Op::InferLogicalTensorDesc( \ + user_op::InferContext* ctx) { \ + return InferFWTensorDesc(ctx); \ + } \ + \ + /*static*/ Maybe op_class_name_prefix##Op::InferPhysicalTensorDesc( \ + user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + \ + /* static */ Maybe op_class_name_prefix##Op::GetSbp(user_op::SbpContext* ctx) { \ + return FwGetSbpFn(ctx); \ + } \ + \ + /* static */ Maybe op_class_name_prefix##Op::InferDataType(user_op::InferContext* ctx) { \ + return InferFWDataType(ctx); \ + } \ + \ + /* static */ Maybe op_class_name_prefix##GradOp::InferLogicalTensorDesc( \ + user_op::InferContext* ctx) { \ + return InferBWTensorDesc(ctx); \ + } \ + \ + /*static*/ Maybe op_class_name_prefix##GradOp::InferPhysicalTensorDesc( \ + user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + \ + /* static */ Maybe op_class_name_prefix##GradOp::GetSbp(user_op::SbpContext* ctx) { \ + return BwGetSbpFn(ctx); \ + } \ + \ + /* static */ Maybe op_class_name_prefix##GradOp::InferDataType( \ + user_op::InferContext* ctx) { \ + return InferBWDataType(ctx); \ + } + +DEF_ADAPTIVE_AVG_POOL_OP(AdaptiveAvgPool1D) +DEF_ADAPTIVE_AVG_POOL_OP(AdaptiveAvgPool2D) +DEF_ADAPTIVE_AVG_POOL_OP(AdaptiveAvgPool3D) + +#undef DEF_ADAPTIVE_AVG_POOL_OP REGISTER_USER_OP_GRAD("adaptive_avg_pool1d") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { @@ -107,23 +137,6 @@ REGISTER_USER_OP_GRAD("adaptive_avg_pool1d") return Maybe::Ok(); }); -REGISTER_USER_OP("adaptive_avg_pool2d") - .Input("x") - .Attr>("output_size") - .Output("y") - .SetTensorDescInferFn(InferFWTensorDesc) - .SetGetSbpFn(FwGetSbpFn) - .SetDataTypeInferFn(InferFWDataType); - -REGISTER_USER_OP("adaptive_avg_pool2d_grad") - .Input("x") - .Input("dy") - .Attr>("output_size") - .Output("dx") - .SetTensorDescInferFn(InferBWTensorDesc) - .SetGetSbpFn(BwGetSbpFn) - .SetDataTypeInferFn(InferBWDataType); - REGISTER_USER_OP_GRAD("adaptive_avg_pool2d") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { const auto adaptive_avg_pool2d_grad_op_name = ctx->FwOp().op_name() + "_grad"; @@ -142,23 +155,6 @@ REGISTER_USER_OP_GRAD("adaptive_avg_pool2d") return Maybe::Ok(); }); -REGISTER_USER_OP("adaptive_avg_pool3d") - .Input("x") - .Attr>("output_size") - .Output("y") - .SetTensorDescInferFn(InferFWTensorDesc) - .SetGetSbpFn(FwGetSbpFn) - .SetDataTypeInferFn(InferFWDataType); - -REGISTER_USER_OP("adaptive_avg_pool3d_grad") - .Input("x") - .Input("dy") - .Attr>("output_size") - .Output("dx") - .SetTensorDescInferFn(InferBWTensorDesc) - .SetGetSbpFn(BwGetSbpFn) - .SetDataTypeInferFn(InferBWDataType); - REGISTER_USER_OP_GRAD("adaptive_avg_pool3d") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { const auto adaptive_avg_pool3d_grad_op_name = ctx->FwOp().op_name() + "_grad"; @@ -177,6 +173,4 @@ REGISTER_USER_OP_GRAD("adaptive_avg_pool3d") return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/add_n_op.cpp b/oneflow/user/ops/add_n_op.cpp index 81f60e0f44b..d1c680f68bc 100644 --- a/oneflow/user/ops/add_n_op.cpp +++ b/oneflow/user/ops/add_n_op.cpp @@ -14,45 +14,55 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("add_n") - .InputWithMinimum("in", 2) - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto& in_0 = ctx->InputTensorDesc("in", 0); - auto* out = ctx->OutputTensorDesc("out", 0); - CHECK_NOTNULL_OR_RETURN(out); - for (const auto& pair : ctx->inputs()) { - const auto& cur_in = ctx->InputTensorDesc(pair.first, pair.second); - if (in_0.shape().NumAxes() > 0 && cur_in.shape().NumAxes() > 0) { - CHECK_EQ_OR_RETURN(in_0.shape(), cur_in.shape()); - } - } - *out->mut_shape() = in_0.shape(); - *out->mut_is_dynamic() = in_0.is_dynamic(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) { - int64_t num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape().NumAxes(); - for (int64_t i = 0; i < num_axes; ++i) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(user_op::OpArg("out", 0), i).Build(); - } - ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(user_op::OpArg("out", 0)).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto& in_0 = ctx->InputTensorDesc("in", 0); - auto* out = ctx->OutputTensorDesc("out", 0); - CHECK_NOTNULL_OR_RETURN(out); - for (const auto& pair : ctx->inputs()) { - const auto& cur_in = ctx->InputTensorDesc(pair.first, pair.second); - CHECK_EQ_OR_RETURN(in_0.data_type(), cur_in.data_type()); - } - *out->mut_data_type() = in_0.data_type(); - return Maybe::Ok(); - }); +/* static */ Maybe AddNOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const auto& in_0 = ctx->InputTensorDesc("in", 0); + auto* out = ctx->OutputTensorDesc("out", 0); + CHECK_NOTNULL_OR_RETURN(out); + for (const auto& pair : ctx->inputs()) { + const auto& cur_in = ctx->InputTensorDesc(pair.first, pair.second); + if (in_0.shape().NumAxes() > 0 && cur_in.shape().NumAxes() > 0) { + CHECK_EQ_OR_RETURN(in_0.shape(), cur_in.shape()); + } + } + *out->mut_shape() = in_0.shape(); + *out->mut_is_dynamic() = in_0.is_dynamic(); + return Maybe::Ok(); +} + +/*static*/ Maybe AddNOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe AddNOp::GetSbp(user_op::SbpContext* ctx) { + int64_t num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape().NumAxes(); + for (int64_t i = 0; i < num_axes; ++i) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(user_op::OpArg("out", 0), i).Build(); + } + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(user_op::OpArg("out", 0)).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe AddNOp::InferDataType(user_op::InferContext* ctx) { + const auto& in_0 = ctx->InputTensorDesc("in", 0); + auto* out = ctx->OutputTensorDesc("out", 0); + CHECK_NOTNULL_OR_RETURN(out); + for (const auto& pair : ctx->inputs()) { + const auto& cur_in = ctx->InputTensorDesc(pair.first, pair.second); + CHECK_EQ_OR_RETURN(in_0.data_type(), cur_in.data_type()); + } + *out->mut_data_type() = in_0.data_type(); + return Maybe::Ok(); +} + +/*static*/ Maybe AddNOp::CheckAttr(const user_op::UserOpDefWrapper&, + const user_op::UserOpConfWrapper& op_conf) { + CHECK_OR_RETURN(op_conf.input_size("in") >= 2); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("add_n").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/affine_grid_op.cpp b/oneflow/user/ops/affine_grid_op.cpp index 449cfe21560..6d042aa851c 100644 --- a/oneflow/user/ops/affine_grid_op.cpp +++ b/oneflow/user/ops/affine_grid_op.cpp @@ -14,13 +14,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { -Maybe CheckAttr(const user_op::UserOpDefWrapper& def, - const user_op::UserOpConfWrapper& conf) { +Maybe CheckAttr_(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { bool pass_checked = true; std::stringstream err; err << "Illegal value for " << conf.op_type_name() << " op " << conf.op_name() << ": "; @@ -44,89 +45,99 @@ Maybe CheckAttr(const user_op::UserOpDefWrapper& def, } // namespace -REGISTER_USER_OP("affine_grid") - .Input("theta") - .Output("grid") - .Attr("size") - .Attr("align_corners") - .SetCheckAttrFn(CheckAttr) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& theta = ctx->InputTensorDesc("theta", 0); - user_op::TensorDesc* grid = ctx->OutputTensorDesc("grid", 0); - const Shape& size = ctx->Attr("size"); - // Only support 2D or 3D affine grid with NCHW layout - // For 2D grid: theta = { N, 2, 3 }, - // size = { N, C, H, W } - // grid = { N, H, W, 2 } - // For 3D grid: theta = { N, 3, 4 }, - // size = { N, C, D, H, W } - // grid = { N, D, H, W, 3 } - bool is_2d_grid = true; - if (theta.shape().At(1) == 2) { - CHECK_EQ_OR_RETURN(theta.shape().At(2), 3) << "Theta shape MUST be (N, 2, 3) or (N, 3, 4)"; - CHECK_EQ_OR_RETURN(size.NumAxes(), 4) << "Dimension of size MUST be 4, when 2d affine grid"; - CHECK_EQ_OR_RETURN(theta.shape().At(0), size.At(0)) - << "Theta and size MUST have same batch dimension"; - is_2d_grid = true; - } else if (theta.shape().At(1) == 3) { - CHECK_EQ_OR_RETURN(theta.shape().At(2), 4) << "Theta shape MUST be (N, 2, 3) or (N, 3, 4)"; - CHECK_EQ_OR_RETURN(size.NumAxes(), 5) "Dimension of size MUST be 4, when 3d affine grid"; - CHECK_EQ_OR_RETURN(theta.shape().At(0), size.At(0)) - << "Theta and size MUST have same batch dimension"; - is_2d_grid = false; - } else { - CHECK_OR_RETURN(false) << "Theta MUST be 2D or 3D grid"; - } - *grid->mut_is_dynamic() = theta.is_dynamic(); - Shape& grid_shape = *grid->mut_shape(); - if (is_2d_grid) { - grid_shape = {size.At(0), size.At(2), size.At(3), 2}; - } else { - grid_shape = {size.At(0), size.At(2), size.At(3), size.At(4), 3}; - } - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("theta", 0), 0) - .Split(user_op::OpArg("grid", 0), 0) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("grid", 0) = ctx->InputDType("theta", 0); - return Maybe::Ok(); - }); +/* static */ Maybe AffineGridOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& theta = ctx->InputTensorDesc("theta", 0); + user_op::TensorDesc* grid = ctx->OutputTensorDesc("grid", 0); + const Shape& size = ctx->Attr("size"); + // Only support 2D or 3D affine grid with NCHW layout + // For 2D grid: theta = { N, 2, 3 }, + // size = { N, C, H, W } + // grid = { N, H, W, 2 } + // For 3D grid: theta = { N, 3, 4 }, + // size = { N, C, D, H, W } + // grid = { N, D, H, W, 3 } + bool is_2d_grid = true; + if (theta.shape().At(1) == 2) { + CHECK_EQ_OR_RETURN(theta.shape().At(2), 3) << "Theta shape MUST be (N, 2, 3) or (N, 3, 4)"; + CHECK_EQ_OR_RETURN(size.NumAxes(), 4) << "Dimension of size MUST be 4, when 2d affine grid"; + CHECK_EQ_OR_RETURN(theta.shape().At(0), size.At(0)) + << "Theta and size MUST have same batch dimension"; + is_2d_grid = true; + } else if (theta.shape().At(1) == 3) { + CHECK_EQ_OR_RETURN(theta.shape().At(2), 4) << "Theta shape MUST be (N, 2, 3) or (N, 3, 4)"; + CHECK_EQ_OR_RETURN(size.NumAxes(), 5) "Dimension of size MUST be 4, when 3d affine grid"; + CHECK_EQ_OR_RETURN(theta.shape().At(0), size.At(0)) + << "Theta and size MUST have same batch dimension"; + is_2d_grid = false; + } else { + CHECK_OR_RETURN(false) << "Theta MUST be 2D or 3D grid"; + } + *grid->mut_is_dynamic() = theta.is_dynamic(); + Shape& grid_shape = *grid->mut_shape(); + if (is_2d_grid) { + grid_shape = {size.At(0), size.At(2), size.At(3), 2}; + } else { + grid_shape = {size.At(0), size.At(2), size.At(3), size.At(4), 3}; + } + return Maybe::Ok(); +} -REGISTER_USER_OP("affine_grid_grad") - .Input("dgrid") - .Output("dtheta") - .Attr("size") - .Attr("align_corners") - .SetCheckAttrFn(CheckAttr) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& size = ctx->Attr("size"); - - if (size.NumAxes() == 4) { - *(ctx->OutputTensorDesc("dtheta", 0)->mut_shape()) = {size.At(0), 2, 3}; - } else if (size.NumAxes() == 5) { - *(ctx->OutputTensorDesc("dtheta", 0)->mut_shape()) = {size.At(0), 3, 4}; - } else { - CHECK_OR_RETURN(false) << "size MUST be 4D or 5D"; - } - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("dgrid", 0), 0) - .Split(user_op::OpArg("dtheta", 0), 0) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dtheta", 0) = ctx->InputDType("dgrid", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe AffineGridOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe AffineGridOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("theta", 0), 0) + .Split(user_op::OpArg("grid", 0), 0) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe AffineGridOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + return CheckAttr_(def, conf); +} + +/* static */ Maybe AffineGridOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("grid", 0) = ctx->InputDType("theta", 0); + return Maybe::Ok(); +} + +/* static */ Maybe AffineGridGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& size = ctx->Attr("size"); + + if (size.NumAxes() == 4) { + *(ctx->OutputTensorDesc("dtheta", 0)->mut_shape()) = {size.At(0), 2, 3}; + } else if (size.NumAxes() == 5) { + *(ctx->OutputTensorDesc("dtheta", 0)->mut_shape()) = {size.At(0), 3, 4}; + } else { + CHECK_OR_RETURN(false) << "size MUST be 4D or 5D"; + } + return Maybe::Ok(); +} + +/*static*/ Maybe AffineGridGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe AffineGridGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("dgrid", 0), 0) + .Split(user_op::OpArg("dtheta", 0), 0) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe AffineGridGradOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + return CheckAttr_(def, conf); +} + +/* static */ Maybe AffineGridGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dtheta", 0) = ctx->InputDType("dgrid", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("affine_grid") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/amp_white_identity_op.cpp b/oneflow/user/ops/amp_white_identity_op.cpp index 269b08c3ef7..46a90141d8d 100644 --- a/oneflow/user/ops/amp_white_identity_op.cpp +++ b/oneflow/user/ops/amp_white_identity_op.cpp @@ -14,35 +14,37 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/* static */ Maybe AmpWhiteIdentityOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + *out->mut_shape() = in.shape(); + *out->mut_is_dynamic() = in.is_dynamic(); + return Maybe::Ok(); +} -REGISTER_USER_OP("amp_white_identity") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - *out->mut_shape() = in.shape(); - *out->mut_is_dynamic() = in.is_dynamic(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) { - const auto& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - for (int i = 0; i < in.shape().NumAxes(); ++i) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - *out->mut_data_type() = in.data_type(); - return Maybe::Ok(); - }); +/*static*/ Maybe AmpWhiteIdentityOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe AmpWhiteIdentityOp::GetSbp(user_op::SbpContext* ctx) { + const auto& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + for (int i = 0; i < in.shape().NumAxes(); ++i) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe AmpWhiteIdentityOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + *out->mut_data_type() = in.data_type(); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("amp_white_identity") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -60,6 +62,4 @@ REGISTER_USER_OP_GRAD("amp_white_identity") return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/arange_op.cpp b/oneflow/user/ops/arange_op.cpp index 7a7fc3e0c97..1d6bb3556ca 100644 --- a/oneflow/user/ops/arange_op.cpp +++ b/oneflow/user/ops/arange_op.cpp @@ -14,59 +14,58 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("arange") - .Output("out") - .Attr("integer_start") - .Attr("integer_delta") - .Attr("integer_limit") - .Attr("float_start") - .Attr("float_delta") - .Attr("float_limit") - .Attr("dtype") - .Attr>("nd_sbp") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - Shape* out_shape = ctx->OutputShape("out", 0); - DataType dtype = ctx->Attr("dtype"); - int64_t range_elem_cnt = 0; - if (IsIntegralDataType(dtype)) { - int64_t integer_delta = ctx->Attr("integer_delta"); - CHECK_NE_OR_RETURN(integer_delta, static_cast(0)) - << "RuntimeError: step must be nonzero. "; - int64_t integer_start = ctx->Attr("integer_start"); - int64_t integer_limit = ctx->Attr("integer_limit"); - // CHECK when limit > start, delta > 0; limit < start, delta < 0; - CHECK_GT_OR_RETURN((integer_limit - integer_start) / integer_delta, static_cast(0)) - << "RuntimeError: upper bound and larger bound inconsistent with step sign"; - range_elem_cnt = - std::ceil(static_cast(integer_limit - integer_start) / integer_delta); - } else { - double float_delta = ctx->Attr("float_delta"); - CHECK_NE_OR_RETURN(float_delta, static_cast(0.0)) - << "RuntimeError: step must be nonzero. "; - double float_start = ctx->Attr("float_start"); - double float_limit = ctx->Attr("float_limit"); - // CHECK when limit > start, delta > 0; limit < start, delta < 0; - CHECK_GT_OR_RETURN((float_limit - float_start) / float_delta, static_cast(0.0)) - << "RuntimeError: upper bound and larger bound inconsistent with step sign"; - range_elem_cnt = std::ceil(static_cast(float_limit - float_start) / float_delta); - } - *out_shape = Shape({range_elem_cnt}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->Attr("dtype"); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - cfg::SbpParallel default_sbp; - default_sbp.mutable_broadcast_parallel(); - return user_op::InferNdSbp4SrcOp(ctx, default_sbp); - }); + +/* static */ Maybe ArangeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + Shape* out_shape = ctx->OutputShape("out", 0); + DataType dtype = ctx->Attr("dtype"); + int64_t range_elem_cnt = 0; + if (IsIntegralDataType(dtype)) { + int64_t integer_delta = ctx->Attr("integer_delta"); + CHECK_NE_OR_RETURN(integer_delta, static_cast(0)) + << "RuntimeError: step must be nonzero. "; + int64_t integer_start = ctx->Attr("integer_start"); + int64_t integer_limit = ctx->Attr("integer_limit"); + // CHECK when limit > start, delta > 0; limit < start, delta < 0; + CHECK_GT_OR_RETURN((integer_limit - integer_start) / integer_delta, static_cast(0)) + << "RuntimeError: upper bound and larger bound inconsistent with step sign"; + range_elem_cnt = std::ceil(static_cast(integer_limit - integer_start) / integer_delta); + } else { + double float_delta = ctx->Attr("float_delta"); + CHECK_NE_OR_RETURN(float_delta, static_cast(0.0)) + << "RuntimeError: step must be nonzero. "; + double float_start = ctx->Attr("float_start"); + double float_limit = ctx->Attr("float_limit"); + // CHECK when limit > start, delta > 0; limit < start, delta < 0; + // CHECK_GE For 0-Dim Tensor + CHECK_GE_OR_RETURN((float_limit - float_start) / float_delta, static_cast(0.0)) + << "RuntimeError: upper bound and larger bound inconsistent with step sign"; + range_elem_cnt = std::ceil(static_cast(float_limit - float_start) / float_delta); + } + *out_shape = Shape({range_elem_cnt}); + return Maybe::Ok(); +} + +/*static*/ Maybe ArangeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ArangeOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ArangeOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + cfg::SbpParallel default_sbp; + default_sbp.mutable_broadcast_parallel(); + return user_op::InferNdSbp4SrcOp(ctx, default_sbp); +} + +/* static */ Maybe ArangeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->Attr("dtype"); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/arg_sort_op.cpp b/oneflow/user/ops/arg_sort_op.cpp index 1b7df445d72..7cc0b23ed45 100644 --- a/oneflow/user/ops/arg_sort_op.cpp +++ b/oneflow/user/ops/arg_sort_op.cpp @@ -14,35 +14,39 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("arg_sort") - .Input("in") - .Output("out") - .Attr("direction") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - // The current implementation can only do arg_sort in the last dimension and should use - // Broadcast (by default) instead of Split for that dimension - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - return Maybe::Ok(); - }) - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& op_def, - const user_op::UserOpConfWrapper& op_conf) -> Maybe { - const std::string& direction = op_conf.attr("direction"); - CHECK_OR_RETURN(direction == "ASCENDING" || direction == "DESCENDING"); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = DataType::kInt32; - return Maybe::Ok(); - }); +/* static */ Maybe ArgSortOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe ArgSortOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ArgSortOp::GetSbp(user_op::SbpContext* ctx) { + // The current implementation can only do arg_sort in the last dimension and should use + // Broadcast (by default) instead of Split for that dimension + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe ArgSortOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + const std::string& direction = conf.attr("direction"); + CHECK_OR_RETURN(direction == "ASCENDING" || direction == "DESCENDING"); + return Maybe::Ok(); +} + +/* static */ Maybe ArgSortOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = DataType::kInt32; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/arg_where_op.cpp b/oneflow/user/ops/arg_where_op.cpp index 18f545ade3e..3ce31486a50 100644 --- a/oneflow/user/ops/arg_where_op.cpp +++ b/oneflow/user/ops/arg_where_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -31,20 +32,25 @@ Maybe InferTensorDesc(user_op::InferContext* ctx) { } // namespace -REGISTER_NO_GRAD_USER_OP("argwhere") - .Input("input") - .Output("output") - .Output("output_size") - .Attr("dtype", DataType::kInt32) - .SetTensorDescInferFn(InferTensorDesc) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const DataType dtype = ctx->Attr("dtype"); - user_op::TensorDesc* output_desc = ctx->OutputTensorDesc("output", 0); - *output_desc->mut_data_type() = dtype; - user_op::TensorDesc* output_size_desc = ctx->OutputTensorDesc("output_size", 0); - *output_size_desc->mut_data_type() = dtype; - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/* static */ Maybe ArgwhereOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDesc(ctx); +} + +/*static*/ Maybe ArgwhereOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ArgwhereOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe ArgwhereOp::InferDataType(user_op::InferContext* ctx) { + const DataType dtype = ctx->Attr("dtype"); + user_op::TensorDesc* output_desc = ctx->OutputTensorDesc("output", 0); + *output_desc->mut_data_type() = dtype; + user_op::TensorDesc* output_size_desc = ctx->OutputTensorDesc("output_size", 0); + *output_size_desc->mut_data_type() = dtype; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/argmax_op.cpp b/oneflow/user/ops/argmax_op.cpp index e79105e8269..58c6581eb29 100644 --- a/oneflow/user/ops/argmax_op.cpp +++ b/oneflow/user/ops/argmax_op.cpp @@ -14,28 +14,32 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("argmax") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - auto dim_vec = ctx->InputShape("in", 0).dim_vec(); - dim_vec.pop_back(); - *ctx->OutputShape("out", 0) = Shape(std::move(dim_vec)); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = DataType::kInt64; - return Maybe::Ok(); - }); +/* static */ Maybe ArgmaxOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + auto dim_vec = ctx->InputShape("in", 0).dim_vec(); + dim_vec.pop_back(); + *ctx->OutputShape("out", 0) = Shape(std::move(dim_vec)); + return Maybe::Ok(); +} + +/*static*/ Maybe ArgmaxOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ArgmaxOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe ArgmaxOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = DataType::kInt64; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/assign_op.cpp b/oneflow/user/ops/assign_op.cpp index f54342c2cce..c2b296dbca7 100644 --- a/oneflow/user/ops/assign_op.cpp +++ b/oneflow/user/ops/assign_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -64,7 +65,7 @@ Maybe InputArgModifierFn(const user_op::GetInputArgModifier& GetInputArgMo return Maybe::Ok(); } -Maybe InferDataType(user_op::InferContext* ctx) { +Maybe InferDataType_(user_op::InferContext* ctx) { const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc("ref", 0); const user_op::TensorDesc& value_desc = ctx->InputTensorDesc("value", 0); CHECK_OR_RETURN(ref_desc.data_type() == value_desc.data_type()); @@ -77,30 +78,32 @@ Maybe InferDataType(user_op::InferContext* ctx) { } // namespace -REGISTER_NO_GRAD_USER_OP("assign") - .Input("ref") - .Input("value") - .SetTensorDescInferFn(InferTensorDesc) - .SetGetSbpFn(GetSbpSignatures) - .SetInputArgModifyFn(InputArgModifierFn) - .SetDataTypeInferFn(InferDataType); +#define DEF_ASSIGN_OP(op_class_name) \ + /* static */ Maybe op_class_name::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return InferTensorDesc(ctx); \ + } \ + \ + /*static*/ Maybe op_class_name::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + \ + /* static */ Maybe op_class_name::GetSbp(user_op::SbpContext* ctx) { \ + return GetSbpSignatures(ctx); \ + } \ + \ + /* static */ Maybe op_class_name::ModifyInputArg( \ + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { \ + return InputArgModifierFn(GetInputArgModifierFn, conf); \ + } \ + \ + /* static */ Maybe op_class_name::InferDataType(user_op::InferContext* ctx) { \ + return InferDataType_(ctx); \ + } -REGISTER_NO_GRAD_USER_OP("assign_if") - .Input("ref") - .Input("value") - .Input("condition") - .SetTensorDescInferFn(InferTensorDesc) - .SetGetSbpFn(GetSbpSignatures) - .SetInputArgModifyFn(InputArgModifierFn) - .SetDataTypeInferFn(InferDataType); +DEF_ASSIGN_OP(AssignUserOp) +DEF_ASSIGN_OP(AssignIfOp) +DEF_ASSIGN_OP(AssignIfNotOp) -REGISTER_NO_GRAD_USER_OP("assign_if_not") - .Input("ref") - .Input("value") - .Input("condition") - .SetTensorDescInferFn(InferTensorDesc) - .SetGetSbpFn(GetSbpSignatures) - .SetInputArgModifyFn(InputArgModifierFn) - .SetDataTypeInferFn(InferDataType); +#undef DEF_ASSIGN_OP } // namespace oneflow diff --git a/oneflow/user/ops/batch_gather_op.cpp b/oneflow/user/ops/batch_gather_op.cpp index 4f37bf102b0..f0581702ece 100644 --- a/oneflow/user/ops/batch_gather_op.cpp +++ b/oneflow/user/ops/batch_gather_op.cpp @@ -14,75 +14,79 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("batch_gather") - .Input("in") - .Input("indices") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - CHECK_GT_OR_RETURN(in.shape().NumAxes(), 0); - const user_op::TensorDesc& indices = ctx->InputTensorDesc("indices", 0); - CHECK_GT_OR_RETURN(indices.shape().NumAxes(), 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - CHECK_LE_OR_RETURN(indices.shape().dim_vec().size(), in.shape().dim_vec().size()); - FOR_RANGE(int64_t, i, 0, indices.shape().dim_vec().size() - 1) { - if (in.is_dynamic() && indices.is_dynamic() == false) { - CHECK_GE_OR_RETURN(indices.shape().dim_vec().at(i), in.shape().dim_vec().at(i)); - } else if (in.is_dynamic() == false && indices.is_dynamic()) { - UNIMPLEMENTED(); - } else { - CHECK_EQ_OR_RETURN(indices.shape().dim_vec().at(i), in.shape().dim_vec().at(i)); - } - } +/* static */ Maybe BatchGatherOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + CHECK_GT_OR_RETURN(in.shape().NumAxes(), 0); + const user_op::TensorDesc& indices = ctx->InputTensorDesc("indices", 0); + CHECK_GT_OR_RETURN(indices.shape().NumAxes(), 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + CHECK_LE_OR_RETURN(indices.shape().dim_vec().size(), in.shape().dim_vec().size()); + FOR_RANGE(int64_t, i, 0, indices.shape().dim_vec().size() - 1) { + if (in.is_dynamic() && indices.is_dynamic() == false) { + CHECK_GE_OR_RETURN(indices.shape().dim_vec().at(i), in.shape().dim_vec().at(i)); + } else if (in.is_dynamic() == false && indices.is_dynamic()) { + UNIMPLEMENTED(); + } else { + CHECK_EQ_OR_RETURN(indices.shape().dim_vec().at(i), in.shape().dim_vec().at(i)); + } + } - DimVector dim_vec(in.shape().dim_vec()); - dim_vec.at(indices.shape().NumAxes() - 1) = indices.shape().dim_vec().back(); - *out->mut_shape() = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK_OR_RETURN(indices_modifier != nullptr); - indices_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const int64_t indices_num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0).shape().NumAxes(); - if (indices_num_axes > 1) { - FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) { - ctx->NewBuilder() - .Split(user_op::OpArg("indices", 0), i) - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .Broadcast(user_op::OpArg("indices", 0)) - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - } else { - auto err = std::make_shared(); - err->set_msg("BatchGatherOp: indices_num_axes equals " + std::to_string(indices_num_axes) - + " (should be bigger than 1)."); - err->mutable_check_failed_error(); - return err; - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& indices = ctx->InputTensorDesc("indices", 0); - CHECK_OR_RETURN(IsIndexDataType(indices.data_type())); - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - *out->mut_data_type() = in.data_type(); - return Maybe::Ok(); - }); + DimVector dim_vec(in.shape().dim_vec()); + dim_vec.at(indices.shape().NumAxes() - 1) = indices.shape().dim_vec().back(); + *out->mut_shape() = Shape(dim_vec); + return Maybe::Ok(); +} + +/*static*/ Maybe BatchGatherOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe BatchGatherOp::GetSbp(user_op::SbpContext* ctx) { + const int64_t indices_num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0).shape().NumAxes(); + if (indices_num_axes > 1) { + FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) { + ctx->NewBuilder() + .Split(user_op::OpArg("indices", 0), i) + .Split(user_op::OpArg("in", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + ctx->NewBuilder() + .Broadcast(user_op::OpArg("indices", 0)) + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + } else { + auto err = std::make_shared(); + err->set_msg("BatchGatherOp: indices_num_axes equals " + std::to_string(indices_num_axes) + + " (should be bigger than 1)."); + err->mutable_check_failed_error(); + return err; + } + return Maybe::Ok(); +} + +/* static */ Maybe BatchGatherOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); + CHECK_OR_RETURN(indices_modifier != nullptr); + indices_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe BatchGatherOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& indices = ctx->InputTensorDesc("indices", 0); + CHECK_OR_RETURN(IsIndexDataType(indices.data_type())); + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + *out->mut_data_type() = in.data_type(); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("batch_gather") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/bernoulli_op.cpp b/oneflow/user/ops/bernoulli_op.cpp index 53f854e62f1..3068b83fd0c 100644 --- a/oneflow/user/ops/bernoulli_op.cpp +++ b/oneflow/user/ops/bernoulli_op.cpp @@ -14,32 +14,33 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("bernoulli") - .Input("in") - .Output("out") - .Attr("seed", -1) - .Attr("has_seed", false) - .Attr("dtype") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - *out_tensor->mut_shape() = in_tensor.shape(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - for (int i = 0; i < in_tensor.shape().NumAxes(); ++i) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - *out_tensor->mut_data_type() = ctx->Attr("dtype"); - return Maybe::Ok(); - }); +/* static */ Maybe BernoulliOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + *out_tensor->mut_shape() = in_tensor.shape(); + return Maybe::Ok(); +} + +/*static*/ Maybe BernoulliOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe BernoulliOp::GetSbp(user_op::SbpContext* ctx) { + const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + for (int i = 0; i < in_tensor.shape().NumAxes(); ++i) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe BernoulliOp::InferDataType(user_op::InferContext* ctx) { + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + *out_tensor->mut_data_type() = ctx->Attr("dtype"); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/bias_add_op.cpp b/oneflow/user/ops/bias_add_op.cpp index 963e38c3116..ba0c928804f 100644 --- a/oneflow/user/ops/bias_add_op.cpp +++ b/oneflow/user/ops/bias_add_op.cpp @@ -14,48 +14,50 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("bias_add") - .Input("a") - .Input("b") - .Output("out") - .Attr("axis") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); - const auto& b_tensor_desc = ctx->InputTensorDesc("b", 0); - const auto bias_add_axis = ctx->Attr("axis"); - CHECK_EQ_OR_RETURN(b_tensor_desc.shape().NumAxes(), 1); - CHECK_GE_OR_RETURN(bias_add_axis, 0); - CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); - CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); - *ctx->OutputShape("out", 0) = ctx->InputShape("a", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("a", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto axis = ctx->Attr("axis"); - for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape().NumAxes(); - ++i) { - if (i == axis) { continue; } - ctx->NewBuilder() - .Split(user_op::OpArg("a", 0), i) - .Broadcast(user_op::OpArg("b", 0)) - .Split(ctx->outputs(), i) - .Build(); - } - ctx->NewBuilder() - .Split(user_op::OpArg("b", 0), 0) - .Split(user_op::OpArg("a", 0), axis) - .Split(ctx->outputs(), axis) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("a", 0); - return Maybe::Ok(); - }); +/* static */ Maybe BiasAddOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); + const auto& b_tensor_desc = ctx->InputTensorDesc("b", 0); + const auto bias_add_axis = ctx->Attr("axis"); + CHECK_EQ_OR_RETURN(b_tensor_desc.shape().NumAxes(), 1); + CHECK_GE_OR_RETURN(bias_add_axis, 0); + CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); + CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); + *ctx->OutputShape("out", 0) = ctx->InputShape("a", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("a", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe BiasAddOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe BiasAddOp::GetSbp(user_op::SbpContext* ctx) { + const auto axis = ctx->Attr("axis"); + for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape().NumAxes(); + ++i) { + if (i == axis) { continue; } + ctx->NewBuilder() + .Split(user_op::OpArg("a", 0), i) + .Broadcast(user_op::OpArg("b", 0)) + .Split(ctx->outputs(), i) + .Build(); + } + ctx->NewBuilder() + .Split(user_op::OpArg("b", 0), 0) + .Split(user_op::OpArg("a", 0), axis) + .Split(ctx->outputs(), axis) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe BiasAddOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("a", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("bias_add") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/binary_cross_entropy_op.cpp b/oneflow/user/ops/binary_cross_entropy_op.cpp index b4bd3b76f74..0d328657660 100644 --- a/oneflow/user/ops/binary_cross_entropy_op.cpp +++ b/oneflow/user/ops/binary_cross_entropy_op.cpp @@ -16,10 +16,13 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/loss_op_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { + namespace { -Maybe InferTensorDescFn(user_op::InferContext* ctx) { + +Maybe InferTensorDescFn_(user_op::InferContext* ctx) { const auto& input_desc = ctx->InputTensorDesc("input", 0); const auto& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.is_dynamic(), target_desc.is_dynamic()); @@ -37,7 +40,7 @@ Maybe InferTensorDescFn(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe InferDataType(user_op::InferContext* ctx) { +Maybe InferDataType_(user_op::InferContext* ctx) { const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type()); @@ -85,31 +88,47 @@ Maybe InferGradDataType(user_op::InferContext* ctx) { } } // namespace -REGISTER_USER_OP("binary_cross_entropy") - .Input("input") - .Input("target") - .OptionalInput("weight") - .Output("out") - .SetTensorDescInferFn(InferTensorDescFn) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); - CHECK_OR_RETURN(target_modifier != nullptr); - target_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetDataTypeInferFn(InferDataType) - .SetGetSbpFn(GenLossForwardDefaultGetSbpFn()); - -REGISTER_USER_OP("binary_cross_entropy_grad") - .Input("input") - .Input("target") - .OptionalInput("weight") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn(InferGradTensorDescFn) - .SetDataTypeInferFn(InferGradDataType) - .SetGetSbpFn(GenLossBackwardDefaultGetSbpFn()); +/* static */ Maybe BinaryCrossEntropyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDescFn_(ctx); +} + +/*static*/ Maybe BinaryCrossEntropyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe BinaryCrossEntropyOp::GetSbp(user_op::SbpContext* ctx) { + return GenLossForwardDefaultGetSbpFn()(ctx); +} + +/* static */ Maybe BinaryCrossEntropyOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); + CHECK_OR_RETURN(target_modifier != nullptr); + target_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe BinaryCrossEntropyOp::InferDataType(user_op::InferContext* ctx) { + return InferDataType_(ctx); +} + +/* static */ Maybe BinaryCrossEntropyGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return InferGradTensorDescFn(ctx); +} + +/*static*/ Maybe BinaryCrossEntropyGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe BinaryCrossEntropyGradOp::GetSbp(user_op::SbpContext* ctx) { + return GenLossBackwardDefaultGetSbpFn()(ctx); +} + +/* static */ Maybe BinaryCrossEntropyGradOp::InferDataType(user_op::InferContext* ctx) { + return InferGradDataType(ctx); +} REGISTER_USER_OP_GRAD("binary_cross_entropy") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/binary_cross_entropy_with_logits_op.cpp b/oneflow/user/ops/binary_cross_entropy_with_logits_op.cpp index d3214edabe7..0a124525a60 100644 --- a/oneflow/user/ops/binary_cross_entropy_with_logits_op.cpp +++ b/oneflow/user/ops/binary_cross_entropy_with_logits_op.cpp @@ -16,6 +16,7 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/loss_op_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { @@ -42,7 +43,7 @@ Maybe InferTensorDescFn(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe InferDataType(user_op::InferContext* ctx) { +Maybe InferDataType_(user_op::InferContext* ctx) { const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type()); @@ -101,39 +102,60 @@ Maybe InferGradDataType(user_op::InferContext* ctx) { } } // namespace -REGISTER_USER_OP("binary_cross_entropy_with_logits") - .Input("input") - .Input("target") - .OptionalInput("weight") - .OptionalInput("pos_weight") - .Output("out") - .Attr("has_pos_weight") - .SetTensorDescInferFn(InferTensorDescFn) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); - CHECK_OR_RETURN(target_modifier != nullptr); - target_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetDataTypeInferFn(InferDataType) - .SetGetSbpFn(GenLossForwardDefaultGetSbpFn([](user_op::UserOpSbpSignatureBuilder& builder) { - builder.Broadcast(user_op::OpArg("pos_weight", 0)); - })); - -REGISTER_USER_OP("binary_cross_entropy_with_logits_grad") - .Input("input") - .Input("target") - .OptionalInput("weight") - .OptionalInput("pos_weight") - .Input("dy") - .Output("dx") - .Attr("has_pos_weight") - .SetTensorDescInferFn(InferGradTensorDescFn) - .SetDataTypeInferFn(InferGradDataType) - .SetGetSbpFn(GenLossBackwardDefaultGetSbpFn([](user_op::UserOpSbpSignatureBuilder& builder) { - builder.Broadcast(user_op::OpArg("pos_weight", 0)); - })); +/* static */ Maybe BinaryCrossEntropyWithLogitsOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return InferTensorDescFn(ctx); +} + +/*static*/ Maybe BinaryCrossEntropyWithLogitsOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe BinaryCrossEntropyWithLogitsOp::GetSbp(user_op::SbpContext* ctx) { + return GenLossForwardDefaultGetSbpFn( + [](user_op::UserOpSbpSignatureBuilder& builder, user_op::SbpContext* ctx) { + if (ctx->user_op_conf().has_input("pos_weight", 0)) { + builder.Broadcast(user_op::OpArg("pos_weight", 0)); + } + })(ctx); +} + +/* static */ Maybe BinaryCrossEntropyWithLogitsOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); + CHECK_OR_RETURN(target_modifier != nullptr); + target_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe BinaryCrossEntropyWithLogitsOp::InferDataType(user_op::InferContext* ctx) { + return InferDataType_(ctx); +} + +/* static */ Maybe BinaryCrossEntropyWithLogitsGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return InferGradTensorDescFn(ctx); +} + +/*static*/ Maybe BinaryCrossEntropyWithLogitsGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe BinaryCrossEntropyWithLogitsGradOp::GetSbp(user_op::SbpContext* ctx) { + return GenLossBackwardDefaultGetSbpFn( + [](user_op::UserOpSbpSignatureBuilder& builder, user_op::SbpContext* ctx) { + if (ctx->user_op_conf().has_input("pos_weight", 0)) { + builder.Broadcast(user_op::OpArg("pos_weight", 0)); + } + })(ctx); +} + +/* static */ Maybe BinaryCrossEntropyWithLogitsGradOp::InferDataType( + user_op::InferContext* ctx) { + return InferGradDataType(ctx); +} REGISTER_USER_OP_GRAD("binary_cross_entropy_with_logits") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/broadcast_div_grad_op.cpp b/oneflow/user/ops/broadcast_div_grad_op.cpp index 39add2276c7..8e1c16a2b2a 100644 --- a/oneflow/user/ops/broadcast_div_grad_op.cpp +++ b/oneflow/user/ops/broadcast_div_grad_op.cpp @@ -14,59 +14,61 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("broadcast_div_grad") - .Input("y") - .Input("z") - .Input("dz") - .Output("dy") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("dy", 0) = ctx->InputShape("y", 0); - *ctx->OutputIsDynamic("dy", 0) = ctx->InputIsDynamic("y", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const Shape& y_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0).shape(); - const Shape& z_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("z", 0).shape(); - CHECK_LE_OR_RETURN(y_shape.NumAxes(), z_shape.NumAxes()); - FOR_RANGE(int64_t, i, 0, y_shape.NumAxes()) { - const int64_t axis_y = y_shape.NumAxes() - 1 - i; - const int64_t axis_z = z_shape.NumAxes() - 1 - i; - if (y_shape.At(axis_y) == z_shape.At(axis_z)) { - ctx->NewBuilder() - .Split(user_op::OpArg("y", 0), axis_y) - .Split(user_op::OpArg("z", 0), axis_z) - .Split(user_op::OpArg("dz", 0), axis_z) - .Split(user_op::OpArg("dy", 0), axis_y) - .Build(); - } else { - ctx->NewBuilder() - .Broadcast(user_op::OpArg("y", 0)) - .Split(user_op::OpArg("z", 0), axis_z) - .Split(user_op::OpArg("dz", 0), axis_z) - .PartialSum(user_op::OpArg("dy", 0)) - .Build(); - } - } +/* static */ Maybe BroadcastDivGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("dy", 0) = ctx->InputShape("y", 0); + *ctx->OutputIsDynamic("dy", 0) = ctx->InputIsDynamic("y", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe BroadcastDivGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe BroadcastDivGradOp::GetSbp(user_op::SbpContext* ctx) { + const Shape& y_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0).shape(); + const Shape& z_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("z", 0).shape(); + CHECK_LE_OR_RETURN(y_shape.NumAxes(), z_shape.NumAxes()); + FOR_RANGE(int64_t, i, 0, y_shape.NumAxes()) { + const int64_t axis_y = y_shape.NumAxes() - 1 - i; + const int64_t axis_z = z_shape.NumAxes() - 1 - i; + if (y_shape.At(axis_y) == z_shape.At(axis_z)) { ctx->NewBuilder() - .Broadcast(user_op::OpArg("y", 0)) - .PartialSum(user_op::OpArg("z", 0)) - .Broadcast(user_op::OpArg("dz", 0)) - .Broadcast(user_op::OpArg("dy", 0)) + .Split(user_op::OpArg("y", 0), axis_y) + .Split(user_op::OpArg("z", 0), axis_z) + .Split(user_op::OpArg("dz", 0), axis_z) + .Split(user_op::OpArg("dy", 0), axis_y) .Build(); + } else { ctx->NewBuilder() .Broadcast(user_op::OpArg("y", 0)) - .Broadcast(user_op::OpArg("z", 0)) - .PartialSum(user_op::OpArg("dz", 0)) - .Broadcast(user_op::OpArg("dy", 0)) + .Split(user_op::OpArg("z", 0), axis_z) + .Split(user_op::OpArg("dz", 0), axis_z) + .PartialSum(user_op::OpArg("dy", 0)) .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dy", 0) = ctx->InputDType("y", 0); - return Maybe::Ok(); - }); + } + } + ctx->NewBuilder() + .Broadcast(user_op::OpArg("y", 0)) + .PartialSum(user_op::OpArg("z", 0)) + .Broadcast(user_op::OpArg("dz", 0)) + .Broadcast(user_op::OpArg("dy", 0)) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("y", 0)) + .Broadcast(user_op::OpArg("z", 0)) + .PartialSum(user_op::OpArg("dz", 0)) + .Broadcast(user_op::OpArg("dy", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe BroadcastDivGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dy", 0) = ctx->InputDType("y", 0); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/broadcast_like_op.cpp b/oneflow/user/ops/broadcast_like_op.cpp index a1a54b6a407..6682f5ed2ea 100644 --- a/oneflow/user/ops/broadcast_like_op.cpp +++ b/oneflow/user/ops/broadcast_like_op.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/reduce_sbp_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -77,28 +78,32 @@ Maybe InferTensorDesc(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("like", 0); - return Maybe::Ok(); +} // namespace + +/* static */ Maybe BroadcastLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDesc(ctx); } -} // namespace +/*static*/ Maybe BroadcastLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} -REGISTER_USER_OP("broadcast_like") - .Input("x") - .Input("like") - .Attr>("broadcast_axes") - .Output("y") - .SetTensorDescInferFn(InferTensorDesc) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); - CHECK_OR_RETURN(like_modifier != nullptr); - like_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetSbpSignatures) - .SetDataTypeInferFn(InferDataType); +/* static */ Maybe BroadcastLikeOp::GetSbp(user_op::SbpContext* ctx) { + return GetSbpSignatures(ctx); +} + +/* static */ Maybe BroadcastLikeOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); + CHECK_OR_RETURN(like_modifier != nullptr); + like_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe BroadcastLikeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("like", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("broadcast_like") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { diff --git a/oneflow/user/ops/broadcast_pow_grad_op.cpp b/oneflow/user/ops/broadcast_pow_grad_op.cpp index a203304c817..21fa575b03b 100644 --- a/oneflow/user/ops/broadcast_pow_grad_op.cpp +++ b/oneflow/user/ops/broadcast_pow_grad_op.cpp @@ -14,108 +14,110 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("broadcast_pow_x_grad") - .Input("x") - .Input("y") - .Input("z") - .Input("dz") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("dx", 0) = ctx->InputShape("x", 0); - *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); - const Shape& y_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0).shape(); - const Shape& z_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("z", 0).shape(); - CHECK_LE_OR_RETURN(x_shape.NumAxes(), z_shape.NumAxes()); - CHECK_LE_OR_RETURN(y_shape.NumAxes(), z_shape.NumAxes()); - FOR_RANGE(int64_t, i, 0, z_shape.NumAxes()) { - const int64_t _axis = z_shape.NumAxes() - 1 - i; - if (z_shape.At(_axis) == x_shape.At(_axis) && z_shape.At(_axis) == y_shape.At(_axis)) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), _axis) - .Split(user_op::OpArg("y", 0), _axis) - .Split(user_op::OpArg("z", 0), _axis) - .Split(user_op::OpArg("dz", 0), _axis) - .Split(user_op::OpArg("dx", 0), _axis) - .Build(); - } - } - ctx->NewBuilder() - .Broadcast(user_op::OpArg("y", 0)) - .PartialSum(user_op::OpArg("z", 0)) - .Broadcast(user_op::OpArg("dz", 0)) - .Broadcast(user_op::OpArg("x", 0)) - .Broadcast(user_op::OpArg("dx", 0)) - .Build(); - ctx->NewBuilder() - .PartialSum(user_op::OpArg("y", 0)) - .Broadcast(user_op::OpArg("z", 0)) - .Broadcast(user_op::OpArg("dz", 0)) - .Broadcast(user_op::OpArg("x", 0)) - .Broadcast(user_op::OpArg("dx", 0)) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("y", 0)) - .Broadcast(user_op::OpArg("z", 0)) - .PartialSum(user_op::OpArg("dz", 0)) - .Broadcast(user_op::OpArg("x", 0)) - .Broadcast(user_op::OpArg("dx", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); -REGISTER_USER_OP("broadcast_pow_y_grad") - .Input("x") - .Input("y") - .Input("z") - .Input("dz") - .Output("dy") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("dy", 0) = ctx->InputShape("y", 0); - *ctx->OutputIsDynamic("dy", 0) = ctx->InputIsDynamic("y", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); - const Shape& z_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("z", 0).shape(); - CHECK_LE_OR_RETURN(x_shape.NumAxes(), z_shape.NumAxes()); - FOR_RANGE(int64_t, i, 0, z_shape.NumAxes()) { - const int64_t _axis = z_shape.NumAxes() - 1 - i; - if (z_shape.At(_axis) == x_shape.At(_axis)) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), _axis) - .Split(user_op::OpArg("z", 0), _axis) - .Split(user_op::OpArg("dz", 0), _axis) - .Split(user_op::OpArg("dy", 0), _axis) - .Build(); - } - } +/* static */ Maybe BroadcastPowXGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("dx", 0) = ctx->InputShape("x", 0); + *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe BroadcastPowXGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe BroadcastPowXGradOp::GetSbp(user_op::SbpContext* ctx) { + const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); + const Shape& y_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0).shape(); + const Shape& z_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("z", 0).shape(); + CHECK_LE_OR_RETURN(x_shape.NumAxes(), z_shape.NumAxes()); + CHECK_LE_OR_RETURN(y_shape.NumAxes(), z_shape.NumAxes()); + FOR_RANGE(int64_t, i, 0, z_shape.NumAxes()) { + const int64_t _axis = z_shape.NumAxes() - 1 - i; + if (z_shape.At(_axis) == x_shape.At(_axis) && z_shape.At(_axis) == y_shape.At(_axis)) { ctx->NewBuilder() - .Broadcast(user_op::OpArg("x", 0)) - .PartialSum(user_op::OpArg("z", 0)) - .Broadcast(user_op::OpArg("dz", 0)) - .Broadcast(user_op::OpArg("dy", 0)) + .Split(user_op::OpArg("x", 0), _axis) + .Split(user_op::OpArg("y", 0), _axis) + .Split(user_op::OpArg("z", 0), _axis) + .Split(user_op::OpArg("dz", 0), _axis) + .Split(user_op::OpArg("dx", 0), _axis) .Build(); + } + } + ctx->NewBuilder() + .Broadcast(user_op::OpArg("y", 0)) + .PartialSum(user_op::OpArg("z", 0)) + .Broadcast(user_op::OpArg("dz", 0)) + .Broadcast(user_op::OpArg("x", 0)) + .Broadcast(user_op::OpArg("dx", 0)) + .Build(); + ctx->NewBuilder() + .PartialSum(user_op::OpArg("y", 0)) + .Broadcast(user_op::OpArg("z", 0)) + .Broadcast(user_op::OpArg("dz", 0)) + .Broadcast(user_op::OpArg("x", 0)) + .Broadcast(user_op::OpArg("dx", 0)) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("y", 0)) + .Broadcast(user_op::OpArg("z", 0)) + .PartialSum(user_op::OpArg("dz", 0)) + .Broadcast(user_op::OpArg("x", 0)) + .Broadcast(user_op::OpArg("dx", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe BroadcastPowXGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} + +/* static */ Maybe BroadcastPowYGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("dy", 0) = ctx->InputShape("y", 0); + *ctx->OutputIsDynamic("dy", 0) = ctx->InputIsDynamic("y", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe BroadcastPowYGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe BroadcastPowYGradOp::GetSbp(user_op::SbpContext* ctx) { + const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); + const Shape& z_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("z", 0).shape(); + CHECK_LE_OR_RETURN(x_shape.NumAxes(), z_shape.NumAxes()); + FOR_RANGE(int64_t, i, 0, z_shape.NumAxes()) { + const int64_t _axis = z_shape.NumAxes() - 1 - i; + if (z_shape.At(_axis) == x_shape.At(_axis)) { ctx->NewBuilder() - .Broadcast(user_op::OpArg("x", 0)) - .Broadcast(user_op::OpArg("z", 0)) - .PartialSum(user_op::OpArg("dz", 0)) - .Broadcast(user_op::OpArg("dy", 0)) + .Split(user_op::OpArg("x", 0), _axis) + .Split(user_op::OpArg("z", 0), _axis) + .Split(user_op::OpArg("dz", 0), _axis) + .Split(user_op::OpArg("dy", 0), _axis) .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dy", 0) = ctx->InputDType("y", 0); - return Maybe::Ok(); - }); + } + } + ctx->NewBuilder() + .Broadcast(user_op::OpArg("x", 0)) + .PartialSum(user_op::OpArg("z", 0)) + .Broadcast(user_op::OpArg("dz", 0)) + .Broadcast(user_op::OpArg("dy", 0)) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("x", 0)) + .Broadcast(user_op::OpArg("z", 0)) + .PartialSum(user_op::OpArg("dz", 0)) + .Broadcast(user_op::OpArg("dy", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe BroadcastPowYGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dy", 0) = ctx->InputDType("y", 0); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/buffer_op.cpp b/oneflow/user/ops/buffer_op.cpp index 3131d1b35b1..eb8abde1ee6 100644 --- a/oneflow/user/ops/buffer_op.cpp +++ b/oneflow/user/ops/buffer_op.cpp @@ -14,39 +14,35 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -REGISTER_NO_GRAD_USER_OP("identity_buffer") - .Input("in") - .Output("out") - .Attr("buffer_size") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); - -} // namespace +/* static */ Maybe IdentityBufferOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe IdentityBufferOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe IdentityBufferOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe IdentityBufferOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/cast_like_op.cpp b/oneflow/user/ops/cast_like_op.cpp index a2a7face17c..c4d41a00be8 100644 --- a/oneflow/user/ops/cast_like_op.cpp +++ b/oneflow/user/ops/cast_like_op.cpp @@ -14,56 +14,60 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("cast_like") - .Input("in") - .Input("dtype_like") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* dtype_like_modifier = GetInputArgModifierFn("dtype_like", 0); - CHECK_NOTNULL_OR_RETURN(dtype_like_modifier); - dtype_like_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); - for (int i = 0; i < in_shape.NumAxes(); ++i) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("dtype_like", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("dtype_like", 0)) - .Broadcast(user_op::OpArg("in", 0)) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("dtype_like", 0)) - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - ctx->NewBuilder() - .PartialSum(user_op::OpArg("dtype_like", 0)) - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& dtype_like_tensor_desc = ctx->InputTensorDesc("dtype_like", 0); - user_op::TensorDesc* output_tensor_desc = ctx->OutputTensorDesc("out", 0); - *output_tensor_desc->mut_data_type() = dtype_like_tensor_desc.data_type(); - return Maybe::Ok(); - }); +/* static */ Maybe CastLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe CastLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CastLikeOp::GetSbp(user_op::SbpContext* ctx) { + const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); + for (int i = 0; i < in_shape.NumAxes(); ++i) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), i) + .Split(user_op::OpArg("dtype_like", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("dtype_like", 0)) + .Broadcast(user_op::OpArg("in", 0)) + .Broadcast(user_op::OpArg("out", 0)) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("dtype_like", 0)) + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + ctx->NewBuilder() + .PartialSum(user_op::OpArg("dtype_like", 0)) + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe CastLikeOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* dtype_like_modifier = GetInputArgModifierFn("dtype_like", 0); + CHECK_NOTNULL_OR_RETURN(dtype_like_modifier); + dtype_like_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe CastLikeOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& dtype_like_tensor_desc = ctx->InputTensorDesc("dtype_like", 0); + user_op::TensorDesc* output_tensor_desc = ctx->OutputTensorDesc("out", 0); + *output_tensor_desc->mut_data_type() = dtype_like_tensor_desc.data_type(); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/cast_op.cpp b/oneflow/user/ops/cast_op.cpp index 2ae1c246be2..545bcfeaba3 100644 --- a/oneflow/user/ops/cast_op.cpp +++ b/oneflow/user/ops/cast_op.cpp @@ -14,11 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { -Maybe TensorDescInfer(user_op::InferContext* ctx) { +/* static */ Maybe CastOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& input_tensor_desc = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* output_tensor_desc = ctx->OutputTensorDesc("out", 0); *output_tensor_desc->mut_shape() = input_tensor_desc.shape(); @@ -26,7 +26,11 @@ Maybe TensorDescInfer(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe GetSbpSignatures(user_op::SbpContext* ctx) { +/*static*/ Maybe CastOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CastOp::GetSbp(user_op::SbpContext* ctx) { const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); for (int i = 0; i < in_tensor.shape().NumAxes(); ++i) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); @@ -35,21 +39,13 @@ Maybe GetSbpSignatures(user_op::SbpContext* ctx) { return Maybe::Ok(); } -Maybe InferDataType(user_op::InferContext* ctx) { +/* static */ Maybe CastOp::InferDataType(user_op::InferContext* ctx) { user_op::TensorDesc* output_tensor_desc = ctx->OutputTensorDesc("out", 0); DataType* dtype = output_tensor_desc->mut_data_type(); *dtype = ctx->Attr("dtype"); return Maybe::Ok(); } -REGISTER_USER_OP("cast") - .Input("in") - .Attr("dtype") - .Output("out") - .SetTensorDescInferFn(TensorDescInfer) - .SetGetSbpFn(GetSbpSignatures) - .SetDataTypeInferFn(InferDataType); - REGISTER_USER_OP_GRAD("cast").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { if (op.NeedGenGradTensor4OpInput("in", 0)) { @@ -67,5 +63,4 @@ REGISTER_USER_OP_GRAD("cast").SetGenBackwardOpConfFn([](const user_op::UserOpWra return Maybe::Ok(); }); -} // namespace } // namespace oneflow diff --git a/oneflow/user/ops/cast_to_static_shape_op.cpp b/oneflow/user/ops/cast_to_static_shape_op.cpp index 749f4940bf7..20843124a24 100644 --- a/oneflow/user/ops/cast_to_static_shape_op.cpp +++ b/oneflow/user/ops/cast_to_static_shape_op.cpp @@ -14,38 +14,41 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("cast_to_static_shape") - .Input("input") - .Output("output") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); - user_op::TensorDesc* output_desc = ctx->OutputTensorDesc("output", 0); - *output_desc->mut_shape() = input_desc.shape(); - output_desc->set_is_dynamic(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& input_desc = - ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0); - FOR_RANGE(int64_t, i, 0, input_desc.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("input", 0), i) - .Split(user_op::OpArg("output", 0), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("input", 0)) - .PartialSum(user_op::OpArg("output", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("output", 0) = ctx->InputDType("input", 0); - return Maybe::Ok(); - }); +/* static */ Maybe CastToStaticShapeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); + user_op::TensorDesc* output_desc = ctx->OutputTensorDesc("output", 0); + *output_desc->mut_shape() = input_desc.shape(); + output_desc->set_is_dynamic(false); + return Maybe::Ok(); +} + +/*static*/ Maybe CastToStaticShapeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CastToStaticShapeOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& input_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0); + FOR_RANGE(int64_t, i, 0, input_desc.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("input", 0), i) + .Split(user_op::OpArg("output", 0), i) + .Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("input", 0)) + .PartialSum(user_op::OpArg("output", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe CastToStaticShapeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("output", 0) = ctx->InputDType("input", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("cast_to_static_shape") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/cast_to_tick_op.cpp b/oneflow/user/ops/cast_to_tick_op.cpp index de7ec6b9d87..1daf6241bb0 100644 --- a/oneflow/user/ops/cast_to_tick_op.cpp +++ b/oneflow/user/ops/cast_to_tick_op.cpp @@ -15,43 +15,46 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/operator.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -REGISTER_NO_GRAD_USER_OP("cast_to_tick") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - Shape* out_shape = ctx->OutputShape("out", 0); - *out_shape = Shape({1}); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), parallel_hierarchy.NumAxes()); - - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - in_distribution->clear_sbp_parallel(); - out_distribution->clear_sbp_parallel(); - // in use hint - in_distribution->CopyFrom(in_dis_hint); - - for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { - // out dim1 = broadcast - out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); - } - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -} // namespace +/* static */ Maybe CastToTickOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + Shape* out_shape = ctx->OutputShape("out", 0); + *out_shape = Shape({1}); + return Maybe::Ok(); +} + +/*static*/ Maybe CastToTickOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CastToTickOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe CastToTickOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), parallel_hierarchy.NumAxes()); + + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + in_distribution->clear_sbp_parallel(); + out_distribution->clear_sbp_parallel(); + // in use hint + in_distribution->CopyFrom(in_dis_hint); + + for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { + // out dim1 = broadcast + out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); + } + return Maybe::Ok(); +} + +/* static */ Maybe CastToTickOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/categorical_ordinal_encode_op.cpp b/oneflow/user/ops/categorical_ordinal_encode_op.cpp index cb2f0d4351f..ca2b4533826 100644 --- a/oneflow/user/ops/categorical_ordinal_encode_op.cpp +++ b/oneflow/user/ops/categorical_ordinal_encode_op.cpp @@ -14,64 +14,66 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("CategoricalOrdinalEncode") - .Input("table") - .Input("size") - .Input("in") - .Output("out") - .Attr("hash_precomputed") - .SetPhysicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->parallel_ctx().parallel_num(), 1); - const Shape& table_shape = ctx->InputShape("table", 0); - CHECK_EQ_OR_RETURN(table_shape.NumAxes(), 1); - CHECK_EQ_OR_RETURN(table_shape.elem_cnt() % 2, 0); - const Shape& size_shape = ctx->InputShape("size", 0); - CHECK_EQ_OR_RETURN(size_shape.NumAxes(), 1); - CHECK_EQ_OR_RETURN(size_shape.elem_cnt(), 1); - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& table_shape = ctx->InputShape("table", 0); - CHECK_EQ_OR_RETURN(table_shape.NumAxes(), 1); - CHECK_EQ_OR_RETURN(table_shape.elem_cnt() % 2, 0); - const Shape& size_shape = ctx->InputShape("size", 0); - CHECK_EQ_OR_RETURN(size_shape.NumAxes(), 1); - CHECK_EQ_OR_RETURN(size_shape.elem_cnt(), 1); - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* table = GetInputArgModifierFn("table", 0); - table->set_is_mutable(true); - table->set_requires_grad(false); - user_op::InputArgModifier* size = GetInputArgModifierFn("size", 0); - size->set_is_mutable(true); - size->set_requires_grad(false); - user_op::InputArgModifier* in = GetInputArgModifierFn("in", 0); - in->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->parallel_num(), 1); - return Maybe::Ok(); - }) - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& op_def, - const user_op::UserOpConfWrapper& op_conf) -> Maybe { - CHECK_OR_RETURN(op_conf.attr("hash_precomputed")); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const DataType& data_type = ctx->InputDType("in", 0); - CHECK_OR_RETURN(IsIndexDataType(data_type)); - CHECK_EQ_OR_RETURN(ctx->InputDType("table", 0), data_type); - CHECK_EQ_OR_RETURN(ctx->InputDType("size", 0), data_type); - *ctx->OutputDType("out", 0) = data_type; - return Maybe::Ok(); - }); +/* static */ Maybe CategoricalOrdinalEncodeOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const Shape& table_shape = ctx->InputShape("table", 0); + CHECK_EQ_OR_RETURN(table_shape.NumAxes(), 1); + CHECK_EQ_OR_RETURN(table_shape.elem_cnt() % 2, 0); + const Shape& size_shape = ctx->InputShape("size", 0); + CHECK_EQ_OR_RETURN(size_shape.NumAxes(), 1); + CHECK_EQ_OR_RETURN(size_shape.elem_cnt(), 1); + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe CategoricalOrdinalEncodeOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->parallel_ctx().parallel_num(), 1); + const Shape& table_shape = ctx->InputShape("table", 0); + CHECK_EQ_OR_RETURN(table_shape.NumAxes(), 1); + CHECK_EQ_OR_RETURN(table_shape.elem_cnt() % 2, 0); + const Shape& size_shape = ctx->InputShape("size", 0); + CHECK_EQ_OR_RETURN(size_shape.NumAxes(), 1); + CHECK_EQ_OR_RETURN(size_shape.elem_cnt(), 1); + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe CategoricalOrdinalEncodeOp::GetSbp(user_op::SbpContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->parallel_num(), 1); + return Maybe::Ok(); +} + +/* static */ Maybe CategoricalOrdinalEncodeOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* table = GetInputArgModifierFn("table", 0); + table->set_is_mutable(true); + table->set_requires_grad(false); + user_op::InputArgModifier* size = GetInputArgModifierFn("size", 0); + size->set_is_mutable(true); + size->set_requires_grad(false); + user_op::InputArgModifier* in = GetInputArgModifierFn("in", 0); + in->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe CategoricalOrdinalEncodeOp::CheckAttr( + const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { + CHECK_OR_RETURN(conf.attr("hash_precomputed")); + return Maybe::Ok(); +} + +/* static */ Maybe CategoricalOrdinalEncodeOp::InferDataType(user_op::InferContext* ctx) { + const DataType& data_type = ctx->InputDType("in", 0); + CHECK_OR_RETURN(IsIndexDataType(data_type)); + CHECK_EQ_OR_RETURN(ctx->InputDType("table", 0), data_type); + CHECK_EQ_OR_RETURN(ctx->InputDType("size", 0), data_type); + *ctx->OutputDType("out", 0) = data_type; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/celu_op.cpp b/oneflow/user/ops/celu_op.cpp index 395c85a67f6..60d48152434 100644 --- a/oneflow/user/ops/celu_op.cpp +++ b/oneflow/user/ops/celu_op.cpp @@ -14,63 +14,62 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/* static */ Maybe CeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("celu") - .Input("in") - .Output("out") - .Attr("alpha") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe CeluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} -REGISTER_USER_OP("celu_grad") - .Input("x") - .Input("dy") - .Output("dx") - .Attr("alpha") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(dy_shape == x_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/* static */ Maybe CeluOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe CeluOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe CeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(dy_shape == x_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} + +/*static*/ Maybe CeluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CeluGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe CeluGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("celu").SetBackwardOpConfGenFn( [](user_op::BackwardOpConfContext* ctx) -> Maybe { @@ -90,6 +89,4 @@ REGISTER_USER_OP_GRAD("celu").SetBackwardOpConfGenFn( return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/clip_by_value_op.cpp b/oneflow/user/ops/clip_by_value_op.cpp index acadfc6ca01..f216e077816 100644 --- a/oneflow/user/ops/clip_by_value_op.cpp +++ b/oneflow/user/ops/clip_by_value_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -66,66 +67,45 @@ Maybe InferClipGradDataType(user_op::InferContext* ctx) { } // namespace -REGISTER_USER_OP("clip_by_scalar") - .Input("x") - .Attr("floating_min") - .Attr("integral_min") - .Attr("floating_max") - .Attr("integral_max") - .Output("y") - .SetTensorDescInferFn(InferClipTensorDesc) - .SetGetSbpFn(GetClipSbpSignature) - .SetDataTypeInferFn(InferClipTensorDataType); - -REGISTER_USER_OP("clip_by_scalar_min") - .Input("x") - .Attr("floating_min") - .Attr("integral_min") - .Output("y") - .SetTensorDescInferFn(InferClipTensorDesc) - .SetGetSbpFn(GetClipSbpSignature) - .SetDataTypeInferFn(InferClipTensorDataType); - -REGISTER_USER_OP("clip_by_scalar_max") - .Input("x") - .Attr("floating_max") - .Attr("integral_max") - .Output("y") - .SetTensorDescInferFn(InferClipTensorDesc) - .SetGetSbpFn(GetClipSbpSignature) - .SetDataTypeInferFn(InferClipTensorDataType); - -REGISTER_USER_OP("clip_by_scalar_grad") - .Input("dy") - .Input("x") - .Attr("floating_min") - .Attr("integral_min") - .Attr("floating_max") - .Attr("integral_max") - .Output("dx") - .SetTensorDescInferFn(InferClipGradTensorDesc) - .SetGetSbpFn(GetClipGradSbpSignature) - .SetDataTypeInferFn(InferClipGradDataType); - -REGISTER_USER_OP("clip_by_scalar_min_grad") - .Input("dy") - .Input("x") - .Attr("floating_min") - .Attr("integral_min") - .Output("dx") - .SetTensorDescInferFn(InferClipGradTensorDesc) - .SetGetSbpFn(GetClipGradSbpSignature) - .SetDataTypeInferFn(InferClipGradDataType); - -REGISTER_USER_OP("clip_by_scalar_max_grad") - .Input("dy") - .Input("x") - .Attr("floating_max") - .Attr("integral_max") - .Output("dx") - .SetTensorDescInferFn(InferClipGradTensorDesc) - .SetGetSbpFn(GetClipGradSbpSignature) - .SetDataTypeInferFn(InferClipGradDataType); +#define DEF_CLIP_BY_VALUE_OP(op_class_name_prefix) \ + /* static */ Maybe op_class_name_prefix##Op::InferLogicalTensorDesc( \ + user_op::InferContext* ctx) { \ + return InferClipTensorDesc(ctx); \ + } \ + \ + /*static*/ Maybe op_class_name_prefix##Op::InferPhysicalTensorDesc( \ + user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + \ + /* static */ Maybe op_class_name_prefix##Op::GetSbp(user_op::SbpContext* ctx) { \ + return GetClipSbpSignature(ctx); \ + } \ + \ + /* static */ Maybe op_class_name_prefix##Op::InferDataType(user_op::InferContext* ctx) { \ + return InferClipTensorDataType(ctx); \ + } \ + /* static */ Maybe op_class_name_prefix##GradOp::InferLogicalTensorDesc( \ + user_op::InferContext* ctx) { \ + return InferClipGradTensorDesc(ctx); \ + } \ + /*static*/ Maybe op_class_name_prefix##GradOp::InferPhysicalTensorDesc( \ + user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /* static */ Maybe op_class_name_prefix##GradOp::GetSbp(user_op::SbpContext* ctx) { \ + return GetClipGradSbpSignature(ctx); \ + } \ + /* static */ Maybe op_class_name_prefix##GradOp::InferDataType( \ + user_op::InferContext* ctx) { \ + return InferClipGradDataType(ctx); \ + } + +DEF_CLIP_BY_VALUE_OP(ClipByScalar) +DEF_CLIP_BY_VALUE_OP(ClipByScalarMin) +DEF_CLIP_BY_VALUE_OP(ClipByScalarMax) + +#undef DEF_CLIP_BY_VALUE_OP REGISTER_USER_OP_GRAD("clip_by_scalar") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/coco_reader_op.cpp b/oneflow/user/ops/coco_reader_op.cpp index 6ab6f25457d..adfca1c99bf 100644 --- a/oneflow/user/ops/coco_reader_op.cpp +++ b/oneflow/user/ops/coco_reader_op.cpp @@ -14,135 +14,122 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("COCOReader") - .Output("image") - .Output("image_id") - .Output("image_size") - .Output("gt_bbox") - .Output("gt_label") - .Output("gt_segm") - .Output("gt_segm_index") - .Attr("session_id") - .Attr("annotation_file") - .Attr("image_dir") - .Attr("batch_size") - .Attr("shuffle_after_epoch", true) - .Attr("random_seed", -1) - .Attr("group_by_ratio", true) - .Attr("remove_images_without_annotations", true) - .Attr("stride_partition", false) - .Attr>("nd_sbp") - .SetPhysicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const cfg::SbpParallel& sbp = ctx->SbpParallel4ArgNameAndIndex("image", 0); - CHECK_OR_RETURN(sbp == ctx->SbpParallel4ArgNameAndIndex("image_id", 0)); - CHECK_OR_RETURN(sbp == ctx->SbpParallel4ArgNameAndIndex("image_size", 0)); - CHECK_OR_RETURN(sbp == ctx->SbpParallel4ArgNameAndIndex("gt_bbox", 0)); - CHECK_OR_RETURN(sbp == ctx->SbpParallel4ArgNameAndIndex("gt_label", 0)); - CHECK_OR_RETURN(sbp == ctx->SbpParallel4ArgNameAndIndex("gt_segm", 0)); - CHECK_OR_RETURN(sbp == ctx->SbpParallel4ArgNameAndIndex("gt_segm_index", 0)); - - int64_t batch_size = ctx->Attr("batch_size"); - int64_t parallel_num = ctx->parallel_ctx().parallel_num(); - int64_t device_batch_size = batch_size; - if (sbp.has_split_parallel() && parallel_num > 1) { - CHECK_EQ_OR_RETURN(device_batch_size % parallel_num, 0); - device_batch_size /= parallel_num; - } - - user_op::TensorDesc* image_desc = ctx->OutputTensorDesc("image", 0); - *image_desc->mut_shape() = Shape({device_batch_size}); - user_op::TensorDesc* image_id_desc = ctx->OutputTensorDesc("image_id", 0); - *image_id_desc->mut_shape() = Shape({device_batch_size}); - user_op::TensorDesc* image_size_desc = ctx->OutputTensorDesc("image_size", 0); - *image_size_desc->mut_shape() = Shape({device_batch_size, 2}); - user_op::TensorDesc* bbox_desc = ctx->OutputTensorDesc("gt_bbox", 0); - *bbox_desc->mut_shape() = Shape({device_batch_size}); - user_op::TensorDesc* label_desc = ctx->OutputTensorDesc("gt_label", 0); - *label_desc->mut_shape() = Shape({device_batch_size}); - user_op::TensorDesc* segm_desc = ctx->OutputTensorDesc("gt_segm", 0); - *segm_desc->mut_shape() = Shape({device_batch_size}); - user_op::TensorDesc* segm_index_desc = ctx->OutputTensorDesc("gt_segm_index", 0); - *segm_index_desc->mut_shape() = Shape({device_batch_size}); - return Maybe::Ok(); - }) - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - int64_t batch_size = ctx->Attr("batch_size"); - user_op::TensorDesc* image_desc = ctx->OutputTensorDesc("image", 0); - *image_desc->mut_shape() = Shape({batch_size}); - user_op::TensorDesc* image_id_desc = ctx->OutputTensorDesc("image_id", 0); - *image_id_desc->mut_shape() = Shape({batch_size}); - user_op::TensorDesc* image_size_desc = ctx->OutputTensorDesc("image_size", 0); - *image_size_desc->mut_shape() = Shape({batch_size, 2}); - user_op::TensorDesc* bbox_desc = ctx->OutputTensorDesc("gt_bbox", 0); - *bbox_desc->mut_shape() = Shape({batch_size}); - user_op::TensorDesc* label_desc = ctx->OutputTensorDesc("gt_label", 0); - *label_desc->mut_shape() = Shape({batch_size}); - user_op::TensorDesc* segm_desc = ctx->OutputTensorDesc("gt_segm", 0); - *segm_desc->mut_shape() = Shape({batch_size}); - user_op::TensorDesc* segm_index_desc = ctx->OutputTensorDesc("gt_segm_index", 0); - *segm_index_desc->mut_shape() = Shape({batch_size}); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* image_desc = ctx->OutputTensorDesc("image", 0); - *image_desc->mut_data_type() = DataType::kTensorBuffer; - user_op::TensorDesc* image_id_desc = ctx->OutputTensorDesc("image_id", 0); - *image_id_desc->mut_data_type() = DataType::kInt64; - user_op::TensorDesc* image_size_desc = ctx->OutputTensorDesc("image_size", 0); - *image_size_desc->mut_data_type() = DataType::kInt32; - user_op::TensorDesc* bbox_desc = ctx->OutputTensorDesc("gt_bbox", 0); - *bbox_desc->mut_data_type() = DataType::kTensorBuffer; - user_op::TensorDesc* label_desc = ctx->OutputTensorDesc("gt_label", 0); - *label_desc->mut_data_type() = DataType::kTensorBuffer; - user_op::TensorDesc* segm_desc = ctx->OutputTensorDesc("gt_segm", 0); - *segm_desc->mut_data_type() = DataType::kTensorBuffer; - user_op::TensorDesc* segm_index_desc = ctx->OutputTensorDesc("gt_segm_index", 0); - *segm_index_desc->mut_data_type() = DataType::kTensorBuffer; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - cfg::SbpParallel default_sbp; - default_sbp.mutable_split_parallel()->set_axis(0); - return user_op::InferNdSbp4SrcOp(ctx, default_sbp); - }) - .SetOutputArgModifyFn([](user_op::GetOutputArgModifier GetOutputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> Maybe { - user_op::OutputArgModifier* image_modifier = GetOutputArgModifierFn("image", 0); - CHECK_OR_RETURN(image_modifier != nullptr); - image_modifier->set_header_infered_before_compute(false); - - user_op::OutputArgModifier* image_id_modifier = GetOutputArgModifierFn("image_id", 0); - CHECK_OR_RETURN(image_id_modifier != nullptr); - image_id_modifier->set_header_infered_before_compute(false); - - user_op::OutputArgModifier* image_size_modifier = GetOutputArgModifierFn("image_size", 0); - CHECK_OR_RETURN(image_size_modifier != nullptr); - image_size_modifier->set_header_infered_before_compute(false); - - user_op::OutputArgModifier* gt_bbox_modifier = GetOutputArgModifierFn("gt_bbox", 0); - CHECK_OR_RETURN(gt_bbox_modifier != nullptr); - gt_bbox_modifier->set_header_infered_before_compute(false); - - user_op::OutputArgModifier* gt_label_modifier = GetOutputArgModifierFn("gt_label", 0); - CHECK_OR_RETURN(gt_label_modifier != nullptr); - gt_label_modifier->set_header_infered_before_compute(false); - - user_op::OutputArgModifier* gt_segm_modifier = GetOutputArgModifierFn("gt_segm", 0); - CHECK_OR_RETURN(gt_segm_modifier != nullptr); - gt_segm_modifier->set_header_infered_before_compute(false); - - user_op::OutputArgModifier* gt_segm_index_modifier = - GetOutputArgModifierFn("gt_segm_index", 0); - CHECK_OR_RETURN(gt_segm_index_modifier != nullptr); - gt_segm_index_modifier->set_header_infered_before_compute(false); - return Maybe::Ok(); - }); +/* static */ Maybe COCOReaderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + int64_t batch_size = ctx->Attr("batch_size"); + user_op::TensorDesc* image_desc = ctx->OutputTensorDesc("image", 0); + *image_desc->mut_shape() = Shape({batch_size}); + user_op::TensorDesc* image_id_desc = ctx->OutputTensorDesc("image_id", 0); + *image_id_desc->mut_shape() = Shape({batch_size}); + user_op::TensorDesc* image_size_desc = ctx->OutputTensorDesc("image_size", 0); + *image_size_desc->mut_shape() = Shape({batch_size, 2}); + user_op::TensorDesc* bbox_desc = ctx->OutputTensorDesc("gt_bbox", 0); + *bbox_desc->mut_shape() = Shape({batch_size}); + user_op::TensorDesc* label_desc = ctx->OutputTensorDesc("gt_label", 0); + *label_desc->mut_shape() = Shape({batch_size}); + user_op::TensorDesc* segm_desc = ctx->OutputTensorDesc("gt_segm", 0); + *segm_desc->mut_shape() = Shape({batch_size}); + user_op::TensorDesc* segm_index_desc = ctx->OutputTensorDesc("gt_segm_index", 0); + *segm_index_desc->mut_shape() = Shape({batch_size}); + return Maybe::Ok(); +} + +/* static */ Maybe COCOReaderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + const cfg::SbpParallel& sbp = ctx->SbpParallel4ArgNameAndIndex("image", 0); + CHECK_OR_RETURN(sbp == ctx->SbpParallel4ArgNameAndIndex("image_id", 0)); + CHECK_OR_RETURN(sbp == ctx->SbpParallel4ArgNameAndIndex("image_size", 0)); + CHECK_OR_RETURN(sbp == ctx->SbpParallel4ArgNameAndIndex("gt_bbox", 0)); + CHECK_OR_RETURN(sbp == ctx->SbpParallel4ArgNameAndIndex("gt_label", 0)); + CHECK_OR_RETURN(sbp == ctx->SbpParallel4ArgNameAndIndex("gt_segm", 0)); + CHECK_OR_RETURN(sbp == ctx->SbpParallel4ArgNameAndIndex("gt_segm_index", 0)); + + int64_t batch_size = ctx->Attr("batch_size"); + int64_t parallel_num = ctx->parallel_ctx().parallel_num(); + int64_t device_batch_size = batch_size; + if (sbp.has_split_parallel() && parallel_num > 1) { + CHECK_EQ_OR_RETURN(device_batch_size % parallel_num, 0); + device_batch_size /= parallel_num; + } + + user_op::TensorDesc* image_desc = ctx->OutputTensorDesc("image", 0); + *image_desc->mut_shape() = Shape({device_batch_size}); + user_op::TensorDesc* image_id_desc = ctx->OutputTensorDesc("image_id", 0); + *image_id_desc->mut_shape() = Shape({device_batch_size}); + user_op::TensorDesc* image_size_desc = ctx->OutputTensorDesc("image_size", 0); + *image_size_desc->mut_shape() = Shape({device_batch_size, 2}); + user_op::TensorDesc* bbox_desc = ctx->OutputTensorDesc("gt_bbox", 0); + *bbox_desc->mut_shape() = Shape({device_batch_size}); + user_op::TensorDesc* label_desc = ctx->OutputTensorDesc("gt_label", 0); + *label_desc->mut_shape() = Shape({device_batch_size}); + user_op::TensorDesc* segm_desc = ctx->OutputTensorDesc("gt_segm", 0); + *segm_desc->mut_shape() = Shape({device_batch_size}); + user_op::TensorDesc* segm_index_desc = ctx->OutputTensorDesc("gt_segm_index", 0); + *segm_index_desc->mut_shape() = Shape({device_batch_size}); + return Maybe::Ok(); +} + +/* static */ Maybe COCOReaderOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe COCOReaderOp::ModifyOutputArg( + const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::OutputArgModifier* image_modifier = GetOutputArgModifierFn("image", 0); + CHECK_OR_RETURN(image_modifier != nullptr); + image_modifier->set_header_infered_before_compute(false); + + user_op::OutputArgModifier* image_id_modifier = GetOutputArgModifierFn("image_id", 0); + CHECK_OR_RETURN(image_id_modifier != nullptr); + image_id_modifier->set_header_infered_before_compute(false); + + user_op::OutputArgModifier* image_size_modifier = GetOutputArgModifierFn("image_size", 0); + CHECK_OR_RETURN(image_size_modifier != nullptr); + image_size_modifier->set_header_infered_before_compute(false); + + user_op::OutputArgModifier* gt_bbox_modifier = GetOutputArgModifierFn("gt_bbox", 0); + CHECK_OR_RETURN(gt_bbox_modifier != nullptr); + gt_bbox_modifier->set_header_infered_before_compute(false); + + user_op::OutputArgModifier* gt_label_modifier = GetOutputArgModifierFn("gt_label", 0); + CHECK_OR_RETURN(gt_label_modifier != nullptr); + gt_label_modifier->set_header_infered_before_compute(false); + + user_op::OutputArgModifier* gt_segm_modifier = GetOutputArgModifierFn("gt_segm", 0); + CHECK_OR_RETURN(gt_segm_modifier != nullptr); + gt_segm_modifier->set_header_infered_before_compute(false); + + user_op::OutputArgModifier* gt_segm_index_modifier = GetOutputArgModifierFn("gt_segm_index", 0); + CHECK_OR_RETURN(gt_segm_index_modifier != nullptr); + gt_segm_index_modifier->set_header_infered_before_compute(false); + return Maybe::Ok(); +} + +/* static */ Maybe COCOReaderOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + cfg::SbpParallel default_sbp; + default_sbp.mutable_split_parallel()->set_axis(0); + return user_op::InferNdSbp4SrcOp(ctx, default_sbp); +} + +/* static */ Maybe COCOReaderOp::InferDataType(user_op::InferContext* ctx) { + user_op::TensorDesc* image_desc = ctx->OutputTensorDesc("image", 0); + *image_desc->mut_data_type() = DataType::kTensorBuffer; + user_op::TensorDesc* image_id_desc = ctx->OutputTensorDesc("image_id", 0); + *image_id_desc->mut_data_type() = DataType::kInt64; + user_op::TensorDesc* image_size_desc = ctx->OutputTensorDesc("image_size", 0); + *image_size_desc->mut_data_type() = DataType::kInt32; + user_op::TensorDesc* bbox_desc = ctx->OutputTensorDesc("gt_bbox", 0); + *bbox_desc->mut_data_type() = DataType::kTensorBuffer; + user_op::TensorDesc* label_desc = ctx->OutputTensorDesc("gt_label", 0); + *label_desc->mut_data_type() = DataType::kTensorBuffer; + user_op::TensorDesc* segm_desc = ctx->OutputTensorDesc("gt_segm", 0); + *segm_desc->mut_data_type() = DataType::kTensorBuffer; + user_op::TensorDesc* segm_index_desc = ctx->OutputTensorDesc("gt_segm_index", 0); + *segm_index_desc->mut_data_type() = DataType::kTensorBuffer; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/combined_margin_loss_op.cpp b/oneflow/user/ops/combined_margin_loss_op.cpp index d420cb35209..72854a53928 100644 --- a/oneflow/user/ops/combined_margin_loss_op.cpp +++ b/oneflow/user/ops/combined_margin_loss_op.cpp @@ -14,96 +14,94 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("combined_margin_loss") - .Input("x") - .Input("label") - .Output("y") - .Output("theta") - .Attr("m1") - .Attr("m2") - .Attr("m3") - .Attr("depth") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& label = ctx->InputTensorDesc("label", 0); - user_op::TensorDesc* theta = ctx->OutputTensorDesc("theta", 0); - CHECK_EQ_OR_RETURN(label.shape().At(0), x.shape().At(0)); - CHECK_GE_OR_RETURN(x.shape().NumAxes(), 2); - *ctx->OutputShape("y", 0) = ctx->InputShape("x", 0); - *ctx->IsDynamic4ArgNameAndIndex("y", 0) = ctx->InputIsDynamic("x", 0); - *theta->mut_is_dynamic() = x.is_dynamic(); - *theta->mut_shape() = label.shape(); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* label_arg_modifier = GetInputArgModifierFn("label", 0); - label_arg_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), 0) - .Split(user_op::OpArg("label", 0), 0) - .Split(user_op::OpArg("y", 0), 0) - .Split(user_op::OpArg("theta", 0), 0) - .Build(); - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), 1) - .Broadcast(user_op::OpArg("label", 0)) - .Split(user_op::OpArg("y", 0), 1) - .PartialSum(user_op::OpArg("theta", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - *ctx->OutputDType("theta", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/* static */ Maybe CombinedMarginLossOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& label = ctx->InputTensorDesc("label", 0); + user_op::TensorDesc* theta = ctx->OutputTensorDesc("theta", 0); + CHECK_EQ_OR_RETURN(label.shape().At(0), x.shape().At(0)); + CHECK_GE_OR_RETURN(x.shape().NumAxes(), 2); + *ctx->OutputShape("y", 0) = ctx->InputShape("x", 0); + *ctx->IsDynamic4ArgNameAndIndex("y", 0) = ctx->InputIsDynamic("x", 0); + *theta->mut_is_dynamic() = x.is_dynamic(); + *theta->mut_shape() = label.shape(); + return Maybe::Ok(); +} -REGISTER_USER_OP("combined_margin_loss_grad") - .Input("dy") - .Input("label") - .Input("theta") - .Output("dx") - .Attr("m1") - .Attr("m2") - .Attr("m3") - .Attr("depth") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); - const user_op::TensorDesc& label = ctx->InputTensorDesc("label", 0); - const user_op::TensorDesc& theta = ctx->InputTensorDesc("theta", 0); - CHECK_EQ_OR_RETURN(label.shape().At(0), dy.shape().At(0)); - CHECK_EQ_OR_RETURN(label.shape().At(0), theta.shape().At(0)); - CHECK_GE_OR_RETURN(dy.shape().NumAxes(), 2); - *ctx->OutputShape("dx", 0) = ctx->InputShape("dy", 0); - *ctx->IsDynamic4ArgNameAndIndex("dx", 0) = ctx->InputIsDynamic("dy", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), 0) - .Split(user_op::OpArg("label", 0), 0) - .Split(user_op::OpArg("theta", 0), 0) - .Split(user_op::OpArg("dx", 0), 0) - .Build(); - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), 1) - .Broadcast(user_op::OpArg("label", 0)) - .Broadcast(user_op::OpArg("theta", 0)) - .Split(user_op::OpArg("dx", 0), 1) - .Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe CombinedMarginLossOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CombinedMarginLossOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), 0) + .Split(user_op::OpArg("label", 0), 0) + .Split(user_op::OpArg("y", 0), 0) + .Split(user_op::OpArg("theta", 0), 0) + .Build(); + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), 1) + .Broadcast(user_op::OpArg("label", 0)) + .Split(user_op::OpArg("y", 0), 1) + .PartialSum(user_op::OpArg("theta", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe CombinedMarginLossOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* label_arg_modifier = GetInputArgModifierFn("label", 0); + label_arg_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe CombinedMarginLossOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->OutputDType("theta", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} + +/* static */ Maybe CombinedMarginLossGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); + const user_op::TensorDesc& label = ctx->InputTensorDesc("label", 0); + const user_op::TensorDesc& theta = ctx->InputTensorDesc("theta", 0); + CHECK_EQ_OR_RETURN(label.shape().At(0), dy.shape().At(0)); + CHECK_EQ_OR_RETURN(label.shape().At(0), theta.shape().At(0)); + CHECK_GE_OR_RETURN(dy.shape().NumAxes(), 2); + *ctx->OutputShape("dx", 0) = ctx->InputShape("dy", 0); + *ctx->IsDynamic4ArgNameAndIndex("dx", 0) = ctx->InputIsDynamic("dy", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe CombinedMarginLossGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CombinedMarginLossGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), 0) + .Split(user_op::OpArg("label", 0), 0) + .Split(user_op::OpArg("theta", 0), 0) + .Split(user_op::OpArg("dx", 0), 0) + .Build(); + ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), 1) + .Broadcast(user_op::OpArg("label", 0)) + .Broadcast(user_op::OpArg("theta", 0)) + .Split(user_op::OpArg("dx", 0), 1) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe CombinedMarginLossGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("combined_margin_loss") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/concat_op.cpp b/oneflow/user/ops/concat_op.cpp index 253bb465ed1..b631d4a15c8 100644 --- a/oneflow/user/ops/concat_op.cpp +++ b/oneflow/user/ops/concat_op.cpp @@ -14,12 +14,40 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { -Maybe InferTensorDesc(user_op::InferContext* ctx) { +Maybe GenGrapOp(const user_op::UserOpWrapper& op, const user_op::AddOpFn& AddOp) { + bool need_grad = false; + const int32_t in_size = op.input_size("in"); + FOR_RANGE(int32_t, i, 0, in_size) { + if (op.NeedGenGradTensor4OpInput("in", i)) { need_grad = true; } + } + if (need_grad) { + user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); + builder = builder.Op("split_like"); + FOR_RANGE(int32_t, i, 0, in_size) { builder = builder.Input("like", op.input("in", i)); } + user_op::UserOpConfWrapper grad_op = builder.Input("in", op.GetGradTensorWithOpOutput("out", 0)) + .Output("out", in_size) + .Attr("axis", op.attr("axis")) + .Build(); + + FOR_RANGE(int32_t, i, 0, in_size) { + if (op.NeedGenGradTensor4OpInput("in", i)) { + op.BindGradTensorWithOpInput(grad_op.output("out", i), "in", i); + } + } + AddOp(grad_op); + } + return Maybe::Ok(); +} + +} // namespace + +/* static */ Maybe ConcatOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc("in", 0); const int64_t axis = ctx->Attr("axis"); CHECK_GE_OR_RETURN(axis, 0); @@ -57,7 +85,11 @@ Maybe InferTensorDesc(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe GetSbpSignature(user_op::SbpContext* ctx) { +/*static*/ Maybe ConcatOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ConcatOp::GetSbp(user_op::SbpContext* ctx) { const int64_t axis = ctx->Attr("axis"); const user_op::TensorDesc& first_in_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, first_in_desc.shape().NumAxes()) { @@ -68,32 +100,7 @@ Maybe GetSbpSignature(user_op::SbpContext* ctx) { return Maybe::Ok(); } -Maybe GenGrapOp(const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) { - bool need_grad = false; - const int32_t in_size = op.input_size("in"); - FOR_RANGE(int32_t, i, 0, in_size) { - if (op.NeedGenGradTensor4OpInput("in", i)) { need_grad = true; } - } - if (need_grad) { - user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); - builder = builder.Op("split_like"); - FOR_RANGE(int32_t, i, 0, in_size) { builder = builder.Input("like", op.input("in", i)); } - user_op::UserOpConfWrapper grad_op = builder.Input("in", op.GetGradTensorWithOpOutput("out", 0)) - .Output("out", in_size) - .Attr("axis", op.attr("axis")) - .Build(); - - FOR_RANGE(int32_t, i, 0, in_size) { - if (op.NeedGenGradTensor4OpInput("in", i)) { - op.BindGradTensorWithOpInput(grad_op.output("out", i), "in", i); - } - } - AddOp(grad_op); - } - return Maybe::Ok(); -} - -Maybe InferDataType(user_op::InferContext* ctx) { +/* static */ Maybe ConcatOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc("in", 0); for (const auto& in_arg_pair : ctx->inputs()) { const user_op::TensorDesc& in_desc = @@ -105,16 +112,11 @@ Maybe InferDataType(user_op::InferContext* ctx) { return Maybe::Ok(); } -} // namespace - -REGISTER_USER_OP("concat") - .InputWithMinimum("in", 2) - .Output("out") - .Attr("axis") - .Attr("max_dim_size") - .SetTensorDescInferFn(InferTensorDesc) - .SetGetSbpFn(GetSbpSignature) - .SetDataTypeInferFn(InferDataType); +/*static*/ Maybe ConcatOp::CheckAttr(const user_op::UserOpDefWrapper&, + const user_op::UserOpConfWrapper& op_conf) { + CHECK_OR_RETURN(op_conf.input_size("in") >= 2); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("concat").SetGenBackwardOpConfFn(GenGrapOp); diff --git a/oneflow/user/ops/constant_op.cpp b/oneflow/user/ops/constant_op.cpp index d1b432f4760..d96f07e7562 100644 --- a/oneflow/user/ops/constant_op.cpp +++ b/oneflow/user/ops/constant_op.cpp @@ -14,31 +14,26 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("constant") - .Output("out") - .SetOutputBufferNum(1) - .Attr("floating_value") - .Attr("integer_value") - .Attr("is_floating_value") - .Attr("dtype") - .Attr("shape") - .Attr>("nd_sbp") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->Attr("dtype"); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { return Maybe::Ok(); }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - cfg::SbpParallel default_sbp; - default_sbp.mutable_broadcast_parallel(); - return user_op::InferNdSbp4SrcOp(ctx, default_sbp); - }); +/* static */ Maybe ConstantOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); + return Maybe::Ok(); +} + +/* static */ Maybe ConstantOp::GetSbp(user_op::SbpContext* ctx) { return Maybe::Ok(); } + +/* static */ Maybe ConstantOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + cfg::SbpParallel default_sbp; + default_sbp.mutable_broadcast_parallel(); + return user_op::InferNdSbp4SrcOp(ctx, default_sbp); +} + +/* static */ Maybe ConstantOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->Attr("dtype"); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/conv_op.cpp b/oneflow/user/ops/conv_op.cpp index 2c7b9806e09..64940f4d2da 100644 --- a/oneflow/user/ops/conv_op.cpp +++ b/oneflow/user/ops/conv_op.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/nn_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -114,8 +115,8 @@ Maybe GetSbpSignatures4Conv(user_op::SbpContext* ctx) { } template -Maybe CheckAttr(const user_op::UserOpDefWrapper& def, - const user_op::UserOpConfWrapper& conf) { +Maybe CheckAttr_(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { bool is_checked = true; std::stringstream err; err << "Illegal value for " << conf.op_type_name() << " op " << conf.op_name() << ": "; @@ -229,241 +230,239 @@ Maybe GenerateBackwardOpConf4Conv(const user_op::UserOpWrapper& op, user_o } // namespace -REGISTER_USER_OP("conv1d") - .Input("in") - .Input("weight") - .OptionalInput("bias") - .OptionalInput("bias_multiplier") // cudnn conv doesn't need this - .Output("out") - .Attr("filters") - .Attr>("padding_before") - .Attr("data_format") - .Attr>("kernel_size") - .Attr>("strides") - .Attr>("dilation_rate") - .Attr("groups", 1) - .SetCheckAttrFn(CheckAttr<1>) - .SetTensorDescInferFn(InferTensorDesc4Conv<1>) - .SetGetSbpFn(GetSbpSignatures4Conv) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); - -REGISTER_USER_OP("conv2d") - .Input("in") - .Input("weight") - .OptionalInput("bias") - .OptionalInput("bias_multiplier") // cudnn conv doesn't need this - .Output("out") - .Attr("filters") - .Attr>("padding_before") - .Attr("data_format") - .Attr>("kernel_size") - .Attr>("strides") - .Attr>("dilation_rate") - .Attr("groups", 1) - .SetCheckAttrFn(CheckAttr<2>) - .SetTensorDescInferFn(InferTensorDesc4Conv<2>) - .SetGetSbpFn(GetSbpSignatures4Conv) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); - -REGISTER_USER_OP("conv3d") - .Input("in") - .Input("weight") - .OptionalInput("bias") - .OptionalInput("bias_multiplier") // cudnn conv doesn't need this - .Output("out") - .Attr("filters") - .Attr>("padding_before") - .Attr("data_format") - .Attr>("kernel_size") - .Attr>("strides") - .Attr>("dilation_rate") - .Attr("groups", 1) - .SetCheckAttrFn(CheckAttr<3>) - .SetTensorDescInferFn(InferTensorDesc4Conv<3>) - .SetGetSbpFn(GetSbpSignatures4Conv) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/* static */ Maybe Conv1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDesc4Conv<1>(ctx); +} + +/*static*/ Maybe Conv1DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe Conv1DOp::GetSbp(user_op::SbpContext* ctx) { + return GetSbpSignatures4Conv(ctx); +} + +/* static */ Maybe Conv1DOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + return CheckAttr_<1>(def, conf); +} + +/* static */ Maybe Conv1DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe Conv2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDesc4Conv<2>(ctx); +} + +/*static*/ Maybe Conv2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe Conv2DOp::GetSbp(user_op::SbpContext* ctx) { + return GetSbpSignatures4Conv(ctx); +} + +/* static */ Maybe Conv2DOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + return CheckAttr_<2>(def, conf); +} + +/* static */ Maybe Conv2DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe Conv3DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDesc4Conv<3>(ctx); +} + +/*static*/ Maybe Conv3DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe Conv3DOp::GetSbp(user_op::SbpContext* ctx) { + return GetSbpSignatures4Conv(ctx); +} + +/* static */ Maybe Conv3DOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + return CheckAttr_<3>(def, conf); +} + +/* static */ Maybe Conv3DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe ConvDataGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); + const user_op::TensorDesc& x_like = ctx->InputTensorDesc("x_like", 0); + const int32_t num_spatial_dims = ctx->Attr("num_spatial_dims"); + CHECK_GE_OR_RETURN(num_spatial_dims, 1); + CHECK_LE_OR_RETURN(num_spatial_dims, 3); + CHECK_EQ_OR_RETURN(dy.shape().NumAxes(), num_spatial_dims + 2); + CHECK_EQ_OR_RETURN(x_like.shape().NumAxes(), num_spatial_dims + 2); + if (ctx->has_input("_add_to_output", 0)) { + const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); + CHECK_EQ_OR_RETURN(add_to_output.shape(), x_like.shape()); + } + *ctx->OutputShape("dx", 0) = ctx->InputShape("x_like", 0); + *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x_like", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe ConvDataGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ConvDataGradOp::GetSbp(user_op::SbpContext* ctx) { + std::vector split_args; + split_args.emplace_back("dy", 0); + split_args.emplace_back("x_like", 0); + split_args.emplace_back("dx", 0); + if (ctx->user_op_conf().has_input("_add_to_output", 0)) { + split_args.emplace_back("_add_to_output", 0); + } + ctx->NewBuilder().Split(split_args, 0).Broadcast(user_op::OpArg("filter", 0)).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ConvDataGradOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + return CheckAttr_<0>(def, conf); +} + +/* static */ Maybe ConvDataGradOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); + const user_op::TensorDesc& x_like = ctx->InputTensorDesc("x_like", 0); + CHECK_EQ_OR_RETURN(x_like.data_type(), dy.data_type()); + if (ctx->has_input("_add_to_output", 0)) { + const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); + CHECK_EQ_OR_RETURN(add_to_output.data_type(), x_like.data_type()); + } + *ctx->OutputDType("dx", 0) = ctx->InputDType("x_like", 0); + return Maybe::Ok(); +} + +/* static */ Maybe ConvFilterGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + + const int32_t num_spatial_dims = ctx->Attr("num_spatial_dims"); + const int32_t groups = ctx->Attr("groups"); + const std::string& data_format = ctx->Attr("data_format"); + const std::vector kernel_size = ctx->Attr>("kernel_size"); + + CHECK_GE_OR_RETURN(num_spatial_dims, 1); + CHECK_LE_OR_RETURN(num_spatial_dims, 3); + CHECK_EQ_OR_RETURN(dy.shape().NumAxes(), num_spatial_dims + 2); + CHECK_EQ_OR_RETURN(x.shape().NumAxes(), num_spatial_dims + 2); + CHECK_GT_OR_RETURN(groups, 0); + + DimVector filter_diff_dim_vec; + if (data_format == "channels_first") { + CHECK_LE_OR_RETURN(groups, x.shape().At(1)); + CHECK_LE_OR_RETURN(groups, dy.shape().At(1)); + CHECK_EQ_OR_RETURN(x.shape().At(1) % groups, 0); + CHECK_EQ_OR_RETURN(dy.shape().At(1) % groups, 0); + filter_diff_dim_vec.emplace_back(dy.shape().At(1)); + filter_diff_dim_vec.emplace_back(x.shape().At(1) / groups); + filter_diff_dim_vec.insert(filter_diff_dim_vec.end(), kernel_size.cbegin(), kernel_size.cend()); + } else { + CHECK_EQ_OR_RETURN("channels_last", data_format); + CHECK_EQ_OR_RETURN(groups, 1); + filter_diff_dim_vec.emplace_back(dy.shape().dim_vec().back()); + filter_diff_dim_vec.insert(filter_diff_dim_vec.end(), kernel_size.cbegin(), kernel_size.cend()); + filter_diff_dim_vec.emplace_back(x.shape().dim_vec().back() / groups); + } + + user_op::TensorDesc* filter_diff = ctx->OutputTensorDesc("filter_diff", 0); + *filter_diff->mut_shape() = Shape(filter_diff_dim_vec); + filter_diff->set_is_dynamic(false); + + return Maybe::Ok(); +} + +/*static*/ Maybe ConvFilterGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ConvFilterGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), 0) + .Split(user_op::OpArg("x", 0), 0) + .PartialSum(user_op::OpArg("filter_diff", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ConvFilterGradOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + return CheckAttr_<0>(def, conf); +} + +/* static */ Maybe ConvFilterGradOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + CHECK_EQ_OR_RETURN(x.data_type(), dy.data_type()); + user_op::TensorDesc* filter_diff = ctx->OutputTensorDesc("filter_diff", 0); + *filter_diff->mut_data_type() = x.data_type(); + return Maybe::Ok(); +} + +/* static */ Maybe ConvBiasGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); + user_op::TensorDesc* bias_diff = ctx->OutputTensorDesc("bias_diff", 0); + + int32_t num_spatial_dims = ctx->Attr("num_spatial_dims"); + std::string data_format = ctx->Attr("data_format"); + + CHECK_GE_OR_RETURN(num_spatial_dims, 1); + CHECK_LE_OR_RETURN(num_spatial_dims, 3); + CHECK_EQ_OR_RETURN(dy.shape().NumAxes(), num_spatial_dims + 2); + if (data_format == "channels_first") { + *bias_diff->mut_shape() = Shape({dy.shape().At(1)}); + } else if (data_format == "channels_last") { + *bias_diff->mut_shape() = Shape({dy.shape().At(dy.shape().NumAxes() - 1)}); + } else { + OF_UNIMPLEMENTED(); + } + return Maybe::Ok(); +} + +/*static*/ Maybe ConvBiasGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ConvBiasGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), 0) + .PartialSum(user_op::OpArg("bias_diff", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ConvBiasGradOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + std::string data_format = conf.attr("data_format"); + if (data_format == "channels_first" || data_format == "channels_last") { + return Maybe::Ok(); + } + return oneflow::Error::CheckFailedError() << "Illegal value for " << conf.op_type_name() << " op " + << conf.op_name() << ": data_format:" << data_format; +} + +/* static */ Maybe ConvBiasGradOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); + user_op::TensorDesc* bias_diff = ctx->OutputTensorDesc("bias_diff", 0); + *bias_diff->mut_data_type() = dy.data_type(); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("conv1d").SetGenBackwardOpConfFn(GenerateBackwardOpConf4Conv); REGISTER_USER_OP_GRAD("conv2d").SetGenBackwardOpConfFn(GenerateBackwardOpConf4Conv); REGISTER_USER_OP_GRAD("conv3d").SetGenBackwardOpConfFn(GenerateBackwardOpConf4Conv); -REGISTER_USER_OP("conv_data_grad") - .Input("dy") - .Input("filter") - .Input("x_like") - .OptionalInput("_add_to_output") - .Output("dx") - .Attr("num_spatial_dims") - .Attr>("padding_before") - .Attr("data_format") - .Attr>("kernel_size") - .Attr>("strides") - .Attr>("dilation_rate") - .Attr("groups") - .SetCheckAttrFn(CheckAttr<0>) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); - const user_op::TensorDesc& x_like = ctx->InputTensorDesc("x_like", 0); - const int32_t num_spatial_dims = ctx->Attr("num_spatial_dims"); - CHECK_GE_OR_RETURN(num_spatial_dims, 1); - CHECK_LE_OR_RETURN(num_spatial_dims, 3); - CHECK_EQ_OR_RETURN(dy.shape().NumAxes(), num_spatial_dims + 2); - CHECK_EQ_OR_RETURN(x_like.shape().NumAxes(), num_spatial_dims + 2); - if (ctx->has_input("_add_to_output", 0)) { - const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); - CHECK_EQ_OR_RETURN(add_to_output.shape(), x_like.shape()); - } - *ctx->OutputShape("dx", 0) = ctx->InputShape("x_like", 0); - *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x_like", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - std::vector split_args; - split_args.emplace_back("dy", 0); - split_args.emplace_back("x_like", 0); - split_args.emplace_back("dx", 0); - if (ctx->user_op_conf().has_input("_add_to_output", 0)) { - split_args.emplace_back("_add_to_output", 0); - } - ctx->NewBuilder().Split(split_args, 0).Broadcast(user_op::OpArg("filter", 0)).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); - const user_op::TensorDesc& x_like = ctx->InputTensorDesc("x_like", 0); - CHECK_EQ_OR_RETURN(x_like.data_type(), dy.data_type()); - if (ctx->has_input("_add_to_output", 0)) { - const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); - CHECK_EQ_OR_RETURN(add_to_output.data_type(), x_like.data_type()); - } - *ctx->OutputDType("dx", 0) = ctx->InputDType("x_like", 0); - return Maybe::Ok(); - }); - -REGISTER_USER_OP("conv_filter_grad") - .Input("dy") - .Input("x") - .Output("filter_diff") - .Attr("num_spatial_dims") - .Attr>("padding_before") - .Attr("data_format") - .Attr>("kernel_size") - .Attr>("strides") - .Attr>("dilation_rate") - .Attr("groups") - .SetCheckAttrFn(CheckAttr<0>) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - - const int32_t num_spatial_dims = ctx->Attr("num_spatial_dims"); - const int32_t groups = ctx->Attr("groups"); - const std::string& data_format = ctx->Attr("data_format"); - const std::vector kernel_size = ctx->Attr>("kernel_size"); - - CHECK_GE_OR_RETURN(num_spatial_dims, 1); - CHECK_LE_OR_RETURN(num_spatial_dims, 3); - CHECK_EQ_OR_RETURN(dy.shape().NumAxes(), num_spatial_dims + 2); - CHECK_EQ_OR_RETURN(x.shape().NumAxes(), num_spatial_dims + 2); - CHECK_GT_OR_RETURN(groups, 0); - - DimVector filter_diff_dim_vec; - if (data_format == "channels_first") { - CHECK_LE_OR_RETURN(groups, x.shape().At(1)); - CHECK_LE_OR_RETURN(groups, dy.shape().At(1)); - CHECK_EQ_OR_RETURN(x.shape().At(1) % groups, 0); - CHECK_EQ_OR_RETURN(dy.shape().At(1) % groups, 0); - filter_diff_dim_vec.emplace_back(dy.shape().At(1)); - filter_diff_dim_vec.emplace_back(x.shape().At(1) / groups); - filter_diff_dim_vec.insert(filter_diff_dim_vec.end(), kernel_size.cbegin(), - kernel_size.cend()); - } else { - CHECK_EQ_OR_RETURN("channels_last", data_format); - CHECK_EQ_OR_RETURN(groups, 1); - filter_diff_dim_vec.emplace_back(dy.shape().dim_vec().back()); - filter_diff_dim_vec.insert(filter_diff_dim_vec.end(), kernel_size.cbegin(), - kernel_size.cend()); - filter_diff_dim_vec.emplace_back(x.shape().dim_vec().back() / groups); - } - - user_op::TensorDesc* filter_diff = ctx->OutputTensorDesc("filter_diff", 0); - *filter_diff->mut_shape() = Shape(filter_diff_dim_vec); - filter_diff->set_is_dynamic(false); - - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), 0) - .Split(user_op::OpArg("x", 0), 0) - .PartialSum(user_op::OpArg("filter_diff", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - CHECK_EQ_OR_RETURN(x.data_type(), dy.data_type()); - user_op::TensorDesc* filter_diff = ctx->OutputTensorDesc("filter_diff", 0); - *filter_diff->mut_data_type() = x.data_type(); - return Maybe::Ok(); - }); - -REGISTER_USER_OP("conv_bias_grad") - .Input("dy") - .Output("bias_diff") - .Attr("data_format") - .Attr("num_spatial_dims") - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& def, - const user_op::UserOpConfWrapper& conf) -> Maybe { - std::string data_format = conf.attr("data_format"); - if (data_format == "channels_first" || data_format == "channels_last") { - return Maybe::Ok(); - } - return oneflow::Error::CheckFailedError() - << "Illegal value for " << conf.op_type_name() << " op " << conf.op_name() - << ": data_format:" << data_format; - }) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); - user_op::TensorDesc* bias_diff = ctx->TensorDesc4ArgNameAndIndex("bias_diff", 0); - - int32_t num_spatial_dims = ctx->Attr("num_spatial_dims"); - std::string data_format = ctx->Attr("data_format"); - - CHECK_GE_OR_RETURN(num_spatial_dims, 1); - CHECK_LE_OR_RETURN(num_spatial_dims, 3); - CHECK_EQ_OR_RETURN(dy.shape().NumAxes(), num_spatial_dims + 2); - if (data_format == "channels_first") { - *bias_diff->mut_shape() = Shape({dy.shape().At(1)}); - } else if (data_format == "channels_last") { - *bias_diff->mut_shape() = Shape({dy.shape().At(dy.shape().NumAxes() - 1)}); - } else { - OF_UNIMPLEMENTED(); - } - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), 0) - .PartialSum(user_op::OpArg("bias_diff", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); - user_op::TensorDesc* bias_diff = ctx->TensorDesc4ArgNameAndIndex("bias_diff", 0); - *bias_diff->mut_data_type() = dy.data_type(); - return Maybe::Ok(); - }); - } // namespace oneflow diff --git a/oneflow/user/ops/copy_op.cpp b/oneflow/user/ops/copy_op.cpp index 774cd3c184a..4a9106640fa 100644 --- a/oneflow/user/ops/copy_op.cpp +++ b/oneflow/user/ops/copy_op.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/device.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -23,52 +24,49 @@ namespace { Maybe> MakeOpDevice(const Symbol& in_device, const Symbol& out_device) { if (JUST(in_device->of_type()) == "gpu" && JUST(out_device->of_type()) == "cpu") { - return Device::New("cuda_d2h"); + return Device::New("cuda_d2h", in_device->device_id()); } else if (JUST(in_device->of_type()) == "cpu" && JUST(out_device->of_type()) == "gpu") { - return Device::New("cuda_h2d"); + return Device::New("cuda_h2d", out_device->device_id()); } else { return Device::New(out_device->type(), out_device->device_id()); } } -std::function>(user_op::DeviceInferContext* ctx)> GetDeviceInferFn() { - std::function>(user_op::DeviceInferContext * ctx)> fn = - [](user_op::DeviceInferContext* ctx) -> Maybe> { - Symbol out_device = - JUST(Device::New(ctx->Attr("device_type"), ctx->Attr("device_id"))); - *ctx->OutputTensorDevice4ArgNameAndIndex("out", 0) = out_device; - const Symbol& in_device = ctx->InputTensorDevice4ArgNameAndIndex("in", 0); - return MakeOpDevice(in_device, out_device); - }; - return fn; +} // namespace + +/* static */ Maybe CopyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); } -REGISTER_USER_OP("copy") - .Input("in") - .Output("out") - .Attr("device_type") - .Attr("device_id") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDeviceInferFn(GetDeviceInferFn()) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto& inputs = ctx->inputs(); - CHECK_EQ_OR_RETURN(inputs.size(), 1); - const auto& input = - ctx->LogicalTensorDesc4InputArgNameAndIndex(inputs[0].first, inputs[0].second); - for (int64_t axis = 0; axis < input.shape().NumAxes(); ++axis) { - ctx->NewBuilder().Split(inputs, axis).Split(ctx->outputs(), axis).Build(); - } - ctx->NewBuilder().PartialSum(inputs).PartialSum(ctx->outputs()).Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe CopyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CopyOp::GetSbp(user_op::SbpContext* ctx) { + const auto& inputs = ctx->inputs(); + CHECK_EQ_OR_RETURN(inputs.size(), 1); + const auto& input = + ctx->LogicalTensorDesc4InputArgNameAndIndex(inputs[0].first, inputs[0].second); + for (int64_t axis = 0; axis < input.shape().NumAxes(); ++axis) { + ctx->NewBuilder().Split(inputs, axis).Split(ctx->outputs(), axis).Build(); + } + ctx->NewBuilder().PartialSum(inputs).PartialSum(ctx->outputs()).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe CopyOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> CopyOp::InferDevice(user_op::DeviceInferContext* ctx) { + Symbol out_device = + JUST(Device::New(ctx->Attr("device_type"), ctx->Attr("device_id"))); + *ctx->OutputTensorDevice4ArgNameAndIndex("out", 0) = out_device; + const Symbol& in_device = ctx->InputTensorDevice4ArgNameAndIndex("in", 0); + return MakeOpDevice(in_device, out_device); +} -} // namespace } // namespace oneflow diff --git a/oneflow/user/ops/count_not_finite_op.cpp b/oneflow/user/ops/count_not_finite_op.cpp index ba4ff545f3e..20e752a0b2c 100644 --- a/oneflow/user/ops/count_not_finite_op.cpp +++ b/oneflow/user/ops/count_not_finite_op.cpp @@ -14,62 +14,71 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("count_not_finite") - .Input("x") - .Output("y") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - *y_desc->mut_shape() = Shape({1}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .PartialSum(user_op::OpArg("y", 0)) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - *y_desc->mut_data_type() = DataType::kInt64; - return Maybe::Ok(); - }); +/* static */ Maybe CountNotFiniteOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + *y_desc->mut_shape() = Shape({1}); + return Maybe::Ok(); +} -REGISTER_NO_GRAD_USER_OP("multi_count_not_finite") - .InputWithMinimum("x", 1) - .Output("y") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - *y_desc->mut_shape() = Shape({1}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - int64_t min_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape().NumAxes(); - for (int64_t i = 1; i < ctx->user_op_conf().input_size("x"); ++i) { - min_num_axes = std::min( - min_num_axes, ctx->LogicalTensorDesc4InputArgNameAndIndex("x", i).shape().NumAxes()); - } - for (int64_t i = 0; i < min_num_axes; ++i) { - ctx->NewBuilder().Split(ctx->inputs(), i).PartialSum(user_op::OpArg("y", 0)).Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& first_x_desc = ctx->InputTensorDesc("x", 0); - for (const auto& in_arg_pair : ctx->inputs()) { - const user_op::TensorDesc& x_desc = - ctx->InputTensorDesc(in_arg_pair.first, in_arg_pair.second); - CHECK_EQ_OR_RETURN(x_desc.data_type(), first_x_desc.data_type()); - } - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - *y_desc->mut_data_type() = DataType::kInt64; - return Maybe::Ok(); - }); +/*static*/ Maybe CountNotFiniteOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CountNotFiniteOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).PartialSum(user_op::OpArg("y", 0)).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe CountNotFiniteOp::InferDataType(user_op::InferContext* ctx) { + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + *y_desc->mut_data_type() = DataType::kInt64; + return Maybe::Ok(); +} + +/* static */ Maybe MultiCountNotFiniteOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + *y_desc->mut_shape() = Shape({1}); + return Maybe::Ok(); +} + +/*static*/ Maybe MultiCountNotFiniteOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe MultiCountNotFiniteOp::GetSbp(user_op::SbpContext* ctx) { + int64_t min_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape().NumAxes(); + for (int64_t i = 1; i < ctx->user_op_conf().input_size("x"); ++i) { + min_num_axes = std::min(min_num_axes, + ctx->LogicalTensorDesc4InputArgNameAndIndex("x", i).shape().NumAxes()); + } + for (int64_t i = 0; i < min_num_axes; ++i) { + ctx->NewBuilder().Split(ctx->inputs(), i).PartialSum(user_op::OpArg("y", 0)).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe MultiCountNotFiniteOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& first_x_desc = ctx->InputTensorDesc("x", 0); + for (const auto& in_arg_pair : ctx->inputs()) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(in_arg_pair.first, in_arg_pair.second); + CHECK_EQ_OR_RETURN(x_desc.data_type(), first_x_desc.data_type()); + } + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + *y_desc->mut_data_type() = DataType::kInt64; + return Maybe::Ok(); +} + +/*static*/ Maybe MultiCountNotFiniteOp::CheckAttr(const user_op::UserOpDefWrapper&, + const user_op::UserOpConfWrapper& op_conf) { + CHECK_OR_RETURN(op_conf.input_size("x") >= 1); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/ctc_loss_op.cpp b/oneflow/user/ops/ctc_loss_op.cpp index feaf6631e55..b8dee1ad9cc 100644 --- a/oneflow/user/ops/ctc_loss_op.cpp +++ b/oneflow/user/ops/ctc_loss_op.cpp @@ -14,105 +14,126 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("ctc_loss") - .Input("log_probs") - .Input("targets") - .Input("input_lengths") - .Input("target_lengths") - .Output("loss") - .Output("alpha") // 'alpha' is just for compute log_probs's grad, alpha's grad will be ignored - .Attr("max_target_length") - .Attr("blank") - .Attr("zero_infinity") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& log_probs = ctx->InputTensorDesc("log_probs", 0); - const user_op::TensorDesc& targets = ctx->InputTensorDesc("targets", 0); - const user_op::TensorDesc& input_lengths = ctx->InputTensorDesc("input_lengths", 0); - const user_op::TensorDesc& target_lengths = ctx->InputTensorDesc("target_lengths", 0); - const int64_t batch_size = log_probs.shape().At(1); - const int64_t max_target_length = ctx->Attr("max_target_length"); - if (targets.shape().NumAxes() == 2) { - CHECK_EQ_OR_RETURN(targets.shape().At(0), batch_size); - CHECK_GE_OR_RETURN(targets.shape().At(1), max_target_length); - } - CHECK_EQ_OR_RETURN(input_lengths.shape().At(0), batch_size); - CHECK_EQ_OR_RETURN(target_lengths.shape().At(0), batch_size); - CHECK_GE_OR_RETURN(ctx->Attr("blank"), 0); - CHECK_LT_OR_RETURN(ctx->Attr("blank"), log_probs.shape().At(2)); - - *ctx->OutputShape("loss", 0) = Shape({batch_size}); - *ctx->OutputShape("alpha", 0) = - Shape({batch_size, log_probs.shape().At(0), 2 * max_target_length + 1}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("log_probs", 0), 1) // `log_probs` batch axis is 1 - .Split(user_op::OpArg("targets", 0), 0) - .Split(user_op::OpArg("input_lengths", 0), 0) - .Split(user_op::OpArg("target_lengths", 0), 0) - .Split(user_op::OpArg("loss", 0), 0) - .Split(user_op::OpArg("alpha", 0), 0) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("loss", 0) = ctx->InputDType("log_probs", 0); - *ctx->OutputDType("alpha", 0) = ctx->InputDType("log_probs", 0); - return Maybe::Ok(); - }); +/* static */ Maybe CtcLossOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& log_probs = ctx->InputTensorDesc("log_probs", 0); + const user_op::TensorDesc& targets = ctx->InputTensorDesc("targets", 0); + const user_op::TensorDesc& input_lengths = ctx->InputTensorDesc("input_lengths", 0); + const user_op::TensorDesc& target_lengths = ctx->InputTensorDesc("target_lengths", 0); + const int64_t batch_size = log_probs.shape().At(1); + const int64_t max_target_length = ctx->Attr("max_target_length"); + if (targets.shape().NumAxes() == 2) { + CHECK_EQ_OR_RETURN(targets.shape().At(0), batch_size); + CHECK_GE_OR_RETURN(targets.shape().At(1), max_target_length); + } + CHECK_EQ_OR_RETURN(input_lengths.shape().At(0), batch_size); + CHECK_EQ_OR_RETURN(target_lengths.shape().At(0), batch_size); + CHECK_GE_OR_RETURN(ctx->Attr("blank"), 0); + CHECK_LT_OR_RETURN(ctx->Attr("blank"), log_probs.shape().At(2)); -REGISTER_USER_OP("ctc_loss_grad") - .Input("grad_out") - .Input("log_probs") - .Input("targets") - .Input("input_lengths") - .Input("target_lengths") - .Input("loss") - .Input("alpha") - .Output("grad") - .Attr("max_target_length") - .Attr("blank") - .Attr("zero_infinity") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& log_probs = ctx->InputTensorDesc("log_probs", 0); - const user_op::TensorDesc& targets = ctx->InputTensorDesc("targets", 0); - const user_op::TensorDesc& input_lengths = ctx->InputTensorDesc("input_lengths", 0); - const user_op::TensorDesc& target_lengths = ctx->InputTensorDesc("target_lengths", 0); - const int64_t batch_size = log_probs.shape().At(1); - const int64_t max_target_length = ctx->Attr("max_target_length"); - if (targets.shape().NumAxes() == 2) { - CHECK_EQ_OR_RETURN(targets.shape().At(0), batch_size); - CHECK_GE_OR_RETURN(targets.shape().At(1), max_target_length); - } - CHECK_EQ_OR_RETURN(input_lengths.shape().At(0), batch_size); - CHECK_EQ_OR_RETURN(target_lengths.shape().At(0), batch_size); - CHECK_GE_OR_RETURN(ctx->Attr("blank"), 0); - CHECK_LT_OR_RETURN(ctx->Attr("blank"), log_probs.shape().At(2)); - - *ctx->OutputShape("grad", 0) = log_probs.shape(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("grad_out", 0), 0) - .Split(user_op::OpArg("log_probs", 0), 1) // `log_probs` batch axis is 1 - .Split(user_op::OpArg("targets", 0), 0) - .Split(user_op::OpArg("input_lengths", 0), 0) - .Split(user_op::OpArg("target_lengths", 0), 0) - .Split(user_op::OpArg("loss", 0), 0) - .Split(user_op::OpArg("alpha", 0), 0) - .Split(user_op::OpArg("grad", 0), 1) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("grad", 0) = ctx->InputDType("log_probs", 0); - return Maybe::Ok(); - }); + *ctx->OutputShape("loss", 0) = Shape({batch_size}); + *ctx->OutputShape("alpha", 0) = + Shape({batch_size, log_probs.shape().At(0), 2 * max_target_length + 1}); + return Maybe::Ok(); +} + +/*static*/ Maybe CtcLossOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CtcLossOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("log_probs", 0), 1) // `log_probs` batch axis is 1 + .Split(user_op::OpArg("targets", 0), 0) + .Split(user_op::OpArg("input_lengths", 0), 0) + .Split(user_op::OpArg("target_lengths", 0), 0) + .Split(user_op::OpArg("loss", 0), 0) + .Split(user_op::OpArg("alpha", 0), 0) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe CtcLossOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("loss", 0) = ctx->InputDType("log_probs", 0); + *ctx->OutputDType("alpha", 0) = ctx->InputDType("log_probs", 0); + return Maybe::Ok(); +} + +/* static */ Maybe CtcLossGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& log_probs = ctx->InputTensorDesc("log_probs", 0); + const user_op::TensorDesc& targets = ctx->InputTensorDesc("targets", 0); + const user_op::TensorDesc& input_lengths = ctx->InputTensorDesc("input_lengths", 0); + const user_op::TensorDesc& target_lengths = ctx->InputTensorDesc("target_lengths", 0); + const int64_t batch_size = log_probs.shape().At(1); + const int64_t max_target_length = ctx->Attr("max_target_length"); + if (targets.shape().NumAxes() == 2) { + CHECK_EQ_OR_RETURN(targets.shape().At(0), batch_size); + CHECK_GE_OR_RETURN(targets.shape().At(1), max_target_length); + } + CHECK_EQ_OR_RETURN(input_lengths.shape().At(0), batch_size); + CHECK_EQ_OR_RETURN(target_lengths.shape().At(0), batch_size); + CHECK_GE_OR_RETURN(ctx->Attr("blank"), 0); + CHECK_LT_OR_RETURN(ctx->Attr("blank"), log_probs.shape().At(2)); + + *ctx->OutputShape("grad", 0) = log_probs.shape(); + return Maybe::Ok(); +} + +/*static*/ Maybe CtcLossGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CtcLossGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("grad_out", 0), 0) + .Split(user_op::OpArg("log_probs", 0), 1) // `log_probs` batch axis is 1 + .Split(user_op::OpArg("targets", 0), 0) + .Split(user_op::OpArg("input_lengths", 0), 0) + .Split(user_op::OpArg("target_lengths", 0), 0) + .Split(user_op::OpArg("loss", 0), 0) + .Split(user_op::OpArg("alpha", 0), 0) + .Split(user_op::OpArg("grad", 0), 1) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe CtcLossGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("grad", 0) = ctx->InputDType("log_probs", 0); + return Maybe::Ok(); +} + +/* static */ Maybe CtcGreedyDecoderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& log_probs = ctx->InputTensorDesc("log_probs", 0); + const user_op::TensorDesc& input_lengths = ctx->InputTensorDesc("input_lengths", 0); + const int64_t batch_size = log_probs.shape().At(1); + CHECK_EQ_OR_RETURN(batch_size, input_lengths.shape().At(0)); + *ctx->OutputShape("decoded", 0) = Shape({batch_size, log_probs.shape().At(0)}); + *ctx->OutputShape("neg_sum_logits", 0) = Shape({batch_size, 1}); + return Maybe::Ok(); +} + +/*static*/ Maybe CtcGreedyDecoderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CtcGreedyDecoderOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("log_probs", 0), 1) // `log_probs` batch axis is 1 + .Split(user_op::OpArg("input_lengths", 0), 0) + .Split(user_op::OpArg("decoded", 0), 0) + .Split(user_op::OpArg("neg_sum_logits", 0), 0) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe CtcGreedyDecoderOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("decoded", 0) = ctx->InputDType("input_lengths", 0); + *ctx->OutputDType("neg_sum_logits", 0) = ctx->InputDType("log_probs", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("ctc_loss") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { @@ -139,34 +160,4 @@ REGISTER_USER_OP_GRAD("ctc_loss") return Maybe::Ok(); }); -REGISTER_USER_OP("ctc_greedy_decoder") - .Input("log_probs") - .Input("input_lengths") - .Output("decoded") - .Output("neg_sum_logits") - .Attr("merge_repeated") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& log_probs = ctx->InputTensorDesc("log_probs", 0); - const user_op::TensorDesc& input_lengths = ctx->InputTensorDesc("input_lengths", 0); - const int64_t batch_size = log_probs.shape().At(1); - CHECK_EQ_OR_RETURN(batch_size, input_lengths.shape().At(0)); - *ctx->OutputShape("decoded", 0) = Shape({batch_size, log_probs.shape().At(0)}); - *ctx->OutputShape("neg_sum_logits", 0) = Shape({batch_size, 1}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("log_probs", 0), 1) // `log_probs` batch axis is 1 - .Split(user_op::OpArg("input_lengths", 0), 0) - .Split(user_op::OpArg("decoded", 0), 0) - .Split(user_op::OpArg("neg_sum_logits", 0), 0) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("decoded", 0) = ctx->InputDType("input_lengths", 0); - *ctx->OutputDType("neg_sum_logits", 0) = ctx->InputDType("log_probs", 0); - return Maybe::Ok(); - }); - } // namespace oneflow diff --git a/oneflow/user/ops/cumsum_op.cpp b/oneflow/user/ops/cumsum_op.cpp new file mode 100644 index 00000000000..beff8b29a5b --- /dev/null +++ b/oneflow/user/ops/cumsum_op.cpp @@ -0,0 +1,87 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" + +namespace oneflow { + +namespace { + +REGISTER_USER_OP("cumsum") + .Input("in") + .Output("out") + .Attr("dim") + .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); + }) + .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { + const auto& in_tensor_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + auto dim = ctx->Attr("dim"); + for (auto i = dim + 1; i < in_tensor_desc.shape().NumAxes(); i++) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + return Maybe::Ok(); + }) + .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); + }); + +REGISTER_USER_OP("cumsum_grad") + .Input("dy") + .Output("dx") + .Attr("dim") + .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { + *ctx->OutputShape("dx", 0) = ctx->InputShape("dy", 0); + return Maybe::Ok(); + }) + .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { + const auto& dy_tensor_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); + for (auto i = 0; i < dy_tensor_desc.shape().NumAxes(); i++) { + ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe::Ok(); + }) + .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); + }); + +REGISTER_USER_OP_GRAD("cumsum").SetGenBackwardOpConfFn( + [](const user_op::UserOpWrapper& op, const user_op::AddOpFn& AddOp) -> Maybe { + if (op.NeedGenGradTensor4OpInput("in", 0)) { + user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); + user_op::UserOpConfWrapper grad_op = + builder.Op("cumsum_grad") + .Input("dy", op.GetGradTensorWithOpOutput("out", 0)) + .Output("dx") + .Attr("dim", op.attr("dim")) + .Build(); + op.BindGradTensorWithOpInput(grad_op.output("dx", 0), "x", 0); + AddOp(grad_op); + } + return Maybe::Ok(); + }); + +} // namespace + +} // namespace oneflow diff --git a/oneflow/user/ops/deconv_op.cpp b/oneflow/user/ops/deconv_op.cpp index cd1a82d8ba0..657098c6e7f 100644 --- a/oneflow/user/ops/deconv_op.cpp +++ b/oneflow/user/ops/deconv_op.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/nn_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -23,6 +24,7 @@ namespace { template Maybe InferTensorDesc4DeConv(user_op::InferContext* ctx) { const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + CHECK_EQ_OR_RETURN(NDims + 2, in.shape().NumAxes()); const std::string& data_format = ctx->Attr("data_format"); @@ -30,7 +32,7 @@ Maybe InferTensorDesc4DeConv(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(NDims, kernel_size.size()); const int32_t filters = ctx->Attr("filters"); size_t idx_offset = IdxOffset(data_format); - + int32_t groups = ctx->Attr("groups"); { const auto& dilation_rate = ctx->Attr>("dilation_rate"); const auto& output_padding = ctx->Attr>("output_padding"); @@ -65,15 +67,14 @@ Maybe InferTensorDesc4DeConv(user_op::InferContext* ctx) { DimVector weight_shape(in.shape().dim_vec()); if (data_format == "channels_first") { weight_shape.at(0) = in.shape().At(1); - weight_shape.at(1) = filters; + weight_shape.at(1) = filters / groups; } else if (data_format == "channels_last") { weight_shape.at(0) = in.shape().At(NDims + 1); - weight_shape.at(NDims + 1) = filters; + weight_shape.at(NDims + 1) = filters / groups; } else { UNIMPLEMENTED_THEN_RETURN(); } for (size_t i = 0; i < NDims; ++i) { weight_shape.at(idx_offset + i) = kernel_size.at(i); } - const user_op::TensorDesc& weight = ctx->InputTensorDesc("weight", 0); CHECK_EQ_OR_RETURN(weight.shape(), Shape(weight_shape)); } @@ -81,7 +82,7 @@ Maybe InferTensorDesc4DeConv(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe InferDataType(user_op::InferContext* ctx) { +Maybe InferDataType_(user_op::InferContext* ctx) { *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -97,8 +98,8 @@ Maybe GetSbpSignatures4DeConv(user_op::SbpContext* ctx) { } template -Maybe CheckAttr(const user_op::UserOpDefWrapper& def, - const user_op::UserOpConfWrapper& conf) { +Maybe CheckAttr_(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { bool is_checked = true; std::stringstream err; err << "Illegal value for " << conf.op_type_name() << " op " << conf.op_name() << ": "; @@ -150,6 +151,7 @@ Maybe GenerateBackwardOpConf4DeConv(const user_op::UserOpWrapper& op, const auto& strides = op.attr>("strides"); const auto& dilation_rate = op.attr>("dilation_rate"); const Shape& weight_shape = op.TensorDesc4ArgNameAndIndex("weight", 0).shape(); + int32_t groups = op.attr("groups"); const int32_t ndims = kernel_size.size(); CHECK_EQ_OR_RETURN(ndims, strides.size()); @@ -168,7 +170,7 @@ Maybe GenerateBackwardOpConf4DeConv(const user_op::UserOpWrapper& op, .Attr>("kernel_size", kernel_size) .Attr>("strides", strides) .Attr>("dilation_rate", dilation_rate) - .Attr("groups", 1) + .Attr("groups", groups) .Build(); op.BindGradTensorWithOpInput(filter_grad_op.output("filter_diff", 0), "weight", 0); AddOp(filter_grad_op); @@ -188,7 +190,7 @@ Maybe GenerateBackwardOpConf4DeConv(const user_op::UserOpWrapper& op, .Attr>("kernel_size", kernel_size) .Attr>("strides", strides) .Attr>("dilation_rate", dilation_rate) - .Attr("groups", 1) + .Attr("groups", groups) .Build(); op.BindGradTensorWithOpInput(data_grad_op.output("out", 0), "in", 0); AddOp(data_grad_op); @@ -198,56 +200,68 @@ Maybe GenerateBackwardOpConf4DeConv(const user_op::UserOpWrapper& op, } // namespace -REGISTER_USER_OP("deconv1d") - .Input("in") - .Input("weight") - .Output("out") - .Attr("filters") - .Attr>("padding_before") - .Attr("data_format") - .Attr>("kernel_size") - .Attr>("output_padding") - .Attr>("strides") - .Attr>("dilation_rate") - .Attr("groups", 1) - .SetCheckAttrFn(CheckAttr<1>) - .SetTensorDescInferFn(InferTensorDesc4DeConv<1>) - .SetGetSbpFn(GetSbpSignatures4DeConv) - .SetDataTypeInferFn(InferDataType); - -REGISTER_USER_OP("deconv2d") - .Input("in") - .Input("weight") - .Output("out") - .Attr("filters") - .Attr>("padding_before") - .Attr("data_format") - .Attr>("kernel_size") - .Attr>("output_padding") - .Attr>("strides") - .Attr>("dilation_rate") - .Attr("groups", 1) - .SetCheckAttrFn(CheckAttr<2>) - .SetTensorDescInferFn(InferTensorDesc4DeConv<2>) - .SetGetSbpFn(GetSbpSignatures4DeConv) - .SetDataTypeInferFn(InferDataType); - -REGISTER_USER_OP("deconv3d") - .Input("in") - .Input("weight") - .Output("out") - .Attr("filters") - .Attr>("padding_before") - .Attr("data_format") - .Attr>("kernel_size") - .Attr>("output_padding") - .Attr>("strides") - .Attr>("dilation_rate") - .Attr("groups", 1) - .SetCheckAttrFn(CheckAttr<3>) - .SetTensorDescInferFn(InferTensorDesc4DeConv<3>) - .SetDataTypeInferFn(InferDataType) - .SetGetSbpFn(GetSbpSignatures4DeConv); +/* static */ Maybe Deconv1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDesc4DeConv<1>(ctx); +} + +/*static*/ Maybe Deconv1DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe Deconv1DOp::GetSbp(user_op::SbpContext* ctx) { + return GetSbpSignatures4DeConv(ctx); +} + +/* static */ Maybe Deconv1DOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + return CheckAttr_<1>(def, conf); +} + +/* static */ Maybe Deconv1DOp::InferDataType(user_op::InferContext* ctx) { + return InferDataType_(ctx); +} + +/* static */ Maybe Deconv2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDesc4DeConv<2>(ctx); +} + +/*static*/ Maybe Deconv2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe Deconv2DOp::GetSbp(user_op::SbpContext* ctx) { + return GetSbpSignatures4DeConv(ctx); +} + +/* static */ Maybe Deconv2DOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + return CheckAttr_<2>(def, conf); +} + +/* static */ Maybe Deconv2DOp::InferDataType(user_op::InferContext* ctx) { + return InferDataType_(ctx); +} + +/* static */ Maybe Deconv3DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDesc4DeConv<3>(ctx); +} + +/*static*/ Maybe Deconv3DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe Deconv3DOp::GetSbp(user_op::SbpContext* ctx) { + return GetSbpSignatures4DeConv(ctx); +} + +/* static */ Maybe Deconv3DOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + return CheckAttr_<3>(def, conf); +} + +/* static */ Maybe Deconv3DOp::InferDataType(user_op::InferContext* ctx) { + return InferDataType_(ctx); +} REGISTER_USER_OP_GRAD("deconv1d").SetGenBackwardOpConfFn(GenerateBackwardOpConf4DeConv); REGISTER_USER_OP_GRAD("deconv2d").SetGenBackwardOpConfFn(GenerateBackwardOpConf4DeConv); diff --git a/oneflow/user/ops/diag_op.cpp b/oneflow/user/ops/diag_op.cpp index 616fe3eadf2..3ea7d0ffd1d 100644 --- a/oneflow/user/ops/diag_op.cpp +++ b/oneflow/user/ops/diag_op.cpp @@ -14,69 +14,73 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("diag") - .Input("in") - .Output("out") - .Attr("diagonal", 0) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - const int32_t diagonal = ctx->Attr("diagonal"); - const ShapeView& in_shape = in.shape(); - const int32_t in_dim = in_shape.NumAxes(); - CHECK_GE_OR_RETURN(in_dim, 1); - CHECK_LE_OR_RETURN(in_dim, 2); - - DimVector out_dim_vec = {0}; - if (in_dim == 1) { - int32_t out_tensor_size = in_shape.At(0) + std::abs(diagonal); - out_dim_vec[0] = out_tensor_size; - out_dim_vec.emplace_back(out_tensor_size); - } else { - if (diagonal >= 0) { - out_dim_vec[0] = std::min(in_shape.At(0), in_shape.At(1) - diagonal); - } else { - out_dim_vec[0] = std::min(in_shape.At(0) + diagonal, in_shape.At(1)); - } - CHECK_GT_OR_RETURN(out_dim_vec[0], 0); - } - - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - out_desc->set_is_dynamic(false); - *out_desc->mut_shape() = Shape(out_dim_vec); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); - -REGISTER_USER_OP("diag_grad") - .Input("dy") - .Input("in") - .Attr("diagonal", 0) - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - const Shape& in_shape = in.shape(); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); - *dx_desc->mut_shape() = Shape(in_shape.dim_vec()); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/* static */ Maybe DiagOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + const int32_t diagonal = ctx->Attr("diagonal"); + const ShapeView& in_shape = in.shape(); + const int32_t in_dim = in_shape.NumAxes(); + CHECK_GE_OR_RETURN(in_dim, 1); + CHECK_LE_OR_RETURN(in_dim, 2); + + DimVector out_dim_vec = {0}; + if (in_dim == 1) { + int32_t out_tensor_size = in_shape.At(0) + std::abs(diagonal); + out_dim_vec[0] = out_tensor_size; + out_dim_vec.emplace_back(out_tensor_size); + } else { + if (diagonal >= 0) { + out_dim_vec[0] = std::min(in_shape.At(0), in_shape.At(1) - diagonal); + } else { + out_dim_vec[0] = std::min(in_shape.At(0) + diagonal, in_shape.At(1)); + } + CHECK_GT_OR_RETURN(out_dim_vec[0], 0); + } + + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + out_desc->set_is_dynamic(false); + *out_desc->mut_shape() = Shape(out_dim_vec); + return Maybe::Ok(); +} + +/*static*/ Maybe DiagOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe DiagOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe DiagOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe DiagGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + const Shape& in_shape = in.shape(); + user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + *dx_desc->mut_shape() = Shape(in_shape.dim_vec()); + return Maybe::Ok(); +} + +/*static*/ Maybe DiagGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe DiagGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe DiagGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("diag").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { diff --git a/oneflow/user/ops/diagonal_op.cpp b/oneflow/user/ops/diagonal_op.cpp new file mode 100644 index 00000000000..c7bed93b172 --- /dev/null +++ b/oneflow/user/ops/diagonal_op.cpp @@ -0,0 +1,100 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" + +namespace oneflow { + +/* static */ Maybe DiagonalOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + const int32_t offset = ctx->Attr("offset"); + const ShapeView& in_shape = in.shape(); + const int32_t in_dim = in_shape.NumAxes(); + CHECK_GE_OR_RETURN(in_dim, 2); + + DimVector out_dim_vec = {}; + FOR_RANGE(int32_t, index, 2, in_dim) { out_dim_vec.push_back(in_shape.At(index)); } + int32_t last_dim = 0; + if (offset >= 0) { + last_dim = std::min(in_shape.At(0), in_shape.At(1) - offset); + } else { + last_dim = std::min(in_shape.At(0) + offset, in_shape.At(1)); + } + if (last_dim < 0) { last_dim = 0; } + out_dim_vec.push_back(last_dim); + + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + out_desc->set_is_dynamic(false); + *out_desc->mut_shape() = Shape(out_dim_vec); + return Maybe::Ok(); +} + +/*static*/ Maybe DiagonalOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe DiagonalOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe DiagonalOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe DiagonalGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + const Shape& in_shape = in.shape(); + user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + *dx_desc->mut_shape() = Shape(in_shape.dim_vec()); + return Maybe::Ok(); +} + +/*static*/ Maybe DiagonalGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe DiagonalGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe DiagonalGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} + +REGISTER_USER_OP_GRAD("diagonal") + .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { + const auto grad_op_name = ctx->FwOp().op_name() + "_grad"; + ctx->DefineOp(grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("diagonal_grad") + .InputBind("in", ctx->FwOp().input("in", 0)) + .InputBind("dy", ctx->FwOp().output_grad("out", 0)) + .Attr("offset", ctx->FwOp().attr("offset")) + .Output("dx") + .Build(); + }); + + ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), + [&ctx, &grad_op_name]() -> const std::string& { + return ctx->GetOp(grad_op_name).output("dx", 0); + }); + return Maybe::Ok(); + }); + +} // namespace oneflow diff --git a/oneflow/user/ops/dim_gather_op.cpp b/oneflow/user/ops/dim_gather_op.cpp index 9c490a97985..fa7bd0815a0 100644 --- a/oneflow/user/ops/dim_gather_op.cpp +++ b/oneflow/user/ops/dim_gather_op.cpp @@ -15,79 +15,80 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/dim_gather_kernel_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace user_op { -REGISTER_USER_OP("dim_gather") - .Input("input") - .Input("index") - .Output("output") - .Attr("dim") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const TensorDesc& in = ctx->InputTensorDesc("input", 0); - int64_t input_num_axes = in.shape().NumAxes(); - CHECK_GT_OR_RETURN(input_num_axes, 0); - CHECK_LE_OR_RETURN(input_num_axes, kDimGatherMaxDimCount); +/* static */ Maybe DimGatherOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("input", 0); + int64_t input_num_axes = in.shape().NumAxes(); + CHECK_GT_OR_RETURN(input_num_axes, 0); + CHECK_LE_OR_RETURN(input_num_axes, kDimGatherMaxDimCount); - const TensorDesc& index = ctx->InputTensorDesc("index", 0); - int64_t index_num_axes = index.shape().NumAxes(); + const user_op::TensorDesc& index = ctx->InputTensorDesc("index", 0); + int64_t index_num_axes = index.shape().NumAxes(); - const int32_t dim = ctx->Attr("dim"); - CHECK_GE_OR_RETURN(dim, 0); - CHECK_LT_OR_RETURN(dim, input_num_axes); - CHECK_EQ_OR_RETURN(input_num_axes, index_num_axes); + const int32_t dim = ctx->Attr("dim"); + CHECK_GE_OR_RETURN(dim, 0); + CHECK_LT_OR_RETURN(dim, input_num_axes); + CHECK_EQ_OR_RETURN(input_num_axes, index_num_axes); - CHECK_EQ_OR_RETURN(in.is_dynamic(), index.is_dynamic()); + CHECK_EQ_OR_RETURN(in.is_dynamic(), index.is_dynamic()); - user_op::TensorDesc* out = ctx->OutputTensorDesc("output", 0); - *out->mut_shape() = index.shape(); + user_op::TensorDesc* out = ctx->OutputTensorDesc("output", 0); + *out->mut_shape() = index.shape(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const TensorDesc& index = ctx->InputTensorDesc("index", 0); - CHECK_OR_RETURN(IsIndexDataType(index.data_type())); - const TensorDesc& in = ctx->InputTensorDesc("input", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("output", 0); - *out->mut_data_type() = in.data_type(); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("index", 0); - CHECK_OR_RETURN(indices_modifier != nullptr); - indices_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& index_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("index", 0); - int64_t index_num_axes = index_tensor.shape().NumAxes(); - const int32_t dim = ctx->Attr("dim"); - - FOR_RANGE(int64_t, i, 0, index_num_axes) { - if (i != dim) { - ctx->NewBuilder() - .Split(user_op::OpArg("index", 0), i) - .Split(user_op::OpArg("input", 0), i) - .Split(user_op::OpArg("output", 0), i) - .Build(); - } else if (i == dim) { - ctx->NewBuilder() - .Broadcast(user_op::OpArg("input", 0)) - .Split(user_op::OpArg("index", 0), i) - .Split(user_op::OpArg("output", 0), i) - .Build(); - } - } + return Maybe::Ok(); +} + +/*static*/ Maybe DimGatherOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe DimGatherOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& index_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("index", 0); + int64_t index_num_axes = index_tensor.shape().NumAxes(); + const int32_t dim = ctx->Attr("dim"); + + FOR_RANGE(int64_t, i, 0, index_num_axes) { + if (i != dim) { ctx->NewBuilder() - .PartialSum(user_op::OpArg("input", 0)) - .Broadcast(user_op::OpArg("index", 0)) - .PartialSum(user_op::OpArg("output", 0)) + .Split(user_op::OpArg("index", 0), i) + .Split(user_op::OpArg("input", 0), i) + .Split(user_op::OpArg("output", 0), i) .Build(); - return Maybe::Ok(); - }); + } else if (i == dim) { + ctx->NewBuilder() + .Broadcast(user_op::OpArg("input", 0)) + .Split(user_op::OpArg("index", 0), i) + .Split(user_op::OpArg("output", 0), i) + .Build(); + } + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("input", 0)) + .Broadcast(user_op::OpArg("index", 0)) + .PartialSum(user_op::OpArg("output", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe DimGatherOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("index", 0); + CHECK_OR_RETURN(indices_modifier != nullptr); + indices_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe DimGatherOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& index = ctx->InputTensorDesc("index", 0); + CHECK_OR_RETURN(IsIndexDataType(index.data_type())); + const user_op::TensorDesc& in = ctx->InputTensorDesc("input", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("output", 0); + *out->mut_data_type() = in.data_type(); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("dim_gather") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { @@ -113,6 +114,4 @@ REGISTER_USER_OP_GRAD("dim_gather") return Maybe::Ok(); }); -} // namespace user_op - } // namespace oneflow diff --git a/oneflow/user/ops/dim_scatter_ops.cpp b/oneflow/user/ops/dim_scatter_ops.cpp index 38b1a599a3b..c6f84a91faf 100644 --- a/oneflow/user/ops/dim_scatter_ops.cpp +++ b/oneflow/user/ops/dim_scatter_ops.cpp @@ -17,25 +17,26 @@ limitations under the License. #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/user_op_registry.h" #include "oneflow/user/kernels/dim_scatter_kernel_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace user_op { - namespace { Maybe InferTensorDesc(user_op::InferContext* ctx) { - const TensorDesc* input = ctx->TensorDesc4ArgNameAndIndex("input", 0); - const TensorDesc* index = ctx->TensorDesc4ArgNameAndIndex("index", 0); - const TensorDesc* like = ctx->TensorDesc4ArgNameAndIndex("like", 0); - const TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("src", 0); + const user_op::TensorDesc* input = + ctx->has_input("input", 0) ? &ctx->InputTensorDesc("input", 0) : nullptr; + const user_op::TensorDesc& index = ctx->InputTensorDesc("index", 0); + const user_op::TensorDesc* like = + ctx->has_input("like", 0) ? &ctx->InputTensorDesc("like", 0) : nullptr; + const user_op::TensorDesc& src = ctx->InputTensorDesc("src", 0); int32_t dim = ctx->Attr("dim"); // check index.numaxes == src.num_axes == input/like.numaxes - int64_t src_num_axes = src->shape().NumAxes(); + int64_t src_num_axes = src.shape().NumAxes(); CHECK_GT_OR_RETURN(src_num_axes, 0); - CHECK_LE_OR_RETURN(src_num_axes, kDimGatherMaxDimCount); - int64_t index_num_axes = index->shape().NumAxes(); + CHECK_LE_OR_RETURN(src_num_axes, user_op::kDimGatherMaxDimCount); + int64_t index_num_axes = index.shape().NumAxes(); CHECK_EQ_OR_RETURN(src_num_axes, index_num_axes); int64_t output_num_axes = 0; @@ -52,46 +53,46 @@ Maybe InferTensorDesc(user_op::InferContext* ctx) { FOR_RANGE(int64_t, i, 0, index_num_axes) { if (i == dim) continue; if (input) { - CHECK_LE_OR_RETURN(index->shape().At(i), input->shape().At(i)); + CHECK_LE_OR_RETURN(index.shape().At(i), input->shape().At(i)); } else { - CHECK_LE_OR_RETURN(index->shape().At(i), like->shape().At(i)); + CHECK_LE_OR_RETURN(index.shape().At(i), like->shape().At(i)); } } // check index.shape(i) <= src.shape(i) FOR_RANGE(int64_t, i, 0, index_num_axes) { if (i == dim) continue; - CHECK_LE_OR_RETURN(index->shape().At(i), src->shape().At(i)); + CHECK_LE_OR_RETURN(index.shape().At(i), src.shape().At(i)); } - user_op::TensorDesc* out = ctx->TensorDesc4ArgNameAndIndex("output", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("output", 0); *out->mut_shape() = input ? input->shape() : like->shape(); return Maybe::Ok(); } Maybe InferScalarTensorDesc(user_op::InferContext* ctx) { - const TensorDesc* input = ctx->TensorDesc4ArgNameAndIndex("input", 0); - const TensorDesc* index = ctx->TensorDesc4ArgNameAndIndex("index", 0); + const user_op::TensorDesc& input = ctx->InputTensorDesc("input", 0); + const user_op::TensorDesc& index = ctx->InputTensorDesc("index", 0); int32_t dim = ctx->Attr("dim"); // check index.numaxes == src.num_axes == input/like.numaxes - int64_t output_num_axes = input->shape().NumAxes(); - int64_t index_num_axes = index->shape().NumAxes(); + int64_t output_num_axes = input.shape().NumAxes(); + int64_t index_num_axes = index.shape().NumAxes(); CHECK_EQ_OR_RETURN(output_num_axes, index_num_axes); // check index.shape(i) <= input/like.shape(i) FOR_RANGE(int64_t, i, 0, index_num_axes) { if (i == dim) continue; - CHECK_LE_OR_RETURN(index->shape().At(i), input->shape().At(i)); + CHECK_LE_OR_RETURN(index.shape().At(i), input.shape().At(i)); } - TensorDesc* out = ctx->TensorDesc4ArgNameAndIndex("output", 0); - *out->mut_shape() = input->shape(); + user_op::TensorDesc* out = ctx->OutputTensorDesc("output", 0); + *out->mut_shape() = input.shape(); return Maybe::Ok(); } -Maybe InputArgModifierFn(user_op::GetInputArgModifier GetInputArgModifierFn, +Maybe InputArgModifierFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("index", 0); CHECK(indices_modifier != nullptr); @@ -100,7 +101,7 @@ Maybe InputArgModifierFn(user_op::GetInputArgModifier GetInputArgModifierF return Maybe::Ok(); } -Maybe InputScalarArgModifierFn(user_op::GetInputArgModifier GetInputArgModifierFn, +Maybe InputScalarArgModifierFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("index", 0); CHECK(indices_modifier != nullptr); @@ -158,10 +159,10 @@ Maybe SetSbpScatter(user_op::SbpContext* ctx) { } Maybe InferDtype(user_op::InferContext* ctx) { - const TensorDesc* index = ctx->TensorDesc4ArgNameAndIndex("index", 0); - CHECK_OR_RETURN(IsIndexDataType(index->data_type())); - const TensorDesc* input = ctx->TensorDesc4ArgNameAndIndex("input", 0); - if (input) { + const user_op::TensorDesc& index = ctx->InputTensorDesc("index", 0); + CHECK_OR_RETURN(IsIndexDataType(index.data_type())); + if (ctx->has_input("input", 0)) { + const user_op::TensorDesc& input = ctx->InputTensorDesc("input", 0); CHECK_EQ_OR_RETURN(ctx->InputDType("input", 0), ctx->InputDType("src", 0)); } else { CHECK_EQ_OR_RETURN(ctx->InputDType("like", 0), ctx->InputDType("src", 0)); @@ -171,15 +172,15 @@ Maybe InferDtype(user_op::InferContext* ctx) { } Maybe InferScalarDtype(user_op::InferContext* ctx) { - const TensorDesc* index = ctx->TensorDesc4ArgNameAndIndex("index", 0); - CHECK_OR_RETURN(IsIndexDataType(index->data_type())); + const user_op::TensorDesc& index = ctx->InputTensorDesc("index", 0); + CHECK_OR_RETURN(IsIndexDataType(index.data_type())); *ctx->OutputDType("output", 0) = ctx->InputDType("input", 0); return Maybe::Ok(); } Maybe ScatterBackward(user_op::BackwardOpConfContext* ctx) { - const TensorDesc& src = ctx->FwOp().TensorDesc4ArgNameAndIndex("src", 0); - const TensorDesc& index = ctx->FwOp().TensorDesc4ArgNameAndIndex("index", 0); + const user_op::TensorDesc& src = ctx->FwOp().TensorDesc4ArgNameAndIndex("src", 0); + const user_op::TensorDesc& index = ctx->FwOp().TensorDesc4ArgNameAndIndex("index", 0); const int64_t ndim = src.shape().NumAxes(); FOR_RANGE(int64_t, i, 0, ndim) { @@ -220,41 +221,70 @@ Maybe ScatterBackward(user_op::BackwardOpConfContext* ctx) { } // namespace -#define REGISTER_SCATTER_LIKE_OP(optypename) \ - REGISTER_USER_OP(optypename) \ - .Input("like") \ - .Input("index") \ - .Input("src") \ - .Output("output") \ - .Attr("dim") \ - .SetTensorDescInferFn(InferTensorDesc) \ - .SetInputArgModifyFn(InputArgModifierFn) \ - .SetDataTypeInferFn(InferDtype) \ - .SetGetSbpFn(SetSbpLike) - -#define REGISTER_SCATTER_OP(optypename) \ - REGISTER_USER_OP(optypename) \ - .Input("input") \ - .Input("index") \ - .Input("src") \ - .Output("output") \ - .Attr("dim") \ - .SetTensorDescInferFn(InferTensorDesc) \ - .SetInputArgModifyFn(InputArgModifierFn) \ - .SetDataTypeInferFn(InferDtype) \ - .SetGetSbpFn(SetSbpScatter) - -#define REGISTER_SCATTER_SCALAR_OP(optypename) \ - REGISTER_USER_OP(optypename) \ - .Input("input") \ - .Input("index") \ - .Attr("src_scalar") \ - .Output("output") \ - .Attr("dim") \ - .SetTensorDescInferFn(InferScalarTensorDesc) \ - .SetInputArgModifyFn(InputScalarArgModifierFn) \ - .SetDataTypeInferFn(InferScalarDtype) \ - .SetGetSbpFn(SetSbpScatter) +/* static */ Maybe DimScatterAddLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDesc(ctx); +} + +/*static*/ Maybe DimScatterAddLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe DimScatterAddLikeOp::GetSbp(user_op::SbpContext* ctx) { + return SetSbpLike(ctx); +} + +/* static */ Maybe DimScatterAddLikeOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return InputArgModifierFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe DimScatterAddLikeOp::InferDataType(user_op::InferContext* ctx) { + return InferDtype(ctx); +} + +#define DEF_SCATTER_OP(op_class_name) \ + /* static */ Maybe op_class_name::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return InferTensorDesc(ctx); \ + } \ + \ + /*static*/ Maybe op_class_name::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + \ + /* static */ Maybe op_class_name::GetSbp(user_op::SbpContext* ctx) { \ + return SetSbpScatter(ctx); \ + } \ + \ + /* static */ Maybe op_class_name::ModifyInputArg( \ + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { \ + return InputArgModifierFn(GetInputArgModifierFn, conf); \ + } \ + \ + /* static */ Maybe op_class_name::InferDataType(user_op::InferContext* ctx) { \ + return InferDtype(ctx); \ + } + +#define DEF_SCATTER_SCALAR_OP(optypename) \ + /* static */ Maybe optypename::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return InferScalarTensorDesc(ctx); \ + } \ + \ + /*static*/ Maybe optypename::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + \ + /* static */ Maybe optypename::GetSbp(user_op::SbpContext* ctx) { \ + return SetSbpScatter(ctx); \ + } \ + \ + /* static */ Maybe optypename::ModifyInputArg( \ + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { \ + return InputScalarArgModifierFn(GetInputArgModifierFn, conf); \ + } \ + \ + /* static */ Maybe optypename::InferDataType(user_op::InferContext* ctx) { \ + return InferScalarDtype(ctx); \ + } #define REGISTER_SCATTER_GRAD(optypename) \ REGISTER_USER_OP_GRAD(optypename).SetBackwardOpConfGenFn(ScatterBackward); @@ -278,19 +308,17 @@ Maybe ScatterBackward(user_op::BackwardOpConfContext* ctx) { }); \ return Maybe::Ok(); \ }); +DEF_SCATTER_OP(DimScatterAddOp); +DEF_SCATTER_OP(DimScatterUpdateOp); +DEF_SCATTER_OP(DimScatterMulOp); -REGISTER_SCATTER_LIKE_OP("dim_scatter_add_like"); -REGISTER_SCATTER_OP("dim_scatter_add"); -REGISTER_SCATTER_OP("dim_scatter_update"); -REGISTER_SCATTER_OP("dim_scatter_mul"); - -REGISTER_SCATTER_SCALAR_OP("dim_scatter_update_scalar"); -REGISTER_SCATTER_SCALAR_OP("dim_scatter_add_scalar"); -REGISTER_SCATTER_SCALAR_OP("dim_scatter_mul_scalar"); +DEF_SCATTER_SCALAR_OP(DimScatterUpdateScalarOp); +DEF_SCATTER_SCALAR_OP(DimScatterAddScalarOp); +DEF_SCATTER_SCALAR_OP(DimScatterMulScalarOp); REGISTER_SCATTER_GRAD("dim_scatter_add"); REGISTER_SCATTER_GRAD("dim_scatter_update"); REGISTER_SCATTER_SCALAR_GRAD("dim_scatter_update_scalar"); -} // namespace user_op + } // namespace oneflow diff --git a/oneflow/user/ops/distributions/normal_op.cpp b/oneflow/user/ops/distributions/normal_op.cpp index 1fe0b07e05f..29adc2e0fbb 100644 --- a/oneflow/user/ops/distributions/normal_op.cpp +++ b/oneflow/user/ops/distributions/normal_op.cpp @@ -15,37 +15,36 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("normal") - .Output("out") - .SetOutputBufferNum(1) - .Attr("mean", 0) - .Attr("std", 1) - .Attr("seed") - .Attr("dtype") - .Attr("shape") - .Attr("nd_sbp") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - Shape* out_shape = ctx->OutputShape("out", 0); - const Shape& shape = ctx->Attr("shape"); - *out_shape = shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - auto dtype = ctx->Attr("dtype"); - *ctx->OutputDType("out", 0) = dtype; - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - cfg::SbpParallel default_sbp; - default_sbp.mutable_broadcast_parallel(); - return user_op::InferNdSbp4SrcOp(ctx, default_sbp); - }); +/* static */ Maybe NormalOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + Shape* out_shape = ctx->OutputShape("out", 0); + const Shape& shape = ctx->Attr("shape"); + *out_shape = shape; + return Maybe::Ok(); +} + +/*static*/ Maybe NormalOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe NormalOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe NormalOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + cfg::SbpParallel default_sbp; + default_sbp.mutable_broadcast_parallel(); + return user_op::InferNdSbp4SrcOp(ctx, default_sbp); +} + +/* static */ Maybe NormalOp::InferDataType(user_op::InferContext* ctx) { + auto dtype = ctx->Attr("dtype"); + *ctx->OutputDType("out", 0) = dtype; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/distributions/uniform_int_op.cpp b/oneflow/user/ops/distributions/uniform_int_op.cpp index 9f080e76634..99a0bb94d9a 100644 --- a/oneflow/user/ops/distributions/uniform_int_op.cpp +++ b/oneflow/user/ops/distributions/uniform_int_op.cpp @@ -14,41 +14,55 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" - +#include "oneflow/core/framework/op_generated.h" +#include "oneflow/core/common/balanced_splitter.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("uniform_int") - .Output("out") - .SetOutputBufferNum(1) - .Attr("from", 0) - .Attr("to", 1) - .Attr("seed") - .Attr("dtype") - .Attr("shape") - .Attr("nd_sbp") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - Shape* out_shape = ctx->OutputShape("out", 0); - const Shape& shape = ctx->Attr("shape"); - DimVector dim_vec; - if (shape.NumAxes() > 0) { - dim_vec.insert(dim_vec.end(), shape.dim_vec().cbegin(), shape.dim_vec().cend()); - } - *out_shape = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - auto dtype = ctx->Attr("dtype"); - *ctx->OutputDType("out", 0) = dtype; - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - cfg::SbpParallel default_sbp; - default_sbp.mutable_broadcast_parallel(); - return user_op::InferNdSbp4SrcOp(ctx, default_sbp); - }); +/* static */ Maybe UniformIntOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + Shape* out_shape = ctx->OutputShape("out", 0); + const Shape& shape = ctx->Attr("shape"); + DimVector dim_vec; + if (shape.NumAxes() > 0) { + dim_vec.insert(dim_vec.end(), shape.dim_vec().cbegin(), shape.dim_vec().cend()); + } + *out_shape = Shape(dim_vec); + return Maybe::Ok(); +} + +/*static*/ Maybe UniformIntOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + const Shape& shape = ctx->Attr("shape"); + DimVector dim_vec{shape.dim_vec()}; + + const cfg::SbpParallel& out_sbp_para = ctx->SbpParallel4ArgNameAndIndex("out", 0); + if (out_sbp_para.has_split_parallel()) { + const int64_t& parallel_num = ctx->parallel_ctx().parallel_num(); + if (parallel_num > 1) { + const int64_t& split_axis = out_sbp_para.split_parallel().axis(); + CHECK_LT_OR_RETURN(split_axis, dim_vec.size()); + BalancedSplitter bs(shape.At(split_axis), parallel_num); + dim_vec[split_axis] = bs.At(ctx->parallel_ctx().parallel_id()).size(); + } + } + + *ctx->OutputShape("out", 0) = Shape(dim_vec); + return Maybe::Ok(); +} + +/* static */ Maybe UniformIntOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe UniformIntOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + cfg::SbpParallel default_sbp; + default_sbp.mutable_broadcast_parallel(); + return user_op::InferNdSbp4SrcOp(ctx, default_sbp); +} + +/* static */ Maybe UniformIntOp::InferDataType(user_op::InferContext* ctx) { + auto dtype = ctx->Attr("dtype"); + *ctx->OutputDType("out", 0) = dtype; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/distributions/uniform_op.cpp b/oneflow/user/ops/distributions/uniform_op.cpp index 9c45dd7d244..0e972755055 100644 --- a/oneflow/user/ops/distributions/uniform_op.cpp +++ b/oneflow/user/ops/distributions/uniform_op.cpp @@ -14,41 +14,40 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("uniform") - .Output("out") - .SetOutputBufferNum(1) - .Attr("from", 0) - .Attr("to", 1) - .Attr("seed") - .Attr("dtype") - .Attr("shape") - .Attr("nd_sbp") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - Shape* out_shape = ctx->OutputShape("out", 0); - const Shape& shape = ctx->Attr("shape"); - DimVector dim_vec; - if (shape.NumAxes() > 0) { - dim_vec.insert(dim_vec.end(), shape.dim_vec().cbegin(), shape.dim_vec().cend()); - } - *out_shape = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - auto dtype = ctx->Attr("dtype"); - *ctx->OutputDType("out", 0) = dtype; - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - cfg::SbpParallel default_sbp; - default_sbp.mutable_broadcast_parallel(); - return user_op::InferNdSbp4SrcOp(ctx, default_sbp); - }); +/* static */ Maybe UniformOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + Shape* out_shape = ctx->OutputShape("out", 0); + const Shape& shape = ctx->Attr("shape"); + DimVector dim_vec; + if (shape.NumAxes() > 0) { + dim_vec.insert(dim_vec.end(), shape.dim_vec().cbegin(), shape.dim_vec().cend()); + } + *out_shape = Shape(dim_vec); + return Maybe::Ok(); +} + +/*static*/ Maybe UniformOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe UniformOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe UniformOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + cfg::SbpParallel default_sbp; + default_sbp.mutable_broadcast_parallel(); + return user_op::InferNdSbp4SrcOp(ctx, default_sbp); +} + +/* static */ Maybe UniformOp::InferDataType(user_op::InferContext* ctx) { + auto dtype = ctx->Attr("dtype"); + *ctx->OutputDType("out", 0) = dtype; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/dot_op.cpp b/oneflow/user/ops/dot_op.cpp index 369c510fe2e..2b376442bfd 100644 --- a/oneflow/user/ops/dot_op.cpp +++ b/oneflow/user/ops/dot_op.cpp @@ -14,39 +14,40 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/* static */ Maybe DotOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& y = ctx->InputTensorDesc("y", 0); + CHECK_OR_RETURN(x.shape() == y.shape()) << "Input tensor shape is different"; + CHECK_OR_RETURN(x.shape().NumAxes() == 1) << "Input tensor is not 1D"; + *ctx->OutputShape("out", 0) = Shape({}); + return Maybe::Ok(); +} -REGISTER_USER_OP("dot") - .Input("x") - .Input("y") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& y = ctx->InputTensorDesc("y", 0); - CHECK_OR_RETURN(x.shape() == y.shape()) << "Input tensor shape is different"; - CHECK_OR_RETURN(x.shape().NumAxes() == 1) << "Input tensor is not 1D"; - *ctx->OutputShape("out", 0) = Shape({}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), 0) - .Split(user_op::OpArg("y", 0), 0) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); +/*static*/ Maybe DotOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& y = ctx->InputTensorDesc("y", 0); - CHECK_OR_RETURN(x.data_type() == y.data_type()) << "The input tensor type is different"; - *ctx->OutputDType("out", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/* static */ Maybe DotOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), 0) + .Split(user_op::OpArg("y", 0), 0) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + + return Maybe::Ok(); +} + +/* static */ Maybe DotOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& y = ctx->InputTensorDesc("y", 0); + CHECK_OR_RETURN(x.data_type() == y.data_type()) << "The data type of input tensors are different"; + *ctx->OutputDType("out", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("dot").SetGenBackwardOpConfFn( [](const user_op::UserOpWrapper& op, const user_op::AddOpFn& AddOp) -> Maybe { @@ -76,6 +77,4 @@ REGISTER_USER_OP_GRAD("dot").SetGenBackwardOpConfFn( return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/dropout_op.cpp b/oneflow/user/ops/dropout_op.cpp index 41cab7bd88e..20beb57083a 100644 --- a/oneflow/user/ops/dropout_op.cpp +++ b/oneflow/user/ops/dropout_op.cpp @@ -14,77 +14,112 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -REGISTER_USER_OP("dropout") - .Input("in") - .OptionalInput("_add_to_output") - .Output("out") - .Output("mask") - .Attr("rate") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - *ctx->OutputShape("out", 0) = in_shape; - *ctx->OutputShape("mask", 0) = in_shape; - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), axis).Build(); - } - return Maybe::Ok(); - }) - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& op_def, - const user_op::UserOpConfWrapper& op_conf) -> Maybe { - float rate = op_conf.attr("rate"); - CHECK_GE_OR_RETURN(rate, 0.0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - *ctx->OutputDType("mask", 0) = DataType::kInt8; - return Maybe::Ok(); - }); - -REGISTER_USER_OP("dropout_grad") - .Input("dy") - .Input("mask") - .Output("dx") - .Attr("scale") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - *ctx->OutputShape("dx", 0) = dy_shape; - *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("dy", 0); - CHECK_EQ_OR_RETURN(ctx->InputShape("mask", 0), dy_shape); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); - FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), axis) - .Split(user_op::OpArg("mask", 0), axis) - .Split(user_op::OpArg("dx", 0), axis) - .Build(); - } - return Maybe::Ok(); - }) - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& op_def, - const user_op::UserOpConfWrapper& op_conf) -> Maybe { - float scale = op_conf.attr("scale"); - CHECK_GT_OR_RETURN(scale, 1); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - CHECK_EQ_OR_RETURN(ctx->InputDType("mask", 0), DataType::kInt8); - return Maybe::Ok(); - }); +/* static */ Maybe DropoutOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + *ctx->OutputShape("out", 0) = in_shape; + *ctx->OutputShape("mask", 0) = in_shape; + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe DropoutOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe DropoutOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), axis).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe DropoutOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + float rate = conf.attr("rate"); + CHECK_GE_OR_RETURN(rate, 0.0); + return Maybe::Ok(); +} + +/* static */ Maybe DropoutOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->OutputDType("mask", 0) = DataType::kInt8; + return Maybe::Ok(); +} + +/* static */ Maybe DropoutGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + *ctx->OutputShape("dx", 0) = dy_shape; + *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("dy", 0); + CHECK_EQ_OR_RETURN(ctx->InputShape("mask", 0), dy_shape); + return Maybe::Ok(); +} + +/*static*/ Maybe DropoutGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe DropoutGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); + FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), axis) + .Split(user_op::OpArg("mask", 0), axis) + .Split(user_op::OpArg("dx", 0), axis) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe DropoutGradOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + float scale = conf.attr("scale"); + CHECK_GT_OR_RETURN(scale, 1); + return Maybe::Ok(); +} + +/* static */ Maybe DropoutGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + CHECK_EQ_OR_RETURN(ctx->InputDType("mask", 0), DataType::kInt8); + return Maybe::Ok(); +} + +/* static */ Maybe RandomMaskLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("like", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe RandomMaskLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe RandomMaskLikeOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& like_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0); + FOR_RANGE(int64_t, axis, 0, like_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("like", 0), axis) + .Split(user_op::OpArg("out", 0), axis) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe RandomMaskLikeOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + float rate = conf.attr("rate"); + CHECK_GE_OR_RETURN(rate, 0); + CHECK_LT_OR_RETURN(rate, 1); + return Maybe::Ok(); +} + +/* static */ Maybe RandomMaskLikeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = DataType::kInt8; + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("dropout").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { @@ -106,38 +141,4 @@ REGISTER_USER_OP_GRAD("dropout").SetGenBackwardOpConfFn([](const user_op::UserOp return Maybe::Ok(); }); -REGISTER_NO_GRAD_USER_OP("random_mask_like") - .Input("like") - .Output("out") - .Attr("rate") - .Attr("seed") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("like", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& like_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0); - FOR_RANGE(int64_t, axis, 0, like_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("like", 0), axis) - .Split(user_op::OpArg("out", 0), axis) - .Build(); - } - return Maybe::Ok(); - }) - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& op_def, - const user_op::UserOpConfWrapper& op_conf) -> Maybe { - float rate = op_conf.attr("rate"); - CHECK_GE_OR_RETURN(rate, 0); - CHECK_LT_OR_RETURN(rate, 1); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = DataType::kInt8; - return Maybe::Ok(); - }); - -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/dynamic_loss_scale_schedule_op.cpp b/oneflow/user/ops/dynamic_loss_scale_schedule_op.cpp index e24277c4235..5745a6cdd25 100644 --- a/oneflow/user/ops/dynamic_loss_scale_schedule_op.cpp +++ b/oneflow/user/ops/dynamic_loss_scale_schedule_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -27,24 +28,27 @@ bool IsTensorWithType(const user_op::TensorDesc* desc, DataType data_type) { return desc->data_type() == data_type; } -Maybe InferTensorDesc(user_op::InferContext* ctx) { +} // namespace + +/* static */ Maybe DynamicLossScaleScheduleOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { CHECK_OR_RETURN(IsScalarTensor(&(ctx->InputTensorDesc("count_not_finite", 0)))); CHECK_OR_RETURN(IsScalarTensor(&(ctx->InputTensorDesc("loss_scale", 0)))); CHECK_OR_RETURN(IsScalarTensor(&(ctx->InputTensorDesc("good_step_counter", 0)))); return Maybe::Ok(); } -Maybe InferDataType(user_op::InferContext* ctx) { - CHECK_OR_RETURN( - IsTensorWithType(&(ctx->InputTensorDesc("count_not_finite", 0)), DataType::kInt64)); - CHECK_OR_RETURN(IsTensorWithType(&(ctx->InputTensorDesc("loss_scale", 0)), DataType::kFloat)); - CHECK_OR_RETURN( - IsTensorWithType(&(ctx->InputTensorDesc("good_step_counter", 0)), DataType::kInt64)); - return Maybe::Ok(); +/*static*/ Maybe DynamicLossScaleScheduleOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe DynamicLossScaleScheduleOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } -Maybe InputArgModifierFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) { +/* static */ Maybe DynamicLossScaleScheduleOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* loss_scale = GetInputArgModifierFn("loss_scale", 0); CHECK_OR_RETURN(loss_scale != nullptr); loss_scale->set_is_mutable(true); @@ -54,17 +58,13 @@ Maybe InputArgModifierFn(const user_op::GetInputArgModifier& GetInputArgMo return Maybe::Ok(); } -} // namespace - -REGISTER_USER_OP("dynamic_loss_scale_schedule") - .Input("count_not_finite") - .Input("loss_scale") - .Input("good_step_counter") - .Attr("increment_period", 2000) - .Attr("multiplier", 2.0) - .SetTensorDescInferFn(InferTensorDesc) - .SetInputArgModifyFn(InputArgModifierFn) - .SetDataTypeInferFn(InferDataType) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/* static */ Maybe DynamicLossScaleScheduleOp::InferDataType(user_op::InferContext* ctx) { + CHECK_OR_RETURN( + IsTensorWithType(&(ctx->InputTensorDesc("count_not_finite", 0)), DataType::kInt64)); + CHECK_OR_RETURN(IsTensorWithType(&(ctx->InputTensorDesc("loss_scale", 0)), DataType::kFloat)); + CHECK_OR_RETURN( + IsTensorWithType(&(ctx->InputTensorDesc("good_step_counter", 0)), DataType::kInt64)); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/eager_b_to_s_op.cpp b/oneflow/user/ops/eager_b_to_s_op.cpp index 542b96fa5df..88a0f4a82a0 100644 --- a/oneflow/user/ops/eager_b_to_s_op.cpp +++ b/oneflow/user/ops/eager_b_to_s_op.cpp @@ -19,12 +19,12 @@ limitations under the License. #include "oneflow/core/common/shape.h" #include "oneflow/core/framework/device.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -Maybe TensorDescInfer(user_op::InferContext* ctx) { +// Can only be called in mirrored TODO: move this comment to ods +/* static */ Maybe EagerBToSOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& shape = ctx->Attr("shape"); const std::string& out_parallel_conf_txt = ctx->Attr("out_parallel_conf"); const int64_t out_split_axis = ctx->Attr("out_split_axis"); @@ -40,27 +40,25 @@ Maybe TensorDescInfer(user_op::InferContext* ctx) { return Maybe::Ok(); } -} // namespace +/*static*/ Maybe EagerBToSOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe EagerBToSOp::GetSbp(user_op::SbpContext* ctx) { + return Error::TypeError() << "eager_b_to_s op doesn't support consistent tensor!"; +} + +/* static */ Maybe EagerBToSOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + return Error::TypeError() << "eager_b_to_s op doesn't support consistent tensor!"; +} -// Can only be called in mirrored -REGISTER_NO_GRAD_USER_OP("eager_b_to_s") - .Input("in") - .Output("out") - .Attr("out_split_axis", -1) - .Attr("in_parallel_conf") - .Attr("out_parallel_conf") - .Attr("shape") - .SetTensorDescInferFn(TensorDescInfer) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - return Error::TypeError() << "eager_b_to_s op doesn't support consistent tensor!"; - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - return Error::TypeError() << "eager_b_to_s op doesn't support consistent tensor!"; - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/* static */ Maybe EagerBToSOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> EagerBToSOp::InferDevice(user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} } // namespace oneflow diff --git a/oneflow/user/ops/eager_nccl_ops.cpp b/oneflow/user/ops/eager_nccl_ops.cpp index efbbedd538d..daa5bd045d8 100644 --- a/oneflow/user/ops/eager_nccl_ops.cpp +++ b/oneflow/user/ops/eager_nccl_ops.cpp @@ -18,210 +18,233 @@ limitations under the License. #include "oneflow/core/common/decorator.h" #include "oneflow/core/framework/device.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("eager_nccl_all_reduce") - .Input("in") - .Output("out") - .Attr("parallel_conf") - .Attr("async_launch", false) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&IsAsyncLaunched>) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); - -REGISTER_USER_OP("eager_nccl_broadcast") - .Input("in") - .Output("out") - .Attr("parallel_conf") - .Attr("root", 0) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("in", 0)) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), 0) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_USER_OP("eager_nccl_reduce") - .Input("in") - .Output("out") - .Attr("parallel_conf") - .Attr("root", 0) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - UNIMPLEMENTED_THEN_RETURN() << "consistent tensor are not supported"; - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_USER_OP("eager_nccl_reduce_scatter") - .Input("in") - .Output("out") - .Attr("parallel_conf") - .Attr("op_type", "sum") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetPhysicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - Shape* out_shape = ctx->OutputShape("out", 0); - const Shape& shape = ctx->InputShape("in", 0); - DimVector dim_vec; - if (shape.NumAxes() > 0) { - dim_vec.insert(dim_vec.end(), shape.dim_vec().cbegin(), shape.dim_vec().cend()); - } - const cfg::SbpParallel& out_sbp_para = ctx->SbpParallel4ArgNameAndIndex("out", 0); - const int64_t& parallel_num = ctx->parallel_ctx().parallel_num(); - if (parallel_num > 1) { - const int64_t& split_axis = out_sbp_para.split_parallel().axis(); - CHECK_LT_OR_RETURN(split_axis, dim_vec.size()); - BalancedSplitter bs(shape.At(split_axis), parallel_num); - dim_vec[split_axis] = bs.At(ctx->parallel_ctx().parallel_id()).size(); - } - *out_shape = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - cfg::NdSbp* in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); - CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); - for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { - CHECK_OR_RETURN(sbp_hint.has_partial_sum_parallel() || sbp_hint.has_broadcast_parallel()); - } - in_nd_sbp->clear_sbp_parallel(); - out_nd_sbp->clear_sbp_parallel(); - - // P2S or B2S - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); - in_nd_sbp->CopyFrom(in_dis_hint); - for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { - out_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(0); - } - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_USER_OP("eager_nccl_all_gather") - .Input("in") - .Output("out") - .Attr("parallel_conf") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - cfg::NdSbp* in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); - CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); - for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { - CHECK_OR_RETURN(sbp_hint.has_split_parallel()); - CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), 0); - } - - in_nd_sbp->clear_sbp_parallel(); - out_nd_sbp->clear_sbp_parallel(); - - // S(0)->B - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); - for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { - in_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(0); - out_nd_sbp->add_sbp_parallel()->mutable_broadcast_parallel(); - } - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_USER_OP("eager_nccl_s2s") - .Input("in") - .Output("out") - .Attr("in_split_axis", -1) - .Attr("out_split_axis", -1) - .Attr("parallel_conf") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const int64_t in_split_axis = ctx->user_op_conf().attr("in_split_axis"); - const int64_t out_split_axis = ctx->user_op_conf().attr("out_split_axis"); - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - cfg::NdSbp* in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); - CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); - for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { - CHECK_OR_RETURN(sbp_hint.has_split_parallel()); - CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), in_split_axis); - } - - in_nd_sbp->clear_sbp_parallel(); - out_nd_sbp->clear_sbp_parallel(); - - // S(in)->S(out) - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); - for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { - in_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(in_split_axis); - out_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(out_split_axis); - } - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/* static */ Maybe EagerNcclAllReduceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe EagerNcclAllReduceOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe EagerNcclAllReduceOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().PartialSum(user_op::OpArg("in", 0)).Broadcast(user_op::OpArg("out", 0)).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe EagerNcclAllReduceOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> EagerNcclAllReduceOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&IsAsyncLaunched>(ctx); +} + +/* static */ Maybe EagerNcclBroadcastOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe EagerNcclBroadcastOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe EagerNcclBroadcastOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().PartialSum(user_op::OpArg("in", 0)).Broadcast(user_op::OpArg("out", 0)).Build(); + ctx->NewBuilder().Broadcast(user_op::OpArg("in", 0)).Broadcast(user_op::OpArg("out", 0)).Build(); + ctx->NewBuilder().Split(user_op::OpArg("in", 0), 0).Broadcast(user_op::OpArg("out", 0)).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe EagerNcclBroadcastOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> EagerNcclBroadcastOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + +/* static */ Maybe EagerNcclReduceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe EagerNcclReduceOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe EagerNcclReduceOp::GetSbp(user_op::SbpContext* ctx) { + UNIMPLEMENTED_THEN_RETURN() << "consistent tensor are not supported"; +} + +/* static */ Maybe EagerNcclReduceOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> EagerNcclReduceOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + +/* static */ Maybe EagerNcclReduceScatterOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe EagerNcclReduceScatterOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + Shape* out_shape = ctx->OutputShape("out", 0); + const Shape& shape = ctx->InputShape("in", 0); + DimVector dim_vec; + if (shape.NumAxes() > 0) { + dim_vec.insert(dim_vec.end(), shape.dim_vec().cbegin(), shape.dim_vec().cend()); + } + const cfg::SbpParallel& out_sbp_para = ctx->SbpParallel4ArgNameAndIndex("out", 0); + const int64_t& parallel_num = ctx->parallel_ctx().parallel_num(); + if (parallel_num > 1) { + const int64_t& split_axis = out_sbp_para.split_parallel().axis(); + CHECK_LT_OR_RETURN(split_axis, dim_vec.size()); + BalancedSplitter bs(shape.At(split_axis), parallel_num); + dim_vec[split_axis] = bs.At(ctx->parallel_ctx().parallel_id()).size(); + } + *out_shape = Shape(dim_vec); + return Maybe::Ok(); +} + +/* static */ Maybe EagerNcclReduceScatterOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe EagerNcclReduceScatterOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + cfg::NdSbp* in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); + CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); + for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { + CHECK_OR_RETURN(sbp_hint.has_partial_sum_parallel() || sbp_hint.has_broadcast_parallel()); + } + in_nd_sbp->clear_sbp_parallel(); + out_nd_sbp->clear_sbp_parallel(); + + // P2S or B2S + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); + in_nd_sbp->CopyFrom(in_dis_hint); + for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { + out_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(0); + } + return Maybe::Ok(); +} + +/* static */ Maybe EagerNcclReduceScatterOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> EagerNcclReduceScatterOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + +/* static */ Maybe EagerNcclAllGatherOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe EagerNcclAllGatherOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe EagerNcclAllGatherOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe EagerNcclAllGatherOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + cfg::NdSbp* in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); + CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); + for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { + CHECK_OR_RETURN(sbp_hint.has_split_parallel()); + CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), 0); + } + + in_nd_sbp->clear_sbp_parallel(); + out_nd_sbp->clear_sbp_parallel(); + + // S(0)->B + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); + for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { + in_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(0); + out_nd_sbp->add_sbp_parallel()->mutable_broadcast_parallel(); + } + return Maybe::Ok(); +} + +/* static */ Maybe EagerNcclAllGatherOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> EagerNcclAllGatherOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + +/* static */ Maybe EagerNcclS2sOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe EagerNcclS2sOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe EagerNcclS2sOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + const int64_t in_split_axis = ctx->user_op_conf().attr("in_split_axis"); + const int64_t out_split_axis = ctx->user_op_conf().attr("out_split_axis"); + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + cfg::NdSbp* in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); + CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); + for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { + CHECK_OR_RETURN(sbp_hint.has_split_parallel()); + CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), in_split_axis); + } + + in_nd_sbp->clear_sbp_parallel(); + out_nd_sbp->clear_sbp_parallel(); + + // S(in)->S(out) + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); + for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { + in_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(in_split_axis); + out_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(out_split_axis); + } + return Maybe::Ok(); +} + +/* static */ Maybe EagerNcclS2sOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> EagerNcclS2sOp::InferDevice(user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + } // namespace oneflow diff --git a/oneflow/user/ops/eager_p_to_b_op.cpp b/oneflow/user/ops/eager_p_to_b_op.cpp index 6b30ac9cec8..e1809c23a8d 100644 --- a/oneflow/user/ops/eager_p_to_b_op.cpp +++ b/oneflow/user/ops/eager_p_to_b_op.cpp @@ -19,30 +19,34 @@ limitations under the License. #include "oneflow/core/common/shape.h" #include "oneflow/core/framework/device.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { - // Can only be called in mirrored -REGISTER_NO_GRAD_USER_OP("eager_p_to_b") - .Input("in") - .Output("out") - .Attr("in_parallel_conf") - .Attr("out_parallel_conf") - .Attr("shape") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - return Error::TypeError() << "eager_s_to_b op doesn't support consistent tensor!"; - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - return Error::TypeError() << "eager_s_to_b op doesn't support consistent tensor!"; - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/* static */ Maybe EagerPToBOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); + return Maybe::Ok(); +} + +/*static*/ Maybe EagerPToBOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe EagerPToBOp::GetSbp(user_op::SbpContext* ctx) { + return Error::TypeError() << "eager_s_to_b op doesn't support consistent tensor!"; +} + +/* static */ Maybe EagerPToBOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + return Error::TypeError() << "eager_s_to_b op doesn't support consistent tensor!"; +} + +/* static */ Maybe EagerPToBOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> EagerPToBOp::InferDevice(user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} } // namespace oneflow diff --git a/oneflow/user/ops/eager_p_to_s_op.cpp b/oneflow/user/ops/eager_p_to_s_op.cpp index b3cce498e31..0e981e21fa0 100644 --- a/oneflow/user/ops/eager_p_to_s_op.cpp +++ b/oneflow/user/ops/eager_p_to_s_op.cpp @@ -19,12 +19,11 @@ limitations under the License. #include "oneflow/core/common/shape.h" #include "oneflow/core/framework/device.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -Maybe TensorDescInfer(user_op::InferContext* ctx) { +/* static */ Maybe EagerPToSOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& shape = ctx->Attr("shape"); const std::string& out_parallel_conf_txt = ctx->Attr("out_parallel_conf"); const int64_t out_split_axis = ctx->Attr("out_split_axis"); @@ -40,27 +39,25 @@ Maybe TensorDescInfer(user_op::InferContext* ctx) { return Maybe::Ok(); } -} // namespace +/*static*/ Maybe EagerPToSOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe EagerPToSOp::GetSbp(user_op::SbpContext* ctx) { + return Error::TypeError() << "eager_b_to_s op doesn't support consistent tensor!"; +} + +/* static */ Maybe EagerPToSOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + return Error::TypeError() << "eager_b_to_s op doesn't support consistent tensor!"; +} -// Can only be called in mirrored -REGISTER_NO_GRAD_USER_OP("eager_p_to_s") - .Input("in") - .Output("out") - .Attr("out_split_axis", -1) - .Attr("in_parallel_conf") - .Attr("out_parallel_conf") - .Attr("shape") - .SetTensorDescInferFn(TensorDescInfer) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - return Error::TypeError() << "eager_b_to_s op doesn't support consistent tensor!"; - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - return Error::TypeError() << "eager_b_to_s op doesn't support consistent tensor!"; - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/* static */ Maybe EagerPToSOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> EagerPToSOp::InferDevice(user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} } // namespace oneflow diff --git a/oneflow/user/ops/eager_s_to_b_op.cpp b/oneflow/user/ops/eager_s_to_b_op.cpp index 6407f4b1ebb..3af6f00b4ad 100644 --- a/oneflow/user/ops/eager_s_to_b_op.cpp +++ b/oneflow/user/ops/eager_s_to_b_op.cpp @@ -19,31 +19,34 @@ limitations under the License. #include "oneflow/core/common/shape.h" #include "oneflow/core/framework/device.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -// Can only be called in mirrored -REGISTER_NO_GRAD_USER_OP("eager_s_to_b") - .Input("in") - .Output("out") - .Attr("in_split_axis", -1) - .Attr("in_parallel_conf") - .Attr("out_parallel_conf") - .Attr("shape") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - return Error::TypeError() << "eager_s_to_b op doesn't support consistent tensor!"; - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - return Error::TypeError() << "eager_s_to_b op doesn't support consistent tensor!"; - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/* static */ Maybe EagerSToBOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); + return Maybe::Ok(); +} + +/*static*/ Maybe EagerSToBOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe EagerSToBOp::GetSbp(user_op::SbpContext* ctx) { + return Error::TypeError() << "eager_s_to_b op doesn't support consistent tensor!"; +} + +/* static */ Maybe EagerSToBOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + return Error::TypeError() << "eager_s_to_b op doesn't support consistent tensor!"; +} + +/* static */ Maybe EagerSToBOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> EagerSToBOp::InferDevice(user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} } // namespace oneflow diff --git a/oneflow/user/ops/eager_s_to_s_op.cpp b/oneflow/user/ops/eager_s_to_s_op.cpp index 9a72361c135..773b671fcd9 100644 --- a/oneflow/user/ops/eager_s_to_s_op.cpp +++ b/oneflow/user/ops/eager_s_to_s_op.cpp @@ -19,12 +19,11 @@ limitations under the License. #include "oneflow/core/common/shape.h" #include "oneflow/core/framework/device.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -Maybe TensorDescInfer(user_op::InferContext* ctx) { +/* static */ Maybe EagerNaiveSToSOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& shape = ctx->Attr("shape"); const std::string& out_parallel_conf_txt = ctx->Attr("out_parallel_conf"); const int64_t out_split_axis = ctx->Attr("out_split_axis"); @@ -40,28 +39,25 @@ Maybe TensorDescInfer(user_op::InferContext* ctx) { return Maybe::Ok(); } -} // namespace +/*static*/ Maybe EagerNaiveSToSOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe EagerNaiveSToSOp::GetSbp(user_op::SbpContext* ctx) { + return Error::TypeError() << "eager_naive_s_to_s op doesn't support consistent tensor!"; +} + +/* static */ Maybe EagerNaiveSToSOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + return Error::TypeError() << "eager_naive_s_to_s op doesn't support consistent tensor!"; +} -// Can only be called in mirrored -REGISTER_NO_GRAD_USER_OP("eager_naive_s_to_s") - .Input("in") - .Output("out") - .Attr("in_split_axis", -1) - .Attr("out_split_axis", -1) - .Attr("in_parallel_conf") - .Attr("out_parallel_conf") - .Attr("shape") - .SetTensorDescInferFn(TensorDescInfer) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - return Error::TypeError() << "eager_naive_s_to_s op doesn't support consistent tensor!"; - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - return Error::TypeError() << "eager_naive_s_to_s op doesn't support consistent tensor!"; - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/* static */ Maybe EagerNaiveSToSOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> EagerNaiveSToSOp::InferDevice(user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} } // namespace oneflow diff --git a/oneflow/user/ops/eager_symmetric_s_to_p_op.cpp b/oneflow/user/ops/eager_symmetric_s_to_p_op.cpp index 847ee7272a9..c108a33b8cb 100644 --- a/oneflow/user/ops/eager_symmetric_s_to_p_op.cpp +++ b/oneflow/user/ops/eager_symmetric_s_to_p_op.cpp @@ -17,54 +17,61 @@ limitations under the License. #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/framework/device.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("eager_symmetric_s_to_p") - .Input("in") - .Output("out") - .Attr("in_split_axis", -1) - .Attr("parallel_conf") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const int64_t in_split_axis = ctx->user_op_conf().attr("in_split_axis"); - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - cfg::NdSbp* in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); - CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); - for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { - CHECK_OR_RETURN(sbp_hint.has_split_parallel()); - CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), in_split_axis); - } +/* static */ Maybe EagerSymmetricSToPOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} - in_nd_sbp->clear_sbp_parallel(); - out_nd_sbp->clear_sbp_parallel(); +/*static*/ Maybe EagerSymmetricSToPOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); - for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { - in_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(in_split_axis); - out_nd_sbp->add_sbp_parallel()->mutable_partial_sum_parallel(); - } - return Maybe::Ok(); - }); +/* static */ Maybe EagerSymmetricSToPOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), i) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe EagerSymmetricSToPOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + const int64_t in_split_axis = ctx->user_op_conf().attr("in_split_axis"); + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + cfg::NdSbp* in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); + CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); + for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { + CHECK_OR_RETURN(sbp_hint.has_split_parallel()); + CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), in_split_axis); + } + + in_nd_sbp->clear_sbp_parallel(); + out_nd_sbp->clear_sbp_parallel(); + + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); + for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { + in_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(in_split_axis); + out_nd_sbp->add_sbp_parallel()->mutable_partial_sum_parallel(); + } + return Maybe::Ok(); +} + +/* static */ Maybe EagerSymmetricSToPOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> EagerSymmetricSToPOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} } // namespace oneflow diff --git a/oneflow/user/ops/elementwise_maximum_minimum_ops.cpp b/oneflow/user/ops/elementwise_maximum_minimum_ops.cpp index 4bce761d226..7a143bb4ecd 100644 --- a/oneflow/user/ops/elementwise_maximum_minimum_ops.cpp +++ b/oneflow/user/ops/elementwise_maximum_minimum_ops.cpp @@ -14,13 +14,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { using namespace user_op; -Maybe GetSbpSignature(SbpContext* ctx) { +Maybe GetSbpSignature_(SbpContext* ctx) { const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); const Shape& y_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0).shape(); @@ -35,7 +36,7 @@ Maybe GetSbpSignature(SbpContext* ctx) { return Maybe::Ok(); } -Maybe InferTensorDesc(InferContext* ctx) { +Maybe InferTensorDesc_(InferContext* ctx) { const TensorDesc& tensor_x = ctx->InputTensorDesc("x", 0); const TensorDesc& tensor_y = ctx->InputTensorDesc("y", 0); @@ -56,7 +57,7 @@ Maybe InferTensorDesc(InferContext* ctx) { return Maybe::Ok(); } -Maybe InferDataType(InferContext* ctx) { +Maybe InferDataType_(InferContext* ctx) { const TensorDesc& tensor_dz = ctx->InputTensorDesc("dz", 0); TensorDesc* tensor_dx = ctx->OutputTensorDesc("dx", 0); TensorDesc* tensor_dy = ctx->OutputTensorDesc("dy", 0); @@ -101,36 +102,55 @@ user_op::BackwardOpConfGenFn MakeGenBackwardOpFn(const std::string& op_type_name } // namespace -#define REGISTER_ELEMENTWISE_XIMUM_FW_OP(op_type_name) \ - REGISTER_USER_OP(op_type_name) \ - .Input("x") \ - .Input("y") \ - .Output("z") \ - .SetTensorDescInferFn(user_op::TensorDescInferFnUtil::Unchanged) \ - .SetGetSbpFn(user_op::GetSbpFnUtil::SplitForEachAxis) \ - .SetDataTypeInferFn(user_op::TensorDescInferFnUtil::UnchangedDataType) - -#define REGISTER_ELEMENTWISE_XIMUM_BW_OP(op_type_name) \ - REGISTER_USER_OP(op_type_name) \ - .Input("dz") \ - .Input("x") \ - .Input("y") \ - .OptionalOutput("dx") \ - .OptionalOutput("dy") \ - .SetTensorDescInferFn(InferTensorDesc) \ - .SetGetSbpFn(GetSbpSignature) \ - .SetDataTypeInferFn(InferDataType) +#define DEF_ELEMENTWISE_XIMUM_FW_OP(op_class_name_prefix) \ + /* static */ Maybe op_class_name_prefix##Op::InferLogicalTensorDesc( \ + user_op::InferContext* ctx) { \ + return user_op::TensorDescInferFnUtil::Unchanged(ctx); \ + } \ + \ + /*static*/ Maybe op_class_name_prefix##Op::InferPhysicalTensorDesc( \ + user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + \ + /* static */ Maybe op_class_name_prefix##Op::GetSbp(user_op::SbpContext* ctx) { \ + return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); \ + } \ + \ + /* static */ Maybe op_class_name_prefix##Op::InferDataType(user_op::InferContext* ctx) { \ + return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx); \ + } + +#define DEF_ELEMENTWISE_XIMUM_BW_OP(op_class_name_prefix) \ + /* static */ Maybe op_class_name_prefix##BackwardOp::InferLogicalTensorDesc( \ + user_op::InferContext* ctx) { \ + return InferTensorDesc_(ctx); \ + } \ + \ + /*static*/ Maybe op_class_name_prefix##BackwardOp::InferPhysicalTensorDesc( \ + user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + \ + /* static */ Maybe op_class_name_prefix##BackwardOp::GetSbp(user_op::SbpContext* ctx) { \ + return GetSbpSignature_(ctx); \ + } \ + \ + /* static */ Maybe op_class_name_prefix##BackwardOp::InferDataType( \ + user_op::InferContext* ctx) { \ + return InferDataType_(ctx); \ + } #define REGISTER_ELEMENTWISE_XIMUM_GRAD(op_type_name) \ REGISTER_USER_OP_GRAD(op_type_name) \ .SetBackwardOpConfGenFn(MakeGenBackwardOpFn(std::string(op_type_name))); -#define REGISTER_ELEMENTWISE_XIMUM_OP(op_type_name) \ - REGISTER_ELEMENTWISE_XIMUM_FW_OP(op_type_name); \ - REGISTER_ELEMENTWISE_XIMUM_BW_OP(op_type_name "_backward"); \ +#define REGISTER_ELEMENTWISE_XIMUM_OP(op_type_name, op_class_name_prefix) \ + DEF_ELEMENTWISE_XIMUM_FW_OP(op_class_name_prefix); \ + DEF_ELEMENTWISE_XIMUM_BW_OP(op_class_name_prefix); \ REGISTER_ELEMENTWISE_XIMUM_GRAD(op_type_name); -REGISTER_ELEMENTWISE_XIMUM_OP("elementwise_maximum"); -REGISTER_ELEMENTWISE_XIMUM_OP("elementwise_minimum"); +REGISTER_ELEMENTWISE_XIMUM_OP("elementwise_maximum", ElementwiseMaximum); +REGISTER_ELEMENTWISE_XIMUM_OP("elementwise_minimum", ElementwiseMinimum); } // namespace oneflow diff --git a/oneflow/user/ops/elu_op.cpp b/oneflow/user/ops/elu_op.cpp index 13cf0de77ec..9de85d34655 100644 --- a/oneflow/user/ops/elu_op.cpp +++ b/oneflow/user/ops/elu_op.cpp @@ -14,63 +14,62 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/* static */ Maybe EluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("elu") - .Input("in") - .Attr("alpha") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe EluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} -REGISTER_USER_OP("elu_grad") - .Input("x") - .Input("dy") - .Attr("alpha") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(dy_shape == x_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/* static */ Maybe EluOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe EluOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe EluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(dy_shape == x_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} + +/*static*/ Maybe EluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe EluGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe EluGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("elu").SetBackwardOpConfGenFn( [](user_op::BackwardOpConfContext* ctx) -> Maybe { @@ -90,6 +89,4 @@ REGISTER_USER_OP_GRAD("elu").SetBackwardOpConfGenFn( return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/empty_op.cpp b/oneflow/user/ops/empty_op.cpp index ff8e3c45c6d..8c8c24f68d8 100644 --- a/oneflow/user/ops/empty_op.cpp +++ b/oneflow/user/ops/empty_op.cpp @@ -15,45 +15,45 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/balanced_splitter.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("empty") - .Output("out") - .SetOutputBufferNum(1) - .Attr("dtype") - .Attr("shape") - .Attr>("nd_sbp") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); - return Maybe::Ok(); - }) - .SetPhysicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& shape = ctx->Attr("shape"); - DimVector dim_vec{shape.dim_vec()}; - - const cfg::SbpParallel& out_sbp_para = ctx->SbpParallel4ArgNameAndIndex("out", 0); - if (out_sbp_para.has_split_parallel()) { - const int64_t& parallel_num = ctx->parallel_ctx().parallel_num(); - if (parallel_num > 1) { - const int64_t& split_axis = out_sbp_para.split_parallel().axis(); - CHECK_LT_OR_RETURN(split_axis, dim_vec.size()); - BalancedSplitter bs(shape.At(split_axis), parallel_num); - dim_vec[split_axis] = bs.At(ctx->parallel_ctx().parallel_id()).size(); - } - } - - *ctx->OutputShape("out", 0) = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->Attr("dtype"); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { return Maybe::Ok(); }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - cfg::SbpParallel default_sbp; - default_sbp.mutable_broadcast_parallel(); - return user_op::InferNdSbp4SrcOp(ctx, default_sbp); - }); +/* static */ Maybe EmptyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); + return Maybe::Ok(); +} + +/* static */ Maybe EmptyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + const Shape& shape = ctx->Attr("shape"); + DimVector dim_vec{shape.dim_vec()}; + + const cfg::SbpParallel& out_sbp_para = ctx->SbpParallel4ArgNameAndIndex("out", 0); + if (out_sbp_para.has_split_parallel()) { + const int64_t& parallel_num = ctx->parallel_ctx().parallel_num(); + if (parallel_num > 1) { + const int64_t& split_axis = out_sbp_para.split_parallel().axis(); + CHECK_LT_OR_RETURN(split_axis, dim_vec.size()); + BalancedSplitter bs(shape.At(split_axis), parallel_num); + dim_vec[split_axis] = bs.At(ctx->parallel_ctx().parallel_id()).size(); + } + } + + *ctx->OutputShape("out", 0) = Shape(dim_vec); + return Maybe::Ok(); +} + +/* static */ Maybe EmptyOp::GetSbp(user_op::SbpContext* ctx) { return Maybe::Ok(); } + +/* static */ Maybe EmptyOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + cfg::SbpParallel default_sbp; + default_sbp.mutable_broadcast_parallel(); + return user_op::InferNdSbp4SrcOp(ctx, default_sbp); +} + +/* static */ Maybe EmptyOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->Attr("dtype"); + return Maybe::Ok(); +} + } // namespace oneflow diff --git a/oneflow/user/ops/expand_dims_op.cpp b/oneflow/user/ops/expand_dims_op.cpp index 99c1be0f79b..f5031f7a1b3 100644 --- a/oneflow/user/ops/expand_dims_op.cpp +++ b/oneflow/user/ops/expand_dims_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -28,43 +29,45 @@ int32_t TransformNegativeAxisToPositive(int32_t axis, const int32_t num_axes) { } // namespace -REGISTER_USER_OP("expand_dims") - .Input("in") - .Output("out") - .Attr("axis") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); - const int32_t axis = - TransformNegativeAxisToPositive(ctx->Attr("axis"), in_shape.NumAxes()); +/* static */ Maybe ExpandDimsOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + Shape* out_shape = ctx->OutputShape("out", 0); + const int32_t axis = + TransformNegativeAxisToPositive(ctx->Attr("axis"), in_shape.NumAxes()); - auto dim_vec = in_shape.dim_vec(); - dim_vec.insert(dim_vec.begin() + axis, 1); - *out_shape = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - const int32_t axis = - TransformNegativeAxisToPositive(ctx->Attr("axis"), in_tensor.shape().NumAxes()); + auto dim_vec = in_shape.dim_vec(); + dim_vec.insert(dim_vec.begin() + axis, 1); + *out_shape = Shape(dim_vec); + return Maybe::Ok(); +} - auto dim_vec = in_tensor.shape().dim_vec(); - FOR_RANGE(int32_t, in_axis, 0, dim_vec.size()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), in_axis) - .Split(user_op::OpArg("out", 0), in_axis < axis ? in_axis : in_axis + 1) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ExpandDimsOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ExpandDimsOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + const int32_t axis = + TransformNegativeAxisToPositive(ctx->Attr("axis"), in_tensor.shape().NumAxes()); + + auto dim_vec = in_tensor.shape().dim_vec(); + FOR_RANGE(int32_t, in_axis, 0, dim_vec.size()) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), in_axis) + .Split(user_op::OpArg("out", 0), in_axis < axis ? in_axis : in_axis + 1) + .Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ExpandDimsOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("expand_dims") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/expand_op.cpp b/oneflow/user/ops/expand_op.cpp index 4ac73550bf9..9e8cfd5c2ef 100644 --- a/oneflow/user/ops/expand_op.cpp +++ b/oneflow/user/ops/expand_op.cpp @@ -15,116 +15,119 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/expand_kernel_utils.h" +#include "oneflow/core/framework/op_generated.h" + namespace oneflow { -REGISTER_USER_OP("expand") - .Input("in") - .Output("out") - .Attr>("logical_in_shape") - .Attr>("logical_expand_shape") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& input_shape = ctx->InputShape("in", 0); - const std::vector& logical_expand_shape = - ctx->Attr>("logical_expand_shape"); - - std::vector in_shape; - in_shape.resize(input_shape.NumAxes()); - for (int i = 0; i < input_shape.NumAxes(); ++i) { in_shape[i] = input_shape.At(i); } - - std::vector out_shape; - std::vector stride; - CHECK_JUST(getOutShapeAndStrideForFp(in_shape, logical_expand_shape, out_shape, stride)); - - Shape* output_shape = ctx->OutputShape("out", 0); - DimVector dim_vec(out_shape.begin(), out_shape.end()); - *output_shape = Shape(dim_vec); - - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const std::vector& logical_in_shape = - ctx->Attr>("logical_in_shape"); - const std::vector& logical_expand_shape = - ctx->Attr>("logical_expand_shape"); - std::vector logical_out_shape; - std::vector stride; - CHECK_JUST( - getOutShapeAndStride(logical_in_shape, logical_expand_shape, logical_out_shape, stride)); - - int offset = logical_out_shape.size() - logical_in_shape.size(); - FOR_RANGE(int64_t, i, 0, logical_in_shape.size()) { - if (logical_in_shape[i] == logical_out_shape[i + offset]) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i + offset) - .Build(); - } - } +/* static */ Maybe ExpandOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& input_shape = ctx->InputShape("in", 0); + const std::vector& logical_expand_shape = + ctx->Attr>("logical_expand_shape"); + + std::vector in_shape; + in_shape.resize(input_shape.NumAxes()); + for (int i = 0; i < input_shape.NumAxes(); ++i) { in_shape[i] = input_shape.At(i); } + + std::vector out_shape; + std::vector stride; + CHECK_JUST(getOutShapeAndStrideForFp(in_shape, logical_expand_shape, out_shape, stride)); + + Shape* output_shape = ctx->OutputShape("out", 0); + DimVector dim_vec(out_shape.begin(), out_shape.end()); + *output_shape = Shape(dim_vec); + return Maybe::Ok(); +} + +/*static*/ Maybe ExpandOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ExpandOp::GetSbp(user_op::SbpContext* ctx) { + const std::vector& logical_in_shape = + ctx->Attr>("logical_in_shape"); + const std::vector& logical_expand_shape = + ctx->Attr>("logical_expand_shape"); + std::vector logical_out_shape; + std::vector stride; + CHECK_JUST( + getOutShapeAndStride(logical_in_shape, logical_expand_shape, logical_out_shape, stride)); + + int offset = logical_out_shape.size() - logical_in_shape.size(); + FOR_RANGE(int64_t, i, 0, logical_in_shape.size()) { + if (logical_in_shape[i] == logical_out_shape[i + offset]) { ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) + .Split(user_op::OpArg("in", 0), i) + .Split(user_op::OpArg("out", 0), i + offset) .Build(); - return Maybe::Ok(); - }); - -REGISTER_USER_OP("expand_grad") - .Input("in") - .Output("out") - .Attr>("logical_out_shape") - .Attr>("logical_expand_shape") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& input_shape = ctx->InputShape("in", 0); - const std::vector& logical_out_shape = - ctx->Attr>("logical_out_shape"); - const std::vector& logical_expand_shape = - ctx->Attr>("logical_expand_shape"); - - std::vector in_shape; - in_shape.resize(input_shape.NumAxes()); - for (int i = 0; i < input_shape.NumAxes(); ++i) { in_shape[i] = input_shape.At(i); } - std::vector out_shape; - std::vector stride; - CHECK_JUST(getOutShapeAndStrideForBp(logical_out_shape, logical_expand_shape, in_shape, - out_shape, stride)); - - Shape* output_shape = ctx->OutputShape("out", 0); - DimVector dim_vec(out_shape.begin(), out_shape.end()); - *output_shape = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& input_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - const std::vector& logical_out_shape = - ctx->Attr>("logical_out_shape"); - const std::vector& logical_expand_shape = - ctx->Attr>("logical_expand_shape"); - - int offset = input_tensor.shape().NumAxes() - logical_out_shape.size(); - FOR_RANGE(int64_t, i, 0, logical_out_shape.size()) { - if (logical_out_shape[i] == input_tensor.shape().At(i + offset)) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i + offset) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - } + } + } + + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} +/* static */ Maybe ExpandOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe ExpandGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& input_shape = ctx->InputShape("in", 0); + const std::vector& logical_out_shape = + ctx->Attr>("logical_out_shape"); + const std::vector& logical_expand_shape = + ctx->Attr>("logical_expand_shape"); + + std::vector in_shape; + in_shape.resize(input_shape.NumAxes()); + for (int i = 0; i < input_shape.NumAxes(); ++i) { in_shape[i] = input_shape.At(i); } + std::vector out_shape; + std::vector stride; + CHECK_JUST(getOutShapeAndStrideForBp(logical_out_shape, logical_expand_shape, in_shape, out_shape, + stride)); + + Shape* output_shape = ctx->OutputShape("out", 0); + DimVector dim_vec(out_shape.begin(), out_shape.end()); + *output_shape = Shape(dim_vec); + return Maybe::Ok(); +} + +/*static*/ Maybe ExpandGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ExpandGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& input_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + const std::vector& logical_out_shape = + ctx->Attr>("logical_out_shape"); + const std::vector& logical_expand_shape = + ctx->Attr>("logical_expand_shape"); + + int offset = input_tensor.shape().NumAxes() - logical_out_shape.size(); + FOR_RANGE(int64_t, i, 0, logical_out_shape.size()) { + if (logical_out_shape[i] == input_tensor.shape().At(i + offset)) { ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) + .Split(user_op::OpArg("in", 0), i + offset) + .Split(user_op::OpArg("out", 0), i) .Build(); - return Maybe::Ok(); - }); + } + } + + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ExpandGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("expand").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/eye_op.cpp b/oneflow/user/ops/eye_op.cpp index ca91db92e6d..077758b2452 100644 --- a/oneflow/user/ops/eye_op.cpp +++ b/oneflow/user/ops/eye_op.cpp @@ -14,23 +14,29 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("eye") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - int64_t rows = ctx->Attr("rows"); - int64_t cols = ctx->Attr("cols"); - *ctx->OutputShape("out", 0) = Shape({rows, cols}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->Attr("dtype"); - return Maybe::Ok(); - }); + +/* static */ Maybe EyeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + int64_t rows = ctx->Attr("rows"); + int64_t cols = ctx->Attr("cols"); + *ctx->OutputShape("out", 0) = Shape({rows, cols}); + return Maybe::Ok(); +} + +/*static*/ Maybe EyeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe EyeOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe EyeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->Attr("dtype"); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/fake_quantization_op.cpp b/oneflow/user/ops/fake_quantization_op.cpp index d8fa0242f96..fbe6a7d8ca6 100644 --- a/oneflow/user/ops/fake_quantization_op.cpp +++ b/oneflow/user/ops/fake_quantization_op.cpp @@ -14,104 +14,99 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -REGISTER_USER_OP("fake_quantization") - .Input("in") - .Input("scale") - .Input("zero_point") - .Output("out") - // NOTE(Liang Depeng): "google" or "cambricon" - .Attr("quantization_formula", "google") - // NOTE(Liang Depeng): quantize from float32 to "quantization_bit" bit signed or unsigned - // integer - .Attr("quantization_bit", 8) - // NOTE(Liang Depeng): "symmetric" or "affine": quantize to signed or unsigned integer - .Attr("quantization_scheme", "symmetric") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - const Shape& scale_shape = ctx->InputShape("scale", 0); - const Shape& zero_point_shape = ctx->InputShape("zero_point", 0); - - // NOTE(Liang Depeng): scale_shape->elem_cnt() > 1 means per-channel quantization for - // convolution weights. - if (scale_shape.elem_cnt() > 1) { - CHECK_EQ_OR_RETURN(scale_shape.elem_cnt(), in_shape.At(0)); - CHECK_EQ_OR_RETURN(zero_point_shape.elem_cnt(), in_shape.At(0)); - } +/* static */ Maybe FakeQuantizationOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + const Shape& scale_shape = ctx->InputShape("scale", 0); + const Shape& zero_point_shape = ctx->InputShape("zero_point", 0); - *ctx->OutputShape("out", 0) = in_shape; - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* scale = GetInputArgModifierFn("scale", 0); - CHECK_OR_RETURN(scale != nullptr); - scale->set_requires_grad(false); - - user_op::InputArgModifier* zero_point = GetInputArgModifierFn("zero_point", 0); - CHECK_OR_RETURN(zero_point != nullptr); - zero_point->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - const Shape& logical_scale_shape = - ctx->LogicalTensorDesc4InputArgNameAndIndex("scale", 0).shape(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("in", 0)) - .Broadcast(user_op::OpArg("scale", 0)) - .Broadcast(user_op::OpArg("zero_point", 0)) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - if (logical_scale_shape.elem_cnt() > 1) { - // NOTE(Liang Depeng): only consider convolution weight per-channel quantization - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), 0) - .Split(user_op::OpArg("scale", 0), 0) - .Split(user_op::OpArg("zero_point", 0), 0) - .Split(user_op::OpArg("out", 0), 0) - .Build(); - } else { - // NOTE(Liang Depeng): the sbp signature of per-layer quantization is the same as eltwise - // ops - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), 0) - .Broadcast(user_op::OpArg("scale", 0)) - .Broadcast(user_op::OpArg("zero_point", 0)) - .Split(user_op::OpArg("out", 0), 0) - .Build(); - } - FOR_RANGE(int64_t, i, 1, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Broadcast(user_op::OpArg("scale", 0)) - .Broadcast(user_op::OpArg("zero_point", 0)) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& op_def, - const user_op::UserOpConfWrapper& op_conf) -> Maybe { - const int32_t quantization_bit = op_conf.attr("quantization_bit"); - CHECK_GT_OR_RETURN(quantization_bit, 1); - CHECK_LE_OR_RETURN(quantization_bit, 8); - - std::string quantization_scheme = op_conf.attr("quantization_scheme"); - CHECK_OR_RETURN(quantization_scheme == "symmetric" || quantization_scheme == "affine"); - - std::string quantization_formula = op_conf.attr("quantization_formula"); - CHECK_OR_RETURN(quantization_formula == "google" || quantization_formula == "cambricon"); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); + // NOTE(Liang Depeng): scale_shape->elem_cnt() > 1 means per-channel quantization for + // convolution weights. + if (scale_shape.elem_cnt() > 1) { + CHECK_EQ_OR_RETURN(scale_shape.elem_cnt(), in_shape.At(0)); + CHECK_EQ_OR_RETURN(zero_point_shape.elem_cnt(), in_shape.At(0)); + } + + *ctx->OutputShape("out", 0) = in_shape; + return Maybe::Ok(); +} + +/*static*/ Maybe FakeQuantizationOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe FakeQuantizationOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + const Shape& logical_scale_shape = + ctx->LogicalTensorDesc4InputArgNameAndIndex("scale", 0).shape(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("in", 0)) + .Broadcast(user_op::OpArg("scale", 0)) + .Broadcast(user_op::OpArg("zero_point", 0)) + .Broadcast(user_op::OpArg("out", 0)) + .Build(); + if (logical_scale_shape.elem_cnt() > 1) { + // NOTE(Liang Depeng): only consider convolution weight per-channel quantization + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), 0) + .Split(user_op::OpArg("scale", 0), 0) + .Split(user_op::OpArg("zero_point", 0), 0) + .Split(user_op::OpArg("out", 0), 0) + .Build(); + } else { + // NOTE(Liang Depeng): the sbp signature of per-layer quantization is the same as eltwise + // ops + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), 0) + .Broadcast(user_op::OpArg("scale", 0)) + .Broadcast(user_op::OpArg("zero_point", 0)) + .Split(user_op::OpArg("out", 0), 0) + .Build(); + } + FOR_RANGE(int64_t, i, 1, in_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), i) + .Broadcast(user_op::OpArg("scale", 0)) + .Broadcast(user_op::OpArg("zero_point", 0)) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe FakeQuantizationOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* scale = GetInputArgModifierFn("scale", 0); + CHECK_OR_RETURN(scale != nullptr); + scale->set_requires_grad(false); + + user_op::InputArgModifier* zero_point = GetInputArgModifierFn("zero_point", 0); + CHECK_OR_RETURN(zero_point != nullptr); + zero_point->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe FakeQuantizationOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + const int32_t quantization_bit = conf.attr("quantization_bit"); + CHECK_GT_OR_RETURN(quantization_bit, 1); + CHECK_LE_OR_RETURN(quantization_bit, 8); + + std::string quantization_scheme = conf.attr("quantization_scheme"); + CHECK_OR_RETURN(quantization_scheme == "symmetric" || quantization_scheme == "affine"); + + std::string quantization_formula = conf.attr("quantization_formula"); + CHECK_OR_RETURN(quantization_formula == "google" || quantization_formula == "cambricon"); + return Maybe::Ok(); +} + +/* static */ Maybe FakeQuantizationOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("fake_quantization") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -129,6 +124,4 @@ REGISTER_USER_OP_GRAD("fake_quantization") return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/flatten_op.cpp b/oneflow/user/ops/flatten_op.cpp index 03f0b9b2b97..487d7abc372 100644 --- a/oneflow/user/ops/flatten_op.cpp +++ b/oneflow/user/ops/flatten_op.cpp @@ -14,39 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -Maybe GetSbpFn(user_op::SbpContext* ctx) { - const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); - const int32_t start_dim = ctx->Attr("start_dim"); - const int32_t end_dim = ctx->Attr("end_dim"); - - CHECK_GE_OR_RETURN(start_dim, 0); - CHECK_LT_OR_RETURN(start_dim, in_shape.NumAxes()); - const int32_t true_end_dim = end_dim < 0 ? end_dim + in_shape.NumAxes() : end_dim; - CHECK_GE_OR_RETURN(true_end_dim, 0); - CHECK_LT_OR_RETURN(true_end_dim, in_shape.NumAxes()); - CHECK_LE_OR_RETURN(start_dim, true_end_dim); - - for (int i = 0; i <= start_dim; ++i) { - ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); - } - const int32_t diff = true_end_dim - start_dim; - for (int i = true_end_dim + 1; i < in_shape.NumAxes(); ++i) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i - diff) - .Build(); - } - - ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); - return Maybe::Ok(); -} - -Maybe TensorDescInferFn(user_op::InferContext* ctx) { +/* static */ Maybe FlattenOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const int32_t start_dim = ctx->Attr("start_dim"); const int32_t end_dim = ctx->Attr("end_dim"); const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc("in", 0); @@ -79,19 +51,41 @@ Maybe TensorDescInferFn(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe DataTypeInferFn(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); +/*static*/ Maybe FlattenOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe FlattenOp::GetSbp(user_op::SbpContext* ctx) { + const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); + const int32_t start_dim = ctx->Attr("start_dim"); + const int32_t end_dim = ctx->Attr("end_dim"); + + CHECK_GE_OR_RETURN(start_dim, 0); + CHECK_LT_OR_RETURN(start_dim, in_shape.NumAxes()); + const int32_t true_end_dim = end_dim < 0 ? end_dim + in_shape.NumAxes() : end_dim; + CHECK_GE_OR_RETURN(true_end_dim, 0); + CHECK_LT_OR_RETURN(true_end_dim, in_shape.NumAxes()); + CHECK_LE_OR_RETURN(start_dim, true_end_dim); + + for (int i = 0; i <= start_dim; ++i) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + const int32_t diff = true_end_dim - start_dim; + for (int i = true_end_dim + 1; i < in_shape.NumAxes(); ++i) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), i) + .Split(user_op::OpArg("out", 0), i - diff) + .Build(); + } + + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); return Maybe::Ok(); } -REGISTER_USER_OP("flatten") - .Input("in") - .Output("out") - .Attr("start_dim", 0) - .Attr("end_dim", -1) - .SetTensorDescInferFn(TensorDescInferFn) - .SetGetSbpFn(GetSbpFn) - .SetDataTypeInferFn(DataTypeInferFn); +/* static */ Maybe FlattenOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("flatten").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { @@ -109,5 +103,4 @@ REGISTER_USER_OP_GRAD("flatten").SetGenBackwardOpConfFn([](const user_op::UserOp return Maybe::Ok(); }); -} // namespace } // namespace oneflow diff --git a/oneflow/user/ops/flip_op.cpp b/oneflow/user/ops/flip_op.cpp index 4176082055b..e062fc6b422 100644 --- a/oneflow/user/ops/flip_op.cpp +++ b/oneflow/user/ops/flip_op.cpp @@ -14,50 +14,48 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("flip") - .Input("x") - .Output("y") - .Attr>("dims") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc* x_desc = ctx->TensorDesc4ArgNameAndIndex("x", 0); - const int input_dims = x_desc->shape().NumAxes(); - const std::vector dims = ctx->Attr>("dims"); - CHECK_OR_RETURN(dims.size() <= input_dims) - << "len of dims must less than len of input tensor"; - for (auto x : dims) { CHECK_OR_RETURN(x < input_dims) << "dims parameter is illegal."; } - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - *y_desc->mut_shape() = x_desc->shape(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ auto FlipOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + const int input_dims = x_desc.shape().NumAxes(); + const std::vector dims = ctx->Attr>("dims"); + CHECK_OR_RETURN(dims.size() <= input_dims) << "len of dims must less than len of input tensor"; + for (auto x : dims) { CHECK_OR_RETURN(x < input_dims) << "dims parameter is illegal."; } + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + *y_desc->mut_shape() = x_desc.shape(); + return Maybe::Ok(); +} +/*static*/ auto FlipOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { + return FlipOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FlipOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ auto FlipOp::InferDataType(user_op::InferContext* ctx) -> Maybe { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("flip_grad") - .Input("dy") - .Output("dx") - .Attr>("dims") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ auto FlipGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + *dx_shape = dy_shape; + return Maybe::Ok(); +} +/*static*/ auto FlipGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { + return FlipGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FlipGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ auto FlipGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/fused_bias_add_op.cpp b/oneflow/user/ops/fused_bias_add_op.cpp index ad8f3539a5a..46f9394ff18 100644 --- a/oneflow/user/ops/fused_bias_add_op.cpp +++ b/oneflow/user/ops/fused_bias_add_op.cpp @@ -14,93 +14,93 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("fused_bias_add_gelu") - .Input("a") - .Input("b") - .Output("out") - .Attr("axis") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); - const auto& b_tensor_desc = ctx->InputTensorDesc("b", 0); - const auto bias_add_axis = ctx->Attr("axis"); - CHECK_EQ_OR_RETURN(b_tensor_desc.shape().NumAxes(), 1); - CHECK_GE_OR_RETURN(bias_add_axis, 0); - CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); - CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); - *ctx->OutputShape("out", 0) = a_tensor_desc.shape(); - *ctx->OutputIsDynamic("out", 0) = a_tensor_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); - *ctx->OutputDType("out", 0) = a_tensor_desc.data_type(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto axis = ctx->Attr("axis"); - for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape().NumAxes(); - ++i) { - if (i == axis) { continue; } - ctx->NewBuilder() - .Split(user_op::OpArg("a", 0), i) - .Broadcast(user_op::OpArg("b", 0)) - .Split(ctx->outputs(), i) - .Build(); - } - ctx->NewBuilder() - .Split(user_op::OpArg("b", 0), 0) - .Split(user_op::OpArg("a", 0), axis) - .Split(ctx->outputs(), axis) - .Build(); - return Maybe::Ok(); - }); +/*static*/ auto FusedBiasAddGeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); + const auto& b_tensor_desc = ctx->InputTensorDesc("b", 0); + const auto bias_add_axis = ctx->Attr("axis"); + CHECK_EQ_OR_RETURN(b_tensor_desc.shape().NumAxes(), 1); + CHECK_GE_OR_RETURN(bias_add_axis, 0); + CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); + CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); + *ctx->OutputShape("out", 0) = a_tensor_desc.shape(); + *ctx->OutputIsDynamic("out", 0) = a_tensor_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ auto FusedBiasAddGeluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + return FusedBiasAddGeluOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FusedBiasAddGeluOp::InferDataType(user_op::InferContext* ctx) -> Maybe { + const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); + *ctx->OutputDType("out", 0) = a_tensor_desc.data_type(); + return Maybe::Ok(); +} +/*static*/ auto FusedBiasAddGeluOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + const auto axis = ctx->Attr("axis"); + for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape().NumAxes(); + ++i) { + if (i == axis) { continue; } + ctx->NewBuilder() + .Split(user_op::OpArg("a", 0), i) + .Broadcast(user_op::OpArg("b", 0)) + .Split(ctx->outputs(), i) + .Build(); + } + ctx->NewBuilder() + .Split(user_op::OpArg("b", 0), 0) + .Split(user_op::OpArg("a", 0), axis) + .Split(ctx->outputs(), axis) + .Build(); + return Maybe::Ok(); +} +/*static*/ auto FusedBiasAddGeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); + const auto& b_tensor_desc = ctx->InputTensorDesc("b", 0); + const auto bias_add_axis = ctx->Attr("axis"); + CHECK_EQ_OR_RETURN(b_tensor_desc.shape().NumAxes(), 1); + CHECK_GE_OR_RETURN(bias_add_axis, 0); + CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); + CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); + *ctx->OutputShape("dx", 0) = a_tensor_desc.shape(); + *ctx->OutputIsDynamic("dx", 0) = a_tensor_desc.is_dynamic(); + return Maybe::Ok(); +} -REGISTER_USER_OP("fused_bias_add_gelu_grad") - .Input("a") - .Input("b") - .Input("dy") - .Output("dx") - .Attr("axis") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); - const auto& b_tensor_desc = ctx->InputTensorDesc("b", 0); - const auto bias_add_axis = ctx->Attr("axis"); - CHECK_EQ_OR_RETURN(b_tensor_desc.shape().NumAxes(), 1); - CHECK_GE_OR_RETURN(bias_add_axis, 0); - CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); - CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); - *ctx->OutputShape("dx", 0) = a_tensor_desc.shape(); - *ctx->OutputIsDynamic("dx", 0) = a_tensor_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); - *ctx->OutputDType("dx", 0) = a_tensor_desc.data_type(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto axis = ctx->Attr("axis"); - for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape().NumAxes(); - ++i) { - if (i == axis) { continue; } - ctx->NewBuilder() - .Split(user_op::OpArg("a", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Broadcast(user_op::OpArg("b", 0)) - .Split(ctx->outputs(), i) - .Build(); - } - ctx->NewBuilder() - .Split(user_op::OpArg("b", 0), 0) - .Split(user_op::OpArg("a", 0), axis) - .Split(user_op::OpArg("dy", 0), axis) - .Split(ctx->outputs(), axis) - .Build(); - return Maybe::Ok(); - }); +/*static*/ auto FusedBiasAddGeluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + return FusedBiasAddGeluGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FusedBiasAddGeluGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe { + const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); + *ctx->OutputDType("dx", 0) = a_tensor_desc.data_type(); + return Maybe::Ok(); +} +/*static*/ auto FusedBiasAddGeluGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + const auto axis = ctx->Attr("axis"); + for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape().NumAxes(); + ++i) { + if (i == axis) { continue; } + ctx->NewBuilder() + .Split(user_op::OpArg("a", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Broadcast(user_op::OpArg("b", 0)) + .Split(ctx->outputs(), i) + .Build(); + } + ctx->NewBuilder() + .Split(user_op::OpArg("b", 0), 0) + .Split(user_op::OpArg("a", 0), axis) + .Split(user_op::OpArg("dy", 0), axis) + .Split(ctx->outputs(), axis) + .Build(); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("fused_bias_add_gelu") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -141,91 +141,114 @@ REGISTER_USER_OP_GRAD("fused_bias_add_gelu") return Maybe::Ok(); }); -REGISTER_USER_OP("fused_bias_add_mask_scale") - .Input("a") - .Input("b") - .Input("mask") - .OptionalInput("_add_to_output") - .Output("out") - .Attr("axis") - .Attr("scale") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); - const auto& mask_tensor_desc = ctx->InputTensorDesc("mask", 0); - const auto& b_tensor_desc = ctx->InputTensorDesc("b", 0); - const auto bias_add_axis = ctx->Attr("axis"); - CHECK_EQ_OR_RETURN(b_tensor_desc.shape().NumAxes(), 1); - CHECK_GE_OR_RETURN(bias_add_axis, 0); - CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); - CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); - CHECK_EQ_OR_RETURN(a_tensor_desc.shape(), mask_tensor_desc.shape()); - *ctx->OutputShape("out", 0) = a_tensor_desc.shape(); - *ctx->OutputIsDynamic("out", 0) = a_tensor_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); - *ctx->OutputDType("out", 0) = a_tensor_desc.data_type(); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn("mask", 0); - CHECK_OR_RETURN(mask_modifier != nullptr); - mask_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto axis = ctx->Attr("axis"); - std::vector split_args; - split_args.emplace_back("a", 0); - split_args.emplace_back("mask", 0); - split_args.emplace_back("out", 0); - if (ctx->user_op_conf().has_input("_add_to_output", 0)) { - split_args.emplace_back("_add_to_output", 0); - } - for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape().NumAxes(); - ++i) { - if (i == axis) { continue; } - ctx->NewBuilder().Split(split_args, i).Broadcast(user_op::OpArg("b", 0)).Build(); - } - ctx->NewBuilder().Split(user_op::OpArg("b", 0), 0).Split(split_args, axis).Build(); - return Maybe::Ok(); - }); +/*static*/ auto FusedBiasAddMaskScaleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); + const auto& mask_tensor_desc = ctx->InputTensorDesc("mask", 0); + const auto& b_tensor_desc = ctx->InputTensorDesc("b", 0); + const auto bias_add_axis = ctx->Attr("axis"); + CHECK_EQ_OR_RETURN(b_tensor_desc.shape().NumAxes(), 1); + CHECK_GE_OR_RETURN(bias_add_axis, 0); + CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); + CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); + CHECK_EQ_OR_RETURN(a_tensor_desc.shape(), mask_tensor_desc.shape()); + *ctx->OutputShape("out", 0) = a_tensor_desc.shape(); + *ctx->OutputIsDynamic("out", 0) = a_tensor_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ auto FusedBiasAddMaskScaleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + return FusedBiasAddMaskScaleOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FusedBiasAddMaskScaleOp::InferDataType(user_op::InferContext* ctx) -> Maybe { + const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); + *ctx->OutputDType("out", 0) = a_tensor_desc.data_type(); + return Maybe::Ok(); +} +/*static*/ auto FusedBiasAddMaskScaleOp::ModifyInputArg( + const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) + -> Maybe { + user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn("mask", 0); + CHECK_OR_RETURN(mask_modifier != nullptr); + mask_modifier->set_requires_grad(false); + return Maybe::Ok(); +} +/*static*/ auto FusedBiasAddMaskScaleOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + const auto axis = ctx->Attr("axis"); + std::vector split_args; + split_args.emplace_back("a", 0); + split_args.emplace_back("mask", 0); + split_args.emplace_back("out", 0); + if (ctx->user_op_conf().has_input("_add_to_output", 0)) { + split_args.emplace_back("_add_to_output", 0); + } + for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape().NumAxes(); + ++i) { + if (i == axis) { continue; } + ctx->NewBuilder().Split(split_args, i).Broadcast(user_op::OpArg("b", 0)).Build(); + } + ctx->NewBuilder().Split(user_op::OpArg("b", 0), 0).Split(split_args, axis).Build(); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("fused_bias_add_mask_scale") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, - user_op::AddOpFn AddOp) -> Maybe { + const user_op::AddOpFn& AddOp) -> Maybe { if (op.NeedGenGradTensor4OpInput("a", 0) || op.NeedGenGradTensor4OpInput("b", 0)) { - user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_gelu_grad"); - user_op::UserOpConfWrapper dropout_grad_op = - builder.Op("dropout_grad") - .Input("dy", op.GetGradTensorWithOpOutput("out", 0)) - .Input("mask", op.input("mask", 0)) - .Output("dx") - .Attr("scale", op.attr("scale")) - .Build(); - AddOp(dropout_grad_op); + float scale = op.attr("scale"); + if (scale != 1.0) { + user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_dropout_grad"); + user_op::UserOpConfWrapper dropout_grad_op = + builder.Op("dropout_grad") + .Input("dy", op.GetGradTensorWithOpOutput("out", 0)) + .Input("mask", op.input("mask", 0)) + .Output("dx") + .Attr("scale", scale) + .Build(); + AddOp(dropout_grad_op); - if (op.NeedGenGradTensor4OpInput("a", 0)) { - op.BindGradTensorWithOpInput(dropout_grad_op.output("dx", 0), "a", 0); - } - if (op.NeedGenGradTensor4OpInput("b", 0)) { - const int64_t num_axes = op.TensorDesc4ArgNameAndIndex("a", 0).shape().NumAxes(); - const int32_t bias_add_axis = op.attr("axis"); - std::vector reduce_axes_vec; - FOR_RANGE(int64_t, i, 0, num_axes) { - if (i != bias_add_axis) { reduce_axes_vec.emplace_back(i); } + if (op.NeedGenGradTensor4OpInput("a", 0)) { + op.BindGradTensorWithOpInput(dropout_grad_op.output("dx", 0), "a", 0); + } + if (op.NeedGenGradTensor4OpInput("b", 0)) { + const int64_t num_axes = op.TensorDesc4ArgNameAndIndex("a", 0).shape().NumAxes(); + const int32_t bias_add_axis = op.attr("axis"); + std::vector reduce_axes_vec; + FOR_RANGE(int64_t, i, 0, num_axes) { + if (i != bias_add_axis) { reduce_axes_vec.emplace_back(i); } + } + user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); + auto grad_op = builder.Op("reduce_sum") + .Input("input_tensor", dropout_grad_op.output("dx", 0)) + .Output("output_tensor") + .Attr("axis", reduce_axes_vec) + .Attr("keepdims", false) + .Build(); + AddOp(grad_op); + op.BindGradTensorWithOpInput(grad_op.output("output_tensor", 0), "b", 0); + } + } else { + // When dropout_prob = 0.0, scale = 1.0, here we directly use out grad. + if (op.NeedGenGradTensor4OpInput("a", 0)) { + op.BindGradTensorWithOpInput(op.GetGradTensorWithOpOutput("out", 0), "a", 0); + } + if (op.NeedGenGradTensor4OpInput("b", 0)) { + const int64_t num_axes = op.TensorDesc4ArgNameAndIndex("a", 0).shape().NumAxes(); + const int32_t bias_add_axis = op.attr("axis"); + std::vector reduce_axes_vec; + FOR_RANGE(int64_t, i, 0, num_axes) { + if (i != bias_add_axis) { reduce_axes_vec.emplace_back(i); } + } + user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); + auto grad_op = builder.Op("reduce_sum") + .Input("input_tensor", op.GetGradTensorWithOpOutput("out", 0)) + .Output("output_tensor") + .Attr("axis", reduce_axes_vec) + .Attr("keepdims", false) + .Build(); + AddOp(grad_op); + op.BindGradTensorWithOpInput(grad_op.output("output_tensor", 0), "b", 0); } - user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); - auto grad_op = builder.Op("reduce_sum") - .Input("input_tensor", dropout_grad_op.output("dx", 0)) - .Output("output_tensor") - .Attr("axis", reduce_axes_vec) - .Attr("keepdims", false) - .Build(); - AddOp(grad_op); - op.BindGradTensorWithOpInput(grad_op.output("output_tensor", 0), "b", 0); } } return Maybe::Ok(); diff --git a/oneflow/user/ops/fused_cast_scale_op.cpp b/oneflow/user/ops/fused_cast_scale_op.cpp index e09f91c9375..816a10efb06 100644 --- a/oneflow/user/ops/fused_cast_scale_op.cpp +++ b/oneflow/user/ops/fused_cast_scale_op.cpp @@ -14,11 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { -Maybe TensorDescInfer(user_op::InferContext* ctx) { +Maybe FusedCastScaleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); CHECK_EQ_OR_RETURN(scale_by_tensor.shape().NumAxes(), 1); @@ -29,14 +29,18 @@ Maybe TensorDescInfer(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe DataTypeInfer(user_op::InferContext* ctx) { +Maybe FusedCastScaleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return FusedCastScaleOp::InferLogicalTensorDesc(ctx); +} + +Maybe FusedCastScaleOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); *y->mut_data_type() = scale_by_tensor.data_type(); return Maybe::Ok(); } -Maybe GetSbpSignatures(user_op::SbpContext* ctx) { +Maybe FusedCastScaleOp::GetSbp(user_op::SbpContext* ctx) { const auto& x = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); for (int i = 0; i < x.shape().NumAxes(); ++i) { ctx->NewBuilder() @@ -58,14 +62,4 @@ Maybe GetSbpSignatures(user_op::SbpContext* ctx) { return Maybe::Ok(); } -REGISTER_USER_OP("fused_cast_scale") - .Input("x") - .Input("scale_by_tensor") - .Output("y") - .Attr("scale", 1.0) - .SetTensorDescInferFn(TensorDescInfer) - .SetGetSbpFn(GetSbpSignatures) - .SetDataTypeInferFn(DataTypeInfer); - -} // namespace } // namespace oneflow diff --git a/oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp b/oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp index 32bf8db067d..028d218359e 100644 --- a/oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp +++ b/oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp @@ -14,106 +14,102 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ auto FusedScaleMaskSoftmaxDropoutOp::InferLogicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); + CHECK_OR_RETURN(x_desc.shape() == mask_desc.shape()); + *ctx->OutputShape("y", 0) = x_desc.shape(); + *ctx->OutputIsDynamic("y", 0) = x_desc.is_dynamic(); + *ctx->OutputShape("softmax_y", 0) = x_desc.shape(); + *ctx->OutputIsDynamic("softmax_y", 0) = x_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ auto FusedScaleMaskSoftmaxDropoutOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + return FusedScaleMaskSoftmaxDropoutOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FusedScaleMaskSoftmaxDropoutOp::InferDataType(user_op::InferContext* ctx) + -> Maybe { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); + CHECK_OR_RETURN(mask_desc.data_type() == DataType::kInt8); + *ctx->OutputDType("y", 0) = x_desc.data_type(); + *ctx->OutputDType("softmax_y", 0) = x_desc.data_type(); + return Maybe::Ok(); +} +/*static*/ auto FusedScaleMaskSoftmaxDropoutOp::ModifyInputArg( + const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) + -> Maybe { + user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn("mask", 0); + user_op::InputArgModifier* dropout_mask_modifier = GetInputArgModifierFn("dropout_mask", 0); + CHECK_OR_RETURN(mask_modifier != nullptr); + CHECK_OR_RETURN(dropout_mask_modifier != nullptr); + mask_modifier->set_requires_grad(false); + dropout_mask_modifier->set_requires_grad(false); + return Maybe::Ok(); +} +/*static*/ auto FusedScaleMaskSoftmaxDropoutOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + CHECK_GE_OR_RETURN(x_tensor.shape().NumAxes(), 2); + FOR_RANGE(int64_t, axis, 0, x_tensor.shape().NumAxes() - 2) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), axis) + .Split(user_op::OpArg("mask", 0), axis) + .Split(user_op::OpArg("dropout_mask", 0), axis) + .Split(user_op::OpArg("y", 0), axis) + .Split(user_op::OpArg("softmax_y", 0), axis) + .Build(); + } + return Maybe::Ok(); +} -REGISTER_USER_OP("fused_scale_mask_softmax_dropout") - .Input("x") - .Input("mask") - .Input("dropout_mask") - .Output("y") - .Output("softmax_y") - .Attr("scale_value", 1.0) - .Attr("mask_fill_value", 0.) - .Attr("dropout_scale_value", 1.0) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); - CHECK_OR_RETURN(x_desc.shape() == mask_desc.shape()); - *ctx->OutputShape("y", 0) = x_desc.shape(); - *ctx->OutputIsDynamic("y", 0) = x_desc.is_dynamic(); - *ctx->OutputShape("softmax_y", 0) = x_desc.shape(); - *ctx->OutputIsDynamic("softmax_y", 0) = x_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); - CHECK_OR_RETURN(mask_desc.data_type() == DataType::kInt8); - *ctx->OutputDType("y", 0) = x_desc.data_type(); - *ctx->OutputDType("softmax_y", 0) = x_desc.data_type(); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn("mask", 0); - user_op::InputArgModifier* dropout_mask_modifier = GetInputArgModifierFn("dropout_mask", 0); - CHECK_OR_RETURN(mask_modifier != nullptr); - CHECK_OR_RETURN(dropout_mask_modifier != nullptr); - mask_modifier->set_requires_grad(false); - dropout_mask_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - CHECK_GE_OR_RETURN(x_tensor.shape().NumAxes(), 2); - FOR_RANGE(int64_t, axis, 0, x_tensor.shape().NumAxes() - 2) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), axis) - .Split(user_op::OpArg("mask", 0), axis) - .Split(user_op::OpArg("dropout_mask", 0), axis) - .Split(user_op::OpArg("y", 0), axis) - .Split(user_op::OpArg("softmax_y", 0), axis) - .Build(); - } - return Maybe::Ok(); - }); - -REGISTER_USER_OP("fused_scale_mask_softmax_dropout_grad") - .Input("softmax_y") - .Input("dy") - .Input("mask") - .Input("dropout_mask") - .Output("dx") - .Attr("scale_value") - .Attr("dropout_scale_value") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc("softmax_y", 0); - const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); - const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); - CHECK_EQ_OR_RETURN(dy_desc.shape(), softmax_y_desc.shape()); - CHECK_OR_RETURN(dy_desc.shape() == mask_desc.shape()); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); - *dx_desc->mut_shape() = dy_desc.shape(); - *dx_desc->mut_is_dynamic() = dy_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc("softmax_y", 0); - const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); - const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); - CHECK_OR_RETURN(dy_desc.data_type() == softmax_y_desc.data_type()); - CHECK_OR_RETURN(mask_desc.data_type() == DataType::kInt8); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); - *dx_desc->mut_data_type() = dy_desc.data_type(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); - CHECK_GE_OR_RETURN(dy_tensor.shape().NumAxes(), 2); - FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes() - 2) { - ctx->NewBuilder() - .Split(user_op::OpArg("softmax_y", 0), axis) - .Split(user_op::OpArg("dy", 0), axis) - .Split(user_op::OpArg("mask", 0), axis) - .Split(user_op::OpArg("dropout_mask", 0), axis) - .Split(user_op::OpArg("dx", 0), axis) - .Build(); - } - return Maybe::Ok(); - }); +/*static*/ auto FusedScaleMaskSoftmaxDropoutGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) -> Maybe { + const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc("softmax_y", 0); + const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); + const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); + CHECK_EQ_OR_RETURN(dy_desc.shape(), softmax_y_desc.shape()); + CHECK_OR_RETURN(dy_desc.shape() == mask_desc.shape()); + user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + *dx_desc->mut_shape() = dy_desc.shape(); + *dx_desc->mut_is_dynamic() = dy_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ auto FusedScaleMaskSoftmaxDropoutGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) -> Maybe { + return FusedScaleMaskSoftmaxDropoutGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FusedScaleMaskSoftmaxDropoutGradOp::InferDataType(user_op::InferContext* ctx) + -> Maybe { + const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc("softmax_y", 0); + const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); + const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); + CHECK_OR_RETURN(dy_desc.data_type() == softmax_y_desc.data_type()); + CHECK_OR_RETURN(mask_desc.data_type() == DataType::kInt8); + user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + *dx_desc->mut_data_type() = dy_desc.data_type(); + return Maybe::Ok(); +} +/*static*/ auto FusedScaleMaskSoftmaxDropoutGradOp::GetSbp(user_op::SbpContext* ctx) + -> Maybe { + const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); + CHECK_GE_OR_RETURN(dy_tensor.shape().NumAxes(), 2); + FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes() - 2) { + ctx->NewBuilder() + .Split(user_op::OpArg("softmax_y", 0), axis) + .Split(user_op::OpArg("dy", 0), axis) + .Split(user_op::OpArg("mask", 0), axis) + .Split(user_op::OpArg("dropout_mask", 0), axis) + .Split(user_op::OpArg("dx", 0), axis) + .Build(); + } + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("fused_scale_mask_softmax_dropout") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -136,6 +132,4 @@ REGISTER_USER_OP_GRAD("fused_scale_mask_softmax_dropout") return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/fused_scale_mask_softmax_op.cpp b/oneflow/user/ops/fused_scale_mask_softmax_op.cpp index 578caa3b495..685c62071f2 100644 --- a/oneflow/user/ops/fused_scale_mask_softmax_op.cpp +++ b/oneflow/user/ops/fused_scale_mask_softmax_op.cpp @@ -14,92 +14,91 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ auto FusedScaleMaskSoftmaxOp::InferLogicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); + CHECK_OR_RETURN(x_desc.shape() == mask_desc.shape()); + *ctx->OutputShape("y", 0) = x_desc.shape(); + *ctx->OutputIsDynamic("y", 0) = x_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ auto FusedScaleMaskSoftmaxOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + return FusedScaleMaskSoftmaxOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FusedScaleMaskSoftmaxOp::InferDataType(user_op::InferContext* ctx) -> Maybe { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); + CHECK_OR_RETURN(mask_desc.data_type() == DataType::kInt8); + *ctx->OutputDType("y", 0) = x_desc.data_type(); + return Maybe::Ok(); +} +/*static*/ auto FusedScaleMaskSoftmaxOp::ModifyInputArg( + const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) + -> Maybe { + user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn("mask", 0); + CHECK_OR_RETURN(mask_modifier != nullptr); + mask_modifier->set_requires_grad(false); + return Maybe::Ok(); +} +/*static*/ auto FusedScaleMaskSoftmaxOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + CHECK_GE_OR_RETURN(x_tensor.shape().NumAxes(), 2); + FOR_RANGE(int64_t, axis, 0, x_tensor.shape().NumAxes() - 2) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), axis) + .Split(user_op::OpArg("mask", 0), axis) + .Split(user_op::OpArg("y", 0), axis) + .Build(); + } + return Maybe::Ok(); +} -REGISTER_USER_OP("fused_scale_mask_softmax") - .Input("x") - .Input("mask") - .Output("y") - .Attr("scale_value", 1.0) - .Attr("mask_fill_value", 0.) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); - CHECK_OR_RETURN(x_desc.shape() == mask_desc.shape()); - *ctx->OutputShape("y", 0) = x_desc.shape(); - *ctx->OutputIsDynamic("y", 0) = x_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); - CHECK_OR_RETURN(mask_desc.data_type() == DataType::kInt8); - *ctx->OutputDType("y", 0) = x_desc.data_type(); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn("mask", 0); - CHECK_OR_RETURN(mask_modifier != nullptr); - mask_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - CHECK_GE_OR_RETURN(x_tensor.shape().NumAxes(), 2); - FOR_RANGE(int64_t, axis, 0, x_tensor.shape().NumAxes() - 2) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), axis) - .Split(user_op::OpArg("mask", 0), axis) - .Split(user_op::OpArg("y", 0), axis) - .Build(); - } - return Maybe::Ok(); - }); - -REGISTER_USER_OP("fused_scale_mask_softmax_grad") - .Input("y") - .Input("dy") - .Input("mask") - .Output("dx") - .Attr("scale_value") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); - const user_op::TensorDesc& y_desc = ctx->InputTensorDesc("y", 0); - const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); - CHECK_EQ_OR_RETURN(dy_desc.shape(), y_desc.shape()); - CHECK_OR_RETURN(y_desc.shape() == mask_desc.shape()); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); - *dx_desc->mut_shape() = dy_desc.shape(); - *dx_desc->mut_is_dynamic() = dy_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); - const user_op::TensorDesc& y_desc = ctx->InputTensorDesc("y", 0); - const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); - CHECK_OR_RETURN(dy_desc.data_type() == y_desc.data_type()); - CHECK_OR_RETURN(mask_desc.data_type() == DataType::kInt8); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); - *dx_desc->mut_data_type() = dy_desc.data_type(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); - CHECK_GE_OR_RETURN(dy_tensor.shape().NumAxes(), 2); - FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes() - 2) { - ctx->NewBuilder() - .Split(user_op::OpArg("y", 0), axis) - .Split(user_op::OpArg("dy", 0), axis) - .Split(user_op::OpArg("mask", 0), axis) - .Split(user_op::OpArg("dx", 0), axis) - .Build(); - } - return Maybe::Ok(); - }); +/*static*/ auto FusedScaleMaskSoftmaxGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); + const user_op::TensorDesc& y_desc = ctx->InputTensorDesc("y", 0); + const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); + CHECK_EQ_OR_RETURN(dy_desc.shape(), y_desc.shape()); + CHECK_OR_RETURN(y_desc.shape() == mask_desc.shape()); + user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + *dx_desc->mut_shape() = dy_desc.shape(); + *dx_desc->mut_is_dynamic() = dy_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ auto FusedScaleMaskSoftmaxGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + return FusedScaleMaskSoftmaxGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FusedScaleMaskSoftmaxGradOp::InferDataType(user_op::InferContext* ctx) + -> Maybe { + const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); + const user_op::TensorDesc& y_desc = ctx->InputTensorDesc("y", 0); + const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); + CHECK_OR_RETURN(dy_desc.data_type() == y_desc.data_type()); + CHECK_OR_RETURN(mask_desc.data_type() == DataType::kInt8); + user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + *dx_desc->mut_data_type() = dy_desc.data_type(); + return Maybe::Ok(); +} +/*static*/ auto FusedScaleMaskSoftmaxGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); + CHECK_GE_OR_RETURN(dy_tensor.shape().NumAxes(), 2); + FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes() - 2) { + ctx->NewBuilder() + .Split(user_op::OpArg("y", 0), axis) + .Split(user_op::OpArg("dy", 0), axis) + .Split(user_op::OpArg("mask", 0), axis) + .Split(user_op::OpArg("dx", 0), axis) + .Build(); + } + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("fused_scale_mask_softmax") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -119,6 +118,4 @@ REGISTER_USER_OP_GRAD("fused_scale_mask_softmax") return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp b/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp index f0905fd8f98..20dead6c8d7 100644 --- a/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp +++ b/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp @@ -14,93 +14,88 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { +/*static*/ auto FusedTrilScaleSoftmaxMaskScaleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + *ctx->OutputShape("y", 0) = x_desc.shape(); + *ctx->OutputIsDynamic("y", 0) = x_desc.is_dynamic(); + *ctx->OutputShape("softmax_y", 0) = x_desc.shape(); + *ctx->OutputIsDynamic("softmax_y", 0) = x_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ auto FusedTrilScaleSoftmaxMaskScaleOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) -> Maybe { + return FusedTrilScaleSoftmaxMaskScaleOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FusedTrilScaleSoftmaxMaskScaleOp::InferDataType(user_op::InferContext* ctx) + -> Maybe { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + *ctx->OutputDType("y", 0) = x_desc.data_type(); + *ctx->OutputDType("softmax_y", 0) = x_desc.data_type(); + return Maybe::Ok(); +} +/*static*/ auto FusedTrilScaleSoftmaxMaskScaleOp::ModifyInputArg( + const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) + -> Maybe { + user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn("mask", 0); + CHECK_OR_RETURN(mask_modifier != nullptr); + mask_modifier->set_requires_grad(false); + return Maybe::Ok(); +} +/*static*/ auto FusedTrilScaleSoftmaxMaskScaleOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + CHECK_GE_OR_RETURN(x_tensor.shape().NumAxes(), 2); + FOR_RANGE(int64_t, axis, 0, x_tensor.shape().NumAxes() - 2) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), axis) + .Split(user_op::OpArg("mask", 0), axis) + .Split(user_op::OpArg("y", 0), axis) + .Split(user_op::OpArg("softmax_y", 0), axis) + .Build(); + } + return Maybe::Ok(); +} -namespace { - -REGISTER_USER_OP("fused_tril_scale_softmax_mask_scale") - .Input("x") - .Input("mask") - .Output("y") - .Output("softmax_y") - .Attr("diagonal") - .Attr("tril_fill_value", 0) - .Attr("tril_scale_value", 1.0) - .Attr("mask_scale_value", 1.0) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - *ctx->OutputShape("y", 0) = x_desc.shape(); - *ctx->OutputIsDynamic("y", 0) = x_desc.is_dynamic(); - *ctx->OutputShape("softmax_y", 0) = x_desc.shape(); - *ctx->OutputIsDynamic("softmax_y", 0) = x_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - *ctx->OutputDType("y", 0) = x_desc.data_type(); - *ctx->OutputDType("softmax_y", 0) = x_desc.data_type(); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn("mask", 0); - CHECK_OR_RETURN(mask_modifier != nullptr); - mask_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - CHECK_GE_OR_RETURN(x_tensor.shape().NumAxes(), 2); - FOR_RANGE(int64_t, axis, 0, x_tensor.shape().NumAxes() - 2) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), axis) - .Split(user_op::OpArg("mask", 0), axis) - .Split(user_op::OpArg("y", 0), axis) - .Split(user_op::OpArg("softmax_y", 0), axis) - .Build(); - } - return Maybe::Ok(); - }); - -REGISTER_USER_OP("fused_tril_scale_softmax_mask_scale_grad") - .Input("softmax_y") - .Input("dy") - .Input("mask") - .Output("dx") - .Attr("diagonal") - .Attr("tril_scale_value") - .Attr("mask_scale_value") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc("softmax_y", 0); - const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); - CHECK_OR_RETURN(dy_desc.shape() == softmax_y_desc.shape()); - *dx_desc->mut_shape() = dy_desc.shape(); - *dx_desc->mut_is_dynamic() = dy_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc("softmax_y", 0); - const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); - CHECK_OR_RETURN(dy_desc.data_type() == softmax_y_desc.data_type()); - *dx_desc->mut_data_type() = dy_desc.data_type(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); - CHECK_GE_OR_RETURN(dy_tensor.shape().NumAxes(), 2); - FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes() - 2) { - ctx->NewBuilder() - .Split(user_op::OpArg("softmax_y", 0), axis) - .Split(user_op::OpArg("dy", 0), axis) - .Split(user_op::OpArg("mask", 0), axis) - .Split(user_op::OpArg("dx", 0), axis) - .Build(); - } - return Maybe::Ok(); - }); +/*static*/ auto FusedTrilScaleSoftmaxMaskScaleGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) -> Maybe { + const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc("softmax_y", 0); + const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); + user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + CHECK_OR_RETURN(dy_desc.shape() == softmax_y_desc.shape()); + *dx_desc->mut_shape() = dy_desc.shape(); + *dx_desc->mut_is_dynamic() = dy_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ auto FusedTrilScaleSoftmaxMaskScaleGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) -> Maybe { + return FusedTrilScaleSoftmaxMaskScaleGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FusedTrilScaleSoftmaxMaskScaleGradOp::InferDataType(user_op::InferContext* ctx) + -> Maybe { + const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc("softmax_y", 0); + const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); + user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + CHECK_OR_RETURN(dy_desc.data_type() == softmax_y_desc.data_type()); + *dx_desc->mut_data_type() = dy_desc.data_type(); + return Maybe::Ok(); +} +/*static*/ auto FusedTrilScaleSoftmaxMaskScaleGradOp::GetSbp(user_op::SbpContext* ctx) + -> Maybe { + const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); + CHECK_GE_OR_RETURN(dy_tensor.shape().NumAxes(), 2); + FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes() - 2) { + ctx->NewBuilder() + .Split(user_op::OpArg("softmax_y", 0), axis) + .Split(user_op::OpArg("dy", 0), axis) + .Split(user_op::OpArg("mask", 0), axis) + .Split(user_op::OpArg("dx", 0), axis) + .Build(); + } + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("fused_tril_scale_softmax_mask_scale") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -123,6 +118,4 @@ REGISTER_USER_OP_GRAD("fused_tril_scale_softmax_mask_scale") return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp b/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp index 0748daa1b22..232a78189c9 100644 --- a/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp +++ b/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp @@ -14,110 +14,113 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("fused_self_attention_query_mul_key_and_value") - .Input("hidden_states") - .Output("query_mul_key") - .Output("value") - .Attr("head_size") - .Attr("alpha") - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const DataType& dtype = ctx->InputDType("hidden_states", 0); - *ctx->OutputDType("query_mul_key", 0) = dtype; - *ctx->OutputDType("value", 0) = dtype; - return Maybe::Ok(); - }) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_OR_RETURN(!(ctx->InputIsDynamic("hidden_states", 0))); - int64_t head_size = ctx->Attr("head_size"); - const Shape& hidden_states_shape = ctx->InputShape("hidden_states", 0); - // hidden_states_shape (seq_len, batch_size, hidden_size) - // layout is (seq_len, batch_size, num_heads, 3, head_size) - // for example shape (1024, 4, 12, 3, 64) -> (1024, 4, 12, 192) which stride is (9216, 2304, - // 192, 1) - CHECK_EQ_OR_RETURN(hidden_states_shape.NumAxes(), 3); - int64_t seq_len = hidden_states_shape.At(0); - int64_t batch_size = hidden_states_shape.At(1); - int64_t hidden_size = hidden_states_shape.At(2); - CHECK_EQ_OR_RETURN(hidden_size % (head_size * 3), 0); - int64_t num_heads = hidden_size / (head_size * 3); +/*static*/ auto FusedSelfAttentionQueryMulKeyAndValueOp::InferDataType(user_op::InferContext* ctx) + -> Maybe { + const DataType& dtype = ctx->InputDType("hidden_states", 0); + *ctx->OutputDType("query_mul_key", 0) = dtype; + *ctx->OutputDType("value", 0) = dtype; + return Maybe::Ok(); +} +/*static*/ auto FusedSelfAttentionQueryMulKeyAndValueOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) -> Maybe { + CHECK_OR_RETURN(!(ctx->InputIsDynamic("hidden_states", 0))); + int64_t head_size = ctx->Attr("head_size"); + const Shape& hidden_states_shape = ctx->InputShape("hidden_states", 0); + // hidden_states_shape (seq_len, batch_size, hidden_size) + // layout is (seq_len, batch_size, num_heads, 3, head_size) + // for example shape (1024, 4, 12, 3, 64) -> (1024, 4, 12, 192) which stride is (9216, 2304, + // 192, 1) + CHECK_EQ_OR_RETURN(hidden_states_shape.NumAxes(), 3); + int64_t seq_len = hidden_states_shape.At(0); + int64_t batch_size = hidden_states_shape.At(1); + int64_t hidden_size = hidden_states_shape.At(2); + CHECK_EQ_OR_RETURN(hidden_size % (head_size * 3), 0); + int64_t num_heads = hidden_size / (head_size * 3); - *ctx->OutputShape("query_mul_key", 0) = Shape({batch_size, num_heads, seq_len, seq_len}); - *ctx->OutputShape("value", 0) = Shape({batch_size, num_heads, seq_len, head_size}); + *ctx->OutputShape("query_mul_key", 0) = Shape({batch_size, num_heads, seq_len, seq_len}); + *ctx->OutputShape("value", 0) = Shape({batch_size, num_heads, seq_len, head_size}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("hidden_states", 0), 1) - .Split(user_op::OpArg("query_mul_key", 0), 0) - .Split(user_op::OpArg("value", 0), 0) - .Build(); - ctx->NewBuilder() - .Split(user_op::OpArg("hidden_states", 0), 2) - .Split(user_op::OpArg("query_mul_key", 0), 1) - .Split(user_op::OpArg("value", 0), 1) - .Build(); - return Maybe::Ok(); - }); + return Maybe::Ok(); +} +/*static*/ auto FusedSelfAttentionQueryMulKeyAndValueOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) -> Maybe { + return FusedSelfAttentionQueryMulKeyAndValueOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FusedSelfAttentionQueryMulKeyAndValueOp::GetSbp(user_op::SbpContext* ctx) + -> Maybe { + ctx->NewBuilder() + .Split(user_op::OpArg("hidden_states", 0), 1) + .Split(user_op::OpArg("query_mul_key", 0), 0) + .Split(user_op::OpArg("value", 0), 0) + .Build(); + ctx->NewBuilder() + .Split(user_op::OpArg("hidden_states", 0), 2) + .Split(user_op::OpArg("query_mul_key", 0), 1) + .Split(user_op::OpArg("value", 0), 1) + .Build(); + return Maybe::Ok(); +} -REGISTER_USER_OP("fused_self_attention_query_mul_key_and_value_grad") - .Input("query_mul_key_grad") - .Input("value_grad") - .Input("hidden_states") - .Output("hidden_states_grad") - .Attr("alpha") - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const DataType& dtype = ctx->InputDType("query_mul_key_grad", 0); - CHECK_EQ_OR_RETURN(ctx->InputDType("value_grad", 0), dtype); - *ctx->OutputDType("hidden_states_grad", 0) = dtype; - return Maybe::Ok(); - }) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_OR_RETURN(!(ctx->InputIsDynamic("query_mul_key_grad", 0))); - CHECK_OR_RETURN(!(ctx->InputIsDynamic("value_grad", 0))); - const Shape& h_shape = ctx->InputShape("hidden_states", 0); - const Shape& qmk_grad_shape = ctx->InputShape("query_mul_key_grad", 0); - const Shape& v_grad_shape = ctx->InputShape("value_grad", 0); - CHECK_EQ_OR_RETURN(h_shape.NumAxes(), 3); - CHECK_EQ_OR_RETURN(qmk_grad_shape.NumAxes(), 4); - CHECK_EQ_OR_RETURN(v_grad_shape.NumAxes(), 4); - // hidden_states shape (s, b, H) - int64_t seq_len = h_shape.At(0); - int64_t batch_size = h_shape.At(1); - int64_t hidden_size = h_shape.At(2); - // value grad shape (b, n, s, h) - int64_t num_heads = v_grad_shape.At(1); - int64_t head_size = v_grad_shape.At(3); - CHECK_EQ_OR_RETURN(v_grad_shape.At(0), batch_size); - CHECK_EQ_OR_RETURN(v_grad_shape.At(2), seq_len); - CHECK_EQ_OR_RETURN(hidden_size, num_heads * 3 * head_size); - // qmk grad shape (b, n, sq, sk) - CHECK_EQ_OR_RETURN(qmk_grad_shape.At(0), batch_size); - CHECK_EQ_OR_RETURN(qmk_grad_shape.At(1), num_heads); - CHECK_EQ_OR_RETURN(qmk_grad_shape.At(2), seq_len); - CHECK_EQ_OR_RETURN(qmk_grad_shape.At(3), seq_len); +/*static*/ auto FusedSelfAttentionQueryMulKeyAndValueGradOp::InferDataType( + user_op::InferContext* ctx) -> Maybe { + const DataType& dtype = ctx->InputDType("query_mul_key_grad", 0); + CHECK_EQ_OR_RETURN(ctx->InputDType("value_grad", 0), dtype); + *ctx->OutputDType("hidden_states_grad", 0) = dtype; + return Maybe::Ok(); +} +/*static*/ auto FusedSelfAttentionQueryMulKeyAndValueGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) -> Maybe { + CHECK_OR_RETURN(!(ctx->InputIsDynamic("query_mul_key_grad", 0))); + CHECK_OR_RETURN(!(ctx->InputIsDynamic("value_grad", 0))); + const Shape& h_shape = ctx->InputShape("hidden_states", 0); + const Shape& qmk_grad_shape = ctx->InputShape("query_mul_key_grad", 0); + const Shape& v_grad_shape = ctx->InputShape("value_grad", 0); + CHECK_EQ_OR_RETURN(h_shape.NumAxes(), 3); + CHECK_EQ_OR_RETURN(qmk_grad_shape.NumAxes(), 4); + CHECK_EQ_OR_RETURN(v_grad_shape.NumAxes(), 4); + // hidden_states shape (s, b, H) + int64_t seq_len = h_shape.At(0); + int64_t batch_size = h_shape.At(1); + int64_t hidden_size = h_shape.At(2); + // value grad shape (b, n, s, h) + int64_t num_heads = v_grad_shape.At(1); + int64_t head_size = v_grad_shape.At(3); + CHECK_EQ_OR_RETURN(v_grad_shape.At(0), batch_size); + CHECK_EQ_OR_RETURN(v_grad_shape.At(2), seq_len); + CHECK_EQ_OR_RETURN(hidden_size, num_heads * 3 * head_size); + // qmk grad shape (b, n, sq, sk) + CHECK_EQ_OR_RETURN(qmk_grad_shape.At(0), batch_size); + CHECK_EQ_OR_RETURN(qmk_grad_shape.At(1), num_heads); + CHECK_EQ_OR_RETURN(qmk_grad_shape.At(2), seq_len); + CHECK_EQ_OR_RETURN(qmk_grad_shape.At(3), seq_len); - *ctx->OutputShape("hidden_states_grad", 0) = h_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("query_mul_key_grad", 0), 0) - .Split(user_op::OpArg("value_grad", 0), 0) - .Split(user_op::OpArg("hidden_states", 0), 1) - .Split(user_op::OpArg("hidden_states_grad", 0), 1) - .Build(); - ctx->NewBuilder() - .Split(user_op::OpArg("query_mul_key_grad", 0), 1) - .Split(user_op::OpArg("value_grad", 0), 1) - .Split(user_op::OpArg("hidden_states", 0), 2) - .Split(user_op::OpArg("hidden_states_grad", 0), 2) - .Build(); - return Maybe::Ok(); - }); + *ctx->OutputShape("hidden_states_grad", 0) = h_shape; + return Maybe::Ok(); +} +/*static*/ auto FusedSelfAttentionQueryMulKeyAndValueGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) -> Maybe { + return FusedSelfAttentionQueryMulKeyAndValueGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FusedSelfAttentionQueryMulKeyAndValueGradOp::GetSbp(user_op::SbpContext* ctx) + -> Maybe { + ctx->NewBuilder() + .Split(user_op::OpArg("query_mul_key_grad", 0), 0) + .Split(user_op::OpArg("value_grad", 0), 0) + .Split(user_op::OpArg("hidden_states", 0), 1) + .Split(user_op::OpArg("hidden_states_grad", 0), 1) + .Build(); + ctx->NewBuilder() + .Split(user_op::OpArg("query_mul_key_grad", 0), 1) + .Split(user_op::OpArg("value_grad", 0), 1) + .Split(user_op::OpArg("hidden_states", 0), 2) + .Split(user_op::OpArg("hidden_states_grad", 0), 2) + .Build(); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("fused_self_attention_query_mul_key_and_value") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { diff --git a/oneflow/user/ops/gather_op.cpp b/oneflow/user/ops/gather_op.cpp index 47045ef4c0c..87ded29ab9c 100644 --- a/oneflow/user/ops/gather_op.cpp +++ b/oneflow/user/ops/gather_op.cpp @@ -14,80 +14,79 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("gather") - .Input("in") - .Input("indices") - .Output("out") - .Attr("axis") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - CHECK_GT_OR_RETURN(in.shape().NumAxes(), 0); - const int64_t axis = ctx->Attr("axis"); - const user_op::TensorDesc& indices = ctx->InputTensorDesc("indices", 0); - CHECK_GT_OR_RETURN(indices.shape().NumAxes(), 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); +/*static*/ auto GatherOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + CHECK_GT_OR_RETURN(in.shape().NumAxes(), 0); + const int64_t axis = ctx->Attr("axis"); + const user_op::TensorDesc& indices = ctx->InputTensorDesc("indices", 0); + CHECK_GT_OR_RETURN(indices.shape().NumAxes(), 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - DimVector dim_vec; - dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin(), - in.shape().dim_vec().cbegin() + axis); - dim_vec.insert(dim_vec.end(), indices.shape().dim_vec().cbegin(), - indices.shape().dim_vec().cend()); - dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin() + axis + 1, - in.shape().dim_vec().end()); - *out->mut_shape() = Shape(dim_vec); - out->set_is_dynamic(indices.is_dynamic() || in.is_dynamic()); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK_OR_RETURN(indices_modifier != nullptr); - indices_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const int64_t in_num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape().NumAxes(); - const int64_t indices_num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0).shape().NumAxes(); - const int64_t gather_axis = ctx->Attr("axis"); - CHECK_GE_OR_RETURN(gather_axis, 0); - CHECK_LT_OR_RETURN(gather_axis, in_num_axes); - FOR_RANGE(int64_t, i, 0, indices_num_axes) { - ctx->NewBuilder() - .Split(user_op::OpArg("indices", 0), i) - .Broadcast(user_op::OpArg("in", 0)) - .Split(user_op::OpArg("out", 0), gather_axis + i) - .Build(); - } - FOR_RANGE(int64_t, i, 0, in_num_axes) { - if (i == gather_axis) { - ctx->NewBuilder() - .Broadcast(user_op::OpArg("indices", 0)) - .Split(user_op::OpArg("in", 0), i) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - } else { - ctx->NewBuilder() - .Broadcast(user_op::OpArg("indices", 0)) - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i < gather_axis ? i : i + indices_num_axes - 1) - .Build(); - } - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - const user_op::TensorDesc& indices = ctx->InputTensorDesc("indices", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(IsIndexDataType(indices.data_type())); - *out->mut_data_type() = in.data_type(); - return Maybe::Ok(); - }); + DimVector dim_vec; + dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin(), + in.shape().dim_vec().cbegin() + axis); + dim_vec.insert(dim_vec.end(), indices.shape().dim_vec().cbegin(), + indices.shape().dim_vec().cend()); + dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin() + axis + 1, + in.shape().dim_vec().end()); + *out->mut_shape() = Shape(dim_vec); + out->set_is_dynamic(indices.is_dynamic() || in.is_dynamic()); + return Maybe::Ok(); +} +/*static*/ auto GatherOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { + return GatherOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto GatherOp::ModifyInputArg(const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper&) -> Maybe { + user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); + CHECK_OR_RETURN(indices_modifier != nullptr); + indices_modifier->set_requires_grad(false); + return Maybe::Ok(); +} +/*static*/ auto GatherOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + const int64_t in_num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape().NumAxes(); + const int64_t indices_num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0).shape().NumAxes(); + const int64_t gather_axis = ctx->Attr("axis"); + CHECK_GE_OR_RETURN(gather_axis, 0); + CHECK_LT_OR_RETURN(gather_axis, in_num_axes); + FOR_RANGE(int64_t, i, 0, indices_num_axes) { + ctx->NewBuilder() + .Split(user_op::OpArg("indices", 0), i) + .Broadcast(user_op::OpArg("in", 0)) + .Split(user_op::OpArg("out", 0), gather_axis + i) + .Build(); + } + FOR_RANGE(int64_t, i, 0, in_num_axes) { + if (i == gather_axis) { + ctx->NewBuilder() + .Broadcast(user_op::OpArg("indices", 0)) + .Split(user_op::OpArg("in", 0), i) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + } else { + ctx->NewBuilder() + .Broadcast(user_op::OpArg("indices", 0)) + .Split(user_op::OpArg("in", 0), i) + .Split(user_op::OpArg("out", 0), i < gather_axis ? i : i + indices_num_axes - 1) + .Build(); + } + } + return Maybe::Ok(); +} +/*static*/ auto GatherOp::InferDataType(user_op::InferContext* ctx) -> Maybe { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + const user_op::TensorDesc& indices = ctx->InputTensorDesc("indices", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(IsIndexDataType(indices.data_type())); + *out->mut_data_type() = in.data_type(); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("gather").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/gelu_op.cpp b/oneflow/user/ops/gelu_op.cpp index d2f374052b5..39f12592c23 100644 --- a/oneflow/user/ops/gelu_op.cpp +++ b/oneflow/user/ops/gelu_op.cpp @@ -14,66 +14,63 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("gelu") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); - *out_shape = in_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ auto GeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { + const Shape& in_shape = ctx->InputShape("in", 0); + Shape* out_shape = ctx->OutputShape("out", 0); + *out_shape = in_shape; + return Maybe::Ok(); +} +/*static*/ auto GeluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { + return GeluOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto GeluOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ auto GeluOp::InferDataType(user_op::InferContext* ctx) -> Maybe { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("gelu_grad") - .Input("x") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(dy_shape == x_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - ctx->NewBuilder() - .Broadcast(user_op::OpArg("x", 0)) - .PartialSum(user_op::OpArg("dy", 0)) - .PartialSum(user_op::OpArg("dx", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ auto GeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(dy_shape == x_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} +/*static*/ auto GeluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { + return GeluGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto GeluGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + ctx->NewBuilder() + .Broadcast(user_op::OpArg("x", 0)) + .PartialSum(user_op::OpArg("dy", 0)) + .PartialSum(user_op::OpArg("dx", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ auto GeluGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe { + CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("gelu").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/generate_random_batch_permutation_indices_op.cpp b/oneflow/user/ops/generate_random_batch_permutation_indices_op.cpp index 5de4f2e6280..73b7dcb52eb 100644 --- a/oneflow/user/ops/generate_random_batch_permutation_indices_op.cpp +++ b/oneflow/user/ops/generate_random_batch_permutation_indices_op.cpp @@ -15,34 +15,32 @@ limitations under the License. */ #include #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("generate_random_batch_permutation_indices") - .Input("x") - .Output("y") - .Attr("seed") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("y", 0) = Shape({ctx->InputShape("x", 0).At(0)}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .PartialSum(user_op::OpArg("x", 0)) - .Broadcast(user_op::OpArg("y", 0)) - .Build(); - const auto& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Broadcast(user_op::OpArg("y", 0)) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = DataType::kInt32; - return Maybe::Ok(); - }); +/*static*/ auto GenerateRandomBatchPermutationIndicesOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) -> Maybe { + *ctx->OutputShape("y", 0) = Shape({ctx->InputShape("x", 0).At(0)}); + return Maybe::Ok(); +} +/*static*/ auto GenerateRandomBatchPermutationIndicesOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) -> Maybe { + return GenerateRandomBatchPermutationIndicesOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto GenerateRandomBatchPermutationIndicesOp::GetSbp(user_op::SbpContext* ctx) + -> Maybe { + ctx->NewBuilder().PartialSum(user_op::OpArg("x", 0)).Broadcast(user_op::OpArg("y", 0)).Build(); + const auto& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Broadcast(user_op::OpArg("y", 0)).Build(); + } + return Maybe::Ok(); +} +/*static*/ auto GenerateRandomBatchPermutationIndicesOp::InferDataType(user_op::InferContext* ctx) + -> Maybe { + *ctx->OutputDType("y", 0) = DataType::kInt32; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/gpt_data_loader_op.cpp b/oneflow/user/ops/gpt_data_loader_op.cpp index a67f0d4077a..ab66be504e1 100644 --- a/oneflow/user/ops/gpt_data_loader_op.cpp +++ b/oneflow/user/ops/gpt_data_loader_op.cpp @@ -14,51 +14,42 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("megatron_gpt_mmap_data_loader") - .OptionalInput("iteration") - .Output("out") - .Attr("data_file_prefix") - .Attr("seq_length") - .Attr("label_length", 1) - .Attr("num_samples") - .Attr("batch_size") - .Attr("dtype") - .Attr>("split_sizes") - .Attr("split_index") - .Attr("shuffle") - .Attr("random_seed") - .Attr>("nd_sbp") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - int64_t batch_size = ctx->Attr("batch_size"); - int64_t sample_len = ctx->Attr("seq_length") + ctx->Attr("label_length"); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_shape() = Shape({batch_size, sample_len}); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputTensorDesc("out", 0)->mut_data_type() = ctx->Attr("dtype"); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - cfg::SbpParallel default_sbp; - default_sbp.mutable_split_parallel()->set_axis(0); - return user_op::InferNdSbp4SrcOp(ctx, default_sbp); - }) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> Maybe { - if (!conf.has_input("iteration", 0)) { return Maybe::Ok(); } - user_op::InputArgModifier* input_modifier = GetInputArgModifierFn("iteration", 0); - CHECK_OR_RETURN(input_modifier != nullptr); - input_modifier->set_is_mutable(true); - input_modifier->set_requires_grad(false); - return Maybe::Ok(); - }); +/*static*/ auto MegatronGptMmapDataLoaderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + int64_t batch_size = ctx->Attr("batch_size"); + int64_t sample_len = ctx->Attr("seq_length") + ctx->Attr("label_length"); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_shape() = Shape({batch_size, sample_len}); + return Maybe::Ok(); +} +/*static*/ auto MegatronGptMmapDataLoaderOp::InferDataType(user_op::InferContext* ctx) + -> Maybe { + *ctx->OutputTensorDesc("out", 0)->mut_data_type() = ctx->Attr("dtype"); + return Maybe::Ok(); +} +/*static*/ auto MegatronGptMmapDataLoaderOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} +/*static*/ auto MegatronGptMmapDataLoaderOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) + -> Maybe { + cfg::SbpParallel default_sbp; + default_sbp.mutable_split_parallel()->set_axis(0); + return user_op::InferNdSbp4SrcOp(ctx, default_sbp); +} +/*static*/ auto MegatronGptMmapDataLoaderOp::ModifyInputArg( + const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) -> Maybe { + if (!conf.has_input("iteration", 0)) { return Maybe::Ok(); } + user_op::InputArgModifier* input_modifier = GetInputArgModifierFn("iteration", 0); + CHECK_OR_RETURN(input_modifier != nullptr); + input_modifier->set_is_mutable(true); + input_modifier->set_requires_grad(false); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/grid_sample_op.cpp b/oneflow/user/ops/grid_sample_op.cpp index c415d858987..d9c81470a7d 100644 --- a/oneflow/user/ops/grid_sample_op.cpp +++ b/oneflow/user/ops/grid_sample_op.cpp @@ -14,13 +14,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -Maybe CheckAttr(const user_op::UserOpDefWrapper& def, - const user_op::UserOpConfWrapper& conf) { +Maybe GridSampleOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { bool pass_checked = true; std::stringstream err; err << "Illegal value for " << conf.op_type_name() << " op " << conf.op_name() << ": "; @@ -45,110 +44,103 @@ Maybe CheckAttr(const user_op::UserOpDefWrapper& def, } } -} // namespace - -REGISTER_USER_OP("grid_sample") - .Input("input") - .Input("grid") - .Output("output") - .Attr("interpolation_mode") - .Attr("padding_mode") - .Attr("align_corners") - .SetCheckAttrFn(CheckAttr) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& input = ctx->InputTensorDesc("input", 0); - const user_op::TensorDesc& grid = ctx->InputTensorDesc("grid", 0); - user_op::TensorDesc& output = *(ctx->OutputTensorDesc("output", 0)); - // Only support 4D or 5D input with NCHW layout - // For 4D grid: input = { N, C, H_in, W_in }, - // grid = { N, H_out, W_out, 2 } - // output = { N, C, H_out, W_out } - // For 5D grid: input = { N, C, D_in, H_in, W_in }, - // grid = { N, D_out, H_out, W_out, 3 } - // output = { N, C, D_out, H_out, W_out } - const Shape& input_shape = input.shape(); - const Shape& grid_shape = grid.shape(); +/*static*/ auto GridSampleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { + const user_op::TensorDesc& input = ctx->InputTensorDesc("input", 0); + const user_op::TensorDesc& grid = ctx->InputTensorDesc("grid", 0); + user_op::TensorDesc& output = *(ctx->OutputTensorDesc("output", 0)); + // Only support 4D or 5D input with NCHW layout + // For 4D grid: input = { N, C, H_in, W_in }, + // grid = { N, H_out, W_out, 2 } + // output = { N, C, H_out, W_out } + // For 5D grid: input = { N, C, D_in, H_in, W_in }, + // grid = { N, D_out, H_out, W_out, 3 } + // output = { N, C, D_out, H_out, W_out } + const Shape& input_shape = input.shape(); + const Shape& grid_shape = grid.shape(); + + bool is_4d_input = true; + if (input_shape.NumAxes() == 4) { + CHECK_EQ_OR_RETURN(grid_shape.NumAxes(), 4) << "Grid and input MUST have same dimention"; + CHECK_EQ_OR_RETURN(grid_shape.At(3), 2) << "Grid shape MUST (N, H_out, W_out, 2)"; + is_4d_input = true; + } else if (input_shape.NumAxes() == 5) { + CHECK_EQ_OR_RETURN(grid_shape.NumAxes(), 5) << "Grid and input MUST have same dimention"; + CHECK_EQ_OR_RETURN(grid_shape.At(4), 3) << "Grid shape MUST (N, H_out, W_out, 3)"; + if (ctx->Attr("interpolation_mode") == "bicubic") { + oneflow::Error::CheckFailedError() << "Mode='bicubic' supports only 4-D input"; + } + is_4d_input = false; + } else { + CHECK_OR_RETURN(false) << "MUST be 4D or 5D input"; + } + *output.mut_is_dynamic() = grid.is_dynamic(); + if (is_4d_input) { + *(output.mut_shape()) = {input_shape.At(0), input_shape.At(1), grid_shape.At(1), + grid_shape.At(2)}; + } else { + *(output.mut_shape()) = {input_shape.At(0), input_shape.At(1), grid_shape.At(1), + grid_shape.At(2), grid_shape.At(3)}; + } + return Maybe::Ok(); +} +/*static*/ auto GridSampleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { + return GridSampleOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto GridSampleOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + ctx->NewBuilder() + .Split(user_op::OpArg("input", 0), 0) + .Split(user_op::OpArg("grid", 0), 0) + .Split(user_op::OpArg("output", 0), 0) + .Build(); + ctx->NewBuilder() + .Split(user_op::OpArg("input", 0), 1) + .Broadcast(user_op::OpArg("grid", 0)) + .Split(user_op::OpArg("output", 0), 1) + .Build(); + return Maybe::Ok(); +} +/*static*/ auto GridSampleOp::InferDataType(user_op::InferContext* ctx) -> Maybe { + *ctx->OutputDType("output", 0) = ctx->InputDType("input", 0); + return Maybe::Ok(); +} - bool is_4d_input = true; - if (input_shape.NumAxes() == 4) { - CHECK_EQ_OR_RETURN(grid_shape.NumAxes(), 4) << "Grid and input MUST have same dimention"; - CHECK_EQ_OR_RETURN(grid_shape.At(3), 2) << "Grid shape MUST (N, H_out, W_out, 2)"; - is_4d_input = true; - } else if (input_shape.NumAxes() == 5) { - CHECK_EQ_OR_RETURN(grid_shape.NumAxes(), 5) << "Grid and input MUST have same dimention"; - CHECK_EQ_OR_RETURN(grid_shape.At(4), 3) << "Grid shape MUST (N, H_out, W_out, 3)"; - if (ctx->Attr("interpolation_mode") == "bicubic") { - oneflow::Error::CheckFailedError() << "Mode='bicubic' supports only 4-D input"; - } - is_4d_input = false; - } else { - CHECK_OR_RETURN(false) << "MUST be 4D or 5D input"; - } - *output.mut_is_dynamic() = grid.is_dynamic(); - if (is_4d_input) { - *(output.mut_shape()) = {input_shape.At(0), input_shape.At(1), grid_shape.At(1), - grid_shape.At(2)}; - } else { - *(output.mut_shape()) = {input_shape.At(0), input_shape.At(1), grid_shape.At(1), - grid_shape.At(2), grid_shape.At(3)}; - } - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("input", 0), 0) - .Split(user_op::OpArg("grid", 0), 0) - .Split(user_op::OpArg("output", 0), 0) - .Build(); - ctx->NewBuilder() - .Split(user_op::OpArg("input", 0), 1) - .Broadcast(user_op::OpArg("grid", 0)) - .Split(user_op::OpArg("output", 0), 1) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("output", 0) = ctx->InputDType("input", 0); - return Maybe::Ok(); - }); +Maybe GridSampleGradOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + return GridSampleOp::CheckAttr(def, conf); +} -REGISTER_USER_OP("grid_sample_grad") - .Input("doutput") - .Input("input") - .Input("grid") - .Output("dinput") - .Output("dgrid") - .Attr("interpolation_mode") - .Attr("padding_mode") - .Attr("align_corners") - .SetCheckAttrFn(CheckAttr) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *(ctx->OutputTensorDesc("dinput", 0)->mut_shape()) = ctx->InputTensorDesc("input", 0).shape(); - *(ctx->OutputTensorDesc("dgrid", 0)->mut_shape()) = ctx->InputTensorDesc("grid", 0).shape(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("doutput", 0), 0) - .Split(user_op::OpArg("input", 0), 0) - .Split(user_op::OpArg("grid", 0), 0) - .Split(user_op::OpArg("dinput", 0), 0) - .Split(user_op::OpArg("dgrid", 0), 0) - .Build(); - ctx->NewBuilder() - .Split(user_op::OpArg("doutput", 0), 1) - .Split(user_op::OpArg("input", 0), 1) - .Broadcast(user_op::OpArg("grid", 0)) - .Split(user_op::OpArg("dinput", 0), 1) - .Broadcast(user_op::OpArg("dgrid", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dinput", 0) = ctx->InputDType("input", 0); - *ctx->OutputDType("dgrid", 0) = ctx->InputDType("grid", 0); - return Maybe::Ok(); - }); +/*static*/ auto GridSampleGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + *(ctx->OutputTensorDesc("dinput", 0)->mut_shape()) = ctx->InputTensorDesc("input", 0).shape(); + *(ctx->OutputTensorDesc("dgrid", 0)->mut_shape()) = ctx->InputTensorDesc("grid", 0).shape(); + return Maybe::Ok(); +} +/*static*/ auto GridSampleGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + return GridSampleGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto GridSampleGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + ctx->NewBuilder() + .Split(user_op::OpArg("doutput", 0), 0) + .Split(user_op::OpArg("input", 0), 0) + .Split(user_op::OpArg("grid", 0), 0) + .Split(user_op::OpArg("dinput", 0), 0) + .Split(user_op::OpArg("dgrid", 0), 0) + .Build(); + ctx->NewBuilder() + .Split(user_op::OpArg("doutput", 0), 1) + .Split(user_op::OpArg("input", 0), 1) + .Broadcast(user_op::OpArg("grid", 0)) + .Split(user_op::OpArg("dinput", 0), 1) + .Broadcast(user_op::OpArg("dgrid", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ auto GridSampleGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe { + *ctx->OutputDType("dinput", 0) = ctx->InputDType("input", 0); + *ctx->OutputDType("dgrid", 0) = ctx->InputDType("grid", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("grid_sample") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/hardsigmoid_op.cpp b/oneflow/user/ops/hardsigmoid_op.cpp index cdf43671ab2..887614425ac 100644 --- a/oneflow/user/ops/hardsigmoid_op.cpp +++ b/oneflow/user/ops/hardsigmoid_op.cpp @@ -14,63 +14,64 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/* static */ Maybe HardsigmoidOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + Shape* out_shape = ctx->OutputShape("out", 0); + *out_shape = in_shape; + return Maybe::Ok(); +} -REGISTER_USER_OP("hardsigmoid") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); - *out_shape = in_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe HardsigmoidOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} -REGISTER_USER_OP("hardsigmoid_grad") - .Input("x") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(dy_shape == x_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/* static */ Maybe HardsigmoidOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe HardsigmoidOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe HardsigmoidGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(dy_shape == x_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} + +/*static*/ Maybe HardsigmoidGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe HardsigmoidGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe HardsigmoidGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("hardsigmoid") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { @@ -89,6 +90,4 @@ REGISTER_USER_OP_GRAD("hardsigmoid") return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/hardswish_op.cpp b/oneflow/user/ops/hardswish_op.cpp index 45d0ebe230c..f7dfbc5c870 100644 --- a/oneflow/user/ops/hardswish_op.cpp +++ b/oneflow/user/ops/hardswish_op.cpp @@ -14,61 +14,62 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/* static */ Maybe HardswishOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("hardswish") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe HardswishOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} -REGISTER_USER_OP("hardswish_grad") - .Input("x") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(dy_shape == x_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/* static */ Maybe HardswishOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe HardswishOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe HardswishGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(dy_shape == x_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} + +/*static*/ Maybe HardswishGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe HardswishGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe HardswishGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("hardswish") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { @@ -87,6 +88,4 @@ REGISTER_USER_OP_GRAD("hardswish") return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/hardtanh_op.cpp b/oneflow/user/ops/hardtanh_op.cpp index 2962c49e99e..2d5208c7b0b 100644 --- a/oneflow/user/ops/hardtanh_op.cpp +++ b/oneflow/user/ops/hardtanh_op.cpp @@ -14,73 +14,70 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/* static */ Maybe HardtanhOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + Shape* out_shape = ctx->OutputShape("out", 0); + *out_shape = in_shape; + double min_val = ctx->Attr("min_val"); + double max_val = ctx->Attr("max_val"); + CHECK_LE_OR_RETURN(min_val, max_val); + return Maybe::Ok(); +} -REGISTER_USER_OP("hardtanh") - .Input("in") - .Attr("min_val") - .Attr("max_val") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); - *out_shape = in_shape; - double min_val = ctx->Attr("min_val"); - double max_val = ctx->Attr("max_val"); - CHECK_LE_OR_RETURN(min_val, max_val); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe HardtanhOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} -REGISTER_USER_OP("hardtanh_grad") - .Input("y") - .Input("dy") - .Attr("min_val") - .Attr("max_val") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& y_shape = ctx->InputShape("y", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(dy_shape == y_shape); - *dx_shape = dy_shape; - double min_val = ctx->Attr("min_val"); - double max_val = ctx->Attr("max_val"); - CHECK_LE_OR_RETURN(min_val, max_val); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); - FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("y", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("y", 0), ctx->InputDType("dy", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("y", 0); - return Maybe::Ok(); - }); +/* static */ Maybe HardtanhOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe HardtanhOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe HardtanhGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& y_shape = ctx->InputShape("y", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(dy_shape == y_shape); + *dx_shape = dy_shape; + double min_val = ctx->Attr("min_val"); + double max_val = ctx->Attr("max_val"); + CHECK_LE_OR_RETURN(min_val, max_val); + return Maybe::Ok(); +} + +/*static*/ Maybe HardtanhGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe HardtanhGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); + FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("y", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe HardtanhGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("y", 0), ctx->InputDType("dy", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("y", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("hardtanh") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { @@ -101,6 +98,4 @@ REGISTER_USER_OP_GRAD("hardtanh") return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/hierarchical_parallel_cast_op.cpp b/oneflow/user/ops/hierarchical_parallel_cast_op.cpp index 8ac12f2573a..8fa5b36bf49 100644 --- a/oneflow/user/ops/hierarchical_parallel_cast_op.cpp +++ b/oneflow/user/ops/hierarchical_parallel_cast_op.cpp @@ -15,64 +15,78 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/operator.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("hierarchical_parallel_cast") - .Input("in") - .Output("out") - .Attr>("nd_sbp") - .Attr("grad_mode") - .Attr>("grad_nd_sbp") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - const auto& conf = ctx->user_op_conf().attr>("nd_sbp"); - CHECK_EQ_OR_RETURN(conf.size(), parallel_hierarchy.NumAxes()); - for (const std::string& sbp_str : conf) { - cfg::SbpParallel sbp_parallel; - CHECK_OR_RETURN(ParseSbpParallelFromString(sbp_str, &sbp_parallel)); - *in_distribution->add_sbp_parallel() = sbp_parallel; - *out_distribution->add_sbp_parallel() = sbp_parallel; - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_USER_OP("hierarchical_parallel_cast_like") - .Input("in") - .Input("like") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - cfg::NdSbp* like_distribution = ctx->NdSbp4ArgNameAndIndex("like", 0); - const cfg::NdSbp& hint_distribution = ctx->NdSbpHint4InputArgNameAndIndex("like", 0); - *in_distribution = hint_distribution; - *out_distribution = hint_distribution; - *like_distribution = hint_distribution; - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/* static */ Maybe HierarchicalParallelCastOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe HierarchicalParallelCastOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe HierarchicalParallelCastOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe HierarchicalParallelCastOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + const auto& conf = ctx->user_op_conf().attr>("nd_sbp"); + CHECK_EQ_OR_RETURN(conf.size(), parallel_hierarchy.NumAxes()); + for (const std::string& sbp_str : conf) { + cfg::SbpParallel sbp_parallel; + CHECK_OR_RETURN(ParseSbpParallelFromString(sbp_str, &sbp_parallel)); + *in_distribution->add_sbp_parallel() = sbp_parallel; + *out_distribution->add_sbp_parallel() = sbp_parallel; + } + return Maybe::Ok(); +} + +/* static */ Maybe HierarchicalParallelCastOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe HierarchicalParallelCastLikeOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe HierarchicalParallelCastLikeOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe HierarchicalParallelCastLikeOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe HierarchicalParallelCastLikeOp::InferNdSbp( + user_op::InferNdSbpFnContext* ctx) { + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + cfg::NdSbp* like_distribution = ctx->NdSbp4ArgNameAndIndex("like", 0); + const cfg::NdSbp& hint_distribution = ctx->NdSbpHint4InputArgNameAndIndex("like", 0); + *in_distribution = hint_distribution; + *out_distribution = hint_distribution; + *like_distribution = hint_distribution; + return Maybe::Ok(); +} + +/* static */ Maybe HierarchicalParallelCastLikeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("hierarchical_parallel_cast") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { diff --git a/oneflow/user/ops/identity_op.cpp b/oneflow/user/ops/identity_op.cpp index 2e67cefc4c6..538abeb5dde 100644 --- a/oneflow/user/ops/identity_op.cpp +++ b/oneflow/user/ops/identity_op.cpp @@ -14,37 +14,36 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/* static */ Maybe IdentityOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("identity") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe IdentityOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe IdentityOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe IdentityOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("identity") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -62,6 +61,4 @@ REGISTER_USER_OP_GRAD("identity") return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/image_batch_align_op.cpp b/oneflow/user/ops/image_batch_align_op.cpp index d6eaa08396c..0563281485b 100644 --- a/oneflow/user/ops/image_batch_align_op.cpp +++ b/oneflow/user/ops/image_batch_align_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -27,70 +28,71 @@ bool PowerOfTwo(T x) { } // namespace -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("image_batch_align") - .Input("in") - .Output("out") - .Attr("shape") - .Attr("data_type") - .Attr("alignment") - .Attr("dynamic_out") - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& def, - const user_op::UserOpConfWrapper& conf) -> Maybe { - bool check_failed = false; - std::stringstream err; - err << "Illegal attr value for " << conf.op_type_name() << " op, op_name: " << conf.op_name(); - const Shape& shape = conf.attr("shape"); - if (shape.NumAxes() != 3) { - err << ", shape: " << shape.ToString() << " (image shape must has 3 axes)"; - check_failed = true; - } - DataType data_type = conf.attr("data_type"); - if (data_type != DataType::kUInt8 && data_type != DataType::kFloat) { - err << ", data_type: " << data_type << " (only support kUInt8 and kFloat for now)"; - check_failed = true; - } - int32_t alignment = conf.attr("alignment"); - if (alignment < 0) { - err << ", alignment: " << alignment << " (alignment must be greater than or equal to 0)"; - check_failed = true; - } else if (alignment != 0 && !PowerOfTwo(alignment)) { - err << ", alignment: " << alignment - << " (alignment must be power of 2 when it's not equal to 0)"; - check_failed = true; - } - if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); } - return Maybe::Ok(); - }) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1); - const Shape& shape_attr = ctx->Attr("shape"); - const bool dynamic_out = ctx->Attr("dynamic_out"); - DimVector dim_vec(shape_attr.NumAxes() + 1); - dim_vec.at(0) = in_desc.shape().elem_cnt(); - FOR_RANGE(int64_t, i, 0, shape_attr.NumAxes()) { dim_vec.at(i + 1) = shape_attr.At(i); } - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_shape() = Shape(dim_vec); - out_desc->set_is_dynamic(dynamic_out); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetOutputArgModifyFn([](user_op::GetOutputArgModifier GetOutputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> Maybe { - user_op::OutputArgModifier* out_modifier = GetOutputArgModifierFn("out", 0); - CHECK_OR_RETURN(out_modifier != nullptr); - out_modifier->set_header_infered_before_compute(false); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_data_type() = ctx->Attr("data_type"); - return Maybe::Ok(); - }); +/* static */ Maybe ImageBatchAlignOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1); + const Shape& shape_attr = ctx->Attr("shape"); + const bool dynamic_out = ctx->Attr("dynamic_out"); + DimVector dim_vec(shape_attr.NumAxes() + 1); + dim_vec.at(0) = in_desc.shape().elem_cnt(); + FOR_RANGE(int64_t, i, 0, shape_attr.NumAxes()) { dim_vec.at(i + 1) = shape_attr.At(i); } + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_shape() = Shape(dim_vec); + out_desc->set_is_dynamic(dynamic_out); + return Maybe::Ok(); +} + +/*static*/ Maybe ImageBatchAlignOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ImageBatchAlignOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ImageBatchAlignOp::ModifyOutputArg( + const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::OutputArgModifier* out_modifier = GetOutputArgModifierFn("out", 0); + CHECK_OR_RETURN(out_modifier != nullptr); + out_modifier->set_header_infered_before_compute(false); + return Maybe::Ok(); +} + +/* static */ Maybe ImageBatchAlignOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + bool check_failed = false; + std::stringstream err; + err << "Illegal attr value for " << conf.op_type_name() << " op, op_name: " << conf.op_name(); + const Shape& shape = conf.attr("shape"); + if (shape.NumAxes() != 3) { + err << ", shape: " << shape.ToString() << " (image shape must has 3 axes)"; + check_failed = true; + } + DataType data_type = conf.attr("data_type"); + if (data_type != DataType::kUInt8 && data_type != DataType::kFloat) { + err << ", data_type: " << data_type << " (only support kUInt8 and kFloat for now)"; + check_failed = true; + } + int32_t alignment = conf.attr("alignment"); + if (alignment < 0) { + err << ", alignment: " << alignment << " (alignment must be greater than or equal to 0)"; + check_failed = true; + } else if (alignment != 0 && !PowerOfTwo(alignment)) { + err << ", alignment: " << alignment + << " (alignment must be power of 2 when it's not equal to 0)"; + check_failed = true; + } + if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); } + return Maybe::Ok(); +} + +/* static */ Maybe ImageBatchAlignOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_data_type() = ctx->Attr("data_type"); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/image_decode_op.cpp b/oneflow/user/ops/image_decode_op.cpp index 7db39a67ae5..cd308ce528e 100644 --- a/oneflow/user/ops/image_decode_op.cpp +++ b/oneflow/user/ops/image_decode_op.cpp @@ -14,50 +14,53 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("image_decode") - .Input("in") - .Output("out") - .Attr("color_space", "BGR") - .Attr("data_type", DataType::kUInt8) - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& def, - const user_op::UserOpConfWrapper& conf) -> Maybe { - bool check_failed = false; - std::stringstream err; - err << "Illegal attr value for " << conf.op_type_name() << " op, op_name: " << conf.op_name(); - const std::string& color_space = conf.attr("color_space"); - if (color_space != "BGR" && color_space != "RGB" && color_space != "GRAY") { - err << ", color_space: " << color_space - << " (color_space can only be one of BGR, RGB and GRAY)"; - check_failed = true; - } - DataType data_type = conf.attr("data_type"); - if (data_type != DataType::kUInt8 && data_type != DataType::kFloat) { - err << ", data_type: " << data_type << " (only support kUInt8 and kFloat for now)"; - check_failed = true; - } - if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); } - return Maybe::Ok(); - }) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1 && in_desc.shape().At(0) >= 1); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_shape() = in_desc.shape(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_data_type() = DataType::kTensorBuffer; - return Maybe::Ok(); - }); +/* static */ Maybe ImageDecodeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1 && in_desc.shape().At(0) >= 1); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_shape() = in_desc.shape(); + return Maybe::Ok(); +} + +/*static*/ Maybe ImageDecodeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ImageDecodeOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ImageDecodeOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + bool check_failed = false; + std::stringstream err; + err << "Illegal attr value for " << conf.op_type_name() << " op, op_name: " << conf.op_name(); + const std::string& color_space = conf.attr("color_space"); + if (color_space != "BGR" && color_space != "RGB" && color_space != "GRAY") { + err << ", color_space: " << color_space + << " (color_space can only be one of BGR, RGB and GRAY)"; + check_failed = true; + } + DataType data_type = conf.attr("data_type"); + if (data_type != DataType::kUInt8 && data_type != DataType::kFloat) { + err << ", data_type: " << data_type << " (only support kUInt8 and kFloat for now)"; + check_failed = true; + } + if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); } + return Maybe::Ok(); +} + +/* static */ Maybe ImageDecodeOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_data_type() = DataType::kTensorBuffer; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/image_object_preprocess_ops.cpp b/oneflow/user/ops/image_object_preprocess_ops.cpp index 979e3ebf30c..5fd2cb99f38 100644 --- a/oneflow/user/ops/image_object_preprocess_ops.cpp +++ b/oneflow/user/ops/image_object_preprocess_ops.cpp @@ -14,209 +14,243 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { -Maybe GetSbp(user_op::SbpContext* ctx) { +Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); return Maybe::Ok(); } } // namespace -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("image_flip") - .Input("in") - .Input("flip_code") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - CHECK_EQ_OR_RETURN(in_desc.shape().NumAxes(), 1); - const int N = in_desc.shape().elem_cnt(); - - const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); - CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N); - - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetSbp) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - CHECK_EQ_OR_RETURN(in_desc.data_type(), DataType::kTensorBuffer); - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("object_bbox_flip") - .Input("bbox") - .Input("image_size") - .Input("flip_code") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& bbox_desc = ctx->InputTensorDesc("bbox", 0); - CHECK_EQ_OR_RETURN(bbox_desc.shape().NumAxes(), 1); - const int N = bbox_desc.shape().elem_cnt(); - - const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); - CHECK_EQ_OR_RETURN(image_size_desc.shape().elem_cnt(), N * 2); - - const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); - CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N); - - *ctx->OutputShape("out", 0) = ctx->InputShape("bbox", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("bbox", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetSbp) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& bbox_desc = ctx->InputTensorDesc("bbox", 0); - CHECK_EQ_OR_RETURN(bbox_desc.data_type(), DataType::kTensorBuffer); - const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); - CHECK_EQ_OR_RETURN(image_size_desc.data_type(), DataType::kInt32); - const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); - CHECK_EQ_OR_RETURN(flip_code_desc.data_type(), DataType::kInt8); - *ctx->OutputDType("out", 0) = ctx->InputDType("bbox", 0); - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("object_bbox_scale") - .Input("bbox") - .Input("scale") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& bbox_desc = ctx->InputTensorDesc("bbox", 0); - CHECK_EQ_OR_RETURN(bbox_desc.shape().NumAxes(), 1); - const int N = bbox_desc.shape().elem_cnt(); - - const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); - CHECK_EQ_OR_RETURN(scale_desc.shape().elem_cnt(), N * 2); - - *ctx->OutputShape("out", 0) = ctx->InputShape("bbox", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("bbox", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetSbp) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& bbox_desc = ctx->InputTensorDesc("bbox", 0); - CHECK_EQ_OR_RETURN(bbox_desc.data_type(), DataType::kTensorBuffer); - const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); - CHECK_EQ_OR_RETURN(scale_desc.data_type(), DataType::kFloat); - *ctx->OutputDType("out", 0) = ctx->InputDType("bbox", 0); - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("object_segmentation_polygon_flip") - .Input("poly") - .Input("image_size") - .Input("flip_code") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); - CHECK_EQ_OR_RETURN(poly_desc.shape().NumAxes(), 1); - const int N = poly_desc.shape().elem_cnt(); - - const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); - CHECK_EQ_OR_RETURN(image_size_desc.shape().elem_cnt(), N * 2); - - const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); - CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N); - - *ctx->OutputShape("out", 0) = ctx->InputShape("poly", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("poly", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetSbp) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); - CHECK_EQ_OR_RETURN(poly_desc.data_type(), DataType::kTensorBuffer); - const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); - CHECK_EQ_OR_RETURN(image_size_desc.data_type(), DataType::kInt32); - const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); - CHECK_EQ_OR_RETURN(flip_code_desc.data_type(), DataType::kInt8); - *ctx->OutputDType("out", 0) = ctx->InputDType("poly", 0); - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("object_segmentation_polygon_scale") - .Input("poly") - .Input("scale") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); - CHECK_EQ_OR_RETURN(poly_desc.shape().NumAxes(), 1); - const int N = poly_desc.shape().elem_cnt(); - - const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); - CHECK_EQ_OR_RETURN(scale_desc.shape().elem_cnt(), N * 2); - - *ctx->OutputShape("out", 0) = ctx->InputShape("poly", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("poly", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetSbp) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); - CHECK_EQ_OR_RETURN(poly_desc.data_type(), DataType::kTensorBuffer); - const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); - CHECK_EQ_OR_RETURN(scale_desc.data_type(), DataType::kFloat); - *ctx->OutputDType("out", 0) = ctx->InputDType("poly", 0); - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("image_normalize") - .Input("in") - .Attr>("std") - .Attr>("mean") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - CHECK_EQ_OR_RETURN(in_desc.shape().NumAxes(), 1); - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetSbp) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - CHECK_EQ_OR_RETURN(in_desc.data_type(), DataType::kTensorBuffer); - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("object_segmentation_polygon_to_mask") - .Input("poly") - .Input("poly_index") - .Input("image_size") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); - CHECK_EQ_OR_RETURN(poly_desc.shape().NumAxes(), 1); - const int N = poly_desc.shape().elem_cnt(); - - const user_op::TensorDesc& poly_index_desc = ctx->InputTensorDesc("poly_index", 0); - CHECK_EQ_OR_RETURN(poly_index_desc.shape().NumAxes(), 1); - CHECK_EQ_OR_RETURN(poly_index_desc.shape().elem_cnt(), N); - - const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); - CHECK_EQ_OR_RETURN(image_size_desc.shape().elem_cnt(), N * 2); - - *ctx->OutputShape("out", 0) = ctx->InputShape("poly", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("poly", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetSbp) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); - CHECK_EQ_OR_RETURN(poly_desc.data_type(), DataType::kTensorBuffer); - const user_op::TensorDesc& poly_index_desc = ctx->InputTensorDesc("poly_index", 0); - CHECK_EQ_OR_RETURN(poly_index_desc.data_type(), DataType::kTensorBuffer); - const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); - CHECK_EQ_OR_RETURN(image_size_desc.data_type(), DataType::kInt32); - *ctx->OutputDType("out", 0) = ctx->InputDType("poly", 0); - return Maybe::Ok(); - }); +/* static */ Maybe ImageFlipOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + CHECK_EQ_OR_RETURN(in_desc.shape().NumAxes(), 1); + const int N = in_desc.shape().elem_cnt(); + + const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); + CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N); + + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe ImageFlipOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ImageFlipOp::GetSbp(user_op::SbpContext* ctx) { + return ImageObjectGetSbp(ctx); +} + +/* static */ Maybe ImageFlipOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + CHECK_EQ_OR_RETURN(in_desc.data_type(), DataType::kTensorBuffer); + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe ObjectBboxFlipOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& bbox_desc = ctx->InputTensorDesc("bbox", 0); + CHECK_EQ_OR_RETURN(bbox_desc.shape().NumAxes(), 1); + const int N = bbox_desc.shape().elem_cnt(); + + const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); + CHECK_EQ_OR_RETURN(image_size_desc.shape().elem_cnt(), N * 2); + + const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); + CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N); + + *ctx->OutputShape("out", 0) = ctx->InputShape("bbox", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("bbox", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe ObjectBboxFlipOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ObjectBboxFlipOp::GetSbp(user_op::SbpContext* ctx) { + return ImageObjectGetSbp(ctx); +} + +/* static */ Maybe ObjectBboxFlipOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& bbox_desc = ctx->InputTensorDesc("bbox", 0); + CHECK_EQ_OR_RETURN(bbox_desc.data_type(), DataType::kTensorBuffer); + const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); + CHECK_EQ_OR_RETURN(image_size_desc.data_type(), DataType::kInt32); + const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); + CHECK_EQ_OR_RETURN(flip_code_desc.data_type(), DataType::kInt8); + *ctx->OutputDType("out", 0) = ctx->InputDType("bbox", 0); + return Maybe::Ok(); +} + +/* static */ Maybe ObjectBboxScaleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& bbox_desc = ctx->InputTensorDesc("bbox", 0); + CHECK_EQ_OR_RETURN(bbox_desc.shape().NumAxes(), 1); + const int N = bbox_desc.shape().elem_cnt(); + + const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); + CHECK_EQ_OR_RETURN(scale_desc.shape().elem_cnt(), N * 2); + + *ctx->OutputShape("out", 0) = ctx->InputShape("bbox", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("bbox", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe ObjectBboxScaleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ObjectBboxScaleOp::GetSbp(user_op::SbpContext* ctx) { + return ImageObjectGetSbp(ctx); +} + +/* static */ Maybe ObjectBboxScaleOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& bbox_desc = ctx->InputTensorDesc("bbox", 0); + CHECK_EQ_OR_RETURN(bbox_desc.data_type(), DataType::kTensorBuffer); + const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); + CHECK_EQ_OR_RETURN(scale_desc.data_type(), DataType::kFloat); + *ctx->OutputDType("out", 0) = ctx->InputDType("bbox", 0); + return Maybe::Ok(); +} + +/* static */ Maybe ObjectSegmentationPolygonFlipOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); + CHECK_EQ_OR_RETURN(poly_desc.shape().NumAxes(), 1); + const int N = poly_desc.shape().elem_cnt(); + + const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); + CHECK_EQ_OR_RETURN(image_size_desc.shape().elem_cnt(), N * 2); + + const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); + CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N); + + *ctx->OutputShape("out", 0) = ctx->InputShape("poly", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("poly", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe ObjectSegmentationPolygonFlipOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ObjectSegmentationPolygonFlipOp::GetSbp(user_op::SbpContext* ctx) { + return ImageObjectGetSbp(ctx); +} + +/* static */ Maybe ObjectSegmentationPolygonFlipOp::InferDataType( + user_op::InferContext* ctx) { + const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); + CHECK_EQ_OR_RETURN(poly_desc.data_type(), DataType::kTensorBuffer); + const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); + CHECK_EQ_OR_RETURN(image_size_desc.data_type(), DataType::kInt32); + const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); + CHECK_EQ_OR_RETURN(flip_code_desc.data_type(), DataType::kInt8); + *ctx->OutputDType("out", 0) = ctx->InputDType("poly", 0); + return Maybe::Ok(); +} + +/* static */ Maybe ObjectSegmentationPolygonScaleOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); + CHECK_EQ_OR_RETURN(poly_desc.shape().NumAxes(), 1); + const int N = poly_desc.shape().elem_cnt(); + + const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); + CHECK_EQ_OR_RETURN(scale_desc.shape().elem_cnt(), N * 2); + + *ctx->OutputShape("out", 0) = ctx->InputShape("poly", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("poly", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe ObjectSegmentationPolygonScaleOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ObjectSegmentationPolygonScaleOp::GetSbp(user_op::SbpContext* ctx) { + return ImageObjectGetSbp(ctx); +} + +/* static */ Maybe ObjectSegmentationPolygonScaleOp::InferDataType( + user_op::InferContext* ctx) { + const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); + CHECK_EQ_OR_RETURN(poly_desc.data_type(), DataType::kTensorBuffer); + const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); + CHECK_EQ_OR_RETURN(scale_desc.data_type(), DataType::kFloat); + *ctx->OutputDType("out", 0) = ctx->InputDType("poly", 0); + return Maybe::Ok(); +} + +/* static */ Maybe ImageNormalizeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + CHECK_EQ_OR_RETURN(in_desc.shape().NumAxes(), 1); + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe ImageNormalizeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ImageNormalizeOp::GetSbp(user_op::SbpContext* ctx) { + return ImageObjectGetSbp(ctx); +} + +/* static */ Maybe ImageNormalizeOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + CHECK_EQ_OR_RETURN(in_desc.data_type(), DataType::kTensorBuffer); + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe ObjectSegmentationPolygonToMaskOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); + CHECK_EQ_OR_RETURN(poly_desc.shape().NumAxes(), 1); + const int N = poly_desc.shape().elem_cnt(); + + const user_op::TensorDesc& poly_index_desc = ctx->InputTensorDesc("poly_index", 0); + CHECK_EQ_OR_RETURN(poly_index_desc.shape().NumAxes(), 1); + CHECK_EQ_OR_RETURN(poly_index_desc.shape().elem_cnt(), N); + + const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); + CHECK_EQ_OR_RETURN(image_size_desc.shape().elem_cnt(), N * 2); + + *ctx->OutputShape("out", 0) = ctx->InputShape("poly", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("poly", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe ObjectSegmentationPolygonToMaskOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ObjectSegmentationPolygonToMaskOp::GetSbp(user_op::SbpContext* ctx) { + return ImageObjectGetSbp(ctx); +} + +/* static */ Maybe ObjectSegmentationPolygonToMaskOp::InferDataType( + user_op::InferContext* ctx) { + const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); + CHECK_EQ_OR_RETURN(poly_desc.data_type(), DataType::kTensorBuffer); + const user_op::TensorDesc& poly_index_desc = ctx->InputTensorDesc("poly_index", 0); + CHECK_EQ_OR_RETURN(poly_index_desc.data_type(), DataType::kTensorBuffer); + const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); + CHECK_EQ_OR_RETURN(image_size_desc.data_type(), DataType::kInt32); + *ctx->OutputDType("out", 0) = ctx->InputDType("poly", 0); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/image_preprocess_ops.cpp b/oneflow/user/ops/image_preprocess_ops.cpp index a244a4ecab4..91bb5cee58f 100644 --- a/oneflow/user/ops/image_preprocess_ops.cpp +++ b/oneflow/user/ops/image_preprocess_ops.cpp @@ -17,236 +17,224 @@ limitations under the License. #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/user/image/image_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("crop_mirror_normalize_from_tensorbuffer") - .Input("in") - .OptionalInput("mirror") - .Output("out") - .Attr("color_space", "BGR") - .Attr("output_layout", "NCHW") - .Attr>("mean", {0.0}) - .Attr>("std", {1.0}) - .Attr("crop_h", 0) - .Attr("crop_w", 0) - .Attr("crop_pos_x", 0.5) - .Attr("crop_pos_y", 0.5) - .Attr("output_dtype", DataType::kFloat) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - bool has_mirror = ctx->has_input("mirror", 0); - if (has_mirror) { - const user_op::TensorDesc& mirror_tensor = ctx->InputTensorDesc("mirror", 0); - CHECK_OR_RETURN(mirror_tensor.shape().NumAxes() == 1 - && in_tensor.shape().At(0) == mirror_tensor.shape().At(0)); - } - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - int64_t N = in_tensor.shape().At(0); - int64_t H = ctx->Attr("crop_h"); - int64_t W = ctx->Attr("crop_w"); - std::string color_space = ctx->Attr("color_space"); - int64_t C = ImageUtil::IsColor(color_space) ? 3 : 1; - - CHECK_OR_RETURN(H != 0 && W != 0); - CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1); - std::string output_layout = ctx->Attr("output_layout"); - if (output_layout == "NCHW") { - *out_tensor->mut_shape() = Shape({N, C, H, W}); - } else if (output_layout == "NHWC") { - *out_tensor->mut_shape() = Shape({N, H, W, C}); - } else { - return Error::CheckFailedError() - << "output_layout: " << output_layout << " is not supported"; - } - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - CHECK_EQ_OR_RETURN(in_tensor.data_type(), DataType::kTensorBuffer); - bool has_mirror = ctx->has_input("mirror", 0); - if (has_mirror) { - const user_op::TensorDesc& mirror_tensor = ctx->InputTensorDesc("mirror", 0); - CHECK_EQ_OR_RETURN(mirror_tensor.data_type(), DataType::kInt8); - } +/* static */ Maybe CropMirrorNormalizeFromTensorbufferOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + bool has_mirror = ctx->has_input("mirror", 0); + if (has_mirror) { + const user_op::TensorDesc& mirror_tensor = ctx->InputTensorDesc("mirror", 0); + CHECK_OR_RETURN(mirror_tensor.shape().NumAxes() == 1 + && in_tensor.shape().At(0) == mirror_tensor.shape().At(0)); + } + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + int64_t N = in_tensor.shape().At(0); + int64_t H = ctx->Attr("crop_h"); + int64_t W = ctx->Attr("crop_w"); + std::string color_space = ctx->Attr("color_space"); + int64_t C = ImageUtil::IsColor(color_space) ? 3 : 1; - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - DataType output_dtype = ctx->Attr("output_dtype"); - CHECK_EQ_OR_RETURN(output_dtype, - DataType::kFloat); // only support float now; for float16 in future - *out_tensor->mut_data_type() = output_dtype; - - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_USER_OP("crop_mirror_normalize_from_uint8") - .Input("in") - .OptionalInput("mirror") - .Output("out") - .Attr("color_space", "BGR") - .Attr("output_layout", "NCHW") - .Attr>("mean", {0.0}) - .Attr>("std", {1.0}) - .Attr("crop_h", 0) - .Attr("crop_w", 0) - .Attr("crop_pos_x", 0.5) - .Attr("crop_pos_y", 0.5) - .Attr("output_dtype", DataType::kFloat) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - bool has_mirror = ctx->has_input("mirror", 0); - if (has_mirror) { - const user_op::TensorDesc& mirror_tensor = ctx->InputTensorDesc("mirror", 0); - CHECK_OR_RETURN(mirror_tensor.shape().NumAxes() == 1 - && in_tensor.shape().At(0) == mirror_tensor.shape().At(0)); - } - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - int64_t N = in_tensor.shape().At(0); - int64_t H = ctx->Attr("crop_h"); - int64_t W = ctx->Attr("crop_w"); - std::string color_space = ctx->Attr("color_space"); - int64_t C = ImageUtil::IsColor(color_space) ? 3 : 1; - CHECK_EQ_OR_RETURN(in_tensor.shape().NumAxes(), 4); // {N, H, W, C} - CHECK_EQ_OR_RETURN(in_tensor.shape().At(3), C); - if (H == 0 || W == 0) { - H = in_tensor.shape().At(1); - W = in_tensor.shape().At(2); - } else { - H = std::min(H, in_tensor.shape().At(1)); - W = std::min(W, in_tensor.shape().At(2)); - } - std::string output_layout = ctx->Attr("output_layout"); - if (output_layout == "NCHW") { - *out_tensor->mut_shape() = Shape({N, C, H, W}); - } else if (output_layout == "NHWC") { - *out_tensor->mut_shape() = Shape({N, H, W, C}); - } else { - return Error::CheckFailedError() - << "output_layout: " << output_layout << " is not supported"; - } + CHECK_OR_RETURN(H != 0 && W != 0); + CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1); + std::string output_layout = ctx->Attr("output_layout"); + if (output_layout == "NCHW") { + *out_tensor->mut_shape() = Shape({N, C, H, W}); + } else if (output_layout == "NHWC") { + *out_tensor->mut_shape() = Shape({N, H, W, C}); + } else { + return Error::CheckFailedError() << "output_layout: " << output_layout << " is not supported"; + } + return Maybe::Ok(); +} - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - CHECK_EQ_OR_RETURN(in_tensor.data_type(), DataType::kUInt8); - bool has_mirror = ctx->has_input("mirror", 0); - if (has_mirror) { - const user_op::TensorDesc& mirror_tensor = ctx->InputTensorDesc("mirror", 0); - CHECK_EQ_OR_RETURN(mirror_tensor.data_type(), DataType::kInt8); - } - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - DataType output_dtype = ctx->Attr("output_dtype"); - CHECK_EQ_OR_RETURN(output_dtype, - DataType::kFloat); // only support float now; for float16 in future - *out_tensor->mut_data_type() = output_dtype; - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("coin_flip") - .Output("out") - .Attr("probability", 0.5) - .Attr("batch_size") - .Attr("seed", -1) - .Attr("has_seed", false) - .Attr>("nd_sbp") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - int64_t batch_size = ctx->Attr("batch_size"); - *out_tensor->mut_shape() = Shape({batch_size}); - return Maybe::Ok(); - }) - .SetPhysicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - int64_t batch_size = ctx->Attr("batch_size"); - const ParallelContext& parallel_ctx = ctx->parallel_ctx(); - const cfg::SbpParallel& out_sbp = ctx->SbpParallel4ArgNameAndIndex("out", 0); - if (parallel_ctx.parallel_num() > 1 && out_sbp.has_split_parallel()) { - BalancedSplitter bs(batch_size, parallel_ctx.parallel_num()); - *out_tensor->mut_shape() = Shape({bs.At(parallel_ctx.parallel_id()).size()}); - } else { - *out_tensor->mut_shape() = Shape({batch_size}); - } - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const Shape& hierarchy = ctx->parallel_hierarchy(); - cfg::NdSbp* output_dist = ctx->NdSbp4ArgNameAndIndex("out", 0); - // the input may be produced by tick which should be broadcast parallel dist - std::vector inputs_dist; - for (const auto& arg_pair : ctx->inputs()) { - inputs_dist.emplace_back(ctx->NdSbp4ArgNameAndIndex(arg_pair.first, arg_pair.second)); +/*static*/ Maybe CropMirrorNormalizeFromTensorbufferOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CropMirrorNormalizeFromTensorbufferOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe CropMirrorNormalizeFromTensorbufferOp::InferDataType( + user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + CHECK_EQ_OR_RETURN(in_tensor.data_type(), DataType::kTensorBuffer); + bool has_mirror = ctx->has_input("mirror", 0); + if (has_mirror) { + const user_op::TensorDesc& mirror_tensor = ctx->InputTensorDesc("mirror", 0); + CHECK_EQ_OR_RETURN(mirror_tensor.data_type(), DataType::kInt8); + } + + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + DataType output_dtype = ctx->Attr("output_dtype"); + CHECK_EQ_OR_RETURN(output_dtype, + DataType::kFloat); // only support float now; for float16 in future + *out_tensor->mut_data_type() = output_dtype; + + return Maybe::Ok(); +} + +/* static */ Maybe CropMirrorNormalizeFromUint8Op::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + bool has_mirror = ctx->has_input("mirror", 0); + if (has_mirror) { + const user_op::TensorDesc& mirror_tensor = ctx->InputTensorDesc("mirror", 0); + CHECK_OR_RETURN(mirror_tensor.shape().NumAxes() == 1 + && in_tensor.shape().At(0) == mirror_tensor.shape().At(0)); + } + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + int64_t N = in_tensor.shape().At(0); + int64_t H = ctx->Attr("crop_h"); + int64_t W = ctx->Attr("crop_w"); + std::string color_space = ctx->Attr("color_space"); + int64_t C = ImageUtil::IsColor(color_space) ? 3 : 1; + CHECK_EQ_OR_RETURN(in_tensor.shape().NumAxes(), 4); // {N, H, W, C} + CHECK_EQ_OR_RETURN(in_tensor.shape().At(3), C); + if (H == 0 || W == 0) { + H = in_tensor.shape().At(1); + W = in_tensor.shape().At(2); + } else { + H = std::min(H, in_tensor.shape().At(1)); + W = std::min(W, in_tensor.shape().At(2)); + } + std::string output_layout = ctx->Attr("output_layout"); + if (output_layout == "NCHW") { + *out_tensor->mut_shape() = Shape({N, C, H, W}); + } else if (output_layout == "NHWC") { + *out_tensor->mut_shape() = Shape({N, H, W, C}); + } else { + return Error::CheckFailedError() << "output_layout: " << output_layout << " is not supported"; + } + + return Maybe::Ok(); +} + +/*static*/ Maybe CropMirrorNormalizeFromUint8Op::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CropMirrorNormalizeFromUint8Op::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe CropMirrorNormalizeFromUint8Op::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + CHECK_EQ_OR_RETURN(in_tensor.data_type(), DataType::kUInt8); + bool has_mirror = ctx->has_input("mirror", 0); + if (has_mirror) { + const user_op::TensorDesc& mirror_tensor = ctx->InputTensorDesc("mirror", 0); + CHECK_EQ_OR_RETURN(mirror_tensor.data_type(), DataType::kInt8); + } + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + DataType output_dtype = ctx->Attr("output_dtype"); + CHECK_EQ_OR_RETURN(output_dtype, + DataType::kFloat); // only support float now; for float16 in future + *out_tensor->mut_data_type() = output_dtype; + return Maybe::Ok(); +} + +/* static */ Maybe CoinFlipOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + int64_t batch_size = ctx->Attr("batch_size"); + *out_tensor->mut_shape() = Shape({batch_size}); + return Maybe::Ok(); +} + +/* static */ Maybe CoinFlipOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + int64_t batch_size = ctx->Attr("batch_size"); + const ParallelContext& parallel_ctx = ctx->parallel_ctx(); + const cfg::SbpParallel& out_sbp = ctx->SbpParallel4ArgNameAndIndex("out", 0); + if (parallel_ctx.parallel_num() > 1 && out_sbp.has_split_parallel()) { + BalancedSplitter bs(batch_size, parallel_ctx.parallel_num()); + *out_tensor->mut_shape() = Shape({bs.At(parallel_ctx.parallel_id()).size()}); + } else { + *out_tensor->mut_shape() = Shape({batch_size}); + } + return Maybe::Ok(); +} + +/* static */ Maybe CoinFlipOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("out", 0), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe CoinFlipOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + const Shape& hierarchy = ctx->parallel_hierarchy(); + cfg::NdSbp* output_dist = ctx->NdSbp4ArgNameAndIndex("out", 0); + // the input may be produced by tick which should be broadcast parallel dist + std::vector inputs_dist; + for (const auto& arg_pair : ctx->inputs()) { + inputs_dist.emplace_back(ctx->NdSbp4ArgNameAndIndex(arg_pair.first, arg_pair.second)); + } + const auto& dist_conf = ctx->user_op_conf().attr>("nd_sbp"); + if (dist_conf.size() == 0) { + FOR_RANGE(int, i, 0, hierarchy.NumAxes()) { + output_dist->add_sbp_parallel()->mutable_split_parallel()->set_axis(0); + for (auto* input_dist : inputs_dist) { + input_dist->add_sbp_parallel()->mutable_broadcast_parallel(); } - const auto& dist_conf = ctx->user_op_conf().attr>("nd_sbp"); - if (dist_conf.size() == 0) { - FOR_RANGE(int, i, 0, hierarchy.NumAxes()) { - output_dist->add_sbp_parallel()->mutable_split_parallel()->set_axis(0); - for (auto* input_dist : inputs_dist) { - input_dist->add_sbp_parallel()->mutable_broadcast_parallel(); - } - } - } else { - CHECK_EQ_OR_RETURN(dist_conf.size(), hierarchy.NumAxes()); - for (const std::string& sbp_str : dist_conf) { - cfg::SbpParallel sbp_parallel; - CHECK_OR_RETURN(ParseSbpParallelFromString(sbp_str, &sbp_parallel)); - CHECK_OR_RETURN( - (sbp_parallel.has_split_parallel() && sbp_parallel.split_parallel().axis() == 0) - || sbp_parallel.has_broadcast_parallel()); - *output_dist->add_sbp_parallel() = sbp_parallel; - for (auto* input_dist : inputs_dist) { - input_dist->add_sbp_parallel()->mutable_broadcast_parallel(); - } - } + } + } else { + CHECK_EQ_OR_RETURN(dist_conf.size(), hierarchy.NumAxes()); + for (const std::string& sbp_str : dist_conf) { + cfg::SbpParallel sbp_parallel; + CHECK_OR_RETURN(ParseSbpParallelFromString(sbp_str, &sbp_parallel)); + CHECK_OR_RETURN( + (sbp_parallel.has_split_parallel() && sbp_parallel.split_parallel().axis() == 0) + || sbp_parallel.has_broadcast_parallel()); + *output_dist->add_sbp_parallel() = sbp_parallel; + for (auto* input_dist : inputs_dist) { + input_dist->add_sbp_parallel()->mutable_broadcast_parallel(); } - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("out", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - *out_tensor->mut_data_type() = DataType::kInt8; - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("image_random_crop") - .Input("in") - .Output("out") - .Attr("num_attempts", 10) - .Attr("seed", -1) - .Attr("has_seed", false) - .Attr>("random_area", {0.08, 1.0}) - .Attr>("random_aspect_ratio", {0.75, 1.333333}) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - *out_tensor->mut_shape() = in_tensor.shape(); - *out_tensor->mut_is_dynamic() = in_tensor.is_dynamic(); - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::SplitForEachAxis) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); - CHECK_NOTNULL_OR_RETURN(in_modifier); - in_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - CHECK_OR_RETURN(in_tensor.data_type() == DataType::kTensorBuffer); - *ctx->OutputDType("out", 0) = in_tensor.data_type(); - return Maybe::Ok(); - }); + } + } + return Maybe::Ok(); +} + +/* static */ Maybe CoinFlipOp::InferDataType(user_op::InferContext* ctx) { + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + *out_tensor->mut_data_type() = DataType::kInt8; + return Maybe::Ok(); +} + +/* static */ Maybe ImageRandomCropOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + *out_tensor->mut_shape() = in_tensor.shape(); + *out_tensor->mut_is_dynamic() = in_tensor.is_dynamic(); + return Maybe::Ok(); +} + +/*static*/ Maybe ImageRandomCropOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ImageRandomCropOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); +} + +/* static */ Maybe ImageRandomCropOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); + CHECK_NOTNULL_OR_RETURN(in_modifier); + in_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe ImageRandomCropOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + CHECK_OR_RETURN(in_tensor.data_type() == DataType::kTensorBuffer); + *ctx->OutputDType("out", 0) = in_tensor.data_type(); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/image_resize_ops.cpp b/oneflow/user/ops/image_resize_ops.cpp index eb030c17531..fe6f351ecaf 100644 --- a/oneflow/user/ops/image_resize_ops.cpp +++ b/oneflow/user/ops/image_resize_ops.cpp @@ -15,132 +15,130 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/image/image_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("image_resize_to_fixed") - .Input("in") - .Output("out") - .Output("scale") - .Attr("target_width", 0) - .Attr("target_height", 0) - .Attr("channels", 3) - .Attr("data_type", DataType::kUInt8) - .Attr("interpolation_type", "bilinear") - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& def, - const user_op::UserOpConfWrapper& conf) -> Maybe { - bool check_failed = false; - std::ostringstream err; - err << "Illegal attr value for " << conf.op_type_name() << " op, op_name: " << conf.op_name(); - int64_t target_width = conf.attr("target_width"); - int64_t target_height = conf.attr("target_height"); - if (target_width <= 0 || target_height <= 0) { - err << ", target_width: " << target_width << ", target_height: " << target_height; - check_failed = true; - } - int64_t channels = conf.attr("channels"); - if (channels != 1 && channels != 3) { - err << ", channels: " << channels << " (channels can only be 1 or 3)"; - check_failed = true; - } - DataType data_type = conf.attr("data_type"); - if (data_type != DataType::kUInt8 && data_type != DataType::kFloat) { - err << ", data_type: " << data_type << " (only support kUInt8 and kFloat for now)"; - check_failed = true; - } - const std::string& interp_type = conf.attr("interpolation_type"); - if (!CheckInterpolationValid(interp_type, err)) { check_failed = true; } - if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); } - return Maybe::Ok(); - }) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().elem_cnt() > 0); - int64_t batch_size = in_tensor.shape().elem_cnt(); - int64_t target_width = ctx->Attr("target_width"); - int64_t target_height = ctx->Attr("target_height"); - int64_t channels = ctx->Attr("channels"); - - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - *out_tensor->mut_shape() = Shape({batch_size, target_height, target_width, channels}); - out_tensor->set_is_dynamic(in_tensor.is_dynamic()); - - user_op::TensorDesc* scale_tensor = ctx->OutputTensorDesc("scale", 0); - *scale_tensor->mut_shape() = Shape({batch_size, 2}); - scale_tensor->set_is_dynamic(in_tensor.is_dynamic()); - - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - CHECK_OR_RETURN(in_tensor.data_type() == DataType::kTensorBuffer); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - *out_tensor->mut_data_type() = ctx->Attr("data_type"); - user_op::TensorDesc* scale_tensor = ctx->OutputTensorDesc("scale", 0); - *scale_tensor->mut_data_type() = DataType::kFloat; - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("image_resize_keep_aspect_ratio") - .Input("in") - .Output("out") - .Output("size") - .Output("scale") - .Attr("target_size") - .Attr("min_size", 0) - .Attr("max_size", 0) - .Attr("resize_longer", false) - .Attr("interpolation_type", "bilinear") - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& def, - const user_op::UserOpConfWrapper& conf) -> Maybe { - bool check_failed = false; - std::ostringstream err; - err << "Illegal attr value for " << conf.op_type_name() << " op, op_name: " << conf.op_name(); - const int32_t target_size = conf.attr("target_size"); - const int32_t max_size = conf.attr("max_size"); - if (target_size <= 0) { - err << ", target_size: " << target_size << " (target_size must be greater than 0)"; - check_failed = true; - } - if (max_size < target_size && max_size > 0) { - err << ", max_size: " << max_size - << " (max_size must be greater than target_size or equal to 0)"; - check_failed = true; - } - const std::string& interp_type = conf.attr("interpolation_type"); - if (!CheckInterpolationValid(interp_type, err)) { check_failed = true; } - if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); } - return Maybe::Ok(); - }) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1 && in_desc.shape().At(0) > 0); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_shape() = in_desc.shape(); - user_op::TensorDesc* size_desc = ctx->OutputTensorDesc("size", 0); - *size_desc->mut_shape() = in_desc.shape(); - user_op::TensorDesc* scale_desc = ctx->OutputTensorDesc("scale", 0); - *scale_desc->mut_shape() = in_desc.shape(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_data_type() = DataType::kTensorBuffer; - user_op::TensorDesc* size_desc = ctx->OutputTensorDesc("size", 0); - *size_desc->mut_data_type() = DataType::kTensorBuffer; - user_op::TensorDesc* scale_desc = ctx->OutputTensorDesc("scale", 0); - *scale_desc->mut_data_type() = DataType::kTensorBuffer; - return Maybe::Ok(); - }); +/* static */ Maybe ImageResizeToFixedOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().elem_cnt() > 0); + int64_t batch_size = in_tensor.shape().elem_cnt(); + int64_t target_width = ctx->Attr("target_width"); + int64_t target_height = ctx->Attr("target_height"); + int64_t channels = ctx->Attr("channels"); + + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + *out_tensor->mut_shape() = Shape({batch_size, target_height, target_width, channels}); + out_tensor->set_is_dynamic(in_tensor.is_dynamic()); + + user_op::TensorDesc* scale_tensor = ctx->OutputTensorDesc("scale", 0); + *scale_tensor->mut_shape() = Shape({batch_size, 2}); + scale_tensor->set_is_dynamic(in_tensor.is_dynamic()); + + return Maybe::Ok(); +} + +/*static*/ Maybe ImageResizeToFixedOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ImageResizeToFixedOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ImageResizeToFixedOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + bool check_failed = false; + std::ostringstream err; + err << "Illegal attr value for " << conf.op_type_name() << " op, op_name: " << conf.op_name(); + int64_t target_width = conf.attr("target_width"); + int64_t target_height = conf.attr("target_height"); + if (target_width <= 0 || target_height <= 0) { + err << ", target_width: " << target_width << ", target_height: " << target_height; + check_failed = true; + } + int64_t channels = conf.attr("channels"); + if (channels != 1 && channels != 3) { + err << ", channels: " << channels << " (channels can only be 1 or 3)"; + check_failed = true; + } + DataType data_type = conf.attr("data_type"); + if (data_type != DataType::kUInt8 && data_type != DataType::kFloat) { + err << ", data_type: " << data_type << " (only support kUInt8 and kFloat for now)"; + check_failed = true; + } + const std::string& interp_type = conf.attr("interpolation_type"); + if (!CheckInterpolationValid(interp_type, err)) { check_failed = true; } + if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); } + return Maybe::Ok(); +} + +/* static */ Maybe ImageResizeToFixedOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + CHECK_OR_RETURN(in_tensor.data_type() == DataType::kTensorBuffer); + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + *out_tensor->mut_data_type() = ctx->Attr("data_type"); + user_op::TensorDesc* scale_tensor = ctx->OutputTensorDesc("scale", 0); + *scale_tensor->mut_data_type() = DataType::kFloat; + return Maybe::Ok(); +} + +/* static */ Maybe ImageResizeKeepAspectRatioOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1 && in_desc.shape().At(0) > 0); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_shape() = in_desc.shape(); + user_op::TensorDesc* size_desc = ctx->OutputTensorDesc("size", 0); + *size_desc->mut_shape() = in_desc.shape(); + user_op::TensorDesc* scale_desc = ctx->OutputTensorDesc("scale", 0); + *scale_desc->mut_shape() = in_desc.shape(); + return Maybe::Ok(); +} + +/*static*/ Maybe ImageResizeKeepAspectRatioOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ImageResizeKeepAspectRatioOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ImageResizeKeepAspectRatioOp::CheckAttr( + const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { + bool check_failed = false; + std::ostringstream err; + err << "Illegal attr value for " << conf.op_type_name() << " op, op_name: " << conf.op_name(); + const int32_t target_size = conf.attr("target_size"); + const int32_t max_size = conf.attr("max_size"); + if (target_size <= 0) { + err << ", target_size: " << target_size << " (target_size must be greater than 0)"; + check_failed = true; + } + if (max_size < target_size && max_size > 0) { + err << ", max_size: " << max_size + << " (max_size must be greater than target_size or equal to 0)"; + check_failed = true; + } + const std::string& interp_type = conf.attr("interpolation_type"); + if (!CheckInterpolationValid(interp_type, err)) { check_failed = true; } + if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); } + return Maybe::Ok(); +} + +/* static */ Maybe ImageResizeKeepAspectRatioOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_data_type() = DataType::kTensorBuffer; + user_op::TensorDesc* size_desc = ctx->OutputTensorDesc("size", 0); + *size_desc->mut_data_type() = DataType::kTensorBuffer; + user_op::TensorDesc* scale_desc = ctx->OutputTensorDesc("scale", 0); + *scale_desc->mut_data_type() = DataType::kTensorBuffer; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/image_target_resize_op.cpp b/oneflow/user/ops/image_target_resize_op.cpp index dcf6e78ce85..49d7db09479 100644 --- a/oneflow/user/ops/image_target_resize_op.cpp +++ b/oneflow/user/ops/image_target_resize_op.cpp @@ -14,59 +14,60 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("image_target_resize") - .Input("in") - .Output("out") - .Output("size") - .Output("scale") - .Attr("target_size") - .Attr("max_size") - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& def, - const user_op::UserOpConfWrapper& conf) -> Maybe { - bool check_failed = false; - std::stringstream err; - err << "Illegal attr value for " << conf.op_type_name() << " op, op_name: " << conf.op_name(); - const int32_t target_size = conf.attr("target_size"); - const int32_t max_size = conf.attr("max_size"); - if (target_size <= 0) { - err << ", target_size: " << target_size << " (target_size must be greater than 0)"; - check_failed = true; - } - if (max_size < target_size) { - err << ", max_size: " << max_size << " (max_size must be greater than 0)"; - check_failed = true; - } - if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); } - return Maybe::Ok(); - }) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1 && in_desc.shape().At(0) >= 1); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_shape() = in_desc.shape(); - user_op::TensorDesc* size_desc = ctx->OutputTensorDesc("size", 0); - *size_desc->mut_shape() = Shape({in_desc.shape().elem_cnt(), 2}); - user_op::TensorDesc* scale_desc = ctx->OutputTensorDesc("scale", 0); - *scale_desc->mut_shape() = Shape({in_desc.shape().elem_cnt(), 2}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_data_type() = DataType::kTensorBuffer; - user_op::TensorDesc* size_desc = ctx->OutputTensorDesc("size", 0); - *size_desc->mut_data_type() = DataType::kInt32; - user_op::TensorDesc* scale_desc = ctx->OutputTensorDesc("scale", 0); - *scale_desc->mut_data_type() = DataType::kFloat; - return Maybe::Ok(); - }); +/* static */ Maybe ImageTargetResizeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1 && in_desc.shape().At(0) >= 1); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_shape() = in_desc.shape(); + user_op::TensorDesc* size_desc = ctx->OutputTensorDesc("size", 0); + *size_desc->mut_shape() = Shape({in_desc.shape().elem_cnt(), 2}); + user_op::TensorDesc* scale_desc = ctx->OutputTensorDesc("scale", 0); + *scale_desc->mut_shape() = Shape({in_desc.shape().elem_cnt(), 2}); + return Maybe::Ok(); +} + +/*static*/ Maybe ImageTargetResizeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ImageTargetResizeOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ImageTargetResizeOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + bool check_failed = false; + std::stringstream err; + err << "Illegal attr value for " << conf.op_type_name() << " op, op_name: " << conf.op_name(); + const int32_t target_size = conf.attr("target_size"); + const int32_t max_size = conf.attr("max_size"); + if (target_size <= 0) { + err << ", target_size: " << target_size << " (target_size must be greater than 0)"; + check_failed = true; + } + if (max_size < target_size) { + err << ", max_size: " << max_size << " (max_size must be greater than 0)"; + check_failed = true; + } + if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); } + return Maybe::Ok(); +} + +/* static */ Maybe ImageTargetResizeOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_data_type() = DataType::kTensorBuffer; + user_op::TensorDesc* size_desc = ctx->OutputTensorDesc("size", 0); + *size_desc->mut_data_type() = DataType::kInt32; + user_op::TensorDesc* scale_desc = ctx->OutputTensorDesc("scale", 0); + *scale_desc->mut_data_type() = DataType::kFloat; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/in_top_k_op.cpp b/oneflow/user/ops/in_top_k_op.cpp index dc3f1a5858f..6ee9b5592e4 100644 --- a/oneflow/user/ops/in_top_k_op.cpp +++ b/oneflow/user/ops/in_top_k_op.cpp @@ -14,38 +14,40 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("in_top_k") - .Input("targets") - .Input("predictions") - .Attr("k") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& targets = ctx->InputTensorDesc("targets", 0); - const user_op::TensorDesc& predictions = ctx->InputTensorDesc("predictions", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - CHECK_EQ_OR_RETURN(targets.shape().NumAxes(), 1); - CHECK_EQ_OR_RETURN(predictions.shape().NumAxes(), 2); - const bool is_dynamic = targets.is_dynamic(); - CHECK_EQ_OR_RETURN(is_dynamic, predictions.is_dynamic()); - out->set_is_dynamic(is_dynamic); - *out->mut_shape() = targets.shape(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& targets = ctx->InputTensorDesc("targets", 0); - CHECK_OR_RETURN(IsIndexDataType(targets.data_type())); - const user_op::TensorDesc& predictions = ctx->InputTensorDesc("predictions", 0); - CHECK_EQ_OR_RETURN(predictions.data_type(), DataType::kFloat); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - *out->mut_data_type() = kInt8; - return Maybe::Ok(); - }); +/* static */ Maybe InTopKOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& targets = ctx->InputTensorDesc("targets", 0); + const user_op::TensorDesc& predictions = ctx->InputTensorDesc("predictions", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + CHECK_EQ_OR_RETURN(targets.shape().NumAxes(), 1); + CHECK_EQ_OR_RETURN(predictions.shape().NumAxes(), 2); + const bool is_dynamic = targets.is_dynamic(); + CHECK_EQ_OR_RETURN(is_dynamic, predictions.is_dynamic()); + out->set_is_dynamic(is_dynamic); + *out->mut_shape() = targets.shape(); + return Maybe::Ok(); +} + +/*static*/ Maybe InTopKOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/* static */ Maybe InTopKOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); } + +/* static */ Maybe InTopKOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& targets = ctx->InputTensorDesc("targets", 0); + CHECK_OR_RETURN(IsIndexDataType(targets.data_type())); + const user_op::TensorDesc& predictions = ctx->InputTensorDesc("predictions", 0); + CHECK_EQ_OR_RETURN(predictions.data_type(), DataType::kFloat); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + *out->mut_data_type() = kInt8; + return Maybe::Ok(); +} + +} // namespace oneflow diff --git a/oneflow/user/ops/indexed_slices_reduce_sum_op.cpp b/oneflow/user/ops/indexed_slices_reduce_sum_op.cpp index e5e9c035b5a..5b61c8ff2ba 100644 --- a/oneflow/user/ops/indexed_slices_reduce_sum_op.cpp +++ b/oneflow/user/ops/indexed_slices_reduce_sum_op.cpp @@ -14,42 +14,47 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("indexed_slices_reduce_sum") - .Input("x_indices") - .Input("x_values") - .Output("y_indices") - .Output("y_values") - .Output("num_unique") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_indices = ctx->InputTensorDesc("x_indices", 0); - const user_op::TensorDesc& x_values = ctx->InputTensorDesc("x_values", 0); - CHECK_LT_OR_RETURN(x_indices.shape().NumAxes(), x_values.shape().NumAxes()); - FOR_RANGE(int64_t, i, 0, x_indices.shape().NumAxes()) { - CHECK_EQ_OR_RETURN(x_indices.shape().At(i), x_values.shape().At(i)); - } - - const int64_t n = x_indices.shape().elem_cnt(); - const int64_t m = x_values.shape().elem_cnt() / n; - user_op::TensorDesc* y_indices = ctx->OutputTensorDesc("y_indices", 0); - user_op::TensorDesc* y_values = ctx->OutputTensorDesc("y_values", 0); - *y_indices = x_indices; - *y_indices->mut_shape() = Shape({n}); - *y_values = x_values; - *y_values->mut_shape() = Shape({n, m}); - user_op::TensorDesc* num_unique = ctx->OutputTensorDesc("num_unique", 0); - *num_unique->mut_shape() = Shape({1}); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_indices = ctx->InputTensorDesc("x_indices", 0); - CHECK_OR_RETURN(IsIndexDataType(x_indices.data_type())); - user_op::TensorDesc* num_unique = ctx->OutputTensorDesc("num_unique", 0); - *num_unique->mut_data_type() = DataType::kInt64; - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/* static */ Maybe IndexedSlicesReduceSumOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& x_indices = ctx->InputTensorDesc("x_indices", 0); + const user_op::TensorDesc& x_values = ctx->InputTensorDesc("x_values", 0); + CHECK_LT_OR_RETURN(x_indices.shape().NumAxes(), x_values.shape().NumAxes()); + FOR_RANGE(int64_t, i, 0, x_indices.shape().NumAxes()) { + CHECK_EQ_OR_RETURN(x_indices.shape().At(i), x_values.shape().At(i)); + } + + const int64_t n = x_indices.shape().elem_cnt(); + const int64_t m = x_values.shape().elem_cnt() / n; + user_op::TensorDesc* y_indices = ctx->OutputTensorDesc("y_indices", 0); + user_op::TensorDesc* y_values = ctx->OutputTensorDesc("y_values", 0); + *y_indices = x_indices; + *y_indices->mut_shape() = Shape({n}); + *y_values = x_values; + *y_values->mut_shape() = Shape({n, m}); + user_op::TensorDesc* num_unique = ctx->OutputTensorDesc("num_unique", 0); + *num_unique->mut_shape() = Shape({1}); + return Maybe::Ok(); +} + +/*static*/ Maybe IndexedSlicesReduceSumOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe IndexedSlicesReduceSumOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe IndexedSlicesReduceSumOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_indices = ctx->InputTensorDesc("x_indices", 0); + CHECK_OR_RETURN(IsIndexDataType(x_indices.data_type())); + user_op::TensorDesc* num_unique = ctx->OutputTensorDesc("num_unique", 0); + *num_unique->mut_data_type() = DataType::kInt64; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/kl_div_op.cpp b/oneflow/user/ops/kl_div_op.cpp index fce4255f74d..cb58f29764b 100644 --- a/oneflow/user/ops/kl_div_op.cpp +++ b/oneflow/user/ops/kl_div_op.cpp @@ -16,10 +16,11 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/loss_op_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { -Maybe InferTensorDescFn(user_op::InferContext* ctx) { +Maybe KlInferTensorDescFn(user_op::InferContext* ctx) { const auto& input_desc = ctx->InputTensorDesc("input", 0); const auto& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.is_dynamic(), target_desc.is_dynamic()); @@ -32,7 +33,7 @@ Maybe InferTensorDescFn(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe InferDataType(user_op::InferContext* ctx) { +Maybe KlInferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type()); @@ -69,48 +70,58 @@ Maybe InferGradDataType(user_op::InferContext* ctx) { } // namespace -REGISTER_USER_OP("kl_div_loss") - .Input("input") - .Input("target") - .Output("out") - .Attr("log_target") - .SetTensorDescInferFn(InferTensorDescFn) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); - CHECK_OR_RETURN(target_modifier != nullptr); - target_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetDataTypeInferFn(InferDataType) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); - FOR_RANGE(int64_t, i, 0, input_shape.NumAxes()) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(user_op::OpArg("out", 0), i).Build(); - } - return Maybe::Ok(); - }); +/* static */ Maybe KlDivLossOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return KlInferTensorDescFn(ctx); +} -REGISTER_USER_OP("kl_div_loss_grad") - .Input("input") - .Input("target") - .Input("dy") - .Output("dx") - .Attr("log_target") - .SetTensorDescInferFn(InferGradTensorDescFn) - .SetDataTypeInferFn(InferGradDataType) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); - FOR_RANGE(int64_t, i, 0, input_shape.NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("input", 0), i) - .Split(user_op::OpArg("target", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Build(); - } - return Maybe::Ok(); - }); +/*static*/ Maybe KlDivLossOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe KlDivLossOp::GetSbp(user_op::SbpContext* ctx) { + const auto& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); + FOR_RANGE(int64_t, i, 0, input_shape.NumAxes()) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe KlDivLossOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); + CHECK_OR_RETURN(target_modifier != nullptr); + target_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe KlDivLossOp::InferDataType(user_op::InferContext* ctx) { + return KlInferDataType(ctx); +} + +/* static */ Maybe KlDivLossGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferGradTensorDescFn(ctx); +} + +/*static*/ Maybe KlDivLossGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe KlDivLossGradOp::GetSbp(user_op::SbpContext* ctx) { + const auto& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); + FOR_RANGE(int64_t, i, 0, input_shape.NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("input", 0), i) + .Split(user_op::OpArg("target", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe KlDivLossGradOp::InferDataType(user_op::InferContext* ctx) { + return InferGradDataType(ctx); +} REGISTER_USER_OP_GRAD("kl_div_loss") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/l1_l2_regularize_gradient_op.cpp b/oneflow/user/ops/l1_l2_regularize_gradient_op.cpp index c99910e689c..05affa22404 100644 --- a/oneflow/user/ops/l1_l2_regularize_gradient_op.cpp +++ b/oneflow/user/ops/l1_l2_regularize_gradient_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -38,20 +39,26 @@ Maybe GetSbpSignatures(user_op::SbpContext* ctx) { } // namespace -REGISTER_NO_GRAD_USER_OP("l1_l2_regularize_gradient") - .Input("model") - .Input("model_diff") - .Output("out") - .Attr("l1", 0) - .Attr("l2", 0) - .SetTensorDescInferFn(InferTensorDesc) - .SetGetSbpFn(GetSbpSignatures) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); - const user_op::TensorDesc& model_diff = ctx->InputTensorDesc("model_diff", 0); - CHECK_EQ_OR_RETURN(model_diff.data_type(), model.data_type()); - *ctx->OutputDType("out", 0) = ctx->InputDType("model", 0); - return Maybe::Ok(); - }); +/* static */ Maybe L1L2RegularizeGradientOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return InferTensorDesc(ctx); +} + +/*static*/ Maybe L1L2RegularizeGradientOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe L1L2RegularizeGradientOp::GetSbp(user_op::SbpContext* ctx) { + return GetSbpSignatures(ctx); +} + +/* static */ Maybe L1L2RegularizeGradientOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); + const user_op::TensorDesc& model_diff = ctx->InputTensorDesc("model_diff", 0); + CHECK_EQ_OR_RETURN(model_diff.data_type(), model.data_type()); + *ctx->OutputDType("out", 0) = ctx->InputDType("model", 0); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/l2_normalize_op.cpp b/oneflow/user/ops/l2_normalize_op.cpp index 1e8abab4651..d1723c41c97 100644 --- a/oneflow/user/ops/l2_normalize_op.cpp +++ b/oneflow/user/ops/l2_normalize_op.cpp @@ -14,98 +14,98 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("l2_normalize") - .Input("x") - .Output("y") - .Output("square_x_sum") - .Attr("axis") - .Attr("epsilon") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - Shape* y_shape = ctx->OutputShape("y", 0); - Shape* square_x_sum_shape = ctx->OutputShape("square_x_sum", 0); - const int32_t axis = ctx->Attr("axis"); - const float epsilon = ctx->Attr("epsilon"); - CHECK_GE_OR_RETURN(axis, 0); - CHECK_LT_OR_RETURN(axis, x_shape.NumAxes()); - CHECK_GT_OR_RETURN(epsilon, 0); - *y_shape = x_shape; - *square_x_sum_shape = x_shape; - square_x_sum_shape->Set(axis, 1); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - const int32_t axis = ctx->Attr("axis"); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - if (i != axis) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("y", 0), i) - .Split(user_op::OpArg("square_x_sum", 0), i) - .Build(); - } - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("square_x_sum", 0) = ctx->InputDType("x", 0); - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/* static */ Maybe L2NormalizeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + Shape* y_shape = ctx->OutputShape("y", 0); + Shape* square_x_sum_shape = ctx->OutputShape("square_x_sum", 0); + const int32_t axis = ctx->Attr("axis"); + const float epsilon = ctx->Attr("epsilon"); + CHECK_GE_OR_RETURN(axis, 0); + CHECK_LT_OR_RETURN(axis, x_shape.NumAxes()); + CHECK_GT_OR_RETURN(epsilon, 0); + *y_shape = x_shape; + *square_x_sum_shape = x_shape; + square_x_sum_shape->Set(axis, 1); + return Maybe::Ok(); +} -REGISTER_USER_OP("l2_normalize_grad") - .Input("dy") - .Input("y") - .Input("square_x_sum") - .Output("dx") - .Attr("axis") - .Attr("epsilon") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - const Shape& y_shape = ctx->InputShape("y", 0); - const Shape& square_x_sum_shape = ctx->InputShape("square_x_sum", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - const int32_t axis = ctx->Attr("axis"); - const float epsilon = ctx->Attr("epsilon"); - CHECK_EQ_OR_RETURN(dy_shape, y_shape); - CHECK_GE_OR_RETURN(axis, 0); - CHECK_LT_OR_RETURN(axis, dy_shape.NumAxes()); - CHECK_GT_OR_RETURN(epsilon, 0); - FOR_RANGE(int32_t, i, 0, dy_shape.NumAxes()) { - if (i == axis) { - CHECK_EQ_OR_RETURN(square_x_sum_shape.At(i), 1); - } else { - CHECK_EQ_OR_RETURN(square_x_sum_shape.At(i), dy_shape.At(i)); - } - } - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); - const int32_t axis = ctx->Attr("axis"); - FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) { - if (i != axis) { - ctx->NewBuilder() - .Split(user_op::OpArg("y", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("square_x_sum", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("y", 0), ctx->InputDType("dy", 0)); - CHECK_EQ_OR_RETURN(ctx->InputDType("y", 0), ctx->InputDType("square_x_sum", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe L2NormalizeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe L2NormalizeOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + const int32_t axis = ctx->Attr("axis"); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + if (i != axis) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("y", 0), i) + .Split(user_op::OpArg("square_x_sum", 0), i) + .Build(); + } + } + return Maybe::Ok(); +} + +/* static */ Maybe L2NormalizeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("square_x_sum", 0) = ctx->InputDType("x", 0); + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} + +/* static */ Maybe L2NormalizeGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + const Shape& y_shape = ctx->InputShape("y", 0); + const Shape& square_x_sum_shape = ctx->InputShape("square_x_sum", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + const int32_t axis = ctx->Attr("axis"); + const float epsilon = ctx->Attr("epsilon"); + CHECK_EQ_OR_RETURN(dy_shape, y_shape); + CHECK_GE_OR_RETURN(axis, 0); + CHECK_LT_OR_RETURN(axis, dy_shape.NumAxes()); + CHECK_GT_OR_RETURN(epsilon, 0); + FOR_RANGE(int32_t, i, 0, dy_shape.NumAxes()) { + if (i == axis) { + CHECK_EQ_OR_RETURN(square_x_sum_shape.At(i), 1); + } else { + CHECK_EQ_OR_RETURN(square_x_sum_shape.At(i), dy_shape.At(i)); + } + } + *dx_shape = dy_shape; + return Maybe::Ok(); +} + +/*static*/ Maybe L2NormalizeGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe L2NormalizeGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); + const int32_t axis = ctx->Attr("axis"); + FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) { + if (i != axis) { + ctx->NewBuilder() + .Split(user_op::OpArg("y", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("square_x_sum", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + } + return Maybe::Ok(); +} + +/* static */ Maybe L2NormalizeGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("y", 0), ctx->InputDType("dy", 0)); + CHECK_EQ_OR_RETURN(ctx->InputDType("y", 0), ctx->InputDType("square_x_sum", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("l2_normalize") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/layer_norm_op.cpp b/oneflow/user/ops/layer_norm_op.cpp index 4f4bc2ae3f1..e014c8fddcb 100644 --- a/oneflow/user/ops/layer_norm_op.cpp +++ b/oneflow/user/ops/layer_norm_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -40,228 +41,215 @@ oneflow::DataType InferBnParamDataType(const DataType x_data_type) { } // namespace -REGISTER_USER_OP("layer_norm") - .Input("x") - .OptionalInput("beta") - .OptionalInput("gamma") - .Output("y") - .Output("mean") - .Output("inv_variance") - .Attr("center") - .Attr("scale") - .Attr("begin_norm_axis") - .Attr("begin_params_axis") - .Attr("epsilon") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); - user_op::TensorDesc* mean = ctx->OutputTensorDesc("mean", 0); - user_op::TensorDesc* inv_variance = ctx->OutputTensorDesc("inv_variance", 0); - const bool center = ctx->Attr("center"); - const bool scale = ctx->Attr("scale"); - const int64_t begin_params_axis = - ShiftNegativeAxisIfNeed(x.shape(), ctx->Attr("begin_params_axis")); - *y->mut_shape() = x.shape(); - *y->mut_is_dynamic() = x.is_dynamic(); - DimVector param_shape_dim_vec; - param_shape_dim_vec.insert(param_shape_dim_vec.end(), - x.shape().dim_vec().cbegin() + begin_params_axis, - x.shape().dim_vec().cend()); - const Shape param_shape(param_shape_dim_vec); - if (center) { - const user_op::TensorDesc& beta = ctx->InputTensorDesc("beta", 0); - CHECK_EQ_OR_RETURN(beta.shape(), param_shape); - } - if (scale) { - const user_op::TensorDesc& gamma = ctx->InputTensorDesc("gamma", 0); - CHECK_EQ_OR_RETURN(gamma.shape(), param_shape); - } - const int64_t begin_norm_axis = - ShiftNegativeAxisIfNeed(x.shape(), ctx->Attr("begin_norm_axis")); - *mean->mut_shape() = InferBnParamShape(x.shape(), begin_norm_axis); - *inv_variance = *mean; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); - int64_t begin_norm_axis = - ShiftNegativeAxisIfNeed(x_shape, ctx->Attr("begin_norm_axis")); - int64_t begin_params_axis = - ShiftNegativeAxisIfNeed(x_shape, ctx->Attr("begin_params_axis")); - for (int i = 0; i < std::min(begin_norm_axis, begin_params_axis); ++i) { - ctx->NewBuilder() - .Split(ctx->inputs(), i) - .Split(ctx->outputs(), i) - .Broadcast(user_op::OpArg("gamma", 0)) - .Broadcast(user_op::OpArg("beta", 0)) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const bool center = ctx->Attr("center"); - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); - *y->mut_data_type() = x.data_type(); - if (center) { - const user_op::TensorDesc& beta = ctx->InputTensorDesc("beta", 0); - CHECK_EQ_OR_RETURN(beta.data_type(), x.data_type()); - } - const bool scale = ctx->Attr("scale"); - if (scale) { - const user_op::TensorDesc& gamma = ctx->InputTensorDesc("gamma", 0); - CHECK_EQ_OR_RETURN(gamma.data_type(), x.data_type()); - } - user_op::TensorDesc* mean = ctx->OutputTensorDesc("mean", 0); - user_op::TensorDesc* inv_variance = ctx->OutputTensorDesc("inv_variance", 0); - *mean->mut_data_type() = InferBnParamDataType(x.data_type()); - *inv_variance->mut_data_type() = mean->data_type(); - return Maybe::Ok(); - }); +/* static */ Maybe LayerNormOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* mean = ctx->OutputTensorDesc("mean", 0); + user_op::TensorDesc* inv_variance = ctx->OutputTensorDesc("inv_variance", 0); + const bool center = ctx->Attr("center"); + const bool scale = ctx->Attr("scale"); + const int64_t begin_params_axis = + ShiftNegativeAxisIfNeed(x.shape(), ctx->Attr("begin_params_axis")); + *y->mut_shape() = x.shape(); + *y->mut_is_dynamic() = x.is_dynamic(); + DimVector param_shape_dim_vec; + param_shape_dim_vec.insert(param_shape_dim_vec.end(), + x.shape().dim_vec().cbegin() + begin_params_axis, + x.shape().dim_vec().cend()); + const Shape param_shape(param_shape_dim_vec); + if (center) { + const user_op::TensorDesc& beta = ctx->InputTensorDesc("beta", 0); + CHECK_EQ_OR_RETURN(beta.shape(), param_shape); + } + if (scale) { + const user_op::TensorDesc& gamma = ctx->InputTensorDesc("gamma", 0); + CHECK_EQ_OR_RETURN(gamma.shape(), param_shape); + } + const int64_t begin_norm_axis = + ShiftNegativeAxisIfNeed(x.shape(), ctx->Attr("begin_norm_axis")); + *mean->mut_shape() = InferBnParamShape(x.shape(), begin_norm_axis); + *inv_variance = *mean; + return Maybe::Ok(); +} -REGISTER_USER_OP("layer_norm_grad") - .Input("dy") - .Input("x") - .Input("mean") - .Input("inv_variance") - .OptionalInput("gamma") - .OptionalInput("_add_to_output") - .Output("dx") - .Attr("begin_norm_axis") - .Attr("epsilon") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0); - const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0); - user_op::TensorDesc* dx = ctx->OutputTensorDesc("dx", 0); - CHECK_EQ_OR_RETURN(dy.shape(), x.shape()); - const int64_t begin_norm_axis = ctx->Attr("begin_norm_axis"); - CHECK_GT_OR_RETURN(begin_norm_axis, 0); - const Shape& bn_param_shape = InferBnParamShape(x.shape(), begin_norm_axis); - CHECK_EQ_OR_RETURN(mean.shape(), bn_param_shape); - CHECK_EQ_OR_RETURN(inv_variance.shape(), bn_param_shape); - *dx->mut_shape() = dy.shape(); - *dx->mut_is_dynamic() = dy.is_dynamic(); - if (ctx->has_input("_add_to_output", 0)) { - const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); - CHECK_EQ_OR_RETURN(add_to_output.shape(), dx->shape()); - } - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - int64_t begin_norm_axis = ctx->Attr("begin_norm_axis"); - for (int i = 0; i < begin_norm_axis; ++i) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - CHECK_EQ_OR_RETURN(dy.data_type(), x.data_type()); - const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0); - const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0); - const DataType& bn_param_data_type = InferBnParamDataType(x.data_type()); - CHECK_EQ_OR_RETURN(mean.data_type(), bn_param_data_type); - CHECK_EQ_OR_RETURN(inv_variance.data_type(), bn_param_data_type); - user_op::TensorDesc* dx = ctx->OutputTensorDesc("dx", 0); - *dx->mut_data_type() = dy.data_type(); - if (ctx->has_input("_add_to_output", 0)) { - const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); - CHECK_EQ_OR_RETURN(add_to_output.data_type(), dx->data_type()); - } - return Maybe::Ok(); - }); +/*static*/ Maybe LayerNormOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} -REGISTER_USER_OP("layer_norm_param_grad") - .Input("dy") - .Input("x") - .Input("mean") - .Input("inv_variance") - .OptionalOutput("beta_diff") - .OptionalOutput("gamma_diff") - .Attr("begin_params_axis") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - // TODO: tsai: replace lambda with user op if - auto has_tensor = [ctx](const std::string& bn) -> bool { - bool ret = false; - for (auto t : ctx->inputs()) { - if (bn == t.first) { return true; } - } - for (auto t : ctx->outputs()) { - if (bn == t.first) { return true; } - } - return ret; - }; - const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); - const int64_t begin_params_axis = ctx->Attr("begin_params_axis"); - const bool has_beta_diff = has_tensor("beta_diff"); - const bool has_gamma_diff = has_tensor("gamma_diff"); - const bool has_gamma = has_tensor("gamma"); - CHECK_GE_OR_RETURN(begin_params_axis, 1); - CHECK_LT_OR_RETURN(begin_params_axis, dy.shape().NumAxes()); - DimVector param_shape_dim_vec; - param_shape_dim_vec.insert(param_shape_dim_vec.end(), - dy.shape().dim_vec().cbegin() + begin_params_axis, - dy.shape().dim_vec().cend()); - const Shape param_shape(param_shape_dim_vec); - if (has_beta_diff) { - user_op::TensorDesc* beta_diff = ctx->OutputTensorDesc("beta_diff", 0); - *beta_diff->mut_shape() = param_shape; - } - if (has_gamma_diff) { - user_op::TensorDesc* gamma_diff = ctx->OutputTensorDesc("gamma_diff", 0); - *gamma_diff->mut_shape() = param_shape; - } - if (has_gamma) { - const user_op::TensorDesc& gamma = ctx->InputTensorDesc("gamma", 0); - CHECK_EQ_OR_RETURN(gamma.shape(), param_shape); - } - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - int64_t begin_params_axis = ctx->Attr("begin_params_axis"); - for (int i = 0; i < begin_params_axis; ++i) { - ctx->NewBuilder() - .Split(ctx->inputs(), i) - .Split(ctx->outputs(), i) - .Broadcast(user_op::OpArg("gamma", 0)) - .PartialSum(user_op::OpArg("gamma_diff", 0)) - .PartialSum(user_op::OpArg("beta_diff", 0)) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - auto has_tensor = [ctx](const std::string& bn) -> bool { - bool ret = false; - for (auto& t : ctx->inputs()) { - if (bn == t.first) { return true; } - } - for (auto& t : ctx->outputs()) { - if (bn == t.first) { return true; } - } - return ret; - }; - const bool has_beta_diff = has_tensor("beta_diff"); - const bool has_gamma_diff = has_tensor("gamma_diff"); - const bool has_gamma = has_tensor("gamma"); - const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); - if (has_beta_diff) { - user_op::TensorDesc* beta_diff = ctx->OutputTensorDesc("beta_diff", 0); - *beta_diff->mut_data_type() = dy.data_type(); - } - if (has_gamma_diff) { - user_op::TensorDesc* gamma_diff = ctx->OutputTensorDesc("gamma_diff", 0); - *gamma_diff->mut_data_type() = dy.data_type(); - } - if (has_gamma) { - const user_op::TensorDesc& gamma = ctx->InputTensorDesc("gamma", 0); - CHECK_EQ_OR_RETURN(gamma.data_type(), dy.data_type()); - } - return Maybe::Ok(); - }); +/* static */ Maybe LayerNormOp::GetSbp(user_op::SbpContext* ctx) { + const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); + int64_t begin_norm_axis = ShiftNegativeAxisIfNeed(x_shape, ctx->Attr("begin_norm_axis")); + int64_t begin_params_axis = + ShiftNegativeAxisIfNeed(x_shape, ctx->Attr("begin_params_axis")); + for (int i = 0; i < std::min(begin_norm_axis, begin_params_axis); ++i) { + ctx->NewBuilder() + .Split(ctx->inputs(), i) + .Split(ctx->outputs(), i) + .Broadcast(user_op::OpArg("gamma", 0)) + .Broadcast(user_op::OpArg("beta", 0)) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe LayerNormOp::InferDataType(user_op::InferContext* ctx) { + const bool center = ctx->Attr("center"); + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + *y->mut_data_type() = x.data_type(); + if (center) { + const user_op::TensorDesc& beta = ctx->InputTensorDesc("beta", 0); + CHECK_EQ_OR_RETURN(beta.data_type(), x.data_type()); + } + const bool scale = ctx->Attr("scale"); + if (scale) { + const user_op::TensorDesc& gamma = ctx->InputTensorDesc("gamma", 0); + CHECK_EQ_OR_RETURN(gamma.data_type(), x.data_type()); + } + user_op::TensorDesc* mean = ctx->OutputTensorDesc("mean", 0); + user_op::TensorDesc* inv_variance = ctx->OutputTensorDesc("inv_variance", 0); + *mean->mut_data_type() = InferBnParamDataType(x.data_type()); + *inv_variance->mut_data_type() = mean->data_type(); + return Maybe::Ok(); +} + +/* static */ Maybe LayerNormGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0); + const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0); + user_op::TensorDesc* dx = ctx->OutputTensorDesc("dx", 0); + CHECK_EQ_OR_RETURN(dy.shape(), x.shape()); + const int64_t begin_norm_axis = ctx->Attr("begin_norm_axis"); + CHECK_GT_OR_RETURN(begin_norm_axis, 0); + const Shape& bn_param_shape = InferBnParamShape(x.shape(), begin_norm_axis); + CHECK_EQ_OR_RETURN(mean.shape(), bn_param_shape); + CHECK_EQ_OR_RETURN(inv_variance.shape(), bn_param_shape); + *dx->mut_shape() = dy.shape(); + *dx->mut_is_dynamic() = dy.is_dynamic(); + if (ctx->has_input("_add_to_output", 0)) { + const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); + CHECK_EQ_OR_RETURN(add_to_output.shape(), dx->shape()); + } + return Maybe::Ok(); +} + +/*static*/ Maybe LayerNormGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe LayerNormGradOp::GetSbp(user_op::SbpContext* ctx) { + int64_t begin_norm_axis = ctx->Attr("begin_norm_axis"); + for (int i = 0; i < begin_norm_axis; ++i) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe LayerNormGradOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + CHECK_EQ_OR_RETURN(dy.data_type(), x.data_type()); + const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0); + const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0); + const DataType& bn_param_data_type = InferBnParamDataType(x.data_type()); + CHECK_EQ_OR_RETURN(mean.data_type(), bn_param_data_type); + CHECK_EQ_OR_RETURN(inv_variance.data_type(), bn_param_data_type); + user_op::TensorDesc* dx = ctx->OutputTensorDesc("dx", 0); + *dx->mut_data_type() = dy.data_type(); + if (ctx->has_input("_add_to_output", 0)) { + const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); + CHECK_EQ_OR_RETURN(add_to_output.data_type(), dx->data_type()); + } + return Maybe::Ok(); +} + +/* static */ Maybe LayerNormParamGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + // TODO: tsai: replace lambda with user op if + auto has_tensor = [ctx](const std::string& bn) -> bool { + bool ret = false; + for (auto t : ctx->inputs()) { + if (bn == t.first) { return true; } + } + for (auto t : ctx->outputs()) { + if (bn == t.first) { return true; } + } + return ret; + }; + const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); + const int64_t begin_params_axis = ctx->Attr("begin_params_axis"); + const bool has_beta_diff = has_tensor("beta_diff"); + const bool has_gamma_diff = has_tensor("gamma_diff"); + const bool has_gamma = has_tensor("gamma"); + CHECK_GE_OR_RETURN(begin_params_axis, 1); + CHECK_LT_OR_RETURN(begin_params_axis, dy.shape().NumAxes()); + DimVector param_shape_dim_vec; + param_shape_dim_vec.insert(param_shape_dim_vec.end(), + dy.shape().dim_vec().cbegin() + begin_params_axis, + dy.shape().dim_vec().cend()); + const Shape param_shape(param_shape_dim_vec); + if (has_beta_diff) { + user_op::TensorDesc* beta_diff = ctx->OutputTensorDesc("beta_diff", 0); + *beta_diff->mut_shape() = param_shape; + } + if (has_gamma_diff) { + user_op::TensorDesc* gamma_diff = ctx->OutputTensorDesc("gamma_diff", 0); + *gamma_diff->mut_shape() = param_shape; + } + if (has_gamma) { + const user_op::TensorDesc& gamma = ctx->InputTensorDesc("gamma", 0); + CHECK_EQ_OR_RETURN(gamma.shape(), param_shape); + } + return Maybe::Ok(); +} + +/*static*/ Maybe LayerNormParamGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe LayerNormParamGradOp::GetSbp(user_op::SbpContext* ctx) { + int64_t begin_params_axis = ctx->Attr("begin_params_axis"); + for (int i = 0; i < begin_params_axis; ++i) { + ctx->NewBuilder() + .Split(ctx->inputs(), i) + .Split(ctx->outputs(), i) + .Broadcast(user_op::OpArg("gamma", 0)) + .PartialSum(user_op::OpArg("gamma_diff", 0)) + .PartialSum(user_op::OpArg("beta_diff", 0)) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe LayerNormParamGradOp::InferDataType(user_op::InferContext* ctx) { + auto has_tensor = [ctx](const std::string& bn) -> bool { + bool ret = false; + for (auto& t : ctx->inputs()) { + if (bn == t.first) { return true; } + } + for (auto& t : ctx->outputs()) { + if (bn == t.first) { return true; } + } + return ret; + }; + const bool has_beta_diff = has_tensor("beta_diff"); + const bool has_gamma_diff = has_tensor("gamma_diff"); + const bool has_gamma = has_tensor("gamma"); + const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); + if (has_beta_diff) { + user_op::TensorDesc* beta_diff = ctx->OutputTensorDesc("beta_diff", 0); + *beta_diff->mut_data_type() = dy.data_type(); + } + if (has_gamma_diff) { + user_op::TensorDesc* gamma_diff = ctx->OutputTensorDesc("gamma_diff", 0); + *gamma_diff->mut_data_type() = dy.data_type(); + } + if (has_gamma) { + const user_op::TensorDesc& gamma = ctx->InputTensorDesc("gamma", 0); + CHECK_EQ_OR_RETURN(gamma.data_type(), dy.data_type()); + } + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("layer_norm") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/leaky_relu_op.cpp b/oneflow/user/ops/leaky_relu_op.cpp index f48b34aadd5..09d8b318c54 100644 --- a/oneflow/user/ops/leaky_relu_op.cpp +++ b/oneflow/user/ops/leaky_relu_op.cpp @@ -14,65 +14,69 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("leaky_relu") - .Input("x") - .Output("y") - .Attr("alpha") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - Shape* y_shape = ctx->OutputShape("y", 0); - *y_shape = x_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/* static */ Maybe LeakyReluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + Shape* y_shape = ctx->OutputShape("y", 0); + *y_shape = x_shape; + return Maybe::Ok(); +} -REGISTER_USER_OP("leaky_relu_grad") - .Input("x") - .Input("dy") - .Output("dx") - .Attr("alpha") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(dy_shape == x_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - ctx->NewBuilder() - .Broadcast(user_op::OpArg("x", 0)) - .PartialSum(user_op::OpArg("dy", 0)) - .PartialSum(user_op::OpArg("dx", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe LeakyReluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe LeakyReluOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe LeakyReluOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} + +/* static */ Maybe LeakyReluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(dy_shape == x_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} + +/*static*/ Maybe LeakyReluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe LeakyReluGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + ctx->NewBuilder() + .Broadcast(user_op::OpArg("x", 0)) + .PartialSum(user_op::OpArg("dy", 0)) + .PartialSum(user_op::OpArg("dx", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe LeakyReluGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("leaky_relu") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/log_softmax_op.cpp b/oneflow/user/ops/log_softmax_op.cpp index 6eff6fb15a3..d8cffbf7460 100644 --- a/oneflow/user/ops/log_softmax_op.cpp +++ b/oneflow/user/ops/log_softmax_op.cpp @@ -14,61 +14,65 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/* static */ Maybe LogSoftmaxOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("prob", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("log_softmax") - .Input("in") - .Output("prob") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("prob", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes() - 1) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), axis) - .Split(user_op::OpArg("prob", 0), axis) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("prob", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe LogSoftmaxOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} -REGISTER_USER_OP("log_softmax_grad") - .Input("prob") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& y_shape = ctx->InputShape("prob", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(dy_shape == y_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("prob", 0), ctx->InputDType("dy", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("prob", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("prob", 0); - FOR_RANGE(int64_t, axis, 0, y_tensor.shape().NumAxes() - 1) { - ctx->NewBuilder() - .Split(user_op::OpArg("prob", 0), axis) - .Split(user_op::OpArg("dy", 0), axis) - .Split(user_op::OpArg("dx", 0), axis) - .Build(); - } - return Maybe::Ok(); - }); +/* static */ Maybe LogSoftmaxOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes() - 1) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), axis) + .Split(user_op::OpArg("prob", 0), axis) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe LogSoftmaxOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("prob", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe LogSoftmaxGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& y_shape = ctx->InputShape("prob", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(dy_shape == y_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} + +/*static*/ Maybe LogSoftmaxGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe LogSoftmaxGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("prob", 0); + FOR_RANGE(int64_t, axis, 0, y_tensor.shape().NumAxes() - 1) { + ctx->NewBuilder() + .Split(user_op::OpArg("prob", 0), axis) + .Split(user_op::OpArg("dy", 0), axis) + .Split(user_op::OpArg("dx", 0), axis) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe LogSoftmaxGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("prob", 0), ctx->InputDType("dy", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("prob", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("log_softmax") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -87,5 +91,4 @@ REGISTER_USER_OP_GRAD("log_softmax") return Maybe::Ok(); }); -} // namespace } // namespace oneflow diff --git a/oneflow/user/ops/logical_not_op.cpp b/oneflow/user/ops/logical_not_op.cpp index 47d5b1ae20e..c4f549fa7ce 100644 --- a/oneflow/user/ops/logical_not_op.cpp +++ b/oneflow/user/ops/logical_not_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -26,11 +27,20 @@ Maybe InferDataTypeLogicalNot(user_op::InferContext* ctx) { } // namespace -REGISTER_NO_GRAD_USER_OP("logical_not") - .Input("x") - .Output("y") - .SetTensorDescInferFn(user_op::TensorDescInferFnUtil::Unchanged) - .SetGetSbpFn(user_op::GetSbpFnUtil::SplitForEachAxis) - .SetDataTypeInferFn(InferDataTypeLogicalNot); +/* static */ Maybe LogicalNotOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return user_op::TensorDescInferFnUtil::Unchanged(ctx); +} + +/*static*/ Maybe LogicalNotOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe LogicalNotOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); +} + +/* static */ Maybe LogicalNotOp::InferDataType(user_op::InferContext* ctx) { + return InferDataTypeLogicalNot(ctx); +} } // namespace oneflow diff --git a/oneflow/user/ops/loss_op_util.cpp b/oneflow/user/ops/loss_op_util.cpp index 96da2d7666f..5b8714df93d 100644 --- a/oneflow/user/ops/loss_op_util.cpp +++ b/oneflow/user/ops/loss_op_util.cpp @@ -19,29 +19,35 @@ limitations under the License. namespace oneflow { user_op::GetSbpFn GenLossForwardDefaultGetSbpFn( - const std::function& f) { + const std::function& f) { return [=](user_op::SbpContext* ctx) -> Maybe { auto builder = ctx->NewBuilder() .Split(user_op::OpArg("input", 0), 0) .Split(user_op::OpArg("target", 0), 0) - .Broadcast(user_op::OpArg("weight", 0)) .Split(user_op::OpArg("out", 0), 0); - f(builder); + if (ctx->user_op_conf().has_input("weight", 0)) { + builder.Broadcast(user_op::OpArg("weight", 0)); + } + f(builder, ctx); builder.Build(); return Maybe::Ok(); }; } user_op::GetSbpFn GenLossBackwardDefaultGetSbpFn( - const std::function& f) { + const std::function& f) { return [=](user_op::SbpContext* ctx) -> Maybe { auto builder = ctx->NewBuilder() .Split(user_op::OpArg("input", 0), 0) .Split(user_op::OpArg("target", 0), 0) - .Broadcast(user_op::OpArg("weight", 0)) .Split(user_op::OpArg("dx", 0), 0) .Split(user_op::OpArg("dy", 0), 0); - f(builder); + if (ctx->user_op_conf().has_input("weight", 0)) { + builder.Broadcast(user_op::OpArg("weight", 0)); + } + f(builder, ctx); builder.Build(); return Maybe::Ok(); }; diff --git a/oneflow/user/ops/loss_op_util.h b/oneflow/user/ops/loss_op_util.h index fb4e3f7f68c..7c91d69f00f 100644 --- a/oneflow/user/ops/loss_op_util.h +++ b/oneflow/user/ops/loss_op_util.h @@ -22,12 +22,14 @@ limitations under the License. namespace oneflow { user_op::GetSbpFn GenLossForwardDefaultGetSbpFn( - const std::function& f = - [](user_op::UserOpSbpSignatureBuilder& builder) {}); + const std::function& f = + [](user_op::UserOpSbpSignatureBuilder& builder, user_op::SbpContext* ctx) {}); user_op::GetSbpFn GenLossBackwardDefaultGetSbpFn( - const std::function& f = - [](user_op::UserOpSbpSignatureBuilder& builder) {}); + const std::function& f = + [](user_op::UserOpSbpSignatureBuilder& builder, user_op::SbpContext* ctx) {}); } // namespace oneflow diff --git a/oneflow/user/ops/masked_fill_op.cpp b/oneflow/user/ops/masked_fill_op.cpp index aa6e08ba954..44afd9b37e9 100644 --- a/oneflow/user/ops/masked_fill_op.cpp +++ b/oneflow/user/ops/masked_fill_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -73,17 +74,25 @@ Maybe GetMaskedFillInputArgModify(const user_op::GetInputArgModifier& GetI } // namespace -REGISTER_USER_OP("masked_fill") - .Input("x") - .Input("mask") - .Output("out") - .Attr("has_int_operand") - .Attr("has_float_operand") - .Attr("int_operand") - .Attr("float_operand") - .SetTensorDescInferFn(InferMaskedFillTensorDesc) - .SetInputArgModifyFn(GetMaskedFillInputArgModify) - .SetDataTypeInferFn(InferMaskedFillDataType) - .SetGetSbpFn(GetMaskedFillSbpSignatures); +/* static */ Maybe MaskedFillOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferMaskedFillTensorDesc(ctx); +} + +/*static*/ Maybe MaskedFillOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe MaskedFillOp::GetSbp(user_op::SbpContext* ctx) { + return GetMaskedFillSbpSignatures(ctx); +} + +/* static */ Maybe MaskedFillOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return GetMaskedFillInputArgModify(GetInputArgModifierFn, conf); +} + +/* static */ Maybe MaskedFillOp::InferDataType(user_op::InferContext* ctx) { + return InferMaskedFillDataType(ctx); +} } // namespace oneflow diff --git a/oneflow/user/ops/math_binary_broadcast_ops.cpp b/oneflow/user/ops/math_binary_broadcast_ops.cpp index d246629ba6a..918884679b0 100644 --- a/oneflow/user/ops/math_binary_broadcast_ops.cpp +++ b/oneflow/user/ops/math_binary_broadcast_ops.cpp @@ -16,6 +16,7 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/binary_func.h" #include "oneflow/user/ops/math_binary_broadcast_seq.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -193,26 +194,36 @@ Maybe GetBinaryBroadcastSbpSignature(user_op::SbpContext* ctx) { } // namespace -#define REGISTER_BINARY_BROADCAST_NORMAL_USER_OP(op_name, suffix) \ - REGISTER_USER_OP(op_name) \ - .Input("x") \ - .Input("y") \ - .Output("z") \ - .SetTensorDescInferFn(InferTensorDescBinaryBroadcastNormal) \ - .SetGetSbpFn(GetBinaryBroadcastSbpSignature) \ - .SetDataTypeInferFn(InferDataTypeBinaryBroadcastNormal); - -#define REGISTER_BINARY_BROADCAST_LOGICAL_USER_OP(op_name, suffix) \ - REGISTER_NO_GRAD_USER_OP(op_name) \ - .Input("x") \ - .Input("y") \ - .Output("z") \ - .SetTensorDescInferFn(InferTensorDescBinaryBroadcastLogical) \ - .SetGetSbpFn(GetBinaryBroadcastSbpSignature) \ - .SetDataTypeInferFn(InferDataTypeBinaryBroadcastLogical); - -OF_PP_FOR_EACH_TUPLE(REGISTER_BINARY_BROADCAST_NORMAL_USER_OP, MATH_BINARY_BROADCAST_FUNC_SEQ) +#define REGISTER_BINARY_BROADCAST_NORMAL_USER_OP(op_name, suffix) \ + /* static */ Maybe op_name::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return InferTensorDescBinaryBroadcastNormal(ctx); \ + } \ + /*static*/ Maybe op_name::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /* static */ Maybe op_name::GetSbp(user_op::SbpContext* ctx) { \ + return GetBinaryBroadcastSbpSignature(ctx); \ + } \ + /* static */ Maybe op_name::InferDataType(user_op::InferContext* ctx) { \ + return InferDataTypeBinaryBroadcastNormal(ctx); \ + } + +#define REGISTER_BINARY_BROADCAST_LOGICAL_USER_OP(op_name, suffix) \ + /* static */ Maybe op_name::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return InferTensorDescBinaryBroadcastLogical(ctx); \ + } \ + /*static*/ Maybe op_name::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /* static */ Maybe op_name::GetSbp(user_op::SbpContext* ctx) { \ + return GetBinaryBroadcastSbpSignature(ctx); \ + } \ + /* static */ Maybe op_name::InferDataType(user_op::InferContext* ctx) { \ + return InferDataTypeBinaryBroadcastLogical(ctx); \ + } + +OF_PP_FOR_EACH_TUPLE(REGISTER_BINARY_BROADCAST_NORMAL_USER_OP, MATH_BINARY_BROADCAST_FUNC_SEQ_ODS) OF_PP_FOR_EACH_TUPLE(REGISTER_BINARY_BROADCAST_LOGICAL_USER_OP, - MATH_BINARY_BROADCAST_LOGICAL_FUNC_SEQ) + MATH_BINARY_BROADCAST_LOGICAL_FUNC_SEQ_ODS) } // namespace oneflow diff --git a/oneflow/user/ops/math_binary_broadcast_seq.h b/oneflow/user/ops/math_binary_broadcast_seq.h index c3eeafab202..4dc820c0fc8 100644 --- a/oneflow/user/ops/math_binary_broadcast_seq.h +++ b/oneflow/user/ops/math_binary_broadcast_seq.h @@ -42,6 +42,28 @@ namespace oneflow { OF_PP_MAKE_TUPLE_SEQ("broadcast_logical_or", OR) \ OF_PP_MAKE_TUPLE_SEQ("broadcast_logical_xor", XOR) +#define MATH_BINARY_BROADCAST_FUNC_SEQ_ODS \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastAddOp, Add) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastSubOp, Sub) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastMulOp, Mul) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastDivOp, Div) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastMinimumOp, Min) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastMaximumOp, Max) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastFloorModOp, FloorMod) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastFmodOp, FMod) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastPowOp, Pow) + +#define MATH_BINARY_BROADCAST_LOGICAL_FUNC_SEQ_ODS \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastEqualOp, EQ) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastNotEqualOp, NE) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastGreaterOp, GT) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastGreaterEqualOp, GE) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastLessOp, LT) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastLessEqualOp, LE) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastLogicalAndOp, AND) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastLogicalOrOp, OR) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastLogicalXorOp, XOR) + } // namespace oneflow #endif // ONEFLOW_USER_OPS_MATH_BINARY_BROADCAST_SEQ_H_ diff --git a/oneflow/user/ops/math_binary_elementwise_ops.cpp b/oneflow/user/ops/math_binary_elementwise_ops.cpp index f528a76966a..ec6e71f82de 100644 --- a/oneflow/user/ops/math_binary_elementwise_ops.cpp +++ b/oneflow/user/ops/math_binary_elementwise_ops.cpp @@ -15,38 +15,34 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/math_binary_elementwise_seq.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -#define MATH_ELEMENTWISE_DEFAULT_SET_FUNC() \ - SetTensorDescInferFn(user_op::TensorDescInferFnUtil::Unchanged) \ - .SetGetSbpFn(user_op::GetSbpFnUtil::SplitForEachAxis) \ - .SetDataTypeInferFn(user_op::TensorDescInferFnUtil::UnchangedDataType) +#define MATH_ELEMENTWISE_DEFAULT_SET_FUNC(op_type) \ + /* static */ Maybe op_type::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return user_op::TensorDescInferFnUtil::Unchanged(ctx); \ + } \ + /*static*/ Maybe op_type::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /* static */ Maybe op_type::GetSbp(user_op::SbpContext* ctx) { \ + return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); \ + } \ + /* static */ Maybe op_type::InferDataType(user_op::InferContext* ctx) { \ + return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx); \ + } #define REGISTER_MATH_BINARY_ELEMENTWISE_OP_AND_GRAD(math_binary_elementwise_type, func_prefix) \ - REGISTER_USER_OP(math_binary_elementwise_type) \ - .Input("x") \ - .Input("y") \ - .Output("z") \ - .MATH_ELEMENTWISE_DEFAULT_SET_FUNC(); \ + MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##Op); \ \ - REGISTER_USER_OP((std::string("") + math_binary_elementwise_type + "_x_grad")) \ - .Input("x") \ - .Input("y") \ - .Input("dz") \ - .Output("dx") \ - .MATH_ELEMENTWISE_DEFAULT_SET_FUNC(); \ + MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##XGradOp); \ \ - REGISTER_USER_OP((std::string("") + math_binary_elementwise_type + "_y_grad")) \ - .Input("x") \ - .Input("y") \ - .Input("dz") \ - .Output("dy") \ - .MATH_ELEMENTWISE_DEFAULT_SET_FUNC(); \ + MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##YGradOp); \ \ REGISTER_USER_OP_GRAD(math_binary_elementwise_type) \ .SetGenBackwardOpConfFn( \ - [](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { \ + [](const user_op::UserOpWrapper& op, const user_op::AddOpFn& AddOp) -> Maybe { \ if (op.NeedGenGradTensor4OpInput("x", 0)) { \ user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_x_grad"); \ user_op::UserOpConfWrapper binary_grad_op = \ @@ -74,6 +70,7 @@ namespace oneflow { return Maybe::Ok(); \ }); -OF_PP_FOR_EACH_TUPLE(REGISTER_MATH_BINARY_ELEMENTWISE_OP_AND_GRAD, MATH_BINARY_ELEMENTWISE_FUNC_SEQ) +OF_PP_FOR_EACH_TUPLE(REGISTER_MATH_BINARY_ELEMENTWISE_OP_AND_GRAD, + MATH_BINARY_ELEMENTWISE_FUNC_SEQ_ODS) } // namespace oneflow diff --git a/oneflow/user/ops/math_binary_elementwise_seq.h b/oneflow/user/ops/math_binary_elementwise_seq.h index 37e667f086e..4cdc682d687 100644 --- a/oneflow/user/ops/math_binary_elementwise_seq.h +++ b/oneflow/user/ops/math_binary_elementwise_seq.h @@ -27,6 +27,13 @@ namespace oneflow { OF_PP_MAKE_TUPLE_SEQ("xdivy", Xdivy) \ OF_PP_MAKE_TUPLE_SEQ("xlogy", Xlogy) +#define MATH_BINARY_ELEMENTWISE_FUNC_SEQ_ODS \ + OF_PP_MAKE_TUPLE_SEQ("pow", Pow) \ + OF_PP_MAKE_TUPLE_SEQ("atan2", Atan2) \ + OF_PP_MAKE_TUPLE_SEQ("floordiv", Floordiv) \ + OF_PP_MAKE_TUPLE_SEQ("xdivy", Xdivy) \ + OF_PP_MAKE_TUPLE_SEQ("xlogy", Xlogy) + } // namespace oneflow #endif // ONEFLOW_USER_OPS_MATH_BINARY_ELEMENTWISE_SEQ_H_ diff --git a/oneflow/user/ops/math_unary_elementwise_op.cpp b/oneflow/user/ops/math_unary_elementwise_op.cpp index 69cc34fb151..64af43f6316 100644 --- a/oneflow/user/ops/math_unary_elementwise_op.cpp +++ b/oneflow/user/ops/math_unary_elementwise_op.cpp @@ -15,40 +15,45 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/math_unary_elementwise_seq.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -#define REGISTER_MATH_UNARY_ELEMENTWISE_OP_AND_GRAD(math_unary_elementwise_type, func_prefix) \ - REGISTER_USER_OP(math_unary_elementwise_type) \ - .Input("x") \ - .Output("y") \ - .SetTensorDescInferFn(user_op::TensorDescInferFnUtil::Unchanged) \ - .SetGetSbpFn(user_op::GetSbpFnUtil::SplitForEachAxis) \ - .SetDataTypeInferFn(user_op::TensorDescInferFnUtil::UnchangedDataType); \ - REGISTER_USER_OP((std::string("") + math_unary_elementwise_type + "_grad")) \ - .Input("x") \ - .Input("dy") \ - .Output("dx") \ - .SetTensorDescInferFn(user_op::TensorDescInferFnUtil::Unchanged) \ - .SetGetSbpFn(user_op::GetSbpFnUtil::SplitForEachAxis) \ - .SetDataTypeInferFn(user_op::TensorDescInferFnUtil::UnchangedDataType); \ - REGISTER_USER_OP_GRAD(math_unary_elementwise_type) \ - .SetGenBackwardOpConfFn( \ - [](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { \ - if (op.NeedGenGradTensor4OpInput("x", 0)) { \ - user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); \ - user_op::UserOpConfWrapper unary_grad_op = \ - builder.Op(std::string("") + math_unary_elementwise_type + "_grad") \ - .Input("x", op.input("x", 0)) \ - .Input("dy", op.GetGradTensorWithOpOutput("y", 0)) \ - .Output("dx") \ - .Build(); \ - op.BindGradTensorWithOpInput(unary_grad_op.output("dx", 0), "x", 0); \ - AddOp(unary_grad_op); \ - } \ - return Maybe::Ok(); \ +#define MATH_ELEMENTWISE_DEFAULT_SET_FUNC(op_type) \ + /* static */ Maybe op_type::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return user_op::TensorDescInferFnUtil::Unchanged(ctx); \ + } \ + /*static*/ Maybe op_type::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /* static */ Maybe op_type::GetSbp(user_op::SbpContext* ctx) { \ + return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); \ + } \ + /* static */ Maybe op_type::InferDataType(user_op::InferContext* ctx) { \ + return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx); \ + } + +#define REGISTER_MATH_UNARY_ELEMENTWISE_OP_AND_GRAD(math_unary_elementwise_type, func_prefix) \ + MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##Op) \ + MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##GradOp) \ + REGISTER_USER_OP_GRAD(math_unary_elementwise_type) \ + .SetGenBackwardOpConfFn( \ + [](const user_op::UserOpWrapper& op, const user_op::AddOpFn& AddOp) -> Maybe { \ + if (op.NeedGenGradTensor4OpInput("x", 0)) { \ + user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); \ + user_op::UserOpConfWrapper unary_grad_op = \ + builder.Op(std::string("") + math_unary_elementwise_type + "_grad") \ + .Input("x", op.input("x", 0)) \ + .Input("dy", op.GetGradTensorWithOpOutput("y", 0)) \ + .Output("dx") \ + .Build(); \ + op.BindGradTensorWithOpInput(unary_grad_op.output("dx", 0), "x", 0); \ + AddOp(unary_grad_op); \ + } \ + return Maybe::Ok(); \ }); -OF_PP_FOR_EACH_TUPLE(REGISTER_MATH_UNARY_ELEMENTWISE_OP_AND_GRAD, MATH_UNARY_ELEMENTWISE_FUNC_SEQ) +OF_PP_FOR_EACH_TUPLE(REGISTER_MATH_UNARY_ELEMENTWISE_OP_AND_GRAD, + MATH_UNARY_ELEMENTWISE_FUNC_SEQ_ODS) } // namespace oneflow diff --git a/oneflow/user/ops/math_unary_elementwise_seq.h b/oneflow/user/ops/math_unary_elementwise_seq.h index dc8a4731e3a..db90cb9a9d3 100644 --- a/oneflow/user/ops/math_unary_elementwise_seq.h +++ b/oneflow/user/ops/math_unary_elementwise_seq.h @@ -38,6 +38,7 @@ namespace oneflow { OF_PP_MAKE_TUPLE_SEQ("floor", Floor) \ OF_PP_MAKE_TUPLE_SEQ("lgamma", Lgamma) \ OF_PP_MAKE_TUPLE_SEQ("log", Log) \ + OF_PP_MAKE_TUPLE_SEQ("log2", Log2) \ OF_PP_MAKE_TUPLE_SEQ("log1p", Log1p) \ OF_PP_MAKE_TUPLE_SEQ("log_sigmoid", LogSigmoid) \ OF_PP_MAKE_TUPLE_SEQ("negative", Negative) \ @@ -55,6 +56,42 @@ namespace oneflow { OF_PP_MAKE_TUPLE_SEQ("square", Square) \ OF_PP_MAKE_TUPLE_SEQ("tan", Tan) +#define MATH_UNARY_ELEMENTWISE_FUNC_SEQ_ODS \ + OF_PP_MAKE_TUPLE_SEQ("abs", Abs) \ + OF_PP_MAKE_TUPLE_SEQ("acos", Acos) \ + OF_PP_MAKE_TUPLE_SEQ("acosh", Acosh) \ + OF_PP_MAKE_TUPLE_SEQ("asin", Asin) \ + OF_PP_MAKE_TUPLE_SEQ("asinh", Asinh) \ + OF_PP_MAKE_TUPLE_SEQ("atan", Atan) \ + OF_PP_MAKE_TUPLE_SEQ("atanh", Atanh) \ + OF_PP_MAKE_TUPLE_SEQ("ceil", Ceil) \ + OF_PP_MAKE_TUPLE_SEQ("cos", Cos) \ + OF_PP_MAKE_TUPLE_SEQ("cosh", Cosh) \ + OF_PP_MAKE_TUPLE_SEQ("erf", Erf) \ + OF_PP_MAKE_TUPLE_SEQ("erfc", Erfc) \ + OF_PP_MAKE_TUPLE_SEQ("exp", Exp) \ + OF_PP_MAKE_TUPLE_SEQ("expm1", Expm1) \ + OF_PP_MAKE_TUPLE_SEQ("floor", Floor) \ + OF_PP_MAKE_TUPLE_SEQ("lgamma", Lgamma) \ + OF_PP_MAKE_TUPLE_SEQ("log", Log) \ + OF_PP_MAKE_TUPLE_SEQ("log2", Log2) \ + OF_PP_MAKE_TUPLE_SEQ("log1p", Log1p) \ + OF_PP_MAKE_TUPLE_SEQ("log_sigmoid", LogSigmoid) \ + OF_PP_MAKE_TUPLE_SEQ("negative", Negative) \ + OF_PP_MAKE_TUPLE_SEQ("reciprocal", Reciprocal) \ + OF_PP_MAKE_TUPLE_SEQ("reciprocal_no_nan", ReciprocalNoNan) \ + OF_PP_MAKE_TUPLE_SEQ("rint", Rint) \ + OF_PP_MAKE_TUPLE_SEQ("round", Round) \ + OF_PP_MAKE_TUPLE_SEQ("rsqrt", Rsqrt) \ + OF_PP_MAKE_TUPLE_SEQ("sigmoid_v2", SigmoidV2) \ + OF_PP_MAKE_TUPLE_SEQ("sign", Sign) \ + OF_PP_MAKE_TUPLE_SEQ("sin", Sin) \ + OF_PP_MAKE_TUPLE_SEQ("sinh", Sinh) \ + OF_PP_MAKE_TUPLE_SEQ("softplus", Softplus) \ + OF_PP_MAKE_TUPLE_SEQ("sqrt", Sqrt) \ + OF_PP_MAKE_TUPLE_SEQ("square", Square) \ + OF_PP_MAKE_TUPLE_SEQ("tan", Tan) + } // namespace oneflow #endif // ONEFLOW_USER_OPS_MATH_UNARY_ELEMENTWISE_SEQ_H_ diff --git a/oneflow/user/ops/matmul_op.cpp b/oneflow/user/ops/matmul_op.cpp index b551617e595..4e73d7e3e35 100644 --- a/oneflow/user/ops/matmul_op.cpp +++ b/oneflow/user/ops/matmul_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -141,271 +142,278 @@ void GenBackwardOpConf4Matmul(const std::string& op_type_name, const user_op::Us } // namespace -REGISTER_USER_OP("matmul") - .Input("a") - .Input("b") - .OptionalInput("_add_to_output") - .Output("out") - .Attr("transpose_a", false) - .Attr("transpose_b", false) - .Attr("alpha", 1.0) - .SetTensorDescInferFn(InferTensorDesc4Matmul) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - // (m, k_a) * (k_b, n) where k_a == k_b - int32_t m_axis = -1; - int32_t k_a_axis = -1; - int32_t k_b_axis = -1; - int32_t n_axis = -1; - if (ctx->Attr("transpose_a")) { - m_axis = 1; - k_a_axis = 0; - } else { - m_axis = 0; - k_a_axis = 1; - } - if (ctx->Attr("transpose_b")) { - k_b_axis = 1; - n_axis = 0; - } else { - k_b_axis = 0; - n_axis = 1; - } - std::vector out_and_add_to_output_args; - out_and_add_to_output_args.emplace_back("out", 0); - if (ctx->user_op_conf().has_input("_add_to_output", 0)) { - out_and_add_to_output_args.emplace_back("_add_to_output", 0); - } - ctx->NewBuilder() - .Split(user_op::OpArg("a", 0), m_axis) - .Broadcast(user_op::OpArg("b", 0)) - .Split(out_and_add_to_output_args, 0) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("a", 0)) - .Split(user_op::OpArg("b", 0), n_axis) - .Split(out_and_add_to_output_args, 1) - .Build(); - ctx->NewBuilder() - .Split(user_op::OpArg("a", 0), k_a_axis) - .Split(user_op::OpArg("b", 0), k_b_axis) - .PartialSum(out_and_add_to_output_args) - .Build(); - ctx->NewBuilder() - .PartialSum(user_op::OpArg("a", 0)) - .Broadcast(user_op::OpArg("b", 0)) - .PartialSum(out_and_add_to_output_args) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("a", 0)) - .PartialSum(user_op::OpArg("b", 0)) - .PartialSum(out_and_add_to_output_args) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn(InferDataType4Matmul); +/* static */ Maybe MatmulOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDesc4Matmul(ctx); +} + +/*static*/ Maybe MatmulOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} -REGISTER_USER_OP_GRAD("matmul").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, - user_op::AddOpFn AddOp) -> Maybe { - GenBackwardOpConf4Matmul("matmul", op, AddOp); +/* static */ Maybe MatmulOp::GetSbp(user_op::SbpContext* ctx) { + // (m, k_a) * (k_b, n) where k_a == k_b + int32_t m_axis = -1; + int32_t k_a_axis = -1; + int32_t k_b_axis = -1; + int32_t n_axis = -1; + if (ctx->Attr("transpose_a")) { + m_axis = 1; + k_a_axis = 0; + } else { + m_axis = 0; + k_a_axis = 1; + } + if (ctx->Attr("transpose_b")) { + k_b_axis = 1; + n_axis = 0; + } else { + k_b_axis = 0; + n_axis = 1; + } + std::vector out_and_add_to_output_args; + out_and_add_to_output_args.emplace_back("out", 0); + if (ctx->user_op_conf().has_input("_add_to_output", 0)) { + out_and_add_to_output_args.emplace_back("_add_to_output", 0); + } + ctx->NewBuilder() + .Split(user_op::OpArg("a", 0), m_axis) + .Broadcast(user_op::OpArg("b", 0)) + .Split(out_and_add_to_output_args, 0) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("a", 0)) + .Split(user_op::OpArg("b", 0), n_axis) + .Split(out_and_add_to_output_args, 1) + .Build(); + ctx->NewBuilder() + .Split(user_op::OpArg("a", 0), k_a_axis) + .Split(user_op::OpArg("b", 0), k_b_axis) + .PartialSum(out_and_add_to_output_args) + .Build(); + ctx->NewBuilder() + .PartialSum(user_op::OpArg("a", 0)) + .Broadcast(user_op::OpArg("b", 0)) + .PartialSum(out_and_add_to_output_args) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("a", 0)) + .PartialSum(user_op::OpArg("b", 0)) + .PartialSum(out_and_add_to_output_args) + .Build(); return Maybe::Ok(); -}); - -REGISTER_USER_OP("batch_matmul") - .Input("a") - .Input("b") - .OptionalInput("_add_to_output") - .Output("out") - .Attr("transpose_a", false) - .Attr("transpose_b", false) - .Attr("alpha", 1.0) - .SetTensorDescInferFn(InferTensorDesc4Matmul) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& a_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0); - std::vector out_and_add_to_output_args; - out_and_add_to_output_args.emplace_back("out", 0); - if (ctx->user_op_conf().has_input("_add_to_output", 0)) { - out_and_add_to_output_args.emplace_back("_add_to_output", 0); - } - FOR_RANGE(int64_t, i, 0, a_tensor.shape().NumAxes() - 2) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(out_and_add_to_output_args, i).Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn(InferDataType4Matmul); +} -REGISTER_USER_OP_GRAD("batch_matmul") - .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, - user_op::AddOpFn AddOp) -> Maybe { - GenBackwardOpConf4Matmul("batch_matmul", op, AddOp); - return Maybe::Ok(); - }); +/* static */ Maybe MatmulOp::InferDataType(user_op::InferContext* ctx) { + return InferDataType4Matmul(ctx); +} -REGISTER_USER_OP("broadcast_matmul") - .Input("a") - .Input("b") - .OptionalInput("_add_to_output") - .Output("out") - .Attr("transpose_a", false) - .Attr("transpose_b", false) - .Attr("alpha", 1.0) - .SetDataTypeInferFn(InferDataType4Matmul) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - bool transpose_a = ctx->Attr("transpose_a"); - bool transpose_b = ctx->Attr("transpose_b"); - - const user_op::TensorDesc& a = ctx->InputTensorDesc("a", 0); - const user_op::TensorDesc& b = ctx->InputTensorDesc("b", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - - // NOTE: support broadcast b to a for now - // TODO(zwx): support broadcast a to b - CHECK_GT_OR_RETURN(a.shape().NumAxes(), b.shape().NumAxes()); - CHECK_EQ_OR_RETURN(b.shape().NumAxes(), 2); - // NOTE: don't support transpose_a for now - CHECK_OR_RETURN(!transpose_a); +/* static */ Maybe BatchMatmulOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDesc4Matmul(ctx); +} - DimVector out_dim_vec(a.shape().NumAxes() - 1); - FOR_RANGE(int64_t, i, 0, out_dim_vec.size()) { out_dim_vec[i] = a.shape().At(i); } - int64_t k = a.shape().At(a.shape().NumAxes() - 1); - int64_t n = -1; - if (!transpose_b) { - CHECK_EQ_OR_RETURN(k, b.shape().At(b.shape().NumAxes() - 2)); - n = b.shape().At(b.shape().NumAxes() - 1); - } else { - CHECK_EQ_OR_RETURN(k, b.shape().At(b.shape().NumAxes() - 1)); - n = b.shape().At(b.shape().NumAxes() - 2); - } - out_dim_vec.emplace_back(n); - *out->mut_shape() = Shape(out_dim_vec); - - if (ctx->has_input("_add_to_output", 0)) { - const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); - CHECK_EQ_OR_RETURN(add_to_output.shape(), out->shape()); - } +/*static*/ Maybe BatchMatmulOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - // (b, m, k) * (k, n) when transpose_b is false - // (b, m, k) * (n, k) when transpose_b is true - bool transpose_a = ctx->Attr("transpose_a"); - bool transpose_b = ctx->Attr("transpose_b"); - CHECK_OR_RETURN(!transpose_a); +/* static */ Maybe BatchMatmulOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& a_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0); + std::vector out_and_add_to_output_args; + out_and_add_to_output_args.emplace_back("out", 0); + if (ctx->user_op_conf().has_input("_add_to_output", 0)) { + out_and_add_to_output_args.emplace_back("_add_to_output", 0); + } + FOR_RANGE(int64_t, i, 0, a_tensor.shape().NumAxes() - 2) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(out_and_add_to_output_args, i).Build(); + } + return Maybe::Ok(); +} - const auto& a_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape(); - int32_t k_a_axis = a_shape.NumAxes() - 1; - int32_t k_b_axis = -1; - int32_t n_axis = -1; - if (transpose_b) { - k_b_axis = 1; - n_axis = 0; - } else { - k_b_axis = 0; - n_axis = 1; - } - - std::vector out_and_add_to_output_args; - out_and_add_to_output_args.emplace_back("out", 0); - if (ctx->user_op_conf().has_input("_add_to_output", 0)) { - out_and_add_to_output_args.emplace_back("_add_to_output", 0); - } - - // S(b or m axis) x B -> S(b or m axis) - for (int64_t i = 0; i < a_shape.NumAxes() - 1; ++i) { - ctx->NewBuilder() - .Split(user_op::OpArg("a", 0), i) - .Broadcast(user_op::OpArg("b", 0)) - .Split(out_and_add_to_output_args, i) - .Build(); - } - // B x S(n_axis) -> S(n_axis) - ctx->NewBuilder() - .Broadcast(user_op::OpArg("a", 0)) - .Split(user_op::OpArg("b", 0), n_axis) - .Split(out_and_add_to_output_args, a_shape.NumAxes() - 1) - .Build(); - // S(a_k_axis) x S(b_k_axis) -> P - ctx->NewBuilder() - .Split(user_op::OpArg("a", 0), k_a_axis) - .Split(user_op::OpArg("b", 0), k_b_axis) - .PartialSum(out_and_add_to_output_args) - .Build(); - // P x B -> P - ctx->NewBuilder() - .PartialSum(user_op::OpArg("a", 0)) - .Broadcast(user_op::OpArg("b", 0)) - .PartialSum(out_and_add_to_output_args) - .Build(); - // B x P -> P - ctx->NewBuilder() - .Broadcast(user_op::OpArg("a", 0)) - .PartialSum(user_op::OpArg("b", 0)) - .PartialSum(out_and_add_to_output_args) - .Build(); - return Maybe::Ok(); - }); +/* static */ Maybe BatchMatmulOp::InferDataType(user_op::InferContext* ctx) { + return InferDataType4Matmul(ctx); +} + +/* static */ Maybe BroadcastMatmulOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + bool transpose_a = ctx->Attr("transpose_a"); + bool transpose_b = ctx->Attr("transpose_b"); -REGISTER_USER_OP("broadcast_matmul_grad_b") - .Input("a") - .Input("b") - .OptionalInput("_add_to_output") - .Output("out") - .Attr("alpha", 1.0) - .SetDataTypeInferFn(InferDataType4Matmul) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& a = ctx->InputTensorDesc("a", 0); - const user_op::TensorDesc& b = ctx->InputTensorDesc("b", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - - CHECK_EQ_OR_RETURN(a.shape().NumAxes(), b.shape().NumAxes()); - for (int i = 0; i < a.shape().NumAxes() - 1; ++i) { - CHECK_EQ_OR_RETURN(a.shape().At(i), b.shape().At(i)); - } - - *out->mut_shape() = - Shape({a.shape().At(a.shape().NumAxes() - 1), b.shape().At(b.shape().NumAxes() - 1)}); - - if (ctx->has_input("_add_to_output", 0)) { - const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); - CHECK_EQ_OR_RETURN(add_to_output.shape(), out->shape()); - } + const user_op::TensorDesc& a = ctx->InputTensorDesc("a", 0); + const user_op::TensorDesc& b = ctx->InputTensorDesc("b", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + // NOTE: support broadcast b to a for now + // TODO(zwx): support broadcast a to b + CHECK_GT_OR_RETURN(a.shape().NumAxes(), b.shape().NumAxes()); + CHECK_EQ_OR_RETURN(b.shape().NumAxes(), 2); + // NOTE: don't support transpose_a for now + CHECK_OR_RETURN(!transpose_a); + + DimVector out_dim_vec(a.shape().NumAxes() - 1); + FOR_RANGE(int64_t, i, 0, out_dim_vec.size()) { out_dim_vec[i] = a.shape().At(i); } + int64_t k = a.shape().At(a.shape().NumAxes() - 1); + int64_t n = -1; + if (!transpose_b) { + CHECK_EQ_OR_RETURN(k, b.shape().At(b.shape().NumAxes() - 2)); + n = b.shape().At(b.shape().NumAxes() - 1); + } else { + CHECK_EQ_OR_RETURN(k, b.shape().At(b.shape().NumAxes() - 1)); + n = b.shape().At(b.shape().NumAxes() - 2); + } + out_dim_vec.emplace_back(n); + *out->mut_shape() = Shape(out_dim_vec); + + if (ctx->has_input("_add_to_output", 0)) { + const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); + CHECK_EQ_OR_RETURN(add_to_output.shape(), out->shape()); + } + + return Maybe::Ok(); +} + +/*static*/ Maybe BroadcastMatmulOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe BroadcastMatmulOp::GetSbp(user_op::SbpContext* ctx) { + // (b, m, k) * (k, n) when transpose_b is false + // (b, m, k) * (n, k) when transpose_b is true + bool transpose_a = ctx->Attr("transpose_a"); + bool transpose_b = ctx->Attr("transpose_b"); + CHECK_OR_RETURN(!transpose_a); + + const auto& a_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape(); + int32_t k_a_axis = a_shape.NumAxes() - 1; + int32_t k_b_axis = -1; + int32_t n_axis = -1; + if (transpose_b) { + k_b_axis = 1; + n_axis = 0; + } else { + k_b_axis = 0; + n_axis = 1; + } + + std::vector out_and_add_to_output_args; + out_and_add_to_output_args.emplace_back("out", 0); + if (ctx->user_op_conf().has_input("_add_to_output", 0)) { + out_and_add_to_output_args.emplace_back("_add_to_output", 0); + } + + // S(b or m axis) x B -> S(b or m axis) + for (int64_t i = 0; i < a_shape.NumAxes() - 1; ++i) { + ctx->NewBuilder() + .Split(user_op::OpArg("a", 0), i) + .Broadcast(user_op::OpArg("b", 0)) + .Split(out_and_add_to_output_args, i) + .Build(); + } + // B x S(n_axis) -> S(n_axis) + ctx->NewBuilder() + .Broadcast(user_op::OpArg("a", 0)) + .Split(user_op::OpArg("b", 0), n_axis) + .Split(out_and_add_to_output_args, a_shape.NumAxes() - 1) + .Build(); + // S(a_k_axis) x S(b_k_axis) -> P + ctx->NewBuilder() + .Split(user_op::OpArg("a", 0), k_a_axis) + .Split(user_op::OpArg("b", 0), k_b_axis) + .PartialSum(out_and_add_to_output_args) + .Build(); + // P x B -> P + ctx->NewBuilder() + .PartialSum(user_op::OpArg("a", 0)) + .Broadcast(user_op::OpArg("b", 0)) + .PartialSum(out_and_add_to_output_args) + .Build(); + // B x P -> P + ctx->NewBuilder() + .Broadcast(user_op::OpArg("a", 0)) + .PartialSum(user_op::OpArg("b", 0)) + .PartialSum(out_and_add_to_output_args) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe BroadcastMatmulOp::InferDataType(user_op::InferContext* ctx) { + return InferDataType4Matmul(ctx); +} + +/* static */ Maybe BroadcastMatmulGradBOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& a = ctx->InputTensorDesc("a", 0); + const user_op::TensorDesc& b = ctx->InputTensorDesc("b", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + + CHECK_EQ_OR_RETURN(a.shape().NumAxes(), b.shape().NumAxes()); + for (int i = 0; i < a.shape().NumAxes() - 1; ++i) { + CHECK_EQ_OR_RETURN(a.shape().At(i), b.shape().At(i)); + } + + *out->mut_shape() = + Shape({a.shape().At(a.shape().NumAxes() - 1), b.shape().At(b.shape().NumAxes() - 1)}); + + if (ctx->has_input("_add_to_output", 0)) { + const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); + CHECK_EQ_OR_RETURN(add_to_output.shape(), out->shape()); + } + + return Maybe::Ok(); +} + +/*static*/ Maybe BroadcastMatmulGradBOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe BroadcastMatmulGradBOp::GetSbp(user_op::SbpContext* ctx) { + const auto& a_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape(); + int64_t last_axis = a_shape.NumAxes() - 1; + + std::vector out_and_add_to_output_args; + out_and_add_to_output_args.emplace_back("out", 0); + if (ctx->user_op_conf().has_input("_add_to_output", 0)) { + out_and_add_to_output_args.emplace_back("_add_to_output", 0); + } + + // S(b or m axis) x S(b or m axis) -> P + for (int64_t i = 0; i < last_axis; ++i) { + ctx->NewBuilder() + .Split(user_op::OpArg("a", 0), i) + .Split(user_op::OpArg("b", 0), i) + .PartialSum(out_and_add_to_output_args) + .Build(); + } + + // (b, m, k) * (b, m, n) -> (k, n) [transpose a] + // S(k) x B -> S(0) or B x S(n) -> S(1) + // (b, m, n) * (b, m, k) -> (n, k) [transpose a] + // S(n) x B -> S(0) or B x S(k) -> S(1) + ctx->NewBuilder() + .Split(user_op::OpArg("a", 0), last_axis) + .Broadcast(user_op::OpArg("b", 0)) + .Split(out_and_add_to_output_args, 0) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("a", 0)) + .Split(user_op::OpArg("b", 0), last_axis) + .Split(out_and_add_to_output_args, 1) + .Build(); + + return Maybe::Ok(); +} + +/* static */ Maybe BroadcastMatmulGradBOp::InferDataType(user_op::InferContext* ctx) { + return InferDataType4Matmul(ctx); +} + +REGISTER_USER_OP_GRAD("matmul").SetGenBackwardOpConfFn( + [](const user_op::UserOpWrapper& op, const user_op::AddOpFn& AddOp) -> Maybe { + GenBackwardOpConf4Matmul("matmul", op, AddOp); return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto& a_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape(); - int64_t last_axis = a_shape.NumAxes() - 1; - - std::vector out_and_add_to_output_args; - out_and_add_to_output_args.emplace_back("out", 0); - if (ctx->user_op_conf().has_input("_add_to_output", 0)) { - out_and_add_to_output_args.emplace_back("_add_to_output", 0); - } - - // S(b or m axis) x S(b or m axis) -> P - for (int64_t i = 0; i < last_axis; ++i) { - ctx->NewBuilder() - .Split(user_op::OpArg("a", 0), i) - .Split(user_op::OpArg("b", 0), i) - .PartialSum(out_and_add_to_output_args) - .Build(); - } - - // (b, m, k) * (b, m, n) -> (k, n) [transpose a] - // S(k) x B -> S(0) or B x S(n) -> S(1) - // (b, m, n) * (b, m, k) -> (n, k) [transpose a] - // S(n) x B -> S(0) or B x S(k) -> S(1) - ctx->NewBuilder() - .Split(user_op::OpArg("a", 0), last_axis) - .Broadcast(user_op::OpArg("b", 0)) - .Split(out_and_add_to_output_args, 0) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("a", 0)) - .Split(user_op::OpArg("b", 0), last_axis) - .Split(out_and_add_to_output_args, 1) - .Build(); + }); +REGISTER_USER_OP_GRAD("batch_matmul") + .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, + const user_op::AddOpFn& AddOp) -> Maybe { + GenBackwardOpConf4Matmul("batch_matmul", op, AddOp); return Maybe::Ok(); }); diff --git a/oneflow/user/ops/min_max_observer_op.cpp b/oneflow/user/ops/min_max_observer_op.cpp index d1003ba287f..3d7f186c378 100644 --- a/oneflow/user/ops/min_max_observer_op.cpp +++ b/oneflow/user/ops/min_max_observer_op.cpp @@ -14,73 +14,65 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/* static */ Maybe MinMaxObserverOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); -REGISTER_NO_GRAD_USER_OP("min_max_observer") - .Input("in") - .Output("scale") - .Output("zero_point") - // NOTE(Liang Depeng): "google" or "cambricon" - .Attr("quantization_formula", "google") - // NOTE(Liang Depeng): quantize from float32 to "quantization_bit" bit signed or unsigned - // integer - .Attr("quantization_bit", 8) - // NOTE(Liang Depeng): "symmetric" or "affine": quantize to signed or unsigned integer - .Attr("quantization_scheme", "symmetric") - // NOTE(Liang Depeng): "true" or "false": per-layer or per-channel quantization. - .Attr("per_layer_quantization", true) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); + if (ctx->Attr("quantization_formula") == "google") { + if (ctx->Attr("per_layer_quantization") == true) { + *ctx->OutputShape("scale", 0) = Shape({1}); + *ctx->OutputShape("zero_point", 0) = Shape({1}); + } else { + // NOTE(Liang Depeng): For now per-channel quantization only support axis 0 + *ctx->OutputShape("scale", 0) = Shape({in_shape.At(0)}); + *ctx->OutputShape("zero_point", 0) = Shape({in_shape.At(0)}); + } + } else { // quantization_formula == "cambricon" + *ctx->OutputShape("scale", 0) = Shape({1}); + *ctx->OutputShape("zero_point", 0) = Shape({1}); + } + return Maybe::Ok(); +} - if (ctx->Attr("quantization_formula") == "google") { - if (ctx->Attr("per_layer_quantization") == true) { - *ctx->OutputShape("scale", 0) = Shape({1}); - *ctx->OutputShape("zero_point", 0) = Shape({1}); - } else { - // NOTE(Liang Depeng): For now per-channel quantization only support axis 0 - *ctx->OutputShape("scale", 0) = Shape({in_shape.At(0)}); - *ctx->OutputShape("zero_point", 0) = Shape({in_shape.At(0)}); - } - } else { // quantization_formula == "cambricon" - *ctx->OutputShape("scale", 0) = Shape({1}); - *ctx->OutputShape("zero_point", 0) = Shape({1}); - } - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* in = GetInputArgModifierFn("in", 0); - CHECK_OR_RETURN(in != nullptr); - in->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - // NOTE(Liang Depeng): input needs to be broadcast in order to accurately calculate the - // global scale and zero_point - return Maybe::Ok(); - }) - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& op_def, - const user_op::UserOpConfWrapper& op_conf) -> Maybe { - int32_t quantization_bit = op_conf.attr("quantization_bit"); - CHECK_GT_OR_RETURN(quantization_bit, 1); - CHECK_LE_OR_RETURN(quantization_bit, 8); +/*static*/ Maybe MinMaxObserverOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} - std::string quantization_scheme = op_conf.attr("quantization_scheme"); - CHECK_OR_RETURN(quantization_scheme == "symmetric" || quantization_scheme == "affine"); +/* static */ Maybe MinMaxObserverOp::GetSbp(user_op::SbpContext* ctx) { + // NOTE(Liang Depeng): input needs to be broadcast in order to accurately calculate the + // global scale and zero_point + return Maybe::Ok(); +} - std::string quantization_formula = op_conf.attr("quantization_formula"); - CHECK_OR_RETURN(quantization_formula == "google" || quantization_formula == "cambricon"); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("scale", 0) = ctx->InputDType("in", 0); - *ctx->OutputDType("zero_point", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/* static */ Maybe MinMaxObserverOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* in = GetInputArgModifierFn("in", 0); + CHECK_OR_RETURN(in != nullptr); + in->set_requires_grad(false); + return Maybe::Ok(); +} -} // namespace +/* static */ Maybe MinMaxObserverOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& op_conf) { + int32_t quantization_bit = op_conf.attr("quantization_bit"); + CHECK_GT_OR_RETURN(quantization_bit, 1); + CHECK_LE_OR_RETURN(quantization_bit, 8); + + std::string quantization_scheme = op_conf.attr("quantization_scheme"); + CHECK_OR_RETURN(quantization_scheme == "symmetric" || quantization_scheme == "affine"); + + std::string quantization_formula = op_conf.attr("quantization_formula"); + CHECK_OR_RETURN(quantization_formula == "google" || quantization_formula == "cambricon"); + return Maybe::Ok(); +} + +/* static */ Maybe MinMaxObserverOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("scale", 0) = ctx->InputDType("in", 0); + *ctx->OutputDType("zero_point", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/mish_op.cpp b/oneflow/user/ops/mish_op.cpp index 4f51ca76034..9b3c04bf17d 100644 --- a/oneflow/user/ops/mish_op.cpp +++ b/oneflow/user/ops/mish_op.cpp @@ -14,61 +14,62 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/* static */ Maybe MishOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("mish") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe MishOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} -REGISTER_USER_OP("mish_grad") - .Input("x") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK(dy_shape == x_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/* static */ Maybe MishOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe MishOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe MishGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK(dy_shape == x_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} + +/*static*/ Maybe MishGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe MishGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe MishGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("mish").SetBackwardOpConfGenFn( [](user_op::BackwardOpConfContext* ctx) -> Maybe { @@ -87,6 +88,4 @@ REGISTER_USER_OP_GRAD("mish").SetBackwardOpConfGenFn( return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/model_update_ops.cpp b/oneflow/user/ops/model_update_ops.cpp index 5c842cea319..d0da22056f9 100644 --- a/oneflow/user/ops/model_update_ops.cpp +++ b/oneflow/user/ops/model_update_ops.cpp @@ -17,6 +17,7 @@ limitations under the License. #include "oneflow/core/framework/infer_util.h" #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/framework/user_op_registry.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -427,361 +428,369 @@ Maybe InferLarsUpdateDataType(user_op::InferContext* ctx) { } return Maybe::Ok(); } -REGISTER_NO_GRAD_USER_OP("sgd_update") - .Input("model") - .Input("model_diff") - .OptionalInput("learning_rate") - .OptionalInput("scale_by_tensor") - .OptionalInput("skip_if") - .Attr("learning_rate_val", 0.0) - .Attr("scale", 1.0) - .Attr("l1", 0.0) - .Attr("l2", 0.0) - .Attr("weight_decay", 0.0) - .SetTensorDescInferFn(InferSGDUpdateTensorDesc) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); - FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { - ctx->NewBuilder() - .Broadcast(ctx->inputs()) - .Split(user_op::OpArg("model", 0), axis) - .Split(user_op::OpArg("model_diff", 0), axis) - .Build(); - } - return Maybe::Ok(); - }) - .SetInputArgModifyFn(SgdInputArgModifyFn) - .SetDataTypeInferFn(InferSGDUpdateDataType); - -REGISTER_NO_GRAD_USER_OP("indexed_slices_sgd_update") - .Input("model") - .Input("model_diff_indices") - .Input("model_diff_values") - .Input("learning_rate") - .Attr("weight_decay", 0.0) - .SetTensorDescInferFn(InferIndexedSlicesSGDUpdateTensorDesc) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); - const user_op::TensorDesc& model_diff_indices = - ctx->LogicalTensorDesc4InputArgNameAndIndex("model_diff_indices", 0); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("learning_rate", 0)) - .Broadcast(user_op::OpArg("model_diff_indices", 0)) - .Broadcast(user_op::OpArg("model_diff_values", 0)) - .Split(user_op::OpArg("model", 0), 0) - .Build(); - FOR_RANGE(int64_t, i, 1, model.shape().NumAxes()) { - ctx->NewBuilder() - .Broadcast(user_op::OpArg("learning_rate", 0)) - .Broadcast(user_op::OpArg("model_diff_indices", 0)) - .Split(user_op::OpArg("model_diff_values", 0), - model_diff_indices.shape().NumAxes() + i - 1) - .Split(user_op::OpArg("model", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetInputArgModifyFn(IndexedSlicesSgdInputArgModifyFn) - .SetDataTypeInferFn(InferIndexedSlicesSGDUpdateDataType); - -REGISTER_NO_GRAD_USER_OP("momentum_update") - .Input("model") - .Input("model_diff") - .Input("momentum") - .OptionalInput("learning_rate") - .OptionalInput("scale_by_tensor") - .OptionalInput("skip_if") - .Attr("learning_rate_val", 0.0) - .Attr("scale", 1.0) - .Attr("l1", 0.0) - .Attr("l2", 0.0) - .Attr("beta", 0.9) - .Attr("weight_decay", 0.0) - .SetTensorDescInferFn(InferMomentumUpdateTensorDesc) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); - FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { - ctx->NewBuilder() - .Broadcast(ctx->inputs()) - .Split(user_op::OpArg("model", 0), axis) - .Split(user_op::OpArg("model_diff", 0), axis) - .Split(user_op::OpArg("momentum", 0), axis) - .Build(); - } - return Maybe::Ok(); - }) - .SetInputArgModifyFn(MomentumInputArgModifyFn) - .SetDataTypeInferFn(InferMomentumUpdateDataType); - -REGISTER_NO_GRAD_USER_OP("indexed_slices_momentum_update") - .Input("model") - .Input("model_diff_indices") - .Input("model_diff_values") - .Input("learning_rate") - .Input("momentum") - .Attr("beta", 0.9) - .Attr("weight_decay", 0.0) - .SetTensorDescInferFn(InferIndexedSlicesMomentumUpdateTensorDesc) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); - const user_op::TensorDesc& model_diff_indices = - ctx->LogicalTensorDesc4InputArgNameAndIndex("model_diff_indices", 0); + +} // namespace + +/* static */ Maybe SgdUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferSGDUpdateTensorDesc(ctx); +} + +/*static*/ Maybe SgdUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe SgdUpdateOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); + FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { + ctx->NewBuilder() + .Broadcast(ctx->inputs()) + .Split(user_op::OpArg("model", 0), axis) + .Split(user_op::OpArg("model_diff", 0), axis) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe SgdUpdateOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return SgdInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe SgdUpdateOp::InferDataType(user_op::InferContext* ctx) { + return InferSGDUpdateDataType(ctx); +} + +/* static */ Maybe IndexedSlicesSgdUpdateOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return InferIndexedSlicesSGDUpdateTensorDesc(ctx); +} + +/*static*/ Maybe IndexedSlicesSgdUpdateOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe IndexedSlicesSgdUpdateOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); + const user_op::TensorDesc& model_diff_indices = + ctx->LogicalTensorDesc4InputArgNameAndIndex("model_diff_indices", 0); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("learning_rate", 0)) + .Broadcast(user_op::OpArg("model_diff_indices", 0)) + .Broadcast(user_op::OpArg("model_diff_values", 0)) + .Split(user_op::OpArg("model", 0), 0) + .Build(); + FOR_RANGE(int64_t, i, 1, model.shape().NumAxes()) { + ctx->NewBuilder() + .Broadcast(user_op::OpArg("learning_rate", 0)) + .Broadcast(user_op::OpArg("model_diff_indices", 0)) + .Split(user_op::OpArg("model_diff_values", 0), model_diff_indices.shape().NumAxes() + i - 1) + .Split(user_op::OpArg("model", 0), i) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe IndexedSlicesSgdUpdateOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return IndexedSlicesSgdInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe IndexedSlicesSgdUpdateOp::InferDataType(user_op::InferContext* ctx) { + return InferIndexedSlicesSGDUpdateDataType(ctx); +} + +/* static */ Maybe MomentumUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferMomentumUpdateTensorDesc(ctx); +} + +/*static*/ Maybe MomentumUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe MomentumUpdateOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); + FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { + ctx->NewBuilder() + .Broadcast(ctx->inputs()) + .Split(user_op::OpArg("model", 0), axis) + .Split(user_op::OpArg("model_diff", 0), axis) + .Split(user_op::OpArg("momentum", 0), axis) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe MomentumUpdateOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return MomentumInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe MomentumUpdateOp::InferDataType(user_op::InferContext* ctx) { + return InferMomentumUpdateDataType(ctx); +} + +/* static */ Maybe IndexedSlicesMomentumUpdateOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return InferIndexedSlicesMomentumUpdateTensorDesc(ctx); +} + +/*static*/ Maybe IndexedSlicesMomentumUpdateOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe IndexedSlicesMomentumUpdateOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); + const user_op::TensorDesc& model_diff_indices = + ctx->LogicalTensorDesc4InputArgNameAndIndex("model_diff_indices", 0); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("learning_rate", 0)) + .Broadcast(user_op::OpArg("model_diff_indices", 0)) + .Broadcast(user_op::OpArg("model_diff_values", 0)) + .Split(user_op::OpArg("model", 0), 0) + .Split(user_op::OpArg("momentum", 0), 0) + .Build(); + FOR_RANGE(int64_t, i, 1, model.shape().NumAxes()) { + ctx->NewBuilder() + .Broadcast(user_op::OpArg("learning_rate", 0)) + .Broadcast(user_op::OpArg("model_diff_indices", 0)) + .Split(user_op::OpArg("model_diff_values", 0), model_diff_indices.shape().NumAxes() + i - 1) + .Split(user_op::OpArg("model", 0), i) + .Split(user_op::OpArg("momentum", 0), i) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe IndexedSlicesMomentumUpdateOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return IndexedSlicesMomentumInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe IndexedSlicesMomentumUpdateOp::InferDataType(user_op::InferContext* ctx) { + return InferIndexedSlicesMomentumUpdateDataType(ctx); +} + +/* static */ Maybe AdamUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferAdamUpdateTensorDesc(ctx); +} + +/*static*/ Maybe AdamUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe AdamUpdateOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); + FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { + ctx->NewBuilder() + .Broadcast(ctx->inputs()) + .Split(user_op::OpArg("model", 0), axis) + .Split(user_op::OpArg("model_diff", 0), axis) + .Split(user_op::OpArg("m", 0), axis) + .Split(user_op::OpArg("v", 0), axis) + .Split(user_op::OpArg("max_v", 0), axis) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe AdamUpdateOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return AdamInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe AdamUpdateOp::InferDataType(user_op::InferContext* ctx) { + return InferAdamUpdateDataType(ctx); +} + +/* static */ Maybe AdagradUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferAdagradUpdateTensorDesc(ctx); +} + +/*static*/ Maybe AdagradUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe AdagradUpdateOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); + FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { + ctx->NewBuilder() + .Broadcast(ctx->inputs()) + .Split(user_op::OpArg("model", 0), axis) + .Split(user_op::OpArg("model_diff", 0), axis) + .Split(user_op::OpArg("sum", 0), axis) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe AdagradUpdateOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return AdagradInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe AdagradUpdateOp::InferDataType(user_op::InferContext* ctx) { + return InferAdagradUpdateDataType(ctx); +} + +/* static */ Maybe IndexedSlicesAdamUpdateOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return InferIndexedSlicesAdamUpdateTensorDesc(ctx); +} + +/*static*/ Maybe IndexedSlicesAdamUpdateOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe IndexedSlicesAdamUpdateOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); + const user_op::TensorDesc& model_diff_indices = + ctx->LogicalTensorDesc4InputArgNameAndIndex("model_diff_indices", 0); + std::vector broadcast_args; + broadcast_args.emplace_back("learning_rate", 0); + broadcast_args.emplace_back("model_diff_indices", 0); + ctx->NewBuilder() + .Broadcast(broadcast_args) + .Broadcast(user_op::OpArg("model_diff_values", 0)) + .Split(user_op::OpArg("model", 0), 0) + .Split(user_op::OpArg("m", 0), 0) + .Split(user_op::OpArg("v", 0), 0) + .Split(user_op::OpArg("max_v", 0), 0) + .Build(); + FOR_RANGE(int64_t, i, 1, model.shape().NumAxes()) { + ctx->NewBuilder() + .Broadcast(broadcast_args) + .Split(user_op::OpArg("model_diff_values", 0), model_diff_indices.shape().NumAxes() + i - 1) + .Split(user_op::OpArg("model", 0), i) + .Split(user_op::OpArg("m", 0), i) + .Split(user_op::OpArg("v", 0), i) + .Split(user_op::OpArg("max_v", 0), i) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe IndexedSlicesAdamUpdateOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return AdamInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe IndexedSlicesAdamUpdateOp::InferDataType(user_op::InferContext* ctx) { + return InferIndexedSlicesAdamUpdateDataType(ctx); +} + +/* static */ Maybe LambUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferLambUpdateTensorDesc(ctx); +} + +/*static*/ Maybe LambUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe LambUpdateOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe LambUpdateOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return LambInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe LambUpdateOp::InferDataType(user_op::InferContext* ctx) { + return InferLambUpdateDataType(ctx); +} + +/* static */ Maybe AdamBiasCorrectionFactorOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("train_step", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe AdamBiasCorrectionFactorOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe AdamBiasCorrectionFactorOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe AdamBiasCorrectionFactorOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = DataType::kFloat; + return Maybe::Ok(); +} + +/* static */ Maybe RmspropUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferRmsPropUpdateTensorDesc(ctx); +} + +/*static*/ Maybe RmspropUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe RmspropUpdateOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); + bool centered = ctx->Attr("centered"); + FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { + if (centered) { ctx->NewBuilder() - .Broadcast(user_op::OpArg("learning_rate", 0)) - .Broadcast(user_op::OpArg("model_diff_indices", 0)) - .Broadcast(user_op::OpArg("model_diff_values", 0)) - .Split(user_op::OpArg("model", 0), 0) - .Split(user_op::OpArg("momentum", 0), 0) + .Broadcast(ctx->inputs()) + .Split(user_op::OpArg("model", 0), axis) + .Split(user_op::OpArg("model_diff", 0), axis) + .Split(user_op::OpArg("mean_square", 0), axis) + .Split(user_op::OpArg("mean_gradient", 0), axis) .Build(); - FOR_RANGE(int64_t, i, 1, model.shape().NumAxes()) { - ctx->NewBuilder() - .Broadcast(user_op::OpArg("learning_rate", 0)) - .Broadcast(user_op::OpArg("model_diff_indices", 0)) - .Split(user_op::OpArg("model_diff_values", 0), - model_diff_indices.shape().NumAxes() + i - 1) - .Split(user_op::OpArg("model", 0), i) - .Split(user_op::OpArg("momentum", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetInputArgModifyFn(IndexedSlicesMomentumInputArgModifyFn) - .SetDataTypeInferFn(InferIndexedSlicesMomentumUpdateDataType); - -REGISTER_NO_GRAD_USER_OP("adam_update") - .Input("model") - .Input("model_diff") - .OptionalInput("learning_rate") - .OptionalInput("scale_by_tensor") - .OptionalInput("skip_if") - .OptionalInput("bias_correction1") - .OptionalInput("bias_correction2") - .Input("m") - .Input("v") - .Input("max_v") - .Attr("learning_rate_val", 0.0) - .Attr("bias_correction1_val", 1.0) - .Attr("bias_correction2_val", 1.0) - .Attr("scale", 1.0) - .Attr("l1", 0.0) - .Attr("l2", 0.0) - .Attr("beta1", 0.9) - .Attr("beta2", 0.999) - .Attr("epsilon", 1e-8) - .Attr("weight_decay", 0.0) - .Attr("amsgrad", false) - .Attr("do_bias_correction", true) - .SetTensorDescInferFn(InferAdamUpdateTensorDesc) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); - FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { - ctx->NewBuilder() - .Broadcast(ctx->inputs()) - .Split(user_op::OpArg("model", 0), axis) - .Split(user_op::OpArg("model_diff", 0), axis) - .Split(user_op::OpArg("m", 0), axis) - .Split(user_op::OpArg("v", 0), axis) - .Split(user_op::OpArg("max_v", 0), axis) - .Build(); - } - return Maybe::Ok(); - }) - .SetInputArgModifyFn(AdamInputArgModifyFn) - .SetDataTypeInferFn(InferAdamUpdateDataType); - -REGISTER_NO_GRAD_USER_OP("adagrad_update") - .Input("model") - .Input("model_diff") - .OptionalInput("learning_rate") - .OptionalInput("scale_by_tensor") - .OptionalInput("skip_if") - .OptionalInput("train_step") - .Input("sum") - .Attr("train_step_val", 0) - .Attr("learning_rate_val", 0.0) - .Attr("scale", 1.0) - .Attr("l1", 0.0) - .Attr("l2", 0.0) - .Attr("lr_decay", 0.0) - .Attr("weight_decay", 0.0) - .Attr("epsilon", 1e-10) - .SetTensorDescInferFn(InferAdagradUpdateTensorDesc) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); - FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { - ctx->NewBuilder() - .Broadcast(ctx->inputs()) - .Split(user_op::OpArg("model", 0), axis) - .Split(user_op::OpArg("model_diff", 0), axis) - .Split(user_op::OpArg("sum", 0), axis) - .Build(); - } - return Maybe::Ok(); - }) - .SetInputArgModifyFn(AdagradInputArgModifyFn) - .SetDataTypeInferFn(InferAdagradUpdateDataType); - -REGISTER_NO_GRAD_USER_OP("indexed_slices_adam_update") - .Input("model") - .Input("model_diff_indices") - .Input("model_diff_values") - .Input("learning_rate") - .OptionalInput("bias_correction1") - .OptionalInput("bias_correction2") - .Input("m") - .Input("v") - .Input("max_v") - .Attr("learning_rate_val", 0.0) - .Attr("beta1", 0.9) - .Attr("beta2", 0.999) - .Attr("epsilon", 1e-8) - .Attr("weight_decay", 0.0) - .Attr("amsgrad", false) - .Attr("do_bias_correction", true) - .SetTensorDescInferFn(InferIndexedSlicesAdamUpdateTensorDesc) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); - const user_op::TensorDesc& model_diff_indices = - ctx->LogicalTensorDesc4InputArgNameAndIndex("model_diff_indices", 0); - std::vector broadcast_args; - broadcast_args.emplace_back("learning_rate", 0); - broadcast_args.emplace_back("model_diff_indices", 0); + } else { ctx->NewBuilder() - .Broadcast(broadcast_args) - .Broadcast(user_op::OpArg("model_diff_values", 0)) - .Split(user_op::OpArg("model", 0), 0) - .Split(user_op::OpArg("m", 0), 0) - .Split(user_op::OpArg("v", 0), 0) - .Split(user_op::OpArg("max_v", 0), 0) + .Broadcast(ctx->inputs()) + .Split(user_op::OpArg("model", 0), axis) + .Split(user_op::OpArg("model_diff", 0), axis) + .Split(user_op::OpArg("mean_square", 0), axis) .Build(); - FOR_RANGE(int64_t, i, 1, model.shape().NumAxes()) { - ctx->NewBuilder() - .Broadcast(broadcast_args) - .Split(user_op::OpArg("model_diff_values", 0), - model_diff_indices.shape().NumAxes() + i - 1) - .Split(user_op::OpArg("model", 0), i) - .Split(user_op::OpArg("m", 0), i) - .Split(user_op::OpArg("v", 0), i) - .Split(user_op::OpArg("max_v", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetInputArgModifyFn(AdamInputArgModifyFn) - .SetDataTypeInferFn(InferIndexedSlicesAdamUpdateDataType); - -REGISTER_NO_GRAD_USER_OP("lamb_update") - .Input("m") - .Input("v") - .Input("beta1_t") - .Input("beta2_t") - .Input("model") - .Input("model_diff") - .Input("learning_rate") - .OptionalInput("scale_by_tensor") - .OptionalInput("skip_if") - .Attr("beta1") - .Attr("beta2") - .Attr("epsilon") - .Attr("scale", 1.0) - .Attr("l1", 0.0) - .Attr("l2", 0.0) - .Attr("weight_decay", 0.0) - .SetTensorDescInferFn(InferLambUpdateTensorDesc) - // every bn has sbp broadcast signature - .SetInputArgModifyFn(LambInputArgModifyFn) - .SetDataTypeInferFn(InferLambUpdateDataType) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_USER_OP("adam_bias_correction_factor") - .Input("train_step") - .Output("out") - .Attr("beta", 0.9) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("train_step", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = DataType::kFloat; - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -// every bn has sbp broadcast signature - -REGISTER_NO_GRAD_USER_OP("rmsprop_update") - .Input("model") - .Input("model_diff") - .OptionalInput("learning_rate") - .OptionalInput("scale_by_tensor") - .OptionalInput("skip_if") - .Input("mean_square") - .OptionalInput("mean_gradient") - .Attr("learning_rate_val", 0.0) - .Attr("scale", 1.0) - .Attr("l1", 0.0) - .Attr("l2", 0.0) - .Attr("centered", false) - .Attr("epsilon", 1e-8) - .Attr("decay_rate", 0.99) - .Attr("weight_decay", 0.0) - .SetTensorDescInferFn(InferRmsPropUpdateTensorDesc) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); - bool centered = ctx->Attr("centered"); - FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { - if (centered) { - ctx->NewBuilder() - .Broadcast(ctx->inputs()) - .Split(user_op::OpArg("model", 0), axis) - .Split(user_op::OpArg("model_diff", 0), axis) - .Split(user_op::OpArg("mean_square", 0), axis) - .Split(user_op::OpArg("mean_gradient", 0), axis) - .Build(); - } else { - ctx->NewBuilder() - .Broadcast(ctx->inputs()) - .Split(user_op::OpArg("model", 0), axis) - .Split(user_op::OpArg("model_diff", 0), axis) - .Split(user_op::OpArg("mean_square", 0), axis) - .Build(); - } - } - return Maybe::Ok(); - }) - .SetInputArgModifyFn(RmsPropUpdateInputArgModifyFn) - .SetDataTypeInferFn(InferRmsPropUpdateDataType); - -REGISTER_NO_GRAD_USER_OP("lars_update") - .Input("model") - .Input("model_diff") - .Input("learning_rate") - .Input("momentum") - .OptionalInput("scale_by_tensor") - .OptionalInput("skip_if") - .Attr("scale", 1.0) - .Attr("l1", 0.0) - .Attr("l2", 0.0) - .Attr("momentum_beta", 0.9) - .Attr("epsilon", 1e-9) - .Attr("lars_coefficient", 1e-4) - .Attr("weight_decay", 0.0) - .SetTensorDescInferFn(InferLarsUpdateTensorDesc) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); - FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { - ctx->NewBuilder() - .Broadcast(ctx->inputs()) - .Split(user_op::OpArg("model", 0), axis) - .Split(user_op::OpArg("model_diff", 0), axis) - .Split(user_op::OpArg("momentum", 0), axis) - .Build(); - } - return Maybe::Ok(); - }) - .SetInputArgModifyFn(LarsUpdateInputArgModifyFn) - .SetDataTypeInferFn(InferLarsUpdateDataType); + } + } + return Maybe::Ok(); +} -} // namespace +/* static */ Maybe RmspropUpdateOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return RmsPropUpdateInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe RmspropUpdateOp::InferDataType(user_op::InferContext* ctx) { + return InferRmsPropUpdateDataType(ctx); +} + +/* static */ Maybe LarsUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferLarsUpdateTensorDesc(ctx); +} + +/*static*/ Maybe LarsUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe LarsUpdateOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); + FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { + ctx->NewBuilder() + .Broadcast(ctx->inputs()) + .Split(user_op::OpArg("model", 0), axis) + .Split(user_op::OpArg("model_diff", 0), axis) + .Split(user_op::OpArg("momentum", 0), axis) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe LarsUpdateOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return LarsUpdateInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe LarsUpdateOp::InferDataType(user_op::InferContext* ctx) { + return InferLarsUpdateDataType(ctx); +} } // namespace oneflow diff --git a/oneflow/user/ops/moving_average_min_max_observer_op.cpp b/oneflow/user/ops/moving_average_min_max_observer_op.cpp index 8c4c59dc8e1..434865f2d59 100644 --- a/oneflow/user/ops/moving_average_min_max_observer_op.cpp +++ b/oneflow/user/ops/moving_average_min_max_observer_op.cpp @@ -14,94 +14,82 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -REGISTER_NO_GRAD_USER_OP("moving_average_min_max_observer") - .Input("in") - .Input("current_train_step") - .Input("moving_max") // NOTE(Liang Depeng): needs to be initialized as 0 - .Input("moving_min") // NOTE(Liang Depeng): needs to be initialized as 0 - .Output("scale") - .Output("zero_point") - .Attr("training") - // NOTE(Liang Depeng): "google" or "cambricon" - .Attr("quantization_formula", "google") - .Attr("stop_update_after_iters") - // NOTE(Liang Depeng): quantize from float32 to "quantization_bit" bit signed or unsigned - // integer - .Attr("quantization_bit", 8) - // NOTE(Liang Depeng): "symmetric" or "affine": quantize to signed or unsigned integer - .Attr("quantization_scheme", "symmetric") - // NOTE(Liang Depeng): smoothing parameter for exponential moving average operation - .Attr("momentum", 0.95) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& moving_max_shape = ctx->InputShape("moving_max", 0); - const Shape& moving_min_shape = ctx->InputShape("moving_min", 0); - const Shape& current_train_step = ctx->InputShape("current_train_step", 0); - - // NOTE(Liang Depeng): for now only support per-layer quantization - // TODO(Liang Depeng): depthwise convolution support per-channel quantization - CHECK_OR_RETURN(moving_max_shape.NumAxes() == 1 && moving_max_shape.At(0) == 1); - CHECK_OR_RETURN(moving_min_shape.NumAxes() == 1 && moving_min_shape.At(0) == 1); - - CHECK_OR_RETURN(current_train_step.NumAxes() == 1 && current_train_step.At(0) == 1); - - *ctx->OutputShape("scale", 0) = Shape({1}); - *ctx->OutputShape("zero_point", 0) = Shape({1}); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("scale", 0) = ctx->InputDType("in", 0); - *ctx->OutputDType("zero_point", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* in = GetInputArgModifierFn("in", 0); - CHECK_OR_RETURN(in != nullptr); - in->set_requires_grad(false); - - user_op::InputArgModifier* current_train_step = - GetInputArgModifierFn("current_train_step", 0); - CHECK_OR_RETURN(current_train_step != nullptr); - current_train_step->set_requires_grad(false); - - user_op::InputArgModifier* moving_max = GetInputArgModifierFn("moving_max", 0); - CHECK_OR_RETURN(moving_max != nullptr); - moving_max->set_requires_grad(false); - moving_max->set_is_mutable(true); - - user_op::InputArgModifier* moving_min = GetInputArgModifierFn("moving_min", 0); - CHECK_OR_RETURN(moving_min != nullptr); - moving_min->set_requires_grad(false); - moving_min->set_is_mutable(true); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - // NOTE(Liang Depeng): all inputs need to be broadcast in order to accuratly calculate the - // global scale and zero_point - return Maybe::Ok(); - }) - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& op_def, - const user_op::UserOpConfWrapper& op_conf) -> Maybe { - int32_t quantization_bit = op_conf.attr("quantization_bit"); - CHECK_GT_OR_RETURN(quantization_bit, 1); - CHECK_LE_OR_RETURN(quantization_bit, 8); - - std::string quantization_scheme = op_conf.attr("quantization_scheme"); - CHECK_OR_RETURN(quantization_scheme == "symmetric" || quantization_scheme == "affine"); - - int64_t stop_update_after_iters = op_conf.attr("stop_update_after_iters"); - CHECK_GT_OR_RETURN(stop_update_after_iters, 0); - - std::string quantization_formula = op_conf.attr("quantization_formula"); - CHECK_OR_RETURN(quantization_formula == "google" || quantization_formula == "cambricon"); - return Maybe::Ok(); - }); - -} // namespace +/* static */ Maybe MovingAverageMinMaxObserverOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const Shape& moving_max_shape = ctx->InputShape("moving_max", 0); + const Shape& moving_min_shape = ctx->InputShape("moving_min", 0); + const Shape& current_train_step = ctx->InputShape("current_train_step", 0); + + // NOTE(Liang Depeng): for now only support per-layer quantization + // TODO(Liang Depeng): depthwise convolution support per-channel quantization + CHECK_OR_RETURN(moving_max_shape.NumAxes() == 1 && moving_max_shape.At(0) == 1); + CHECK_OR_RETURN(moving_min_shape.NumAxes() == 1 && moving_min_shape.At(0) == 1); + + CHECK_OR_RETURN(current_train_step.NumAxes() == 1 && current_train_step.At(0) == 1); + + *ctx->OutputShape("scale", 0) = Shape({1}); + *ctx->OutputShape("zero_point", 0) = Shape({1}); + return Maybe::Ok(); +} + +/*static*/ Maybe MovingAverageMinMaxObserverOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe MovingAverageMinMaxObserverOp::GetSbp(user_op::SbpContext* ctx) { + // NOTE(Liang Depeng): all inputs need to be broadcast in order to accuratly calculate the + // global scale and zero_point + return Maybe::Ok(); +} + +/* static */ Maybe MovingAverageMinMaxObserverOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* in = GetInputArgModifierFn("in", 0); + CHECK_OR_RETURN(in != nullptr); + in->set_requires_grad(false); + + user_op::InputArgModifier* current_train_step = GetInputArgModifierFn("current_train_step", 0); + CHECK_OR_RETURN(current_train_step != nullptr); + current_train_step->set_requires_grad(false); + + user_op::InputArgModifier* moving_max = GetInputArgModifierFn("moving_max", 0); + CHECK_OR_RETURN(moving_max != nullptr); + moving_max->set_requires_grad(false); + moving_max->set_is_mutable(true); + + user_op::InputArgModifier* moving_min = GetInputArgModifierFn("moving_min", 0); + CHECK_OR_RETURN(moving_min != nullptr); + moving_min->set_requires_grad(false); + moving_min->set_is_mutable(true); + return Maybe::Ok(); +} + +/* static */ Maybe MovingAverageMinMaxObserverOp::CheckAttr( + const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& op_conf) { + int32_t quantization_bit = op_conf.attr("quantization_bit"); + CHECK_GT_OR_RETURN(quantization_bit, 1); + CHECK_LE_OR_RETURN(quantization_bit, 8); + + std::string quantization_scheme = op_conf.attr("quantization_scheme"); + CHECK_OR_RETURN(quantization_scheme == "symmetric" || quantization_scheme == "affine"); + + int64_t stop_update_after_iters = op_conf.attr("stop_update_after_iters"); + CHECK_GT_OR_RETURN(stop_update_after_iters, 0); + + std::string quantization_formula = op_conf.attr("quantization_formula"); + CHECK_OR_RETURN(quantization_formula == "google" || quantization_formula == "cambricon"); + return Maybe::Ok(); +} + +/* static */ Maybe MovingAverageMinMaxObserverOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("scale", 0) = ctx->InputDType("in", 0); + *ctx->OutputDType("zero_point", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/multiply_op.cpp b/oneflow/user/ops/multiply_op.cpp index 59e45557025..18d6fa26a44 100644 --- a/oneflow/user/ops/multiply_op.cpp +++ b/oneflow/user/ops/multiply_op.cpp @@ -14,48 +14,51 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("multiply") - .Input("x") - .Input("y") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& y = ctx->InputTensorDesc("y", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(x.shape() == y.shape()); - *out->mut_shape() = x.shape(); - *out->mut_is_dynamic() = x.is_dynamic(); - if (x.is_dynamic() || y.is_dynamic()) { *out->mut_is_dynamic() = true; } - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x.shape().NumAxes()) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("x", 0)) - .Broadcast(user_op::OpArg("y", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("x", 0)) - .PartialSum(user_op::OpArg("y", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& y = ctx->InputTensorDesc("y", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(x.data_type() == y.data_type()); - *out->mut_data_type() = x.data_type(); - return Maybe::Ok(); - }); +/* static */ Maybe MultiplyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& y = ctx->InputTensorDesc("y", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(x.shape() == y.shape()); + *out->mut_shape() = x.shape(); + *out->mut_is_dynamic() = x.is_dynamic(); + if (x.is_dynamic() || y.is_dynamic()) { *out->mut_is_dynamic() = true; } + return Maybe::Ok(); +} + +/*static*/ Maybe MultiplyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe MultiplyOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x.shape().NumAxes()) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("x", 0)) + .Broadcast(user_op::OpArg("y", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("x", 0)) + .PartialSum(user_op::OpArg("y", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe MultiplyOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& y = ctx->InputTensorDesc("y", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(x.data_type() == y.data_type()); + *out->mut_data_type() = x.data_type(); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("multiply") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/narrow_op.cpp b/oneflow/user/ops/narrow_op.cpp index 0ca17e284ef..aebfd5a9262 100644 --- a/oneflow/user/ops/narrow_op.cpp +++ b/oneflow/user/ops/narrow_op.cpp @@ -14,125 +14,125 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("narrow") - .Input("in") - .Output("out") - .Attr("dim") - .Attr("start") - .Attr("length") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - CHECK_GT_OR_RETURN(in.shape().NumAxes(), 0); - const int64_t& dim = ctx->Attr("dim"); - const int64_t& start = ctx->Attr("start"); - const int64_t& length = ctx->Attr("length"); - CHECK_GE_OR_RETURN(dim, 0); - CHECK_GE_OR_RETURN(start, 0); - CHECK_GE_OR_RETURN(length, 0); - CHECK_GE_OR_RETURN(in.shape().At(dim), start + length); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - - DimVector dim_vec; - dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin(), - in.shape().dim_vec().cbegin() + dim); - dim_vec.insert(dim_vec.end(), length); - dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin() + dim + 1, - in.shape().dim_vec().end()); - *out->mut_shape() = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - const int64_t& dim = ctx->Attr("dim"); - const int64_t& length = ctx->Attr("length"); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - if (i != dim) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } else { - if (length == in_tensor.shape().At(i)) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - } - } +/* static */ Maybe NarrowOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + CHECK_GT_OR_RETURN(in.shape().NumAxes(), 0); + const int64_t& dim = ctx->Attr("dim"); + const int64_t& start = ctx->Attr("start"); + const int64_t& length = ctx->Attr("length"); + CHECK_GE_OR_RETURN(dim, 0); + CHECK_GE_OR_RETURN(start, 0); + CHECK_GE_OR_RETURN(length, 0); + CHECK_GE_OR_RETURN(in.shape().At(dim), start + length); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + + DimVector dim_vec; + dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin(), in.shape().dim_vec().cbegin() + dim); + dim_vec.insert(dim_vec.end(), length); + dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin() + dim + 1, + in.shape().dim_vec().end()); + *out->mut_shape() = Shape(dim_vec); + return Maybe::Ok(); +} + +/*static*/ Maybe NarrowOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe NarrowOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + const int64_t& dim = ctx->Attr("dim"); + const int64_t& length = ctx->Attr("length"); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + if (i != dim) { ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) + .Split(user_op::OpArg("in", 0), i) + .Split(user_op::OpArg("out", 0), i) .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - *out->mut_data_type() = in.data_type(); - return Maybe::Ok(); - }); - -REGISTER_USER_OP("narrow_grad") - .Input("dy") - .Input("like") - .Output("dx") - .Attr("dim") - .Attr("start") - .Attr("length") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& like_shape = ctx->InputShape("like", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - const int64_t ndim = dy_shape.NumAxes(); - CHECK_EQ_OR_RETURN(like_shape.NumAxes(), ndim); - - *ctx->OutputShape("dx", 0) = like_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const Shape& like_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape(); - const int64_t ndim = like_shape.NumAxes(); - const int64_t& dim = ctx->Attr("dim"); - const int64_t& length = ctx->Attr("length"); - FOR_RANGE(int64_t, i, 0, ndim) { - if (i != dim) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } else { - if (length == like_shape.At(i)) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - } + } else { + if (length == in_tensor.shape().At(i)) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); } - ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); - ctx->NewBuilder() - .PartialSum(user_op::OpArg("dy", 0)) - .Broadcast(user_op::OpArg("like", 0)) - .PartialSum(user_op::OpArg("dx", 0)) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("dy", 0)) - .PartialSum(user_op::OpArg("like", 0)) - .Broadcast(user_op::OpArg("dx", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> Maybe { - user_op::InputArgModifier* dy_modifier = GetInputArgModifierFn("dy", 0); - CHECK_NOTNULL_OR_RETURN(dy_modifier); - dy_modifier->set_requires_grad(false); - user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); - CHECK_NOTNULL_OR_RETURN(like_modifier); - like_modifier->set_requires_grad(false); - return Maybe::Ok(); - }); + } + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe NarrowOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + *out->mut_data_type() = in.data_type(); + return Maybe::Ok(); +} + +/* static */ Maybe NarrowGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& like_shape = ctx->InputShape("like", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + const int64_t ndim = dy_shape.NumAxes(); + CHECK_EQ_OR_RETURN(like_shape.NumAxes(), ndim); + + *ctx->OutputShape("dx", 0) = like_shape; + return Maybe::Ok(); +} + +/*static*/ Maybe NarrowGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe NarrowGradOp::GetSbp(user_op::SbpContext* ctx) { + const Shape& like_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape(); + const int64_t ndim = like_shape.NumAxes(); + const int64_t& dim = ctx->Attr("dim"); + const int64_t& length = ctx->Attr("length"); + FOR_RANGE(int64_t, i, 0, ndim) { + if (i != dim) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } else { + if (length == like_shape.At(i)) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + } + } + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); + ctx->NewBuilder() + .PartialSum(user_op::OpArg("dy", 0)) + .Broadcast(user_op::OpArg("like", 0)) + .PartialSum(user_op::OpArg("dx", 0)) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("dy", 0)) + .PartialSum(user_op::OpArg("like", 0)) + .Broadcast(user_op::OpArg("dx", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe NarrowGradOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* dy_modifier = GetInputArgModifierFn("dy", 0); + CHECK_NOTNULL_OR_RETURN(dy_modifier); + dy_modifier->set_requires_grad(false); + user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); + CHECK_NOTNULL_OR_RETURN(like_modifier); + like_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe NarrowGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("narrow").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp b/oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp index 397aea1f05b..a061187164e 100644 --- a/oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp +++ b/oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp @@ -16,200 +16,244 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/operator.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("_nccl_logical_2D_same_dim0_all_reduce") - .Input("in") - .Output("out") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), 2); - CHECK_OR_RETURN(in_dis_hint.sbp_parallel(1).has_partial_sum_parallel()); - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_EQ_OR_RETURN(parallel_hierarchy.NumAxes(), 2); - - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - in_distribution->clear_sbp_parallel(); - out_distribution->clear_sbp_parallel(); - // in use hint - in_distribution->CopyFrom(in_dis_hint); - - // out dim0 use hint - *out_distribution->add_sbp_parallel() = in_dis_hint.sbp_parallel(0); - // out dim1 = broadcast - out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); - - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_USER_OP("_nccl_logical_2D_same_dim1_all_reduce") - .Input("in") - .Output("out") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), 2); - CHECK_OR_RETURN(in_dis_hint.sbp_parallel(0).has_partial_sum_parallel()); - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_EQ_OR_RETURN(parallel_hierarchy.NumAxes(), 2); - - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - in_distribution->clear_sbp_parallel(); - out_distribution->clear_sbp_parallel(); - // in use hint - in_distribution->CopyFrom(in_dis_hint); - - // out dim0 = broadcast - out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); - // out dim1 use hint - *out_distribution->add_sbp_parallel() = in_dis_hint.sbp_parallel(1); - - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_USER_OP("_nccl_logical_2D_same_dim0_all_gather") - .Input("in") - .Output("out") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), 2); - // (*, S(0)) -> (*, B) - CHECK_OR_RETURN(in_dis_hint.sbp_parallel(1).has_split_parallel()); - CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel(1).split_parallel().axis(), 0); - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_EQ_OR_RETURN(parallel_hierarchy.NumAxes(), 2); - - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - in_distribution->clear_sbp_parallel(); - out_distribution->clear_sbp_parallel(); - // in use hint - in_distribution->CopyFrom(in_dis_hint); - - // out dim0 use hint - *out_distribution->add_sbp_parallel() = in_dis_hint.sbp_parallel(0); - // out dim1 = broadcast - out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); - - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_USER_OP("_nccl_logical_2D_same_dim0_all_gather_noncontinuous") - .Input("in") - .Output("out") - .Attr("in_dim1_split_axis", -1) - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), 2); - // (*, S(1)) -> (*, B) - const int64_t in_split_axis = ctx->user_op_conf().attr("in_dim1_split_axis"); - CHECK_GE_OR_RETURN(in_split_axis, 1); - CHECK_OR_RETURN(in_dis_hint.sbp_parallel(1).has_split_parallel()); - CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel(1).split_parallel().axis(), in_split_axis); - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_EQ_OR_RETURN(parallel_hierarchy.NumAxes(), 2); - - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - in_distribution->clear_sbp_parallel(); - out_distribution->clear_sbp_parallel(); - // in use hint - in_distribution->CopyFrom(in_dis_hint); - - // out dim0 use hint - *out_distribution->add_sbp_parallel() = in_dis_hint.sbp_parallel(0); - // out dim1 = broadcast - out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); - - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_USER_OP("_nccl_logical_2D_same_dim0_all2all") - .Input("in") - .Output("out") - .Attr("in_dim1_split_axis", -1) - .Attr("out_dim1_split_axis", -1) - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), 2); - // (*, S(in_dim1_split_axis)) -> (*, S(out_dim1_split_axis)) - const int64_t in_split_axis = ctx->user_op_conf().attr("in_dim1_split_axis"); - const int64_t out_split_axis = ctx->user_op_conf().attr("out_dim1_split_axis"); - CHECK_OR_RETURN(in_dis_hint.sbp_parallel(1).has_split_parallel()); - CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel(1).split_parallel().axis(), in_split_axis); - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_EQ_OR_RETURN(parallel_hierarchy.NumAxes(), 2); - - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - in_distribution->clear_sbp_parallel(); - out_distribution->clear_sbp_parallel(); - // in use hint - in_distribution->CopyFrom(in_dis_hint); - - // out dim0 use hint - *out_distribution->add_sbp_parallel() = in_dis_hint.sbp_parallel(0); - // out dim1 = Split(out_split_axis) - out_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(out_split_axis); - - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/* static */ Maybe _ncclLogical_2DSameDim0AllReduceOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogical_2DSameDim0AllReduceOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe _ncclLogical_2DSameDim0AllReduceOp::InferNdSbp( + user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), 2); + CHECK_OR_RETURN(in_dis_hint.sbp_parallel(1).has_partial_sum_parallel()); + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_EQ_OR_RETURN(parallel_hierarchy.NumAxes(), 2); + + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + in_distribution->clear_sbp_parallel(); + out_distribution->clear_sbp_parallel(); + // in use hint + in_distribution->CopyFrom(in_dis_hint); + + // out dim0 use hint + *out_distribution->add_sbp_parallel() = in_dis_hint.sbp_parallel(0); + // out dim1 = broadcast + out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); + + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogical_2DSameDim0AllReduceOp::InferDataType( + user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> _ncclLogical_2DSameDim0AllReduceOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + +/* static */ Maybe _ncclLogical_2DSameDim1AllReduceOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogical_2DSameDim1AllReduceOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe _ncclLogical_2DSameDim1AllReduceOp::InferNdSbp( + user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), 2); + CHECK_OR_RETURN(in_dis_hint.sbp_parallel(0).has_partial_sum_parallel()); + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_EQ_OR_RETURN(parallel_hierarchy.NumAxes(), 2); + + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + in_distribution->clear_sbp_parallel(); + out_distribution->clear_sbp_parallel(); + // in use hint + in_distribution->CopyFrom(in_dis_hint); + + // out dim0 = broadcast + out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); + // out dim1 use hint + *out_distribution->add_sbp_parallel() = in_dis_hint.sbp_parallel(1); + + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogical_2DSameDim1AllReduceOp::InferDataType( + user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> _ncclLogical_2DSameDim1AllReduceOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + +/* static */ Maybe _ncclLogical_2DSameDim0AllGatherOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogical_2DSameDim0AllGatherOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe _ncclLogical_2DSameDim0AllGatherOp::InferNdSbp( + user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), 2); + // (*, S(0)) -> (*, B) + CHECK_OR_RETURN(in_dis_hint.sbp_parallel(1).has_split_parallel()); + CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel(1).split_parallel().axis(), 0); + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_EQ_OR_RETURN(parallel_hierarchy.NumAxes(), 2); + + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + in_distribution->clear_sbp_parallel(); + out_distribution->clear_sbp_parallel(); + // in use hint + in_distribution->CopyFrom(in_dis_hint); + + // out dim0 use hint + *out_distribution->add_sbp_parallel() = in_dis_hint.sbp_parallel(0); + // out dim1 = broadcast + out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); + + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogical_2DSameDim0AllGatherOp::InferDataType( + user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> _ncclLogical_2DSameDim0AllGatherOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + +/* static */ Maybe _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::GetSbp( + user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::InferNdSbp( + user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), 2); + // (*, S(1)) -> (*, B) + const int64_t in_split_axis = ctx->user_op_conf().attr("in_dim1_split_axis"); + CHECK_GE_OR_RETURN(in_split_axis, 1); + CHECK_OR_RETURN(in_dis_hint.sbp_parallel(1).has_split_parallel()); + CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel(1).split_parallel().axis(), in_split_axis); + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_EQ_OR_RETURN(parallel_hierarchy.NumAxes(), 2); + + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + in_distribution->clear_sbp_parallel(); + out_distribution->clear_sbp_parallel(); + // in use hint + in_distribution->CopyFrom(in_dis_hint); + + // out dim0 use hint + *out_distribution->add_sbp_parallel() = in_dis_hint.sbp_parallel(0); + // out dim1 = broadcast + out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); + + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::InferDataType( + user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + +/* static */ Maybe _ncclLogical_2DSameDim0All2allOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogical_2DSameDim0All2allOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe _ncclLogical_2DSameDim0All2allOp::InferNdSbp( + user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), 2); + // (*, S(in_dim1_split_axis)) -> (*, S(out_dim1_split_axis)) + const int64_t in_split_axis = ctx->user_op_conf().attr("in_dim1_split_axis"); + const int64_t out_split_axis = ctx->user_op_conf().attr("out_dim1_split_axis"); + CHECK_OR_RETURN(in_dis_hint.sbp_parallel(1).has_split_parallel()); + CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel(1).split_parallel().axis(), in_split_axis); + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_EQ_OR_RETURN(parallel_hierarchy.NumAxes(), 2); + + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + in_distribution->clear_sbp_parallel(); + out_distribution->clear_sbp_parallel(); + // in use hint + in_distribution->CopyFrom(in_dis_hint); + + // out dim0 use hint + *out_distribution->add_sbp_parallel() = in_dis_hint.sbp_parallel(0); + // out dim1 = Split(out_split_axis) + out_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(out_split_axis); + + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogical_2DSameDim0All2allOp::InferDataType( + user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> _ncclLogical_2DSameDim0All2allOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} } // namespace oneflow diff --git a/oneflow/user/ops/nccl_logical_ops.cpp b/oneflow/user/ops/nccl_logical_ops.cpp index 48915b59e06..ef0980024b9 100644 --- a/oneflow/user/ops/nccl_logical_ops.cpp +++ b/oneflow/user/ops/nccl_logical_ops.cpp @@ -16,197 +16,232 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/operator.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("_nccl_logical_all_reduce") - .Input("in") - .Output("out") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); - for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { - CHECK_OR_RETURN(sbp_hint.has_partial_sum_parallel()); - } - - in_distribution->clear_sbp_parallel(); - out_distribution->clear_sbp_parallel(); - - // P2B - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); - for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { - in_distribution->add_sbp_parallel()->mutable_partial_sum_parallel(); - out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); - } - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_USER_OP("_nccl_logical_reduce_scatter") - .Input("in") - .Output("out") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); - for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { - CHECK_OR_RETURN(sbp_hint.has_partial_sum_parallel()); - } - - in_distribution->clear_sbp_parallel(); - out_distribution->clear_sbp_parallel(); - - // P2S - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); - for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { - in_distribution->add_sbp_parallel()->mutable_partial_sum_parallel(); - out_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(0); - } - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_USER_OP("_nccl_logical_all_gather") - .Input("in") - .Output("out") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); - for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { - CHECK_OR_RETURN(sbp_hint.has_split_parallel()); - CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), 0); - } - - in_distribution->clear_sbp_parallel(); - out_distribution->clear_sbp_parallel(); - - // S(0)->B - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); - for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { - in_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(0); - out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); - } - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_USER_OP("_nccl_logical_all_gather_noncontinuous") - .Input("in") - .Output("out") - .Attr("in_split_axis", -1) - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); - const int64_t in_split_axis = ctx->user_op_conf().attr("in_split_axis"); - CHECK_GE_OR_RETURN(in_split_axis, 1); - for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { - CHECK_OR_RETURN(sbp_hint.has_split_parallel()); - CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), in_split_axis); - } - - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - in_distribution->clear_sbp_parallel(); - out_distribution->clear_sbp_parallel(); - - // S(1)->(B) - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); - for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { - in_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(in_split_axis); - out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); - } - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_USER_OP("_nccl_logical_s2s") - .Input("in") - .Output("out") - .Attr("in_split_axis", -1) - .Attr("out_split_axis", -1) - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const int64_t in_split_axis = ctx->user_op_conf().attr("in_split_axis"); - const int64_t out_split_axis = ctx->user_op_conf().attr("out_split_axis"); - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); - for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { - CHECK_OR_RETURN(sbp_hint.has_split_parallel()); - CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), in_split_axis); - } - - in_distribution->clear_sbp_parallel(); - out_distribution->clear_sbp_parallel(); - - // S(in)->S(out) - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); - for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { - in_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(in_split_axis); - out_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(out_split_axis); - } - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/* static */ Maybe _ncclLogicalAllReduceOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogicalAllReduceOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe _ncclLogicalAllReduceOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); + for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { + CHECK_OR_RETURN(sbp_hint.has_partial_sum_parallel()); + } + + in_distribution->clear_sbp_parallel(); + out_distribution->clear_sbp_parallel(); + + // P2B + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); + for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { + in_distribution->add_sbp_parallel()->mutable_partial_sum_parallel(); + out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); + } + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogicalAllReduceOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> _ncclLogicalAllReduceOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + +/* static */ Maybe _ncclLogicalReduceScatterOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogicalReduceScatterOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe _ncclLogicalReduceScatterOp::InferNdSbp( + user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); + for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { + CHECK_OR_RETURN(sbp_hint.has_partial_sum_parallel()); + } + + in_distribution->clear_sbp_parallel(); + out_distribution->clear_sbp_parallel(); + + // P2S + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); + for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { + in_distribution->add_sbp_parallel()->mutable_partial_sum_parallel(); + out_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(0); + } + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogicalReduceScatterOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> _ncclLogicalReduceScatterOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + +/* static */ Maybe _ncclLogicalAllGatherOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogicalAllGatherOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe _ncclLogicalAllGatherOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); + for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { + CHECK_OR_RETURN(sbp_hint.has_split_parallel()); + CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), 0); + } + + in_distribution->clear_sbp_parallel(); + out_distribution->clear_sbp_parallel(); + + // S(0)->B + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); + for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { + in_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(0); + out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); + } + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogicalAllGatherOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> _ncclLogicalAllGatherOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + +/* static */ Maybe _ncclLogicalAllGatherNoncontinuousOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogicalAllGatherNoncontinuousOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe _ncclLogicalAllGatherNoncontinuousOp::InferNdSbp( + user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); + const int64_t in_split_axis = ctx->user_op_conf().attr("in_split_axis"); + CHECK_GE_OR_RETURN(in_split_axis, 1); + for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { + CHECK_OR_RETURN(sbp_hint.has_split_parallel()); + CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), in_split_axis); + } + + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + in_distribution->clear_sbp_parallel(); + out_distribution->clear_sbp_parallel(); + + // S(1)->(B) + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); + for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { + in_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(in_split_axis); + out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); + } + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogicalAllGatherNoncontinuousOp::InferDataType( + user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> _ncclLogicalAllGatherNoncontinuousOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + +/* static */ Maybe _ncclLogicalS2sOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogicalS2sOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe _ncclLogicalS2sOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + const int64_t in_split_axis = ctx->user_op_conf().attr("in_split_axis"); + const int64_t out_split_axis = ctx->user_op_conf().attr("out_split_axis"); + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); + for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { + CHECK_OR_RETURN(sbp_hint.has_split_parallel()); + CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), in_split_axis); + } + + in_distribution->clear_sbp_parallel(); + out_distribution->clear_sbp_parallel(); + + // S(in)->S(out) + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); + for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { + in_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(in_split_axis); + out_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(out_split_axis); + } + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogicalS2sOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> _ncclLogicalS2sOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} } // namespace oneflow diff --git a/oneflow/user/ops/nd_index_slice_ops.cpp b/oneflow/user/ops/nd_index_slice_ops.cpp index 2c51b18a98d..2fa17d2d390 100644 --- a/oneflow/user/ops/nd_index_slice_ops.cpp +++ b/oneflow/user/ops/nd_index_slice_ops.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -112,175 +113,207 @@ Maybe GetTensorScatterNdOptSbpSignatures(user_op::SbpContext* ctx) { } // namespace -REGISTER_USER_OP("gather_nd") - .Input("params") - .Input("indices") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& params_shape = ctx->InputShape("params", 0); - const Shape& indices_shape = ctx->InputShape("indices", 0); - int64_t index_ndims = indices_shape.At(indices_shape.NumAxes() - 1); - CHECK_LE_OR_RETURN(index_ndims, params_shape.NumAxes()); - DimVector out_shape_vec(indices_shape.dim_vec().cbegin(), indices_shape.dim_vec().cend() - 1); - FOR_RANGE(int64_t, i, index_ndims, params_shape.NumAxes()) { - out_shape_vec.emplace_back(params_shape.At(i)); - } - *ctx->OutputShape("out", 0) = Shape(out_shape_vec); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK_OR_RETURN(indices_modifier != nullptr); - indices_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& params_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("params", 0); - const user_op::TensorDesc& indices_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0); - int64_t indices_num_axes = indices_tensor.shape().NumAxes(); - FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) { - ctx->NewBuilder() - .Broadcast(user_op::OpArg("params", 0)) - .Split(user_op::OpArg("indices", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - int64_t index_ndims = indices_tensor.shape().At(indices_num_axes - 1); - FOR_RANGE(int64_t, i, index_ndims, params_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("params", 0), i) - .Broadcast(user_op::OpArg("indices", 0)) - .Split(user_op::OpArg("out", 0), i - index_ndims + indices_num_axes - 1) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("params", 0)) - .Broadcast(user_op::OpArg("indices", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("params", 0); - return Maybe::Ok(); - }); +/* static */ Maybe GatherNdOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& params_shape = ctx->InputShape("params", 0); + const Shape& indices_shape = ctx->InputShape("indices", 0); + int64_t index_ndims = indices_shape.At(indices_shape.NumAxes() - 1); + CHECK_LE_OR_RETURN(index_ndims, params_shape.NumAxes()); + DimVector out_shape_vec(indices_shape.dim_vec().cbegin(), indices_shape.dim_vec().cend() - 1); + FOR_RANGE(int64_t, i, index_ndims, params_shape.NumAxes()) { + out_shape_vec.emplace_back(params_shape.At(i)); + } + *ctx->OutputShape("out", 0) = Shape(out_shape_vec); + return Maybe::Ok(); +} -REGISTER_USER_OP("scatter_nd") - .Input("indices") - .Input("updates") - .Output("out") - .Attr("shape") - .SetTensorDescInferFn(InferScatterNdTensorDesc) - .SetDataTypeInferFn(InferScatterNdDataType) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK_OR_RETURN(indices_modifier != nullptr); - indices_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& indices_desc = - ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0); - int64_t indices_num_axes = indices_desc.shape().NumAxes(); - FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) { - ctx->NewBuilder() - .Split(user_op::OpArg("indices", 0), i) - .Split(user_op::OpArg("updates", 0), i) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - } - const Shape& out_shape = ctx->Attr("shape"); - int64_t index_ndims = indices_desc.shape().At(indices_num_axes - 1); - int64_t slice_ndims = out_shape.NumAxes() - index_ndims; - FOR_RANGE(int64_t, i, 0, slice_ndims) { - ctx->NewBuilder() - .Broadcast(user_op::OpArg("indices", 0)) - .Split(user_op::OpArg("updates", 0), i + indices_num_axes - 1) - .Split(user_op::OpArg("out", 0), i + index_ndims) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("updates", 0)) - .Broadcast(user_op::OpArg("indices", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe GatherNdOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} -REGISTER_USER_OP("scatter_nd_like") - .Input("like") - .Input("indices") - .Input("updates") - .Output("out") - .SetTensorDescInferFn(InferScatterNdLikeTensorDesc) - .SetDataTypeInferFn(InferScatterNdLikeDataType) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& indices_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0); - int64_t indices_num_axes = indices_tensor.shape().NumAxes(); - FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) { - ctx->NewBuilder() - .Broadcast(user_op::OpArg("like", 0)) - .Split(user_op::OpArg("indices", 0), i) - .Split(user_op::OpArg("updates", 0), i) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - } - const Shape& out_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape(); - int64_t index_ndims = indices_tensor.shape().At(indices_num_axes - 1); - int64_t slice_ndims = out_shape.NumAxes() - index_ndims; - FOR_RANGE(int64_t, i, 0, slice_ndims) { - ctx->NewBuilder() - .Split(user_op::OpArg("like", 0), i + index_ndims) - .Broadcast(user_op::OpArg("indices", 0)) - .Split(user_op::OpArg("updates", 0), i + indices_num_axes - 1) - .Split(user_op::OpArg("out", 0), i + index_ndims) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("like", 0)) - .PartialSum(user_op::OpArg("updates", 0)) - .Broadcast(user_op::OpArg("indices", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }); +/* static */ Maybe GatherNdOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& params_tensor = + ctx->LogicalTensorDesc4InputArgNameAndIndex("params", 0); + const user_op::TensorDesc& indices_tensor = + ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0); + int64_t indices_num_axes = indices_tensor.shape().NumAxes(); + FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) { + ctx->NewBuilder() + .Broadcast(user_op::OpArg("params", 0)) + .Split(user_op::OpArg("indices", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + int64_t index_ndims = indices_tensor.shape().At(indices_num_axes - 1); + FOR_RANGE(int64_t, i, index_ndims, params_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("params", 0), i) + .Broadcast(user_op::OpArg("indices", 0)) + .Split(user_op::OpArg("out", 0), i - index_ndims + indices_num_axes - 1) + .Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("params", 0)) + .Broadcast(user_op::OpArg("indices", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} -REGISTER_USER_OP("tensor_scatter_nd_update") - .Input("params") - .Input("updates") - .Input("indices") - .Output("out") - .SetTensorDescInferFn(InferTensorScatterNdOptTensorDesc) - .SetDataTypeInferFn(InferTensorScatterNdOptDataType) - .SetGetSbpFn(GetTensorScatterNdOptSbpSignatures) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK_OR_RETURN(indices_modifier != nullptr); - indices_modifier->set_requires_grad(false); - return Maybe::Ok(); - }); +/* static */ Maybe GatherNdOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); + CHECK_OR_RETURN(indices_modifier != nullptr); + indices_modifier->set_requires_grad(false); + return Maybe::Ok(); +} -REGISTER_USER_OP("tensor_scatter_nd_add") - .Input("params") - .Input("updates") - .Input("indices") - .Output("out") - .SetTensorDescInferFn(InferTensorScatterNdOptTensorDesc) - .SetDataTypeInferFn(InferTensorScatterNdOptDataType) - .SetGetSbpFn(GetTensorScatterNdOptSbpSignatures) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK_OR_RETURN(indices_modifier != nullptr); - indices_modifier->set_requires_grad(false); - return Maybe::Ok(); - }); +/* static */ Maybe GatherNdOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("params", 0); + return Maybe::Ok(); +} + +/* static */ Maybe ScatterNdOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferScatterNdTensorDesc(ctx); +} + +/*static*/ Maybe ScatterNdOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ScatterNdOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& indices_desc = + ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0); + int64_t indices_num_axes = indices_desc.shape().NumAxes(); + FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) { + ctx->NewBuilder() + .Split(user_op::OpArg("indices", 0), i) + .Split(user_op::OpArg("updates", 0), i) + .Broadcast(user_op::OpArg("out", 0)) + .Build(); + } + const Shape& out_shape = ctx->Attr("shape"); + int64_t index_ndims = indices_desc.shape().At(indices_num_axes - 1); + int64_t slice_ndims = out_shape.NumAxes() - index_ndims; + FOR_RANGE(int64_t, i, 0, slice_ndims) { + ctx->NewBuilder() + .Broadcast(user_op::OpArg("indices", 0)) + .Split(user_op::OpArg("updates", 0), i + indices_num_axes - 1) + .Split(user_op::OpArg("out", 0), i + index_ndims) + .Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("updates", 0)) + .Broadcast(user_op::OpArg("indices", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ScatterNdOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); + CHECK_OR_RETURN(indices_modifier != nullptr); + indices_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe ScatterNdOp::InferDataType(user_op::InferContext* ctx) { + return InferScatterNdDataType(ctx); +} + +/* static */ Maybe ScatterNdLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferScatterNdLikeTensorDesc(ctx); +} + +/*static*/ Maybe ScatterNdLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ScatterNdLikeOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& indices_tensor = + ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0); + int64_t indices_num_axes = indices_tensor.shape().NumAxes(); + FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) { + ctx->NewBuilder() + .Broadcast(user_op::OpArg("like", 0)) + .Split(user_op::OpArg("indices", 0), i) + .Split(user_op::OpArg("updates", 0), i) + .Broadcast(user_op::OpArg("out", 0)) + .Build(); + } + const Shape& out_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape(); + int64_t index_ndims = indices_tensor.shape().At(indices_num_axes - 1); + int64_t slice_ndims = out_shape.NumAxes() - index_ndims; + FOR_RANGE(int64_t, i, 0, slice_ndims) { + ctx->NewBuilder() + .Split(user_op::OpArg("like", 0), i + index_ndims) + .Broadcast(user_op::OpArg("indices", 0)) + .Split(user_op::OpArg("updates", 0), i + indices_num_axes - 1) + .Split(user_op::OpArg("out", 0), i + index_ndims) + .Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("like", 0)) + .PartialSum(user_op::OpArg("updates", 0)) + .Broadcast(user_op::OpArg("indices", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ScatterNdLikeOp::InferDataType(user_op::InferContext* ctx) { + return InferScatterNdLikeDataType(ctx); +} + +/* static */ Maybe TensorScatterNdUpdateOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return InferTensorScatterNdOptTensorDesc(ctx); +} + +/*static*/ Maybe TensorScatterNdUpdateOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe TensorScatterNdUpdateOp::GetSbp(user_op::SbpContext* ctx) { + return GetTensorScatterNdOptSbpSignatures(ctx); +} + +/* static */ Maybe TensorScatterNdUpdateOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); + CHECK_OR_RETURN(indices_modifier != nullptr); + indices_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe TensorScatterNdUpdateOp::InferDataType(user_op::InferContext* ctx) { + return InferTensorScatterNdOptDataType(ctx); +} + +/* static */ Maybe TensorScatterNdAddOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorScatterNdOptTensorDesc(ctx); +} + +/*static*/ Maybe TensorScatterNdAddOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe TensorScatterNdAddOp::GetSbp(user_op::SbpContext* ctx) { + return GetTensorScatterNdOptSbpSignatures(ctx); +} + +/* static */ Maybe TensorScatterNdAddOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); + CHECK_OR_RETURN(indices_modifier != nullptr); + indices_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe TensorScatterNdAddOp::InferDataType(user_op::InferContext* ctx) { + return InferTensorScatterNdOptDataType(ctx); +} REGISTER_USER_OP_GRAD("gather_nd") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/nll_op.cpp b/oneflow/user/ops/nll_op.cpp index a3343bcbad7..b170194aff4 100644 --- a/oneflow/user/ops/nll_op.cpp +++ b/oneflow/user/ops/nll_op.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/loss_op_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -44,7 +45,7 @@ Maybe InferTensorDescFn(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe InferDataType(user_op::InferContext* ctx) { +Maybe NllInferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_OR_RETURN(IsIndexDataType(target_desc.data_type())); @@ -88,39 +89,51 @@ Maybe InferGradDataType(user_op::InferContext* ctx) { } } // namespace -REGISTER_USER_OP("nll") - .Input("input") - .Input("target") - .OptionalInput("weight") - .Output("out") - .Output("total_weight") - .Attr("ignore_index") - .SetTensorDescInferFn(InferTensorDescFn) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); - CHECK_OR_RETURN(target_modifier != nullptr); - target_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetDataTypeInferFn(InferDataType) - .SetGetSbpFn(GenLossForwardDefaultGetSbpFn([](user_op::UserOpSbpSignatureBuilder& builder) { - builder.PartialSum(user_op::OpArg("total_weight", 0)); - })); - -REGISTER_USER_OP("nll_grad") - .Input("input") - .Input("target") - .Input("total_weight") - .OptionalInput("weight") - .Input("dy") - .Output("dx") - .Attr("ignore_index") - .SetTensorDescInferFn(InferGradTensorDescFn) - .SetDataTypeInferFn(InferGradDataType) - .SetGetSbpFn(GenLossBackwardDefaultGetSbpFn([](user_op::UserOpSbpSignatureBuilder& builder) { - builder.PartialSum(user_op::OpArg("total_weight", 0)); - })); +/* static */ Maybe NllOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDescFn(ctx); +} + +/*static*/ Maybe NllOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe NllOp::GetSbp(user_op::SbpContext* ctx) { + return GenLossForwardDefaultGetSbpFn( + [](user_op::UserOpSbpSignatureBuilder& builder, user_op::SbpContext* ctx) { + builder.PartialSum(user_op::OpArg("total_weight", 0)); + })(ctx); +} + +/* static */ Maybe NllOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); + CHECK_OR_RETURN(target_modifier != nullptr); + target_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe NllOp::InferDataType(user_op::InferContext* ctx) { + return NllInferDataType(ctx); +} + +/* static */ Maybe NllGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferGradTensorDescFn(ctx); +} + +/*static*/ Maybe NllGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe NllGradOp::GetSbp(user_op::SbpContext* ctx) { + return GenLossBackwardDefaultGetSbpFn( + [](user_op::UserOpSbpSignatureBuilder& builder, user_op::SbpContext* ctx) { + builder.PartialSum(user_op::OpArg("total_weight", 0)); + })(ctx); +} + +/* static */ Maybe NllGradOp::InferDataType(user_op::InferContext* ctx) { + return InferGradDataType(ctx); +} REGISTER_USER_OP_GRAD("nll").SetGenBackwardOpConfFn( [](const user_op::UserOpWrapper& op, const user_op::AddOpFn& AddOp) -> Maybe { diff --git a/oneflow/user/ops/nms_op.cpp b/oneflow/user/ops/nms_op.cpp index 20dcbbf5901..1d9c0e29537 100644 --- a/oneflow/user/ops/nms_op.cpp +++ b/oneflow/user/ops/nms_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -31,13 +32,20 @@ Maybe InferNmsDataType(user_op::InferContext* ctx) { } // namespace -REGISTER_USER_OP("nms") - .Input("in") - .Output("out") - .Attr("iou_threshold") - .Attr("keep_n") - .SetTensorDescInferFn(InferNmsTensorDesc) - .SetDataTypeInferFn(InferNmsDataType) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/* static */ Maybe NmsOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferNmsTensorDesc(ctx); +} + +/*static*/ Maybe NmsOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe NmsOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe NmsOp::InferDataType(user_op::InferContext* ctx) { + return InferNmsDataType(ctx); +} } // namespace oneflow diff --git a/oneflow/user/ops/normalization_op.cpp b/oneflow/user/ops/normalization_op.cpp index df95d677b95..2444ca274a4 100644 --- a/oneflow/user/ops/normalization_op.cpp +++ b/oneflow/user/ops/normalization_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" #ifdef WITH_CUDA #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/device/cudnn_util.h" @@ -212,77 +213,82 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() { user_op::TensorDesc* reserve_space)>()); } -REGISTER_USER_OP("normalization") - .Input("x") - .OptionalInput("moving_mean") - .OptionalInput("moving_variance") - .Input("gamma") - .Input("beta") - .OptionalInput("_add_to_output") - .Output("y") - .OptionalOutput("mean") - .OptionalOutput("inv_variance") - .Attr("axis") - .Attr("epsilon") - .Attr("training") - .Attr("momentum") - .SetInputArgModifyFn(FwInputArgModifyFn) - .SetTensorDescInferFn(MakeFwTensorDescInferFn()) - .SetGetSbpFn(FwGetSbpFn) - .SetDataTypeInferFn(MakeFwDataTypeInferFn()); - -REGISTER_USER_OP("normalization_add_relu") - .Input("x") - .OptionalInput("addend") - .OptionalInput("moving_mean") - .OptionalInput("moving_variance") - .Input("gamma") - .Input("beta") - .Output("y") - .Output("reserve_space") - .OptionalOutput("mean") - .OptionalOutput("inv_variance") - .Attr("axis") - .Attr("epsilon") - .Attr("training") - .Attr("momentum") - .SetInputArgModifyFn(FwInputArgModifyFn) - .SetLogicalTensorDescInferFn( - MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, - user_op::TensorDesc* reserve_space) -> Maybe { - const auto& x_desc = ctx->InputTensorDesc("x", 0); - size_t reserve_space_bits = x_desc.shape().elem_cnt(); - int64_t parallel_num = ctx->parallel_num(); - if (parallel_num != 1) { - // There no need to call SbpParallel4ArgNameAndIndex when parallel_num = 1 in local. - const cfg::SbpParallel& x_sbp = ctx->SbpParallel4ArgNameAndIndex("x", 0); - if (x_sbp.has_split_parallel()) { - CHECK_EQ_OR_RETURN(x_sbp.split_parallel().axis(), 0); - reserve_space_bits = reserve_space_bits / ctx->parallel_num(); - } - } - *reserve_space->mut_shape() = - Shape({static_cast(RoundUp(reserve_space_bits, 32) / 32)}); - return Maybe::Ok(); - })) - .SetPhysicalTensorDescInferFn( - MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, - user_op::TensorDesc* reserve_space) -> Maybe { - const auto& x_desc = ctx->InputTensorDesc("x", 0); - *reserve_space->mut_shape() = - Shape({static_cast(RoundUp(x_desc.shape().elem_cnt(), 32) / 32)}); - return Maybe::Ok(); - })) - .SetGetSbpFn(FwGetSbpFn) - .SetDataTypeInferFn( - MakeFwDataTypeInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, - user_op::TensorDesc* reserve_space) -> Maybe { - *reserve_space->mut_data_type() = DataType::kInt32; - return Maybe::Ok(); - })); +} // namespace + +/* static */ Maybe NormalizationOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return MakeFwTensorDescInferFn()(ctx); +} + +/*static*/ Maybe NormalizationOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe NormalizationOp::GetSbp(user_op::SbpContext* ctx) { + return FwGetSbpFn(ctx); +} + +/* static */ Maybe NormalizationOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return FwInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe NormalizationOp::InferDataType(user_op::InferContext* ctx) { + return MakeFwDataTypeInferFn()(ctx); +} + +/* static */ Maybe NormalizationAddReluOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, + user_op::TensorDesc* reserve_space) -> Maybe { + const auto& x_desc = ctx->InputTensorDesc("x", 0); + size_t reserve_space_bits = x_desc.shape().elem_cnt(); + int64_t parallel_num = ctx->parallel_num(); + if (parallel_num != 1) { + // There no need to call SbpParallel4ArgNameAndIndex when parallel_num = 1 in local. + const cfg::SbpParallel& x_sbp = ctx->SbpParallel4ArgNameAndIndex("x", 0); + if (x_sbp.has_split_parallel()) { + CHECK_EQ_OR_RETURN(x_sbp.split_parallel().axis(), 0); + reserve_space_bits = reserve_space_bits / ctx->parallel_num(); + } + } + *reserve_space->mut_shape() = + Shape({static_cast(RoundUp(reserve_space_bits, 32) / 32)}); + return Maybe::Ok(); + })(ctx); +} + +/* static */ Maybe NormalizationAddReluOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, + user_op::TensorDesc* reserve_space) -> Maybe { + const auto& x_desc = ctx->InputTensorDesc("x", 0); + *reserve_space->mut_shape() = + Shape({static_cast(RoundUp(x_desc.shape().elem_cnt(), 32) / 32)}); + return Maybe::Ok(); + })(ctx); +} + +/* static */ Maybe NormalizationAddReluOp::GetSbp(user_op::SbpContext* ctx) { + return FwGetSbpFn(ctx); +} + +/* static */ Maybe NormalizationAddReluOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return FwInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe NormalizationAddReluOp::InferDataType(user_op::InferContext* ctx) { + return MakeFwDataTypeInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, + user_op::TensorDesc* reserve_space) -> Maybe { + *reserve_space->mut_data_type() = DataType::kInt32; + return Maybe::Ok(); + })(ctx); +} #if defined(WITH_CUDA) && (CUDNN_VERSION >= 7401) +namespace { + void InferCudnnReserveSpaceSize(DataType data_type, cudnnBatchNormOps_t ops, int64_t n, int64_t c, int64_t h, int64_t w, size_t* reserve_space_size) { cudnnHandle_t cudnn_handle; @@ -295,79 +301,110 @@ void InferCudnnReserveSpaceSize(DataType data_type, cudnnBatchNormOps_t ops, int OF_CUDNN_CHECK(cudnnDestroy(cudnn_handle)); } -REGISTER_USER_OP("cudnn_fused_normalization_add_relu") - .Input("x") - .OptionalInput("addend") - .OptionalInput("moving_mean") - .OptionalInput("moving_variance") - .Input("gamma") - .Input("beta") - .Output("y") - .Output("reserve_space") - .OptionalOutput("mean") - .OptionalOutput("inv_variance") - .Attr("axis") - .Attr("epsilon") - .Attr("momentum") - .SetInputArgModifyFn(FwInputArgModifyFn) - .SetLogicalTensorDescInferFn( - MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, - user_op::TensorDesc* reserve_space) -> Maybe { - const Shape& x_shape = x->shape(); - const auto axis = ctx->Attr("axis"); - CHECK_EQ_OR_RETURN(x_shape.Count(axis + 1), 1); - int64_t n = x_shape.At(0); - int64_t h = x_shape.Count(1, axis); - int64_t w = 1; - int64_t c = x_shape.At(axis); - const auto& x_sbp = ctx->SbpParallel4ArgNameAndIndex("x", 0); - if (x_sbp.has_split_parallel()) { - CHECK_EQ_OR_RETURN(x_sbp.split_parallel().axis(), 0); - n = n / ctx->parallel_num(); - } - cudnnBatchNormOps_t ops; - if (ctx->has_input("addend", 0)) { - ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION; - } else { - ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION; - } - size_t reserve_space_size; - InferCudnnReserveSpaceSize(x->data_type(), ops, n, c, h, w, &reserve_space_size); - reserve_space_size = std::max(reserve_space_size, GetOneVal()); - *reserve_space->mut_shape() = Shape({static_cast(reserve_space_size)}); - return Maybe::Ok(); - })) - .SetPhysicalTensorDescInferFn( - MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, - user_op::TensorDesc* reserve_space) -> Maybe { - const Shape& x_shape = x->shape(); - const auto axis = ctx->Attr("axis"); - CHECK_EQ_OR_RETURN(x_shape.Count(axis + 1), 1); - int64_t n = x_shape.At(0); - int64_t h = x_shape.Count(1, axis); - int64_t w = 1; - int64_t c = x_shape.At(axis); - cudnnBatchNormOps_t ops; - if (ctx->has_input("addend", 0)) { - ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION; - } else { - ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION; - } - size_t reserve_space_size; - InferCudnnReserveSpaceSize(x->data_type(), ops, n, c, h, w, &reserve_space_size); - reserve_space_size = std::max(reserve_space_size, GetOneVal()); - *reserve_space->mut_shape() = Shape({static_cast(reserve_space_size)}); - return Maybe::Ok(); - })) - .SetGetSbpFn(FwGetSbpFn) - .SetDataTypeInferFn( - MakeFwDataTypeInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, - user_op::TensorDesc* reserve_space) -> Maybe { - *reserve_space->mut_data_type() = DataType::kChar; - return Maybe::Ok(); - })); +} // namespace -#endif +/* static */ Maybe CudnnFusedNormalizationAddReluOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, + user_op::TensorDesc* reserve_space) -> Maybe { + const Shape& x_shape = x->shape(); + const auto axis = ctx->Attr("axis"); + CHECK_EQ_OR_RETURN(x_shape.Count(axis + 1), 1); + int64_t n = x_shape.At(0); + int64_t h = x_shape.Count(1, axis); + int64_t w = 1; + int64_t c = x_shape.At(axis); + const auto& x_sbp = ctx->SbpParallel4ArgNameAndIndex("x", 0); + if (x_sbp.has_split_parallel()) { + CHECK_EQ_OR_RETURN(x_sbp.split_parallel().axis(), 0); + n = n / ctx->parallel_num(); + } + cudnnBatchNormOps_t ops; + if (ctx->has_input("addend", 0)) { + ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION; + } else { + ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION; + } + size_t reserve_space_size; + InferCudnnReserveSpaceSize(x->data_type(), ops, n, c, h, w, &reserve_space_size); + reserve_space_size = std::max(reserve_space_size, GetOneVal()); + *reserve_space->mut_shape() = Shape({static_cast(reserve_space_size)}); + return Maybe::Ok(); + })(ctx); +} + +/* static */ Maybe CudnnFusedNormalizationAddReluOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, + user_op::TensorDesc* reserve_space) -> Maybe { + const Shape& x_shape = x->shape(); + const auto axis = ctx->Attr("axis"); + CHECK_EQ_OR_RETURN(x_shape.Count(axis + 1), 1); + int64_t n = x_shape.At(0); + int64_t h = x_shape.Count(1, axis); + int64_t w = 1; + int64_t c = x_shape.At(axis); + cudnnBatchNormOps_t ops; + if (ctx->has_input("addend", 0)) { + ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION; + } else { + ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION; + } + size_t reserve_space_size; + InferCudnnReserveSpaceSize(x->data_type(), ops, n, c, h, w, &reserve_space_size); + reserve_space_size = std::max(reserve_space_size, GetOneVal()); + *reserve_space->mut_shape() = Shape({static_cast(reserve_space_size)}); + return Maybe::Ok(); + })(ctx); +} + +/* static */ Maybe CudnnFusedNormalizationAddReluOp::GetSbp(user_op::SbpContext* ctx) { + return FwGetSbpFn(ctx); +} + +/* static */ Maybe CudnnFusedNormalizationAddReluOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return FwInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe CudnnFusedNormalizationAddReluOp::InferDataType( + user_op::InferContext* ctx) { + return MakeFwDataTypeInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, + user_op::TensorDesc* reserve_space) -> Maybe { + *reserve_space->mut_data_type() = DataType::kChar; + return Maybe::Ok(); + })(ctx); +} + +#else + +/* static */ Maybe CudnnFusedNormalizationAddReluOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; +} + +/* static */ Maybe CudnnFusedNormalizationAddReluOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; +} + +/* static */ Maybe CudnnFusedNormalizationAddReluOp::GetSbp(user_op::SbpContext* ctx) { + return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; +} + +/* static */ Maybe CudnnFusedNormalizationAddReluOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; +} + +/* static */ Maybe CudnnFusedNormalizationAddReluOp::InferDataType( + user_op::InferContext* ctx) { + return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; +} + +#endif // WITH_CUDA + +namespace { Maybe BwTensorDescInferFn(user_op::InferContext* ctx) { #ifdef WITH_CUDA @@ -447,60 +484,83 @@ Maybe BwGetSbpFn(user_op::SbpContext* ctx) { return Maybe::Ok(); } -REGISTER_USER_OP("normalization_grad") - .Input("x") - .Input("dy") - .Input("mean") - .Input("inv_variance") - .Input("gamma") - .Output("gamma_diff") - .Output("beta_diff") - .Output("dx") - .Attr("axis") - .Attr("epsilon") - .SetTensorDescInferFn(BwTensorDescInferFn) - .SetGetSbpFn(BwGetSbpFn) - .SetDataTypeInferFn(BwDataTypeInferFn); - -REGISTER_USER_OP("normalization_add_relu_grad") - .Input("x") - .Input("dy") - .Input("mean") - .Input("inv_variance") - .Input("gamma") - .Input("beta") - .Input("reserve_space") - .Input("y") - .Output("gamma_diff") - .Output("beta_diff") - .Output("dx") - .OptionalOutput("addend_diff") - .Attr("axis") - .Attr("epsilon") - .SetTensorDescInferFn(BwTensorDescInferFn) - .SetGetSbpFn(BwGetSbpFn) - .SetDataTypeInferFn(BwDataTypeInferFn); +} // namespace + +/* static */ Maybe NormalizationGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return BwTensorDescInferFn(ctx); +} + +/*static*/ Maybe NormalizationGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe NormalizationGradOp::GetSbp(user_op::SbpContext* ctx) { + return BwGetSbpFn(ctx); +} + +/* static */ Maybe NormalizationGradOp::InferDataType(user_op::InferContext* ctx) { + return BwDataTypeInferFn(ctx); +} + +/* static */ Maybe NormalizationAddReluGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return BwTensorDescInferFn(ctx); +} + +/*static*/ Maybe NormalizationAddReluGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe NormalizationAddReluGradOp::GetSbp(user_op::SbpContext* ctx) { + return BwGetSbpFn(ctx); +} + +/* static */ Maybe NormalizationAddReluGradOp::InferDataType(user_op::InferContext* ctx) { + return BwDataTypeInferFn(ctx); +} #if defined(WITH_CUDA) && (CUDNN_VERSION >= 7401) -REGISTER_USER_OP("cudnn_fused_normalization_add_relu_grad") - .Input("x") - .Input("dy") - .Input("mean") - .Input("inv_variance") - .Input("gamma") - .Input("beta") - .Input("reserve_space") - .Input("y") - .Output("gamma_diff") - .Output("beta_diff") - .Output("dx") - .OptionalOutput("addend_diff") - .Attr("axis") - .Attr("epsilon") - .SetTensorDescInferFn(BwTensorDescInferFn) - .SetGetSbpFn(BwGetSbpFn) - .SetDataTypeInferFn(BwDataTypeInferFn); +/* static */ Maybe CudnnFusedNormalizationAddReluGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return BwTensorDescInferFn(ctx); +} + +/*static*/ Maybe CudnnFusedNormalizationAddReluGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CudnnFusedNormalizationAddReluGradOp::GetSbp(user_op::SbpContext* ctx) { + return BwGetSbpFn(ctx); +} + +/* static */ Maybe CudnnFusedNormalizationAddReluGradOp::InferDataType( + user_op::InferContext* ctx) { + return BwDataTypeInferFn(ctx); +} + +#else + +/* static */ Maybe CudnnFusedNormalizationAddReluGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; +} + +/*static*/ Maybe CudnnFusedNormalizationAddReluGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; +} + +/* static */ Maybe CudnnFusedNormalizationAddReluGradOp::GetSbp(user_op::SbpContext* ctx) { + return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; +} + +/* static */ Maybe CudnnFusedNormalizationAddReluGradOp::InferDataType( + user_op::InferContext* ctx) { + return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; +} #endif @@ -709,6 +769,4 @@ REGISTER_USER_OP_GRAD("normalization_add_relu") return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/nvtx_range_op.cpp b/oneflow/user/ops/nvtx_range_op.cpp index e0a222a7bd5..0f2bd54b2e6 100644 --- a/oneflow/user/ops/nvtx_range_op.cpp +++ b/oneflow/user/ops/nvtx_range_op.cpp @@ -13,67 +13,103 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +#ifdef WITH_CUDA -REGISTER_USER_OP("nvtx_start") - .Input("in") - .Output("out") - .Attr("mark_prefix") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/* static */ Maybe NvtxStartOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("nvtx_end") - .Input("in") - .Output("out") - .Attr("mark_prefix") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe NvtxStartOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe NvtxStartOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe NvtxStartOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe NvtxEndOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe NvtxEndOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe NvtxEndOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe NvtxEndOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +#else + +/* static */ Maybe NvtxStartOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return Error::UnimplementedError() << "require CUDA to use NVTX"; +} + +/*static*/ Maybe NvtxStartOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe NvtxStartOp::GetSbp(user_op::SbpContext* ctx) { + return Error::UnimplementedError() << "require CUDA to use NVTX"; +} + +/* static */ Maybe NvtxStartOp::InferDataType(user_op::InferContext* ctx) { + return Error::UnimplementedError() << "require CUDA to use NVTX"; +} + +/* static */ Maybe NvtxEndOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return Error::UnimplementedError() << "require CUDA to use NVTX"; +} + +/*static*/ Maybe NvtxEndOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return Error::UnimplementedError() << "require CUDA to use NVTX"; +} + +/* static */ Maybe NvtxEndOp::GetSbp(user_op::SbpContext* ctx) { + return Error::UnimplementedError() << "require CUDA to use NVTX"; +} + +/* static */ Maybe NvtxEndOp::InferDataType(user_op::InferContext* ctx) { + return Error::UnimplementedError() << "require CUDA to use NVTX"; +} + +#endif // WITH_CUDA REGISTER_USER_OP_GRAD("nvtx_start") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -108,6 +144,5 @@ REGISTER_USER_OP_GRAD("nvtx_end") } return Maybe::Ok(); }); -} // namespace } // namespace oneflow diff --git a/oneflow/user/ops/ofrecord_decoder_ops.cpp b/oneflow/user/ops/ofrecord_decoder_ops.cpp index ff19fe78952..02ccf542062 100644 --- a/oneflow/user/ops/ofrecord_decoder_ops.cpp +++ b/oneflow/user/ops/ofrecord_decoder_ops.cpp @@ -15,148 +15,148 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/balanced_splitter.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("ofrecord_raw_decoder") - .Input("in") - .Output("out") - .Attr("name") - .Attr("shape") - .Attr("data_type") - .Attr("dim1_varying_length", false) - .Attr("truncate", false) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1); - Shape conf_shape = ctx->Attr("shape"); - DimVector dim_vec(1 + conf_shape.NumAxes()); - dim_vec[0] = in_tensor.shape().At(0); - for (int i = 1; i < dim_vec.size(); ++i) { dim_vec[i] = conf_shape.At(i - 1); } - *out_tensor->mut_shape() = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); - CHECK_NOTNULL_OR_RETURN(in_modifier); - in_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), 0) - .Split(user_op::OpArg("out", 0), 0) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(in_tensor.data_type() == DataType::kOFRecord); - *out_tensor->mut_data_type() = ctx->Attr("data_type"); - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("ofrecord_bytes_decoder") - .Input("in") - .Output("out") - .Attr("name") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - *out->mut_is_dynamic() = in.is_dynamic(); - *out->mut_shape() = in.shape(); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); - CHECK_NOTNULL_OR_RETURN(in_modifier); - in_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::SplitForEachAxis) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(in.data_type() == DataType::kOFRecord); - *out->mut_data_type() = DataType::kTensorBuffer; - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("ofrecord_image_decoder") - .Input("in") - .Output("out") - .Attr("name") - .Attr("color_space", "BGR") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1); - *out_tensor->mut_shape() = in_tensor.shape(); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); - CHECK_NOTNULL_OR_RETURN(in_modifier); - in_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), 0) - .Split(user_op::OpArg("out", 0), 0) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(in_tensor.data_type() == DataType::kOFRecord); - *out_tensor->mut_data_type() = DataType::kTensorBuffer; - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("ofrecord_image_decoder_random_crop") - .Input("in") - .Output("out") - .Attr("name") - .Attr("color_space", "BGR") - .Attr("num_attempts", 10) - .Attr("seed", -1) - .Attr("has_seed", false) - .Attr>("random_area", {0.08, 1.0}) - .Attr>("random_aspect_ratio", {0.75, 1.333333}) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1); - *out_tensor->mut_shape() = in_tensor.shape(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), 0) - .Split(user_op::OpArg("out", 0), 0) - .Build(); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); - CHECK_NOTNULL_OR_RETURN(in_modifier); - in_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(in_tensor.data_type() == DataType::kOFRecord); - *out_tensor->mut_data_type() = DataType::kTensorBuffer; - return Maybe::Ok(); - }); +/* static */ Maybe OfrecordRawDecoderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1); + Shape conf_shape = ctx->Attr("shape"); + DimVector dim_vec(1 + conf_shape.NumAxes()); + dim_vec[0] = in_tensor.shape().At(0); + for (int i = 1; i < dim_vec.size(); ++i) { dim_vec[i] = conf_shape.At(i - 1); } + *out_tensor->mut_shape() = Shape(dim_vec); + return Maybe::Ok(); +} + +/*static*/ Maybe OfrecordRawDecoderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe OfrecordRawDecoderOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), 0).Split(user_op::OpArg("out", 0), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordRawDecoderOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); + CHECK_NOTNULL_OR_RETURN(in_modifier); + in_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordRawDecoderOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(in_tensor.data_type() == DataType::kOFRecord); + *out_tensor->mut_data_type() = ctx->Attr("data_type"); + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordBytesDecoderOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + *out->mut_is_dynamic() = in.is_dynamic(); + *out->mut_shape() = in.shape(); + return Maybe::Ok(); +} + +/*static*/ Maybe OfrecordBytesDecoderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe OfrecordBytesDecoderOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); +} + +/* static */ Maybe OfrecordBytesDecoderOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); + CHECK_NOTNULL_OR_RETURN(in_modifier); + in_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordBytesDecoderOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(in.data_type() == DataType::kOFRecord); + *out->mut_data_type() = DataType::kTensorBuffer; + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordImageDecoderOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1); + *out_tensor->mut_shape() = in_tensor.shape(); + return Maybe::Ok(); +} + +/*static*/ Maybe OfrecordImageDecoderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe OfrecordImageDecoderOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), 0).Split(user_op::OpArg("out", 0), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordImageDecoderOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); + CHECK_NOTNULL_OR_RETURN(in_modifier); + in_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordImageDecoderOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(in_tensor.data_type() == DataType::kOFRecord); + *out_tensor->mut_data_type() = DataType::kTensorBuffer; + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordImageDecoderRandomCropOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1); + *out_tensor->mut_shape() = in_tensor.shape(); + return Maybe::Ok(); +} + +/*static*/ Maybe OfrecordImageDecoderRandomCropOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe OfrecordImageDecoderRandomCropOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), 0).Split(user_op::OpArg("out", 0), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordImageDecoderRandomCropOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); + CHECK_NOTNULL_OR_RETURN(in_modifier); + in_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordImageDecoderRandomCropOp::InferDataType( + user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(in_tensor.data_type() == DataType::kOFRecord); + *out_tensor->mut_data_type() = DataType::kTensorBuffer; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/ofrecord_image_classification_reader_op.cpp b/oneflow/user/ops/ofrecord_image_classification_reader_op.cpp index 2e73940718f..5e1c21cc54c 100644 --- a/oneflow/user/ops/ofrecord_image_classification_reader_op.cpp +++ b/oneflow/user/ops/ofrecord_image_classification_reader_op.cpp @@ -14,66 +14,57 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("ofrecord_image_classification_reader") - .Output("image") - .Output("label") - .Attr("data_dir") - .Attr("data_part_num") - .Attr("batch_size") - .Attr("part_name_prefix", "part-") - .Attr("part_name_suffix_length", -1) - .Attr("random_shuffle", false) - .Attr("seed", -1) - .Attr("shuffle_buffer_size", 1024) - .Attr("shuffle_after_epoch", false) - .Attr("color_space", "BGR") - .Attr("image_feature_name", "encoded") - .Attr("label_feature_name", "class/label") - .Attr("decode_buffer_size_per_thread", 8) - .Attr("num_decode_threads_per_machine", 0) - .SetPhysicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* image_tensor = ctx->OutputTensorDesc("image", 0); - user_op::TensorDesc* label_tensor = ctx->OutputTensorDesc("label", 0); - int32_t local_batch_size = ctx->Attr("batch_size"); - const cfg::SbpParallel& sbp = ctx->SbpParallel4ArgNameAndIndex("image", 0); - int64_t parallel_num = ctx->parallel_ctx().parallel_num(); - if (sbp.has_split_parallel() && parallel_num > 1) { - CHECK_EQ_OR_RETURN(local_batch_size % parallel_num, 0); - local_batch_size /= parallel_num; - } - *image_tensor->mut_shape() = Shape({local_batch_size}); - *label_tensor->mut_shape() = Shape({local_batch_size}); - return Maybe::Ok(); - }) - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* image_tensor = ctx->OutputTensorDesc("image", 0); - user_op::TensorDesc* label_tensor = ctx->OutputTensorDesc("label", 0); - int32_t batch_size = ctx->Attr("batch_size"); - *image_tensor->mut_shape() = Shape({batch_size}); - *label_tensor->mut_shape() = Shape({batch_size}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetOutputArgModifyFn([](user_op::GetOutputArgModifier GetOutputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> Maybe { - user_op::OutputArgModifier* image_modifier = GetOutputArgModifierFn("image", 0); - CHECK_OR_RETURN(image_modifier != nullptr); - image_modifier->set_header_infered_before_compute(false); - user_op::OutputArgModifier* label_modifier = GetOutputArgModifierFn("label", 0); - CHECK_OR_RETURN(label_modifier != nullptr); - label_modifier->set_header_infered_before_compute(false); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("image", 0) = DataType::kTensorBuffer; - *ctx->OutputDType("label", 0) = DataType::kTensorBuffer; - return Maybe::Ok(); - }); +/* static */ Maybe OfrecordImageClassificationReaderOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + user_op::TensorDesc* image_tensor = ctx->OutputTensorDesc("image", 0); + user_op::TensorDesc* label_tensor = ctx->OutputTensorDesc("label", 0); + int32_t batch_size = ctx->Attr("batch_size"); + *image_tensor->mut_shape() = Shape({batch_size}); + *label_tensor->mut_shape() = Shape({batch_size}); + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordImageClassificationReaderOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + user_op::TensorDesc* image_tensor = ctx->OutputTensorDesc("image", 0); + user_op::TensorDesc* label_tensor = ctx->OutputTensorDesc("label", 0); + int32_t local_batch_size = ctx->Attr("batch_size"); + const cfg::SbpParallel& sbp = ctx->SbpParallel4ArgNameAndIndex("image", 0); + int64_t parallel_num = ctx->parallel_ctx().parallel_num(); + if (sbp.has_split_parallel() && parallel_num > 1) { + CHECK_EQ_OR_RETURN(local_batch_size % parallel_num, 0); + local_batch_size /= parallel_num; + } + *image_tensor->mut_shape() = Shape({local_batch_size}); + *label_tensor->mut_shape() = Shape({local_batch_size}); + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordImageClassificationReaderOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordImageClassificationReaderOp::ModifyOutputArg( + const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::OutputArgModifier* image_modifier = GetOutputArgModifierFn("image", 0); + CHECK_OR_RETURN(image_modifier != nullptr); + image_modifier->set_header_infered_before_compute(false); + user_op::OutputArgModifier* label_modifier = GetOutputArgModifierFn("label", 0); + CHECK_OR_RETURN(label_modifier != nullptr); + label_modifier->set_header_infered_before_compute(false); + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordImageClassificationReaderOp::InferDataType( + user_op::InferContext* ctx) { + *ctx->OutputDType("image", 0) = DataType::kTensorBuffer; + *ctx->OutputDType("label", 0) = DataType::kTensorBuffer; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/ofrecord_reader_op.cpp b/oneflow/user/ops/ofrecord_reader_op.cpp index 5e29638c773..475d058f00b 100644 --- a/oneflow/user/ops/ofrecord_reader_op.cpp +++ b/oneflow/user/ops/ofrecord_reader_op.cpp @@ -14,59 +14,53 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("OFRecordReader") - .Output("out") - .Attr("data_dir") - .Attr("data_part_num") - .Attr("batch_size") - .Attr("part_name_prefix", "part-") - .Attr("part_name_suffix_length", -1) - .Attr("random_shuffle", false) - .Attr("seed", -1) - .Attr("shuffle_buffer_size", 1024) - .Attr("shuffle_after_epoch", false) - .Attr>("nd_sbp") - .SetPhysicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - int32_t batch_size = ctx->Attr("batch_size"); - const cfg::SbpParallel& sbp = ctx->SbpParallel4ArgNameAndIndex("out", 0); - int64_t parallel_num = ctx->parallel_ctx().parallel_num(); - if (sbp.has_split_parallel() && parallel_num > 1) { - CHECK_EQ_OR_RETURN(batch_size % parallel_num, 0); - batch_size /= parallel_num; - } - *out_tensor->mut_shape() = Shape({batch_size}); - return Maybe::Ok(); - }) - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - *out_tensor->mut_shape() = Shape({ctx->Attr("batch_size")}); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = DataType::kOFRecord; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - cfg::SbpParallel default_sbp; - default_sbp.mutable_split_parallel()->set_axis(0); - return user_op::InferNdSbp4SrcOp(ctx, default_sbp); - }) - .SetOutputArgModifyFn([](user_op::GetOutputArgModifier GetOutputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> Maybe { - user_op::OutputArgModifier* out_modifier = GetOutputArgModifierFn("out", 0); - CHECK_OR_RETURN(out_modifier != nullptr); - // NOTE(chengcheng): OFRecordReader Only support static shape infer which will read all batch - // size data with output shape (batch_size,) - // out_modifier->set_header_infered_before_compute(false); - return Maybe::Ok(); - }); +/* static */ Maybe OFRecordReaderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + *out_tensor->mut_shape() = Shape({ctx->Attr("batch_size")}); + return Maybe::Ok(); +} + +/* static */ Maybe OFRecordReaderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + int32_t batch_size = ctx->Attr("batch_size"); + const cfg::SbpParallel& sbp = ctx->SbpParallel4ArgNameAndIndex("out", 0); + int64_t parallel_num = ctx->parallel_ctx().parallel_num(); + if (sbp.has_split_parallel() && parallel_num > 1) { + CHECK_EQ_OR_RETURN(batch_size % parallel_num, 0); + batch_size /= parallel_num; + } + *out_tensor->mut_shape() = Shape({batch_size}); + return Maybe::Ok(); +} + +/* static */ Maybe OFRecordReaderOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe OFRecordReaderOp::ModifyOutputArg( + const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::OutputArgModifier* out_modifier = GetOutputArgModifierFn("out", 0); + CHECK_OR_RETURN(out_modifier != nullptr); + // NOTE(chengcheng): OFRecordReader Only support static shape infer which will read all batch + // size data with output shape (batch_size,) + // out_modifier->set_header_infered_before_compute(false); + return Maybe::Ok(); +} + +/* static */ Maybe OFRecordReaderOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + cfg::SbpParallel default_sbp; + default_sbp.mutable_split_parallel()->set_axis(0); + return user_op::InferNdSbp4SrcOp(ctx, default_sbp); +} + +/* static */ Maybe OFRecordReaderOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = DataType::kOFRecord; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/one_hot_op.cpp b/oneflow/user/ops/one_hot_op.cpp index c28be85e05f..ded1daf6e20 100644 --- a/oneflow/user/ops/one_hot_op.cpp +++ b/oneflow/user/ops/one_hot_op.cpp @@ -15,56 +15,55 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/balanced_splitter.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("one_hot") - .Input("indices") - .Output("out") - .Attr("depth") - .Attr("floating_on_value") - .Attr("integer_on_value") - .Attr("floating_off_value") - .Attr("integer_off_value") - .Attr("dtype") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const int64_t depth = ctx->Attr("depth"); - CHECK_GT_OR_RETURN(depth, 0); - const user_op::TensorDesc& indices_desc = ctx->InputTensorDesc("indices", 0); - CHECK_GT_OR_RETURN(indices_desc.shape().NumAxes(), 0); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_is_dynamic() = indices_desc.is_dynamic(); - DimVector dim_vec = indices_desc.shape().dim_vec(); - dim_vec.emplace_back(depth); - *out_desc->mut_shape() = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK_OR_RETURN(indices_modifier != nullptr); - indices_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& indices_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0); - FOR_RANGE(int64_t, i, 0, indices_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("indices", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } +/* static */ Maybe OneHotOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const int64_t depth = ctx->Attr("depth"); + CHECK_GT_OR_RETURN(depth, 0); + const user_op::TensorDesc& indices_desc = ctx->InputTensorDesc("indices", 0); + CHECK_GT_OR_RETURN(indices_desc.shape().NumAxes(), 0); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_is_dynamic() = indices_desc.is_dynamic(); + DimVector dim_vec = indices_desc.shape().dim_vec(); + dim_vec.emplace_back(depth); + *out_desc->mut_shape() = Shape(dim_vec); + return Maybe::Ok(); +} - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& indices_desc = ctx->InputTensorDesc("indices", 0); - CHECK_OR_RETURN(IsIndexDataType(indices_desc.data_type())); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - DataType dtype = ctx->Attr("dtype"); - *out_desc->mut_data_type() = dtype; - return Maybe::Ok(); - }); +/*static*/ Maybe OneHotOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe OneHotOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& indices_tensor = + ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0); + FOR_RANGE(int64_t, i, 0, indices_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("indices", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + + return Maybe::Ok(); +} + +/* static */ Maybe OneHotOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); + CHECK_OR_RETURN(indices_modifier != nullptr); + indices_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe OneHotOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& indices_desc = ctx->InputTensorDesc("indices", 0); + CHECK_OR_RETURN(IsIndexDataType(indices_desc.data_type())); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + DataType dtype = ctx->Attr("dtype"); + *out_desc->mut_data_type() = dtype; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/onerec_decoder_op.cpp b/oneflow/user/ops/onerec_decoder_op.cpp index ede82961920..6057dd6486c 100644 --- a/oneflow/user/ops/onerec_decoder_op.cpp +++ b/oneflow/user/ops/onerec_decoder_op.cpp @@ -14,65 +14,61 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("onerec_decoder") - .Input("in") - .Output("out") - .Attr("key") - .Attr("data_type") - .Attr("static_shape") - .Attr("is_dynamic", false) - .Attr("has_reshape", false) - .Attr("reshape") - .Attr("has_batch_padding", false) - .Attr("batch_padding") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1); - const Shape& static_shape = ctx->Attr("static_shape"); - DimVector dim_vec(1 + static_shape.NumAxes()); - dim_vec[0] = in_tensor.shape().At(0); - FOR_RANGE(int64_t, i, 1, dim_vec.size()) { dim_vec[i] = static_shape.At(i - 1); } - *out_tensor->mut_shape() = Shape(dim_vec); - out_tensor->set_is_dynamic(ctx->Attr("is_dynamic")); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); - CHECK_NOTNULL_OR_RETURN(in_modifier); - in_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), 0) - .Split(user_op::OpArg("out", 0), 0) - .Build(); - return Maybe::Ok(); - }) - .SetOutputArgModifyFn([](user_op::GetOutputArgModifier GetOutputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> Maybe { - // NOTE(yaochi): refer to tensor_buffer_to_list_of_tensors - // In order to support consistent tensor, set set_header_infered_before_compute to false - // only when is_dynamic == true - if (conf.attr("is_dynamic")) { - FOR_RANGE(int64_t, i, 0, conf.output_size("out")) { - user_op::OutputArgModifier* out_i_modifier = GetOutputArgModifierFn("out", i); - CHECK_OR_RETURN(out_i_modifier != nullptr); - out_i_modifier->set_header_infered_before_compute(false); - } - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(in_tensor.data_type() == DataType::kTensorBuffer); - *out_tensor->mut_data_type() = ctx->Attr("data_type"); - return Maybe::Ok(); - }); +/* static */ Maybe OnerecDecoderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1); + const Shape& static_shape = ctx->Attr("static_shape"); + DimVector dim_vec(1 + static_shape.NumAxes()); + dim_vec[0] = in_tensor.shape().At(0); + FOR_RANGE(int64_t, i, 1, dim_vec.size()) { dim_vec[i] = static_shape.At(i - 1); } + *out_tensor->mut_shape() = Shape(dim_vec); + out_tensor->set_is_dynamic(ctx->Attr("is_dynamic")); + return Maybe::Ok(); } + +/*static*/ Maybe OnerecDecoderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe OnerecDecoderOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), 0).Split(user_op::OpArg("out", 0), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe OnerecDecoderOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); + CHECK_NOTNULL_OR_RETURN(in_modifier); + in_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe OnerecDecoderOp::ModifyOutputArg( + const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + // NOTE(yaochi): refer to tensor_buffer_to_list_of_tensors + // In order to support consistent tensor, set set_header_infered_before_compute to false + // only when is_dynamic == true + if (conf.attr("is_dynamic")) { + FOR_RANGE(int64_t, i, 0, conf.output_size("out")) { + user_op::OutputArgModifier* out_i_modifier = GetOutputArgModifierFn("out", i); + CHECK_OR_RETURN(out_i_modifier != nullptr); + out_i_modifier->set_header_infered_before_compute(false); + } + } + return Maybe::Ok(); +} + +/* static */ Maybe OnerecDecoderOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(in_tensor.data_type() == DataType::kTensorBuffer); + *out_tensor->mut_data_type() = ctx->Attr("data_type"); + return Maybe::Ok(); +} + +} // namespace oneflow diff --git a/oneflow/user/ops/onerec_reader_op.cpp b/oneflow/user/ops/onerec_reader_op.cpp index 697fa7a2722..b98a924b1ff 100644 --- a/oneflow/user/ops/onerec_reader_op.cpp +++ b/oneflow/user/ops/onerec_reader_op.cpp @@ -14,43 +14,34 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("OneRecReader") - .Output("out") - .Attr>("files") - .Attr("batch_size") - .Attr("random_shuffle", false) - .Attr("shuffle_mode", "instance") - .Attr("seed", -1) - .Attr("shuffle_buffer_size", 1024) - .Attr("shuffle_after_epoch", false) - .Attr("verify_example", true) - .SetPhysicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - int32_t local_batch_size = ctx->Attr("batch_size"); - const cfg::SbpParallel& sbp = ctx->SbpParallel4ArgNameAndIndex("out", 0); - int64_t parallel_num = ctx->parallel_ctx().parallel_num(); - CHECK_OR_RETURN(parallel_num == 1 || sbp.has_split_parallel()); - CHECK_EQ_OR_RETURN(local_batch_size % parallel_num, 0); - local_batch_size /= parallel_num; - *out_tensor->mut_shape() = Shape({local_batch_size}); - return Maybe::Ok(); - }) - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - int32_t batch_size = ctx->Attr("batch_size"); - *out_tensor->mut_shape() = Shape({batch_size}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = DataType::kTensorBuffer; - return Maybe::Ok(); - }); +/*static*/ Maybe OneRecReaderOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe OneRecReaderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + int32_t batch_size = ctx->Attr("batch_size"); + *out_tensor->mut_shape() = Shape({batch_size}); + return Maybe::Ok(); +} +/*static*/ Maybe OneRecReaderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + int32_t local_batch_size = ctx->Attr("batch_size"); + const cfg::SbpParallel& sbp = ctx->SbpParallel4ArgNameAndIndex("out", 0); + int64_t parallel_num = ctx->parallel_ctx().parallel_num(); + CHECK_OR_RETURN(parallel_num == 1 || sbp.has_split_parallel()); + CHECK_EQ_OR_RETURN(local_batch_size % parallel_num, 0); + local_batch_size /= parallel_num; + *out_tensor->mut_shape() = Shape({local_batch_size}); + return Maybe::Ok(); +} +/*static*/ Maybe OneRecReaderOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = DataType::kTensorBuffer; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/ones_like_op.cpp b/oneflow/user/ops/ones_like_op.cpp index 9c37d088b9d..cf05b880f87 100644 --- a/oneflow/user/ops/ones_like_op.cpp +++ b/oneflow/user/ops/ones_like_op.cpp @@ -14,35 +14,34 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("ones_like") - .Input("like") - .Output("out") - .SetOutputBufferNum(1) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("like", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("like", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& like_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0); - FOR_RANGE(int64_t, i, 0, like_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("like", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("like", 0)) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe OnesLikeOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& like_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0); + FOR_RANGE(int64_t, i, 0, like_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("like", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("like", 0)) + .Broadcast(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe OnesLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("like", 0); + return Maybe::Ok(); +} +/*static*/ Maybe OnesLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return OnesLikeOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe OnesLikeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("like", 0); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/p2p_comm_op.cpp b/oneflow/user/ops/p2p_comm_op.cpp index 8bd612fa54e..3e7c06fd7b1 100644 --- a/oneflow/user/ops/p2p_comm_op.cpp +++ b/oneflow/user/ops/p2p_comm_op.cpp @@ -15,24 +15,27 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ Maybe SendOp::GetSbp(user_op::SbpContext* ctx) { UNIMPLEMENTED_THEN_RETURN(); } +/*static*/ Maybe SendOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + // Do nothing. + return Maybe::Ok(); +} +/*static*/ Maybe SendOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return SendOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SendOp::InferDataType(user_op::InferContext* ctx) { + // Do nothing. + return Maybe::Ok(); +} +/*static*/ Maybe> SendOp::InferDevice(user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} -REGISTER_NO_GRAD_USER_OP("send") - .Input("in") - .Attr("dst_process_id") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - // Do nothing. - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { UNIMPLEMENTED_THEN_RETURN(); }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - // Do nothing. - return Maybe::Ok(); - }); +namespace { Maybe> GetRecvOutputDeivce(user_op::DeviceInferContext* ctx) { const std::string& device_type = ctx->Attr("device_type"); @@ -40,24 +43,22 @@ Maybe> GetRecvOutputDeivce(user_op::DeviceInferContext* ctx) { return Device::New(device_type, device_id); } -REGISTER_NO_GRAD_USER_OP("recv") - .Output("out") - .Attr("src_process_id") - .Attr("dtype") - .Attr("shape") - .Attr("device_type") - .Attr("device_id") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->Attr("shape"); - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched, &GetRecvOutputDeivce>) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { UNIMPLEMENTED_THEN_RETURN(); }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->Attr("dtype"); - return Maybe::Ok(); - }); - } // namespace +/*static*/ Maybe RecvOp::GetSbp(user_op::SbpContext* ctx) { UNIMPLEMENTED_THEN_RETURN(); } +/*static*/ Maybe RecvOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->Attr("shape"); + return Maybe::Ok(); +} +/*static*/ Maybe RecvOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return SendOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe RecvOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->Attr("dtype"); + return Maybe::Ok(); +} +/*static*/ Maybe> RecvOp::InferDevice(user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched, &GetRecvOutputDeivce>(ctx); +} + } // namespace oneflow diff --git a/oneflow/user/ops/pack_op.cpp b/oneflow/user/ops/pack_op.cpp index e28529d542c..b5ae5c75a74 100644 --- a/oneflow/user/ops/pack_op.cpp +++ b/oneflow/user/ops/pack_op.cpp @@ -14,61 +14,59 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ Maybe PackOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe PackOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + const Shape& in_shape = in_desc.shape(); + const int32_t pack_num = ctx->Attr("pack_num"); + CHECK_GT_OR_RETURN(pack_num, 0); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_is_dynamic() = in_desc.is_dynamic(); + if (in_shape.NumAxes() > 0) { + *out_desc->mut_shape() = in_shape; + out_desc->mut_shape()->Set(0, in_shape.At(0) * pack_num); + } else { + // NOTE(chengcheng): for Scalar input pack + CHECK_EQ_OR_RETURN(in_shape.elem_cnt(), 1); + *out_desc->mut_shape() = Shape({pack_num}); + } + return Maybe::Ok(); +} -REGISTER_USER_OP("pack") - .Input("in") - .Output("out") - .Attr("pack_num") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - const Shape& in_shape = in_desc.shape(); - const int32_t pack_num = ctx->Attr("pack_num"); - CHECK_GT_OR_RETURN(pack_num, 0); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_is_dynamic() = in_desc.is_dynamic(); - if (in_shape.NumAxes() > 0) { - *out_desc->mut_shape() = in_shape; - out_desc->mut_shape()->Set(0, in_shape.At(0) * pack_num); - } else { - // NOTE(chengcheng): for Scalar input pack - CHECK_EQ_OR_RETURN(in_shape.elem_cnt(), 1); - *out_desc->mut_shape() = Shape({pack_num}); - } - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetOutputBlobTimeShapeInferFn( - [](user_op::InferOutputBlobTimeShapeFnContext* ctx) -> Maybe { - const int32_t pack_num = ctx->user_op_conf().attr("pack_num"); - DimVector time_shape_dim_vec = ctx->TimeShape4InputArgNameAndIndex("in", 0).dim_vec(); - CHECK_OR_RETURN(!time_shape_dim_vec.empty()); - CHECK_EQ_OR_RETURN(time_shape_dim_vec.back(), pack_num); - time_shape_dim_vec.pop_back(); - if (time_shape_dim_vec.empty()) { time_shape_dim_vec.emplace_back(1); } - *ctx->mut_output_blob_time_shape() = Shape(time_shape_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe PackOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return PackOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe PackOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe PackOp::InferOutputBlobTimeShape( + user_op::InferOutputBlobTimeShapeFnContext* ctx) { + const int32_t pack_num = ctx->user_op_conf().attr("pack_num"); + DimVector time_shape_dim_vec = ctx->TimeShape4InputArgNameAndIndex("in", 0).dim_vec(); + CHECK_OR_RETURN(!time_shape_dim_vec.empty()); + CHECK_EQ_OR_RETURN(time_shape_dim_vec.back(), pack_num); + time_shape_dim_vec.pop_back(); + if (time_shape_dim_vec.empty()) { time_shape_dim_vec.emplace_back(1); } + *ctx->mut_output_blob_time_shape() = Shape(time_shape_dim_vec); + return Maybe::Ok(); +} + +namespace { REGISTER_USER_OP_GRAD("pack").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { diff --git a/oneflow/user/ops/pad_op.cpp b/oneflow/user/ops/pad_op.cpp index a446b29ce90..b222bdc2ad7 100644 --- a/oneflow/user/ops/pad_op.cpp +++ b/oneflow/user/ops/pad_op.cpp @@ -15,86 +15,73 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/balanced_splitter.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("pad") - .Input("x") - .Output("y") - .Attr>("padding_before") - .Attr>("padding_after") - .Attr("floating_constant_value") - .Attr("integral_constant_value") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const auto& padding_before = ctx->Attr>("padding_before"); - const auto& padding_after = ctx->Attr>("padding_after"); - CHECK_EQ_OR_RETURN(padding_before.size(), x_shape.NumAxes()); - CHECK_EQ_OR_RETURN(padding_after.size(), x_shape.NumAxes()); - DimVector y_dim_vec(x_shape.NumAxes()); - FOR_RANGE(int64_t, i, 0, x_shape.NumAxes()) { - y_dim_vec[i] = x_shape.At(i) + padding_before[i] + padding_after[i]; - } - *ctx->OutputShape("y", 0) = Shape(y_dim_vec); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - const auto& padding_before = ctx->Attr>("padding_before"); - const auto& padding_after = ctx->Attr>("padding_after"); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - if (padding_before[i] == 0 && padding_after[i] == 0) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("y", 0), i) - .Build(); - } - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe PadOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + const auto& padding_before = ctx->Attr>("padding_before"); + const auto& padding_after = ctx->Attr>("padding_after"); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + if (padding_before[i] == 0 && padding_after[i] == 0) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); + } + } + return Maybe::Ok(); +} +/*static*/ Maybe PadOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const auto& padding_before = ctx->Attr>("padding_before"); + const auto& padding_after = ctx->Attr>("padding_after"); + CHECK_EQ_OR_RETURN(padding_before.size(), x_shape.NumAxes()); + CHECK_EQ_OR_RETURN(padding_after.size(), x_shape.NumAxes()); + DimVector y_dim_vec(x_shape.NumAxes()); + FOR_RANGE(int64_t, i, 0, x_shape.NumAxes()) { + y_dim_vec[i] = x_shape.At(i) + padding_before[i] + padding_after[i]; + } + *ctx->OutputShape("y", 0) = Shape(y_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe PadOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return PadOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe PadOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("pad_grad") - .Input("dy") - .Output("dx") - .Attr>("padding_before") - .Attr>("padding_after") - .Attr("floating_constant_value") - .Attr("integral_constant_value") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - const auto& padding_before = ctx->Attr>("padding_before"); - const auto& padding_after = ctx->Attr>("padding_after"); - CHECK_EQ_OR_RETURN(padding_before.size(), dy_shape.NumAxes()); - CHECK_EQ_OR_RETURN(padding_after.size(), dy_shape.NumAxes()); - DimVector dx_dim_vec(dy_shape.NumAxes()); - FOR_RANGE(int64_t, i, 0, dy_shape.NumAxes()) { - dx_dim_vec[i] = dy_shape.At(i) - padding_before[i] - padding_after[i]; - } - *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); - const auto& padding_before = ctx->Attr>("padding_before"); - const auto& padding_after = ctx->Attr>("padding_after"); - FOR_RANGE(int64_t, i, 0, dy_tensor.shape().NumAxes()) { - if (padding_before[i] == 0 && padding_after[i] == 0) { - ctx->NewBuilder() - .Split(user_op::OpArg("dx", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Build(); - } - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe PadGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); + const auto& padding_before = ctx->Attr>("padding_before"); + const auto& padding_after = ctx->Attr>("padding_after"); + FOR_RANGE(int64_t, i, 0, dy_tensor.shape().NumAxes()) { + if (padding_before[i] == 0 && padding_after[i] == 0) { + ctx->NewBuilder().Split(user_op::OpArg("dx", 0), i).Split(user_op::OpArg("dy", 0), i).Build(); + } + } + return Maybe::Ok(); +} +/*static*/ Maybe PadGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + const auto& padding_before = ctx->Attr>("padding_before"); + const auto& padding_after = ctx->Attr>("padding_after"); + CHECK_EQ_OR_RETURN(padding_before.size(), dy_shape.NumAxes()); + CHECK_EQ_OR_RETURN(padding_after.size(), dy_shape.NumAxes()); + DimVector dx_dim_vec(dy_shape.NumAxes()); + FOR_RANGE(int64_t, i, 0, dy_shape.NumAxes()) { + dx_dim_vec[i] = dy_shape.At(i) - padding_before[i] - padding_after[i]; + } + *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe PadGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return PadGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe PadGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("pad").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/padding_ops.cpp b/oneflow/user/ops/padding_ops.cpp index 46d97d8520c..969bd1a5721 100644 --- a/oneflow/user/ops/padding_ops.cpp +++ b/oneflow/user/ops/padding_ops.cpp @@ -16,6 +16,7 @@ limitations under the License. #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/nn_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -46,89 +47,82 @@ Maybe GetOpGradSbpSignature(user_op::SbpContext* ctx) { } // namespace -REGISTER_USER_OP("reflection_pad2d") - .Input("x") - .Output("y") - .Attr>("padding") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const auto& padding = ctx->Attr>("padding"); - CHECK_EQ_OR_RETURN(padding.size(), x_shape.NumAxes()); - const int64_t n_idx = 0; - const int64_t c_idx = 1; - const int64_t h_idx = 2; - const int64_t w_idx = 3; - - // Ensure the padding size is less than the input dimension. - CHECK_LT_OR_RETURN(padding[0], x_shape.At(w_idx)); - CHECK_LT_OR_RETURN(padding[1], x_shape.At(w_idx)); - CHECK_LT_OR_RETURN(padding[2], x_shape.At(h_idx)); - CHECK_LT_OR_RETURN(padding[3], x_shape.At(h_idx)); - - DimVector y_dim_vec(x_shape.NumAxes()); - const int64_t h_x = x_shape.At(h_idx); - const int64_t w_x = x_shape.At(w_idx); - - y_dim_vec[n_idx] = x_shape.At(n_idx); - y_dim_vec[c_idx] = x_shape.At(c_idx); - y_dim_vec[h_idx] = h_x + padding[2] + padding[3]; - y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; - - *ctx->OutputShape("y", 0) = Shape(y_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetOpSbpSignature) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); - CHECK_NOTNULL_OR_RETURN(x_modifier); - x_modifier->set_requires_grad(true); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ReflectionPad2DOp::GetSbp(user_op::SbpContext* ctx) { + return GetOpSbpSignature(ctx); +} +/*static*/ Maybe ReflectionPad2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const auto& padding = ctx->Attr>("padding"); + CHECK_EQ_OR_RETURN(padding.size(), x_shape.NumAxes()); + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t h_idx = 2; + const int64_t w_idx = 3; + + // Ensure the padding size is less than the input dimension. + CHECK_LT_OR_RETURN(padding[0], x_shape.At(w_idx)); + CHECK_LT_OR_RETURN(padding[1], x_shape.At(w_idx)); + CHECK_LT_OR_RETURN(padding[2], x_shape.At(h_idx)); + CHECK_LT_OR_RETURN(padding[3], x_shape.At(h_idx)); + + DimVector y_dim_vec(x_shape.NumAxes()); + const int64_t h_x = x_shape.At(h_idx); + const int64_t w_x = x_shape.At(w_idx); + + y_dim_vec[n_idx] = x_shape.At(n_idx); + y_dim_vec[c_idx] = x_shape.At(c_idx); + y_dim_vec[h_idx] = h_x + padding[2] + padding[3]; + y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; + + *ctx->OutputShape("y", 0) = Shape(y_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe ReflectionPad2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return ReflectionPad2DOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ReflectionPad2DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe ReflectionPad2DOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); + CHECK_NOTNULL_OR_RETURN(x_modifier); + x_modifier->set_requires_grad(true); + return Maybe::Ok(); +} -REGISTER_USER_OP("reflection_pad2d_grad") - .Input("dy") - .Output("dx") - .Attr>("padding") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - const auto& padding = ctx->Attr>("padding"); - CHECK_EQ_OR_RETURN(padding.size(), dy_shape.NumAxes()); - const int64_t n_idx = 0; - const int64_t c_idx = 1; - const int64_t h_idx = 2; - const int64_t w_idx = 3; - - DimVector dx_dim_vec(dy_shape.NumAxes()); - int64_t h_dy, w_dy; - h_dy = dy_shape.At(h_idx); - w_dy = dy_shape.At(w_idx); - - dx_dim_vec[n_idx] = dy_shape.At(0); - dx_dim_vec[c_idx] = dy_shape.At(1); - dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3]; - dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; - - *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetOpGradSbpSignature) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ReflectionPad2DGradOp::GetSbp(user_op::SbpContext* ctx) { + return GetOpGradSbpSignature(ctx); +} +/*static*/ Maybe ReflectionPad2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + const auto& padding = ctx->Attr>("padding"); + CHECK_EQ_OR_RETURN(padding.size(), dy_shape.NumAxes()); + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t h_idx = 2; + const int64_t w_idx = 3; + + DimVector dx_dim_vec(dy_shape.NumAxes()); + int64_t h_dy = dy_shape.At(h_idx); + int64_t w_dy = dy_shape.At(w_idx); + + dx_dim_vec[n_idx] = dy_shape.At(0); + dx_dim_vec[c_idx] = dy_shape.At(1); + dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3]; + dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; + + *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe ReflectionPad2DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return ReflectionPad2DGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ReflectionPad2DGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("reflection_pad2d") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -147,83 +141,76 @@ REGISTER_USER_OP_GRAD("reflection_pad2d") return Maybe::Ok(); }); -REGISTER_USER_OP("replication_pad2d") - .Input("x") - .Output("y") - .Attr>("padding") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const auto& padding = ctx->Attr>("padding"); - CHECK_EQ_OR_RETURN(padding.size(), x_shape.NumAxes()); - const int64_t n_idx = 0; - const int64_t c_idx = 1; - const int64_t h_idx = 2; - const int64_t w_idx = 3; - - DimVector y_dim_vec(x_shape.NumAxes()); - const int64_t h_x = x_shape.At(h_idx); - const int64_t w_x = x_shape.At(w_idx); - - y_dim_vec[n_idx] = x_shape.At(n_idx); - y_dim_vec[c_idx] = x_shape.At(c_idx); - y_dim_vec[h_idx] = h_x + padding[2] + padding[3]; - y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; - - *ctx->OutputShape("y", 0) = Shape(y_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetOpSbpSignature) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); - CHECK_NOTNULL_OR_RETURN(x_modifier); - x_modifier->set_requires_grad(true); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ReplicationPad2DOp::GetSbp(user_op::SbpContext* ctx) { + return GetOpSbpSignature(ctx); +} +/*static*/ Maybe ReplicationPad2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const auto& padding = ctx->Attr>("padding"); + CHECK_EQ_OR_RETURN(padding.size(), x_shape.NumAxes()); + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t h_idx = 2; + const int64_t w_idx = 3; + + DimVector y_dim_vec(x_shape.NumAxes()); + const int64_t h_x = x_shape.At(h_idx); + const int64_t w_x = x_shape.At(w_idx); + + y_dim_vec[n_idx] = x_shape.At(n_idx); + y_dim_vec[c_idx] = x_shape.At(c_idx); + y_dim_vec[h_idx] = h_x + padding[2] + padding[3]; + y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; + + *ctx->OutputShape("y", 0) = Shape(y_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe ReplicationPad2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return ReplicationPad2DOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ReplicationPad2DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe ReplicationPad2DOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); + CHECK_NOTNULL_OR_RETURN(x_modifier); + x_modifier->set_requires_grad(true); + return Maybe::Ok(); +} -REGISTER_USER_OP("replication_pad2d_grad") - .Input("dy") - .Output("dx") - .Attr>("padding") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - const auto& padding = ctx->Attr>("padding"); - CHECK_EQ_OR_RETURN(padding.size(), dy_shape.NumAxes()); - const int64_t n_idx = 0; - const int64_t c_idx = 1; - const int64_t h_idx = 2; - const int64_t w_idx = 3; - - DimVector dx_dim_vec(dy_shape.NumAxes()); - int64_t h_dy, w_dy; - h_dy = dy_shape.At(h_idx); - w_dy = dy_shape.At(w_idx); - - dx_dim_vec[n_idx] = dy_shape.At(0); - dx_dim_vec[c_idx] = dy_shape.At(1); - dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3]; - dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; - - *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetOpGradSbpSignature) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ReplicationPad2DGradOp::GetSbp(user_op::SbpContext* ctx) { + return GetOpGradSbpSignature(ctx); +} +/*static*/ Maybe ReplicationPad2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + const auto& padding = ctx->Attr>("padding"); + CHECK_EQ_OR_RETURN(padding.size(), dy_shape.NumAxes()); + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t h_idx = 2; + const int64_t w_idx = 3; + + DimVector dx_dim_vec(dy_shape.NumAxes()); + int64_t h_dy = dy_shape.At(h_idx); + int64_t w_dy = dy_shape.At(w_idx); + + dx_dim_vec[n_idx] = dy_shape.At(0); + dx_dim_vec[c_idx] = dy_shape.At(1); + dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3]; + dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; + + *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe ReplicationPad2DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return ReplicationPad2DGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ReplicationPad2DGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("replication_pad2d") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -242,83 +229,72 @@ REGISTER_USER_OP_GRAD("replication_pad2d") return Maybe::Ok(); }); -REGISTER_USER_OP("constant_pad1d") - .Input("x") - .Output("y") - .Attr>("padding") - .Attr("floating_value") - .Attr("integral_value") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const auto& padding = ctx->Attr>("padding"); - CHECK_EQ_OR_RETURN(x_shape.NumAxes(), 3); - CHECK_EQ_OR_RETURN(padding.size(), 2); - const int64_t n_idx = 0; - const int64_t c_idx = 1; - const int64_t w_idx = 2; - - DimVector y_dim_vec(x_shape.NumAxes()); - const int64_t w_x = x_shape.At(w_idx); - - y_dim_vec[n_idx] = x_shape.At(n_idx); - y_dim_vec[c_idx] = x_shape.At(c_idx); - y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; - - *ctx->OutputShape("y", 0) = Shape(y_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetOpSbpSignature) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); - CHECK_NOTNULL_OR_RETURN(x_modifier); - x_modifier->set_requires_grad(true); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ConstantPad1DOp::GetSbp(user_op::SbpContext* ctx) { + return GetOpSbpSignature(ctx); +} +/*static*/ Maybe ConstantPad1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const auto& padding = ctx->Attr>("padding"); + CHECK_EQ_OR_RETURN(x_shape.NumAxes(), 3); + CHECK_EQ_OR_RETURN(padding.size(), 2); + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t w_idx = 2; + + DimVector y_dim_vec(x_shape.NumAxes()); + const int64_t w_x = x_shape.At(w_idx); + + y_dim_vec[n_idx] = x_shape.At(n_idx); + y_dim_vec[c_idx] = x_shape.At(c_idx); + y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; + + *ctx->OutputShape("y", 0) = Shape(y_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe ConstantPad1DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return ConstantPad1DOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ConstantPad1DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe ConstantPad1DOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); + CHECK_NOTNULL_OR_RETURN(x_modifier); + x_modifier->set_requires_grad(true); + return Maybe::Ok(); +} -REGISTER_USER_OP("constant_pad1d_grad") - .Input("dy") - .Output("dx") - .Attr>("padding") - .Attr("floating_value") - .Attr("integral_value") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - const auto& padding = ctx->Attr>("padding"); - CHECK_EQ_OR_RETURN(dy_shape.NumAxes(), 3); - CHECK_EQ_OR_RETURN(padding.size(), 2); - const int64_t n_idx = 0; - const int64_t c_idx = 1; - const int64_t w_idx = 2; - - DimVector dx_dim_vec(dy_shape.NumAxes()); - int64_t w_dy; - w_dy = dy_shape.At(w_idx); - - dx_dim_vec[n_idx] = dy_shape.At(0); - dx_dim_vec[c_idx] = dy_shape.At(1); - dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; - - *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetOpGradSbpSignature) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ConstantPad1DGradOp::GetSbp(user_op::SbpContext* ctx) { + return GetOpGradSbpSignature(ctx); +} +/*static*/ Maybe ConstantPad1DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + const auto& padding = ctx->Attr>("padding"); + CHECK_EQ_OR_RETURN(dy_shape.NumAxes(), 3); + CHECK_EQ_OR_RETURN(padding.size(), 2); + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t w_idx = 2; + + DimVector dx_dim_vec(dy_shape.NumAxes()); + int64_t w_dy = dy_shape.At(w_idx); + + dx_dim_vec[n_idx] = dy_shape.At(0); + dx_dim_vec[c_idx] = dy_shape.At(1); + dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; + + *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe ConstantPad1DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return ConstantPad1DGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ConstantPad1DGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("constant_pad1d") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -339,87 +315,76 @@ REGISTER_USER_OP_GRAD("constant_pad1d") return Maybe::Ok(); }); -REGISTER_USER_OP("constant_pad2d") - .Input("x") - .Output("y") - .Attr>("padding") - .Attr("floating_value") - .Attr("integral_value") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const auto& padding = ctx->Attr>("padding"); - CHECK_EQ_OR_RETURN(padding.size(), x_shape.NumAxes()); - const int64_t n_idx = 0; - const int64_t c_idx = 1; - const int64_t h_idx = 2; - const int64_t w_idx = 3; - - DimVector y_dim_vec(x_shape.NumAxes()); - const int64_t h_x = x_shape.At(h_idx); - const int64_t w_x = x_shape.At(w_idx); - - y_dim_vec[n_idx] = x_shape.At(n_idx); - y_dim_vec[c_idx] = x_shape.At(c_idx); - y_dim_vec[h_idx] = h_x + padding[2] + padding[3]; - y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; - - *ctx->OutputShape("y", 0) = Shape(y_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetOpSbpSignature) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); - CHECK_NOTNULL_OR_RETURN(x_modifier); - x_modifier->set_requires_grad(true); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ConstantPad2DOp::GetSbp(user_op::SbpContext* ctx) { + return GetOpSbpSignature(ctx); +} +/*static*/ Maybe ConstantPad2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const auto& padding = ctx->Attr>("padding"); + CHECK_EQ_OR_RETURN(padding.size(), x_shape.NumAxes()); + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t h_idx = 2; + const int64_t w_idx = 3; + + DimVector y_dim_vec(x_shape.NumAxes()); + const int64_t h_x = x_shape.At(h_idx); + const int64_t w_x = x_shape.At(w_idx); + + y_dim_vec[n_idx] = x_shape.At(n_idx); + y_dim_vec[c_idx] = x_shape.At(c_idx); + y_dim_vec[h_idx] = h_x + padding[2] + padding[3]; + y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; + + *ctx->OutputShape("y", 0) = Shape(y_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe ConstantPad2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return ConstantPad2DOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ConstantPad2DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe ConstantPad2DOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); + CHECK_NOTNULL_OR_RETURN(x_modifier); + x_modifier->set_requires_grad(true); + return Maybe::Ok(); +} -REGISTER_USER_OP("constant_pad2d_grad") - .Input("dy") - .Output("dx") - .Attr>("padding") - .Attr("floating_value") - .Attr("integral_value") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - const auto& padding = ctx->Attr>("padding"); - CHECK_EQ_OR_RETURN(padding.size(), dy_shape.NumAxes()); - const int64_t n_idx = 0; - const int64_t c_idx = 1; - const int64_t h_idx = 2; - const int64_t w_idx = 3; - - DimVector dx_dim_vec(dy_shape.NumAxes()); - int64_t h_dy, w_dy; - h_dy = dy_shape.At(h_idx); - w_dy = dy_shape.At(w_idx); - - dx_dim_vec[n_idx] = dy_shape.At(0); - dx_dim_vec[c_idx] = dy_shape.At(1); - dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3]; - dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; - - *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetOpGradSbpSignature) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ConstantPad2DGradOp::GetSbp(user_op::SbpContext* ctx) { + return GetOpGradSbpSignature(ctx); +} +/*static*/ Maybe ConstantPad2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + const auto& padding = ctx->Attr>("padding"); + CHECK_EQ_OR_RETURN(padding.size(), dy_shape.NumAxes()); + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t h_idx = 2; + const int64_t w_idx = 3; + + DimVector dx_dim_vec(dy_shape.NumAxes()); + int64_t h_dy = dy_shape.At(h_idx); + int64_t w_dy = dy_shape.At(w_idx); + + dx_dim_vec[n_idx] = dy_shape.At(0); + dx_dim_vec[c_idx] = dy_shape.At(1); + dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3]; + dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; + + *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe ConstantPad2DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return ConstantPad2DGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ConstantPad2DGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("constant_pad2d") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -440,95 +405,84 @@ REGISTER_USER_OP_GRAD("constant_pad2d") return Maybe::Ok(); }); -REGISTER_USER_OP("constant_pad3d") - .Input("x") - .Output("y") - .Attr>("padding") - .Attr("floating_value") - .Attr("integral_value") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const auto& padding = ctx->Attr>("padding"); - CHECK_EQ_OR_RETURN(x_shape.NumAxes(), 5); - // only support NCDHW format input tensor for now ! - // for NCDHW format, index of num,channel,depth,height,width is 0,1,2,3,4 - const int64_t n_idx = 0; - const int64_t c_idx = 1; - const int64_t d_idx = 2; - const int64_t h_idx = 3; - const int64_t w_idx = 4; - - DimVector y_dim_vec(x_shape.NumAxes()); - const int64_t d_x = x_shape.At(d_idx); - const int64_t h_x = x_shape.At(h_idx); - const int64_t w_x = x_shape.At(w_idx); - - y_dim_vec[n_idx] = x_shape.At(n_idx); - y_dim_vec[c_idx] = x_shape.At(c_idx); - y_dim_vec[d_idx] = d_x + padding[4] + padding[5]; - y_dim_vec[h_idx] = h_x + padding[2] + padding[3]; - y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; - - *ctx->OutputShape("y", 0) = Shape(y_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetOpSbpSignature) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); - CHECK_NOTNULL_OR_RETURN(x_modifier); - x_modifier->set_requires_grad(true); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ConstantPad3DOp::GetSbp(user_op::SbpContext* ctx) { + return GetOpSbpSignature(ctx); +} +/*static*/ Maybe ConstantPad3DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const auto& padding = ctx->Attr>("padding"); + CHECK_EQ_OR_RETURN(x_shape.NumAxes(), 5); + // only support NCDHW format input tensor for now ! + // for NCDHW format, index of num,channel,depth,height,width is 0,1,2,3,4 + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t d_idx = 2; + const int64_t h_idx = 3; + const int64_t w_idx = 4; + + DimVector y_dim_vec(x_shape.NumAxes()); + const int64_t d_x = x_shape.At(d_idx); + const int64_t h_x = x_shape.At(h_idx); + const int64_t w_x = x_shape.At(w_idx); + + y_dim_vec[n_idx] = x_shape.At(n_idx); + y_dim_vec[c_idx] = x_shape.At(c_idx); + y_dim_vec[d_idx] = d_x + padding[4] + padding[5]; + y_dim_vec[h_idx] = h_x + padding[2] + padding[3]; + y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; + + *ctx->OutputShape("y", 0) = Shape(y_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe ConstantPad3DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return ConstantPad3DOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ConstantPad3DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe ConstantPad3DOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); + CHECK_NOTNULL_OR_RETURN(x_modifier); + x_modifier->set_requires_grad(true); + return Maybe::Ok(); +} -REGISTER_USER_OP("constant_pad3d_grad") - .Input("dy") - .Output("dx") - .Attr>("padding") - .Attr("floating_value") - .Attr("integral_value") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - const auto& padding = ctx->Attr>("padding"); - CHECK_EQ_OR_RETURN(dy_shape.NumAxes(), 5); - const int64_t n_idx = 0; - const int64_t c_idx = 1; - const int64_t d_idx = 2; - const int64_t h_idx = 3; - const int64_t w_idx = 4; - - DimVector dx_dim_vec(dy_shape.NumAxes()); - int64_t d_dy, h_dy, w_dy; - d_dy = dy_shape.At(d_idx); - h_dy = dy_shape.At(h_idx); - w_dy = dy_shape.At(w_idx); - - dx_dim_vec[n_idx] = dy_shape.At(0); - dx_dim_vec[c_idx] = dy_shape.At(1); - dx_dim_vec[d_idx] = d_dy - padding[4] - padding[5]; - dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3]; - dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; - - *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetOpGradSbpSignature) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ConstantPad3DGradOp::GetSbp(user_op::SbpContext* ctx) { + return GetOpGradSbpSignature(ctx); +} +/*static*/ Maybe ConstantPad3DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + const auto& padding = ctx->Attr>("padding"); + CHECK_EQ_OR_RETURN(dy_shape.NumAxes(), 5); + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t d_idx = 2; + const int64_t h_idx = 3; + const int64_t w_idx = 4; + + DimVector dx_dim_vec(dy_shape.NumAxes()); + int64_t d_dy = dy_shape.At(d_idx); + int64_t h_dy = dy_shape.At(h_idx); + int64_t w_dy = dy_shape.At(w_idx); + + dx_dim_vec[n_idx] = dy_shape.At(0); + dx_dim_vec[c_idx] = dy_shape.At(1); + dx_dim_vec[d_idx] = d_dy - padding[4] - padding[5]; + dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3]; + dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; + + *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe ConstantPad3DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return ConstantPad3DGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ConstantPad3DGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("constant_pad3d") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/parallel_cast_op.cpp b/oneflow/user/ops/parallel_cast_op.cpp index b31cf919f4b..b4762acf2d4 100644 --- a/oneflow/user/ops/parallel_cast_op.cpp +++ b/oneflow/user/ops/parallel_cast_op.cpp @@ -15,49 +15,50 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/operator.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("parallel_cast") - .Input("in") - .Output("out") - .Attr("sbp_parallel", "") - .Attr("grad_sbp_parallel", "") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetSbpSignatureInferFn([](user_op::InferSbpSignatureFnContext* ctx) -> Maybe { - auto* bn2sbp = ctx->mutable_sbp_signature()->mutable_bn_in_op2sbp_parallel(); - const std::string& ibn = GenRepeatedBn("in", 0); - const std::string& obn = GenRepeatedBn("out", 0); - const auto& sbp_parallel_str = ctx->Attr("sbp_parallel"); - if (sbp_parallel_str.empty()) { - const auto& sbp_parallel = ctx->SbpParallelHint4InputArgNameAndIndex("in", 0); - (*bn2sbp)[ibn] = sbp_parallel; - (*bn2sbp)[obn] = sbp_parallel; - } else { - cfg::SbpParallel sbp_parallel; - CHECK_OR_RETURN(ParseSbpParallelFromString(sbp_parallel_str, &sbp_parallel)) - << "invalid sbp_parallel: " << sbp_parallel_str; - if (sbp_parallel.has_split_parallel()) { - int64_t split_axis = sbp_parallel.split_parallel().axis(); - const auto& in_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - int64_t num_axes = in_desc.shape().NumAxes(); - CHECK_GE_OR_RETURN(split_axis, 0); - CHECK_LT_OR_RETURN(split_axis, num_axes); - } - (*bn2sbp)[ibn] = sbp_parallel; - (*bn2sbp)[obn] = sbp_parallel; - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/*static*/ Maybe ParallelCastOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe ParallelCastOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe ParallelCastOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return ParallelCastOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ParallelCastOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe ParallelCastOp::InferSbpSignature(user_op::InferSbpSignatureFnContext* ctx) { + auto* bn2sbp = ctx->mutable_sbp_signature()->mutable_bn_in_op2sbp_parallel(); + const std::string& ibn = GenRepeatedBn("in", 0); + const std::string& obn = GenRepeatedBn("out", 0); + const auto& sbp_parallel_str = ctx->Attr("sbp_parallel"); + if (sbp_parallel_str.empty()) { + const auto& sbp_parallel = ctx->SbpParallelHint4InputArgNameAndIndex("in", 0); + (*bn2sbp)[ibn] = sbp_parallel; + (*bn2sbp)[obn] = sbp_parallel; + } else { + cfg::SbpParallel sbp_parallel; + CHECK_OR_RETURN(ParseSbpParallelFromString(sbp_parallel_str, &sbp_parallel)) + << "invalid sbp_parallel: " << sbp_parallel_str; + if (sbp_parallel.has_split_parallel()) { + int64_t split_axis = sbp_parallel.split_parallel().axis(); + const auto& in_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + int64_t num_axes = in_desc.shape().NumAxes(); + CHECK_GE_OR_RETURN(split_axis, 0); + CHECK_LT_OR_RETURN(split_axis, num_axes); + } + (*bn2sbp)[ibn] = sbp_parallel; + (*bn2sbp)[obn] = sbp_parallel; + } + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("parallel_cast") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { diff --git a/oneflow/user/ops/partial_fc_sample_op.cpp b/oneflow/user/ops/partial_fc_sample_op.cpp index b40d6f94d8d..1798e91fe6d 100644 --- a/oneflow/user/ops/partial_fc_sample_op.cpp +++ b/oneflow/user/ops/partial_fc_sample_op.cpp @@ -14,127 +14,119 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("distributed_partial_fc_sample") - .Input("weight") - .Input("label") - .Output("mapped_label") - .Output("sampled_label") - .Output("sampled_weight") - .Attr("num_sample") - .Attr("seed", -1) - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const int64_t num_sample = ctx->Attr("num_sample"); - const user_op::TensorDesc& weight = ctx->InputTensorDesc("weight", 0); - const user_op::TensorDesc& label = ctx->InputTensorDesc("label", 0); - user_op::TensorDesc* mapped_label = ctx->OutputTensorDesc("mapped_label", 0); - user_op::TensorDesc* sampled_weight = ctx->OutputTensorDesc("sampled_weight", 0); - user_op::TensorDesc* sampled_label = ctx->OutputTensorDesc("sampled_label", 0); - *mapped_label->mut_shape() = label.shape(); - *mapped_label->mut_is_dynamic() = label.is_dynamic(); - *sampled_weight->mut_shape() = weight.shape(); - sampled_weight->mut_shape()->Set(0, num_sample); - *sampled_weight->mut_is_dynamic() = weight.is_dynamic(); - *sampled_label->mut_shape() = label.shape(); - sampled_label->mut_shape()->Set(0, num_sample); - *sampled_label->mut_is_dynamic() = label.is_dynamic(); - return Maybe::Ok(); - }) - .SetPhysicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const int64_t num_sample = ctx->Attr("num_sample"); - const int64_t parallel_num = ctx->parallel_ctx().parallel_num(); - CHECK_EQ_OR_RETURN(num_sample % parallel_num, 0); - const int64_t num_sample_per_rank = num_sample / parallel_num; - const user_op::TensorDesc& weight = ctx->InputTensorDesc("weight", 0); - const user_op::TensorDesc& label = ctx->InputTensorDesc("label", 0); - user_op::TensorDesc* mapped_label = ctx->OutputTensorDesc("mapped_label", 0); - user_op::TensorDesc* sampled_weight = ctx->OutputTensorDesc("sampled_weight", 0); - user_op::TensorDesc* sampled_label = ctx->OutputTensorDesc("sampled_label", 0); - *mapped_label->mut_shape() = label.shape(); - *mapped_label->mut_is_dynamic() = label.is_dynamic(); - *sampled_weight->mut_shape() = weight.shape(); - sampled_weight->mut_shape()->Set(0, num_sample_per_rank); - *sampled_weight->mut_is_dynamic() = weight.is_dynamic(); - *sampled_label->mut_shape() = label.shape(); - sampled_label->mut_shape()->Set(0, num_sample_per_rank); - *sampled_label->mut_is_dynamic() = label.is_dynamic(); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); - CHECK_NOTNULL_OR_RETURN(label_modifier); - label_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("weight", 0), 0) - .Broadcast(user_op::OpArg("label", 0)) - .Broadcast(user_op::OpArg("mapped_label", 0)) - .Split(user_op::OpArg("sampled_label", 0), 0) - .Split(user_op::OpArg("sampled_weight", 0), 0) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("mapped_label", 0) = ctx->InputDType("label", 0); - *ctx->OutputDType("sampled_weight", 0) = ctx->InputDType("weight", 0); - *ctx->OutputDType("sampled_label", 0) = ctx->InputDType("label", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe DistributedPartialFcSampleOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("weight", 0), 0) + .Broadcast(user_op::OpArg("label", 0)) + .Broadcast(user_op::OpArg("mapped_label", 0)) + .Split(user_op::OpArg("sampled_label", 0), 0) + .Split(user_op::OpArg("sampled_weight", 0), 0) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe DistributedPartialFcSampleOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const int64_t num_sample = ctx->Attr("num_sample"); + const user_op::TensorDesc& weight = ctx->InputTensorDesc("weight", 0); + const user_op::TensorDesc& label = ctx->InputTensorDesc("label", 0); + user_op::TensorDesc* mapped_label = ctx->OutputTensorDesc("mapped_label", 0); + user_op::TensorDesc* sampled_weight = ctx->OutputTensorDesc("sampled_weight", 0); + user_op::TensorDesc* sampled_label = ctx->OutputTensorDesc("sampled_label", 0); + *mapped_label->mut_shape() = label.shape(); + *mapped_label->mut_is_dynamic() = label.is_dynamic(); + *sampled_weight->mut_shape() = weight.shape(); + sampled_weight->mut_shape()->Set(0, num_sample); + *sampled_weight->mut_is_dynamic() = weight.is_dynamic(); + *sampled_label->mut_shape() = label.shape(); + sampled_label->mut_shape()->Set(0, num_sample); + *sampled_label->mut_is_dynamic() = label.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ Maybe DistributedPartialFcSampleOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + const int64_t num_sample = ctx->Attr("num_sample"); + const int64_t parallel_num = ctx->parallel_ctx().parallel_num(); + CHECK_EQ_OR_RETURN(num_sample % parallel_num, 0); + const int64_t num_sample_per_rank = num_sample / parallel_num; + const user_op::TensorDesc& weight = ctx->InputTensorDesc("weight", 0); + const user_op::TensorDesc& label = ctx->InputTensorDesc("label", 0); + user_op::TensorDesc* mapped_label = ctx->OutputTensorDesc("mapped_label", 0); + user_op::TensorDesc* sampled_weight = ctx->OutputTensorDesc("sampled_weight", 0); + user_op::TensorDesc* sampled_label = ctx->OutputTensorDesc("sampled_label", 0); + *mapped_label->mut_shape() = label.shape(); + *mapped_label->mut_is_dynamic() = label.is_dynamic(); + *sampled_weight->mut_shape() = weight.shape(); + sampled_weight->mut_shape()->Set(0, num_sample_per_rank); + *sampled_weight->mut_is_dynamic() = weight.is_dynamic(); + *sampled_label->mut_shape() = label.shape(); + sampled_label->mut_shape()->Set(0, num_sample_per_rank); + *sampled_label->mut_is_dynamic() = label.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ Maybe DistributedPartialFcSampleOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("mapped_label", 0) = ctx->InputDType("label", 0); + *ctx->OutputDType("sampled_weight", 0) = ctx->InputDType("weight", 0); + *ctx->OutputDType("sampled_label", 0) = ctx->InputDType("label", 0); + return Maybe::Ok(); +} +/*static*/ Maybe DistributedPartialFcSampleOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); + CHECK_NOTNULL_OR_RETURN(label_modifier); + label_modifier->set_requires_grad(false); + return Maybe::Ok(); +} -REGISTER_USER_OP("distributed_partial_fc_sample_disable_boxing") - .Input("sampled_weight_diff") - .Input("sampled_label") - .Output("boxing_disabled_sampled_weight_diff") - .Output("boxing_disabled_sampled_label") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* boxing_disabled_sampled_weight_diff = - ctx->OutputTensorDesc("boxing_disabled_sampled_weight_diff", 0); - *boxing_disabled_sampled_weight_diff->mut_shape() = ctx->InputShape("sampled_weight_diff", 0); - CHECK_EQ_OR_RETURN(boxing_disabled_sampled_weight_diff->shape().At(0) % ctx->parallel_num(), - 0); - boxing_disabled_sampled_weight_diff->mut_shape()->Set( - 0, boxing_disabled_sampled_weight_diff->shape().At(0) / ctx->parallel_num()); - *boxing_disabled_sampled_weight_diff->mut_is_dynamic() = - ctx->InputIsDynamic("sampled_weight_diff", 0); - user_op::TensorDesc* boxing_disabled_sampled_label = - ctx->OutputTensorDesc("boxing_disabled_sampled_label", 0); - *boxing_disabled_sampled_label->mut_shape() = ctx->InputShape("sampled_label", 0); - CHECK_EQ_OR_RETURN(boxing_disabled_sampled_label->shape().At(0) % ctx->parallel_num(), 0); - boxing_disabled_sampled_label->mut_shape()->Set( - 0, boxing_disabled_sampled_label->shape().At(0) / ctx->parallel_num()); - *boxing_disabled_sampled_label->mut_is_dynamic() = ctx->InputIsDynamic("sampled_label", 0); - return Maybe::Ok(); - }) - .SetPhysicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("boxing_disabled_sampled_weight_diff", 0) = - ctx->InputShape("sampled_weight_diff", 0); - *ctx->OutputIsDynamic("boxing_disabled_sampled_weight_diff", 0) = - ctx->InputIsDynamic("sampled_weight_diff", 0); - *ctx->OutputShape("boxing_disabled_sampled_label", 0) = ctx->InputShape("sampled_label", 0); - *ctx->OutputIsDynamic("boxing_disabled_sampled_label", 0) = - ctx->InputIsDynamic("sampled_label", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("sampled_weight_diff", 0), 0) - .Split(user_op::OpArg("sampled_label", 0), 0) - .Broadcast(user_op::OpArg("boxing_disabled_sampled_weight_diff", 0)) - .Broadcast(user_op::OpArg("boxing_disabled_sampled_label", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("boxing_disabled_sampled_weight_diff", 0) = - ctx->InputDType("sampled_weight_diff", 0); - *ctx->OutputDType("boxing_disabled_sampled_label", 0) = ctx->InputDType("sampled_label", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe DistributedPartialFcSampleDisableBoxingOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("sampled_weight_diff", 0), 0) + .Split(user_op::OpArg("sampled_label", 0), 0) + .Broadcast(user_op::OpArg("boxing_disabled_sampled_weight_diff", 0)) + .Broadcast(user_op::OpArg("boxing_disabled_sampled_label", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe DistributedPartialFcSampleDisableBoxingOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + user_op::TensorDesc* boxing_disabled_sampled_weight_diff = + ctx->OutputTensorDesc("boxing_disabled_sampled_weight_diff", 0); + *boxing_disabled_sampled_weight_diff->mut_shape() = ctx->InputShape("sampled_weight_diff", 0); + CHECK_EQ_OR_RETURN(boxing_disabled_sampled_weight_diff->shape().At(0) % ctx->parallel_num(), 0); + boxing_disabled_sampled_weight_diff->mut_shape()->Set( + 0, boxing_disabled_sampled_weight_diff->shape().At(0) / ctx->parallel_num()); + *boxing_disabled_sampled_weight_diff->mut_is_dynamic() = + ctx->InputIsDynamic("sampled_weight_diff", 0); + user_op::TensorDesc* boxing_disabled_sampled_label = + ctx->OutputTensorDesc("boxing_disabled_sampled_label", 0); + *boxing_disabled_sampled_label->mut_shape() = ctx->InputShape("sampled_label", 0); + CHECK_EQ_OR_RETURN(boxing_disabled_sampled_label->shape().At(0) % ctx->parallel_num(), 0); + boxing_disabled_sampled_label->mut_shape()->Set( + 0, boxing_disabled_sampled_label->shape().At(0) / ctx->parallel_num()); + *boxing_disabled_sampled_label->mut_is_dynamic() = ctx->InputIsDynamic("sampled_label", 0); + return Maybe::Ok(); +} +/*static*/ Maybe DistributedPartialFcSampleDisableBoxingOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("boxing_disabled_sampled_weight_diff", 0) = + ctx->InputShape("sampled_weight_diff", 0); + *ctx->OutputIsDynamic("boxing_disabled_sampled_weight_diff", 0) = + ctx->InputIsDynamic("sampled_weight_diff", 0); + *ctx->OutputShape("boxing_disabled_sampled_label", 0) = ctx->InputShape("sampled_label", 0); + *ctx->OutputIsDynamic("boxing_disabled_sampled_label", 0) = + ctx->InputIsDynamic("sampled_label", 0); + return Maybe::Ok(); +} +/*static*/ Maybe DistributedPartialFcSampleDisableBoxingOp::InferDataType( + user_op::InferContext* ctx) { + *ctx->OutputDType("boxing_disabled_sampled_weight_diff", 0) = + ctx->InputDType("sampled_weight_diff", 0); + *ctx->OutputDType("boxing_disabled_sampled_label", 0) = ctx->InputDType("sampled_label", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("distributed_partial_fc_sample") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { diff --git a/oneflow/user/ops/pool_op.cpp b/oneflow/user/ops/pool_op.cpp index 238c04e4c08..39afc8478b8 100644 --- a/oneflow/user/ops/pool_op.cpp +++ b/oneflow/user/ops/pool_op.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/utils/pool_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -113,49 +114,47 @@ GenBackwardOpConfFn MakeGenBackwardOpConfFn(const std::string& mode, const int32 } // namespace -#define REGISTER_TF_AVG_POOL_FORWARD_OP(name, dim) \ - REGISTER_USER_OP(name) \ - .Input("x") \ - .Output("y") \ - .Attr("padding") \ - .Attr>("padding_before") \ - .Attr>("padding_after") \ - .Attr("data_format") \ - .Attr>("pool_size") \ - .Attr>("strides") \ - .Attr("ceil_mode") \ - .SetTensorDescInferFn(MakeFwTensorDescInferFn(dim)) \ - .SetGetSbpFn(FwGetSbpFn) \ - .SetDataTypeInferFn(FwInferDataType); - -REGISTER_TF_AVG_POOL_FORWARD_OP("tf_avg_pool_1d", 1) -REGISTER_TF_AVG_POOL_FORWARD_OP("tf_avg_pool_2d", 2) -REGISTER_TF_AVG_POOL_FORWARD_OP("tf_avg_pool_3d", 3) - -#undef REGISTER_TF_AVG_POOL_FORWARD_OP - -#define REGISTER_TF_AVG_POOL_BACKWARD_OP(name) \ - REGISTER_USER_OP(name) \ - .Input("x") \ - .Input("y") \ - .Input("dy") \ - .Output("dx") \ - .Attr("padding") \ - .Attr>("padding_before") \ - .Attr>("padding_after") \ - .Attr("data_format") \ - .Attr>("pool_size") \ - .Attr>("strides") \ - .Attr("ceil_mode") \ - .SetTensorDescInferFn(BwTensorDescInferFn) \ - .SetGetSbpFn(BwGetSbpFn) \ - .SetDataTypeInferFn(BwInferDataType); - -REGISTER_TF_AVG_POOL_BACKWARD_OP("tf_avg_pool_1d_grad") -REGISTER_TF_AVG_POOL_BACKWARD_OP("tf_avg_pool_2d_grad") -REGISTER_TF_AVG_POOL_BACKWARD_OP("tf_avg_pool_3d_grad") - -#undef REGISTER_TF_AVG_POOL_FORWARD_OP +#define IMPLEMENT_TF_POOL_FUNCS(name, dim) \ + /*static*/ Maybe name##Op::GetSbp(user_op::SbpContext* ctx) { return FwGetSbpFn(ctx); } \ + /*static*/ Maybe name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return MakeFwTensorDescInferFn(dim)(ctx); \ + } \ + /*static*/ Maybe name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe name##Op::InferDataType(user_op::InferContext* ctx) { \ + return FwInferDataType(ctx); \ + } + +IMPLEMENT_TF_POOL_FUNCS(TfAvgPool1D, 1) +IMPLEMENT_TF_POOL_FUNCS(TfAvgPool2D, 2) +IMPLEMENT_TF_POOL_FUNCS(TfAvgPool3D, 3) +IMPLEMENT_TF_POOL_FUNCS(TfMaxPool1D, 1) +IMPLEMENT_TF_POOL_FUNCS(TfMaxPool2D, 2) +IMPLEMENT_TF_POOL_FUNCS(TfMaxPool3D, 3) +#undef IMPLEMENT_TF_POOL_FUNCS + +#define IMPLEMENT_TF_POOL_BACKWARD_FUNCS(name) \ + /*static*/ Maybe name##GradOp::GetSbp(user_op::SbpContext* ctx) { \ + return BwGetSbpFn(ctx); \ + } \ + /*static*/ Maybe name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return BwTensorDescInferFn(ctx); \ + } \ + /*static*/ Maybe name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe name##GradOp::InferDataType(user_op::InferContext* ctx) { \ + return BwInferDataType(ctx); \ + } + +IMPLEMENT_TF_POOL_BACKWARD_FUNCS(TfAvgPool1D) +IMPLEMENT_TF_POOL_BACKWARD_FUNCS(TfAvgPool2D) +IMPLEMENT_TF_POOL_BACKWARD_FUNCS(TfAvgPool3D) +IMPLEMENT_TF_POOL_BACKWARD_FUNCS(TfMaxPool1D) +IMPLEMENT_TF_POOL_BACKWARD_FUNCS(TfMaxPool2D) +IMPLEMENT_TF_POOL_BACKWARD_FUNCS(TfMaxPool3D) +#undef IMPLEMENT_TF_POOL_BACKWARD_FUNCS REGISTER_USER_OP_GRAD("tf_avg_pool_1d") .SetGenBackwardOpConfFn(MakeGenBackwardOpConfFn("tf_avg", 1)); @@ -164,50 +163,6 @@ REGISTER_USER_OP_GRAD("tf_avg_pool_2d") REGISTER_USER_OP_GRAD("tf_avg_pool_3d") .SetGenBackwardOpConfFn(MakeGenBackwardOpConfFn("tf_avg", 3)); -#define REGISTER_TF_MAX_POOL_FORWARD_OP(name, dim) \ - REGISTER_USER_OP(name) \ - .Input("x") \ - .Output("y") \ - .Attr("padding") \ - .Attr>("padding_before") \ - .Attr>("padding_after") \ - .Attr("data_format") \ - .Attr>("pool_size") \ - .Attr>("strides") \ - .Attr("ceil_mode") \ - .SetTensorDescInferFn(MakeFwTensorDescInferFn(dim)) \ - .SetGetSbpFn(FwGetSbpFn) \ - .SetDataTypeInferFn(FwInferDataType); - -REGISTER_TF_MAX_POOL_FORWARD_OP("tf_max_pool_1d", 1) -REGISTER_TF_MAX_POOL_FORWARD_OP("tf_max_pool_2d", 2) -REGISTER_TF_MAX_POOL_FORWARD_OP("tf_max_pool_3d", 3) - -#undef REGISTER_TF_MAX_POOL_FORWARD_OP - -#define REGISTER_TF_MAX_POOL_BACKWARD_OP(name) \ - REGISTER_USER_OP(name) \ - .Input("x") \ - .Input("y") \ - .Input("dy") \ - .Output("dx") \ - .Attr("padding") \ - .Attr>("padding_before") \ - .Attr>("padding_after") \ - .Attr("data_format") \ - .Attr>("pool_size") \ - .Attr>("strides") \ - .Attr("ceil_mode") \ - .SetTensorDescInferFn(BwTensorDescInferFn) \ - .SetGetSbpFn(BwGetSbpFn) \ - .SetDataTypeInferFn(BwInferDataType); - -REGISTER_TF_MAX_POOL_BACKWARD_OP("tf_max_pool_1d_grad") -REGISTER_TF_MAX_POOL_BACKWARD_OP("tf_max_pool_2d_grad") -REGISTER_TF_MAX_POOL_BACKWARD_OP("tf_max_pool_3d_grad") - -#undef REGISTER_TF_MAX_POOL_BACKWARD_OP - REGISTER_USER_OP_GRAD("tf_max_pool_1d") .SetGenBackwardOpConfFn(MakeGenBackwardOpConfFn("tf_max", 1)); REGISTER_USER_OP_GRAD("tf_max_pool_2d") diff --git a/oneflow/user/ops/pooling_op.cpp b/oneflow/user/ops/pooling_op.cpp index 46bddf986ce..c46b85188d3 100644 --- a/oneflow/user/ops/pooling_op.cpp +++ b/oneflow/user/ops/pooling_op.cpp @@ -16,6 +16,7 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/pooling_kernel_util.h" #include "oneflow/user/kernels/avg_pooling_kernel_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -27,7 +28,7 @@ typedef std::function(const user_op::UserOpWrapper& op, user_op::Add TensorDescInferFn MaxPoolMakeForwardTensorDescInferFn(const int32_t dim) { return [dim](user_op::InferContext* ctx) -> Maybe { - const Shape* x_shape = ctx->Shape4ArgNameAndIndex("x", 0); + const Shape& x_shape = ctx->InputShape("x", 0); const std::string& data_format = ctx->Attr("data_format"); const std::vector& padding = ctx->Attr>("padding"); const std::vector& kernel_size = ctx->Attr>("kernel_size"); @@ -45,14 +46,14 @@ TensorDescInferFn MaxPoolMakeForwardTensorDescInferFn(const int32_t dim) { << "pad should be smaller than half of kernel size"; } - const MaxPoolingParams3D params_3d(dim, *x_shape, data_format, padding, kernel_size, stride, + const MaxPoolingParams3D params_3d(dim, x_shape, data_format, padding, kernel_size, stride, dilation, return_indices, ceil_mode); - user_op::TensorDesc* y_desc = ctx->TensorDesc4ArgNameAndIndex("y", 0); - *y_desc = *ctx->TensorDesc4ArgNameAndIndex("x", 0); + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + *y_desc = ctx->InputTensorDesc("x", 0); *y_desc->mut_shape() = params_3d.GetYShape(); - user_op::TensorDesc* indice_desc = ctx->TensorDesc4ArgNameAndIndex("indice", 0); - *indice_desc = *ctx->TensorDesc4ArgNameAndIndex("y", 0); + user_op::TensorDesc* indice_desc = ctx->OutputTensorDesc("indice", 0); + *indice_desc = *ctx->OutputTensorDesc("y", 0); *indice_desc->mut_shape() = *y_desc->mut_shape(); DataType* dtype = indice_desc->mut_data_type(); *dtype = kInt64; @@ -82,8 +83,8 @@ TensorDescInferFn AvgPoolMakeForwardTensorDescInferFn(const int32_t dim) { const AvgPoolingParams3D params_3d(dim, *x_shape, data_format, padding, kernel_size, stride, ceil_mode, count_include_pad, divisor_override); - user_op::TensorDesc* y_desc = ctx->TensorDesc4ArgNameAndIndex("y", 0); - *y_desc = *ctx->TensorDesc4ArgNameAndIndex("x", 0); + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + *y_desc = ctx->InputTensorDesc("x", 0); *y_desc->mut_shape() = params_3d.GetYShape(); return Maybe::Ok(); @@ -189,7 +190,7 @@ GenBackwardOpConfFn AvgPoolMakeBackwardOpConfFn(const int32_t dim) { } Maybe BackwardTensorDescInferFn(user_op::InferContext* ctx) { - *ctx->TensorDesc4ArgNameAndIndex("dx", 0) = *ctx->TensorDesc4ArgNameAndIndex("x", 0); + *ctx->OutputTensorDesc("dx", 0) = ctx->InputTensorDesc("x", 0); return Maybe::Ok(); } @@ -204,99 +205,85 @@ Maybe BwInferDataType(user_op::InferContext* ctx) { } } // namespace -#define REGISTER_MAXPOOL_FORWARD_OP(name, dim) \ - REGISTER_USER_OP(name) \ - .Input("x") \ - .Output("y") \ - .Output("indice") \ - .Attr>("padding") \ - .Attr("data_format") \ - .Attr>("kernel_size") \ - .Attr>("stride") \ - .Attr>("dilation") \ - .Attr("return_indices") \ - .Attr("ceil_mode") \ - .SetTensorDescInferFn(MaxPoolMakeForwardTensorDescInferFn(dim)) \ - .SetGetSbpFn(MaxPoolForwardGetSbpFn) \ - .SetDataTypeInferFn(FwInferDataType); - -REGISTER_MAXPOOL_FORWARD_OP("maxpool_1d", 1) -REGISTER_MAXPOOL_FORWARD_OP("maxpool_2d", 2) -REGISTER_MAXPOOL_FORWARD_OP("maxpool_3d", 3) - -#undef REGISTER_MAXPOOL_FORWARD_OP - -#define REGISTER_MAXPOOL_BACKWARD_OP(name) \ - REGISTER_USER_OP(name) \ - .Input("x") \ - .Input("y") \ - .Input("indice") \ - .Input("dy") \ - .Output("dx") \ - .Attr>("padding") \ - .Attr("data_format") \ - .Attr>("kernel_size") \ - .Attr>("stride") \ - .Attr>("dilation") \ - .Attr("return_indices") \ - .Attr("ceil_mode") \ - .SetTensorDescInferFn(BackwardTensorDescInferFn) \ - .SetGetSbpFn(MaxPoolBackwardGetSbpFn) \ - .SetDataTypeInferFn(BwInferDataType); - -REGISTER_MAXPOOL_BACKWARD_OP("maxpool_1d_grad") -REGISTER_MAXPOOL_BACKWARD_OP("maxpool_2d_grad") -REGISTER_MAXPOOL_BACKWARD_OP("maxpool_3d_grad") - -#undef REGISTER_MAXPOOL_BACKWARD_OP +#define IMPLEMENT_MAXPOOL_FUNCS(name, dim) \ + /*static*/ Maybe name##Op::GetSbp(user_op::SbpContext* ctx) { \ + return MaxPoolForwardGetSbpFn(ctx); \ + } \ + /*static*/ Maybe name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return MaxPoolMakeForwardTensorDescInferFn(dim)(ctx); \ + } \ + /*static*/ Maybe name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe name##Op::InferDataType(user_op::InferContext* ctx) { \ + return FwInferDataType(ctx); \ + } + +IMPLEMENT_MAXPOOL_FUNCS(MaxPool1D, 1) +IMPLEMENT_MAXPOOL_FUNCS(MaxPool2D, 2) +IMPLEMENT_MAXPOOL_FUNCS(MaxPool3D, 3) +#undef IMPLEMENT_MAXPOOL_FUNCS + +#define IMPLEMENT_MAXPOOL_BACKWARD_FUNCS(name) \ + /*static*/ Maybe name##GradOp::GetSbp(user_op::SbpContext* ctx) { \ + return MaxPoolBackwardGetSbpFn(ctx); \ + } \ + /*static*/ Maybe name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return BackwardTensorDescInferFn(ctx); \ + } \ + /*static*/ Maybe name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe name##GradOp::InferDataType(user_op::InferContext* ctx) { \ + return BwInferDataType(ctx); \ + } + +IMPLEMENT_MAXPOOL_BACKWARD_FUNCS(MaxPool1D) +IMPLEMENT_MAXPOOL_BACKWARD_FUNCS(MaxPool2D) +IMPLEMENT_MAXPOOL_BACKWARD_FUNCS(MaxPool3D) +#undef IMPLEMENT_MAXPOOL_BACKWARD_FUNCS REGISTER_USER_OP_GRAD("maxpool_1d").SetGenBackwardOpConfFn(MaxPoolMakeBackwardOpConfFn("max", 1)); REGISTER_USER_OP_GRAD("maxpool_2d").SetGenBackwardOpConfFn(MaxPoolMakeBackwardOpConfFn("max", 2)); REGISTER_USER_OP_GRAD("maxpool_3d").SetGenBackwardOpConfFn(MaxPoolMakeBackwardOpConfFn("max", 3)); -#define REGISTER_AVGPOOL_FORWARD_OP(name, ndim) \ - REGISTER_USER_OP(name) \ - .Input("x") \ - .Output("y") \ - .Attr>("padding") \ - .Attr("data_format") \ - .Attr>("kernel_size") \ - .Attr>("stride") \ - .Attr("ceil_mode") \ - .Attr("count_include_pad") \ - .Attr("divisor_override") \ - .SetTensorDescInferFn(AvgPoolMakeForwardTensorDescInferFn(ndim)) \ - .SetGetSbpFn(AvgPoolForwardGetSbpFn) \ - .SetDataTypeInferFn(FwInferDataType); - -REGISTER_AVGPOOL_FORWARD_OP("avgpool_1d", 1); -REGISTER_AVGPOOL_FORWARD_OP("avgpool_2d", 2); -REGISTER_AVGPOOL_FORWARD_OP("avgpool_3d", 3); - -#undef REGISTER_AVGPOOL_FORWARD_OP - -#define REGISTER_AVGPOOL_BACKWARD_OP(name) \ - REGISTER_USER_OP(name) \ - .Input("x") \ - .Input("y") \ - .Input("dy") \ - .Output("dx") \ - .Attr>("padding") \ - .Attr("data_format") \ - .Attr>("kernel_size") \ - .Attr>("stride") \ - .Attr("ceil_mode") \ - .Attr("count_include_pad") \ - .Attr("divisor_override") \ - .SetTensorDescInferFn(BackwardTensorDescInferFn) \ - .SetGetSbpFn(AvgPoolBackwardGetSbpFn) \ - .SetDataTypeInferFn(BwInferDataType); - -REGISTER_AVGPOOL_BACKWARD_OP("avgpool_1d_grad"); -REGISTER_AVGPOOL_BACKWARD_OP("avgpool_2d_grad"); -REGISTER_AVGPOOL_BACKWARD_OP("avgpool_3d_grad"); - -#undef REGISTER_AVGPOOL_BACKWARD_OP +#define IMPLEMENT_AVGPOOL_FUNCS(name, ndim) \ + /*static*/ Maybe name##Op::GetSbp(user_op::SbpContext* ctx) { \ + return AvgPoolForwardGetSbpFn(ctx); \ + } \ + /*static*/ Maybe name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return AvgPoolMakeForwardTensorDescInferFn(ndim)(ctx); \ + } \ + /*static*/ Maybe name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe name##Op::InferDataType(user_op::InferContext* ctx) { \ + return FwInferDataType(ctx); \ + } + +IMPLEMENT_AVGPOOL_FUNCS(AvgPool1D, 1) +IMPLEMENT_AVGPOOL_FUNCS(AvgPool2D, 2) +IMPLEMENT_AVGPOOL_FUNCS(AvgPool3D, 3) +#undef IMPLEMENT_AVGPOOL_FUNCS + +#define IMPLEMENT_AVGPOOL_BACKWARD_FUNCS(name) \ + /*static*/ Maybe name##GradOp::GetSbp(user_op::SbpContext* ctx) { \ + return AvgPoolBackwardGetSbpFn(ctx); \ + } \ + /*static*/ Maybe name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return BackwardTensorDescInferFn(ctx); \ + } \ + /*static*/ Maybe name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe name##GradOp::InferDataType(user_op::InferContext* ctx) { \ + return BwInferDataType(ctx); \ + } + +IMPLEMENT_AVGPOOL_BACKWARD_FUNCS(AvgPool1D) +IMPLEMENT_AVGPOOL_BACKWARD_FUNCS(AvgPool2D) +IMPLEMENT_AVGPOOL_BACKWARD_FUNCS(AvgPool3D) +#undef IMPLEMENT_AVGPOOL_BACKWARD_FUNCS REGISTER_USER_OP_GRAD("avgpool_1d").SetGenBackwardOpConfFn(AvgPoolMakeBackwardOpConfFn(1)); REGISTER_USER_OP_GRAD("avgpool_2d").SetGenBackwardOpConfFn(AvgPoolMakeBackwardOpConfFn(2)); diff --git a/oneflow/user/ops/prelu_op.cpp b/oneflow/user/ops/prelu_op.cpp index c100cbe1b59..c6104318156 100644 --- a/oneflow/user/ops/prelu_op.cpp +++ b/oneflow/user/ops/prelu_op.cpp @@ -14,105 +14,101 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("prelu") - .Input("x") - .Input("alpha") - .Output("y") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - Shape* y_shape = ctx->OutputShape("y", 0); - const Shape& alpha_shape = ctx->InputShape("alpha", 0); - CHECK_EQ_OR_RETURN(alpha_shape.NumAxes(), 1); - *y_shape = x_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - const user_op::TensorDesc& alpha_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("alpha", 0); - if (alpha_tensor.shape().At(0) != 1) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), 1) - .Split(user_op::OpArg("alpha", 0), 0) - .Split(user_op::OpArg("y", 0), 1) - .Build(); - } - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - if (i == 1) continue; - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Broadcast(user_op::OpArg("alpha", 0)) - .Split(user_op::OpArg("y", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe PreluOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + const user_op::TensorDesc& alpha_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("alpha", 0); + if (alpha_tensor.shape().At(0) != 1) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), 1) + .Split(user_op::OpArg("alpha", 0), 0) + .Split(user_op::OpArg("y", 0), 1) + .Build(); + } + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + if (i == 1) continue; + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Broadcast(user_op::OpArg("alpha", 0)) + .Split(user_op::OpArg("y", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe PreluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + Shape* y_shape = ctx->OutputShape("y", 0); + const Shape& alpha_shape = ctx->InputShape("alpha", 0); + CHECK_EQ_OR_RETURN(alpha_shape.NumAxes(), 1); + *y_shape = x_shape; + return Maybe::Ok(); +} +/*static*/ Maybe PreluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe PreluOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("prelu_grad") - .Input("dy") - .Input("x") - .Input("alpha") - .Output("dx") - .Output("alpha_diff") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - Shape* alpha_diff_shape = ctx->OutputShape("alpha_diff", 0); - const Shape& alpha_shape = ctx->InputShape("alpha", 0); - CHECK_EQ_OR_RETURN(alpha_shape.NumAxes(), 1); - CHECK_OR_RETURN((alpha_shape.At(0) == x_shape.At(1)) || (alpha_shape.At(0) == 1)); - CHECK_EQ_OR_RETURN(dy_shape, x_shape); - *dx_shape = x_shape; - *alpha_diff_shape = alpha_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), 0) - .Split(user_op::OpArg("x", 0), 0) - .Broadcast(user_op::OpArg("alpha", 0)) - .Split(user_op::OpArg("dx", 0), 0) - .PartialSum(user_op::OpArg("alpha_diff", 0)) - .Build(); - ctx->NewBuilder() - .PartialSum(user_op::OpArg("dy", 0)) - .Broadcast(user_op::OpArg("x", 0)) - .Broadcast(user_op::OpArg("alpha", 0)) - .PartialSum(user_op::OpArg("dx", 0)) - .PartialSum(user_op::OpArg("alpha_diff", 0)) - .Build(); - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), 1) - .Split(user_op::OpArg("x", 0), 1) - .Split(user_op::OpArg("alpha", 0), 0) - .Split(user_op::OpArg("dx", 0), 1) - .Split(user_op::OpArg("alpha_diff", 0), 0) - .Build(); - FOR_RANGE(int64_t, i, 1, x_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("alpha", 0), 0) - .Split(user_op::OpArg("dx", 0), i) - .Split(user_op::OpArg("alpha_diff", 0), 0) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - *ctx->OutputDType("alpha_diff", 0) = ctx->InputDType("alpha", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe PreluGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), 0) + .Split(user_op::OpArg("x", 0), 0) + .Broadcast(user_op::OpArg("alpha", 0)) + .Split(user_op::OpArg("dx", 0), 0) + .PartialSum(user_op::OpArg("alpha_diff", 0)) + .Build(); + ctx->NewBuilder() + .PartialSum(user_op::OpArg("dy", 0)) + .Broadcast(user_op::OpArg("x", 0)) + .Broadcast(user_op::OpArg("alpha", 0)) + .PartialSum(user_op::OpArg("dx", 0)) + .PartialSum(user_op::OpArg("alpha_diff", 0)) + .Build(); + ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), 1) + .Split(user_op::OpArg("x", 0), 1) + .Split(user_op::OpArg("alpha", 0), 0) + .Split(user_op::OpArg("dx", 0), 1) + .Split(user_op::OpArg("alpha_diff", 0), 0) + .Build(); + FOR_RANGE(int64_t, i, 1, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("alpha", 0), 0) + .Split(user_op::OpArg("dx", 0), i) + .Split(user_op::OpArg("alpha_diff", 0), 0) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe PreluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* alpha_diff_shape = ctx->OutputShape("alpha_diff", 0); + const Shape& alpha_shape = ctx->InputShape("alpha", 0); + CHECK_EQ_OR_RETURN(alpha_shape.NumAxes(), 1); + CHECK_OR_RETURN((alpha_shape.At(0) == x_shape.At(1)) || (alpha_shape.At(0) == 1)); + CHECK_EQ_OR_RETURN(dy_shape, x_shape); + *dx_shape = x_shape; + *alpha_diff_shape = alpha_shape; + return Maybe::Ok(); +} +/*static*/ Maybe PreluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe PreluGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->OutputDType("alpha_diff", 0) = ctx->InputDType("alpha", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("prelu").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/quantization_op.cpp b/oneflow/user/ops/quantization_op.cpp index b67e707113e..2396a1a1685 100644 --- a/oneflow/user/ops/quantization_op.cpp +++ b/oneflow/user/ops/quantization_op.cpp @@ -14,105 +14,93 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ Maybe QuantizationOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + const Shape& logical_scale_shape = + ctx->LogicalTensorDesc4InputArgNameAndIndex("scale", 0).shape(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("in", 0)) + .Broadcast(user_op::OpArg("scale", 0)) + .Broadcast(user_op::OpArg("zero_point", 0)) + .Broadcast(user_op::OpArg("out", 0)) + .Build(); + if (logical_scale_shape.elem_cnt() > 1) { + // NOTE(Liang Depeng): only consider convolution weight per-channel quantization + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), 0) + .Split(user_op::OpArg("scale", 0), 0) + .Split(user_op::OpArg("zero_point", 0), 0) + .Split(user_op::OpArg("out", 0), 0) + .Build(); + } else { + // NOTE(Liang Depeng): the sbp signature of per-layer quantization is the same as eltwise + // ops + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), 0) + .Broadcast(user_op::OpArg("scale", 0)) + .Broadcast(user_op::OpArg("zero_point", 0)) + .Split(user_op::OpArg("out", 0), 0) + .Build(); + } + FOR_RANGE(int64_t, i, 1, in_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), i) + .Broadcast(user_op::OpArg("scale", 0)) + .Broadcast(user_op::OpArg("zero_point", 0)) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe QuantizationOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + const Shape& scale_shape = ctx->InputShape("scale", 0); + const Shape& zero_point_shape = ctx->InputShape("zero_point", 0); -REGISTER_USER_OP("quantization") - .Input("in") - .Input("scale") - .Input("zero_point") - .Output("out") - // NOTE(Liang Depeng): "google" or "cambricon" - .Attr("quantization_formula", "google") - // NOTE(Liang Depeng): quantize from float32 to "quantization_bit" bit signed or unsigned - // integer - .Attr("quantization_bit", 8) - // NOTE(Liang Depeng): "symmetric" or "affine": quantize to signed or unsigned integer - .Attr("quantization_scheme", "symmetric") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - const Shape& scale_shape = ctx->InputShape("scale", 0); - const Shape& zero_point_shape = ctx->InputShape("zero_point", 0); + // NOTE(Liang Depeng): scale_shape->elem_cnt() > 1 means per-channel quantization for + // convolution weights. + if (scale_shape.elem_cnt() > 1) { + CHECK_EQ_OR_RETURN(scale_shape.elem_cnt(), in_shape.At(0)); + CHECK_EQ_OR_RETURN(zero_point_shape.elem_cnt(), in_shape.At(0)); + } - // NOTE(Liang Depeng): scale_shape->elem_cnt() > 1 means per-channel quantization for - // convolution weights. - if (scale_shape.elem_cnt() > 1) { - CHECK_EQ_OR_RETURN(scale_shape.elem_cnt(), in_shape.At(0)); - CHECK_EQ_OR_RETURN(zero_point_shape.elem_cnt(), in_shape.At(0)); - } + *ctx->OutputShape("out", 0) = in_shape; + return Maybe::Ok(); +} +/*static*/ Maybe QuantizationOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe QuantizationOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe QuantizationOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* scale = GetInputArgModifierFn("scale", 0); + CHECK_OR_RETURN(scale != nullptr); + scale->set_requires_grad(false); - *ctx->OutputShape("out", 0) = in_shape; - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* scale = GetInputArgModifierFn("scale", 0); - CHECK_OR_RETURN(scale != nullptr); - scale->set_requires_grad(false); + user_op::InputArgModifier* zero_point = GetInputArgModifierFn("zero_point", 0); + CHECK_OR_RETURN(zero_point != nullptr); + zero_point->set_requires_grad(false); + return Maybe::Ok(); +} +/*static*/ Maybe QuantizationOp::CheckAttr(const user_op::UserOpDefWrapper&, + const user_op::UserOpConfWrapper& op_conf) { + const int32_t quantization_bit = op_conf.attr("quantization_bit"); + CHECK_GT_OR_RETURN(quantization_bit, 1); + CHECK_LE_OR_RETURN(quantization_bit, 8); - user_op::InputArgModifier* zero_point = GetInputArgModifierFn("zero_point", 0); - CHECK_OR_RETURN(zero_point != nullptr); - zero_point->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - const Shape& logical_scale_shape = - ctx->LogicalTensorDesc4InputArgNameAndIndex("scale", 0).shape(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("in", 0)) - .Broadcast(user_op::OpArg("scale", 0)) - .Broadcast(user_op::OpArg("zero_point", 0)) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - if (logical_scale_shape.elem_cnt() > 1) { - // NOTE(Liang Depeng): only consider convolution weight per-channel quantization - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), 0) - .Split(user_op::OpArg("scale", 0), 0) - .Split(user_op::OpArg("zero_point", 0), 0) - .Split(user_op::OpArg("out", 0), 0) - .Build(); - } else { - // NOTE(Liang Depeng): the sbp signature of per-layer quantization is the same as eltwise - // ops - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), 0) - .Broadcast(user_op::OpArg("scale", 0)) - .Broadcast(user_op::OpArg("zero_point", 0)) - .Split(user_op::OpArg("out", 0), 0) - .Build(); - } - FOR_RANGE(int64_t, i, 1, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Broadcast(user_op::OpArg("scale", 0)) - .Broadcast(user_op::OpArg("zero_point", 0)) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& op_def, - const user_op::UserOpConfWrapper& op_conf) -> Maybe { - const int32_t quantization_bit = op_conf.attr("quantization_bit"); - CHECK_GT_OR_RETURN(quantization_bit, 1); - CHECK_LE_OR_RETURN(quantization_bit, 8); + std::string quantization_scheme = op_conf.attr("quantization_scheme"); + CHECK_OR_RETURN(quantization_scheme == "symmetric" || quantization_scheme == "affine"); - std::string quantization_scheme = op_conf.attr("quantization_scheme"); - CHECK_OR_RETURN(quantization_scheme == "symmetric" || quantization_scheme == "affine"); - - std::string quantization_formula = op_conf.attr("quantization_formula"); - CHECK_OR_RETURN(quantization_formula == "google" || quantization_formula == "cambricon"); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); - -} // namespace + std::string quantization_formula = op_conf.attr("quantization_formula"); + CHECK_OR_RETURN(quantization_formula == "google" || quantization_formula == "cambricon"); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/randperm_op.cpp b/oneflow/user/ops/randperm_op.cpp index 8b70d603775..ae52c9b2938 100644 --- a/oneflow/user/ops/randperm_op.cpp +++ b/oneflow/user/ops/randperm_op.cpp @@ -14,42 +14,28 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" -#include "oneflow/core/common/global.h" -#include "oneflow/core/common/multi_client.h" -#include "oneflow/core/common/protobuf.h" -#include "oneflow/core/job/global_for.h" -namespace oneflow { +#include "oneflow/core/framework/op_generated.h" -Maybe InferRandpermNdSbp(user_op::InferNdSbpFnContext* ctx); -REGISTER_NO_GRAD_USER_OP("randperm") - .Output("out") - .Attr("n") - .Attr("seed") - .Attr("nd_sbp") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - Shape* out_shape = ctx->OutputShape("out", 0); - int32_t n = ctx->Attr("n"); - CHECK_GE_OR_RETURN(n, 0); - *out_shape = Shape({n}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { return Maybe::Ok(); }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = DataType::kInt32; - return Maybe::Ok(); - }) - .SetNdSbpInferFn(&InferRandpermNdSbp); +namespace oneflow { -Maybe InferRandpermNdSbp(user_op::InferNdSbpFnContext* ctx) { - cfg::NdSbp* out = ctx->NdSbp4ArgNameAndIndex("out", 0); - if (JUST(IsMultiClient())) { - const auto& pb_str = ctx->user_op_conf().attr("nd_sbp"); - NdSbp pb; - CHECK_OR_RETURN(TxtString2PbMessage(pb_str, &pb)); - out->InitFromProto(pb); - } else { - out->mutable_sbp_parallel()->Add()->mutable_broadcast_parallel(); - } +/*static*/ Maybe RandpermOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + cfg::SbpParallel default_sbp; + default_sbp.mutable_broadcast_parallel(); + return user_op::InferNdSbp4SrcOp(ctx, default_sbp); +} +/*static*/ Maybe RandpermOp::GetSbp(user_op::SbpContext* ctx) { return Maybe::Ok(); } +/*static*/ Maybe RandpermOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + Shape* out_shape = ctx->OutputShape("out", 0); + int32_t n = ctx->Attr("n"); + CHECK_GE_OR_RETURN(n, 0); + *out_shape = Shape({n}); + return Maybe::Ok(); +} +/*static*/ Maybe RandpermOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe RandpermOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = DataType::kInt32; return Maybe::Ok(); } diff --git a/oneflow/user/ops/reduce_like_ops.cpp b/oneflow/user/ops/reduce_like_ops.cpp index c8fc889debc..898e81ade46 100644 --- a/oneflow/user/ops/reduce_like_ops.cpp +++ b/oneflow/user/ops/reduce_like_ops.cpp @@ -15,81 +15,80 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/reduce_sbp_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("reduce_sum_like") - .Input("x") - .Input("like") - .Output("y") - .Attr>("axis") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& like_tensor = ctx->InputTensorDesc("like", 0); - const auto& axis = ctx->Attr>("axis"); - if (axis.empty()) { CHECK_EQ_OR_RETURN(x_tensor.shape(), like_tensor.shape()); } - user_op::TensorDesc* y_tensor = ctx->OutputTensorDesc("y", 0); - *y_tensor->mut_shape() = like_tensor.shape(); - *y_tensor->mut_is_dynamic() = like_tensor.is_dynamic(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - int32_t num_axes = 0; - HashSet conf_axes; - { - const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - num_axes = in_tensor.shape().NumAxes(); - const auto& reduced_axes = ctx->Attr>("axis"); - ReduceSbpUtil::GetRegularAxes(num_axes, reduced_axes, &conf_axes); - } - const auto& like_num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape().NumAxes(); - const bool keep_dims = (num_axes == like_num_axes); - if (!keep_dims) { CHECK_EQ_OR_RETURN(conf_axes.size(), num_axes - like_num_axes); } - auto IsReducedAxis = ReduceSbpUtil::MakePredicatorIsReducedAxis(conf_axes, num_axes); - int64_t num_reduced_axes = 0; - FOR_RANGE(int64_t, i, 0, num_axes) { - if (IsReducedAxis(i)) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Broadcast(user_op::OpArg("like", 0)) - .PartialSum(user_op::OpArg("y", 0)) - .Build(); - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .PartialSum(user_op::OpArg("like", 0)) - .PartialSum(user_op::OpArg("y", 0)) - .Build(); - num_reduced_axes += 1; - } else { - const int64_t out_split_axis = keep_dims ? i : i - num_reduced_axes; - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("like", 0), out_split_axis) - .Split(user_op::OpArg("y", 0), out_split_axis) - .Build(); - } - } +/*static*/ Maybe ReduceSumLikeOp::GetSbp(user_op::SbpContext* ctx) { + int32_t num_axes = 0; + HashSet conf_axes; + { + const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + num_axes = in_tensor.shape().NumAxes(); + const auto& reduced_axes = ctx->Attr>("axis"); + ReduceSbpUtil::GetRegularAxes(num_axes, reduced_axes, &conf_axes); + } + const auto& like_num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape().NumAxes(); + const bool keep_dims = (num_axes == like_num_axes); + if (!keep_dims) { CHECK_EQ_OR_RETURN(conf_axes.size(), num_axes - like_num_axes); } + auto IsReducedAxis = ReduceSbpUtil::MakePredicatorIsReducedAxis(conf_axes, num_axes); + int64_t num_reduced_axes = 0; + FOR_RANGE(int64_t, i, 0, num_axes) { + if (IsReducedAxis(i)) { ctx->NewBuilder() - .Broadcast(user_op::OpArg("x", 0)) + .Split(user_op::OpArg("x", 0), i) + .Broadcast(user_op::OpArg("like", 0)) + .PartialSum(user_op::OpArg("y", 0)) + .Build(); + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) .PartialSum(user_op::OpArg("like", 0)) - .Broadcast(user_op::OpArg("y", 0)) + .PartialSum(user_op::OpArg("y", 0)) + .Build(); + num_reduced_axes += 1; + } else { + const int64_t out_split_axis = keep_dims ? i : i - num_reduced_axes; + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("like", 0), out_split_axis) + .Split(user_op::OpArg("y", 0), out_split_axis) .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& like_tensor = ctx->InputTensorDesc("like", 0); - CHECK_EQ_OR_RETURN(x_tensor.data_type(), like_tensor.data_type()); - *ctx->OutputDType("y", 0) = like_tensor.data_type(); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* like_arg_modifier = GetInputArgModifierFn("like", 0); - CHECK_OR_RETURN(like_arg_modifier != nullptr); - like_arg_modifier->set_requires_grad(false); - return Maybe::Ok(); - }); + } + } + ctx->NewBuilder() + .Broadcast(user_op::OpArg("x", 0)) + .PartialSum(user_op::OpArg("like", 0)) + .Broadcast(user_op::OpArg("y", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe ReduceSumLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& like_tensor = ctx->InputTensorDesc("like", 0); + const auto& axis = ctx->Attr>("axis"); + if (axis.empty()) { CHECK_EQ_OR_RETURN(x_tensor.shape(), like_tensor.shape()); } + user_op::TensorDesc* y_tensor = ctx->OutputTensorDesc("y", 0); + *y_tensor->mut_shape() = like_tensor.shape(); + *y_tensor->mut_is_dynamic() = like_tensor.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ Maybe ReduceSumLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ReduceSumLikeOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& like_tensor = ctx->InputTensorDesc("like", 0); + CHECK_EQ_OR_RETURN(x_tensor.data_type(), like_tensor.data_type()); + *ctx->OutputDType("y", 0) = like_tensor.data_type(); + return Maybe::Ok(); +} +/*static*/ Maybe ReduceSumLikeOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* like_arg_modifier = GetInputArgModifierFn("like", 0); + CHECK_OR_RETURN(like_arg_modifier != nullptr); + like_arg_modifier->set_requires_grad(false); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/reduce_ops.cpp b/oneflow/user/ops/reduce_ops.cpp index 68abea62e7d..610c01a5aaf 100644 --- a/oneflow/user/ops/reduce_ops.cpp +++ b/oneflow/user/ops/reduce_ops.cpp @@ -16,6 +16,7 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/reduce_sbp_util.h" #include "oneflow/core/ndarray/binary_func.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -83,32 +84,27 @@ Maybe GetSbpFn(user_op::SbpContext* ctx) { return Maybe::Ok(); } -#define REGISTER_REDUCE_USER_OP(op_name, binary_func) \ - REGISTER_USER_OP(op_name) \ - .Input("input_tensor") \ - .Output("output_tensor") \ - .Attr>("axis") \ - .Attr("keepdims") \ - .SetTensorDescInferFn(InferTensorDescFn) \ - .SetGetSbpFn(GetSbpFn) \ - .SetDataTypeInferFn(InferDataType); - -#define REGISTER_REDUCE_LOGICAL_USER_OP(op_name, binary_func) \ - REGISTER_USER_OP(op_name) \ - .Input("input_tensor") \ - .Output("output_tensor") \ - .Attr>("axis") \ - .Attr("keepdims") \ - .SetTensorDescInferFn(InferTensorDescFn) \ - .SetGetSbpFn(GetSbpFn) \ - .SetDataTypeInferFn(InferLogicalDataType); - -REGISTER_REDUCE_LOGICAL_USER_OP("reduce_any", BinaryFuncAny) -REGISTER_REDUCE_LOGICAL_USER_OP("reduce_all", BinaryFuncAll) -REGISTER_REDUCE_USER_OP("reduce_min", BinaryFuncMin) -REGISTER_REDUCE_USER_OP("reduce_prod", BinaryFuncProd) -REGISTER_REDUCE_USER_OP("reduce_sum", BinaryFuncSum) -REGISTER_REDUCE_USER_OP("reduce_max", BinaryFuncMax) +#define IMPLEMENT_REDUCE_OP_FUNCS(name, binary_func, infer_dtype_func) \ + /*static*/ Maybe name##Op::GetSbp(user_op::SbpContext* ctx) { \ + return GetSbpFn(ctx); \ + } \ + /*static*/ Maybe name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return InferTensorDescFn(ctx); \ + } \ + /*static*/ Maybe name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe name##Op::InferDataType(user_op::InferContext* ctx) { \ + return infer_dtype_func(ctx); \ + } + +IMPLEMENT_REDUCE_OP_FUNCS(ReduceAny, BinaryFuncAny, InferLogicalDataType) +IMPLEMENT_REDUCE_OP_FUNCS(ReduceAll, BinaryFuncAll, InferLogicalDataType) +IMPLEMENT_REDUCE_OP_FUNCS(ReduceMin, BinaryFuncMin, oneflow::InferDataType) +IMPLEMENT_REDUCE_OP_FUNCS(ReduceMax, BinaryFuncMax, oneflow::InferDataType) +IMPLEMENT_REDUCE_OP_FUNCS(ReduceSum, BinaryFuncSum, oneflow::InferDataType) +IMPLEMENT_REDUCE_OP_FUNCS(ReduceProd, BinaryFuncProd, oneflow::InferDataType) +#undef IMPLEMENT_REDUCE_OP_FUNCS REGISTER_USER_OP_GRAD("reduce_sum") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/relu_op.cpp b/oneflow/user/ops/relu_op.cpp index d2ae5b6bf23..52fb55fdc22 100644 --- a/oneflow/user/ops/relu_op.cpp +++ b/oneflow/user/ops/relu_op.cpp @@ -14,76 +14,73 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ Maybe ReluOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe ReluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("x", 0); + Shape* out_shape = ctx->OutputShape("y", 0); + *out_shape = in_shape; + return Maybe::Ok(); +} +/*static*/ Maybe ReluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ReluOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("relu") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); - *out_shape = in_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ReluGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); + FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("y", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe ReluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& y_shape = ctx->InputShape("y", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(dy_shape == y_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} +/*static*/ Maybe ReluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ReluGradOp::InferDataType(user_op::InferContext* ctx) { + const DataType& data_type = ctx->InputDType("y", 0); + CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), data_type); + *ctx->OutputDType("dx", 0) = data_type; + return Maybe::Ok(); +} -REGISTER_USER_OP("relu_grad") - .Input("y") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& y_shape = ctx->InputShape("y", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(dy_shape == y_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); - FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("y", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const DataType& data_type = ctx->InputDType("y", 0); - CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), data_type); - *ctx->OutputDType("dx", 0) = data_type; - return Maybe::Ok(); - }); +namespace { REGISTER_USER_OP_GRAD("relu").SetBackwardOpConfGenFn( [](user_op::BackwardOpConfContext* ctx) -> Maybe { const auto relu_grad_op_name = ctx->FwOp().op_name() + "_grad"; ctx->DefineOp(relu_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { return builder.OpTypeName("relu_grad") - .InputBind("y", ctx->FwOp().output("out", 0)) - .InputBind("dy", ctx->FwOp().output_grad("out", 0)) + .InputBind("y", ctx->FwOp().output("y", 0)) + .InputBind("dy", ctx->FwOp().output_grad("y", 0)) .Output("dx") .Build(); }); - ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), + ctx->FwOp().InputGradBind(user_op::OpArg("x", 0), [&ctx, &relu_grad_op_name]() -> const std::string& { return ctx->GetOp(relu_grad_op_name).output("dx", 0); }); diff --git a/oneflow/user/ops/repeat_op.cpp b/oneflow/user/ops/repeat_op.cpp index 4098ac9d00e..2b087308603 100644 --- a/oneflow/user/ops/repeat_op.cpp +++ b/oneflow/user/ops/repeat_op.cpp @@ -14,45 +14,42 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ Maybe RepeatOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe RepeatOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe RepeatOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe RepeatOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe RepeatOp::InferOutputBlobTimeShape( + user_op::InferOutputBlobTimeShapeFnContext* ctx) { + DimVector dim_vec(ctx->TimeShape4InputArgNameAndIndex("in", 0).dim_vec()); + dim_vec.emplace_back(ctx->user_op_conf().attr("repeat_num")); + *ctx->mut_output_blob_time_shape() = Shape(dim_vec); + return Maybe::Ok(); +} -REGISTER_USER_OP("repeat") - .Input("in") - .Output("out") - .Attr("repeat_num") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetOutputBlobTimeShapeInferFn( - [](user_op::InferOutputBlobTimeShapeFnContext* ctx) -> Maybe { - DimVector dim_vec(ctx->TimeShape4InputArgNameAndIndex("in", 0).dim_vec()); - dim_vec.emplace_back(ctx->user_op_conf().attr("repeat_num")); - *ctx->mut_output_blob_time_shape() = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +namespace { REGISTER_USER_OP_GRAD("repeat").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { diff --git a/oneflow/user/ops/reshape_like_op.cpp b/oneflow/user/ops/reshape_like_op.cpp index af758ec3414..3c6b3d720fa 100644 --- a/oneflow/user/ops/reshape_like_op.cpp +++ b/oneflow/user/ops/reshape_like_op.cpp @@ -15,60 +15,53 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/reshape_user_op_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -Maybe InferNdSbpFn(user_op::InferNdSbpFnContext* ctx) { +/*static*/ Maybe ReshapeLikeOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { const Shape& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); const Shape& out_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape(); return ReshapeUserOpUtil::InferNdSbp(ctx, in_shape, out_shape); } - -} // namespace - -REGISTER_USER_OP("reshape_like") - .Input("in") - .Input("like") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - const Shape& like_shape = ctx->InputShape("like", 0); - CHECK_EQ_OR_RETURN(in_shape.elem_cnt(), like_shape.elem_cnt()); - *ctx->OutputShape("out", 0) = like_shape; - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); - CHECK_NOTNULL_OR_RETURN(like_modifier); - like_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); - const auto& like_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape(); - ctx->NewBuilder() - .PartialSum(user_op::OpArg("like", 0)) - .Broadcast(user_op::OpArg("in", 0)) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("like", 0)) - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - user_op::UserOpSbpSignatureBuilder builder = ctx->NewBuilder(); - return ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures(in_shape, like_shape, {{"in", 0}}, - {{"like", 0}, {"out", 0}}, - ctx->parallel_num(), &builder); - }) - .SetNdSbpInferFn(InferNdSbpFn) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ReshapeLikeOp::GetSbp(user_op::SbpContext* ctx) { + const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); + const auto& like_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape(); + ctx->NewBuilder() + .PartialSum(user_op::OpArg("like", 0)) + .Broadcast(user_op::OpArg("in", 0)) + .Broadcast(user_op::OpArg("out", 0)) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("like", 0)) + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + user_op::UserOpSbpSignatureBuilder builder = ctx->NewBuilder(); + return ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures( + in_shape, like_shape, {{"in", 0}}, {{"like", 0}, {"out", 0}}, ctx->parallel_num(), &builder); +} +/*static*/ Maybe ReshapeLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + const Shape& like_shape = ctx->InputShape("like", 0); + CHECK_EQ_OR_RETURN(in_shape.elem_cnt(), like_shape.elem_cnt()); + *ctx->OutputShape("out", 0) = like_shape; + return Maybe::Ok(); +} +/*static*/ Maybe ReshapeLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ReshapeLikeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe ReshapeLikeOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); + CHECK_NOTNULL_OR_RETURN(like_modifier); + like_modifier->set_requires_grad(false); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("reshape_like") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/reshape_op.cpp b/oneflow/user/ops/reshape_op.cpp index 2be0133bdda..4a69084129e 100644 --- a/oneflow/user/ops/reshape_op.cpp +++ b/oneflow/user/ops/reshape_op.cpp @@ -17,12 +17,11 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/reshape_user_op_util.h" #include "oneflow/core/operator/operator.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -Maybe GetSbpFn(user_op::SbpContext* ctx) { +/*static*/ Maybe ReshapeOp::GetSbp(user_op::SbpContext* ctx) { const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); const Shape& shape = ctx->Attr("shape"); const auto& outshape = JUST(ReshapeUserOpUtil::GetLogicalOutBlobShape(in_shape, shape)); @@ -31,14 +30,14 @@ Maybe GetSbpFn(user_op::SbpContext* ctx) { in_shape, *outshape, {{"in", 0}}, {{"out", 0}}, ctx->parallel_num(), &builder); } -Maybe InferNdSbpFn(user_op::InferNdSbpFnContext* ctx) { +/*static*/ Maybe ReshapeOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { const Shape& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); const Shape& shape = ctx->user_op_conf().attr("shape"); const auto& out_shape = JUST(ReshapeUserOpUtil::GetLogicalOutBlobShape(in_shape, shape)); return ReshapeUserOpUtil::InferNdSbp(ctx, in_shape, *out_shape); } -Maybe LogicalTensorDescInferFn(user_op::InferContext* ctx) { +/*static*/ Maybe ReshapeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { Shape shape = ctx->Attr("shape"); const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out_tensor_desc = ctx->OutputTensorDesc("out", 0); @@ -70,7 +69,7 @@ Maybe LogicalTensorDescInferFn(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe TensorDescInferFn(user_op::InferContext* ctx) { +/*static*/ Maybe ReshapeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { Shape logical_shape = ctx->Attr("shape"); const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out_tensor_desc = ctx->OutputTensorDesc("out", 0); @@ -115,20 +114,12 @@ Maybe TensorDescInferFn(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe InferDataType(user_op::InferContext* ctx) { +/*static*/ Maybe ReshapeOp::InferDataType(user_op::InferContext* ctx) { *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } -REGISTER_USER_OP("reshape") - .Input("in") - .Output("out") - .Attr("shape") - .SetLogicalTensorDescInferFn(LogicalTensorDescInferFn) - .SetPhysicalTensorDescInferFn(TensorDescInferFn) - .SetGetSbpFn(GetSbpFn) - .SetNdSbpInferFn(InferNdSbpFn) - .SetDataTypeInferFn(InferDataType); +namespace { REGISTER_USER_OP_GRAD("reshape").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/roi_align_op.cpp b/oneflow/user/ops/roi_align_op.cpp index 58ff83e6419..15568d7a672 100644 --- a/oneflow/user/ops/roi_align_op.cpp +++ b/oneflow/user/ops/roi_align_op.cpp @@ -14,12 +14,19 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -Maybe InferRoiAlignTensorDesc(user_op::InferContext* ctx) { +/*static*/ Maybe RoiAlignOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), 0) + .Split(user_op::OpArg("rois", 0), 0) + .Split(user_op::OpArg("y", 0), 0) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe RoiAlignOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& rois_shape = ctx->InputShape("rois", 0); const int32_t pooled_h = ctx->Attr("pooled_h"); @@ -33,8 +40,34 @@ Maybe InferRoiAlignTensorDesc(user_op::InferContext* ctx) { *ctx->OutputShape("y", 0) = Shape({rois_shape.At(0), x_shape.At(1), pooled_h, pooled_w}); return Maybe::Ok(); } +/*static*/ Maybe RoiAlignOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe RoiAlignOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe RoiAlignOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* roi_modifier = GetInputArgModifierFn("rois", 0); + CHECK(roi_modifier != nullptr); + roi_modifier->set_requires_grad(false); + user_op::InputArgModifier* feat_modifier = GetInputArgModifierFn("x", 0); + CHECK(feat_modifier != nullptr); + feat_modifier->set_requires_grad(true); + return Maybe::Ok(); +} -Maybe InferRoiAlignGradTensorDesc(user_op::InferContext* ctx) { +/*static*/ Maybe RoiAlignGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), 0) + .Split(user_op::OpArg("x_like", 0), 0) + .Split(user_op::OpArg("rois", 0), 0) + .Split(user_op::OpArg("dx", 0), 0) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe RoiAlignGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); const Shape& x_like_shape = ctx->InputShape("x_like", 0); const Shape& rois_shape = ctx->InputShape("rois", 0); @@ -51,47 +84,16 @@ Maybe InferRoiAlignGradTensorDesc(user_op::InferContext* ctx) { *ctx->OutputShape("dx", 0) = x_like_shape; return Maybe::Ok(); } - -Maybe InferRoiAlignDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); +/*static*/ Maybe RoiAlignGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); } - -Maybe InferRoiAlignGradDataType(user_op::InferContext* ctx) { +/*static*/ Maybe RoiAlignGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x_like", 0)); *ctx->OutputDType("dx", 0) = ctx->InputDType("x_like", 0); return Maybe::Ok(); } -Maybe RoiAlignSbpFn(user_op::SbpContext* ctx) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), 0) - .Split(user_op::OpArg("rois", 0), 0) - .Split(user_op::OpArg("y", 0), 0) - .Build(); - return Maybe::Ok(); -} - -Maybe RoiAlignGradSbpFn(user_op::SbpContext* ctx) { - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), 0) - .Split(user_op::OpArg("x_like", 0), 0) - .Split(user_op::OpArg("rois", 0), 0) - .Split(user_op::OpArg("dx", 0), 0) - .Build(); - return Maybe::Ok(); -} - -Maybe RoiAlignArgModifier(const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { - user_op::InputArgModifier* roi_modifier = GetInputArgModifierFn("rois", 0); - CHECK(roi_modifier != nullptr); - roi_modifier->set_requires_grad(false); - user_op::InputArgModifier* feat_modifier = GetInputArgModifierFn("x", 0); - CHECK(feat_modifier != nullptr); - feat_modifier->set_requires_grad(true); - return Maybe::Ok(); -} +namespace { Maybe GenerateBackwardOpConf4RoiAlign(const user_op::UserOpWrapper& op, const user_op::AddOpFn& AddOp) { @@ -117,34 +119,6 @@ Maybe GenerateBackwardOpConf4RoiAlign(const user_op::UserOpWrapper& op, } // namespace -REGISTER_USER_OP("roi_align") - .Input("x") - .Input("rois") - .Output("y") - .Attr("pooled_h") - .Attr("pooled_w") - .Attr("spatial_scale") - .Attr("sampling_ratio") - .Attr("aligned") - .SetTensorDescInferFn(InferRoiAlignTensorDesc) - .SetDataTypeInferFn(InferRoiAlignDataType) - .SetGetSbpFn(RoiAlignSbpFn) - .SetInputArgModifyFn(RoiAlignArgModifier); - -REGISTER_USER_OP("roi_align_grad") - .Input("dy") - .Input("x_like") - .Input("rois") - .Output("dx") - .Attr("pooled_h") - .Attr("pooled_w") - .Attr("spatial_scale") - .Attr("sampling_ratio") - .Attr("aligned") - .SetTensorDescInferFn(InferRoiAlignGradTensorDesc) - .SetDataTypeInferFn(InferRoiAlignGradDataType) - .SetGetSbpFn(RoiAlignGradSbpFn); - REGISTER_USER_OP_GRAD("roi_align").SetGenBackwardOpConfFn(GenerateBackwardOpConf4RoiAlign); } // namespace oneflow diff --git a/oneflow/user/ops/roll_op.cpp b/oneflow/user/ops/roll_op.cpp index cf86f20397c..b07077d814b 100644 --- a/oneflow/user/ops/roll_op.cpp +++ b/oneflow/user/ops/roll_op.cpp @@ -14,48 +14,47 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("roll") - .Input("in") - .Output("out") - .Attr>("shifts") - .Attr>("dims") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - *ctx->OutputShape("out", 0) = in_shape; - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - const std::vector& dims = ctx->Attr>("dims"); +/*static*/ Maybe RollOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + const std::vector& dims = ctx->Attr>("dims"); - CHECK_GT_OR_RETURN(dims.size(), 0); + CHECK_GT_OR_RETURN(dims.size(), 0); - // NOTE(Liang Depeng): (dims.size == 1 && dims[0] == -1) means that user call flow.roll with - // dims == None - if (dims[0] != -1) { - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - if (std::find(dims.begin(), dims.end(), i) == dims.end()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - } + // NOTE(Liang Depeng): (dims.size == 1 && dims[0] == -1) means that user call flow.roll with + // dims == None + if (dims[0] != -1) { + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + if (std::find(dims.begin(), dims.end(), i) == dims.end()) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); } + } + } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }); + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe RollOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + *ctx->OutputShape("out", 0) = in_shape; + return Maybe::Ok(); +} +/*static*/ Maybe RollOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe RollOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("roll").SetGenBackwardOpConfFn( [](const user_op::UserOpWrapper& op, const user_op::AddOpFn& AddOp) -> Maybe { diff --git a/oneflow/user/ops/same_padding_op.cpp b/oneflow/user/ops/same_padding_op.cpp index b54d705df65..e643232ba66 100644 --- a/oneflow/user/ops/same_padding_op.cpp +++ b/oneflow/user/ops/same_padding_op.cpp @@ -16,14 +16,26 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/user/ops/nn_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace user_op { -namespace { -Maybe SamePaddingTensorDescInferFn(user_op::InferContext* ctx) { - const TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); +/*static*/ Maybe SamePaddingOp::GetSbp(user_op::SbpContext* ctx) { + const int32_t num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("x_like", 0).shape().NumAxes(); + const std::string& data_format = ctx->Attr("data_format"); + ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); + const int32_t channel_idx = ChannelIdx(data_format, num_axes); + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), channel_idx) + .Split(user_op::OpArg("y", 0), channel_idx) + .Build(); + ctx->NewBuilder().PartialSum(user_op::OpArg("x", 0)).PartialSum(user_op::OpArg("y", 0)).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe SamePaddingOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); *y_desc->mut_shape() = x_desc.shape(); *y_desc->mut_is_dynamic() = x_desc.is_dynamic(); const std::string& data_format = ctx->Attr("data_format"); @@ -46,88 +58,58 @@ Maybe SamePaddingTensorDescInferFn(user_op::InferContext* ctx) { *y_desc->mut_shape() = Shape(y_dim_vec); return Maybe::Ok(); } -} // namespace - -REGISTER_USER_OP("same_padding") - .Input("x") - .Output("y") - .Attr("padding") - .Attr("data_format") - .Attr>("kernel_size") - .Attr>("strides") - .Attr>("dilation_rate") - .SetTensorDescInferFn(SamePaddingTensorDescInferFn) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const int32_t num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("x_like", 0).shape().NumAxes(); - const std::string& data_format = ctx->Attr("data_format"); - ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); - const int32_t channel_idx = ChannelIdx(data_format, num_axes); - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), channel_idx) - .Split(user_op::OpArg("y", 0), channel_idx) - .Build(); - ctx->NewBuilder() - .PartialSum(user_op::OpArg("x", 0)) - .PartialSum(user_op::OpArg("y", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe SamePaddingOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SamePaddingOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("same_padding_grad") - .Input("x_like") - .Input("dy") - .Output("dx") - .Attr("padding") - .Attr("data_format") - .Attr>("kernel_size") - .Attr>("strides") - .Attr>("dilation_rate") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("dx", 0) = ctx->InputShape("x_like", 0); - *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x_like", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const int32_t num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("x_like", 0).shape().NumAxes(); - const std::string& data_format = ctx->Attr("data_format"); - ctx->NewBuilder() - .Split(user_op::OpArg("x_like", 0), 0) - .Split(user_op::OpArg("dy", 0), 0) - .Split(user_op::OpArg("dx", 0), 0) - .Build(); - const int32_t channel_idx = ChannelIdx(data_format, num_axes); - ctx->NewBuilder() - .Split(user_op::OpArg("x_like", 0), channel_idx) - .Split(user_op::OpArg("dy", 0), channel_idx) - .Split(user_op::OpArg("dx", 0), channel_idx) - .Build(); - ctx->NewBuilder() - .PartialSum(user_op::OpArg("x_like", 0)) - .PartialSum(user_op::OpArg("dy", 0)) - .PartialSum(user_op::OpArg("dx", 0)) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("x_like", 0)) - .PartialSum(user_op::OpArg("dy", 0)) - .PartialSum(user_op::OpArg("dx", 0)) - .Build(); - ctx->NewBuilder() - .PartialSum(user_op::OpArg("x_like", 0)) - .Broadcast(user_op::OpArg("dy", 0)) - .Broadcast(user_op::OpArg("dx", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("x_like", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe SamePaddingGradOp::GetSbp(user_op::SbpContext* ctx) { + const int32_t num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("x_like", 0).shape().NumAxes(); + const std::string& data_format = ctx->Attr("data_format"); + ctx->NewBuilder() + .Split(user_op::OpArg("x_like", 0), 0) + .Split(user_op::OpArg("dy", 0), 0) + .Split(user_op::OpArg("dx", 0), 0) + .Build(); + const int32_t channel_idx = ChannelIdx(data_format, num_axes); + ctx->NewBuilder() + .Split(user_op::OpArg("x_like", 0), channel_idx) + .Split(user_op::OpArg("dy", 0), channel_idx) + .Split(user_op::OpArg("dx", 0), channel_idx) + .Build(); + ctx->NewBuilder() + .PartialSum(user_op::OpArg("x_like", 0)) + .PartialSum(user_op::OpArg("dy", 0)) + .PartialSum(user_op::OpArg("dx", 0)) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("x_like", 0)) + .PartialSum(user_op::OpArg("dy", 0)) + .PartialSum(user_op::OpArg("dx", 0)) + .Build(); + ctx->NewBuilder() + .PartialSum(user_op::OpArg("x_like", 0)) + .Broadcast(user_op::OpArg("dy", 0)) + .Broadcast(user_op::OpArg("dx", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe SamePaddingGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("dx", 0) = ctx->InputShape("x_like", 0); + *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x_like", 0); + return Maybe::Ok(); +} +/*static*/ Maybe SamePaddingGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SamePaddingGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("x_like", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("same_padding") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -156,5 +138,4 @@ REGISTER_USER_OP_GRAD("same_padding") return Maybe::Ok(); }); -} // namespace user_op } // namespace oneflow diff --git a/oneflow/user/ops/scalar_by_tensor_op.cpp b/oneflow/user/ops/scalar_by_tensor_op.cpp index 0ec8c0adfe4..f5420517a67 100644 --- a/oneflow/user/ops/scalar_by_tensor_op.cpp +++ b/oneflow/user/ops/scalar_by_tensor_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -61,20 +62,82 @@ GetSbpFn MakeGetSbpFn(GetSbpFn extra) { } // namespace -REGISTER_USER_OP("scalar_add_by_tensor") - .Input("x") - .Input("scalar") - .Output("y") - .SetTensorDescInferFn(TensorDescInferFn) - .SetDataTypeInferFn(DataTypeInferFn) - .SetGetSbpFn(MakeGetSbpFn([](user_op::SbpContext* ctx) { - ctx->NewBuilder() - .PartialSum(user_op::OpArg("x", 0)) - .PartialSum(user_op::OpArg("scalar", 0)) - .PartialSum(user_op::OpArg("y", 0)) - .Build(); - return Maybe::Ok(); - })); +/*static*/ Maybe ScalarAddByTensorOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .PartialSum(user_op::OpArg("x", 0)) + .PartialSum(user_op::OpArg("scalar", 0)) + .PartialSum(user_op::OpArg("y", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe ScalarAddByTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return TensorDescInferFn(ctx); +} +/*static*/ Maybe ScalarAddByTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ScalarAddByTensorOp::InferDataType(user_op::InferContext* ctx) { + return DataTypeInferFn(ctx); +} + +/*static*/ Maybe ScalarSubByTensorOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .PartialSum(user_op::OpArg("x", 0)) + .PartialSum(user_op::OpArg("scalar", 0)) + .PartialSum(user_op::OpArg("y", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe ScalarSubByTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return TensorDescInferFn(ctx); +} +/*static*/ Maybe ScalarSubByTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ScalarSubByTensorOp::InferDataType(user_op::InferContext* ctx) { + return DataTypeInferFn(ctx); +} + +/*static*/ Maybe ScalarMulByTensorOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .PartialSum(user_op::OpArg("x", 0)) + .Broadcast(user_op::OpArg("scalar", 0)) + .PartialSum(user_op::OpArg("y", 0)) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("x", 0)) + .PartialSum(user_op::OpArg("scalar", 0)) + .PartialSum(user_op::OpArg("y", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe ScalarMulByTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return TensorDescInferFn(ctx); +} +/*static*/ Maybe ScalarMulByTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ScalarMulByTensorOp::InferDataType(user_op::InferContext* ctx) { + return DataTypeInferFn(ctx); +} + +/*static*/ Maybe ScalarDivByTensorOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .PartialSum(user_op::OpArg("x", 0)) + .Broadcast(user_op::OpArg("scalar", 0)) + .PartialSum(user_op::OpArg("y", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe ScalarDivByTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return TensorDescInferFn(ctx); +} +/*static*/ Maybe ScalarDivByTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ScalarDivByTensorOp::InferDataType(user_op::InferContext* ctx) { + return DataTypeInferFn(ctx); +} REGISTER_USER_OP_GRAD("scalar_add_by_tensor") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -99,21 +162,6 @@ REGISTER_USER_OP_GRAD("scalar_add_by_tensor") return Maybe::Ok(); }); -REGISTER_USER_OP("scalar_sub_by_tensor") - .Input("x") - .Input("scalar") - .Output("y") - .SetTensorDescInferFn(TensorDescInferFn) - .SetDataTypeInferFn(DataTypeInferFn) - .SetGetSbpFn(MakeGetSbpFn([](user_op::SbpContext* ctx) { - ctx->NewBuilder() - .PartialSum(user_op::OpArg("x", 0)) - .PartialSum(user_op::OpArg("scalar", 0)) - .PartialSum(user_op::OpArg("y", 0)) - .Build(); - return Maybe::Ok(); - })); - REGISTER_USER_OP_GRAD("scalar_sub_by_tensor") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { @@ -148,26 +196,6 @@ REGISTER_USER_OP_GRAD("scalar_sub_by_tensor") return Maybe::Ok(); }); -REGISTER_USER_OP("scalar_mul_by_tensor") - .Input("x") - .Input("scalar") - .Output("y") - .SetTensorDescInferFn(TensorDescInferFn) - .SetDataTypeInferFn(DataTypeInferFn) - .SetGetSbpFn(MakeGetSbpFn([](user_op::SbpContext* ctx) { - ctx->NewBuilder() - .PartialSum(user_op::OpArg("x", 0)) - .Broadcast(user_op::OpArg("scalar", 0)) - .PartialSum(user_op::OpArg("y", 0)) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("x", 0)) - .PartialSum(user_op::OpArg("scalar", 0)) - .PartialSum(user_op::OpArg("y", 0)) - .Build(); - return Maybe::Ok(); - })); - REGISTER_USER_OP_GRAD("scalar_mul_by_tensor") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { @@ -208,21 +236,6 @@ REGISTER_USER_OP_GRAD("scalar_mul_by_tensor") return Maybe::Ok(); }); -REGISTER_USER_OP("scalar_div_by_tensor") - .Input("x") - .Input("scalar") - .Output("y") - .SetTensorDescInferFn(TensorDescInferFn) - .SetDataTypeInferFn(DataTypeInferFn) - .SetGetSbpFn(MakeGetSbpFn([](user_op::SbpContext* ctx) { - ctx->NewBuilder() - .PartialSum(user_op::OpArg("x", 0)) - .Broadcast(user_op::OpArg("scalar", 0)) - .PartialSum(user_op::OpArg("y", 0)) - .Build(); - return Maybe::Ok(); - })); - REGISTER_USER_OP_GRAD("scalar_div_by_tensor") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/scalar_logical_op.cpp b/oneflow/user/ops/scalar_logical_op.cpp index 7f33dfa66c4..7bd176790ee 100644 --- a/oneflow/user/ops/scalar_logical_op.cpp +++ b/oneflow/user/ops/scalar_logical_op.cpp @@ -14,42 +14,39 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -#define REGISTER_SCALAR_LOGICAL_OP(op_name) \ - REGISTER_NO_GRAD_USER_OP(op_name) \ - .Input("in") \ - .Output("out") \ - .Attr("has_int_operand") \ - .Attr("has_float_operand") \ - .Attr("int_operand") \ - .Attr("float_operand") \ - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { \ - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); \ - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); \ - return Maybe::Ok(); \ - }) \ - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { \ - const user_op::TensorDesc& in_tensor = \ - ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); \ - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { \ - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); \ - } \ - return Maybe::Ok(); \ - }) \ - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { \ - *ctx->OutputDType("out", 0) = DataType::kInt8; \ - return Maybe::Ok(); \ - }); +#define IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(name) \ + /*static*/ Maybe name##Op::GetSbp(user_op::SbpContext* ctx) { \ + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); \ + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { \ + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); \ + } \ + return Maybe::Ok(); \ + } \ + /*static*/ Maybe name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); \ + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); \ + return Maybe::Ok(); \ + } \ + /*static*/ Maybe name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe name##Op::InferDataType(user_op::InferContext* ctx) { \ + *ctx->OutputDType("out", 0) = DataType::kInt8; \ + return Maybe::Ok(); \ + } + +IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalEqual); +IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalNotEqual); +IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalGreater); +IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalGreaterEqual); +IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalLess); +IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalLessEqual); +IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalAnd); +IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalOr); +IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalXor); -REGISTER_SCALAR_LOGICAL_OP("scalar_logical_equal"); -REGISTER_SCALAR_LOGICAL_OP("scalar_logical_not_equal"); -REGISTER_SCALAR_LOGICAL_OP("scalar_logical_greater"); -REGISTER_SCALAR_LOGICAL_OP("scalar_logical_greater_equal"); -REGISTER_SCALAR_LOGICAL_OP("scalar_logical_less"); -REGISTER_SCALAR_LOGICAL_OP("scalar_logical_less_equal"); -REGISTER_SCALAR_LOGICAL_OP("scalar_logical_and"); -REGISTER_SCALAR_LOGICAL_OP("scalar_logical_or"); -REGISTER_SCALAR_LOGICAL_OP("scalar_logical_xor"); } // namespace oneflow diff --git a/oneflow/user/ops/scalar_math_op.cpp b/oneflow/user/ops/scalar_math_op.cpp index 82950d50131..3c20827281d 100644 --- a/oneflow/user/ops/scalar_math_op.cpp +++ b/oneflow/user/ops/scalar_math_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -38,55 +39,47 @@ Maybe GetSbp4ScalarMul(user_op::SbpContext* ctx) { } // namespace -#define REGISTER_SCALAR_MATH_OP(op_name, get_sbp_fn) \ - REGISTER_USER_OP(op_name) \ - .Input("in") \ - .Output("out") \ - .Attr("has_int_operand") \ - .Attr("has_float_operand") \ - .Attr("int_operand") \ - .Attr("float_operand") \ - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { \ - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); \ - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); \ - return Maybe::Ok(); \ - }) \ - .SetGetSbpFn(get_sbp_fn) \ - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { \ - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); \ - return Maybe::Ok(); \ - }); +#define IMPLEMENT_SCALAR_MATH_OP_FUNCS(op_name, get_sbp_fn) \ + /*static*/ Maybe op_name##Op::GetSbp(user_op::SbpContext* ctx) { return get_sbp_fn(ctx); } \ + /*static*/ Maybe op_name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); \ + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); \ + return Maybe::Ok(); \ + } \ + /*static*/ Maybe op_name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe op_name##Op::InferDataType(user_op::InferContext* ctx) { \ + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); \ + return Maybe::Ok(); \ + } -REGISTER_SCALAR_MATH_OP("scalar_add", GetSbp4ScalarMath) -REGISTER_SCALAR_MATH_OP("scalar_floordiv", GetSbp4ScalarMath) -REGISTER_SCALAR_MATH_OP("scalar_fmod", GetSbp4ScalarMath) -REGISTER_SCALAR_MATH_OP("scalar_mul", GetSbp4ScalarMul) -REGISTER_SCALAR_MATH_OP("scalar_pow", GetSbp4ScalarMath) +IMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarAdd, GetSbp4ScalarMath) +IMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarFloordiv, GetSbp4ScalarMath) +IMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarFmod, GetSbp4ScalarMath) +IMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarMul, GetSbp4ScalarMul) +IMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarPow, GetSbp4ScalarMath) +#undef IMPLEMENT_SCALAR_MATH_OP_FUNCS -REGISTER_USER_OP("scalar_pow_grad") - .Input("x") - .Input("dy") - .Attr("has_int_operand") - .Attr("has_float_operand") - .Attr("int_operand") - .Attr("float_operand") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("dx", 0) = ctx->InputShape("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ScalarPowGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe ScalarPowGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("dx", 0) = ctx->InputShape("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe ScalarPowGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ScalarPowGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("scalar_add") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/selu_op.cpp b/oneflow/user/ops/selu_op.cpp index 8697cbf39a4..8ed852eb395 100644 --- a/oneflow/user/ops/selu_op.cpp +++ b/oneflow/user/ops/selu_op.cpp @@ -14,61 +14,58 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ Maybe SeluOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe SeluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SeluOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("selu") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe SeluGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK(dy_shape == x_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} +/*static*/ Maybe SeluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SeluGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("selu_grad") - .Input("x") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK(dy_shape == x_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +namespace { REGISTER_USER_OP_GRAD("selu").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { const auto selu_grad_op_name = ctx->FwOp().op_name() + "_grad"; diff --git a/oneflow/user/ops/sigmoid_cross_entropy_op.cpp b/oneflow/user/ops/sigmoid_cross_entropy_op.cpp index cdf7e78870b..1928b5ab5c2 100644 --- a/oneflow/user/ops/sigmoid_cross_entropy_op.cpp +++ b/oneflow/user/ops/sigmoid_cross_entropy_op.cpp @@ -14,84 +14,84 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("sigmoid_cross_entropy") - .Input("prediction") - .Input("label") - .Output("loss") - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn("label", 0); - cond_arg_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); - const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); - CHECK_EQ_OR_RETURN(label_desc.shape(), prediction_desc.shape()); - user_op::TensorDesc* loss_desc = ctx->OutputTensorDesc("loss", 0); - *loss_desc->mut_shape() = prediction_desc.shape(); - *loss_desc->mut_is_dynamic() = prediction_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto num_out_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("prediction", 0).shape().NumAxes(); - FOR_RANGE(int64_t, i, 0, num_out_axes) { - ctx->NewBuilder() - .Split(user_op::OpArg("prediction", 0), i) - .Split(user_op::OpArg("label", 0), i) - .Split(user_op::OpArg("loss", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("loss", 0) = ctx->InputDType("prediction", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe SigmoidCrossEntropyOp::GetSbp(user_op::SbpContext* ctx) { + const auto num_out_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("prediction", 0).shape().NumAxes(); + FOR_RANGE(int64_t, i, 0, num_out_axes) { + ctx->NewBuilder() + .Split(user_op::OpArg("prediction", 0), i) + .Split(user_op::OpArg("label", 0), i) + .Split(user_op::OpArg("loss", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SigmoidCrossEntropyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); + const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); + CHECK_EQ_OR_RETURN(label_desc.shape(), prediction_desc.shape()); + user_op::TensorDesc* loss_desc = ctx->OutputTensorDesc("loss", 0); + *loss_desc->mut_shape() = prediction_desc.shape(); + *loss_desc->mut_is_dynamic() = prediction_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ Maybe SigmoidCrossEntropyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SigmoidCrossEntropyOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("loss", 0) = ctx->InputDType("prediction", 0); + return Maybe::Ok(); +} +/*static*/ Maybe SigmoidCrossEntropyOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn("label", 0); + cond_arg_modifier->set_requires_grad(false); + return Maybe::Ok(); +} -REGISTER_USER_OP("sigmoid_cross_entropy_grad") - .Input("prediction") - .Input("loss_diff") - .Input("label") - .Output("prediction_diff") - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn("label", 0); - cond_arg_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); - const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); - const user_op::TensorDesc& loss_diff_desc = ctx->InputTensorDesc("loss_diff", 0); - CHECK_EQ_OR_RETURN(label_desc.shape(), prediction_desc.shape()); - CHECK_EQ_OR_RETURN(loss_diff_desc.shape(), prediction_desc.shape()); - user_op::TensorDesc* prediction_diff = ctx->OutputTensorDesc("prediction_diff", 0); - *prediction_diff->mut_shape() = prediction_desc.shape(); - *prediction_diff->mut_is_dynamic() = prediction_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto num_dy_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("loss_diff", 0).shape().NumAxes(); - FOR_RANGE(int64_t, i, 0, num_dy_axes) { - ctx->NewBuilder() - .Split(user_op::OpArg("loss_diff", 0), i) - .Split(user_op::OpArg("label", 0), i) - .Split(user_op::OpArg("prediction", 0), i) - .Split(user_op::OpArg("prediction_diff", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("prediction_diff", 0) = ctx->InputDType("prediction", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe SigmoidCrossEntropyGradOp::GetSbp(user_op::SbpContext* ctx) { + const auto num_dy_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("loss_diff", 0).shape().NumAxes(); + FOR_RANGE(int64_t, i, 0, num_dy_axes) { + ctx->NewBuilder() + .Split(user_op::OpArg("loss_diff", 0), i) + .Split(user_op::OpArg("label", 0), i) + .Split(user_op::OpArg("prediction", 0), i) + .Split(user_op::OpArg("prediction_diff", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SigmoidCrossEntropyGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); + const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); + const user_op::TensorDesc& loss_diff_desc = ctx->InputTensorDesc("loss_diff", 0); + CHECK_EQ_OR_RETURN(label_desc.shape(), prediction_desc.shape()); + CHECK_EQ_OR_RETURN(loss_diff_desc.shape(), prediction_desc.shape()); + user_op::TensorDesc* prediction_diff = ctx->OutputTensorDesc("prediction_diff", 0); + *prediction_diff->mut_shape() = prediction_desc.shape(); + *prediction_diff->mut_is_dynamic() = prediction_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ Maybe SigmoidCrossEntropyGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SigmoidCrossEntropyGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("prediction_diff", 0) = ctx->InputDType("prediction", 0); + return Maybe::Ok(); +} +/*static*/ Maybe SigmoidCrossEntropyGradOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn("label", 0); + cond_arg_modifier->set_requires_grad(false); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("sigmoid_cross_entropy") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/sigmoid_op.cpp b/oneflow/user/ops/sigmoid_op.cpp index 3af60af6440..f45506bc723 100644 --- a/oneflow/user/ops/sigmoid_op.cpp +++ b/oneflow/user/ops/sigmoid_op.cpp @@ -14,63 +14,60 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ Maybe SigmoidOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SigmoidOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + Shape* out_shape = ctx->OutputShape("out", 0); + *out_shape = in_shape; + return Maybe::Ok(); +} +/*static*/ Maybe SigmoidOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SigmoidOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("sigmoid") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); - *out_shape = in_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe SigmoidGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); + FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("y", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SigmoidGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& y_shape = ctx->InputShape("y", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(dy_shape == y_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} +/*static*/ Maybe SigmoidGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SigmoidGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("y", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("y", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("sigmoid_grad") - .Input("y") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& y_shape = ctx->InputShape("y", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(dy_shape == y_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); - FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("y", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("y", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("y", 0); - return Maybe::Ok(); - }); +namespace { REGISTER_USER_OP_GRAD("sigmoid").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/silu_op.cpp b/oneflow/user/ops/silu_op.cpp index eb46ab7b406..59c9831bf29 100644 --- a/oneflow/user/ops/silu_op.cpp +++ b/oneflow/user/ops/silu_op.cpp @@ -14,61 +14,58 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ Maybe SiluOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SiluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe SiluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SiluOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("silu") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe SiluGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SiluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK(dy_shape == x_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} +/*static*/ Maybe SiluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SiluGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("silu_grad") - .Input("x") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK(dy_shape == x_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +namespace { REGISTER_USER_OP_GRAD("silu").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { const auto silu_grad_op_name = ctx->FwOp().op_name() + "_grad"; diff --git a/oneflow/user/ops/slice_op.cpp b/oneflow/user/ops/slice_op.cpp index 3d5cc31e2b0..473927db358 100644 --- a/oneflow/user/ops/slice_op.cpp +++ b/oneflow/user/ops/slice_op.cpp @@ -15,19 +15,39 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/slice_util.h" +#include "oneflow/core/framework/op_generated.h" +#include "oneflow/core/operator/operator.h" namespace oneflow { namespace { - bool IsFullSlice(int64_t start, int64_t stop, int64_t step, int64_t size) { if (step != 1) { return false; } if (start != 0) { return false; } if (stop != size) { return false; } return true; } +} // namespace + +/*static*/ Maybe SliceOp::GetSbp(user_op::SbpContext* ctx) { + const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); + const int64_t ndim = x_shape.NumAxes(); + const auto& start_vec = ctx->Attr>("start"); + const auto& stop_vec = ctx->Attr>("stop"); + const auto& step_vec = ctx->Attr>("step"); + CHECK_EQ_OR_RETURN(start_vec.size(), ndim); + CHECK_EQ_OR_RETURN(stop_vec.size(), ndim); + CHECK_EQ_OR_RETURN(step_vec.size(), ndim); -Maybe InferSliceOpTensorDesc(user_op::InferContext* ctx) { + FOR_RANGE(int, i, 0, ndim) { + if (IsFullSlice(start_vec.at(i), stop_vec.at(i), step_vec.at(i), x_shape.At(i))) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + } + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe SliceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const int64_t ndim = x_shape.NumAxes(); const auto& start_vec = ctx->Attr>("start"); @@ -63,15 +83,17 @@ Maybe InferSliceOpTensorDesc(user_op::InferContext* ctx) { *ctx->OutputShape("y", 0) = Shape(dim_vec); return Maybe::Ok(); } - -Maybe InferSliceOpDataType(user_op::InferContext* ctx) { +/*static*/ Maybe SliceOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SliceOp::InferDataType(user_op::InferContext* ctx) { *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } -Maybe GetSliceOpSbpSignature(user_op::SbpContext* ctx) { - const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); - const int64_t ndim = x_shape.NumAxes(); +/*static*/ Maybe SliceGradOp::GetSbp(user_op::SbpContext* ctx) { + const Shape& like_shape = ctx->Attr("like_shape"); + const int64_t ndim = like_shape.NumAxes(); const auto& start_vec = ctx->Attr>("start"); const auto& stop_vec = ctx->Attr>("stop"); const auto& step_vec = ctx->Attr>("step"); @@ -80,16 +102,16 @@ Maybe GetSliceOpSbpSignature(user_op::SbpContext* ctx) { CHECK_EQ_OR_RETURN(step_vec.size(), ndim); FOR_RANGE(int, i, 0, ndim) { - if (IsFullSlice(start_vec.at(i), stop_vec.at(i), step_vec.at(i), x_shape.At(i))) { + if (IsFullSlice(start_vec.at(i), stop_vec.at(i), step_vec.at(i), like_shape.At(i))) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } } - ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); + ctx->NewBuilder().PartialSum(user_op::OpArg("dy", 0)).PartialSum(user_op::OpArg("dx", 0)).Build(); + ctx->NewBuilder().Broadcast(user_op::OpArg("dy", 0)).Broadcast(user_op::OpArg("dx", 0)).Build(); return Maybe::Ok(); } - -Maybe InferSliceGradOpTensorDesc(user_op::InferContext* ctx) { - const Shape& like_shape = ctx->InputShape("like", 0); +/*static*/ Maybe SliceGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& like_shape = ctx->Attr("like_shape"); const Shape& dy_shape = ctx->InputShape("dy", 0); const auto& start_vec = ctx->Attr>("start"); const auto& stop_vec = ctx->Attr>("stop"); @@ -103,15 +125,131 @@ Maybe InferSliceGradOpTensorDesc(user_op::InferContext* ctx) { *ctx->OutputShape("dx", 0) = like_shape; return Maybe::Ok(); } - -Maybe InferSliceGradDataType(user_op::InferContext* ctx) { +/*static*/ Maybe SliceGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + Shape logical_shape = ctx->Attr("like_shape"); + const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); + user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + *dx_desc->mut_is_dynamic() = dy_desc.is_dynamic(); + + const auto& nd_sbp = ctx->NdSbp4ArgNameAndIndex("dx", 0); + *(dx_desc->mut_shape()) = + *JUST(GetPhysicalShape(logical_shape, nd_sbp, ctx->parallel_desc(), ctx->parallel_ctx())); + int dx_ndim = dx_desc->shape().NumAxes(); + int dy_ndim = dy_desc.shape().NumAxes(); + CHECK_EQ_OR_RETURN(dx_ndim, dy_ndim) + << "Output dimension (" << dx_ndim << ") should equal to the input dimension (" << dy_ndim + << ") for slice backward."; + return Maybe::Ok(); +} +/*static*/ Maybe SliceGradOp::InferDataType(user_op::InferContext* ctx) { *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } +/*static*/ Maybe SliceGradOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* dy_modifier = GetInputArgModifierFn("dy", 0); + CHECK_NOTNULL_OR_RETURN(dy_modifier); + dy_modifier->set_requires_grad(false); + return Maybe::Ok(); +} -Maybe GetSliceGradOpSbpSignature(user_op::SbpContext* ctx) { - const Shape& like_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape(); - const int64_t ndim = like_shape.NumAxes(); +/*static*/ Maybe LogicalSliceAssignOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& ref_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("ref", 0); + FOR_RANGE(int64_t, axis, 0, ref_desc.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("ref", 0), axis) + // TODO(jianhao): Support (S(n), S(n)) when axis n is not sliced + .Broadcast(user_op::OpArg("value", 0)) + .Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("ref", 0)) + .PartialSum(user_op::OpArg("value", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe LogicalSliceAssignOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc("ref", 0); + const auto& start_vec = ctx->Attr>("start"); + const auto& stop_vec = ctx->Attr>("stop"); + const auto& step_vec = ctx->Attr>("step"); + CHECK_OR_RETURN(!ref_desc.is_dynamic()); + FOR_RANGE(size_t, i, 0, step_vec.size()) { + const int64_t step = step_vec.at(i); + const int64_t start = start_vec.at(i); + const int64_t stop = stop_vec.at(i); + CHECK_GT_OR_RETURN(step, 0) << "logical_slice_assign step must be greater than 0"; + CHECK_GE_OR_RETURN(start, 0) << "logical_slice_assign start must be greater or equal to 0"; + CHECK_GT_OR_RETURN(stop, 0) << "logical_slice_assign stop must be greater than 0"; + CHECK_LT_OR_RETURN(start, stop) << "logical_slice_assign start must be less than stop"; + } + return Maybe::Ok(); +} +/*static*/ Maybe LogicalSliceAssignOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe LogicalSliceAssignOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc("ref", 0); + const user_op::TensorDesc& value_desc = ctx->InputTensorDesc("value", 0); + CHECK_OR_RETURN(ref_desc.data_type() == value_desc.data_type()); + return Maybe::Ok(); +} + +/*static*/ Maybe LogicalSliceAssignOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* ref_modifier = GetInputArgModifierFn("ref", 0); + CHECK_OR_RETURN(ref_modifier != nullptr); + ref_modifier->set_is_mutable(true); + user_op::InputArgModifier* value_modifier = GetInputArgModifierFn("value", 0); + CHECK_OR_RETURN(value_modifier != nullptr); + value_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/*static*/ Maybe LogicalSliceOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& input_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, axis, 0, input_desc.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), axis) + // TODO(jianhao): Support S(n) -> S(n) when axis n is not sliced + .PartialSum(user_op::OpArg("y", 0)) + .Build(); + } + ctx->NewBuilder().PartialSum(user_op::OpArg("x", 0)).PartialSum(user_op::OpArg("y", 0)).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe LogicalSliceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const int64_t ndim = x_shape.NumAxes(); + const auto& start_vec = ctx->Attr>("start"); + const auto& stop_vec = ctx->Attr>("stop"); + const auto& step_vec = ctx->Attr>("step"); + DimVector dim_vec(ndim); + FOR_RANGE(size_t, i, 0, dim_vec.size()) { + const int64_t step = step_vec.at(i); + const int64_t start = start_vec.at(i); + const int64_t stop = stop_vec.at(i); + CHECK_GT_OR_RETURN(step, 0) << "LogicalSlice step must be greater than 0"; + CHECK_GE_OR_RETURN(start, 0) << "LogicalSlice start must be greater or equal to 0"; + CHECK_GT_OR_RETURN(stop, 0) << "LogicalSlice stop must be greater than 0"; + CHECK_LT_OR_RETURN(start, stop) << "LogicalSlice start must be less than stop"; + const int64_t diff = stop - start - 1; + dim_vec[i] = diff / step + 1; + } + *ctx->OutputShape("y", 0) = Shape(dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe LogicalSliceOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe LogicalSliceOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe SliceUpdateOp::GetSbp(user_op::SbpContext* ctx) { + const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); + const int64_t ndim = x_shape.NumAxes(); const auto& start_vec = ctx->Attr>("start"); const auto& stop_vec = ctx->Attr>("stop"); const auto& step_vec = ctx->Attr>("step"); @@ -120,36 +258,14 @@ Maybe GetSliceGradOpSbpSignature(user_op::SbpContext* ctx) { CHECK_EQ_OR_RETURN(step_vec.size(), ndim); FOR_RANGE(int, i, 0, ndim) { - if (IsFullSlice(start_vec.at(i), stop_vec.at(i), step_vec.at(i), like_shape.At(i))) { + if (IsFullSlice(start_vec.at(i), stop_vec.at(i), step_vec.at(i), x_shape.At(i))) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } } ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); - ctx->NewBuilder() - .PartialSum(user_op::OpArg("dy", 0)) - .Broadcast(user_op::OpArg("like", 0)) - .PartialSum(user_op::OpArg("dx", 0)) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("dy", 0)) - .PartialSum(user_op::OpArg("like", 0)) - .Broadcast(user_op::OpArg("dx", 0)) - .Build(); - return Maybe::Ok(); -} - -Maybe InferSliceGradInputArgModifier(user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) { - user_op::InputArgModifier* dy_modifier = GetInputArgModifierFn("dy", 0); - CHECK_NOTNULL_OR_RETURN(dy_modifier); - dy_modifier->set_requires_grad(false); - user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); - CHECK_NOTNULL_OR_RETURN(like_modifier); - like_modifier->set_requires_grad(false); return Maybe::Ok(); } - -Maybe InferSliceUpdateOpTensorDesc(user_op::InferContext* ctx) { +/*static*/ Maybe SliceUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const auto& x_desc = ctx->InputTensorDesc("x", 0); const int64_t ndim = x_desc.shape().NumAxes(); const auto& update_desc = ctx->InputTensorDesc("update", 0); @@ -185,8 +301,10 @@ Maybe InferSliceUpdateOpTensorDesc(user_op::InferContext* ctx) { *y_desc->mut_is_dynamic() = x_desc.is_dynamic(); return Maybe::Ok(); } - -Maybe InferSliceUpdateOpDataType(user_op::InferContext* ctx) { +/*static*/ Maybe SliceUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SliceUpdateOp::InferDataType(user_op::InferContext* ctx) { const auto& x_desc = ctx->InputTensorDesc("x", 0); const auto& update_desc = ctx->InputTensorDesc("update", 0); CHECK_EQ_OR_RETURN(update_desc.data_type(), x_desc.data_type()); @@ -195,31 +313,15 @@ Maybe InferSliceUpdateOpDataType(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe GetSliceUpdateOpSbpSignature(user_op::SbpContext* ctx) { - const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); - const int64_t ndim = x_shape.NumAxes(); - const auto& start_vec = ctx->Attr>("start"); - const auto& stop_vec = ctx->Attr>("stop"); - const auto& step_vec = ctx->Attr>("step"); - CHECK_EQ_OR_RETURN(start_vec.size(), ndim); - CHECK_EQ_OR_RETURN(stop_vec.size(), ndim); - CHECK_EQ_OR_RETURN(step_vec.size(), ndim); - - FOR_RANGE(int, i, 0, ndim) { - if (IsFullSlice(start_vec.at(i), stop_vec.at(i), step_vec.at(i), x_shape.At(i))) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - } - ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); - return Maybe::Ok(); -} +namespace { Maybe GenSliceGradOp(const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) { if (op.NeedGenGradTensor4OpInput("x", 0)) { + const auto& x_desc = op.TensorDesc4ArgNameAndIndex("x", 0); user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); user_op::UserOpConfWrapper grad_op = builder.Op("slice_grad") .Input("dy", op.GetGradTensorWithOpOutput("y", 0)) - .Input("like", op.input("x", 0)) + .Attr("like_shape", x_desc.shape()) .Attr("start", op.attr>("start")) .Attr("stop", op.attr>("stop")) .Attr("step", op.attr>("step")) @@ -231,98 +333,6 @@ Maybe GenSliceGradOp(const user_op::UserOpWrapper& op, user_op::AddOpFn Ad return Maybe::Ok(); } -Maybe InferLogicalSliceAssignTensorDesc(user_op::InferContext* ctx) { - const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc("ref", 0); - const auto& start_vec = ctx->Attr>("start"); - const auto& stop_vec = ctx->Attr>("stop"); - const auto& step_vec = ctx->Attr>("step"); - CHECK_OR_RETURN(!ref_desc.is_dynamic()); - FOR_RANGE(size_t, i, 0, step_vec.size()) { - const int64_t step = step_vec.at(i); - const int64_t start = start_vec.at(i); - const int64_t stop = stop_vec.at(i); - CHECK_GT_OR_RETURN(step, 0) << "logical_slice_assign step must be greater than 0"; - CHECK_GE_OR_RETURN(start, 0) << "logical_slice_assign start must be greater or equal to 0"; - CHECK_GT_OR_RETURN(stop, 0) << "logical_slice_assign stop must be greater than 0"; - CHECK_LT_OR_RETURN(start, stop) << "logical_slice_assign start must be less than stop"; - } - return Maybe::Ok(); -} - -Maybe InferLogicalSliceAssignDataType(user_op::InferContext* ctx) { - const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc("ref", 0); - const user_op::TensorDesc& value_desc = ctx->InputTensorDesc("value", 0); - CHECK_OR_RETURN(ref_desc.data_type() == value_desc.data_type()); - return Maybe::Ok(); -} - -Maybe GetLogicalSliceAssignSbpSignatures(user_op::SbpContext* ctx) { - const user_op::TensorDesc& ref_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("ref", 0); - FOR_RANGE(int64_t, axis, 0, ref_desc.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("ref", 0), axis) - // TODO(jianhao): Support (S(n), S(n)) when axis n is not sliced - .Broadcast(user_op::OpArg("value", 0)) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("ref", 0)) - .PartialSum(user_op::OpArg("value", 0)) - .Build(); - return Maybe::Ok(); -} - -Maybe InferLogicalSliceAssignInputArgModifier( - user_op::GetInputArgModifier GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { - user_op::InputArgModifier* ref_modifier = GetInputArgModifierFn("ref", 0); - CHECK_OR_RETURN(ref_modifier != nullptr); - ref_modifier->set_is_mutable(true); - user_op::InputArgModifier* value_modifier = GetInputArgModifierFn("value", 0); - CHECK_OR_RETURN(value_modifier != nullptr); - value_modifier->set_requires_grad(false); - return Maybe::Ok(); -} - -Maybe InferLogicalSliceTensorDesc(user_op::InferContext* ctx) { - const Shape& x_shape = ctx->InputShape("x", 0); - const int64_t ndim = x_shape.NumAxes(); - const auto& start_vec = ctx->Attr>("start"); - const auto& stop_vec = ctx->Attr>("stop"); - const auto& step_vec = ctx->Attr>("step"); - DimVector dim_vec(ndim); - FOR_RANGE(size_t, i, 0, dim_vec.size()) { - const int64_t step = step_vec.at(i); - const int64_t start = start_vec.at(i); - const int64_t stop = stop_vec.at(i); - CHECK_GT_OR_RETURN(step, 0) << "LogicalSlice step must be greater than 0"; - CHECK_GE_OR_RETURN(start, 0) << "LogicalSlice start must be greater or equal to 0"; - CHECK_GT_OR_RETURN(stop, 0) << "LogicalSlice stop must be greater than 0"; - CHECK_LT_OR_RETURN(start, stop) << "LogicalSlice start must be less than stop"; - const int64_t diff = stop - start - 1; - dim_vec[i] = diff / step + 1; - } - *ctx->OutputShape("y", 0) = Shape(dim_vec); - return Maybe::Ok(); -} - -Maybe InferLogicalSliceDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); -} - -Maybe GetLogicalSliceSbpSignatures(user_op::SbpContext* ctx) { - const user_op::TensorDesc& input_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, axis, 0, input_desc.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), axis) - // TODO(jianhao): Support S(n) -> S(n) when axis n is not sliced - .PartialSum(user_op::OpArg("y", 0)) - .Build(); - } - ctx->NewBuilder().PartialSum(user_op::OpArg("x", 0)).PartialSum(user_op::OpArg("y", 0)).Build(); - return Maybe::Ok(); -} - Maybe GenSliceUpdateGradOp(user_op::BackwardOpConfContext* ctx) { const std::string update_grad_op_name = ctx->FwOp().op_name() + "_update_grad"; ctx->DefineOp(update_grad_op_name, [&](user_op::BackwardOpBuilder& builder) { @@ -364,62 +374,7 @@ Maybe GenSliceUpdateGradOp(user_op::BackwardOpConfContext* ctx) { } // namespace -REGISTER_USER_OP("slice") - .Input("x") - .Output("y") - .Attr>("start") - .Attr>("stop") - .Attr>("step") - .SetTensorDescInferFn(InferSliceOpTensorDesc) - .SetDataTypeInferFn(InferSliceOpDataType) - .SetGetSbpFn(GetSliceOpSbpSignature); - -REGISTER_USER_OP("slice_grad") - .Input("dy") - .Input("like") - .Output("dx") - .Attr>("start") - .Attr>("stop") - .Attr>("step") - .SetTensorDescInferFn(InferSliceGradOpTensorDesc) - .SetDataTypeInferFn(InferSliceGradDataType) - .SetGetSbpFn(GetSliceGradOpSbpSignature) - .SetInputArgModifyFn(InferSliceGradInputArgModifier); - -REGISTER_USER_OP("logical_slice_assign") - .Input("ref") - .Input("value") - .Attr>("start") - .Attr>("stop") - .Attr>("step") - .SetTensorDescInferFn(InferLogicalSliceAssignTensorDesc) - .SetDataTypeInferFn(InferLogicalSliceAssignDataType) - .SetGetSbpFn(GetLogicalSliceAssignSbpSignatures) - .SetInputArgModifyFn(InferLogicalSliceAssignInputArgModifier); - -REGISTER_USER_OP("logical_slice") - .Input("x") - .Output("y") - .Attr>("start") - .Attr>("stop") - .Attr>("step") - .SetTensorDescInferFn(InferLogicalSliceTensorDesc) - .SetDataTypeInferFn(InferLogicalSliceDataType) - .SetGetSbpFn(GetLogicalSliceSbpSignatures); - REGISTER_USER_OP_GRAD("slice").SetGenBackwardOpConfFn(GenSliceGradOp); - -REGISTER_USER_OP("slice_update") - .Input("x") - .Input("update") - .Output("y") - .Attr>("start") - .Attr>("stop") - .Attr>("step") - .SetTensorDescInferFn(InferSliceUpdateOpTensorDesc) - .SetDataTypeInferFn(InferSliceUpdateOpDataType) - .SetGetSbpFn(GetSliceUpdateOpSbpSignature); - REGISTER_USER_OP_GRAD("slice_update").SetBackwardOpConfGenFn(GenSliceUpdateGradOp); } // namespace oneflow diff --git a/oneflow/user/ops/smooth_l1_loss_op.cpp b/oneflow/user/ops/smooth_l1_loss_op.cpp index 67c110f526e..025895cb2d7 100644 --- a/oneflow/user/ops/smooth_l1_loss_op.cpp +++ b/oneflow/user/ops/smooth_l1_loss_op.cpp @@ -15,12 +15,18 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/loss_op_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -Maybe InferTensorDescFn(user_op::InferContext* ctx) { +/*static*/ Maybe SmoothL1LossOp::GetSbp(user_op::SbpContext* ctx) { + const auto& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); + FOR_RANGE(int64_t, i, 0, input_shape.NumAxes()) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SmoothL1LossOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const auto& input_desc = ctx->InputTensorDesc("input", 0); const auto& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.is_dynamic(), target_desc.is_dynamic()); @@ -33,8 +39,10 @@ Maybe InferTensorDescFn(user_op::InferContext* ctx) { return Maybe::Ok(); } - -Maybe InferDataType(user_op::InferContext* ctx) { +/*static*/ Maybe SmoothL1LossOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SmoothL1LossOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type()); @@ -43,8 +51,27 @@ Maybe InferDataType(user_op::InferContext* ctx) { return Maybe::Ok(); } +/*static*/ Maybe SmoothL1LossOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); + CHECK_OR_RETURN(target_modifier != nullptr); + target_modifier->set_requires_grad(false); + return Maybe::Ok(); +} -Maybe InferGradTensorDescFn(user_op::InferContext* ctx) { +/*static*/ Maybe SmoothL1LossGradOp::GetSbp(user_op::SbpContext* ctx) { + const auto& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); + FOR_RANGE(int64_t, i, 0, input_shape.NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("input", 0), i) + .Split(user_op::OpArg("target", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SmoothL1LossGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const auto& input_desc = ctx->InputTensorDesc("input", 0); const auto& target_desc = ctx->InputTensorDesc("target", 0); const auto& dy_desc = ctx->InputTensorDesc("dy", 0); @@ -60,8 +87,10 @@ Maybe InferGradTensorDescFn(user_op::InferContext* ctx) { return Maybe::Ok(); } - -Maybe InferGradDataType(user_op::InferContext* ctx) { +/*static*/ Maybe SmoothL1LossGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SmoothL1LossGradOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type()); @@ -71,50 +100,6 @@ Maybe InferGradDataType(user_op::InferContext* ctx) { return Maybe::Ok(); } -} // namespace -REGISTER_USER_OP("smooth_l1_loss") - .Input("input") - .Input("target") - .Output("out") - .Attr("beta") - .SetTensorDescInferFn(InferTensorDescFn) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); - CHECK_OR_RETURN(target_modifier != nullptr); - target_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetDataTypeInferFn(InferDataType) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); - FOR_RANGE(int64_t, i, 0, input_shape.NumAxes()) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(user_op::OpArg("out", 0), i).Build(); - } - return Maybe::Ok(); - }); - -REGISTER_USER_OP("smooth_l1_loss_grad") - .Input("input") - .Input("target") - .Input("dy") - .Output("dx") - .Attr("beta") - .SetTensorDescInferFn(InferGradTensorDescFn) - .SetDataTypeInferFn(InferGradDataType) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); - FOR_RANGE(int64_t, i, 0, input_shape.NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("input", 0), i) - .Split(user_op::OpArg("target", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Build(); - } - return Maybe::Ok(); - }); - REGISTER_USER_OP_GRAD("smooth_l1_loss") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, const user_op::AddOpFn& AddOp) -> Maybe { diff --git a/oneflow/user/ops/softmax_cross_entropy_op.cpp b/oneflow/user/ops/softmax_cross_entropy_op.cpp index 85836b93dac..aa42ab0ee40 100644 --- a/oneflow/user/ops/softmax_cross_entropy_op.cpp +++ b/oneflow/user/ops/softmax_cross_entropy_op.cpp @@ -14,104 +14,102 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("softmax_cross_entropy") - .Input("prediction") - .Input("label") - .Output("prob") //'prob' is just for compute prediction's grad, prob's grad will be ignored - .Output("out") - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn("label", 0); - cond_arg_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); - const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); - CHECK_EQ_OR_RETURN(prediction_desc.is_dynamic(), label_desc.is_dynamic()); - CHECK_GE_OR_RETURN(prediction_desc.shape().NumAxes(), 2); - CHECK_EQ_OR_RETURN(label_desc.shape(), prediction_desc.shape()); - const int64_t num_out_axes = prediction_desc.shape().NumAxes() - 1; - DimVector out_dim_vector; - FOR_RANGE(int64_t, i, 0, num_out_axes) { - out_dim_vector.emplace_back(prediction_desc.shape().At(i)); - } - *ctx->OutputShape("prob", 0) = ctx->InputShape("prediction", 0); - *ctx->OutputIsDynamic("prob", 0) = ctx->InputIsDynamic("prediction", 0); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_is_dynamic() = prediction_desc.is_dynamic(); - *out_desc->mut_shape() = Shape(out_dim_vector); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - // ctx->LogicalTensorDesc4InputArgNameAndIndex("out", 0) is not initialized here - const auto num_out_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("prediction", 0).shape().NumAxes() - 1; - FOR_RANGE(int64_t, i, 0, num_out_axes) { - ctx->NewBuilder() - .Split(user_op::OpArg("prediction", 0), i) - .Split(user_op::OpArg("label", 0), i) - .Split(user_op::OpArg("prob", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); - const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); - CHECK_EQ_OR_RETURN(label_desc.data_type(), prediction_desc.data_type()); - *ctx->OutputDType("prob", 0) = ctx->InputDType("prediction", 0); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_data_type() = prediction_desc.data_type(); - return Maybe::Ok(); - }); +/*static*/ Maybe SoftmaxCrossEntropyOp::GetSbp(user_op::SbpContext* ctx) { + // ctx->LogicalTensorDesc4InputArgNameAndIndex("out", 0) is not initialized here + const auto num_out_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("prediction", 0).shape().NumAxes() - 1; + FOR_RANGE(int64_t, i, 0, num_out_axes) { + ctx->NewBuilder() + .Split(user_op::OpArg("prediction", 0), i) + .Split(user_op::OpArg("label", 0), i) + .Split(user_op::OpArg("prob", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SoftmaxCrossEntropyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); + const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); + CHECK_EQ_OR_RETURN(prediction_desc.is_dynamic(), label_desc.is_dynamic()); + CHECK_GE_OR_RETURN(prediction_desc.shape().NumAxes(), 2); + CHECK_EQ_OR_RETURN(label_desc.shape(), prediction_desc.shape()); + const int64_t num_out_axes = prediction_desc.shape().NumAxes() - 1; + DimVector out_dim_vector; + FOR_RANGE(int64_t, i, 0, num_out_axes) { + out_dim_vector.emplace_back(prediction_desc.shape().At(i)); + } + *ctx->OutputShape("prob", 0) = ctx->InputShape("prediction", 0); + *ctx->OutputIsDynamic("prob", 0) = ctx->InputIsDynamic("prediction", 0); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_is_dynamic() = prediction_desc.is_dynamic(); + *out_desc->mut_shape() = Shape(out_dim_vector); + return Maybe::Ok(); +} +/*static*/ Maybe SoftmaxCrossEntropyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SoftmaxCrossEntropyOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); + const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); + CHECK_EQ_OR_RETURN(label_desc.data_type(), prediction_desc.data_type()); + *ctx->OutputDType("prob", 0) = ctx->InputDType("prediction", 0); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_data_type() = prediction_desc.data_type(); + return Maybe::Ok(); +} +/*static*/ Maybe SoftmaxCrossEntropyOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn("label", 0); + cond_arg_modifier->set_requires_grad(false); + return Maybe::Ok(); +} -REGISTER_USER_OP("softmax_cross_entropy_grad") - .Input("dy") - .Input("label") - .Input("prob") - .Output("prediction_diff") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& prob_desc = ctx->InputTensorDesc("prob", 0); - const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); - const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); - CHECK_EQ_OR_RETURN(prob_desc.is_dynamic(), label_desc.is_dynamic()); - CHECK_GE_OR_RETURN(prob_desc.shape().NumAxes(), 2); - CHECK_EQ_OR_RETURN(dy_desc.shape().NumAxes(), prob_desc.shape().NumAxes() - 1); - FOR_RANGE(int64_t, i, 0, dy_desc.shape().NumAxes()) { - CHECK_EQ_OR_RETURN(dy_desc.shape().At(i), label_desc.shape().At(i)); - } - CHECK_EQ_OR_RETURN(label_desc.shape(), prob_desc.shape()); - *ctx->OutputShape("prediction_diff", 0) = ctx->InputShape("prob", 0); - *ctx->OutputIsDynamic("prediction_diff", 0) = ctx->InputIsDynamic("prob", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto num_dy_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0).shape().NumAxes(); - FOR_RANGE(int64_t, i, 0, num_dy_axes) { - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("label", 0), i) - .Split(user_op::OpArg("prob", 0), i) - .Split(user_op::OpArg("prediction_diff", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& prob_desc = ctx->InputTensorDesc("prob", 0); - const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); - const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); - CHECK_EQ_OR_RETURN(label_desc.data_type(), prob_desc.data_type()); - CHECK_EQ_OR_RETURN(dy_desc.data_type(), prob_desc.data_type()); - *ctx->OutputDType("prediction_diff", 0) = ctx->InputDType("prob", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe SoftmaxCrossEntropyGradOp::GetSbp(user_op::SbpContext* ctx) { + const auto num_dy_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0).shape().NumAxes(); + FOR_RANGE(int64_t, i, 0, num_dy_axes) { + ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("label", 0), i) + .Split(user_op::OpArg("prob", 0), i) + .Split(user_op::OpArg("prediction_diff", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SoftmaxCrossEntropyGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& prob_desc = ctx->InputTensorDesc("prob", 0); + const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); + const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); + CHECK_EQ_OR_RETURN(prob_desc.is_dynamic(), label_desc.is_dynamic()); + CHECK_GE_OR_RETURN(prob_desc.shape().NumAxes(), 2); + CHECK_EQ_OR_RETURN(dy_desc.shape().NumAxes(), prob_desc.shape().NumAxes() - 1); + FOR_RANGE(int64_t, i, 0, dy_desc.shape().NumAxes()) { + CHECK_EQ_OR_RETURN(dy_desc.shape().At(i), label_desc.shape().At(i)); + } + CHECK_EQ_OR_RETURN(label_desc.shape(), prob_desc.shape()); + *ctx->OutputShape("prediction_diff", 0) = ctx->InputShape("prob", 0); + *ctx->OutputIsDynamic("prediction_diff", 0) = ctx->InputIsDynamic("prob", 0); + return Maybe::Ok(); +} +/*static*/ Maybe SoftmaxCrossEntropyGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SoftmaxCrossEntropyGradOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& prob_desc = ctx->InputTensorDesc("prob", 0); + const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); + const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); + CHECK_EQ_OR_RETURN(label_desc.data_type(), prob_desc.data_type()); + CHECK_EQ_OR_RETURN(dy_desc.data_type(), prob_desc.data_type()); + *ctx->OutputDType("prediction_diff", 0) = ctx->InputDType("prob", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("softmax_cross_entropy") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/softmax_op.cpp b/oneflow/user/ops/softmax_op.cpp index e4c8ad8b730..d460508d783 100644 --- a/oneflow/user/ops/softmax_op.cpp +++ b/oneflow/user/ops/softmax_op.cpp @@ -14,61 +14,61 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ Maybe SoftmaxOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes() - 1) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), axis) + .Split(user_op::OpArg("out", 0), axis) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SoftmaxOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe SoftmaxOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SoftmaxOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("softmax") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes() - 1) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), axis) - .Split(user_op::OpArg("out", 0), axis) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe SoftmaxGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); + FOR_RANGE(int64_t, axis, 0, y_tensor.shape().NumAxes() - 1) { + ctx->NewBuilder() + .Split(user_op::OpArg("y", 0), axis) + .Split(user_op::OpArg("dy", 0), axis) + .Split(user_op::OpArg("dx", 0), axis) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SoftmaxGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& y_shape = ctx->InputShape("y", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(dy_shape == y_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} +/*static*/ Maybe SoftmaxGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SoftmaxGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("y", 0), ctx->InputDType("dy", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("y", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("softmax_grad") - .Input("y") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& y_shape = ctx->InputShape("y", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(dy_shape == y_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("y", 0), ctx->InputDType("dy", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("y", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); - FOR_RANGE(int64_t, axis, 0, y_tensor.shape().NumAxes() - 1) { - ctx->NewBuilder() - .Split(user_op::OpArg("y", 0), axis) - .Split(user_op::OpArg("dy", 0), axis) - .Split(user_op::OpArg("dx", 0), axis) - .Build(); - } - return Maybe::Ok(); - }); +namespace { REGISTER_USER_OP_GRAD("softmax").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/softsign_op.cpp b/oneflow/user/ops/softsign_op.cpp index 3249803029e..9cbc34d8cdf 100644 --- a/oneflow/user/ops/softsign_op.cpp +++ b/oneflow/user/ops/softsign_op.cpp @@ -14,61 +14,58 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ Maybe SoftsignOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SoftsignOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe SoftsignOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SoftsignOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("softsign") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe SoftsignGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SoftsignGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK(dy_shape == x_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} +/*static*/ Maybe SoftsignGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SoftsignGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("softsign_grad") - .Input("x") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK(dy_shape == x_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +namespace { REGISTER_USER_OP_GRAD("softsign").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { const auto softsign_grad_op_name = ctx->FwOp().op_name() + "_grad"; diff --git a/oneflow/user/ops/sort_op.cpp b/oneflow/user/ops/sort_op.cpp index 469533b88ae..cbcbaa07e48 100644 --- a/oneflow/user/ops/sort_op.cpp +++ b/oneflow/user/ops/sort_op.cpp @@ -14,35 +14,35 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("sort") - .Input("in") - .Output("out") - .Attr("direction") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - // The current implementation can only do sort in the last dimension and should use Broadcast - // (by default) instead of Split for that dimension - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - return Maybe::Ok(); - }) - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& op_def, - const user_op::UserOpConfWrapper& op_conf) -> Maybe { - const std::string& direction = op_conf.attr("direction"); - CHECK_OR_RETURN(direction == "ASCENDING" || direction == "DESCENDING"); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe SortOp::GetSbp(user_op::SbpContext* ctx) { + // The current implementation can only do sort in the last dimension and should use Broadcast + // (by default) instead of Split for that dimension + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SortOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe SortOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SortOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe SortOp::CheckAttr(const user_op::UserOpDefWrapper&, + const user_op::UserOpConfWrapper& op_conf) { + const std::string& direction = op_conf.attr("direction"); + CHECK_OR_RETURN(direction == "ASCENDING" || direction == "DESCENDING"); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/sparse_cross_entropy_op.cpp b/oneflow/user/ops/sparse_cross_entropy_op.cpp index ff4ea687b00..adf9acdebfd 100644 --- a/oneflow/user/ops/sparse_cross_entropy_op.cpp +++ b/oneflow/user/ops/sparse_cross_entropy_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -71,7 +72,52 @@ Maybe InferDataTypeGrad(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe AddMsSignature(user_op::SbpContext* ctx) { +Maybe GenBackwardOpConf4SparseCrossEntropy(const std::string& op_type_name, + const user_op::UserOpWrapper& op, + const user_op::AddOpFn& AddOp) { + if (op.NeedGenGradTensor4OpInput("prediction", 0)) { + user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); + user_op::UserOpConfWrapper grad_op = builder.Op(op_type_name) + .Input("prediction", op.input("prediction", 0)) + .Input("label", op.input("label", 0)) + .Input("dy", op.GetGradTensorWithOpOutput("out", 0)) + .Output("prediction_diff") + .Attr("depth", op.attr("depth")) + .Build(); + op.BindGradTensorWithOpInput(grad_op.output("prediction_diff", 0), "prediction", 0); + AddOp(grad_op); + } + return Maybe::Ok(); +} + +} // namespace + +/*static*/ Maybe SparseCrossEntropyOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("prediction", 0), 0) + .Split(user_op::OpArg("label", 0), 0) + .Split(user_op::OpArg("out", 0), 0) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe SparseCrossEntropyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDescFn(ctx); +} +/*static*/ Maybe SparseCrossEntropyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SparseCrossEntropyOp::InferDataType(user_op::InferContext* ctx) { + return oneflow::InferDataType(ctx); +} +/*static*/ Maybe SparseCrossEntropyOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); + CHECK_OR_RETURN(label_modifier != nullptr); + label_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/*static*/ Maybe SparseCrossEntropyMsOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& prediction = ctx->LogicalTensorDesc4InputArgNameAndIndex("prediction", 0); ctx->NewBuilder() @@ -86,17 +132,45 @@ Maybe AddMsSignature(user_op::SbpContext* ctx) { .Build(); return Maybe::Ok(); } +/*static*/ Maybe SparseCrossEntropyMsOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDescFn(ctx); +} +/*static*/ Maybe SparseCrossEntropyMsOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SparseCrossEntropyMsOp::InferDataType(user_op::InferContext* ctx) { + return oneflow::InferDataType(ctx); +} +/*static*/ Maybe SparseCrossEntropyMsOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); + CHECK_OR_RETURN(label_modifier != nullptr); + label_modifier->set_requires_grad(false); + return Maybe::Ok(); +} -Maybe AddSignature(user_op::SbpContext* ctx) { +/*static*/ Maybe SparseCrossEntropyGradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("prediction", 0), 0) .Split(user_op::OpArg("label", 0), 0) - .Split(user_op::OpArg("out", 0), 0) + .Split(user_op::OpArg("dy", 0), 0) + .Split(user_op::OpArg("prediction_diff", 0), 0) .Build(); return Maybe::Ok(); } +/*static*/ Maybe SparseCrossEntropyGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return InferGradTensorDescFn(ctx); +} +/*static*/ Maybe SparseCrossEntropyGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SparseCrossEntropyGradOp::InferDataType(user_op::InferContext* ctx) { + return InferDataTypeGrad(ctx); +} -Maybe AddGradMsSignature(user_op::SbpContext* ctx) { +/*static*/ Maybe SparseCrossEntropyMsGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& prediction = ctx->LogicalTensorDesc4InputArgNameAndIndex("prediction", 0); ctx->NewBuilder() @@ -113,76 +187,18 @@ Maybe AddGradMsSignature(user_op::SbpContext* ctx) { .Build(); return Maybe::Ok(); } - -Maybe AddGradSignature(user_op::SbpContext* ctx) { - ctx->NewBuilder() - .Split(user_op::OpArg("prediction", 0), 0) - .Split(user_op::OpArg("label", 0), 0) - .Split(user_op::OpArg("dy", 0), 0) - .Split(user_op::OpArg("prediction_diff", 0), 0) - .Build(); - return Maybe::Ok(); +/*static*/ Maybe SparseCrossEntropyMsGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return InferGradTensorDescFn(ctx); } - -template (*GetSbpSignature)(user_op::SbpContext*)> -Maybe GetSbpFn(user_op::SbpContext* ctx) { - JUST(GetSbpSignature(ctx)); - return Maybe::Ok(); +/*static*/ Maybe SparseCrossEntropyMsGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); } - -Maybe GenBackwardOpConf4SparseCrossEntropy(const std::string& op_type_name, - const user_op::UserOpWrapper& op, - user_op::AddOpFn AddOp) { - if (op.NeedGenGradTensor4OpInput("prediction", 0)) { - user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); - user_op::UserOpConfWrapper grad_op = builder.Op(op_type_name) - .Input("prediction", op.input("prediction", 0)) - .Input("label", op.input("label", 0)) - .Input("dy", op.GetGradTensorWithOpOutput("out", 0)) - .Output("prediction_diff") - .Attr("depth", op.attr("depth")) - .Build(); - op.BindGradTensorWithOpInput(grad_op.output("prediction_diff", 0), "prediction", 0); - AddOp(grad_op); - } - return Maybe::Ok(); +/*static*/ Maybe SparseCrossEntropyMsGradOp::InferDataType(user_op::InferContext* ctx) { + return InferDataTypeGrad(ctx); } -} // namespace - -#define REGISTER_SPAESE_CROSS_ENTROPY_USER_OP(op_name, sbp_sig) \ - REGISTER_USER_OP(op_name) \ - .Input("prediction") \ - .Input("label") \ - .Output("out") \ - .Attr("depth") \ - .SetTensorDescInferFn(InferTensorDescFn) \ - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, \ - const user_op::UserOpConfWrapper&) -> Maybe { \ - user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); \ - CHECK_OR_RETURN(label_modifier != nullptr); \ - label_modifier->set_requires_grad(false); \ - return Maybe::Ok(); \ - }) \ - .SetGetSbpFn(GetSbpFn) \ - .SetDataTypeInferFn(InferDataType); - -#define REGISTER_SPAESE_CROSS_ENTROPY_GRAD_USER_OP(op_name, sbp_sig) \ - REGISTER_USER_OP(op_name) \ - .Input("prediction") \ - .Input("label") \ - .Input("dy") \ - .Output("prediction_diff") \ - .Attr("depth") \ - .SetTensorDescInferFn(InferGradTensorDescFn) \ - .SetGetSbpFn(GetSbpFn) \ - .SetDataTypeInferFn(InferDataTypeGrad); - -REGISTER_SPAESE_CROSS_ENTROPY_USER_OP("sparse_cross_entropy", AddSignature); -REGISTER_SPAESE_CROSS_ENTROPY_USER_OP("sparse_cross_entropy_ms", AddMsSignature); -REGISTER_SPAESE_CROSS_ENTROPY_GRAD_USER_OP("sparse_cross_entropy_grad", AddGradSignature); -REGISTER_SPAESE_CROSS_ENTROPY_GRAD_USER_OP("sparse_cross_entropy_ms_grad", AddGradMsSignature); - REGISTER_USER_OP_GRAD("sparse_cross_entropy") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp b/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp index e5df89b6148..5550e1caae8 100644 --- a/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp +++ b/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -155,41 +156,47 @@ Maybe GenBackwardOpConf4SparseSoftmaxCrossEntropy(const std::string& op_ty } // namespace -#define REGISTER_SPAESE_SOFTMAX_CROSS_ENTROPY_USER_OP(op_name, sbp_sig) \ - REGISTER_USER_OP(op_name) \ - .Input("prediction") \ - .Input("label") \ - .Output("prob") \ - .Output("out") \ - .Attr("depth") \ - .SetTensorDescInferFn(InferTensorDescFn) \ - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, \ - const user_op::UserOpConfWrapper&) -> Maybe { \ - user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); \ - CHECK_OR_RETURN(label_modifier != nullptr); \ - label_modifier->set_requires_grad(false); \ - return Maybe::Ok(); \ - }) \ - .SetGetSbpFn(GetSbpFn) \ - .SetDataTypeInferFn(InferDataType); - -#define REGISTER_SPAESE_SOFTMAX_CROSS_ENTROPY_GRAD_USER_OP(op_name, sbp_sig) \ - REGISTER_USER_OP(op_name) \ - .Input("label") \ - .Input("dy") \ - .Input("prob") \ - .Output("prediction_diff") \ - .Attr("depth") \ - .SetTensorDescInferFn(InferGradTensorDescFn) \ - .SetGetSbpFn(GetSbpFn) \ - .SetDataTypeInferFn(InferDataTypeGrad); - -REGISTER_SPAESE_SOFTMAX_CROSS_ENTROPY_USER_OP("sparse_softmax_cross_entropy", AddSignature); -REGISTER_SPAESE_SOFTMAX_CROSS_ENTROPY_USER_OP("sparse_softmax_cross_entropy_ms", AddMsSignature); -REGISTER_SPAESE_SOFTMAX_CROSS_ENTROPY_GRAD_USER_OP("sparse_softmax_cross_entropy_grad", - AddGradSignature); -REGISTER_SPAESE_SOFTMAX_CROSS_ENTROPY_GRAD_USER_OP("sparse_softmax_cross_entropy_ms_grad", - AddGradMsSignature); +#define IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_OP_FUNCS(op_name, sbp_sig) \ + /*static*/ Maybe op_name##Op::GetSbp(user_op::SbpContext* ctx) { return sbp_sig(ctx); } \ + /*static*/ Maybe op_name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return InferTensorDescFn(ctx); \ + } \ + /*static*/ Maybe op_name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe op_name##Op::InferDataType(user_op::InferContext* ctx) { \ + return oneflow::InferDataType(ctx); \ + } \ + /*static*/ Maybe op_name##Op::ModifyInputArg( \ + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { \ + user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); \ + CHECK_OR_RETURN(label_modifier != nullptr); \ + label_modifier->set_requires_grad(false); \ + return Maybe::Ok(); \ + } + +IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_OP_FUNCS(SparseSoftmaxCrossEntropy, AddSignature); +IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_OP_FUNCS(SparseSoftmaxCrossEntropyMs, AddMsSignature); +#undef IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_OP_FUNCS + +#define IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_GRAD_OP_FUNCS(op_name, sbp_sig) \ + /*static*/ Maybe op_name##GradOp::GetSbp(user_op::SbpContext* ctx) { \ + return sbp_sig(ctx); \ + } \ + /*static*/ Maybe op_name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return InferGradTensorDescFn(ctx); \ + } \ + /*static*/ Maybe op_name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe op_name##GradOp::InferDataType(user_op::InferContext* ctx) { \ + return InferDataTypeGrad(ctx); \ + } + +IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_GRAD_OP_FUNCS(SparseSoftmaxCrossEntropy, AddGradSignature); +IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_GRAD_OP_FUNCS(SparseSoftmaxCrossEntropyMs, + AddGradMsSignature); +#undef IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_GRAD_OP_FUNCS REGISTER_USER_OP_GRAD("sparse_softmax_cross_entropy") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/split_like_op.cpp b/oneflow/user/ops/split_like_op.cpp index 80f98d49404..a0e31349603 100644 --- a/oneflow/user/ops/split_like_op.cpp +++ b/oneflow/user/ops/split_like_op.cpp @@ -14,12 +14,54 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -Maybe InferTensorDesc(user_op::InferContext* ctx) { +/*static*/ Maybe SplitLikeOp::GetSbp(user_op::SbpContext* ctx) { + const auto axis = ctx->Attr("axis"); + const int64_t in_num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape().NumAxes(); + const int64_t like_num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape().NumAxes(); + FOR_RANGE(int64_t, i, 0, like_num_axes) { + if (i == axis) { continue; } + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + std::vector like_arg_vec; + const size_t like_arg_size = ctx->outputs().size(); + like_arg_vec.reserve(like_arg_size); + FOR_RANGE(int32_t, i, 0, like_arg_size) { like_arg_vec.emplace_back("like", i); } + FOR_RANGE(int64_t, i, like_num_axes, in_num_axes) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), i) + .Broadcast(like_arg_vec) + .Split(ctx->outputs(), i) + .Build(); + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), i) + .PartialSum(like_arg_vec) + .Split(ctx->outputs(), i) + .Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(like_arg_vec) + .PartialSum(ctx->outputs()) + .Build(); + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .Broadcast(like_arg_vec) + .PartialSum(ctx->outputs()) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("in", 0)) + .PartialSum(like_arg_vec) + .Broadcast(ctx->outputs()) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe SplitLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const auto axis = ctx->Attr("axis"); const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); int64_t dynamic_dim_size = 0; @@ -57,8 +99,10 @@ Maybe InferTensorDesc(user_op::InferContext* ctx) { } return Maybe::Ok(); } - -Maybe InferDataType(user_op::InferContext* ctx) { +/*static*/ Maybe SplitLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SplitLikeOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); FOR_RANGE(int32_t, i, 0, ctx->outputs().size()) { user_op::TensorDesc* out_i_desc = ctx->OutputTensorDesc("out", i); @@ -66,9 +110,8 @@ Maybe InferDataType(user_op::InferContext* ctx) { } return Maybe::Ok(); } - -Maybe SetLikeArgModifier(user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper& user_op_conf) { +/*static*/ Maybe SplitLikeOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& user_op_conf) { FOR_RANGE(int32_t, i, 0, user_op_conf.input_size("like")) { user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", i); CHECK_NOTNULL_OR_RETURN(like_modifier); @@ -77,50 +120,15 @@ Maybe SetLikeArgModifier(user_op::GetInputArgModifier GetInputArgModifierF return Maybe::Ok(); } -Maybe GetSbpSignature(user_op::SbpContext* ctx) { - const auto axis = ctx->Attr("axis"); - const int64_t in_num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape().NumAxes(); - const int64_t like_num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape().NumAxes(); - FOR_RANGE(int64_t, i, 0, like_num_axes) { - if (i == axis) { continue; } - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - std::vector like_arg_vec; - const size_t like_arg_size = ctx->outputs().size(); - like_arg_vec.reserve(like_arg_size); - FOR_RANGE(int32_t, i, 0, like_arg_size) { like_arg_vec.emplace_back("like", i); } - FOR_RANGE(int64_t, i, like_num_axes, in_num_axes) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Broadcast(like_arg_vec) - .Split(ctx->outputs(), i) - .Build(); - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .PartialSum(like_arg_vec) - .Split(ctx->outputs(), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(like_arg_vec) - .PartialSum(ctx->outputs()) - .Build(); - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .Broadcast(like_arg_vec) - .PartialSum(ctx->outputs()) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("in", 0)) - .PartialSum(like_arg_vec) - .Broadcast(ctx->outputs()) - .Build(); +/*static*/ Maybe SplitLikeOp::CheckAttr(const user_op::UserOpDefWrapper&, + const user_op::UserOpConfWrapper& op_conf) { + CHECK_OR_RETURN(op_conf.input_size("like") >= 2); + CHECK_OR_RETURN(op_conf.output_size("out") >= 2); return Maybe::Ok(); } +namespace { + Maybe GenGradOp(const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) { const int64_t axis = op.attr("axis"); const int32_t out_size = op.output_size("out"); @@ -158,16 +166,6 @@ Maybe GenGradOp(const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) } // namespace -REGISTER_USER_OP("split_like") - .Input("in") - .InputWithMinimum("like", 2) - .OutputWithMinimum("out", 2) - .Attr("axis") - .SetTensorDescInferFn(InferTensorDesc) - .SetInputArgModifyFn(SetLikeArgModifier) - .SetGetSbpFn(GetSbpSignature) - .SetDataTypeInferFn(InferDataType); - REGISTER_USER_OP_GRAD("split_like").SetGenBackwardOpConfFn(GenGradOp); } // namespace oneflow diff --git a/oneflow/user/ops/sqrt_square_sum_op.cpp b/oneflow/user/ops/sqrt_square_sum_op.cpp new file mode 100644 index 00000000000..f8c6b43ca5b --- /dev/null +++ b/oneflow/user/ops/sqrt_square_sum_op.cpp @@ -0,0 +1,41 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" + +namespace oneflow { + +/*static*/ Maybe SqrtSquareSumOp::GetSbp(user_op::SbpContext* ctx) { + const int64_t num_x_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape().NumAxes(); + FOR_RANGE(int64_t, i, 0, num_x_axes) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).PartialSum(user_op::OpArg("y", 0)).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SqrtSquareSumOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + *y->mut_shape() = Shape({}); + return Maybe::Ok(); +} +/*static*/ Maybe SqrtSquareSumOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SqrtSquareSumOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} + +} // namespace oneflow diff --git a/oneflow/user/ops/square_sum_op.cpp b/oneflow/user/ops/square_sum_op.cpp index 494688fd9ce..c97d3219046 100644 --- a/oneflow/user/ops/square_sum_op.cpp +++ b/oneflow/user/ops/square_sum_op.cpp @@ -14,61 +14,62 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("square_sum") - .Input("x") - .Output("y") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); - *y->mut_shape() = Shape({1}); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const int64_t num_x_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape().NumAxes(); - FOR_RANGE(int64_t, i, 0, num_x_axes) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .PartialSum(user_op::OpArg("y", 0)) - .Build(); - } - return Maybe::Ok(); - }); - -REGISTER_USER_OP("multi_square_sum") - .InputWithMinimum("x", 1) - .Output("y") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); - *y->mut_shape() = Shape({1}); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_0 = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); - for (int64_t i = 1; i < ctx->input_size("x"); ++i) { - const user_op::TensorDesc& x_i = ctx->InputTensorDesc("x", i); - CHECK_EQ_OR_RETURN(x_i.data_type(), x_0.data_type()); - } - *y->mut_data_type() = x_0.data_type(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - int64_t min_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape().NumAxes(); - for (int64_t i = 1; i < ctx->user_op_conf().input_size("x"); ++i) { - min_num_axes = std::min( - min_num_axes, ctx->LogicalTensorDesc4InputArgNameAndIndex("x", i).shape().NumAxes()); - } - for (int64_t i = 0; i < min_num_axes; ++i) { - ctx->NewBuilder().Split(ctx->inputs(), i).PartialSum(user_op::OpArg("y", 0)).Build(); - } - return Maybe::Ok(); - }); +/*static*/ Maybe SquareSumOp::GetSbp(user_op::SbpContext* ctx) { + const int64_t num_x_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape().NumAxes(); + FOR_RANGE(int64_t, i, 0, num_x_axes) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).PartialSum(user_op::OpArg("y", 0)).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SquareSumOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + *y->mut_shape() = Shape({1}); + return Maybe::Ok(); +} +/*static*/ Maybe SquareSumOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SquareSumOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe MultiSquareSumOp::GetSbp(user_op::SbpContext* ctx) { + int64_t min_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape().NumAxes(); + for (int64_t i = 1; i < ctx->user_op_conf().input_size("x"); ++i) { + min_num_axes = std::min(min_num_axes, + ctx->LogicalTensorDesc4InputArgNameAndIndex("x", i).shape().NumAxes()); + } + for (int64_t i = 0; i < min_num_axes; ++i) { + ctx->NewBuilder().Split(ctx->inputs(), i).PartialSum(user_op::OpArg("y", 0)).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe MultiSquareSumOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + *y->mut_shape() = Shape({1}); + return Maybe::Ok(); +} +/*static*/ Maybe MultiSquareSumOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe MultiSquareSumOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_0 = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + for (int64_t i = 1; i < ctx->input_size("x"); ++i) { + const user_op::TensorDesc& x_i = ctx->InputTensorDesc("x", i); + CHECK_EQ_OR_RETURN(x_i.data_type(), x_0.data_type()); + } + *y->mut_data_type() = x_0.data_type(); + return Maybe::Ok(); +} +/*static*/ Maybe MultiSquareSumOp::CheckAttr(const user_op::UserOpDefWrapper&, + const user_op::UserOpConfWrapper& op_conf) { + CHECK_OR_RETURN(op_conf.input_size("x") >= 1); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/squeeze_op.cpp b/oneflow/user/ops/squeeze_op.cpp index 6be95c78159..d6c9cb111a4 100644 --- a/oneflow/user/ops/squeeze_op.cpp +++ b/oneflow/user/ops/squeeze_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -40,21 +41,7 @@ Maybe CheckAndLabelAxesToSqueezeMinusOne(const AxisVector& axes, DimVector } // namespace -Maybe SqueezeTensorDescInferFn(user_op::InferContext* ctx) { - const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); - AxisVector fixed_axes_vec; - JUST(TransformNegativeAxesToPositive(ctx->Attr>("axes"), in_shape.NumAxes(), - &fixed_axes_vec)); - - DimVector dim_vec = in_shape.dim_vec(); - JUST(CheckAndLabelAxesToSqueezeMinusOne(fixed_axes_vec, &dim_vec)); - dim_vec.erase(std::remove(dim_vec.begin(), dim_vec.end(), -1), dim_vec.end()); - *out_shape = Shape(dim_vec); - return Maybe::Ok(); -} - -Maybe SqueezeGetSbpFn(user_op::SbpContext* ctx) { +/*static*/ Maybe SqueezeOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); AxisVector fixed_axes_vec; JUST(TransformNegativeAxesToPositive(ctx->Attr>("axes"), @@ -74,17 +61,26 @@ Maybe SqueezeGetSbpFn(user_op::SbpContext* ctx) { } return Maybe::Ok(); } +/*static*/ Maybe SqueezeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + Shape* out_shape = ctx->OutputShape("out", 0); + AxisVector fixed_axes_vec; + JUST(TransformNegativeAxesToPositive(ctx->Attr>("axes"), in_shape.NumAxes(), + &fixed_axes_vec)); -REGISTER_USER_OP("squeeze") - .Input("in") - .Output("out") - .Attr>("axes") - .SetTensorDescInferFn(SqueezeTensorDescInferFn) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(SqueezeGetSbpFn); + DimVector dim_vec = in_shape.dim_vec(); + JUST(CheckAndLabelAxesToSqueezeMinusOne(fixed_axes_vec, &dim_vec)); + dim_vec.erase(std::remove(dim_vec.begin(), dim_vec.end(), -1), dim_vec.end()); + *out_shape = Shape(dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe SqueezeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SqueezeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("squeeze").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/ssp_variable_proxy_op.cpp b/oneflow/user/ops/ssp_variable_proxy_op.cpp index 1af7c0c4138..9a5a31262a7 100644 --- a/oneflow/user/ops/ssp_variable_proxy_op.cpp +++ b/oneflow/user/ops/ssp_variable_proxy_op.cpp @@ -14,46 +14,41 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -REGISTER_NO_GRAD_USER_OP("ssp_variable_proxy") - .Input("var") - .Output("ref") - .Output("value") - .Attr("buffer_size", 1) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& var_shape = ctx->InputShape("var", 0); - *ctx->OutputShape("ref", 0) = var_shape; - *ctx->OutputShape("value", 0) = var_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto& var_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("var", 0); - FOR_RANGE(int64_t, i, 0, var_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("var", 0), i) - .Split(user_op::OpArg("ref", 0), i) - .Split(user_op::OpArg("value", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("ref", 0) = ctx->InputDType("var", 0); - *ctx->OutputDType("value", 0) = ctx->InputDType("var", 0); - return Maybe::Ok(); - }) - .SetOutputArgModifyFn([](user_op::GetOutputArgModifier GetOutputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> Maybe { - user_op::OutputArgModifier* out_modifier = GetOutputArgModifierFn("ref", 0); - CHECK_OR_RETURN(out_modifier != nullptr); - out_modifier->set_is_mutable(true); - return Maybe::Ok(); - }); - -} // namespace +/*static*/ Maybe SspVariableProxyOp::GetSbp(user_op::SbpContext* ctx) { + const auto& var_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("var", 0); + FOR_RANGE(int64_t, i, 0, var_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("var", 0), i) + .Split(user_op::OpArg("ref", 0), i) + .Split(user_op::OpArg("value", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SspVariableProxyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& var_shape = ctx->InputShape("var", 0); + *ctx->OutputShape("ref", 0) = var_shape; + *ctx->OutputShape("value", 0) = var_shape; + return Maybe::Ok(); +} +/*static*/ Maybe SspVariableProxyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SspVariableProxyOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("ref", 0) = ctx->InputDType("var", 0); + *ctx->OutputDType("value", 0) = ctx->InputDType("var", 0); + return Maybe::Ok(); +} +/*static*/ Maybe SspVariableProxyOp::ModifyOutputArg( + const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::OutputArgModifier* out_modifier = GetOutputArgModifierFn("ref", 0); + CHECK_OR_RETURN(out_modifier != nullptr); + out_modifier->set_is_mutable(true); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/summary_ops.cpp b/oneflow/user/ops/summary_ops.cpp index 0026e7820fd..6235856d1fc 100644 --- a/oneflow/user/ops/summary_ops.cpp +++ b/oneflow/user/ops/summary_ops.cpp @@ -14,11 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" -#include "oneflow/core/framework/user_op_attr.pb.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace summary { +namespace { Maybe CheckStepShape(const Shape* step) { CHECK_OR_RETURN(step->elem_cnt() == 1); @@ -37,51 +37,85 @@ Maybe CheckInAndStepScalar(user_op::InferContext* ctx) { return Maybe::Ok(); } -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("create_summary_writer") - .Attr("logdir") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { return Maybe::Ok(); }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("flush_summary_writer") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { return Maybe::Ok(); }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("summary_write_scalar") - .Input("in") - .Input("step") - .Input("tag") - .SetTensorDescInferFn(CheckInAndStepScalar) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { return Maybe::Ok(); }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("summary_write_histogram") - .Input("in") - .Input("step") - .Input("tag") - .SetTensorDescInferFn(CheckStepShapeInCtx) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { return Maybe::Ok(); }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("summary_write_pb") - .Input("in") - .Input("step") - .SetTensorDescInferFn(CheckStepShapeInCtx) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { return Maybe::Ok(); }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("summary_write_image") - .Input("in") - .Input("step") - .Input("tag") - .SetTensorDescInferFn(CheckStepShapeInCtx) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { return Maybe::Ok(); }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); -} // namespace summary +} // namespace + +/*static*/ Maybe CreateSummaryWriterOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe CreateSummaryWriterOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return Maybe::Ok(); +} +/*static*/ Maybe CreateSummaryWriterOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return Maybe::Ok(); +} +/*static*/ Maybe CreateSummaryWriterOp::InferDataType(user_op::InferContext* ctx) { + return Maybe::Ok(); +} + +/*static*/ Maybe FlushSummaryWriterOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe FlushSummaryWriterOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return Maybe::Ok(); +} +/*static*/ Maybe FlushSummaryWriterOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return Maybe::Ok(); +} +/*static*/ Maybe FlushSummaryWriterOp::InferDataType(user_op::InferContext* ctx) { + return Maybe::Ok(); +} + +/*static*/ Maybe SummaryWriteScalarOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe SummaryWriteScalarOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return CheckInAndStepScalar(ctx); +} +/*static*/ Maybe SummaryWriteScalarOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SummaryWriteScalarOp::InferDataType(user_op::InferContext* ctx) { + return Maybe::Ok(); +} + +/*static*/ Maybe SummaryWriteHistogramOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe SummaryWriteHistogramOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return CheckStepShapeInCtx(ctx); +} +/*static*/ Maybe SummaryWriteHistogramOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SummaryWriteHistogramOp::InferDataType(user_op::InferContext* ctx) { + return Maybe::Ok(); +} + +/*static*/ Maybe SummaryWritePbOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe SummaryWritePbOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return CheckStepShapeInCtx(ctx); +} +/*static*/ Maybe SummaryWritePbOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SummaryWritePbOp::InferDataType(user_op::InferContext* ctx) { + return Maybe::Ok(); +} + +/*static*/ Maybe SummaryWriteImageOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe SummaryWriteImageOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return CheckStepShapeInCtx(ctx); +} +/*static*/ Maybe SummaryWriteImageOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SummaryWriteImageOp::InferDataType(user_op::InferContext* ctx) { + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/tanh_op.cpp b/oneflow/user/ops/tanh_op.cpp index 639b9529e72..caf89a63ac9 100644 --- a/oneflow/user/ops/tanh_op.cpp +++ b/oneflow/user/ops/tanh_op.cpp @@ -14,23 +14,35 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("tanh") - .Input("x") - .Output("y") - .SetTensorDescInferFn(user_op::TensorDescInferFnUtil::Unchanged) - .SetGetSbpFn(user_op::GetSbpFnUtil::SplitForEachAxis) - .SetDataTypeInferFn(user_op::TensorDescInferFnUtil::UnchangedDataType); - -REGISTER_USER_OP((std::string("") + "tanh" + "_grad")) - .Input("x") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn(user_op::TensorDescInferFnUtil::Unchanged) - .SetGetSbpFn(user_op::GetSbpFnUtil::SplitForEachAxis) - .SetDataTypeInferFn(user_op::TensorDescInferFnUtil::UnchangedDataType); +/*static*/ Maybe TanhOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); +} +/*static*/ Maybe TanhOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return user_op::TensorDescInferFnUtil::Unchanged(ctx); +} +/*static*/ Maybe TanhOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TanhOp::InferDataType(user_op::InferContext* ctx) { + return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx); +} + +/*static*/ Maybe TanhGradOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); +} +/*static*/ Maybe TanhGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return user_op::TensorDescInferFnUtil::Unchanged(ctx); +} +/*static*/ Maybe TanhGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TanhGradOp::InferDataType(user_op::InferContext* ctx) { + return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx); +} REGISTER_USER_OP_GRAD("tanh").SetGenBackwardOpConfFn( [](const user_op::UserOpWrapper& op, const user_op::AddOpFn& AddOp) -> Maybe { diff --git a/oneflow/user/ops/tensor_buffer_ops.cpp b/oneflow/user/ops/tensor_buffer_ops.cpp index f1e964c50db..80b1c5c99ff 100644 --- a/oneflow/user/ops/tensor_buffer_ops.cpp +++ b/oneflow/user/ops/tensor_buffer_ops.cpp @@ -14,199 +14,197 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ Maybe TensorBufferToTensorOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe TensorBufferToTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + out->set_is_dynamic(in.is_dynamic()); + const auto& instance_shape = ctx->Attr("instance_shape"); + DimVector dim_vec; + dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin(), in.shape().dim_vec().cend()); + dim_vec.insert(dim_vec.end(), instance_shape.dim_vec().cbegin(), instance_shape.dim_vec().cend()); + *out->mut_shape() = Shape(dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe TensorBufferToTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TensorBufferToTensorOp::InferDataType(user_op::InferContext* ctx) { + const auto data_type = ctx->Attr("dtype"); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(IsPODDataType(data_type)); + *out->mut_data_type() = data_type; + return Maybe::Ok(); +} -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("tensor_buffer_to_tensor") - .Input("in") - .Output("out") - .Attr("instance_shape") - .Attr("dtype") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - out->set_is_dynamic(in.is_dynamic()); - const auto& instance_shape = ctx->Attr("instance_shape"); - DimVector dim_vec; - dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin(), in.shape().dim_vec().cend()); - dim_vec.insert(dim_vec.end(), instance_shape.dim_vec().cbegin(), - instance_shape.dim_vec().cend()); - *out->mut_shape() = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto data_type = ctx->Attr("dtype"); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(IsPODDataType(data_type)); - *out->mut_data_type() = data_type; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }); +/*static*/ Maybe TensorToTensorBufferOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + const auto& instance_dims = ctx->Attr("instance_dims"); + CHECK_LE_OR_RETURN(instance_dims, in.shape().NumAxes()); + FOR_RANGE(int64_t, i, 0, in.shape().NumAxes() - instance_dims) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe TensorToTensorBufferOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + const Shape& in_shape = in.shape(); + const auto& instance_dims = ctx->Attr("instance_dims"); + CHECK_LT_OR_RETURN(instance_dims, in_shape.NumAxes()); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + out->set_is_dynamic(in.is_dynamic()); + DimVector out_dim_vec; + out_dim_vec.insert(out_dim_vec.end(), in_shape.dim_vec().cbegin(), + in_shape.dim_vec().cend() - instance_dims); + *out->mut_shape() = Shape(out_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe TensorToTensorBufferOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TensorToTensorBufferOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + CHECK_OR_RETURN(IsPODDataType(in.data_type())); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + *out->mut_data_type() = DataType::kTensorBuffer; + return Maybe::Ok(); +} -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("tensor_to_tensor_buffer") - .Input("in") - .Output("out") - .Attr("instance_dims") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - const Shape& in_shape = in.shape(); - const auto& instance_dims = ctx->Attr("instance_dims"); - CHECK_LT_OR_RETURN(instance_dims, in_shape.NumAxes()); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - out->set_is_dynamic(in.is_dynamic()); - DimVector out_dim_vec; - out_dim_vec.insert(out_dim_vec.end(), in_shape.dim_vec().cbegin(), - in_shape.dim_vec().cend() - instance_dims); - *out->mut_shape() = Shape(out_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - CHECK_OR_RETURN(IsPODDataType(in.data_type())); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - *out->mut_data_type() = DataType::kTensorBuffer; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - const auto& instance_dims = ctx->Attr("instance_dims"); - CHECK_LE_OR_RETURN(instance_dims, in.shape().NumAxes()); - FOR_RANGE(int64_t, i, 0, in.shape().NumAxes() - instance_dims) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }); +/*static*/ Maybe GenTensorBufferOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe GenTensorBufferOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + const Shape& shape = ctx->Attr("shape"); + const int64_t num_tensor_buffers = shape.elem_cnt(); + const std::vector& shape_list = ctx->Attr>("shape_list"); + const std::vector& value_list = ctx->Attr>("value_list"); + CHECK_EQ_OR_RETURN(num_tensor_buffers, shape_list.size()); + CHECK_EQ_OR_RETURN(num_tensor_buffers, value_list.size()); + *out->mut_shape() = shape; + out->set_is_dynamic(ctx->Attr("dynamic_out")); + return Maybe::Ok(); +} +/*static*/ Maybe GenTensorBufferOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe GenTensorBufferOp::InferDataType(user_op::InferContext* ctx) { + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + *out->mut_data_type() = DataType::kTensorBuffer; + return Maybe::Ok(); +} -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("gen_tensor_buffer") - .Output("out") - .Attr("shape") - .Attr>("shape_list") - .Attr>("value_list") - .Attr("data_type") - .Attr("dynamic_out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - const Shape& shape = ctx->Attr("shape"); - const int64_t num_tensor_buffers = shape.elem_cnt(); - const std::vector& shape_list = ctx->Attr>("shape_list"); - const std::vector& value_list = ctx->Attr>("value_list"); - CHECK_EQ_OR_RETURN(num_tensor_buffers, shape_list.size()); - CHECK_EQ_OR_RETURN(num_tensor_buffers, value_list.size()); - *out->mut_shape() = shape; - out->set_is_dynamic(ctx->Attr("dynamic_out")); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - *out->mut_data_type() = DataType::kTensorBuffer; - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/*static*/ Maybe TensorBufferToListOfTensorsOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe TensorBufferToListOfTensorsOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + CHECK_GT_OR_RETURN(in.shape().elem_cnt(), 0); + CHECK_OR_RETURN(!in.is_dynamic()); + const Shape& out_shape = ctx->Attr("out_shape"); + const bool dynamic_out = ctx->Attr("dynamic_out"); + int64_t num_tensor_buffers = in.shape().elem_cnt(); + for (int64_t i = 0; i < num_tensor_buffers; ++i) { + user_op::TensorDesc* out_i = ctx->OutputTensorDesc("out", i); + *out_i->mut_shape() = out_shape; + out_i->set_is_dynamic(dynamic_out); + } + return Maybe::Ok(); +} +/*static*/ Maybe TensorBufferToListOfTensorsOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TensorBufferToListOfTensorsOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + CHECK_EQ_OR_RETURN(in.data_type(), DataType::kTensorBuffer); + const DataType out_dtype = ctx->Attr("out_dtype"); + CHECK_OR_RETURN(IsPODDataType(out_dtype)); + int64_t num_tensor_buffers = ctx->outputs().size(); + for (int64_t i = 0; i < num_tensor_buffers; ++i) { + user_op::TensorDesc* out_i = ctx->OutputTensorDesc("out", i); + *out_i->mut_data_type() = out_dtype; + } + return Maybe::Ok(); +} +/*static*/ Maybe TensorBufferToListOfTensorsOp::ModifyOutputArg( + const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + if (conf.attr("dynamic_out")) { + FOR_RANGE(int64_t, i, 0, conf.output_size("out")) { + user_op::OutputArgModifier* out_i_modifier = GetOutputArgModifierFn("out", i); + CHECK_OR_RETURN(out_i_modifier != nullptr); + out_i_modifier->set_header_infered_before_compute(false); + } + } + return Maybe::Ok(); +} -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("tensor_buffer_to_list_of_tensors") - .Input("in") - .OutputWithMinimum("out", 1) - .Attr("out_shape") - .Attr("out_dtype") - .Attr("dynamic_out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - CHECK_GT_OR_RETURN(in.shape().elem_cnt(), 0); - CHECK_OR_RETURN(!in.is_dynamic()); - const Shape& out_shape = ctx->Attr("out_shape"); - const bool dynamic_out = ctx->Attr("dynamic_out"); - int64_t num_tensor_buffers = in.shape().elem_cnt(); - for (int64_t i = 0; i < num_tensor_buffers; ++i) { - user_op::TensorDesc* out_i = ctx->OutputTensorDesc("out", i); - *out_i->mut_shape() = out_shape; - out_i->set_is_dynamic(dynamic_out); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - CHECK_EQ_OR_RETURN(in.data_type(), DataType::kTensorBuffer); - const DataType out_dtype = ctx->Attr("out_dtype"); - CHECK_OR_RETURN(IsPODDataType(out_dtype)); - int64_t num_tensor_buffers = ctx->outputs().size(); - for (int64_t i = 0; i < num_tensor_buffers; ++i) { - user_op::TensorDesc* out_i = ctx->OutputTensorDesc("out", i); - *out_i->mut_data_type() = out_dtype; - } - return Maybe::Ok(); - }) - .SetOutputArgModifyFn([](user_op::GetOutputArgModifier GetOutputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> Maybe { - if (conf.attr("dynamic_out")) { - FOR_RANGE(int64_t, i, 0, conf.output_size("out")) { - user_op::OutputArgModifier* out_i_modifier = GetOutputArgModifierFn("out", i); - CHECK_OR_RETURN(out_i_modifier != nullptr); - out_i_modifier->set_header_infered_before_compute(false); - } - } - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/*static*/ Maybe TensorBufferToListOfTensorsOp::CheckAttr( + const user_op::UserOpDefWrapper&, const user_op::UserOpConfWrapper& op_conf) { + CHECK_OR_RETURN(op_conf.output_size("out") >= 1); + return Maybe::Ok(); +} -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("tensor_buffer_to_list_of_tensors_v2") - .Input("in") - .OutputWithMinimum("out", 1) - .Attr>("out_shapes") - .Attr>("out_dtypes") - .Attr("dynamic_out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - CHECK_GT_OR_RETURN(in.shape().elem_cnt(), 0); - CHECK_OR_RETURN(!in.is_dynamic()); - const std::vector& out_shapes = ctx->Attr>("out_shapes"); - const bool dynamic_out = ctx->Attr("dynamic_out"); - int64_t num_tensor_buffers = in.shape().elem_cnt(); - for (int64_t i = 0; i < num_tensor_buffers; ++i) { - user_op::TensorDesc* out_i = ctx->OutputTensorDesc("out", i); - *out_i->mut_shape() = out_shapes[i]; - out_i->set_is_dynamic(dynamic_out); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - CHECK_EQ_OR_RETURN(in.data_type(), DataType::kTensorBuffer); - const std::vector& out_dtypes = ctx->Attr>("out_dtypes"); - int64_t num_tensor_buffers = ctx->outputs().size(); - for (int64_t i = 0; i < num_tensor_buffers; ++i) { - CHECK_OR_RETURN(IsPODDataType(out_dtypes[i])); - user_op::TensorDesc* out_i = ctx->OutputTensorDesc("out", i); - *out_i->mut_data_type() = out_dtypes[i]; - } - return Maybe::Ok(); - }) - .SetOutputArgModifyFn([](user_op::GetOutputArgModifier GetOutputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> Maybe { - if (conf.attr("dynamic_out")) { - FOR_RANGE(int64_t, i, 0, conf.output_size("out")) { - user_op::OutputArgModifier* out_i_modifier = GetOutputArgModifierFn("out", i); - CHECK_OR_RETURN(out_i_modifier != nullptr); - out_i_modifier->set_header_infered_before_compute(false); - } - } - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -} // namespace +/*static*/ Maybe TensorBufferToListOfTensorsV2Op::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe TensorBufferToListOfTensorsV2Op::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + CHECK_GT_OR_RETURN(in.shape().elem_cnt(), 0); + CHECK_OR_RETURN(!in.is_dynamic()); + const std::vector& out_shapes = ctx->Attr>("out_shapes"); + const bool dynamic_out = ctx->Attr("dynamic_out"); + int64_t num_tensor_buffers = in.shape().elem_cnt(); + for (int64_t i = 0; i < num_tensor_buffers; ++i) { + user_op::TensorDesc* out_i = ctx->OutputTensorDesc("out", i); + *out_i->mut_shape() = out_shapes[i]; + out_i->set_is_dynamic(dynamic_out); + } + return Maybe::Ok(); +} +/*static*/ Maybe TensorBufferToListOfTensorsV2Op::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TensorBufferToListOfTensorsV2Op::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + CHECK_EQ_OR_RETURN(in.data_type(), DataType::kTensorBuffer); + const std::vector& out_dtypes = ctx->Attr>("out_dtypes"); + int64_t num_tensor_buffers = ctx->outputs().size(); + for (int64_t i = 0; i < num_tensor_buffers; ++i) { + CHECK_OR_RETURN(IsPODDataType(out_dtypes[i])); + user_op::TensorDesc* out_i = ctx->OutputTensorDesc("out", i); + *out_i->mut_data_type() = out_dtypes[i]; + } + return Maybe::Ok(); +} +/*static*/ Maybe TensorBufferToListOfTensorsV2Op::ModifyOutputArg( + const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + if (conf.attr("dynamic_out")) { + FOR_RANGE(int64_t, i, 0, conf.output_size("out")) { + user_op::OutputArgModifier* out_i_modifier = GetOutputArgModifierFn("out", i); + CHECK_OR_RETURN(out_i_modifier != nullptr); + out_i_modifier->set_header_infered_before_compute(false); + } + } + return Maybe::Ok(); +} +/*static*/ Maybe TensorBufferToListOfTensorsV2Op::CheckAttr( + const user_op::UserOpDefWrapper&, const user_op::UserOpConfWrapper& op_conf) { + CHECK_OR_RETURN(op_conf.output_size("out") >= 1); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/test_ops.cpp b/oneflow/user/ops/test_ops.cpp index efbff608f89..2b2249dc8f3 100644 --- a/oneflow/user/ops/test_ops.cpp +++ b/oneflow/user/ops/test_ops.cpp @@ -15,210 +15,206 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/balanced_splitter.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("ccrelu") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); - *out_shape = in_shape; - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe CcreluOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe CcreluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("x", 0); + Shape* out_shape = ctx->OutputShape("y", 0); + *out_shape = in_shape; + return Maybe::Ok(); +} +/*static*/ Maybe CcreluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe CcreluOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("ccrelu_grad") - .Input("y") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& y_shape = ctx->InputShape("y", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(dy_shape == y_shape); - *dx_shape = y_shape; - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("y", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("y", 0), 0) - .Split(user_op::OpArg("dy", 0), 0) - .Split(user_op::OpArg("dx", 0), 0) - .Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe CcreluGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("y", 0), 0) + .Split(user_op::OpArg("dy", 0), 0) + .Split(user_op::OpArg("dx", 0), 0) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe CcreluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& y_shape = ctx->InputShape("y", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(dy_shape == y_shape); + *dx_shape = y_shape; + return Maybe::Ok(); +} +/*static*/ Maybe CcreluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe CcreluGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("y", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("ccrelu").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { - if (op.NeedGenGradTensor4OpInput("in", 0)) { + if (op.NeedGenGradTensor4OpInput("x", 0)) { user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); user_op::UserOpConfWrapper ccrelu_grad_op = builder.Op("ccrelu_grad") - .Input("y", op.output("out", 0)) - .Input("dy", op.GetGradTensorWithOpOutput("out", 0)) + .Input("y", op.output("y", 0)) + .Input("dy", op.GetGradTensorWithOpOutput("y", 0)) .Output("dx") .Build(); - op.BindGradTensorWithOpInput(ccrelu_grad_op.output("dx", 0), "in", 0); + op.BindGradTensorWithOpInput(ccrelu_grad_op.output("dx", 0), "x", 0); AddOp(ccrelu_grad_op); } return Maybe::Ok(); }); -REGISTER_USER_OP("TestReshape") - .Input("in") - .Output("out") - .Attr("shape") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); - const Shape& conf_shape = ctx->Attr("shape"); - CHECK_EQ_OR_RETURN(in_shape.NumAxes(), conf_shape.NumAxes()); - *out_shape = conf_shape; - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/*static*/ Maybe TestReshapeOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe TestReshapeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + Shape* out_shape = ctx->OutputShape("out", 0); + const Shape& conf_shape = ctx->Attr("shape"); + CHECK_EQ_OR_RETURN(in_shape.NumAxes(), conf_shape.NumAxes()); + *out_shape = conf_shape; + return Maybe::Ok(); +} +/*static*/ Maybe TestReshapeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TestReshapeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("TestSource") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - Shape* out_shape = ctx->OutputShape("out", 0); - *out_shape = Shape({5}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = DataType::kFloat; - return Maybe::Ok(); - }); +/*static*/ Maybe TestSourceOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe TestSourceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + Shape* out_shape = ctx->OutputShape("out", 0); + *out_shape = Shape({5}); + return Maybe::Ok(); +} +/*static*/ Maybe TestSourceOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TestSourceOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = DataType::kFloat; + return Maybe::Ok(); +} -REGISTER_USER_OP("TestMultiOutputOrder") - .Input("in") - .Output("out1") - .Output("out2") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out1_shape = ctx->OutputShape("out1", 0); - Shape* out2_shape = ctx->OutputShape("out2", 0); - *out1_shape = in_shape; - *out2_shape = in_shape; - int32_t last_axis = in_shape.NumAxes() - 1; - out2_shape->Set(last_axis, in_shape.At(last_axis) * 2); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out1", 0) = ctx->InputDType("in", 0); - *ctx->OutputDType("out2", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe TestMultiOutputOrderOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe TestMultiOutputOrderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + Shape* out1_shape = ctx->OutputShape("out1", 0); + Shape* out2_shape = ctx->OutputShape("out2", 0); + *out1_shape = in_shape; + *out2_shape = in_shape; + int32_t last_axis = in_shape.NumAxes() - 1; + out2_shape->Set(last_axis, in_shape.At(last_axis) * 2); + return Maybe::Ok(); +} +/*static*/ Maybe TestMultiOutputOrderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TestMultiOutputOrderOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out1", 0) = ctx->InputDType("in", 0); + *ctx->OutputDType("out2", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("TestSourceMultiGpuFixedOutNum") - .Output("out") - .Attr("out_num") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - Shape* out_shape = ctx->OutputShape("out", 0); - int64_t out_num = ctx->Attr("out_num"); - *out_shape = Shape({out_num}); - return Maybe::Ok(); - }) - .SetPhysicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - Shape* out_shape = ctx->OutputShape("out", 0); - int64_t out_num = ctx->Attr("out_num"); - const ParallelContext& parallel_ctx = ctx->parallel_ctx(); - BalancedSplitter bs(out_num, parallel_ctx.parallel_num()); - *out_shape = Shape({bs.At(parallel_ctx.parallel_id()).size()}); +/*static*/ Maybe TestSourceMultiGpuFixedOutNumOp::GetSbp(user_op::SbpContext* ctx) { + int64_t parallel_num = ctx->parallel_num(); + DeviceType device_type = ctx->device_type(); + if (device_type == DeviceType::kCPU && parallel_num > 1) { + ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe TestSourceMultiGpuFixedOutNumOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + Shape* out_shape = ctx->OutputShape("out", 0); + int64_t out_num = ctx->Attr("out_num"); + *out_shape = Shape({out_num}); + return Maybe::Ok(); +} +/*static*/ Maybe TestSourceMultiGpuFixedOutNumOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + Shape* out_shape = ctx->OutputShape("out", 0); + int64_t out_num = ctx->Attr("out_num"); + const ParallelContext& parallel_ctx = ctx->parallel_ctx(); + BalancedSplitter bs(out_num, parallel_ctx.parallel_num()); + *out_shape = Shape({bs.At(parallel_ctx.parallel_id()).size()}); - const cfg::SbpParallel& out_sbp = ctx->SbpParallel4ArgNameAndIndex("out", 0); - CHECK_OR_RETURN(out_sbp.has_split_parallel() && out_sbp.split_parallel().axis() == 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = DataType::kFloat; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - int64_t parallel_num = ctx->parallel_num(); - DeviceType device_type = ctx->device_type(); - if (device_type == DeviceType::kCPU && parallel_num > 1) { - ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); - } - return Maybe::Ok(); - }); + const cfg::SbpParallel& out_sbp = ctx->SbpParallel4ArgNameAndIndex("out", 0); + CHECK_OR_RETURN(out_sbp.has_split_parallel() && out_sbp.split_parallel().axis() == 0); + return Maybe::Ok(); +} +/*static*/ Maybe TestSourceMultiGpuFixedOutNumOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = DataType::kFloat; + return Maybe::Ok(); +} -REGISTER_USER_OP("TestMultiInput") - .Input("x1") - .Input("x2") - .Output("y") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x1_shape = ctx->InputShape("x1", 0); - const Shape& x2_shape = ctx->InputShape("x2", 0); - Shape* y_shape = ctx->OutputShape("y", 0); - CHECK_OR_RETURN(x1_shape == x2_shape); - *y_shape = x1_shape; - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x1", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x1_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x1", 0); - FOR_RANGE(int64_t, i, 0, x1_tensor.shape().NumAxes()) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - return Maybe::Ok(); - }); +/*static*/ Maybe TestMultiInputOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x1_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x1", 0); + FOR_RANGE(int64_t, i, 0, x1_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe TestMultiInputOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x1_shape = ctx->InputShape("x1", 0); + const Shape& x2_shape = ctx->InputShape("x2", 0); + Shape* y_shape = ctx->OutputShape("y", 0); + CHECK_OR_RETURN(x1_shape == x2_shape); + *y_shape = x1_shape; + return Maybe::Ok(); +} +/*static*/ Maybe TestMultiInputOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TestMultiInputOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x1", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("TestMultiInputGrad") - .Input("x1") - .Input("x2") - .Input("y_diff") - .Output("x1_diff") - .Output("x2_diff") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x1_shape = ctx->InputShape("x1", 0); - const Shape& x2_shape = ctx->InputShape("x2", 0); - Shape* x1_diff_shape = ctx->OutputShape("x1_diff", 0); - Shape* x2_diff_shape = ctx->OutputShape("x2_diff", 0); - *x1_diff_shape = x1_shape; - *x2_diff_shape = x2_shape; - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("x1_diff", 0) = ctx->InputDType("x1", 0); - *ctx->OutputDType("x2_diff", 0) = ctx->InputDType("x2", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x1_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x1", 0); - FOR_RANGE(int64_t, i, 0, x1_tensor.shape().NumAxes()) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - return Maybe::Ok(); - }); +/*static*/ Maybe TestMultiInputGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x1_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x1", 0); + FOR_RANGE(int64_t, i, 0, x1_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe TestMultiInputGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x1_shape = ctx->InputShape("x1", 0); + const Shape& x2_shape = ctx->InputShape("x2", 0); + Shape* x1_diff_shape = ctx->OutputShape("x1_diff", 0); + Shape* x2_diff_shape = ctx->OutputShape("x2_diff", 0); + *x1_diff_shape = x1_shape; + *x2_diff_shape = x2_shape; + return Maybe::Ok(); +} +/*static*/ Maybe TestMultiInputGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TestMultiInputGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("x1_diff", 0) = ctx->InputDType("x1", 0); + *ctx->OutputDType("x2_diff", 0) = ctx->InputDType("x2", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("TestMultiInput") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -240,110 +236,121 @@ REGISTER_USER_OP_GRAD("TestMultiInput") return Maybe::Ok(); }); -REGISTER_USER_OP("TestDynamicSource") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - *out_tensor->mut_shape() = Shape({5}); - out_tensor->set_is_dynamic(true); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = DataType::kFloat; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetOutputArgModifyFn([](user_op::GetOutputArgModifier GetOutputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> Maybe { - user_op::OutputArgModifier* out_modifier = GetOutputArgModifierFn("out", 0); - CHECK_OR_RETURN(out_modifier != nullptr); - out_modifier->set_header_infered_before_compute(false); - return Maybe::Ok(); - }); +/*static*/ Maybe TestDynamicSourceOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe TestDynamicSourceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + *out_tensor->mut_shape() = Shape({5}); + out_tensor->set_is_dynamic(true); + return Maybe::Ok(); +} +/*static*/ Maybe TestDynamicSourceOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TestDynamicSourceOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = DataType::kFloat; + return Maybe::Ok(); +} +/*static*/ Maybe TestDynamicSourceOp::ModifyOutputArg( + const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::OutputArgModifier* out_modifier = GetOutputArgModifierFn("out", 0); + CHECK_OR_RETURN(out_modifier != nullptr); + out_modifier->set_header_infered_before_compute(false); + return Maybe::Ok(); +} -REGISTER_USER_OP("TestRandomSource") - .Output("out") - .Attr("seed") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - *out_tensor->mut_shape() = Shape({5}); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = DataType::kFloat; - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/*static*/ Maybe TestRandomSourceOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe TestRandomSourceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + *out_tensor->mut_shape() = Shape({5}); + return Maybe::Ok(); +} +/*static*/ Maybe TestRandomSourceOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TestRandomSourceOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = DataType::kFloat; + return Maybe::Ok(); +} -REGISTER_USER_OP("TestDataTypeAttr") - .Input("in") - .Output("out") - .Attr("output_type") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); - *out_shape = in_shape; - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->Attr("output_type"); - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/*static*/ Maybe TestDataTypeAttrOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe TestDataTypeAttrOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + Shape* out_shape = ctx->OutputShape("out", 0); + *out_shape = in_shape; + return Maybe::Ok(); +} +/*static*/ Maybe TestDataTypeAttrOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TestDataTypeAttrOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->Attr("output_type"); + return Maybe::Ok(); +} -REGISTER_USER_OP("TestListDataTypeAndListShapeAndListStringAttr") - .Input("in") - .Output("out", 3) - .Attr>("out_shapes") - .Attr>("out_types") - .Attr>("string_list") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto& out_shapes = ctx->Attr>("out_shapes"); - const auto& string_list = ctx->Attr>("string_list"); - FOR_RANGE(int32_t, i, 0, ctx->outputs().size()) { - *ctx->OutputShape("out", i) = out_shapes.at(i); - } - CHECK_GT_OR_RETURN(string_list.size(), 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto& out_types = ctx->Attr>("out_types"); - FOR_RANGE(int32_t, i, 0, ctx->outputs().size()) { - *ctx->OutputDType("out", i) = out_types.at(i); - } - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/*static*/ Maybe TestListDataTypeAndListShapeAndListStringAttrOp::GetSbp( + user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe TestListDataTypeAndListShapeAndListStringAttrOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const auto& out_shapes = ctx->Attr>("out_shapes"); + const auto& string_list = ctx->Attr>("string_list"); + FOR_RANGE(int32_t, i, 0, ctx->outputs().size()) { + *ctx->OutputShape("out", i) = out_shapes.at(i); + } + CHECK_GT_OR_RETURN(string_list.size(), 0); + return Maybe::Ok(); +} +/*static*/ Maybe TestListDataTypeAndListShapeAndListStringAttrOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TestListDataTypeAndListShapeAndListStringAttrOp::InferDataType( + user_op::InferContext* ctx) { + const auto& out_types = ctx->Attr>("out_types"); + FOR_RANGE(int32_t, i, 0, ctx->outputs().size()) { *ctx->OutputDType("out", i) = out_types.at(i); } + return Maybe::Ok(); +} -REGISTER_USER_OP("test_user_op_attr_auto_type") - .Input("in") - .Output("out") - .Attr("int1") - .Attr("int2") - .SetTensorDescInferFn(user_op::TensorDescInferFnUtil::Unchanged) - .SetDataTypeInferFn(user_op::TensorDescInferFnUtil::UnchangedDataType) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/*static*/ Maybe TestUserOpAttrAutoTypeOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe TestUserOpAttrAutoTypeOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return user_op::TensorDescInferFnUtil::Unchanged(ctx); +} +/*static*/ Maybe TestUserOpAttrAutoTypeOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TestUserOpAttrAutoTypeOp::InferDataType(user_op::InferContext* ctx) { + return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx); +} -REGISTER_CPU_ONLY_USER_OP("cpu_only_relu_test") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto& in_desc = ctx->InputTensorDesc("in", 0); - auto* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_shape() = in_desc.shape(); - *out_desc->mut_is_dynamic() = in_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe CpuOnlyReluTestOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe CpuOnlyReluTestOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const auto& in_desc = ctx->InputTensorDesc("in", 0); + auto* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_shape() = in_desc.shape(); + *out_desc->mut_is_dynamic() = in_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ Maybe CpuOnlyReluTestOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe CpuOnlyReluTestOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/tf_prelu_op.cpp b/oneflow/user/ops/tf_prelu_op.cpp index f264ecb2378..b4880e201e7 100644 --- a/oneflow/user/ops/tf_prelu_op.cpp +++ b/oneflow/user/ops/tf_prelu_op.cpp @@ -14,111 +14,106 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("tf_prelu") - .Input("x") - .Input("alpha") - .Output("y") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - const Shape& alpha_shape = ctx->InputShape("alpha", 0); - CHECK_EQ_OR_RETURN(x_desc.shape().NumAxes(), alpha_shape.NumAxes() + 1); - FOR_RANGE(int64_t, i, 1, x_desc.shape().NumAxes()) { - CHECK_OR_RETURN((alpha_shape.At(i - 1) == x_desc.shape().At(i)) - || (alpha_shape.At(i - 1) == 1)); - } - *y_desc->mut_shape() = x_desc.shape(); - *y_desc->mut_is_dynamic() = x_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - const user_op::TensorDesc& alpha_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("alpha", 0); +/*static*/ Maybe TfPreluOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + const user_op::TensorDesc& alpha_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("alpha", 0); + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), 0) + .Broadcast(user_op::OpArg("alpha", 0)) + .Split(user_op::OpArg("y", 0), 0) + .Build(); + FOR_RANGE(int64_t, i, 1, x_tensor.shape().NumAxes()) { + if (x_tensor.shape().At(i) == alpha_tensor.shape().At(i - 1)) { ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), 0) - .Broadcast(user_op::OpArg("alpha", 0)) - .Split(user_op::OpArg("y", 0), 0) + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("alpha", 0), i - 1) + .Split(user_op::OpArg("y", 0), i) .Build(); - FOR_RANGE(int64_t, i, 1, x_tensor.shape().NumAxes()) { - if (x_tensor.shape().At(i) == alpha_tensor.shape().At(i - 1)) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("alpha", 0), i - 1) - .Split(user_op::OpArg("y", 0), i) - .Build(); - } - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); + } + } + return Maybe::Ok(); +} +/*static*/ Maybe TfPreluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + const Shape& alpha_shape = ctx->InputShape("alpha", 0); + CHECK_EQ_OR_RETURN(x_desc.shape().NumAxes(), alpha_shape.NumAxes() + 1); + FOR_RANGE(int64_t, i, 1, x_desc.shape().NumAxes()) { + CHECK_OR_RETURN((alpha_shape.At(i - 1) == x_desc.shape().At(i)) + || (alpha_shape.At(i - 1) == 1)); + } + *y_desc->mut_shape() = x_desc.shape(); + *y_desc->mut_is_dynamic() = x_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ Maybe TfPreluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TfPreluOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("tf_prelu_grad") - .Input("dy") - .Input("x") - .Input("alpha") - .Output("dx") - .Output("alpha_diff") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); - const user_op::TensorDesc& alpha_desc = ctx->InputTensorDesc("alpha", 0); - CHECK_EQ_OR_RETURN(x_desc.shape().NumAxes(), alpha_desc.shape().NumAxes() + 1); - FOR_RANGE(int64_t, i, 1, x_desc.shape().NumAxes()) { - CHECK_OR_RETURN((alpha_desc.shape().At(i - 1) == x_desc.shape().At(i)) - || (alpha_desc.shape().At(i - 1) == 1)); - } - CHECK_EQ_OR_RETURN(dy_desc.shape(), x_desc.shape()); - CHECK_EQ_OR_RETURN(dy_desc.data_type(), x_desc.data_type()); - *dx_desc->mut_shape() = x_desc.shape(); - *dx_desc->mut_is_dynamic() = x_desc.is_dynamic(); - *ctx->OutputShape("alpha_diff", 0) = alpha_desc.shape(); - *ctx->OutputIsDynamic("alpha_diff", 0) = alpha_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - const user_op::TensorDesc& alpha_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("alpha", 0); - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), 0) - .Split(user_op::OpArg("x", 0), 0) - .Broadcast(user_op::OpArg("alpha", 0)) - .Split(user_op::OpArg("dx", 0), 0) - .PartialSum(user_op::OpArg("alpha_diff", 0)) - .Build(); +/*static*/ Maybe TfPreluGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + const user_op::TensorDesc& alpha_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("alpha", 0); + ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), 0) + .Split(user_op::OpArg("x", 0), 0) + .Broadcast(user_op::OpArg("alpha", 0)) + .Split(user_op::OpArg("dx", 0), 0) + .PartialSum(user_op::OpArg("alpha_diff", 0)) + .Build(); + ctx->NewBuilder() + .PartialSum(user_op::OpArg("dy", 0)) + .Broadcast(user_op::OpArg("x", 0)) + .Broadcast(user_op::OpArg("alpha", 0)) + .PartialSum(user_op::OpArg("dx", 0)) + .PartialSum(user_op::OpArg("alpha_diff", 0)) + .Build(); + FOR_RANGE(int64_t, i, 1, x_tensor.shape().NumAxes()) { + if (x_tensor.shape().At(i) == alpha_tensor.shape().At(i - 1)) { ctx->NewBuilder() - .PartialSum(user_op::OpArg("dy", 0)) - .Broadcast(user_op::OpArg("x", 0)) - .Broadcast(user_op::OpArg("alpha", 0)) - .PartialSum(user_op::OpArg("dx", 0)) - .PartialSum(user_op::OpArg("alpha_diff", 0)) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("alpha", 0), i - 1) + .Split(user_op::OpArg("dx", 0), i) + .Split(user_op::OpArg("alpha_diff", 0), i - 1) .Build(); - FOR_RANGE(int64_t, i, 1, x_tensor.shape().NumAxes()) { - if (x_tensor.shape().At(i) == alpha_tensor.shape().At(i - 1)) { - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("alpha", 0), i - 1) - .Split(user_op::OpArg("dx", 0), i) - .Split(user_op::OpArg("alpha_diff", 0), i - 1) - .Build(); - } - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - *ctx->OutputDType("alpha_diff", 0) = ctx->InputDType("alpha", 0); - return Maybe::Ok(); - }); + } + } + return Maybe::Ok(); +} +/*static*/ Maybe TfPreluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); + user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + const user_op::TensorDesc& alpha_desc = ctx->InputTensorDesc("alpha", 0); + CHECK_EQ_OR_RETURN(x_desc.shape().NumAxes(), alpha_desc.shape().NumAxes() + 1); + FOR_RANGE(int64_t, i, 1, x_desc.shape().NumAxes()) { + CHECK_OR_RETURN((alpha_desc.shape().At(i - 1) == x_desc.shape().At(i)) + || (alpha_desc.shape().At(i - 1) == 1)); + } + CHECK_EQ_OR_RETURN(dy_desc.shape(), x_desc.shape()); + CHECK_EQ_OR_RETURN(dy_desc.data_type(), x_desc.data_type()); + *dx_desc->mut_shape() = x_desc.shape(); + *dx_desc->mut_is_dynamic() = x_desc.is_dynamic(); + *ctx->OutputShape("alpha_diff", 0) = alpha_desc.shape(); + *ctx->OutputIsDynamic("alpha_diff", 0) = alpha_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ Maybe TfPreluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TfPreluGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->OutputDType("alpha_diff", 0) = ctx->InputDType("alpha", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("tf_prelu") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/top_k_op.cpp b/oneflow/user/ops/top_k_op.cpp index 12a57114e43..0bcf295d5bd 100644 --- a/oneflow/user/ops/top_k_op.cpp +++ b/oneflow/user/ops/top_k_op.cpp @@ -14,35 +14,33 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("top_k") - .Input("in") - .Output("out") - .Attr("k") - .Attr("sorted") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); - *out_shape = in_shape; - out_shape->Set( - in_shape.NumAxes() - 1, - std::min(ctx->Attr("k"), static_cast(in_shape.dim_vec().back()))); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = DataType::kInt32; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - // The current implementation can only do top_k in the last dimension and should use Broadcast - // (by default) instead of Split for that dimension - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - return Maybe::Ok(); - }); +/*static*/ Maybe TopKOp::GetSbp(user_op::SbpContext* ctx) { + // The current implementation can only do top_k in the last dimension and should use Broadcast + // (by default) instead of Split for that dimension + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe TopKOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + Shape* out_shape = ctx->OutputShape("out", 0); + *out_shape = in_shape; + out_shape->Set(in_shape.NumAxes() - 1, std::min(ctx->Attr("k"), + static_cast(in_shape.dim_vec().back()))); + return Maybe::Ok(); +} +/*static*/ Maybe TopKOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TopKOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = DataType::kInt64; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/transpose_ops.cpp b/oneflow/user/ops/transpose_ops.cpp index 8a3b849ef7f..9d8130e6efb 100644 --- a/oneflow/user/ops/transpose_ops.cpp +++ b/oneflow/user/ops/transpose_ops.cpp @@ -13,9 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -29,43 +28,41 @@ void CheckIsPerm(const std::vector& perm) { } } -REGISTER_USER_OP("transpose") - .Input("input") - .Output("output") - .Attr>("perm") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc("input", 0); - user_op::TensorDesc* out_tensor_desc = ctx->OutputTensorDesc("output", 0); - const Shape& in_shape = in_tensor_desc.shape(); - Shape* out_shape = out_tensor_desc->mut_shape(); - const auto& perm = ctx->Attr>("perm"); - CHECK_EQ_OR_RETURN(perm.size(), in_shape.NumAxes()); - CheckIsPerm(perm); - // if (perm.at(0) != 0) { CHECK_OR_RETURN(!in_tensor_desc->is_dynamic()); } - *out_tensor_desc->mut_shape() = in_tensor_desc.shape(); - *out_tensor_desc->mut_is_dynamic() = in_tensor_desc.is_dynamic(); - FOR_RANGE(size_t, i, 0, perm.size()) { out_shape->Set(i, in_shape.At(perm[i])); } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("output", 0) = ctx->InputDType("input", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& input_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0); - const auto& perm = ctx->Attr>("perm"); - CHECK_EQ_OR_RETURN(perm.size(), input_tensor.shape().NumAxes()); - FOR_RANGE(int32_t, i, 0, perm.size()) { - int32_t axis = perm.at(i); - if (axis < 0) { axis += perm.size(); } - CHECK_GE_OR_RETURN(axis, 0); - CHECK_LT_OR_RETURN(axis, perm.size()); - ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), i).Build(); - } - ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe TransposeOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& input_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0); + const auto& perm = ctx->Attr>("perm"); + CHECK_EQ_OR_RETURN(perm.size(), input_tensor.shape().NumAxes()); + FOR_RANGE(int32_t, i, 0, perm.size()) { + int32_t axis = perm.at(i); + if (axis < 0) { axis += perm.size(); } + CHECK_GE_OR_RETURN(axis, 0); + CHECK_LT_OR_RETURN(axis, perm.size()); + ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), i).Build(); + } + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe TransposeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc("input", 0); + user_op::TensorDesc* out_tensor_desc = ctx->OutputTensorDesc("output", 0); + const Shape& in_shape = in_tensor_desc.shape(); + Shape* out_shape = out_tensor_desc->mut_shape(); + const auto& perm = ctx->Attr>("perm"); + CHECK_EQ_OR_RETURN(perm.size(), in_shape.NumAxes()); + CheckIsPerm(perm); + // if (perm.at(0) != 0) { CHECK_OR_RETURN(!in_tensor_desc->is_dynamic()); } + *out_tensor_desc->mut_shape() = in_tensor_desc.shape(); + *out_tensor_desc->mut_is_dynamic() = in_tensor_desc.is_dynamic(); + FOR_RANGE(size_t, i, 0, perm.size()) { out_shape->Set(i, in_shape.At(perm[i])); } + return Maybe::Ok(); +} +/*static*/ Maybe TransposeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TransposeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("output", 0) = ctx->InputDType("input", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("transpose") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/tril_op.cpp b/oneflow/user/ops/tril_op.cpp index 7324a1a336a..933727beef0 100644 --- a/oneflow/user/ops/tril_op.cpp +++ b/oneflow/user/ops/tril_op.cpp @@ -14,46 +14,43 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("tril") - .Input("in") - .Output("out") - .Attr("diagonal") - .Attr("floating_fill_value", 0) - .Attr("integer_fill_value", 0) - .Attr("is_floating_fill_value", false) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - CHECK_GE_OR_RETURN(in.shape().NumAxes(), 2); - *out->mut_shape() = in.shape(); - *out->mut_is_dynamic() = in.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - *out->mut_data_type() = in.data_type(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in.shape().NumAxes() - 2) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - bool fill_zero = ctx->Attr("is_floating_fill_value") - ? (ctx->Attr("floating_fill_value") == static_cast(0)) - : (ctx->Attr("integer_fill_value") == static_cast(0)); - if (fill_zero) { - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - } - return Maybe::Ok(); - }); +/*static*/ Maybe TrilOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in.shape().NumAxes() - 2) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + bool fill_zero = ctx->Attr("is_floating_fill_value") + ? (ctx->Attr("floating_fill_value") == static_cast(0)) + : (ctx->Attr("integer_fill_value") == static_cast(0)); + if (fill_zero) { + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe TrilOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + CHECK_GE_OR_RETURN(in.shape().NumAxes(), 2); + *out->mut_shape() = in.shape(); + *out->mut_is_dynamic() = in.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ Maybe TrilOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TrilOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + *out->mut_data_type() = in.data_type(); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("tril").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { @@ -70,46 +67,39 @@ REGISTER_USER_OP_GRAD("tril").SetGenBackwardOpConfFn([](const user_op::UserOpWra return Maybe::Ok(); }); -REGISTER_USER_OP("fused_scale_tril") - .Input("in") - .Output("out") - .Attr("diagonal") - .Attr("floating_fill_value", 0) - .Attr("integer_fill_value", 0) - .Attr("is_floating_fill_value", false) - .Attr("floating_scale_value", 1) - .Attr("integer_scale_value", 1) - .Attr("is_floating_scale_value", false) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - CHECK_GE_OR_RETURN(in.shape().NumAxes(), 2); - *out->mut_shape() = in.shape(); - *out->mut_is_dynamic() = in.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - *out->mut_data_type() = in.data_type(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in.shape().NumAxes() - 2) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - bool fill_zero = ctx->Attr("is_floating_fill_value") - ? (ctx->Attr("floating_fill_value") == static_cast(0)) - : (ctx->Attr("integer_fill_value") == static_cast(0)); - if (fill_zero) { - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - } - return Maybe::Ok(); - }); +/*static*/ Maybe FusedScaleTrilOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in.shape().NumAxes() - 2) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + bool fill_zero = ctx->Attr("is_floating_fill_value") + ? (ctx->Attr("floating_fill_value") == static_cast(0)) + : (ctx->Attr("integer_fill_value") == static_cast(0)); + if (fill_zero) { + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe FusedScaleTrilOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + CHECK_GE_OR_RETURN(in.shape().NumAxes(), 2); + *out->mut_shape() = in.shape(); + *out->mut_is_dynamic() = in.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ Maybe FusedScaleTrilOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe FusedScaleTrilOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + *out->mut_data_type() = in.data_type(); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("fused_scale_tril") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/triu_op.cpp b/oneflow/user/ops/triu_op.cpp index 47a7e48f522..00448d7f585 100644 --- a/oneflow/user/ops/triu_op.cpp +++ b/oneflow/user/ops/triu_op.cpp @@ -14,37 +14,37 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("triu") - .Input("in") - .Output("out") - .Attr("diagonal") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - CHECK_GE_OR_RETURN(in.shape().NumAxes(), 2); - *out->mut_shape() = in.shape(); - *out->mut_is_dynamic() = in.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - *out->mut_data_type() = in.data_type(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in.shape().NumAxes() - 2) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe TriuOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in.shape().NumAxes() - 2) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe TriuOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + CHECK_GE_OR_RETURN(in.shape().NumAxes(), 2); + *out->mut_shape() = in.shape(); + *out->mut_is_dynamic() = in.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ Maybe TriuOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TriuOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + *out->mut_data_type() = in.data_type(); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/tuple_identity_op.cpp b/oneflow/user/ops/tuple_identity_op.cpp index f829f9d72c5..b777e39fe5b 100644 --- a/oneflow/user/ops/tuple_identity_op.cpp +++ b/oneflow/user/ops/tuple_identity_op.cpp @@ -15,52 +15,60 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/operator.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("tuple_identity") - .InputWithMinimum("in", 1) - .OutputWithMinimum("out", 1) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const int64_t in_size = ctx->input_size("in"); - CHECK_EQ_OR_RETURN(ctx->output_size("out"), in_size); - for (int64_t i = 0; i < in_size; ++i) { - *ctx->OutputShape("out", i) = ctx->InputShape("in", i); - *ctx->IsDynamic4ArgNameAndIndex("out", i) = ctx->InputIsDynamic("in", i); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const int64_t in_size = ctx->input_size("in"); - CHECK_EQ_OR_RETURN(ctx->output_size("out"), in_size); - for (int64_t i = 0; i < in_size; ++i) { - *ctx->OutputDType("out", i) = ctx->InputDType("in", i); - } - return Maybe::Ok(); - }) - .SetSbpSignatureInferFn([](user_op::InferSbpSignatureFnContext* ctx) -> Maybe { - cfg::SbpSignature* signature = ctx->mutable_sbp_signature(); - const cfg::SbpSignature& sbp_signature_conf = ctx->sbp_signature_conf(); - auto* bn2sbp = signature->mutable_bn_in_op2sbp_parallel(); - const auto& bn2conf_sbp = sbp_signature_conf.bn_in_op2sbp_parallel(); - const int64_t in_size = ctx->user_op_conf().input_size("in"); - CHECK_EQ_OR_RETURN(ctx->user_op_conf().output_size("out"), in_size); - for (int64_t i = 0; i < in_size; ++i) { - const cfg::SbpParallel* sbp_parallel = nullptr; - const std::string ibn = GenRepeatedBn("in", i); - const std::string& obn = GenRepeatedBn("out", i); - const auto& conf_sbp_it = bn2conf_sbp.find(obn); - if (conf_sbp_it == bn2conf_sbp.end()) { - sbp_parallel = &ctx->SbpParallelHint4InputArgNameAndIndex("in", i); - } else { - sbp_parallel = &conf_sbp_it->second; - } - (*bn2sbp)[ibn] = *sbp_parallel; - (*bn2sbp)[obn] = *sbp_parallel; - } - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/*static*/ Maybe TupleIdentityOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe TupleIdentityOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const int64_t in_size = ctx->input_size("in"); + CHECK_EQ_OR_RETURN(ctx->output_size("out"), in_size); + for (int64_t i = 0; i < in_size; ++i) { + *ctx->OutputShape("out", i) = ctx->InputShape("in", i); + *ctx->IsDynamic4ArgNameAndIndex("out", i) = ctx->InputIsDynamic("in", i); + } + return Maybe::Ok(); +} +/*static*/ Maybe TupleIdentityOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TupleIdentityOp::InferDataType(user_op::InferContext* ctx) { + const int64_t in_size = ctx->input_size("in"); + CHECK_EQ_OR_RETURN(ctx->output_size("out"), in_size); + for (int64_t i = 0; i < in_size; ++i) { *ctx->OutputDType("out", i) = ctx->InputDType("in", i); } + return Maybe::Ok(); +} +/*static*/ Maybe TupleIdentityOp::InferSbpSignature( + user_op::InferSbpSignatureFnContext* ctx) { + cfg::SbpSignature* signature = ctx->mutable_sbp_signature(); + const cfg::SbpSignature& sbp_signature_conf = ctx->sbp_signature_conf(); + auto* bn2sbp = signature->mutable_bn_in_op2sbp_parallel(); + const auto& bn2conf_sbp = sbp_signature_conf.bn_in_op2sbp_parallel(); + const int64_t in_size = ctx->user_op_conf().input_size("in"); + CHECK_EQ_OR_RETURN(ctx->user_op_conf().output_size("out"), in_size); + for (int64_t i = 0; i < in_size; ++i) { + const cfg::SbpParallel* sbp_parallel = nullptr; + const std::string ibn = GenRepeatedBn("in", i); + const std::string& obn = GenRepeatedBn("out", i); + const auto& conf_sbp_it = bn2conf_sbp.find(obn); + if (conf_sbp_it == bn2conf_sbp.end()) { + sbp_parallel = &ctx->SbpParallelHint4InputArgNameAndIndex("in", i); + } else { + sbp_parallel = &conf_sbp_it->second; + } + (*bn2sbp)[ibn] = *sbp_parallel; + (*bn2sbp)[obn] = *sbp_parallel; + } + return Maybe::Ok(); +} +/*static*/ Maybe TupleIdentityOp::CheckAttr(const user_op::UserOpDefWrapper&, + const user_op::UserOpConfWrapper& op_conf) { + CHECK_OR_RETURN(op_conf.input_size("in") >= 1); + CHECK_OR_RETURN(op_conf.output_size("out") >= 1); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("tuple_identity") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/two_stage_reduce_ops.cpp b/oneflow/user/ops/two_stage_reduce_ops.cpp index 0d2176faaef..c9c1ac15c11 100644 --- a/oneflow/user/ops/two_stage_reduce_ops.cpp +++ b/oneflow/user/ops/two_stage_reduce_ops.cpp @@ -16,6 +16,7 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/reduce_sbp_util.h" #include "oneflow/core/ndarray/binary_func.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -198,34 +199,41 @@ Maybe GetReduceDeviceStageGradSbpFn(user_op::SbpContext* ctx) { } // namespace -#define REGISTER_REDUCE_DEVICE_STAGE_USER_OP(op_name) \ - REGISTER_USER_OP(op_name) \ - .Input("in") \ - .Output("out") \ - .Output("mask") \ - .Output("count") \ - .Attr>("axis") \ - .SetLogicalTensorDescInferFn(InferReduceDeviceStageLogicalTensorDescFn) \ - .SetPhysicalTensorDescInferFn(InferReduceDeviceStagePhysicalTensorDescFn) \ - .SetDataTypeInferFn(InferReduceDeviceStageDtypeFn) \ - .SetGetSbpFn(GetReduceDeviceStageSbpFn); - -REGISTER_REDUCE_DEVICE_STAGE_USER_OP("reduce_min_device_stage") -REGISTER_REDUCE_DEVICE_STAGE_USER_OP("reduce_max_device_stage") - -#define REGISTER_REDUCE_DEVICE_STAGE_GRAD_USER_OP(op_name) \ - REGISTER_USER_OP(op_name) \ - .Input("out_diff") \ - .Input("mask") \ - .Input("count") \ - .Output("in_diff") \ - .Attr>("axis") \ - .SetTensorDescInferFn(InferReduceDeviceStageGradTensorDescFn) \ - .SetDataTypeInferFn(InferReduceDeviceStageGradDtypeFn) \ - .SetGetSbpFn(GetReduceDeviceStageGradSbpFn); - -REGISTER_REDUCE_DEVICE_STAGE_GRAD_USER_OP("reduce_min_device_stage_grad") -REGISTER_REDUCE_DEVICE_STAGE_GRAD_USER_OP("reduce_max_device_stage_grad") +#define IMPLEMENT_REDUCE_DEVICE_STAGE_USER_OP_FUNCS(op_name) \ + /*static*/ Maybe op_name##Op::GetSbp(user_op::SbpContext* ctx) { \ + return GetReduceDeviceStageSbpFn(ctx); \ + } \ + /*static*/ Maybe op_name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return InferReduceDeviceStageLogicalTensorDescFn(ctx); \ + } \ + /*static*/ Maybe op_name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferReduceDeviceStagePhysicalTensorDescFn(ctx); \ + } \ + /*static*/ Maybe op_name##Op::InferDataType(user_op::InferContext* ctx) { \ + return InferReduceDeviceStageDtypeFn(ctx); \ + } + +IMPLEMENT_REDUCE_DEVICE_STAGE_USER_OP_FUNCS(ReduceMinDeviceStage) +IMPLEMENT_REDUCE_DEVICE_STAGE_USER_OP_FUNCS(ReduceMaxDeviceStage) +#undef IMPLEMENT_REDUCE_DEVICE_STAGE_USER_OP_FUNCS + +#define IMPLEMENT_REDUCE_DEVICE_STAGE_USER_GRAD_OP_FUNCS(op_name) \ + /*static*/ Maybe op_name##GradOp::GetSbp(user_op::SbpContext* ctx) { \ + return GetReduceDeviceStageGradSbpFn(ctx); \ + } \ + /*static*/ Maybe op_name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return InferReduceDeviceStageGradTensorDescFn(ctx); \ + } \ + /*static*/ Maybe op_name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe op_name##GradOp::InferDataType(user_op::InferContext* ctx) { \ + return InferReduceDeviceStageGradDtypeFn(ctx); \ + } + +IMPLEMENT_REDUCE_DEVICE_STAGE_USER_GRAD_OP_FUNCS(ReduceMinDeviceStage) +IMPLEMENT_REDUCE_DEVICE_STAGE_USER_GRAD_OP_FUNCS(ReduceMaxDeviceStage) +#undef IMPLEMENT_REDUCE_DEVICE_STAGE_USER_GRAD_OP_FUNCS Maybe GenBackwardOpConf4ReduceDeviceStage(const std::string& op_type_name, const user_op::UserOpWrapper& op, @@ -255,58 +263,59 @@ Maybe GenBackwardOpConf4ReduceDeviceStage(const std::string& op_type_name, REGISTER_REDUCE_DEVICE_STAGE_USER_OP_GRAD("reduce_min_device_stage", "reduce_min_device_stage_grad") REGISTER_REDUCE_DEVICE_STAGE_USER_OP_GRAD("reduce_max_device_stage", "reduce_max_device_stage_grad") -#define REGISTER_REDUCE_GLOBAL_STAGE_USER_OP(op_name) \ - REGISTER_USER_OP(op_name) \ - .Input("in") \ - .Input("device_count") \ - .Output("out") \ - .Output("mask") \ - .Attr>("axis") \ - .Attr("keepdims") \ - .SetTensorDescInferFn(InferReduceGlobalStageTensorDescFn) \ - .SetDataTypeInferFn(InferReduceGlobalStageDtypeFn) \ - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, \ - const user_op::UserOpConfWrapper&) -> Maybe { \ - user_op::InputArgModifier* device_count_modifier = \ - GetInputArgModifierFn("device_count", 0); \ - device_count_modifier->set_requires_grad(false); \ - return Maybe::Ok(); \ - }) \ - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { \ - ctx->NewBuilder() \ - .Split(user_op::OpArg("in", 0), 0) \ - .Split(user_op::OpArg("device_count", 0), 0) \ - .Split(user_op::OpArg("out", 0), 0) \ - .Split(user_op::OpArg("mask", 0), 0) \ - .Build(); \ - return Maybe::Ok(); \ - }); - -REGISTER_REDUCE_GLOBAL_STAGE_USER_OP("reduce_min_global_stage") -REGISTER_REDUCE_GLOBAL_STAGE_USER_OP("reduce_max_global_stage") - -#define REGISTER_REDUCE_GLOBAL_STAGE_GRAD_USER_OP(op_name) \ - REGISTER_USER_OP(op_name) \ - .Input("out_diff") \ - .Input("mask") \ - .Input("device_count") \ - .Output("in_diff") \ - .Attr>("axis") \ - .Attr("keepdims") \ - .SetTensorDescInferFn(InferReduceGlobalStageGradTensorDescFn) \ - .SetDataTypeInferFn(InferReduceGlobalStageGradDtypeFn) \ - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { \ - ctx->NewBuilder() \ - .Split(user_op::OpArg("out_diff", 0), 0) \ - .Split(user_op::OpArg("mask", 0), 0) \ - .Split(user_op::OpArg("device_count", 0), 0) \ - .Split(user_op::OpArg("in_diff", 0), 0) \ - .Build(); \ - return Maybe::Ok(); \ - }); - -REGISTER_REDUCE_GLOBAL_STAGE_GRAD_USER_OP("reduce_min_global_stage_grad") -REGISTER_REDUCE_GLOBAL_STAGE_GRAD_USER_OP("reduce_max_global_stage_grad") +#define IMPLEMENT_REDUCE_GLOBAL_STAGE_OP_FUNCS(op_name) \ + /*static*/ Maybe op_name##Op::GetSbp(user_op::SbpContext* ctx) { \ + ctx->NewBuilder() \ + .Split(user_op::OpArg("in", 0), 0) \ + .Split(user_op::OpArg("device_count", 0), 0) \ + .Split(user_op::OpArg("out", 0), 0) \ + .Split(user_op::OpArg("mask", 0), 0) \ + .Build(); \ + return Maybe::Ok(); \ + } \ + /*static*/ Maybe op_name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return InferReduceGlobalStageTensorDescFn(ctx); \ + } \ + /*static*/ Maybe op_name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe op_name##Op::InferDataType(user_op::InferContext* ctx) { \ + return InferReduceGlobalStageDtypeFn(ctx); \ + } \ + /*static*/ Maybe op_name##Op::ModifyInputArg( \ + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { \ + user_op::InputArgModifier* device_count_modifier = GetInputArgModifierFn("device_count", 0); \ + device_count_modifier->set_requires_grad(false); \ + return Maybe::Ok(); \ + } + +IMPLEMENT_REDUCE_GLOBAL_STAGE_OP_FUNCS(ReduceMinGlobalStage) +IMPLEMENT_REDUCE_GLOBAL_STAGE_OP_FUNCS(ReduceMaxGlobalStage) +#undef IMPLEMENT_REDUCE_GLOBAL_STAGE_OP_FUNCS + +#define IMPLEMENT_REDUCE_GLOBAL_STAGE_GRAD_OP_FUNCS(op_name) \ + /*static*/ Maybe op_name##GradOp::GetSbp(user_op::SbpContext* ctx) { \ + ctx->NewBuilder() \ + .Split(user_op::OpArg("out_diff", 0), 0) \ + .Split(user_op::OpArg("mask", 0), 0) \ + .Split(user_op::OpArg("device_count", 0), 0) \ + .Split(user_op::OpArg("in_diff", 0), 0) \ + .Build(); \ + return Maybe::Ok(); \ + } \ + /*static*/ Maybe op_name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return InferReduceGlobalStageGradTensorDescFn(ctx); \ + } \ + /*static*/ Maybe op_name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe op_name##GradOp::InferDataType(user_op::InferContext* ctx) { \ + return InferReduceGlobalStageGradDtypeFn(ctx); \ + } + +IMPLEMENT_REDUCE_GLOBAL_STAGE_GRAD_OP_FUNCS(ReduceMinGlobalStage) +IMPLEMENT_REDUCE_GLOBAL_STAGE_GRAD_OP_FUNCS(ReduceMaxGlobalStage) +#undef IMPLEMENT_REDUCE_GLOBAL_STAGE_GRAD_OP_FUNCS Maybe GenBackwardOpConf4ReduceGlobalStage(const std::string& op_type_name, const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/unfold_fold_op.cpp b/oneflow/user/ops/unfold_fold_op.cpp index f4b531d5779..0560561604c 100644 --- a/oneflow/user/ops/unfold_fold_op.cpp +++ b/oneflow/user/ops/unfold_fold_op.cpp @@ -16,11 +16,10 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/nn_util.h" #include "oneflow/core/operator/operator_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace user_op { - namespace { Maybe UnfoldTensorDescInferFn(user_op::InferContext* ctx) { @@ -135,30 +134,26 @@ Maybe GetFoldSbpFn(user_op::SbpContext* ctx) { } // namespace -REGISTER_USER_OP("unfold") - .Input("x") - .Output("y") - .Attr("data_format") - .Attr>("kernel_size") - .Attr>("padding") - .Attr>("strides") - .Attr>("dilation_rate") - .SetTensorDescInferFn(UnfoldTensorDescInferFn) - .SetGetSbpFn(GetUnfoldSbpFn) - .SetDataTypeInferFn(SetUnfoldDTypeFn); - -REGISTER_USER_OP("fold") - .Input("x") - .Output("y") - .Attr>("output_size") - .Attr>("kernel_size") - .Attr>("strides") - .Attr>("padding") - .Attr>("dilation_rate") - .SetTensorDescInferFn(FoldTensorDescInferFn) - .SetGetSbpFn(GetFoldSbpFn) - .SetDataTypeInferFn(FoldDTypeFn); - -} // namespace user_op - -} // namespace oneflow \ No newline at end of file +/*static*/ Maybe UnfoldOp::GetSbp(user_op::SbpContext* ctx) { return GetUnfoldSbpFn(ctx); } +/*static*/ Maybe UnfoldOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return UnfoldTensorDescInferFn(ctx); +} +/*static*/ Maybe UnfoldOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UnfoldOp::InferDataType(user_op::InferContext* ctx) { + return SetUnfoldDTypeFn(ctx); +} + +/*static*/ Maybe FoldOp::GetSbp(user_op::SbpContext* ctx) { return GetFoldSbpFn(ctx); } +/*static*/ Maybe FoldOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return FoldTensorDescInferFn(ctx); +} +/*static*/ Maybe FoldOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe FoldOp::InferDataType(user_op::InferContext* ctx) { + return FoldDTypeFn(ctx); +} + +} // namespace oneflow diff --git a/oneflow/user/ops/unfold_tensor_op.cpp b/oneflow/user/ops/unfold_tensor_op.cpp index 7a6b5f7586f..c383cfee652 100644 --- a/oneflow/user/ops/unfold_tensor_op.cpp +++ b/oneflow/user/ops/unfold_tensor_op.cpp @@ -15,95 +15,83 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/unfold_tensor_kernel_utils.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("unfold_tensor") - .Input("x") - .Output("y") - .Attr("dimension") - .Attr("size") - .Attr("step") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("x", 0); - const int32_t dimension = ctx->Attr("dimension"); - const int32_t size = ctx->Attr("size"); - const int32_t step = ctx->Attr("step"); +/*static*/ Maybe UnfoldTensorOp::GetSbp(user_op::SbpContext* ctx) { + const int32_t dimension = ctx->Attr("dimension"); + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + if (i != dimension) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); + } + } + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UnfoldTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("x", 0); + const int32_t dimension = ctx->Attr("dimension"); + const int32_t size = ctx->Attr("size"); + const int32_t step = ctx->Attr("step"); - const Shape& in_shape = ctx->InputShape("x", 0); - const int32_t in_dim = in_shape.NumAxes(); - CHECK_GE_OR_RETURN(dimension, 0); - CHECK_LE_OR_RETURN(dimension, in_dim - 1); + const Shape& in_shape = ctx->InputShape("x", 0); + const int32_t in_dim = in_shape.NumAxes(); + CHECK_GE_OR_RETURN(dimension, 0); + CHECK_LE_OR_RETURN(dimension, in_dim - 1); - const int32_t max_size = in_dim == 0 ? 1 : in_shape.At(dimension); - CHECK_GT_OR_RETURN(size, 0); - CHECK_LE_OR_RETURN(size, max_size); - CHECK_GT_OR_RETURN(step, 0); + const int32_t max_size = in_dim == 0 ? 1 : in_shape.At(dimension); + CHECK_GT_OR_RETURN(size, 0); + CHECK_LE_OR_RETURN(size, max_size); + CHECK_GT_OR_RETURN(step, 0); - DimVector out_shape(in_dim + 1); - out_shape[in_dim] = size; - FOR_RANGE(int32_t, d, 0, in_dim) { - int32_t in_size_at_d = in.shape().At(d); - if (d == dimension) { - out_shape.at(d) = (in_size_at_d - size) / step + 1; - } else { - out_shape.at(d) = in_size_at_d; - } - } - *ctx->OutputShape("y", 0) = Shape(out_shape); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const int32_t dimension = ctx->Attr("dimension"); - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - if (i != dimension) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("y", 0), i) - .Build(); - } - } - ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); - return Maybe::Ok(); - }); + DimVector out_shape(in_dim + 1); + out_shape[in_dim] = size; + FOR_RANGE(int32_t, d, 0, in_dim) { + int32_t in_size_at_d = in.shape().At(d); + if (d == dimension) { + out_shape.at(d) = (in_size_at_d - size) / step + 1; + } else { + out_shape.at(d) = in_size_at_d; + } + } + *ctx->OutputShape("y", 0) = Shape(out_shape); + return Maybe::Ok(); +} +/*static*/ Maybe UnfoldTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UnfoldTensorOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("unfold_tensor_grad") - .Input("dy") - .Input("x") - .Output("dx") - .Attr("dimension") - .Attr("size") - .Attr("step") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("x", 0); - const Shape& in_shape = in.shape(); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); - *dx_desc->mut_shape() = Shape(in_shape.dim_vec()); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const int32_t dimension = ctx->Attr("dimension"); - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dx", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - if (i != dimension) { - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - } - ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe UnfoldTensorGradOp::GetSbp(user_op::SbpContext* ctx) { + const int32_t dimension = ctx->Attr("dimension"); + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dx", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + if (i != dimension) { + ctx->NewBuilder().Split(user_op::OpArg("dy", 0), i).Split(user_op::OpArg("dx", 0), i).Build(); + } + } + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UnfoldTensorGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("x", 0); + const Shape& in_shape = in.shape(); + user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + *dx_desc->mut_shape() = Shape(in_shape.dim_vec()); + return Maybe::Ok(); +} +/*static*/ Maybe UnfoldTensorGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UnfoldTensorGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("unfold_tensor") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { diff --git a/oneflow/user/ops/unique_with_counts_op.cpp b/oneflow/user/ops/unique_with_counts_op.cpp index bf643a7f377..ea0c120dfa7 100644 --- a/oneflow/user/ops/unique_with_counts_op.cpp +++ b/oneflow/user/ops/unique_with_counts_op.cpp @@ -14,52 +14,51 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("unique_with_counts") - .Input("x") - .Output("y") - .Output("idx") - .Output("count") - .Output("num_unique") - .Attr("out_idx", DataType::kInt32) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - CHECK_EQ_OR_RETURN(x.shape().NumAxes(), 1); +/*static*/ Maybe UniqueWithCountsOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe UniqueWithCountsOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + CHECK_EQ_OR_RETURN(x.shape().NumAxes(), 1); - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); - *y->mut_shape() = x.shape(); - *y->mut_is_dynamic() = x.is_dynamic(); + user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + *y->mut_shape() = x.shape(); + *y->mut_is_dynamic() = x.is_dynamic(); - user_op::TensorDesc* idx = ctx->OutputTensorDesc("idx", 0); - *idx->mut_shape() = x.shape(); - *idx->mut_is_dynamic() = x.is_dynamic(); + user_op::TensorDesc* idx = ctx->OutputTensorDesc("idx", 0); + *idx->mut_shape() = x.shape(); + *idx->mut_is_dynamic() = x.is_dynamic(); - user_op::TensorDesc* count = ctx->OutputTensorDesc("count", 0); - *count->mut_shape() = x.shape(); - *count->mut_is_dynamic() = x.is_dynamic(); + user_op::TensorDesc* count = ctx->OutputTensorDesc("count", 0); + *count->mut_shape() = x.shape(); + *count->mut_is_dynamic() = x.is_dynamic(); - user_op::TensorDesc* num_unique = ctx->OutputTensorDesc("num_unique", 0); - *num_unique->mut_shape() = Shape({1}); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - auto out_idx = ctx->Attr("out_idx"); - CHECK_OR_RETURN(IsIndexDataType(out_idx)); - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); - *y->mut_data_type() = x.data_type(); + user_op::TensorDesc* num_unique = ctx->OutputTensorDesc("num_unique", 0); + *num_unique->mut_shape() = Shape({1}); + return Maybe::Ok(); +} +/*static*/ Maybe UniqueWithCountsOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UniqueWithCountsOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + auto out_idx = ctx->Attr("out_idx"); + CHECK_OR_RETURN(IsIndexDataType(out_idx)); + user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + *y->mut_data_type() = x.data_type(); - user_op::TensorDesc* idx = ctx->OutputTensorDesc("idx", 0); - *idx->mut_data_type() = out_idx; + user_op::TensorDesc* idx = ctx->OutputTensorDesc("idx", 0); + *idx->mut_data_type() = out_idx; - user_op::TensorDesc* count = ctx->OutputTensorDesc("count", 0); - *count->mut_data_type() = out_idx; - user_op::TensorDesc* num_unique = ctx->OutputTensorDesc("num_unique", 0); - *num_unique->mut_data_type() = out_idx; - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); + user_op::TensorDesc* count = ctx->OutputTensorDesc("count", 0); + *count->mut_data_type() = out_idx; + user_op::TensorDesc* num_unique = ctx->OutputTensorDesc("num_unique", 0); + *num_unique->mut_data_type() = out_idx; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/unpack_op.cpp b/oneflow/user/ops/unpack_op.cpp index 0c5c589a70c..b0b4ee12f04 100644 --- a/oneflow/user/ops/unpack_op.cpp +++ b/oneflow/user/ops/unpack_op.cpp @@ -14,55 +14,50 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -REGISTER_USER_OP("unpack") - .Input("in") - .Output("out") - .Attr("unpack_num") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - const Shape& in_shape = in_desc.shape(); - CHECK_GT_OR_RETURN(in_shape.NumAxes(), 0); - const auto unpack_num = ctx->Attr("unpack_num"); - CHECK_EQ_OR_RETURN(in_shape.At(0) % unpack_num, 0); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_shape() = in_desc.shape(); - out_desc->mut_shape()->Set(0, in_shape.At(0) / unpack_num); - *out_desc->mut_is_dynamic() = in_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - *out_desc->mut_data_type() = in_desc.data_type(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetOutputBlobTimeShapeInferFn( - [](user_op::InferOutputBlobTimeShapeFnContext* ctx) -> Maybe { - const int32_t unpack_num = ctx->user_op_conf().attr("unpack_num"); - DimVector time_shape_dim_vec = ctx->TimeShape4InputArgNameAndIndex("in", 0).dim_vec(); - time_shape_dim_vec.emplace_back(unpack_num); - *ctx->mut_output_blob_time_shape() = Shape(time_shape_dim_vec); - return Maybe::Ok(); - }); +/*static*/ Maybe UnpackOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UnpackOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + const Shape& in_shape = in_desc.shape(); + CHECK_GT_OR_RETURN(in_shape.NumAxes(), 0); + const auto unpack_num = ctx->Attr("unpack_num"); + CHECK_EQ_OR_RETURN(in_shape.At(0) % unpack_num, 0); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_shape() = in_desc.shape(); + out_desc->mut_shape()->Set(0, in_shape.At(0) / unpack_num); + *out_desc->mut_is_dynamic() = in_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ Maybe UnpackOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UnpackOp::InferDataType(user_op::InferContext* ctx) { + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + *out_desc->mut_data_type() = in_desc.data_type(); + return Maybe::Ok(); +} +/*static*/ Maybe UnpackOp::InferOutputBlobTimeShape( + user_op::InferOutputBlobTimeShapeFnContext* ctx) { + const int32_t unpack_num = ctx->user_op_conf().attr("unpack_num"); + DimVector time_shape_dim_vec = ctx->TimeShape4InputArgNameAndIndex("in", 0).dim_vec(); + time_shape_dim_vec.emplace_back(unpack_num); + *ctx->mut_output_blob_time_shape() = Shape(time_shape_dim_vec); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("unpack").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { @@ -80,6 +75,4 @@ REGISTER_USER_OP_GRAD("unpack").SetBackwardOpConfGenFn([](user_op::BackwardOpCon return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/unsorted_batch_segment_sum_op.cpp b/oneflow/user/ops/unsorted_batch_segment_sum_op.cpp index bec3711097a..0ba58274570 100644 --- a/oneflow/user/ops/unsorted_batch_segment_sum_op.cpp +++ b/oneflow/user/ops/unsorted_batch_segment_sum_op.cpp @@ -14,69 +14,70 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("unsorted_batch_segment_sum") - .Input("data") - .Input("segment_ids") - .Output("out") - .Attr("num_segments") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& data = ctx->InputTensorDesc("data", 0); - const user_op::TensorDesc& segment_ids = ctx->InputTensorDesc("segment_ids", 0); - CHECK_GE_OR_RETURN(segment_ids.shape().NumAxes(), 1); - CHECK_GE_OR_RETURN(data.shape().NumAxes(), segment_ids.shape().NumAxes()); - CHECK_EQ_OR_RETURN(segment_ids.is_dynamic(), data.is_dynamic()); - const int64_t num_segments = ctx->Attr("num_segments"); - CHECK_GE_OR_RETURN(num_segments, 1); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); +/*static*/ Maybe UnsortedBatchSegmentSumOp::GetSbp(user_op::SbpContext* ctx) { + const int64_t segment_ids_num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("segment_ids", 0).shape().NumAxes(); + CHECK_GT_OR_RETURN(segment_ids_num_axes, 1) + << "UnsortedBatchSegmentSumOp: segment_ids_num_axes equals " << segment_ids_num_axes + << " (should be bigger than 1)."; - FOR_RANGE(int64_t, i, 0, segment_ids.shape().NumAxes() - 1) { - CHECK_EQ_OR_RETURN(segment_ids.shape().At(i), data.shape().At(i)); - } + FOR_RANGE(int64_t, i, 0, segment_ids_num_axes - 1) { + ctx->NewBuilder() + .Split(user_op::OpArg("segment_ids", 0), i) + .Split(user_op::OpArg("data", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + ctx->NewBuilder() + .Broadcast(user_op::OpArg("segment_ids", 0)) + .PartialSum(user_op::OpArg("data", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UnsortedBatchSegmentSumOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& data = ctx->InputTensorDesc("data", 0); + const user_op::TensorDesc& segment_ids = ctx->InputTensorDesc("segment_ids", 0); + CHECK_GE_OR_RETURN(segment_ids.shape().NumAxes(), 1); + CHECK_GE_OR_RETURN(data.shape().NumAxes(), segment_ids.shape().NumAxes()); + CHECK_EQ_OR_RETURN(segment_ids.is_dynamic(), data.is_dynamic()); + const int64_t num_segments = ctx->Attr("num_segments"); + CHECK_GE_OR_RETURN(num_segments, 1); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - DimVector dim_vec(data.shape().dim_vec()); - dim_vec.at(segment_ids.shape().NumAxes() - 1) = num_segments; - *out->mut_shape() = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& data = ctx->InputTensorDesc("data", 0); - const user_op::TensorDesc& segment_ids = ctx->InputTensorDesc("segment_ids", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(IsIndexDataType(segment_ids.data_type())); - *out->mut_data_type() = data.data_type(); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* segment_ids_modifier = GetInputArgModifierFn("segment_ids", 0); - CHECK_NOTNULL_OR_RETURN(segment_ids_modifier); - segment_ids_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const int64_t segment_ids_num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("segment_ids", 0).shape().NumAxes(); - CHECK_GT_OR_RETURN(segment_ids_num_axes, 1) - << "UnsortedBatchSegmentSumOp: segment_ids_num_axes equals " << segment_ids_num_axes - << " (should be bigger than 1)."; + FOR_RANGE(int64_t, i, 0, segment_ids.shape().NumAxes() - 1) { + CHECK_EQ_OR_RETURN(segment_ids.shape().At(i), data.shape().At(i)); + } - FOR_RANGE(int64_t, i, 0, segment_ids_num_axes - 1) { - ctx->NewBuilder() - .Split(user_op::OpArg("segment_ids", 0), i) - .Split(user_op::OpArg("data", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .Broadcast(user_op::OpArg("segment_ids", 0)) - .PartialSum(user_op::OpArg("data", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }); + DimVector dim_vec(data.shape().dim_vec()); + dim_vec.at(segment_ids.shape().NumAxes() - 1) = num_segments; + *out->mut_shape() = Shape(dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe UnsortedBatchSegmentSumOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UnsortedBatchSegmentSumOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& data = ctx->InputTensorDesc("data", 0); + const user_op::TensorDesc& segment_ids = ctx->InputTensorDesc("segment_ids", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(IsIndexDataType(segment_ids.data_type())); + *out->mut_data_type() = data.data_type(); + return Maybe::Ok(); +} +/*static*/ Maybe UnsortedBatchSegmentSumOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* segment_ids_modifier = GetInputArgModifierFn("segment_ids", 0); + CHECK_NOTNULL_OR_RETURN(segment_ids_modifier); + segment_ids_modifier->set_requires_grad(false); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("unsorted_batch_segment_sum") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/unsorted_segment_sum_op.cpp b/oneflow/user/ops/unsorted_segment_sum_op.cpp index 7dca75f50b0..5df5e81e451 100644 --- a/oneflow/user/ops/unsorted_segment_sum_op.cpp +++ b/oneflow/user/ops/unsorted_segment_sum_op.cpp @@ -14,74 +14,71 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("unsorted_segment_sum") - .Input("data") - .Input("segment_ids") - .Output("out") - .Attr("axis") - .Attr("num_segments") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& data_shape = ctx->InputShape("data", 0); - const int64_t axis = ctx->Attr("axis"); - const int64_t num_segments = ctx->Attr("num_segments"); - Shape* out_shape = ctx->OutputShape("out", 0); - const Shape& segment_ids_shape = ctx->InputShape("segment_ids", 0); +/*static*/ Maybe UnsortedSegmentSumOp::GetSbp(user_op::SbpContext* ctx) { + const int64_t data_num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("data", 0).shape().NumAxes(); + const int64_t segment_ids_num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("segment_ids", 0).shape().NumAxes(); + const int64_t axis = ctx->Attr("axis"); + FOR_RANGE(int64_t, i, 0, segment_ids_num_axes) { + ctx->NewBuilder() + .Split(user_op::OpArg("segment_ids", 0), i) + .Split(user_op::OpArg("data", 0), i + axis) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + } + FOR_RANGE(int64_t, i, 0, data_num_axes) { + if (i >= axis && i < axis + segment_ids_num_axes) { continue; } + const int64_t out_split_axis = (i < axis) ? i : i - segment_ids_num_axes + 1; + if (out_split_axis == axis) { continue; } + ctx->NewBuilder() + .Broadcast(user_op::OpArg("segment_ids", 0)) + .Split(user_op::OpArg("data", 0), i) + .Split(user_op::OpArg("out", 0), out_split_axis) + .Build(); + } + ctx->NewBuilder() + .Broadcast(user_op::OpArg("segment_ids", 0)) + .PartialSum(user_op::OpArg("data", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UnsortedSegmentSumOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& data_shape = ctx->InputShape("data", 0); + const int64_t axis = ctx->Attr("axis"); + const int64_t num_segments = ctx->Attr("num_segments"); + Shape* out_shape = ctx->OutputShape("out", 0); + const Shape& segment_ids_shape = ctx->InputShape("segment_ids", 0); - DimVector dim_vec; - dim_vec.insert(dim_vec.end(), data_shape.dim_vec().cbegin(), - data_shape.dim_vec().cbegin() + axis); - dim_vec.emplace_back(num_segments); - dim_vec.insert(dim_vec.end(), - data_shape.dim_vec().cbegin() + axis + segment_ids_shape.NumAxes(), - data_shape.dim_vec().end()); - *out_shape = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_OR_RETURN(IsIndexDataType(ctx->InputDType("segment_ids", 0))); - *ctx->OutputDType("out", 0) = ctx->InputDType("data", 0); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* segment_ids_modifier = GetInputArgModifierFn("segment_ids", 0); - CHECK_NOTNULL_OR_RETURN(segment_ids_modifier); - segment_ids_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const int64_t data_num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("data", 0).shape().NumAxes(); - const int64_t segment_ids_num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("segment_ids", 0).shape().NumAxes(); - const int64_t axis = ctx->Attr("axis"); - FOR_RANGE(int64_t, i, 0, segment_ids_num_axes) { - ctx->NewBuilder() - .Split(user_op::OpArg("segment_ids", 0), i) - .Split(user_op::OpArg("data", 0), i + axis) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - } - FOR_RANGE(int64_t, i, 0, data_num_axes) { - if (i >= axis && i < axis + segment_ids_num_axes) { continue; } - const int64_t out_split_axis = (i < axis) ? i : i - segment_ids_num_axes + 1; - if (out_split_axis == axis) { continue; } - ctx->NewBuilder() - .Broadcast(user_op::OpArg("segment_ids", 0)) - .Split(user_op::OpArg("data", 0), i) - .Split(user_op::OpArg("out", 0), out_split_axis) - .Build(); - } - ctx->NewBuilder() - .Broadcast(user_op::OpArg("segment_ids", 0)) - .PartialSum(user_op::OpArg("data", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }); + DimVector dim_vec; + dim_vec.insert(dim_vec.end(), data_shape.dim_vec().cbegin(), + data_shape.dim_vec().cbegin() + axis); + dim_vec.emplace_back(num_segments); + dim_vec.insert(dim_vec.end(), data_shape.dim_vec().cbegin() + axis + segment_ids_shape.NumAxes(), + data_shape.dim_vec().end()); + *out_shape = Shape(dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe UnsortedSegmentSumOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UnsortedSegmentSumOp::InferDataType(user_op::InferContext* ctx) { + CHECK_OR_RETURN(IsIndexDataType(ctx->InputDType("segment_ids", 0))); + *ctx->OutputDType("out", 0) = ctx->InputDType("data", 0); + return Maybe::Ok(); +} +/*static*/ Maybe UnsortedSegmentSumOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* segment_ids_modifier = GetInputArgModifierFn("segment_ids", 0); + CHECK_NOTNULL_OR_RETURN(segment_ids_modifier); + segment_ids_modifier->set_requires_grad(false); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("unsorted_segment_sum") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -102,97 +99,95 @@ REGISTER_USER_OP_GRAD("unsorted_segment_sum") return Maybe::Ok(); }); -REGISTER_USER_OP("unsorted_segment_sum_like") - .Input("data") - .Input("segment_ids") - .Input("like") - .Output("out") - .Attr("axis") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& data_shape = ctx->InputShape("data", 0); - const Shape& like_shape = ctx->InputShape("like", 0); - const Shape& segment_ids_shape = ctx->InputShape("segment_ids", 0); - const int64_t axis = ctx->Attr("axis"); - CHECK_GE_OR_RETURN(axis, 0); - CHECK_LE_OR_RETURN(axis, like_shape.NumAxes()); - FOR_RANGE(int64_t, i, 0, axis) { CHECK_EQ_OR_RETURN(like_shape.At(i), data_shape.At(i)); } - CHECK_EQ_OR_RETURN(data_shape.NumAxes() - segment_ids_shape.NumAxes() + 1, - like_shape.NumAxes()); - FOR_RANGE(int64_t, i, axis + 1, like_shape.NumAxes()) { - CHECK_EQ_OR_RETURN(like_shape.At(i), data_shape.At(i + segment_ids_shape.NumAxes() - 1)); - } - *ctx->OutputShape("out", 0) = ctx->InputShape("like", 0); - *ctx->IsDynamic4ArgNameAndIndex("out", 0) = ctx->InputIsDynamic("like", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& data = ctx->InputTensorDesc("data", 0); - const user_op::TensorDesc& like = ctx->InputTensorDesc("like", 0); - CHECK_EQ_OR_RETURN(data.data_type(), like.data_type()); - CHECK_OR_RETURN(IsIndexDataType(ctx->InputDType("segment_ids", 0))); - *ctx->OutputDType("out", 0) = ctx->InputDType("like", 0); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* segment_ids_modifier = GetInputArgModifierFn("segment_ids", 0); - CHECK_NOTNULL_OR_RETURN(segment_ids_modifier); - segment_ids_modifier->set_requires_grad(false); - user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); - CHECK_NOTNULL_OR_RETURN(like_modifier); - like_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const int64_t data_num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("data", 0).shape().NumAxes(); - const int64_t segment_ids_num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("segment_ids", 0).shape().NumAxes(); - const int64_t axis = ctx->Attr("axis"); - FOR_RANGE(int64_t, i, 0, segment_ids_num_axes) { - ctx->NewBuilder() - .Split(user_op::OpArg("segment_ids", 0), i) - .Split(user_op::OpArg("data", 0), i + axis) - .Broadcast(user_op::OpArg("like", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - ctx->NewBuilder() - .Split(user_op::OpArg("segment_ids", 0), i) - .Split(user_op::OpArg("data", 0), i + axis) - .PartialSum(user_op::OpArg("like", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - } - FOR_RANGE(int64_t, i, 0, data_num_axes) { - if (i >= axis && i < axis + segment_ids_num_axes) { continue; } - const int64_t out_split_axis = (i < axis) ? i : i - segment_ids_num_axes + 1; - if (out_split_axis == axis) { continue; } - ctx->NewBuilder() - .Broadcast(user_op::OpArg("segment_ids", 0)) - .Split(user_op::OpArg("data", 0), i) - .Split(user_op::OpArg("like", 0), out_split_axis) - .Split(user_op::OpArg("out", 0), out_split_axis) - .Build(); - } - ctx->NewBuilder() - .Broadcast(user_op::OpArg("segment_ids", 0)) - .PartialSum(user_op::OpArg("data", 0)) - .Broadcast(user_op::OpArg("like", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("segment_ids", 0)) - .PartialSum(user_op::OpArg("data", 0)) - .PartialSum(user_op::OpArg("like", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("segment_ids", 0)) - .Broadcast(user_op::OpArg("data", 0)) - .Split(user_op::OpArg("like", 0), axis) - .Split(user_op::OpArg("out", 0), axis) - .Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe UnsortedSegmentSumLikeOp::GetSbp(user_op::SbpContext* ctx) { + const int64_t data_num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("data", 0).shape().NumAxes(); + const int64_t segment_ids_num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("segment_ids", 0).shape().NumAxes(); + const int64_t axis = ctx->Attr("axis"); + FOR_RANGE(int64_t, i, 0, segment_ids_num_axes) { + ctx->NewBuilder() + .Split(user_op::OpArg("segment_ids", 0), i) + .Split(user_op::OpArg("data", 0), i + axis) + .Broadcast(user_op::OpArg("like", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + ctx->NewBuilder() + .Split(user_op::OpArg("segment_ids", 0), i) + .Split(user_op::OpArg("data", 0), i + axis) + .PartialSum(user_op::OpArg("like", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + } + FOR_RANGE(int64_t, i, 0, data_num_axes) { + if (i >= axis && i < axis + segment_ids_num_axes) { continue; } + const int64_t out_split_axis = (i < axis) ? i : i - segment_ids_num_axes + 1; + if (out_split_axis == axis) { continue; } + ctx->NewBuilder() + .Broadcast(user_op::OpArg("segment_ids", 0)) + .Split(user_op::OpArg("data", 0), i) + .Split(user_op::OpArg("like", 0), out_split_axis) + .Split(user_op::OpArg("out", 0), out_split_axis) + .Build(); + } + ctx->NewBuilder() + .Broadcast(user_op::OpArg("segment_ids", 0)) + .PartialSum(user_op::OpArg("data", 0)) + .Broadcast(user_op::OpArg("like", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("segment_ids", 0)) + .PartialSum(user_op::OpArg("data", 0)) + .PartialSum(user_op::OpArg("like", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("segment_ids", 0)) + .Broadcast(user_op::OpArg("data", 0)) + .Split(user_op::OpArg("like", 0), axis) + .Split(user_op::OpArg("out", 0), axis) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UnsortedSegmentSumLikeOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const Shape& data_shape = ctx->InputShape("data", 0); + const Shape& like_shape = ctx->InputShape("like", 0); + const Shape& segment_ids_shape = ctx->InputShape("segment_ids", 0); + const int64_t axis = ctx->Attr("axis"); + CHECK_GE_OR_RETURN(axis, 0); + CHECK_LE_OR_RETURN(axis, like_shape.NumAxes()); + FOR_RANGE(int64_t, i, 0, axis) { CHECK_EQ_OR_RETURN(like_shape.At(i), data_shape.At(i)); } + CHECK_EQ_OR_RETURN(data_shape.NumAxes() - segment_ids_shape.NumAxes() + 1, like_shape.NumAxes()); + FOR_RANGE(int64_t, i, axis + 1, like_shape.NumAxes()) { + CHECK_EQ_OR_RETURN(like_shape.At(i), data_shape.At(i + segment_ids_shape.NumAxes() - 1)); + } + *ctx->OutputShape("out", 0) = ctx->InputShape("like", 0); + *ctx->IsDynamic4ArgNameAndIndex("out", 0) = ctx->InputIsDynamic("like", 0); + return Maybe::Ok(); +} +/*static*/ Maybe UnsortedSegmentSumLikeOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UnsortedSegmentSumLikeOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& data = ctx->InputTensorDesc("data", 0); + const user_op::TensorDesc& like = ctx->InputTensorDesc("like", 0); + CHECK_EQ_OR_RETURN(data.data_type(), like.data_type()); + CHECK_OR_RETURN(IsIndexDataType(ctx->InputDType("segment_ids", 0))); + *ctx->OutputDType("out", 0) = ctx->InputDType("like", 0); + return Maybe::Ok(); +} +/*static*/ Maybe UnsortedSegmentSumLikeOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* segment_ids_modifier = GetInputArgModifierFn("segment_ids", 0); + CHECK_NOTNULL_OR_RETURN(segment_ids_modifier); + segment_ids_modifier->set_requires_grad(false); + user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); + CHECK_NOTNULL_OR_RETURN(like_modifier); + like_modifier->set_requires_grad(false); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/upsample_op.cpp b/oneflow/user/ops/upsample_op.cpp index 6b8bb994e39..0a48bbbfe1a 100644 --- a/oneflow/user/ops/upsample_op.cpp +++ b/oneflow/user/ops/upsample_op.cpp @@ -14,446 +14,386 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("upsample_linear_1d") - .Input("x") - .Output("y") - .Attr("scale_factor") - .Attr("align_corners") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc* x_desc = ctx->TensorDesc4ArgNameAndIndex("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - const float scale_factor = ctx->Attr("scale_factor"); +/*static*/ Maybe UpsampleLinear1DOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleLinear1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + const float scale_factor = ctx->Attr("scale_factor"); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && x_desc->shape().NumAxes() == 3) - << "upsample_linear_1d only supports NCH"; - *y_desc->mut_shape() = Shape({x_desc->shape().At(0), x_desc->shape().At(1), - static_cast(scale_factor * x_desc->shape().At(2))}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && x_desc.shape().NumAxes() == 3) + << "upsample_linear_1d only supports NCH"; + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), + static_cast(scale_factor * x_desc.shape().At(2))}); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleLinear1DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleLinear1DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_nearest_1d") - .Input("x") - .Output("y") - .Attr("scale_factor") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc* x_desc = ctx->TensorDesc4ArgNameAndIndex("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - const float scale_factor = ctx->Attr("scale_factor"); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && x_desc->shape().NumAxes() == 3) - << "upsample_nearest_1d only supports NCH"; - *y_desc->mut_shape() = Shape({x_desc->shape().At(0), x_desc->shape().At(1), - static_cast(scale_factor * x_desc->shape().At(2))}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleNearest1DOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleNearest1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + const float scale_factor = ctx->Attr("scale_factor"); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && x_desc.shape().NumAxes() == 3) + << "upsample_nearest_1d only supports NCH"; + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), + static_cast(scale_factor * x_desc.shape().At(2))}); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleNearest1DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleNearest1DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_nearest_2d") - .Input("x") - .Output("y") - .Attr("height_scale") - .Attr("width_scale") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc* x_desc = ctx->TensorDesc4ArgNameAndIndex("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && x_desc->shape().NumAxes() == 4) - << "upsample_nearest_2d only supports NCHW"; - *y_desc->mut_shape() = Shape({x_desc->shape().At(0), x_desc->shape().At(1), - static_cast(height_scale * x_desc->shape().At(2)), - static_cast(width_scale * x_desc->shape().At(3))}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleNearest2DOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleNearest2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + const float height_scale = ctx->Attr("height_scale"); + const float width_scale = ctx->Attr("width_scale"); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && x_desc.shape().NumAxes() == 4) + << "upsample_nearest_2d only supports NCHW"; + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), + static_cast(height_scale * x_desc.shape().At(2)), + static_cast(width_scale * x_desc.shape().At(3))}); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleNearest2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleNearest2DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_bilinear_2d") - .Input("x") - .Output("y") - .Attr("height_scale") - .Attr("width_scale") - .Attr("align_corners") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc* x_desc = ctx->TensorDesc4ArgNameAndIndex("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && x_desc->shape().NumAxes() == 4) - << "upsample_bilinear_2d only supports NCHW"; - *y_desc->mut_shape() = Shape({x_desc->shape().At(0), x_desc->shape().At(1), - static_cast(height_scale * x_desc->shape().At(2)), - static_cast(width_scale * x_desc->shape().At(3))}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleBilinear2DOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleBilinear2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + const float height_scale = ctx->Attr("height_scale"); + const float width_scale = ctx->Attr("width_scale"); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && x_desc.shape().NumAxes() == 4) + << "upsample_bilinear_2d only supports NCHW"; + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), + static_cast(height_scale * x_desc.shape().At(2)), + static_cast(width_scale * x_desc.shape().At(3))}); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleBilinear2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleBilinear2DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_bicubic_2d") - .Input("x") - .Output("y") - .Attr("height_scale") - .Attr("width_scale") - .Attr("align_corners") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc* x_desc = ctx->TensorDesc4ArgNameAndIndex("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && x_desc->shape().NumAxes() == 4) - << "upsample_bicubic_2d only supports NCHW"; - *y_desc->mut_shape() = Shape({x_desc->shape().At(0), x_desc->shape().At(1), - static_cast(height_scale * x_desc->shape().At(2)), - static_cast(width_scale * x_desc->shape().At(3))}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleBicubic2DOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleBicubic2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + const float height_scale = ctx->Attr("height_scale"); + const float width_scale = ctx->Attr("width_scale"); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && x_desc.shape().NumAxes() == 4) + << "upsample_bicubic_2d only supports NCHW"; + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), + static_cast(height_scale * x_desc.shape().At(2)), + static_cast(width_scale * x_desc.shape().At(3))}); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleBicubic2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleBicubic2DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample") - .Input("x") - .Output("y") - .Attr("height_scale") - .Attr("width_scale") - .Attr("align_corners") - .Attr("data_format") - .Attr("interpolation") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); - if (ctx->Attr("data_format") != "channels_first" - || x_desc.shape().NumAxes() != 4) { - LOG(FATAL) << "upsample only supports NCHW"; - } - *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), - static_cast(height_scale * x_desc.shape().At(2)), - static_cast(width_scale * x_desc.shape().At(3))}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + const float height_scale = ctx->Attr("height_scale"); + const float width_scale = ctx->Attr("width_scale"); + if (ctx->Attr("data_format") != "channels_first" || x_desc.shape().NumAxes() != 4) { + LOG(FATAL) << "upsample only supports NCHW"; + } + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), + static_cast(height_scale * x_desc.shape().At(2)), + static_cast(width_scale * x_desc.shape().At(3))}); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_nearest_3d") - .Input("x") - .Output("y") - .Attr("depth_scale") - .Attr("height_scale") - .Attr("width_scale") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc* x_desc = ctx->TensorDesc4ArgNameAndIndex("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - const float depth_scale = ctx->Attr("depth_scale"); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && x_desc->shape().NumAxes() == 5) - << "upsample_nearest_3d only supports NCDHW"; - *y_desc->mut_shape() = Shape({x_desc->shape().At(0), x_desc->shape().At(1), - static_cast(depth_scale * x_desc->shape().At(2)), - static_cast(height_scale * x_desc->shape().At(3)), - static_cast(width_scale * x_desc->shape().At(4))}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleNearest3DOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleNearest3DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + const float depth_scale = ctx->Attr("depth_scale"); + const float height_scale = ctx->Attr("height_scale"); + const float width_scale = ctx->Attr("width_scale"); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && x_desc.shape().NumAxes() == 5) + << "upsample_nearest_3d only supports NCDHW"; + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), + static_cast(depth_scale * x_desc.shape().At(2)), + static_cast(height_scale * x_desc.shape().At(3)), + static_cast(width_scale * x_desc.shape().At(4))}); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleNearest3DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleNearest3DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_trilinear_3d") - .Input("x") - .Output("y") - .Attr("depth_scale") - .Attr("height_scale") - .Attr("width_scale") - .Attr("align_corners") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc* x_desc = ctx->TensorDesc4ArgNameAndIndex("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - const float depth_scale = ctx->Attr("depth_scale"); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && x_desc->shape().NumAxes() == 5) - << "upsample_trilinear_3d only supports NCDHW"; - *y_desc->mut_shape() = Shape({x_desc->shape().At(0), x_desc->shape().At(1), - static_cast(depth_scale * x_desc->shape().At(2)), - static_cast(height_scale * x_desc->shape().At(3)), - static_cast(width_scale * x_desc->shape().At(4))}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleTrilinear3DOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleTrilinear3DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + const float depth_scale = ctx->Attr("depth_scale"); + const float height_scale = ctx->Attr("height_scale"); + const float width_scale = ctx->Attr("width_scale"); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && x_desc.shape().NumAxes() == 5) + << "upsample_trilinear_3d only supports NCDHW"; + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), + static_cast(depth_scale * x_desc.shape().At(2)), + static_cast(height_scale * x_desc.shape().At(3)), + static_cast(width_scale * x_desc.shape().At(4))}); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleTrilinear3DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleTrilinear3DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_linear_1d_grad") - .Input("dy") - .Input("x") - .Output("dx") - .Attr("scale_factor") - .Attr("align_corners") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && dy_shape.NumAxes() == 3) - << "upsample_linear_1d_grad only supports NCH"; - *dx_shape = ctx->InputShape("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleLinear1DGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleLinear1DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && dy_shape.NumAxes() == 3) + << "upsample_linear_1d_grad only supports NCH"; + *dx_shape = ctx->InputShape("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleLinear1DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleLinear1DGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_nearest_1d_grad") - .Input("dy") - .Input("x") - .Output("dx") - .Attr("scale_factor") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && dy_shape.NumAxes() == 3) - << "upsample_nearest_1d_grad only supports NCH"; - *dx_shape = ctx->InputShape("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleNearest1DGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleNearest1DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && dy_shape.NumAxes() == 3) + << "upsample_nearest_1d_grad only supports NCH"; + *dx_shape = ctx->InputShape("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleNearest1DGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleNearest1DGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_nearest_2d_grad") - .Input("dy") - .Input("x") - .Output("dx") - .Attr("height_scale") - .Attr("width_scale") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && dy_shape.NumAxes() == 4) - << "upsample_nearest_2d_grad only supports NCHW"; - *dx_shape = ctx->InputShape("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleNearest2DGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleNearest2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && dy_shape.NumAxes() == 4) + << "upsample_nearest_2d_grad only supports NCHW"; + *dx_shape = ctx->InputShape("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleNearest2DGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleNearest2DGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_bilinear_2d_grad") - .Input("dy") - .Input("x") - .Output("dx") - .Attr("height_scale") - .Attr("width_scale") - .Attr("align_corners") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && dy_shape.NumAxes() == 4) - << "upsample_bilinear_2d_grad only supports NCHW"; - *dx_shape = ctx->InputShape("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleBilinear2DGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleBilinear2DGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && dy_shape.NumAxes() == 4) + << "upsample_bilinear_2d_grad only supports NCHW"; + *dx_shape = ctx->InputShape("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleBilinear2DGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleBilinear2DGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_bicubic_2d_grad") - .Input("dy") - .Input("x") - .Output("dx") - .Attr("height_scale") - .Attr("width_scale") - .Attr("align_corners") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && dy_shape.NumAxes() == 4) - << "upsample_bicubic_2d_grad only supports NCHW"; - *dx_shape = ctx->InputShape("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleBicubic2DGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleBicubic2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && dy_shape.NumAxes() == 4) + << "upsample_bicubic_2d_grad only supports NCHW"; + *dx_shape = ctx->InputShape("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleBicubic2DGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleBicubic2DGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_grad") - .Input("dy") - .Input("x") - .Output("dx") - .Attr("height_scale") - .Attr("width_scale") - .Attr("align_corners") - .Attr("data_format") - .Attr("interpolation") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - if (ctx->Attr("data_format") != "channels_first" || dy_shape.NumAxes() != 4) { - LOG(FATAL) << "upsample_nearest only supports NCHW"; - } - *dx_shape = ctx->InputShape("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + if (ctx->Attr("data_format") != "channels_first" || dy_shape.NumAxes() != 4) { + LOG(FATAL) << "upsample_nearest only supports NCHW"; + } + *dx_shape = ctx->InputShape("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_nearest_3d_grad") - .Input("dy") - .Input("x") - .Output("dx") - .Attr("depth_scale") - .Attr("height_scale") - .Attr("width_scale") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && dy_shape.NumAxes() == 5) - << "upsample_nearest_3d_grad only supports NCDHW"; - *dx_shape = ctx->InputShape("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleNearest3DGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleNearest3DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && dy_shape.NumAxes() == 5) + << "upsample_nearest_3d_grad only supports NCDHW"; + *dx_shape = ctx->InputShape("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleNearest3DGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleNearest3DGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_trilinear_3d_grad") - .Input("dy") - .Input("x") - .Output("dx") - .Attr("depth_scale") - .Attr("height_scale") - .Attr("width_scale") - .Attr("align_corners") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && dy_shape.NumAxes() == 5) - << "upsample_trilinear_3d_grad only supports NCDHW"; - *dx_shape = ctx->InputShape("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleTrilinear3DGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleTrilinear3DGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && dy_shape.NumAxes() == 5) + << "upsample_trilinear_3d_grad only supports NCDHW"; + *dx_shape = ctx->InputShape("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleTrilinear3DGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleTrilinear3DGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("upsample_linear_1d") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/where_op.cpp b/oneflow/user/ops/where_op.cpp index df26b3015e9..8dba2951a44 100644 --- a/oneflow/user/ops/where_op.cpp +++ b/oneflow/user/ops/where_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -239,7 +240,7 @@ Maybe GetWhereXYScalarSbpSignatures(user_op::SbpContext* ctx) { return Maybe::Ok(); } -Maybe GetWhereInputArgModify(user_op::GetInputArgModifier GetInputArgModifierFn, +Maybe GetWhereInputArgModify(const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn("condition", 0); cond_arg_modifier->set_requires_grad(false); @@ -248,101 +249,109 @@ Maybe GetWhereInputArgModify(user_op::GetInputArgModifier GetInputArgModif } // namespace -REGISTER_USER_OP("where") - .Input("condition") - .Input("x") - .Input("y") - .Output("out") - .SetTensorDescInferFn(InferWhereTensorDesc) - .SetInputArgModifyFn(GetWhereInputArgModify) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const DataType& cond_dtype = ctx->InputDType("condition", 0); - CHECK_OR_RETURN(IsIntegralDataType(cond_dtype)); - const DataType& x_dtype = ctx->InputDType("x", 0); - CHECK_EQ_OR_RETURN(x_dtype, ctx->InputDType("y", 0)); - *ctx->OutputDType("out", 0) = x_dtype; - return Maybe::Ok(); - }) - .SetGetSbpFn(GetWhereSbpSignatures); +/*static*/ Maybe WhereOp::GetSbp(user_op::SbpContext* ctx) { + return GetWhereSbpSignatures(ctx); +} +/*static*/ Maybe WhereOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferWhereTensorDesc(ctx); +} +/*static*/ Maybe WhereOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe WhereOp::InferDataType(user_op::InferContext* ctx) { + const DataType& cond_dtype = ctx->InputDType("condition", 0); + CHECK_OR_RETURN(IsIntegralDataType(cond_dtype)); + const DataType& x_dtype = ctx->InputDType("x", 0); + CHECK_EQ_OR_RETURN(x_dtype, ctx->InputDType("y", 0)); + *ctx->OutputDType("out", 0) = x_dtype; + return Maybe::Ok(); +} +/*static*/ Maybe WhereOp::ModifyInputArg(const GetInputArgModifier& f, + const user_op::UserOpConfWrapper& conf) { + return GetWhereInputArgModify(f, conf); +} -REGISTER_USER_OP("where_scalar_x") - .Input("condition") - .Input("y") - .Output("out") - .Attr("has_int_operand") - .Attr("has_float_operand") - .Attr("int_operand") - .Attr("float_operand") - .SetTensorDescInferFn(InferWhereXScalarTensorDesc) - .SetInputArgModifyFn(GetWhereInputArgModify) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const DataType& cond_dtype = ctx->InputDType("condition", 0); - CHECK_OR_RETURN(IsIntegralDataType(cond_dtype)); - const DataType& y_dtype = ctx->InputDType("y", 0); - if (ctx->Attr("has_int_operand")) { - CHECK_EQ_OR_RETURN(y_dtype, GetDataType::value) - << "expected scalar type " << GetDataType::value << "but found " << y_dtype; - } else if (ctx->Attr("has_float_operand")) { - CHECK_EQ_OR_RETURN(y_dtype, GetDataType::value) - << "expected scalar type " << GetDataType::value << "but found " << y_dtype; - } - *ctx->OutputDType("out", 0) = y_dtype; - return Maybe::Ok(); - }) - .SetGetSbpFn(GetWhereXScalarSbpSignatures); +/*static*/ Maybe WhereScalarXOp::GetSbp(user_op::SbpContext* ctx) { + return GetWhereXScalarSbpSignatures(ctx); +} +/*static*/ Maybe WhereScalarXOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferWhereXScalarTensorDesc(ctx); +} +/*static*/ Maybe WhereScalarXOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe WhereScalarXOp::InferDataType(user_op::InferContext* ctx) { + const DataType& cond_dtype = ctx->InputDType("condition", 0); + CHECK_OR_RETURN(IsIntegralDataType(cond_dtype)); + const DataType& y_dtype = ctx->InputDType("y", 0); + if (ctx->Attr("has_int_operand")) { + CHECK_EQ_OR_RETURN(y_dtype, GetDataType::value) + << "expected scalar type " << GetDataType::value << "but found " << y_dtype; + } else if (ctx->Attr("has_float_operand")) { + CHECK_EQ_OR_RETURN(y_dtype, GetDataType::value) + << "expected scalar type " << GetDataType::value << "but found " << y_dtype; + } + *ctx->OutputDType("out", 0) = y_dtype; + return Maybe::Ok(); +} +/*static*/ Maybe WhereScalarXOp::ModifyInputArg(const GetInputArgModifier& f, + const user_op::UserOpConfWrapper& conf) { + return GetWhereInputArgModify(f, conf); +} -REGISTER_USER_OP("where_scalar_y") - .Input("condition") - .Input("x") - .Output("out") - .Attr("has_int_operand") - .Attr("has_float_operand") - .Attr("int_operand") - .Attr("float_operand") - .SetTensorDescInferFn(InferWhereYScalarTensorDesc) - .SetInputArgModifyFn(GetWhereInputArgModify) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const DataType& cond_dtype = ctx->InputDType("condition", 0); - CHECK_OR_RETURN(IsIntegralDataType(cond_dtype)); - const DataType& x_dtype = ctx->InputDType("x", 0); - if (ctx->Attr("has_int_operand")) { - CHECK_EQ_OR_RETURN(x_dtype, GetDataType::value) - << "expected scalar type " << x_dtype << "but found " << GetDataType::value; - } else if (ctx->Attr("has_float_operand")) { - CHECK_EQ_OR_RETURN(x_dtype, GetDataType::value) - << "expected scalar type " << x_dtype << "but found " << GetDataType::value; - } - *ctx->OutputDType("out", 0) = x_dtype; - return Maybe::Ok(); - }) - .SetGetSbpFn(GetWhereYScalarSbpSignatures); +/*static*/ Maybe WhereScalarYOp::GetSbp(user_op::SbpContext* ctx) { + return GetWhereYScalarSbpSignatures(ctx); +} +/*static*/ Maybe WhereScalarYOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferWhereYScalarTensorDesc(ctx); +} +/*static*/ Maybe WhereScalarYOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe WhereScalarYOp::InferDataType(user_op::InferContext* ctx) { + const DataType& cond_dtype = ctx->InputDType("condition", 0); + CHECK_OR_RETURN(IsIntegralDataType(cond_dtype)); + const DataType& x_dtype = ctx->InputDType("x", 0); + if (ctx->Attr("has_int_operand")) { + CHECK_EQ_OR_RETURN(x_dtype, GetDataType::value) + << "expected scalar type " << x_dtype << "but found " << GetDataType::value; + } else if (ctx->Attr("has_float_operand")) { + CHECK_EQ_OR_RETURN(x_dtype, GetDataType::value) + << "expected scalar type " << x_dtype << "but found " << GetDataType::value; + } + *ctx->OutputDType("out", 0) = x_dtype; + return Maybe::Ok(); +} +/*static*/ Maybe WhereScalarYOp::ModifyInputArg(const GetInputArgModifier& f, + const user_op::UserOpConfWrapper& conf) { + return GetWhereInputArgModify(f, conf); +} -REGISTER_NO_GRAD_USER_OP("where_scalar_xy") - .Input("condition") - .Output("out") - .Attr("has_x_int_operand") - .Attr("has_x_float_operand") - .Attr("has_y_int_operand") - .Attr("has_y_float_operand") - .Attr("x_int_operand") - .Attr("x_float_operand") - .Attr("y_int_operand") - .Attr("y_float_operand") - .SetTensorDescInferFn(InferWhereXYScalarTensorDesc) - .SetInputArgModifyFn(GetWhereInputArgModify) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const DataType& cond_dtype = ctx->InputDType("condition", 0); - CHECK_OR_RETURN(IsIntegralDataType(cond_dtype)); - if (ctx->Attr("has_x_int_operand") && ctx->Attr("has_y_int_operand")) { - *ctx->OutputDType("out", 0) = GetDataType::value; - } else if (ctx->Attr("has_x_float_operand") && ctx->Attr("has_y_float_operand")) { - *ctx->OutputDType("out", 0) = GetDataType::value; - } else { - UNIMPLEMENTED(); - } - return Maybe::Ok(); - }) - .SetGetSbpFn(GetWhereXYScalarSbpSignatures); +/*static*/ Maybe WhereScalarXyOp::GetSbp(user_op::SbpContext* ctx) { + return GetWhereXYScalarSbpSignatures(ctx); +} +/*static*/ Maybe WhereScalarXyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferWhereXYScalarTensorDesc(ctx); +} +/*static*/ Maybe WhereScalarXyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe WhereScalarXyOp::InferDataType(user_op::InferContext* ctx) { + const DataType& cond_dtype = ctx->InputDType("condition", 0); + CHECK_OR_RETURN(IsIntegralDataType(cond_dtype)); + if (ctx->Attr("has_x_int_operand") && ctx->Attr("has_y_int_operand")) { + *ctx->OutputDType("out", 0) = GetDataType::value; + } else if (ctx->Attr("has_x_float_operand") && ctx->Attr("has_y_float_operand")) { + *ctx->OutputDType("out", 0) = GetDataType::value; + } else { + UNIMPLEMENTED(); + } + return Maybe::Ok(); +} +/*static*/ Maybe WhereScalarXyOp::ModifyInputArg(const GetInputArgModifier& f, + const user_op::UserOpConfWrapper& conf) { + return GetWhereInputArgModify(f, conf); +} REGISTER_USER_OP_GRAD("where").SetBackwardOpConfGenFn( [](user_op::BackwardOpConfContext* ctx) -> Maybe { diff --git a/oneflow/user/ops/zero_like_op.cpp b/oneflow/user/ops/zero_like_op.cpp index 193ad79666b..6e650556069 100644 --- a/oneflow/user/ops/zero_like_op.cpp +++ b/oneflow/user/ops/zero_like_op.cpp @@ -14,35 +14,34 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("zero_like") - .Input("like") - .Output("out") - .SetOutputBufferNum(1) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("like", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("like", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& like_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0); - FOR_RANGE(int64_t, i, 0, like_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("like", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("like", 0)) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe ZeroLikeOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& like_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0); + FOR_RANGE(int64_t, i, 0, like_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("like", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("like", 0)) + .Broadcast(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe ZeroLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("like", 0); + return Maybe::Ok(); +} +/*static*/ Maybe ZeroLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ZeroLikeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("like", 0); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/summary/summary_converter.h b/oneflow/user/summary/summary_converter.h index 473afbf3411..3c4b0e79088 100644 --- a/oneflow/user/summary/summary_converter.h +++ b/oneflow/user/summary/summary_converter.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef ONEFLOW_USER_SUMMARY_SUMMARY_CONVERTER_H_ #define ONEFLOW_USER_SUMMARY_SUMMARY_CONVERTER_H_ -#include +#include "nlohmann/json.hpp" #include "oneflow/core/common/data_type.h" #include "oneflow/core/common/protobuf.h" diff --git a/oneflow/user/utils/pool_util.h b/oneflow/user/utils/pool_util.h index e688be836c3..9a21f8a9129 100644 --- a/oneflow/user/utils/pool_util.h +++ b/oneflow/user/utils/pool_util.h @@ -17,7 +17,7 @@ limitations under the License. #define ONEFLOW_USER_UTILS_POOL_UTIL_H_ #include "oneflow/core/device/cudnn_util.h" #include "oneflow/core/framework/framework.h" -#include "oneflow/user/kernels/op_kernel_state_wrapper.h" +#include "oneflow/user/kernels/op_kernel_wrapper.h" namespace oneflow { diff --git a/oneflow/xrt/passes/rebuild_job_pass.cpp b/oneflow/xrt/passes/rebuild_job_pass.cpp index cbc312484d6..a587b6e409c 100644 --- a/oneflow/xrt/passes/rebuild_job_pass.cpp +++ b/oneflow/xrt/passes/rebuild_job_pass.cpp @@ -285,7 +285,8 @@ void FoldSubgraphBuilder::BuildXrtLaunchOps() { } CHECK_GT(folded_nodes_[i].size(), 0); - const ParallelConf& parallel_conf = builder_->ParallelConf4OpName(folded_nodes_[i][0]->name()); + const ParallelConf& parallel_conf = + CHECK_JUST(builder_->ParallelConf4OpName(folded_nodes_[i][0]->name())); // TODO(hjchen2) check parallel conf over all folded nodes builder_->AddOps(parallel_conf, {op_conf}); diff --git a/oneflow/xrt/xla/ops/layer_norm_op.cpp b/oneflow/xrt/xla/ops/layer_norm_op.cpp index a7d1b371788..1f10a0744ef 100644 --- a/oneflow/xrt/xla/ops/layer_norm_op.cpp +++ b/oneflow/xrt/xla/ops/layer_norm_op.cpp @@ -85,10 +85,6 @@ void LayerNormOp::Compile(XlaOpContext* ctx) { output = xla::ConvertElementType(output, DataTypeToPrimitiveType(output_type)); } - if (ctx->Attr("scale") && ctx->HasOutput("normalized_0")) { - ctx->SetOutput("normalized_0", Reshape(output, input_shape)); - } - Shape gamma_shape = Shape({norm_dims}); // output = Reshape(output, Shape({batch_dims, norm_dims})); if (ctx->Attr("scale")) { @@ -123,6 +119,11 @@ void LayerNormGradOp::Compile(XlaOpContext* ctx) { xla::XlaOp mean = ctx->Input("mean_0"); xla::XlaOp inv_variance = ctx->Input("inv_variance_0"); + if (ctx->HasInput("gamma_0")) { + xla::XlaOp gamma = ctx->Input("gamma_0"); + output_grad = output_grad * gamma; + } + Shape activation_shape = ctx->InputShape("x_0"); int begin_norm_axis = ctx->Attr("begin_norm_axis"); CHECK_LT(begin_norm_axis, activation_shape.NumAxes()); @@ -184,22 +185,14 @@ void LayerNormParamGradOp::Compile(XlaOpContext* ctx) { xla::XlaOp beta_grad = xla::Reduce(output_grad, Zero(builder, data_type), add_func, batch_dims); ctx->SetOutput("beta_diff_0", beta_grad); } + xla::XlaOp x = ctx->Input("x_0"); + xla::XlaOp mean = ctx->Input("mean_0"); + xla::XlaOp inv_variance = ctx->Input("inv_variance_0"); if (ctx->HasOutput("gamma_diff_0")) { - xla::XlaOp normalized = ctx->Input("normalized_0"); - xla::XlaOp gamma_grad = normalized * output_grad; + xla::XlaOp gamma_grad = (x - mean) * inv_variance * output_grad; gamma_grad = xla::Reduce(gamma_grad, Zero(builder, data_type), add_func, batch_dims); ctx->SetOutput("gamma_diff_0", gamma_grad); } - if (ctx->HasOutput("normalized_diff_0")) { - xla::XlaOp normalized_grad; - if (ctx->HasInput("gamma_0")) { - xla::XlaOp gamma = ctx->Input("gamma_0"); - normalized_grad = xla::Mul(output_grad, gamma, norm_dims); - } else { - normalized_grad = output_grad; - } - ctx->SetOutput("normalized_diff_0", normalized_grad); - } } REGISTER_XLA_OP_KERNEL(LayerNorm, LayerNormOp).Finalize(); diff --git a/python/oneflow/__init__.py b/python/oneflow/__init__.py index 1cb725b2efa..a0e1c4ccbdf 100755 --- a/python/oneflow/__init__.py +++ b/python/oneflow/__init__.py @@ -15,6 +15,12 @@ """ import os + +if os.getenv("CTEST_RESOURCE_GROUP_COUNT"): + vram_str = os.getenv("CTEST_RESOURCE_GROUP_0_VRAM") + gpu_id = vram_str.split(",")[0].split(":")[-1] + os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id + import sys import collections @@ -72,6 +78,9 @@ def is_deprecated(func_or_class): from oneflow._C import atanh from oneflow._C import atanh as arctanh from oneflow._C import batch_matmul as bmm +from oneflow._C import broadcast_like +from oneflow._C import chunk +from oneflow._C import split from oneflow._C import sign from oneflow._C import sinh from oneflow._C import tan @@ -117,12 +126,14 @@ def is_deprecated(func_or_class): from oneflow._C import clamp as clip from oneflow._C import cos from oneflow._C import cosh +from oneflow._C import diagonal from oneflow._C import erf from oneflow._C import erfc from oneflow._C import expm1 from oneflow._C import fmod from oneflow._C import flatten from oneflow._C import log +from oneflow._C import log2 from oneflow._C import minimum from oneflow._C import maximum from oneflow._C import pow @@ -159,6 +170,8 @@ def is_deprecated(func_or_class): from oneflow._C import read_onerec from oneflow._C import decode_onerec from oneflow._C import dot +from oneflow._C import eye +from oneflow._C import cumsum from . import sbp @@ -249,6 +262,7 @@ def atexit_hook(hook): import oneflow._C from oneflow._C import tensor, batch_gather +from oneflow._C import from_numpy from oneflow.autograd import grad_enable, no_grad, inference_mode, is_grad_enabled import oneflow.nn.image @@ -281,10 +295,9 @@ def atexit_hook(hook): adaptive_avg_pool3d, ) from oneflow.nn.modules.arange import arange_op as arange +from oneflow.nn.modules.linspace import linspace_op as linspace from oneflow.nn.modules.argsort import argsort_op as argsort from oneflow.nn.modules.argwhere import argwhere_op as argwhere -from oneflow.nn.modules.broadcast_like import broadcast_like_op as broadcast_like -from oneflow.nn.modules.chunk import chunk_op as chunk from oneflow.nn.modules.constant import ones_op as ones from oneflow.nn.modules.constant import zeros_op as zeros from oneflow.nn.modules.constant import full_op as full @@ -310,10 +323,10 @@ def atexit_hook(hook): from oneflow.nn.modules.masked_select import masked_select_op as masked_select from oneflow.nn.modules.math_ops import addmm_op as addmm from oneflow.nn.modules.math_ops import topk_op as topk -from oneflow.nn.modules.meshgrid import meshgrid_op as meshgrid from oneflow.nn.modules.nonzero import nonzero_op as nonzero from oneflow.nn.modules.nms import nms_op as nms from oneflow.nn.modules.numel import numel_op as numel +from oneflow.nn.modules.meshgrid import meshgrid_op as meshgrid from oneflow.nn.modules.random_ops import rand_op as rand from oneflow.nn.modules.random_ops import randn_op as randn from oneflow.nn.modules.random_ops import randint_op as randint @@ -331,9 +344,8 @@ def atexit_hook(hook): from oneflow.nn.modules.slice import slice_op as slice from oneflow.nn.modules.slice import slice_update_op as slice_update from oneflow.nn.modules.slice import logical_slice_assign_op as logical_slice_assign +from oneflow.nn.modules.slice import logical_slice_op as logical_slice from oneflow.nn.modules.sort import sort_op as sort -from oneflow.nn.modules.split import split_op as split -from oneflow.nn.modules.eye import eye_op as eye from oneflow.nn.modules.tensor_buffer import gen_tensor_buffer from oneflow.nn.modules.tensor_buffer import ( tensor_buffer_to_tensor_op as tensor_buffer_to_tensor, @@ -345,7 +357,7 @@ def atexit_hook(hook): from oneflow.nn.modules.consistent_cast import to_local_op as to_local from oneflow.nn.modules.where import where_op as where from oneflow.nn.modules.scatter import * -from oneflow.ops.builtin_ops import BuiltinOp as builtin_op +from oneflow.ops.stateful_ops import StatefulOp as stateful_op from oneflow.ops.initializer_util import constant_initializer from oneflow.ops.initializer_util import glorot_normal_initializer from oneflow.ops.initializer_util import ( diff --git a/python/oneflow/comm/__init__.py b/python/oneflow/comm/__init__.py index 71919ae9917..82c8267e2d2 100644 --- a/python/oneflow/comm/__init__.py +++ b/python/oneflow/comm/__init__.py @@ -18,6 +18,8 @@ from oneflow.comm.comm_ops import broadcast from oneflow.comm.comm_ops import scatter from oneflow.comm.comm_ops import reduce +from oneflow.comm.comm_ops import all_to_all +from oneflow.comm.comm_ops import barrier from oneflow.comm.comm_ops import reduce_scatter from oneflow.comm.comm_ops import gather from oneflow._C import send, recv diff --git a/python/oneflow/comm/comm_ops.py b/python/oneflow/comm/comm_ops.py index 76dfd64cee0..d4c5bd9afaa 100644 --- a/python/oneflow/comm/comm_ops.py +++ b/python/oneflow/comm/comm_ops.py @@ -34,17 +34,17 @@ def all_reduce(tensor): >>> # We have 1 process groups, 2 ranks. >>> import oneflow as flow - >>> input = flow.tensor([[1, 2], [3, 4]], device="cuda") + flow.env.get_local_rank() - >>> # input on rank0 - >>> input # doctest: +ONLY_CHECK_RANK_0 + >>> tensor = flow.tensor([[1, 2], [3, 4]], device="cuda") + flow.env.get_local_rank() + >>> # tensor on rank0 + >>> tensor # doctest: +ONLY_CHECK_RANK_0 tensor([[1, 2], [3, 4]], device='cuda:0', dtype=oneflow.int64) - >>> # input on rank1 - >>> input # doctest: +ONLY_CHECK_RANK_1 + >>> # tensor on rank1 + >>> tensor # doctest: +ONLY_CHECK_RANK_1 tensor([[2, 3], [4, 5]], device='cuda:1', dtype=oneflow.int64) - >>> out = flow.comm.all_reduce(input) - >>> out.numpy() + >>> flow.comm.all_reduce(tensor) + >>> tensor.numpy() array([[3, 5], [7, 9]]) @@ -54,11 +54,11 @@ def all_reduce(tensor): assert tensor.is_local device_type = tensor.device.type placement = flow.env.all_device_placement(device_type) - tensor = tensor.to_consistent( + result = tensor.to_consistent( placement=placement, sbp=flow.sbp.partial_sum ).to_consistent(placement=placement, sbp=flow.sbp.broadcast) - return tensor.to_local() + tensor.data = result.to_local() def all_gather(tensor_list, tensor): @@ -107,10 +107,10 @@ def all_gather(tensor_list, tensor): assert tensor.is_local tensor = tensor.expand(*([1] + list(tensor.shape))) device_type = tensor.device.type + placement = flow.env.all_device_placement(device_type) tensor = tensor.to_consistent( - placement=flow.env.all_device_placement(device_type), sbp=flow.sbp.split(0) - ) - tensor = tensor.to_consistent(sbp=flow.sbp.broadcast) + placement=placement, sbp=flow.sbp.split(0) + ).to_consistent(placement=placement, sbp=flow.sbp.broadcast) assert len(tensor_list) == flow.env.get_world_size() for i in range(tensor.shape[0]): tensor_list[i] = tensor[i].to_local() @@ -203,9 +203,58 @@ def reduce(tensor, dst): assert isinstance(tensor, flow._oneflow_internal.Tensor) assert tensor.is_local assert isinstance(dst, int) - result = flow.comm.all_reduce(tensor) - if flow.env.get_rank() == dst: - tensor.data = result + original_tensor = flow._C.identity(tensor) + flow.comm.all_reduce(tensor) + if flow.env.get_rank() != dst: + tensor.data = original_tensor + + +def all_to_all(output_tensor_list, input_tensor_list): + """ + Each process scatters list of input tensors to all processes in a group and + return gathered list of tensors in output list. + + Args: + output_tensor_list (list[Tensor]): List of tensors to be gathered one + per rank. + input_tensor_list (list[Tensor]): List of tensors to scatter one per rank. + + """ + + def _check_list(tensor_list): + assert isinstance(tensor_list, list) + assert len(tensor_list) == flow.env.get_world_size() + shape = tensor_list[0].shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + for tensor in tensor_list: + assert isinstance(tensor, flow._oneflow_internal.Tensor) + assert tensor.is_local + assert shape == tensor.shape + assert dtype == tensor.dtype + assert device == tensor.device + + _check_list(output_tensor_list) + _check_list(input_tensor_list) + + assert input_tensor_list[0].shape == output_tensor_list[0].shape + assert input_tensor_list[0].dtype == output_tensor_list[0].dtype + assert input_tensor_list[0].device == output_tensor_list[0].device + + for i in range(flow.env.get_world_size()): + flow.comm.scatter( + output_tensor_list[i], + input_tensor_list if i == flow.env.get_rank() else [], + src=i, + ) + + +def barrier(): + """ + Synchronizes all processes. + + """ + oneflow._oneflow_internal.eager.multi_client.Sync() def reduce_scatter(output, input_list): @@ -266,7 +315,5 @@ def gather(tensor, gather_list=None, dst=0): assert gather_list is not None assert isinstance(gather_list, list) assert len(gather_list) == flow.env.get_world_size() - # "to_consistent(placement=flow.env.all_device_placement("cuda/cpu"), sbp=flow.sbp.broadcast)" - # after here will fail, if do getitem on some a rank for i in range(tensor.shape[0]): gather_list[i] = tensor[i].to_local() diff --git a/python/oneflow/compatible/single_client/__init__.py b/python/oneflow/compatible/single_client/__init__.py index 705b5790d83..2b4fd39b448 100644 --- a/python/oneflow/compatible/single_client/__init__.py +++ b/python/oneflow/compatible/single_client/__init__.py @@ -37,6 +37,15 @@ locals()["record"] = oneflow._oneflow_internal.record locals()["tensor_buffer"] = oneflow._oneflow_internal.tensor_buffer locals()["bfloat16"] = oneflow._oneflow_internal.bfloat16 +locals()["uint16"] = oneflow._oneflow_internal.uint16 +locals()["uint32"] = oneflow._oneflow_internal.uint32 +locals()["uint64"] = oneflow._oneflow_internal.uint64 +locals()["uint128"] = oneflow._oneflow_internal.uint128 +locals()["int16"] = oneflow._oneflow_internal.int16 +locals()["int128"] = oneflow._oneflow_internal.int128 +locals()["complex32"] = oneflow._oneflow_internal.complex32 +locals()["complex64"] = oneflow._oneflow_internal.complex64 +locals()["complex128"] = oneflow._oneflow_internal.complex128 from oneflow.compatible.single_client.framework import ( env_util, session_context, @@ -247,7 +256,7 @@ def custom_exit(returncode): zeros, ) from oneflow.compatible.single_client.ops.assign_op import assign -from oneflow.compatible.single_client.ops.builtin_ops import BuiltinOp as builtin_op +from oneflow.compatible.single_client.ops.stateful_ops import StatefulOp as stateful_op from oneflow.compatible.single_client.ops.categorical_ordinal_encode_op import ( categorical_ordinal_encode, ) diff --git a/python/oneflow/compatible/single_client/framework/config_util.py b/python/oneflow/compatible/single_client/framework/config_util.py index e7ec9b68540..58053dbdec4 100644 --- a/python/oneflow/compatible/single_client/framework/config_util.py +++ b/python/oneflow/compatible/single_client/framework/config_util.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +import os import traceback import oneflow._oneflow_internal @@ -356,6 +357,8 @@ def enable_tensor_float_32_compute(val=True): sess = session_ctx.GetDefaultSession() assert type(val) is bool sess.config_proto.resource.enable_tensor_float_32_compute = val + if not val: + os.environ["ONEFLOW_EP_CUDA_ENABLE_TF32_EXECUTION"] = "0" def api_enable_mem_chain_merge(val: bool = True) -> None: diff --git a/python/oneflow/compatible/single_client/framework/dtype.py b/python/oneflow/compatible/single_client/framework/dtype.py index b910de41719..c89f15d7891 100644 --- a/python/oneflow/compatible/single_client/framework/dtype.py +++ b/python/oneflow/compatible/single_client/framework/dtype.py @@ -34,6 +34,16 @@ flow.record, flow.tensor_buffer, flow.bfloat16, + flow.uint16, + flow.uint32, + flow.uint64, + flow.uint128, + flow.int16, + flow.int64, + flow.int128, + flow.complex32, + flow.complex64, + flow.complex128, ] @@ -57,6 +67,13 @@ def convert_proto_dtype_to_oneflow_dtype(proto_dtype): flow.int32: np.int32, flow.int64: np.int64, flow.uint8: np.uint8, + flow.uint16: np.uint16, + flow.uint32: np.uint32, + flow.uint64: np.uint64, + flow.int16: np.int16, + flow.int64: np.int64, + flow.complex64: np.complex64, + flow.complex128: np.complex128, } diff --git a/python/oneflow/compatible/single_client/framework/op_expr_util.py b/python/oneflow/compatible/single_client/framework/op_expr_util.py deleted file mode 100644 index 1ba07670e25..00000000000 --- a/python/oneflow/compatible/single_client/framework/op_expr_util.py +++ /dev/null @@ -1,38 +0,0 @@ -""" -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" -import oneflow._oneflow_internal -from oneflow.compatible import single_client as flow -from oneflow.compatible.single_client.framework.attr_util import ( - convert_to_user_attr_value, -) - - -def user_op_expr_call(self, *args, **kwargs): - attrs = oneflow._oneflow_internal.MutableCfgAttrMap() - for (attr_name, attr_value) in kwargs.items(): - assert isinstance(attr_name, str) - attrs[attr_name] = convert_to_user_attr_value( - self.op_type_name, attr_name, attr_value - ) - try: - results = self.apply(args, attrs) - except oneflow._oneflow_internal.exception.Exception: - raise oneflow._oneflow_internal.exception.GetThreadLocalLastError() - return results - - -def RegisterMethod4UserOpExpr(): - oneflow._oneflow_internal.one.UserOpExpr.__call__ = user_op_expr_call diff --git a/python/oneflow/compatible/single_client/framework/register_class_method_util.py b/python/oneflow/compatible/single_client/framework/register_class_method_util.py index 03c8cd724f5..732a27b27d8 100644 --- a/python/oneflow/compatible/single_client/framework/register_class_method_util.py +++ b/python/oneflow/compatible/single_client/framework/register_class_method_util.py @@ -18,12 +18,10 @@ from oneflow.compatible.single_client.framework import blob_trait as blob_trait from oneflow.compatible.single_client.framework import functional as functional from oneflow.compatible.single_client.framework import generator as generator -from oneflow.compatible.single_client.framework import op_expr_util as op_expr_util from oneflow.compatible.single_client.framework import remote_blob as remote_blob_util def RegisterMethod4Class(): - op_expr_util.RegisterMethod4UserOpExpr() eager_blob_util.RegisterMethod4EagerPhysicalBlob() blob_trait.RegisterBlobOperatorTraitMethod( oneflow._oneflow_internal.EagerPhysicalBlob diff --git a/python/oneflow/compatible/single_client/nn/optimizer/adam.py b/python/oneflow/compatible/single_client/nn/optimizer/adam.py index 9b062b2d335..99148e1d608 100644 --- a/python/oneflow/compatible/single_client/nn/optimizer/adam.py +++ b/python/oneflow/compatible/single_client/nn/optimizer/adam.py @@ -103,13 +103,11 @@ def __init__( self._state[param]["exp_avg"] = flow.experimental.zeros_like(param) self._state[param]["exp_avg_sq"] = flow.experimental.zeros_like(param) self._op = ( - flow.builtin_op("adam_update") + flow.stateful_op("adam_update") .Input("model") .Input("model_diff") .Input("m") .Input("v") - .Attr("l1", 0.0) - .Attr("weight_decay", 0.0) .Build() ) @@ -126,7 +124,7 @@ def step(self, closure: Callable = None): loss = closure() for param_group in self.param_groups: kwargs = { - "learning_rate_val": param_group["lr"], + "learning_rate": param_group["lr"], "scale": param_group["scale"], "l2": param_group["weight_decay"], "beta1": param_group["betas"][0], @@ -138,6 +136,8 @@ def step(self, closure: Callable = None): continue m_tensor = self._state[param]["exp_avg"] v_tensor = self._state[param]["exp_avg_sq"] - self._op(param, param.grad, m_tensor, v_tensor, **kwargs) + flow._C.dispatch_adam_update( + self._op, (param, param.grad, m_tensor, v_tensor), **kwargs + ) self._state["step"] = self._state["step"] + 1 return loss diff --git a/python/oneflow/compatible/single_client/nn/optimizer/adamw.py b/python/oneflow/compatible/single_client/nn/optimizer/adamw.py index d1eb2e257b0..72628386d67 100644 --- a/python/oneflow/compatible/single_client/nn/optimizer/adamw.py +++ b/python/oneflow/compatible/single_client/nn/optimizer/adamw.py @@ -106,13 +106,11 @@ def __init__( self._state[param]["exp_avg"] = flow.experimental.zeros_like(param) self._state[param]["exp_avg_sq"] = flow.experimental.zeros_like(param) self._op = ( - flow.builtin_op("adam_update") + flow.stateful_op("adam_update") .Input("model") .Input("model_diff") .Input("m") .Input("v") - .Attr("l1", 0.0) - .Attr("l2", 0.0) .Build() ) @@ -129,7 +127,7 @@ def step(self, closure: Callable = None): loss = closure() for param_group in self.param_groups: kwargs = { - "learning_rate_val": param_group["lr"], + "learning_rate": param_group["lr"], "scale": param_group["scale"], "weight_decay": param_group["weight_decay"], "beta1": param_group["betas"][0], @@ -141,6 +139,8 @@ def step(self, closure: Callable = None): continue m_tensor = self._state[param]["exp_avg"] v_tensor = self._state[param]["exp_avg_sq"] - self._op(param, param.grad, m_tensor, v_tensor, **kwargs) + flow._C.dispatch_adam_update( + self._op, (param, param.grad, m_tensor, v_tensor), **kwargs + ) self._state["step"] = self._state["step"] + 1 return loss diff --git a/python/oneflow/compatible/single_client/nn/optimizer/rmsprop.py b/python/oneflow/compatible/single_client/nn/optimizer/rmsprop.py index 4797e7f7301..dc196be937d 100644 --- a/python/oneflow/compatible/single_client/nn/optimizer/rmsprop.py +++ b/python/oneflow/compatible/single_client/nn/optimizer/rmsprop.py @@ -121,24 +121,18 @@ def __init__( if param_group["centered"]: self._state[param]["grad_avg"] = flow.experimental.zeros_like(param) self._centered_rmsprop = ( - flow.builtin_op("rmsprop_update") + flow.stateful_op("rmsprop_update") .Input("model") .Input("model_diff") .Input("mean_square") .Input("mean_gradient") - .Attr("centered", True) - .Attr("l1", 0.0) - .Attr("l2", 0.0) .Build() ) self._rmsprop = ( - flow.builtin_op("rmsprop_update") + flow.stateful_op("rmsprop_update") .Input("model") .Input("model_diff") .Input("mean_square") - .Attr("centered", False) - .Attr("l1", 0.0) - .Attr("l2", 0.0) .Build() ) @@ -155,7 +149,7 @@ def step(self, closure: Callable = None): loss = closure() for param_group in self.param_groups: kwargs = { - "learning_rate_val": param_group["lr"], + "learning_rate": param_group["lr"], "scale": param_group["scale"], "epsilon": param_group["eps"], "decay_rate": param_group["alpha"], @@ -167,10 +161,15 @@ def step(self, closure: Callable = None): ms_tensor = self._state[param]["square_avg"] if param_group["centered"]: mg_tensor = self._state[param]["grad_avg"] - self._centered_rmsprop( - param, param.grad, ms_tensor, mg_tensor, **kwargs + flow._C.dispatch_rmsprop_update( + self._centered_rmsprop, + (param, param.grad, ms_tensor, mg_tensor), + centered=True, + **kwargs, ) else: - self._rmsprop(param, param.grad, ms_tensor, **kwargs) + flow._C.dispatch_rmsprop_update( + self._rmsprop, (param, param.grad, ms_tensor), **kwargs + ) self._state["step"] = self._state["step"] + 1 return loss diff --git a/python/oneflow/compatible/single_client/nn/optimizer/sgd.py b/python/oneflow/compatible/single_client/nn/optimizer/sgd.py index 17f82ea0f50..1c3ef75effc 100644 --- a/python/oneflow/compatible/single_client/nn/optimizer/sgd.py +++ b/python/oneflow/compatible/single_client/nn/optimizer/sgd.py @@ -78,23 +78,14 @@ def __init__( param ) self._momentum_sgd = ( - flow.builtin_op("momentum_update") + flow.stateful_op("momentum_update") .Input("model") .Input("model_diff") .Input("momentum") - .Attr("l1", 0.0) - .Attr("l2", 0.0) - .Attr("weight_decay", 0.0) .Build() ) self._sgd = ( - flow.builtin_op("sgd_update") - .Input("model") - .Input("model_diff") - .Attr("weight_decay", 0.0) - .Attr("l1", 0.0) - .Attr("l2", 0.0) - .Build() + flow.stateful_op("sgd_update").Input("model").Input("model_diff").Build() ) def step(self, closure: Callable = None): @@ -109,16 +100,20 @@ def step(self, closure: Callable = None): continue if param_group["momentum"] == 0.0: scale = param_group["scale"] - self._sgd(param, param.grad, learning_rate_val=lr, scale=scale) + flow._C.dispatch_sgd_update( + self._sgd, + (param, param.grad), + learning_rate=lr, + scale=scale, + ) else: momentum_buf = self._state[param]["momentum_buf"] scale = param_group["scale"] beta = param_group["momentum"] - self._momentum_sgd( - param, - param.grad, - momentum_buf, - learning_rate_val=lr, + flow._C.dispatch_momentum_update( + self._momentum_sgd, + (param, param.grad, momentum_buf), + learning_rate=lr, scale=scale, beta=beta, ) diff --git a/python/oneflow/compatible/single_client/ops/layers.py b/python/oneflow/compatible/single_client/ops/layers.py index 1f64355013b..df6954ed9db 100644 --- a/python/oneflow/compatible/single_client/ops/layers.py +++ b/python/oneflow/compatible/single_client/ops/layers.py @@ -839,7 +839,6 @@ def layer_norm_Job(x: tp.Numpy.Placeholder((1, 64, 128, 128)) op_builder.Input("beta", [beta]) if gamma is not None: op_builder.Input("gamma", [gamma]) - op_builder.Output("normalized") op_builder.Attr("center", center) op_builder.Attr("scale", scale) op_builder.Attr("begin_norm_axis", begin_norm_axis) @@ -855,6 +854,7 @@ def layer_norm_grad( x: oneflow._oneflow_internal.BlobDesc, mean: oneflow._oneflow_internal.BlobDesc, inv_variance: oneflow._oneflow_internal.BlobDesc, + gamma: oneflow._oneflow_internal.BlobDesc = None, begin_norm_axis: int = 1, name: str = "LayerNormGrad", ) -> oneflow._oneflow_internal.BlobDesc: @@ -865,13 +865,14 @@ def layer_norm_grad( x (oneflow._oneflow_internal.BlobDesc): Input `Blob`. mean (oneflow._oneflow_internal.BlobDesc): Mean over neurons. inv_variance (oneflow._oneflow_internal.BlobDesc): Variance over neurons. + gamma (oneflow._oneflow_internal.BlobDesc): Scale parameter. begin_norm_axis (int, optional): An integer specifies which axis to normalize at first. Defaults to 1. name (Optional[str], optional): This layer's name. Defaults to None. Returns: oneflow._oneflow_internal.BlobDesc: Gradient with respect to input `Blob`. """ - op = ( + op_builder = ( flow.user_op_builder(name) .Op("layer_norm_grad") .Input("dy", [dy]) @@ -879,17 +880,19 @@ def layer_norm_grad( .Input("mean", [mean]) .Input("inv_variance", [inv_variance]) .Output("dx") - .Attr("begin_norm_axis", begin_norm_axis) - .Attr("epsilon", 1e-05) - .Build() ) - return op.InferAndTryRun().SoleOutputBlob() + if gamma is not None: + op_builder.Input("gamma", [gamma]) + op_builder.Attr("begin_norm_axis", begin_norm_axis) + op_builder.Attr("epsilon", 1e-05) + return op_builder.Build().InferAndTryRun().SoleOutputBlob() def layer_norm_param_grad( dy: oneflow._oneflow_internal.BlobDesc, - norm: oneflow._oneflow_internal.BlobDesc, - gamma: oneflow._oneflow_internal.BlobDesc, + x: oneflow._oneflow_internal.BlobDesc, + mean: oneflow._oneflow_internal.BlobDesc, + inv_variance: oneflow._oneflow_internal.BlobDesc, begin_params_axis: int = -1, name: str = "LayerNormParamGrad", ) -> Tuple[ @@ -901,14 +904,14 @@ def layer_norm_param_grad( Args: dy (oneflow._oneflow_internal.BlobDesc): Upstream derivstives. - norm (oneflow._oneflow_internal.BlobDesc): Normalized output. - gamma (oneflow._oneflow_internal.BlobDesc): Scale parameter. + x (oneflow._oneflow_internal.BlobDesc): Input `Blob`. + mean (oneflow._oneflow_internal.BlobDesc): Mean over neurons. + inv_variance (oneflow._oneflow_internal.BlobDesc): Variance over neurons. begin_params_axis (int, optional): From which parameters to begin with. Defaults to -1. name (Optional[str], optional): This layer's name. Defaults to 'LayerNormParamGrad'. Returns: Tuple[oneflow._oneflow_internal.BlobDesc]: - normalized_diff: Gradient with respect to input `Blob`. beta_diff: Gradient with respect to shift parameter beta. gamma_diff: Gradient with respect to scale parameter gamma. """ @@ -916,22 +919,16 @@ def layer_norm_param_grad( flow.user_op_builder(name) .Op("layer_norm_param_grad") .Input("dy", [dy]) - .Input("normalized", [norm]) - .Input("gamma", [gamma]) - .Output("normalized_diff") + .Input("x", [x]) + .Input("mean", [mean]) + .Input("inv_variance", [inv_variance]) .Output("beta_diff") .Output("gamma_diff") - .Output("reduce_buf") .Attr("begin_params_axis", begin_params_axis) .Build() ) - ( - normalized_diff, - beta_diff, - gamma_diff, - reduce_buf, - ) = op.InferAndTryRun().RemoteBlobList() - return (normalized_diff, beta_diff, gamma_diff) + (beta_diff, gamma_diff,) = op.InferAndTryRun().RemoteBlobList() + return (beta_diff, gamma_diff) def _get_batch_normalization_variables( diff --git a/python/oneflow/compatible/single_client/ops/math_ops.py b/python/oneflow/compatible/single_client/ops/math_ops.py index f96a428bea6..19d720f80c1 100644 --- a/python/oneflow/compatible/single_client/ops/math_ops.py +++ b/python/oneflow/compatible/single_client/ops/math_ops.py @@ -594,8 +594,8 @@ def reluJob(x: tp.Numpy.Placeholder((3, )) return ( flow.user_op_builder(name if name is not None else id_util.UniqueStr("Relu_")) .Op("relu") - .Input("in", [x]) - .Output("out") + .Input("x", [x]) + .Output("y") .Build() .InferAndTryRun() .RemoteBlobList()[0] diff --git a/python/oneflow/compatible/single_client/ops/builtin_ops.py b/python/oneflow/compatible/single_client/ops/stateful_ops.py similarity index 69% rename from python/oneflow/compatible/single_client/ops/builtin_ops.py rename to python/oneflow/compatible/single_client/ops/stateful_ops.py index 01fda38a604..91dce02dc3a 100644 --- a/python/oneflow/compatible/single_client/ops/builtin_ops.py +++ b/python/oneflow/compatible/single_client/ops/stateful_ops.py @@ -21,7 +21,7 @@ ) -class BuiltinOp(object): +class StatefulOp(object): def __init__(self, op_type_name, op_name=None): if op_name is None: op_name = id_util.UniqueStr(op_type_name) @@ -68,31 +68,6 @@ def Output(self, output_name, num=1): self._builder.output(output_name, num) return self - def Attr(self, attr_name, attr_value, attr_type_name=None): - """Set value of op's attribute. - - Args: - attr_name (str): attribute name of op - attr_value (Any): attribute value of op - - Raises: - ValueError: raised when value is not idential to op's attribute type. - - Returns: - [type]: [description] - """ - if attr_type_name is not None: - print( - 'WARNING: Argument \'attr_type_name\' of UserOpConfBuilder.Attr has been deprecated. Please remove it.\n\n For instance:\n - .Attr("out_num", out_num, "AttrTypeInt64")\n + .Attr("out_num", out_num)\n ' - ) - print(traceback.format_stack()[-2]) - assert self._op_type_name is not None - self._builder.attr( - attr_name, - convert_to_user_attr_value(self._op_type_name, attr_name, attr_value), - ) - return self - def Build(self): """Explicitly complete the construction of the builtin op diff --git a/python/oneflow/compatible/single_client/test/ops/test_ccrelu.py b/python/oneflow/compatible/single_client/test/ops/test_ccrelu.py index 3ee8f22a532..3431d39121b 100644 --- a/python/oneflow/compatible/single_client/test/ops/test_ccrelu.py +++ b/python/oneflow/compatible/single_client/test/ops/test_ccrelu.py @@ -28,8 +28,8 @@ def ccrelu(x, name): return ( flow.user_op_builder(name) .Op("ccrelu") - .Input("in", [x]) - .Output("out") + .Input("x", [x]) + .Output("y") .Build() .InferAndTryRun() .RemoteBlobList()[0] diff --git a/python/oneflow/compatible/single_client/test/ops/test_multi_global_function.py b/python/oneflow/compatible/single_client/test/ops/test_multi_global_function.py index 4060d353959..ba5fd20eec9 100644 --- a/python/oneflow/compatible/single_client/test/ops/test_multi_global_function.py +++ b/python/oneflow/compatible/single_client/test/ops/test_multi_global_function.py @@ -28,8 +28,8 @@ def ccrelu(x, name): return ( flow.user_op_builder(name) .Op("ccrelu") - .Input("in", [x]) - .Output("out") + .Input("x", [x]) + .Output("y") .Build() .InferAndTryRun() .RemoteBlobList()[0] diff --git a/python/oneflow/compatible/single_client/test/ops/test_multi_process.py b/python/oneflow/compatible/single_client/test/ops/test_multi_process.py deleted file mode 100644 index d921a5e3ff0..00000000000 --- a/python/oneflow/compatible/single_client/test/ops/test_multi_process.py +++ /dev/null @@ -1,111 +0,0 @@ -""" -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -import os -import unittest - -import oneflow.compatible.single_client.unittest -from oneflow.compatible import single_client as flow - - -@unittest.skipIf(flow.sysconfig.has_rpc_backend_grpc() == False, "lacks grpc") -@flow.unittest.skip_unless_1n4d() -@unittest.skipIf( - os.getenv("ONEFLOW_TEST_GITHUB_HOSTED"), - "this will fail because github hosted VM has only two CPU cores", -) -class TestMultiProcess(flow.unittest.TestCase): - def test_multi_process(test_case): - flow.config.gpu_device_num(4) - func_config = flow.FunctionConfig() - func_config.concurrency_width(1) - - @flow.global_function() - def Foo(): - with flow.scope.placement("gpu", "0:0-3"): - x = flow.get_variable( - "x", - shape=(2, 5), - dtype=flow.float, - initializer=flow.random_uniform_initializer(minval=0, maxval=1), - trainable=False, - ) - return x - - of_ret = Foo().get() - test_case.assertEqual(of_ret.numpy().shape, (2, 5)) - - def test_worker_to_master_communication(test_case): - flow.config.gpu_device_num(4) - func_config = flow.FunctionConfig() - func_config.concurrency_width(1) - - @flow.global_function() - def Foo(): - with flow.scope.placement("gpu", "0:0"): - x = flow.get_variable( - "x", - shape=(2, 5), - dtype=flow.float, - initializer=flow.random_uniform_initializer(minval=0, maxval=1), - trainable=False, - ) - with flow.scope.placement("gpu", "0:3"): - y = flow.get_variable( - "y", - shape=(2, 5), - dtype=flow.float, - initializer=flow.constant_initializer(0), - trainable=False, - ) - flow.assign(y, x) - return y - - of_ret = Foo().get() - test_case.assertEqual(of_ret.numpy().shape, (2, 5)) - - def test_worker_to_worker_communication(test_case): - flow.config.gpu_device_num(4) - func_config = flow.FunctionConfig() - func_config.concurrency_width(1) - - @flow.global_function() - def Foo(): - with flow.scope.placement("gpu", "0:1"): - x = flow.get_variable( - "x", - shape=(2, 5), - dtype=flow.float, - initializer=flow.random_uniform_initializer(minval=0, maxval=1), - trainable=False, - ) - with flow.scope.placement("gpu", "0:2"): - y = flow.get_variable( - "y", - shape=(2, 5), - dtype=flow.float, - initializer=flow.constant_initializer(0), - trainable=False, - ) - flow.assign(y, x) - return y - - of_ret = Foo().get() - test_case.assertEqual(of_ret.numpy().shape, (2, 5)) - - -if __name__ == "__main__": - unittest.main() diff --git a/python/oneflow/compatible/single_client/test/xrt/test_layer_norm_param_grad.py b/python/oneflow/compatible/single_client/test/xrt/test_layer_norm_param_grad.py index 03090422f85..a6fd65d44ef 100644 --- a/python/oneflow/compatible/single_client/test/xrt/test_layer_norm_param_grad.py +++ b/python/oneflow/compatible/single_client/test/xrt/test_layer_norm_param_grad.py @@ -25,35 +25,37 @@ config = flow.function_config() -def make_job(shape, gamma_shape, params_axis, dtype=flow.float32): +def make_job(shape, mean_shape, params_axis, dtype=flow.float32): config.use_xla_jit(False) config.use_tensorrt(False) @flow.global_function(config) def layer_norm_param_grad_job( dy=flow.FixedTensorDef(shape, dtype=dtype), - norm=flow.FixedTensorDef(shape, dtype=dtype), - gamma=flow.FixedTensorDef(gamma_shape, dtype=dtype), + x=flow.FixedTensorDef(shape, dtype=dtype), + mean=flow.FixedTensorDef(mean_shape, dtype=dtype), + inv_variance=flow.FixedTensorDef(mean_shape, dtype=dtype), ): return flow.layers.layer_norm_param_grad( - dy, norm, gamma, begin_params_axis=params_axis + dy, x, mean, inv_variance, begin_params_axis=params_axis ) return layer_norm_param_grad_job -def make_xla_job(shape, gamma_shape, params_axis, dtype=flow.float32): +def make_xla_job(shape, mean_shape, params_axis, dtype=flow.float32): config.use_xla_jit(True) config.use_tensorrt(False) @flow.global_function(config) def xla_layer_norm_param_grad_job( dy=flow.FixedTensorDef(shape, dtype=dtype), - norm=flow.FixedTensorDef(shape, dtype=dtype), - gamma=flow.FixedTensorDef(gamma_shape, dtype=dtype), + x=flow.FixedTensorDef(shape, dtype=dtype), + mean=flow.FixedTensorDef(mean_shape, dtype=dtype), + inv_variance=flow.FixedTensorDef(mean_shape, dtype=dtype), ): return flow.layers.layer_norm_param_grad( - dy, norm, gamma, begin_params_axis=params_axis + dy, x, mean, inv_variance, begin_params_axis=params_axis ) return xla_layer_norm_param_grad_job @@ -61,26 +63,19 @@ def xla_layer_norm_param_grad_job( @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestLayerNormParamGrad(unittest.TestCase): - def _test_body(self, dy, norm, gamma, params_axis, dtype=np.float32): - f1 = make_job(dy.shape, gamma.shape, params_axis, dtype=flow.float32) - f2 = make_xla_job(dy.shape, gamma.shape, params_axis, dtype=flow.float32) - (d_norm1, d_beta1, d_gamma1) = f1(dy, norm, gamma).get() - (d_norm2, d_beta2, d_gamma2) = f2(dy, norm, gamma).get() - print("normalize diff:") - print(" without xla: ", d_norm1) - print(" with xla: ", d_norm2) + def _test_body(self, dy, x, mean, inv_variance, params_axis, dtype=np.float32): + f1 = make_job(dy.shape, mean.shape, params_axis, dtype=flow.float32) + f2 = make_xla_job(dy.shape, mean.shape, params_axis, dtype=flow.float32) + (d_beta1, d_gamma1) = f1(dy, x, mean, inv_variance).get() + (d_beta2, d_gamma2) = f2(dy, x, mean, inv_variance).get() print("beta diff:") print(" without xla: ", d_beta1) print(" with xla: ", d_beta2) print("gamma diff:") print(" without xla: ", d_gamma1) print(" with xla: ", d_gamma2) - self.assertTrue(d_norm1.shape, d_norm2.shape) self.assertTrue(d_beta1.shape, d_beta2.shape) self.assertTrue(d_gamma1.shape, d_gamma2.shape) - self.assertTrue( - np.allclose(d_norm1.numpy(), d_norm2.numpy(), rtol=0.001, atol=1e-05) - ) self.assertTrue( np.allclose(d_beta1.numpy(), d_beta2.numpy(), rtol=0.001, atol=1e-05) ) @@ -91,25 +86,27 @@ def _test_body(self, dy, norm, gamma, params_axis, dtype=np.float32): def _test_ones_body(self, shape, params_axis=-1, dtype=np.float32): dy = np.ones(shape, dtype=dtype) - norm = np.ones(shape, dtype=dtype) + x = np.ones(shape, dtype=dtype) if params_axis < 0: params_axis += len(shape) - gamma_shape = shape[params_axis:] - if len(gamma_shape) == 0: - gamma_shape = [1] - gamma = np.ones(gamma_shape, dtype=dtype) - self._test_body(dy, norm, gamma, params_axis, dtype=dtype) + mean_shape = shape[params_axis:] + if len(mean_shape) == 0: + mean_shape = [1] + mean = np.ones(mean_shape, dtype=dtype) + inv_variance = np.ones(mean_shape, dtype=dtype) + self._test_body(dy, x, mean, inv_variance, params_axis, dtype=dtype) def _test_random_body(self, shape, params_axis=-1, dtype=np.float32): dy = np.random.random(shape).astype(dtype) - norm = np.random.random(shape).astype(dtype) + x = np.random.random(shape).astype(dtype) if params_axis < 0: params_axis += len(shape) - gamma_shape = shape[params_axis:] - if len(gamma_shape) == 0: - gamma_shape = [1] - gamma = np.random.random(gamma_shape).astype(dtype) - self._test_body(dy, norm, gamma, params_axis, dtype=dtype) + mean_shape = shape[params_axis:] + if len(mean_shape) == 0: + mean_shape = [1] + mean = np.random.random(mean_shape).astype(dtype) + inv_variance = np.random.random(mean_shape).astype(dtype) + self._test_body(dy, x, mean, inv_variance, params_axis, dtype=dtype) def test_ones_input(self): self._test_ones_body((1, 10)) diff --git a/python/oneflow/distributed/launch.py b/python/oneflow/distributed/launch.py index ccf155f4f31..b058bfaa6fc 100644 --- a/python/oneflow/distributed/launch.py +++ b/python/oneflow/distributed/launch.py @@ -156,13 +156,23 @@ def main(): sig_names = {2: "SIGINT", 15: "SIGTERM"} last_return_code = None + # set killing flag to make sure killing signal only executed once + kill_flag = True + def sigkill_handler(signum, frame): + nonlocal kill_flag + if not kill_flag: + return for process in processes: print(f"Killing subprocess {process.pid}") - try: - process.kill() - except Exception: - pass + kill_flag = False + try: + # Note: use os.kill or process.kill() may only kill current process + # use killpg will kill(use signal) this process and all sub-processes + # if orphan sub-processes still exist, use signal.SIGKILL instead. + os.killpg(os.getpid(), signal.SIGTERM) + except Exception: + pass if last_return_code is not None: raise subprocess.CalledProcessError( returncode=last_return_code, cmd=cmd diff --git a/python/oneflow/env.py b/python/oneflow/env.py index 23730d5a609..6a7763a84e1 100644 --- a/python/oneflow/env.py +++ b/python/oneflow/env.py @@ -28,11 +28,19 @@ def get_local_rank(): + """Returns the local rank of current machine. + Local rank is not globally unique. It is only unique per process on a machine. + + Returns: + The the local rank of process on current machine. + + """ return oneflow._oneflow_internal.GetLocalRank() def get_rank(): """Returns the rank of current process group. + Rank is globally unique, range of which is [0, world_size). Returns: The rank of the process group. @@ -42,6 +50,12 @@ def get_rank(): def get_node_size(): + """Returns the number of machines in the current process group. + + Returns: + The the number of machines in the process group. + + """ return oneflow._oneflow_internal.GetNodeSize() @@ -56,4 +70,10 @@ def get_world_size(): def is_multi_client(): + """Returns whether it is currently in multi client mode. + + Returns: + True if currently in multi client mode, otherwise returns Flase. + + """ return oneflow._oneflow_internal.IsMultiClient() diff --git a/python/oneflow/framework/check_point_v2.py b/python/oneflow/framework/check_point_v2.py index 211dfc1459c..1bb32572b4e 100644 --- a/python/oneflow/framework/check_point_v2.py +++ b/python/oneflow/framework/check_point_v2.py @@ -29,6 +29,8 @@ import oneflow.core.framework.variable_meta_info_pb2 as variable_meta_info_pb import oneflow.framework.dtype as dtype_util import oneflow.framework.id_util as id_util +from oneflow.framework.tensor import Tensor +import oneflow.nn.graph.graph as graph_util import pickle SNAPSHOT_DONE_FILENAME = "snapshot_done" @@ -120,10 +122,6 @@ def _save_tensor_to_disk(tensor: "oneflow.Tensor", dir_name: Union[str, Path]) - ValueContainer = Union[FileBackendVariableBlob, np.ndarray, "oneflow.Tensor"] -def _ElemCnt(shape): - return np.prod(shape).astype(int).item() - - def _LoadSingleVariable( path: Optional[str], consistent_src_rank: Optional[int] = None ) -> "flow.Tensor": @@ -205,6 +203,11 @@ def tensor_setstate(self, pickle_dict): ) +def RegisterMethods(): + Tensor.__setstate__ = tensor_setstate + Tensor.__getstate__ = tensor_getstate + + def legacy_load( path: Union[str, Path], consistent_src_rank: Optional[int] = None, ) -> Dict[str, "flow.Tensor"]: @@ -302,6 +305,22 @@ def save( disk I/O. """ path: Path = Path(path) + + if isinstance(obj, graph_util.Graph): + graph: graph_util.Graph = obj + if not graph._is_compiled: + raise RuntimeError("graph must be compiled first.") + + path.mkdir(exist_ok=True) + + serialized_job = str(text_format.MessageToString(graph._forward_job_proto)) + oneflow._oneflow_internal.nn.graph.SaveJobToIR(serialized_job, str(path)) + + for x in graph._state(): + _save_tensor_to_disk(x.origin, path / f"{x.name_prefix}{x.name}") + + return + obj = {"protocol_version": PROTOCOL_VERSION, "data": obj} with tensor_pickling_context(path, consistent_dst_rank): pickled_bytes = pickle.dumps(obj) @@ -312,11 +331,5 @@ def save( pickle_path.write_bytes(pickled_bytes) -def generate_values_by_initializer(initializer, shape, dtype): - np_dtype = np.dtype(dtype_util.convert_oneflow_dtype_to_numpy_dtype(dtype)) - length = _ElemCnt(shape) - return np.array(initializer(length)).astype(np_dtype).reshape(shape) - - save_load_path = None consistent_src_dsk_rank = None diff --git a/python/oneflow/framework/config_util.py b/python/oneflow/framework/config_util.py index 1ab8feec1ca..e586e4cac29 100644 --- a/python/oneflow/framework/config_util.py +++ b/python/oneflow/framework/config_util.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +import os import traceback import oneflow._oneflow_internal @@ -320,6 +321,8 @@ def enable_tensor_float_32_compute(val=True): sess = session_ctx.GetDefaultSession() assert type(val) is bool sess.config_proto.resource.enable_tensor_float_32_compute = val + if not val: + os.environ["ONEFLOW_EP_CUDA_ENABLE_TF32_EXECUTION"] = "0" def api_enable_mem_chain_merge(val: bool = True) -> None: diff --git a/python/oneflow/framework/docstr/__init__.py b/python/oneflow/framework/docstr/__init__.py index 9f2b5b44a16..2416fb8940a 100644 --- a/python/oneflow/framework/docstr/__init__.py +++ b/python/oneflow/framework/docstr/__init__.py @@ -38,3 +38,7 @@ from .dataset import * from .bmm import * from .flatten import * +from .chunk import * +from .broadcast_like import * +from .arange import * +from .split import * diff --git a/python/oneflow/framework/docstr/arange.py b/python/oneflow/framework/docstr/arange.py new file mode 100644 index 00000000000..32f9c62caa5 --- /dev/null +++ b/python/oneflow/framework/docstr/arange.py @@ -0,0 +1,52 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import oneflow +from oneflow.framework.docstr.utils import add_docstr + +add_docstr( + oneflow.arange, + """ + oneflow.arange(start: int = 0, end, step: int = 1, dtype: Optional[oneflow._oneflow_internal.dtype] = None, device: Optional[Union[oneflow._oneflow_internal.device, str]] = None, placement: Optional[oneflow._oneflow_internal.placement] = None, sbp: Optional[Union[oneflow._oneflow_internal.sbp.sbp, List[oneflow._oneflow_internal.sbp.sbp]]] = None, requires_grad: bool = False) + + Returns a 1-D tensor of size :math:`\\left\\lfloor \\frac{\\text{end} - \\text{start}}{\\text{step}} \\right\\rfloor + 1` + with values from :attr:`start` to :attr:`end` with step :attr:`step`. Step is + the gap between two values in the tensor. + + .. math:: + \\text{out}_{i+1} = \\text{out}_i + \\text{step}. + + Args: + start (int): the starting value for the set of points. Default: ``0``. + end (int): the ending value for the set of points + step (int): the gap between each pair of adjacent points. Default: ``1``. + + Keyword args: + dtype(flow.dtype, optional): If `dtype` is not given, the `dtype` is inferred to be `flow.int64`. + device(flow.device, optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor. + requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: `False`. + + For example: + + .. code-block:: python + + >>> import oneflow as flow + + >>> y = flow.arange(0, 5) + >>> y + tensor([0, 1, 2, 3, 4], dtype=oneflow.int64) + + """, +) diff --git a/python/oneflow/framework/docstr/array_ops.py b/python/oneflow/framework/docstr/array_ops.py index 196d2413b02..037b6bb7f8c 100644 --- a/python/oneflow/framework/docstr/array_ops.py +++ b/python/oneflow/framework/docstr/array_ops.py @@ -16,6 +16,36 @@ import oneflow from oneflow.framework.docstr.utils import add_docstr +add_docstr( + oneflow.diagonal, + r""" + oneflow.diagonal(input, offset, dim1, dim2) -> Tensor + + Returns a partial view of input with the its diagonal elements with respect to dim1 and dim2 + appended as a dimension at the end of the shape. + + Args: + input (Tensor): the input tensor.Must be at least 2-dimensional. + offset (Optional[int], 0): which diagonal to consider. Default: 0 (main diagonal) + dim1 (Optional[int], 0): first dimension with respect to which to take diagonal. Default: 0 + dim2 (Optional[int], 1): second dimension with respect to which to take diagonal. Default: 1 + + Returns: + oneflow.Tensor: the output Tensor. + + For example: + + .. code-block:: python + + >>> import oneflow as flow + + >>> input = flow.randn(2, 3, 4) + >>> output = flow.diagonal(input, offset=1, dim1=1, dim2=0) + >>> output.shape + oneflow.Size([4, 1]) + """, +) + add_docstr( oneflow.diag, r""" diff --git a/python/oneflow/nn/modules/broadcast_like.py b/python/oneflow/framework/docstr/broadcast_like.py similarity index 51% rename from python/oneflow/nn/modules/broadcast_like.py rename to python/oneflow/framework/docstr/broadcast_like.py index 1e3d42d9b08..4ef6c3b542b 100644 --- a/python/oneflow/nn/modules/broadcast_like.py +++ b/python/oneflow/framework/docstr/broadcast_like.py @@ -13,29 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. """ -from typing import Optional, Sequence +import oneflow +from oneflow.framework.docstr.utils import add_docstr -import oneflow as flow -from oneflow.nn.module import Module - - -def _calc_broadcast_axes(x, like_tensor): - num_prepend = len(like_tensor.shape) - len(x.shape) - prepend_shape = [1] * num_prepend + list(x.shape) - broadcast_axes = [x for x in range(num_prepend)] - for i in range(num_prepend, len(prepend_shape)): - if prepend_shape[i] != like_tensor.shape[i]: - if prepend_shape[i] != 1: - raise RuntimeError( - f"output with shape {x.shape} doesn't match the broadcast shape {like_tensor.shape}" - ) - else: - broadcast_axes.append(i) - return tuple(broadcast_axes) - - -def broadcast_like_op(x, like_tensor, broadcast_axes: Optional[Sequence] = None): - """This operator broadcast tensor `x` to `like_tensor` according to the broadcast_axes. +add_docstr( + oneflow.broadcast_like, + """ + This operator broadcast tensor `x` to `like_tensor` according to the broadcast_axes. Args: x (Tensor): The input Tensor. @@ -56,10 +40,6 @@ def broadcast_like_op(x, like_tensor, broadcast_axes: Optional[Sequence] = None) >>> broadcast_tensor = flow.broadcast_like(x, like_tensor, broadcast_axes=[1, 2]) >>> broadcast_tensor.shape oneflow.Size([3, 4, 5]) - - """ - if broadcast_axes is None: - broadcast_axes = _calc_broadcast_axes(x, like_tensor) - else: - broadcast_axes = broadcast_axes - return flow._C.broadcast_like(x, like_tensor, broadcast_axes=broadcast_axes) + + """, +) diff --git a/python/oneflow/nn/modules/chunk.py b/python/oneflow/framework/docstr/chunk.py similarity index 71% rename from python/oneflow/nn/modules/chunk.py rename to python/oneflow/framework/docstr/chunk.py index d82b3121d50..0e5cd1b7660 100644 --- a/python/oneflow/nn/modules/chunk.py +++ b/python/oneflow/framework/docstr/chunk.py @@ -13,14 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. """ -import oneflow as flow -from oneflow.framework.tensor import register_tensor_op -from oneflow.ops.array_ops import parse_slice_tuple_list +import oneflow +from oneflow.framework.docstr.utils import add_docstr - -@register_tensor_op("chunk") -def chunk_op(input, chunks, dim: int = 0): - """Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor. Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by chunks. +add_docstr( + oneflow.chunk, + """Splits a tensor into a specific number of chunks. Each chunk is a view of the input tensor. Last chunk will be bigger if the tensor size along the given dimension dim is not divisible by chunks. Args: input (oneflow.Tensor): The tensor to split. @@ -59,16 +57,5 @@ def chunk_op(input, chunks, dim: int = 0): >>> of_out_shape [(5, 3, 6, 2), (5, 3, 6, 2), (5, 3, 6, 2), (5, 3, 6, 3)] - """ - split_size = input.shape[dim] // chunks - if split_size * chunks != input.shape[dim]: - split_size = [split_size] * (chunks - 1) + [ - input.shape[dim] - split_size * (chunks - 1) - ] - return flow._C.split(input, split_size=split_size, dim=dim) - - -if __name__ == "__main__": - import doctest - - doctest.testmod(raise_on_error=True) + """, +) diff --git a/python/oneflow/framework/docstr/conv.py b/python/oneflow/framework/docstr/conv.py index 220afe2d834..3655e794b81 100644 --- a/python/oneflow/framework/docstr/conv.py +++ b/python/oneflow/framework/docstr/conv.py @@ -47,11 +47,10 @@ >>> import oneflow as flow >>> import numpy as np - >>> import oneflow.nn as nn >>> input = flow.tensor(np.random.randn(33, 16, 30), dtype=flow.float32) >>> filters = flow.tensor(np.random.randn(20, 16, 5), dtype=flow.float32) - >>> out = nn.functional.conv1d(input, filters,stride=[1], padding=[0], dilation=[1]) + >>> out = flow._C.conv1d(input, filters,stride=[1], padding=[0], dilation=[1], channel_pos="channels_first") """, ) add_docstr( diff --git a/python/oneflow/framework/docstr/math_ops.py b/python/oneflow/framework/docstr/math_ops.py index 8dc6c91de65..e45b49bf555 100644 --- a/python/oneflow/framework/docstr/math_ops.py +++ b/python/oneflow/framework/docstr/math_ops.py @@ -33,7 +33,7 @@ >>> x = flow.tensor(np.array([-1, 2, -3, 4]).astype(np.float32)) >>> flow.abs(x) tensor([1., 2., 3., 4.], dtype=oneflow.float32) - + """, ) @@ -51,7 +51,7 @@ >>> import numpy as np >>> import oneflow as flow - + # element-wise add >>> x = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> y = flow.tensor(np.random.randn(2,3), dtype=flow.float32) @@ -86,7 +86,7 @@ Args: input (Tensor): the input tensor. - + For example: .. code-block:: python @@ -99,7 +99,7 @@ oneflow.Size([4]) >>> output.numpy() array([-1., 1., 0., 0.], dtype=float32) - + >>> input1 = flow.tensor(np.array([[0.8, 1.0], [-0.6, 2.5]]), dtype=flow.float32) >>> output1 = input1.floor() >>> output1.shape @@ -118,18 +118,18 @@ .. math:: out = \frac{input}{other} - + Args: input (Union[int, float, oneflow.Tensor]): input. other (Union[int, float, oneflow.Tensor]): other. - + For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow - + # element-wise divide >>> input = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> other = flow.tensor(np.random.randn(2,3), dtype=flow.float32) @@ -148,7 +148,7 @@ >>> input = flow.tensor(np.random.randn(1,1), dtype=flow.float32) >>> other = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> out = flow.div(input,other).numpy() - >>> out.shape + >>> out.shape (2, 3) """, @@ -157,19 +157,19 @@ add_docstr( oneflow.mul, r"""Computes the multiplication of input by other for each element, scalar and broadcast promotation are supported. - + The formula is: .. math:: \text{out}_i = \text{input}_i \times \text{other}_i - + For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow - + # element-wise multiply >>> input = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> other = flow.tensor(np.random.randn(2,3), dtype=flow.float32) @@ -188,7 +188,7 @@ >>> input = flow.tensor(np.random.randn(1,1), dtype=flow.float32) >>> other = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> out = flow.mul(input,other).numpy() - >>> out.shape + >>> out.shape (2, 3) """, @@ -205,7 +205,7 @@ >>> import numpy as np >>> import oneflow as flow - + >>> x = flow.tensor(np.array([[1, 2, 3], [4, 5, 6]]), dtype=flow.float32) >>> out = flow.reciprocal(x) >>> out.numpy() @@ -221,14 +221,14 @@ .. math:: out = input - other - + For example: .. code-block:: python >>> import numpy as np >>> import oneflow as flow - + # element-wise subtract >>> input = flow.tensor(np.random.randn(2,3), dtype=flow.float32) >>> other = flow.tensor(np.random.randn(2,3), dtype=flow.float32) @@ -335,14 +335,14 @@ For example: .. code-block:: python - + >>> import oneflow as flow >>> import numpy as np >>> input = flow.tensor(np.array([0.5, 0.6, 0.7]), dtype=flow.float32) >>> output = flow.atan(input) >>> output.shape oneflow.Size([3]) - + """, ) @@ -351,24 +351,24 @@ r"""Returns a new tensor with the ceil of the elements of :attr:`input`, the smallest integer greater than or equal to each element. - The equation is: + The equation is: .. math:: \text{out}_{i} = \left\lceil \text{input}_{i} \right\rceil = \left\lfloor \text{input}_{i} \right\rfloor + 1 Args: input (oneflow.Tensor): A Tensor. - + Returns: oneflow.Tensor: The result Tensor - For example: + For example: - .. code-block:: python - + .. code-block:: python + >>> import oneflow as flow - >>> import numpy as np + >>> import numpy as np >>> x = flow.tensor(np.array([0.1, -2, 3.4]).astype(np.float32)) >>> y = flow.ceil(x) >>> y.shape @@ -399,27 +399,27 @@ add_docstr( oneflow.negative, r"""This operator computes the negative value of Tensor. - + Args: input (oneflow.Tensor): A Tensor - + Returns: oneflow.Tensor: The result Tensor - + For example: - + .. code-block:: python >>> import numpy as np >>> import oneflow as flow - + >>> input = flow.tensor( ... np.array([1.0, -1.0, 2.3]).astype(np.float32), dtype=flow.float32 ... ) >>> out = flow.negative(input) >>> out tensor([-1.0000, 1.0000, -2.3000], dtype=oneflow.float32) - + """, ) @@ -468,7 +468,7 @@ >>> import numpy as np >>> import oneflow as flow - + >>> x = flow.tensor(np.array([1, 2, 3]).astype(np.float32), dtype=flow.float32) >>> y = flow.exp(x) >>> y @@ -591,13 +591,13 @@ r"""Returns a new tensor with the sine of the elements of :attr:`input`. sin(x: Tensor) -> Tensor - + .. math:: \text{y}_{i} = \sin(\text{x}_{i}) - + Args: x (Tensor): the input tensor. - + For example: .. code-block:: python @@ -607,12 +607,12 @@ >>> y1 = flow.sin(x1) >>> y1 tensor([-0.5194, 0.1343, -0.4032, -0.2712], dtype=oneflow.float32) - + >>> x2 = flow.tensor(np.array([-1.4, 2.6, 3.7]).astype(np.float32), device=flow.device('cuda')) >>> y2 = flow.sin(x2) >>> y2 tensor([-0.9854, 0.5155, -0.5298], device='cuda:0', dtype=oneflow.float32) - + """, ) @@ -734,7 +734,7 @@ oneflow.cos, r""" Returns a new tensor with the cosine of the elements of :attr:`input`. - + .. math:: \text{out}_{i} = \cos(\text{input}_{i}) @@ -771,7 +771,7 @@ >>> import numpy as np >>> import oneflow as flow - + >>> arr = np.array([ 0.1632, 1.1835, -0.6979, -0.7325]) >>> input = flow.tensor(arr, dtype=flow.float32) >>> output = flow.cosh(input).numpy() @@ -792,15 +792,15 @@ x (oneflow.Tensor): A Tensor Returns: - oneflow.Tensor: The result Tensor - + oneflow.Tensor: The result Tensor + For example: .. code-block:: python >>> import oneflow as flow >>> import numpy as np - + >>> x = flow.tensor(np.array([0, -1., 10.]), dtype=flow.float32) >>> out = flow.erf(x) >>> out.shape @@ -830,7 +830,7 @@ add_docstr( oneflow.erfc, - r"""Computes the complementary error function of each element of input. The complementary error + r"""Computes the complementary error function of each element of input. The complementary error function is defined as follows: .. math:: @@ -848,7 +848,7 @@ >>> import oneflow as flow >>> import numpy as np - + >>> x = flow.tensor(np.array([0, -1., 10.]), dtype=flow.float32) >>> out = flow.erfc(x) >>> out @@ -859,7 +859,7 @@ >>> out tensor([[1.0000e+00, 1.8427e+00, 2.8026e-45], [1.5375e-12, 4.1838e-23, 2.5790e-01]], dtype=oneflow.float32) - + """, ) @@ -869,21 +869,21 @@ of :attr:`input`. - The equation is: + The equation is: .. math:: y_{i} = e^{x_{i}} - 1 Args: input (oneflow.Tensor): A Tensor. - + Returns: oneflow.Tensor: The result Tensor - For example: + For example: + + .. code-block:: python - .. code-block:: python - >>> import oneflow as flow >>> import numpy as np >>> x = flow.tensor(np.array([1, 2, 3]).astype(np.float32)) @@ -945,13 +945,13 @@ oneflow.log, r""" Returns a new tensor with the natural logarithm of the elements of :attr:`input`. - + .. math:: y_{i} = \log_{e} (x_{i}) Args: input (Tensor): the input tensor. - + For example: .. code-block:: python @@ -1038,7 +1038,7 @@ >>> import oneflow as flow >>> import numpy as np - + >>> x = flow.tensor(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]), dtype=flow.float32) >>> out = flow.pow(x, 2) >>> out @@ -1049,7 +1049,7 @@ >>> out = flow.pow(x, y) >>> out tensor([ 1., 4., 27., 256.], dtype=oneflow.float32) - + """, ) @@ -1070,7 +1070,7 @@ >>> import oneflow as flow >>> import numpy as np - + >>> a = flow.tensor(np.array([1.0, 2.0, 3.0]), dtype=flow.float32) >>> out = flow.rsqrt(a).numpy() >>> out @@ -1094,7 +1094,7 @@ >>> import oneflow as flow >>> import numpy as np - + >>> arr = np.array([1.0, 2.0, 3.0]) >>> input = flow.tensor(arr, dtype=flow.float32) >>> output = flow.sqrt(input).numpy() @@ -1120,7 +1120,7 @@ >>> import oneflow as flow >>> import numpy as np - + >>> arr = np.array([1.0, 2.0, 3.0]) >>> input = flow.tensor(arr, dtype=flow.float32) >>> output = flow.square(input).numpy() @@ -1190,8 +1190,8 @@ dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, reduce over all of them. - If keepdim is True, the output tensor is of the same size as input except in - the dimension(s) dim where it is of size 1. Otherwise, dim is squeezed, + If keepdim is True, the output tensor is of the same size as input except in + the dimension(s) dim where it is of size 1. Otherwise, dim is squeezed, resulting in the output tensor having 1 (or len(dim)) fewer dimension(s). If :attr:`unbiased` is ``False``, then the standard-deviation will be calculated @@ -1209,7 +1209,7 @@ >>> import oneflow as flow >>> import numpy as np - + >>> arr = np.array([1.0, 2.0, 3.0]) >>> input = flow.tensor(arr) >>> output = flow.std(input, dim=0).numpy() @@ -1223,8 +1223,8 @@ oneflow.var, r"""Returns the variance of each row of the `input` tensor in the given dimension `dim`. - If `keepdim` is `True`, the output tensor is of the same size as `input` except in the dimension(s) `dim` - where it is of size 1. Otherwise, dim is squeezed (see `flow.squeeze()`), resulting in the output + If `keepdim` is `True`, the output tensor is of the same size as `input` except in the dimension(s) `dim` + where it is of size 1. Otherwise, dim is squeezed (see `flow.squeeze()`), resulting in the output tensor having 1 (or `len(dim)`) fewer dimension(s). Args: @@ -1242,7 +1242,7 @@ >>> import numpy as np >>> import oneflow as flow - + >>> input = flow.tensor(np.random.randn(2, 3, 4, 5)) >>> output = flow.var(input, 1, True) @@ -1252,7 +1252,7 @@ add_docstr( oneflow.logical_not, r""" - Computes the element-wise logical NOT of the given input tensors. + Computes the element-wise logical NOT of the given input tensors. Zeros are treated as False and nonzeros are treated as True. Args: input (oneflow.Tensor): The input Tensor @@ -1272,20 +1272,20 @@ >>> out = flow.logical_not(input) >>> out tensor([0, 1, 0], dtype=oneflow.int8) - + """, ) add_docstr( oneflow.dot, r"""This operator computes the dot product of tensor input and other. - + The equation is: - - $$ - ā€‹ \\sum_{i=1}^{n}(x[i] * y[i]) + + $$ + \\sum_{i=1}^{n}(x[i] * y[i]) $$ - + Args: input (Tensor): first tensor in the dot product. other (Tensor): second tensor in the dot product. @@ -1301,7 +1301,7 @@ >>> import oneflow as flow >>> flow.dot(flow.Tensor([2, 3]), flow.Tensor([2, 1])) tensor(7., dtype=oneflow.float32) - + """, ) @@ -1315,16 +1315,16 @@ Args: input (Tensor): the input tensor. - source (int or a list): Original positions of the dims to move. These must be unique. + source (int or a list): Original positions of the dims to move. These must be unique. destination (int or a list): Destination positions for each of the original dims. These must also be unique. - + Returns: oneflow.Tensor: the output Tensor. For example: .. code-block:: python - + >>> import oneflow as flow >>> import numpy as np @@ -1337,3 +1337,71 @@ oneflow.Size([3, 4, 2, 5]) """, ) + +add_docstr( + oneflow.eye, + """oneflow.eye(n, m, *, device=None, requires_grad=False, placement=None, sbp) -> Tensor + + This operator creates a 2-D Tensor with ones on the diagonal and zeros elsewhere. + + Args: + n (int): the number of rows. + m (int, optional): the number of colums with default being n. Defaults to None. + + Keyword args: + device(Union[flow.device, str], optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor. + requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: `False`. + placement(oneflow._oneflow_internal.placement, optional): The placement attribute allows you to specify which physical device the tensor is stored on. + sbp(Union[oneflow._oneflow_internal.sbp.sbp, List[oneflow._oneflow_internal.sbp.sbp]], optional): When creating a consistent tensor, specify the SBP of the tensor. + + Returns: + oneflow.Tensor: The result tensor with ones on the diagonal and zeros elsewhere. + + For example: + + .. code-block:: python + + >>> import oneflow as flow + >>> out = flow.eye(3, 3) + >>> out + tensor([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]], dtype=oneflow.float32) + >>> out = flow.eye(3, 3, device="cuda") + >>> out + tensor([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]], device='cuda:0', dtype=oneflow.float32) + """, +) + +add_docstr( + oneflow.cumsum, + r"""This operator computes the cumulative sum of input elements in the given dimension. + + The equation is: + + $$ + y_{i}=x_{0}+x_{1}+...+x_{i} + $$ + + Args: + input (Tensor): the input ND tensor. + dim (int): the dimension to do cumsum, valid range is [-N, N-1), N is tensor's dimensions + + Returns: + oneflow.Tensor: The result tensor with cumsum result. + + For example: + + .. code-block:: python + + >>> import oneflow as flow + >>> input=flow.ones(3,3) + >>> dim=1 + >>> flow.cumsum(input,dim) + tensor([[1., 2., 3.], + [1., 2., 3.], + [1., 2., 3.]], dtype=oneflow.float32) + """, +) diff --git a/python/oneflow/framework/docstr/meshgrid.py b/python/oneflow/framework/docstr/meshgrid.py index ea8611f1396..d237f08dd14 100644 --- a/python/oneflow/framework/docstr/meshgrid.py +++ b/python/oneflow/framework/docstr/meshgrid.py @@ -28,7 +28,11 @@ Args: tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be - treated as tensors of size :math:`(1,)` automatically + treated as tensors of size :math:`(1,)` automatically. + indexing ((string, optional): the indexing mode, either "xy" or "ij", defaults to "ij". + If "ij" is selected, the dimensions are in the same order as the cardinality of the inputs. + If "xy" is selected, the first dimension corresponds to the cardinality of + the second input and the second dimension corresponds to the cardinality of the first input. Returns: seq (sequence of Tensors): If the input has :math:`k` tensors of size @@ -42,11 +46,11 @@ >>> import numpy as np >>> import oneflow as flow - >>> input1 = flow.tensor(np.array([1, 2, 3]), dtype=flow.float32) + >>> input1 = flow.tensor(np.array([2, 2, 3]), dtype=flow.float32) >>> input2 = flow.tensor(np.array([4, 5, 6]), dtype=flow.float32) >>> of_x, of_y = flow.meshgrid(input1, input2) >>> of_x - tensor([[1., 1., 1.], + tensor([[2., 2., 2.], [2., 2., 2.], [3., 3., 3.]], dtype=oneflow.float32) >>> of_y diff --git a/python/oneflow/framework/docstr/norm.py b/python/oneflow/framework/docstr/norm.py index 833e036ea8a..a3a0b0048af 100644 --- a/python/oneflow/framework/docstr/norm.py +++ b/python/oneflow/framework/docstr/norm.py @@ -263,3 +263,81 @@ """, ) + +add_docstr( + oneflow._C.normalize, + """nn.functional.normalize(input: Tensor, p: float=2.0, dim: int=0, epsilon: float=1e-12) -> Tensor + + Performs :math:`L_p` normalization of inputs over specified dimension + + For a tensor :attr:`input` of sizes :math:`(n_0, ..., n_{dim}, ..., n_k)`, each + :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`dim` is transformed as: + + .. math:: + v = \\frac{v}{\max(\\lVert v \\rVert_p, \\epsilon)}. + + With the default arguments it uses the Euclidean norm over vectors along dimension :math:`1` for normalization. + + But note that the gradient calculation of the input tensor has different results on different frameworks + when `input.shape[dim] = 1`. + + Args: + input (oneflow.Tensor): input tensor of any shape + p (float): the exponent value in the norm formulation. Default: 2 + dim (int): the dimension to reduce. Default: 1 + eps (float): small value to avoid division by zero. Default: 1e-12 + + For example: + + .. code-block:: python + + >>> import oneflow as flow + >>> x = flow.tensor([[1, 2], [3, 4]], dtype=flow.float32) + >>> out = flow.nn.functional.normalize(x, 2, 0) + >>> out + tensor([[0.3162, 0.4472], + [0.9487, 0.8944]], dtype=oneflow.float32) + >>> out = flow.nn.functional.normalize(x, 2, 1) + >>> out + tensor([[0.4472, 0.8944], + [0.6000, 0.8000]], dtype=oneflow.float32) + + """, +) + +add_docstr( + oneflow._C.l2_normalize, + """nn.functional.l2_normalize(input: Tensor, dim: int=0, epsilon: float=1e-12) -> Tensor + + Use L2 norm to normalizes along dimension `dim` + + The equation is: + + .. math:: + out = \\frac{x}{max(\\sqrt{\\Sigma{x^2}}, \\epsilon)} + + Args: + input (oneflow.Tensor): Input Tensor + dim (int): The axis on which to apply L2 normalization. Defaults to 0. + epsilon (float): The epsilon value is used to avoid division by zero. Defaults to 1e-12. + + Returns: + oneflow.Tensor: The normalized Tensor + + For example: + + .. code-block:: python + + >>> import oneflow as flow + >>> x = flow.tensor([[1, 2], [3, 4]], dtype=flow.float32) + >>> out = flow.nn.functional.l2_normalize(x, 0) + >>> out + tensor([[0.3162, 0.4472], + [0.9487, 0.8944]], dtype=oneflow.float32) + >>> out = flow.nn.functional.l2_normalize(x, 1) + >>> out + tensor([[0.4472, 0.8944], + [0.6000, 0.8000]], dtype=oneflow.float32) + + """, +) diff --git a/python/oneflow/nn/modules/split.py b/python/oneflow/framework/docstr/split.py similarity index 80% rename from python/oneflow/nn/modules/split.py rename to python/oneflow/framework/docstr/split.py index cf65fc1012e..7b69312aa08 100644 --- a/python/oneflow/nn/modules/split.py +++ b/python/oneflow/framework/docstr/split.py @@ -13,15 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. """ -from typing import Union, List -import numpy as np +import oneflow +from oneflow.framework.docstr.utils import add_docstr -import oneflow as flow -from oneflow.framework.tensor import Tensor, register_tensor_op - - -@register_tensor_op("split") -def split_op(x, split_size_or_sections: Union[int, List[int]], dim: int = 0): +add_docstr( + oneflow.split, """Splits the tensor into chunks. If `split_size_or_sections` is an integer type, then x will be split into equally sized chunks (if possible). @@ -50,11 +46,5 @@ def split_op(x, split_size_or_sections: Union[int, List[int]], dim: int = 0): [4, 5], [6, 7], [8, 9]], dtype=oneflow.int64)) - """ - return flow._C.split(x, split_size=split_size_or_sections, dim=dim) - - -if __name__ == "__main__": - import doctest - - doctest.testmod(raise_on_error=True) + """, +) diff --git a/python/oneflow/framework/docstr/tensor.py b/python/oneflow/framework/docstr/tensor.py index c904053d180..f3a6cbd3b2b 100644 --- a/python/oneflow/framework/docstr/tensor.py +++ b/python/oneflow/framework/docstr/tensor.py @@ -51,6 +51,35 @@ """, ) +add_docstr( + oneflow.from_numpy, + r""" + Creates a ``Tensor`` from a ``numpy.ndarray``. + + The returned tensor and ndarray share the same memory. Modifications to the tensor + will be reflected in the ndarray and vice versa. + + It currently accepts ndarray with dtypes of numpy.float64, numpy.float32, numpy.float16, + numpy.int64, numpy.int32, numpy.int8, numpy.uint8. + + For example: + + .. code-block:: python + + >>> import oneflow as flow + >>> import numpy as np + >>> np_arr = np.arange(6).reshape(2, 3) + >>> t = flow.from_numpy(np_arr) + >>> t + tensor([[0, 1, 2], + [3, 4, 5]], dtype=oneflow.int64) + >>> np_arr[0, 0] = -1 + >>> t + tensor([[-1, 1, 2], + [ 3, 4, 5]], dtype=oneflow.int64) + """, +) + add_docstr( oneflow.Tensor.atan2, r""" @@ -337,6 +366,20 @@ """, ) +add_docstr( + oneflow.Tensor.chunk, + """ + See :func:`oneflow.chunk` + """, +) + +add_docstr( + oneflow.Tensor.split, + """ + See :func:`oneflow.split` + """, +) + add_docstr( oneflow.Tensor.cast, """ @@ -344,7 +387,6 @@ """, ) - add_docstr( oneflow.Tensor.diag, """ diff --git a/python/oneflow/framework/env_util.py b/python/oneflow/framework/env_util.py index e6d6d0f629b..fa5461f002f 100644 --- a/python/oneflow/framework/env_util.py +++ b/python/oneflow/framework/env_util.py @@ -337,10 +337,10 @@ def _FindFreePort(): def HasAllMultiClientEnvVars(): env_var_names = ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK", "LOCAL_RANK"] - has_all_env_vars = all([os.getenv(x) for x in env_var_names]) - if not has_all_env_vars: - has_at_least_one_env_var = any([os.getenv(x) for x in env_var_names]) - assert not has_at_least_one_env_var + env_var_values = [os.getenv(x) for x in env_var_names] + has_no_env_vars = not any(env_var_values) + has_all_env_vars = all(env_var_values) + assert has_no_env_vars or has_all_env_vars, list(zip(env_var_names, env_var_values)) return has_all_env_vars diff --git a/python/oneflow/framework/graph_build_util.py b/python/oneflow/framework/graph_build_util.py index c37d9a0af4a..479c33c7262 100644 --- a/python/oneflow/framework/graph_build_util.py +++ b/python/oneflow/framework/graph_build_util.py @@ -29,6 +29,8 @@ import oneflow.framework.session_context as session_context from oneflow.framework.tensor import Tensor +import oneflow._oneflow_internal._C as _C + lazy_mode = oneflow._oneflow_internal.lazy_mode @@ -137,8 +139,7 @@ def build_graph_input_arg(op_name, arg): input_op = oneflow._oneflow_internal.one.FeedInputOpExpr( op_name, input_conf, ["in_0"], ["out_0"] ) - attrs = oneflow._oneflow_internal.MutableCfgAttrMap() - lazy_arg = input_op.apply([arg], attrs)[0] + lazy_arg = _C.dispatch_feed_input(input_op, arg) return lazy_arg @@ -150,17 +151,14 @@ def build_graph_state(op_name, state_tensor, state_config): var_op = oneflow._oneflow_internal.one.FeedVariableOpExpr( op_name, var_conf, ["in_0"], ["out_0"] ) - - attrs = oneflow._oneflow_internal.MutableCfgAttrMap() + l2 = 0.0 if state_config is not None: - attr_l2 = user_op_attr_cfg.AttrValue() - attr_l2.set_at_double(state_config.l2) - attrs["l2"] = attr_l2 + l2 = state_config.l2 elif state_tensor.requires_grad: - attrs["l2"] = 0.0 + l2 = 0.0 assert isinstance(state_tensor, Tensor) - lazy_tensor = var_op.apply([state_tensor], attrs)[0] + lazy_tensor = _C.dispatch_feed_variable(var_op, state_tensor, l2=l2) return lazy_tensor @@ -175,8 +173,6 @@ def build_graph_output(op_name, out): output_op = oneflow._oneflow_internal.one.FetchOutputOpExpr( op_name, output_conf, ["in_0"], ["out_0"] ) - attrs = oneflow._oneflow_internal.MutableCfgAttrMap() - - fake_eager_out = output_op.apply([out], attrs)[0] + fake_eager_out = _C.dispatch_fetch_output(output_op, out) return fake_eager_out diff --git a/python/oneflow/framework/op_expr_util.py b/python/oneflow/framework/op_expr_util.py deleted file mode 100644 index 1421cbefc09..00000000000 --- a/python/oneflow/framework/op_expr_util.py +++ /dev/null @@ -1,36 +0,0 @@ -""" -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" -import oneflow as flow -import oneflow._oneflow_internal -from oneflow.framework.attr_util import convert_to_user_attr_value - - -def user_op_expr_call(self, *args, **kwargs): - attrs = oneflow._oneflow_internal.MutableCfgAttrMap() - for (attr_name, attr_value) in kwargs.items(): - assert isinstance(attr_name, str) - attrs[attr_name] = convert_to_user_attr_value( - self.op_type_name, attr_name, attr_value - ) - try: - results = self.apply(args, attrs) - except flow._oneflow_internal.exception.Exception: - raise oneflow._oneflow_internal.exception.GetThreadLocalLastError() - return results - - -def RegisterMethod4UserOpExpr(): - oneflow._oneflow_internal.one.UserOpExpr.__call__ = user_op_expr_call diff --git a/python/oneflow/framework/register_class_method_util.py b/python/oneflow/framework/register_class_method_util.py index fce74ed0ca5..0687791e116 100644 --- a/python/oneflow/framework/register_class_method_util.py +++ b/python/oneflow/framework/register_class_method_util.py @@ -14,11 +14,11 @@ limitations under the License. """ import oneflow._oneflow_internal +import oneflow.framework.check_point_v2 as check_point_v2 import oneflow.framework.generator as generator -import oneflow.framework.op_expr_util as op_expr_util import oneflow.framework.tensor as tensor_util def RegisterMethod4Class(): tensor_util.RegisterMethods() - op_expr_util.RegisterMethod4UserOpExpr() + check_point_v2.RegisterMethods() diff --git a/python/oneflow/framework/tensor.py b/python/oneflow/framework/tensor.py index 33a1cd30eb5..34b4cb6ab26 100644 --- a/python/oneflow/framework/tensor.py +++ b/python/oneflow/framework/tensor.py @@ -15,8 +15,7 @@ """ import oneflow as flow from oneflow._oneflow_internal.exception import IndexException -import oneflow.framework.check_point_v2 as check_point_v2 -import oneflow.framework.tensor_str as tensor_str_util +import oneflow.framework.tensor_str as tensor_str import oneflow.ops.initializer_util as initializer_util import oneflow._oneflow_internal.lazy_mode as lazy_mode import oneflow.core.framework.variable_meta_info_pb2 as variable_meta_info_pb @@ -143,11 +142,11 @@ def _str(self): def _repr(self): - return tensor_str_util._gen_tensor_str(self) + return tensor_str._gen_tensor_str(self) def _meta_repr(self): - return tensor_str_util._gen_tensor_meta_str(self) + return tensor_str._gen_tensor_meta_str(self) def _eq(self, other): @@ -390,6 +389,10 @@ def _diag(self, diagonal=0): return flow.diag(self, diagonal=diagonal) +def _diagonal(self, offset=0, dim1=0, dim2=1): + return flow._C.diagonal(self, offset=offset, dim1=dim1, dim2=dim2) + + def _log1p(self): return flow.log1p(self) @@ -580,6 +583,14 @@ def _bmm(self, other): return flow.bmm(self, other) +def _chunk(self, chunks=None, dim=None): + return flow._C.chunk(self, chunks, dim) + + +def _split(self, split_size_or_sections=None, dim=None): + return flow._C.split(self, split_size_or_sections, dim) + + def _all(self, dim=None, keepdim=False): return flow.all(self, dim, keepdim) @@ -682,7 +693,7 @@ def _init_by_initializer_conf(tensor, initializer_conf, random_seed=None): shape = tuple(tensor.shape) initializer = initializer_util.GetInitializer(initializer_conf, random_seed, shape) - np_arr = check_point_v2.generate_values_by_initializer( + np_arr = initializer_util.generate_values_by_initializer( initializer, shape, tensor.dtype ) if tensor.is_consistent: @@ -738,6 +749,7 @@ def RegisterMethods(): Tensor.__rmul__ = lambda self, other: self.mul(other) Tensor.__add__ = lambda self, other: self.add(other) Tensor.__iadd__ = lambda self, other: self.add_(other) + Tensor.__matmul__ = lambda self, other: self.matmul(other) Tensor.ndim = property(_ndim) Tensor.numpy = _tensor_numpy Tensor.size = _size @@ -749,8 +761,6 @@ def RegisterMethods(): Tensor.backward = _backward Tensor.__getitem__ = _getitem Tensor.__setitem__ = _setitem - Tensor.__setstate__ = check_point_v2.tensor_setstate - Tensor.__getstate__ = check_point_v2.tensor_getstate Tensor.__str__ = _str Tensor.__repr__ = _repr Tensor.__eq__ = _eq @@ -819,6 +829,7 @@ def RegisterMethods(): Tensor.softsign = _softsign Tensor.cast = _cast Tensor.diag = _diag + Tensor.diagonal = _diagonal Tensor.log1p = _log1p Tensor.add = _add Tensor.add_ = _add_inplace @@ -869,6 +880,8 @@ def RegisterMethods(): Tensor.logical_not = _not Tensor.roll = _roll Tensor.bmm = _bmm + Tensor.chunk = _chunk + Tensor.split = _split Tensor.squeeze = _squeeze Tensor.unfold = _unfold Tensor.narrow = _narrow diff --git a/python/oneflow/framework/tensor_str.py b/python/oneflow/framework/tensor_str.py index 29bd7921d26..0dcd6fd5568 100644 --- a/python/oneflow/framework/tensor_str.py +++ b/python/oneflow/framework/tensor_str.py @@ -16,12 +16,15 @@ """ This file is mostly referenced from PyTorch v1.8.1 torch/_tensor_str.py """ -import os + import math import numpy as np from typing import Optional import oneflow as flow +from oneflow.framework.tensor_str_util import slice_wrapper +from oneflow.framework.tensor_str_util import _autoset_linewidth +from oneflow.framework.tensor_str_util import _try_convert_to_local_tensor class __PrinterOptions(object): @@ -88,7 +91,7 @@ def set_printoptions( PRINT_OPTS.linewidth = 80 elif profile == "full": PRINT_OPTS.precision = 4 - PRINT_OPTS.threshold = inf + PRINT_OPTS.threshold = math.inf PRINT_OPTS.edgeitems = 3 PRINT_OPTS.linewidth = 80 @@ -105,26 +108,6 @@ def set_printoptions( PRINT_OPTS.autoset_linewidth = False -def _autoset_linewidth(): - # os.terminal_size(columns, lines), - # columns represents width of the terminal window in characters - # and lines represents height of the terminal window in characters. - try: - linewidth = os.get_terminal_size()[0] - except OSError: - linewidth = 80 - return linewidth - - -def _try_convert_to_local_tensor(tensor): - if tensor.is_consistent: - tensor = tensor.to_consistent( - placement=flow.env.all_device_placement(tensor.placement.device_type), - sbp=flow.sbp.broadcast, - ).to_local() - return tensor - - class _Formatter(object): def __init__(self, tensor): self.floating_dtype = tensor.dtype.is_floating_point @@ -233,10 +216,10 @@ def _val_formatter(val, formatter1=formatter1): if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems: left_values = _try_convert_to_local_tensor( - self[: PRINT_OPTS.edgeitems] + slice_wrapper(self, [0, PRINT_OPTS.edgeitems, 1]) ).tolist() right_values = _try_convert_to_local_tensor( - self[-PRINT_OPTS.edgeitems :] + slice_wrapper(self, [self.size(0) - PRINT_OPTS.edgeitems, self.size(0), 1]) ).tolist() data = ( [_val_formatter(val) for val in left_values] @@ -266,18 +249,30 @@ def _tensor_str_with_formatter(self, indent, summarize, formatter1): if summarize and self.size(0) > 2 * PRINT_OPTS.edgeitems: slices = ( [ - _tensor_str_with_formatter(self[i], indent + 1, summarize, formatter1) + _tensor_str_with_formatter( + slice_wrapper(self, [i, i + 1, 1]), + indent + 1, + summarize, + formatter1, + ) for i in range(0, PRINT_OPTS.edgeitems) ] + ["..."] + [ - _tensor_str_with_formatter(self[i], indent + 1, summarize, formatter1) + _tensor_str_with_formatter( + slice_wrapper(self, [i, i + 1, 1]), + indent + 1, + summarize, + formatter1, + ) for i in range(self.shape[0] - PRINT_OPTS.edgeitems, self.shape[0]) ] ) else: slices = [ - _tensor_str_with_formatter(self[i], indent + 1, summarize, formatter1) + _tensor_str_with_formatter( + slice_wrapper(self, [i, i + 1, 1]), indent + 1, summarize, formatter1 + ) for i in range(0, self.size(0)) ] @@ -321,18 +316,31 @@ def get_summarized_data(self): if dim == 1: if self.size(0) > 2 * PRINT_OPTS.edgeitems: return flow.cat( - (self[: PRINT_OPTS.edgeitems], self[-PRINT_OPTS.edgeitems :]) + ( + slice_wrapper(self, [0, PRINT_OPTS.edgeitems, 1]), + slice_wrapper( + self, [self.size(0) - PRINT_OPTS.edgeitems, self.size(0), 1] + ), + ) ) else: return self if self.size(0) > 2 * PRINT_OPTS.edgeitems: - start = [self[i] for i in range(0, PRINT_OPTS.edgeitems)] + start = [ + slice_wrapper(self, [i, i + 1, 1]) for i in range(0, PRINT_OPTS.edgeitems) + ] end = [ - self[i] for i in range(self.shape[0] - PRINT_OPTS.edgeitems, self.shape[0]) + slice_wrapper(self, [i, i + 1, 1]) + for i in range(self.shape[0] - PRINT_OPTS.edgeitems, self.shape[0]) ] return flow.stack([get_summarized_data(x) for x in (start + end)]) else: - return flow.stack([get_summarized_data(x) for x in self]) + return flow.stack( + [ + get_summarized_data(slice_wrapper(self, [i, i + 1, 1])) + for i in range(len(self)) + ] + ) def _gen_tensor_str_template(tensor, is_meta): diff --git a/python/oneflow/framework/tensor_str_util.py b/python/oneflow/framework/tensor_str_util.py new file mode 100644 index 00000000000..0dbfa6cfb44 --- /dev/null +++ b/python/oneflow/framework/tensor_str_util.py @@ -0,0 +1,57 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import oneflow as flow +from typing import Optional, Tuple + + +def slice_wrapper(tensor, slice_tuple: Tuple[int, int, int]): + with flow.no_grad(): + ndim = tensor.ndim + slice_tuple_list = [slice_tuple] + [[None, None, None]] * (ndim - 1) + # TODO(): a kind 'slice op' supports both local and consistent tensor + if tensor.is_consistent: + # input is s0, output is p + # input is b, output is b + # input is p, output is p + # so 'to b' is not needed here + tensor = flow.logical_slice(tensor, slice_tuple_list) + else: + tensor = flow.slice(tensor, slice_tuple_list) + # TODO(): flow.sequeeze will fail in some consistent tensor case + if tensor.shape[0] == 1 and ndim > 1: + tensor = tensor.reshape(list(tensor.shape[1:])) + return tensor + + +def _autoset_linewidth(): + # os.terminal_size(columns, lines), + # columns represents width of the terminal window in characters + # and lines represents height of the terminal window in characters. + try: + linewidth = os.get_terminal_size()[0] + except OSError: + linewidth = 80 + return linewidth + + +def _try_convert_to_local_tensor(tensor): + if tensor.is_consistent: + tensor = tensor.to_consistent( + placement=flow.env.all_device_placement(tensor.placement.device_type), + sbp=flow.sbp.broadcast, + ).to_local() + return tensor diff --git a/python/oneflow/nn/functional/__init__.py b/python/oneflow/nn/functional/__init__.py index 616e0153b1f..4a1bd002699 100644 --- a/python/oneflow/nn/functional/__init__.py +++ b/python/oneflow/nn/functional/__init__.py @@ -14,7 +14,6 @@ limitations under the License. """ from oneflow.nn.modules.interpolate import interpolate -from oneflow.nn.modules.norm import l2_normalize from oneflow.nn.modules.affine_grid import affine_grid from oneflow.nn.modules.grid_sample import grid_sample from oneflow.nn.modules.sparse_softmax_cross_entropy import sparse_softmax_cross_entropy @@ -43,6 +42,7 @@ from oneflow._C import gelu from oneflow._C import glu from oneflow._C import logsigmoid +from oneflow._C import log_softmax from oneflow._C import softsign from oneflow._C import softmax from oneflow._C import softplus @@ -57,6 +57,8 @@ from oneflow._C import triplet_margin_loss from oneflow._C import ctc_greedy_decoder from oneflow._C import one_hot +from oneflow._C import l2_normalize +from oneflow._C import normalize from oneflow.nn.modules.sparse import embedding from oneflow.nn.modules.linear import linear from oneflow.nn.modules.activation import relu6 diff --git a/python/oneflow/nn/graph/block.py b/python/oneflow/nn/graph/block.py index ce18009b1a2..d2ef74b4ab4 100644 --- a/python/oneflow/nn/graph/block.py +++ b/python/oneflow/nn/graph/block.py @@ -486,18 +486,8 @@ def _get_from_states(self, name, states_name): _s_block = _states[name] if graph_build_util.lazy_mode.is_enabled(): - # lazy - if _s_block._lazy_origin is None: - assert _s_block._lazy_origin_builder is not None, ( - repr(_s_block) + " has no lazy Tensor creation function." - ) - assert self._is_executing_forward, ( - repr(_s_block) - + "'s first get must happened in it's nn.Module.forward() to generate the right scope." - ) - with _s_block.scope_context(): - _s_block._lazy_origin = _s_block._lazy_origin_builder() - return _s_block._lazy_origin + _s_block.try_build() + return _s_block.lazy_origin elif ( not graph_build_util.lazy_mode.is_enabled() ) and self._is_executing_forward: @@ -565,6 +555,23 @@ def _print(self, s_level=2, v_level=0, msg: str = ""): print(msg) +class LazyBuilder(object): + def __init__(self, name: str = None, method=None): + self.name = name + self.method = method + self.result = None + self.finished = False + + def try_build(self, block=None): + if not self.finished: + assert self.name is not None + assert self.method is not None + assert self.result is None + with block.scope_context(): + self.result = self.method() + self.finished = True + + class TensorBlock(Block): def __init__( self, prefix: str = "", name: str = "", origin: Union[Parameter, Tensor] = None, @@ -577,8 +584,8 @@ def __init__( self._type = BlockType.BUFFER else: raise NotImplementedError() - self._lazy_origin = None - self._lazy_origin_builder = None + self._lazy_origin_builder = LazyBuilder() + self.build_finished = False self.set_origin(origin) @property @@ -593,7 +600,7 @@ def lazy_origin(self): assert ( self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER ), "Only Parameter or Buffer Block has lazy_origin" - return self._lazy_origin + return self._lazy_origin_builder.result def lazy_origin_builder(self): assert ( @@ -601,11 +608,16 @@ def lazy_origin_builder(self): ), "Only Parameter or Buffer Block has lazy_origin_builder" return self._lazy_origin_builder - def set_lazy_origin_builder(self, fn=None): + def set_lazy_origin_builder(self, builder=None): assert ( self._type == BlockType.PARAMETER or self._type == BlockType.BUFFER ), "Only Parameter or Buffer Block has lazy_origin_builder" - self._lazy_origin_builder = fn + self._lazy_origin_builder = builder + + def try_build(self): + if not self.build_finished: + self._lazy_origin_builder.try_build(self) + self.build_finished = True def __repr__(self): lines = None diff --git a/python/oneflow/nn/graph/graph.py b/python/oneflow/nn/graph/graph.py index 50cd3bc4178..ad48452218f 100644 --- a/python/oneflow/nn/graph/graph.py +++ b/python/oneflow/nn/graph/graph.py @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. """ +import os +import time from collections import OrderedDict from functools import partial from typing import Dict, Optional, Union, List -import time +from google.protobuf import text_format import oneflow import oneflow._oneflow_internal @@ -35,6 +37,7 @@ from oneflow.nn.module import Module from oneflow.nn.optimizer.lr_scheduler import LrScheduler from oneflow.nn.optimizer.optimizer import Optimizer +from oneflow.nn.optimizer.sparse_optimizer import SparseOptimizer class Graph(object): @@ -212,7 +215,7 @@ def add_optimizer( opt_dict = dict() assert optim is not None, "optimizer cannot be None" assert isinstance( - optim, Optimizer + optim, (Optimizer, SparseOptimizer) ), "optimizer must be an instance of Optimizer" opt_dict["optim"] = optim if lr_sch is not None: @@ -422,29 +425,71 @@ def _state(self): for bu in bu_gen: yield bu + def _filter_states(self): + state_tensor_set = set() + state_tensors = [] + state_op_names = [] + + for state_block in self._state(): + state_tensor = state_block.origin + if state_tensor in state_tensor_set: + continue + op_name = state_block.name_prefix + state_block.name + state_tensor_set.add(state_tensor) + state_tensors.append(state_tensor) + state_op_names.append(op_name) + + if state_block.type == BlockType.PARAMETER: + self._variables_conf[state_tensor] = VariableConfig(op_name) + + self._state_tensor_tuple = convert_to_tensor_tuple(state_tensors) + return state_op_names + def _generate_config_proto(self): self.config.proto.set_job_name(self._name) + self._outputs_buffer_size = self.config._outputs_buffer_size if self._grad_scaler is not None: self._grad_scaler._generate_conf_for_graph( self.config.proto.mutable_train_conf() ) - for state_block in self._state(): - if state_block.type == BlockType.PARAMETER: - self._variables_conf[state_block.origin] = VariableConfig( - state_block.name_prefix + state_block.name - ) for opt in self._opts: opt_dict = OptDict(opt) self.config._generate_optimizer_and_variable_configs( opt_dict, self._variables_conf ) + def _create_states_builder(self): + state2lazy_builder = dict() + for state_block in self._state(): + state_tensor = state_block.origin + op_name = state_block.name_prefix + state_block.name + if state_tensor in state2lazy_builder: + # Differe tensor block shares the same tensor, so they need to share the same + # builder. + state_block.set_lazy_origin_builder(state2lazy_builder[state_tensor]) + else: + if state_block.type == BlockType.PARAMETER: + assert state_tensor in self._variables_conf + state_config = self._variables_conf[state_tensor] + op_name = state_config.name + else: + state_config = None + # Init a new lazy tensor builder + state_block.lazy_origin_builder().name = op_name + state_block.lazy_origin_builder().method = partial( + graph_build_util.build_graph_state, + op_name, + state_tensor, + state_config, + ) + state2lazy_builder[state_tensor] = state_block.lazy_origin_builder() + def _compile(self, *args): # Build graph try: - self._print(0, 0, self._shallow_repr() + " Start building graph.") + self._print(0, 0, self._shallow_repr() + " start building graph.") assert not self._is_compiled, ( "nn.Graph " + self._name + " has already been compiled." ) @@ -455,7 +500,7 @@ def _compile(self, *args): 0, 0, self._shallow_repr() - + " Done! cost time: " + + " building graph Done! Cost time: " + str(round(build_graph_end - build_graph_start, 2)) + "s." + "\n", @@ -466,7 +511,7 @@ def _compile(self, *args): 0, "[ERROR]" + self._shallow_repr() - + " build graph got error: " + + " building graph got error: " + sys_exc_error_msg(), ) raise @@ -474,9 +519,7 @@ def _compile(self, *args): # Complie graph to execution plan and init Runtime try: self._print( - 0, - 0, - self._shallow_repr() + " Start compiling plan and init graph runtime.", + 0, 0, self._shallow_repr() + " start building plan.", ) compile_and_init_start = time.perf_counter() self._c_nn_graph.complie_and_init_runtime() @@ -485,12 +528,12 @@ def _compile(self, *args): 0, 0, self._shallow_repr() - + " Done! cost time: " + + " building plan Done! Cost time: " + str(round(compile_and_init_end - compile_and_init_start, 2)) + "s." + "\n" + self._shallow_repr() - + " The total time consumed to complete build graph, compiling plan and init graph runtime: " + + "'s total time to build graph and plan : " + str(round(compile_and_init_end - build_graph_start, 2)) + "s." + "\n", @@ -501,7 +544,7 @@ def _compile(self, *args): 0, "[ERROR]" + self._shallow_repr() - + " compiling plan or initialing graph runtime got error: " + + " building plan got error: " + sys_exc_error_msg(), ) raise @@ -513,10 +556,26 @@ def _build_graph(self, *args): session = session_ctx.GetDefaultSession() assert type(session) is MultiClientSession - # Get config form GraphConfig - self._outputs_buffer_size = self.config._outputs_buffer_size + # Filter to get unique states in graph + state_op_names = self._filter_states() + self._generate_config_proto() + # Deal with parameter and buffer + self._print( + 0, + 1, + self._shallow_repr() + + " start building graph builders of parameters and buffers.", + ) + self._create_states_builder() + self._print( + 0, + 1, + self._shallow_repr() + + " end building graph builders of parameters and buffers.", + ) + with graph_build_util.graph_build_context(self.config.proto, session): # Deal with inputs self._print(0, 1, self._shallow_repr() + " start building graph inputs.") @@ -525,19 +584,6 @@ def _build_graph(self, *args): ) self._print(0, 1, self._shallow_repr() + " end building graph inputs.") - # Deal with parameter and buffer - self._print( - 0, - 1, - self._shallow_repr() + " start building graph parameters and buffers.", - ) - state_op_names, self._states_tensor_tuple = self._build_states() - self._print( - 0, - 1, - self._shallow_repr() + " end building graph parameters and buffers.", - ) - # Deal with module in self.build(*args) self._print(0, 1, self._shallow_repr() + " start building graph modules.") outputs = self.build(*lazy_args) @@ -599,7 +645,7 @@ def _build_graph(self, *args): output_op_names, self._outputs_tensor_tuple ) self._c_nn_graph.register_variable_op_names_and_tensors( - state_op_names, self._states_tensor_tuple + state_op_names, self._state_tensor_tuple ) return seq_to_func_return(self._eager_outputs_buffer[0]) @@ -699,7 +745,7 @@ def _run(self, *args): oneflow._oneflow_internal.nn.graph.RunLazyNNGraph( convert_to_tensor_tuple(flattened_eager_args), outputs_tensor_tuple, - self._states_tensor_tuple, + self._state_tensor_tuple, self._c_nn_graph, ) # Update outputs buffer reading index @@ -916,32 +962,6 @@ def _io_item_check_and_gen(self, item, expect_type, io_type, idx, second_idx=Non "nn.Graph.build()'s input/output only support types: Tensor/list(Tensor)/None." ) - def _build_states(self): - state_op_names = [] - state_tensors = [] - for state_block in self._state(): - op_name = state_block.name_prefix + state_block.name - state_tensor = state_block.origin - state_op_names.append(op_name) - state_tensors.append(state_tensor) - if ( - state_block.type == BlockType.PARAMETER - and state_block.origin in self._variables_conf - ): - state_config = self._variables_conf[state_block.origin] - else: - state_config = None - state_block.set_lazy_origin_builder( - partial( - graph_build_util.build_graph_state, - op_name, - state_tensor, - state_config, - ) - ) - state_tensor_tuple = convert_to_tensor_tuple(state_tensors) - return state_op_names, state_tensor_tuple - def _add_block(self, name: str, module: Module = None) -> None: r"""Adds module to the graph as a block so that the module will be called in nn.Graph.build. diff --git a/python/oneflow/nn/graph/graph_config.py b/python/oneflow/nn/graph/graph_config.py index 93fc76fa772..4f753821afe 100644 --- a/python/oneflow/nn/graph/graph_config.py +++ b/python/oneflow/nn/graph/graph_config.py @@ -157,9 +157,7 @@ def enable_cudnn_conv_heuristic_search_algo(self, mode: bool = True): def _generate_optimizer_and_variable_configs( self, opt_dict: OptDict = None, variables_conf: OrderedDict = None, ): - opt_dict.generate_optimizer_and_variable_configs( - self.proto.mutable_train_conf(), variables_conf - ) + opt_dict.generate_optimizer_and_variable_configs(self.proto, variables_conf) def __repr__(self): main_str = ( diff --git a/python/oneflow/nn/graph/optimizer.py b/python/oneflow/nn/graph/optimizer.py index e0c6e422c33..a6ef4d18760 100644 --- a/python/oneflow/nn/graph/optimizer.py +++ b/python/oneflow/nn/graph/optimizer.py @@ -14,29 +14,56 @@ limitations under the License. """ from oneflow.nn.optimizer.optimizer import Optimizer +from oneflow.nn.optimizer.sparse_optimizer import SparseOptimizer from oneflow.nn.optimizer.lr_scheduler import LrScheduler class OptDict(object): - def __init__( - self, opt_dict, - ): - assert isinstance(opt_dict, dict), "opt dict must be a dict" - assert "optim" in opt_dict, "opt dict must has an optimizer" - self._optimizer = opt_dict["optim"] - assert isinstance(opt_dict["optim"], Optimizer) + def __init__(self, opt_dict): + if not isinstance(opt_dict, dict): + raise ValueError("opt_dict is not a dict") + + if "optim" in opt_dict: + if isinstance(opt_dict["optim"], Optimizer): + self._optimizer = opt_dict["optim"] + self._is_sparse = False + elif isinstance(opt_dict["optim"], SparseOptimizer): + self._optimizer = opt_dict["optim"]._nested_optim + self._is_sparse = True + else: + raise ValueError( + 'opt_dict["optim"] is not an instance of Optimizer or SparseOptimizer.' + ) + else: + raise ValueError("Key 'optim' doesn't exist in opt_dict.") self._lr_scheduler = None if "lr_sch" in opt_dict: - assert isinstance(opt_dict["lr_sch"], LrScheduler) + if not isinstance(opt_dict["lr_sch"], LrScheduler): + raise ValueError( + 'opt_dict["lr_sch"] is not an instance of LrScheduler.' + ) + + if opt_dict["lr_sch"]._optimizer is not self._optimizer: + raise ValueError("lr_scheduler doesn't match optimizer.") + self._lr_scheduler = opt_dict["lr_sch"] - assert ( - self._lr_scheduler._optimizer is self._optimizer - ), "lr_scheduler's optimizer must be the same optimizer in the opt dict." - def generate_optimizer_and_variable_configs(self, train_conf, vars_conf): + def generate_optimizer_and_variable_configs(self, job_conf, vars_conf): + train_conf = job_conf.mutable_train_conf() + if self._optimizer is not None: + # Check first + self._optimizer._check_variables_in_graph(vars_conf) + self._optimizer._check_variables_optimizer_bound(vars_conf) + opt_confs = self._optimizer._generate_conf_for_graph(train_conf, vars_conf) + + if self._is_sparse: + self._optimizer._generate_indexed_slices_optimizer_conf( + job_conf, vars_conf + ) + if self._lr_scheduler is not None: self._lr_scheduler._generate_conf_for_graph(opt_confs) @@ -46,6 +73,7 @@ def __init__(self, name: str): assert name != "" self._name = name self._l2 = 0.0 + self._bound_opt = None @property def name(self): @@ -59,5 +87,13 @@ def l2(self): def l2(self, l2: float = 0.0): self._l2 = l2 + @property + def bound_optimizer(self): + return self._bound_opt + + @bound_optimizer.setter + def bound_optimizer(self, opt): + self._bound_opt = opt + def __repr__(self): return "(variable name: " + self._name + "):(l2: " + str(self._l2) + ".)" diff --git a/python/oneflow/nn/init.py b/python/oneflow/nn/init.py index df535c545bf..cc541e8f0ff 100644 --- a/python/oneflow/nn/init.py +++ b/python/oneflow/nn/init.py @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. """ +import os + import oneflow as flow from oneflow.ops.initializer_util import CalcGain @@ -157,6 +159,8 @@ def kaiming_normal_( >>> w = flow.empty(3, 5) >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu') """ + if os.getenv("ONEFLOW_ENABLE_NHWC") == "1": + data_format = "NHWC" with flow.no_grad(): return tensor.kaiming_normal_(a, mode, nonlinearity, data_format=data_format) diff --git a/python/oneflow/nn/module.py b/python/oneflow/nn/module.py index 988d5aa0cca..a4b50a179f3 100644 --- a/python/oneflow/nn/module.py +++ b/python/oneflow/nn/module.py @@ -463,9 +463,14 @@ def register_forward_pre_hook(self, hook: Callable[..., None]) -> None: def register_forward_hook(self, hook: Callable[..., None]) -> None: self._forward_hooks[len(self._forward_hooks)] = hook - def _apply(self, fn): + def _apply(self, fn, applied_dict=None): + # A dict to store tensors that has already been applied. + # There is no need to apply multiple times on a same tensor. + if applied_dict is None: + applied_dict = dict() + for module in self.children(): - module._apply(fn) + module._apply(fn, applied_dict) def can_use_assign_copy(tensor, tensor_applied): return tensor.is_local == tensor_applied.is_local @@ -474,27 +479,47 @@ def can_use_assign_copy(tensor, tensor_applied): if param is None: continue - assert isinstance(param, Parameter) - assert param.is_leaf - with flow.no_grad(): - param_applied = fn(param) - param_applied.requires_grad = param.requires_grad - - if param.grad is not None: - assert param.grad.is_leaf + need_apply = False + if param not in applied_dict: + need_apply = True + assert isinstance(param, Parameter) + assert param.is_leaf with flow.no_grad(): - grad_applied = fn(param.grad) - grad_applied.requires_grad = param.grad.requires_grad - param_applied.grad = grad_applied + param_applied = fn(param) + param_applied.requires_grad = param.requires_grad + + if param.grad is not None: + assert param.grad.is_leaf + with flow.no_grad(): + grad_applied = fn(param.grad) + grad_applied.requires_grad = param.grad.requires_grad + param_applied.grad = grad_applied + else: + param_applied = applied_dict[param] if can_use_assign_copy(param_applied, param): - self._parameters[key].data = param_applied + if need_apply: + self._parameters[key].data = param_applied + applied_dict[param] = param_applied + else: + # The parameter's data has already been set when it can use assign copy. + pass else: - self._parameters[key] = Parameter(param_applied, param.requires_grad) + if need_apply: + new_param = Parameter(param_applied, param.requires_grad) + self._parameters[key] = new_param + applied_dict[param] = new_param + else: + self._parameters[key] = applied_dict[param] for (key, buf) in self._buffers.items(): if buf is not None: - self._buffers[key] = fn(buf) + if buf not in applied_dict: + buf_applied = fn(buf) + self._buffers[key] = buf_applied + applied_dict[buf] = buf_applied + else: + self._buffers[key] = applied_dict[buf] return self def apply(self: T, fn: Callable[["Module"], None]) -> T: diff --git a/python/oneflow/nn/modules/activation.py b/python/oneflow/nn/modules/activation.py index b453693f1c9..1f4b23cf639 100644 --- a/python/oneflow/nn/modules/activation.py +++ b/python/oneflow/nn/modules/activation.py @@ -68,14 +68,21 @@ class PReLU(Module): """ - def __init__(self, num_parameters: int = 1, init: float = 0.25) -> None: + def __init__( + self, num_parameters: int = 1, init: float = 0.25, device=None, dtype=None + ) -> None: super().__init__() self.num_parameters = num_parameters - self.weight = flow.nn.Parameter(flow.Tensor(num_parameters).fill_(init)) + self.weight = flow.nn.Parameter( + flow.empty(num_parameters, dtype=dtype, device=device).fill_(init) + ) def forward(self, x): return flow._C.prelu(x, self.weight) + def extra_repr(self) -> str: + return "num_parameters={}".format(self.num_parameters) + class ReLU(Module): """Applies the rectified linear unit function element-wise: diff --git a/python/oneflow/nn/modules/all_reduce.py b/python/oneflow/nn/modules/all_reduce.py index 57d45a94ed9..69b87e72906 100644 --- a/python/oneflow/nn/modules/all_reduce.py +++ b/python/oneflow/nn/modules/all_reduce.py @@ -23,14 +23,11 @@ class AllReduce(Module): def __init__(self, parallel_conf_str: str): super().__init__() self._op = ( - flow.builtin_op("eager_nccl_all_reduce") - .Input("in") - .Output("out") - .Attr("parallel_conf", parallel_conf_str) - .Build() + flow.stateful_op("eager_nccl_all_reduce").Input("in").Output("out").Build() ) + self.parallel_conf = parallel_conf_str def forward(self, x): assert x.device.type == "cuda" assert x.device.index == flow.env.get_local_rank() - return self._op(x)[0] + return flow._C.dispatch_eager_nccl_all_reduce(self._op, parallel_conf) diff --git a/python/oneflow/nn/modules/arange.py b/python/oneflow/nn/modules/arange.py index ff1b005277b..f38448c75c1 100644 --- a/python/oneflow/nn/modules/arange.py +++ b/python/oneflow/nn/modules/arange.py @@ -24,41 +24,12 @@ def arange_op( start: int = 0, end: int = None, step: int = 1, - dtype: flow.dtype = flow.int64, + dtype: flow.dtype = None, device: Union[str, flow.device] = None, placement: flow.placement = None, sbp: Union[flow.sbp.sbp, List[flow.sbp.sbp]] = None, requires_grad: bool = False, ): - """ - Returns a 1-D tensor of size :math:`\\left\\lfloor \\frac{\\text{end} - \\text{start}}{\\text{step}} \\right\\rfloor + 1` - with values from :attr:`start` to :attr:`end` with step :attr:`step`. Step is - the gap between two values in the tensor. - - .. math:: - \\text{out}_{i+1} = \\text{out}_i + \\text{step}. - - Args: - start (int): the starting value for the set of points. Default: ``0``. - end (int): the ending value for the set of points - step (int): the gap between each pair of adjacent points. Default: ``1``. - - Keyword args: - dtype(flow.dtype, optional): If `dtype` is not given, the `dtype` is inferred to be `flow.int64`. - device(flow.device, optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor. - requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: `False`. - - For example: - - .. code-block:: python - - >>> import oneflow as flow - - >>> y = flow.arange(0, 5) - >>> y - tensor([0, 1, 2, 3, 4], dtype=oneflow.int64) - - """ if end is None: end = start start = 0 diff --git a/python/oneflow/nn/modules/batchnorm.py b/python/oneflow/nn/modules/batchnorm.py index 8fa35bd694c..be131fed54f 100644 --- a/python/oneflow/nn/modules/batchnorm.py +++ b/python/oneflow/nn/modules/batchnorm.py @@ -14,6 +14,7 @@ limitations under the License. """ from typing import Union +import os import oneflow as flow from oneflow.nn.module import Module @@ -100,6 +101,12 @@ def __init__( track_running_stats=True, ): super().__init__(num_features, eps, momentum, affine, track_running_stats) + if os.getenv("ONEFLOW_ENABLE_NHWC") == "1": + self.data_format = "NHWC" + self.channel_axis = 3 + else: + self.data_format = "NCHW" + self.channel_axis = 1 def forward(self, x): self._check_input_dim(x) @@ -115,7 +122,7 @@ def forward(self, x): self.running_var if not self.training or self.track_running_stats else None, self.weight, self.bias, - axis=1, + axis=self.channel_axis, epsilon=self.eps, momentum=self.momentum, is_training=is_training, diff --git a/python/oneflow/nn/modules/batchnorm_fused.py b/python/oneflow/nn/modules/batchnorm_fused.py index fe1be60ea8d..950af38ce1e 100644 --- a/python/oneflow/nn/modules/batchnorm_fused.py +++ b/python/oneflow/nn/modules/batchnorm_fused.py @@ -14,6 +14,7 @@ limitations under the License. """ from typing import Union +import os import oneflow as flow from oneflow.nn.module import Module @@ -80,6 +81,12 @@ def __init__( track_running_stats=True, ): super().__init__(num_features, eps, momentum, affine, track_running_stats) + if os.getenv("ONEFLOW_ENABLE_NHWC") == "1": + self.data_format = "NHWC" + self.channel_axis = 3 + else: + self.data_format = "NCHW" + self.channel_axis = 1 def forward(self, x, addend=None): self._check_input_dim(x) @@ -97,7 +104,7 @@ def forward(self, x, addend=None): self.running_var if not self.training or self.track_running_stats else None, self.weight, self.bias, - axis=1, + axis=self.channel_axis, epsilon=self.eps, momentum=self.momentum, is_training=is_training, diff --git a/python/oneflow/nn/modules/conv.py b/python/oneflow/nn/modules/conv.py index 43b92d1fd4e..5769ca31c97 100644 --- a/python/oneflow/nn/modules/conv.py +++ b/python/oneflow/nn/modules/conv.py @@ -14,6 +14,7 @@ limitations under the License. """ import math +import os import oneflow as flow from oneflow.nn import init @@ -181,6 +182,7 @@ def __init__( self.padding = _single(padding) self.dilation = _single(dilation) self.groups = groups + self.channel_pos = "channels_first" assert in_channels % groups == 0 assert out_channels % groups == 0 self.in_channels = in_channels @@ -210,6 +212,7 @@ def forward(self, x): padding=self.padding, dilation=self.dilation, groups=self.groups, + channel_pos=self.channel_pos, ) def extra_repr(self): @@ -364,13 +367,25 @@ def __init__( self.padding = _pair(padding) self.dilation = _pair(dilation) self.groups = groups + + if os.getenv("ONEFLOW_ENABLE_NHWC") == "1": + self.channel_pos = "channels_last" + else: + self.channel_pos = "channels_first" + assert in_channels % groups == 0 assert out_channels % groups == 0 self.in_channels = in_channels self.out_channels = out_channels - self.weight = flow.nn.Parameter( - flow.Tensor(out_channels, in_channels // groups, *self.kernel_size) - ) + if self.channel_pos == "channels_first": + self.weight = flow.nn.Parameter( + flow.Tensor(out_channels, in_channels // groups, *self.kernel_size) + ) + else: + self.weight = flow.nn.Parameter( + flow.Tensor(out_channels, *self.kernel_size, in_channels // groups) + ) + self.out_channel_groups = out_channels // groups self.bias = None if bias: @@ -385,13 +400,15 @@ def reset_parameters(self) -> None: init.uniform_(self.bias, -bound, bound) def forward(self, x): - if x.shape[1] != self.in_channels: - raise ValueError("The input channels should be equal to self.in_channels") - # TODO(zwx): Use `tensor.device_type()` method to help checking if x is on cpu. - # Using `if x.device == flow.device("cpu"):` will fail as consistent tensor has - # no device, however using `x.is_cuda` is not a good choice. - - res = flow._C.conv2d( + if self.channel_pos == "channels_first": + in_channel_axis = 1 + else: + in_channel_axis = 3 + if x.shape[in_channel_axis] != self.in_channels: + raise ValueError( + f"The input channels {x.shape[in_channel_axis]} should be equal to self.in_channels {self.in_channels}." + ) + return flow._C.conv2d( x, self.weight, self.bias, @@ -399,8 +416,8 @@ def forward(self, x): padding=self.padding, dilation=self.dilation, groups=self.groups, + channel_pos=self.channel_pos, ) - return res def extra_repr(self): s = "{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}" @@ -533,6 +550,7 @@ def __init__( self.padding = _triple(padding) self.dilation = _triple(dilation) self.groups = groups + self.channel_pos = "channels_first" assert in_channels % groups == 0, "in_channels must be divisible by groups" assert out_channels % groups == 0, "out_channels must be divisible by groups" self.in_channels = in_channels @@ -564,6 +582,7 @@ def forward(self, x): padding=self.padding, dilation=self.dilation, groups=self.groups, + channel_pos=self.channel_pos, ) def extra_repr(self): @@ -692,7 +711,7 @@ def __init__( self.weight = flow.nn.Parameter( flow.Tensor(in_channels, out_channels // groups, *self.kernel_size) ) - self.filters = out_channels // groups + self.filters = out_channels self.bias = None self._bias_add_op = None if bias: @@ -806,45 +825,24 @@ def __init__( ) -> None: super().__init__() assert padding_mode == "zeros" - kernel_size = _pair(kernel_size) - stride = _pair(stride) - padding = _pair(padding) - output_padding = _pair(output_padding) - dilation = _pair(dilation) + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.output_padding = _pair(output_padding) + self.dilation = _pair(dilation) self.groups = groups assert in_channels % groups == 0 assert out_channels % groups == 0 self.weight = flow.nn.Parameter( - flow.Tensor(in_channels, out_channels // groups, *kernel_size) + flow.Tensor(in_channels, out_channels // groups, *self.kernel_size) ) self.in_channel_groups = in_channels // groups + self.filters = out_channels self.bias = None self._bias_add_op = None if bias: self.bias = flow.nn.Parameter(flow.Tensor(out_channels)) - self._bias_add_op = ( - flow.builtin_op("bias_add") - .Input("a") - .Input("b") - .Output("out") - .Attr("axis", 1) - .Build() - ) - self._op = ( - flow.builtin_op("deconv2d") - .Input("in") - .Input("weight") - .Attr("filters", out_channels // groups) - .Attr("padding_before", padding) - .Attr("data_format", "channels_first") - .Attr("kernel_size", kernel_size) - .Attr("strides", stride) - .Attr("dilation_rate", dilation) - .Attr("output_padding", output_padding) - .Attr("groups", 1) - .Output("out") - .Build() - ) + self.reset_parameters() def reset_parameters(self) -> None: @@ -855,31 +853,19 @@ def reset_parameters(self) -> None: init.uniform_(self.bias, -bound, bound) def forward(self, x): - if self.groups > 1: - in_channel_axis = 1 - in_split_list = ConvUtil.split( - x, axis=in_channel_axis, split_num=self.groups - ) - out_list = [] - for i in range(len(in_split_list)): - out_list.append( - self._op( - in_split_list[i], - self.weight[ - i - * self.in_channel_groups : (i + 1) - * self.in_channel_groups, - :, - :, - :, - ], - )[0] - ) - res = flow.cat(out_list, dim=in_channel_axis) - else: - res = self._op(x, self.weight)[0] - if self._bias_add_op is not None: - res = self._bias_add_op(res, self.bias)[0] + res = flow._C.deconv2d( + x, + self.weight, + self.bias, + self.filters, + self.padding, + "channels_first", + self.kernel_size, + self.output_padding, + self.stride, + self.dilation, + self.groups, + ) return res @@ -1015,7 +1001,7 @@ def __init__( self.weight = flow.nn.Parameter( flow.Tensor(in_channels, out_channels // groups, *self.kernel_size) ) - self.filters = out_channels // groups + self.filters = out_channels self.bias = None self._bias_add_op = None if bias: diff --git a/python/oneflow/nn/modules/dataset.py b/python/oneflow/nn/modules/dataset.py index 84523369e2b..a3904f5d3de 100644 --- a/python/oneflow/nn/modules/dataset.py +++ b/python/oneflow/nn/modules/dataset.py @@ -13,13 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. """ +import os import random import sys import traceback from typing import List, Optional, Sequence, Tuple, Union import oneflow as flow -from oneflow.framework.tensor import Tensor, TensorTuple +import oneflow._oneflow_internal._C as _C +from oneflow.framework.tensor import Tensor from oneflow.nn.common_types import _size_1_t, _size_2_t, _size_3_t, _size_any_t from oneflow.nn.module import Module from oneflow.nn.modules.utils import _pair, _reverse_repeat_tuple, _single, _triple @@ -56,8 +58,14 @@ def __init__( if name is not None: print("WARNING: name has been deprecated and has NO effect.\n") - - nd_sbp = [] + self.ofrecord_dir = ofrecord_dir + self.batch_size = batch_size + self.data_part_num = data_part_num + self.part_name_prefix = part_name_prefix + self.part_name_suffix_length = part_name_suffix_length + self.random_shuffle = random_shuffle + self.shuffle_buffer_size = shuffle_buffer_size + self.shuffle_after_epoch = shuffle_after_epoch self.placement = placement if placement is None: @@ -68,41 +76,49 @@ def __init__( if placement is not None: assert isinstance(sbp, (flow.sbp.sbp, tuple, list)), "sbp: %s" % sbp if isinstance(sbp, flow.sbp.sbp): - nd_sbp.append(sbp._ToAttrStr()) sbp = (sbp,) else: for elem in sbp: assert isinstance(elem, flow.sbp.sbp), "sbp: %s" % sbp - nd_sbp.append(elem._ToAttrStr()) - assert len(nd_sbp) == len(placement.hierarchy) + assert len(sbp) == len(placement.hierarchy) else: assert sbp is None, "sbp: %s" % sbp self.sbp = sbp - (seed, has_seed) = mirrored_gen_random_seed(random_seed) - self._op = ( - flow.builtin_op("OFRecordReader") - .Output("out") - .Attr("data_dir", ofrecord_dir) - .Attr("data_part_num", data_part_num) - .Attr("batch_size", batch_size) - .Attr("part_name_prefix", part_name_prefix) - .Attr("random_shuffle", random_shuffle) - .Attr("shuffle_buffer_size", shuffle_buffer_size) - .Attr("shuffle_after_epoch", shuffle_after_epoch) - .Attr("part_name_suffix_length", part_name_suffix_length) - .Attr("seed", seed) - .Attr("nd_sbp", nd_sbp) - .Build() - ) - self.attrs = flow._oneflow_internal.MutableCfgAttrMap() + (self.seed, self.has_seed) = mirrored_gen_random_seed(random_seed) + self._op = flow.stateful_op("OFRecordReader").Output("out").Build() def forward(self): if self.placement is not None: - res = self._op.apply(self.placement, self.sbp, self.attrs)[0] + res = _C.dispatch_ofrecord_reader( + self._op, + data_dir=self.ofrecord_dir, + data_part_num=self.data_part_num, + part_name_prefix=self.part_name_prefix, + part_name_suffix_length=self.part_name_suffix_length, + batch_size=self.batch_size, + shuffle_buffer_size=self.shuffle_buffer_size, + random_shuffle=self.random_shuffle, + shuffle_after_epoch=self.shuffle_after_epoch, + seed=self.seed, + sbp=self.sbp, + placement=self.placement, + ) else: - res = self._op.apply(self.device, self.attrs)[0] + res = _C.dispatch_ofrecord_reader( + self._op, + data_dir=self.ofrecord_dir, + data_part_num=self.data_part_num, + part_name_prefix=self.part_name_prefix, + part_name_suffix_length=self.part_name_suffix_length, + batch_size=self.batch_size, + shuffle_buffer_size=self.shuffle_buffer_size, + random_shuffle=self.random_shuffle, + shuffle_after_epoch=self.shuffle_after_epoch, + seed=self.seed, + device=self.device, + ) return res @@ -124,20 +140,26 @@ def __init__( ) if name is not None: print("WARNING: name has been deprecated and has NO effect.\n") + self.blob_name = blob_name + self.shape = shape + self.dtype = dtype + self.dim1_varying_length = dim1_varying_length + self.truncate = truncate + self.auto_zero_padding = auto_zero_padding self._op = ( - flow.builtin_op("ofrecord_raw_decoder") - .Input("in") - .Output("out") - .Attr("name", blob_name) - .Attr("shape", shape) - .Attr("data_type", dtype) - .Attr("dim1_varying_length", dim1_varying_length) - .Attr("truncate", truncate or auto_zero_padding) - .Build() + flow.stateful_op("ofrecord_raw_decoder").Input("in").Output("out").Build() ) def forward(self, input): - res = self._op(input)[0] + res = _C.dispatch_ofrecord_raw_decoder( + self._op, + input, + name=self.blob_name, + shape=self.shape, + data_type=self.dtype, + dim1_varying_length=self.dim1_varying_length, + truncate=self.truncate or self.auto_zero_padding, + ) return res @@ -152,7 +174,8 @@ def __init__( sbp: Union[flow.sbp.sbp, List[flow.sbp.sbp]] = None, ): super().__init__() - nd_sbp = [] + self.batch_size = batch_size + self.probability = probability self.placement = placement if placement is None: @@ -164,37 +187,40 @@ def __init__( if placement is not None: assert isinstance(sbp, (flow.sbp.sbp, tuple, list)), "sbp: %s" % sbp if isinstance(sbp, flow.sbp.sbp): - nd_sbp.append(sbp._ToAttrStr()) sbp = (sbp,) else: for elem in sbp: assert isinstance(elem, flow.sbp.sbp), "sbp: %s" % sbp - nd_sbp.append(elem._ToAttrStr()) - assert len(nd_sbp) == len(placement.hierarchy) + assert len(sbp) == len(placement.hierarchy) else: assert sbp is None, "sbp: %s" % sbp self.sbp = sbp - (seed, has_seed) = mirrored_gen_random_seed(random_seed) + (self.seed, self.has_seed) = mirrored_gen_random_seed(random_seed) - self._op = ( - flow.builtin_op("coin_flip") - .Output("out") - .Attr("batch_size", batch_size) - .Attr("probability", probability) - .Attr("has_seed", has_seed) - .Attr("seed", seed) - .Attr("nd_sbp", nd_sbp) - .Build() - ) - self.attrs = flow._oneflow_internal.MutableCfgAttrMap() + self._op = flow.stateful_op("coin_flip").Output("out").Build() def forward(self): if self.placement is not None: - res = self._op.apply(self.placement, self.sbp, self.attrs)[0] + res = _C.dispatch_coin_flip( + self._op, + batch_size=self.batch_size, + probability=self.probability, + has_seed=self.has_seed, + seed=self.seed, + placement=self.placement, + sbp=self.sbp, + ) else: - res = self._op.apply(self.device, self.attrs)[0] + res = _C.dispatch_coin_flip( + self._op, + batch_size=self.batch_size, + probability=self.probability, + has_seed=self.has_seed, + seed=self.seed, + device=self.device, + ) return res @@ -212,93 +238,118 @@ def __init__( output_dtype: flow.dtype = flow.float, ): super().__init__() + if output_layout != "NCHW": + print( + "WARNING: output_layout has been deprecated. Please use Environment Variable ONEFLOW_ENABLE_NHWC, and make it equals 1." + ) + if os.getenv("ONEFLOW_ENABLE_NHWC") == "1": + output_layout = "NHWC" + else: + output_layout = "NCHW" + + self.color_space = color_space + self.output_layout = output_layout + self.mean = mean + self.std = std + self.crop_h = crop_h + self.crop_w = crop_w + self.crop_pos_y = crop_pos_y + self.crop_pos_x = crop_pos_x + self.output_dtype = output_dtype + self._op_uint8_with_mirror = ( - flow.builtin_op("crop_mirror_normalize_from_uint8") + flow.stateful_op("crop_mirror_normalize_from_uint8") .Input("in") .Input("mirror") .Output("out") - .Attr("color_space", color_space) - .Attr("output_layout", output_layout) - .Attr("mean", mean) - .Attr("std", std) - .Attr("crop_h", crop_h) - .Attr("crop_w", crop_w) - .Attr("crop_pos_y", crop_pos_y) - .Attr("crop_pos_x", crop_pos_x) - .Attr("output_dtype", output_dtype) .Build() ) self._op_uint8_no_mirror = ( - flow.builtin_op("crop_mirror_normalize_from_uint8") + flow.stateful_op("crop_mirror_normalize_from_uint8") .Input("in") .Output("out") - .Attr("color_space", color_space) - .Attr("output_layout", output_layout) - .Attr("mean", mean) - .Attr("std", std) - .Attr("crop_h", crop_h) - .Attr("crop_w", crop_w) - .Attr("crop_pos_y", crop_pos_y) - .Attr("crop_pos_x", crop_pos_x) - .Attr("output_dtype", output_dtype) .Build() ) self._op_buffer_with_mirror = ( - flow.builtin_op("crop_mirror_normalize_from_tensorbuffer") + flow.stateful_op("crop_mirror_normalize_from_tensorbuffer") .Input("in") .Input("mirror") .Output("out") - .Attr("color_space", color_space) - .Attr("output_layout", output_layout) - .Attr("mean", mean) - .Attr("std", std) - .Attr("crop_h", crop_h) - .Attr("crop_w", crop_w) - .Attr("crop_pos_y", crop_pos_y) - .Attr("crop_pos_x", crop_pos_x) - .Attr("output_dtype", output_dtype) .Build() ) self._op_buffer_no_mirror = ( - flow.builtin_op("crop_mirror_normalize_from_tensorbuffer") + flow.stateful_op("crop_mirror_normalize_from_tensorbuffer") .Input("in") .Output("out") - .Attr("color_space", color_space) - .Attr("output_layout", output_layout) - .Attr("mean", mean) - .Attr("std", std) - .Attr("crop_h", crop_h) - .Attr("crop_w", crop_w) - .Attr("crop_pos_y", crop_pos_y) - .Attr("crop_pos_x", crop_pos_x) - .Attr("output_dtype", output_dtype) .Build() ) def forward(self, input, mirror=None): - if mirror is not None: - if input.dtype is flow.uint8: - res = self._op_uint8_with_mirror(input, mirror)[0] - elif input.dtype is flow.tensor_buffer: - res = self._op_buffer_with_mirror(input, mirror)[0] + if input.dtype is flow.uint8: + if mirror is not None: + res = _C.dispatch_crop_mirror_normalize_from_uint8( + self._op_uint8_with_mirror, + (input, mirror), + color_space=self.color_space, + output_layout=self.output_layout, + mean=self.mean, + std=self.std, + crop_h=self.crop_h, + crop_w=self.crop_w, + crop_pos_x=self.crop_pos_x, + crop_pos_y=self.crop_pos_y, + output_dtype=self.output_dtype, + ) else: - print( - "ERROR! oneflow.nn.CropMirrorNormalize module NOT support input dtype = ", - input.dtype, + res = _C.dispatch_crop_mirror_normalize_from_uint8( + self._op_uint8_no_mirror, + (input,), + color_space=self.color_space, + output_layout=self.output_layout, + mean=self.mean, + std=self.std, + crop_h=self.crop_h, + crop_w=self.crop_w, + crop_pos_x=self.crop_pos_x, + crop_pos_y=self.crop_pos_y, + output_dtype=self.output_dtype, + ) + elif input.dtype is flow.tensor_buffer: + if mirror is not None: + res = _C.dispatch_crop_mirror_normalize_from_tensorbuffer( + self._op_buffer_with_mirror, + (input, mirror), + color_space=self.color_space, + output_layout=self.output_layout, + mean=self.mean, + std=self.std, + crop_h=self.crop_h, + crop_w=self.crop_w, + crop_pos_x=self.crop_pos_x, + crop_pos_y=self.crop_pos_y, + output_dtype=self.output_dtype, ) - raise NotImplementedError - else: - if input.dtype is flow.uint8: - res = self._op_uint8_no_mirror(input)[0] - elif input.dtype is flow.tensor_buffer: - res = self._op_buffer_no_mirror(input)[0] else: - print( - "ERROR! oneflow.nn.CropMirrorNormalize module NOT support input dtype = ", - input.dtype, + res = _C.dispatch_crop_mirror_normalize_from_tensorbuffer( + self._op_buffer_no_mirror, + (input,), + color_space=self.color_space, + output_layout=self.output_layout, + mean=self.mean, + std=self.std, + crop_h=self.crop_h, + crop_w=self.crop_w, + crop_pos_x=self.crop_pos_x, + crop_pos_y=self.crop_pos_y, + output_dtype=self.output_dtype, ) - raise NotImplementedError + else: + print( + "ERROR! oneflow.nn.CropMirrorNormalize module NOT support input dtype = ", + input.dtype, + ) + raise NotImplementedError return res @@ -313,23 +364,31 @@ def __init__( random_aspect_ratio: Sequence[float] = [0.75, 1.333333], ): super().__init__() - (seed, has_seed) = mirrored_gen_random_seed(random_seed) + self.blob_name = blob_name + self.color_space = color_space + self.num_attempts = num_attempts + self.random_area = random_area + self.random_aspect_ratio = random_aspect_ratio + (self.seed, self.has_seed) = mirrored_gen_random_seed(random_seed) self._op = ( - flow.builtin_op("ofrecord_image_decoder_random_crop") + flow.stateful_op("ofrecord_image_decoder_random_crop") .Input("in") .Output("out") - .Attr("name", blob_name) - .Attr("color_space", color_space) - .Attr("num_attempts", num_attempts) - .Attr("random_area", random_area) - .Attr("random_aspect_ratio", random_aspect_ratio) - .Attr("has_seed", has_seed) - .Attr("seed", seed) .Build() ) def forward(self, input): - res = self._op(input)[0] + res = _C.dispatch_ofrecord_image_decoder_random_crop( + self._op, + input, + name=self.blob_name, + color_space=self.color_space, + num_attempts=self.num_attempts, + random_area=self.random_area, + random_aspect_ratio=self.random_aspect_ratio, + has_seed=self.has_seed, + seed=self.seed, + ) return res @@ -337,16 +396,15 @@ class OFRecordImageDecoder(Module): def __init__(self, blob_name: str, color_space: str = "BGR"): super().__init__() self._op = ( - flow.builtin_op("ofrecord_image_decoder") - .Input("in") - .Output("out") - .Attr("name", blob_name) - .Attr("color_space", color_space) - .Build() + flow.stateful_op("ofrecord_image_decoder").Input("in").Output("out").Build() ) + self.blob_name = blob_name + self.color_space = color_space def forward(self, input): - res = self._op(input)[0] + res = _C.dispatch_ofrecord_image_decoder( + self._op, input, name=self.blob_name, color_space=self.color_space + ) return res @@ -355,45 +413,34 @@ def __init__( self, target_width: int, target_height: int, - num_attempts: Optional[int] = None, - seed: Optional[int] = None, - random_area: Optional[Sequence[float]] = None, - random_aspect_ratio: Optional[Sequence[float]] = None, - num_workers: Optional[int] = None, - warmup_size: Optional[int] = None, - max_num_pixels: Optional[int] = None, + num_attempts: Optional[int] = 10, + seed: Optional[int] = 0, + random_area: Optional[Sequence[float]] = [0.08, 1.0], + random_aspect_ratio: Optional[Sequence[float]] = [0.75, 1.333333], + num_workers: Optional[int] = 3, + warmup_size: Optional[int] = 6400, + max_num_pixels: Optional[int] = 67108864, ): super().__init__() + self.target_width = target_width + self.target_height = target_height + self.num_attempts = num_attempts + self.seed = seed + assert len(random_area) == 2 + self.random_area = random_area + assert len(random_aspect_ratio) == 2 + self.random_aspect_ratio = random_aspect_ratio + self.num_workers = num_workers + self.warmup_size = warmup_size + self.max_num_pixels = max_num_pixels gpu_decoder_conf = ( flow._oneflow_internal.oneflow.core.operator.op_conf.ImageDecoderRandomCropResizeOpConf() ) gpu_decoder_conf.set_in("error_input_need_to_be_replaced") gpu_decoder_conf.set_out("out") - gpu_decoder_conf.set_target_width(target_width) - gpu_decoder_conf.set_target_height(target_height) - if num_attempts is not None: - gpu_decoder_conf.set_num_attempts(num_attempts) - if seed is not None: - gpu_decoder_conf.set_seed(seed) - if random_area is not None: - assert len(random_area) == 2 - gpu_decoder_conf.set_random_area_min(random_area[0]) - gpu_decoder_conf.set_random_area_max(random_area[1]) - if random_aspect_ratio is not None: - assert len(random_aspect_ratio) == 2 - gpu_decoder_conf.set_random_aspect_ratio_min(random_aspect_ratio[0]) - gpu_decoder_conf.set_random_aspect_ratio_max(random_aspect_ratio[1]) - if num_workers is not None: - gpu_decoder_conf.set_num_workers(num_workers) - if warmup_size is not None: - gpu_decoder_conf.set_warmup_size(warmup_size) - if max_num_pixels is not None: - gpu_decoder_conf.set_max_num_pixels(max_num_pixels) - self._op = flow._oneflow_internal.one.ImageDecoderRandomCropResizeOpExpr( id_util.UniqueStr("ImageGpuDecoder"), gpu_decoder_conf, ["in"], ["out"] ) - self.attrs = flow._oneflow_internal.MutableCfgAttrMap() def forward(self, input): if not input.is_lazy: @@ -402,7 +449,21 @@ def forward(self, input): "NOT support run as eager module, please use it in nn.Graph.", ) raise NotImplementedError - res = self._op.apply([input], self.attrs)[0] + res = _C.dispatch_image_decoder_random_crop_resize( + self._op, + input, + target_width=self.target_width, + target_height=self.target_height, + num_attempts=self.num_attempts, + seed=self.seed, + random_area_min=self.random_area[0], + random_area_max=self.random_area[1], + random_aspect_ratio_min=self.random_aspect_ratio[0], + random_aspect_ratio_max=self.random_aspect_ratio[1], + num_workers=self.num_workers, + warmup_size=self.warmup_size, + max_num_pixels=self.max_num_pixels, + ) if not res.is_cuda: print( "WARNING! oneflow.nn.OFRecordImageGpuDecoderRandomCropResize ONLY support ", @@ -417,17 +478,23 @@ def __init__( ): super().__init__() self._op = ( - flow.builtin_op("tensor_buffer_to_list_of_tensors_v2") + flow.stateful_op("tensor_buffer_to_list_of_tensors_v2") .Input("in") .Output("out", out_num) - .Attr("out_shapes", out_shapes) - .Attr("out_dtypes", out_dtypes) - .Attr("dynamic_out", dynamic_out) .Build() ) + self.out_shapes = out_shapes + self.out_dtypes = out_dtypes + self.dynamic_out = dynamic_out def forward(self, input): - return self._op(input) + return _C.dispatch_tensor_buffer_to_list_of_tensors_v2( + self._op, + input, + out_shapes=self.out_shapes, + out_dtypes=self.out_dtypes, + dynamic_out=self.dynamic_out, + ) def tensor_buffer_to_list_of_tensors(tensor, out_shapes, out_dtypes): @@ -471,6 +538,7 @@ def __init__( channels = 1 else: raise ValueError("invalid color_space") + self.channels = channels if interp_type is not None: print( "WARNING: interp_type has been deprecated. Please use interpolation_type instead." @@ -486,6 +554,8 @@ def __init__( interpolation_type = "bicubic" else: raise ValueError("invalid interp_type") + self.interpolation_type = interpolation_type + if resize_x > 0 and resize_y > 0: print( "WARNING: resize_x and resize_y has been deprecated. Please use target_size instead." @@ -503,7 +573,8 @@ def __init__( target_size = resize_shorter keep_aspect_ratio = True resize_side = "shorter" - if keep_aspect_ratio: + self.keep_aspect_ratio = keep_aspect_ratio + if self.keep_aspect_ratio: if not isinstance(target_size, int): raise ValueError( "target_size must be an int when keep_aspect_ratio is True" @@ -518,17 +589,16 @@ def __init__( resize_longer = True else: raise ValueError('resize_side must be "shorter" or "longer"') + self.target_size = target_size + self.min_size = min_size + self.max_size = max_size + self.resize_longer = resize_longer self._op = ( - flow.builtin_op("image_resize_keep_aspect_ratio") + flow.stateful_op("image_resize_keep_aspect_ratio") .Input("in") .Output("out") .Output("size") .Output("scale") - .Attr("target_size", target_size) - .Attr("min_size", min_size) - .Attr("max_size", max_size) - .Attr("resize_longer", resize_longer) - .Attr("interpolation_type", interpolation_type) .Build() ) else: @@ -542,24 +612,27 @@ def __init__( ) if dtype is None: dtype = flow.uint8 - (target_w, target_h) = target_size + self.dtype = dtype + (self.target_w, self.target_h) = target_size self._op = ( - flow.builtin_op("image_resize_to_fixed") + flow.stateful_op("image_resize_to_fixed") .Input("in") .Output("out") .Output("scale") - .Attr("target_width", target_w) - .Attr("target_height", target_h) - .Attr("channels", channels) - .Attr("data_type", dtype) - .Attr("interpolation_type", interpolation_type) .Build() ) def forward(self, input): - res = self._op(input) - res_image = res[0] - if len(res) == 3: + if self.keep_aspect_ratio: + res = _C.dispatch_image_resize_keep_aspect_ratio( + self._op, + input, + target_size=self.target_size, + min_size=self.min_size, + max_size=self.max_size, + resize_longer=self.resize_longer, + interpolation_type=self.interpolation_type, + ) new_size = flow.tensor_buffer_to_tensor( res[1], dtype=flow.int32, instance_shape=(2,) ) @@ -567,8 +640,18 @@ def forward(self, input): res[2], dtype=flow.float32, instance_shape=(2,) ) else: + res = _C.dispatch_image_resize_to_fixed( + self._op, + input, + target_width=self.target_w, + target_height=self.target_h, + channels=self.channels, + data_type=self.dtype, + interpolation_type=self.interpolation_type, + ) new_size = None scale = res[1] + res_image = res[0] return (res_image, scale, new_size) @@ -681,33 +764,27 @@ def forward(self, images, flip_code): class ImageDecode(Module): def __init__(self, dtype: flow.dtype = flow.uint8, color_space: str = "BGR"): super().__init__() - self._op = ( - flow.builtin_op("image_decode") - .Input("in") - .Output("out") - .Attr("color_space", color_space) - .Attr("data_type", dtype) - .Build() - ) + self.color_space = color_space + self.dtype = dtype + self._op = flow.stateful_op("image_decode").Input("in").Output("out").Build() def forward(self, input): - return self._op(input)[0] + return _C.dispatch_image_decode( + self._op, input, color_space=self.color_space, data_type=self.dtype + ) class ImageNormalize(Module): def __init__(self, std: Sequence[float], mean: Sequence[float]): super().__init__() - self._op = ( - flow.builtin_op("image_normalize") - .Input("in") - .Output("out") - .Attr("std", std) - .Attr("mean", mean) - .Build() - ) + self.std = std + self.mean = mean + self._op = flow.stateful_op("image_normalize").Input("in").Output("out").Build() def forward(self, input): - return self._op(input)[0] + return _C.dispatch_image_normalize( + self._op, input, mean=self.mean, std=self.std + ) class COCOReader(Module): @@ -726,10 +803,17 @@ def __init__( sbp: Union[flow.sbp.sbp, List[flow.sbp.sbp]] = None, ): super().__init__() + self.annotation_file = annotation_file + self.image_dir = image_dir + self.batch_size = batch_size + self.shuffle = shuffle + self.group_by_aspect_ratio = group_by_aspect_ratio + self.remove_images_without_annotations = remove_images_without_annotations + self.stride_partition = stride_partition if random_seed is None: random_seed = random.randrange(sys.maxsize) + self.random_seed = random_seed - nd_sbp = [] self.placement = placement if placement is None: self.device = device or flow.device("cpu") @@ -743,21 +827,19 @@ def __init__( for sbp_item in sbp: if not isinstance(sbp_item, flow.sbp.sbp): raise ValueError(f"invalid sbp item: {sbp_item}") - nd_sbp.append(sbp_item._ToAttrStr()) elif isinstance(sbp, flow.sbp.sbp): - nd_sbp.append(sbp._ToAttrStr()) sbp = (sbp,) else: raise ValueError(f"invalid param sbp: {sbp}") - if len(nd_sbp) != len(placement.hierarchy): + if len(sbp) != len(placement.hierarchy): raise ValueError( "dimensions of sbp and dimensions of hierarchy of placement don't equal" ) self.sbp = sbp self._op = ( - flow.builtin_op("COCOReader") + flow.stateful_op("COCOReader") .Output("image") .Output("image_id") .Output("image_size") @@ -765,53 +847,63 @@ def __init__( .Output("gt_label") .Output("gt_segm") .Output("gt_segm_index") - .Attr("session_id", flow.current_scope().session_id) - .Attr("annotation_file", annotation_file) - .Attr("image_dir", image_dir) - .Attr("batch_size", batch_size) - .Attr("shuffle_after_epoch", shuffle) - .Attr("random_seed", random_seed) - .Attr("group_by_ratio", group_by_aspect_ratio) - .Attr( - "remove_images_without_annotations", remove_images_without_annotations - ) - .Attr("stride_partition", stride_partition) - .Attr("nd_sbp", nd_sbp) .Build() ) - self.attrs = flow._oneflow_internal.MutableCfgAttrMap() def forward(self): if self.placement is None: # local apply - outputs = self._op.apply(self.device, self.attrs) + outputs = _C.dispatch_coco_reader( + self._op, + session_id=flow.current_scope().session_id, + annotation_file=self.annotation_file, + image_dir=self.image_dir, + batch_size=self.batch_size, + shuffle_after_epoch=self.shuffle, + random_seed=self.random_seed, + group_by_ratio=self.group_by_aspect_ratio, + remove_images_without_annotations=self.remove_images_without_annotations, + stride_partition=self.stride_partition, + device=self.device, + ) else: # consistent apply - outputs = self._op.apply(self.placement, self.sbp, self.attrs) - - # COCOReader has multiple output, so it return a TensorTuple - # convert TensorTuple to tuple of Tensor - assert isinstance(outputs, TensorTuple) - ret = tuple(out for out in outputs) - return ret + outputs = _C.dispatch_coco_reader( + self._op, + session_id=flow.current_scope().session_id, + annotation_file=self.annotation_file, + image_dir=self.image_dir, + batch_size=self.batch_size, + shuffle_after_epoch=self.shuffle, + random_seed=self.random_seed, + group_by_ratio=self.group_by_aspect_ratio, + remove_images_without_annotations=self.remove_images_without_annotations, + stride_partition=self.stride_partition, + placement=self.placement, + sbp=self.sbp, + ) + return outputs class ImageBatchAlign(Module): def __init__(self, shape: Sequence[int], dtype: flow.dtype, alignment: int): super().__init__() self._op = ( - flow.builtin_op("image_batch_align") - .Input("in") - .Output("out") - .Attr("shape", shape) - .Attr("data_type", dtype) - .Attr("alignment", alignment) - .Attr("dynamic_out", False) - .Build() + flow.stateful_op("image_batch_align").Input("in").Output("out").Build() ) + self.shape = shape + self.dtype = dtype + self.alignment = alignment def forward(self, input): - return self._op(input)[0] + return _C.dispatch_image_batch_align( + self._op, + input, + shape=self.shape, + data_type=self.dtype, + alignment=self.alignment, + dynamic_out=False, + ) class OFRecordBytesDecoder(Module): @@ -866,15 +958,12 @@ def __init__(self, blob_name: str, name: Optional[str] = None): if name is not None: print("WARNING: name has been deprecated and has NO effect.\n") self._op = ( - flow.builtin_op("ofrecord_bytes_decoder") - .Input("in") - .Output("out") - .Attr("name", blob_name) - .Build() + flow.stateful_op("ofrecord_bytes_decoder").Input("in").Output("out").Build() ) + self.blob_name = blob_name def forward(self, input): - return self._op(input)[0] + return _C.dispatch_ofrecord_bytes_decoder(self._op, input, name=self.blob_name) class GPTIndexedBinDataReader(Module): @@ -895,7 +984,12 @@ def __init__( ): super().__init__() - nd_sbp = [] + self.data_file_prefix = data_file_prefix + self.seq_length = seq_length + self.num_samples = num_samples + self.batch_size = batch_size + self.dtype = dtype + self.shuffle = shuffle self.placement = placement if placement is None: self.device = device or flow.device("cpu") @@ -909,14 +1003,12 @@ def __init__( for sbp_item in sbp: if not isinstance(sbp_item, flow.sbp.sbp): raise ValueError(f"invalid sbp item: {sbp_item}") - nd_sbp.append(sbp_item._ToAttrStr()) elif isinstance(sbp, flow.sbp.sbp): - nd_sbp.append(sbp._ToAttrStr()) sbp = (sbp,) else: raise ValueError(f"invalid param sbp: {sbp}") - if len(nd_sbp) != len(placement.hierarchy): + if len(sbp) != len(placement.hierarchy): raise ValueError( "dimensions of sbp and dimensions of hierarchy of placement don't equal" ) @@ -924,12 +1016,15 @@ def __init__( if random_seed is None: random_seed = random.randrange(sys.maxsize) + self.random_seed = random_seed if split_index is None: split_index = 0 + self.split_index = split_index if split_sizes is None: split_sizes = (1,) + self.split_sizes = split_sizes if split_index >= len(split_sizes): raise ValueError( @@ -938,29 +1033,42 @@ def __init__( ) ) - op_builder = ( - flow.builtin_op("megatron_gpt_mmap_data_loader") - .Output("out") - .Attr("data_file_prefix", data_file_prefix) - .Attr("seq_length", seq_length) - .Attr("label_length", 1) - .Attr("num_samples", num_samples) - .Attr("batch_size", batch_size) - .Attr("dtype", dtype) - .Attr("shuffle", shuffle) - .Attr("random_seed", random_seed) - .Attr("split_sizes", split_sizes) - .Attr("split_index", split_index) - .Attr("nd_sbp", nd_sbp) + self.op_ = ( + flow.stateful_op("megatron_gpt_mmap_data_loader").Output("out").Build() ) - self.op_ = op_builder.Build() - self.attrs = flow._oneflow_internal.MutableCfgAttrMap() def forward(self): if self.placement is None: - output = self.op_.apply(self.device, self.attrs)[0] + output = _C.dispatch_megatron_gpt_mmap_data_loader( + self.op_, + data_file_prefix=self.data_file_prefix, + seq_length=self.seq_length, + label_length=1, + num_samples=self.num_samples, + batch_size=self.batch_size, + dtype=self.dtype, + shuffle=self.shuffle, + random_seed=self.random_seed, + split_sizes=self.split_sizes, + split_index=self.split_index, + device=self.device, + ) else: - output = self.op_.apply(self.placement, self.sbp, self.attrs)[0] + output = _C.dispatch_megatron_gpt_mmap_data_loader( + self.op_, + data_file_prefix=self.data_file_prefix, + seq_length=self.seq_length, + label_length=1, + num_samples=self.num_samples, + batch_size=self.batch_size, + dtype=self.dtype, + shuffle=self.shuffle, + random_seed=self.random_seed, + split_sizes=self.split_sizes, + split_index=self.split_index, + placement=self.placement, + sbp=self.sbp, + ) return output diff --git a/python/oneflow/nn/modules/in_top_k.py b/python/oneflow/nn/modules/in_top_k.py index 557ddaf78f9..c09fb6ccc39 100644 --- a/python/oneflow/nn/modules/in_top_k.py +++ b/python/oneflow/nn/modules/in_top_k.py @@ -18,27 +18,6 @@ from oneflow.nn.module import Module -class InTopk(Module): - def __init__(self, k) -> None: - super().__init__() - self._in_top_k = ( - flow.builtin_op("in_top_k") - .Input("targets") - .Input("predictions") - .Output("out") - .Attr("k", k) - .Build() - ) - - def forward(self, targets, predictions): - assert ( - targets.shape[0] == predictions.shape[0] - ), "The num of targets must equal the num of predictions" - assert len(targets.shape) == 1, "The dimension of targets must be 1" - assert len(predictions.shape) == 2, "The dimension of predictions must be 2" - return self._in_top_k(targets, predictions) - - def in_top_k_op(targets, predictions, k): """Says whether the targets are in the top K predictions. @@ -71,7 +50,7 @@ def in_top_k_op(targets, predictions, k): tensor([1, 0], device='cuda:0', dtype=oneflow.int8) """ - return InTopk(k=k)(targets, predictions)[0] + return flow._C.in_top_k(targets, predictions, k=k) @register_tensor_op("in_top_k") @@ -83,7 +62,7 @@ def in_top_k_op_tensor(targets, predictions, k): See :func:`oneflow.in_top_k` """ - return InTopk(k=k)(targets, predictions)[0] + return flow._C.in_top_k(targets, predictions, k=k) if __name__ == "__main__": diff --git a/python/oneflow/nn/modules/eye.py b/python/oneflow/nn/modules/linspace.py similarity index 54% rename from python/oneflow/nn/modules/eye.py rename to python/oneflow/nn/modules/linspace.py index 490f2564685..c8c9f57de7d 100644 --- a/python/oneflow/nn/modules/eye.py +++ b/python/oneflow/nn/modules/linspace.py @@ -1,82 +1,98 @@ -""" -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" -from typing import Union, List - -import oneflow as flow -from oneflow.framework.tensor import register_tensor_op - - -def eye_op( - n, - m=None, - dtype: flow.dtype = flow.float, - device: Union[str, flow.device] = None, - placement: flow.placement = None, - sbp: Union[flow.sbp.sbp, List[flow.sbp.sbp]] = None, - requires_grad: bool = False, -): - """This operator creates a 2-D Tensor with ones on the diagonal and zeros elsewhere. - - Args: - n (int): the number of rows. - m (Optional[int], optional): the number of colums with default being n. Defaults to None. - - Keyword args: - device(flow.device, optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor. - requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: `False`. - - Returns: - oneflow.Tensor: The result Blob with ones on the diagonal and zeros elsewhere. - - For example: - - .. code-block:: python - - >>> import oneflow as flow - >>> out = flow.eye(3, 3) - >>> out - tensor([[1., 0., 0.], - [0., 1., 0.], - [0., 0., 1.]], dtype=oneflow.float32) - - """ - if placement is None: - if isinstance(device, str): - device = flow.device(device) - res = flow._C.eye(n, m, dtype=dtype, device=device) - else: - assert isinstance( - placement, flow._oneflow_internal.placement - ), "placement should be oneflow._oneflow_internal.placement type." - assert isinstance(sbp, (flow.sbp.sbp, tuple, list)), "sbp: %s" % sbp - if isinstance(sbp, flow.sbp.sbp): - assert sbp == flow.sbp.broadcast - sbp = (sbp,) - else: - for elem in sbp: - assert isinstance(elem, flow.sbp.sbp), "sbp: %s" % sbp - assert elem == flow.sbp.broadcast - assert len(sbp) == len(placement.hierarchy) - res = flow._C.consistent_eye(n, m, dtype=dtype, placement=placement, sbp=sbp) - - res.requires_grad = requires_grad - return res - - -if __name__ == "__main__": - import doctest - - doctest.testmod(raise_on_error=True) +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from typing import List, Optional, Union + +import oneflow as flow + + +def linspace_op( + start: float, + end: float, + steps: int, + dtype: flow.dtype = flow.float32, + device: Union[str, flow.device] = None, + placement: flow.placement = None, + sbp: Union[flow.sbp.sbp, List[flow.sbp.sbp]] = None, + requires_grad: bool = False, +): + r""" + Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly + spaced from :attr:`start` to :attr:`end`, inclusive. That is, the value are: + + .. math:: + (\text{start}, + \text{start} + \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \ldots, + \text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{\text{steps} - 1}, + \text{end}) + + Args: + start (float): the starting value for the set of points + end (float): the ending value for the set of points + steps (int): size of the constructed tensor + + Keyword arguments: + dtype(flow.dtype, optional): If `dtype` is not given, the `dtype` is inferred to be `flow.float32`. + device(flow.device, optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor. + requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: `False`. + + For example: + + .. code-block:: python + + >>> import oneflow as flow + + >>> y = flow.linspace(3, 10, steps=5) + >>> y + tensor([ 3.0000, 4.7500, 6.5000, 8.2500, 10.0000], dtype=oneflow.float32) + + """ + step = 1.0 + if steps == 0: + end = start + elif steps == 1: + end = start + 1.0 + else: + step = (end - start) * 1.0 / (steps - 1) + if ((end - start) / (steps - 1)) * (steps - 1) == (end - start): + end = end + step / 2.0 + if placement is None: + if isinstance(device, str): + device = flow.device(device) + res = flow._C.arange(start, end, step, dtype=dtype, device=device) + else: + assert isinstance( + placement, flow._oneflow_internal.placement + ), "placement should be oneflow._oneflow_internal.placement type." + assert isinstance(sbp, (flow.sbp.sbp, tuple, list)), "sbp: %s" % sbp + if isinstance(sbp, flow.sbp.sbp): + sbp = (sbp,) + else: + for elem in sbp: + assert isinstance(elem, flow.sbp.sbp), "sbp: %s" % sbp + assert len(sbp) == len(placement.hierarchy) + res = flow._C.consistent_arange( + start, end, step, dtype=dtype, placement=placement, sbp=sbp + ) + + res.requires_grad = requires_grad + return res + + +if __name__ == "__main__": + import doctest + + doctest.testmod(raise_on_error=True) diff --git a/python/oneflow/nn/modules/math_ops.py b/python/oneflow/nn/modules/math_ops.py index ea48f05a532..ba1bc3c6471 100644 --- a/python/oneflow/nn/modules/math_ops.py +++ b/python/oneflow/nn/modules/math_ops.py @@ -467,6 +467,32 @@ def log_op(input): return flow._C.log(input) +@register_tensor_op("log2") +def log2_op(input): + """ + Returns a new tensor with the natural logarithm to the base 2 of the elements of :attr:`input`. + + .. math:: + y_{i} = \\log2_{e} (x_{i}) + + Args: + input (Tensor): the input tensor. + + For example: + + .. code-block:: python + + >>> import oneflow as flow + >>> import numpy as np + >>> arr = np.random.randn(2, 3, 4, 5) + >>> input = flow.tensor(arr, dtype=flow.float32) + >>> output = flow.log2(input) + + + """ + return flow._C.log2(input) + + @register_tensor_op("rsqrt") def rsqrt_op(input): """Returns a new tensor with the reciprocal of the square-root of each of @@ -616,51 +642,6 @@ def addmm_op_tensor(input, mat1, mat2, alpha=1, beta=1): return addmm(input, mat1, mat2, alpha, beta) -class Clamp(Module): - def __init__(self, min_value=None, max_value=None) -> None: - super().__init__() - if min_value is not None: - floating_min_value = float(min_value) - integral_min_value = int(min_value) - if max_value is not None: - floating_max_value = float(max_value) - integral_max_value = int(max_value) - if min_value is not None and max_value is not None: - self._op = ( - flow.builtin_op("clip_by_scalar") - .Input("x") - .Output("y") - .Attr("floating_min", floating_min_value) - .Attr("integral_min", integral_min_value) - .Attr("floating_max", floating_max_value) - .Attr("integral_max", integral_max_value) - .Build() - ) - elif min_value is not None: - self._op = ( - flow.builtin_op("clip_by_scalar_min") - .Input("x") - .Output("y") - .Attr("floating_min", floating_min_value) - .Attr("integral_min", integral_min_value) - .Build() - ) - elif max_value is not None: - self._op = ( - flow.builtin_op("clip_by_scalar_max") - .Input("x") - .Output("y") - .Attr("floating_max", floating_max_value) - .Attr("integral_max", integral_max_value) - .Build() - ) - else: - raise ValueError("min_value and max_value cannot be None at the same time") - - def forward(self, x): - return self._op(x)[0] - - def clamp_op(input, min=None, max=None): """ Clamp all elements in :attr:`input` into the range `[` :attr:`min`, :attr:`max` `]` and return @@ -716,14 +697,14 @@ def clamp_op_tensor(tensor, min=None, max=None): """ See :func:`oneflow.clamp` """ - return Clamp(min, max)(tensor) + return flow._C.clamp(tensor, min, max) def clip_op(tensor, min=None, max=None): """ Alias for :func:`oneflow.clamp` """ - return Clamp(min, max)(tensor) + return flow._C.clamp(tensor, min, max) @register_tensor_op("clip") @@ -731,7 +712,7 @@ def clip_op_tensor(tensor, min=None, max=None): """ See :func:`oneflow.clamp` """ - return Clamp(min, max)(tensor) + return flow._C.clamp(tensor, min, max) @register_tensor_op("cosh") @@ -989,14 +970,8 @@ def __init__( self, k, dim: int = None, largest: bool = True, sorted: bool = True ) -> None: super().__init__() - self._op_topk_last_dim = ( - flow.builtin_op("top_k") - .Input("in") - .Output("out") - .Attr("k", k) - .Attr("sorted", sorted) - .Build() - ) + self.k = k + self.sorted = sorted self.dim = dim self.largest = largest @@ -1008,19 +983,19 @@ def forward(self, input): assert 0 <= axis < num_axes, "axis out of range" if axis == num_axes - 1: if self.largest: - indices = self._op_topk_last_dim(input)[0] + indices = flow._C.top_k(input, self.k) else: neg_input = flow.mul(input, -1) - indices = self._op_topk_last_dim(neg_input)[0] + indices = flow._C.top_k(neg_input, self.k) return (flow.gather(input, axis, indices), indices) else: perm = get_perm_when_transpose_axis_to_last_dim(num_axes, axis) x = flow._C.transpose(input, perm=perm) if self.largest: - indices = self._op_topk_last_dim(x)[0] + indices = flow._C.top_k(x, self.k) else: neg_input = flow.mul(x, -1) - indices = self._op_topk_last_dim(neg_input)[0] + indices = flow._C.top_k(neg_input, self.k) indices = flow._C.transpose(indices, perm=get_inversed_perm(perm)) return (flow.gather(input, axis, indices), indices) @@ -1053,7 +1028,7 @@ def topk_op(input, k, dim: int = None, largest: bool = True, sorted: bool = True [9., 4., 3.]], dtype=oneflow.float32) >>> indices tensor([[2, 3, 1], - [1, 2, 3]], dtype=oneflow.int32) + [1, 2, 3]], dtype=oneflow.int64) >>> values.shape oneflow.Size([2, 3]) >>> indices.shape @@ -1064,7 +1039,7 @@ def topk_op(input, k, dim: int = None, largest: bool = True, sorted: bool = True [1., 2.]], dtype=oneflow.float32) >>> indices tensor([[0, 4], - [0, 4]], dtype=oneflow.int32) + [0, 4]], dtype=oneflow.int64) >>> values.shape oneflow.Size([2, 2]) >>> indices.shape diff --git a/python/oneflow/nn/modules/meshgrid.py b/python/oneflow/nn/modules/meshgrid.py index 4e3aa00a7f2..838c573bc73 100644 --- a/python/oneflow/nn/modules/meshgrid.py +++ b/python/oneflow/nn/modules/meshgrid.py @@ -16,8 +16,8 @@ import oneflow as flow -def meshgrid_op(*tensors): - return flow._C.meshgrid(tensors) +def meshgrid_op(*tensors, indexing="ij"): + return flow._C.meshgrid(tensors, indexing) if __name__ == "__main__": diff --git a/python/oneflow/nn/modules/norm.py b/python/oneflow/nn/modules/norm.py deleted file mode 100644 index 47c003c7a59..00000000000 --- a/python/oneflow/nn/modules/norm.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" -import oneflow as flow -from oneflow.framework.tensor import register_tensor_op -from oneflow.nn.module import Module - - -def l2_normalize(input, dim=0, epsilon=1e-12): - """Use L2 norm to normalizes along dimension `dim` - - The equation is: - - .. math:: - out = \\frac{x}{max(\\sqrt{\\Sigma{x^2}}, \\epsilon)} - - Args: - input (oneflow.Tensor): Input Tensor - dim (int): The axis on which to apply L2 normalization. Defaults to 0. - epsilon (float, optional): The epsilon value is used to avoid division by zero. Defaults to 1e-12. - - Returns: - oneflow.Tensor: The normalized Tensor - - For example: - - .. code-block:: python - - >>> import oneflow as flow - >>> x = flow.tensor([[1, 2], [3, 4]], dtype=flow.float32) - >>> out = flow.nn.functional.l2_normalize(x, 0) - >>> out - tensor([[0.3162, 0.4472], - [0.9487, 0.8944]], dtype=oneflow.float32) - >>> out = flow.nn.functional.l2_normalize(x, 1) - >>> out - tensor([[0.4472, 0.8944], - [0.6000, 0.8000]], dtype=oneflow.float32) - - """ - y, _ = flow._C.l2_normalize(input, dim, epsilon) - return y - - -if __name__ == "__main__": - import doctest - - doctest.testmod(raise_on_error=True) diff --git a/python/oneflow/nn/modules/pooling.py b/python/oneflow/nn/modules/pooling.py index 3a5c3904653..971e8e20968 100644 --- a/python/oneflow/nn/modules/pooling.py +++ b/python/oneflow/nn/modules/pooling.py @@ -14,6 +14,7 @@ limitations under the License. """ from typing import Optional +import os import oneflow as flow from oneflow.nn.common_types import _size_1_t, _size_2_t, _size_3_t @@ -122,6 +123,45 @@ def extra_repr(self) -> str: ) +def get_dhw_offset(channel_pos): + if channel_pos == "channels_first": + return 2 + else: + return 1 + + +def get_ndim_pads_list(padding, dhw_offset, ndims): + pads_list = [] + for i in range(len(padding)): + pad = padding[i] + if isinstance(pad, int): + pad = [pad, pad] + elif isinstance(pad, (list, tuple)): + assert len(pad) == 2 + pad = [pad[0], pad[1]] + else: + raise ValueError("padding must be list tuple or int") + if i in range(dhw_offset, dhw_offset + ndims): + pads_list.append(pad) + else: + assert pad == [0, 0] + return pads_list + + +def calc_pool_padding(padding, dhw_offset, ndims): + if isinstance(padding, str): + padding = "SAME_LOWER" if padding.upper() == "SAME" else padding + assert padding.upper() in ["VALID", "SAME_LOWER", "SAME_UPPER"] + padding_type = padding.lower() + ndim_pads_list = [[0, 0]] * ndims + elif isinstance(padding, (list, tuple)): + padding_type = "customized" + ndim_pads_list = get_ndim_pads_list(padding, dhw_offset, ndims) + else: + raise ValueError("padding must be str or a list.") + return (padding_type, ndim_pads_list) + + class MaxPool2d(Module): r"""The interface is consistent with PyTorch. The documentation is referenced from: https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d @@ -197,31 +237,70 @@ def __init__( ): super().__init__() self.kernel_size = _pair(kernel_size) - data_format = "NCHW" # only support "NCHW" for now ! - self.channel_pos = ( - "channels_first" if data_format == "NCHW" else "channels_last" - ) self.stride = _pair(stride) if (stride is not None) else _pair(kernel_size) - self.dilation = _pair(dilation) - self.return_indices = return_indices self.ceil_mode = ceil_mode - self.padding = _pair(padding) + + if os.getenv("ONEFLOW_ENABLE_NHWC") == "1": + self.data_format = "NHWC" + self.channel_pos = "channels_last" + padding = _pair(padding) + if len(padding) == 2: + if self.data_format == "NCHW": + padding = (0, 0, padding[0], padding[1]) + elif self.data_format == "NHWC": + padding = (0, padding[0], padding[1], 0) + else: + raise ValueError("error padding param!") + self.padding = padding + self.padding_type, pads_list = calc_pool_padding( + padding, get_dhw_offset(self.channel_pos), 2 + ) + self.padding_before = [pad[0] for pad in pads_list] + self.padding_after = [pad[1] for pad in pads_list] + if return_indices == True: + raise ValueError( + "MaxPool2d with NHWC data format don't support return indices for now." + ) + if dilation != 1: + raise ValueError( + "MaxPool2d with NHWC data format only support dilation == 1 for now." + ) + self.dilation = _pair(dilation) + + else: + self.data_format = "NCHW" + self.channel_pos = "channels_first" + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.return_indices = return_indices def forward(self, x): - y, indice = flow._C.max_pool2d( - x, - kernel_size=self.kernel_size, - stride=self.stride, - padding=self.padding, - dilation=self.dilation, - return_indices=True, - ceil_mode=self.ceil_mode, - data_format=self.channel_pos, - ) - if self.return_indices: - return y, indice + if self.data_format == "NCHW": + y, indice = flow._C.max_pool2d( + x, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + return_indices=True, + ceil_mode=self.ceil_mode, + data_format=self.channel_pos, + ) + if self.return_indices: + return y, indice + else: + return y else: - return y + return flow._C.max_pool2d_nhwc( + x, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding_type, + padding_before=self.padding_before, + padding_after=self.padding_after, + data_format=self.channel_pos, + ceil_mode=self.ceil_mode, + ) def extra_repr(self) -> str: return "kernel_size={}, stride={}, padding={}, dilation={}".format( @@ -465,27 +544,70 @@ def __init__( ): super().__init__() self.kernel_size = _pair(kernel_size) - data_format = "NCHW" # only support "NCHW" for now ! - self.channel_pos = ( - "channels_first" if data_format == "NCHW" else "channels_last" - ) self.stride = _pair(stride) if (stride is not None) else _pair(kernel_size) self.ceil_mode = ceil_mode - self.count_include_pad = count_include_pad - self.divisor_override = int(divisor_override) - self.padding = _pair(padding) + + if os.getenv("ONEFLOW_ENABLE_NHWC") == "1": + self.data_format = "NHWC" + self.channel_pos = "channels_last" + assert isinstance(padding, int) or isinstance( + padding, tuple + ), "padding can only int int or tuple of 2 ints." + padding = _pair(padding) + if len(padding) == 2: + if self.data_format == "NCHW": + padding = (0, 0, padding[0], padding[1]) + elif self.data_format == "NHWC": + padding = (0, padding[0], padding[1], 0) + else: + raise ValueError("error padding param!") + self.padding = padding + + if not count_include_pad: + raise ValueError( + "AvgPool2d with NHWC data format don't support count_include_pad for now." + ) + if divisor_override != 0: + raise ValueError( + "AvgPool2d with NHWC data format don't support divisor_override for now." + ) + + # TODO(yaochi): align with pytorch when padding is asymmetric + self._padding_type, _pads_list = calc_pool_padding( + padding, get_dhw_offset(self.channel_pos), 2 + ) + self._padding_before = [pad[0] for pad in _pads_list] + self._padding_after = [pad[1] for pad in _pads_list] + else: + self.data_format = "NCHW" + self.channel_pos = "channels_first" + self.padding = _pair(padding) + self.count_include_pad = count_include_pad + self.divisor_override = int(divisor_override) def forward(self, x): - return flow._C.avg_pool2d( - x, - kernel_size=self.kernel_size, - stride=self.stride, - padding=self.padding, - ceil_mode=self.ceil_mode, - count_include_pad=self.count_include_pad, - divisor_override=self.divisor_override, - data_format=self.channel_pos, - ) + if self.data_format == "NCHW": + return flow._C.avg_pool2d( + x, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + ceil_mode=self.ceil_mode, + count_include_pad=self.count_include_pad, + divisor_override=self.divisor_override, + data_format=self.channel_pos, + ) + else: + return flow._C.avg_pool2d_nhwc( + x, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self._padding_type, + padding_before=self._padding_before, + padding_after=self._padding_after, + ceil_mode=self.ceil_mode, + data_format=self.channel_pos, + ) def extra_repr(self) -> str: return ( diff --git a/python/oneflow/nn/modules/slice.py b/python/oneflow/nn/modules/slice.py index 1f05cd21b85..79dc92b52ad 100644 --- a/python/oneflow/nn/modules/slice.py +++ b/python/oneflow/nn/modules/slice.py @@ -99,6 +99,33 @@ def logical_slice_assign_op( return flow._C.logical_slice_assign(input, update, start, stop, step) +def logical_slice_op(input, slice_tup_list: Sequence[Tuple[int, int, int]]): + """Extracts a slice from a consistent tensor. + The `slice_tup_list` assigns the slice indices in each dimension, the format is (start, stop, step). + The operator will slice the tensor according to the `slice_tup_list`. + + Args: + input: A `Tensor`. + slice_tup_list: A list of slice tuple, indicate each dimension slice (start, stop, step). + + For example: + + .. code-block:: python + + >>> import oneflow as flow + + >>> placement = flow.placement("cpu", {0: [0]}) + >>> x = flow.Tensor([[1, 2], [3, 4]], placement=placement, sbp=flow.sbp.broadcast) + >>> y = flow.logical_slice(x, slice_tup_list=[[0, 1, 1]]) + >>> y.numpy() + array([[1., 2.]], dtype=float32) + + """ + + (start, stop, step) = parse_slice_tuple_list(slice_tup_list, input.shape) + return flow._C.logical_slice(input, start, stop, step) + + if __name__ == "__main__": import doctest diff --git a/python/oneflow/nn/modules/tensor_buffer.py b/python/oneflow/nn/modules/tensor_buffer.py index 29f49625702..8413fc8e5f9 100644 --- a/python/oneflow/nn/modules/tensor_buffer.py +++ b/python/oneflow/nn/modules/tensor_buffer.py @@ -16,23 +16,6 @@ from typing import Optional, Sequence import oneflow as flow -from oneflow.nn.module import Module - - -class TensorBufferToTensor(Module): - def __init__(self, dtype, instance_shape): - super().__init__() - self._op = ( - flow.builtin_op("tensor_buffer_to_tensor") - .Input("in") - .Output("out") - .Attr("dtype", dtype) - .Attr("instance_shape", instance_shape) - .Build() - ) - - def forward(self, input): - return self._op(input)[0] def tensor_buffer_to_tensor_op(x, dtype: flow.dtype, instance_shape: Sequence[int]): @@ -65,22 +48,9 @@ def tensor_buffer_to_tensor_op(x, dtype: flow.dtype, instance_shape: Sequence[in oneflow.Size([4, 16, 64, 64]) """ - return TensorBufferToTensor(dtype=dtype, instance_shape=instance_shape)(x) - - -class TensorToTensorBuffer(Module): - def __init__(self, instance_dims): - super().__init__() - self._op = ( - flow.builtin_op("tensor_to_tensor_buffer") - .Input("in") - .Output("out") - .Attr("instance_dims", instance_dims) - .Build() - ) - - def forward(self, input): - return self._op(input)[0] + return flow._C.tensor_buffer_to_tensor( + x, dtype=dtype, instance_shape=instance_shape + ) def tensor_to_tensor_buffer(x, instance_dims: int): @@ -110,25 +80,7 @@ def tensor_to_tensor_buffer(x, instance_dims: int): oneflow.Size([4, 16, 64, 64]) """ - return TensorToTensorBuffer(instance_dims=instance_dims)(x) - - -class GenTensorBuffer(Module): - def __init__(self, shape, shape_list, value_list, data_type, dynamic_out): - super().__init__() - self._op = ( - flow.builtin_op("gen_tensor_buffer") - .Output("out") - .Attr("shape", shape) - .Attr("shape_list", shape_list) - .Attr("value_list", value_list) - .Attr("data_type", data_type) - .Attr("dynamic_out", dynamic_out) - .Build() - ) - - def forward(self): - return self._op()[0] + return flow._C.tensor_to_tensor_buffer(x, instance_dims) def gen_tensor_buffer( @@ -138,7 +90,9 @@ def gen_tensor_buffer( data_type: Optional[flow.dtype] = flow.float32, dynamic_out: Optional[bool] = False, ): - return GenTensorBuffer(shape, shape_list, value_list, data_type, dynamic_out)() + return flow._C.gen_tensor_buffer( + shape, shape_list, value_list, data_type, dynamic_out + ) if __name__ == "__main__": diff --git a/python/oneflow/nn/optimizer/adagrad.py b/python/oneflow/nn/optimizer/adagrad.py index f1683ecd495..87f58f4469c 100644 --- a/python/oneflow/nn/optimizer/adagrad.py +++ b/python/oneflow/nn/optimizer/adagrad.py @@ -123,12 +123,10 @@ def __init__( ) self._op = ( - flow.builtin_op("adagrad_update") + flow.stateful_op("adagrad_update") .Input("model") .Input("model_diff") .Input("sum") - .Attr("l1", 0.0) - .Attr("weight_decay", 0.0) .Build() ) @@ -145,7 +143,7 @@ def step(self, closure: Callable = None): loss = closure() for param_group in self.param_groups: kwargs = { - "learning_rate_val": param_group["lr"], + "learning_rate": param_group["lr"], "l2": param_group["weight_decay"], "epsilon": param_group["eps"], "lr_decay": param_group["lr_decay"], @@ -155,7 +153,9 @@ def step(self, closure: Callable = None): if param.grad is None: continue sum_tensor = self._state[param]["sum"] - self._op(param, param.grad, sum_tensor, **kwargs) + flow._C.dispatch_adagrad_update( + self._op, (param, param.grad, sum_tensor), **kwargs + ) self._state["step"] = self._state["step"] + 1 return loss diff --git a/python/oneflow/nn/optimizer/adam.py b/python/oneflow/nn/optimizer/adam.py index d5e5d38fc20..a145c5e3754 100644 --- a/python/oneflow/nn/optimizer/adam.py +++ b/python/oneflow/nn/optimizer/adam.py @@ -144,14 +144,12 @@ def __init__( self._state[param] = dict() self._op = ( - flow.builtin_op("adam_update") + flow.stateful_op("adam_update") .Input("model") .Input("model_diff") .Input("m") .Input("v") .Input("max_v") - .Attr("l1", 0.0) - .Attr("weight_decay", 0.0) .Build() ) @@ -177,9 +175,9 @@ def step(self, closure: Callable = None): ) kwargs = { - "learning_rate_val": param_group["lr"], - "bias_correction1_val": param_group["bias_correction1"], - "bias_correction2_val": param_group["bias_correction2"], + "learning_rate": param_group["lr"], + "bias_correction1": param_group["bias_correction1"], + "bias_correction2": param_group["bias_correction2"], "l2": param_group["weight_decay"], "beta1": param_group["betas"][0], "beta2": param_group["betas"][1], @@ -199,8 +197,10 @@ def step(self, closure: Callable = None): m_tensor = self._state[param]["exp_avg"] v_tensor = self._state[param]["exp_avg_sq"] max_v_tensor = self._state[param]["max_exp_avg_sq"] - self._op( - param, param.grad, m_tensor, v_tensor, max_v_tensor, **kwargs, + flow._C.dispatch_adam_update( + self._op, + (param, param.grad, m_tensor, v_tensor, max_v_tensor), + **kwargs, ) self._state["step"] += 1 @@ -243,3 +243,7 @@ def _generate_conf_for_graph(self, train_conf, vars_conf): new_opt_confs.append(optimizer_conf) return new_opt_confs + + @property + def support_sparse(self): + return True diff --git a/python/oneflow/nn/optimizer/adamw.py b/python/oneflow/nn/optimizer/adamw.py index bd7696ee909..b8912fb5b9c 100644 --- a/python/oneflow/nn/optimizer/adamw.py +++ b/python/oneflow/nn/optimizer/adamw.py @@ -146,14 +146,12 @@ def __init__( self._state[param] = dict() self._op = ( - flow.builtin_op("adam_update") + flow.stateful_op("adam_update") .Input("model") .Input("model_diff") .Input("m") .Input("v") .Input("max_v") - .Attr("l1", 0.0) - .Attr("l2", 0.0) .Build() ) @@ -178,9 +176,9 @@ def step(self, closure: Callable = None): ) kwargs = { - "learning_rate_val": param_group["lr"], - "bias_correction1_val": param_group["bias_correction1"], - "bias_correction2_val": param_group["bias_correction2"], + "learning_rate": param_group["lr"], + "bias_correction1": param_group["bias_correction1"], + "bias_correction2": param_group["bias_correction2"], "weight_decay": param_group["weight_decay"], "beta1": param_group["betas"][0], "beta2": param_group["betas"][1], @@ -202,8 +200,10 @@ def step(self, closure: Callable = None): m_tensor = self._state[param]["exp_avg"] v_tensor = self._state[param]["exp_avg_sq"] max_v_tensor = self._state[param]["max_exp_avg_sq"] - self._op( - param, param.grad, m_tensor, v_tensor, max_v_tensor, **kwargs, + flow._C.dispatch_adam_update( + self._op, + (param, param.grad, m_tensor, v_tensor, max_v_tensor), + **kwargs, ) self._state["step"] += 1 @@ -247,3 +247,7 @@ def _generate_conf_for_graph(self, train_conf, vars_conf): new_opt_confs.append(optimizer_conf) return new_opt_confs + + @property + def support_sparse(self): + return True diff --git a/python/oneflow/nn/optimizer/lr_scheduler.py b/python/oneflow/nn/optimizer/lr_scheduler.py index 7af358bef7a..44e205b9fbe 100644 --- a/python/oneflow/nn/optimizer/lr_scheduler.py +++ b/python/oneflow/nn/optimizer/lr_scheduler.py @@ -37,7 +37,7 @@ def __init__(self, optimizer, last_step=-1, verbose=False): self.step() def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. + """Return the state of the scheduler as a :class:`dict`. It contains an entry for every variable in self.__dict__ which is not the optimizer. @@ -47,7 +47,7 @@ def state_dict(self): } def load_state_dict(self, state_dict): - """Loads the schedulers state. + """Load the schedulers state. Arguments: state_dict (dict): scheduler state. Should be an object returned @@ -61,7 +61,7 @@ def get_lr(self): raise NotImplementedError def get_last_lr(self): - """ Return last computed learning rate by current scheduler. + """Return last computed learning rate by current scheduler. """ return [group["lr"] for group in self._optimizer.param_groups] @@ -119,3 +119,29 @@ def step(self): self._inner_lr_sch.step() # get right last_step from inner lr_scheduler self.last_step = self._inner_lr_sch.last_step + + def state_dict(self): + """Return the state of the scheduler as a :class:`dict`. + """ + state = { + key: value for (key, value) in self.__dict__.items() if key != "_optimizer" + } + if self._inner_lr_sch is not None: + state["_inner_lr_sch"] = self._inner_lr_sch.state_dict() + return state + + def load_state_dict(self, state_dict): + """Load the schedulers state. + + Arguments: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + if self._inner_lr_sch is not None: + assert "_inner_lr_sch" in state_dict + inner_lr_sch_state = state_dict.pop("_inner_lr_sch") + self._inner_lr_sch.load_state_dict(inner_lr_sch_state) + self.__dict__.update(state_dict) + # Resume _inner_lr_sch because that we should not change `state_dict` + if self._inner_lr_sch is not None: + state_dict["_inner_lr_sch"] = inner_lr_sch_state diff --git a/python/oneflow/nn/optimizer/optimizer.py b/python/oneflow/nn/optimizer/optimizer.py index afe62aa4618..ff1e4d4c5ca 100644 --- a/python/oneflow/nn/optimizer/optimizer.py +++ b/python/oneflow/nn/optimizer/optimizer.py @@ -294,3 +294,46 @@ def _generate_grad_clip_conf_for_optim_conf(self, param_group, optimizer_conf): warnings.warn( "For now, nn.Graph only support clip grad with `clip_grad_max_norm == 1.0` and `clip_grad_norm_type == 2.0`." ) + + @property + def support_sparse(self): + return False + + def _check_variables_in_graph(self, vars_conf): + for param_group in self.param_groups: + for param in param_group.parameters: + if not param.requires_grad: + continue + + if param not in vars_conf: + raise ValueError( + f"Parameter <{param}> is not in the corresponding nn.Graph/nn.Module." + " Please make sure you call the module's to(..)/to_consistent(...) method first," + " then add the module's parameters into an optimizer." + ) + + def _check_variables_optimizer_bound(self, vars_conf): + for param_group in self.param_groups: + for param in param_group.parameters: + if not param.requires_grad: + continue + + if vars_conf[param].bound_optimizer is None: + vars_conf[param].bound_optimizer = self + elif vars_conf[param].bound_optimizer is not self: + raise ValueError( + f"<{vars_conf[param].name}> is already bound to another optimizer." + ) + + def _generate_indexed_slices_optimizer_conf(self, job_conf, vars_conf): + if not self.support_sparse: + raise ValueError(f"{self.__class__} does not support sparse updating.") + + for param_group in self.param_groups: + for param in param_group.parameters: + if not param.requires_grad: + continue + + sparse_opt_conf = job_conf.mutable_indexed_slices_optimizer_conf() + sparse_variable_op_names = sparse_opt_conf.mutable_include_op_names() + sparse_variable_op_names.add_op_name(vars_conf[param].name) diff --git a/python/oneflow/nn/optimizer/rmsprop.py b/python/oneflow/nn/optimizer/rmsprop.py index ce7aaacb1e8..9386a8ff278 100644 --- a/python/oneflow/nn/optimizer/rmsprop.py +++ b/python/oneflow/nn/optimizer/rmsprop.py @@ -155,24 +155,18 @@ def __init__( self._state[param] = dict() self._centered_rmsprop = ( - flow.builtin_op("rmsprop_update") + flow.stateful_op("rmsprop_update") .Input("model") .Input("model_diff") .Input("mean_square") .Input("mean_gradient") - .Attr("centered", True) - .Attr("l1", 0.0) - .Attr("l2", 0.0) .Build() ) self._rmsprop = ( - flow.builtin_op("rmsprop_update") + flow.stateful_op("rmsprop_update") .Input("model") .Input("model_diff") .Input("mean_square") - .Attr("centered", False) - .Attr("l1", 0.0) - .Attr("l2", 0.0) .Build() ) @@ -189,7 +183,7 @@ def step(self, closure: Callable = None): loss = closure() for param_group in self.param_groups: kwargs = { - "learning_rate_val": param_group["lr"], + "learning_rate": param_group["lr"], "epsilon": param_group["eps"], "decay_rate": param_group["alpha"], "l2": param_group["weight_decay"], @@ -206,11 +200,16 @@ def step(self, closure: Callable = None): if "grad_avg" not in self._state[param]: self._state[param]["grad_avg"] = flow.zeros_like(param) mg_tensor = self._state[param]["grad_avg"] - self._centered_rmsprop( - param, param.grad, ms_tensor, mg_tensor, **kwargs + flow._C.dispatch_rmsprop_update( + self._centered_rmsprop, + (param, param.grad, ms_tensor, mg_tensor), + centered=True, + **kwargs, ) else: - self._rmsprop(param, param.grad, ms_tensor, **kwargs) + flow._C.dispatch_rmsprop_update( + self._rmsprop, (param, param.grad, ms_tensor), **kwargs + ) self._state["step"] = self._state["step"] + 1 return loss diff --git a/python/oneflow/nn/optimizer/sgd.py b/python/oneflow/nn/optimizer/sgd.py index f726cee2a25..33ee2d401f4 100644 --- a/python/oneflow/nn/optimizer/sgd.py +++ b/python/oneflow/nn/optimizer/sgd.py @@ -117,21 +117,14 @@ def __init__( self._state[param] = dict() self._momentum_sgd = ( - flow.builtin_op("momentum_update") + flow.stateful_op("momentum_update") .Input("model") .Input("model_diff") .Input("momentum") - .Attr("l1", 0.0) - .Attr("weight_decay", 0.0) .Build() ) self._sgd = ( - flow.builtin_op("sgd_update") - .Input("model") - .Input("model_diff") - .Attr("weight_decay", 0.0) - .Attr("l1", 0.0) - .Build() + flow.stateful_op("sgd_update").Input("model").Input("model_diff").Build() ) def step(self, closure: Callable = None): @@ -146,17 +139,18 @@ def step(self, closure: Callable = None): if param.grad is None: continue if param_group["momentum"] == 0.0: - self._sgd(param, param.grad, learning_rate_val=lr, l2=l2) + flow._C.dispatch_sgd_update( + self._sgd, (param, param.grad), learning_rate=lr, l2=l2 + ) else: if "momentum_buf" not in self._state[param]: self._state[param]["momentum_buf"] = flow.zeros_like(param) momentum_buf = self._state[param]["momentum_buf"] beta = param_group["momentum"] - self._momentum_sgd( - param, - param.grad, - momentum_buf, - learning_rate_val=lr, + flow._C.dispatch_momentum_update( + self._momentum_sgd, + (param, param.grad, momentum_buf), + learning_rate=lr, l2=l2, beta=beta, ) @@ -190,3 +184,7 @@ def _generate_conf_for_graph(self, train_conf, vars_conf): new_opt_confs.append(optimizer_conf) return new_opt_confs + + @property + def support_sparse(self): + return True diff --git a/python/oneflow/nn/optimizer/sparse_optimizer.py b/python/oneflow/nn/optimizer/sparse_optimizer.py new file mode 100644 index 00000000000..39d85a87086 --- /dev/null +++ b/python/oneflow/nn/optimizer/sparse_optimizer.py @@ -0,0 +1,41 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from oneflow.nn.optimizer.optimizer import Optimizer + + +class SparseOptimizer(object): + r"""SparseOptimizer do not support eager mode for now. If we need sparse optimizer + in graph mode, use SparseOptimizer to wrap the instance of Optimizer and add SparseOptimizer + to graph through nn.Graph.add_optimizer. + """ + + def __init__(self, optimizer: Optimizer): + self._nested_optim = optimizer + + def load_state_dict(self, state_dict): + self._nested_optim.load_state_dict(state_dict) + + def state_dict(self): + return self._nested_optim.state_dict() + + def step(self, closure): + raise NotImplementedError("SparseOptimizer doesn't support step for now") + + def clip_grad(self): + raise NotImplementedError("SparseOptimizer doesn't support clip_grad for now") + + def zero_grad(self, set_to_none: bool = False): + raise NotImplementedError("SparseOptimizer doesn't support zero_grad for now") diff --git a/python/oneflow/nn/parallel/ddp.py b/python/oneflow/nn/parallel/ddp.py index 9064bb249a7..47bad86d0e6 100644 --- a/python/oneflow/nn/parallel/ddp.py +++ b/python/oneflow/nn/parallel/ddp.py @@ -17,7 +17,6 @@ import oneflow as flow from oneflow.framework.tensor_tuple_util import convert_to_tensor_tuple -from oneflow.ops.builtin_ops import BuiltinOp as builtin_op def allreduce_fn(ddp_state_for_reversed_params, param): diff --git a/python/oneflow/ops/initializer_util.py b/python/oneflow/ops/initializer_util.py index 36aaf7f4afc..2a21173eb95 100644 --- a/python/oneflow/ops/initializer_util.py +++ b/python/oneflow/ops/initializer_util.py @@ -22,6 +22,7 @@ import oneflow as flow import oneflow.core.job.initializer_conf_pb2 as initializer_conf_util import oneflow.core.operator.op_conf_pb2 as op_conf_util +import oneflow.framework.dtype as dtype_util def constant_initializer( @@ -1206,3 +1207,13 @@ def EmptyInitializerImpl( var_blob_shape: Sequence[int], ): return None + + +def _elem_cnt(shape): + return np.prod(shape).astype(int).item() + + +def generate_values_by_initializer(initializer, shape, dtype): + np_dtype = np.dtype(dtype_util.convert_oneflow_dtype_to_numpy_dtype(dtype)) + length = _elem_cnt(shape) + return np.array(initializer(length)).astype(np_dtype).reshape(shape) diff --git a/python/oneflow/ops/builtin_ops.py b/python/oneflow/ops/stateful_ops.py similarity index 66% rename from python/oneflow/ops/builtin_ops.py rename to python/oneflow/ops/stateful_ops.py index 00b73e97f61..19c3a41e51b 100644 --- a/python/oneflow/ops/builtin_ops.py +++ b/python/oneflow/ops/stateful_ops.py @@ -16,10 +16,9 @@ import oneflow import oneflow._oneflow_internal import oneflow.framework.id_util as id_util -from oneflow.framework.attr_util import convert_to_user_attr_value -class BuiltinOp(object): +class StatefulOp(object): def __init__(self, op_type_name, op_name=None): if op_name is None: op_name = id_util.UniqueStr(op_type_name) @@ -66,31 +65,6 @@ def Output(self, output_name, num=1): self._builder.output(output_name, num) return self - def Attr(self, attr_name, attr_value, attr_type_name=None): - """Set value of op's attribute. - - Args: - attr_name (str): attribute name of op - attr_value (Any): attribute value of op - - Raises: - ValueError: raised when value is not idential to op's attribute type. - - Returns: - [type]: [description] - """ - if attr_type_name is not None: - print( - 'WARNING: Argument \'attr_type_name\' of UserOpConfBuilder.Attr has been deprecated. Please remove it.\n\n For instance:\n - .Attr("out_num", out_num, "AttrTypeInt64")\n + .Attr("out_num", out_num)\n ' - ) - print(traceback.format_stack()[-2]) - assert self._op_type_name is not None - self._builder.attr( - attr_name, - convert_to_user_attr_value(self._op_type_name, attr_name, attr_value), - ) - return self - def Build(self): """Explicitly complete the construction of the builtin op diff --git a/python/oneflow/optim/__init__.py b/python/oneflow/optim/__init__.py index f01268b15dd..baf0061095c 100644 --- a/python/oneflow/optim/__init__.py +++ b/python/oneflow/optim/__init__.py @@ -21,3 +21,4 @@ from oneflow.nn.optimizer.adagrad import Adagrad from . import lr_scheduler +from . import utils diff --git a/python/oneflow/optim/utils.py b/python/oneflow/optim/utils.py new file mode 100644 index 00000000000..faa073f586b --- /dev/null +++ b/python/oneflow/optim/utils.py @@ -0,0 +1,16 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from oneflow.nn.optimizer.sparse_optimizer import SparseOptimizer diff --git a/python/oneflow/test/graph/test_graph_eye.py b/python/oneflow/test/graph/test_graph_eye.py new file mode 100644 index 00000000000..64bf225eae2 --- /dev/null +++ b/python/oneflow/test/graph/test_graph_eye.py @@ -0,0 +1,38 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +import numpy as np +import random +import oneflow as flow +import oneflow.unittest +from test_util import generate_graph + + +@flow.unittest.skip_unless_1n1d() +class TestEyeGraph(oneflow.unittest.TestCase): + def test_eye_graph(test_case): + n = random.randint(1, 10) + m = random.randint(1, 10) + + eye_fn = lambda: flow.eye(n, m) + y_eager = eye_fn() + eye_graph = generate_graph(eye_fn) + y_lazy = eye_graph() + test_case.assertTrue(np.array_equal(y_eager.numpy(), y_lazy.numpy())) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/oneflow/test/graph/test_graph_free_eager_tensor.py b/python/oneflow/test/graph/test_graph_free_eager_tensor.py index 08de661599c..f4890daa61c 100644 --- a/python/oneflow/test/graph/test_graph_free_eager_tensor.py +++ b/python/oneflow/test/graph/test_graph_free_eager_tensor.py @@ -21,22 +21,21 @@ import oneflow.unittest -class MyModuleWithEagerTensorForward(flow.nn.Module): - def __init__(self): - super().__init__() - self.linear = flow.nn.Linear(3, 8, False) - - def forward(self, x): - y0 = self.linear(x) - eager_t = flow.tensor([1.0], dtype=y0.dtype, device=y0.device) - out = y0 + eager_t - return out - - @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n1d() class TestGraphWithEagerTensorCaught(oneflow.unittest.TestCase): def test_eager_tensor_forward_graph(test_case): + class MyModuleWithEagerTensorForward(flow.nn.Module): + def __init__(self): + super().__init__() + self.linear = flow.nn.Linear(3, 8, False) + + def forward(self, x): + y0 = self.linear(x) + eager_t = flow.tensor([1.0], dtype=y0.dtype, device=y0.device) + out = y0 + eager_t + return out + my_net_module = MyModuleWithEagerTensorForward() flow.nn.init.constant_(my_net_module.linear.weight, 2.3) x = np.random.randn(5, 3) @@ -84,6 +83,38 @@ def build(self): np.allclose(graph_out.numpy(), eager_out.numpy(), atol=1e-4, rtol=1e-4) ) + def test_two_graph_caught_same_free_eager_tensor(test_case): + np_x = np.random.randn(5, 3) + np_y = np.random.randn(5, 3) + x = flow.tensor(np_x, dtype=flow.float32) + y = flow.tensor(np_y, dtype=flow.float32) + + class GraphAdd(flow.nn.Graph): + def __init__(self): + super().__init__() + + def build(self): + return x + y + + class GraphMul(flow.nn.Graph): + def __init__(self): + super().__init__() + + def build(self): + return x * y + + g_add = GraphAdd() + g_mul = GraphMul() + + add_out = g_add() + mul_out = g_mul() + test_case.assertTrue( + np.allclose(add_out.numpy(), np_x + np_y, atol=1e-4, rtol=1e-4) + ) + test_case.assertTrue( + np.allclose(mul_out.numpy(), np_x * np_y, atol=1e-4, rtol=1e-4) + ) + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() diff --git a/python/oneflow/test/graph/test_graph_inplace_add.py b/python/oneflow/test/graph/test_graph_inplace_add.py new file mode 100644 index 00000000000..153dd12ccee --- /dev/null +++ b/python/oneflow/test/graph/test_graph_inplace_add.py @@ -0,0 +1,74 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import unittest +import numpy as np + +import oneflow as flow +import oneflow.unittest + + +def _test_graph_lazy_inplace(test_case, x, y): + class LazyInplaceAdd(flow.nn.Graph): + def __init__(self): + super().__init__() + + def build(self, x, y): + x += y + return x + + z = LazyInplaceAdd()(x, y) + test_case.assertTrue(np.allclose(z.numpy(), (x + y).numpy(), 1e-05, 1e-05)) + + +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +@flow.unittest.skip_unless_1n1d() +class TestLocalInplace(oneflow.unittest.TestCase): + def test_graph_inplace_gpu(test_case): + x = flow.randn(10, 10, device=flow.device("cuda")) + y = flow.ones(10, device=flow.device("cuda")) + _test_graph_lazy_inplace(test_case, x, y) + + def test_graph_inplace_cpu(test_case): + x = flow.randn(10, 10, device=flow.device("cpu")) + y = flow.ones(10, device=flow.device("cpu")) + _test_graph_lazy_inplace(test_case, x, y) + + +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +@flow.unittest.skip_unless_1n2d() +class TestConsistentInplace(oneflow.unittest.TestCase): + def test_graph_inplace_gpu(test_case): + x = flow.randn( + 10, 10, placement=flow.placement("cuda", {0: [0, 1]}), sbp=flow.sbp.split(1) + ) + y = flow.ones( + 10, placement=flow.placement("cuda", {0: [0, 1]}), sbp=flow.sbp.broadcast + ) + _test_graph_lazy_inplace(test_case, x, y) + + def test_graph_inplace_cpu(test_case): + x = flow.randn( + 10, 10, placement=flow.placement("cpu", {0: [0, 1]}), sbp=flow.sbp.split(1) + ) + y = flow.ones( + 10, placement=flow.placement("cpu", {0: [0, 1]}), sbp=flow.sbp.broadcast + ) + _test_graph_lazy_inplace(test_case, x, y) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/oneflow/test/graph/test_graph_optimizer.py b/python/oneflow/test/graph/test_graph_optimizer.py index d32ed6225af..8797bf7a930 100644 --- a/python/oneflow/test/graph/test_graph_optimizer.py +++ b/python/oneflow/test/graph/test_graph_optimizer.py @@ -140,8 +140,8 @@ def __init__(self): self.add_optimizer(sgd0, lr_sch=constant_warmup_cosine_lr0) self.add_optimizer(sgd1, lr_sch=linear_warmup_cosine_lr1) - def build(self, x, y): - out0, out1 = self.m(x, y) + def build(self, x): + out0, out1 = self.m(x) out0.backward() out1.backward() return out0, out1 @@ -149,6 +149,7 @@ def build(self, x, y): g = CustomGraph0() x = flow.Tensor(4, 10) flow.nn.init.uniform_(x, a=-1.0, b=1.0) + g._filter_states() g._generate_config_proto() print("repr(g): \n", repr(g)) print("g.config.proto: \n", g.config.proto) diff --git a/python/oneflow/test/graph/test_graph_reuse_var.py b/python/oneflow/test/graph/test_graph_reuse_var.py new file mode 100644 index 00000000000..ccb18225541 --- /dev/null +++ b/python/oneflow/test/graph/test_graph_reuse_var.py @@ -0,0 +1,98 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import unittest +from collections import OrderedDict + +import numpy as np +from test_util import GenArgList + +import oneflow as flow +import oneflow.unittest + + +@flow.unittest.skip_unless_1n2d() +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +class TestGraphResueVar(flow.unittest.TestCase): + def test_graph_reuse_var(test_case): + rank = flow.env.get_rank() + P = flow.placement("cuda", {0: [0, 1]}) + B = flow.sbp.broadcast + + class ReuseVarModule(flow.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = flow.nn.Linear(2, 2) + self.linear2 = flow.nn.Linear(2, 2) + # Reuse parameter + self.linear2.weight = self.linear1.weight + + def forward(self, x): + # Allow user to call parameter outside it's module. + self.linear1.weight + x = self.linear1(x) + x = self.linear2(x) + return x + + reuse_var_m = ReuseVarModule() + reuse_var_m.to_consistent(placement=P, sbp=B) + of_sgd = flow.optim.SGD(reuse_var_m.parameters(), lr=0.001, momentum=0.9) + + class ReuseVarGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + self.reuse_var_m = reuse_var_m + self.add_optimizer(of_sgd) + + def build(self, x): + x = self.reuse_var_m(x) + loss = x.sum() + loss.backward() + return loss + + x = flow.randint(0, 1, (2, 2), placement=P, sbp=B, dtype=flow.float32) + reuse_var_g = ReuseVarGraph() + loss = reuse_var_g(x) + + # check lazy tensor builder + block = reuse_var_g.reuse_var_m + test_case.assertEqual( + block.linear1.weight.lazy_origin_builder().name, + "reuse_var_m.linear1.weight", + ) + test_case.assertEqual( + block.linear1.weight.lazy_origin_builder().name, + block.linear2.weight.lazy_origin_builder().name, + ) + + # check optimizer's variable list + var_list = [ + "reuse_var_m.linear1.weight", + "reuse_var_m.linear1.bias", + "reuse_var_m.linear2.bias", + ] + var_list_in_conf = reuse_var_g._graph_proto.job_conf.train_conf.optimizer_conf[ + 0 + ].variable_op_names + test_case.assertEqual(len(var_list_in_conf), 3) + for idx in range(3): + test_case.assertEqual(var_list[idx], var_list_in_conf[idx]) + if rank == 0: + print(var_list_in_conf[idx]) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/oneflow/test/graph/test_graph_sparse_optimizer.py b/python/oneflow/test/graph/test_graph_sparse_optimizer.py new file mode 100644 index 00000000000..0154bb591e5 --- /dev/null +++ b/python/oneflow/test/graph/test_graph_sparse_optimizer.py @@ -0,0 +1,74 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import unittest + +import oneflow as flow +import oneflow.unittest + + +class MyModule(flow.nn.Module): + def __init__(self, placement=None, sbp=None): + super().__init__() + w = flow.randn(10, 10, placement=placement, sbp=sbp) + self.weight = flow.nn.Parameter(w) + + def forward(self, input): + return flow._C.gather(self.weight, input, 0) + + +class MyGraph(flow.nn.Graph): + def __init__(self, module): + super().__init__() + self.m = module + sgd = flow.optim.SGD(module.parameters(), lr=1e-3) + self.add_optimizer(flow.optim.utils.SparseOptimizer(sgd)) + + def build(self, input): + result = self.m(input) + result.mean().backward() + + +def _rand_input(placement=None, sbp=None): + generator = flow.Generator() + generator.manual_seed(0) + return flow.randint(0, 10, (8,), generator=generator, placement=placement, sbp=sbp) + + +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +@flow.unittest.skip_unless_1n1d() +class GraphSparseOptimizerTest(oneflow.unittest.TestCase): + def test(test_case): + PLC = flow.placement("cuda", {0: [0]}) + SBP = flow.sbp.broadcast + m = MyModule(PLC, SBP) + graph = MyGraph(m) + graph._compile(_rand_input(PLC, SBP)) + + sparse_optimizer_found = False + for op in graph._full_graph_proto.net.op: + # print("==>", op.name) + if op.HasField("user_conf"): + # print(" -->", op.user_conf.op_type_name) + if op.user_conf.op_type_name == "indexed_slices_sgd_update": + sparse_optimizer_found = True + break + + test_case.assertTrue(sparse_optimizer_found) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/oneflow/test/graph/test_input_op_expr.py b/python/oneflow/test/graph/test_input_op_expr.py index 7685b957d62..57c12b2ea10 100644 --- a/python/oneflow/test/graph/test_input_op_expr.py +++ b/python/oneflow/test/graph/test_input_op_expr.py @@ -22,6 +22,7 @@ import oneflow import oneflow as flow import oneflow._oneflow_internal +import oneflow._oneflow_internal._C as _C import oneflow.framework.c_api_util as c_api_util import oneflow.framework.session_context as session_ctx import oneflow.unittest @@ -58,7 +59,7 @@ def test_feed_input_tensor(test_case): op_name, input_conf, ["in_0"], ["out_0"] ) attrs = oneflow._oneflow_internal.MutableCfgAttrMap() - out_tensor = input_op.apply([x], attrs)[0] + out_tensor = _C.dispatch_feed_input(input_op, x) test_case.assertEqual(out_tensor.shape, (1, 1, 10, 10)) test_case.assertTrue(out_tensor.is_lazy) test_case.assertTrue(out_tensor.is_local) diff --git a/python/oneflow/test/graph/test_output_op_expr.py b/python/oneflow/test/graph/test_output_op_expr.py index c159182fb79..d6901bed6a6 100644 --- a/python/oneflow/test/graph/test_output_op_expr.py +++ b/python/oneflow/test/graph/test_output_op_expr.py @@ -22,6 +22,7 @@ import oneflow import oneflow as flow import oneflow._oneflow_internal +import oneflow._oneflow_internal._C as _C import oneflow.framework.c_api_util as c_api_util import oneflow.framework.session_context as session_ctx import oneflow.unittest @@ -48,7 +49,6 @@ def test_fetch_output_tensor(test_case): job_conf.set_job_name("cc_test_output_op_expr_job") job_conf.mutable_predict_conf() c_api_util.CurJobBuildAndInferCtx_SetJobConf(job_conf) - attrs = oneflow._oneflow_internal.MutableCfgAttrMap() input_conf = ( oneflow._oneflow_internal.oneflow.core.operator.op_conf.FeedInputOpConf() ) @@ -65,11 +65,11 @@ def test_fetch_output_tensor(test_case): output_op = oneflow._oneflow_internal.one.FetchOutputOpExpr( "cc_Output_0", output_conf, ["in_0"], ["out_0"] ) - lazy_tensor = input_op.apply([x], attrs)[0] + lazy_tensor = _C.dispatch_feed_input(input_op, x) test_case.assertEqual(lazy_tensor.shape, (1, 1, 10, 10)) test_case.assertTrue(lazy_tensor.is_lazy) test_case.assertTrue(lazy_tensor.is_local) - eager_tensor = output_op.apply([lazy_tensor], attrs)[0] + eager_tensor = _C.dispatch_fetch_output(output_op, lazy_tensor) test_case.assertEqual(eager_tensor.shape, (1, 1, 10, 10)) test_case.assertTrue(not eager_tensor.is_lazy) test_case.assertTrue(eager_tensor.is_local) diff --git a/python/oneflow/test/graph/test_user_op_expr.py b/python/oneflow/test/graph/test_user_op_expr.py index 9f0c2a35dda..0576c718b4c 100644 --- a/python/oneflow/test/graph/test_user_op_expr.py +++ b/python/oneflow/test/graph/test_user_op_expr.py @@ -21,6 +21,7 @@ import oneflow import oneflow as flow import oneflow._oneflow_internal +import oneflow._oneflow_internal._C as _C import oneflow.framework.c_api_util as c_api_util import oneflow.framework.session_context as session_ctx import oneflow.unittest @@ -95,18 +96,13 @@ def _test_user_op_graph(test_case, is_cuda): "cc_Output_0", output_conf, ["in_0"], ["out_0"] ) - attrs = oneflow._oneflow_internal.MutableCfgAttrMap() - - x0_tensor_in_c = _get_c_tensor(x0) - x1_tensor_in_c = _get_c_tensor(x1) - weight0_tensor_in_c = _get_c_tensor(weight0) - - x0_lazy_tensor = x0_op.apply([x0_tensor_in_c], attrs)[0] - x1_lazy_tensor = x1_op.apply([x1_tensor_in_c], attrs)[0] - weight0_lazy_tensor = weight0_op.apply([weight0_tensor_in_c], attrs)[0] + x0_lazy_tensor = _C.dispatch_feed_input(x0_op, x0) + x1_lazy_tensor = _C.dispatch_feed_input(x1_op, x1) + weight0_lazy_tensor = _C.dispatch_feed_input(weight0_op, weight0) test_case.assertEqual(x0_lazy_tensor.shape, (20, 30)) test_case.assertTrue(x0_lazy_tensor.is_lazy) + test_case.assertEqual(weight0_lazy_tensor.shape, (30, 50)) test_case.assertTrue(weight0_lazy_tensor.is_lazy) test_case.assertEqual(x1_lazy_tensor.shape, (50, 70)) @@ -128,7 +124,7 @@ def _test_user_op_graph(test_case, is_cuda): test_case.assertEqual(y1.shape, (20, 70)) test_case.assertTrue(y1.is_lazy) - eager_output = output_op.apply([y1], attrs)[0] + eager_output = _C.dispatch_fetch_output(output_op, y1) test_case.assertEqual(eager_output.shape, (20, 70)) test_case.assertTrue(not eager_output.is_lazy) diff --git a/python/oneflow/test/graph/test_util.py b/python/oneflow/test/graph/test_util.py index 5910e2d2d26..ce9f4f39e89 100644 --- a/python/oneflow/test/graph/test_util.py +++ b/python/oneflow/test/graph/test_util.py @@ -117,3 +117,14 @@ def Coordinate2Index(coordinate, tensor_shape): size_at_axis *= tensor_shape[j] idx += size_at_axis return idx + + +def generate_graph(func): + class Graph(flow.nn.Graph): + def __init__(self): + super().__init__() + + def build(self, *args): + return func(*args) + + return Graph() diff --git a/python/oneflow/test/graph/test_variable_op_expr.py b/python/oneflow/test/graph/test_variable_op_expr.py index 3e638e7a9a8..f7cfb258e40 100644 --- a/python/oneflow/test/graph/test_variable_op_expr.py +++ b/python/oneflow/test/graph/test_variable_op_expr.py @@ -22,6 +22,7 @@ import oneflow import oneflow as flow import oneflow._oneflow_internal +import oneflow._oneflow_internal._C as _C import oneflow.framework.c_api_util as c_api_util import oneflow.framework.session_context as session_ctx import oneflow.unittest @@ -57,8 +58,7 @@ def test_feed_var_tensor(test_case): var_op = oneflow._oneflow_internal.one.FeedVariableOpExpr( op_name, var_conf, ["in_0"], ["out_0"] ) - attrs = oneflow._oneflow_internal.MutableCfgAttrMap() - out_tensor = var_op.apply([x], attrs)[0] + out_tensor = _C.dispatch_feed_variable(var_op, x, l2=0) test_case.assertEqual(out_tensor.shape, (1, 1, 10, 10)) test_case.assertTrue(out_tensor.is_lazy) test_case.assertTrue(out_tensor.is_local) diff --git a/python/oneflow/test/modules/resnet50_model.py b/python/oneflow/test/modules/resnet50_model.py index 9782878621e..8a2aa565c05 100644 --- a/python/oneflow/test/modules/resnet50_model.py +++ b/python/oneflow/test/modules/resnet50_model.py @@ -49,10 +49,9 @@ def __init__( else: self.register_parameter("running_mean", None) self.register_parameter("running_var", None) - self._op = flow.builtin_op("identity").Input("in").Output("out").Build() def forward(self, input): - return self._op(input)[0] + return flow._C.identity(input) def conv3x3( diff --git a/python/oneflow/test/modules/test_abs.py b/python/oneflow/test/modules/test_abs.py index d51d2470cc6..ae4bfcb9f1d 100644 --- a/python/oneflow/test/modules/test_abs.py +++ b/python/oneflow/test/modules/test_abs.py @@ -23,8 +23,8 @@ @flow.unittest.skip_unless_1n1d() class TestAbsModule(flow.unittest.TestCase): - @autotest(check_graph=False) - def test_abs_with_0shape_data(test_case): + @autotest(check_graph=True) + def test_abs_with_0_size_data(test_case): device = random_device() x = random_pytorch_tensor().to(device) y = torch.abs(x) diff --git a/python/oneflow/test/modules/test_activation.py b/python/oneflow/test/modules/test_activation.py index 12d22fdf72e..bae166eb9ae 100644 --- a/python/oneflow/test/modules/test_activation.py +++ b/python/oneflow/test/modules/test_activation.py @@ -29,7 +29,7 @@ @flow.unittest.skip_unless_1n1d() class TestReLUModule(flow.unittest.TestCase): - @autotest() + @autotest(check_graph=True) def test_relu_module_with_random_data(test_case): m = torch.nn.ReLU() m.train(random()) @@ -39,8 +39,8 @@ def test_relu_module_with_random_data(test_case): y = m(x) return y - @autotest(auto_backward=False, check_graph=False) - def test_relu_module_with_0shape_data(test_case): + @autotest(auto_backward=False, check_graph=True) + def test_relu_module_with_0_size_data(test_case): m = torch.nn.ReLU() m.train(random()) device = random_device() @@ -52,7 +52,7 @@ def test_relu_module_with_0shape_data(test_case): @flow.unittest.skip_unless_1n1d() class TestReLU6Module(flow.unittest.TestCase): - @autotest() + @autotest(check_graph=True) def test_relu6_module_with_random_data(test_case): m = torch.nn.ReLU6() m.train(random()) @@ -62,8 +62,8 @@ def test_relu6_module_with_random_data(test_case): y = m(x) return y - @autotest(auto_backward=False, check_graph=False) - def test_relu6_module_with_0shape_data(test_case): + @autotest(auto_backward=False, check_graph=True) + def test_relu6_module_with_0_size_data(test_case): m = torch.nn.ReLU6() m.train(random()) device = random_device() @@ -85,8 +85,8 @@ def test_tanh_module_with_random_data(test_case): y = m(x) return y - @autotest(auto_backward=False, check_graph=False) - def test_tanh_module_with_0shapedata(test_case): + @autotest(auto_backward=False, check_graph=True) + def test_tanh_module_with_0_size_data(test_case): m = torch.nn.Tanh() m.train(random()) device = random_device() @@ -102,8 +102,8 @@ def test_flow_tanh_with_random_data(test_case): y = torch.tanh(x) return y - @autotest(auto_backward=False, check_graph=False) - def test_flow_tanh_with_0shape_data(test_case): + @autotest(auto_backward=False, check_graph=True) + def test_flow_tanh_with_0_size_data(test_case): device = random_device() x = random_pytorch_tensor(4, 2, 3, 0, 3).to(device) y = torch.tanh(x) @@ -122,8 +122,8 @@ def test_elu_module_with_random_data(test_case): y = m(x) return y - @autotest(auto_backward=False, check_graph=False) - def test_elu_module_with_0shape_data(test_case): + @autotest(auto_backward=False, check_graph=True) + def test_elu_module_with_0_size_data(test_case): m = torch.nn.ELU(alpha=random() | nothing()) m.train(random()) device = random_device() @@ -145,8 +145,8 @@ def test_celu_module_with_random_data(test_case): y = m(x) return y - @autotest(auto_backward=False, check_graph=False) - def test_celu_module_with_0shape_data(test_case): + @autotest(auto_backward=False, check_graph=True) + def test_celu_module_with_0_size_data(test_case): m = torch.nn.CELU(alpha=random() | nothing()) m.train(random()) device = random_device() diff --git a/python/oneflow/test/modules/test_adaptive_pool.py b/python/oneflow/test/modules/test_adaptive_pool.py index 5d318726a8a..97adbc4eb0e 100644 --- a/python/oneflow/test/modules/test_adaptive_pool.py +++ b/python/oneflow/test/modules/test_adaptive_pool.py @@ -84,15 +84,11 @@ def test_adaptive_avgpool2d_functional(test_case): x = random_pytorch_tensor(ndim=4).to(device) return torch.nn.functional.adaptive_avg_pool2d(x, output_size=random().to(int)) - @unittest.skipIf( - version.parse(torch_original.__version__) < version.parse("1.10.0"), - "GPU version 'nn.AdaptiveAvgPool3d' has a bug in PyTorch before '1.10.0'", - ) @autotest() def test_adaptive_avgpool3d_functional(test_case): device = random_device() x = random_pytorch_tensor(ndim=5).to(device) - return torch.nn.functional.adaptive_avg_pool2d(x, output_size=random().to(int)) + return torch.nn.functional.adaptive_avg_pool3d(x, output_size=random().to(int)) if __name__ == "__main__": diff --git a/python/oneflow/test/modules/test_add.py b/python/oneflow/test/modules/test_add.py index 8d0b679d648..8b5b0ace7ad 100644 --- a/python/oneflow/test/modules/test_add.py +++ b/python/oneflow/test/modules/test_add.py @@ -170,8 +170,8 @@ def test_add(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) - @autotest(check_graph=False) - def test_0shape_add(test_case): + @autotest(check_graph=True) + def test_0_size_add(test_case): device = random_device() x = random_pytorch_tensor(2, 0, 3).to(device) y = random_pytorch_tensor(2, 1, 3).to(device) @@ -194,7 +194,7 @@ def test_0dim_two_inplace_add(test_case): x += y.mean() return x - @autotest(check_graph=False) + @autotest(check_graph=True) def test_add_with_alpha(test_case): device = random_device() x1 = random_pytorch_tensor(2, 2, 3).to(device).mean() diff --git a/python/oneflow/test/modules/test_addmm.py b/python/oneflow/test/modules/test_addmm.py index 355ff577b61..3ca3d3569c6 100644 --- a/python/oneflow/test/modules/test_addmm.py +++ b/python/oneflow/test/modules/test_addmm.py @@ -67,7 +67,7 @@ def test_addmm(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) - @autotest(check_graph=False) + @autotest(check_graph=True) def test_addmm_flow_with_random_data(test_case): device = random_device() input = random_pytorch_tensor(ndim=2, dim0=2, dim1=3).to(device) @@ -82,7 +82,7 @@ def test_addmm_flow_with_random_data(test_case): ) return y - @autotest(check_graph=False) + @autotest(check_graph=True) def test_addmm_broadcast_flow_with_random_data(test_case): device = random_device() input = random_pytorch_tensor(ndim=2, dim0=1, dim1=1).to(device) diff --git a/python/oneflow/test/modules/test_affine_grid.py b/python/oneflow/test/modules/test_affine_grid.py index c245d5e68f8..f91426a4e65 100644 --- a/python/oneflow/test/modules/test_affine_grid.py +++ b/python/oneflow/test/modules/test_affine_grid.py @@ -89,7 +89,7 @@ def test_affine_grid_3d(test_case): np.allclose(output.numpy(), groundtruth, rtol=1e-3, atol=1e-4) ) - @autotest(rtol=1e-03, atol=1e-04, check_graph=False) + @autotest(rtol=1e-03, atol=1e-04, check_allclose=False, check_graph=True) def test_flow_affine_grid_2d_with_random_data(test_case): N = randint(1, 8) C = randint(1, 8) @@ -103,7 +103,7 @@ def test_flow_affine_grid_2d_with_random_data(test_case): ).to(device) return output - @autotest(rtol=1e-03, atol=1e-03, check_graph=False) + @autotest(rtol=1e-03, atol=1e-03, check_allclose=False, check_graph=True) def test_flow_affine_grid_3d_with_random_data(test_case): N = randint(1, 8) C = randint(1, 8) diff --git a/python/oneflow/test/modules/test_arange.py b/python/oneflow/test/modules/test_arange.py index dd05da8c62c..56374b9fc6c 100644 --- a/python/oneflow/test/modules/test_arange.py +++ b/python/oneflow/test/modules/test_arange.py @@ -67,7 +67,7 @@ def test_arange(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) - @autotest(n=5, auto_backward=False, rtol=1e-5, atol=1e-5, check_graph=False) + @autotest(n=30, auto_backward=False, rtol=1e-5, atol=1e-5, check_graph=True) def test_arange_with_random_data(test_case): start = random().to(int) end = start + random().to(int) @@ -77,6 +77,16 @@ def test_arange_with_random_data(test_case): x.to(device) return x + @autotest(n=5, auto_backward=False, rtol=1e-5, atol=1e-5, check_graph=True) + def test_arange_with_float_delta(test_case): + start = random().to(int) + end = start + random().to(int) + step = random(0, end - start).to(float) + x = torch.arange(start=start, end=end, step=step) + device = random_device() + x.to(device) + return x + def test_consistent_naive(test_case): placement = flow.placement("cpu", {0: [0]}) sbp = (flow.sbp.broadcast,) diff --git a/python/oneflow/test/modules/test_argmax.py b/python/oneflow/test/modules/test_argmax.py index 4085611d7e1..273ef7ff2da 100644 --- a/python/oneflow/test/modules/test_argmax.py +++ b/python/oneflow/test/modules/test_argmax.py @@ -93,7 +93,7 @@ def test_argmax(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) - @autotest(n=5, auto_backward=False, rtol=1e-5, atol=1e-5, check_graph=False) + @autotest(n=5, auto_backward=False, rtol=1e-5, atol=1e-5, check_graph=True) def test_argmax_with_random_data(test_case): device = random_device() ndim = random(1, 6).to(int) diff --git a/python/oneflow/test/modules/test_autograd.py b/python/oneflow/test/modules/test_autograd.py index 2e1af21a532..2e6fd3c93c7 100644 --- a/python/oneflow/test/modules/test_autograd.py +++ b/python/oneflow/test/modules/test_autograd.py @@ -87,7 +87,7 @@ def test_autograd_interface(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) - @autotest(n=10, auto_backward=True, rtol=1e-3, atol=1e-3, check_graph=False) + @autotest(n=10, auto_backward=True, rtol=1e-3, atol=1e-3, check_graph=True) def test_accumulate_grad(test_case): device = random_device() ndim = random(1, 4).to(int) diff --git a/python/oneflow/test/modules/test_cast.py b/python/oneflow/test/modules/test_cast.py index bf8782dc3c8..e2f2a17b12a 100644 --- a/python/oneflow/test/modules/test_cast.py +++ b/python/oneflow/test/modules/test_cast.py @@ -66,7 +66,7 @@ def test_cast(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) - def test_cast_with_0shape_data(test_case): + def test_cast_with_0_size_data(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [ _test_cast_float2int, diff --git a/python/oneflow/test/modules/test_ceil.py b/python/oneflow/test/modules/test_ceil.py index 749267c11b3..fb457b4dfb7 100644 --- a/python/oneflow/test/modules/test_ceil.py +++ b/python/oneflow/test/modules/test_ceil.py @@ -32,8 +32,8 @@ def test_ceil_flow_with_random_data(test_case): y = torch.ceil(input) return y - @autotest(auto_backward=False, check_graph=False) - def test_ceil_with_0shape_data(test_case): + @autotest(auto_backward=False, check_graph=True) + def test_ceil_with_0_size_data(test_case): device = random_device() x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) y = torch.ceil(x) diff --git a/python/oneflow/test/modules/test_chunk.py b/python/oneflow/test/modules/test_chunk.py index a4ac939b47d..c8e41698eed 100644 --- a/python/oneflow/test/modules/test_chunk.py +++ b/python/oneflow/test/modules/test_chunk.py @@ -42,6 +42,21 @@ def test_flow_chunk_list_with_random_data(test_case): z = torch.cat(y, dim=dim) return z + @autotest(check_graph=False) + def test_flow_chunk_list_with_random_data_negative_dim(test_case): + device = random_device() + dim = random(1, 3).to(int) + x = random_pytorch_tensor( + ndim=4, + dim0=random(low=4, high=8).to(int), + dim1=random(low=4, high=8).to(int), + dim2=random(low=4, high=8).to(int), + dim3=random(low=4, high=8).to(int), + ).to(device) + y = torch.chunk(x, chunks=4, dim=-1) + z = torch.cat(y, dim=-1) + return z + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_clamp.py b/python/oneflow/test/modules/test_clamp.py index 36f9e24de35..1140a6732be 100644 --- a/python/oneflow/test/modules/test_clamp.py +++ b/python/oneflow/test/modules/test_clamp.py @@ -154,8 +154,8 @@ def test_clip_max_none_flow_with_random_data(test_case): ) return y - @autotest(auto_backward=False, check_graph=False) - def test_clamp_with_0shape_data(test_case): + @autotest(auto_backward=False, check_graph=True) + def test_clamp_with_0_size_data(test_case): device = random_device() x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) y = torch.clamp(x, min=random().to(float), max=random().to(float)) diff --git a/python/oneflow/test/modules/test_comm_ops.py b/python/oneflow/test/modules/test_comm_ops.py index 5a2dc78a6ec..8a03b3c8550 100644 --- a/python/oneflow/test/modules/test_comm_ops.py +++ b/python/oneflow/test/modules/test_comm_ops.py @@ -15,29 +15,42 @@ """ import numpy as np +import unittest import os + import oneflow as flow import oneflow.unittest -import unittest + +import torch +import torch.distributed as dist +@unittest.skip("comm test case has bug") @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestAllReduce(flow.unittest.TestCase): @flow.unittest.skip_unless_1n2d() def test_all_reduce_1n2d(test_case): np_arr = np.array([[1, 2], [3, 4]]) - input = flow.tensor(np_arr, device="cuda") - out = flow.comm.all_reduce(input) - test_case.assertTrue(np.allclose(out.numpy(), np_arr * 2)) + of_tensor = flow.tensor(np_arr, device="cuda") + flow.comm.all_reduce(of_tensor) + + if not torch.distributed.is_initialized(): + dist.init_process_group("gloo") + torch_tensor = torch.tensor(np_arr) + dist.all_reduce(torch_tensor) + + test_case.assertTrue(np.allclose(of_tensor.numpy(), torch_tensor.cpu().numpy())) + dist.destroy_process_group() @flow.unittest.skip_unless_2n2d() def test_all_reduce_2n2d(test_case): np_arr = np.array([[1, 2], [3, 4]]) - input = flow.tensor(np_arr, device="cuda") - out = flow.comm.all_reduce(input) - test_case.assertTrue(np.allclose(out.numpy(), np_arr * 4)) + tensor = flow.tensor(np_arr, device="cuda") + flow.comm.all_reduce(tensor) + test_case.assertTrue(np.allclose(tensor.numpy(), np_arr * 4)) +@unittest.skip("comm test case has bug") @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestAllGather(flow.unittest.TestCase): @flow.unittest.skip_unless_1n2d() @@ -46,17 +59,27 @@ def test_all_gather_1n2d(test_case): np_arr = np.array([[2, 3], [4, 5]]) elif flow.env.get_rank() == 1: np_arr = np.array([[1, 2], [3, 4]]) - input = flow.tensor(np_arr, device="cuda", dtype=flow.int32) - tensor_list = [flow.zeros(np_arr.shape, dtype=flow.int32) for _ in range(2)] - flow.comm.all_gather(tensor_list, input) + of_input = flow.tensor(np_arr, device="cuda", dtype=flow.int32) + of_tensor_list = [flow.zeros(np_arr.shape, dtype=flow.int32) for _ in range(2)] + flow.comm.all_gather(of_tensor_list, of_input) + + if not torch.distributed.is_initialized(): + dist.init_process_group("gloo") + torch_tensor_list = [ + torch.zeros(np_arr.shape, dtype=torch.int32) for _ in range(2) + ] + torch_input = torch.tensor(np_arr, dtype=torch.int32) + dist.all_gather(torch_tensor_list, torch_input) test_case.assertTrue( - np.allclose(tensor_list[0].numpy(), np.array([[2, 3], [4, 5]])) + np.allclose(of_tensor_list[0].numpy(), torch_tensor_list[0].cpu().numpy()) ) test_case.assertTrue( - np.allclose(tensor_list[1].numpy(), np.array([[1, 2], [3, 4]])) + np.allclose(of_tensor_list[1].numpy(), torch_tensor_list[1].cpu().numpy()) ) + dist.destroy_process_group() +@unittest.skip("comm test case has bug") @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestBroadCast(flow.unittest.TestCase): @flow.unittest.skip_unless_1n2d() @@ -65,74 +88,87 @@ def test_broadcast_1n2d(test_case): np_arr = np.array([[1, 2], [3, 4]]) elif flow.env.get_rank() == 1: np_arr = np.array([[4, 5], [6, 7]]) - tensor = flow.tensor(np_arr, device="cuda", dtype=flow.int32) - flow.comm.broadcast(tensor, 1) - test_case.assertTrue(np.allclose(tensor.numpy(), np.array([[4, 5], [6, 7]]))) + of_tensor = flow.tensor(np_arr, device="cuda", dtype=flow.int32) + flow.comm.broadcast(of_tensor, 1) + + if not torch.distributed.is_initialized(): + dist.init_process_group("gloo") + + torch_tensor = torch.tensor(np_arr, dtype=torch.int32) + dist.broadcast(torch_tensor, 1) + test_case.assertTrue(np.allclose(of_tensor.numpy(), torch_tensor.cpu().numpy())) - tensor = flow.tensor(np_arr, device="cuda", dtype=flow.int32) - flow.comm.broadcast(tensor, 0) - test_case.assertTrue(np.allclose(tensor.numpy(), np.array([[1, 2], [3, 4]]))) + of_tensor = flow.tensor(np_arr, device="cuda", dtype=flow.int32) + flow.comm.broadcast(of_tensor, 0) + torch_tensor = torch.tensor(np_arr, dtype=torch.int32) + dist.broadcast(torch_tensor, 0) + test_case.assertTrue(np.allclose(of_tensor.numpy(), torch_tensor.cpu().numpy())) + dist.destroy_process_group() +@unittest.skip("comm test case has bug") @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestScatter(flow.unittest.TestCase): @flow.unittest.skip_unless_1n4d() def test_scatter_1n4d(test_case): - output = flow.tensor([[1, 2], [3, 4]]) + of_output = flow.tensor([[1, 2], [3, 4]], device="cuda") + torch_output = torch.tensor([[1, 2], [3, 4]]) + if not torch.distributed.is_initialized(): + dist.init_process_group("gloo") if flow.env.get_rank() == 1: - tensor_list = [flow.tensor([[5, 6], [7, 8]]) + i for i in range(4)] - flow.comm.scatter(output, tensor_list, src=1) - test_case.assertTrue( - np.allclose(output.numpy(), np.array([[6, 7], [8, 9]])) - ) + of_tensor_list = [ + flow.tensor([[5, 6], [7, 8]], device="cuda") + i for i in range(4) + ] + flow.comm.scatter(of_output, of_tensor_list, src=1) + + torch_tensor_list = [torch.tensor(x.numpy()) for x in of_tensor_list] + dist.scatter(torch_output, torch_tensor_list, src=1) + test_case.assertTrue(np.allclose(of_output.numpy(), torch_output.numpy())) else: - flow.comm.scatter(output, src=1) - test_case.assertTrue( - np.allclose( - output.numpy(), np.array([[5, 6], [7, 8]]) + flow.env.get_rank() - ) - ) + flow.comm.scatter(of_output, src=1) + + dist.scatter(torch_output, src=1) + test_case.assertTrue(np.allclose(of_output.numpy(), torch_output.numpy())) + dist.destroy_process_group() +@unittest.skip("comm test case has bug") @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestGather(flow.unittest.TestCase): @flow.unittest.skip_unless_1n4d() def test_gather_1n4d(test_case): np_arr = np.array([[1, 2], [3, 4]]) + of_input = flow.tensor( + np_arr + flow.env.get_rank(), dtype=flow.int32, device="cuda" + ) + + if not torch.distributed.is_initialized(): + dist.init_process_group("gloo") + torch_input = torch.tensor(np_arr + dist.get_rank(), dtype=torch.int32) + if flow.env.get_rank() == 1: - input = flow.tensor( - np_arr + flow.env.get_rank(), device="cuda", dtype=flow.int32 - ) - tensor_list = [flow.zeros(np_arr.shape, dtype=flow.int32) for _ in range(4)] - flow.comm.gather(input, gather_list=tensor_list, dst=1) + of_tensor_list = [ + flow.zeros(np_arr.shape, dtype=flow.int32, device="cuda") + for _ in range(4) + ] + flow.comm.gather(of_input, gather_list=of_tensor_list, dst=1) + + torch_tensor_list = [ + torch.zeros(np_arr.shape, dtype=torch.int32) for _ in range(4) + ] + dist.gather(torch_input, gather_list=torch_tensor_list, dst=1) for i in range(4): test_case.assertTrue( - np.allclose(tensor_list[i].numpy(), np.array([[1, 2], [3, 4]]) + i) + np.allclose(of_tensor_list[i].numpy(), torch_tensor_list[i].numpy()) ) else: - input = flow.tensor( - np_arr + flow.env.get_rank(), device="cuda", dtype=flow.int32 - ) - flow.comm.gather(input, dst=1) - # this case will fail, if do gititem on some a rank in process group - if flow.env.get_rank() == 0: - np_arr = np.array([4, 6, 7, 8], dtype=np.float32) - else: - np_arr = np.array([0, 0, 0, 0], dtype=np.float32) - tensor = flow.tensor(np_arr, dtype=flow.float32) - placement = flow.placement("cuda", {0: range(4)}) - device = flow.device("cuda") - consistent_tensor = tensor.to_consistent(placement, flow.sbp.broadcast) - test_case.assertEqual(consistent_tensor.to_local().device, device) - test_case.assertEqual(consistent_tensor.placement, placement) - test_case.assertTrue( - np.array_equal( - consistent_tensor.to_local().numpy(), - np.array([4, 6, 7, 8], dtype=np.float32), - ) - ) + flow.comm.gather(of_input, dst=1) + dist.gather(torch_input, dst=1) + + dist.destroy_process_group() +@unittest.skip("comm test case has bug") @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestReduce(flow.unittest.TestCase): @flow.unittest.skip_unless_1n2d() @@ -141,32 +177,88 @@ def test_reduce_1n2d(test_case): np_arr = np.array([[1, 2], [3, 4]]) elif flow.env.get_rank() == 1: np_arr = np.array([[4, 5], [6, 7]]) - tensor = flow.tensor(np_arr, device="cuda", dtype=flow.int32) - flow.comm.reduce(tensor, 0) + of_tensor = flow.tensor(np_arr, device="cpu", dtype=flow.int32) + flow.comm.reduce(of_tensor, 0) + + if not torch.distributed.is_initialized(): + dist.init_process_group("nccl") + torch.cuda.set_device(dist.get_rank()) + torch_tensor = torch.tensor(np_arr, dtype=torch.int32, device="cuda") + dist.reduce(torch_tensor, 0) + if flow.env.get_rank() == 0: test_case.assertTrue( - np.allclose(tensor.numpy(), np.array([[5, 7], [9, 11]])) + np.allclose(of_tensor.numpy(), torch_tensor.cpu().numpy()) ) else: test_case.assertTrue( - np.allclose(tensor.numpy(), np.array([[4, 5], [6, 7]])) + np.allclose(of_tensor.numpy(), torch_tensor.cpu().numpy()) + ) + dist.destroy_process_group() + + +@unittest.skip("comm test case has bug") +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +class TestAllToAll(flow.unittest.TestCase): + @flow.unittest.skip_unless_1n4d() + def test_all_to_all_1n4d(test_case): + of_input_list = [ + flow.tensor([0, 1], device="cpu") + i * 2 + flow.env.get_rank() * 8 + for i in range(4) + ] + of_output_list = [flow.tensor([0, 1], device="cpu") for _ in range(4)] + flow.comm.all_to_all(of_output_list, of_input_list) + + # only nccl support + if not torch.distributed.is_initialized(): + dist.init_process_group("nccl") + torch_input_list = [ + torch.tensor(x.numpy()).to("cuda:{}".format(dist.get_rank())) + for x in of_input_list + ] + torch_output_list = [ + torch.tensor(x.numpy()).to("cuda:{}".format(dist.get_rank())) + for x in of_output_list + ] + dist.all_to_all(torch_output_list, torch_input_list) + + for i in range(len(of_output_list)): + test_case.assertTrue( + np.allclose( + of_output_list[i].numpy(), torch_output_list[i].cpu().numpy(), + ) ) + dist.destroy_process_group() +@unittest.skip("comm test case has bug") @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") class TestReduceScatter(flow.unittest.TestCase): @flow.unittest.skip_unless_1n4d() def test_reduce_scatter_1n4d(test_case): - output = flow.tensor([[0, 0], [0, 0]]) - tensor_list = [ - flow.tensor([[1, 2], [3, 4]]) + flow.env.get_rank() + i for i in range(4) + of_output = flow.tensor([[0, 0], [0, 0]], device="cpu") + of_tensor_list = [ + flow.tensor([[1, 2], [3, 4]], device="cpu") + flow.env.get_rank() + i + for i in range(4) ] - flow.comm.reduce_scatter(output, tensor_list) - test_case.assertTrue( - np.allclose(output.numpy(), tensor_list[0].numpy() * 4 + 6) + flow.comm.reduce_scatter(of_output, of_tensor_list) + + if not torch.distributed.is_initialized(): + dist.init_process_group("nccl") + torch_output = torch.tensor([[0, 0], [0, 0]]).to( + "cuda:{}".format(dist.get_rank()) ) + torch_tensor_list = [ + torch.tensor(x.numpy()).to("cuda:{}".format(dist.get_rank())) + for x in of_tensor_list + ] + dist.reduce_scatter(torch_output, torch_tensor_list) + + test_case.assertTrue(np.allclose(of_output.numpy(), torch_output.cpu().numpy())) + dist.destroy_process_group() +@unittest.skip("comm test case has bug") @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") @flow.unittest.skip_unless_1n2d() class TestDocs(flow.unittest.TestCase): diff --git a/python/oneflow/test/modules/test_concat.py b/python/oneflow/test/modules/test_concat.py index 769344254b4..34ed9b24cc2 100644 --- a/python/oneflow/test/modules/test_concat.py +++ b/python/oneflow/test/modules/test_concat.py @@ -140,16 +140,16 @@ def test_cat_with_random_data(test_case): x = random_pytorch_tensor(ndim=2, dim0=random(), dim1=random()).to(device) return torch.cat((x, x, x), random(0, 2).to(int)) - @autotest(n=10, auto_backward=False, check_graph=False) - def test_concat_with_input_0shape_data(test_case): + @autotest(n=10, auto_backward=False, check_graph=True) + def test_concat_with_input_0_size_data(test_case): device = random_device() x = random_pytorch_tensor(4, 2, 3, 2, 4).to(device) y = random_pytorch_tensor(4, 2, 3, random(0, 3), 4).to(device) z = torch.cat((x, y), dim=2) return z - @autotest(n=10, auto_backward=False, check_graph=False) - def test_concat_with_output_0shape_data(test_case): + @autotest(n=10, auto_backward=False, check_graph=True) + def test_concat_with_output_0_size_data(test_case): device = random_device() x = random_pytorch_tensor(4, 2, 0, 2, 4).to(device) y = random_pytorch_tensor(4, 2, 0, 2, 4).to(device) @@ -157,6 +157,12 @@ def test_concat_with_output_0shape_data(test_case): z = torch.cat((x, y), dim=dim) return z + @autotest(n=10, check_graph=False) + def test_cat_only_one_tensor(test_case): + device = random_device() + x = random_pytorch_tensor(4, 2, 3, random(0, 3)).to(device) + return torch.cat((x,), 0) + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_constant.py b/python/oneflow/test/modules/test_constant.py index f306d25698d..0287949d5b5 100644 --- a/python/oneflow/test/modules/test_constant.py +++ b/python/oneflow/test/modules/test_constant.py @@ -68,21 +68,21 @@ def test_flow_ones_list_with_random_data(test_case): ).to(device) return y1, y2, y3, y4 - @autotest(auto_backward=False, check_graph=False) + @autotest(auto_backward=False, check_graph=True) def test_flow_zeros_like_list_with_random_data(test_case): device = random_device() x = random_pytorch_tensor().to(device) y = torch.zeros_like(x) return y - @autotest(auto_backward=False, check_graph=False) + @autotest(auto_backward=False, check_graph=True) def test_flow_ones_like_list_with_random_data(test_case): device = random_device() x = random_pytorch_tensor().to(device) y = torch.ones_like(x) return y - @autotest(auto_backward=True, check_graph=False) + @autotest(auto_backward=True, check_graph=True) def test_flow_new_ones_list_with_random_data(test_case): device = random_device() x = random_pytorch_tensor().to(device) diff --git a/python/oneflow/test/modules/test_constantpad.py b/python/oneflow/test/modules/test_constantpad.py index 6a5419ea7f5..4362b9aa817 100644 --- a/python/oneflow/test/modules/test_constantpad.py +++ b/python/oneflow/test/modules/test_constantpad.py @@ -83,7 +83,7 @@ def test_constantpad3d_with_random_data(test_case): @flow.unittest.skip_unless_1n1d() class TestFunctionalConstantPad2d(flow.unittest.TestCase): - @autotest(n=20, rtol=0.001, atol=0.001, check_graph=False) + @autotest(n=20, rtol=0.001, atol=0.001, check_graph=True) def test_functional_constantpad2d(test_case): device = random_device() padding = random(-1, 6).to(_size_4_t) diff --git a/python/oneflow/test/modules/test_convtranspose.py b/python/oneflow/test/modules/test_convtranspose.py index 79a49e24f44..fa1a2e45a5f 100644 --- a/python/oneflow/test/modules/test_convtranspose.py +++ b/python/oneflow/test/modules/test_convtranspose.py @@ -298,6 +298,30 @@ def test_ConvTranspose1d_(test_case): y = m(x) return y + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") + @autotest(n=30) + def test_deconv1d_group_with_random_data(test_case): + channels = 720 # lcm(1, 2, 3, 4, 5, 6) + m = torch.nn.ConvTranspose1d( + in_channels=channels, + out_channels=channels, + kernel_size=random(1, 4), + stride=random() | nothing(), + padding=random(1, 3).to(int) | nothing(), + dilation=random(1, 5) | nothing(), + groups=random(1, 7), + padding_mode=constant("zeros") | nothing(), + ) + m.train(random()) + + device = random_device() + m.to(device) + m.pytorch.to("cuda") + x = random_pytorch_tensor(ndim=3, dim1=channels).to(device) + x.pytorch = x.pytorch.to("cuda") + y = m(x) + return y + @autotest() def test_ConvTranspose3d_(test_case): channels = random(1, 2) @@ -318,6 +342,30 @@ def test_ConvTranspose3d_(test_case): y = m(x) return y + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") + @autotest(n=30) + def test_deconv3d_group_with_random_data(test_case): + channels = 720 # lcm(1, 2, 3, 4, 5, 6) + m = torch.nn.ConvTranspose3d( + in_channels=channels, + out_channels=channels, + kernel_size=random(1, 4), + stride=random() | nothing(), + padding=random(1, 3).to(int) | nothing(), + dilation=random(1, 5) | nothing(), + groups=random(1, 7), + padding_mode=constant("zeros") | nothing(), + ) + m.train(random()) + + device = random_device() + m.to(device) + m.pytorch.to("cuda") + x = random_pytorch_tensor(ndim=5, dim1=channels).to(device) + x.pytorch = x.pytorch.to("cuda") + y = m(x) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_cumsum.py b/python/oneflow/test/modules/test_cumsum.py new file mode 100644 index 00000000000..86a5a5065c0 --- /dev/null +++ b/python/oneflow/test/modules/test_cumsum.py @@ -0,0 +1,37 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +from collections import OrderedDict + +import oneflow as flow +import oneflow.unittest + +from oneflow.test_utils.automated_test_util import * + + +@flow.unittest.skip_unless_1n1d() +class TestCumsum(flow.unittest.TestCase): + @autotest(n=30, check_graph=True) + def test_cumsum(test_case): + device = random_device() + x = random_pytorch_tensor().to(device) + dim = random(0, x.ndim.pytorch).to(int) + z = torch.cumsum(x, dim) + return z + + +if __name__ == "__main__": + unittest.main() diff --git a/python/oneflow/test/modules/test_deconv2d.py b/python/oneflow/test/modules/test_deconv2d.py index 7b0917b8f39..57458efe459 100644 --- a/python/oneflow/test/modules/test_deconv2d.py +++ b/python/oneflow/test/modules/test_deconv2d.py @@ -336,6 +336,7 @@ def _test_deconv_group_bias_false(test_case, device): ] ] ) + test_case.assertTrue(np.allclose(output.numpy(), np_out, 1e-06, 1e-06)) output = output.sum() output.backward() @@ -890,6 +891,30 @@ def test_deconv2d_with_random_data(test_case): y = m(x) return y + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") + @autotest(n=30) + def test_deconv2d_group_with_random_data(test_case): + channels = 720 # lcm(1, 2, 3, 4, 5, 6) + m = torch.nn.ConvTranspose2d( + in_channels=channels, + out_channels=channels, + kernel_size=random(1, 4), + stride=random() | nothing(), + padding=random(1, 3).to(int) | nothing(), + dilation=random(1, 5) | nothing(), + groups=random(1, 7), + padding_mode=constant("zeros") | nothing(), + ) + m.train(random()) + + device = random_device() + m.to(device) + m.pytorch.to("cuda") + x = random_pytorch_tensor(ndim=4, dim1=channels).to(device) + x.pytorch = x.pytorch.to("cuda") + y = m(x) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_diag.py b/python/oneflow/test/modules/test_diag.py index a2693e93b66..1acbf66b8d3 100644 --- a/python/oneflow/test/modules/test_diag.py +++ b/python/oneflow/test/modules/test_diag.py @@ -28,13 +28,13 @@ @flow.unittest.skip_unless_1n1d() class Test_Diag_module(flow.unittest.TestCase): - @autotest(check_graph=False) + @autotest(check_graph=True) def test_diag_one_dim(test_case): device = random_device() x = random_pytorch_tensor(ndim=1, dim0=random()).to(device) return torch.diag(x) - @autotest(check_graph=False) + @autotest(check_graph=True) def test_diag_other_dim(test_case): device = random_device() x = random_pytorch_tensor(ndim=2, dim0=random(), dim1=random()).to(device) diff --git a/python/oneflow/test/modules/test_diagonal.py b/python/oneflow/test/modules/test_diagonal.py new file mode 100644 index 00000000000..f988f3ca3c4 --- /dev/null +++ b/python/oneflow/test/modules/test_diagonal.py @@ -0,0 +1,44 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +import numpy as np +import oneflow as flow + +import oneflow.unittest +from oneflow.test_utils.automated_test_util import * + + +class TestDiagonal(flow.unittest.TestCase): + @autotest(n=40, check_graph=False) + def test_flow_diagonal_with_random_data(test_case): + device = random_device() + offset = random(-5, 5).to(int) + dim1 = random(-4, 4).to(int) + dim2 = random(-4, 4).to(int) + + x = random_pytorch_tensor( + ndim=4, + dim1=random(4, 6), + dim2=random(4, 6), + dim3=random(4, 6), + dim4=random(4, 6), + ).to(device) + z = torch.diagonal(x, offset, dim1, dim2) + return z + + +if __name__ == "__main__": + unittest.main() diff --git a/python/oneflow/test/modules/test_div.py b/python/oneflow/test/modules/test_div.py index de6083aa200..7ab51192ad2 100644 --- a/python/oneflow/test/modules/test_div.py +++ b/python/oneflow/test/modules/test_div.py @@ -125,8 +125,8 @@ def test_div_against_pytorch(test_case): device=arg[1], ) - @autotest(auto_backward=False, check_graph=False) - def test_0shape_div(test_case): + @autotest(auto_backward=False, check_graph=True) + def test_0_size_div(test_case): device = random_device() x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) y = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) diff --git a/python/oneflow/test/modules/test_dot.py b/python/oneflow/test/modules/test_dot.py index e34ac9bb35d..543c7ec4987 100644 --- a/python/oneflow/test/modules/test_dot.py +++ b/python/oneflow/test/modules/test_dot.py @@ -22,7 +22,7 @@ @flow.unittest.skip_unless_1n1d() class TestDot(flow.unittest.TestCase): - @autotest(check_graph=False) + @autotest(check_graph=True) def test_dot(test_case): device = random_device() k = random(1000, 10000) diff --git a/python/oneflow/test/modules/test_eager_boxing.py b/python/oneflow/test/modules/test_eager_boxing.py index 5bc9337fc11..2938d8063df 100644 --- a/python/oneflow/test/modules/test_eager_boxing.py +++ b/python/oneflow/test/modules/test_eager_boxing.py @@ -3153,5 +3153,66 @@ def test_eager_naive_boxing_s_to_s(test_case): _test_eager_naive_boxing_s_to_s(test_case, *arg) +@flow.unittest.skip_unless_1n2d() +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +class TestEagerConsistentCastWithSamePlacementAndSBP(flow.unittest.TestCase): + def test_eager_consistent_cast_with_same_placement_and_sbp(test_case): + x = np.ones((4, 8), dtype=np.int32) + placement = flow.placement("cuda", {0: range(2)}) + y = flow.tensor( + x, + dtype=flow.float32, + placement=placement, + sbp=[flow.sbp.split(0)], + requires_grad=False, + ) + z = y.to_consistent(placement=placement, sbp=[flow.sbp.split(0)]) + test_case.assertEqual(y.consistent_id(), z.consistent_id()) + + +@flow.unittest.skip_unless_1n4d() +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +class TestEagerConsistentCast1DTo2DSBP(flow.unittest.TestCase): + def test_eager_consistent_cast_1d_to_2d_sbp(test_case): + x = np.ones((4, 8), dtype=np.int32) + placement1 = flow.placement("cuda", {0: range(4)}) + placement2 = flow.placement("cuda", {0: range(4)}, (2, 2)) + y = flow.tensor( + x, + dtype=flow.float32, + placement=placement1, + sbp=[flow.sbp.split(0)], + requires_grad=False, + ) + z = y.to_consistent( + placement=placement2, sbp=[flow.sbp.broadcast, flow.sbp.split(0)] + ) + test_case.assertEqual(z.placement, placement2) + test_case.assertTrue( + np.array_equal(z.to_local().numpy(), np.ones((2, 8), dtype=np.int32),) + ) + + +@flow.unittest.skip_unless_1n4d() +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +class TestEagerConsistentCast2DTo1DSBP(flow.unittest.TestCase): + def test_eager_consistent_cast_2d_to_1d_sbp(test_case): + x = np.ones((4, 8), dtype=np.int32) + placement1 = flow.placement("cuda", {0: range(4)}) + placement2 = flow.placement("cuda", {0: range(4)}, (2, 2)) + y = flow.tensor( + x, + dtype=flow.float32, + placement=placement2, + sbp=[flow.sbp.broadcast, flow.sbp.split(0)], + requires_grad=False, + ) + z = y.to_consistent(placement=placement1, sbp=[flow.sbp.split(0)]) + test_case.assertEqual(z.placement, placement1) + test_case.assertTrue( + np.array_equal(z.to_local().numpy(), np.ones((1, 8), dtype=np.int32),) + ) + + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_eq.py b/python/oneflow/test/modules/test_eq.py index 0c732a8d166..3caa5e014f8 100644 --- a/python/oneflow/test/modules/test_eq.py +++ b/python/oneflow/test/modules/test_eq.py @@ -28,8 +28,8 @@ @flow.unittest.skip_unless_1n1d() class TestEq(flow.unittest.TestCase): - @autotest(auto_backward=False, check_graph=False) - def test_eq_with_0shape_data(test_case): + @autotest(auto_backward=False, check_graph=True) + def test_eq_with_0_size_data(test_case): device = random_device() x = random_pytorch_tensor(3, 2, 0, 3).to(device) y = random_pytorch_tensor(3, 2, 0, 3).to(device) diff --git a/python/oneflow/test/modules/test_erf.py b/python/oneflow/test/modules/test_erf.py index 970b0977c2e..95e499944b9 100644 --- a/python/oneflow/test/modules/test_erf.py +++ b/python/oneflow/test/modules/test_erf.py @@ -29,7 +29,7 @@ @flow.unittest.skip_unless_1n1d() class TestErfModule(flow.unittest.TestCase): - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flow_erf_with_random_data(test_case): device = random_device() x = random_pytorch_tensor().to(device) diff --git a/python/oneflow/test/modules/test_erfc.py b/python/oneflow/test/modules/test_erfc.py index b6adaceb0c9..04fce7eea6a 100644 --- a/python/oneflow/test/modules/test_erfc.py +++ b/python/oneflow/test/modules/test_erfc.py @@ -29,7 +29,7 @@ @flow.unittest.skip_unless_1n1d() class TestErfcModule(flow.unittest.TestCase): - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flow_erfc_with_random_data(test_case): device = random_device() x = random_pytorch_tensor().to(device) diff --git a/python/oneflow/test/modules/test_expand.py b/python/oneflow/test/modules/test_expand.py index 4c65299c755..75c37386a8d 100644 --- a/python/oneflow/test/modules/test_expand.py +++ b/python/oneflow/test/modules/test_expand.py @@ -181,7 +181,7 @@ def random_expand(x, ndim, expand_size): @flow.unittest.skip_unless_1n1d() class TestExpand(flow.unittest.TestCase): - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flow_tensor_expand_with_random_data(test_case): random_expand_size = random(1, 6).to(int).value() x = random_pytorch_tensor(ndim=5, dim0=1, dim1=1, dim2=1, dim3=1, dim4=1) diff --git a/python/oneflow/test/modules/test_expm1.py b/python/oneflow/test/modules/test_expm1.py index 8d4ba7ac61f..58d32dbf624 100644 --- a/python/oneflow/test/modules/test_expm1.py +++ b/python/oneflow/test/modules/test_expm1.py @@ -58,8 +58,8 @@ def test_expm1_flow_with_random_data(test_case): y = torch.expm1(input) return y - @autotest(auto_backward=False, check_graph=False) - def test_expm1_with_0shape_data(test_case): + @autotest(auto_backward=False, check_graph=True) + def test_expm1_with_0_size_data(test_case): device = random_device() x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) y = torch.expm1(x) diff --git a/python/oneflow/test/modules/test_flatten.py b/python/oneflow/test/modules/test_flatten.py index 92ce73f974d..91145475b9f 100644 --- a/python/oneflow/test/modules/test_flatten.py +++ b/python/oneflow/test/modules/test_flatten.py @@ -67,7 +67,7 @@ def test_cast(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flatten_module_with_random_data(test_case): m = torch.nn.Flatten( start_dim=random(1, 6) | nothing(), end_dim=random(1, 6) | nothing() diff --git a/python/oneflow/test/modules/test_flip.py b/python/oneflow/test/modules/test_flip.py index 53abe8cb2fa..0e27d41aab4 100644 --- a/python/oneflow/test/modules/test_flip.py +++ b/python/oneflow/test/modules/test_flip.py @@ -27,7 +27,7 @@ class TestFlip(flow.unittest.TestCase): - @autotest(check_graph=False) + @autotest(check_graph=False, check_allclose=False) def test_flow_flip_list_with_random_data(test_case): device = random_device() x = random_pytorch_tensor( diff --git a/python/oneflow/test/modules/test_fmod.py b/python/oneflow/test/modules/test_fmod.py index 6c0686e3709..56533aa4ec9 100644 --- a/python/oneflow/test/modules/test_fmod.py +++ b/python/oneflow/test/modules/test_fmod.py @@ -55,8 +55,8 @@ def test_flow_fmod_scalar_with_random_data(test_case): other = 3 return torch.fmod(input, other) - @autotest(auto_backward=False, check_graph=False) - def test_fmod_with_0shape_data(test_case): + @autotest(auto_backward=False, check_graph=True) + def test_fmod_with_0_size_data(test_case): device = random_device() x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) y = torch.fmod(x, 2) diff --git a/python/oneflow/test/modules/test_from_numpy.py b/python/oneflow/test/modules/test_from_numpy.py new file mode 100644 index 00000000000..8d469c02ecf --- /dev/null +++ b/python/oneflow/test/modules/test_from_numpy.py @@ -0,0 +1,62 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import random +import unittest + +import numpy as np +import oneflow as flow +import oneflow.unittest + + +@flow.unittest.skip_unless_1n1d() +class TestFromNumpy(flow.unittest.TestCase): + def test_same_data(test_case): + np_arr = np.random.randn(3, 4, 5) + tensor = flow.from_numpy(np_arr) + test_case.assertTrue(np.array_equal(np_arr, tensor.numpy())) + test_case.assertEqual(tensor.size(), (3, 4, 5)) + test_case.assertEqual(tensor.stride(), (20, 5, 1)) + test_case.assertEqual(tensor.storage_offset(), 0) + + np_arr[1:2, 2:3, 3:4] = random.random() + test_case.assertTrue(np.array_equal(np_arr, tensor.numpy())) + + def test_use_ops(test_case): + np_arr = np.random.randn(3, 4, 5) + tensor = flow.from_numpy(np_arr) + res = tensor ** 2 + test_case.assertTrue(np.allclose(np_arr ** 2, res.numpy())) + + def test_more_dtype(test_case): + for dtype in [ + np.float64, + np.float32, + np.float16, + np.int64, + np.int32, + np.int8, + np.uint8, + ]: + np_arr = np.ones((2, 3), dtype=dtype) + tensor = flow.from_numpy(np_arr) + # TODO(wyg): oneflow.float16 do not support to copy from tensor to numpy + if tensor.dtype not in [flow.float16]: + test_case.assertTrue(np.array_equal(np_arr, tensor.numpy())) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/oneflow/test/modules/test_functional_docstr.py b/python/oneflow/test/modules/test_functional_docstr.py index 8c4e2d95f2f..8c302326154 100644 --- a/python/oneflow/test/modules/test_functional_docstr.py +++ b/python/oneflow/test/modules/test_functional_docstr.py @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. """ -import os import inspect +import os import unittest from collections import OrderedDict -from test_util import GenArgList - import oneflow as flow import oneflow.unittest +from test_util import GenArgList + def _run_functional_doctest( test_case, @@ -59,4 +59,5 @@ def test_functional_docstr(test_case): if __name__ == "__main__": + flow.set_printoptions(linewidth=80) unittest.main() diff --git a/python/oneflow/test/modules/test_fused_bias_add_dropout.py b/python/oneflow/test/modules/test_fused_bias_add_dropout.py index 9feb3624b57..0d341f9784e 100644 --- a/python/oneflow/test/modules/test_fused_bias_add_dropout.py +++ b/python/oneflow/test/modules/test_fused_bias_add_dropout.py @@ -25,9 +25,9 @@ import oneflow.unittest -def _test_fused_bias_add_dropout(test_case, channel, axis, drop_prob): - x = np.random.randn(4, channel, 2, 4) - bias = np.random.randn(channel) +def _test_fused_bias_add_dropout(test_case, shape, axis, drop_prob): + x = np.random.randn(*shape) + bias = np.random.randn(shape[axis]) # fused version only support in GPU fused_x_tensor = flow.Tensor(x).to("cuda") fused_x_tensor.requires_grad = True @@ -77,8 +77,8 @@ class TestFusedBiasAddDropout(flow.unittest.TestCase): def test_fuse_bias_add_dropout(test_case): arg_dict = OrderedDict() arg_dict["test_fun"] = [_test_fused_bias_add_dropout] - arg_dict["channels"] = [4, 6, 8] - arg_dict["axis"] = [1] + arg_dict["shape"] = [(16, 64, 72), (32, 16, 48)] + arg_dict["axis"] = [0, 1, 2, -1, -2, -3] arg_dict["drop_prob"] = [0.0, 1.0] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) diff --git a/python/oneflow/test/modules/test_gather.py b/python/oneflow/test/modules/test_gather.py index 9729e834d8a..13cc0751172 100644 --- a/python/oneflow/test/modules/test_gather.py +++ b/python/oneflow/test/modules/test_gather.py @@ -117,7 +117,7 @@ def test_gather(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flow_gather_with_random_data(test_case): device = random_device() input = random_pytorch_tensor(ndim=4, dim1=3, dim2=4, dim3=5).to(device) diff --git a/python/oneflow/test/modules/test_glu.py b/python/oneflow/test/modules/test_glu.py index 8e6f510f84c..b2bcabad73b 100644 --- a/python/oneflow/test/modules/test_glu.py +++ b/python/oneflow/test/modules/test_glu.py @@ -33,8 +33,8 @@ def test_glu_module_with_random_data(test_case): y = m(x, dim) return y - @autotest(n=5) - def test_GLU_module_with_random_data(test_case): + @autotest(n=5, check_graph=True) + def test_glu_module_with_random_data(test_case): device = random_device() m = torch.nn.GLU() m.train(random()) diff --git a/python/oneflow/test/modules/test_greater.py b/python/oneflow/test/modules/test_greater.py index 7d9f81dc63d..e32f0973adf 100644 --- a/python/oneflow/test/modules/test_greater.py +++ b/python/oneflow/test/modules/test_greater.py @@ -119,8 +119,8 @@ def test_tensor_greater_with_random_data(test_case): y2 = x1 > x2 return (y1, y2) - @autotest(auto_backward=False, check_graph=False) - def test_greater_with_0shape_data(test_case): + @autotest(auto_backward=False, check_graph=True) + def test_greater_with_0_size_data(test_case): device = random_device() x1 = random_pytorch_tensor(4, 2, 3, 0, 5).to(device) x2 = random_pytorch_tensor(4, 2, 3, 0, 5).to(device) diff --git a/python/oneflow/test/modules/test_groupnorm.py b/python/oneflow/test/modules/test_groupnorm.py index 680b59729db..91211c53243 100644 --- a/python/oneflow/test/modules/test_groupnorm.py +++ b/python/oneflow/test/modules/test_groupnorm.py @@ -341,7 +341,7 @@ def test_groupnorm(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) - @autotest(rtol=1e-03, atol=1e-03, check_graph=False) + @autotest(rtol=1e-03, atol=1e-03, check_graph=True) def test_group_norm_with_random_data(test_case): channels = random(5, 20) m = torch.nn.GroupNorm( diff --git a/python/oneflow/test/modules/test_layernorm.py b/python/oneflow/test/modules/test_layernorm.py index 6022a07354f..07afe3e48a8 100644 --- a/python/oneflow/test/modules/test_layernorm.py +++ b/python/oneflow/test/modules/test_layernorm.py @@ -203,6 +203,24 @@ def get_random_norm_shape(): y = m(x) return y + @autotest(n=20, auto_backward=True, rtol=1e-3, atol=1e-3) + def test_layernorm_without_affine(test_case): + device = random_device() + channel = random(1, 200).to(int) + height = random(1, 2).to(int) + width = random(8192, 32768).to(int) + + def get_random_norm_shape(): + begin_axis = random(1, 3).to(int).value() + return tuple((channel.value(), height.value(), width.value())[begin_axis:]) + + m = torch.nn.LayerNorm(normalized_shape=get_random_norm_shape()).to(device) + x = random_pytorch_tensor(ndim=4, dim1=channel, dim2=height, dim3=width).to( + device + ) + y = m(x) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_linear.py b/python/oneflow/test/modules/test_linear.py index 1da38776c2d..6a1286884e7 100644 --- a/python/oneflow/test/modules/test_linear.py +++ b/python/oneflow/test/modules/test_linear.py @@ -193,7 +193,7 @@ def test_linear_with_random_data(test_case): y = m(x) return y - @autotest(check_graph=False) + @autotest(check_graph=True) def test_nn_functional_linear_with_random_data(test_case): input_size = random() device = random_device() diff --git a/python/oneflow/test/modules/test_linspace.py b/python/oneflow/test/modules/test_linspace.py new file mode 100644 index 00000000000..3a079d600f0 --- /dev/null +++ b/python/oneflow/test/modules/test_linspace.py @@ -0,0 +1,59 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +from collections import OrderedDict + +from test_util import GenArgList + +import oneflow as flow +import oneflow.unittest + +from oneflow.test_utils.automated_test_util import * + + +@flow.unittest.skip_unless_1n1d() +class TestLinspace(flow.unittest.TestCase): + @autotest(n=30, auto_backward=False, rtol=1e-5, atol=1e-5, check_graph=True) + def test_linspace_int_with_random_data(test_case): + start = random().to(int) + end = start + random().to(int) + steps = random(0, end - start).to(int) + x = torch.linspace(start=start, end=end, steps=steps) + device = random_device() + x.to(device) + return x + + @autotest(n=30, auto_backward=False, rtol=1e-5, atol=1e-5, check_graph=True) + def test_linspace_float_with_random_data(test_case): + start = random() + end = start + random() + steps = random(0, end - start).to(int) + x = torch.linspace(start=start, end=end, steps=steps) + device = random_device() + x.to(device) + return x + + def test_consistent_naive(test_case): + placement = flow.placement("cpu", {0: [0]}) + sbp = (flow.sbp.broadcast,) + x = flow.linspace(start=0, end=10, steps=2, placement=placement, sbp=sbp) + test_case.assertEqual(x.sbp, sbp) + test_case.assertEqual(x.placement, placement) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/oneflow/test/modules/test_log1p.py b/python/oneflow/test/modules/test_log1p.py index a6ce91a26fd..275b00a6890 100644 --- a/python/oneflow/test/modules/test_log1p.py +++ b/python/oneflow/test/modules/test_log1p.py @@ -27,7 +27,7 @@ @flow.unittest.skip_unless_1n1d() class TestLog1pModule(flow.unittest.TestCase): - @autotest(check_graph=False) + @autotest(check_graph=True) def test_log1p_with_random_data(test_case): device = random_device() x = random_pytorch_tensor().to(device) diff --git a/python/oneflow/test/modules/test_logical_and.py b/python/oneflow/test/modules/test_logical_and.py index 59d6ee1a165..f44cd42b8bb 100644 --- a/python/oneflow/test/modules/test_logical_and.py +++ b/python/oneflow/test/modules/test_logical_and.py @@ -77,7 +77,7 @@ def test_scalar_logical_and(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) - @autotest(n=10, auto_backward=False, check_graph=False) + @autotest(n=10, auto_backward=False, check_graph=True) def test_logical_and_with_random_data(test_case): device = random_device() shape = random_tensor().value().shape diff --git a/python/oneflow/test/modules/test_logical_not.py b/python/oneflow/test/modules/test_logical_not.py index abbeff01470..8b07e6aa1c4 100644 --- a/python/oneflow/test/modules/test_logical_not.py +++ b/python/oneflow/test/modules/test_logical_not.py @@ -51,7 +51,7 @@ def test_logical_not(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) - @autotest(n=10, auto_backward=False, check_graph=False) + @autotest(n=10, auto_backward=False, check_graph=True) def test_logical_not_with_random_data(test_case): device = random_device() shape = random_tensor().value().shape diff --git a/python/oneflow/test/modules/test_logical_or.py b/python/oneflow/test/modules/test_logical_or.py index 9ef8b915d3b..2c122d3ea3c 100644 --- a/python/oneflow/test/modules/test_logical_or.py +++ b/python/oneflow/test/modules/test_logical_or.py @@ -76,7 +76,7 @@ def test_scalar_logical_or(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) - @autotest(n=10, auto_backward=False, check_graph=False) + @autotest(n=10, auto_backward=False, check_graph=True) def test_logical_or_with_random_data(test_case): device = random_device() shape = random_tensor().value().shape diff --git a/python/oneflow/test/modules/test_logical_xor.py b/python/oneflow/test/modules/test_logical_xor.py index 3eb20f7fa33..7efe454e445 100644 --- a/python/oneflow/test/modules/test_logical_xor.py +++ b/python/oneflow/test/modules/test_logical_xor.py @@ -97,7 +97,7 @@ def test_scalar_logical_xor(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) - @autotest(n=10, auto_backward=False, check_graph=False) + @autotest(n=10, auto_backward=False, check_graph=True) def test_logical_xor_with_random_data(test_case): device = random_device() shape = random_tensor().value().shape diff --git a/python/oneflow/test/modules/test_lr_scheduler.py b/python/oneflow/test/modules/test_lr_scheduler.py index 01e29d69938..1a8e83e3611 100644 --- a/python/oneflow/test/modules/test_lr_scheduler.py +++ b/python/oneflow/test/modules/test_lr_scheduler.py @@ -15,15 +15,17 @@ """ import math +import random +import tempfile import unittest from collections import OrderedDict -from test_util import GenArgDict import oneflow as flow import oneflow.unittest -from oneflow.nn.parameter import Parameter import torch -import random +from oneflow.nn.parameter import Parameter + +from test_util import GenArgDict def compare_with_troch_reduce_lr( @@ -229,6 +231,44 @@ def test_reduce_lr_on_plateau(test_case): for arg in GenArgDict(arg_dict): compare_with_troch_reduce_lr(test_case, **arg) + def test_warmup_scheduler_save_and_load(test_case): + param = flow.nn.Parameter(flow.ones(3, 4)) + + optimizer = flow.optim.SGD([param]) + cosine_scheduler = flow.optim.lr_scheduler.CosineAnnealingLR(optimizer, 100) + lr_scheduler = flow.optim.lr_scheduler.WarmUpLR( + cosine_scheduler, warmup_factor=0.1, warmup_iters=5, warmup_method="linear", + ) + for _ in range(random.randint(1, 10)): + lr_scheduler.step() + # save + with tempfile.TemporaryDirectory() as save_dir: + flow.save(lr_scheduler.state_dict(), save_dir) + state_dict = flow.load(save_dir) + + # load + param2 = flow.nn.Parameter(flow.ones(3, 4)) + optimizer2 = flow.optim.SGD([param]) + cosine_scheduler2 = flow.optim.lr_scheduler.CosineAnnealingLR(optimizer, 50) + lr_scheduler2 = flow.optim.lr_scheduler.WarmUpLR( + cosine_scheduler2, + warmup_factor=0.5, + warmup_iters=10, + warmup_method="linear", + ) + lr_scheduler2.load_state_dict(state_dict) + + # compare warm up scheduler + for attr in ["warmup_iters", "warmup_factor", "warmup_method", "last_step"]: + test_case.assertEqual( + getattr(lr_scheduler, attr), getattr(lr_scheduler2, attr) + ) + # compare cosine_annealing_lr + for attr in ["T_max", "eta_min", "last_step"]: + test_case.assertEqual( + getattr(cosine_scheduler, attr), getattr(cosine_scheduler2, attr) + ) + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_masked_fill.py b/python/oneflow/test/modules/test_masked_fill.py index d685ef8e559..26c76d2a561 100644 --- a/python/oneflow/test/modules/test_masked_fill.py +++ b/python/oneflow/test/modules/test_masked_fill.py @@ -26,7 +26,7 @@ @flow.unittest.skip_unless_1n1d() class TestMaskedFill(flow.unittest.TestCase): - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flow_masked_fill_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) @@ -36,7 +36,7 @@ def test_flow_masked_fill_with_random_data(test_case): value = random().to(float) return input.masked_fill(mask > 0, value) - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flow_masked_fill_broadcast_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) @@ -46,7 +46,7 @@ def test_flow_masked_fill_broadcast_with_random_data(test_case): value = random().to(float) return input.masked_fill(mask > 0, value) - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flow_masked_fill_int_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) diff --git a/python/oneflow/test/modules/test_math_ops.py b/python/oneflow/test/modules/test_math_ops.py index 8335f30192e..0cb83cefa68 100644 --- a/python/oneflow/test/modules/test_math_ops.py +++ b/python/oneflow/test/modules/test_math_ops.py @@ -156,12 +156,18 @@ def test_square_tensor_with_random_data(test_case): @flow.unittest.skip_unless_1n1d() class TestPow(flow.unittest.TestCase): @autotest(check_graph=False) - def test_pow_scalar_with_random_data(test_case): + def test_pow_float_scalar_with_random_data(test_case): device = random_device() x = random_pytorch_tensor().to(device) y = random().to(float) return torch.pow(x, y) + def test_pow_int_scalar_with_random_data(test_case): + device = random_device() + x = random_pytorch_tensor().to(device) + y = random().to(int) + return torch.pow(x, y) + @autotest(check_graph=False) def test_pow_elementwise_with_random_data(test_case): device = random_device() diff --git a/python/oneflow/test/modules/test_matmul.py b/python/oneflow/test/modules/test_matmul.py index 38e48ebd111..917bcd2dbea 100644 --- a/python/oneflow/test/modules/test_matmul.py +++ b/python/oneflow/test/modules/test_matmul.py @@ -24,7 +24,7 @@ @flow.unittest.skip_unless_1n1d() class TestModule(flow.unittest.TestCase): - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flow_matmul_with_random_data(test_case): device = random_device() k = random(1, 6) @@ -33,7 +33,7 @@ def test_flow_matmul_with_random_data(test_case): z = torch.matmul(x, y) return z - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flow_tensor_matmul_with_random_data(test_case): device = random_device() k = random(1, 6) @@ -41,7 +41,7 @@ def test_flow_tensor_matmul_with_random_data(test_case): y = random_pytorch_tensor(ndim=2, dim0=k).to(device) return x.matmul(y) - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flow_tensor_broadcast_matmul_with_random_data(test_case): device = random_device() k = random(1, 6) diff --git a/python/oneflow/test/modules/test_maxpool.py b/python/oneflow/test/modules/test_maxpool.py index 59f2dcecdd6..b4d148f60b5 100644 --- a/python/oneflow/test/modules/test_maxpool.py +++ b/python/oneflow/test/modules/test_maxpool.py @@ -28,7 +28,7 @@ def unpack_indices(dual_object): @flow.unittest.skip_unless_1n1d() class TestMaxPooling(flow.unittest.TestCase): - @autotest(auto_backward=False) + @autotest(auto_backward=False, check_graph=False) def test_maxpool1d_with_random_data(test_case): return_indices = random().to(bool).value() m = torch.nn.MaxPool1d( @@ -49,7 +49,7 @@ def test_maxpool1d_with_random_data(test_case): else: return y, y.sum().backward() - @autotest(auto_backward=False) + @autotest(auto_backward=False, check_graph=False) def test_maxpool2d_with_random_data(test_case): return_indices = random().to(bool).value() m = torch.nn.MaxPool2d( @@ -73,7 +73,7 @@ def test_maxpool2d_with_random_data(test_case): else: return y, y.sum().backward() - @autotest(auto_backward=False) + @autotest(auto_backward=False, check_graph=False) def test_maxpool3d_with_random_data(test_case): return_indices = random().to(bool).value() m = torch.nn.MaxPool3d( diff --git a/python/oneflow/test/modules/test_mean.py b/python/oneflow/test/modules/test_mean.py index 0cb19203e5d..fab43bc0d82 100644 --- a/python/oneflow/test/modules/test_mean.py +++ b/python/oneflow/test/modules/test_mean.py @@ -79,7 +79,7 @@ def test_mean(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) - @autotest(check_graph=False) + @autotest(check_graph=True) def test_mean_with_random_data(test_case): device = random_device() dim = random(1, 4).to(int) diff --git a/python/oneflow/test/modules/test_meshgrid.py b/python/oneflow/test/modules/test_meshgrid.py index ebb69b91eeb..49a4c2ab55a 100644 --- a/python/oneflow/test/modules/test_meshgrid.py +++ b/python/oneflow/test/modules/test_meshgrid.py @@ -26,29 +26,27 @@ from oneflow.test_utils.automated_test_util import * -def _test_meshgrid_forawd(test_case, device): +def _test_meshgrid_forawd(test_case, device, indexing): input1 = flow.tensor( np.array([1, 2, 3]), dtype=flow.float32, device=flow.device(device) ) input2 = flow.tensor( np.array([4, 5, 6]), dtype=flow.float32, device=flow.device(device) ) - (np_x, np_y) = np.meshgrid(input1.numpy(), input2.numpy(), indexing="ij") - (of_x, of_y) = flow.meshgrid(input1, input2) + (np_x, np_y) = np.meshgrid(input1.numpy(), input2.numpy(), indexing=indexing) + (of_x, of_y) = flow.meshgrid(input1, input2, indexing=indexing) test_case.assertTrue(np.allclose(of_x.numpy(), np_x, 0.0001, 0.0001)) - test_case.assertTrue(np.allclose(of_y.numpy(), np_y, 0.0001, 0.0001)) -def _test_meshgrid_forawd_scalar(test_case, device): +def _test_meshgrid_forawd_scalar(test_case, device, indexing): input1 = flow.tensor(np.array(1.0), dtype=flow.float32, device=flow.device(device)) input2 = flow.tensor(np.array(2.0), dtype=flow.float32, device=flow.device(device)) - (np_x, np_y) = np.meshgrid(input1.numpy(), input2.numpy(), indexing="ij") - (of_x, of_y) = flow.meshgrid(input1, input2) + (np_x, np_y) = np.meshgrid(input1.numpy(), input2.numpy(), indexing=indexing) + (of_x, of_y) = flow.meshgrid(input1, input2, indexing=indexing) test_case.assertTrue(np.allclose(of_x.numpy(), np_x, 0.0001, 0.0001)) - test_case.assertTrue(np.allclose(of_y.numpy(), np_y, 0.0001, 0.0001)) -def _test_meshgrid_forawd_3tensor(test_case, device): +def _test_meshgrid_forawd_3tensor(test_case, device, indexing): input1 = flow.tensor( np.array([1, 2, 3]), dtype=flow.float32, device=flow.device(device) ) @@ -59,12 +57,10 @@ def _test_meshgrid_forawd_3tensor(test_case, device): np.array([7, 8, 9]), dtype=flow.float32, device=flow.device(device) ) (np_x, np_y, np_z) = np.meshgrid( - input1.numpy(), input2.numpy(), input3.numpy(), indexing="ij" + input1.numpy(), input2.numpy(), input3.numpy(), indexing=indexing ) - (of_x, of_y, of_z) = flow.meshgrid(input1, input2, input3) + (of_x, of_y, of_z) = flow.meshgrid(input1, input2, input3, indexing=indexing) test_case.assertTrue(np.allclose(of_x.numpy(), np_x, 0.0001, 0.0001)) - test_case.assertTrue(np.allclose(of_y.numpy(), np_y, 0.0001, 0.0001)) - test_case.assertTrue(np.allclose(of_z.numpy(), np_z, 0.0001, 0.0001)) @flow.unittest.skip_unless_1n1d() @@ -77,10 +73,12 @@ def test_meshgrid(test_case): _test_meshgrid_forawd_3tensor, ] arg_dict["device"] = ["cpu", "cuda"] + arg_dict["indexing"] = ["ij", "xy"] for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) @autotest(auto_backward=False, check_graph=False) + @unittest.skip("pytorch 1.9.0 exist not indexing") def test_meshgrid_with_random_data(test_case): device = random_device() x = random_pytorch_tensor(ndim=1, dim0=3, requires_grad=False).to(device) @@ -88,6 +86,23 @@ def test_meshgrid_with_random_data(test_case): res = torch.meshgrid(x, y) return res[0], res[1] + @autotest(auto_backward=True, check_graph=False) + @unittest.skip("pytorch 1.9.0 exist not indexing") + def test_meshgrid_with_random_data_xy(test_case): + device = random_device() + x = random_pytorch_tensor(ndim=1, dim0=random(1, 6)).to(device) + y = random_pytorch_tensor(ndim=1, dim0=random(1, 6)).to(device) + res = torch.meshgrid(x, y, indexing="xy") + return torch.cat((res[0], res[1]), 0) + + @autotest(auto_backward=True, check_graph=False) + @unittest.skip("pytorch 1.9.0 exist not indexing") + def test_meshgrid_with_random_data_size(test_case): + device = random_device() + x = random_pytorch_tensor(ndim=1, dim0=random(1, 6)).to(device) + res = torch.meshgrid(x, indexing="xy") + return res[0] + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_module_to.py b/python/oneflow/test/modules/test_module_to.py index 4f02fbbba03..505d116cbdc 100644 --- a/python/oneflow/test/modules/test_module_to.py +++ b/python/oneflow/test/modules/test_module_to.py @@ -90,6 +90,30 @@ def test_module_to(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) + def test_module_to_with_var_reuse(test_case): + class ReuseVarModule(flow.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = flow.nn.Linear(3, 4) + self.linear2 = flow.nn.Linear(3, 4) + self.linear2.weight = self.linear1.weight + + reuse_var_m = ReuseVarModule() + + test_case.assertTrue(reuse_var_m.linear1.weight is reuse_var_m.linear2.weight) + test_case.assertEqual(reuse_var_m.linear1.weight.device, cpu0_device) + + test_case.assertTrue(reuse_var_m.linear1.bias is not reuse_var_m.linear2.bias) + test_case.assertEqual(reuse_var_m.linear1.bias.device, cpu0_device) + + reuse_var_m.to(gpu0_device) + + test_case.assertTrue(reuse_var_m.linear1.weight is reuse_var_m.linear2.weight) + test_case.assertEqual(reuse_var_m.linear1.weight.device, gpu0_device) + + test_case.assertTrue(reuse_var_m.linear1.bias is not reuse_var_m.linear2.bias) + test_case.assertEqual(reuse_var_m.linear1.bias.device, gpu0_device) + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_module_to_consistent.py b/python/oneflow/test/modules/test_module_to_consistent.py new file mode 100644 index 00000000000..655a3ce3955 --- /dev/null +++ b/python/oneflow/test/modules/test_module_to_consistent.py @@ -0,0 +1,64 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import unittest +from collections import OrderedDict + +import numpy as np +from test_util import GenArgList + +import oneflow as flow +import oneflow.unittest + + +@flow.unittest.skip_unless_1n2d() +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +class TestModuleToCosistent(flow.unittest.TestCase): + def test_module_to_consistent(test_case): + rank = flow.env.get_rank() + P = flow.placement("cuda", {0: [0, 1]}) + B = flow.sbp.broadcast + + class ReuseVarModule(flow.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = flow.nn.Linear(3, 4) + self.linear2 = flow.nn.Linear(3, 4) + self.linear2.weight = self.linear1.weight + + reuse_var_m = ReuseVarModule() + + test_case.assertTrue(reuse_var_m.linear1.weight is reuse_var_m.linear2.weight) + test_case.assertEqual( + reuse_var_m.linear1.weight.device, flow.device("cpu", rank) + ) + + test_case.assertTrue(reuse_var_m.linear1.bias is not reuse_var_m.linear2.bias) + test_case.assertEqual(reuse_var_m.linear1.bias.device, flow.device("cpu", rank)) + + reuse_var_m.to_consistent(placement=P, sbp=B) + + test_case.assertTrue(reuse_var_m.linear1.weight is reuse_var_m.linear2.weight) + test_case.assertEqual(reuse_var_m.linear1.weight.placement, P) + test_case.assertEqual(reuse_var_m.linear1.weight.sbp[0], B) + + test_case.assertTrue(reuse_var_m.linear1.bias is not reuse_var_m.linear2.bias) + test_case.assertEqual(reuse_var_m.linear1.bias.placement, P) + test_case.assertEqual(reuse_var_m.linear1.bias.sbp[0], B) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/oneflow/test/modules/test_ne.py b/python/oneflow/test/modules/test_ne.py index f7173ce767b..3da5caa8251 100644 --- a/python/oneflow/test/modules/test_ne.py +++ b/python/oneflow/test/modules/test_ne.py @@ -101,8 +101,8 @@ def test_ne(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) - @autotest(auto_backward=False, check_graph=False) - def test_ne_with_0shape_data(test_case): + @autotest(auto_backward=False, check_graph=True) + def test_ne_with_0_size_data(test_case): device = random_device() x1 = random_pytorch_tensor(4, 2, 3, 0, 5).to(device) x2 = random_pytorch_tensor(4, 2, 3, 0, 5).to(device) diff --git a/python/oneflow/test/modules/test_negative.py b/python/oneflow/test/modules/test_negative.py index 956efe3ee5e..ea5549c4d5c 100644 --- a/python/oneflow/test/modules/test_negative.py +++ b/python/oneflow/test/modules/test_negative.py @@ -24,8 +24,8 @@ @flow.unittest.skip_unless_1n1d() class TestNegativeModule(flow.unittest.TestCase): - @autotest(auto_backward=False, check_graph=False) - def test_ne_with_0shape_data(test_case): + @autotest(auto_backward=False, check_graph=True) + def test_ne_with_0_size_data(test_case): device = random_device() x = random_pytorch_tensor(4, 2, 3, 0, 5).to(device) y1 = torch.negative(x) diff --git a/python/oneflow/test/modules/test_norm.py b/python/oneflow/test/modules/test_norm.py index c3210097ae6..918860ad9eb 100644 --- a/python/oneflow/test/modules/test_norm.py +++ b/python/oneflow/test/modules/test_norm.py @@ -255,7 +255,7 @@ def test_norm(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) - @autotest(check_graph=False) + @autotest(check_graph=True) def test_no_dim_no_ord_norm_with_random_data(test_case): device = random_device() input = random_pytorch_tensor().to(device) @@ -263,7 +263,7 @@ def test_no_dim_no_ord_norm_with_random_data(test_case): m = torch.linalg.norm(input, keepdim=keepdim) return m - @autotest(check_graph=False) + @autotest(check_graph=True) def test_one_dim_norm_with_random_data(test_case): device = random_device() input = random_pytorch_tensor(ndim=4).to(device) @@ -274,7 +274,7 @@ def test_one_dim_norm_with_random_data(test_case): m = torch.linalg.norm(input, ord, dim, keepdim) return m - @autotest(check_graph=False) + @autotest(check_graph=True) def test_no_dim_one_shape_norm_with_random_data(test_case): device = random_device() input = random_pytorch_tensor(ndim=1).to(device) @@ -284,7 +284,7 @@ def test_no_dim_one_shape_norm_with_random_data(test_case): m = torch.linalg.norm(input, ord=ord, keepdim=keepdim) return m - @autotest(check_graph=False) + @autotest(check_graph=True) def test_no_dim_two_shape_norm_with_random_data(test_case): device = random_device() input = random_pytorch_tensor(ndim=2).to(device) @@ -293,7 +293,7 @@ def test_no_dim_two_shape_norm_with_random_data(test_case): m = torch.linalg.norm(input, ord=ord, keepdim=keepdim) return m - @autotest(check_graph=False) + @autotest(check_graph=True) def test_tuple_dim_norm_with_random_data(test_case): device = random_device() input = random_pytorch_tensor(ndim=2).to(device) diff --git a/python/oneflow/test/modules/test_l2_normalize.py b/python/oneflow/test/modules/test_normalize.py similarity index 73% rename from python/oneflow/test/modules/test_l2_normalize.py rename to python/oneflow/test/modules/test_normalize.py index 8287a078aa5..b981baf8044 100644 --- a/python/oneflow/test/modules/test_l2_normalize.py +++ b/python/oneflow/test/modules/test_normalize.py @@ -16,10 +16,9 @@ import unittest from collections import OrderedDict - -import numpy as np from test_util import GenArgList - +from oneflow.test_utils.automated_test_util import * +import numpy as np import oneflow as flow import oneflow.unittest @@ -32,12 +31,18 @@ def _count(shape, begin_axis, end_axis): def _l2_norm_numpy(x, dim, epsilon=1e-12): + axes = [k for k in range(len(list(x.shape)))] + axes[0], axes[dim] = axes[dim], axes[0] + axes_tuple = tuple(axes) + + x = np.transpose(x, axes_tuple) + square_x_sum_shape = list(x.shape) - square_x_sum_shape[dim] = 1 + square_x_sum_shape[0] = 1 - c = x.shape[dim] + c = x.shape[0] n = int(x.size / c) - d = _count(x.shape, dim + 1, len(x.shape)) + d = _count(x.shape, 1, len(x.shape)) square_x_sum = np.zeros(square_x_sum_shape) @@ -58,13 +63,21 @@ def _l2_norm_numpy(x, dim, epsilon=1e-12): square_x_sum = square_x_sum_flatten.reshape(square_x_sum.shape) out = out.reshape(x.shape) - return out, square_x_sum + return np.transpose(out, axes_tuple), np.transpose(square_x_sum, axes_tuple) def _l2_norm_backward_np(dy, y, square_x_sum, dim, epsilon=1e-12): - c = dy.shape[dim] + axes = [k for k in range(len(list(y.shape)))] + axes[0], axes[dim] = axes[dim], axes[0] + axes_tuple = tuple(axes) + + dy = np.transpose(dy, axes_tuple) + y = np.transpose(y, axes_tuple) + square_x_sum = np.transpose(square_x_sum, axes_tuple) + + c = dy.shape[0] n = int(dy.size / c) - d = _count(dy.shape, dim + 1, len(y.shape)) + d = _count(dy.shape, 1, len(y.shape)) dx = np.zeros(dy.shape).reshape(-1) dy_flatten = dy.reshape(-1) @@ -89,7 +102,7 @@ def _l2_norm_backward_np(dy, y, square_x_sum, dim, epsilon=1e-12): index = offset + j * d dx[index] = (1 / norm) * dy_flatten[index] - return dx.reshape(y.shape) + return np.transpose(dx.reshape(y.shape), axes_tuple) def _test_l2_normalize(test_case, device, dim, shape): @@ -124,5 +137,23 @@ def test_l2_normalize(test_case): arg[0](test_case, *arg[1:]) +@flow.unittest.skip_unless_1n1d() +class TestFunctionalNormalize(flow.unittest.TestCase): + @autotest(check_graph=False) + def test_functional_normalize(test_case): + device = random_device() + ndim = random(low=2) + + shape = list(random_tensor(ndim).value().shape) + dim = random(low=0, high=ndim).to(int).value() + shape[dim] = random(low=2, high=8).to(int).value() + shape = tuple(shape) + + x = random_pytorch_tensor(len(shape), *shape).to(device) + y = torch.nn.functional.normalize(x, oneof(2, 3, 4), dim, 1e-12) + + return y + + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_permute.py b/python/oneflow/test/modules/test_permute.py index a7bcb4d209c..64143d908f6 100644 --- a/python/oneflow/test/modules/test_permute.py +++ b/python/oneflow/test/modules/test_permute.py @@ -84,7 +84,7 @@ def test_torch_permute4d_with_random_data(test_case): y = torch.permute(x, dims=permute_list) return y - @autotest(check_graph=False) + @autotest(check_graph=True) def test_permute5d_tensor_with_random_data(test_case): device = random_device() ndim = 5 @@ -101,7 +101,7 @@ def test_permute5d_tensor_with_random_data(test_case): y = x.permute(permute_list) return y - @autotest(check_graph=False) + @autotest(check_graph=True) def test_permute4d_tensor_with_random_data(test_case): device = random_device() ndim = 4 @@ -117,7 +117,7 @@ def test_permute4d_tensor_with_random_data(test_case): y = x.permute(permute_list) return y - @autotest(check_graph=False) + @autotest(check_graph=True) def test_permute3d_tensor_with_random_data(test_case): device = random_device() ndim = 3 diff --git a/python/oneflow/test/modules/test_prod.py b/python/oneflow/test/modules/test_prod.py index b9ce7694461..1da4bd16a32 100644 --- a/python/oneflow/test/modules/test_prod.py +++ b/python/oneflow/test/modules/test_prod.py @@ -22,7 +22,7 @@ @flow.unittest.skip_unless_1n1d() class TestReduceProd(flow.unittest.TestCase): - @autotest(check_graph=False) + @autotest(check_graph=True) def test_reduce_prod_without_dim(test_case): device = random_device() ndim = random(1, 5).to(int) @@ -31,7 +31,7 @@ def test_reduce_prod_without_dim(test_case): return y - @autotest(check_graph=False) + @autotest(check_graph=True) def test_reduce_prod_with_dim(test_case): device = random_device() ndim = random(1, 5).to(int) diff --git a/python/oneflow/test/modules/test_randint.py b/python/oneflow/test/modules/test_randint.py index cba1af0cd49..e7a02b96efd 100644 --- a/python/oneflow/test/modules/test_randint.py +++ b/python/oneflow/test/modules/test_randint.py @@ -148,5 +148,46 @@ def test_0rank_randint(test_case): arg[0](test_case, *arg[1:]) +def _test_consistent_rand(test_case, low, high, shape, placement, sbp): + x = flow.randint(low, high, shape, placement=placement, sbp=sbp) + test_case.assertEqual(x.shape, shape) + test_case.assertEqual(x.sbp, sbp) + test_case.assertEqual(x.placement, placement) + + +def _test_consistent_rand_graph(test_case, low, high, shape, placement, sbp): + class ConsistentRandGraph(flow.nn.Graph): + def __init__(self,): + super().__init__() + + def build(self): + x = flow.randint(low, high, shape, placement=placement, sbp=sbp) + return x + + c_r_g = ConsistentRandGraph() + x = c_r_g() + test_case.assertEqual(x.shape, shape) + test_case.assertEqual(x.sbp, sbp) + test_case.assertEqual(x.placement, placement) + + +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +@flow.unittest.skip_unless_1n2d() +class TestRandintConsistent(flow.unittest.TestCase): + def test_rand_consistent(test_case): + arg_dict = OrderedDict() + arg_dict["test_fun"] = [_test_consistent_rand, _test_consistent_rand_graph] + arg_dict["low"] = [i for i in range(2)] + arg_dict["high"] = [1000 + np.random.randint(1, 10) for i in range(2)] + arg_dict["shape"] = [(2, 3, 4), (2, 5, 2)] + arg_dict["placement"] = [ + flow.placement("cpu", {0: [0, 1]}), + flow.placement("cuda", {0: [0, 1]}), + ] + arg_dict["sbp"] = [(flow.sbp.broadcast,), (flow.sbp.split(0),)] + for arg in GenArgList(arg_dict): + arg[0](test_case, *arg[1:]) + + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_randperm.py b/python/oneflow/test/modules/test_randperm.py index 5bba21aadc2..14cfef2709e 100644 --- a/python/oneflow/test/modules/test_randperm.py +++ b/python/oneflow/test/modules/test_randperm.py @@ -106,7 +106,7 @@ def test_randperm_backward(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) - @autotest(auto_backward=False, check_graph=False) + @autotest(auto_backward=False, check_graph=True) def test_auto_1(test_case): device = random_device() y = torch.randperm(1, device=device) diff --git a/python/oneflow/test/modules/test_reciprocal.py b/python/oneflow/test/modules/test_reciprocal.py index 6cf2557d4ec..544ef09eb08 100644 --- a/python/oneflow/test/modules/test_reciprocal.py +++ b/python/oneflow/test/modules/test_reciprocal.py @@ -28,7 +28,7 @@ @flow.unittest.skip_unless_1n1d() class TestReciprocalModule(flow.unittest.TestCase): - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flow_reciprocal_list_with_random_data(test_case): device = random_device() x = random_pytorch_tensor( diff --git a/python/oneflow/test/modules/test_repeat.py b/python/oneflow/test/modules/test_repeat.py index 89c0c2d7572..2a4fdec9ce9 100644 --- a/python/oneflow/test/modules/test_repeat.py +++ b/python/oneflow/test/modules/test_repeat.py @@ -23,7 +23,7 @@ @flow.unittest.skip_unless_1n1d() class TestRepeat(flow.unittest.TestCase): - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flow_tensor_repeat_with_random_data(test_case): x = random_pytorch_tensor(ndim=2, dim0=1, dim1=2) sizes = (random(1, 5).to(int), random(1, 5).to(int), random(1, 5).to(int)) diff --git a/python/oneflow/test/modules/test_reshape.py b/python/oneflow/test/modules/test_reshape.py index f69839be09f..d9e129fb3d8 100644 --- a/python/oneflow/test/modules/test_reshape.py +++ b/python/oneflow/test/modules/test_reshape.py @@ -102,8 +102,8 @@ def test_reshape_flow_with_random_data(test_case): y = torch.reshape(x, shape=(-1,)) return y - @autotest(auto_backward=False, check_graph=False) - def test_reshape_with_0shape_data(test_case): + @autotest(auto_backward=False, check_graph=True) + def test_reshape_with_0_size_data(test_case): device = random_device() x = random_pytorch_tensor(4, 2, 0, 3).to(device) y = torch.reshape( diff --git a/python/oneflow/test/modules/test_round.py b/python/oneflow/test/modules/test_round.py index 2069a160eb0..c2baf1777ad 100644 --- a/python/oneflow/test/modules/test_round.py +++ b/python/oneflow/test/modules/test_round.py @@ -26,7 +26,7 @@ @flow.unittest.skip_unless_1n1d() class TestRound(flow.unittest.TestCase): - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flow_round_with_random_data(test_case): device = random_device() x = random_pytorch_tensor().to(device) diff --git a/python/oneflow/test/modules/test_sign.py b/python/oneflow/test/modules/test_sign.py index 55663c14c28..78f7a9e7488 100644 --- a/python/oneflow/test/modules/test_sign.py +++ b/python/oneflow/test/modules/test_sign.py @@ -49,15 +49,15 @@ def test_sign(test_case): for arg in GenArgList(arg_dict): _test_sign_impl(test_case, *arg) - @autotest(check_graph=False) + @autotest(check_graph=True) def test_sign_with_random_data(test_case): device = random_device() x = random_pytorch_tensor().to(device) y = torch.sign(x) return y - @autotest(auto_backward=False, check_graph=False) - def test_sign_with_0shape_data(test_case): + @autotest(auto_backward=False, check_graph=True) + def test_sign_with_0_size_data(test_case): device = random_device() x = random_pytorch_tensor(4, 2, 3, 0, 4).to(device) y = torch.sign(x) diff --git a/python/oneflow/test/modules/test_split.py b/python/oneflow/test/modules/test_split.py index 82460db2c46..7b7f04215ce 100644 --- a/python/oneflow/test/modules/test_split.py +++ b/python/oneflow/test/modules/test_split.py @@ -30,8 +30,8 @@ def test_flow_split_with_random_data(test_case): k2 = random(2, 6) rand_dim = random(0, 3).to(int) device = random_device() - x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim3=k2).to(device) - res = torch.split(x, split_size_or_sections=2, dim=rand_dim) + x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device) + res = torch.split(x, 2, dim=rand_dim) return torch.cat(res, rand_dim) @autotest(check_graph=False) @@ -40,8 +40,8 @@ def test_flow_split_sizes_with_random_data(test_case): k1 = 7 k2 = random(2, 6) device = random_device() - x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim3=k2).to(device) - res = torch.split(x, split_size_or_sections=[1, 2, 3, 1], dim=1) + x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device) + res = torch.split(x, [1, 2, 3, 1], dim=1) return torch.cat(res, dim=1) @autotest(check_graph=False) @@ -50,8 +50,8 @@ def test_flow_split_sizes_neg_dim_with_random_data(test_case): k1 = 7 k2 = random(2, 6) device = random_device() - x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim3=k2).to(device) - res = torch.split(x, split_size_or_sections=[1, 2, 3, 1], dim=-2) + x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device) + res = torch.split(x, [1, 2, 3, 1], dim=-2) return torch.cat(res, dim=1) diff --git a/python/oneflow/test/modules/test_sqrt_square_sum.py b/python/oneflow/test/modules/test_sqrt_square_sum.py new file mode 100644 index 00000000000..924cf661aa2 --- /dev/null +++ b/python/oneflow/test/modules/test_sqrt_square_sum.py @@ -0,0 +1,58 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest + +import oneflow as flow +import oneflow.unittest + +from oneflow.test_utils.automated_test_util import * + + +@flow.unittest.skip_unless_1n1d() +class TestLinalgVectorNorm2D(flow.unittest.TestCase): + @autotest(n=30, auto_backward=False, check_graph=True, rtol=0.5, atol=0.5) + def test_sqrt_sum_with_cpu_random_data(test_case): + device = cpu_device() + x = random_pytorch_tensor( + ndim=4, dim1=30, dim2=40, dim3=50, requires_grad=False + ).to(device) + y = torch.linalg.norm(x) + return y + + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") + @autotest(n=30, auto_backward=False, check_graph=True) + def test_sqrt_sum_with_cuda_random_data(test_case): + device = gpu_device() + x = random_pytorch_tensor( + ndim=4, dim1=100, dim2=100, dim3=100, requires_grad=False + ).to(device) + y = torch.linalg.norm(x) + return y + + @autotest(n=30, auto_backward=False, check_graph=False, rtol=0.5, atol=0.5) + def test_scalar_print_random_data(test_case): + device = random_device() + x = random_pytorch_tensor( + ndim=4, dim1=30, dim2=40, dim3=50, requires_grad=False + ).to(device) + y = torch.linalg.norm(x) + print(f"grad_norm {y.oneflow:.4f}\t") + return y + + +if __name__ == "__main__": + unittest.main() diff --git a/python/oneflow/test/modules/test_squeeze.py b/python/oneflow/test/modules/test_squeeze.py index 40b113b71f3..ca777688bb0 100644 --- a/python/oneflow/test/modules/test_squeeze.py +++ b/python/oneflow/test/modules/test_squeeze.py @@ -104,15 +104,15 @@ def test_squeeze(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flow_squeeze_with_random_data(test_case): device = random_device() x = random_pytorch_tensor().to(device) y = torch.squeeze(x, random(1, 3).to(int)) return y - @autotest(auto_backward=False, check_graph=False) - def test_squeeze_with_0shape_data(test_case): + @autotest(auto_backward=False, check_graph=True) + def test_squeeze_with_0_size_data(test_case): device = random_device() x = random_pytorch_tensor(3, 2, 1, 0).to(device) y = torch.squeeze(x) diff --git a/python/oneflow/test/modules/test_stack.py b/python/oneflow/test/modules/test_stack.py index 80e8d22ef51..b385a180c11 100644 --- a/python/oneflow/test/modules/test_stack.py +++ b/python/oneflow/test/modules/test_stack.py @@ -24,7 +24,7 @@ @flow.unittest.skip_unless_1n1d() class TestStackModule(flow.unittest.TestCase): - @autotest(check_graph=False) + @autotest(check_graph=True) def test_stack_with_random_data(test_case): device = random_device() x = random_pytorch_tensor(ndim=4, dim1=3, dim2=4, dim3=5).to(device) diff --git a/python/oneflow/test/modules/test_stateful_kernel_with_cache.py b/python/oneflow/test/modules/test_stateful_kernel_with_cache.py new file mode 100644 index 00000000000..78f32b05295 --- /dev/null +++ b/python/oneflow/test/modules/test_stateful_kernel_with_cache.py @@ -0,0 +1,48 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +import os + +import numpy as np +import oneflow as flow +import oneflow.unittest + + +@flow.unittest.skip_unless_1n2d() +@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") +class TestStatefulKernelWithInpersistentState(flow.unittest.TestCase): + def test_stateful_kernel_with_inpersistent_state(test_case): + x = flow.arange(4).reshape(2, 2) + x = x.to_consistent(flow.env.all_device_placement("cuda"), flow.sbp.split(0)) + y = flow._C.logical_slice(x, [0, 0], [3, 1], [1, 1]) + y_np = np.array([[0], [2], [0]]) + test_case.assertTrue( + np.array_equal( + y.to_consistent(sbp=flow.sbp.broadcast).to_local().numpy(), y_np + ) + ) + x = x.to_consistent(sbp=flow.sbp.split(1)) + y = flow._C.logical_slice(x, [0, 0], [3, 1], [1, 1]) + test_case.assertTrue( + np.array_equal( + y.to_consistent(sbp=flow.sbp.broadcast).to_local().numpy(), y_np + ) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/oneflow/test/modules/test_stateful_local_opkernel.py b/python/oneflow/test/modules/test_stateful_local_opkernel.py index 9c31dc6db72..be8ec2c807f 100644 --- a/python/oneflow/test/modules/test_stateful_local_opkernel.py +++ b/python/oneflow/test/modules/test_stateful_local_opkernel.py @@ -27,45 +27,12 @@ class TestStatefulLocalKernel(flow.unittest.TestCase): @flow.unittest.skip_unless_1n1d() def test_dynamic_attrs(test_case): - x = ( - flow.builtin_op("constant") - .Output("out") - .Attr("is_floating_value", True) - .Attr("floating_value", 3.0) - .Attr("dtype", flow.float32) - .Attr("shape", [2, 3]) - .Build()()[0] - ) - op = flow.builtin_op("expand_dims").Input("in").Output("out").Build() - y = op(x, axis=1)[0] + x = flow.full((2, 3), 3.0) + y = flow.unsqueeze(x, dim=1) test_case.assertEqual(y.shape, flow.Size((2, 1, 3))) - y = op(x, axis=2)[0] + y = flow.unsqueeze(x, dim=2) test_case.assertEqual(y.shape, flow.Size((2, 3, 1))) - @flow.unittest.skip_unless_1n1d() - def test_stateful_local_kernel(test_case): - op1 = ( - flow.builtin_op("constant") - .Output("out") - .Attr("is_floating_value", True) - .Attr("floating_value", 3.0) - .Attr("dtype", flow.float32) - .Attr("shape", [1, 1]) - .Build() - ) - op2 = ( - flow.builtin_op("matmul") - .Input("a") - .Input("b") - .Attr("transpose_a", False) - .Attr("transpose_b", False) - .Attr("alpha", float(1.0)) - .Output("out") - .Build() - ) - x = op1()[0] - x = op2(x, x)[0] - @flow.unittest.skip_unless_1n2d() def test_stateful_local_kernel_in_consistent_mode(test_case): rank = int(os.getenv("RANK")) diff --git a/python/oneflow/test/modules/test_std.py b/python/oneflow/test/modules/test_std.py index 1f6ea8368d7..fd97d1074ee 100644 --- a/python/oneflow/test/modules/test_std.py +++ b/python/oneflow/test/modules/test_std.py @@ -22,7 +22,7 @@ @flow.unittest.skip_unless_1n1d() class TestStd(flow.unittest.TestCase): - @autotest(n=10, auto_backward=False, rtol=0.01, atol=0.01, check_graph=False) + @autotest(n=10, auto_backward=False, rtol=0.01, atol=0.01, check_graph=True) def test_std_flow_with_random_data(test_case): device = random_device() all_dim = random().to(int) @@ -33,7 +33,7 @@ def test_std_flow_with_random_data(test_case): ) return z - @autotest(n=10, auto_backward=False, rtol=0.01, atol=0.01, check_graph=False) + @autotest(n=10, auto_backward=False, rtol=0.01, atol=0.01, check_graph=True) def test_std_tensor_with_random_data(test_case): device = random_device() dim = random(low=0, high=4).to(int) diff --git a/python/oneflow/test/modules/test_sub.py b/python/oneflow/test/modules/test_sub.py index 79495994e36..973c09d73b5 100644 --- a/python/oneflow/test/modules/test_sub.py +++ b/python/oneflow/test/modules/test_sub.py @@ -128,7 +128,7 @@ def test_sub_against_pytorch(test_case): ) @autotest(auto_backward=False, check_graph=False) - def test_sub_with_0shape_data(test_case): + def test_sub_with_0_size_data(test_case): device = random_device() x = random_pytorch_tensor(2, 0, 3).to(device) y = random_pytorch_tensor(2, 1, 3).to(device) diff --git a/python/oneflow/test/modules/test_sum.py b/python/oneflow/test/modules/test_sum.py index bc1afb1fa85..004856c03c1 100644 --- a/python/oneflow/test/modules/test_sum.py +++ b/python/oneflow/test/modules/test_sum.py @@ -77,8 +77,8 @@ def test_sum_against_pytorch(test_case): y = torch.sum(x) return y - @autotest(auto_backward=False, check_graph=False) - def test_sum_with_0shape_tensor(test_case): + @autotest(auto_backward=False, check_graph=True) + def test_sum_with_0_size_tensor(test_case): device = random_device() x = random_pytorch_tensor(4, 4, 3, 0, 2).to(device) y = torch.sum(x, dim=np.random.randint(0, 3)) diff --git a/python/oneflow/test/modules/test_tensor_str.py b/python/oneflow/test/modules/test_tensor_str.py index b6a29629e65..0d68458648f 100644 --- a/python/oneflow/test/modules/test_tensor_str.py +++ b/python/oneflow/test/modules/test_tensor_str.py @@ -140,12 +140,14 @@ def _test_consistent_tensor_str_2d(test_case, device): x = flow.ones((100, 100), placement=placement, sbp=[flow.sbp.split(0)]) tensor_str = str(x) test_case.assertTrue("1." in tensor_str) - test_case.assertTrue("..." in tensor_str) + # TODO: this test has bug + # test_case.assertTrue("..." in tensor_str) x = flow.ones((100, 100), placement=placement, sbp=[flow.sbp.split(1)]) tensor_str = str(x) test_case.assertTrue("1." in tensor_str) - test_case.assertTrue("..." in tensor_str) + # TODO: this test has bug + # test_case.assertTrue("..." in tensor_str) x = flow.ones( (10, 10), placement=flow.placement(device, {0: [0]}), sbp=[flow.sbp.broadcast] @@ -153,6 +155,10 @@ def _test_consistent_tensor_str_2d(test_case, device): tensor_str = str(x) test_case.assertTrue("1." in tensor_str) + x = flow.ones((2, 5), placement=placement, sbp=[flow.sbp.split(0)]) + tensor_str = str(x) + test_case.assertTrue("1." in tensor_str) + class TestTensorStrModule(flow.unittest.TestCase): @flow.unittest.skip_unless_1n1d() diff --git a/python/oneflow/test/modules/test_tensor_to.py b/python/oneflow/test/modules/test_tensor_to.py index f0e6854ac91..c54247cd991 100644 --- a/python/oneflow/test/modules/test_tensor_to.py +++ b/python/oneflow/test/modules/test_tensor_to.py @@ -59,6 +59,15 @@ def test_consistent_tensor_to(test_case): test_case.assertEqual(cloned_local[0].numpy().item(), 0) test_case.assertEqual(x.to_local()[0].numpy().item(), 1) + def test_tensor_to_h2d1(test_case): + input = flow.tensor(np.random.randn(2, 3, 4, 5), dtype=flow.int64) + output = input.to(device=flow.device("cuda:1"), dtype=flow.int32) + test_case.assertEqual(output.device, flow.device("cuda:1")) + test_case.assertEqual(output.dtype, flow.int32) + test_case.assertTrue( + np.allclose(input.numpy(), output.numpy(), rtol=0.0001, atol=0.0001) + ) + @flow.unittest.skip_unless_1n1d() @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") diff --git a/python/oneflow/test/modules/test_tile.py b/python/oneflow/test/modules/test_tile.py index 915f730eff5..944213024a3 100644 --- a/python/oneflow/test/modules/test_tile.py +++ b/python/oneflow/test/modules/test_tile.py @@ -23,14 +23,14 @@ @flow.unittest.skip_unless_1n1d() class TestTile(flow.unittest.TestCase): - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flow_tile_with_random_data(test_case): x = random_pytorch_tensor(ndim=2, dim0=1, dim1=2) reps = (random(1, 5).to(int), random(1, 5).to(int), random(1, 5).to(int)) z = torch.tile(x, reps) return z - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flow_tensor_tile_with_random_data(test_case): x = random_pytorch_tensor(ndim=2, dim0=1, dim1=2) reps = (random(1, 5).to(int), random(1, 5).to(int), random(1, 5).to(int)) diff --git a/python/oneflow/test/modules/test_transpose.py b/python/oneflow/test/modules/test_transpose.py index 7d7ef4721a9..d495ee4b4e8 100644 --- a/python/oneflow/test/modules/test_transpose.py +++ b/python/oneflow/test/modules/test_transpose.py @@ -103,8 +103,8 @@ def test_transpose_flow_with_random_data(test_case): y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int)) return y - @autotest(auto_backward=False, check_graph=False) - def test_transpose_with_0shape_data(test_case): + @autotest(auto_backward=False, check_graph=True) + def test_transpose_with_0_size_data(test_case): device = random_device() x = random_pytorch_tensor(4, 2, 3, 0, 4).to(device) y = torch.transpose(x, dim0=random(1, 3).to(int), dim1=random(1, 3).to(int)) diff --git a/python/oneflow/test/modules/test_tril.py b/python/oneflow/test/modules/test_tril.py index 8a40b2fdad3..d86b77baeff 100644 --- a/python/oneflow/test/modules/test_tril.py +++ b/python/oneflow/test/modules/test_tril.py @@ -22,7 +22,7 @@ @flow.unittest.skip_unless_1n1d() class TestTril(flow.unittest.TestCase): - @autotest(check_graph=False) + @autotest(check_graph=True) def test_tril_without_diag(test_case): device = random_device() x = random_pytorch_tensor( @@ -37,7 +37,7 @@ def test_tril_without_diag(test_case): return y - @autotest(check_graph=False) + @autotest(check_graph=True) def test_tril_with_diag(test_case): device = random_device() diagonal = random(-3, 3).to(int) diff --git a/python/oneflow/test/modules/test_triu.py b/python/oneflow/test/modules/test_triu.py index 0c4e1533a23..de5a48851fc 100644 --- a/python/oneflow/test/modules/test_triu.py +++ b/python/oneflow/test/modules/test_triu.py @@ -52,8 +52,8 @@ def test_triu(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) - @autotest(auto_backward=False, check_graph=False) - def test_triu_with_0shape_data(test_case): + @autotest(auto_backward=False, check_graph=True) + def test_triu_with_0_size_data(test_case): device = random_device() x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) y = torch.triu(x) diff --git a/python/oneflow/test/modules/test_unfold_tensor.py b/python/oneflow/test/modules/test_unfold_tensor.py index 54049b60075..27e81d8b715 100644 --- a/python/oneflow/test/modules/test_unfold_tensor.py +++ b/python/oneflow/test/modules/test_unfold_tensor.py @@ -24,7 +24,7 @@ @flow.unittest.skip_unless_1n1d() class TestUnfoldTensor(flow.unittest.TestCase): - @autotest(n=10, auto_backward=True, check_graph=False) + @autotest(n=10, auto_backward=True, check_graph=True) def test_unfold_tensor_with_random_data(test_case): device = random_device() x = random_pytorch_tensor(3, 3, 4, 5).to(device) diff --git a/python/oneflow/test/modules/test_unsqueeze.py b/python/oneflow/test/modules/test_unsqueeze.py index 94ea98735f2..7a5c9fa5c11 100644 --- a/python/oneflow/test/modules/test_unsqueeze.py +++ b/python/oneflow/test/modules/test_unsqueeze.py @@ -77,15 +77,15 @@ def test_unsqueeze(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flow_unsqueeze_with_random_data(test_case): device = random_device() x = random_pytorch_tensor().to(device) y = torch.unsqueeze(x, random(1, 3).to(int)) return y - @autotest(auto_backward=False, check_graph=False) - def test_unsqueeze_with_0shape_data(test_case): + @autotest(auto_backward=False, check_graph=True) + def test_unsqueeze_with_0_size_data(test_case): device = random_device() x = random_pytorch_tensor(3, 2, 1, 0).to(device) y = torch.unsqueeze(x, random(0, 2).to(int)) diff --git a/python/oneflow/test/modules/test_var.py b/python/oneflow/test/modules/test_var.py index 2a19cf66885..b7c38776e7e 100644 --- a/python/oneflow/test/modules/test_var.py +++ b/python/oneflow/test/modules/test_var.py @@ -23,7 +23,6 @@ class TestVar(flow.unittest.TestCase): - @autotest(check_graph=False) def test_flow_var_all_dim_with_random_data(test_case): device = random_device() x = random_pytorch_tensor().to(device) @@ -43,8 +42,8 @@ def test_flow_var_one_dim_with_random_data(test_case): return y @unittest.skip("var not support 0-shape tensor currently") - @autotest() - def test_flow_var_0d_tensor_with_random_data(test_case): + @autotest(check_graph=False) + def test_flow_var_0_size_data_with_random_data(test_case): device = random_device() x = random_pytorch_tensor(4, 2, 3, 0, 4).to(device) y = torch.var( diff --git a/python/oneflow/test/modules/test_where.py b/python/oneflow/test/modules/test_where.py index e0446cf138f..a437eb1a532 100644 --- a/python/oneflow/test/modules/test_where.py +++ b/python/oneflow/test/modules/test_where.py @@ -209,7 +209,7 @@ def test_where(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flow_where_tensor_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) @@ -219,7 +219,7 @@ def test_flow_where_tensor_with_random_data(test_case): y = random_pytorch_tensor(ndim=2, dim0=k1, dim1=k2).to(device) return torch.where(cond > 0, x, y) - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flow_where_tensor_broadcast_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) @@ -229,7 +229,7 @@ def test_flow_where_tensor_broadcast_with_random_data(test_case): y = random_pytorch_tensor(ndim=2, dim0=k1, dim1=1).to(device) return torch.where(cond > 0, x, y) - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flow_where_scalar_x_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) @@ -241,7 +241,7 @@ def test_flow_where_scalar_x_with_random_data(test_case): ) return torch.where(cond > 0, x, y) - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flow_where_scalar_x_broadcast_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) @@ -253,7 +253,7 @@ def test_flow_where_scalar_x_broadcast_with_random_data(test_case): ) return torch.where(cond > 0, x, y) - @autotest(auto_backward=False, check_graph=False) + @autotest(auto_backward=False, check_graph=True) def test_flow_where_scalar_x_int_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) @@ -263,7 +263,7 @@ def test_flow_where_scalar_x_int_with_random_data(test_case): y = random_pytorch_tensor(ndim=2, dim0=k1, dim1=k2, dtype=int).to(device) return torch.where(cond > 0, x, y) - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flow_where_scalar_y_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) @@ -275,7 +275,7 @@ def test_flow_where_scalar_y_with_random_data(test_case): y = random().to(float) return torch.where(cond > 0, x, y) - @autotest(check_graph=False) + @autotest(check_graph=True) def test_flow_where_scalar_y_broadcast_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) @@ -287,7 +287,7 @@ def test_flow_where_scalar_y_broadcast_with_random_data(test_case): y = random().to(float) return torch.where(cond > 0, x, y) - @autotest(auto_backward=False, check_graph=False) + @autotest(auto_backward=False, check_graph=True) def test_flow_where_scalar_y_int_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) @@ -297,7 +297,7 @@ def test_flow_where_scalar_y_int_with_random_data(test_case): y = random().to(int) return torch.where(cond > 0, x, y) - @autotest(auto_backward=False, check_graph=False) + @autotest(auto_backward=False, check_graph=True) def test_flow_where_scalar_xy_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) @@ -307,7 +307,7 @@ def test_flow_where_scalar_xy_with_random_data(test_case): y = random().to(float) return torch.where(cond > 0, x, y) - @autotest(auto_backward=False, check_graph=False) + @autotest(auto_backward=False, check_graph=True) def test_flow_where_scalar_xy_int_with_random_data(test_case): k1 = random(2, 6) k2 = random(2, 6) diff --git a/python/oneflow/test/tensor/test_parameter.py b/python/oneflow/test/tensor/test_parameter.py index 946f4c9f2eb..029e6f89cd2 100644 --- a/python/oneflow/test/tensor/test_parameter.py +++ b/python/oneflow/test/tensor/test_parameter.py @@ -25,14 +25,14 @@ @flow.unittest.skip_unless_1n1d() class TestParameter(flow.unittest.TestCase): - @autotest(n=1) + @autotest(n=1, check_graph=False) def test_parameter_grad_fn_none(test_case): x = torch.ones(2, 3).requires_grad_(True) y = x + x z = torch.nn.Parameter(y) return z.grad_fn - @autotest(n=1) + @autotest(n=1, check_graph=False) def test_parameter_set_data_autograd_meta(test_case): x = torch.ones(2, 3).requires_grad_(True) y = x + x diff --git a/python/oneflow/test/tensor/test_tensor.py b/python/oneflow/test/tensor/test_tensor.py index 2a8934fbe16..0e554d7e9ea 100644 --- a/python/oneflow/test/tensor/test_tensor.py +++ b/python/oneflow/test/tensor/test_tensor.py @@ -351,21 +351,22 @@ def test_mirrored_tensor_and_op(test_case): test_case.assertEqual(x1.dtype, flow.float32) test_case.assertEqual(x1.shape, flow.Size((1, 2))) x2 = flow.Tensor([[1.0], [2.0]]) - op = ( - flow.builtin_op("matmul") - .Input("a") - .Input("b") - .Attr("transpose_a", False) - .Attr("transpose_b", False) - .Attr("alpha", float(1.0)) - .Output("out") - .Build() - ) - y = op(x1, x2)[0] + y = flow.matmul(x1, x2) test_case.assertTrue( np.allclose(y.numpy(), np.array([[5.0]], dtype=np.float32)) ) + @flow.unittest.skip_unless_1n1d() + @autotest(check_graph=False) + def test_matmul_with_random_data(test_case): + device = random_device() + dim0 = random(low=2, high=10).to(int) + dim1 = random(low=3, high=20).to(int) + dim2 = random(low=2, high=11).to(int) + a = random_pytorch_tensor(ndim=2, dim0=dim0, dim1=dim1) + b = random_pytorch_tensor(ndim=2, dim0=dim1, dim1=dim2) + return a @ b + @flow.unittest.skip_unless_1n1d() def test_tensor_to_list(test_case): list_data = [[1.0, 3.0], [5.0, 6.0]] @@ -716,8 +717,8 @@ def test_flow_fmod_scalar_with_random_data(test_case): other = 3 return input.fmod(other) - @autotest(auto_backward=False, check_graph=False) - def test_fmod_with_0shape_data(test_case): + @autotest(auto_backward=False, check_graph=True) + def test_fmod_with_0_size_data(test_case): device = random_device() x = random_pytorch_tensor(4, 2, 1, 0, 3).to(device) y = x.fmod(2) @@ -1623,6 +1624,29 @@ def test_tensor_bmm(test_case): of_out = input1.bmm(input2) return of_out + @flow.unittest.skip_unless_1n1d() + @autotest(check_graph=False) + def test_tensor_split(test_case): + k0 = random(2, 6) + k1 = random(2, 6) + k2 = random(2, 6) + rand_dim = random(0, 3).to(int) + device = random_device() + x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device) + res = x.split(2, dim=rand_dim) + return torch.cat(res, rand_dim) + + @flow.unittest.skip_unless_1n1d() + @autotest(check_graph=False) + def test_tensor_split_sizes(test_case): + k0 = random(2, 6) + k1 = 7 + k2 = random(2, 6) + device = random_device() + x = random_pytorch_tensor(ndim=3, dim0=k0, dim1=k1, dim2=k2).to(device) + res = x.split([1, 2, 3, 1], dim=-2) + return torch.cat(res, dim=1) + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/tensor/test_tensor_indexing.py b/python/oneflow/test/tensor/test_tensor_indexing.py index 3a3004f2398..4a071196f6a 100644 --- a/python/oneflow/test/tensor/test_tensor_indexing.py +++ b/python/oneflow/test/tensor/test_tensor_indexing.py @@ -152,6 +152,21 @@ def test_advanced_indexing_array(test_case, numpy_x, dtype): test_case.assertTrue(np.allclose(numpy_x[idx, idx, :], x[idx, idx, :].numpy())) test_case.assertTrue(np.allclose(numpy_x[idx, idx, idx], x[idx, idx, idx].numpy())) + idx1 = np.array([[1, 0, 1], [1, 1, 0]]) + idx2 = np.array([[0], [1]]) + test_case.assertTrue( + np.allclose(numpy_x[:, idx1, :, idx2].shape, x[:, idx1, :, idx2].shape) + ) + test_case.assertTrue( + np.allclose(numpy_x[:, idx1, 1, idx2].shape, x[:, idx1, 1, idx2].shape) + ) + test_case.assertTrue( + np.allclose(numpy_x[idx1, :, idx2, :].shape, x[idx1, :, idx2, :].shape) + ) + test_case.assertTrue( + np.allclose(numpy_x[:, idx1, idx2, :].shape, x[:, idx1, idx2, :].shape) + ) + def test_combining_indexing(test_case, numpy_x): x = flow.tensor(numpy_x) @@ -233,7 +248,7 @@ def test_advanced_indexing(test_case): test_advanced_indexing(test_case, numpy_x) def test_advanced_indexing_array(test_case): - numpy_x = np.arange(0, 60, 1).reshape([3, 4, 5]).astype(np.float32) + numpy_x = np.arange(0, 60, 1).reshape([3, 2, 2, 5]).astype(np.float32) test_advanced_indexing_array(test_case, numpy_x, np.int32) test_advanced_indexing_array(test_case, numpy_x, np.int64) @@ -241,7 +256,7 @@ def test_advanced_indexing_array(test_case): test_advanced_indexing_array(test_case, numpy_x, np.int32) test_advanced_indexing_array(test_case, numpy_x, np.int64) - numpy_x = np.arange(0, 720, 1).reshape([8, 9, 10]).astype(np.float32) + numpy_x = np.arange(0, 720, 1).reshape([5, 8, 9, 2]).astype(np.float32) test_advanced_indexing_array(test_case, numpy_x, np.int32) test_advanced_indexing_array(test_case, numpy_x, np.int64) diff --git a/python/oneflow/test_utils/automated_test_util/generators.py b/python/oneflow/test_utils/automated_test_util/generators.py index 81e2dd8ef56..555dcf84bb4 100644 --- a/python/oneflow/test_utils/automated_test_util/generators.py +++ b/python/oneflow/test_utils/automated_test_util/generators.py @@ -362,6 +362,22 @@ def _calc_value(self): return random_util.choice(["cuda", "cpu"]) +class cpu_device(generator): + def __init__(self): + super().__init__([]) + + def _calc_value(self): + return random_util.choice(["cpu"]) + + +class gpu_device(generator): + def __init__(self): + super().__init__([]) + + def _calc_value(self): + return random_util.choice(["cuda"]) + + def test_against_pytorch( test_case, callable_name, @@ -649,6 +665,8 @@ def test_tensor_against_pytorch( "random_tensor", "random_bool", "random_device", + "cpu_device", + "gpu_device", "random", "random_or_nothing", "oneof", diff --git a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py index 2681e1f3f43..0c058f49c58 100644 --- a/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py +++ b/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py @@ -22,8 +22,13 @@ import numpy as np import oneflow as flow +flow.backends.cudnn.deterministic = True + try: import torch as torch_original + + torch_original.backends.cudnn.deterministic = True + torch_original.set_printoptions(profile="full") except ImportError: print( "automated_test_util module uses PyTorch to verify OneFlow module's interface and result. Please install Pytorch according `https://pytorch.org/get-started/locally/`." @@ -48,7 +53,7 @@ def torch_tensor_to_flow(x): vis_parameters = {} call_tensor_id = [] extra_input_tensor = set() -flow_res_id_eager2_graph = dict() +eager_tensor_2_graph_tensor = dict() class PyTorchDoesNotSupportError(Exception): @@ -62,6 +67,17 @@ def __repr__(self): return f"PyTorch error: {str(self.exc)}" +class OneFlowGraphBuildOrRunError(Exception): + def __init__(self, exc): + self.exc = exc + + def __str__(self): + return repr(self) + + def __repr__(self): + return f"OneFlow nn.Graph Build Or Run Error: {str(self.exc)}" + + class BothDoNotSupportError(Exception): def __init__(self, th_exc, of_exc): self.th_exc = th_exc @@ -211,6 +227,7 @@ def GetDualObject(name, pytorch, oneflow): "__str__", "__repr__", ] + verbose = os.getenv("ONEFLOW_TEST_VERBOSE") is not None pytorch_methods = dir(pytorch) if hasattr(pytorch, "__call__") and "__call__" not in pytorch_methods: pytorch_methods.append("__call__") @@ -277,6 +294,9 @@ def dual_method(self, *args, **kwargs): else: oneflow_res = oneflow(*oneflow_args, **oneflow_kwargs) if testing_graph: + find_check_module_func = True + ignore_apis_list = ["to", "tensor", "_to", "train"] + test_g_res = [] if isinstance(oneflow, flow.nn.Module): class TestGraphOfModule(flow.nn.Graph): @@ -288,15 +308,49 @@ def build(self, *args): return self.test_module(*args) test_g = TestGraphOfModule() + if verbose: + print("Run graph of module: ", repr(oneflow)) + test_g.debug(2) test_g_res = test_g(*oneflow_args) + elif oneflow.__name__ in ignore_apis_list: + find_check_module_func = False + # 1. "oneflow.nn.modules" not in oneflow.__module__: For avoid run nn.Module branch graph test, like fold op call Fold Module actually. + # 2. inspect.isfunction(oneflow): Compared with the ordinary flow.xxx, oneflow.nn.modules.math_ops series op exist an extra layer of python wrapper. + # 3. inspect.ismethod(oneflow) and "oneflow.nn.modules" in oneflow.__module__: For op that only has Tensor.xxx method, and call oneflow.xxx actually, like masked_fill. + elif ( + "oneflow.nn.modules" not in oneflow.__module__ + or inspect.isfunction(oneflow) + or ( + inspect.ismethod(oneflow) + and "oneflow.nn.modules" in oneflow.__module__ + ) + ): + + class TestGraphOfFunctional(flow.nn.Graph): + def __init__(self): + super().__init__() + self.test_module_func = oneflow + + def build(self): + return self.test_module_func( + *oneflow_args, **oneflow_kwargs + ) + + try: + test_g = TestGraphOfFunctional() + test_g_res = test_g() + except Exception as e: + print_note_fake_program() + raise OneFlowGraphBuildOrRunError(e) + if find_check_module_func: if isinstance(test_g_res, tuple): for idx, g_res in enumerate(test_g_res): - flow_res_id_eager2_graph[ - id(oneflow_res[idx]) + eager_tensor_2_graph_tensor[ + oneflow_res[idx] ] = g_res else: - flow_res_id_eager2_graph[ - id(oneflow_res) + eager_tensor_2_graph_tensor[ + oneflow_res ] = test_g_res return GetDualObject("unused", pytorch_res, oneflow_res) @@ -331,6 +385,31 @@ def dual_method(self, *args, **kwargs): ) raise PyTorchDoesNotSupportError(e) oneflow_res = oneflow_method(*oneflow_args, **oneflow_kwargs) + if testing_graph: + + class TestGraphOfTensorMethod(flow.nn.Graph): + def __init__(self): + super().__init__() + + def build(self): + return oneflow_method( + *oneflow_args, **oneflow_kwargs + ) + + try: + test_g = TestGraphOfTensorMethod() + test_g_res = test_g() + except Exception as e: + print_note_fake_program() + raise OneFlowGraphBuildOrRunError(e) + if isinstance(test_g_res, tuple): + for idx, g_res in enumerate(test_g_res): + eager_tensor_2_graph_tensor[ + oneflow_res[idx] + ] = g_res + else: + eager_tensor_2_graph_tensor[oneflow_res] = test_g_res + return GetDualObject("unused", pytorch_res, oneflow_res) return dual_method @@ -436,9 +515,10 @@ def clear_note_fake_program(): note_pytorch_kwargs.clear() call_tensor_id.clear() vis_tensor.clear() - flow_res_id_eager2_graph.clear() + eager_tensor_2_graph_tensor.clear() vis_parameters.clear() extra_input_tensor.clear() + flow.set_printoptions(profile="full") class DualObject: @@ -572,9 +652,10 @@ def new_f(test_case): dual_modules_to_test.clear() dual_objects_to_test.clear() try: + global testing_graph + # for generate fake program input tensor global testing testing = True - global testing_graph if check_graph: testing_graph = True res = f(test_case) @@ -625,6 +706,7 @@ def new_f(test_case): and id(x.pytorch) not in call_tensor_id ): vis_tensor.append(x.pytorch) + # check eager for x in dual_objects_to_test: if check_allclose: @@ -635,26 +717,35 @@ def new_f(test_case): for output in func_outputs: flow_tensor = output.oneflow if isinstance(flow_tensor, flow.Tensor): - if ( - id(flow_tensor) in flow_res_id_eager2_graph - and check_allclose - ): - test_case.assertTrue( - np.allclose( + if flow_tensor in eager_tensor_2_graph_tensor: + if check_allclose: + equality_res = np.allclose( flow_tensor.numpy(), - flow_res_id_eager2_graph[id(flow_tensor)].numpy(), + eager_tensor_2_graph_tensor[flow_tensor].numpy(), rtol=rtol, atol=atol, equal_nan=True, ) - ) + if equality_res == False: + print_note_fake_program() + print("---------Tensor Shape--------") + print(flow_tensor.shape) + print( + eager_tensor_2_graph_tensor[flow_tensor].shape + ) + test_case.assertTrue( + equality_res, + f"Check graph failed: graph result {eager_tensor_2_graph_tensor[flow_tensor].numpy()} not equals to eager result {flow_tensor.numpy()}.", + ) + if verbose: print(f"{f.__name__} test graph passed.") else: - if check_graph and check_allclose: + if check_graph: + print_note_fake_program() test_case.assertTrue( False, - f"{f.__name__} cannot find module to check graph.", + f"{f.__name__} cannot find module/function/method to check graph.", ) else: warnings.warn( diff --git a/python/oneflow/utils/data/_utils/__init__.py b/python/oneflow/utils/data/_utils/__init__.py index 3b35c7d9f4e..4fb17c6902a 100644 --- a/python/oneflow/utils/data/_utils/__init__.py +++ b/python/oneflow/utils/data/_utils/__init__.py @@ -27,7 +27,8 @@ IS_WINDOWS = sys.platform == "win32" -MP_STATUS_CHECK_INTERVAL = 60.0 +# pytorch's check interval is 5.0 seconds +MP_STATUS_CHECK_INTERVAL = 10.0 r"""Interval (in seconds) to check status of processes to avoid hanging in multiprocessing data loading. This is mainly used in getting data from another process, in which case we need to periodically check whether the diff --git a/python/oneflow/utils/data/_utils/worker.py b/python/oneflow/utils/data/_utils/worker.py index 702f4f5f1d2..4ecd33b10df 100644 --- a/python/oneflow/utils/data/_utils/worker.py +++ b/python/oneflow/utils/data/_utils/worker.py @@ -275,6 +275,7 @@ def _worker_loop( auto_collation, collate_fn, drop_last, + generator, base_seed, init_fn, worker_id, @@ -294,7 +295,7 @@ def _worker_loop( # TODO:flow.set_num_threads(1) seed = base_seed + worker_id random.seed(seed) - flow.manual_seed(seed) + generator.manual_seed(seed) if HAS_NUMPY: np_seed = _generate_state(base_seed, worker_id) import numpy as np diff --git a/python/oneflow/utils/data/dataloader.py b/python/oneflow/utils/data/dataloader.py index 58eacd9990a..84d67c20041 100644 --- a/python/oneflow/utils/data/dataloader.py +++ b/python/oneflow/utils/data/dataloader.py @@ -211,7 +211,7 @@ def __init__( timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None, multiprocessing_context=None, - generator=None, + generator=flow.Generator("cpu"), *, prefetch_factor: int = 2, persistent_workers: bool = False @@ -525,8 +525,8 @@ def __init__(self, loader: DataLoader) -> None: self._timeout = loader.timeout self._collate_fn = loader.collate_fn self._sampler_iter = iter(self._index_sampler) + self._generator = loader.generator self._base_seed = flow.tensor([0], dtype=flow.int64).uniform_().numpy().item() - # TODO: flow.empty() # self._base_seed = flow.empty((), dtype=flow.int64).random_(generator=loader.generator).item() self._persistent_workers = loader.persistent_workers self._num_yielded = 0 @@ -944,6 +944,7 @@ def __init__(self, loader): self._auto_collation, self._collate_fn, self._drop_last, + self._generator, self._base_seed, self._worker_init_fn, i, diff --git a/python/oneflow/utils/data/distributed.py b/python/oneflow/utils/data/distributed.py index 88d9365a36e..67d82dbb67a 100644 --- a/python/oneflow/utils/data/distributed.py +++ b/python/oneflow/utils/data/distributed.py @@ -124,7 +124,7 @@ def __init__( def __iter__(self) -> Iterator[T_co]: if self.shuffle: # deterministically shuffle based on epoch and seed - g = flow.Generator() + g = flow.Generator("cpu") g.manual_seed(self.seed + self.epoch) indices = flow._C.randperm(len(self.dataset), generator=g).tolist() else: diff --git a/python/oneflow/utils/data/sampler.py b/python/oneflow/utils/data/sampler.py index 9755d8ef2ee..4eedce6f7cf 100644 --- a/python/oneflow/utils/data/sampler.py +++ b/python/oneflow/utils/data/sampler.py @@ -140,7 +140,7 @@ def num_samples(self) -> int: def __iter__(self): n = len(self.data_source) if self.generator is None: - generator = flow.Generator() + generator = flow.Generator("cpu") generator.manual_seed(np.random.randint(0, np.iinfo(np.int64).max)) # TODO: use Tensor.random_ # generator.manual_seed( diff --git a/tools/functional/generate_dispatch_stateful_ops.py b/tools/functional/generate_dispatch_stateful_ops.py new file mode 100644 index 00000000000..a91c36b1fb3 --- /dev/null +++ b/tools/functional/generate_dispatch_stateful_ops.py @@ -0,0 +1,185 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import re +import argparse +import yaml + +from generator import Generator + +parser = argparse.ArgumentParser() +parser.add_argument( + "--project_source_dir", type=str, help="The project source code directory.", +) +args = parser.parse_args() + +license = """/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the \"License\"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an \"AS IS\" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Generated from oneflow/api/python/functional/dispatch_stateful_ops.yaml. DO NOT EDIT!""" + +header_fmt = ( + license + + """ + +#ifndef ONEFLOW_API_PYTHON_FUNCTIONAL_GENERATED_DISPATCH_OP_API_H_ +#define ONEFLOW_API_PYTHON_FUNCTIONAL_GENERATED_DISPATCH_OP_API_H_ + +#include + +#include "oneflow/core/common/optional.h" +#include "oneflow/core/common/scalar.h" +#include "oneflow/core/framework/dtype.h" +#include "oneflow/core/framework/tensor.h" +#include "oneflow/core/framework/tensor_tuple.h" +#include "oneflow/core/framework/random_generator.h" +#include "oneflow/core/functional/tensor_index.h" + +namespace oneflow {{ +namespace one {{ +namespace functional {{ +{0} +}} // namespace functional +}} // namespace one +}} // namespace oneflow + +#endif // ONEFLOW_API_PYTHON_FUNCTIONAL_GENERATED_DISPATCH_OP_API_H_""" +) + +source_fmt = ( + license + + """ + +#include "oneflow/api/python/functional/dispatch_stateful_ops.yaml.h" +#include "oneflow/core/functional/function_library.h" + +namespace oneflow {{ +namespace one {{ +namespace functional {{ +{0} +}} // namespace functional +}} // namespace one +}} // namespace oneflow +""" +) + +pybind_header_fmt = ( + license + + """ + +#include + +namespace py = pybind11; + +namespace oneflow {{ +namespace one {{ +namespace functional {{ +{0} +}} // namespace functional +}} // namespace one +}} // namespace oneflow +""" +) + +pybind_source_fmt = ( + license + + """ + +#include +#include + +#include "oneflow/api/python/of_api_registry.h" +#include "oneflow/api/python/functional/function_def.h" +#include "oneflow/api/python/functional/py_function.h" +#include "oneflow/api/python/functional/dispatch_stateful_ops.yaml.h" +#include "oneflow/api/python/functional/dispatch_stateful_ops.yaml.pybind.h" +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/optional.h" + +namespace py = pybind11; + +namespace oneflow {{ +namespace one {{ +namespace functional {{ +{0} +}} // namespace functional +}} // namespace one + +namespace functional = one::functional; + +ONEFLOW_API_PYBIND11_MODULE("_C", m) {{ + py::options options; + options.disable_function_signatures(); + +{1} + options.enable_function_signatures(); +}} + +}} // namespace oneflow +""" +) + +yaml_file_path = os.path.join( + args.project_source_dir, "oneflow/api/python/functional/dispatch_stateful_ops.yaml" +) +generated_api_dir = "oneflow/api/python/functional" +generated_pybind_dir = "oneflow/api/python/functional" + +if __name__ == "__main__": + assert os.path.isfile(yaml_file_path), ( + "It is not a regular file for the yaml file which is " + yaml_file_path + ) + g = Generator(yaml_file_path) + + assert os.path.isdir(generated_api_dir), ( + "Could not locate the api generate directory which is " + generated_api_dir + ) + target_header_file = os.path.join(generated_api_dir, "dispatch_stateful_ops.yaml.h") + g.generate_cpp_header_file(header_fmt, target_header_file) + target_source_file = os.path.join( + generated_api_dir, "dispatch_stateful_ops.yaml.cpp" + ) + g.generate_cpp_source_file(source_fmt, target_source_file) + + assert os.path.isdir(generated_pybind_dir), ( + "Could not locate the pybind generate directory which is " + + generated_pybind_dir + ) + target_pybind_header_file = os.path.join( + generated_pybind_dir, "dispatch_stateful_ops.yaml.pybind.h" + ) + target_pybind_source_file = os.path.join( + generated_pybind_dir, "dispatch_stateful_ops.yaml.pybind.cpp" + ) + g.generate_pybind_for_python( + pybind_header_fmt, + pybind_source_fmt, + target_pybind_header_file, + target_pybind_source_file, + ) diff --git a/tools/functional/generator.py b/tools/functional/generator.py index add3dc07cdf..e32bc1788ca 100644 --- a/tools/functional/generator.py +++ b/tools/functional/generator.py @@ -46,7 +46,10 @@ "Placement", "Sbp", "SbpList", + "OpExpr", "PyObject*", + "ShapeList", + "DataTypeList", } mangled_name = { @@ -77,7 +80,10 @@ "Placement": "P", "Sbp": "Sbp", "SbpList": "Sbpl", + "OpExpr": "Op", "PyObject*": "Pyo", + "ShapeList": "Shl", + "DataTypeList": "Dtl", } generic_type_aliases = { @@ -110,7 +116,10 @@ "Placement": "const Symbol&", "Sbp": "const Symbol&", "SbpList": "const std::vector>&", + "OpExpr": "const std::shared_ptr&", "PyObject*": "PyObject*", + "ShapeList": "const std::vector&", + "DataTypeList": "const std::vector>&", **generic_type_aliases, } @@ -135,7 +144,10 @@ "Placement": "const Optional>&", "Sbp": "const Optional>&", "SbpList": "const Optional>>&", + "OpExpr": "const Optional&", "PyObject*": "const Optional&", + "ShapeList": "const Optional>&", + "DataTypeList": "const Optional>>&", **{k: "const Optional<{0}>&".format(v) for k, v in generic_type_aliases.items()}, } @@ -151,6 +163,8 @@ "True": "true", "False": "false", "kInt": "DType::Int32()", + "kInt8": "DType::Int8()", + "kUInt8": "DType::UInt8()", "kInt32": "DType::Int32()", "kInt64": "DType::Int64()", "kFloat": "DType::Float()", @@ -409,12 +423,12 @@ def generate_cpp_source_file(self, source_fmt, target_source_file): fmt += "\n" fmt += signature.to_string(to_cpp=True) fmt += " {\n" - fmt += ' static thread_local const auto& op = CHECK_JUST(FunctionLibrary::Global()->find<{0}, {1}>("{2}"));\n'.format( + fmt += ' static thread_local const auto& __op = CHECK_JUST(FunctionLibrary::Global()->find<{0}, {1}>("{2}"));\n'.format( signature._ret._cpp_type, ", ".join([arg._cpp_type for arg in signature._args]), signature._name, ) - fmt += " return op->call({0});\n".format( + fmt += " return __op->call({0});\n".format( ", ".join([arg._name for arg in signature._args]), ) fmt += "}\n" diff --git a/tools/oneflow-tblgen/CMakeLists.txt b/tools/oneflow-tblgen/CMakeLists.txt new file mode 100644 index 00000000000..8ce68429843 --- /dev/null +++ b/tools/oneflow-tblgen/CMakeLists.txt @@ -0,0 +1,46 @@ +set(LLVM_LINK_COMPONENTS + Support +) +include(FetchContent) + +set(JSON_Install ON CACHE STRING "" FORCE) +FetchContent_Declare( + json + URL ${JSON_URL} + URL_HASH MD5=${JSON_URL_HASH} +) + +set(INJA_USE_EMBEDDED_JSON OFF CACHE STRING "" FORCE) +set(INJA_BUILD_TESTS OFF CACHE STRING "" FORCE) +set(BUILD_BENCHMARK OFF CACHE STRING "" FORCE) +FetchContent_Declare( + inja + URL ${INJA_URL} + URL_HASH MD5=${INJA_URL_HASH} +) + +FetchContent_MakeAvailable(json inja) + +add_tablegen(oneflow_tblgen llvm + tablegen.cpp + op_schema_emitter.cpp +) + +if(LLVM_ENABLE_OBJLIB) + set(OF_TBLGEN_TARGET obj.oneflow_tblgen) +else() + set(OF_TBLGEN_TARGET oneflow_tblgen) +endif() + +target_link_libraries(${OF_TBLGEN_TARGET} PRIVATE + nlohmann_json::nlohmann_json + pantor::inja +) + +install(TARGETS oneflow_tblgen LLVMTableGen LLVMDemangle LLVMSupport COMPONENT OneFlowTableGen) +add_custom_target(install-oneflow-tblgen + DEPENDS oneflow_tblgen + COMMAND + "${CMAKE_COMMAND}" -DCMAKE_INSTALL_COMPONENT=OneFlowTableGen + -P "${CMAKE_BINARY_DIR}/cmake_install.cmake" +) diff --git a/tools/oneflow-tblgen/backends.h b/tools/oneflow-tblgen/backends.h new file mode 100644 index 00000000000..d1da04ac2ea --- /dev/null +++ b/tools/oneflow-tblgen/backends.h @@ -0,0 +1,39 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef ONEFLOW_TBLGEN_BACKENDS_H +#define ONEFLOW_TBLGEN_BACKENDS_H + +namespace llvm { +class raw_ostream; +class RecordKeeper; +} // namespace llvm + +namespace oneflow { + +namespace tblgen { + +using llvm::raw_ostream; +using llvm::RecordKeeper; + +void EmitOpSchemaHeader(RecordKeeper& RK, raw_ostream& OS); +void EmitOpSchemaSource(RecordKeeper& RK, raw_ostream& OS); + +} // namespace tblgen + +} // namespace oneflow + +#endif // ONEFLOW_TBLGEN_BACKENDS_H diff --git a/tools/oneflow-tblgen/example/constant.td b/tools/oneflow-tblgen/example/constant.td new file mode 100644 index 00000000000..561d2999bfb --- /dev/null +++ b/tools/oneflow-tblgen/example/constant.td @@ -0,0 +1,17 @@ +include "mlir/Interfaces/SideEffectInterfaces.td" +include "OneFlowEnums.td" +include "OneFlowBase.td" + +def OneFlow_ConstantOp : OneFlow_BaseOp<"constant", [NoSideEffect, DeclareOpInterfaceMethods]> { + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$floating_value, + DefaultValuedAttr:$integer_value, + DefaultValuedAttr:$is_floating_value, + StrAttr:$dtype, + AnyI64ElementsAttr:$shape, + StrArrayAttr:$nd_sbp + ); +} diff --git a/tools/oneflow-tblgen/op_schema_emitter.cpp b/tools/oneflow-tblgen/op_schema_emitter.cpp new file mode 100644 index 00000000000..e06d5d10ba1 --- /dev/null +++ b/tools/oneflow-tblgen/op_schema_emitter.cpp @@ -0,0 +1,242 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Format.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" +#include "inja/inja.hpp" + +#include +#include + +using namespace llvm; +using inja::json; + +namespace oneflow { +namespace tblgen { + +cl::OptionCategory opSchemaCat("Options for -gen-op-schema"); + +cl::opt sourceIncludeFilename{ + "op-include", cl::desc("header filename to include in source file"), + cl::value_desc("include filename"), cl::init(""), cl::cat(opSchemaCat)}; + +cl::opt dumpJson{"op-dump-json", + cl::desc("dump tablegen code to json in provided file"), + cl::value_desc("filename"), cl::init(""), cl::cat(opSchemaCat)}; + +enum class FileTarget { + kHeader = 1, + kSource, +}; + +template +class OpSchemaEmitter { + public: + explicit OpSchemaEmitter(RecordKeeper& RK); + + void run(raw_ostream& os); + + void emitInputAndOutput(const Record* def, json* op) const; + + void emitAttrs(const Record* def, json* op) const; + + void emitInt(const Record* def, StringRef fieldname, json* op) const; + void emitBit(const Record* def, StringRef fieldname, json* op) const; + void emitTrait(const Record* def, StringRef fieldname, StringRef traitname, json* op) const; + + private: + static std::string emitType(const std::string& ods_type) { +#define OP_SCHEMA(ods, cpp) \ + if (ods_type == #ods) return #cpp; +#include "op_schema_types.inc" +#undef OP_SCHEMA + PrintFatalError("undefined attribute type: " + ods_type); + } + + private: + RecordKeeper& records; + + StringRef op_type_name; + StringRef op_name; + + inja::Environment env; + inja::Template temp; + static const std::string code; +}; + +template +OpSchemaEmitter::OpSchemaEmitter(RecordKeeper& RK) : records(RK) { + env.add_callback("quoted", 1, [](inja::Arguments& args) { + auto str = args.at(0)->get(); + std::ostringstream os; + os << std::quoted(str); + return os.str(); + }); + env.add_callback("to_header", 1, [](inja::Arguments& args) { + auto str = args.at(0)->get(); + auto dot_pos = str.find_last_of('.'); + if (dot_pos != std::string::npos) { str.replace(dot_pos, str.size() - dot_pos, ".h"); } + + // assume that the source and header file is in the same directory + auto slash_pos = str.find_last_of('/'); + if (slash_pos != std::string::npos) { str.replace(0, slash_pos + 1, ""); } + return str; + }); + temp = env.parse(code); +} + +template +void OpSchemaEmitter::run(raw_ostream& os) { + emitSourceFileHeader("oneflow op schema", os); + json ops = json::object(); + + for (const auto& def : records.getAllDerivedDefinitions("OneFlow_BaseOp")) { + op_type_name = def->getValueAsString("opName"); + if (op_type_name.empty()) { + PrintFatalError(def, "`opName` of op definitions cannot be omitted"); + } + op_name = def->getName(); + if (!op_name.consume_front("OneFlow_")) { + PrintFatalError(def, "op name is not start with `OneFlow_`: " + op_name.str()); + } + json op{{"name", op_type_name}, + {"input", json::array()}, + {"output", json::array()}, + {"attrs", json::array()}}; + + emitInputAndOutput(def, &op); + emitAttrs(def, &op); + emitInt(def, "same_output_regst_num", &op); + emitTrait(def, "no_grad", "NoGrad", &op); + emitTrait(def, "cpu_only", "CpuOnly", &op); + emitBit(def, "has_nd_sbp_infer_fn", &op); + emitBit(def, "has_get_sbp_fn", &op); + emitBit(def, "has_logical_tensor_desc_infer_fn", &op); + emitBit(def, "has_physical_tensor_desc_infer_fn", &op); + emitBit(def, "has_data_type_infer_fn", &op); + emitBit(def, "has_device_infer_fn", &op); + emitBit(def, "has_input_arg_modify_fn", &op); + emitBit(def, "has_output_arg_modify_fn", &op); + emitBit(def, "has_output_blob_time_shape_infer_fn", &op); + emitBit(def, "has_sbp_signature_infer_fn", &op); + emitBit(def, "has_check_fn", &op); + ops[op_name.str()] = op; + } + + auto* option = static_cast*>(cl::getRegisteredOptions().lookup("o")); + auto filename = option->getValue(); + filename = filename != "-" ? filename : ""; + json data{{"filename", filename}, {"ops", ops}}; + + if (Target == FileTarget::kSource) { data["include"] = sourceIncludeFilename; } + if (!dumpJson.empty()) { + std::ofstream file(dumpJson); + file << data.dump(); + } + os << env.render(temp, data); +} + +template +void OpSchemaEmitter::emitInputAndOutput(const Record* def, json* op) const { + const auto* input = def->getValueAsDag("input"); + for (size_t i = 0; i < input->getNumArgs(); ++i) { + const auto* A = dyn_cast(input->getArg(i))->getDef(); + bool is_optional = A->isSubClassOf("Optional"); + auto NS = input->getArgName(i)->getAsUnquotedString(); + (*op)["input"].push_back({{"name", NS}, {"is_optional", is_optional}, {"size", 1}}); + } + const auto* output = def->getValueAsDag("output"); + for (size_t i = 0; i < output->getNumArgs(); ++i) { + const auto* A = dyn_cast(output->getArg(i))->getDef(); + bool is_optional = A->isSubClassOf("Optional"); + auto NS = output->getArgName(i)->getAsUnquotedString(); + (*op)["output"].push_back({{"name", NS}, {"is_optional", is_optional}, {"size", 1}}); + } +} + +template +void OpSchemaEmitter::emitAttrs(const Record* def, json* op) const { + const auto* attrs = def->getValueAsDag("attrs"); + for (size_t i = 0; i < attrs->getNumArgs(); ++i) { + const auto* A = dyn_cast(attrs->getArg(i))->getDef(); + std::string AS; + if (!A->isAnonymous()) { + AS = A->getNameInitAsString(); + } else { + AS = A->getValueAsDef("baseAttr")->getNameInitAsString(); + } + auto NS = attrs->getArgName(i)->getAsUnquotedString(); + json attr{{"name", NS}, {"type", emitType(AS)}}; + + if (auto DV = A->getValueAsOptionalString("defaultValue")) { attr["default"] = DV.getValue(); } + + (*op)["attrs"].push_back(attr); + } +} + +template +void OpSchemaEmitter::emitBit(const Record* def, StringRef fieldname, json* op) const { + (*op)[fieldname.str()] = def->getValueAsBit(fieldname); +} + +template +void OpSchemaEmitter::emitTrait(const Record* def, StringRef fieldname, StringRef traitname, + json* op) const { + bool hasTrait = false; + + for (auto elem : *def->getValueAsListInit("traits")) { + if (elem->getAsString() == traitname) { + hasTrait = true; + break; + } + } + + (*op)[fieldname.str()] = hasTrait; +} + +template +void OpSchemaEmitter::emitInt(const Record* def, StringRef fieldname, json* op) const { + (*op)[fieldname.str()] = def->getValueAsInt(fieldname); +} + +template<> +const std::string OpSchemaEmitter::code{ +#include "op_schema_header.inc" +}; + +template<> +const std::string OpSchemaEmitter::code{ +#include "op_schema_source.inc" +}; + +void EmitOpSchemaHeader(RecordKeeper& RK, raw_ostream& os) { + OpSchemaEmitter(RK).run(os); +} + +void EmitOpSchemaSource(RecordKeeper& RK, raw_ostream& os) { + OpSchemaEmitter(RK).run(os); +} + +} // namespace tblgen +} // namespace oneflow diff --git a/tools/oneflow-tblgen/op_schema_header.inc b/tools/oneflow-tblgen/op_schema_header.inc new file mode 100644 index 00000000000..ae167d08921 --- /dev/null +++ b/tools/oneflow-tblgen/op_schema_header.inc @@ -0,0 +1,100 @@ +R"OP_SCHEMA_INC( +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/common/shape.h" +#include "oneflow/core/common/symbol.h" +#include "oneflow/core/framework/op_base.h" + +#include +#include +#include + +namespace oneflow { + +class Device; +class InputBlobModifier; +class OutputBlobModifier; + +namespace user_op { +class UserOpDefWrapper; +class UserOpConfWrapper; +class InferContext; +class SbpContext; +class InferSbpSignatureFnContext; +class InferOutputBlobTimeShapeFnContext; +class InferNdSbpFnContext; +class DeviceInferContext; +} // namespace user_op + +using GetInputArgModifier = + std::function; +using GetOutputArgModifier = + std::function; + +{% for opname, op in ops %} +class {{opname}} : public OpBase { + public: + virtual ~{{opname}}() = default; + {% if op.has_nd_sbp_infer_fn -%} + static Maybe InferNdSbp(user_op::InferNdSbpFnContext* ctx); + {% endif -%} + {% if op.has_get_sbp_fn -%} + static Maybe GetSbp(user_op::SbpContext* ctx); + {% endif -%} + {% if op.has_logical_tensor_desc_infer_fn -%} + static Maybe InferLogicalTensorDesc(user_op::InferContext* ctx); + {% endif -%} + {% if op.has_physical_tensor_desc_infer_fn -%} + static Maybe InferPhysicalTensorDesc(user_op::InferContext* ctx); + {% endif -%} + {% if op.has_data_type_infer_fn -%} + static Maybe InferDataType(user_op::InferContext* ctx); + {% endif -%} + {% if op.has_device_infer_fn -%} + static Maybe> InferDevice(user_op::DeviceInferContext* ctx); + {% endif -%} + {% if op.has_sbp_signature_infer_fn -%} + static Maybe InferSbpSignature(user_op::InferSbpSignatureFnContext* ctx); + {% endif -%} + {% if op.has_input_arg_modify_fn -%} + static Maybe ModifyInputArg(const GetInputArgModifier&, const user_op::UserOpConfWrapper&); + {% endif -%} + {% if op.has_output_arg_modify_fn -%} + static Maybe ModifyOutputArg(const GetOutputArgModifier&, const user_op::UserOpConfWrapper&); + {% endif -%} + {% if op.has_output_blob_time_shape_infer_fn -%} + static Maybe InferOutputBlobTimeShape(user_op::InferOutputBlobTimeShapeFnContext* ctx); + {% endif -%} + {% if op.has_check_fn -%} + static Maybe CheckAttr(const user_op::UserOpDefWrapper&, const user_op::UserOpConfWrapper&); + {% endif -%} + + {% for attr in op.attrs -%} + virtual const {{attr.type}}& {{attr.name}}() const = 0; + virtual {{attr.type}}* mutable_{{attr.name}}() = 0; + virtual void set_{{attr.name}}(const {{attr.type}}& {{attr.name}}) = 0; + + {% endfor -%} + const HashSet& AttrNames() const; +}; + +namespace schema { +class {{opname}} : public oneflow::{{opname}} { + public: + {% for attr in op.attrs -%} + const {{attr.type}}& {{attr.name}}() const override { return {{attr.name}}_; } + {{attr.type}}* mutable_{{attr.name}}() override { return &{{attr.name}}_; } + void set_{{attr.name}}(const {{attr.type}}& {{attr.name}}) override { {{attr.name}}_ = {{attr.name}}; } + + {% endfor -%} + + Maybe GetAttr(const std::string& attr_name) const override; + + private: + {% for attr in op.attrs -%} + {{attr.type}} {{attr.name}}_{% if existsIn(attr, "default") %} = {{attr.default}}{% endif %}; + {% endfor %} +}; +} // namespace schema +{% endfor %} +} // namespace oneflow +)OP_SCHEMA_INC" diff --git a/tools/oneflow-tblgen/op_schema_source.inc b/tools/oneflow-tblgen/op_schema_source.inc new file mode 100644 index 00000000000..ceaa4b3d1b7 --- /dev/null +++ b/tools/oneflow-tblgen/op_schema_source.inc @@ -0,0 +1,106 @@ +R"OP_SCHEMA_INC( +{% if include != "" %}#include "{{ include }}" +{% else if filename != "" %}#include "{{ to_header(filename) }}" +{% endif %} +#include "oneflow/core/common/auto_registration_factory.h" +#include "oneflow/core/framework/attr_value.h" +#include "oneflow/core/framework/nd_sbp.h" +#include "oneflow/core/framework/infer_nd_sbp_fn_context.h" +#include "oneflow/core/framework/user_op_registry_manager.h" + +namespace oneflow { + +#define REGISTER_OP_SCHEMA(op_type, schema) \ + REGISTER_CLASS_CREATOR(std::string, op_type, OpBase, ([]() { return new schema; })) + +{% for opname, op in ops %} +const HashSet& {{opname}}::AttrNames() const { + static const HashSet attr_names = { {%- for attr in op.attrs -%}"{{attr.name}}", {%- endfor -%} }; + return attr_names; +} + +namespace schema { +Maybe {{opname}}::GetAttr(const std::string& attr_name) const { + {% for attr in op.attrs %}if(attr_name == "{{attr.name}}") { + return CastAttrValue(&{{attr.name}}_); + } + {% endfor -%} + return Error::RuntimeError() << "{{op.name}} op has no attribute named " << attr_name; +} +} // namespace schema + +REGISTER_OP_SCHEMA("user.{{op.name}}", schema::{{opname}}); + +REGISTER_USER_OP("{{op.name}}") +{%- if op.input -%} +{%- for input in op.input -%} +{%- if input.is_optional -%} + .OptionalInput("{{input.name}}") +{%- else -%} + .Input("{{input.name}}") +{%- endif -%} +{%- endfor -%} +{%- endif -%} +{%- if op.output -%} +{%- for output in op.output -%} +{%- if output.is_optional -%} + .OptionalOutput("{{output.name}}") +{%- else -%} + .Output("{{output.name}}") +{%- endif -%} +{%- endfor -%} +{%- endif -%} + +{%- for attr in op.attrs -%} +{%- if existsIn(attr, "default") -%} + .Attr<{{attr.type}}>("{{attr.name}}", {{attr.default}}) +{%- else -%} + .Attr<{{attr.type}}>("{{attr.name}}") +{%- endif -%} +{%- endfor -%} +{%- if op.cpu_only -%} + .SupportCpuOnly() +{%- endif -%} +{%- if op.no_grad -%} + .NoGrad() +{%- endif -%} +{%- if op.same_output_regst_num != -1 -%} + .SetOutputBufferNum({{op.same_output_regst_num}}) +{%- endif -%} +{%- if op.has_nd_sbp_infer_fn -%} + .SetNdSbpInferFn(&{{opname}}::InferNdSbp) +{%- endif -%} +{%- if op.has_get_sbp_fn -%} + .SetGetSbpFn(&{{opname}}::GetSbp) +{%- endif -%} +{%- if op.has_logical_tensor_desc_infer_fn -%} + .SetLogicalTensorDescInferFn(&{{opname}}::InferLogicalTensorDesc) +{%- endif -%} +{%- if op.has_physical_tensor_desc_infer_fn -%} + .SetPhysicalTensorDescInferFn(&{{opname}}::InferPhysicalTensorDesc) +{%- endif -%} +{%- if op.has_data_type_infer_fn -%} + .SetDataTypeInferFn(&{{opname}}::InferDataType) +{%- endif -%} +{%- if op.has_device_infer_fn -%} + .SetDeviceInferFn(&{{opname}}::InferDevice) +{%- endif -%} +{%- if op.has_sbp_signature_infer_fn -%} + .SetSbpSignatureInferFn(&{{opname}}::InferSbpSignature) +{% endif -%} +{%- if op.has_input_arg_modify_fn -%} + .SetInputArgModifyFn(&{{opname}}::ModifyInputArg) +{%- endif -%} +{%- if op.has_output_arg_modify_fn -%} + .SetOutputArgModifyFn(&{{opname}}::ModifyOutputArg) +{%- endif -%} +{%- if op.has_output_blob_time_shape_infer_fn -%} + .SetOutputBlobTimeShapeInferFn(&{{opname}}::InferOutputBlobTimeShape) +{%- endif -%} +{%- if op.has_check_fn -%} + .SetCheckAttrFn(&{{opname}}::CheckAttr) +{%- endif -%} +; +{%- endfor %} +} // namespace oneflow +)OP_SCHEMA_INC" diff --git a/tools/oneflow-tblgen/op_schema_types.inc b/tools/oneflow-tblgen/op_schema_types.inc new file mode 100644 index 00000000000..62656abca62 --- /dev/null +++ b/tools/oneflow-tblgen/op_schema_types.inc @@ -0,0 +1,14 @@ +OP_SCHEMA(SI32Attr, int32_t) +OP_SCHEMA(SI64Attr, int64_t) +OP_SCHEMA(BoolAttr, bool) +OP_SCHEMA(F32Attr, float) +OP_SCHEMA(F64Attr, double) +OP_SCHEMA(StrAttr, std::string) +OP_SCHEMA(ShapeAttr, Shape) +OP_SCHEMA(OneFlow_DataType, DataType) +OP_SCHEMA(SI32ArrayAttr, std::vector) +OP_SCHEMA(SI64ArrayAttr, std::vector) +OP_SCHEMA(F32ArrayAttr, std::vector) +OP_SCHEMA(DTArrayAttr, std::vector) +OP_SCHEMA(ShapeArrayAttr, std::vector) +OP_SCHEMA(StrArrayAttr, std::vector) diff --git a/tools/oneflow-tblgen/tablegen.cpp b/tools/oneflow-tblgen/tablegen.cpp new file mode 100644 index 00000000000..a2e3f4c038f --- /dev/null +++ b/tools/oneflow-tblgen/tablegen.cpp @@ -0,0 +1,104 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/TableGen/Main.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/SetTheory.h" + +#include "backends.h" + +using namespace llvm; +using namespace oneflow::tblgen; + +enum ActionType { + PrintRecords, + PrintDetailedRecords, + NullBackend, + DumpJSON, + PrintEnums, + PrintSets, + GenOpSchemaHeader, + GenOpSchemaSource, +}; + +namespace llvm { +cl::opt EmitLongStrLiterals( + "long-string-literals", + cl::desc("when emitting large string tables, prefer string literals over " + "comma-separated char literals. This can be a readability and " + "compile-time performance win, but upsets some compilers"), + cl::Hidden, cl::init(true)); +} // end namespace llvm + +namespace { +cl::opt Action( + cl::desc("Action to perform:"), + cl::values(clEnumValN(PrintRecords, "print-records", "Print all records to stdout (default)"), + clEnumValN(PrintDetailedRecords, "print-detailed-records", + "Print full details of all records to stdout"), + clEnumValN(NullBackend, "null-backend", + "Do nothing after parsing (useful for timing)"), + clEnumValN(DumpJSON, "dump-json", "Dump all records as machine-readable JSON"), + clEnumValN(PrintEnums, "print-enums", "Print enum values for a class"), + clEnumValN(PrintSets, "print-sets", "Print expanded sets for testing DAG exprs"), + clEnumValN(GenOpSchemaHeader, "gen-op-schema-h", + "Generate oneflow op schema header code (.h)"), + clEnumValN(GenOpSchemaSource, "gen-op-schema-cpp", + "Generate oneflow op schema source code (.cpp)"))); + +cl::OptionCategory PrintEnumsCat("Options for -print-enums"); +cl::opt Class("class", cl::desc("Print Enum list for this class"), + cl::value_desc("class name"), cl::cat(PrintEnumsCat)); + +bool LLVMTableGenMain(raw_ostream& OS, RecordKeeper& Records) { + switch (Action) { + case PrintRecords: OS << Records; break; + case PrintDetailedRecords: EmitDetailedRecords(Records, OS); break; + case NullBackend: break; + case DumpJSON: EmitJSON(Records, OS); break; + case PrintEnums: { + for (Record* Rec : Records.getAllDerivedDefinitions(Class)) OS << Rec->getName() << ", "; + OS << "\n"; + break; + } + case PrintSets: { + SetTheory Sets; + Sets.addFieldExpander("Set", "Elements"); + for (Record* Rec : Records.getAllDerivedDefinitions("Set")) { + OS << Rec->getName() << " = ["; + const std::vector* Elts = Sets.expand(Rec); + assert(Elts && "Couldn't expand Set instance"); + for (Record* Elt : *Elts) OS << ' ' << Elt->getName(); + OS << " ]\n"; + } + break; + } + case GenOpSchemaHeader: EmitOpSchemaHeader(Records, OS); break; + case GenOpSchemaSource: EmitOpSchemaSource(Records, OS); break; + } + + return false; +} +} // namespace + +int main(int argc, char** argv) { + InitLLVM X(argc, argv); + cl::ParseCommandLineOptions(argc, argv); + + return TableGenMain(argv[0], &LLVMTableGenMain); +} diff --git a/tools/package_mirror.py b/tools/package_mirror.py index 59cd2901890..db439688319 100644 --- a/tools/package_mirror.py +++ b/tools/package_mirror.py @@ -70,6 +70,7 @@ def should_be_mirrored(url: str): and not "mirror.tensorflow.org" in url and not "mirror.bazel.build" in url and not "aliyuncs.com" in url + and not "file:" in url )